jaclang 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (103) hide show
  1. jaclang/cli/cli.md +3 -3
  2. jaclang/cli/cli.py +37 -37
  3. jaclang/cli/cmdreg.py +45 -140
  4. jaclang/compiler/constant.py +0 -1
  5. jaclang/compiler/jac.lark +3 -6
  6. jaclang/compiler/larkparse/jac_parser.py +2 -2
  7. jaclang/compiler/parser.py +213 -34
  8. jaclang/compiler/passes/main/__init__.py +2 -4
  9. jaclang/compiler/passes/main/def_use_pass.py +0 -4
  10. jaclang/compiler/passes/main/predynamo_pass.py +221 -0
  11. jaclang/compiler/passes/main/pyast_gen_pass.py +83 -55
  12. jaclang/compiler/passes/main/pyast_load_pass.py +66 -40
  13. jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
  14. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
  15. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
  16. jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
  17. jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
  18. jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
  19. jaclang/compiler/passes/main/tests/fixtures/checker_binary_op.jac +21 -0
  20. jaclang/compiler/passes/main/tests/fixtures/checker_call_expr_class.jac +12 -0
  21. jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
  22. jaclang/compiler/passes/main/tests/fixtures/checker_cyclic_symbol.jac +4 -0
  23. jaclang/compiler/passes/main/tests/fixtures/checker_expr_call.jac +9 -0
  24. jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
  25. jaclang/compiler/passes/main/tests/fixtures/checker_import_missing_module.jac +13 -0
  26. jaclang/compiler/passes/main/tests/fixtures/checker_magic_call.jac +17 -0
  27. jaclang/compiler/passes/main/tests/fixtures/checker_mod_path.jac +8 -0
  28. jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
  29. jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
  30. jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
  31. jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
  32. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
  33. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
  34. jaclang/compiler/passes/main/tests/test_checker_pass.py +265 -0
  35. jaclang/compiler/passes/main/tests/test_predynamo_pass.py +57 -0
  36. jaclang/compiler/passes/main/type_checker_pass.py +36 -61
  37. jaclang/compiler/passes/tool/doc_ir_gen_pass.py +204 -44
  38. jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
  39. jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
  40. jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
  41. jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +171 -11
  42. jaclang/compiler/passes/transform.py +12 -8
  43. jaclang/compiler/program.py +14 -6
  44. jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
  45. jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
  46. jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
  47. jaclang/compiler/tests/fixtures/python_module.py +1 -0
  48. jaclang/compiler/tests/test_importer.py +39 -0
  49. jaclang/compiler/tests/test_parser.py +49 -0
  50. jaclang/compiler/type_system/operations.py +104 -0
  51. jaclang/compiler/type_system/type_evaluator.py +470 -47
  52. jaclang/compiler/type_system/type_utils.py +246 -0
  53. jaclang/compiler/type_system/types.py +58 -2
  54. jaclang/compiler/unitree.py +79 -94
  55. jaclang/langserve/engine.jac +253 -230
  56. jaclang/langserve/server.jac +46 -15
  57. jaclang/langserve/tests/fixtures/circle.jac +3 -3
  58. jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
  59. jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
  60. jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
  61. jaclang/langserve/tests/server_test/circle_template.jac +80 -0
  62. jaclang/langserve/tests/server_test/glob_template.jac +4 -0
  63. jaclang/langserve/tests/server_test/test_lang_serve.py +154 -312
  64. jaclang/langserve/tests/server_test/utils.py +153 -116
  65. jaclang/langserve/tests/test_dev_server.py +1 -1
  66. jaclang/langserve/tests/test_server.py +30 -86
  67. jaclang/langserve/utils.jac +56 -63
  68. jaclang/runtimelib/machine.py +7 -0
  69. jaclang/runtimelib/meta_importer.py +27 -1
  70. jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
  71. jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
  72. jaclang/settings.py +18 -14
  73. jaclang/tests/fixtures/abc_check.jac +3 -3
  74. jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
  75. jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
  76. jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
  77. jaclang/tests/fixtures/jac_run_py_bugs.py +18 -0
  78. jaclang/tests/fixtures/jac_run_py_import.py +13 -0
  79. jaclang/tests/fixtures/lambda_arg_annotation.jac +15 -0
  80. jaclang/tests/fixtures/lambda_self.jac +18 -0
  81. jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
  82. jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
  83. jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
  84. jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
  85. jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
  86. jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
  87. jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
  88. jaclang/tests/fixtures/py2jac_params.py +8 -0
  89. jaclang/tests/fixtures/run_test.jac +4 -4
  90. jaclang/tests/test_cli.py +103 -18
  91. jaclang/tests/test_language.py +74 -16
  92. jaclang/utils/helpers.py +47 -2
  93. jaclang/utils/module_resolver.py +11 -1
  94. jaclang/utils/test.py +8 -0
  95. jaclang/utils/treeprinter.py +0 -18
  96. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/METADATA +3 -3
  97. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/RECORD +99 -62
  98. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/WHEEL +1 -1
  99. jaclang/compiler/passes/main/inheritance_pass.py +0 -131
  100. jaclang/langserve/dev_engine.jac +0 -645
  101. jaclang/langserve/dev_server.jac +0 -201
  102. jaclang/langserve/tests/server_test/code_test.py +0 -0
  103. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/entry_points.txt +0 -0
@@ -6,19 +6,30 @@ PyrightReference:
6
6
  packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts
7
7
  """
8
8
 
9
+ import ast as py_ast
10
+ import os
9
11
  from dataclasses import dataclass
10
12
  from pathlib import Path
11
- from typing import TYPE_CHECKING, cast
13
+ from typing import Callable, TYPE_CHECKING, cast
12
14
 
13
15
  import jaclang.compiler.unitree as uni
16
+ from jaclang.compiler import TOKEN_MAP
17
+ from jaclang.compiler.constant import Tokens as Tok
18
+ from jaclang.compiler.passes.main.pyast_load_pass import PyastBuildPass
19
+ from jaclang.compiler.passes.main.sym_tab_build_pass import SymTabBuildPass
14
20
  from jaclang.compiler.type_system import types
21
+ from jaclang.runtimelib.utils import read_file_with_encoding
15
22
 
16
23
  if TYPE_CHECKING:
17
24
  from jaclang.compiler.program import JacProgram
18
25
 
19
- from .type_utils import ClassMember
26
+ from . import operations
27
+ from . import type_utils
20
28
  from .types import TypeBase
21
29
 
30
+ # The callback type definition for the diagnostic messages.
31
+ DiagnosticCallback = Callable[[uni.UniNode, str, bool], None]
32
+
22
33
 
23
34
  @dataclass
24
35
  class PrefetchedTypes:
@@ -34,6 +45,7 @@ class PrefetchedTypes:
34
45
  tuple_class: TypeBase | None = None
35
46
  bool_class: TypeBase | None = None
36
47
  int_class: TypeBase | None = None
48
+ float_class: TypeBase | None = None
37
49
  str_class: TypeBase | None = None
38
50
  dict_class: TypeBase | None = None
39
51
  module_type_class: TypeBase | None = None
@@ -44,10 +56,50 @@ class PrefetchedTypes:
44
56
  template_class: TypeBase | None = None
45
57
 
46
58
 
59
+ @dataclass
60
+ class SymbolResolutionStackEntry:
61
+ """Represents a single entry in the symbol resolution stack."""
62
+
63
+ symbol: uni.Symbol
64
+
65
+ # Initially true, it's set to false if a recursion
66
+ # is detected.
67
+ is_result_valid: bool = True
68
+
69
+ # Some limited forms of recursion are allowed. In these
70
+ # cases, a partially-constructed type can be registered.
71
+ partial_type: TypeBase | None = None
72
+
73
+
74
+ @dataclass
75
+ class MatchArgsToParamsResult:
76
+ """Result of matching arguments to parameters."""
77
+
78
+ # FIXME: This class implementation is modified from pyright to make it
79
+ # simple and work for now, however this needs to be revisited and
80
+ # implemented properly.
81
+ arg_params: dict[uni.Expr | uni.KWPair, types.Parameter | None]
82
+
83
+ overload: types.FunctionType | None = None
84
+ argument_errors: bool = False
85
+
86
+
47
87
  class TypeEvaluator:
48
88
  """Type evaluator for JacLang."""
49
89
 
50
- def __init__(self, builtins_module: uni.Module, program: "JacProgram") -> None:
90
+ # NOTE: This is done in the binder pass of pyright, however I'm doing this
91
+ # here, cause this will be the entry point of the type checker and we're not
92
+ # relying on the binder pass at the moment and we can go back to binder pass
93
+ # in the future if we needed it.
94
+ _BUILTINS_STUB_FILE_PATH = os.path.join(
95
+ os.path.dirname(__file__),
96
+ "../../vendor/typeshed/stdlib/builtins.pyi",
97
+ )
98
+
99
+ def __init__(
100
+ self,
101
+ program: "JacProgram",
102
+ ) -> None:
51
103
  """Initialize the type evaluator with prefetched types.
