1
1
# Standard Library
2
- from typing import Callable , Dict , Optional , Type
2
+ import logging
3
+
4
+ from functools import partial
5
+ from typing import Callable , Dict , List , Optional , Type
3
6
4
7
# Third Party Library
5
8
from mypy .checker import TypeChecker
15
18
TypeInfo ,
16
19
)
17
20
from mypy .nodes import Var as VarExpr
18
- from mypy .plugin import FunctionContext , MethodSigContext , Plugin
21
+ from mypy .options import Options
22
+ from mypy .plugin import (
23
+ DynamicClassDefContext ,
24
+ FunctionContext ,
25
+ MethodSigContext ,
26
+ Plugin ,
27
+ )
28
+ from mypy .semanal import SemanticAnalyzerInterface
29
+ from mypy .semanal_typeddict import TypedDictAnalyzer
19
30
from mypy .traverser import TraverserVisitor
20
31
from mypy .types import AnyType , CallableType , Instance
21
32
from mypy .types import Type as MypyType
22
- from mypy .types import TypeOfAny , UninhabitedType , UnionType
33
+ from mypy .types import TypedDictType , TypeOfAny , UninhabitedType , UnionType
34
+
35
+ logger = logging .getLogger (__name__ )
23
36
24
37
25
38
class RelationshipVisitor (TraverserVisitor ):
@@ -81,7 +94,13 @@ def visit_assignment_stmt(self, o):
81
94
82
95
83
96
class DataExtractorPlugin (Plugin ):
84
- cache : Dict [str , Dict [str , str ]] = {}
97
+ cache : Dict [str , Dict [str , str ]]
98
+ item_typeddict_mapping : Dict [str , TypedDictType ]
99
+
100
+ def __init__ (self , options : Options ) -> None :
101
+ super ().__init__ (options )
102
+ self .cache = {}
103
+ self .item_typeddict_mapping = {}
85
104
86
105
def get_current_code (self , ctx : FunctionContext ) -> MypyFile :
87
106
api = ctx .api
@@ -112,12 +131,6 @@ def check_field_generic_type(self, ctx: FunctionContext) -> MypyType:
112
131
if rv_type .args and not isinstance (rv_type .args [0 ], UninhabitedType ):
113
132
return rv_type
114
133
115
- # # check parameter "type"
116
- # idx = ctx.callee_arg_names.index("type")
117
- # arg = ctx.args[type_idx]
118
- # if arg:
119
- # args = [arg[0].node]
120
- # else:
121
134
if not self .options .disallow_any_generics :
122
135
return self .apply_any_generic (type = rv_type )
123
136
else :
@@ -139,11 +152,12 @@ def check_is_many(self, ctx: FunctionContext) -> bool:
139
152
140
153
return False
141
154
142
- def prepare_type_annotations (self , ctx : FunctionContext ) -> MypyType :
155
+ def prepare_type_annotations (self , ctx : FunctionContext , fullname : str ) -> MypyType :
156
+ logger .debug ("prepare_type_annotations %r" , fullname )
157
+
143
158
# check parameter "is_many"
144
159
expr = ctx .context
145
160
assert isinstance (expr , CallExpr )
146
-
147
161
relationship = self .anal_code (self .get_current_code (ctx ))
148
162
lvalue_key = str ((expr .line , expr .column ))
149
163
if lvalue_key in relationship :
@@ -171,24 +185,30 @@ def prepare_type_annotations(self, ctx: FunctionContext) -> MypyType:
171
185
rv_type = self .check_field_generic_type (ctx )
172
186
return rv_type
173
187
174
- def is_extractor_cls (self , fullname : str ) -> bool :
188
+ def is_extractor_cls (self , fullname : str , is_item_subcls = False ) -> bool :
175
189
node = self .lookup_fully_qualified (fullname )
176
190
if node is not None :
177
191
typenode = node .node
178
192
if isinstance (typenode , TypeInfo ):
179
- return typenode .has_base ("data_extractor.item.Field" )
193
+ if is_item_subcls :
194
+ return typenode .has_base ("data_extractor.item.Item" )
195
+ else :
196
+ return typenode .has_base ("data_extractor.item.Field" )
180
197
181
198
return False
182
199
183
200
def get_function_hook (
184
201
self , fullname : str
185
202
) -> Optional [Callable [[FunctionContext ], MypyType ]]:
203
+ logger .debug ("get_function_hook %r" , fullname )
186
204
if self .is_extractor_cls (fullname ):
187
- return self .prepare_type_annotations
205
+ return partial ( self .prepare_type_annotations , fullname = fullname )
188
206
189
207
return super ().get_function_hook (fullname )
190
208
191
- def apply_is_many_on_extract_method (self , ctx : MethodSigContext ) -> CallableType :
209
+ def apply_is_many_on_extract_method (
210
+ self , ctx : MethodSigContext , fullname : str
211
+ ) -> CallableType :
192
212
origin : CallableType = ctx .default_signature
193
213
origin_ret_type = origin .ret_type
194
214
assert isinstance (origin_ret_type , UnionType )
@@ -211,6 +231,7 @@ def apply_is_many_on_extract_method(self, ctx: MethodSigContext) -> CallableType
211
231
key = str ((obj .line , obj .column ))
212
232
# metadata = obj.type.type.metadata
213
233
234
+ logger .debug ("apply_is_many %r %r %r" , fullname , key , metadata )
214
235
if key in metadata :
215
236
is_many = metadata [key ]["is_many" ]
216
237
ret_type = origin_ret_type .items [int (is_many )]
@@ -227,14 +248,108 @@ def is_extract_method(self, fullname: str) -> bool:
227
248
return self .is_extractor_cls (fullname [: - len (suffix )])
228
249
return False
229
250
251
+ def apply_extract_method (
252
+ self , ctx : MethodSigContext , fullname : str
253
+ ) -> CallableType :
254
+ rv = self .apply_is_many_on_extract_method (ctx , fullname )
255
+
256
+ # apply item typeddict
257
+ item_classname = fullname [: - len (".extract" )]
258
+ if item_classname in self .item_typeddict_mapping :
259
+ logger .debug ("apply_extract_method %r %r" , fullname , rv .ret_type )
260
+ original = rv .ret_type
261
+ typeddict = self .item_typeddict_mapping [item_classname ]
262
+ ret_type : Optional [MypyType ]
263
+ if isinstance (original , AnyType ):
264
+ ret_type = typeddict
265
+ else :
266
+ assert isinstance (original , Instance )
267
+ if original .type .name == "list" :
268
+ ret_type = original
269
+ ret_type .args = (typeddict ,)
270
+ else :
271
+ api = ctx .api
272
+ assert isinstance (api , TypeChecker )
273
+ api .fail (
274
+ "Cant determine extract method return type" , context = ctx .context
275
+ )
276
+ ret_type = None
277
+
278
+ if ret_type is not None :
279
+ rv = rv .copy_modified (ret_type = ret_type )
280
+
281
+ logger .debug (
282
+ "apply_extract_method %r %r %r" , fullname , rv , self .item_typeddict_mapping
283
+ )
284
+ return rv
285
+
230
286
def get_method_signature_hook (
231
287
self , fullname : str
232
288
) -> Optional [Callable [[MethodSigContext ], CallableType ]]:
233
289
if self .is_extract_method (fullname ):
234
- return self .apply_is_many_on_extract_method
235
-
290
+ return partial (self .apply_extract_method , fullname = fullname )
236
291
return super ().get_method_signature_hook (fullname )
237
292
293
+ def prepare_typeddict (self , ctx : DynamicClassDefContext , fullname : str ) -> None :
294
+ logger .debug ("prepare_typeddict %r" , fullname )
295
+ if fullname in self .item_typeddict_mapping :
296
+ return
297
+
298
+ api = ctx .api
299
+ assert isinstance (api , SemanticAnalyzerInterface )
300
+ analyzer = TypedDictAnalyzer (api .options , api , api .msg ) # type: ignore
301
+
302
+ items : List [str ] = []
303
+ types : List [MypyType ] = []
304
+ callee = ctx .call .callee
305
+ assert isinstance (callee , NameExpr )
306
+ node = callee .node
307
+ assert isinstance (node , TypeInfo )
308
+ for block in node .defn .defs .body :
309
+ if not isinstance (block , AssignmentStmt ):
310
+ continue
311
+
312
+ rvalue = block .rvalue
313
+ if not isinstance (rvalue , CallExpr ):
314
+ continue
315
+
316
+ rvalue_type : MypyType
317
+ callee = rvalue .callee
318
+ if isinstance (callee , IndexExpr ):
319
+ index = callee .index
320
+ assert isinstance (index , NameExpr )
321
+ name = index .fullname
322
+ assert name is not None
323
+ named_type = api .named_type_or_none (name , [])
324
+ assert named_type is not None
325
+ rvalue_type = named_type
326
+ else :
327
+ rvalue_type = AnyType (TypeOfAny .special_form )
328
+
329
+ for lvalue in block .lvalues :
330
+ assert isinstance (lvalue , NameExpr )
331
+ items .append (lvalue .name )
332
+ types .append (rvalue_type )
333
+
334
+ callee = ctx .call .callee
335
+ assert isinstance (callee , NameExpr )
336
+ typeinfo = analyzer .build_typeddict_typeinfo (
337
+ callee .name , items , types , set (items ), - 1
338
+ )
339
+ assert typeinfo .typeddict_type is not None
340
+ self .item_typeddict_mapping [fullname ] = typeinfo .typeddict_type
341
+ logger .debug ("prepare_typeddict %r %r" , fullname , self .item_typeddict_mapping )
342
+
343
+ def get_dynamic_class_hook (
344
+ self , fullname : str
345
+ ) -> Optional [Callable [[DynamicClassDefContext ], None ]]:
346
+ logger .debug ("dynamic_class_hook %r" , fullname )
347
+ if self .options .python_version >= (3 , 8 ):
348
+ if self .is_extractor_cls (fullname , is_item_subcls = True ):
349
+ return partial (self .prepare_typeddict , fullname = fullname )
350
+
351
+ return super ().get_dynamic_class_hook (fullname )
352
+
238
353
239
354
def plugin (version : str ) -> Type [Plugin ]:
240
355
return DataExtractorPlugin
0 commit comments