Skip to content

Commit 6da936b

Browse files
authored
feat(mypy_plugin): item extracted results analyzed as typeddict (#87)
* feat(mypy_plugin): item extracted results analyzed as typeddict * fix: compat issue
1 parent 83dcb84 commit 6da936b

File tree

3 files changed

+182
-18
lines changed

3 files changed

+182
-18
lines changed

data_extractor/contrib/mypy/__init__.py

Lines changed: 133 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Standard Library
2-
from typing import Callable, Dict, Optional, Type
2+
import logging
3+
4+
from functools import partial
5+
from typing import Callable, Dict, List, Optional, Type
36

47
# Third Party Library
58
from mypy.checker import TypeChecker
@@ -15,11 +18,21 @@
1518
TypeInfo,
1619
)
1720
from mypy.nodes import Var as VarExpr
18-
from mypy.plugin import FunctionContext, MethodSigContext, Plugin
21+
from mypy.options import Options
22+
from mypy.plugin import (
23+
DynamicClassDefContext,
24+
FunctionContext,
25+
MethodSigContext,
26+
Plugin,
27+
)
28+
from mypy.semanal import SemanticAnalyzerInterface
29+
from mypy.semanal_typeddict import TypedDictAnalyzer
1930
from mypy.traverser import TraverserVisitor
2031
from mypy.types import AnyType, CallableType, Instance
2132
from mypy.types import Type as MypyType
22-
from mypy.types import TypeOfAny, UninhabitedType, UnionType
33+
from mypy.types import TypedDictType, TypeOfAny, UninhabitedType, UnionType
34+
35+
logger = logging.getLogger(__name__)
2336

2437

2538
class RelationshipVisitor(TraverserVisitor):
@@ -81,7 +94,13 @@ def visit_assignment_stmt(self, o):
8194

8295

8396
class DataExtractorPlugin(Plugin):
84-
cache: Dict[str, Dict[str, str]] = {}
97+
cache: Dict[str, Dict[str, str]]
98+
item_typeddict_mapping: Dict[str, TypedDictType]
99+
100+
def __init__(self, options: Options) -> None:
101+
super().__init__(options)
102+
self.cache = {}
103+
self.item_typeddict_mapping = {}
85104

86105
def get_current_code(self, ctx: FunctionContext) -> MypyFile:
87106
api = ctx.api
@@ -112,12 +131,6 @@ def check_field_generic_type(self, ctx: FunctionContext) -> MypyType:
112131
if rv_type.args and not isinstance(rv_type.args[0], UninhabitedType):
113132
return rv_type
114133

115-
# # check parameter "type"
116-
# idx = ctx.callee_arg_names.index("type")
117-
# arg = ctx.args[type_idx]
118-
# if arg:
119-
# args = [arg[0].node]
120-
# else:
121134
if not self.options.disallow_any_generics:
122135
return self.apply_any_generic(type=rv_type)
123136
else:
@@ -139,11 +152,12 @@ def check_is_many(self, ctx: FunctionContext) -> bool:
139152

140153
return False
141154

142-
def prepare_type_annotations(self, ctx: FunctionContext) -> MypyType:
155+
def prepare_type_annotations(self, ctx: FunctionContext, fullname: str) -> MypyType:
156+
logger.debug("prepare_type_annotations %r", fullname)
157+
143158
# check parameter "is_many"
144159
expr = ctx.context
145160
assert isinstance(expr, CallExpr)
146-
147161
relationship = self.anal_code(self.get_current_code(ctx))
148162
lvalue_key = str((expr.line, expr.column))
149163
if lvalue_key in relationship:
@@ -171,24 +185,30 @@ def prepare_type_annotations(self, ctx: FunctionContext) -> MypyType:
171185
rv_type = self.check_field_generic_type(ctx)
172186
return rv_type
173187

174-
def is_extractor_cls(self, fullname: str) -> bool:
188+
def is_extractor_cls(self, fullname: str, is_item_subcls=False) -> bool:
175189
node = self.lookup_fully_qualified(fullname)
176190
if node is not None:
177191
typenode = node.node
178192
if isinstance(typenode, TypeInfo):
179-
return typenode.has_base("data_extractor.item.Field")
193+
if is_item_subcls:
194+
return typenode.has_base("data_extractor.item.Item")
195+
else:
196+
return typenode.has_base("data_extractor.item.Field")
180197

181198
return False
182199

183200
def get_function_hook(
184201
self, fullname: str
185202
) -> Optional[Callable[[FunctionContext], MypyType]]:
203+
logger.debug("get_function_hook %r", fullname)
186204
if self.is_extractor_cls(fullname):
187-
return self.prepare_type_annotations
205+
return partial(self.prepare_type_annotations, fullname=fullname)
188206

189207
return super().get_function_hook(fullname)
190208

191-
def apply_is_many_on_extract_method(self, ctx: MethodSigContext) -> CallableType:
209+
def apply_is_many_on_extract_method(
210+
self, ctx: MethodSigContext, fullname: str
211+
) -> CallableType:
192212
origin: CallableType = ctx.default_signature
193213
origin_ret_type = origin.ret_type
194214
assert isinstance(origin_ret_type, UnionType)
@@ -211,6 +231,7 @@ def apply_is_many_on_extract_method(self, ctx: MethodSigContext) -> CallableType
211231
key = str((obj.line, obj.column))
212232
# metadata = obj.type.type.metadata
213233

234+
logger.debug("apply_is_many %r %r %r", fullname, key, metadata)
214235
if key in metadata:
215236
is_many = metadata[key]["is_many"]
216237
ret_type = origin_ret_type.items[int(is_many)]
@@ -227,14 +248,108 @@ def is_extract_method(self, fullname: str) -> bool:
227248
return self.is_extractor_cls(fullname[: -len(suffix)])
228249
return False
229250

251+
def apply_extract_method(
252+
self, ctx: MethodSigContext, fullname: str
253+
) -> CallableType:
254+
rv = self.apply_is_many_on_extract_method(ctx, fullname)
255+
256+
# apply item typeddict
257+
item_classname = fullname[: -len(".extract")]
258+
if item_classname in self.item_typeddict_mapping:
259+
logger.debug("apply_extract_method %r %r", fullname, rv.ret_type)
260+
original = rv.ret_type
261+
typeddict = self.item_typeddict_mapping[item_classname]
262+
ret_type: Optional[MypyType]
263+
if isinstance(original, AnyType):
264+
ret_type = typeddict
265+
else:
266+
assert isinstance(original, Instance)
267+
if original.type.name == "list":
268+
ret_type = original
269+
ret_type.args = (typeddict,)
270+
else:
271+
api = ctx.api
272+
assert isinstance(api, TypeChecker)
273+
api.fail(
274+
"Cant determine extract method return type", context=ctx.context
275+
)
276+
ret_type = None
277+
278+
if ret_type is not None:
279+
rv = rv.copy_modified(ret_type=ret_type)
280+
281+
logger.debug(
282+
"apply_extract_method %r %r %r", fullname, rv, self.item_typeddict_mapping
283+
)
284+
return rv
285+
230286
def get_method_signature_hook(
231287
self, fullname: str
232288
) -> Optional[Callable[[MethodSigContext], CallableType]]:
233289
if self.is_extract_method(fullname):
234-
return self.apply_is_many_on_extract_method
235-
290+
return partial(self.apply_extract_method, fullname=fullname)
236291
return super().get_method_signature_hook(fullname)
237292

