literalenum 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,333 @@
1
+ """
2
+ mypy plugin for LiteralEnum.
3
+
4
+ Makes LiteralEnum subclasses behave as Literal unions in type-annotation
5
+ context while keeping them as normal classes for attribute access.
6
+
7
+ class HttpMethod(LiteralEnum):
8
+ GET = "GET"
9
+ POST = "POST"
10
+
11
+ # In type context: HttpMethod → Literal["GET", "POST"]
12
+ # In value context: HttpMethod.GET → Literal["GET"] (unchanged class)
13
+
14
+ Enable in mypy.ini / pyproject.toml:
15
+
16
+ [mypy]
17
+ plugins = literalenum.mypy_plugin
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import Any, Callable
23
+
24
+ from mypy.nodes import (
25
+ ARG_POS,
26
+ Argument,
27
+ AssignmentStmt,
28
+ BytesExpr,
29
+ IntExpr,
30
+ NameExpr,
31
+ StrExpr,
32
+ TypeInfo,
33
+ Var,
34
+ )
35
+ from mypy.plugin import (
36
+ AnalyzeTypeContext,
37
+ ClassDefContext,
38
+ FunctionContext,
39
+ Plugin,
40
+ )
41
+ from mypy.plugins.common import add_method_to_class
42
+ from mypy.types import (
43
+ Instance,
44
+ LiteralType,
45
+ NoneType,
46
+ Type,
47
+ UnionType,
48
+ get_proper_type,
49
+ )
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Constants
54
+ # ---------------------------------------------------------------------------
55
+
56
+ METADATA_KEY = "literalenum"
57
+
58
+ # Fully-qualified names of the LiteralEnum base class.
59
+ _BASE_FULLNAMES: frozenset[str] = frozenset(
60
+ {
61
+ "literalenum.LiteralEnum",
62
+ }
63
+ )
64
+
65
+ _TAG_TO_BUILTIN: dict[str, str] = {
66
+ "str": "builtins.str",
67
+ "int": "builtins.int",
68
+ "bool": "builtins.bool",
69
+ "bytes": "builtins.bytes",
70
+ }
71
+
72
+ # Member storage: {name: (value, type_tag)}
73
+ # e.g. {"GET": ("GET", "str"), "OK": (200, "int")}
74
+ # The type_tag is one of: "str", "int", "bool", "bytes", "none"
75
+ Members = dict[str, tuple[Any, str]]
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # AST helpers
80
+ # ---------------------------------------------------------------------------
81
+
82
+
83
+ def _extract_literal(expr: Any) -> tuple[Any, str] | None:
84
+ """Return (value, type_tag) for a literal AST node, or None."""
85
+ if isinstance(expr, StrExpr):
86
+ return (expr.value, "str")
87
+ if isinstance(expr, IntExpr):
88
+ return (expr.value, "int")
89
+ if isinstance(expr, BytesExpr):
90
+ return (expr.value, "bytes")
91
+ if isinstance(expr, NameExpr):
92
+ if expr.name == "True":
93
+ return (True, "bool")
94
+ if expr.name == "False":
95
+ return (False, "bool")
96
+ if expr.name == "None":
97
+ return (None, "none")
98
+ return None
99
+
100
+
101
+ def _make_literal_type(
102
+ value: Any,
103
+ type_tag: str,
104
+ named_type: Callable[..., Instance],
105
+ ) -> Type:
106
+ """Build a single mypy LiteralType (or NoneType for None)."""
107
+ if type_tag == "none":
108
+ return NoneType()
109
+ fallback = named_type(_TAG_TO_BUILTIN[type_tag], [])
110
+ return LiteralType(value, fallback)
111
+
112
+
113
+ def _make_union(members: Members, named_type: Callable[..., Instance]) -> Type:
114
+ """Build a UnionType of Literal types from all member values."""
115
+ types: list[Type] = []
116
+ seen: set[tuple[Any, str]] = set()
117
+ for _name, (value, type_tag) in members.items():
118
+ key = (value, type_tag)
119
+ if key in seen:
120
+ continue
121
+ seen.add(key)
122
+ types.append(_make_literal_type(value, type_tag, named_type))
123
+ if not types:
124
+ # Degenerate: empty LiteralEnum ⇒ Never (nothing is assignable)
125
+ return UnionType([])
126
+ return UnionType.make_union(types)
127
+
128
+
129
+ # ---------------------------------------------------------------------------
130
+ # Plugin
131
+ # ---------------------------------------------------------------------------
132
+
133
+
134
+ class LiteralEnumPlugin(Plugin):
135
+ """
136
+ Two-hook architecture:
137
+
138
+ 1. get_base_class_hook – processes class definitions, types members
139
+ as Final[Literal[...]], stores metadata.
140
+
141
+ 2. get_type_analyze_hook – intercepts type references in annotations
142
+ and expands HttpMethod → Literal["GET", "POST", ...]
143
+
144
+ 3. get_function_hook – refines the return type of constructor calls
145
+ HttpMethod("GET") → Literal["GET"]
146
+ """
147
+
148
+ def __init__(self, options: Any) -> None:
149
+ super().__init__(options)
150
+ self._classes: dict[str, Members] = {}
151
+
152
+ # ── hook registration ─────────────────────────────────────────────
153
+
154
+ def get_base_class_hook(
155
+ self,
156
+ fullname: str,
157
+ ) -> Callable[[ClassDefContext], None] | None:
158
+ if fullname in _BASE_FULLNAMES:
159
+ return self._on_class_def
160
+ return None
161
+
162
+ def get_type_analyze_hook(
163
+ self,
164
+ fullname: str,
165
+ ) -> Callable[[AnalyzeTypeContext], Type] | None:
166
+ members = self._resolve(fullname)
167
+ if members is not None:
168
+
169
+ def callback(ctx: AnalyzeTypeContext) -> Type:
170
+ return _make_union(members, ctx.api.named_type)
171
+
172
+ return callback
173
+ return None
174
+
175
+ def get_function_hook(
176
+ self,
177
+ fullname: str,
178
+ ) -> Callable[[FunctionContext], Type] | None:
179
+ # Constructor calls: mypy calls with the class fullname
180
+ members = self._resolve(fullname)
181
+ if members is not None:
182
+
183
+ def callback(ctx: FunctionContext) -> Type:
184
+ return self._on_constructor(fullname, members, ctx)
185
+
186
+ return callback
187
+ return None
188
+
189
+ # ── member resolution (cache + metadata fallback) ─────────────────
190
+
191
+ def _resolve(self, fullname: str) -> Members | None:
192
+ if fullname in self._classes:
193
+ return self._classes[fullname]
194
+ # Incremental mode: reconstruct from persisted TypeInfo.metadata
195
+ sym = self.lookup_fully_qualified(fullname)
196
+ if sym and sym.node and isinstance(sym.node, TypeInfo):
197
+ meta = sym.node.metadata.get(METADATA_KEY)
198
+ if meta and "members" in meta:
199
+ members: Members = {
200
+ k: tuple(v) for k, v in meta["members"].items()
201
+ }
202
+ self._classes[fullname] = members
203
+ return members
204
+ return None
205
+
206
+ # ── hook 1: base class — process class definition ─────────────────
207
+
208
+ def _on_class_def(self, ctx: ClassDefContext) -> None:
209
+ info = ctx.cls.info
210
+
211
+ # Inherit parent members (walk MRO)
212
+ members: Members = {}
213
+ for base in info.mro[1:]:
214
+ parent_meta = base.metadata.get(METADATA_KEY)
215
+ if parent_meta and "members" in parent_meta:
216
+ for name, pair in parent_meta["members"].items():
217
+ members[name] = tuple(pair)
218
+
219
+ # Collect own members from the class body
220
+ for stmt in ctx.cls.defs.body:
221
+ if not isinstance(stmt, AssignmentStmt) or len(stmt.lvalues) != 1:
222
+ continue
223
+ lvalue = stmt.lvalues[0]
224
+ if not isinstance(lvalue, NameExpr):
225
+ continue
226
+ name = lvalue.name
227
+ if not name.isupper() or name.startswith("_"):
228
+ continue
229
+ result = _extract_literal(stmt.rvalue)
230
+ if result is None:
231
+ continue
232
+ members[name] = result
233
+
234
+ # Persist (JSON-safe) and cache
235
+ info.metadata[METADATA_KEY] = {
236
+ "members": {k: list(v) for k, v in members.items()},
237
+ }
238
+ self._classes[info.fullname] = members
239
+
240
+ # Type each member as Final[Literal[<value>]]
241
+ for name, (value, type_tag) in members.items():
242
+ sym = info.names.get(name)
243
+ if sym is None or not isinstance(sym.node, Var):
244
+ continue
245
+ var = sym.node
246
+ var.is_final = True
247
+ var.type = _make_literal_type(value, type_tag, ctx.api.named_type)
248
+
249
+ # Add __init__(self, value: <base_type>) -> None
250
+ # so that HttpMethod("GET") is syntactically valid.
251
+ # The function hook (below) refines the return type.
252
+ if members:
253
+ base_tags = sorted({tag for _, (_, tag) in members.items()})
254
+ param_types: list[Type] = []
255
+ for tag in base_tags:
256
+ if tag == "none":
257
+ param_types.append(NoneType())
258
+ else:
259
+ param_types.append(
260
+ ctx.api.named_type(_TAG_TO_BUILTIN[tag], [])
261
+ )
262
+ param_type = (
263
+ UnionType.make_union(param_types)
264
+ if len(param_types) > 1
265
+ else param_types[0]
266
+ )
267
+ add_method_to_class(
268
+ ctx.api,
269
+ ctx.cls,
270
+ "__init__",
271
+ args=[
272
+ Argument(
273
+ Var("value", param_type), param_type, None, ARG_POS
274
+ )
275
+ ],
276
+ return_type=NoneType(),
277
+ )
278
+
279
+ # ── hook 3: constructor return type ───────────────────────────────
280
+
281
+ def _on_constructor(
282
+ self,
283
+ fullname: str,
284
+ members: Members,
285
+ ctx: FunctionContext,
286
+ ) -> Type:
287
+ if not ctx.arg_types or not ctx.arg_types[0]:
288
+ return ctx.default_return_type
289
+
290
+ arg_type = get_proper_type(ctx.arg_types[0][0])
291
+
292
+ # Literal argument: validate membership and narrow return type
293
+ if isinstance(arg_type, LiteralType):
294
+ # Determine the type tag of this argument
295
+ arg_tag: str | None = None
296
+ if isinstance(arg_type.fallback, Instance):
297
+ fb_name = arg_type.fallback.type.fullname
298
+ for tag, builtin in _TAG_TO_BUILTIN.items():
299
+ if fb_name == builtin:
300
+ arg_tag = tag
301
+ break
302
+
303
+ member_keys = {(v, t) for _, (v, t) in members.items()}
304
+ if arg_tag and (arg_type.value, arg_tag) in member_keys:
305
+ return arg_type # narrow: HttpMethod("GET") → Literal["GET"]
306
+
307
+ # Not a member — report error
308
+ class_name = fullname.rsplit(".", 1)[-1]
309
+ valid = ", ".join(repr(v) for _, (v, _) in members.items())
310
+ ctx.api.fail(
311
+ f'Value {arg_type.value!r} is not a member of '
312
+ f'"{class_name}"; expected one of {valid}',
313
+ ctx.context,
314
+ )
315
+ return ctx.default_return_type
316
+
317
+ # NoneType argument
318
+ if isinstance(arg_type, NoneType):
319
+ if any(t == "none" for _, (_, t) in members.items()):
320
+ return NoneType()
321
+
322
+ # Non-literal (bare str, variable, etc.) — return the full union.
323
+ # This is the best we can do without knowing the runtime value.
324
+ return _make_union(members, ctx.api.named_generic_type)
325
+
326
+
327
+ # ---------------------------------------------------------------------------
328
+ # Entry point
329
+ # ---------------------------------------------------------------------------
330
+
331
+
332
+ def plugin(version: str) -> type[Plugin]:
333
+ return LiteralEnumPlugin
literalenum/py.typed ADDED
File without changes