ob-metaflow 2.11.13.1__py2.py3-none-any.whl → 2.19.7.1rc0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. metaflow/R.py +10 -7
  2. metaflow/__init__.py +40 -25
  3. metaflow/_vendor/imghdr/__init__.py +186 -0
  4. metaflow/_vendor/importlib_metadata/__init__.py +1063 -0
  5. metaflow/_vendor/importlib_metadata/_adapters.py +68 -0
  6. metaflow/_vendor/importlib_metadata/_collections.py +30 -0
  7. metaflow/_vendor/importlib_metadata/_compat.py +71 -0
  8. metaflow/_vendor/importlib_metadata/_functools.py +104 -0
  9. metaflow/_vendor/importlib_metadata/_itertools.py +73 -0
  10. metaflow/_vendor/importlib_metadata/_meta.py +48 -0
  11. metaflow/_vendor/importlib_metadata/_text.py +99 -0
  12. metaflow/_vendor/importlib_metadata/py.typed +0 -0
  13. metaflow/_vendor/typeguard/__init__.py +48 -0
  14. metaflow/_vendor/typeguard/_checkers.py +1070 -0
  15. metaflow/_vendor/typeguard/_config.py +108 -0
  16. metaflow/_vendor/typeguard/_decorators.py +233 -0
  17. metaflow/_vendor/typeguard/_exceptions.py +42 -0
  18. metaflow/_vendor/typeguard/_functions.py +308 -0
  19. metaflow/_vendor/typeguard/_importhook.py +213 -0
  20. metaflow/_vendor/typeguard/_memo.py +48 -0
  21. metaflow/_vendor/typeguard/_pytest_plugin.py +127 -0
  22. metaflow/_vendor/typeguard/_suppression.py +86 -0
  23. metaflow/_vendor/typeguard/_transformer.py +1229 -0
  24. metaflow/_vendor/typeguard/_union_transformer.py +55 -0
  25. metaflow/_vendor/typeguard/_utils.py +173 -0
  26. metaflow/_vendor/typeguard/py.typed +0 -0
  27. metaflow/_vendor/typing_extensions.py +3641 -0
  28. metaflow/_vendor/v3_7/importlib_metadata/__init__.py +1063 -0
  29. metaflow/_vendor/v3_7/importlib_metadata/_adapters.py +68 -0
  30. metaflow/_vendor/v3_7/importlib_metadata/_collections.py +30 -0
  31. metaflow/_vendor/v3_7/importlib_metadata/_compat.py +71 -0
  32. metaflow/_vendor/v3_7/importlib_metadata/_functools.py +104 -0
  33. metaflow/_vendor/v3_7/importlib_metadata/_itertools.py +73 -0
  34. metaflow/_vendor/v3_7/importlib_metadata/_meta.py +48 -0
  35. metaflow/_vendor/v3_7/importlib_metadata/_text.py +99 -0
  36. metaflow/_vendor/v3_7/importlib_metadata/py.typed +0 -0
  37. metaflow/_vendor/v3_7/typeguard/__init__.py +48 -0
  38. metaflow/_vendor/v3_7/typeguard/_checkers.py +906 -0
  39. metaflow/_vendor/v3_7/typeguard/_config.py +108 -0
  40. metaflow/_vendor/v3_7/typeguard/_decorators.py +237 -0
  41. metaflow/_vendor/v3_7/typeguard/_exceptions.py +42 -0
  42. metaflow/_vendor/v3_7/typeguard/_functions.py +310 -0
  43. metaflow/_vendor/v3_7/typeguard/_importhook.py +213 -0
  44. metaflow/_vendor/v3_7/typeguard/_memo.py +48 -0
  45. metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py +100 -0
  46. metaflow/_vendor/v3_7/typeguard/_suppression.py +88 -0
  47. metaflow/_vendor/v3_7/typeguard/_transformer.py +1207 -0
  48. metaflow/_vendor/v3_7/typeguard/_union_transformer.py +54 -0
  49. metaflow/_vendor/v3_7/typeguard/_utils.py +169 -0
  50. metaflow/_vendor/v3_7/typeguard/py.typed +0 -0
  51. metaflow/_vendor/v3_7/typing_extensions.py +3072 -0
  52. metaflow/_vendor/yaml/__init__.py +427 -0
  53. metaflow/_vendor/yaml/composer.py +139 -0
  54. metaflow/_vendor/yaml/constructor.py +748 -0
  55. metaflow/_vendor/yaml/cyaml.py +101 -0
  56. metaflow/_vendor/yaml/dumper.py +62 -0
  57. metaflow/_vendor/yaml/emitter.py +1137 -0
  58. metaflow/_vendor/yaml/error.py +75 -0
  59. metaflow/_vendor/yaml/events.py +86 -0
  60. metaflow/_vendor/yaml/loader.py +63 -0
  61. metaflow/_vendor/yaml/nodes.py +49 -0
  62. metaflow/_vendor/yaml/parser.py +589 -0
  63. metaflow/_vendor/yaml/reader.py +185 -0
  64. metaflow/_vendor/yaml/representer.py +389 -0
  65. metaflow/_vendor/yaml/resolver.py +227 -0
  66. metaflow/_vendor/yaml/scanner.py +1435 -0
  67. metaflow/_vendor/yaml/serializer.py +111 -0
  68. metaflow/_vendor/yaml/tokens.py +104 -0
  69. metaflow/cards.py +5 -0
  70. metaflow/cli.py +331 -785
  71. metaflow/cli_args.py +17 -0
  72. metaflow/cli_components/__init__.py +0 -0
  73. metaflow/cli_components/dump_cmd.py +96 -0
  74. metaflow/cli_components/init_cmd.py +52 -0
  75. metaflow/cli_components/run_cmds.py +546 -0
  76. metaflow/cli_components/step_cmd.py +334 -0
  77. metaflow/cli_components/utils.py +140 -0
  78. metaflow/client/__init__.py +1 -0
  79. metaflow/client/core.py +467 -73
  80. metaflow/client/filecache.py +75 -35
  81. metaflow/clone_util.py +7 -1
  82. metaflow/cmd/code/__init__.py +231 -0
  83. metaflow/cmd/develop/stub_generator.py +756 -288
  84. metaflow/cmd/develop/stubs.py +12 -28
  85. metaflow/cmd/main_cli.py +6 -4
  86. metaflow/cmd/make_wrapper.py +78 -0
  87. metaflow/datastore/__init__.py +1 -0
  88. metaflow/datastore/content_addressed_store.py +41 -10
  89. metaflow/datastore/datastore_set.py +11 -2
  90. metaflow/datastore/flow_datastore.py +156 -10
  91. metaflow/datastore/spin_datastore.py +91 -0
  92. metaflow/datastore/task_datastore.py +154 -39
  93. metaflow/debug.py +5 -0
  94. metaflow/decorators.py +404 -78
  95. metaflow/exception.py +8 -2
  96. metaflow/extension_support/__init__.py +527 -376
  97. metaflow/extension_support/_empty_file.py +2 -2
  98. metaflow/extension_support/plugins.py +49 -31
  99. metaflow/flowspec.py +482 -33
  100. metaflow/graph.py +210 -42
  101. metaflow/includefile.py +84 -40
  102. metaflow/lint.py +141 -22
  103. metaflow/meta_files.py +13 -0
  104. metaflow/{metadata → metadata_provider}/heartbeat.py +24 -8
  105. metaflow/{metadata → metadata_provider}/metadata.py +86 -1
  106. metaflow/metaflow_config.py +175 -28
  107. metaflow/metaflow_config_funcs.py +51 -3
  108. metaflow/metaflow_current.py +4 -10
  109. metaflow/metaflow_environment.py +139 -53
  110. metaflow/metaflow_git.py +115 -0
  111. metaflow/metaflow_profile.py +18 -0
  112. metaflow/metaflow_version.py +150 -66
  113. metaflow/mflog/__init__.py +4 -3
  114. metaflow/mflog/save_logs.py +2 -2
  115. metaflow/multicore_utils.py +31 -14
  116. metaflow/package/__init__.py +673 -0
  117. metaflow/packaging_sys/__init__.py +880 -0
  118. metaflow/packaging_sys/backend.py +128 -0
  119. metaflow/packaging_sys/distribution_support.py +153 -0
  120. metaflow/packaging_sys/tar_backend.py +99 -0
  121. metaflow/packaging_sys/utils.py +54 -0
  122. metaflow/packaging_sys/v1.py +527 -0
  123. metaflow/parameters.py +149 -28
  124. metaflow/plugins/__init__.py +74 -5
  125. metaflow/plugins/airflow/airflow.py +40 -25
  126. metaflow/plugins/airflow/airflow_cli.py +22 -5
  127. metaflow/plugins/airflow/airflow_decorator.py +1 -1
  128. metaflow/plugins/airflow/airflow_utils.py +5 -3
  129. metaflow/plugins/airflow/sensors/base_sensor.py +4 -4
  130. metaflow/plugins/airflow/sensors/external_task_sensor.py +2 -2
  131. metaflow/plugins/airflow/sensors/s3_sensor.py +2 -2
  132. metaflow/plugins/argo/argo_client.py +78 -33
  133. metaflow/plugins/argo/argo_events.py +6 -6
  134. metaflow/plugins/argo/argo_workflows.py +2410 -527
  135. metaflow/plugins/argo/argo_workflows_cli.py +571 -121
  136. metaflow/plugins/argo/argo_workflows_decorator.py +43 -12
  137. metaflow/plugins/argo/argo_workflows_deployer.py +106 -0
  138. metaflow/plugins/argo/argo_workflows_deployer_objects.py +453 -0
  139. metaflow/plugins/argo/capture_error.py +73 -0
  140. metaflow/plugins/argo/conditional_input_paths.py +35 -0
  141. metaflow/plugins/argo/exit_hooks.py +209 -0
  142. metaflow/plugins/argo/jobset_input_paths.py +15 -0
  143. metaflow/plugins/argo/param_val.py +19 -0
  144. metaflow/plugins/aws/aws_client.py +10 -3
  145. metaflow/plugins/aws/aws_utils.py +55 -2
  146. metaflow/plugins/aws/batch/batch.py +72 -5
  147. metaflow/plugins/aws/batch/batch_cli.py +33 -10
  148. metaflow/plugins/aws/batch/batch_client.py +4 -3
  149. metaflow/plugins/aws/batch/batch_decorator.py +102 -35
  150. metaflow/plugins/aws/secrets_manager/aws_secrets_manager_secrets_provider.py +13 -10
  151. metaflow/plugins/aws/step_functions/dynamo_db_client.py +0 -3
  152. metaflow/plugins/aws/step_functions/production_token.py +1 -1
  153. metaflow/plugins/aws/step_functions/step_functions.py +65 -8
  154. metaflow/plugins/aws/step_functions/step_functions_cli.py +101 -7
  155. metaflow/plugins/aws/step_functions/step_functions_decorator.py +1 -2
  156. metaflow/plugins/aws/step_functions/step_functions_deployer.py +97 -0
  157. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +264 -0
  158. metaflow/plugins/azure/azure_exceptions.py +1 -1
  159. metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
  160. metaflow/plugins/azure/azure_tail.py +1 -1
  161. metaflow/plugins/azure/includefile_support.py +2 -0
  162. metaflow/plugins/cards/card_cli.py +66 -30
  163. metaflow/plugins/cards/card_creator.py +25 -1
  164. metaflow/plugins/cards/card_datastore.py +21 -49
  165. metaflow/plugins/cards/card_decorator.py +132 -8
  166. metaflow/plugins/cards/card_modules/basic.py +112 -17
  167. metaflow/plugins/cards/card_modules/bundle.css +1 -1
  168. metaflow/plugins/cards/card_modules/card.py +16 -1
  169. metaflow/plugins/cards/card_modules/chevron/renderer.py +1 -1
  170. metaflow/plugins/cards/card_modules/components.py +665 -28
  171. metaflow/plugins/cards/card_modules/convert_to_native_type.py +36 -7
  172. metaflow/plugins/cards/card_modules/json_viewer.py +232 -0
  173. metaflow/plugins/cards/card_modules/main.css +1 -0
  174. metaflow/plugins/cards/card_modules/main.js +68 -49
  175. metaflow/plugins/cards/card_modules/renderer_tools.py +1 -0
  176. metaflow/plugins/cards/card_modules/test_cards.py +26 -12
  177. metaflow/plugins/cards/card_server.py +39 -14
  178. metaflow/plugins/cards/component_serializer.py +2 -9
  179. metaflow/plugins/cards/metadata.py +22 -0
  180. metaflow/plugins/catch_decorator.py +9 -0
  181. metaflow/plugins/datastores/azure_storage.py +10 -1
  182. metaflow/plugins/datastores/gs_storage.py +6 -2
  183. metaflow/plugins/datastores/local_storage.py +12 -6
  184. metaflow/plugins/datastores/spin_storage.py +12 -0
  185. metaflow/plugins/datatools/local.py +2 -0
  186. metaflow/plugins/datatools/s3/s3.py +126 -75
  187. metaflow/plugins/datatools/s3/s3op.py +254 -121
  188. metaflow/plugins/env_escape/__init__.py +3 -3
  189. metaflow/plugins/env_escape/client_modules.py +102 -72
  190. metaflow/plugins/env_escape/server.py +7 -0
  191. metaflow/plugins/env_escape/stub.py +24 -5
  192. metaflow/plugins/events_decorator.py +343 -185
  193. metaflow/plugins/exit_hook/__init__.py +0 -0
  194. metaflow/plugins/exit_hook/exit_hook_decorator.py +46 -0
  195. metaflow/plugins/exit_hook/exit_hook_script.py +52 -0
  196. metaflow/plugins/gcp/__init__.py +1 -1
  197. metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +11 -6
  198. metaflow/plugins/gcp/gs_tail.py +10 -6
  199. metaflow/plugins/gcp/includefile_support.py +3 -0
  200. metaflow/plugins/kubernetes/kube_utils.py +108 -0
  201. metaflow/plugins/kubernetes/kubernetes.py +411 -130
  202. metaflow/plugins/kubernetes/kubernetes_cli.py +168 -36
  203. metaflow/plugins/kubernetes/kubernetes_client.py +104 -2
  204. metaflow/plugins/kubernetes/kubernetes_decorator.py +246 -88
  205. metaflow/plugins/kubernetes/kubernetes_job.py +253 -581
  206. metaflow/plugins/kubernetes/kubernetes_jobsets.py +1071 -0
  207. metaflow/plugins/kubernetes/spot_metadata_cli.py +69 -0
  208. metaflow/plugins/kubernetes/spot_monitor_sidecar.py +109 -0
  209. metaflow/plugins/logs_cli.py +359 -0
  210. metaflow/plugins/{metadata → metadata_providers}/local.py +144 -84
  211. metaflow/plugins/{metadata → metadata_providers}/service.py +103 -26
  212. metaflow/plugins/metadata_providers/spin.py +16 -0
  213. metaflow/plugins/package_cli.py +36 -24
  214. metaflow/plugins/parallel_decorator.py +128 -11
  215. metaflow/plugins/parsers.py +16 -0
  216. metaflow/plugins/project_decorator.py +51 -5
  217. metaflow/plugins/pypi/bootstrap.py +357 -105
  218. metaflow/plugins/pypi/conda_decorator.py +82 -81
  219. metaflow/plugins/pypi/conda_environment.py +187 -52
  220. metaflow/plugins/pypi/micromamba.py +157 -47
  221. metaflow/plugins/pypi/parsers.py +268 -0
  222. metaflow/plugins/pypi/pip.py +88 -13
  223. metaflow/plugins/pypi/pypi_decorator.py +37 -1
  224. metaflow/plugins/pypi/utils.py +48 -2
  225. metaflow/plugins/resources_decorator.py +2 -2
  226. metaflow/plugins/secrets/__init__.py +3 -0
  227. metaflow/plugins/secrets/secrets_decorator.py +26 -181
  228. metaflow/plugins/secrets/secrets_func.py +49 -0
  229. metaflow/plugins/secrets/secrets_spec.py +101 -0
  230. metaflow/plugins/secrets/utils.py +74 -0
  231. metaflow/plugins/tag_cli.py +4 -7
  232. metaflow/plugins/test_unbounded_foreach_decorator.py +41 -6
  233. metaflow/plugins/timeout_decorator.py +3 -3
  234. metaflow/plugins/uv/__init__.py +0 -0
  235. metaflow/plugins/uv/bootstrap.py +128 -0
  236. metaflow/plugins/uv/uv_environment.py +72 -0
  237. metaflow/procpoll.py +1 -1
  238. metaflow/pylint_wrapper.py +5 -1
  239. metaflow/runner/__init__.py +0 -0
  240. metaflow/runner/click_api.py +717 -0
  241. metaflow/runner/deployer.py +470 -0
  242. metaflow/runner/deployer_impl.py +201 -0
  243. metaflow/runner/metaflow_runner.py +714 -0
  244. metaflow/runner/nbdeploy.py +132 -0
  245. metaflow/runner/nbrun.py +225 -0
  246. metaflow/runner/subprocess_manager.py +650 -0
  247. metaflow/runner/utils.py +335 -0
  248. metaflow/runtime.py +1078 -260
  249. metaflow/sidecar/sidecar_worker.py +1 -1
  250. metaflow/system/__init__.py +5 -0
  251. metaflow/system/system_logger.py +85 -0
  252. metaflow/system/system_monitor.py +108 -0
  253. metaflow/system/system_utils.py +19 -0
  254. metaflow/task.py +521 -225
  255. metaflow/tracing/__init__.py +7 -7
  256. metaflow/tracing/span_exporter.py +31 -38
  257. metaflow/tracing/tracing_modules.py +38 -43
  258. metaflow/tuple_util.py +27 -0
  259. metaflow/user_configs/__init__.py +0 -0
  260. metaflow/user_configs/config_options.py +563 -0
  261. metaflow/user_configs/config_parameters.py +598 -0
  262. metaflow/user_decorators/__init__.py +0 -0
  263. metaflow/user_decorators/common.py +144 -0
  264. metaflow/user_decorators/mutable_flow.py +512 -0
  265. metaflow/user_decorators/mutable_step.py +424 -0
  266. metaflow/user_decorators/user_flow_decorator.py +264 -0
  267. metaflow/user_decorators/user_step_decorator.py +749 -0
  268. metaflow/util.py +243 -27
  269. metaflow/vendor.py +23 -7
  270. metaflow/version.py +1 -1
  271. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Makefile +355 -0
  272. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Tiltfile +726 -0
  273. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/pick_services.sh +105 -0
  274. ob_metaflow-2.19.7.1rc0.dist-info/METADATA +87 -0
  275. ob_metaflow-2.19.7.1rc0.dist-info/RECORD +445 -0
  276. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/WHEEL +1 -1
  277. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/entry_points.txt +1 -0
  278. metaflow/_vendor/v3_5/__init__.py +0 -1
  279. metaflow/_vendor/v3_5/importlib_metadata/__init__.py +0 -644
  280. metaflow/_vendor/v3_5/importlib_metadata/_compat.py +0 -152
  281. metaflow/package.py +0 -188
  282. ob_metaflow-2.11.13.1.dist-info/METADATA +0 -85
  283. ob_metaflow-2.11.13.1.dist-info/RECORD +0 -308
  284. /metaflow/_vendor/{v3_5/zipp.py → zipp.py} +0 -0
  285. /metaflow/{metadata → metadata_provider}/__init__.py +0 -0
  286. /metaflow/{metadata → metadata_provider}/util.py +0 -0
  287. /metaflow/plugins/{metadata → metadata_providers}/__init__.py +0 -0
  288. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info/licenses}/LICENSE +0 -0
  289. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1229 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import builtins
