guppylang-internals 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/builder.py +20 -2
  3. guppylang_internals/cfg/cfg.py +3 -0
  4. guppylang_internals/checker/cfg_checker.py +6 -0
  5. guppylang_internals/checker/core.py +1 -2
  6. guppylang_internals/checker/errors/linearity.py +6 -2
  7. guppylang_internals/checker/errors/wasm.py +7 -4
  8. guppylang_internals/checker/expr_checker.py +39 -19
  9. guppylang_internals/checker/func_checker.py +17 -13
  10. guppylang_internals/checker/linearity_checker.py +2 -10
  11. guppylang_internals/checker/modifier_checker.py +6 -2
  12. guppylang_internals/checker/unitary_checker.py +132 -0
  13. guppylang_internals/compiler/cfg_compiler.py +7 -6
  14. guppylang_internals/compiler/core.py +5 -5
  15. guppylang_internals/compiler/expr_compiler.py +72 -81
  16. guppylang_internals/compiler/modifier_compiler.py +5 -0
  17. guppylang_internals/decorator.py +88 -7
  18. guppylang_internals/definition/custom.py +4 -0
  19. guppylang_internals/definition/declaration.py +6 -2
  20. guppylang_internals/definition/function.py +26 -3
  21. guppylang_internals/definition/metadata.py +87 -0
  22. guppylang_internals/definition/overloaded.py +11 -2
  23. guppylang_internals/definition/pytket_circuits.py +7 -2
  24. guppylang_internals/definition/struct.py +6 -3
  25. guppylang_internals/definition/wasm.py +42 -10
  26. guppylang_internals/diagnostic.py +72 -15
  27. guppylang_internals/engine.py +10 -13
  28. guppylang_internals/nodes.py +55 -24
  29. guppylang_internals/std/_internal/checker.py +13 -108
  30. guppylang_internals/std/_internal/compiler/array.py +37 -2
  31. guppylang_internals/std/_internal/compiler/either.py +14 -2
  32. guppylang_internals/std/_internal/compiler/list.py +1 -1
  33. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  34. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  35. guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
  36. guppylang_internals/std/_internal/compiler/tket_exts.py +4 -5
  37. guppylang_internals/std/_internal/debug.py +18 -9
  38. guppylang_internals/std/_internal/util.py +1 -1
  39. guppylang_internals/tracing/object.py +14 -0
  40. guppylang_internals/tys/errors.py +23 -1
  41. guppylang_internals/tys/parsing.py +3 -3
  42. guppylang_internals/tys/printing.py +2 -8
  43. guppylang_internals/tys/qubit.py +37 -2
  44. guppylang_internals/tys/ty.py +60 -64
  45. guppylang_internals/wasm_util.py +129 -0
  46. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/METADATA +5 -4
  47. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/RECORD +49 -45
  48. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/WHEEL +1 -1
  49. {guppylang_internals-0.25.0.dist-info → guppylang_internals-0.27.0.dist-info}/licenses/LICENCE +0 -0
