guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,573 @@
1
+ import ast
2
+ from dataclasses import dataclass
3
+ from typing import ClassVar, cast
4
+
5
+ from typing_extensions import assert_never
6
+
7
+ from guppylang_internals.ast_util import get_type, with_loc, with_type
8
+ from guppylang_internals.checker.core import ComptimeVariable, Context
9
+ from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
10
+ from guppylang_internals.checker.errors.type_errors import (
11
+ ArrayComprUnknownSizeError,
12
+ TypeMismatchError,
13
+ )
14
+ from guppylang_internals.checker.expr_checker import (
15
+ ExprChecker,
16
+ ExprSynthesizer,
17
+ check_call,
18
+ check_num_args,
19
+ check_type_against,
20
+ synthesize_call,
21
+ synthesize_comprehension,
22
+ )
23
+ from guppylang_internals.definition.custom import (
24
+ CustomCallChecker,
25
+ )
26
+ from guppylang_internals.definition.struct import CheckedStructDef, RawStructDef
27
+ from guppylang_internals.diagnostic import Error, Note
28
+ from guppylang_internals.error import GuppyError, GuppyTypeError, InternalGuppyError
29
+ from guppylang_internals.nodes import (
30
+ BarrierExpr,
31
+ DesugaredArrayComp,
32
+ DesugaredGeneratorExpr,
33
+ ExitKind,
34
+ GenericParamValue,
35
+ GlobalCall,
36
+ MakeIter,
37
+ PanicExpr,
38
+ PlaceNode,
39
+ ResultExpr,
40
+ )
41
+ from guppylang_internals.tys.arg import ConstArg, TypeArg
42
+ from guppylang_internals.tys.builtin import (
43
+ array_type,
44
+ array_type_def,
45
+ bool_type,
46
+ get_element_type,
47
+ get_iter_size,
48
+ int_type,
49
+ is_array_type,
50
+ is_bool_type,
51
+ is_sized_iter_type,
52
+ nat_type,
53
+ sized_iter_type,
54
+ string_type,
55
+ )
56
+ from guppylang_internals.tys.const import Const, ConstValue
57
+ from guppylang_internals.tys.subst import Subst
58
+ from guppylang_internals.tys.ty import (
59
+ FuncInput,
60
+ FunctionType,
61
+ InputFlags,
62
+ NoneType,
63
+ NumericType,
64
+ StructType,
65
+ Type,
66
+ unify,
67
+ )
68
+
69
+
70
+ class ReversingChecker(CustomCallChecker):
71
+ """Call checker for reverse arithmetic methods.
72
+
73
+ For examples, turns a call to `__radd__` into a call to `__add__` with reversed
74
+ arguments.
75
+ """
76
+
77
+ def parse_name(self) -> str:
78
+ # Must be a dunder method
79
+ assert self.func.name.startswith("__")
80
+ assert self.func.name.endswith("__")
81
+ name = self.func.name[2:-2]
82
+ # Remove the `r`
83
+ assert name.startswith("r")
84
+ return f"__{name[1:]}__"
85
+
86
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
87
+ [self_arg, other_arg] = args
88
+ self_arg, self_ty = ExprSynthesizer(self.ctx).synthesize(self_arg)
89
+ f = self.ctx.globals.get_instance_func(self_ty, self.parse_name())
90
+ assert f is not None
91
+ return f.synthesize_call([other_arg, self_arg], self.node, self.ctx)
92
+
93
+
94
+ class UnsupportedChecker(CustomCallChecker):
95
+ """Call checker for Python builtin functions that are not available in Guppy.
96
+
97
+ Gives the uses a nicer error message when they try to use an unsupported feature.
98
+ """
99
+
100
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
101
+ err = UnsupportedError(
102
+ self.node, f"Builtin method `{self.func.name}`", singular=True
103
+ )
104
+ raise GuppyError(err)
105
+
106
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
107
+ err = UnsupportedError(
108
+ self.node, f"Builtin method `{self.func.name}`", singular=True
109
+ )
110
+ raise GuppyError(err)
111
+
112
+
113
+ class DunderChecker(CustomCallChecker):
114
+ """Call checker for builtin functions that call out to dunder instance methods"""
115
+
116
+ dunder_name: str
117
+ num_args: int
118
+
119
+ def __init__(self, dunder_name: str, num_args: int = 1):
120
+ assert num_args > 0
121
+ self.dunder_name = dunder_name
122
+ self.num_args = num_args
123
+
124
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
125
+ check_num_args(self.num_args, len(args), self.node)
126
+ fst, *rest = args
127
+ return ExprSynthesizer(self.ctx).synthesize_instance_func(
128
+ fst,
129
+ rest,
130
+ self.dunder_name,
131
+ f"a valid argument to `{self.func.name}`",
132
+ give_reason=True,
133
+ )
134
+
135
+
136
+ class CallableChecker(CustomCallChecker):
137
+ """Call checker for the builtin `callable` function"""
138
+
139
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
140
+ check_num_args(1, len(args), self.node)
141
+ [arg] = args
142
+ arg, ty = ExprSynthesizer(self.ctx).synthesize(arg)
143
+ is_callable = (
144
+ isinstance(ty, FunctionType)
145
+ or self.ctx.globals.get_instance_func(ty, "__call__") is not None
146
+ )
147
+ const = with_loc(self.node, ast.Constant(value=is_callable))
148
+ return const, bool_type()
149
+
150
+
151
+ class ArrayCopyChecker(CustomCallChecker):
152
+ """Function call checker for the `array.copy` function."""
153
+
154
+ @dataclass(frozen=True)
155
+ class NonCopyableElementsError(Error):
156
+ title: ClassVar[str] = "Non-copyable elements"
157
+ span_label: ClassVar[str] = "Elements of type `{ty}` cannot be copied."
158
+ ty: Type
159
+
160
+ @dataclass(frozen=True)
161
+ class Explanation(Note):
162
+ message: ClassVar[str] = "Only arrays with copyable elements can be copied"
163
+
164
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
165
+ # First, check if we're trying to copy a non-copyable element type to give a
166
+ # nicer error message. Then, do the full `synthesize_call` type check
167
+ if len(args) == 1:
168
+ args[0], array_ty = ExprSynthesizer(self.ctx).synthesize(args[0])
169
+ if is_array_type(array_ty):
170
+ elem_ty = get_element_type(array_ty)
171
+ if not elem_ty.copyable:
172
+ err = ArrayCopyChecker.NonCopyableElementsError(self.node, elem_ty)
173
+ err.add_sub_diagnostic(
174
+ ArrayCopyChecker.NonCopyableElementsError.Explanation(None)
175
+ )
176
+ raise GuppyTypeError(err)
177
+ [array_arg], _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
178
+ node = GlobalCall(def_id=self.func.id, args=[array_arg], type_args=inst)
179
+ return with_loc(self.node, node), get_type(array_arg)
180
+
181
+
182
+ class NewArrayChecker(CustomCallChecker):
183
+ """Function call checker for the `array.__new__` function."""
184
+
185
+ @dataclass(frozen=True)
186
+ class InferenceError(Error):
187
+ title: ClassVar[str] = "Cannot infer type"
188
+ span_label: ClassVar[str] = "Cannot infer the type of this array"
189
+
190
+ @dataclass(frozen=True)
191
+ class Suggestion(Note):
192
+ message: ClassVar[str] = (
193
+ "Consider adding a type annotation: `x: array[???] = ...`"
194
+ )
195
+
196
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
197
+ match args:
198
+ case []:
199
+ err = NewArrayChecker.InferenceError(self.node)
200
+ err.add_sub_diagnostic(NewArrayChecker.InferenceError.Suggestion(None))
201
+ raise GuppyTypeError(err)
202
+ # Either an array comprehension
203
+ case [DesugaredGeneratorExpr() as compr]:
204
+ return self.synthesize_array_comprehension(compr)
205
+ # Or a list of array elements
206
+ case [fst, *rest]:
207
+ fst, ty = ExprSynthesizer(self.ctx).synthesize(fst)
208
+ checker = ExprChecker(self.ctx)
209
+ for i in range(len(rest)):
210
+ rest[i], subst = checker.check(rest[i], ty)
211
+ assert len(subst) == 0, "Array element type is closed"
212
+ result_ty = array_type(ty, len(args))
213
+ call = GlobalCall(
214
+ def_id=self.func.id, args=[fst, *rest], type_args=result_ty.args
215
+ )
216
+ return with_loc(self.node, call), result_ty
217
+ case args:
218
+ return assert_never(args) # type: ignore[arg-type]
219
+
220
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
221
+ if not is_array_type(ty):
222
+ dummy_array_ty = array_type_def.check_instantiate(
223
+ [p.to_existential()[0] for p in array_type_def.params], self.node
224
+ )
225
+ raise GuppyTypeError(TypeMismatchError(self.node, ty, dummy_array_ty))
226
+ subst: Subst = {}
227
+ match ty.args:
228
+ case [TypeArg(ty=elem_ty), ConstArg(length)]:
229
+ match args:
230
+ # Either an array comprehension
231
+ case [DesugaredGeneratorExpr() as compr]:
232
+ # TODO: We could use the type information to infer some stuff
233
+ # in the comprehension
234
+ arr_compr, res_ty = self.synthesize_array_comprehension(compr)
235
+ arr_compr = with_loc(self.node, arr_compr)
236
+ arr_compr, subst, _ = check_type_against(
237
+ res_ty, ty, arr_compr, self.ctx
238
+ )
239
+ return arr_compr, subst
240
+ # Or a list of array elements
241
+ case args:
242
+ checker = ExprChecker(self.ctx)
243
+ for i in range(len(args)):
244
+ args[i], s = checker.check(
245
+ args[i], elem_ty.substitute(subst)
246
+ )
247
+ subst |= s
248
+ ls = unify(length, ConstValue(nat_type(), len(args)), {})
249
+ if ls is None:
250
+ raise GuppyTypeError(
251
+ TypeMismatchError(
252
+ self.node, ty, array_type(elem_ty, len(args))
253
+ )
254
+ )
255
+ subst |= ls
256
+ type_args = [
257
+ TypeArg(elem_ty.substitute(subst)),
258
+ ConstValue(nat_type(), len(args)),
259
+ ]
260
+ call = GlobalCall(
261
+ def_id=self.func.id, args=args, type_args=type_args
262
+ )
263
+ return with_loc(self.node, call), subst
264
+ case type_args:
265
+ raise InternalGuppyError(f"Invalid array type args: {type_args}")
266
+
267
+ def synthesize_array_comprehension(
268
+ self, compr: DesugaredGeneratorExpr
269
+ ) -> tuple[DesugaredArrayComp, Type]:
270
+ # Array comprehensions require a static size. To keep things simple, we'll only
271
+ # allow a single generator for now, so we don't have to reason about products
272
+ # of iterator sizes.
273
+ if len(compr.generators) > 1:
274
+ # Individual generator objects unfortunately don't have a span in Python's
275
+ # AST, so we have to use the whole expression span
276
+ raise GuppyError(UnsupportedError(compr, "Nested array comprehensions"))
277
+ [gen] = compr.generators
278
+ # Similarly, dynamic if guards are not allowed
279
+ if gen.ifs:
280
+ err = ArrayComprUnknownSizeError(compr)
281
+ err.add_sub_diagnostic(ArrayComprUnknownSizeError.IfGuard(gen.ifs[0]))
282
+ raise GuppyError(err)
283
+ # Extract the iterator size
284
+ match gen.iter_assign:
285
+ case ast.Assign(value=MakeIter() as make_iter):
286
+ sized_make_iter = MakeIter(
287
+ make_iter.value, make_iter.origin_node, unwrap_size_hint=False
288
+ )
289
+ _, iter_ty = ExprSynthesizer(self.ctx).synthesize(sized_make_iter)
290
+ # The iterator must have a static size hint
291
+ if not is_sized_iter_type(iter_ty):
292
+ err = ArrayComprUnknownSizeError(compr)
293
+ err.add_sub_diagnostic(
294
+ ArrayComprUnknownSizeError.DynamicIterator(make_iter)
295
+ )
296
+ raise GuppyError(err)
297
+ size = get_iter_size(iter_ty)
298
+ case _:
299
+ raise InternalGuppyError("Invalid iterator assign statement")
300
+ # Finally, type check the comprehension
301
+ [gen], elt, elt_ty = synthesize_comprehension(compr, [gen], compr.elt, self.ctx)
302
+ array_compr = DesugaredArrayComp(
303
+ elt=elt, generator=gen, length=size, elt_ty=elt_ty
304
+ )
305
+ return with_loc(compr, array_compr), array_type(elt_ty, size)
306
+
307
+
308
+ #: Maximum length of a tag in the `result` function.
309
+ TAG_MAX_LEN = 200
310
+
311
+
312
+ @dataclass(frozen=True)
313
+ class TooLongError(Error):
314
+ title: ClassVar[str] = "Tag too long"
315
+ span_label: ClassVar[str] = "Result tag is too long"
316
+
317
+ @dataclass(frozen=True)
318
+ class Hint(Note):
319
+ message: ClassVar[str] = f"Result tags are limited to {TAG_MAX_LEN} bytes"
320
+
321
+
322
+ class ResultChecker(CustomCallChecker):
323
+ """Call checker for the `result` function."""
324
+
325
+ @dataclass(frozen=True)
326
+ class InvalidError(Error):
327
+ title: ClassVar[str] = "Invalid Result"
328
+ span_label: ClassVar[str] = "Expression of type `{ty}` is not a valid result."
329
+ ty: Type
330
+
331
+ @dataclass(frozen=True)
332
+ class Explanation(Note):
333
+ message: ClassVar[str] = (
334
+ "Only numeric values or arrays thereof are allowed as results"
335
+ )
336
+
337
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
338
+ check_num_args(2, len(args), self.node)
339
+ [tag, value] = args
340
+ tag, _ = ExprChecker(self.ctx).check(tag, string_type())
341
+ match tag:
342
+ case ast.Constant(value=str(v)):
343
+ tag_value = v
344
+ case PlaceNode(place=ComptimeVariable(static_value=str(v))):
345
+ tag_value = v
346
+ case _:
347
+ raise GuppyTypeError(ExpectedError(tag, "a string literal"))
348
+ if len(tag_value.encode("utf-8")) > TAG_MAX_LEN:
349
+ err: Error = TooLongError(tag)
350
+ err.add_sub_diagnostic(TooLongError.Hint(None))
351
+ raise GuppyTypeError(err)
352
+ value, ty = ExprSynthesizer(self.ctx).synthesize(value)
353
+ # We only allow numeric values or vectors of numeric values
354
+ err = ResultChecker.InvalidError(value, ty)
355
+ err.add_sub_diagnostic(ResultChecker.InvalidError.Explanation(None))
356
+ if self._is_numeric_or_bool_type(ty):
357
+ base_ty = ty
358
+ array_len: Const | None = None
359
+ elif is_array_type(ty):
360
+ [ty_arg, len_arg] = ty.args
361
+ assert isinstance(ty_arg, TypeArg)
362
+ assert isinstance(len_arg, ConstArg)
363
+ if not self._is_numeric_or_bool_type(ty_arg.ty):
364
+ raise GuppyError(err)
365
+ base_ty = ty_arg.ty
366
+ array_len = len_arg.const
367
+ else:
368
+ raise GuppyError(err)
369
+ node = ResultExpr(value, base_ty, array_len, tag_value)
370
+ return with_loc(self.node, node), NoneType()
371
+
372
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
373
+ expr, res_ty = self.synthesize(args)
374
+ expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
375
+ return expr, subst
376
+
377
+ @staticmethod
378
+ def _is_numeric_or_bool_type(ty: Type) -> bool:
379
+ return isinstance(ty, NumericType) or is_bool_type(ty)
380
+
381
+
382
+ class PanicChecker(CustomCallChecker):
383
+ """Call checker for the `panic` function."""
384
+
385
+ @dataclass(frozen=True)
386
+ class NoMessageError(Error):
387
+ title: ClassVar[str] = "No panic message"
388
+ span_label: ClassVar[str] = "Missing message argument to panic call"
389
+
390
+ @dataclass(frozen=True)
391
+ class Suggestion(Note):
392
+ message: ClassVar[str] = 'Add a message: `panic("message")`'
393
+
394
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
395
+ match args:
396
+ case []:
397
+ err = PanicChecker.NoMessageError(self.node)
398
+ err.add_sub_diagnostic(PanicChecker.NoMessageError.Suggestion(None))
399
+ raise GuppyTypeError(err)
400
+ case [msg, *rest]:
401
+ msg, _ = ExprChecker(self.ctx).check(msg, string_type())
402
+ match msg:
403
+ case ast.Constant(value=str(v)):
404
+ msg_value = v
405
+ case PlaceNode(place=ComptimeVariable(static_value=str(v))):
406
+ msg_value = v
407
+ case _:
408
+ raise GuppyTypeError(ExpectedError(msg, "a string literal"))
409
+ vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
410
+ # TODO variable signals once default arguments are available
411
+ node = PanicExpr(
412
+ kind=ExitKind.Panic, msg=msg_value, values=vals, signal=1
413
+ )
414
+ return with_loc(self.node, node), NoneType()
415
+ case args:
416
+ return assert_never(args) # type: ignore[arg-type]
417
+
418
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
419
+ # Panic may return any type, so we don't have to check anything. Consequently
420
+ # we also can't infer anything in the expected type, so we always return an
421
+ # empty substitution
422
+ expr, _ = self.synthesize(args)
423
+ return expr, {}
424
+
425
+
426
+ class ExitChecker(CustomCallChecker):
427
+ """Call checker for the ``exit` functions."""
428
+
429
+ @dataclass(frozen=True)
430
+ class NoMessageError(Error):
431
+ title: ClassVar[str] = "No exit message"
432
+ span_label: ClassVar[str] = "Missing message argument to exit call"
433
+
434
+ @dataclass(frozen=True)
435
+ class Suggestion(Note):
436
+ message: ClassVar[str] = 'Add a message: `exit("message", 0)`'
437
+
438
+ @dataclass(frozen=True)
439
+ class NoSignalError(Error):
440
+ title: ClassVar[str] = "No exit signal"
441
+ span_label: ClassVar[str] = "Missing signal argument to exit call"
442
+
443
+ @dataclass(frozen=True)
444
+ class Suggestion(Note):
445
+ message: ClassVar[str] = 'Add a signal: `exit("message", 0)`'
446
+
447
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
448
+ match args:
449
+ case []:
450
+ msg_err = ExitChecker.NoMessageError(self.node)
451
+ msg_err.add_sub_diagnostic(ExitChecker.NoMessageError.Suggestion(None))
452
+ raise GuppyTypeError(msg_err)
453
+ case [_msg]:
454
+ signal_err = ExitChecker.NoSignalError(self.node)
455
+ signal_err.add_sub_diagnostic(
456
+ ExitChecker.NoSignalError.Suggestion(None)
457
+ )
458
+ raise GuppyTypeError(signal_err)
459
+ case [msg, signal, *rest]:
460
+ msg, _ = ExprChecker(self.ctx).check(msg, string_type())
461
+ match msg:
462
+ case ast.Constant(value=str(v)):
463
+ msg_value = v
464
+ case PlaceNode(place=ComptimeVariable(static_value=str(v))):
465
+ msg_value = v
466
+ case _:
467
+ raise GuppyTypeError(ExpectedError(msg, "a string literal"))
468
+ # TODO allow variable signals after https://github.com/CQCL/hugr/issues/1863
469
+ signal, _ = ExprChecker(self.ctx).check(signal, int_type())
470
+ match signal:
471
+ case ast.Constant(value=int(s)):
472
+ signal_value = s
473
+ case PlaceNode(place=ComptimeVariable(static_value=int(s))):
474
+ signal_value = s
475
+ case _:
476
+ raise GuppyTypeError(
477
+ ExpectedError(signal, "an integer literal")
478
+ )
479
+ vals = [ExprSynthesizer(self.ctx).synthesize(val)[0] for val in rest]
480
+ node = PanicExpr(
481
+ kind=ExitKind.ExitShot,
482
+ msg=msg_value,
483
+ values=vals,
484
+ signal=signal_value,
485
+ )
486
+ return with_loc(self.node, node), NoneType()
487
+ case args:
488
+ return assert_never(args) # type: ignore[arg-type]
489
+
490
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
491
+ # Exit may return any type, so we don't have to check anything. Consequently
492
+ # we also can't infer anything in the expected type, so we always return an
493
+ # empty substitution
494
+ expr, _ = self.synthesize(args)
495
+ return expr, {}
496
+
497
+
498
+ class RangeChecker(CustomCallChecker):
499
+ """Call checker for the `range` function."""
500
+
501
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
502
+ check_num_args(1, len(args), self.node)
503
+ [stop] = args
504
+ stop_checked, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument")
505
+ range_iter, range_ty = self.make_range(stop_checked)
506
+ # Check if `stop` is a statically known value. Note that we need to do this on
507
+ # the original `stop` instead of `stop_checked` to avoid any previously inserted
508
+ # `int` coercions.
509
+ if (static_stop := self.check_static(stop)) is not None:
510
+ return to_sized_iter(range_iter, range_ty, static_stop, self.ctx)
511
+ return range_iter, range_ty
512
+
513
+ def check_static(self, stop: ast.expr) -> "int | Const | None":
514
+ stop, _ = ExprSynthesizer(self.ctx).synthesize(stop, allow_free_vars=True)
515
+ if isinstance(stop, ast.Constant) and isinstance(stop.value, int):
516
+ return stop.value
517
+ if isinstance(stop, GenericParamValue) and stop.param.ty == nat_type():
518
+ return stop.param.to_bound().const
519
+ return None
520
+
521
+ def range_ty(self) -> StructType:
522
+ from guppylang.std.builtins import Range
523
+ from guppylang_internals.engine import ENGINE
524
+
525
+ def_id = cast(RawStructDef, Range).id
526
+ range_type_def = ENGINE.get_checked(def_id)
527
+ assert isinstance(range_type_def, CheckedStructDef)
528
+ return StructType([], range_type_def)
529
+
530
+ def make_range(self, stop: ast.expr) -> tuple[ast.expr, Type]:
531
+ make_range = self.ctx.globals.get_instance_func(self.range_ty(), "__new__")
532
+ assert make_range is not None
533
+ start = with_type(int_type(), with_loc(self.node, ast.Constant(value=0)))
534
+ return make_range.synthesize_call([start, stop], self.node, self.ctx)
535
+
536
+
537
+ def to_sized_iter(
538
+ iterator: ast.expr, range_ty: Type, size: "int | Const", ctx: Context
539
+ ) -> tuple[ast.expr, Type]:
540
+ """Adds a static size annotation to an iterator."""
541
+ sized_iter_ty = sized_iter_type(range_ty, size)
542
+ make_sized_iter = ctx.globals.get_instance_func(sized_iter_ty, "__new__")
543
+ assert make_sized_iter is not None
544
+ sized_iter, _ = make_sized_iter.check_call([iterator], sized_iter_ty, iterator, ctx)
545
+ return sized_iter, sized_iter_ty
546
+
547
+
548
+ class BarrierChecker(CustomCallChecker):
549
+ """Call checker for the `barrier` function."""
550
+
551
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
552
+ tys = [ExprSynthesizer(self.ctx).synthesize(val)[1] for val in args]
553
+ func_ty = FunctionType(
554
+ [FuncInput(t, InputFlags.Inout) for t in tys],
555
+ NoneType(),
556
+ )
557
+ args, ret_ty, inst = synthesize_call(func_ty, args, self.node, self.ctx)
558
+ assert len(inst) == 0, "func_ty is not generic"
559
+ node = BarrierExpr(args=args, func_ty=func_ty)
560
+ return with_loc(self.node, node), ret_ty
561
+
562
+
563
+ class WasmCallChecker(CustomCallChecker):
564
+ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
565
+ # Use default implementation from the expression checker
566
+ args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)
567
+
568
+ return GlobalCall(def_id=self.func.id, args=args, type_args=inst), subst
569
+
570
+ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
571
+ # Use default implementation from the expression checker
572
+ args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
573
+ return GlobalCall(def_id=self.func.id, args=args, type_args=inst), ty
File without changes