guppylang-internals 0.26.0__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +37 -18
  3. guppylang_internals/cfg/analysis.py +6 -6
  4. guppylang_internals/cfg/builder.py +44 -12
  5. guppylang_internals/cfg/cfg.py +1 -1
  6. guppylang_internals/checker/core.py +1 -1
  7. guppylang_internals/checker/errors/comptime_errors.py +0 -12
  8. guppylang_internals/checker/errors/linearity.py +6 -2
  9. guppylang_internals/checker/expr_checker.py +53 -28
  10. guppylang_internals/checker/func_checker.py +4 -3
  11. guppylang_internals/checker/stmt_checker.py +1 -1
  12. guppylang_internals/compiler/cfg_compiler.py +1 -1
  13. guppylang_internals/compiler/core.py +17 -4
  14. guppylang_internals/compiler/expr_compiler.py +36 -14
  15. guppylang_internals/compiler/modifier_compiler.py +5 -2
  16. guppylang_internals/decorator.py +5 -3
  17. guppylang_internals/definition/common.py +1 -0
  18. guppylang_internals/definition/custom.py +2 -2
  19. guppylang_internals/definition/declaration.py +3 -3
  20. guppylang_internals/definition/function.py +28 -8
  21. guppylang_internals/definition/metadata.py +87 -0
  22. guppylang_internals/definition/overloaded.py +11 -2
  23. guppylang_internals/definition/pytket_circuits.py +50 -67
  24. guppylang_internals/definition/value.py +1 -1
  25. guppylang_internals/definition/wasm.py +3 -3
  26. guppylang_internals/diagnostic.py +89 -16
  27. guppylang_internals/engine.py +84 -40
  28. guppylang_internals/error.py +1 -1
  29. guppylang_internals/nodes.py +301 -3
  30. guppylang_internals/span.py +7 -3
  31. guppylang_internals/std/_internal/checker.py +104 -2
  32. guppylang_internals/std/_internal/compiler/array.py +36 -1
  33. guppylang_internals/std/_internal/compiler/either.py +14 -2
  34. guppylang_internals/std/_internal/compiler/tket_bool.py +1 -6
  35. guppylang_internals/std/_internal/compiler/tket_exts.py +1 -1
  36. guppylang_internals/std/_internal/debug.py +5 -3
  37. guppylang_internals/tracing/builtins_mock.py +2 -2
  38. guppylang_internals/tracing/object.py +6 -2
  39. guppylang_internals/tys/parsing.py +4 -1
  40. guppylang_internals/tys/qubit.py +6 -4
  41. guppylang_internals/tys/subst.py +2 -2
  42. guppylang_internals/tys/ty.py +2 -2
  43. guppylang_internals/wasm_util.py +2 -3
  44. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/METADATA +5 -4
  45. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/RECORD +47 -46
  46. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/WHEEL +0 -0
  47. {guppylang_internals-0.26.0.dist-info → guppylang_internals-0.28.0.dist-info}/licenses/LICENCE +0 -0
@@ -208,7 +208,8 @@ class DiagnosticsRenderer:
208
208
  MAX_MESSAGE_LINE_LEN: Final[int] = 80
209
209
 
210
210
  #: Number of preceding source lines we show to give additional context
211
- PREFIX_CONTEXT_LINES: Final[int] = 2
211
+ PREFIX_ERROR_CONTEXT_LINES: Final[int] = 2
212
+ PREFIX_NOTE_CONTEXT_LINES: Final[int] = 1
212
213
 
213
214
  def __init__(self, source: SourceMap) -> None:
214
215
  self.buffer = []
@@ -243,31 +244,84 @@ class DiagnosticsRenderer:
243
244
  else:
244
245
  span = to_span(diag.span)
245
246
  level = self.level_str(diag.level)
