micropython-stubber 1.24.0__py3-none-any.whl → 1.24.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {micropython_stubber-1.24.0.dist-info → micropython_stubber-1.24.2.dist-info}/METADATA +3 -2
  2. {micropython_stubber-1.24.0.dist-info → micropython_stubber-1.24.2.dist-info}/RECORD +38 -36
  3. mpflash/mpflash/bootloader/activate.py +1 -1
  4. mpflash/mpflash/flash/esp.py +1 -1
  5. mpflash/mpflash/flash/uf2/__init__.py +18 -2
  6. mpflash/mpflash/mpboard_id/add_boards.py +5 -2
  7. mpflash/mpflash/mpboard_id/board_id.py +7 -5
  8. mpflash/mpflash/mpboard_id/board_info.zip +0 -0
  9. mpflash/mpflash/vendor/pico-universal-flash-nuke/LICENSE.txt +21 -0
  10. mpflash/mpflash/vendor/pico-universal-flash-nuke/universal_flash_nuke.uf2 +0 -0
  11. mpflash/mpflash/vendor/readme.md +2 -0
  12. mpflash/poetry.lock +754 -488
  13. mpflash/pyproject.toml +1 -1
  14. stubber/__init__.py +1 -1
  15. stubber/board/createstubs.py +44 -38
  16. stubber/board/createstubs_db.py +17 -12
  17. stubber/board/createstubs_db_min.py +63 -63
  18. stubber/board/createstubs_db_mpy.mpy +0 -0
  19. stubber/board/createstubs_mem.py +17 -12
  20. stubber/board/createstubs_mem_min.py +99 -99
  21. stubber/board/createstubs_mem_mpy.mpy +0 -0
  22. stubber/board/createstubs_min.py +111 -112
  23. stubber/board/createstubs_mpy.mpy +0 -0
  24. stubber/board/modulelist.txt +27 -27
  25. stubber/codemod/enrich.py +4 -6
  26. stubber/codemod/merge_docstub.py +10 -10
  27. stubber/codemod/visitors/type_helpers.py +182 -0
  28. stubber/commands/get_docstubs_cmd.py +5 -6
  29. stubber/cst_transformer.py +2 -1
  30. stubber/merge_config.py +3 -0
  31. stubber/publish/merge_docstubs.py +1 -2
  32. stubber/publish/stubpackage.py +36 -14
  33. stubber/rst/lookup.py +3 -0
  34. stubber/rst/reader.py +8 -13
  35. stubber/tools/manifestfile.py +2 -1
  36. stubber/codemod/visitors/typevars.py +0 -200
  37. {micropython_stubber-1.24.0.dist-info → micropython_stubber-1.24.2.dist-info}/LICENSE +0 -0
  38. {micropython_stubber-1.24.0.dist-info → micropython_stubber-1.24.2.dist-info}/WHEEL +0 -0
  39. {micropython_stubber-1.24.0.dist-info → micropython_stubber-1.24.2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,182 @@
1
+ """
2
+ Gather all TypeVar and TypeAlias assignments in a module.
3
+ """
4
+
5
+ from typing import List, Sequence, Tuple, Union
6
+
7
+ import libcst
8
+ import libcst.matchers as m
9
+ from libcst import SimpleStatementLine
10
+ from libcst.codemod._context import CodemodContext
11
+ from libcst.codemod._visitor import ContextAwareTransformer, ContextAwareVisitor
12
+ from libcst.codemod.visitors._add_imports import _skip_first
13
+ from typing_extensions import TypeAlias
14
+
15
+ from mpflash.logger import log
16
+
17
+ TypeHelper: TypeAlias = Union[libcst.Assign, libcst.AnnAssign]
18
+ TypeHelpers: TypeAlias = List[TypeHelper]
19
+ Statement: TypeAlias = Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]
20
+
21
+ _mod = libcst.parse_module("") # Debugging aid : _mod.code_for_node(node)
22
+ _code = _mod.code_for_node # Debugging aid : _code(node)
23
+
24
+
25
+ def is_TypeAlias(statement) -> bool:
26
+ "Annotated Assign - Foo:TypeAlias = ..."
27
+ if m.matches(statement, m.SimpleStatementLine()):
28
+ statement = statement.body[0]
29
+ return m.matches(
30
+ statement,
31
+ m.AnnAssign(annotation=m.Annotation(annotation=m.Name(value="TypeAlias"))),
32
+ )
33
+
34
+
35
+ def is_TypeVar(statement):
36
+ "Assing - Foo = Typevar(...)"
37
+ if m.matches(statement, m.SimpleStatementLine()):
38
+ statement = statement.body[0]
39
+ return m.matches(
40
+ statement,
41
+ m.Assign(value=m.Call(func=m.Name(value="TypeVar"))),
42
+ )
43
+
44
+
45
+ def is_ParamSpec(statement):
46
+ "Assign - Foo = ParamSpec(...)"
47
+ if m.matches(statement, m.SimpleStatementLine()):
48
+ statement = statement.body[0]
49
+ return m.matches(
50
+ statement,
51
+ m.Assign(value=m.Call(func=m.Name(value="ParamSpec"))),
52
+ )
53
+
54
+
55
+ def is_import(statement):
56
+ "import - import foo"
57
+ return m.matches(statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]))
58
+
59
+
60
+ class GatherTypeHelpers(ContextAwareVisitor):
61
+ """
62
+ A class for tracking visited TypeVars and TypeAliases.
63
+ """
64
+
65
+ def __init__(self, context: CodemodContext) -> None:
66
+ super().__init__(context)
67
+ # Track all of the TypeVar, TypeAlias and Paramspec assignments found
68
+ self.all_typehelpers: TypeHelpers = []
69
+
70
+ def visit_Assign(self, node: libcst.Assign) -> None:
71
+ """
72
+ Find all TypeVar assignments in the module.
73
+ format: T = TypeVar("T", int, float, str, bytes, Tuple)
74
+ """
75
+ # is this a TypeVar assignment?
76
+ # needs to be more robust
77
+ if is_TypeVar(node) or is_ParamSpec(node):
78
+ self.all_typehelpers.append(node)
79
+
80
+ def visit_AnnAssign(self, node: libcst.AnnAssign) -> None:
81
+ """ "
82
+ Find all TypeAlias assignments in the module.
83
+ format: T: TypeAlias = str
84
+ """
85
+ if is_TypeAlias(node):
86
+ self.all_typehelpers.append(node)
87
+
88
+
89
+ class AddTypeHelpers(ContextAwareTransformer):
90
+ """
91
+ Visitor loosly based on AddImportsVisitor
92
+ """
93
+
94
+ CONTEXT_KEY = "AddTypeHelpers"
95
+
96
+ def __init__(self, context: CodemodContext) -> None:
97
+ super().__init__(context)
98
+ self.new_typehelpers: TypeHelpers = context.scratch.get(self.CONTEXT_KEY, [])
99
+ self.all_typehelpers: TypeHelpers = []
100
+
101
+ @classmethod
102
+ def add_typevar(cls, context: CodemodContext, node: libcst.Assign):
103
+ new_typehelpers = context.scratch.get(cls.CONTEXT_KEY, [])
104
+ new_typehelpers.append(node)
105
+ context.scratch[cls.CONTEXT_KEY] = new_typehelpers
106
+ # add the typevar to the module
107
+
108
+ def leave_Module(
109
+ self,
110
+ original_node: libcst.Module,
111
+ updated_node: libcst.Module,
112
+ ) -> libcst.Module:
113
+ if not self.new_typehelpers:
114
+ return updated_node
115
+
116
+ # split the module into 3 parts
117
+ # before and after the insertions point , and a list of the TV and TA statements
118
+ (
119
+ statements_before,
120
+ helper_statements,
121
+ statements_after,
122
+ ) = self._split_module(original_node, updated_node)
123
+
124
+ # simpler to compare as text than to compare the nodes -
125
+ existing_targets = [
126
+ helper.body[0].targets[0].target.value # type: ignore
127
+ for helper in helper_statements
128
+ if is_TypeVar(helper) or is_ParamSpec(helper)
129
+ ] + [
130
+ helper.body[0].target.value # type: ignore
131
+ for helper in helper_statements
132
+ if is_TypeAlias(helper)
133
+ ]
134
+ statements_to_add = []
135
+ for new_typehelper in self.new_typehelpers:
136
+ if isinstance(new_typehelper, libcst.AnnAssign):
137
+ if new_typehelper.target.value not in existing_targets: # type: ignore
138
+ statements_to_add.append(SimpleStatementLine(body=[new_typehelper]))
139
+ elif isinstance(new_typehelper, libcst.Assign):
140
+ if new_typehelper.targets[0].target.value not in existing_targets: # type: ignore
141
+ statements_to_add.append(SimpleStatementLine(body=[new_typehelper]))
142
+
143
+ body = (
144
+ *statements_before,
145
+ *statements_to_add,
146
+ *statements_after,
147
+ )
148
+
149
+ return updated_node.with_changes(body=body)
150
+
151
+ def _split_module(
152
+ self,
153
+ orig_module: libcst.Module,
154
+ updated_module: libcst.Module,
155
+ ) -> Tuple[Sequence[Statement], Sequence[libcst.SimpleStatementLine], Sequence[Statement]]:
156
+ """
157
+ Split the module into 3 parts:
158
+ - before any TypeAlias, TypeVar or ParamSpec statements
159
+ - the TypeAlias and TypeVar statements
160
+ - the rest of the module after the TypeAlias and TypeVar statements
161
+ """
162
+ last_import = first_typehelper = last_typehelper = -1
163
+ start = 0
164
+ if _skip_first(orig_module):
165
+ start = 1
166
+
167
+ for i, statement in enumerate(orig_module.body[start:]):
168
+ if is_import(statement):
169
+ last_import = i + start
170
+ continue
171
+ if is_TypeVar(statement) or is_TypeAlias(statement) or is_ParamSpec(statement):
172
+ if first_typehelper == -1:
173
+ first_typehelper = i + start
174
+ last_typehelper = i + start
175
+ # insert as high as possible, but after the last import and last TypeVar/TypeAlias statement
176
+ insert_after = max(start, last_import + 1, last_typehelper + 1)
177
+ assert insert_after != -1, "insert_after must be != -1"
178
+ #
179
+ before = list(updated_module.body[:insert_after])
180
+ after = list(updated_module.body[insert_after:])
181
+ helper_statements: Sequence[libcst.SimpleStatementLine] = list(updated_module.body[first_typehelper : last_typehelper + 1]) # type: ignore
182
+ return (before, helper_statements, after)
@@ -5,14 +5,14 @@ get-docstubs
5
5
 
