micropython-stubber 1.24.1__py3-none-any.whl → 1.24.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. {micropython_stubber-1.24.1.dist-info → micropython_stubber-1.24.4.dist-info}/METADATA +9 -29
  2. micropython_stubber-1.24.4.dist-info/RECORD +107 -0
  3. {micropython_stubber-1.24.1.dist-info → micropython_stubber-1.24.4.dist-info}/WHEEL +1 -1
  4. stubber/__init__.py +1 -1
  5. stubber/board/createstubs.py +44 -38
  6. stubber/board/createstubs_db.py +17 -12
  7. stubber/board/createstubs_db_min.py +63 -63
  8. stubber/board/createstubs_db_mpy.mpy +0 -0
  9. stubber/board/createstubs_mem.py +17 -12
  10. stubber/board/createstubs_mem_min.py +99 -99
  11. stubber/board/createstubs_mem_mpy.mpy +0 -0
  12. stubber/board/createstubs_min.py +111 -112
  13. stubber/board/createstubs_mpy.mpy +0 -0
  14. stubber/board/modulelist.txt +27 -27
  15. stubber/codemod/board.py +1 -1
  16. stubber/codemod/enrich.py +13 -13
  17. stubber/codemod/merge_docstub.py +83 -53
  18. stubber/codemod/visitors/type_helpers.py +143 -41
  19. stubber/commands/enrich_folder_cmd.py +17 -17
  20. stubber/commands/get_docstubs_cmd.py +27 -9
  21. stubber/commands/get_frozen_cmd.py +1 -0
  22. stubber/commands/merge_cmd.py +2 -4
  23. stubber/merge_config.py +5 -36
  24. stubber/minify.py +3 -3
  25. stubber/modcat.py +118 -0
  26. stubber/publish/merge_docstubs.py +22 -5
  27. stubber/publish/stubpackage.py +33 -28
  28. stubber/rst/lookup.py +6 -23
  29. stubber/rst/reader.py +8 -13
  30. stubber/stubs_from_docs.py +2 -1
  31. stubber/tools/manifestfile.py +2 -1
  32. stubber/{cst_transformer.py → typing_collector.py} +36 -4
  33. micropython_stubber-1.24.1.dist-info/RECORD +0 -161
  34. mpflash/README.md +0 -220
  35. mpflash/libusb_flash.ipynb +0 -203
  36. mpflash/mpflash/__init__.py +0 -0
  37. mpflash/mpflash/add_firmware.py +0 -98
  38. mpflash/mpflash/ask_input.py +0 -236
  39. mpflash/mpflash/basicgit.py +0 -324
  40. mpflash/mpflash/bootloader/__init__.py +0 -2
  41. mpflash/mpflash/bootloader/activate.py +0 -60
  42. mpflash/mpflash/bootloader/detect.py +0 -82
  43. mpflash/mpflash/bootloader/manual.py +0 -101
  44. mpflash/mpflash/bootloader/micropython.py +0 -12
  45. mpflash/mpflash/bootloader/touch1200.py +0 -36
  46. mpflash/mpflash/cli_download.py +0 -129
  47. mpflash/mpflash/cli_flash.py +0 -224
  48. mpflash/mpflash/cli_group.py +0 -111
  49. mpflash/mpflash/cli_list.py +0 -87
  50. mpflash/mpflash/cli_main.py +0 -39
  51. mpflash/mpflash/common.py +0 -217
  52. mpflash/mpflash/config.py +0 -44
  53. mpflash/mpflash/connected.py +0 -96
  54. mpflash/mpflash/download.py +0 -364
  55. mpflash/mpflash/downloaded.py +0 -138
  56. mpflash/mpflash/errors.py +0 -9
  57. mpflash/mpflash/flash/__init__.py +0 -55
  58. mpflash/mpflash/flash/esp.py +0 -59
  59. mpflash/mpflash/flash/stm32.py +0 -19
  60. mpflash/mpflash/flash/stm32_dfu.py +0 -104
  61. mpflash/mpflash/flash/uf2/__init__.py +0 -88
  62. mpflash/mpflash/flash/uf2/boardid.py +0 -15
  63. mpflash/mpflash/flash/uf2/linux.py +0 -136
  64. mpflash/mpflash/flash/uf2/macos.py +0 -42
  65. mpflash/mpflash/flash/uf2/uf2disk.py +0 -12
  66. mpflash/mpflash/flash/uf2/windows.py +0 -43
  67. mpflash/mpflash/flash/worklist.py +0 -170
  68. mpflash/mpflash/list.py +0 -106
  69. mpflash/mpflash/logger.py +0 -41
  70. mpflash/mpflash/mpboard_id/__init__.py +0 -98
  71. mpflash/mpflash/mpboard_id/add_boards.py +0 -262
  72. mpflash/mpflash/mpboard_id/board.py +0 -37
  73. mpflash/mpflash/mpboard_id/board_id.py +0 -90
  74. mpflash/mpflash/mpboard_id/board_info.zip +0 -0
  75. mpflash/mpflash/mpboard_id/store.py +0 -48
  76. mpflash/mpflash/mpremoteboard/__init__.py +0 -271
  77. mpflash/mpflash/mpremoteboard/mpy_fw_info.py +0 -152
  78. mpflash/mpflash/mpremoteboard/runner.py +0 -140
  79. mpflash/mpflash/vendor/board_database.py +0 -185
  80. mpflash/mpflash/vendor/click_aliases.py +0 -91
  81. mpflash/mpflash/vendor/dfu.py +0 -165
  82. mpflash/mpflash/vendor/pydfu.py +0 -605
  83. mpflash/mpflash/vendor/readme.md +0 -12
  84. mpflash/mpflash/versions.py +0 -123
  85. mpflash/poetry.lock +0 -2603
  86. mpflash/pyproject.toml +0 -66
  87. mpflash/stm32_udev_rules.md +0 -63
  88. stubber/codemod/test_enrich.py +0 -87
  89. {micropython_stubber-1.24.1.dist-info → micropython_stubber-1.24.4.dist-info}/LICENSE +0 -0
  90. {micropython_stubber-1.24.1.dist-info → micropython_stubber-1.24.4.dist-info}/entry_points.txt +0 -0
