guppylang-internals 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/builder.py +20 -2
  3. guppylang_internals/cfg/cfg.py +3 -0
  4. guppylang_internals/checker/cfg_checker.py +6 -0
  5. guppylang_internals/checker/core.py +1 -2
  6. guppylang_internals/checker/errors/linearity.py +6 -2
  7. guppylang_internals/checker/errors/wasm.py +7 -4
  8. guppylang_internals/checker/expr_checker.py +39 -19
  9. guppylang_internals/checker/func_checker.py +17 -13
  10. guppylang_internals/checker/linearity_checker.py +2 -10
  11. guppylang_internals/checker/modifier_checker.py +6 -2
  12. guppylang_internals/checker/unitary_checker.py +132 -0
  13. guppylang_internals/compiler/cfg_compiler.py +7 -6
  14. guppylang_internals/compiler/core.py +5 -5
  15. guppylang_internals/compiler/expr_compiler.py +72 -81
  16. guppylang_internals/compiler/modifier_compiler.py +5 -0
  17. guppylang_internals/decorator.py +88 -7
  18. guppylang_internals/definition/custom.py +4 -0
  19. guppylang_internals/definition/declaration.py +6 -2
  20. guppylang_internals/definition/function.py +26 -3
  21. guppylang_internals/definition/metadata.py +87 -0
  22. guppylang_internals/definition/overloaded.py +11 -2
  23. guppylang_internals/definition/pytket_circuits.py +7 -2
  24. guppylang_internals/definition/struct.py +6 -3
  25. guppylang_internals/definition/wasm.py +42 -10
  26. guppylang_internals/diagnostic.py +72 -15
  27. guppylang_internals/engine.py +10 -13
  28. guppylang_internals/nodes.py +55 -24
  29. guppylang_internals/std/_internal/checker.py +13 -108
  30. guppylang_internals/std/_internal/compiler/array.py +37 -2
  31. guppylang_internals/std/_internal/compiler/either.py +14 -2
  32. guppylang_internals/std/_internal/compiler/list.py +1 -1
  33. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  34. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  35. guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
  36. guppylang_internals/std/_internal/compiler/tket_exts.py +4 -5
  37. guppylang_internals/std/_internal/debug.py +18 -9
  38. guppylang_internals/std/_internal/util.py +1 -1
  39. guppylang_internals/tracing/object.py +14 -0
  40. guppylang_internals/tys/errors.py +23 -1
  41. guppylang_internals/tys/parsing.py +3 -3
  42. guppylang_internals/tys/printing.py +2 -8
  43. guppylang_internals/tys/qubit.py +37 -2
  44. guppylang_internals/tys/ty.py +60 -64
  45. guppylang_internals/wasm_util.py +129 -0
  46. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/METADATA +5 -4
  47. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/RECORD +49 -45
  48. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/WHEEL +1 -1
  49. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/licenses/LICENCE +0 -0
@@ -4,9 +4,9 @@ from typing import ClassVar
4
4
 
5
5
  from typing_extensions import assert_never
6
6
 
7
- from guppylang_internals.ast_util import get_type, with_loc
8
- from guppylang_internals.checker.core import ComptimeVariable, Context
9
- from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
7
+ from guppylang_internals.ast_util import get_type, with_loc, with_type
8
+ from guppylang_internals.checker.core import Context
9
+ from guppylang_internals.checker.errors.generic import UnsupportedError
10
10
  from guppylang_internals.checker.errors.type_errors import (
11
11
  ArrayComprUnknownSizeError,
12
12
  TypeMismatchError,
@@ -33,8 +33,6 @@ from guppylang_internals.nodes import (
33
33
  GlobalCall,
34
34
  MakeIter,
35
35
  PanicExpr,
36
- PlaceNode,
37
- ResultExpr,
38
36
  )
39
37
  from guppylang_internals.tys.arg import ConstArg, TypeArg
40
38
  from guppylang_internals.tys.builtin import (
@@ -45,7 +43,6 @@ from guppylang_internals.tys.builtin import (
45
43
  get_iter_size,
46
44
  int_type,
47
45
  is_array_type,
48
- is_bool_type,
49
46
  is_sized_iter_type,
50
47
  nat_type,
51
48
  sized_iter_type,
@@ -58,7 +55,6 @@ from guppylang_internals.tys.ty import (
58
55
  FunctionType,
59
56
  InputFlags,
60
57
  NoneType,
61
- NumericType,
62
58
  Type,
63
59
  unify,
64
60
  )
@@ -302,80 +298,6 @@ class NewArrayChecker(CustomCallChecker):
302
298
  return with_loc(compr, array_compr), array_type(elt_ty, size)
303
299
 
304
300
 
305
- #: Maximum length of a tag in the `result` function.
306
- TAG_MAX_LEN = 200
307
-
308
-
309
- @dataclass(frozen=True)
310
- class TooLongError(Error):
311
- title: ClassVar[str] = "Tag too long"
312
- span_label: ClassVar[str] = "Result tag is too long"
313
-
314
- @dataclass(frozen=True)
315
- class Hint(Note):
316
- message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
317
-
318
-
319
- class ResultChecker(CustomCallChecker):
320
- """Call checker for the `result` function."""
321
-
322
- @dataclass(frozen=True)
323
- class InvalidError(Error):
324
- title: ClassVar[str] = "Invalid Result"
325
- span_label: ClassVar[str] = "Expression of type `{ty}` is not a valid result."
326
- ty: Type
327
-
328
- @dataclass(frozen=True)
329
- class Explanation(Note):
330
- message: ClassVar[str] = (
331
- "Only numeric values or arrays thereof are allowed as results"
332
- )
333
-
334
- def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
335
- check_num_args(2, len(args), self.node)
336
- [tag, value] = args
337
- tag, _ = ExprChecker(self.ctx).check(tag, string_type())
338
- match tag:
339
- case ast.Constant(value=str(v)):
340
- tag_value = v
341
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
342
- tag_value = v
343
- case _:
344
- raise GuppyTypeError(ExpectedError(tag, "a string literal"))
345
- if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
346
- err: Error = TooLongError(tag)
347
- err.add_sub_diagnostic(TooLongError.Hint(None))
348
- raise GuppyTypeError(err)
349
- value, ty = ExprSynthesizer(self.ctx).synthesize(value)
350
- # We only allow numeric values or vectors of numeric values
351
- err = ResultChecker.InvalidError(value, ty)
352
- err.add_sub_diagnostic(ResultChecker.InvalidError.Explanation(None))
353
- if self._is_numeric_or_bool_type(ty):
354
- base_ty = ty
355
- array_len: Const | None = None
356
- elif is_array_type(ty):
357
- [ty_arg, len_arg] = ty.args
358
- assert isinstance(ty_arg, TypeArg)
359
- assert isinstance(len_arg, ConstArg)
360
- if not self._is_numeric_or_bool_type(ty_arg.ty):
361
- raise GuppyError(err)
362
- base_ty = ty_arg.ty
363
- array_len = len_arg.const
364
- else:
365
- raise GuppyError(err)
366
- node = ResultExpr(value, base_ty, array_len, tag_value)
367
- return with_loc(self.node, node), NoneType()
368
-
369
- def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
370
- expr, res_ty = self.synthesize(args)
371
- expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
372
- return expr, subst
373
-
374
- @staticmethod
375
- def _is_numeric_or_bool_type(ty: Type) -> bool:
376
- return isinstance(ty, NumericType) or is_bool_type(ty)
377
-
378
-
379
301
  class PanicChecker(CustomCallChecker):
380
302
  """Call checker for the `panic` function."""
381
303
 
@@ -395,18 +317,16 @@ class PanicChecker(CustomCallChecker):
395
317
  err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
396
318
  raise GuppyTypeError(err)
397
319
  case [msg, *rest]:
320
+ # Check type of message and synthesize types for additional values.
398
321
  msg, _ = ExprChecker(self.ctx).check(msg, string_type())
399
- match msg:
400
- case ast.Constant(value=str(v)):
401
- msg_value = v
402
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
403
- msg_value = v
404
- case _:
405
- raise GuppyTypeError(ExpectedError(msg, "a string literal"))
406
322
  vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
407
323
  # TODO variable signals once default arguments are available
324
+ # TODO this will also allow us to remove this manual AST node hack
325
+ signal_expr = with_type(
326
+ int_type(), with_loc(self.node, ast.Constant(value=1))
327
+ )
408
328
  node = PanicExpr(
409
- kind=ExitKind.Panic, msg=msg_value, values=vals, signal=1
329
+ kind=ExitKind.Panic, msg=msg, values=vals, signal=signal_expr
410
330
  )
411
331
  return with_loc(self.node, node), NoneType()
412
332
  case args:
@@ -454,31 +374,16 @@ class ExitChecker(CustomCallChecker):
454
374
  )
455
375
  raise GuppyTypeError(signal_err)
456
376
  case [msg, signal, *rest]:
377
+ # Check types for message and signal and synthesize types for additional
378
+ # values.
457
379
  msg, _ = ExprChecker(self.ctx).check(msg, string_type())
458
- match msg:
459
- case ast.Constant(value=str(v)):
460
- msg_value = v
461
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
462
- msg_value = v
463
- case _:
464
- raise GuppyTypeError(ExpectedError(msg, "a string literal"))
465
- # TODO allow variable signals after https://github.com/CQCL/hugr/issues/1863
466
380
  signal, _ = ExprChecker(self.ctx).check(signal, int_type())
467
- match signal:
468
- case ast.Constant(value=int(s)):
469
- signal_value = s
470
- case PlaceNode(place=ComptimeVariable(static_value=int(s))):
471
- signal_value = s
472
- case _:
473
- raise GuppyTypeError(
474
- ExpectedError(signal, "an integer literal")
475
- )
476
381
  vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
477
382
  node = PanicExpr(
478
383
  kind=ExitKind.ExitShot,
479
- msg=msg_value,
384
+ msg=msg,
480
385
  values=vals,
481
- signal=signal_value,
386
+ signal=signal,
482
387
  )
483
388
  return with_loc(self.node, node), NoneType()
484
389
  case args:
@@ -16,6 +16,7 @@ from guppylang_internals.std._internal.compiler.arithmetic import convert_itousi
16
16
  from guppylang_internals.std._internal.compiler.prelude import (
17
17
  build_unwrap_right,
18
18
  )
19
+ from guppylang_internals.std._internal.compiler.tket_bool import make_opaque
19
20
  from guppylang_internals.tys.arg import ConstArg, TypeArg
20
21
 
21
22
  if TYPE_CHECKING:
@@ -206,6 +207,14 @@ def barray_new_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
206
207
  return _instantiate_array_op("new_all_borrowed", elem_ty, length, [], [arr_ty])
207
208
 
208
209
 
210
+ def barray_is_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
211
+ """Returns an array `is_borrowed` operation."""
212
+ arr_ty = array_type(elem_ty, length)
213
+ return _instantiate_array_op(
214
+ "is_borrowed", elem_ty, length, [arr_ty, ht.USize()], [arr_ty, ht.Bool]
215
+ )
216
+
217
+
209
218
  def array_clone(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
210
219
  """Returns an array `clone` operation for arrays none of whose elements are
211
220
  borrowed."""
@@ -261,7 +270,7 @@ class NewArrayCompiler(ArrayCompiler):
261
270
 
262
271
  def build_classical_array(self, elems: list[Wire]) -> Wire:
263
272
  """Lowers a call to `array.__new__` for classical arrays."""
264
- # See https://github.com/CQCL/guppylang/issues/629
273
+ # See https://github.com/quantinuum/guppylang/issues/629
265
274
  return self.build_linear_array(elems)
266
275
 
267
276
  def build_linear_array(self, elems: list[Wire]) -> Wire:
@@ -320,7 +329,15 @@ class ArrayGetitemCompiler(ArrayCompiler):
320
329
 
321
330
 
322
331
  class ArraySetitemCompiler(ArrayCompiler):
323
- """Compiler for the `array.__setitem__` function."""
332
+ """Compiler for the `array.__setitem__` function.
333
+
334
+ Arguments:
335
+ elem_first: If `True`, then compiler will assume that the element wire comes
336
+ before the index wire. Defaults to `False`.
337
+ """
338
+
339
+ def __init__(self, elem_first: bool = False):
340
+ self.elem_first = elem_first
324
341
 
325
342
  def _build_classical_setitem(
326
343
  self, array: Wire, idx: Wire, elem: Wire
@@ -359,6 +376,8 @@ class ArraySetitemCompiler(ArrayCompiler):
359
376
 
360
377
  def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
361
378
  [array, idx, elem] = args
379
+ if self.elem_first:
380
+ elem, idx = idx, elem
362
381
  if self.elem_ty.type_bound() == ht.TypeBound.Linear:
363
382
  return self._build_linear_setitem(array, idx, elem)
364
383
  else:
@@ -379,3 +398,19 @@ class ArrayDiscardAllUsedCompiler(ArrayCompiler):
379
398
  arr,
380
399
  )
381
400
  return []
401
+
402
+
403
+ class ArrayIsBorrowedCompiler(ArrayCompiler):
404
+ """Compiler for the `array.is_borrowed` method."""
405
+
406
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
407
+ [array, idx] = args
408
+ idx = self.builder.add_op(convert_itousize(), idx)
409
+ array, b = self.builder.add_op(
410
+ barray_is_borrowed(self.elem_ty, self.length), array, idx
411
+ )
412
+ b = self.builder.add_op(make_opaque(), b)
413
+ return CallReturnWires(regular_returns=[b], inout_returns=[array])
414
+
415
+ def compile(self, args: list[Wire]) -> list[Wire]:
416
+ raise InternalGuppyError("Call compile_with_inouts instead")
@@ -4,6 +4,8 @@ from collections.abc import Sequence
4
4
  from hugr import Wire, ops
5
5
  from hugr import tys as ht
6
6
 
7
+ from guppylang_internals.ast_util import get_type
8
+ from guppylang_internals.compiler.expr_compiler import pack_returns, unpack_wire
7
9
  from guppylang_internals.definition.custom import (
8
10
  CustomCallCompiler,
9
11
  CustomInoutCallCompiler,
@@ -69,7 +71,14 @@ class EitherConstructor(EitherCompiler, CustomCallCompiler):
69
71
  # In the `right` case, the type args are swapped around since `R` occurs
70
72
  # first in the signature :(
71
73
  ty.variant_rows = [ty.variant_rows[1], ty.variant_rows[0]]
72
- return [self.builder.add_op(ops.Tag(self.tag, ty), *args)]
74
+ # For the same reason, the type of the input corresponds to the first type
75
+ # variable
76
+ inp_arg = self.type_args[0]
77
+ assert isinstance(inp_arg, TypeArg)
78
+ [inp] = args
79
+ # Unpack the single input into a row
80
+ inp_row = unpack_wire(inp, inp_arg.ty, self.builder, self.ctx)
81
+ return [self.builder.add_op(ops.Tag(self.tag, ty), *inp_row)]
73
82
 
74
83
 
75
84
  class EitherTestCompiler(EitherCompiler):
@@ -128,4 +137,7 @@ class EitherUnwrapCompiler(EitherCompiler, CustomCallCompiler):
128
137
  out = build_unwrap_right(
129
138
  self.builder, either, "Either.unwrap_right: value is `left`"
130
139
  )
131
- return list(out)
140
+ # Pack outputs into a single wire. We're not allowed to return a row since the
141
+ # signature has a generic return type (also see `TupleType.preserve`)
142
+ return_ty = get_type(self.node)
143
+ return [pack_returns(list(out), return_ty, self.builder, self.ctx)]
@@ -328,7 +328,7 @@ def _list_new_classical(
328
328
  builder: DfBase[ops.DfParentOp], elem_type: ht.Type, args: list[Wire]
329
329
  ) -> Wire:
330
330
  # This may be simplified in the future with a `new` or `with_capacity` list op
331
- # See https://github.com/CQCL/hugr/issues/1508
331
+ # See https://github.com/quantinuum/hugr/issues/1508
332
332
  lst = builder.load(ListVal([], elem_ty=elem_type))
333
333
  push_op = list_push(elem_type)
334
334
  for elem in args:
@@ -0,0 +1,153 @@
1
+ from dataclasses import dataclass
2
+ from typing import ClassVar
3
+
4
+ import hugr
5
+ from hugr import Wire, ops, tys
6
+
7
+ from guppylang_internals.ast_util import AstNode
8
+ from guppylang_internals.compiler.core import CompilerContext
9
+ from guppylang_internals.compiler.expr_compiler import array_read_bool
10
+ from guppylang_internals.definition.custom import (
11
+ CustomCallCompiler,
12
+ CustomInoutCallCompiler,
13
+ )
14
+ from guppylang_internals.definition.value import CallReturnWires
15
+ from guppylang_internals.diagnostic import Error, Note
16
+ from guppylang_internals.error import GuppyError, InternalGuppyError
17
+ from guppylang_internals.std._internal.compiler.array import (
18
+ array_clone,
19
+ array_map,
20
+ array_to_std_array,
21
+ )
22
+ from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
23
+ from guppylang_internals.std._internal.compiler.tket_exts import RESULT_EXTENSION
24
+ from guppylang_internals.tys.arg import Argument, ConstArg
25
+ from guppylang_internals.tys.builtin import get_element_type, is_bool_type
26
+ from guppylang_internals.tys.const import BoundConstVar, ConstValue
27
+ from guppylang_internals.tys.ty import NumericType
28
+
29
+ #: Maximum length of a tag in the `result` function.
30
+ TAG_MAX_LEN = 200
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class TooLongError(Error):
35
+ title: ClassVar[str] = "Tag too long"
36
+ span_label: ClassVar[str] = "Result tag is too long"
37
+
38
+ @dataclass(frozen=True)
39
+ class Hint(Note):
40
+ message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
41
+
42
+ @dataclass(frozen=True)
43
+ class GenericHint(Note):
44
+ message: ClassVar[str] = "Parameter `{param}` was instantiated to `{value}`"
45
+ param: str
46
+ value: str
47
+
48
+
49
+ class ResultCompiler(CustomCallCompiler):
50
+ """Custom compiler for overloads of the `result` function.
51
+
52
+ See `ArrayResultCompiler` for the compiler that handles results involving arrays.
53
+ """
54
+
55
+ def __init__(self, op_name: str, with_int_width: bool = False):
56
+ self.op_name = op_name
57
+ self.with_int_width = with_int_width
58
+
59
+ def compile(self, args: list[Wire]) -> list[Wire]:
60
+ assert self.func is not None
61
+ [value] = args
62
+ ty = self.func.ty.inputs[1].ty
63
+ hugr_ty = ty.to_hugr(self.ctx)
64
+ args = [tag_to_hugr(self.type_args[0], self.ctx, self.node)]
65
+ if self.with_int_width:
66
+ args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
67
+ # Bool results need an extra conversion into regular hugr bools
68
+ if is_bool_type(ty):
69
+ value = self.builder.add_op(read_bool(), value)
70
+ hugr_ty = tys.Bool
71
+ op = RESULT_EXTENSION.get_op(self.op_name)
72
+ sig = tys.FunctionType(input=[hugr_ty], output=[])
73
+ self.builder.add_op(op.instantiate(args, sig), value)
74
+ return []
75
+
76
+
77
+ class ArrayResultCompiler(CustomInoutCallCompiler):
78
+ """Custom compiler for overloads of the `result` function accepting arrays.
79
+
80
+ See `ResultCompiler` for the compiler that handles basic results.
81
+ """
82
+
83
+ def __init__(self, op_name: str, with_int_width: bool = False):
84
+ self.op_name = op_name
85
+ self.with_int_width = with_int_width
86
+
87
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
88
+ assert self.func is not None
89
+ array_ty = self.func.ty.inputs[1].ty
90
+ elem_ty = get_element_type(array_ty)
91
+ [tag_arg, size_arg] = self.type_args
92
+ [arr] = args
93
+
94
+ # As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
95
+ # that all elements in it are copyable) to avoid linearity violations when
96
+ # both passing it to the result operation and returning it (as an inout
97
+ # argument).
98
+ hugr_elem_ty = elem_ty.to_hugr(self.ctx)
99
+ hugr_size = size_arg.to_hugr(self.ctx)
100
+ arr, out_arr = self.builder.add_op(array_clone(hugr_elem_ty, hugr_size), arr)
101
+ # For bool arrays, we furthermore need to coerce a read on all the array
102
+ # elements
103
+ if is_bool_type(elem_ty):
104
+ array_read = array_read_bool(self.ctx)
105
+ array_read = self.builder.load_function(array_read)
106
+ map_op = array_map(OpaqueBool, hugr_size, tys.Bool)
107
+ arr = self.builder.add_op(map_op, arr, array_read).out(0)
108
+ hugr_elem_ty = tys.Bool
109
+ # Turn `borrow_array` into regular `array`
110
+ arr = self.builder.add_op(array_to_std_array(hugr_elem_ty, hugr_size), arr).out(
111
+ 0
112
+ )
113
+
114
+ hugr_ty = hugr.std.collections.array.Array(hugr_elem_ty, hugr_size)
115
+ sig = tys.FunctionType(input=[hugr_ty], output=[])
116
+ args = [tag_to_hugr(tag_arg, self.ctx, self.node), hugr_size]
117
+ if self.with_int_width:
118
+ args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
119
+ op = ops.ExtOp(RESULT_EXTENSION.get_op(self.op_name), signature=sig, args=args)
120
+ self.builder.add_op(op, arr)
121
+ return CallReturnWires([], [out_arr])
122
+
123
+
124
+ def tag_to_hugr(tag_arg: Argument, ctx: CompilerContext, loc: AstNode) -> tys.TypeArg:
125
+ """Helper function to convert the Guppy tag comptime argument into a Hugr type arg.
126
+
127
+ Takes care of reading the tag value from the current monomorphization and checks
128
+ that the tag fits into `TAG_MAX_LEN`.
129
+ """
130
+ is_generic: BoundConstVar | None = None
131
+ match tag_arg:
132
+ case ConstArg(const=ConstValue(value=str(value))):
133
+ tag = value
134
+ case ConstArg(const=BoundConstVar(idx=idx) as var):
135
+ is_generic = var
136
+ assert ctx.current_mono_args is not None
137
+ match ctx.current_mono_args[idx]:
138
+ case ConstArg(const=ConstValue(value=str(value))):
139
+ tag = value
140
+ case _:
141
+ raise InternalGuppyError("Invalid tag monomorphization")
142
+ case _:
143
+ raise InternalGuppyError("Invalid tag argument")
144
+
145
+ if len(tag.encode("utf-8")) > TAG_MAX_LEN:
146
+ err = TooLongError(loc)
147
+ err.add_sub_diagnostic(TooLongError.Hint(None))
148
+ if is_generic:
149
+ err.add_sub_diagnostic(
150
+ TooLongError.GenericHint(None, is_generic.display_name, tag)
151
+ )
152
+ raise GuppyError(err)
153
+ return tys.StringArg(tag)
@@ -73,6 +73,14 @@ def panic(
73
73
  return ops.ExtOp(op_def, sig, args)
74
74
 
75
75
 
76
+ def make_error() -> ops.ExtOp:
77
+ """Returns an operation that makes an error."""
78
+ op_def = hugr.std.PRELUDE.get_op("MakeError")
79
+ args: list[ht.TypeArg] = []
80
+ sig = ht.FunctionType([ht.USize(), hugr.std.prelude.STRING_T], [error_type()])
81
+ return ops.ExtOp(op_def, sig, args)
82
+
83
+
76
84
  # ------------------------------------------------------
77
85
  # --------- Custom compilers for non-native ops --------
78
86
  # ------------------------------------------------------
@@ -90,14 +98,14 @@ def build_panic(
90
98
  return builder.add_op(op, err, *args)
91
99
 
92
100
 
93
- def build_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
101
+ def build_static_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
94
102
  """Constructs and loads a static error value."""
95
103
  val = ErrorVal(signal, msg)
96
104
  return builder.load(builder.add_const(val))
97
105
 
98
106
 
99
107
  # TODO: Common up build_unwrap_right and build_unwrap_left below once
100
- # https://github.com/CQCL/hugr/issues/1596 is fixed
108
+ # https://github.com/quantinuum/hugr/issues/1596 is fixed
101
109
 
102
110
 
103
111
  def build_unwrap_right(
@@ -111,7 +119,7 @@ def build_unwrap_right(
111
119
  assert isinstance(result_ty, ht.Sum)
112
120
  [left_tys, right_tys] = result_ty.variant_rows
113
121
  with conditional.add_case(0) as case:
114
- error = build_error(case, error_signal, error_msg)
122
+ error = build_static_error(case, error_signal, error_msg)
115
123
  case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
116
124
  with conditional.add_case(1) as case:
117
125
  case.set_outputs(*case.inputs())
@@ -134,7 +142,7 @@ def build_unwrap_left(
134
142
  with conditional.add_case(0) as case:
135
143
  case.set_outputs(*case.inputs())
136
144
  with conditional.add_case(1) as case:
137
- error = build_error(case, error_signal, error_msg)
145
+ error = build_static_error(case, error_signal, error_msg)
138
146
  case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
139
147
  return conditional.to_node()
140
148
 
@@ -40,12 +40,7 @@ class OpaqueBoolVal(hv.ExtensionValue):
40
40
  def to_value(self) -> hv.Extension:
41
41
  name = "ConstBool"
42
42
  payload = self.v
43
- return hv.Extension(
44
- name,
45
- typ=OpaqueBool,
46
- val=payload,
47
- extensions=[BOOL_EXTENSION.name],
48
- )
43
+ return hv.Extension(name, typ=OpaqueBool, val=payload)
49
44
 
50
45
  def __str__(self) -> str:
51
46
  return f"{self.v}"
@@ -20,6 +20,7 @@ from tket_exts import (
20
20
  BOOL_EXTENSION = tket_exts.bool()
21
21
  DEBUG_EXTENSION = debug()
22
22
  FUTURES_EXTENSION = futures()
23
+ GLOBAL_PHASE_EXTENSION = global_phase()
23
24
  GUPPY_EXTENSION = guppy()
24
25
  MODIFIER_EXTENSION = modifier()
25
26
  QSYSTEM_EXTENSION = qsystem()
@@ -29,14 +30,14 @@ QUANTUM_EXTENSION = quantum()
29
30
  RESULT_EXTENSION = result()
30
31
  ROTATION_EXTENSION = rotation()
31
32
  WASM_EXTENSION = wasm()
32
- MODIFIER_EXTENSION = modifier()
33
- GLOBAL_PHASE_EXTENSION = global_phase()
34
33
 
35
34
  TKET_EXTENSIONS = [
36
35
  BOOL_EXTENSION,
37
36
  DEBUG_EXTENSION,
38
37
  FUTURES_EXTENSION,
38
+ GLOBAL_PHASE_EXTENSION,
39
39
  GUPPY_EXTENSION,
40
+ MODIFIER_EXTENSION,
40
41
  QSYSTEM_EXTENSION,
41
42
  QSYSTEM_RANDOM_EXTENSION,
42
43
  QSYSTEM_UTILS_EXTENSION,
@@ -44,8 +45,6 @@ TKET_EXTENSIONS = [
44
45
  RESULT_EXTENSION,
45
46
  ROTATION_EXTENSION,
46
47
  WASM_EXTENSION,
47
- MODIFIER_EXTENSION,
48
- GLOBAL_PHASE_EXTENSION,
49
48
  ]
50
49
 
51
50
 
@@ -60,7 +59,7 @@ class ConstWasmModule(val.ExtensionValue):
60
59
 
61
60
  name = "ConstWasmModule"
62
61
  payload = {"module_filename": self.wasm_file}
63
- return val.Extension(name, typ=ty, val=payload, extensions=["tket.wasm"])
62
+ return val.Extension(name, typ=ty, val=payload)
64
63
 
65
64
  def __str__(self) -> str:
66
65
  return f"tket.wasm.module(module_filename={self.wasm_file})"
@@ -3,6 +3,7 @@ from dataclasses import dataclass
3
3
  from typing import ClassVar, cast
4
4
 
5
5
  from guppylang_internals.ast_util import with_loc
6
+ from guppylang_internals.checker.core import ComptimeVariable
6
7
  from guppylang_internals.checker.errors.generic import ExpectedError
7
8
  from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
8
9
  from guppylang_internals.checker.expr_checker import (
@@ -14,14 +15,14 @@ from guppylang_internals.definition.custom import CustomCallChecker
14
15
  from guppylang_internals.definition.ty import TypeDef
15
16
  from guppylang_internals.diagnostic import Error
16
17
  from guppylang_internals.error import GuppyTypeError
17
- from guppylang_internals.nodes import StateResultExpr
18
- from guppylang_internals.std._internal.checker import TAG_MAX_LEN, TooLongError
18
+ from guppylang_internals.nodes import GenericParamValue, PlaceNode, StateResultExpr
19
19
  from guppylang_internals.tys.builtin import (
20
20
  get_array_length,
21
21
  get_element_type,
22
22
  is_array_type,
23
23
  string_type,
24
24
  )
25
+ from guppylang_internals.tys.const import Const, ConstValue
25
26
  from guppylang_internals.tys.ty import (
26
27
  FuncInput,
27
28
  FunctionType,
@@ -43,12 +44,16 @@ class StateResultChecker(CustomCallChecker):
43
44
 
44
45
  def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
45
46
  tag, _ = ExprChecker(self.ctx).check(args[0], string_type())
46
- if not isinstance(tag, ast.Constant) or not isinstance(tag.value, str):
47
- raise GuppyTypeError(ExpectedError(tag, "a string literal"))
48
- if len(tag.value.encode("utf-8")) > TAG_MAX_LEN:
49
- err: Error = TooLongError(tag)
50
- err.add_sub_diagnostic(TooLongError.Hint(None))
51
- raise GuppyTypeError(err)
47
+ tag_value: Const
48
+ match tag:
49
+ case ast.Constant(value=str(v)):
50
+ tag_value = ConstValue(string_type(), v)
51
+ case PlaceNode(place=ComptimeVariable(static_value=str(v))):
52
+ tag_value = ConstValue(string_type(), v)
53
+ case GenericParamValue() as param_value:
54
+ tag_value = param_value.param.to_bound().const
55
+ case _:
56
+ raise GuppyTypeError(ExpectedError(tag, "a string literal"))
52
57
  syn_args: list[ast.expr] = [tag]
53
58
 
54
59
  if len(args) < 2:
@@ -90,6 +95,10 @@ class StateResultChecker(CustomCallChecker):
90
95
  args, ret_ty, inst = synthesize_call(func_ty, syn_args, self.node, self.ctx)
91
96
  assert len(inst) == 0, "func_ty is not generic"
92
97
  node = StateResultExpr(
93
- tag=tag.value, args=args, func_ty=func_ty, array_len=array_len
98
+ tag_value=tag_value,
99
+ tag_expr=tag,
100
+ args=args,
101
+ func_ty=func_ty,
102
+ array_len=array_len,
94
103
  )
95
104
  return with_loc(self.node, node), ret_ty
@@ -129,7 +129,7 @@ def int_op(
129
129
  # Ideally we'd be able to derive the arguments from the input/output types,
130
130
  # but the amount of variables does not correlate with the signature for the
131
131
  # integer ops in hugr :/
132
- # https://github.com/CQCL/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
132
+ # https://github.com/quantinuum/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
133
133
  #
134
134
  # For now, we just instantiate every type argument to a 64-bit integer.
135
135
  args: list[ht.TypeArg] = [int_arg() for _ in range(n_vars)]
@@ -342,6 +342,10 @@ class GuppyObject(DunderMixin):
342
342
  if not ty.droppable and not self._used:
343
343
  state.unused_undroppable_objs[self._id] = self
344
344
 
345
+ def __deepcopy__(self, memo: dict[int, Any]) -> "GuppyObject":
346
+ # Dummy deepcopy implementation, we do not want to actually deepcopy
347
+ return self
348
+
345
349
  @hide_trace
346
350
  def __getattr__(self, key: str) -> Any: # type: ignore[misc]
347
351
  # Guppy objects don't have fields (structs are treated separately below), so the
@@ -539,6 +543,16 @@ class TracingDefMixin(DunderMixin):
539
543
 
540
544
  def to_guppy_object(self) -> GuppyObject:
541
545
  state = get_tracing_state()
546
+ defn = ENGINE.get_checked(self.id)
547
+ # TODO: For generic functions, we need to know an instantiation for their type
548
+ # parameters. Maybe we should pass them to `to_guppy_object`? Either way, this
549
+ # will require some more plumbing of type inference information through the
550
+ # comptime logic. For now, let's just bail on generic functions.
551
+ # See https://github.com/quantinuum/guppylang/issues/1336
552
+ if isinstance(defn, CallableDef) and defn.ty.parametrized:
553
+ raise GuppyComptimeError(
554
+ f"Cannot infer type parameters of generic function `{defn.name}`"
555
+ )
542
556
  defn, [] = state.ctx.build_compiled_def(self.id, type_args=[])
543
557
  if isinstance(defn, CompiledValueDef):
544
558
  wire = defn.load(state.dfg, state.ctx, state.node)