guppylang-internals 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +101 -3
  5. guppylang_internals/checker/core.py +12 -0
  6. guppylang_internals/checker/errors/generic.py +32 -1
  7. guppylang_internals/checker/errors/type_errors.py +14 -0
  8. guppylang_internals/checker/expr_checker.py +55 -29
  9. guppylang_internals/checker/func_checker.py +171 -22
  10. guppylang_internals/checker/linearity_checker.py +65 -0
  11. guppylang_internals/checker/modifier_checker.py +116 -0
  12. guppylang_internals/checker/stmt_checker.py +49 -2
  13. guppylang_internals/compiler/core.py +90 -53
  14. guppylang_internals/compiler/expr_compiler.py +49 -114
  15. guppylang_internals/compiler/modifier_compiler.py +174 -0
  16. guppylang_internals/compiler/stmt_compiler.py +15 -8
  17. guppylang_internals/decorator.py +124 -58
  18. guppylang_internals/definition/const.py +2 -2
  19. guppylang_internals/definition/custom.py +36 -2
  20. guppylang_internals/definition/declaration.py +4 -5
  21. guppylang_internals/definition/extern.py +2 -2
  22. guppylang_internals/definition/function.py +1 -1
  23. guppylang_internals/definition/parameter.py +10 -5
  24. guppylang_internals/definition/pytket_circuits.py +14 -42
  25. guppylang_internals/definition/struct.py +17 -14
  26. guppylang_internals/definition/traced.py +1 -1
  27. guppylang_internals/definition/ty.py +9 -3
  28. guppylang_internals/definition/wasm.py +2 -2
  29. guppylang_internals/engine.py +13 -2
  30. guppylang_internals/experimental.py +5 -0
  31. guppylang_internals/nodes.py +124 -23
  32. guppylang_internals/std/_internal/compiler/array.py +94 -282
  33. guppylang_internals/std/_internal/compiler/tket_exts.py +12 -8
  34. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  35. guppylang_internals/tracing/function.py +13 -2
  36. guppylang_internals/tracing/unpacking.py +33 -28
  37. guppylang_internals/tys/arg.py +18 -3
  38. guppylang_internals/tys/builtin.py +32 -16
  39. guppylang_internals/tys/const.py +33 -4
  40. guppylang_internals/tys/errors.py +6 -0
  41. guppylang_internals/tys/param.py +31 -16
  42. guppylang_internals/tys/parsing.py +118 -145
  43. guppylang_internals/tys/qubit.py +27 -0
  44. guppylang_internals/tys/subst.py +8 -26
  45. guppylang_internals/tys/ty.py +31 -21
  46. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/METADATA +4 -4
  47. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/RECORD +49 -46
  48. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/WHEEL +0 -0
  49. {guppylang_internals-0.23.0.dist-info → guppylang_internals-0.25.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,6 +1,7 @@
1
1
  import ast
2
2
  import sys
3
3
  from collections.abc import Sequence
4
+ from dataclasses import dataclass, field
4
5
  from types import ModuleType
5
6
 
6
7
  from guppylang_internals.ast_util import (
@@ -17,7 +18,7 @@ from guppylang_internals.definition.ty import TypeDef
17
18
  from guppylang_internals.engine import ENGINE
18
19
  from guppylang_internals.error import GuppyError
19
20
  from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
20
- from guppylang_internals.tys.builtin import CallableTypeDef, bool_type
21
+ from guppylang_internals.tys.builtin import CallableTypeDef, SelfTypeDef, bool_type
21
22
  from guppylang_internals.tys.const import ConstValue
22
23
  from guppylang_internals.tys.errors import (
23
24
  CallableComptimeError,
@@ -34,9 +35,10 @@ from guppylang_internals.tys.errors import (
34
35
  LinearConstParamError,
35
36
  ModuleMemberNotFoundError,
36
37
  NonLinearOwnedError,
38
+ SelfTyNotInMethodError,
39
+ WrongNumberOfTypeArgsError,
37
40
  )
38
41
  from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
39
- from guppylang_internals.tys.subst import BoundVarFinder
40
42
  from guppylang_internals.tys.ty import (
41
43
  FuncInput,
42
44
  FunctionType,
@@ -48,46 +50,51 @@ from guppylang_internals.tys.ty import (
48
50
  )
49
51
 
50
52
 
51
- def arg_from_ast(
52
- node: AstNode,
53
- globals: Globals,
54
- param_var_mapping: dict[str, Parameter],
55
- allow_free_vars: bool = False,
56
- ) -> Argument:
53
+ @dataclass(frozen=True)
54
+ class TypeParsingCtx:
55
+ """Context for parsing types from AST nodes."""
56
+
57
+ #: The globals variable context
58
+ globals: Globals
59
+
60
+ #: The available type parameters indexed by name
61
+ param_var_mapping: dict[str, Parameter] = field(default_factory=dict)
62
+
63
+ #: Whether a previously unseen type parameters is allowed to be bound (i.e. is
64
+ #: allowed to be added to `param_var_mapping`
65
+ allow_free_vars: bool = False
66
+
67
+ #: When parsing types in the signature or body of a method, we also need access to
68
+ #: the type this method belongs to in order to resolve `Self` annotations.
69
+ self_ty: Type | None = None
70
+
71
+
72
+ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
57
73
  """Turns an AST expression into an argument."""
58
74
  from guppylang_internals.checker.cfg_checker import VarNotDefinedError
59
75
 
60
76
  # A single (possibly qualified) identifier
61
- if defn := _try_parse_defn(node, globals):
62
- return _arg_from_instantiated_defn(
63
- defn, [], globals, node, param_var_mapping, allow_free_vars
64
- )
77
+ if defn := _try_parse_defn(node, ctx.globals):
78
+ return _arg_from_instantiated_defn(defn, [], node, ctx)
65
79
 
66
80
  # An identifier referring to a quantified variable
67
81
  if isinstance(node, ast.Name):
68
- if node.id in param_var_mapping:
69
- return param_var_mapping[node.id].to_bound()
82
+ if node.id in ctx.param_var_mapping:
83
+ return ctx.param_var_mapping[node.id].to_bound()
70
84
  raise GuppyError(VarNotDefinedError(node, node.id))
71
85
 
72
86
  # A parametrised type, e.g. `list[??]`
73
87
  if isinstance(node, ast.Subscript) and (
74
- defn := _try_parse_defn(node.value, globals)
88
+ defn := _try_parse_defn(node.value, ctx.globals)
75
89
  ):
76
90
  arg_nodes = (
77
91
  node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
78
92
  )
79
- return _arg_from_instantiated_defn(
80
- defn, arg_nodes, globals, node, param_var_mapping, allow_free_vars
81
- )
93
+ return _arg_from_instantiated_defn(defn, arg_nodes, node, ctx)
82
94
 
83
95
  # We allow tuple types to be written as `(int, bool)`
84
96
  if isinstance(node, ast.Tuple):
85
- ty = TupleType(
86
- [
87
- type_from_ast(el, globals, param_var_mapping, allow_free_vars)
88
- for el in node.elts
89
- ]
90
- )
97
+ ty = TupleType([type_from_ast(el, ctx) for el in node.elts])
91
98
  return TypeArg(ty)
92
99
 
93
100
  # Literals
@@ -118,7 +125,7 @@ def arg_from_ast(
118
125
  if comptime_expr := is_comptime_expression(node):
119
126
  from guppylang_internals.checker.expr_checker import eval_comptime_expr
120
127
 
121
- v = eval_comptime_expr(comptime_expr, Context(globals, Locals({}), {}))
128
+ v = eval_comptime_expr(comptime_expr, Context(ctx.globals, Locals({}), {}))
122
129
  if isinstance(v, int):
123
130
  nat_ty = NumericType(NumericType.Kind.Nat)
124
131
  return ConstArg(ConstValue(nat_ty, v))
@@ -128,7 +135,7 @@ def arg_from_ast(
128
135
  # Finally, we also support delayed annotations in strings
129
136
  if isinstance(node, ast.Constant) and isinstance(node.value, str):
130
137
  node = _parse_delayed_annotation(node.value, node)
131
- return arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
138
+ return arg_from_ast(node, ctx)
132
139
 
133
140
  raise GuppyError(InvalidTypeArgError(node))
134
141
 
@@ -165,28 +172,19 @@ def _try_parse_defn(node: AstNode, globals: Globals) -> Definition | None:
165
172
 
166
173
 
167
174
  def _arg_from_instantiated_defn(
168
- defn: Definition,
169
- arg_nodes: list[ast.expr],
170
- globals: Globals,
171
- node: AstNode,
172
- param_var_mapping: dict[str, Parameter],
173
- allow_free_vars: bool = False,
175
+ defn: Definition, arg_nodes: list[ast.expr], node: AstNode, ctx: TypeParsingCtx
174
176
  ) -> Argument:
175
177
  """Parses a globals definition with type args into an argument."""
176
178
  match defn:
177
179
  # Special case for the `Callable` type
178
180
  case CallableTypeDef():
179
- return TypeArg(
180
- _parse_callable_type(
181
- arg_nodes, node, globals, param_var_mapping, allow_free_vars
182
- )
183
- )
181
+ return TypeArg(_parse_callable_type(arg_nodes, node, ctx))
182
+ # Special case for the `Callable` type
183
+ case SelfTypeDef():
184
+ return TypeArg(_parse_self_type(arg_nodes, node, ctx))
184
185
  # Either a defined type (e.g. `int`, `bool`, ...)
185
186
  case TypeDef() as defn:
186
- args = [
187
- arg_from_ast(arg_node, globals, param_var_mapping, allow_free_vars)
188
- for arg_node in arg_nodes
189
- ]
187
+ args = [arg_from_ast(arg_node, ctx) for arg_node in arg_nodes]
190
188
  ty = defn.check_instantiate(args, node)
191
189
  return TypeArg(ty)
192
190
  # Or a parameter (e.g. `T`, `n`, ...)
@@ -194,12 +192,14 @@ def _arg_from_instantiated_defn(
194
192
  # We don't allow parametrised variables like `T[int]`
195
193
  if arg_nodes:
196
194
  raise GuppyError(HigherKindedTypeVarError(node, defn))
197
- if defn.name not in param_var_mapping:
198
- if allow_free_vars:
199
- param_var_mapping[defn.name] = defn.to_param(len(param_var_mapping))
195
+ if defn.name not in ctx.param_var_mapping:
196
+ if ctx.allow_free_vars:
197
+ ctx.param_var_mapping[defn.name] = defn.to_param(
198
+ len(ctx.param_var_mapping)
199
+ )
200
200
  else:
201
201
  raise GuppyError(FreeTypeVarError(node, defn))
202
- return param_var_mapping[defn.name].to_bound()
202
+ return ctx.param_var_mapping[defn.name].to_bound()
203
203
  case defn:
204
204
  err = ExpectedError(node, "a type", got=f"{defn.description} `{defn.name}`")
205
205
  raise GuppyError(err)
@@ -224,11 +224,7 @@ def _parse_delayed_annotation(ast_str: str, node: ast.Constant) -> ast.expr:
224
224
 
225
225
 
226
226
  def _parse_callable_type(
227
- args: list[ast.expr],
228
- loc: AstNode,
229
- globals: Globals,
230
- param_var_mapping: dict[str, Parameter],
231
- allow_free_vars: bool = False,
227
+ args: list[ast.expr], loc: AstNode, ctx: TypeParsingCtx
232
228
  ) -> FunctionType:
233
229
  """Helper function to parse a `Callable[[<arguments>], <return type>]` type."""
234
230
  err = InvalidCallableTypeError(loc)
@@ -237,64 +233,74 @@ def _parse_callable_type(
237
233
  [inputs, output] = args
238
234
  if not isinstance(inputs, ast.List):
239
235
  raise GuppyError(err)
240
- inouts, output = parse_function_io_types(
241
- inputs.elts, output, None, loc, globals, param_var_mapping, allow_free_vars
242
- )
243
- return FunctionType(inouts, output)
244
-
245
-
246
- def parse_function_io_types(
247
- input_nodes: list[ast.expr],
248
- output_node: ast.expr,
249
- input_names: list[str] | None,
250
- loc: AstNode,
251
- globals: Globals,
252
- param_var_mapping: dict[str, Parameter],
253
- allow_free_vars: bool = False,
254
- ) -> tuple[list[FuncInput], Type]:
255
- """Parses the inputs and output types of a function type.
256
-
257
- This function takes care of parsing annotations and any related checks.
258
-
259
- Returns the parsed input and output types.
236
+ inputs = [parse_function_arg_annotation(inp, None, ctx) for inp in inputs.elts]
237
+ output = type_from_ast(output, ctx)
238
+ return FunctionType(inputs, output)
239
+
240
+
241
+ def _parse_self_type(args: list[ast.expr], loc: AstNode, ctx: TypeParsingCtx) -> Type:
242
+ """Helper function to parse a `Self` type.
243
+
244
+ Returns the actual type `Self` refers to or emits a user error if we're not inside
245
+ a method.
260
246
  """
261
- inputs = []
262
- for i, inp in enumerate(input_nodes):
263
- ty, flags = type_with_flags_from_ast(
264
- inp, globals, param_var_mapping, allow_free_vars
247
+ if ctx.self_ty is None:
248
+ raise GuppyError(SelfTyNotInMethodError(loc))
249
+
250
+ # We don't allow specifying generic arguments of `Self`. This matches the behaviour
251
+ # of Python.
252
+ if args:
253
+ raise GuppyError(WrongNumberOfTypeArgsError(loc, 0, len(args), "Self"))
254
+ return ctx.self_ty
255
+
256
+
257
+ def parse_function_arg_annotation(
258
+ annotation: ast.expr, name: str | None, ctx: TypeParsingCtx
259
+ ) -> FuncInput:
260
+ """Parses an annotation in the input of a function type."""
261
+ ty, flags = type_with_flags_from_ast(annotation, ctx)
262
+ return check_function_arg(ty, flags, annotation, name, ctx)
263
+
264
+
265
+ def check_function_arg(
266
+ ty: Type, flags: InputFlags, loc: AstNode, name: str | None, ctx: TypeParsingCtx
267
+ ) -> FuncInput:
268
+ """Given a function input type and its user-provided flags, checks if the flags
269
+ are valid and inserts implicit flags."""
270
+ if InputFlags.Owned in flags and ty.copyable:
271
+ raise GuppyError(NonLinearOwnedError(loc, ty))
272
+ if not ty.copyable and InputFlags.Owned not in flags:
273
+ flags |= InputFlags.Inout
274
+ if InputFlags.Comptime in flags:
275
+ if name is None:
276
+ raise GuppyError(CallableComptimeError(loc))
277
+
278
+ # Make sure we're not shadowing a type variable with the same name that was
279
+ # already used on the left. E.g
280
+ #
281
+ # n = guppy.type_var("n")
282
+ # def foo(xs: array[int, n], n: nat @comptime)
283
+ #
284
+ # TODO: In principle we could lift this restriction by tracking multiple
285
+ # params referring to the same name in `param_var_mapping`, but not sure if
286
+ # this would be worth it...
287
+ if name in ctx.param_var_mapping:
288
+ raise GuppyError(ComptimeArgShadowError(loc, name))
289
+ ctx.param_var_mapping[name] = ConstParam(
290
+ len(ctx.param_var_mapping), name, ty, from_comptime_arg=True
265
291
  )
266
- if InputFlags.Owned in flags and ty.copyable:
267
- raise GuppyError(NonLinearOwnedError(loc, ty))
268
- if not ty.copyable and InputFlags.Owned not in flags:
269
- flags |= InputFlags.Inout
270
- if InputFlags.Comptime in flags:
271
- if input_names is None:
272
- raise GuppyError(CallableComptimeError(inp))
273
- name = input_names[i]
274
-
275
- # Make sure we're not shadowing a type variable with the same name that was
276
- # already used on the left. E.g
277
- #
278
- # n = guppy.type_var("n")
279
- # def foo(xs: array[int, n], n: nat @comptime)
280
- #
281
- # TODO: In principle we could lift this restriction by tracking multiple
282
- # params referring to the same name in `param_var_mapping`, but not sure if
283
- # this would be worth it...
284
- if name in param_var_mapping:
285
- raise GuppyError(ComptimeArgShadowError(inp, name))
286
- param_var_mapping[name] = ConstParam(
287
- len(param_var_mapping), name, ty, from_comptime_arg=True
288
- )
289
-
290
- inputs.append(FuncInput(ty, flags))
291
- output = type_from_ast(output_node, globals, param_var_mapping, allow_free_vars)
292
- return inputs, output
292
+ return FuncInput(ty, flags)
293
293
 
294
294
 
295
295
  if sys.version_info >= (3, 12):
296
296
 
297
- def parse_parameter(node: ast.type_param, idx: int, globals: Globals) -> Parameter:
297
+ def parse_parameter(
298
+ node: ast.type_param,
299
+ idx: int,
300
+ globals: Globals,
301
+ param_var_mapping: dict[str, Parameter],
302
+ allow_free_vars: bool = False,
303
+ ) -> Parameter:
298
304
  """Parses a `Variable: Bound` generic type parameter declaration."""
299
305
  if isinstance(node, ast.TypeVarTuple | ast.ParamSpec):
300
306
  raise GuppyError(UnsupportedError(node, "Variadic generic parameters"))
@@ -330,17 +336,10 @@ if sys.version_info >= (3, 12):
330
336
  # parameters, so we pass an empty dict as the `param_var_mapping`.
331
337
  # TODO: In the future we might want to allow stuff like
332
338
  # `def foo[T, XS: array[T, 42]]` and so on
333
- ty = type_from_ast(bound, globals, {}, allow_free_vars=False)
339
+ ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars)
340
+ ty = type_from_ast(bound, ctx)
334
341
  if not ty.copyable or not ty.droppable:
335
342
  raise GuppyError(LinearConstParamError(bound, ty))
336
-
337
- # TODO: For now we can only do `nat` const args since they lower to
338
- # Hugr bounded nats. Extend to arbitrary types via monomorphization.
339
- # See https://github.com/CQCL/guppylang/issues/1008
340
- if ty != NumericType(NumericType.Kind.Nat):
341
- raise GuppyError(
342
- UnsupportedError(bound, f"`{ty}` generic parameters")
343
- )
344
343
  return ConstParam(idx, node.name, ty)
345
344
 
346
345
 
@@ -348,15 +347,10 @@ _type_param = TypeParam(0, "T", False, False)
348
347
 
349
348
 
350
349
  def type_with_flags_from_ast(
351
- node: AstNode,
352
- globals: Globals,
353
- param_var_mapping: dict[str, Parameter],
354
- allow_free_vars: bool = False,
350
+ node: AstNode, ctx: TypeParsingCtx
355
351
  ) -> tuple[Type, InputFlags]:
356
352
  if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
357
- ty, flags = type_with_flags_from_ast(
358
- node.left, globals, param_var_mapping, allow_free_vars
359
- )
353
+ ty, flags = type_with_flags_from_ast(node.left, ctx)
360
354
  match node.right:
361
355
  case ast.Name(id="owned"):
362
356
  if ty.copyable:
@@ -366,51 +360,30 @@ def type_with_flags_from_ast(
366
360
  flags |= InputFlags.Comptime
367
361
  if not ty.copyable or not ty.droppable:
368
362
  raise GuppyError(LinearComptimeError(node.right, ty))
369
- # For now, we don't allow comptime annotations on generic inputs
370
- # TODO: In the future we might want to allow stuff like
371
- # `def foo[T: (Copy, Discard](x: T @comptime)`.
372
- # Also see the todo in `parse_parameter`.
373
- var_finder = BoundVarFinder()
374
- ty.visit(var_finder)
375
- if var_finder.bound_vars:
376
- raise GuppyError(
377
- UnsupportedError(node.left, "Generic comptime arguments")
378
- )
379
363
  case _:
380
364
  raise GuppyError(InvalidFlagError(node.right))
381
365
  return ty, flags
382
366
  # We also need to handle the case that this could be a delayed string annotation
383
367
  elif isinstance(node, ast.Constant) and isinstance(node.value, str):
384
368
  node = _parse_delayed_annotation(node.value, node)
385
- return type_with_flags_from_ast(
386
- node, globals, param_var_mapping, allow_free_vars
387
- )
369
+ return type_with_flags_from_ast(node, ctx)
388
370
  else:
389
371
  # Parse an argument and check that it's valid for a `TypeParam`
390
- arg = arg_from_ast(node, globals, param_var_mapping, allow_free_vars)
372
+ arg = arg_from_ast(node, ctx)
391
373
  tyarg = _type_param.check_arg(arg, node)
392
374
  return tyarg.ty, InputFlags.NoFlags
393
375
 
394
376
 
395
- def type_from_ast(
396
- node: AstNode,
397
- globals: Globals,
398
- param_var_mapping: dict[str, Parameter],
399
- allow_free_vars: bool = False,
400
- ) -> Type:
377
+ def type_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Type:
401
378
  """Turns an AST expression into a Guppy type."""
402
- ty, flags = type_with_flags_from_ast(
403
- node, globals, param_var_mapping, allow_free_vars
404
- )
379
+ ty, flags = type_with_flags_from_ast(node, ctx)
405
380
  if flags != InputFlags.NoFlags:
406
381
  assert InputFlags.Inout not in flags # Users shouldn't be able to set this
407
382
  raise GuppyError(FlagNotAllowedError(node))
408
383
  return ty
409
384
 
410
385
 
411
- def type_row_from_ast(
412
- node: ast.expr, globals: "Globals", allow_free_vars: bool = False
413
- ) -> Sequence[Type]:
386
+ def type_row_from_ast(node: ast.expr, ctx: TypeParsingCtx) -> Sequence[Type]:
414
387
  """Turns an AST expression into a Guppy type row.
415
388
 
416
389
  This is needed to interpret the return type annotation of functions.
@@ -418,7 +391,7 @@ def type_row_from_ast(
418
391
  # The return type `-> None` is represented in the ast as `ast.Constant(value=None)`
419
392
  if isinstance(node, ast.Constant) and node.value is None:
420
393
  return []
421
- ty = type_from_ast(node, globals, {}, allow_free_vars)
394
+ ty = type_from_ast(node, ctx)
422
395
  if isinstance(ty, TupleType):
423
396
  return ty.element_types
424
397
  else:
@@ -0,0 +1,27 @@
1
+ import functools
2
+ from typing import cast
3
+
4
+ from guppylang_internals.definition.ty import TypeDef
5
+ from guppylang_internals.tys.ty import Type
6
+
7
+
8
+ @functools.cache
9
+ def qubit_ty() -> Type:
10
+ """Returns the qubit type. Beware that this function imports guppylang definitions,
11
+ so, if called before the definitions are registered,
12
+ it might result in circular imports.
13
+ """
14
+ from guppylang.defs import GuppyDefinition
15
+ from guppylang.std.quantum import qubit
16
+
17
+ assert isinstance(qubit, GuppyDefinition)
18
+ qubit_ty = cast(TypeDef, qubit.wrapped).check_instantiate([])
19
+ return qubit_ty
20
+
21
+
22
+ def is_qubit_ty(ty: Type) -> bool:
23
+ """Checks if the given type is the qubit type.
24
+ This function results in circular imports if called
25
+ before qubit types are registered.
26
+ """
27
+ return ty == qubit_ty()
@@ -4,7 +4,7 @@ from typing import Any
4
4
 
5
5
  from guppylang_internals.error import InternalGuppyError
6
6
  from guppylang_internals.tys.arg import Argument, ConstArg, TypeArg
7
- from guppylang_internals.tys.common import Transformer, Visitor
7
+ from guppylang_internals.tys.common import Transformer
8
8
  from guppylang_internals.tys.const import (
9
9
  BoundConstVar,
10
10
  Const,
@@ -18,7 +18,7 @@ from guppylang_internals.tys.ty import (
18
18
  Type,
19
19
  TypeBase,
20
20
  )
21
- from guppylang_internals.tys.var import BoundVar, ExistentialVar
21
+ from guppylang_internals.tys.var import ExistentialVar
22
22
 
23
23
  Subst = dict[ExistentialVar, Type | Const]
24
24
  Inst = Sequence[Argument]
@@ -51,7 +51,8 @@ class Substituter(Transformer):
51
51
  class Instantiator(Transformer):
52
52
  """Type transformer that instantiates bound variables."""
53
53
 
54
- def __init__(self, inst: Inst) -> None:
54
+ def __init__(self, inst: PartialInst, allow_partial: bool = False) -> None:
55
+ self.allow_partial = allow_partial
55
56
  self.inst = inst
56
57
 
57
58
  @functools.singledispatchmethod
@@ -63,6 +64,8 @@ class Instantiator(Transformer):
63
64
  # Instantiate if type for the index is available
64
65
  if ty.idx < len(self.inst):
65
66
  arg = self.inst[ty.idx]
67
+ if arg is None and self.allow_partial:
68
+ return None
66
69
  assert isinstance(arg, TypeArg)
67
70
  return arg.ty
68
71
 
@@ -76,6 +79,8 @@ class Instantiator(Transformer):
76
79
  # Instantiate if const value for the index is available
77
80
  if c.idx < len(self.inst):
78
81
  arg = self.inst[c.idx]
82
+ if arg is None and self.allow_partial:
83
+ return None
79
84
  assert isinstance(arg, ConstArg)
80
85
  return arg.const
81
86
 
@@ -87,26 +92,3 @@ class Instantiator(Transformer):
87
92
  if ty.parametrized:
88
93
  raise InternalGuppyError("Tried to instantiate under binder")
89
94
  return None
90
-
91
-
92
- class BoundVarFinder(Visitor):
93
- """Type visitor that looks for occurrences of bound variables."""
94
-
95
- bound_vars: set[BoundVar]
96
-
97
- def __init__(self) -> None:
98
- self.bound_vars = set()
99
-
100
- @functools.singledispatchmethod
101
- def visit(self, ty: Any) -> bool: # type: ignore[override]
102
- return False
103
-
104
- @visit.register
105
- def _transform_BoundTypeVar(self, ty: BoundTypeVar) -> bool:
106
- self.bound_vars.add(ty)
107
- return False
108
-
109
- @visit.register
110
- def _transform_BoundConstVar(self, c: BoundConstVar) -> bool:
111
- self.bound_vars.add(c)
112
- return False
@@ -57,14 +57,11 @@ class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
57
57
  return not self.copyable and self.droppable
58
58
 
59
59
  @cached_property
60
- @abstractmethod
61
60
  def hugr_bound(self) -> ht.TypeBound:
62
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.
63
-
64
- This needs to be specified explicitly, since opaque nonlinear types in a Hugr
65
- extension could be either declared as copyable or equatable. If we don't get the
66
- bound exactly right during serialisation, the Hugr validator will complain.
67
- """
61
+ """The Hugr bound of this type, i.e. `Any` or `Copyable`."""
62
+ if self.linear or self.affine:
63
+ return ht.TypeBound.Linear
64
+ return ht.TypeBound.Copyable
68
65
 
69
66
  @abstractmethod
70
67
  def cast(self) -> "Type":
@@ -79,6 +76,11 @@ class TypeBase(ToHugr[ht.Type], Transformable["Type"], ABC):
79
76
  """The existential type variables contained in this type."""
80
77
  return set()
81
78
 
79
+ @cached_property
80
+ def bound_vars(self) -> set[BoundVar]:
81
+ """The bound type variables contained in this type."""
82
+ return set()
83
+
82
84
  def substitute(self, subst: "Subst") -> "Type":
83
85
  """Substitutes existential variables in this type."""
84
86
  from guppylang_internals.tys.subst import Substituter
@@ -158,13 +160,17 @@ class ParametrizedTypeBase(TypeBase, ABC):
158
160
  """The existential type variables contained in this type."""
159
161
  return set().union(*(arg.unsolved_vars for arg in self.args))
160
162
 
163
+ @cached_property
164
+ def bound_vars(self) -> set[BoundVar]:
165
+ """The bound type variables contained in this type."""
166
+ return set().union(*(arg.bound_vars for arg in self.args))
167
+
161
168
  @cached_property
162
169
  def hugr_bound(self) -> ht.TypeBound:
163
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
164
- if self.linear:
165
- return ht.TypeBound.Linear
170
+ """The Hugr bound of this type, i.e. `Any` or `Copyable`."""
166
171
  return ht.TypeBound.join(
167
- *(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg))
172
+ super().hugr_bound,
173
+ *(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg)),
168
174
  )
169
175
 
170
176
  def visit(self, visitor: Visitor) -> None:
@@ -187,14 +193,10 @@ class BoundTypeVar(TypeBase, BoundVar):
187
193
  copyable: bool
188
194
  droppable: bool
189
195
 
190
- @cached_property
191
- def hugr_bound(self) -> ht.TypeBound:
192
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
193
- if self.linear:
194
- return ht.TypeBound.Linear
195
- # We're conservative and don't require equatability for non-linear variables.
196
- # This is fine since Guppy doesn't use the equatable feature anyways.
197
- return ht.TypeBound.Copyable
196
+ @property
197
+ def bound_vars(self) -> set[BoundVar]:
198
+ """The bound type variables contained in this type."""
199
+ return {self}
198
200
 
199
201
  def cast(self) -> "Type":
200
202
  """Casts an implementor of `TypeBase` into a `Type`."""
@@ -426,6 +428,14 @@ class FunctionType(ParametrizedTypeBase):
426
428
  """Whether the function is parametrized."""
427
429
  return len(self.params) > 0
428
430
 
431
+ @cached_property
432
+ def bound_vars(self) -> set[BoundVar]:
433
+ """The bound type variables contained in this type."""
434
+ if self.parametrized:
435
+ # Ensures that we don't look inside quantifiers
436
+ return set()
437
+ return super().bound_vars
438
+
429
439
  def cast(self) -> "Type":
430
440
  """Casts an implementor of `TypeBase` into a `Type`."""
431
441
  return self
@@ -506,7 +516,7 @@ class FunctionType(ParametrizedTypeBase):
506
516
  # However, we have to down-shift the de Bruijn index.
507
517
  if arg is None:
508
518
  param = param.with_idx(len(remaining_params))
509
- remaining_params.append(param)
519
+ remaining_params.append(param.instantiate_bounds(full_inst))
510
520
  arg = param.to_bound()
511
521
 
512
522
  # Set the `preserve` flag for instantiated tuples and None
@@ -651,7 +661,7 @@ class OpaqueType(ParametrizedTypeBase):
651
661
 
652
662
  @property
653
663
  def hugr_bound(self) -> ht.TypeBound:
654
- """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`."""
664
+ """The Hugr bound of this type, i.e. `Any` or `Copyable`."""
655
665
  if self.defn.bound is not None:
656
666
  return self.defn.bound
657
667
  return super().hugr_bound
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: guppylang-internals
3
- Version: 0.23.0
3
+ Version: 0.25.0
4
4
  Summary: Compiler internals for `guppylang` package.
5
5
  Author-email: Mark Koch <mark.koch@quantinuum.com>, TKET development team <tket-support@quantinuum.com>
6
6
  Maintainer-email: Mark Koch <mark.koch@quantinuum.com>, TKET development team <tket-support@quantinuum.com>
@@ -219,8 +219,8 @@ Classifier: Programming Language :: Python :: 3.13
219
219
  Classifier: Programming Language :: Python :: 3.14
220
220
  Classifier: Topic :: Software Development :: Compilers
221
221
  Requires-Python: <4,>=3.10
222
- Requires-Dist: hugr~=0.13.1
223
- Requires-Dist: tket-exts~=0.10.0
222
+ Requires-Dist: hugr~=0.14.1
223
+ Requires-Dist: tket-exts~=0.12.0
224
224
  Requires-Dist: typing-extensions<5,>=4.9.0
225
225
  Provides-Extra: pytket
226
226
  Requires-Dist: pytket>=1.34; extra == 'pytket'
@@ -228,7 +228,7 @@ Description-Content-Type: text/markdown
228
228
 
229
229
  # guppylang-internals
230
230
 
231
- This packages contains the internals of the Guppy compiler.
231
+ This packages contains the internals of the Guppy compiler.
232
232
 
233
233
  See `guppylang` for the package providing the user-facing language frontend.
234
234