guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
from hugr import Node, Wire, val
|
|
5
|
+
from hugr.build.dfg import DefinitionBuilder, OpVar
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.ast_util import AstNode
|
|
8
|
+
from guppylang_internals.checker.core import Globals
|
|
9
|
+
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
10
|
+
from guppylang_internals.definition.common import CompilableDef, ParsableDef
|
|
11
|
+
from guppylang_internals.definition.value import (
|
|
12
|
+
CompiledHugrNodeDef,
|
|
13
|
+
CompiledValueDef,
|
|
14
|
+
ValueDef,
|
|
15
|
+
)
|
|
16
|
+
from guppylang_internals.span import SourceMap
|
|
17
|
+
from guppylang_internals.tys.parsing import type_from_ast
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class RawExternDef(ParsableDef):
|
|
22
|
+
"""A raw extern symbol definition provided by the user."""
|
|
23
|
+
|
|
24
|
+
symbol: str
|
|
25
|
+
constant: bool
|
|
26
|
+
type_ast: ast.expr
|
|
27
|
+
|
|
28
|
+
description: str = field(default="extern", init=False)
|
|
29
|
+
|
|
30
|
+
def parse(self, globals: Globals, sources: SourceMap) -> "ExternDef":
|
|
31
|
+
"""Parses and checks the user-provided signature of the function."""
|
|
32
|
+
return ExternDef(
|
|
33
|
+
self.id,
|
|
34
|
+
self.name,
|
|
35
|
+
self.defined_at,
|
|
36
|
+
type_from_ast(self.type_ast, globals, {}),
|
|
37
|
+
self.symbol,
|
|
38
|
+
self.constant,
|
|
39
|
+
self.type_ast,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class ExternDef(RawExternDef, ValueDef, CompilableDef):
|
|
45
|
+
"""An extern symbol definition."""
|
|
46
|
+
|
|
47
|
+
def compile_outer(
|
|
48
|
+
self, graph: DefinitionBuilder[OpVar], ctx: CompilerContext
|
|
49
|
+
) -> "CompiledExternDef":
|
|
50
|
+
"""Adds a Hugr constant node for the extern definition to the provided graph."""
|
|
51
|
+
# The `typ` field must be serialized at this point, to ensure that the
|
|
52
|
+
# `Extension` is serializable.
|
|
53
|
+
custom_const = {
|
|
54
|
+
"symbol": self.symbol,
|
|
55
|
+
"typ": self.ty.to_hugr(ctx)._to_serial_root(),
|
|
56
|
+
"constant": self.constant,
|
|
57
|
+
}
|
|
58
|
+
value = val.Extension(
|
|
59
|
+
name="ConstExternalSymbol",
|
|
60
|
+
typ=self.ty.to_hugr(ctx),
|
|
61
|
+
val=custom_const,
|
|
62
|
+
)
|
|
63
|
+
const_node = graph.add_const(value)
|
|
64
|
+
return CompiledExternDef(
|
|
65
|
+
self.id,
|
|
66
|
+
self.name,
|
|
67
|
+
self.defined_at,
|
|
68
|
+
self.ty,
|
|
69
|
+
self.symbol,
|
|
70
|
+
self.constant,
|
|
71
|
+
self.type_ast,
|
|
72
|
+
const_node,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(frozen=True)
|
|
77
|
+
class CompiledExternDef(ExternDef, CompiledValueDef, CompiledHugrNodeDef):
|
|
78
|
+
"""An extern symbol definition that has been compiled to a Hugr constant."""
|
|
79
|
+
|
|
80
|
+
const_node: Node
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def hugr_node(self) -> Node:
|
|
84
|
+
"""The Hugr node this definition was compiled into."""
|
|
85
|
+
return self.const_node
|
|
86
|
+
|
|
87
|
+
def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:
|
|
88
|
+
"""Loads the extern value into a local Hugr dataflow graph."""
|
|
89
|
+
return dfg.builder.load(self.const_node)
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import inspect
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import hugr.build.function as hf
|
|
8
|
+
import hugr.tys as ht
|
|
9
|
+
from hugr import Node, Wire
|
|
10
|
+
from hugr.build.dfg import DefinitionBuilder, OpVar
|
|
11
|
+
from hugr.hugr.node_port import ToNode
|
|
12
|
+
|
|
13
|
+
from guppylang_internals.ast_util import AstNode, annotate_location, with_loc, with_type
|
|
14
|
+
from guppylang_internals.checker.cfg_checker import CheckedCFG
|
|
15
|
+
from guppylang_internals.checker.core import Context, Globals, Place
|
|
16
|
+
from guppylang_internals.checker.errors.generic import ExpectedError
|
|
17
|
+
from guppylang_internals.checker.expr_checker import check_call, synthesize_call
|
|
18
|
+
from guppylang_internals.checker.func_checker import (
|
|
19
|
+
check_global_func_def,
|
|
20
|
+
check_signature,
|
|
21
|
+
parse_function_with_docstring,
|
|
22
|
+
)
|
|
23
|
+
from guppylang_internals.compiler.core import (
|
|
24
|
+
CompilerContext,
|
|
25
|
+
DFContainer,
|
|
26
|
+
PartiallyMonomorphizedArgs,
|
|
27
|
+
)
|
|
28
|
+
from guppylang_internals.compiler.func_compiler import compile_global_func_def
|
|
29
|
+
from guppylang_internals.definition.common import (
|
|
30
|
+
CheckableDef,
|
|
31
|
+
MonomorphizableDef,
|
|
32
|
+
MonomorphizedDef,
|
|
33
|
+
ParsableDef,
|
|
34
|
+
UnknownSourceError,
|
|
35
|
+
)
|
|
36
|
+
from guppylang_internals.definition.value import (
|
|
37
|
+
CallableDef,
|
|
38
|
+
CallReturnWires,
|
|
39
|
+
CompiledCallableDef,
|
|
40
|
+
CompiledHugrNodeDef,
|
|
41
|
+
)
|
|
42
|
+
from guppylang_internals.error import GuppyError
|
|
43
|
+
from guppylang_internals.nodes import GlobalCall
|
|
44
|
+
from guppylang_internals.span import SourceMap
|
|
45
|
+
from guppylang_internals.tys.subst import Inst, Subst
|
|
46
|
+
from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from guppylang_internals.tys.param import Parameter
|
|
50
|
+
|
|
51
|
+
PyFunc = Callable[..., Any]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(frozen=True)
|
|
55
|
+
class RawFunctionDef(ParsableDef):
|
|
56
|
+
"""A raw function definition provided by the user.
|
|
57
|
+
|
|
58
|
+
The raw definition stores exactly what the user has written (i.e. the AST), without
|
|
59
|
+
any additional checking or parsing. Furthermore, we store the values of the Python
|
|
60
|
+
variables in scope at the point of definition.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
id: The unique definition identifier.
|
|
64
|
+
name: The name of the function.
|
|
65
|
+
defined_at: The AST node where the function was defined.
|
|
66
|
+
python_func: The Python function to be defined.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
python_func: PyFunc
|
|
70
|
+
|
|
71
|
+
description: str = field(default="function", init=False)
|
|
72
|
+
|
|
73
|
+
def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
|
|
74
|
+
"""Parses and checks the user-provided signature of the function."""
|
|
75
|
+
func_ast, docstring = parse_py_func(self.python_func, sources)
|
|
76
|
+
ty = check_signature(func_ast, globals)
|
|
77
|
+
return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass(frozen=True)
|
|
81
|
+
class ParsedFunctionDef(CheckableDef, CallableDef):
|
|
82
|
+
"""A function definition with parsed and checked signature.
|
|
83
|
+
|
|
84
|
+
In particular, this means that we have determined a type for the function and are
|
|
85
|
+
ready to check the function body.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
id: The unique definition identifier.
|
|
89
|
+
name: The name of the function.
|
|
90
|
+
defined_at: The AST node where the function was defined.
|
|
91
|
+
ty: The type of the function.
|
|
92
|
+
python_scope: The Python scope where the function was defined.
|
|
93
|
+
docstring: The docstring of the function.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
defined_at: ast.FunctionDef
|
|
97
|
+
ty: FunctionType
|
|
98
|
+
docstring: str | None
|
|
99
|
+
|
|
100
|
+
description: str = field(default="function", init=False)
|
|
101
|
+
|
|
102
|
+
def check(self, globals: Globals) -> "CheckedFunctionDef":
|
|
103
|
+
"""Type checks the body of the function."""
|
|
104
|
+
# Add python variable scope to the globals
|
|
105
|
+
cfg = check_global_func_def(self.defined_at, self.ty, globals)
|
|
106
|
+
return CheckedFunctionDef(
|
|
107
|
+
self.id,
|
|
108
|
+
self.name,
|
|
109
|
+
self.defined_at,
|
|
110
|
+
self.ty,
|
|
111
|
+
self.docstring,
|
|
112
|
+
cfg,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def check_call(
|
|
116
|
+
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
|
|
117
|
+
) -> tuple[ast.expr, Subst]:
|
|
118
|
+
"""Checks the return type of a function call against a given type."""
|
|
119
|
+
# Use default implementation from the expression checker
|
|
120
|
+
args, subst, inst = check_call(self.ty, args, ty, node, ctx)
|
|
121
|
+
node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
|
|
122
|
+
return node, subst
|
|
123
|
+
|
|
124
|
+
def synthesize_call(
|
|
125
|
+
self, args: list[ast.expr], node: AstNode, ctx: Context
|
|
126
|
+
) -> tuple[ast.expr, Type]:
|
|
127
|
+
"""Synthesizes the return type of a function call."""
|
|
128
|
+
# Use default implementation from the expression checker
|
|
129
|
+
args, ty, inst = synthesize_call(self.ty, args, node, ctx)
|
|
130
|
+
node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
|
|
131
|
+
return with_type(ty, node), ty
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass(frozen=True)
|
|
135
|
+
class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
|
|
136
|
+
"""Type checked version of a user-defined function that is ready to be compiled.
|
|
137
|
+
|
|
138
|
+
In particular, this means that we have a constructed and type checked a control-flow
|
|
139
|
+
graph for the function body.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
id: The unique definition identifier.
|
|
143
|
+
name: The name of the function.
|
|
144
|
+
defined_at: The AST node where the function was defined.
|
|
145
|
+
ty: The type of the function.
|
|
146
|
+
python_scope: The Python scope where the function was defined.
|
|
147
|
+
docstring: The docstring of the function.
|
|
148
|
+
cfg: The type- and linearity-checked CFG for the function body.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
cfg: CheckedCFG[Place]
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def params(self) -> "Sequence[Parameter]":
|
|
155
|
+
"""Generic parameters of this function."""
|
|
156
|
+
return self.ty.params
|
|
157
|
+
|
|
158
|
+
def monomorphize(
|
|
159
|
+
self,
|
|
160
|
+
module: DefinitionBuilder[OpVar],
|
|
161
|
+
mono_args: "PartiallyMonomorphizedArgs",
|
|
162
|
+
ctx: "CompilerContext",
|
|
163
|
+
) -> "CompiledFunctionDef":
|
|
164
|
+
"""Adds a Hugr `FuncDefn` node for the (partially) monomorphized function to the
|
|
165
|
+
Hugr.
|
|
166
|
+
|
|
167
|
+
Note that we don't compile the function body at this point since we don't have
|
|
168
|
+
access to the other compiled functions yet. The body is compiled later in
|
|
169
|
+
`CompiledFunctionDef.compile_inner()`.
|
|
170
|
+
"""
|
|
171
|
+
mono_ty = self.ty.instantiate_partial(mono_args)
|
|
172
|
+
hugr_ty = mono_ty.to_hugr_poly(ctx)
|
|
173
|
+
func_def = module.module_root_builder().define_function(
|
|
174
|
+
self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
|
|
175
|
+
)
|
|
176
|
+
return CompiledFunctionDef(
|
|
177
|
+
self.id,
|
|
178
|
+
self.name,
|
|
179
|
+
self.defined_at,
|
|
180
|
+
mono_args,
|
|
181
|
+
mono_ty,
|
|
182
|
+
self.docstring,
|
|
183
|
+
self.cfg,
|
|
184
|
+
func_def,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclass(frozen=True)
|
|
189
|
+
class CompiledFunctionDef(
|
|
190
|
+
CheckedFunctionDef, CompiledCallableDef, MonomorphizedDef, CompiledHugrNodeDef
|
|
191
|
+
):
|
|
192
|
+
"""A function definition with a corresponding Hugr node.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
id: The unique definition identifier.
|
|
196
|
+
name: The name of the function.
|
|
197
|
+
defined_at: The AST node where the function was defined.
|
|
198
|
+
mono_args: Partial monomorphization of the generic type parameters.
|
|
199
|
+
ty: The type of the function after partial monomorphization.
|
|
200
|
+
python_scope: The Python scope where the function was defined.
|
|
201
|
+
docstring: The docstring of the function.
|
|
202
|
+
cfg: The type- and linearity-checked CFG for the function body.
|
|
203
|
+
func_def: The Hugr function definition.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
func_def: hf.Function
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def hugr_node(self) -> Node:
|
|
210
|
+
"""The Hugr node this definition was compiled into."""
|
|
211
|
+
return self.func_def.parent_node
|
|
212
|
+
|
|
213
|
+
def load_with_args(
|
|
214
|
+
self,
|
|
215
|
+
type_args: Inst,
|
|
216
|
+
dfg: DFContainer,
|
|
217
|
+
ctx: CompilerContext,
|
|
218
|
+
node: AstNode,
|
|
219
|
+
) -> Wire:
|
|
220
|
+
"""Loads the function as a value into a local Hugr dataflow graph."""
|
|
221
|
+
return load_with_args(type_args, dfg, self.ty, self.func_def)
|
|
222
|
+
|
|
223
|
+
def compile_call(
|
|
224
|
+
self,
|
|
225
|
+
args: list[Wire],
|
|
226
|
+
type_args: Inst,
|
|
227
|
+
dfg: DFContainer,
|
|
228
|
+
ctx: CompilerContext,
|
|
229
|
+
node: AstNode,
|
|
230
|
+
) -> CallReturnWires:
|
|
231
|
+
"""Compiles a call to the function."""
|
|
232
|
+
return compile_call(args, type_args, dfg, self.ty, self.func_def)
|
|
233
|
+
|
|
234
|
+
def compile_inner(self, globals: CompilerContext) -> None:
|
|
235
|
+
"""Compiles the body of the function."""
|
|
236
|
+
compile_global_func_def(self, self.func_def, globals)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def load_with_args(
|
|
240
|
+
type_args: Inst,
|
|
241
|
+
dfg: DFContainer,
|
|
242
|
+
ty: FunctionType,
|
|
243
|
+
func: ToNode,
|
|
244
|
+
) -> Wire:
|
|
245
|
+
"""Loads the function as a value into a local Hugr dataflow graph."""
|
|
246
|
+
func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr(dfg.ctx)
|
|
247
|
+
type_args = [ta.to_hugr(dfg.ctx) for ta in type_args]
|
|
248
|
+
return dfg.builder.load_function(func, func_ty, type_args)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def compile_call(
|
|
252
|
+
args: list[Wire],
|
|
253
|
+
type_args: Inst, # Non-monomorphized type args only
|
|
254
|
+
dfg: DFContainer,
|
|
255
|
+
ty: FunctionType,
|
|
256
|
+
func: ToNode,
|
|
257
|
+
) -> CallReturnWires:
|
|
258
|
+
"""Compiles a call to the function."""
|
|
259
|
+
func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr(dfg.ctx)
|
|
260
|
+
type_args = [arg.to_hugr(dfg.ctx) for arg in type_args]
|
|
261
|
+
num_returns = len(type_to_row(ty.output))
|
|
262
|
+
call = dfg.builder.call(func, *args, instantiation=func_ty, type_args=type_args)
|
|
263
|
+
return CallReturnWires(
|
|
264
|
+
regular_returns=list(call[:num_returns]),
|
|
265
|
+
inout_returns=list(call[num_returns:]),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str | None]:
|
|
270
|
+
source_lines, line_offset = inspect.getsourcelines(f)
|
|
271
|
+
source, func_ast, line_offset = parse_source(source_lines, line_offset)
|
|
272
|
+
file = inspect.getsourcefile(f)
|
|
273
|
+
if file is None:
|
|
274
|
+
raise GuppyError(UnknownSourceError(None, f))
|
|
275
|
+
sources.add_file(file)
|
|
276
|
+
annotate_location(func_ast, source, file, line_offset)
|
|
277
|
+
if not isinstance(func_ast, ast.FunctionDef):
|
|
278
|
+
raise GuppyError(ExpectedError(func_ast, "a function definition"))
|
|
279
|
+
return parse_function_with_docstring(func_ast)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AST, int]:
|
|
283
|
+
"""Parses a list of source lines into an AST object.
|
|
284
|
+
|
|
285
|
+
Also takes care of correctly parsing source that is indented.
|
|
286
|
+
|
|
287
|
+
Returns the full source, the parsed AST node, and a potentially updated line number
|
|
288
|
+
offset.
|
|
289
|
+
"""
|
|
290
|
+
source = "".join(source_lines) # Lines already have trailing \n's
|
|
291
|
+
if source_lines[0][0].isspace():
|
|
292
|
+
# This means the function is indented, so we cannot parse it straight away.
|
|
293
|
+
# Running `textwrap.dedent` would mess up the column number in spans. Instead,
|
|
294
|
+
# we'll just wrap the source into a dummy class definition so the indent becomes
|
|
295
|
+
# valid
|
|
296
|
+
cls_node = ast.parse("class _:\n" + source).body[0]
|
|
297
|
+
assert isinstance(cls_node, ast.ClassDef)
|
|
298
|
+
node = cls_node.body[0]
|
|
299
|
+
line_offset -= 1
|
|
300
|
+
else:
|
|
301
|
+
node = ast.parse(source).body[0]
|
|
302
|
+
return source, node, line_offset
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from contextlib import suppress
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import ClassVar, NoReturn
|
|
5
|
+
|
|
6
|
+
from hugr import Wire
|
|
7
|
+
|
|
8
|
+
from guppylang_internals.ast_util import AstNode
|
|
9
|
+
from guppylang_internals.checker.core import Context
|
|
10
|
+
from guppylang_internals.checker.expr_checker import ExprSynthesizer
|
|
11
|
+
from guppylang_internals.compiler.core import CompilerContext, DFContainer
|
|
12
|
+
from guppylang_internals.definition.common import (
|
|
13
|
+
DefId,
|
|
14
|
+
)
|
|
15
|
+
from guppylang_internals.definition.value import (
|
|
16
|
+
CallableDef,
|
|
17
|
+
CallReturnWires,
|
|
18
|
+
CompiledCallableDef,
|
|
19
|
+
)
|
|
20
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
21
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
22
|
+
from guppylang_internals.span import Span, to_span
|
|
23
|
+
from guppylang_internals.tys.printing import signature_to_str
|
|
24
|
+
from guppylang_internals.tys.subst import Inst, Subst
|
|
25
|
+
from guppylang_internals.tys.ty import FunctionType, Type
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class OverloadNoMatchError(Error):
|
|
30
|
+
title: ClassVar[str] = "Invalid call of overloaded function"
|
|
31
|
+
func: str
|
|
32
|
+
arg_tys: list[Type]
|
|
33
|
+
return_ty: Type | None
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def rendered_span_label(self) -> str:
|
|
37
|
+
stem = f"No variant of overloaded function `{self.func}` "
|
|
38
|
+
match self.arg_tys:
|
|
39
|
+
case []:
|
|
40
|
+
stem += "takes 0 arguments"
|
|
41
|
+
case [ty]:
|
|
42
|
+
stem += f"takes a `{ty}` argument"
|
|
43
|
+
case tys:
|
|
44
|
+
args = ", ".join(f"`{ty}`" for ty in tys)
|
|
45
|
+
stem += f"takes arguments {args}"
|
|
46
|
+
if self.return_ty:
|
|
47
|
+
stem += f" and returns `{self.return_ty}`"
|
|
48
|
+
return stem
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class AvailableOverloadsHint(Note):
|
|
53
|
+
func_name: str
|
|
54
|
+
variants: list[FunctionType]
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def rendered_message(self) -> str:
|
|
58
|
+
return "Available overloads are:\n" + "\n".join(
|
|
59
|
+
f" {signature_to_str(self.func_name, ty)}" for ty in self.variants
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass(frozen=True)
|
|
64
|
+
class OverloadHigherOrderError(Error):
|
|
65
|
+
title: ClassVar[str] = "Higher-order overloaded function"
|
|
66
|
+
span_label: ClassVar[str] = (
|
|
67
|
+
"Overloaded function `{func}` may not be used as a higher-order value"
|
|
68
|
+
)
|
|
69
|
+
func: str
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass(frozen=True)
|
|
73
|
+
class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
|
|
74
|
+
func_ids: list[DefId]
|
|
75
|
+
description: str = field(default="overloaded function", init=False)
|
|
76
|
+
|
|
77
|
+
def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:
|
|
78
|
+
raise GuppyError(OverloadHigherOrderError(node, self.name))
|
|
79
|
+
|
|
80
|
+
def check_call(
|
|
81
|
+
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
|
|
82
|
+
) -> tuple[ast.expr, Subst]:
|
|
83
|
+
available_sigs: list[FunctionType] = []
|
|
84
|
+
for def_id in self.func_ids:
|
|
85
|
+
defn = ctx.globals[def_id]
|
|
86
|
+
assert isinstance(defn, CallableDef)
|
|
87
|
+
available_sigs.append(defn.ty)
|
|
88
|
+
with suppress(GuppyError):
|
|
89
|
+
return defn.check_call(args, ty, node, ctx)
|
|
90
|
+
return self._call_error(args, node, ctx, available_sigs, ty)
|
|
91
|
+
|
|
92
|
+
def synthesize_call(
|
|
93
|
+
self, args: list[ast.expr], node: AstNode, ctx: "Context"
|
|
94
|
+
) -> tuple[ast.expr, Type]:
|
|
95
|
+
available_sigs: list[FunctionType] = []
|
|
96
|
+
for def_id in self.func_ids:
|
|
97
|
+
defn = ctx.globals[def_id]
|
|
98
|
+
assert isinstance(defn, CallableDef)
|
|
99
|
+
available_sigs.append(defn.ty)
|
|
100
|
+
with suppress(GuppyError):
|
|
101
|
+
return defn.synthesize_call(args, node, ctx)
|
|
102
|
+
return self._call_error(args, node, ctx, available_sigs)
|
|
103
|
+
|
|
104
|
+
def _call_error(
|
|
105
|
+
self,
|
|
106
|
+
args: list[ast.expr],
|
|
107
|
+
node: AstNode,
|
|
108
|
+
ctx: "Context",
|
|
109
|
+
available_sigs: list[FunctionType],
|
|
110
|
+
return_ty: Type | None = None,
|
|
111
|
+
) -> NoReturn:
|
|
112
|
+
if args and not return_ty:
|
|
113
|
+
start = to_span(args[0]).start
|
|
114
|
+
end = to_span(args[-1]).end
|
|
115
|
+
span = Span(start, end)
|
|
116
|
+
else:
|
|
117
|
+
span = to_span(node)
|
|
118
|
+
|
|
119
|
+
synth = ExprSynthesizer(ctx)
|
|
120
|
+
arg_tys = [synth.synthesize(arg)[1] for arg in args]
|
|
121
|
+
err = OverloadNoMatchError(span, self.name, arg_tys, return_ty)
|
|
122
|
+
err.add_sub_diagnostic(AvailableOverloadsHint(None, self.name, available_sigs))
|
|
123
|
+
raise GuppyError(err)
|
|
124
|
+
|
|
125
|
+
def compile_call(
|
|
126
|
+
self,
|
|
127
|
+
args: list[Wire],
|
|
128
|
+
type_args: Inst,
|
|
129
|
+
dfg: "DFContainer",
|
|
130
|
+
ctx: "CompilerContext",
|
|
131
|
+
node: AstNode,
|
|
132
|
+
) -> "CallReturnWires":
|
|
133
|
+
# This should never be called: Checking the call replaces it with the concrete
|
|
134
|
+
# implementation
|
|
135
|
+
raise InternalGuppyError(
|
|
136
|
+
"OverloadedFunctionDef.compile_call shouldn't be invoked"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def load_with_args(
|
|
140
|
+
self,
|
|
141
|
+
type_args: Inst,
|
|
142
|
+
dfg: "DFContainer",
|
|
143
|
+
ctx: "CompilerContext",
|
|
144
|
+
node: AstNode,
|
|
145
|
+
) -> Wire:
|
|
146
|
+
# This should never be called: During checking we should have already ruled out
|
|
147
|
+
# that overloaded functions are used as higher-order values.
|
|
148
|
+
raise InternalGuppyError(
|
|
149
|
+
"OverloadedFunctionDef.load_with_args shouldn't be invoked"
|
|
150
|
+
)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import ClassVar
|
|
5
|
+
|
|
6
|
+
from guppylang_internals.checker.core import Globals
|
|
7
|
+
from guppylang_internals.definition.common import CompiledDef, Definition, ParsableDef
|
|
8
|
+
from guppylang_internals.diagnostic import Error
|
|
9
|
+
from guppylang_internals.error import GuppyError, InternalGuppyError
|
|
10
|
+
from guppylang_internals.span import SourceMap
|
|
11
|
+
from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
|
|
12
|
+
from guppylang_internals.tys.ty import Type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class LinearConstVarError(Error):
|
|
17
|
+
title: ClassVar[str] = "Invalid const variable"
|
|
18
|
+
span_label: ClassVar[str] = (
|
|
19
|
+
"Const variable `{name}` is not allowed have {thing} type `{ty}`"
|
|
20
|
+
)
|
|
21
|
+
name: str
|
|
22
|
+
ty: Type
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def thing(self) -> str:
|
|
26
|
+
return "non-copyable" if not self.ty.copyable else "non-droppable"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ParamDef(Definition):
|
|
30
|
+
"""Abstract base class for type parameter definitions."""
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def to_param(self, idx: int) -> Parameter:
|
|
34
|
+
"""Creates a parameter from this definition."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class TypeVarDef(ParamDef, CompiledDef):
|
|
39
|
+
"""A type variable definition."""
|
|
40
|
+
|
|
41
|
+
must_be_copyable: bool
|
|
42
|
+
must_be_droppable: bool
|
|
43
|
+
|
|
44
|
+
description: str = field(default="type variable", init=False)
|
|
45
|
+
|
|
46
|
+
def to_param(self, idx: int) -> TypeParam:
|
|
47
|
+
"""Creates a parameter from this definition."""
|
|
48
|
+
return TypeParam(idx, self.name, self.must_be_copyable, self.must_be_droppable)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class RawConstVarDef(ParamDef, ParsableDef):
|
|
53
|
+
"""A constant variable definition whose type is not parsed yet."""
|
|
54
|
+
|
|
55
|
+
type_ast: ast.expr
|
|
56
|
+
description: str = field(default="const variable", init=False)
|
|
57
|
+
|
|
58
|
+
def parse(self, globals: Globals, sources: SourceMap) -> "ConstVarDef":
|
|
59
|
+
from guppylang_internals.tys.parsing import type_from_ast
|
|
60
|
+
|
|
61
|
+
ty = type_from_ast(self.type_ast, globals, {})
|
|
62
|
+
if not ty.copyable or not ty.droppable:
|
|
63
|
+
raise GuppyError(LinearConstVarError(self.type_ast, self.name, ty))
|
|
64
|
+
return ConstVarDef(self.id, self.name, self.defined_at, ty)
|
|
65
|
+
|
|
66
|
+
def to_param(self, idx: int) -> Parameter:
|
|
67
|
+
raise InternalGuppyError(
|
|
68
|
+
"RawConstVarDef.to_param: Parameter conversion only possible after parsing"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass(frozen=True)
|
|
73
|
+
class ConstVarDef(ParamDef, CompiledDef):
|
|
74
|
+
"""A constant variable definition."""
|
|
75
|
+
|
|
76
|
+
ty: Type
|
|
77
|
+
|
|
78
|
+
description: str = field(default="const variable", init=False)
|
|
79
|
+
|
|
80
|
+
def to_param(self, idx: int) -> ConstParam:
|
|
81
|
+
"""Creates a parameter from this definition."""
|
|
82
|
+
return ConstParam(idx, self.name, self.ty)
|