guppylang-internals 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guppylang_internals/__init__.py +3 -0
- guppylang_internals/ast_util.py +350 -0
- guppylang_internals/cfg/__init__.py +0 -0
- guppylang_internals/cfg/analysis.py +230 -0
- guppylang_internals/cfg/bb.py +221 -0
- guppylang_internals/cfg/builder.py +606 -0
- guppylang_internals/cfg/cfg.py +117 -0
- guppylang_internals/checker/__init__.py +0 -0
- guppylang_internals/checker/cfg_checker.py +388 -0
- guppylang_internals/checker/core.py +550 -0
- guppylang_internals/checker/errors/__init__.py +0 -0
- guppylang_internals/checker/errors/comptime_errors.py +106 -0
- guppylang_internals/checker/errors/generic.py +45 -0
- guppylang_internals/checker/errors/linearity.py +300 -0
- guppylang_internals/checker/errors/type_errors.py +344 -0
- guppylang_internals/checker/errors/wasm.py +34 -0
- guppylang_internals/checker/expr_checker.py +1413 -0
- guppylang_internals/checker/func_checker.py +269 -0
- guppylang_internals/checker/linearity_checker.py +821 -0
- guppylang_internals/checker/stmt_checker.py +447 -0
- guppylang_internals/compiler/__init__.py +0 -0
- guppylang_internals/compiler/cfg_compiler.py +233 -0
- guppylang_internals/compiler/core.py +613 -0
- guppylang_internals/compiler/expr_compiler.py +989 -0
- guppylang_internals/compiler/func_compiler.py +97 -0
- guppylang_internals/compiler/hugr_extension.py +224 -0
- guppylang_internals/compiler/qtm_platform_extension.py +0 -0
- guppylang_internals/compiler/stmt_compiler.py +212 -0
- guppylang_internals/decorator.py +246 -0
- guppylang_internals/definition/__init__.py +0 -0
- guppylang_internals/definition/common.py +214 -0
- guppylang_internals/definition/const.py +74 -0
- guppylang_internals/definition/custom.py +492 -0
- guppylang_internals/definition/declaration.py +171 -0
- guppylang_internals/definition/extern.py +89 -0
- guppylang_internals/definition/function.py +302 -0
- guppylang_internals/definition/overloaded.py +150 -0
- guppylang_internals/definition/parameter.py +82 -0
- guppylang_internals/definition/pytket_circuits.py +405 -0
- guppylang_internals/definition/struct.py +392 -0
- guppylang_internals/definition/traced.py +151 -0
- guppylang_internals/definition/ty.py +51 -0
- guppylang_internals/definition/value.py +115 -0
- guppylang_internals/definition/wasm.py +61 -0
- guppylang_internals/diagnostic.py +523 -0
- guppylang_internals/dummy_decorator.py +76 -0
- guppylang_internals/engine.py +295 -0
- guppylang_internals/error.py +107 -0
- guppylang_internals/experimental.py +92 -0
- guppylang_internals/ipython_inspect.py +28 -0
- guppylang_internals/nodes.py +427 -0
- guppylang_internals/py.typed +0 -0
- guppylang_internals/span.py +150 -0
- guppylang_internals/std/__init__.py +0 -0
- guppylang_internals/std/_internal/__init__.py +0 -0
- guppylang_internals/std/_internal/checker.py +573 -0
- guppylang_internals/std/_internal/compiler/__init__.py +0 -0
- guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
- guppylang_internals/std/_internal/compiler/array.py +569 -0
- guppylang_internals/std/_internal/compiler/either.py +131 -0
- guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
- guppylang_internals/std/_internal/compiler/futures.py +30 -0
- guppylang_internals/std/_internal/compiler/list.py +348 -0
- guppylang_internals/std/_internal/compiler/mem.py +13 -0
- guppylang_internals/std/_internal/compiler/option.py +78 -0
- guppylang_internals/std/_internal/compiler/prelude.py +271 -0
- guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
- guppylang_internals/std/_internal/compiler/quantum.py +118 -0
- guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
- guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
- guppylang_internals/std/_internal/compiler/wasm.py +135 -0
- guppylang_internals/std/_internal/compiler.py +0 -0
- guppylang_internals/std/_internal/debug.py +95 -0
- guppylang_internals/std/_internal/util.py +271 -0
- guppylang_internals/tracing/__init__.py +0 -0
- guppylang_internals/tracing/builtins_mock.py +62 -0
- guppylang_internals/tracing/frozenlist.py +57 -0
- guppylang_internals/tracing/function.py +186 -0
- guppylang_internals/tracing/object.py +551 -0
- guppylang_internals/tracing/state.py +69 -0
- guppylang_internals/tracing/unpacking.py +194 -0
- guppylang_internals/tracing/util.py +86 -0
- guppylang_internals/tys/__init__.py +0 -0
- guppylang_internals/tys/arg.py +115 -0
- guppylang_internals/tys/builtin.py +382 -0
- guppylang_internals/tys/common.py +110 -0
- guppylang_internals/tys/const.py +114 -0
- guppylang_internals/tys/errors.py +178 -0
- guppylang_internals/tys/param.py +251 -0
- guppylang_internals/tys/parsing.py +425 -0
- guppylang_internals/tys/printing.py +174 -0
- guppylang_internals/tys/subst.py +112 -0
- guppylang_internals/tys/ty.py +876 -0
- guppylang_internals/tys/var.py +49 -0
- guppylang_internals-0.21.0.dist-info/METADATA +253 -0
- guppylang_internals-0.21.0.dist-info/RECORD +98 -0
- guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
- guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
|
@@ -0,0 +1,573 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import ClassVar, cast
|
|
4
|
+
|
|
5
|
+
from typing_extensions import assert_never
|
|
6
|
+
|
|
7
|
+
from guppylang_internals.ast_util import get_type, with_loc, with_type
|
|
8
|
+
from guppylang_internals.checker.core import ComptimeVariable, Context
|
|
9
|
+
from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
|
|
10
|
+
from guppylang_internals.checker.errors.type_errors import (
|
|
11
|
+
ArrayComprUnknownSizeError,
|
|
12
|
+
TypeMismatchError,
|
|
13
|
+
)
|
|
14
|
+
from guppylang_internals.checker.expr_checker import (
|
|
15
|
+
ExprChecker,
|
|
16
|
+
ExprSynthesizer,
|
|
17
|
+
check_call,
|
|
18
|
+
check_num_args,
|
|
19
|
+
check_type_against,
|
|
20
|
+
synthesize_call,
|
|
21
|
+
synthesize_comprehension,
|
|
22
|
+
)
|
|
23
|
+
from guppylang_internals.definition.custom import (
|
|
24
|
+
CustomCallChecker,
|
|
25
|
+
)
|
|
26
|
+
from guppylang_internals.definition.struct import CheckedStructDef, RawStructDef
|
|
27
|
+
from guppylang_internals.diagnostic import Error, Note
|
|
28
|
+
from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError
|
|
29
|
+
from guppylang_internals.nodes import (
|
|
30
|
+
BarrierExpr,
|
|
31
|
+
DesugaredArrayComp,
|
|
32
|
+
DesugaredGeneratorExpr,
|
|
33
|
+
ExitKind,
|
|
34
|
+
GenericParamValue,
|
|
35
|
+
GlobalCall,
|
|
36
|
+
MakeIter,
|
|
37
|
+
PanicExpr,
|
|
38
|
+
PlaceNode,
|
|
39
|
+
ResultExpr,
|
|
40
|
+
)
|
|
41
|
+
from guppylang_internals.tys.arg import ConstArg, TypeArg
|
|
42
|
+
from guppylang_internals.tys.builtin import (
|
|
43
|
+
array_type,
|
|
44
|
+
array_type_def,
|
|
45
|
+
bool_type,
|
|
46
|
+
get_element_type,
|
|
47
|
+
get_iter_size,
|
|
48
|
+
int_type,
|
|
49
|
+
is_array_type,
|
|
50
|
+
is_bool_type,
|
|
51
|
+
is_sized_iter_type,
|
|
52
|
+
nat_type,
|
|
53
|
+
sized_iter_type,
|
|
54
|
+
string_type,
|
|
55
|
+
)
|
|
56
|
+
from guppylang_internals.tys.const import Const, ConstValue
|
|
57
|
+
from guppylang_internals.tys.subst import Subst
|
|
58
|
+
from guppylang_internals.tys.ty import (
|
|
59
|
+
FuncInput,
|
|
60
|
+
FunctionType,
|
|
61
|
+
InputFlags,
|
|
62
|
+
NoneType,
|
|
63
|
+
NumericType,
|
|
64
|
+
StructType,
|
|
65
|
+
Type,
|
|
66
|
+
unify,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ReversingChecker(CustomCallChecker):
|
|
71
|
+
"""Call checker for reverse arithmetic methods.
|
|
72
|
+
|
|
73
|
+
For examples, turns a call to `__radd__` into a call to `__add__` with reversed
|
|
74
|
+
arguments.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def parse_name(self) -> str:
|
|
78
|
+
# Must be a dunder method
|
|
79
|
+
assert self.func.name.startswith("__")
|
|
80
|
+
assert self.func.name.endswith("__")
|
|
81
|
+
name = self.func.name[2:-2]
|
|
82
|
+
# Remove the `r`
|
|
83
|
+
assert name.startswith("r")
|
|
84
|
+
return f"__{name[1:]}__"
|
|
85
|
+
|
|
86
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
87
|
+
[self_arg, other_arg] = args
|
|
88
|
+
self_arg, self_ty = ExprSynthesizer(self.ctx).synthesize(self_arg)
|
|
89
|
+
f = self.ctx.globals.get_instance_func(self_ty, self.parse_name())
|
|
90
|
+
assert f is not None
|
|
91
|
+
return f.synthesize_call([other_arg, self_arg], self.node, self.ctx)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class UnsupportedChecker(CustomCallChecker):
|
|
95
|
+
"""Call checker for Python builtin functions that are not available in Guppy.
|
|
96
|
+
|
|
97
|
+
Gives the uses a nicer error message when they try to use an unsupported feature.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
101
|
+
err = UnsupportedError(
|
|
102
|
+
self.node, f"Builtin method `{self.func.name}`", singular=True
|
|
103
|
+
)
|
|
104
|
+
raise GuppyError(err)
|
|
105
|
+
|
|
106
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
107
|
+
err = UnsupportedError(
|
|
108
|
+
self.node, f"Builtin method `{self.func.name}`", singular=True
|
|
109
|
+
)
|
|
110
|
+
raise GuppyError(err)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class DunderChecker(CustomCallChecker):
|
|
114
|
+
"""Call checker for builtin functions that call out to dunder instance methods"""
|
|
115
|
+
|
|
116
|
+
dunder_name: str
|
|
117
|
+
num_args: int
|
|
118
|
+
|
|
119
|
+
def __init__(self, dunder_name: str, num_args: int = 1):
|
|
120
|
+
assert num_args > 0
|
|
121
|
+
self.dunder_name = dunder_name
|
|
122
|
+
self.num_args = num_args
|
|
123
|
+
|
|
124
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
125
|
+
check_num_args(self.num_args, len(args), self.node)
|
|
126
|
+
fst, *rest = args
|
|
127
|
+
return ExprSynthesizer(self.ctx).synthesize_instance_func(
|
|
128
|
+
fst,
|
|
129
|
+
rest,
|
|
130
|
+
self.dunder_name,
|
|
131
|
+
f"a valid argument to `{self.func.name}`",
|
|
132
|
+
give_reason=True,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class CallableChecker(CustomCallChecker):
|
|
137
|
+
"""Call checker for the builtin `callable` function"""
|
|
138
|
+
|
|
139
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
140
|
+
check_num_args(1, len(args), self.node)
|
|
141
|
+
[arg] = args
|
|
142
|
+
arg, ty = ExprSynthesizer(self.ctx).synthesize(arg)
|
|
143
|
+
is_callable = (
|
|
144
|
+
isinstance(ty, FunctionType)
|
|
145
|
+
or self.ctx.globals.get_instance_func(ty, "__call__") is not None
|
|
146
|
+
)
|
|
147
|
+
const = with_loc(self.node, ast.Constant(value=is_callable))
|
|
148
|
+
return const, bool_type()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class ArrayCopyChecker(CustomCallChecker):
|
|
152
|
+
"""Function call checker for the `array.copy` function."""
|
|
153
|
+
|
|
154
|
+
@dataclass(frozen=True)
|
|
155
|
+
class NonCopyableElementsError(Error):
|
|
156
|
+
title: ClassVar[str] = "Non-copyable elements"
|
|
157
|
+
span_label: ClassVar[str] = "Elements of type `{ty}` cannot be copied."
|
|
158
|
+
ty: Type
|
|
159
|
+
|
|
160
|
+
@dataclass(frozen=True)
|
|
161
|
+
class Explanation(Note):
|
|
162
|
+
message: ClassVar[str] = "Only arrays with copyable elements can be copied"
|
|
163
|
+
|
|
164
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
165
|
+
# First, check if we're trying to copy a non-copyable element type to give a
|
|
166
|
+
# nicer error message. Then, do the full `synthesize_call` type check
|
|
167
|
+
if len(args) == 1:
|
|
168
|
+
args[0], array_ty = ExprSynthesizer(self.ctx).synthesize(args[0])
|
|
169
|
+
if is_array_type(array_ty):
|
|
170
|
+
elem_ty = get_element_type(array_ty)
|
|
171
|
+
if not elem_ty.copyable:
|
|
172
|
+
err = ArrayCopyChecker.NonCopyableElementsError(self.node, elem_ty)
|
|
173
|
+
err.add_sub_diagnostic(
|
|
174
|
+
ArrayCopyChecker.NonCopyableElementsError.Explanation(None)
|
|
175
|
+
)
|
|
176
|
+
raise GuppyTypeError(err)
|
|
177
|
+
[array_arg], _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
|
|
178
|
+
node = GlobalCall(def_id=self.func.id, args=[array_arg], type_args=inst)
|
|
179
|
+
return with_loc(self.node, node), get_type(array_arg)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class NewArrayChecker(CustomCallChecker):
|
|
183
|
+
"""Function call checker for the `array.__new__` function."""
|
|
184
|
+
|
|
185
|
+
@dataclass(frozen=True)
|
|
186
|
+
class InferenceError(Error):
|
|
187
|
+
title: ClassVar[str] = "Cannot infer type"
|
|
188
|
+
span_label: ClassVar[str] = "Cannot infer the type of this array"
|
|
189
|
+
|
|
190
|
+
@dataclass(frozen=True)
|
|
191
|
+
class Suggestion(Note):
|
|
192
|
+
message: ClassVar[str] = (
|
|
193
|
+
"Consider adding a type annotation: `x: array[???] = ...`"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
197
|
+
match args:
|
|
198
|
+
case []:
|
|
199
|
+
err = NewArrayChecker.InferenceError(self.node)
|
|
200
|
+
err.add_sub_diagnostic(NewArrayChecker.InferenceError.Suggestion(None))
|
|
201
|
+
raise GuppyTypeError(err)
|
|
202
|
+
# Either an array comprehension
|
|
203
|
+
case [DesugaredGeneratorExpr() as compr]:
|
|
204
|
+
return self.synthesize_array_comprehension(compr)
|
|
205
|
+
# Or a list of array elements
|
|
206
|
+
case [fst, *rest]:
|
|
207
|
+
fst, ty = ExprSynthesizer(self.ctx).synthesize(fst)
|
|
208
|
+
checker = ExprChecker(self.ctx)
|
|
209
|
+
for i in range(len(rest)):
|
|
210
|
+
rest[i], subst = checker.check(rest[i], ty)
|
|
211
|
+
assert len(subst) == 0, "Array element type is closed"
|
|
212
|
+
result_ty = array_type(ty, len(args))
|
|
213
|
+
call = GlobalCall(
|
|
214
|
+
def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args
|
|
215
|
+
)
|
|
216
|
+
return with_loc(self.node, call), result_ty
|
|
217
|
+
case args:
|
|
218
|
+
return assert_never(args) # type: ignore[arg-type]
|
|
219
|
+
|
|
220
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
221
|
+
if not is_array_type(ty):
|
|
222
|
+
dummy_array_ty = array_type_def.check_instantiate(
|
|
223
|
+
[p.to_existential()[0] for p in array_type_def.params], self.node
|
|
224
|
+
)
|
|
225
|
+
raise GuppyTypeError(TypeMismatchError(self.node, ty, dummy_array_ty))
|
|
226
|
+
subst: Subst = {}
|
|
227
|
+
match ty.args:
|
|
228
|
+
case [TypeArg(ty=elem_ty), ConstArg(length)]:
|
|
229
|
+
match args:
|
|
230
|
+
# Either an array comprehension
|
|
231
|
+
case [DesugaredGeneratorExpr() as compr]:
|
|
232
|
+
# TODO: We could use the type information to infer some stuff
|
|
233
|
+
# in the comprehension
|
|
234
|
+
arr_compr, res_ty = self.synthesize_array_comprehension(compr)
|
|
235
|
+
arr_compr = with_loc(self.node, arr_compr)
|
|
236
|
+
arr_compr, subst, _ = check_type_against(
|
|
237
|
+
res_ty, ty, arr_compr, self.ctx
|
|
238
|
+
)
|
|
239
|
+
return arr_compr, subst
|
|
240
|
+
# Or a list of array elements
|
|
241
|
+
case args:
|
|
242
|
+
checker = ExprChecker(self.ctx)
|
|
243
|
+
for i in range(len(args)):
|
|
244
|
+
args[i], s = checker.check(
|
|
245
|
+
args[i], elem_ty.substitute(subst)
|
|
246
|
+
)
|
|
247
|
+
subst |= s
|
|
248
|
+
ls = unify(length, ConstValue(nat_type(), len(args)), {})
|
|
249
|
+
if ls is None:
|
|
250
|
+
raise GuppyTypeError(
|
|
251
|
+
TypeMismatchError(
|
|
252
|
+
self.node, ty, array_type(elem_ty, len(args))
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
subst |= ls
|
|
256
|
+
type_args = [
|
|
257
|
+
TypeArg(elem_ty.substitute(subst)),
|
|
258
|
+
ConstValue(nat_type(), len(args)),
|
|
259
|
+
]
|
|
260
|
+
call = GlobalCall(
|
|
261
|
+
def_id=self.func.id, args=args, type_args=type_args
|
|
262
|
+
)
|
|
263
|
+
return with_loc(self.node, call), subst
|
|
264
|
+
case type_args:
|
|
265
|
+
raise InternalGuppyError(f"Invalid array type args: {type_args}")
|
|
266
|
+
|
|
267
|
+
def synthesize_array_comprehension(
|
|
268
|
+
self, compr: DesugaredGeneratorExpr
|
|
269
|
+
) -> tuple[DesugaredArrayComp, Type]:
|
|
270
|
+
# Array comprehensions require a static size. To keep things simple, we'll only
|
|
271
|
+
# allow a single generator for now, so we don't have to reason about products
|
|
272
|
+
# of iterator sizes.
|
|
273
|
+
if len(compr.generators) > 1:
|
|
274
|
+
# Individual generator objects unfortunately don't have a span in Python's
|
|
275
|
+
# AST, so we have to use the whole expression span
|
|
276
|
+
raise GuppyError(UnsupportedError(compr, "Nested array comprehensions"))
|
|
277
|
+
[gen] = compr.generators
|
|
278
|
+
# Similarly, dynamic if guards are not allowed
|
|
279
|
+
if gen.ifs:
|
|
280
|
+
err = ArrayComprUnknownSizeError(compr)
|
|
281
|
+
err.add_sub_diagnostic(ArrayComprUnknownSizeError.IfGuard(gen.ifs[0]))
|
|
282
|
+
raise GuppyError(err)
|
|
283
|
+
# Extract the iterator size
|
|
284
|
+
match gen.iter_assign:
|
|
285
|
+
case ast.Assign(value=MakeIter() as make_iter):
|
|
286
|
+
sized_make_iter = MakeIter(
|
|
287
|
+
make_iter.value, make_iter.origin_node, unwrap_size_hint=False
|
|
288
|
+
)
|
|
289
|
+
_, iter_ty = ExprSynthesizer(self.ctx).synthesize(sized_make_iter)
|
|
290
|
+
# The iterator must have a static size hint
|
|
291
|
+
if not is_sized_iter_type(iter_ty):
|
|
292
|
+
err = ArrayComprUnknownSizeError(compr)
|
|
293
|
+
err.add_sub_diagnostic(
|
|
294
|
+
ArrayComprUnknownSizeError.DynamicIterator(make_iter)
|
|
295
|
+
)
|
|
296
|
+
raise GuppyError(err)
|
|
297
|
+
size = get_iter_size(iter_ty)
|
|
298
|
+
case _:
|
|
299
|
+
raise InternalGuppyError("Invalid iterator assign statement")
|
|
300
|
+
# Finally, type check the comprehension
|
|
301
|
+
[gen], elt, elt_ty = synthesize_comprehension(compr, [gen], compr.elt, self.ctx)
|
|
302
|
+
array_compr = DesugaredArrayComp(
|
|
303
|
+
elt=elt, generator=gen, length=size, elt_ty=elt_ty
|
|
304
|
+
)
|
|
305
|
+
return with_loc(compr, array_compr), array_type(elt_ty, size)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
#: Maximum length of a tag in the `result` function.
|
|
309
|
+
TAG_MAX_LEN = 200
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@dataclass(frozen=True)
|
|
313
|
+
class TooLongError(Error):
|
|
314
|
+
title: ClassVar[str] = "Tag too long"
|
|
315
|
+
span_label: ClassVar[str] = "Result tag is too long"
|
|
316
|
+
|
|
317
|
+
@dataclass(frozen=True)
|
|
318
|
+
class Hint(Note):
|
|
319
|
+
message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class ResultChecker(CustomCallChecker):
|
|
323
|
+
"""Call checker for the `result` function."""
|
|
324
|
+
|
|
325
|
+
@dataclass(frozen=True)
|
|
326
|
+
class InvalidError(Error):
|
|
327
|
+
title: ClassVar[str] = "Invalid Result"
|
|
328
|
+
span_label: ClassVar[str] = "Expression of type `{ty}` is not a valid result."
|
|
329
|
+
ty: Type
|
|
330
|
+
|
|
331
|
+
@dataclass(frozen=True)
|
|
332
|
+
class Explanation(Note):
|
|
333
|
+
message: ClassVar[str] = (
|
|
334
|
+
"Only numeric values or arrays thereof are allowed as results"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
338
|
+
check_num_args(2, len(args), self.node)
|
|
339
|
+
[tag, value] = args
|
|
340
|
+
tag, _ = ExprChecker(self.ctx).check(tag, string_type())
|
|
341
|
+
match tag:
|
|
342
|
+
case ast.Constant(value=str(v)):
|
|
343
|
+
tag_value = v
|
|
344
|
+
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
345
|
+
tag_value = v
|
|
346
|
+
case _:
|
|
347
|
+
raise GuppyTypeError(ExpectedError(tag, "a string literal"))
|
|
348
|
+
if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
|
|
349
|
+
err: Error = TooLongError(tag)
|
|
350
|
+
err.add_sub_diagnostic(TooLongError.Hint(None))
|
|
351
|
+
raise GuppyTypeError(err)
|
|
352
|
+
value, ty = ExprSynthesizer(self.ctx).synthesize(value)
|
|
353
|
+
# We only allow numeric values or vectors of numeric values
|
|
354
|
+
err = ResultChecker.InvalidError(value, ty)
|
|
355
|
+
err.add_sub_diagnostic(ResultChecker.InvalidError.Explanation(None))
|
|
356
|
+
if self._is_numeric_or_bool_type(ty):
|
|
357
|
+
base_ty = ty
|
|
358
|
+
array_len: Const | None = None
|
|
359
|
+
elif is_array_type(ty):
|
|
360
|
+
[ty_arg, len_arg] = ty.args
|
|
361
|
+
assert isinstance(ty_arg, TypeArg)
|
|
362
|
+
assert isinstance(len_arg, ConstArg)
|
|
363
|
+
if not self._is_numeric_or_bool_type(ty_arg.ty):
|
|
364
|
+
raise GuppyError(err)
|
|
365
|
+
base_ty = ty_arg.ty
|
|
366
|
+
array_len = len_arg.const
|
|
367
|
+
else:
|
|
368
|
+
raise GuppyError(err)
|
|
369
|
+
node = ResultExpr(value, base_ty, array_len, tag_value)
|
|
370
|
+
return with_loc(self.node, node), NoneType()
|
|
371
|
+
|
|
372
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
373
|
+
expr, res_ty = self.synthesize(args)
|
|
374
|
+
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
|
|
375
|
+
return expr, subst
|
|
376
|
+
|
|
377
|
+
@staticmethod
|
|
378
|
+
def _is_numeric_or_bool_type(ty: Type) -> bool:
|
|
379
|
+
return isinstance(ty, NumericType) or is_bool_type(ty)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class PanicChecker(CustomCallChecker):
|
|
383
|
+
"""Call checker for the `panic` function."""
|
|
384
|
+
|
|
385
|
+
@dataclass(frozen=True)
|
|
386
|
+
class NoMessageError(Error):
|
|
387
|
+
title: ClassVar[str] = "No panic message"
|
|
388
|
+
span_label: ClassVar[str] = "Missing message argument to panic call"
|
|
389
|
+
|
|
390
|
+
@dataclass(frozen=True)
|
|
391
|
+
class Suggestion(Note):
|
|
392
|
+
message: ClassVar[str] = 'Add a message: `panic("message")`'
|
|
393
|
+
|
|
394
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
395
|
+
match args:
|
|
396
|
+
case []:
|
|
397
|
+
err = PanicChecker.NoMessageError(self.node)
|
|
398
|
+
err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
|
|
399
|
+
raise GuppyTypeError(err)
|
|
400
|
+
case [msg, *rest]:
|
|
401
|
+
msg, _ = ExprChecker(self.ctx).check(msg, string_type())
|
|
402
|
+
match msg:
|
|
403
|
+
case ast.Constant(value=str(v)):
|
|
404
|
+
msg_value = v
|
|
405
|
+
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
406
|
+
msg_value = v
|
|
407
|
+
case _:
|
|
408
|
+
raise GuppyTypeError(ExpectedError(msg, "a string literal"))
|
|
409
|
+
vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
|
|
410
|
+
# TODO variable signals once default arguments are available
|
|
411
|
+
node = PanicExpr(
|
|
412
|
+
kind=ExitKind.Panic, msg=msg_value, values=vals, signal=1
|
|
413
|
+
)
|
|
414
|
+
return with_loc(self.node, node), NoneType()
|
|
415
|
+
case args:
|
|
416
|
+
return assert_never(args) # type: ignore[arg-type]
|
|
417
|
+
|
|
418
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
419
|
+
# Panic may return any type, so we don't have to check anything. Consequently
|
|
420
|
+
# we also can't infer anything in the expected type, so we always return an
|
|
421
|
+
# empty substitution
|
|
422
|
+
expr, _ = self.synthesize(args)
|
|
423
|
+
return expr, {}
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class ExitChecker(CustomCallChecker):
|
|
427
|
+
"""Call checker for the ``exit` functions."""
|
|
428
|
+
|
|
429
|
+
@dataclass(frozen=True)
|
|
430
|
+
class NoMessageError(Error):
|
|
431
|
+
title: ClassVar[str] = "No exit message"
|
|
432
|
+
span_label: ClassVar[str] = "Missing message argument to exit call"
|
|
433
|
+
|
|
434
|
+
@dataclass(frozen=True)
|
|
435
|
+
class Suggestion(Note):
|
|
436
|
+
message: ClassVar[str] = 'Add a message: `exit("message", 0)`'
|
|
437
|
+
|
|
438
|
+
@dataclass(frozen=True)
|
|
439
|
+
class NoSignalError(Error):
|
|
440
|
+
title: ClassVar[str] = "No exit signal"
|
|
441
|
+
span_label: ClassVar[str] = "Missing signal argument to exit call"
|
|
442
|
+
|
|
443
|
+
@dataclass(frozen=True)
|
|
444
|
+
class Suggestion(Note):
|
|
445
|
+
message: ClassVar[str] = 'Add a signal: `exit("message", 0)`'
|
|
446
|
+
|
|
447
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
448
|
+
match args:
|
|
449
|
+
case []:
|
|
450
|
+
msg_err = ExitChecker.NoMessageError(self.node)
|
|
451
|
+
msg_err.add_sub_diagnostic(ExitChecker.NoMessageError.Suggestion(None))
|
|
452
|
+
raise GuppyTypeError(msg_err)
|
|
453
|
+
case [_msg]:
|
|
454
|
+
signal_err = ExitChecker.NoSignalError(self.node)
|
|
455
|
+
signal_err.add_sub_diagnostic(
|
|
456
|
+
ExitChecker.NoSignalError.Suggestion(None)
|
|
457
|
+
)
|
|
458
|
+
raise GuppyTypeError(signal_err)
|
|
459
|
+
case [msg, signal, *rest]:
|
|
460
|
+
msg, _ = ExprChecker(self.ctx).check(msg, string_type())
|
|
461
|
+
match msg:
|
|
462
|
+
case ast.Constant(value=str(v)):
|
|
463
|
+
msg_value = v
|
|
464
|
+
case PlaceNode(place=ComptimeVariable(static_value=str(v))):
|
|
465
|
+
msg_value = v
|
|
466
|
+
case _:
|
|
467
|
+
raise GuppyTypeError(ExpectedError(msg, "a string literal"))
|
|
468
|
+
# TODO allow variable signals after https://github.com/CQCL/hugr/issues/1863
|
|
469
|
+
signal, _ = ExprChecker(self.ctx).check(signal, int_type())
|
|
470
|
+
match signal:
|
|
471
|
+
case ast.Constant(value=int(s)):
|
|
472
|
+
signal_value = s
|
|
473
|
+
case PlaceNode(place=ComptimeVariable(static_value=int(s))):
|
|
474
|
+
signal_value = s
|
|
475
|
+
case _:
|
|
476
|
+
raise GuppyTypeError(
|
|
477
|
+
ExpectedError(signal, "an integer literal")
|
|
478
|
+
)
|
|
479
|
+
vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
|
|
480
|
+
node = PanicExpr(
|
|
481
|
+
kind=ExitKind.ExitShot,
|
|
482
|
+
msg=msg_value,
|
|
483
|
+
values=vals,
|
|
484
|
+
signal=signal_value,
|
|
485
|
+
)
|
|
486
|
+
return with_loc(self.node, node), NoneType()
|
|
487
|
+
case args:
|
|
488
|
+
return assert_never(args) # type: ignore[arg-type]
|
|
489
|
+
|
|
490
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
491
|
+
# Exit may return any type, so we don't have to check anything. Consequently
|
|
492
|
+
# we also can't infer anything in the expected type, so we always return an
|
|
493
|
+
# empty substitution
|
|
494
|
+
expr, _ = self.synthesize(args)
|
|
495
|
+
return expr, {}
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
class RangeChecker(CustomCallChecker):
|
|
499
|
+
"""Call checker for the `range` function."""
|
|
500
|
+
|
|
501
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
502
|
+
check_num_args(1, len(args), self.node)
|
|
503
|
+
[stop] = args
|
|
504
|
+
stop_checked, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument")
|
|
505
|
+
range_iter, range_ty = self.make_range(stop_checked)
|
|
506
|
+
# Check if `stop` is a statically known value. Note that we need to do this on
|
|
507
|
+
# the original `stop` instead of `stop_checked` to avoid any previously inserted
|
|
508
|
+
# `int` coercions.
|
|
509
|
+
if (static_stop := self.check_static(stop)) is not None:
|
|
510
|
+
return to_sized_iter(range_iter, range_ty, static_stop, self.ctx)
|
|
511
|
+
return range_iter, range_ty
|
|
512
|
+
|
|
513
|
+
def check_static(self, stop: ast.expr) -> "int | Const | None":
|
|
514
|
+
stop, _ = ExprSynthesizer(self.ctx).synthesize(stop, allow_free_vars=True)
|
|
515
|
+
if isinstance(stop, ast.Constant) and isinstance(stop.value, int):
|
|
516
|
+
return stop.value
|
|
517
|
+
if isinstance(stop, GenericParamValue) and stop.param.ty == nat_type():
|
|
518
|
+
return stop.param.to_bound().const
|
|
519
|
+
return None
|
|
520
|
+
|
|
521
|
+
def range_ty(self) -> StructType:
|
|
522
|
+
from guppylang.std.builtins import Range
|
|
523
|
+
from guppylang_internals.engine import ENGINE
|
|
524
|
+
|
|
525
|
+
def_id = cast(RawStructDef, Range).id
|
|
526
|
+
range_type_def = ENGINE.get_checked(def_id)
|
|
527
|
+
assert isinstance(range_type_def, CheckedStructDef)
|
|
528
|
+
return StructType([], range_type_def)
|
|
529
|
+
|
|
530
|
+
def make_range(self, stop: ast.expr) -> tuple[ast.expr, Type]:
|
|
531
|
+
make_range = self.ctx.globals.get_instance_func(self.range_ty(), "__new__")
|
|
532
|
+
assert make_range is not None
|
|
533
|
+
start = with_type(int_type(), with_loc(self.node, ast.Constant(value=0)))
|
|
534
|
+
return make_range.synthesize_call([start, stop], self.node, self.ctx)
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def to_sized_iter(
|
|
538
|
+
iterator: ast.expr, range_ty: Type, size: "int | Const", ctx: Context
|
|
539
|
+
) -> tuple[ast.expr, Type]:
|
|
540
|
+
"""Adds a static size annotation to an iterator."""
|
|
541
|
+
sized_iter_ty = sized_iter_type(range_ty, size)
|
|
542
|
+
make_sized_iter = ctx.globals.get_instance_func(sized_iter_ty, "__new__")
|
|
543
|
+
assert make_sized_iter is not None
|
|
544
|
+
sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, iterator, ctx)
|
|
545
|
+
return sized_iter, sized_iter_ty
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class BarrierChecker(CustomCallChecker):
|
|
549
|
+
"""Call checker for the `barrier` function."""
|
|
550
|
+
|
|
551
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
552
|
+
tys = [ExprSynthesizer(self.ctx).synthesize(val)[1] for val in args]
|
|
553
|
+
func_ty = FunctionType(
|
|
554
|
+
[FuncInput(t, InputFlags.Inout) for t in tys],
|
|
555
|
+
NoneType(),
|
|
556
|
+
)
|
|
557
|
+
args, ret_ty, inst = synthesize_call(func_ty, args, self.node, self.ctx)
|
|
558
|
+
assert len(inst) == 0, "func_ty is not generic"
|
|
559
|
+
node = BarrierExpr(args=args, func_ty=func_ty)
|
|
560
|
+
return with_loc(self.node, node), ret_ty
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
class WasmCallChecker(CustomCallChecker):
|
|
564
|
+
def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
|
|
565
|
+
# Use default implementation from the expression checker
|
|
566
|
+
args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)
|
|
567
|
+
|
|
568
|
+
return GlobalCall(def_id=self.func.id, args=args, type_args=inst), subst
|
|
569
|
+
|
|
570
|
+
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
|
|
571
|
+
# Use default implementation from the expression checker
|
|
572
|
+
args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
|
|
573
|
+
return GlobalCall(def_id=self.func.id, args=args, type_args=inst), ty
|
|
File without changes
|