guppylang-internals 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +37 -18
  3. guppylang_internals/cfg/analysis.py +6 -6
  4. guppylang_internals/cfg/builder.py +44 -12
  5. guppylang_internals/cfg/cfg.py +1 -1
  6. guppylang_internals/checker/core.py +1 -1
  7. guppylang_internals/checker/errors/comptime_errors.py +0 -12
  8. guppylang_internals/checker/errors/linearity.py +6 -2
  9. guppylang_internals/checker/expr_checker.py +53 -28
  10. guppylang_internals/checker/func_checker.py +4 -3
  11. guppylang_internals/checker/stmt_checker.py +1 -1
  12. guppylang_internals/compiler/cfg_compiler.py +1 -1
  13. guppylang_internals/compiler/core.py +17 -4
  14. guppylang_internals/compiler/expr_compiler.py +36 -14
  15. guppylang_internals/compiler/modifier_compiler.py +5 -2
  16. guppylang_internals/decorator.py +5 -3
  17. guppylang_internals/definition/common.py +1 -0
  18. guppylang_internals/definition/custom.py +2 -2
  19. guppylang_internals/definition/declaration.py +3 -3
  20. guppylang_internals/definition/function.py +28 -8
  21. guppylang_internals/definition/metadata.py +87 -0
  22. guppylang_internals/definition/overloaded.py +11 -2
  23. guppylang_internals/definition/pytket_circuits.py +50 -67
  24. guppylang_internals/definition/value.py +1 -1
  25. guppylang_internals/definition/wasm.py +3 -3
  26. guppylang_internals/diagnostic.py +89 -16
  27. guppylang_internals/engine.py +84 -40
  28. guppylang_internals/error.py +1 -1
  29. guppylang_internals/nodes.py +301 -3
  30. guppylang_internals/span.py +7 -3
  31. guppylang_internals/std/_internal/checker.py +104 -2
  32. guppylang_internals/std/_internal/compiler/array.py +36 -1
  33. guppylang_internals/std/_internal/compiler/either.py +14 -2
  34. guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
  35. guppylang_internals/std/_internal/compiler/tket_exts.py +1 -1
  36. guppylang_internals/std/_internal/debug.py +5 -3
  37. guppylang_internals/tracing/builtins_mock.py +2 -2
  38. guppylang_internals/tracing/object.py +6 -2
  39. guppylang_internals/tys/parsing.py +4 -1
  40. guppylang_internals/tys/qubit.py +6 -4
  41. guppylang_internals/tys/subst.py +2 -2
  42. guppylang_internals/tys/ty.py +2 -2
  43. guppylang_internals/wasm_util.py +2 -3
  44. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
  45. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +47 -46
  46. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
  47. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,3 +1,3 @@
1
1
  # This is updated by our release-please workflow, triggered by this
2
2
  # annotation: x-release-please-version
3
- __version__ = "0.26.0"
3
+ __version__ = "0.28.0"
@@ -97,13 +97,13 @@ def find_nodes(
97
97
  def name_nodes_in_ast(node: Any) -> list[ast.Name]:
98
98
  """Returns all `Name` nodes occurring in an AST."""
99
99
  found = find_nodes(lambda n: isinstance(n, ast.Name), node)
100
- return cast(list[ast.Name], found)
100
+ return cast("list[ast.Name]", found)
101
101
 
102
102
 
103
103
  def return_nodes_in_ast(node: Any) -> list[ast.Return]:
104
104
  """Returns all `Return` nodes occurring in an AST."""
105
105
  found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef})
106
- return cast(list[ast.Return], found)
106
+ return cast("list[ast.Return]", found)
107
107
 
108
108
 
109
109
  def loop_in_ast(node: Any) -> list[ast.For | ast.While]:
@@ -111,7 +111,7 @@ def loop_in_ast(node: Any) -> list[ast.For | ast.While]:
111
111
  found = find_nodes(
112
112
  lambda n: isinstance(n, ast.For | ast.While), node, {ast.FunctionDef}
113
113
  )
114
- return cast(list[ast.For | ast.While], found)
114
+ return cast("list[ast.For | ast.While]", found)
115
115
 
116
116
 
117
117
  def breaks_in_loop(node: Any) -> list[ast.Break]:
@@ -122,7 +122,7 @@ def breaks_in_loop(node: Any) -> list[ast.Break]:
122
122
  found = find_nodes(
123
123
  lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef}
