Typhon-Language 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Typhon/Driver/configs.py +14 -0
- Typhon/Driver/debugging.py +148 -5
- Typhon/Driver/diagnostic.py +4 -3
- Typhon/Driver/language_server.py +25 -0
- Typhon/Driver/run.py +1 -1
- Typhon/Driver/translate.py +16 -11
- Typhon/Driver/utils.py +39 -1
- Typhon/Grammar/_typhon_parser.py +2920 -2718
- Typhon/Grammar/parser.py +80 -53
- Typhon/Grammar/parser_helper.py +68 -87
- Typhon/Grammar/syntax_errors.py +41 -20
- Typhon/Grammar/token_factory_custom.py +541 -485
- Typhon/Grammar/tokenizer_custom.py +52 -0
- Typhon/Grammar/typhon_ast.py +754 -76
- Typhon/Grammar/typhon_ast_error.py +438 -0
- Typhon/Grammar/unparse_custom.py +25 -0
- Typhon/LanguageServer/__init__.py +3 -0
- Typhon/LanguageServer/client/__init__.py +42 -0
- Typhon/LanguageServer/client/pyrefly.py +115 -0
- Typhon/LanguageServer/client/pyright.py +173 -0
- Typhon/LanguageServer/semantic_tokens.py +446 -0
- Typhon/LanguageServer/server.py +376 -0
- Typhon/LanguageServer/utils.py +65 -0
- Typhon/SourceMap/ast_match_based_map.py +199 -152
- Typhon/SourceMap/ast_matching.py +102 -87
- Typhon/SourceMap/datatype.py +275 -264
- Typhon/SourceMap/defined_name_retrieve.py +145 -0
- Typhon/Transform/comprehension_to_function.py +2 -5
- Typhon/Transform/const_member_to_final.py +12 -7
- Typhon/Transform/extended_patterns.py +139 -0
- Typhon/Transform/forbidden_statements.py +25 -0
- Typhon/Transform/if_while_let.py +122 -11
- Typhon/Transform/inline_statement_block_capture.py +22 -15
- Typhon/Transform/optional_operators_to_checked.py +14 -6
- Typhon/Transform/placeholder_to_function.py +0 -1
- Typhon/Transform/record_to_dataclass.py +22 -238
- Typhon/Transform/scope_check_rename.py +109 -29
- Typhon/Transform/transform.py +16 -12
- Typhon/Transform/type_abbrev_desugar.py +11 -15
- Typhon/Transform/type_annotation_check_expand.py +2 -2
- Typhon/Transform/utils/__init__.py +0 -0
- Typhon/Transform/utils/imports.py +83 -0
- Typhon/Transform/{utils.py → utils/jump_away.py} +2 -38
- Typhon/Transform/utils/make_class.py +135 -0
- Typhon/Transform/visitor.py +25 -0
- Typhon/Typing/pyrefly.py +145 -0
- Typhon/Typing/pyright.py +141 -144
- Typhon/Typing/result_diagnostic.py +1 -1
- Typhon/__main__.py +15 -1
- {typhon_language-0.1.2.dist-info → typhon_language-0.1.4.dist-info}/METADATA +13 -6
- typhon_language-0.1.4.dist-info/RECORD +65 -0
- {typhon_language-0.1.2.dist-info → typhon_language-0.1.4.dist-info}/WHEEL +1 -1
- typhon_language-0.1.4.dist-info/licenses/LICENSE +201 -0
- typhon_language-0.1.2.dist-info/RECORD +0 -48
- typhon_language-0.1.2.dist-info/licenses/LICENSE +0 -21
- {typhon_language-0.1.2.dist-info → typhon_language-0.1.4.dist-info}/entry_points.txt +0 -0
- {typhon_language-0.1.2.dist-info → typhon_language-0.1.4.dist-info}/top_level.txt +0 -0
|
@@ -1,152 +1,199 @@
|
|
|
1
|
-
import ast
|
|
2
|
-
from .datatype import Range, Pos, RangeIntervalTree
|
|
3
|
-
from ..Grammar.typhon_ast import get_pos_attributes_if_exists
|
|
4
|
-
from ..Driver.debugging import debug_print, debug_verbose_print
|
|
5
|
-
from ..SourceMap.ast_matching import match_ast
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
self.
|
|
19
|
-
self.
|
|
20
|
-
self.
|
|
21
|
-
self.
|
|
22
|
-
self.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
Range.from_pos_attr_may_not_end(origin_pos)
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
Range.from_pos_attr_may_not_end(unparsed_pos)
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
self
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
1
|
+
import ast
|
|
2
|
+
from .datatype import Range, Pos, RangeIntervalTree, RangeInterval
|
|
3
|
+
from ..Grammar.typhon_ast import get_pos_attributes_if_exists
|
|
4
|
+
from ..Driver.debugging import debug_print, debug_verbose_print
|
|
5
|
+
from ..SourceMap.ast_matching import match_ast
|
|
6
|
+
from .defined_name_retrieve import defined_name_retrieve
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MatchBasedSourceMap:
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
origin_to_unparsed: dict[ast.AST, ast.AST],
|
|
13
|
+
unparsed_to_origin: dict[ast.AST, ast.AST],
|
|
14
|
+
source_code: str,
|
|
15
|
+
source_file: str,
|
|
16
|
+
unparsed_code: str,
|
|
17
|
+
):
|
|
18
|
+
self.origin_to_unparsed = origin_to_unparsed
|
|
19
|
+
self.unparsed_to_origin = unparsed_to_origin
|
|
20
|
+
self.origin_interval_tree = RangeIntervalTree[ast.AST]()
|
|
21
|
+
self.unparsed_interval_tree = RangeIntervalTree[ast.AST]()
|
|
22
|
+
self.source_code = source_code
|
|
23
|
+
self.source_file = source_file
|
|
24
|
+
self.unparsed_code: str = unparsed_code
|
|
25
|
+
self._setup_interval_trees()
|
|
26
|
+
|
|
27
|
+
def _setup_interval_trees(self):
|
|
28
|
+
for origin_node, unparsed_node in self.origin_to_unparsed.items():
|
|
29
|
+
origin_pos = get_pos_attributes_if_exists(origin_node)
|
|
30
|
+
if origin_pos is not None:
|
|
31
|
+
debug_verbose_print(
|
|
32
|
+
f"Adding to origin interval tree:\n range={Range.from_pos_attr_may_not_end(origin_pos)}\n {ast.dump(origin_node)}\n pos: {origin_pos}"
|
|
33
|
+
)
|
|
34
|
+
self.origin_interval_tree.add(
|
|
35
|
+
Range.from_pos_attr_may_not_end(origin_pos), origin_node
|
|
36
|
+
)
|
|
37
|
+
unparsed_pos = get_pos_attributes_if_exists(unparsed_node)
|
|
38
|
+
if unparsed_pos is not None:
|
|
39
|
+
debug_verbose_print(
|
|
40
|
+
f" Adding to unparsed interval tree:\n range={Range.from_pos_attr_may_not_end(unparsed_pos)}\n {ast.dump(unparsed_node)}\n pos: {unparsed_pos}"
|
|
41
|
+
)
|
|
42
|
+
self.unparsed_interval_tree.add(
|
|
43
|
+
Range.from_pos_attr_may_not_end(unparsed_pos), unparsed_node
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Assume range in base_node, apply the offset of the range in base_node to result_node
|
|
47
|
+
# TODO: This is valid range conversion only for very simple cases.
|
|
48
|
+
def _apply_offset_range(
|
|
49
|
+
self,
|
|
50
|
+
range: Range,
|
|
51
|
+
base_node: ast.AST,
|
|
52
|
+
result_node: ast.AST,
|
|
53
|
+
) -> Range | None:
|
|
54
|
+
base_pos_attr = get_pos_attributes_if_exists(base_node)
|
|
55
|
+
if base_pos_attr is None:
|
|
56
|
+
return None
|
|
57
|
+
base_range = Range.from_pos_attr_may_not_end(base_pos_attr)
|
|
58
|
+
result_pos_attr = get_pos_attributes_if_exists(result_node)
|
|
59
|
+
if result_pos_attr is None:
|
|
60
|
+
return None
|
|
61
|
+
result_range = Range.from_pos_attr_may_not_end(result_pos_attr)
|
|
62
|
+
range_offset = base_range.calc_offset(range)
|
|
63
|
+
debug_verbose_print(
|
|
64
|
+
f"Offsetting range:\n base_range={base_range}\n range={range}\n offset={range_offset}\n result_range={result_range}\n apply_offset:{result_range.start.apply_offset(range_offset)}"
|
|
65
|
+
)
|
|
66
|
+
return result_range.start.apply_offset(range_offset)
|
|
67
|
+
|
|
68
|
+
def _range_to(
|
|
69
|
+
self,
|
|
70
|
+
range: Range,
|
|
71
|
+
interval_tree: RangeIntervalTree[ast.AST],
|
|
72
|
+
mapping: dict[ast.AST, ast.AST],
|
|
73
|
+
) -> Range | None:
|
|
74
|
+
nodes = interval_tree.minimal_containers(range)
|
|
75
|
+
debug_verbose_print(f"Mapping range: {range} nodes: {nodes}")
|
|
76
|
+
if nodes:
|
|
77
|
+
if len(nodes) == 1:
|
|
78
|
+
# The canonical node for the range
|
|
79
|
+
node_range, node = nodes[0]
|
|
80
|
+
if node in mapping:
|
|
81
|
+
debug_verbose_print(
|
|
82
|
+
f"node is one: {ast.dump(node)}\n mapping to {ast.dump(mapping[node])}\n range: {range}\n node_range: {node_range}"
|
|
83
|
+
)
|
|
84
|
+
if node_range == range:
|
|
85
|
+
# If the range matches exactly, no need to apply offset
|
|
86
|
+
return Range.from_ast_node(mapping[node])
|
|
87
|
+
# Use offset mapping to precisely map the range inside the node
|
|
88
|
+
return self._apply_offset_range(range, node, mapping[node])
|
|
89
|
+
else:
|
|
90
|
+
# When multiple nodes are found, merge the ranges of all mapped nodes
|
|
91
|
+
mapped_nodes = (mapping[node] for _, node in nodes if node in mapping)
|
|
92
|
+
pos_attrs = (
|
|
93
|
+
pos_attr
|
|
94
|
+
for mapped_node in mapped_nodes
|
|
95
|
+
if (pos_attr := get_pos_attributes_if_exists(mapped_node))
|
|
96
|
+
)
|
|
97
|
+
debug_verbose_print(f"Found {len(nodes)} nodes for the given range.")
|
|
98
|
+
return Range.merge_ranges(
|
|
99
|
+
Range.from_pos_attr_may_not_end(pos_attr) for pos_attr in pos_attrs
|
|
100
|
+
)
|
|
101
|
+
debug_verbose_print("No nodes found for the given range.")
|
|
102
|
+
|
|
103
|
+
def unparsed_range_to_origin(
|
|
104
|
+
self,
|
|
105
|
+
range_unparsed: Range,
|
|
106
|
+
) -> Range | None:
|
|
107
|
+
return self._range_to(
|
|
108
|
+
range_unparsed,
|
|
109
|
+
self.unparsed_interval_tree,
|
|
110
|
+
self.unparsed_to_origin,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def unparsed_range_to_source_code(
|
|
114
|
+
self,
|
|
115
|
+
range_unparsed: Range,
|
|
116
|
+
) -> str | None:
|
|
117
|
+
range_in_origin = self.unparsed_range_to_origin(range_unparsed)
|
|
118
|
+
if range_in_origin is None:
|
|
119
|
+
return None
|
|
120
|
+
return range_in_origin.of_string(self.source_code)
|
|
121
|
+
|
|
122
|
+
def unparsed_range_to_origin_node(
|
|
123
|
+
self,
|
|
124
|
+
range_unparsed: Range,
|
|
125
|
+
filter_node_type: type[ast.AST] | None = None,
|
|
126
|
+
) -> ast.AST | None:
|
|
127
|
+
nodes: list[RangeInterval[ast.AST]] = (
|
|
128
|
+
self.unparsed_interval_tree.minimal_containers(range_unparsed)
|
|
129
|
+
)
|
|
130
|
+
debug_verbose_print(
|
|
131
|
+
f"Mapping unparsed range to origin node: {range_unparsed} nodes: {nodes}"
|
|
132
|
+
)
|
|
133
|
+
if filter_node_type is not None:
|
|
134
|
+
nodes = [
|
|
135
|
+
(r, n)
|
|
136
|
+
for r, n in nodes
|
|
137
|
+
if isinstance(self.unparsed_to_origin.get(n, None), filter_node_type)
|
|
138
|
+
]
|
|
139
|
+
if nodes and len(nodes) == 1:
|
|
140
|
+
_, node = nodes[0]
|
|
141
|
+
return self.unparsed_to_origin.get(node, None)
|
|
142
|
+
debug_verbose_print("No nodes found for the given unparsed range.")
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
def origin_range_to_unparsed(
|
|
146
|
+
self,
|
|
147
|
+
range_origin: Range,
|
|
148
|
+
) -> Range | None:
|
|
149
|
+
debug_verbose_print(f"Mapping origin range: {range_origin}")
|
|
150
|
+
return self._range_to(
|
|
151
|
+
range_origin,
|
|
152
|
+
self.origin_interval_tree,
|
|
153
|
+
self.origin_to_unparsed,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def origin_node_to_unparsed_range(
|
|
157
|
+
self,
|
|
158
|
+
origin_node: ast.AST,
|
|
159
|
+
) -> Range | None:
|
|
160
|
+
range_origin = get_pos_attributes_if_exists(origin_node)
|
|
161
|
+
if range_origin is None:
|
|
162
|
+
return None
|
|
163
|
+
range_origin_part = Range.from_pos_attr_may_not_end(range_origin)
|
|
164
|
+
return self.origin_range_to_unparsed(range_origin_part)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def map_from_translated_ast(
|
|
168
|
+
origin_ast: ast.AST,
|
|
169
|
+
unparsed_ast: ast.AST,
|
|
170
|
+
source_code: str,
|
|
171
|
+
source_file_path: str,
|
|
172
|
+
unparsed_code: str,
|
|
173
|
+
) -> MatchBasedSourceMap | None:
|
|
174
|
+
defined_name_retrieve(unparsed_ast, unparsed_code)
|
|
175
|
+
mapping = match_ast(origin_ast, unparsed_ast)
|
|
176
|
+
if mapping is None:
|
|
177
|
+
return None
|
|
178
|
+
return MatchBasedSourceMap(
|
|
179
|
+
mapping.left_to_right,
|
|
180
|
+
mapping.right_to_left,
|
|
181
|
+
source_code,
|
|
182
|
+
source_file_path,
|
|
183
|
+
unparsed_code,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def map_from_translated(
|
|
188
|
+
origin_ast: ast.AST,
|
|
189
|
+
source_code: str,
|
|
190
|
+
source_file_path: str,
|
|
191
|
+
translated_code: str,
|
|
192
|
+
) -> MatchBasedSourceMap | None:
|
|
193
|
+
return map_from_translated_ast(
|
|
194
|
+
origin_ast,
|
|
195
|
+
ast.parse(translated_code),
|
|
196
|
+
source_code,
|
|
197
|
+
source_file_path,
|
|
198
|
+
translated_code,
|
|
199
|
+
)
|
Typhon/SourceMap/ast_matching.py
CHANGED
|
@@ -1,87 +1,102 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
|
-
import ast
|
|
3
|
-
from contextlib import contextmanager
|
|
4
|
-
from typing import Any, cast
|
|
5
|
-
from ..Driver.debugging import debug_verbose_print
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
self.
|
|
18
|
-
self.
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
self.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import ast
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
from ..Driver.debugging import debug_verbose_print
|
|
6
|
+
from ..Grammar.typhon_ast import DefinesName, get_defined_name, get_import_from_names
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Match the AST node to right module recursively
|
|
10
|
+
class MatchingVisitor(ast.NodeVisitor):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
right: ast.AST,
|
|
14
|
+
left_to_right: dict[ast.AST, ast.AST],
|
|
15
|
+
right_to_left: dict[ast.AST, ast.AST],
|
|
16
|
+
):
|
|
17
|
+
self.left_to_right = left_to_right
|
|
18
|
+
self.right_to_left = right_to_left
|
|
19
|
+
self.right = right
|
|
20
|
+
|
|
21
|
+
@contextmanager
|
|
22
|
+
def _with_right(self, right: ast.AST):
|
|
23
|
+
old_right = self.right
|
|
24
|
+
self.right = right
|
|
25
|
+
yield
|
|
26
|
+
self.right = old_right
|
|
27
|
+
|
|
28
|
+
def _commit(self, left: ast.AST, right: ast.AST):
|
|
29
|
+
debug_verbose_print(f"Matched: {ast.dump(left)} <-> {ast.dump(right)}")
|
|
30
|
+
self.left_to_right[left] = right
|
|
31
|
+
self.right_to_left[right] = left
|
|
32
|
+
|
|
33
|
+
def _visit_list(self, lefts: list[Any], rights: list[Any]):
|
|
34
|
+
if len(lefts) != len(rights):
|
|
35
|
+
# TODO: Error rescue: List length mismatch
|
|
36
|
+
raise ValueError(f"List length mismatch: {len(lefts)} vs {len(rights)}")
|
|
37
|
+
for left, right in zip(lefts, rights):
|
|
38
|
+
if not isinstance(left, ast.AST) or not isinstance(right, ast.AST):
|
|
39
|
+
if left != right:
|
|
40
|
+
# TODO: Error rescue: List value mismatch
|
|
41
|
+
raise ValueError(f"List value mismatch: {left} vs {right}")
|
|
42
|
+
continue
|
|
43
|
+
with self._with_right(right):
|
|
44
|
+
self.visit(left)
|
|
45
|
+
|
|
46
|
+
def visit(self, node: ast.AST):
|
|
47
|
+
right = self.right
|
|
48
|
+
if type(node) is not type(right):
|
|
49
|
+
# TODO: Error rescue: Type mismatch
|
|
50
|
+
raise ValueError(f"Type mismatch: {ast.dump(node)} vs {ast.dump(right)}")
|
|
51
|
+
# Commit the match
|
|
52
|
+
self._commit(node, right)
|
|
53
|
+
# Check defined name
|
|
54
|
+
if isinstance(node, DefinesName):
|
|
55
|
+
left_name = get_defined_name(node)
|
|
56
|
+
right_name = get_defined_name(cast(DefinesName, right))
|
|
57
|
+
if left_name is not None and right_name is not None:
|
|
58
|
+
with self._with_right(right_name):
|
|
59
|
+
self.visit(left_name)
|
|
60
|
+
# Allow defined name not matching
|
|
61
|
+
# Check import from module names
|
|
62
|
+
if isinstance(node, ast.ImportFrom):
|
|
63
|
+
modules = get_import_from_names(node)
|
|
64
|
+
right_modules = get_import_from_names(cast(ast.ImportFrom, right))
|
|
65
|
+
if modules and right_modules:
|
|
66
|
+
self._visit_list(modules, right_modules)
|
|
67
|
+
# Recursively visit fields
|
|
68
|
+
for field, value in ast.iter_fields(node):
|
|
69
|
+
right_value = getattr(right, field, None)
|
|
70
|
+
if isinstance(value, list):
|
|
71
|
+
if isinstance(right_value, list):
|
|
72
|
+
self._visit_list(
|
|
73
|
+
cast(list[Any], value), cast(list[Any], right_value)
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
# TODO: Error rescue: List length mismatch
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"List mismatch in field {field}: {ast.dump(node)} vs {ast.dump(right)}"
|
|
79
|
+
)
|
|
80
|
+
elif isinstance(value, ast.AST):
|
|
81
|
+
with self._with_right(cast(ast.AST, right_value)):
|
|
82
|
+
self.visit(value)
|
|
83
|
+
else:
|
|
84
|
+
if value != right_value:
|
|
85
|
+
# TODO: Error rescue: Value mismatch
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"Value mismatch in field {field}: {value} vs {right_value}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class MatchResult:
|
|
93
|
+
left_to_right: dict[ast.AST, ast.AST]
|
|
94
|
+
right_to_left: dict[ast.AST, ast.AST]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def match_ast[T: ast.AST](left: T, right: T) -> MatchResult | None:
|
|
98
|
+
left_to_right: dict[ast.AST, ast.AST] = {}
|
|
99
|
+
right_to_left: dict[ast.AST, ast.AST] = {}
|
|
100
|
+
visitor = MatchingVisitor(right, left_to_right, right_to_left)
|
|
101
|
+
visitor.visit(left)
|
|
102
|
+
return MatchResult(left_to_right, right_to_left)
|