metaflow 2.12.28__py2.py3-none-any.whl → 2.12.29__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. metaflow/__init__.py +2 -3
  2. metaflow/cli.py +23 -13
  3. metaflow/client/core.py +2 -2
  4. metaflow/clone_util.py +1 -1
  5. metaflow/cmd/develop/stub_generator.py +623 -233
  6. metaflow/datastore/task_datastore.py +1 -1
  7. metaflow/extension_support/plugins.py +1 -0
  8. metaflow/flowspec.py +2 -2
  9. metaflow/includefile.py +8 -14
  10. metaflow/metaflow_config.py +4 -0
  11. metaflow/metaflow_current.py +1 -1
  12. metaflow/parameters.py +3 -0
  13. metaflow/plugins/__init__.py +12 -3
  14. metaflow/plugins/airflow/airflow_cli.py +5 -0
  15. metaflow/plugins/airflow/airflow_decorator.py +1 -1
  16. metaflow/plugins/argo/argo_workflows_decorator.py +1 -1
  17. metaflow/plugins/argo/argo_workflows_deployer.py +77 -263
  18. metaflow/plugins/argo/argo_workflows_deployer_objects.py +381 -0
  19. metaflow/plugins/aws/batch/batch_cli.py +1 -1
  20. metaflow/plugins/aws/batch/batch_decorator.py +2 -2
  21. metaflow/plugins/aws/step_functions/step_functions_cli.py +7 -0
  22. metaflow/plugins/aws/step_functions/step_functions_decorator.py +1 -1
  23. metaflow/plugins/aws/step_functions/step_functions_deployer.py +65 -224
  24. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +236 -0
  25. metaflow/plugins/azure/includefile_support.py +2 -0
  26. metaflow/plugins/cards/card_cli.py +3 -2
  27. metaflow/plugins/cards/card_modules/components.py +9 -9
  28. metaflow/plugins/cards/card_server.py +39 -14
  29. metaflow/plugins/datatools/local.py +2 -0
  30. metaflow/plugins/datatools/s3/s3.py +2 -0
  31. metaflow/plugins/env_escape/__init__.py +3 -3
  32. metaflow/plugins/gcp/includefile_support.py +3 -0
  33. metaflow/plugins/kubernetes/kubernetes_cli.py +1 -1
  34. metaflow/plugins/kubernetes/kubernetes_decorator.py +5 -4
  35. metaflow/plugins/{metadata → metadata_providers}/local.py +2 -2
  36. metaflow/plugins/{metadata → metadata_providers}/service.py +2 -2
  37. metaflow/plugins/parallel_decorator.py +1 -1
  38. metaflow/plugins/pypi/conda_decorator.py +1 -1
  39. metaflow/plugins/test_unbounded_foreach_decorator.py +1 -1
  40. metaflow/runner/click_api.py +4 -0
  41. metaflow/runner/deployer.py +139 -269
  42. metaflow/runner/deployer_impl.py +167 -0
  43. metaflow/runner/metaflow_runner.py +10 -9
  44. metaflow/runner/nbdeploy.py +12 -13
  45. metaflow/runner/nbrun.py +3 -3
  46. metaflow/runner/utils.py +55 -8
  47. metaflow/runtime.py +1 -1
  48. metaflow/task.py +1 -1
  49. metaflow/version.py +1 -1
  50. {metaflow-2.12.28.dist-info → metaflow-2.12.29.dist-info}/METADATA +2 -2
  51. {metaflow-2.12.28.dist-info → metaflow-2.12.29.dist-info}/RECORD +60 -57
  52. /metaflow/{metadata → metadata_provider}/__init__.py +0 -0
  53. /metaflow/{metadata → metadata_provider}/heartbeat.py +0 -0
  54. /metaflow/{metadata → metadata_provider}/metadata.py +0 -0
  55. /metaflow/{metadata → metadata_provider}/util.py +0 -0
  56. /metaflow/plugins/{metadata → metadata_providers}/__init__.py +0 -0
  57. {metaflow-2.12.28.dist-info → metaflow-2.12.29.dist-info}/LICENSE +0 -0
  58. {metaflow-2.12.28.dist-info → metaflow-2.12.29.dist-info}/WHEEL +0 -0
  59. {metaflow-2.12.28.dist-info → metaflow-2.12.29.dist-info}/entry_points.txt +0 -0
  60. {metaflow-2.12.28.dist-info → metaflow-2.12.29.dist-info}/top_level.txt +0 -0
@@ -31,13 +31,17 @@ from metaflow import FlowSpec, step
31
31
  from metaflow.debug import debug
32
32
  from metaflow.decorators import Decorator, FlowDecorator
33
33
  from metaflow.extension_support import get_aliased_modules
34
- from metaflow.graph import deindent_docstring
34
+ from metaflow.metaflow_current import Current
35
35
  from metaflow.metaflow_version import get_version
36
+ from metaflow.runner.deployer import DeployedFlow, Deployer, TriggeredRun
37
+ from metaflow.runner.deployer_impl import DeployerImpl
36
38
 
37
39
  TAB = " "
38
40
  METAFLOW_CURRENT_MODULE_NAME = "metaflow.metaflow_current"
41
+ METAFLOW_DEPLOYER_MODULE_NAME = "metaflow.runner.deployer"
39
42
 
40
43
  param_section_header = re.compile(r"Parameters\s*\n----------\s*\n", flags=re.M)
44
+ return_section_header = re.compile(r"Returns\s*\n-------\s*\n", flags=re.M)
41
45
  add_to_current_header = re.compile(
42
46
  r"MF Add To Current\s*\n-----------------\s*\n", flags=re.M
43
47
  )
@@ -57,6 +61,20 @@ MetaflowStepFunction = Union[
57
61
  ]
58
62
 
59
63
 
64
+ # Object that has start() and end() like a Match object to make the code simpler when
65
+ # we are parsing different sections of doc
66
+ class StartEnd:
67
+ def __init__(self, start: int, end: int):
68
+ self._start = start
69
+ self._end = end
70
+
71
+ def start(self):
72
+ return self._start
73
+
74
+ def end(self):
75
+ return self._end
76
+
77
+
60
78
  def type_var_to_str(t: TypeVar) -> str:
61
79
  bound_name = None
62
80
  if t.__bound__ is not None:
@@ -92,6 +110,131 @@ def descend_object(object: str, options: Iterable[str]):
92
110
  return False
93
111
 
94
112
 
