guppylang-internals 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +37 -18
  3. guppylang_internals/cfg/analysis.py +6 -6
  4. guppylang_internals/cfg/builder.py +44 -12
  5. guppylang_internals/cfg/cfg.py +1 -1
  6. guppylang_internals/checker/core.py +1 -1
  7. guppylang_internals/checker/errors/comptime_errors.py +0 -12
  8. guppylang_internals/checker/errors/linearity.py +6 -2
  9. guppylang_internals/checker/expr_checker.py +53 -28
  10. guppylang_internals/checker/func_checker.py +4 -3
  11. guppylang_internals/checker/stmt_checker.py +1 -1
  12. guppylang_internals/compiler/cfg_compiler.py +1 -1
  13. guppylang_internals/compiler/core.py +17 -4
  14. guppylang_internals/compiler/expr_compiler.py +36 -14
  15. guppylang_internals/compiler/modifier_compiler.py +5 -2
  16. guppylang_internals/decorator.py +5 -3
  17. guppylang_internals/definition/common.py +1 -0
  18. guppylang_internals/definition/custom.py +2 -2
  19. guppylang_internals/definition/declaration.py +3 -3
  20. guppylang_internals/definition/function.py +28 -8
  21. guppylang_internals/definition/metadata.py +87 -0
  22. guppylang_internals/definition/overloaded.py +11 -2
  23. guppylang_internals/definition/pytket_circuits.py +50 -67
  24. guppylang_internals/definition/value.py +1 -1
  25. guppylang_internals/definition/wasm.py +3 -3
  26. guppylang_internals/diagnostic.py +89 -16
  27. guppylang_internals/engine.py +84 -40
  28. guppylang_internals/error.py +1 -1
  29. guppylang_internals/nodes.py +301 -3
  30. guppylang_internals/span.py +7 -3
  31. guppylang_internals/std/_internal/checker.py +104 -2
  32. guppylang_internals/std/_internal/compiler/array.py +36 -1
  33. guppylang_internals/std/_internal/compiler/either.py +14 -2
  34. guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
  35. guppylang_internals/std/_internal/compiler/tket_exts.py +1 -1
  36. guppylang_internals/std/_internal/debug.py +5 -3
  37. guppylang_internals/tracing/builtins_mock.py +2 -2
  38. guppylang_internals/tracing/object.py +6 -2
  39. guppylang_internals/tys/parsing.py +4 -1
  40. guppylang_internals/tys/qubit.py +6 -4
  41. guppylang_internals/tys/subst.py +2 -2
  42. guppylang_internals/tys/ty.py +2 -2
  43. guppylang_internals/wasm_util.py +2 -3
  44. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
  45. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +47 -46
  46. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
  47. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