246
- all_spans = [span] + [
247
- to_span(child.span) for child in diag.children if child.span
247
+
248
+ children_with_span = [
249
+ (child, to_span(child.span)) for child in diag.children if child.span
248
250
  ]
251
+ all_spans = [span] + [span for _, span in children_with_span]
249
252
  max_lineno = max(s.end.line for s in all_spans)
253
+
250
254
  self.buffer.append(f"{level}: {diag.rendered_title} (at {span.start})")
255
+
256
+ # Render main error span first
251
257
  self.render_snippet(
252
258
  span,
253
259
  diag.rendered_span_label,
254
260
  max_lineno,
255
261
  is_primary=True,
256
- prefix_lines=self.PREFIX_CONTEXT_LINES,
262
+ prefix_lines=self.PREFIX_ERROR_CONTEXT_LINES,
257
263
  )
258
- # First render all sub-diagnostics that come with a span
259
- for sub_diag in diag.children:
260
- if sub_diag.span:
264
+
265
+ match children_with_span:
266
+ case []:
267
+ pass
268
+ case [(only_child, span)]:
269
+ self.buffer.append("\nNote:")
270
+ self.render_snippet(
271
+ span,
272
+ only_child.rendered_span_label,
273
+ max_lineno,
274
+ prefix_lines=self.PREFIX_NOTE_CONTEXT_LINES,
275
+ print_pad_line=True,
276
+ )
277
+ case [(first_child, first_span), *children_with_span]:
278
+ self.buffer.append("\nNotes:")
261
279
  self.render_snippet(
262
- to_span(sub_diag.span),
263
- sub_diag.rendered_span_label,
280
+ first_span,
281
+ first_child.rendered_span_label,
264
282
  max_lineno,
265
- is_primary=False,
283
+ prefix_lines=self.PREFIX_NOTE_CONTEXT_LINES,
284
+ print_pad_line=True,
266
285
  )
286
+
287
+ prev_span_end_lineno = first_span.end.line
288
+
289
+ for sub_diag, span in children_with_span:
290
+ span_start_lineno = span.start.line
291
+ span_end_lineno = span.end.line
292
+
293
+ # If notes are on the same line, render them together
294
+ if span_start_lineno == prev_span_end_lineno:
295
+ prefix_lines = 0
296
+ print_pad_line = True
297
+ # if notes are close enough, render them adjacently
298
+ elif (
299
+ span_start_lineno - self.PREFIX_NOTE_CONTEXT_LINES
300
+ <= prev_span_end_lineno + 1
301
+ ):
302
+ prefix_lines = span_start_lineno - prev_span_end_lineno - 1
303
+ print_pad_line = False
304
+ # otherwise we render a separator between notes
305
+ else:
306
+ self.buffer.append("")
307
+ prefix_lines = self.PREFIX_NOTE_CONTEXT_LINES
308
+ print_pad_line = False
309
+
310
+ self.render_snippet(
311
+ span,
312
+ sub_diag.rendered_span_label,
313
+ max_lineno,
314
+ prefix_lines=prefix_lines,
315
+ print_pad_line=print_pad_line,
316
+ )
317
+ prev_span_end_lineno = span_end_lineno
318
+
319
+ # Render the main diagnostic message if present
267
320
  if diag.rendered_message:
268
321
  self.buffer.append("")
269
322
  self.buffer += wrap(diag.rendered_message, self.MAX_MESSAGE_LINE_LEN)
270
- # Finally, render all sub-diagnostics that have a non-span message
323
+
324
+ # Render all sub-diagnostics that have a non-span message
271
325
  for sub_diag in diag.children:
272
326
  if sub_diag.rendered_message:
273
327
  self.buffer.append("")
@@ -281,8 +335,9 @@ class DiagnosticsRenderer:
281
335
  span: Span,
282
336
  label: str | None,
283
337
  max_lineno: int,
284
- is_primary: bool,
338
+ is_primary: bool = False,
285
339
  prefix_lines: int = 0,
340
+ print_pad_line: bool = False,
286
341
  ) -> None:
287
342
  """Renders the source associated with a span together with an optional label.
288
343
 
@@ -315,7 +370,8 @@ class DiagnosticsRenderer:
315
370
  Optionally includes up to `prefix_lines` preceding source lines to give
316
371
  additional context.
317
372
  """
318
- # Check how much space we need to reserve for the leading line numbers
373
+ # Check how much horizontal space we need to reserve for the leading
374
+ # line numbers
319
375
  ll_length = len(str(max_lineno))
320
376
  highlight_char = "^" if is_primary else "-"
321
377
 
@@ -324,12 +380,29 @@ class DiagnosticsRenderer:
324
380
  ll = "" if line_number is None else str(line_number)
325
381
  self.buffer.append(" " * (ll_length - len(ll)) + ll + " | " + line)
326
382
 
327
- # One line of padding
328
- render_line("")
383
+ # One line of padding (primary span, first note or between same line notes)
384
+ if is_primary or print_pad_line:
385
+ render_line("")
329
386
 
330
- # Grab all lines we want to display and remove excessive leading whitespace
387
+ # Grab all lines we want to display
331
388
  prefix_lines = min(prefix_lines, span.start.line - 1)
332
389
  all_lines = self.source.span_lines(span, prefix_lines)
