guppylang-internals 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +37 -18
  3. guppylang_internals/cfg/analysis.py +6 -6
  4. guppylang_internals/cfg/builder.py +44 -12
  5. guppylang_internals/cfg/cfg.py +1 -1
  6. guppylang_internals/checker/core.py +1 -1
  7. guppylang_internals/checker/errors/comptime_errors.py +0 -12
  8. guppylang_internals/checker/errors/linearity.py +6 -2
  9. guppylang_internals/checker/expr_checker.py +53 -28
  10. guppylang_internals/checker/func_checker.py +4 -3
  11. guppylang_internals/checker/stmt_checker.py +1 -1
  12. guppylang_internals/compiler/cfg_compiler.py +1 -1
  13. guppylang_internals/compiler/core.py +17 -4
  14. guppylang_internals/compiler/expr_compiler.py +36 -14
  15. guppylang_internals/compiler/modifier_compiler.py +5 -2
  16. guppylang_internals/decorator.py +5 -3
  17. guppylang_internals/definition/common.py +1 -0
  18. guppylang_internals/definition/custom.py +2 -2
  19. guppylang_internals/definition/declaration.py +3 -3
  20. guppylang_internals/definition/function.py +28 -8
  21. guppylang_internals/definition/metadata.py +87 -0
  22. guppylang_internals/definition/overloaded.py +11 -2
  23. guppylang_internals/definition/pytket_circuits.py +50 -67
  24. guppylang_internals/definition/value.py +1 -1
  25. guppylang_internals/definition/wasm.py +3 -3
  26. guppylang_internals/diagnostic.py +89 -16
  27. guppylang_internals/engine.py +84 -40
  28. guppylang_internals/error.py +1 -1
  29. guppylang_internals/nodes.py +301 -3
  30. guppylang_internals/span.py +7 -3
  31. guppylang_internals/std/_internal/checker.py +104 -2
  32. guppylang_internals/std/_internal/compiler/array.py +36 -1
  33. guppylang_internals/std/_internal/compiler/either.py +14 -2
  34. guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
  35. guppylang_internals/std/_internal/compiler/tket_exts.py +1 -1
  36. guppylang_internals/std/_internal/debug.py +5 -3
  37. guppylang_internals/tracing/builtins_mock.py +2 -2
  38. guppylang_internals/tracing/object.py +6 -2
  39. guppylang_internals/tys/parsing.py +4 -1
  40. guppylang_internals/tys/qubit.py +6 -4
  41. guppylang_internals/tys/subst.py +2 -2
  42. guppylang_internals/tys/ty.py +2 -2
  43. guppylang_internals/wasm_util.py +2 -3
  44. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
  45. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +47 -46
  46. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
  47. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
@@ -31,6 +31,14 @@ class PlaceNode(ast.expr):
31
31
 
32
32
  _fields = ("place",)
33
33
 
34
+ def __init__(self, place: "Place") -> None:
35
+ super().__init__()
36
+ self.place = place
37
+
38
+ # See MakeIter for explanation
39
+ __reduce__ = object.__reduce__
40
+ __reduce_ex__ = object.__reduce_ex__
41
+
34
42
 
35
43
  class GlobalName(ast.Name):
36
44
  id: str
@@ -41,6 +49,15 @@ class GlobalName(ast.Name):
41
49
  "def_id",
42
50
  )
43
51
 
52
+ def __init__(self, id: str, def_id: "DefId") -> None:
53
+ super().__init__(id=id)
54
+ self.id = id
55
+ self.def_id = def_id
56
+
57
+ # See MakeIter for explanation
58
+ __reduce__ = object.__reduce__
59
+ __reduce_ex__ = object.__reduce_ex__
60
+
44
61
 
45
62
  class GenericParamValue(ast.Name):
46
63
  id: str
@@ -51,6 +68,15 @@ class GenericParamValue(ast.Name):
51
68
  "param",
52
69
  )
53
70
 
71
+ def __init__(self, id: str, param: "ConstParam") -> None:
72
+ super().__init__(id=id)
73
+ self.id = id
74
+ self.param = param
75
+
76
+ # See MakeIter for explanation
77
+ __reduce__ = object.__reduce__
78
+ __reduce_ex__ = object.__reduce_ex__
79
+
54
80
 
55
81
  class LocalCall(ast.expr):
56
82
  func: ast.expr
