cinderx 2026.1.16.2__cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. __static__/__init__.py +641 -0
  2. __static__/compiler_flags.py +8 -0
  3. __static__/enum.py +160 -0
  4. __static__/native_utils.py +77 -0
  5. __static__/type_code.py +48 -0
  6. __strict__/__init__.py +39 -0
  7. _cinderx.so +0 -0
  8. cinderx/__init__.py +577 -0
  9. cinderx/__pycache__/__init__.cpython-314.pyc +0 -0
  10. cinderx/_asyncio.py +156 -0
  11. cinderx/compileall.py +710 -0
  12. cinderx/compiler/__init__.py +40 -0
  13. cinderx/compiler/__main__.py +137 -0
  14. cinderx/compiler/config.py +7 -0
  15. cinderx/compiler/consts.py +72 -0
  16. cinderx/compiler/debug.py +70 -0
  17. cinderx/compiler/dis_stable.py +283 -0
  18. cinderx/compiler/errors.py +151 -0
  19. cinderx/compiler/flow_graph_optimizer.py +1287 -0
  20. cinderx/compiler/future.py +91 -0
  21. cinderx/compiler/misc.py +32 -0
  22. cinderx/compiler/opcode_cinder.py +18 -0
  23. cinderx/compiler/opcode_static.py +100 -0
  24. cinderx/compiler/opcodebase.py +158 -0
  25. cinderx/compiler/opcodes.py +991 -0
  26. cinderx/compiler/optimizer.py +547 -0
  27. cinderx/compiler/pyassem.py +3711 -0
  28. cinderx/compiler/pycodegen.py +7660 -0
  29. cinderx/compiler/pysourceloader.py +62 -0
  30. cinderx/compiler/static/__init__.py +1404 -0
  31. cinderx/compiler/static/compiler.py +629 -0
  32. cinderx/compiler/static/declaration_visitor.py +335 -0
  33. cinderx/compiler/static/definite_assignment_checker.py +280 -0
  34. cinderx/compiler/static/effects.py +160 -0
  35. cinderx/compiler/static/module_table.py +666 -0
  36. cinderx/compiler/static/type_binder.py +2176 -0
  37. cinderx/compiler/static/types.py +10580 -0
  38. cinderx/compiler/static/util.py +81 -0
  39. cinderx/compiler/static/visitor.py +91 -0
  40. cinderx/compiler/strict/__init__.py +69 -0
  41. cinderx/compiler/strict/class_conflict_checker.py +249 -0
  42. cinderx/compiler/strict/code_gen_base.py +409 -0
  43. cinderx/compiler/strict/common.py +507 -0
  44. cinderx/compiler/strict/compiler.py +352 -0
  45. cinderx/compiler/strict/feature_extractor.py +130 -0
  46. cinderx/compiler/strict/flag_extractor.py +97 -0
  47. cinderx/compiler/strict/loader.py +827 -0
  48. cinderx/compiler/strict/preprocessor.py +11 -0
  49. cinderx/compiler/strict/rewriter/__init__.py +5 -0
  50. cinderx/compiler/strict/rewriter/remove_annotations.py +84 -0
  51. cinderx/compiler/strict/rewriter/rewriter.py +975 -0
  52. cinderx/compiler/strict/runtime.py +77 -0
  53. cinderx/compiler/symbols.py +1754 -0
  54. cinderx/compiler/unparse.py +414 -0
  55. cinderx/compiler/visitor.py +194 -0
  56. cinderx/jit.py +230 -0
  57. cinderx/opcode.py +202 -0
  58. cinderx/static.py +113 -0
  59. cinderx/strictmodule.py +6 -0
  60. cinderx/test_support.py +341 -0
  61. cinderx-2026.1.16.2.dist-info/METADATA +15 -0
  62. cinderx-2026.1.16.2.dist-info/RECORD +68 -0
  63. cinderx-2026.1.16.2.dist-info/WHEEL +6 -0
  64. cinderx-2026.1.16.2.dist-info/licenses/LICENSE +21 -0
  65. cinderx-2026.1.16.2.dist-info/top_level.txt +5 -0
  66. opcodes/__init__.py +0 -0
  67. opcodes/assign_opcode_numbers.py +272 -0
  68. opcodes/cinderx_opcodes.py +121 -0
