ob-metaflow 2.11.13.1__py2.py3-none-any.whl → 2.19.7.1rc0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. metaflow/R.py +10 -7
  2. metaflow/__init__.py +40 -25
  3. metaflow/_vendor/imghdr/__init__.py +186 -0
  4. metaflow/_vendor/importlib_metadata/__init__.py +1063 -0
  5. metaflow/_vendor/importlib_metadata/_adapters.py +68 -0
  6. metaflow/_vendor/importlib_metadata/_collections.py +30 -0
  7. metaflow/_vendor/importlib_metadata/_compat.py +71 -0
  8. metaflow/_vendor/importlib_metadata/_functools.py +104 -0
  9. metaflow/_vendor/importlib_metadata/_itertools.py +73 -0
  10. metaflow/_vendor/importlib_metadata/_meta.py +48 -0
  11. metaflow/_vendor/importlib_metadata/_text.py +99 -0
  12. metaflow/_vendor/importlib_metadata/py.typed +0 -0
  13. metaflow/_vendor/typeguard/__init__.py +48 -0
  14. metaflow/_vendor/typeguard/_checkers.py +1070 -0
  15. metaflow/_vendor/typeguard/_config.py +108 -0
  16. metaflow/_vendor/typeguard/_decorators.py +233 -0
  17. metaflow/_vendor/typeguard/_exceptions.py +42 -0
  18. metaflow/_vendor/typeguard/_functions.py +308 -0
  19. metaflow/_vendor/typeguard/_importhook.py +213 -0
  20. metaflow/_vendor/typeguard/_memo.py +48 -0
  21. metaflow/_vendor/typeguard/_pytest_plugin.py +127 -0
  22. metaflow/_vendor/typeguard/_suppression.py +86 -0
  23. metaflow/_vendor/typeguard/_transformer.py +1229 -0
  24. metaflow/_vendor/typeguard/_union_transformer.py +55 -0
  25. metaflow/_vendor/typeguard/_utils.py +173 -0
  26. metaflow/_vendor/typeguard/py.typed +0 -0
  27. metaflow/_vendor/typing_extensions.py +3641 -0
  28. metaflow/_vendor/v3_7/importlib_metadata/__init__.py +1063 -0
  29. metaflow/_vendor/v3_7/importlib_metadata/_adapters.py +68 -0
  30. metaflow/_vendor/v3_7/importlib_metadata/_collections.py +30 -0
  31. metaflow/_vendor/v3_7/importlib_metadata/_compat.py +71 -0
  32. metaflow/_vendor/v3_7/importlib_metadata/_functools.py +104 -0
  33. metaflow/_vendor/v3_7/importlib_metadata/_itertools.py +73 -0
  34. metaflow/_vendor/v3_7/importlib_metadata/_meta.py +48 -0
  35. metaflow/_vendor/v3_7/importlib_metadata/_text.py +99 -0
  36. metaflow/_vendor/v3_7/importlib_metadata/py.typed +0 -0
  37. metaflow/_vendor/v3_7/typeguard/__init__.py +48 -0
  38. metaflow/_vendor/v3_7/typeguard/_checkers.py +906 -0
  39. metaflow/_vendor/v3_7/typeguard/_config.py +108 -0
  40. metaflow/_vendor/v3_7/typeguard/_decorators.py +237 -0
  41. metaflow/_vendor/v3_7/typeguard/_exceptions.py +42 -0
  42. metaflow/_vendor/v3_7/typeguard/_functions.py +310 -0
  43. metaflow/_vendor/v3_7/typeguard/_importhook.py +213 -0
  44. metaflow/_vendor/v3_7/typeguard/_memo.py +48 -0
  45. metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py +100 -0
  46. metaflow/_vendor/v3_7/typeguard/_suppression.py +88 -0
  47. metaflow/_vendor/v3_7/typeguard/_transformer.py +1207 -0
  48. metaflow/_vendor/v3_7/typeguard/_union_transformer.py +54 -0
  49. metaflow/_vendor/v3_7/typeguard/_utils.py +169 -0
  50. metaflow/_vendor/v3_7/typeguard/py.typed +0 -0
  51. metaflow/_vendor/v3_7/typing_extensions.py +3072 -0
  52. metaflow/_vendor/yaml/__init__.py +427 -0
  53. metaflow/_vendor/yaml/composer.py +139 -0
  54. metaflow/_vendor/yaml/constructor.py +748 -0
  55. metaflow/_vendor/yaml/cyaml.py +101 -0
  56. metaflow/_vendor/yaml/dumper.py +62 -0
  57. metaflow/_vendor/yaml/emitter.py +1137 -0
  58. metaflow/_vendor/yaml/error.py +75 -0
  59. metaflow/_vendor/yaml/events.py +86 -0
  60. metaflow/_vendor/yaml/loader.py +63 -0
  61. metaflow/_vendor/yaml/nodes.py +49 -0
  62. metaflow/_vendor/yaml/parser.py +589 -0
  63. metaflow/_vendor/yaml/reader.py +185 -0
  64. metaflow/_vendor/yaml/representer.py +389 -0
  65. metaflow/_vendor/yaml/resolver.py +227 -0
  66. metaflow/_vendor/yaml/scanner.py +1435 -0
  67. metaflow/_vendor/yaml/serializer.py +111 -0
  68. metaflow/_vendor/yaml/tokens.py +104 -0
  69. metaflow/cards.py +5 -0
  70. metaflow/cli.py +331 -785
  71. metaflow/cli_args.py +17 -0
  72. metaflow/cli_components/__init__.py +0 -0
  73. metaflow/cli_components/dump_cmd.py +96 -0
  74. metaflow/cli_components/init_cmd.py +52 -0
  75. metaflow/cli_components/run_cmds.py +546 -0
  76. metaflow/cli_components/step_cmd.py +334 -0
  77. metaflow/cli_components/utils.py +140 -0
  78. metaflow/client/__init__.py +1 -0
  79. metaflow/client/core.py +467 -73
  80. metaflow/client/filecache.py +75 -35
  81. metaflow/clone_util.py +7 -1
  82. metaflow/cmd/code/__init__.py +231 -0
  83. metaflow/cmd/develop/stub_generator.py +756 -288
  84. metaflow/cmd/develop/stubs.py +12 -28
  85. metaflow/cmd/main_cli.py +6 -4
  86. metaflow/cmd/make_wrapper.py +78 -0
  87. metaflow/datastore/__init__.py +1 -0
  88. metaflow/datastore/content_addressed_store.py +41 -10
  89. metaflow/datastore/datastore_set.py +11 -2
  90. metaflow/datastore/flow_datastore.py +156 -10
  91. metaflow/datastore/spin_datastore.py +91 -0
  92. metaflow/datastore/task_datastore.py +154 -39
  93. metaflow/debug.py +5 -0
  94. metaflow/decorators.py +404 -78
  95. metaflow/exception.py +8 -2
  96. metaflow/extension_support/__init__.py +527 -376
  97. metaflow/extension_support/_empty_file.py +2 -2
  98. metaflow/extension_support/plugins.py +49 -31
  99. metaflow/flowspec.py +482 -33
  100. metaflow/graph.py +210 -42
  101. metaflow/includefile.py +84 -40
  102. metaflow/lint.py +141 -22
  103. metaflow/meta_files.py +13 -0
  104. metaflow/{metadata → metadata_provider}/heartbeat.py +24 -8
  105. metaflow/{metadata → metadata_provider}/metadata.py +86 -1
  106. metaflow/metaflow_config.py +175 -28
  107. metaflow/metaflow_config_funcs.py +51 -3
  108. metaflow/metaflow_current.py +4 -10
  109. metaflow/metaflow_environment.py +139 -53
  110. metaflow/metaflow_git.py +115 -0
  111. metaflow/metaflow_profile.py +18 -0
  112. metaflow/metaflow_version.py +150 -66
  113. metaflow/mflog/__init__.py +4 -3
  114. metaflow/mflog/save_logs.py +2 -2
  115. metaflow/multicore_utils.py +31 -14
  116. metaflow/package/__init__.py +673 -0
  117. metaflow/packaging_sys/__init__.py +880 -0
  118. metaflow/packaging_sys/backend.py +128 -0
  119. metaflow/packaging_sys/distribution_support.py +153 -0
  120. metaflow/packaging_sys/tar_backend.py +99 -0
  121. metaflow/packaging_sys/utils.py +54 -0
  122. metaflow/packaging_sys/v1.py +527 -0
  123. metaflow/parameters.py +149 -28
  124. metaflow/plugins/__init__.py +74 -5
  125. metaflow/plugins/airflow/airflow.py +40 -25
  126. metaflow/plugins/airflow/airflow_cli.py +22 -5
  127. metaflow/plugins/airflow/airflow_decorator.py +1 -1
  128. metaflow/plugins/airflow/airflow_utils.py +5 -3
  129. metaflow/plugins/airflow/sensors/base_sensor.py +4 -4
  130. metaflow/plugins/airflow/sensors/external_task_sensor.py +2 -2
  131. metaflow/plugins/airflow/sensors/s3_sensor.py +2 -2
  132. metaflow/plugins/argo/argo_client.py +78 -33
  133. metaflow/plugins/argo/argo_events.py +6 -6
  134. metaflow/plugins/argo/argo_workflows.py +2410 -527
  135. metaflow/plugins/argo/argo_workflows_cli.py +571 -121
  136. metaflow/plugins/argo/argo_workflows_decorator.py +43 -12
  137. metaflow/plugins/argo/argo_workflows_deployer.py +106 -0
  138. metaflow/plugins/argo/argo_workflows_deployer_objects.py +453 -0
  139. metaflow/plugins/argo/capture_error.py +73 -0
  140. metaflow/plugins/argo/conditional_input_paths.py +35 -0
  141. metaflow/plugins/argo/exit_hooks.py +209 -0
  142. metaflow/plugins/argo/jobset_input_paths.py +15 -0
  143. metaflow/plugins/argo/param_val.py +19 -0
  144. metaflow/plugins/aws/aws_client.py +10 -3
  145. metaflow/plugins/aws/aws_utils.py +55 -2
  146. metaflow/plugins/aws/batch/batch.py +72 -5
  147. metaflow/plugins/aws/batch/batch_cli.py +33 -10
  148. metaflow/plugins/aws/batch/batch_client.py +4 -3
  149. metaflow/plugins/aws/batch/batch_decorator.py +102 -35
  150. metaflow/plugins/aws/secrets_manager/aws_secrets_manager_secrets_provider.py +13 -10
  151. metaflow/plugins/aws/step_functions/dynamo_db_client.py +0 -3
  152. metaflow/plugins/aws/step_functions/production_token.py +1 -1
  153. metaflow/plugins/aws/step_functions/step_functions.py +65 -8
  154. metaflow/plugins/aws/step_functions/step_functions_cli.py +101 -7
  155. metaflow/plugins/aws/step_functions/step_functions_decorator.py +1 -2
  156. metaflow/plugins/aws/step_functions/step_functions_deployer.py +97 -0
  157. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +264 -0
  158. metaflow/plugins/azure/azure_exceptions.py +1 -1
  159. metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
  160. metaflow/plugins/azure/azure_tail.py +1 -1
  161. metaflow/plugins/azure/includefile_support.py +2 -0
  162. metaflow/plugins/cards/card_cli.py +66 -30
  163. metaflow/plugins/cards/card_creator.py +25 -1
  164. metaflow/plugins/cards/card_datastore.py +21 -49
  165. metaflow/plugins/cards/card_decorator.py +132 -8
  166. metaflow/plugins/cards/card_modules/basic.py +112 -17
  167. metaflow/plugins/cards/card_modules/bundle.css +1 -1
  168. metaflow/plugins/cards/card_modules/card.py +16 -1
  169. metaflow/plugins/cards/card_modules/chevron/renderer.py +1 -1
  170. metaflow/plugins/cards/card_modules/components.py +665 -28
  171. metaflow/plugins/cards/card_modules/convert_to_native_type.py +36 -7
  172. metaflow/plugins/cards/card_modules/json_viewer.py +232 -0
  173. metaflow/plugins/cards/card_modules/main.css +1 -0
  174. metaflow/plugins/cards/card_modules/main.js +68 -49
  175. metaflow/plugins/cards/card_modules/renderer_tools.py +1 -0
  176. metaflow/plugins/cards/card_modules/test_cards.py +26 -12
  177. metaflow/plugins/cards/card_server.py +39 -14
  178. metaflow/plugins/cards/component_serializer.py +2 -9
  179. metaflow/plugins/cards/metadata.py +22 -0
  180. metaflow/plugins/catch_decorator.py +9 -0
  181. metaflow/plugins/datastores/azure_storage.py +10 -1
  182. metaflow/plugins/datastores/gs_storage.py +6 -2
  183. metaflow/plugins/datastores/local_storage.py +12 -6
  184. metaflow/plugins/datastores/spin_storage.py +12 -0
  185. metaflow/plugins/datatools/local.py +2 -0
  186. metaflow/plugins/datatools/s3/s3.py +126 -75
  187. metaflow/plugins/datatools/s3/s3op.py +254 -121
  188. metaflow/plugins/env_escape/__init__.py +3 -3
  189. metaflow/plugins/env_escape/client_modules.py +102 -72
  190. metaflow/plugins/env_escape/server.py +7 -0
  191. metaflow/plugins/env_escape/stub.py +24 -5
  192. metaflow/plugins/events_decorator.py +343 -185
  193. metaflow/plugins/exit_hook/__init__.py +0 -0
  194. metaflow/plugins/exit_hook/exit_hook_decorator.py +46 -0
  195. metaflow/plugins/exit_hook/exit_hook_script.py +52 -0
  196. metaflow/plugins/gcp/__init__.py +1 -1
  197. metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +11 -6
  198. metaflow/plugins/gcp/gs_tail.py +10 -6
  199. metaflow/plugins/gcp/includefile_support.py +3 -0
  200. metaflow/plugins/kubernetes/kube_utils.py +108 -0
  201. metaflow/plugins/kubernetes/kubernetes.py +411 -130
  202. metaflow/plugins/kubernetes/kubernetes_cli.py +168 -36
  203. metaflow/plugins/kubernetes/kubernetes_client.py +104 -2
  204. metaflow/plugins/kubernetes/kubernetes_decorator.py +246 -88
  205. metaflow/plugins/kubernetes/kubernetes_job.py +253 -581
  206. metaflow/plugins/kubernetes/kubernetes_jobsets.py +1071 -0
  207. metaflow/plugins/kubernetes/spot_metadata_cli.py +69 -0
  208. metaflow/plugins/kubernetes/spot_monitor_sidecar.py +109 -0
  209. metaflow/plugins/logs_cli.py +359 -0
  210. metaflow/plugins/{metadata → metadata_providers}/local.py +144 -84
  211. metaflow/plugins/{metadata → metadata_providers}/service.py +103 -26
  212. metaflow/plugins/metadata_providers/spin.py +16 -0
  213. metaflow/plugins/package_cli.py +36 -24
  214. metaflow/plugins/parallel_decorator.py +128 -11
  215. metaflow/plugins/parsers.py +16 -0
  216. metaflow/plugins/project_decorator.py +51 -5
  217. metaflow/plugins/pypi/bootstrap.py +357 -105
  218. metaflow/plugins/pypi/conda_decorator.py +82 -81
  219. metaflow/plugins/pypi/conda_environment.py +187 -52
  220. metaflow/plugins/pypi/micromamba.py +157 -47
  221. metaflow/plugins/pypi/parsers.py +268 -0
  222. metaflow/plugins/pypi/pip.py +88 -13
  223. metaflow/plugins/pypi/pypi_decorator.py +37 -1
  224. metaflow/plugins/pypi/utils.py +48 -2
  225. metaflow/plugins/resources_decorator.py +2 -2
  226. metaflow/plugins/secrets/__init__.py +3 -0
  227. metaflow/plugins/secrets/secrets_decorator.py +26 -181
  228. metaflow/plugins/secrets/secrets_func.py +49 -0
  229. metaflow/plugins/secrets/secrets_spec.py +101 -0
  230. metaflow/plugins/secrets/utils.py +74 -0
  231. metaflow/plugins/tag_cli.py +4 -7
  232. metaflow/plugins/test_unbounded_foreach_decorator.py +41 -6
  233. metaflow/plugins/timeout_decorator.py +3 -3
  234. metaflow/plugins/uv/__init__.py +0 -0
  235. metaflow/plugins/uv/bootstrap.py +128 -0
  236. metaflow/plugins/uv/uv_environment.py +72 -0
  237. metaflow/procpoll.py +1 -1
  238. metaflow/pylint_wrapper.py +5 -1
  239. metaflow/runner/__init__.py +0 -0
  240. metaflow/runner/click_api.py +717 -0
  241. metaflow/runner/deployer.py +470 -0
  242. metaflow/runner/deployer_impl.py +201 -0
  243. metaflow/runner/metaflow_runner.py +714 -0
  244. metaflow/runner/nbdeploy.py +132 -0
  245. metaflow/runner/nbrun.py +225 -0
  246. metaflow/runner/subprocess_manager.py +650 -0
  247. metaflow/runner/utils.py +335 -0
  248. metaflow/runtime.py +1078 -260
  249. metaflow/sidecar/sidecar_worker.py +1 -1
  250. metaflow/system/__init__.py +5 -0
  251. metaflow/system/system_logger.py +85 -0
  252. metaflow/system/system_monitor.py +108 -0
  253. metaflow/system/system_utils.py +19 -0
  254. metaflow/task.py +521 -225
  255. metaflow/tracing/__init__.py +7 -7
  256. metaflow/tracing/span_exporter.py +31 -38
  257. metaflow/tracing/tracing_modules.py +38 -43
  258. metaflow/tuple_util.py +27 -0
  259. metaflow/user_configs/__init__.py +0 -0
  260. metaflow/user_configs/config_options.py +563 -0
  261. metaflow/user_configs/config_parameters.py +598 -0
  262. metaflow/user_decorators/__init__.py +0 -0
  263. metaflow/user_decorators/common.py +144 -0
  264. metaflow/user_decorators/mutable_flow.py +512 -0
  265. metaflow/user_decorators/mutable_step.py +424 -0
  266. metaflow/user_decorators/user_flow_decorator.py +264 -0
  267. metaflow/user_decorators/user_step_decorator.py +749 -0
  268. metaflow/util.py +243 -27
  269. metaflow/vendor.py +23 -7
  270. metaflow/version.py +1 -1
  271. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Makefile +355 -0
  272. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Tiltfile +726 -0
  273. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/pick_services.sh +105 -0
  274. ob_metaflow-2.19.7.1rc0.dist-info/METADATA +87 -0
  275. ob_metaflow-2.19.7.1rc0.dist-info/RECORD +445 -0
  276. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/WHEEL +1 -1
  277. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/entry_points.txt +1 -0
  278. metaflow/_vendor/v3_5/__init__.py +0 -1
  279. metaflow/_vendor/v3_5/importlib_metadata/__init__.py +0 -644
  280. metaflow/_vendor/v3_5/importlib_metadata/_compat.py +0 -152
  281. metaflow/package.py +0 -188
  282. ob_metaflow-2.11.13.1.dist-info/METADATA +0 -85
  283. ob_metaflow-2.11.13.1.dist-info/RECORD +0 -308
  284. /metaflow/_vendor/{v3_5/zipp.py → zipp.py} +0 -0
  285. /metaflow/{metadata → metadata_provider}/__init__.py +0 -0
  286. /metaflow/{metadata → metadata_provider}/util.py +0 -0
  287. /metaflow/plugins/{metadata → metadata_providers}/__init__.py +0 -0
  288. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info/licenses}/LICENSE +0 -0
  289. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ import pathlib
