metaflow 2.11.16__py2.py3-none-any.whl → 2.12.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. metaflow/__init__.py +5 -0
  2. metaflow/_vendor/importlib_metadata/__init__.py +1063 -0
  3. metaflow/_vendor/importlib_metadata/_adapters.py +68 -0
  4. metaflow/_vendor/importlib_metadata/_collections.py +30 -0
  5. metaflow/_vendor/importlib_metadata/_compat.py +71 -0
  6. metaflow/_vendor/importlib_metadata/_functools.py +104 -0
  7. metaflow/_vendor/importlib_metadata/_itertools.py +73 -0
  8. metaflow/_vendor/importlib_metadata/_meta.py +48 -0
  9. metaflow/_vendor/importlib_metadata/_text.py +99 -0
  10. metaflow/_vendor/importlib_metadata/py.typed +0 -0
  11. metaflow/_vendor/typeguard/__init__.py +48 -0
  12. metaflow/_vendor/typeguard/_checkers.py +906 -0
  13. metaflow/_vendor/typeguard/_config.py +108 -0
  14. metaflow/_vendor/typeguard/_decorators.py +237 -0
  15. metaflow/_vendor/typeguard/_exceptions.py +42 -0
  16. metaflow/_vendor/typeguard/_functions.py +307 -0
  17. metaflow/_vendor/typeguard/_importhook.py +213 -0
  18. metaflow/_vendor/typeguard/_memo.py +48 -0
  19. metaflow/_vendor/typeguard/_pytest_plugin.py +100 -0
  20. metaflow/_vendor/typeguard/_suppression.py +88 -0
  21. metaflow/_vendor/typeguard/_transformer.py +1193 -0
  22. metaflow/_vendor/typeguard/_union_transformer.py +54 -0
  23. metaflow/_vendor/typeguard/_utils.py +169 -0
  24. metaflow/_vendor/typeguard/py.typed +0 -0
  25. metaflow/_vendor/typing_extensions.py +3053 -0
  26. metaflow/cli.py +100 -43
  27. metaflow/cmd/develop/stubs.py +2 -0
  28. metaflow/decorators.py +16 -3
  29. metaflow/extension_support/__init__.py +2 -0
  30. metaflow/metaflow_config.py +21 -0
  31. metaflow/parameters.py +1 -0
  32. metaflow/plugins/argo/argo_workflows.py +10 -5
  33. metaflow/plugins/aws/batch/batch_decorator.py +3 -3
  34. metaflow/plugins/kubernetes/kubernetes_job.py +0 -5
  35. metaflow/runner/__init__.py +0 -0
  36. metaflow/runner/click_api.py +406 -0
  37. metaflow/runner/metaflow_runner.py +452 -0
  38. metaflow/runner/nbrun.py +246 -0
  39. metaflow/runner/subprocess_manager.py +552 -0
  40. metaflow/vendor.py +0 -1
  41. metaflow/version.py +1 -1
  42. {metaflow-2.11.16.dist-info → metaflow-2.12.1.dist-info}/METADATA +2 -2
  43. {metaflow-2.11.16.dist-info → metaflow-2.12.1.dist-info}/RECORD +48 -20
  44. metaflow/_vendor/v3_7/__init__.py +0 -1
  45. /metaflow/_vendor/{v3_7/zipp.py → zipp.py} +0 -0
  46. {metaflow-2.11.16.dist-info → metaflow-2.12.1.dist-info}/LICENSE +0 -0
  47. {metaflow-2.11.16.dist-info → metaflow-2.12.1.dist-info}/WHEEL +0 -0
  48. {metaflow-2.11.16.dist-info → metaflow-2.12.1.dist-info}/entry_points.txt +0 -0
  49. {metaflow-2.11.16.dist-info → metaflow-2.12.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1193 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import builtins
5
+ import sys
6
+ import typing
7
+ from ast import (
8
+ AST,
9
+ Add,
10
+ AnnAssign,
11
+ Assign,
12
+ AsyncFunctionDef,
13
+ Attribute,
14
+ AugAssign,
15
+ BinOp,
16
+ BitAnd,
17
+ BitOr,
18
+ BitXor,
19
+ Call,
20
+ ClassDef,
21
+ Constant,
22
+ Dict,
23
+ Div,
24
+ Expr,
25
+ Expression,
26
+ FloorDiv,
27
+ FunctionDef,
28
+ If,
29
+ Import,
30
+ ImportFrom,
31
+ Index,
32
+ List,
33
+ Load,
34
+ LShift,
35
+ MatMult,
36
+ Mod,
37
+ Module,
38
+ Mult,
39
+ Name,
40
+ NodeTransformer,
41
+ NodeVisitor,
42
+ Pass,
43
+ Pow,
44
+ Return,
45
+ RShift,
46
+ Starred,
47
+ Store,
48
+ Str,
49
+ Sub,
50
+ Subscript,
51
+ Tuple,
52
+ Yield,
53
+ YieldFrom,
54
+ alias,
55
+ copy_location,
56
+ expr,
57
+ fix_missing_locations,
58
+ keyword,
59
+ walk,
60
+ )
61
+ from collections import defaultdict
62
+ from collections.abc import Generator, Sequence
63
+ from contextlib import contextmanager
64
+ from copy import deepcopy
65
+ from dataclasses import dataclass, field
66
+ from typing import Any, ClassVar, cast, overload
67
+
68
+ if sys.version_info >= (3, 8):
69
+ from ast import NamedExpr
70
+
71
+ generator_names = (
72
+ "typing.Generator",
73
+ "collections.abc.Generator",
74
+ "typing.Iterator",
75
+ "collections.abc.Iterator",
76
+ "typing.Iterable",
77
+ "collections.abc.Iterable",
78
+ "typing.AsyncIterator",
79
+ "collections.abc.AsyncIterator",
80
+ "typing.AsyncIterable",
81
+ "collections.abc.AsyncIterable",
82
+ "typing.AsyncGenerator",
83
+ "collections.abc.AsyncGenerator",
84
+ )
85
+ anytype_names = (
86
+ "typing.Any",
87
+ "typing_extensions.Any",
88
+ )
89
+ literal_names = (
90
+ "typing.Literal",
91
+ "typing_extensions.Literal",
92
+ )
93
+ annotated_names = (
94
+ "typing.Annotated",
95
+ "typing_extensions.Annotated",
96
+ )
97
+ ignore_decorators = (
98
+ "typing.no_type_check",
99
+ "typeguard.typeguard_ignore",
100
+ )
101
+ aug_assign_functions = {
102
+ Add: "iadd",
103
+ Sub: "isub",
104
+ Mult: "imul",
105
+ MatMult: "imatmul",
106
+ Div: "itruediv",
107
+ FloorDiv: "ifloordiv",
108
+ Mod: "imod",
109
+ Pow: "ipow",
110
+ LShift: "ilshift",
111
+ RShift: "irshift",
112
+ BitAnd: "iand",
113
+ BitXor: "ixor",
114
+ BitOr: "ior",
115
+ }
116
+
117
+
118
+ @dataclass
119
+ class TransformMemo:
120
+ node: Module | ClassDef | FunctionDef | AsyncFunctionDef | None
121
+ parent: TransformMemo | None
122
+ path: tuple[str, ...]
123
+ joined_path: Constant = field(init=False)
124
+ return_annotation: expr | None = None
125
+ yield_annotation: expr | None = None
126
+ send_annotation: expr | None = None
127
+ is_async: bool = False
128
+ local_names: set[str] = field(init=False, default_factory=set)
129
+ imported_names: dict[str, str] = field(init=False, default_factory=dict)
130
+ ignored_names: set[str] = field(init=False, default_factory=set)
131
+ load_names: defaultdict[str, dict[str, Name]] = field(
132
+ init=False, default_factory=lambda: defaultdict(dict)
133
+ )
134
+ has_yield_expressions: bool = field(init=False, default=False)
135
+ has_return_expressions: bool = field(init=False, default=False)
136
+ memo_var_name: Name | None = field(init=False, default=None)
137
+ should_instrument: bool = field(init=False, default=True)
138
+ variable_annotations: dict[str, expr] = field(init=False, default_factory=dict)
139
+ configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict)
140
+ code_inject_index: int = field(init=False, default=0)
141
+
142
+ def __post_init__(self) -> None:
143
+ elements: list[str] = []
144
+ memo = self
145
+ while isinstance(memo.node, (ClassDef, FunctionDef, AsyncFunctionDef)):
146
+ elements.insert(0, memo.node.name)
147
+ if not memo.parent:
148
+ break
149
+
150
+ memo = memo.parent
151
+ if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
152
+ elements.insert(0, "<locals>")
153
+
154
+ self.joined_path = Constant(".".join(elements))
155
+
156
+ # Figure out where to insert instrumentation code
157
+ if self.node:
158
+ for index, child in enumerate(self.node.body):
159
+ if isinstance(child, ImportFrom) and child.module == "__future__":
160
+ # (module only) __future__ imports must come first
161
+ continue
162
+ elif isinstance(child, Expr):
163
+ if isinstance(child.value, Constant) and isinstance(
164
+ child.value.value, str
165
+ ):
166
+ continue # docstring
167
+ elif sys.version_info < (3, 8) and isinstance(child.value, Str):
168
+ continue # docstring
169
+
170
+ self.code_inject_index = index
171
+ break
172
+
173
+ def get_unused_name(self, name: str) -> str:
174
+ memo: TransformMemo | None = self
175
+ while memo is not None:
176
+ if name in memo.local_names:
177
+ memo = self
178
+ name += "_"
179
+ else:
180
+ memo = memo.parent
181
+
182
+ self.local_names.add(name)
183
+ return name
184
+
185
+ def is_ignored_name(self, expression: expr | Expr | None) -> bool:
186
+ top_expression = (
187
+ expression.value if isinstance(expression, Expr) else expression
188
+ )
189
+
190
+ if isinstance(top_expression, Attribute) and isinstance(
191
+ top_expression.value, Name
192
+ ):
193
+ name = top_expression.value.id
194
+ elif isinstance(top_expression, Name):
195
+ name = top_expression.id
196
+ else:
197
+ return False
198
+
199
+ memo: TransformMemo | None = self
200
+ while memo is not None:
201
+ if name in memo.ignored_names:
202
+ return True
203
+
204
+ memo = memo.parent
205
+
206
+ return False
207
+
208
+ def get_memo_name(self) -> Name:
209
+ if not self.memo_var_name:
210
+ self.memo_var_name = Name(id="memo", ctx=Load())
211
+
212
+ return self.memo_var_name
213
+
214
+ def get_import(self, module: str, name: str) -> Name:
215
+ if module in self.load_names and name in self.load_names[module]:
216
+ return self.load_names[module][name]
217
+
218
+ qualified_name = f"{module}.{name}"
219
+ if name in self.imported_names and self.imported_names[name] == qualified_name:
220
+ return Name(id=name, ctx=Load())
221
+
222
+ alias = self.get_unused_name(name)
223
+ node = self.load_names[module][name] = Name(id=alias, ctx=Load())
224
+ self.imported_names[name] = qualified_name
225
+ return node
226
+
227
+ def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None:
228
+ """Insert imports needed by injected code."""
229
+ if not self.load_names:
230
+ return
231
+
232
+ # Insert imports after any "from __future__ ..." imports and any docstring
233
+ for modulename, names in self.load_names.items():
234
+ aliases = [
235
+ alias(orig_name, new_name.id if orig_name != new_name.id else None)
236
+ for orig_name, new_name in sorted(names.items())
237
+ ]
238
+ node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0))
239
+
240
+ def name_matches(self, expression: expr | Expr | None, *names: str) -> bool:
241
+ if expression is None:
242
+ return False
243
+
244
+ path: list[str] = []
245
+ top_expression = (
246
+ expression.value if isinstance(expression, Expr) else expression
247
+ )
248
+
249
+ if isinstance(top_expression, Subscript):
250
+ top_expression = top_expression.value
251
+ elif isinstance(top_expression, Call):
252
+ top_expression = top_expression.func
253
+
254
+ while isinstance(top_expression, Attribute):
255
+ path.insert(0, top_expression.attr)
256
+ top_expression = top_expression.value
257
+
258
+ if not isinstance(top_expression, Name):
259
+ return False
260
+
261
+ if top_expression.id in self.imported_names:
262
+ translated = self.imported_names[top_expression.id]
263
+ elif hasattr(builtins, top_expression.id):
264
+ translated = "builtins." + top_expression.id
265
+ else:
266
+ translated = top_expression.id
267
+
268
+ path.insert(0, translated)
269
+ joined_path = ".".join(path)
270
+ if joined_path in names:
271
+ return True
272
+ elif self.parent:
273
+ return self.parent.name_matches(expression, *names)
274
+ else:
275
+ return False
276
+
277
+ def get_config_keywords(self) -> list[keyword]:
278
+ if self.parent and isinstance(self.parent.node, ClassDef):
279
+ overrides = self.parent.configuration_overrides.copy()
280
+ else:
281
+ overrides = {}
282
+
283
+ overrides.update(self.configuration_overrides)
284
+ return [keyword(key, value) for key, value in overrides.items()]
285
+
286
+
287
+ class NameCollector(NodeVisitor):
288
+ def __init__(self) -> None:
289
+ self.names: set[str] = set()
290
+
291
+ def visit_Import(self, node: Import) -> None:
292
+ for name in node.names:
293
+ self.names.add(name.asname or name.name)
294
+
295
+ def visit_ImportFrom(self, node: ImportFrom) -> None:
296
+ for name in node.names:
297
+ self.names.add(name.asname or name.name)
298
+
299
+ def visit_Assign(self, node: Assign) -> None:
300
+ for target in node.targets:
301
+ if isinstance(target, Name):
302
+ self.names.add(target.id)
303
+
304
+ def visit_NamedExpr(self, node: NamedExpr) -> Any:
305
+ if isinstance(node.target, Name):
306
+ self.names.add(node.target.id)
307
+
308
+ def visit_FunctionDef(self, node: FunctionDef) -> None:
309
+ pass
310
+
311
+ def visit_ClassDef(self, node: ClassDef) -> None:
312
+ pass
313
+
314
+
315
+ class GeneratorDetector(NodeVisitor):
316
+ """Detects if a function node is a generator function."""
317
+
318
+ contains_yields: bool = False
319
+ in_root_function: bool = False
320
+
321
+ def visit_Yield(self, node: Yield) -> Any:
322
+ self.contains_yields = True
323
+
324
+ def visit_YieldFrom(self, node: YieldFrom) -> Any:
325
+ self.contains_yields = True
326
+
327
+ def visit_ClassDef(self, node: ClassDef) -> Any:
328
+ pass
329
+
330
+ def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
331
+ if not self.in_root_function:
332
+ self.in_root_function = True
333
+ self.generic_visit(node)
334
+ self.in_root_function = False
335
+
336
+ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> Any:
337
+ self.visit_FunctionDef(node)
338
+
339
+
340
+ class AnnotationTransformer(NodeTransformer):
341
+ type_substitutions: ClassVar[dict[str, tuple[str, str]]] = {
342
+ "builtins.dict": ("typing", "Dict"),
343
+ "builtins.list": ("typing", "List"),
344
+ "builtins.tuple": ("typing", "Tuple"),
345
+ "builtins.set": ("typing", "Set"),
346
+ "builtins.frozenset": ("typing", "FrozenSet"),
347
+ }
348
+
349
+ def __init__(self, transformer: TypeguardTransformer):
350
+ self.transformer = transformer
351
+ self._memo = transformer._memo
352
+
353
+ def visit(self, node: AST) -> Any:
354
+ new_node = super().visit(node)
355
+ if isinstance(new_node, Expression) and not hasattr(new_node, "body"):
356
+ return None
357
+
358
+ # Return None if this new node matches a variation of typing.Any
359
+ if isinstance(new_node, expr) and self._memo.name_matches(
360
+ new_node, *anytype_names
361
+ ):
362
+ return None
363
+
364
+ return new_node
365
+
366
+ def visit_BinOp(self, node: BinOp) -> Any:
367
+ self.generic_visit(node)
368
+
369
+ if isinstance(node.op, BitOr):
370
+ # If either side of the operation resolved to None, return None
371
+ if not hasattr(node, "left") or not hasattr(node, "right"):
372
+ return None
373
+
374
+ if sys.version_info < (3, 10):
375
+ union_name = self.transformer._get_import("typing", "Union")
376
+ return Subscript(
377
+ value=union_name,
378
+ slice=Index(
379
+ Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
380
+ ),
381
+ ctx=Load(),
382
+ )
383
+
384
+ return node
385
+
386
+ def visit_Attribute(self, node: Attribute) -> Any:
387
+ if self._memo.is_ignored_name(node):
388
+ return None
389
+
390
+ return node
391
+
392
+ def visit_Subscript(self, node: Subscript) -> Any:
393
+ if self._memo.is_ignored_name(node.value):
394
+ return None
395
+
396
+ # The subscript of typing(_extensions).Literal can be any arbitrary string, so
397
+ # don't try to evaluate it as code
398
+ if not self._memo.name_matches(node.value, *literal_names) and node.slice:
399
+ if isinstance(node.slice, Index):
400
+ # Python 3.7 and 3.8
401
+ slice_value = node.slice.value # type: ignore[attr-defined]
402
+ else:
403
+ slice_value = node.slice
404
+
405
+ if isinstance(slice_value, Tuple):
406
+ if self._memo.name_matches(node.value, *annotated_names):
407
+ # Only treat the first argument to typing.Annotated as a potential
408
+ # forward reference
409
+ items = cast(
410
+ typing.List[expr],
411
+ [self.generic_visit(slice_value.elts[0])]
412
+ + slice_value.elts[1:],
413
+ )
414
+ else:
415
+ items = cast(
416
+ typing.List[expr],
417
+ [self.generic_visit(item) for item in slice_value.elts],
418
+ )
419
+
420
+ # If this is a Union and any of the items is Any, erase the entire
421
+ # annotation
422
+ if self._memo.name_matches(node.value, "typing.Union") and any(
423
+ isinstance(item, expr)
424
+ and self._memo.name_matches(item, *anytype_names)
425
+ for item in items
426
+ ):
427
+ return None
428
+
429
+ # If all items in the subscript were Any, erase the subscript entirely
430
+ if all(item is None for item in items):
431
+ return node.value
432
+
433
+ for index, item in enumerate(items):
434
+ if item is None:
435
+ items[index] = self.transformer._get_import("typing", "Any")
436
+
437
+ slice_value.elts = items
438
+ else:
439
+ self.generic_visit(node)
440
+
441
+ # If the transformer erased the slice entirely, just return the node
442
+ # value without the subscript (unless it's Optional, in which case erase
443
+ # the node entirely
444
+ if self._memo.name_matches(node.value, "typing.Optional"):
445
+ return None
446
+ elif sys.version_info >= (3, 9) and not hasattr(node, "slice"):
447
+ return node.value
448
+ elif sys.version_info < (3, 9) and not hasattr(node.slice, "value"):
449
+ return node.value
450
+
451
+ return node
452
+
453
+ def visit_Name(self, node: Name) -> Any:
454
+ if self._memo.is_ignored_name(node):
455
+ return None
456
+
457
+ if sys.version_info < (3, 9):
458
+ for typename, substitute in self.type_substitutions.items():
459
+ if self._memo.name_matches(node, typename):
460
+ new_node = self.transformer._get_import(*substitute)
461
+ return copy_location(new_node, node)
462
+
463
+ return node
464
+
465
+ def visit_Call(self, node: Call) -> Any:
466
+ # Don't recurse into calls
467
+ return node
468
+
469
+ def visit_Constant(self, node: Constant) -> Any:
470
+ if isinstance(node.value, str):
471
+ expression = ast.parse(node.value, mode="eval")
472
+ new_node = self.visit(expression)
473
+ if new_node:
474
+ return copy_location(new_node.body, node)
475
+ else:
476
+ return None
477
+
478
+ return node
479
+
480
+ def visit_Str(self, node: Str) -> Any:
481
+ # Only used on Python 3.7
482
+ expression = ast.parse(node.s, mode="eval")
483
+ new_node = self.visit(expression)
484
+ if new_node:
485
+ return copy_location(new_node.body, node)
486
+ else:
487
+ return None
488
+
489
+
490
+ class TypeguardTransformer(NodeTransformer):
491
+ def __init__(
492
+ self, target_path: Sequence[str] | None = None, target_lineno: int | None = None
493
+ ) -> None:
494
+ self._target_path = tuple(target_path) if target_path else None
495
+ self._memo = self._module_memo = TransformMemo(None, None, ())
496
+ self.names_used_in_annotations: set[str] = set()
497
+ self.target_node: FunctionDef | AsyncFunctionDef | None = None
498
+ self.target_lineno = target_lineno
499
+
500
+ @contextmanager
501
+ def _use_memo(
502
+ self, node: ClassDef | FunctionDef | AsyncFunctionDef
503
+ ) -> Generator[None, Any, None]:
504
+ new_memo = TransformMemo(node, self._memo, self._memo.path + (node.name,))
505
+ if isinstance(node, (FunctionDef, AsyncFunctionDef)):
506
+ new_memo.should_instrument = (
507
+ self._target_path is None or new_memo.path == self._target_path
508
+ )
509
+ if new_memo.should_instrument:
510
+ # Check if the function is a generator function
511
+ detector = GeneratorDetector()
512
+ detector.visit(node)
513
+
514
+ # Extract yield, send and return types where possible from a subscripted
515
+ # annotation like Generator[int, str, bool]
516
+ return_annotation = deepcopy(node.returns)
517
+ if detector.contains_yields and new_memo.name_matches(
518
+ return_annotation, *generator_names
519
+ ):
520
+ if isinstance(return_annotation, Subscript):
521
+ annotation_slice = return_annotation.slice
522
+
523
+ # Python < 3.9
524
+ if isinstance(annotation_slice, Index):
525
+ annotation_slice = (
526
+ annotation_slice.value # type: ignore[attr-defined]
527
+ )
528
+
529
+ if isinstance(annotation_slice, Tuple):
530
+ items = annotation_slice.elts
531
+ else:
532
+ items = [annotation_slice]
533
+
534
+ if len(items) > 0:
535
+ new_memo.yield_annotation = self._convert_annotation(
536
+ items[0]
537
+ )
538
+
539
+ if len(items) > 1:
540
+ new_memo.send_annotation = self._convert_annotation(
541
+ items[1]
542
+ )
543
+
544
+ if len(items) > 2:
545
+ new_memo.return_annotation = self._convert_annotation(
546
+ items[2]
547
+ )
548
+ else:
549
+ new_memo.return_annotation = self._convert_annotation(
550
+ return_annotation
551
+ )
552
+
553
+ if isinstance(node, AsyncFunctionDef):
554
+ new_memo.is_async = True
555
+
556
+ old_memo = self._memo
557
+ self._memo = new_memo
558
+ yield
559
+ self._memo = old_memo
560
+
561
+ def _get_import(self, module: str, name: str) -> Name:
562
+ memo = self._memo if self._target_path else self._module_memo
563
+ return memo.get_import(module, name)
564
+
565
+ @overload
566
+ def _convert_annotation(self, annotation: None) -> None:
567
+ ...
568
+
569
+ @overload
570
+ def _convert_annotation(self, annotation: expr) -> expr:
571
+ ...
572
+
573
+ def _convert_annotation(self, annotation: expr | None) -> expr | None:
574
+ if annotation is None:
575
+ return None
576
+
577
+ # Convert PEP 604 unions (x | y) and generic built-in collections where
578
+ # necessary, and undo forward references
579
+ new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
580
+ if isinstance(new_annotation, expr):
581
+ new_annotation = ast.copy_location(new_annotation, annotation)
582
+
583
+ # Store names used in the annotation
584
+ names = {node.id for node in walk(new_annotation) if isinstance(node, Name)}
585
+ self.names_used_in_annotations.update(names)
586
+
587
+ return new_annotation
588
+
589
+ def visit_Name(self, node: Name) -> Name:
590
+ self._memo.local_names.add(node.id)
591
+ return node
592
+
593
+ def visit_Module(self, node: Module) -> Module:
594
+ self.generic_visit(node)
595
+ self._memo.insert_imports(node)
596
+
597
+ fix_missing_locations(node)
598
+ return node
599
+
600
+ def visit_Import(self, node: Import) -> Import:
601
+ for name in node.names:
602
+ self._memo.local_names.add(name.asname or name.name)
603
+ self._memo.imported_names[name.asname or name.name] = name.name
604
+
605
+ return node
606
+
607
+ def visit_ImportFrom(self, node: ImportFrom) -> ImportFrom:
608
+ for name in node.names:
609
+ if name.name != "*":
610
+ alias = name.asname or name.name
611
+ self._memo.local_names.add(alias)
612
+ self._memo.imported_names[alias] = f"{node.module}.{name.name}"
613
+
614
+ return node
615
+
616
+ def visit_ClassDef(self, node: ClassDef) -> ClassDef | None:
617
+ self._memo.local_names.add(node.name)
618
+
619
+ # Eliminate top level classes not belonging to the target path
620
+ if (
621
+ self._target_path is not None
622
+ and not self._memo.path
623
+ and node.name != self._target_path[0]
624
+ ):
625
+ return None
626
+
627
+ with self._use_memo(node):
628
+ for decorator in node.decorator_list.copy():
629
+ if self._memo.name_matches(decorator, "typeguard.typechecked"):
630
+ # Remove the decorator to prevent duplicate instrumentation
631
+ node.decorator_list.remove(decorator)
632
+
633
+ # Store any configuration overrides
634
+ if isinstance(decorator, Call) and decorator.keywords:
635
+ self._memo.configuration_overrides.update(
636
+ {kw.arg: kw.value for kw in decorator.keywords if kw.arg}
637
+ )
638
+
639
+ self.generic_visit(node)
640
+ return node
641
+
642
+ def visit_FunctionDef(
643
+ self, node: FunctionDef | AsyncFunctionDef
644
+ ) -> FunctionDef | AsyncFunctionDef | None:
645
+ """
646
+ Injects type checks for function arguments, and for a return of None if the
647
+ function is annotated to return something else than Any or None, and the body
648
+ ends without an explicit "return".
649
+
650
+ """
651
+ self._memo.local_names.add(node.name)
652
+
653
+ # Eliminate top level functions not belonging to the target path
654
+ if (
655
+ self._target_path is not None
656
+ and not self._memo.path
657
+ and node.name != self._target_path[0]
658
+ ):
659
+ return None
660
+
661
+ # Skip instrumentation if we're instrumenting the whole module and the function
662
+ # contains either @no_type_check or @typeguard_ignore
663
+ if self._target_path is None:
664
+ for decorator in node.decorator_list:
665
+ if self._memo.name_matches(decorator, *ignore_decorators):
666
+ return node
667
+
668
+ with self._use_memo(node):
669
+ arg_annotations: dict[str, Any] = {}
670
+ if self._target_path is None or self._memo.path == self._target_path:
671
+ # Find line number we're supposed to match against
672
+ if node.decorator_list:
673
+ first_lineno = node.decorator_list[0].lineno
674
+ else:
675
+ first_lineno = node.lineno
676
+
677
+ for decorator in node.decorator_list.copy():
678
+ if self._memo.name_matches(decorator, "typing.overload"):
679
+ # Remove overloads entirely
680
+ return None
681
+ elif self._memo.name_matches(decorator, "typeguard.typechecked"):
682
+ # Remove the decorator to prevent duplicate instrumentation
683
+ node.decorator_list.remove(decorator)
684
+
685
+ # Store any configuration overrides
686
+ if isinstance(decorator, Call) and decorator.keywords:
687
+ self._memo.configuration_overrides = {
688
+ kw.arg: kw.value for kw in decorator.keywords if kw.arg
689
+ }
690
+
691
+ if self.target_lineno == first_lineno:
692
+ assert self.target_node is None
693
+ self.target_node = node
694
+ if node.decorator_list and sys.version_info >= (3, 8):
695
+ self.target_lineno = node.decorator_list[0].lineno
696
+ else:
697
+ self.target_lineno = node.lineno
698
+
699
+ all_args = node.args.args + node.args.kwonlyargs
700
+ if sys.version_info >= (3, 8):
701
+ all_args.extend(node.args.posonlyargs)
702
+
703
+ # Ensure that any type shadowed by the positional or keyword-only
704
+ # argument names are ignored in this function
705
+ for arg in all_args:
706
+ self._memo.ignored_names.add(arg.arg)
707
+
708
+ # Ensure that any type shadowed by the variable positional argument name
709
+ # (e.g. "args" in *args) is ignored this function
710
+ if node.args.vararg:
711
+ self._memo.ignored_names.add(node.args.vararg.arg)
712
+
713
+ # Ensure that any type shadowed by the variable keywrod argument name
714
+ # (e.g. "kwargs" in *kwargs) is ignored this function
715
+ if node.args.kwarg:
716
+ self._memo.ignored_names.add(node.args.kwarg.arg)
717
+
718
+ for arg in all_args:
719
+ annotation = self._convert_annotation(deepcopy(arg.annotation))
720
+ if annotation:
721
+ arg_annotations[arg.arg] = annotation
722
+
723
+ if node.args.vararg:
724
+ annotation_ = self._convert_annotation(node.args.vararg.annotation)
725
+ if annotation_:
726
+ if sys.version_info >= (3, 9):
727
+ container = Name("tuple", ctx=Load())
728
+ else:
729
+ container = self._get_import("typing", "Tuple")
730
+
731
+ subscript_slice: Tuple | Index = Tuple(
732
+ [
733
+ annotation_,
734
+ Constant(Ellipsis),
735
+ ],
736
+ ctx=Load(),
737
+ )
738
+ if sys.version_info < (3, 9):
739
+ subscript_slice = Index(subscript_slice, ctx=Load())
740
+
741
+ arg_annotations[node.args.vararg.arg] = Subscript(
742
+ container, subscript_slice, ctx=Load()
743
+ )
744
+
745
+ if node.args.kwarg:
746
+ annotation_ = self._convert_annotation(node.args.kwarg.annotation)
747
+ if annotation_:
748
+ if sys.version_info >= (3, 9):
749
+ container = Name("dict", ctx=Load())
750
+ else:
751
+ container = self._get_import("typing", "Dict")
752
+
753
+ subscript_slice = Tuple(
754
+ [
755
+ Name("str", ctx=Load()),
756
+ annotation_,
757
+ ],
758
+ ctx=Load(),
759
+ )
760
+ if sys.version_info < (3, 9):
761
+ subscript_slice = Index(subscript_slice, ctx=Load())
762
+
763
+ arg_annotations[node.args.kwarg.arg] = Subscript(
764
+ container, subscript_slice, ctx=Load()
765
+ )
766
+
767
+ if arg_annotations:
768
+ self._memo.variable_annotations.update(arg_annotations)
769
+
770
+ self.generic_visit(node)
771
+
772
+ if arg_annotations:
773
+ annotations_dict = Dict(
774
+ keys=[Constant(key) for key in arg_annotations.keys()],
775
+ values=[
776
+ Tuple([Name(key, ctx=Load()), annotation], ctx=Load())
777
+ for key, annotation in arg_annotations.items()
778
+ ],
779
+ )
780
+ func_name = self._get_import(
781
+ "typeguard._functions", "check_argument_types"
782
+ )
783
+ args = [
784
+ self._memo.joined_path,
785
+ annotations_dict,
786
+ self._memo.get_memo_name(),
787
+ ]
788
+ node.body.insert(
789
+ self._memo.code_inject_index, Expr(Call(func_name, args, []))
790
+ )
791
+
792
+ # Add a checked "return None" to the end if there's no explicit return
793
+ # Skip if the return annotation is None or Any
794
+ if (
795
+ self._memo.return_annotation
796
+ and (not self._memo.is_async or not self._memo.has_yield_expressions)
797
+ and not isinstance(node.body[-1], Return)
798
+ and (
799
+ not isinstance(self._memo.return_annotation, Constant)
800
+ or self._memo.return_annotation.value is not None
801
+ )
802
+ ):
803
+ func_name = self._get_import(
804
+ "typeguard._functions", "check_return_type"
805
+ )
806
+ return_node = Return(
807
+ Call(
808
+ func_name,
809
+ [
810
+ self._memo.joined_path,
811
+ Constant(None),
812
+ self._memo.return_annotation,
813
+ self._memo.get_memo_name(),
814
+ ],
815
+ [],
816
+ )
817
+ )
818
+
819
+ # Replace a placeholder "pass" at the end
820
+ if isinstance(node.body[-1], Pass):
821
+ copy_location(return_node, node.body[-1])
822
+ del node.body[-1]
823
+
824
+ node.body.append(return_node)
825
+
826
+ # Insert code to create the call memo, if it was ever needed for this
827
+ # function
828
+ if self._memo.memo_var_name:
829
+ memo_kwargs: dict[str, Any] = {}
830
+ if self._memo.parent and isinstance(self._memo.parent.node, ClassDef):
831
+ for decorator in node.decorator_list:
832
+ if (
833
+ isinstance(decorator, Name)
834
+ and decorator.id == "staticmethod"
835
+ ):
836
+ break
837
+ elif (
838
+ isinstance(decorator, Name)
839
+ and decorator.id == "classmethod"
840
+ ):
841
+ memo_kwargs["self_type"] = Name(
842
+ id=node.args.args[0].arg, ctx=Load()
843
+ )
844
+ break
845
+ else:
846
+ if node.args.args:
847
+ if node.name == "__new__":
848
+ memo_kwargs["self_type"] = Name(
849
+ id=node.args.args[0].arg, ctx=Load()
850
+ )
851
+ else:
852
+ memo_kwargs["self_type"] = Attribute(
853
+ Name(id=node.args.args[0].arg, ctx=Load()),
854
+ "__class__",
855
+ ctx=Load(),
856
+ )
857
+
858
+ # Construct the function reference
859
+ # Nested functions get special treatment: the function name is added
860
+ # to free variables (and the closure of the resulting function)
861
+ names: list[str] = [node.name]
862
+ memo = self._memo.parent
863
+ while memo:
864
+ if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
865
+ # This is a nested function. Use the function name as-is.
866
+ del names[:-1]
867
+ break
868
+ elif not isinstance(memo.node, ClassDef):
869
+ break
870
+
871
+ names.insert(0, memo.node.name)
872
+ memo = memo.parent
873
+
874
+ config_keywords = self._memo.get_config_keywords()
875
+ if config_keywords:
876
+ memo_kwargs["config"] = Call(
877
+ self._get_import("dataclasses", "replace"),
878
+ [self._get_import("typeguard._config", "global_config")],
879
+ config_keywords,
880
+ )
881
+
882
+ self._memo.memo_var_name.id = self._memo.get_unused_name("memo")
883
+ memo_store_name = Name(id=self._memo.memo_var_name.id, ctx=Store())
884
+ globals_call = Call(Name(id="globals", ctx=Load()), [], [])
885
+ locals_call = Call(Name(id="locals", ctx=Load()), [], [])
886
+ memo_expr = Call(
887
+ self._get_import("typeguard", "TypeCheckMemo"),
888
+ [globals_call, locals_call],
889
+ [keyword(key, value) for key, value in memo_kwargs.items()],
890
+ )
891
+ node.body.insert(
892
+ self._memo.code_inject_index,
893
+ Assign([memo_store_name], memo_expr),
894
+ )
895
+
896
+ self._memo.insert_imports(node)
897
+
898
+ # Rmove any placeholder "pass" at the end
899
+ if isinstance(node.body[-1], Pass):
900
+ del node.body[-1]
901
+
902
+ return node
903
+
904
+ def visit_AsyncFunctionDef(
905
+ self, node: AsyncFunctionDef
906
+ ) -> FunctionDef | AsyncFunctionDef | None:
907
+ return self.visit_FunctionDef(node)
908
+
909
+ def visit_Return(self, node: Return) -> Return:
910
+ """This injects type checks into "return" statements."""
911
+ self.generic_visit(node)
912
+ if (
913
+ self._memo.return_annotation
914
+ and self._memo.should_instrument
915
+ and not self._memo.is_ignored_name(self._memo.return_annotation)
916
+ ):
917
+ func_name = self._get_import("typeguard._functions", "check_return_type")
918
+ old_node = node
919
+ retval = old_node.value or Constant(None)
920
+ node = Return(
921
+ Call(
922
+ func_name,
923
+ [
924
+ self._memo.joined_path,
925
+ retval,
926
+ self._memo.return_annotation,
927
+ self._memo.get_memo_name(),
928
+ ],
929
+ [],
930
+ )
931
+ )
932
+ copy_location(node, old_node)
933
+
934
+ return node
935
+
936
+ def visit_Yield(self, node: Yield) -> Yield | Call:
937
+ """
938
+ This injects type checks into "yield" expressions, checking both the yielded
939
+ value and the value sent back to the generator, when appropriate.
940
+
941
+ """
942
+ self._memo.has_yield_expressions = True
943
+ self.generic_visit(node)
944
+
945
+ if (
946
+ self._memo.yield_annotation
947
+ and self._memo.should_instrument
948
+ and not self._memo.is_ignored_name(self._memo.yield_annotation)
949
+ ):
950
+ func_name = self._get_import("typeguard._functions", "check_yield_type")
951
+ yieldval = node.value or Constant(None)
952
+ node.value = Call(
953
+ func_name,
954
+ [
955
+ self._memo.joined_path,
956
+ yieldval,
957
+ self._memo.yield_annotation,
958
+ self._memo.get_memo_name(),
959
+ ],
960
+ [],
961
+ )
962
+
963
+ if (
964
+ self._memo.send_annotation
965
+ and self._memo.should_instrument
966
+ and not self._memo.is_ignored_name(self._memo.send_annotation)
967
+ ):
968
+ func_name = self._get_import("typeguard._functions", "check_send_type")
969
+ old_node = node
970
+ call_node = Call(
971
+ func_name,
972
+ [
973
+ self._memo.joined_path,
974
+ old_node,
975
+ self._memo.send_annotation,
976
+ self._memo.get_memo_name(),
977
+ ],
978
+ [],
979
+ )
980
+ copy_location(call_node, old_node)
981
+ return call_node
982
+
983
+ return node
984
+
985
+ def visit_AnnAssign(self, node: AnnAssign) -> Any:
986
+ """
987
+ This injects a type check into a local variable annotation-assignment within a
988
+ function body.
989
+
990
+ """
991
+ self.generic_visit(node)
992
+
993
+ if (
994
+ isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef))
995
+ and node.annotation
996
+ and isinstance(node.target, Name)
997
+ ):
998
+ self._memo.ignored_names.add(node.target.id)
999
+ annotation = self._convert_annotation(deepcopy(node.annotation))
1000
+ if annotation:
1001
+ self._memo.variable_annotations[node.target.id] = annotation
1002
+ if node.value:
1003
+ func_name = self._get_import(
1004
+ "typeguard._functions", "check_variable_assignment"
1005
+ )
1006
+ node.value = Call(
1007
+ func_name,
1008
+ [
1009
+ node.value,
1010
+ Constant(node.target.id),
1011
+ annotation,
1012
+ self._memo.get_memo_name(),
1013
+ ],
1014
+ [],
1015
+ )
1016
+
1017
+ return node
1018
+
1019
+ def visit_Assign(self, node: Assign) -> Any:
1020
+ """
1021
+ This injects a type check into a local variable assignment within a function
1022
+ body. The variable must have been annotated earlier in the function body.
1023
+
1024
+ """
1025
+ self.generic_visit(node)
1026
+
1027
+ # Only instrument function-local assignments
1028
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)):
1029
+ targets: list[dict[Constant, expr | None]] = []
1030
+ check_required = False
1031
+ for target in node.targets:
1032
+ elts: Sequence[expr]
1033
+ if isinstance(target, Name):
1034
+ elts = [target]
1035
+ elif isinstance(target, Tuple):
1036
+ elts = target.elts
1037
+ else:
1038
+ continue
1039
+
1040
+ annotations_: dict[Constant, expr | None] = {}
1041
+ for exp in elts:
1042
+ prefix = ""
1043
+ if isinstance(exp, Starred):
1044
+ exp = exp.value
1045
+ prefix = "*"
1046
+
1047
+ if isinstance(exp, Name):
1048
+ self._memo.ignored_names.add(exp.id)
1049
+ name = prefix + exp.id
1050
+ annotation = self._memo.variable_annotations.get(exp.id)
1051
+ if annotation:
1052
+ annotations_[Constant(name)] = annotation
1053
+ check_required = True
1054
+ else:
1055
+ annotations_[Constant(name)] = None
1056
+
1057
+ targets.append(annotations_)
1058
+
1059
+ if check_required:
1060
+ # Replace missing annotations with typing.Any
1061
+ for item in targets:
1062
+ for key, expression in item.items():
1063
+ if expression is None:
1064
+ item[key] = self._get_import("typing", "Any")
1065
+
1066
+ if len(targets) == 1 and len(targets[0]) == 1:
1067
+ func_name = self._get_import(
1068
+ "typeguard._functions", "check_variable_assignment"
1069
+ )
1070
+ target_varname = next(iter(targets[0]))
1071
+ node.value = Call(
1072
+ func_name,
1073
+ [
1074
+ node.value,
1075
+ target_varname,
1076
+ targets[0][target_varname],
1077
+ self._memo.get_memo_name(),
1078
+ ],
1079
+ [],
1080
+ )
1081
+ elif targets:
1082
+ func_name = self._get_import(
1083
+ "typeguard._functions", "check_multi_variable_assignment"
1084
+ )
1085
+ targets_arg = List(
1086
+ [
1087
+ Dict(keys=list(target), values=list(target.values()))
1088
+ for target in targets
1089
+ ],
1090
+ ctx=Load(),
1091
+ )
1092
+ node.value = Call(
1093
+ func_name,
1094
+ [node.value, targets_arg, self._memo.get_memo_name()],
1095
+ [],
1096
+ )
1097
+
1098
+ return node
1099
+
1100
+ def visit_NamedExpr(self, node: NamedExpr) -> Any:
1101
+ """This injects a type check into an assignment expression (a := foo())."""
1102
+ self.generic_visit(node)
1103
+
1104
+ # Only instrument function-local assignments
1105
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
1106
+ node.target, Name
1107
+ ):
1108
+ self._memo.ignored_names.add(node.target.id)
1109
+
1110
+ # Bail out if no matching annotation is found
1111
+ annotation = self._memo.variable_annotations.get(node.target.id)
1112
+ if annotation is None:
1113
+ return node
1114
+
1115
+ func_name = self._get_import(
1116
+ "typeguard._functions", "check_variable_assignment"
1117
+ )
1118
+ node.value = Call(
1119
+ func_name,
1120
+ [
1121
+ node.value,
1122
+ Constant(node.target.id),
1123
+ annotation,
1124
+ self._memo.get_memo_name(),
1125
+ ],
1126
+ [],
1127
+ )
1128
+
1129
+ return node
1130
+
1131
+ def visit_AugAssign(self, node: AugAssign) -> Any:
1132
+ """
1133
+ This injects a type check into an augmented assignment expression (a += 1).
1134
+
1135
+ """
1136
+ self.generic_visit(node)
1137
+
1138
+ # Only instrument function-local assignments
1139
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
1140
+ node.target, Name
1141
+ ):
1142
+ # Bail out if no matching annotation is found
1143
+ annotation = self._memo.variable_annotations.get(node.target.id)
1144
+ if annotation is None:
1145
+ return node
1146
+
1147
+ # Bail out if the operator is not found (newer Python version?)
1148
+ try:
1149
+ operator_func_name = aug_assign_functions[node.op.__class__]
1150
+ except KeyError:
1151
+ return node
1152
+
1153
+ operator_func = self._get_import("operator", operator_func_name)
1154
+ operator_call = Call(
1155
+ operator_func, [Name(node.target.id, ctx=Load()), node.value], []
1156
+ )
1157
+ check_call = Call(
1158
+ self._get_import("typeguard._functions", "check_variable_assignment"),
1159
+ [
1160
+ operator_call,
1161
+ Constant(node.target.id),
1162
+ annotation,
1163
+ self._memo.get_memo_name(),
1164
+ ],
1165
+ [],
1166
+ )
1167
+ return Assign(targets=[node.target], value=check_call)
1168
+
1169
+ return node
1170
+
1171
+ def visit_If(self, node: If) -> Any:
1172
+ """
1173
+ This blocks names from being collected from a module-level
1174
+ "if typing.TYPE_CHECKING:" block, so that they won't be type checked.
1175
+
1176
+ """
1177
+ self.generic_visit(node)
1178
+
1179
+ # Fix empty node body (caused by removal of classes/functions not on the target
1180
+ # path)
1181
+ if not node.body:
1182
+ node.body.append(Pass())
1183
+
1184
+ if (
1185
+ self._memo is self._module_memo
1186
+ and isinstance(node.test, Name)
1187
+ and self._memo.name_matches(node.test, "typing.TYPE_CHECKING")
1188
+ ):
1189
+ collector = NameCollector()
1190
+ collector.visit(node)
1191
+ self._memo.ignored_names.update(collector.names)
1192
+
1193
+ return node