guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import inspect
|
|
3
|
+
import linecache
|
|
4
|
+
import sys
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from types import FrameType
|
|
8
|
+
from typing import ClassVar
|
|
9
|
+
|
|
10
|
+
from hugr import Wire, ops
|
|
11
|
+
|
|
12
|
+
from guppylang_internals.ast_util import AstNode, annotate_location
|
|
13
|
+
from guppylang_internals.checker.core import Globals
|
|
14
|
+
from guppylang_internals.checker.errors.generic import (
|
|
15
|
+
ExpectedError,
|
|
16
|
+
UnexpectedError,
|
|
17
|
+
UnsupportedError,
|
|
18
|
+
)
|
|
19
|
+
from guppylang_internals.compiler.core import GlobalConstId
|
|
20
|
+
from guppylang_internals.definition.common import (
|
|
21
|
+
CheckableDef,
|
|
22
|
+
CompiledDef,
|
|
23
|
+
DefId,
|
|
24
|
+
ParsableDef,
|
|
25
|
+
UnknownSourceError,
|
|
26
|
+
)
|
|
27
|
+
from guppylang_internals.definition.custom import (
|
|
28
|
+
CustomCallCompiler,
|
|
29
|
+
CustomFunctionDef,
|
|
30
|
+
DefaultCallChecker,
|
|
31
|
+
)
|
|
32
|
+
from guppylang_internals.definition.function import parse_source
|
|
33
|
+
from guppylang_internals.definition.parameter import ParamDef
|
|
34
|
+
from guppylang_internals.definition.ty import TypeDef
|
|
35
|
+
from guppylang_internals.diagnostic import Error, Help, Note
|
|
36
|
+
from guppylang_internals.engine import DEF_STORE
|
|
37
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
38
|
+
from guppylang_internals.ipython_inspect import is_running_ipython
|
|
39
|
+
from guppylang_internals.span import SourceMap, Span, to_span
|
|
40
|
+
from guppylang_internals.tys.arg import Argument
|
|
41
|
+
from guppylang_internals.tys.param import Parameter, check_all_args
|
|
42
|
+
from guppylang_internals.tys.parsing import type_from_ast
|
|
43
|
+
from guppylang_internals.tys.ty import (
|
|
44
|
+
FuncInput,
|
|
45
|
+
FunctionType,
|
|
46
|
+
InputFlags,
|
|
47
|
+
StructType,
|
|
48
|
+
Type,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if sys.version_info >= (3, 12):
|
|
52
|
+
from guppylang_internals.tys.parsing import parse_parameter
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(frozen=True)
|
|
56
|
+
class UncheckedStructField:
|
|
57
|
+
"""A single field on a struct whose type has not been checked yet."""
|
|
58
|
+
|
|
59
|
+
name: str
|
|
60
|
+
type_ast: ast.expr
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class StructField:
|
|
65
|
+
"""A single field on a struct."""
|
|
66
|
+
|
|
67
|
+
name: str
|
|
68
|
+
ty: Type
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(frozen=True)
|
|
72
|
+
class RedundantParamsError(Error):
|
|
73
|
+
title: ClassVar[str] = "Generic parameters already specified"
|
|
74
|
+
span_label: ClassVar[str] = "Duplicate specification of generic parameters"
|
|
75
|
+
struct_name: str
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True)
|
|
78
|
+
class PrevSpec(Note):
|
|
79
|
+
span_label: ClassVar[str] = (
|
|
80
|
+
"Parameters of `{struct_name}` are already specified here"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(frozen=True)
|
|
85
|
+
class DuplicateFieldError(Error):
|
|
86
|
+
title: ClassVar[str] = "Duplicate field"
|
|
87
|
+
span_label: ClassVar[str] = (
|
|
88
|
+
"Struct `{struct_name}` already contains a field named `{field_name}`"
|
|
89
|
+
)
|
|
90
|
+
struct_name: str
|
|
91
|
+
field_name: str
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass(frozen=True)
|
|
95
|
+
class NonGuppyMethodError(Error):
|
|
96
|
+
title: ClassVar[str] = "Not a Guppy method"
|
|
97
|
+
span_label: ClassVar[str] = (
|
|
98
|
+
"Method `{method_name}` of struct `{struct_name}` is not a Guppy function"
|
|
99
|
+
)
|
|
100
|
+
struct_name: str
|
|
101
|
+
method_name: str
|
|
102
|
+
|
|
103
|
+
@dataclass(frozen=True)
|
|
104
|
+
class Suggestion(Help):
|
|
105
|
+
message: ClassVar[str] = (
|
|
106
|
+
"Add a `@guppy` annotation to turn `{method_name}` into a Guppy method"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def __post_init__(self) -> None:
|
|
110
|
+
self.add_sub_diagnostic(NonGuppyMethodError.Suggestion(None))
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass(frozen=True)
|
|
114
|
+
class RawStructDef(TypeDef, ParsableDef):
|
|
115
|
+
"""A raw struct type definition that has not been parsed yet."""
|
|
116
|
+
|
|
117
|
+
python_class: type
|
|
118
|
+
|
|
119
|
+
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
|
|
120
|
+
"""Parses the raw class object into an AST and checks that it is well-formed."""
|
|
121
|
+
frame = DEF_STORE.frames[self.id]
|
|
122
|
+
cls_def = parse_py_class(self.python_class, frame, sources)
|
|
123
|
+
if cls_def.keywords:
|
|
124
|
+
raise GuppyError(UnexpectedError(cls_def.keywords[0], "keyword"))
|
|
125
|
+
|
|
126
|
+
# Look for generic parameters from Python 3.12 style syntax
|
|
127
|
+
params = []
|
|
128
|
+
params_span: Span | None = None
|
|
129
|
+
if sys.version_info >= (3, 12):
|
|
130
|
+
if cls_def.type_params:
|
|
131
|
+
first, last = cls_def.type_params[0], cls_def.type_params[-1]
|
|
132
|
+
params_span = Span(to_span(first).start, to_span(last).end)
|
|
133
|
+
params = [
|
|
134
|
+
parse_parameter(node, idx, globals)
|
|
135
|
+
for idx, node in enumerate(cls_def.type_params)
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
# The only base we allow is `Generic[...]` to specify generic parameters with
|
|
139
|
+
# the legacy syntax
|
|
140
|
+
match cls_def.bases:
|
|
141
|
+
case []:
|
|
142
|
+
pass
|
|
143
|
+
case [base] if elems := try_parse_generic_base(base):
|
|
144
|
+
# Complain if we already have Python 3.12 generic params
|
|
145
|
+
if params_span is not None:
|
|
146
|
+
err: Error = RedundantParamsError(base, self.name)
|
|
147
|
+
err.add_sub_diagnostic(RedundantParamsError.PrevSpec(params_span))
|
|
148
|
+
raise GuppyError(err)
|
|
149
|
+
params = params_from_ast(elems, globals)
|
|
150
|
+
case bases:
|
|
151
|
+
err = UnsupportedError(bases[0], "Struct inheritance", singular=True)
|
|
152
|
+
raise GuppyError(err)
|
|
153
|
+
|
|
154
|
+
fields: list[UncheckedStructField] = []
|
|
155
|
+
used_field_names: set[str] = set()
|
|
156
|
+
used_func_names: dict[str, ast.FunctionDef] = {}
|
|
157
|
+
for i, node in enumerate(cls_def.body):
|
|
158
|
+
match i, node:
|
|
159
|
+
# We allow `pass` statements to define empty structs
|
|
160
|
+
case _, ast.Pass():
|
|
161
|
+
pass
|
|
162
|
+
# Docstrings are also fine if they occur at the start
|
|
163
|
+
case 0, ast.Expr(value=ast.Constant(value=v)) if isinstance(v, str):
|
|
164
|
+
pass
|
|
165
|
+
# Ensure that all function definitions are Guppy functions
|
|
166
|
+
case _, ast.FunctionDef(name=name) as node:
|
|
167
|
+
from guppylang.defs import GuppyDefinition
|
|
168
|
+
|
|
169
|
+
v = getattr(self.python_class, name)
|
|
170
|
+
if not isinstance(v, GuppyDefinition):
|
|
171
|
+
raise GuppyError(NonGuppyMethodError(node, self.name, name))
|
|
172
|
+
used_func_names[name] = node
|
|
173
|
+
if name in used_field_names:
|
|
174
|
+
raise GuppyError(DuplicateFieldError(node, self.name, name))
|
|
175
|
+
# Struct fields are declared via annotated assignments without value
|
|
176
|
+
case _, ast.AnnAssign(target=ast.Name(id=field_name)) as node:
|
|
177
|
+
if node.value:
|
|
178
|
+
err = UnsupportedError(node.value, "Default struct values")
|
|
179
|
+
raise GuppyError(err)
|
|
180
|
+
if field_name in used_field_names:
|
|
181
|
+
err = DuplicateFieldError(node.target, self.name, field_name)
|
|
182
|
+
raise GuppyError(err)
|
|
183
|
+
fields.append(UncheckedStructField(field_name, node.annotation))
|
|
184
|
+
used_field_names.add(field_name)
|
|
185
|
+
case _, node:
|
|
186
|
+
err = UnexpectedError(
|
|
187
|
+
node, "statement", unexpected_in="struct definition"
|
|
188
|
+
)
|
|
189
|
+
raise GuppyError(err)
|
|
190
|
+
|
|
191
|
+
# Ensure that functions don't override struct fields
|
|
192
|
+
if overridden := used_field_names.intersection(used_func_names.keys()):
|
|
193
|
+
x = overridden.pop()
|
|
194
|
+
raise GuppyError(DuplicateFieldError(used_func_names[x], self.name, x))
|
|
195
|
+
|
|
196
|
+
return ParsedStructDef(self.id, self.name, cls_def, params, fields)
|
|
197
|
+
|
|
198
|
+
def check_instantiate(
|
|
199
|
+
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
200
|
+
) -> Type:
|
|
201
|
+
raise InternalGuppyError("Tried to instantiate raw struct definition")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass(frozen=True)
|
|
205
|
+
class ParsedStructDef(TypeDef, CheckableDef):
|
|
206
|
+
"""A struct definition whose fields have not been checked yet."""
|
|
207
|
+
|
|
208
|
+
defined_at: ast.ClassDef
|
|
209
|
+
params: Sequence[Parameter]
|
|
210
|
+
fields: Sequence[UncheckedStructField]
|
|
211
|
+
|
|
212
|
+
def check(self, globals: Globals) -> "CheckedStructDef":
|
|
213
|
+
"""Checks that all struct fields have valid types."""
|
|
214
|
+
# Before checking the fields, make sure that this definition is not recursive,
|
|
215
|
+
# otherwise the code below would not terminate.
|
|
216
|
+
# TODO: This is not ideal (see todo in `check_instantiate`)
|
|
217
|
+
param_var_mapping = {p.name: p for p in self.params}
|
|
218
|
+
check_not_recursive(self, globals, param_var_mapping)
|
|
219
|
+
|
|
220
|
+
fields = [
|
|
221
|
+
StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping))
|
|
222
|
+
for f in self.fields
|
|
223
|
+
]
|
|
224
|
+
return CheckedStructDef(
|
|
225
|
+
self.id, self.name, self.defined_at, self.params, fields
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def check_instantiate(
|
|
229
|
+
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
230
|
+
) -> Type:
|
|
231
|
+
"""Checks if the struct can be instantiated with the given arguments."""
|
|
232
|
+
check_all_args(self.params, args, self.name, loc)
|
|
233
|
+
# Obtain a checked version of this struct definition so we can construct a
|
|
234
|
+
# `StructType` instance
|
|
235
|
+
globals = Globals(DEF_STORE.frames[self.id])
|
|
236
|
+
# TODO: This is quite bad: If we have a cyclic definition this will not
|
|
237
|
+
# terminate, so we have to check for cycles in every call to `check`. The
|
|
238
|
+
# proper way to deal with this is changing `StructType` such that it only
|
|
239
|
+
# takes a `DefId` instead of a `CheckedStructDef`. But this will be a bigger
|
|
240
|
+
# refactor...
|
|
241
|
+
checked_def = self.check(globals)
|
|
242
|
+
return StructType(args, checked_def)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@dataclass(frozen=True)
|
|
246
|
+
class CheckedStructDef(TypeDef, CompiledDef):
|
|
247
|
+
"""A struct definition that has been fully checked."""
|
|
248
|
+
|
|
249
|
+
defined_at: ast.ClassDef
|
|
250
|
+
params: Sequence[Parameter]
|
|
251
|
+
fields: Sequence[StructField]
|
|
252
|
+
|
|
253
|
+
def check_instantiate(
|
|
254
|
+
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
255
|
+
) -> Type:
|
|
256
|
+
"""Checks if the struct can be instantiated with the given arguments."""
|
|
257
|
+
check_all_args(self.params, args, self.name, loc)
|
|
258
|
+
return StructType(args, self)
|
|
259
|
+
|
|
260
|
+
def generated_methods(self) -> list[CustomFunctionDef]:
|
|
261
|
+
"""Auto-generated methods for this struct."""
|
|
262
|
+
|
|
263
|
+
class ConstructorCompiler(CustomCallCompiler):
|
|
264
|
+
"""Compiler for the `__new__` constructor method of a struct."""
|
|
265
|
+
|
|
266
|
+
def compile(self, args: list[Wire]) -> list[Wire]:
|
|
267
|
+
return list(self.builder.add(ops.MakeTuple()(*args)))
|
|
268
|
+
|
|
269
|
+
constructor_sig = FunctionType(
|
|
270
|
+
inputs=[
|
|
271
|
+
FuncInput(f.ty, InputFlags.Owned if f.ty.linear else InputFlags.NoFlags)
|
|
272
|
+
for f in self.fields
|
|
273
|
+
],
|
|
274
|
+
output=StructType(
|
|
275
|
+
defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
|
|
276
|
+
),
|
|
277
|
+
input_names=[f.name for f in self.fields],
|
|
278
|
+
params=self.params,
|
|
279
|
+
)
|
|
280
|
+
constructor_def = CustomFunctionDef(
|
|
281
|
+
id=DefId.fresh(),
|
|
282
|
+
name="__new__",
|
|
283
|
+
defined_at=self.defined_at,
|
|
284
|
+
ty=constructor_sig,
|
|
285
|
+
call_checker=DefaultCallChecker(),
|
|
286
|
+
call_compiler=ConstructorCompiler(),
|
|
287
|
+
higher_order_value=True,
|
|
288
|
+
higher_order_func_id=GlobalConstId.fresh(f"{self.name}.__new__"),
|
|
289
|
+
has_signature=True,
|
|
290
|
+
)
|
|
291
|
+
return [constructor_def]
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def parse_py_class(
|
|
295
|
+
cls: type, defining_frame: FrameType, sources: SourceMap
|
|
296
|
+
) -> ast.ClassDef:
|
|
297
|
+
"""Parses a Python class object into an AST."""
|
|
298
|
+
module = inspect.getmodule(cls)
|
|
299
|
+
if module is None:
|
|
300
|
+
raise GuppyError(UnknownSourceError(None, cls))
|
|
301
|
+
|
|
302
|
+
# If we are running IPython, `inspect.getsourcefile` won't work if the class was
|
|
303
|
+
# defined inside a cell. See
|
|
304
|
+
# - https://bugs.python.org/issue33826
|
|
305
|
+
# - https://github.com/ipython/ipython/issues/11249
|
|
306
|
+
# - https://github.com/wandb/weave/pull/1864
|
|
307
|
+
if is_running_ipython() and module.__name__ == "__main__":
|
|
308
|
+
file: str | None = defining_frame.f_code.co_filename
|
|
309
|
+
else:
|
|
310
|
+
file = inspect.getsourcefile(cls)
|
|
311
|
+
if file is None:
|
|
312
|
+
raise GuppyError(UnknownSourceError(None, cls))
|
|
313
|
+
|
|
314
|
+
# We can't rely on `inspect.getsourcelines` since it doesn't work properly for
|
|
315
|
+
# classes prior to Python 3.13. See https://github.com/CQCL/guppylang/issues/1107.
|
|
316
|
+
# Instead, we reproduce the behaviour of Python >= 3.13 using the `__firstlineno__`
|
|
317
|
+
# attribute. See https://github.com/python/cpython/blob/3.13/Lib/inspect.py#L1052.
|
|
318
|
+
# In the decorator, we make sure that `__firstlineno__` is set, even if we're not
|
|
319
|
+
# on Python 3.13.
|
|
320
|
+
file_lines = linecache.getlines(file)
|
|
321
|
+
line_offset = cls.__firstlineno__ # type: ignore[attr-defined]
|
|
322
|
+
source_lines = inspect.getblock(file_lines[line_offset - 1 :])
|
|
323
|
+
source, cls_ast, line_offset = parse_source(source_lines, line_offset)
|
|
324
|
+
|
|
325
|
+
# Store the source file in our cache
|
|
326
|
+
sources.add_file(file)
|
|
327
|
+
annotate_location(cls_ast, source, file, line_offset)
|
|
328
|
+
if not isinstance(cls_ast, ast.ClassDef):
|
|
329
|
+
raise GuppyError(ExpectedError(cls_ast, "a class definition"))
|
|
330
|
+
return cls_ast
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def try_parse_generic_base(node: ast.expr) -> list[ast.expr] | None:
|
|
334
|
+
"""Checks if an AST node corresponds to a `Generic[T1, ..., Tn]` base class.
|
|
335
|
+
|
|
336
|
+
Returns the generic parameters or `None` if the AST has a different shape
|
|
337
|
+
"""
|
|
338
|
+
match node:
|
|
339
|
+
case ast.Subscript(value=ast.Name(id="Generic"), slice=elem):
|
|
340
|
+
return elem.elts if isinstance(elem, ast.Tuple) else [elem]
|
|
341
|
+
case _:
|
|
342
|
+
return None
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
@dataclass(frozen=True)
|
|
346
|
+
class RepeatedTypeParamError(Error):
|
|
347
|
+
title: ClassVar[str] = "Duplicate type parameter"
|
|
348
|
+
span_label: ClassVar[str] = "Type parameter `{name}` cannot be used multiple times"
|
|
349
|
+
name: str
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Parameter]:
|
|
353
|
+
"""Parses a list of AST nodes into unique type parameters.
|
|
354
|
+
|
|
355
|
+
Raises user errors if the AST nodes don't correspond to parameters or parameters
|
|
356
|
+
occur multiple times.
|
|
357
|
+
"""
|
|
358
|
+
params: list[Parameter] = []
|
|
359
|
+
params_set: set[DefId] = set()
|
|
360
|
+
for node in nodes:
|
|
361
|
+
if isinstance(node, ast.Name) and node.id in globals:
|
|
362
|
+
defn = globals[node.id]
|
|
363
|
+
if isinstance(defn, ParamDef):
|
|
364
|
+
if defn.id in params_set:
|
|
365
|
+
raise GuppyError(RepeatedTypeParamError(node, node.id))
|
|
366
|
+
params.append(defn.to_param(len(params)))
|
|
367
|
+
params_set.add(defn.id)
|
|
368
|
+
continue
|
|
369
|
+
raise GuppyError(ExpectedError(node, "a type parameter"))
|
|
370
|
+
return params
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def check_not_recursive(
|
|
374
|
+
defn: ParsedStructDef, globals: Globals, param_var_mapping: dict[str, Parameter]
|
|
375
|
+
) -> None:
|
|
376
|
+
"""Throws a user error if the given struct definition is recursive."""
|
|
377
|
+
|
|
378
|
+
# TODO: The implementation below hijacks the type parsing logic to detect recursive
|
|
379
|
+
# structs. This is not great since it repeats the work done during checking. We can
|
|
380
|
+
# get rid of this after resolving the todo in `ParsedStructDef.check_instantiate()`
|
|
381
|
+
|
|
382
|
+
def dummy_check_instantiate(
|
|
383
|
+
args: Sequence[Argument],
|
|
384
|
+
loc: AstNode | None = None,
|
|
385
|
+
) -> Type:
|
|
386
|
+
raise GuppyError(UnsupportedError(loc, "Recursive structs"))
|
|
387
|
+
|
|
388
|
+
original = defn.check_instantiate
|
|
389
|
+
object.__setattr__(defn, "check_instantiate", dummy_check_instantiate)
|
|
390
|
+
for fld in defn.fields:
|
|
391
|
+
type_from_ast(fld.type_ast, globals, param_var_mapping)
|
|
392
|
+
object.__setattr__(defn, "check_instantiate", original)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import hugr.build.function as hf
|
|
7
|
+
import hugr.tys as ht
|
|
8
|
+
from hugr import Node, Wire
|
|
9
|
+
from hugr.build.dfg import DefinitionBuilder, OpVar
|
|
10
|
+
|
|
11
|
+
from guppylang_internals.ast_util import AstNode, with_loc
|
|
12
|
+
from guppylang_internals.checker.core import Context, Globals
|
|
13
|
+
from guppylang_internals.checker.errors.generic import UnsupportedError
|
|
14
|
+
from guppylang_internals.checker.expr_checker import (
|
|
15
|
+
check_call,
|
|
16
|
+
synthesize_call,
|
|
17
|
+
)
|
|
18
|
+
from guppylang_internals.checker.func_checker import (
|
|
19
|
+
check_signature,
|
|
20
|
+
)
|
|
21
|
+
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
22
|
+
from guppylang_internals.definition.common import (
|
|
23
|
+
CompilableDef,
|
|
24
|
+
ParsableDef,
|
|
25
|
+
)
|
|
26
|
+
from guppylang_internals.definition.function import parse_py_func
|
|
27
|
+
from guppylang_internals.definition.value import (
|
|
28
|
+
CallableDef,
|
|
29
|
+
CallReturnWires,
|
|
30
|
+
CompiledCallableDef,
|
|
31
|
+
CompiledHugrNodeDef,
|
|
32
|
+
)
|
|
33
|
+
from guppylang_internals.error import GuppyError
|
|
34
|
+
from guppylang_internals.nodes import GlobalCall
|
|
35
|
+
from guppylang_internals.span import SourceMap
|
|
36
|
+
from guppylang_internals.tys.subst import Inst, Subst
|
|
37
|
+
from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
|
|
38
|
+
|
|
39
|
+
PyFunc = Callable[..., Any]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class RawTracedFunctionDef(ParsableDef):
|
|
44
|
+
python_func: PyFunc
|
|
45
|
+
|
|
46
|
+
description: str = field(default="function", init=False)
|
|
47
|
+
|
|
48
|
+
def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef":
|
|
49
|
+
"""Parses and checks the user-provided signature of the function."""
|
|
50
|
+
func_ast, _docstring = parse_py_func(self.python_func, sources)
|
|
51
|
+
ty = check_signature(func_ast, globals)
|
|
52
|
+
if ty.parametrized:
|
|
53
|
+
raise GuppyError(UnsupportedError(func_ast, "Generic comptime functions"))
|
|
54
|
+
return TracedFunctionDef(self.id, self.name, func_ast, ty, self.python_func)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class TracedFunctionDef(RawTracedFunctionDef, CallableDef, CompilableDef):
|
|
59
|
+
python_func: PyFunc
|
|
60
|
+
ty: FunctionType
|
|
61
|
+
defined_at: ast.FunctionDef
|
|
62
|
+
|
|
63
|
+
def check_call(
|
|
64
|
+
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
|
|
65
|
+
) -> tuple[ast.expr, Subst]:
|
|
66
|
+
"""Checks the return type of a function call against a given type."""
|
|
67
|
+
# Use default implementation from the expression checker
|
|
68
|
+
args, subst, inst = check_call(self.ty, args, ty, node, ctx)
|
|
69
|
+
node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
|
|
70
|
+
return node, subst
|
|
71
|
+
|
|
72
|
+
def synthesize_call(
|
|
73
|
+
self, args: list[ast.expr], node: AstNode, ctx: Context
|
|
74
|
+
) -> tuple[ast.expr, Type]:
|
|
75
|
+
"""Synthesizes the return type of a function call."""
|
|
76
|
+
# Use default implementation from the expression checker
|
|
77
|
+
args, ty, inst = synthesize_call(self.ty, args, node, ctx)
|
|
78
|
+
node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
|
|
79
|
+
return node, ty
|
|
80
|
+
|
|
81
|
+
def compile_outer(
|
|
82
|
+
self, module: DefinitionBuilder[OpVar], ctx: CompilerContext
|
|
83
|
+
) -> "CompiledTracedFunctionDef":
|
|
84
|
+
"""Adds a Hugr `FuncDefn` node for this function to the Hugr.
|
|
85
|
+
|
|
86
|
+
Note that we don't compile the function body at this point since we don't have
|
|
87
|
+
access to the other compiled functions yet. The body is compiled later in
|
|
88
|
+
`CompiledFunctionDef.compile_inner()`.
|
|
89
|
+
"""
|
|
90
|
+
func_type = self.ty.to_hugr_poly(ctx)
|
|
91
|
+
func_def = module.module_root_builder().define_function(
|
|
92
|
+
self.name, func_type.body.input, func_type.body.output, func_type.params
|
|
93
|
+
)
|
|
94
|
+
return CompiledTracedFunctionDef(
|
|
95
|
+
self.id,
|
|
96
|
+
self.name,
|
|
97
|
+
self.defined_at,
|
|
98
|
+
self.ty,
|
|
99
|
+
self.python_func,
|
|
100
|
+
func_def,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass(frozen=True)
|
|
105
|
+
class CompiledTracedFunctionDef(
|
|
106
|
+
TracedFunctionDef, CompiledCallableDef, CompiledHugrNodeDef
|
|
107
|
+
):
|
|
108
|
+
func_def: hf.Function
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def hugr_node(self) -> Node:
|
|
112
|
+
"""The Hugr node this definition was compiled into."""
|
|
113
|
+
return self.func_def.parent_node
|
|
114
|
+
|
|
115
|
+
def load_with_args(
|
|
116
|
+
self,
|
|
117
|
+
type_args: Inst,
|
|
118
|
+
dfg: DFContainer,
|
|
119
|
+
ctx: CompilerContext,
|
|
120
|
+
node: AstNode,
|
|
121
|
+
) -> Wire:
|
|
122
|
+
"""Loads the function as a value into a local Hugr dataflow graph."""
|
|
123
|
+
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr(ctx)
|
|
124
|
+
type_args: list[ht.TypeArg] = [arg.to_hugr(ctx) for arg in type_args]
|
|
125
|
+
return dfg.builder.load_function(self.func_def, func_ty, type_args)
|
|
126
|
+
|
|
127
|
+
def compile_call(
|
|
128
|
+
self,
|
|
129
|
+
args: list[Wire],
|
|
130
|
+
type_args: Inst,
|
|
131
|
+
dfg: DFContainer,
|
|
132
|
+
ctx: CompilerContext,
|
|
133
|
+
node: AstNode,
|
|
134
|
+
) -> CallReturnWires:
|
|
135
|
+
"""Compiles a call to the function."""
|
|
136
|
+
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr(ctx)
|
|
137
|
+
type_args: list[ht.TypeArg] = [arg.to_hugr(ctx) for arg in type_args]
|
|
138
|
+
num_returns = len(type_to_row(self.ty.output))
|
|
139
|
+
call = dfg.builder.call(
|
|
140
|
+
self.func_def, *args, instantiation=func_ty, type_args=type_args
|
|
141
|
+
)
|
|
142
|
+
return CallReturnWires(
|
|
143
|
+
regular_returns=list(call[:num_returns]),
|
|
144
|
+
inout_returns=list(call[num_returns:]),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def compile_inner(self, ctx: CompilerContext) -> None:
|
|
148
|
+
"""Compiles the body of the function by tracing it."""
|
|
149
|
+
from guppylang_internals.tracing.function import trace_function
|
|
150
|
+
|
|
151
|
+
trace_function(self.python_func, self.ty, self.func_def, ctx, self.defined_at)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
from hugr import tys
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.ast_util import AstNode
|
|
8
|
+
from guppylang_internals.definition.common import CompiledDef, Definition
|
|
9
|
+
from guppylang_internals.tys.arg import Argument
|
|
10
|
+
from guppylang_internals.tys.common import ToHugrContext
|
|
11
|
+
from guppylang_internals.tys.param import Parameter, check_all_args
|
|
12
|
+
from guppylang_internals.tys.ty import OpaqueType, Type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class TypeDef(Definition):
|
|
17
|
+
"""Abstract base class for type definitions."""
|
|
18
|
+
|
|
19
|
+
description: str = field(default="type", init=False)
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def check_instantiate(
|
|
23
|
+
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
24
|
+
) -> Type:
|
|
25
|
+
"""Checks if the type definition can be instantiated with the given arguments.
|
|
26
|
+
|
|
27
|
+
Returns the resulting concrete type or raises a user error if the arguments are
|
|
28
|
+
invalid.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class OpaqueTypeDef(TypeDef, CompiledDef):
|
|
34
|
+
"""An opaque type definition that is backed by some Hugr type."""
|
|
35
|
+
|
|
36
|
+
params: Sequence[Parameter]
|
|
37
|
+
never_copyable: bool
|
|
38
|
+
never_droppable: bool
|
|
39
|
+
to_hugr: Callable[[Sequence[Argument], ToHugrContext], tys.Type]
|
|
40
|
+
bound: tys.TypeBound | None = None
|
|
41
|
+
|
|
42
|
+
def check_instantiate(
|
|
43
|
+
self, args: Sequence[Argument], loc: AstNode | None = None
|
|
44
|
+
) -> OpaqueType:
|
|
45
|
+
"""Checks if the type definition can be instantiated with the given arguments.
|
|
46
|
+
|
|
47
|
+
Returns the resulting concrete type or raises a user error if the arguments are
|
|
48
|
+
invalid.
|
|
49
|
+
"""
|
|
50
|
+
check_all_args(self.params, args, self.name, loc)
|
|
51
|
+
return OpaqueType(args, self)
|