guppylang-internals 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. guppylang_internals/__init__.py +1 -1
  2. guppylang_internals/cfg/cfg.py +8 -0
  3. guppylang_internals/checker/cfg_checker.py +26 -65
  4. guppylang_internals/checker/core.py +8 -0
  5. guppylang_internals/checker/expr_checker.py +11 -25
  6. guppylang_internals/checker/func_checker.py +170 -21
  7. guppylang_internals/checker/stmt_checker.py +1 -1
  8. guppylang_internals/decorator.py +124 -58
  9. guppylang_internals/definition/const.py +2 -2
  10. guppylang_internals/definition/custom.py +1 -1
  11. guppylang_internals/definition/declaration.py +1 -1
  12. guppylang_internals/definition/extern.py +2 -2
  13. guppylang_internals/definition/function.py +1 -1
  14. guppylang_internals/definition/parameter.py +2 -2
  15. guppylang_internals/definition/pytket_circuits.py +1 -1
  16. guppylang_internals/definition/struct.py +10 -10
  17. guppylang_internals/definition/traced.py +1 -1
  18. guppylang_internals/definition/ty.py +6 -0
  19. guppylang_internals/definition/wasm.py +2 -2
  20. guppylang_internals/engine.py +13 -2
  21. guppylang_internals/nodes.py +0 -23
  22. guppylang_internals/std/_internal/compiler/tket_exts.py +3 -6
  23. guppylang_internals/std/_internal/compiler/wasm.py +37 -26
  24. guppylang_internals/tracing/function.py +13 -2
  25. guppylang_internals/tracing/unpacking.py +18 -12
  26. guppylang_internals/tys/builtin.py +30 -11
  27. guppylang_internals/tys/errors.py +6 -0
  28. guppylang_internals/tys/parsing.py +111 -125
  29. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/METADATA +5 -5
  30. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/RECORD +32 -32
  31. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/WHEEL +0 -0
  32. {guppylang_internals-0.22.0.dist-info → guppylang_internals-0.24.0.dist-info}/licenses/LICENCE +0 -0
@@ -1,3 +1,3 @@
1
1
  # This is updated by our release-please workflow, triggered by this
2
2
  # annotation: x-release-please-version
3
- __version__ = "0.22.0"
3
+ __version__ = "0.24.0"
@@ -27,6 +27,9 @@ class BaseCFG(Generic[T]):
27
27
  ass_before: Result[DefAssignmentDomain[str]]
28
28
  maybe_ass_before: Result[MaybeAssignmentDomain[str]]
29
29
 
30
+ #: Set of variables defined in this CFG
31
+ assigned_somewhere: set[str]
32
+
30
33
  def __init__(
31
34
  self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
32
35
  ):
@@ -38,6 +41,7 @@ class BaseCFG(Generic[T]):
38
41
  self.live_before = {}
39
42
  self.ass_before = {}
40
43
  self.maybe_ass_before = {}
44
+ self.assigned_somewhere = set()
41
45
 
42
46
  def ancestors(self, *bbs: T) -> Iterator[T]:
43
47
  """Returns an iterator over all ancestors of the given BBs in BFS order."""
@@ -101,6 +105,10 @@ class CFG(BaseCFG[BB]):
101
105
  inout_vars: list[str],
102
106
  ) -> dict[BB, VariableStats[str]]:
103
107
  stats = {bb: bb.compute_variable_stats() for bb in self.bbs}
108
+ # Locals are variables that are assigned somewhere inside the function
109
+ self.assigned_somewhere = def_ass_before.union(
110
+ maybe_ass_before, (x for bb in self.bbs for x in stats[bb].assigned)
111
+ )
104
112
  # Mark all borrowed variables as implicitly used in the exit BB
105
113
  stats[self.exit_bb].used |= {x: InoutReturnSentinel(var=x) for x in inout_vars}
106
114
  # This also means borrowed variables are always live, so we can use them as the
@@ -8,7 +8,7 @@ import ast
8
8
  import collections
9
9
  from collections.abc import Iterator, Sequence
10
10
  from dataclasses import dataclass, field