113
+ def parse_params_from_doc(doc: str) -> Tuple[List[inspect.Parameter], bool]:
114
+ parameters = []
115
+ no_arg_version = True
116
+ for line in doc.splitlines():
117
+ if non_indented_line.match(line):
118
+ match = param_name_type.match(line)
119
+ arg_name = type_name = is_optional = default = None
120
+ default_set = False
121
+ if match is not None:
122
+ arg_name = match.group("name")
123
+ type_name = match.group("type")
124
+ if type_name is not None:
125
+ type_detail = type_annotations.match(type_name)
126
+ if type_detail is not None:
127
+ type_name = type_detail.group("type")
128
+ is_optional = type_detail.group("optional") is not None
129
+ default = type_detail.group("default")
130
+ if default:
131
+ default_set = True
132
+ try:
133
+ default = eval(default)
134
+ except:
135
+ pass
136
+ try:
137
+ type_name = eval(type_name)
138
+ except:
139
+ pass
140
+ parameters.append(
141
+ inspect.Parameter(
142
+ name=arg_name,
143
+ kind=inspect.Parameter.KEYWORD_ONLY,
144
+ default=(
145
+ default
146
+ if default_set
147
+ else None if is_optional else inspect.Parameter.empty
148
+ ),
149
+ annotation=(Optional[type_name] if is_optional else type_name),
150
+ )
151
+ )
152
+ if not default_set:
153
+ # If we don't have a default set for any parameter, we can't
154
+ # have a no-arg version since the function would be incomplete
155
+ no_arg_version = False
156
+ return parameters, no_arg_version
157
+
158
+
159
+ def split_docs(
160
+ raw_doc: str, boundaries: List[Tuple[str, Union[StartEnd, re.Match]]]
161
+ ) -> Dict[str, str]:
162
+ docs = dict()
163
+ boundaries.sort(key=lambda x: x[1].start())
164
+
165
+ section_start = 0
166
+ for idx in range(1, len(boundaries)):
167
+ docs[boundaries[idx - 1][0]] = raw_doc[
168
+ section_start : boundaries[idx][1].start()
169
+ ]
170
+ section_start = boundaries[idx][1].end()
171
+ docs[boundaries[-1][0]] = raw_doc[section_start:]
172
+ return docs
173
+
174
+
175
+ def parse_add_to_docs(
176
+ raw_doc: str,
177
+ ) -> Dict[str, Union[Tuple[inspect.Signature, str], str]]:
178
+ prop = None
179
+ return_type = None
180
+ property_indent = None
181
+ doc = []
182
+ add_to_docs = dict() # type: Dict[str, Union[str, Tuple[inspect.Signature, str]]]
183
+
184
+ def _add():
185
+ if prop:
186
+ add_to_docs[prop] = (
187
+ inspect.Signature(
188
+ [
189
+ inspect.Parameter(
190
+ "self", inspect.Parameter.POSITIONAL_OR_KEYWORD
191
+ )
192
+ ],
193
+ return_annotation=return_type,
194
+ ),
195
+ "\n".join(doc),
196
+ )
197
+
198
+ for line in raw_doc.splitlines():
199
+ # Parse stanzas that look like the following:
200
+ # <property-name> -> type
201
+ # indented doc string
202
+ if property_indent is not None and (
203
+ line.startswith(property_indent + " ") or line.strip() == ""
204
+ ):
205
+ offset = len(property_indent)
206
+ if line.lstrip().startswith("@@ "):
207
+ line = line.replace("@@ ", "")
208
+ doc.append(line[offset:].rstrip())
209
+ else:
210
+ if line.strip() == 0:
211
+ continue
212
+ if prop:
213
+ # Ends a property stanza
214
+ _add()
215
+ # Now start a new one
216
+ line = line.rstrip()
217
+ property_indent = line[: len(line) - len(line.lstrip())]
218
+ # Either this has a -> to denote a property or it is a pure name
219
+ # to denote a reference to a function (starting with #)
220
+ line = line.lstrip()
221
+ if line.startswith("#"):
222
+ # The name of the function is the last part like metaflow.deployer.run
223
+ add_to_docs[line.split(".")[-1]] = line[1:]
224
+ continue
225
+ # This is a line so we split it using "->"
226
+ prop, return_type = line.split("->")
227
+ prop = prop.strip()
228
+ return_type = return_type.strip()
229
+ doc = []
230
+ _add()
231
+ return add_to_docs
232
+
233
+
234
+ def add_indent(indentation: str, text: str) -> str:
235
+ return "\n".join([indentation + line for line in text.splitlines()])
236
+
237
+
95
238
  class StubGenerator:
96
239
  """
97
240
  This class takes the name of a library as input and a directory as output.
@@ -121,11 +264,18 @@ class StubGenerator:
121
264
  os.environ["METAFLOW_STUBGEN"] = "1"
122
265
 
123
266
  self._write_generated_for = include_generated_for
124
- self._pending_modules = ["metaflow"] # type: List[str]
125
- self._pending_modules.extend(get_aliased_modules())
267
+ # First element is the name it should be installed in (alias) and second is the
268
+ # actual module name
269
+ self._pending_modules = [
270
+ ("metaflow", "metaflow")
271
+ ] # type: List[Tuple[str, str]]
126
272
  self._root_module = "metaflow."
127
273
  self._safe_modules = ["metaflow.", "metaflow_extensions."]
128
274
 
275
+ self._pending_modules.extend(
276
+ (self._get_module_name_alias(x), x) for x in get_aliased_modules()
277
+ )
278
+
129
279
  # We exclude some modules to not create a bunch of random non-user facing
130
280
  # .pyi files.
131
281
  self._exclude_modules = set(
@@ -151,7 +301,7 @@ class StubGenerator:
151
301
  "metaflow.package",
152
302
  "metaflow.plugins.datastores",
153
303
  "metaflow.plugins.env_escape",
154
- "metaflow.plugins.metadata",
304
+ "metaflow.plugins.metadata_providers",
155
305
  "metaflow.procpoll.py",
156
306
  "metaflow.R",
157
307
  "metaflow.runtime",
@@ -163,9 +313,16 @@ class StubGenerator:
163
313
  "metaflow._vendor",
164
314
  ]
165
315
  )
316
+
166
317
  self._done_modules = set() # type: Set[str]
167
318
  self._output_dir = output_dir
168
319
  self._mf_version = get_version()
320
+
321
+ # Contains the names of the methods that are injected in Deployer
322
+ self._deployer_injected_methods = (
323
+ {}
324
+ ) # type: Dict[str, Dict[str, Union[Tuple[str, str], str]]]
325
+ # Contains information to add to the Current object (injected by decorators)
169
326
  self._addl_current = (
170
327
  dict()
171
328
  ) # type: Dict[str, Dict[str, Tuple[inspect.Signature, str]]]
@@ -184,6 +341,7 @@ class StubGenerator:
184
341
  self._typevars = dict() # type: Dict[str, Union[TypeVar, type]]
185
342
  # Current objects in the file being processed
186
343
  self._current_objects = {} # type: Dict[str, Any]
344
+ self._current_references = [] # type: List[str]
187
345
  # Current stubs in the file being processed
188
346
  self._stubs = [] # type: List[str]
189
347
 
@@ -192,26 +350,78 @@ class StubGenerator:
192
350
  # the "globals()"
193
351
  self._current_parent_module = None # type: Optional[ModuleType]
194
352
 
195
- def _get_module(self, name):
196
- debug.stubgen_exec("Analyzing module %s ..." % name)
353
+ def _get_module_name_alias(self, module_name):
354
+ if any(
355
+ module_name.startswith(x) for x in self._safe_modules
356
+ ) and not module_name.startswith(self._root_module):
357
+ return self._root_module + ".".join(
358
+ ["mf_extensions", *module_name.split(".")[1:]]
359
+ )
360
+ return module_name
361
+
362
+ def _get_relative_import(
363
+ self, new_module_name, cur_module_name, is_init_module=False
364
+ ):
365
+ new_components = new_module_name.split(".")
366
+ cur_components = cur_module_name.split(".")
367
+ init_module_count = 1 if is_init_module else 0
368
+ common_idx = 0
369
+ max_idx = min(len(new_components), len(cur_components))
370
+ while (
371
+ common_idx < max_idx
372
+ and new_components[common_idx] == cur_components[common_idx]
373
+ ):
374
+ common_idx += 1
375
+ # current: a.b and parent: a.b.e.d -> from .e.d import <name>
376
+ # current: a.b.c.d and parent: a.b.e.f -> from ...e.f import <name>
377
+ return "." * (len(cur_components) - common_idx + init_module_count) + ".".join(
378
+ new_components[common_idx:]
379
+ )
380
+
381
+ def _get_module(self, alias, name):
382
+ debug.stubgen_exec("Analyzing module %s (aliased at %s)..." % (name, alias))
197
383
  self._current_module = importlib.import_module(name)
198
- self._current_module_name = name
384
+ self._current_module_name = alias
199
385
  for objname, obj in self._current_module.__dict__.items():
386
+ if objname == "_addl_stubgen_modules":
387
+ debug.stubgen_exec(
388
+ "Adding modules %s from _addl_stubgen_modules" % str(obj)
389
+ )
390
+ self._pending_modules.extend(
391
+ (self._get_module_name_alias(m), m) for m in obj
392
+ )
393
+ continue
200
394
  if objname.startswith("_"):
201
395
  debug.stubgen_exec(
202
396
  "Skipping object because it starts with _ %s" % objname
203
397
  )
204
398
  continue
205
399
  if inspect.ismodule(obj):
206
- # Only consider modules that are part of the root module
400
+ # Only consider modules that are safe modules
207
401
  if (
208
- obj.__name__.startswith(self._root_module)
402
+ any(obj.__name__.startswith(m) for m in self._safe_modules)
209
403
  and not obj.__name__ in self._exclude_modules
210
404
  ):
211
405
  debug.stubgen_exec(
212
406
  "Adding child module %s to process" % obj.__name__
213
407
  )
214
- self._pending_modules.append(obj.__name__)
408
+
409
+ new_module_alias = self._get_module_name_alias(obj.__name__)
410
+ self._pending_modules.append((new_module_alias, obj.__name__))
411
+
412
+ new_parent, new_name = new_module_alias.rsplit(".", 1)
413
+ self._current_references.append(
414
+ "from %s import %s as %s"
415
+ % (
416
+ self._get_relative_import(
417
+ new_parent,
418
+ alias,
419
+ hasattr(self._current_module, "__path__"),
420
+ ),
421
+ new_name,
422
+ objname,
423
+ )
424
+ )
215
425
  else:
216
426
  debug.stubgen_exec("Skipping child module %s" % obj.__name__)
217
427
  else:
@@ -221,8 +431,10 @@ class StubGenerator:
221
431
  # we could be more specific but good enough for now) for root module.
222
432
  # We also include the step decorator (it's from metaflow.decorators
223
433
  # which is typically excluded)
224
- # - otherwise, anything that is in safe_modules. Note this may include
225
- # a bit much (all the imports)
434
+ # - Stuff that is defined in this module itself
435
+ # - a reference to anything in the modules we will process later
436
+ # (so we don't duplicate a ton of times)
437
+
226
438
  if (
227
439
  parent_module is None
228
440
  or (
@@ -232,43 +444,44 @@ class StubGenerator:
232
444
  or obj == step
233
445
  )
234
446
  )
235
- or (
236
- not any(
237
- [
238
- parent_module.__name__.startswith(p)
239
- for p in self._exclude_modules
240
- ]
241
- )
242
- and any(
243
- [
244
- parent_module.__name__.startswith(p)
245
- for p in self._safe_modules
246
- ]
247
- )
248
- )
447
+ or parent_module.__name__ == name
249
448
  ):
250
449
  debug.stubgen_exec("Adding object %s to process" % objname)
251
450
  self._current_objects[objname] = obj
252
- else:
253
- debug.stubgen_exec("Skipping object %s" % objname)
254
- # We also include the module to process if it is part of root_module
255
- if (
256
- parent_module is not None
257
- and not any(
258
- [
259
- parent_module.__name__.startswith(d)
260
- for d in self._exclude_modules
261
- ]
262
- )
263
- and parent_module.__name__.startswith(self._root_module)
451
+
452
+ elif not any(
453
+ [
454
+ parent_module.__name__.startswith(p)
455
+ for p in self._exclude_modules
456
+ ]
457
+ ) and any(
458
+ [parent_module.__name__.startswith(p) for p in self._safe_modules]
264
459
  ):
460
+ parent_alias = self._get_module_name_alias(parent_module.__name__)
461
+
462
+ relative_import = self._get_relative_import(
463
+ parent_alias, alias, hasattr(self._current_module, "__path__")
464
+ )
465
+
265
466
  debug.stubgen_exec(
266
- "Adding module of child object %s to process"
267
- % parent_module.__name__,
467
+ "Adding reference %s and adding module %s as %s"
468
+ % (objname, parent_module.__name__, parent_alias)
469
+ )
470
+ obj_import_name = getattr(obj, "__name__", objname)
471
+ if obj_import_name == "<lambda>":
472
+ # We have one case of this
473
+ obj_import_name = objname
474
+ self._current_references.append(
475
+ "from %s import %s as %s"
476
+ % (relative_import, obj_import_name, objname)
268
477
  )
269
- self._pending_modules.append(parent_module.__name__)
478
+ self._pending_modules.append((parent_alias, parent_module.__name__))
479
+ else:
480
+ debug.stubgen_exec("Skipping object %s" % objname)
270
481
 
271
- def _get_element_name_with_module(self, element: Union[TypeVar, type, Any]) -> str:
482
+ def _get_element_name_with_module(
483
+ self, element: Union[TypeVar, type, Any], force_import=False
484
+ ) -> str:
272
485
  # The element can be a string, for example "def f() -> 'SameClass':..."
273
486
  def _add_to_import(name):
274
487
  if name != self._current_module_name:
@@ -292,6 +505,9 @@ class StubGenerator:
292
505
  self._typing_imports.add(splits[0])
293
506
 
294
507
  if isinstance(element, str):
508
+ # Special case for self referential things (particularly in a class)
509
+ if element == self._current_name:
510
+ return '"%s"' % element
295
511
  # We first try to eval the annotation because with the annotations future
296
512
  # it is always a string
297
513
  try:
@@ -309,6 +525,9 @@ class StubGenerator:
309
525
  pass
310
526
 
311
527
  if isinstance(element, str):
528
+ # If we are in our "safe" modules, make sure we alias properly
529
+ if any(element.startswith(x) for x in self._safe_modules):
530
+ element = self._get_module_name_alias(element)
312
531
  _add_to_typing_check(element)
313
532
  return '"%s"' % element
314
533
  # 3.10+ has NewType as a class but not before so hack around to check for NewType
@@ -328,9 +547,12 @@ class StubGenerator:
328
547
  return "None"
329
548
  return element.__name__
330
549
 
331
- _add_to_typing_check(module.__name__, is_module=True)
332
- if module.__name__ != self._current_module_name:
333
- return "{0}.{1}".format(module.__name__, element.__name__)
550
+ module_name = self._get_module_name_alias(module.__name__)
551
+ if force_import:
552
+ _add_to_import(module_name.split(".")[0])
553
+ _add_to_typing_check(module_name, is_module=True)
554
+ if module_name != self._current_module_name:
555
+ return "{0}.{1}".format(module_name, element.__name__)
334
556
  else:
335
557
  return element.__name__
336
558
  elif isinstance(element, type(Ellipsis)):
@@ -364,7 +586,7 @@ class StubGenerator:
364
586
  else:
365
587
  return "%s[%s]" % (element.__origin__, ", ".join(args_str))
366
588
  elif isinstance(element, ForwardRef):
367
- f_arg = element.__forward_arg__
589
+ f_arg = self._get_module_name_alias(element.__forward_arg__)
368
590
  # if f_arg in ("Run", "Task"): # HACK -- forward references in current.py
369
591
  # _add_to_import("metaflow")
370
592
  # f_arg = "metaflow.%s" % f_arg
@@ -377,9 +599,17 @@ class StubGenerator:
377
599
  return "typing.NamedTuple"
378
600
  return str(element)
379
601
  else:
380
- raise RuntimeError(
381
- "Does not handle element %s of type %s" % (str(element), type(element))
382
- )
602
+ if hasattr(element, "__module__"):
603
+ elem_module = self._get_module_name_alias(element.__module__)
604
+ if elem_module == "builtins":
605
+ return getattr(element, "__name__", str(element))
606
+ _add_to_typing_check(elem_module, is_module=True)
607
+ return "{0}.{1}".format(
608
+ elem_module, getattr(element, "__name__", element)
609
+ )
610
+ else:
611
+ # A constant
612
+ return str(element)
383
613
 
384
614
  def _exploit_annotation(self, annotation: Any, starting: str = ": ") -> str:
385
615
  annotation_string = ""
@@ -390,23 +620,34 @@ class StubGenerator:
390
620
  return annotation_string
391
621
 
392
622
  def _generate_class_stub(self, name: str, clazz: type) -> str:
623
+ debug.stubgen_exec("Generating class stub for %s" % name)
624
+ skip_init = issubclass(clazz, (TriggeredRun, DeployedFlow))
625
+ if issubclass(clazz, DeployerImpl):
626
+ if clazz.TYPE is not None:
627
+ clazz_type = clazz.TYPE.replace("-", "_")
628
+ self._deployer_injected_methods.setdefault(clazz_type, {})[
629
+ "deployer"
630
+ ] = (self._current_module_name + "." + name)
631
+
393
632
  buff = StringIO()
394
633
  # Class prototype
395
634
  buff.write("class " + name.split(".")[-1] + "(")
396
635
 
397
636
  # Add super classes
398
637
  for c in clazz.__bases__:
399
- name_with_module = self._get_element_name_with_module(c)
638
+ name_with_module = self._get_element_name_with_module(c, force_import=True)
400
639
  buff.write(name_with_module + ", ")
401
640
 
402
641
  # Add metaclass
403
- name_with_module = self._get_element_name_with_module(clazz.__class__)
642
+ name_with_module = self._get_element_name_with_module(
643
+ clazz.__class__, force_import=True
644
+ )
404
645
  buff.write("metaclass=" + name_with_module + "):\n")
405
646
 
406
647
  # Add class docstring
407
648
  if clazz.__doc__:
408
649
  buff.write('%s"""\n' % TAB)
409
- my_doc = cast(str, deindent_docstring(clazz.__doc__))
650
+ my_doc = inspect.cleandoc(clazz.__doc__)
410
651
  init_blank = True
411
652
  for line in my_doc.split("\n"):
412
653
  if init_blank and len(line.strip()) == 0:
@@ -429,6 +670,8 @@ class StubGenerator:
429
670
  func_deco = "@classmethod"
430
671
  element = element.__func__
431
672
  if key == "__init__":
673
+ if skip_init:
674
+ continue
432
675
  init_func = element
433
676
  elif key == "__annotations__":
434
677
  annotation_dict = element
@@ -436,11 +679,201 @@ class StubGenerator:
436
679
  if not element.__name__.startswith("_") or element.__name__.startswith(
437
680
  "__"
438
681
  ):
439
- buff.write(
440
- self._generate_function_stub(
441
- key, element, indentation=TAB, deco=func_deco
682
+ if (
683
+ clazz == Deployer
684
+ and element.__name__ in self._deployer_injected_methods
685
+ ):
686
+ # This is a method that was injected. It has docs but we need
687
+ # to parse it to generate the proper signature
688
+ func_doc = inspect.cleandoc(element.__doc__)
689
+ docs = split_docs(
690
+ func_doc,
691
+ [
692
+ ("func_doc", StartEnd(0, 0)),
693
+ (
694
+ "param_doc",
695
+ param_section_header.search(func_doc)
696
+ or StartEnd(len(func_doc), len(func_doc)),
697
+ ),
698
+ (
699
+ "return_doc",
700
+ return_section_header.search(func_doc)
701
+ or StartEnd(len(func_doc), len(func_doc)),
702
+ ),
703
+ ],
442
704
  )
443
- )
705
+
706
+ parameters, _ = parse_params_from_doc(docs["param_doc"])
707
+ return_type = self._deployer_injected_methods[element.__name__][
708
+ "deployer"
709
+ ]
710
+
711
+ buff.write(
712
+ self._generate_function_stub(
713
+ key,
714
+ element,
715
+ sign=[
716
+ inspect.Signature(
717
+ parameters=[
718
+ inspect.Parameter(
719
+ "self",
720
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
721
+ )
722
+ ]
723
+ + parameters,
724
+ return_annotation=return_type,
725
+ )
726
+ ],
727
+ indentation=TAB,
728
+ deco=func_deco,
729
+ )
730
+ )
731
+ elif (
732
+ clazz == DeployedFlow and element.__name__ == "from_deployment"
733
+ ):
734
+ # We simply update the signature to list the return
735
+ # type as a union of all possible deployers
736
+ func_doc = inspect.cleandoc(element.__doc__)
737
+ docs = split_docs(
738
+ func_doc,
739
+ [
740
+ ("func_doc", StartEnd(0, 0)),
741
+ (
742
+ "param_doc",
743
+ param_section_header.search(func_doc)
744
+ or StartEnd(len(func_doc), len(func_doc)),
745
+ ),
746
+ (
747
+ "return_doc",
748
+ return_section_header.search(func_doc)
749
+ or StartEnd(len(func_doc), len(func_doc)),
750
+ ),
751
+ ],
752
+ )
753
+
754
+ parameters, _ = parse_params_from_doc(docs["param_doc"])
755
+
756
+ def _create_multi_type(*l):
757
+ return typing.Union[l]
758
+
759
+ all_types = [
760
+ v["from_deployment"][0]
761
+ for v in self._deployer_injected_methods.values()
762
+ ]
763
+
764
+ if len(all_types) > 1:
765
+ return_type = _create_multi_type(*all_types)
766
+ else:
767
+ return_type = all_types[0] if len(all_types) else None
768
+
769
+ buff.write(
770
+ self._generate_function_stub(
771
+ key,
772
+ element,
773
+ sign=[
774
+ inspect.Signature(
775
+ parameters=[
776
+ inspect.Parameter(
777
+ "cls",
778
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
779
+ )
780
+ ]
781
+ + parameters,
782
+ return_annotation=return_type,
783
+ )
784
+ ],
785
+ indentation=TAB,
786
+ doc=docs["func_doc"]
787
+ + "\n\nParameters\n----------\n"
788
+ + docs["param_doc"]
789
+ + "\n\nReturns\n-------\n"
790
+ + "%s\nA `DeployedFlow` object" % str(return_type),
791
+ deco=func_deco,
792
+ )
793
+ )
794
+ elif (
795
+ clazz == DeployedFlow
796
+ and element.__name__.startswith("from_")
797
+ and element.__name__[5:] in self._deployer_injected_methods
798
+ ):
799
+ # Get the doc from the from_deployment method stored in
800
+ # self._deployer_injected_methods
801
+ func_doc = inspect.cleandoc(
802
+ self._deployer_injected_methods[element.__name__[5:]][
803
+ "from_deployment"
804
+ ][1]
805
+ or ""
806
+ )
807
+ docs = split_docs(
808
+ func_doc,
809
+ [
810
+ ("func_doc", StartEnd(0, 0)),
811
+ (
812
+ "param_doc",
813
+ param_section_header.search(func_doc)
814
+ or StartEnd(len(func_doc), len(func_doc)),
815
+ ),
816
+ (
817
+ "return_doc",
818
+ return_section_header.search(func_doc)
819
+ or StartEnd(len(func_doc), len(func_doc)),
820
+ ),
821
+ ],
822
+ )
823
+
824
+ parameters, _ = parse_params_from_doc(docs["param_doc"])
825
+ return_type = self._deployer_injected_methods[
826
+ element.__name__[5:]
827
+ ]["from_deployment"][0]
828
+
829
+ buff.write(
830
+ self._generate_function_stub(
831
+ key,
832
+ element,
833
+ sign=[
834
+ inspect.Signature(
835
+ parameters=[
836
+ inspect.Parameter(
837
+ "cls",
838
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
839
+ )
840
+ ]
841
+ + parameters,
842
+ return_annotation=return_type,
843
+ )
844
+ ],
845
+ indentation=TAB,
846
+ doc=docs["func_doc"]
847
+ + "\n\nParameters\n----------\n"
848
+ + docs["param_doc"]
849
+ + "\n\nReturns\n-------\n"
850
+ + docs["return_doc"],
851
+ deco=func_deco,
852
+ )
853
+ )
854
+ else:
855
+ if (
856
+ issubclass(clazz, DeployedFlow)
857
+ and clazz.TYPE is not None
858
+ and key == "from_deployment"
859
+ ):
860
+ clazz_type = clazz.TYPE.replace("-", "_")
861
+ # Record docstring for this function
862
+ self._deployer_injected_methods.setdefault(clazz_type, {})[
863
+ "from_deployment"
864
+ ] = (
865
+ self._current_module_name + "." + name,
866
+ element.__doc__,
867
+ )
868
+ buff.write(
869
+ self._generate_function_stub(
870
+ key,
871
+ element,
872
+ indentation=TAB,
873
+ deco=func_deco,
874
+ )
875
+ )
876
+
444
877
  elif isinstance(element, property):
445
878
  if element.fget:
446
879
  buff.write(
@@ -455,20 +888,17 @@ class StubGenerator:
455
888
  )
456
889
  )