390
+
391
+ # Convert leading tab characters into four whitespaces each (see PEP 8)
392
+ for i, line in enumerate(all_lines):
393
+ line_no_tabs = line.lstrip("\t")
394
+ num_tabs = len(line) - len(line_no_tabs)
395
+ all_lines[i] = " " * (num_tabs * 4) + line_no_tabs
396
+ # Shift span locations, accounting for incorporated \t
397
+ new_start = span.start
398
+ new_end = span.end
399
+ if i == prefix_lines: # Line is the first line in the span
400
+ new_start = span.start.shift_right(num_tabs * 3)
401
+ if i == len(all_lines) - 1: # Line is the last line in the span
402
+ new_end = span.end.shift_right(num_tabs * 3)
403
+ span = Span(new_start or span.start, new_end)
404
+
405
+ # Remove excessive leading whitespace
333
406
  leading_whitespace = min(len(line) - len(line.lstrip()) for line in all_lines)
334
407
  if leading_whitespace > self.MAX_LEADING_WHITESPACE:
335
408
  remove = leading_whitespace - self.OPTIMAL_LEADING_WHITESPACE
@@ -3,14 +3,10 @@ from enum import Enum
3
3
  from types import FrameType
4
4
  from typing import TYPE_CHECKING
5
5
 
6
+ import hugr
6
7
  import hugr.build.function as hf
7
- import hugr.std.collections.array
8
- import hugr.std.float
9
- import hugr.std.int
10
- import hugr.std.logic
11
- import hugr.std.prelude
12
8
  from hugr import ops
13
- from hugr.ext import Extension
9
+ from hugr.ext import Extension, ExtensionRegistry
14
10
  from hugr.package import ModulePointer, Package
15
11
 
16
12
  import guppylang_internals
@@ -150,11 +146,52 @@ class CompilationEngine:
150
146
  types_to_check_worklist: dict[DefId, ParsedDef]
151
147
  to_check_worklist: dict[DefId, ParsedDef]
152
148
 
149
+ # Cached compilation infrastructure (lazy-initialized, program-independent)
150
+ _base_packaged_extensions: list[Extension] | None = None
151
+ _base_resolve_registry: ExtensionRegistry | None = None
152
+
153
153
  def __init__(self) -> None:
154
154
  """Resets the compilation cache."""
155
155
  self.reset()
156
156
  self.additional_extensions = []
157
157
 
158
+ @staticmethod
159
+ def _get_base_packaged_extensions() -> list[Extension]:
160
+ """Get the base list of packaged extensions (cached at class level)."""
161
+ if CompilationEngine._base_packaged_extensions is None:
162
+ from guppylang_internals.std._internal.compiler.tket_exts import (
163
+ TKET_EXTENSIONS,
164
+ )
165
+
166
+ CompilationEngine._base_packaged_extensions = [
167
+ *TKET_EXTENSIONS,
168
+ guppylang_internals.compiler.hugr_extension.EXTENSION, # type: ignore[attr-defined]
169
+ ]
170
+ return CompilationEngine._base_packaged_extensions
171
+
172
+ @staticmethod
173
+ def _get_base_resolve_registry() -> ExtensionRegistry:
174
+ """Get the base resolve registry with standard extensions.
175
+
176
+ Cached at class level.
177
+ """
178
+ if CompilationEngine._base_resolve_registry is None:
179
+ base_extensions = CompilationEngine._get_base_packaged_extensions()
180
+ registry = ExtensionRegistry()
181
+ for ext in [
182
+ *base_extensions,
183
+ hugr.std.prelude.PRELUDE_EXTENSION,
184
+ hugr.std.collections.array.EXTENSION,
185
+ hugr.std.float.FLOAT_OPS_EXTENSION,
186
+ hugr.std.float.FLOAT_TYPES_EXTENSION,
187
+ hugr.std.int.INT_OPS_EXTENSION,
188
+ hugr.std.int.INT_TYPES_EXTENSION,
189
+ hugr.std.logic.EXTENSION,
190
+ ]:
191
+ registry.register_updated(ext)
192
+ CompilationEngine._base_resolve_registry = registry
193
+ return CompilationEngine._base_resolve_registry
194
+
158
195
  def reset(self) -> None:
159
196
  """Resets the compilation cache."""
160
197
  self.parsed = {}
@@ -220,21 +257,12 @@ class CompilationEngine:
220
257
 
221
258
  This is the main driver behind `guppy.check()`.
222
259
  """
223
- from guppylang_internals.checker.core import Globals
224
-
225
260
  # Clear previous compilation cache.
226
261
  # TODO: In order to maintain results from the previous `check` call we would
227
262
  # need to store and check if any dependencies have changed.
228
263
  self.reset()
229
264
 
230
- defn = DEF_STORE.raw_defs[id]
231
- self.to_check_worklist = {
232
- defn.id: (
233
- defn.parse(Globals(DEF_STORE.frames[defn.id]), DEF_STORE.sources)
234
- if isinstance(defn, ParsableDef)
235
- else defn
236
- )
237
- }
265
+ self.to_check_worklist[id] = self.get_parsed(id)
238
266
  while self.types_to_check_worklist or self.to_check_worklist:
239
267
  # Types need to be checked first. This is because parsing e.g. a function
240
268
  # definition requires instantiating the types in its signature which can
@@ -273,40 +301,56 @@ class CompilationEngine:
273
301
  # loosened after https://github.com/quantinuum/hugr/issues/2501 is fixed
274
302
  graph.hugr.entrypoint = compiled_def.hugr_node
275
303
 
276
- # TODO: Currently the list of extensions is manually managed by the user.
277
- # We should compute this dynamically from the imported dependencies instead.
278
- #
279
- # The hugr prelude and std_extensions are implicit.
280
- from guppylang_internals.std._internal.compiler.tket_exts import TKET_EXTENSIONS
304
+ # Use cached base extensions and registry, only add additional extensions
305
+ base_extensions = self._get_base_packaged_extensions()
306
+ packaged_extensions = [*base_extensions, *self.additional_extensions]
281
307
 
282
- extensions = [
283
- *TKET_EXTENSIONS,
284
- guppylang_internals.compiler.hugr_extension.EXTENSION,
285
- *self.additional_extensions,
286
- ]
287
- # TODO replace with computed extensions after https://github.com/quantinuum/guppylang/issues/550
288
- all_used_extensions = [
289
- *extensions,
290
- hugr.std.prelude.PRELUDE_EXTENSION,
291
- hugr.std.collections.array.EXTENSION,
292
- hugr.std.float.FLOAT_OPS_EXTENSION,
293
- hugr.std.float.FLOAT_TYPES_EXTENSION,
294
- hugr.std.int.INT_OPS_EXTENSION,
295
- hugr.std.int.INT_TYPES_EXTENSION,
296
- hugr.std.logic.EXTENSION,
297
- ]
298
- graph.hugr.module_root.metadata[CoreMetadataKeys.USED_EXTENSIONS.value] = [
308
+ # Build resolve registry: start with cached base, add any additional
309
+ if self.additional_extensions:
310
+ from copy import deepcopy
311
+
312
+ resolve_registry = deepcopy(self._get_base_resolve_registry())
313
+ for ext in self.additional_extensions:
314
+ resolve_registry.register_updated(ext)
315
+ else:
316
+ resolve_registry = self._get_base_resolve_registry()
317
+
318
+ # Compute used extensions dynamically from the HUGR.
319
+ used_extensions_result = graph.hugr.used_extensions(
320
+ resolve_from=resolve_registry
321
+ )
322
+
323
+ # Set metadata for used extensions
324
+ used_exts_meta = [
299
325
  {
300
326
  "name": ext.name,
301
327
  "version": str(ext.version),
302
328
  }
303
- for ext in all_used_extensions
329
+ for ext in used_extensions_result.used_extensions.extensions.values()
304
330
  ]
331
+ # Add unresolved extensions as well, but we only have the names
332
+ used_exts_meta.extend(
333
+ {
334
+ "name": ext,
335
+ }
336
+ for ext in used_extensions_result.unresolved_extensions
337
+ )
338
+ graph.hugr.module_root.metadata[CoreMetadataKeys.USED_EXTENSIONS.value] = (
339
+ used_exts_meta
340
+ )
305
341
  graph.hugr.module_root.metadata[CoreMetadataKeys.GENERATOR.value] = {
306
342
  "name": "guppylang",
307
343
  "version": guppylang_internals.__version__,
308
344
  }
309
- return ModulePointer(Package(modules=[graph.hugr], extensions=extensions), 0)
345
+ # only package used extensions
346
+ packaged_extensions = [
347
+ ext
348
+ for ext in packaged_extensions
349
+ if ext.name in used_extensions_result.ids()
350
+ ]
351
+ return ModulePointer(
352
+ Package(modules=[graph.hugr], extensions=packaged_extensions), 0
353
+ )
310
354
 
311
355
 
312
356
  ENGINE: CompilationEngine = CompilationEngine()
@@ -104,4 +104,4 @@ def pretty_errors(f: FuncT) -> FuncT:
104
104
  with exception_hook(hook):
105
105
  return f(*args, **kwargs)
106
106
 
107
- return cast(FuncT, pretty_errors_wrapped)
107
+ return cast("FuncT", pretty_errors_wrapped)