guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,492 @@
1
+ import ast
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, ClassVar
6
+
7
+ from hugr import Wire, ops
8
+ from hugr import tys as ht
9
+ from hugr.build.dfg import DfBase
10
+
11
+ from guppylang_internals.ast_util import (
12
+ AstNode,
13
+ get_type,
14
+ has_empty_body,
15
+ with_loc,
16
+ with_type,
17
+ )
18
+ from guppylang_internals.checker.core import Context, Globals
19
+ from guppylang_internals.checker.expr_checker import check_call, synthesize_call
20
+ from guppylang_internals.checker.func_checker import check_signature
21
+ from guppylang_internals.compiler.core import (
22
+ CompilerContext,
23
+ DFContainer,
24
+ GlobalConstId,
25
+ partially_monomorphize_args,
26
+ )
27
+ from guppylang_internals.definition.common import ParsableDef
28
+ from guppylang_internals.definition.value import CallReturnWires, CompiledCallableDef
29
+ from guppylang_internals.diagnostic import Error, Help
30
+ from guppylang_internals.error import GuppyError, InternalGuppyError
31
+ from guppylang_internals.nodes import GlobalCall
32
+ from guppylang_internals.span import SourceMap
33
+ from guppylang_internals.std._internal.compiler.tket_bool import (
34
+ OpaqueBool,
35
+ make_opaque,
36
+ read_bool,
37
+ )
38
+ from guppylang_internals.tys.subst import Inst, Subst
39
+ from guppylang_internals.tys.ty import (
40
+ FuncInput,
41
+ FunctionType,
42
+ InputFlags,
43
+ NoneType,
44
+ Type,
45
+ type_to_row,
46
+ )
47
+
48
+ if TYPE_CHECKING:
49
+ from guppylang_internals.definition.function import PyFunc
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class BodyNotEmptyError(Error):
54
+ title: ClassVar[str] = "Unexpected function body"
55
+ span_label: ClassVar[str] = "Body of custom function `{name}` must be empty"
56
+ name: str
57
+
58
+
59
+ @dataclass(frozen=True)
60
+ class NoSignatureError(Error):
61
+ title: ClassVar[str] = "Type signature missing"
62
+ span_label: ClassVar[str] = "Custom function `{name}` requires a type signature"
63
+ name: str
64
+
65
+ @dataclass(frozen=True)
66
+ class Suggestion(Help):
67
+ message: ClassVar[str] = (
68
+ "Annotate the type signature of `{name}` or disallow the use of `{name}` "
69
+ "as a higher-order value: `@custom_function(..., higher_order_value=False)`"
70
+ )
71
+
72
+ def __post_init__(self) -> None:
73
+ self.add_sub_diagnostic(NoSignatureError.Suggestion(None))
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class NotHigherOrderError(Error):
78
+ title: ClassVar[str] = "Not higher-order"
79
+ span_label: ClassVar[str] = (
80
+ "Function `{name}` may not be used as a higher-order value"
81
+ )
82
+ name: str
83
+
84
+
85
+ @dataclass(frozen=True)
86
+ class RawCustomFunctionDef(ParsableDef):
87
+ """A raw custom function definition provided by the user.
88
+
89
+ Custom functions provide their own checking and compilation logic using a
90
+ `CustomCallChecker` and a `CustomCallCompiler`.
91
+
92
+ The raw definition stores exactly what the user has written (i.e. the AST together
93
+ with the provided checker and compiler), without inspecting the signature.
94
+
95
+ Args:
96
+ id: The unique definition identifier.
97
+ name: The name of the definition.
98
+ defined_at: The AST node where the definition was defined.
99
+ call_checker: The custom call checker.
100
+ call_compiler: The custom call compiler.
101
+ higher_order_value: Whether the function may be used as a higher-order value.
102
+ signature: User-provided signature.
103
+ """
104
+
105
+ python_func: "PyFunc"
106
+ call_checker: "CustomCallChecker"
107
+ call_compiler: "CustomInoutCallCompiler"
108
+
109
+ # Whether the function may be used as a higher-order value. This is only possible
110
+ # if a static type for the function is provided.
111
+ higher_order_value: bool
112
+
113
+ signature: FunctionType | None
114
+
115
+ description: str = field(default="function", init=False)
116
+
117
+ def parse(self, globals: "Globals", sources: SourceMap) -> "CustomFunctionDef":
118
+ """Parses and checks the signature of the custom function.
119
+
120
+ The signature is optional if custom type checking logic is provided by the user.
121
+ However, note that a signature must be provided by either annotation or as an
122
+ argument, if we want to use the function as a higher-order value. If a signature
123
+ is provided as an argument, this will override any annotation.
124
+
125
+ If no signature is provided, we fill in the dummy signature `() -> ()`. This
126
+ type will never be inspected, since we rely on the provided custom checking
127
+ code. The only information we need to access is that it's a function type and
128
+ that there are no unsolved existential vars.
129
+ """
130
+ from guppylang_internals.definition.function import parse_py_func
131
+
132
+ func_ast, docstring = parse_py_func(self.python_func, sources)
133
+ if not has_empty_body(func_ast):
134
+ raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
135
+ sig = self.signature or self._get_signature(func_ast, globals)
136
+ ty = sig or FunctionType([], NoneType())
137
+ return CustomFunctionDef(
138
+ self.id,
139
+ self.name,
140
+ func_ast,
141
+ ty,
142
+ self.call_checker,
143
+ self.call_compiler,
144
+ self.higher_order_value,
145
+ GlobalConstId.fresh(self.name),
146
+ sig is not None,
147
+ )
148
+
149
+ def _get_signature(
150
+ self, node: ast.FunctionDef, globals: Globals
151
+ ) -> FunctionType | None:
152
+ """Returns the type of the function, if known.
153
+
154
+ Type annotations are needed if we rely on the default call checker or
155
+ want to allow the usage of the function as a higher-order value.
156
+
157
+ Some function types like python's `int()` cannot be expressed in the Guppy
158
+ type system, so we return `None` here and rely on the specialized compiler
159
+ to handle the call.
160
+ """
161
+ requires_type_annotation = (
162
+ isinstance(self.call_checker, DefaultCallChecker) or self.higher_order_value
163
+ )
164
+ has_type_annotation = node.returns or any(
165
+ arg.annotation for arg in node.args.args
166
+ )
167
+
168
+ if requires_type_annotation and not has_type_annotation:
169
+ raise GuppyError(NoSignatureError(node, self.name))
170
+
171
+ if requires_type_annotation:
172
+ return check_signature(node, globals)
173
+ else:
174
+ return None
175
+
176
+
177
+ @dataclass(frozen=True)
178
+ class CustomFunctionDef(CompiledCallableDef):
179
+ """A custom function with parsed and checked signature.
180
+
181
+ Args:
182
+ id: The unique definition identifier.
183
+ name: The name of the definition.
184
+ defined_at: The AST node where the definition was defined.
185
+ ty: The type of the function. This may be a dummy value if `has_signature` is
186
+ false.
187
+ call_checker: The custom call checker.
188
+ call_compiler: The custom call compiler.
189
+ higher_order_value: Whether the function may be used as a higher-order value.
190
+ has_signature: Whether the function has a declared signature.
191
+ """
192
+
193
+ defined_at: AstNode | None
194
+ ty: FunctionType
195
+ call_checker: "CustomCallChecker"
196
+ call_compiler: "CustomInoutCallCompiler"
197
+ higher_order_value: bool
198
+ higher_order_func_id: GlobalConstId
199
+ has_signature: bool
200
+
201
+ description: str = field(default="function", init=False)
202
+
203
+ def check_call(
204
+ self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
205
+ ) -> tuple[ast.expr, Subst]:
206
+ """Checks the return type of a function call against a given type.
207
+
208
+ This is done by invoking the provided `CustomCallChecker`.
209
+ """
210
+ self.call_checker._setup(ctx, node, self)
211
+ new_node, subst = self.call_checker.check(args, ty)
212
+ return with_type(ty, with_loc(node, new_node)), subst
213
+
214
+ def synthesize_call(
215
+ self, args: list[ast.expr], node: AstNode, ctx: "Context"
216
+ ) -> tuple[ast.expr, Type]:
217
+ """Synthesizes the return type of a function call.
218
+
219
+ This is done by invoking the provided `CustomCallChecker`.
220
+ """
221
+ self.call_checker._setup(ctx, node, self)
222
+ new_node, ty = self.call_checker.synthesize(args)
223
+ return with_type(ty, with_loc(node, new_node)), ty
224
+
225
+ def load_with_args(
226
+ self,
227
+ type_args: Inst,
228
+ dfg: "DFContainer",
229
+ ctx: CompilerContext,
230
+ node: AstNode,
231
+ ) -> Wire:
232
+ """Loads the custom function as a value into a local dataflow graph.
233
+
234
+ This will place a `FunctionDef` node in the local DFG, and load with a
235
+ `LoadFunc` node. This operation will fail if the function is not allowed
236
+ to be used as a higher-order value.
237
+ """
238
+ # TODO: This should be raised during checking, not compilation!
239
+ if not self.higher_order_value:
240
+ raise GuppyError(NotHigherOrderError(node, self.name))
241
+ assert len(self.ty.params) == len(type_args)
242
+
243
+ # Partially monomorphize the function if required
244
+ mono_args, rem_args = partially_monomorphize_args(
245
+ self.ty.params, type_args, ctx
246
+ )
247
+
248
+ # We create a generic `FunctionDef` that takes some inputs, compiles a call to
249
+ # the function, and returns the results
250
+ func, already_defined = ctx.declare_global_func(
251
+ self.higher_order_func_id,
252
+ self.ty.instantiate_partial(mono_args).to_hugr_poly(ctx),
253
+ mono_args,
254
+ )
255
+ if not already_defined:
256
+ with ctx.set_monomorphized_args(mono_args):
257
+ func_dfg = DFContainer(func, ctx, dfg.locals.copy())
258
+ args: list[Wire] = list(func.inputs())
259
+ generic_ty_args = [param.to_bound() for param in self.ty.params]
260
+ returns = self.compile_call(args, generic_ty_args, func_dfg, ctx, node)
261
+ func.set_outputs(*returns.regular_returns, *returns.inout_returns)
262
+
263
+ # Finally, load the function into the local DFG
264
+ mono_ty = self.ty.instantiate(type_args).to_hugr(ctx)
265
+ hugr_ty_args = [ta.to_hugr(ctx) for ta in rem_args]
266
+ return dfg.builder.load_function(func, mono_ty, hugr_ty_args)
267
+
268
+ def compile_call(
269
+ self,
270
+ args: list[Wire],
271
+ type_args: Inst,
272
+ dfg: "DFContainer",
273
+ ctx: CompilerContext,
274
+ node: AstNode,
275
+ ) -> CallReturnWires:
276
+ """Compiles a call to the function."""
277
+ if self.has_signature:
278
+ concrete_ty = self.ty.instantiate(type_args)
279
+ else:
280
+ assert isinstance(node, GlobalCall)
281
+ concrete_ty = FunctionType(
282
+ [FuncInput(get_type(arg), InputFlags.NoFlags) for arg in node.args],
283
+ get_type(node),
284
+ )
285
+ hugr_ty = concrete_ty.to_hugr(ctx)
286
+
287
+ self.call_compiler._setup(type_args, dfg, ctx, node, hugr_ty, self)
288
+ return self.call_compiler.compile_with_inouts(args)
289
+
290
+
291
+ class CustomCallChecker(ABC):
292
+ """Abstract base class for custom function call type checkers."""
293
+
294
+ ctx: Context
295
+ node: AstNode
296
+ func: CustomFunctionDef
297
+
298
+ def _setup(self, ctx: Context, node: AstNode, func: CustomFunctionDef) -> None:
299
+ self.ctx = ctx
300
+ self.node = node
301
+ self.func = func
302
+
303
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
304
+ """Checks the return value against a given type.
305
+
306
+ Returns a (possibly) transformed and annotated AST node for the call.
307
+ """
308
+ from guppylang_internals.checker.expr_checker import check_type_against
309
+
310
+ expr, res_ty = self.synthesize(args)
311
+ expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
312
+ return expr, subst
313
+
314
+ @abstractmethod
315
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
316
+ """Synthesizes a type for the return value of a call.
317
+
318
+ Also returns a (possibly) transformed and annotated argument list.
319
+ """
320
+
321
+
322
+ class CustomInoutCallCompiler(ABC):
323
+ """Abstract base class for custom function call compilers with borrowed args.
324
+
325
+ Args:
326
+ builder: The function builder where the function should be defined.
327
+ type_args: The type arguments for the function.
328
+ globals: The compiled globals.
329
+ node: The AST node where the function is defined.
330
+ ty: The type of the function, if known.
331
+ """
332
+
333
+ dfg: DFContainer
334
+ type_args: Inst
335
+ ctx: CompilerContext
336
+ node: AstNode
337
+ ty: ht.FunctionType
338
+ func: CustomFunctionDef | None
339
+
340
+ def _setup(
341
+ self,
342
+ type_args: Inst,
343
+ dfg: DFContainer,
344
+ ctx: CompilerContext,
345
+ node: AstNode,
346
+ hugr_ty: ht.FunctionType,
347
+ func: CustomFunctionDef | None,
348
+ ) -> None:
349
+ self.type_args = type_args
350
+ self.dfg = dfg
351
+ self.ctx = ctx
352
+ self.node = node
353
+ self.ty = hugr_ty
354
+ self.func = func
355
+
356
+ @abstractmethod
357
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
358
+ """Compiles a custom function call.
359
+
360
+ Returns the outputs of the call together with any borrowed arguments that are
361
+ passed through the function.
362
+ """
363
+
364
+ @property
365
+ def builder(self) -> DfBase[ops.DfParentOp]:
366
+ """The hugr dataflow builder."""
367
+ return self.dfg.builder
368
+
369
+
370
+ class CustomCallCompiler(CustomInoutCallCompiler, ABC):
371
+ """Abstract base class for custom function call compilers with only owned args."""
372
+
373
+ @abstractmethod
374
+ def compile(self, args: list[Wire]) -> list[Wire]:
375
+ """Compiles a custom function call and returns the resulting ports.
376
+
377
+ Use the provided `self.builder` to add nodes to the Hugr graph.
378
+ """
379
+
380
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
381
+ return CallReturnWires(self.compile(args), inout_returns=[])
382
+
383
+
384
+ class DefaultCallChecker(CustomCallChecker):
385
+ """Checks function calls by comparing to a type signature."""
386
+
387
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
388
+ # Use default implementation from the expression checker
389
+ args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)
390
+ return GlobalCall(def_id=self.func.id, args=args, type_args=inst), subst
391
+
392
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
393
+ # Use default implementation from the expression checker
394
+ args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
395
+ return GlobalCall(def_id=self.func.id, args=args, type_args=inst), ty
396
+
397
+
398
+ class NotImplementedCallCompiler(CustomCallCompiler):
399
+ """Call compiler for custom functions that are already lowered during checking.
400
+
401
+ For example, the custom checker could replace the call with a series of calls to
402
+ other functions. In that case, the original function will no longer be present and
403
+ thus doesn't need to be compiled.
404
+ """
405
+
406
+ def compile(self, args: list[Wire]) -> list[Wire]:
407
+ raise InternalGuppyError("Function should have been removed during checking")
408
+
409
+
410
+ class OpCompiler(CustomInoutCallCompiler):
411
+ """Call compiler for functions that are directly implemented via Hugr ops.
412
+
413
+ args:
414
+ op: A function that takes an instantiation of the type arguments as well as
415
+ the monomorphic function type, and returns a concrete HUGR op.
416
+ """
417
+
418
+ op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
419
+
420
+ def __init__(
421
+ self, op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
422
+ ) -> None:
423
+ self.op = op
424
+
425
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
426
+ op = self.op(self.ty, self.type_args, self.ctx)
427
+ node = self.builder.add_op(op, *args)
428
+ num_returns = (
429
+ len(type_to_row(self.func.ty.output)) if self.func else len(self.ty.output)
430
+ )
431
+ return CallReturnWires(
432
+ regular_returns=list(node[:num_returns]),
433
+ inout_returns=list(node[num_returns:]),
434
+ )
435
+
436
+
437
+ class BoolOpCompiler(CustomInoutCallCompiler):
438
+ """Call compiler for functions that are directly implemented via Hugr ops but need
439
+ input and/or output conversions from hugr sum bools to the opaque bools Guppy is
440
+ using.
441
+
442
+ args:
443
+ op: A function that takes an instantiation of the type arguments as well as
444
+ the monomorphic function type, and returns a concrete HUGR op.
445
+ """
446
+
447
+ op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
448
+
449
+ def __init__(
450
+ self, op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp]
451
+ ) -> None:
452
+ self.op = op
453
+
454
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
455
+ converted_in = [ht.Bool if inp == OpaqueBool else inp for inp in self.ty.input]
456
+ converted_out = [
457
+ ht.Bool if out == OpaqueBool else out for out in self.ty.output
458
+ ]
459
+ hugr_op_ty = ht.FunctionType(converted_in, converted_out)
460
+ op = self.op(hugr_op_ty, self.type_args, self.ctx)
461
+ converted_args = [
462
+ self.builder.add_op(read_bool(), arg)
463
+ if self.builder.hugr.port_type(arg.out_port()) == OpaqueBool
464
+ else arg
465
+ for arg in args
466
+ ]
467
+ node = self.builder.add_op(op, *converted_args)
468
+ result = list(node.outputs())
469
+ converted_result = [
470
+ self.builder.add_op(make_opaque(), res)
471
+ if self.builder.hugr.port_type(res.out_port()) == ht.Bool
472
+ else res
473
+ for res in result
474
+ ]
475
+ return CallReturnWires(
476
+ regular_returns=converted_result,
477
+ inout_returns=[],
478
+ )
479
+
480
+
481
+ class NoopCompiler(CustomCallCompiler):
482
+ """Call compiler for functions that are noops."""
483
+
484
+ def compile(self, args: list[Wire]) -> list[Wire]:
485
+ return args
486
+
487
+
488
+ class CopyInoutCompiler(CustomInoutCallCompiler):
489
+ """Call compiler for functions that are noops but only want to borrow arguments."""
490
+
491
+ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
492
+ return CallReturnWires(regular_returns=args, inout_returns=args)
@@ -0,0 +1,171 @@
1
+ import ast
2
+ from dataclasses import dataclass, field
3
+ from typing import ClassVar
4
+
5
+ from hugr import Node, Wire
6
+ from hugr.build import function as hf
7
+ from hugr.build.dfg import DefinitionBuilder, OpVar
8
+
9
+ from guppylang_internals.ast_util import AstNode, has_empty_body, with_loc, with_type
10
+ from guppylang_internals.checker.core import Context, Globals
11
+ from guppylang_internals.checker.expr_checker import check_call, synthesize_call
12
+ from guppylang_internals.checker.func_checker import check_signature
13
+ from guppylang_internals.compiler.core import (
14
+ CompilerContext,
15
+ DFContainer,
16
+ requires_monomorphization,
17
+ )
18
+ from guppylang_internals.definition.common import CompilableDef, ParsableDef
19
+ from guppylang_internals.definition.function import (
20
+ PyFunc,
21
+ compile_call,
22
+ load_with_args,
23
+ parse_py_func,
24
+ )
25
+ from guppylang_internals.definition.value import (
26
+ CallableDef,
27
+ CallReturnWires,
28
+ CompiledCallableDef,
29
+ CompiledHugrNodeDef,
30
+ )
31
+ from guppylang_internals.diagnostic import Error
32
+ from guppylang_internals.error import GuppyError
33
+ from guppylang_internals.nodes import GlobalCall
34
+ from guppylang_internals.span import SourceMap
35
+ from guppylang_internals.tys.param import Parameter
36
+ from guppylang_internals.tys.subst import Inst, Subst
37
+ from guppylang_internals.tys.ty import Type
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class BodyNotEmptyError(Error):
42
+ title: ClassVar[str] = "Unexpected function body"
43
+ span_label: ClassVar[str] = "Body of declared function `{name}` must be empty"
44
+ name: str
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class MonomorphizeError(Error):
49
+ title: ClassVar[str] = "Invalid function declaration"
50
+ span_label: ClassVar[str] = (
51
+ "Function declaration `{name}` is not allowed to be generic over `{param}`"
52
+ )
53
+ name: str
54
+ param: Parameter
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class RawFunctionDecl(ParsableDef):
59
+ """A raw function declaration provided by the user.
60
+
61
+ The raw declaration stores exactly what the user has written (i.e. the AST), without
62
+ any additional checking or parsing.
63
+ """
64
+
65
+ python_func: PyFunc
66
+ description: str = field(default="function", init=False)
67
+
68
+ def parse(self, globals: Globals, sources: SourceMap) -> "CheckedFunctionDecl":
69
+ """Parses and checks the user-provided signature of the function."""
70
+ func_ast, docstring = parse_py_func(self.python_func, sources)
71
+ ty = check_signature(func_ast, globals)
72
+ if not has_empty_body(func_ast):
73
+ raise GuppyError(BodyNotEmptyError(func_ast.body[0], self.name))
74
+ # Make sure we won't need monomorphization to compile this declaration
75
+ for param in ty.params:
76
+ if requires_monomorphization(param):
77
+ raise GuppyError(MonomorphizeError(func_ast, self.name, param))
78
+ return CheckedFunctionDecl(
79
+ self.id,
80
+ self.name,
81
+ func_ast,
82
+ ty,
83
+ self.python_func,
84
+ docstring,
85
+ )
86
+
87
+
88
+ @dataclass(frozen=True)
89
+ class CheckedFunctionDecl(RawFunctionDecl, CompilableDef, CallableDef):
90
+ """A function declaration with parsed and checked signature.
91
+
92
+ In particular, this means that we have determined a type for the function.
93
+ """
94
+
95
+ defined_at: ast.FunctionDef
96
+ docstring: str | None
97
+
98
+ def check_call(
99
+ self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
100
+ ) -> tuple[ast.expr, Subst]:
101
+ """Checks the return type of a function call against a given type."""
102
+ # Use default implementation from the expression checker
103
+ args, subst, inst = check_call(self.ty, args, ty, node, ctx)
104
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
105
+ return node, subst
106
+
107
+ def synthesize_call(
108
+ self, args: list[ast.expr], node: AstNode, ctx: Context
109
+ ) -> tuple[GlobalCall, Type]:
110
+ """Synthesizes the return type of a function call."""
111
+ # Use default implementation from the expression checker
112
+ args, ty, inst = synthesize_call(self.ty, args, node, ctx)
113
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
114
+ return with_type(ty, node), ty
115
+
116
+ def compile_outer(
117
+ self, module: DefinitionBuilder[OpVar], ctx: CompilerContext
118
+ ) -> "CompiledFunctionDecl":
119
+ """Adds a Hugr `FuncDecl` node for this function to the Hugr."""
120
+ assert isinstance(
121
+ module, hf.Module
122
+ ), "Functions can only be declared in modules"
123
+ module: hf.Module = module
124
+
125
+ node = module.declare_function(self.name, self.ty.to_hugr_poly(ctx))
126
+ return CompiledFunctionDecl(
127
+ self.id,
128
+ self.name,
129
+ self.defined_at,
130
+ self.ty,
131
+ self.python_func,
132
+ self.docstring,
133
+ node,
134
+ )
135
+
136
+
137
+ @dataclass(frozen=True)
138
+ class CompiledFunctionDecl(
139
+ CheckedFunctionDecl, CompiledCallableDef, CompiledHugrNodeDef
140
+ ):
141
+ """A function declaration with a corresponding Hugr node."""
142
+
143
+ declaration: Node
144
+
145
+ @property
146
+ def hugr_node(self) -> Node:
147
+ """The Hugr node this definition was compiled into."""
148
+ return self.declaration
149
+
150
+ def load_with_args(
151
+ self,
152
+ type_args: Inst,
153
+ dfg: DFContainer,
154
+ ctx: CompilerContext,
155
+ node: AstNode,
156
+ ) -> Wire:
157
+ """Loads the function as a value into a local Hugr dataflow graph."""
158
+ # Use implementation from function definition.
159
+ return load_with_args(type_args, dfg, self.ty, self.declaration)
160
+
161
+ def compile_call(
162
+ self,
163
+ args: list[Wire],
164
+ type_args: Inst,
165
+ dfg: DFContainer,
166
+ ctx: CompilerContext,
167
+ node: AstNode,
168
+ ) -> CallReturnWires:
169
+ """Compiles a call to the function."""
170
+ # Use implementation from function definition.
171
+ return compile_call(args, type_args, dfg, self.ty, self.declaration)