guppylang-internals 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/builder.py +17 -2
  3. guppylang_internals/cfg/cfg.py +3 -0
  4. guppylang_internals/checker/cfg_checker.py +6 -0
  5. guppylang_internals/checker/core.py +1 -2
  6. guppylang_internals/checker/errors/wasm.py +7 -4
  7. guppylang_internals/checker/expr_checker.py +13 -8
  8. guppylang_internals/checker/func_checker.py +17 -13
  9. guppylang_internals/checker/linearity_checker.py +2 -10
  10. guppylang_internals/checker/modifier_checker.py +6 -2
  11. guppylang_internals/checker/unitary_checker.py +132 -0
  12. guppylang_internals/compiler/cfg_compiler.py +7 -6
  13. guppylang_internals/compiler/core.py +5 -5
  14. guppylang_internals/compiler/expr_compiler.py +42 -73
  15. guppylang_internals/compiler/modifier_compiler.py +2 -0
  16. guppylang_internals/decorator.py +86 -7
  17. guppylang_internals/definition/custom.py +4 -0
  18. guppylang_internals/definition/declaration.py +6 -2
  19. guppylang_internals/definition/function.py +12 -2
  20. guppylang_internals/definition/pytket_circuits.py +1 -0
  21. guppylang_internals/definition/struct.py +6 -3
  22. guppylang_internals/definition/wasm.py +42 -10
  23. guppylang_internals/engine.py +9 -3
  24. guppylang_internals/nodes.py +23 -24
  25. guppylang_internals/std/_internal/checker.py +13 -108
  26. guppylang_internals/std/_internal/compiler/array.py +1 -1
  27. guppylang_internals/std/_internal/compiler/list.py +1 -1
  28. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  29. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  30. guppylang_internals/std/_internal/compiler/tket_exts.py +3 -4
  31. guppylang_internals/std/_internal/debug.py +18 -9
  32. guppylang_internals/std/_internal/util.py +1 -1
  33. guppylang_internals/tracing/object.py +10 -0
  34. guppylang_internals/tys/errors.py +23 -1
  35. guppylang_internals/tys/parsing.py +3 -3
  36. guppylang_internals/tys/printing.py +2 -8
  37. guppylang_internals/tys/qubit.py +37 -2
  38. guppylang_internals/tys/ty.py +60 -64
  39. guppylang_internals/wasm_util.py +129 -0
  40. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +4 -3
  41. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/RECORD +43 -40
  42. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  43. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -9,7 +9,13 @@ from guppylang_internals.ast_util import AstNode
9
9
  from guppylang_internals.span import Span, to_span
10
10
  from guppylang_internals.tys.const import Const
11
11
  from guppylang_internals.tys.subst import Inst
12
- from guppylang_internals.tys.ty import FunctionType, StructType, TupleType, Type
12
+ from guppylang_internals.tys.ty import (
13
+ FunctionType,
14
+ StructType,
15
+ TupleType,
16
+ Type,
17
+ UnitaryFlags,
18
+ )
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from guppylang_internals.cfg.cfg import CFG
@@ -250,22 +256,6 @@ class ComptimeExpr(ast.expr):
250
256
  _fields = ("value",)
251
257
 
252
258
 
253
- class ResultExpr(ast.expr):
254
- """A `result(tag, value)` expression."""
255
-
256
- value: ast.expr
257
- base_ty: Type
258
- #: Array length in case this is an array result, otherwise `None`
259
- array_len: Const | None
260
- tag: str
261
-
262
- _fields = ("value", "base_ty", "array_len", "tag")
263
-
264
- @property
265
- def args(self) -> list[ast.expr]:
266
- return [self.value]
267
-
268
-
269
259
  class ExitKind(Enum):
270
260
  ExitShot = 0 # Exit the current shot
271
261
  Panic = 1 # Panic the program ending all shots
@@ -275,8 +265,8 @@ class PanicExpr(ast.expr):
275
265
  """A `panic(msg, *args)` or `exit(msg, *args)` expression ."""
276
266
 
277
267
  kind: ExitKind
278
- signal: int
279
- msg: str
268
+ signal: ast.expr
269
+ msg: ast.expr
280
270
  values: list[ast.expr]
281
271
 
282
272
  _fields = ("kind", "signal", "msg", "values")
@@ -293,17 +283,16 @@ class BarrierExpr(ast.expr):
293
283
  class StateResultExpr(ast.expr):
294
284
  """A `state_result(tag, *args)` expression."""
295
285
 
296
- tag: str
286
+ tag_value: Const
287
+ tag_expr: ast.expr
297
288
  args: list[ast.expr]
298
289
  func_ty: FunctionType
299
290
  #: Array length in case this is an array result, otherwise `None`
300
291
  array_len: Const | None
301
- _fields = ("tag", "args", "func_ty", "has_array_input")
292
+ _fields = ("tag_value", "tag_expr", "args", "func_ty", "has_array_input")
302
293
 
303
294
 
304
- AnyCall = (
305
- LocalCall | GlobalCall | TensorCall | BarrierExpr | ResultExpr | StateResultExpr
306
- )
295
+ AnyCall = LocalCall | GlobalCall | TensorCall | BarrierExpr | StateResultExpr
307
296
 
308
297
 
309
298
  class InoutReturnSentinel(ast.expr):
@@ -500,6 +489,16 @@ class ModifiedBlock(ast.With):
500
489
  else:
501
490
  raise TypeError(f"Unknown modifier: {modifier}")
502
491
 
492
+ def flags(self) -> UnitaryFlags:
493
+ flags = UnitaryFlags.NoFlags
494
+ if self.is_dagger():
495
+ flags |= UnitaryFlags.Dagger
496
+ if self.is_control():
497
+ flags |= UnitaryFlags.Control
498
+ if self.is_power():
499
+ flags |= UnitaryFlags.Power
500
+ return flags
501
+
503
502
 
504
503
  class CheckedModifiedBlock(ast.With):
505
504
  def_id: "DefId"
@@ -4,9 +4,9 @@ from typing import ClassVar
4
4
 
5
5
  from typing_extensions import assert_never
6
6
 
7
- from guppylang_internals.ast_util import get_type, with_loc
8
- from guppylang_internals.checker.core import ComptimeVariable, Context
9
- from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
7
+ from guppylang_internals.ast_util import get_type, with_loc, with_type
8
+ from guppylang_internals.checker.core import Context
9
+ from guppylang_internals.checker.errors.generic import UnsupportedError
10
10
  from guppylang_internals.checker.errors.type_errors import (
11
11
  ArrayComprUnknownSizeError,
12
12
  TypeMismatchError,
@@ -33,8 +33,6 @@ from guppylang_internals.nodes import (
33
33
  GlobalCall,
34
34
  MakeIter,
35
35
  PanicExpr,
36
- PlaceNode,
37
- ResultExpr,
38
36
  )
39
37
  from guppylang_internals.tys.arg import ConstArg, TypeArg
40
38
  from guppylang_internals.tys.builtin import (
@@ -45,7 +43,6 @@ from guppylang_internals.tys.builtin import (
45
43
  get_iter_size,
46
44
  int_type,
47
45
  is_array_type,
48
- is_bool_type,
49
46
  is_sized_iter_type,
50
47
  nat_type,
51
48
  sized_iter_type,
@@ -58,7 +55,6 @@ from guppylang_internals.tys.ty import (
58
55
  FunctionType,
59
56
  InputFlags,
60
57
  NoneType,
61
- NumericType,
62
58
  Type,
63
59
  unify,
64
60
  )
@@ -302,80 +298,6 @@ class NewArrayChecker(CustomCallChecker):
302
298
  return with_loc(compr, array_compr), array_type(elt_ty, size)
303
299
 
304
300
 
305
- #: Maximum length of a tag in the `result` function.
306
- TAG_MAX_LEN = 200
307
-
308
-
309
- @dataclass(frozen=True)
310
- class TooLongError(Error):
311
- title: ClassVar[str] = "Tag too long"
312
- span_label: ClassVar[str] = "Result tag is too long"
313
-
314
- @dataclass(frozen=True)
315
- class Hint(Note):
316
- message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
317
-
318
-
319
- class ResultChecker(CustomCallChecker):
320
- """Call checker for the `result` function."""
321
-
322
- @dataclass(frozen=True)
323
- class InvalidError(Error):
324
- title: ClassVar[str] = "Invalid Result"
325
- span_label: ClassVar[str] = "Expression of type `{ty}` is not a valid result."
326
- ty: Type
327
-
328
- @dataclass(frozen=True)
329
- class Explanation(Note):
330
- message: ClassVar[str] = (
331
- "Only numeric values or arrays thereof are allowed as results"
332
- )
333
-
334
- def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
335
- check_num_args(2, len(args), self.node)
336
- [tag, value] = args
337
- tag, _ = ExprChecker(self.ctx).check(tag, string_type())
338
- match tag:
339
- case ast.Constant(value=str(v)):
340
- tag_value = v
341
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
342
- tag_value = v
343
- case _:
344
- raise GuppyTypeError(ExpectedError(tag, "a string literal"))
345
- if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
346
- err: Error = TooLongError(tag)
347
- err.add_sub_diagnostic(TooLongError.Hint(None))
348
- raise GuppyTypeError(err)
349
- value, ty = ExprSynthesizer(self.ctx).synthesize(value)
350
- # We only allow numeric values or vectors of numeric values
351
- err = ResultChecker.InvalidError(value, ty)
352
- err.add_sub_diagnostic(ResultChecker.InvalidError.Explanation(None))
353
- if self._is_numeric_or_bool_type(ty):
354
- base_ty = ty
355
- array_len: Const | None = None
356
- elif is_array_type(ty):
357
- [ty_arg, len_arg] = ty.args
358
- assert isinstance(ty_arg, TypeArg)
359
- assert isinstance(len_arg, ConstArg)
360
- if not self._is_numeric_or_bool_type(ty_arg.ty):
361
- raise GuppyError(err)
362
- base_ty = ty_arg.ty
363
- array_len = len_arg.const
364
- else:
365
- raise GuppyError(err)
366
- node = ResultExpr(value, base_ty, array_len, tag_value)
367
- return with_loc(self.node, node), NoneType()
368
-
369
- def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
370
- expr, res_ty = self.synthesize(args)
371
- expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
372
- return expr, subst
373
-
374
- @staticmethod
375
- def _is_numeric_or_bool_type(ty: Type) -> bool:
376
- return isinstance(ty, NumericType) or is_bool_type(ty)
377
-
378
-
379
301
  class PanicChecker(CustomCallChecker):
380
302
  """Call checker for the `panic` function."""
381
303
 
@@ -395,18 +317,16 @@ class PanicChecker(CustomCallChecker):
395
317
  err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
396
318
  raise GuppyTypeError(err)
397
319
  case [msg, *rest]:
320
+ # Check type of message and synthesize types for additional values.
398
321
  msg, _ = ExprChecker(self.ctx).check(msg, string_type())
399
- match msg:
400
- case ast.Constant(value=str(v)):
401
- msg_value = v
402
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
403
- msg_value = v
404
- case _:
405
- raise GuppyTypeError(ExpectedError(msg, "a string literal"))
406
322
  vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
407
323
  # TODO variable signals once default arguments are available
324
+ # TODO this will also allow us to remove this manual AST node hack
325
+ signal_expr = with_type(
326
+ int_type(), with_loc(self.node, ast.Constant(value=1))
327
+ )
408
328
  node = PanicExpr(
409
- kind=ExitKind.Panic, msg=msg_value, values=vals, signal=1
329
+ kind=ExitKind.Panic, msg=msg, values=vals, signal=signal_expr
410
330
  )
411
331
  return with_loc(self.node, node), NoneType()
412
332
  case args:
@@ -454,31 +374,16 @@ class ExitChecker(CustomCallChecker):
454
374
  )
455
375
  raise GuppyTypeError(signal_err)
456
376
  case [msg, signal, *rest]:
377
+ # Check types for message and signal and synthesize types for additional
378
+ # values.
457
379
  msg, _ = ExprChecker(self.ctx).check(msg, string_type())
458
- match msg:
459
- case ast.Constant(value=str(v)):
460
- msg_value = v
461
- case PlaceNode(place=ComptimeVariable(static_value=str(v))):
462
- msg_value = v
463
- case _:
464
- raise GuppyTypeError(ExpectedError(msg, "a string literal"))
465
- # TODO allow variable signals after https://github.com/CQCL/hugr/issues/1863
466
380
  signal, _ = ExprChecker(self.ctx).check(signal, int_type())
467
- match signal:
468
- case ast.Constant(value=int(s)):
469
- signal_value = s
470
- case PlaceNode(place=ComptimeVariable(static_value=int(s))):
471
- signal_value = s
472
- case _:
473
- raise GuppyTypeError(
474
- ExpectedError(signal, "an integer literal")
475
- )
476
381
  vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
477
382
  node = PanicExpr(
478
383
  kind=ExitKind.ExitShot,
479
- msg=msg_value,
384
+ msg=msg,
480
385
  values=vals,
481
- signal=signal_value,
386
+ signal=signal,
482
387
  )
483
388
  return with_loc(self.node, node), NoneType()
484
389
  case args:
@@ -261,7 +261,7 @@ class NewArrayCompiler(ArrayCompiler):
261
261
 
262
262
  def build_classical_array(self, elems: list[Wire]) -> Wire:
263
263
  """Lowers a call to `array.__new__` for classical arrays."""
264
- # See https://github.com/CQCL/guppylang/issues/629
264
+ # See https://github.com/quantinuum/guppylang/issues/629
265
265
  return self.build_linear_array(elems)
266
266
 
267
267
  def build_linear_array(self, elems: list[Wire]) -> Wire:
@@ -328,7 +328,7 @@ def _list_new_classical(
328
328
  builder: DfBase[ops.DfParentOp], elem_type: ht.Type, args: list[Wire]
329
329
  ) -> Wire:
330
330
  # This may be simplified in the future with a `new` or `with_capacity` list op
331
- # See https://github.com/CQCL/hugr/issues/1508
331
+ # See https://github.com/quantinuum/hugr/issues/1508
332
332
  lst = builder.load(ListVal([], elem_ty=elem_type))
333
333
  push_op = list_push(elem_type)
334
334
  for elem in args:
@@ -0,0 +1,153 @@
1
+ from dataclasses import dataclass
2
+ from typing import ClassVar
3
+
4
+ import hugr
5
+ from hugr import Wire, ops, tys
6
+
7
+ from guppylang_internals.ast_util import AstNode
8
+ from guppylang_internals.compiler.core import CompilerContext
9
+ from guppylang_internals.compiler.expr_compiler import array_read_bool
10
+ from guppylang_internals.definition.custom import (
11
+ CustomCallCompiler,
12
+ CustomInoutCallCompiler,
13
+ )
14
+ from guppylang_internals.definition.value import CallReturnWires
15
+ from guppylang_internals.diagnostic import Error, Note
16
+ from guppylang_internals.error import GuppyError, InternalGuppyError
17
+ from guppylang_internals.std._internal.compiler.array import (
18
+ array_clone,
19
+ array_map,
20
+ array_to_std_array,
21
+ )
22
+ from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
23
+ from guppylang_internals.std._internal.compiler.tket_exts import RESULT_EXTENSION
24
+ from guppylang_internals.tys.arg import Argument, ConstArg
25
+ from guppylang_internals.tys.builtin import get_element_type, is_bool_type
26
+ from guppylang_internals.tys.const import BoundConstVar, ConstValue
27
+ from guppylang_internals.tys.ty import NumericType
28
+
29
+ #: Maximum length of a tag in the `result` function.
30
+ TAG_MAX_LEN = 200
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class TooLongError(Error):
35
+ title: ClassVar[str] = "Tag too long"
36
+ span_label: ClassVar[str] = "Result tag is too long"
37
+
38
+ @dataclass(frozen=True)
39
+ class Hint(Note):
40
+ message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
41
+
42
+ @dataclass(frozen=True)
43
+ class GenericHint(Note):
44
+ message: ClassVar[str] = "Parameter `{param}` was instantiated to `{value}`"
45
+ param: str
46
+ value: str
47
+
48
+
49
+ class ResultCompiler(CustomCallCompiler):
50
+ """Custom compiler for overloads of the `result` function.
51
+
52
+ See `ArrayResultCompiler` for the compiler that handles results involving arrays.
53
+ """
54
+
55
+ def __init__(self, op_name: str, with_int_width: bool = False):
56
+ self.op_name = op_name
57
+ self.with_int_width = with_int_width
58
+
59
+ def compile(self, args: list[Wire]) -> list[Wire]:
60
+ assert self.func is not None
61
+ [value] = args
62
+ ty = self.func.ty.inputs[1].ty
63
+ hugr_ty = ty.to_hugr(self.ctx)
64
+ args = [tag_to_hugr(self.type_args[0], self.ctx, self.node)]
65
+ if self.with_int_width:
66
+ args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
67
+ # Bool results need an extra conversion into regular hugr bools
68
+ if is_bool_type(ty):
69
+ value = self.builder.add_op(read_bool(), value)
70
+ hugr_ty = tys.Bool
71
+ op = RESULT_EXTENSION.get_op(self.op_name)
72
+ sig = tys.FunctionType(input=[hugr_ty], output=[])
73
+ self.builder.add_op(op.instantiate(args, sig), value)
74
+ return []
75
+
76
+
77
+ class ArrayResultCompiler(CustomInoutCallCompiler):
78
+ """Custom compiler for overloads of the `result` function accepting arrays.
79
+
80
+ See `ResultCompiler` for the compiler that handles basic results.
81
+ """
82
+
83
+ def __init__(self, op_name: str, with_int_width: bool = False):
84
+ self.op_name = op_name
85
+ self.with_int_width = with_int_width
86
+
87
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
88
+ assert self.func is not None
89
+ array_ty = self.func.ty.inputs[1].ty
90
+ elem_ty = get_element_type(array_ty)
91
+ [tag_arg, size_arg] = self.type_args
92
+ [arr] = args
93
+
94
+ # As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
95
+ # that all elements in it are copyable) to avoid linearity violations when
96
+ # both passing it to the result operation and returning it (as an inout
97
+ # argument).
98
+ hugr_elem_ty = elem_ty.to_hugr(self.ctx)
99
+ hugr_size = size_arg.to_hugr(self.ctx)
100
+ arr, out_arr = self.builder.add_op(array_clone(hugr_elem_ty, hugr_size), arr)
101
+ # For bool arrays, we furthermore need to coerce a read on all the array
102
+ # elements
103
+ if is_bool_type(elem_ty):
104
+ array_read = array_read_bool(self.ctx)
105
+ array_read = self.builder.load_function(array_read)
106
+ map_op = array_map(OpaqueBool, hugr_size, tys.Bool)
107
+ arr = self.builder.add_op(map_op, arr, array_read).out(0)
108
+ hugr_elem_ty = tys.Bool
109
+ # Turn `borrow_array` into regular `array`
110
+ arr = self.builder.add_op(array_to_std_array(hugr_elem_ty, hugr_size), arr).out(
111
+ 0
112
+ )
113
+
114
+ hugr_ty = hugr.std.collections.array.Array(hugr_elem_ty, hugr_size)
115
+ sig = tys.FunctionType(input=[hugr_ty], output=[])
116
+ args = [tag_to_hugr(tag_arg, self.ctx, self.node), hugr_size]
117
+ if self.with_int_width:
118
+ args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
119
+ op = ops.ExtOp(RESULT_EXTENSION.get_op(self.op_name), signature=sig, args=args)
120
+ self.builder.add_op(op, arr)
121
+ return CallReturnWires([], [out_arr])
122
+
123
+
124
+ def tag_to_hugr(tag_arg: Argument, ctx: CompilerContext, loc: AstNode) -> tys.TypeArg:
125
+ """Helper function to convert the Guppy tag comptime argument into a Hugr type arg.
126
+
127
+ Takes care of reading the tag value from the current monomorphization and checks
128
+ that the tag fits into `TAG_MAX_LEN`.
129
+ """
130
+ is_generic: BoundConstVar | None = None
131
+ match tag_arg:
132
+ case ConstArg(const=ConstValue(value=str(value))):
133
+ tag = value
134
+ case ConstArg(const=BoundConstVar(idx=idx) as var):
135
+ is_generic = var
136
+ assert ctx.current_mono_args is not None
137
+ match ctx.current_mono_args[idx]:
138
+ case ConstArg(const=ConstValue(value=str(value))):
139
+ tag = value
140
+ case _:
141
+ raise InternalGuppyError("Invalid tag monomorphization")
142
+ case _:
143
+ raise InternalGuppyError("Invalid tag argument")
144
+
145
+ if len(tag.encode("utf-8")) > TAG_MAX_LEN:
146
+ err = TooLongError(loc)
147
+ err.add_sub_diagnostic(TooLongError.Hint(None))
148
+ if is_generic:
149
+ err.add_sub_diagnostic(
150
+ TooLongError.GenericHint(None, is_generic.display_name, tag)
151
+ )
152
+ raise GuppyError(err)
153
+ return tys.StringArg(tag)
@@ -73,6 +73,14 @@ def panic(
73
73
  return ops.ExtOp(op_def, sig, args)
74
74
 
75
75
 
76
+ def make_error() -> ops.ExtOp:
77
+ """Returns an operation that makes an error."""
78
+ op_def = hugr.std.PRELUDE.get_op("MakeError")
79
+ args: list[ht.TypeArg] = []
80
+ sig = ht.FunctionType([ht.USize(), hugr.std.prelude.STRING_T], [error_type()])
81
+ return ops.ExtOp(op_def, sig, args)
82
+
83
+
76
84
  # ------------------------------------------------------
77
85
  # --------- Custom compilers for non-native ops --------
78
86
  # ------------------------------------------------------
@@ -90,14 +98,14 @@ def build_panic(
90
98
  return builder.add_op(op, err, *args)
91
99
 
92
100
 
93
- def build_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
101
+ def build_static_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
94
102
  """Constructs and loads a static error value."""
95
103
  val = ErrorVal(signal, msg)
96
104
  return builder.load(builder.add_const(val))
97
105
 
98
106
 
99
107
  # TODO: Common up build_unwrap_right and build_unwrap_left below once
100
- # https://github.com/CQCL/hugr/issues/1596 is fixed
108
+ # https://github.com/quantinuum/hugr/issues/1596 is fixed
101
109
 
102
110
 
103
111
  def build_unwrap_right(
@@ -111,7 +119,7 @@ def build_unwrap_right(
111
119
  assert isinstance(result_ty, ht.Sum)
112
120
  [left_tys, right_tys] = result_ty.variant_rows
113
121
  with conditional.add_case(0) as case:
114
- error = build_error(case, error_signal, error_msg)
122
+ error = build_static_error(case, error_signal, error_msg)
115
123
  case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
116
124
  with conditional.add_case(1) as case:
117
125
  case.set_outputs(*case.inputs())
@@ -134,7 +142,7 @@ def build_unwrap_left(
134
142
  with conditional.add_case(0) as case:
135
143
  case.set_outputs(*case.inputs())
136
144
  with conditional.add_case(1) as case:
137
- error = build_error(case, error_signal, error_msg)
145
+ error = build_static_error(case, error_signal, error_msg)
138
146
  case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
139
147
  return conditional.to_node()
140
148
 
@@ -20,6 +20,7 @@ from tket_exts import (
20
20
  BOOL_EXTENSION = tket_exts.bool()
21
21
  DEBUG_EXTENSION = debug()
22
22
  FUTURES_EXTENSION = futures()
23
+ GLOBAL_PHASE_EXTENSION = global_phase()
23
24
  GUPPY_EXTENSION = guppy()
24
25
  MODIFIER_EXTENSION = modifier()
25
26
  QSYSTEM_EXTENSION = qsystem()
@@ -29,14 +30,14 @@ QUANTUM_EXTENSION = quantum()
29
30
  RESULT_EXTENSION = result()
30
31
  ROTATION_EXTENSION = rotation()
31
32
  WASM_EXTENSION = wasm()
32
- MODIFIER_EXTENSION = modifier()
33
- GLOBAL_PHASE_EXTENSION = global_phase()
34
33
 
35
34
  TKET_EXTENSIONS = [
36
35
  BOOL_EXTENSION,
37
36
  DEBUG_EXTENSION,
38
37
  FUTURES_EXTENSION,
38
+ GLOBAL_PHASE_EXTENSION,
39
39
  GUPPY_EXTENSION,
40
+ MODIFIER_EXTENSION,
40
41
  QSYSTEM_EXTENSION,
41
42
  QSYSTEM_RANDOM_EXTENSION,
42
43
  QSYSTEM_UTILS_EXTENSION,
@@ -44,8 +45,6 @@ TKET_EXTENSIONS = [
44
45
  RESULT_EXTENSION,
45
46
  ROTATION_EXTENSION,
46
47
  WASM_EXTENSION,
47
- MODIFIER_EXTENSION,
48
- GLOBAL_PHASE_EXTENSION,
49
48
  ]
50
49
 
51
50
 
@@ -3,6 +3,7 @@ from dataclasses import dataclass
3
3
  from typing import ClassVar, cast
4
4
 
5
5
  from guppylang_internals.ast_util import with_loc
6
+ from guppylang_internals.checker.core import ComptimeVariable
6
7
  from guppylang_internals.checker.errors.generic import ExpectedError
7
8
  from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
8
9
  from guppylang_internals.checker.expr_checker import (
@@ -14,14 +15,14 @@ from guppylang_internals.definition.custom import CustomCallChecker
14
15
  from guppylang_internals.definition.ty import TypeDef
15
16
  from guppylang_internals.diagnostic import Error
16
17
  from guppylang_internals.error import GuppyTypeError
17
- from guppylang_internals.nodes import StateResultExpr
18
- from guppylang_internals.std._internal.checker import TAG_MAX_LEN, TooLongError
18
+ from guppylang_internals.nodes import GenericParamValue, PlaceNode, StateResultExpr
19
19
  from guppylang_internals.tys.builtin import (
20
20
  get_array_length,
21
21
  get_element_type,
22
22
  is_array_type,
23
23
  string_type,
24
24
  )
25
+ from guppylang_internals.tys.const import Const, ConstValue
25
26
  from guppylang_internals.tys.ty import (
26
27
  FuncInput,
27
28
  FunctionType,
@@ -43,12 +44,16 @@ class StateResultChecker(CustomCallChecker):
43
44
 
44
45
  def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
45
46
  tag, _ = ExprChecker(self.ctx).check(args[0], string_type())
46
- if not isinstance(tag, ast.Constant) or not isinstance(tag.value, str):
47
- raise GuppyTypeError(ExpectedError(tag, "a string literal"))
48
- if len(tag.value.encode("utf-8")) > TAG_MAX_LEN:
49
- err: Error = TooLongError(tag)
50
- err.add_sub_diagnostic(TooLongError.Hint(None))
51
- raise GuppyTypeError(err)
47
+ tag_value: Const
48
+ match tag:
49
+ case ast.Constant(value=str(v)):
50
+ tag_value = ConstValue(string_type(), v)
51
+ case PlaceNode(place=ComptimeVariable(static_value=str(v))):
52
+ tag_value = ConstValue(string_type(), v)
53
+ case GenericParamValue() as param_value:
54
+ tag_value = param_value.param.to_bound().const
55
+ case _:
56
+ raise GuppyTypeError(ExpectedError(tag, "a string literal"))
52
57
  syn_args: list[ast.expr] = [tag]
53
58
 
54
59
  if len(args) < 2:
@@ -90,6 +95,10 @@ class StateResultChecker(CustomCallChecker):
90
95
  args, ret_ty, inst = synthesize_call(func_ty, syn_args, self.node, self.ctx)
91
96
  assert len(inst) == 0, "func_ty is not generic"
92
97
  node = StateResultExpr(
93
- tag=tag.value, args=args, func_ty=func_ty, array_len=array_len
98
+ tag_value=tag_value,
99
+ tag_expr=tag,
100
+ args=args,
101
+ func_ty=func_ty,
102
+ array_len=array_len,
94
103
  )
95
104
  return with_loc(self.node, node), ret_ty
@@ -129,7 +129,7 @@ def int_op(
129
129
  # Ideally we'd be able to derive the arguments from the input/output types,
130
130
  # but the amount of variables does not correlate with the signature for the
131
131
  # integer ops in hugr :/
132
- # https://github.com/CQCL/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
132
+ # https://github.com/quantinuum/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
133
133
  #
134
134
  # For now, we just instantiate every type argument to a 64-bit integer.
135
135
  args: list[ht.TypeArg] = [int_arg() for _ in range(n_vars)]
@@ -539,6 +539,16 @@ class TracingDefMixin(DunderMixin):
539
539
 
540
540
  def to_guppy_object(self) -> GuppyObject:
541
541
  state = get_tracing_state()
542
+ defn = ENGINE.get_checked(self.id)
543
+ # TODO: For generic functions, we need to know an instantiation for their type
544
+ # parameters. Maybe we should pass them to `to_guppy_object`? Either way, this
545
+ # will require some more plumbing of type inference information through the
546
+ # comptime logic. For now, let's just bail on generic functions.
547
+ # See https://github.com/quantinuum/guppylang/issues/1336
548
+ if isinstance(defn, CallableDef) and defn.ty.parametrized:
549
+ raise GuppyComptimeError(
550
+ f"Cannot infer type parameters of generic function `{defn.name}`"
551
+ )
542
552
  defn, [] = state.ctx.build_compiled_def(self.id, type_args=[])
543
553
  if isinstance(defn, CompiledValueDef):
544
554
  wire = defn.load(state.dfg, state.ctx, state.node)
@@ -5,7 +5,7 @@ from guppylang_internals.diagnostic import Error, Help, Note
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from guppylang_internals.definition.parameter import ParamDef
8
- from guppylang_internals.tys.ty import Type
8
+ from guppylang_internals.tys.ty import Type, UnitaryFlags
9
9
 
10
10
 
11
11
  @dataclass(frozen=True)
@@ -182,3 +182,25 @@ class InvalidFlagError(Error):
182
182
  class FlagNotAllowedError(Error):
183
183
  title: ClassVar[str] = "Invalid annotation"
184
184
  span_label: ClassVar[str] = "`@` type annotations are not allowed in this position"
185
+
186
+
187
+ @dataclass(frozen=True)
188
+ class UnitaryCallError(Error):
189
+ title: ClassVar[str] = "Unitary constraint violation"
190
+ span_label: ClassVar[str] = (
191
+ "This function cannot be called in a {render_flags} context"
192
+ )
193
+ flags: "UnitaryFlags"
194
+
195
+ @property
196
+ def render_flags(self) -> str:
197
+ from guppylang_internals.tys.ty import UnitaryFlags
198
+
199
+ if self.flags == UnitaryFlags.Dagger:
200
+ return "dagger"
201
+ elif self.flags == UnitaryFlags.Control:
202
+ return "control"
203
+ elif self.flags == UnitaryFlags.Power:
204
+ return "power"
205
+ else:
206
+ return "unitary"
@@ -107,7 +107,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
107
107
  return ConstArg(ConstValue(bool_type(), v))
108
108
  # Integer literals are turned into nat args.
109
109
  # TODO: To support int args, we need proper inference logic here
110
- # See https://github.com/CQCL/guppylang/issues/1030
110
+ # See https://github.com/quantinuum/guppylang/issues/1030
111
111
  case int(v) if v >= 0:
112
112
  nat_ty = NumericType(NumericType.Kind.Nat)
113
113
  return ConstArg(ConstValue(nat_ty, v))
@@ -117,7 +117,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
117
117
  # String literals are ignored for now since they could also be stringified
118
118
  # types.
119
119
  # TODO: To support string args, we need proper inference logic here
120
- # See https://github.com/CQCL/guppylang/issues/1030
120
+ # See https://github.com/quantinuum/guppylang/issues/1030
121
121
  case str(_):
122
122
  pass
123
123
 
@@ -289,7 +289,7 @@ def check_function_arg(
289
289
  ctx.param_var_mapping[name] = ConstParam(
290
290
  len(ctx.param_var_mapping), name, ty, from_comptime_arg=True
291
291
  )
292
- return FuncInput(ty, flags)
292
+ return FuncInput(ty, flags, name)
293
293
 
294
294
 
295
295
  if sys.version_info >= (3, 12):