@@ -1,24 +1,34 @@
1
1
  # sourcery skip: snake-case-functions
2
- """Merge documentation and type information from from the docstubs into a board stub"""
3
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ """
3
+ Merge documentation and type information
4
+ - from an doctring-rich and typed stub module
5
+ - infor a less well documented and typed stub module
6
+ """
7
+ # Copyright Jos Verlinde
4
8
  #
5
9
  # This source code is licensed under the MIT license found in the
6
10
  # LICENSE file in the root directory of this source tree.
7
11
  #
12
+
8
13
  import argparse
9
14
  from pathlib import Path
10
15
  from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast
11
16
 
12
17
  import libcst as cst
18
+ import libcst.matchers as m
13
19
  from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
14
20
  from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, ImportItem
15
21
  from libcst.helpers.module import insert_header_comments
16
22
 
17
23
  from mpflash.logger import log
18
- from stubber.cst_transformer import (
24
+ from stubber.typing_collector import (
19
25
  MODULE_KEY,
20
26
  AnnoValue,
21
27
  StubTypingCollector,
28
+ is_deleter,
29
+ is_getter,
30
+ is_property,
31
+ is_setter,
22
32
  update_def_docstr,
23
33
  update_module_docstr,
24
34
  )
@@ -31,6 +41,12 @@ Mod_Class_T = TypeVar("Mod_Class_T", cst.Module, cst.ClassDef)
31
41
  # # log = logging.getLogger(__name__)
32
42
  #########################################################################################
33
43
  empty_module = cst.parse_module("") # Debugging aid : empty_module.code_for_node(node)
44
+ _code = empty_module.code_for_node
45
+
46
+
47
+ def is_decorator(dec: cst.CSTNode, name: str) -> bool:
48
+ """shorthand to determin if something is a specific decorator"""
49
+ return m.matches(dec, m.Decorator(decorator=m.Name(value=name)))
34
50
 
35
51
 
36
52
  class MergeCommand(VisitorBasedCodemodCommand):
@@ -48,6 +64,8 @@ class MergeCommand(VisitorBasedCodemodCommand):
48
64
  """
49
65
 
50
66
  DESCRIPTION: str = "Merge the type-rich information from a doc-stub into a firmware stub"
67
+ copy_params: bool = True
68
+ copy_docstr: bool = True
51
69
 
52
70
  @staticmethod
53
71
  def add_args(arg_parser: argparse.ArgumentParser) -> None:
@@ -63,16 +81,23 @@ class MergeCommand(VisitorBasedCodemodCommand):
63
81
  )
64
82
 
65
83
  arg_parser.add_argument(
66
- "--params-only",
67
- dest="params_only",
84
+ "--copy-params",
85
+ dest="copy_params",
68
86
  default=False,
69
87
  )
70
88
 
89
+ arg_parser.add_argument(
90
+ "--copy-docstr",
91
+ dest="copy_docstr",
92
+ default=True,
93
+ )
94
+
71
95
  def __init__(
72
96
  self,
73
97
  context: CodemodContext,
74
98
  docstub_file: Union[Path, str],
75
- params_only: bool = False,
99
+ copy_params: bool = False,
100
+ copy_docstr: bool = True,
76
101
  ) -> None:
77
102
  """initialize the base class with context, and save our args."""
78
103
  super().__init__(context)
@@ -91,11 +116,12 @@ class MergeCommand(VisitorBasedCodemodCommand):
91
116
  ] = {}
92
117
  self.comments: List[str] = []
93
118
 
94
- self.params_only = params_only
119
+ self.copy_params = copy_params
120
+ self.copy_docstr = copy_docstr
95
121
 
96
122
  self.stub_imports: Dict[str, ImportItem] = {}
97
123
  self.all_imports: List[Union[cst.Import, cst.ImportFrom]] = []
98
- self.type_helpers = []
124
+ self.type_helpers = {}
99
125
  # parse the doc-stub file
100
126
  if self.docstub_source:
101
127
  try:
@@ -119,6 +145,7 @@ class MergeCommand(VisitorBasedCodemodCommand):
119
145
  # Get typevars, type aliasses and ParamSpecs
120
146
  stub_tree.visit(typevar_collector)
121
147
  self.type_helpers = typevar_collector.all_typehelpers
148
+ pass
122
149
 
123
150
  # ------------------------------------------------------------------------
124
151
 
@@ -142,9 +169,7 @@ class MergeCommand(VisitorBasedCodemodCommand):
142
169
  if import_item.module_name == self.context.full_module_name:
143
170
  # this is an import from the same module we should NOT add it
144
171
  continue
145
- if import_item.module_name.split(".")[
146
- 0
147
- ] == self.context.full_module_name and not self.context.filename.endswith(
172
+ if import_item.module_name.split(".")[0] == self.context.full_module_name and not self.context.filename.endswith(
148
173
  "__init__.pyi"
149
174
  ):
150
175
  # this is an import from a module child module we should NOT add it
@@ -178,24 +203,21 @@ class MergeCommand(VisitorBasedCodemodCommand):
178
203
  # --------------------------------------------------------------------
179
204
  # Add any typevars to the module
180
205
  if self.type_helpers:
181
- for tv in self.type_helpers:
182
- AddTypeHelpers.add_typevar(self.context, tv) # type: ignore
183
-
206
+ AddTypeHelpers.add_helpers(self.context, self.type_helpers)
184
207
  atv = AddTypeHelpers(self.context)
185
208
  updated_node = atv.transform_module(updated_node)
186
209
 
187
210
  # --------------------------------------------------------------------
188
211
  # update the docstring.
189
212
  if MODULE_KEY in self.annotations:
190
-
191
213
  # update/replace module docstrings
192
214
  # todo: or should we add / merge the docstrings?
193
215
  docstub_docstr = self.annotations[MODULE_KEY].docstring
194
216
  assert isinstance(docstub_docstr, str)
195
217
  src_docstr = original_node.get_docstring() or ""
196
218
  if src_docstr or docstub_docstr:
197
- if not self.params_only and (docstub_docstr.strip() != src_docstr.strip()):
198
- if src_docstr:
219
+ if self.copy_docstr and (docstub_docstr.strip() != src_docstr.strip()):
220
+ if src_docstr and self.copy_docstr:
199
221
  log.trace(f"Append module docstrings. (new --- old) ")
200
222
  new_docstr = '"""\n' + docstub_docstr + "\n\n---\n" + src_docstr + '\n"""'
201
223
  else:
@@ -251,7 +273,7 @@ class MergeCommand(VisitorBasedCodemodCommand):
251
273
  matched, i = self.locate_function_by_name(overload, updated_body)
252
274
  if matched:
253
275
  log.trace(f"Add @overload for {overload.name.value}")
254
- if self.params_only:
276
+ if self.copy_params:
255
277
  docstring_node = self.annotations[key].docstring_node or ""
256
278
  # Use the new overload - but with the existing docstring
257
279
  overload = update_def_docstr(overload, docstring_node)
@@ -264,13 +286,9 @@ class MergeCommand(VisitorBasedCodemodCommand):
264
286
  if class_name not in new_classes:
265
287
  new_classes.append(class_name)
266
288
  # create a class for it, and then add all the overload methods to that class
267
- log.trace(
268
- f"Add New class @overload for {overload.name.value} at the end of the module"
269
- )
289
+ log.trace(f"Add New class @overload for {overload.name.value} at the end of the module")
270
290
  # create a list of all overloads for this class
271
- class_overloads = [
272
- overload for overload, k in missing_overloads if k[0] == class_name
273
- ]
291
+ class_overloads = [overload for overload, k in missing_overloads if k[0] == class_name]
274
292
  class_def = cst.ClassDef(
275
293
  name=cst.Name(value=class_name),
276
294
  body=cst.IndentedBlock(body=class_overloads),
@@ -280,9 +298,7 @@ class MergeCommand(VisitorBasedCodemodCommand):
280
298
  # already processed this class method
281
299
  pass
282
300
  else:
283
- log.trace(
284
- f"Add @overload for {overload.name.value} at the end of the class"
285
- )
301
+ log.trace(f"Add @overload for {overload.name.value} at the end of the class")
286
302
  updated_body.append(overload)
287
303
 
288
304
  if isinstance(updated_node, cst.Module):
@@ -310,9 +326,7 @@ class MergeCommand(VisitorBasedCodemodCommand):
310
326
  """keep track of the the (class, method) names to the stack"""
311
327
  self.stack.append(node.name.value)
312
328
 
313
- def leave_ClassDef(
314
- self, original_node: cst.ClassDef, updated_node: cst.ClassDef
315
- ) -> cst.ClassDef:
329
+ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
316
330
  stack_id = tuple(self.stack)
317
331
  self.stack.pop()
318
332
  if stack_id not in self.annotations:
@@ -344,40 +358,58 @@ class MergeCommand(VisitorBasedCodemodCommand):
344
358
  updated_node = self.add_missed_overloads(updated_node, stack_id)
345
359
  return updated_node
346
360
 
361
+ # ------------------------------------------------------------------------
362
+
347
363
  # ------------------------------------------------------------------------
348
364
  def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
349
365
  self.stack.append(node.name.value)
350
366
  return True
351
367
 
352
- def leave_FunctionDef(
353
- self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
354
- ) -> Union[cst.FunctionDef, cst.ClassDef]:
368
+ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> Union[cst.FunctionDef, cst.ClassDef]:
355
369
  "Update the function Parameters and return type, decorators and docstring"
356
- stack_id = tuple(self.stack)
370
+ if is_getter(updated_node):
371
+ extra = ["getter"]
372
+ elif is_setter(updated_node):
373
+ extra = ["setter"]
374
+ elif is_deleter(updated_node):
375
+ extra = ["deleter"]
376
+ else:
377
+ extra = []
378
+ stack_id = tuple(self.stack + extra)
357
379
  self.stack.pop()
358
380
  if stack_id not in self.annotations:
359
381
  # no changes to the function in docstub
360
382
  return updated_node
361
- if updated_node.decorators and any(
362
- dec.decorator.value == "overload" for dec in updated_node.decorators # type: ignore
363
- ):
383
+ if updated_node.decorators and any(is_decorator(dec, "overload") for dec in updated_node.decorators):
364
384
  # do not overwrite existing @overload functions
365
385
  # ASSUME: they are OK as they are
366
386
  return updated_node
367
-
368
387
  # update the firmware_stub from the doc_stub information
369
388
  doc_stub = self.annotations[stack_id].type_info
389
+ if isinstance(doc_stub.def_node, cst.FunctionDef):
390
+ # avoid mismatching property decorators
391
+ # if the updated node is a property, and the doc_stub node is not a property
392
+ # then we should not update the node
393
+ if is_property(updated_node) and not is_property(doc_stub.def_node):
394
+ return updated_node
395
+ if is_setter(updated_node) and not is_setter(doc_stub.def_node):
396
+ return updated_node
397
+ if is_getter(updated_node) and not is_getter(doc_stub.def_node):
398
+ return updated_node
399
+ if is_deleter(updated_node) and not is_deleter(doc_stub.def_node):
400
+ return updated_node
401
+
370
402
  # Check if it is an @overload decorator
371
- add_overload = any(dec.decorator.value == "overload" for dec in doc_stub.decorators) and len(self.annotations[stack_id].overloads) > 1 # type: ignore
403
+ add_overload = any(is_decorator(dec, "overload") for dec in doc_stub.decorators) and len(self.annotations[stack_id].overloads) > 1
372
404
 
373
- # If there are overloads in the documentation , lets use the first one
405
+ # If there are overloads in the documentation, use the first one
374
406
  if add_overload:
375
407
  log.debug(f"Change to @overload :{updated_node.name.value}")
376
408
  # Use the new overload - but with the existing docstring
377
409
  doc_stub = self.annotations[stack_id].overloads.pop(0)
378
410
  assert doc_stub.def_node
379
411
 
380
- if not self.params_only:
412
+ if not self.copy_params:
381
413
  # we have copied over the entire function definition, no further processing should be done on this node
382
414
  doc_stub.def_node = cast(cst.FunctionDef, doc_stub.def_node)
383
415
  updated_node = doc_stub.def_node
@@ -397,9 +429,9 @@ class MergeCommand(VisitorBasedCodemodCommand):
397
429
  # assert isinstance(doc_stub, TypeInfo)
398
430
  # assert doc_stub
399
431
  # first update the docstring
400
- no_docstring = updated_node.get_docstring() is None
401
- if (not self.params_only) or no_docstring:
402
- # DO Not overwrite existing docstring
432
+ has_no_docstring = updated_node.get_docstring() is None
433
+ if has_no_docstring or self.copy_docstr:
434
+ # overwrite existing docstring if there was none , or if it is asked
403
435
  updated_node = update_def_docstr(updated_node, doc_stub.docstr_node, doc_stub.def_node)
404
436
 
405
437
  # Sometimes the MCU stubs and the doc stubs have different types : FunctionDef / ClassDef
@@ -407,12 +439,12 @@ class MergeCommand(VisitorBasedCodemodCommand):
407
439
  if doc_stub.def_type == "funcdef":
408
440
  # Same type, we can copy over the annotations
409
441
  # params that should not be overwritten by the doc-stub ?
410
- if self.params_only:
442
+ if self.copy_params:
411
443
  # we are copying rich type definitions, just assume they are better than what is currently
412
444
  # in the destination stub
413
445
  overwrite_params = True
414
446
  else:
415
- params_txt = empty_module.code_for_node(original_node.params)
447
+ params_txt = _code(original_node.params)
416
448
  overwrite_params = params_txt in [
417
449
  "",
418
450
  "...",
@@ -440,16 +472,14 @@ class MergeCommand(VisitorBasedCodemodCommand):
440
472
  new_decorators.extend(doc_stub.decorators)
441
473
 
442
474
  for decorator in updated_node.decorators:
443
- if decorator.decorator.value not in [n.decorator.value for n in new_decorators]: # type: ignore
475
+ if _code(decorator) not in [_code(d) for d in new_decorators]:
444
476
  new_decorators.append(decorator)
445
477
 
446
- # if there is both a static and a class method, we remove the class decorator to avoid inconsistencies
447
- if any(dec.decorator.value == "staticmethod" for dec in doc_stub.decorators) and any( # type: ignore
448
- dec.decorator.value == "staticmethod" for dec in doc_stub.decorators # type: ignore
478
+ # if the method is both a static and a class method, we remove the @classmethod decorator to avoid inconsistencies
479
+ if any(is_decorator(dec, "staticmethod") for dec in new_decorators) and any(
480
+ is_decorator(dec, "classmethod") for dec in new_decorators
449
481
  ):
450
- new_decorators = [
451
- dec for dec in new_decorators if dec.decorator.value != "classmethod"
452
- ]
482
+ new_decorators = [dec for dec in new_decorators if dec.decorator.value != "classmethod"]
453
483
 
454
484
  return updated_node.with_changes(
455
485
  decorators=new_decorators,
@@ -2,9 +2,9 @@
2
2
  Gather all TypeVar and TypeAlias assignments in a module.
3
3
  """
4
4
 
5
- from typing import List, Sequence, Tuple, Union
5
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
6
6
 
7
- import libcst
7
+ import libcst as cst
8
8
  import libcst.matchers as m
9
9
  from libcst import SimpleStatementLine
10
10
  from libcst.codemod._context import CodemodContext
@@ -14,11 +14,11 @@ from typing_extensions import TypeAlias
14
14
 
15
15
  from mpflash.logger import log
16
16
 
17
- TypeHelper: TypeAlias = Union[libcst.Assign, libcst.AnnAssign]
17
+ TypeHelper: TypeAlias = Union[cst.Assign, cst.AnnAssign]
18
18
  TypeHelpers: TypeAlias = List[TypeHelper]
19
- Statement: TypeAlias = Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]
19
+ Statement: TypeAlias = Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]
20
20
 