11
- from typing import ClassVar, Generic, TypeVar, cast
11
+ from typing import ClassVar, Generic, TypeVar
12
12
 
13
13
  from guppylang_internals.ast_util import line_col
14
14
  from guppylang_internals.cfg.bb import BB
@@ -23,7 +23,6 @@ from guppylang_internals.checker.core import (
23
23
  )
24
24
  from guppylang_internals.checker.expr_checker import ExprSynthesizer, to_bool
25
25
  from guppylang_internals.checker.stmt_checker import StmtChecker
26
- from guppylang_internals.definition.value import ValueDef
27
26
  from guppylang_internals.diagnostic import Error, Note
28
27
  from guppylang_internals.error import GuppyError
29
28
  from guppylang_internals.tys.param import Parameter
@@ -115,7 +114,7 @@ def check_cfg(
115
114
  if bb in compiled:
116
115
  # If the BB was already compiled, we just have to check that the signatures
117
116
  # match.
118
- check_rows_match(input_row, compiled[bb].sig.input_row, bb, globals)
117
+ check_rows_match(input_row, compiled[bb].sig.input_row, bb)
119
118
  else:
120
119
  # Otherwise, check the BB and enqueue its successors
121
120
  checked_bb = check_bb(
@@ -195,21 +194,6 @@ class BranchTypeError(Error):
195
194
  span_label: ClassVar[str] = "This is of type `{ty}`"
196
195
  ty: Type
197
196
 
198
- @dataclass(frozen=True)
199
- class GlobalHint(Note):
200
- message: ClassVar[str] = (
201
- "{ident} may be shadowing a global {defn.description} definition of type "
202
- "`{defn.ty}` on some branches"
203
- )
204
- defn: ValueDef
205
-
206
-
207
- @dataclass(frozen=True)
208
- class GlobalShadowError(Error):
209
- title: ClassVar[str] = "Global variable conditionally shadowed"
210
- span_label: ClassVar[str] = "{ident} may be shadowing a global variable"
211
- ident: str
212
-
213
197
 
214
198
  def check_bb(
215
199
  bb: BB,
@@ -245,23 +229,27 @@ def check_bb(
245
229
 
246
230
  for succ in bb.successors + bb.dummy_successors:
247
231
  for x, use_bb in cfg.live_before[succ].items():
248
- # Check that the variables requested by the successor are defined
249
- if (
250
- x not in ctx.locals
251
- and x not in ctx.globals
252
- and x not in ctx.generic_params
253
- ):
254
- # If the variable is defined on *some* paths, we can give a more
255
- # informative error message
256
- if x in cfg.maybe_ass_before[use_bb]:
257
- err = VarMaybeNotDefinedError(use_bb.vars.used[x], x)
258
- if bad_branch := diagnose_maybe_undefined(use_bb, x, cfg):
259
- branch_expr, truth_value = bad_branch
260
- note = VarMaybeNotDefinedError.BadBranch(
261
- branch_expr, x, truth_value
262
- )
263
- err.add_sub_diagnostic(note)
232
+ # Check that the variables requested by the successor are defined. If `x` is
233
+ # a local variable, then we must be able to find it in the context.
234
+ # Following Python, locals are exactly those variables that are defined
235
+ # somewhere in the function body.
236
+ if x in cfg.assigned_somewhere:
237
+ if x not in ctx.locals:
238
+ # If the variable is defined on *some* paths, we can give a more
239
+ # informative error message
240
+ if x in cfg.maybe_ass_before[use_bb]:
241
+ err: Error = VarMaybeNotDefinedError(use_bb.vars.used[x], x)
242
+ if bad_branch := diagnose_maybe_undefined(use_bb, x, cfg):
243
+ branch_expr, truth_value = bad_branch
244
+ note = VarMaybeNotDefinedError.BadBranch(
245
+ branch_expr, x, truth_value
246
+ )
247
+ err.add_sub_diagnostic(note)
248
+ else:
249
+ err = VarNotDefinedError(use_bb.vars.used[x], x)
264
250
  raise GuppyError(err)
251
+ # If x is not a local, then it must be a global or generic param
252
+ elif x not in ctx.globals and x not in ctx.generic_params:
265
253
  raise GuppyError(VarNotDefinedError(use_bb.vars.used[x], x))
266
254
 
267
255
  # Finally, we need to compute the signature of the basic block
@@ -287,9 +275,7 @@ def check_bb(
287
275
  return checked_bb
288
276
 
289
277
 
290
- def check_rows_match(
291
- row1: Row[Variable], row2: Row[Variable], bb: BB, globals: Globals
292
- ) -> None:
278
+ def check_rows_match(row1: Row[Variable], row2: Row[Variable], bb: BB) -> None:
293
279
  """Checks that the types of two rows match up.
294
280
 
295
281
  Otherwise, an error is thrown, alerting the user that a variable has different
@@ -299,10 +285,7 @@ def check_rows_match(
299
285
  for x in map1.keys() | map2.keys():
300
286
  # If block signature lengths don't match but no undefined error was thrown, some
301
287
  # variables may be shadowing global variables.
302
- v1 = map1.get(x) or cast(ValueDef, globals[x])
303
- assert isinstance(v1, Variable | ValueDef)
304
- v2 = map2.get(x) or cast(ValueDef, globals[x])
305
- assert isinstance(v2, Variable | ValueDef)
288
+ v1, v2 = map1[x], map2[x]
306
289
  if v1.ty != v2.ty:
307
290
  # In the error message, we want to mention the variable that was first
308
291
  # defined at the start.
@@ -320,31 +303,9 @@ def check_rows_match(
320
303
  # We don't add a location to the type hint for the global variable,
321
304
  # since it could lead to cross-file diagnostics (which are not
322
305
  # supported) or refer to long function definitions.
323
- sub1 = (
324
- BranchTypeError.TypeHint(v1.defined_at, v1.ty)
325
- if isinstance(v1, Variable)
326
- else BranchTypeError.GlobalHint(None, v1)
327
- )
328
- sub2 = (
329
- BranchTypeError.TypeHint(v2.defined_at, v2.ty)
330
- if isinstance(v2, Variable)
331
- else BranchTypeError.GlobalHint(None, v2)
332
- )
333
- err.add_sub_diagnostic(sub1)
334
- err.add_sub_diagnostic(sub2)
306
+ err.add_sub_diagnostic(BranchTypeError.TypeHint(v1.defined_at, v1.ty))
307
+ err.add_sub_diagnostic(BranchTypeError.TypeHint(v2.defined_at, v2.ty))
335
308
  raise GuppyError(err)
336
- else:
337
- # TODO: Remove once https://github.com/CQCL/guppylang/issues/827 is done.
338
- # If either is a global variable, don't allow shadowing even if types match.
339
- if not (isinstance(v1, Variable) and isinstance(v2, Variable)):
340
- local_var = v1 if isinstance(v1, Variable) else v2
341
- ident = (
342
- "Expression"
343
- if local_var.name.startswith("%")
344
- else f"Variable `{local_var.name}`"
345
- )
346
- glob_err = GlobalShadowError(local_var.defined_at, ident)
347
- raise GuppyError(glob_err)
348
309
 
349
310
 
350
311
  def diagnose_maybe_undefined(
@@ -54,6 +54,7 @@ from guppylang_internals.tys.ty import (
54
54
 
55
55
  if TYPE_CHECKING:
56
56
  from guppylang_internals.definition.struct import StructField
57
+ from guppylang_internals.tys.parsing import TypeParsingCtx
57
58
 
58
59
 
59
60
  #: A "place" is a description for a storage location of a local value that users
@@ -507,6 +508,13 @@ class Context(NamedTuple):
507
508
  locals: Locals[str, Variable]
508
509
  generic_params: dict[str, Parameter]
509
510
 
511
+ @property
512
+ def parsing_ctx(self) -> "TypeParsingCtx":
513
+ """A type parsing context derived from this checking context."""
514
+ from guppylang_internals.tys.parsing import TypeParsingCtx
515
+
516
+ return TypeParsingCtx(self.globals, self.generic_params)
517
+
510
518
 
511
519
  class DummyEvalDict(dict[str, Any]):
512
520
  """A custom dict that can be passed to `eval` to give better error messages.
@@ -34,6 +34,7 @@ from guppylang_internals.ast_util import (
34
34
  AstNode,
35
35
  AstVisitor,
36
36
  breaks_in_loop,
37
+ get_type,
37
38
  get_type_opt,
38
39
  return_nodes_in_ast,
39
40
  with_loc,
@@ -101,8 +102,6 @@ from guppylang_internals.nodes import (
101
102
  FieldAccessAndDrop,
102
103
  GenericParamValue,
103
104
  GlobalName,
104
- IterEnd,
105
- IterHasNext,
106
105
  IterNext,
107
106
  LocalCall,
108
107
  MakeIter,
@@ -784,14 +783,6 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
784
783
  raise GuppyTypeError(err)
785
784
  return expr, ty
786
785
 
787
- def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]:
788
- node.value, ty = self.synthesize(node.value)
789
- flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
790
- exp_sig = FunctionType([FuncInput(ty, flags)], TupleType([bool_type(), ty]))
791
- return self.synthesize_instance_func(
792
- node.value, [], "__hasnext__", "an iterator", exp_sig, True
793
- )
794
-
795
786
  def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]:
796
787
  node.value, ty = self.synthesize(node.value)
797
788
  flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
@@ -803,14 +794,6 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]):
803
794
  node.value, [], "__next__", "an iterator", exp_sig, True
804
795
  )
805
796
 
806
- def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]:
807
- node.value, ty = self.synthesize(node.value)
808
- flags = InputFlags.Owned if not ty.copyable else InputFlags.NoFlags
809
- exp_sig = FunctionType([FuncInput(ty, flags)], NoneType())
810
- return self.synthesize_instance_func(
811
- node.value, [], "__end__", "an iterator", exp_sig, True
812
- )
813
-
814
797
  def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, Type]:
815
798
  raise InternalGuppyError(
816
799
  "BB contains `ListComp`. Should have been removed during CFG"
@@ -946,7 +929,7 @@ def check_type_apply(ty: FunctionType, node: ast.Subscript, ctx: Context) -> Ins
946
929
  raise GuppyError(err)
947
930
 
948
931
  return [
949
- param.check_arg(arg_from_ast(arg_expr, globals, ctx.generic_params), arg_expr)
932
+ param.check_arg(arg_from_ast(arg_expr, ctx.parsing_ctx), arg_expr)
950
933
  for arg_expr, param in zip(arg_exprs, ty.params, strict=True)
951
934
  ]
952
935
 
@@ -1232,7 +1215,14 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
1232
1215
  """Instantiates quantified type arguments in a function."""
1233
1216
  assert len(ty.params) == len(inst)
1234
1217
  if len(inst) > 0:
1235
- node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst))
1218
+ # Partial applications need to be instantiated on the inside
1219
+ if isinstance(node, PartialApply):
1220
+ full_ty = get_type(node.func)
1221
+ assert isinstance(full_ty, FunctionType)
1222
+ assert full_ty.params == ty.params
1223
+ node.func = instantiate_poly(node.func, full_ty, inst)
1224
+ else:
1225
+ node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst))
1236
1226
  return with_type(ty.instantiate(inst), node)
