dissect.cstruct 4.4.dev3__py3-none-any.whl → 4.5.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,226 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import importlib
5
+ import importlib.util
6
+ import logging
7
+ import textwrap
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING
10
+
11
+ from dissect.cstruct import types
12
+ from dissect.cstruct.cstruct import cstruct
13
+
14
+ if TYPE_CHECKING:
15
+ from types import ModuleType
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ def load_module(path: Path, base: Path) -> ModuleType | None:
21
+ module = None
22
+ try:
23
+ relative_path = path.relative_to(base)
24
+ module_tuple = (*relative_path.parent.parts, relative_path.stem)
25
+ spec = importlib.util.spec_from_file_location(".".join(module_tuple), path)
26
+ module = importlib.util.module_from_spec(spec)
27
+ spec.loader.exec_module(module)
28
+ except Exception as e:
29
+ log.warning("Unable to import %s", path)
30
+ log.debug("Error while trying to import module %s", path, exc_info=e)
31
+
32
+ return module
33
+
34
+
35
+ def generate_file_stub(path: Path, base: Path) -> str:
36
+ tmp_module = load_module(path, base)
37
+ if tmp_module is None or not hasattr(tmp_module, "cstruct"):
38
+ return ""
39
+
40
+ header = [
41
+ "# Generated by cstruct-stubgen",
42
+ "from typing import BinaryIO, Literal, overload",
43
+ "",
44
+ "import dissect.cstruct as __cs__",
45
+ "from typing_extensions import TypeAlias",
46
+ ]
47
+ body = []
48
+
49
+ for name, obj in tmp_module.__dict__.items():
50
+ if isinstance(obj, cstruct):
51
+ stub = generate_cstruct_stub(obj, module_prefix="__cs__.", cls_name=f"_{name}")
52
+ body.append(stub)
53
+
54
+ if body[-1][-1] != "\n":
55
+ body.append("")
56
+
57
+ body.append(f"# Technically `{name}` is an instance of `_{name}`, but then we can't use it in type hints")
58
+ body.append(f"{name}: TypeAlias = _{name}")
59
+ body.append("")
60
+
61
+ if not body:
62
+ return ""
63
+
64
+ return "\n".join([*header, "", "\n".join(body)])
65
+
66
+
67
+ def generate_cstruct_stub(cs: cstruct, module_prefix: str = "", cls_name: str = "cstruct") -> str:
68
+ empty_cs = cstruct()
69
+
70
+ cs_prefix = f"{cls_name}."
71
+ header = [f"class {cls_name}({module_prefix}cstruct):"]
72
+ body = []
73
+ indent = " " * 4
74
+
75
+ # Constants first
76
+ for name, value in cs.consts.items():
77
+ if name in empty_cs.consts:
78
+ continue
79
+ body.append(textwrap.indent(f"{name}: Literal[{value!r}] = ...", prefix=indent))
80
+
81
+ defined_names = set()
82
+
83
+ # Then typedefs
84
+ for name, typedef in cs.typedefs.items():
85
+ if name in empty_cs.typedefs:
86
+ continue
87
+
88
+ if typedef.__name__ in empty_cs.typedefs:
89
+ stub = f"{name}: TypeAlias = {cs_prefix}{typedef.__name__}"
90
+ elif typedef.__name__ in defined_names:
91
+ # Create an alias to the type if we have already seen it before.
92
+ stub = f"{name}: TypeAlias = {typedef.__name__}"
93
+ elif issubclass(typedef, (types.Enum, types.Flag)):
94
+ stub = generate_enum_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix)
95
+ elif issubclass(typedef, types.Structure):
96
+ stub = generate_structure_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix)
97
+ elif issubclass(typedef, types.BaseType):
98
+ stub = generate_generic_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix)
99
+ elif isinstance(typedef, str):
100
+ stub = f"{name}: TypeAlias = {typedef}"
101
+ else:
102
+ raise TypeError(f"Unknown typedef: {typedef}")
103
+
104
+ defined_names.add(typedef.__name__)
105
+
106
+ body.append(textwrap.indent(stub, prefix=indent))
107
+
108
+ if not body:
109
+ body.append(textwrap.indent("...", prefix=indent))
110
+
111
+ return "\n".join(header + body)
112
+
113
+
114
+ def generate_typehint(
115
+ type_: type[types.BaseType],
116
+ prefix: str = "",
117
+ module_prefix: str = "",
118
+ ) -> str:
119
+ if issubclass(type_, types.CharArray):
120
+ return f"{module_prefix}CharArray"
121
+ if issubclass(type_, types.WcharArray):
122
+ return f"{module_prefix}WcharArray"
123
+ if issubclass(type_, types.Pointer):
124
+ return f"{module_prefix}Pointer[{generate_typehint(type_.type, prefix, module_prefix)}]"
125
+ if issubclass(type_, types.Array):
126
+ return f"{module_prefix}Array[{generate_typehint(type_.type, prefix, module_prefix)}]"
127
+ return f"{prefix}{type_.__name__}"
128
+
129
+
130
+ def generate_generic_stub(
131
+ type_: type[types.BaseType],
132
+ name_prefix: str = "",
133
+ cs_prefix: str = "",
134
+ module_prefix: str = "",
135
+ ) -> str:
136
+ return f"class {name_prefix}{type_.__name__}({module_prefix}{type_.__base__.__name__}): ...\n"
137
+
138
+
139
+ def generate_enum_stub(
140
+ enum: type[types.Enum | types.Flag],
141
+ name_prefix: str = "",
142
+ cs_prefix: str = "",
143
+ module_prefix: str = "",
144
+ ) -> str:
145
+ result = [f"class {name_prefix}{enum.__name__}({module_prefix}{enum.__base__.__name__}):"]
146
+ result.extend(f" {key} = ..." for key in enum.__members__)
147
+ result.append("")
148
+
149
+ return "\n".join(result)
150
+
151
+
152
+ def generate_structure_stub(
153
+ structure: type[types.Structure],
154
+ name_prefix: str = "",
155
+ cs_prefix: str = "",
156
+ module_prefix: str = "",
157
+ ) -> str:
158
+ result = [f"class {name_prefix}{structure.__name__}({module_prefix}{structure.__base__.__name__}):"]
159
+
160
+ indent = " " * 4
161
+
162
+ args = ["self"]
163
+ for field_name, field in structure.fields.items():
164
+ type_name = field.type.__name__
165
+ inlined = False
166
+
167
+ # If it's a structure and not globally defined, add an inline stub for it
168
+ nested_type = field.type
169
+ while issubclass(nested_type, types.BaseArray):
170
+ nested_type = nested_type.type
171
+
172
+ if issubclass(nested_type, types.Structure) and type_name not in structure.cs.typedefs:
173
+ inlined = True
174
+ inline_stub = generate_structure_stub(nested_type, cs_prefix=cs_prefix, module_prefix=module_prefix)
175
+
176
+ result.append(textwrap.indent(inline_stub, prefix=indent))
177
+
178
+ type_hint = generate_typehint(field.type, "" if inlined else f"{cs_prefix}", module_prefix)
179
+ result.append(f" {field_name}: {type_hint}")
180
+
181
+ args.append(f"{field_name}: {type_hint} | None = ...")
182
+
183
+ result.append(textwrap.indent("@overload", prefix=indent))
184
+ result.append(textwrap.indent(f"def __init__({', '.join(args)}): ...", prefix=indent))
185
+ result.append(textwrap.indent("@overload", prefix=indent))
186
+ result.append(
187
+ textwrap.indent("def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ...", prefix=indent)
188
+ )
189
+ result.append("")
190
+ return "\n".join(result)
191
+
192
+
193
+ def setup_logger(verbosity: int) -> None:
194
+ level = logging.INFO
195
+ if verbosity >= 1:
196
+ level = logging.DEBUG
197
+
198
+ logging.basicConfig(level=level)
199
+
200
+
201
+ def main() -> None:
202
+ parser = argparse.ArgumentParser(
203
+ "cstruct-stubify",
204
+ description="Create .pyi stub files for cstruct definitions",
205
+ epilog="NOTE: This tool will only generate stubs for the cstruct definitions in a file, not any other Python code. Manual fixups may be required.", # noqa: E501
206
+ )
207
+ parser.add_argument("path", type=Path, help="path to the file or directory to create stubs for")
208
+ parser.add_argument("-v", "--verbose", action="count", default=0)
209
+ args = parser.parse_args()
210
+
211
+ setup_logger(args.verbose)
212
+
213
+ path: Path = args.path
214
+ for file in path.rglob("*.py") if path.is_dir() else [path]:
215
+ if file.is_file() and file.suffix == ".py":
216
+ stub = generate_file_stub(file, path)
217
+ if not stub:
218
+ continue
219
+
220
+ stub_file = file.with_suffix(".pyi")
221
+ log.info("Writing stub of file %s to %s", file, stub_file.name)
222
+ stub_file.write_text(stub)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()
@@ -1,4 +1,4 @@
1
- from dissect.cstruct.types.base import Array, ArrayMetaType, BaseType, MetaType
1
+ from dissect.cstruct.types.base import Array, BaseArray, BaseType, MetaType
2
2
  from dissect.cstruct.types.char import Char, CharArray