124
124
  )
125
- return cast(list[ast.Break], found)
125
+ return cast("list[ast.Break]", found)
126
126
 
127
127
 
128
128
  def loop_controls_in_loop(node: Any) -> list[ast.Break | ast.Continue]:
@@ -135,7 +135,7 @@ def loop_controls_in_loop(node: Any) -> list[ast.Break | ast.Continue]:
135
135
  node,
136
136
  {ast.For, ast.While, ast.FunctionDef},
137
137
  )
138
- return cast(list[ast.Break | ast.Continue], found)
138
+ return cast("list[ast.Break | ast.Continue]", found)
139
139
 
140
140
 
141
141
  class ContextAdjuster(ast.NodeTransformer):
@@ -147,7 +147,7 @@ class ContextAdjuster(ast.NodeTransformer):
147
147
  self.ctx = ctx
148
148
 
149
149
  def visit(self, node: ast.AST) -> ast.AST:
150
- return cast(ast.AST, super().visit(node))
150
+ return cast("ast.AST", super().visit(node))
151
151
 
152
152
  def visit_Name(self, node: ast.Name) -> ast.Name:
153
153
  return with_loc(node, ast.Name(id=node.id, ctx=self.ctx))
@@ -156,29 +156,48 @@ class ContextAdjuster(ast.NodeTransformer):
156
156
  self,
157
157
  node: ast.Starred,
158
158
  ) -> ast.Starred:
159
- return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx))
159
+ return with_loc(
160
+ node,
161
+ ast.Starred(value=self.visit(node.value), ctx=self.ctx), # type: ignore[arg-type]
162
+ )
160
163
 
161
164
  def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
162
165
  return with_loc(
163
- node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
166
+ node,
167
+ ast.Tuple(
168
+ elts=[self.visit(elt) for elt in node.elts], # type: ignore[misc]
169
+ ctx=self.ctx,
170
+ ),
164
171
  )
165
172
 
166
173
  def visit_List(self, node: ast.List) -> ast.List:
167
174
  return with_loc(
168
- node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
175
+ node,
176
+ ast.List(
177
+ elts=[self.visit(elt) for elt in node.elts], # type: ignore[misc]
178
+ ctx=self.ctx,
179
+ ),
169
180
  )
170
181
 
171
182
  def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
172
183
  # Don't adjust the slice!
173
184
  return with_loc(
174
185
  node,
175
- ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx),
186
+ ast.Subscript(
187
+ value=self.visit(node.value), # type: ignore[arg-type]
188
+ slice=node.slice,
189
+ ctx=self.ctx,
190
+ ),
176
191
  )
177
192
 
178
193
  def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
179
194
  return with_loc(
180
195
  node,
181
- ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx),
196
+ ast.Attribute(
197
+ value=self.visit(node.value), # type: ignore[arg-type]
198
+ attr=node.attr,
199
+ ctx=self.ctx,
200
+ ),
182
201
  )
183
202
 
184
203
 
@@ -240,15 +259,15 @@ def template_replace(
240
259
 
241
260
  def line_col(node: ast.AST) -> tuple[int, int]:
242
261
  """Returns the line and column of an ast node."""
243
- return node.lineno, node.col_offset
262
+ return node.lineno, node.col_offset # type: ignore[attr-defined]
244
263
 
245
264
 
246
265
  def set_location_from(node: ast.AST, loc: ast.AST) -> None:
247
266
  """Copy source location from one AST node to the other."""
248
- node.lineno = loc.lineno
249
- node.col_offset = loc.col_offset
250
- node.end_lineno = loc.end_lineno
251
- node.end_col_offset = loc.end_col_offset
267
+ node.lineno = loc.lineno # type: ignore[attr-defined]
268
+ node.col_offset = loc.col_offset # type: ignore[attr-defined]
269
+ node.end_lineno = loc.end_lineno # type: ignore[attr-defined]
270
+ node.end_col_offset = loc.end_col_offset # type: ignore[attr-defined]
252
271
 
253
272
  source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc)
254
273
  assert source is not None
@@ -341,11 +360,11 @@ def with_type(ty: "Type", node: A) -> A:
341
360
 
342
361
  def get_type_opt(node: AstNode) -> Optional["Type"]:
343
362
  """Tries to retrieve a type annotation from an AST node."""
344
- from guppylang_internals.tys.ty import Type, TypeBase
363
+ from guppylang_internals.tys.ty import TypeBase
345
364
 
346
365
  try:
347
366
  ty = node.type # type: ignore[union-attr]
348
- return cast(Type, ty) if isinstance(ty, TypeBase) else None
367
+ return cast("Type", ty) if isinstance(ty, TypeBase) else None
349
368
  except AttributeError:
350
369
  return None
351
370
 
@@ -11,7 +11,7 @@ T = TypeVar("T")
11
11
  Result = dict[BB, T]
12
12
 
13
13
 
14
- class Analysis(Generic[T], ABC):
14
+ class Analysis(ABC, Generic[T]):
15
15
  """Abstract base class for a program analysis pass over the lattice `T`"""
16
16
 
17
17
  def eq(self, t1: T, t2: T, /) -> bool:
@@ -39,7 +39,7 @@ class Analysis(Generic[T], ABC):
39
39
  """
40
40
 
41
41
 
42
- class ForwardAnalysis(Generic[T], Analysis[T], ABC):
42
+ class ForwardAnalysis(Analysis[T], ABC, Generic[T]):
43
43
  """Abstract base class for a program analysis pass running in forward direction."""
44
44
 
45
45
  @abstractmethod
@@ -71,7 +71,7 @@ class ForwardAnalysis(Generic[T], Analysis[T], ABC):
71
71
  return vals_before
72
72
 
73
73
 
74
- class BackwardAnalysis(Generic[T], Analysis[T], ABC):
74
+ class BackwardAnalysis(Analysis[T], ABC, Generic[T]):
75
75
  """Abstract base class for a program analysis pass running in backward direction."""
76
76
 
77
77
  @abstractmethod
@@ -105,7 +105,7 @@ class BackwardAnalysis(Generic[T], Analysis[T], ABC):
105
105
  LivenessDomain = dict[VId, BB]
106
106
 
107
107
 
108
- class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):
108
+ class LivenessAnalysis(BackwardAnalysis[LivenessDomain[VId]], Generic[VId]):
109
109
  """Live variable analysis pass.
110
110
 
111
111
  Computes the variables that are live before the execution of each BB. The analysis
@@ -143,7 +143,7 @@ class LivenessAnalysis(Generic[VId], BackwardAnalysis[LivenessDomain[VId]]):
143
143
 
144
144
  def apply_bb(self, live_after: LivenessDomain[VId], bb: BB) -> LivenessDomain[VId]:
145
145
  stats = self.stats[bb]
146
- return {x: bb for x in stats.used} | {
146
+ return dict.fromkeys(stats.used, bb) | {
147
147
  x: b for x, b in live_after.items() if x not in stats.assigned
148
148
  }
149
149
 
@@ -159,7 +159,7 @@ MaybeAssignmentDomain = set[VId]
159
159
  AssignmentDomain = tuple[DefAssignmentDomain[VId], MaybeAssignmentDomain[VId]]
160
160
 
161
161
 
162
- class AssignmentAnalysis(Generic[VId], ForwardAnalysis[AssignmentDomain[VId]]):
162
+ class AssignmentAnalysis(ForwardAnalysis[AssignmentDomain[VId]], Generic[VId]):
163
163
  """Assigned variable analysis pass.
164
164
 
165
165
  Computes the set of variables (i.e. `V`s) that are definitely assigned at the start
@@ -154,27 +154,52 @@ class CFGBuilder(AstVisitor[BB | None]):
154
154
  return bb_opt
155
155
 
156
156
  def _build_node_value(self, node: BBStatement, bb: BB) -> BB:
157
- """Utility method for building a node containing a `value` expression.
157
+ """Utility method for building a nodes `value` expression, if available.
158
158
 
159
159
  Builds the expression and mutates `node.value` to point to the built expression.
160
- Returns the BB in which the expression is available and adds the node to it.
160
+ Returns the BB in which the expression is available.
161
161
  """
162
162
  if (
163
163
  not isinstance(node, NestedFunctionDef | ModifiedBlock)
164
164
  and node.value is not None
165
165
  ):
166
166
  node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
167
- bb.statements.append(node)
167
+ return bb
168
+
169
+ def _build_node_targets(self, node: BBStatement, bb: BB) -> BB:
170
+ """Utility method for building a nodes `target` or `targets` expressions,
171
+ depending on the node type.
172
+
173
+ Builds the expressions and mutates the elements of `node.targets` to point to
174
+ the built expressions. Returns the BB in which the expressions are available.
175
+ """
176
+ if isinstance(node, ast.Assign):
177
+ for i, target in enumerate(node.targets):
178
+ node.targets[i], bb = ExprBuilder.build(target, self.cfg, bb)
179
+ elif isinstance(node, ast.AugAssign | ast.AnnAssign):
180
+ new_target, bb = ExprBuilder.build(node.target, self.cfg, bb)
181
+ if not isinstance(new_target, ast.Name | ast.Attribute | ast.Subscript):
182
+ raise InternalGuppyError("Unexpected type for built expression.")
183
+ node.target = new_target
168
184
  return bb
169
185
 
170
186
  def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> BB | None:
171
- return self._build_node_value(node, bb)
187
+ bb = self._build_node_value(node, bb)
188
+ bb = self._build_node_targets(node, bb)
189
+ bb.statements.append(node)
190
+ return bb
172
191
 
173
192
  def visit_AugAssign(self, node: ast.AugAssign, bb: BB, jumps: Jumps) -> BB | None:
174
- return self._build_node_value(node, bb)
193
+ bb = self._build_node_value(node, bb)
194
+ bb = self._build_node_targets(node, bb)
195
+ bb.statements.append(node)
196
+ return bb
175
197
 
176
198
  def visit_AnnAssign(self, node: ast.AnnAssign, bb: BB, jumps: Jumps) -> BB | None:
177
- return self._build_node_value(node, bb)
199
+ bb = self._build_node_value(node, bb)
200
+ bb = self._build_node_targets(node, bb)
201
+ bb.statements.append(node)
202
+ return bb
178
203
 
179
204
  def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> BB | None:
180
205
  # This is an expression statement where the value is discarded
@@ -262,6 +287,7 @@ class CFGBuilder(AstVisitor[BB | None]):
262
287
 
263
288
  def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> BB | None:
264
289
  bb = self._build_node_value(node, bb)
290
+ bb.statements.append(node)
265
291
  self.cfg.link(bb, jumps.return_bb)
266
292
  return None
267
293
 
@@ -572,7 +598,7 @@ class BranchBuilder(AstVisitor[None]):
572
598
  comparators[:-1], node.ops, comparators[1:], strict=True
573
599
  )
574
600
  ]
575
- conj = ast.BoolOp(op=ast.And(), values=values)
601
+ conj = ast.BoolOp(op=ast.And(), values=values) # type: ignore[arg-type]
576
602
  set_location_from(conj, node)
577
603
  self.visit_BoolOp(conj, bb, true_bb, false_bb)
578
604
  else:
@@ -668,6 +694,9 @@ def is_comptime_expression(node: ast.AST) -> ComptimeExpr | None:
668
694
 
669
695
  Otherwise, returns `None`.
670
696
  """
697
+ if isinstance(node, ComptimeExpr):
698
+ return node
699
+
671
700
  if (
672
701
  isinstance(node, ast.Call)
673
702
  and isinstance(node.func, ast.Name)
@@ -679,8 +708,8 @@ def is_comptime_expression(node: ast.AST) -> ComptimeExpr | None:
679
708
  case [arg]:
680
709
  pass
681
710
  case args:
682
- arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
683
- return with_loc(node, ComptimeExpr(value=arg))
711
+ arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) # type: ignore[arg-type]
712
+ return with_loc(node, ComptimeExpr(arg))
684
713
  return None
685
714
 
686
715
 
@@ -701,7 +730,7 @@ def is_illegal_in_list_comp(node: ast.AST) -> bool:
701
730
 
702
731
  def make_var(name: str, loc: ast.AST | None = None) -> ast.Name:
703
732
  """Creates an `ast.Name` node."""
704
- node = ast.Name(id=name, ctx=ast.Load)
733
+ node = ast.Name(id=name, ctx=ast.Load) # type: ignore[arg-type]
705
734
  if loc is not None:
706
735
  set_location_from(node, loc)
707
736
  return node
@@ -715,5 +744,8 @@ def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign:
715
744
  if len(lhs) == 1:
716
745
  target = lhs[0]
717
746
  else:
718
- target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store()))
719
- return with_loc(value, ast.Assign(targets=[target], value=value))
747
+ target = with_loc(
748
+ value,
749
+ ast.Tuple(elts=lhs, ctx=ast.Store()), # type: ignore[arg-type]
750
+ )
751
+ return with_loc(value, ast.Assign(targets=[target], value=value)) # type: ignore[list-item]
@@ -118,7 +118,7 @@ class CFG(BaseCFG[BB]):
118
118
  # initial value in the liveness analysis. This solves the edge case that
119
119
  # borrowed variables should be considered live, even if the exit is actually
120
120
  # unreachable (to avoid linearity violations later).
121
- inout_live = {x: self.exit_bb for x in inout_vars}
121
+ inout_live = dict.fromkeys(inout_vars, self.exit_bb)
122
122
  self.live_before = LivenessAnalysis(
123
123
  stats, initial=inout_live, include_unreachable=True
124
124
  ).run(self.bbs)
@@ -384,7 +384,7 @@ class Globals:
384
384
  case _:
385
385
  return assert_never(ty)
386
386
 
387
- type_defn = cast(TypeDef, ENGINE.get_checked(type_defn.id))
387
+ type_defn = cast("TypeDef", ENGINE.get_checked(type_defn.id))
388
388
  if type_defn.id in DEF_STORE.impls and name in DEF_STORE.impls[type_defn.id]:
389
389
  def_id = DEF_STORE.impls[type_defn.id][name]
390
390
  defn = ENGINE.get_parsed(def_id)
@@ -44,18 +44,6 @@ class ComptimeExprIncoherentListError(Error):
44
44
  span_label: ClassVar[str] = "List contains elements with different types"
45
45
 
46
46
 
47
- @dataclass(frozen=True)
48
- class TketNotInstalled(Error):
49
- title: ClassVar[str] = "Tket not installed"
50
- span_label: ClassVar[str] = (
51
- "Experimental pytket compatibility requires `tket` to be installed"
52
- )
53
-
54
- @dataclass(frozen=True)
55
- class InstallInstruction(Help):
56
- message: ClassVar[str] = "Install tket: `pip install tket`"
57
-
58
-
59
47
  @dataclass(frozen=True)
60
48
  class PytketSignatureMismatch(Error):
61
49
  title: ClassVar[str] = "Signature mismatch"
@@ -33,7 +33,9 @@ class AlreadyUsedError(Error):
33
33
 
34
34
  @dataclass(frozen=True)
35
35
  class PrevUse(Note):
36
- span_label: ClassVar[str] = "since it was already {prev_kind.subjunctive} here"
36
+ span_label: ClassVar[str] = (
37
+ "{place.describe} already {prev_kind.subjunctive} here"
38
+ )
37
39
  prev_kind: UseKind
38
40
 
39
41
  @dataclass(frozen=True)
@@ -55,7 +57,9 @@ class ComprAlreadyUsedError(Error):
55
57
 
56
58
  @dataclass(frozen=True)
57
59
  class PrevUse(Note):
58
- span_label: ClassVar[str] = "since it was already {prev_kind.subjunctive} here"
60
+ span_label: ClassVar[str] = (
61
+ "{place.describe} already {prev_kind.subjunctive} here"
62
+ )
59
63
  prev_kind: UseKind
60
64
 
61
65
 
@@ -21,6 +21,7 @@ can be used to infer a type for an expression.
21
21
  """
22
22
 
23
23
  import ast
24
+ import copy
24
25
  import sys
25
26
  import traceback
26
27
  from collections.abc import Sequence
@@ -85,6 +86,7 @@ from guppylang_internals.checker.errors.type_errors import (
85
86
  WrongNumberOfArgsError,
86
87
  )
87
88
  from guppylang_internals.definition.common import Definition
89
+ from guppylang_internals.definition.parameter import ParamDef
88
90
  from guppylang_internals.definition.ty import TypeDef
89
91
  from guppylang_internals.definition.value import CallableDef, ValueDef
90
92
  from guppylang_internals.error import (
@@ -234,7 +236,7 @@ class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]):
234
236
  if actual := get_type_opt(expr):
235
237
  expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind)
236
238
  if inst:
237
- expr = with_loc(expr, TypeApply(value=expr, tys=inst))
239
+ expr = with_loc(expr, TypeApply(expr, inst))
238
240
  return with_type(ty.substitute(subst), expr), subst
239
241
 
240
242
  # When checking against a variable, we have to synthesize
@@ -370,7 +372,7 @@ class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]):
370
372
 
371
373
  # Apply instantiation of quantified type variables
372
374
  if inst:
373
- node = with_loc(node, TypeApply(value=node, inst=inst))
375
+ node = with_loc(node, TypeApply(node, inst))
374
376
 
375
377
  return node, subst
376
378
 
@@ -407,23 +409,27 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
407
409
  raise GuppyError(IllegalConstant(node, type(node.value)))
408
410
  return node, ty
409
411
 
412
+ def _check_generic_param(self, name: str, node: ast.expr) -> tuple[ast.expr, Type]:
413
+ """Helper method to check a generic parameter (ConstParam or TypeParam)."""
414
+ param = self.ctx.generic_params[name]
415
+ match param:
416
+ case ConstParam() as param:
417
+ ast_node = with_loc(node, GenericParamValue(id=name, param=param))
418
+ return ast_node, param.ty
419
+ case TypeParam() as param:
420
+ raise GuppyError(
421
+ ExpectedError(node, "a value", got=f"type `{param.name}`")
422
+ )
423
+ case _:
424
+ return assert_never(param)
425
+
410
426
  def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
411
427
  x = node.id
412
428
  if x in self.ctx.locals:
413
429
  var = self.ctx.locals[x]
414
430
  return with_loc(node, PlaceNode(place=var)), var.ty
415
431
  elif x in self.ctx.generic_params:
416
- param = self.ctx.generic_params[x]
417
- match param:
418
- case ConstParam() as param:
419
- ast_node = with_loc(node, GenericParamValue(id=x, param=param))
420
- return ast_node, param.ty
421
- case TypeParam() as param:
422
- raise GuppyError(
423
- ExpectedError(node, "a value", got=f"type `{param.name}`")
424
- )
425
- case _:
426
- return assert_never(param)
432
+ return self._check_generic_param(x, node)
427
433
  elif x in self.ctx.globals:
428
434
  match self.ctx.globals[x]:
429
435
  case Definition() as defn:
@@ -454,6 +460,16 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
454
460
  defn, "__new__"
455
461
  ):
456
462
  return with_loc(node, GlobalName(id=name, def_id=constr.id)), constr.ty
463
+ # Handle parameter definitions (e.g., nat_var) that may be imported
464
+ case ParamDef():
465
+ # Check if this parameter is in our generic_params
466
+ # (e.g., used in type signature)
467
+ if name in self.ctx.generic_params:
468
+ return self._check_generic_param(name, node)
469
+ # If not in generic_params, it's being used outside its scope
470
+ raise GuppyError(
471
+ ExpectedError(node, "a value", got=f"{defn.description} `{name}`")
472
+ )
457
473
  case defn:
458
474
  raise GuppyError(
459
475
  ExpectedError(node, "a value", got=f"{defn.description} `{name}`")
@@ -461,6 +477,7 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
461
477
 
462
478
  def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
463
479
  from guppylang.defs import GuppyDefinition
480
+
464
481
  from guppylang_internals.engine import ENGINE
465
482
 
466
483
  # A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
@@ -993,11 +1010,13 @@ def type_check_args(
993
1010
  new_args.append(a)
994
1011
  assert next(comptime_args, None) is None
995
1012
 
996
- # If the argument check succeeded, this means that we must have found instantiations
997
- # for all unification variables occurring in the input types
998
- assert all(
999
- set.issubset(inp.ty.unsolved_vars, subst.keys()) for inp in func_ty.inputs
1000
- )
1013
+ # Check whether we have found instantiations for all unification variables occurring
1014
+ # in the input types
1015
+ for inp in func_ty.inputs:
1016
+ if not set.issubset(inp.ty.unsolved_vars, subst.keys()):
1017
+ raise GuppyTypeInferenceError(
1018
+ TypeInferenceError(node, inp.ty.substitute(subst))
1019
+ )
1001
1020
 
1002
1021
  # We also have to check that we found instantiations for all vars in the return type
1003
1022
  if not set.issubset(func_ty.output.unsolved_vars, subst.keys()):
@@ -1162,19 +1181,25 @@ def check_call(
1162
1181
  # However the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common
1163
1182
  # in practice. Can we do better than that?
1164
1183
 
1165
- # First, try to synthesize
1166
- res: tuple[Type, Inst] | None = None
1184
+ # synthesize_call may modify args and node in place,
1185
+ # hence we deepcopy them before passing in the function
1186
+ node_copy = copy.deepcopy(node)
1187
+ inputs_copy = copy.deepcopy(inputs)
1188
+
1167
1189
  try:
1168
1190
  inputs, synth, inst = synthesize_call(func_ty, inputs, node, ctx)
1169
- res = synth, inst
1170
- except GuppyTypeInferenceError:
1171
- pass
1172
- if res is not None:
1173
- synth, inst = res
1174
1191
  subst = unify(ty, synth, {})
1175
1192
  if subst is None:
1176
1193
  raise GuppyTypeError(TypeMismatchError(node, ty, synth, kind))
1177
- return inputs, subst, inst
1194
+ else:
1195
+ return inputs, subst, inst
1196
+ except GuppyTypeInferenceError:
1197
+ pass
1198
+
1199
+ # Restore the state of these values from before they were potentially
1200
+ # modified by `synthesize_call`.
1201
+ inputs = inputs_copy
1202
+ node = node_copy
1178
1203
 
1179
1204
  # If synthesis fails, we try again, this time also using information from the
1180
1205
  # expected return type
@@ -1263,7 +1288,7 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
1263
1288
  assert full_ty.params == ty.params
1264
1289
  node.func = instantiate_poly(node.func, full_ty, inst)
1265
1290
  else:
1266
- node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst))
1291
+ node = with_loc(node, TypeApply(with_type(ty, node), inst))
1267
1292
  return with_type(ty.instantiate(inst), node)
1268
1293
  return with_type(ty, node)
1269
1294
 
@@ -1389,7 +1414,7 @@ def python_value_to_guppy_type(
1389
1414
  ]
1390
1415
  if any(ty is None for ty in tys):
1391
1416
  return None
1392
- return TupleType(cast(list[Type], tys))
1417
+ return TupleType(cast("list[Type]", tys))
1393
1418
  case list():
1394
1419
  return _python_list_to_guppy_type(v, node, globals, type_hint)
1395
1420
  case None:
@@ -141,7 +141,7 @@ def check_global_func_def(
141
141
  check_invalid_under_dagger(func_def, ty.unitary_flags)
142
142
  cfg = CFGBuilder().build(func_def.body, returns_none, globals, ty.unitary_flags)
143
143
  inputs = [
144
- Variable(cast(str, inp.name), inp.ty, loc, inp.flags, is_func_input=True)
144
+ Variable(cast("str", inp.name), inp.ty, loc, inp.flags, is_func_input=True)
145
145
  for inp, loc in zip(ty.inputs, args, strict=True)
146
146
  # Comptime inputs are turned into generic args, so are not included here
147
147
  if InputFlags.Comptime not in inp.flags
@@ -199,7 +199,7 @@ def check_nested_func_def(
199
199
 
200
200
  # Construct inputs for checking the body CFG
201
201
  inputs = [v for v, _ in captured.values()] + [
202
- Variable(cast(str, inp.name), inp.ty, arg, inp.flags, is_func_input=True)
202
+ Variable(cast("str", inp.name), inp.ty, arg, inp.flags, is_func_input=True)
203
203
  for arg, inp in zip(func_def.args.args, func_ty.inputs, strict=True)
204
204
  # Comptime inputs are turned into generic args, so are not included here
205
205
  if InputFlags.Comptime not in inp.flags
@@ -214,6 +214,7 @@ def check_nested_func_def(
214
214
  if not captured:
215
215
  # If there are no captured vars, we treat the function like a global name
216
216
  from guppylang.defs import GuppyDefinition
217
+
217
218
  from guppylang_internals.definition.function import ParsedFunctionDef
218
219
 
219
220
  func = ParsedFunctionDef(def_id, func_def.name, func_def, func_ty, None)
@@ -288,7 +289,7 @@ def check_signature(
288
289
  # Figure out if this is a method
289
290
  self_defn: TypeDef | None = None
290
291
  if def_id is not None and def_id in DEF_STORE.impl_parents:
291
- self_defn = cast(TypeDef, ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
292
+ self_defn = cast("TypeDef", ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
292
293
  assert isinstance(self_defn, TypeDef)
293
294
 
294
295
  inputs = []
@@ -478,7 +478,7 @@ def parse_unpack_pattern(lhs: ast.Tuple | ast.List) -> UnpackPattern:
478
478
  # that there is at most one starred expression)
479
479
  left = list(takewhile(lambda e: not isinstance(e, ast.Starred), lhs.elts))
480
480
  starred = (
481
- cast(ast.Starred, lhs.elts[len(left)]).value
481
+ cast("ast.Starred", lhs.elts[len(left)]).value
482
482
  if len(left) < len(lhs.elts)
483
483
  else None
484
484
  )
@@ -111,7 +111,7 @@ def compile_bb(
111
111
  pred_ty = builder.hugr.port_type(branch_port.out_port())
112
112
  assert pred_ty == OpaqueBool
113
113
  branch_port = dfg.builder.add_op(read_bool(), branch_port)
114
- branch_port = cast(Wire, branch_port)
114
+ branch_port = cast("Wire", branch_port)
115
115
  else:
116
116
  # Even if we don't branch, we still have to add a `Sum(())` predicates
117
117
  branch_port = dfg.builder.add_op(ops.Tag(0, ht.UnitSum(1)))
@@ -31,12 +31,14 @@ from guppylang_internals.definition.common import (
31
31
  CompilableDef,
32
32
  CompiledDef,
33
33
  DefId,
34
+ Definition,
34
35
  MonomorphizableDef,
36
+ RawDef,
35
37
  )
36
38
  from guppylang_internals.definition.ty import TypeDef
37
39
  from guppylang_internals.definition.value import CompiledCallableDef
38
40
  from guppylang_internals.diagnostic import Error
39
- from guppylang_internals.engine import ENGINE
41
+ from guppylang_internals.engine import DEF_STORE, ENGINE
40
42
  from guppylang_internals.error import GuppyError, InternalGuppyError
41
43
  from guppylang_internals.std._internal.compiler.tket_exts import GUPPY_EXTENSION
42
44
  from guppylang_internals.tys.arg import ConstArg, TypeArg
@@ -199,7 +201,7 @@ class CompilerContext(ToHugrContext):
199
201
  params, type_args, self
200
202
  )
201
203
  compile_outer = lambda: monomorphizable.monomorphize( # noqa: E731 (assign-lambda)
202
- self.module, mono_args, self
204
+ self.module, mono_args, self, get_parent_type(monomorphizable)
203
205
  )
204
206
  case CompilableDef() as compilable:
205
207
  compile_outer = lambda: compilable.compile_outer(self.module, self) # noqa: E731
@@ -227,7 +229,9 @@ class CompilerContext(ToHugrContext):
227
229
  raise GuppyError(err)
228
230
  # Thus, the partial monomorphization for the entry point is always empty
229
231
  entry_mono_args = tuple(None for _ in params)
230
- entry_compiled = defn.monomorphize(self.module, entry_mono_args, self)
232
+ entry_compiled = defn.monomorphize(
233
+ self.module, entry_mono_args, self, get_parent_type(defn)
234
+ )
231
235
  case CompilableDef() as defn:
232
236
  entry_compiled = defn.compile_outer(self.module, self)
233
237
  case CompiledDef() as defn:
@@ -371,7 +375,7 @@ class DFContainer:
371
375
  ctx: CompilerContext,
372
376
  locals: CompiledLocals | None = None,
373
377
  ) -> None:
374
- generic_builder = cast(DfBase[ops.DfParentOp], builder)
378
+ generic_builder = cast("DfBase[ops.DfParentOp]", builder)
375
379
  if locals is None:
376
380
  locals = {}
377
381
  self.builder = generic_builder
@@ -467,6 +471,15 @@ def is_return_var(x: str) -> bool:
467
471
  return x.startswith("%ret")
468
472
 
469
473
 
474
+ def get_parent_type(defn: Definition) -> "RawDef | None":
475
+ """Returns the RawDef registered as the parent of `child` in the DEF_STORE,
476
+ or None if it has no parent."""
477
+ if parent_ty_id := DEF_STORE.impl_parents.get(defn.id):
478
+ return DEF_STORE.raw_defs[parent_ty_id]
479
+ else:
480
+ return None
481
+
482
+
470
483
  def require_monomorphization(params: Sequence[Parameter]) -> set[Parameter]:
471
484
  """Returns the subset of type parameters that must be monomorphized before compiling
472
485
  to Hugr.