guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,392 @@
1
+ import ast
2
+ import inspect
3
+ import linecache
4
+ import sys
5
+ from collections.abc import Sequence
6
+ from dataclasses import dataclass
7
+ from types import FrameType
8
+ from typing import ClassVar
9
+
10
+ from hugr import Wire, ops
11
+
12
+ from guppylang_internals.ast_util import AstNode, annotate_location
13
+ from guppylang_internals.checker.core import Globals
14
+ from guppylang_internals.checker.errors.generic import (
15
+ ExpectedError,
16
+ UnexpectedError,
17
+ UnsupportedError,
18
+ )
19
+ from guppylang_internals.compiler.core import GlobalConstId
20
+ from guppylang_internals.definition.common import (
21
+ CheckableDef,
22
+ CompiledDef,
23
+ DefId,
24
+ ParsableDef,
25
+ UnknownSourceError,
26
+ )
27
+ from guppylang_internals.definition.custom import (
28
+ CustomCallCompiler,
29
+ CustomFunctionDef,
30
+ DefaultCallChecker,
31
+ )
32
+ from guppylang_internals.definition.function import parse_source
33
+ from guppylang_internals.definition.parameter import ParamDef
34
+ from guppylang_internals.definition.ty import TypeDef
35
+ from guppylang_internals.diagnostic import Error, Help, Note
36
+ from guppylang_internals.engine import DEF_STORE
37
+ from guppylang_internals.error import GuppyError, InternalGuppyError
38
+ from guppylang_internals.ipython_inspect import is_running_ipython
39
+ from guppylang_internals.span import SourceMap, Span, to_span
40
+ from guppylang_internals.tys.arg import Argument
41
+ from guppylang_internals.tys.param import Parameter, check_all_args
42
+ from guppylang_internals.tys.parsing import type_from_ast
43
+ from guppylang_internals.tys.ty import (
44
+ FuncInput,
45
+ FunctionType,
46
+ InputFlags,
47
+ StructType,
48
+ Type,
49
+ )
50
+
51
+ if sys.version_info >= (3, 12):
52
+ from guppylang_internals.tys.parsing import parse_parameter
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class UncheckedStructField:
57
+ """A single field on a struct whose type has not been checked yet."""
58
+
59
+ name: str
60
+ type_ast: ast.expr
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class StructField:
65
+ """A single field on a struct."""
66
+
67
+ name: str
68
+ ty: Type
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class RedundantParamsError(Error):
73
+ title: ClassVar[str] = "Generic parameters already specified"
74
+ span_label: ClassVar[str] = "Duplicate specification of generic parameters"
75
+ struct_name: str
76
+
77
+ @dataclass(frozen=True)
78
+ class PrevSpec(Note):
79
+ span_label: ClassVar[str] = (
80
+ "Parameters of `{struct_name}` are already specified here"
81
+ )
82
+
83
+
84
+ @dataclass(frozen=True)
85
+ class DuplicateFieldError(Error):
86
+ title: ClassVar[str] = "Duplicate field"
87
+ span_label: ClassVar[str] = (
88
+ "Struct `{struct_name}` already contains a field named `{field_name}`"
89
+ )
90
+ struct_name: str
91
+ field_name: str
92
+
93
+
94
+ @dataclass(frozen=True)
95
+ class NonGuppyMethodError(Error):
96
+ title: ClassVar[str] = "Not a Guppy method"
97
+ span_label: ClassVar[str] = (
98
+ "Method `{method_name}` of struct `{struct_name}` is not a Guppy function"
99
+ )
100
+ struct_name: str
101
+ method_name: str
102
+
103
+ @dataclass(frozen=True)
104
+ class Suggestion(Help):
105
+ message: ClassVar[str] = (
106
+ "Add a `@guppy` annotation to turn `{method_name}` into a Guppy method"
107
+ )
108
+
109
+ def __post_init__(self) -> None:
110
+ self.add_sub_diagnostic(NonGuppyMethodError.Suggestion(None))
111
+
112
+
113
+ @dataclass(frozen=True)
114
+ class RawStructDef(TypeDef, ParsableDef):
115
+ """A raw struct type definition that has not been parsed yet."""
116
+
117
+ python_class: type
118
+
119
+ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
120
+ """Parses the raw class object into an AST and checks that it is well-formed."""
121
+ frame = DEF_STORE.frames[self.id]
122
+ cls_def = parse_py_class(self.python_class, frame, sources)
123
+ if cls_def.keywords:
124
+ raise GuppyError(UnexpectedError(cls_def.keywords[0], "keyword"))
125
+
126
+ # Look for generic parameters from Python 3.12 style syntax
127
+ params = []
128
+ params_span: Span | None = None
129
+ if sys.version_info >= (3, 12):
130
+ if cls_def.type_params:
131
+ first, last = cls_def.type_params[0], cls_def.type_params[-1]
132
+ params_span = Span(to_span(first).start, to_span(last).end)
133
+ params = [
134
+ parse_parameter(node, idx, globals)
135
+ for idx, node in enumerate(cls_def.type_params)
136
+ ]
137
+
138
+ # The only base we allow is `Generic[...]` to specify generic parameters with
139
+ # the legacy syntax
140
+ match cls_def.bases:
141
+ case []:
142
+ pass
143
+ case [base] if elems := try_parse_generic_base(base):
144
+ # Complain if we already have Python 3.12 generic params
145
+ if params_span is not None:
146
+ err: Error = RedundantParamsError(base, self.name)
147
+ err.add_sub_diagnostic(RedundantParamsError.PrevSpec(params_span))
148
+ raise GuppyError(err)
149
+ params = params_from_ast(elems, globals)
150
+ case bases:
151
+ err = UnsupportedError(bases[0], "Struct inheritance", singular=True)
152
+ raise GuppyError(err)
153
+
154
+ fields: list[UncheckedStructField] = []
155
+ used_field_names: set[str] = set()
156
+ used_func_names: dict[str, ast.FunctionDef] = {}
157
+ for i, node in enumerate(cls_def.body):
158
+ match i, node:
159
+ # We allow `pass` statements to define empty structs
160
+ case _, ast.Pass():
161
+ pass
162
+ # Docstrings are also fine if they occur at the start
163
+ case 0, ast.Expr(value=ast.Constant(value=v)) if isinstance(v, str):
164
+ pass
165
+ # Ensure that all function definitions are Guppy functions
166
+ case _, ast.FunctionDef(name=name) as node:
167
+ from guppylang.defs import GuppyDefinition
168
+
169
+ v = getattr(self.python_class, name)
170
+ if not isinstance(v, GuppyDefinition):
171
+ raise GuppyError(NonGuppyMethodError(node, self.name, name))
172
+ used_func_names[name] = node
173
+ if name in used_field_names:
174
+ raise GuppyError(DuplicateFieldError(node, self.name, name))
175
+ # Struct fields are declared via annotated assignments without value
176
+ case _, ast.AnnAssign(target=ast.Name(id=field_name)) as node:
177
+ if node.value:
178
+ err = UnsupportedError(node.value, "Default struct values")
179
+ raise GuppyError(err)
180
+ if field_name in used_field_names:
181
+ err = DuplicateFieldError(node.target, self.name, field_name)
182
+ raise GuppyError(err)
183
+ fields.append(UncheckedStructField(field_name, node.annotation))
184
+ used_field_names.add(field_name)
185
+ case _, node:
186
+ err = UnexpectedError(
187
+ node, "statement", unexpected_in="struct definition"
188
+ )
189
+ raise GuppyError(err)
190
+
191
+ # Ensure that functions don't override struct fields
192
+ if overridden := used_field_names.intersection(used_func_names.keys()):
193
+ x = overridden.pop()
194
+ raise GuppyError(DuplicateFieldError(used_func_names[x], self.name, x))
195
+
196
+ return ParsedStructDef(self.id, self.name, cls_def, params, fields)
197
+
198
+ def check_instantiate(
199
+ self, args: Sequence[Argument], loc: AstNode | None = None
200
+ ) -> Type:
201
+ raise InternalGuppyError("Tried to instantiate raw struct definition")
202
+
203
+
204
+ @dataclass(frozen=True)
205
+ class ParsedStructDef(TypeDef, CheckableDef):
206
+ """A struct definition whose fields have not been checked yet."""
207
+
208
+ defined_at: ast.ClassDef
209
+ params: Sequence[Parameter]
210
+ fields: Sequence[UncheckedStructField]
211
+
212
+ def check(self, globals: Globals) -> "CheckedStructDef":
213
+ """Checks that all struct fields have valid types."""
214
+ # Before checking the fields, make sure that this definition is not recursive,
215
+ # otherwise the code below would not terminate.
216
+ # TODO: This is not ideal (see todo in `check_instantiate`)
217
+ param_var_mapping = {p.name: p for p in self.params}
218
+ check_not_recursive(self, globals, param_var_mapping)
219
+
220
+ fields = [
221
+ StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping))
222
+ for f in self.fields
223
+ ]
224
+ return CheckedStructDef(
225
+ self.id, self.name, self.defined_at, self.params, fields
226
+ )
227
+
228
+ def check_instantiate(
229
+ self, args: Sequence[Argument], loc: AstNode | None = None
230
+ ) -> Type:
231
+ """Checks if the struct can be instantiated with the given arguments."""
232
+ check_all_args(self.params, args, self.name, loc)
233
+ # Obtain a checked version of this struct definition so we can construct a
234
+ # `StructType` instance
235
+ globals = Globals(DEF_STORE.frames[self.id])
236
+ # TODO: This is quite bad: If we have a cyclic definition this will not
237
+ # terminate, so we have to check for cycles in every call to `check`. The
238
+ # proper way to deal with this is changing `StructType` such that it only
239
+ # takes a `DefId` instead of a `CheckedStructDef`. But this will be a bigger
240
+ # refactor...
241
+ checked_def = self.check(globals)
242
+ return StructType(args, checked_def)
243
+
244
+
245
+ @dataclass(frozen=True)
246
+ class CheckedStructDef(TypeDef, CompiledDef):
247
+ """A struct definition that has been fully checked."""
248
+
249
+ defined_at: ast.ClassDef
250
+ params: Sequence[Parameter]
251
+ fields: Sequence[StructField]
252
+
253
+ def check_instantiate(
254
+ self, args: Sequence[Argument], loc: AstNode | None = None
255
+ ) -> Type:
256
+ """Checks if the struct can be instantiated with the given arguments."""
257
+ check_all_args(self.params, args, self.name, loc)
258
+ return StructType(args, self)
259
+
260
+ def generated_methods(self) -> list[CustomFunctionDef]:
261
+ """Auto-generated methods for this struct."""
262
+
263
+ class ConstructorCompiler(CustomCallCompiler):
264
+ """Compiler for the `__new__` constructor method of a struct."""
265
+
266
+ def compile(self, args: list[Wire]) -> list[Wire]:
267
+ return list(self.builder.add(ops.MakeTuple()(*args)))
268
+
269
+ constructor_sig = FunctionType(
270
+ inputs=[
271
+ FuncInput(f.ty, InputFlags.Owned if f.ty.linear else InputFlags.NoFlags)
272
+ for f in self.fields
273
+ ],
274
+ output=StructType(
275
+ defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
276
+ ),
277
+ input_names=[f.name for f in self.fields],
278
+ params=self.params,
279
+ )
280
+ constructor_def = CustomFunctionDef(
281
+ id=DefId.fresh(),
282
+ name="__new__",
283
+ defined_at=self.defined_at,
284
+ ty=constructor_sig,
285
+ call_checker=DefaultCallChecker(),
286
+ call_compiler=ConstructorCompiler(),
287
+ higher_order_value=True,
288
+ higher_order_func_id=GlobalConstId.fresh(f"{self.name}.__new__"),
289
+ has_signature=True,
290
+ )
291
+ return [constructor_def]
292
+
293
+
294
+ def parse_py_class(
295
+ cls: type, defining_frame: FrameType, sources: SourceMap
296
+ ) -> ast.ClassDef:
297
+ """Parses a Python class object into an AST."""
298
+ module = inspect.getmodule(cls)
299
+ if module is None:
300
+ raise GuppyError(UnknownSourceError(None, cls))
301
+
302
+ # If we are running IPython, `inspect.getsourcefile` won't work if the class was
303
+ # defined inside a cell. See
304
+ # - https://bugs.python.org/issue33826
305
+ # - https://github.com/ipython/ipython/issues/11249
306
+ # - https://github.com/wandb/weave/pull/1864
307
+ if is_running_ipython() and module.__name__ == "__main__":
308
+ file: str | None = defining_frame.f_code.co_filename
309
+ else:
310
+ file = inspect.getsourcefile(cls)
311
+ if file is None:
312
+ raise GuppyError(UnknownSourceError(None, cls))
313
+
314
+ # We can't rely on `inspect.getsourcelines` since it doesn't work properly for
315
+ # classes prior to Python 3.13. See https://github.com/CQCL/guppylang/issues/1107.
316
+ # Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
317
+ # attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
318
+ # In the decorator, we make sure that `__firstlineno__` is set, even if we're not
319
+ # on Python 3.13.
320
+ file_lines = linecache.getlines(file)
321
+ line_offset = cls.__firstlineno__ # type: ignore[attr-defined]
322
+ source_lines = inspect.getblock(file_lines[line_offset - 1 :])
323
+ source, cls_ast, line_offset = parse_source(source_lines, line_offset)
324
+
325
+ # Store the source file in our cache
326
+ sources.add_file(file)
327
+ annotate_location(cls_ast, source, file, line_offset)
328
+ if not isinstance(cls_ast, ast.ClassDef):
329
+ raise GuppyError(ExpectedError(cls_ast, "a class definition"))
330
+ return cls_ast
331
+
332
+
333
+ def try_parse_generic_base(node: ast.expr) -> list[ast.expr] | None:
334
+ """Checks if an AST node corresponds to a `Generic[T1, ..., Tn]` base class.
335
+
336
+ Returns the generic parameters or `None` if the AST has a different shape
337
+ """
338
+ match node:
339
+ case ast.Subscript(value=ast.Name(id="Generic"), slice=elem):
340
+ return elem.elts if isinstance(elem, ast.Tuple) else [elem]
341
+ case _:
342
+ return None
343
+
344
+
345
+ @dataclass(frozen=True)
346
+ class RepeatedTypeParamError(Error):
347
+ title: ClassVar[str] = "Duplicate type parameter"
348
+ span_label: ClassVar[str] = "Type parameter `{name}` cannot be used multiple times"
349
+ name: str
350
+
351
+
352
+ def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Parameter]:
353
+ """Parses a list of AST nodes into unique type parameters.
354
+
355
+ Raises user errors if the AST nodes don't correspond to parameters or parameters
356
+ occur multiple times.
357
+ """
358
+ params: list[Parameter] = []
359
+ params_set: set[DefId] = set()
360
+ for node in nodes:
361
+ if isinstance(node, ast.Name) and node.id in globals:
362
+ defn = globals[node.id]
363
+ if isinstance(defn, ParamDef):
364
+ if defn.id in params_set:
365
+ raise GuppyError(RepeatedTypeParamError(node, node.id))
366
+ params.append(defn.to_param(len(params)))
367
+ params_set.add(defn.id)
368
+ continue
369
+ raise GuppyError(ExpectedError(node, "a type parameter"))
370
+ return params
371
+
372
+
373
+ def check_not_recursive(
374
+ defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
375
+ ) -> None:
376
+ """Throws a user error if the given struct definition is recursive."""
377
+
378
+ # TODO: The implementation below hijacks the type parsing logic to detect recursive
379
+ # structs. This is not great since it repeats the work done during checking. We can
380
+ # get rid of this after resolving the todo in `ParsedStructDef.check_instantiate()`
381
+
382
+ def dummy_check_instantiate(
383
+ args: Sequence[Argument],
384
+ loc: AstNode | None = None,
385
+ ) -> Type:
386
+ raise GuppyError(UnsupportedError(loc, "Recursive structs"))
387
+
388
+ original = defn.check_instantiate
389
+ object.__setattr__(defn, "check_instantiate", dummy_check_instantiate)
390
+ for fld in defn.fields:
391
+ type_from_ast(fld.type_ast, globals, param_var_mapping)
392
+ object.__setattr__(defn, "check_instantiate", original)
@@ -0,0 +1,151 @@
1
+ import ast
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ import hugr.build.function as hf
7
+ import hugr.tys as ht
8
+ from hugr import Node, Wire
9
+ from hugr.build.dfg import DefinitionBuilder, OpVar
10
+
11
+ from guppylang_internals.ast_util import AstNode, with_loc
12
+ from guppylang_internals.checker.core import Context, Globals
13
+ from guppylang_internals.checker.errors.generic import UnsupportedError
14
+ from guppylang_internals.checker.expr_checker import (
15
+ check_call,
16
+ synthesize_call,
17
+ )
18
+ from guppylang_internals.checker.func_checker import (
19
+ check_signature,
20
+ )
21
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
22
+ from guppylang_internals.definition.common import (
23
+ CompilableDef,
24
+ ParsableDef,
25
+ )
26
+ from guppylang_internals.definition.function import parse_py_func
27
+ from guppylang_internals.definition.value import (
28
+ CallableDef,
29
+ CallReturnWires,
30
+ CompiledCallableDef,
31
+ CompiledHugrNodeDef,
32
+ )
33
+ from guppylang_internals.error import GuppyError
34
+ from guppylang_internals.nodes import GlobalCall
35
+ from guppylang_internals.span import SourceMap
36
+ from guppylang_internals.tys.subst import Inst, Subst
37
+ from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
38
+
39
+ PyFunc = Callable[..., Any]
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class RawTracedFunctionDef(ParsableDef):
44
+ python_func: PyFunc
45
+
46
+ description: str = field(default="function", init=False)
47
+
48
+ def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef":
49
+ """Parses and checks the user-provided signature of the function."""
50
+ func_ast, _docstring = parse_py_func(self.python_func, sources)
51
+ ty = check_signature(func_ast, globals)
52
+ if ty.parametrized:
53
+ raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions"))
54
+ return TracedFunctionDef(self.id, self.name, func_ast, ty, self.python_func)
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class TracedFunctionDef(RawTracedFunctionDef, CallableDef, CompilableDef):
59
+ python_func: PyFunc
60
+ ty: FunctionType
61
+ defined_at: ast.FunctionDef
62
+
63
+ def check_call(
64
+ self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
65
+ ) -> tuple[ast.expr, Subst]:
66
+ """Checks the return type of a function call against a given type."""
67
+ # Use default implementation from the expression checker
68
+ args, subst, inst = check_call(self.ty, args, ty, node, ctx)
69
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
70
+ return node, subst
71
+
72
+ def synthesize_call(
73
+ self, args: list[ast.expr], node: AstNode, ctx: Context
74
+ ) -> tuple[ast.expr, Type]:
75
+ """Synthesizes the return type of a function call."""
76
+ # Use default implementation from the expression checker
77
+ args, ty, inst = synthesize_call(self.ty, args, node, ctx)
78
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
79
+ return node, ty
80
+
81
+ def compile_outer(
82
+ self, module: DefinitionBuilder[OpVar], ctx: CompilerContext
83
+ ) -> "CompiledTracedFunctionDef":
84
+ """Adds a Hugr `FuncDefn` node for this function to the Hugr.
85
+
86
+ Note that we don't compile the function body at this point since we don't have
87
+ access to the other compiled functions yet. The body is compiled later in
88
+ `CompiledFunctionDef.compile_inner()`.
89
+ """
90
+ func_type = self.ty.to_hugr_poly(ctx)
91
+ func_def = module.module_root_builder().define_function(
92
+ self.name, func_type.body.input, func_type.body.output, func_type.params
93
+ )
94
+ return CompiledTracedFunctionDef(
95
+ self.id,
96
+ self.name,
97
+ self.defined_at,
98
+ self.ty,
99
+ self.python_func,
100
+ func_def,
101
+ )
102
+
103
+
104
+ @dataclass(frozen=True)
105
+ class CompiledTracedFunctionDef(
106
+ TracedFunctionDef, CompiledCallableDef, CompiledHugrNodeDef
107
+ ):
108
+ func_def: hf.Function
109
+
110
+ @property
111
+ def hugr_node(self) -> Node:
112
+ """The Hugr node this definition was compiled into."""
113
+ return self.func_def.parent_node
114
+
115
+ def load_with_args(
116
+ self,
117
+ type_args: Inst,
118
+ dfg: DFContainer,
119
+ ctx: CompilerContext,
120
+ node: AstNode,
121
+ ) -> Wire:
122
+ """Loads the function as a value into a local Hugr dataflow graph."""
123
+ func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr(ctx)
124
+ type_args: list[ht.TypeArg] = [arg.to_hugr(ctx) for arg in type_args]
125
+ return dfg.builder.load_function(self.func_def, func_ty, type_args)
126
+
127
+ def compile_call(
128
+ self,
129
+ args: list[Wire],
130
+ type_args: Inst,
131
+ dfg: DFContainer,
132
+ ctx: CompilerContext,
133
+ node: AstNode,
134
+ ) -> CallReturnWires:
135
+ """Compiles a call to the function."""
136
+ func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr(ctx)
137
+ type_args: list[ht.TypeArg] = [arg.to_hugr(ctx) for arg in type_args]
138
+ num_returns = len(type_to_row(self.ty.output))
139
+ call = dfg.builder.call(
140
+ self.func_def, *args, instantiation=func_ty, type_args=type_args
141
+ )
142
+ return CallReturnWires(
143
+ regular_returns=list(call[:num_returns]),
144
+ inout_returns=list(call[num_returns:]),
145
+ )
146
+
147
+ def compile_inner(self, ctx: CompilerContext) -> None:
148
+ """Compiles the body of the function by tracing it."""
149
+ from guppylang_internals.tracing.function import trace_function
150
+
151
+ trace_function(self.python_func, self.ty, self.func_def, ctx, self.defined_at)
@@ -0,0 +1,51 @@
1
+ from abc import abstractmethod
2
+ from collections.abc import Callable, Sequence
3
+ from dataclasses import dataclass, field
4
+
5
+ from hugr import tys
6
+
7
+ from guppylang_internals.ast_util import AstNode
8
+ from guppylang_internals.definition.common import CompiledDef, Definition
9
+ from guppylang_internals.tys.arg import Argument
10
+ from guppylang_internals.tys.common import ToHugrContext
11
+ from guppylang_internals.tys.param import Parameter, check_all_args
12
+ from guppylang_internals.tys.ty import OpaqueType, Type
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class TypeDef(Definition):
17
+ """Abstract base class for type definitions."""
18
+
19
+ description: str = field(default="type", init=False)
20
+
21
+ @abstractmethod
22
+ def check_instantiate(
23
+ self, args: Sequence[Argument], loc: AstNode | None = None
24
+ ) -> Type:
25
+ """Checks if the type definition can be instantiated with the given arguments.
26
+
27
+ Returns the resulting concrete type or raises a user error if the arguments are
28
+ invalid.
29
+ """
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class OpaqueTypeDef(TypeDef, CompiledDef):
34
+ """An opaque type definition that is backed by some Hugr type."""
35
+
36
+ params: Sequence[Parameter]
37
+ never_copyable: bool
38
+ never_droppable: bool
39
+ to_hugr: Callable[[Sequence[Argument], ToHugrContext], tys.Type]
40
+ bound: tys.TypeBound | None = None
41
+
42
+ def check_instantiate(
43
+ self, args: Sequence[Argument], loc: AstNode | None = None
44
+ ) -> OpaqueType:
45
+ """Checks if the type definition can be instantiated with the given arguments.
46
+
47
+ Returns the resulting concrete type or raises a user error if the arguments are
48
+ invalid.
49
+ """
50
+ check_all_args(self.params, args, self.name, loc)
51
+ return OpaqueType(args, self)