metaflow 2.15.5__py2.py3-none-any.whl → 2.15.7__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. metaflow/_vendor/typeguard/_checkers.py +259 -95
  2. metaflow/_vendor/typeguard/_config.py +4 -4
  3. metaflow/_vendor/typeguard/_decorators.py +8 -12
  4. metaflow/_vendor/typeguard/_functions.py +33 -32
  5. metaflow/_vendor/typeguard/_pytest_plugin.py +40 -13
  6. metaflow/_vendor/typeguard/_suppression.py +3 -5
  7. metaflow/_vendor/typeguard/_transformer.py +84 -48
  8. metaflow/_vendor/typeguard/_union_transformer.py +1 -0
  9. metaflow/_vendor/typeguard/_utils.py +13 -9
  10. metaflow/_vendor/typing_extensions.py +1088 -500
  11. metaflow/_vendor/v3_7/__init__.py +1 -0
  12. metaflow/_vendor/v3_7/importlib_metadata/__init__.py +1063 -0
  13. metaflow/_vendor/v3_7/importlib_metadata/_adapters.py +68 -0
  14. metaflow/_vendor/v3_7/importlib_metadata/_collections.py +30 -0
  15. metaflow/_vendor/v3_7/importlib_metadata/_compat.py +71 -0
  16. metaflow/_vendor/v3_7/importlib_metadata/_functools.py +104 -0
  17. metaflow/_vendor/v3_7/importlib_metadata/_itertools.py +73 -0
  18. metaflow/_vendor/v3_7/importlib_metadata/_meta.py +48 -0
  19. metaflow/_vendor/v3_7/importlib_metadata/_text.py +99 -0
  20. metaflow/_vendor/v3_7/importlib_metadata/py.typed +0 -0
  21. metaflow/_vendor/v3_7/typeguard/__init__.py +48 -0
  22. metaflow/_vendor/v3_7/typeguard/_checkers.py +906 -0
  23. metaflow/_vendor/v3_7/typeguard/_config.py +108 -0
  24. metaflow/_vendor/v3_7/typeguard/_decorators.py +237 -0
  25. metaflow/_vendor/v3_7/typeguard/_exceptions.py +42 -0
  26. metaflow/_vendor/v3_7/typeguard/_functions.py +310 -0
  27. metaflow/_vendor/v3_7/typeguard/_importhook.py +213 -0
  28. metaflow/_vendor/v3_7/typeguard/_memo.py +48 -0
  29. metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py +100 -0
  30. metaflow/_vendor/v3_7/typeguard/_suppression.py +88 -0
  31. metaflow/_vendor/v3_7/typeguard/_transformer.py +1207 -0
  32. metaflow/_vendor/v3_7/typeguard/_union_transformer.py +54 -0
  33. metaflow/_vendor/v3_7/typeguard/_utils.py +169 -0
  34. metaflow/_vendor/v3_7/typeguard/py.typed +0 -0
  35. metaflow/_vendor/v3_7/typing_extensions.py +3072 -0
  36. metaflow/_vendor/v3_7/zipp.py +329 -0
  37. metaflow/cmd/develop/stubs.py +1 -1
  38. metaflow/extension_support/__init__.py +1 -1
  39. metaflow/plugins/argo/argo_workflows.py +34 -11
  40. metaflow/plugins/argo/argo_workflows_deployer_objects.py +7 -6
  41. metaflow/plugins/pypi/utils.py +4 -0
  42. metaflow/runner/click_api.py +7 -2
  43. metaflow/vendor.py +1 -0
  44. metaflow/version.py +1 -1
  45. {metaflow-2.15.5.data → metaflow-2.15.7.data}/data/share/metaflow/devtools/Makefile +2 -2
  46. {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/METADATA +4 -3
  47. {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/RECORD +53 -27
  48. {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/WHEEL +1 -1
  49. {metaflow-2.15.5.data → metaflow-2.15.7.data}/data/share/metaflow/devtools/Tiltfile +0 -0
  50. {metaflow-2.15.5.data → metaflow-2.15.7.data}/data/share/metaflow/devtools/pick_services.sh +0 -0
  51. {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/entry_points.txt +0 -0
  52. {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info/licenses}/LICENSE +0 -0
  53. {metaflow-2.15.5.dist-info → metaflow-2.15.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1207 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import builtins
5
+ import sys
6
+ import typing
7
+ from ast import (
8
+ AST,
9
+ Add,
10
+ AnnAssign,
11
+ Assign,
12
+ AsyncFunctionDef,
13
+ Attribute,
14
+ AugAssign,
15
+ BinOp,
16
+ BitAnd,
17
+ BitOr,
18
+ BitXor,
19
+ Call,
20
+ ClassDef,
21
+ Constant,
22
+ Dict,
23
+ Div,
24
+ Expr,
25
+ Expression,
26
+ FloorDiv,
27
+ FunctionDef,
28
+ If,
29
+ Import,
30
+ ImportFrom,
31
+ Index,
32
+ List,
33
+ Load,
34
+ LShift,
35
+ MatMult,
36
+ Mod,
37
+ Module,
38
+ Mult,
39
+ Name,
40
+ NodeTransformer,
41
+ NodeVisitor,
42
+ Pass,
43
+ Pow,
44
+ Return,
45
+ RShift,
46
+ Starred,
47
+ Store,
48
+ Str,
49
+ Sub,
50
+ Subscript,
51
+ Tuple,
52
+ Yield,
53
+ YieldFrom,
54
+ alias,
55
+ copy_location,
56
+ expr,
57
+ fix_missing_locations,
58
+ keyword,
59
+ walk,
60
+ )
61
+ from collections import defaultdict
62
+ from collections.abc import Generator, Sequence
63
+ from contextlib import contextmanager
64
+ from copy import deepcopy
65
+ from dataclasses import dataclass, field
66
+ from typing import Any, ClassVar, cast, overload
67
+
68
+ if sys.version_info >= (3, 8):
69
+ from ast import NamedExpr
70
+
71
+ generator_names = (
72
+ "typing.Generator",
73
+ "collections.abc.Generator",
74
+ "typing.Iterator",
75
+ "collections.abc.Iterator",
76
+ "typing.Iterable",
77
+ "collections.abc.Iterable",
78
+ "typing.AsyncIterator",
79
+ "collections.abc.AsyncIterator",
80
+ "typing.AsyncIterable",
81
+ "collections.abc.AsyncIterable",
82
+ "typing.AsyncGenerator",
83
+ "collections.abc.AsyncGenerator",
84
+ )
85
+ anytype_names = (
86
+ "typing.Any",
87
+ "typing_extensions.Any",
88
+ )
89
+ literal_names = (
90
+ "typing.Literal",
91
+ "typing_extensions.Literal",
92
+ )
93
+ annotated_names = (
94
+ "typing.Annotated",
95
+ "typing_extensions.Annotated",
96
+ )
97
+ ignore_decorators = (
98
+ "typing.no_type_check",
99
+ "typeguard.typeguard_ignore",
100
+ )
101
+ aug_assign_functions = {
102
+ Add: "iadd",
103
+ Sub: "isub",
104
+ Mult: "imul",
105
+ MatMult: "imatmul",
106
+ Div: "itruediv",
107
+ FloorDiv: "ifloordiv",
108
+ Mod: "imod",
109
+ Pow: "ipow",
110
+ LShift: "ilshift",
111
+ RShift: "irshift",
112
+ BitAnd: "iand",
113
+ BitXor: "ixor",
114
+ BitOr: "ior",
115
+ }
116
+
117
+
118
+ @dataclass
119
+ class TransformMemo:
120
+ node: Module | ClassDef | FunctionDef | AsyncFunctionDef | None
121
+ parent: TransformMemo | None
122
+ path: tuple[str, ...]
123
+ joined_path: Constant = field(init=False)
124
+ return_annotation: expr | None = None
125
+ yield_annotation: expr | None = None
126
+ send_annotation: expr | None = None
127
+ is_async: bool = False
128
+ local_names: set[str] = field(init=False, default_factory=set)
129
+ imported_names: dict[str, str] = field(init=False, default_factory=dict)
130
+ ignored_names: set[str] = field(init=False, default_factory=set)
131
+ load_names: defaultdict[str, dict[str, Name]] = field(
132
+ init=False, default_factory=lambda: defaultdict(dict)
133
+ )
134
+ has_yield_expressions: bool = field(init=False, default=False)
135
+ has_return_expressions: bool = field(init=False, default=False)
136
+ memo_var_name: Name | None = field(init=False, default=None)
137
+ should_instrument: bool = field(init=False, default=True)
138
+ variable_annotations: dict[str, expr] = field(init=False, default_factory=dict)
139
+ configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict)
140
+ code_inject_index: int = field(init=False, default=0)
141
+
142
+ def __post_init__(self) -> None:
143
+ elements: list[str] = []
144
+ memo = self
145
+ while isinstance(memo.node, (ClassDef, FunctionDef, AsyncFunctionDef)):
146
+ elements.insert(0, memo.node.name)
147
+ if not memo.parent:
148
+ break
149
+
150
+ memo = memo.parent
151
+ if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
152
+ elements.insert(0, "<locals>")
153
+
154
+ self.joined_path = Constant(".".join(elements))
155
+
156
+ # Figure out where to insert instrumentation code
157
+ if self.node:
158
+ for index, child in enumerate(self.node.body):
159
+ if isinstance(child, ImportFrom) and child.module == "__future__":
160
+ # (module only) __future__ imports must come first
161
+ continue
162
+ elif isinstance(child, Expr):
163
+ if isinstance(child.value, Constant) and isinstance(
164
+ child.value.value, str
165
+ ):
166
+ continue # docstring
167
+ elif sys.version_info < (3, 8) and isinstance(child.value, Str):
168
+ continue # docstring
169
+
170
+ self.code_inject_index = index
171
+ break
172
+
173
+ def get_unused_name(self, name: str) -> str:
174
+ memo: TransformMemo | None = self
175
+ while memo is not None:
176
+ if name in memo.local_names:
177
+ memo = self
178
+ name += "_"
179
+ else:
180
+ memo = memo.parent
181
+
182
+ self.local_names.add(name)
183
+ return name
184
+
185
+ def is_ignored_name(self, expression: expr | Expr | None) -> bool:
186
+ top_expression = (
187
+ expression.value if isinstance(expression, Expr) else expression
188
+ )
189
+
190
+ if isinstance(top_expression, Attribute) and isinstance(
191
+ top_expression.value, Name
192
+ ):
193
+ name = top_expression.value.id
194
+ elif isinstance(top_expression, Name):
195
+ name = top_expression.id
196
+ else:
197
+ return False
198
+
199
+ memo: TransformMemo | None = self
200
+ while memo is not None:
201
+ if name in memo.ignored_names:
202
+ return True
203
+
204
+ memo = memo.parent
205
+
206
+ return False
207
+
208
+ def get_memo_name(self) -> Name:
209
+ if not self.memo_var_name:
210
+ self.memo_var_name = Name(id="memo", ctx=Load())
211
+
212
+ return self.memo_var_name
213
+
214
+ def get_import(self, module: str, name: str) -> Name:
215
+ if module in self.load_names and name in self.load_names[module]:
216
+ return self.load_names[module][name]
217
+
218
+ qualified_name = f"{module}.{name}"
219
+ if name in self.imported_names and self.imported_names[name] == qualified_name:
220
+ return Name(id=name, ctx=Load())
221
+
222
+ alias = self.get_unused_name(name)
223
+ node = self.load_names[module][name] = Name(id=alias, ctx=Load())
224
+ self.imported_names[name] = qualified_name
225
+ return node
226
+
227
+ def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None:
228
+ """Insert imports needed by injected code."""
229
+ if not self.load_names:
230
+ return
231
+
232
+ # Insert imports after any "from __future__ ..." imports and any docstring
233
+ for modulename, names in self.load_names.items():
234
+ aliases = [
235
+ alias(orig_name, new_name.id if orig_name != new_name.id else None)
236
+ for orig_name, new_name in sorted(names.items())
237
+ ]
238
+ node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0))
239
+
240
+ def name_matches(self, expression: expr | Expr | None, *names: str) -> bool:
241
+ if expression is None:
242
+ return False
243
+
244
+ path: list[str] = []
245
+ top_expression = (
246
+ expression.value if isinstance(expression, Expr) else expression
247
+ )
248
+
249
+ if isinstance(top_expression, Subscript):
250
+ top_expression = top_expression.value
251
+ elif isinstance(top_expression, Call):
252
+ top_expression = top_expression.func
253
+
254
+ while isinstance(top_expression, Attribute):
255
+ path.insert(0, top_expression.attr)
256
+ top_expression = top_expression.value
257
+
258
+ if not isinstance(top_expression, Name):
259
+ return False
260
+
261
+ if top_expression.id in self.imported_names:
262
+ translated = self.imported_names[top_expression.id]
263
+ elif hasattr(builtins, top_expression.id):
264
+ translated = "builtins." + top_expression.id
265
+ else:
266
+ translated = top_expression.id
267
+
268
+ path.insert(0, translated)
269
+ joined_path = ".".join(path)
270
+ if joined_path in names:
271
+ return True
272
+ elif self.parent:
273
+ return self.parent.name_matches(expression, *names)
274
+ else:
275
+ return False
276
+
277
+ def get_config_keywords(self) -> list[keyword]:
278
+ if self.parent and isinstance(self.parent.node, ClassDef):
279
+ overrides = self.parent.configuration_overrides.copy()
280
+ else:
281
+ overrides = {}
282
+
283
+ overrides.update(self.configuration_overrides)
284
+ return [keyword(key, value) for key, value in overrides.items()]
285
+
286
+
287
+ class NameCollector(NodeVisitor):
288
+ def __init__(self) -> None:
289
+ self.names: set[str] = set()
290
+
291
+ def visit_Import(self, node: Import) -> None:
292
+ for name in node.names:
293
+ self.names.add(name.asname or name.name)
294
+
295
+ def visit_ImportFrom(self, node: ImportFrom) -> None:
296
+ for name in node.names:
297
+ self.names.add(name.asname or name.name)
298
+
299
+ def visit_Assign(self, node: Assign) -> None:
300
+ for target in node.targets:
301
+ if isinstance(target, Name):
302
+ self.names.add(target.id)
303
+
304
+ def visit_NamedExpr(self, node: NamedExpr) -> Any:
305
+ if isinstance(node.target, Name):
306
+ self.names.add(node.target.id)
307
+
308
+ def visit_FunctionDef(self, node: FunctionDef) -> None:
309
+ pass
310
+
311
+ def visit_ClassDef(self, node: ClassDef) -> None:
312
+ pass
313
+
314
+
315
+ class GeneratorDetector(NodeVisitor):
316
+ """Detects if a function node is a generator function."""
317
+
318
+ contains_yields: bool = False
319
+ in_root_function: bool = False
320
+
321
+ def visit_Yield(self, node: Yield) -> Any:
322
+ self.contains_yields = True
323
+
324
+ def visit_YieldFrom(self, node: YieldFrom) -> Any:
325
+ self.contains_yields = True
326
+
327
+ def visit_ClassDef(self, node: ClassDef) -> Any:
328
+ pass
329
+
330
+ def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
331
+ if not self.in_root_function:
332
+ self.in_root_function = True
333
+ self.generic_visit(node)
334
+ self.in_root_function = False
335
+
336
+ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> Any:
337
+ self.visit_FunctionDef(node)
338
+
339
+
340
+ class AnnotationTransformer(NodeTransformer):
341
+ type_substitutions: ClassVar[dict[str, tuple[str, str]]] = {
342
+ "builtins.dict": ("typing", "Dict"),
343
+ "builtins.list": ("typing", "List"),
344
+ "builtins.tuple": ("typing", "Tuple"),
345
+ "builtins.set": ("typing", "Set"),
346
+ "builtins.frozenset": ("typing", "FrozenSet"),
347
+ }
348
+
349
+ def __init__(self, transformer: TypeguardTransformer):
350
+ self.transformer = transformer
351
+ self._memo = transformer._memo
352
+ self._level = 0
353
+
354
+ def visit(self, node: AST) -> Any:
355
+ self._level += 1
356
+ new_node = super().visit(node)
357
+ self._level -= 1
358
+
359
+ if isinstance(new_node, Expression) and not hasattr(new_node, "body"):
360
+ return None
361
+
362
+ # Return None if this new node matches a variation of typing.Any
363
+ if (
364
+ self._level == 0
365
+ and isinstance(new_node, expr)
366
+ and self._memo.name_matches(new_node, *anytype_names)
367
+ ):
368
+ return None
369
+
370
+ return new_node
371
+
372
+ def generic_visit(self, node: AST) -> AST:
373
+ if isinstance(node, expr) and self._memo.name_matches(node, *literal_names):
374
+ return node
375
+
376
+ return super().generic_visit(node)
377
+
378
+ def visit_BinOp(self, node: BinOp) -> Any:
379
+ self.generic_visit(node)
380
+
381
+ if isinstance(node.op, BitOr):
382
+ # Return Any if either side is Any
383
+ if self._memo.name_matches(node.left, *anytype_names):
384
+ return node.left
385
+ elif self._memo.name_matches(node.right, *anytype_names):
386
+ return node.right
387
+
388
+ if sys.version_info < (3, 10):
389
+ union_name = self.transformer._get_import("typing", "Union")
390
+ return Subscript(
391
+ value=union_name,
392
+ slice=Index(
393
+ Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
394
+ ),
395
+ ctx=Load(),
396
+ )
397
+
398
+ return node
399
+
400
+ def visit_Attribute(self, node: Attribute) -> Any:
401
+ if self._memo.is_ignored_name(node):
402
+ return None
403
+
404
+ return node
405
+
406
+ def visit_Subscript(self, node: Subscript) -> Any:
407
+ if self._memo.is_ignored_name(node.value):
408
+ return None
409
+
410
+ # The subscript of typing(_extensions).Literal can be any arbitrary string, so
411
+ # don't try to evaluate it as code
412
+ if node.slice:
413
+ if isinstance(node.slice, Index):
414
+ # Python 3.7 and 3.8
415
+ slice_value = node.slice.value # type: ignore[attr-defined]
416
+ else:
417
+ slice_value = node.slice
418
+
419
+ if isinstance(slice_value, Tuple):
420
+ if self._memo.name_matches(node.value, *annotated_names):
421
+ # Only treat the first argument to typing.Annotated as a potential
422
+ # forward reference
423
+ items = cast(
424
+ typing.List[expr],
425
+ [self.generic_visit(slice_value.elts[0])]
426
+ + slice_value.elts[1:],
427
+ )
428
+ else:
429
+ items = cast(
430
+ typing.List[expr],
431
+ [self.generic_visit(item) for item in slice_value.elts],
432
+ )
433
+
434
+ # If this is a Union and any of the items is Any, erase the entire
435
+ # annotation
436
+ if self._memo.name_matches(node.value, "typing.Union") and any(
437
+ isinstance(item, expr)
438
+ and self._memo.name_matches(item, *anytype_names)
439
+ for item in items
440
+ ):
441
+ return None
442
+
443
+ # If all items in the subscript were Any, erase the subscript entirely
444
+ if all(item is None for item in items):
445
+ return node.value
446
+
447
+ for index, item in enumerate(items):
448
+ if item is None:
449
+ items[index] = self.transformer._get_import("typing", "Any")
450
+
451
+ slice_value.elts = items
452
+ else:
453
+ self.generic_visit(node)
454
+
455
+ # If the transformer erased the slice entirely, just return the node
456
+ # value without the subscript (unless it's Optional, in which case erase
457
+ # the node entirely
458
+ if self._memo.name_matches(node.value, "typing.Optional"):
459
+ return None
460
+ elif sys.version_info >= (3, 9) and not hasattr(node, "slice"):
461
+ return node.value
462
+ elif sys.version_info < (3, 9) and not hasattr(node.slice, "value"):
463
+ return node.value
464
+
465
+ return node
466
+
467
+ def visit_Name(self, node: Name) -> Any:
468
+ if self._memo.is_ignored_name(node):
469
+ return None
470
+
471
+ if sys.version_info < (3, 9):
472
+ for typename, substitute in self.type_substitutions.items():
473
+ if self._memo.name_matches(node, typename):
474
+ new_node = self.transformer._get_import(*substitute)
475
+ return copy_location(new_node, node)
476
+
477
+ return node
478
+
479
+ def visit_Call(self, node: Call) -> Any:
480
+ # Don't recurse into calls
481
+ return node
482
+
483
+ def visit_Constant(self, node: Constant) -> Any:
484
+ if isinstance(node.value, str):
485
+ expression = ast.parse(node.value, mode="eval")
486
+ new_node = self.visit(expression)
487
+ if new_node:
488
+ return copy_location(new_node.body, node)
489
+ else:
490
+ return None
491
+
492
+ return node
493
+
494
+ def visit_Str(self, node: Str) -> Any:
495
+ # Only used on Python 3.7
496
+ expression = ast.parse(node.s, mode="eval")
497
+ new_node = self.visit(expression)
498
+ if new_node:
499
+ return copy_location(new_node.body, node)
500
+ else:
501
+ return None
502
+
503
+
504
+ class TypeguardTransformer(NodeTransformer):
505
+ def __init__(
506
+ self, target_path: Sequence[str] | None = None, target_lineno: int | None = None
507
+ ) -> None:
508
+ self._target_path = tuple(target_path) if target_path else None
509
+ self._memo = self._module_memo = TransformMemo(None, None, ())
510
+ self.names_used_in_annotations: set[str] = set()
511
+ self.target_node: FunctionDef | AsyncFunctionDef | None = None
512
+ self.target_lineno = target_lineno
513
+
514
+ @contextmanager
515
+ def _use_memo(
516
+ self, node: ClassDef | FunctionDef | AsyncFunctionDef
517
+ ) -> Generator[None, Any, None]:
518
+ new_memo = TransformMemo(node, self._memo, self._memo.path + (node.name,))
519
+ if isinstance(node, (FunctionDef, AsyncFunctionDef)):
520
+ new_memo.should_instrument = (
521
+ self._target_path is None or new_memo.path == self._target_path
522
+ )
523
+ if new_memo.should_instrument:
524
+ # Check if the function is a generator function
525
+ detector = GeneratorDetector()
526
+ detector.visit(node)
527
+
528
+ # Extract yield, send and return types where possible from a subscripted
529
+ # annotation like Generator[int, str, bool]
530
+ return_annotation = deepcopy(node.returns)
531
+ if detector.contains_yields and new_memo.name_matches(
532
+ return_annotation, *generator_names
533
+ ):
534
+ if isinstance(return_annotation, Subscript):
535
+ annotation_slice = return_annotation.slice
536
+
537
+ # Python < 3.9
538
+ if isinstance(annotation_slice, Index):
539
+ annotation_slice = (
540
+ annotation_slice.value # type: ignore[attr-defined]
541
+ )
542
+
543
+ if isinstance(annotation_slice, Tuple):
544
+ items = annotation_slice.elts
545
+ else:
546
+ items = [annotation_slice]
547
+
548
+ if len(items) > 0:
549
+ new_memo.yield_annotation = self._convert_annotation(
550
+ items[0]
551
+ )
552
+
553
+ if len(items) > 1:
554
+ new_memo.send_annotation = self._convert_annotation(
555
+ items[1]
556
+ )
557
+
558
+ if len(items) > 2:
559
+ new_memo.return_annotation = self._convert_annotation(
560
+ items[2]
561
+ )
562
+ else:
563
+ new_memo.return_annotation = self._convert_annotation(
564
+ return_annotation
565
+ )
566
+
567
+ if isinstance(node, AsyncFunctionDef):
568
+ new_memo.is_async = True
569
+
570
+ old_memo = self._memo
571
+ self._memo = new_memo
572
+ yield
573
+ self._memo = old_memo
574
+
575
+ def _get_import(self, module: str, name: str) -> Name:
576
+ memo = self._memo if self._target_path else self._module_memo
577
+ return memo.get_import(module, name)
578
+
579
+ @overload
580
+ def _convert_annotation(self, annotation: None) -> None:
581
+ ...
582
+
583
+ @overload
584
+ def _convert_annotation(self, annotation: expr) -> expr:
585
+ ...
586
+
587
+ def _convert_annotation(self, annotation: expr | None) -> expr | None:
588
+ if annotation is None:
589
+ return None
590
+
591
+ # Convert PEP 604 unions (x | y) and generic built-in collections where
592
+ # necessary, and undo forward references
593
+ new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
594
+ if isinstance(new_annotation, expr):
595
+ new_annotation = ast.copy_location(new_annotation, annotation)
596
+
597
+ # Store names used in the annotation
598
+ names = {node.id for node in walk(new_annotation) if isinstance(node, Name)}
599
+ self.names_used_in_annotations.update(names)
600
+
601
+ return new_annotation
602
+
603
+ def visit_Name(self, node: Name) -> Name:
604
+ self._memo.local_names.add(node.id)
605
+ return node
606
+
607
+ def visit_Module(self, node: Module) -> Module:
608
+ self.generic_visit(node)
609
+ self._memo.insert_imports(node)
610
+
611
+ fix_missing_locations(node)
612
+ return node
613
+
614
+ def visit_Import(self, node: Import) -> Import:
615
+ for name in node.names:
616
+ self._memo.local_names.add(name.asname or name.name)
617
+ self._memo.imported_names[name.asname or name.name] = name.name
618
+
619
+ return node
620
+
621
+ def visit_ImportFrom(self, node: ImportFrom) -> ImportFrom:
622
+ for name in node.names:
623
+ if name.name != "*":
624
+ alias = name.asname or name.name
625
+ self._memo.local_names.add(alias)
626
+ self._memo.imported_names[alias] = f"{node.module}.{name.name}"
627
+
628
+ return node
629
+
630
+ def visit_ClassDef(self, node: ClassDef) -> ClassDef | None:
631
+ self._memo.local_names.add(node.name)
632
+
633
+ # Eliminate top level classes not belonging to the target path
634
+ if (
635
+ self._target_path is not None
636
+ and not self._memo.path
637
+ and node.name != self._target_path[0]
638
+ ):
639
+ return None
640
+
641
+ with self._use_memo(node):
642
+ for decorator in node.decorator_list.copy():
643
+ if self._memo.name_matches(decorator, "typeguard.typechecked"):
644
+ # Remove the decorator to prevent duplicate instrumentation
645
+ node.decorator_list.remove(decorator)
646
+
647
+ # Store any configuration overrides
648
+ if isinstance(decorator, Call) and decorator.keywords:
649
+ self._memo.configuration_overrides.update(
650
+ {kw.arg: kw.value for kw in decorator.keywords if kw.arg}
651
+ )
652
+
653
+ self.generic_visit(node)
654
+ return node
655
+
656
+ def visit_FunctionDef(
657
+ self, node: FunctionDef | AsyncFunctionDef
658
+ ) -> FunctionDef | AsyncFunctionDef | None:
659
+ """
660
+ Injects type checks for function arguments, and for a return of None if the
661
+ function is annotated to return something else than Any or None, and the body
662
+ ends without an explicit "return".
663
+
664
+ """
665
+ self._memo.local_names.add(node.name)
666
+
667
+ # Eliminate top level functions not belonging to the target path
668
+ if (
669
+ self._target_path is not None
670
+ and not self._memo.path
671
+ and node.name != self._target_path[0]
672
+ ):
673
+ return None
674
+
675
+ # Skip instrumentation if we're instrumenting the whole module and the function
676
+ # contains either @no_type_check or @typeguard_ignore
677
+ if self._target_path is None:
678
+ for decorator in node.decorator_list:
679
+ if self._memo.name_matches(decorator, *ignore_decorators):
680
+ return node
681
+
682
+ with self._use_memo(node):
683
+ arg_annotations: dict[str, Any] = {}
684
+ if self._target_path is None or self._memo.path == self._target_path:
685
+ # Find line number we're supposed to match against
686
+ if node.decorator_list:
687
+ first_lineno = node.decorator_list[0].lineno
688
+ else:
689
+ first_lineno = node.lineno
690
+
691
+ for decorator in node.decorator_list.copy():
692
+ if self._memo.name_matches(decorator, "typing.overload"):
693
+ # Remove overloads entirely
694
+ return None
695
+ elif self._memo.name_matches(decorator, "typeguard.typechecked"):
696
+ # Remove the decorator to prevent duplicate instrumentation
697
+ node.decorator_list.remove(decorator)
698
+
699
+ # Store any configuration overrides
700
+ if isinstance(decorator, Call) and decorator.keywords:
701
+ self._memo.configuration_overrides = {
702
+ kw.arg: kw.value for kw in decorator.keywords if kw.arg
703
+ }
704
+
705
+ if self.target_lineno == first_lineno:
706
+ assert self.target_node is None
707
+ self.target_node = node
708
+ if node.decorator_list and sys.version_info >= (3, 8):
709
+ self.target_lineno = node.decorator_list[0].lineno
710
+ else:
711
+ self.target_lineno = node.lineno
712
+
713
+ all_args = node.args.args + node.args.kwonlyargs
714
+ if sys.version_info >= (3, 8):
715
+ all_args.extend(node.args.posonlyargs)
716
+
717
+ # Ensure that any type shadowed by the positional or keyword-only
718
+ # argument names are ignored in this function
719
+ for arg in all_args:
720
+ self._memo.ignored_names.add(arg.arg)
721
+
722
+ # Ensure that any type shadowed by the variable positional argument name
723
+ # (e.g. "args" in *args) is ignored this function
724
+ if node.args.vararg:
725
+ self._memo.ignored_names.add(node.args.vararg.arg)
726
+
727
+ # Ensure that any type shadowed by the variable keywrod argument name
728
+ # (e.g. "kwargs" in *kwargs) is ignored this function
729
+ if node.args.kwarg:
730
+ self._memo.ignored_names.add(node.args.kwarg.arg)
731
+
732
+ for arg in all_args:
733
+ annotation = self._convert_annotation(deepcopy(arg.annotation))
734
+ if annotation:
735
+ arg_annotations[arg.arg] = annotation
736
+
737
+ if node.args.vararg:
738
+ annotation_ = self._convert_annotation(node.args.vararg.annotation)
739
+ if annotation_:
740
+ if sys.version_info >= (3, 9):
741
+ container = Name("tuple", ctx=Load())
742
+ else:
743
+ container = self._get_import("typing", "Tuple")
744
+
745
+ subscript_slice: Tuple | Index = Tuple(
746
+ [
747
+ annotation_,
748
+ Constant(Ellipsis),
749
+ ],
750
+ ctx=Load(),
751
+ )
752
+ if sys.version_info < (3, 9):
753
+ subscript_slice = Index(subscript_slice, ctx=Load())
754
+
755
+ arg_annotations[node.args.vararg.arg] = Subscript(
756
+ container, subscript_slice, ctx=Load()
757
+ )
758
+
759
+ if node.args.kwarg:
760
+ annotation_ = self._convert_annotation(node.args.kwarg.annotation)
761
+ if annotation_:
762
+ if sys.version_info >= (3, 9):
763
+ container = Name("dict", ctx=Load())
764
+ else:
765
+ container = self._get_import("typing", "Dict")
766
+
767
+ subscript_slice = Tuple(
768
+ [
769
+ Name("str", ctx=Load()),
770
+ annotation_,
771
+ ],
772
+ ctx=Load(),
773
+ )
774
+ if sys.version_info < (3, 9):
775
+ subscript_slice = Index(subscript_slice, ctx=Load())
776
+
777
+ arg_annotations[node.args.kwarg.arg] = Subscript(
778
+ container, subscript_slice, ctx=Load()
779
+ )
780
+
781
+ if arg_annotations:
782
+ self._memo.variable_annotations.update(arg_annotations)
783
+
784
+ self.generic_visit(node)
785
+
786
+ if arg_annotations:
787
+ annotations_dict = Dict(
788
+ keys=[Constant(key) for key in arg_annotations.keys()],
789
+ values=[
790
+ Tuple([Name(key, ctx=Load()), annotation], ctx=Load())
791
+ for key, annotation in arg_annotations.items()
792
+ ],
793
+ )
794
+ func_name = self._get_import(
795
+ "typeguard._functions", "check_argument_types"
796
+ )
797
+ args = [
798
+ self._memo.joined_path,
799
+ annotations_dict,
800
+ self._memo.get_memo_name(),
801
+ ]
802
+ node.body.insert(
803
+ self._memo.code_inject_index, Expr(Call(func_name, args, []))
804
+ )
805
+
806
+ # Add a checked "return None" to the end if there's no explicit return
807
+ # Skip if the return annotation is None or Any
808
+ if (
809
+ self._memo.return_annotation
810
+ and (not self._memo.is_async or not self._memo.has_yield_expressions)
811
+ and not isinstance(node.body[-1], Return)
812
+ and (
813
+ not isinstance(self._memo.return_annotation, Constant)
814
+ or self._memo.return_annotation.value is not None
815
+ )
816
+ ):
817
+ func_name = self._get_import(
818
+ "typeguard._functions", "check_return_type"
819
+ )
820
+ return_node = Return(
821
+ Call(
822
+ func_name,
823
+ [
824
+ self._memo.joined_path,
825
+ Constant(None),
826
+ self._memo.return_annotation,
827
+ self._memo.get_memo_name(),
828
+ ],
829
+ [],
830
+ )
831
+ )
832
+
833
+ # Replace a placeholder "pass" at the end
834
+ if isinstance(node.body[-1], Pass):
835
+ copy_location(return_node, node.body[-1])
836
+ del node.body[-1]
837
+
838
+ node.body.append(return_node)
839
+
840
+ # Insert code to create the call memo, if it was ever needed for this
841
+ # function
842
+ if self._memo.memo_var_name:
843
+ memo_kwargs: dict[str, Any] = {}
844
+ if self._memo.parent and isinstance(self._memo.parent.node, ClassDef):
845
+ for decorator in node.decorator_list:
846
+ if (
847
+ isinstance(decorator, Name)
848
+ and decorator.id == "staticmethod"
849
+ ):
850
+ break
851
+ elif (
852
+ isinstance(decorator, Name)
853
+ and decorator.id == "classmethod"
854
+ ):
855
+ memo_kwargs["self_type"] = Name(
856
+ id=node.args.args[0].arg, ctx=Load()
857
+ )
858
+ break
859
+ else:
860
+ if node.args.args:
861
+ if node.name == "__new__":
862
+ memo_kwargs["self_type"] = Name(
863
+ id=node.args.args[0].arg, ctx=Load()
864
+ )
865
+ else:
866
+ memo_kwargs["self_type"] = Attribute(
867
+ Name(id=node.args.args[0].arg, ctx=Load()),
868
+ "__class__",
869
+ ctx=Load(),
870
+ )
871
+
872
+ # Construct the function reference
873
+ # Nested functions get special treatment: the function name is added
874
+ # to free variables (and the closure of the resulting function)
875
+ names: list[str] = [node.name]
876
+ memo = self._memo.parent
877
+ while memo:
878
+ if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
879
+ # This is a nested function. Use the function name as-is.
880
+ del names[:-1]
881
+ break
882
+ elif not isinstance(memo.node, ClassDef):
883
+ break
884
+
885
+ names.insert(0, memo.node.name)
886
+ memo = memo.parent
887
+
888
+ config_keywords = self._memo.get_config_keywords()
889
+ if config_keywords:
890
+ memo_kwargs["config"] = Call(
891
+ self._get_import("dataclasses", "replace"),
892
+ [self._get_import("typeguard._config", "global_config")],
893
+ config_keywords,
894
+ )
895
+
896
+ self._memo.memo_var_name.id = self._memo.get_unused_name("memo")
897
+ memo_store_name = Name(id=self._memo.memo_var_name.id, ctx=Store())
898
+ globals_call = Call(Name(id="globals", ctx=Load()), [], [])
899
+ locals_call = Call(Name(id="locals", ctx=Load()), [], [])
900
+ memo_expr = Call(
901
+ self._get_import("typeguard", "TypeCheckMemo"),
902
+ [globals_call, locals_call],
903
+ [keyword(key, value) for key, value in memo_kwargs.items()],
904
+ )
905
+ node.body.insert(
906
+ self._memo.code_inject_index,
907
+ Assign([memo_store_name], memo_expr),
908
+ )
909
+
910
+ self._memo.insert_imports(node)
911
+
912
+ # Rmove any placeholder "pass" at the end
913
+ if isinstance(node.body[-1], Pass):
914
+ del node.body[-1]
915
+
916
+ return node
917
+
918
+ def visit_AsyncFunctionDef(
919
+ self, node: AsyncFunctionDef
920
+ ) -> FunctionDef | AsyncFunctionDef | None:
921
+ return self.visit_FunctionDef(node)
922
+
923
+ def visit_Return(self, node: Return) -> Return:
924
+ """This injects type checks into "return" statements."""
925
+ self.generic_visit(node)
926
+ if (
927
+ self._memo.return_annotation
928
+ and self._memo.should_instrument
929
+ and not self._memo.is_ignored_name(self._memo.return_annotation)
930
+ ):
931
+ func_name = self._get_import("typeguard._functions", "check_return_type")
932
+ old_node = node
933
+ retval = old_node.value or Constant(None)
934
+ node = Return(
935
+ Call(
936
+ func_name,
937
+ [
938
+ self._memo.joined_path,
939
+ retval,
940
+ self._memo.return_annotation,
941
+ self._memo.get_memo_name(),
942
+ ],
943
+ [],
944
+ )
945
+ )
946
+ copy_location(node, old_node)
947
+
948
+ return node
949
+
950
+ def visit_Yield(self, node: Yield) -> Yield | Call:
951
+ """
952
+ This injects type checks into "yield" expressions, checking both the yielded
953
+ value and the value sent back to the generator, when appropriate.
954
+
955
+ """
956
+ self._memo.has_yield_expressions = True
957
+ self.generic_visit(node)
958
+
959
+ if (
960
+ self._memo.yield_annotation
961
+ and self._memo.should_instrument
962
+ and not self._memo.is_ignored_name(self._memo.yield_annotation)
963
+ ):
964
+ func_name = self._get_import("typeguard._functions", "check_yield_type")
965
+ yieldval = node.value or Constant(None)
966
+ node.value = Call(
967
+ func_name,
968
+ [
969
+ self._memo.joined_path,
970
+ yieldval,
971
+ self._memo.yield_annotation,
972
+ self._memo.get_memo_name(),
973
+ ],
974
+ [],
975
+ )
976
+
977
+ if (
978
+ self._memo.send_annotation
979
+ and self._memo.should_instrument
980
+ and not self._memo.is_ignored_name(self._memo.send_annotation)
981
+ ):
982
+ func_name = self._get_import("typeguard._functions", "check_send_type")
983
+ old_node = node
984
+ call_node = Call(
985
+ func_name,
986
+ [
987
+ self._memo.joined_path,
988
+ old_node,
989
+ self._memo.send_annotation,
990
+ self._memo.get_memo_name(),
991
+ ],
992
+ [],
993
+ )
994
+ copy_location(call_node, old_node)
995
+ return call_node
996
+
997
+ return node
998
+
999
+ def visit_AnnAssign(self, node: AnnAssign) -> Any:
1000
+ """
1001
+ This injects a type check into a local variable annotation-assignment within a
1002
+ function body.
1003
+
1004
+ """
1005
+ self.generic_visit(node)
1006
+
1007
+ if (
1008
+ isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef))
1009
+ and node.annotation
1010
+ and isinstance(node.target, Name)
1011
+ ):
1012
+ self._memo.ignored_names.add(node.target.id)
1013
+ annotation = self._convert_annotation(deepcopy(node.annotation))
1014
+ if annotation:
1015
+ self._memo.variable_annotations[node.target.id] = annotation
1016
+ if node.value:
1017
+ func_name = self._get_import(
1018
+ "typeguard._functions", "check_variable_assignment"
1019
+ )
1020
+ node.value = Call(
1021
+ func_name,
1022
+ [
1023
+ node.value,
1024
+ Constant(node.target.id),
1025
+ annotation,
1026
+ self._memo.get_memo_name(),
1027
+ ],
1028
+ [],
1029
+ )
1030
+
1031
+ return node
1032
+
1033
+ def visit_Assign(self, node: Assign) -> Any:
1034
+ """
1035
+ This injects a type check into a local variable assignment within a function
1036
+ body. The variable must have been annotated earlier in the function body.
1037
+
1038
+ """
1039
+ self.generic_visit(node)
1040
+
1041
+ # Only instrument function-local assignments
1042
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)):
1043
+ targets: list[dict[Constant, expr | None]] = []
1044
+ check_required = False
1045
+ for target in node.targets:
1046
+ elts: Sequence[expr]
1047
+ if isinstance(target, Name):
1048
+ elts = [target]
1049
+ elif isinstance(target, Tuple):
1050
+ elts = target.elts
1051
+ else:
1052
+ continue
1053
+
1054
+ annotations_: dict[Constant, expr | None] = {}
1055
+ for exp in elts:
1056
+ prefix = ""
1057
+ if isinstance(exp, Starred):
1058
+ exp = exp.value
1059
+ prefix = "*"
1060
+
1061
+ if isinstance(exp, Name):
1062
+ self._memo.ignored_names.add(exp.id)
1063
+ name = prefix + exp.id
1064
+ annotation = self._memo.variable_annotations.get(exp.id)
1065
+ if annotation:
1066
+ annotations_[Constant(name)] = annotation
1067
+ check_required = True
1068
+ else:
1069
+ annotations_[Constant(name)] = None
1070
+
1071
+ targets.append(annotations_)
1072
+
1073
+ if check_required:
1074
+ # Replace missing annotations with typing.Any
1075
+ for item in targets:
1076
+ for key, expression in item.items():
1077
+ if expression is None:
1078
+ item[key] = self._get_import("typing", "Any")
1079
+
1080
+ if len(targets) == 1 and len(targets[0]) == 1:
1081
+ func_name = self._get_import(
1082
+ "typeguard._functions", "check_variable_assignment"
1083
+ )
1084
+ target_varname = next(iter(targets[0]))
1085
+ node.value = Call(
1086
+ func_name,
1087
+ [
1088
+ node.value,
1089
+ target_varname,
1090
+ targets[0][target_varname],
1091
+ self._memo.get_memo_name(),
1092
+ ],
1093
+ [],
1094
+ )
1095
+ elif targets:
1096
+ func_name = self._get_import(
1097
+ "typeguard._functions", "check_multi_variable_assignment"
1098
+ )
1099
+ targets_arg = List(
1100
+ [
1101
+ Dict(keys=list(target), values=list(target.values()))
1102
+ for target in targets
1103
+ ],
1104
+ ctx=Load(),
1105
+ )
1106
+ node.value = Call(
1107
+ func_name,
1108
+ [node.value, targets_arg, self._memo.get_memo_name()],
1109
+ [],
1110
+ )
1111
+
1112
+ return node
1113
+
1114
+ def visit_NamedExpr(self, node: NamedExpr) -> Any:
1115
+ """This injects a type check into an assignment expression (a := foo())."""
1116
+ self.generic_visit(node)
1117
+
1118
+ # Only instrument function-local assignments
1119
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
1120
+ node.target, Name
1121
+ ):
1122
+ self._memo.ignored_names.add(node.target.id)
1123
+
1124
+ # Bail out if no matching annotation is found
1125
+ annotation = self._memo.variable_annotations.get(node.target.id)
1126
+ if annotation is None:
1127
+ return node
1128
+
1129
+ func_name = self._get_import(
1130
+ "typeguard._functions", "check_variable_assignment"
1131
+ )
1132
+ node.value = Call(
1133
+ func_name,
1134
+ [
1135
+ node.value,
1136
+ Constant(node.target.id),
1137
+ annotation,
1138
+ self._memo.get_memo_name(),
1139
+ ],
1140
+ [],
1141
+ )
1142
+
1143
+ return node
1144
+
1145
+ def visit_AugAssign(self, node: AugAssign) -> Any:
1146
+ """
1147
+ This injects a type check into an augmented assignment expression (a += 1).
1148
+
1149
+ """
1150
+ self.generic_visit(node)
1151
+
1152
+ # Only instrument function-local assignments
1153
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
1154
+ node.target, Name
1155
+ ):
1156
+ # Bail out if no matching annotation is found
1157
+ annotation = self._memo.variable_annotations.get(node.target.id)
1158
+ if annotation is None:
1159
+ return node
1160
+
1161
+ # Bail out if the operator is not found (newer Python version?)
1162
+ try:
1163
+ operator_func_name = aug_assign_functions[node.op.__class__]
1164
+ except KeyError:
1165
+ return node
1166
+
1167
+ operator_func = self._get_import("operator", operator_func_name)
1168
+ operator_call = Call(
1169
+ operator_func, [Name(node.target.id, ctx=Load()), node.value], []
1170
+ )
1171
+ check_call = Call(
1172
+ self._get_import("typeguard._functions", "check_variable_assignment"),
1173
+ [
1174
+ operator_call,
1175
+ Constant(node.target.id),
1176
+ annotation,
1177
+ self._memo.get_memo_name(),
1178
+ ],
1179
+ [],
1180
+ )
1181
+ return Assign(targets=[node.target], value=check_call)
1182
+
1183
+ return node
1184
+
1185
+ def visit_If(self, node: If) -> Any:
1186
+ """
1187
+ This blocks names from being collected from a module-level
1188
+ "if typing.TYPE_CHECKING:" block, so that they won't be type checked.
1189
+
1190
+ """
1191
+ self.generic_visit(node)
1192
+
1193
+ # Fix empty node body (caused by removal of classes/functions not on the target
1194
+ # path)
1195
+ if not node.body:
1196
+ node.body.append(Pass())
1197
+
1198
+ if (
1199
+ self._memo is self._module_memo
1200
+ and isinstance(node.test, Name)
1201
+ and self._memo.name_matches(node.test, "typing.TYPE_CHECKING")
1202
+ ):
1203
+ collector = NameCollector()
1204
+ collector.visit(node)
1205
+ self._memo.ignored_names.update(collector.names)
1206
+
1207
+ return node