@@ -61,6 +87,15 @@ class LocalCall(ast.expr):
61
87
  "args",
62
88
  )
63
89
 
90
+ def __init__(self, func: ast.expr, args: list[ast.expr]) -> None:
91
+ super().__init__()
92
+ self.func = func
93
+ self.args = args
94
+
95
+ # See MakeIter for explanation
96
+ __reduce__ = object.__reduce__
97
+ __reduce_ex__ = object.__reduce_ex__
98
+
64
99
 
65
100
  class GlobalCall(ast.expr):
66
101
  def_id: "DefId"
@@ -73,6 +108,16 @@ class GlobalCall(ast.expr):
73
108
  "type_args",
74
109
  )
75
110
 
111
+ def __init__(self, def_id: "DefId", args: list[ast.expr], type_args: Inst) -> None:
112
+ super().__init__()
113
+ self.def_id = def_id
114
+ self.args = args
115
+ self.type_args = type_args
116
+
117
+ # See MakeIter for explanation
118
+ __reduce__ = object.__reduce__
119
+ __reduce_ex__ = object.__reduce_ex__
120
+
76
121
 
77
122
  class TensorCall(ast.expr):
78
123
  """A call to a tuple of functions. Behaves like a local call, but more
@@ -88,6 +133,18 @@ class TensorCall(ast.expr):
88
133
  "tensor_ty",
89
134
  )
90
135
 
136
+ def __init__(
137
+ self, func: ast.expr, args: list[ast.expr], tensor_ty: FunctionType
138
+ ) -> None:
139
+ super().__init__()
140
+ self.func = func
141
+ self.args = args
142
+ self.tensor_ty = tensor_ty
143
+
144
+ # See MakeIter for explanation
145
+ __reduce__ = object.__reduce__
146
+ __reduce_ex__ = object.__reduce_ex__
147
+
91
148
 
92
149
  class TypeApply(ast.expr):
93
150
  value: ast.expr
@@ -98,6 +155,15 @@ class TypeApply(ast.expr):
98
155
  "inst",
99
156
  )
100
157
 
158
+ def __init__(self, value: ast.expr, inst: Inst) -> None:
159
+ super().__init__()
160
+ self.value = value
161
+ self.inst = inst
162
+
163
+ # See MakeIter for explanation
164
+ __reduce__ = object.__reduce__
165
+ __reduce_ex__ = object.__reduce_ex__
166
+
101
167
 
102
168
  class PartialApply(ast.expr):
103
169
  """A partial function application.
@@ -114,6 +180,15 @@ class PartialApply(ast.expr):
114
180
  "args",
115
181
  )
116
182
 
183
+ def __init__(self, func: ast.expr, args: list[ast.expr]) -> None:
184
+ super().__init__()
185
+ self.func = func
186
+ self.args = args
187
+
188
+ # See MakeIter for explanation
189
+ __reduce__ = object.__reduce__
190
+ __reduce_ex__ = object.__reduce_ex__
191
+
117
192
 
118
193
  class FieldAccessAndDrop(ast.expr):
119
194
  """A field access on a struct, dropping all the remaining other fields."""
@@ -128,6 +203,18 @@ class FieldAccessAndDrop(ast.expr):
128
203
  "field",
129
204
  )
130
205
 
206
+ def __init__(
207
+ self, value: ast.expr, struct_ty: "StructType", field: "StructField"
208
+ ) -> None:
209
+ super().__init__()
210
+ self.value = value
211
+ self.struct_ty = struct_ty
212
+ self.field = field
213
+
214
+ # See MakeIter for explanation
215
+ __reduce__ = object.__reduce__
216
+ __reduce_ex__ = object.__reduce_ex__
217
+
131
218
 
132
219
  class SubscriptAccessAndDrop(ast.expr):
133
220
  """A subscript element access on an object, dropping all the remaining items."""
@@ -139,6 +226,23 @@ class SubscriptAccessAndDrop(ast.expr):
139
226
 
140
227
  _fields = ("item", "item_expr", "getitem_expr", "original_expr")
141
228
 
229
+ def __init__(
230
+ self,
231
+ item: "Variable",
232
+ item_expr: ast.expr,
233
+ getitem_expr: ast.expr,
234
+ original_expr: ast.Subscript,
235
+ ) -> None:
236
+ super().__init__()
237
+ self.item = item
238
+ self.item_expr = item_expr
239
+ self.getitem_expr = getitem_expr
240
+ self.original_expr = original_expr
241
+
242
+ # See MakeIter for explanation
243
+ __reduce__ = object.__reduce__
244
+ __reduce_ex__ = object.__reduce_ex__
245
+
142
246
 
