guppylang-internals 0.27.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +37 -18
  3. guppylang_internals/cfg/analysis.py +6 -6
  4. guppylang_internals/cfg/builder.py +41 -12
  5. guppylang_internals/cfg/cfg.py +1 -1
  6. guppylang_internals/checker/core.py +1 -1
  7. guppylang_internals/checker/errors/comptime_errors.py +0 -12
  8. guppylang_internals/checker/expr_checker.py +27 -17
  9. guppylang_internals/checker/func_checker.py +4 -3
  10. guppylang_internals/checker/stmt_checker.py +1 -1
  11. guppylang_internals/compiler/cfg_compiler.py +1 -1
  12. guppylang_internals/compiler/core.py +17 -4
  13. guppylang_internals/compiler/expr_compiler.py +9 -9
  14. guppylang_internals/decorator.py +2 -2
  15. guppylang_internals/definition/common.py +1 -0
  16. guppylang_internals/definition/custom.py +2 -2
  17. guppylang_internals/definition/declaration.py +3 -3
  18. guppylang_internals/definition/function.py +8 -1
  19. guppylang_internals/definition/metadata.py +1 -1
  20. guppylang_internals/definition/pytket_circuits.py +44 -65
  21. guppylang_internals/definition/value.py +1 -1
  22. guppylang_internals/definition/wasm.py +3 -3
  23. guppylang_internals/diagnostic.py +17 -1
  24. guppylang_internals/engine.py +83 -30
  25. guppylang_internals/error.py +1 -1
  26. guppylang_internals/nodes.py +269 -3
  27. guppylang_internals/span.py +7 -3
  28. guppylang_internals/std/_internal/checker.py +104 -2
  29. guppylang_internals/std/_internal/debug.py +5 -3
  30. guppylang_internals/tracing/builtins_mock.py +2 -2
  31. guppylang_internals/tracing/object.py +2 -2
  32. guppylang_internals/tys/parsing.py +4 -1
  33. guppylang_internals/tys/qubit.py +6 -4
  34. guppylang_internals/tys/subst.py +2 -2
  35. guppylang_internals/tys/ty.py +2 -2
  36. guppylang_internals/wasm_util.py +1 -2
  37. {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
  38. {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +40 -40
  39. {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
  40. {guppylang_internals-0.27.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
@@ -31,6 +31,7 @@ from guppylang_internals.definition.common import (
31
31
  MonomorphizableDef,
32
32
  MonomorphizedDef,
33
33
  ParsableDef,
34
+ RawDef,
34
35
  UnknownSourceError,
35
36
  )
36
37
  from guppylang_internals.definition.metadata import GuppyMetadata, add_metadata
@@ -177,6 +178,7 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
177
178
  module: DefinitionBuilder[OpVar],
178
179
  mono_args: "PartiallyMonomorphizedArgs",
179
180
  ctx: "CompilerContext",
181
+ parent_ty: "RawDef | None" = None,
180
182
  ) -> "CompiledFunctionDef":
181
183
  """Adds a Hugr `FuncDefn` node for the (partially) monomorphized function to the
182
184
  Hugr.
@@ -185,10 +187,15 @@ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
185
187
  access to the other compiled functions yet. The body is compiled later in
186
188
  `CompiledFunctionDef.compile_inner()`.
187
189
  """
190
+ if parent_ty is None:
191
+ hugr_func_name = self.name
192
+ else:
193
+ hugr_func_name = f"{parent_ty.name}.{self.name}"
194
+
188
195
  mono_ty = self.ty.instantiate_partial(mono_args)
189
196
  hugr_ty = mono_ty.to_hugr_poly(ctx)
190
197
  func_def = module.module_root_builder().define_function(
191
- self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
198
+ hugr_func_name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
192
199
  )
193
200
  add_metadata(
194
201
  func_def,
@@ -35,7 +35,7 @@ class GuppyMetadata:
35
35
 
36
36
  @classmethod
37
37
  def reserved_keys(cls) -> set[str]:
38
- return {f.type.key for f in fields(GuppyMetadata)}
38
+ return {f.type.key for f in fields(GuppyMetadata)} # type: ignore[union-attr]
39
39
 
40
40
 
41
41
  @dataclass(frozen=True)
@@ -3,18 +3,17 @@ from dataclasses import dataclass, field
3
3
  from typing import Any, cast
4
4
 
5
5
  import hugr.build.function as hf
6
+ from guppylang.defs import GuppyDefinition
6
7
  from hugr import Node, Wire, envelope, ops, val
7
8
  from hugr import tys as ht
8
9
  from hugr.build.dfg import DefinitionBuilder, OpVar
9
10
  from hugr.envelope import EnvelopeConfig
10
11
  from hugr.std.float import FLOAT_T
12
+ from pytket.circuit import Circuit
11
13
 
12
14
  from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc
13
15
  from guppylang_internals.checker.core import Context, Globals
14
- from guppylang_internals.checker.errors.comptime_errors import (
15
- PytketSignatureMismatch,
16
- TketNotInstalled,
17
- )
16
+ from guppylang_internals.checker.errors.comptime_errors import PytketSignatureMismatch
18
17
  from guppylang_internals.checker.expr_checker import check_call, synthesize_call
19
18
  from guppylang_internals.checker.func_checker import (
20
19
  check_signature,
@@ -231,7 +230,7 @@ class ParsedPytketDef(CallableDef, CompilableDef):
231
230
  )
232
231
  lex_params = list(unpack_result)
233
232
  param_order = cast(
234
- list[str], hugr_func.metadata["TKET1.input_parameters"]
233
+ "list[str]", hugr_func.metadata["TKET1.input_parameters"]
235
234
  )
236
235
  lex_names = sorted(param_order)
237
236
  name_to_param = dict(zip(lex_names, lex_params, strict=True))
@@ -369,69 +368,49 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef, CompiledHugrNodeDe
369
368
 
370
369
 
371
370
  def _signature_from_circuit(
372
- input_circuit: Any,
371
+ input_circuit: Circuit,
373
372
  defined_at: ToSpan | None,
374
373
  use_arrays: bool = False,
375
374
  ) -> FunctionType:
376
375
  """Helper function for inferring a function signature from a pytket circuit."""
377
376
  # May want to set proper unitary flags in the future.
378
- try:
379
- import pytket
380
-
381
- if isinstance(input_circuit, pytket.circuit.Circuit):
382
- try:
383
- import tket # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401
384
-
385
- from guppylang.defs import GuppyDefinition
386
- from guppylang.std.angles import angle
387
- from guppylang.std.quantum import qubit
388
-
389
- assert isinstance(qubit, GuppyDefinition)
390
- qubit_ty = cast(TypeDef, qubit.wrapped).check_instantiate([])
391
-
392
- angle_defn = ENGINE.get_checked(angle.id) # type: ignore[attr-defined]
393
- assert isinstance(angle_defn, TypeDef)
394
- angle_ty = angle_defn.check_instantiate([])
395
-
396
- if use_arrays:
397
- inputs = [
398
- FuncInput(array_type(qubit_ty, q_reg.size), InputFlags.Inout)
399
- for q_reg in input_circuit.q_registers
400
- ]
401
- if len(input_circuit.free_symbols()) != 0:
402
- inputs.append(
403
- FuncInput(
404
- array_type(angle_ty, len(input_circuit.free_symbols())),
405
- InputFlags.NoFlags,
406
- )
407
- )
408
- outputs = [
409
- array_type(bool_type(), c_reg.size)
410
- for c_reg in input_circuit.c_registers
411
- ]
412
- circuit_signature = FunctionType(
413
- inputs,
414
- row_to_type(outputs),
415
- )
416
- else:
417
- param_inputs = [
418
- FuncInput(angle_ty, InputFlags.NoFlags)
419
- for _ in range(len(input_circuit.free_symbols()))
420
- ]
421
- circuit_signature = FunctionType(
422
- [FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
423
- + param_inputs,
424
- row_to_type([bool_type()] * input_circuit.n_bits),
425
- )
426
- except ImportError:
427
- err = TketNotInstalled(defined_at)
428
- err.add_sub_diagnostic(TketNotInstalled.InstallInstruction(None))
429
- raise GuppyError(err) from None
430
- else:
431
- pass
432
- except ImportError:
433
- raise InternalGuppyError(
434
- "Pytket error should have been caught earlier"
435
- ) from None
377
+ from guppylang.std.angles import angle # Avoid circular imports
378
+ from guppylang.std.quantum import qubit
379
+
380
+ assert isinstance(qubit, GuppyDefinition)
381
+ qubit_ty = cast("TypeDef", qubit.wrapped).check_instantiate([])
382
+
383
+ angle_defn = ENGINE.get_checked(angle.id) # type: ignore[attr-defined]
384
+ assert isinstance(angle_defn, TypeDef)
385
+ angle_ty = angle_defn.check_instantiate([])
386
+
387
+ if use_arrays:
388
+ inputs = [
389
+ FuncInput(array_type(qubit_ty, q_reg.size), InputFlags.Inout)
390
+ for q_reg in input_circuit.q_registers
391
+ ]
392
+ if len(input_circuit.free_symbols()) != 0:
393
+ inputs.append(
394
+ FuncInput(
395
+ array_type(angle_ty, len(input_circuit.free_symbols())),
396
+ InputFlags.NoFlags,
397
+ )
398
+ )
399
+ outputs = [
400
+ array_type(bool_type(), c_reg.size) for c_reg in input_circuit.c_registers
401
+ ]
402
+ circuit_signature = FunctionType(
403
+ inputs,
404
+ row_to_type(outputs),
405
+ )
436
406
  else:
437
- return circuit_signature
407
+ param_inputs = [
408
+ FuncInput(angle_ty, InputFlags.NoFlags)
409
+ for _ in range(len(input_circuit.free_symbols()))
410
+ ]
411
+ circuit_signature = FunctionType(
412
+ [FuncInput(qubit_ty, InputFlags.Inout)] * input_circuit.n_qubits
413
+ + param_inputs,
414
+ row_to_type([bool_type()] * input_circuit.n_bits),
415
+ )
416
+ return circuit_signature
@@ -55,7 +55,7 @@ class CallableDef(ValueDef):
55
55
  raise RuntimeError("Guppy functions can only be called in a Guppy context")
56
56
 
57
57
 
58
- class CompiledCallableDef(CallableDef, CompiledValueDef):
58
+ class CompiledCallableDef(CallableDef, CompiledValueDef): # type: ignore[misc, unused-ignore]
59
59
  """Abstract base class a global module-level function."""
60
60
 
61
61
  ty: FunctionType
@@ -38,9 +38,9 @@ class RawWasmFunctionDef(RawCustomFunctionDef):
38
38
  def sanitise_type(self, loc: AstNode, fun_ty: FunctionType) -> None:
39
39
  # Place to highlight in error messages
40
40
  match fun_ty.inputs:
41
- case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if wasm_module_name(
42
- ty
43
- ) is not None:
41
+ case [FuncInput(ty=ty, flags=InputFlags.Inout), *args] if (
42
+ wasm_module_name(ty) is not None
43
+ ):
44
44
  for inp in args:
45
45
  if not self.is_type_wasmable(inp.ty):
46
46
  raise GuppyError(UnWasmableType(loc, inp.ty))
@@ -384,9 +384,25 @@ class DiagnosticsRenderer:
384
384
  if is_primary or print_pad_line:
385
385
  render_line("")
386
386
 
387
- # Grab all lines we want to display and remove excessive leading whitespace
387
+ # Grab all lines we want to display
388
388
  prefix_lines = min(prefix_lines, span.start.line - 1)
389
389
  all_lines = self.source.span_lines(span, prefix_lines)
390
+
391
+ # Convert leading tab characters into four whitespaces each (see PEP 8)
392
+ for i, line in enumerate(all_lines):
393
+ line_no_tabs = line.lstrip("\t")
394
+ num_tabs = len(line) - len(line_no_tabs)
395
+ all_lines[i] = " " * (num_tabs * 4) + line_no_tabs
396
+ # Shift span locations, accounting for incorporated \t
397
+ new_start = span.start
398
+ new_end = span.end
399
+ if i == prefix_lines: # Line is the first line in the span
400
+ new_start = span.start.shift_right(num_tabs * 3)
401
+ if i == len(all_lines) - 1: # Line is the last line in the span
402
+ new_end = span.end.shift_right(num_tabs * 3)
403
+ span = Span(new_start or span.start, new_end)
404
+
405
+ # Remove excessive leading whitespace
390
406
  leading_whitespace = min(len(line) - len(line.lstrip()) for line in all_lines)
391
407
  if leading_whitespace > self.MAX_LEADING_WHITESPACE:
392
408
  remove = leading_whitespace - self.OPTIMAL_LEADING_WHITESPACE
@@ -3,14 +3,10 @@ from enum import Enum
3
3
  from types import FrameType
4
4
  from typing import TYPE_CHECKING
5
5
 
6
+ import hugr
6
7
  import hugr.build.function as hf
7
- import hugr.std.collections.array
8
- import hugr.std.float
9
- import hugr.std.int
10
- import hugr.std.logic
11
- import hugr.std.prelude
12
8
  from hugr import ops
13
- from hugr.ext import Extension
9
+ from hugr.ext import Extension, ExtensionRegistry
14
10
  from hugr.package import ModulePointer, Package
15
11
 
16
12
  import guppylang_internals
@@ -150,11 +146,52 @@ class CompilationEngine:
150
146
  types_to_check_worklist: dict[DefId, ParsedDef]
151
147
  to_check_worklist: dict[DefId, ParsedDef]
152
148
 
149
+ # Cached compilation infrastructure (lazy-initialized, program-independent)
150
+ _base_packaged_extensions: list[Extension] | None = None
151
+ _base_resolve_registry: ExtensionRegistry | None = None
152
+
153
153
  def __init__(self) -> None:
154
154
  """Resets the compilation cache."""
155
155
  self.reset()
156
156
  self.additional_extensions = []
157
157
 
158
+ @staticmethod
159
+ def _get_base_packaged_extensions() -> list[Extension]:
160
+ """Get the base list of packaged extensions (cached at class level)."""
161
+ if CompilationEngine._base_packaged_extensions is None:
162
+ from guppylang_internals.std._internal.compiler.tket_exts import (
163
+ TKET_EXTENSIONS,
164
+ )
165
+
166
+ CompilationEngine._base_packaged_extensions = [
167
+ *TKET_EXTENSIONS,
168
+ guppylang_internals.compiler.hugr_extension.EXTENSION, # type: ignore[attr-defined]
169
+ ]
170
+ return CompilationEngine._base_packaged_extensions
171
+
172
+ @staticmethod
173
+ def _get_base_resolve_registry() -> ExtensionRegistry:
174
+ """Get the base resolve registry with standard extensions.
175
+
176
+ Cached at class level.
177
+ """
178
+ if CompilationEngine._base_resolve_registry is None:
179
+ base_extensions = CompilationEngine._get_base_packaged_extensions()
180
+ registry = ExtensionRegistry()
181
+ for ext in [
182
+ *base_extensions,
183
+ hugr.std.prelude.PRELUDE_EXTENSION,
184
+ hugr.std.collections.array.EXTENSION,
185
+ hugr.std.float.FLOAT_OPS_EXTENSION,
186
+ hugr.std.float.FLOAT_TYPES_EXTENSION,
187
+ hugr.std.int.INT_OPS_EXTENSION,
188
+ hugr.std.int.INT_TYPES_EXTENSION,
189
+ hugr.std.logic.EXTENSION,
190
+ ]:
191
+ registry.register_updated(ext)
192
+ CompilationEngine._base_resolve_registry = registry
193
+ return CompilationEngine._base_resolve_registry
194
+
158
195
  def reset(self) -> None:
159
196
  """Resets the compilation cache."""
160
197
  self.parsed = {}
@@ -264,40 +301,56 @@ class CompilationEngine:
264
301
  # loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
265
302
  graph.hugr.entrypoint = compiled_def.hugr_node
266
303
 
267
- # TODO: Currently the list of extensions is manually managed by the user.
268
- # We should compute this dynamically from the imported dependencies instead.
269
- #
270
- # The hugr prelude and std_extensions are implicit.
271
- from guppylang_internals.std._internal.compiler.tket_exts import TKET_EXTENSIONS
304
+ # Use cached base extensions and registry, only add additional extensions
305
+ base_extensions = self._get_base_packaged_extensions()
306
+ packaged_extensions = [*base_extensions, *self.additional_extensions]
272
307
 
273
- extensions = [
274
- *TKET_EXTENSIONS,
275
- guppylang_internals.compiler.hugr_extension.EXTENSION,
276
- *self.additional_extensions,
277
- ]
278
- # TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
279
- all_used_extensions = [
280
- *extensions,
281
- hugr.std.prelude.PRELUDE_EXTENSION,
282
- hugr.std.collections.array.EXTENSION,
283
- hugr.std.float.FLOAT_OPS_EXTENSION,
284
- hugr.std.float.FLOAT_TYPES_EXTENSION,
285
- hugr.std.int.INT_OPS_EXTENSION,
286
- hugr.std.int.INT_TYPES_EXTENSION,
287
- hugr.std.logic.EXTENSION,
288
- ]
289
- graph.hugr.module_root.metadata[CoreMetadataKeys.USED_EXTENSIONS.value] = [
308
+ # Build resolve registry: start with cached base, add any additional
309
+ if self.additional_extensions:
310
+ from copy import deepcopy
311
+
312
+ resolve_registry = deepcopy(self._get_base_resolve_registry())
313
+ for ext in self.additional_extensions:
314
+ resolve_registry.register_updated(ext)
315
+ else:
316
+ resolve_registry = self._get_base_resolve_registry()
317
+
318
+ # Compute used extensions dynamically from the HUGR.
319
+ used_extensions_result = graph.hugr.used_extensions(
320
+ resolve_from=resolve_registry
321
+ )
322
+
323
+ # Set metadata for used extensions
324
+ used_exts_meta = [
290
325
  {
291
326
  "name": ext.name,
292
327
  "version": str(ext.version),
293
328
  }
294
- for ext in all_used_extensions
329
+ for ext in used_extensions_result.used_extensions.extensions.values()
295
330
  ]
331
+ # Add unresolved extensions as well, but we only have the names
332
+ used_exts_meta.extend(
333
+ {
334
+ "name": ext,
335
+ }
336
+ for ext in used_extensions_result.unresolved_extensions
337
+ )
338
+ graph.hugr.module_root.metadata[CoreMetadataKeys.USED_EXTENSIONS.value] = (
339
+ used_exts_meta
340
+ )
296
341
  graph.hugr.module_root.metadata[CoreMetadataKeys.GENERATOR.value] = {
297
342
  "name": "guppylang",
298
343
  "version": guppylang_internals.__version__,
299
344
  }
300
- return ModulePointer(Package(modules=[graph.hugr], extensions=extensions), 0)
345
+ # only package used extensions
346
+ packaged_extensions = [
347
+ ext
348
+ for ext in packaged_extensions
349
+ if ext.name in used_extensions_result.ids()
350
+ ]
351
+ return ModulePointer(
352
+ Package(modules=[graph.hugr], extensions=packaged_extensions), 0
353
+ )
301
354
 
302
355
 
303
356
  ENGINE: CompilationEngine = CompilationEngine()
@@ -104,4 +104,4 @@ def pretty_errors(f: FuncT) -> FuncT:
104
104
  with exception_hook(hook):
105
105
  return f(*args, **kwargs)
106
106
 
107
- return cast(FuncT, pretty_errors_wrapped)
107
+ return cast("FuncT", pretty_errors_wrapped)