guppylang-internals 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. guppylang_internals/__init__.py +3 -0
  2. guppylang_internals/ast_util.py +350 -0
  3. guppylang_internals/cfg/__init__.py +0 -0
  4. guppylang_internals/cfg/analysis.py +230 -0
  5. guppylang_internals/cfg/bb.py +221 -0
  6. guppylang_internals/cfg/builder.py +606 -0
  7. guppylang_internals/cfg/cfg.py +117 -0
  8. guppylang_internals/checker/__init__.py +0 -0
  9. guppylang_internals/checker/cfg_checker.py +388 -0
  10. guppylang_internals/checker/core.py +550 -0
  11. guppylang_internals/checker/errors/__init__.py +0 -0
  12. guppylang_internals/checker/errors/comptime_errors.py +106 -0
  13. guppylang_internals/checker/errors/generic.py +45 -0
  14. guppylang_internals/checker/errors/linearity.py +300 -0
  15. guppylang_internals/checker/errors/type_errors.py +344 -0
  16. guppylang_internals/checker/errors/wasm.py +34 -0
  17. guppylang_internals/checker/expr_checker.py +1413 -0
  18. guppylang_internals/checker/func_checker.py +269 -0
  19. guppylang_internals/checker/linearity_checker.py +821 -0
  20. guppylang_internals/checker/stmt_checker.py +447 -0
  21. guppylang_internals/compiler/__init__.py +0 -0
  22. guppylang_internals/compiler/cfg_compiler.py +233 -0
  23. guppylang_internals/compiler/core.py +613 -0
  24. guppylang_internals/compiler/expr_compiler.py +989 -0
  25. guppylang_internals/compiler/func_compiler.py +97 -0
  26. guppylang_internals/compiler/hugr_extension.py +224 -0
  27. guppylang_internals/compiler/qtm_platform_extension.py +0 -0
  28. guppylang_internals/compiler/stmt_compiler.py +212 -0
  29. guppylang_internals/decorator.py +246 -0
  30. guppylang_internals/definition/__init__.py +0 -0
  31. guppylang_internals/definition/common.py +214 -0
  32. guppylang_internals/definition/const.py +74 -0
  33. guppylang_internals/definition/custom.py +492 -0
  34. guppylang_internals/definition/declaration.py +171 -0
  35. guppylang_internals/definition/extern.py +89 -0
  36. guppylang_internals/definition/function.py +302 -0
  37. guppylang_internals/definition/overloaded.py +150 -0
  38. guppylang_internals/definition/parameter.py +82 -0
  39. guppylang_internals/definition/pytket_circuits.py +405 -0
  40. guppylang_internals/definition/struct.py +392 -0
  41. guppylang_internals/definition/traced.py +151 -0
  42. guppylang_internals/definition/ty.py +51 -0
  43. guppylang_internals/definition/value.py +115 -0
  44. guppylang_internals/definition/wasm.py +61 -0
  45. guppylang_internals/diagnostic.py +523 -0
  46. guppylang_internals/dummy_decorator.py +76 -0
  47. guppylang_internals/engine.py +295 -0
  48. guppylang_internals/error.py +107 -0
  49. guppylang_internals/experimental.py +92 -0
  50. guppylang_internals/ipython_inspect.py +28 -0
  51. guppylang_internals/nodes.py +427 -0
  52. guppylang_internals/py.typed +0 -0
  53. guppylang_internals/span.py +150 -0
  54. guppylang_internals/std/__init__.py +0 -0
  55. guppylang_internals/std/_internal/__init__.py +0 -0
  56. guppylang_internals/std/_internal/checker.py +573 -0
  57. guppylang_internals/std/_internal/compiler/__init__.py +0 -0
  58. guppylang_internals/std/_internal/compiler/arithmetic.py +136 -0
  59. guppylang_internals/std/_internal/compiler/array.py +569 -0
  60. guppylang_internals/std/_internal/compiler/either.py +131 -0
  61. guppylang_internals/std/_internal/compiler/frozenarray.py +68 -0
  62. guppylang_internals/std/_internal/compiler/futures.py +30 -0
  63. guppylang_internals/std/_internal/compiler/list.py +348 -0
  64. guppylang_internals/std/_internal/compiler/mem.py +13 -0
  65. guppylang_internals/std/_internal/compiler/option.py +78 -0
  66. guppylang_internals/std/_internal/compiler/prelude.py +271 -0
  67. guppylang_internals/std/_internal/compiler/qsystem.py +48 -0
  68. guppylang_internals/std/_internal/compiler/quantum.py +118 -0
  69. guppylang_internals/std/_internal/compiler/tket_bool.py +55 -0
  70. guppylang_internals/std/_internal/compiler/tket_exts.py +59 -0
  71. guppylang_internals/std/_internal/compiler/wasm.py +135 -0
  72. guppylang_internals/std/_internal/compiler.py +0 -0
  73. guppylang_internals/std/_internal/debug.py +95 -0
  74. guppylang_internals/std/_internal/util.py +271 -0
  75. guppylang_internals/tracing/__init__.py +0 -0
  76. guppylang_internals/tracing/builtins_mock.py +62 -0
  77. guppylang_internals/tracing/frozenlist.py +57 -0
  78. guppylang_internals/tracing/function.py +186 -0
  79. guppylang_internals/tracing/object.py +551 -0
  80. guppylang_internals/tracing/state.py +69 -0
  81. guppylang_internals/tracing/unpacking.py +194 -0
  82. guppylang_internals/tracing/util.py +86 -0
  83. guppylang_internals/tys/__init__.py +0 -0
  84. guppylang_internals/tys/arg.py +115 -0
  85. guppylang_internals/tys/builtin.py +382 -0
  86. guppylang_internals/tys/common.py +110 -0
  87. guppylang_internals/tys/const.py +114 -0
  88. guppylang_internals/tys/errors.py +178 -0
  89. guppylang_internals/tys/param.py +251 -0
  90. guppylang_internals/tys/parsing.py +425 -0
  91. guppylang_internals/tys/printing.py +174 -0
  92. guppylang_internals/tys/subst.py +112 -0
  93. guppylang_internals/tys/ty.py +876 -0
  94. guppylang_internals/tys/var.py +49 -0
  95. guppylang_internals-0.21.0.dist-info/METADATA +253 -0
  96. guppylang_internals-0.21.0.dist-info/RECORD +98 -0
  97. guppylang_internals-0.21.0.dist-info/WHEEL +4 -0
  98. guppylang_internals-0.21.0.dist-info/licenses/LICENCE +201 -0
@@ -0,0 +1,551 @@
1
+ import functools
2
+ import itertools
3
+ from collections.abc import Callable, Iterator, Sequence
4
+ from contextlib import suppress
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, ClassVar, NamedTuple, TypeAlias
8
+
9
+ from hugr import Wire
10
+
11
+ import guppylang_internals.checker.expr_checker as expr_checker
12
+ from guppylang_internals.checker.errors.generic import UnsupportedError
13
+ from guppylang_internals.checker.errors.type_errors import (
14
+ BinaryOperatorNotDefinedError,
15
+ UnaryOperatorNotDefinedError,
16
+ )
17
+ from guppylang_internals.definition.common import DefId, Definition
18
+ from guppylang_internals.definition.ty import TypeDef
19
+ from guppylang_internals.definition.value import (
20
+ CallableDef,
21
+ CompiledValueDef,
22
+ )
23
+ from guppylang_internals.engine import DEF_STORE, ENGINE
24
+ from guppylang_internals.error import GuppyComptimeError, GuppyError, GuppyTypeError
25
+ from guppylang_internals.ipython_inspect import normalize_ipython_dummy_files
26
+ from guppylang_internals.tracing.state import get_tracing_state, tracing_active
27
+ from guppylang_internals.tracing.util import (
28
+ capture_guppy_errors,
29
+ get_calling_frame,
30
+ hide_trace,
31
+ )
32
+ from guppylang_internals.tys.ty import FunctionType, StructType, Type
33
+
34
+ # Mapping from unary dunder method to display name of the operation
35
+ unary_table = dict(expr_checker.unary_table.values())
36
+
37
+ # Mapping from binary dunder method to reversed method and display name of the operation
38
+ binary_table = {
39
+ method: (reverse_method, display_name)
40
+ for method, reverse_method, display_name in expr_checker.binary_table.values()
41
+ }
42
+
43
+ # Mapping from reverse binary dunder method to original method and display name of the
44
+ # operation
45
+ reverse_binary_table = {
46
+ reverse_method: (method, display_name)
47
+ for method, reverse_method, display_name in expr_checker.binary_table.values()
48
+ }
49
+
50
+ UnaryDunderMethod: TypeAlias = Callable[["DunderMixin"], Any]
51
+ BinaryDunderMethod: TypeAlias = Callable[["DunderMixin", Any], Any]
52
+
53
+
54
+ def unary_operation(f: UnaryDunderMethod) -> UnaryDunderMethod:
55
+ """Decorator for methods corresponding to unary operations like `__neg__` etc.
56
+
57
+ Emits a user error if the unary operation is not defined for the given type.
58
+ """
59
+
60
+ @functools.wraps(f)
61
+ @capture_guppy_errors
62
+ def wrapped(self: "DunderMixin") -> Any:
63
+ from guppylang_internals.tracing.state import get_tracing_state
64
+ from guppylang_internals.tracing.unpacking import guppy_object_from_py
65
+
66
+ state = get_tracing_state()
67
+ self = guppy_object_from_py(self, state.dfg.builder, state.node, state.ctx)
68
+
69
+ with suppress(Exception):
70
+ return f(self)
71
+
72
+ raise GuppyTypeError(
73
+ UnaryOperatorNotDefinedError(state.node, self._ty, unary_table[f.__name__])
74
+ )
75
+
76
+ return wrapped
77
+
78
+
79
+ def binary_operation(f: BinaryDunderMethod) -> BinaryDunderMethod:
80
+ """Decorator for methods corresponding to binary operations like `__add__` etc.
81
+
82
+ Delegate calls to their reversed versions `__radd__` etc. if the original one
83
+ doesn't type check. Otherwise, emits an error informing the user that the binary
84
+ operation is not defined for those types.
85
+ """
86
+
87
+ @functools.wraps(f)
88
+ @capture_guppy_errors
89
+ def wrapped(self: "DunderMixin", other: Any) -> Any:
90
+ from guppylang_internals.tracing.state import get_tracing_state
91
+ from guppylang_internals.tracing.unpacking import guppy_object_from_py
92
+
93
+ state = get_tracing_state()
94
+ self = guppy_object_from_py(self, state.dfg.builder, state.node, state.ctx)
95
+ other = guppy_object_from_py(other, state.dfg.builder, state.node, state.ctx)
96
+
97
+ # First try the method on `self`
98
+ with suppress(Exception):
99
+ return f(self, other)
100
+
101
+ # If that failed, try the reverse method on `other`.
102
+ # NB: We know that `f.__name__` is in one of the tables since we make sure to
103
+ # only put this decorator on the correct dunder methods below
104
+ if f.__name__ in binary_table:
105
+ reverse_method, display_name = binary_table[f.__name__]
106
+ left_ty, right_ty = self._ty, other._ty
107
+ else:
108
+ reverse_method, display_name = reverse_binary_table[f.__name__]
109
+ left_ty, right_ty = other._ty, self._ty
110
+ with suppress(Exception):
111
+ return other.__getattr__(reverse_method)(self)
112
+
113
+ raise GuppyTypeError(
114
+ BinaryOperatorNotDefinedError(state.node, left_ty, right_ty, display_name)
115
+ )
116
+
117
+ return wrapped
118
+
119
+
120
+ class DunderMixin:
121
+ """Mixin class to allow `GuppyObject`s and `GuppyDefinition`s to be used in
122
+ arithmetic expressions etc. via providing the corresponding dunder methods
123
+ delegating to the objects impls.
124
+ """
125
+
126
+ def _get_method(self, name: str) -> Any:
127
+ from guppylang_internals.tracing.state import get_tracing_state
128
+ from guppylang_internals.tracing.unpacking import guppy_object_from_py
129
+
130
+ state = get_tracing_state()
131
+ self = guppy_object_from_py(self, state.dfg.builder, state.node, state.ctx)
132
+ return self.__getattr__(name)
133
+
134
+ def __abs__(self) -> Any:
135
+ return self._get_method("__abs__")()
136
+
137
+ @binary_operation
138
+ def __add__(self, other: Any) -> Any:
139
+ return self._get_method("__add__")(other)
140
+
141
+ @binary_operation
142
+ def __and__(self, other: Any) -> Any:
143
+ return self._get_method("__and__")(other)
144
+
145
+ def __bool__(self: Any) -> Any:
146
+ return self._get_method("__bool__")()
147
+
148
+ def __ceil__(self: Any) -> Any:
149
+ return self._get_method("__ceil__")()
150
+
151
+ def __divmod__(self, other: Any) -> Any:
152
+ return self._get_method("__divmod__")(other)
153
+
154
+ @binary_operation
155
+ def __eq__(self, other: object) -> Any:
156
+ return self._get_method("__eq__")(other)
157
+
158
+ def __float__(self) -> Any:
159
+ return self._get_method("__float__")()
160
+
161
+ def __floor__(self) -> Any:
162
+ return self._get_method("__floor__")()
163
+
164
+ @binary_operation
165
+ def __floordiv__(self, other: Any) -> Any:
166
+ return self._get_method("__floordiv__")(other)
167
+
168
+ @binary_operation
169
+ def __ge__(self, other: Any) -> Any:
170
+ return self._get_method("__ge__")(other)
171
+
172
+ @binary_operation
173
+ def __gt__(self, other: Any) -> Any:
174
+ return self._get_method("__gt__")(other)
175
+
176
+ def __int__(self) -> Any:
177
+ return self._get_method("__int__")()
178
+
179
+ @unary_operation
180
+ def __invert__(self) -> Any:
181
+ return self._get_method("__invert__")()
182
+
183
+ @binary_operation
184
+ def __le__(self, other: Any) -> Any:
185
+ return self._get_method("__le__")(other)
186
+
187
+ @binary_operation
188
+ def __lshift__(self, other: Any) -> Any:
189
+ return self._get_method("__lshift__")(other)
190
+
191
+ @binary_operation
192
+ def __lt__(self, other: Any) -> Any:
193
+ return self._get_method("__lt__")(other)
194
+
195
+ @binary_operation
196
+ def __mod__(self, other: Any) -> Any:
197
+ return self._get_method("__mod__")(other)
198
+
199
+ @binary_operation
200
+ def __mul__(self, other: Any) -> Any:
201
+ return self._get_method("__mul__")(other)
202
+
203
+ @binary_operation
204
+ def __ne__(self, other: object) -> Any:
205
+ return self._get_method("__ne__")(other)
206
+
207
+ @unary_operation
208
+ def __neg__(self) -> Any:
209
+ return self._get_method("__neg__")()
210
+
211
+ @binary_operation
212
+ def __or__(self, other: Any) -> Any:
213
+ return self._get_method("__or__")(other)
214
+
215
+ @unary_operation
216
+ def __pos__(self) -> Any:
217
+ return self._get_method("__pos__")()
218
+
219
+ @binary_operation
220
+ def __pow__(self, other: Any) -> Any:
221
+ return self._get_method("__pow__")(other)
222
+
223
+ @binary_operation
224
+ def __radd__(self, other: Any) -> Any:
225
+ return self._get_method("__radd__")(other)
226
+
227
+ @binary_operation
228
+ def __rand__(self, other: Any) -> Any:
229
+ return self._get_method("__rand__")(other)
230
+
231
+ @binary_operation
232
+ def __rfloordiv__(self, other: Any) -> Any:
233
+ return self._get_method("__rfloordiv__")(other)
234
+
235
+ @binary_operation
236
+ def __rlshift__(self, other: Any) -> Any:
237
+ return self._get_method("__rlshift__")(other)
238
+
239
+ @binary_operation
240
+ def __rmod__(self, other: Any) -> Any:
241
+ return self._get_method("__rmod__")(other)
242
+
243
+ @binary_operation
244
+ def __rmul__(self, other: Any) -> Any:
245
+ return self._get_method("__rmul__")(other)
246
+
247
+ @binary_operation
248
+ def __ror__(self, other: Any) -> Any:
249
+ return self._get_method("__ror__")(other)
250
+
251
+ @binary_operation
252
+ def __rpow__(self, other: Any) -> Any:
253
+ return self._get_method("__rpow__")(other)
254
+
255
+ @binary_operation
256
+ def __rrshift__(self, other: Any) -> Any:
257
+ return self._get_method("__pow__")(other)
258
+
259
+ @binary_operation
260
+ def __rshift__(self, other: Any) -> Any:
261
+ return self._get_method("__rshift__")(other)
262
+
263
+ @binary_operation
264
+ def __rsub__(self, other: Any) -> Any:
265
+ return self._get_method("__rsub__")(other)
266
+
267
+ @binary_operation
268
+ def __rtruediv__(self, other: Any) -> Any:
269
+ return self._get_method("__rtruediv__")(other)
270
+
271
+ @binary_operation
272
+ def __rxor__(self, other: Any) -> Any:
273
+ return self._get_method("__rxor__")(other)
274
+
275
+ @binary_operation
276
+ def __sub__(self, other: Any) -> Any:
277
+ return self._get_method("__sub__")(other)
278
+
279
+ @binary_operation
280
+ def __truediv__(self, other: Any) -> Any:
281
+ return self._get_method("__truediv__")(other)
282
+
283
+ def __trunc__(self) -> Any:
284
+ return self._get_method("__trunc__")()
285
+
286
+ @binary_operation
287
+ def __xor__(self, other: Any) -> Any:
288
+ return self._get_method("__xor__")(other)
289
+
290
+
291
+ class ObjectUse(NamedTuple):
292
+ """Records a use of a non-copyable `GuppyObject`."""
293
+
294
+ #: Path of the Python file in which the use occurred
295
+ module: str
296
+
297
+ #: Line number of the use
298
+ lineno: int
299
+
300
+ #: If the use was as an argument to a Guppy function, we also record a reference to
301
+ #: the called function.
302
+ called_func: CallableDef | None
303
+
304
+
305
+ @dataclass(frozen=True)
306
+ class GuppyObjectId:
307
+ """Unique id for abstract GuppyObjects allocated during tracing."""
308
+
309
+ id: int
310
+
311
+ _fresh_ids: ClassVar[Iterator[int]] = itertools.count()
312
+
313
+ @classmethod
314
+ def fresh(cls) -> "GuppyObjectId":
315
+ return GuppyObjectId(next(cls._fresh_ids))
316
+
317
+
318
+ class GuppyObject(DunderMixin):
319
+ """The runtime representation of abstract Guppy objects during tracing.
320
+
321
+ They correspond to a single Hugr wire within the current dataflow graph.
322
+ """
323
+
324
+ #: The type of this object
325
+ _ty: Type
326
+
327
+ #: The Hugr wire holding this object
328
+ _wire: Wire
329
+
330
+ #: Whether this object has been used
331
+ _used: ObjectUse | None
332
+
333
+ #: Unique id for this object
334
+ _id: GuppyObjectId
335
+
336
+ def __init__(self, ty: Type, wire: Wire, used: ObjectUse | None = None) -> None:
337
+ self._ty = ty
338
+ self._wire = wire
339
+ self._used = used
340
+ self._id = GuppyObjectId.fresh()
341
+ state = get_tracing_state()
342
+ if not ty.droppable and not self._used:
343
+ state.unused_undroppable_objs[self._id] = self
344
+
345
+ @hide_trace
346
+ def __getattr__(self, key: str) -> Any: # type: ignore[misc]
347
+ # Guppy objects don't have fields (structs are treated separately below), so the
348
+ # only attributes we have to worry about are methods.
349
+ func = get_tracing_state().globals.get_instance_func(self._ty, key)
350
+ if func is None:
351
+ raise GuppyComptimeError(
352
+ f"Expression of type `{self._ty}` has no attribute `{key}`"
353
+ )
354
+ return lambda *xs: TracingDefMixin(func)(self, *xs)
355
+
356
+ @hide_trace
357
+ def __bool__(self) -> Any:
358
+ err = (
359
+ "Can't branch on a dynamic Guppy value since its concrete value is not "
360
+ "known at comptime. Consider defining a regular Guppy function to perform "
361
+ "dynamic branching."
362
+ )
363
+ raise GuppyComptimeError(err)
364
+
365
+ @hide_trace
366
+ @capture_guppy_errors
367
+ def __call__(self, *args: Any) -> Any:
368
+ if not isinstance(self._ty, FunctionType):
369
+ err = f"Value of type `{self._ty}` is not callable"
370
+ raise GuppyComptimeError(err)
371
+
372
+ # TODO: Support higher-order functions
373
+ state = get_tracing_state()
374
+ raise GuppyError(
375
+ UnsupportedError(state.node, "Higher-order comptime functions")
376
+ )
377
+
378
+ @hide_trace
379
+ def __iter__(self) -> Any:
380
+ # Abstract Guppy objects are not iterable from Python since our iterator
381
+ # protocol doesn't work during tracing.
382
+ raise GuppyComptimeError(
383
+ f"Expression of type `{self._ty}` is not iterable at comptime"
384
+ )
385
+
386
+ def _use_wire(self, called_func: CallableDef | None) -> Wire:
387
+ # Panic if the value has already been used
388
+ if self._used and not self._ty.copyable:
389
+ use = self._used
390
+ # TODO: Should we print the full path to the file or only the name as is
391
+ # done here? Note that the former will lead to challenges with golden
392
+ # tests
393
+ filename = Path(normalize_ipython_dummy_files(use.module)).name
394
+ err = (
395
+ f"Value with non-copyable type `{self._ty}` was already used\n\n"
396
+ f"Previous use occurred in {filename}:{use.lineno}"
397
+ )
398
+ if use.called_func:
399
+ err += f" as an argument to `{use.called_func.name}`"
400
+ raise GuppyComptimeError(err)
401
+ # Otherwise, mark it as used
402
+ else:
403
+ frame = get_calling_frame()
404
+ assert frame is not None
405
+ module_name = frame.f_code.co_filename
406
+ self._used = ObjectUse(module_name, frame.f_lineno, called_func)
407
+ if not self._ty.droppable:
408
+ state = get_tracing_state()
409
+ state.unused_undroppable_objs.pop(self._id)
410
+ return self._wire
411
+
412
+
413
+ class GuppyStructObject(DunderMixin):
414
+ """The runtime representation of Guppy struct objects during tracing.
415
+
416
+ Note that `GuppyStructObject` is not a `GuppyObject` itself since it's not backed
417
+ by a single wire, but it can contain multiple of them.
418
+
419
+ Mutation of structs during tracing is generally unchecked. We allow users to write
420
+ whatever they want into the fields, making it more or less isomorphic to a Python
421
+ dataclass. Thus, structs need to be checked at function call boundaries to ensure
422
+ that the user hasn't messed up. This is done in `guppylang.tracing.unpacking.
423
+ guppy_object_from_py`.
424
+
425
+ Similar to dataclasses, we allow structs to be `frozen` which makes them immutable.
426
+ This is needed to preserve Python semantics when structs are used as non-borrowed
427
+ function arguments: Mutation in the function body cannot be observed from the
428
+ outside, so we prevent it to avoid confusion.
429
+ """
430
+
431
+ #: The type of this struct object
432
+ _ty: StructType
433
+
434
+ #: Mapping from field names to values. The values can be any Python object.
435
+ _field_values: dict[str, Any]
436
+
437
+ #: Whether this struct object is frozen, i.e. immutable
438
+ _frozen: bool
439
+
440
+ def __init__(
441
+ self, ty: StructType, field_values: Sequence[Any], frozen: bool
442
+ ) -> None:
443
+ field_values_dict = {
444
+ f.name: v for f, v in zip(ty.fields, field_values, strict=True)
445
+ }
446
+ # Can't use regular assignment for class attributes since we override
447
+ # `__setattr__` below
448
+ object.__setattr__(self, "_ty", ty)
449
+ object.__setattr__(self, "_field_values", field_values_dict)
450
+ object.__setattr__(self, "_frozen", frozen)
451
+
452
+ @hide_trace
453
+ def __getattr__(self, key: str) -> Any: # type: ignore[misc]
454
+ # It could be an attribute
455
+ if key in self._field_values:
456
+ return self._field_values[key]
457
+ # Or a method
458
+ func = get_tracing_state().globals.get_instance_func(self._ty, key)
459
+ if func is None:
460
+ err = f"Expression of type `{self._ty}` has no attribute `{key}`"
461
+ raise AttributeError(err)
462
+ return lambda *xs: TracingDefMixin(func)(self, *xs)
463
+
464
+ @hide_trace
465
+ def __setattr__(self, key: str, value: Any) -> None:
466
+ if key in self._field_values:
467
+ if self._frozen:
468
+ err = (
469
+ f"Object of type `{self._ty}` is an owned function argument. "
470
+ "Therefore, this mutation won't be visible to the caller."
471
+ )
472
+ raise GuppyComptimeError(err)
473
+ self._field_values[key] = value
474
+ else:
475
+ err = f"Expression of type `{self._ty}` has no attribute `{key}`"
476
+ raise AttributeError(err)
477
+
478
+ @hide_trace
479
+ def __iter__(self) -> Any:
480
+ # Abstract Guppy objects are not iterable from Python since our iterator
481
+ # protocol doesn't work during tracing.
482
+ raise GuppyComptimeError(f"Expression of type `{self._ty}` is not iterable")
483
+
484
+
485
+ @dataclass(frozen=True)
486
+ class TracingDefMixin(DunderMixin):
487
+ """Mixin to provide tracing semantics for definitions."""
488
+
489
+ wrapped: Definition
490
+
491
+ @property
492
+ def id(self) -> DefId:
493
+ return self.wrapped.id
494
+
495
+ @hide_trace
496
+ def __call__(self, *args: Any) -> Any:
497
+ from guppylang_internals.tracing.function import trace_call
498
+
499
+ if not tracing_active():
500
+ raise GuppyComptimeError(
501
+ f"{self.wrapped.description.capitalize()} `{self.wrapped.name}` may "
502
+ "only be called in a Guppy context"
503
+ )
504
+
505
+ defn = ENGINE.get_checked(self.wrapped.id)
506
+ if isinstance(defn, CallableDef):
507
+ return trace_call(defn, *args)
508
+ elif (
509
+ isinstance(defn, TypeDef)
510
+ and defn.id in DEF_STORE.impls
511
+ and "__new__" in DEF_STORE.impls[defn.id]
512
+ ):
513
+ constructor = DEF_STORE.raw_defs[DEF_STORE.impls[defn.id]["__new__"]]
514
+ return TracingDefMixin(constructor)(*args)
515
+ err = f"{defn.description.capitalize()} `{defn.name}` is not callable"
516
+ raise GuppyComptimeError(err)
517
+
518
+ def __getitem__(self, item: Any) -> Any:
519
+ # If this is a type definition, then `__getitem__` might be called when
520
+ # specifying generic arguments
521
+ if isinstance(self.wrapped, TypeDef):
522
+ # It doesn't really matter what we return here since we don't support types
523
+ # as comptime values yet, so just give back the definition
524
+ return self
525
+ # TODO: Alternatively, it could be a type application on a generic function.
526
+ # Supporting those requires a comptime representation of types as values
527
+ if tracing_active():
528
+ state = get_tracing_state()
529
+ defn = state.globals[self.wrapped.id]
530
+ if isinstance(defn, CallableDef) and defn.ty.parametrized:
531
+ raise GuppyComptimeError(
532
+ "Explicitly specifying type arguments of generic functions in a "
533
+ "comptime context is not supported yet"
534
+ )
535
+ raise GuppyComptimeError(
536
+ f"{self.wrapped.description.capitalize()} `{self.wrapped.name}` is not "
537
+ "subscriptable"
538
+ )
539
+
540
+ def to_guppy_object(self) -> GuppyObject:
541
+ state = get_tracing_state()
542
+ defn, [] = state.ctx.build_compiled_def(self.id, type_args=[])
543
+ if isinstance(defn, CompiledValueDef):
544
+ wire = defn.load(state.dfg, state.ctx, state.node)
545
+ return GuppyObject(defn.ty, wire, None)
546
+ elif isinstance(defn, TypeDef):
547
+ if defn.id in DEF_STORE.impls and "__new__" in DEF_STORE.impls[defn.id]:
548
+ constructor = DEF_STORE.raw_defs[DEF_STORE.impls[defn.id]["__new__"]]
549
+ return TracingDefMixin(constructor).to_guppy_object()
550
+ err = f"{defn.description.capitalize()} `{defn.name}` is not a value"
551
+ raise GuppyComptimeError(err)
@@ -0,0 +1,69 @@
1
+ from collections.abc import Iterator
2
+ from contextlib import contextmanager
3
+ from contextvars import ContextVar
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING
6
+
7
+ from guppylang_internals.ast_util import AstNode
8
+ from guppylang_internals.checker.core import Globals
9
+ from guppylang_internals.compiler.core import CompilerContext, DFContainer
10
+ from guppylang_internals.error import InternalGuppyError
11
+
12
+ if TYPE_CHECKING:
13
+ from guppylang_internals.tracing.object import GuppyObject, GuppyObjectId
14
+
15
+
16
+ @dataclass
17
+ class TracingState:
18
+ """Internal state that is used during the tracing phase of comptime functions."""
19
+
20
+ #: Reference to the global compilation context.
21
+ ctx: CompilerContext
22
+
23
+ #: The current dataflow graph under construction.
24
+ dfg: DFContainer
25
+
26
+ #: An AST node capturing the code block that is currently being traced
27
+ node: AstNode
28
+
29
+ #: Set of all allocated undroppable GuppyObjects where the `used` flag is not set,
30
+ #: indexed by their id. This is used to detect linearity violations.
31
+ unused_undroppable_objs: "dict[GuppyObjectId, GuppyObject]" = field(
32
+ default_factory=dict
33
+ )
34
+
35
+ @property
36
+ def globals(self) -> Globals:
37
+ return self.ctx.checked_globals
38
+
39
+
40
+ _STATE: ContextVar[TracingState | None] = ContextVar("_STATE", default=None)
41
+
42
+
43
+ def reset_state() -> None:
44
+ """Resets the tracing state to be undefined."""
45
+ _STATE.set(None)
46
+
47
+
48
+ def tracing_active() -> bool:
49
+ """Checks if the tracing mode is currently active."""
50
+ return _STATE.get() is not None
51
+
52
+
53
+ def get_tracing_state() -> TracingState:
54
+ """Returns the current tracing state.
55
+
56
+ Raises an `InternalGuppyError` if the tracing mode is currently not active.
57
+ """
58
+ state = _STATE.get()
59
+ if state is None:
60
+ raise InternalGuppyError("Guppy tracing mode is not active")
61
+ return state
62
+
63
+
64
+ @contextmanager
65
+ def set_tracing_state(state: TracingState) -> Iterator[None]:
66
+ """Context manager to update tracing state for the duration of a code block."""
67
+ token = _STATE.set(state)
68
+ yield
69
+ _STATE.reset(token)