143
247
  class TupleAccessAndDrop(ast.expr):
144
248
  """A subscript element access on a tuple, dropping all the remaining items."""
@@ -149,6 +253,16 @@ class TupleAccessAndDrop(ast.expr):
149
253
 
150
254
  _fields = ("value", "tuple_ty", "index")
151
255
 
256
+ def __init__(self, value: ast.expr, tuple_ty: TupleType, index: int) -> None:
257
+ super().__init__()
258
+ self.value = value
259
+ self.tuple_ty = tuple_ty
260
+ self.index = index
261
+
262
+ # See MakeIter for explanation
263
+ __reduce__ = object.__reduce__
264
+ __reduce_ex__ = object.__reduce_ex__
265
+
152
266
 
153
267
  class MakeIter(ast.expr):
154
268
  """Creates an iterator using the `__iter__` magic method.
@@ -168,10 +282,19 @@ class MakeIter(ast.expr):
168
282
  def __init__(
169
283
  self, value: ast.expr, origin_node: ast.AST, unwrap_size_hint: bool = True
170
284
  ) -> None:
171
- super().__init__(value)
285
+ super().__init__()
286
+ self.value = value
172
287
  self.origin_node = origin_node
173
288
  self.unwrap_size_hint = unwrap_size_hint
174
289
 
290
+ # Needed for the deepcopy to work correctly, ast.AST's deepcopy logic
291
+ # reconstructs nodes using _fields only.
292
+ # If you store extra attributes or rely overwriting the __init__,
293
+ # deepcopy will crash with a constructor mismatch.
294
+ # Overriding __reduce__ forces deepcopy to copy the instance dictionary instead
295
+ __reduce_ex__ = object.__reduce_ex__
296
+ __reduce__ = object.__reduce__
297
+
175
298
 
176
299
  class IterNext(ast.expr):
177
300
  """Obtains the next element of an iterator using the `__next__` magic method.
@@ -183,6 +306,14 @@ class IterNext(ast.expr):
183
306
 
184
307
  _fields = ("value",)
185
308
 
309
+ def __init__(self, value: ast.expr) -> None:
310
+ super().__init__()
311
+ self.value = value
312
+
313
+ # See MakeIter for explanation
314
+ __reduce__ = object.__reduce__
315
+ __reduce_ex__ = object.__reduce_ex__
316
+
186
317
 
187
318
  class DesugaredGenerator(ast.expr):