5
+ import sys
6
+ import typing
7
+ from ast import (
8
+ AST,
9
+ Add,
10
+ AnnAssign,
11
+ Assign,
12
+ AsyncFunctionDef,
13
+ Attribute,
14
+ AugAssign,
15
+ BinOp,
16
+ BitAnd,
17
+ BitOr,
18
+ BitXor,
19
+ Call,
20
+ ClassDef,
21
+ Constant,
22
+ Dict,
23
+ Div,
24
+ Expr,
25
+ Expression,
26
+ FloorDiv,
27
+ FunctionDef,
28
+ If,
29
+ Import,
30
+ ImportFrom,
31
+ Index,
32
+ List,
33
+ Load,
34
+ LShift,
35
+ MatMult,
36
+ Mod,
37
+ Module,
38
+ Mult,
39
+ Name,
40
+ NamedExpr,
41
+ NodeTransformer,
42
+ NodeVisitor,
43
+ Pass,
44
+ Pow,
45
+ Return,
46
+ RShift,
47
+ Starred,
48
+ Store,
49
+ Sub,
50
+ Subscript,
51
+ Tuple,
52
+ Yield,
53
+ YieldFrom,
54
+ alias,
55
+ copy_location,
56
+ expr,
57
+ fix_missing_locations,
58
+ keyword,
59
+ walk,
60
+ )
61
+ from collections import defaultdict
62
+ from collections.abc import Generator, Sequence
63
+ from contextlib import contextmanager
64
+ from copy import deepcopy
65
+ from dataclasses import dataclass, field
66
+ from typing import Any, ClassVar, cast, overload
67
+
68
+ generator_names = (
69
+ "typing.Generator",
70
+ "collections.abc.Generator",
71
+ "typing.Iterator",
72
+ "collections.abc.Iterator",
73
+ "typing.Iterable",
74
+ "collections.abc.Iterable",
75
+ "typing.AsyncIterator",
76
+ "collections.abc.AsyncIterator",
77
+ "typing.AsyncIterable",
78
+ "collections.abc.AsyncIterable",
79
+ "typing.AsyncGenerator",
80
+ "collections.abc.AsyncGenerator",
81
+ )
82
+ anytype_names = (
83
+ "typing.Any",
84
+ "typing_extensions.Any",
85
+ )
86
+ literal_names = (
87
+ "typing.Literal",
88
+ "typing_extensions.Literal",
89
+ )
90
+ annotated_names = (
91
+ "typing.Annotated",
92
+ "typing_extensions.Annotated",
93
+ )
94
+ ignore_decorators = (
95
+ "typing.no_type_check",
96
+ "typeguard.typeguard_ignore",
97
+ )
98
+ aug_assign_functions = {
99
+ Add: "iadd",
100
+ Sub: "isub",
101
+ Mult: "imul",
102
+ MatMult: "imatmul",
103
+ Div: "itruediv",
104
+ FloorDiv: "ifloordiv",
105
+ Mod: "imod",
106
+ Pow: "ipow",
107
+ LShift: "ilshift",
108
+ RShift: "irshift",
109
+ BitAnd: "iand",
110
+ BitXor: "ixor",
111
+ BitOr: "ior",
112
+ }
113
+
114
+
115
+ @dataclass
116
+ class TransformMemo:
117
+ node: Module | ClassDef | FunctionDef | AsyncFunctionDef | None
118
+ parent: TransformMemo | None
119
+ path: tuple[str, ...]
120
+ joined_path: Constant = field(init=False)
121
+ return_annotation: expr | None = None
122
+ yield_annotation: expr | None = None
123
+ send_annotation: expr | None = None
124
+ is_async: bool = False
125
+ local_names: set[str] = field(init=False, default_factory=set)
126
+ imported_names: dict[str, str] = field(init=False, default_factory=dict)
127
+ ignored_names: set[str] = field(init=False, default_factory=set)
128
+ load_names: defaultdict[str, dict[str, Name]] = field(
129
+ init=False, default_factory=lambda: defaultdict(dict)
130
+ )
131
+ has_yield_expressions: bool = field(init=False, default=False)
132
+ has_return_expressions: bool = field(init=False, default=False)
133
+ memo_var_name: Name | None = field(init=False, default=None)
134
+ should_instrument: bool = field(init=False, default=True)
135
+ variable_annotations: dict[str, expr] = field(init=False, default_factory=dict)
136
+ configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict)
137
+ code_inject_index: int = field(init=False, default=0)
138
+
139
+ def __post_init__(self) -> None:
140
+ elements: list[str] = []
141
+ memo = self
142
+ while isinstance(memo.node, (ClassDef, FunctionDef, AsyncFunctionDef)):
143
+ elements.insert(0, memo.node.name)
144
+ if not memo.parent:
145
+ break
146
+
147
+ memo = memo.parent
148
+ if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
149
+ elements.insert(0, "<locals>")
150
+
151
+ self.joined_path = Constant(".".join(elements))
152
+
153
+ # Figure out where to insert instrumentation code
154
+ if self.node:
155
+ for index, child in enumerate(self.node.body):
156
+ if isinstance(child, ImportFrom) and child.module == "__future__":
157
+ # (module only) __future__ imports must come first
158
+ continue
159
+ elif (
160
+ isinstance(child, Expr)
161
+ and isinstance(child.value, Constant)
162
+ and isinstance(child.value.value, str)
163
+ ):
164
+ continue # docstring
165
+
166
+ self.code_inject_index = index
167
+ break
168
+
169
+ def get_unused_name(self, name: str) -> str:
170
+ memo: TransformMemo | None = self
171
+ while memo is not None:
172
+ if name in memo.local_names:
173
+ memo = self
174
+ name += "_"
175
+ else:
176
+ memo = memo.parent
177
+
178
+ self.local_names.add(name)
179
+ return name
180
+
181
+ def is_ignored_name(self, expression: expr | Expr | None) -> bool:
182
+ top_expression = (
183
+ expression.value if isinstance(expression, Expr) else expression
184
+ )
185
+
186
+ if isinstance(top_expression, Attribute) and isinstance(
187
+ top_expression.value, Name
188
+ ):
189
+ name = top_expression.value.id
190
+ elif isinstance(top_expression, Name):
191
+ name = top_expression.id
192
+ else:
193
+ return False
194
+
195
+ memo: TransformMemo | None = self
196
+ while memo is not None:
197
+ if name in memo.ignored_names:
198
+ return True
199
+
200
+ memo = memo.parent
201
+
202
+ return False
203
+
204
+ def get_memo_name(self) -> Name:
205
+ if not self.memo_var_name:
206
+ self.memo_var_name = Name(id="memo", ctx=Load())
207
+
208
+ return self.memo_var_name
209
+
210
+ def get_import(self, module: str, name: str) -> Name:
211
+ if module in self.load_names and name in self.load_names[module]:
212
+ return self.load_names[module][name]
213
+
214
+ qualified_name = f"{module}.{name}"
215
+ if name in self.imported_names and self.imported_names[name] == qualified_name:
216
+ return Name(id=name, ctx=Load())
217
+
218
+ alias = self.get_unused_name(name)
219
+ node = self.load_names[module][name] = Name(id=alias, ctx=Load())
220
+ self.imported_names[name] = qualified_name
221
+ return node
222
+
223
+ def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None:
224
+ """Insert imports needed by injected code."""
225
+ if not self.load_names:
226
+ return
227
+
228
+ # Insert imports after any "from __future__ ..." imports and any docstring
229
+ for modulename, names in self.load_names.items():
230
+ aliases = [
231
+ alias(orig_name, new_name.id if orig_name != new_name.id else None)
232
+ for orig_name, new_name in sorted(names.items())
233
+ ]
234
+ node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0))
235
+
236
+ def name_matches(self, expression: expr | Expr | None, *names: str) -> bool:
237
+ if expression is None:
238
+ return False
239
+
240
+ path: list[str] = []
241
+ top_expression = (
242
+ expression.value if isinstance(expression, Expr) else expression
243
+ )
244
+
245
+ if isinstance(top_expression, Subscript):
246
+ top_expression = top_expression.value
247
+ elif isinstance(top_expression, Call):
248
+ top_expression = top_expression.func
249
+
250
+ while isinstance(top_expression, Attribute):
251
+ path.insert(0, top_expression.attr)
252
+ top_expression = top_expression.value
253
+
254
+ if not isinstance(top_expression, Name):
255
+ return False
256
+
257
+ if top_expression.id in self.imported_names:
258
+ translated = self.imported_names[top_expression.id]
259
+ elif hasattr(builtins, top_expression.id):
260
+ translated = "builtins." + top_expression.id
261
+ else:
262
+ translated = top_expression.id
263
+
264
+ path.insert(0, translated)
265
+ joined_path = ".".join(path)
266
+ if joined_path in names:
267
+ return True
268
+ elif self.parent:
269
+ return self.parent.name_matches(expression, *names)
270
+ else:
271
+ return False
272
+
273
+ def get_config_keywords(self) -> list[keyword]:
274
+ if self.parent and isinstance(self.parent.node, ClassDef):
275
+ overrides = self.parent.configuration_overrides.copy()
276
+ else:
277
+ overrides = {}
278
+
279
+ overrides.update(self.configuration_overrides)
280
+ return [keyword(key, value) for key, value in overrides.items()]
281
+
282
+
283
+ class NameCollector(NodeVisitor):
284
+ def __init__(self) -> None:
285
+ self.names: set[str] = set()
286
+
287
+ def visit_Import(self, node: Import) -> None:
288
+ for name in node.names:
289
+ self.names.add(name.asname or name.name)
290
+
291
+ def visit_ImportFrom(self, node: ImportFrom) -> None:
292
+ for name in node.names:
293
+ self.names.add(name.asname or name.name)
294
+
295
+ def visit_Assign(self, node: Assign) -> None:
296
+ for target in node.targets:
297
+ if isinstance(target, Name):
298
+ self.names.add(target.id)
299
+
300
+ def visit_NamedExpr(self, node: NamedExpr) -> Any:
301
+ if isinstance(node.target, Name):
302
+ self.names.add(node.target.id)
303
+
304
+ def visit_FunctionDef(self, node: FunctionDef) -> None:
305
+ pass
306
+
307
+ def visit_ClassDef(self, node: ClassDef) -> None:
308
+ pass
309
+
310
+
311
+ class GeneratorDetector(NodeVisitor):
312
+ """Detects if a function node is a generator function."""
313
+
314
+ contains_yields: bool = False
315
+ in_root_function: bool = False
316
+
317
+ def visit_Yield(self, node: Yield) -> Any:
318
+ self.contains_yields = True
319
+
320
+ def visit_YieldFrom(self, node: YieldFrom) -> Any:
321
+ self.contains_yields = True
322
+
323
+ def visit_ClassDef(self, node: ClassDef) -> Any:
324
+ pass
325
+
326
+ def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any:
327
+ if not self.in_root_function:
328
+ self.in_root_function = True
329
+ self.generic_visit(node)
330
+ self.in_root_function = False
331
+
332
+ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> Any:
333
+ self.visit_FunctionDef(node)
334
+
335
+
336
+ class AnnotationTransformer(NodeTransformer):
337
+ type_substitutions: ClassVar[dict[str, tuple[str, str]]] = {
338
+ "builtins.dict": ("typing", "Dict"),
339
+ "builtins.list": ("typing", "List"),
340
+ "builtins.tuple": ("typing", "Tuple"),
341
+ "builtins.set": ("typing", "Set"),
342
+ "builtins.frozenset": ("typing", "FrozenSet"),
343
+ }
344
+
345
+ def __init__(self, transformer: TypeguardTransformer):
346
+ self.transformer = transformer
347
+ self._memo = transformer._memo
348
+ self._level = 0
349
+
350
+ def visit(self, node: AST) -> Any:
351
+ # Don't process Literals
352
+ if isinstance(node, expr) and self._memo.name_matches(node, *literal_names):
353
+ return node
354
+
355
+ self._level += 1
356
+ new_node = super().visit(node)
357
+ self._level -= 1
358
+
359
+ if isinstance(new_node, Expression) and not hasattr(new_node, "body"):
360
+ return None
361
+
362
+ # Return None if this new node matches a variation of typing.Any
363
+ if (
364
+ self._level == 0
365
+ and isinstance(new_node, expr)
366
+ and self._memo.name_matches(new_node, *anytype_names)
367
+ ):
368
+ return None
369
+
370
+ return new_node
371
+
372
+ def visit_BinOp(self, node: BinOp) -> Any:
373
+ self.generic_visit(node)
374
+
375
+ if isinstance(node.op, BitOr):
376
+ # If either branch of the BinOp has been transformed to `None`, it means
377
+ # that a type in the union was ignored, so the entire annotation should e
378
+ # ignored
379
+ if not hasattr(node, "left") or not hasattr(node, "right"):
380
+ return None
381
+
382
+ # Return Any if either side is Any
383
+ if self._memo.name_matches(node.left, *anytype_names):
384
+ return node.left
385
+ elif self._memo.name_matches(node.right, *anytype_names):
386
+ return node.right
387
+
388
+ if sys.version_info < (3, 10):
389
+ union_name = self.transformer._get_import("typing", "Union")
390
+ return Subscript(
391
+ value=union_name,
392
+ slice=Index(
393
+ Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
394
+ ),
395
+ ctx=Load(),
396
+ )
397
+
398
+ return node
399
+
400
+ def visit_Attribute(self, node: Attribute) -> Any:
401
+ if self._memo.is_ignored_name(node):
402
+ return None
403
+
404
+ return node
405
+
406
+ def visit_Subscript(self, node: Subscript) -> Any:
407
+ if self._memo.is_ignored_name(node.value):
408
+ return None
409
+
410
+ # The subscript of typing(_extensions).Literal can be any arbitrary string, so
411
+ # don't try to evaluate it as code
412
+ if node.slice:
413
+ if isinstance(node.slice, Index):
414
+ # Python 3.8
415
+ slice_value = node.slice.value # type: ignore[attr-defined]
416
+ else:
417
+ slice_value = node.slice
418
+
419
+ if isinstance(slice_value, Tuple):
420
+ if self._memo.name_matches(node.value, *annotated_names):
421
+ # Only treat the first argument to typing.Annotated as a potential
422
+ # forward reference
423
+ items = cast(
424
+ typing.List[expr],
425
+ [self.visit(slice_value.elts[0])] + slice_value.elts[1:],
426
+ )
427
+ else:
428
+ items = cast(
429
+ typing.List[expr],
430
+ [self.visit(item) for item in slice_value.elts],
431
+ )
432
+
433
+ # If this is a Union and any of the items is Any, erase the entire
434
+ # annotation
435
+ if self._memo.name_matches(node.value, "typing.Union") and any(
436
+ item is None
437
+ or (
438
+ isinstance(item, expr)
439
+ and self._memo.name_matches(item, *anytype_names)
440
+ )
441
+ for item in items
442
+ ):
443
+ return None
444
+
445
+ # If all items in the subscript were Any, erase the subscript entirely
446
+ if all(item is None for item in items):
447
+ return node.value
448
+
449
+ for index, item in enumerate(items):
450
+ if item is None:
451
+ items[index] = self.transformer._get_import("typing", "Any")
452
+
453
+ slice_value.elts = items
454
+ else:
455
+ self.generic_visit(node)
456
+
457
+ # If the transformer erased the slice entirely, just return the node
458
+ # value without the subscript (unless it's Optional, in which case erase
459
+ # the node entirely
460
+ if self._memo.name_matches(
461
+ node.value, "typing.Optional"
462
+ ) and not hasattr(node, "slice"):
463
+ return None
464
+ if sys.version_info >= (3, 9) and not hasattr(node, "slice"):
465
+ return node.value
466
+ elif sys.version_info < (3, 9) and not hasattr(node.slice, "value"):
467
+ return node.value
468
+
469
+ return node
470
+
471
+ def visit_Name(self, node: Name) -> Any:
472
+ if self._memo.is_ignored_name(node):
473
+ return None
474
+
475
+ if sys.version_info < (3, 9):
476
+ for typename, substitute in self.type_substitutions.items():
477
+ if self._memo.name_matches(node, typename):
478
+ new_node = self.transformer._get_import(*substitute)
479
+ return copy_location(new_node, node)
480
+
481
+ return node
482
+
483
+ def visit_Call(self, node: Call) -> Any:
484
+ # Don't recurse into calls
485
+ return node
486
+
487
+ def visit_Constant(self, node: Constant) -> Any:
488
+ if isinstance(node.value, str):
489
+ expression = ast.parse(node.value, mode="eval")
490
+ new_node = self.visit(expression)
491
+ if new_node:
492
+ return copy_location(new_node.body, node)
493
+ else:
494
+ return None
495
+
496
+ return node
497
+
498
+
499
+ class TypeguardTransformer(NodeTransformer):
500
+ def __init__(
501
+ self, target_path: Sequence[str] | None = None, target_lineno: int | None = None
502
+ ) -> None:
503
+ self._target_path = tuple(target_path) if target_path else None
504
+ self._memo = self._module_memo = TransformMemo(None, None, ())
505
+ self.names_used_in_annotations: set[str] = set()
506
+ self.target_node: FunctionDef | AsyncFunctionDef | None = None
507
+ self.target_lineno = target_lineno
508
+
509
+ def generic_visit(self, node: AST) -> AST:
510
+ has_non_empty_body_initially = bool(getattr(node, "body", None))
511
+ initial_type = type(node)
512
+
513
+ node = super().generic_visit(node)
514
+
515
+ if (
516
+ type(node) is initial_type
517
+ and has_non_empty_body_initially
518
+ and hasattr(node, "body")
519
+ and not node.body
520
+ ):
521
+ # If we have still the same node type after transformation
522
+ # but we've optimised it's body away, we add a `pass` statement.
523
+ node.body = [Pass()]
524
+
525
+ return node
526
+
527
+ @contextmanager
528
+ def _use_memo(
529
+ self, node: ClassDef | FunctionDef | AsyncFunctionDef
530
+ ) -> Generator[None, Any, None]:
531
+ new_memo = TransformMemo(node, self._memo, self._memo.path + (node.name,))
532
+ old_memo = self._memo
533
+ self._memo = new_memo
534
+
535
+ if isinstance(node, (FunctionDef, AsyncFunctionDef)):
536
+ new_memo.should_instrument = (
537
+ self._target_path is None or new_memo.path == self._target_path
538
+ )
539
+ if new_memo.should_instrument:
540
+ # Check if the function is a generator function
541
+ detector = GeneratorDetector()
542
+ detector.visit(node)
543
+
544
+ # Extract yield, send and return types where possible from a subscripted
545
+ # annotation like Generator[int, str, bool]
546
+ return_annotation = deepcopy(node.returns)
547
+ if detector.contains_yields and new_memo.name_matches(
548
+ return_annotation, *generator_names
549
+ ):
550
+ if isinstance(return_annotation, Subscript):
551
+ annotation_slice = return_annotation.slice
552
+
553
+ # Python < 3.9
554
+ if isinstance(annotation_slice, Index):
555
+ annotation_slice = (
556
+ annotation_slice.value # type: ignore[attr-defined]
557
+ )
558
+
559
+ if isinstance(annotation_slice, Tuple):
560
+ items = annotation_slice.elts
561
+ else:
562
+ items = [annotation_slice]
563
+
564
+ if len(items) > 0:
565
+ new_memo.yield_annotation = self._convert_annotation(
566
+ items[0]
567
+ )
568
+
569
+ if len(items) > 1:
570
+ new_memo.send_annotation = self._convert_annotation(
571
+ items[1]
572
+ )
573
+
574
+ if len(items) > 2:
575
+ new_memo.return_annotation = self._convert_annotation(
576
+ items[2]
577
+ )
578
+ else:
579
+ new_memo.return_annotation = self._convert_annotation(
580
+ return_annotation
581
+ )
582
+
583
+ if isinstance(node, AsyncFunctionDef):
584
+ new_memo.is_async = True
585
+
586
+ yield
587
+ self._memo = old_memo
588
+
589
+ def _get_import(self, module: str, name: str) -> Name:
590
+ memo = self._memo if self._target_path else self._module_memo
591
+ return memo.get_import(module, name)
592
+
593
+ @overload
594
+ def _convert_annotation(self, annotation: None) -> None: ...
595
+
596
+ @overload
597
+ def _convert_annotation(self, annotation: expr) -> expr: ...
598
+
599
+ def _convert_annotation(self, annotation: expr | None) -> expr | None:
600
+ if annotation is None:
601
+ return None
602
+
603
+ # Convert PEP 604 unions (x | y) and generic built-in collections where
604
+ # necessary, and undo forward references
605
+ new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
606
+ if isinstance(new_annotation, expr):
607
+ new_annotation = ast.copy_location(new_annotation, annotation)
608
+
609
+ # Store names used in the annotation
610
+ names = {node.id for node in walk(new_annotation) if isinstance(node, Name)}
611
+ self.names_used_in_annotations.update(names)
612
+
613
+ return new_annotation
614
+
615
+ def visit_Name(self, node: Name) -> Name:
616
+ self._memo.local_names.add(node.id)
617
+ return node
618
+
619
+ def visit_Module(self, node: Module) -> Module:
620
+ self._module_memo = self._memo = TransformMemo(node, None, ())
621
+ self.generic_visit(node)
622
+ self._module_memo.insert_imports(node)
623
+
624
+ fix_missing_locations(node)
625
+ return node
626
+
627
+ def visit_Import(self, node: Import) -> Import:
628
+ for name in node.names:
629
+ self._memo.local_names.add(name.asname or name.name)
630
+ self._memo.imported_names[name.asname or name.name] = name.name
631
+
632
+ return node
633
+
634
+ def visit_ImportFrom(self, node: ImportFrom) -> ImportFrom:
635
+ for name in node.names:
636
+ if name.name != "*":
637
+ alias = name.asname or name.name
638
+ self._memo.local_names.add(alias)
639
+ self._memo.imported_names[alias] = f"{node.module}.{name.name}"
640
+
641
+ return node
642
+
643
+ def visit_ClassDef(self, node: ClassDef) -> ClassDef | None:
644
+ self._memo.local_names.add(node.name)
645
+
646
+ # Eliminate top level classes not belonging to the target path
647
+ if (
648
+ self._target_path is not None
649
+ and not self._memo.path
650
+ and node.name != self._target_path[0]
651
+ ):
652
+ return None
653
+
654
+ with self._use_memo(node):
655
+ for decorator in node.decorator_list.copy():
656
+ if self._memo.name_matches(decorator, "typeguard.typechecked"):
657
+ # Remove the decorator to prevent duplicate instrumentation
658
+ node.decorator_list.remove(decorator)
659
+
660
+ # Store any configuration overrides
661
+ if isinstance(decorator, Call) and decorator.keywords:
662
+ self._memo.configuration_overrides.update(
663
+ {kw.arg: kw.value for kw in decorator.keywords if kw.arg}
664
+ )
665
+
666
+ self.generic_visit(node)
667
+ return node
668
+
669
+ def visit_FunctionDef(
670
+ self, node: FunctionDef | AsyncFunctionDef
671
+ ) -> FunctionDef | AsyncFunctionDef | None:
672
+ """
673
+ Injects type checks for function arguments, and for a return of None if the
674
+ function is annotated to return something else than Any or None, and the body
675
+ ends without an explicit "return".
676
+
677
+ """
678
+ self._memo.local_names.add(node.name)
679
+
680
+ # Eliminate top level functions not belonging to the target path
681
+ if (
682
+ self._target_path is not None
683
+ and not self._memo.path
684
+ and node.name != self._target_path[0]
685
+ ):
686
+ return None
687
+
688
+ # Skip instrumentation if we're instrumenting the whole module and the function
689
+ # contains either @no_type_check or @typeguard_ignore
690
+ if self._target_path is None:
691
+ for decorator in node.decorator_list:
692
+ if self._memo.name_matches(decorator, *ignore_decorators):
693
+ return node
694
+
695
+ with self._use_memo(node):
696
+ arg_annotations: dict[str, Any] = {}
697
+ if self._target_path is None or self._memo.path == self._target_path:
698
+ # Find line number we're supposed to match against
699
+ if node.decorator_list:
700
+ first_lineno = node.decorator_list[0].lineno
701
+ else:
702
+ first_lineno = node.lineno
703
+
704
+ for decorator in node.decorator_list.copy():
705
+ if self._memo.name_matches(decorator, "typing.overload"):
706
+ # Remove overloads entirely
707
+ return None
708
+ elif self._memo.name_matches(decorator, "typeguard.typechecked"):
709
+ # Remove the decorator to prevent duplicate instrumentation
710
+ node.decorator_list.remove(decorator)
711
+
712
+ # Store any configuration overrides
713
+ if isinstance(decorator, Call) and decorator.keywords:
714
+ self._memo.configuration_overrides = {
715
+ kw.arg: kw.value for kw in decorator.keywords if kw.arg
716
+ }
717
+
718
+ if self.target_lineno == first_lineno:
719
+ assert self.target_node is None
720
+ self.target_node = node
721
+ if node.decorator_list:
722
+ self.target_lineno = node.decorator_list[0].lineno
723
+ else:
724
+ self.target_lineno = node.lineno
725
+
726
+ all_args = node.args.args + node.args.kwonlyargs + node.args.posonlyargs
727
+
728
+ # Ensure that any type shadowed by the positional or keyword-only
729
+ # argument names are ignored in this function
730
+ for arg in all_args:
731
+ self._memo.ignored_names.add(arg.arg)
732
+
733
+ # Ensure that any type shadowed by the variable positional argument name
734
+ # (e.g. "args" in *args) is ignored this function
735
+ if node.args.vararg:
736
+ self._memo.ignored_names.add(node.args.vararg.arg)
737
+
738
+ # Ensure that any type shadowed by the variable keywrod argument name
739
+ # (e.g. "kwargs" in *kwargs) is ignored this function
740
+ if node.args.kwarg:
741
+ self._memo.ignored_names.add(node.args.kwarg.arg)
742
+
743
+ for arg in all_args:
744
+ annotation = self._convert_annotation(deepcopy(arg.annotation))
745
+ if annotation:
746
+ arg_annotations[arg.arg] = annotation
747
+
748
+ if node.args.vararg:
749
+ annotation_ = self._convert_annotation(node.args.vararg.annotation)
750
+ if annotation_:
751
+ if sys.version_info >= (3, 9):
752
+ container = Name("tuple", ctx=Load())
753
+ else:
754
+ container = self._get_import("typing", "Tuple")
755
+
756
+ subscript_slice: Tuple | Index = Tuple(
757
+ [
758
+ annotation_,
759
+ Constant(Ellipsis),
760
+ ],
761
+ ctx=Load(),
762
+ )
763
+ if sys.version_info < (3, 9):
764
+ subscript_slice = Index(subscript_slice, ctx=Load())
765
+
766
+ arg_annotations[node.args.vararg.arg] = Subscript(
767
+ container, subscript_slice, ctx=Load()
768
+ )
769
+
770
+ if node.args.kwarg:
771
+ annotation_ = self._convert_annotation(node.args.kwarg.annotation)
772
+ if annotation_:
773
+ if sys.version_info >= (3, 9):
774
+ container = Name("dict", ctx=Load())
775
+ else:
776
+ container = self._get_import("typing", "Dict")
777
+
778
+ subscript_slice = Tuple(
779
+ [
780
+ Name("str", ctx=Load()),
781
+ annotation_,
782
+ ],
783
+ ctx=Load(),
784
+ )
785
+ if sys.version_info < (3, 9):
786
+ subscript_slice = Index(subscript_slice, ctx=Load())
787
+
788
+ arg_annotations[node.args.kwarg.arg] = Subscript(
789
+ container, subscript_slice, ctx=Load()
790
+ )
791
+
792
+ if arg_annotations:
793
+ self._memo.variable_annotations.update(arg_annotations)
794
+
795
+ self.generic_visit(node)
796
+
797
+ if arg_annotations:
798
+ annotations_dict = Dict(
799
+ keys=[Constant(key) for key in arg_annotations.keys()],
800
+ values=[
801
+ Tuple([Name(key, ctx=Load()), annotation], ctx=Load())
802
+ for key, annotation in arg_annotations.items()
803
+ ],
804
+ )
805
+ func_name = self._get_import(
806
+ "typeguard._functions", "check_argument_types"
807
+ )
808
+ args = [
809
+ self._memo.joined_path,
810
+ annotations_dict,
811
+ self._memo.get_memo_name(),
812
+ ]
813
+ node.body.insert(
814
+ self._memo.code_inject_index, Expr(Call(func_name, args, []))
815
+ )
816
+
817
+ # Add a checked "return None" to the end if there's no explicit return
818
+ # Skip if the return annotation is None or Any
819
+ if (
820
+ self._memo.return_annotation
821
+ and (not self._memo.is_async or not self._memo.has_yield_expressions)
822
+ and not isinstance(node.body[-1], Return)
823
+ and (
824
+ not isinstance(self._memo.return_annotation, Constant)
825
+ or self._memo.return_annotation.value is not None
826
+ )
827
+ ):
828
+ func_name = self._get_import(
829
+ "typeguard._functions", "check_return_type"
830
+ )
831
+ return_node = Return(
832
+ Call(
833
+ func_name,
834
+ [
835
+ self._memo.joined_path,
836
+ Constant(None),
837
+ self._memo.return_annotation,
838
+ self._memo.get_memo_name(),
839
+ ],
840
+ [],
841
+ )
842
+ )
843
+
844
+ # Replace a placeholder "pass" at the end
845
+ if isinstance(node.body[-1], Pass):
846
+ copy_location(return_node, node.body[-1])
847
+ del node.body[-1]
848
+
849
+ node.body.append(return_node)
850
+
851
+ # Insert code to create the call memo, if it was ever needed for this
852
+ # function
853
+ if self._memo.memo_var_name:
854
+ memo_kwargs: dict[str, Any] = {}
855
+ if self._memo.parent and isinstance(self._memo.parent.node, ClassDef):
856
+ for decorator in node.decorator_list:
857
+ if (
858
+ isinstance(decorator, Name)
859
+ and decorator.id == "staticmethod"
860
+ ):
861
+ break
862
+ elif (
863
+ isinstance(decorator, Name)
864
+ and decorator.id == "classmethod"
865
+ ):
866
+ memo_kwargs["self_type"] = Name(
867
+ id=node.args.args[0].arg, ctx=Load()
868
+ )
869
+ break
870
+ else:
871
+ if node.args.args:
872
+ if node.name == "__new__":
873
+ memo_kwargs["self_type"] = Name(
874
+ id=node.args.args[0].arg, ctx=Load()
875
+ )
876
+ else:
877
+ memo_kwargs["self_type"] = Attribute(
878
+ Name(id=node.args.args[0].arg, ctx=Load()),
879
+ "__class__",
880
+ ctx=Load(),
881
+ )
882
+
883
+ # Construct the function reference
884
+ # Nested functions get special treatment: the function name is added
885
+ # to free variables (and the closure of the resulting function)
886
+ names: list[str] = [node.name]
887
+ memo = self._memo.parent
888
+ while memo:
889
+ if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)):
890
+ # This is a nested function. Use the function name as-is.
891
+ del names[:-1]
892
+ break
893
+ elif not isinstance(memo.node, ClassDef):
894
+ break
895
+
896
+ names.insert(0, memo.node.name)
897
+ memo = memo.parent
898
+
899
+ config_keywords = self._memo.get_config_keywords()
900
+ if config_keywords:
901
+ memo_kwargs["config"] = Call(
902
+ self._get_import("dataclasses", "replace"),
903
+ [self._get_import("typeguard._config", "global_config")],
904
+ config_keywords,
905
+ )
906
+
907
+ self._memo.memo_var_name.id = self._memo.get_unused_name("memo")
908
+ memo_store_name = Name(id=self._memo.memo_var_name.id, ctx=Store())
909
+ globals_call = Call(Name(id="globals", ctx=Load()), [], [])
910
+ locals_call = Call(Name(id="locals", ctx=Load()), [], [])
911
+ memo_expr = Call(
912
+ self._get_import("typeguard", "TypeCheckMemo"),
913
+ [globals_call, locals_call],
914
+ [keyword(key, value) for key, value in memo_kwargs.items()],
915
+ )
916
+ node.body.insert(
917
+ self._memo.code_inject_index,
918
+ Assign([memo_store_name], memo_expr),
919
+ )
920
+
921
+ self._memo.insert_imports(node)
922
+
923
+ # Special case the __new__() method to create a local alias from the
924
+ # class name to the first argument (usually "cls")
925
+ if (
926
+ isinstance(node, FunctionDef)
927
+ and node.args
928
+ and self._memo.parent is not None
929
+ and isinstance(self._memo.parent.node, ClassDef)
930
+ and node.name == "__new__"
931
+ ):
932
+ first_args_expr = Name(node.args.args[0].arg, ctx=Load())
933
+ cls_name = Name(self._memo.parent.node.name, ctx=Store())
934
+ node.body.insert(
935
+ self._memo.code_inject_index,
936
+ Assign([cls_name], first_args_expr),
937
+ )
938
+
939
+ # Rmove any placeholder "pass" at the end
940
+ if isinstance(node.body[-1], Pass):
941
+ del node.body[-1]
942
+
943
+ return node
944
+
945
+ def visit_AsyncFunctionDef(
946
+ self, node: AsyncFunctionDef
947
+ ) -> FunctionDef | AsyncFunctionDef | None:
948
+ return self.visit_FunctionDef(node)
949
+
950
+ def visit_Return(self, node: Return) -> Return:
951
+ """This injects type checks into "return" statements."""
952
+ self.generic_visit(node)
953
+ if (
954
+ self._memo.return_annotation
955
+ and self._memo.should_instrument
956
+ and not self._memo.is_ignored_name(self._memo.return_annotation)
957
+ ):
958
+ func_name = self._get_import("typeguard._functions", "check_return_type")
959
+ old_node = node
960
+ retval = old_node.value or Constant(None)
961
+ node = Return(
962
+ Call(
963
+ func_name,
964
+ [
965
+ self._memo.joined_path,
966
+ retval,
967
+ self._memo.return_annotation,
968
+ self._memo.get_memo_name(),
969
+ ],
970
+ [],
971
+ )
972
+ )
973
+ copy_location(node, old_node)
974
+
975
+ return node
976
+
977
+ def visit_Yield(self, node: Yield) -> Yield | Call:
978
+ """
979
+ This injects type checks into "yield" expressions, checking both the yielded
980
+ value and the value sent back to the generator, when appropriate.
981
+
982
+ """
983
+ self._memo.has_yield_expressions = True
984
+ self.generic_visit(node)
985
+
986
+ if (
987
+ self._memo.yield_annotation
988
+ and self._memo.should_instrument
989
+ and not self._memo.is_ignored_name(self._memo.yield_annotation)
990
+ ):
991
+ func_name = self._get_import("typeguard._functions", "check_yield_type")
992
+ yieldval = node.value or Constant(None)
993
+ node.value = Call(
994
+ func_name,
995
+ [
996
+ self._memo.joined_path,
997
+ yieldval,
998
+ self._memo.yield_annotation,
999
+ self._memo.get_memo_name(),
1000
+ ],
1001
+ [],
1002
+ )
1003
+
1004
+ if (
1005
+ self._memo.send_annotation
1006
+ and self._memo.should_instrument
1007
+ and not self._memo.is_ignored_name(self._memo.send_annotation)
1008
+ ):
1009
+ func_name = self._get_import("typeguard._functions", "check_send_type")
1010
+ old_node = node
1011
+ call_node = Call(
1012
+ func_name,
1013
+ [
1014
+ self._memo.joined_path,
1015
+ old_node,
1016
+ self._memo.send_annotation,
1017
+ self._memo.get_memo_name(),
1018
+ ],
1019
+ [],
1020
+ )
1021
+ copy_location(call_node, old_node)
1022
+ return call_node
1023
+
1024
+ return node
1025
+
1026
+ def visit_AnnAssign(self, node: AnnAssign) -> Any:
1027
+ """
1028
+ This injects a type check into a local variable annotation-assignment within a
1029
+ function body.
1030
+
1031
+ """
1032
+ self.generic_visit(node)
1033
+
1034
+ if (
1035
+ isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef))
1036
+ and node.annotation
1037
+ and isinstance(node.target, Name)
1038
+ ):
1039
+ self._memo.ignored_names.add(node.target.id)
1040
+ annotation = self._convert_annotation(deepcopy(node.annotation))
1041
+ if annotation:
1042
+ self._memo.variable_annotations[node.target.id] = annotation
1043
+ if node.value:
1044
+ func_name = self._get_import(
1045
+ "typeguard._functions", "check_variable_assignment"
1046
+ )
1047
+ node.value = Call(
1048
+ func_name,
1049
+ [
1050
+ node.value,
1051
+ Constant(node.target.id),
1052
+ annotation,
1053
+ self._memo.get_memo_name(),
1054
+ ],
1055
+ [],
1056
+ )
1057
+
1058
+ return node
1059
+
1060
+ def visit_Assign(self, node: Assign) -> Any:
1061
+ """
1062
+ This injects a type check into a local variable assignment within a function
1063
+ body. The variable must have been annotated earlier in the function body.
1064
+
1065
+ """
1066
+ self.generic_visit(node)
1067
+
1068
+ # Only instrument function-local assignments
1069
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)):
1070
+ targets: list[dict[Constant, expr | None]] = []
1071
+ check_required = False
1072
+ for target in node.targets:
1073
+ elts: Sequence[expr]
1074
+ if isinstance(target, Name):
1075
+ elts = [target]
1076
+ elif isinstance(target, Tuple):
1077
+ elts = target.elts
1078
+ else:
1079
+ continue
1080
+
1081
+ annotations_: dict[Constant, expr | None] = {}
1082
+ for exp in elts:
1083
+ prefix = ""
1084
+ if isinstance(exp, Starred):
1085
+ exp = exp.value
1086
+ prefix = "*"
1087
+
1088
+ if isinstance(exp, Name):
1089
+ self._memo.ignored_names.add(exp.id)
1090
+ name = prefix + exp.id
1091
+ annotation = self._memo.variable_annotations.get(exp.id)
1092
+ if annotation:
1093
+ annotations_[Constant(name)] = annotation
1094
+ check_required = True
1095
+ else:
1096
+ annotations_[Constant(name)] = None
1097
+
1098
+ targets.append(annotations_)
1099
+
1100
+ if check_required:
1101
+ # Replace missing annotations with typing.Any
1102
+ for item in targets:
1103
+ for key, expression in item.items():
1104
+ if expression is None:
1105
+ item[key] = self._get_import("typing", "Any")
1106
+
1107
+ if len(targets) == 1 and len(targets[0]) == 1:
1108
+ func_name = self._get_import(
1109
+ "typeguard._functions", "check_variable_assignment"
1110
+ )
1111
+ target_varname = next(iter(targets[0]))
1112
+ node.value = Call(
1113
+ func_name,
1114
+ [
1115
+ node.value,
1116
+ target_varname,
1117
+ targets[0][target_varname],
1118
+ self._memo.get_memo_name(),
1119
+ ],
1120
+ [],
1121
+ )
1122
+ elif targets:
1123
+ func_name = self._get_import(
1124
+ "typeguard._functions", "check_multi_variable_assignment"
1125
+ )
1126
+ targets_arg = List(
1127
+ [
1128
+ Dict(keys=list(target), values=list(target.values()))
1129
+ for target in targets
1130
+ ],
1131
+ ctx=Load(),
1132
+ )
1133
+ node.value = Call(
1134
+ func_name,
1135
+ [node.value, targets_arg, self._memo.get_memo_name()],
1136
+ [],
1137
+ )
1138
+
1139
+ return node
1140
+
1141
+ def visit_NamedExpr(self, node: NamedExpr) -> Any:
1142
+ """This injects a type check into an assignment expression (a := foo())."""
1143
+ self.generic_visit(node)
1144
+
1145
+ # Only instrument function-local assignments
1146
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
1147
+ node.target, Name
1148
+ ):
1149
+ self._memo.ignored_names.add(node.target.id)
1150
+
1151
+ # Bail out if no matching annotation is found
1152
+ annotation = self._memo.variable_annotations.get(node.target.id)
1153
+ if annotation is None:
1154
+ return node
1155
+
1156
+ func_name = self._get_import(
1157
+ "typeguard._functions", "check_variable_assignment"
1158
+ )
1159
+ node.value = Call(
1160
+ func_name,
1161
+ [
1162
+ node.value,
1163
+ Constant(node.target.id),
1164
+ annotation,
1165
+ self._memo.get_memo_name(),
1166
+ ],
1167
+ [],
1168
+ )
1169
+
1170
+ return node
1171
+
1172
+ def visit_AugAssign(self, node: AugAssign) -> Any:
1173
+ """
1174
+ This injects a type check into an augmented assignment expression (a += 1).
1175
+
1176
+ """
1177
+ self.generic_visit(node)
1178
+
1179
+ # Only instrument function-local assignments
1180
+ if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance(
1181
+ node.target, Name
1182
+ ):
1183
+ # Bail out if no matching annotation is found
1184
+ annotation = self._memo.variable_annotations.get(node.target.id)
1185
+ if annotation is None:
1186
+ return node
1187
+
1188
+ # Bail out if the operator is not found (newer Python version?)
1189
+ try:
1190
+ operator_func_name = aug_assign_functions[node.op.__class__]
1191
+ except KeyError:
1192
+ return node
1193
+
1194
+ operator_func = self._get_import("operator", operator_func_name)
1195
+ operator_call = Call(
1196
+ operator_func, [Name(node.target.id, ctx=Load()), node.value], []
1197
+ )
1198
+ check_call = Call(
1199
+ self._get_import("typeguard._functions", "check_variable_assignment"),
1200
+ [
1201
+ operator_call,
1202
+ Constant(node.target.id),
1203
+ annotation,
1204
+ self._memo.get_memo_name(),
1205
+ ],
1206
+ [],
1207
+ )
1208
+ return Assign(targets=[node.target], value=check_call)
1209
+
1210
+ return node
1211
+
1212
+ def visit_If(self, node: If) -> Any:
1213
+ """
1214
+ This blocks names from being collected from a module-level
1215
+ "if typing.TYPE_CHECKING:" block, so that they won't be type checked.
1216
+
1217
+ """
1218
+ self.generic_visit(node)
1219
+
1220
+ if (
1221
+ self._memo is self._module_memo
1222
+ and isinstance(node.test, Name)
1223
+ and self._memo.name_matches(node.test, "typing.TYPE_CHECKING")
1224
+ ):
1225
+ collector = NameCollector()
1226
+ collector.visit(node)
1227
+ self._memo.ignored_names.update(collector.names)
1228
+
1229
+ return node