1237
1227
  return with_type(ty, node)
1238
1228
 
@@ -1309,11 +1299,7 @@ def eval_comptime_expr(node: ComptimeExpr, ctx: Context) -> Any:
1309
1299
  raise GuppyError(ComptimeExprNotCPythonError(node))
1310
1300
 
1311
1301
  try:
1312
- python_val = eval( # noqa: S307
1313
- ast.unparse(node.value),
1314
- None,
1315
- DummyEvalDict(ctx, node.value),
1316
- )
1302
+ python_val = eval(ast.unparse(node.value), DummyEvalDict(ctx, node.value)) # noqa: S307
1317
1303
  except DummyEvalDict.GuppyVarUsedError as e:
1318
1304
  raise GuppyError(ComptimeExprNotStaticError(e.node or node, e.var)) from None
1319
1305
  except Exception as e:
@@ -7,8 +7,8 @@ node straight from the Python AST. We build a CFG, check it, and return a
7
7
 
8
8
  import ast
9
9
  import sys
10
- from dataclasses import dataclass
11
- from typing import TYPE_CHECKING, ClassVar
10
+ from dataclasses import dataclass, replace
11
+ from typing import TYPE_CHECKING, ClassVar, cast
12
12
 
13
13
  from guppylang_internals.ast_util import return_nodes_in_ast, with_loc
