guppylang-internals 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +37 -18
- guppylang_internals/cfg/analysis.py +6 -6
- guppylang_internals/cfg/builder.py +44 -12
- guppylang_internals/cfg/cfg.py +1 -1
- guppylang_internals/checker/core.py +1 -1
- guppylang_internals/checker/errors/comptime_errors.py +0 -12
- guppylang_internals/checker/errors/linearity.py +6 -2
- guppylang_internals/checker/expr_checker.py +53 -28
- guppylang_internals/checker/func_checker.py +4 -3
- guppylang_internals/checker/stmt_checker.py +1 -1
- guppylang_internals/compiler/cfg_compiler.py +1 -1
- guppylang_internals/compiler/core.py +17 -4
- guppylang_internals/compiler/expr_compiler.py +36 -14
- guppylang_internals/compiler/modifier_compiler.py +5 -2
- guppylang_internals/decorator.py +5 -3
- guppylang_internals/definition/common.py +1 -0
- guppylang_internals/definition/custom.py +2 -2
- guppylang_internals/definition/declaration.py +3 -3
- guppylang_internals/definition/function.py +28 -8
- guppylang_internals/definition/metadata.py +87 -0
- guppylang_internals/definition/overloaded.py +11 -2
- guppylang_internals/definition/pytket_circuits.py +50 -67
- guppylang_internals/definition/value.py +1 -1
- guppylang_internals/definition/wasm.py +3 -3
- guppylang_internals/diagnostic.py +89 -16
- guppylang_internals/engine.py +84 -40
- guppylang_internals/error.py +1 -1
- guppylang_internals/nodes.py +301 -3
- guppylang_internals/span.py +7 -3
- guppylang_internals/std/_internal/checker.py +104 -2
- guppylang_internals/std/_internal/compiler/array.py +36 -1
- guppylang_internals/std/_internal/compiler/either.py +14 -2
- guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
- guppylang_internals/std/_internal/compiler/tket_exts.py +1 -1
- guppylang_internals/std/_internal/debug.py +5 -3
- guppylang_internals/tracing/builtins_mock.py +2 -2
- guppylang_internals/tracing/object.py +6 -2
- guppylang_internals/tys/parsing.py +4 -1
- guppylang_internals/tys/qubit.py +6 -4
- guppylang_internals/tys/subst.py +2 -2
- guppylang_internals/tys/ty.py +2 -2
- guppylang_internals/wasm_util.py +2 -3
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +47 -46
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -144,9 +144,9 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
144
144
|
"""
|
|
145
145
|
old = self.dfg
|
|
146
146
|
# Check that the input names are unique
|
|
147
|
-
assert len({inp.place.id for inp in inputs}) == len(
|
|
148
|
-
|
|
149
|
-
)
|
|
147
|
+
assert len({inp.place.id for inp in inputs}) == len(inputs), (
|
|
148
|
+
"Inputs are not unique"
|
|
149
|
+
)
|
|
150
150
|
self.dfg = DFContainer(builder, self.ctx, self.dfg.locals.copy())
|
|
151
151
|
hugr_input = builder.input_node
|
|
152
152
|
for input_node, wire in zip(inputs, hugr_input, strict=True):
|
|
@@ -325,14 +325,7 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
325
325
|
|
|
326
326
|
def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire:
|
|
327
327
|
"""Groups function return values into a tuple"""
|
|
328
|
-
|
|
329
|
-
types = type_to_row(return_ty)
|
|
330
|
-
assert len(returns) == len(types)
|
|
331
|
-
return self._pack_tuple(returns, types)
|
|
332
|
-
assert (
|
|
333
|
-
len(returns) == 1
|
|
334
|
-
), f"Expected a single return value. Got {returns}. return type {return_ty}"
|
|
335
|
-
return returns[0]
|
|
328
|
+
return pack_returns(returns, return_ty, self.builder, self.ctx)
|
|
336
329
|
|
|
337
330
|
def _update_inout_ports(
|
|
338
331
|
self,
|
|
@@ -394,9 +387,9 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
|
394
387
|
func, func_ty, remaining_args
|
|
395
388
|
)
|
|
396
389
|
rets.extend(outs)
|
|
397
|
-
assert (
|
|
398
|
-
|
|
399
|
-
)
|
|
390
|
+
assert remaining_args == [], (
|
|
391
|
+
"Not all function arguments were consumed after a tensor call"
|
|
392
|
+
)
|
|
400
393
|
return self._pack_returns(rets, node.tensor_ty.output)
|
|
401
394
|
|
|
402
395
|
def _compile_tensor_with_leftovers(
|
|
@@ -760,6 +753,35 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]:
|
|
|
760
753
|
return expr.elts if isinstance(expr, ast.Tuple) else [expr]
|
|
761
754
|
|
|
762
755
|
|
|
756
|
+
def pack_returns(
|
|
757
|
+
returns: Sequence[Wire],
|
|
758
|
+
return_ty: Type,
|
|
759
|
+
builder: DfBase[ops.DfParentOp],
|
|
760
|
+
ctx: CompilerContext,
|
|
761
|
+
) -> Wire:
|
|
762
|
+
"""Groups function return values into a tuple"""
|
|
763
|
+
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
|
|
764
|
+
types = type_to_row(return_ty)
|
|
765
|
+
assert len(returns) == len(types)
|
|
766
|
+
hugr_tys = [t.to_hugr(ctx) for t in types]
|
|
767
|
+
return builder.add_op(ops.MakeTuple(hugr_tys), *returns)
|
|
768
|
+
assert len(returns) == 1, (
|
|
769
|
+
f"Expected a single return value. Got {returns}. return type {return_ty}"
|
|
770
|
+
)
|
|
771
|
+
return returns[0]
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def unpack_wire(
|
|
775
|
+
wire: Wire, return_ty: Type, builder: DfBase[ops.DfParentOp], ctx: CompilerContext
|
|
776
|
+
) -> list[Wire]:
|
|
777
|
+
"""The inverse of `pack_returns`"""
|
|
778
|
+
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
|
|
779
|
+
types = type_to_row(return_ty)
|
|
780
|
+
hugr_tys = [t.to_hugr(ctx) for t in types]
|
|
781
|
+
return list(builder.add_op(ops.UnpackTuple(hugr_tys), wire).outputs())
|
|
782
|
+
return [wire]
|
|
783
|
+
|
|
784
|
+
|
|
763
785
|
def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
|
|
764
786
|
"""Checks if instantiating a polymorphic makes it return a row."""
|
|
765
787
|
if isinstance(func_ty.output, BoundTypeVar):
|
|
@@ -8,7 +8,7 @@ from guppylang_internals.checker.modifier_checker import non_copyable_front_othe
|
|
|
8
8
|
from guppylang_internals.compiler.cfg_compiler import compile_cfg
|
|
9
9
|
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
10
10
|
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
11
|
-
from guppylang_internals.definition.
|
|
11
|
+
from guppylang_internals.definition.metadata import add_metadata
|
|
12
12
|
from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
|
|
13
13
|
from guppylang_internals.std._internal.compiler.array import (
|
|
14
14
|
array_new,
|
|
@@ -57,7 +57,10 @@ def compile_modified_block(
|
|
|
57
57
|
func_builder = dfg.builder.module_root_builder().define_function(
|
|
58
58
|
str(modified_block), hugr_ty.input, hugr_ty.output
|
|
59
59
|
)
|
|
60
|
-
|
|
60
|
+
add_metadata(
|
|
61
|
+
func_builder,
|
|
62
|
+
additional_metadata={"unitary": modified_block.ty.unitary_flags.value},
|
|
63
|
+
)
|
|
61
64
|
|
|
62
65
|
# compile body
|
|
63
66
|
cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
|
guppylang_internals/decorator.py
CHANGED
|
@@ -4,10 +4,10 @@ import inspect
|
|
|
4
4
|
import pathlib
|
|
5
5
|
from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
|
|
6
6
|
|
|
7
|
+
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
7
8
|
from hugr import ops
|
|
8
9
|
from hugr import tys as ht
|
|
9
10
|
|
|
10
|
-
from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
|
|
11
11
|
from guppylang_internals.compiler.core import (
|
|
12
12
|
CompilerContext,
|
|
13
13
|
GlobalConstId,
|
|
@@ -26,7 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
|
|
|
26
26
|
from guppylang_internals.definition.wasm import RawWasmFunctionDef
|
|
27
27
|
from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
|
|
28
28
|
from guppylang_internals.engine import DEF_STORE
|
|
29
|
-
from guppylang_internals.error import GuppyError
|
|
29
|
+
from guppylang_internals.error import GuppyError, pretty_errors
|
|
30
30
|
from guppylang_internals.std._internal.checker import WasmCallChecker
|
|
31
31
|
from guppylang_internals.std._internal.compiler.wasm import (
|
|
32
32
|
WasmModuleCallCompiler,
|
|
@@ -193,7 +193,7 @@ def custom_type(
|
|
|
193
193
|
params or [],
|
|
194
194
|
not copyable,
|
|
195
195
|
not droppable,
|
|
196
|
-
mk_hugr_ty,
|
|
196
|
+
mk_hugr_ty, # type: ignore[arg-type]
|
|
197
197
|
bound,
|
|
198
198
|
)
|
|
199
199
|
DEF_STORE.register_def(defn, get_calling_frame())
|
|
@@ -207,6 +207,7 @@ def custom_type(
|
|
|
207
207
|
return dec
|
|
208
208
|
|
|
209
209
|
|
|
210
|
+
@pretty_errors
|
|
210
211
|
def wasm_module(
|
|
211
212
|
filename: str,
|
|
212
213
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
@@ -252,6 +253,7 @@ def ext_module_decorator(
|
|
|
252
253
|
def fun(
|
|
253
254
|
filename: str, module: str | None
|
|
254
255
|
) -> Callable[[builtins.type[T]], GuppyDefinition]:
|
|
256
|
+
@pretty_errors
|
|
255
257
|
def dec(cls: builtins.type[T]) -> GuppyDefinition:
|
|
256
258
|
# N.B. Only one module per file and vice-versa
|
|
257
259
|
ext_module = type_def(
|
|
@@ -163,6 +163,7 @@ class MonomorphizableDef(Definition):
|
|
|
163
163
|
module: DefinitionBuilder[OpVar],
|
|
164
164
|
mono_args: "PartiallyMonomorphizedArgs",
|
|
165
165
|
ctx: "CompilerContext",
|
|
166
|
+
parent_ty: "RawDef | None" = None,
|
|
166
167
|
) -> "MonomorphizedDef":
|
|
167
168
|
"""Adds a Hugr node for the (partially) monomorphized definition to the provided
|
|
168
169
|
Hugr module.
|
|
@@ -134,7 +134,7 @@ class RawCustomFunctionDef(ParsableDef):
|
|
|
134
134
|
"""
|
|
135
135
|
from guppylang_internals.definition.function import parse_py_func
|
|
136
136
|
|
|
137
|
-
func_ast,
|
|
137
|
+
func_ast, _docstring = parse_py_func(self.python_func, sources)
|
|
138
138
|
if not has_empty_body(func_ast):
|
|
139
139
|
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
140
140
|
sig = self.signature or self._get_signature(func_ast, globals)
|
|
@@ -479,7 +479,7 @@ class BoolOpCompiler(CustomInoutCallCompiler):
|
|
|
479
479
|
for res in result
|
|
480
480
|
]
|
|
481
481
|
return CallReturnWires(
|
|
482
|
-
regular_returns=converted_result,
|
|
482
|
+
regular_returns=converted_result, # type: ignore[arg-type]
|
|
483
483
|
inout_returns=[],
|
|
484
484
|
)
|
|
485
485
|
|
|
@@ -120,9 +120,9 @@ class CheckedFunctionDecl(RawFunctionDecl, CompilableDef, CallableDef):
|
|
|
120
120
|
self, module: DefinitionBuilder[OpVar], ctx: CompilerContext
|
|
121
121
|
) -> "CompiledFunctionDecl":
|
|
122
122
|
"""Adds a Hugr `FuncDecl` node for this function to the Hugr."""
|
|
123
|
-
assert isinstance(
|
|
124
|
-
|
|
125
|
-
)
|
|
123
|
+
assert isinstance(module, hf.Module), (
|
|
124
|
+
"Functions can only be declared in modules"
|
|
125
|
+
)
|
|
126
126
|
module: hf.Module = module
|
|
127
127
|
|
|
128
128
|
node = module.declare_function(self.name, self.ty.to_hugr_poly(ctx))
|
|
@@ -31,8 +31,10 @@ from guppylang_internals.definition.common import (
|
|
|
31
31
|
MonomorphizableDef,
|
|
32
32
|
MonomorphizedDef,
|
|
33
33
|
ParsableDef,
|
|
34
|
+
RawDef,
|
|
34
35
|
UnknownSourceError,
|
|
35
36
|
)
|
|
37
|
+
from guppylang_internals.definition.metadata import GuppyMetadata, add_metadata
|
|
36
38
|
from guppylang_internals.definition.value import (
|
|
37
39
|
CallableDef,
|
|
38
40
|
CallReturnWires,
|
|
@@ -72,13 +74,22 @@ class RawFunctionDef(ParsableDef):
|
|
|
72
74
|
|
|
73
75
|
unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
|
|
74
76
|
|
|
77
|
+
metadata: GuppyMetadata | None = field(default=None, kw_only=True)
|
|
78
|
+
|
|
75
79
|
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
|
|
76
80
|
"""Parses and checks the user-provided signature of the function."""
|
|
77
81
|
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
78
82
|
ty = check_signature(
|
|
79
83
|
func_ast, globals, self.id, unitary_flags=self.unitary_flags
|
|
80
84
|
)
|
|
81
|
-
return ParsedFunctionDef(
|
|
85
|
+
return ParsedFunctionDef(
|
|
86
|
+
self.id,
|
|
87
|
+
self.name,
|
|
88
|
+
func_ast,
|
|
89
|
+
ty,
|
|
90
|
+
docstring,
|
|
91
|
+
metadata=self.metadata,
|
|
92
|
+
)
|
|
82
93
|
|
|
83
94
|
|
|
84
95
|
@dataclass(frozen=True)
|
|
@@ -103,6 +114,8 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
|
|
|
103
114
|
|
|
104
115
|
description: str = field(default="function", init=False)
|
|
105
116
|
|
|
117
|
+
metadata: GuppyMetadata | None = field(default=None, kw_only=True)
|
|
118
|
+
|
|
106
119
|
def check(self, globals: Globals) -> "CheckedFunctionDef":
|
|
107
120
|
"""Type checks the body of the function."""
|
|
108
121
|
# Add python variable scope to the globals
|
|
@@ -114,6 +127,7 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
|
|
|
114
127
|
self.ty,
|
|
115
128
|
self.docstring,
|
|
116
129
|
cfg,
|
|
130
|
+
metadata=self.metadata,
|
|
117
131
|
)
|
|
118
132
|
|
|
119
133
|
def check_call(
|
|
@@ -164,6 +178,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
164
178
|
module: DefinitionBuilder[OpVar],
|
|
165
179
|
mono_args: "PartiallyMonomorphizedArgs",
|
|
166
180
|
ctx: "CompilerContext",
|
|
181
|
+
parent_ty: "RawDef | None" = None,
|
|
167
182
|
) -> "CompiledFunctionDef":
|
|
168
183
|
"""Adds a Hugr `FuncDefn` node for the (partially) monomorphized function to the
|
|
169
184
|
Hugr.
|
|
@@ -172,12 +187,21 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
172
187
|
access to the other compiled functions yet. The body is compiled later in
|
|
173
188
|
`CompiledFunctionDef.compile_inner()`.
|
|
174
189
|
"""
|
|
190
|
+
if parent_ty is None:
|
|
191
|
+
hugr_func_name = self.name
|
|
192
|
+
else:
|
|
193
|
+
hugr_func_name = f"{parent_ty.name}.{self.name}"
|
|
194
|
+
|
|
175
195
|
mono_ty = self.ty.instantiate_partial(mono_args)
|
|
176
196
|
hugr_ty = mono_ty.to_hugr_poly(ctx)
|
|
177
197
|
func_def = module.module_root_builder().define_function(
|
|
178
|
-
|
|
198
|
+
hugr_func_name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
|
|
199
|
+
)
|
|
200
|
+
add_metadata(
|
|
201
|
+
func_def,
|
|
202
|
+
self.metadata,
|
|
203
|
+
additional_metadata={"unitary": self.ty.unitary_flags.value},
|
|
179
204
|
)
|
|
180
|
-
add_unitarity_metadata(func_def, self.ty.unitary_flags)
|
|
181
205
|
return CompiledFunctionDef(
|
|
182
206
|
self.id,
|
|
183
207
|
self.name,
|
|
@@ -187,6 +211,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
187
211
|
self.docstring,
|
|
188
212
|
self.cfg,
|
|
189
213
|
func_def,
|
|
214
|
+
metadata=self.metadata,
|
|
190
215
|
)
|
|
191
216
|
|
|
192
217
|
|
|
@@ -305,8 +330,3 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
|
|
|
305
330
|
else:
|
|
306
331
|
node = ast.parse(source).body[0]
|
|
307
332
|
return source, node, line_offset
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
def add_unitarity_metadata(func: hf.Function, flags: UnitaryFlags) -> None:
|
|
311
|
-
"""Stores unitarity annotations in the metadate of a Hugr function definition."""
|
|
312
|
-
func.metadata["unitary"] = flags.value
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Metadata attached to objects within the Guppy compiler, both for internal use and to
|
|
2
|
+
attach to HUGR nodes for lower-level processing."""
|
|
3
|
+
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from dataclasses import dataclass, field, fields
|
|
6
|
+
from typing import Any, ClassVar, Generic, TypeVar
|
|
7
|
+
|
|
8
|
+
from hugr.hugr.node_port import ToNode
|
|
9
|
+
|
|
10
|
+
from guppylang_internals.diagnostic import Fatal
|
|
11
|
+
from guppylang_internals.error import GuppyError
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(init=True, kw_only=True)
|
|
17
|
+
class GuppyMetadataValue(ABC, Generic[T]):
|
|
18
|
+
"""A template class for a metadata value within the scope of the Guppy compiler.
|
|
19
|
+
Implementations should provide the `key` in reverse-URL format."""
|
|
20
|
+
|
|
21
|
+
key: ClassVar[str]
|
|
22
|
+
value: T | None = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MetadataMaxQubits(GuppyMetadataValue[int]):
|
|
26
|
+
key = "tket.hint.max_qubits"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True, init=True, kw_only=True)
|
|
30
|
+
class GuppyMetadata:
|
|
31
|
+
"""DTO for metadata within the scope of the guppy compiler for attachment to HUGR
|
|
32
|
+
nodes. See `add_metadata`."""
|
|
33
|
+
|
|
34
|
+
max_qubits: MetadataMaxQubits = field(default_factory=MetadataMaxQubits, init=False)
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def reserved_keys(cls) -> set[str]:
|
|
38
|
+
return {f.type.key for f in fields(GuppyMetadata)} # type: ignore[union-attr]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class MetadataAlreadySetError(Fatal):
|
|
43
|
+
title: ClassVar[str] = "Metadata key already set"
|
|
44
|
+
message: ClassVar[str] = "Received two values for the metadata key `{key}`"
|
|
45
|
+
key: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class ReservedMetadataKeysError(Fatal):
|
|
50
|
+
title: ClassVar[str] = "Metadata key is reserved"
|
|
51
|
+
message: ClassVar[str] = (
|
|
52
|
+
"The following metadata keys are reserved by Guppy but also provided in "
|
|
53
|
+
"additional metadata: `{keys}`"
|
|
54
|
+
)
|
|
55
|
+
keys: set[str]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def add_metadata(
|
|
59
|
+
node: ToNode,
|
|
60
|
+
metadata: GuppyMetadata | None = None,
|
|
61
|
+
*,
|
|
62
|
+
additional_metadata: dict[str, Any] | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Adds metadata to the given node using the keys defined through inheritors of
|
|
65
|
+
`GuppyMetadataValue` defined in the `GuppyMetadata` class.
|
|
66
|
+
|
|
67
|
+
Additional metadata is forwarded as is, although the given dictionary may not
|
|
68
|
+
contain any keys already reserved by fields in `GuppyMetadata`.
|
|
69
|
+
"""
|
|
70
|
+
if metadata is not None:
|
|
71
|
+
for f in fields(GuppyMetadata):
|
|
72
|
+
data: GuppyMetadataValue[Any] = getattr(metadata, f.name)
|
|
73
|
+
if data.key in node.metadata:
|
|
74
|
+
raise GuppyError(MetadataAlreadySetError(None, data.key))
|
|
75
|
+
if data.value is not None:
|
|
76
|
+
node.metadata[data.key] = data.value
|
|
77
|
+
|
|
78
|
+
if additional_metadata is not None:
|
|
79
|
+
reserved_keys = GuppyMetadata.reserved_keys()
|
|
80
|
+
used_reserved_keys = reserved_keys.intersection(additional_metadata.keys())
|
|
81
|
+
if len(used_reserved_keys) > 0:
|
|
82
|
+
raise GuppyError(ReservedMetadataKeysError(None, keys=used_reserved_keys))
|
|
83
|
+
|
|
84
|
+
for key, value in additional_metadata.items():
|
|
85
|
+
if key in node.metadata:
|
|
86
|
+
raise GuppyError(MetadataAlreadySetError(None, key))
|
|
87
|
+
node.metadata[key] = value
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import copy
|
|
2
3
|
from contextlib import suppress
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from typing import ClassVar, NoReturn
|
|
@@ -86,7 +87,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
|
|
|
86
87
|
assert isinstance(defn, CallableDef)
|
|
87
88
|
available_sigs.append(defn.ty)
|
|
88
89
|
with suppress(GuppyError):
|
|
89
|
-
|
|
90
|
+
# check_call may modify args and node,
|
|
91
|
+
# thus we deepcopy them before passing in the function
|
|
92
|
+
node_copy = copy.deepcopy(node)
|
|
93
|
+
args_copy = copy.deepcopy(args)
|
|
94
|
+
return defn.check_call(args_copy, ty, node_copy, ctx)
|
|
90
95
|
return self._call_error(args, node, ctx, available_sigs, ty)
|
|
91
96
|
|
|
92
97
|
def synthesize_call(
|
|
@@ -98,7 +103,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
|
|
|
98
103
|
assert isinstance(defn, CallableDef)
|
|
99
104
|
available_sigs.append(defn.ty)
|
|
100
105
|
with suppress(GuppyError):
|
|
101
|
-
|
|
106
|
+
# synthesize_call may modify args and node,
|
|
107
|
+
# thus we deepcopy them before passing in the function
|
|
108
|
+
node_copy = copy.deepcopy(node)
|
|
109
|
+
args_copy = copy.deepcopy(args)
|
|
110
|
+
return defn.synthesize_call(args_copy, node_copy, ctx)
|
|
102
111
|
return self._call_error(args, node, ctx, available_sigs)
|
|
103
112
|
|
|
104
113
|
def _call_error(
|
|
@@ -3,18 +3,17 @@ from dataclasses import dataclass, field
|
|
|
3
3
|
from typing import Any, cast
|
|
4
4
|
|
|
5
5
|
import hugr.build.function as hf
|
|
6
|
+
from guppylang.defs import GuppyDefinition
|
|
6
7
|
from hugr import Node, Wire, envelope, ops, val
|
|
7
8
|
from hugr import tys as ht
|
|
8
9
|
from hugr.build.dfg import DefinitionBuilder, OpVar
|
|
9
10
|
from hugr.envelope import EnvelopeConfig
|
|
10
11
|
from hugr.std.float import FLOAT_T
|
|
12
|
+
from pytket.circuit import Circuit
|
|
11
13
|
|
|
12
14
|
from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc
|
|
13
15
|
from guppylang_internals.checker.core import Context, Globals
|
|
14
|
-
from guppylang_internals.checker.errors.comptime_errors import
|
|
15
|
-
PytketSignatureMismatch,
|
|
16
|
-
TketNotInstalled,
|
|
17
|
-
)
|
|
16
|
+
from guppylang_internals.checker.errors.comptime_errors import PytketSignatureMismatch
|
|
18
17
|
from guppylang_internals.checker.expr_checker import check_call, synthesize_call
|
|
19
18
|
from guppylang_internals.checker.func_checker import (
|
|
20
19
|
check_signature,
|
|
@@ -46,6 +45,7 @@ from guppylang_internals.std._internal.compiler.array import (
|
|
|
46
45
|
array_new,
|
|
47
46
|
array_unpack,
|
|
48
47
|
)
|
|
48
|
+
from guppylang_internals.std._internal.compiler.quantum import from_halfturns_unchecked
|
|
49
49
|
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
|
|
50
50
|
from guppylang_internals.tys.builtin import array_type, bool_type, float_type
|
|
51
51
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
@@ -230,17 +230,20 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
230
230
|
)
|
|
231
231
|
lex_params = list(unpack_result)
|
|
232
232
|
param_order = cast(
|
|
233
|
-
list[str], hugr_func.metadata["TKET1.input_parameters"]
|
|
233
|
+
"list[str]", hugr_func.metadata["TKET1.input_parameters"]
|
|
234
234
|
)
|
|
235
235
|
lex_names = sorted(param_order)
|
|
236
236
|
name_to_param = dict(zip(lex_names, lex_params, strict=True))
|
|
237
237
|
angle_wires = [name_to_param[name] for name in param_order]
|
|
238
|
-
# Need to convert all angles to
|
|
238
|
+
# Need to convert all angles to rotations.
|
|
239
239
|
for angle in angle_wires:
|
|
240
240
|
[halfturns] = outer_func.add_op(
|
|
241
241
|
ops.UnpackTuple([FLOAT_T]), angle
|
|
242
242
|
)
|
|
243
|
-
|
|
243
|
+
rotation = outer_func.add_op(
|
|
244
|
+
from_halfturns_unchecked(), halfturns
|
|
245
|
+
)
|
|
246
|
+
param_wires.append(rotation)
|
|
244
247
|
|
|
245
248
|
# Pass all arguments to call node.
|
|
246
249
|
call_node = outer_func.call(
|
|
@@ -365,69 +368,49 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef, CompiledHugrNodeDe
|
|
|
365
368
|
|
|
366
369
|
|
|
367
370
|
def _signature_from_circuit(
|
|
368
|
-
input_circuit:
|
|
371
|
+
input_circuit: Circuit,
|
|
369
372
|
defined_at: ToSpan | None,
|
|
370
373
|
use_arrays: bool = False,
|
|
371
374
|
) -> FunctionType:
|
|
372
375
|
"""Helper function for inferring a function signature from a pytket circuit."""
|
|
373
376
|
# May want to set proper unitary flags in the future.
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
)
|
|
404
|
-
outputs = [
|
|
405
|
-
array_type(bool_type(), c_reg.size)
|
|
406
|
-
for c_reg in input_circuit.c_registers
|
|
407
|
-
]
|
|
408
|
-
circuit_signature = FunctionType(
|
|
409
|
-
inputs,
|
|
410
|
-
row_to_type(outputs),
|
|
411
|
-
)
|
|
412
|
-
else:
|
|
413
|
-
param_inputs = [
|
|
414
|
-
FuncInput(angle_ty, InputFlags.NoFlags)
|
|
415
|
-
for _ in range(len(input_circuit.free_symbols()))
|
|
416
|
-
]
|
|
417
|
-
circuit_signature = FunctionType(
|
|
418
|
-
[FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
|
|
419
|
-
+ param_inputs,
|
|
420
|
-
row_to_type([bool_type()] * input_circuit.n_bits),
|
|
421
|
-
)
|
|
422
|
-
except ImportError:
|
|
423
|
-
err = TketNotInstalled(defined_at)
|
|
424
|
-
err.add_sub_diagnostic(TketNotInstalled.InstallInstruction(None))
|
|
425
|
-
raise GuppyError(err) from None
|
|
426
|
-
else:
|
|
427
|
-
pass
|
|
428
|
-
except ImportError:
|
|
429
|
-
raise InternalGuppyError(
|
|
430
|
-
"Pytket error should have been caught earlier"
|
|
431
|
-
) from None
|
|
377
|
+
from guppylang.std.angles import angle # Avoid circular imports
|
|
378
|
+
from guppylang.std.quantum import qubit
|
|
379
|
+
|
|
380
|
+
assert isinstance(qubit, GuppyDefinition)
|
|
381
|
+
qubit_ty = cast("TypeDef", qubit.wrapped).check_instantiate([])
|
|
382
|
+
|
|
383
|
+
angle_defn = ENGINE.get_checked(angle.id) # type: ignore[attr-defined]
|
|
384
|
+
assert isinstance(angle_defn, TypeDef)
|
|
385
|
+
angle_ty = angle_defn.check_instantiate([])
|
|
386
|
+
|
|
387
|
+
if use_arrays:
|
|
388
|
+
inputs = [
|
|
389
|
+
FuncInput(array_type(qubit_ty, q_reg.size), InputFlags.Inout)
|
|
390
|
+
for q_reg in input_circuit.q_registers
|
|
391
|
+
]
|
|
392
|
+
if len(input_circuit.free_symbols()) != 0:
|
|
393
|
+
inputs.append(
|
|
394
|
+
FuncInput(
|
|
395
|
+
array_type(angle_ty, len(input_circuit.free_symbols())),
|
|
396
|
+
InputFlags.NoFlags,
|
|
397
|
+
)
|
|
398
|
+
)
|
|
399
|
+
outputs = [
|
|
400
|
+
array_type(bool_type(), c_reg.size) for c_reg in input_circuit.c_registers
|
|
401
|
+
]
|
|
402
|
+
circuit_signature = FunctionType(
|
|
403
|
+
inputs,
|
|
404
|
+
row_to_type(outputs),
|
|
405
|
+
)
|
|
432
406
|
else:
|
|
433
|
-
|
|
407
|
+
param_inputs = [
|
|
408
|
+
FuncInput(angle_ty, InputFlags.NoFlags)
|
|
409
|
+
for _ in range(len(input_circuit.free_symbols()))
|
|
410
|
+
]
|
|
411
|
+
circuit_signature = FunctionType(
|
|
412
|
+
[FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
|
|
413
|
+
+ param_inputs,
|
|
414
|
+
row_to_type([bool_type()] * input_circuit.n_bits),
|
|
415
|
+
)
|
|
416
|
+
return circuit_signature
|
|
@@ -55,7 +55,7 @@ class CallableDef(ValueDef):
|
|
|
55
55
|
raise RuntimeError("Guppy functions can only be called in a Guppy context")
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
class CompiledCallableDef(CallableDef, CompiledValueDef):
|
|
58
|
+
class CompiledCallableDef(CallableDef, CompiledValueDef): # type: ignore[misc, unused-ignore]
|
|
59
59
|
"""Abstract base class a global module-level function."""
|
|
60
60
|
|
|
61
61
|
ty: FunctionType
|
|
@@ -38,9 +38,9 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
38
38
|
def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
39
39
|
# Place to highlight in error messages
|
|
40
40
|
match fun_ty.inputs:
|
|
41
|
-
case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if
|
|
42
|
-
ty
|
|
43
|
-
)
|
|
41
|
+
case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if (
|
|
42
|
+
wasm_module_name(ty) is not None
|
|
43
|
+
):
|
|
44
44
|
for inp in args:
|
|
45
45
|
if not self.is_type_wasmable(inp.ty):
|
|
46
46
|
raise GuppyError(UnWasmableType(loc, inp.ty))
|