3
3
  from dissect.cstruct.types.enum import Enum
4
4
  from dissect.cstruct.types.flag import Flag
@@ -13,7 +13,7 @@ from dissect.cstruct.types.wchar import Wchar, WcharArray
13
13
  __all__ = [
14
14
  "LEB128",
15
15
  "Array",
16
- "ArrayMetaType",
16
+ "BaseArray",
17
17
  "BaseType",
18
18
  "Char",
19
19
  "CharArray",
@@ -2,12 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  import functools
4
4
  from io import BytesIO
5
- from typing import TYPE_CHECKING, Any, BinaryIO, Callable
5
+ from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, TypeVar
6
6
 
7
7
  from dissect.cstruct.exceptions import ArraySizeError
8
8
  from dissect.cstruct.expression import Expression
9
9
 
10
10
  if TYPE_CHECKING:
11
+ from typing_extensions import Self
12
+
11
13
  from dissect.cstruct.cstruct import cstruct
12
14
 
13
15
 
@@ -27,10 +29,10 @@ class MetaType(type):
27
29
  """The alignment of the type in bytes. A value of ``None`` will be treated as 1-byte aligned."""
28
30
 
29
31
  # This must be the actual type, but since Array is a subclass of BaseType, we correct this at the bottom of the file
30
- ArrayType: type[Array] = "Array"
32
+ ArrayType: type[BaseArray] = "Array"
31
33
  """The array type for this type class."""
32
34
 
33
- def __call__(cls, *args, **kwargs) -> MetaType | BaseType:
35
+ def __call__(cls, *args, **kwargs) -> Self: # type: ignore
34
36
  """Adds support for ``TypeClass(bytes | file-like object)`` parsing syntax."""
35
37
  # TODO: add support for Type(cs) API to create new bounded type classes, similar to the old API?
36
38
  if len(args) == 1 and not isinstance(args[0], cls):
@@ -48,22 +50,26 @@ class MetaType(type):
48
50
 
49
51
  return type.__call__(cls, *args, **kwargs)
50
52
 
51
- def __getitem__(cls, num_entries: int | Expression | None) -> ArrayMetaType:
53
+ def __getitem__(cls, num_entries: int | Expression | None) -> type[BaseArray]:
52
54
  """Create a new array with the given number of entries."""
53
55
  return cls.cs._make_array(cls, num_entries)
54
56
 
55
57
  def __len__(cls) -> int:
56
58
  """Return the byte size of the type."""
59
+ # Python 3.9 compat thing for bound type vars
60
+ if cls is BaseType:
61
+ return 0
62
+
57
63
  if cls.size is None:
58
64
  raise TypeError("Dynamic size")
59
65
 
60
66
  return cls.size
61
67
 
62
- def __default__(cls) -> BaseType:
68
+ def __default__(cls) -> Self: # type: ignore
63
69
  """Return the default value of this type."""
64
70
  return cls()
65
71
 
66
- def reads(cls, data: bytes) -> BaseType:
72
+ def reads(cls, data: bytes | memoryview | bytearray) -> Self: # type: ignore
67
73
  """Parse the given data from a bytes-like object.
68
74
 
69
75
  Args:
@@ -74,7 +80,7 @@ class MetaType(type):
74
80
  """
75
81
  return cls._read(BytesIO(data))
76
82
 
77
- def read(cls, obj: BinaryIO | bytes) -> BaseType:
83
+ def read(cls, obj: BinaryIO | bytes | memoryview | bytearray) -> Self: # type: ignore
78
84
  """Parse the given data.
79
85
 
80
86
  Args:
@@ -86,6 +92,9 @@ class MetaType(type):
86
92
  if _is_buffer_type(obj):
87
93
  return cls.reads(obj)
88
94
 
95
+ if not _is_readable_type(obj):
96
+ raise TypeError("Invalid object type")
97
+
89
98
  return cls._read(obj)
90
99
 
91
100
  def write(cls, stream: BinaryIO, value: Any) -> int:
@@ -113,7 +122,7 @@ class MetaType(type):
113
122
  cls._write(out, value)
114
123
  return out.getvalue()
115
124
 
116
- def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> BaseType:
125
+ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self: # type: ignore
117
126
  """Internal function for reading value.
118
127
 
119
128
  Must be implemented per type.
@@ -124,7 +133,7 @@ class MetaType(type):
124
133
  """
125
134
  raise NotImplementedError
126
135
 
127
- def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[BaseType]:
136
+ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Self]: # type: ignore
128
137
  """Internal function for reading array values.
129
138
 
130
139
  Allows type implementations to do optimized reading for their type.
@@ -142,7 +151,7 @@ class MetaType(type):
142
151
 
143
152
  return [cls._read(stream, context) for _ in range(count)]
144
153
 
145
- def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]:
154
+ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]:
146
155
  """Internal function for reading null-terminated data.
147
156
 
148
157
  "Null" is type specific, so must be implemented per type.
@@ -156,7 +165,7 @@ class MetaType(type):
156
165
  def _write(cls, stream: BinaryIO, data: Any) -> int:
157
166
  raise NotImplementedError
158
167
 
159
- def _write_array(cls, stream: BinaryIO, array: list[BaseType]) -> int:
168
+ def _write_array(cls, stream: BinaryIO, array: list[Self]) -> int: # type: ignore
160
169
  """Internal function for writing arrays.
