guppylang-internals 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/cfg/builder.py +20 -2
- guppylang_internals/cfg/cfg.py +3 -0
- guppylang_internals/checker/cfg_checker.py +6 -0
- guppylang_internals/checker/core.py +1 -2
- guppylang_internals/checker/errors/linearity.py +6 -2
- guppylang_internals/checker/errors/wasm.py +7 -4
- guppylang_internals/checker/expr_checker.py +39 -19
- guppylang_internals/checker/func_checker.py +17 -13
- guppylang_internals/checker/linearity_checker.py +2 -10
- guppylang_internals/checker/modifier_checker.py +6 -2
- guppylang_internals/checker/unitary_checker.py +132 -0
- guppylang_internals/compiler/cfg_compiler.py +7 -6
- guppylang_internals/compiler/core.py +5 -5
- guppylang_internals/compiler/expr_compiler.py +72 -81
- guppylang_internals/compiler/modifier_compiler.py +5 -0
- guppylang_internals/decorator.py +88 -7
- guppylang_internals/definition/custom.py +4 -0
- guppylang_internals/definition/declaration.py +6 -2
- guppylang_internals/definition/function.py +26 -3
- guppylang_internals/definition/metadata.py +87 -0
- guppylang_internals/definition/overloaded.py +11 -2
- guppylang_internals/definition/pytket_circuits.py +7 -2
- guppylang_internals/definition/struct.py +6 -3
- guppylang_internals/definition/wasm.py +42 -10
- guppylang_internals/diagnostic.py +72 -15
- guppylang_internals/engine.py +10 -13
- guppylang_internals/nodes.py +55 -24
- guppylang_internals/std/_internal/checker.py +13 -108
- guppylang_internals/std/_internal/compiler/array.py +37 -2
- guppylang_internals/std/_internal/compiler/either.py +14 -2
- guppylang_internals/std/_internal/compiler/list.py +1 -1
- guppylang_internals/std/_internal/compiler/platform.py +153 -0
- guppylang_internals/std/_internal/compiler/prelude.py +12 -4
- guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
- guppylang_internals/std/_internal/compiler/tket_exts.py +4 -5
- guppylang_internals/std/_internal/debug.py +18 -9
- guppylang_internals/std/_internal/util.py +1 -1
- guppylang_internals/tracing/object.py +14 -0
- guppylang_internals/tys/errors.py +23 -1
- guppylang_internals/tys/parsing.py +3 -3
- guppylang_internals/tys/printing.py +2 -8
- guppylang_internals/tys/qubit.py +37 -2
- guppylang_internals/tys/ty.py +60 -64
- guppylang_internals/wasm_util.py +129 -0
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/METADATA +5 -4
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/RECORD +49 -45
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/WHEEL +1 -1
- {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Metadata attached to objects within the Guppy compiler, both for internal use and to
|
|
2
|
+
attach to HUGR nodes for lower-level processing."""
|
|
3
|
+
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from dataclasses import dataclass, field, fields
|
|
6
|
+
from typing import Any, ClassVar, Generic, TypeVar
|
|
7
|
+
|
|
8
|
+
from hugr.hugr.node_port import ToNode
|
|
9
|
+
|
|
10
|
+
from guppylang_internals.diagnostic import Fatal
|
|
11
|
+
from guppylang_internals.error import GuppyError
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(init=True, kw_only=True)
|
|
17
|
+
class GuppyMetadataValue(ABC, Generic[T]):
|
|
18
|
+
"""A template class for a metadata value within the scope of the Guppy compiler.
|
|
19
|
+
Implementations should provide the `key` in reverse-URL format."""
|
|
20
|
+
|
|
21
|
+
key: ClassVar[str]
|
|
22
|
+
value: T | None = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MetadataMaxQubits(GuppyMetadataValue[int]):
|
|
26
|
+
key = "tket.hint.max_qubits"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True, init=True, kw_only=True)
|
|
30
|
+
class GuppyMetadata:
|
|
31
|
+
"""DTO for metadata within the scope of the guppy compiler for attachment to HUGR
|
|
32
|
+
nodes. See `add_metadata`."""
|
|
33
|
+
|
|
34
|
+
max_qubits: MetadataMaxQubits = field(default_factory=MetadataMaxQubits, init=False)
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def reserved_keys(cls) -> set[str]:
|
|
38
|
+
return {f.type.key for f in fields(GuppyMetadata)}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class MetadataAlreadySetError(Fatal):
|
|
43
|
+
title: ClassVar[str] = "Metadata key already set"
|
|
44
|
+
message: ClassVar[str] = "Received two values for the metadata key `{key}`"
|
|
45
|
+
key: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True)
|
|
49
|
+
class ReservedMetadataKeysError(Fatal):
|
|
50
|
+
title: ClassVar[str] = "Metadata key is reserved"
|
|
51
|
+
message: ClassVar[str] = (
|
|
52
|
+
"The following metadata keys are reserved by Guppy but also provided in "
|
|
53
|
+
"additional metadata: `{keys}`"
|
|
54
|
+
)
|
|
55
|
+
keys: set[str]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def add_metadata(
|
|
59
|
+
node: ToNode,
|
|
60
|
+
metadata: GuppyMetadata | None = None,
|
|
61
|
+
*,
|
|
62
|
+
additional_metadata: dict[str, Any] | None = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Adds metadata to the given node using the keys defined through inheritors of
|
|
65
|
+
`GuppyMetadataValue` defined in the `GuppyMetadata` class.
|
|
66
|
+
|
|
67
|
+
Additional metadata is forwarded as is, although the given dictionary may not
|
|
68
|
+
contain any keys already reserved by fields in `GuppyMetadata`.
|
|
69
|
+
"""
|
|
70
|
+
if metadata is not None:
|
|
71
|
+
for f in fields(GuppyMetadata):
|
|
72
|
+
data: GuppyMetadataValue[Any] = getattr(metadata, f.name)
|
|
73
|
+
if data.key in node.metadata:
|
|
74
|
+
raise GuppyError(MetadataAlreadySetError(None, data.key))
|
|
75
|
+
if data.value is not None:
|
|
76
|
+
node.metadata[data.key] = data.value
|
|
77
|
+
|
|
78
|
+
if additional_metadata is not None:
|
|
79
|
+
reserved_keys = GuppyMetadata.reserved_keys()
|
|
80
|
+
used_reserved_keys = reserved_keys.intersection(additional_metadata.keys())
|
|
81
|
+
if len(used_reserved_keys) > 0:
|
|
82
|
+
raise GuppyError(ReservedMetadataKeysError(None, keys=used_reserved_keys))
|
|
83
|
+
|
|
84
|
+
for key, value in additional_metadata.items():
|
|
85
|
+
if key in node.metadata:
|
|
86
|
+
raise GuppyError(MetadataAlreadySetError(None, key))
|
|
87
|
+
node.metadata[key] = value
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import copy
|
|
2
3
|
from contextlib import suppress
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from typing import ClassVar, NoReturn
|
|
@@ -86,7 +87,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
|
|
|
86
87
|
assert isinstance(defn, CallableDef)
|
|
87
88
|
available_sigs.append(defn.ty)
|
|
88
89
|
with suppress(GuppyError):
|
|
89
|
-
|
|
90
|
+
# check_call may modify args and node,
|
|
91
|
+
# thus we deepcopy them before passing in the function
|
|
92
|
+
node_copy = copy.deepcopy(node)
|
|
93
|
+
args_copy = copy.deepcopy(args)
|
|
94
|
+
return defn.check_call(args_copy, ty, node_copy, ctx)
|
|
90
95
|
return self._call_error(args, node, ctx, available_sigs, ty)
|
|
91
96
|
|
|
92
97
|
def synthesize_call(
|
|
@@ -98,7 +103,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
|
|
|
98
103
|
assert isinstance(defn, CallableDef)
|
|
99
104
|
available_sigs.append(defn.ty)
|
|
100
105
|
with suppress(GuppyError):
|
|
101
|
-
|
|
106
|
+
# synthesize_call may modify args and node,
|
|
107
|
+
# thus we deepcopy them before passing in the function
|
|
108
|
+
node_copy = copy.deepcopy(node)
|
|
109
|
+
args_copy = copy.deepcopy(args)
|
|
110
|
+
return defn.synthesize_call(args_copy, node_copy, ctx)
|
|
102
111
|
return self._call_error(args, node, ctx, available_sigs)
|
|
103
112
|
|
|
104
113
|
def _call_error(
|
|
@@ -46,6 +46,7 @@ from guppylang_internals.std._internal.compiler.array import (
|
|
|
46
46
|
array_new,
|
|
47
47
|
array_unpack,
|
|
48
48
|
)
|
|
49
|
+
from guppylang_internals.std._internal.compiler.quantum import from_halfturns_unchecked
|
|
49
50
|
from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
|
|
50
51
|
from guppylang_internals.tys.builtin import array_type, bool_type, float_type
|
|
51
52
|
from guppylang_internals.tys.subst import Inst, Subst
|
|
@@ -235,12 +236,15 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
235
236
|
lex_names = sorted(param_order)
|
|
236
237
|
name_to_param = dict(zip(lex_names, lex_params, strict=True))
|
|
237
238
|
angle_wires = [name_to_param[name] for name in param_order]
|
|
238
|
-
# Need to convert all angles to
|
|
239
|
+
# Need to convert all angles to rotations.
|
|
239
240
|
for angle in angle_wires:
|
|
240
241
|
[halfturns] = outer_func.add_op(
|
|
241
242
|
ops.UnpackTuple([FLOAT_T]), angle
|
|
242
243
|
)
|
|
243
|
-
|
|
244
|
+
rotation = outer_func.add_op(
|
|
245
|
+
from_halfturns_unchecked(), halfturns
|
|
246
|
+
)
|
|
247
|
+
param_wires.append(rotation)
|
|
244
248
|
|
|
245
249
|
# Pass all arguments to call node.
|
|
246
250
|
call_node = outer_func.call(
|
|
@@ -370,6 +374,7 @@ def _signature_from_circuit(
|
|
|
370
374
|
use_arrays: bool = False,
|
|
371
375
|
) -> FunctionType:
|
|
372
376
|
"""Helper function for inferring a function signature from a pytket circuit."""
|
|
377
|
+
# May want to set proper unitary flags in the future.
|
|
373
378
|
try:
|
|
374
379
|
import pytket
|
|
375
380
|
|
|
@@ -273,13 +273,16 @@ class CheckedStructDef(TypeDef, CompiledDef):
|
|
|
273
273
|
|
|
274
274
|
constructor_sig = FunctionType(
|
|
275
275
|
inputs=[
|
|
276
|
-
FuncInput(
|
|
276
|
+
FuncInput(
|
|
277
|
+
f.ty,
|
|
278
|
+
InputFlags.Owned if f.ty.linear else InputFlags.NoFlags,
|
|
279
|
+
f.name,
|
|
280
|
+
)
|
|
277
281
|
for f in self.fields
|
|
278
282
|
],
|
|
279
283
|
output=StructType(
|
|
280
284
|
defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
|
|
281
285
|
),
|
|
282
|
-
input_names=[f.name for f in self.fields],
|
|
283
286
|
params=self.params,
|
|
284
287
|
)
|
|
285
288
|
constructor_def = CustomFunctionDef(
|
|
@@ -317,7 +320,7 @@ def parse_py_class(
|
|
|
317
320
|
raise GuppyError(UnknownSourceError(None, cls))
|
|
318
321
|
|
|
319
322
|
# We can't rely on `inspect.getsourcelines` since it doesn't work properly for
|
|
320
|
-
# classes prior to Python 3.13. See https://github.com/
|
|
323
|
+
# classes prior to Python 3.13. See https://github.com/quantinuum/guppylang/issues/1107.
|
|
321
324
|
# Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
|
|
322
325
|
# attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
|
|
323
326
|
# In the decorator, we make sure that `__firstlineno__` is set, even if we're not
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
1
2
|
from typing import TYPE_CHECKING
|
|
2
3
|
|
|
3
4
|
from guppylang_internals.ast_util import AstNode
|
|
@@ -9,7 +10,8 @@ from guppylang_internals.definition.custom import (
|
|
|
9
10
|
CustomFunctionDef,
|
|
10
11
|
RawCustomFunctionDef,
|
|
11
12
|
)
|
|
12
|
-
from guppylang_internals.
|
|
13
|
+
from guppylang_internals.engine import DEF_STORE
|
|
14
|
+
from guppylang_internals.error import GuppyError, GuppyTypeError
|
|
13
15
|
from guppylang_internals.span import SourceMap
|
|
14
16
|
from guppylang_internals.tys.builtin import wasm_module_name
|
|
15
17
|
from guppylang_internals.tys.ty import (
|
|
@@ -21,24 +23,35 @@ from guppylang_internals.tys.ty import (
|
|
|
21
23
|
TupleType,
|
|
22
24
|
Type,
|
|
23
25
|
)
|
|
26
|
+
from guppylang_internals.wasm_util import WasmSigMismatchError
|
|
24
27
|
|
|
25
28
|
if TYPE_CHECKING:
|
|
26
29
|
from guppylang_internals.checker.core import Globals
|
|
27
30
|
|
|
28
31
|
|
|
32
|
+
@dataclass(frozen=True)
|
|
29
33
|
class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
30
|
-
|
|
34
|
+
# If a function is specified in the @wasm decorator by its index in the wasm
|
|
35
|
+
# file, record what the index was.
|
|
36
|
+
wasm_index: int | None = field(default=None)
|
|
37
|
+
|
|
38
|
+
def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
31
39
|
# Place to highlight in error messages
|
|
32
|
-
match fun_ty.inputs
|
|
33
|
-
case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
|
|
40
|
+
match fun_ty.inputs:
|
|
41
|
+
case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
|
|
34
42
|
ty
|
|
35
43
|
) is not None:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
for inp in args:
|
|
45
|
+
if not self.is_type_wasmable(inp.ty):
|
|
46
|
+
raise GuppyError(UnWasmableType(loc, inp.ty))
|
|
47
|
+
case [FuncInput(ty=ty), *_]:
|
|
48
|
+
raise GuppyError(
|
|
49
|
+
FirstArgNotModule(loc).add_sub_diagnostic(
|
|
50
|
+
FirstArgNotModule.GotOtherType(loc, ty)
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
case []:
|
|
54
|
+
raise GuppyError(FirstArgNotModule(loc))
|
|
42
55
|
if not self.is_type_wasmable(fun_ty.output):
|
|
43
56
|
match fun_ty.output:
|
|
44
57
|
case NoneType():
|
|
@@ -46,6 +59,23 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
46
59
|
case _:
|
|
47
60
|
raise GuppyError(UnWasmableType(loc, fun_ty.output))
|
|
48
61
|
|
|
62
|
+
def validate_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
63
|
+
type_in_wasm: FunctionType = DEF_STORE.wasm_functions[self.id]
|
|
64
|
+
assert type_in_wasm is not None
|
|
65
|
+
# Drop the first arg because it should be "self"
|
|
66
|
+
expected_type = FunctionType(fun_ty.inputs[1:], fun_ty.output)
|
|
67
|
+
|
|
68
|
+
if expected_type != type_in_wasm:
|
|
69
|
+
raise GuppyTypeError(
|
|
70
|
+
WasmSigMismatchError(loc)
|
|
71
|
+
.add_sub_diagnostic(
|
|
72
|
+
WasmSigMismatchError.Declaration(None, declared=str(expected_type))
|
|
73
|
+
)
|
|
74
|
+
.add_sub_diagnostic(
|
|
75
|
+
WasmSigMismatchError.Actual(None, actual=str(type_in_wasm))
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
|
|
49
79
|
def is_type_wasmable(self, ty: Type) -> bool:
|
|
50
80
|
match ty:
|
|
51
81
|
case NumericType():
|
|
@@ -57,5 +87,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
57
87
|
|
|
58
88
|
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
|
|
59
89
|
parsed = super().parse(globals, sources)
|
|
90
|
+
assert parsed.defined_at is not None
|
|
60
91
|
self.sanitise_type(parsed.defined_at, parsed.ty)
|
|
92
|
+
self.validate_type(parsed.defined_at, parsed.ty)
|
|
61
93
|
return parsed
|
|
@@ -208,7 +208,8 @@ class DiagnosticsRenderer:
|
|
|
208
208
|
MAX_MESSAGE_LINE_LEN: Final[int] = 80
|
|
209
209
|
|
|
210
210
|
#: Number of preceding source lines we show to give additional context
|
|
211
|
-
|
|
211
|
+
PREFIX_ERROR_CONTEXT_LINES: Final[int] = 2
|
|
212
|
+
PREFIX_NOTE_CONTEXT_LINES: Final[int] = 1
|
|
212
213
|
|
|
213
214
|
def __init__(self, source: SourceMap) -> None:
|
|
214
215
|
self.buffer = []
|
|
@@ -243,31 +244,84 @@ class DiagnosticsRenderer:
|
|
|
243
244
|
else:
|
|
244
245
|
span = to_span(diag.span)
|
|
245
246
|
level = self.level_str(diag.level)
|
|
246
|
-
|
|
247
|
-
|
|
247
|
+
|
|
248
|
+
children_with_span = [
|
|
249
|
+
(child, to_span(child.span)) for child in diag.children if child.span
|
|
248
250
|
]
|
|
251
|
+
all_spans = [span] + [span for _, span in children_with_span]
|
|
249
252
|
max_lineno = max(s.end.line for s in all_spans)
|
|
253
|
+
|
|
250
254
|
self.buffer.append(f"{level}: {diag.rendered_title} (at {span.start})")
|
|
255
|
+
|
|
256
|
+
# Render main error span first
|
|
251
257
|
self.render_snippet(
|
|
252
258
|
span,
|
|
253
259
|
diag.rendered_span_label,
|
|
254
260
|
max_lineno,
|
|
255
261
|
is_primary=True,
|
|
256
|
-
prefix_lines=self.
|
|
262
|
+
prefix_lines=self.PREFIX_ERROR_CONTEXT_LINES,
|
|
257
263
|
)
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
264
|
+
|
|
265
|
+
match children_with_span:
|
|
266
|
+
case []:
|
|
267
|
+
pass
|
|
268
|
+
case [(only_child, span)]:
|
|
269
|
+
self.buffer.append("\nNote:")
|
|
261
270
|
self.render_snippet(
|
|
262
|
-
|
|
263
|
-
|
|
271
|
+
span,
|
|
272
|
+
only_child.rendered_span_label,
|
|
264
273
|
max_lineno,
|
|
265
|
-
|
|
274
|
+
prefix_lines=self.PREFIX_NOTE_CONTEXT_LINES,
|
|
275
|
+
print_pad_line=True,
|
|
266
276
|
)
|
|
277
|
+
case [(first_child, first_span), *children_with_span]:
|
|
278
|
+
self.buffer.append("\nNotes:")
|
|
279
|
+
self.render_snippet(
|
|
280
|
+
first_span,
|
|
281
|
+
first_child.rendered_span_label,
|
|
282
|
+
max_lineno,
|
|
283
|
+
prefix_lines=self.PREFIX_NOTE_CONTEXT_LINES,
|
|
284
|
+
print_pad_line=True,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
prev_span_end_lineno = first_span.end.line
|
|
288
|
+
|
|
289
|
+
for sub_diag, span in children_with_span:
|
|
290
|
+
span_start_lineno = span.start.line
|
|
291
|
+
span_end_lineno = span.end.line
|
|
292
|
+
|
|
293
|
+
# If notes are on the same line, render them together
|
|
294
|
+
if span_start_lineno == prev_span_end_lineno:
|
|
295
|
+
prefix_lines = 0
|
|
296
|
+
print_pad_line = True
|
|
297
|
+
# if notes are close enough, render them adjacently
|
|
298
|
+
elif (
|
|
299
|
+
span_start_lineno - self.PREFIX_NOTE_CONTEXT_LINES
|
|
300
|
+
<= prev_span_end_lineno + 1
|
|
301
|
+
):
|
|
302
|
+
prefix_lines = span_start_lineno - prev_span_end_lineno - 1
|
|
303
|
+
print_pad_line = False
|
|
304
|
+
# otherwise we render a separator between notes
|
|
305
|
+
else:
|
|
306
|
+
self.buffer.append("")
|
|
307
|
+
prefix_lines = self.PREFIX_NOTE_CONTEXT_LINES
|
|
308
|
+
print_pad_line = False
|
|
309
|
+
|
|
310
|
+
self.render_snippet(
|
|
311
|
+
span,
|
|
312
|
+
sub_diag.rendered_span_label,
|
|
313
|
+
max_lineno,
|
|
314
|
+
prefix_lines=prefix_lines,
|
|
315
|
+
print_pad_line=print_pad_line,
|
|
316
|
+
)
|
|
317
|
+
prev_span_end_lineno = span_end_lineno
|
|
318
|
+
|
|
319
|
+
# Render the main diagnostic message if present
|
|
267
320
|
if diag.rendered_message:
|
|
268
321
|
self.buffer.append("")
|
|
269
322
|
self.buffer += wrap(diag.rendered_message, self.MAX_MESSAGE_LINE_LEN)
|
|
270
|
-
|
|
323
|
+
|
|
324
|
+
# Render all sub-diagnostics that have a non-span message
|
|
271
325
|
for sub_diag in diag.children:
|
|
272
326
|
if sub_diag.rendered_message:
|
|
273
327
|
self.buffer.append("")
|
|
@@ -281,8 +335,9 @@ class DiagnosticsRenderer:
|
|
|
281
335
|
span: Span,
|
|
282
336
|
label: str | None,
|
|
283
337
|
max_lineno: int,
|
|
284
|
-
is_primary: bool,
|
|
338
|
+
is_primary: bool = False,
|
|
285
339
|
prefix_lines: int = 0,
|
|
340
|
+
print_pad_line: bool = False,
|
|
286
341
|
) -> None:
|
|
287
342
|
"""Renders the source associated with a span together with an optional label.
|
|
288
343
|
|
|
@@ -315,7 +370,8 @@ class DiagnosticsRenderer:
|
|
|
315
370
|
Optionally includes up to `prefix_lines` preceding source lines to give
|
|
316
371
|
additional context.
|
|
317
372
|
"""
|
|
318
|
-
# Check how much space we need to reserve for the leading
|
|
373
|
+
# Check how much horizontal space we need to reserve for the leading
|
|
374
|
+
# line numbers
|
|
319
375
|
ll_length = len(str(max_lineno))
|
|
320
376
|
highlight_char = "^" if is_primary else "-"
|
|
321
377
|
|
|
@@ -324,8 +380,9 @@ class DiagnosticsRenderer:
|
|
|
324
380
|
ll = "" if line_number is None else str(line_number)
|
|
325
381
|
self.buffer.append(" " * (ll_length - len(ll)) + ll + " | " + line)
|
|
326
382
|
|
|
327
|
-
# One line of padding
|
|
328
|
-
|
|
383
|
+
# One line of padding (primary span, first note or between same line notes)
|
|
384
|
+
if is_primary or print_pad_line:
|
|
385
|
+
render_line("")
|
|
329
386
|
|
|
330
387
|
# Grab all lines we want to display and remove excessive leading whitespace
|
|
331
388
|
prefix_lines = min(prefix_lines, span.start.line - 1)
|
guppylang_internals/engine.py
CHANGED
|
@@ -46,6 +46,7 @@ from guppylang_internals.tys.builtin import (
|
|
|
46
46
|
string_type_def,
|
|
47
47
|
tuple_type_def,
|
|
48
48
|
)
|
|
49
|
+
from guppylang_internals.tys.ty import FunctionType
|
|
49
50
|
|
|
50
51
|
if TYPE_CHECKING:
|
|
51
52
|
from guppylang_internals.compiler.core import MonoDefId
|
|
@@ -87,6 +88,7 @@ class DefinitionStore:
|
|
|
87
88
|
raw_defs: dict[DefId, RawDef]
|
|
88
89
|
impls: defaultdict[DefId, dict[str, DefId]]
|
|
89
90
|
impl_parents: dict[DefId, DefId]
|
|
91
|
+
wasm_functions: dict[DefId, FunctionType]
|
|
90
92
|
frames: dict[DefId, FrameType]
|
|
91
93
|
sources: SourceMap
|
|
92
94
|
|
|
@@ -96,6 +98,7 @@ class DefinitionStore:
|
|
|
96
98
|
self.impl_parents = {}
|
|
97
99
|
self.frames = {}
|
|
98
100
|
self.sources = SourceMap()
|
|
101
|
+
self.wasm_functions = {}
|
|
99
102
|
|
|
100
103
|
def register_def(self, defn: RawDef, frame: FrameType | None) -> None:
|
|
101
104
|
self.raw_defs[defn.id] = defn
|
|
@@ -123,6 +126,9 @@ class DefinitionStore:
|
|
|
123
126
|
assert frame is not None
|
|
124
127
|
self.frames[impl_id] = frame
|
|
125
128
|
|
|
129
|
+
def register_wasm_function(self, fn_id: DefId, sig: FunctionType) -> None:
|
|
130
|
+
self.wasm_functions[fn_id] = sig
|
|
131
|
+
|
|
126
132
|
|
|
127
133
|
DEF_STORE: DefinitionStore = DefinitionStore()
|
|
128
134
|
|
|
@@ -214,21 +220,12 @@ class CompilationEngine:
|
|
|
214
220
|
|
|
215
221
|
This is the main driver behind `guppy.check()`.
|
|
216
222
|
"""
|
|
217
|
-
from guppylang_internals.checker.core import Globals
|
|
218
|
-
|
|
219
223
|
# Clear previous compilation cache.
|
|
220
224
|
# TODO: In order to maintain results from the previous `check` call we would
|
|
221
225
|
# need to store and check if any dependencies have changed.
|
|
222
226
|
self.reset()
|
|
223
227
|
|
|
224
|
-
|
|
225
|
-
self.to_check_worklist = {
|
|
226
|
-
defn.id: (
|
|
227
|
-
defn.parse(Globals(DEF_STORE.frames[defn.id]), DEF_STORE.sources)
|
|
228
|
-
if isinstance(defn, ParsableDef)
|
|
229
|
-
else defn
|
|
230
|
-
)
|
|
231
|
-
}
|
|
228
|
+
self.to_check_worklist[id] = self.get_parsed(id)
|
|
232
229
|
while self.types_to_check_worklist or self.to_check_worklist:
|
|
233
230
|
# Types need to be checked first. This is because parsing e.g. a function
|
|
234
231
|
# definition requires instantiating the types in its signature which can
|
|
@@ -263,8 +260,8 @@ class CompilationEngine:
|
|
|
263
260
|
and isinstance(compiled_def, CompiledCallableDef)
|
|
264
261
|
and not isinstance(graph.hugr[compiled_def.hugr_node].op, ops.FuncDecl)
|
|
265
262
|
):
|
|
266
|
-
# if compiling a region set it as the HUGR entrypoint
|
|
267
|
-
#
|
|
263
|
+
# if compiling a region set it as the HUGR entrypoint can be
|
|
264
|
+
# loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
|
|
268
265
|
graph.hugr.entrypoint = compiled_def.hugr_node
|
|
269
266
|
|
|
270
267
|
# TODO: Currently the list of extensions is manually managed by the user.
|
|
@@ -278,7 +275,7 @@ class CompilationEngine:
|
|
|
278
275
|
guppylang_internals.compiler.hugr_extension.EXTENSION,
|
|
279
276
|
*self.additional_extensions,
|
|
280
277
|
]
|
|
281
|
-
# TODO replace with computed extensions after https://github.com/
|
|
278
|
+
# TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
|
|
282
279
|
all_used_extensions = [
|
|
283
280
|
*extensions,
|
|
284
281
|
hugr.std.prelude.PRELUDE_EXTENSION,
|
guppylang_internals/nodes.py
CHANGED
|
@@ -9,7 +9,13 @@ from guppylang_internals.ast_util import AstNode
|
|
|
9
9
|
from guppylang_internals.span import Span, to_span
|
|
10
10
|
from guppylang_internals.tys.const import Const
|
|
11
11
|
from guppylang_internals.tys.subst import Inst
|
|
12
|
-
from guppylang_internals.tys.ty import
|
|
12
|
+
from guppylang_internals.tys.ty import (
|
|
13
|
+
FunctionType,
|
|
14
|
+
StructType,
|
|
15
|
+
TupleType,
|
|
16
|
+
Type,
|
|
17
|
+
UnitaryFlags,
|
|
18
|
+
)
|
|
13
19
|
|
|
14
20
|
if TYPE_CHECKING:
|
|
15
21
|
from guppylang_internals.cfg.cfg import CFG
|
|
@@ -166,6 +172,14 @@ class MakeIter(ast.expr):
|
|
|
166
172
|
self.origin_node = origin_node
|
|
167
173
|
self.unwrap_size_hint = unwrap_size_hint
|
|
168
174
|
|
|
175
|
+
# Needed for the deepcopy to work correctly, ast.AST's deepcopy logic
|
|
176
|
+
# reconstructs nodes using _fields only.
|
|
177
|
+
# If you store extra attributes or rely overwriting the __init__,
|
|
178
|
+
# deepcopy will crash with a constructor mismatch.
|
|
179
|
+
# Overriding __reduce__ forces deepcopy to copy the instance dictionary instead
|
|
180
|
+
__reduce_ex__ = object.__reduce_ex__
|
|
181
|
+
__reduce__ = object.__reduce__
|
|
182
|
+
|
|
169
183
|
|
|
170
184
|
class IterNext(ast.expr):
|
|
171
185
|
"""Obtains the next element of an iterator using the `__next__` magic method.
|
|
@@ -250,22 +264,6 @@ class ComptimeExpr(ast.expr):
|
|
|
250
264
|
_fields = ("value",)
|
|
251
265
|
|
|
252
266
|
|
|
253
|
-
class ResultExpr(ast.expr):
|
|
254
|
-
"""A `result(tag, value)` expression."""
|
|
255
|
-
|
|
256
|
-
value: ast.expr
|
|
257
|
-
base_ty: Type
|
|
258
|
-
#: Array length in case this is an array result, otherwise `None`
|
|
259
|
-
array_len: Const | None
|
|
260
|
-
tag: str
|
|
261
|
-
|
|
262
|
-
_fields = ("value", "base_ty", "array_len", "tag")
|
|
263
|
-
|
|
264
|
-
@property
|
|
265
|
-
def args(self) -> list[ast.expr]:
|
|
266
|
-
return [self.value]
|
|
267
|
-
|
|
268
|
-
|
|
269
267
|
class ExitKind(Enum):
|
|
270
268
|
ExitShot = 0 # Exit the current shot
|
|
271
269
|
Panic = 1 # Panic the program ending all shots
|
|
@@ -275,8 +273,8 @@ class PanicExpr(ast.expr):
|
|
|
275
273
|
"""A `panic(msg, *args)` or `exit(msg, *args)` expression ."""
|
|
276
274
|
|
|
277
275
|
kind: ExitKind
|
|
278
|
-
signal:
|
|
279
|
-
msg:
|
|
276
|
+
signal: ast.expr
|
|
277
|
+
msg: ast.expr
|
|
280
278
|
values: list[ast.expr]
|
|
281
279
|
|
|
282
280
|
_fields = ("kind", "signal", "msg", "values")
|
|
@@ -293,17 +291,16 @@ class BarrierExpr(ast.expr):
|
|
|
293
291
|
class StateResultExpr(ast.expr):
|
|
294
292
|
"""A `state_result(tag, *args)` expression."""
|
|
295
293
|
|
|
296
|
-
|
|
294
|
+
tag_value: Const
|
|
295
|
+
tag_expr: ast.expr
|
|
297
296
|
args: list[ast.expr]
|
|
298
297
|
func_ty: FunctionType
|
|
299
298
|
#: Array length in case this is an array result, otherwise `None`
|
|
300
299
|
array_len: Const | None
|
|
301
|
-
_fields = ("
|
|
300
|
+
_fields = ("tag_value", "tag_expr", "args", "func_ty", "has_array_input")
|
|
302
301
|
|
|
303
302
|
|
|
304
|
-
AnyCall =
|
|
305
|
-
LocalCall | GlobalCall | TensorCall | BarrierExpr | ResultExpr | StateResultExpr
|
|
306
|
-
)
|
|
303
|
+
AnyCall = LocalCall | GlobalCall | TensorCall | BarrierExpr | StateResultExpr
|
|
307
304
|
|
|
308
305
|
|
|
309
306
|
class InoutReturnSentinel(ast.expr):
|
|
@@ -360,6 +357,10 @@ class ArrayUnpack(ast.expr):
|
|
|
360
357
|
self.length = length
|
|
361
358
|
self.elt_type = elt_type
|
|
362
359
|
|
|
360
|
+
# See MakeIter for explanation
|
|
361
|
+
__reduce__ = object.__reduce__
|
|
362
|
+
__reduce_ex__ = object.__reduce_ex__
|
|
363
|
+
|
|
363
364
|
|
|
364
365
|
class IterableUnpack(ast.expr):
|
|
365
366
|
"""The LHS of an unpacking assignment of an iterable type."""
|
|
@@ -384,6 +385,10 @@ class IterableUnpack(ast.expr):
|
|
|
384
385
|
self.compr = compr
|
|
385
386
|
self.rhs_var = rhs_var
|
|
386
387
|
|
|
388
|
+
# See MakeIter for explanation
|
|
389
|
+
__reduce__ = object.__reduce__
|
|
390
|
+
__reduce_ex__ = object.__reduce_ex__
|
|
391
|
+
|
|
387
392
|
|
|
388
393
|
#: Any unpacking operation.
|
|
389
394
|
AnyUnpack = TupleUnpack | ArrayUnpack | IterableUnpack
|
|
@@ -431,6 +436,10 @@ class Dagger(ast.expr):
|
|
|
431
436
|
def __init__(self, node: ast.expr) -> None:
|
|
432
437
|
super().__init__(**node.__dict__)
|
|
433
438
|
|
|
439
|
+
# See MakeIter for explanation
|
|
440
|
+
__reduce__ = object.__reduce__
|
|
441
|
+
__reduce_ex__ = object.__reduce_ex__
|
|
442
|
+
|
|
434
443
|
|
|
435
444
|
class Control(ast.Call):
|
|
436
445
|
"""The control modifier"""
|
|
@@ -445,6 +454,10 @@ class Control(ast.Call):
|
|
|
445
454
|
self.ctrl = ctrl
|
|
446
455
|
self.qubit_num = None
|
|
447
456
|
|
|
457
|
+
# See MakeIter for explanation
|
|
458
|
+
__reduce__ = object.__reduce__
|
|
459
|
+
__reduce_ex__ = object.__reduce_ex__
|
|
460
|
+
|
|
448
461
|
|
|
449
462
|
class Power(ast.expr):
|
|
450
463
|
"""The power modifier"""
|
|
@@ -457,6 +470,10 @@ class Power(ast.expr):
|
|
|
457
470
|
super().__init__(**node.__dict__)
|
|
458
471
|
self.iter = iter
|
|
459
472
|
|
|
473
|
+
# See MakeIter for explanation
|
|
474
|
+
__reduce__ = object.__reduce__
|
|
475
|
+
__reduce_ex__ = object.__reduce_ex__
|
|
476
|
+
|
|
460
477
|
|
|
461
478
|
Modifier = Dagger | Control | Power
|
|
462
479
|
|
|
@@ -500,6 +517,16 @@ class ModifiedBlock(ast.With):
|
|
|
500
517
|
else:
|
|
501
518
|
raise TypeError(f"Unknown modifier: {modifier}")
|
|
502
519
|
|
|
520
|
+
def flags(self) -> UnitaryFlags:
|
|
521
|
+
flags = UnitaryFlags.NoFlags
|
|
522
|
+
if self.is_dagger():
|
|
523
|
+
flags |= UnitaryFlags.Dagger
|
|
524
|
+
if self.is_control():
|
|
525
|
+
flags |= UnitaryFlags.Control
|
|
526
|
+
if self.is_power():
|
|
527
|
+
flags |= UnitaryFlags.Power
|
|
528
|
+
return flags
|
|
529
|
+
|
|
503
530
|
|
|
504
531
|
class CheckedModifiedBlock(ast.With):
|
|
505
532
|
def_id: "DefId"
|
|
@@ -534,6 +561,10 @@ class CheckedModifiedBlock(ast.With):
|
|
|
534
561
|
self.control = control
|
|
535
562
|
self.power = power
|
|
536
563
|
|
|
564
|
+
# See MakeIter for explanation
|
|
565
|
+
__reduce__ = object.__reduce__
|
|
566
|
+
__reduce_ex__ = object.__reduce_ex__
|
|
567
|
+
|
|
537
568
|
def __str__(self) -> str:
|
|
538
569
|
# generate a function name from the def_id
|
|
539
570
|
return f"__WithBlock__({self.def_id})"
|