guppylang-internals 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/cfg.py +8 -0
  3. guppylang_internals/checker/cfg_checker.py +26 -65
  4. guppylang_internals/checker/core.py +8 -0
  5. guppylang_internals/checker/expr_checker.py +11 -25
  6. guppylang_internals/checker/func_checker.py +170 -21
  7. guppylang_internals/checker/stmt_checker.py +1 -1
  8. guppylang_internals/decorator.py +124 -58
  9. guppylang_internals/definition/const.py +2 -2
  10. guppylang_internals/definition/custom.py +1 -1
  11. guppylang_internals/definition/declaration.py +1 -1
  12. guppylang_internals/definition/extern.py +2 -2
  13. guppylang_internals/definition/function.py +1 -1
  14. guppylang_internals/definition/parameter.py +2 -2
  15. guppylang_internals/definition/pytket_circuits.py +1 -1
  16. guppylang_internals/definition/struct.py +10 -10
  17. guppylang_internals/definition/traced.py +1 -1
  18. guppylang_internals/definition/ty.py +6 -0
  19. guppylang_internals/definition/wasm.py +2 -2
  20. guppylang_internals/engine.py +13 -2
  21. guppylang_internals/nodes.py +0 -23
  22. guppylang_internals/std/_internal/compiler/tket_exts.py +3 -6
  23. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  24. guppylang_internals/tracing/function.py +13 -2
  25. guppylang_internals/tracing/unpacking.py +18 -12
  26. guppylang_internals/tys/builtin.py +30 -11
  27. guppylang_internals/tys/errors.py +6 -0
  28. guppylang_internals/tys/parsing.py +111 -125
  29. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/METADATA +5 -5
  30. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/RECORD +32 -32
  31. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/WHEEL +0 -0
  32. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/licenses/LICENCE +0 -0
@@ -176,10 +176,21 @@ def trace_call(func: CallableDef, *args: Any) -> Any:
176
176
  if len(func.ty.inputs) != 0:
177
177
  for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=True):
178
178
  if InputFlags.Inout in inp.flags:
179
+ # Note that `inp.ty` could refer to bound variables in the function
180
+ # signature. Instead, make sure to use `var.ty` which will always be a
181
+ # concrete type and type checking has ensured that they unify.
182
+ ty = var.ty
179
183
  inout_wire = state.dfg[var]
180
- update_packed_value(
181
- arg, GuppyObject(inp.ty, inout_wire), state.dfg.builder
184
+ success = update_packed_value(
185
+ arg, GuppyObject(ty, inout_wire), state.dfg.builder
182
186
  )
187
+ if not success:
188
+ # This means the user has passed an object that we cannot update,
189
+ # e.g. calling `mem_swap(x, y)` where the inputs are plain Python
190
+ # objects
191
+ raise GuppyComptimeError(
192
+ f"Cannot borrow Python object of type `{ty}` at comptime"
193
+ )
183
194
 
184
195
  ret_obj = GuppyObject(ret_ty, ret_wire)
185
196
  return unpack_guppy_object(ret_obj, state.dfg.builder)
@@ -150,13 +150,15 @@ def guppy_object_from_py(
150
150
  return GuppyObject(ty, builder.load(hugr_val))
151
151
 
152
152
 
153
- def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
153
+ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> bool:
154
154
  """Given a Python value `v` and a `GuppyObject` `obj` that was constructed from `v`
155
- using `guppy_object_from_py`, updates the wires of any `GuppyObjects` contained in
156
- `v` to the new wires specified by `obj`.
155
+ using `guppy_object_from_py`, tries to update the wires of any `GuppyObjects`
156
+ contained in `v` to the new wires specified by `obj`.
157
157
 
158
158
  Also resets the used flag on any of those updated wires. This corresponds to making
159
159
  the object available again since it now corresponds to a fresh wire.
160
+
161
+ Returns `True` if all wires could be updated, otherwise `False`.
160
162
  """
161
163
  match v:
162
164
  case GuppyObject() as v_obj:
@@ -172,23 +174,27 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None:
172
174
  assert isinstance(obj._ty, TupleType)
173
175
  wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
174
176
  for v, ty, wire in zip(vs, obj._ty.element_types, wires, strict=True):
175
- update_packed_value(v, GuppyObject(ty, wire), builder)
177
+ success = update_packed_value(v, GuppyObject(ty, wire), builder)
178
+ if not success:
179
+ return False
176
180
  case GuppyStructObject(_ty=ty, _field_values=values):
177
181
  assert obj._ty == ty
178
182
  wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs()
179
- for (
180
- field,
181
- wire,
182
- ) in zip(ty.fields, wires, strict=True):
183
+ for field, wire in zip(ty.fields, wires, strict=True):
183
184
  v = values[field.name]
184
- update_packed_value(v, GuppyObject(field.ty, wire), builder)
185
+ success = update_packed_value(v, GuppyObject(field.ty, wire), builder)
186
+ if not success:
187
+ values[field.name] = obj
185
188
  case list(vs) if len(vs) > 0:
186
189
  assert is_array_type(obj._ty)
187
190
  elem_ty = get_element_type(obj._ty)
188
191
  opt_wires = unpack_array(builder, obj._use_wire(None))
189
192
  err = "Non-droppable array element has already been used"
190
- for v, opt_wire in zip(vs, opt_wires, strict=True):
193
+ for i, (v, opt_wire) in enumerate(zip(vs, opt_wires, strict=True)):
191
194
  (wire,) = build_unwrap(builder, opt_wire, err).outputs()
192
- update_packed_value(v, GuppyObject(elem_ty, wire), builder)
195
+ success = update_packed_value(v, GuppyObject(elem_ty, wire), builder)
196
+ if not success:
197
+ vs[i] = obj
193
198
  case _:
194
- pass
199
+ return False
200
+ return True
@@ -46,6 +46,27 @@ class CallableTypeDef(TypeDef, CompiledDef):
46
46
  raise InternalGuppyError("Tried to `Callable` type via `check_instantiate`")
47
47
 
48
48
 
49
+ @dataclass(frozen=True)
50
+ class SelfTypeDef(TypeDef, CompiledDef):
51
+ """Type definition associated with the `Self` type on methods.
52
+
53
+ During type parsing, we make sure that this type is replaced with the concrete type
54
+ the method is attached to. Thus, we should never have instances of this type around.
55
+
56
+ In other words, this definition is only a marker so that type parsing doesn't have
57
+ to rely on matching against the string "Self". By making `Self` a definition, we can
58
+ use the existing identifier tracking system and also handle users shadowing the
59
+ `Self` binder or assigning `Self` to some other name.
60
+ """
61
+
62
+ name: Literal["Self"] = field(default="Self", init=False)
63
+
64
+ def check_instantiate(
65
+ self, args: Sequence[Argument], loc: AstNode | None = None
66
+ ) -> FunctionType:
67
+ raise InternalGuppyError("Tried to instantiate abstract `Self` type`")
68
+
69
+
49
70
  @dataclass(frozen=True)
50
71
  class _TupleTypeDef(TypeDef, CompiledDef):
51
72
  """Type definition associated with the builtin `tuple` type.
@@ -106,7 +127,6 @@ class _NumericTypeDef(TypeDef, CompiledDef):
106
127
 
107
128
  class WasmModuleTypeDef(OpaqueTypeDef):
108
129
  wasm_file: str
109
- wasm_hash: int
110
130
 
111
131
  def __init__(
112
132
  self,
@@ -114,11 +134,9 @@ class WasmModuleTypeDef(OpaqueTypeDef):
114
134
  name: str,
115
135
  defined_at: ast.AST | None,
116
136
  wasm_file: str,
117
- wasm_hash: int,
118
137
  ) -> None:
119
138
  super().__init__(id, name, defined_at, [], True, True, self.to_hugr)
120
139
  self.wasm_file = wasm_file
121
- self.wasm_hash = wasm_hash
122
140
 
123
141
  def to_hugr(
124
142
  self, args: Sequence[TypeArg | ConstArg], ctx: ToHugrContext
@@ -189,9 +207,10 @@ def _option_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
189
207
  return ht.Option(arg.ty.to_hugr(ctx))
190
208
 
191
209
 
192
- callable_type_def = CallableTypeDef(DefId.fresh(), None)
193
- tuple_type_def = _TupleTypeDef(DefId.fresh(), None)
194
- none_type_def = _NoneTypeDef(DefId.fresh(), None)
210
+ callable_type_def = CallableTypeDef(DefId.fresh(), None, None)
211
+ self_type_def = SelfTypeDef(DefId.fresh(), None, [])
212
+ tuple_type_def = _TupleTypeDef(DefId.fresh(), None, None)
213
+ none_type_def = _NoneTypeDef(DefId.fresh(), None, [])
195
214
  bool_type_def = OpaqueTypeDef(
196
215
  id=DefId.fresh(),
197
216
  name="bool",
@@ -202,13 +221,13 @@ bool_type_def = OpaqueTypeDef(
202
221
  to_hugr=lambda args, ctx: OpaqueBool,
203
222
  )
204
223
  nat_type_def = _NumericTypeDef(
205
- DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat)
224
+ DefId.fresh(), "nat", None, [], NumericType(NumericType.Kind.Nat)
206
225
  )
207
226
  int_type_def = _NumericTypeDef(
208
- DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int)
227
+ DefId.fresh(), "int", None, [], NumericType(NumericType.Kind.Int)
209
228
  )
210
229
  float_type_def = _NumericTypeDef(
211
- DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
230
+ DefId.fresh(), "float", None, [], NumericType(NumericType.Kind.Float)
212
231
  )
213
232
  string_type_def = OpaqueTypeDef(
214
233
  id=DefId.fresh(),
@@ -345,9 +364,9 @@ def is_sized_iter_type(ty: Type) -> TypeGuard[OpaqueType]:
345
364
  return isinstance(ty, OpaqueType) and ty.defn == sized_iter_type_def
346
365
 
347
366
 
348
- def wasm_module_info(ty: Type) -> tuple[str, int] | None:
367
+ def wasm_module_name(ty: Type) -> str | None:
349
368
  if isinstance(ty, OpaqueType) and isinstance(ty.defn, WasmModuleTypeDef):
350
- return ty.defn.wasm_file, ty.defn.wasm_hash
369
+ return ty.defn.wasm_file
351
370
  return None
352
371
 
353
372
 
@@ -116,6 +116,12 @@ class InvalidCallableTypeError(Error):
116
116
  self.add_sub_diagnostic(InvalidCallableTypeError.Explain(None))
117
117
 
118
118
 
119
+ @dataclass(frozen=True)
120
+ class SelfTyNotInMethodError(Error):
121
+ title: ClassVar[str] = "Invalid type"
122
+ span_label: ClassVar[str] = "`Self` type annotations are only allowed in methods"
123
+
124
+
119
125
  @dataclass(frozen=True)
120
126
  class NonLinearOwnedError(Error):
121
127
  title: ClassVar[str] = "Invalid annotation"
@@ -1,6 +1,7 @@
1
1
  import ast
2
2
  import sys
3
3
  from collections.abc import Sequence
4
+ from dataclasses import dataclass, field
4
5
  from types import ModuleType
5
6
 
6
7
  from guppylang_internals.ast_util import (
@@ -17,7 +18,7 @@ from guppylang_internals.definition.ty import TypeDef
17
18
  from guppylang_internals.engine import ENGINE
18
19
  from guppylang_internals.error import GuppyError
19
20
  from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
20
- from guppylang_internals.tys.builtin import CallableTypeDef, bool_type
21
+ from guppylang_internals.tys.builtin import CallableTypeDef, SelfTypeDef, bool_type
21
22
  from guppylang_internals.tys.const import ConstValue
22
23
  from guppylang_internals.tys.errors import (
23
24
  CallableComptimeError,
@@ -34,6 +35,8 @@ from guppylang_internals.tys.errors import (
34
35
  LinearConstParamError,
35
36
  ModuleMemberNotFoundError,
36
37
  NonLinearOwnedError,
38
+ SelfTyNotInMethodError,
39
+ WrongNumberOfTypeArgsError,
37
40
  )
38
41
  from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
39
42
  from guppylang_internals.tys.subst import BoundVarFinder
@@ -48,46 +51,51 @@ from guppylang_internals.tys.ty import (
48
51
  )
49
52
 
50
53
 
51
- def arg_from_ast(
52
- node: AstNode,
53
- globals: Globals,
54
- param_var_mapping: dict[str, Parameter],
55
- allow_free_vars: bool = False,
56
- ) -> Argument:
54
+ @dataclass(frozen=True)
55
+ class TypeParsingCtx:
56
+ """Context for parsing types from AST nodes."""
57
+
58
+ #: The globals variable context
59
+ globals: Globals
60
+
61
+ #: The available type parameters indexed by name
62
+ param_var_mapping: dict[str, Parameter] = field(default_factory=dict)
63
+
64
+ #: Whether a previously unseen type parameters is allowed to be bound (i.e. is
65
+ #: allowed to be added to `param_var_mapping`
66
+ allow_free_vars: bool = False
67
+
68
+ #: When parsing types in the signature or body of a method, we also need access to
69
+ #: the type this method belongs to in order to resolve `Self` annotations.
70
+ self_ty: Type | None = None
71
+
72
+
73
+ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
57
74
  """Turns an AST expression into an argument."""
58
75
  from guppylang_internals.checker.cfg_checker import VarNotDefinedError
59
76
 
60
77
  # A single (possibly qualified) identifier
61
- if defn := _try_parse_defn(node, globals):
62
- return _arg_from_instantiated_defn(
63
- defn, [], globals, node, param_var_mapping, allow_free_vars
64
- )
78
+ if defn := _try_parse_defn(node, ctx.globals):
79
+ return _arg_from_instantiated_defn(defn, [], node, ctx)
65
80
 
66
81
  # An identifier referring to a quantified variable
67
82
  if isinstance(node, ast.Name):
68
- if node.id in param_var_mapping:
69
- return param_var_mapping[node.id].to_bound()
83
+ if node.id in ctx.param_var_mapping:
84
+ return ctx.param_var_mapping[node.id].to_bound()
70
85
  raise GuppyError(VarNotDefinedError(node, node.id))
71
86
 
72
87
  # A parametrised type, e.g. `list[??]`
73
88
  if isinstance(node, ast.Subscript) and (
74
- defn := _try_parse_defn(node.value, globals)
89
+ defn := _try_parse_defn(node.value, ctx.globals)
75
90
  ):
76
91
  arg_nodes = (
77
92
  node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
78
93
  )
79
- return _arg_from_instantiated_defn(
80
- defn, arg_nodes, globals, node, param_var_mapping, allow_free_vars
81
- )
94
+ return _arg_from_instantiated_defn(defn, arg_nodes, node, ctx)
82
95
 
83
96
  # We allow tuple types to be written as `(int, bool)`
84
97
  if isinstance(node, ast.Tuple):
85
- ty = TupleType(
86
- [
87
- type_from_ast(el, globals, param_var_mapping, allow_free_vars)
88
- for el in node.elts
89
- ]
90
- )
98
+ ty = TupleType([type_from_ast(el, ctx) for el in node.elts])
91
99
  return TypeArg(ty)
92
100
 
93
101
  # Literals
@@ -118,7 +126,7 @@ def arg_from_ast(
118
126
  if comptime_expr := is_comptime_expression(node):
119
127
  from guppylang_internals.checker.expr_checker import eval_comptime_expr
120
128
 
121
- v = eval_comptime_expr(comptime_expr, Context(globals, Locals({}), {}))
129
+ v = eval_comptime_expr(comptime_expr, Context(ctx.globals, Locals({}), {}))
122
130
  if isinstance(v, int):
123
131
  nat_ty = NumericType(NumericType.Kind.Nat)
124
132
  return ConstArg(ConstValue(nat_ty, v))
@@ -128,7 +136,7 @@ def arg_from_ast(
128
136
  # Finally, we also support delayed annotations in strings
129
137
  if isinstance(node, ast.Constant) and isinstance(node.value, str):
130
138
  node = _parse_delayed_annotation(node.value, node)
131
- return arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
139
+ return arg_from_ast(node, ctx)
132
140
 
133
141
  raise GuppyError(InvalidTypeArgError(node))
134
142
 
@@ -165,28 +173,19 @@ def _try_parse_defn(node: AstNode, globals: Globals) -> Definition | None:
165
173
 
166
174
 
167
175
  def _arg_from_instantiated_defn(
168
- defn: Definition,
169
- arg_nodes: list[ast.expr],
170
- globals: Globals,
171
- node: AstNode,
172
- param_var_mapping: dict[str, Parameter],
173
- allow_free_vars: bool = False,
176
+ defn: Definition, arg_nodes: list[ast.expr], node: AstNode, ctx: TypeParsingCtx
174
177
  ) -> Argument:
175
178
  """Parses a globals definition with type args into an argument."""
176
179
  match defn:
177
180
  # Special case for the `Callable` type
178
181
  case CallableTypeDef():
179
- return TypeArg(
180
- _parse_callable_type(
181
- arg_nodes, node, globals, param_var_mapping, allow_free_vars
182
- )
183
- )
182
+ return TypeArg(_parse_callable_type(arg_nodes, node, ctx))
183
+ # Special case for the `Callable` type
184
+ case SelfTypeDef():
185
+ return TypeArg(_parse_self_type(arg_nodes, node, ctx))
184
186
  # Either a defined type (e.g. `int`, `bool`, ...)
185
187
  case TypeDef() as defn:
186
- args = [
187
- arg_from_ast(arg_node, globals, param_var_mapping, allow_free_vars)
188
- for arg_node in arg_nodes
189
- ]
188
+ args = [arg_from_ast(arg_node, ctx) for arg_node in arg_nodes]
190
189
  ty = defn.check_instantiate(args, node)
191
190
  return TypeArg(ty)
192
191
  # Or a parameter (e.g. `T`, `n`, ...)
@@ -194,12 +193,14 @@ def _arg_from_instantiated_defn(
194
193
  # We don't allow parametrised variables like `T[int]`
195
194
  if arg_nodes:
196
195
  raise GuppyError(HigherKindedTypeVarError(node, defn))
197
- if defn.name not in param_var_mapping:
198
- if allow_free_vars:
199
- param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping))
196
+ if defn.name not in ctx.param_var_mapping:
197
+ if ctx.allow_free_vars:
198
+ ctx.param_var_mapping[defn.name] = defn.to_param(
199
+ len(ctx.param_var_mapping)
200
+ )
200
201
  else:
201
202
  raise GuppyError(FreeTypeVarError(node, defn))
202
- return param_var_mapping[defn.name].to_bound()
203
+ return ctx.param_var_mapping[defn.name].to_bound()
203
204
  case defn:
204
205
  err = ExpectedError(node, "a type", got=f"{defn.description} `{defn.name}`")
205
206
  raise GuppyError(err)
@@ -224,11 +225,7 @@ def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr:
224
225
 
225
226
 
226
227
  def _parse_callable_type(
227
- args: list[ast.expr],
228
- loc: AstNode,
229
- globals: Globals,
230
- param_var_mapping: dict[str, Parameter],
231
- allow_free_vars: bool = False,
228
+ args: list[ast.expr], loc: AstNode, ctx: TypeParsingCtx
232
229
  ) -> FunctionType:
233
230
  """Helper function to parse a `Callable[[<arguments>], <return type>]` type."""
234
231
  err = InvalidCallableTypeError(loc)
@@ -237,59 +234,63 @@ def _parse_callable_type(
237
234
  [inputs, output] = args
238
235
  if not isinstance(inputs, ast.List):
239
236
  raise GuppyError(err)
240
- inouts, output = parse_function_io_types(
241
- inputs.elts, output, None, loc, globals, param_var_mapping, allow_free_vars
242
- )
243
- return FunctionType(inouts, output)
244
-
245
-
246
- def parse_function_io_types(
247
- input_nodes: list[ast.expr],
248
- output_node: ast.expr,
249
- input_names: list[str] | None,
250
- loc: AstNode,
251
- globals: Globals,
252
- param_var_mapping: dict[str, Parameter],
253
- allow_free_vars: bool = False,
254
- ) -> tuple[list[FuncInput], Type]:
255
- """Parses the inputs and output types of a function type.
256
-
257
- This function takes care of parsing annotations and any related checks.
258
-
259
- Returns the parsed input and output types.
237
+ inputs = [parse_function_arg_annotation(inp, None, ctx) for inp in inputs.elts]
238
+ output = type_from_ast(output, ctx)
239
+ return FunctionType(inputs, output)
240
+
241
+
242
+ def _parse_self_type(args: list[ast.expr], loc: AstNode, ctx: TypeParsingCtx) -> Type:
243
+ """Helper function to parse a `Self` type.
244
+
245
+ Returns the actual type `Self` refers to or emits a user error if we're not inside
246
+ a method.
260
247
  """
261
- inputs = []
262
- for i, inp in enumerate(input_nodes):
263
- ty, flags = type_with_flags_from_ast(
264
- inp, globals, param_var_mapping, allow_free_vars
248
+ if ctx.self_ty is None:
249
+ raise GuppyError(SelfTyNotInMethodError(loc))
250
+
251
+ # We don't allow specifying generic arguments of `Self`. This matches the behaviour
252
+ # of Python.
253
+ if args:
254
+ raise GuppyError(WrongNumberOfTypeArgsError(loc, 0, len(args), "Self"))
255
+ return ctx.self_ty
256
+
257
+
258
+ def parse_function_arg_annotation(
259
+ annotation: ast.expr, name: str | None, ctx: TypeParsingCtx
260
+ ) -> FuncInput:
261
+ """Parses an annotation in the input of a function type."""
262
+ ty, flags = type_with_flags_from_ast(annotation, ctx)
263
+ return check_function_arg(ty, flags, annotation, name, ctx)
264
+
265
+
266
+ def check_function_arg(
267
+ ty: Type, flags: InputFlags, loc: AstNode, name: str | None, ctx: TypeParsingCtx
268
+ ) -> FuncInput:
269
+ """Given a function input type and its user-provided flags, checks if the flags
270
+ are valid and inserts implicit flags."""
271
+ if InputFlags.Owned in flags and ty.copyable:
272
+ raise GuppyError(NonLinearOwnedError(loc, ty))
273
+ if not ty.copyable and InputFlags.Owned not in flags:
274
+ flags |= InputFlags.Inout
275
+ if InputFlags.Comptime in flags:
276
+ if name is None:
277
+ raise GuppyError(CallableComptimeError(loc))
278
+
279
+ # Make sure we're not shadowing a type variable with the same name that was
280
+ # already used on the left. E.g
281
+ #
282
+ # n = guppy.type_var("n")
283
+ # def foo(xs: array[int, n], n: nat @comptime)
284
+ #
285
+ # TODO: In principle we could lift this restriction by tracking multiple
286
+ # params referring to the same name in `param_var_mapping`, but not sure if
287
+ # this would be worth it...
288
+ if name in ctx.param_var_mapping:
289
+ raise GuppyError(ComptimeArgShadowError(loc, name))
290
+ ctx.param_var_mapping[name] = ConstParam(
291
+ len(ctx.param_var_mapping), name, ty, from_comptime_arg=True
265
292
  )
266
- if InputFlags.Owned in flags and ty.copyable:
267
- raise GuppyError(NonLinearOwnedError(loc, ty))
268
- if not ty.copyable and InputFlags.Owned not in flags:
269
- flags |= InputFlags.Inout
270
- if InputFlags.Comptime in flags:
271
- if input_names is None:
272
- raise GuppyError(CallableComptimeError(inp))
273
- name = input_names[i]
274
-
275
- # Make sure we're not shadowing a type variable with the same name that was
276
- # already used on the left. E.g
277
- #
278
- # n = guppy.type_var("n")
279
- # def foo(xs: array[int, n], n: nat @comptime)
280
- #
281
- # TODO: In principle we could lift this restriction by tracking multiple
282
- # params referring to the same name in `param_var_mapping`, but not sure if
283
- # this would be worth it...
284
- if name in param_var_mapping:
285
- raise GuppyError(ComptimeArgShadowError(inp, name))
286
- param_var_mapping[name] = ConstParam(
287
- len(param_var_mapping), name, ty, from_comptime_arg=True
288
- )
289
-
290
- inputs.append(FuncInput(ty, flags))
291
- output = type_from_ast(output_node, globals, param_var_mapping, allow_free_vars)
292
- return inputs, output
293
+ return FuncInput(ty, flags)
293
294
 
294
295
 
295
296
  if sys.version_info >= (3, 12):
@@ -330,7 +331,8 @@ if sys.version_info >= (3, 12):
330
331
  # parameters, so we pass an empty dict as the `param_var_mapping`.
331
332
  # TODO: In the future we might want to allow stuff like
332
333
  # `def foo[T, XS: array[T, 42]]` and so on
333
- ty = type_from_ast(bound, globals, {}, allow_free_vars=False)
334
+ ctx = TypeParsingCtx(globals, param_var_mapping={})
335
+ ty = type_from_ast(bound, ctx)
334
336
  if not ty.copyable or not ty.droppable:
335
337
  raise GuppyError(LinearConstParamError(bound, ty))
336
338
 
@@ -348,15 +350,10 @@ _type_param = TypeParam(0, "T", False, False)
348
350
 
349
351
 
350
352
  def type_with_flags_from_ast(
351
- node: AstNode,
352
- globals: Globals,
353
- param_var_mapping: dict[str, Parameter],
354
- allow_free_vars: bool = False,
353
+ node: AstNode, ctx: TypeParsingCtx
355
354
  ) -> tuple[Type, InputFlags]:
356
355
  if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
357
- ty, flags = type_with_flags_from_ast(
358
- node.left, globals, param_var_mapping, allow_free_vars
359
- )
356
+ ty, flags = type_with_flags_from_ast(node.left, ctx)
360
357
  match node.right:
361
358
  case ast.Name(id="owned"):
362
359
  if ty.copyable:
@@ -382,35 +379,24 @@ def type_with_flags_from_ast(
382
379
  # We also need to handle the case that this could be a delayed string annotation
383
380
  elif isinstance(node, ast.Constant) and isinstance(node.value, str):
384
381
  node = _parse_delayed_annotation(node.value, node)
385
- return type_with_flags_from_ast(
386
- node, globals, param_var_mapping, allow_free_vars
387
- )
382
+ return type_with_flags_from_ast(node, ctx)
388
383
  else:
389
384
  # Parse an argument and check that it's valid for a `TypeParam`
390
- arg = arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
385
+ arg = arg_from_ast(node, ctx)
391
386
  tyarg = _type_param.check_arg(arg, node)
392
387
  return tyarg.ty, InputFlags.NoFlags
393
388
 
394
389
 
395
- def type_from_ast(
396
- node: AstNode,
397
- globals: Globals,
398
- param_var_mapping: dict[str, Parameter],
399
- allow_free_vars: bool = False,
400
- ) -> Type:
390
+ def type_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Type:
401
391
  """Turns an AST expression into a Guppy type."""
402
- ty, flags = type_with_flags_from_ast(
403
- node, globals, param_var_mapping, allow_free_vars
404
- )
392
+ ty, flags = type_with_flags_from_ast(node, ctx)
405
393
  if flags != InputFlags.NoFlags:
406
394
  assert InputFlags.Inout not in flags # Users shouldn't be able to set this
407
395
  raise GuppyError(FlagNotAllowedError(node))
408
396
  return ty
409
397
 
410
398
 
411
- def type_row_from_ast(
412
- node: ast.expr, globals: "Globals", allow_free_vars: bool = False
413
- ) -> Sequence[Type]:
399
+ def type_row_from_ast(node: ast.expr, ctx: TypeParsingCtx) -> Sequence[Type]:
414
400
  """Turns an AST expression into a Guppy type row.
415
401
 
416
402
  This is needed to interpret the return type annotation of functions.
@@ -418,7 +404,7 @@ def type_row_from_ast(
418
404
  # The return type `-> None` is represented in the ast as `ast.Constant(value=None)`
419
405
  if isinstance(node, ast.Constant) and node.value is None:
420
406
  return []
421
- ty = type_from_ast(node, globals, {}, allow_free_vars)
407
+ ty = type_from_ast(node, ctx)
422
408
  if isinstance(ty, TupleType):
423
409
  return ty.element_types
424
410
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guppylang-internals
3
- Version: 0.22.0
3
+ Version: 0.24.0
4
4
  Summary: Compiler internals for `guppylang` package.
5
5
  Author-email: Mark Koch <mark.koch@quantinuum.com>, TKET development team <tket-support@quantinuum.com>
6
6
  Maintainer-email: Mark Koch <mark.koch@quantinuum.com>, TKET development team <tket-support@quantinuum.com>
@@ -219,8 +219,8 @@ Classifier: Programming Language :: Python :: 3.13
219
219
  Classifier: Programming Language :: Python :: 3.14
220
220
  Classifier: Topic :: Software Development :: Compilers
221
221
  Requires-Python: <4,>=3.10
222
- Requires-Dist: hugr<0.14,>=0.13.0rc1
223
- Requires-Dist: tket-exts~=0.10.0
222
+ Requires-Dist: hugr~=0.13.1
223
+ Requires-Dist: tket-exts~=0.11.0
224
224
  Requires-Dist: typing-extensions<5,>=4.9.0
225
225
  Provides-Extra: pytket
226
226
  Requires-Dist: pytket>=1.34; extra == 'pytket'
@@ -228,7 +228,7 @@ Description-Content-Type: text/markdown
228
228
 
229
229
  # guppylang-internals
230
230
 
231
- This packages contains the internals of the Guppy compiler.
231
+ This packages contains the internals of the Guppy compiler.
232
232
 
233
233
  See `guppylang` for the package providing the user-facing language frontend.
234
234
 
@@ -250,4 +250,4 @@ See [DEVELOPMENT.md] information on how to develop and contribute to this packag
250
250
 
251
251
  This project is licensed under Apache License, Version 2.0 ([LICENCE][] or http://www.apache.org/licenses/LICENSE-2.0).
252
252
 
253
- [LICENCE]: https://github.com/CQCL/guppylang/blob/main/LICENCE
253
+ [LICENCE]: https://github.com/CQCL/guppylang/blob/main/LICENCE