@@ -144,9 +144,9 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
144
144
  """
145
145
  old = self.dfg
146
146
  # Check that the input names are unique
147
- assert len({inp.place.id for inp in inputs}) == len(
148
- inputs
149
- ), "Inputs are not unique"
147
+ assert len({inp.place.id for inp in inputs}) == len(inputs), (
148
+ "Inputs are not unique"
149
+ )
150
150
  self.dfg = DFContainer(builder, self.ctx, self.dfg.locals.copy())
151
151
  hugr_input = builder.input_node
152
152
  for input_node, wire in zip(inputs, hugr_input, strict=True):
@@ -325,14 +325,7 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
325
325
 
326
326
  def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire:
327
327
  """Groups function return values into a tuple"""
328
- if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
329
- types = type_to_row(return_ty)
330
- assert len(returns) == len(types)
331
- return self._pack_tuple(returns, types)
332
- assert (
333
- len(returns) == 1
334
- ), f"Expected a single return value. Got {returns}. return type {return_ty}"
335
- return returns[0]
328
+ return pack_returns(returns, return_ty, self.builder, self.ctx)
336
329
 
337
330
  def _update_inout_ports(
338
331
  self,
@@ -394,9 +387,9 @@ class ExprCompiler(CompilerBase, AstVisitor[Wire]):
394
387
  func, func_ty, remaining_args
395
388
  )
396
389
  rets.extend(outs)
397
- assert (
398
- remaining_args == []
399
- ), "Not all function arguments were consumed after a tensor call"
390
+ assert remaining_args == [], (
391
+ "Not all function arguments were consumed after a tensor call"
392
+ )
400
393
  return self._pack_returns(rets, node.tensor_ty.output)
401
394
 
402
395
  def _compile_tensor_with_leftovers(
@@ -760,6 +753,35 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]:
760
753
  return expr.elts if isinstance(expr, ast.Tuple) else [expr]
761
754
 
762
755
 
756
+ def pack_returns(
757
+ returns: Sequence[Wire],
758
+ return_ty: Type,
759
+ builder: DfBase[ops.DfParentOp],
760
+ ctx: CompilerContext,
761
+ ) -> Wire:
762
+ """Groups function return values into a tuple"""
763
+ if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
764
+ types = type_to_row(return_ty)
765
+ assert len(returns) == len(types)
766
+ hugr_tys = [t.to_hugr(ctx) for t in types]
767
+ return builder.add_op(ops.MakeTuple(hugr_tys), *returns)
768
+ assert len(returns) == 1, (
769
+ f"Expected a single return value. Got {returns}. return type {return_ty}"
770
+ )
771
+ return returns[0]
772
+
773
+
774
+ def unpack_wire(
775
+ wire: Wire, return_ty: Type, builder: DfBase[ops.DfParentOp], ctx: CompilerContext
776
+ ) -> list[Wire]:
777
+ """The inverse of `pack_returns`"""
778
+ if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
779
+ types = type_to_row(return_ty)
780
+ hugr_tys = [t.to_hugr(ctx) for t in types]
781
+ return list(builder.add_op(ops.UnpackTuple(hugr_tys), wire).outputs())
782
+ return [wire]
783
+
784
+
763
785
  def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
764
786
  """Checks if instantiating a polymorphic makes it return a row."""
765
787
  if isinstance(func_ty.output, BoundTypeVar):
@@ -8,7 +8,7 @@ from guppylang_internals.checker.modifier_checker import non_copyable_front_othe
8
8
  from guppylang_internals.compiler.cfg_compiler import compile_cfg
9
9
  from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
10
  from guppylang_internals.compiler.expr_compiler import ExprCompiler
11
- from guppylang_internals.definition.function import add_unitarity_metadata
11
+ from guppylang_internals.definition.metadata import add_metadata
12
12
  from guppylang_internals.nodes import CheckedModifiedBlock, PlaceNode
13
13
  from guppylang_internals.std._internal.compiler.array import (
14
14
  array_new,
@@ -57,7 +57,10 @@ def compile_modified_block(
57
57
  func_builder = dfg.builder.module_root_builder().define_function(
58
58
  str(modified_block), hugr_ty.input, hugr_ty.output
59
59
  )
60
- add_unitarity_metadata(func_builder, modified_block.ty.unitary_flags)
60
+ add_metadata(
61
+ func_builder,
62
+ additional_metadata={"unitary": modified_block.ty.unitary_flags.value},
63
+ )
61
64
 
62
65
  # compile body
63
66
  cfg = compile_cfg(modified_block.cfg, func_builder, func_builder.inputs(), ctx)
@@ -4,10 +4,10 @@ import inspect
4
4
  import pathlib
5
5
  from typing import TYPE_CHECKING, ParamSpec, TypeVar, overload
6
6
 
7
+ from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
7
8
  from hugr import ops
8
9
  from hugr import tys as ht
9
10
 
10
- from guppylang.defs import GuppyDefinition, GuppyFunctionDefinition
11
11
  from guppylang_internals.compiler.core import (
12
12
  CompilerContext,
13
13
  GlobalConstId,
@@ -26,7 +26,7 @@ from guppylang_internals.definition.ty import OpaqueTypeDef, TypeDef
26
26
  from guppylang_internals.definition.wasm import RawWasmFunctionDef
27
27
  from guppylang_internals.dummy_decorator import _dummy_custom_decorator, sphinx_running
28
28
  from guppylang_internals.engine import DEF_STORE
29
- from guppylang_internals.error import GuppyError
29
+ from guppylang_internals.error import GuppyError, pretty_errors
30
30
  from guppylang_internals.std._internal.checker import WasmCallChecker
31
31
  from guppylang_internals.std._internal.compiler.wasm import (
32
32
  WasmModuleCallCompiler,
@@ -193,7 +193,7 @@ def custom_type(
193
193
  params or [],
194
194
  not copyable,
195
195
  not droppable,
196
- mk_hugr_ty,
196
+ mk_hugr_ty, # type: ignore[arg-type]
197
197
  bound,
198
198
  )
199
199
  DEF_STORE.register_def(defn, get_calling_frame())
@@ -207,6 +207,7 @@ def custom_type(
207
207
  return dec
208
208
 
209
209
 
210
+ @pretty_errors
210
211
  def wasm_module(
211
212
  filename: str,
212
213
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
@@ -252,6 +253,7 @@ def ext_module_decorator(
252
253
  def fun(
253
254
  filename: str, module: str | None
254
255
  ) -> Callable[[builtins.type[T]], GuppyDefinition]:
256
+ @pretty_errors
255
257
  def dec(cls: builtins.type[T]) -> GuppyDefinition:
256
258
  # N.B. Only one module per file and vice-versa
257
259
  ext_module = type_def(
@@ -163,6 +163,7 @@ class MonomorphizableDef(Definition):
163
163
  module: DefinitionBuilder[OpVar],
164
164
  mono_args: "PartiallyMonomorphizedArgs",
165
165
  ctx: "CompilerContext",
166
+ parent_ty: "RawDef | None" = None,
166
167
  ) -> "MonomorphizedDef":
167
168
  """Adds a Hugr node for the (partially) monomorphized definition to the provided