457
890
 
458
- # Special handling for the current module
459
- if (
460
- self._current_module_name == METAFLOW_CURRENT_MODULE_NAME
461
- and name == "Current"
462
- ):
891
+ # Special handling of classes that have injected methods
892
+ if clazz == Current:
463
893
  # Multiple decorators can add the same object (trigger and trigger_on_finish)
464
894
  # as examples so we sort it out.
465
895
  resulting_dict = (
466
896
  dict()
467
897
  ) # type Dict[str, List[inspect.Signature, str, List[str]]]
468
- for project_name, addl_current in self._addl_current.items():
898
+ for deco_name, addl_current in self._addl_current.items():
469
899
  for name, (sign, doc) in addl_current.items():
470
900
  r = resulting_dict.setdefault(name, [sign, doc, []])
471
- r[2].append("@%s" % project_name)
901
+ r[2].append("@%s" % deco_name)
472
902
  for name, (sign, doc, decos) in resulting_dict.items():
473
903
  buff.write(
474
904
  self._generate_function_stub(
@@ -481,7 +911,8 @@ class StubGenerator:
481
911
  deco="@property",
482
912
  )
483
913
  )
484
- if init_func is None and annotation_dict:
914
+
915
+ if not skip_init and init_func is None and annotation_dict:
485
916
  buff.write(
486
917
  self._generate_function_stub(
487
918
  "__init__",
@@ -527,121 +958,31 @@ class StubGenerator:
527
958
  self._typevars["StepFlag"] = StepFlag
528
959
 
529
960
  raw_doc = inspect.cleandoc(raw_doc)
530
- has_parameters = param_section_header.search(raw_doc)
531
- has_add_to_current = add_to_current_header.search(raw_doc)
532
-
533
- if has_parameters and has_add_to_current:
534
- doc = raw_doc[has_parameters.end() : has_add_to_current.start()]
535
- add_to_current_doc = raw_doc[has_add_to_current.end() :]
536
- raw_doc = raw_doc[: has_add_to_current.start()]
537
- elif has_parameters:
538
- doc = raw_doc[has_parameters.end() :]
539
- add_to_current_doc = None
540
- elif has_add_to_current:
541
- add_to_current_doc = raw_doc[has_add_to_current.end() :]
542
- raw_doc = raw_doc[: has_add_to_current.start()]
543
- doc = ""
544
- else:
545
- doc = ""
546
- add_to_current_doc = None
547
- parameters = []
548
- no_arg_version = True
549
- for line in doc.splitlines():
550
- if non_indented_line.match(line):
551
- match = param_name_type.match(line)
552
- arg_name = type_name = is_optional = default = None
553
- default_set = False
554
- if match is not None:
555
- arg_name = match.group("name")
556
- type_name = match.group("type")
557
- if type_name is not None:
558
- type_detail = type_annotations.match(type_name)
559
- if type_detail is not None:
560
- type_name = type_detail.group("type")
561
- is_optional = type_detail.group("optional") is not None
562
- default = type_detail.group("default")
563
- if default:
564
- default_set = True
565
- try:
566
- default = eval(default)
567
- except:
568
- pass
569
- try:
570
- type_name = eval(type_name)
571
- except:
572
- pass
573
- parameters.append(
574
- inspect.Parameter(
575
- name=arg_name,
576
- kind=inspect.Parameter.KEYWORD_ONLY,
577
- default=(
578
- default
579
- if default_set
580
- else None if is_optional else inspect.Parameter.empty
581
- ),
582
- annotation=(
583
- Optional[type_name] if is_optional else type_name
584
- ),
585
- )
586
- )
587
- if not default_set:
588
- # If we don't have a default set for any parameter, we can't
589
- # have a no-arg version since the decorator would be incomplete
590
- no_arg_version = False
591
- if add_to_current_doc:
592
- current_property = None
593
- current_return_type = None
594
- current_property_indent = None
595
- current_doc = []
596
- add_to_current = dict() # type: Dict[str, Tuple[inspect.Signature, str]]
597
-
598
- def _add():
599
- if current_property:
600
- add_to_current[current_property] = (
601
- inspect.Signature(
602
- [
603
- inspect.Parameter(
604
- "self", inspect.Parameter.POSITIONAL_OR_KEYWORD
605
- )
606
- ],
607
- return_annotation=current_return_type,
608
- ),
609
- "\n".join(current_doc),
610
- )
611
-
612
- for line in add_to_current_doc.splitlines():
613
- # Parse stanzas that look like the following:
614
- # <property-name> -> type
615
- # indented doc string
616
- if current_property_indent is not None and (
617
- line.startswith(current_property_indent + " ") or line.strip() == ""
618
- ):
619
- offset = len(current_property_indent)
620
- if line.lstrip().startswith("@@ "):
621
- line = line.replace("@@ ", "")
622
- current_doc.append(line[offset:].rstrip())
623
- else:
624
- if line.strip() == 0:
625
- continue
626
- if current_property:
627
- # Ends a property stanza
628
- _add()
629
- # Now start a new one
630
- line = line.rstrip()
631
- current_property_indent = line[: len(line) - len(line.lstrip())]
632
- # This is a line so we split it using "->"
633
- current_property, current_return_type = line.split("->")
634
- current_property = current_property.strip()
635
- current_return_type = current_return_type.strip()
636
- current_doc = []
637
- _add()
638
-
639
- self._addl_current[name] = add_to_current
961
+ section_boundaries = [
962
+ ("func_doc", StartEnd(0, 0)),
963
+ (
964
+ "param_doc",
965
+ param_section_header.search(raw_doc)
966
+ or StartEnd(len(raw_doc), len(raw_doc)),
967
+ ),
968
+ (
969
+ "add_to_current_doc",
970
+ add_to_current_header.search(raw_doc)
971
+ or StartEnd(len(raw_doc), len(raw_doc)),
972
+ ),
973
+ ]
974
+
975
+ docs = split_docs(raw_doc, section_boundaries)
976
+
977
+ parameters, no_arg_version = parse_params_from_doc(docs["param_doc"])
978
+
979
+ if docs["add_to_current_doc"]:
980
+ self._addl_current[name] = parse_add_to_docs(docs["add_to_current_doc"])
640
981
 
641
982
  result = []
642
983
  if no_arg_version:
643
984
  if is_flow_decorator:
644
- if has_parameters:
985
+ if docs["param_doc"]:
645
986
  result.append(
646
987
  (
647
988
  inspect.Signature(
@@ -670,7 +1011,7 @@ class StubGenerator:
670
1011
  ),
671
1012
  )
672
1013
  else:
673
- if has_parameters:
1014
+ if docs["param_doc"]:
674
1015
  result.append(
675
1016
  (
676
1017
  inspect.Signature(
@@ -792,8 +1133,8 @@ class StubGenerator:
792
1133
  result = result[1:]
793
1134
  # Add doc to first and last overloads. Jedi uses the last one and pycharm
794
1135
  # the first one. Go figure.
795
- result[0] = (result[0][0], raw_doc)
796
- result[-1] = (result[-1][0], raw_doc)
1136
+ result[0] = (result[0][0], docs["func_doc"])
1137
+ result[-1] = (result[-1][0], docs["func_doc"])
797
1138
  return result
798
1139
 
799
1140
  def _generate_function_stub(
@@ -805,11 +1146,12 @@ class StubGenerator:
805
1146
  doc: Optional[str] = None,
806
1147
  deco: Optional[str] = None,
807
1148
  ) -> str:
1149
+ debug.stubgen_exec("Generating function stub for %s" % name)
1150
+
808
1151
  def exploit_default(default_value: Any) -> Optional[str]:
809
- if (
810
- default_value != inspect.Parameter.empty
811
- and type(default_value).__module__ == "builtins"
812
- ):
1152
+ if default_value == inspect.Parameter.empty:
1153
+ return None
1154
+ if type(default_value).__module__ == "builtins":
813
1155
  if isinstance(default_value, list):
814
1156
  return (
815
1157
  "["
@@ -839,22 +1181,23 @@ class StubGenerator:
839
1181
  )
840
1182
  + "}"
841
1183
  )
842
- elif str(default_value).startswith("<"):
843
- if default_value.__module__ == "builtins":
844
- return default_value.__name__
845
- else:
846
- self._typing_imports.add(default_value.__module__)
847
- return ".".join(
848
- [default_value.__module__, default_value.__name__]
849
- )
1184
+ elif isinstance(default_value, str):
1185
+ return "'" + default_value + "'"
850
1186
  else:
851
- return (
852
- str(default_value)
853
- if not isinstance(default_value, str)
854
- else '"' + default_value + '"'
855
- )
1187
+ return self._get_element_name_with_module(default_value)
1188
+
1189
+ elif str(default_value).startswith("<"):
1190
+ if default_value.__module__ == "builtins":
1191
+ return default_value.__name__
1192
+ else:
1193
+ self._typing_imports.add(default_value.__module__)
1194
+ return ".".join([default_value.__module__, default_value.__name__])
856
1195
  else:
857
- return None
1196
+ return (
1197
+ str(default_value)
1198
+ if not isinstance(default_value, str)
1199
+ else '"' + default_value + '"'
1200
+ )
858
1201
 
859
1202
  buff = StringIO()
860
1203
  if sign is None and func is None:
@@ -870,6 +1213,10 @@ class StubGenerator:
870
1213
  # value
871
1214
  return ""
872
1215
  doc = doc or func.__doc__
1216
+ if doc == "STUBGEN_IGNORE":
1217
+ # Ignore methods that have STUBGEN_IGNORE. Used to ignore certain
1218
+ # methods for the Deployer
1219
+ return ""
873
1220
  indentation = indentation or ""
874
1221
 
875
1222
  # Deal with overload annotations -- the last one will be non overloaded and
@@ -883,6 +1230,9 @@ class StubGenerator:
883
1230
  buff.write("\n")
884
1231
 
885
1232
  if do_overload and count < len(sign) - 1:
1233
+ # According to mypy, we should have this on all variants but
1234
+ # some IDEs seem to prefer if there is one non-overloaded
1235
+ # This also changes our checks so if changing, modify tests
886
1236
  buff.write(indentation + "@typing.overload\n")
887
1237
  if deco:
888
1238
  buff.write(indentation + deco + "\n")
@@ -890,6 +1240,7 @@ class StubGenerator:
890
1240
  kw_only_param = False
891
1241
  for i, (par_name, parameter) in enumerate(my_sign.parameters.items()):
892
1242
  annotation = self._exploit_annotation(parameter.annotation)
1243
+
893
1244
  default = exploit_default(parameter.default)
894
1245
 
895
1246
  if kw_only_param and parameter.kind != inspect.Parameter.KEYWORD_ONLY:
@@ -922,7 +1273,7 @@ class StubGenerator:
922
1273
 
923
1274
  if (count == 0 or count == len(sign) - 1) and doc is not None:
924
1275
  buff.write('%s%s"""\n' % (indentation, TAB))
925
- my_doc = cast(str, deindent_docstring(doc))
1276
+ my_doc = inspect.cleandoc(doc)
926
1277
  init_blank = True
927
1278
  for line in my_doc.split("\n"):
928
1279
  if init_blank and len(line.strip()) == 0:
@@ -941,6 +1292,7 @@ class StubGenerator:
941
1292
  def _generate_stubs(self):
942
1293
  for name, attr in self._current_objects.items():
943
1294
  self._current_parent_module = inspect.getmodule(attr)
1295
+ self._current_name = name
944
1296
  if inspect.isclass(attr):
945
1297
  self._stubs.append(self._generate_class_stub(name, attr))
946
1298
  elif inspect.isfunction(attr):
@@ -1023,6 +1375,29 @@ class StubGenerator:
1023
1375
  elif not inspect.ismodule(attr):
1024
1376
  self._stubs.append(self._generate_generic_stub(name, attr))
1025
1377
 
1378
+ def _write_header(self, f, width):
1379
+ title_line = "Auto-generated Metaflow stub file"
1380
+ title_white_space = (width - len(title_line)) / 2
1381
+ title_line = "#%s%s%s#\n" % (
1382
+ " " * math.floor(title_white_space),
1383
+ title_line,
1384
+ " " * math.ceil(title_white_space),
1385
+ )
1386
+ f.write(
1387
+ "#" * (width + 2)
1388
+ + "\n"
1389
+ + title_line
1390
+ + "# MF version: %s%s#\n"
1391
+ % (self._mf_version, " " * (width - 13 - len(self._mf_version)))
1392
+ + "# Generated on %s%s#\n"
1393
+ % (
1394
+ datetime.fromtimestamp(time.time()).isoformat(),
1395
+ " " * (width - 14 - 26),
1396
+ )
1397
+ + "#" * (width + 2)
1398
+ + "\n\n"
1399
+ )
1400
+
1026
1401
  def write_out(self):
1027
1402
  out_dir = self._output_dir
1028
1403
  os.makedirs(out_dir, exist_ok=True)
@@ -1036,66 +1411,75 @@ class StubGenerator:
1036
1411
  "%s %s"
1037
1412
  % (self._mf_version, datetime.fromtimestamp(time.time()).isoformat())
1038
1413
  )
1039
- while len(self._pending_modules) != 0:
1040
- module_name = self._pending_modules.pop(0)
1414
+ post_process_modules = []
1415
+ is_post_processing = False
1416
+ while len(self._pending_modules) != 0 or len(post_process_modules) != 0:
1417
+ if is_post_processing or len(self._pending_modules) == 0:
1418
+ is_post_processing = True
1419
+ module_alias, module_name = post_process_modules.pop(0)
1420
+ else:
1421
+ module_alias, module_name = self._pending_modules.pop(0)
1041
1422
  # Skip vendored stuff
1042
- if module_name.startswith("metaflow._vendor"):
1423
+ if module_alias.startswith("metaflow._vendor") or module_name.startswith(
1424
+ "metaflow._vendor"
1425
+ ):
1043
1426
  continue
1044
- # We delay current module
1427
+ # We delay current module and deployer module to the end since they
1428
+ # depend on info we gather elsewhere
1045
1429
  if (
1046
- module_name == METAFLOW_CURRENT_MODULE_NAME
1047
- and len(set(self._pending_modules)) > 1
1430
+ module_alias
1431
+ in (
1432
+ METAFLOW_CURRENT_MODULE_NAME,
1433
+ METAFLOW_DEPLOYER_MODULE_NAME,
1434
+ )
1435
+ and len(self._pending_modules) != 0
1048
1436
  ):
1049
- self._pending_modules.append(module_name)
1437
+ post_process_modules.append((module_alias, module_name))
1050
1438
  continue
1051
- if module_name in self._done_modules:
1439
+ if module_alias in self._done_modules:
1052
1440
  continue
1053
- self._done_modules.add(module_name)
1441
+ self._done_modules.add(module_alias)
1054
1442
  # If not, we process the module
1055
1443
  self._reset()
1056
- self._get_module(module_name)
1444
+ self._get_module(module_alias, module_name)
1445
+ if module_name == "metaflow" and not is_post_processing:
1446
+ # We will want to regenerate this at the end to take into account
1447
+ # any changes to the Deployer
1448
+ post_process_modules.append((module_name, module_name))
1449
+ self._done_modules.remove(module_name)
1450
+ continue
1057
1451
  self._generate_stubs()
1058
1452
 
1059
1453
  if hasattr(self._current_module, "__path__"):
1060
1454
  # This is a package (so a directory) and we are dealing with
1061
1455
  # a __init__.pyi type of case
1062
- dir_path = os.path.join(
1063
- self._output_dir, *self._current_module.__name__.split(".")[1:]
1064
- )
1456
+ dir_path = os.path.join(self._output_dir, *module_alias.split(".")[1:])
1065
1457
  else:
1066
1458
  # This is NOT a package so the original source file is not a __init__.py
1067
1459
  dir_path = os.path.join(
1068
- self._output_dir, *self._current_module.__name__.split(".")[1:-1]
1460
+ self._output_dir, *module_alias.split(".")[1:-1]
1069
1461
  )
1070
1462
  out_file = os.path.join(
1071
1463
  dir_path, os.path.basename(self._current_module.__file__) + "i"
1072
1464
  )
1073
1465
 
1466
+ width = 100
1467
+
1074
1468
  os.makedirs(os.path.dirname(out_file), exist_ok=True)
1469
+ # We want to make sure we always have a __init__.pyi in the directories
1470
+ # we are creating
1471
+ parts = dir_path.split(os.sep)[len(self._output_dir.split(os.sep)) :]
1472
+ for i in range(1, len(parts) + 1):
1473
+ init_file_path = os.path.join(
1474
+ self._output_dir, *parts[:i], "__init__.pyi"
1475
+ )
1476
+ if not os.path.exists(init_file_path):
1477
+ with open(init_file_path, mode="w", encoding="utf-8") as f:
1478
+ self._write_header(f, width)
1075
1479
 
1076
- width = 80
1077
- title_line = "Auto-generated Metaflow stub file"
1078
- title_white_space = (width - len(title_line)) / 2
1079
- title_line = "#%s%s%s#\n" % (
1080
- " " * math.floor(title_white_space),
1081
- title_line,
1082
- " " * math.ceil(title_white_space),
1083
- )
1084
1480
  with open(out_file, mode="w", encoding="utf-8") as f:
1085
- f.write(
1086
- "#" * (width + 2)
1087
- + "\n"
1088
- + title_line
1089
- + "# MF version: %s%s#\n"
1090
- % (self._mf_version, " " * (width - 13 - len(self._mf_version)))
1091
- + "# Generated on %s%s#\n"
1092
- % (
1093
- datetime.fromtimestamp(time.time()).isoformat(),
1094
- " " * (width - 14 - 26),
1095
- )
1096
- + "#" * (width + 2)
1097
- + "\n\n"
1098
- )
1481
+ self._write_header(f, width)
1482
+
1099
1483
  f.write("from __future__ import annotations\n\n")
1100
1484
  imported_typing = False
1101
1485
  for module in self._imports:
@@ -1123,8 +1507,14 @@ class StubGenerator:
1123
1507
  "%s = %s\n" % (type_name, new_type_to_str(type_var))
1124
1508
  )
1125
1509
  f.write("\n")
1510
+ for import_line in self._current_references:
1511
+ f.write(import_line + "\n")
1512
+ f.write("\n")
1126
1513
  for stub in self._stubs:
1127
1514
  f.write(stub + "\n")
1515
+ if is_post_processing:
1516
+ # Don't consider any pending modules if we are post processing
1517
+ self._pending_modules.clear()
1128
1518
 
1129
1519
 
1130
1520
  if __name__ == "__main__":