21
- _mod = libcst.parse_module("") # Debugging aid : _mod.code_for_node(node)
21
+ _mod = cst.parse_module("") # Debugging aid : _mod.code_for_node(node)
22
22
  _code = _mod.code_for_node # Debugging aid : _code(node)
23
23
 
24
24
 
@@ -33,7 +33,7 @@ def is_TypeAlias(statement) -> bool:
33
33
 
34
34
 
35
35
  def is_TypeVar(statement):
36
- "Assing - Foo = Typevar(...)"
36
+ "Assign - Foo = Typevar(...)"
37
37
  if m.matches(statement, m.SimpleStatementLine()):
38
38
  statement = statement.body[0]
39
39
  return m.matches(
@@ -41,6 +41,35 @@ def is_TypeVar(statement):
41
41
  m.Assign(value=m.Call(func=m.Name(value="TypeVar"))),
42
42
  )
43
43
 
44
+ def is_CONSTANT(statement):
45
+ """
46
+ Assign - FOO = ...
47
+ AnnAssign- FOO:bool = ...
48
+ """
49
+
50
+ if m.matches(statement, m.SimpleStatementLine()):
51
+ statement = statement.body[0]
52
+ if m.matches(statement, m.Assign()):
53
+ if len(statement.targets) != 1:
54
+ return False
55
+ if statement.targets[0].target.value.isupper():
56
+ return True
57
+ return False
58
+
59
+
60
+ def is_AnnCONSTANT(statement):
61
+ """
62
+ Assign - FOO = ...
63
+ AnnAssign- FOO:bool = ...
64
+ """
65
+
66
+ if m.matches(statement, m.SimpleStatementLine()):
67
+ statement = statement.body[0]
68
+ if m.matches(statement, m.AnnAssign()):
69
+ if statement.target.value.isupper():
70
+ return True
71
+ return False
72
+
44
73
 