161
170
 
162
171
  Allows type implementations to do optimized writing for their type.
@@ -167,7 +176,7 @@ class MetaType(type):
167
176
  """
168
177
  return sum(cls._write(stream, entry) for entry in array)
169
178
 
170
- def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
179
+ def _write_0(cls, stream: BinaryIO, array: list[Self]) -> int: # type: ignore
171
180
  """Internal function for writing null-terminated arrays.
172
181
 
173
182
  Allows type implementations to do optimized writing for their type.
@@ -191,10 +200,10 @@ class _overload:
191
200
  b'\\x7b\\x00\\x00\\x00'
192
201
  """
193
202
 
194
- def __init__(self, func: Callable[[Any], Any]) -> None:
203
+ def __init__(self, func: Callable[..., Any]) -> None:
195
204
  self.func = func
196
205
 
197
- def __get__(self, instance: BaseType | None, owner: MetaType) -> Callable[[Any], bytes]:
206
+ def __get__(self, instance: BaseType | None, owner: type[BaseType]) -> Callable[[], bytes]:
198
207
  if instance is None:
199
208
  return functools.partial(self.func, owner)
200
209
  return functools.partial(self.func, instance.__class__, value=instance)
@@ -214,19 +223,32 @@ class BaseType(metaclass=MetaType):
214
223
  return self.__class__.size
215
224
 
216
225
 
217
- class ArrayMetaType(MetaType):
218
- """Base metaclass for array-like types."""
226
+ T = TypeVar("T", bound=BaseType)
227
+
219
228
 
220
- type: MetaType
221
- num_entries: int | Expression | None
222
- null_terminated: bool
229
+ class BaseArray(BaseType):
230
+ """Implements a fixed or dynamically sized array type.
231
+
232
+ Example:
233
+ When using the default C-style parser, the following syntax is supported:
223
234
 
