metaflow 2.15.5__py2.py3-none-any.whl → 2.15.7__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow/_vendor/typeguard/_checkers.py +259 -95
- metaflow/_vendor/typeguard/_config.py +4 -4
- metaflow/_vendor/typeguard/_decorators.py +8 -12
- metaflow/_vendor/typeguard/_functions.py +33 -32
- metaflow/_vendor/typeguard/_pytest_plugin.py +40 -13
- metaflow/_vendor/typeguard/_suppression.py +3 -5
- metaflow/_vendor/typeguard/_transformer.py +84 -48
- metaflow/_vendor/typeguard/_union_transformer.py +1 -0
- metaflow/_vendor/typeguard/_utils.py +13 -9
- metaflow/_vendor/typing_extensions.py +1088 -500
- metaflow/_vendor/v3_7/__init__.py +1 -0
- metaflow/_vendor/v3_7/importlib_metadata/__init__.py +1063 -0
- metaflow/_vendor/v3_7/importlib_metadata/_adapters.py +68 -0
- metaflow/_vendor/v3_7/importlib_metadata/_collections.py +30 -0
- metaflow/_vendor/v3_7/importlib_metadata/_compat.py +71 -0
- metaflow/_vendor/v3_7/importlib_metadata/_functools.py +104 -0
- metaflow/_vendor/v3_7/importlib_metadata/_itertools.py +73 -0
- metaflow/_vendor/v3_7/importlib_metadata/_meta.py +48 -0
- metaflow/_vendor/v3_7/importlib_metadata/_text.py +99 -0
- metaflow/_vendor/v3_7/importlib_metadata/py.typed +0 -0
- metaflow/_vendor/v3_7/typeguard/__init__.py +48 -0
- metaflow/_vendor/v3_7/typeguard/_checkers.py +906 -0
- metaflow/_vendor/v3_7/typeguard/_config.py +108 -0
- metaflow/_vendor/v3_7/typeguard/_decorators.py +237 -0
- metaflow/_vendor/v3_7/typeguard/_exceptions.py +42 -0
- metaflow/_vendor/v3_7/typeguard/_functions.py +310 -0
- metaflow/_vendor/v3_7/typeguard/_importhook.py +213 -0
- metaflow/_vendor/v3_7/typeguard/_memo.py +48 -0
- metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py +100 -0
- metaflow/_vendor/v3_7/typeguard/_suppression.py +88 -0
- metaflow/_vendor/v3_7/typeguard/_transformer.py +1207 -0
- metaflow/_vendor/v3_7/typeguard/_union_transformer.py +54 -0
- metaflow/_vendor/v3_7/typeguard/_utils.py +169 -0
- metaflow/_vendor/v3_7/typeguard/py.typed +0 -0
- metaflow/_vendor/v3_7/typing_extensions.py +3072 -0
- metaflow/_vendor/v3_7/zipp.py +329 -0
- metaflow/cmd/develop/stubs.py +1 -1
- metaflow/extension_support/__init__.py +1 -1
- metaflow/plugins/argo/argo_workflows.py +34 -11
- metaflow/plugins/argo/argo_workflows_deployer_objects.py +7 -6
- metaflow/plugins/pypi/utils.py +4 -0
- metaflow/runner/click_api.py +7 -2
- metaflow/vendor.py +1 -0
- metaflow/version.py +1 -1
- {metaflow-2.15.5.data → metaflow-2.15.7.data}/data/share/metaflow/devtools/Makefile +2 -2
- {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/METADATA +4 -3
- {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/RECORD +53 -27
- {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/WHEEL +1 -1
- {metaflow-2.15.5.data → metaflow-2.15.7.data}/data/share/metaflow/devtools/Tiltfile +0 -0
- {metaflow-2.15.5.data → metaflow-2.15.7.data}/data/share/metaflow/devtools/pick_services.sh +0 -0
- {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/entry_points.txt +0 -0
- {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info/licenses}/LICENSE +0 -0
- {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1207 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import builtins
|
5
|
+
import sys
|
6
|
+
import typing
|
7
|
+
from ast import (
|
8
|
+
AST,
|
9
|
+
Add,
|
10
|
+
AnnAssign,
|
11
|
+
Assign,
|
12
|
+
AsyncFunctionDef,
|
13
|
+
Attribute,
|
14
|
+
AugAssign,
|
15
|
+
BinOp,
|
16
|
+
BitAnd,
|
17
|
+
BitOr,
|
18
|
+
BitXor,
|
19
|
+
Call,
|
20
|
+
ClassDef,
|
21
|
+
Constant,
|
22
|
+
Dict,
|
23
|
+
Div,
|
24
|
+
Expr,
|
25
|
+
Expression,
|
26
|
+
FloorDiv,
|
27
|
+
FunctionDef,
|
28
|
+
If,
|
29
|
+
Import,
|
30
|
+
ImportFrom,
|
31
|
+
Index,
|
32
|
+
List,
|
33
|
+
Load,
|
34
|
+
LShift,
|
35
|
+
MatMult,
|
36
|
+
Mod,
|
37
|
+
Module,
|
38
|
+
Mult,
|
39
|
+
Name,
|
40
|
+
NodeTransformer,
|
41
|
+
NodeVisitor,
|
42
|
+
Pass,
|
43
|
+
Pow,
|
44
|
+
Return,
|
45
|
+
RShift,
|
46
|
+
Starred,
|
47
|
+
Store,
|
48
|
+
Str,
|
49
|
+
Sub,
|
50
|
+
Subscript,
|
51
|
+
Tuple,
|
52
|
+
Yield,
|
53
|
+
YieldFrom,
|
54
|
+
alias,
|
55
|
+
copy_location,
|
56
|
+
expr,
|
57
|
+
fix_missing_locations,
|
58
|
+
keyword,
|
59
|
+
walk,
|
60
|
+
)
|
61
|
+
from collections import defaultdict
|
62
|
+
from collections.abc import Generator, Sequence
|
63
|
+
from contextlib import contextmanager
|
64
|
+
from copy import deepcopy
|
65
|
+
from dataclasses import dataclass, field
|
66
|
+
from typing import Any, ClassVar, cast, overload
|
67
|
+
|
68
|
+
if sys.version_info >= (3, 8):
|
69
|
+
from ast import NamedExpr
|
70
|
+
|
71
|
+
generator_names = (
|
72
|
+
"typing.Generator",
|
73
|
+
"collections.abc.Generator",
|
74
|
+
"typing.Iterator",
|
75
|
+
"collections.abc.Iterator",
|
76
|
+
"typing.Iterable",
|
77
|
+
"collections.abc.Iterable",
|
78
|
+
"typing.AsyncIterator",
|
79
|
+
"collections.abc.AsyncIterator",
|
80
|
+
"typing.AsyncIterable",
|
81
|
+
"collections.abc.AsyncIterable",
|
82
|
+
"typing.AsyncGenerator",
|
83
|
+
"collections.abc.AsyncGenerator",
|
84
|
+
)
|
85
|
+
anytype_names = (
|
86
|
+
"typing.Any",
|
87
|
+
"typing_extensions.Any",
|
88
|
+
)
|
89
|
+
literal_names = (
|
90
|
+
"typing.Literal",
|
91
|
+
"typing_extensions.Literal",
|
92
|
+
)
|
93
|
+
annotated_names = (
|
94
|
+
"typing.Annotated",
|
95
|
+
"typing_extensions.Annotated",
|
96
|
+
)
|
97
|
+
ignore_decorators = (
|
98
|
+
"typing.no_type_check",
|
99
|
+
"typeguard.typeguard_ignore",
|
100
|
+
)
|
101
|
+
aug_assign_functions = {
|
102
|
+
Add: "iadd",
|
103
|
+
Sub: "isub",
|
104
|
+
Mult: "imul",
|
105
|
+
MatMult: "imatmul",
|
106
|
+
Div: "itruediv",
|
107
|
+
FloorDiv: "ifloordiv",
|
108
|
+
Mod: "imod",
|
109
|
+
Pow: "ipow",
|
110
|
+
LShift: "ilshift",
|
111
|
+
RShift: "irshift",
|
112
|
+
BitAnd: "iand",
|
113
|
+
BitXor: "ixor",
|
114
|
+
BitOr: "ior",
|
115
|
+
}
|
116
|
+
|
117
|
+
|
118
|
+
@dataclass
|
119
|
+
class TransformMemo:
|
120
|
+
node: Module | ClassDef | FunctionDef | AsyncFunctionDef | None
|
121
|
+
parent: TransformMemo | None
|
122
|
+
path: tuple[str, ...]
|
123
|
+
joined_path: Constant = field(init=False)
|
124
|
+
return_annotation: expr | None = None
|
125
|
+
yield_annotation: expr | None = None
|
126
|
+
send_annotation: expr | None = None
|
127
|
+
is_async: bool = False
|
128
|
+
local_names: set[str] = field(init=False, default_factory=set)
|
129
|
+
imported_names: dict[str, str] = field(init=False, default_factory=dict)
|
130
|
+
ignored_names: set[str] = field(init=False, default_factory=set)
|
131
|
+
load_names: defaultdict[str, dict[str, Name]] = field(
|
132
|
+
init=False, default_factory=lambda: defaultdict(dict)
|
133
|
+
)
|
134
|
+
has_yield_expressions: bool = field(init=False, default=False)
|
135
|
+
has_return_expressions: bool = field(init=False, default=False)
|
136
|
+
memo_var_name: Name | None = field(init=False, default=None)
|
137
|
+
should_instrument: bool = field(init=False, default=True)
|
138
|
+
variable_annotations: dict[str, expr] = field(init=False, default_factory=dict)
|
139
|
+
configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict)
|
140
|
+
code_inject_index: int = field(init=False, default=0)
|
141
|
+
|
142
|
+
def __post_init__(self) -> None:
|
143
|
+
elements: list[str] = []
|
144
|
+
memo = self
|
145
|
+
while isinstance(memo.node, (ClassDef, FunctionDef, AsyncFunctionDef)):
|
146
|
+
elements.insert(0, memo.node.name)
|
147
|
+
if not memo.parent:
|
148
|
+
break
|
149
|
+
|
150
|
+
memo = memo.parent
|
151
|
+
if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
|
152
|
+
elements.insert(0, "<locals>")
|
153
|
+
|
154
|
+
self.joined_path = Constant(".".join(elements))
|
155
|
+
|
156
|
+
# Figure out where to insert instrumentation code
|
157
|
+
if self.node:
|
158
|
+
for index, child in enumerate(self.node.body):
|
159
|
+
if isinstance(child, ImportFrom) and child.module == "__future__":
|
160
|
+
# (module only) __future__ imports must come first
|
161
|
+
continue
|
162
|
+
elif isinstance(child, Expr):
|
163
|
+
if isinstance(child.value, Constant) and isinstance(
|
164
|
+
child.value.value, str
|
165
|
+
):
|
166
|
+
continue # docstring
|
167
|
+
elif sys.version_info < (3, 8) and isinstance(child.value, Str):
|
168
|
+
continue # docstring
|
169
|
+
|
170
|
+
self.code_inject_index = index
|
171
|
+
break
|
172
|
+
|
173
|
+
def get_unused_name(self, name: str) -> str:
|
174
|
+
memo: TransformMemo | None = self
|
175
|
+
while memo is not None:
|
176
|
+
if name in memo.local_names:
|
177
|
+
memo = self
|
178
|
+
name += "_"
|
179
|
+
else:
|
180
|
+
memo = memo.parent
|
181
|
+
|
182
|
+
self.local_names.add(name)
|
183
|
+
return name
|
184
|
+
|
185
|
+
def is_ignored_name(self, expression: expr | Expr | None) -> bool:
|
186
|
+
top_expression = (
|
187
|
+
expression.value if isinstance(expression, Expr) else expression
|
188
|
+
)
|
189
|
+
|
190
|
+
if isinstance(top_expression, Attribute) and isinstance(
|
191
|
+
top_expression.value, Name
|
192
|
+
):
|
193
|
+
name = top_expression.value.id
|
194
|
+
elif isinstance(top_expression, Name):
|
195
|
+
name = top_expression.id
|
196
|
+
else:
|
197
|
+
return False
|
198
|
+
|
199
|
+
memo: TransformMemo | None = self
|
200
|
+
while memo is not None:
|
201
|
+
if name in memo.ignored_names:
|
202
|
+
return True
|
203
|
+
|
204
|
+
memo = memo.parent
|
205
|
+
|
206
|
+
return False
|
207
|
+
|
208
|
+
def get_memo_name(self) -> Name:
|
209
|
+
if not self.memo_var_name:
|
210
|
+
self.memo_var_name = Name(id="memo", ctx=Load())
|
211
|
+
|
212
|
+
return self.memo_var_name
|
213
|
+
|
214
|
+
def get_import(self, module: str, name: str) -> Name:
|
215
|
+
if module in self.load_names and name in self.load_names[module]:
|
216
|
+
return self.load_names[module][name]
|
217
|
+
|
218
|
+
qualified_name = f"{module}.{name}"
|
219
|
+
if name in self.imported_names and self.imported_names[name] == qualified_name:
|
220
|
+
return Name(id=name, ctx=Load())
|
221
|
+
|
222
|
+
alias = self.get_unused_name(name)
|
223
|
+
node = self.load_names[module][name] = Name(id=alias, ctx=Load())
|
224
|
+
self.imported_names[name] = qualified_name
|
225
|
+
return node
|
226
|
+
|
227
|
+
def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None:
|
228
|
+
"""Insert imports needed by injected code."""
|
229
|
+
if not self.load_names:
|
230
|
+
return
|
231
|
+
|
232
|
+
# Insert imports after any "from __future__ ..." imports and any docstring
|
233
|
+
for modulename, names in self.load_names.items():
|
234
|
+
aliases = [
|
235
|
+
alias(orig_name, new_name.id if orig_name != new_name.id else None)
|
236
|
+
for orig_name, new_name in sorted(names.items())
|
237
|
+
]
|
238
|
+
node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0))
|
239
|
+
|
240
|
+
def name_matches(self, expression: expr | Expr | None, *names: str) -> bool:
|
241
|
+
if expression is None:
|
242
|
+
return False
|
243
|
+
|
244
|
+
path: list[str] = []
|
245
|
+
top_expression = (
|
246
|
+
expression.value if isinstance(expression, Expr) else expression
|
247
|
+
)
|
248
|
+
|
249
|
+
if isinstance(top_expression, Subscript):
|
250
|
+
top_expression = top_expression.value
|
251
|
+
elif isinstance(top_expression, Call):
|
252
|
+
top_expression = top_expression.func
|
253
|
+
|
254
|
+
while isinstance(top_expression, Attribute):
|
255
|
+
path.insert(0, top_expression.attr)
|
256
|
+
top_expression = top_expression.value
|
257
|
+
|
258
|
+
if not isinstance(top_expression, Name):
|
259
|
+
return False
|
260
|
+
|
261
|
+
if top_expression.id in self.imported_names:
|
262
|
+
translated = self.imported_names[top_expression.id]
|
263
|
+
elif hasattr(builtins, top_expression.id):
|
264
|
+
translated = "builtins." + top_expression.id
|
265
|
+
else:
|
266
|
+
translated = top_expression.id
|
267
|
+
|
268
|
+
path.insert(0, translated)
|
269
|
+
joined_path = ".".join(path)
|
270
|
+
if joined_path in names:
|
271
|
+
return True
|
272
|
+
elif self.parent:
|
273
|
+
return self.parent.name_matches(expression, *names)
|
274
|
+
else:
|
275
|
+
return False
|
276
|
+
|
277
|
+
def get_config_keywords(self) -> list[keyword]:
|
278
|
+
if self.parent and isinstance(self.parent.node, ClassDef):
|
279
|
+
overrides = self.parent.configuration_overrides.copy()
|
280
|
+
else:
|
281
|
+
overrides = {}
|
282
|
+
|
283
|
+
overrides.update(self.configuration_overrides)
|
284
|
+
return [keyword(key, value) for key, value in overrides.items()]
|
285
|
+
|
286
|
+
|
287
|
+
class NameCollector(NodeVisitor):
|
288
|
+
def __init__(self) -> None:
|
289
|
+
self.names: set[str] = set()
|
290
|
+
|
291
|
+
def visit_Import(self, node: Import) -> None:
|
292
|
+
for name in node.names:
|
293
|
+
self.names.add(name.asname or name.name)
|
294
|
+
|
295
|
+
def visit_ImportFrom(self, node: ImportFrom) -> None:
|
296
|
+
for name in node.names:
|
297
|
+
self.names.add(name.asname or name.name)
|
298
|
+
|
299
|
+
def visit_Assign(self, node: Assign) -> None:
|
300
|
+
for target in node.targets:
|
301
|
+
if isinstance(target, Name):
|
302
|
+
self.names.add(target.id)
|
303
|
+
|
304
|
+
def visit_NamedExpr(self, node: NamedExpr) -> Any:
|
305
|
+
if isinstance(node.target, Name):
|
306
|
+
self.names.add(node.target.id)
|
307
|
+
|
308
|
+
def visit_FunctionDef(self, node: FunctionDef) -> None:
|
309
|
+
pass
|
310
|
+
|
311
|
+
def visit_ClassDef(self, node: ClassDef) -> None:
|
312
|
+
pass
|
313
|
+
|
314
|
+
|
315
|
+
class GeneratorDetector(NodeVisitor):
|
316
|
+
"""Detects if a function node is a generator function."""
|
317
|
+
|
318
|
+
contains_yields: bool = False
|
319
|
+
in_root_function: bool = False
|
320
|
+
|
321
|
+
def visit_Yield(self, node: Yield) -> Any:
|
322
|
+
self.contains_yields = True
|
323
|
+
|
324
|
+
def visit_YieldFrom(self, node: YieldFrom) -> Any:
|
325
|
+
self.contains_yields = True
|
326
|
+
|
327
|
+
def visit_ClassDef(self, node: ClassDef) -> Any:
|
328
|
+
pass
|
329
|
+
|
330
|
+
def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
|
331
|
+
if not self.in_root_function:
|
332
|
+
self.in_root_function = True
|
333
|
+
self.generic_visit(node)
|
334
|
+
self.in_root_function = False
|
335
|
+
|
336
|
+
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> Any:
|
337
|
+
self.visit_FunctionDef(node)
|
338
|
+
|
339
|
+
|
340
|
+
class AnnotationTransformer(NodeTransformer):
|
341
|
+
type_substitutions: ClassVar[dict[str, tuple[str, str]]] = {
|
342
|
+
"builtins.dict": ("typing", "Dict"),
|
343
|
+
"builtins.list": ("typing", "List"),
|
344
|
+
"builtins.tuple": ("typing", "Tuple"),
|
345
|
+
"builtins.set": ("typing", "Set"),
|
346
|
+
"builtins.frozenset": ("typing", "FrozenSet"),
|
347
|
+
}
|
348
|
+
|
349
|
+
def __init__(self, transformer: TypeguardTransformer):
|
350
|
+
self.transformer = transformer
|
351
|
+
self._memo = transformer._memo
|
352
|
+
self._level = 0
|
353
|
+
|
354
|
+
def visit(self, node: AST) -> Any:
|
355
|
+
self._level += 1
|
356
|
+
new_node = super().visit(node)
|
357
|
+
self._level -= 1
|
358
|
+
|
359
|
+
if isinstance(new_node, Expression) and not hasattr(new_node, "body"):
|
360
|
+
return None
|
361
|
+
|
362
|
+
# Return None if this new node matches a variation of typing.Any
|
363
|
+
if (
|
364
|
+
self._level == 0
|
365
|
+
and isinstance(new_node, expr)
|
366
|
+
and self._memo.name_matches(new_node, *anytype_names)
|
367
|
+
):
|
368
|
+
return None
|
369
|
+
|
370
|
+
return new_node
|
371
|
+
|
372
|
+
def generic_visit(self, node: AST) -> AST:
|
373
|
+
if isinstance(node, expr) and self._memo.name_matches(node, *literal_names):
|
374
|
+
return node
|
375
|
+
|
376
|
+
return super().generic_visit(node)
|
377
|
+
|
378
|
+
def visit_BinOp(self, node: BinOp) -> Any:
|
379
|
+
self.generic_visit(node)
|
380
|
+
|
381
|
+
if isinstance(node.op, BitOr):
|
382
|
+
# Return Any if either side is Any
|
383
|
+
if self._memo.name_matches(node.left, *anytype_names):
|
384
|
+
return node.left
|
385
|
+
elif self._memo.name_matches(node.right, *anytype_names):
|
386
|
+
return node.right
|
387
|
+
|
388
|
+
if sys.version_info < (3, 10):
|
389
|
+
union_name = self.transformer._get_import("typing", "Union")
|
390
|
+
return Subscript(
|
391
|
+
value=union_name,
|
392
|
+
slice=Index(
|
393
|
+
Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
|
394
|
+
),
|
395
|
+
ctx=Load(),
|
396
|
+
)
|
397
|
+
|
398
|
+
return node
|
399
|
+
|
400
|
+
def visit_Attribute(self, node: Attribute) -> Any:
|
401
|
+
if self._memo.is_ignored_name(node):
|
402
|
+
return None
|
403
|
+
|
404
|
+
return node
|
405
|
+
|
406
|
+
def visit_Subscript(self, node: Subscript) -> Any:
|
407
|
+
if self._memo.is_ignored_name(node.value):
|
408
|
+
return None
|
409
|
+
|
410
|
+
# The subscript of typing(_extensions).Literal can be any arbitrary string, so
|
411
|
+
# don't try to evaluate it as code
|
412
|
+
if node.slice:
|
413
|
+
if isinstance(node.slice, Index):
|
414
|
+
# Python 3.7 and 3.8
|
415
|
+
slice_value = node.slice.value # type: ignore[attr-defined]
|
416
|
+
else:
|
417
|
+
slice_value = node.slice
|
418
|
+
|
419
|
+
if isinstance(slice_value, Tuple):
|
420
|
+
if self._memo.name_matches(node.value, *annotated_names):
|
421
|
+
# Only treat the first argument to typing.Annotated as a potential
|
422
|
+
# forward reference
|
423
|
+
items = cast(
|
424
|
+
typing.List[expr],
|
425
|
+
[self.generic_visit(slice_value.elts[0])]
|
426
|
+
+ slice_value.elts[1:],
|
427
|
+
)
|
428
|
+
else:
|
429
|
+
items = cast(
|
430
|
+
typing.List[expr],
|
431
|
+
[self.generic_visit(item) for item in slice_value.elts],
|
432
|
+
)
|
433
|
+
|
434
|
+
# If this is a Union and any of the items is Any, erase the entire
|
435
|
+
# annotation
|
436
|
+
if self._memo.name_matches(node.value, "typing.Union") and any(
|
437
|
+
isinstance(item, expr)
|
438
|
+
and self._memo.name_matches(item, *anytype_names)
|
439
|
+
for item in items
|
440
|
+
):
|
441
|
+
return None
|
442
|
+
|
443
|
+
# If all items in the subscript were Any, erase the subscript entirely
|
444
|
+
if all(item is None for item in items):
|
445
|
+
return node.value
|
446
|
+
|
447
|
+
for index, item in enumerate(items):
|
448
|
+
if item is None:
|
449
|
+
items[index] = self.transformer._get_import("typing", "Any")
|
450
|
+
|
451
|
+
slice_value.elts = items
|
452
|
+
else:
|
453
|
+
self.generic_visit(node)
|
454
|
+
|
455
|
+
# If the transformer erased the slice entirely, just return the node
|
456
|
+
# value without the subscript (unless it's Optional, in which case erase
|
457
|
+
# the node entirely
|
458
|
+
if self._memo.name_matches(node.value, "typing.Optional"):
|
459
|
+
return None
|
460
|
+
elif sys.version_info >= (3, 9) and not hasattr(node, "slice"):
|
461
|
+
return node.value
|
462
|
+
elif sys.version_info < (3, 9) and not hasattr(node.slice, "value"):
|
463
|
+
return node.value
|
464
|
+
|
465
|
+
return node
|
466
|
+
|
467
|
+
def visit_Name(self, node: Name) -> Any:
|
468
|
+
if self._memo.is_ignored_name(node):
|
469
|
+
return None
|
470
|
+
|
471
|
+
if sys.version_info < (3, 9):
|
472
|
+
for typename, substitute in self.type_substitutions.items():
|
473
|
+
if self._memo.name_matches(node, typename):
|
474
|
+
new_node = self.transformer._get_import(*substitute)
|
475
|
+
return copy_location(new_node, node)
|
476
|
+
|
477
|
+
return node
|
478
|
+
|
479
|
+
def visit_Call(self, node: Call) -> Any:
|
480
|
+
# Don't recurse into calls
|
481
|
+
return node
|
482
|
+
|
483
|
+
def visit_Constant(self, node: Constant) -> Any:
|
484
|
+
if isinstance(node.value, str):
|
485
|
+
expression = ast.parse(node.value, mode="eval")
|
486
|
+
new_node = self.visit(expression)
|
487
|
+
if new_node:
|
488
|
+
return copy_location(new_node.body, node)
|
489
|
+
else:
|
490
|
+
return None
|
491
|
+
|
492
|
+
return node
|
493
|
+
|
494
|
+
def visit_Str(self, node: Str) -> Any:
|
495
|
+
# Only used on Python 3.7
|
496
|
+
expression = ast.parse(node.s, mode="eval")
|
497
|
+
new_node = self.visit(expression)
|
498
|
+
if new_node:
|
499
|
+
return copy_location(new_node.body, node)
|
500
|
+
else:
|
501
|
+
return None
|
502
|
+
|
503
|
+
|
504
|
+
class TypeguardTransformer(NodeTransformer):
|
505
|
+
def __init__(
|
506
|
+
self, target_path: Sequence[str] | None = None, target_lineno: int | None = None
|
507
|
+
) -> None:
|
508
|
+
self._target_path = tuple(target_path) if target_path else None
|
509
|
+
self._memo = self._module_memo = TransformMemo(None, None, ())
|
510
|
+
self.names_used_in_annotations: set[str] = set()
|
511
|
+
self.target_node: FunctionDef | AsyncFunctionDef | None = None
|
512
|
+
self.target_lineno = target_lineno
|
513
|
+
|
514
|
+
@contextmanager
|
515
|
+
def _use_memo(
|
516
|
+
self, node: ClassDef | FunctionDef | AsyncFunctionDef
|
517
|
+
) -> Generator[None, Any, None]:
|
518
|
+
new_memo = TransformMemo(node, self._memo, self._memo.path + (node.name,))
|
519
|
+
if isinstance(node, (FunctionDef, AsyncFunctionDef)):
|
520
|
+
new_memo.should_instrument = (
|
521
|
+
self._target_path is None or new_memo.path == self._target_path
|
522
|
+
)
|
523
|
+
if new_memo.should_instrument:
|
524
|
+
# Check if the function is a generator function
|
525
|
+
detector = GeneratorDetector()
|
526
|
+
detector.visit(node)
|
527
|
+
|
528
|
+
# Extract yield, send and return types where possible from a subscripted
|
529
|
+
# annotation like Generator[int, str, bool]
|
530
|
+
return_annotation = deepcopy(node.returns)
|
531
|
+
if detector.contains_yields and new_memo.name_matches(
|
532
|
+
return_annotation, *generator_names
|
533
|
+
):
|
534
|
+
if isinstance(return_annotation, Subscript):
|
535
|
+
annotation_slice = return_annotation.slice
|
536
|
+
|
537
|
+
# Python < 3.9
|
538
|
+
if isinstance(annotation_slice, Index):
|
539
|
+
annotation_slice = (
|
540
|
+
annotation_slice.value # type: ignore[attr-defined]
|
541
|
+
)
|
542
|
+
|
543
|
+
if isinstance(annotation_slice, Tuple):
|
544
|
+
items = annotation_slice.elts
|
545
|
+
else:
|
546
|
+
items = [annotation_slice]
|
547
|
+
|
548
|
+
if len(items) > 0:
|
549
|
+
new_memo.yield_annotation = self._convert_annotation(
|
550
|
+
items[0]
|
551
|
+
)
|
552
|
+
|
553
|
+
if len(items) > 1:
|
554
|
+
new_memo.send_annotation = self._convert_annotation(
|
555
|
+
items[1]
|
556
|
+
)
|
557
|
+
|
558
|
+
if len(items) > 2:
|
559
|
+
new_memo.return_annotation = self._convert_annotation(
|
560
|
+
items[2]
|
561
|
+
)
|
562
|
+
else:
|
563
|
+
new_memo.return_annotation = self._convert_annotation(
|
564
|
+
return_annotation
|
565
|
+
)
|
566
|
+
|
567
|
+
if isinstance(node, AsyncFunctionDef):
|
568
|
+
new_memo.is_async = True
|
569
|
+
|
570
|
+
old_memo = self._memo
|
571
|
+
self._memo = new_memo
|
572
|
+
yield
|
573
|
+
self._memo = old_memo
|
574
|
+
|
575
|
+
def _get_import(self, module: str, name: str) -> Name:
|
576
|
+
memo = self._memo if self._target_path else self._module_memo
|
577
|
+
return memo.get_import(module, name)
|
578
|
+
|
579
|
+
@overload
|
580
|
+
def _convert_annotation(self, annotation: None) -> None:
|
581
|
+
...
|
582
|
+
|
583
|
+
@overload
|
584
|
+
def _convert_annotation(self, annotation: expr) -> expr:
|
585
|
+
...
|
586
|
+
|
587
|
+
def _convert_annotation(self, annotation: expr | None) -> expr | None:
|
588
|
+
if annotation is None:
|
589
|
+
return None
|
590
|
+
|
591
|
+
# Convert PEP 604 unions (x | y) and generic built-in collections where
|
592
|
+
# necessary, and undo forward references
|
593
|
+
new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
|
594
|
+
if isinstance(new_annotation, expr):
|
595
|
+
new_annotation = ast.copy_location(new_annotation, annotation)
|
596
|
+
|
597
|
+
# Store names used in the annotation
|
598
|
+
names = {node.id for node in walk(new_annotation) if isinstance(node, Name)}
|
599
|
+
self.names_used_in_annotations.update(names)
|
600
|
+
|
601
|
+
return new_annotation
|
602
|
+
|
603
|
+
def visit_Name(self, node: Name) -> Name:
|
604
|
+
self._memo.local_names.add(node.id)
|
605
|
+
return node
|
606
|
+
|
607
|
+
def visit_Module(self, node: Module) -> Module:
|
608
|
+
self.generic_visit(node)
|
609
|
+
self._memo.insert_imports(node)
|
610
|
+
|
611
|
+
fix_missing_locations(node)
|
612
|
+
return node
|
613
|
+
|
614
|
+
def visit_Import(self, node: Import) -> Import:
|
615
|
+
for name in node.names:
|
616
|
+
self._memo.local_names.add(name.asname or name.name)
|
617
|
+
self._memo.imported_names[name.asname or name.name] = name.name
|
618
|
+
|
619
|
+
return node
|
620
|
+
|
621
|
+
def visit_ImportFrom(self, node: ImportFrom) -> ImportFrom:
|
622
|
+
for name in node.names:
|
623
|
+
if name.name != "*":
|
624
|
+
alias = name.asname or name.name
|
625
|
+
self._memo.local_names.add(alias)
|
626
|
+
self._memo.imported_names[alias] = f"{node.module}.{name.name}"
|
627
|
+
|
628
|
+
return node
|
629
|
+
|
630
|
+
def visit_ClassDef(self, node: ClassDef) -> ClassDef | None:
|
631
|
+
self._memo.local_names.add(node.name)
|
632
|
+
|
633
|
+
# Eliminate top level classes not belonging to the target path
|
634
|
+
if (
|
635
|
+
self._target_path is not None
|
636
|
+
and not self._memo.path
|
637
|
+
and node.name != self._target_path[0]
|
638
|
+
):
|
639
|
+
return None
|
640
|
+
|
641
|
+
with self._use_memo(node):
|
642
|
+
for decorator in node.decorator_list.copy():
|
643
|
+
if self._memo.name_matches(decorator, "typeguard.typechecked"):
|
644
|
+
# Remove the decorator to prevent duplicate instrumentation
|
645
|
+
node.decorator_list.remove(decorator)
|
646
|
+
|
647
|
+
# Store any configuration overrides
|
648
|
+
if isinstance(decorator, Call) and decorator.keywords:
|
649
|
+
self._memo.configuration_overrides.update(
|
650
|
+
{kw.arg: kw.value for kw in decorator.keywords if kw.arg}
|
651
|
+
)
|
652
|
+
|
653
|
+
self.generic_visit(node)
|
654
|
+
return node
|
655
|
+
|
656
|
+
def visit_FunctionDef(
|
657
|
+
self, node: FunctionDef | AsyncFunctionDef
|
658
|
+
) -> FunctionDef | AsyncFunctionDef | None:
|
659
|
+
"""
|
660
|
+
Injects type checks for function arguments, and for a return of None if the
|
661
|
+
function is annotated to return something else than Any or None, and the body
|
662
|
+
ends without an explicit "return".
|
663
|
+
|
664
|
+
"""
|
665
|
+
self._memo.local_names.add(node.name)
|
666
|
+
|
667
|
+
# Eliminate top level functions not belonging to the target path
|
668
|
+
if (
|
669
|
+
self._target_path is not None
|
670
|
+
and not self._memo.path
|
671
|
+
and node.name != self._target_path[0]
|
672
|
+
):
|
673
|
+
return None
|
674
|
+
|
675
|
+
# Skip instrumentation if we're instrumenting the whole module and the function
|
676
|
+
# contains either @no_type_check or @typeguard_ignore
|
677
|
+
if self._target_path is None:
|
678
|
+
for decorator in node.decorator_list:
|
679
|
+
if self._memo.name_matches(decorator, *ignore_decorators):
|
680
|
+
return node
|
681
|
+
|
682
|
+
with self._use_memo(node):
|
683
|
+
arg_annotations: dict[str, Any] = {}
|
684
|
+
if self._target_path is None or self._memo.path == self._target_path:
|
685
|
+
# Find line number we're supposed to match against
|
686
|
+
if node.decorator_list:
|
687
|
+
first_lineno = node.decorator_list[0].lineno
|
688
|
+
else:
|
689
|
+
first_lineno = node.lineno
|
690
|
+
|
691
|
+
for decorator in node.decorator_list.copy():
|
692
|
+
if self._memo.name_matches(decorator, "typing.overload"):
|
693
|
+
# Remove overloads entirely
|
694
|
+
return None
|
695
|
+
elif self._memo.name_matches(decorator, "typeguard.typechecked"):
|
696
|
+
# Remove the decorator to prevent duplicate instrumentation
|
697
|
+
node.decorator_list.remove(decorator)
|
698
|
+
|
699
|
+
# Store any configuration overrides
|
700
|
+
if isinstance(decorator, Call) and decorator.keywords:
|
701
|
+
self._memo.configuration_overrides = {
|
702
|
+
kw.arg: kw.value for kw in decorator.keywords if kw.arg
|
703
|
+
}
|
704
|
+
|
705
|
+
if self.target_lineno == first_lineno:
|
706
|
+
assert self.target_node is None
|
707
|
+
self.target_node = node
|
708
|
+
if node.decorator_list and sys.version_info >= (3, 8):
|
709
|
+
self.target_lineno = node.decorator_list[0].lineno
|
710
|
+
else:
|
711
|
+
self.target_lineno = node.lineno
|
712
|
+
|
713
|
+
all_args = node.args.args + node.args.kwonlyargs
|
714
|
+
if sys.version_info >= (3, 8):
|
715
|
+
all_args.extend(node.args.posonlyargs)
|
716
|
+
|
717
|
+
# Ensure that any type shadowed by the positional or keyword-only
|
718
|
+
# argument names are ignored in this function
|
719
|
+
for arg in all_args:
|
720
|
+
self._memo.ignored_names.add(arg.arg)
|
721
|
+
|
722
|
+
# Ensure that any type shadowed by the variable positional argument name
|
723
|
+
# (e.g. "args" in *args) is ignored this function
|
724
|
+
if node.args.vararg:
|
725
|
+
self._memo.ignored_names.add(node.args.vararg.arg)
|
726
|
+
|
727
|
+
# Ensure that any type shadowed by the variable keywrod argument name
|
728
|
+
# (e.g. "kwargs" in *kwargs) is ignored this function
|
729
|
+
if node.args.kwarg:
|
730
|
+
self._memo.ignored_names.add(node.args.kwarg.arg)
|
731
|
+
|
732
|
+
for arg in all_args:
|
733
|
+
annotation = self._convert_annotation(deepcopy(arg.annotation))
|
734
|
+
if annotation:
|
735
|
+
arg_annotations[arg.arg] = annotation
|
736
|
+
|
737
|
+
if node.args.vararg:
|
738
|
+
annotation_ = self._convert_annotation(node.args.vararg.annotation)
|
739
|
+
if annotation_:
|
740
|
+
if sys.version_info >= (3, 9):
|
741
|
+
container = Name("tuple", ctx=Load())
|
742
|
+
else:
|
743
|
+
container = self._get_import("typing", "Tuple")
|
744
|
+
|
745
|
+
subscript_slice: Tuple | Index = Tuple(
|
746
|
+
[
|
747
|
+
annotation_,
|
748
|
+
Constant(Ellipsis),
|
749
|
+
],
|
750
|
+
ctx=Load(),
|
751
|
+
)
|
752
|
+
if sys.version_info < (3, 9):
|
753
|
+
subscript_slice = Index(subscript_slice, ctx=Load())
|
754
|
+
|
755
|
+
arg_annotations[node.args.vararg.arg] = Subscript(
|
756
|
+
container, subscript_slice, ctx=Load()
|
757
|
+
)
|
758
|
+
|
759
|
+
if node.args.kwarg:
|
760
|
+
annotation_ = self._convert_annotation(node.args.kwarg.annotation)
|
761
|
+
if annotation_:
|
762
|
+
if sys.version_info >= (3, 9):
|
763
|
+
container = Name("dict", ctx=Load())
|
764
|
+
else:
|
765
|
+
container = self._get_import("typing", "Dict")
|
766
|
+
|
767
|
+
subscript_slice = Tuple(
|
768
|
+
[
|
769
|
+
Name("str", ctx=Load()),
|
770
|
+
annotation_,
|
771
|
+
],
|
772
|
+
ctx=Load(),
|
773
|
+
)
|
774
|
+
if sys.version_info < (3, 9):
|
775
|
+
subscript_slice = Index(subscript_slice, ctx=Load())
|
776
|
+
|
777
|
+
arg_annotations[node.args.kwarg.arg] = Subscript(
|
778
|
+
container, subscript_slice, ctx=Load()
|
779
|
+
)
|
780
|
+
|
781
|
+
if arg_annotations:
|
782
|
+
self._memo.variable_annotations.update(arg_annotations)
|
783
|
+
|
784
|
+
self.generic_visit(node)
|
785
|
+
|
786
|
+
if arg_annotations:
|
787
|
+
annotations_dict = Dict(
|
788
|
+
keys=[Constant(key) for key in arg_annotations.keys()],
|
789
|
+
values=[
|
790
|
+
Tuple([Name(key, ctx=Load()), annotation], ctx=Load())
|
791
|
+
for key, annotation in arg_annotations.items()
|
792
|
+
],
|
793
|
+
)
|
794
|
+
func_name = self._get_import(
|
795
|
+
"typeguard._functions", "check_argument_types"
|
796
|
+
)
|
797
|
+
args = [
|
798
|
+
self._memo.joined_path,
|
799
|
+
annotations_dict,
|
800
|
+
self._memo.get_memo_name(),
|
801
|
+
]
|
802
|
+
node.body.insert(
|
803
|
+
self._memo.code_inject_index, Expr(Call(func_name, args, []))
|
804
|
+
)
|
805
|
+
|
806
|
+
# Add a checked "return None" to the end if there's no explicit return
|
807
|
+
# Skip if the return annotation is None or Any
|
808
|
+
if (
|
809
|
+
self._memo.return_annotation
|
810
|
+
and (not self._memo.is_async or not self._memo.has_yield_expressions)
|
811
|
+
and not isinstance(node.body[-1], Return)
|
812
|
+
and (
|
813
|
+
not isinstance(self._memo.return_annotation, Constant)
|
814
|
+
or self._memo.return_annotation.value is not None
|
815
|
+
)
|
816
|
+
):
|
817
|
+
func_name = self._get_import(
|
818
|
+
"typeguard._functions", "check_return_type"
|
819
|
+
)
|
820
|
+
return_node = Return(
|
821
|
+
Call(
|
822
|
+
func_name,
|
823
|
+
[
|
824
|
+
self._memo.joined_path,
|
825
|
+
Constant(None),
|
826
|
+
self._memo.return_annotation,
|
827
|
+
self._memo.get_memo_name(),
|
828
|
+
],
|
829
|
+
[],
|
830
|
+
)
|
831
|
+
)
|
832
|
+
|
833
|
+
# Replace a placeholder "pass" at the end
|
834
|
+
if isinstance(node.body[-1], Pass):
|
835
|
+
copy_location(return_node, node.body[-1])
|
836
|
+
del node.body[-1]
|
837
|
+
|
838
|
+
node.body.append(return_node)
|
839
|
+
|
840
|
+
# Insert code to create the call memo, if it was ever needed for this
|
841
|
+
# function
|
842
|
+
if self._memo.memo_var_name:
|
843
|
+
memo_kwargs: dict[str, Any] = {}
|
844
|
+
if self._memo.parent and isinstance(self._memo.parent.node, ClassDef):
|
845
|
+
for decorator in node.decorator_list:
|
846
|
+
if (
|
847
|
+
isinstance(decorator, Name)
|
848
|
+
and decorator.id == "staticmethod"
|
849
|
+
):
|
850
|
+
break
|
851
|
+
elif (
|
852
|
+
isinstance(decorator, Name)
|
853
|
+
and decorator.id == "classmethod"
|
854
|
+
):
|
855
|
+
memo_kwargs["self_type"] = Name(
|
856
|
+
id=node.args.args[0].arg, ctx=Load()
|
857
|
+
)
|
858
|
+
break
|
859
|
+
else:
|
860
|
+
if node.args.args:
|
861
|
+
if node.name == "__new__":
|
862
|
+
memo_kwargs["self_type"] = Name(
|
863
|
+
id=node.args.args[0].arg, ctx=Load()
|
864
|
+
)
|
865
|
+
else:
|
866
|
+
memo_kwargs["self_type"] = Attribute(
|
867
|
+
Name(id=node.args.args[0].arg, ctx=Load()),
|
868
|
+
"__class__",
|
869
|
+
ctx=Load(),
|
870
|
+
)
|
871
|
+
|
872
|
+
# Construct the function reference
|
873
|
+
# Nested functions get special treatment: the function name is added
|
874
|
+
# to free variables (and the closure of the resulting function)
|
875
|
+
names: list[str] = [node.name]
|
876
|
+
memo = self._memo.parent
|
877
|
+
while memo:
|
878
|
+
if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
|
879
|
+
# This is a nested function. Use the function name as-is.
|
880
|
+
del names[:-1]
|
881
|
+
break
|
882
|
+
elif not isinstance(memo.node, ClassDef):
|
883
|
+
break
|
884
|
+
|
885
|
+
names.insert(0, memo.node.name)
|
886
|
+
memo = memo.parent
|
887
|
+
|
888
|
+
config_keywords = self._memo.get_config_keywords()
|
889
|
+
if config_keywords:
|
890
|
+
memo_kwargs["config"] = Call(
|
891
|
+
self._get_import("dataclasses", "replace"),
|
892
|
+
[self._get_import("typeguard._config", "global_config")],
|
893
|
+
config_keywords,
|
894
|
+
)
|
895
|
+
|
896
|
+
self._memo.memo_var_name.id = self._memo.get_unused_name("memo")
|
897
|
+
memo_store_name = Name(id=self._memo.memo_var_name.id, ctx=Store())
|
898
|
+
globals_call = Call(Name(id="globals", ctx=Load()), [], [])
|
899
|
+
locals_call = Call(Name(id="locals", ctx=Load()), [], [])
|
900
|
+
memo_expr = Call(
|
901
|
+
self._get_import("typeguard", "TypeCheckMemo"),
|
902
|
+
[globals_call, locals_call],
|
903
|
+
[keyword(key, value) for key, value in memo_kwargs.items()],
|
904
|
+
)
|
905
|
+
node.body.insert(
|
906
|
+
self._memo.code_inject_index,
|
907
|
+
Assign([memo_store_name], memo_expr),
|
908
|
+
)
|
909
|
+
|
910
|
+
self._memo.insert_imports(node)
|
911
|
+
|
912
|
+
# Rmove any placeholder "pass" at the end
|
913
|
+
if isinstance(node.body[-1], Pass):
|
914
|
+
del node.body[-1]
|
915
|
+
|
916
|
+
return node
|
917
|
+
|
918
|
+
def visit_AsyncFunctionDef(
|
919
|
+
self, node: AsyncFunctionDef
|
920
|
+
) -> FunctionDef | AsyncFunctionDef | None:
|
921
|
+
return self.visit_FunctionDef(node)
|
922
|
+
|
923
|
+
def visit_Return(self, node: Return) -> Return:
|
924
|
+
"""This injects type checks into "return" statements."""
|
925
|
+
self.generic_visit(node)
|
926
|
+
if (
|
927
|
+
self._memo.return_annotation
|
928
|
+
and self._memo.should_instrument
|
929
|
+
and not self._memo.is_ignored_name(self._memo.return_annotation)
|
930
|
+
):
|
931
|
+
func_name = self._get_import("typeguard._functions", "check_return_type")
|
932
|
+
old_node = node
|
933
|
+
retval = old_node.value or Constant(None)
|
934
|
+
node = Return(
|
935
|
+
Call(
|
936
|
+
func_name,
|
937
|
+
[
|
938
|
+
self._memo.joined_path,
|
939
|
+
retval,
|
940
|
+
self._memo.return_annotation,
|
941
|
+
self._memo.get_memo_name(),
|
942
|
+
],
|
943
|
+
[],
|
944
|
+
)
|
945
|
+
)
|
946
|
+
copy_location(node, old_node)
|
947
|
+
|
948
|
+
return node
|
949
|
+
|
950
|
+
def visit_Yield(self, node: Yield) -> Yield | Call:
|
951
|
+
"""
|
952
|
+
This injects type checks into "yield" expressions, checking both the yielded
|
953
|
+
value and the value sent back to the generator, when appropriate.
|
954
|
+
|
955
|
+
"""
|
956
|
+
self._memo.has_yield_expressions = True
|
957
|
+
self.generic_visit(node)
|
958
|
+
|
959
|
+
if (
|
960
|
+
self._memo.yield_annotation
|
961
|
+
and self._memo.should_instrument
|
962
|
+
and not self._memo.is_ignored_name(self._memo.yield_annotation)
|
963
|
+
):
|
964
|
+
func_name = self._get_import("typeguard._functions", "check_yield_type")
|
965
|
+
yieldval = node.value or Constant(None)
|
966
|
+
node.value = Call(
|
967
|
+
func_name,
|
968
|
+
[
|
969
|
+
self._memo.joined_path,
|
970
|
+
yieldval,
|
971
|
+
self._memo.yield_annotation,
|
972
|
+
self._memo.get_memo_name(),
|
973
|
+
],
|
974
|
+
[],
|
975
|
+
)
|
976
|
+
|
977
|
+
if (
|
978
|
+
self._memo.send_annotation
|
979
|
+
and self._memo.should_instrument
|
980
|
+
and not self._memo.is_ignored_name(self._memo.send_annotation)
|
981
|
+
):
|
982
|
+
func_name = self._get_import("typeguard._functions", "check_send_type")
|
983
|
+
old_node = node
|
984
|
+
call_node = Call(
|
985
|
+
func_name,
|
986
|
+
[
|
987
|
+
self._memo.joined_path,
|
988
|
+
old_node,
|
989
|
+
self._memo.send_annotation,
|
990
|
+
self._memo.get_memo_name(),
|
991
|
+
],
|
992
|
+
[],
|
993
|
+
)
|
994
|
+
copy_location(call_node, old_node)
|
995
|
+
return call_node
|
996
|
+
|
997
|
+
return node
|
998
|
+
|
999
|
+
def visit_AnnAssign(self, node: AnnAssign) -> Any:
|
1000
|
+
"""
|
1001
|
+
This injects a type check into a local variable annotation-assignment within a
|
1002
|
+
function body.
|
1003
|
+
|
1004
|
+
"""
|
1005
|
+
self.generic_visit(node)
|
1006
|
+
|
1007
|
+
if (
|
1008
|
+
isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef))
|
1009
|
+
and node.annotation
|
1010
|
+
and isinstance(node.target, Name)
|
1011
|
+
):
|
1012
|
+
self._memo.ignored_names.add(node.target.id)
|
1013
|
+
annotation = self._convert_annotation(deepcopy(node.annotation))
|
1014
|
+
if annotation:
|
1015
|
+
self._memo.variable_annotations[node.target.id] = annotation
|
1016
|
+
if node.value:
|
1017
|
+
func_name = self._get_import(
|
1018
|
+
"typeguard._functions", "check_variable_assignment"
|
1019
|
+
)
|
1020
|
+
node.value = Call(
|
1021
|
+
func_name,
|
1022
|
+
[
|
1023
|
+
node.value,
|
1024
|
+
Constant(node.target.id),
|
1025
|
+
annotation,
|
1026
|
+
self._memo.get_memo_name(),
|
1027
|
+
],
|
1028
|
+
[],
|
1029
|
+
)
|
1030
|
+
|
1031
|
+
return node
|
1032
|
+
|
1033
|
+
def visit_Assign(self, node: Assign) -> Any:
|
1034
|
+
"""
|
1035
|
+
This injects a type check into a local variable assignment within a function
|
1036
|
+
body. The variable must have been annotated earlier in the function body.
|
1037
|
+
|
1038
|
+
"""
|
1039
|
+
self.generic_visit(node)
|
1040
|
+
|
1041
|
+
# Only instrument function-local assignments
|
1042
|
+
if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)):
|
1043
|
+
targets: list[dict[Constant, expr | None]] = []
|
1044
|
+
check_required = False
|
1045
|
+
for target in node.targets:
|
1046
|
+
elts: Sequence[expr]
|
1047
|
+
if isinstance(target, Name):
|
1048
|
+
elts = [target]
|
1049
|
+
elif isinstance(target, Tuple):
|
1050
|
+
elts = target.elts
|
1051
|
+
else:
|
1052
|
+
continue
|
1053
|
+
|
1054
|
+
annotations_: dict[Constant, expr | None] = {}
|
1055
|
+
for exp in elts:
|
1056
|
+
prefix = ""
|
1057
|
+
if isinstance(exp, Starred):
|
1058
|
+
exp = exp.value
|
1059
|
+
prefix = "*"
|
1060
|
+
|
1061
|
+
if isinstance(exp, Name):
|
1062
|
+
self._memo.ignored_names.add(exp.id)
|
1063
|
+
name = prefix + exp.id
|
1064
|
+
annotation = self._memo.variable_annotations.get(exp.id)
|
1065
|
+
if annotation:
|
1066
|
+
annotations_[Constant(name)] = annotation
|
1067
|
+
check_required = True
|
1068
|
+
else:
|
1069
|
+
annotations_[Constant(name)] = None
|
1070
|
+
|
1071
|
+
targets.append(annotations_)
|
1072
|
+
|
1073
|
+
if check_required:
|
1074
|
+
# Replace missing annotations with typing.Any
|
1075
|
+
for item in targets:
|
1076
|
+
for key, expression in item.items():
|
1077
|
+
if expression is None:
|
1078
|
+
item[key] = self._get_import("typing", "Any")
|
1079
|
+
|
1080
|
+
if len(targets) == 1 and len(targets[0]) == 1:
|
1081
|
+
func_name = self._get_import(
|
1082
|
+
"typeguard._functions", "check_variable_assignment"
|
1083
|
+
)
|
1084
|
+
target_varname = next(iter(targets[0]))
|
1085
|
+
node.value = Call(
|
1086
|
+
func_name,
|
1087
|
+
[
|
1088
|
+
node.value,
|
1089
|
+
target_varname,
|
1090
|
+
targets[0][target_varname],
|
1091
|
+
self._memo.get_memo_name(),
|
1092
|
+
],
|
1093
|
+
[],
|
1094
|
+
)
|
1095
|
+
elif targets:
|
1096
|
+
func_name = self._get_import(
|
1097
|
+
"typeguard._functions", "check_multi_variable_assignment"
|
1098
|
+
)
|
1099
|
+
targets_arg = List(
|
1100
|
+
[
|
1101
|
+
Dict(keys=list(target), values=list(target.values()))
|
1102
|
+
for target in targets
|
1103
|
+
],
|
1104
|
+
ctx=Load(),
|
1105
|
+
)
|
1106
|
+
node.value = Call(
|
1107
|
+
func_name,
|
1108
|
+
[node.value, targets_arg, self._memo.get_memo_name()],
|
1109
|
+
[],
|
1110
|
+
)
|
1111
|
+
|
1112
|
+
return node
|
1113
|
+
|
1114
|
+
def visit_NamedExpr(self, node: NamedExpr) -> Any:
|
1115
|
+
"""This injects a type check into an assignment expression (a := foo())."""
|
1116
|
+
self.generic_visit(node)
|
1117
|
+
|
1118
|
+
# Only instrument function-local assignments
|
1119
|
+
if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
|
1120
|
+
node.target, Name
|
1121
|
+
):
|
1122
|
+
self._memo.ignored_names.add(node.target.id)
|
1123
|
+
|
1124
|
+
# Bail out if no matching annotation is found
|
1125
|
+
annotation = self._memo.variable_annotations.get(node.target.id)
|
1126
|
+
if annotation is None:
|
1127
|
+
return node
|
1128
|
+
|
1129
|
+
func_name = self._get_import(
|
1130
|
+
"typeguard._functions", "check_variable_assignment"
|
1131
|
+
)
|
1132
|
+
node.value = Call(
|
1133
|
+
func_name,
|
1134
|
+
[
|
1135
|
+
node.value,
|
1136
|
+
Constant(node.target.id),
|
1137
|
+
annotation,
|
1138
|
+
self._memo.get_memo_name(),
|
1139
|
+
],
|
1140
|
+
[],
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
return node
|
1144
|
+
|
1145
|
+
def visit_AugAssign(self, node: AugAssign) -> Any:
|
1146
|
+
"""
|
1147
|
+
This injects a type check into an augmented assignment expression (a += 1).
|
1148
|
+
|
1149
|
+
"""
|
1150
|
+
self.generic_visit(node)
|
1151
|
+
|
1152
|
+
# Only instrument function-local assignments
|
1153
|
+
if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
|
1154
|
+
node.target, Name
|
1155
|
+
):
|
1156
|
+
# Bail out if no matching annotation is found
|
1157
|
+
annotation = self._memo.variable_annotations.get(node.target.id)
|
1158
|
+
if annotation is None:
|
1159
|
+
return node
|
1160
|
+
|
1161
|
+
# Bail out if the operator is not found (newer Python version?)
|
1162
|
+
try:
|
1163
|
+
operator_func_name = aug_assign_functions[node.op.__class__]
|
1164
|
+
except KeyError:
|
1165
|
+
return node
|
1166
|
+
|
1167
|
+
operator_func = self._get_import("operator", operator_func_name)
|
1168
|
+
operator_call = Call(
|
1169
|
+
operator_func, [Name(node.target.id, ctx=Load()), node.value], []
|
1170
|
+
)
|
1171
|
+
check_call = Call(
|
1172
|
+
self._get_import("typeguard._functions", "check_variable_assignment"),
|
1173
|
+
[
|
1174
|
+
operator_call,
|
1175
|
+
Constant(node.target.id),
|
1176
|
+
annotation,
|
1177
|
+
self._memo.get_memo_name(),
|
1178
|
+
],
|
1179
|
+
[],
|
1180
|
+
)
|
1181
|
+
return Assign(targets=[node.target], value=check_call)
|
1182
|
+
|
1183
|
+
return node
|
1184
|
+
|
1185
|
+
def visit_If(self, node: If) -> Any:
|
1186
|
+
"""
|
1187
|
+
This blocks names from being collected from a module-level
|
1188
|
+
"if typing.TYPE_CHECKING:" block, so that they won't be type checked.
|
1189
|
+
|
1190
|
+
"""
|
1191
|
+
self.generic_visit(node)
|
1192
|
+
|
1193
|
+
# Fix empty node body (caused by removal of classes/functions not on the target
|
1194
|
+
# path)
|
1195
|
+
if not node.body:
|
1196
|
+
node.body.append(Pass())
|
1197
|
+
|
1198
|
+
if (
|
1199
|
+
self._memo is self._module_memo
|
1200
|
+
and isinstance(node.test, Name)
|
1201
|
+
and self._memo.name_matches(node.test, "typing.TYPE_CHECKING")
|
1202
|
+
):
|
1203
|
+
collector = NameCollector()
|
1204
|
+
collector.visit(node)
|
1205
|
+
self._memo.ignored_names.update(collector.names)
|
1206
|
+
|
1207
|
+
return node
|