7
7
  import re
8
8
  import time
9
9
  import typing
10
-
11
10
  from datetime import datetime
12
11
  from io import StringIO
13
12
  from types import ModuleType
@@ -31,13 +30,17 @@ from metaflow import FlowSpec, step
31
30
  from metaflow.debug import debug
32
31
  from metaflow.decorators import Decorator, FlowDecorator
33
32
  from metaflow.extension_support import get_aliased_modules
34
- from metaflow.graph import deindent_docstring
33
+ from metaflow.metaflow_current import Current
35
34
  from metaflow.metaflow_version import get_version
35
+ from metaflow.runner.deployer import DeployedFlow, Deployer, TriggeredRun
36
+ from metaflow.runner.deployer_impl import DeployerImpl
36
37
 
37
38
  TAB = " "
38
39
  METAFLOW_CURRENT_MODULE_NAME = "metaflow.metaflow_current"
40
+ METAFLOW_DEPLOYER_MODULE_NAME = "metaflow.runner.deployer"
39
41
 
40
42
  param_section_header = re.compile(r"Parameters\s*\n----------\s*\n", flags=re.M)
43
+ return_section_header = re.compile(r"Returns\s*\n-------\s*\n", flags=re.M)
41
44
  add_to_current_header = re.compile(
42
45
  r"MF Add To Current\s*\n-----------------\s*\n", flags=re.M
43
46
  )
@@ -57,6 +60,20 @@ MetaflowStepFunction = Union[
57
60
  ]
58
61
 
59
62
 
63
+ # Object that has start() and end() like a Match object to make the code simpler when
64
+ # we are parsing different sections of doc
65
+ class StartEnd:
66
+ def __init__(self, start: int, end: int):
67
+ self._start = start
68
+ self._end = end
69
+
70
+ def start(self):
71
+ return self._start
72
+
73
+ def end(self):
74
+ return self._end
75
+
76
+
60
77
  def type_var_to_str(t: TypeVar) -> str:
61
78
  bound_name = None
62
79
  if t.__bound__ is not None:
@@ -92,6 +109,131 @@ def descend_object(object: str, options: Iterable[str]):
92
109
  return False
93
110
 
94
111
 
112
+ def parse_params_from_doc(doc: str) -> Tuple[List[inspect.Parameter], bool]:
113
+ parameters = []
114
+ no_arg_version = True
115
+ for line in doc.splitlines():
116
+ if non_indented_line.match(line):
117
+ match = param_name_type.match(line)
118
+ arg_name = type_name = is_optional = default = None
119
+ default_set = False
120
+ if match is not None:
121
+ arg_name = match.group("name")
122
+ type_name = match.group("type")
123
+ if type_name is not None:
124
+ type_detail = type_annotations.match(type_name)
125
+ if type_detail is not None:
126
+ type_name = type_detail.group("type")
127
+ is_optional = type_detail.group("optional") is not None
128
+ default = type_detail.group("default")
129
+ if default:
130
+ default_set = True
131
+ try:
132
+ default = eval(default)
133
+ except:
134
+ pass
135
+ try:
136
+ type_name = eval(type_name)
137
+ except:
138
+ pass
139
+ parameters.append(
140
+ inspect.Parameter(
141
+ name=arg_name,
142
+ kind=inspect.Parameter.KEYWORD_ONLY,
143
+ default=(
144
+ default
145
+ if default_set
146
+ else None if is_optional else inspect.Parameter.empty
147
+ ),
148
+ annotation=(Optional[type_name] if is_optional else type_name),
149
+ )
150
+ )
151
+ if not default_set:
152
+ # If we don't have a default set for any parameter, we can't
153
+ # have a no-arg version since the function would be incomplete
154
+ no_arg_version = False
155
+ return parameters, no_arg_version
156
+
157
+
158
+ def split_docs(
159
+ raw_doc: str, boundaries: List[Tuple[str, Union[StartEnd, re.Match]]]
160
+ ) -> Dict[str, str]:
161
+ docs = dict()
162
+ boundaries.sort(key=lambda x: x[1].start())
163
+
164
+ section_start = 0
165
+ for idx in range(1, len(boundaries)):
166
+ docs[boundaries[idx - 1][0]] = raw_doc[
167
+ section_start : boundaries[idx][1].start()
168
+ ]
169
+ section_start = boundaries[idx][1].end()
170
+ docs[boundaries[-1][0]] = raw_doc[section_start:]
171
+ return docs
172
+
173
+
174
+ def parse_add_to_docs(
175
+ raw_doc: str,
176
+ ) -> Dict[str, Union[Tuple[inspect.Signature, str], str]]:
177
+ prop = None
178
+ return_type = None
179
+ property_indent = None
180
+ doc = []
181
+ add_to_docs = dict() # type: Dict[str, Union[str, Tuple[inspect.Signature, str]]]
182
+
183
+ def _add():
184
+ if prop:
185
+ add_to_docs[prop] = (
186
+ inspect.Signature(
187
+ [
188
+ inspect.Parameter(
189
+ "self", inspect.Parameter.POSITIONAL_OR_KEYWORD
190
+ )
191
+ ],
192
+ return_annotation=return_type,
193
+ ),
194
+ "\n".join(doc),
195
+ )
196
+
197
+ for line in raw_doc.splitlines():
198
+ # Parse stanzas that look like the following:
199
+ # <property-name> -> type
200
+ # indented doc string
201
+ if property_indent is not None and (
202
+ line.startswith(property_indent + " ") or line.strip() == ""
203
+ ):
204
+ offset = len(property_indent)
205
+ if line.lstrip().startswith("@@ "):
206
+ line = line.replace("@@ ", "")
207
+ doc.append(line[offset:].rstrip())
208
+ else:
209
+ if line.strip() == 0:
210
+ continue
211
+ if prop:
212
+ # Ends a property stanza
213
+ _add()
214
+ # Now start a new one
215
+ line = line.rstrip()
216
+ property_indent = line[: len(line) - len(line.lstrip())]
217
+ # Either this has a -> to denote a property or it is a pure name
218
+ # to denote a reference to a function (starting with #)
219
+ line = line.lstrip()
220
+ if line.startswith("#"):
221
+ # The name of the function is the last part like metaflow.deployer.run
222
+ add_to_docs[line.split(".")[-1]] = line[1:]
223
+ continue
224
+ # This is a line so we split it using "->"
225
+ prop, return_type = line.split("->")
226
+ prop = prop.strip()
227
+ return_type = return_type.strip()
228
+ doc = []
229
+ _add()
230
+ return add_to_docs
231
+
232
+
233
+ def add_indent(indentation: str, text: str) -> str:
234
+ return "\n".join([indentation + line for line in text.splitlines()])
235
+
236
+
95
237
  class StubGenerator:
96
238
  """
97
239
  This class takes the name of a library as input and a directory as output.
@@ -115,12 +257,24 @@ class StubGenerator:
115
257
  :type members_from_other_modules: List[str]
116
258
  """
117
259
 
260
+ # Let metaflow know we are in stubgen mode. This is sometimes useful to skip
261
+ # some processing like loading libraries, etc. It is used in Metaflow extensions
262
+ # so do not remove even if you do not see a use for it directly in the code.
263
+ os.environ["METAFLOW_STUBGEN"] = "1"
264
+
118
265
  self._write_generated_for = include_generated_for
119
- self._pending_modules = ["metaflow"] # type: List[str]
120
- self._pending_modules.extend(get_aliased_modules())
266
+ # First element is the name it should be installed in (alias) and second is the
267
+ # actual module name
268
+ self._pending_modules = [
269
+ ("metaflow", "metaflow")
270
+ ] # type: List[Tuple[str, str]]
121
271
  self._root_module = "metaflow."
122
272
  self._safe_modules = ["metaflow.", "metaflow_extensions."]
123
273
 
274
+ self._pending_modules.extend(
275
+ (self._get_module_name_alias(x), x) for x in get_aliased_modules()
276
+ )
277
+
124
278
  # We exclude some modules to not create a bunch of random non-user facing
125
279
  # .pyi files.
126
280
  self._exclude_modules = set(
@@ -146,7 +300,7 @@ class StubGenerator:
146
300
  "metaflow.package",
147
301
  "metaflow.plugins.datastores",
148
302
  "metaflow.plugins.env_escape",
149
- "metaflow.plugins.metadata",
303
+ "metaflow.plugins.metadata_providers",
150
304
  "metaflow.procpoll.py",
151
305
  "metaflow.R",
152
306
  "metaflow.runtime",
@@ -158,9 +312,16 @@ class StubGenerator:
158
312
  "metaflow._vendor",
159
313
  ]