168
169
  Hugr module.
@@ -134,7 +134,7 @@ class RawCustomFunctionDef(ParsableDef):
134
134
  """
135
135
  from guppylang_internals.definition.function import parse_py_func
136
136
 
137
- func_ast, docstring = parse_py_func(self.python_func, sources)
137
+ func_ast, _docstring = parse_py_func(self.python_func, sources)
138
138
  if not has_empty_body(func_ast):
139
139
  raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
140
140
  sig = self.signature or self._get_signature(func_ast, globals)
@@ -479,7 +479,7 @@ class BoolOpCompiler(CustomInoutCallCompiler):
479
479
  for res in result
480
480
  ]
481
481
  return CallReturnWires(
482
- regular_returns=converted_result,
482
+ regular_returns=converted_result, # type: ignore[arg-type]
483
483
  inout_returns=[],
484
484
  )
485
485
 
@@ -120,9 +120,9 @@ class CheckedFunctionDecl(RawFunctionDecl, CompilableDef, CallableDef):
120
120
  self, module: DefinitionBuilder[OpVar], ctx: CompilerContext
121
121
  ) -> "CompiledFunctionDecl":
122
122
  """Adds a Hugr `FuncDecl` node for this function to the Hugr."""
123
- assert isinstance(
124
- module, hf.Module
125
- ), "Functions can only be declared in modules"
123
+ assert isinstance(module, hf.Module), (
124
+ "Functions can only be declared in modules"
125
+ )
126
126
  module: hf.Module = module
127
127
 
128
128
  node = module.declare_function(self.name, self.ty.to_hugr_poly(ctx))
@@ -31,8 +31,10 @@ from guppylang_internals.definition.common import (
31
31
  MonomorphizableDef,
32
32
  MonomorphizedDef,
33
33
  ParsableDef,
34
+ RawDef,
34
35
  UnknownSourceError,
35
36
  )
37
+ from guppylang_internals.definition.metadata import GuppyMetadata, add_metadata
36
38
  from guppylang_internals.definition.value import (
37
39
  CallableDef,
38
40
  CallReturnWires,
@@ -72,13 +74,22 @@ class RawFunctionDef(ParsableDef):
72
74
 
73
75
  unitary_flags: UnitaryFlags = field(default=UnitaryFlags.NoFlags, kw_only=True)
74
76
 
77
+ metadata: GuppyMetadata | None = field(default=None, kw_only=True)
78
+
75
79
  def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
76
80
  """Parses and checks the user-provided signature of the function."""
77
81
  func_ast, docstring = parse_py_func(self.python_func, sources)
78
82
  ty = check_signature(
79
83
  func_ast, globals, self.id, unitary_flags=self.unitary_flags
80
84
  )
81
- return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
85
+ return ParsedFunctionDef(
86
+ self.id,
87
+ self.name,
88
+ func_ast,
89
+ ty,
90
+ docstring,
91
+ metadata=self.metadata,
92
+ )
82
93
 
83
94
 
84
95
  @dataclass(frozen=True)
@@ -103,6 +114,8 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
103
114
 
104
115
  description: str = field(default="function", init=False)
105
116
 
117
+ metadata: GuppyMetadata | None = field(default=None, kw_only=True)
118
+
106
119
  def check(self, globals: Globals) -> "CheckedFunctionDef":
107
120
  """Type checks the body of the function."""
108
121
  # Add python variable scope to the globals
@@ -114,6 +127,7 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
114
127
  self.ty,
115
128
  self.docstring,
116
129
  cfg,
130
+ metadata=self.metadata,
117
131
  )
118
132
 
119
133
  def check_call(
@@ -164,6 +178,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
164
178
  module: DefinitionBuilder[OpVar],
165
179
  mono_args: "PartiallyMonomorphizedArgs",
166
180
  ctx: "CompilerContext",
181
+ parent_ty: "RawDef | None" = None,
167
182
  ) -> "CompiledFunctionDef":
168
183
  """Adds a Hugr `FuncDefn` node for the (partially) monomorphized function to the
169
184
  Hugr.
@@ -172,12 +187,21 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
172
187
  access to the other compiled functions yet. The body is compiled later in
173
188
  `CompiledFunctionDef.compile_inner()`.
174
189
  """
190
+ if parent_ty is None:
191
+ hugr_func_name = self.name
192
+ else:
193
+ hugr_func_name = f"{parent_ty.name}.{self.name}"
194
+
175
195
  mono_ty = self.ty.instantiate_partial(mono_args)
176
196
  hugr_ty = mono_ty.to_hugr_poly(ctx)
177
197
  func_def = module.module_root_builder().define_function(
178
- self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
198
+ hugr_func_name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
199
+ )
200
+ add_metadata(
201
+ func_def,
202
+ self.metadata,
203
+ additional_metadata={"unitary": self.ty.unitary_flags.value},
179
204
  )
180
- add_unitarity_metadata(func_def, self.ty.unitary_flags)
181
205
  return CompiledFunctionDef(
182
206
  self.id,
183
207
  self.name,
@@ -187,6 +211,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
187
211
  self.docstring,
188
212
  self.cfg,
189
213
  func_def,
214
+ metadata=self.metadata,
190
215
  )
191
216
 
192
217
 
@@ -305,8 +330,3 @@ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AS
305
330
  else:
306
331
  node = ast.parse(source).body[0]
307
332
  return source, node, line_offset
308
-
309
-
310
- def add_unitarity_metadata(func: hf.Function, flags: UnitaryFlags) -> None:
311
- """Stores unitarity annotations in the metadate of a Hugr function definition."""
312
- func.metadata["unitary"] = flags.value
@@ -0,0 +1,87 @@
1
+ """Metadata attached to objects within the Guppy compiler, both for internal use and to
2
+ attach to HUGR nodes for lower-level processing."""
3
+
4
+ from abc import ABC
5
+ from dataclasses import dataclass, field, fields
6
+ from typing import Any, ClassVar, Generic, TypeVar
7
+
8
+ from hugr.hugr.node_port import ToNode
9
+
10
+ from guppylang_internals.diagnostic import Fatal
11
+ from guppylang_internals.error import GuppyError
12
+
13
+ T = TypeVar("T")
14
+
15
+
16
+ @dataclass(init=True, kw_only=True)
17
+ class GuppyMetadataValue(ABC, Generic[T]):
18
+ """A template class for a metadata value within the scope of the Guppy compiler.
19
+ Implementations should provide the `key` in reverse-URL format."""
20
+
21
+ key: ClassVar[str]
22
+ value: T | None = None
23
+
24
+
25
+ class MetadataMaxQubits(GuppyMetadataValue[int]):
26
+ key = "tket.hint.max_qubits"
27
+
28
+
29
+ @dataclass(frozen=True, init=True, kw_only=True)
30
+ class GuppyMetadata:
31
+ """DTO for metadata within the scope of the guppy compiler for attachment to HUGR
32
+ nodes. See `add_metadata`."""
33
+
34
+ max_qubits: MetadataMaxQubits = field(default_factory=MetadataMaxQubits, init=False)
35
+
36
+ @classmethod
37
+ def reserved_keys(cls) -> set[str]:
38
+ return {f.type.key for f in fields(GuppyMetadata)} # type: ignore[union-attr]
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class MetadataAlreadySetError(Fatal):
43
+ title: ClassVar[str] = "Metadata key already set"
44
+ message: ClassVar[str] = "Received two values for the metadata key `{key}`"
45
+ key: str
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class ReservedMetadataKeysError(Fatal):
50
+ title: ClassVar[str] = "Metadata key is reserved"
51
+ message: ClassVar[str] = (
52
+ "The following metadata keys are reserved by Guppy but also provided in "
53
+ "additional metadata: `{keys}`"
54
+ )
55
+ keys: set[str]
56
+
57
+
58
+ def add_metadata(
59
+ node: ToNode,
60
+ metadata: GuppyMetadata | None = None,
61
+ *,
62
+ additional_metadata: dict[str, Any] | None = None,
63
+ ) -> None:
64
+ """Adds metadata to the given node using the keys defined through inheritors of
65
+ `GuppyMetadataValue` defined in the `GuppyMetadata` class.
66
+
67
+ Additional metadata is forwarded as is, although the given dictionary may not
68
+ contain any keys already reserved by fields in `GuppyMetadata`.
69
+ """
70
+ if metadata is not None:
71
+ for f in fields(GuppyMetadata):
72
+ data: GuppyMetadataValue[Any] = getattr(metadata, f.name)
73
+ if data.key in node.metadata:
74
+ raise GuppyError(MetadataAlreadySetError(None, data.key))
75
+ if data.value is not None:
76
+ node.metadata[data.key] = data.value
77
+
78
+ if additional_metadata is not None:
79
+ reserved_keys = GuppyMetadata.reserved_keys()
80
+ used_reserved_keys = reserved_keys.intersection(additional_metadata.keys())
81
+ if len(used_reserved_keys) > 0:
82
+ raise GuppyError(ReservedMetadataKeysError(None, keys=used_reserved_keys))
83
+
84
+ for key, value in additional_metadata.items():
85
+ if key in node.metadata:
86
+ raise GuppyError(MetadataAlreadySetError(None, key))
87
+ node.metadata[key] = value
@@ -1,4 +1,5 @@
1
1
  import ast
2
+ import copy
2
3
  from contextlib import suppress
3
4
  from dataclasses import dataclass, field
4
5
  from typing import ClassVar, NoReturn
@@ -86,7 +87,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
86
87
  assert isinstance(defn, CallableDef)
87
88
  available_sigs.append(defn.ty)
88
89
  with suppress(GuppyError):
89
- return defn.check_call(args, ty, node, ctx)
90
+ # check_call may modify args and node,
91
+ # thus we deepcopy them before passing in the function
92
+ node_copy = copy.deepcopy(node)
93
+ args_copy = copy.deepcopy(args)
94
+ return defn.check_call(args_copy, ty, node_copy, ctx)
90
95
  return self._call_error(args, node, ctx, available_sigs, ty)
91
96
 
92
97
  def synthesize_call(
@@ -98,7 +103,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
98
103
  assert isinstance(defn, CallableDef)
99
104
  available_sigs.append(defn.ty)
100
105
  with suppress(GuppyError):
101
- return defn.synthesize_call(args, node, ctx)
106
+ # synthesize_call may modify args and node,
107
+ # thus we deepcopy them before passing in the function
108
+ node_copy = copy.deepcopy(node)
109
+ args_copy = copy.deepcopy(args)
110
+ return defn.synthesize_call(args_copy, node_copy, ctx)
102
111
  return self._call_error(args, node, ctx, available_sigs)
103
112
 
104
113
  def _call_error(
@@ -3,18 +3,17 @@ from dataclasses import dataclass, field
3
3
  from typing import Any, cast
4
4
 
5
5
  import hugr.build.function as hf
6
+ from guppylang.defs import GuppyDefinition
6
7
  from hugr import Node, Wire, envelope, ops, val
7
8
  from hugr import tys as ht
8
9
  from hugr.build.dfg import DefinitionBuilder, OpVar
9
10
  from hugr.envelope import EnvelopeConfig
10
11
  from hugr.std.float import FLOAT_T
12
+ from pytket.circuit import Circuit
11
13
 
12
14
  from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc
13
15
  from guppylang_internals.checker.core import Context, Globals
14
- from guppylang_internals.checker.errors.comptime_errors import (
15
- PytketSignatureMismatch,
16
- TketNotInstalled,
17
- )
16
+ from guppylang_internals.checker.errors.comptime_errors import PytketSignatureMismatch
18
17
  from guppylang_internals.checker.expr_checker import check_call, synthesize_call
19
18
  from guppylang_internals.checker.func_checker import (
20
19
  check_signature,
@@ -46,6 +45,7 @@ from guppylang_internals.std._internal.compiler.array import (
46
45
  array_new,
47
46
  array_unpack,
48
47
  )
48
+ from guppylang_internals.std._internal.compiler.quantum import from_halfturns_unchecked
49
49
  from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
50
50
  from guppylang_internals.tys.builtin import array_type, bool_type, float_type
51
51
  from guppylang_internals.tys.subst import Inst, Subst
@@ -230,17 +230,20 @@ class ParsedPytketDef(CallableDef, CompilableDef):
230
230
  )
231
231
  lex_params = list(unpack_result)
232
232
  param_order = cast(
233
- list[str], hugr_func.metadata["TKET1.input_parameters"]
233
+ "list[str]", hugr_func.metadata["TKET1.input_parameters"]
234
234
  )
235
235
  lex_names = sorted(param_order)
236
236
  name_to_param = dict(zip(lex_names, lex_params, strict=True))
237
237
  angle_wires = [name_to_param[name] for name in param_order]
238
- # Need to convert all angles to floats.
238
+ # Need to convert all angles to rotations.
239
239
  for angle in angle_wires:
240
240
  [halfturns] = outer_func.add_op(
241
241
  ops.UnpackTuple([FLOAT_T]), angle
242
242
  )
243
- param_wires.append(halfturns)
243
+ rotation = outer_func.add_op(
244
+ from_halfturns_unchecked(), halfturns
245
+ )
246
+ param_wires.append(rotation)
244
247
 
245
248
  # Pass all arguments to call node.
246
249
  call_node = outer_func.call(
@@ -365,69 +368,49 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef, CompiledHugrNodeDe
365
368
 
366
369
 
367
370
  def _signature_from_circuit(
368
- input_circuit: Any,
371
+ input_circuit: Circuit,
369
372
  defined_at: ToSpan | None,
370
373
  use_arrays: bool = False,
371
374
  ) -> FunctionType:
372
375
  """Helper function for inferring a function signature from a pytket circuit."""
373
376
  # May want to set proper unitary flags in the future.
374
- try:
375
- import pytket
376
-
377
- if isinstance(input_circuit, pytket.circuit.Circuit):
378
- try:
379
- import tket # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401
380
-
381
- from guppylang.defs import GuppyDefinition
382
- from guppylang.std.angles import angle
383
- from guppylang.std.quantum import qubit
384
-
385
- assert isinstance(qubit, GuppyDefinition)
386
- qubit_ty = cast(TypeDef, qubit.wrapped).check_instantiate([])
387
-
388
- angle_defn = ENGINE.get_checked(angle.id) # type: ignore[attr-defined]
389
- assert isinstance(angle_defn, TypeDef)
390
- angle_ty = angle_defn.check_instantiate([])
391
-
392
- if use_arrays:
393
- inputs = [
394
- FuncInput(array_type(qubit_ty, q_reg.size), InputFlags.Inout)
395
- for q_reg in input_circuit.q_registers
396
- ]
397
- if len(input_circuit.free_symbols()) != 0:
398
- inputs.append(
399
- FuncInput(
400
- array_type(angle_ty, len(input_circuit.free_symbols())),
401
- InputFlags.NoFlags,
402
- )
403
- )
404
- outputs = [
405
- array_type(bool_type(), c_reg.size)
406
- for c_reg in input_circuit.c_registers
407
- ]
408
- circuit_signature = FunctionType(
409
- inputs,
410
- row_to_type(outputs),
411
- )
412
- else:
413
- param_inputs = [
414
- FuncInput(angle_ty, InputFlags.NoFlags)
415
- for _ in range(len(input_circuit.free_symbols()))
416
- ]
417
- circuit_signature = FunctionType(
418
- [FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
419
- + param_inputs,
420
- row_to_type([bool_type()] * input_circuit.n_bits),
421
- )
422
- except ImportError:
423
- err = TketNotInstalled(defined_at)
424
- err.add_sub_diagnostic(TketNotInstalled.InstallInstruction(None))
425
- raise GuppyError(err) from None
426
- else:
427
- pass
428
- except ImportError:
429
- raise InternalGuppyError(
430
- "Pytket error should have been caught earlier"
431
- ) from None
377
+ from guppylang.std.angles import angle # Avoid circular imports
378
+ from guppylang.std.quantum import qubit
379
+
380
+ assert isinstance(qubit, GuppyDefinition)
381
+ qubit_ty = cast("TypeDef", qubit.wrapped).check_instantiate([])
382
+
383
+ angle_defn = ENGINE.get_checked(angle.id) # type: ignore[attr-defined]
384
+ assert isinstance(angle_defn, TypeDef)
385
+ angle_ty = angle_defn.check_instantiate([])
386
+
387
+ if use_arrays:
388
+ inputs = [
389
+ FuncInput(array_type(qubit_ty, q_reg.size), InputFlags.Inout)
390
+ for q_reg in input_circuit.q_registers
391
+ ]
392
+ if len(input_circuit.free_symbols()) != 0:
393
+ inputs.append(
394
+ FuncInput(
395
+ array_type(angle_ty, len(input_circuit.free_symbols())),
396
+ InputFlags.NoFlags,
397
+ )
398
+ )
399
+ outputs = [
400
+ array_type(bool_type(), c_reg.size) for c_reg in input_circuit.c_registers
401
+ ]
402
+ circuit_signature = FunctionType(
403
+ inputs,
404
+ row_to_type(outputs),
405
+ )
432
406
  else:
433
- return circuit_signature
407
+ param_inputs = [
408
+ FuncInput(angle_ty, InputFlags.NoFlags)
409
+ for _ in range(len(input_circuit.free_symbols()))
410
+ ]
411
+ circuit_signature = FunctionType(
412
+ [FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
413
+ + param_inputs,
414
+ row_to_type([bool_type()] * input_circuit.n_bits),
415
+ )
416
+ return circuit_signature
@@ -55,7 +55,7 @@ class CallableDef(ValueDef):
55
55
  raise RuntimeError("Guppy functions can only be called in a Guppy context")
56
56
 
57
57
 
58
- class CompiledCallableDef(CallableDef, CompiledValueDef):
58
+ class CompiledCallableDef(CallableDef, CompiledValueDef): # type: ignore[misc, unused-ignore]
59
59
  """Abstract base class a global module-level function."""
60
60
 
61
61
  ty: FunctionType
@@ -38,9 +38,9 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
38
38
  def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
39
39
  # Place to highlight in error messages
40
40
  match fun_ty.inputs:
41
- case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
42
- ty
43
- ) is not None:
41
+ case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if (
42
+ wasm_module_name(ty) is not None
43
+ ):
44
44
  for inp in args:
45
45
  if not self.is_type_wasmable(inp.ty):
46
46
  raise GuppyError(UnWasmableType(loc, inp.ty))