14
14
  from guppylang_internals.cfg.bb import BB
@@ -17,13 +17,28 @@ from guppylang_internals.checker.cfg_checker import CheckedCFG, check_cfg
17
17
  from guppylang_internals.checker.core import Context, Globals, Place, Variable
18
18
  from guppylang_internals.checker.errors.generic import UnsupportedError
19
19
  from guppylang_internals.definition.common import DefId
20
+ from guppylang_internals.definition.ty import TypeDef
20
21
  from guppylang_internals.diagnostic import Error, Help, Note
21
22
  from guppylang_internals.engine import DEF_STORE, ENGINE
22
23
  from guppylang_internals.error import GuppyError
23
24
  from guppylang_internals.experimental import check_capturing_closures_enabled
24
25
  from guppylang_internals.nodes import CheckedNestedFunctionDef, NestedFunctionDef
25
- from guppylang_internals.tys.parsing import parse_function_io_types
26
- from guppylang_internals.tys.ty import FunctionType, InputFlags, NoneType
26
+ from guppylang_internals.tys.parsing import (
27
+ TypeParsingCtx,
28
+ check_function_arg,
29
+ parse_function_arg_annotation,
30
+ type_from_ast,
31
+ type_with_flags_from_ast,
32
+ )
33
+ from guppylang_internals.tys.ty import (
34
+ ExistentialTypeVar,
35
+ FuncInput,
36
+ FunctionType,
37
+ InputFlags,
38
+ NoneType,
39
+ Type,
40
+ unify,
41
+ )
27
42
 
28
43
  if sys.version_info >= (3, 12):
29
44
  from guppylang_internals.tys.parsing import parse_parameter
@@ -53,6 +68,15 @@ class MissingArgAnnotationError(Error):
53
68
  span_label: ClassVar[str] = "Argument requires a type annotation"
54
69
 
55
70
 
71
+ @dataclass(frozen=True)
72
+ class RecursiveSelfError(Error):
73
+ title: ClassVar[str] = "Recursive self annotation"
74
+ span_label: ClassVar[str] = (
75
+ "Type of `{self_arg}` cannot recursively refer to `Self`"
76
+ )
77
+ self_arg: str
78
+
79
+
56
80
  @dataclass(frozen=True)
57
81
  class MissingReturnAnnotationError(Error):
58
82
  title: ClassVar[str] = "Missing type annotation"
@@ -67,6 +91,43 @@ class MissingReturnAnnotationError(Error):
67
91
  func: str
68
92
 
69
93
 
94
+ @dataclass(frozen=True)
95
+ class InvalidSelfError(Error):
96
+ title: ClassVar[str] = "Invalid self annotation"
97
+ span_label: ClassVar[str] = "`{self_arg}` must be of type `{self_ty}`"
98
+ self_arg: str
99
+ self_ty: Type
100
+
101
+
102
+ @dataclass(frozen=True)
103
+ class SelfParamsShadowedError(Error):
104
+ title: ClassVar[str] = "Shadowed generic parameters"
105
+ span_label: ClassVar[str] = (
106
+ "Cannot infer type for `{self_arg}` since parameter `{param}` of "
107
+ "`{ty_defn.name}` is shadowed"
108
+ )
109
+ param: str
110
+ ty_defn: "TypeDef"
111
+ self_arg: str
112
+
113
+ @dataclass(frozen=True)
114
+ class ExplicitHelp(Help):
115
+ span_label: ClassVar[str] = (
116
+ "Consider specifying the type explicitly: `{suggestion}`"
117
+ )
118
+
119
+ @property
120
+ def suggestion(self) -> str:
121
+ parent = self._parent
122
+ assert isinstance(parent, SelfParamsShadowedError)
123
+ params = (
124
+ f"[{', '.join(f'?{p.name}' for p in parent.ty_defn.params)}]"
125
+ if parent.ty_defn.params
126
+ else ""
127
+ )
128
+ return f'{parent.self_arg}: "{parent.ty_defn.name}{params}"'
129
+
130
+
70
131
  def check_global_func_def(
71
132
  func_def: ast.FunctionDef, ty: FunctionType, globals: Globals
72
133
  ) -> CheckedCFG[Place]:
@@ -176,9 +237,16 @@ def check_nested_func_def(
176
237
  return with_loc(func_def, checked_def)
177
238
 
178
239
 
179
- def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType:
240
+ def check_signature(
241
+ func_def: ast.FunctionDef, globals: Globals, def_id: DefId | None = None
242
+ ) -> FunctionType:
180
243
  """Checks the signature of a function definition and returns the corresponding
181
- Guppy type."""
244
+ Guppy type.
245
+
246
+ If this is a method, then the `DefId` of the associated parent type should also be
247
+ passed. This will be used to check or infer the type annotation for the `self`
248
+ argument.
249
+ """
182
250
  if len(func_def.args.posonlyargs) != 0:
183
251
  raise GuppyError(
184
252
  UnsupportedError(func_def.args.posonlyargs[0], "Positional-only parameters")
@@ -211,23 +279,29 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
211
279
  param = parse_parameter(param_node, i, globals)
212
280
  param_var_mapping[param.name] = param
213
281
 
214
- input_nodes = []
282
+ # Figure out if this is a method
283
+ self_defn: TypeDef | None = None
284
+ if def_id is not None and def_id in DEF_STORE.impl_parents:
285
+ self_defn = cast(TypeDef, ENGINE.get_checked(DEF_STORE.impl_parents[def_id]))
286
+ assert isinstance(self_defn, TypeDef)
287
+
288
+ inputs = []
215
289
  input_names = []
216
- for inp in func_def.args.args:
217
- ty_ast = inp.annotation
218
- if ty_ast is None:
219
- raise GuppyError(MissingArgAnnotationError(inp))
220
- input_nodes.append(ty_ast)
290
+ ctx = TypeParsingCtx(globals, param_var_mapping, allow_free_vars=True)
291
+ for i, inp in enumerate(func_def.args.args):
292
+ # Special handling for `self` arguments. Note that `__new__` is excluded here
293
+ # since it's not a method so doesn't take `self`.
294
+ if self_defn and i == 0 and func_def.name != "__new__":
295
+ input = parse_self_arg(inp, self_defn, ctx)
296
+ ctx = replace(ctx, self_ty=input.ty)
297
+ else:
298
+ ty_ast = inp.annotation
299
+ if ty_ast is None:
300
+ raise GuppyError(MissingArgAnnotationError(inp))
301
+ input = parse_function_arg_annotation(ty_ast, inp.arg, ctx)
302
+ inputs.append(input)
221
303
  input_names.append(inp.arg)
222
- inputs, output = parse_function_io_types(
223
- input_nodes,
224
- func_def.returns,
225
- input_names,
226
- func_def,
227
- globals,
228
- param_var_mapping,
229
- True,
230
- )
304
+ output = type_from_ast(func_def.returns, ctx)
231
305
  return FunctionType(
232
306
  inputs,
233
307
  output,
@@ -236,6 +310,81 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
236
310
  )
237
311
 
238
312
 
313
+ def parse_self_arg(arg: ast.arg, self_defn: TypeDef, ctx: TypeParsingCtx) -> FuncInput:
314
+ """Handles parsing of the `self` argument on methods.
315
+
316
+ This argument is special since its type annotation may be omitted. Furthermore, if a
317
+ type is provided then it must match the parent type.
318
+ """
319
+ assert self_defn.params is not None
320
+ if arg.annotation is None:
321
+ return handle_implicit_self_arg(arg, self_defn, ctx)
322
+
323
+ # If the user has provided an annotation for `self`, then we go ahead and parse it.
324
+ # However, in the annotation the user is also allowed to use `Self`, so we have to
325
+ # specify a `self_ty` in the context.
326
+ self_ty_head = self_defn.check_instantiate(
327
+ [param.to_existential()[0] for param in self_defn.params]
328
+ )
329
+ self_ty_placeholder = ExistentialTypeVar.fresh(
330
+ "Self", copyable=self_ty_head.copyable, droppable=self_ty_head.droppable
331
+ )
332
+ assert ctx.self_ty is None
333
+ ctx = replace(ctx, self_ty=self_ty_placeholder)
334
+ user_ty, user_flags = type_with_flags_from_ast(arg.annotation, ctx)
335
+
336
+ # If the user just annotates `self: Self` then we can fall back to the case where
337
+ # no annotation is provided at all
338
+ if user_ty == self_ty_placeholder:
339
+ return handle_implicit_self_arg(arg, self_defn, ctx, user_flags)
340
+
341
+ # Annotations like `self: Foo[Self]` are not allowed (would be an infinite type)
342
+ if self_ty_placeholder in user_ty.unsolved_vars:
343
+ raise GuppyError(RecursiveSelfError(arg.annotation, arg.arg))
344
+
345
+ # Check that the annotation matches the parent type. We can do this by unifying with
346
+ # the expected self type where all params are instantiated with unification vars
347
+ subst = unify(user_ty, self_ty_head, {})
348
+ if subst is None:
349
+ raise GuppyError(InvalidSelfError(arg.annotation, arg.arg, self_ty_head))
350
+
351
+ return check_function_arg(user_ty, user_flags, arg, arg.arg, ctx)
352
+
353
+
354
+ def handle_implicit_self_arg(
355
+ arg: ast.arg,
356
+ self_defn: TypeDef,
357
+ ctx: TypeParsingCtx,
358
+ flags: InputFlags = InputFlags.NoFlags,
359
+ ) -> FuncInput:
360
+ """Handles the case where no annotation for `self` is provided.
361
+
362
+ Generates the most generic annotation that is possible by making the function as
363
+ generic as the parent type.
364
+ """
365
+ # Check that the user hasn't shadowed some of the parent type parameters using a
366
+ # Python 3.12 style parameter declaration
367
+ assert self_defn.params is not None
368
+ shadowed_params = [
369
+ param for param in self_defn.params if param.name in ctx.param_var_mapping
370
+ ]
371
+ if shadowed_params:
372
+ param = shadowed_params.pop()
373
+ err = SelfParamsShadowedError(arg, param.name, self_defn, arg.arg)
374
+ err.add_sub_diagnostic(SelfParamsShadowedError.ExplicitHelp(arg))
375
+ raise GuppyError(err)
376
+
377
+ # The generic params inherited from the parent type should appear first in the
378
+ # parameter list, so we have to shift the existing ones
379
+ for name, param in ctx.param_var_mapping.items():
380
+ ctx.param_var_mapping[name] = param.with_idx(param.idx + len(self_defn.params))
381
+
382
+ ctx.param_var_mapping.update({param.name: param for param in self_defn.params})
383
+ self_args = [param.to_bound() for param in self_defn.params]
384
+ self_ty = self_defn.check_instantiate(self_args, loc=arg)
385
+ return check_function_arg(self_ty, flags, arg, arg.arg, ctx)
386
+
387
+
239
388
  def parse_function_with_docstring(
240
389
  func_ast: ast.FunctionDef,
241
390
  ) -> tuple[ast.FunctionDef, str | None]:
@@ -356,7 +356,7 @@ class StmtChecker(AstVisitor[BBStatement]):
356
356
  def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt:
357
357
  if node.value is None:
358
358
  raise GuppyError(UnsupportedError(node, "Variable declarations"))
359
- ty = type_from_ast(node.annotation, self.ctx.globals, self.ctx.generic_params)
359
+ ty = type_from_ast(node.annotation, self.ctx.parsing_ctx)
360
360
  node.value, subst = self._check_expr(node.value, ty)
361
361
  assert not ty.unsolved_vars # `ty` must be closed!
362
362
  assert len(subst) == 0