52
104
 
53
105
  Implementation Note:
@@ -58,14 +110,113 @@ class TypeEvaluator:
58
110
  in some place then it will not be available in the evaluator, So we
59
111
  are prefetching the builtins at the constructor level once.
60
112
  """
61
- self.builtins_module = builtins_module
62
113
  self.program = program
114
+ self.symbol_resolution_stack: list[SymbolResolutionStackEntry] = []
115
+ self.builtins_module = self._load_builtins_stub_module()
63
116
  self.prefetch = self._prefetch_types()
117
+ self.diagnostic_callback: DiagnosticCallback | None = None
118
+
119
+ def _load_builtins_stub_module(self) -> uni.Module:
120
+ """Load and return the builtins stub module."""
121
+ if not os.path.exists(TypeEvaluator._BUILTINS_STUB_FILE_PATH):
122
+ raise FileNotFoundError(
123
+ f"Builtins stub file not found at {TypeEvaluator._BUILTINS_STUB_FILE_PATH}"
124
+ )
125
+ file_content = read_file_with_encoding(TypeEvaluator._BUILTINS_STUB_FILE_PATH)
126
+ uni_source = uni.Source(file_content, TypeEvaluator._BUILTINS_STUB_FILE_PATH)
127
+ mod = PyastBuildPass(
128
+ ir_in=uni.PythonModuleAst(
129
+ py_ast.parse(file_content),
130
+ orig_src=uni_source,
131
+ ),
132
+ prog=self.program,
133
+ ).ir_out
134
+ SymTabBuildPass(ir_in=mod, prog=self.program)
135
+ return mod
136
+
137
+ def _get_builtin_type(self, name: str) -> TypeBase:
138
+ """Return the built-in type with the given name."""
139
+ if (symbol := self.builtins_module.lookup(name)) is not None:
140
+ return self.get_type_of_symbol(symbol)
141
+ return types.UnknownType()
142
+
143
+ def _prefetch_types(self) -> "PrefetchedTypes":
144
+ """Return the prefetched types for the type evaluator."""
145
+ return PrefetchedTypes(
146
+ # TODO: Pyright first try load NoneType from typeshed and if it cannot
147
+ # then it set to unknown type.
148
+ none_type_class=types.UnknownType(),
149
+ object_class=self._get_builtin_type("object"),
150
+ type_class=self._get_builtin_type("type"),
151
+ # union_type_class=
152
+ # awaitable_class=
153
+ # function_class=
154
+ # method_class=
155
+ tuple_class=self._get_builtin_type("tuple"),
156
+ bool_class=self._get_builtin_type("bool"),
157
+ int_class=self._get_builtin_type("int"),
158
+ float_class=self._get_builtin_type("float"),
159
+ str_class=self._get_builtin_type("str"),
160
+ dict_class=self._get_builtin_type("dict"),
161
+ # module_type_class=
162
+ # typed_dict_class=
163
+ # typed_dict_private_class=
164
+ # supports_keys_and_get_item_class=
165
+ # mapping_class=
166
+ # template_class=
167
+ )
168
+
169
+ def add_diagnostic(
170
+ self, node: uni.UniNode, message: str, warning: bool = False
171
+ ) -> None:
172
+ """Add a diagnostic message to the program."""
173
+ if self.diagnostic_callback:
174
+ self.diagnostic_callback(node, message, warning)
175
+
176
+ # -------------------------------------------------------------------------
177
+ # Symbol resolution stack
178
+ # -------------------------------------------------------------------------
179
+
180
+ def get_index_of_symbol_resolution(self, symbol: uni.Symbol) -> int | None:
181
+ """Get the index of a symbol in the resolution stack."""
182
+ for i, entry in enumerate(self.symbol_resolution_stack):
183
+ if entry.symbol == symbol:
184
+ return i
185
+ return None
186
+
187
+ def push_symbol_resolution(self, symbol: uni.Symbol) -> bool:
188
+ """
189
+ Push a symbol onto the resolution stack.
190
+
191
+ Return False if recursion detected and in that case it won't push the symbol.
192
+ """
193
+ idx = self.get_index_of_symbol_resolution(symbol)
194
+ if idx is not None:
195
+ # Mark all of the entries between these two as invalid.
196
+ for i in range(idx, len(self.symbol_resolution_stack)):
197
+ entry = self.symbol_resolution_stack[i]
198
+ entry.is_result_valid = False
199
+ return False
200
+ self.symbol_resolution_stack.append(SymbolResolutionStackEntry(symbol=symbol))
201
+ return True
202
+
203
+ def pop_symbol_resolution(self, symbol: uni.Symbol) -> bool:
204
+ """Pop a symbol from the resolution stack."""
205
+ popped_entry = self.symbol_resolution_stack.pop()
206
+ assert popped_entry.symbol == symbol
207
+ return popped_entry.is_result_valid
64
208
 
65
209
  # Pyright equivalent function name = getEffectiveTypeOfSymbol.
66
210
  def get_type_of_symbol(self, symbol: uni.Symbol) -> TypeBase:
67
211
  """Return the effective type of the symbol."""
68
- return self._get_type_of_symbol(symbol)
212
+ if self.push_symbol_resolution(symbol):
213
+ try:
214
+ return self._get_type_of_symbol(symbol)
215
+ finally:
216
+ self.pop_symbol_resolution(symbol)
217
+
218
+ # If we reached here that means we have a cyclic symbolic reference.
219
+ return types.UnknownType()
69
220
 
70
221
  # NOTE: This function doesn't exists in pyright, however it exists as a helper function
71
222
  # for the following functions.
@@ -92,10 +243,13 @@ class TypeEvaluator:
92
243
  mod.parent_scope = self.builtins_module
93
244
  return mod
94
245
 
95
- def get_type_of_module(self, node: uni.ModulePath) -> types.ModuleType:
246
+ def get_type_of_module(self, node: uni.ModulePath) -> types.TypeBase:
96
247
  """Return the effective type of the module."""
97
248
  if node.name_spec.type is not None:
98
249
  return cast(types.ModuleType, node.name_spec.type)
250
+ if not Path(node.resolve_relative_path()).exists():
251
+ node.name_spec.type = types.UnknownType()
252
+ return node.name_spec.type
99
253
 
100
254
  mod: uni.Module = self._import_module_from_path(node.resolve_relative_path())
101
255
  mod_type = types.ModuleType(
@@ -149,6 +303,11 @@ class TypeEvaluator:
149
303
  # import from mod { item }
150
304
  else:
151
305
  mod_type = self.get_type_of_module(import_node.from_loc)
306
+ if not isinstance(mod_type, types.ModuleType):
307
+ node.name_spec.type = types.UnknownType()
308
+ # TODO: Add diagnostic that from_loc is not accessible.
309
+ # Eg: 'Import "scipy" could not be resolved'
310
+ return node.name_spec.type
152
311
  if sym := mod_type.symbol_table.lookup(node.name.value, deep=True):
153
312
  node.name.sym = sym
154
313
  if node.alias:
@@ -164,20 +323,99 @@ class TypeEvaluator:
164
323
  if node.name_spec.type is not None:
165
324
  return cast(types.ClassType, node.name_spec.type)
166
325
 
326
+ base_classes: list[TypeBase] = []
327
+ for base_class in node.base_classes or []:
328
+ base_class_type = self.get_type_of_expression(base_class)
329
+ base_classes.append(base_class_type)
330
+ is_builtin_class = node.find_parent_of_type(uni.Module) == self.builtins_module
331
+
167
332
  cls_type = types.ClassType(
168
333
  types.ClassType.ClassDetailsShared(
169
334
  class_name=node.name_spec.sym_name,
170
335
  symbol_table=node,
171
- # TODO: Resolve the base class expression and pass them here.
336
+ base_classes=base_classes,
337
+ is_builtin_class=is_builtin_class,
172
338
  ),
173
339
  flags=types.TypeFlags.Instantiable,
174
340
  )
175
341
 
342
+ # Compute the MRO for the class.
343
+ type_utils.compute_mro_linearization(cls_type)
344
+
176
345
  # Cache the type, pyright is doing invalidateTypeCacheIfCanceled()
177
346
  # we're not doing that any time sooner.
178
347
  node.name_spec.type = cls_type
179
348
  return cls_type
180
349
 
350
+ def get_type_of_ability(self, node: uni.Ability) -> TypeBase:
351
+ """Return the effective type of an ability."""
352
+ if node.name_spec.type is not None:
353
+ return node.name_spec.type
354
+
355
+ if not isinstance(node.signature, uni.FuncSignature):
356
+ node.name_spec.type = types.UnknownType()
357
+ return node.name_spec.type
358
+
359
+ return_type: TypeBase
360
+ if isinstance(node.signature.return_type, uni.Expr):
361
+ return_type = self._convert_to_instance(
362
+ self.get_type_of_expression(node.signature.return_type)
363
+ )
364
+ else:
365
+ return_type = types.UnknownType()
366
+
367
+ # Define helper function for parameter conversion.
368
+ def _get_param_category(param: uni.ParamVar) -> types.ParameterCategory:
369
+ if param.is_vararg:
370
+ return types.ParameterCategory.ArgsList
371
+ if param.is_kwargs:
372
+ return types.ParameterCategory.KwargsDict
373
+ return types.ParameterCategory.Positional
374
+
375
+ # Define helper function for parameter kind conversion.
376
+ def _convert_param_kind(kind: uni.ParamKind) -> types.ParamKind:
377
+ match kind:
378
+ case uni.ParamKind.POSONLY:
379
+ return types.ParamKind.POSONLY
380
+ case uni.ParamKind.NORMAL:
381
+ return types.ParamKind.NORMAL
382
+ case uni.ParamKind.VARARG:
383
+ return types.ParamKind.VARARG
384
+ case uni.ParamKind.KWONLY:
385
+ return types.ParamKind.KWONLY
386
+ case uni.ParamKind.KWARG:
387
+ return types.ParamKind.KWARG
388
+ return types.ParamKind.NORMAL
389
+
390
+ parameters: list[types.Parameter] = []
391
+ for idx, param in enumerate(node.signature.get_parameters()):
392
+ # TODO: Set parameter category for *args, and **kwargs
393
+ param_type: TypeBase | None = None
394
+
395
+ if param.type_tag:
396
+ param_type_cls = self.get_type_of_expression(param.type_tag.tag)
397
+ param_type = self._convert_to_instance(param_type_cls)
398
+
399
+ parameters.append(
400
+ types.Parameter(
401
+ name=param.name.value,
402
+ category=_get_param_category(param),
403
+ param_type=param_type,
404
+ default_value=param.value,
405
+ is_self=(idx == 0 and self._is_expr_self(param.name)),
406
+ param_kind=_convert_param_kind(param.param_kind),
407
+ )
408
+ )
409
+
410
+ func_type = types.FunctionType(
411
+ func_name=node.name_spec.sym_name,
412
+ return_type=return_type,
413
+ parameters=parameters,
414
+ )
415
+
416
+ node.name_spec.type = func_type
417
+ return func_type
418
+
181
419
  def get_type_of_string(self, node: uni.String | uni.MultiString) -> TypeBase:
182
420
  """Return the effective type of the string."""
183
421
  # FIXME: Strings are a type of LiteralString type:
@@ -195,6 +433,11 @@ class TypeEvaluator:
195
433
  assert self.prefetch.int_class is not None
196
434
  return self.prefetch.int_class
197
435
 
436
+ def get_type_of_float(self, node: uni.Float) -> TypeBase:
437
+ """Return the effective type of the float."""
438
+ assert self.prefetch.float_class is not None
439
+ return self.prefetch.float_class
440
+
198
441
  # Pyright equivalent function name = getTypeOfExpression();
199
442
  def get_type_of_expression(self, node: uni.Expr) -> TypeBase:
200
443
  """Return the effective type of the expression."""
@@ -221,9 +464,6 @@ class TypeEvaluator:
221
464
  # NOTE: For now if we don't have the type info, we assume it's compatible.
222
465
  # For strict mode we should disallow usage of unknown unless explicitly ignored.
223
466
  return True
224
- # FIXME: This logic is not valid, just here as a stub.
225
- if types.TypeCategory.Unknown in (src_type.category, dest_type.category):
226
- return True
227
467
 
228
468
  if src_type == dest_type:
229
469
  return True
@@ -233,8 +473,30 @@ class TypeEvaluator:
233
473
  assert isinstance(src_type, types.ClassType)
234
474
  return self._assign_class(src_type, dest_type)
235
475
 
236
- # FIXME: This is temporary.
237
- return src_type == dest_type
476
+ return False
477
+
478
+ # TODO: This should take an argument list as parameter.
479
+ def get_type_of_magic_method_call(
480
+ self, obj_type: TypeBase, method_name: str
481
+ ) -> TypeBase | None:
482
+ """Return the effective return type of a magic method call."""
483
+ if obj_type.category == types.TypeCategory.Class:
484
+ # TODO: getTypeOfBoundMember() <-- Implement this if needed, for the simple case
485
+ # we'll directly call member lookup.
486
+ #
487
+ # WE'RE DAVIATING FROM PYRIGHT FOR THIS METHOD HEAVILY HOWEVER THIS CAN BE RE-WRITTEN IF NEEDED.
488
+ #
489
+ assert isinstance(obj_type, types.ClassType) # <-- To make typecheck happy.
490
+ if member := self._lookup_class_member(obj_type, method_name):
491
+ member_ty = self.get_type_of_symbol(member.symbol)
492
+ if isinstance(member_ty, types.FunctionType):
493
+ return member_ty.return_type
494
+ # If we reached here, magic method is not a function.
495
+ # 1. recursively check __call__() on the type, TODO
496
+ # 2. if any or unknown, return getUnknownTypeForCallable() TODO
497
+ # 3. return undefined.
498
+ return None
499
+ return None
238
500
 
239
501
  def _assign_class(
240
502
  self, src_type: types.ClassType, dest_type: types.ClassType
@@ -243,39 +505,23 @@ class TypeEvaluator:
243
505
  if src_type.shared == dest_type.shared:
244
506
  return True
245
507
 
246
- # TODO: Search base classes and everything else pyright is doing.
247
- return False
508
+ # Check if src class is a subclass of dest class.
509
+ for base_cls in src_type.shared.mro:
510
+ if base_cls.shared == dest_type.shared:
511
+ return True
248
512
 
249
- def _prefetch_types(self) -> "PrefetchedTypes":
250
- """Return the prefetched types for the type evaluator."""
251
- return PrefetchedTypes(
252
- # TODO: Pyright first try load NoneType from typeshed and if it cannot
253
- # then it set to unknown type.
254
- none_type_class=types.UnknownType(),
255
- object_class=self._get_builtin_type("object"),
256
- type_class=self._get_builtin_type("type"),
257
- # union_type_class=
258
- # awaitable_class=
259
- # function_class=
260
- # method_class=
261
- tuple_class=self._get_builtin_type("tuple"),
262
- bool_class=self._get_builtin_type("bool"),
263
- int_class=self._get_builtin_type("int"),
264
- str_class=self._get_builtin_type("str"),
265
- dict_class=self._get_builtin_type("dict"),
266
- # module_type_class=
267
- # typed_dict_class=
268
- # typed_dict_private_class=
269
- # supports_keys_and_get_item_class=
270
- # mapping_class=
271
- # template_class=
272
- )
513
+ # Everything is assignable to an object.
514
+ if dest_type.is_builtin("object"):
515
+ # TODO: Invariance not handled yet
516
+ # invariant contexts to avoid list[int] <: list[object] errors.
517
+ return True
273
518
 
274
- def _get_builtin_type(self, name: str) -> TypeBase:
275
- """Return the built-in type with the given name."""
276
- if (symbol := self.builtins_module.lookup(name)) is not None:
277
- return self.get_type_of_symbol(symbol)
278
- return types.UnknownType()
519
+ # Integers can be used where floats are expected.
520
+ if src_type.is_builtin("int") and dest_type.is_builtin("float"):
521
+ return True
522
+
523
+ # TODO: Search base classes and everything else pyright is doing.
524
+ return False
279
525
 
280
526
  # This function is a combination of the bellow pyright functions.
281
527
  # - getDeclaredTypeOfSymbol
@@ -298,6 +544,14 @@ class TypeEvaluator:
298
544
  case uni.Archetype():
299
545
  return self.get_type_of_class(node)
300
546
 
547
+ case uni.Ability():
548
+ return self.get_type_of_ability(node)
549
+
550
+ case uni.ParamVar():
551
+ if node.type_tag:
552
+ annotation_type = self.get_type_of_expression(node.type_tag.tag)
553
+ return self._convert_to_instance(annotation_type)
554
+
301
555
  # This actually defined in the function getTypeForDeclaration();
302
556
  # Pyright has DeclarationType.Variable.
303
557
  case uni.Name():
@@ -335,6 +589,9 @@ class TypeEvaluator:
335
589
  case uni.Int():
336
590
  return self._convert_to_instance(self.get_type_of_int(expr))
337
591
 
592
+ case uni.Float():
593
+ return self._convert_to_instance(self.get_type_of_float(expr))
594
+
338
595
  case uni.AtomTrailer():
339
596
  # NOTE: Pyright is using CFG to figure out the member type by narrowing the base
340
597
  # type and filtering the members. We're not doing that anytime sooner.
@@ -376,13 +633,71 @@ class TypeEvaluator:
376
633
  else: # <expr>[<expr>]
377
634
  pass # TODO:
378
635
 
636
+ case uni.AtomUnit():
637
+ return self.get_type_of_expression(expr.value)
638
+
639
+ case uni.FuncCall():
640
+ return self.validate_call_args(expr)
641
+
642
+ case uni.BinaryExpr():
643
+ return operations.get_type_of_binary_operation(self, expr)
644
+
379
645
  case uni.Name():
646
+ # NOTE: For self's type pyright is getting the first parameter of a method and
647
+ # the name can be anything not just self, however we don't have the first parameter
648
+ # and self is a keyword, we need to do it in this way.
649
+ if self._is_expr_self(expr):
650
+ return self._get_type_of_self(expr)
651
+
380
652
  if symbol := expr.sym_tab.lookup(expr.value, deep=True):
653
+ expr.sym = symbol
381
654
  return self.get_type_of_symbol(symbol)
382
655
 
383
656
  # TODO: More expressions.
384
657
  return types.UnknownType()
385
658
 
659
+ # -----------------------------------------------------------------------------
660
+ # Helper functions
661
+ # -----------------------------------------------------------------------------
662
+
663
+ def _is_expr_self(self, expr: uni.Expr) -> bool:
664
+ """Check if the expression is Name that is 'self' and in the method context."""
665
+ if (
666
+ isinstance(expr, uni.Name)
667
+ and (expr.value == TOKEN_MAP[Tok.KW_SELF])
668
+ and (fn := self._get_enclosing_method(expr))
669
+ and (not fn.is_static)
670
+ ):
671
+ return True
672
+ return False
673
+
674
+ def _get_enclosing_function(self, node: uni.UniNode) -> uni.Ability | None:
675
+ """Get the enclosing function (ability) of the given node."""
676
+ if (impl := node.find_parent_of_type(uni.ImplDef)) and (
677
+ isinstance(impl.decl_link, uni.Ability)
678
+ ):
679
+ return impl.decl_link
680
+ return node.find_parent_of_type(uni.Ability)
681
+
682
+ def _get_enclosing_method(self, node: uni.UniNode) -> uni.Ability | None:
683
+ """Get the enclosing method (ability) of the given node."""
684
+ enclosing_fn = self._get_enclosing_function(node)
685
+ while enclosing_fn and (not enclosing_fn.is_method):
686
+ enclosing_fn = self._get_enclosing_function(enclosing_fn)
687
+ if enclosing_fn and enclosing_fn.is_method:
688
+ return enclosing_fn
689
+ return None
690
+
691
+ def _get_type_of_self(self, node: uni.Name) -> TypeBase:
692
+ """Return the effective type of self."""
693
+ if method := self._get_enclosing_method(node):
694
+ cls = method.method_owner
695
+ if isinstance(cls, uni.Archetype):
696
+ return self.get_type_of_class(cls).clone_as_instance()
697
+ if isinstance(cls, uni.Enum):
698
+ pass # TODO: Implement type from enum.
699
+ return types.UnknownType()
700
+
386
701
  def _convert_to_instance(self, jtype: TypeBase) -> TypeBase:
387
702
  """Convert a class type to an instance type."""
388
703
  # TODO: Grep pyright "Handle type[x] as a special case." They handle `type[x]` as a special case:
@@ -397,7 +712,7 @@ class TypeEvaluator:
397
712
 
398
713
  def _lookup_class_member(
399
714
  self, base_type: types.ClassType, member: str
400
- ) -> ClassMember | None:
715
+ ) -> type_utils.ClassMember | None:
401
716
  """Lookup the class member type."""
402
717
  assert self.prefetch.int_class is not None
403
718
  # FIXME: Pyright's way: Implement class member iterator (based on mro and the multiple inheritance)
@@ -405,13 +720,14 @@ class TypeEvaluator:
405
720
 
406
721
  # NOTE: This is a simple implementation to make it work and more robust implementation will
407
722
  # be done in a future PR.
408
- if sym := base_type.lookup_member_symbol(member):
409
- return ClassMember(sym, base_type)
723
+ for cls in base_type.shared.mro:
724
+ if sym := cls.lookup_member_symbol(member):
725
+ return type_utils.ClassMember(sym, cls)
410
726
  return None
411
727
 
412
728
  def _lookup_object_member(
413
729
  self, base_type: types.ClassType, member: str
414
- ) -> ClassMember | None:
730
+ ) -> type_utils.ClassMember | None:
415
731
  """Lookup the object member type."""
416
732
  assert self.prefetch.int_class is not None
417
733
  if base_type.is_class_instance():
@@ -419,3 +735,110 @@ class TypeEvaluator:
419
735
  # TODO: We need to implement Member lookup flags and set SkipInstanceMember to 0.
420
736
  return self._lookup_class_member(base_type, member)
421
737
  return None
738
+
739
+ def match_args_to_params(
740
+ self, expr: uni.FuncCall, func_type: types.FunctionType
741
+ ) -> MatchArgsToParamsResult:
742
+ """
743
+ Match arguments passed to a function to the corresponding parameters in that function.
744
+
745
+ This matching is done based on positions and keywords. Type evaluation and
746
+ validation is left to the caller.
747
+ This logic is based on PEP 3102: https://www.python.org/dev/peps/pep-3102/
748
+ """
749
+ arg_params: dict[uni.Expr | uni.KWPair, types.Parameter | None] = {}
750
+ argument_errors = False
751
+
752
+ params_to_match = func_type.parameters.copy()
753
+
754
+ # Skip `self` for method calls.
755
+ if len(func_type.parameters) >= 1 and func_type.parameters[0].is_self:
756
+ params_to_match.pop(0)
757
+
758
+ # Create a tracker for parameter assignment.
759
+ param_tracker = type_utils.ParamAssignmentTracker(params_to_match)
760
+
761
+ # We iterate over the arguments and match with the parameter, the param_tracker will
762
+ # keep track of the matched parameters and unmatched required parameters.
763
+ #
764
+ # Tracker: p1, p2, /, p3, p4, *args, | p6, **kwargs
765
+ # ^ ^ ^ ^ ^^ ^ | ^ ^^
766
+ # | | | | | \__ \__ | | | \________
767
+ # Args: a1, a2, a3, p4=a4, a5, a6, *a7, | p6=a8, p_kw1=a9, p_kw2=a10
768
+ # '--------------------------------' | '------------------------'
769
+ # We match positional with | We match named arguments with
770
+ # tracked parameter index. | param name lookup.
771
+ #
772
+ for arg in expr.params:
773
+ try:
774
+ if isinstance(arg, uni.KWPair):
775
+ # Match parameter based on name lookup.
776
+ matching_param = param_tracker.match_named_argument(arg)
777
+ arg_params[arg] = matching_param
778
+ else: # Match parameter based on the position of the argument.
779
+ matching_param = param_tracker.match_positional_argument(arg)
780
+ arg_params[arg] = matching_param
781
+ except Exception as e:
782
+ self.add_diagnostic(arg, str(e))
783
+ argument_errors = True
784
+
785
+ if unmatched_params := param_tracker.get_unmatched_required_params():
786
+ names = ", ".join(f"'{p.name}'" for p in unmatched_params)
787
+ argument_errors = True
788
+ self.add_diagnostic(
789
+ expr,
790
+ f"Not all required parameters were provided in the function call: {names}",
791
+ )
792
+
793
+ return MatchArgsToParamsResult(
794
+ arg_params=arg_params, argument_errors=argument_errors
795
+ )
796
+
797
+ def validate_call_args(self, expr: uni.FuncCall) -> TypeBase:
798
+ """
799
+ Validate that the arguments can be assigned to the call's parameter list.
800
+
801
+ Specializes the call based on arg types, and returns the specialized
802
+ type of the return value. If it detects an error along the way, it emits
803
+ a diagnostic and sets argumentErrors to true.
804
+ """
805
+ caller_type = self.get_type_of_expression(expr.target)
806
+ if isinstance(caller_type, types.FunctionType):
807
+ arg_param_match = self.match_args_to_params(expr, caller_type)
808
+ if not arg_param_match.argument_errors:
809
+ self.validate_arg_types(arg_param_match)
810
+ return caller_type.return_type or types.UnknownType()
811
+
812
+ if (
813
+ isinstance(caller_type, types.ClassType)
814
+ and caller_type.is_instantiable_class()
815
+ ):
816
+ # TODO: validate args for __init__()
817
+ return caller_type.clone_as_instance()
818
+
819
+ if caller_type.is_class_instance():
820
+ # TODO: validate args.
821
+ magic_call_ret = self.get_type_of_magic_method_call(caller_type, "__call__")
822
+ if magic_call_ret:
823
+ return magic_call_ret
824
+
825
+ return types.UnknownType()
826
+
827
+ def validate_arg_types(
828
+ self,
829
+ args: MatchArgsToParamsResult,
830
+ ) -> None:
831
+ """Validate that the argument types can be assigned to the parameter types."""
832
+ for arg, param in args.arg_params.items():
833
+ if param is None or param.param_type is None:
834
+ continue
835
+ if isinstance(arg, uni.KWPair):
836
+ arg_type = self.get_type_of_expression(arg.value)
837
+ else:
838
+ arg_type = self.get_type_of_expression(arg)
839
+
840
+ if not self.assign_type(arg_type, param.param_type):
841
+ self.add_diagnostic(
842
+ arg,
843
+ f"Cannot assign {arg_type} to parameter '{param.name}' of type {param.param_type}",
844
+ )