45
74
  def is_ParamSpec(statement):
46
75
  "Assign - Foo = ParamSpec(...)"
@@ -56,6 +85,15 @@ def is_import(statement):
56
85
  "import - import foo"
57
86
  return m.matches(statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]))
58
87
 
88
+ def is_docstr(statement):
89
+ "single or triple quoted string"
90
+ if m.matches(statement, m.SimpleStatementLine()):
91
+ statement = statement.body[0]
92
+ return m.matches(
93
+ statement,
94
+ m.Expr(value=m.SimpleString()),
95
+ # | m.TripleQuotedString(),
96
+ )
59
97
 
60
98
  class GatherTypeHelpers(ContextAwareVisitor):
61
99
  """
@@ -65,25 +103,41 @@ class GatherTypeHelpers(ContextAwareVisitor):
65
103
  def __init__(self, context: CodemodContext) -> None:
66
104
  super().__init__(context)
67
105
  # Track all of the TypeVar, TypeAlias and Paramspec assignments found
68
- self.all_typehelpers: TypeHelpers = []
106
+ self.all_typehelpers: Dict[Tuple[str, ...], TypeHelpers] = {}
107
+ self.stack: List[str] = []
69
108
 
70
- def visit_Assign(self, node: libcst.Assign) -> None:
109
+ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
110
+ """keep track of the the (class, method) names to the stack"""
111
+ self.stack.append(node.name.value)
112
+
113
+ def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
114
+ """remove the class name from the stack"""
115
+ self.stack.pop()
116
+
117
+ def visit_Assign(self, node: cst.Assign) -> None:
71
118
  """
