jaclang 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jaclang might be problematic. Click here for more details.
- jaclang/cli/cli.md +3 -3
- jaclang/cli/cli.py +37 -37
- jaclang/cli/cmdreg.py +45 -140
- jaclang/compiler/constant.py +0 -1
- jaclang/compiler/jac.lark +3 -6
- jaclang/compiler/larkparse/jac_parser.py +2 -2
- jaclang/compiler/parser.py +213 -34
- jaclang/compiler/passes/main/__init__.py +2 -4
- jaclang/compiler/passes/main/def_use_pass.py +0 -4
- jaclang/compiler/passes/main/predynamo_pass.py +221 -0
- jaclang/compiler/passes/main/pyast_gen_pass.py +83 -55
- jaclang/compiler/passes/main/pyast_load_pass.py +66 -40
- jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_binary_op.jac +21 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_call_expr_class.jac +12 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_cyclic_symbol.jac +4 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_expr_call.jac +9 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_import_missing_module.jac +13 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_magic_call.jac +17 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_mod_path.jac +8 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
- jaclang/compiler/passes/main/tests/test_checker_pass.py +265 -0
- jaclang/compiler/passes/main/tests/test_predynamo_pass.py +57 -0
- jaclang/compiler/passes/main/type_checker_pass.py +36 -61
- jaclang/compiler/passes/tool/doc_ir_gen_pass.py +204 -44
- jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
- jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
- jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
- jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +171 -11
- jaclang/compiler/passes/transform.py +12 -8
- jaclang/compiler/program.py +14 -6
- jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
- jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
- jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
- jaclang/compiler/tests/fixtures/python_module.py +1 -0
- jaclang/compiler/tests/test_importer.py +39 -0
- jaclang/compiler/tests/test_parser.py +49 -0
- jaclang/compiler/type_system/operations.py +104 -0
- jaclang/compiler/type_system/type_evaluator.py +470 -47
- jaclang/compiler/type_system/type_utils.py +246 -0
- jaclang/compiler/type_system/types.py +58 -2
- jaclang/compiler/unitree.py +79 -94
- jaclang/langserve/engine.jac +253 -230
- jaclang/langserve/server.jac +46 -15
- jaclang/langserve/tests/fixtures/circle.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
- jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
- jaclang/langserve/tests/server_test/circle_template.jac +80 -0
- jaclang/langserve/tests/server_test/glob_template.jac +4 -0
- jaclang/langserve/tests/server_test/test_lang_serve.py +154 -312
- jaclang/langserve/tests/server_test/utils.py +153 -116
- jaclang/langserve/tests/test_dev_server.py +1 -1
- jaclang/langserve/tests/test_server.py +30 -86
- jaclang/langserve/utils.jac +56 -63
- jaclang/runtimelib/machine.py +7 -0
- jaclang/runtimelib/meta_importer.py +27 -1
- jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
- jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
- jaclang/settings.py +18 -14
- jaclang/tests/fixtures/abc_check.jac +3 -3
- jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
- jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
- jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
- jaclang/tests/fixtures/jac_run_py_bugs.py +18 -0
- jaclang/tests/fixtures/jac_run_py_import.py +13 -0
- jaclang/tests/fixtures/lambda_arg_annotation.jac +15 -0
- jaclang/tests/fixtures/lambda_self.jac +18 -0
- jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
- jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
- jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
- jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
- jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
- jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
- jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
- jaclang/tests/fixtures/py2jac_params.py +8 -0
- jaclang/tests/fixtures/run_test.jac +4 -4
- jaclang/tests/test_cli.py +103 -18
- jaclang/tests/test_language.py +74 -16
- jaclang/utils/helpers.py +47 -2
- jaclang/utils/module_resolver.py +11 -1
- jaclang/utils/test.py +8 -0
- jaclang/utils/treeprinter.py +0 -18
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/METADATA +3 -3
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/RECORD +99 -62
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/WHEEL +1 -1
- jaclang/compiler/passes/main/inheritance_pass.py +0 -131
- jaclang/langserve/dev_engine.jac +0 -645
- jaclang/langserve/dev_server.jac +0 -201
- jaclang/langserve/tests/server_test/code_test.py +0 -0
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/entry_points.txt +0 -0
|
@@ -6,19 +6,30 @@ PyrightReference:
|
|
|
6
6
|
packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
+
import ast as py_ast
|
|
10
|
+
import os
|
|
9
11
|
from dataclasses import dataclass
|
|
10
12
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, cast
|
|
13
|
+
from typing import Callable, TYPE_CHECKING, cast
|
|
12
14
|
|
|
13
15
|
import jaclang.compiler.unitree as uni
|
|
16
|
+
from jaclang.compiler import TOKEN_MAP
|
|
17
|
+
from jaclang.compiler.constant import Tokens as Tok
|
|
18
|
+
from jaclang.compiler.passes.main.pyast_load_pass import PyastBuildPass
|
|
19
|
+
from jaclang.compiler.passes.main.sym_tab_build_pass import SymTabBuildPass
|
|
14
20
|
from jaclang.compiler.type_system import types
|
|
21
|
+
from jaclang.runtimelib.utils import read_file_with_encoding
|
|
15
22
|
|
|
16
23
|
if TYPE_CHECKING:
|
|
17
24
|
from jaclang.compiler.program import JacProgram
|
|
18
25
|
|
|
19
|
-
from .
|
|
26
|
+
from . import operations
|
|
27
|
+
from . import type_utils
|
|
20
28
|
from .types import TypeBase
|
|
21
29
|
|
|
30
|
+
# The callback type definition for the diagnostic messages.
|
|
31
|
+
DiagnosticCallback = Callable[[uni.UniNode, str, bool], None]
|
|
32
|
+
|
|
22
33
|
|
|
23
34
|
@dataclass
|
|
24
35
|
class PrefetchedTypes:
|
|
@@ -34,6 +45,7 @@ class PrefetchedTypes:
|
|
|
34
45
|
tuple_class: TypeBase | None = None
|
|
35
46
|
bool_class: TypeBase | None = None
|
|
36
47
|
int_class: TypeBase | None = None
|
|
48
|
+
float_class: TypeBase | None = None
|
|
37
49
|
str_class: TypeBase | None = None
|
|
38
50
|
dict_class: TypeBase | None = None
|
|
39
51
|
module_type_class: TypeBase | None = None
|
|
@@ -44,10 +56,50 @@ class PrefetchedTypes:
|
|
|
44
56
|
template_class: TypeBase | None = None
|
|
45
57
|
|
|
46
58
|
|
|
59
|
+
@dataclass
|
|
60
|
+
class SymbolResolutionStackEntry:
|
|
61
|
+
"""Represents a single entry in the symbol resolution stack."""
|
|
62
|
+
|
|
63
|
+
symbol: uni.Symbol
|
|
64
|
+
|
|
65
|
+
# Initially true, it's set to false if a recursion
|
|
66
|
+
# is detected.
|
|
67
|
+
is_result_valid: bool = True
|
|
68
|
+
|
|
69
|
+
# Some limited forms of recursion are allowed. In these
|
|
70
|
+
# cases, a partially-constructed type can be registered.
|
|
71
|
+
partial_type: TypeBase | None = None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class MatchArgsToParamsResult:
|
|
76
|
+
"""Result of matching arguments to parameters."""
|
|
77
|
+
|
|
78
|
+
# FIXME: This class implementation is modified from pyright to make it
|
|
79
|
+
# simple and work for now, however this needs to be revisited and
|
|
80
|
+
# implemented properly.
|
|
81
|
+
arg_params: dict[uni.Expr | uni.KWPair, types.Parameter | None]
|
|
82
|
+
|
|
83
|
+
overload: types.FunctionType | None = None
|
|
84
|
+
argument_errors: bool = False
|
|
85
|
+
|
|
86
|
+
|
|
47
87
|
class TypeEvaluator:
|
|
48
88
|
"""Type evaluator for JacLang."""
|
|
49
89
|
|
|
50
|
-
|
|
90
|
+
# NOTE: This is done in the binder pass of pyright, however I'm doing this
|
|
91
|
+
# here, cause this will be the entry point of the type checker and we're not
|
|
92
|
+
# relying on the binder pass at the moment and we can go back to binder pass
|
|
93
|
+
# in the future if we needed it.
|
|
94
|
+
_BUILTINS_STUB_FILE_PATH = os.path.join(
|
|
95
|
+
os.path.dirname(__file__),
|
|
96
|
+
"../../vendor/typeshed/stdlib/builtins.pyi",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
program: "JacProgram",
|
|
102
|
+
) -> None:
|
|
51
103
|
"""Initialize the type evaluator with prefetched types.
|
|
52
104
|
|
|
53
105
|
Implementation Note:
|
|
@@ -58,14 +110,113 @@ class TypeEvaluator:
|
|
|
58
110
|
in some place then it will not be available in the evaluator, So we
|
|
59
111
|
are prefetching the builtins at the constructor level once.
|
|
60
112
|
"""
|
|
61
|
-
self.builtins_module = builtins_module
|
|
62
113
|
self.program = program
|
|
114
|
+
self.symbol_resolution_stack: list[SymbolResolutionStackEntry] = []
|
|
115
|
+
self.builtins_module = self._load_builtins_stub_module()
|
|
63
116
|
self.prefetch = self._prefetch_types()
|
|
117
|
+
self.diagnostic_callback: DiagnosticCallback | None = None
|
|
118
|
+
|
|
119
|
+
def _load_builtins_stub_module(self) -> uni.Module:
|
|
120
|
+
"""Load and return the builtins stub module."""
|
|
121
|
+
if not os.path.exists(TypeEvaluator._BUILTINS_STUB_FILE_PATH):
|
|
122
|
+
raise FileNotFoundError(
|
|
123
|
+
f"Builtins stub file not found at {TypeEvaluator._BUILTINS_STUB_FILE_PATH}"
|
|
124
|
+
)
|
|
125
|
+
file_content = read_file_with_encoding(TypeEvaluator._BUILTINS_STUB_FILE_PATH)
|
|
126
|
+
uni_source = uni.Source(file_content, TypeEvaluator._BUILTINS_STUB_FILE_PATH)
|
|
127
|
+
mod = PyastBuildPass(
|
|
128
|
+
ir_in=uni.PythonModuleAst(
|
|
129
|
+
py_ast.parse(file_content),
|
|
130
|
+
orig_src=uni_source,
|
|
131
|
+
),
|
|
132
|
+
prog=self.program,
|
|
133
|
+
).ir_out
|
|
134
|
+
SymTabBuildPass(ir_in=mod, prog=self.program)
|
|
135
|
+
return mod
|
|
136
|
+
|
|
137
|
+
def _get_builtin_type(self, name: str) -> TypeBase:
|
|
138
|
+
"""Return the built-in type with the given name."""
|
|
139
|
+
if (symbol := self.builtins_module.lookup(name)) is not None:
|
|
140
|
+
return self.get_type_of_symbol(symbol)
|
|
141
|
+
return types.UnknownType()
|
|
142
|
+
|
|
143
|
+
def _prefetch_types(self) -> "PrefetchedTypes":
|
|
144
|
+
"""Return the prefetched types for the type evaluator."""
|
|
145
|
+
return PrefetchedTypes(
|
|
146
|
+
# TODO: Pyright first try load NoneType from typeshed and if it cannot
|
|
147
|
+
# then it set to unknown type.
|
|
148
|
+
none_type_class=types.UnknownType(),
|
|
149
|
+
object_class=self._get_builtin_type("object"),
|
|
150
|
+
type_class=self._get_builtin_type("type"),
|
|
151
|
+
# union_type_class=
|
|
152
|
+
# awaitable_class=
|
|
153
|
+
# function_class=
|
|
154
|
+
# method_class=
|
|
155
|
+
tuple_class=self._get_builtin_type("tuple"),
|
|
156
|
+
bool_class=self._get_builtin_type("bool"),
|
|
157
|
+
int_class=self._get_builtin_type("int"),
|
|
158
|
+
float_class=self._get_builtin_type("float"),
|
|
159
|
+
str_class=self._get_builtin_type("str"),
|
|
160
|
+
dict_class=self._get_builtin_type("dict"),
|
|
161
|
+
# module_type_class=
|
|
162
|
+
# typed_dict_class=
|
|
163
|
+
# typed_dict_private_class=
|
|
164
|
+
# supports_keys_and_get_item_class=
|
|
165
|
+
# mapping_class=
|
|
166
|
+
# template_class=
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def add_diagnostic(
|
|
170
|
+
self, node: uni.UniNode, message: str, warning: bool = False
|
|
171
|
+
) -> None:
|
|
172
|
+
"""Add a diagnostic message to the program."""
|
|
173
|
+
if self.diagnostic_callback:
|
|
174
|
+
self.diagnostic_callback(node, message, warning)
|
|
175
|
+
|
|
176
|
+
# -------------------------------------------------------------------------
|
|
177
|
+
# Symbol resolution stack
|
|
178
|
+
# -------------------------------------------------------------------------
|
|
179
|
+
|
|
180
|
+
def get_index_of_symbol_resolution(self, symbol: uni.Symbol) -> int | None:
|
|
181
|
+
"""Get the index of a symbol in the resolution stack."""
|
|
182
|
+
for i, entry in enumerate(self.symbol_resolution_stack):
|
|
183
|
+
if entry.symbol == symbol:
|
|
184
|
+
return i
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
def push_symbol_resolution(self, symbol: uni.Symbol) -> bool:
|
|
188
|
+
"""
|
|
189
|
+
Push a symbol onto the resolution stack.
|
|
190
|
+
|
|
191
|
+
Return False if recursion detected and in that case it won't push the symbol.
|
|
192
|
+
"""
|
|
193
|
+
idx = self.get_index_of_symbol_resolution(symbol)
|
|
194
|
+
if idx is not None:
|
|
195
|
+
# Mark all of the entries between these two as invalid.
|
|
196
|
+
for i in range(idx, len(self.symbol_resolution_stack)):
|
|
197
|
+
entry = self.symbol_resolution_stack[i]
|
|
198
|
+
entry.is_result_valid = False
|
|
199
|
+
return False
|
|
200
|
+
self.symbol_resolution_stack.append(SymbolResolutionStackEntry(symbol=symbol))
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
def pop_symbol_resolution(self, symbol: uni.Symbol) -> bool:
|
|
204
|
+
"""Pop a symbol from the resolution stack."""
|
|
205
|
+
popped_entry = self.symbol_resolution_stack.pop()
|
|
206
|
+
assert popped_entry.symbol == symbol
|
|
207
|
+
return popped_entry.is_result_valid
|
|
64
208
|
|
|
65
209
|
# Pyright equivalent function name = getEffectiveTypeOfSymbol.
|
|
66
210
|
def get_type_of_symbol(self, symbol: uni.Symbol) -> TypeBase:
|
|
67
211
|
"""Return the effective type of the symbol."""
|
|
68
|
-
|
|
212
|
+
if self.push_symbol_resolution(symbol):
|
|
213
|
+
try:
|
|
214
|
+
return self._get_type_of_symbol(symbol)
|
|
215
|
+
finally:
|
|
216
|
+
self.pop_symbol_resolution(symbol)
|
|
217
|
+
|
|
218
|
+
# If we reached here that means we have a cyclic symbolic reference.
|
|
219
|
+
return types.UnknownType()
|
|
69
220
|
|
|
70
221
|
# NOTE: This function doesn't exists in pyright, however it exists as a helper function
|
|
71
222
|
# for the following functions.
|
|
@@ -92,10 +243,13 @@ class TypeEvaluator:
|
|
|
92
243
|
mod.parent_scope = self.builtins_module
|
|
93
244
|
return mod
|
|
94
245
|
|
|
95
|
-
def get_type_of_module(self, node: uni.ModulePath) -> types.
|
|
246
|
+
def get_type_of_module(self, node: uni.ModulePath) -> types.TypeBase:
|
|
96
247
|
"""Return the effective type of the module."""
|
|
97
248
|
if node.name_spec.type is not None:
|
|
98
249
|
return cast(types.ModuleType, node.name_spec.type)
|
|
250
|
+
if not Path(node.resolve_relative_path()).exists():
|
|
251
|
+
node.name_spec.type = types.UnknownType()
|
|
252
|
+
return node.name_spec.type
|
|
99
253
|
|
|
100
254
|
mod: uni.Module = self._import_module_from_path(node.resolve_relative_path())
|
|
101
255
|
mod_type = types.ModuleType(
|
|
@@ -149,6 +303,11 @@ class TypeEvaluator:
|
|
|
149
303
|
# import from mod { item }
|
|
150
304
|
else:
|
|
151
305
|
mod_type = self.get_type_of_module(import_node.from_loc)
|
|
306
|
+
if not isinstance(mod_type, types.ModuleType):
|
|
307
|
+
node.name_spec.type = types.UnknownType()
|
|
308
|
+
# TODO: Add diagnostic that from_loc is not accessible.
|
|
309
|
+
# Eg: 'Import "scipy" could not be resolved'
|
|
310
|
+
return node.name_spec.type
|
|
152
311
|
if sym := mod_type.symbol_table.lookup(node.name.value, deep=True):
|
|
153
312
|
node.name.sym = sym
|
|
154
313
|
if node.alias:
|
|
@@ -164,20 +323,99 @@ class TypeEvaluator:
|
|
|
164
323
|
if node.name_spec.type is not None:
|
|
165
324
|
return cast(types.ClassType, node.name_spec.type)
|
|
166
325
|
|
|
326
|
+
base_classes: list[TypeBase] = []
|
|
327
|
+
for base_class in node.base_classes or []:
|
|
328
|
+
base_class_type = self.get_type_of_expression(base_class)
|
|
329
|
+
base_classes.append(base_class_type)
|
|
330
|
+
is_builtin_class = node.find_parent_of_type(uni.Module) == self.builtins_module
|
|
331
|
+
|
|
167
332
|
cls_type = types.ClassType(
|
|
168
333
|
types.ClassType.ClassDetailsShared(
|
|
169
334
|
class_name=node.name_spec.sym_name,
|
|
170
335
|
symbol_table=node,
|
|
171
|
-
|
|
336
|
+
base_classes=base_classes,
|
|
337
|
+
is_builtin_class=is_builtin_class,
|
|
172
338
|
),
|
|
173
339
|
flags=types.TypeFlags.Instantiable,
|
|
174
340
|
)
|
|
175
341
|
|
|
342
|
+
# Compute the MRO for the class.
|
|
343
|
+
type_utils.compute_mro_linearization(cls_type)
|
|
344
|
+
|
|
176
345
|
# Cache the type, pyright is doing invalidateTypeCacheIfCanceled()
|
|
177
346
|
# we're not doing that any time sooner.
|
|
178
347
|
node.name_spec.type = cls_type
|
|
179
348
|
return cls_type
|
|
180
349
|
|
|
350
|
+
def get_type_of_ability(self, node: uni.Ability) -> TypeBase:
|
|
351
|
+
"""Return the effective type of an ability."""
|
|
352
|
+
if node.name_spec.type is not None:
|
|
353
|
+
return node.name_spec.type
|
|
354
|
+
|
|
355
|
+
if not isinstance(node.signature, uni.FuncSignature):
|
|
356
|
+
node.name_spec.type = types.UnknownType()
|
|
357
|
+
return node.name_spec.type
|
|
358
|
+
|
|
359
|
+
return_type: TypeBase
|
|
360
|
+
if isinstance(node.signature.return_type, uni.Expr):
|
|
361
|
+
return_type = self._convert_to_instance(
|
|
362
|
+
self.get_type_of_expression(node.signature.return_type)
|
|
363
|
+
)
|
|
364
|
+
else:
|
|
365
|
+
return_type = types.UnknownType()
|
|
366
|
+
|
|
367
|
+
# Define helper function for parameter conversion.
|
|
368
|
+
def _get_param_category(param: uni.ParamVar) -> types.ParameterCategory:
|
|
369
|
+
if param.is_vararg:
|
|
370
|
+
return types.ParameterCategory.ArgsList
|
|
371
|
+
if param.is_kwargs:
|
|
372
|
+
return types.ParameterCategory.KwargsDict
|
|
373
|
+
return types.ParameterCategory.Positional
|
|
374
|
+
|
|
375
|
+
# Define helper function for parameter kind conversion.
|
|
376
|
+
def _convert_param_kind(kind: uni.ParamKind) -> types.ParamKind:
|
|
377
|
+
match kind:
|
|
378
|
+
case uni.ParamKind.POSONLY:
|
|
379
|
+
return types.ParamKind.POSONLY
|
|
380
|
+
case uni.ParamKind.NORMAL:
|
|
381
|
+
return types.ParamKind.NORMAL
|
|
382
|
+
case uni.ParamKind.VARARG:
|
|
383
|
+
return types.ParamKind.VARARG
|
|
384
|
+
case uni.ParamKind.KWONLY:
|
|
385
|
+
return types.ParamKind.KWONLY
|
|
386
|
+
case uni.ParamKind.KWARG:
|
|
387
|
+
return types.ParamKind.KWARG
|
|
388
|
+
return types.ParamKind.NORMAL
|
|
389
|
+
|
|
390
|
+
parameters: list[types.Parameter] = []
|
|
391
|
+
for idx, param in enumerate(node.signature.get_parameters()):
|
|
392
|
+
# TODO: Set parameter category for *args, and **kwargs
|
|
393
|
+
param_type: TypeBase | None = None
|
|
394
|
+
|
|
395
|
+
if param.type_tag:
|
|
396
|
+
param_type_cls = self.get_type_of_expression(param.type_tag.tag)
|
|
397
|
+
param_type = self._convert_to_instance(param_type_cls)
|
|
398
|
+
|
|
399
|
+
parameters.append(
|
|
400
|
+
types.Parameter(
|
|
401
|
+
name=param.name.value,
|
|
402
|
+
category=_get_param_category(param),
|
|
403
|
+
param_type=param_type,
|
|
404
|
+
default_value=param.value,
|
|
405
|
+
is_self=(idx == 0 and self._is_expr_self(param.name)),
|
|
406
|
+
param_kind=_convert_param_kind(param.param_kind),
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
func_type = types.FunctionType(
|
|
411
|
+
func_name=node.name_spec.sym_name,
|
|
412
|
+
return_type=return_type,
|
|
413
|
+
parameters=parameters,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
node.name_spec.type = func_type
|
|
417
|
+
return func_type
|
|
418
|
+
|
|
181
419
|
def get_type_of_string(self, node: uni.String | uni.MultiString) -> TypeBase:
|
|
182
420
|
"""Return the effective type of the string."""
|
|
183
421
|
# FIXME: Strings are a type of LiteralString type:
|
|
@@ -195,6 +433,11 @@ class TypeEvaluator:
|
|
|
195
433
|
assert self.prefetch.int_class is not None
|
|
196
434
|
return self.prefetch.int_class
|
|
197
435
|
|
|
436
|
+
def get_type_of_float(self, node: uni.Float) -> TypeBase:
|
|
437
|
+
"""Return the effective type of the float."""
|
|
438
|
+
assert self.prefetch.float_class is not None
|
|
439
|
+
return self.prefetch.float_class
|
|
440
|
+
|
|
198
441
|
# Pyright equivalent function name = getTypeOfExpression();
|
|
199
442
|
def get_type_of_expression(self, node: uni.Expr) -> TypeBase:
|
|
200
443
|
"""Return the effective type of the expression."""
|
|
@@ -221,9 +464,6 @@ class TypeEvaluator:
|
|
|
221
464
|
# NOTE: For now if we don't have the type info, we assume it's compatible.
|
|
222
465
|
# For strict mode we should disallow usage of unknown unless explicitly ignored.
|
|
223
466
|
return True
|
|
224
|
-
# FIXME: This logic is not valid, just here as a stub.
|
|
225
|
-
if types.TypeCategory.Unknown in (src_type.category, dest_type.category):
|
|
226
|
-
return True
|
|
227
467
|
|
|
228
468
|
if src_type == dest_type:
|
|
229
469
|
return True
|
|
@@ -233,8 +473,30 @@ class TypeEvaluator:
|
|
|
233
473
|
assert isinstance(src_type, types.ClassType)
|
|
234
474
|
return self._assign_class(src_type, dest_type)
|
|
235
475
|
|
|
236
|
-
|
|
237
|
-
|
|
476
|
+
return False
|
|
477
|
+
|
|
478
|
+
# TODO: This should take an argument list as parameter.
|
|
479
|
+
def get_type_of_magic_method_call(
|
|
480
|
+
self, obj_type: TypeBase, method_name: str
|
|
481
|
+
) -> TypeBase | None:
|
|
482
|
+
"""Return the effective return type of a magic method call."""
|
|
483
|
+
if obj_type.category == types.TypeCategory.Class:
|
|
484
|
+
# TODO: getTypeOfBoundMember() <-- Implement this if needed, for the simple case
|
|
485
|
+
# we'll directly call member lookup.
|
|
486
|
+
#
|
|
487
|
+
# WE'RE DAVIATING FROM PYRIGHT FOR THIS METHOD HEAVILY HOWEVER THIS CAN BE RE-WRITTEN IF NEEDED.
|
|
488
|
+
#
|
|
489
|
+
assert isinstance(obj_type, types.ClassType) # <-- To make typecheck happy.
|
|
490
|
+
if member := self._lookup_class_member(obj_type, method_name):
|
|
491
|
+
member_ty = self.get_type_of_symbol(member.symbol)
|
|
492
|
+
if isinstance(member_ty, types.FunctionType):
|
|
493
|
+
return member_ty.return_type
|
|
494
|
+
# If we reached here, magic method is not a function.
|
|
495
|
+
# 1. recursively check __call__() on the type, TODO
|
|
496
|
+
# 2. if any or unknown, return getUnknownTypeForCallable() TODO
|
|
497
|
+
# 3. return undefined.
|
|
498
|
+
return None
|
|
499
|
+
return None
|
|
238
500
|
|
|
239
501
|
def _assign_class(
|
|
240
502
|
self, src_type: types.ClassType, dest_type: types.ClassType
|
|
@@ -243,39 +505,23 @@ class TypeEvaluator:
|
|
|
243
505
|
if src_type.shared == dest_type.shared:
|
|
244
506
|
return True
|
|
245
507
|
|
|
246
|
-
#
|
|
247
|
-
|
|
508
|
+
# Check if src class is a subclass of dest class.
|
|
509
|
+
for base_cls in src_type.shared.mro:
|
|
510
|
+
if base_cls.shared == dest_type.shared:
|
|
511
|
+
return True
|
|
248
512
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
#
|
|
253
|
-
|
|
254
|
-
none_type_class=types.UnknownType(),
|
|
255
|
-
object_class=self._get_builtin_type("object"),
|
|
256
|
-
type_class=self._get_builtin_type("type"),
|
|
257
|
-
# union_type_class=
|
|
258
|
-
# awaitable_class=
|
|
259
|
-
# function_class=
|
|
260
|
-
# method_class=
|
|
261
|
-
tuple_class=self._get_builtin_type("tuple"),
|
|
262
|
-
bool_class=self._get_builtin_type("bool"),
|
|
263
|
-
int_class=self._get_builtin_type("int"),
|
|
264
|
-
str_class=self._get_builtin_type("str"),
|
|
265
|
-
dict_class=self._get_builtin_type("dict"),
|
|
266
|
-
# module_type_class=
|
|
267
|
-
# typed_dict_class=
|
|
268
|
-
# typed_dict_private_class=
|
|
269
|
-
# supports_keys_and_get_item_class=
|
|
270
|
-
# mapping_class=
|
|
271
|
-
# template_class=
|
|
272
|
-
)
|
|
513
|
+
# Everything is assignable to an object.
|
|
514
|
+
if dest_type.is_builtin("object"):
|
|
515
|
+
# TODO: Invariance not handled yet
|
|
516
|
+
# invariant contexts to avoid list[int] <: list[object] errors.
|
|
517
|
+
return True
|
|
273
518
|
|
|
274
|
-
|
|
275
|
-
""
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
519
|
+
# Integers can be used where floats are expected.
|
|
520
|
+
if src_type.is_builtin("int") and dest_type.is_builtin("float"):
|
|
521
|
+
return True
|
|
522
|
+
|
|
523
|
+
# TODO: Search base classes and everything else pyright is doing.
|
|
524
|
+
return False
|
|
279
525
|
|
|
280
526
|
# This function is a combination of the bellow pyright functions.
|
|
281
527
|
# - getDeclaredTypeOfSymbol
|
|
@@ -298,6 +544,14 @@ class TypeEvaluator:
|
|
|
298
544
|
case uni.Archetype():
|
|
299
545
|
return self.get_type_of_class(node)
|
|
300
546
|
|
|
547
|
+
case uni.Ability():
|
|
548
|
+
return self.get_type_of_ability(node)
|
|
549
|
+
|
|
550
|
+
case uni.ParamVar():
|
|
551
|
+
if node.type_tag:
|
|
552
|
+
annotation_type = self.get_type_of_expression(node.type_tag.tag)
|
|
553
|
+
return self._convert_to_instance(annotation_type)
|
|
554
|
+
|
|
301
555
|
# This actually defined in the function getTypeForDeclaration();
|
|
302
556
|
# Pyright has DeclarationType.Variable.
|
|
303
557
|
case uni.Name():
|
|
@@ -335,6 +589,9 @@ class TypeEvaluator:
|
|
|
335
589
|
case uni.Int():
|
|
336
590
|
return self._convert_to_instance(self.get_type_of_int(expr))
|
|
337
591
|
|
|
592
|
+
case uni.Float():
|
|
593
|
+
return self._convert_to_instance(self.get_type_of_float(expr))
|
|
594
|
+
|
|
338
595
|
case uni.AtomTrailer():
|
|
339
596
|
# NOTE: Pyright is using CFG to figure out the member type by narrowing the base
|
|
340
597
|
# type and filtering the members. We're not doing that anytime sooner.
|
|
@@ -376,13 +633,71 @@ class TypeEvaluator:
|
|
|
376
633
|
else: # <expr>[<expr>]
|
|
377
634
|
pass # TODO:
|
|
378
635
|
|
|
636
|
+
case uni.AtomUnit():
|
|
637
|
+
return self.get_type_of_expression(expr.value)
|
|
638
|
+
|
|
639
|
+
case uni.FuncCall():
|
|
640
|
+
return self.validate_call_args(expr)
|
|
641
|
+
|
|
642
|
+
case uni.BinaryExpr():
|
|
643
|
+
return operations.get_type_of_binary_operation(self, expr)
|
|
644
|
+
|
|
379
645
|
case uni.Name():
|
|
646
|
+
# NOTE: For self's type pyright is getting the first parameter of a method and
|
|
647
|
+
# the name can be anything not just self, however we don't have the first parameter
|
|
648
|
+
# and self is a keyword, we need to do it in this way.
|
|
649
|
+
if self._is_expr_self(expr):
|
|
650
|
+
return self._get_type_of_self(expr)
|
|
651
|
+
|
|
380
652
|
if symbol := expr.sym_tab.lookup(expr.value, deep=True):
|
|
653
|
+
expr.sym = symbol
|
|
381
654
|
return self.get_type_of_symbol(symbol)
|
|
382
655
|
|
|
383
656
|
# TODO: More expressions.
|
|
384
657
|
return types.UnknownType()
|
|
385
658
|
|
|
659
|
+
# -----------------------------------------------------------------------------
|
|
660
|
+
# Helper functions
|
|
661
|
+
# -----------------------------------------------------------------------------
|
|
662
|
+
|
|
663
|
+
def _is_expr_self(self, expr: uni.Expr) -> bool:
|
|
664
|
+
"""Check if the expression is Name that is 'self' and in the method context."""
|
|
665
|
+
if (
|
|
666
|
+
isinstance(expr, uni.Name)
|
|
667
|
+
and (expr.value == TOKEN_MAP[Tok.KW_SELF])
|
|
668
|
+
and (fn := self._get_enclosing_method(expr))
|
|
669
|
+
and (not fn.is_static)
|
|
670
|
+
):
|
|
671
|
+
return True
|
|
672
|
+
return False
|
|
673
|
+
|
|
674
|
+
def _get_enclosing_function(self, node: uni.UniNode) -> uni.Ability | None:
|
|
675
|
+
"""Get the enclosing function (ability) of the given node."""
|
|
676
|
+
if (impl := node.find_parent_of_type(uni.ImplDef)) and (
|
|
677
|
+
isinstance(impl.decl_link, uni.Ability)
|
|
678
|
+
):
|
|
679
|
+
return impl.decl_link
|
|
680
|
+
return node.find_parent_of_type(uni.Ability)
|
|
681
|
+
|
|
682
|
+
def _get_enclosing_method(self, node: uni.UniNode) -> uni.Ability | None:
|
|
683
|
+
"""Get the enclosing method (ability) of the given node."""
|
|
684
|
+
enclosing_fn = self._get_enclosing_function(node)
|
|
685
|
+
while enclosing_fn and (not enclosing_fn.is_method):
|
|
686
|
+
enclosing_fn = self._get_enclosing_function(enclosing_fn)
|
|
687
|
+
if enclosing_fn and enclosing_fn.is_method:
|
|
688
|
+
return enclosing_fn
|
|
689
|
+
return None
|
|
690
|
+
|
|
691
|
+
def _get_type_of_self(self, node: uni.Name) -> TypeBase:
|
|
692
|
+
"""Return the effective type of self."""
|
|
693
|
+
if method := self._get_enclosing_method(node):
|
|
694
|
+
cls = method.method_owner
|
|
695
|
+
if isinstance(cls, uni.Archetype):
|
|
696
|
+
return self.get_type_of_class(cls).clone_as_instance()
|
|
697
|
+
if isinstance(cls, uni.Enum):
|
|
698
|
+
pass # TODO: Implement type from enum.
|
|
699
|
+
return types.UnknownType()
|
|
700
|
+
|
|
386
701
|
def _convert_to_instance(self, jtype: TypeBase) -> TypeBase:
|
|
387
702
|
"""Convert a class type to an instance type."""
|
|
388
703
|
# TODO: Grep pyright "Handle type[x] as a special case." They handle `type[x]` as a special case:
|
|
@@ -397,7 +712,7 @@ class TypeEvaluator:
|
|
|
397
712
|
|
|
398
713
|
def _lookup_class_member(
|
|
399
714
|
self, base_type: types.ClassType, member: str
|
|
400
|
-
) -> ClassMember | None:
|
|
715
|
+
) -> type_utils.ClassMember | None:
|
|
401
716
|
"""Lookup the class member type."""
|
|
402
717
|
assert self.prefetch.int_class is not None
|
|
403
718
|
# FIXME: Pyright's way: Implement class member iterator (based on mro and the multiple inheritance)
|
|
@@ -405,13 +720,14 @@ class TypeEvaluator:
|
|
|
405
720
|
|
|
406
721
|
# NOTE: This is a simple implementation to make it work and more robust implementation will
|
|
407
722
|
# be done in a future PR.
|
|
408
|
-
|
|
409
|
-
|
|
723
|
+
for cls in base_type.shared.mro:
|
|
724
|
+
if sym := cls.lookup_member_symbol(member):
|
|
725
|
+
return type_utils.ClassMember(sym, cls)
|
|
410
726
|
return None
|
|
411
727
|
|
|
412
728
|
def _lookup_object_member(
|
|
413
729
|
self, base_type: types.ClassType, member: str
|
|
414
|
-
) -> ClassMember | None:
|
|
730
|
+
) -> type_utils.ClassMember | None:
|
|
415
731
|
"""Lookup the object member type."""
|
|
416
732
|
assert self.prefetch.int_class is not None
|
|
417
733
|
if base_type.is_class_instance():
|
|
@@ -419,3 +735,110 @@ class TypeEvaluator:
|
|
|
419
735
|
# TODO: We need to implement Member lookup flags and set SkipInstanceMember to 0.
|
|
420
736
|
return self._lookup_class_member(base_type, member)
|
|
421
737
|
return None
|
|
738
|
+
|
|
739
|
+
def match_args_to_params(
|
|
740
|
+
self, expr: uni.FuncCall, func_type: types.FunctionType
|
|
741
|
+
) -> MatchArgsToParamsResult:
|
|
742
|
+
"""
|
|
743
|
+
Match arguments passed to a function to the corresponding parameters in that function.
|
|
744
|
+
|
|
745
|
+
This matching is done based on positions and keywords. Type evaluation and
|
|
746
|
+
validation is left to the caller.
|
|
747
|
+
This logic is based on PEP 3102: https://www.python.org/dev/peps/pep-3102/
|
|
748
|
+
"""
|
|
749
|
+
arg_params: dict[uni.Expr | uni.KWPair, types.Parameter | None] = {}
|
|
750
|
+
argument_errors = False
|
|
751
|
+
|
|
752
|
+
params_to_match = func_type.parameters.copy()
|
|
753
|
+
|
|
754
|
+
# Skip `self` for method calls.
|
|
755
|
+
if len(func_type.parameters) >= 1 and func_type.parameters[0].is_self:
|
|
756
|
+
params_to_match.pop(0)
|
|
757
|
+
|
|
758
|
+
# Create a tracker for parameter assignment.
|
|
759
|
+
param_tracker = type_utils.ParamAssignmentTracker(params_to_match)
|
|
760
|
+
|
|
761
|
+
# We iterate over the arguments and match with the parameter, the param_tracker will
|
|
762
|
+
# keep track of the matched parameters and unmatched required parameters.
|
|
763
|
+
#
|
|
764
|
+
# Tracker: p1, p2, /, p3, p4, *args, | p6, **kwargs
|
|
765
|
+
# ^ ^ ^ ^ ^^ ^ | ^ ^^
|
|
766
|
+
# | | | | | \__ \__ | | | \________
|
|
767
|
+
# Args: a1, a2, a3, p4=a4, a5, a6, *a7, | p6=a8, p_kw1=a9, p_kw2=a10
|
|
768
|
+
# '--------------------------------' | '------------------------'
|
|
769
|
+
# We match positional with | We match named arguments with
|
|
770
|
+
# tracked parameter index. | param name lookup.
|
|
771
|
+
#
|
|
772
|
+
for arg in expr.params:
|
|
773
|
+
try:
|
|
774
|
+
if isinstance(arg, uni.KWPair):
|
|
775
|
+
# Match parameter based on name lookup.
|
|
776
|
+
matching_param = param_tracker.match_named_argument(arg)
|
|
777
|
+
arg_params[arg] = matching_param
|
|
778
|
+
else: # Match parameter based on the position of the argument.
|
|
779
|
+
matching_param = param_tracker.match_positional_argument(arg)
|
|
780
|
+
arg_params[arg] = matching_param
|
|
781
|
+
except Exception as e:
|
|
782
|
+
self.add_diagnostic(arg, str(e))
|
|
783
|
+
argument_errors = True
|
|
784
|
+
|
|
785
|
+
if unmatched_params := param_tracker.get_unmatched_required_params():
|
|
786
|
+
names = ", ".join(f"'{p.name}'" for p in unmatched_params)
|
|
787
|
+
argument_errors = True
|
|
788
|
+
self.add_diagnostic(
|
|
789
|
+
expr,
|
|
790
|
+
f"Not all required parameters were provided in the function call: {names}",
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
return MatchArgsToParamsResult(
|
|
794
|
+
arg_params=arg_params, argument_errors=argument_errors
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
def validate_call_args(self, expr: uni.FuncCall) -> TypeBase:
|
|
798
|
+
"""
|
|
799
|
+
Validate that the arguments can be assigned to the call's parameter list.
|
|
800
|
+
|
|
801
|
+
Specializes the call based on arg types, and returns the specialized
|
|
802
|
+
type of the return value. If it detects an error along the way, it emits
|
|
803
|
+
a diagnostic and sets argumentErrors to true.
|
|
804
|
+
"""
|
|
805
|
+
caller_type = self.get_type_of_expression(expr.target)
|
|
806
|
+
if isinstance(caller_type, types.FunctionType):
|
|
807
|
+
arg_param_match = self.match_args_to_params(expr, caller_type)
|
|
808
|
+
if not arg_param_match.argument_errors:
|
|
809
|
+
self.validate_arg_types(arg_param_match)
|
|
810
|
+
return caller_type.return_type or types.UnknownType()
|
|
811
|
+
|
|
812
|
+
if (
|
|
813
|
+
isinstance(caller_type, types.ClassType)
|
|
814
|
+
and caller_type.is_instantiable_class()
|
|
815
|
+
):
|
|
816
|
+
# TODO: validate args for __init__()
|
|
817
|
+
return caller_type.clone_as_instance()
|
|
818
|
+
|
|
819
|
+
if caller_type.is_class_instance():
|
|
820
|
+
# TODO: validate args.
|
|
821
|
+
magic_call_ret = self.get_type_of_magic_method_call(caller_type, "__call__")
|
|
822
|
+
if magic_call_ret:
|
|
823
|
+
return magic_call_ret
|
|
824
|
+
|
|
825
|
+
return types.UnknownType()
|
|
826
|
+
|
|
827
|
+
def validate_arg_types(
|
|
828
|
+
self,
|
|
829
|
+
args: MatchArgsToParamsResult,
|
|
830
|
+
) -> None:
|
|
831
|
+
"""Validate that the argument types can be assigned to the parameter types."""
|
|
832
|
+
for arg, param in args.arg_params.items():
|
|
833
|
+
if param is None or param.param_type is None:
|
|
834
|
+
continue
|
|
835
|
+
if isinstance(arg, uni.KWPair):
|
|
836
|
+
arg_type = self.get_type_of_expression(arg.value)
|
|
837
|
+
else:
|
|
838
|
+
arg_type = self.get_type_of_expression(arg)
|
|
839
|
+
|
|
840
|
+
if not self.assign_type(arg_type, param.param_type):
|
|
841
|
+
self.add_diagnostic(
|
|
842
|
+
arg,
|
|
843
|
+
f"Cannot assign {arg_type} to parameter '{param.name}' of type {param.param_type}",
|
|
844
|
+
)
|