6
6
  from pathlib import Path
7
7
  from typing import Optional
8
+
9
+ import rich_click as click
8
10
  from packaging.version import Version
9
11
 
10
12
  import mpflash.basicgit as git
11
- import rich_click as click
13
+ import stubber.utils as utils
12
14
  from mpflash.logger import log
13
-
14
15
  from stubber.codemod.enrich import enrich_folder
15
- import stubber.utils as utils
16
16
  from stubber.commands.cli import stubber_cli
17
17
  from stubber.stubs_from_docs import generate_from_rst
18
18
  from stubber.utils.config import CONFIG
@@ -65,7 +65,7 @@ from stubber.utils.repos import fetch_repos
65
65
  "--enrich",
66
66
  is_flag=True,
67
67
  default=False,
68
- help="Enrich with type information from micropython-reference",
68
+ help="Enrich with type information from reference/micropython",
69
69
  show_default=True,
70
70
  )
71
71
  @click.pass_context
@@ -129,8 +129,7 @@ def cli_docstubs(
129
129
  log.warning(f"Enriching is not supported for version {version}")
130
130
  else:
131
131
  # !stubber enrich --params-only --source {reference} --dest {docstubs}
132
- reference_path = CONFIG.stub_path.parent / "micropython-reference"
133
- log.info(f"Enriching from {reference_path}")
132
+ reference_path = CONFIG.stub_path.parent / "reference/micropython"
134
133
  _ = enrich_folder(
135
134
  reference_path,
136
135
  dst_path,
@@ -5,6 +5,7 @@ from dataclasses import dataclass, field
5
5
  from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
6
 
7
7
  import libcst as cst
8
+ from libcst import matchers as m
8
9
 
9
10
 
10
11
  @dataclass
@@ -124,7 +125,7 @@ class StubTypingCollector(cst.CSTVisitor):
124
125
  # store the first function/method signature
125
126
  self.annotations[key] = AnnoValue(type_info=ti)
126
127
 
127
- if any(dec.decorator.value == "overload" for dec in node.decorators): # type: ignore
128
+ if any(m.matches(dec , m.Decorator(decorator=m.Name("overload"))) for dec in node.decorators):
128
129
  # and store the overloads
129
130
  self.annotations[key].overloads.append(ti)
130
131
 
stubber/merge_config.py CHANGED
@@ -40,6 +40,9 @@ RM_MERGED: Final = (
40
40
  "_rp2", # Leave out for now , to avoid conflicts with the rp2 module
41
41
  "pycopy_imphook", # pycopy only: not needed in the merged stubs
42
42
  # "os",
43
+ "types", # defined in webassembly pyscript - shadows stdlib
44
+ "abc", # defined in webassembly pyscript - shadows stdlib
45
+ # "uos", # ???? problems with mypy & webassembly stubs
43
46
  ]
44
47
  + STDLIB_MODULES
45
48
  + [f"u{mod}" for mod in U_MODULES]
@@ -7,7 +7,6 @@ from pathlib import Path
7
7
  from typing import List, Optional, Union
8
8
 
9
9
  from mpflash.logger import log
10
-
11
10
  from stubber.codemod.enrich import enrich_folder
12
11
  from stubber.merge_config import RM_MERGED, recreate_umodules, remove_modules
13
12
  from stubber.publish.candidates import board_candidates, filter_list
@@ -147,7 +146,7 @@ def refactor_rp2_module(dest_path: Path):
147
146
  rp2_file.rename(rp2_folder / "__init__.pyi")
148
147
  # copy the asm_pio.pyi file from the reference folder
149
148
  for submod in ["rp2/asm_pio.pyi"]:
150
- file = CONFIG.mpy_stubs_path / "micropython-reference" / submod
149
+ file = CONFIG.mpy_stubs_path / "reference/micropython" / submod
151
150
  if file.exists():
152
151
  shutil.copy(file, rp2_folder / file.name)
153
152
  log.info(f" - add rp2/{ file.name}")
@@ -1,4 +1,12 @@
1
- """Create a stub-only package for a specific version of micropython"""
1
+ """
2
+ Create a stub-only package for a specific
3
+ - version
4
+ - port
5
+ - board
6
+ of micropython
7
+
8
+
9
+ """
2
10
 
3
11
  import hashlib
4
12
  import json
@@ -10,8 +18,8 @@ from pathlib import Path
10
18
  from typing import Any, Dict, List, Optional, Tuple, Union
11
19
 
12
20
  import tenacity
13
- from mpflash.basicgit import get_git_describe
14
21
 
22
+ from mpflash.basicgit import get_git_describe
15
23
  from stubber.publish.helpers import get_module_docstring
16
24
 
17
25
  if sys.version_info >= (3, 11):
@@ -22,10 +30,10 @@ else:
22
30
  from typing import NewType
23
31
 
24
32
  import tomli_w
25
- from mpflash.logger import log
26
- from mpflash.versions import SET_PREVIEW, V_PREVIEW, clean_version
27
33
  from packaging.version import Version, parse
28
34
 
35
+ from mpflash.logger import log
36
+ from mpflash.versions import SET_PREVIEW, V_PREVIEW, clean_version
29
37
  from stubber.publish.bump import bump_version
30
38
  from stubber.publish.defaults import GENERIC_U, default_board
31
39
  from stubber.publish.enums import StubSource
@@ -35,10 +43,15 @@ from stubber.utils.config import CONFIG
35
43
  Status = NewType("Status", Dict[str, Union[str, None]])
36
44
  StubSources = List[Tuple[StubSource, Path]]
37
45
 
38
- # indicates which stubs will be skipped when copying for these stub sources
46
+ # indicates which stubs will not be copyied from the stub sources
39
47
  STUBS_COPY_FILTER = {
40
48
  StubSource.FROZEN: [
41
- "espnow", # merged stubs + documentation of the espnow module is better than the info in the forzen stubs
49
+ "espnow", # merged stubs + documentation of the espnow module is better than the info in the frozen stubs
50
+ "collections", # must be in stdlib
51
+ "types", # must be in stdlib
52
+ "abc", # must be in stdlib
53
+ "time", # must be in stdlib
54
+ "io", # must be in stdlib
42
55
  ],
43
56
  StubSource.FIRMWARE: [
44
57
  "builtins",
@@ -359,10 +372,14 @@ class Builder(VersionedPackage):
359
372
 
360
373
  def copy_folder(self, stub_type: StubSource, src_path: Path):
361
374
  Path(self.package_path).mkdir(parents=True, exist_ok=True)
362
- for item in (CONFIG.stub_path / src_path).rglob("*"):
375
+ for item in (CONFIG.stub_path / src_path).rglob("*.pyi"):
363
376
  if item.is_file():
364
377
  # filter the 'poorly' decorated files
365
- if stub_type in STUBS_COPY_FILTER and item.stem in STUBS_COPY_FILTER[stub_type]:
378
+ if stub_type in STUBS_COPY_FILTER and (
379
+ item.stem in STUBS_COPY_FILTER[stub_type] or
380
+ item.parent.name in STUBS_COPY_FILTER[stub_type]
381
+ ):
382
+ log.trace(f"Skipping {item.name}")
366
383
  continue
367
384
 
368
385
  target = Path(self.package_path) / item.relative_to(CONFIG.stub_path / src_path)
@@ -631,8 +648,10 @@ class PoetryBuilder(Builder):
631
648
  try:
632
649
  with open(_toml, "rb") as f:
633
650
  pyproject = tomllib.load(f)
634
- # pyproject["tool"]["poetry"]["version"] = version
635
- pyproject["project"]["version"] = version
651
+ if "project" in pyproject:
652
+ pyproject["project"]["version"] = version
653
+ else:
654
+ pyproject["tool"]["poetry"]["version"] = version
636
655
  # update the version in the toml file
637
656
  with open(_toml, "wb") as output:
638
657
  tomli_w.dump(pyproject, output)
@@ -742,10 +761,13 @@ class PoetryBuilder(Builder):
742
761
  raise (e)
743
762
 
744
763
  # update the name , version and description of the package
745
- # _pyproject["tool"]["poetry"]["name"] = self.package_name
746
- # _pyproject["tool"]["poetry"]["description"] = self.description
747
- _pyproject["project"]["name"] = self.package_name
748
- _pyproject["project"]["description"] = self.description
764
+ if 'project' in _pyproject:
765
+ _pyproject["project"]["name"] = self.package_name
766
+ _pyproject["project"]["description"] = self.description
767
+ else:
768
+ _pyproject["tool"]["poetry"]["name"] = self.package_name
769
+ _pyproject["tool"]["poetry"]["description"] = self.description
770
+
749
771
  # write out the pyproject.toml file
750
772
  self.pyproject = _pyproject
751
773
 
stubber/rst/lookup.py CHANGED
@@ -101,6 +101,9 @@ RST_DOC_FIXES: List[Tuple[str, str]] = [
101
101
  ".. method:: AIOESPNow._aiter__() / async AIOESPNow.__anext__()",
102
102
  ".. method:: AIOESPNow._aiter__()\n async AIOESPNow.__anext__()",
103
103
  ),
104
+ # appended to in ssl.constant name - ssl.PROTOCOL_DTLS_CLIENT(when DTLS support is enabled)
105
+ # Ugly hack to fix the documentation
106
+ ( '(when DTLS support is enabled)', " : Incomplete # (when DTLS support is enabled)")
104
107
  ]
105
108
 
106
109
 
stubber/rst/reader.py CHANGED
@@ -69,19 +69,10 @@ from typing import List, Optional, Tuple
69
69
 
70
70
  from mpflash.logger import log
71
71
  from mpflash.versions import V_PREVIEW
72
-
73
- from stubber.rst import (
74
- CHILD_PARENT_CLASS,
75
- MODULE_GLUE,
76
- PARAM_FIXES,
77
- PARAM_RE_FIXES,
78
- RST_DOC_FIXES,
79
- TYPING_IMPORT,
80
- ClassSourceDict,
81
- FunctionSourceDict,
82
- ModuleSourceDict,
83
- return_type_from_context,
84
- )
72
+ from stubber.rst import (CHILD_PARENT_CLASS, MODULE_GLUE, PARAM_FIXES,
73
+ PARAM_RE_FIXES, RST_DOC_FIXES, TYPING_IMPORT,
74
+ ClassSourceDict, FunctionSourceDict, ModuleSourceDict,
75
+ return_type_from_context)
85
76
  from stubber.rst.lookup import Fix
86
77
  from stubber.utils.config import CONFIG
87
78
 
@@ -782,6 +773,10 @@ class RSTParser(RSTReader):
782
773
  self.line_no += counter - 1
783
774
  # clean up before returning
784
775
  names = [n.strip() for n in names if n.strip() != "etc."] # remove etc.
776
+ # Ugly one-off hack
777
+ # to remove the '(when DTLS support is enabled)' from the ssl constants
778
+ # names = [n.replace('(when DTLS support is enabled)', '') for n in names]
779
+
785
780
  return names
786
781
 
787
782
  def parse_data(self):
@@ -29,6 +29,7 @@ from __future__ import print_function
29
29
  import contextlib
30
30
  import os
31
31
  import sys
32
+ import glob
32
33
  import tempfile
33
34
  from collections import namedtuple
34
35
 
@@ -400,7 +401,7 @@ class ManifestFile:
400
401
  self._metadata.pop()
401
402
 
402
403
  def _require_from_path(self, library_path, name, version, extra_kwargs):
403
- for root, dirnames, filenames in os.walk(library_path): # type: ignore
404
+ for root, dirnames, filenames in os.walk(library_path):
404
405
  if os.path.basename(root) == name and "manifest.py" in filenames:
405
406
  self.include(root, is_require=True, **extra_kwargs)
406
407
  return True
@@ -1,200 +0,0 @@
1
- """
2
- Gather all TypeVar and TypeAlias assignments in a module.
3
- """
4
-
5
- from typing import List, Tuple, Union
6
-
7
- import libcst
8
- import libcst.matchers as m
9
- from libcst import SimpleStatementLine
10
- from libcst.codemod._context import CodemodContext
11
- from libcst.codemod._visitor import ContextAwareTransformer, ContextAwareVisitor
12
- from libcst.codemod.visitors._add_imports import _skip_first
13
- from mpflash.logger import log
14
-
15
- _mod = libcst.parse_module("") # Debugging aid : _mod.code_for_node(node)
16
- _code = _mod.code_for_node # Debugging aid : _code(node)
17
-
18
-
19
- class GatherTypeVarsVisitor(ContextAwareVisitor):
20
- """
21
- A class for tracking visited TypeVars and TypeAliases.
22
- """
23
-
24
- def __init__(self, context: CodemodContext) -> None:
25
- super().__init__(context)
26
- # Track all of the TypeVar assignments found in this transform
27
- self.all_typealias_or_vars: List[Union[libcst.Assign, libcst.AnnAssign]] = []
28
-
29
- def visit_Assign(self, node: libcst.Assign) -> None:
30
- """
31
- Find all TypeVar assignments in the module.
32
- format: T = TypeVar("T", int, float, str, bytes, Tuple)
33
- """
34
- # is this a TypeVar assignment?
35
- # needs to be more robust
36
- if isinstance(node.value, libcst.Call) and node.value.func.value == "TypeVar": # type: ignore
37
- self.all_typealias_or_vars.append(node)
38
-
39
- def visit_AnnAssign(self, node: libcst.AnnAssign) -> None:
40
- """ "
41
- Find all TypeAlias assignments in the module.
42
- format: T: TypeAlias = str
43
- """
44
- if (
45
- isinstance(node.annotation.annotation, libcst.Name)
46
- and node.annotation.annotation.value == "TypeAlias"
47
- ):
48
- self.all_typealias_or_vars.append(node)
49
-
50
-
51
- def is_TypeAlias(statement) -> bool:
52
- "Just the body of a simple statement" ""
53
- return m.matches(
54
- statement,
55
- m.AnnAssign(annotation=m.Annotation(annotation=m.Name(value="TypeAlias"))),
56
- )
57
-
58
-
59
- def is_TypeVar(statement):
60
- "Just the body of a simple statement" ""
61
- r = m.matches(
62
- statement,
63
- m.Assign(value=m.Call(func=m.Name(value="TypeVar"))),
64
- )
65
- # m.SimpleStatementLine(body=[m.Assign(value=m.Call(func=m.Name(value="TypeVar")))]),
66
- return r
67
-
68
-
69
- class AddTypeVarsVisitor(ContextAwareTransformer):
70
- """
71
- Visitor loosly based on AddImportsVisitor
72
- """
73
-
74
- CONTEXT_KEY = "AddTypeVarsVisitor"
75
-
76
- def __init__(self, context: CodemodContext) -> None:
77
- super().__init__(context)
78
- self.new_typealias_or_vars: List[Union[libcst.Assign, libcst.AnnAssign]] = (
79
- context.scratch.get(self.CONTEXT_KEY, [])
80
- )
81
-
82
- self.all_typealias_or_vars: List[Union[libcst.Assign, libcst.AnnAssign]] = []
83
-
84
- @classmethod
85
- def add_typevar(cls, context: CodemodContext, node: libcst.Assign):
86
- new_typealias_or_vars = context.scratch.get(cls.CONTEXT_KEY, [])
87
- new_typealias_or_vars.append(node)
88
- context.scratch[cls.CONTEXT_KEY] = new_typealias_or_vars
89
- # add the typevar to the module
90
-
91
- def leave_Module(
92
- self,
93
- original_node: libcst.Module,
94
- updated_node: libcst.Module,
95
- ) -> libcst.Module:
96
-
97
- if not self.new_typealias_or_vars:
98
- return updated_node
99
-
100
- # split the module into 3 parts
101
- # before and after the insertions point , and a list of the TV and TA statements
102
- (
103
- statements_before,
104
- statements_after,
105
- tv_ta_statements,
106
- ) = self._split_module(original_node, updated_node)
107
-
108
- # TODO: avoid duplication of TypeVars and TypeAliases
109
- statements_to_add = []
110
- for new_tvta in self.new_typealias_or_vars:
111
- existing = False
112
- for existing_line in tv_ta_statements:
113
- try:
114
- existing_tv = existing_line.body[0] # type: ignore
115
- except (TypeError, IndexError):
116
- # catch 'SimpleStatementLine' object is not subscriptable when the statement is not a simple statement
117
- log.error("TypeVar or TypeAlias statement is not a simple statement")
118
- continue
119
-
120
- # same type and same target?
121
- if (
122
- is_TypeAlias(new_tvta)
123
- and is_TypeAlias(existing_tv)
124
- and new_tvta.target.value == existing_tv.target.value # type: ignore
125
- ):
126
- existing = True
127
- break
128
-
129
- # same type and same targets?
130
- if (
131
- is_TypeVar(new_tvta)
132
- and is_TypeVar(existing_tv)
133
- and new_tvta.targets[0].children[0].value == existing_tv.targets[0].children[0].value # type: ignore
134
- ):
135
- existing = True
136
- break
137
-
138
- if not existing:
139
- statements_to_add.append(SimpleStatementLine(body=[new_tvta]))
140
-
141
- body = (
142
- *statements_before,
143
- *statements_to_add,
144
- *statements_after,
145
- )
146
-
147
- return updated_node.with_changes(body=body)
148
-
149
- def _split_module(
150
- self,
151
- orig_module: libcst.Module,
152
- updated_module: libcst.Module,
153
- ) -> Tuple[
154
- List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
155
- List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
156
- List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
157
- ]:
158
- """
159
- Split the module into 3 parts:
160
- - before any TypeAlias or TypeVar statements
161
- - the TypeAlias and TypeVar statements
162
- - the rest of the module after the TypeAlias and TypeVar statements
163
- """
164
- last_import = first_tv_ta = last_tv_ta = -1
165
- start = 0
166
- if _skip_first(orig_module):
167
- start = 1
168
-
169
- for i, statement in enumerate(orig_module.body[start:]):
170
- # todo: optimize to avoid multiple parses
171
- is_imp = m.matches(
172
- statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])
173
- )
174
- if is_imp:
175
- last_import = i + start
176
- is_ta = m.matches(
177
- statement,
178
- m.SimpleStatementLine(
179
- body=[
180
- m.AnnAssign(annotation=m.Annotation(annotation=m.Name(value="TypeAlias")))
181
- ]
182
- ),
183
- )
184
- is_tv = m.matches(
185
- statement,
186
- m.SimpleStatementLine(body=[m.Assign(value=m.Call(func=m.Name(value="TypeVar")))]),
187
- )
188
- if is_tv or is_ta:
189
- if first_tv_ta == -1:
190
- first_tv_ta = i + start
191
- last_tv_ta = i + start
192
- # insert as high as possible, but after the last import and last TypeVar/TypeAlias statement
193
- insert_after = max(start, last_import + 1, last_tv_ta + 1)
194
- assert insert_after != -1, "insert_after must be != -1"
195
- #
196
- first = list(updated_module.body[:insert_after])
197
- last = list(updated_module.body[insert_after:])
198
- ta_statements = list(updated_module.body[first_tv_ta : last_tv_ta + 1])
199
-
200
- return (first, last, ta_statements)