188
319
  """A single desugared generator in a list comprehension.
@@ -207,6 +338,27 @@ class DesugaredGenerator(ast.expr):
207
338
  "ifs",
208
339
  )
209
340
 
341
+ def __init__(
342
+ self,
343
+ iter_assign: ast.Assign,
344
+ next_call: ast.expr,
345
+ iter: ast.expr,
346
+ target: ast.expr,
347
+ ifs: list[ast.expr],
348
+ used_outer_places: "list[Place]",
349
+ ) -> None:
350
+ super().__init__()
351
+ self.iter_assign = iter_assign
352
+ self.next_call = next_call
353
+ self.iter = iter
354
+ self.target = target
355
+ self.ifs = ifs
356
+ self.used_outer_places = used_outer_places
357
+
358
+ # See MakeIter for explanation
359
+ __reduce__ = object.__reduce__
360
+ __reduce_ex__ = object.__reduce_ex__
361
+
210
362
 
211
363
  class DesugaredGeneratorExpr(ast.expr):
212
364
  """A desugared generator expression."""
@@ -219,6 +371,15 @@ class DesugaredGeneratorExpr(ast.expr):
219
371
  "generators",
220
372
  )
221
373
 
374
+ def __init__(self, elt: ast.expr, generators: list[DesugaredGenerator]) -> None:
375
+ super().__init__()
376
+ self.elt = elt
377
+ self.generators = generators
378
+
379
+ # See MakeIter for explanation
380
+ __reduce__ = object.__reduce__
381
+ __reduce_ex__ = object.__reduce_ex__
382
+
222
383
 
223
384
  class DesugaredListComp(ast.expr):
224
385
  """A desugared list comprehension."""
@@ -231,6 +392,15 @@ class DesugaredListComp(ast.expr):
231
392
  "generators",
232
393
  )
233
394
 
395
+ def __init__(self, elt: ast.expr, generators: list[DesugaredGenerator]) -> None:
396
+ super().__init__()
397
+ self.elt = elt
398
+ self.generators = generators
399
+
400
+ # See MakeIter for explanation
401
+ __reduce__ = object.__reduce__
402
+ __reduce_ex__ = object.__reduce_ex__
403
+
234
404
 
235
405
  class DesugaredArrayComp(ast.expr):
236
406
  """A desugared array comprehension."""
@@ -247,6 +417,19 @@ class DesugaredArrayComp(ast.expr):
247
417
  "elt_ty",
248
418
  )
249
419
 
420
+ def __init__(
421
+ self, elt: ast.expr, generator: DesugaredGenerator, length: Const, elt_ty: Type
422
+ ) -> None:
423
+ super().__init__()
424
+ self.elt = elt
425
+ self.generator = generator
426
+ self.length = length
427
+ self.elt_ty = elt_ty
428
+
429
+ # See MakeIter for explanation
430
+ __reduce__ = object.__reduce__
431
+ __reduce_ex__ = object.__reduce_ex__
432
+
250
433
 
251
434
  class ComptimeExpr(ast.expr):
252
435
  """A compile-time evaluated `py(...)` expression."""
@@ -255,6 +438,14 @@ class ComptimeExpr(ast.expr):
255
438
 
256
439
  _fields = ("value",)
257
440
 
441
+ def __init__(self, value: ast.expr) -> None:
442
+ super().__init__()
443
+ self.value = value
444
+
445
+ # See MakeIter for explanation
446
+ __reduce__ = object.__reduce__
447
+ __reduce_ex__ = object.__reduce_ex__
448
+
258
449
 
259
450
  class ExitKind(Enum):
260
451
  ExitShot = 0 # Exit the current shot
@@ -271,6 +462,19 @@ class PanicExpr(ast.expr):
271
462
 
272
463
  _fields = ("kind", "signal", "msg", "values")
273
464
 
465
+ def __init__(
466
+ self, kind: ExitKind, signal: ast.expr, msg: ast.expr, values: list[ast.expr]
467
+ ) -> None:
468
+ super().__init__()
469
+ self.kind = kind
470
+ self.signal = signal
471
+ self.msg = msg
472
+ self.values = values
473
+
474
+ # See MakeIter for explanation
475
+ __reduce__ = object.__reduce__
476
+ __reduce_ex__ = object.__reduce_ex__
477
+
274
478
 
275
479
  class BarrierExpr(ast.expr):
276
480
  """A `barrier(*args)` expression."""
@@ -279,6 +483,15 @@ class BarrierExpr(ast.expr):
279
483
  func_ty: FunctionType
280
484
  _fields = ("args", "func_ty")
281
485
 
486
+ def __init__(self, args: list[ast.expr], func_ty: FunctionType) -> None:
487
+ super().__init__()
488
+ self.args = args
489
+ self.func_ty = func_ty
490
+
491
+ # See MakeIter for explanation
492
+ __reduce__ = object.__reduce__
493
+ __reduce_ex__ = object.__reduce_ex__
494
+
282
495
 
283
496
  class StateResultExpr(ast.expr):
284
497
  """A `state_result(tag, *args)` expression."""
@@ -291,6 +504,25 @@ class StateResultExpr(ast.expr):
291
504
  array_len: Const | None
292
505
  _fields = ("tag_value", "tag_expr", "args", "func_ty", "has_array_input")
293
506
 
507
+ def __init__(
508
+ self,
509
+ tag_value: Const,
510
+ tag_expr: ast.expr,
511
+ args: list[ast.expr],
512
+ func_ty: FunctionType,
513
+ array_len: Const | None,
514
+ ) -> None:
515
+ super().__init__()
516
+ self.tag_value = tag_value
517
+ self.tag_expr = tag_expr
518
+ self.args = args
519
+ self.func_ty = func_ty
520
+ self.array_len = array_len
521
+
522
+ # See MakeIter for explanation
523
+ __reduce__ = object.__reduce__
524
+ __reduce_ex__ = object.__reduce_ex__
525
+
294
526
 
295
527
  AnyCall = LocalCall | GlobalCall | TensorCall | BarrierExpr | StateResultExpr
296
528
 
@@ -303,6 +535,14 @@ class InoutReturnSentinel(ast.expr):
303
535
 
304
536
  _fields = ("var",)
305
537
 
538
+ def __init__(self, var: "Place | str") -> None:
539
+ super().__init__()
540
+ self.var = var
541
+
542
+ # See MakeIter for explanation
543
+ __reduce__ = object.__reduce__
544
+ __reduce_ex__ = object.__reduce_ex__
545
+
306
546
 
307
547
  class UnpackPattern(ast.expr):
308
548
  """The LHS of an unpacking assignment like `a, *bs, c = ...` or
