guppylang-internals 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/cfg/builder.py +17 -2
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +1 -2
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +13 -8
- guppylang_internals/checker/func_checker.py +17 -13
- guppylang_internals/checker/linearity_checker.py +2 -10
- guppylang_internals/checker/modifier_checker.py +6 -2
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +5 -5
- guppylang_internals/compiler/expr_compiler.py +42 -73
- guppylang_internals/compiler/modifier_compiler.py +2 -0
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +4 -0
- guppylang_internals/definition/declaration.py +6 -2
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/pytket_circuits.py +1 -0
- guppylang_internals/definition/struct.py +6 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/nodes.py +23 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +1 -1
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +3 -4
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/parsing.py +3 -3
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +37 -2
- guppylang_internals/tys/ty.py +60 -64
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +4 -3
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/RECORD +43 -40
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
guppylang_internals/nodes.py
CHANGED
|
@@ -9,7 +9,13 @@ from guppylang_internals.ast_util import AstNode
|
|
|
9
9
|
from guppylang_internals.span import Span, to_span
|
|
10
10
|
from guppylang_internals.tys.const import Const
|
|
11
11
|
from guppylang_internals.tys.subst import Inst
|
|
12
|
-
from guppylang_internals.tys.ty import
|
|
12
|
+
from guppylang_internals.tys.ty import (
|
|
13
|
+
FunctionType,
|
|
14
|
+
StructType,
|
|
15
|
+
TupleType,
|
|
16
|
+
Type,
|
|
17
|
+
UnitaryFlags,
|
|
18
|
+
)
|
|
13
19
|
|
|
14
20
|
if TYPE_CHECKING:
|
|
15
21
|
from guppylang_internals.cfg.cfg import CFG
|
|
@@ -250,22 +256,6 @@ class ComptimeExpr(ast.expr):
|
|
|
250
256
|
_fields = ("value",)
|
|
251
257
|
|
|
252
258
|
|
|
253
|
-
class ResultExpr(ast.expr):
|
|
254
|
-
"""A `result(tag, value)` expression."""
|
|
255
|
-
|
|
256
|
-
value: ast.expr
|
|
257
|
-
base_ty: Type
|
|
258
|
-
#: Array length in case this is an array result, otherwise `None`
|
|
259
|
-
array_len: Const | None
|
|
260
|
-
tag: str
|
|
261
|
-
|
|
262
|
-
_fields = ("value", "base_ty", "array_len", "tag")
|
|
263
|
-
|
|
264
|
-
@property
|
|
265
|
-
def args(self) -> list[ast.expr]:
|
|
266
|
-
return [self.value]
|
|
267
|
-
|
|
268
|
-
|
|
269
259
|
class ExitKind(Enum):
|
|
270
260
|
ExitShot = 0 # Exit the current shot
|
|
271
261
|
Panic = 1 # Panic the program ending all shots
|
|
@@ -275,8 +265,8 @@ class PanicExpr(ast.expr):
|
|
|
275
265
|
"""A `panic(msg, *args)` or `exit(msg, *args)` expression ."""
|
|
276
266
|
|
|
277
267
|
kind: ExitKind
|
|
278
|
-
signal:
|
|
279
|
-
msg:
|
|
268
|
+
signal: ast.expr
|
|
269
|
+
msg: ast.expr
|
|
280
270
|
values: list[ast.expr]
|
|
281
271
|
|
|
282
272
|
_fields = ("kind", "signal", "msg", "values")
|
|
@@ -293,17 +283,16 @@ class BarrierExpr(ast.expr):
|
|
|
293
283
|
class StateResultExpr(ast.expr):
|
|
294
284
|
"""A `state_result(tag, *args)` expression."""
|
|
295
285
|
|
|
296
|
-
|
|
286
|
+
tag_value: Const
|
|
287
|
+
tag_expr: ast.expr
|
|
297
288
|
args: list[ast.expr]
|
|
298
289
|
func_ty: FunctionType
|
|
299
290
|
#: Array length in case this is an array result, otherwise `None`
|
|
300
291
|
array_len: Const | None
|
|
301
|
-
_fields = ("
|
|
292
|
+
_fields = ("tag_value", "tag_expr", "args", "func_ty", "has_array_input")
|
|
302
293
|
|
|
303
294
|
|
|
304
|
-
AnyCall =
|
|
305
|
-
LocalCall | GlobalCall | TensorCall | BarrierExpr | ResultExpr | StateResultExpr
|
|
306
|
-
)
|
|
295
|
+
AnyCall = LocalCall | GlobalCall | TensorCall | BarrierExpr | StateResultExpr
|
|
307
296
|
|
|
308
297
|
|
|
309
298
|
class InoutReturnSentinel(ast.expr):
|
|
@@ -500,6 +489,16 @@ class ModifiedBlock(ast.With):
|
|
|
500
489
|
else:
|
|
501
490
|
raise TypeError(f"Unknown modifier: {modifier}")
|
|
502
491
|
|
|
492
|
+
def flags(self) -> UnitaryFlags:
|
|
493
|
+
flags = UnitaryFlags.NoFlags
|
|
494
|
+
if self.is_dagger():
|
|
495
|
+
flags |= UnitaryFlags.Dagger
|
|
496
|
+
if self.is_control():
|
|
497
|
+
flags |= UnitaryFlags.Control
|
|
498
|
+
if self.is_power():
|
|
499
|
+
flags |= UnitaryFlags.Power
|
|
500
|
+
return flags
|
|
501
|
+
|
|
503
502
|
|
|
504
503
|
class CheckedModifiedBlock(ast.With):
|
|
505
504
|
def_id: "DefId"
|
|
@@ -4,9 +4,9 @@ from typing import ClassVar
|
|
|
4
4
|
|
|
5
5
|
from typing_extensions import assert_never
|
|
6
6
|
|
|
7
|
-
from guppylang_internals.ast_util import get_type, with_loc
|
|
8
|
-
from guppylang_internals.checker.core import
|
|
9
|
-
from guppylang_internals.checker.errors.generic import
|
|
7
|
+
from guppylang_internals.ast_util import get_type, with_loc, with_type
|
|
8
|
+
from guppylang_internals.checker.core import Context
|
|
9
|
+
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
10
10
|
from guppylang_internals.checker.errors.type_errors import (
|
|
11
11
|
ArrayComprUnknownSizeError,
|
|
12
12
|
TypeMismatchError,
|
|
@@ -33,8 +33,6 @@ from guppylang_internals.nodes import (
|
|
|
33
33
|
GlobalCall,
|
|
34
34
|
MakeIter,
|
|
35
35
|
PanicExpr,
|
|
36
|
-
PlaceNode,
|
|
37
|
-
ResultExpr,
|
|
38
36
|
)
|
|
39
37
|
from guppylang_internals.tys.arg import ConstArg, TypeArg
|
|
40
38
|
from guppylang_internals.tys.builtin import (
|
|
@@ -45,7 +43,6 @@ from guppylang_internals.tys.builtin import (
|
|
|
45
43
|
get_iter_size,
|
|
46
44
|
int_type,
|
|
47
45
|
is_array_type,
|
|
48
|
-
is_bool_type,
|
|
49
46
|
is_sized_iter_type,
|
|
50
47
|
nat_type,
|
|
51
48
|
sized_iter_type,
|
|
@@ -58,7 +55,6 @@ from guppylang_internals.tys.ty import (
|
|
|
58
55
|
FunctionType,
|
|
59
56
|
InputFlags,
|
|
60
57
|
NoneType,
|
|
61
|
-
NumericType,
|
|
62
58
|
Type,
|
|
63
59
|
unify,
|
|
64
60
|
)
|
|
@@ -302,80 +298,6 @@ class NewArrayChecker(CustomCallChecker):
|
|
|
302
298
|
return with_loc(compr, array_compr), array_type(elt_ty, size)
|
|
303
299
|
|
|
304
300
|
|
|
305
|
-
#: Maximum length of a tag in the `result` function.
|
|
306
|
-
TAG_MAX_LEN = 200
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
@dataclass(frozen=True)
|
|
310
|
-
class TooLongError(Error):
|
|
311
|
-
title: ClassVar[str] = "Tag too long"
|
|
312
|
-
span_label: ClassVar[str] = "Result tag is too long"
|
|
313
|
-
|
|
314
|
-
@dataclass(frozen=True)
|
|
315
|
-
class Hint(Note):
|
|
316
|
-
message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
class ResultChecker(CustomCallChecker):
|
|
320
|
-
"""Call checker for the `result` function."""
|
|
321
|
-
|
|
322
|
-
@dataclass(frozen=True)
|
|
323
|
-
class InvalidError(Error):
|
|
324
|
-
title: ClassVar[str] = "Invalid Result"
|
|
325
|
-
span_label: ClassVar[str] = "Expression of type `{ty}` is not a valid result."
|
|
326
|
-
ty: Type
|
|
327
|
-
|
|
328
|
-
@dataclass(frozen=True)
|
|
329
|
-
class Explanation(Note):
|
|
330
|
-
message: ClassVar[str] = (
|
|
331
|
-
"Only numeric values or arrays thereof are allowed as results"
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
335
|
-
check_num_args(2, len(args), self.node)
|
|
336
|
-
[tag, value] = args
|
|
337
|
-
tag, _ = ExprChecker(self.ctx).check(tag, string_type())
|
|
338
|
-
match tag:
|
|
339
|
-
case ast.Constant(value=str(v)):
|
|
340
|
-
tag_value = v
|
|
341
|
-
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
342
|
-
tag_value = v
|
|
343
|
-
case _:
|
|
344
|
-
raise GuppyTypeError(ExpectedError(tag, "a string literal"))
|
|
345
|
-
if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
|
|
346
|
-
err: Error = TooLongError(tag)
|
|
347
|
-
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
348
|
-
raise GuppyTypeError(err)
|
|
349
|
-
value, ty = ExprSynthesizer(self.ctx).synthesize(value)
|
|
350
|
-
# We only allow numeric values or vectors of numeric values
|
|
351
|
-
err = ResultChecker.InvalidError(value, ty)
|
|
352
|
-
err.add_sub_diagnostic(ResultChecker.InvalidError.Explanation(None))
|
|
353
|
-
if self._is_numeric_or_bool_type(ty):
|
|
354
|
-
base_ty = ty
|
|
355
|
-
array_len: Const | None = None
|
|
356
|
-
elif is_array_type(ty):
|
|
357
|
-
[ty_arg, len_arg] = ty.args
|
|
358
|
-
assert isinstance(ty_arg, TypeArg)
|
|
359
|
-
assert isinstance(len_arg, ConstArg)
|
|
360
|
-
if not self._is_numeric_or_bool_type(ty_arg.ty):
|
|
361
|
-
raise GuppyError(err)
|
|
362
|
-
base_ty = ty_arg.ty
|
|
363
|
-
array_len = len_arg.const
|
|
364
|
-
else:
|
|
365
|
-
raise GuppyError(err)
|
|
366
|
-
node = ResultExpr(value, base_ty, array_len, tag_value)
|
|
367
|
-
return with_loc(self.node, node), NoneType()
|
|
368
|
-
|
|
369
|
-
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
370
|
-
expr, res_ty = self.synthesize(args)
|
|
371
|
-
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
|
|
372
|
-
return expr, subst
|
|
373
|
-
|
|
374
|
-
@staticmethod
|
|
375
|
-
def _is_numeric_or_bool_type(ty: Type) -> bool:
|
|
376
|
-
return isinstance(ty, NumericType) or is_bool_type(ty)
|
|
377
|
-
|
|
378
|
-
|
|
379
301
|
class PanicChecker(CustomCallChecker):
|
|
380
302
|
"""Call checker for the `panic` function."""
|
|
381
303
|
|
|
@@ -395,18 +317,16 @@ class PanicChecker(CustomCallChecker):
|
|
|
395
317
|
err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
|
|
396
318
|
raise GuppyTypeError(err)
|
|
397
319
|
case [msg, *rest]:
|
|
320
|
+
# Check type of message and synthesize types for additional values.
|
|
398
321
|
msg, _ = ExprChecker(self.ctx).check(msg, string_type())
|
|
399
|
-
match msg:
|
|
400
|
-
case ast.Constant(value=str(v)):
|
|
401
|
-
msg_value = v
|
|
402
|
-
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
403
|
-
msg_value = v
|
|
404
|
-
case _:
|
|
405
|
-
raise GuppyTypeError(ExpectedError(msg, "a string literal"))
|
|
406
322
|
vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
|
|
407
323
|
# TODO variable signals once default arguments are available
|
|
324
|
+
# TODO this will also allow us to remove this manual AST node hack
|
|
325
|
+
signal_expr = with_type(
|
|
326
|
+
int_type(), with_loc(self.node, ast.Constant(value=1))
|
|
327
|
+
)
|
|
408
328
|
node = PanicExpr(
|
|
409
|
-
kind=ExitKind.Panic, msg=
|
|
329
|
+
kind=ExitKind.Panic, msg=msg, values=vals, signal=signal_expr
|
|
410
330
|
)
|
|
411
331
|
return with_loc(self.node, node), NoneType()
|
|
412
332
|
case args:
|
|
@@ -454,31 +374,16 @@ class ExitChecker(CustomCallChecker):
|
|
|
454
374
|
)
|
|
455
375
|
raise GuppyTypeError(signal_err)
|
|
456
376
|
case [msg, signal, *rest]:
|
|
377
|
+
# Check types for message and signal and synthesize types for additional
|
|
378
|
+
# values.
|
|
457
379
|
msg, _ = ExprChecker(self.ctx).check(msg, string_type())
|
|
458
|
-
match msg:
|
|
459
|
-
case ast.Constant(value=str(v)):
|
|
460
|
-
msg_value = v
|
|
461
|
-
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
462
|
-
msg_value = v
|
|
463
|
-
case _:
|
|
464
|
-
raise GuppyTypeError(ExpectedError(msg, "a string literal"))
|
|
465
|
-
# TODO allow variable signals after https://github.com/CQCL/hugr/issues/1863
|
|
466
380
|
signal, _ = ExprChecker(self.ctx).check(signal, int_type())
|
|
467
|
-
match signal:
|
|
468
|
-
case ast.Constant(value=int(s)):
|
|
469
|
-
signal_value = s
|
|
470
|
-
case PlaceNode(place=ComptimeVariable(static_value=int(s))):
|
|
471
|
-
signal_value = s
|
|
472
|
-
case _:
|
|
473
|
-
raise GuppyTypeError(
|
|
474
|
-
ExpectedError(signal, "an integer literal")
|
|
475
|
-
)
|
|
476
381
|
vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
|
|
477
382
|
node = PanicExpr(
|
|
478
383
|
kind=ExitKind.ExitShot,
|
|
479
|
-
msg=
|
|
384
|
+
msg=msg,
|
|
480
385
|
values=vals,
|
|
481
|
-
signal=
|
|
386
|
+
signal=signal,
|
|
482
387
|
)
|
|
483
388
|
return with_loc(self.node, node), NoneType()
|
|
484
389
|
case args:
|
|
@@ -261,7 +261,7 @@ class NewArrayCompiler(ArrayCompiler):
|
|
|
261
261
|
|
|
262
262
|
def build_classical_array(self, elems: list[Wire]) -> Wire:
|
|
263
263
|
"""Lowers a call to `array.__new__` for classical arrays."""
|
|
264
|
-
# See https://github.com/
|
|
264
|
+
# See https://github.com/quantinuum/guppylang/issues/629
|
|
265
265
|
return self.build_linear_array(elems)
|
|
266
266
|
|
|
267
267
|
def build_linear_array(self, elems: list[Wire]) -> Wire:
|
|
@@ -328,7 +328,7 @@ def _list_new_classical(
|
|
|
328
328
|
builder: DfBase[ops.DfParentOp], elem_type: ht.Type, args: list[Wire]
|
|
329
329
|
) -> Wire:
|
|
330
330
|
# This may be simplified in the future with a `new` or `with_capacity` list op
|
|
331
|
-
# See https://github.com/
|
|
331
|
+
# See https://github.com/quantinuum/hugr/issues/1508
|
|
332
332
|
lst = builder.load(ListVal([], elem_ty=elem_type))
|
|
333
333
|
push_op = list_push(elem_type)
|
|
334
334
|
for elem in args:
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import ClassVar
|
|
3
|
+
|
|
4
|
+
import hugr
|
|
5
|
+
from hugr import Wire, ops, tys
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.ast_util import AstNode
|
|
8
|
+
from guppylang_internals.compiler.core import CompilerContext
|
|
9
|
+
from guppylang_internals.compiler.expr_compiler import array_read_bool
|
|
10
|
+
from guppylang_internals.definition.custom import (
|
|
11
|
+
CustomCallCompiler,
|
|
12
|
+
CustomInoutCallCompiler,
|
|
13
|
+
)
|
|
14
|
+
from guppylang_internals.definition.value import CallReturnWires
|
|
15
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
16
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
17
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
18
|
+
array_clone,
|
|
19
|
+
array_map,
|
|
20
|
+
array_to_std_array,
|
|
21
|
+
)
|
|
22
|
+
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
|
|
23
|
+
from guppylang_internals.std._internal.compiler.tket_exts import RESULT_EXTENSION
|
|
24
|
+
from guppylang_internals.tys.arg import Argument, ConstArg
|
|
25
|
+
from guppylang_internals.tys.builtin import get_element_type, is_bool_type
|
|
26
|
+
from guppylang_internals.tys.const import BoundConstVar, ConstValue
|
|
27
|
+
from guppylang_internals.tys.ty import NumericType
|
|
28
|
+
|
|
29
|
+
#: Maximum length of a tag in the `result` function.
|
|
30
|
+
TAG_MAX_LEN = 200
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class TooLongError(Error):
|
|
35
|
+
title: ClassVar[str] = "Tag too long"
|
|
36
|
+
span_label: ClassVar[str] = "Result tag is too long"
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class Hint(Note):
|
|
40
|
+
message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class GenericHint(Note):
|
|
44
|
+
message: ClassVar[str] = "Parameter `{param}` was instantiated to `{value}`"
|
|
45
|
+
param: str
|
|
46
|
+
value: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ResultCompiler(CustomCallCompiler):
|
|
50
|
+
"""Custom compiler for overloads of the `result` function.
|
|
51
|
+
|
|
52
|
+
See `ArrayResultCompiler` for the compiler that handles results involving arrays.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, op_name: str, with_int_width: bool = False):
|
|
56
|
+
self.op_name = op_name
|
|
57
|
+
self.with_int_width = with_int_width
|
|
58
|
+
|
|
59
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
60
|
+
assert self.func is not None
|
|
61
|
+
[value] = args
|
|
62
|
+
ty = self.func.ty.inputs[1].ty
|
|
63
|
+
hugr_ty = ty.to_hugr(self.ctx)
|
|
64
|
+
args = [tag_to_hugr(self.type_args[0], self.ctx, self.node)]
|
|
65
|
+
if self.with_int_width:
|
|
66
|
+
args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
|
|
67
|
+
# Bool results need an extra conversion into regular hugr bools
|
|
68
|
+
if is_bool_type(ty):
|
|
69
|
+
value = self.builder.add_op(read_bool(), value)
|
|
70
|
+
hugr_ty = tys.Bool
|
|
71
|
+
op = RESULT_EXTENSION.get_op(self.op_name)
|
|
72
|
+
sig = tys.FunctionType(input=[hugr_ty], output=[])
|
|
73
|
+
self.builder.add_op(op.instantiate(args, sig), value)
|
|
74
|
+
return []
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ArrayResultCompiler(CustomInoutCallCompiler):
|
|
78
|
+
"""Custom compiler for overloads of the `result` function accepting arrays.
|
|
79
|
+
|
|
80
|
+
See `ResultCompiler` for the compiler that handles basic results.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, op_name: str, with_int_width: bool = False):
|
|
84
|
+
self.op_name = op_name
|
|
85
|
+
self.with_int_width = with_int_width
|
|
86
|
+
|
|
87
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
88
|
+
assert self.func is not None
|
|
89
|
+
array_ty = self.func.ty.inputs[1].ty
|
|
90
|
+
elem_ty = get_element_type(array_ty)
|
|
91
|
+
[tag_arg, size_arg] = self.type_args
|
|
92
|
+
[arr] = args
|
|
93
|
+
|
|
94
|
+
# As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
|
|
95
|
+
# that all elements in it are copyable) to avoid linearity violations when
|
|
96
|
+
# both passing it to the result operation and returning it (as an inout
|
|
97
|
+
# argument).
|
|
98
|
+
hugr_elem_ty = elem_ty.to_hugr(self.ctx)
|
|
99
|
+
hugr_size = size_arg.to_hugr(self.ctx)
|
|
100
|
+
arr, out_arr = self.builder.add_op(array_clone(hugr_elem_ty, hugr_size), arr)
|
|
101
|
+
# For bool arrays, we furthermore need to coerce a read on all the array
|
|
102
|
+
# elements
|
|
103
|
+
if is_bool_type(elem_ty):
|
|
104
|
+
array_read = array_read_bool(self.ctx)
|
|
105
|
+
array_read = self.builder.load_function(array_read)
|
|
106
|
+
map_op = array_map(OpaqueBool, hugr_size, tys.Bool)
|
|
107
|
+
arr = self.builder.add_op(map_op, arr, array_read).out(0)
|
|
108
|
+
hugr_elem_ty = tys.Bool
|
|
109
|
+
# Turn `borrow_array` into regular `array`
|
|
110
|
+
arr = self.builder.add_op(array_to_std_array(hugr_elem_ty, hugr_size), arr).out(
|
|
111
|
+
0
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
hugr_ty = hugr.std.collections.array.Array(hugr_elem_ty, hugr_size)
|
|
115
|
+
sig = tys.FunctionType(input=[hugr_ty], output=[])
|
|
116
|
+
args = [tag_to_hugr(tag_arg, self.ctx, self.node), hugr_size]
|
|
117
|
+
if self.with_int_width:
|
|
118
|
+
args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
|
|
119
|
+
op = ops.ExtOp(RESULT_EXTENSION.get_op(self.op_name), signature=sig, args=args)
|
|
120
|
+
self.builder.add_op(op, arr)
|
|
121
|
+
return CallReturnWires([], [out_arr])
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def tag_to_hugr(tag_arg: Argument, ctx: CompilerContext, loc: AstNode) -> tys.TypeArg:
|
|
125
|
+
"""Helper function to convert the Guppy tag comptime argument into a Hugr type arg.
|
|
126
|
+
|
|
127
|
+
Takes care of reading the tag value from the current monomorphization and checks
|
|
128
|
+
that the tag fits into `TAG_MAX_LEN`.
|
|
129
|
+
"""
|
|
130
|
+
is_generic: BoundConstVar | None = None
|
|
131
|
+
match tag_arg:
|
|
132
|
+
case ConstArg(const=ConstValue(value=str(value))):
|
|
133
|
+
tag = value
|
|
134
|
+
case ConstArg(const=BoundConstVar(idx=idx) as var):
|
|
135
|
+
is_generic = var
|
|
136
|
+
assert ctx.current_mono_args is not None
|
|
137
|
+
match ctx.current_mono_args[idx]:
|
|
138
|
+
case ConstArg(const=ConstValue(value=str(value))):
|
|
139
|
+
tag = value
|
|
140
|
+
case _:
|
|
141
|
+
raise InternalGuppyError("Invalid tag monomorphization")
|
|
142
|
+
case _:
|
|
143
|
+
raise InternalGuppyError("Invalid tag argument")
|
|
144
|
+
|
|
145
|
+
if len(tag.encode("utf-8")) > TAG_MAX_LEN:
|
|
146
|
+
err = TooLongError(loc)
|
|
147
|
+
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
148
|
+
if is_generic:
|
|
149
|
+
err.add_sub_diagnostic(
|
|
150
|
+
TooLongError.GenericHint(None, is_generic.display_name, tag)
|
|
151
|
+
)
|
|
152
|
+
raise GuppyError(err)
|
|
153
|
+
return tys.StringArg(tag)
|
|
@@ -73,6 +73,14 @@ def panic(
|
|
|
73
73
|
return ops.ExtOp(op_def, sig, args)
|
|
74
74
|
|
|
75
75
|
|
|
76
|
+
def make_error() -> ops.ExtOp:
|
|
77
|
+
"""Returns an operation that makes an error."""
|
|
78
|
+
op_def = hugr.std.PRELUDE.get_op("MakeError")
|
|
79
|
+
args: list[ht.TypeArg] = []
|
|
80
|
+
sig = ht.FunctionType([ht.USize(), hugr.std.prelude.STRING_T], [error_type()])
|
|
81
|
+
return ops.ExtOp(op_def, sig, args)
|
|
82
|
+
|
|
83
|
+
|
|
76
84
|
# ------------------------------------------------------
|
|
77
85
|
# --------- Custom compilers for non-native ops --------
|
|
78
86
|
# ------------------------------------------------------
|
|
@@ -90,14 +98,14 @@ def build_panic(
|
|
|
90
98
|
return builder.add_op(op, err, *args)
|
|
91
99
|
|
|
92
100
|
|
|
93
|
-
def
|
|
101
|
+
def build_static_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
|
|
94
102
|
"""Constructs and loads a static error value."""
|
|
95
103
|
val = ErrorVal(signal, msg)
|
|
96
104
|
return builder.load(builder.add_const(val))
|
|
97
105
|
|
|
98
106
|
|
|
99
107
|
# TODO: Common up build_unwrap_right and build_unwrap_left below once
|
|
100
|
-
# https://github.com/
|
|
108
|
+
# https://github.com/quantinuum/hugr/issues/1596 is fixed
|
|
101
109
|
|
|
102
110
|
|
|
103
111
|
def build_unwrap_right(
|
|
@@ -111,7 +119,7 @@ def build_unwrap_right(
|
|
|
111
119
|
assert isinstance(result_ty, ht.Sum)
|
|
112
120
|
[left_tys, right_tys] = result_ty.variant_rows
|
|
113
121
|
with conditional.add_case(0) as case:
|
|
114
|
-
error =
|
|
122
|
+
error = build_static_error(case, error_signal, error_msg)
|
|
115
123
|
case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
|
|
116
124
|
with conditional.add_case(1) as case:
|
|
117
125
|
case.set_outputs(*case.inputs())
|
|
@@ -134,7 +142,7 @@ def build_unwrap_left(
|
|
|
134
142
|
with conditional.add_case(0) as case:
|
|
135
143
|
case.set_outputs(*case.inputs())
|
|
136
144
|
with conditional.add_case(1) as case:
|
|
137
|
-
error =
|
|
145
|
+
error = build_static_error(case, error_signal, error_msg)
|
|
138
146
|
case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
|
|
139
147
|
return conditional.to_node()
|
|
140
148
|
|
|
@@ -20,6 +20,7 @@ from tket_exts import (
|
|
|
20
20
|
BOOL_EXTENSION = tket_exts.bool()
|
|
21
21
|
DEBUG_EXTENSION = debug()
|
|
22
22
|
FUTURES_EXTENSION = futures()
|
|
23
|
+
GLOBAL_PHASE_EXTENSION = global_phase()
|
|
23
24
|
GUPPY_EXTENSION = guppy()
|
|
24
25
|
MODIFIER_EXTENSION = modifier()
|
|
25
26
|
QSYSTEM_EXTENSION = qsystem()
|
|
@@ -29,14 +30,14 @@ QUANTUM_EXTENSION = quantum()
|
|
|
29
30
|
RESULT_EXTENSION = result()
|
|
30
31
|
ROTATION_EXTENSION = rotation()
|
|
31
32
|
WASM_EXTENSION = wasm()
|
|
32
|
-
MODIFIER_EXTENSION = modifier()
|
|
33
|
-
GLOBAL_PHASE_EXTENSION = global_phase()
|
|
34
33
|
|
|
35
34
|
TKET_EXTENSIONS = [
|
|
36
35
|
BOOL_EXTENSION,
|
|
37
36
|
DEBUG_EXTENSION,
|
|
38
37
|
FUTURES_EXTENSION,
|
|
38
|
+
GLOBAL_PHASE_EXTENSION,
|
|
39
39
|
GUPPY_EXTENSION,
|
|
40
|
+
MODIFIER_EXTENSION,
|
|
40
41
|
QSYSTEM_EXTENSION,
|
|
41
42
|
QSYSTEM_RANDOM_EXTENSION,
|
|
42
43
|
QSYSTEM_UTILS_EXTENSION,
|
|
@@ -44,8 +45,6 @@ TKET_EXTENSIONS = [
|
|
|
44
45
|
RESULT_EXTENSION,
|
|
45
46
|
ROTATION_EXTENSION,
|
|
46
47
|
WASM_EXTENSION,
|
|
47
|
-
MODIFIER_EXTENSION,
|
|
48
|
-
GLOBAL_PHASE_EXTENSION,
|
|
49
48
|
]
|
|
50
49
|
|
|
51
50
|
|
|
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|
|
3
3
|
from typing import ClassVar, cast
|
|
4
4
|
|
|
5
5
|
from guppylang_internals.ast_util import with_loc
|
|
6
|
+
from guppylang_internals.checker.core import ComptimeVariable
|
|
6
7
|
from guppylang_internals.checker.errors.generic import ExpectedError
|
|
7
8
|
from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
|
|
8
9
|
from guppylang_internals.checker.expr_checker import (
|
|
@@ -14,14 +15,14 @@ from guppylang_internals.definition.custom import CustomCallChecker
|
|
|
14
15
|
from guppylang_internals.definition.ty import TypeDef
|
|
15
16
|
from guppylang_internals.diagnostic import Error
|
|
16
17
|
from guppylang_internals.error import GuppyTypeError
|
|
17
|
-
from guppylang_internals.nodes import StateResultExpr
|
|
18
|
-
from guppylang_internals.std._internal.checker import TAG_MAX_LEN, TooLongError
|
|
18
|
+
from guppylang_internals.nodes import GenericParamValue, PlaceNode, StateResultExpr
|
|
19
19
|
from guppylang_internals.tys.builtin import (
|
|
20
20
|
get_array_length,
|
|
21
21
|
get_element_type,
|
|
22
22
|
is_array_type,
|
|
23
23
|
string_type,
|
|
24
24
|
)
|
|
25
|
+
from guppylang_internals.tys.const import Const, ConstValue
|
|
25
26
|
from guppylang_internals.tys.ty import (
|
|
26
27
|
FuncInput,
|
|
27
28
|
FunctionType,
|
|
@@ -43,12 +44,16 @@ class StateResultChecker(CustomCallChecker):
|
|
|
43
44
|
|
|
44
45
|
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
45
46
|
tag, _ = ExprChecker(self.ctx).check(args[0], string_type())
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
47
|
+
tag_value: Const
|
|
48
|
+
match tag:
|
|
49
|
+
case ast.Constant(value=str(v)):
|
|
50
|
+
tag_value = ConstValue(string_type(), v)
|
|
51
|
+
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
52
|
+
tag_value = ConstValue(string_type(), v)
|
|
53
|
+
case GenericParamValue() as param_value:
|
|
54
|
+
tag_value = param_value.param.to_bound().const
|
|
55
|
+
case _:
|
|
56
|
+
raise GuppyTypeError(ExpectedError(tag, "a string literal"))
|
|
52
57
|
syn_args: list[ast.expr] = [tag]
|
|
53
58
|
|
|
54
59
|
if len(args) < 2:
|
|
@@ -90,6 +95,10 @@ class StateResultChecker(CustomCallChecker):
|
|
|
90
95
|
args, ret_ty, inst = synthesize_call(func_ty, syn_args, self.node, self.ctx)
|
|
91
96
|
assert len(inst) == 0, "func_ty is not generic"
|
|
92
97
|
node = StateResultExpr(
|
|
93
|
-
|
|
98
|
+
tag_value=tag_value,
|
|
99
|
+
tag_expr=tag,
|
|
100
|
+
args=args,
|
|
101
|
+
func_ty=func_ty,
|
|
102
|
+
array_len=array_len,
|
|
94
103
|
)
|
|
95
104
|
return with_loc(self.node, node), ret_ty
|
|
@@ -129,7 +129,7 @@ def int_op(
|
|
|
129
129
|
# Ideally we'd be able to derive the arguments from the input/output types,
|
|
130
130
|
# but the amount of variables does not correlate with the signature for the
|
|
131
131
|
# integer ops in hugr :/
|
|
132
|
-
# https://github.com/
|
|
132
|
+
# https://github.com/quantinuum/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
|
|
133
133
|
#
|
|
134
134
|
# For now, we just instantiate every type argument to a 64-bit integer.
|
|
135
135
|
args: list[ht.TypeArg] = [int_arg() for _ in range(n_vars)]
|
|
@@ -539,6 +539,16 @@ class TracingDefMixin(DunderMixin):
|
|
|
539
539
|
|
|
540
540
|
def to_guppy_object(self) -> GuppyObject:
|
|
541
541
|
state = get_tracing_state()
|
|
542
|
+
defn = ENGINE.get_checked(self.id)
|
|
543
|
+
# TODO: For generic functions, we need to know an instantiation for their type
|
|
544
|
+
# parameters. Maybe we should pass them to `to_guppy_object`? Either way, this
|
|
545
|
+
# will require some more plumbing of type inference information through the
|
|
546
|
+
# comptime logic. For now, let's just bail on generic functions.
|
|
547
|
+
# See https://github.com/quantinuum/guppylang/issues/1336
|
|
548
|
+
if isinstance(defn, CallableDef) and defn.ty.parametrized:
|
|
549
|
+
raise GuppyComptimeError(
|
|
550
|
+
f"Cannot infer type parameters of generic function `{defn.name}`"
|
|
551
|
+
)
|
|
542
552
|
defn, [] = state.ctx.build_compiled_def(self.id, type_args=[])
|
|
543
553
|
if isinstance(defn, CompiledValueDef):
|
|
544
554
|
wire = defn.load(state.dfg, state.ctx, state.node)
|
|
@@ -5,7 +5,7 @@ from guppylang_internals.diagnostic import Error, Help, Note
|
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
7
|
from guppylang_internals.definition.parameter import ParamDef
|
|
8
|
-
from guppylang_internals.tys.ty import Type
|
|
8
|
+
from guppylang_internals.tys.ty import Type, UnitaryFlags
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass(frozen=True)
|
|
@@ -182,3 +182,25 @@ class InvalidFlagError(Error):
|
|
|
182
182
|
class FlagNotAllowedError(Error):
|
|
183
183
|
title: ClassVar[str] = "Invalid annotation"
|
|
184
184
|
span_label: ClassVar[str] = "`@` type annotations are not allowed in this position"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass(frozen=True)
|
|
188
|
+
class UnitaryCallError(Error):
|
|
189
|
+
title: ClassVar[str] = "Unitary constraint violation"
|
|
190
|
+
span_label: ClassVar[str] = (
|
|
191
|
+
"This function cannot be called in a {render_flags} context"
|
|
192
|
+
)
|
|
193
|
+
flags: "UnitaryFlags"
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def render_flags(self) -> str:
|
|
197
|
+
from guppylang_internals.tys.ty import UnitaryFlags
|
|
198
|
+
|
|
199
|
+
if self.flags == UnitaryFlags.Dagger:
|
|
200
|
+
return "dagger"
|
|
201
|
+
elif self.flags == UnitaryFlags.Control:
|
|
202
|
+
return "control"
|
|
203
|
+
elif self.flags == UnitaryFlags.Power:
|
|
204
|
+
return "power"
|
|
205
|
+
else:
|
|
206
|
+
return "unitary"
|
|
@@ -107,7 +107,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
|
|
|
107
107
|
return ConstArg(ConstValue(bool_type(), v))
|
|
108
108
|
# Integer literals are turned into nat args.
|
|
109
109
|
# TODO: To support int args, we need proper inference logic here
|
|
110
|
-
# See https://github.com/
|
|
110
|
+
# See https://github.com/quantinuum/guppylang/issues/1030
|
|
111
111
|
case int(v) if v >= 0:
|
|
112
112
|
nat_ty = NumericType(NumericType.Kind.Nat)
|
|
113
113
|
return ConstArg(ConstValue(nat_ty, v))
|
|
@@ -117,7 +117,7 @@ def arg_from_ast(node: AstNode, ctx: TypeParsingCtx) -> Argument:
|
|
|
117
117
|
# String literals are ignored for now since they could also be stringified
|
|
118
118
|
# types.
|
|
119
119
|
# TODO: To support string args, we need proper inference logic here
|
|
120
|
-
# See https://github.com/
|
|
120
|
+
# See https://github.com/quantinuum/guppylang/issues/1030
|
|
121
121
|
case str(_):
|
|
122
122
|
pass
|
|
123
123
|
|
|
@@ -289,7 +289,7 @@ def check_function_arg(
|
|
|
289
289
|
ctx.param_var_mapping[name] = ConstParam(
|
|
290
290
|
len(ctx.param_var_mapping), name, ty, from_comptime_arg=True
|
|
291
291
|
)
|
|
292
|
-
return FuncInput(ty, flags)
|
|
292
|
+
return FuncInput(ty, flags, name)
|
|
293
293
|
|
|
294
294
|
|
|
295
295
|
if sys.version_info >= (3, 12):
|