@@ -0,0 +1,2176 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # pyre-strict
4
+
5
+ from __future__ import annotations
6
+
7
+ import ast
8
+ from ast import (
9
+ And,
10
+ AnnAssign,
11
+ Assign,
12
+ AST,
13
+ AsyncFor,
14
+ AsyncFunctionDef,
15
+ Attribute,
16
+ AugAssign,
17
+ Await,
18
+ BinOp,
19
+ BoolOp,
20
+ Call,
21
+ ClassDef,
22
+ Compare,
23
+ Constant,
24
+ DictComp,
25
+ expr,
26
+ For,
27
+ FormattedValue,
28
+ FunctionDef,
29
+ GeneratorExp,
30
+ If,
31
+ IfExp,
32
+ Import,
33
+ ImportFrom,
34
+ Is,
35
+ IsNot,
36
+ JoinedStr,
37
+ Lambda,
38
+ ListComp,
39
+ Match,
40
+ MatchAs,
41
+ MatchClass,
42
+ MatchMapping,
43
+ MatchOr,
44
+ MatchSequence,
45
+ MatchSingleton,
46
+ MatchStar,
47
+ MatchValue,
48
+ Module,
49
+ Name,
50
+ Return,
51
+ SetComp,
52
+ Slice,
53
+ Starred,
54
+ Subscript,
55
+ Try,
56
+ UnaryOp,
57
+ While,
58
+ Yield,
59
+ YieldFrom,
60
+ )
61
+ from collections.abc import Generator, Sequence
62
+ from contextlib import contextmanager
63
+ from enum import IntEnum
64
+ from typing import Optional, TYPE_CHECKING
65
+
66
+ from ..consts import SC_CELL, SC_FREE, SC_GLOBAL_EXPLICIT, SC_GLOBAL_IMPLICIT, SC_LOCAL
67
+ from ..errors import CollectingErrorSink, TypedSyntaxError
68
+ from ..symbols import NodeWithTypeParams, SymbolVisitor
69
+ from .declaration_visitor import GenericVisitor
70
+ from .effects import NarrowingEffect, NO_EFFECT, TypeState
71
+ from .module_table import ModuleFlag, ModuleTable
72
+ from .types import (
73
+ access_path,
74
+ BoolClass,
75
+ Callable,
76
+ CheckedDictInstance,
77
+ CheckedListInstance,
78
+ CInstance,
79
+ Class,
80
+ ClassVar,
81
+ CType,
82
+ Dataclass,
83
+ EnumType,
84
+ FinalClass,
85
+ Function,
86
+ FunctionContainer,
87
+ GenericClass,
88
+ InitVar,
89
+ IsInstanceEffect,
90
+ KnownBoolean,
91
+ MethodType,
92
+ ModuleInstance,
93
+ NestedFunctionClass,
94
+ Object,
95
+ OptionalInstance,
96
+ resolve_assign_error_msg,
97
+ resolve_instance_attr_by_name,
98
+ Slot,
99
+ TMP_VAR_PREFIX,
100
+ TransientDecoratedMethod,
101
+ TransparentDecoratedMethod,
102
+ TType,
103
+ TypeDescr,
104
+ TypeEnvironment,
105
+ TypeName,
106
+ UnionInstance,
107
+ Value,
108
+ )
109
+ from .util import make_qualname
110
+
111
+ if TYPE_CHECKING:
112
+ from .compiler import Compiler
113
+
114
+
115
+ class PreserveRefinedFields:
116
+ pass
117
+
118
+
119
+ class UsedRefinementField:
120
+ def __init__(self, name: str, is_source: bool, is_used: bool) -> None:
121
+ self.name = name
122
+ self.is_source = is_source
123
+ self.is_used = is_used
124
+
125
+ def __repr__(self) -> str:
126
+ return f"UsedRefinementField(name={self.name}, is_source={self.is_source}, is_used={self.is_used})"
127
+
128
+
129
+ PRESERVE_REFINED_FIELDS = PreserveRefinedFields()
130
+
131
+
132
+ class BindingScope:
133
+ name: str
134
+ qualname: str | None
135
+
136
+ def __init__(
137
+ self,
138
+ node: AST,
139
+ type_env: TypeEnvironment,
140
+ parent_qualname: str | None,
141
+ ) -> None:
142
+ self.node = node
143
+ self.type_state = TypeState()
144
+ self.decl_types: dict[str, TypeDeclaration] = {}
145
+ self.type_env: TypeEnvironment = type_env
146
+ if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
147
+ self.name = node.name
148
+ else:
149
+ self.name = "<unknown>"
150
+ self.qualname = make_qualname(parent_qualname, self.name)
151
+
152
+ def branch(self) -> LocalsBranch:
153
+ return LocalsBranch(self)
154
+
155
+ def declare(
156
+ self, name: str, typ: Value, is_final: bool = False, is_inferred: bool = False
157
+ ) -> TypeDeclaration:
158
+ # For an unannotated assignment (is_inferred=True), we declare dynamic
159
+ # type; this disallows later re-declaration, but allows any type to be
160
+ # assigned later, so `x = None; if flag: x = "foo"` works.
161
+ decl = TypeDeclaration(self.type_env.DYNAMIC if is_inferred else typ, is_final)
162
+ self.decl_types[name] = decl
163
+ self.type_state.local_types[name] = typ
164
+ return decl
165
+
166
+
167
+ class EnumBindingScope(BindingScope):
168
+ def __init__(
169
+ self,
170
+ node: AST,
171
+ type_env: TypeEnvironment,
172
+ parent_qualname: str | None,
173
+ enum_type: EnumType,
174
+ ) -> None:
175
+ super().__init__(node, type_env, parent_qualname)
176
+ self.enum_type = enum_type
177
+
178
+ def declare(
179
+ self, name: str, typ: Value, is_final: bool = False, is_inferred: bool = False
180
+ ) -> TypeDeclaration:
181
+ self.enum_type.bind_enum_value(name, typ)
182
+ return super().declare(name, typ, is_final, is_inferred)
183
+
184
+
185
+ class ModuleBindingScope(BindingScope):
186
+ def __init__(
187
+ self,
188
+ node: ast.Module,
189
+ module: ModuleTable,
190
+ type_env: TypeEnvironment,
191
+ ) -> None:
192
+ super().__init__(node, type_env, None)
193
+ self.name = "<module>"
194
+ self.qualname: str | None = None
195
+
196
+ def declare(
197
+ self, name: str, typ: Value, is_final: bool = False, is_inferred: bool = False
198
+ ) -> TypeDeclaration:
199
+ # at module scope we will go ahead and set a declared type even without
200
+ # an annotation, but we don't want to infer the exact type; should be
201
+ # able to reassign to a subtype
202
+ if is_inferred:
203
+ typ = typ.nonliteral().inexact()
204
+ is_inferred = False
205
+ return super().declare(name, typ, is_final=is_final, is_inferred=is_inferred)
206
+
207
+
208
+ class LocalsBranch:
209
+ """Handles branching and merging local variable types"""
210
+
211
+ def __init__(self, scope: BindingScope) -> None:
212
+ self.scope = scope
213
+ self.type_env: TypeEnvironment = scope.type_env
214
+ self.entry_type_state: TypeState = scope.type_state.copy()
215
+
216
+ def copy(self) -> TypeState:
217
+ """Make a copy of the current local state"""
218
+ return self.scope.type_state.copy()
219
+
220
+ def restore(self, state: TypeState | None = None) -> None:
221
+ """Restore the locals to the state when we entered"""
222
+ self.scope.type_state = state or self.entry_type_state.copy()
223
+
224
+ def merge(self, entry_type_state: TypeState | None = None) -> None:
225
+ """Merge the entry type state, or a specific copy, into the current type state"""
226
+ # TODO: What about del's?
227
+ if entry_type_state is None:
228
+ entry_type_state = self.entry_type_state
229
+
230
+ local_types = self.scope.type_state.local_types
231
+ refined_fields = self.scope.type_state.refined_fields
232
+ keys_to_remove = []
233
+ for key, value in local_types.items():
234
+ if key in entry_type_state.local_types:
235
+ if value != entry_type_state.local_types[key]:
236
+ local_types[key] = self._join(
237
+ value, entry_type_state.local_types[key]
238
+ )
239
+ else:
240
+ keys_to_remove.append(key)
241
+ for key in keys_to_remove:
242
+ del local_types[key]
243
+
244
+ keys_to_remove = [
245
+ key for key in refined_fields if key not in entry_type_state.refined_fields
246
+ ]
247
+
248
+ # Unlike local types, we can't simply join the types here, since absence of
249
+ # a refined field indicates that we should use the coarser type. Instead, remove the
250
+ # refinement if something isn't refined at the entry.
251
+ for key in keys_to_remove:
252
+ del refined_fields[key]
253
+ for key in refined_fields:
254
+ entry_refinement_dict = entry_type_state.refined_fields[key]
255
+ refinement_dict = refined_fields[key]
256
+ keys_to_remove = [
257
+ key for key in refinement_dict if key not in entry_refinement_dict
258
+ ]
259
+ for key in keys_to_remove:
260
+ del refinement_dict[key]
261
+ for key in refinement_dict:
262
+ # Every key in here is now also in the entry dict.
263
+ (entry_typ, _, entry_nodes) = entry_refinement_dict[key]
264
+ (typ, idx, nodes) = refinement_dict[key]
265
+ refinement_dict[key] = (
266
+ self._join(entry_typ, typ),
267
+ idx,
268
+ entry_nodes | nodes,
269
+ )
270
+
271
+ def changed(self) -> bool:
272
+ for key in self.entry_type_state.refined_fields:
273
+ if (
274
+ key in self.scope.type_state.refined_fields
275
+ and self.scope.type_state.refined_fields[key]
276
+ != self.entry_type_state.refined_fields[key]
277
+ ):
278
+ return True
279
+ # Refined fields not in the entry state aren't considered.
280
+
281
+ return self.entry_type_state.local_types != self.scope.type_state.local_types
282
+
283
+ def _join(self, *types: Value) -> Value:
284
+ if len(types) == 1:
285
+ return types[0]
286
+
287
+ return self.type_env.get_union(
288
+ tuple(t.klass.inexact_type() for t in types)
289
+ ).instance
290
+
291
+
292
+ class TypeDeclaration:
293
+ def __init__(self, typ: Value, is_final: bool = False) -> None:
294
+ self.type = typ
295
+ self.is_final = is_final
296
+
297
+
298
+ class TerminalKind(IntEnum):
299
+ NonTerminal = 0
300
+ BreakOrContinue = 1
301
+ RaiseOrReturn = 2
302
+
303
+
304
+ class TypeBinder(GenericVisitor[Optional[NarrowingEffect]]):
305
+ """Walks an AST and produces an optionally strongly typed AST, reporting errors when
306
+ operations are occuring that are not sound. Strong types are based upon places where
307
+ annotations occur which opt-in the strong typing"""
308
+
309
+ def __init__(
310
+ self,
311
+ symbols: SymbolVisitor,
312
+ filename: str,
313
+ compiler: Compiler,
314
+ module_name: str,
315
+ optimize: int,
316
+ enable_patching: bool = False,
317
+ ) -> None:
318
+ module = compiler[module_name]
319
+ super().__init__(module)
320
+ self.symbols = symbols
321
+ self.scopes: list[BindingScope] = []
322
+ self.modules: dict[str, ModuleTable] = compiler.modules
323
+ self.optimize = optimize
324
+ self.terminals: dict[AST, TerminalKind] = {}
325
+ self.type_env: TypeEnvironment = compiler.type_env
326
+ self.inline_depth = 0
327
+ self.inline_calls = 0
328
+ self.enable_patching = enable_patching
329
+ self.current_loop: AST | None = None
330
+ self.loop_may_break: set[AST] = set()
331
+ self.visiting_assignment_target = False
332
+ self._refined_tmpvar_indices: dict[str, int] = {}
333
+
334
+ @property
335
+ def nodes_default_dynamic(self) -> bool:
336
+ # If we have a non-throwing ErrorSink, then we may miss typing some
337
+ # nodes on error, so default them to dynamic silently.
338
+ return not self.error_sink.throwing
339
+
340
+ @property
341
+ def type_state(self) -> TypeState:
342
+ return self.binding_scope.type_state
343
+
344
+ @property
345
+ def decl_types(self) -> dict[str, TypeDeclaration]:
346
+ return self.binding_scope.decl_types
347
+
348
+ @property
349
+ def binding_scope(self) -> BindingScope:
350
+ return self.scopes[-1]
351
+
352
+ @property
353
+ def scope(self) -> AST:
354
+ return self.binding_scope.node
355
+
356
+ @property
357
+ def context_qualname(self) -> str:
358
+ return self.binding_scope.qualname or ""
359
+
360
+ def maybe_set_local_type(self, name: str, local_type: Value) -> Value:
361
+ decl = self.get_target_decl(name)
362
+ assert decl is not None
363
+ decl_type = decl.type
364
+ if local_type is self.type_env.DYNAMIC or not decl_type.klass.can_be_narrowed:
365
+ local_type = decl_type
366
+ self.type_state.local_types[name] = local_type
367
+ return local_type
368
+
369
+ def maybe_get_current_class(self) -> Class | None:
370
+ node = self.scope
371
+ if isinstance(node, ClassDef):
372
+ res = self.get_type(node)
373
+ assert isinstance(res, Class)
374
+ return res
375
+
376
+ def maybe_get_current_enclosing_class(self) -> Class | None:
377
+ for scope in reversed(self.scopes):
378
+ node = scope.node
379
+ if isinstance(node, ClassDef):
380
+ res = self.get_type(node)
381
+ return res if isinstance(res, Class) else None
382
+
383
+ def visit(self, node: AST, *args: object) -> NarrowingEffect | None:
384
+ """This override is only here to give Pyre the return type information."""
385
+ ret = super().visit(node, *args)
386
+ if (
387
+ len(self.scopes) > 0
388
+ and isinstance(node, AST)
389
+ and not self.get_opt_node_data(node, PreserveRefinedFields)
390
+ ):
391
+ self.type_state.refined_fields.clear()
392
+
393
+ if ret is not None:
394
+ assert isinstance(ret, NarrowingEffect)
395
+ return ret
396
+ return None
397
+
398
+ def get_final_literal(self, node: AST) -> ast.Constant | None:
399
+ return self.module.get_final_literal(node, self.symbols.scopes[self.scope])
400
+
401
+ def declare_local(
402
+ self,
403
+ name: str,
404
+ typ: Value,
405
+ is_final: bool = False,
406
+ is_inferred: bool = False,
407
+ ) -> None:
408
+ if name in self.decl_types and (
409
+ typ.is_nominal_type or self.decl_types[name].type.is_nominal_type
410
+ ):
411
+ raise TypedSyntaxError(f"Cannot redefine local variable {name}")
412
+
413
+ if isinstance(typ, CInstance):
414
+ self.check_primitive_scope(name)
415
+ self.binding_scope.declare(
416
+ name, typ, is_final=is_final, is_inferred=is_inferred
417
+ )
418
+
419
+ def check_static_import_flags(self, node: Module) -> None:
420
+ saw_doc_str = False
421
+ for stmt in node.body:
422
+ if isinstance(stmt, ast.Expr):
423
+ val = stmt.value
424
+ if isinstance(val, ast.Constant) and isinstance(val.value, str):
425
+ if saw_doc_str:
426
+ break
427
+ saw_doc_str = True
428
+ else:
429
+ break
430
+ elif isinstance(stmt, ast.Import):
431
+ continue
432
+ elif isinstance(stmt, ast.ImportFrom):
433
+ if stmt.module == "__static__.compiler_flags":
434
+ for name in stmt.names:
435
+ if name.name == "checked_dicts":
436
+ self.module.flags.add(ModuleFlag.CHECKED_DICTS)
437
+ elif name.name == "checked_lists":
438
+ self.module.flags.add(ModuleFlag.CHECKED_LISTS)
439
+
440
+ def visitModule(self, node: Module) -> None:
441
+ self.scopes.append(
442
+ ModuleBindingScope(
443
+ node,
444
+ self.module,
445
+ type_env=self.type_env,
446
+ )
447
+ )
448
+
449
+ self.check_static_import_flags(node)
450
+
451
+ for stmt in node.body:
452
+ self.visit(stmt)
453
+
454
+ self.scopes.pop()
455
+
456
+ def set_param(
457
+ self,
458
+ arg: ast.arg,
459
+ arg_type: Value,
460
+ scope: BindingScope,
461
+ ) -> None:
462
+ scope.declare(arg.arg, arg_type)
463
+ self.set_type(arg, arg_type)
464
+
465
+ def _visitParameters(self, args: ast.arguments, scope: BindingScope) -> None:
466
+ default_index = len(args.defaults or []) - (
467
+ len(args.posonlyargs) + len(args.args)
468
+ )
469
+ qualname = scope.qualname or ""
470
+ for arg in args.posonlyargs:
471
+ ann = arg.annotation
472
+ if ann:
473
+ self.visitExpectedType(
474
+ ann,
475
+ self.type_env.DYNAMIC,
476
+ "argument annotation cannot be a primitive",
477
+ )
478
+ arg_type = (
479
+ self.module.resolve_annotation(ann, qualname)
480
+ or self.type_env.dynamic
481
+ )
482
+ elif arg.arg in scope.decl_types:
483
+ # Already handled self
484
+ default_index += 1
485
+ continue
486
+ else:
487
+ self.perf_warning(
488
+ "Missing type annotation for positional-only argument "
489
+ f"'{arg.arg}' prevents type specialization in Static Python",
490
+ arg,
491
+ )
492
+ arg_type = self.type_env.dynamic
493
+ if default_index >= 0:
494
+ self.visit(args.defaults[default_index], arg_type.instance)
495
+ self.check_can_assign_from(
496
+ arg_type,
497
+ self.get_type(args.defaults[default_index]).klass,
498
+ args.defaults[default_index],
499
+ )
500
+ default_index += 1
501
+ self.set_param(arg, arg_type.instance, scope)
502
+
503
+ for arg in args.args:
504
+ ann = arg.annotation
505
+ if ann:
506
+ self.visitExpectedType(
507
+ ann,
508
+ self.type_env.DYNAMIC,
509
+ "argument annotation cannot be a primitive",
510
+ )
511
+ arg_type = (
512
+ self.module.resolve_annotation(ann, qualname)
513
+ or self.type_env.dynamic
514
+ )
515
+ elif arg.arg in scope.decl_types:
516
+ # Already handled self
517
+ default_index += 1
518
+ continue
519
+ else:
520
+ self.perf_warning(
521
+ f"Missing type annotation for argument '{arg.arg}' "
522
+ "prevents type specialization in Static Python",
523
+ arg,
524
+ )
525
+ arg_type = self.type_env.dynamic
526
+ if default_index >= 0:
527
+ self.visit(args.defaults[default_index], arg_type.instance)
528
+ self.check_can_assign_from(
529
+ arg_type,
530
+ self.get_type(args.defaults[default_index]).klass,
531
+ args.defaults[default_index],
532
+ )
533
+ default_index += 1
534
+ self.set_param(arg, arg_type.instance, scope)
535
+
536
+ vararg = args.vararg
537
+ if vararg:
538
+ ann = vararg.annotation
539
+ if ann:
540
+ self.visitExpectedType(
541
+ ann,
542
+ self.type_env.DYNAMIC,
543
+ "argument annotation cannot be a primitive",
544
+ )
545
+
546
+ self.set_param(vararg, self.type_env.tuple.exact_type().instance, scope)
547
+
548
+ default_index = len(args.kw_defaults or []) - len(args.kwonlyargs)
549
+ for arg in args.kwonlyargs:
550
+ ann = arg.annotation
551
+ if ann:
552
+ self.visitExpectedType(
553
+ ann,
554
+ self.type_env.DYNAMIC,
555
+ "argument annotation cannot be a primitive",
556
+ )
557
+ arg_type = (
558
+ self.module.resolve_annotation(ann, qualname)
559
+ or self.type_env.dynamic
560
+ )
561
+ else:
562
+ self.perf_warning(
563
+ "Missing type annotation for keyword-only argument "
564
+ f"'{arg.arg}' prevents type specialization in Static Python",
565
+ arg,
566
+ )
567
+ arg_type = self.type_env.dynamic
568
+
569
+ if default_index >= 0:
570
+ default = args.kw_defaults[default_index]
571
+ if default is not None:
572
+ self.visit(default, arg_type.instance)
573
+ self.check_can_assign_from(
574
+ arg_type,
575
+ self.get_type(default).klass,
576
+ default,
577
+ )
578
+ default_index += 1
579
+ self.set_param(arg, arg_type.instance, scope)
580
+
581
+ kwarg = args.kwarg
582
+ if kwarg:
583
+ ann = kwarg.annotation
584
+ if ann:
585
+ self.visitExpectedType(
586
+ ann,
587
+ self.type_env.DYNAMIC,
588
+ "argument annotation cannot be a primitive",
589
+ )
590
+ self.set_param(kwarg, self.type_env.dict.exact_type().instance, scope)
591
+
592
+ def new_scope(self, node: AST) -> BindingScope:
593
+ return BindingScope(
594
+ node,
595
+ type_env=self.type_env,
596
+ parent_qualname=self.binding_scope.qualname,
597
+ )
598
+
599
+ def get_func_container(
600
+ self, node: ast.FunctionDef | ast.AsyncFunctionDef
601
+ ) -> FunctionContainer:
602
+ function = self.get_type(node)
603
+ if not isinstance(function, FunctionContainer):
604
+ raise RuntimeError("bad value for function")
605
+
606
+ return function
607
+
608
+ def _visitFunc(self, node: FunctionDef | AsyncFunctionDef) -> None:
609
+ self._visitTypeParams(node)
610
+ func = self.get_func_container(node)
611
+ func.bind_function(node, self)
612
+ typ = self.get_type(node)
613
+ # avoid declaring unknown-decorateds as locals in order to support
614
+ # @overload and @property.setter
615
+ if isinstance(typ, TransientDecoratedMethod):
616
+ return
617
+
618
+ if isinstance(self.scope, (FunctionDef, AsyncFunctionDef)):
619
+ # nested functions can't be invoked against; to ensure we
620
+ # don't, declare them as a special NestedFunctionClass
621
+ # which doesn't support invoking. If there were decorators
622
+ # we couldn't understand then we'll just declare it as dynamic.
623
+ if isinstance(func, Function):
624
+ typ = NestedFunctionClass(
625
+ TypeName(self.context_qualname, node.name), self.type_env, func
626
+ ).instance
627
+ else:
628
+ typ = self.type_env.DYNAMIC
629
+ self.declare_local(node.name, typ)
630
+
631
+ def visitFunctionDef(self, node: FunctionDef) -> None:
632
+ self._visitFunc(node)
633
+
634
+ def visitAsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
635
+ self._visitFunc(node)
636
+
637
+ def visitClassDef(self, node: ClassDef) -> None:
638
+ self._visitTypeParams(node)
639
+
640
+ for decorator in node.decorator_list:
641
+ self.visitExpectedType(
642
+ decorator, self.type_env.DYNAMIC, "decorator cannot be a primitive"
643
+ )
644
+
645
+ for kwarg in node.keywords:
646
+ self.visitExpectedType(
647
+ kwarg.value, self.type_env.DYNAMIC, "class kwarg cannot be a primitive"
648
+ )
649
+
650
+ is_protocol = False
651
+ for base in node.bases:
652
+ self.visitExpectedType(
653
+ base, self.type_env.DYNAMIC, "class base cannot be a primitive"
654
+ )
655
+ base_type = self.get_type(base)
656
+ is_protocol |= base_type is self.type_env.protocol
657
+
658
+ # skip type-binding protocols; they can't be instantiated and their
659
+ # "methods" commonly won't type check anyway since they typically would
660
+ # have no body
661
+ res = self.get_type(node)
662
+ if is_protocol:
663
+ self.module.compile_non_static.add(node)
664
+ else:
665
+ if isinstance(res, EnumType):
666
+ scope = EnumBindingScope(
667
+ node, self.type_env, self.context_qualname, res
668
+ )
669
+ else:
670
+ scope = self.new_scope(node)
671
+ self.scopes.append(scope)
672
+
673
+ for stmt in node.body:
674
+ self.visit(stmt)
675
+
676
+ self.scopes.pop()
677
+
678
+ self.declare_local(node.name, res)
679
+
680
+ # pyre-ignore[11]: Annotation `NodeWithTypeParams` is not defined as a type
681
+ def _visitTypeParams(self, node: NodeWithTypeParams) -> None:
682
+ if hasattr(node, "type_params"):
683
+ for t in node.type_params:
684
+ self.declare_local(t.name, self.type_env.DYNAMIC)
685
+
686
+ # pyre-ignore[11]: Annotation `ast.TypeAlias` is not defined as a type
687
+ def visitTypeAlias(self, node: ast.TypeAlias) -> None:
688
+ self._visitTypeParams(node)
689
+ self.visit(node.value)
690
+ value_type = self.get_type(node.value)
691
+ self.assign_name(node.name, node.name.id, value_type)
692
+
693
+ def set_type(
694
+ self,
695
+ node: AST,
696
+ type: Value,
697
+ ) -> None:
698
+ self.module.types[node] = type
699
+
700
+ def get_type(self, node: AST) -> Value:
701
+ if self.nodes_default_dynamic:
702
+ return self.module.types.get(node, self.type_env.DYNAMIC)
703
+ # pyre-fixme[16]: `AST` has no attribute `lineno`.
704
+ assert node in self.module.types, f"node not found: {node}, {node.lineno}"
705
+ return self.module.types[node]
706
+
707
+ def get_node_data(self, key: AST, data_type: type[TType]) -> TType:
708
+ return self.module.get_node_data(key, data_type)
709
+
710
+ def get_opt_node_data(self, key: AST, data_type: type[TType]) -> TType | None:
711
+ return self.module.get_opt_node_data(key, data_type)
712
+
713
+ def set_node_data(self, key: AST, data_type: type[TType], value: TType) -> None:
714
+ self.module.set_node_data(key, data_type, value)
715
+
716
+ def check_primitive_scope(self, name: str) -> None:
717
+ cur_scope = self.symbols.scopes[self.scope]
718
+ var_scope = cur_scope.check_name(name)
719
+ if var_scope != SC_LOCAL or isinstance(self.scope, Module):
720
+ raise TypedSyntaxError("cannot use primitives in global or closure scope")
721
+
722
+ def get_var_scope(self, var_id: str) -> int | None:
723
+ cur_scope = self.symbols.scopes[self.scope]
724
+ var_scope = cur_scope.check_name(var_id)
725
+ return var_scope
726
+
727
+ def _check_final_attribute_reassigned(
728
+ self,
729
+ target: AST,
730
+ assignment: AST | None,
731
+ ) -> None:
732
+ member = None
733
+ klass = None
734
+ member_name = None
735
+
736
+ # Try to look up the Class and associated Slot
737
+ scope = self.scope
738
+ if isinstance(target, ast.Name) and isinstance(scope, ast.ClassDef):
739
+ klass = self.maybe_get_current_class()
740
+ assert isinstance(klass, Class)
741
+ member_name = target.id
742
+ member = klass.get_member(member_name)
743
+ elif isinstance(target, ast.Attribute):
744
+ val = self.get_type(target.value)
745
+ member_name = target.attr
746
+ # TODO this logic will be inadequate if we support metaclasses
747
+ if isinstance(val, Class):
748
+ klass = val
749
+ else:
750
+ klass = val.klass
751
+ member = klass.get_member(member_name)
752
+
753
+ # Ensure we don't reassign to Finals
754
+ if (
755
+ klass is not None
756
+ and member is not None
757
+ and (
758
+ (
759
+ isinstance(member, Slot)
760
+ and member.is_final
761
+ and member.assignment != assignment
762
+ )
763
+ or (isinstance(member, Function) and member.is_final)
764
+ or (
765
+ isinstance(member, TransparentDecoratedMethod)
766
+ and isinstance(member.function, Function)
767
+ and member.function.is_final
768
+ )
769
+ )
770
+ ):
771
+ self.syntax_error(
772
+ f"Cannot assign to a Final attribute of {klass.instance.name}:{member_name}",
773
+ target,
774
+ )
775
+
776
+ def visitAnnAssign(self, node: AnnAssign) -> None:
777
+ self.visitExpectedType(
778
+ node.annotation,
779
+ self.type_env.DYNAMIC,
780
+ "annotation can not be a primitive value",
781
+ )
782
+
783
+ target = node.target
784
+ comp_type = (
785
+ self.module.resolve_annotation(
786
+ node.annotation, self.context_qualname, is_declaration=True
787
+ )
788
+ or self.type_env.dynamic
789
+ )
790
+ is_final = False
791
+ comp_type, wrapper = comp_type.unwrap(), type(comp_type)
792
+ if wrapper in (ClassVar, InitVar) and not isinstance(self.scope, ClassDef):
793
+ self.syntax_error(
794
+ f"{wrapper.__name__} is allowed only in class attribute annotations.",
795
+ node,
796
+ )
797
+ if wrapper is FinalClass:
798
+ is_final = True
799
+
800
+ declared_type = comp_type.instance
801
+ is_dynamic_final = is_final and declared_type is self.type_env.DYNAMIC
802
+ if isinstance(target, Name):
803
+ # We special case x: Final[dynamic] = value to treat `x`'s inferred type as the
804
+ # declared type instead of the comp_type - this allows us to support aliasing of
805
+ # functions declared as protocols.
806
+ if is_dynamic_final:
807
+ value = node.value
808
+ if value:
809
+ self.visit(value)
810
+ declared_type = self.get_type(value)
811
+
812
+ self.declare_local(target.id, declared_type, is_final)
813
+ self.set_type(target, declared_type)
814
+
815
+ with self.in_target():
816
+ self.visit(target)
817
+ value = node.value
818
+ if isinstance(self.scope, ClassDef):
819
+ scope_type = self.get_type(self.scope)
820
+ if isinstance(scope_type, Dataclass) and isinstance(target, Name):
821
+ value = scope_type.bind_field(target.id, value, self)
822
+ if value:
823
+ self.visitExpectedType(value, declared_type)
824
+ if not is_dynamic_final:
825
+ if isinstance(target, Name):
826
+ # We could be narrowing the type after the assignment, so we update it here
827
+ # even though we assigned it above (but we never narrow primtives)
828
+ new_type = self.get_type(value)
829
+ local_type = self.maybe_set_local_type(target.id, new_type)
830
+ self.set_type(target, local_type)
831
+
832
+ self._check_final_attribute_reassigned(target, node)
833
+
834
+ def visitAugAssign(self, node: AugAssign) -> None:
835
+ self.visit(node.target)
836
+ target_type = self.get_type(node.target).inexact()
837
+ self.visit(node.value, target_type)
838
+ self.set_type(node, target_type)
839
+
840
+ @contextmanager
841
+ def in_target(self) -> Generator[None, None, None]:
842
+ prev = self.visiting_assignment_target
843
+ self.visiting_assignment_target = True
844
+ try:
845
+ yield
846
+ finally:
847
+ self.visiting_assignment_target = prev
848
+
849
+ def visitNamedExpr(
850
+ self, node: ast.NamedExpr, type_ctx: Class | None = None
851
+ ) -> NarrowingEffect | None:
852
+ target = node.target
853
+ with self.in_target():
854
+ self.visit(target)
855
+ target_type = self.get_type(target)
856
+ self.visit(node.value, target_type)
857
+ value_type = self.get_type(node.value)
858
+ self.assign_value(target, value_type)
859
+ self.set_type(node, self.get_type(target))
860
+ return self.refine_truthy(node.target)
861
+
862
+ def visitAssign(self, node: Assign) -> None:
863
+ # Sometimes, we need to propagate types from the target to the value to allow primitives to be handled
864
+ # correctly. So we compute the narrowest target type. (Other checks do happen later).
865
+ # e.g: `x: int8 = 1` means we need `1` to be of type `int8`
866
+ narrowest_target_type = None
867
+ for target in reversed(node.targets):
868
+ cur_type = None
869
+ if isinstance(target, ast.Name):
870
+ # This is a name, it could be unassigned still
871
+ decl_type = self.get_target_decl(target.id)
872
+ if decl_type is not None:
873
+ cur_type = decl_type.type
874
+ elif isinstance(target, (ast.Tuple, ast.List)):
875
+ # TODO: We should walk into the tuple/list and use it to infer
876
+ # types down on the RHS if we can
877
+ with self.in_target():
878
+ self.visit(target)
879
+ else:
880
+ # This is an attribute or subscript, the assignment can't change the type
881
+ self.visit(target)
882
+ cur_type = self.get_type(target)
883
+
884
+ if cur_type is not None and (
885
+ narrowest_target_type is None
886
+ or narrowest_target_type.klass.can_assign_from(cur_type.klass)
887
+ ):
888
+ narrowest_target_type = cur_type
889
+
890
+ self.visit(node.value, narrowest_target_type)
891
+ value_type = self.get_type(node.value)
892
+ for target in reversed(node.targets):
893
+ self.assign_value(target, value_type, src=node.value, assignment=node)
894
+
895
+ if len(node.targets) == 1 and self.is_refinable(node.targets[0]):
896
+ # In simple cases (i.e. a = expr), we know that the act of assignment won't execute
897
+ # arbitrary code. There are some more complex cases we can handle (self.x = y where we
898
+ # know self.x is a slot), but ignore these for now for safety.
899
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
900
+ target = node.targets[0]
901
+ if (
902
+ isinstance(target, ast.Attribute)
903
+ and narrowest_target_type != value_type
904
+ and self.type_env.dynamic.can_assign_from(value_type.klass)
905
+ ):
906
+ assert isinstance(target.value, ast.Name)
907
+ # pyre-fixme[16]: `expr` has no attribute `id`.
908
+ self.type_state.refined_fields.setdefault(target.value.id, {})[
909
+ target.attr
910
+ ] = (
911
+ value_type,
912
+ self.refined_field_index(access_path(target)),
913
+ {target},
914
+ )
915
+
916
+ self.set_type(node, value_type)
917
+
918
+ def visitPass(self, node: ast.Pass) -> None:
919
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
920
+
921
+ def check_can_assign_from(
922
+ self,
923
+ dest: Class,
924
+ src: Class,
925
+ node: AST,
926
+ reason: str = "type mismatch: {} cannot be assigned to {}",
927
+ ) -> None:
928
+ if not dest.can_assign_from(src) and (
929
+ src is not self.type_env.dynamic or isinstance(dest, CType)
930
+ ):
931
+ reason = resolve_assign_error_msg(dest, src, reason)
932
+ self.syntax_error(reason, node)
933
+
934
+ def clear_refinements_for_nonbool_test(self, test_node: ast.AST) -> None:
935
+ # If we visit an expression of an unknown type in a test context, such as
936
+ # `if x: ...`, `x.__bool__()` will be called if `x` is not a boolean. Since
937
+ # `x.__bool__()` can execute arbitrary code, it may also contain side effects
938
+ # which affect field refinements. For safety, we clear them if we know that
939
+ # `__bool__()` can be called.
940
+ if self.get_type(test_node).klass is not self.type_env.bool:
941
+ self.type_state.refined_fields.clear()
942
+
943
+ def visitAssert(self, node: ast.Assert) -> None:
944
+ effect = self.visit(node.test) or NO_EFFECT
945
+ effect.apply(self.type_state)
946
+ self.clear_refinements_for_nonbool_test(node.test)
947
+ self.set_node_data(node, NarrowingEffect, effect)
948
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
949
+ message = node.msg
950
+ if message:
951
+ self.visitExpectedType(
952
+ message, self.type_env.DYNAMIC, "assert message cannot be a primitive"
953
+ )
954
+
955
+ def visitBoolOp(
956
+ self, node: BoolOp, type_ctx: Class | None = None
957
+ ) -> NarrowingEffect:
958
+ effect = NO_EFFECT
959
+ final_type = None
960
+ if isinstance(node.op, And):
961
+ for value in node.values:
962
+ new_effect = self.visit(value) or NO_EFFECT
963
+ effect = effect.and_(new_effect)
964
+ final_type = self.widen(final_type, self.get_type(value))
965
+
966
+ # apply the new effect as short circuiting would
967
+ # eliminate it.
968
+ new_effect.apply(self.type_state)
969
+
970
+ # we undo the effect as we have no clue what context we're in
971
+ # but then we return the combined effect in case we're being used
972
+ # in a conditional context
973
+ effect.undo(self.type_state)
974
+ elif isinstance(node.op, ast.Or):
975
+ for value in node.values[:-1]:
976
+ new_effect = self.visit(value) or NO_EFFECT
977
+ effect = effect.or_(new_effect)
978
+
979
+ old_type = self.get_type(value)
980
+ # The or expression will only return the `value` we're visiting if it's
981
+ # effect holds, so we visit it assuming that the narrowing effects apply.
982
+ new_effect.apply(self.type_state)
983
+ self.visit(value)
984
+ new_effect.undo(self.type_state)
985
+
986
+ final_type = self.widen(
987
+ final_type,
988
+ # For Optional[T], we use `T` when widening, to handle `my_optional or something`
989
+ (
990
+ old_type.klass.opt_type.instance
991
+ if isinstance(old_type, OptionalInstance)
992
+ else old_type
993
+ ),
994
+ )
995
+ self.set_type(value, old_type)
996
+
997
+ new_effect.reverse(self.type_state)
998
+ # We know nothing about the last node of an or, so we simply widen with its type.
999
+ new_effect = self.visit(node.values[-1]) or NO_EFFECT
1000
+ final_type = self.widen(final_type, self.get_type(node.values[-1]))
1001
+
1002
+ effect.undo(self.type_state)
1003
+ effect = effect.or_(new_effect)
1004
+ else:
1005
+ for value in node.values:
1006
+ self.visit(value)
1007
+ final_type = self.widen(final_type, self.get_type(value))
1008
+
1009
+ self.set_type(node, final_type or self.type_env.DYNAMIC)
1010
+ return effect
1011
+
1012
+ def visitBinOp(self, node: BinOp, type_ctx: Class | None = None) -> NarrowingEffect:
1013
+ # If we're taking pow, the output type is always double, regardless of
1014
+ # the input types, and we need to clear the type context to avoid coercing improperly.
1015
+ if isinstance(node.op, ast.Pow):
1016
+ type_ctx = None
1017
+
1018
+ self.visit(node.left, type_ctx)
1019
+ self.visit(node.right, type_ctx)
1020
+
1021
+ ltype = self.get_type(node.left)
1022
+ rtype = self.get_type(node.right)
1023
+
1024
+ tried_right = False
1025
+ if ltype.klass.exact_type() in rtype.klass.mro[1:]:
1026
+ if rtype.bind_reverse_binop(node, self, type_ctx):
1027
+ return NO_EFFECT
1028
+ tried_right = True
1029
+
1030
+ if ltype.bind_binop(node, self, type_ctx):
1031
+ return NO_EFFECT
1032
+
1033
+ if not tried_right:
1034
+ rtype.bind_reverse_binop(node, self, type_ctx)
1035
+
1036
+ return NO_EFFECT
1037
+
1038
+ def visitUnaryOp(
1039
+ self, node: UnaryOp, type_ctx: Class | None = None
1040
+ ) -> NarrowingEffect:
1041
+ effect = self.visit(node.operand, type_ctx)
1042
+ self.get_type(node.operand).bind_unaryop(node, self, type_ctx)
1043
+ if (
1044
+ effect is not None
1045
+ and effect is not NO_EFFECT
1046
+ and isinstance(node.op, ast.Not)
1047
+ ):
1048
+ return effect.not_()
1049
+ return NO_EFFECT
1050
+
1051
+ def visitLambda(
1052
+ self, node: Lambda, type_ctx: Class | None = None
1053
+ ) -> NarrowingEffect:
1054
+ scope = self.new_scope(node)
1055
+ self._visitParameters(node.args, scope)
1056
+
1057
+ self.scopes.append(scope)
1058
+ self.visitExpectedType(
1059
+ node.body, self.type_env.DYNAMIC, "lambda cannot return primitive value"
1060
+ )
1061
+ self.scopes.pop()
1062
+
1063
+ self.set_type(node, self.type_env.DYNAMIC)
1064
+ return NO_EFFECT
1065
+
1066
+ def visitIfExp(self, node: IfExp, type_ctx: Class | None = None) -> NarrowingEffect:
1067
+ effect = self.visit(node.test) or NO_EFFECT
1068
+ effect.apply(self.type_state)
1069
+ self.clear_refinements_for_nonbool_test(node.test)
1070
+
1071
+ self.visit(node.body, type_ctx)
1072
+ effect.reverse(self.type_state)
1073
+ self.visit(node.orelse, type_ctx)
1074
+ effect.undo(self.type_state)
1075
+
1076
+ # Select the most compatible types that we can, or fallback to
1077
+ # dynamic if we can coerce to dynamic, otherwise report an error.
1078
+ body_t = self.get_type(node.body)
1079
+ else_t = self.get_type(node.orelse)
1080
+ self.set_type(
1081
+ node,
1082
+ self.type_env.get_union((body_t.klass, else_t.klass)).instance,
1083
+ )
1084
+ return NO_EFFECT
1085
+
1086
+ def visitSlice(self, node: Slice, type_ctx: Class | None = None) -> NarrowingEffect:
1087
+ lower = node.lower
1088
+ if lower:
1089
+ self.visitExpectedType(
1090
+ lower, self.type_env.DYNAMIC, "slice indices cannot be primitives"
1091
+ )
1092
+ upper = node.upper
1093
+ if upper:
1094
+ self.visitExpectedType(
1095
+ upper, self.type_env.DYNAMIC, "slice indices cannot be primitives"
1096
+ )
1097
+ step = node.step
1098
+ if step:
1099
+ self.visitExpectedType(
1100
+ step, self.type_env.DYNAMIC, "slice indices cannot be primitives"
1101
+ )
1102
+ self.set_type(node, self.type_env.slice.instance)
1103
+ return NO_EFFECT
1104
+
1105
+ def widen(self, existing: Value | None, new: Value) -> Value:
1106
+ if existing is None or new.klass.can_assign_from(existing.klass):
1107
+ return new
1108
+ elif existing.klass.can_assign_from(new.klass):
1109
+ return existing
1110
+
1111
+ return self.type_env.get_union((existing.klass, new.klass)).instance
1112
+
1113
+ def visitDict(
1114
+ self, node: ast.Dict, type_ctx: Class | None = None
1115
+ ) -> NarrowingEffect:
1116
+ key_type: Value | None = None
1117
+ value_type: Value | None = None
1118
+ for k, v in zip(node.keys, node.values):
1119
+ if k:
1120
+ self.visitExpectedType(
1121
+ k, self.type_env.DYNAMIC, "dict keys cannot be primitives"
1122
+ )
1123
+ key_type = self.widen(key_type, self.get_type(k))
1124
+ self.visitExpectedType(
1125
+ v, self.type_env.DYNAMIC, "dict values cannot be primitives"
1126
+ )
1127
+ value_type = self.widen(value_type, self.get_type(v))
1128
+ else:
1129
+ self.visitExpectedType(
1130
+ v, self.type_env.DYNAMIC, "dict splat cannot be a primitive"
1131
+ )
1132
+ d_type = self.get_type(v).klass
1133
+ if d_type.generic_type_def is self.type_env.checked_dict:
1134
+ assert isinstance(d_type, GenericClass)
1135
+ key_type = self.widen(key_type, d_type.type_args[0].instance)
1136
+ value_type = self.widen(value_type, d_type.type_args[1].instance)
1137
+ elif d_type in (
1138
+ self.type_env.dict,
1139
+ self.type_env.dict.exact_type(),
1140
+ self.type_env.dynamic,
1141
+ ):
1142
+ key_type = self.type_env.DYNAMIC
1143
+ value_type = self.type_env.DYNAMIC
1144
+
1145
+ self.set_dict_type(node, key_type, value_type, type_ctx)
1146
+ return NO_EFFECT
1147
+
1148
+ def set_dict_type(
1149
+ self,
1150
+ node: ast.expr,
1151
+ key_type: Value | None,
1152
+ value_type: Value | None,
1153
+ type_ctx: Class | None,
1154
+ ) -> Value:
1155
+ if not isinstance(type_ctx, CheckedDictInstance):
1156
+ if (
1157
+ ModuleFlag.CHECKED_DICTS in self.module.flags
1158
+ and key_type is not None
1159
+ and value_type is not None
1160
+ ):
1161
+ typ = self.type_env.get_generic_type(
1162
+ self.type_env.checked_dict,
1163
+ (key_type.klass.inexact_type(), value_type.klass.inexact_type()),
1164
+ ).instance
1165
+ else:
1166
+ typ = self.type_env.dict.exact_type().instance
1167
+ self.set_type(node, typ)
1168
+ return typ
1169
+
1170
+ # Calculate the type that is inferred by the keys and values
1171
+ assert type_ctx is not None
1172
+ type_class = type_ctx.klass
1173
+ assert type_class.generic_type_def is self.type_env.checked_dict, type_class
1174
+ assert isinstance(type_class, GenericClass)
1175
+ if key_type is None:
1176
+ key_type = type_class.type_args[0].instance
1177
+
1178
+ if value_type is None:
1179
+ value_type = type_class.type_args[1].instance
1180
+
1181
+ gen_type = self.type_env.get_generic_type(
1182
+ self.type_env.checked_dict,
1183
+ (key_type.klass, value_type.klass),
1184
+ )
1185
+
1186
+ self.set_type(node, type_ctx)
1187
+ # We can use the type context to have a type which is wider than the
1188
+ # inferred types. But we need to make sure that the keys/values are compatible
1189
+ # with the wider type, and if not, we'll report that the inferred type isn't
1190
+ # compatible.
1191
+ if not type_class.type_args[0].can_assign_from(
1192
+ key_type.klass
1193
+ ) or not type_class.type_args[1].can_assign_from(value_type.klass):
1194
+ self.check_can_assign_from(type_class, gen_type, node)
1195
+ return type_ctx
1196
+
1197
+ def set_list_type(
1198
+ self,
1199
+ node: ast.expr,
1200
+ item_type: Value | None,
1201
+ type_ctx: Class | None,
1202
+ ) -> Value:
1203
+ if not isinstance(type_ctx, CheckedListInstance):
1204
+ if ModuleFlag.CHECKED_LISTS in self.module.flags and item_type is not None:
1205
+ typ = self.type_env.get_generic_type(
1206
+ self.type_env.checked_list,
1207
+ (item_type.nonliteral().klass.inexact_type(),),
1208
+ ).instance
1209
+ else:
1210
+ typ = self.type_env.list.exact_type().instance
1211
+
1212
+ self.set_type(node, typ)
1213
+ return typ
1214
+
1215
+ # Calculate the type that is inferred by the item.
1216
+ assert type_ctx is not None
1217
+ type_class = type_ctx.klass
1218
+ assert type_class.generic_type_def is self.type_env.checked_list, type_class
1219
+ assert isinstance(type_class, GenericClass)
1220
+ if item_type is None:
1221
+ item_type = type_class.type_args[0].instance
1222
+
1223
+ gen_type = self.type_env.get_generic_type(
1224
+ self.type_env.checked_list,
1225
+ (item_type.nonliteral().klass.inexact_type(),),
1226
+ )
1227
+
1228
+ self.set_type(node, type_ctx)
1229
+ # We can use the type context to have a type which is wider than the
1230
+ # inferred types. But we need to make sure that the items are compatible
1231
+ # with the wider type, and if not, we'll report that the inferred type isn't
1232
+ # compatible.
1233
+ if not type_class.type_args[0].can_assign_from(item_type.klass):
1234
+ self.check_can_assign_from(type_class, gen_type, node)
1235
+ return type_ctx
1236
+
1237
+ def visitSet(self, node: ast.Set, type_ctx: Class | None = None) -> NarrowingEffect:
1238
+ for elt in node.elts:
1239
+ self.visitExpectedType(
1240
+ elt, self.type_env.DYNAMIC, "set members cannot be primitives"
1241
+ )
1242
+ self.set_type(node, self.type_env.set.exact_type().instance)
1243
+ return NO_EFFECT
1244
+
1245
+ def visitGeneratorExp(
1246
+ self, node: GeneratorExp, type_ctx: Class | None = None
1247
+ ) -> NarrowingEffect:
1248
+ self.visit_comprehension(node, node.generators, node.elt)
1249
+ self.set_type(node, self.type_env.DYNAMIC)
1250
+ return NO_EFFECT
1251
+
1252
+ def visitListComp(
1253
+ self, node: ListComp, type_ctx: Class | None = None
1254
+ ) -> NarrowingEffect:
1255
+ self.visit_comprehension(node, node.generators, node.elt)
1256
+ item_type = self.get_type(node.elt)
1257
+ self.set_list_type(node, item_type, type_ctx)
1258
+ return NO_EFFECT
1259
+
1260
+ def visitSetComp(
1261
+ self, node: SetComp, type_ctx: Class | None = None
1262
+ ) -> NarrowingEffect:
1263
+ self.visit_comprehension(node, node.generators, node.elt)
1264
+ self.set_type(node, self.type_env.set.exact_type().instance)
1265
+ return NO_EFFECT
1266
+
1267
+ def get_target_decl(self, name: str) -> TypeDeclaration | None:
1268
+ decl_type = self.decl_types.get(name)
1269
+ if decl_type is None:
1270
+ scope_type = self.get_var_scope(name)
1271
+ if scope_type in (SC_GLOBAL_EXPLICIT, SC_GLOBAL_IMPLICIT):
1272
+ decl_type = self.scopes[0].decl_types.get(name)
1273
+ return decl_type
1274
+
1275
+ def assign_name(
1276
+ self,
1277
+ target: AST,
1278
+ name: str,
1279
+ value: Value,
1280
+ ) -> None:
1281
+ decl_type = self.get_target_decl(name)
1282
+ if decl_type is None:
1283
+ self.declare_local(name, value, is_inferred=True)
1284
+ else:
1285
+ if decl_type.is_final:
1286
+ self.syntax_error("Cannot assign to a Final variable", target)
1287
+ self.check_can_assign_from(decl_type.type.klass, value.klass, target)
1288
+
1289
+ local_type = self.maybe_set_local_type(name, value)
1290
+ self.set_type(target, local_type)
1291
+
1292
+ def assign_value(
1293
+ self,
1294
+ target: expr,
1295
+ value: Value,
1296
+ src: expr | None = None,
1297
+ assignment: AST | None = None,
1298
+ ) -> None:
1299
+ if isinstance(target, Name):
1300
+ self.assign_name(target, target.id, value)
1301
+ elif isinstance(target, (ast.Tuple, ast.List)):
1302
+ if isinstance(src, (ast.Tuple, ast.List)) and len(target.elts) == len(
1303
+ src.elts
1304
+ ):
1305
+ for inner_target, inner_value in zip(target.elts, src.elts):
1306
+ self.assign_value(
1307
+ inner_target, self.get_type(inner_value), src=inner_value
1308
+ )
1309
+ elif isinstance(src, ast.Constant):
1310
+ t = src.value
1311
+ if isinstance(t, tuple) and len(t) == len(target.elts):
1312
+ for inner_target, inner_value in zip(target.elts, t):
1313
+ self.assign_value(
1314
+ inner_target,
1315
+ self.type_env.constant_types[type(inner_value)],
1316
+ )
1317
+ else:
1318
+ for val in target.elts:
1319
+ self.assign_value(val, self.type_env.DYNAMIC)
1320
+ else:
1321
+ for val in target.elts:
1322
+ self.assign_value(val, self.type_env.DYNAMIC)
1323
+ else:
1324
+ self.check_can_assign_from(self.get_type(target).klass, value.klass, target)
1325
+ self._check_final_attribute_reassigned(target, assignment)
1326
+
1327
+ def visitDictComp(
1328
+ self, node: DictComp, type_ctx: Class | None = None
1329
+ ) -> NarrowingEffect:
1330
+ self.visit(node.generators[0].iter)
1331
+
1332
+ scope = self.new_scope(node)
1333
+ self.scopes.append(scope)
1334
+
1335
+ iter_type = self.get_type(node.generators[0].iter).get_iter_type(
1336
+ node.generators[0].iter, self
1337
+ )
1338
+
1339
+ with self.in_target():
1340
+ self.visit(node.generators[0].target)
1341
+ self.assign_value(node.generators[0].target, iter_type)
1342
+ for if_ in node.generators[0].ifs:
1343
+ self.visit(if_)
1344
+
1345
+ for gen in node.generators[1:]:
1346
+ self.visit(gen.iter)
1347
+ iter_type = self.get_type(gen.iter).get_iter_type(gen.iter, self)
1348
+ self.assign_value(gen.target, iter_type)
1349
+
1350
+ self.visitExpectedType(
1351
+ node.key,
1352
+ self.type_env.DYNAMIC,
1353
+ "dictionary comprehension key cannot be a primitive",
1354
+ )
1355
+ self.visitExpectedType(
1356
+ node.value,
1357
+ self.type_env.DYNAMIC,
1358
+ "dictionary comprehension value cannot be a primitive",
1359
+ )
1360
+
1361
+ self.scopes.pop()
1362
+
1363
+ key_type = self.get_type(node.key)
1364
+ value_type = self.get_type(node.value)
1365
+ self.set_dict_type(node, key_type, value_type, type_ctx)
1366
+
1367
+ return NO_EFFECT
1368
+
1369
+ def visit_comprehension(
1370
+ self, node: ast.expr, generators: list[ast.comprehension], *elts: ast.expr
1371
+ ) -> None:
1372
+ self.visit(generators[0].iter)
1373
+
1374
+ scope = self.new_scope(node)
1375
+ self.scopes.append(scope)
1376
+
1377
+ iter_type = self.get_type(generators[0].iter).get_iter_type(
1378
+ generators[0].iter, self
1379
+ )
1380
+
1381
+ with self.in_target():
1382
+ self.visit(generators[0].target)
1383
+ self.assign_value(generators[0].target, iter_type)
1384
+ for if_ in generators[0].ifs:
1385
+ self.visit(if_)
1386
+
1387
+ for gen in generators[1:]:
1388
+ self.visit(gen.iter)
1389
+ iter_type = self.get_type(gen.iter).get_iter_type(gen.iter, self)
1390
+ with self.in_target():
1391
+ self.visit(gen.target)
1392
+ self.assign_value(gen.target, iter_type)
1393
+ for if_ in gen.ifs:
1394
+ self.visit(if_)
1395
+
1396
+ for elt in elts:
1397
+ self.visitExpectedType(
1398
+ elt, self.type_env.DYNAMIC, "generator element cannot be a primitive"
1399
+ )
1400
+
1401
+ self.scopes.pop()
1402
+
1403
+ def visitAwait(self, node: Await, type_ctx: Class | None = None) -> NarrowingEffect:
1404
+ self.visitExpectedType(
1405
+ node.value, self.type_env.DYNAMIC, "cannot await a primitive value"
1406
+ )
1407
+ self.get_type(node.value).bind_await(node, self, type_ctx)
1408
+ return NO_EFFECT
1409
+
1410
+ def visitYield(self, node: Yield, type_ctx: Class | None = None) -> NarrowingEffect:
1411
+ value = node.value
1412
+ if value is not None:
1413
+ self.visitExpectedType(
1414
+ value, self.type_env.DYNAMIC, "cannot yield a primitive value"
1415
+ )
1416
+ self.set_type(node, self.type_env.DYNAMIC)
1417
+ return NO_EFFECT
1418
+
1419
+ def visitYieldFrom(
1420
+ self, node: YieldFrom, type_ctx: Class | None = None
1421
+ ) -> NarrowingEffect:
1422
+ self.visitExpectedType(
1423
+ node.value, self.type_env.DYNAMIC, "cannot yield from a primitive value"
1424
+ )
1425
+ self.set_type(node, self.type_env.DYNAMIC)
1426
+ return NO_EFFECT
1427
+
1428
+ def refine_truthy(self, node: ast.expr | None) -> NarrowingEffect | None:
1429
+ if node is None or not self.is_refinable(node):
1430
+ return None
1431
+
1432
+ type_ = self.get_type(node)
1433
+ if (
1434
+ not isinstance(type_, UnionInstance)
1435
+ or type_.klass.is_generic_type_definition
1436
+ ):
1437
+ return None
1438
+
1439
+ assert isinstance(node, (ast.Name, ast.NamedExpr, ast.Attribute))
1440
+ effect = IsInstanceEffect(
1441
+ node,
1442
+ type_,
1443
+ self.type_env.none.instance,
1444
+ self,
1445
+ )
1446
+ return effect.not_()
1447
+
1448
+ def visitCompare(
1449
+ self, node: Compare, type_ctx: Class | None = None
1450
+ ) -> NarrowingEffect:
1451
+ if len(node.ops) == 1 and isinstance(node.ops[0], (Is, IsNot)):
1452
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
1453
+ left = node.left
1454
+ right = node.comparators[0]
1455
+ other = None
1456
+
1457
+ self.set_type(node, self.type_env.bool.instance)
1458
+ self.set_type(node.ops[0], self.type_env.bool.instance)
1459
+
1460
+ self.visit(left)
1461
+ self.visit(right)
1462
+
1463
+ if isinstance(left, Constant) and left.value is None:
1464
+ other = right
1465
+ elif isinstance(right, Constant) and right.value is None:
1466
+ other = left
1467
+
1468
+ if (effect := self.refine_truthy(other)) is not None:
1469
+ if isinstance(node.ops[0], Is):
1470
+ return effect.not_()
1471
+ else:
1472
+ return effect
1473
+
1474
+ self.visit(node.left)
1475
+ left = node.left
1476
+ ltype = self.get_type(node.left)
1477
+ node.ops = [type(op)() for op in node.ops]
1478
+ for comparator, op in zip(node.comparators, node.ops):
1479
+ self.visit(comparator)
1480
+ rtype = self.get_type(comparator)
1481
+
1482
+ tried_right = False
1483
+ if ltype.klass.exact_type() in rtype.klass.mro[1:]:
1484
+ if ltype.bind_reverse_compare(
1485
+ node, left, op, comparator, self, type_ctx
1486
+ ):
1487
+ continue
1488
+ tried_right = True
1489
+
1490
+ if ltype.bind_compare(node, left, op, comparator, self, type_ctx):
1491
+ continue
1492
+
1493
+ if not tried_right:
1494
+ rtype.bind_reverse_compare(node, left, op, comparator, self, type_ctx)
1495
+
1496
+ ltype = rtype
1497
+ right = comparator
1498
+ return NO_EFFECT
1499
+
1500
+ def visitCall(self, node: Call, type_ctx: Class | None = None) -> NarrowingEffect:
1501
+ self.visit(node.func)
1502
+ return self.get_type(node.func).bind_call(node, self, type_ctx)
1503
+
1504
+ def visitFormattedValue(
1505
+ self, node: FormattedValue, type_ctx: Class | None = None
1506
+ ) -> NarrowingEffect:
1507
+ self.visitExpectedType(
1508
+ node.value, self.type_env.DYNAMIC, "cannot use primitive in formatted value"
1509
+ )
1510
+ if fs := node.format_spec:
1511
+ self.visit(fs)
1512
+ self.set_type(node, self.type_env.DYNAMIC)
1513
+ return NO_EFFECT
1514
+
1515
+ def visitJoinedStr(
1516
+ self, node: JoinedStr, type_ctx: Class | None = None
1517
+ ) -> NarrowingEffect:
1518
+ for value in node.values:
1519
+ self.visit(value)
1520
+
1521
+ self.set_type(node, self.type_env.str.exact_type().instance)
1522
+ return NO_EFFECT
1523
+
1524
+ def visitConstant(
1525
+ self, node: Constant, type_ctx: Class | None = None
1526
+ ) -> NarrowingEffect:
1527
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
1528
+ if type_ctx is not None:
1529
+ type_ctx.bind_constant(node, self)
1530
+ else:
1531
+ self.type_env.DYNAMIC.bind_constant(node, self)
1532
+ return NO_EFFECT
1533
+
1534
+ def visitAttribute(
1535
+ self, node: Attribute, type_ctx: Class | None = None
1536
+ ) -> NarrowingEffect:
1537
+ value = node.value
1538
+ self.visit(value)
1539
+ base = self.get_type(value)
1540
+ base.bind_attr(node, self, type_ctx)
1541
+ if (
1542
+ isinstance(value, ast.Name)
1543
+ and value.id in self.type_state.refined_fields
1544
+ and node.attr in self.type_state.refined_fields[value.id]
1545
+ ):
1546
+ if isinstance(node.ctx, ast.Load):
1547
+ typ, idx, source_nodes = self.type_state.refined_fields[value.id][
1548
+ node.attr
1549
+ ]
1550
+ self.set_type(node, typ)
1551
+ temp_name = self._refined_field_name(idx)
1552
+ for source_node in source_nodes:
1553
+ is_used = node != source_node
1554
+ self.set_node_data(
1555
+ source_node,
1556
+ UsedRefinementField,
1557
+ UsedRefinementField(temp_name, True, is_used),
1558
+ )
1559
+ if node not in source_nodes:
1560
+ self.set_node_data(
1561
+ node,
1562
+ UsedRefinementField,
1563
+ UsedRefinementField(temp_name, False, True),
1564
+ )
1565
+ else:
1566
+ if node.attr in self.type_state.refined_fields[value.id]:
1567
+ # Ensure we don't keep stale refinement information around when setting/deleting an
1568
+ # attr.
1569
+ del self.type_state.refined_fields[value.id][node.attr]
1570
+
1571
+ if isinstance(base, ModuleInstance):
1572
+ self.set_node_data(node, TypeDescr, ((base.module_name,), node.attr))
1573
+ if self.is_refinable(node):
1574
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
1575
+ # If we're storing a field at a refinable position, mark it so that codegen
1576
+ # can hoist reads and the store at this point.
1577
+ if isinstance(node.ctx, ast.Store):
1578
+ temp_name = self._refined_field_name(
1579
+ self.refined_field_index(access_path(node))
1580
+ )
1581
+ self.set_node_data(
1582
+ node,
1583
+ UsedRefinementField,
1584
+ UsedRefinementField(temp_name, True, False),
1585
+ )
1586
+
1587
+ return NO_EFFECT
1588
+
1589
+ def visitSubscript(
1590
+ self, node: Subscript, type_ctx: Class | None = None
1591
+ ) -> NarrowingEffect:
1592
+ self.visit(node.value)
1593
+ self.visit(node.slice)
1594
+ val_type = self.get_type(node.value)
1595
+ val_type.bind_subscr(node, self.get_type(node.slice), self, type_ctx)
1596
+ return NO_EFFECT
1597
+
1598
+ def visitStarred(
1599
+ self, node: Starred, type_ctx: Class | None = None
1600
+ ) -> NarrowingEffect:
1601
+ self.visitExpectedType(
1602
+ node.value,
1603
+ self.type_env.DYNAMIC,
1604
+ "cannot use primitive in starred expression",
1605
+ )
1606
+ self.set_type(node, self.type_env.DYNAMIC)
1607
+ return NO_EFFECT
1608
+
1609
+ def visitName(self, node: Name, type_ctx: Class | None = None) -> NarrowingEffect:
1610
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
1611
+ found_name = self.visiting_assignment_target
1612
+ cur_scope = self.symbols.scopes[self.scope]
1613
+ scope = cur_scope.check_name(node.id)
1614
+ if scope == SC_LOCAL and not isinstance(self.scope, Module):
1615
+ if node.id in self.type_state.local_types:
1616
+ found_name = True
1617
+ var_type = self.type_state.local_types.get(node.id, self.type_env.DYNAMIC)
1618
+ self.set_type(node, var_type)
1619
+ else:
1620
+ typ, descr = self.module.resolve_name_with_descr(
1621
+ node.id, self.context_qualname
1622
+ )
1623
+ if typ is None and len(self.scopes) > 0:
1624
+ # We might be dealing with a context decorated method, in which case we mint
1625
+ # temporary decorator names. These won't be exposed in the module table, but
1626
+ # will be declared at the module scope for internal use. Search it for this name.
1627
+ decl = self.scopes[0].decl_types.get(node.id)
1628
+ if decl is not None:
1629
+ typ = decl.type
1630
+ self.set_type(node, typ or self.type_env.DYNAMIC)
1631
+ if descr is not None:
1632
+ self.set_node_data(node, TypeDescr, descr)
1633
+ if typ is not None:
1634
+ found_name = True
1635
+
1636
+ if not found_name:
1637
+ if scope == SC_FREE:
1638
+ # Search for the name in outer scopes - resolve_name_with_descr simply checks the module
1639
+ # table for globals. We currently set the type to dynamic still for safety, as the inner
1640
+ # scope might override the outer type.
1641
+ for scope in reversed(self.scopes):
1642
+ if node.id in scope.type_state.local_types:
1643
+ found_name = True
1644
+ break
1645
+ elif scope == SC_CELL:
1646
+ # If a name is a cell var, we currently don't resolve it to a type due to soundness
1647
+ # holes with nested types. Do this search to avoid erroring, still.
1648
+ found_name = node.id in self.scopes[-1].type_state.local_types
1649
+
1650
+ if not found_name:
1651
+ raise TypedSyntaxError(f"Name `{node.id}` is not defined.")
1652
+
1653
+ if (effect := self.refine_truthy(node)) is not None:
1654
+ return effect
1655
+
1656
+ return NO_EFFECT
1657
+
1658
+ def visitExpectedType(
1659
+ self,
1660
+ node: AST,
1661
+ expected: Value,
1662
+ reason: str = "type mismatch: {} cannot be assigned to {}",
1663
+ blame: AST | None = None,
1664
+ ) -> NarrowingEffect | None:
1665
+ res = self.visit(node, expected)
1666
+ self.check_can_assign_from(
1667
+ expected.klass, self.get_type(node).klass, blame or node, reason
1668
+ )
1669
+ return res
1670
+
1671
+ def visitList(
1672
+ self, node: ast.List, type_ctx: Class | None = None
1673
+ ) -> NarrowingEffect:
1674
+ item_type: Value | None = None
1675
+ for elt in node.elts:
1676
+ self.visitExpectedType(elt, self.type_env.DYNAMIC)
1677
+ if isinstance(elt, ast.Starred):
1678
+ unpacked_value_type = self.get_type(elt.value)
1679
+ if isinstance(unpacked_value_type, CheckedListInstance):
1680
+ element_type = unpacked_value_type.klass.type_args[0].instance
1681
+ else:
1682
+ element_type = self.type_env.DYNAMIC
1683
+ else:
1684
+ element_type = self.get_type(elt)
1685
+ item_type = self.widen(item_type, element_type)
1686
+ self.set_list_type(node, item_type, type_ctx)
1687
+ return NO_EFFECT
1688
+
1689
+ def visitTuple(
1690
+ self, node: ast.Tuple, type_ctx: Class | None = None
1691
+ ) -> NarrowingEffect:
1692
+ for elt in node.elts:
1693
+ self.visitExpectedType(elt, self.type_env.DYNAMIC)
1694
+ self.set_type(node, self.type_env.tuple.exact_type().instance)
1695
+ return NO_EFFECT
1696
+
1697
+ def set_terminal_kind(self, node: AST, level: TerminalKind) -> None:
1698
+ current = self.terminals.get(node, TerminalKind.NonTerminal)
1699
+ if current < level:
1700
+ self.terminals[node] = level
1701
+
1702
+ def visitContinue(self, node: ast.Continue) -> None:
1703
+ self.set_node_data(node, AST, self.current_loop)
1704
+ self.set_terminal_kind(node, TerminalKind.BreakOrContinue)
1705
+
1706
+ def visitBreak(self, node: ast.Break) -> None:
1707
+ self.set_terminal_kind(node, TerminalKind.BreakOrContinue)
1708
+ if self.current_loop is not None:
1709
+ self.loop_may_break.add(self.current_loop)
1710
+
1711
+ def visitRaise(self, node: ast.Raise) -> None:
1712
+ self.set_terminal_kind(node, TerminalKind.RaiseOrReturn)
1713
+ self.generic_visit(node)
1714
+
1715
+ def visitReturn(self, node: Return) -> None:
1716
+ self.set_terminal_kind(node, TerminalKind.RaiseOrReturn)
1717
+ value = node.value
1718
+ if value is not None:
1719
+ cur_scope = self.binding_scope
1720
+ func = cur_scope.node
1721
+ expected = self.type_env.DYNAMIC
1722
+
1723
+ if isinstance(func, (ast.FunctionDef, ast.AsyncFunctionDef)):
1724
+ function = self.get_func_container(func)
1725
+ expected = function.get_expected_return()
1726
+
1727
+ self.visit(value, expected)
1728
+ returned = self.get_type(value).klass
1729
+ if (
1730
+ returned is not self.type_env.dynamic
1731
+ and not expected.klass.can_assign_from(returned)
1732
+ ):
1733
+ reason = resolve_assign_error_msg(
1734
+ expected.klass,
1735
+ returned,
1736
+ "mismatched types: expected {1} because of return type, found {0} instead",
1737
+ )
1738
+ self.syntax_error(reason, node)
1739
+
1740
+ def visitImport(self, node: Import) -> None:
1741
+ # If we're doing an import within a function, we need to declare the import to retain type
1742
+ # information.
1743
+ if isinstance(self.scope, (FunctionDef, AsyncFunctionDef)):
1744
+ for name in node.names:
1745
+ import_name = (
1746
+ name.name.split(".")[0] if name.asname is None else name.name
1747
+ )
1748
+ declaration_name = name.asname or import_name.split(".")[0]
1749
+ if import_name not in self.compiler.modules:
1750
+ self.compiler.import_module(import_name, optimize=self.optimize)
1751
+ if import_name in self.compiler.modules:
1752
+ typ = ModuleInstance(import_name, self.compiler)
1753
+ else:
1754
+ typ = self.type_env.DYNAMIC
1755
+ self.declare_local(declaration_name, typ)
1756
+
1757
+ def visitImportFrom(self, node: ImportFrom) -> None:
1758
+ mod_name = node.module
1759
+ if node.level or not mod_name:
1760
+ raise NotImplementedError("relative imports aren't supported")
1761
+
1762
+ if mod_name == "__static__":
1763
+ for alias in node.names:
1764
+ name = alias.name
1765
+ if name == "*":
1766
+ self.syntax_error("from __static__ import * is disallowed", node)
1767
+ # no need to track depencencies to statics
1768
+ elif self.compiler.statics.get_child_intrinsic(name) is None:
1769
+ self.syntax_error(f"unsupported static import {name}", node)
1770
+ if mod_name not in self.compiler.modules:
1771
+ self.compiler.import_module(mod_name, optimize=self.optimize)
1772
+ # Unknown module, let's add a local dynamic type to ensure we don't try to infer too much.
1773
+ if mod_name not in self.compiler.modules:
1774
+ for alias in node.names:
1775
+ asname = alias.asname
1776
+ name: str = asname if asname is not None else alias.name
1777
+ self.declare_local(name, self.type_env.DYNAMIC)
1778
+ # If we're doing an import within a function, we need to declare the import to retain type
1779
+ # information.
1780
+ elif isinstance(self.scope, (FunctionDef, AsyncFunctionDef)):
1781
+ for alias in node.names:
1782
+ asname = alias.asname
1783
+ name: str = asname if asname is not None else alias.name
1784
+ context_qualname = self.context_qualname
1785
+ # it's only None at module scope, we know we are at function scope here
1786
+ assert context_qualname is not None
1787
+ self.module.record_dependency(context_qualname, (mod_name, alias.name))
1788
+ child = self.compiler.modules[mod_name].get_child(
1789
+ alias.name, context_qualname
1790
+ )
1791
+ self.declare_local(
1792
+ name,
1793
+ (child or self.type_env.DYNAMIC),
1794
+ )
1795
+
1796
+ def visit_check_terminal(self, nodes: Sequence[ast.stmt]) -> TerminalKind:
1797
+ ret = TerminalKind.NonTerminal
1798
+ for stmt in nodes:
1799
+ self.visit(stmt)
1800
+ if ret == TerminalKind.NonTerminal and stmt in self.terminals:
1801
+ # We have concluded the remainder of `nodes` are unreachable,
1802
+ # but we type-check them anyway for UX consistency
1803
+ ret = self.terminals[stmt]
1804
+
1805
+ return ret
1806
+
1807
+ def get_bool_const(self, node: ast.expr) -> bool | None:
1808
+ kb = self.get_opt_node_data(node, KnownBoolean)
1809
+ if kb is not None:
1810
+ return True if kb == KnownBoolean.TRUE else False
1811
+
1812
+ def visitIf(self, node: If) -> None:
1813
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
1814
+
1815
+ effect = self.visit(node.test) or NO_EFFECT
1816
+
1817
+ test_const = self.get_bool_const(node.test)
1818
+ # Visit body/orelse blocks depending on whether the condition
1819
+ # is determined to be a constant bool
1820
+ visit_body = test_const is not False
1821
+ visit_orelse = test_const is not True
1822
+
1823
+ self.clear_refinements_for_nonbool_test(node.test)
1824
+ branch = self.binding_scope.branch()
1825
+ effect.apply(self.type_state)
1826
+
1827
+ if visit_body:
1828
+ terminates = self.visit_check_terminal(node.body)
1829
+ else:
1830
+ terminates = TerminalKind.NonTerminal
1831
+
1832
+ if visit_orelse and node.orelse:
1833
+ if_end = branch.copy()
1834
+ branch.restore()
1835
+
1836
+ effect.reverse(self.type_state)
1837
+ else_terminates = self.visit_check_terminal(node.orelse)
1838
+ if else_terminates:
1839
+ if terminates:
1840
+ # We're the least severe terminal of our two children
1841
+ self.terminals[node] = min(terminates, else_terminates)
1842
+ else:
1843
+ branch.restore(if_end)
1844
+ elif not terminates:
1845
+ # Merge end of orelse with end of if
1846
+ branch.merge(if_end)
1847
+ else:
1848
+ match terminates:
1849
+ case TerminalKind.NonTerminal:
1850
+ # Merge end of if w/ opening (with test effect reversed)
1851
+ effect.reverse(branch.entry_type_state)
1852
+ branch.merge()
1853
+ case TerminalKind.BreakOrContinue:
1854
+ branch.merge()
1855
+ effect.reverse(self.type_state)
1856
+ case TerminalKind.RaiseOrReturn:
1857
+ branch.restore()
1858
+ effect.reverse(self.type_state)
1859
+
1860
+ def visitTry(self, node: Try) -> None:
1861
+ # There's a bit of a subtlety here: When executing each exception handler,
1862
+ # we need to account for any statement in the try branch being either executed
1863
+ # or not executed. We capture this in the `body_maybe_executed` branch.
1864
+ # However, when later merging the branches, the end of try branch needs to be merged
1865
+ # with the handlers, as if the handlers did not execute, that means that the try branch
1866
+ # successfully ran completely.
1867
+ branch = self.binding_scope.branch()
1868
+ body_terminal = self.visit_check_terminal(node.body)
1869
+ post_try = branch.copy()
1870
+
1871
+ branch.merge()
1872
+ body_maybe_executed = branch.copy()
1873
+ merges = []
1874
+
1875
+ else_terminal = TerminalKind.NonTerminal
1876
+ if node.orelse:
1877
+ branch.restore(post_try)
1878
+ else_terminal = self.visit_check_terminal(node.orelse)
1879
+ post_try = branch.copy()
1880
+
1881
+ no_exception_terminal = max(body_terminal, else_terminal)
1882
+
1883
+ terminals = [no_exception_terminal]
1884
+ for handler in node.handlers:
1885
+ branch.restore(body_maybe_executed.copy())
1886
+ self.visit(handler)
1887
+ terminal = self.terminals.get(handler, TerminalKind.NonTerminal)
1888
+ terminals.append(terminal)
1889
+ # The types in the terminal branches should not influence the later type inference.
1890
+ if terminal == TerminalKind.NonTerminal:
1891
+ merges.append(branch.copy())
1892
+
1893
+ branch.restore(post_try)
1894
+ for merge in merges:
1895
+ branch.merge(merge)
1896
+
1897
+ terminal = min(terminals)
1898
+
1899
+ if node.finalbody:
1900
+ finally_terminal = self.visit_check_terminal(node.finalbody)
1901
+ if finally_terminal:
1902
+ terminal = finally_terminal
1903
+
1904
+ if terminal:
1905
+ self.set_terminal_kind(node, terminal)
1906
+
1907
+ def visitExceptHandler(self, node: ast.ExceptHandler) -> None:
1908
+ htype = node.type
1909
+ hname = None
1910
+ if htype:
1911
+ self.visit(htype)
1912
+ handler_type = self.get_type(htype)
1913
+ hname = node.name
1914
+ if hname:
1915
+ if handler_type is self.type_env.DYNAMIC or not isinstance(
1916
+ handler_type, Class
1917
+ ):
1918
+ handler_type = self.type_env.dynamic
1919
+
1920
+ handler_type = handler_type.inexact_type()
1921
+ decl_type = self.decl_types.get(hname)
1922
+ if decl_type and decl_type.is_final:
1923
+ self.syntax_error("Cannot assign to a Final variable", node)
1924
+
1925
+ self.binding_scope.declare(hname, handler_type.instance)
1926
+
1927
+ terminal = self.visit_check_terminal(node.body)
1928
+ if terminal:
1929
+ self.set_terminal_kind(node, terminal)
1930
+ if hname is not None:
1931
+ del self.decl_types[hname]
1932
+ del self.type_state.local_types[hname]
1933
+
1934
+ def iterate_to_fixed_point(
1935
+ self, body: Sequence[ast.stmt], test: ast.expr | None = None
1936
+ ) -> None:
1937
+ """Iterate given loop body until local types reach a fixed point."""
1938
+ branch: LocalsBranch | None = None
1939
+ counter = 0
1940
+ entry_decls = self.decl_types.copy()
1941
+ while (not branch) or branch.changed():
1942
+ branch = self.binding_scope.branch()
1943
+ counter += 1
1944
+ if counter > 50:
1945
+ # TODO today it should not be possible to hit this case, but in
1946
+ # the future with more complex types or more accurate tracking
1947
+ # of literal types (think `x += 1` in a loop) it could become
1948
+ # possible, and we'll need a smarter approach here: union all
1949
+ # types seen? fall back to declared type?
1950
+ raise AssertionError("Too many loops in fixed-point iteration.")
1951
+ with self.temporary_error_sink(CollectingErrorSink()):
1952
+ if test is not None:
1953
+ effect = self.visit(test) or NO_EFFECT
1954
+ effect.apply(self.type_state)
1955
+ self.clear_refinements_for_nonbool_test(test)
1956
+
1957
+ self.visit_check_terminal(body)
1958
+ # reset any declarations from the loop body to avoid redeclaration errors
1959
+ self.binding_scope.decl_types = entry_decls.copy()
1960
+ branch.merge()
1961
+
1962
+ @contextmanager
1963
+ def in_loop(self, node: AST) -> Generator[None, None, None]:
1964
+ orig = self.current_loop
1965
+ self.current_loop = node
1966
+ try:
1967
+ yield
1968
+ finally:
1969
+ self.current_loop = orig
1970
+
1971
+ def visitWhile(self, node: While) -> None:
1972
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
1973
+
1974
+ branch = self.scopes[-1].branch()
1975
+
1976
+ with self.in_loop(node):
1977
+ self.iterate_to_fixed_point(node.body, node.test)
1978
+
1979
+ effect = self.visit(node.test) or NO_EFFECT
1980
+ condition_always_true = self.get_type(node.test).is_truthy_literal()
1981
+ effect.apply(self.type_state)
1982
+ terminal_level = self.visit_check_terminal(node.body)
1983
+
1984
+ self.clear_refinements_for_nonbool_test(node.test)
1985
+
1986
+ does_not_break = node not in self.loop_may_break
1987
+
1988
+ if terminal_level == TerminalKind.RaiseOrReturn and does_not_break:
1989
+ branch.restore()
1990
+ effect.reverse(self.type_state)
1991
+ else:
1992
+ branch.merge(effect.reverse(branch.entry_type_state))
1993
+
1994
+ if condition_always_true and does_not_break:
1995
+ self.set_terminal_kind(node, terminal_level)
1996
+ if node.orelse:
1997
+ # The or-else can happen after the while body, or without executing
1998
+ # it, but it can only happen after the while condition evaluates to
1999
+ # False.
2000
+ effect.reverse(self.type_state)
2001
+ self.visit_list(node.orelse)
2002
+
2003
+ branch.merge()
2004
+
2005
+ def visitFor(self, node: For) -> None:
2006
+ with self.in_loop(node):
2007
+ self.visit(node.iter)
2008
+ container_type = self.get_type(node.iter)
2009
+ target_type = container_type.get_iter_type(node.iter, self)
2010
+ with self.in_target():
2011
+ container_type.bind_forloop_target(node.target, self)
2012
+ self.assign_value(node.target, target_type)
2013
+ branch = self.scopes[-1].branch()
2014
+ with self.in_loop(node):
2015
+ self.iterate_to_fixed_point(node.body)
2016
+ self.visit_list(node.body)
2017
+ self.visit_list(node.orelse)
2018
+ branch.merge()
2019
+
2020
+ def visitAsyncFor(self, node: AsyncFor) -> None:
2021
+ self.visitExpectedType(
2022
+ node.iter, self.type_env.DYNAMIC, "cannot await a primitive value"
2023
+ )
2024
+ target_type = self.type_env.DYNAMIC
2025
+ with self.in_target():
2026
+ self.visit(node.target)
2027
+ self.assign_value(node.target, target_type)
2028
+ branch = self.scopes[-1].branch()
2029
+ with self.in_loop(node):
2030
+ self.iterate_to_fixed_point(node.body)
2031
+ self.visit_list(node.body)
2032
+ self.visit_list(node.orelse)
2033
+ branch.merge()
2034
+
2035
+ def visitWith(self, node: ast.With) -> None:
2036
+ self.visit_list(node.items)
2037
+ may_suppress_exceptions = False
2038
+ for item in node.items:
2039
+ expr = item.context_expr
2040
+ typ = self.get_type(expr)
2041
+ if isinstance(typ, Object):
2042
+ exit_method_type = resolve_instance_attr_by_name(
2043
+ expr, "__exit__", typ, self
2044
+ )
2045
+ # TODO probably MethodType should itself be a Callable
2046
+ if isinstance(exit_method_type, MethodType):
2047
+ exit_method_type = exit_method_type.function
2048
+ if isinstance(exit_method_type, Callable):
2049
+ exit_ret_type = exit_method_type.return_type.resolved()
2050
+ if (
2051
+ isinstance(exit_ret_type, BoolClass)
2052
+ and exit_ret_type.literal_value is False
2053
+ ):
2054
+ continue
2055
+ may_suppress_exceptions = True
2056
+ terminates = self.visit_check_terminal(node.body)
2057
+ if not may_suppress_exceptions:
2058
+ self.set_terminal_kind(node, terminates)
2059
+
2060
+ def visitAsyncWith(self, node: ast.With) -> None:
2061
+ self.visit_list(node.items)
2062
+ for stmt in node.body:
2063
+ self.visit(stmt)
2064
+
2065
+ def visitMatch(self, node: Match) -> None:
2066
+ self.set_node_data(node, PreserveRefinedFields, PRESERVE_REFINED_FIELDS)
2067
+ self.visit(node.subject)
2068
+
2069
+ branch = self.binding_scope.branch()
2070
+
2071
+ continuing_branches = []
2072
+ for case in node.cases:
2073
+ self.visit(case.pattern)
2074
+
2075
+ post_if_guard_branch = None
2076
+ if case.guard:
2077
+ # TODO: Apply narrowing effect of the case guard
2078
+ self.visit(case.guard)
2079
+ post_if_guard_branch = branch.copy()
2080
+
2081
+ case_terminates = self.visit_check_terminal(case.body)
2082
+
2083
+ match case_terminates:
2084
+ case TerminalKind.RaiseOrReturn:
2085
+ pass
2086
+ case TerminalKind.BreakOrContinue | TerminalKind.NonTerminal:
2087
+ continuing_branches.append(branch.copy())
2088
+
2089
+ branch.restore()
2090
+
2091
+ if post_if_guard_branch is not None:
2092
+ branch.merge(post_if_guard_branch)
2093
+
2094
+ for b in continuing_branches:
2095
+ branch.merge(b)
2096
+
2097
+ def visitMatchValue(self, node: MatchValue) -> None:
2098
+ self.visit(node.value)
2099
+
2100
+ def visitMatchSingleton(self, node: MatchSingleton) -> None:
2101
+ pass
2102
+
2103
+ def visitMatchSequence(self, node: MatchSequence) -> None:
2104
+ for pattern in node.patterns:
2105
+ self.visit(pattern)
2106
+
2107
+ def visitMatchStar(self, node: MatchStar) -> None:
2108
+ name = node.name
2109
+ if name:
2110
+ self.assign_name(node, name, self.type_env.DYNAMIC)
2111
+
2112
+ def visitMatchMapping(self, node: MatchMapping) -> None:
2113
+ for key in node.keys:
2114
+ self.visit(key)
2115
+
2116
+ for pattern in node.patterns:
2117
+ self.visit(pattern)
2118
+
2119
+ rest = node.rest
2120
+ if rest:
2121
+ self.assign_name(node, rest, self.type_env.DYNAMIC)
2122
+
2123
+ def visitMatchClass(self, node: MatchClass) -> None:
2124
+ self.visit(node.cls)
2125
+
2126
+ for pattern in node.patterns:
2127
+ self.visit(pattern)
2128
+
2129
+ for kwd_pattern in node.kwd_patterns:
2130
+ self.visit(kwd_pattern)
2131
+
2132
+ def visitMatchAs(self, node: MatchAs) -> None:
2133
+ # If name is None, pattern must also be None and the node represents the wildcard pattern.
2134
+ name = node.name
2135
+ if name is None:
2136
+ return
2137
+
2138
+ # If the pattern is None, the node represents a capture pattern (i.e a bare name) and will always succeed.
2139
+ if node.pattern is not None:
2140
+ self.visit(node.pattern)
2141
+
2142
+ self.assign_name(node, name, self.type_env.DYNAMIC)
2143
+
2144
+ def visitMatchOr(self, node: MatchOr) -> None:
2145
+ self.binding_scope.branch()
2146
+ for pattern in node.patterns:
2147
+ self.visit(pattern)
2148
+
2149
+ def visitwithitem(self, node: ast.withitem) -> None:
2150
+ self.visit(node.context_expr)
2151
+ optional_vars = node.optional_vars
2152
+ if optional_vars:
2153
+ with self.in_target():
2154
+ self.visit(optional_vars)
2155
+ self.assign_value(optional_vars, self.type_env.DYNAMIC)
2156
+
2157
+ def is_refinable(self, node: ast.AST) -> bool:
2158
+ if isinstance(node, (Name, ast.NamedExpr)):
2159
+ return True
2160
+ elif isinstance(node, ast.Attribute) and isinstance(node.value, Name):
2161
+ typ = self.get_type(node.value)
2162
+ slot = typ.klass.find_slot(node)
2163
+ if slot:
2164
+ return True
2165
+ return False
2166
+
2167
+ def refined_field_index(self, access_path: list[str]) -> int:
2168
+ key = ".".join(access_path)
2169
+ if key in self._refined_tmpvar_indices:
2170
+ return self._refined_tmpvar_indices[key]
2171
+ next_index = len(self._refined_tmpvar_indices)
2172
+ self._refined_tmpvar_indices[key] = next_index
2173
+ return self._refined_tmpvar_indices[key]
2174
+
2175
+ def _refined_field_name(self, idx: int) -> str:
2176
+ return f"{TMP_VAR_PREFIX}.__refined_field__.{idx}"