72
119
  Find all TypeVar assignments in the module.
73
120
  format: T = TypeVar("T", int, float, str, bytes, Tuple)
121
+ or Constants
74
122
  """
75
123
  # is this a TypeVar assignment?
76
124
  # needs to be more robust
77
- if is_TypeVar(node) or is_ParamSpec(node):
78
- self.all_typehelpers.append(node)
125
+ if is_TypeVar(node) or is_ParamSpec(node) or is_CONSTANT(node):
126
+ key = tuple(self.stack)
127
+ if not key in self.all_typehelpers:
128
+ self.all_typehelpers[key] = []
129
+ self.all_typehelpers[key].append(node)
79
130
 
80
- def visit_AnnAssign(self, node: libcst.AnnAssign) -> None:
131
+ def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
81
132
  """ "
82
133
  Find all TypeAlias assignments in the module.
83
134
  format: T: TypeAlias = str
84
135
  """
85
- if is_TypeAlias(node):
86
- self.all_typehelpers.append(node)
136
+ if is_TypeAlias(node) or is_AnnCONSTANT(node):
137
+ key = tuple(self.stack)
138
+ if not key in self.all_typehelpers:
139
+ self.all_typehelpers[key] = []
140
+ self.all_typehelpers[key].append(node)
87
141
 
88
142
 
89
143
  class AddTypeHelpers(ContextAwareTransformer):
@@ -95,48 +149,64 @@ class AddTypeHelpers(ContextAwareTransformer):
95
149
 
96
150
  def __init__(self, context: CodemodContext) -> None:
97
151
  super().__init__(context)
98
- self.new_typehelpers: TypeHelpers = context.scratch.get(self.CONTEXT_KEY, [])
99
- self.all_typehelpers: TypeHelpers = []
152
+ self.new_typehelpers: Dict[Tuple[str, ...], TypeHelpers] = context.scratch.get(self.CONTEXT_KEY, [])
153
+ self.all_typehelpers: Dict[Tuple[str, ...], TypeHelpers] = {}
154
+ self.stack: List[str] = []
100
155
 
101
156
  @classmethod
102
- def add_typevar(cls, context: CodemodContext, node: libcst.Assign):
103
- new_typehelpers = context.scratch.get(cls.CONTEXT_KEY, [])
104
- new_typehelpers.append(node)
105
- context.scratch[cls.CONTEXT_KEY] = new_typehelpers
157
+ def add_helpers(cls, context: CodemodContext, helpers: Dict[Tuple[str, ...], TypeHelpers]):
158
+ context.scratch[cls.CONTEXT_KEY] = helpers
106
159
  # add the typevar to the module
107
160
 
161
+ @staticmethod
162
+ def skip_first(body: Sequence[cst.BaseStatement]) -> bool:
163
+ # Is there a __strict__ flag or docstring at the top?
164
+ if m.matches(
165
+ body[0],
166
+ m.SimpleStatementLine(body=[m.Assign(targets=[m.AssignTarget(target=m.Name("__strict__"))])])
167
+ | m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
168
+ ):
169
+ return True
170
+ return False
171
+
108
172
  def leave_Module(
109
173
  self,
110
- original_node: libcst.Module,
111
- updated_node: libcst.Module,
112
- ) -> libcst.Module:
113
- if not self.new_typehelpers:
174
+ original_node: cst.Module,
175
+ updated_node: cst.Module,
176
+ ) -> cst.Module:
177
+ stack_id = tuple(self.stack)
178
+ if stack_id not in self.new_typehelpers:
179
+ # nothing new to add @ module level
114
180
  return updated_node
115
181
 
116
- # split the module into 3 parts
182
+ # split the body of the module or classdef into 3 parts
117
183
  # before and after the insertions point , and a list of the TV and TA statements
184
+ body = self.update_body(updated_node.body, stack_id)
185
+ return updated_node.with_changes(body=body)
186
+
187
+ def update_body(self, body:Sequence[cst.BaseStatement], stack_id):
118
188
  (
119
189
  statements_before,
120
190
  helper_statements,
121
191
  statements_after,
122
- ) = self._split_module(original_node, updated_node)
192
+ ) = self._split_body(body)
123
193
 
124
194
  # simpler to compare as text than to compare the nodes -
125
195
  existing_targets = [
126
196
  helper.body[0].targets[0].target.value # type: ignore
127
197
  for helper in helper_statements
128
- if is_TypeVar(helper) or is_ParamSpec(helper)
198
+ if is_TypeVar(helper) or is_ParamSpec(helper) or is_CONSTANT(helper)
129
199
  ] + [
130
200
  helper.body[0].target.value # type: ignore
131
201
  for helper in helper_statements
132
- if is_TypeAlias(helper)
202
+ if is_TypeAlias(helper) or is_AnnCONSTANT(helper)
133
203
  ]
134
204
  statements_to_add = []
135
- for new_typehelper in self.new_typehelpers:
136
- if isinstance(new_typehelper, libcst.AnnAssign):
205
+ for new_typehelper in self.new_typehelpers[stack_id]:
206
+ if isinstance(new_typehelper, cst.AnnAssign):
137
207
  if new_typehelper.target.value not in existing_targets: # type: ignore
138
208
  statements_to_add.append(SimpleStatementLine(body=[new_typehelper]))
139
- elif isinstance(new_typehelper, libcst.Assign):
209
+ elif isinstance(new_typehelper, cst.Assign):
140
210
  if new_typehelper.targets[0].target.value not in existing_targets: # type: ignore
141
211
  statements_to_add.append(SimpleStatementLine(body=[new_typehelper]))
142
212
 
@@ -145,14 +215,38 @@ class AddTypeHelpers(ContextAwareTransformer):
145
215
  *statements_to_add,
146
216
  *statements_after,
147
217
  )
218
+ return body
148
219
 
149
- return updated_node.with_changes(body=body)
220
+ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
221
+ """keep track of the the (class, method) names to the stack"""
222
+ self.stack.append(node.name.value)
150
223
 
151
- def _split_module(
224
+ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
225
+ stack_id = tuple(self.stack)
226
+ self.stack.pop()
227
+ if stack_id not in self.new_typehelpers:
228
+ # no changes to the class
229
+ return updated_node
230
+ # split the body of the module or classdef into 3 parts
231
+ if isinstance(updated_node.body, cst.IndentedBlock):
232
+ flat_body = self.update_body(updated_node.body.body, stack_id)
233
+ new_body = cst.IndentedBlock(
234
+ body=flat_body,
235
+ header=updated_node.body.header,
236
+ footer=updated_node.body.footer,
237
+ indent=updated_node.body.indent,
238
+ )
239
+ else:
240
+ flat_body = self.update_body([], stack_id)
241
+ new_body = cst.IndentedBlock(
242
+ body=flat_body,
243
+ )
244
+ return updated_node.with_changes(body=new_body)
245
+
246
+ def _split_body(
152
247
  self,
153
- orig_module: libcst.Module,
154
- updated_module: libcst.Module,
155
- ) -> Tuple[Sequence[Statement], Sequence[libcst.SimpleStatementLine], Sequence[Statement]]:
248
+ body: Sequence[cst.BaseStatement],
249
+ ) -> Tuple[Sequence[cst.BaseStatement], Sequence[cst.BaseStatement], Sequence[cst.BaseStatement]]:
156
250
  """
157
251
  Split the module into 3 parts:
158
252
  - before any TypeAlias, TypeVar or ParamSpec statements
@@ -161,14 +255,21 @@ class AddTypeHelpers(ContextAwareTransformer):
161
255
  """
162
256
  last_import = first_typehelper = last_typehelper = -1
163
257
  start = 0
164
- if _skip_first(orig_module):
258
+ if self.skip_first(body):
165
259
  start = 1
166
260
 
167
- for i, statement in enumerate(orig_module.body[start:]):
261
+ for i, statement in enumerate(body[start:]):
168
262
  if is_import(statement):
169
263
  last_import = i + start
170
264
  continue
171
- if is_TypeVar(statement) or is_TypeAlias(statement) or is_ParamSpec(statement):
265
+ if (
266
+ is_TypeVar(statement)
267
+ or is_TypeAlias(statement)
268
+ or is_ParamSpec(statement)
269
+ or is_CONSTANT(statement)
270
+ or is_AnnCONSTANT(statement)
271
+ or is_docstr(statement)
272
+ ):
172
273
  if first_typehelper == -1:
173
274
  first_typehelper = i + start
174
275
  last_typehelper = i + start
@@ -176,7 +277,8 @@ class AddTypeHelpers(ContextAwareTransformer):
176
277
  insert_after = max(start, last_import + 1, last_typehelper + 1)
177
278
  assert insert_after != -1, "insert_after must be != -1"
178
279
  #
179
- before = list(updated_module.body[:insert_after])
180
- after = list(updated_module.body[insert_after:])
181
- helper_statements: Sequence[libcst.SimpleStatementLine] = list(updated_module.body[first_typehelper : last_typehelper + 1]) # type: ignore
280
+ before = list(body[:insert_after])
281
+ after = list(body[insert_after:])
282
+ helper_statements: Sequence[cst.SimpleStatementLine] = list(body[first_typehelper : last_typehelper + 1]) # type: ignore
182
283
  return (before, helper_statements, after)
284
+
@@ -6,8 +6,8 @@ from pathlib import Path
6
6
  from typing import Union
7
7
 
8
8
  import rich_click as click
9
- from mpflash.logger import log
10
9
 
10
+ from mpflash.logger import log
11
11
  from stubber.codemod.enrich import enrich_folder
12
12
  from stubber.commands.cli import stubber_cli
13
13
  from stubber.utils.config import CONFIG
@@ -44,28 +44,28 @@ from stubber.utils.config import CONFIG
44
44
  is_flag=True,
45
45
  )
46
46
  @click.option(
47
- "--params-only",
48
- "params_only",
49
- default=False,
50
- help="Copy only the parameters, not the docstrings (unless the docstring is missing)",
47
+ "--copy-params/--no-copy-params",
48
+ "copy_params",
49
+ default=True,
50
+ help="Copy the function/method parameters",
51
+ show_default=True,
52
+ is_flag=True,
53
+ )
54
+ @click.option(
55
+ "--copy-docstr/--no-copy-docstr",
56
+ "copy_docstr",
57
+ default=True,
58
+ help="Copy the docstrings",
51
59
  show_default=True,
52
60
  is_flag=True,
53
61
  )
54
- # @click.option(
55
- # "--package-name",
56
- # "-p",
57
- # "package_name",
58
- # default="",
59
- # help="Package name to be enriched (Optional)",
60
- # show_default=True,
61
- # )
62
62
  def cli_enrich_folder(
63
63
  dest_folder: Union[str, Path],
64
64
  source_folder: Union[str, Path],
65
65
  diff: bool = False,
66
66
  dry_run: bool = False,
67
- params_only: bool = True,
68
- # package_name: str = "",
67
+ copy_params: bool = True,
68
+ copy_docstr: bool = False,
69
69
  ):
70
70
  """
71
71
  Enrich the stubs in stub_folder with the docstubs in docstubs_folder.
@@ -78,6 +78,6 @@ def cli_enrich_folder(
78
78
  show_diff=diff,
79
79
  write_back=write_back,
80
80
  require_docstub=False,
81
- # package_name=package_name,
82
- params_only=params_only,
81
+ copy_params=copy_params,
82
+ copy_docstr=copy_docstr,
83
83
  )