@@ -0,0 +1,87 @@
1
+ """Metadata attached to objects within the Guppy compiler, both for internal use and to
2
+ attach to HUGR nodes for lower-level processing."""
3
+
4
+ from abc import ABC
5
+ from dataclasses import dataclass, field, fields
6
+ from typing import Any, ClassVar, Generic, TypeVar
7
+
8
+ from hugr.hugr.node_port import ToNode
9
+
10
+ from guppylang_internals.diagnostic import Fatal
11
+ from guppylang_internals.error import GuppyError
12
+
13
+ T = TypeVar("T")
14
+
15
+
16
+ @dataclass(init=True, kw_only=True)
17
+ class GuppyMetadataValue(ABC, Generic[T]):
18
+ """A template class for a metadata value within the scope of the Guppy compiler.
19
+ Implementations should provide the `key` in reverse-URL format."""
20
+
21
+ key: ClassVar[str]
22
+ value: T | None = None
23
+
24
+
25
+ class MetadataMaxQubits(GuppyMetadataValue[int]):
26
+ key = "tket.hint.max_qubits"
27
+
28
+
29
+ @dataclass(frozen=True, init=True, kw_only=True)
30
+ class GuppyMetadata:
31
+ """DTO for metadata within the scope of the guppy compiler for attachment to HUGR
32
+ nodes. See `add_metadata`."""
33
+
34
+ max_qubits: MetadataMaxQubits = field(default_factory=MetadataMaxQubits, init=False)
35
+
36
+ @classmethod
37
+ def reserved_keys(cls) -> set[str]:
38
+ return {f.type.key for f in fields(GuppyMetadata)}
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class MetadataAlreadySetError(Fatal):
43
+ title: ClassVar[str] = "Metadata key already set"
44
+ message: ClassVar[str] = "Received two values for the metadata key `{key}`"
45
+ key: str
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class ReservedMetadataKeysError(Fatal):
50
+ title: ClassVar[str] = "Metadata key is reserved"
51
+ message: ClassVar[str] = (
52
+ "The following metadata keys are reserved by Guppy but also provided in "
53
+ "additional metadata: `{keys}`"
54
+ )
55
+ keys: set[str]
56
+
57
+
58
+ def add_metadata(
59
+ node: ToNode,
60
+ metadata: GuppyMetadata | None = None,
61
+ *,
62
+ additional_metadata: dict[str, Any] | None = None,
63
+ ) -> None:
64
+ """Adds metadata to the given node using the keys defined through inheritors of
65
+ `GuppyMetadataValue` defined in the `GuppyMetadata` class.
66
+
67
+ Additional metadata is forwarded as is, although the given dictionary may not
68
+ contain any keys already reserved by fields in `GuppyMetadata`.
69
+ """
70
+ if metadata is not None:
71
+ for f in fields(GuppyMetadata):
72
+ data: GuppyMetadataValue[Any] = getattr(metadata, f.name)
73
+ if data.key in node.metadata:
74
+ raise GuppyError(MetadataAlreadySetError(None, data.key))
75
+ if data.value is not None:
76
+ node.metadata[data.key] = data.value
77
+
78
+ if additional_metadata is not None:
79
+ reserved_keys = GuppyMetadata.reserved_keys()
80
+ used_reserved_keys = reserved_keys.intersection(additional_metadata.keys())
81
+ if len(used_reserved_keys) > 0:
82
+ raise GuppyError(ReservedMetadataKeysError(None, keys=used_reserved_keys))
83
+
84
+ for key, value in additional_metadata.items():
85
+ if key in node.metadata:
86
+ raise GuppyError(MetadataAlreadySetError(None, key))
87
+ node.metadata[key] = value
@@ -1,4 +1,5 @@
1
1
  import ast
2
+ import copy
2
3
  from contextlib import suppress
3
4
  from dataclasses import dataclass, field
4
5
  from typing import ClassVar, NoReturn
@@ -86,7 +87,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
86
87
  assert isinstance(defn, CallableDef)
87
88
  available_sigs.append(defn.ty)
88
89
  with suppress(GuppyError):
89
- return defn.check_call(args, ty, node, ctx)
90
+ # check_call may modify args and node,
91
+ # thus we deepcopy them before passing in the function
92
+ node_copy = copy.deepcopy(node)
93
+ args_copy = copy.deepcopy(args)
94
+ return defn.check_call(args_copy, ty, node_copy, ctx)
90
95
  return self._call_error(args, node, ctx, available_sigs, ty)
91
96
 
92
97
  def synthesize_call(
@@ -98,7 +103,11 @@ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
98
103
  assert isinstance(defn, CallableDef)
99
104
  available_sigs.append(defn.ty)
100
105
  with suppress(GuppyError):
101
- return defn.synthesize_call(args, node, ctx)
106
+ # synthesize_call may modify args and node,
107
+ # thus we deepcopy them before passing in the function
108
+ node_copy = copy.deepcopy(node)
109
+ args_copy = copy.deepcopy(args)
110
+ return defn.synthesize_call(args_copy, node_copy, ctx)
102
111
  return self._call_error(args, node, ctx, available_sigs)
103
112
 
104
113
  def _call_error(
@@ -46,6 +46,7 @@ from guppylang_internals.std._internal.compiler.array import (
46
46
  array_new,
47
47
  array_unpack,
48
48
  )
49
+ from guppylang_internals.std._internal.compiler.quantum import from_halfturns_unchecked
49
50
  from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
50
51
  from guppylang_internals.tys.builtin import array_type, bool_type, float_type
51
52
  from guppylang_internals.tys.subst import Inst, Subst
@@ -235,12 +236,15 @@ class ParsedPytketDef(CallableDef, CompilableDef):
235
236
  lex_names = sorted(param_order)
236
237
  name_to_param = dict(zip(lex_names, lex_params, strict=True))
237
238
  angle_wires = [name_to_param[name] for name in param_order]
238
- # Need to convert all angles to floats.
239
+ # Need to convert all angles to rotations.
239
240
  for angle in angle_wires:
240
241
  [halfturns] = outer_func.add_op(
241
242
  ops.UnpackTuple([FLOAT_T]), angle
242
243
  )
243
- param_wires.append(halfturns)
244
+ rotation = outer_func.add_op(
245
+ from_halfturns_unchecked(), halfturns
246
+ )
247
+ param_wires.append(rotation)
244
248
 
245
249
  # Pass all arguments to call node.
246
250
  call_node = outer_func.call(
@@ -370,6 +374,7 @@ def _signature_from_circuit(
370
374
  use_arrays: bool = False,
371
375
  ) -> FunctionType:
372
376
  """Helper function for inferring a function signature from a pytket circuit."""
377
+ # May want to set proper unitary flags in the future.
373
378
  try:
374
379
  import pytket
375
380
 
@@ -273,13 +273,16 @@ class CheckedStructDef(TypeDef, CompiledDef):
273
273
 
274
274
  constructor_sig = FunctionType(
275
275
  inputs=[
276
- FuncInput(f.ty, InputFlags.Owned if f.ty.linear else InputFlags.NoFlags)
276
+ FuncInput(
277
+ f.ty,
278
+ InputFlags.Owned if f.ty.linear else InputFlags.NoFlags,
279
+ f.name,
280
+ )
277
281
  for f in self.fields
278
282
  ],
279
283
  output=StructType(
280
284
  defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
281
285
  ),
282
- input_names=[f.name for f in self.fields],
283
286
  params=self.params,
284
287
  )
285
288
  constructor_def = CustomFunctionDef(
@@ -317,7 +320,7 @@ def parse_py_class(
317
320
  raise GuppyError(UnknownSourceError(None, cls))
318
321
 
319
322
  # We can't rely on `inspect.getsourcelines` since it doesn't work properly for
320
- # classes prior to Python 3.13. See https://github.com/CQCL/guppylang/issues/1107.
323
+ # classes prior to Python 3.13. See https://github.com/quantinuum/guppylang/issues/1107.
321
324
  # Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
322
325
  # attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
323
326
  # In the decorator, we make sure that `__firstlineno__` is set, even if we're not
@@ -1,3 +1,4 @@
1
+ from dataclasses import dataclass, field
1
2
  from typing import TYPE_CHECKING
2
3
 
3
4
  from guppylang_internals.ast_util import AstNode
@@ -9,7 +10,8 @@ from guppylang_internals.definition.custom import (
9
10
  CustomFunctionDef,
10
11
  RawCustomFunctionDef,
11
12
  )
12
- from guppylang_internals.error import GuppyError
13
+ from guppylang_internals.engine import DEF_STORE
14
+ from guppylang_internals.error import GuppyError, GuppyTypeError
13
15
  from guppylang_internals.span import SourceMap
14
16
  from guppylang_internals.tys.builtin import wasm_module_name
15
17
  from guppylang_internals.tys.ty import (
@@ -21,24 +23,35 @@ from guppylang_internals.tys.ty import (
21
23
  TupleType,
22
24
  Type,
23
25
  )
26
+ from guppylang_internals.wasm_util import WasmSigMismatchError
24
27
 
25
28
  if TYPE_CHECKING:
26
29
  from guppylang_internals.checker.core import Globals
27
30
 
28
31
 
32
+ @dataclass(frozen=True)
29
33
  class RawWasmFunctionDef(RawCustomFunctionDef):
30
- def sanitise_type(self, loc: AstNode | None, fun_ty: FunctionType) -> None:
34
+ # If a function is specified in the @wasm decorator by its index in the wasm
35
+ # file, record what the index was.
36
+ wasm_index: int | None = field(default=None)
37
+
38
+ def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
31
39
  # Place to highlight in error messages
32
- match fun_ty.inputs[0]:
33
- case FuncInput(ty=ty, flags=InputFlags.Inout) if wasm_module_name(
40
+ match fun_ty.inputs:
41
+ case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
34
42
  ty
35
43
  ) is not None:
36
- pass
37
- case FuncInput(ty=ty):
38
- raise GuppyError(FirstArgNotModule(loc, ty))
39
- for inp in fun_ty.inputs[1:]:
40
- if not self.is_type_wasmable(inp.ty):
41
- raise GuppyError(UnWasmableType(loc, inp.ty))
44
+ for inp in args:
45
+ if not self.is_type_wasmable(inp.ty):
46
+ raise GuppyError(UnWasmableType(loc, inp.ty))
47
+ case [FuncInput(ty=ty), *_]:
48
+ raise GuppyError(
49
+ FirstArgNotModule(loc).add_sub_diagnostic(
50
+ FirstArgNotModule.GotOtherType(loc, ty)
51
+ )
52
+ )
53
+ case []:
54
+ raise GuppyError(FirstArgNotModule(loc))
42
55
  if not self.is_type_wasmable(fun_ty.output):
43
56
  match fun_ty.output:
44
57
  case NoneType():
@@ -46,6 +59,23 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
46
59
  case _:
47
60
  raise GuppyError(UnWasmableType(loc, fun_ty.output))
48
61
 
62
+ def validate_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
63
+ type_in_wasm: FunctionType = DEF_STORE.wasm_functions[self.id]
64
+ assert type_in_wasm is not None
65
+ # Drop the first arg because it should be "self"
66
+ expected_type = FunctionType(fun_ty.inputs[1:], fun_ty.output)
67
+
68
+ if expected_type != type_in_wasm:
69
+ raise GuppyTypeError(
70
+ WasmSigMismatchError(loc)
71
+ .add_sub_diagnostic(
72
+ WasmSigMismatchError.Declaration(None, declared=str(expected_type))
73
+ )
74
+ .add_sub_diagnostic(
75
+ WasmSigMismatchError.Actual(None, actual=str(type_in_wasm))
76
+ )
77
+ )
78
+
49
79
  def is_type_wasmable(self, ty: Type) -> bool:
50
80
  match ty:
51
81
  case NumericType():
@@ -57,5 +87,7 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
57
87
 
58
88
  def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
59
89
  parsed = super().parse(globals, sources)
90
+ assert parsed.defined_at is not None
60
91
  self.sanitise_type(parsed.defined_at, parsed.ty)
92
+ self.validate_type(parsed.defined_at, parsed.ty)
61
93
  return parsed
@@ -208,7 +208,8 @@ class DiagnosticsRenderer:
208
208
  MAX_MESSAGE_LINE_LEN: Final[int] = 80
209
209
 
210
210
  #: Number of preceding source lines we show to give additional context
211
- PREFIX_CONTEXT_LINES: Final[int] = 2
211
+ PREFIX_ERROR_CONTEXT_LINES: Final[int] = 2
212
+ PREFIX_NOTE_CONTEXT_LINES: Final[int] = 1
212
213
 
213
214
  def __init__(self, source: SourceMap) -> None:
214
215
  self.buffer = []
@@ -243,31 +244,84 @@ class DiagnosticsRenderer:
243
244
  else:
244
245
  span = to_span(diag.span)
245
246
  level = self.level_str(diag.level)
246
- all_spans = [span] + [
247
- to_span(child.span) for child in diag.children if child.span
247
+
248
+ children_with_span = [
249
+ (child, to_span(child.span)) for child in diag.children if child.span
248
250
  ]
251
+ all_spans = [span] + [span for _, span in children_with_span]
249
252
  max_lineno = max(s.end.line for s in all_spans)
253
+
250
254
  self.buffer.append(f"{level}: {diag.rendered_title} (at {span.start})")
255
+
256
+ # Render main error span first
251
257
  self.render_snippet(
252
258
  span,
253
259
  diag.rendered_span_label,
254
260
  max_lineno,
255
261
  is_primary=True,
256
- prefix_lines=self.PREFIX_CONTEXT_LINES,
262
+ prefix_lines=self.PREFIX_ERROR_CONTEXT_LINES,
257
263
  )
258
- # First render all sub-diagnostics that come with a span
259
- for sub_diag in diag.children:
260
- if sub_diag.span:
264
+
265
+ match children_with_span:
266
+ case []:
267
+ pass
268
+ case [(only_child, span)]:
269
+ self.buffer.append("\nNote:")
261
270
  self.render_snippet(
262
- to_span(sub_diag.span),
263
- sub_diag.rendered_span_label,
271
+ span,
272
+ only_child.rendered_span_label,
264
273
  max_lineno,
265
- is_primary=False,
274
+ prefix_lines=self.PREFIX_NOTE_CONTEXT_LINES,
275
+ print_pad_line=True,
266
276
  )
277
+ case [(first_child, first_span), *children_with_span]:
278
+ self.buffer.append("\nNotes:")
279
+ self.render_snippet(
280
+ first_span,
281
+ first_child.rendered_span_label,
282
+ max_lineno,
283
+ prefix_lines=self.PREFIX_NOTE_CONTEXT_LINES,
284
+ print_pad_line=True,
285
+ )
286
+
287
+ prev_span_end_lineno = first_span.end.line
288
+
289
+ for sub_diag, span in children_with_span:
290
+ span_start_lineno = span.start.line
291
+ span_end_lineno = span.end.line
292
+
293
+ # If notes are on the same line, render them together
294
+ if span_start_lineno == prev_span_end_lineno:
295
+ prefix_lines = 0
296
+ print_pad_line = True
297
+ # if notes are close enough, render them adjacently
298
+ elif (
299
+ span_start_lineno - self.PREFIX_NOTE_CONTEXT_LINES
300
+ <= prev_span_end_lineno + 1
301
+ ):
302
+ prefix_lines = span_start_lineno - prev_span_end_lineno - 1
303
+ print_pad_line = False
304
+ # otherwise we render a separator between notes
305
+ else:
306
+ self.buffer.append("")
307
+ prefix_lines = self.PREFIX_NOTE_CONTEXT_LINES
308
+ print_pad_line = False
309
+
310
+ self.render_snippet(
311
+ span,
312
+ sub_diag.rendered_span_label,
313
+ max_lineno,
314
+ prefix_lines=prefix_lines,
315
+ print_pad_line=print_pad_line,
316
+ )
317
+ prev_span_end_lineno = span_end_lineno
318
+
319
+ # Render the main diagnostic message if present
267
320
  if diag.rendered_message:
268
321
  self.buffer.append("")
269
322
  self.buffer += wrap(diag.rendered_message, self.MAX_MESSAGE_LINE_LEN)
270
- # Finally, render all sub-diagnostics that have a non-span message
323
+
324
+ # Render all sub-diagnostics that have a non-span message
271
325
  for sub_diag in diag.children:
272
326
  if sub_diag.rendered_message:
273
327
  self.buffer.append("")
@@ -281,8 +335,9 @@ class DiagnosticsRenderer:
281
335
  span: Span,
282
336
  label: str | None,
283
337
  max_lineno: int,
284
- is_primary: bool,
338
+ is_primary: bool = False,
285
339
  prefix_lines: int = 0,
340
+ print_pad_line: bool = False,
286
341
  ) -> None:
287
342
  """Renders the source associated with a span together with an optional label.
288
343
 
@@ -315,7 +370,8 @@ class DiagnosticsRenderer:
315
370
  Optionally includes up to `prefix_lines` preceding source lines to give
316
371
  additional context.
317
372
  """
318
- # Check how much space we need to reserve for the leading line numbers
373
+ # Check how much horizontal space we need to reserve for the leading
374
+ # line numbers
319
375
  ll_length = len(str(max_lineno))
320
376
  highlight_char = "^" if is_primary else "-"
321
377
 
@@ -324,8 +380,9 @@ class DiagnosticsRenderer:
324
380
  ll = "" if line_number is None else str(line_number)
325
381
  self.buffer.append(" " * (ll_length - len(ll)) + ll + " | " + line)
326
382
 
327
- # One line of padding
328
- render_line("")
383
+ # One line of padding (primary span, first note or between same line notes)
384
+ if is_primary or print_pad_line:
385
+ render_line("")
329
386
 
330
387
  # Grab all lines we want to display and remove excessive leading whitespace
331
388
  prefix_lines = min(prefix_lines, span.start.line - 1)
@@ -46,6 +46,7 @@ from guppylang_internals.tys.builtin import (
46
46
  string_type_def,
47
47
  tuple_type_def,
48
48
  )
49
+ from guppylang_internals.tys.ty import FunctionType
49
50
 
50
51
  if TYPE_CHECKING:
51
52
  from guppylang_internals.compiler.core import MonoDefId
@@ -87,6 +88,7 @@ class DefinitionStore:
87
88
  raw_defs: dict[DefId, RawDef]
88
89
  impls: defaultdict[DefId, dict[str, DefId]]
89
90
  impl_parents: dict[DefId, DefId]
91
+ wasm_functions: dict[DefId, FunctionType]
90
92
  frames: dict[DefId, FrameType]
91
93
  sources: SourceMap
92
94
 
@@ -96,6 +98,7 @@ class DefinitionStore:
96
98
  self.impl_parents = {}
97
99
  self.frames = {}
98
100
  self.sources = SourceMap()
101
+ self.wasm_functions = {}
99
102
 
100
103
  def register_def(self, defn: RawDef, frame: FrameType | None) -> None:
101
104
  self.raw_defs[defn.id] = defn
@@ -123,6 +126,9 @@ class DefinitionStore:
123
126
  assert frame is not None
124
127
  self.frames[impl_id] = frame
125
128
 
129
+ def register_wasm_function(self, fn_id: DefId, sig: FunctionType) -> None:
130
+ self.wasm_functions[fn_id] = sig
131
+
126
132
 
127
133
  DEF_STORE: DefinitionStore = DefinitionStore()
128
134
 
@@ -214,21 +220,12 @@ class CompilationEngine:
214
220
 
215
221
  This is the main driver behind `guppy.check()`.
216
222
  """
217
- from guppylang_internals.checker.core import Globals
218
-
219
223
  # Clear previous compilation cache.
220
224
  # TODO: In order to maintain results from the previous `check` call we would
221
225
  # need to store and check if any dependencies have changed.
222
226
  self.reset()
223
227
 
224
- defn = DEF_STORE.raw_defs[id]
225
- self.to_check_worklist = {
226
- defn.id: (
227
- defn.parse(Globals(DEF_STORE.frames[defn.id]), DEF_STORE.sources)
228
- if isinstance(defn, ParsableDef)
229
- else defn
230
- )
231
- }
228
+ self.to_check_worklist[id] = self.get_parsed(id)
232
229
  while self.types_to_check_worklist or self.to_check_worklist:
233
230
  # Types need to be checked first. This is because parsing e.g. a function
234
231
  # definition requires instantiating the types in its signature which can
@@ -263,8 +260,8 @@ class CompilationEngine:
263
260
  and isinstance(compiled_def, CompiledCallableDef)
264
261
  and not isinstance(graph.hugr[compiled_def.hugr_node].op, ops.FuncDecl)
265
262
  ):
266
- # if compiling a region set it as the HUGR entrypoint
267
- # can be loosened after https://github.com/CQCL/hugr/issues/2501 is fixed
263
+ # if compiling a region set it as the HUGR entrypoint can be
264
+ # loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
268
265
  graph.hugr.entrypoint = compiled_def.hugr_node
269
266
 
270
267
  # TODO: Currently the list of extensions is manually managed by the user.
@@ -278,7 +275,7 @@ class CompilationEngine:
278
275
  guppylang_internals.compiler.hugr_extension.EXTENSION,
279
276
  *self.additional_extensions,
280
277
  ]
281
- # TODO replace with computed extensions after https://github.com/CQCL/guppylang/issues/550
278
+ # TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
282
279
  all_used_extensions = [
283
280
  *extensions,
284
281
  hugr.std.prelude.PRELUDE_EXTENSION,
@@ -9,7 +9,13 @@ from guppylang_internals.ast_util import AstNode
9
9
  from guppylang_internals.span import Span, to_span
10
10
  from guppylang_internals.tys.const import Const
11
11
  from guppylang_internals.tys.subst import Inst
12
- from guppylang_internals.tys.ty import FunctionType, StructType, TupleType, Type
12
+ from guppylang_internals.tys.ty import (
13
+ FunctionType,
14
+ StructType,
15
+ TupleType,
16
+ Type,
17
+ UnitaryFlags,
18
+ )
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from guppylang_internals.cfg.cfg import CFG
@@ -166,6 +172,14 @@ class MakeIter(ast.expr):
166
172
  self.origin_node = origin_node
167
173
  self.unwrap_size_hint = unwrap_size_hint
168
174
 
175
+ # Needed for the deepcopy to work correctly, ast.AST's deepcopy logic
176
+ # reconstructs nodes using _fields only.
177
+ # If you store extra attributes or rely overwriting the __init__,
178
+ # deepcopy will crash with a constructor mismatch.
179
+ # Overriding __reduce__ forces deepcopy to copy the instance dictionary instead
180
+ __reduce_ex__ = object.__reduce_ex__
181
+ __reduce__ = object.__reduce__
182
+
169
183
 
170
184
  class IterNext(ast.expr):
171
185
  """Obtains the next element of an iterator using the `__next__` magic method.
@@ -250,22 +264,6 @@ class ComptimeExpr(ast.expr):
250
264
  _fields = ("value",)
251
265
 
252
266
 
253
- class ResultExpr(ast.expr):
254
- """A `result(tag, value)` expression."""
255
-
256
- value: ast.expr
257
- base_ty: Type
258
- #: Array length in case this is an array result, otherwise `None`
259
- array_len: Const | None
260
- tag: str
261
-
262
- _fields = ("value", "base_ty", "array_len", "tag")
263
-
264
- @property
265
- def args(self) -> list[ast.expr]:
266
- return [self.value]
267
-
268
-
269
267
  class ExitKind(Enum):
270
268
  ExitShot = 0 # Exit the current shot
271
269
  Panic = 1 # Panic the program ending all shots
@@ -275,8 +273,8 @@ class PanicExpr(ast.expr):
275
273
  """A `panic(msg, *args)` or `exit(msg, *args)` expression ."""
276
274
 
277
275
  kind: ExitKind
278
- signal: int
279
- msg: str
276
+ signal: ast.expr
277
+ msg: ast.expr
280
278
  values: list[ast.expr]
281
279
 
282
280
  _fields = ("kind", "signal", "msg", "values")
@@ -293,17 +291,16 @@ class BarrierExpr(ast.expr):
293
291
  class StateResultExpr(ast.expr):
294
292
  """A `state_result(tag, *args)` expression."""
295
293
 
296
- tag: str
294
+ tag_value: Const
295
+ tag_expr: ast.expr
297
296
  args: list[ast.expr]
298
297
  func_ty: FunctionType
299
298
  #: Array length in case this is an array result, otherwise `None`
300
299
  array_len: Const | None
301
- _fields = ("tag", "args", "func_ty", "has_array_input")
300
+ _fields = ("tag_value", "tag_expr", "args", "func_ty", "has_array_input")
302
301
 
303
302
 
304
- AnyCall = (
305
- LocalCall | GlobalCall | TensorCall | BarrierExpr | ResultExpr | StateResultExpr
306
- )
303
+ AnyCall = LocalCall | GlobalCall | TensorCall | BarrierExpr | StateResultExpr
307
304
 
308
305
 
309
306
  class InoutReturnSentinel(ast.expr):
@@ -360,6 +357,10 @@ class ArrayUnpack(ast.expr):
360
357
  self.length = length
361
358
  self.elt_type = elt_type
362
359
 
360
+ # See MakeIter for explanation
361
+ __reduce__ = object.__reduce__
362
+ __reduce_ex__ = object.__reduce_ex__
363
+
363
364
 
364
365
  class IterableUnpack(ast.expr):
365
366
  """The LHS of an unpacking assignment of an iterable type."""
@@ -384,6 +385,10 @@ class IterableUnpack(ast.expr):
384
385
  self.compr = compr
385
386
  self.rhs_var = rhs_var
386
387
 
388
+ # See MakeIter for explanation
389
+ __reduce__ = object.__reduce__
390
+ __reduce_ex__ = object.__reduce_ex__
391
+
387
392
 
388
393
  #: Any unpacking operation.
389
394
  AnyUnpack = TupleUnpack | ArrayUnpack | IterableUnpack
@@ -431,6 +436,10 @@ class Dagger(ast.expr):
431
436
  def __init__(self, node: ast.expr) -> None:
432
437
  super().__init__(**node.__dict__)
433
438
 
439
+ # See MakeIter for explanation
440
+ __reduce__ = object.__reduce__
441
+ __reduce_ex__ = object.__reduce_ex__
442
+
434
443
 
435
444
  class Control(ast.Call):
436
445
  """The control modifier"""
@@ -445,6 +454,10 @@ class Control(ast.Call):
445
454
  self.ctrl = ctrl
446
455
  self.qubit_num = None
447
456
 
457
+ # See MakeIter for explanation
458
+ __reduce__ = object.__reduce__
459
+ __reduce_ex__ = object.__reduce_ex__
460
+
448
461
 
449
462
  class Power(ast.expr):
450
463
  """The power modifier"""
@@ -457,6 +470,10 @@ class Power(ast.expr):
457
470
  super().__init__(**node.__dict__)
458
471
  self.iter = iter
459
472
 
473
+ # See MakeIter for explanation
474
+ __reduce__ = object.__reduce__
475
+ __reduce_ex__ = object.__reduce_ex__
476
+
460
477
 
461
478
  Modifier = Dagger | Control | Power
462
479
 
@@ -500,6 +517,16 @@ class ModifiedBlock(ast.With):
500
517
  else:
501
518
  raise TypeError(f"Unknown modifier: {modifier}")
502
519
 
520
+ def flags(self) -> UnitaryFlags:
521
+ flags = UnitaryFlags.NoFlags
522
+ if self.is_dagger():
523
+ flags |= UnitaryFlags.Dagger
524
+ if self.is_control():
525
+ flags |= UnitaryFlags.Control
526
+ if self.is_power():
527
+ flags |= UnitaryFlags.Power
528
+ return flags
529
+
503
530
 
504
531
  class CheckedModifiedBlock(ast.With):
505
532
  def_id: "DefId"
@@ -534,6 +561,10 @@ class CheckedModifiedBlock(ast.With):
534
561
  self.control = control
535
562
  self.power = power
536
563
 
564
+ # See MakeIter for explanation
565
+ __reduce__ = object.__reduce__
566
+ __reduce_ex__ = object.__reduce_ex__
567
+
537
568
  def __str__(self) -> str:
538
569
  # generate a function name from the def_id
539
570
  return f"__WithBlock__({self.def_id})"