235
+ x[3] -> 3 -> static length.
236
+ x[] -> None -> null-terminated.
237
+ x[expr] -> expr -> dynamic length.
238
+ """
239
+
240
+ type: ClassVar[type[BaseType]]
241
+ num_entries: ClassVar[int | Expression | None]
242
+ null_terminated: ClassVar[bool]
243
+
244
+ @classmethod
224
245
  def __default__(cls) -> BaseType:
225
246
  return type.__call__(
226
247
  cls, [cls.type.__default__()] * (cls.num_entries if isinstance(cls.num_entries, int) else 0)
227
248
  )
228
249
 
229
- def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array:
250
+ @classmethod
251
+ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]:
230
252
  if cls.null_terminated:
231
253
  return cls.type._read_0(stream, context)
232
254
 
@@ -244,22 +266,6 @@ class ArrayMetaType(MetaType):
244
266
 
245
267
  return cls.type._read_array(stream, num, context)
246
268
 
247
-
248
- class Array(list, BaseType, metaclass=ArrayMetaType):
249
- """Implements a fixed or dynamically sized array type.
250
-
251
- Example:
252
- When using the default C-style parser, the following syntax is supported:
253
-
254
- x[3] -> 3 -> static length.
255
- x[] -> None -> null-terminated.
256
- x[expr] -> expr -> dynamic length.
257
- """
258
-
259
- @classmethod
260
- def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Array:
261
- return cls(ArrayMetaType._read(cls, stream, context))
262
-
263
269
  @classmethod
264
270
  def _write(cls, stream: BinaryIO, data: list[Any]) -> int:
265
271
  if cls.null_terminated:
@@ -271,11 +277,17 @@ class Array(list, BaseType, metaclass=ArrayMetaType):
271
277
  return cls.type._write_array(stream, data)
272
278
 
273
279
 
274
- def _is_readable_type(value: Any) -> bool:
280
+ class Array(list[T], BaseArray):
281
+ @classmethod
282
+ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[T]:
283
+ return cls(super()._read(stream, context))
284
+
285
+
286
+ def _is_readable_type(value: object) -> bool:
275
287
  return hasattr(value, "read")
276
288
 
277
289
 
278
- def _is_buffer_type(value: Any) -> bool:
290
+ def _is_buffer_type(value: object) -> bool:
279
291
  return isinstance(value, (bytes, memoryview, bytearray))
280
292
 
281
293
 
@@ -1,16 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, BinaryIO
3
+ from typing import TYPE_CHECKING, Any, BinaryIO
4
4
 
5
- from dissect.cstruct.types.base import EOF, ArrayMetaType, BaseType
5
+ from dissect.cstruct.types.base import EOF, BaseArray, BaseType
6
6
 
7
+ if TYPE_CHECKING:
8
+ from typing_extensions import Self
7
9
 
8
- class CharArray(bytes, BaseType, metaclass=ArrayMetaType):
10
+
11
+ class CharArray(bytes, BaseArray):
9
12
  """Character array type for reading and writing byte strings."""
10
13
 
11
14
  @classmethod
12
- def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> CharArray:
13
- return type.__call__(cls, ArrayMetaType._read(cls, stream, context))
15
+ def __default__(cls) -> Self:
16
+ return type.__call__(cls, b"\x00" * (0 if cls.dynamic or cls.null_terminated else cls.num_entries))
17
+
18
+ @classmethod
19
+ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self:
20
+ return type.__call__(cls, super()._read(stream, context))
14
21
 
15
22
  @classmethod
16
23
  def _write(cls, stream: BinaryIO, data: bytes) -> int:
@@ -24,10 +31,6 @@ class CharArray(bytes, BaseType, metaclass=ArrayMetaType):
24
31
  return stream.write(data + b"\x00")
25
32
  return stream.write(data)
26
33
 
27
- @classmethod
28
- def __default__(cls) -> CharArray:
29
- return type.__call__(cls, b"\x00" * (0 if cls.dynamic or cls.null_terminated else cls.num_entries))
30
-
31
34
 
32
35
  class Char(bytes, BaseType):
33
36
  """Character type for reading and writing bytes."""
@@ -35,11 +38,15 @@ class Char(bytes, BaseType):
35
38
  ArrayType = CharArray
36
39
 
37
40
  @classmethod
38
- def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char:
41
+ def __default__(cls) -> Self:
42
+ return type.__call__(cls, b"\x00")
43
+
44
+ @classmethod
45
+ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self:
39
46
  return cls._read_array(stream, 1, context)
40
47
 
41
48
  @classmethod
42
- def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Char:
49
+ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> Self:
43
50
  if count == 0:
44
51
  return type.__call__(cls, b"")
45
52
 
@@ -50,7 +57,7 @@ class Char(bytes, BaseType):
50
57
  return type.__call__(cls, data)
51
58
 
52
59
  @classmethod
53
- def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Char:
60
+ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self:
54
61
  buf = []
55
62
  while True:
56
63
  byte = stream.read(1)
@@ -73,7 +80,3 @@ class Char(bytes, BaseType):
73
80
  data = data.encode("latin-1")
74
81
 
75
82
  return stream.write(data)
76
-
77
- @classmethod
78
- def __default__(cls) -> Char:
79
- return type.__call__(cls, b"\x00")
@@ -1,30 +1,41 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import sys
4
+ from enum import Enum as _Enum
4
5
  from enum import EnumMeta, IntEnum, IntFlag
5
- from typing import TYPE_CHECKING, Any, BinaryIO
6
+ from typing import TYPE_CHECKING, Any, BinaryIO, TypeVar, overload
6
7
 
7
8
  from dissect.cstruct.types.base import Array, BaseType, MetaType
8
9
 
9
10
  if TYPE_CHECKING:
11
+ from typing_extensions import Self
12
+
10
13
  from dissect.cstruct.cstruct import cstruct
11
14
 
12
15
 
13
16
  PY_311 = sys.version_info >= (3, 11, 0)
14
17
  PY_312 = sys.version_info >= (3, 12, 0)
15
18
 
19
+ _S = TypeVar("_S")
20
+
16
21
 
17
22
  class EnumMetaType(EnumMeta, MetaType):
18
- type: MetaType
23
+ type: type[BaseType]
24
+
25
+ @overload
26
+ def __call__(cls, value: cstruct, name: str, type_: type[BaseType], *args, **kwargs) -> type[Enum]: ...
27
+
28
+ @overload
29
+ def __call__(cls: type[_S], value: int | BinaryIO | bytes) -> _S: ...
19
30
 
20
31
  def __call__(
21
32
  cls,
22
- value: cstruct | int | BinaryIO | bytes = None,
33
+ value: cstruct | int | BinaryIO | bytes | None = None,
23
34
  name: str | None = None,
24
- type_: MetaType | None = None,
35
+ type_: type[BaseType] | None = None,
25
36
  *args,
26
37
  **kwargs,
27
- ) -> EnumMetaType:
38
+ ) -> Enum | type[Enum]:
28
39
  if name is None:
29
40
  if value is None:
30
41
  value = cls.type.__default__()
@@ -35,6 +46,8 @@ class EnumMetaType(EnumMeta, MetaType):
35
46
 
36
47
  return super().__call__(value)
37
48
 
49
+ # We are constructing a new Enum class
50
+ # cs is the cstruct instance, but we can't isinstance check it due to circular imports
38
51
  cs = value
39
52
  if not issubclass(type_, int):
40
53
  raise TypeError("Enum can only be created from int type")
@@ -50,7 +63,13 @@ class EnumMetaType(EnumMeta, MetaType):
50
63
 
51
64
  return enum_cls
52
65
 
53
- def __getitem__(cls, name: str | int) -> Enum | Array:
66
+ @overload
67
+ def __getitem__(cls: type[_S], name: str) -> _S: ...
68
+
69
+ @overload
70
+ def __getitem__(cls: type[_S], name: int) -> Array: ...
71
+
72
+ def __getitem__(cls: type[_S], name: str | int) -> _S | Array:
54
73
  if isinstance(name, str):
55
74
  return super().__getitem__(name)
56
75
  return MetaType.__getitem__(cls, name)
@@ -64,24 +83,24 @@ class EnumMetaType(EnumMeta, MetaType):
64
83
  return True
65
84
  return value in cls._value2member_map_
66
85
 
67
- def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Enum:
86
+ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self:
68
87
  return cls(cls.type._read(stream, context))
69
88
 
70
- def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Enum]:
89
+ def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] | None = None) -> list[Self]:
71
90
  return list(map(cls, cls.type._read_array(stream, count, context)))
72
91
 
73
- def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Enum]:
92
+ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[Self]:
74
93
  return list(map(cls, cls.type._read_0(stream, context)))
75
94
 
76
95
  def _write(cls, stream: BinaryIO, data: Enum) -> int:
77
96
  return cls.type._write(stream, data.value)
78
97
 
79
- def _write_array(cls, stream: BinaryIO, array: list[Enum]) -> int:
80
- data = [entry.value if isinstance(entry, Enum) else entry for entry in array]
98
+ def _write_array(cls, stream: BinaryIO, array: list[BaseType | int]) -> int:
99
+ data = [entry.value if isinstance(entry, _Enum) else entry for entry in array]
81
100
  return cls.type._write_array(stream, data)
82
101
 
83
- def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int:
84
- data = [entry.value if isinstance(entry, Enum) else entry for entry in array]
102
+ def _write_0(cls, stream: BinaryIO, array: list[BaseType | int]) -> int:
103
+ data = [entry.value if isinstance(entry, _Enum) else entry for entry in array]
85
104
  return cls._write_array(stream, [*data, cls.type.__default__()])
86
105
 
87
106
 
@@ -180,7 +199,7 @@ class Enum(BaseType, IntEnum, metaclass=EnumMetaType):
180
199
  return hash((self.__class__, self.name, self.value))
181
200
 
182
201
  @classmethod
183
- def _missing_(cls, value: int) -> Enum:
202
+ def _missing_(cls, value: int) -> Self:
184
203
  # Emulate FlagBoundary.KEEP for enum (allow values other than the defined members)
185
204
  new_member = int.__new__(cls, value)
186
205
  new_member._name_ = None
@@ -1,10 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, BinaryIO
3
+ from typing import TYPE_CHECKING, Any, BinaryIO
4
4
 
5
5
  from dissect.cstruct.types.base import BaseType
6
6
  from dissect.cstruct.utils import ENDIANNESS_MAP
7
7
 
8
+ if TYPE_CHECKING:
9
+ from typing_extensions import Self
10
+
8
11
 
9
12
  class Int(int, BaseType):
10
13
  """Integer type that can span an arbitrary amount of bytes."""
@@ -12,7 +15,7 @@ class Int(int, BaseType):
12
15
  signed: bool
13
16
 
14
17
  @classmethod
15
- def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int:
18
+ def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self:
16
19
  data = stream.read(cls.size)
17
20
 
18
21
  if len(data) != cls.size:
@@ -21,7 +24,7 @@ class Int(int, BaseType):
21
24
  return cls.from_bytes(data, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed)
22
25
 
23
26
  @classmethod
24
- def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Int:
27
+ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> Self:
25
28
  result = []
26
29
 
27
30
  while True: