guppylang-internals 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/cfg/builder.py +20 -2
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +1 -2
- guppylang_internals/checker/errors/linearity.py +6 -2
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +39 -19
- guppylang_internals/checker/func_checker.py +17 -13
- guppylang_internals/checker/linearity_checker.py +2 -10
- guppylang_internals/checker/modifier_checker.py +6 -2
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +5 -5
- guppylang_internals/compiler/expr_compiler.py +72 -81
- guppylang_internals/compiler/modifier_compiler.py +5 -0
- guppylang_internals/decorator.py +88 -7
- guppylang_internals/definition/custom.py +4 -0
- guppylang_internals/definition/declaration.py +6 -2
- guppylang_internals/definition/function.py +26 -3
- guppylang_internals/definition/metadata.py +87 -0
- guppylang_internals/definition/overloaded.py +11 -2
- guppylang_internals/definition/pytket_circuits.py +7 -2
- guppylang_internals/definition/struct.py +6 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/diagnostic.py +72 -15
- guppylang_internals/engine.py +10 -13
- guppylang_internals/nodes.py +55 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +37 -2
- guppylang_internals/std/_internal/compiler/either.py +14 -2
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
- guppylang_internals/std/_internal/compiler/tket_exts.py +4 -5
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +14 -0
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/parsing.py +3 -3
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +37 -2
- guppylang_internals/tys/ty.py +60 -64
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/METADATA +5 -4
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/RECORD +49 -45
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/WHEEL +1 -1
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -4,9 +4,9 @@ from typing import ClassVar
|
|
|
4
4
|
|
|
5
5
|
from typing_extensions import assert_never
|
|
6
6
|
|
|
7
|
-
from guppylang_internals.ast_util import get_type, with_loc
|
|
8
|
-
from guppylang_internals.checker.core import
|
|
9
|
-
from guppylang_internals.checker.errors.generic import
|
|
7
|
+
from guppylang_internals.ast_util import get_type, with_loc, with_type
|
|
8
|
+
from guppylang_internals.checker.core import Context
|
|
9
|
+
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
10
10
|
from guppylang_internals.checker.errors.type_errors import (
|
|
11
11
|
ArrayComprUnknownSizeError,
|
|
12
12
|
TypeMismatchError,
|
|
@@ -33,8 +33,6 @@ from guppylang_internals.nodes import (
|
|
|
33
33
|
GlobalCall,
|
|
34
34
|
MakeIter,
|
|
35
35
|
PanicExpr,
|
|
36
|
-
PlaceNode,
|
|
37
|
-
ResultExpr,
|
|
38
36
|
)
|
|
39
37
|
from guppylang_internals.tys.arg import ConstArg, TypeArg
|
|
40
38
|
from guppylang_internals.tys.builtin import (
|
|
@@ -45,7 +43,6 @@ from guppylang_internals.tys.builtin import (
|
|
|
45
43
|
get_iter_size,
|
|
46
44
|
int_type,
|
|
47
45
|
is_array_type,
|
|
48
|
-
is_bool_type,
|
|
49
46
|
is_sized_iter_type,
|
|
50
47
|
nat_type,
|
|
51
48
|
sized_iter_type,
|
|
@@ -58,7 +55,6 @@ from guppylang_internals.tys.ty import (
|
|
|
58
55
|
FunctionType,
|
|
59
56
|
InputFlags,
|
|
60
57
|
NoneType,
|
|
61
|
-
NumericType,
|
|
62
58
|
Type,
|
|
63
59
|
unify,
|
|
64
60
|
)
|
|
@@ -302,80 +298,6 @@ class NewArrayChecker(CustomCallChecker):
|
|
|
302
298
|
return with_loc(compr, array_compr), array_type(elt_ty, size)
|
|
303
299
|
|
|
304
300
|
|
|
305
|
-
#: Maximum length of a tag in the `result` function.
|
|
306
|
-
TAG_MAX_LEN = 200
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
@dataclass(frozen=True)
|
|
310
|
-
class TooLongError(Error):
|
|
311
|
-
title: ClassVar[str] = "Tag too long"
|
|
312
|
-
span_label: ClassVar[str] = "Result tag is too long"
|
|
313
|
-
|
|
314
|
-
@dataclass(frozen=True)
|
|
315
|
-
class Hint(Note):
|
|
316
|
-
message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
class ResultChecker(CustomCallChecker):
|
|
320
|
-
"""Call checker for the `result` function."""
|
|
321
|
-
|
|
322
|
-
@dataclass(frozen=True)
|
|
323
|
-
class InvalidError(Error):
|
|
324
|
-
title: ClassVar[str] = "Invalid Result"
|
|
325
|
-
span_label: ClassVar[str] = "Expression of type `{ty}` is not a valid result."
|
|
326
|
-
ty: Type
|
|
327
|
-
|
|
328
|
-
@dataclass(frozen=True)
|
|
329
|
-
class Explanation(Note):
|
|
330
|
-
message: ClassVar[str] = (
|
|
331
|
-
"Only numeric values or arrays thereof are allowed as results"
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
335
|
-
check_num_args(2, len(args), self.node)
|
|
336
|
-
[tag, value] = args
|
|
337
|
-
tag, _ = ExprChecker(self.ctx).check(tag, string_type())
|
|
338
|
-
match tag:
|
|
339
|
-
case ast.Constant(value=str(v)):
|
|
340
|
-
tag_value = v
|
|
341
|
-
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
342
|
-
tag_value = v
|
|
343
|
-
case _:
|
|
344
|
-
raise GuppyTypeError(ExpectedError(tag, "a string literal"))
|
|
345
|
-
if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
|
|
346
|
-
err: Error = TooLongError(tag)
|
|
347
|
-
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
348
|
-
raise GuppyTypeError(err)
|
|
349
|
-
value, ty = ExprSynthesizer(self.ctx).synthesize(value)
|
|
350
|
-
# We only allow numeric values or vectors of numeric values
|
|
351
|
-
err = ResultChecker.InvalidError(value, ty)
|
|
352
|
-
err.add_sub_diagnostic(ResultChecker.InvalidError.Explanation(None))
|
|
353
|
-
if self._is_numeric_or_bool_type(ty):
|
|
354
|
-
base_ty = ty
|
|
355
|
-
array_len: Const | None = None
|
|
356
|
-
elif is_array_type(ty):
|
|
357
|
-
[ty_arg, len_arg] = ty.args
|
|
358
|
-
assert isinstance(ty_arg, TypeArg)
|
|
359
|
-
assert isinstance(len_arg, ConstArg)
|
|
360
|
-
if not self._is_numeric_or_bool_type(ty_arg.ty):
|
|
361
|
-
raise GuppyError(err)
|
|
362
|
-
base_ty = ty_arg.ty
|
|
363
|
-
array_len = len_arg.const
|
|
364
|
-
else:
|
|
365
|
-
raise GuppyError(err)
|
|
366
|
-
node = ResultExpr(value, base_ty, array_len, tag_value)
|
|
367
|
-
return with_loc(self.node, node), NoneType()
|
|
368
|
-
|
|
369
|
-
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
370
|
-
expr, res_ty = self.synthesize(args)
|
|
371
|
-
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
|
|
372
|
-
return expr, subst
|
|
373
|
-
|
|
374
|
-
@staticmethod
|
|
375
|
-
def _is_numeric_or_bool_type(ty: Type) -> bool:
|
|
376
|
-
return isinstance(ty, NumericType) or is_bool_type(ty)
|
|
377
|
-
|
|
378
|
-
|
|
379
301
|
class PanicChecker(CustomCallChecker):
|
|
380
302
|
"""Call checker for the `panic` function."""
|
|
381
303
|
|
|
@@ -395,18 +317,16 @@ class PanicChecker(CustomCallChecker):
|
|
|
395
317
|
err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
|
|
396
318
|
raise GuppyTypeError(err)
|
|
397
319
|
case [msg, *rest]:
|
|
320
|
+
# Check type of message and synthesize types for additional values.
|
|
398
321
|
msg, _ = ExprChecker(self.ctx).check(msg, string_type())
|
|
399
|
-
match msg:
|
|
400
|
-
case ast.Constant(value=str(v)):
|
|
401
|
-
msg_value = v
|
|
402
|
-
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
403
|
-
msg_value = v
|
|
404
|
-
case _:
|
|
405
|
-
raise GuppyTypeError(ExpectedError(msg, "a string literal"))
|
|
406
322
|
vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
|
|
407
323
|
# TODO variable signals once default arguments are available
|
|
324
|
+
# TODO this will also allow us to remove this manual AST node hack
|
|
325
|
+
signal_expr = with_type(
|
|
326
|
+
int_type(), with_loc(self.node, ast.Constant(value=1))
|
|
327
|
+
)
|
|
408
328
|
node = PanicExpr(
|
|
409
|
-
kind=ExitKind.Panic, msg=
|
|
329
|
+
kind=ExitKind.Panic, msg=msg, values=vals, signal=signal_expr
|
|
410
330
|
)
|
|
411
331
|
return with_loc(self.node, node), NoneType()
|
|
412
332
|
case args:
|
|
@@ -454,31 +374,16 @@ class ExitChecker(CustomCallChecker):
|
|
|
454
374
|
)
|
|
455
375
|
raise GuppyTypeError(signal_err)
|
|
456
376
|
case [msg, signal, *rest]:
|
|
377
|
+
# Check types for message and signal and synthesize types for additional
|
|
378
|
+
# values.
|
|
457
379
|
msg, _ = ExprChecker(self.ctx).check(msg, string_type())
|
|
458
|
-
match msg:
|
|
459
|
-
case ast.Constant(value=str(v)):
|
|
460
|
-
msg_value = v
|
|
461
|
-
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
462
|
-
msg_value = v
|
|
463
|
-
case _:
|
|
464
|
-
raise GuppyTypeError(ExpectedError(msg, "a string literal"))
|
|
465
|
-
# TODO allow variable signals after https://github.com/CQCL/hugr/issues/1863
|
|
466
380
|
signal, _ = ExprChecker(self.ctx).check(signal, int_type())
|
|
467
|
-
match signal:
|
|
468
|
-
case ast.Constant(value=int(s)):
|
|
469
|
-
signal_value = s
|
|
470
|
-
case PlaceNode(place=ComptimeVariable(static_value=int(s))):
|
|
471
|
-
signal_value = s
|
|
472
|
-
case _:
|
|
473
|
-
raise GuppyTypeError(
|
|
474
|
-
ExpectedError(signal, "an integer literal")
|
|
475
|
-
)
|
|
476
381
|
vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
|
|
477
382
|
node = PanicExpr(
|
|
478
383
|
kind=ExitKind.ExitShot,
|
|
479
|
-
msg=
|
|
384
|
+
msg=msg,
|
|
480
385
|
values=vals,
|
|
481
|
-
signal=
|
|
386
|
+
signal=signal,
|
|
482
387
|
)
|
|
483
388
|
return with_loc(self.node, node), NoneType()
|
|
484
389
|
case args:
|
|
@@ -16,6 +16,7 @@ from guppylang_internals.std._internal.compiler.arithmetic import convert_itousi
|
|
|
16
16
|
from guppylang_internals.std._internal.compiler.prelude import (
|
|
17
17
|
build_unwrap_right,
|
|
18
18
|
)
|
|
19
|
+
from guppylang_internals.std._internal.compiler.tket_bool import make_opaque
|
|
19
20
|
from guppylang_internals.tys.arg import ConstArg, TypeArg
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
@@ -206,6 +207,14 @@ def barray_new_all_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
|
|
|
206
207
|
return _instantiate_array_op("new_all_borrowed", elem_ty, length, [], [arr_ty])
|
|
207
208
|
|
|
208
209
|
|
|
210
|
+
def barray_is_borrowed(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
|
|
211
|
+
"""Returns an array `is_borrowed` operation."""
|
|
212
|
+
arr_ty = array_type(elem_ty, length)
|
|
213
|
+
return _instantiate_array_op(
|
|
214
|
+
"is_borrowed", elem_ty, length, [arr_ty, ht.USize()], [arr_ty, ht.Bool]
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
209
218
|
def array_clone(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
|
|
210
219
|
"""Returns an array `clone` operation for arrays none of whose elements are
|
|
211
220
|
borrowed."""
|
|
@@ -261,7 +270,7 @@ class NewArrayCompiler(ArrayCompiler):
|
|
|
261
270
|
|
|
262
271
|
def build_classical_array(self, elems: list[Wire]) -> Wire:
|
|
263
272
|
"""Lowers a call to `array.__new__` for classical arrays."""
|
|
264
|
-
# See https://github.com/
|
|
273
|
+
# See https://github.com/quantinuum/guppylang/issues/629
|
|
265
274
|
return self.build_linear_array(elems)
|
|
266
275
|
|
|
267
276
|
def build_linear_array(self, elems: list[Wire]) -> Wire:
|
|
@@ -320,7 +329,15 @@ class ArrayGetitemCompiler(ArrayCompiler):
|
|
|
320
329
|
|
|
321
330
|
|
|
322
331
|
class ArraySetitemCompiler(ArrayCompiler):
|
|
323
|
-
"""Compiler for the `array.__setitem__` function.
|
|
332
|
+
"""Compiler for the `array.__setitem__` function.
|
|
333
|
+
|
|
334
|
+
Arguments:
|
|
335
|
+
elem_first: If `True`, then compiler will assume that the element wire comes
|
|
336
|
+
before the index wire. Defaults to `False`.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
def __init__(self, elem_first: bool = False):
|
|
340
|
+
self.elem_first = elem_first
|
|
324
341
|
|
|
325
342
|
def _build_classical_setitem(
|
|
326
343
|
self, array: Wire, idx: Wire, elem: Wire
|
|
@@ -359,6 +376,8 @@ class ArraySetitemCompiler(ArrayCompiler):
|
|
|
359
376
|
|
|
360
377
|
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
361
378
|
[array, idx, elem] = args
|
|
379
|
+
if self.elem_first:
|
|
380
|
+
elem, idx = idx, elem
|
|
362
381
|
if self.elem_ty.type_bound() == ht.TypeBound.Linear:
|
|
363
382
|
return self._build_linear_setitem(array, idx, elem)
|
|
364
383
|
else:
|
|
@@ -379,3 +398,19 @@ class ArrayDiscardAllUsedCompiler(ArrayCompiler):
|
|
|
379
398
|
arr,
|
|
380
399
|
)
|
|
381
400
|
return []
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class ArrayIsBorrowedCompiler(ArrayCompiler):
|
|
404
|
+
"""Compiler for the `array.is_borrowed` method."""
|
|
405
|
+
|
|
406
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
407
|
+
[array, idx] = args
|
|
408
|
+
idx = self.builder.add_op(convert_itousize(), idx)
|
|
409
|
+
array, b = self.builder.add_op(
|
|
410
|
+
barray_is_borrowed(self.elem_ty, self.length), array, idx
|
|
411
|
+
)
|
|
412
|
+
b = self.builder.add_op(make_opaque(), b)
|
|
413
|
+
return CallReturnWires(regular_returns=[b], inout_returns=[array])
|
|
414
|
+
|
|
415
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
416
|
+
raise InternalGuppyError("Call compile_with_inouts instead")
|
|
@@ -4,6 +4,8 @@ from collections.abc import Sequence
|
|
|
4
4
|
from hugr import Wire, ops
|
|
5
5
|
from hugr import tys as ht
|
|
6
6
|
|
|
7
|
+
from guppylang_internals.ast_util import get_type
|
|
8
|
+
from guppylang_internals.compiler.expr_compiler import pack_returns, unpack_wire
|
|
7
9
|
from guppylang_internals.definition.custom import (
|
|
8
10
|
CustomCallCompiler,
|
|
9
11
|
CustomInoutCallCompiler,
|
|
@@ -69,7 +71,14 @@ class EitherConstructor(EitherCompiler, CustomCallCompiler):
|
|
|
69
71
|
# In the `right` case, the type args are swapped around since `R` occurs
|
|
70
72
|
# first in the signature :(
|
|
71
73
|
ty.variant_rows = [ty.variant_rows[1], ty.variant_rows[0]]
|
|
72
|
-
|
|
74
|
+
# For the same reason, the type of the input corresponds to the first type
|
|
75
|
+
# variable
|
|
76
|
+
inp_arg = self.type_args[0]
|
|
77
|
+
assert isinstance(inp_arg, TypeArg)
|
|
78
|
+
[inp] = args
|
|
79
|
+
# Unpack the single input into a row
|
|
80
|
+
inp_row = unpack_wire(inp, inp_arg.ty, self.builder, self.ctx)
|
|
81
|
+
return [self.builder.add_op(ops.Tag(self.tag, ty), *inp_row)]
|
|
73
82
|
|
|
74
83
|
|
|
75
84
|
class EitherTestCompiler(EitherCompiler):
|
|
@@ -128,4 +137,7 @@ class EitherUnwrapCompiler(EitherCompiler, CustomCallCompiler):
|
|
|
128
137
|
out = build_unwrap_right(
|
|
129
138
|
self.builder, either, "Either.unwrap_right: value is `left`"
|
|
130
139
|
)
|
|
131
|
-
return
|
|
140
|
+
# Pack outputs into a single wire. We're not allowed to return a row since the
|
|
141
|
+
# signature has a generic return type (also see `TupleType.preserve`)
|
|
142
|
+
return_ty = get_type(self.node)
|
|
143
|
+
return [pack_returns(list(out), return_ty, self.builder, self.ctx)]
|
|
@@ -328,7 +328,7 @@ def _list_new_classical(
|
|
|
328
328
|
builder: DfBase[ops.DfParentOp], elem_type: ht.Type, args: list[Wire]
|
|
329
329
|
) -> Wire:
|
|
330
330
|
# This may be simplified in the future with a `new` or `with_capacity` list op
|
|
331
|
-
# See https://github.com/
|
|
331
|
+
# See https://github.com/quantinuum/hugr/issues/1508
|
|
332
332
|
lst = builder.load(ListVal([], elem_ty=elem_type))
|
|
333
333
|
push_op = list_push(elem_type)
|
|
334
334
|
for elem in args:
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import ClassVar
|
|
3
|
+
|
|
4
|
+
import hugr
|
|
5
|
+
from hugr import Wire, ops, tys
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.ast_util import AstNode
|
|
8
|
+
from guppylang_internals.compiler.core import CompilerContext
|
|
9
|
+
from guppylang_internals.compiler.expr_compiler import array_read_bool
|
|
10
|
+
from guppylang_internals.definition.custom import (
|
|
11
|
+
CustomCallCompiler,
|
|
12
|
+
CustomInoutCallCompiler,
|
|
13
|
+
)
|
|
14
|
+
from guppylang_internals.definition.value import CallReturnWires
|
|
15
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
16
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
17
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
18
|
+
array_clone,
|
|
19
|
+
array_map,
|
|
20
|
+
array_to_std_array,
|
|
21
|
+
)
|
|
22
|
+
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
|
|
23
|
+
from guppylang_internals.std._internal.compiler.tket_exts import RESULT_EXTENSION
|
|
24
|
+
from guppylang_internals.tys.arg import Argument, ConstArg
|
|
25
|
+
from guppylang_internals.tys.builtin import get_element_type, is_bool_type
|
|
26
|
+
from guppylang_internals.tys.const import BoundConstVar, ConstValue
|
|
27
|
+
from guppylang_internals.tys.ty import NumericType
|
|
28
|
+
|
|
29
|
+
#: Maximum length of a tag in the `result` function.
|
|
30
|
+
TAG_MAX_LEN = 200
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class TooLongError(Error):
|
|
35
|
+
title: ClassVar[str] = "Tag too long"
|
|
36
|
+
span_label: ClassVar[str] = "Result tag is too long"
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class Hint(Note):
|
|
40
|
+
message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class GenericHint(Note):
|
|
44
|
+
message: ClassVar[str] = "Parameter `{param}` was instantiated to `{value}`"
|
|
45
|
+
param: str
|
|
46
|
+
value: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ResultCompiler(CustomCallCompiler):
|
|
50
|
+
"""Custom compiler for overloads of the `result` function.
|
|
51
|
+
|
|
52
|
+
See `ArrayResultCompiler` for the compiler that handles results involving arrays.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, op_name: str, with_int_width: bool = False):
|
|
56
|
+
self.op_name = op_name
|
|
57
|
+
self.with_int_width = with_int_width
|
|
58
|
+
|
|
59
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
60
|
+
assert self.func is not None
|
|
61
|
+
[value] = args
|
|
62
|
+
ty = self.func.ty.inputs[1].ty
|
|
63
|
+
hugr_ty = ty.to_hugr(self.ctx)
|
|
64
|
+
args = [tag_to_hugr(self.type_args[0], self.ctx, self.node)]
|
|
65
|
+
if self.with_int_width:
|
|
66
|
+
args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
|
|
67
|
+
# Bool results need an extra conversion into regular hugr bools
|
|
68
|
+
if is_bool_type(ty):
|
|
69
|
+
value = self.builder.add_op(read_bool(), value)
|
|
70
|
+
hugr_ty = tys.Bool
|
|
71
|
+
op = RESULT_EXTENSION.get_op(self.op_name)
|
|
72
|
+
sig = tys.FunctionType(input=[hugr_ty], output=[])
|
|
73
|
+
self.builder.add_op(op.instantiate(args, sig), value)
|
|
74
|
+
return []
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ArrayResultCompiler(CustomInoutCallCompiler):
|
|
78
|
+
"""Custom compiler for overloads of the `result` function accepting arrays.
|
|
79
|
+
|
|
80
|
+
See `ResultCompiler` for the compiler that handles basic results.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, op_name: str, with_int_width: bool = False):
|
|
84
|
+
self.op_name = op_name
|
|
85
|
+
self.with_int_width = with_int_width
|
|
86
|
+
|
|
87
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
88
|
+
assert self.func is not None
|
|
89
|
+
array_ty = self.func.ty.inputs[1].ty
|
|
90
|
+
elem_ty = get_element_type(array_ty)
|
|
91
|
+
[tag_arg, size_arg] = self.type_args
|
|
92
|
+
[arr] = args
|
|
93
|
+
|
|
94
|
+
# As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
|
|
95
|
+
# that all elements in it are copyable) to avoid linearity violations when
|
|
96
|
+
# both passing it to the result operation and returning it (as an inout
|
|
97
|
+
# argument).
|
|
98
|
+
hugr_elem_ty = elem_ty.to_hugr(self.ctx)
|
|
99
|
+
hugr_size = size_arg.to_hugr(self.ctx)
|
|
100
|
+
arr, out_arr = self.builder.add_op(array_clone(hugr_elem_ty, hugr_size), arr)
|
|
101
|
+
# For bool arrays, we furthermore need to coerce a read on all the array
|
|
102
|
+
# elements
|
|
103
|
+
if is_bool_type(elem_ty):
|
|
104
|
+
array_read = array_read_bool(self.ctx)
|
|
105
|
+
array_read = self.builder.load_function(array_read)
|
|
106
|
+
map_op = array_map(OpaqueBool, hugr_size, tys.Bool)
|
|
107
|
+
arr = self.builder.add_op(map_op, arr, array_read).out(0)
|
|
108
|
+
hugr_elem_ty = tys.Bool
|
|
109
|
+
# Turn `borrow_array` into regular `array`
|
|
110
|
+
arr = self.builder.add_op(array_to_std_array(hugr_elem_ty, hugr_size), arr).out(
|
|
111
|
+
0
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
hugr_ty = hugr.std.collections.array.Array(hugr_elem_ty, hugr_size)
|
|
115
|
+
sig = tys.FunctionType(input=[hugr_ty], output=[])
|
|
116
|
+
args = [tag_to_hugr(tag_arg, self.ctx, self.node), hugr_size]
|
|
117
|
+
if self.with_int_width:
|
|
118
|
+
args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
|
|
119
|
+
op = ops.ExtOp(RESULT_EXTENSION.get_op(self.op_name), signature=sig, args=args)
|
|
120
|
+
self.builder.add_op(op, arr)
|
|
121
|
+
return CallReturnWires([], [out_arr])
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def tag_to_hugr(tag_arg: Argument, ctx: CompilerContext, loc: AstNode) -> tys.TypeArg:
|
|
125
|
+
"""Helper function to convert the Guppy tag comptime argument into a Hugr type arg.
|
|
126
|
+
|
|
127
|
+
Takes care of reading the tag value from the current monomorphization and checks
|
|
128
|
+
that the tag fits into `TAG_MAX_LEN`.
|
|
129
|
+
"""
|
|
130
|
+
is_generic: BoundConstVar | None = None
|
|
131
|
+
match tag_arg:
|
|
132
|
+
case ConstArg(const=ConstValue(value=str(value))):
|
|
133
|
+
tag = value
|
|
134
|
+
case ConstArg(const=BoundConstVar(idx=idx) as var):
|
|
135
|
+
is_generic = var
|
|
136
|
+
assert ctx.current_mono_args is not None
|
|
137
|
+
match ctx.current_mono_args[idx]:
|
|
138
|
+
case ConstArg(const=ConstValue(value=str(value))):
|
|
139
|
+
tag = value
|
|
140
|
+
case _:
|
|
141
|
+
raise InternalGuppyError("Invalid tag monomorphization")
|
|
142
|
+
case _:
|
|
143
|
+
raise InternalGuppyError("Invalid tag argument")
|
|
144
|
+
|
|
145
|
+
if len(tag.encode("utf-8")) > TAG_MAX_LEN:
|
|
146
|
+
err = TooLongError(loc)
|
|
147
|
+
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
148
|
+
if is_generic:
|
|
149
|
+
err.add_sub_diagnostic(
|
|
150
|
+
TooLongError.GenericHint(None, is_generic.display_name, tag)
|
|
151
|
+
)
|
|
152
|
+
raise GuppyError(err)
|
|
153
|
+
return tys.StringArg(tag)
|
|
@@ -73,6 +73,14 @@ def panic(
|
|
|
73
73
|
return ops.ExtOp(op_def, sig, args)
|
|
74
74
|
|
|
75
75
|
|
|
76
|
+
def make_error() -> ops.ExtOp:
|
|
77
|
+
"""Returns an operation that makes an error."""
|
|
78
|
+
op_def = hugr.std.PRELUDE.get_op("MakeError")
|
|
79
|
+
args: list[ht.TypeArg] = []
|
|
80
|
+
sig = ht.FunctionType([ht.USize(), hugr.std.prelude.STRING_T], [error_type()])
|
|
81
|
+
return ops.ExtOp(op_def, sig, args)
|
|
82
|
+
|
|
83
|
+
|
|
76
84
|
# ------------------------------------------------------
|
|
77
85
|
# --------- Custom compilers for non-native ops --------
|
|
78
86
|
# ------------------------------------------------------
|
|
@@ -90,14 +98,14 @@ def build_panic(
|
|
|
90
98
|
return builder.add_op(op, err, *args)
|
|
91
99
|
|
|
92
100
|
|
|
93
|
-
def
|
|
101
|
+
def build_static_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
|
|
94
102
|
"""Constructs and loads a static error value."""
|
|
95
103
|
val = ErrorVal(signal, msg)
|
|
96
104
|
return builder.load(builder.add_const(val))
|
|
97
105
|
|
|
98
106
|
|
|
99
107
|
# TODO: Common up build_unwrap_right and build_unwrap_left below once
|
|
100
|
-
# https://github.com/
|
|
108
|
+
# https://github.com/quantinuum/hugr/issues/1596 is fixed
|
|
101
109
|
|
|
102
110
|
|
|
103
111
|
def build_unwrap_right(
|
|
@@ -111,7 +119,7 @@ def build_unwrap_right(
|
|
|
111
119
|
assert isinstance(result_ty, ht.Sum)
|
|
112
120
|
[left_tys, right_tys] = result_ty.variant_rows
|
|
113
121
|
with conditional.add_case(0) as case:
|
|
114
|
-
error =
|
|
122
|
+
error = build_static_error(case, error_signal, error_msg)
|
|
115
123
|
case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
|
|
116
124
|
with conditional.add_case(1) as case:
|
|
117
125
|
case.set_outputs(*case.inputs())
|
|
@@ -134,7 +142,7 @@ def build_unwrap_left(
|
|
|
134
142
|
with conditional.add_case(0) as case:
|
|
135
143
|
case.set_outputs(*case.inputs())
|
|
136
144
|
with conditional.add_case(1) as case:
|
|
137
|
-
error =
|
|
145
|
+
error = build_static_error(case, error_signal, error_msg)
|
|
138
146
|
case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
|
|
139
147
|
return conditional.to_node()
|
|
140
148
|
|
|
@@ -40,12 +40,7 @@ class OpaqueBoolVal(hv.ExtensionValue):
|
|
|
40
40
|
def to_value(self) -> hv.Extension:
|
|
41
41
|
name = "ConstBool"
|
|
42
42
|
payload = self.v
|
|
43
|
-
return hv.Extension(
|
|
44
|
-
name,
|
|
45
|
-
typ=OpaqueBool,
|
|
46
|
-
val=payload,
|
|
47
|
-
extensions=[BOOL_EXTENSION.name],
|
|
48
|
-
)
|
|
43
|
+
return hv.Extension(name, typ=OpaqueBool, val=payload)
|
|
49
44
|
|
|
50
45
|
def __str__(self) -> str:
|
|
51
46
|
return f"{self.v}"
|
|
@@ -20,6 +20,7 @@ from tket_exts import (
|
|
|
20
20
|
BOOL_EXTENSION = tket_exts.bool()
|
|
21
21
|
DEBUG_EXTENSION = debug()
|
|
22
22
|
FUTURES_EXTENSION = futures()
|
|
23
|
+
GLOBAL_PHASE_EXTENSION = global_phase()
|
|
23
24
|
GUPPY_EXTENSION = guppy()
|
|
24
25
|
MODIFIER_EXTENSION = modifier()
|
|
25
26
|
QSYSTEM_EXTENSION = qsystem()
|
|
@@ -29,14 +30,14 @@ QUANTUM_EXTENSION = quantum()
|
|
|
29
30
|
RESULT_EXTENSION = result()
|
|
30
31
|
ROTATION_EXTENSION = rotation()
|
|
31
32
|
WASM_EXTENSION = wasm()
|
|
32
|
-
MODIFIER_EXTENSION = modifier()
|
|
33
|
-
GLOBAL_PHASE_EXTENSION = global_phase()
|
|
34
33
|
|
|
35
34
|
TKET_EXTENSIONS = [
|
|
36
35
|
BOOL_EXTENSION,
|
|
37
36
|
DEBUG_EXTENSION,
|
|
38
37
|
FUTURES_EXTENSION,
|
|
38
|
+
GLOBAL_PHASE_EXTENSION,
|
|
39
39
|
GUPPY_EXTENSION,
|
|
40
|
+
MODIFIER_EXTENSION,
|
|
40
41
|
QSYSTEM_EXTENSION,
|
|
41
42
|
QSYSTEM_RANDOM_EXTENSION,
|
|
42
43
|
QSYSTEM_UTILS_EXTENSION,
|
|
@@ -44,8 +45,6 @@ TKET_EXTENSIONS = [
|
|
|
44
45
|
RESULT_EXTENSION,
|
|
45
46
|
ROTATION_EXTENSION,
|
|
46
47
|
WASM_EXTENSION,
|
|
47
|
-
MODIFIER_EXTENSION,
|
|
48
|
-
GLOBAL_PHASE_EXTENSION,
|
|
49
48
|
]
|
|
50
49
|
|
|
51
50
|
|
|
@@ -60,7 +59,7 @@ class ConstWasmModule(val.ExtensionValue):
|
|
|
60
59
|
|
|
61
60
|
name = "ConstWasmModule"
|
|
62
61
|
payload = {"module_filename": self.wasm_file}
|
|
63
|
-
return val.Extension(name, typ=ty, val=payload
|
|
62
|
+
return val.Extension(name, typ=ty, val=payload)
|
|
64
63
|
|
|
65
64
|
def __str__(self) -> str:
|
|
66
65
|
return f"tket.wasm.module(module_filename={self.wasm_file})"
|
|
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|
|
3
3
|
from typing import ClassVar, cast
|
|
4
4
|
|
|
5
5
|
from guppylang_internals.ast_util import with_loc
|
|
6
|
+
from guppylang_internals.checker.core import ComptimeVariable
|
|
6
7
|
from guppylang_internals.checker.errors.generic import ExpectedError
|
|
7
8
|
from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
|
|
8
9
|
from guppylang_internals.checker.expr_checker import (
|
|
@@ -14,14 +15,14 @@ from guppylang_internals.definition.custom import CustomCallChecker
|
|
|
14
15
|
from guppylang_internals.definition.ty import TypeDef
|
|
15
16
|
from guppylang_internals.diagnostic import Error
|
|
16
17
|
from guppylang_internals.error import GuppyTypeError
|
|
17
|
-
from guppylang_internals.nodes import StateResultExpr
|
|
18
|
-
from guppylang_internals.std._internal.checker import TAG_MAX_LEN, TooLongError
|
|
18
|
+
from guppylang_internals.nodes import GenericParamValue, PlaceNode, StateResultExpr
|
|
19
19
|
from guppylang_internals.tys.builtin import (
|
|
20
20
|
get_array_length,
|
|
21
21
|
get_element_type,
|
|
22
22
|
is_array_type,
|
|
23
23
|
string_type,
|
|
24
24
|
)
|
|
25
|
+
from guppylang_internals.tys.const import Const, ConstValue
|
|
25
26
|
from guppylang_internals.tys.ty import (
|
|
26
27
|
FuncInput,
|
|
27
28
|
FunctionType,
|
|
@@ -43,12 +44,16 @@ class StateResultChecker(CustomCallChecker):
|
|
|
43
44
|
|
|
44
45
|
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
45
46
|
tag, _ = ExprChecker(self.ctx).check(args[0], string_type())
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
47
|
+
tag_value: Const
|
|
48
|
+
match tag:
|
|
49
|
+
case ast.Constant(value=str(v)):
|
|
50
|
+
tag_value = ConstValue(string_type(), v)
|
|
51
|
+
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
52
|
+
tag_value = ConstValue(string_type(), v)
|
|
53
|
+
case GenericParamValue() as param_value:
|
|
54
|
+
tag_value = param_value.param.to_bound().const
|
|
55
|
+
case _:
|
|
56
|
+
raise GuppyTypeError(ExpectedError(tag, "a string literal"))
|
|
52
57
|
syn_args: list[ast.expr] = [tag]
|
|
53
58
|
|
|
54
59
|
if len(args) < 2:
|
|
@@ -90,6 +95,10 @@ class StateResultChecker(CustomCallChecker):
|
|
|
90
95
|
args, ret_ty, inst = synthesize_call(func_ty, syn_args, self.node, self.ctx)
|
|
91
96
|
assert len(inst) == 0, "func_ty is not generic"
|
|
92
97
|
node = StateResultExpr(
|
|
93
|
-
|
|
98
|
+
tag_value=tag_value,
|
|
99
|
+
tag_expr=tag,
|
|
100
|
+
args=args,
|
|
101
|
+
func_ty=func_ty,
|
|
102
|
+
array_len=array_len,
|
|
94
103
|
)
|
|
95
104
|
return with_loc(self.node, node), ret_ty
|
|
@@ -129,7 +129,7 @@ def int_op(
|
|
|
129
129
|
# Ideally we'd be able to derive the arguments from the input/output types,
|
|
130
130
|
# but the amount of variables does not correlate with the signature for the
|
|
131
131
|
# integer ops in hugr :/
|
|
132
|
-
# https://github.com/
|
|
132
|
+
# https://github.com/quantinuum/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
|
|
133
133
|
#
|
|
134
134
|
# For now, we just instantiate every type argument to a 64-bit integer.
|
|
135
135
|
args: list[ht.TypeArg] = [int_arg() for _ in range(n_vars)]
|
|
@@ -342,6 +342,10 @@ class GuppyObject(DunderMixin):
|
|
|
342
342
|
if not ty.droppable and not self._used:
|
|
343
343
|
state.unused_undroppable_objs[self._id] = self
|
|
344
344
|
|
|
345
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> "GuppyObject":
|
|
346
|
+
# Dummy deepcopy implementation, we do not want to actually deepcopy
|
|
347
|
+
return self
|
|
348
|
+
|
|
345
349
|
@hide_trace
|
|
346
350
|
def __getattr__(self, key: str) -> Any: # type: ignore[misc]
|
|
347
351
|
# Guppy objects don't have fields (structs are treated separately below), so the
|
|
@@ -539,6 +543,16 @@ class TracingDefMixin(DunderMixin):
|
|
|
539
543
|
|
|
540
544
|
def to_guppy_object(self) -> GuppyObject:
|
|
541
545
|
state = get_tracing_state()
|
|
546
|
+
defn = ENGINE.get_checked(self.id)
|
|
547
|
+
# TODO: For generic functions, we need to know an instantiation for their type
|
|
548
|
+
# parameters. Maybe we should pass them to `to_guppy_object`? Either way, this
|
|
549
|
+
# will require some more plumbing of type inference information through the
|
|
550
|
+
# comptime logic. For now, let's just bail on generic functions.
|
|
551
|
+
# See https://github.com/quantinuum/guppylang/issues/1336
|
|
552
|
+
if isinstance(defn, CallableDef) and defn.ty.parametrized:
|
|
553
|
+
raise GuppyComptimeError(
|
|
554
|
+
f"Cannot infer type parameters of generic function `{defn.name}`"
|
|
555
|
+
)
|
|
542
556
|
defn, [] = state.ctx.build_compiled_def(self.id, type_args=[])
|
|
543
557
|
if isinstance(defn, CompiledValueDef):
|
|
544
558
|
wire = defn.load(state.dfg, state.ctx, state.node)
|