guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +21 -0
- guppylang_internals/cfg/bb.py +20 -0
- guppylang_internals/cfg/builder.py +118 -5
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +5 -2
- guppylang_internals/checker/errors/generic.py +32 -1
- guppylang_internals/checker/errors/type_errors.py +14 -0
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +58 -17
- guppylang_internals/checker/func_checker.py +18 -14
- guppylang_internals/checker/linearity_checker.py +67 -10
- guppylang_internals/checker/modifier_checker.py +120 -0
- guppylang_internals/checker/stmt_checker.py +48 -1
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +93 -56
- guppylang_internals/compiler/expr_compiler.py +72 -168
- guppylang_internals/compiler/modifier_compiler.py +176 -0
- guppylang_internals/compiler/stmt_compiler.py +15 -8
- guppylang_internals/decorator.py +86 -7
- guppylang_internals/definition/custom.py +39 -1
- guppylang_internals/definition/declaration.py +9 -6
- guppylang_internals/definition/function.py +12 -2
- guppylang_internals/definition/parameter.py +8 -3
- guppylang_internals/definition/pytket_circuits.py +14 -41
- guppylang_internals/definition/struct.py +13 -7
- guppylang_internals/definition/ty.py +3 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/engine.py +9 -3
- guppylang_internals/experimental.py +5 -0
- guppylang_internals/nodes.py +147 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +95 -283
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +10 -0
- guppylang_internals/tracing/unpacking.py +19 -20
- guppylang_internals/tys/arg.py +18 -3
- guppylang_internals/tys/builtin.py +2 -5
- guppylang_internals/tys/const.py +33 -4
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/param.py +31 -16
- guppylang_internals/tys/parsing.py +11 -24
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +62 -0
- guppylang_internals/tys/subst.py +8 -26
- guppylang_internals/tys/ty.py +91 -85
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
- guppylang_internals-0.26.0.dist-info/RECORD +104 -0
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
- guppylang_internals-0.24.0.dist-info/RECORD +0 -98
- {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -328,7 +328,7 @@ def _list_new_classical(
|
|
|
328
328
|
builder: DfBase[ops.DfParentOp], elem_type: ht.Type, args: list[Wire]
|
|
329
329
|
) -> Wire:
|
|
330
330
|
# This may be simplified in the future with a `new` or `with_capacity` list op
|
|
331
|
-
# See https://github.com/
|
|
331
|
+
# See https://github.com/quantinuum/hugr/issues/1508
|
|
332
332
|
lst = builder.load(ListVal([], elem_ty=elem_type))
|
|
333
333
|
push_op = list_push(elem_type)
|
|
334
334
|
for elem in args:
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import ClassVar
|
|
3
|
+
|
|
4
|
+
import hugr
|
|
5
|
+
from hugr import Wire, ops, tys
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.ast_util import AstNode
|
|
8
|
+
from guppylang_internals.compiler.core import CompilerContext
|
|
9
|
+
from guppylang_internals.compiler.expr_compiler import array_read_bool
|
|
10
|
+
from guppylang_internals.definition.custom import (
|
|
11
|
+
CustomCallCompiler,
|
|
12
|
+
CustomInoutCallCompiler,
|
|
13
|
+
)
|
|
14
|
+
from guppylang_internals.definition.value import CallReturnWires
|
|
15
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
16
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
17
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
18
|
+
array_clone,
|
|
19
|
+
array_map,
|
|
20
|
+
array_to_std_array,
|
|
21
|
+
)
|
|
22
|
+
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, read_bool
|
|
23
|
+
from guppylang_internals.std._internal.compiler.tket_exts import RESULT_EXTENSION
|
|
24
|
+
from guppylang_internals.tys.arg import Argument, ConstArg
|
|
25
|
+
from guppylang_internals.tys.builtin import get_element_type, is_bool_type
|
|
26
|
+
from guppylang_internals.tys.const import BoundConstVar, ConstValue
|
|
27
|
+
from guppylang_internals.tys.ty import NumericType
|
|
28
|
+
|
|
29
|
+
#: Maximum length of a tag in the `result` function.
|
|
30
|
+
TAG_MAX_LEN = 200
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class TooLongError(Error):
|
|
35
|
+
title: ClassVar[str] = "Tag too long"
|
|
36
|
+
span_label: ClassVar[str] = "Result tag is too long"
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class Hint(Note):
|
|
40
|
+
message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class GenericHint(Note):
|
|
44
|
+
message: ClassVar[str] = "Parameter `{param}` was instantiated to `{value}`"
|
|
45
|
+
param: str
|
|
46
|
+
value: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ResultCompiler(CustomCallCompiler):
|
|
50
|
+
"""Custom compiler for overloads of the `result` function.
|
|
51
|
+
|
|
52
|
+
See `ArrayResultCompiler` for the compiler that handles results involving arrays.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, op_name: str, with_int_width: bool = False):
|
|
56
|
+
self.op_name = op_name
|
|
57
|
+
self.with_int_width = with_int_width
|
|
58
|
+
|
|
59
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
60
|
+
assert self.func is not None
|
|
61
|
+
[value] = args
|
|
62
|
+
ty = self.func.ty.inputs[1].ty
|
|
63
|
+
hugr_ty = ty.to_hugr(self.ctx)
|
|
64
|
+
args = [tag_to_hugr(self.type_args[0], self.ctx, self.node)]
|
|
65
|
+
if self.with_int_width:
|
|
66
|
+
args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
|
|
67
|
+
# Bool results need an extra conversion into regular hugr bools
|
|
68
|
+
if is_bool_type(ty):
|
|
69
|
+
value = self.builder.add_op(read_bool(), value)
|
|
70
|
+
hugr_ty = tys.Bool
|
|
71
|
+
op = RESULT_EXTENSION.get_op(self.op_name)
|
|
72
|
+
sig = tys.FunctionType(input=[hugr_ty], output=[])
|
|
73
|
+
self.builder.add_op(op.instantiate(args, sig), value)
|
|
74
|
+
return []
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ArrayResultCompiler(CustomInoutCallCompiler):
|
|
78
|
+
"""Custom compiler for overloads of the `result` function accepting arrays.
|
|
79
|
+
|
|
80
|
+
See `ResultCompiler` for the compiler that handles basic results.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, op_name: str, with_int_width: bool = False):
|
|
84
|
+
self.op_name = op_name
|
|
85
|
+
self.with_int_width = with_int_width
|
|
86
|
+
|
|
87
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
88
|
+
assert self.func is not None
|
|
89
|
+
array_ty = self.func.ty.inputs[1].ty
|
|
90
|
+
elem_ty = get_element_type(array_ty)
|
|
91
|
+
[tag_arg, size_arg] = self.type_args
|
|
92
|
+
[arr] = args
|
|
93
|
+
|
|
94
|
+
# As `borrow_array`s used by Guppy are linear, we need to clone it (knowing
|
|
95
|
+
# that all elements in it are copyable) to avoid linearity violations when
|
|
96
|
+
# both passing it to the result operation and returning it (as an inout
|
|
97
|
+
# argument).
|
|
98
|
+
hugr_elem_ty = elem_ty.to_hugr(self.ctx)
|
|
99
|
+
hugr_size = size_arg.to_hugr(self.ctx)
|
|
100
|
+
arr, out_arr = self.builder.add_op(array_clone(hugr_elem_ty, hugr_size), arr)
|
|
101
|
+
# For bool arrays, we furthermore need to coerce a read on all the array
|
|
102
|
+
# elements
|
|
103
|
+
if is_bool_type(elem_ty):
|
|
104
|
+
array_read = array_read_bool(self.ctx)
|
|
105
|
+
array_read = self.builder.load_function(array_read)
|
|
106
|
+
map_op = array_map(OpaqueBool, hugr_size, tys.Bool)
|
|
107
|
+
arr = self.builder.add_op(map_op, arr, array_read).out(0)
|
|
108
|
+
hugr_elem_ty = tys.Bool
|
|
109
|
+
# Turn `borrow_array` into regular `array`
|
|
110
|
+
arr = self.builder.add_op(array_to_std_array(hugr_elem_ty, hugr_size), arr).out(
|
|
111
|
+
0
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
hugr_ty = hugr.std.collections.array.Array(hugr_elem_ty, hugr_size)
|
|
115
|
+
sig = tys.FunctionType(input=[hugr_ty], output=[])
|
|
116
|
+
args = [tag_to_hugr(tag_arg, self.ctx, self.node), hugr_size]
|
|
117
|
+
if self.with_int_width:
|
|
118
|
+
args.append(tys.BoundedNatArg(NumericType.INT_WIDTH))
|
|
119
|
+
op = ops.ExtOp(RESULT_EXTENSION.get_op(self.op_name), signature=sig, args=args)
|
|
120
|
+
self.builder.add_op(op, arr)
|
|
121
|
+
return CallReturnWires([], [out_arr])
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def tag_to_hugr(tag_arg: Argument, ctx: CompilerContext, loc: AstNode) -> tys.TypeArg:
|
|
125
|
+
"""Helper function to convert the Guppy tag comptime argument into a Hugr type arg.
|
|
126
|
+
|
|
127
|
+
Takes care of reading the tag value from the current monomorphization and checks
|
|
128
|
+
that the tag fits into `TAG_MAX_LEN`.
|
|
129
|
+
"""
|
|
130
|
+
is_generic: BoundConstVar | None = None
|
|
131
|
+
match tag_arg:
|
|
132
|
+
case ConstArg(const=ConstValue(value=str(value))):
|
|
133
|
+
tag = value
|
|
134
|
+
case ConstArg(const=BoundConstVar(idx=idx) as var):
|
|
135
|
+
is_generic = var
|
|
136
|
+
assert ctx.current_mono_args is not None
|
|
137
|
+
match ctx.current_mono_args[idx]:
|
|
138
|
+
case ConstArg(const=ConstValue(value=str(value))):
|
|
139
|
+
tag = value
|
|
140
|
+
case _:
|
|
141
|
+
raise InternalGuppyError("Invalid tag monomorphization")
|
|
142
|
+
case _:
|
|
143
|
+
raise InternalGuppyError("Invalid tag argument")
|
|
144
|
+
|
|
145
|
+
if len(tag.encode("utf-8")) > TAG_MAX_LEN:
|
|
146
|
+
err = TooLongError(loc)
|
|
147
|
+
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
148
|
+
if is_generic:
|
|
149
|
+
err.add_sub_diagnostic(
|
|
150
|
+
TooLongError.GenericHint(None, is_generic.display_name, tag)
|
|
151
|
+
)
|
|
152
|
+
raise GuppyError(err)
|
|
153
|
+
return tys.StringArg(tag)
|
|
@@ -73,6 +73,14 @@ def panic(
|
|
|
73
73
|
return ops.ExtOp(op_def, sig, args)
|
|
74
74
|
|
|
75
75
|
|
|
76
|
+
def make_error() -> ops.ExtOp:
|
|
77
|
+
"""Returns an operation that makes an error."""
|
|
78
|
+
op_def = hugr.std.PRELUDE.get_op("MakeError")
|
|
79
|
+
args: list[ht.TypeArg] = []
|
|
80
|
+
sig = ht.FunctionType([ht.USize(), hugr.std.prelude.STRING_T], [error_type()])
|
|
81
|
+
return ops.ExtOp(op_def, sig, args)
|
|
82
|
+
|
|
83
|
+
|
|
76
84
|
# ------------------------------------------------------
|
|
77
85
|
# --------- Custom compilers for non-native ops --------
|
|
78
86
|
# ------------------------------------------------------
|
|
@@ -90,14 +98,14 @@ def build_panic(
|
|
|
90
98
|
return builder.add_op(op, err, *args)
|
|
91
99
|
|
|
92
100
|
|
|
93
|
-
def
|
|
101
|
+
def build_static_error(builder: DfBase[P], signal: int, msg: str) -> Wire:
|
|
94
102
|
"""Constructs and loads a static error value."""
|
|
95
103
|
val = ErrorVal(signal, msg)
|
|
96
104
|
return builder.load(builder.add_const(val))
|
|
97
105
|
|
|
98
106
|
|
|
99
107
|
# TODO: Common up build_unwrap_right and build_unwrap_left below once
|
|
100
|
-
# https://github.com/
|
|
108
|
+
# https://github.com/quantinuum/hugr/issues/1596 is fixed
|
|
101
109
|
|
|
102
110
|
|
|
103
111
|
def build_unwrap_right(
|
|
@@ -111,7 +119,7 @@ def build_unwrap_right(
|
|
|
111
119
|
assert isinstance(result_ty, ht.Sum)
|
|
112
120
|
[left_tys, right_tys] = result_ty.variant_rows
|
|
113
121
|
with conditional.add_case(0) as case:
|
|
114
|
-
error =
|
|
122
|
+
error = build_static_error(case, error_signal, error_msg)
|
|
115
123
|
case.set_outputs(*build_panic(case, left_tys, right_tys, error, *case.inputs()))
|
|
116
124
|
with conditional.add_case(1) as case:
|
|
117
125
|
case.set_outputs(*case.inputs())
|
|
@@ -134,7 +142,7 @@ def build_unwrap_left(
|
|
|
134
142
|
with conditional.add_case(0) as case:
|
|
135
143
|
case.set_outputs(*case.inputs())
|
|
136
144
|
with conditional.add_case(1) as case:
|
|
137
|
-
error =
|
|
145
|
+
error = build_static_error(case, error_signal, error_msg)
|
|
138
146
|
case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs()))
|
|
139
147
|
return conditional.to_node()
|
|
140
148
|
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
|
|
3
|
+
import tket_exts
|
|
3
4
|
from hugr import val
|
|
4
5
|
from tket_exts import (
|
|
5
6
|
debug,
|
|
6
7
|
futures,
|
|
8
|
+
global_phase,
|
|
7
9
|
guppy,
|
|
8
|
-
|
|
10
|
+
modifier,
|
|
9
11
|
qsystem,
|
|
10
12
|
qsystem_random,
|
|
11
13
|
qsystem_utils,
|
|
@@ -15,10 +17,12 @@ from tket_exts import (
|
|
|
15
17
|
wasm,
|
|
16
18
|
)
|
|
17
19
|
|
|
18
|
-
BOOL_EXTENSION =
|
|
20
|
+
BOOL_EXTENSION = tket_exts.bool()
|
|
19
21
|
DEBUG_EXTENSION = debug()
|
|
20
22
|
FUTURES_EXTENSION = futures()
|
|
23
|
+
GLOBAL_PHASE_EXTENSION = global_phase()
|
|
21
24
|
GUPPY_EXTENSION = guppy()
|
|
25
|
+
MODIFIER_EXTENSION = modifier()
|
|
22
26
|
QSYSTEM_EXTENSION = qsystem()
|
|
23
27
|
QSYSTEM_RANDOM_EXTENSION = qsystem_random()
|
|
24
28
|
QSYSTEM_UTILS_EXTENSION = qsystem_utils()
|
|
@@ -31,7 +35,9 @@ TKET_EXTENSIONS = [
|
|
|
31
35
|
BOOL_EXTENSION,
|
|
32
36
|
DEBUG_EXTENSION,
|
|
33
37
|
FUTURES_EXTENSION,
|
|
38
|
+
GLOBAL_PHASE_EXTENSION,
|
|
34
39
|
GUPPY_EXTENSION,
|
|
40
|
+
MODIFIER_EXTENSION,
|
|
35
41
|
QSYSTEM_EXTENSION,
|
|
36
42
|
QSYSTEM_RANDOM_EXTENSION,
|
|
37
43
|
QSYSTEM_UTILS_EXTENSION,
|
|
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|
|
3
3
|
from typing import ClassVar, cast
|
|
4
4
|
|
|
5
5
|
from guppylang_internals.ast_util import with_loc
|
|
6
|
+
from guppylang_internals.checker.core import ComptimeVariable
|
|
6
7
|
from guppylang_internals.checker.errors.generic import ExpectedError
|
|
7
8
|
from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
|
|
8
9
|
from guppylang_internals.checker.expr_checker import (
|
|
@@ -14,14 +15,14 @@ from guppylang_internals.definition.custom import CustomCallChecker
|
|
|
14
15
|
from guppylang_internals.definition.ty import TypeDef
|
|
15
16
|
from guppylang_internals.diagnostic import Error
|
|
16
17
|
from guppylang_internals.error import GuppyTypeError
|
|
17
|
-
from guppylang_internals.nodes import StateResultExpr
|
|
18
|
-
from guppylang_internals.std._internal.checker import TAG_MAX_LEN, TooLongError
|
|
18
|
+
from guppylang_internals.nodes import GenericParamValue, PlaceNode, StateResultExpr
|
|
19
19
|
from guppylang_internals.tys.builtin import (
|
|
20
20
|
get_array_length,
|
|
21
21
|
get_element_type,
|
|
22
22
|
is_array_type,
|
|
23
23
|
string_type,
|
|
24
24
|
)
|
|
25
|
+
from guppylang_internals.tys.const import Const, ConstValue
|
|
25
26
|
from guppylang_internals.tys.ty import (
|
|
26
27
|
FuncInput,
|
|
27
28
|
FunctionType,
|
|
@@ -43,12 +44,16 @@ class StateResultChecker(CustomCallChecker):
|
|
|
43
44
|
|
|
44
45
|
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
45
46
|
tag, _ = ExprChecker(self.ctx).check(args[0], string_type())
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
47
|
+
tag_value: Const
|
|
48
|
+
match tag:
|
|
49
|
+
case ast.Constant(value=str(v)):
|
|
50
|
+
tag_value = ConstValue(string_type(), v)
|
|
51
|
+
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
52
|
+
tag_value = ConstValue(string_type(), v)
|
|
53
|
+
case GenericParamValue() as param_value:
|
|
54
|
+
tag_value = param_value.param.to_bound().const
|
|
55
|
+
case _:
|
|
56
|
+
raise GuppyTypeError(ExpectedError(tag, "a string literal"))
|
|
52
57
|
syn_args: list[ast.expr] = [tag]
|
|
53
58
|
|
|
54
59
|
if len(args) < 2:
|
|
@@ -90,6 +95,10 @@ class StateResultChecker(CustomCallChecker):
|
|
|
90
95
|
args, ret_ty, inst = synthesize_call(func_ty, syn_args, self.node, self.ctx)
|
|
91
96
|
assert len(inst) == 0, "func_ty is not generic"
|
|
92
97
|
node = StateResultExpr(
|
|
93
|
-
|
|
98
|
+
tag_value=tag_value,
|
|
99
|
+
tag_expr=tag,
|
|
100
|
+
args=args,
|
|
101
|
+
func_ty=func_ty,
|
|
102
|
+
array_len=array_len,
|
|
94
103
|
)
|
|
95
104
|
return with_loc(self.node, node), ret_ty
|
|
@@ -129,7 +129,7 @@ def int_op(
|
|
|
129
129
|
# Ideally we'd be able to derive the arguments from the input/output types,
|
|
130
130
|
# but the amount of variables does not correlate with the signature for the
|
|
131
131
|
# integer ops in hugr :/
|
|
132
|
-
# https://github.com/
|
|
132
|
+
# https://github.com/quantinuum/hugr/blob/bfa13e59468feb0fc746677ea3b3a4341b2ed42e/hugr-core/src/std_extensions/arithmetic/int_ops.rs#L116
|
|
133
133
|
#
|
|
134
134
|
# For now, we just instantiate every type argument to a 64-bit integer.
|
|
135
135
|
args: list[ht.TypeArg] = [int_arg() for _ in range(n_vars)]
|
|
@@ -539,6 +539,16 @@ class TracingDefMixin(DunderMixin):
|
|
|
539
539
|
|
|
540
540
|
def to_guppy_object(self) -> GuppyObject:
|
|
541
541
|
state = get_tracing_state()
|
|
542
|
+
defn = ENGINE.get_checked(self.id)
|
|
543
|
+
# TODO: For generic functions, we need to know an instantiation for their type
|
|
544
|
+
# parameters. Maybe we should pass them to `to_guppy_object`? Either way, this
|
|
545
|
+
# will require some more plumbing of type inference information through the
|
|
546
|
+
# comptime logic. For now, let's just bail on generic functions.
|
|
547
|
+
# See https://github.com/quantinuum/guppylang/issues/1336
|
|
548
|
+
if isinstance(defn, CallableDef) and defn.ty.parametrized:
|
|
549
|
+
raise GuppyComptimeError(
|
|
550
|
+
f"Cannot infer type parameters of generic function `{defn.name}`"
|
|
551
|
+
)
|
|
542
552
|
defn, [] = state.ctx.build_compiled_def(self.id, type_args=[])
|
|
543
553
|
if isinstance(defn, CompiledValueDef):
|
|
544
554
|
wire = defn.load(state.dfg, state.ctx, state.node)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Any, TypeVar
|
|
2
2
|
|
|
3
3
|
from hugr import ops
|
|
4
|
-
from hugr import tys as ht
|
|
5
4
|
from hugr.build.dfg import DfBase
|
|
6
5
|
|
|
7
6
|
from guppylang_internals.ast_util import AstNode
|
|
@@ -13,7 +12,6 @@ from guppylang_internals.compiler.core import CompilerContext
|
|
|
13
12
|
from guppylang_internals.compiler.expr_compiler import python_value_to_hugr
|
|
14
13
|
from guppylang_internals.error import GuppyComptimeError, GuppyError
|
|
15
14
|
from guppylang_internals.std._internal.compiler.array import array_new, unpack_array
|
|
16
|
-
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
17
15
|
from guppylang_internals.tracing.frozenlist import frozenlist
|
|
18
16
|
from guppylang_internals.tracing.object import (
|
|
19
17
|
GuppyObject,
|
|
@@ -71,9 +69,7 @@ def unpack_guppy_object(
|
|
|
71
69
|
# them as Guppy objects here
|
|
72
70
|
return obj
|
|
73
71
|
elem_ty = get_element_type(ty)
|
|
74
|
-
|
|
75
|
-
err = "Non-copyable array element has already been used"
|
|
76
|
-
elems = [build_unwrap(builder, opt_elem, err) for opt_elem in opt_elems]
|
|
72
|
+
elems = unpack_array(builder, obj._use_wire(None))
|
|
77
73
|
obj_list = [
|
|
78
74
|
unpack_guppy_object(GuppyObject(elem_ty, wire), builder, frozen)
|
|
79
75
|
for wire in elems
|
|
@@ -128,11 +124,8 @@ def guppy_object_from_py(
|
|
|
128
124
|
f"Element at index {i + 1} does not match the type of "
|
|
129
125
|
f"previous elements. Expected `{elem_ty}`, got `{obj._ty}`."
|
|
130
126
|
)
|
|
131
|
-
hugr_elem_ty =
|
|
132
|
-
wires = [
|
|
133
|
-
builder.add_op(ops.Tag(1, hugr_elem_ty), obj._use_wire(None))
|
|
134
|
-
for obj in objs
|
|
135
|
-
]
|
|
127
|
+
hugr_elem_ty = elem_ty.to_hugr(ctx)
|
|
128
|
+
wires = [obj._use_wire(None) for obj in objs]
|
|
136
129
|
return GuppyObject(
|
|
137
130
|
array_type(elem_ty, len(vs)),
|
|
138
131
|
builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires),
|
|
@@ -172,26 +165,32 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> bool:
|
|
|
172
165
|
assert isinstance(obj._ty, NoneType)
|
|
173
166
|
case tuple(vs):
|
|
174
167
|
assert isinstance(obj._ty, TupleType)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
168
|
+
wire_iterator = builder.add_op(
|
|
169
|
+
ops.UnpackTuple(), obj._use_wire(None)
|
|
170
|
+
).outputs()
|
|
171
|
+
for v, ty, out_wire in zip(
|
|
172
|
+
vs, obj._ty.element_types, wire_iterator, strict=True
|
|
173
|
+
):
|
|
174
|
+
success = update_packed_value(v, GuppyObject(ty, out_wire), builder)
|
|
178
175
|
if not success:
|
|
179
176
|
return False
|
|
180
177
|
case GuppyStructObject(_ty=ty, _field_values=values):
|
|
181
178
|
assert obj._ty == ty
|
|
182
|
-
|
|
183
|
-
|
|
179
|
+
wire_iterator = builder.add_op(
|
|
180
|
+
ops.UnpackTuple(), obj._use_wire(None)
|
|
181
|
+
).outputs()
|
|
182
|
+
for field, out_wire in zip(ty.fields, wire_iterator, strict=True):
|
|
184
183
|
v = values[field.name]
|
|
185
|
-
success = update_packed_value(
|
|
184
|
+
success = update_packed_value(
|
|
185
|
+
v, GuppyObject(field.ty, out_wire), builder
|
|
186
|
+
)
|
|
186
187
|
if not success:
|
|
187
188
|
values[field.name] = obj
|
|
188
189
|
case list(vs) if len(vs) > 0:
|
|
189
190
|
assert is_array_type(obj._ty)
|
|
190
191
|
elem_ty = get_element_type(obj._ty)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
for i, (v, opt_wire) in enumerate(zip(vs, opt_wires, strict=True)):
|
|
194
|
-
(wire,) = build_unwrap(builder, opt_wire, err).outputs()
|
|
192
|
+
wires = unpack_array(builder, obj._use_wire(None))
|
|
193
|
+
for i, (v, wire) in enumerate(zip(vs, wires, strict=True)):
|
|
195
194
|
success = update_packed_value(v, GuppyObject(elem_ty, wire), builder)
|
|
196
195
|
if not success:
|
|
197
196
|
vs[i] = obj
|
guppylang_internals/tys/arg.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
3
|
from typing import TYPE_CHECKING, TypeAlias
|
|
4
4
|
|
|
5
5
|
from hugr import tys as ht
|
|
@@ -18,7 +18,7 @@ from guppylang_internals.tys.const import (
|
|
|
18
18
|
ConstValue,
|
|
19
19
|
ExistentialConstVar,
|
|
20
20
|
)
|
|
21
|
-
from guppylang_internals.tys.var import ExistentialVar
|
|
21
|
+
from guppylang_internals.tys.var import BoundVar, ExistentialVar
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
24
|
from guppylang_internals.tys.ty import Type
|
|
@@ -45,19 +45,29 @@ class ArgumentBase(ToHugr[ht.TypeArg], Transformable["Argument"], ABC):
|
|
|
45
45
|
def unsolved_vars(self) -> set[ExistentialVar]:
|
|
46
46
|
"""The existential type variables contained in this argument."""
|
|
47
47
|
|
|
48
|
+
@property
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
51
|
+
"""The bound type variables contained in this argument."""
|
|
52
|
+
|
|
48
53
|
|
|
49
54
|
@dataclass(frozen=True)
|
|
50
55
|
class TypeArg(ArgumentBase):
|
|
51
56
|
"""Argument that can be instantiated for a `TypeParameter`."""
|
|
52
57
|
|
|
53
58
|
# The type to instantiate
|
|
54
|
-
ty: "Type"
|
|
59
|
+
ty: "Type" = field(hash=False) # Types are not hashable
|
|
55
60
|
|
|
56
61
|
@property
|
|
57
62
|
def unsolved_vars(self) -> set[ExistentialVar]:
|
|
58
63
|
"""The existential type variables contained in this argument."""
|
|
59
64
|
return self.ty.unsolved_vars
|
|
60
65
|
|
|
66
|
+
@property
|
|
67
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
68
|
+
"""The bound type variables contained in this type."""
|
|
69
|
+
return self.ty.bound_vars
|
|
70
|
+
|
|
61
71
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeTypeArg:
|
|
62
72
|
"""Computes the Hugr representation of the argument."""
|
|
63
73
|
ty: ht.Type = self.ty.to_hugr(ctx)
|
|
@@ -84,6 +94,11 @@ class ConstArg(ArgumentBase):
|
|
|
84
94
|
"""The existential const variables contained in this argument."""
|
|
85
95
|
return self.const.unsolved_vars
|
|
86
96
|
|
|
97
|
+
@property
|
|
98
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
99
|
+
"""The bound type variables contained in this argument."""
|
|
100
|
+
return self.const.bound_vars
|
|
101
|
+
|
|
87
102
|
def to_hugr(self, ctx: ToHugrContext) -> ht.TypeArg:
|
|
88
103
|
"""Computes the Hugr representation of this argument."""
|
|
89
104
|
from guppylang_internals.tys.ty import NumericType
|
|
@@ -177,13 +177,10 @@ def _array_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
|
|
|
177
177
|
assert isinstance(ty_arg, TypeArg)
|
|
178
178
|
assert isinstance(len_arg, ConstArg)
|
|
179
179
|
|
|
180
|
-
|
|
181
|
-
# See `ArrayGetitemCompiler` for details.
|
|
182
|
-
# Same also for classical arrays, see https://github.com/CQCL/guppylang/issues/629
|
|
183
|
-
elem_ty = ht.Option(ty_arg.ty.to_hugr(ctx))
|
|
180
|
+
elem_ty = ty_arg.ty.to_hugr(ctx)
|
|
184
181
|
hugr_arg = len_arg.to_hugr(ctx)
|
|
185
182
|
|
|
186
|
-
return hugr.std.collections.
|
|
183
|
+
return hugr.std.collections.borrow_array.BorrowArray(elem_ty, hugr_arg)
|
|
187
184
|
|
|
188
185
|
|
|
189
186
|
def _frozenarray_to_hugr(args: Sequence[Argument], ctx: ToHugrContext) -> ht.Type:
|
guppylang_internals/tys/const.py
CHANGED
|
@@ -8,6 +8,7 @@ from guppylang_internals.tys.var import BoundVar, ExistentialVar
|
|
|
8
8
|
|
|
9
9
|
if TYPE_CHECKING:
|
|
10
10
|
from guppylang_internals.tys.arg import ConstArg
|
|
11
|
+
from guppylang_internals.tys.subst import Subst
|
|
11
12
|
from guppylang_internals.tys.ty import Type
|
|
12
13
|
|
|
13
14
|
|
|
@@ -39,6 +40,11 @@ class ConstBase(Transformable["Const"], ABC):
|
|
|
39
40
|
"""The existential type variables contained in this constant."""
|
|
40
41
|
return set()
|
|
41
42
|
|
|
43
|
+
@property
|
|
44
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
45
|
+
"""The bound type variables contained in this constant."""
|
|
46
|
+
return self.ty.bound_vars
|
|
47
|
+
|
|
42
48
|
def __str__(self) -> str:
|
|
43
49
|
from guppylang_internals.tys.printing import TypePrinter
|
|
44
50
|
|
|
@@ -48,16 +54,18 @@ class ConstBase(Transformable["Const"], ABC):
|
|
|
48
54
|
"""Accepts a visitor on this constant."""
|
|
49
55
|
visitor.visit(self)
|
|
50
56
|
|
|
51
|
-
def transform(self, transformer: Transformer, /) -> "Const":
|
|
52
|
-
"""Accepts a transformer on this constant."""
|
|
53
|
-
return transformer.transform(self) or self.cast()
|
|
54
|
-
|
|
55
57
|
def to_arg(self) -> "ConstArg":
|
|
56
58
|
"""Wraps this constant into a type argument."""
|
|
57
59
|
from guppylang_internals.tys.arg import ConstArg
|
|
58
60
|
|
|
59
61
|
return ConstArg(self.cast())
|
|
60
62
|
|
|
63
|
+
def substitute(self, subst: "Subst") -> "Const":
|
|
64
|
+
"""Substitutes existential variables in this constant."""
|
|
65
|
+
from guppylang_internals.tys.subst import Substituter
|
|
66
|
+
|
|
67
|
+
return self.transform(Substituter(subst))
|
|
68
|
+
|
|
61
69
|
|
|
62
70
|
@dataclass(frozen=True)
|
|
63
71
|
class ConstValue(ConstBase):
|
|
@@ -74,6 +82,10 @@ class ConstValue(ConstBase):
|
|
|
74
82
|
"""Casts an implementor of `ConstBase` into a `Const`."""
|
|
75
83
|
return self
|
|
76
84
|
|
|
85
|
+
def transform(self, transformer: Transformer, /) -> "Const":
|
|
86
|
+
"""Accepts a transformer on this constant."""
|
|
87
|
+
return transformer.transform(self) or self
|
|
88
|
+
|
|
77
89
|
|
|
78
90
|
@dataclass(frozen=True)
|
|
79
91
|
class BoundConstVar(BoundVar, ConstBase):
|
|
@@ -84,10 +96,21 @@ class BoundConstVar(BoundVar, ConstBase):
|
|
|
84
96
|
`BoundConstVar(idx=0)`.
|
|
85
97
|
"""
|
|
86
98
|
|
|
99
|
+
@property
|
|
100
|
+
def bound_vars(self) -> set[BoundVar]:
|
|
101
|
+
"""The bound type variables contained in this constant."""
|
|
102
|
+
return {self} | self.ty.bound_vars
|
|
103
|
+
|
|
87
104
|
def cast(self) -> "Const":
|
|
88
105
|
"""Casts an implementor of `ConstBase` into a `Const`."""
|
|
89
106
|
return self
|
|
90
107
|
|
|
108
|
+
def transform(self, transformer: Transformer, /) -> "Const":
|
|
109
|
+
"""Accepts a transformer on this constant."""
|
|
110
|
+
return transformer.transform(self) or BoundConstVar(
|
|
111
|
+
transformer.transform(self.ty) or self.ty, self.display_name, self.idx
|
|
112
|
+
)
|
|
113
|
+
|
|
91
114
|
|
|
92
115
|
@dataclass(frozen=True)
|
|
93
116
|
class ExistentialConstVar(ExistentialVar, ConstBase):
|
|
@@ -110,5 +133,11 @@ class ExistentialConstVar(ExistentialVar, ConstBase):
|
|
|
110
133
|
"""Casts an implementor of `ConstBase` into a `Const`."""
|
|
111
134
|
return self
|
|
112
135
|
|
|
136
|
+
def transform(self, transformer: Transformer, /) -> "Const":
|
|
137
|
+
"""Accepts a transformer on this constant."""
|
|
138
|
+
return transformer.transform(self) or ExistentialConstVar(
|
|
139
|
+
transformer.transform(self.ty) or self.ty, self.display_name, self.id
|
|
140
|
+
)
|
|
141
|
+
|
|
113
142
|
|
|
114
143
|
Const: TypeAlias = ConstValue | BoundConstVar | ExistentialConstVar
|
|
@@ -5,7 +5,7 @@ from guppylang_internals.diagnostic import Error, Help, Note
|
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
7
|
from guppylang_internals.definition.parameter import ParamDef
|
|
8
|
-
from guppylang_internals.tys.ty import Type
|
|
8
|
+
from guppylang_internals.tys.ty import Type, UnitaryFlags
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass(frozen=True)
|
|
@@ -182,3 +182,25 @@ class InvalidFlagError(Error):
|
|
|
182
182
|
class FlagNotAllowedError(Error):
|
|
183
183
|
title: ClassVar[str] = "Invalid annotation"
|
|
184
184
|
span_label: ClassVar[str] = "`@` type annotations are not allowed in this position"
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass(frozen=True)
|
|
188
|
+
class UnitaryCallError(Error):
|
|
189
|
+
title: ClassVar[str] = "Unitary constraint violation"
|
|
190
|
+
span_label: ClassVar[str] = (
|
|
191
|
+
"This function cannot be called in a {render_flags} context"
|
|
192
|
+
)
|
|
193
|
+
flags: "UnitaryFlags"
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def render_flags(self) -> str:
|
|
197
|
+
from guppylang_internals.tys.ty import UnitaryFlags
|
|
198
|
+
|
|
199
|
+
if self.flags == UnitaryFlags.Dagger:
|
|
200
|
+
return "dagger"
|
|
201
|
+
elif self.flags == UnitaryFlags.Control:
|
|
202
|
+
return "control"
|
|
203
|
+
elif self.flags == UnitaryFlags.Power:
|
|
204
|
+
return "power"
|
|
205
|
+
else:
|
|
206
|
+
return "unitary"
|