guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
6
|
+
|
|
7
|
+
from hugr import Wire, ops
|
|
8
|
+
from hugr import tys as ht
|
|
9
|
+
from hugr.build.dfg import DfBase
|
|
10
|
+
|
|
11
|
+
from guppylang_internals.ast_util import (
|
|
12
|
+
AstNode,
|
|
13
|
+
get_type,
|
|
14
|
+
has_empty_body,
|
|
15
|
+
with_loc,
|
|
16
|
+
with_type,
|
|
17
|
+
)
|
|
18
|
+
from guppylang_internals.checker.core import Context, Globals
|
|
19
|
+
from guppylang_internals.checker.expr_checker import check_call, synthesize_call
|
|
20
|
+
from guppylang_internals.checker.func_checker import check_signature
|
|
21
|
+
from guppylang_internals.compiler.core import (
|
|
22
|
+
CompilerContext,
|
|
23
|
+
DFContainer,
|
|
24
|
+
GlobalConstId,
|
|
25
|
+
partially_monomorphize_args,
|
|
26
|
+
)
|
|
27
|
+
from guppylang_internals.definition.common import ParsableDef
|
|
28
|
+
from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
|
|
29
|
+
from guppylang_internals.diagnostic import Error, Help
|
|
30
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
31
|
+
from guppylang_internals.nodes import GlobalCall
|
|
32
|
+
from guppylang_internals.span import SourceMap
|
|
33
|
+
from guppylang_internals.std._internal.compiler.tket_bool import (
|
|
34
|
+
OpaqueBool,
|
|
35
|
+
make_opaque,
|
|
36
|
+
read_bool,
|
|
37
|
+
)
|
|
38
|
+
from guppylang_internals.tys.subst import Inst, Subst
|
|
39
|
+
from guppylang_internals.tys.ty import (
|
|
40
|
+
FuncInput,
|
|
41
|
+
FunctionType,
|
|
42
|
+
InputFlags,
|
|
43
|
+
NoneType,
|
|
44
|
+
Type,
|
|
45
|
+
type_to_row,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from guppylang_internals.definition.function import PyFunc
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class BodyNotEmptyError(Error):
|
|
54
|
+
title: ClassVar[str] = "Unexpected function body"
|
|
55
|
+
span_label: ClassVar[str] = "Body of custom function `{name}` must be empty"
|
|
56
|
+
name: str
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass(frozen=True)
|
|
60
|
+
class NoSignatureError(Error):
|
|
61
|
+
title: ClassVar[str] = "Type signature missing"
|
|
62
|
+
span_label: ClassVar[str] = "Custom function `{name}` requires a type signature"
|
|
63
|
+
name: str
|
|
64
|
+
|
|
65
|
+
@dataclass(frozen=True)
|
|
66
|
+
class Suggestion(Help):
|
|
67
|
+
message: ClassVar[str] = (
|
|
68
|
+
"Annotate the type signature of `{name}` or disallow the use of `{name}` "
|
|
69
|
+
"as a higher-order value: `@custom_function(..., higher_order_value=False)`"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def __post_init__(self) -> None:
|
|
73
|
+
self.add_sub_diagnostic(NoSignatureError.Suggestion(None))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(frozen=True)
|
|
77
|
+
class NotHigherOrderError(Error):
|
|
78
|
+
title: ClassVar[str] = "Not higher-order"
|
|
79
|
+
span_label: ClassVar[str] = (
|
|
80
|
+
"Function `{name}` may not be used as a higher-order value"
|
|
81
|
+
)
|
|
82
|
+
name: str
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass(frozen=True)
|
|
86
|
+
class RawCustomFunctionDef(ParsableDef):
|
|
87
|
+
"""A raw custom function definition provided by the user.
|
|
88
|
+
|
|
89
|
+
Custom functions provide their own checking and compilation logic using a
|
|
90
|
+
`CustomCallChecker` and a `CustomCallCompiler`.
|
|
91
|
+
|
|
92
|
+
The raw definition stores exactly what the user has written (i.e. the AST together
|
|
93
|
+
with the provided checker and compiler), without inspecting the signature.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
id: The unique definition identifier.
|
|
97
|
+
name: The name of the definition.
|
|
98
|
+
defined_at: The AST node where the definition was defined.
|
|
99
|
+
call_checker: The custom call checker.
|
|
100
|
+
call_compiler: The custom call compiler.
|
|
101
|
+
higher_order_value: Whether the function may be used as a higher-order value.
|
|
102
|
+
signature: User-provided signature.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
python_func: "PyFunc"
|
|
106
|
+
call_checker: "CustomCallChecker"
|
|
107
|
+
call_compiler: "CustomInoutCallCompiler"
|
|
108
|
+
|
|
109
|
+
# Whether the function may be used as a higher-order value. This is only possible
|
|
110
|
+
# if a static type for the function is provided.
|
|
111
|
+
higher_order_value: bool
|
|
112
|
+
|
|
113
|
+
signature: FunctionType | None
|
|
114
|
+
|
|
115
|
+
description: str = field(default="function", init=False)
|
|
116
|
+
|
|
117
|
+
def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
|
|
118
|
+
"""Parses and checks the signature of the custom function.
|
|
119
|
+
|
|
120
|
+
The signature is optional if custom type checking logic is provided by the user.
|
|
121
|
+
However, note that a signature must be provided by either annotation or as an
|
|
122
|
+
argument, if we want to use the function as a higher-order value. If a signature
|
|
123
|
+
is provided as an argument, this will override any annotation.
|
|
124
|
+
|
|
125
|
+
If no signature is provided, we fill in the dummy signature `() -> ()`. This
|
|
126
|
+
type will never be inspected, since we rely on the provided custom checking
|
|
127
|
+
code. The only information we need to access is that it's a function type and
|
|
128
|
+
that there are no unsolved existential vars.
|
|
129
|
+
"""
|
|
130
|
+
from guppylang_internals.definition.function import parse_py_func
|
|
131
|
+
|
|
132
|
+
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
133
|
+
if not has_empty_body(func_ast):
|
|
134
|
+
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
135
|
+
sig = self.signature or self._get_signature(func_ast, globals)
|
|
136
|
+
ty = sig or FunctionType([], NoneType())
|
|
137
|
+
return CustomFunctionDef(
|
|
138
|
+
self.id,
|
|
139
|
+
self.name,
|
|
140
|
+
func_ast,
|
|
141
|
+
ty,
|
|
142
|
+
self.call_checker,
|
|
143
|
+
self.call_compiler,
|
|
144
|
+
self.higher_order_value,
|
|
145
|
+
GlobalConstId.fresh(self.name),
|
|
146
|
+
sig is not None,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _get_signature(
|
|
150
|
+
self, node: ast.FunctionDef, globals: Globals
|
|
151
|
+
) -> FunctionType | None:
|
|
152
|
+
"""Returns the type of the function, if known.
|
|
153
|
+
|
|
154
|
+
Type annotations are needed if we rely on the default call checker or
|
|
155
|
+
want to allow the usage of the function as a higher-order value.
|
|
156
|
+
|
|
157
|
+
Some function types like python's `int()` cannot be expressed in the Guppy
|
|
158
|
+
type system, so we return `None` here and rely on the specialized compiler
|
|
159
|
+
to handle the call.
|
|
160
|
+
"""
|
|
161
|
+
requires_type_annotation = (
|
|
162
|
+
isinstance(self.call_checker, DefaultCallChecker) or self.higher_order_value
|
|
163
|
+
)
|
|
164
|
+
has_type_annotation = node.returns or any(
|
|
165
|
+
arg.annotation for arg in node.args.args
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if requires_type_annotation and not has_type_annotation:
|
|
169
|
+
raise GuppyError(NoSignatureError(node, self.name))
|
|
170
|
+
|
|
171
|
+
if requires_type_annotation:
|
|
172
|
+
return check_signature(node, globals)
|
|
173
|
+
else:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@dataclass(frozen=True)
|
|
178
|
+
class CustomFunctionDef(CompiledCallableDef):
|
|
179
|
+
"""A custom function with parsed and checked signature.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
id: The unique definition identifier.
|
|
183
|
+
name: The name of the definition.
|
|
184
|
+
defined_at: The AST node where the definition was defined.
|
|
185
|
+
ty: The type of the function. This may be a dummy value if `has_signature` is
|
|
186
|
+
false.
|
|
187
|
+
call_checker: The custom call checker.
|
|
188
|
+
call_compiler: The custom call compiler.
|
|
189
|
+
higher_order_value: Whether the function may be used as a higher-order value.
|
|
190
|
+
has_signature: Whether the function has a declared signature.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
defined_at: AstNode | None
|
|
194
|
+
ty: FunctionType
|
|
195
|
+
call_checker: "CustomCallChecker"
|
|
196
|
+
call_compiler: "CustomInoutCallCompiler"
|
|
197
|
+
higher_order_value: bool
|
|
198
|
+
higher_order_func_id: GlobalConstId
|
|
199
|
+
has_signature: bool
|
|
200
|
+
|
|
201
|
+
description: str = field(default="function", init=False)
|
|
202
|
+
|
|
203
|
+
def check_call(
|
|
204
|
+
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
|
|
205
|
+
) -> tuple[ast.expr, Subst]:
|
|
206
|
+
"""Checks the return type of a function call against a given type.
|
|
207
|
+
|
|
208
|
+
This is done by invoking the provided `CustomCallChecker`.
|
|
209
|
+
"""
|
|
210
|
+
self.call_checker._setup(ctx, node, self)
|
|
211
|
+
new_node, subst = self.call_checker.check(args, ty)
|
|
212
|
+
return with_type(ty, with_loc(node, new_node)), subst
|
|
213
|
+
|
|
214
|
+
def synthesize_call(
|
|
215
|
+
self, args: list[ast.expr], node: AstNode, ctx: "Context"
|
|
216
|
+
) -> tuple[ast.expr, Type]:
|
|
217
|
+
"""Synthesizes the return type of a function call.
|
|
218
|
+
|
|
219
|
+
This is done by invoking the provided `CustomCallChecker`.
|
|
220
|
+
"""
|
|
221
|
+
self.call_checker._setup(ctx, node, self)
|
|
222
|
+
new_node, ty = self.call_checker.synthesize(args)
|
|
223
|
+
return with_type(ty, with_loc(node, new_node)), ty
|
|
224
|
+
|
|
225
|
+
def load_with_args(
|
|
226
|
+
self,
|
|
227
|
+
type_args: Inst,
|
|
228
|
+
dfg: "DFContainer",
|
|
229
|
+
ctx: CompilerContext,
|
|
230
|
+
node: AstNode,
|
|
231
|
+
) -> Wire:
|
|
232
|
+
"""Loads the custom function as a value into a local dataflow graph.
|
|
233
|
+
|
|
234
|
+
This will place a `FunctionDef` node in the local DFG, and load with a
|
|
235
|
+
`LoadFunc` node. This operation will fail if the function is not allowed
|
|
236
|
+
to be used as a higher-order value.
|
|
237
|
+
"""
|
|
238
|
+
# TODO: This should be raised during checking, not compilation!
|
|
239
|
+
if not self.higher_order_value:
|
|
240
|
+
raise GuppyError(NotHigherOrderError(node, self.name))
|
|
241
|
+
assert len(self.ty.params) == len(type_args)
|
|
242
|
+
|
|
243
|
+
# Partially monomorphize the function if required
|
|
244
|
+
mono_args, rem_args = partially_monomorphize_args(
|
|
245
|
+
self.ty.params, type_args, ctx
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# We create a generic `FunctionDef` that takes some inputs, compiles a call to
|
|
249
|
+
# the function, and returns the results
|
|
250
|
+
func, already_defined = ctx.declare_global_func(
|
|
251
|
+
self.higher_order_func_id,
|
|
252
|
+
self.ty.instantiate_partial(mono_args).to_hugr_poly(ctx),
|
|
253
|
+
mono_args,
|
|
254
|
+
)
|
|
255
|
+
if not already_defined:
|
|
256
|
+
with ctx.set_monomorphized_args(mono_args):
|
|
257
|
+
func_dfg = DFContainer(func, ctx, dfg.locals.copy())
|
|
258
|
+
args: list[Wire] = list(func.inputs())
|
|
259
|
+
generic_ty_args = [param.to_bound() for param in self.ty.params]
|
|
260
|
+
returns = self.compile_call(args, generic_ty_args, func_dfg, ctx, node)
|
|
261
|
+
func.set_outputs(*returns.regular_returns, *returns.inout_returns)
|
|
262
|
+
|
|
263
|
+
# Finally, load the function into the local DFG
|
|
264
|
+
mono_ty = self.ty.instantiate(type_args).to_hugr(ctx)
|
|
265
|
+
hugr_ty_args = [ta.to_hugr(ctx) for ta in rem_args]
|
|
266
|
+
return dfg.builder.load_function(func, mono_ty, hugr_ty_args)
|
|
267
|
+
|
|
268
|
+
def compile_call(
|
|
269
|
+
self,
|
|
270
|
+
args: list[Wire],
|
|
271
|
+
type_args: Inst,
|
|
272
|
+
dfg: "DFContainer",
|
|
273
|
+
ctx: CompilerContext,
|
|
274
|
+
node: AstNode,
|
|
275
|
+
) -> CallReturnWires:
|
|
276
|
+
"""Compiles a call to the function."""
|
|
277
|
+
if self.has_signature:
|
|
278
|
+
concrete_ty = self.ty.instantiate(type_args)
|
|
279
|
+
else:
|
|
280
|
+
assert isinstance(node, GlobalCall)
|
|
281
|
+
concrete_ty = FunctionType(
|
|
282
|
+
[FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
|
|
283
|
+
get_type(node),
|
|
284
|
+
)
|
|
285
|
+
hugr_ty = concrete_ty.to_hugr(ctx)
|
|
286
|
+
|
|
287
|
+
self.call_compiler._setup(type_args, dfg, ctx, node, hugr_ty, self)
|
|
288
|
+
return self.call_compiler.compile_with_inouts(args)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class CustomCallChecker(ABC):
|
|
292
|
+
"""Abstract base class for custom function call type checkers."""
|
|
293
|
+
|
|
294
|
+
ctx: Context
|
|
295
|
+
node: AstNode
|
|
296
|
+
func: CustomFunctionDef
|
|
297
|
+
|
|
298
|
+
def _setup(self, ctx: Context, node: AstNode, func: CustomFunctionDef) -> None:
|
|
299
|
+
self.ctx = ctx
|
|
300
|
+
self.node = node
|
|
301
|
+
self.func = func
|
|
302
|
+
|
|
303
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
304
|
+
"""Checks the return value against a given type.
|
|
305
|
+
|
|
306
|
+
Returns a (possibly) transformed and annotated AST node for the call.
|
|
307
|
+
"""
|
|
308
|
+
from guppylang_internals.checker.expr_checker import check_type_against
|
|
309
|
+
|
|
310
|
+
expr, res_ty = self.synthesize(args)
|
|
311
|
+
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
|
|
312
|
+
return expr, subst
|
|
313
|
+
|
|
314
|
+
@abstractmethod
|
|
315
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
316
|
+
"""Synthesizes a type for the return value of a call.
|
|
317
|
+
|
|
318
|
+
Also returns a (possibly) transformed and annotated argument list.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class CustomInoutCallCompiler(ABC):
|
|
323
|
+
"""Abstract base class for custom function call compilers with borrowed args.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
builder: The function builder where the function should be defined.
|
|
327
|
+
type_args: The type arguments for the function.
|
|
328
|
+
globals: The compiled globals.
|
|
329
|
+
node: The AST node where the function is defined.
|
|
330
|
+
ty: The type of the function, if known.
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
dfg: DFContainer
|
|
334
|
+
type_args: Inst
|
|
335
|
+
ctx: CompilerContext
|
|
336
|
+
node: AstNode
|
|
337
|
+
ty: ht.FunctionType
|
|
338
|
+
func: CustomFunctionDef | None
|
|
339
|
+
|
|
340
|
+
def _setup(
|
|
341
|
+
self,
|
|
342
|
+
type_args: Inst,
|
|
343
|
+
dfg: DFContainer,
|
|
344
|
+
ctx: CompilerContext,
|
|
345
|
+
node: AstNode,
|
|
346
|
+
hugr_ty: ht.FunctionType,
|
|
347
|
+
func: CustomFunctionDef | None,
|
|
348
|
+
) -> None:
|
|
349
|
+
self.type_args = type_args
|
|
350
|
+
self.dfg = dfg
|
|
351
|
+
self.ctx = ctx
|
|
352
|
+
self.node = node
|
|
353
|
+
self.ty = hugr_ty
|
|
354
|
+
self.func = func
|
|
355
|
+
|
|
356
|
+
@abstractmethod
|
|
357
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
358
|
+
"""Compiles a custom function call.
|
|
359
|
+
|
|
360
|
+
Returns the outputs of the call together with any borrowed arguments that are
|
|
361
|
+
passed through the function.
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def builder(self) -> DfBase[ops.DfParentOp]:
|
|
366
|
+
"""The hugr dataflow builder."""
|
|
367
|
+
return self.dfg.builder
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class CustomCallCompiler(CustomInoutCallCompiler, ABC):
|
|
371
|
+
"""Abstract base class for custom function call compilers with only owned args."""
|
|
372
|
+
|
|
373
|
+
@abstractmethod
|
|
374
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
375
|
+
"""Compiles a custom function call and returns the resulting ports.
|
|
376
|
+
|
|
377
|
+
Use the provided `self.builder` to add nodes to the Hugr graph.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
381
|
+
return CallReturnWires(self.compile(args), inout_returns=[])
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class DefaultCallChecker(CustomCallChecker):
|
|
385
|
+
"""Checks function calls by comparing to a type signature."""
|
|
386
|
+
|
|
387
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
388
|
+
# Use default implementation from the expression checker
|
|
389
|
+
args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)
|
|
390
|
+
return GlobalCall(def_id=self.func.id, args=args, type_args=inst), subst
|
|
391
|
+
|
|
392
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
393
|
+
# Use default implementation from the expression checker
|
|
394
|
+
args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
|
|
395
|
+
return GlobalCall(def_id=self.func.id, args=args, type_args=inst), ty
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class NotImplementedCallCompiler(CustomCallCompiler):
|
|
399
|
+
"""Call compiler for custom functions that are already lowered during checking.
|
|
400
|
+
|
|
401
|
+
For example, the custom checker could replace the call with a series of calls to
|
|
402
|
+
other functions. In that case, the original function will no longer be present and
|
|
403
|
+
thus doesn't need to be compiled.
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
407
|
+
raise InternalGuppyError("Function should have been removed during checking")
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class OpCompiler(CustomInoutCallCompiler):
|
|
411
|
+
"""Call compiler for functions that are directly implemented via Hugr ops.
|
|
412
|
+
|
|
413
|
+
args:
|
|
414
|
+
op: A function that takes an instantiation of the type arguments as well as
|
|
415
|
+
the monomorphic function type, and returns a concrete HUGR op.
|
|
416
|
+
"""
|
|
417
|
+
|
|
418
|
+
op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
|
|
419
|
+
|
|
420
|
+
def __init__(
|
|
421
|
+
self, op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
|
|
422
|
+
) -> None:
|
|
423
|
+
self.op = op
|
|
424
|
+
|
|
425
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
426
|
+
op = self.op(self.ty, self.type_args, self.ctx)
|
|
427
|
+
node = self.builder.add_op(op, *args)
|
|
428
|
+
num_returns = (
|
|
429
|
+
len(type_to_row(self.func.ty.output)) if self.func else len(self.ty.output)
|
|
430
|
+
)
|
|
431
|
+
return CallReturnWires(
|
|
432
|
+
regular_returns=list(node[:num_returns]),
|
|
433
|
+
inout_returns=list(node[num_returns:]),
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class BoolOpCompiler(CustomInoutCallCompiler):
|
|
438
|
+
"""Call compiler for functions that are directly implemented via Hugr ops but need
|
|
439
|
+
input and/or output conversions from hugr sum bools to the opaque bools Guppy is
|
|
440
|
+
using.
|
|
441
|
+
|
|
442
|
+
args:
|
|
443
|
+
op: A function that takes an instantiation of the type arguments as well as
|
|
444
|
+
the monomorphic function type, and returns a concrete HUGR op.
|
|
445
|
+
"""
|
|
446
|
+
|
|
447
|
+
op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
|
|
448
|
+
|
|
449
|
+
def __init__(
|
|
450
|
+
self, op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
|
|
451
|
+
) -> None:
|
|
452
|
+
self.op = op
|
|
453
|
+
|
|
454
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
455
|
+
converted_in = [ht.Bool if inp == OpaqueBool else inp for inp in self.ty.input]
|
|
456
|
+
converted_out = [
|
|
457
|
+
ht.Bool if out == OpaqueBool else out for out in self.ty.output
|
|
458
|
+
]
|
|
459
|
+
hugr_op_ty = ht.FunctionType(converted_in, converted_out)
|
|
460
|
+
op = self.op(hugr_op_ty, self.type_args, self.ctx)
|
|
461
|
+
converted_args = [
|
|
462
|
+
self.builder.add_op(read_bool(), arg)
|
|
463
|
+
if self.builder.hugr.port_type(arg.out_port()) == OpaqueBool
|
|
464
|
+
else arg
|
|
465
|
+
for arg in args
|
|
466
|
+
]
|
|
467
|
+
node = self.builder.add_op(op, *converted_args)
|
|
468
|
+
result = list(node.outputs())
|
|
469
|
+
converted_result = [
|
|
470
|
+
self.builder.add_op(make_opaque(), res)
|
|
471
|
+
if self.builder.hugr.port_type(res.out_port()) == ht.Bool
|
|
472
|
+
else res
|
|
473
|
+
for res in result
|
|
474
|
+
]
|
|
475
|
+
return CallReturnWires(
|
|
476
|
+
regular_returns=converted_result,
|
|
477
|
+
inout_returns=[],
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class NoopCompiler(CustomCallCompiler):
|
|
482
|
+
"""Call compiler for functions that are noops."""
|
|
483
|
+
|
|
484
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
485
|
+
return args
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
class CopyInoutCompiler(CustomInoutCallCompiler):
|
|
489
|
+
"""Call compiler for functions that are noops but only want to borrow arguments."""
|
|
490
|
+
|
|
491
|
+
def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
|
|
492
|
+
return CallReturnWires(regular_returns=args, inout_returns=args)
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import ClassVar
|
|
4
|
+
|
|
5
|
+
from hugr import Node, Wire
|
|
6
|
+
from hugr.build import function as hf
|
|
7
|
+
from hugr.build.dfg import DefinitionBuilder, OpVar
|
|
8
|
+
|
|
9
|
+
from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc, with_type
|
|
10
|
+
from guppylang_internals.checker.core import Context, Globals
|
|
11
|
+
from guppylang_internals.checker.expr_checker import check_call, synthesize_call
|
|
12
|
+
from guppylang_internals.checker.func_checker import check_signature
|
|
13
|
+
from guppylang_internals.compiler.core import (
|
|
14
|
+
CompilerContext,
|
|
15
|
+
DFContainer,
|
|
16
|
+
requires_monomorphization,
|
|
17
|
+
)
|
|
18
|
+
from guppylang_internals.definition.common import CompilableDef, ParsableDef
|
|
19
|
+
from guppylang_internals.definition.function import (
|
|
20
|
+
PyFunc,
|
|
21
|
+
compile_call,
|
|
22
|
+
load_with_args,
|
|
23
|
+
parse_py_func,
|
|
24
|
+
)
|
|
25
|
+
from guppylang_internals.definition.value import (
|
|
26
|
+
CallableDef,
|
|
27
|
+
CallReturnWires,
|
|
28
|
+
CompiledCallableDef,
|
|
29
|
+
CompiledHugrNodeDef,
|
|
30
|
+
)
|
|
31
|
+
from guppylang_internals.diagnostic import Error
|
|
32
|
+
from guppylang_internals.error import GuppyError
|
|
33
|
+
from guppylang_internals.nodes import GlobalCall
|
|
34
|
+
from guppylang_internals.span import SourceMap
|
|
35
|
+
from guppylang_internals.tys.param import Parameter
|
|
36
|
+
from guppylang_internals.tys.subst import Inst, Subst
|
|
37
|
+
from guppylang_internals.tys.ty import Type
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class BodyNotEmptyError(Error):
|
|
42
|
+
title: ClassVar[str] = "Unexpected function body"
|
|
43
|
+
span_label: ClassVar[str] = "Body of declared function `{name}` must be empty"
|
|
44
|
+
name: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class MonomorphizeError(Error):
|
|
49
|
+
title: ClassVar[str] = "Invalid function declaration"
|
|
50
|
+
span_label: ClassVar[str] = (
|
|
51
|
+
"Function declaration `{name}` is not allowed to be generic over `{param}`"
|
|
52
|
+
)
|
|
53
|
+
name: str
|
|
54
|
+
param: Parameter
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class RawFunctionDecl(ParsableDef):
|
|
59
|
+
"""A raw function declaration provided by the user.
|
|
60
|
+
|
|
61
|
+
The raw declaration stores exactly what the user has written (i.e. the AST), without
|
|
62
|
+
any additional checking or parsing.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
python_func: PyFunc
|
|
66
|
+
description: str = field(default="function", init=False)
|
|
67
|
+
|
|
68
|
+
def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
|
|
69
|
+
"""Parses and checks the user-provided signature of the function."""
|
|
70
|
+
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
71
|
+
ty = check_signature(func_ast, globals)
|
|
72
|
+
if not has_empty_body(func_ast):
|
|
73
|
+
raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
|
|
74
|
+
# Make sure we won't need monomorphization to compile this declaration
|
|
75
|
+
for param in ty.params:
|
|
76
|
+
if requires_monomorphization(param):
|
|
77
|
+
raise GuppyError(MonomorphizeError(func_ast, self.name, param))
|
|
78
|
+
return CheckedFunctionDecl(
|
|
79
|
+
self.id,
|
|
80
|
+
self.name,
|
|
81
|
+
func_ast,
|
|
82
|
+
ty,
|
|
83
|
+
self.python_func,
|
|
84
|
+
docstring,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(frozen=True)
|
|
89
|
+
class CheckedFunctionDecl(RawFunctionDecl, CompilableDef, CallableDef):
|
|
90
|
+
"""A function declaration with parsed and checked signature.
|
|
91
|
+
|
|
92
|
+
In particular, this means that we have determined a type for the function.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
defined_at: ast.FunctionDef
|
|
96
|
+
docstring: str | None
|
|
97
|
+
|
|
98
|
+
def check_call(
|
|
99
|
+
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
|
|
100
|
+
) -> tuple[ast.expr, Subst]:
|
|
101
|
+
"""Checks the return type of a function call against a given type."""
|
|
102
|
+
# Use default implementation from the expression checker
|
|
103
|
+
args, subst, inst = check_call(self.ty, args, ty, node, ctx)
|
|
104
|
+
node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
|
|
105
|
+
return node, subst
|
|
106
|
+
|
|
107
|
+
def synthesize_call(
|
|
108
|
+
self, args: list[ast.expr], node: AstNode, ctx: Context
|
|
109
|
+
) -> tuple[GlobalCall, Type]:
|
|
110
|
+
"""Synthesizes the return type of a function call."""
|
|
111
|
+
# Use default implementation from the expression checker
|
|
112
|
+
args, ty, inst = synthesize_call(self.ty, args, node, ctx)
|
|
113
|
+
node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
|
|
114
|
+
return with_type(ty, node), ty
|
|
115
|
+
|
|
116
|
+
def compile_outer(
|
|
117
|
+
self, module: DefinitionBuilder[OpVar], ctx: CompilerContext
|
|
118
|
+
) -> "CompiledFunctionDecl":
|
|
119
|
+
"""Adds a Hugr `FuncDecl` node for this function to the Hugr."""
|
|
120
|
+
assert isinstance(
|
|
121
|
+
module, hf.Module
|
|
122
|
+
), "Functions can only be declared in modules"
|
|
123
|
+
module: hf.Module = module
|
|
124
|
+
|
|
125
|
+
node = module.declare_function(self.name, self.ty.to_hugr_poly(ctx))
|
|
126
|
+
return CompiledFunctionDecl(
|
|
127
|
+
self.id,
|
|
128
|
+
self.name,
|
|
129
|
+
self.defined_at,
|
|
130
|
+
self.ty,
|
|
131
|
+
self.python_func,
|
|
132
|
+
self.docstring,
|
|
133
|
+
node,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass(frozen=True)
|
|
138
|
+
class CompiledFunctionDecl(
|
|
139
|
+
CheckedFunctionDecl, CompiledCallableDef, CompiledHugrNodeDef
|
|
140
|
+
):
|
|
141
|
+
"""A function declaration with a corresponding Hugr node."""
|
|
142
|
+
|
|
143
|
+
declaration: Node
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def hugr_node(self) -> Node:
|
|
147
|
+
"""The Hugr node this definition was compiled into."""
|
|
148
|
+
return self.declaration
|
|
149
|
+
|
|
150
|
+
def load_with_args(
|
|
151
|
+
self,
|
|
152
|
+
type_args: Inst,
|
|
153
|
+
dfg: DFContainer,
|
|
154
|
+
ctx: CompilerContext,
|
|
155
|
+
node: AstNode,
|
|
156
|
+
) -> Wire:
|
|
157
|
+
"""Loads the function as a value into a local Hugr dataflow graph."""
|
|
158
|
+
# Use implementation from function definition.
|
|
159
|
+
return load_with_args(type_args, dfg, self.ty, self.declaration)
|
|
160
|
+
|
|
161
|
+
def compile_call(
|
|
162
|
+
self,
|
|
163
|
+
args: list[Wire],
|
|
164
|
+
type_args: Inst,
|
|
165
|
+
dfg: DFContainer,
|
|
166
|
+
ctx: CompilerContext,
|
|
167
|
+
node: AstNode,
|
|
168
|
+
) -> CallReturnWires:
|
|
169
|
+
"""Compiles a call to the function."""
|
|
170
|
+
# Use implementation from function definition.
|
|
171
|
+
return compile_call(args, type_args, dfg, self.ty, self.declaration)
|