160
314
  )
315
+
161
316
  self._done_modules = set() # type: Set[str]
162
317
  self._output_dir = output_dir
163
318
  self._mf_version = get_version()
319
+
320
+ # Contains the names of the methods that are injected in Deployer
321
+ self._deployer_injected_methods = (
322
+ {}
323
+ ) # type: Dict[str, Dict[str, Union[Tuple[str, str], str]]]
324
+ # Contains information to add to the Current object (injected by decorators)
164
325
  self._addl_current = (
165
326
  dict()
166
327
  ) # type: Dict[str, Dict[str, Tuple[inspect.Signature, str]]]
@@ -173,12 +334,15 @@ class StubGenerator:
173
334
 
174
335
  # Imports that are needed at the top of the file
175
336
  self._imports = set() # type: Set[str]
337
+
338
+ self._sub_module_imports = set() # type: Set[Tuple[str, str]]``
176
339
  # Typing imports (behind if TYPE_CHECKING) that are needed at the top of the file
177
340
  self._typing_imports = set() # type: Set[str]
178
341
  # Typevars that are defined
179
342
  self._typevars = dict() # type: Dict[str, Union[TypeVar, type]]
180
343
  # Current objects in the file being processed
181
344
  self._current_objects = {} # type: Dict[str, Any]
345
+ self._current_references = [] # type: List[str]
182
346
  # Current stubs in the file being processed
183
347
  self._stubs = [] # type: List[str]
184
348
 
@@ -187,26 +351,78 @@ class StubGenerator:
187
351
  # the "globals()"
188
352
  self._current_parent_module = None # type: Optional[ModuleType]
189
353
 
190
- def _get_module(self, name):
191
- debug.stubgen_exec("Analyzing module %s ..." % name)
354
+ def _get_module_name_alias(self, module_name):
355
+ if any(
356
+ module_name.startswith(x) for x in self._safe_modules
357
+ ) and not module_name.startswith(self._root_module):
358
+ return self._root_module + ".".join(
359
+ ["mf_extensions", *module_name.split(".")[1:]]
360
+ )
361
+ return module_name
362
+
363
+ def _get_relative_import(
364
+ self, new_module_name, cur_module_name, is_init_module=False
365
+ ):
366
+ new_components = new_module_name.split(".")
367
+ cur_components = cur_module_name.split(".")
368
+ init_module_count = 1 if is_init_module else 0
369
+ common_idx = 0
370
+ max_idx = min(len(new_components), len(cur_components))
371
+ while (
372
+ common_idx < max_idx
373
+ and new_components[common_idx] == cur_components[common_idx]
374
+ ):
375
+ common_idx += 1
376
+ # current: a.b and parent: a.b.e.d -> from .e.d import <name>
377
+ # current: a.b.c.d and parent: a.b.e.f -> from ...e.f import <name>
378
+ return "." * (len(cur_components) - common_idx + init_module_count) + ".".join(
379
+ new_components[common_idx:]
380
+ )
381
+
382
+ def _get_module(self, alias, name):
383
+ debug.stubgen_exec("Analyzing module %s (aliased at %s)..." % (name, alias))
192
384
  self._current_module = importlib.import_module(name)
193
- self._current_module_name = name
385
+ self._current_module_name = alias
194
386
  for objname, obj in self._current_module.__dict__.items():
387
+ if objname == "_addl_stubgen_modules":
388
+ debug.stubgen_exec(
389
+ "Adding modules %s from _addl_stubgen_modules" % str(obj)
390
+ )
391
+ self._pending_modules.extend(
392
+ (self._get_module_name_alias(m), m) for m in obj
393
+ )
394
+ continue
195
395
  if objname.startswith("_"):
196
396
  debug.stubgen_exec(
197
397
  "Skipping object because it starts with _ %s" % objname
198
398
  )
199
399
  continue
200
400
  if inspect.ismodule(obj):
201
- # Only consider modules that are part of the root module
401
+ # Only consider modules that are safe modules
202
402
  if (
203
- obj.__name__.startswith(self._root_module)
403
+ any(obj.__name__.startswith(m) for m in self._safe_modules)
204
404
  and not obj.__name__ in self._exclude_modules
205
405
  ):
206
406
  debug.stubgen_exec(
207
407
  "Adding child module %s to process" % obj.__name__
208
408
  )
209
- self._pending_modules.append(obj.__name__)
409
+
410
+ new_module_alias = self._get_module_name_alias(obj.__name__)
411
+ self._pending_modules.append((new_module_alias, obj.__name__))
412
+
413
+ new_parent, new_name = new_module_alias.rsplit(".", 1)
414
+ self._current_references.append(
415
+ "from %s import %s as %s"
416
+ % (
417
+ self._get_relative_import(
418
+ new_parent,
419
+ alias,
420
+ hasattr(self._current_module, "__path__"),
421
+ ),
422
+ new_name,
423
+ objname,
424
+ )
425
+ )
210
426
  else:
211
427
  debug.stubgen_exec("Skipping child module %s" % obj.__name__)
212
428
  else:
@@ -216,8 +432,10 @@ class StubGenerator:
216
432
  # we could be more specific but good enough for now) for root module.
217
433
  # We also include the step decorator (it's from metaflow.decorators
218
434
  # which is typically excluded)
219
- # - otherwise, anything that is in safe_modules. Note this may include
220
- # a bit much (all the imports)
435
+ # - Stuff that is defined in this module itself
436
+ # - a reference to anything in the modules we will process later
437
+ # (so we don't duplicate a ton of times)
438
+
221
439
  if (
222
440
  parent_module is None
223
441
  or (
@@ -227,52 +445,50 @@ class StubGenerator:
227
445
  or obj == step
228
446
  )
229
447
  )
230
- or (
231
- not any(
232
- [
233
- parent_module.__name__.startswith(p)
234
- for p in self._exclude_modules
235
- ]
236
- )
237
- and any(
238
- [
239
- parent_module.__name__.startswith(p)
240
- for p in self._safe_modules
241
- ]
242
- )
243
- )
448
+ or parent_module.__name__ == name
244
449
  ):
245
450
  debug.stubgen_exec("Adding object %s to process" % objname)
246
451
  self._current_objects[objname] = obj
