yuho 5.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yuho/__init__.py +16 -0
- yuho/ast/__init__.py +196 -0
- yuho/ast/builder.py +926 -0
- yuho/ast/constant_folder.py +280 -0
- yuho/ast/dead_code.py +199 -0
- yuho/ast/exhaustiveness.py +503 -0
- yuho/ast/nodes.py +907 -0
- yuho/ast/overlap.py +291 -0
- yuho/ast/reachability.py +293 -0
- yuho/ast/scope_analysis.py +490 -0
- yuho/ast/transformer.py +490 -0
- yuho/ast/type_check.py +471 -0
- yuho/ast/type_inference.py +425 -0
- yuho/ast/visitor.py +239 -0
- yuho/cli/__init__.py +14 -0
- yuho/cli/commands/__init__.py +1 -0
- yuho/cli/commands/api.py +431 -0
- yuho/cli/commands/ast_viz.py +334 -0
- yuho/cli/commands/check.py +218 -0
- yuho/cli/commands/config.py +311 -0
- yuho/cli/commands/contribute.py +122 -0
- yuho/cli/commands/diff.py +487 -0
- yuho/cli/commands/explain.py +240 -0
- yuho/cli/commands/fmt.py +253 -0
- yuho/cli/commands/generate.py +316 -0
- yuho/cli/commands/graph.py +410 -0
- yuho/cli/commands/init.py +120 -0
- yuho/cli/commands/library.py +656 -0
- yuho/cli/commands/lint.py +503 -0
- yuho/cli/commands/lsp.py +36 -0
- yuho/cli/commands/preview.py +377 -0
- yuho/cli/commands/repl.py +444 -0
- yuho/cli/commands/serve.py +44 -0
- yuho/cli/commands/test.py +528 -0
- yuho/cli/commands/transpile.py +121 -0
- yuho/cli/commands/wizard.py +370 -0
- yuho/cli/completions.py +182 -0
- yuho/cli/error_formatter.py +193 -0
- yuho/cli/main.py +1064 -0
- yuho/config/__init__.py +46 -0
- yuho/config/loader.py +235 -0
- yuho/config/mask.py +194 -0
- yuho/config/schema.py +147 -0
- yuho/library/__init__.py +84 -0
- yuho/library/index.py +328 -0
- yuho/library/install.py +699 -0
- yuho/library/lockfile.py +330 -0
- yuho/library/package.py +421 -0
- yuho/library/resolver.py +791 -0
- yuho/library/signature.py +335 -0
- yuho/llm/__init__.py +45 -0
- yuho/llm/config.py +75 -0
- yuho/llm/factory.py +123 -0
- yuho/llm/prompts.py +146 -0
- yuho/llm/providers.py +383 -0
- yuho/llm/utils.py +470 -0
- yuho/lsp/__init__.py +14 -0
- yuho/lsp/code_action_handler.py +518 -0
- yuho/lsp/completion_handler.py +85 -0
- yuho/lsp/diagnostics.py +100 -0
- yuho/lsp/hover_handler.py +130 -0
- yuho/lsp/server.py +1425 -0
- yuho/mcp/__init__.py +10 -0
- yuho/mcp/server.py +1452 -0
- yuho/parser/__init__.py +8 -0
- yuho/parser/source_location.py +108 -0
- yuho/parser/wrapper.py +311 -0
- yuho/testing/__init__.py +48 -0
- yuho/testing/coverage.py +274 -0
- yuho/testing/fixtures.py +263 -0
- yuho/transpile/__init__.py +52 -0
- yuho/transpile/alloy_transpiler.py +546 -0
- yuho/transpile/base.py +100 -0
- yuho/transpile/blocks_transpiler.py +338 -0
- yuho/transpile/english_transpiler.py +470 -0
- yuho/transpile/graphql_transpiler.py +404 -0
- yuho/transpile/json_transpiler.py +217 -0
- yuho/transpile/jsonld_transpiler.py +250 -0
- yuho/transpile/latex_preamble.py +161 -0
- yuho/transpile/latex_transpiler.py +406 -0
- yuho/transpile/latex_utils.py +206 -0
- yuho/transpile/mermaid_transpiler.py +357 -0
- yuho/transpile/registry.py +275 -0
- yuho/verify/__init__.py +43 -0
- yuho/verify/alloy.py +352 -0
- yuho/verify/combined.py +218 -0
- yuho/verify/z3_solver.py +1155 -0
- yuho-5.0.0.dist-info/METADATA +186 -0
- yuho-5.0.0.dist-info/RECORD +91 -0
- yuho-5.0.0.dist-info/WHEEL +4 -0
- yuho-5.0.0.dist-info/entry_points.txt +2 -0
yuho/ast/type_check.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Type checking visitor for Yuho AST.
|
|
3
|
+
|
|
4
|
+
Validates type consistency and reports type errors:
|
|
5
|
+
- Assignment type compatibility
|
|
6
|
+
- Binary operator type compatibility
|
|
7
|
+
- Function argument type matching
|
|
8
|
+
- Return type matching
|
|
9
|
+
- Match arm type consistency
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Any, Dict, List, Optional
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
from yuho.ast import nodes
|
|
16
|
+
from yuho.ast.visitor import Visitor
|
|
17
|
+
from yuho.ast.type_inference import (
|
|
18
|
+
TypeAnnotation,
|
|
19
|
+
TypeInferenceVisitor,
|
|
20
|
+
TypeInferenceResult,
|
|
21
|
+
INT_TYPE,
|
|
22
|
+
FLOAT_TYPE,
|
|
23
|
+
BOOL_TYPE,
|
|
24
|
+
STRING_TYPE,
|
|
25
|
+
MONEY_TYPE,
|
|
26
|
+
PERCENT_TYPE,
|
|
27
|
+
DATE_TYPE,
|
|
28
|
+
DURATION_TYPE,
|
|
29
|
+
VOID_TYPE,
|
|
30
|
+
PASS_TYPE,
|
|
31
|
+
UNKNOWN_TYPE,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class TypeErrorInfo:
|
|
37
|
+
"""Represents a type error with location and context."""
|
|
38
|
+
|
|
39
|
+
message: str
|
|
40
|
+
line: int = 0
|
|
41
|
+
column: int = 0
|
|
42
|
+
node_type: str = ""
|
|
43
|
+
severity: str = "error" # "error" or "warning"
|
|
44
|
+
|
|
45
|
+
def __str__(self) -> str:
|
|
46
|
+
loc = f"{self.line}:{self.column}" if self.line else ""
|
|
47
|
+
return f"[{self.severity}] {loc} {self.message}"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class TypeCheckResult:
|
|
52
|
+
"""Result of type checking including all errors and warnings."""
|
|
53
|
+
|
|
54
|
+
errors: List[TypeErrorInfo] = field(default_factory=list)
|
|
55
|
+
warnings: List[TypeErrorInfo] = field(default_factory=list)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def has_errors(self) -> bool:
|
|
59
|
+
return len(self.errors) > 0
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def is_valid(self) -> bool:
|
|
63
|
+
return not self.has_errors
|
|
64
|
+
|
|
65
|
+
def add_error(
|
|
66
|
+
self,
|
|
67
|
+
message: str,
|
|
68
|
+
node: Optional[nodes.ASTNode] = None,
|
|
69
|
+
severity: str = "error",
|
|
70
|
+
) -> None:
|
|
71
|
+
"""Add a type error."""
|
|
72
|
+
line = 0
|
|
73
|
+
column = 0
|
|
74
|
+
node_type = ""
|
|
75
|
+
|
|
76
|
+
if node and node.source_location:
|
|
77
|
+
line = node.source_location.start_line
|
|
78
|
+
column = node.source_location.start_column
|
|
79
|
+
if node:
|
|
80
|
+
node_type = type(node).__name__
|
|
81
|
+
|
|
82
|
+
error = TypeErrorInfo(
|
|
83
|
+
message=message,
|
|
84
|
+
line=line,
|
|
85
|
+
column=column,
|
|
86
|
+
node_type=node_type,
|
|
87
|
+
severity=severity,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if severity == "error":
|
|
91
|
+
self.errors.append(error)
|
|
92
|
+
else:
|
|
93
|
+
self.warnings.append(error)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class TypeCheckVisitor(Visitor):
|
|
97
|
+
"""
|
|
98
|
+
Visitor that validates type consistency and reports errors.
|
|
99
|
+
|
|
100
|
+
Runs after TypeInferenceVisitor to validate inferred types are consistent.
|
|
101
|
+
|
|
102
|
+
Usage:
|
|
103
|
+
# First run type inference
|
|
104
|
+
infer_visitor = TypeInferenceVisitor()
|
|
105
|
+
module.accept(infer_visitor)
|
|
106
|
+
|
|
107
|
+
# Then run type checking
|
|
108
|
+
check_visitor = TypeCheckVisitor(infer_visitor.result)
|
|
109
|
+
module.accept(check_visitor)
|
|
110
|
+
result = check_visitor.result
|
|
111
|
+
|
|
112
|
+
if result.has_errors:
|
|
113
|
+
for error in result.errors:
|
|
114
|
+
print(error)
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
# Types that are compatible with each other for numeric operations
|
|
118
|
+
NUMERIC_TYPES = frozenset({"int", "float", "money", "percent"})
|
|
119
|
+
|
|
120
|
+
# Types that can be compared with equality
|
|
121
|
+
COMPARABLE_TYPES = frozenset({
|
|
122
|
+
"int", "float", "bool", "string", "money", "percent", "date", "duration"
|
|
123
|
+
})
|
|
124
|
+
|
|
125
|
+
# Types that can be ordered (< > <= >=)
|
|
126
|
+
ORDERABLE_TYPES = frozenset({
|
|
127
|
+
"int", "float", "money", "percent", "date", "duration"
|
|
128
|
+
})
|
|
129
|
+
|
|
130
|
+
def __init__(self, type_info: TypeInferenceResult) -> None:
|
|
131
|
+
self.type_info = type_info
|
|
132
|
+
self.result = TypeCheckResult()
|
|
133
|
+
self._current_function_return: Optional[TypeAnnotation] = None
|
|
134
|
+
|
|
135
|
+
def _get_type(self, node: nodes.ASTNode) -> TypeAnnotation:
|
|
136
|
+
"""Get the inferred type for a node."""
|
|
137
|
+
return self.type_info.get_type(node)
|
|
138
|
+
|
|
139
|
+
def _types_compatible(
|
|
140
|
+
self,
|
|
141
|
+
expected: TypeAnnotation,
|
|
142
|
+
actual: TypeAnnotation,
|
|
143
|
+
allow_coercion: bool = True,
|
|
144
|
+
) -> bool:
|
|
145
|
+
"""Check if two types are compatible."""
|
|
146
|
+
# Unknown types are always compatible (inference failed)
|
|
147
|
+
if expected == UNKNOWN_TYPE or actual == UNKNOWN_TYPE:
|
|
148
|
+
return True
|
|
149
|
+
|
|
150
|
+
# Pass type is compatible with anything (placeholder)
|
|
151
|
+
if actual == PASS_TYPE:
|
|
152
|
+
return True
|
|
153
|
+
|
|
154
|
+
# Exact match
|
|
155
|
+
if expected.type_name == actual.type_name:
|
|
156
|
+
return True
|
|
157
|
+
|
|
158
|
+
# Optional type is compatible with base type
|
|
159
|
+
if expected.is_optional and not actual.is_optional:
|
|
160
|
+
return expected.type_name == actual.type_name
|
|
161
|
+
|
|
162
|
+
# Numeric coercion: int -> float
|
|
163
|
+
if allow_coercion:
|
|
164
|
+
if expected.type_name == "float" and actual.type_name == "int":
|
|
165
|
+
return True
|
|
166
|
+
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
def _check_binary_types(
|
|
170
|
+
self,
|
|
171
|
+
node: nodes.BinaryExprNode,
|
|
172
|
+
left_type: TypeAnnotation,
|
|
173
|
+
right_type: TypeAnnotation,
|
|
174
|
+
) -> None:
|
|
175
|
+
"""Check type compatibility for binary operators."""
|
|
176
|
+
op = node.operator
|
|
177
|
+
|
|
178
|
+
# Logical operators require bool
|
|
179
|
+
if op in ("&&", "||", "and", "or"):
|
|
180
|
+
if left_type.type_name != "bool":
|
|
181
|
+
self.result.add_error(
|
|
182
|
+
f"Left operand of '{op}' must be bool, got {left_type}",
|
|
183
|
+
node,
|
|
184
|
+
)
|
|
185
|
+
if right_type.type_name != "bool":
|
|
186
|
+
self.result.add_error(
|
|
187
|
+
f"Right operand of '{op}' must be bool, got {right_type}",
|
|
188
|
+
node,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Comparison operators
|
|
192
|
+
elif op in ("==", "!="):
|
|
193
|
+
if not self._types_compatible(left_type, right_type):
|
|
194
|
+
self.result.add_error(
|
|
195
|
+
f"Cannot compare {left_type} with {right_type}",
|
|
196
|
+
node,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Ordering operators
|
|
200
|
+
elif op in ("<", ">", "<=", ">="):
|
|
201
|
+
if left_type.type_name not in self.ORDERABLE_TYPES:
|
|
202
|
+
self.result.add_error(
|
|
203
|
+
f"Type {left_type} is not orderable",
|
|
204
|
+
node,
|
|
205
|
+
)
|
|
206
|
+
if right_type.type_name not in self.ORDERABLE_TYPES:
|
|
207
|
+
self.result.add_error(
|
|
208
|
+
f"Type {right_type} is not orderable",
|
|
209
|
+
node,
|
|
210
|
+
)
|
|
211
|
+
if left_type.type_name != right_type.type_name:
|
|
212
|
+
self.result.add_error(
|
|
213
|
+
f"Cannot compare {left_type} with {right_type}",
|
|
214
|
+
node,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Arithmetic operators
|
|
218
|
+
elif op in ("+", "-", "*", "/", "%"):
|
|
219
|
+
# String concatenation with +
|
|
220
|
+
if op == "+" and left_type.type_name == "string":
|
|
221
|
+
if right_type.type_name != "string":
|
|
222
|
+
self.result.add_error(
|
|
223
|
+
f"Cannot concatenate string with {right_type}",
|
|
224
|
+
node,
|
|
225
|
+
)
|
|
226
|
+
# Duration arithmetic
|
|
227
|
+
elif left_type.type_name == "duration" or right_type.type_name == "duration":
|
|
228
|
+
if op not in ("+", "-"):
|
|
229
|
+
self.result.add_error(
|
|
230
|
+
f"Invalid operation '{op}' on duration",
|
|
231
|
+
node,
|
|
232
|
+
)
|
|
233
|
+
# Numeric arithmetic
|
|
234
|
+
elif left_type.type_name not in self.NUMERIC_TYPES:
|
|
235
|
+
self.result.add_error(
|
|
236
|
+
f"Type {left_type} does not support arithmetic",
|
|
237
|
+
node,
|
|
238
|
+
)
|
|
239
|
+
elif right_type.type_name not in self.NUMERIC_TYPES:
|
|
240
|
+
self.result.add_error(
|
|
241
|
+
f"Type {right_type} does not support arithmetic",
|
|
242
|
+
node,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# =========================================================================
|
|
246
|
+
# Expression nodes
|
|
247
|
+
# =========================================================================
|
|
248
|
+
|
|
249
|
+
def visit_binary_expr(self, node: nodes.BinaryExprNode) -> Any:
|
|
250
|
+
"""Check binary expression type compatibility."""
|
|
251
|
+
self.visit(node.left)
|
|
252
|
+
self.visit(node.right)
|
|
253
|
+
|
|
254
|
+
left_type = self._get_type(node.left)
|
|
255
|
+
right_type = self._get_type(node.right)
|
|
256
|
+
|
|
257
|
+
self._check_binary_types(node, left_type, right_type)
|
|
258
|
+
return self.generic_visit(node)
|
|
259
|
+
|
|
260
|
+
def visit_unary_expr(self, node: nodes.UnaryExprNode) -> Any:
|
|
261
|
+
"""Check unary expression type compatibility."""
|
|
262
|
+
self.visit(node.operand)
|
|
263
|
+
operand_type = self._get_type(node.operand)
|
|
264
|
+
|
|
265
|
+
op = node.operator
|
|
266
|
+
if op in ("!", "not"):
|
|
267
|
+
if operand_type.type_name != "bool":
|
|
268
|
+
self.result.add_error(
|
|
269
|
+
f"Operand of '{op}' must be bool, got {operand_type}",
|
|
270
|
+
node,
|
|
271
|
+
)
|
|
272
|
+
elif op == "-":
|
|
273
|
+
if operand_type.type_name not in self.NUMERIC_TYPES:
|
|
274
|
+
self.result.add_error(
|
|
275
|
+
f"Cannot negate non-numeric type {operand_type}",
|
|
276
|
+
node,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return self.generic_visit(node)
|
|
280
|
+
|
|
281
|
+
# =========================================================================
|
|
282
|
+
# Variable and assignment
|
|
283
|
+
# =========================================================================
|
|
284
|
+
|
|
285
|
+
def visit_variable_decl(self, node: nodes.VariableDecl) -> Any:
|
|
286
|
+
"""Check variable declaration type matches initializer."""
|
|
287
|
+
if node.type_annotation and node.value:
|
|
288
|
+
declared_type = self.type_info.variable_types.get(node.name, UNKNOWN_TYPE)
|
|
289
|
+
value_type = self._get_type(node.value)
|
|
290
|
+
|
|
291
|
+
if not self._types_compatible(declared_type, value_type):
|
|
292
|
+
self.result.add_error(
|
|
293
|
+
f"Cannot assign {value_type} to variable of type {declared_type}",
|
|
294
|
+
node,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if node.value:
|
|
298
|
+
self.visit(node.value)
|
|
299
|
+
|
|
300
|
+
return self.generic_visit(node)
|
|
301
|
+
|
|
302
|
+
def visit_assignment_stmt(self, node: nodes.AssignmentStmt) -> Any:
|
|
303
|
+
"""Check assignment type compatibility."""
|
|
304
|
+
self.visit(node.target)
|
|
305
|
+
self.visit(node.value)
|
|
306
|
+
|
|
307
|
+
target_type = self._get_type(node.target)
|
|
308
|
+
value_type = self._get_type(node.value)
|
|
309
|
+
|
|
310
|
+
if not self._types_compatible(target_type, value_type):
|
|
311
|
+
self.result.add_error(
|
|
312
|
+
f"Cannot assign {value_type} to {target_type}",
|
|
313
|
+
node,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return self.generic_visit(node)
|
|
317
|
+
|
|
318
|
+
# =========================================================================
|
|
319
|
+
# Function calls and returns
|
|
320
|
+
# =========================================================================
|
|
321
|
+
|
|
322
|
+
def visit_function_call(self, node: nodes.FunctionCallNode) -> Any:
|
|
323
|
+
"""Check function argument types match parameters."""
|
|
324
|
+
func_name = node.callee if isinstance(node.callee, str) else getattr(node.callee, "name", "")
|
|
325
|
+
|
|
326
|
+
if func_name in self.type_info.function_sigs:
|
|
327
|
+
param_types, _ = self.type_info.function_sigs[func_name]
|
|
328
|
+
|
|
329
|
+
# Check argument count
|
|
330
|
+
if len(node.arguments) != len(param_types):
|
|
331
|
+
self.result.add_error(
|
|
332
|
+
f"Function '{func_name}' expects {len(param_types)} arguments, got {len(node.arguments)}",
|
|
333
|
+
node,
|
|
334
|
+
)
|
|
335
|
+
else:
|
|
336
|
+
# Check each argument type
|
|
337
|
+
for i, (arg, expected_type) in enumerate(zip(node.arguments, param_types)):
|
|
338
|
+
self.visit(arg)
|
|
339
|
+
arg_type = self._get_type(arg)
|
|
340
|
+
if not self._types_compatible(expected_type, arg_type):
|
|
341
|
+
self.result.add_error(
|
|
342
|
+
f"Argument {i + 1} of '{func_name}' expected {expected_type}, got {arg_type}",
|
|
343
|
+
arg,
|
|
344
|
+
)
|
|
345
|
+
else:
|
|
346
|
+
for arg in node.arguments:
|
|
347
|
+
self.visit(arg)
|
|
348
|
+
|
|
349
|
+
return self.generic_visit(node)
|
|
350
|
+
|
|
351
|
+
def visit_function_def(self, node: nodes.FunctionDefNode) -> Any:
|
|
352
|
+
"""Check function body return statements match declared return type."""
|
|
353
|
+
if node.return_type:
|
|
354
|
+
from yuho.ast.type_inference import TypeInferenceVisitor
|
|
355
|
+
# Get return type from type info
|
|
356
|
+
if node.name in self.type_info.function_sigs:
|
|
357
|
+
_, return_type = self.type_info.function_sigs[node.name]
|
|
358
|
+
self._current_function_return = return_type
|
|
359
|
+
else:
|
|
360
|
+
self._current_function_return = VOID_TYPE
|
|
361
|
+
else:
|
|
362
|
+
self._current_function_return = VOID_TYPE
|
|
363
|
+
|
|
364
|
+
# Visit body
|
|
365
|
+
if node.body:
|
|
366
|
+
self.visit(node.body)
|
|
367
|
+
|
|
368
|
+
self._current_function_return = None
|
|
369
|
+
return self.generic_visit(node)
|
|
370
|
+
|
|
371
|
+
def visit_return_stmt(self, node: nodes.ReturnStmt) -> Any:
|
|
372
|
+
"""Check return statement type matches function return type."""
|
|
373
|
+
if self._current_function_return:
|
|
374
|
+
if node.value:
|
|
375
|
+
self.visit(node.value)
|
|
376
|
+
return_value_type = self._get_type(node.value)
|
|
377
|
+
|
|
378
|
+
if not self._types_compatible(self._current_function_return, return_value_type):
|
|
379
|
+
self.result.add_error(
|
|
380
|
+
f"Return type {return_value_type} does not match expected {self._current_function_return}",
|
|
381
|
+
node,
|
|
382
|
+
)
|
|
383
|
+
elif self._current_function_return != VOID_TYPE:
|
|
384
|
+
self.result.add_error(
|
|
385
|
+
f"Missing return value, expected {self._current_function_return}",
|
|
386
|
+
node,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
return self.generic_visit(node)
|
|
390
|
+
|
|
391
|
+
# =========================================================================
|
|
392
|
+
# Match expression
|
|
393
|
+
# =========================================================================
|
|
394
|
+
|
|
395
|
+
def visit_match_expr(self, node: nodes.MatchExprNode) -> Any:
|
|
396
|
+
"""Check all match arms have consistent types."""
|
|
397
|
+
if node.scrutinee:
|
|
398
|
+
self.visit(node.scrutinee)
|
|
399
|
+
|
|
400
|
+
arm_types: List[TypeAnnotation] = []
|
|
401
|
+
for arm in node.arms:
|
|
402
|
+
self.visit(arm)
|
|
403
|
+
arm_type = self._get_type(arm)
|
|
404
|
+
if arm_type != PASS_TYPE and arm_type != UNKNOWN_TYPE:
|
|
405
|
+
arm_types.append(arm_type)
|
|
406
|
+
|
|
407
|
+
# Check all non-pass arms have the same type
|
|
408
|
+
if len(arm_types) > 1:
|
|
409
|
+
first_type = arm_types[0]
|
|
410
|
+
for i, arm_type in enumerate(arm_types[1:], 1):
|
|
411
|
+
if not self._types_compatible(first_type, arm_type):
|
|
412
|
+
self.result.add_error(
|
|
413
|
+
f"Match arm {i + 1} has type {arm_type}, expected {first_type}",
|
|
414
|
+
node.arms[i] if i < len(node.arms) else node,
|
|
415
|
+
severity="warning",
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
return self.generic_visit(node)
|
|
419
|
+
|
|
420
|
+
def visit_match_arm(self, node: nodes.MatchArm) -> Any:
|
|
421
|
+
"""Check guard expression is boolean."""
|
|
422
|
+
self.visit(node.pattern)
|
|
423
|
+
|
|
424
|
+
if node.guard:
|
|
425
|
+
self.visit(node.guard)
|
|
426
|
+
guard_type = self._get_type(node.guard)
|
|
427
|
+
if guard_type.type_name != "bool":
|
|
428
|
+
self.result.add_error(
|
|
429
|
+
f"Match arm guard must be bool, got {guard_type}",
|
|
430
|
+
node.guard,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
self.visit(node.body)
|
|
434
|
+
return self.generic_visit(node)
|
|
435
|
+
|
|
436
|
+
# =========================================================================
|
|
437
|
+
# Struct literal
|
|
438
|
+
# =========================================================================
|
|
439
|
+
|
|
440
|
+
def visit_struct_literal(self, node: nodes.StructLiteralNode) -> Any:
|
|
441
|
+
"""Check struct field assignments match field types."""
|
|
442
|
+
if node.type_name and node.type_name in self.type_info.struct_defs:
|
|
443
|
+
struct_fields = self.type_info.struct_defs[node.type_name]
|
|
444
|
+
|
|
445
|
+
for field_assign in node.field_assignments:
|
|
446
|
+
self.visit(field_assign)
|
|
447
|
+
|
|
448
|
+
if field_assign.name in struct_fields:
|
|
449
|
+
expected_type = struct_fields[field_assign.name]
|
|
450
|
+
actual_type = self._get_type(field_assign.value)
|
|
451
|
+
|
|
452
|
+
if not self._types_compatible(expected_type, actual_type):
|
|
453
|
+
self.result.add_error(
|
|
454
|
+
f"Field '{field_assign.name}' expected {expected_type}, got {actual_type}",
|
|
455
|
+
field_assign,
|
|
456
|
+
)
|
|
457
|
+
else:
|
|
458
|
+
for field_assign in node.field_assignments:
|
|
459
|
+
self.visit(field_assign)
|
|
460
|
+
|
|
461
|
+
return self.generic_visit(node)
|
|
462
|
+
|
|
463
|
+
# =========================================================================
|
|
464
|
+
# Module entry point
|
|
465
|
+
# =========================================================================
|
|
466
|
+
|
|
467
|
+
def visit_module(self, node: nodes.ModuleNode) -> Any:
|
|
468
|
+
"""Entry point: check all declarations."""
|
|
469
|
+
for decl in node.declarations:
|
|
470
|
+
self.visit(decl)
|
|
471
|
+
return self.result
|