guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,405 @@
1
+ import ast
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, cast
4
+
5
+ import hugr.build.function as hf
6
+ from hugr import Node, Wire, envelope, ops, val
7
+ from hugr import tys as ht
8
+ from hugr.build.dfg import DefinitionBuilder, OpVar
9
+ from hugr.envelope import EnvelopeConfig
10
+
11
+ from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc
12
+ from guppylang_internals.checker.core import Context, Globals
13
+ from guppylang_internals.checker.errors.comptime_errors import (
14
+ PytketSignatureMismatch,
15
+ TketNotInstalled,
16
+ )
17
+ from guppylang_internals.checker.expr_checker import check_call, synthesize_call
18
+ from guppylang_internals.checker.func_checker import (
19
+ check_signature,
20
+ )
21
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
22
+ from guppylang_internals.definition.common import (
23
+ CompilableDef,
24
+ ParsableDef,
25
+ )
26
+ from guppylang_internals.definition.declaration import BodyNotEmptyError
27
+ from guppylang_internals.definition.function import (
28
+ PyFunc,
29
+ compile_call,
30
+ load_with_args,
31
+ parse_py_func,
32
+ )
33
+ from guppylang_internals.definition.ty import TypeDef
34
+ from guppylang_internals.definition.value import (
35
+ CallableDef,
36
+ CallReturnWires,
37
+ CompiledCallableDef,
38
+ CompiledHugrNodeDef,
39
+ )
40
+ from guppylang_internals.error import GuppyError, InternalGuppyError
41
+ from guppylang_internals.nodes import GlobalCall
42
+ from guppylang_internals.span import SourceMap, Span, ToSpan
43
+ from guppylang_internals.std._internal.compiler.array import (
44
+ array_discard_empty,
45
+ array_new,
46
+ array_pop,
47
+ )
48
+ from guppylang_internals.std._internal.compiler.prelude import build_unwrap
49
+ from guppylang_internals.std._internal.compiler.tket_bool import OpaqueBool, make_opaque
50
+ from guppylang_internals.tys.builtin import array_type, bool_type
51
+ from guppylang_internals.tys.subst import Inst, Subst
52
+ from guppylang_internals.tys.ty import (
53
+ FuncInput,
54
+ FunctionType,
55
+ InputFlags,
56
+ Type,
57
+ row_to_type,
58
+ )
59
+
60
+
61
+ @dataclass(frozen=True)
62
+ class RawPytketDef(ParsableDef):
63
+ """A raw function stub definition describing the signature of a circuit.
64
+
65
+ Args:
66
+ id: The unique definition identifier.
67
+ name: The name of the function stub.
68
+ defined_at: The AST node where the stub was defined.
69
+ python_func: The Python function stub.
70
+ input_circuit: The user-provided pytket circuit.
71
+ """
72
+
73
+ python_func: PyFunc
74
+ input_circuit: Any
75
+
76
+ description: str = field(default="pytket circuit", init=False)
77
+
78
+ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef":
79
+ """Parses and checks the user-provided signature matches the user-provided
80
+ circuit.
81
+ """
82
+ # Retrieve stub signature.
83
+ func_ast, _ = parse_py_func(self.python_func, sources)
84
+ if not has_empty_body(func_ast):
85
+ # Function stub should have empty body.
86
+ raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
87
+ stub_signature = check_signature(func_ast, globals)
88
+
89
+ # Compare signatures.
90
+ circuit_signature = _signature_from_circuit(
91
+ self.input_circuit, globals, self.defined_at
92
+ )
93
+ if not (
94
+ circuit_signature.inputs == stub_signature.inputs
95
+ and circuit_signature.output == stub_signature.output
96
+ ):
97
+ err = PytketSignatureMismatch(func_ast, self.name)
98
+ err.add_sub_diagnostic(
99
+ PytketSignatureMismatch.TypeHint(None, circ_sig=circuit_signature)
100
+ )
101
+ raise GuppyError(err)
102
+ return ParsedPytketDef(
103
+ self.id, self.name, func_ast, stub_signature, self.input_circuit, False
104
+ )
105
+
106
+
107
+ @dataclass(frozen=True)
108
+ class RawLoadPytketDef(ParsableDef):
109
+ """A raw definition for loading pytket circuits without explicit function stub.
110
+
111
+ Args:
112
+ id: The unique definition identifier.
113
+ name: The name of the circuit function.
114
+ defined_at: The AST node of the definition (here always None).
115
+ source_span: The source span where the circuit was loaded.
116
+ input_circuit: The user-provided pytket circuit.
117
+ use_arrays: Whether the circuit function should use arrays as input types.
118
+ """
119
+
120
+ source_span: Span | None
121
+ input_circuit: Any
122
+ use_arrays: bool
123
+
124
+ description: str = field(default="pytket circuit", init=False)
125
+
126
+ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef":
127
+ """Creates a function signature based on the user-provided circuit."""
128
+ circuit_signature = _signature_from_circuit(
129
+ self.input_circuit, globals, self.source_span, self.use_arrays
130
+ )
131
+
132
+ return ParsedPytketDef(
133
+ self.id,
134
+ self.name,
135
+ self.defined_at,
136
+ circuit_signature,
137
+ self.input_circuit,
138
+ self.use_arrays,
139
+ )
140
+
141
+
142
+ @dataclass(frozen=True)
143
+ class ParsedPytketDef(CallableDef, CompilableDef):
144
+ """A circuit definition with signature.
145
+
146
+ Args:
147
+ id: The unique definition identifier.
148
+ name: The name of the function.
149
+ defined_at: The AST node of the function stub, if there is one.
150
+ ty: The type of the function.
151
+ input_circuit: The user-provided pytket circuit.
152
+ use_arrays: Whether the circuit function should use arrays as input types.
153
+ """
154
+
155
+ ty: FunctionType
156
+ input_circuit: Any
157
+ use_arrays: bool
158
+
159
+ description: str = field(default="pytket circuit", init=False)
160
+
161
+ def compile_outer(
162
+ self, module: DefinitionBuilder[OpVar], ctx: CompilerContext
163
+ ) -> "CompiledPytketDef":
164
+ """Adds a Hugr `FuncDefn` node for this function to the Hugr."""
165
+ try:
166
+ import pytket
167
+
168
+ if isinstance(self.input_circuit, pytket.circuit.Circuit):
169
+ from tket.circuit import ( # type: ignore[import-untyped, import-not-found, unused-ignore]
170
+ Tk2Circuit,
171
+ )
172
+
173
+ # TODO extract the correct entry point from the module
174
+ circ = envelope.read_envelope(
175
+ Tk2Circuit(self.input_circuit).to_bytes(EnvelopeConfig.TEXT)
176
+ ).modules[0]
177
+ mapping = module.hugr.insert_hugr(circ)
178
+ hugr_func = mapping[circ.entrypoint]
179
+
180
+ func_type = self.ty.to_hugr_poly(ctx)
181
+ outer_func = module.module_root_builder().define_function(
182
+ self.name, func_type.body.input, func_type.body.output
183
+ )
184
+
185
+ # Initialise every input bit in the circuit as false.
186
+ # TODO: Provide the option for the user to pass this input as well.
187
+ bool_wires = [
188
+ outer_func.load(val.FALSE) for _ in range(self.input_circuit.n_bits)
189
+ ]
190
+
191
+ input_list = []
192
+ if self.use_arrays:
193
+ # If the input is given as arrays, we need to unpack each element in
194
+ # them into separate wires.
195
+ # TODO: Replace with actual unpack HUGR op once
196
+ # https://github.com/CQCL/hugr/issues/1947 is done.
197
+ def unpack(
198
+ array: Wire, elem_ty: ht.Type, length: int
199
+ ) -> list[Wire]:
200
+ err = "Internal error: unpacking of array failed"
201
+ elts: list[Wire] = []
202
+ for i in range(length):
203
+ res = outer_func.add_op(
204
+ array_pop(elem_ty, length - i, True), array
205
+ )
206
+ [elt_opt, array] = build_unwrap(outer_func, res, err)
207
+ [elt] = build_unwrap(outer_func, elt_opt, err)
208
+ elts.append(elt)
209
+ outer_func.add_op(array_discard_empty(elem_ty), array)
210
+ return elts
211
+
212
+ # Must be same length due to earlier signature computation /
213
+ # comparison.
214
+ for q_reg, wire in zip(
215
+ self.input_circuit.q_registers,
216
+ list(outer_func.inputs()),
217
+ strict=True,
218
+ ):
219
+ input_list.extend(unpack(wire, ht.Option(ht.Qubit), q_reg.size))
220
+
221
+ else:
222
+ # Otherwise pass inputs directly.
223
+ input_list = list(outer_func.inputs())
224
+
225
+ call_node = outer_func.call(hugr_func, *(input_list + bool_wires))
226
+
227
+ # Pytket circuit hugr has qubit and bool wires in the opposite
228
+ # order to Guppy output wires.
229
+ output_list: list[Wire] = list(call_node.outputs())
230
+ wires = (
231
+ output_list[self.input_circuit.n_qubits :]
232
+ + output_list[: self.input_circuit.n_qubits]
233
+ )
234
+ # Convert hugr sum bools into the opaque bools that Guppy uses.
235
+ wires = [
236
+ outer_func.add_op(make_opaque(), wire)
237
+ if outer_func.hugr.port_type(wire.out_port()) == ht.Bool
238
+ else wire
239
+ for wire in wires
240
+ ]
241
+
242
+ if self.use_arrays:
243
+
244
+ def pack(elems: list[Wire], elem_ty: ht.Type, length: int) -> Wire:
245
+ elem_opts = [
246
+ outer_func.add_op(ops.Some(elem_ty), elem) for elem in elems
247
+ ]
248
+ return outer_func.add_op(
249
+ array_new(ht.Option(elem_ty), length), *elem_opts
250
+ )
251
+
252
+ array_wires: list[Wire] = []
253
+ wire_idx = 0
254
+ # First pack bool results into an array.
255
+ for c_reg in self.input_circuit.c_registers:
256
+ array_wires.append(
257
+ pack(
258
+ wires[wire_idx : wire_idx + c_reg.size],
259
+ OpaqueBool,
260
+ c_reg.size,
261
+ )
262
+ )
263
+ wire_idx = wire_idx + c_reg.size
264
+ # Then the borrowed qubits also need to be put back into arrays.
265
+ for q_reg in self.input_circuit.q_registers:
266
+ array_wires.append(
267
+ pack(
268
+ wires[wire_idx : wire_idx + q_reg.size],
269
+ ht.Qubit,
270
+ q_reg.size,
271
+ )
272
+ )
273
+ wire_idx = wire_idx + q_reg.size
274
+ wires = array_wires
275
+
276
+ outer_func.set_outputs(*wires)
277
+
278
+ except ImportError:
279
+ pass
280
+
281
+ return CompiledPytketDef(
282
+ self.id,
283
+ self.name,
284
+ self.defined_at,
285
+ self.ty,
286
+ self.input_circuit,
287
+ self.use_arrays,
288
+ outer_func,
289
+ )
290
+
291
+ def check_call(
292
+ self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
293
+ ) -> tuple[ast.expr, Subst]:
294
+ """Checks the return type of a function call against a given type."""
295
+ # Use default implementation from the expression checker
296
+ args, subst, inst = check_call(self.ty, args, ty, node, ctx)
297
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
298
+ return node, subst
299
+
300
+ def synthesize_call(
301
+ self, args: list[ast.expr], node: AstNode, ctx: Context
302
+ ) -> tuple[ast.expr, Type]:
303
+ """Synthesizes the return type of a function call."""
304
+ # Use default implementation from the expression checker
305
+ args, ty, inst = synthesize_call(self.ty, args, node, ctx)
306
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
307
+ return node, ty
308
+
309
+
310
+ @dataclass(frozen=True)
311
+ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef, CompiledHugrNodeDef):
312
+ """A function definition with a corresponding Hugr node.
313
+
314
+ Args:
315
+ id: The unique definition identifier.
316
+ name: The name of the function.
317
+ defined_at: The AST node where the function was defined.
318
+ ty: The type of the function.
319
+ input_circuit: The user-provided pytket circuit.
320
+ func_df: The Hugr function definition.
321
+ use_arrays: Whether the circuit function uses arrays as input types.
322
+ """
323
+
324
+ func_def: hf.Function
325
+
326
+ @property
327
+ def hugr_node(self) -> Node:
328
+ """The Hugr node this definition was compiled into."""
329
+ return self.func_def.parent_node
330
+
331
+ def load_with_args(
332
+ self,
333
+ type_args: Inst,
334
+ dfg: DFContainer,
335
+ ctx: CompilerContext,
336
+ node: AstNode,
337
+ ) -> Wire:
338
+ """Loads the function as a value into a local Hugr dataflow graph."""
339
+ # Use implementation from function definition.
340
+ return load_with_args(type_args, dfg, self.ty, self.func_def)
341
+
342
+ def compile_call(
343
+ self,
344
+ args: list[Wire],
345
+ type_args: Inst,
346
+ dfg: DFContainer,
347
+ ctx: CompilerContext,
348
+ node: AstNode,
349
+ ) -> CallReturnWires:
350
+ """Compiles a call to the function."""
351
+ # Use implementation from function definition.
352
+ return compile_call(args, type_args, dfg, self.ty, self.func_def)
353
+
354
+
355
+ def _signature_from_circuit(
356
+ input_circuit: Any,
357
+ globals: Globals,
358
+ defined_at: ToSpan | None,
359
+ use_arrays: bool = False,
360
+ ) -> FunctionType:
361
+ """Helper function for inferring a function signature from a pytket circuit."""
362
+ try:
363
+ import pytket
364
+
365
+ if isinstance(input_circuit, pytket.circuit.Circuit):
366
+ try:
367
+ import tket # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401
368
+
369
+ from guppylang.defs import GuppyDefinition
370
+ from guppylang.std.quantum import qubit
371
+
372
+ assert isinstance(qubit, GuppyDefinition)
373
+ qubit_ty = cast(TypeDef, qubit.wrapped).check_instantiate([])
374
+
375
+ if use_arrays:
376
+ inputs = [
377
+ FuncInput(array_type(qubit_ty, q_reg.size), InputFlags.Inout)
378
+ for q_reg in input_circuit.q_registers
379
+ ]
380
+ outputs = [
381
+ array_type(bool_type(), c_reg.size)
382
+ for c_reg in input_circuit.c_registers
383
+ ]
384
+ circuit_signature = FunctionType(
385
+ inputs,
386
+ row_to_type(outputs),
387
+ )
388
+ else:
389
+ circuit_signature = FunctionType(
390
+ [FuncInput(qubit_ty, InputFlags.Inout)]
391
+ * input_circuit.n_qubits,
392
+ row_to_type([bool_type()] * input_circuit.n_bits),
393
+ )
394
+ except ImportError:
395
+ err = TketNotInstalled(defined_at)
396
+ err.add_sub_diagnostic(TketNotInstalled.InstallInstruction(None))
397
+ raise GuppyError(err) from None
398
+ else:
399
+ pass
400
+ except ImportError:
401
+ raise InternalGuppyError(
402
+ "Pytket error should have been caught earlier"
403
+ ) from None
404
+ else:
405
+ return circuit_signature