293+
def prepare_typeddict(self, ctx: DynamicClassDefContext, fullname: str) -> None:
294+
logger.debug("prepare_typeddict %r", fullname)
295+
if fullname in self.item_typeddict_mapping:
296+
return
297+
298+
api = ctx.api
299+
assert isinstance(api, SemanticAnalyzerInterface)
300+
analyzer = TypedDictAnalyzer(api.options, api, api.msg) # type: ignore
301+
302+
items: List[str] = []
303+
types: List[MypyType] = []
304+
callee = ctx.call.callee
305+
assert isinstance(callee, NameExpr)
306+
node = callee.node
307+
assert isinstance(node, TypeInfo)
308+
for block in node.defn.defs.body:
309+
if not isinstance(block, AssignmentStmt):
310+
continue
311+
312+
rvalue = block.rvalue
313+
if not isinstance(rvalue, CallExpr):
314+
continue
315+
316+
rvalue_type: MypyType
317+
callee = rvalue.callee
318+
if isinstance(callee, IndexExpr):
319+
index = callee.index
320+
assert isinstance(index, NameExpr)
321+
name = index.fullname
322+
assert name is not None
323+
named_type = api.named_type_or_none(name, [])
324+
assert named_type is not None
325+
rvalue_type = named_type
326+
else:
327+
rvalue_type = AnyType(TypeOfAny.special_form)
328+
329+
for lvalue in block.lvalues:
330+
assert isinstance(lvalue, NameExpr)
331+
items.append(lvalue.name)
332+
types.append(rvalue_type)
333+
334+
callee = ctx.call.callee
335+
assert isinstance(callee, NameExpr)
336+
typeinfo = analyzer.build_typeddict_typeinfo(
337+
callee.name, items, types, set(items), -1
338+
)
339+
assert typeinfo.typeddict_type is not None
340+
self.item_typeddict_mapping[fullname] = typeinfo.typeddict_type
341+
logger.debug("prepare_typeddict %r %r", fullname, self.item_typeddict_mapping)
342+
343+
def get_dynamic_class_hook(
344+
self, fullname: str
345+
) -> Optional[Callable[[DynamicClassDefContext], None]]:
346+
logger.debug("dynamic_class_hook %r", fullname)
347+
if self.options.python_version >= (3, 8):
348+
if self.is_extractor_cls(fullname, is_item_subcls=True):
349+
return partial(self.prepare_typeddict, fullname=fullname)
350+
351+
return super().get_dynamic_class_hook(fullname)
352+
238353

239354
def plugin(version: str) -> Type[Plugin]:
240355
return DataExtractorPlugin
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
- case: item_extracted_result_is_typeddict
2+
skip: sys.version_info.minor < 8
3+
main: |
4+
from tests.utils import D
5+
from data_extractor.item import Item, Field
6+
7+
class Point2D(Item):
8+
x = Field[int](D())
9+
y = Field[int](D())
10+
11+
p = Point2D(D())
12+
rv = p.extract({"x": 1, "y": 3})
13+
reveal_type(rv)
14+
out: |
15+
main:10: note: Revealed type is "TypedDict({'x': builtins.int, 'y': builtins.int})"
16+
- case: item_extracted_many_results_are_typeddict
17+
skip: sys.version_info.minor < 8
18+
main: |
19+
from tests.utils import D
20+
from data_extractor.item import Item, Field
21+
22+
class Point2D(Item):
23+
x = Field[int](D())
24+
y = Field[int](D())
25+
26+
p = Point2D(D(), is_many=True)
27+
rvs = p.extract([{"x": 1, "y": 3}])
28+
reveal_type(rvs)
29+
out: |
30+
main:10: note: Revealed type is "builtins.list[TypedDict({'x': builtins.int, 'y': builtins.int})]"

tests/typesafety/test_generic.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,25 @@
133133
main:9: note: Revealed type is "builtins.str*"
134134
main:12: note: Revealed type is "builtins.list[builtins.str*]"
135135
- case: item_extract_with_flag_is_many
136+
skip: sys.version_info.minor < 8
137+
main: |
138+
from tests.utils import D
139+
from data_extractor import RV, Item
140+
141+
class C(Item[RV]):
142+
pass
143+
144+
f1 = C(D())
145+
rv = f1.extract([1])
146+
reveal_type(rv)
147+
f2 = C(D(), is_many=True)
148+
rvs = f2.extract([1])
149+
reveal_type(rvs)
150+
out: |
151+
main:9: note: Revealed type is "TypedDict({})"
152+
main:12: note: Revealed type is "builtins.list[TypedDict({})]"
153+
- case: item_extract_with_flag_is_many/compat
154+
skip: sys.version_info.minor >= 8
136155
main: |
137156
from tests.utils import D
138157
from data_extractor import RV, Item

0 commit comments

Comments
 (0)