guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,89 @@
1
+ import ast
2
+ from dataclasses import dataclass, field
3
+
4
+ from hugr import Node, Wire, val
5
+ from hugr.build.dfg import DefinitionBuilder, OpVar
6
+
7
+ from guppylang_internals.ast_util import AstNode
8
+ from guppylang_internals.checker.core import Globals
9
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
+ from guppylang_internals.definition.common import CompilableDef, ParsableDef
11
+ from guppylang_internals.definition.value import (
12
+ CompiledHugrNodeDef,
13
+ CompiledValueDef,
14
+ ValueDef,
15
+ )
16
+ from guppylang_internals.span import SourceMap
17
+ from guppylang_internals.tys.parsing import type_from_ast
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class RawExternDef(ParsableDef):
22
+ """A raw extern symbol definition provided by the user."""
23
+
24
+ symbol: str
25
+ constant: bool
26
+ type_ast: ast.expr
27
+
28
+ description: str = field(default="extern", init=False)
29
+
30
+ def parse(self, globals: Globals, sources: SourceMap) -> "ExternDef":
31
+ """Parses and checks the user-provided signature of the function."""
32
+ return ExternDef(
33
+ self.id,
34
+ self.name,
35
+ self.defined_at,
36
+ type_from_ast(self.type_ast, globals, {}),
37
+ self.symbol,
38
+ self.constant,
39
+ self.type_ast,
40
+ )
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class ExternDef(RawExternDef, ValueDef, CompilableDef):
45
+ """An extern symbol definition."""
46
+
47
+ def compile_outer(
48
+ self, graph: DefinitionBuilder[OpVar], ctx: CompilerContext
49
+ ) -> "CompiledExternDef":
50
+ """Adds a Hugr constant node for the extern definition to the provided graph."""
51
+ # The `typ` field must be serialized at this point, to ensure that the
52
+ # `Extension` is serializable.
53
+ custom_const = {
54
+ "symbol": self.symbol,
55
+ "typ": self.ty.to_hugr(ctx)._to_serial_root(),
56
+ "constant": self.constant,
57
+ }
58
+ value = val.Extension(
59
+ name="ConstExternalSymbol",
60
+ typ=self.ty.to_hugr(ctx),
61
+ val=custom_const,
62
+ )
63
+ const_node = graph.add_const(value)
64
+ return CompiledExternDef(
65
+ self.id,
66
+ self.name,
67
+ self.defined_at,
68
+ self.ty,
69
+ self.symbol,
70
+ self.constant,
71
+ self.type_ast,
72
+ const_node,
73
+ )
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class CompiledExternDef(ExternDef, CompiledValueDef, CompiledHugrNodeDef):
78
+ """An extern symbol definition that has been compiled to a Hugr constant."""
79
+
80
+ const_node: Node
81
+
82
+ @property
83
+ def hugr_node(self) -> Node:
84
+ """The Hugr node this definition was compiled into."""
85
+ return self.const_node
86
+
87
+ def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:
88
+ """Loads the extern value into a local Hugr dataflow graph."""
89
+ return dfg.builder.load(self.const_node)
@@ -0,0 +1,302 @@
1
+ import ast
2
+ import inspect
3
+ from collections.abc import Callable, Sequence
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import hugr.build.function as hf
8
+ import hugr.tys as ht
9
+ from hugr import Node, Wire
10
+ from hugr.build.dfg import DefinitionBuilder, OpVar
11
+ from hugr.hugr.node_port import ToNode
12
+
13
+ from guppylang_internals.ast_util import AstNode, annotate_location, with_loc, with_type
14
+ from guppylang_internals.checker.cfg_checker import CheckedCFG
15
+ from guppylang_internals.checker.core import Context, Globals, Place
16
+ from guppylang_internals.checker.errors.generic import ExpectedError
17
+ from guppylang_internals.checker.expr_checker import check_call, synthesize_call
18
+ from guppylang_internals.checker.func_checker import (
19
+ check_global_func_def,
20
+ check_signature,
21
+ parse_function_with_docstring,
22
+ )
23
+ from guppylang_internals.compiler.core import (
24
+ CompilerContext,
25
+ DFContainer,
26
+ PartiallyMonomorphizedArgs,
27
+ )
28
+ from guppylang_internals.compiler.func_compiler import compile_global_func_def
29
+ from guppylang_internals.definition.common import (
30
+ CheckableDef,
31
+ MonomorphizableDef,
32
+ MonomorphizedDef,
33
+ ParsableDef,
34
+ UnknownSourceError,
35
+ )
36
+ from guppylang_internals.definition.value import (
37
+ CallableDef,
38
+ CallReturnWires,
39
+ CompiledCallableDef,
40
+ CompiledHugrNodeDef,
41
+ )
42
+ from guppylang_internals.error import GuppyError
43
+ from guppylang_internals.nodes import GlobalCall
44
+ from guppylang_internals.span import SourceMap
45
+ from guppylang_internals.tys.subst import Inst, Subst
46
+ from guppylang_internals.tys.ty import FunctionType, Type, type_to_row
47
+
48
+ if TYPE_CHECKING:
49
+ from guppylang_internals.tys.param import Parameter
50
+
51
+ PyFunc = Callable[..., Any]
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class RawFunctionDef(ParsableDef):
56
+ """A raw function definition provided by the user.
57
+
58
+ The raw definition stores exactly what the user has written (i.e. the AST), without
59
+ any additional checking or parsing. Furthermore, we store the values of the Python
60
+ variables in scope at the point of definition.
61
+
62
+ Args:
63
+ id: The unique definition identifier.
64
+ name: The name of the function.
65
+ defined_at: The AST node where the function was defined.
66
+ python_func: The Python function to be defined.
67
+ """
68
+
69
+ python_func: PyFunc
70
+
71
+ description: str = field(default="function", init=False)
72
+
73
+ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedFunctionDef":
74
+ """Parses and checks the user-provided signature of the function."""
75
+ func_ast, docstring = parse_py_func(self.python_func, sources)
76
+ ty = check_signature(func_ast, globals)
77
+ return ParsedFunctionDef(self.id, self.name, func_ast, ty, docstring)
78
+
79
+
80
+ @dataclass(frozen=True)
81
+ class ParsedFunctionDef(CheckableDef, CallableDef):
82
+ """A function definition with parsed and checked signature.
83
+
84
+ In particular, this means that we have determined a type for the function and are
85
+ ready to check the function body.
86
+
87
+ Args:
88
+ id: The unique definition identifier.
89
+ name: The name of the function.
90
+ defined_at: The AST node where the function was defined.
91
+ ty: The type of the function.
92
+ python_scope: The Python scope where the function was defined.
93
+ docstring: The docstring of the function.
94
+ """
95
+
96
+ defined_at: ast.FunctionDef
97
+ ty: FunctionType
98
+ docstring: str | None
99
+
100
+ description: str = field(default="function", init=False)
101
+
102
+ def check(self, globals: Globals) -> "CheckedFunctionDef":
103
+ """Type checks the body of the function."""
104
+ # Add python variable scope to the globals
105
+ cfg = check_global_func_def(self.defined_at, self.ty, globals)
106
+ return CheckedFunctionDef(
107
+ self.id,
108
+ self.name,
109
+ self.defined_at,
110
+ self.ty,
111
+ self.docstring,
112
+ cfg,
113
+ )
114
+
115
+ def check_call(
116
+ self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
117
+ ) -> tuple[ast.expr, Subst]:
118
+ """Checks the return type of a function call against a given type."""
119
+ # Use default implementation from the expression checker
120
+ args, subst, inst = check_call(self.ty, args, ty, node, ctx)
121
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
122
+ return node, subst
123
+
124
+ def synthesize_call(
125
+ self, args: list[ast.expr], node: AstNode, ctx: Context
126
+ ) -> tuple[ast.expr, Type]:
127
+ """Synthesizes the return type of a function call."""
128
+ # Use default implementation from the expression checker
129
+ args, ty, inst = synthesize_call(self.ty, args, node, ctx)
130
+ node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst))
131
+ return with_type(ty, node), ty
132
+
133
+
134
+ @dataclass(frozen=True)
135
+ class CheckedFunctionDef(ParsedFunctionDef, MonomorphizableDef):
136
+ """Type checked version of a user-defined function that is ready to be compiled.
137
+
138
+ In particular, this means that we have a constructed and type checked a control-flow
139
+ graph for the function body.
140
+
141
+ Args:
142
+ id: The unique definition identifier.
143
+ name: The name of the function.
144
+ defined_at: The AST node where the function was defined.
145
+ ty: The type of the function.
146
+ python_scope: The Python scope where the function was defined.
147
+ docstring: The docstring of the function.
148
+ cfg: The type- and linearity-checked CFG for the function body.
149
+ """
150
+
151
+ cfg: CheckedCFG[Place]
152
+
153
+ @property
154
+ def params(self) -> "Sequence[Parameter]":
155
+ """Generic parameters of this function."""
156
+ return self.ty.params
157
+
158
+ def monomorphize(
159
+ self,
160
+ module: DefinitionBuilder[OpVar],
161
+ mono_args: "PartiallyMonomorphizedArgs",
162
+ ctx: "CompilerContext",
163
+ ) -> "CompiledFunctionDef":
164
+ """Adds a Hugr `FuncDefn` node for the (partially) monomorphized function to the
165
+ Hugr.
166
+
167
+ Note that we don't compile the function body at this point since we don't have
168
+ access to the other compiled functions yet. The body is compiled later in
169
+ `CompiledFunctionDef.compile_inner()`.
170
+ """
171
+ mono_ty = self.ty.instantiate_partial(mono_args)
172
+ hugr_ty = mono_ty.to_hugr_poly(ctx)
173
+ func_def = module.module_root_builder().define_function(
174
+ self.name, hugr_ty.body.input, hugr_ty.body.output, hugr_ty.params
175
+ )
176
+ return CompiledFunctionDef(
177
+ self.id,
178
+ self.name,
179
+ self.defined_at,
180
+ mono_args,
181
+ mono_ty,
182
+ self.docstring,
183
+ self.cfg,
184
+ func_def,
185
+ )
186
+
187
+
188
+ @dataclass(frozen=True)
189
+ class CompiledFunctionDef(
190
+ CheckedFunctionDef, CompiledCallableDef, MonomorphizedDef, CompiledHugrNodeDef
191
+ ):
192
+ """A function definition with a corresponding Hugr node.
193
+
194
+ Args:
195
+ id: The unique definition identifier.
196
+ name: The name of the function.
197
+ defined_at: The AST node where the function was defined.
198
+ mono_args: Partial monomorphization of the generic type parameters.
199
+ ty: The type of the function after partial monomorphization.
200
+ python_scope: The Python scope where the function was defined.
201
+ docstring: The docstring of the function.
202
+ cfg: The type- and linearity-checked CFG for the function body.
203
+ func_def: The Hugr function definition.
204
+ """
205
+
206
+ func_def: hf.Function
207
+
208
+ @property
209
+ def hugr_node(self) -> Node:
210
+ """The Hugr node this definition was compiled into."""
211
+ return self.func_def.parent_node
212
+
213
+ def load_with_args(
214
+ self,
215
+ type_args: Inst,
216
+ dfg: DFContainer,
217
+ ctx: CompilerContext,
218
+ node: AstNode,
219
+ ) -> Wire:
220
+ """Loads the function as a value into a local Hugr dataflow graph."""
221
+ return load_with_args(type_args, dfg, self.ty, self.func_def)
222
+
223
+ def compile_call(
224
+ self,
225
+ args: list[Wire],
226
+ type_args: Inst,
227
+ dfg: DFContainer,
228
+ ctx: CompilerContext,
229
+ node: AstNode,
230
+ ) -> CallReturnWires:
231
+ """Compiles a call to the function."""
232
+ return compile_call(args, type_args, dfg, self.ty, self.func_def)
233
+
234
+ def compile_inner(self, globals: CompilerContext) -> None:
235
+ """Compiles the body of the function."""
236
+ compile_global_func_def(self, self.func_def, globals)
237
+
238
+
239
+ def load_with_args(
240
+ type_args: Inst,
241
+ dfg: DFContainer,
242
+ ty: FunctionType,
243
+ func: ToNode,
244
+ ) -> Wire:
245
+ """Loads the function as a value into a local Hugr dataflow graph."""
246
+ func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr(dfg.ctx)
247
+ type_args = [ta.to_hugr(dfg.ctx) for ta in type_args]
248
+ return dfg.builder.load_function(func, func_ty, type_args)
249
+
250
+
251
+ def compile_call(
252
+ args: list[Wire],
253
+ type_args: Inst, # Non-monomorphized type args only
254
+ dfg: DFContainer,
255
+ ty: FunctionType,
256
+ func: ToNode,
257
+ ) -> CallReturnWires:
258
+ """Compiles a call to the function."""
259
+ func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr(dfg.ctx)
260
+ type_args = [arg.to_hugr(dfg.ctx) for arg in type_args]
261
+ num_returns = len(type_to_row(ty.output))
262
+ call = dfg.builder.call(func, *args, instantiation=func_ty, type_args=type_args)
263
+ return CallReturnWires(
264
+ regular_returns=list(call[:num_returns]),
265
+ inout_returns=list(call[num_returns:]),
266
+ )
267
+
268
+
269
+ def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str | None]:
270
+ source_lines, line_offset = inspect.getsourcelines(f)
271
+ source, func_ast, line_offset = parse_source(source_lines, line_offset)
272
+ file = inspect.getsourcefile(f)
273
+ if file is None:
274
+ raise GuppyError(UnknownSourceError(None, f))
275
+ sources.add_file(file)
276
+ annotate_location(func_ast, source, file, line_offset)
277
+ if not isinstance(func_ast, ast.FunctionDef):
278
+ raise GuppyError(ExpectedError(func_ast, "a function definition"))
279
+ return parse_function_with_docstring(func_ast)
280
+
281
+
282
+ def parse_source(source_lines: list[str], line_offset: int) -> tuple[str, ast.AST, int]:
283
+ """Parses a list of source lines into an AST object.
284
+
285
+ Also takes care of correctly parsing source that is indented.
286
+
287
+ Returns the full source, the parsed AST node, and a potentially updated line number
288
+ offset.
289
+ """
290
+ source = "".join(source_lines) # Lines already have trailing \n's
291
+ if source_lines[0][0].isspace():
292
+ # This means the function is indented, so we cannot parse it straight away.
293
+ # Running `textwrap.dedent` would mess up the column number in spans. Instead,
294
+ # we'll just wrap the source into a dummy class definition so the indent becomes
295
+ # valid
296
+ cls_node = ast.parse("class _:\n" + source).body[0]
297
+ assert isinstance(cls_node, ast.ClassDef)
298
+ node = cls_node.body[0]
299
+ line_offset -= 1
300
+ else:
301
+ node = ast.parse(source).body[0]
302
+ return source, node, line_offset
@@ -0,0 +1,150 @@
1
+ import ast
2
+ from contextlib import suppress
3
+ from dataclasses import dataclass, field
4
+ from typing import ClassVar, NoReturn
5
+
6
+ from hugr import Wire
7
+
8
+ from guppylang_internals.ast_util import AstNode
9
+ from guppylang_internals.checker.core import Context
10
+ from guppylang_internals.checker.expr_checker import ExprSynthesizer
11
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
12
+ from guppylang_internals.definition.common import (
13
+ DefId,
14
+ )
15
+ from guppylang_internals.definition.value import (
16
+ CallableDef,
17
+ CallReturnWires,
18
+ CompiledCallableDef,
19
+ )
20
+ from guppylang_internals.diagnostic import Error, Note
21
+ from guppylang_internals.error import GuppyError, InternalGuppyError
22
+ from guppylang_internals.span import Span, to_span
23
+ from guppylang_internals.tys.printing import signature_to_str
24
+ from guppylang_internals.tys.subst import Inst, Subst
25
+ from guppylang_internals.tys.ty import FunctionType, Type
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class OverloadNoMatchError(Error):
30
+ title: ClassVar[str] = "Invalid call of overloaded function"
31
+ func: str
32
+ arg_tys: list[Type]
33
+ return_ty: Type | None
34
+
35
+ @property
36
+ def rendered_span_label(self) -> str:
37
+ stem = f"No variant of overloaded function `{self.func}` "
38
+ match self.arg_tys:
39
+ case []:
40
+ stem += "takes 0 arguments"
41
+ case [ty]:
42
+ stem += f"takes a `{ty}` argument"
43
+ case tys:
44
+ args = ", ".join(f"`{ty}`" for ty in tys)
45
+ stem += f"takes arguments {args}"
46
+ if self.return_ty:
47
+ stem += f" and returns `{self.return_ty}`"
48
+ return stem
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class AvailableOverloadsHint(Note):
53
+ func_name: str
54
+ variants: list[FunctionType]
55
+
56
+ @property
57
+ def rendered_message(self) -> str:
58
+ return "Available overloads are:\n" + "\n".join(
59
+ f" {signature_to_str(self.func_name, ty)}" for ty in self.variants
60
+ )
61
+
62
+
63
+ @dataclass(frozen=True)
64
+ class OverloadHigherOrderError(Error):
65
+ title: ClassVar[str] = "Higher-order overloaded function"
66
+ span_label: ClassVar[str] = (
67
+ "Overloaded function `{func}` may not be used as a higher-order value"
68
+ )
69
+ func: str
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class OverloadedFunctionDef(CompiledCallableDef, CallableDef):
74
+ func_ids: list[DefId]
75
+ description: str = field(default="overloaded function", init=False)
76
+
77
+ def load(self, dfg: DFContainer, ctx: CompilerContext, node: AstNode) -> Wire:
78
+ raise GuppyError(OverloadHigherOrderError(node, self.name))
79
+
80
+ def check_call(
81
+ self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
82
+ ) -> tuple[ast.expr, Subst]:
83
+ available_sigs: list[FunctionType] = []
84
+ for def_id in self.func_ids:
85
+ defn = ctx.globals[def_id]
86
+ assert isinstance(defn, CallableDef)
87
+ available_sigs.append(defn.ty)
88
+ with suppress(GuppyError):
89
+ return defn.check_call(args, ty, node, ctx)
90
+ return self._call_error(args, node, ctx, available_sigs, ty)
91
+
92
+ def synthesize_call(
93
+ self, args: list[ast.expr], node: AstNode, ctx: "Context"
94
+ ) -> tuple[ast.expr, Type]:
95
+ available_sigs: list[FunctionType] = []
96
+ for def_id in self.func_ids:
97
+ defn = ctx.globals[def_id]
98
+ assert isinstance(defn, CallableDef)
99
+ available_sigs.append(defn.ty)
100
+ with suppress(GuppyError):
101
+ return defn.synthesize_call(args, node, ctx)
102
+ return self._call_error(args, node, ctx, available_sigs)
103
+
104
+ def _call_error(
105
+ self,
106
+ args: list[ast.expr],
107
+ node: AstNode,
108
+ ctx: "Context",
109
+ available_sigs: list[FunctionType],
110
+ return_ty: Type | None = None,
111
+ ) -> NoReturn:
112
+ if args and not return_ty:
113
+ start = to_span(args[0]).start
114
+ end = to_span(args[-1]).end
115
+ span = Span(start, end)
116
+ else:
117
+ span = to_span(node)
118
+
119
+ synth = ExprSynthesizer(ctx)
120
+ arg_tys = [synth.synthesize(arg)[1] for arg in args]
121
+ err = OverloadNoMatchError(span, self.name, arg_tys, return_ty)
122
+ err.add_sub_diagnostic(AvailableOverloadsHint(None, self.name, available_sigs))
123
+ raise GuppyError(err)
124
+
125
+ def compile_call(
126
+ self,
127
+ args: list[Wire],
128
+ type_args: Inst,
129
+ dfg: "DFContainer",
130
+ ctx: "CompilerContext",
131
+ node: AstNode,
132
+ ) -> "CallReturnWires":
133
+ # This should never be called: Checking the call replaces it with the concrete
134
+ # implementation
135
+ raise InternalGuppyError(
136
+ "OverloadedFunctionDef.compile_call shouldn't be invoked"
137
+ )
138
+
139
+ def load_with_args(
140
+ self,
141
+ type_args: Inst,
142
+ dfg: "DFContainer",
143
+ ctx: "CompilerContext",
144
+ node: AstNode,
145
+ ) -> Wire:
146
+ # This should never be called: During checking we should have already ruled out
147
+ # that overloaded functions are used as higher-order values.
148
+ raise InternalGuppyError(
149
+ "OverloadedFunctionDef.load_with_args shouldn't be invoked"
150
+ )
@@ -0,0 +1,82 @@
1
+ import ast
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass, field
4
+ from typing import ClassVar
5
+
6
+ from guppylang_internals.checker.core import Globals
7
+ from guppylang_internals.definition.common import CompiledDef, Definition, ParsableDef
8
+ from guppylang_internals.diagnostic import Error
9
+ from guppylang_internals.error import GuppyError, InternalGuppyError
10
+ from guppylang_internals.span import SourceMap
11
+ from guppylang_internals.tys.param import ConstParam, Parameter, TypeParam
12
+ from guppylang_internals.tys.ty import Type
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class LinearConstVarError(Error):
17
+ title: ClassVar[str] = "Invalid const variable"
18
+ span_label: ClassVar[str] = (
19
+ "Const variable `{name}` is not allowed have {thing} type `{ty}`"
20
+ )
21
+ name: str
22
+ ty: Type
23
+
24
+ @property
25
+ def thing(self) -> str:
26
+ return "non-copyable" if not self.ty.copyable else "non-droppable"
27
+
28
+
29
+ class ParamDef(Definition):
30
+ """Abstract base class for type parameter definitions."""
31
+
32
+ @abstractmethod
33
+ def to_param(self, idx: int) -> Parameter:
34
+ """Creates a parameter from this definition."""
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class TypeVarDef(ParamDef, CompiledDef):
39
+ """A type variable definition."""
40
+
41
+ must_be_copyable: bool
42
+ must_be_droppable: bool
43
+
44
+ description: str = field(default="type variable", init=False)
45
+
46
+ def to_param(self, idx: int) -> TypeParam:
47
+ """Creates a parameter from this definition."""
48
+ return TypeParam(idx, self.name, self.must_be_copyable, self.must_be_droppable)
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class RawConstVarDef(ParamDef, ParsableDef):
53
+ """A constant variable definition whose type is not parsed yet."""
54
+
55
+ type_ast: ast.expr
56
+ description: str = field(default="const variable", init=False)
57
+
58
+ def parse(self, globals: Globals, sources: SourceMap) -> "ConstVarDef":
59
+ from guppylang_internals.tys.parsing import type_from_ast
60
+
61
+ ty = type_from_ast(self.type_ast, globals, {})
62
+ if not ty.copyable or not ty.droppable:
63
+ raise GuppyError(LinearConstVarError(self.type_ast, self.name, ty))
64
+ return ConstVarDef(self.id, self.name, self.defined_at, ty)
65
+
66
+ def to_param(self, idx: int) -> Parameter:
67
+ raise InternalGuppyError(
68
+ "RawConstVarDef.to_param: Parameter conversion only possible after parsing"
69
+ )
70
+
71
+
72
+ @dataclass(frozen=True)
73
+ class ConstVarDef(ParamDef, CompiledDef):
74
+ """A constant variable definition."""
75
+
76
+ ty: Type
77
+
78
+ description: str = field(default="const variable", init=False)
79
+
80
+ def to_param(self, idx: int) -> ConstParam:
81
+ """Creates a parameter from this definition."""
82
+ return ConstParam(idx, self.name, self.ty)