Skip to content

Commit

Permalink
[stubgen] Preserve dataclass_transform decorator (#18418)
Browse files Browse the repository at this point in the history
Ref: #18081
  • Loading branch information
cdce8p authored Jan 7, 2025
1 parent 306c1af commit 20355d5
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 11 deletions.
29 changes: 26 additions & 3 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
Var,
)
from mypy.options import Options as MypyOptions
from mypy.semanal_shared import find_dataclass_transform_spec
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
Expand All @@ -139,6 +140,7 @@
has_yield_from_expression,
)
from mypy.types import (
DATACLASS_TRANSFORM_NAMES,
OVERLOAD_NAMES,
TPDICT_NAMES,
TYPED_NAMEDTUPLE_NAMES,
Expand Down Expand Up @@ -701,10 +703,13 @@ def process_decorator(self, o: Decorator) -> None:
"""
o.func.is_overload = False
for decorator in o.original_decorators:
if not isinstance(decorator, (NameExpr, MemberExpr)):
d = decorator
if isinstance(d, CallExpr):
d = d.callee
if not isinstance(d, (NameExpr, MemberExpr)):
continue
qualname = get_qualified_name(decorator)
fullname = self.get_fullname(decorator)
qualname = get_qualified_name(d)
fullname = self.get_fullname(d)
if fullname in (
"builtins.property",
"builtins.staticmethod",
Expand Down Expand Up @@ -739,6 +744,9 @@ def process_decorator(self, o: Decorator) -> None:
o.func.is_overload = True
elif qualname.endswith((".setter", ".deleter")):
self.add_decorator(qualname, require_name=False)
elif fullname in DATACLASS_TRANSFORM_NAMES:
p = AliasPrinter(self)
self._decorators.append(f"@{decorator.accept(p)}")

def get_fullname(self, expr: Expression) -> str:
"""Return the expression's full name."""
Expand Down Expand Up @@ -785,6 +793,8 @@ def visit_class_def(self, o: ClassDef) -> None:
self.add(f"{self._indent}{docstring}\n")
n = len(self._output)
self._vars.append([])
if self.analyzed and find_dataclass_transform_spec(o):
self.processing_dataclass = True
super().visit_class_def(o)
self.dedent()
self._vars.pop()
Expand Down Expand Up @@ -854,13 +864,26 @@ def get_class_decorators(self, cdef: ClassDef) -> list[str]:
decorators.append(d.accept(p))
self.import_tracker.require_name(get_qualified_name(d))
self.processing_dataclass = True
if self.is_dataclass_transform(d):
decorators.append(d.accept(p))
self.import_tracker.require_name(get_qualified_name(d))
return decorators

def is_dataclass(self, expr: Expression) -> bool:
if isinstance(expr, CallExpr):
expr = expr.callee
return self.get_fullname(expr) == "dataclasses.dataclass"

def is_dataclass_transform(self, expr: Expression) -> bool:
if isinstance(expr, CallExpr):
expr = expr.callee
if self.get_fullname(expr) in DATACLASS_TRANSFORM_NAMES:
return True
if find_dataclass_transform_spec(expr) is not None:
self.processing_dataclass = True
return True
return False

def visit_block(self, o: Block) -> None:
# Unreachable statements may be partially uninitialized and that may
# cause trouble.
Expand Down
130 changes: 122 additions & 8 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3104,15 +3104,12 @@ class C:
x = attrs.field()

[out]
from _typeshed import Incomplete
import attrs

@attrs.define
class C:
x: Incomplete
x = ...
def __init__(self, x) -> None: ...
def __lt__(self, other): ...
def __le__(self, other): ...
def __gt__(self, other): ...
def __ge__(self, other): ...

[case testNamedTupleInClass]
from collections import namedtuple
Expand Down Expand Up @@ -4249,6 +4246,122 @@ class Y(missing.Base):
generated_kwargs_: float
def __init__(self, *generated_args__, generated_args, generated_args_, generated_kwargs, generated_kwargs_, **generated_kwargs__) -> None: ...

[case testDataclassTransform]
# dataclass_transform detection only works with sementic analysis.
# Test stubgen doesn't break too badly without it.
from typing_extensions import dataclass_transform

@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls):
return cls

@create_model
class X:
a: int
b: str = "hello"

@typing_extensions.dataclass_transform(kw_only_default=True)
class ModelBase: ...

class Y(ModelBase):
a: int
b: str = "hello"

@typing_extensions.dataclass_transform(kw_only_default=True)
class DCMeta(type): ...

class Z(metaclass=DCMeta):
a: int
b: str = "hello"

[out]
@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls): ...

class X:
a: int
b: str

@typing_extensions.dataclass_transform(kw_only_default=True)
class ModelBase: ...

class Y(ModelBase):
a: int
b: str

@typing_extensions.dataclass_transform(kw_only_default=True)
class DCMeta(type): ...

class Z(metaclass=DCMeta):
a: int
b: str

[case testDataclassTransformDecorator_semanal]
import typing_extensions

@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls):
return cls

@create_model
class X:
a: int
b: str = "hello"

[out]
import typing_extensions

@typing_extensions.dataclass_transform(kw_only_default=True)
def create_model(cls): ...

@create_model
class X:
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...

[case testDataclassTransformClass_semanal]
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
class ModelBase: ...

class X(ModelBase):
a: int
b: str = "hello"

[out]
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
class ModelBase: ...

class X(ModelBase):
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...

[case testDataclassTransformMetaclass_semanal]
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
class DCMeta(type): ...

class X(metaclass=DCMeta):
a: int
b: str = "hello"

[out]
from typing_extensions import dataclass_transform

@dataclass_transform(kw_only_default=True)
class DCMeta(type): ...

class X(metaclass=DCMeta):
a: int
b: str = ...
def __init__(self, *, a, b=...) -> None: ...

[case testAlwaysUsePEP604Union]
import typing
import typing as t
Expand Down Expand Up @@ -4536,16 +4649,17 @@ def f5[T5 = int]() -> None: ...
# flags: --include-private --python-version=3.13
from typing_extensions import dataclass_transform

# TODO: preserve dataclass_transform decorator
@dataclass_transform()
class DCMeta(type): ...
class DC(metaclass=DCMeta):
x: str

[out]
from typing_extensions import dataclass_transform

@dataclass_transform()
class DCMeta(type): ...

class DC(metaclass=DCMeta):
x: str
def __init__(self, x) -> None: ...
def __replace__(self, *, x) -> None: ...

0 comments on commit 20355d5

Please sign in to comment.