247
- else:
248
- debug.stubgen_exec("Skipping object %s" % objname)
249
- # We also include the module to process if it is part of root_module
250
- if (
251
- parent_module is not None
252
- and not any(
253
- [
254
- parent_module.__name__.startswith(d)
255
- for d in self._exclude_modules
256
- ]
257
- )
258
- and parent_module.__name__.startswith(self._root_module)
452
+
453
+ elif not any(
454
+ [
455
+ parent_module.__name__.startswith(p)
456
+ for p in self._exclude_modules
457
+ ]
458
+ ) and any(
459
+ [parent_module.__name__.startswith(p) for p in self._safe_modules]
259
460
  ):
461
+ parent_alias = self._get_module_name_alias(parent_module.__name__)
462
+
463
+ relative_import = self._get_relative_import(
464
+ parent_alias, alias, hasattr(self._current_module, "__path__")
465
+ )
466
+
260
467
  debug.stubgen_exec(
261
- "Adding module of child object %s to process"
262
- % parent_module.__name__,
468
+ "Adding reference %s and adding module %s as %s"
469
+ % (objname, parent_module.__name__, parent_alias)
263
470
  )
264
- self._pending_modules.append(parent_module.__name__)
471
+ obj_import_name = getattr(obj, "__name__", objname)
472
+ if obj_import_name == "<lambda>":
473
+ # We have one case of this
474
+ obj_import_name = objname
475
+ self._current_references.append(
476
+ "from %s import %s as %s"
477
+ % (relative_import, obj_import_name, objname)
478
+ )
479
+ self._pending_modules.append((parent_alias, parent_module.__name__))
480
+ else:
481
+ debug.stubgen_exec("Skipping object %s" % objname)
265
482
 
266
- def _get_element_name_with_module(self, element: Union[TypeVar, type, Any]) -> str:
483
+ def _get_element_name_with_module(
484
+ self, element: Union[TypeVar, type, Any], force_import=False
485
+ ) -> str:
267
486
  # The element can be a string, for example "def f() -> 'SameClass':..."
268
487
  def _add_to_import(name):
269
488
  if name != self._current_module_name:
270
489
  self._imports.add(name)
271
490
 
272
491
  def _add_to_typing_check(name, is_module=False):
273
- # if name != self._current_module_name:
274
- # self._typing_imports.add(name)
275
- #
276
492
  if name == "None":
277
493
  return
278
494
  if is_module:
@@ -286,15 +502,38 @@ class StubGenerator:
286
502
  # the current file
287
503
  self._typing_imports.add(splits[0])
288
504
 
505
+ def _format_qualified_class_name(cls: type) -> str:
506
+ """Helper to format a class with its qualified module name"""
507
+ # Special case for NoneType - return None
508
+ if cls.__name__ == "NoneType":
509
+ return "None"
510
+
511
+ module = inspect.getmodule(cls)
512
+ if (
513
+ module
514
+ and module.__name__ != "builtins"
515
+ and module.__name__ != "__main__"
516
+ ):
517
+ module_name = self._get_module_name_alias(module.__name__)
518
+ _add_to_typing_check(module_name, is_module=True)
519
+ return f"{module_name}.{cls.__name__}"
520
+ else:
521
+ return cls.__name__
522
+
289
523
  if isinstance(element, str):
524
+ # Special case for self referential things (particularly in a class)
525
+ if element == self._current_name:
526
+ return '"%s"' % element
290
527
  # We first try to eval the annotation because with the annotations future
291
528
  # it is always a string
292
529
  try:
293
530
  potential_element = eval(
294
531
  element,
295
- self._current_parent_module.__dict__
296
- if self._current_parent_module
297
- else None,
532
+ (
533
+ self._current_parent_module.__dict__
534
+ if self._current_parent_module
535
+ else None
536
+ ),
298
537
  )
299
538
  if potential_element:
300
539
  element = potential_element
@@ -302,6 +541,9 @@ class StubGenerator:
302
541
  pass
303
542
 
304
543
  if isinstance(element, str):
544
+ # If we are in our "safe" modules, make sure we alias properly
545
+ if any(element.startswith(x) for x in self._safe_modules):
546
+ element = self._get_module_name_alias(element)
305
547
  _add_to_typing_check(element)
306
548
  return '"%s"' % element
307
549
  # 3.10+ has NewType as a class but not before so hack around to check for NewType
@@ -321,26 +563,25 @@ class StubGenerator:
321
563
  return "None"
322
564
  return element.__name__
323
565
 
324
- _add_to_typing_check(module.__name__, is_module=True)
325
- if module.__name__ != self._current_module_name:
326
- return "{0}.{1}".format(module.__name__, element.__name__)
566
+ module_name = self._get_module_name_alias(module.__name__)
567
+ if force_import:
568
+ _add_to_import(module_name.split(".")[0])
569
+ _add_to_typing_check(module_name, is_module=True)
570
+ if module_name != self._current_module_name:
571
+ return "{0}.{1}".format(module_name, element.__name__)
327
572
  else:
328
573
  return element.__name__
329
574
  elif isinstance(element, type(Ellipsis)):
330
575
  return "..."
331
- # elif (
332
- # isinstance(element, typing._GenericAlias)
333
- # and hasattr(element, "_name")
334
- # and element._name in ("List", "Tuple", "Dict", "Set")
335
- # ):
336
- # # 3.7 has these as _GenericAlias but they don't behave like the ones in 3.10
337
- # _add_to_import("typing")
338
- # return str(element)
339
576
  elif isinstance(element, typing._GenericAlias):
340
577
  # We need to check things recursively in __args__ if it exists
341
578
  args_str = []
342
579
  for arg in getattr(element, "__args__", []):
343
- args_str.append(self._get_element_name_with_module(arg))
580
+ # Special handling for class objects in type arguments
581
+ if isinstance(arg, type):
582
+ args_str.append(_format_qualified_class_name(arg))
583
+ else:
584
+ args_str.append(self._get_element_name_with_module(arg))
344
585
 
345
586
  _add_to_import("typing")
346
587
  if element._name:
@@ -355,12 +596,15 @@ class StubGenerator:
355
596
  args_str = [call_args, args_str[-1]]
356
597
  return "typing.%s[%s]" % (element._name, ", ".join(args_str))
357
598
  else:
358
- return "%s[%s]" % (element.__origin__, ", ".join(args_str))
599
+ # Handle the case where we have a generic type without a _name
600
+ origin = element.__origin__
601
+ if isinstance(origin, type):
602
+ origin_str = _format_qualified_class_name(origin)
603
+ else:
604
+ origin_str = str(origin)
605
+ return "%s[%s]" % (origin_str, ", ".join(args_str))
359
606
  elif isinstance(element, ForwardRef):
360
- f_arg = element.__forward_arg__
361
- # if f_arg in ("Run", "Task"): # HACK -- forward references in current.py
362
- # _add_to_import("metaflow")
363
- # f_arg = "metaflow.%s" % f_arg
607
+ f_arg = self._get_module_name_alias(element.__forward_arg__)
364
608
  _add_to_typing_check(f_arg)
365
609
  return '"%s"' % f_arg
366
610
  elif inspect.getmodule(element) == inspect.getmodule(typing):
@@ -370,9 +614,17 @@ class StubGenerator:
370
614
  return "typing.NamedTuple"
371
615
  return str(element)
372
616
  else:
373
- raise RuntimeError(
374
- "Does not handle element %s of type %s" % (str(element), type(element))
375
- )
617
+ if hasattr(element, "__module__"):
618
+ elem_module = self._get_module_name_alias(element.__module__)
619
+ if elem_module == "builtins":
620
+ return getattr(element, "__name__", str(element))
621
+ _add_to_typing_check(elem_module, is_module=True)
622
+ return "{0}.{1}".format(
623
+ elem_module, getattr(element, "__name__", element)
624
+ )
625
+ else:
626
+ # A constant
627
+ return str(element)
376
628
 
377
629
  def _exploit_annotation(self, annotation: Any, starting: str = ": ") -> str:
378
630
  annotation_string = ""
@@ -383,19 +635,57 @@ class StubGenerator:
383
635
  return annotation_string
384
636
 
385
637
  def _generate_class_stub(self, name: str, clazz: type) -> str:
638
+ debug.stubgen_exec("Generating class stub for %s" % name)
639
+ skip_init = issubclass(clazz, (TriggeredRun, DeployedFlow))
640
+ if issubclass(clazz, DeployerImpl):
641
+ if clazz.TYPE is not None:
642
+ clazz_type = clazz.TYPE.replace("-", "_")
643
+ self._deployer_injected_methods.setdefault(clazz_type, {})[
644
+ "deployer"
645
+ ] = (self._current_module_name + "." + name)
646
+
647
+ # Handle TypedDict gracefully for Python 3.7 compatibility
648
+ # _TypedDictMeta is not available in Python 3.7
649
+ typed_dict_meta = getattr(typing, "_TypedDictMeta", None)
650
+ if typed_dict_meta is not None and isinstance(clazz, typed_dict_meta):
651
+ self._sub_module_imports.add(("typing", "TypedDict"))
652
+ total_flag = getattr(clazz, "__total__", False)
653
+ buff = StringIO()
654
+ # Emit the TypedDict base and total flag
655
+ buff.write(f"class {name}(TypedDict, total={total_flag}):\n")
656
+ # Write out each field from __annotations__
657
+ for field_name, field_type in clazz.__annotations__.items():
658
+ ann = self._get_element_name_with_module(field_type)
659
+ buff.write(f"{TAB}{field_name}: {ann}\n")
660
+ return buff.getvalue()
661
+
386
662
  buff = StringIO()
387
663
  # Class prototype
388
664
  buff.write("class " + name.split(".")[-1] + "(")
389
665
 
390
666
  # Add super classes
391
667
  for c in clazz.__bases__:
392
- name_with_module = self._get_element_name_with_module(c)
668
+ name_with_module = self._get_element_name_with_module(c, force_import=True)
393
669
  buff.write(name_with_module + ", ")
394
670
 
395
671
  # Add metaclass
396
- name_with_module = self._get_element_name_with_module(clazz.__class__)
672
+ name_with_module = self._get_element_name_with_module(
673
+ clazz.__class__, force_import=True
674
+ )
397
675
  buff.write("metaclass=" + name_with_module + "):\n")
398
676
 
677
+ # Add class docstring
678
+ if clazz.__doc__:
679
+ buff.write('%s"""\n' % TAB)
680
+ my_doc = inspect.cleandoc(clazz.__doc__)
681
+ init_blank = True
682
+ for line in my_doc.split("\n"):
683
+ if init_blank and len(line.strip()) == 0:
684
+ continue
685
+ init_blank = False
686
+ buff.write("%s%s\n" % (TAB, line.rstrip()))
687
+ buff.write('%s"""\n' % TAB)
688
+
399
689
  # For NamedTuple, we have __annotations__ but no __init__. In that case,
400
690
  # we are going to "create" a __init__ function with the annotations
401
691
  # to show what the class takes.
@@ -410,6 +700,8 @@ class StubGenerator:
410
700
  func_deco = "@classmethod"
411
701
  element = element.__func__
412
702
  if key == "__init__":
703
+ if skip_init:
704
+ continue
413
705
  init_func = element
414
706
  elif key == "__annotations__":
415
707
  annotation_dict = element
@@ -417,11 +709,201 @@ class StubGenerator:
417
709
  if not element.__name__.startswith("_") or element.__name__.startswith(
418
710
  "__"
419
711
  ):
420
- buff.write(
421
- self._generate_function_stub(
422
- key, element, indentation=TAB, deco=func_deco
712
+ if (
713
+ clazz == Deployer
714
+ and element.__name__ in self._deployer_injected_methods
715
+ ):
716
+ # This is a method that was injected. It has docs but we need
717
+ # to parse it to generate the proper signature
718
+ func_doc = inspect.cleandoc(element.__doc__)
719
+ docs = split_docs(
720
+ func_doc,
721
+ [
722
+ ("func_doc", StartEnd(0, 0)),
723
+ (
724
+ "param_doc",
725
+ param_section_header.search(func_doc)
726
+ or StartEnd(len(func_doc), len(func_doc)),
727
+ ),
728
+ (
729
+ "return_doc",
730
+ return_section_header.search(func_doc)
731
+ or StartEnd(len(func_doc), len(func_doc)),
732
+ ),
733
+ ],
423
734
  )
424
- )
735
+
736
+ parameters, _ = parse_params_from_doc(docs["param_doc"])
737
+ return_type = self._deployer_injected_methods[element.__name__][
738
+ "deployer"
739
+ ]
740
+
741
+ buff.write(
742
+ self._generate_function_stub(
743
+ key,
744
+ element,
745
+ sign=[
746
+ inspect.Signature(
747
+ parameters=[
748
+ inspect.Parameter(
749
+ "self",
750
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
751
+ )
752
+ ]
753
+ + parameters,
754
+ return_annotation=return_type,
755
+ )
756
+ ],
757
+ indentation=TAB,
758
+ deco=func_deco,
759
+ )
760
+ )
761
+ elif (
762
+ clazz == DeployedFlow and element.__name__ == "from_deployment"
763
+ ):
764
+ # We simply update the signature to list the return
765
+ # type as a union of all possible deployers
766
+ func_doc = inspect.cleandoc(element.__doc__)
767
+ docs = split_docs(
768
+ func_doc,
769
+ [
770
+ ("func_doc", StartEnd(0, 0)),
771
+ (
772
+ "param_doc",
773
+ param_section_header.search(func_doc)
774
+ or StartEnd(len(func_doc), len(func_doc)),
775
+ ),
776
+ (
777
+ "return_doc",
778
+ return_section_header.search(func_doc)
779
+ or StartEnd(len(func_doc), len(func_doc)),
780
+ ),
781
+ ],
782
+ )
783
+
784
+ parameters, _ = parse_params_from_doc(docs["param_doc"])
785
+
786
+ def _create_multi_type(*l):
787
+ return typing.Union[l]
788
+
789
+ all_types = [
790
+ v["from_deployment"][0]
791
+ for v in self._deployer_injected_methods.values()
792
+ ]
793
+
794
+ if len(all_types) > 1:
795
+ return_type = _create_multi_type(*all_types)
796
+ else:
797
+ return_type = all_types[0] if len(all_types) else None
798
+
799
+ buff.write(
800
+ self._generate_function_stub(
801
+ key,
802
+ element,
803
+ sign=[
804
+ inspect.Signature(
805
+ parameters=[
806
+ inspect.Parameter(
807
+ "cls",
808
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
809
+ )
810
+ ]
811
+ + parameters,
812
+ return_annotation=return_type,
813
+ )
814
+ ],
815
+ indentation=TAB,
816
+ doc=docs["func_doc"]
817
+ + "\n\nParameters\n----------\n"
818
+ + docs["param_doc"]
819
+ + "\n\nReturns\n-------\n"
820
+ + "%s\nA `DeployedFlow` object" % str(return_type),
821
+ deco=func_deco,
822
+ )
823
+ )
824
+ elif (
825
+ clazz == DeployedFlow
826
+ and element.__name__.startswith("from_")
827
+ and element.__name__[5:] in self._deployer_injected_methods
828
+ ):
829
+ # Get the doc from the from_deployment method stored in
830
+ # self._deployer_injected_methods
831
+ func_doc = inspect.cleandoc(
832
+ self._deployer_injected_methods[element.__name__[5:]][
833
+ "from_deployment"
834
+ ][1]
835
+ or ""
836
+ )
837
+ docs = split_docs(
838
+ func_doc,
839
+ [
840
+ ("func_doc", StartEnd(0, 0)),
841
+ (
842
+ "param_doc",
843
+ param_section_header.search(func_doc)
844
+ or StartEnd(len(func_doc), len(func_doc)),
845
+ ),
846
+ (
847
+ "return_doc",
848
+ return_section_header.search(func_doc)
849
+ or StartEnd(len(func_doc), len(func_doc)),
850
+ ),
851
+ ],
852
+ )
853
+
854
+ parameters, _ = parse_params_from_doc(docs["param_doc"])
855
+ return_type = self._deployer_injected_methods[
856
+ element.__name__[5:]
857
+ ]["from_deployment"][0]
858
+
859
+ buff.write(
860
+ self._generate_function_stub(
861
+ key,
862
+ element,
863
+ sign=[
864
+ inspect.Signature(
865
+ parameters=[
866
+ inspect.Parameter(
867
+ "cls",
868
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
869
+ )
870
+ ]
871
+ + parameters,
872
+ return_annotation=return_type,
873
+ )
874
+ ],
875
+ indentation=TAB,
876
+ doc=docs["func_doc"]
877
+ + "\n\nParameters\n----------\n"
878
+ + docs["param_doc"]
879
+ + "\n\nReturns\n-------\n"
880
+ + docs["return_doc"],
881
+ deco=func_deco,
882
+ )
883
+ )
884
+ else:
885
+ if (
886
+ issubclass(clazz, DeployedFlow)
887
+ and clazz.TYPE is not None
888
+ and key == "from_deployment"
889
+ ):
890
+ clazz_type = clazz.TYPE.replace("-", "_")
891
+ # Record docstring for this function
892
+ self._deployer_injected_methods.setdefault(clazz_type, {})[
893
+ "from_deployment"
894
+ ] = (
895
+ self._current_module_name + "." + name,
896
+ element.__doc__,
897
+ )
898
+ buff.write(
899
+ self._generate_function_stub(
900
+ key,
901
+ element,
902
+ indentation=TAB,
903
+ deco=func_deco,
904
+ )
905
+ )
906
+
425
907
  elif isinstance(element, property):
426
908
  if element.fget:
427
909
  buff.write(
@@ -436,20 +918,17 @@ class StubGenerator:
436
918
  )
437
919
  )
438
920
 
439
- # Special handling for the current module
440
- if (
441
- self._current_module_name == METAFLOW_CURRENT_MODULE_NAME
442
- and name == "Current"
443
- ):
921
+ # Special handling of classes that have injected methods
922
+ if clazz == Current:
444
923
  # Multiple decorators can add the same object (trigger and trigger_on_finish)
445
924
  # as examples so we sort it out.
446
925
  resulting_dict = (
447
926
  dict()
448
927
  ) # type Dict[str, List[inspect.Signature, str, List[str]]]
449
- for project_name, addl_current in self._addl_current.items():
928
+ for deco_name, addl_current in self._addl_current.items():
450
929
  for name, (sign, doc) in addl_current.items():
451
930
  r = resulting_dict.setdefault(name, [sign, doc, []])
452
- r[2].append("@%s" % project_name)
931
+ r[2].append("@%s" % deco_name)
453
932
  for name, (sign, doc, decos) in resulting_dict.items():
454
933
  buff.write(
455
934
  self._generate_function_stub(
@@ -462,7 +941,8 @@ class StubGenerator:
462
941
  deco="@property",
463
942
  )
464
943
  )
465
- if init_func is None and annotation_dict:
944
+
945
+ if not skip_init and init_func is None and annotation_dict:
466
946
  buff.write(
467
947
  self._generate_function_stub(
468
948
  "__init__",
@@ -508,121 +988,30 @@ class StubGenerator:
508
988
  self._typevars["StepFlag"] = StepFlag
509
989
 
510
990
  raw_doc = inspect.cleandoc(raw_doc)
511
- has_parameters = param_section_header.search(raw_doc)
512
- has_add_to_current = add_to_current_header.search(raw_doc)
513
-
514
- if has_parameters and has_add_to_current:
515
- doc = raw_doc[has_parameters.end() : has_add_to_current.start()]
516
- add_to_current_doc = raw_doc[has_add_to_current.end() :]
517
- raw_doc = raw_doc[: has_add_to_current.start()]
518
- elif has_parameters:
519
- doc = raw_doc[has_parameters.end() :]
520
- add_to_current_doc = None
521
- elif has_add_to_current:
522
- add_to_current_doc = raw_doc[has_add_to_current.end() :]
523
- raw_doc = raw_doc[: has_add_to_current.start()]
524
- doc = ""
525
- else:
526
- doc = ""
527
- add_to_current_doc = None
528
- parameters = []
529
- no_arg_version = True
530
- for line in doc.splitlines():
531
- if non_indented_line.match(line):
532
- match = param_name_type.match(line)
533
- arg_name = type_name = is_optional = default = None
534
- default_set = False
535
- if match is not None:
536
- arg_name = match.group("name")
537
- type_name = match.group("type")
538
- if type_name is not None:
539
- type_detail = type_annotations.match(type_name)
540
- if type_detail is not None:
541
- type_name = type_detail.group("type")
542
- is_optional = type_detail.group("optional") is not None
543
- default = type_detail.group("default")
544
- if default:
545
- default_set = True
546
- try:
547
- default = eval(default)
548
- except:
549
- pass
550
- try:
551
- type_name = eval(type_name)
552
- except:
553
- pass
554
- parameters.append(
555
- inspect.Parameter(
556
- name=arg_name,
557
- kind=inspect.Parameter.KEYWORD_ONLY,
558
- default=default
559
- if default_set
560
- else None
561
- if is_optional
562
- else inspect.Parameter.empty,
563
- annotation=Optional[type_name]
564
- if is_optional
565
- else type_name,
566
- )
567
- )
568
- if not default_set:
569
- # If we don't have a default set for any parameter, we can't
570
- # have a no-arg version since the decorator would be incomplete
571
- no_arg_version = False
572
- if add_to_current_doc:
573
- current_property = None
574
- current_return_type = None
575
- current_property_indent = None
576
- current_doc = []
577
- add_to_current = dict() # type: Dict[str, Tuple[inspect.Signature, str]]
578
-
579
- def _add():
580
- if current_property:
581
- add_to_current[current_property] = (
582
- inspect.Signature(
583
- [
584
- inspect.Parameter(
585
- "self", inspect.Parameter.POSITIONAL_OR_KEYWORD
586
- )
587
- ],
588
- return_annotation=current_return_type,
589
- ),
590
- "\n".join(current_doc),
591
- )
991
+ section_boundaries = [
992
+ ("func_doc", StartEnd(0, 0)),
993
+ (
994
+ "param_doc",
995
+ param_section_header.search(raw_doc)
996
+ or StartEnd(len(raw_doc), len(raw_doc)),
997
+ ),
998
+ (
999
+ "add_to_current_doc",
1000
+ add_to_current_header.search(raw_doc)
1001
+ or StartEnd(len(raw_doc), len(raw_doc)),
1002
+ ),
1003
+ ]
592
1004
 
593
- for line in add_to_current_doc.splitlines():
594
- # Parse stanzas that look like the following:
595
- # <property-name> -> type
596
- # indented doc string
597
- if current_property_indent is not None and (
598
- line.startswith(current_property_indent + " ") or line.strip() == ""
599
- ):
600
- offset = len(current_property_indent)
601
- if line.lstrip().startswith("@@ "):
602
- line = line.replace("@@ ", "")
603
- current_doc.append(line[offset:].rstrip())
604
- else:
605
- if line.strip() == 0:
606
- continue
607
- if current_property:
608
- # Ends a property stanza
609
- _add()
610
- # Now start a new one
611
- line = line.rstrip()
612
- current_property_indent = line[: len(line) - len(line.lstrip())]
613
- # This is a line so we split it using "->"
614
- current_property, current_return_type = line.split("->")
615
- current_property = current_property.strip()
616
- current_return_type = current_return_type.strip()
617
- current_doc = []
618
- _add()
619
-
620
- self._addl_current[name] = add_to_current
1005
+ docs = split_docs(raw_doc, section_boundaries)
1006
+ parameters, no_arg_version = parse_params_from_doc(docs["param_doc"])
1007
+
1008
+ if docs["add_to_current_doc"]:
1009
+ self._addl_current[name] = parse_add_to_docs(docs["add_to_current_doc"])
621
1010
 
622
1011
  result = []
623
1012
  if no_arg_version:
624
1013
  if is_flow_decorator:
625
- if has_parameters:
1014
+ if docs["param_doc"]:
626
1015
  result.append(
627
1016
  (
628
1017
  inspect.Signature(
@@ -651,7 +1040,7 @@ class StubGenerator:
651
1040
  ),
652
1041
  )
653
1042
  else:
654
- if has_parameters:
1043
+ if docs["param_doc"]:
655
1044
  result.append(
656
1045
  (
657
1046
  inspect.Signature(
@@ -706,24 +1095,31 @@ class StubGenerator:
706
1095
  result = result + [
707
1096
  (
708
1097
  inspect.Signature(
709
- parameters=[
710
- inspect.Parameter(
711
- name="f",
712
- kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
713
- annotation=Optional[typing.Type[FlowSpecDerived]],
714
- default=None
715
- if no_arg_version
716
- else inspect.Parameter.empty,
717
- )
718
- ]
719
- + parameters
720
- if no_arg_version
721
- else [] + parameters,
722
- return_annotation=inspect.Signature.empty
723
- if no_arg_version
724
- else Callable[
725
- [typing.Type[FlowSpecDerived]], typing.Type[FlowSpecDerived]
726
- ],
1098
+ parameters=(
1099
+ [
1100
+ inspect.Parameter(
1101
+ name="f",
1102
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
1103
+ annotation=Optional[typing.Type[FlowSpecDerived]],
1104
+ default=(
1105
+ None
1106
+ if no_arg_version
1107
+ else inspect.Parameter.empty
1108
+ ),
1109
+ )
1110
+ ]
1111
+ + parameters
1112
+ if no_arg_version
1113
+ else [] + parameters
1114
+ ),
1115
+ return_annotation=(
1116
+ inspect.Signature.empty
1117
+ if no_arg_version
1118
+ else Callable[
1119
+ [typing.Type[FlowSpecDerived]],
1120
+ typing.Type[FlowSpecDerived],
1121
+ ]
1122
+ ),
727
1123
  ),
728
1124
  "",
729
1125
  ),
@@ -732,24 +1128,30 @@ class StubGenerator:
732
1128
  result = result + [
733
1129
  (
734
1130
  inspect.Signature(
735
- parameters=[
736
- inspect.Parameter(
737
- name="f",
738
- kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
739
- annotation=Optional[MetaflowStepFunction],
740
- default=None
741
- if no_arg_version
742
- else inspect.Parameter.empty,
743
- )
744
- ]
745
- + parameters
746
- if no_arg_version
747
- else [] + parameters,
748
- return_annotation=inspect.Signature.empty
749
- if no_arg_version
750
- else typing.Callable[
751
- [MetaflowStepFunction], MetaflowStepFunction
752
- ],
1131
+ parameters=(
1132
+ [
1133
+ inspect.Parameter(
1134
+ name="f",
1135
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
1136
+ annotation=Optional[MetaflowStepFunction],
1137
+ default=(
1138
+ None
1139
+ if no_arg_version
1140
+ else inspect.Parameter.empty
1141
+ ),
1142
+ )
1143
+ ]
1144
+ + parameters
1145
+ if no_arg_version
1146
+ else [] + parameters
1147
+ ),
1148
+ return_annotation=(
1149
+ inspect.Signature.empty
1150
+ if no_arg_version
1151
+ else typing.Callable[
1152
+ [MetaflowStepFunction], MetaflowStepFunction
1153
+ ]
1154
+ ),
753
1155
  ),
754
1156
  "",
755
1157
  ),
@@ -760,8 +1162,17 @@ class StubGenerator:
760
1162
  result = result[1:]
761
1163
  # Add doc to first and last overloads. Jedi uses the last one and pycharm
762
1164
  # the first one. Go figure.
763
- result[0] = (result[0][0], raw_doc)
764
- result[-1] = (result[-1][0], raw_doc)
1165
+ result_docstring = docs["func_doc"]
1166
+ if docs["param_doc"]:
1167
+ result_docstring += "\nParameters\n----------\n" + docs["param_doc"]
1168
+ result[0] = (
1169
+ result[0][0],
1170
+ result_docstring,
1171
+ )
1172
+ result[-1] = (
1173
+ result[-1][0],
1174
+ result_docstring,
1175
+ )
765
1176
  return result
766
1177
 
767
1178
  def _generate_function_stub(
@@ -773,11 +1184,12 @@ class StubGenerator:
773
1184
  doc: Optional[str] = None,
774
1185
  deco: Optional[str] = None,
775
1186
  ) -> str:
1187
+ debug.stubgen_exec("Generating function stub for %s" % name)
1188
+
776
1189
  def exploit_default(default_value: Any) -> Optional[str]:
777
- if (
778
- default_value != inspect.Parameter.empty
779
- and type(default_value).__module__ == "builtins"
780
- ):
1190
+ if default_value == inspect.Parameter.empty:
1191
+ return None
1192
+ if type(default_value).__module__ == "builtins":
781
1193
  if isinstance(default_value, list):
782
1194
  return (
783
1195
  "["
@@ -807,22 +1219,22 @@ class StubGenerator:
807
1219
  )
808
1220
  + "}"
809
1221
  )
810
- elif str(default_value).startswith("<"):
811
- if default_value.__module__ == "builtins":
812
- return default_value.__name__
813
- else:
814
- self._typing_imports.add(default_value.__module__)
815
- return ".".join(
816
- [default_value.__module__, default_value.__name__]
817
- )
1222
+ elif isinstance(default_value, str):
1223
+ return repr(default_value) # Use repr() for proper escaping
1224
+ elif isinstance(default_value, (int, float, bool)):
1225
+ return str(default_value)
1226
+ elif default_value is None:
1227
+ return "None"
818
1228
  else:
819
- return (
820
- str(default_value)
821
- if not isinstance(default_value, str)
822
- else '"' + default_value + '"'
823
- )
1229
+ return "..." # For other built-in types not explicitly handled
1230
+ elif inspect.isclass(default_value) or inspect.isfunction(default_value):
1231
+ if default_value.__module__ == "builtins":
1232
+ return default_value.__name__
1233
+ else:
1234
+ self._typing_imports.add(default_value.__module__)
1235
+ return ".".join([default_value.__module__, default_value.__name__])
824
1236
  else:
825
- return None
1237
+ return "..." # For complex objects like class instances
826
1238
 
827
1239
  buff = StringIO()
828
1240
  if sign is None and func is None:
@@ -838,6 +1250,10 @@ class StubGenerator:
838
1250
  # value
839
1251
  return ""
840
1252
  doc = doc or func.__doc__
1253
+ if doc == "STUBGEN_IGNORE":
1254
+ # Ignore methods that have STUBGEN_IGNORE. Used to ignore certain
1255
+ # methods for the Deployer
1256
+ return ""
841
1257
  indentation = indentation or ""
842
1258
 
843
1259
  # Deal with overload annotations -- the last one will be non overloaded and
@@ -851,29 +1267,40 @@ class StubGenerator:
851
1267
  buff.write("\n")
852
1268
 
853
1269
  if do_overload and count < len(sign) - 1:
1270
+ # According to mypy, we should have this on all variants but
1271
+ # some IDEs seem to prefer if there is one non-overloaded
1272
+ # This also changes our checks so if changing, modify tests
854
1273
  buff.write(indentation + "@typing.overload\n")
855
1274
  if deco:
856
1275
  buff.write(indentation + deco + "\n")
857
1276
  buff.write(indentation + "def " + name + "(")
858
1277
  kw_only_param = False
1278
+ has_var_args = False
859
1279
  for i, (par_name, parameter) in enumerate(my_sign.parameters.items()):
860
1280
  annotation = self._exploit_annotation(parameter.annotation)
861
1281
  default = exploit_default(parameter.default)
862
1282
 
863
- if kw_only_param and parameter.kind != inspect.Parameter.KEYWORD_ONLY:
1283
+ if (
1284
+ kw_only_param
1285
+ and not has_var_args
1286
+ and parameter.kind != inspect.Parameter.KEYWORD_ONLY
1287
+ ):
864
1288
  raise RuntimeError(
865
1289
  "In function '%s': cannot have a positional parameter after a "
866
1290
  "keyword only parameter" % name
867
1291
  )
1292
+
868
1293
  if (
869
1294
  parameter.kind == inspect.Parameter.KEYWORD_ONLY
870
1295
  and not kw_only_param
1296
+ and not has_var_args
871
1297
  ):
872
1298
  kw_only_param = True
873
1299
  buff.write("*, ")
874
1300
  if parameter.kind == inspect.Parameter.VAR_KEYWORD:
875
1301
  par_name = "**%s" % par_name
876
1302
  elif parameter.kind == inspect.Parameter.VAR_POSITIONAL:
1303
+ has_var_args = True
877
1304
  par_name = "*%s" % par_name
878
1305
 
879
1306
  if default:
@@ -890,7 +1317,7 @@ class StubGenerator:
890
1317
 
891
1318
  if (count == 0 or count == len(sign) - 1) and doc is not None:
892
1319
  buff.write('%s%s"""\n' % (indentation, TAB))
893
- my_doc = cast(str, deindent_docstring(doc))
1320
+ my_doc = inspect.cleandoc(doc)
894
1321
  init_blank = True
895
1322
  for line in my_doc.split("\n"):
896
1323
  if init_blank and len(line.strip()) == 0:
@@ -909,6 +1336,7 @@ class StubGenerator:
909
1336
  def _generate_stubs(self):
910
1337
  for name, attr in self._current_objects.items():
911
1338
  self._current_parent_module = inspect.getmodule(attr)
1339
+ self._current_name = name
912
1340
  if inspect.isclass(attr):
913
1341
  self._stubs.append(self._generate_class_stub(name, attr))
914
1342
  elif inspect.isfunction(attr):
@@ -991,6 +1419,29 @@ class StubGenerator:
991
1419
  elif not inspect.ismodule(attr):
992
1420
  self._stubs.append(self._generate_generic_stub(name, attr))
993
1421
 
1422
+ def _write_header(self, f, width):
1423
+ title_line = "Auto-generated Metaflow stub file"
1424
+ title_white_space = (width - len(title_line)) / 2
1425
+ title_line = "#%s%s%s#\n" % (
1426
+ " " * math.floor(title_white_space),
1427
+ title_line,
1428
+ " " * math.ceil(title_white_space),
1429
+ )
1430
+ f.write(
1431
+ "#" * (width + 2)
1432
+ + "\n"
1433
+ + title_line
1434
+ + "# MF version: %s%s#\n"
1435
+ % (self._mf_version, " " * (width - 13 - len(self._mf_version)))
1436
+ + "# Generated on %s%s#\n"
1437
+ % (
1438
+ datetime.fromtimestamp(time.time()).isoformat(),
1439
+ " " * (width - 14 - 26),
1440
+ )
1441
+ + "#" * (width + 2)
1442
+ + "\n\n"
1443
+ )
1444
+
994
1445
  def write_out(self):
995
1446
  out_dir = self._output_dir
996
1447
  os.makedirs(out_dir, exist_ok=True)
@@ -1004,72 +1455,83 @@ class StubGenerator:
1004
1455
  "%s %s"
1005
1456
  % (self._mf_version, datetime.fromtimestamp(time.time()).isoformat())
1006
1457
  )
1007
- while len(self._pending_modules) != 0:
1008
- module_name = self._pending_modules.pop(0)
1458
+ post_process_modules = []
1459
+ is_post_processing = False
1460
+ while len(self._pending_modules) != 0 or len(post_process_modules) != 0:
1461
+ if is_post_processing or len(self._pending_modules) == 0:
1462
+ is_post_processing = True
1463
+ module_alias, module_name = post_process_modules.pop(0)
1464
+ else:
1465
+ module_alias, module_name = self._pending_modules.pop(0)
1009
1466
  # Skip vendored stuff
1010
- if module_name.startswith("metaflow._vendor"):
1467
+ if module_alias.startswith("metaflow._vendor") or module_name.startswith(
1468
+ "metaflow._vendor"
1469
+ ):
1011
1470
  continue
1012
- # We delay current module
1471
+ # We delay current module and deployer module to the end since they
1472
+ # depend on info we gather elsewhere
1013
1473
  if (
1014
- module_name == METAFLOW_CURRENT_MODULE_NAME
1015
- and len(set(self._pending_modules)) > 1
1474
+ module_alias
1475
+ in (
1476
+ METAFLOW_CURRENT_MODULE_NAME,
1477
+ METAFLOW_DEPLOYER_MODULE_NAME,
1478
+ )
1479
+ and len(self._pending_modules) != 0
1016
1480
  ):
1017
- self._pending_modules.append(module_name)
1481
+ post_process_modules.append((module_alias, module_name))
1018
1482
  continue
1019
- if module_name in self._done_modules:
1483
+ if module_alias in self._done_modules:
1020
1484
  continue
1021
- self._done_modules.add(module_name)
1485
+ self._done_modules.add(module_alias)
1022
1486
  # If not, we process the module
1023
1487
  self._reset()
1024
- self._get_module(module_name)
1488
+ self._get_module(module_alias, module_name)
1489
+ if module_name == "metaflow" and not is_post_processing:
1490
+ # We will want to regenerate this at the end to take into account
1491
+ # any changes to the Deployer
1492
+ post_process_modules.append((module_name, module_name))
1493
+ self._done_modules.remove(module_name)
1494
+ continue
1025
1495
  self._generate_stubs()
1026
1496
 
1027
1497
  if hasattr(self._current_module, "__path__"):
1028
1498
  # This is a package (so a directory) and we are dealing with
1029
1499
  # a __init__.pyi type of case
1030
- dir_path = os.path.join(
1031
- self._output_dir, *self._current_module.__name__.split(".")[1:]
1032
- )
1500
+ dir_path = os.path.join(self._output_dir, *module_alias.split(".")[1:])
1033
1501
  else:
1034
1502
  # This is NOT a package so the original source file is not a __init__.py
1035
1503
  dir_path = os.path.join(
1036
- self._output_dir, *self._current_module.__name__.split(".")[1:-1]
1504
+ self._output_dir, *module_alias.split(".")[1:-1]
1037
1505
  )
1038
1506
  out_file = os.path.join(
1039
1507
  dir_path, os.path.basename(self._current_module.__file__) + "i"
1040
1508
  )
1041
1509
 
1510
+ width = 100
1511
+
1042
1512
  os.makedirs(os.path.dirname(out_file), exist_ok=True)
1513
+ # We want to make sure we always have a __init__.pyi in the directories
1514
+ # we are creating
1515
+ parts = dir_path.split(os.sep)[len(self._output_dir.split(os.sep)) :]
1516
+ for i in range(1, len(parts) + 1):
1517
+ init_file_path = os.path.join(
1518
+ self._output_dir, *parts[:i], "__init__.pyi"
1519
+ )
1520
+ if not os.path.exists(init_file_path):
1521
+ with open(init_file_path, mode="w", encoding="utf-8") as f:
1522
+ self._write_header(f, width)
1043
1523
 
1044
- width = 80
1045
- title_line = "Auto-generated Metaflow stub file"
1046
- title_white_space = (width - len(title_line)) / 2
1047
- title_line = "#%s%s%s#\n" % (
1048
- " " * math.floor(title_white_space),
1049
- title_line,
1050
- " " * math.ceil(title_white_space),
1051
- )
1052
1524
  with open(out_file, mode="w", encoding="utf-8") as f:
1053
- f.write(
1054
- "#" * (width + 2)
1055
- + "\n"
1056
- + title_line
1057
- + "# MF version: %s%s#\n"
1058
- % (self._mf_version, " " * (width - 13 - len(self._mf_version)))
1059
- + "# Generated on %s%s#\n"
1060
- % (
1061
- datetime.fromtimestamp(time.time()).isoformat(),
1062
- " " * (width - 14 - 26),
1063
- )
1064
- + "#" * (width + 2)
1065
- + "\n\n"
1066
- )
1525
+ self._write_header(f, width)
1526
+
1067
1527
  f.write("from __future__ import annotations\n\n")
1068
1528
  imported_typing = False
1069
1529
  for module in self._imports:
1070
1530
  f.write("import " + module + "\n")
1071
1531
  if module == "typing":
1072
1532
  imported_typing = True
1533
+ for module, sub_module in self._sub_module_imports:
1534
+ f.write(f"from {module} import {sub_module}\n")
1073
1535
  if self._typing_imports:
1074
1536
  if not imported_typing:
1075
1537
  f.write("import typing\n")
@@ -1091,8 +1553,14 @@ class StubGenerator:
1091
1553
  "%s = %s\n" % (type_name, new_type_to_str(type_var))
1092
1554
  )
1093
1555
  f.write("\n")
1556
+ for import_line in self._current_references:
1557
+ f.write(import_line + "\n")
1558
+ f.write("\n")
1094
1559
  for stub in self._stubs:
1095
1560
  f.write(stub + "\n")
1561
+ if is_post_processing:
1562
+ # Don't consider any pending modules if we are post processing
1563
+ self._pending_modules.clear()
1096
1564
 
1097
1565
 
1098
1566
  if __name__ == "__main__":