guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from hugr import Wire
|
|
4
|
+
from hugr import tys as ht
|
|
5
|
+
from hugr.build.function import Function
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.compiler.cfg_compiler import compile_cfg
|
|
8
|
+
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
9
|
+
from guppylang_internals.compiler.hugr_extension import PartialOp
|
|
10
|
+
from guppylang_internals.nodes import CheckedNestedFunctionDef
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from guppylang_internals.definition.function import CheckedFunctionDef
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def compile_global_func_def(
|
|
17
|
+
func: "CheckedFunctionDef",
|
|
18
|
+
builder: Function,
|
|
19
|
+
ctx: CompilerContext,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Compiles a top-level function definition to Hugr."""
|
|
22
|
+
cfg = compile_cfg(func.cfg, builder, builder.inputs(), ctx)
|
|
23
|
+
builder.set_outputs(*cfg)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compile_local_func_def(
|
|
27
|
+
func: CheckedNestedFunctionDef,
|
|
28
|
+
dfg: DFContainer,
|
|
29
|
+
ctx: CompilerContext,
|
|
30
|
+
) -> Wire:
|
|
31
|
+
"""Compiles a local (nested) function definition to Hugr and loads it into a value.
|
|
32
|
+
|
|
33
|
+
Returns the wire output of the `LoadFunc` operation.
|
|
34
|
+
"""
|
|
35
|
+
assert func.ty.input_names is not None
|
|
36
|
+
|
|
37
|
+
# Pick an order for the captured variables
|
|
38
|
+
captured = list(func.captured.values())
|
|
39
|
+
captured_types = [v.ty.to_hugr(ctx) for v, _ in captured]
|
|
40
|
+
|
|
41
|
+
# Whether the function calls itself recursively.
|
|
42
|
+
recursive = func.name in func.cfg.live_before[func.cfg.entry_bb]
|
|
43
|
+
|
|
44
|
+
# Prepend captured variables to the function arguments
|
|
45
|
+
func_ty = func.ty.to_hugr(ctx)
|
|
46
|
+
closure_ty = ht.FunctionType([*captured_types, *func_ty.input], func_ty.output)
|
|
47
|
+
func_builder = dfg.builder.module_root_builder().define_function(
|
|
48
|
+
func.name, closure_ty.input, closure_ty.output
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Nested functions are not generic, so no need to worry about monomorphization
|
|
52
|
+
mono_args = ()
|
|
53
|
+
|
|
54
|
+
# If we have captured variables and the body contains a recursive occurrence of
|
|
55
|
+
# the function itself, then we provide the partially applied function as a local
|
|
56
|
+
# variable
|
|
57
|
+
call_args: list[Wire] = list(func_builder.inputs())
|
|
58
|
+
if len(captured) > 0 and recursive:
|
|
59
|
+
loaded = func_builder.load_function(func_builder, closure_ty)
|
|
60
|
+
partial = func_builder.add_op(
|
|
61
|
+
PartialOp.from_closure(closure_ty, captured_types),
|
|
62
|
+
loaded,
|
|
63
|
+
*func_builder.input_node[: len(captured)],
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
call_args.append(partial)
|
|
67
|
+
func.cfg.input_tys.append(func.ty)
|
|
68
|
+
|
|
69
|
+
# Compile the CFG
|
|
70
|
+
cfg = compile_cfg(func.cfg, func_builder, call_args, ctx)
|
|
71
|
+
func_builder.set_outputs(*cfg)
|
|
72
|
+
else:
|
|
73
|
+
# Otherwise, we treat the function like a normal global variable
|
|
74
|
+
from guppylang_internals.definition.function import CompiledFunctionDef
|
|
75
|
+
|
|
76
|
+
ctx.compiled[func.def_id, mono_args] = CompiledFunctionDef(
|
|
77
|
+
func.def_id,
|
|
78
|
+
func.name,
|
|
79
|
+
func,
|
|
80
|
+
mono_args,
|
|
81
|
+
func.ty,
|
|
82
|
+
None,
|
|
83
|
+
func.cfg,
|
|
84
|
+
func_builder,
|
|
85
|
+
)
|
|
86
|
+
ctx.worklist[func.def_id, mono_args] = None # will compile the CFG later
|
|
87
|
+
|
|
88
|
+
# Finally, load the function into the local data-flow graph
|
|
89
|
+
loaded = dfg.builder.load_function(func_builder, closure_ty)
|
|
90
|
+
if len(captured) > 0:
|
|
91
|
+
loaded = dfg.builder.add_op(
|
|
92
|
+
PartialOp.from_closure(closure_ty, captured_types),
|
|
93
|
+
loaded,
|
|
94
|
+
*(dfg[v] for v, _ in captured),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return loaded
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""A hugr extension with guppy-specific operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import hugr.ext as he
|
|
9
|
+
import hugr.tys as ht
|
|
10
|
+
from hugr import ops
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import Iterator, Sequence
|
|
14
|
+
|
|
15
|
+
EXTENSION: he.Extension = he.Extension("guppylang", he.Version(0, 1, 0))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
PARTIAL_OP_DEF: he.OpDef = EXTENSION.add_op_def(
|
|
19
|
+
he.OpDef(
|
|
20
|
+
"partial",
|
|
21
|
+
signature=he.OpDefSig(
|
|
22
|
+
poly_func=ht.PolyFuncType(
|
|
23
|
+
params=[
|
|
24
|
+
# Captured input types
|
|
25
|
+
ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
|
|
26
|
+
# Non-captured input types
|
|
27
|
+
ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
|
|
28
|
+
# Output types
|
|
29
|
+
ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
|
|
30
|
+
],
|
|
31
|
+
body=ht.FunctionType(
|
|
32
|
+
input=[
|
|
33
|
+
ht.FunctionType(
|
|
34
|
+
input=[
|
|
35
|
+
ht.RowVariable(0, ht.TypeBound.Linear),
|
|
36
|
+
ht.RowVariable(1, ht.TypeBound.Linear),
|
|
37
|
+
],
|
|
38
|
+
output=[ht.RowVariable(2, ht.TypeBound.Linear)],
|
|
39
|
+
),
|
|
40
|
+
ht.RowVariable(0, ht.TypeBound.Linear),
|
|
41
|
+
],
|
|
42
|
+
output=[
|
|
43
|
+
ht.FunctionType(
|
|
44
|
+
input=[ht.RowVariable(1, ht.TypeBound.Linear)],
|
|
45
|
+
output=[ht.RowVariable(2, ht.TypeBound.Linear)],
|
|
46
|
+
),
|
|
47
|
+
],
|
|
48
|
+
),
|
|
49
|
+
)
|
|
50
|
+
),
|
|
51
|
+
description="A partial application of a function."
|
|
52
|
+
" Given arguments [*a],[*b],[*c], represents an operation with type"
|
|
53
|
+
" `(*c, *a -> *b), *c -> (*a -> *b)`",
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class PartialOp(ops.AsExtOp):
|
|
60
|
+
"""An operation that partially evaluates a function.
|
|
61
|
+
|
|
62
|
+
args:
|
|
63
|
+
captured_inputs: A list of input types `c_0, ..., c_k` to partially apply.
|
|
64
|
+
other_inputs: A list of input types `a_0, ..., a_n` not partially applied.
|
|
65
|
+
outputs: The output types `b_0, ..., b_m` of the partially applied function.
|
|
66
|
+
|
|
67
|
+
returns:
|
|
68
|
+
An operation with type
|
|
69
|
+
` (c_0, ..., c_k, a_0, ..., a_n -> b_0, ..., b_m ), c_0, ..., c_k`
|
|
70
|
+
`-> (a_0, ..., a_n -> b_0, ..., b_m)`
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
captured_inputs: list[ht.Type]
|
|
74
|
+
other_inputs: list[ht.Type]
|
|
75
|
+
outputs: list[ht.Type]
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def from_closure(
|
|
79
|
+
cls, closure_ty: ht.FunctionType, captured_tys: Sequence[ht.Type]
|
|
80
|
+
) -> PartialOp:
|
|
81
|
+
"""An operation that partially evaluates a function.
|
|
82
|
+
|
|
83
|
+
args:
|
|
84
|
+
closure_ty: A function `(c_0, ..., c_k, a_0, ..., a_n) -> b_0, ..., b_m`
|
|
85
|
+
captured_tys: A list `c_0, ..., c_k` of types captured by the function
|
|
86
|
+
|
|
87
|
+
returns:
|
|
88
|
+
An operation with type
|
|
89
|
+
` (c_0, ..., c_k, a_0, ..., a_n -> b_0, ..., b_m ), c_0, ..., c_k`
|
|
90
|
+
`-> (a_0, ..., a_n -> b_0, ..., b_m)`
|
|
91
|
+
"""
|
|
92
|
+
assert len(closure_ty.input) >= len(captured_tys)
|
|
93
|
+
assert captured_tys == closure_ty.input[: len(captured_tys)]
|
|
94
|
+
|
|
95
|
+
other_inputs = closure_ty.input[len(captured_tys) :]
|
|
96
|
+
return cls(
|
|
97
|
+
captured_inputs=list(captured_tys),
|
|
98
|
+
other_inputs=list(other_inputs),
|
|
99
|
+
outputs=list(closure_ty.output),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def op_def(self) -> he.OpDef:
|
|
103
|
+
return PARTIAL_OP_DEF
|
|
104
|
+
|
|
105
|
+
def type_args(self) -> list[ht.TypeArg]:
|
|
106
|
+
captured_args: list[ht.TypeArg] = [
|
|
107
|
+
ht.TypeTypeArg(ty) for ty in self.captured_inputs
|
|
108
|
+
]
|
|
109
|
+
other_args: list[ht.TypeArg] = [ht.TypeTypeArg(ty) for ty in self.other_inputs]
|
|
110
|
+
output_args: list[ht.TypeArg] = [ht.TypeTypeArg(ty) for ty in self.outputs]
|
|
111
|
+
return [
|
|
112
|
+
ht.ListArg(captured_args),
|
|
113
|
+
ht.ListArg(other_args),
|
|
114
|
+
ht.ListArg(output_args),
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
def cached_signature(self) -> ht.FunctionType | None:
|
|
118
|
+
closure_ty = ht.FunctionType(
|
|
119
|
+
[*self.captured_inputs, *self.other_inputs],
|
|
120
|
+
self.outputs,
|
|
121
|
+
)
|
|
122
|
+
partial_fn_ty = ht.FunctionType(self.other_inputs, closure_ty.output)
|
|
123
|
+
return ht.FunctionType([closure_ty, *self.captured_inputs], [partial_fn_ty])
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def from_ext(cls, custom: ops.ExtOp) -> PartialOp:
|
|
127
|
+
match custom:
|
|
128
|
+
case ops.ExtOp(
|
|
129
|
+
_op_def=op_def, args=[captured_args, other_args, output_args]
|
|
130
|
+
):
|
|
131
|
+
if op_def.qualified_name() == PARTIAL_OP_DEF.qualified_name():
|
|
132
|
+
return cls(
|
|
133
|
+
captured_inputs=[*_arg_seq_to_types(captured_args)],
|
|
134
|
+
other_inputs=[*_arg_seq_to_types(other_args)],
|
|
135
|
+
outputs=[*_arg_seq_to_types(output_args)],
|
|
136
|
+
)
|
|
137
|
+
msg = f"Invalid custom op: {custom}"
|
|
138
|
+
raise ops.AsExtOp.InvalidExtOp(msg)
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def num_out(self) -> int:
|
|
142
|
+
return 1
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
UNSUPPORTED_OP_DEF: he.OpDef = EXTENSION.add_op_def(
|
|
146
|
+
he.OpDef(
|
|
147
|
+
"unsupported",
|
|
148
|
+
signature=he.OpDefSig(
|
|
149
|
+
poly_func=ht.PolyFuncType(
|
|
150
|
+
params=[
|
|
151
|
+
# Name of the operation
|
|
152
|
+
ht.StringParam(),
|
|
153
|
+
# Input types
|
|
154
|
+
ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
|
|
155
|
+
# Output types
|
|
156
|
+
ht.ListParam(ht.TypeTypeParam(ht.TypeBound.Linear)),
|
|
157
|
+
],
|
|
158
|
+
body=ht.FunctionType(
|
|
159
|
+
input=[ht.RowVariable(1, ht.TypeBound.Linear)],
|
|
160
|
+
output=[ht.RowVariable(2, ht.TypeBound.Linear)],
|
|
161
|
+
),
|
|
162
|
+
)
|
|
163
|
+
),
|
|
164
|
+
description="An unsupported operation stub emitted by Guppy.",
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclass
|
|
170
|
+
class UnsupportedOp(ops.AsExtOp):
|
|
171
|
+
"""An unsupported operation stub emitted by Guppy.
|
|
172
|
+
|
|
173
|
+
args:
|
|
174
|
+
op_name: The name of the unsupported operation.
|
|
175
|
+
inputs: The input types of the operation.
|
|
176
|
+
outputs: The output types of the operation.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
op_name: str
|
|
180
|
+
inputs: list[ht.Type]
|
|
181
|
+
outputs: list[ht.Type]
|
|
182
|
+
|
|
183
|
+
def op_def(self) -> he.OpDef:
|
|
184
|
+
return UNSUPPORTED_OP_DEF
|
|
185
|
+
|
|
186
|
+
def type_args(self) -> list[ht.TypeArg]:
|
|
187
|
+
op_name = ht.StringArg(self.op_name)
|
|
188
|
+
input_args = ht.ListArg([ht.TypeTypeArg(ty) for ty in self.inputs])
|
|
189
|
+
output_args = ht.ListArg([ht.TypeTypeArg(ty) for ty in self.outputs])
|
|
190
|
+
return [op_name, input_args, output_args]
|
|
191
|
+
|
|
192
|
+
def cached_signature(self) -> ht.FunctionType | None:
|
|
193
|
+
return ht.FunctionType(self.inputs, self.outputs)
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def from_ext(cls, custom: ops.ExtOp) -> UnsupportedOp:
|
|
197
|
+
match custom:
|
|
198
|
+
case ops.ExtOp(_op_def=op_def, args=args):
|
|
199
|
+
if op_def.qualified_name() == UNSUPPORTED_OP_DEF.qualified_name():
|
|
200
|
+
[op_name, input_args, output_args] = args
|
|
201
|
+
assert isinstance(op_name, ht.StringArg), (
|
|
202
|
+
"The first argument to a guppylang.unsupported op "
|
|
203
|
+
"must be the operation name"
|
|
204
|
+
)
|
|
205
|
+
op_name = op_name.value
|
|
206
|
+
return cls(
|
|
207
|
+
op_name=op_name,
|
|
208
|
+
inputs=[*_arg_seq_to_types(input_args)],
|
|
209
|
+
outputs=[*_arg_seq_to_types(output_args)],
|
|
210
|
+
)
|
|
211
|
+
msg = f"Invalid custom op: {custom}"
|
|
212
|
+
raise ops.AsExtOp.InvalidExtOp(msg)
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def num_out(self) -> int:
|
|
216
|
+
return len(self.outputs)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _arg_seq_to_types(args: ht.TypeArg) -> Iterator[ht.Type]:
|
|
220
|
+
"""Converts a ListArg of type arguments into a sequence of types."""
|
|
221
|
+
assert isinstance(args, ht.ListArg)
|
|
222
|
+
for arg in args.elems:
|
|
223
|
+
assert isinstance(arg, ht.TypeTypeArg)
|
|
224
|
+
yield arg.ty
|
|
File without changes
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import functools
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import hugr.tys as ht
|
|
6
|
+
from hugr import Wire, ops
|
|
7
|
+
from hugr.build.dfg import DfBase
|
|
8
|
+
|
|
9
|
+
from guppylang_internals.ast_util import AstVisitor, get_type
|
|
10
|
+
from guppylang_internals.checker.core import Variable, contains_subscript
|
|
11
|
+
from guppylang_internals.compiler.core import (
|
|
12
|
+
CompilerBase,
|
|
13
|
+
CompilerContext,
|
|
14
|
+
DFContainer,
|
|
15
|
+
return_var,
|
|
16
|
+
)
|
|
17
|
+
from guppylang_internals.compiler.expr_compiler import ExprCompiler
|
|
18
|
+
from guppylang_internals.error import InternalGuppyError
|
|
19
|
+
from guppylang_internals.nodes import (
|
|
20
|
+
CheckedNestedFunctionDef,
|
|
21
|
+
IterableUnpack,
|
|
22
|
+
PlaceNode,
|
|
23
|
+
TupleUnpack,
|
|
24
|
+
)
|
|
25
|
+
from guppylang_internals.std._internal.compiler.array import (
|
|
26
|
+
array_discard_empty,
|
|
27
|
+
array_new,
|
|
28
|
+
array_pop,
|
|
29
|
+
)
|
|
30
|
+
from guppylang_internals.std._internal.compiler.prelude import build_unwrap
|
|
31
|
+
from guppylang_internals.tys.builtin import get_element_type
|
|
32
|
+
from guppylang_internals.tys.const import ConstValue
|
|
33
|
+
from guppylang_internals.tys.ty import TupleType, Type, type_to_row
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class StmtCompiler(CompilerBase, AstVisitor[None]):
|
|
37
|
+
"""A compiler for Guppy statements to Hugr"""
|
|
38
|
+
|
|
39
|
+
expr_compiler: ExprCompiler
|
|
40
|
+
|
|
41
|
+
dfg: DFContainer
|
|
42
|
+
|
|
43
|
+
def __init__(self, ctx: CompilerContext):
|
|
44
|
+
super().__init__(ctx)
|
|
45
|
+
self.expr_compiler = ExprCompiler(ctx)
|
|
46
|
+
|
|
47
|
+
def compile_stmts(
|
|
48
|
+
self,
|
|
49
|
+
stmts: Sequence[ast.stmt],
|
|
50
|
+
dfg: DFContainer,
|
|
51
|
+
) -> DFContainer:
|
|
52
|
+
"""Compiles a list of basic statements into a dataflow node.
|
|
53
|
+
|
|
54
|
+
Note that the `dfg` is mutated in-place. After compilation, the DFG will also
|
|
55
|
+
contain all variables that are assigned in the given list of statements.
|
|
56
|
+
"""
|
|
57
|
+
self.dfg = dfg
|
|
58
|
+
for s in stmts:
|
|
59
|
+
self.visit(s)
|
|
60
|
+
return self.dfg
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def builder(self) -> DfBase[ops.DfParentOp]:
|
|
64
|
+
"""The Hugr dataflow graph builder."""
|
|
65
|
+
return self.dfg.builder
|
|
66
|
+
|
|
67
|
+
@functools.singledispatchmethod
|
|
68
|
+
def _assign(self, lhs: ast.expr, port: Wire) -> None:
|
|
69
|
+
"""Updates the local DFG with assignments."""
|
|
70
|
+
raise InternalGuppyError("Invalid assign pattern in compiler")
|
|
71
|
+
|
|
72
|
+
@_assign.register
|
|
73
|
+
def _assign_place(self, lhs: PlaceNode, port: Wire) -> None:
|
|
74
|
+
if subscript := contains_subscript(lhs.place):
|
|
75
|
+
assert subscript.setitem_call is not None
|
|
76
|
+
if subscript.item not in self.dfg:
|
|
77
|
+
self.dfg[subscript.item] = self.expr_compiler.compile(
|
|
78
|
+
subscript.item_expr, self.dfg
|
|
79
|
+
)
|
|
80
|
+
# If the subscript is nested inside the place, e.g. `xs[i].y = ...`, we
|
|
81
|
+
# first need to lookup `tmp = xs[i]`, assign `tmp.y = ...`, and then finally
|
|
82
|
+
# set `xs[i] = tmp`
|
|
83
|
+
if subscript != lhs.place:
|
|
84
|
+
assert subscript.getitem_call is not None
|
|
85
|
+
# Instead of `tmp` just use `xs[i]` as a "name", the dfg tracker doesn't
|
|
86
|
+
# care about this
|
|
87
|
+
self.dfg[subscript] = self.expr_compiler.compile(
|
|
88
|
+
subscript.getitem_call, self.dfg
|
|
89
|
+
)
|
|
90
|
+
# Assign to the name `xs[i].y`
|
|
91
|
+
self.dfg[lhs.place] = port
|
|
92
|
+
# Look up `xs[i]` again since it was mutated by the assignment above, then
|
|
93
|
+
# compile a call to `__setitem__` to actually mutate
|
|
94
|
+
self.dfg[subscript.setitem_call.value_var] = self.dfg[subscript]
|
|
95
|
+
self.expr_compiler.visit(subscript.setitem_call.call)
|
|
96
|
+
else:
|
|
97
|
+
self.dfg[lhs.place] = port
|
|
98
|
+
|
|
99
|
+
@_assign.register
|
|
100
|
+
def _assign_tuple(self, lhs: TupleUnpack, port: Wire) -> None:
|
|
101
|
+
"""Handles assignment where the RHS is a tuple that should be unpacked."""
|
|
102
|
+
# Unpack the RHS tuple
|
|
103
|
+
left, starred, right = lhs.pattern.left, lhs.pattern.starred, lhs.pattern.right
|
|
104
|
+
types = [ty.to_hugr(self.ctx) for ty in type_to_row(get_type(lhs))]
|
|
105
|
+
unpack = self.builder.add_op(ops.UnpackTuple(types), port)
|
|
106
|
+
ports = list(unpack)
|
|
107
|
+
|
|
108
|
+
# Assign left and right
|
|
109
|
+
for pat, wire in zip(left, ports[: len(left)], strict=True):
|
|
110
|
+
self._assign(pat, wire)
|
|
111
|
+
if right:
|
|
112
|
+
for pat, wire in zip(right, ports[-len(right) :], strict=True):
|
|
113
|
+
self._assign(pat, wire)
|
|
114
|
+
|
|
115
|
+
# Starred assignments are collected into an array
|
|
116
|
+
if starred:
|
|
117
|
+
array_ty = get_type(starred)
|
|
118
|
+
starred_ports = (
|
|
119
|
+
ports[len(left) : -len(right)] if right else ports[len(left) :]
|
|
120
|
+
)
|
|
121
|
+
elt = get_element_type(array_ty).to_hugr(self.ctx)
|
|
122
|
+
opts = [self.builder.add_op(ops.Some(elt), p) for p in starred_ports]
|
|
123
|
+
array = self.builder.add_op(array_new(ht.Option(elt), len(opts)), *opts)
|
|
124
|
+
self._assign(starred, array)
|
|
125
|
+
|
|
126
|
+
@_assign.register
|
|
127
|
+
def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
|
|
128
|
+
"""Handles assignment where the RHS is an iterable that should be unpacked."""
|
|
129
|
+
# Given an assignment pattern `left, *starred, right`, collect the RHS into an
|
|
130
|
+
# array and pop from the left and right, leaving us with the starred array in
|
|
131
|
+
# the middle
|
|
132
|
+
assert isinstance(lhs.compr.length, ConstValue)
|
|
133
|
+
length = lhs.compr.length.value
|
|
134
|
+
assert isinstance(length, int)
|
|
135
|
+
opt_elt_ty = ht.Option(lhs.compr.elt_ty.to_hugr(self.ctx))
|
|
136
|
+
|
|
137
|
+
def pop(
|
|
138
|
+
array: Wire, length: int, pats: list[ast.expr], from_left: bool
|
|
139
|
+
) -> tuple[Wire, int]:
|
|
140
|
+
err = "Internal error: unpacking of iterable failed"
|
|
141
|
+
num_pats = len(pats)
|
|
142
|
+
# Pop the number of requested elements from the array
|
|
143
|
+
elts = []
|
|
144
|
+
for i in range(num_pats):
|
|
145
|
+
res = self.builder.add_op(
|
|
146
|
+
array_pop(opt_elt_ty, length - i, from_left), array
|
|
147
|
+
)
|
|
148
|
+
[elt_opt, array] = build_unwrap(self.builder, res, err)
|
|
149
|
+
[elt] = build_unwrap(self.builder, elt_opt, err)
|
|
150
|
+
elts.append(elt)
|
|
151
|
+
# Assign elements to the given patterns
|
|
152
|
+
for pat, elt in zip(
|
|
153
|
+
pats,
|
|
154
|
+
# Assignments are evaluated from left to right, so we need to assign in
|
|
155
|
+
# reverse order if we popped from the right
|
|
156
|
+
elts if from_left else reversed(elts),
|
|
157
|
+
strict=True,
|
|
158
|
+
):
|
|
159
|
+
self._assign(pat, elt)
|
|
160
|
+
return array, length - num_pats
|
|
161
|
+
|
|
162
|
+
self.dfg[lhs.rhs_var.place] = port
|
|
163
|
+
array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr)
|
|
164
|
+
array, length = pop(array, length, lhs.pattern.left, True)
|
|
165
|
+
array, length = pop(array, length, lhs.pattern.right, False)
|
|
166
|
+
if lhs.pattern.starred:
|
|
167
|
+
self._assign(lhs.pattern.starred, array)
|
|
168
|
+
else:
|
|
169
|
+
assert length == 0
|
|
170
|
+
self.builder.add_op(array_discard_empty(opt_elt_ty), array)
|
|
171
|
+
|
|
172
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
|
173
|
+
[target] = node.targets
|
|
174
|
+
port = self.expr_compiler.compile(node.value, self.dfg)
|
|
175
|
+
self._assign(target, port)
|
|
176
|
+
|
|
177
|
+
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
|
|
178
|
+
assert node.value is not None
|
|
179
|
+
port = self.expr_compiler.compile(node.value, self.dfg)
|
|
180
|
+
self._assign(node.target, port)
|
|
181
|
+
|
|
182
|
+
def visit_AugAssign(self, node: ast.AugAssign) -> None:
|
|
183
|
+
raise InternalGuppyError("Node should have been removed during type checking.")
|
|
184
|
+
|
|
185
|
+
def visit_Expr(self, node: ast.Expr) -> None:
|
|
186
|
+
self.expr_compiler.compile_row(node.value, self.dfg)
|
|
187
|
+
|
|
188
|
+
def visit_Return(self, node: ast.Return) -> None:
|
|
189
|
+
# We turn returns into assignments of dummy variables, i.e. the statement
|
|
190
|
+
# `return e0, e1, e2` is turned into `%ret0 = e0; %ret1 = e1; %ret2 = e2`.
|
|
191
|
+
if node.value is not None:
|
|
192
|
+
return_ty = get_type(node.value)
|
|
193
|
+
port = self.expr_compiler.compile(node.value, self.dfg)
|
|
194
|
+
|
|
195
|
+
row: list[tuple[Wire, Type]]
|
|
196
|
+
if isinstance(return_ty, TupleType):
|
|
197
|
+
types = [e.to_hugr(self.ctx) for e in return_ty.element_types]
|
|
198
|
+
unpack = self.builder.add_op(ops.UnpackTuple(types), port)
|
|
199
|
+
row = list(zip(unpack, return_ty.element_types, strict=True))
|
|
200
|
+
else:
|
|
201
|
+
row = [(port, return_ty)]
|
|
202
|
+
|
|
203
|
+
for i, (wire, ty) in enumerate(row):
|
|
204
|
+
var = Variable(return_var(i), ty, node.value)
|
|
205
|
+
self.dfg[var] = wire
|
|
206
|
+
|
|
207
|
+
def visit_CheckedNestedFunctionDef(self, node: CheckedNestedFunctionDef) -> None:
|
|
208
|
+
from guppylang_internals.compiler.func_compiler import compile_local_func_def
|
|
209
|
+
|
|
210
|
+
var = Variable(node.name, node.ty, node)
|
|
211
|
+
loaded_func = compile_local_func_def(node, self.dfg, self.ctx)
|
|
212
|
+
self.dfg[var] = loaded_func
|