@@ -320,6 +560,18 @@ class UnpackPattern(ast.expr):
320
560
 
321
561
  _fields = ("left", "starred", "right")
322
562
 
563
+ def __init__(
564
+ self, left: list[ast.expr], starred: ast.expr | None, right: list[ast.expr]
565
+ ) -> None:
566
+ super().__init__()
567
+ self.left = left
568
+ self.starred = starred
569
+ self.right = right
570
+
571
+ # See MakeIter for explanation
572
+ __reduce__ = object.__reduce__
573
+ __reduce_ex__ = object.__reduce_ex__
574
+
323
575
 
324
576
  class TupleUnpack(ast.expr):
325
577
  """The LHS of an unpacking assignment of a tuple."""
@@ -329,6 +581,14 @@ class TupleUnpack(ast.expr):
329
581
 
330
582
  _fields = ("pattern",)
331
583
 
584
+ def __init__(self, pattern: UnpackPattern) -> None:
585
+ super().__init__()
586
+ self.pattern = pattern
587
+
588
+ # See MakeIter for explanation
589
+ __reduce__ = object.__reduce__
590
+ __reduce_ex__ = object.__reduce_ex__
591
+
332
592
 
333
593
  class ArrayUnpack(ast.expr):
334
594
  """The LHS of an unpacking assignment of an array."""
@@ -345,10 +605,15 @@ class ArrayUnpack(ast.expr):
345
605
  _fields = ("pattern",)
346
606
 
347
607
  def __init__(self, pattern: UnpackPattern, length: int, elt_type: Type) -> None:
348
- super().__init__(pattern)
608
+ super().__init__()
609
+ self.pattern = pattern
349
610
  self.length = length
350
611
  self.elt_type = elt_type
351
612
 
613
+ # See MakeIter for explanation
614
+ __reduce__ = object.__reduce__
615
+ __reduce_ex__ = object.__reduce_ex__
616
+
352
617
 
353
618
  class IterableUnpack(ast.expr):
354
619
  """The LHS of an unpacking assignment of an iterable type."""
@@ -369,10 +634,15 @@ class IterableUnpack(ast.expr):
369
634
  def __init__(
370
635
  self, pattern: UnpackPattern, compr: DesugaredArrayComp, rhs_var: PlaceNode
371
636
  ) -> None:
372
- super().__init__(pattern)
637
+ super().__init__()
638
+ self.pattern = pattern
373
639
  self.compr = compr
374
640
  self.rhs_var = rhs_var
375
641
 
642
+ # See MakeIter for explanation
643
+ __reduce__ = object.__reduce__
644
+ __reduce_ex__ = object.__reduce_ex__
645
+
376
646
 
377
647
  #: Any unpacking operation.
378
648
  AnyUnpack = TupleUnpack | ArrayUnpack | IterableUnpack
@@ -388,6 +658,10 @@ class NestedFunctionDef(ast.FunctionDef):
388
658
  self.cfg = cfg
389
659
  self.ty = ty
390
660
 
661
+ # See MakeIter for explanation
662
+ __reduce__ = object.__reduce__
663
+ __reduce_ex__ = object.__reduce_ex__
664
+
391
665
 
392
666
  class CheckedNestedFunctionDef(ast.FunctionDef):
393
667
  def_id: "DefId"
@@ -413,6 +687,10 @@ class CheckedNestedFunctionDef(ast.FunctionDef):
413
687
  self.ty = ty
414
688
  self.captured = captured
415
689
 
690
+ # See MakeIter for explanation
691
+ __reduce__ = object.__reduce__
692
+ __reduce_ex__ = object.__reduce_ex__
693
+
416
694
 
417
695
  class Dagger(ast.expr):
418
696
  """The dagger modifier"""
@@ -420,6 +698,10 @@ class Dagger(ast.expr):
420
698
  def __init__(self, node: ast.expr) -> None:
421
699
  super().__init__(**node.__dict__)
422
700
 
701
+ # See MakeIter for explanation
702
+ __reduce__ = object.__reduce__
703
+ __reduce_ex__ = object.__reduce_ex__
704
+
423
705
 
424
706
  class Control(ast.Call):
425
707
  """The control modifier"""
@@ -434,6 +716,10 @@ class Control(ast.Call):
434
716
  self.ctrl = ctrl
435
717
  self.qubit_num = None
436
718
 
719
+ # See MakeIter for explanation
720
+ __reduce__ = object.__reduce__
721
+ __reduce_ex__ = object.__reduce_ex__
722
+
437
723
 
438
724
  class Power(ast.expr):
439
725
  """The power modifier"""
@@ -446,6 +732,10 @@ class Power(ast.expr):
446
732
  super().__init__(**node.__dict__)
447
733
  self.iter = iter
448
734
 
735
+ # See MakeIter for explanation
736
+ __reduce__ = object.__reduce__
737
+ __reduce_ex__ = object.__reduce_ex__
738
+
449
739
 
450
740
  Modifier = Dagger | Control | Power
451
741
 
@@ -463,6 +753,10 @@ class ModifiedBlock(ast.With):
463
753
  self.control = []
464
754
  self.power = []
465
755
 
756
+ # See MakeIter for explanation
757
+ __reduce__ = object.__reduce__
758
+ __reduce_ex__ = object.__reduce_ex__
759
+
466
760
  def is_dagger(self) -> bool:
467
761
  return len(self.dagger) % 2 == 1
468
762
 
@@ -533,6 +827,10 @@ class CheckedModifiedBlock(ast.With):
533
827
  self.control = control
534
828
  self.power = power
535
829
 
830
+ # See MakeIter for explanation
831
+ __reduce__ = object.__reduce__
832
+ __reduce_ex__ = object.__reduce_ex__
833
+
536
834
  def __str__(self) -> str:
537
835
  # generate a function name from the def_id
538
836
  return f"__WithBlock__({self.def_id})"
@@ -113,11 +113,15 @@ def to_span(x: ToSpan) -> Span:
113
113
  assert file is not None
114
114
  assert line_offset is not None
115
115
  # x.lineno and line_offset both start at 1, so we have to subtract 1
116
- start = Loc(file, x.lineno + line_offset - 1, x.col_offset)
116
+ start = Loc(
117
+ file,
118
+ x.lineno + line_offset - 1, # type: ignore[attr-defined]
119
+ x.col_offset, # type: ignore[attr-defined]
120
+ )
117
121
  end = Loc(
118
122
  file,
119
- (x.end_lineno or x.lineno) + line_offset - 1,
120
- x.end_col_offset or x.col_offset,
123
+ (x.end_lineno or x.lineno) + line_offset - 1, # type: ignore[attr-defined]
124
+ x.end_col_offset or x.col_offset, # type: ignore[attr-defined]
121
125
  )
122
126
  return Span(start, end)
123
127
 
@@ -5,7 +5,7 @@ from typing import ClassVar
5
5
  from typing_extensions import assert_never
6
6
 
7
7
  from guppylang_internals.ast_util import get_type, with_loc, with_type
8
- from guppylang_internals.checker.core import Context
8
+ from guppylang_internals.checker.core import Context, Variable
9
9
  from guppylang_internals.checker.errors.generic import UnsupportedError
10
10
  from guppylang_internals.checker.errors.type_errors import (
11
11
  ArrayComprUnknownSizeError,
@@ -33,6 +33,7 @@ from guppylang_internals.nodes import (
33
33
  GlobalCall,
34
34
  MakeIter,
35
35
  PanicExpr,
36
+ PlaceNode,
36
37
  )
37
38
  from guppylang_internals.tys.arg import ConstArg, TypeArg
38
39
  from guppylang_internals.tys.builtin import (
@@ -172,6 +173,105 @@ class ArrayCopyChecker(CustomCallChecker):
172
173
  return with_loc(self.node, node), get_type(array_arg)
173
174
 
174
175
 
176
+ class ArrayIndexChecker(CustomCallChecker):
177
+ """Performs compile-time bounds checking for array indexing.
178
+
179
+ When the array size is statically known and the index is a literal constant,
180
+ this checker validates that the index is within bounds and raises an error
181
+ at compile time if it's not.
182
+ """
183
+
184
+ @dataclass(frozen=True)
185
+ class IndexOutOfBoundsError(Error):
186
+ title: ClassVar[str] = "Index out of bounds"
187
+ span_label: ClassVar[str] = (
188
+ "Array index {index} is out of bounds for array of size {size}."
189
+ )
190
+ index: int
191
+ size: int
192
+
193
+ def __init__(self, *, expr_index: int = 1):
194
+ """
195
+ Args:
196
+ expr_index: Position of the expression index argument (0 based)
197
+ """
198
+ self.expr_index: int = expr_index
199
+
200
+ def _extract_constant_index(self, index_expr: ast.expr) -> int | None:
201
+ """Extract a constant integer value from an index expression if possible.
202
+
203
+ Handles both AST constants and PlaceNode structures.
204
+ """
205
+ # Case 1: Simple AST constant (e.g., arr.take(0))
206
+ if isinstance(index_expr, ast.Constant) and isinstance(index_expr.value, int):
207
+ return index_expr.value
208
+
209
+ # Case 2: Subscript accesses (e.g., arr[0])
210
+ if isinstance(index_expr, PlaceNode):
211
+ place = index_expr.place
212
+ if isinstance(place, Variable):
213
+ defined_at = place.defined_at
214
+ if isinstance(defined_at, ast.Constant) and isinstance(
215
+ defined_at.value, int
216
+ ):
217
+ return defined_at.value
218
+
219
+ return None
220
+
221
+ def _check_constant_index_bounds(
222
+ self, index_expr: ast.expr, length_arg: TypeArg | ConstArg
223
+ ) -> None:
224
+ """Perform compile-time bounds checking if size and index are constant."""
225
+
226
+ # Check if array size is statically known
227
+ if not (
228
+ isinstance(length_arg, ConstArg)
229
+ and isinstance(length_arg.const, ConstValue)
230
+ ):
231
+ return
232
+
233
+ array_length = length_arg.const.value
234
+
235
+ index_value = self._extract_constant_index(index_expr)
236
+ if index_value is None:
237
+ return
238
+
239
+ if index_value < 0 or index_value >= array_length:
240
+ raise GuppyError(
241
+ ArrayIndexChecker.IndexOutOfBoundsError(
242
+ index_expr,
243
+ index=index_value,
244
+ size=array_length,
245
+ )
246
+ )
247
+
248
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
249
+ """Check-mode: verify arguments against
250
+ expected type and perform bounds check."""
251
+
252
+ # Run regular type checking for the arguments
253
+ args, subs, type_args = check_call(self.func.ty, args, ty, self.node, self.ctx)
254
+
255
+ # Check the index bounds (first:index expression, second: length_arg)
256
+ self._check_constant_index_bounds(args[self.expr_index], type_args[1])
257
+
258
+ # Return the synthesized node and type
259
+ node = GlobalCall(def_id=self.func.id, args=args, type_args=type_args)
260
+ return with_loc(self.node, node), subs
261
+
262
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
263
+ """Synthesize-mode: infer return type and perform bounds check."""
264
+ # Run regular type synthesis for the arguments
265
+ args, ty, type_args = synthesize_call(self.func.ty, args, self.node, self.ctx)
266
+
267
+ # Check the index bounds (first:index expression, second: length_arg)
268
+ self._check_constant_index_bounds(args[self.expr_index], type_args[1])
269
+
270
+ # Return the synthesized node and type
271
+ node = GlobalCall(def_id=self.func.id, args=args, type_args=type_args)
272
+ return with_loc(self.node, node), ty
273
+
274
+
175
275
  class NewArrayChecker(CustomCallChecker):
176
276
  """Function call checker for the `array.__new__` function."""
177
277
 
@@ -251,7 +351,9 @@ class NewArrayChecker(CustomCallChecker):
251
351
  ConstValue(nat_type(), len(args)),
252
352
  ]
253
353
  call = GlobalCall(
254
- def_id=self.func.id, args=args, type_args=type_args
354
+ self.func.id,
355
+ args,
356
+ type_args, # type: ignore[arg-type]
255
357
  )
256
358
  return with_loc(self.node, call), subst
257
359
  case type_args: