guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,989 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from collections.abc import Iterator, Sequence
|
|
3
|
+
from contextlib import AbstractContextManager, ExitStack, contextmanager
|
|
4
|
+
from typing import Any, Final, TypeGuard, TypeVar
|
|
5
|
+
|
|
6
|
+
import hugr
|
|
7
|
+
import hugr.std.collections.array
|
|
8
|
+
import hugr.std.float
|
|
9
|
+
import hugr.std.int
|
|
10
|
+
import hugr.std.logic
|
|
11
|
+
import hugr.std.prelude
|
|
12
|
+
from hugr import Wire, ops
|
|
13
|
+
from hugr import tys as ht
|
|
14
|
+
from hugr import val as hv
|
|
15
|
+
from hugr.build import function as hf
|
|
16
|
+
from hugr.build.cond_loop import Conditional
|
|
17
|
+
from hugr.build.dfg import DP, DfBase
|
|
18
|
+
from typing_extensions import assert_never
|
|
19
|
+
|
|
20
|
+
from guppylang_internals.ast_util import AstNode, AstVisitor, get_type
|
|
21
|
+
from guppylang_internals.cfg.builder import tmp_vars
|
|
22
|
+
from guppylang_internals.checker.core import Variable, contains_subscript
|
|
23
|
+
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
24
|
+
from guppylang_internals.compiler.core import (
|
|
25
|
+
DEBUG_EXTENSION,
|
|
26
|
+
RESULT_EXTENSION,
|
|
27
|
+
CompilerBase,
|
|
28
|
+
CompilerContext,
|
|
29
|
+
DFContainer,
|
|
30
|
+
GlobalConstId,
|
|
31
|
+
)
|
|
32
|
+
from guppylang_internals.compiler.hugr_extension import PartialOp
|
|
33
|
+
from guppylang_internals.definition.custom import CustomFunctionDef
|
|
34
|
+
from guppylang_internals.definition.value import (
|
|
35
|
+
CallableDef,
|
|
36
|
+
CallReturnWires,
|
|
37
|
+
CompiledCallableDef,
|
|
38
|
+
CompiledValueDef,
|
|
39
|
+
)
|
|
40
|
+
from guppylang_internals.engine import ENGINE
|
|
41
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
42
|
+
from guppylang_internals.nodes import (
|
|
43
|
+
BarrierExpr,
|
|
44
|
+
DesugaredArrayComp,
|
|
45
|
+
DesugaredGenerator,
|
|
46
|
+
DesugaredListComp,
|
|
47
|
+
ExitKind,
|
|
48
|
+
FieldAccessAndDrop,
|
|
49
|
+
GenericParamValue,
|
|
50
|
+
GlobalCall,
|
|
51
|
+
GlobalName,
|
|
52
|
+
LocalCall,
|
|
53
|
+
PanicExpr,
|
|
54
|
+
PartialApply,
|
|
55
|
+
PlaceNode,
|
|
56
|
+
ResultExpr,
|
|
57
|
+
StateResultExpr,
|
|
58
|
+
SubscriptAccessAndDrop,
|
|
59
|
+
TensorCall,
|
|
60
|
+
TupleAccessAndDrop,
|
|
61
|
+
TypeApply,
|
|
62
|
+
)
|
|
63
|
+
from guppylang_internals.std._internal.compiler.arithmetic import (
|
|
64
|
+
UnsignedIntVal,
|
|
65
|
+
convert_ifromusize,
|
|
66
|
+
)
|
|
67
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
68
|
+
array_convert_from_std_array,
|
|
69
|
+
array_convert_to_std_array,
|
|
70
|
+
array_map,
|
|
71
|
+
array_new,
|
|
72
|
+
array_repeat,
|
|
73
|
+
standard_array_type,
|
|
74
|
+
unpack_array,
|
|
75
|
+
)
|
|
76
|
+
from guppylang_internals.std._internal.compiler.list import (
|
|
77
|
+
list_new,
|
|
78
|
+
)
|
|
79
|
+
from guppylang_internals.std._internal.compiler.prelude import (
|
|
80
|
+
build_error,
|
|
81
|
+
build_panic,
|
|
82
|
+
build_unwrap,
|
|
83
|
+
panic,
|
|
84
|
+
)
|
|
85
|
+
from guppylang_internals.std._internal.compiler.tket_bool import (
|
|
86
|
+
OpaqueBool,
|
|
87
|
+
OpaqueBoolVal,
|
|
88
|
+
make_opaque,
|
|
89
|
+
not_op,
|
|
90
|
+
read_bool,
|
|
91
|
+
)
|
|
92
|
+
from guppylang_internals.tys.arg import ConstArg
|
|
93
|
+
from guppylang_internals.tys.builtin import (
|
|
94
|
+
bool_type,
|
|
95
|
+
get_element_type,
|
|
96
|
+
int_type,
|
|
97
|
+
is_bool_type,
|
|
98
|
+
is_frozenarray_type,
|
|
99
|
+
)
|
|
100
|
+
from guppylang_internals.tys.const import ConstValue
|
|
101
|
+
from guppylang_internals.tys.subst import Inst
|
|
102
|
+
from guppylang_internals.tys.ty import (
|
|
103
|
+
BoundTypeVar,
|
|
104
|
+
FuncInput,
|
|
105
|
+
FunctionType,
|
|
106
|
+
InputFlags,
|
|
107
|
+
NoneType,
|
|
108
|
+
NumericType,
|
|
109
|
+
OpaqueType,
|
|
110
|
+
TupleType,
|
|
111
|
+
Type,
|
|
112
|
+
type_to_row,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ExprCompiler(CompilerBase, AstVisitor[Wire]):
|
|
117
|
+
"""A compiler from guppylang expressions to Hugr."""
|
|
118
|
+
|
|
119
|
+
dfg: DFContainer
|
|
120
|
+
|
|
121
|
+
def compile(self, expr: ast.expr, dfg: DFContainer) -> Wire:
|
|
122
|
+
"""Compiles an expression and returns a single wire holding the output value."""
|
|
123
|
+
self.dfg = dfg
|
|
124
|
+
return self.visit(expr)
|
|
125
|
+
|
|
126
|
+
def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[Wire]:
|
|
127
|
+
"""Compiles a row expression and returns a list of wires, one for each value in
|
|
128
|
+
the row.
|
|
129
|
+
|
|
130
|
+
On Python-level, we treat tuples like rows on top-level. However, nested tuples
|
|
131
|
+
are treated like regular Guppy tuples.
|
|
132
|
+
"""
|
|
133
|
+
return [self.compile(e, dfg) for e in expr_to_row(expr)]
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def builder(self) -> DfBase[ops.DfParentOp]:
|
|
137
|
+
"""The current Hugr dataflow graph builder."""
|
|
138
|
+
return self.dfg.builder
|
|
139
|
+
|
|
140
|
+
@contextmanager
|
|
141
|
+
def _new_dfcontainer(
|
|
142
|
+
self, inputs: list[PlaceNode], builder: DfBase[DP]
|
|
143
|
+
) -> Iterator[None]:
|
|
144
|
+
"""Context manager to build a graph inside a new `DFContainer`.
|
|
145
|
+
|
|
146
|
+
Automatically updates `self.dfg` and makes the inputs available.
|
|
147
|
+
"""
|
|
148
|
+
old = self.dfg
|
|
149
|
+
# Check that the input names are unique
|
|
150
|
+
assert len({inp.place.id for inp in inputs}) == len(
|
|
151
|
+
inputs
|
|
152
|
+
), "Inputs are not unique"
|
|
153
|
+
self.dfg = DFContainer(builder, self.ctx, self.dfg.locals.copy())
|
|
154
|
+
hugr_input = builder.input_node
|
|
155
|
+
for input_node, wire in zip(inputs, hugr_input, strict=True):
|
|
156
|
+
self.dfg[input_node.place] = wire
|
|
157
|
+
|
|
158
|
+
yield
|
|
159
|
+
|
|
160
|
+
self.dfg = old
|
|
161
|
+
|
|
162
|
+
@contextmanager
|
|
163
|
+
def _new_loop(
|
|
164
|
+
self,
|
|
165
|
+
just_inputs_vars: list[PlaceNode],
|
|
166
|
+
loop_vars: list[PlaceNode],
|
|
167
|
+
break_predicate: PlaceNode,
|
|
168
|
+
) -> Iterator[None]:
|
|
169
|
+
"""Context manager to build a graph inside a new `TailLoop` node.
|
|
170
|
+
|
|
171
|
+
Automatically adds the `Output` node to the loop body once the context manager
|
|
172
|
+
exits.
|
|
173
|
+
"""
|
|
174
|
+
just_inputs = [self.visit(name) for name in just_inputs_vars]
|
|
175
|
+
loop_inputs = [self.visit(name) for name in loop_vars]
|
|
176
|
+
loop = self.builder.add_tail_loop(just_inputs, loop_inputs)
|
|
177
|
+
with self._new_dfcontainer(just_inputs_vars + loop_vars, loop):
|
|
178
|
+
yield
|
|
179
|
+
# Output the branch predicate and the inputs for the next iteration. Note
|
|
180
|
+
# that we have to do fresh calls to `self.visit` here since we're in a new
|
|
181
|
+
# context
|
|
182
|
+
do_break = self.visit(break_predicate)
|
|
183
|
+
loop.set_loop_outputs(do_break, *(self.visit(name) for name in loop_vars))
|
|
184
|
+
# Update the DFG with the outputs from the loop
|
|
185
|
+
for node, wire in zip(loop_vars, loop, strict=True):
|
|
186
|
+
self.dfg[node.place] = wire
|
|
187
|
+
|
|
188
|
+
@contextmanager
|
|
189
|
+
def _new_case(
|
|
190
|
+
self,
|
|
191
|
+
inputs: list[PlaceNode],
|
|
192
|
+
outputs: list[PlaceNode],
|
|
193
|
+
conditional: Conditional,
|
|
194
|
+
case_id: int,
|
|
195
|
+
) -> Iterator[None]:
|
|
196
|
+
"""Context manager to build a graph inside a new `Case` node.
|
|
197
|
+
|
|
198
|
+
Automatically adds the `Output` node once the context manager exits.
|
|
199
|
+
"""
|
|
200
|
+
# TODO: `Case` is `_DfgBase`, but not `Dfg`?
|
|
201
|
+
case = conditional.add_case(case_id)
|
|
202
|
+
with self._new_dfcontainer(inputs, case):
|
|
203
|
+
yield
|
|
204
|
+
case.set_outputs(*(self.visit(name) for name in outputs))
|
|
205
|
+
|
|
206
|
+
def _if_else(
|
|
207
|
+
self,
|
|
208
|
+
cond: ast.expr,
|
|
209
|
+
inputs: list[PlaceNode],
|
|
210
|
+
outputs: list[PlaceNode],
|
|
211
|
+
only_true_inputs: list[PlaceNode] | None = None,
|
|
212
|
+
only_false_inputs: list[PlaceNode] | None = None,
|
|
213
|
+
) -> tuple[AbstractContextManager[None], AbstractContextManager[None]]:
|
|
214
|
+
"""Builds a `Conditional`, returning context managers to build the `True` and
|
|
215
|
+
`False` branch.
|
|
216
|
+
"""
|
|
217
|
+
cond_wire = self.visit(cond)
|
|
218
|
+
cond_ty = self.builder.hugr.port_type(cond_wire.out_port())
|
|
219
|
+
if cond_ty == OpaqueBool:
|
|
220
|
+
cond_wire = self.builder.add_op(read_bool(), cond_wire)
|
|
221
|
+
conditional = self.builder.add_conditional(
|
|
222
|
+
cond_wire, *(self.visit(inp) for inp in inputs)
|
|
223
|
+
)
|
|
224
|
+
only_true_inputs_ = only_true_inputs or []
|
|
225
|
+
only_false_inputs_ = only_false_inputs or []
|
|
226
|
+
|
|
227
|
+
@contextmanager
|
|
228
|
+
def true_case() -> Iterator[None]:
|
|
229
|
+
with self._new_case(only_true_inputs_ + inputs, outputs, conditional, 1):
|
|
230
|
+
yield
|
|
231
|
+
|
|
232
|
+
@contextmanager
|
|
233
|
+
def false_case() -> Iterator[None]:
|
|
234
|
+
with self._new_case(only_false_inputs_ + inputs, outputs, conditional, 0):
|
|
235
|
+
yield
|
|
236
|
+
# Update the DFG with the outputs from the Conditional node
|
|
237
|
+
for node, wire in zip(outputs, conditional, strict=True):
|
|
238
|
+
self.dfg[node.place] = wire
|
|
239
|
+
|
|
240
|
+
return true_case(), false_case()
|
|
241
|
+
|
|
242
|
+
@contextmanager
|
|
243
|
+
def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]:
|
|
244
|
+
"""Context manager to build a graph inside the `true` case of a `Conditional`
|
|
245
|
+
|
|
246
|
+
In the `false` case, the inputs are outputted as is.
|
|
247
|
+
"""
|
|
248
|
+
true_case, false_case = self._if_else(cond, inputs, inputs)
|
|
249
|
+
with false_case:
|
|
250
|
+
# If the condition is false, output the inputs as is
|
|
251
|
+
pass
|
|
252
|
+
with true_case:
|
|
253
|
+
# If the condition is true, we enter the `with` block
|
|
254
|
+
yield
|
|
255
|
+
|
|
256
|
+
def visit_Constant(self, node: ast.Constant) -> Wire:
|
|
257
|
+
if value := python_value_to_hugr(node.value, get_type(node), self.ctx):
|
|
258
|
+
return self.builder.load(value)
|
|
259
|
+
raise InternalGuppyError("Unsupported constant expression in compiler")
|
|
260
|
+
|
|
261
|
+
def visit_PlaceNode(self, node: PlaceNode) -> Wire:
|
|
262
|
+
if subscript := contains_subscript(node.place):
|
|
263
|
+
if subscript.item not in self.dfg:
|
|
264
|
+
self.dfg[subscript.item] = self.visit(subscript.item_expr)
|
|
265
|
+
self.dfg[subscript] = self.visit(subscript.getitem_call)
|
|
266
|
+
return self.dfg[node.place]
|
|
267
|
+
|
|
268
|
+
def visit_GlobalName(self, node: GlobalName) -> Wire:
|
|
269
|
+
defn = ENGINE.get_checked(node.def_id)
|
|
270
|
+
if isinstance(defn, CallableDef) and defn.ty.parametrized:
|
|
271
|
+
# TODO: This should be caught during checking
|
|
272
|
+
err = UnsupportedError(
|
|
273
|
+
node, "Polymorphic functions as dynamic higher-order values"
|
|
274
|
+
)
|
|
275
|
+
raise GuppyError(err)
|
|
276
|
+
|
|
277
|
+
defn, [] = self.ctx.build_compiled_def(node.def_id, type_args=[])
|
|
278
|
+
assert isinstance(defn, CompiledValueDef)
|
|
279
|
+
return defn.load(self.dfg, self.ctx, node)
|
|
280
|
+
|
|
281
|
+
def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
|
|
282
|
+
match node.param.ty:
|
|
283
|
+
case NumericType(NumericType.Kind.Nat):
|
|
284
|
+
# Generic nat parameters are encoded using Hugr bounded nat parameters,
|
|
285
|
+
# so they are not monomorphized when compiling to Hugr
|
|
286
|
+
arg = node.param.to_bound().to_hugr(self.ctx)
|
|
287
|
+
load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
|
|
288
|
+
[arg], ht.FunctionType([], [ht.USize()])
|
|
289
|
+
)
|
|
290
|
+
usize = self.builder.add_op(load_nat)
|
|
291
|
+
return self.builder.add_op(convert_ifromusize(), usize)
|
|
292
|
+
case ty:
|
|
293
|
+
# Look up monomorphization
|
|
294
|
+
assert self.ctx.current_mono_args is not None
|
|
295
|
+
match self.ctx.current_mono_args[node.param.idx]:
|
|
296
|
+
case ConstArg(const=ConstValue(value=v)):
|
|
297
|
+
val = python_value_to_hugr(v, ty, self.ctx)
|
|
298
|
+
assert val is not None
|
|
299
|
+
return self.builder.load(val)
|
|
300
|
+
case _:
|
|
301
|
+
raise InternalGuppyError("Monomorphized const is not a value")
|
|
302
|
+
|
|
303
|
+
def visit_Name(self, node: ast.Name) -> Wire:
|
|
304
|
+
raise InternalGuppyError("Node should have been removed during type checking.")
|
|
305
|
+
|
|
306
|
+
def visit_Tuple(self, node: ast.Tuple) -> Wire:
|
|
307
|
+
elems = [self.visit(e) for e in node.elts]
|
|
308
|
+
types = [get_type(e) for e in node.elts]
|
|
309
|
+
return self._pack_tuple(elems, types)
|
|
310
|
+
|
|
311
|
+
def visit_List(self, node: ast.List) -> Wire:
|
|
312
|
+
# Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension
|
|
313
|
+
inputs = [self.visit(e) for e in node.elts]
|
|
314
|
+
list_ty = get_type(node)
|
|
315
|
+
elem_ty = get_element_type(list_ty)
|
|
316
|
+
return list_new(self.builder, elem_ty.to_hugr(self.ctx), inputs)
|
|
317
|
+
|
|
318
|
+
def _unpack_tuple(self, wire: Wire, types: Sequence[Type]) -> Sequence[Wire]:
|
|
319
|
+
"""Add a tuple unpack operation to the graph"""
|
|
320
|
+
types = [t.to_hugr(self.ctx) for t in types]
|
|
321
|
+
return list(self.builder.add_op(ops.UnpackTuple(types), wire))
|
|
322
|
+
|
|
323
|
+
def _pack_tuple(self, wires: Sequence[Wire], types: Sequence[Type]) -> Wire:
|
|
324
|
+
"""Add a tuple pack operation to the graph"""
|
|
325
|
+
types = [t.to_hugr(self.ctx) for t in types]
|
|
326
|
+
return self.builder.add_op(ops.MakeTuple(types), *wires)
|
|
327
|
+
|
|
328
|
+
def _pack_returns(self, returns: Sequence[Wire], return_ty: Type) -> Wire:
|
|
329
|
+
"""Groups function return values into a tuple"""
|
|
330
|
+
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
|
|
331
|
+
types = type_to_row(return_ty)
|
|
332
|
+
assert len(returns) == len(types)
|
|
333
|
+
return self._pack_tuple(returns, types)
|
|
334
|
+
assert (
|
|
335
|
+
len(returns) == 1
|
|
336
|
+
), f"Expected a single return value. Got {returns}. return type {return_ty}"
|
|
337
|
+
return returns[0]
|
|
338
|
+
|
|
339
|
+
def _update_inout_ports(
|
|
340
|
+
self,
|
|
341
|
+
args: list[ast.expr],
|
|
342
|
+
inout_ports: Iterator[Wire],
|
|
343
|
+
func_ty: FunctionType,
|
|
344
|
+
) -> None:
|
|
345
|
+
"""Helper method that updates the ports for borrowed arguments after a call."""
|
|
346
|
+
for inp, arg in zip(func_ty.inputs, args, strict=True):
|
|
347
|
+
if InputFlags.Inout in inp.flags:
|
|
348
|
+
# Linearity checker ensures that borrowed arguments that are not places
|
|
349
|
+
# can be safely dropped after the call returns
|
|
350
|
+
if not isinstance(arg, PlaceNode):
|
|
351
|
+
next(inout_ports)
|
|
352
|
+
continue
|
|
353
|
+
self.dfg[arg.place] = next(inout_ports)
|
|
354
|
+
# Places involving subscripts need to generate code for the appropriate
|
|
355
|
+
# `__setitem__` call. Nested subscripts are handled automatically since
|
|
356
|
+
# `arg.place.parent` occurs as an arg of this call, so will also
|
|
357
|
+
# be recursively reassigned.
|
|
358
|
+
if subscript := contains_subscript(arg.place):
|
|
359
|
+
assert subscript.setitem_call is not None
|
|
360
|
+
# Need to assign __setitem__ value before compiling call.
|
|
361
|
+
# Note that the assignment to `self.dfg[arg.place]` also updated
|
|
362
|
+
# `self.dfg[subscript]` so that it now contains the value we want
|
|
363
|
+
# to write back into the subscript.
|
|
364
|
+
self.dfg[subscript.setitem_call.value_var] = self.dfg[subscript]
|
|
365
|
+
self.visit(subscript.setitem_call.call)
|
|
366
|
+
assert next(inout_ports, None) is None, "Too many inout return ports"
|
|
367
|
+
|
|
368
|
+
def visit_LocalCall(self, node: LocalCall) -> Wire:
|
|
369
|
+
func = self.visit(node.func)
|
|
370
|
+
func_ty = get_type(node.func)
|
|
371
|
+
assert isinstance(func_ty, FunctionType)
|
|
372
|
+
num_returns = len(type_to_row(func_ty.output))
|
|
373
|
+
|
|
374
|
+
args = self._compile_call_args(node.args, func_ty)
|
|
375
|
+
call = self.builder.add_op(
|
|
376
|
+
ops.CallIndirect(func_ty.to_hugr(self.ctx)), func, *args
|
|
377
|
+
)
|
|
378
|
+
regular_returns = list(call[:num_returns])
|
|
379
|
+
inout_returns = call[num_returns:]
|
|
380
|
+
self._update_inout_ports(node.args, inout_returns, func_ty)
|
|
381
|
+
return self._pack_returns(regular_returns, func_ty.output)
|
|
382
|
+
|
|
383
|
+
def visit_TensorCall(self, node: TensorCall) -> Wire:
|
|
384
|
+
functions: Wire = self.visit(node.func)
|
|
385
|
+
function_types = get_type(node.func)
|
|
386
|
+
assert isinstance(function_types, TupleType)
|
|
387
|
+
|
|
388
|
+
rets: list[Wire] = []
|
|
389
|
+
remaining_args = node.args
|
|
390
|
+
for func, func_ty in zip(
|
|
391
|
+
self._unpack_tuple(functions, function_types.element_types),
|
|
392
|
+
function_types.element_types,
|
|
393
|
+
strict=True,
|
|
394
|
+
):
|
|
395
|
+
outs, remaining_args = self._compile_tensor_with_leftovers(
|
|
396
|
+
func, func_ty, remaining_args
|
|
397
|
+
)
|
|
398
|
+
rets.extend(outs)
|
|
399
|
+
assert (
|
|
400
|
+
remaining_args == []
|
|
401
|
+
), "Not all function arguments were consumed after a tensor call"
|
|
402
|
+
return self._pack_returns(rets, node.tensor_ty.output)
|
|
403
|
+
|
|
404
|
+
def _compile_tensor_with_leftovers(
|
|
405
|
+
self, func: Wire, func_ty: Type, args: list[ast.expr]
|
|
406
|
+
) -> tuple[
|
|
407
|
+
list[Wire], # Compiled outputs
|
|
408
|
+
list[ast.expr], # Leftover args
|
|
409
|
+
]:
|
|
410
|
+
"""Compiles a function call, consuming as many arguments as needed, and
|
|
411
|
+
returning the unused ones.
|
|
412
|
+
"""
|
|
413
|
+
if isinstance(func_ty, TupleType):
|
|
414
|
+
remaining_args = args
|
|
415
|
+
all_outs = []
|
|
416
|
+
for elem, ty in zip(
|
|
417
|
+
self._unpack_tuple(func, func_ty.element_types),
|
|
418
|
+
func_ty.element_types,
|
|
419
|
+
strict=True,
|
|
420
|
+
):
|
|
421
|
+
outs, remaining_args = self._compile_tensor_with_leftovers(
|
|
422
|
+
elem, ty, remaining_args
|
|
423
|
+
)
|
|
424
|
+
all_outs.extend(outs)
|
|
425
|
+
return all_outs, remaining_args
|
|
426
|
+
|
|
427
|
+
elif isinstance(func_ty, FunctionType):
|
|
428
|
+
input_len = len(func_ty.inputs)
|
|
429
|
+
num_returns = len(type_to_row(func_ty.output))
|
|
430
|
+
consumed_args, other_args = args[0:input_len], args[input_len:]
|
|
431
|
+
consumed_wires = self._compile_call_args(consumed_args, func_ty)
|
|
432
|
+
call = self.builder.add_op(
|
|
433
|
+
ops.CallIndirect(func_ty.to_hugr(self.ctx)), func, *consumed_wires
|
|
434
|
+
)
|
|
435
|
+
regular_returns: list[Wire] = list(call[:num_returns])
|
|
436
|
+
inout_returns = call[num_returns:]
|
|
437
|
+
self._update_inout_ports(consumed_args, inout_returns, func_ty)
|
|
438
|
+
return regular_returns, other_args
|
|
439
|
+
else:
|
|
440
|
+
raise InternalGuppyError("Tensor element wasn't function or tuple")
|
|
441
|
+
|
|
442
|
+
def visit_GlobalCall(self, node: GlobalCall) -> Wire:
|
|
443
|
+
func, rem_args = self.ctx.build_compiled_def(node.def_id, node.type_args)
|
|
444
|
+
assert isinstance(func, CompiledCallableDef)
|
|
445
|
+
|
|
446
|
+
if isinstance(func, CustomFunctionDef) and not func.has_signature:
|
|
447
|
+
func_ty = FunctionType(
|
|
448
|
+
[FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
|
|
449
|
+
get_type(node),
|
|
450
|
+
)
|
|
451
|
+
else:
|
|
452
|
+
func_ty = func.ty.instantiate(rem_args)
|
|
453
|
+
|
|
454
|
+
args = self._compile_call_args(node.args, func_ty)
|
|
455
|
+
rets = func.compile_call(args, rem_args, self.dfg, self.ctx, node)
|
|
456
|
+
self._update_inout_ports(node.args, iter(rets.inout_returns), func_ty)
|
|
457
|
+
return self._pack_returns(rets.regular_returns, func_ty.output)
|
|
458
|
+
|
|
459
|
+
def _compile_call_args(
|
|
460
|
+
self, args: list[ast.expr], func_ty: FunctionType
|
|
461
|
+
) -> list[Wire]:
|
|
462
|
+
"""Helper function to compile arguments for function calls.
|
|
463
|
+
|
|
464
|
+
Takes care of filtering out comptime arguments that are provided via generic
|
|
465
|
+
args or monomorphization instead of wires.
|
|
466
|
+
"""
|
|
467
|
+
return [
|
|
468
|
+
self.visit(arg)
|
|
469
|
+
for arg, inp in zip(args, func_ty.inputs, strict=True)
|
|
470
|
+
# Don't compile comptime args since we already include them as type args
|
|
471
|
+
if InputFlags.Comptime not in inp.flags
|
|
472
|
+
]
|
|
473
|
+
|
|
474
|
+
def visit_Call(self, node: ast.Call) -> Wire:
|
|
475
|
+
raise InternalGuppyError("Node should have been removed during type checking.")
|
|
476
|
+
|
|
477
|
+
def visit_PartialApply(self, node: PartialApply) -> Wire:
|
|
478
|
+
func_ty = get_type(node.func)
|
|
479
|
+
assert isinstance(func_ty, FunctionType)
|
|
480
|
+
op = PartialOp.from_closure(
|
|
481
|
+
func_ty.to_hugr(self.ctx),
|
|
482
|
+
[get_type(arg).to_hugr(self.ctx) for arg in node.args],
|
|
483
|
+
)
|
|
484
|
+
return self.builder.add_op(
|
|
485
|
+
op, self.visit(node.func), *(self.visit(arg) for arg in node.args)
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
def visit_TypeApply(self, node: TypeApply) -> Wire:
|
|
489
|
+
# For now, we can only TypeApply global FunctionDefs/Decls.
|
|
490
|
+
if not isinstance(node.value, GlobalName):
|
|
491
|
+
raise InternalGuppyError("Dynamic TypeApply not supported yet!")
|
|
492
|
+
defn, rem_args = self.ctx.build_compiled_def(node.value.def_id, node.inst)
|
|
493
|
+
assert isinstance(defn, CompiledCallableDef)
|
|
494
|
+
|
|
495
|
+
# We have to be very careful here: If we instantiate `foo: forall T. T -> T`
|
|
496
|
+
# with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`.
|
|
497
|
+
# Normally, this would be represented in Hugr as a function with two output
|
|
498
|
+
# ports types A and B. However, when TypeApplying `foo`, we actually get a
|
|
499
|
+
# function with a single output port typed `tuple[A, B]`.
|
|
500
|
+
# TODO: We would need to do manual monomorphisation in that case to obtain a
|
|
501
|
+
# function that returns two ports as expected
|
|
502
|
+
if instantiation_needs_unpacking(defn.ty, node.inst):
|
|
503
|
+
err = UnsupportedError(
|
|
504
|
+
node, "Generic function instantiations returning rows"
|
|
505
|
+
)
|
|
506
|
+
raise GuppyError(err)
|
|
507
|
+
|
|
508
|
+
return defn.load_with_args(rem_args, self.dfg, self.ctx, node)
|
|
509
|
+
|
|
510
|
+
def visit_UnaryOp(self, node: ast.UnaryOp) -> Wire:
|
|
511
|
+
# The only case that is not desugared by the type checker is the `not` operation
|
|
512
|
+
# since it is not implemented via a dunder method
|
|
513
|
+
if isinstance(node.op, ast.Not):
|
|
514
|
+
arg = self.visit(node.operand)
|
|
515
|
+
return self.builder.add_op(not_op(), arg)
|
|
516
|
+
|
|
517
|
+
raise InternalGuppyError("Node should have been removed during type checking.")
|
|
518
|
+
|
|
519
|
+
def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> Wire:
|
|
520
|
+
struct_port = self.visit(node.value)
|
|
521
|
+
field_idx = node.struct_ty.fields.index(node.field)
|
|
522
|
+
return self._unpack_tuple(struct_port, [f.ty for f in node.struct_ty.fields])[
|
|
523
|
+
field_idx
|
|
524
|
+
]
|
|
525
|
+
|
|
526
|
+
def visit_SubscriptAccessAndDrop(self, node: SubscriptAccessAndDrop) -> Wire:
|
|
527
|
+
self.dfg[node.item] = self.visit(node.item_expr)
|
|
528
|
+
return self.visit(node.getitem_expr)
|
|
529
|
+
|
|
530
|
+
def visit_TupleAccessAndDrop(self, node: TupleAccessAndDrop) -> Wire:
|
|
531
|
+
tuple_port = self.visit(node.value)
|
|
532
|
+
return self._unpack_tuple(tuple_port, node.tuple_ty.element_types)[node.index]
|
|
533
|
+
|
|
534
|
+
def visit_ResultExpr(self, node: ResultExpr) -> Wire:
|
|
535
|
+
value_wire = self.visit(node.value)
|
|
536
|
+
base_ty = node.base_ty.to_hugr(self.ctx)
|
|
537
|
+
extra_args: list[ht.TypeArg] = []
|
|
538
|
+
if isinstance(node.base_ty, NumericType):
|
|
539
|
+
match node.base_ty.kind:
|
|
540
|
+
case NumericType.Kind.Nat:
|
|
541
|
+
base_name = "uint"
|
|
542
|
+
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
543
|
+
case NumericType.Kind.Int:
|
|
544
|
+
base_name = "int"
|
|
545
|
+
extra_args = [ht.BoundedNatArg(n=NumericType.INT_WIDTH)]
|
|
546
|
+
case NumericType.Kind.Float:
|
|
547
|
+
base_name = "f64"
|
|
548
|
+
case kind:
|
|
549
|
+
assert_never(kind)
|
|
550
|
+
else:
|
|
551
|
+
# The only other valid base type is bool
|
|
552
|
+
assert is_bool_type(node.base_ty)
|
|
553
|
+
base_name = "bool"
|
|
554
|
+
if node.array_len is not None:
|
|
555
|
+
op_name = f"result_array_{base_name}"
|
|
556
|
+
size_arg = node.array_len.to_arg().to_hugr(self.ctx)
|
|
557
|
+
extra_args = [size_arg, *extra_args]
|
|
558
|
+
# Remove the option wrapping in the array
|
|
559
|
+
unwrap = array_unwrap_elem(self.ctx)
|
|
560
|
+
unwrap = self.builder.load_function(
|
|
561
|
+
unwrap,
|
|
562
|
+
instantiation=ht.FunctionType([ht.Option(base_ty)], [base_ty]),
|
|
563
|
+
type_args=[ht.TypeTypeArg(base_ty)],
|
|
564
|
+
)
|
|
565
|
+
map_op = array_map(ht.Option(base_ty), size_arg, base_ty)
|
|
566
|
+
value_wire = self.builder.add_op(map_op, value_wire, unwrap)
|
|
567
|
+
if is_bool_type(node.base_ty):
|
|
568
|
+
# We need to coerce a read on all the array elements if they are bools.
|
|
569
|
+
array_read = array_read_bool(self.ctx)
|
|
570
|
+
array_read = self.builder.load_function(array_read)
|
|
571
|
+
map_op = array_map(OpaqueBool, size_arg, ht.Bool)
|
|
572
|
+
value_wire = self.builder.add_op(map_op, value_wire, array_read)
|
|
573
|
+
base_ty = ht.Bool
|
|
574
|
+
# Turn `value_array` into regular linear `array`
|
|
575
|
+
value_wire = self.builder.add_op(
|
|
576
|
+
array_convert_to_std_array(base_ty, size_arg), value_wire
|
|
577
|
+
)
|
|
578
|
+
hugr_ty: ht.Type = hugr.std.collections.array.Array(base_ty, size_arg)
|
|
579
|
+
else:
|
|
580
|
+
if is_bool_type(node.base_ty):
|
|
581
|
+
base_ty = ht.Bool
|
|
582
|
+
value_wire = self.builder.add_op(read_bool(), value_wire)
|
|
583
|
+
op_name = f"result_{base_name}"
|
|
584
|
+
hugr_ty = base_ty
|
|
585
|
+
|
|
586
|
+
sig = ht.FunctionType(input=[hugr_ty], output=[])
|
|
587
|
+
args = [ht.StringArg(node.tag), *extra_args]
|
|
588
|
+
op = ops.ExtOp(RESULT_EXTENSION.get_op(op_name), signature=sig, args=args)
|
|
589
|
+
|
|
590
|
+
self.builder.add_op(op, value_wire)
|
|
591
|
+
return self._pack_returns([], NoneType())
|
|
592
|
+
|
|
593
|
+
def visit_PanicExpr(self, node: PanicExpr) -> Wire:
|
|
594
|
+
err = build_error(self.builder, node.signal, node.msg)
|
|
595
|
+
in_tys = [get_type(e).to_hugr(self.ctx) for e in node.values]
|
|
596
|
+
out_tys = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(node))]
|
|
597
|
+
args = [self.visit(e) for e in node.values]
|
|
598
|
+
match node.kind:
|
|
599
|
+
case ExitKind.Panic:
|
|
600
|
+
h_node = build_panic(self.builder, in_tys, out_tys, err, *args)
|
|
601
|
+
case ExitKind.ExitShot:
|
|
602
|
+
op = panic(in_tys, out_tys, ExitKind.ExitShot)
|
|
603
|
+
h_node = self.builder.add_op(op, err, *args)
|
|
604
|
+
return self._pack_returns(list(h_node.outputs()), get_type(node))
|
|
605
|
+
|
|
606
|
+
def visit_BarrierExpr(self, node: BarrierExpr) -> Wire:
|
|
607
|
+
hugr_tys = [get_type(e).to_hugr(self.ctx) for e in node.args]
|
|
608
|
+
op = hugr.std.prelude.PRELUDE_EXTENSION.get_op("Barrier").instantiate(
|
|
609
|
+
[ht.ListArg([ht.TypeTypeArg(ty) for ty in hugr_tys])],
|
|
610
|
+
ht.FunctionType.endo(hugr_tys),
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
barrier_n = self.builder.add_op(op, *(self.visit(e) for e in node.args))
|
|
614
|
+
|
|
615
|
+
self._update_inout_ports(node.args, iter(barrier_n), node.func_ty)
|
|
616
|
+
return self._pack_returns([], NoneType())
|
|
617
|
+
|
|
618
|
+
def visit_StateResultExpr(self, node: StateResultExpr) -> Wire:
|
|
619
|
+
num_qubits_arg = (
|
|
620
|
+
node.array_len.to_arg().to_hugr(self.ctx)
|
|
621
|
+
if node.array_len
|
|
622
|
+
else ht.BoundedNatArg(len(node.args) - 1)
|
|
623
|
+
)
|
|
624
|
+
args = [ht.StringArg(node.tag), num_qubits_arg]
|
|
625
|
+
sig = ht.FunctionType(
|
|
626
|
+
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
627
|
+
[standard_array_type(ht.Qubit, num_qubits_arg)],
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
op = ops.ExtOp(DEBUG_EXTENSION.get_op("StateResult"), signature=sig, args=args)
|
|
631
|
+
|
|
632
|
+
if not node.array_len:
|
|
633
|
+
# If the input is a sequence of qubits, we pack them into an array.
|
|
634
|
+
qubits_in = [self.visit(e) for e in node.args[1:]]
|
|
635
|
+
qubit_arr_in = self.builder.add_op(
|
|
636
|
+
array_new(ht.Qubit, len(node.args) - 1), *qubits_in
|
|
637
|
+
)
|
|
638
|
+
# Turn into standard array from value array.
|
|
639
|
+
qubit_arr_in = self.builder.add_op(
|
|
640
|
+
array_convert_to_std_array(ht.Qubit, num_qubits_arg), qubit_arr_in
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
qubit_arr_out = self.builder.add_op(op, qubit_arr_in)
|
|
644
|
+
|
|
645
|
+
qubit_arr_out = self.builder.add_op(
|
|
646
|
+
array_convert_from_std_array(ht.Qubit, num_qubits_arg), qubit_arr_out
|
|
647
|
+
)
|
|
648
|
+
qubits_out = unpack_array(self.builder, qubit_arr_out)
|
|
649
|
+
else:
|
|
650
|
+
# If the input is an array of qubits, we need to unwrap the elements first,
|
|
651
|
+
# and then convert to a value array and back.
|
|
652
|
+
qubits_in = [self.visit(node.args[1])]
|
|
653
|
+
qubits_out = [
|
|
654
|
+
apply_array_op_with_conversions(
|
|
655
|
+
self.ctx, self.builder, op, ht.Qubit, num_qubits_arg, qubits_in[0]
|
|
656
|
+
)
|
|
657
|
+
]
|
|
658
|
+
|
|
659
|
+
self._update_inout_ports(node.args, iter(qubits_out), node.func_ty)
|
|
660
|
+
return self._pack_returns([], NoneType())
|
|
661
|
+
|
|
662
|
+
def visit_DesugaredListComp(self, node: DesugaredListComp) -> Wire:
|
|
663
|
+
# Make up a name for the list under construction and bind it to an empty list
|
|
664
|
+
list_ty = get_type(node)
|
|
665
|
+
assert isinstance(list_ty, OpaqueType)
|
|
666
|
+
elem_ty = get_element_type(list_ty)
|
|
667
|
+
list_place = Variable(next(tmp_vars), list_ty, node)
|
|
668
|
+
self.dfg[list_place] = list_new(self.builder, elem_ty.to_hugr(self.ctx), [])
|
|
669
|
+
with self._build_generators(node.generators, [list_place]):
|
|
670
|
+
elt_port = self.visit(node.elt)
|
|
671
|
+
list_port = self.dfg[list_place]
|
|
672
|
+
[], [self.dfg[list_place]] = self._build_method_call(
|
|
673
|
+
list_ty, "append", node, [list_port, elt_port], list_ty.args
|
|
674
|
+
)
|
|
675
|
+
return self.dfg[list_place]
|
|
676
|
+
|
|
677
|
+
def visit_DesugaredArrayComp(self, node: DesugaredArrayComp) -> Wire:
|
|
678
|
+
# Allocate an uninitialised array of the desired size and a counter variable
|
|
679
|
+
array_ty = get_type(node)
|
|
680
|
+
assert isinstance(array_ty, OpaqueType)
|
|
681
|
+
array_var = Variable(next(tmp_vars), array_ty, node)
|
|
682
|
+
count_var = Variable(next(tmp_vars), int_type(), node)
|
|
683
|
+
# See https://github.com/CQCL/guppylang/issues/629
|
|
684
|
+
hugr_elt_ty = ht.Option(node.elt_ty.to_hugr(self.ctx))
|
|
685
|
+
# Initialise array with `None`s
|
|
686
|
+
make_none = array_comprehension_init_func(self.ctx)
|
|
687
|
+
make_none = self.builder.load_function(
|
|
688
|
+
make_none,
|
|
689
|
+
instantiation=ht.FunctionType([], [hugr_elt_ty]),
|
|
690
|
+
type_args=[ht.TypeTypeArg(node.elt_ty.to_hugr(self.ctx))],
|
|
691
|
+
)
|
|
692
|
+
self.dfg[array_var] = self.builder.add_op(
|
|
693
|
+
array_repeat(hugr_elt_ty, node.length.to_arg().to_hugr(self.ctx)), make_none
|
|
694
|
+
)
|
|
695
|
+
self.dfg[count_var] = self.builder.load(
|
|
696
|
+
hugr.std.int.IntVal(0, width=NumericType.INT_WIDTH)
|
|
697
|
+
)
|
|
698
|
+
with self._build_generators([node.generator], [array_var, count_var]):
|
|
699
|
+
elt = self.visit(node.elt)
|
|
700
|
+
array, count = self.dfg[array_var], self.dfg[count_var]
|
|
701
|
+
[], [self.dfg[array_var]] = self._build_method_call(
|
|
702
|
+
array_ty, "__setitem__", node, [array, count, elt], array_ty.args
|
|
703
|
+
)
|
|
704
|
+
# Update `count += 1`
|
|
705
|
+
one = self.builder.load(hugr.std.int.IntVal(1, width=NumericType.INT_WIDTH))
|
|
706
|
+
[self.dfg[count_var]], [] = self._build_method_call(
|
|
707
|
+
int_type(), "__add__", node, [count, one], []
|
|
708
|
+
)
|
|
709
|
+
return self.dfg[array_var]
|
|
710
|
+
|
|
711
|
+
def _build_method_call(
|
|
712
|
+
self, ty: Type, method: str, node: AstNode, args: list[Wire], type_args: Inst
|
|
713
|
+
) -> CallReturnWires:
|
|
714
|
+
func_and_targs = self.ctx.build_compiled_instance_func(ty, method, type_args)
|
|
715
|
+
assert func_and_targs is not None
|
|
716
|
+
func, rem_args = func_and_targs
|
|
717
|
+
return func.compile_call(args, rem_args, self.dfg, self.ctx, node)
|
|
718
|
+
|
|
719
|
+
@contextmanager
|
|
720
|
+
def _build_generators(
|
|
721
|
+
self, gens: list[DesugaredGenerator], loop_vars: list[Variable]
|
|
722
|
+
) -> Iterator[None]:
|
|
723
|
+
"""Context manager to build and enter the `TailLoop`s for a list of generators.
|
|
724
|
+
|
|
725
|
+
The provided `loop_vars` will be threaded through and will be available inside
|
|
726
|
+
the loops.
|
|
727
|
+
"""
|
|
728
|
+
from guppylang_internals.compiler.stmt_compiler import StmtCompiler
|
|
729
|
+
|
|
730
|
+
compiler = StmtCompiler(self.ctx)
|
|
731
|
+
with ExitStack() as stack:
|
|
732
|
+
for gen in gens:
|
|
733
|
+
# Build the generator
|
|
734
|
+
compiler.compile_stmts([gen.iter_assign], self.dfg)
|
|
735
|
+
assert isinstance(gen.iter, PlaceNode)
|
|
736
|
+
iter_ty = get_type(gen.iter)
|
|
737
|
+
inputs = [PlaceNode(place=var) for var in loop_vars]
|
|
738
|
+
inputs += [PlaceNode(place=place) for place in gen.used_outer_places]
|
|
739
|
+
# Enter a new tail loop. Note that the iterator is a `just_input`, so
|
|
740
|
+
# will not be outputted by the loop
|
|
741
|
+
break_pred = PlaceNode(Variable(next(tmp_vars), bool_type(), gen.iter))
|
|
742
|
+
stack.enter_context(self._new_loop([gen.iter], inputs, break_pred))
|
|
743
|
+
# Enter a conditional checking if we have a next element
|
|
744
|
+
next_ty = TupleType([get_type(gen.target), iter_ty])
|
|
745
|
+
next_var = PlaceNode(Variable(next(tmp_vars), next_ty, gen.iter))
|
|
746
|
+
hasnext_case, stop_case = self._if_else(
|
|
747
|
+
gen.next_call,
|
|
748
|
+
inputs,
|
|
749
|
+
only_true_inputs=[next_var],
|
|
750
|
+
outputs=[break_pred, *inputs],
|
|
751
|
+
)
|
|
752
|
+
# In the "no" case, we set the break predicate to true
|
|
753
|
+
break_pred_hugr_ty = ht.Either([iter_ty.to_hugr(self.ctx)], [])
|
|
754
|
+
with stop_case:
|
|
755
|
+
self.dfg[break_pred.place] = self.dfg.builder.add_op(
|
|
756
|
+
ops.Tag(1, break_pred_hugr_ty)
|
|
757
|
+
)
|
|
758
|
+
# Otherwise, we continue, set the break predicate to false, and insert
|
|
759
|
+
# the iterator for the next loop iteration
|
|
760
|
+
stack.enter_context(hasnext_case)
|
|
761
|
+
next_wire = self.dfg[next_var.place]
|
|
762
|
+
elt, it = self.dfg.builder.add_op(ops.UnpackTuple(), next_wire)
|
|
763
|
+
compiler.dfg = self.dfg
|
|
764
|
+
compiler._assign(gen.target, elt)
|
|
765
|
+
self.dfg[break_pred.place] = self.dfg.builder.add_op(
|
|
766
|
+
ops.Tag(0, break_pred_hugr_ty), it
|
|
767
|
+
)
|
|
768
|
+
# Enter nested conditionals for each if guard on the generator
|
|
769
|
+
for if_expr in gen.ifs:
|
|
770
|
+
stack.enter_context(self._if_true(if_expr, [break_pred, *inputs]))
|
|
771
|
+
# Yield control to the caller to build inside the loop
|
|
772
|
+
yield
|
|
773
|
+
|
|
774
|
+
def visit_BinOp(self, node: ast.BinOp) -> Wire:
|
|
775
|
+
raise InternalGuppyError("Node should have been removed during type checking.")
|
|
776
|
+
|
|
777
|
+
def visit_Compare(self, node: ast.Compare) -> Wire:
|
|
778
|
+
raise InternalGuppyError("Node should have been removed during type checking.")
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def expr_to_row(expr: ast.expr) -> list[ast.expr]:
|
|
782
|
+
"""Turns an expression into a row expressions by unpacking top-level tuples."""
|
|
783
|
+
return expr.elts if isinstance(expr, ast.Tuple) else [expr]
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
|
|
787
|
+
"""Checks if instantiating a polymorphic makes it return a row."""
|
|
788
|
+
if isinstance(func_ty.output, BoundTypeVar):
|
|
789
|
+
return_ty = inst[func_ty.output.idx]
|
|
790
|
+
return isinstance(return_ty, TupleType | NoneType)
|
|
791
|
+
return False
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def python_value_to_hugr(v: Any, exp_ty: Type, ctx: CompilerContext) -> hv.Value | None:
|
|
795
|
+
"""Turns a Python value into a Hugr value.
|
|
796
|
+
|
|
797
|
+
Returns None if the Python value cannot be represented in Guppy.
|
|
798
|
+
"""
|
|
799
|
+
match v:
|
|
800
|
+
case bool():
|
|
801
|
+
return OpaqueBoolVal(v)
|
|
802
|
+
case str():
|
|
803
|
+
return hugr.std.prelude.StringVal(v)
|
|
804
|
+
case int():
|
|
805
|
+
assert isinstance(exp_ty, NumericType)
|
|
806
|
+
match exp_ty.kind:
|
|
807
|
+
case NumericType.Kind.Nat:
|
|
808
|
+
return UnsignedIntVal(v, width=NumericType.INT_WIDTH)
|
|
809
|
+
case NumericType.Kind.Int:
|
|
810
|
+
return hugr.std.int.IntVal(v, width=NumericType.INT_WIDTH)
|
|
811
|
+
case _:
|
|
812
|
+
raise InternalGuppyError("Unexpected numeric type")
|
|
813
|
+
case float():
|
|
814
|
+
return hugr.std.float.FloatVal(v)
|
|
815
|
+
case tuple(elts):
|
|
816
|
+
assert isinstance(exp_ty, TupleType)
|
|
817
|
+
vs = [
|
|
818
|
+
python_value_to_hugr(elt, ty, ctx)
|
|
819
|
+
for elt, ty in zip(elts, exp_ty.element_types, strict=True)
|
|
820
|
+
]
|
|
821
|
+
if doesnt_contain_none(vs):
|
|
822
|
+
return hv.Tuple(*vs)
|
|
823
|
+
case list(elts):
|
|
824
|
+
assert is_frozenarray_type(exp_ty)
|
|
825
|
+
elem_ty = get_element_type(exp_ty)
|
|
826
|
+
vs = [python_value_to_hugr(elt, elem_ty, ctx) for elt in elts]
|
|
827
|
+
if doesnt_contain_none(vs):
|
|
828
|
+
return hugr.std.collections.static_array.StaticArrayVal(
|
|
829
|
+
vs, elem_ty.to_hugr(ctx), name=f"static_pyarray.{next(tmp_vars)}"
|
|
830
|
+
)
|
|
831
|
+
case _:
|
|
832
|
+
return None
|
|
833
|
+
return None
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
ARRAY_COMPREHENSION_INIT: Final[GlobalConstId] = GlobalConstId.fresh(
|
|
837
|
+
"array.__comprehension.init"
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
ARRAY_UNWRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__unwrap_elem")
|
|
841
|
+
ARRAY_WRAP_ELEM: Final[GlobalConstId] = GlobalConstId.fresh("array.__wrap_elem")
|
|
842
|
+
|
|
843
|
+
ARRAY_READ_BOOL: Final[GlobalConstId] = GlobalConstId.fresh("array.__read_bool")
|
|
844
|
+
ARRAY_MAKE_OPAQUE_BOOL: Final[GlobalConstId] = GlobalConstId.fresh(
|
|
845
|
+
"array.__make_opaque_bool"
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def array_comprehension_init_func(ctx: CompilerContext) -> hf.Function:
|
|
850
|
+
"""Returns the Hugr function that is used to initialise arrays elements before a
|
|
851
|
+
comprehension.
|
|
852
|
+
|
|
853
|
+
Just returns the `None` variant of the optional element type.
|
|
854
|
+
|
|
855
|
+
See https://github.com/CQCL/guppylang/issues/629
|
|
856
|
+
"""
|
|
857
|
+
v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
|
|
858
|
+
sig = ht.PolyFuncType(
|
|
859
|
+
params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
|
|
860
|
+
body=ht.FunctionType([], [ht.Option(v)]),
|
|
861
|
+
)
|
|
862
|
+
func, already_defined = ctx.declare_global_func(ARRAY_COMPREHENSION_INIT, sig)
|
|
863
|
+
if not already_defined:
|
|
864
|
+
func.set_outputs(func.add_op(ops.Tag(0, ht.Option(v))))
|
|
865
|
+
return func
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def array_unwrap_elem(ctx: CompilerContext) -> hf.Function:
|
|
869
|
+
"""Returns the Hugr function that is used to unwrap the elements in an option array
|
|
870
|
+
to turn it into a regular array."""
|
|
871
|
+
v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
|
|
872
|
+
sig = ht.PolyFuncType(
|
|
873
|
+
params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
|
|
874
|
+
body=ht.FunctionType([ht.Option(v)], [v]),
|
|
875
|
+
)
|
|
876
|
+
func, already_defined = ctx.declare_global_func(ARRAY_UNWRAP_ELEM, sig)
|
|
877
|
+
if not already_defined:
|
|
878
|
+
msg = "Linear array element has already been used"
|
|
879
|
+
func.set_outputs(build_unwrap(func, func.inputs()[0], msg))
|
|
880
|
+
return func
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def array_wrap_elem(ctx: CompilerContext) -> hf.Function:
|
|
884
|
+
"""Returns the Hugr function that is used to wrap the elements in an regular array
|
|
885
|
+
to turn it into a option array."""
|
|
886
|
+
v = ht.Variable(0, ht.TypeBound(ht.TypeBound.Linear))
|
|
887
|
+
sig = ht.PolyFuncType(
|
|
888
|
+
params=[ht.TypeTypeParam(ht.TypeBound.Linear)],
|
|
889
|
+
body=ht.FunctionType([v], [ht.Option(v)]),
|
|
890
|
+
)
|
|
891
|
+
func, already_defined = ctx.declare_global_func(ARRAY_WRAP_ELEM, sig)
|
|
892
|
+
if not already_defined:
|
|
893
|
+
func.set_outputs(func.add_op(ops.Tag(1, ht.Option(v)), func.inputs()[0]))
|
|
894
|
+
return func
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
def array_read_bool(ctx: CompilerContext) -> hf.Function:
|
|
898
|
+
"""Returns the Hugr function that is used to unwrap the elements in an option array
|
|
899
|
+
to turn it into a regular array."""
|
|
900
|
+
sig = ht.PolyFuncType(
|
|
901
|
+
params=[],
|
|
902
|
+
body=ht.FunctionType([OpaqueBool], [ht.Bool]),
|
|
903
|
+
)
|
|
904
|
+
func, already_defined = ctx.declare_global_func(ARRAY_READ_BOOL, sig)
|
|
905
|
+
if not already_defined:
|
|
906
|
+
func.set_outputs(func.add_op(read_bool(), func.inputs()[0]))
|
|
907
|
+
return func
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
def array_make_opaque_bool(ctx: CompilerContext) -> hf.Function:
|
|
911
|
+
"""Returns the Hugr function that is used to unwrap the elements in an option array
|
|
912
|
+
to turn it into a regular array."""
|
|
913
|
+
sig = ht.PolyFuncType(
|
|
914
|
+
params=[],
|
|
915
|
+
body=ht.FunctionType([ht.Bool], [OpaqueBool]),
|
|
916
|
+
)
|
|
917
|
+
func, already_defined = ctx.declare_global_func(ARRAY_MAKE_OPAQUE_BOOL, sig)
|
|
918
|
+
if not already_defined:
|
|
919
|
+
func.set_outputs(func.add_op(make_opaque(), func.inputs()[0]))
|
|
920
|
+
return func
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
T = TypeVar("T")
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
def doesnt_contain_none(xs: list[T | None]) -> TypeGuard[list[T]]:
|
|
927
|
+
"""Checks if a list contains `None`."""
|
|
928
|
+
return all(x is not None for x in xs)
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def apply_array_op_with_conversions(
|
|
932
|
+
ctx: CompilerContext,
|
|
933
|
+
builder: DfBase[ops.DfParentOp],
|
|
934
|
+
op: ops.DataflowOp,
|
|
935
|
+
elem_ty: ht.Type,
|
|
936
|
+
size_arg: ht.TypeArg,
|
|
937
|
+
input_array: Wire,
|
|
938
|
+
convert_bool: bool = False,
|
|
939
|
+
) -> Wire:
|
|
940
|
+
"""Applies common transformations to a Guppy array input before it can be passed to
|
|
941
|
+
a Hugr op operating on a standard Hugr array, and then reverses them again on the
|
|
942
|
+
output array.
|
|
943
|
+
|
|
944
|
+
Transformations:
|
|
945
|
+
1. Unwraps / wraps elements in options.
|
|
946
|
+
3. (Optional) Converts from / to opaque bool to / from Hugr bool.
|
|
947
|
+
2. Converts from / to value array to / from standard Hugr array.
|
|
948
|
+
"""
|
|
949
|
+
unwrap = array_unwrap_elem(ctx)
|
|
950
|
+
unwrap = builder.load_function(
|
|
951
|
+
unwrap,
|
|
952
|
+
instantiation=ht.FunctionType([ht.Option(elem_ty)], [elem_ty]),
|
|
953
|
+
type_args=[ht.TypeTypeArg(elem_ty)],
|
|
954
|
+
)
|
|
955
|
+
map_op = array_map(ht.Option(elem_ty), size_arg, elem_ty)
|
|
956
|
+
unwrapped_array = builder.add_op(map_op, input_array, unwrap)
|
|
957
|
+
|
|
958
|
+
if convert_bool:
|
|
959
|
+
array_read = array_read_bool(ctx)
|
|
960
|
+
array_read = builder.load_function(array_read)
|
|
961
|
+
map_op = array_map(OpaqueBool, size_arg, ht.Bool)
|
|
962
|
+
unwrapped_array = builder.add_op(map_op, unwrapped_array, array_read)
|
|
963
|
+
elem_ty = ht.Bool
|
|
964
|
+
|
|
965
|
+
unwrapped_array = builder.add_op(
|
|
966
|
+
array_convert_to_std_array(elem_ty, size_arg), unwrapped_array
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
result_array = builder.add_op(op, unwrapped_array)
|
|
970
|
+
|
|
971
|
+
result_array = builder.add_op(
|
|
972
|
+
array_convert_from_std_array(elem_ty, size_arg), result_array
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
if convert_bool:
|
|
976
|
+
array_make_opaque = array_make_opaque_bool(ctx)
|
|
977
|
+
array_make_opaque = builder.load_function(array_make_opaque)
|
|
978
|
+
map_op = array_map(ht.Bool, size_arg, OpaqueBool)
|
|
979
|
+
result_array = builder.add_op(map_op, result_array, array_make_opaque)
|
|
980
|
+
elem_ty = OpaqueBool
|
|
981
|
+
|
|
982
|
+
wrap = array_wrap_elem(ctx)
|
|
983
|
+
wrap = builder.load_function(
|
|
984
|
+
wrap,
|
|
985
|
+
instantiation=ht.FunctionType([elem_ty], [ht.Option(elem_ty)]),
|
|
986
|
+
type_args=[ht.TypeTypeArg(elem_ty)],
|
|
987
|
+
)
|
|
988
|
+
map_op = array_map(elem_ty, size_arg, ht.Option(elem_ty))
|
|
989
|
+
return builder.add_op(map_op, result_array, wrap)
|