guppylang-internals 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +1 -1
- guppylang_internals/ast_util.py +37 -18
- guppylang_internals/cfg/analysis.py +6 -6
- guppylang_internals/cfg/builder.py +41 -12
- guppylang_internals/cfg/cfg.py +1 -1
- guppylang_internals/checker/core.py +1 -1
- guppylang_internals/checker/errors/comptime_errors.py +0 -12
- guppylang_internals/checker/expr_checker.py +27 -17
- guppylang_internals/checker/func_checker.py +4 -3
- guppylang_internals/checker/stmt_checker.py +1 -1
- guppylang_internals/compiler/cfg_compiler.py +1 -1
- guppylang_internals/compiler/core.py +17 -4
- guppylang_internals/compiler/expr_compiler.py +9 -9
- guppylang_internals/decorator.py +2 -2
- guppylang_internals/definition/common.py +1 -0
- guppylang_internals/definition/custom.py +2 -2
- guppylang_internals/definition/declaration.py +3 -3
- guppylang_internals/definition/function.py +8 -1
- guppylang_internals/definition/metadata.py +1 -1
- guppylang_internals/definition/pytket_circuits.py +44 -65
- guppylang_internals/definition/value.py +1 -1
- guppylang_internals/definition/wasm.py +3 -3
- guppylang_internals/diagnostic.py +17 -1
- guppylang_internals/engine.py +83 -30
- guppylang_internals/error.py +1 -1
- guppylang_internals/nodes.py +269 -3
- guppylang_internals/span.py +7 -3
- guppylang_internals/std/_internal/checker.py +104 -2
- guppylang_internals/std/_internal/debug.py +5 -3
- guppylang_internals/tracing/builtins_mock.py +2 -2
- guppylang_internals/tracing/object.py +2 -2
- guppylang_internals/tys/parsing.py +4 -1
- guppylang_internals/tys/qubit.py +6 -4
- guppylang_internals/tys/subst.py +2 -2
- guppylang_internals/tys/ty.py +2 -2
- guppylang_internals/wasm_util.py +1 -2
- {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
- {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +40 -40
- {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
- {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
|
@@ -31,6 +31,7 @@ from guppylang_internals.definition.common import (
|
|
|
31
31
|
MonomorphizableDef,
|
|
32
32
|
MonomorphizedDef,
|
|
33
33
|
ParsableDef,
|
|
34
|
+
RawDef,
|
|
34
35
|
UnknownSourceError,
|
|
35
36
|
)
|
|
36
37
|
from guppylang_internals.definition.metadata import GuppyMetadata, add_metadata
|
|
@@ -177,6 +178,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
177
178
|
module: DefinitionBuilder[OpVar],
|
|
178
179
|
mono_args: "PartiallyMonomorphizedArgs",
|
|
179
180
|
ctx: "CompilerContext",
|
|
181
|
+
parent_ty: "RawDef | None" = None,
|
|
180
182
|
) -> "CompiledFunctionDef":
|
|
181
183
|
"""Adds a Hugr `FuncDefn` node for the (partially) monomorphized function to the
|
|
182
184
|
Hugr.
|
|
@@ -185,10 +187,15 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
|
185
187
|
access to the other compiled functions yet. The body is compiled later in
|
|
186
188
|
`CompiledFunctionDef.compile_inner()`.
|
|
187
189
|
"""
|
|
190
|
+
if parent_ty is None:
|
|
191
|
+
hugr_func_name = self.name
|
|
192
|
+
else:
|
|
193
|
+
hugr_func_name = f"{parent_ty.name}.{self.name}"
|
|
194
|
+
|
|
188
195
|
mono_ty = self.ty.instantiate_partial(mono_args)
|
|
189
196
|
hugr_ty = mono_ty.to_hugr_poly(ctx)
|
|
190
197
|
func_def = module.module_root_builder().define_function(
|
|
191
|
-
|
|
198
|
+
hugr_func_name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
|
|
192
199
|
)
|
|
193
200
|
add_metadata(
|
|
194
201
|
func_def,
|
|
@@ -3,18 +3,17 @@ from dataclasses import dataclass, field
|
|
|
3
3
|
from typing import Any, cast
|
|
4
4
|
|
|
5
5
|
import hugr.build.function as hf
|
|
6
|
+
from guppylang.defs import GuppyDefinition
|
|
6
7
|
from hugr import Node, Wire, envelope, ops, val
|
|
7
8
|
from hugr import tys as ht
|
|
8
9
|
from hugr.build.dfg import DefinitionBuilder, OpVar
|
|
9
10
|
from hugr.envelope import EnvelopeConfig
|
|
10
11
|
from hugr.std.float import FLOAT_T
|
|
12
|
+
from pytket.circuit import Circuit
|
|
11
13
|
|
|
12
14
|
from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc
|
|
13
15
|
from guppylang_internals.checker.core import Context, Globals
|
|
14
|
-
from guppylang_internals.checker.errors.comptime_errors import
|
|
15
|
-
PytketSignatureMismatch,
|
|
16
|
-
TketNotInstalled,
|
|
17
|
-
)
|
|
16
|
+
from guppylang_internals.checker.errors.comptime_errors import PytketSignatureMismatch
|
|
18
17
|
from guppylang_internals.checker.expr_checker import check_call, synthesize_call
|
|
19
18
|
from guppylang_internals.checker.func_checker import (
|
|
20
19
|
check_signature,
|
|
@@ -231,7 +230,7 @@ class ParsedPytketDef(CallableDef, CompilableDef):
|
|
|
231
230
|
)
|
|
232
231
|
lex_params = list(unpack_result)
|
|
233
232
|
param_order = cast(
|
|
234
|
-
list[str], hugr_func.metadata["TKET1.input_parameters"]
|
|
233
|
+
"list[str]", hugr_func.metadata["TKET1.input_parameters"]
|
|
235
234
|
)
|
|
236
235
|
lex_names = sorted(param_order)
|
|
237
236
|
name_to_param = dict(zip(lex_names, lex_params, strict=True))
|
|
@@ -369,69 +368,49 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef, CompiledHugrNodeDe
|
|
|
369
368
|
|
|
370
369
|
|
|
371
370
|
def _signature_from_circuit(
|
|
372
|
-
input_circuit:
|
|
371
|
+
input_circuit: Circuit,
|
|
373
372
|
defined_at: ToSpan | None,
|
|
374
373
|
use_arrays: bool = False,
|
|
375
374
|
) -> FunctionType:
|
|
376
375
|
"""Helper function for inferring a function signature from a pytket circuit."""
|
|
377
376
|
# May want to set proper unitary flags in the future.
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
)
|
|
408
|
-
outputs = [
|
|
409
|
-
array_type(bool_type(), c_reg.size)
|
|
410
|
-
for c_reg in input_circuit.c_registers
|
|
411
|
-
]
|
|
412
|
-
circuit_signature = FunctionType(
|
|
413
|
-
inputs,
|
|
414
|
-
row_to_type(outputs),
|
|
415
|
-
)
|
|
416
|
-
else:
|
|
417
|
-
param_inputs = [
|
|
418
|
-
FuncInput(angle_ty, InputFlags.NoFlags)
|
|
419
|
-
for _ in range(len(input_circuit.free_symbols()))
|
|
420
|
-
]
|
|
421
|
-
circuit_signature = FunctionType(
|
|
422
|
-
[FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
|
|
423
|
-
+ param_inputs,
|
|
424
|
-
row_to_type([bool_type()] * input_circuit.n_bits),
|
|
425
|
-
)
|
|
426
|
-
except ImportError:
|
|
427
|
-
err = TketNotInstalled(defined_at)
|
|
428
|
-
err.add_sub_diagnostic(TketNotInstalled.InstallInstruction(None))
|
|
429
|
-
raise GuppyError(err) from None
|
|
430
|
-
else:
|
|
431
|
-
pass
|
|
432
|
-
except ImportError:
|
|
433
|
-
raise InternalGuppyError(
|
|
434
|
-
"Pytket error should have been caught earlier"
|
|
435
|
-
) from None
|
|
377
|
+
from guppylang.std.angles import angle # Avoid circular imports
|
|
378
|
+
from guppylang.std.quantum import qubit
|
|
379
|
+
|
|
380
|
+
assert isinstance(qubit, GuppyDefinition)
|
|
381
|
+
qubit_ty = cast("TypeDef", qubit.wrapped).check_instantiate([])
|
|
382
|
+
|
|
383
|
+
angle_defn = ENGINE.get_checked(angle.id) # type: ignore[attr-defined]
|
|
384
|
+
assert isinstance(angle_defn, TypeDef)
|
|
385
|
+
angle_ty = angle_defn.check_instantiate([])
|
|
386
|
+
|
|
387
|
+
if use_arrays:
|
|
388
|
+
inputs = [
|
|
389
|
+
FuncInput(array_type(qubit_ty, q_reg.size), InputFlags.Inout)
|
|
390
|
+
for q_reg in input_circuit.q_registers
|
|
391
|
+
]
|
|
392
|
+
if len(input_circuit.free_symbols()) != 0:
|
|
393
|
+
inputs.append(
|
|
394
|
+
FuncInput(
|
|
395
|
+
array_type(angle_ty, len(input_circuit.free_symbols())),
|
|
396
|
+
InputFlags.NoFlags,
|
|
397
|
+
)
|
|
398
|
+
)
|
|
399
|
+
outputs = [
|
|
400
|
+
array_type(bool_type(), c_reg.size) for c_reg in input_circuit.c_registers
|
|
401
|
+
]
|
|
402
|
+
circuit_signature = FunctionType(
|
|
403
|
+
inputs,
|
|
404
|
+
row_to_type(outputs),
|
|
405
|
+
)
|
|
436
406
|
else:
|
|
437
|
-
|
|
407
|
+
param_inputs = [
|
|
408
|
+
FuncInput(angle_ty, InputFlags.NoFlags)
|
|
409
|
+
for _ in range(len(input_circuit.free_symbols()))
|
|
410
|
+
]
|
|
411
|
+
circuit_signature = FunctionType(
|
|
412
|
+
[FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
|
|
413
|
+
+ param_inputs,
|
|
414
|
+
row_to_type([bool_type()] * input_circuit.n_bits),
|
|
415
|
+
)
|
|
416
|
+
return circuit_signature
|
|
@@ -55,7 +55,7 @@ class CallableDef(ValueDef):
|
|
|
55
55
|
raise RuntimeError("Guppy functions can only be called in a Guppy context")
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
class CompiledCallableDef(CallableDef, CompiledValueDef):
|
|
58
|
+
class CompiledCallableDef(CallableDef, CompiledValueDef): # type: ignore[misc, unused-ignore]
|
|
59
59
|
"""Abstract base class a global module-level function."""
|
|
60
60
|
|
|
61
61
|
ty: FunctionType
|
|
@@ -38,9 +38,9 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
|
|
|
38
38
|
def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
|
|
39
39
|
# Place to highlight in error messages
|
|
40
40
|
match fun_ty.inputs:
|
|
41
|
-
case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if
|
|
42
|
-
ty
|
|
43
|
-
)
|
|
41
|
+
case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if (
|
|
42
|
+
wasm_module_name(ty) is not None
|
|
43
|
+
):
|
|
44
44
|
for inp in args:
|
|
45
45
|
if not self.is_type_wasmable(inp.ty):
|
|
46
46
|
raise GuppyError(UnWasmableType(loc, inp.ty))
|
|
@@ -384,9 +384,25 @@ class DiagnosticsRenderer:
|
|
|
384
384
|
if is_primary or print_pad_line:
|
|
385
385
|
render_line("")
|
|
386
386
|
|
|
387
|
-
# Grab all lines we want to display
|
|
387
|
+
# Grab all lines we want to display
|
|
388
388
|
prefix_lines = min(prefix_lines, span.start.line - 1)
|
|
389
389
|
all_lines = self.source.span_lines(span, prefix_lines)
|
|
390
|
+
|
|
391
|
+
# Convert leading tab characters into four whitespaces each (see PEP 8)
|
|
392
|
+
for i, line in enumerate(all_lines):
|
|
393
|
+
line_no_tabs = line.lstrip("\t")
|
|
394
|
+
num_tabs = len(line) - len(line_no_tabs)
|
|
395
|
+
all_lines[i] = " " * (num_tabs * 4) + line_no_tabs
|
|
396
|
+
# Shift span locations, accounting for incorporated \t
|
|
397
|
+
new_start = span.start
|
|
398
|
+
new_end = span.end
|
|
399
|
+
if i == prefix_lines: # Line is the first line in the span
|
|
400
|
+
new_start = span.start.shift_right(num_tabs * 3)
|
|
401
|
+
if i == len(all_lines) - 1: # Line is the last line in the span
|
|
402
|
+
new_end = span.end.shift_right(num_tabs * 3)
|
|
403
|
+
span = Span(new_start or span.start, new_end)
|
|
404
|
+
|
|
405
|
+
# Remove excessive leading whitespace
|
|
390
406
|
leading_whitespace = min(len(line) - len(line.lstrip()) for line in all_lines)
|
|
391
407
|
if leading_whitespace > self.MAX_LEADING_WHITESPACE:
|
|
392
408
|
remove = leading_whitespace - self.OPTIMAL_LEADING_WHITESPACE
|
guppylang_internals/engine.py
CHANGED
|
@@ -3,14 +3,10 @@ from enum import Enum
|
|
|
3
3
|
from types import FrameType
|
|
4
4
|
from typing import TYPE_CHECKING
|
|
5
5
|
|
|
6
|
+
import hugr
|
|
6
7
|
import hugr.build.function as hf
|
|
7
|
-
import hugr.std.collections.array
|
|
8
|
-
import hugr.std.float
|
|
9
|
-
import hugr.std.int
|
|
10
|
-
import hugr.std.logic
|
|
11
|
-
import hugr.std.prelude
|
|
12
8
|
from hugr import ops
|
|
13
|
-
from hugr.ext import Extension
|
|
9
|
+
from hugr.ext import Extension, ExtensionRegistry
|
|
14
10
|
from hugr.package import ModulePointer, Package
|
|
15
11
|
|
|
16
12
|
import guppylang_internals
|
|
@@ -150,11 +146,52 @@ class CompilationEngine:
|
|
|
150
146
|
types_to_check_worklist: dict[DefId, ParsedDef]
|
|
151
147
|
to_check_worklist: dict[DefId, ParsedDef]
|
|
152
148
|
|
|
149
|
+
# Cached compilation infrastructure (lazy-initialized, program-independent)
|
|
150
|
+
_base_packaged_extensions: list[Extension] | None = None
|
|
151
|
+
_base_resolve_registry: ExtensionRegistry | None = None
|
|
152
|
+
|
|
153
153
|
def __init__(self) -> None:
|
|
154
154
|
"""Resets the compilation cache."""
|
|
155
155
|
self.reset()
|
|
156
156
|
self.additional_extensions = []
|
|
157
157
|
|
|
158
|
+
@staticmethod
|
|
159
|
+
def _get_base_packaged_extensions() -> list[Extension]:
|
|
160
|
+
"""Get the base list of packaged extensions (cached at class level)."""
|
|
161
|
+
if CompilationEngine._base_packaged_extensions is None:
|
|
162
|
+
from guppylang_internals.std._internal.compiler.tket_exts import (
|
|
163
|
+
TKET_EXTENSIONS,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
CompilationEngine._base_packaged_extensions = [
|
|
167
|
+
*TKET_EXTENSIONS,
|
|
168
|
+
guppylang_internals.compiler.hugr_extension.EXTENSION, # type: ignore[attr-defined]
|
|
169
|
+
]
|
|
170
|
+
return CompilationEngine._base_packaged_extensions
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def _get_base_resolve_registry() -> ExtensionRegistry:
|
|
174
|
+
"""Get the base resolve registry with standard extensions.
|
|
175
|
+
|
|
176
|
+
Cached at class level.
|
|
177
|
+
"""
|
|
178
|
+
if CompilationEngine._base_resolve_registry is None:
|
|
179
|
+
base_extensions = CompilationEngine._get_base_packaged_extensions()
|
|
180
|
+
registry = ExtensionRegistry()
|
|
181
|
+
for ext in [
|
|
182
|
+
*base_extensions,
|
|
183
|
+
hugr.std.prelude.PRELUDE_EXTENSION,
|
|
184
|
+
hugr.std.collections.array.EXTENSION,
|
|
185
|
+
hugr.std.float.FLOAT_OPS_EXTENSION,
|
|
186
|
+
hugr.std.float.FLOAT_TYPES_EXTENSION,
|
|
187
|
+
hugr.std.int.INT_OPS_EXTENSION,
|
|
188
|
+
hugr.std.int.INT_TYPES_EXTENSION,
|
|
189
|
+
hugr.std.logic.EXTENSION,
|
|
190
|
+
]:
|
|
191
|
+
registry.register_updated(ext)
|
|
192
|
+
CompilationEngine._base_resolve_registry = registry
|
|
193
|
+
return CompilationEngine._base_resolve_registry
|
|
194
|
+
|
|
158
195
|
def reset(self) -> None:
|
|
159
196
|
"""Resets the compilation cache."""
|
|
160
197
|
self.parsed = {}
|
|
@@ -264,40 +301,56 @@ class CompilationEngine:
|
|
|
264
301
|
# loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
|
|
265
302
|
graph.hugr.entrypoint = compiled_def.hugr_node
|
|
266
303
|
|
|
267
|
-
#
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
# The hugr prelude and std_extensions are implicit.
|
|
271
|
-
from guppylang_internals.std._internal.compiler.tket_exts import TKET_EXTENSIONS
|
|
304
|
+
# Use cached base extensions and registry, only add additional extensions
|
|
305
|
+
base_extensions = self._get_base_packaged_extensions()
|
|
306
|
+
packaged_extensions = [*base_extensions, *self.additional_extensions]
|
|
272
307
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
308
|
+
# Build resolve registry: start with cached base, add any additional
|
|
309
|
+
if self.additional_extensions:
|
|
310
|
+
from copy import deepcopy
|
|
311
|
+
|
|
312
|
+
resolve_registry = deepcopy(self._get_base_resolve_registry())
|
|
313
|
+
for ext in self.additional_extensions:
|
|
314
|
+
resolve_registry.register_updated(ext)
|
|
315
|
+
else:
|
|
316
|
+
resolve_registry = self._get_base_resolve_registry()
|
|
317
|
+
|
|
318
|
+
# Compute used extensions dynamically from the HUGR.
|
|
319
|
+
used_extensions_result = graph.hugr.used_extensions(
|
|
320
|
+
resolve_from=resolve_registry
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Set metadata for used extensions
|
|
324
|
+
used_exts_meta = [
|
|
290
325
|
{
|
|
291
326
|
"name": ext.name,
|
|
292
327
|
"version": str(ext.version),
|
|
293
328
|
}
|
|
294
|
-
for ext in
|
|
329
|
+
for ext in used_extensions_result.used_extensions.extensions.values()
|
|
295
330
|
]
|
|
331
|
+
# Add unresolved extensions as well, but we only have the names
|
|
332
|
+
used_exts_meta.extend(
|
|
333
|
+
{
|
|
334
|
+
"name": ext,
|
|
335
|
+
}
|
|
336
|
+
for ext in used_extensions_result.unresolved_extensions
|
|
337
|
+
)
|
|
338
|
+
graph.hugr.module_root.metadata[CoreMetadataKeys.USED_EXTENSIONS.value] = (
|
|
339
|
+
used_exts_meta
|
|
340
|
+
)
|
|
296
341
|
graph.hugr.module_root.metadata[CoreMetadataKeys.GENERATOR.value] = {
|
|
297
342
|
"name": "guppylang",
|
|
298
343
|
"version": guppylang_internals.__version__,
|
|
299
344
|
}
|
|
300
|
-
|
|
345
|
+
# only package used extensions
|
|
346
|
+
packaged_extensions = [
|
|
347
|
+
ext
|
|
348
|
+
for ext in packaged_extensions
|
|
349
|
+
if ext.name in used_extensions_result.ids()
|
|
350
|
+
]
|
|
351
|
+
return ModulePointer(
|
|
352
|
+
Package(modules=[graph.hugr], extensions=packaged_extensions), 0
|
|
353
|
+
)
|
|
301
354
|
|
|
302
355
|
|
|
303
356
|
ENGINE: CompilationEngine = CompilationEngine()
|
guppylang_internals/error.py
CHANGED