guppylang-internals 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/ast_util.py +21 -0
  3. guppylang_internals/cfg/bb.py +20 -0
  4. guppylang_internals/cfg/builder.py +118 -5
  5. guppylang_internals/cfg/cfg.py +3 -0
  6. guppylang_internals/checker/cfg_checker.py +6 -0
  7. guppylang_internals/checker/core.py +5 -2
  8. guppylang_internals/checker/errors/generic.py +32 -1
  9. guppylang_internals/checker/errors/type_errors.py +14 -0
  10. guppylang_internals/checker/errors/wasm.py +7 -4
  11. guppylang_internals/checker/expr_checker.py +58 -17
  12. guppylang_internals/checker/func_checker.py +18 -14
  13. guppylang_internals/checker/linearity_checker.py +67 -10
  14. guppylang_internals/checker/modifier_checker.py +120 -0
  15. guppylang_internals/checker/stmt_checker.py +48 -1
  16. guppylang_internals/checker/unitary_checker.py +132 -0
  17. guppylang_internals/compiler/cfg_compiler.py +7 -6
  18. guppylang_internals/compiler/core.py +93 -56
  19. guppylang_internals/compiler/expr_compiler.py +72 -168
  20. guppylang_internals/compiler/modifier_compiler.py +176 -0
  21. guppylang_internals/compiler/stmt_compiler.py +15 -8
  22. guppylang_internals/decorator.py +86 -7
  23. guppylang_internals/definition/custom.py +39 -1
  24. guppylang_internals/definition/declaration.py +9 -6
  25. guppylang_internals/definition/function.py +12 -2
  26. guppylang_internals/definition/parameter.py +8 -3
  27. guppylang_internals/definition/pytket_circuits.py +14 -41
  28. guppylang_internals/definition/struct.py +13 -7
  29. guppylang_internals/definition/ty.py +3 -3
  30. guppylang_internals/definition/wasm.py +42 -10
  31. guppylang_internals/engine.py +9 -3
  32. guppylang_internals/experimental.py +5 -0
  33. guppylang_internals/nodes.py +147 -24
  34. guppylang_internals/std/_internal/checker.py +13 -108
  35. guppylang_internals/std/_internal/compiler/array.py +95 -283
  36. guppylang_internals/std/_internal/compiler/list.py +1 -1
  37. guppylang_internals/std/_internal/compiler/platform.py +153 -0
  38. guppylang_internals/std/_internal/compiler/prelude.py +12 -4
  39. guppylang_internals/std/_internal/compiler/tket_exts.py +8 -2
  40. guppylang_internals/std/_internal/debug.py +18 -9
  41. guppylang_internals/std/_internal/util.py +1 -1
  42. guppylang_internals/tracing/object.py +10 -0
  43. guppylang_internals/tracing/unpacking.py +19 -20
  44. guppylang_internals/tys/arg.py +18 -3
  45. guppylang_internals/tys/builtin.py +2 -5
  46. guppylang_internals/tys/const.py +33 -4
  47. guppylang_internals/tys/errors.py +23 -1
  48. guppylang_internals/tys/param.py +31 -16
  49. guppylang_internals/tys/parsing.py +11 -24
  50. guppylang_internals/tys/printing.py +2 -8
  51. guppylang_internals/tys/qubit.py +62 -0
  52. guppylang_internals/tys/subst.py +8 -26
  53. guppylang_internals/tys/ty.py +91 -85
  54. guppylang_internals/wasm_util.py +129 -0
  55. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/METADATA +6 -5
  56. guppylang_internals-0.26.0.dist-info/RECORD +104 -0
  57. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/WHEEL +1 -1
  58. guppylang_internals-0.24.0.dist-info/RECORD +0 -98
  59. {guppylang_internals-0.24.0.dist-info → guppylang_internals-0.26.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,3 +1,3 @@
1
1
  # This is updated by our release-please workflow, triggered by this
2
2
  # annotation: x-release-please-version
3
- __version__ = "0.24.0"
3
+ __version__ = "0.26.0"
@@ -106,6 +106,14 @@ def return_nodes_in_ast(node: Any) -> list[ast.Return]:
106
106
  return cast(list[ast.Return], found)
107
107
 
108
108
 
109
+ def loop_in_ast(node: Any) -> list[ast.For | ast.While]:
110
+ """Returns all `For` and `While` nodes occurring in an AST."""
111
+ found = find_nodes(
112
+ lambda n: isinstance(n, ast.For | ast.While), node, {ast.FunctionDef}
113
+ )
114
+ return cast(list[ast.For | ast.While], found)
115
+
116
+
109
117
  def breaks_in_loop(node: Any) -> list[ast.Break]:
110
118
  """Returns all `Break` nodes occurring in a loop.
111
119
 
@@ -117,6 +125,19 @@ def breaks_in_loop(node: Any) -> list[ast.Break]:
117
125
  return cast(list[ast.Break], found)
118
126
 
119
127
 
128
+ def loop_controls_in_loop(node: Any) -> list[ast.Break | ast.Continue]:
129
+ """Returns all `Break` and `Continue` nodes occurring in a loop.
130
+
131
+ Note that breaks in nested loops are excluded.
132
+ """
133
+ found = find_nodes(
134
+ lambda n: isinstance(n, ast.Break | ast.Continue),
135
+ node,
136
+ {ast.For, ast.While, ast.FunctionDef},
137
+ )
138
+ return cast(list[ast.Break | ast.Continue], found)
139
+
140
+
120
141
  class ContextAdjuster(ast.NodeTransformer):
121
142
  """Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""
122
143
 
@@ -13,6 +13,7 @@ from guppylang_internals.nodes import (
13
13
  DesugaredGenerator,
14
14
  DesugaredGeneratorExpr,
15
15
  DesugaredListComp,
16
+ ModifiedBlock,
16
17
  NestedFunctionDef,
17
18
  )
18
19
 
@@ -44,6 +45,7 @@ BBStatement = (
44
45
  | ast.Expr
45
46
  | ast.Return
46
47
  | NestedFunctionDef
48
+ | ModifiedBlock
47
49
  )
48
50
 
49
51
 
@@ -219,3 +221,21 @@ class VariableVisitor(ast.NodeVisitor):
219
221
 
220
222
  # The name of the function is now assigned
221
223
  self.stats.assigned[node.name] = node
224
+
225
+ def visit_ModifiedBlock(self, node: ModifiedBlock) -> None:
226
+ for item in node.control:
227
+ self.visit(item)
228
+ for item in node.power:
229
+ self.visit(item)
230
+
231
+ # Similarly to nested functions
232
+ from guppylang_internals.cfg.analysis import LivenessAnalysis
233
+
234
+ stats = {bb: bb.compute_variable_stats() for bb in node.cfg.bbs}
235
+ live = LivenessAnalysis(stats).run(node.cfg.bbs)
236
+ assigned_before_in_bb = self.stats.assigned.keys()
237
+ self.stats.used |= {
238
+ x: using_bb.vars.used[x]
239
+ for x, using_bb in live[node.cfg.entry_bb].items()
240
+ if x not in assigned_before_in_bb
241
+ }
@@ -9,6 +9,8 @@ from guppylang_internals.ast_util import (
9
9
  AstVisitor,
10
10
  ContextAdjuster,
11
11
  find_nodes,
12
+ loop_controls_in_loop,
13
+ return_nodes_in_ast,
12
14
  set_location_from,
13
15
  template_replace,
14
16
  with_loc,
@@ -16,20 +18,35 @@ from guppylang_internals.ast_util import (
16
18
  from guppylang_internals.cfg.bb import BB, BBStatement
17
19
  from guppylang_internals.cfg.cfg import CFG
18
20
  from guppylang_internals.checker.core import Globals
19
- from guppylang_internals.checker.errors.generic import ExpectedError, UnsupportedError
21
+ from guppylang_internals.checker.errors.generic import (
22
+ ExpectedError,
23
+ UnexpectedInWithBlockError,
24
+ UnknownModifierError,
25
+ UnsupportedError,
26
+ )
27
+ from guppylang_internals.checker.errors.type_errors import WrongNumberOfArgsError
20
28
  from guppylang_internals.diagnostic import Error
21
29
  from guppylang_internals.error import GuppyError, InternalGuppyError
22
- from guppylang_internals.experimental import check_lists_enabled
30
+ from guppylang_internals.experimental import (
31
+ check_lists_enabled,
32
+ check_modifiers_enabled,
33
+ )
23
34
  from guppylang_internals.nodes import (
24
35
  ComptimeExpr,
36
+ Control,
37
+ Dagger,
25
38
  DesugaredGenerator,
26
39
  DesugaredGeneratorExpr,
27
40
  DesugaredListComp,
28
41
  IterNext,
29
42
  MakeIter,
43
+ ModifiedBlock,
44
+ Modifier,
30
45
  NestedFunctionDef,
46
+ Power,
31
47
  )
32
- from guppylang_internals.tys.ty import NoneType
48
+ from guppylang_internals.span import Span, to_span
49
+ from guppylang_internals.tys.ty import NoneType, UnitaryFlags
33
50
 
34
51
  # In order to build expressions, need an endless stream of unique temporary variables
35
52
  # to store intermediate results
@@ -61,7 +78,13 @@ class CFGBuilder(AstVisitor[BB | None]):
61
78
  cfg: CFG
62
79
  globals: Globals
63
80
 
64
- def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG:
81
+ def build(
82
+ self,
83
+ nodes: list[ast.stmt],
84
+ returns_none: bool,
85
+ globals: Globals,
86
+ unitary_flags: UnitaryFlags = UnitaryFlags.NoFlags,
87
+ ) -> CFG:
65
88
  """Builds a CFG from a list of ast nodes.
66
89
 
67
90
  We also require the expected number of return ports for the whole CFG. This is
@@ -69,6 +92,7 @@ class CFGBuilder(AstVisitor[BB | None]):
69
92
  variables.
70
93
  """
71
94
  self.cfg = CFG()
95
+ self.cfg.unitary_flags = unitary_flags
72
96
  self.globals = globals
73
97
 
74
98
  final_bb = self.visit_stmts(
@@ -135,7 +159,10 @@ class CFGBuilder(AstVisitor[BB | None]):
135
159
  Builds the expression and mutates `node.value` to point to the built expression.
136
160
  Returns the BB in which the expression is available and adds the node to it.
137
161
  """
138
- if not isinstance(node, NestedFunctionDef) and node.value is not None:
162
+ if (
163
+ not isinstance(node, NestedFunctionDef | ModifiedBlock)
164
+ and node.value is not None
165
+ ):
139
166
  node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
140
167
  bb.statements.append(node)
141
168
  return bb
@@ -253,6 +280,7 @@ class CFGBuilder(AstVisitor[BB | None]):
253
280
 
254
281
  func_ty = check_signature(node, self.globals)
255
282
  returns_none = isinstance(func_ty.output, NoneType)
283
+ # No UnitaryFlags are assigned to nested functions
256
284
  cfg = CFGBuilder().build(node.body, returns_none, self.globals)
257
285
 
258
286
  new_node = NestedFunctionDef(
@@ -265,6 +293,91 @@ class CFGBuilder(AstVisitor[BB | None]):
265
293
  bb.statements.append(new_node)
266
294
  return bb
267
295
 
296
+ def visit_With(self, node: ast.With, bb: BB, jumps: Jumps) -> BB | None:
297
+ check_modifiers_enabled(node)
298
+ self._validate_modified_block(node)
299
+
300
+ cfg = CFGBuilder().build(node.body, True, self.globals)
301
+ new_node = ModifiedBlock(
302
+ cfg=cfg,
303
+ **dict(ast.iter_fields(node)),
304
+ )
305
+
306
+ for item in node.items:
307
+ item.context_expr, bb = ExprBuilder.build(item.context_expr, self.cfg, bb)
308
+ modifier = self._handle_withitem(item)
309
+ new_node.push_modifier(modifier)
310
+
311
+ # FIXME: Currently, the unitary flags is not set correctly if there are nested
312
+ # `with` blocks. This is because the outer block's unitary flags are not
313
+ # propagated to the outer block. The following line should calculate the sum
314
+ # of the unitary flags of the outer block and modifiers applied in this
315
+ # `with` block.
316
+ cfg.unitary_flags = new_node.flags()
317
+
318
+ set_location_from(new_node, node)
319
+ bb.statements.append(new_node)
320
+ return bb
321
+
322
+ def _handle_withitem(self, node: ast.withitem) -> Modifier:
323
+ # Check that `as` notation is not used
324
+ if node.optional_vars is not None:
325
+ span = Span(
326
+ to_span(node.context_expr).start, to_span(node.optional_vars).end
327
+ )
328
+ raise GuppyError(UnsupportedError(span, "`as` expression", singular=True))
329
+
330
+ e = node.context_expr
331
+ modifier: Modifier
332
+ match e:
333
+ case ast.Name(id="dagger"):
334
+ modifier = Dagger(e)
335
+ case ast.Call(func=ast.Name(id="dagger")):
336
+ if len(e.args) != 0:
337
+ span = Span(to_span(e.args[0]).start, to_span(e.args[-1]).end)
338
+ raise GuppyError(WrongNumberOfArgsError(span, 0, len(e.args)))
339
+ modifier = Dagger(e)
340
+ case ast.Call(func=ast.Name(id="control")):
341
+ if len(e.args) == 0:
342
+ span = Span(to_span(e.func).end, to_span(e).end)
343
+ raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
344
+ modifier = Control(e, e.args)
345
+ case ast.Call(func=ast.Name(id="power")):
346
+ if len(e.args) == 0:
347
+ span = Span(to_span(e.func).end, to_span(e).end)
348
+ raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
349
+ elif len(e.args) != 1:
350
+ span = Span(to_span(e.args[1]).start, to_span(e.args[-1]).end)
351
+ raise GuppyError(WrongNumberOfArgsError(span, 1, len(e.args)))
352
+ modifier = Power(e, e.args[0])
353
+ case _:
354
+ raise GuppyError(UnknownModifierError(e))
355
+ return modifier
356
+
357
+ def _validate_modified_block(self, node: ast.With) -> None:
358
+ # Check if the body contains a return statement.
359
+ return_in_body = return_nodes_in_ast(node)
360
+ if len(return_in_body) != 0:
361
+ err = UnexpectedInWithBlockError(return_in_body[0], "return", "Return")
362
+ span = Span(
363
+ to_span(node.items[0].context_expr).start,
364
+ to_span(node.items[-1].context_expr).end,
365
+ )
366
+ err.add_sub_diagnostic(UnexpectedInWithBlockError.Modifier(span))
367
+ raise GuppyError(err)
368
+
369
+ loop_controls_in_body = loop_controls_in_loop(node)
370
+ if len(loop_controls_in_body) != 0:
371
+ lc = loop_controls_in_body[0]
372
+ kind = lc.__class__.__name__
373
+ err = UnexpectedInWithBlockError(lc, "loop control", kind)
374
+ span = Span(
375
+ to_span(node.items[0].context_expr).start,
376
+ to_span(node.items[-1].context_expr).end,
377
+ )
378
+ err.add_sub_diagnostic(UnexpectedInWithBlockError.Modifier(span))
379
+ raise GuppyError(err)
380
+
268
381
  def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None:
269
382
  # When adding support for new statements, we have to remember to use the
270
383
  # ExprBuilder to transform all included expressions!
@@ -12,6 +12,7 @@ from guppylang_internals.cfg.analysis import (
12
12
  )
13
13
  from guppylang_internals.cfg.bb import BB, BBStatement, VariableStats
14
14
  from guppylang_internals.nodes import InoutReturnSentinel
15
+ from guppylang_internals.tys.ty import UnitaryFlags
15
16
 
16
17
  T = TypeVar("T", bound=BB)
17
18
 
@@ -29,6 +30,7 @@ class BaseCFG(Generic[T]):
29
30
 
30
31
  #: Set of variables defined in this CFG
31
32
  assigned_somewhere: set[str]
33
+ unitary_flags: UnitaryFlags
32
34
 
33
35
  def __init__(
34
36
  self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
@@ -42,6 +44,7 @@ class BaseCFG(Generic[T]):
42
44
  self.ass_before = {}
43
45
  self.maybe_ass_before = {}
44
46
  self.assigned_somewhere = set()
47
+ self.unitary_flags = UnitaryFlags.NoFlags
45
48
 
46
49
  def ancestors(self, *bbs: T) -> Iterator[T]:
47
50
  """Returns an iterator over all ancestors of the given BBs in BFS order."""
@@ -149,11 +149,17 @@ def check_cfg(
149
149
  checked_cfg.maybe_ass_before = {
150
150
  compiled[bb]: cfg.maybe_ass_before[bb] for bb in required_bbs
151
151
  }
152
+ checked_cfg.unitary_flags = cfg.unitary_flags
152
153
 
153
154
  # Finally, run the linearity check
154
155
  from guppylang_internals.checker.linearity_checker import check_cfg_linearity
155
156
 
156
157
  linearity_checked_cfg = check_cfg_linearity(checked_cfg, func_name, globals)
158
+
159
+ from guppylang_internals.checker.unitary_checker import check_cfg_unitary
160
+
161
+ check_cfg_unitary(linearity_checked_cfg, cfg.unitary_flags)
162
+
157
163
  return linearity_checked_cfg
158
164
 
159
165
 
@@ -47,7 +47,6 @@ from guppylang_internals.tys.ty import (
47
47
  NumericType,
48
48
  OpaqueType,
49
49
  StructType,
50
- SumType,
51
50
  TupleType,
52
51
  Type,
53
52
  )
@@ -117,6 +116,10 @@ class Variable:
117
116
  """Returns a new `Variable` instance with an updated definition location."""
118
117
  return replace(self, defined_at=node)
119
118
 
119
+ def add_flags(self, flags: InputFlags) -> "Variable":
120
+ """Returns a new `Variable` instance with updated flags."""
121
+ return replace(self, flags=self.flags | flags)
122
+
120
123
 
121
124
  @dataclass(frozen=True, kw_only=True)
122
125
  class ComptimeVariable(Variable):
@@ -356,7 +359,7 @@ class Globals:
356
359
  match ty:
357
360
  case TypeDef() as type_defn:
358
361
  pass
359
- case BoundTypeVar() | ExistentialTypeVar() | SumType():
362
+ case BoundTypeVar() | ExistentialTypeVar():
360
363
  return None
361
364
  case NumericType(kind):
362
365
  match kind:
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import ClassVar
3
3
 
4
- from guppylang_internals.diagnostic import Error
4
+ from guppylang_internals.diagnostic import Error, Note
5
5
 
6
6
 
7
7
  @dataclass(frozen=True)
@@ -43,3 +43,34 @@ class ExpectedError(Error):
43
43
  @property
44
44
  def extra(self) -> str:
45
45
  return f", got {self.got}" if self.got else ""
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class UnknownModifierError(Error):
50
+ title: ClassVar[str] = "Unknown modifier"
51
+ span_label: ClassVar[str] = (
52
+ "Expected one of {{dagger, control(...), or power(...)}}"
53
+ )
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class UnexpectedInWithBlockError(Error):
58
+ title: ClassVar[str] = "Unexpected {kind}"
59
+ span_label: ClassVar[str] = "{things} found in a `With` block"
60
+ kind: str
61
+ things: str
62
+
63
+ @dataclass(frozen=True)
64
+ class Modifier(Note):
65
+ span_label: ClassVar[str] = "modifier is used here"
66
+
67
+
68
+ @dataclass(frozen=True)
69
+ class InvalidUnderDagger(Error):
70
+ title: ClassVar[str] = "Invalid expression in dagger"
71
+ span_label: ClassVar[str] = "{things} found in a dagger context"
72
+ things: str
73
+
74
+ @dataclass(frozen=True)
75
+ class Dagger(Note):
76
+ span_label: ClassVar[str] = "dagger modifier is used here"
@@ -95,6 +95,20 @@ class TypeInferenceError(Error):
95
95
  unsolved_ty: Type
96
96
 
97
97
 
98
+ @dataclass(frozen=True)
99
+ class ParameterInferenceError(Error):
100
+ title: ClassVar[str] = "Cannot infer generic parameter"
101
+ span_label: ClassVar[str] = (
102
+ "Cannot infer generic parameter `{param}` of this function"
103
+ )
104
+ param: str
105
+
106
+ @dataclass(frozen=True)
107
+ class SignatureHint(Note):
108
+ message: ClassVar[str] = "Function signature is `{sig}`"
109
+ sig: FunctionType
110
+
111
+
98
112
  @dataclass(frozen=True)
99
113
  class IllegalConstant(Error):
100
114
  title: ClassVar[str] = "Unsupported constant"
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass
2
2
  from typing import ClassVar
3
3
 
4
- from guppylang_internals.diagnostic import Error
4
+ from guppylang_internals.diagnostic import Error, Note
5
5
  from guppylang_internals.tys.ty import Type
6
6
 
7
7
 
@@ -13,10 +13,13 @@ class WasmError(Error):
13
13
  @dataclass(frozen=True)
14
14
  class FirstArgNotModule(WasmError):
15
15
  span_label: ClassVar[str] = (
16
- "First argument to WASM function should be a reference to a WASM module."
17
- " Found `{ty}` instead"
16
+ "First argument to WASM function should be a WASM module."
18
17
  )
19
- ty: Type
18
+
19
+ @dataclass(frozen=True)
20
+ class GotOtherType(Note):
21
+ span_label: ClassVar[str] = "Found `{ty}` instead."
22
+ ty: Type
20
23
 
21
24
 
22
25
  @dataclass(frozen=True)
@@ -23,6 +23,7 @@ can be used to infer a type for an expression.
23
23
  import ast
24
24
  import sys
25
25
  import traceback
26
+ from collections.abc import Sequence
26
27
  from contextlib import suppress
27
28
  from dataclasses import replace
28
29
  from types import ModuleType
@@ -42,6 +43,7 @@ from guppylang_internals.ast_util import (
42
43
  )
43
44
  from guppylang_internals.cfg.builder import is_tmp_var, tmp_vars
44
45
  from guppylang_internals.checker.core import (
46
+ ComptimeVariable,
45
47
  Context,
46
48
  DummyEvalDict,
47
49
  FieldAccess,
@@ -74,6 +76,7 @@ from guppylang_internals.checker.errors.type_errors import (
74
76
  ModuleMemberNotFoundError,
75
77
  NonLinearInstantiateError,
76
78
  NotCallableError,
79
+ ParameterInferenceError,
77
80
  TupleIndexOutOfBoundsError,
78
81
  TypeApplyNotGenericError,
79
82
  TypeInferenceError,
@@ -130,7 +133,7 @@ from guppylang_internals.tys.builtin import (
130
133
  string_type,
131
134
  )
132
135
  from guppylang_internals.tys.const import Const, ConstValue
133
- from guppylang_internals.tys.param import ConstParam, TypeParam
136
+ from guppylang_internals.tys.param import ConstParam, TypeParam, check_all_args
134
137
  from guppylang_internals.tys.parsing import arg_from_ast
135
138
  from guppylang_internals.tys.subst import Inst, Subst
136
139
  from guppylang_internals.tys.ty import (
@@ -149,6 +152,7 @@ from guppylang_internals.tys.ty import (
149
152
  parse_function_tensor,
150
153
  unify,
151
154
  )
155
+ from guppylang_internals.tys.var import ExistentialVar
152
156
 
153
157
  if TYPE_CHECKING:
154
158
  from guppylang_internals.diagnostic import SubDiagnostic
@@ -462,8 +466,15 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
462
466
  # A `value.attr` attribute access. Unfortunately, the `attr` is just a string,
463
467
  # not an AST node, so we have to compute its span by hand. This is fine since
464
468
  # linebreaks are not allowed in the identifier following the `.`
469
+ # The only exception are attributes accesses that are generated during
470
+ # desugaring (for example for iterators in `for` loops). Since those just
471
+ # inherit the span of the sugared code, we could have line breaks there.
472
+ # See https://github.com/quantinuum/guppylang/issues/1301
465
473
  span = to_span(node)
466
- attr_span = Span(span.end.shift_left(len(node.attr)), span.end)
474
+ if span.start.line == span.end.line:
475
+ attr_span = Span(span.end.shift_left(len(node.attr)), span.end)
476
+ else:
477
+ attr_span = span
467
478
  if module := self._is_python_module(node.value):
468
479
  if node.attr in module.__dict__:
469
480
  val = module.__dict__[node.attr]
@@ -493,12 +504,7 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
493
504
  )
494
505
  # Make a closure by partially applying the `self` argument
495
506
  # TODO: Try to infer some type args based on `self`
496
- result_ty = FunctionType(
497
- func.ty.inputs[1:],
498
- func.ty.output,
499
- func.ty.input_names[1:] if func.ty.input_names else None,
500
- func.ty.params,
501
- )
507
+ result_ty = FunctionType(func.ty.inputs[1:], func.ty.output, func.ty.params)
502
508
  return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty
503
509
  raise GuppyTypeError(AttributeNotFoundError(attr_span, ty, node.attr))
504
510
 
@@ -928,10 +934,9 @@ def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Ins
928
934
  err.add_sub_diagnostic(WrongNumberOfArgsError.SignatureHint(None, ty))
929
935
  raise GuppyError(err)
930
936
 
931
- return [
932
- param.check_arg(arg_from_ast(arg_expr, ctx.parsing_ctx), arg_expr)
933
- for arg_expr, param in zip(arg_exprs, ty.params, strict=True)
934
- ]
937
+ inst = [arg_from_ast(node, ctx.parsing_ctx) for node in arg_exprs]
938
+ check_all_args(ty.params, inst, "", node, arg_exprs)
939
+ return inst
935
940
 
936
941
 
937
942
  def check_num_args(
@@ -975,15 +980,17 @@ def type_check_args(
975
980
  comptime_args = iter(func_ty.comptime_args)
976
981
  for inp, func_inp in zip(inputs, func_ty.inputs, strict=True):
977
982
  a, s = ExprChecker(ctx).check(inp, func_inp.ty.substitute(subst), "argument")
983
+ subst |= s
978
984
  if InputFlags.Inout in func_inp.flags and isinstance(a, PlaceNode):
979
985
  a.place = check_place_assignable(
980
986
  a.place, ctx, a, "able to borrow subscripted elements"
981
987
  )
982
988
  if InputFlags.Comptime in func_inp.flags:
983
989
  comptime_arg = next(comptime_args)
984
- s = check_comptime_arg(a, comptime_arg.const, func_inp.ty, s)
990
+ const = comptime_arg.const.substitute(subst)
991
+ s = check_comptime_arg(a, const, func_inp.ty.substitute(subst), subst)
992
+ subst |= s
985
993
  new_args.append(a)
986
- subst |= s
987
994
  assert next(comptime_args, None) is None
988
995
 
989
996
  # If the argument check succeeded, this means that we must have found instantiations
@@ -1024,7 +1031,14 @@ def check_place_assignable(
1024
1031
  exp_sig = FunctionType(
1025
1032
  [
1026
1033
  FuncInput(parent.ty, InputFlags.Inout),
1027
- FuncInput(item.ty, InputFlags.NoFlags),
1034
+ FuncInput(
1035
+ # Due to potential coercions that were applied during the
1036
+ # `__getitem__` call (e.g. coercing a nat index to int), we're
1037
+ # not allowed to rely on `item.ty` here.
1038
+ # See https://github.com/CQCL/guppylang/issues/1356
1039
+ ExistentialTypeVar.fresh("T", True, True),
1040
+ InputFlags.NoFlags,
1041
+ ),
1028
1042
  FuncInput(ty, InputFlags.Owned),
1029
1043
  ],
1030
1044
  NoneType(),
@@ -1061,6 +1075,8 @@ def check_comptime_arg(
1061
1075
  match arg:
1062
1076
  case ast.Constant(value=v):
1063
1077
  const = ConstValue(ty, v)
1078
+ case PlaceNode(place=ComptimeVariable(ty=ty, static_value=v)):
1079
+ const = ConstValue(ty, v)
1064
1080
  case GenericParamValue(param=const_param):
1065
1081
  const = const_param.to_bound().const
1066
1082
  case arg:
@@ -1103,7 +1119,7 @@ def synthesize_call(
1103
1119
 
1104
1120
  # Success implies that the substitution is closed
1105
1121
  assert all(not t.unsolved_vars for t in subst.values())
1106
- inst = [subst[v].to_arg() for v in free_vars]
1122
+ inst = check_all_solved(subst, free_vars, func_ty, node)
1107
1123
 
1108
1124
  # Finally, check that the instantiation respects the linearity requirements
1109
1125
  check_inst(func_ty, inst, node)
@@ -1182,7 +1198,7 @@ def check_call(
1182
1198
 
1183
1199
  # Success implies that the substitution is closed
1184
1200
  assert all(not t.unsolved_vars for t in subst.values())
1185
- inst = [subst[v].to_arg() for v in free_vars]
1201
+ inst = check_all_solved(subst, free_vars, func_ty, node)
1186
1202
  subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars}
1187
1203
 
1188
1204
  # Finally, check that the instantiation respects the linearity requirements
@@ -1191,12 +1207,37 @@ def check_call(
1191
1207
  return inputs, subst, inst
1192
1208
 
1193
1209
 
1210
+ def check_all_solved(
1211
+ subst: Subst,
1212
+ free_vars: Sequence[ExistentialVar],
1213
+ func_ty: FunctionType,
1214
+ loc: AstNode,
1215
+ ) -> Inst:
1216
+ """Checks that a substitution solves all parameters of a function.
1217
+
1218
+ Using 3.12 generic syntax, users can declare parameters that don't occur in the
1219
+ signature. Those will remain unsolved, even after unifying all function arguments,
1220
+ so we have to perform this extra check.
1221
+
1222
+ Returns an instantiation of all free variables, or emits a user error if some are
1223
+ not solved.
1224
+ """
1225
+ for v in free_vars:
1226
+ if v not in subst:
1227
+ err = ParameterInferenceError(loc, v.display_name)
1228
+ err.add_sub_diagnostic(ParameterInferenceError.SignatureHint(None, func_ty))
1229
+ raise GuppyTypeInferenceError(err)
1230
+ return [subst[v].to_arg() for v in free_vars]
1231
+
1232
+
1194
1233
  def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None:
1195
1234
  """Checks if an instantiation is valid.
1196
1235
 
1197
1236
  Makes sure that the linearity requirements are satisfied.
1198
1237
  """
1199
1238
  for param, arg in zip(func_ty.params, inst, strict=True):
1239
+ param = param.instantiate_bounds(inst)
1240
+
1200
1241
  # Give a more informative error message for linearity issues
1201
1242
  if isinstance(param, TypeParam) and isinstance(arg, TypeArg):
1202
1243
  if param.must_be_copyable and not arg.ty.copyable: