mrd-python 2.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mrd/__init__.py +96 -0
- mrd/_binary.py +1339 -0
- mrd/_dtypes.py +89 -0
- mrd/_ndjson.py +1194 -0
- mrd/binary.py +680 -0
- mrd/ndjson.py +2716 -0
- mrd/protocols.py +180 -0
- mrd/tools/export_png_images.py +39 -0
- mrd/tools/phantom.py +161 -0
- mrd/tools/simulation.py +173 -0
- mrd/tools/stream_recon.py +184 -0
- mrd/tools/transform.py +37 -0
- mrd/types.py +1714 -0
- mrd/yardl_types.py +303 -0
- mrd_python-2.0.0rc1.dist-info/LICENSE +8 -0
- mrd_python-2.0.0rc1.dist-info/METADATA +28 -0
- mrd_python-2.0.0rc1.dist-info/RECORD +19 -0
- mrd_python-2.0.0rc1.dist-info/WHEEL +5 -0
- mrd_python-2.0.0rc1.dist-info/top_level.txt +1 -0
mrd/_binary.py
ADDED
|
@@ -0,0 +1,1339 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
# pyright: reportUnnecessaryIsInstance=false
|
|
5
|
+
|
|
6
|
+
import datetime
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from io import BufferedIOBase, BufferedReader, BytesIO
|
|
9
|
+
from typing import (
|
|
10
|
+
BinaryIO,
|
|
11
|
+
Iterable,
|
|
12
|
+
Protocol,
|
|
13
|
+
TypeVar,
|
|
14
|
+
Generic,
|
|
15
|
+
Any,
|
|
16
|
+
Optional,
|
|
17
|
+
Tuple,
|
|
18
|
+
cast,
|
|
19
|
+
)
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
import struct
|
|
22
|
+
import sys
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from numpy.lib import recfunctions
|
|
26
|
+
import numpy.typing as npt
|
|
27
|
+
|
|
28
|
+
from .yardl_types import *
|
|
29
|
+
|
|
30
|
+
if sys.byteorder != "little":
|
|
31
|
+
raise RuntimeError("Only little-endian systems are currently supported")
|
|
32
|
+
|
|
33
|
+
MAGIC_BYTES: bytes = b"yardl"
|
|
34
|
+
CURRENT_BINARY_FORMAT_VERSION: int = 1
|
|
35
|
+
|
|
36
|
+
INT8_MIN: int = np.iinfo(np.int8).min
|
|
37
|
+
INT8_MAX: int = np.iinfo(np.int8).max
|
|
38
|
+
|
|
39
|
+
UINT8_MAX: int = np.iinfo(np.uint8).max
|
|
40
|
+
|
|
41
|
+
INT16_MIN: int = np.iinfo(np.int16).min
|
|
42
|
+
INT16_MAX: int = np.iinfo(np.int16).max
|
|
43
|
+
|
|
44
|
+
UINT16_MAX: int = np.iinfo(np.uint16).max
|
|
45
|
+
|
|
46
|
+
INT32_MIN: int = np.iinfo(np.int32).min
|
|
47
|
+
INT32_MAX: int = np.iinfo(np.int32).max
|
|
48
|
+
|
|
49
|
+
UINT32_MAX: int = np.iinfo(np.uint32).max
|
|
50
|
+
|
|
51
|
+
INT64_MIN: int = np.iinfo(np.int64).min
|
|
52
|
+
INT64_MAX: int = np.iinfo(np.int64).max
|
|
53
|
+
|
|
54
|
+
UINT64_MAX: int = np.iinfo(np.uint64).max
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class BinaryProtocolWriter(ABC):
|
|
58
|
+
def __init__(self, stream: Union[BinaryIO, str], schema: str) -> None:
|
|
59
|
+
self._stream = CodedOutputStream(stream)
|
|
60
|
+
self._stream.write_bytes(MAGIC_BYTES)
|
|
61
|
+
write_fixed_int32(self._stream, CURRENT_BINARY_FORMAT_VERSION)
|
|
62
|
+
string_serializer.write(self._stream, schema)
|
|
63
|
+
|
|
64
|
+
def _close(self) -> None:
|
|
65
|
+
self._stream.close()
|
|
66
|
+
|
|
67
|
+
def _end_stream(self) -> None:
|
|
68
|
+
self._stream.ensure_capacity(1)
|
|
69
|
+
self._stream.write_byte_no_check(0)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BinaryProtocolReader(ABC):
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
stream: Union[BufferedReader, BytesIO, BinaryIO, str],
|
|
76
|
+
expected_schema: Optional[str],
|
|
77
|
+
) -> None:
|
|
78
|
+
self._stream = CodedInputStream(stream)
|
|
79
|
+
magic_bytes = self._stream.read_view(len(MAGIC_BYTES))
|
|
80
|
+
if magic_bytes != MAGIC_BYTES: # pyright: ignore [reportUnnecessaryComparison]
|
|
81
|
+
raise RuntimeError("Invalid magic bytes")
|
|
82
|
+
|
|
83
|
+
version = read_fixed_int32(self._stream)
|
|
84
|
+
if version != CURRENT_BINARY_FORMAT_VERSION:
|
|
85
|
+
raise RuntimeError("Invalid binary format version")
|
|
86
|
+
|
|
87
|
+
self._schema = string_serializer.read(self._stream)
|
|
88
|
+
if expected_schema and self._schema != expected_schema:
|
|
89
|
+
raise RuntimeError("Invalid schema")
|
|
90
|
+
|
|
91
|
+
def _close(self) -> None:
|
|
92
|
+
self._stream.close()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class CodedOutputStream:
|
|
96
|
+
def __init__(
|
|
97
|
+
self, stream: Union[BinaryIO, str], *, buffer_size: int = 65536
|
|
98
|
+
) -> None:
|
|
99
|
+
if isinstance(stream, str):
|
|
100
|
+
self._stream = cast(BinaryIO, open(stream, "wb"))
|
|
101
|
+
self._owns_stream = True
|
|
102
|
+
else:
|
|
103
|
+
self._stream = stream
|
|
104
|
+
self._owns_stream = False
|
|
105
|
+
|
|
106
|
+
self._buffer = bytearray(buffer_size)
|
|
107
|
+
self._offset = 0
|
|
108
|
+
|
|
109
|
+
def close(self) -> None:
|
|
110
|
+
self.flush()
|
|
111
|
+
if self._owns_stream:
|
|
112
|
+
self._stream.close()
|
|
113
|
+
|
|
114
|
+
def ensure_capacity(self, size: int) -> None:
|
|
115
|
+
if (len(self._buffer) - self._offset) < size:
|
|
116
|
+
self.flush()
|
|
117
|
+
|
|
118
|
+
def flush(self) -> None:
|
|
119
|
+
if self._offset > 0:
|
|
120
|
+
self._stream.write(self._buffer[: self._offset])
|
|
121
|
+
self._stream.flush()
|
|
122
|
+
self._offset = 0
|
|
123
|
+
|
|
124
|
+
def write(self, formatter: struct.Struct, *args: Any) -> None:
|
|
125
|
+
size = formatter.size
|
|
126
|
+
if (len(self._buffer) - self._offset) < size:
|
|
127
|
+
self.flush()
|
|
128
|
+
|
|
129
|
+
formatter.pack_into(self._buffer, self._offset, *args)
|
|
130
|
+
self._offset += size
|
|
131
|
+
|
|
132
|
+
def write_bytes(self, data: Union[bytes, bytearray]) -> None:
|
|
133
|
+
if len(data) > (len(self._buffer) - self._offset):
|
|
134
|
+
self.flush()
|
|
135
|
+
self._stream.write(data)
|
|
136
|
+
else:
|
|
137
|
+
self._buffer[self._offset : self._offset + len(data)] = data
|
|
138
|
+
self._offset += len(data)
|
|
139
|
+
|
|
140
|
+
def write_bytes_directly(self, data: Union[bytes, bytearray, memoryview]) -> None:
|
|
141
|
+
self.flush()
|
|
142
|
+
self._stream.write(data)
|
|
143
|
+
|
|
144
|
+
def write_byte_no_check(self, value: int) -> None:
|
|
145
|
+
assert 0 <= value <= UINT8_MAX
|
|
146
|
+
self._buffer[self._offset] = value
|
|
147
|
+
self._offset += 1
|
|
148
|
+
|
|
149
|
+
def write_unsigned_varint(
|
|
150
|
+
self,
|
|
151
|
+
value: Union[int, np.uint8, np.uint16, np.uint32, np.uint64],
|
|
152
|
+
) -> None:
|
|
153
|
+
if (len(self._buffer) - self._offset) < 10:
|
|
154
|
+
self.flush()
|
|
155
|
+
|
|
156
|
+
int_val = int(value) # bitwise ops not supported on numpy types
|
|
157
|
+
|
|
158
|
+
while True:
|
|
159
|
+
if int_val < 0x80:
|
|
160
|
+
self.write_byte_no_check(int_val)
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
self.write_byte_no_check((int_val & 0x7F) | 0x80)
|
|
164
|
+
int_val >>= 7
|
|
165
|
+
|
|
166
|
+
def zigzag_encode(
|
|
167
|
+
self,
|
|
168
|
+
value: Union[int, np.int8, np.int16, np.int32, np.int64],
|
|
169
|
+
) -> int:
|
|
170
|
+
int_val = int(value)
|
|
171
|
+
return (int_val << 1) ^ (int_val >> 63)
|
|
172
|
+
|
|
173
|
+
def write_signed_varint(
|
|
174
|
+
self,
|
|
175
|
+
value: Union[int, np.int8, np.int16, np.int32, np.int64],
|
|
176
|
+
) -> None:
|
|
177
|
+
self.write_unsigned_varint(self.zigzag_encode(value))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class CodedInputStream:
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
stream: Union[BufferedReader, BytesIO, BinaryIO, str],
|
|
184
|
+
*,
|
|
185
|
+
buffer_size: int = 65536,
|
|
186
|
+
) -> None:
|
|
187
|
+
if isinstance(stream, str):
|
|
188
|
+
self._stream = open(stream, "rb")
|
|
189
|
+
self._owns_stream = True
|
|
190
|
+
else:
|
|
191
|
+
if not isinstance(stream, BufferedIOBase):
|
|
192
|
+
self._stream = BufferedReader(stream) # type: ignore
|
|
193
|
+
else:
|
|
194
|
+
self._stream = stream
|
|
195
|
+
self._owns_stream = False
|
|
196
|
+
|
|
197
|
+
self._last_read_count = 0
|
|
198
|
+
self._buffer = bytearray(buffer_size)
|
|
199
|
+
self._view = memoryview(self._buffer)
|
|
200
|
+
self._offset = 0
|
|
201
|
+
self._at_end = False
|
|
202
|
+
|
|
203
|
+
def close(self) -> None:
|
|
204
|
+
if self._owns_stream:
|
|
205
|
+
self._stream.close()
|
|
206
|
+
|
|
207
|
+
def read(self, formatter: struct.Struct) -> tuple[Any, ...]:
|
|
208
|
+
if self._last_read_count - self._offset < formatter.size:
|
|
209
|
+
self._fill_buffer(formatter.size)
|
|
210
|
+
|
|
211
|
+
result = formatter.unpack_from(self._buffer, self._offset)
|
|
212
|
+
self._offset += formatter.size
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
def read_byte(self) -> int:
|
|
216
|
+
if self._last_read_count - self._offset < 1:
|
|
217
|
+
self._fill_buffer(1)
|
|
218
|
+
|
|
219
|
+
result = self._buffer[self._offset]
|
|
220
|
+
self._offset += 1
|
|
221
|
+
return result
|
|
222
|
+
|
|
223
|
+
def read_unsigned_varint(self) -> int:
|
|
224
|
+
result = 0
|
|
225
|
+
shift = 0
|
|
226
|
+
while True:
|
|
227
|
+
if self._last_read_count - self._offset < 1:
|
|
228
|
+
self._fill_buffer(1)
|
|
229
|
+
|
|
230
|
+
byte = self._buffer[self._offset]
|
|
231
|
+
self._offset += 1
|
|
232
|
+
result |= (byte & 0x7F) << shift
|
|
233
|
+
if byte < 0x80:
|
|
234
|
+
return result
|
|
235
|
+
shift += 7
|
|
236
|
+
|
|
237
|
+
def zigzag_decode(self, value: int) -> int:
|
|
238
|
+
return (value >> 1) ^ -(value & 1)
|
|
239
|
+
|
|
240
|
+
def read_signed_varint(self) -> int:
|
|
241
|
+
return self.zigzag_decode(self.read_unsigned_varint())
|
|
242
|
+
|
|
243
|
+
def read_view(self, count: int) -> memoryview:
|
|
244
|
+
if count <= (self._last_read_count - self._offset):
|
|
245
|
+
res = self._view[self._offset : self._offset + count]
|
|
246
|
+
self._offset += count
|
|
247
|
+
return res
|
|
248
|
+
|
|
249
|
+
if count > len(self._buffer):
|
|
250
|
+
local_buf = bytearray(count)
|
|
251
|
+
local_view = memoryview(local_buf)
|
|
252
|
+
remaining = self._last_read_count - self._offset
|
|
253
|
+
local_view[:remaining] = self._view[self._offset : self._last_read_count]
|
|
254
|
+
self._offset = self._last_read_count
|
|
255
|
+
if self._stream.readinto(local_view[remaining:]) < count - remaining:
|
|
256
|
+
raise EOFError("Unexpected EOF")
|
|
257
|
+
return local_view
|
|
258
|
+
|
|
259
|
+
self._fill_buffer(count)
|
|
260
|
+
result = self._view[self._offset : self._offset + count]
|
|
261
|
+
self._offset += count
|
|
262
|
+
return result
|
|
263
|
+
|
|
264
|
+
def read_bytearray(self, count: int) -> bytearray:
|
|
265
|
+
if count <= (self._last_read_count - self._offset):
|
|
266
|
+
res = bytearray(self._view[self._offset : self._offset + count])
|
|
267
|
+
self._offset += count
|
|
268
|
+
return res
|
|
269
|
+
|
|
270
|
+
if count > len(self._buffer):
|
|
271
|
+
local_buf = bytearray(count)
|
|
272
|
+
local_view = memoryview(local_buf)
|
|
273
|
+
remaining = self._last_read_count - self._offset
|
|
274
|
+
local_view[:remaining] = self._view[self._offset : self._last_read_count]
|
|
275
|
+
self._offset = self._last_read_count
|
|
276
|
+
if self._stream.readinto(local_view[remaining:]) < count - remaining:
|
|
277
|
+
raise EOFError("Unexpected EOF")
|
|
278
|
+
return local_buf
|
|
279
|
+
|
|
280
|
+
self._fill_buffer(count)
|
|
281
|
+
result = self._view[self._offset : self._offset + count]
|
|
282
|
+
self._offset += count
|
|
283
|
+
return bytearray(result)
|
|
284
|
+
|
|
285
|
+
def _fill_buffer(self, min_count: int = 0) -> None:
|
|
286
|
+
remaining = self._last_read_count - self._offset
|
|
287
|
+
if remaining > 0:
|
|
288
|
+
remaining_view = memoryview(self._buffer)[
|
|
289
|
+
self._offset : self._offset + remaining + 1
|
|
290
|
+
]
|
|
291
|
+
self._buffer[:remaining] = remaining_view
|
|
292
|
+
|
|
293
|
+
slice = memoryview(self._buffer)[remaining:]
|
|
294
|
+
self._last_read_count = self._stream.readinto(slice) + remaining
|
|
295
|
+
self._offset = 0
|
|
296
|
+
if self._last_read_count == 0:
|
|
297
|
+
self._at_end = True
|
|
298
|
+
if min_count > 0 and (self._last_read_count) < min_count:
|
|
299
|
+
raise EOFError("Unexpected EOF")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
T = TypeVar("T")
|
|
303
|
+
T_NP = TypeVar("T_NP", bound=np.generic)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class TypeSerializer(Generic[T, T_NP], ABC):
|
|
307
|
+
def __init__(self, dtype: npt.DTypeLike) -> None:
|
|
308
|
+
self._dtype: np.dtype[Any] = np.dtype(dtype)
|
|
309
|
+
|
|
310
|
+
def overall_dtype(self) -> np.dtype[Any]:
|
|
311
|
+
return self._dtype
|
|
312
|
+
|
|
313
|
+
def struct_format_str(self) -> Optional[str]:
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
@abstractmethod
|
|
317
|
+
def write(self, stream: CodedOutputStream, value: T) -> None:
|
|
318
|
+
raise NotImplementedError
|
|
319
|
+
|
|
320
|
+
@abstractmethod
|
|
321
|
+
def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None:
|
|
322
|
+
raise NotImplementedError
|
|
323
|
+
|
|
324
|
+
@abstractmethod
|
|
325
|
+
def read(self, stream: CodedInputStream) -> T:
|
|
326
|
+
raise NotImplementedError
|
|
327
|
+
|
|
328
|
+
@abstractmethod
|
|
329
|
+
def read_numpy(self, stream: CodedInputStream) -> T_NP:
|
|
330
|
+
raise NotImplementedError
|
|
331
|
+
|
|
332
|
+
def is_trivially_serializable(self) -> bool:
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class StructSerializer(TypeSerializer[T, T_NP]):
|
|
337
|
+
def __init__(self, numpy_type: type, format_string: str) -> None:
|
|
338
|
+
super().__init__(numpy_type)
|
|
339
|
+
self._struct = struct.Struct(format_string)
|
|
340
|
+
self._numpy_type = numpy_type
|
|
341
|
+
|
|
342
|
+
def write(self, stream: CodedOutputStream, value: T) -> None:
|
|
343
|
+
stream.write(self._struct, value)
|
|
344
|
+
|
|
345
|
+
def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None:
|
|
346
|
+
stream.write(self._struct, value)
|
|
347
|
+
|
|
348
|
+
def read(self, stream: CodedInputStream) -> T:
|
|
349
|
+
return cast(T, stream.read(self._struct)[0])
|
|
350
|
+
|
|
351
|
+
def read_numpy(self, stream: CodedInputStream) -> T_NP:
|
|
352
|
+
return cast(T_NP, self._numpy_type(stream.read(self._struct)[0]))
|
|
353
|
+
|
|
354
|
+
def struct_format_str(self) -> str:
|
|
355
|
+
return self._struct.format
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class BoolSerializer(StructSerializer[bool, np.bool_]):
|
|
359
|
+
def __init__(self) -> None:
|
|
360
|
+
super().__init__(np.bool_, "<?")
|
|
361
|
+
|
|
362
|
+
def read(self, stream: CodedInputStream) -> bool:
|
|
363
|
+
return super().read(stream)
|
|
364
|
+
|
|
365
|
+
def read_numpy(self, stream: CodedInputStream) -> np.bool_:
|
|
366
|
+
return super().read_numpy(stream)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
bool_serializer = BoolSerializer()
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class Int8Serializer(StructSerializer[Int8, np.int8]):
|
|
373
|
+
def __init__(self) -> None:
|
|
374
|
+
super().__init__(np.int8, "<b")
|
|
375
|
+
|
|
376
|
+
def read(self, stream: CodedInputStream) -> Int8:
|
|
377
|
+
return super().read(stream)
|
|
378
|
+
|
|
379
|
+
def is_trivially_serializable(self) -> bool:
|
|
380
|
+
return True
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
int8_serializer = Int8Serializer()
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class UInt8Serializer(StructSerializer[UInt8, np.uint8]):
|
|
387
|
+
def __init__(self) -> None:
|
|
388
|
+
super().__init__(np.uint8, "<B")
|
|
389
|
+
|
|
390
|
+
def read(self, stream: CodedInputStream) -> UInt8:
|
|
391
|
+
return super().read(stream)
|
|
392
|
+
|
|
393
|
+
def is_trivially_serializable(self) -> bool:
|
|
394
|
+
return True
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
uint8_serializer = UInt8Serializer()
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
class Int16Serializer(TypeSerializer[Int16, np.int16]):
|
|
401
|
+
def __init__(self) -> None:
|
|
402
|
+
super().__init__(np.int16)
|
|
403
|
+
|
|
404
|
+
def write(self, stream: CodedOutputStream, value: Int16) -> None:
|
|
405
|
+
if isinstance(value, int):
|
|
406
|
+
if value < INT16_MIN or value > INT16_MAX:
|
|
407
|
+
raise ValueError(
|
|
408
|
+
f"Value {value} is outside the range of a signed 16-bit integer"
|
|
409
|
+
)
|
|
410
|
+
elif not isinstance(value, cast(type, np.int16)):
|
|
411
|
+
raise ValueError(f"Value is not a signed 16-bit integer: {value}")
|
|
412
|
+
|
|
413
|
+
stream.write_signed_varint(value)
|
|
414
|
+
|
|
415
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.int16) -> None:
|
|
416
|
+
stream.write_signed_varint(value)
|
|
417
|
+
|
|
418
|
+
def read(self, stream: CodedInputStream) -> Int16:
|
|
419
|
+
return stream.read_signed_varint()
|
|
420
|
+
|
|
421
|
+
def read_numpy(self, stream: CodedInputStream) -> np.int16:
|
|
422
|
+
return np.int16(stream.read_signed_varint())
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
int16_serializer = Int16Serializer()
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class UInt16Serializer(TypeSerializer[UInt16, np.uint16]):
|
|
429
|
+
def __init__(self) -> None:
|
|
430
|
+
super().__init__(np.uint16)
|
|
431
|
+
|
|
432
|
+
def write(self, stream: CodedOutputStream, value: UInt16) -> None:
|
|
433
|
+
if isinstance(value, int):
|
|
434
|
+
if value < 0 or value > UINT16_MAX:
|
|
435
|
+
raise ValueError(
|
|
436
|
+
f"Value {value} is outside the range of an unsigned 16-bit integer"
|
|
437
|
+
)
|
|
438
|
+
elif not isinstance(value, cast(type, np.uint16)):
|
|
439
|
+
raise ValueError(f"Value is not an unsigned 16-bit integer: {value}")
|
|
440
|
+
|
|
441
|
+
stream.write_unsigned_varint(value)
|
|
442
|
+
|
|
443
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.uint16) -> None:
|
|
444
|
+
stream.write_unsigned_varint(value)
|
|
445
|
+
|
|
446
|
+
def read(self, stream: CodedInputStream) -> UInt16:
|
|
447
|
+
return stream.read_unsigned_varint()
|
|
448
|
+
|
|
449
|
+
def read_numpy(self, stream: CodedInputStream) -> np.uint16:
|
|
450
|
+
return np.uint16(stream.read_unsigned_varint())
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
uint16_serializer = UInt16Serializer()
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
class Int32Serializer(TypeSerializer[Int32, np.int32]):
|
|
457
|
+
def __init__(self) -> None:
|
|
458
|
+
super().__init__(np.int32)
|
|
459
|
+
|
|
460
|
+
def write(self, stream: CodedOutputStream, value: Int32) -> None:
|
|
461
|
+
if isinstance(value, int):
|
|
462
|
+
if value < INT32_MIN or value > INT32_MAX:
|
|
463
|
+
raise ValueError(
|
|
464
|
+
f"Value {value} is outside the range of a signed 32-bit integer"
|
|
465
|
+
)
|
|
466
|
+
elif not isinstance(value, cast(type, np.int32)):
|
|
467
|
+
raise ValueError(f"Value is not a signed 32-bit integer: {value}")
|
|
468
|
+
|
|
469
|
+
stream.write_signed_varint(value)
|
|
470
|
+
|
|
471
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.int32) -> None:
|
|
472
|
+
stream.write_signed_varint(value)
|
|
473
|
+
|
|
474
|
+
def read(self, stream: CodedInputStream) -> Int32:
|
|
475
|
+
return stream.read_signed_varint()
|
|
476
|
+
|
|
477
|
+
def read_numpy(self, stream: CodedInputStream) -> np.int32:
|
|
478
|
+
return np.int32(stream.read_signed_varint())
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
int32_serializer = Int32Serializer()
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
class UInt32Serializer(TypeSerializer[UInt32, np.uint32]):
|
|
485
|
+
def __init__(self) -> None:
|
|
486
|
+
super().__init__(np.uint32)
|
|
487
|
+
|
|
488
|
+
def write(self, stream: CodedOutputStream, value: UInt32) -> None:
|
|
489
|
+
if isinstance(value, int):
|
|
490
|
+
if value < 0 or value > UINT32_MAX:
|
|
491
|
+
raise ValueError(
|
|
492
|
+
f"Value {value} is outside the range of an unsigned 32-bit integer"
|
|
493
|
+
)
|
|
494
|
+
elif not isinstance(value, cast(type, np.uint32)):
|
|
495
|
+
raise ValueError(f"Value is not an unsigned 32-bit integer: {value}")
|
|
496
|
+
|
|
497
|
+
stream.write_unsigned_varint(value)
|
|
498
|
+
|
|
499
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.uint32) -> None:
|
|
500
|
+
stream.write_unsigned_varint(value)
|
|
501
|
+
|
|
502
|
+
def read(self, stream: CodedInputStream) -> UInt32:
|
|
503
|
+
return stream.read_unsigned_varint()
|
|
504
|
+
|
|
505
|
+
def read_numpy(self, stream: CodedInputStream) -> np.uint32:
|
|
506
|
+
return np.uint32(stream.read_unsigned_varint())
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
uint32_serializer = UInt32Serializer()
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class Int64Serializer(TypeSerializer[Int64, np.int64]):
|
|
513
|
+
def __init__(self) -> None:
|
|
514
|
+
super().__init__(np.int64)
|
|
515
|
+
|
|
516
|
+
def write(self, stream: CodedOutputStream, value: Int64) -> None:
|
|
517
|
+
if isinstance(value, int):
|
|
518
|
+
if value < INT64_MIN or value > INT64_MAX:
|
|
519
|
+
raise ValueError(
|
|
520
|
+
f"Value {value} is outside the range of a signed 64-bit integer"
|
|
521
|
+
)
|
|
522
|
+
elif not isinstance(value, cast(type, np.int64)):
|
|
523
|
+
raise ValueError(f"Value is not a signed 64-bit integer: {value}")
|
|
524
|
+
|
|
525
|
+
stream.write_signed_varint(value)
|
|
526
|
+
|
|
527
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.int64) -> None:
|
|
528
|
+
stream.write_signed_varint(value)
|
|
529
|
+
|
|
530
|
+
def read(self, stream: CodedInputStream) -> Int64:
|
|
531
|
+
return stream.read_signed_varint()
|
|
532
|
+
|
|
533
|
+
def read_numpy(self, stream: CodedInputStream) -> np.int64:
|
|
534
|
+
return np.int64(stream.read_signed_varint())
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
int64_serializer = Int64Serializer()
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
class UInt64Serializer(TypeSerializer[UInt64, np.uint64]):
|
|
541
|
+
def __init__(self) -> None:
|
|
542
|
+
super().__init__(np.uint64)
|
|
543
|
+
|
|
544
|
+
def write(self, stream: CodedOutputStream, value: UInt64) -> None:
|
|
545
|
+
if isinstance(value, int):
|
|
546
|
+
if value < 0 or value > UINT64_MAX:
|
|
547
|
+
raise ValueError(
|
|
548
|
+
f"Value {value} is outside the range of an unsigned 64-bit integer"
|
|
549
|
+
)
|
|
550
|
+
elif not isinstance(value, cast(type, np.uint64)):
|
|
551
|
+
raise ValueError(f"Value is not an unsigned 64-bit integer: {value}")
|
|
552
|
+
|
|
553
|
+
stream.write_unsigned_varint(value)
|
|
554
|
+
|
|
555
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.uint64) -> None:
|
|
556
|
+
stream.write_unsigned_varint(value)
|
|
557
|
+
|
|
558
|
+
def read(self, stream: CodedInputStream) -> UInt64:
|
|
559
|
+
return stream.read_unsigned_varint()
|
|
560
|
+
|
|
561
|
+
def read_numpy(self, stream: CodedInputStream) -> np.uint64:
|
|
562
|
+
return np.uint64(stream.read_unsigned_varint())
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
uint64_serializer = UInt64Serializer()
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class SizeSerializer(TypeSerializer[Size, np.uint64]):
|
|
569
|
+
def __init__(self) -> None:
|
|
570
|
+
super().__init__(np.uint64)
|
|
571
|
+
|
|
572
|
+
def write(self, stream: CodedOutputStream, value: Size) -> None:
|
|
573
|
+
if isinstance(value, int):
|
|
574
|
+
if value < 0 or value > UINT64_MAX:
|
|
575
|
+
raise ValueError(
|
|
576
|
+
f"Value {value} is outside the range of an unsigned 64-bit integer"
|
|
577
|
+
)
|
|
578
|
+
elif not isinstance(value, cast(type, np.uint64)):
|
|
579
|
+
raise ValueError(f"Value is not an unsigned 64-bit integer: {value}")
|
|
580
|
+
|
|
581
|
+
stream.write_unsigned_varint(value)
|
|
582
|
+
|
|
583
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.uint64) -> None:
|
|
584
|
+
stream.write_unsigned_varint(value)
|
|
585
|
+
|
|
586
|
+
def read(self, stream: CodedInputStream) -> Size:
|
|
587
|
+
return stream.read_unsigned_varint()
|
|
588
|
+
|
|
589
|
+
def read_numpy(self, stream: CodedInputStream) -> np.uint64:
|
|
590
|
+
return np.uint64(stream.read_unsigned_varint())
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
size_serializer = SizeSerializer()
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class Float32Serializer(StructSerializer[Float32, np.float32]):
|
|
597
|
+
def __init__(self) -> None:
|
|
598
|
+
super().__init__(np.float32, "<f")
|
|
599
|
+
|
|
600
|
+
def read(self, stream: CodedInputStream) -> Float32:
|
|
601
|
+
return super().read(stream)
|
|
602
|
+
|
|
603
|
+
def is_trivially_serializable(self) -> bool:
|
|
604
|
+
return True
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
float32_serializer = Float32Serializer()
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class Float64Serializer(StructSerializer[Float64, np.float64]):
|
|
611
|
+
def __init__(self) -> None:
|
|
612
|
+
super().__init__(np.float64, "<d")
|
|
613
|
+
|
|
614
|
+
def read(self, stream: CodedInputStream) -> Float64:
|
|
615
|
+
return super().read(stream)
|
|
616
|
+
|
|
617
|
+
def is_trivially_serializable(self) -> bool:
|
|
618
|
+
return True
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
float64_serializer = Float64Serializer()
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
class Complex32Serializer(StructSerializer[ComplexFloat, np.complex64]):
|
|
625
|
+
def __init__(self) -> None:
|
|
626
|
+
super().__init__(np.complex64, "<ff")
|
|
627
|
+
|
|
628
|
+
def write(self, stream: CodedOutputStream, value: ComplexFloat) -> None:
|
|
629
|
+
stream.write(self._struct, value.real, value.imag)
|
|
630
|
+
|
|
631
|
+
def read(self, stream: CodedInputStream) -> ComplexFloat:
|
|
632
|
+
return ComplexFloat(*stream.read(self._struct))
|
|
633
|
+
|
|
634
|
+
def read_numpy(self, stream: CodedInputStream) -> np.complex64:
|
|
635
|
+
real, imag = stream.read(self._struct)
|
|
636
|
+
return np.complex64(complex(real, imag))
|
|
637
|
+
|
|
638
|
+
def is_trivially_serializable(self) -> bool:
|
|
639
|
+
return True
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
complexfloat32_serializer = Complex32Serializer()
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
class Complex64Serializer(StructSerializer[ComplexDouble, np.complex128]):
|
|
646
|
+
def __init__(self) -> None:
|
|
647
|
+
super().__init__(np.complex128, "<dd")
|
|
648
|
+
|
|
649
|
+
def write(self, stream: CodedOutputStream, value: ComplexDouble) -> None:
|
|
650
|
+
stream.write(self._struct, value.real, value.imag)
|
|
651
|
+
|
|
652
|
+
def read(self, stream: CodedInputStream) -> ComplexDouble:
|
|
653
|
+
return ComplexDouble(*stream.read(self._struct))
|
|
654
|
+
|
|
655
|
+
def read_numpy(self, stream: CodedInputStream) -> np.complex128:
|
|
656
|
+
real, imag = stream.read(self._struct)
|
|
657
|
+
return np.complex128(complex(real, imag))
|
|
658
|
+
|
|
659
|
+
def is_trivially_serializable(self) -> bool:
|
|
660
|
+
return True
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
complexfloat64_serializer = Complex64Serializer()
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
class StringSerializer(TypeSerializer[str, np.object_]):
|
|
667
|
+
def __init__(self) -> None:
|
|
668
|
+
super().__init__(np.object_)
|
|
669
|
+
|
|
670
|
+
def write(self, stream: CodedOutputStream, value: str) -> None:
|
|
671
|
+
b = value.encode("utf-8")
|
|
672
|
+
stream.write_unsigned_varint(len(b))
|
|
673
|
+
stream.write_bytes(b)
|
|
674
|
+
|
|
675
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
676
|
+
self.write(stream, cast(str, value))
|
|
677
|
+
|
|
678
|
+
def read(self, stream: CodedInputStream) -> str:
|
|
679
|
+
length = stream.read_unsigned_varint()
|
|
680
|
+
view = stream.read_view(length)
|
|
681
|
+
return str(view, "utf-8")
|
|
682
|
+
|
|
683
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
684
|
+
return np.object_(self.read(stream))
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
string_serializer = StringSerializer()
|
|
688
|
+
|
|
689
|
+
EPOCH_ORDINAL_DAYS = datetime.date(1970, 1, 1).toordinal()
|
|
690
|
+
DATETIME_DAYS_DTYPE = np.dtype("datetime64[D]")
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class DateSerializer(TypeSerializer[datetime.date, np.datetime64]):
|
|
694
|
+
def __init__(self) -> None:
|
|
695
|
+
super().__init__(DATETIME_DAYS_DTYPE)
|
|
696
|
+
|
|
697
|
+
def write(self, stream: CodedOutputStream, value: datetime.date) -> None:
|
|
698
|
+
if isinstance(value, datetime.date):
|
|
699
|
+
stream.write_signed_varint(value.toordinal() - EPOCH_ORDINAL_DAYS)
|
|
700
|
+
else:
|
|
701
|
+
if not isinstance(value, np.datetime64):
|
|
702
|
+
raise ValueError(
|
|
703
|
+
f"Expected datetime.date or numpy.datetime64, got {type(value)}"
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
self.write_numpy(stream, value)
|
|
707
|
+
|
|
708
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.datetime64) -> None:
|
|
709
|
+
if value.dtype == DATETIME_DAYS_DTYPE:
|
|
710
|
+
stream.write_signed_varint(value.astype(np.int32))
|
|
711
|
+
else:
|
|
712
|
+
stream.write_signed_varint(
|
|
713
|
+
value.astype(DATETIME_DAYS_DTYPE).astype(np.int32)
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
def read(self, stream: CodedInputStream) -> datetime.date:
|
|
717
|
+
days_since_epoch = stream.read_signed_varint()
|
|
718
|
+
return datetime.date.fromordinal(days_since_epoch + EPOCH_ORDINAL_DAYS)
|
|
719
|
+
|
|
720
|
+
def read_numpy(self, stream: CodedInputStream) -> np.datetime64:
|
|
721
|
+
days_since_epoch = stream.read_signed_varint()
|
|
722
|
+
return np.datetime64(days_since_epoch, "D")
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
date_serializer = DateSerializer()
|
|
726
|
+
|
|
727
|
+
TIMEDELTA_NANOSECONDS_DTYPE = np.dtype("timedelta64[ns]")
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
class TimeSerializer(TypeSerializer[Time, np.timedelta64]):
|
|
731
|
+
def __init__(self) -> None:
|
|
732
|
+
super().__init__(TIMEDELTA_NANOSECONDS_DTYPE)
|
|
733
|
+
|
|
734
|
+
def write(self, stream: CodedOutputStream, value: Time) -> None:
|
|
735
|
+
if isinstance(value, Time):
|
|
736
|
+
self.write_numpy(stream, value.numpy_value)
|
|
737
|
+
elif isinstance(value, datetime.time):
|
|
738
|
+
self.write_numpy(stream, Time.from_time(value).numpy_value)
|
|
739
|
+
else:
|
|
740
|
+
if not isinstance(value, np.timedelta64):
|
|
741
|
+
raise ValueError(
|
|
742
|
+
f"Expected a Time, datetime.time or np.timedelta64, got {type(value)}"
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
self.write_numpy(stream, value)
|
|
746
|
+
|
|
747
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.timedelta64) -> None:
|
|
748
|
+
if value.dtype == TIMEDELTA_NANOSECONDS_DTYPE:
|
|
749
|
+
stream.write_signed_varint(value.astype(np.int64))
|
|
750
|
+
else:
|
|
751
|
+
stream.write_signed_varint(
|
|
752
|
+
value.astype(DATETIME_NANOSECONDS_DTYPE).astype(np.int64)
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
def read(self, stream: CodedInputStream) -> Time:
|
|
756
|
+
nanoseconds_since_midnight = stream.read_signed_varint()
|
|
757
|
+
return Time(nanoseconds_since_midnight)
|
|
758
|
+
|
|
759
|
+
def read_numpy(self, stream: CodedInputStream) -> np.timedelta64:
|
|
760
|
+
nanoseconds_since_midnight = stream.read_signed_varint()
|
|
761
|
+
return np.timedelta64(nanoseconds_since_midnight, "ns")
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
time_serializer = TimeSerializer()
|
|
765
|
+
|
|
766
|
+
DATETIME_NANOSECONDS_DTYPE = np.dtype("datetime64[ns]")
|
|
767
|
+
EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
class DateTimeSerializer(TypeSerializer[DateTime, np.datetime64]):
|
|
771
|
+
def __init__(self) -> None:
|
|
772
|
+
super().__init__(DATETIME_NANOSECONDS_DTYPE)
|
|
773
|
+
|
|
774
|
+
def write(self, stream: CodedOutputStream, value: DateTime) -> None:
|
|
775
|
+
if isinstance(value, DateTime):
|
|
776
|
+
self.write_numpy(stream, value.numpy_value)
|
|
777
|
+
elif isinstance(value, datetime.datetime):
|
|
778
|
+
self.write_numpy(stream, DateTime.from_datetime(value).numpy_value)
|
|
779
|
+
else:
|
|
780
|
+
if not isinstance(value, np.datetime64):
|
|
781
|
+
raise ValueError(
|
|
782
|
+
f"Expected datetime.datetime or numpy.datetime64, got {type(value)}"
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
self.write_numpy(stream, value)
|
|
786
|
+
|
|
787
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.datetime64) -> None:
|
|
788
|
+
if value.dtype == DATETIME_NANOSECONDS_DTYPE:
|
|
789
|
+
stream.write_signed_varint(value.astype(np.int64))
|
|
790
|
+
else:
|
|
791
|
+
stream.write_signed_varint(
|
|
792
|
+
value.astype(DATETIME_NANOSECONDS_DTYPE).astype(np.int64)
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
def read(self, stream: CodedInputStream) -> DateTime:
|
|
796
|
+
nanoseconds_since_epoch = stream.read_signed_varint()
|
|
797
|
+
return DateTime(nanoseconds_since_epoch)
|
|
798
|
+
|
|
799
|
+
def read_numpy(self, stream: CodedInputStream) -> np.datetime64:
|
|
800
|
+
nanoseconds_since_epoch = stream.read_signed_varint()
|
|
801
|
+
return np.datetime64(nanoseconds_since_epoch, "ns")
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
datetime_serializer = DateTimeSerializer()
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
class NoneSerializer(TypeSerializer[None, Any]):
|
|
808
|
+
def __init__(self) -> None:
|
|
809
|
+
super().__init__(np.object_)
|
|
810
|
+
|
|
811
|
+
def write(self, stream: CodedOutputStream, value: None) -> None:
|
|
812
|
+
pass
|
|
813
|
+
|
|
814
|
+
def write_numpy(self, stream: CodedOutputStream, value: Any) -> None:
|
|
815
|
+
pass
|
|
816
|
+
|
|
817
|
+
def read(self, stream: CodedInputStream) -> None:
|
|
818
|
+
return None
|
|
819
|
+
|
|
820
|
+
def read_numpy(self, stream: CodedInputStream) -> Any:
|
|
821
|
+
return np.object_()
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
none_serializer = NoneSerializer()
|
|
825
|
+
|
|
826
|
+
TEnum = TypeVar("TEnum", bound=Enum)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
class EnumSerializer(Generic[TEnum, T, T_NP], TypeSerializer[TEnum, T_NP]):
|
|
830
|
+
def __init__(
|
|
831
|
+
self, integer_serializer: TypeSerializer[T, T_NP], enum_type: type
|
|
832
|
+
) -> None:
|
|
833
|
+
super().__init__(integer_serializer.overall_dtype())
|
|
834
|
+
self._integer_serializer = integer_serializer
|
|
835
|
+
self._enum_type = enum_type
|
|
836
|
+
|
|
837
|
+
def write(self, stream: CodedOutputStream, value: TEnum) -> None:
|
|
838
|
+
self._integer_serializer.write(stream, value.value)
|
|
839
|
+
|
|
840
|
+
def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None:
|
|
841
|
+
return self._integer_serializer.write_numpy(stream, value)
|
|
842
|
+
|
|
843
|
+
def read(self, stream: CodedInputStream) -> TEnum:
|
|
844
|
+
int_value = self._integer_serializer.read(stream)
|
|
845
|
+
return self._enum_type(int_value)
|
|
846
|
+
|
|
847
|
+
def read_numpy(self, stream: CodedInputStream) -> T_NP:
|
|
848
|
+
return self._integer_serializer.read_numpy(stream)
|
|
849
|
+
|
|
850
|
+
def is_trivially_serializable(self) -> bool:
|
|
851
|
+
return self._integer_serializer.is_trivially_serializable()
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
class OptionalSerializer(Generic[T, T_NP], TypeSerializer[Optional[T], np.void]):
|
|
855
|
+
def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None:
|
|
856
|
+
super().__init__(
|
|
857
|
+
np.dtype(
|
|
858
|
+
[("has_value", np.bool_), ("value", element_serializer.overall_dtype())]
|
|
859
|
+
)
|
|
860
|
+
)
|
|
861
|
+
self._element_serializer = element_serializer
|
|
862
|
+
self._none = cast(np.void, np.zeros((), dtype=self.overall_dtype())[()])
|
|
863
|
+
|
|
864
|
+
def write(self, stream: CodedOutputStream, value: Optional[T]) -> None:
|
|
865
|
+
stream.ensure_capacity(1)
|
|
866
|
+
if value is None:
|
|
867
|
+
stream.write_byte_no_check(0)
|
|
868
|
+
else:
|
|
869
|
+
stream.write_byte_no_check(1)
|
|
870
|
+
self._element_serializer.write(stream, value)
|
|
871
|
+
|
|
872
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.void) -> None:
|
|
873
|
+
stream.ensure_capacity(1)
|
|
874
|
+
if not value["has_value"]:
|
|
875
|
+
stream.write_byte_no_check(0)
|
|
876
|
+
else:
|
|
877
|
+
stream.write_byte_no_check(1)
|
|
878
|
+
self._element_serializer.write_numpy(stream, value["value"])
|
|
879
|
+
|
|
880
|
+
def read(self, stream: CodedInputStream) -> Optional[T]:
|
|
881
|
+
has_value = stream.read_byte()
|
|
882
|
+
if has_value == 0:
|
|
883
|
+
return None
|
|
884
|
+
else:
|
|
885
|
+
return self._element_serializer.read(stream)
|
|
886
|
+
|
|
887
|
+
def read_numpy(self, stream: CodedInputStream) -> np.void:
|
|
888
|
+
has_value = stream.read_byte()
|
|
889
|
+
if has_value == 0:
|
|
890
|
+
return self._none
|
|
891
|
+
else:
|
|
892
|
+
return cast(np.void, (True, self._element_serializer.read_numpy(stream)))
|
|
893
|
+
|
|
894
|
+
def is_trivially_serializable(self) -> bool:
|
|
895
|
+
return super().is_trivially_serializable()
|
|
896
|
+
|
|
897
|
+
|
|
898
|
+
class UnionCaseProtocol(Protocol):
|
|
899
|
+
index: int
|
|
900
|
+
value: Any
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
class UnionSerializer(TypeSerializer[T, np.object_]):
|
|
904
|
+
def __init__(
|
|
905
|
+
self,
|
|
906
|
+
union_type: type,
|
|
907
|
+
cases: list[Optional[tuple[type, TypeSerializer[Any, Any]]]],
|
|
908
|
+
) -> None:
|
|
909
|
+
super().__init__(np.object_)
|
|
910
|
+
self._union_type = union_type
|
|
911
|
+
self._cases = cases
|
|
912
|
+
self._offset = 1 if cases[0] is None else 0
|
|
913
|
+
|
|
914
|
+
def write(self, stream: CodedOutputStream, value: T) -> None:
|
|
915
|
+
if value is None:
|
|
916
|
+
if self._cases[0] is None:
|
|
917
|
+
stream.write_byte_no_check(0)
|
|
918
|
+
return
|
|
919
|
+
else:
|
|
920
|
+
raise ValueError("None is not a valid for this union type")
|
|
921
|
+
|
|
922
|
+
if not isinstance(value, self._union_type):
|
|
923
|
+
raise ValueError(
|
|
924
|
+
f"Expected union value of type {self._union_type} but got {type(value)}"
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
union_value = cast(UnionCaseProtocol, value)
|
|
928
|
+
|
|
929
|
+
tag_index = union_value.index + self._offset
|
|
930
|
+
stream.ensure_capacity(1)
|
|
931
|
+
stream.write_byte_no_check(tag_index)
|
|
932
|
+
type_case = self._cases[tag_index]
|
|
933
|
+
assert type_case is not None
|
|
934
|
+
type_case[1].write(stream, union_value.value)
|
|
935
|
+
|
|
936
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
937
|
+
self.write(stream, cast(T, value))
|
|
938
|
+
|
|
939
|
+
def read(self, stream: CodedInputStream) -> T:
|
|
940
|
+
case_index = stream.read_byte()
|
|
941
|
+
if case_index == 0 and self._offset == 1:
|
|
942
|
+
return None # type: ignore
|
|
943
|
+
case_type, case_serializer = self._cases[case_index] # type: ignore
|
|
944
|
+
return case_type(case_serializer.read(stream)) # type: ignore
|
|
945
|
+
|
|
946
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
947
|
+
return self.read(stream) # type: ignore
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
class StreamSerializer(TypeSerializer[Iterable[T], Any]):
|
|
951
|
+
def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None:
|
|
952
|
+
super().__init__(np.object_)
|
|
953
|
+
self._element_serializer = element_serializer
|
|
954
|
+
|
|
955
|
+
def write(self, stream: CodedOutputStream, value: Iterable[T]) -> None:
|
|
956
|
+
# Note that the final 0 is missing and will be added before the next protocol step
|
|
957
|
+
# or the protocol is closed.
|
|
958
|
+
if isinstance(value, list) and len(value) > 0:
|
|
959
|
+
stream.write_unsigned_varint(len(value))
|
|
960
|
+
for element in value:
|
|
961
|
+
self._element_serializer.write(stream, element)
|
|
962
|
+
else:
|
|
963
|
+
for element in value:
|
|
964
|
+
stream.write_byte_no_check(1)
|
|
965
|
+
self._element_serializer.write(stream, element)
|
|
966
|
+
|
|
967
|
+
def write_numpy(self, stream: CodedOutputStream, value: Any) -> None:
|
|
968
|
+
raise NotImplementedError()
|
|
969
|
+
|
|
970
|
+
def read(self, stream: CodedInputStream) -> Iterable[T]:
|
|
971
|
+
while (i := stream.read_unsigned_varint()) > 0:
|
|
972
|
+
for _ in range(i):
|
|
973
|
+
yield self._element_serializer.read(stream)
|
|
974
|
+
|
|
975
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
976
|
+
raise NotImplementedError()
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
class FixedVectorSerializer(Generic[T, T_NP], TypeSerializer[list[T], np.object_]):
|
|
980
|
+
def __init__(
|
|
981
|
+
self, element_serializer: TypeSerializer[T, T_NP], length: int
|
|
982
|
+
) -> None:
|
|
983
|
+
super().__init__(np.dtype((element_serializer.overall_dtype(), length)))
|
|
984
|
+
self.element_serializer = element_serializer
|
|
985
|
+
self._length = length
|
|
986
|
+
|
|
987
|
+
def write(self, stream: CodedOutputStream, value: list[T]) -> None:
|
|
988
|
+
if len(value) != self._length:
|
|
989
|
+
raise ValueError(
|
|
990
|
+
f"Expected a list of length {self._length}, got {len(value)}"
|
|
991
|
+
)
|
|
992
|
+
for element in value:
|
|
993
|
+
self.element_serializer.write(stream, element)
|
|
994
|
+
|
|
995
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
996
|
+
raise NotImplementedError("Internal error: expected this to be a subarray")
|
|
997
|
+
|
|
998
|
+
def read(self, stream: CodedInputStream) -> list[T]:
|
|
999
|
+
return [self.element_serializer.read(stream) for _ in range(self._length)]
|
|
1000
|
+
|
|
1001
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
1002
|
+
raise NotImplementedError("Internal error: expected this to be a subarray")
|
|
1003
|
+
|
|
1004
|
+
def is_trivially_serializable(self) -> bool:
|
|
1005
|
+
return self.element_serializer.is_trivially_serializable()
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
class VectorSerializer(Generic[T, T_NP], TypeSerializer[list[T], np.object_]):
|
|
1009
|
+
def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None:
|
|
1010
|
+
super().__init__(np.object_)
|
|
1011
|
+
self._element_serializer = element_serializer
|
|
1012
|
+
|
|
1013
|
+
def write(self, stream: CodedOutputStream, value: list[T]) -> None:
|
|
1014
|
+
stream.write_unsigned_varint(len(value))
|
|
1015
|
+
for element in value:
|
|
1016
|
+
self._element_serializer.write(stream, element)
|
|
1017
|
+
|
|
1018
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
1019
|
+
if not isinstance(value, list):
|
|
1020
|
+
raise ValueError(f"Expected a list, got {type(value)}")
|
|
1021
|
+
|
|
1022
|
+
stream.write_unsigned_varint(len(value))
|
|
1023
|
+
for element in cast(list[T], value):
|
|
1024
|
+
self._element_serializer.write(stream, element)
|
|
1025
|
+
|
|
1026
|
+
def read(self, stream: CodedInputStream) -> list[T]:
|
|
1027
|
+
length = stream.read_unsigned_varint()
|
|
1028
|
+
return [self._element_serializer.read(stream) for _ in range(length)]
|
|
1029
|
+
|
|
1030
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
1031
|
+
return np.object_(self.read(stream))
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
TKey = TypeVar("TKey")
|
|
1035
|
+
TKey_NP = TypeVar("TKey_NP", bound=np.generic)
|
|
1036
|
+
TValue = TypeVar("TValue")
|
|
1037
|
+
TValue_NP = TypeVar("TValue_NP", bound=np.generic)
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
class MapSerializer(
|
|
1041
|
+
Generic[TKey, TKey_NP, TValue, TValue_NP],
|
|
1042
|
+
TypeSerializer[dict[TKey, TValue], np.object_],
|
|
1043
|
+
):
|
|
1044
|
+
def __init__(
|
|
1045
|
+
self,
|
|
1046
|
+
key_serializer: TypeSerializer[TKey, TKey_NP],
|
|
1047
|
+
value_serializer: TypeSerializer[TValue, TValue_NP],
|
|
1048
|
+
) -> None:
|
|
1049
|
+
super().__init__(np.object_)
|
|
1050
|
+
self._key_serializer = key_serializer
|
|
1051
|
+
self._value_serializer = value_serializer
|
|
1052
|
+
|
|
1053
|
+
def write(self, stream: CodedOutputStream, value: dict[TKey, TValue]) -> None:
|
|
1054
|
+
stream.write_unsigned_varint(len(value))
|
|
1055
|
+
for k, v in value.items():
|
|
1056
|
+
self._key_serializer.write(stream, k)
|
|
1057
|
+
self._value_serializer.write(stream, v)
|
|
1058
|
+
|
|
1059
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
1060
|
+
self.write(stream, cast(dict[TKey, TValue], value))
|
|
1061
|
+
|
|
1062
|
+
def read(self, stream: CodedInputStream) -> dict[TKey, TValue]:
|
|
1063
|
+
length = stream.read_unsigned_varint()
|
|
1064
|
+
return {
|
|
1065
|
+
self._key_serializer.read(stream): self._value_serializer.read(stream)
|
|
1066
|
+
for _ in range(length)
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
1070
|
+
return np.object_(self.read(stream))
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
class NDArraySerializerBase(
|
|
1074
|
+
Generic[T, T_NP], TypeSerializer[npt.NDArray[Any], np.object_]
|
|
1075
|
+
):
|
|
1076
|
+
def __init__(
|
|
1077
|
+
self,
|
|
1078
|
+
overall_dtype: npt.DTypeLike,
|
|
1079
|
+
element_serializer: TypeSerializer[T, T_NP],
|
|
1080
|
+
dtype: npt.DTypeLike,
|
|
1081
|
+
) -> None:
|
|
1082
|
+
super().__init__(overall_dtype)
|
|
1083
|
+
self.element_serializer = element_serializer
|
|
1084
|
+
|
|
1085
|
+
(
|
|
1086
|
+
self._array_dtype,
|
|
1087
|
+
self._subarray_shape,
|
|
1088
|
+
) = NDArraySerializerBase._get_dtype_and_subarray_shape(
|
|
1089
|
+
dtype
|
|
1090
|
+
if isinstance(dtype, np.dtype)
|
|
1091
|
+
else np.dtype(dtype) # pyright: ignore [reportUnknownArgumentType]
|
|
1092
|
+
)
|
|
1093
|
+
if self._subarray_shape == ():
|
|
1094
|
+
self._subarray_shape = None
|
|
1095
|
+
else:
|
|
1096
|
+
if isinstance(element_serializer, FixedNDArraySerializer) or isinstance(
|
|
1097
|
+
element_serializer, FixedVectorSerializer
|
|
1098
|
+
):
|
|
1099
|
+
self.element_serializer = cast(
|
|
1100
|
+
TypeSerializer[T, T_NP],
|
|
1101
|
+
element_serializer.element_serializer, # pyright: ignore [reportUnknownMemberType]
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
@staticmethod
|
|
1105
|
+
def _get_dtype_and_subarray_shape(
|
|
1106
|
+
dtype: np.dtype[Any],
|
|
1107
|
+
) -> tuple[np.dtype[Any], tuple[int, ...]]:
|
|
1108
|
+
if dtype.subdtype is None:
|
|
1109
|
+
return dtype, ()
|
|
1110
|
+
subres = NDArraySerializerBase._get_dtype_and_subarray_shape(dtype.subdtype[0])
|
|
1111
|
+
return (subres[0], dtype.subdtype[1] + subres[1])
|
|
1112
|
+
|
|
1113
|
+
def _write_data(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
|
|
1114
|
+
if value.dtype != self._array_dtype:
|
|
1115
|
+
# see if it's the same dtype but packed, not aligned
|
|
1116
|
+
packed_dtype = recfunctions.repack_fields(self._array_dtype, align=False, recurse=True) # type: ignore
|
|
1117
|
+
if packed_dtype != value.dtype:
|
|
1118
|
+
if packed_dtype == self._array_dtype:
|
|
1119
|
+
message = f"Expected dtype {self._array_dtype}, got {value.dtype}"
|
|
1120
|
+
else:
|
|
1121
|
+
message = f"Expected dtype {self._array_dtype} or {packed_dtype}, got {value.dtype}"
|
|
1122
|
+
|
|
1123
|
+
raise ValueError(message)
|
|
1124
|
+
|
|
1125
|
+
if self._is_current_array_trivially_serializable(value):
|
|
1126
|
+
stream.write_bytes_directly(value.data)
|
|
1127
|
+
else:
|
|
1128
|
+
for element in value.flat:
|
|
1129
|
+
self.element_serializer.write_numpy(stream, element)
|
|
1130
|
+
|
|
1131
|
+
def _read_data(
|
|
1132
|
+
self, stream: CodedInputStream, shape: tuple[int, ...]
|
|
1133
|
+
) -> npt.NDArray[Any]:
|
|
1134
|
+
flat_length = int(np.prod(shape)) # type: ignore
|
|
1135
|
+
|
|
1136
|
+
if self.element_serializer.is_trivially_serializable():
|
|
1137
|
+
flat_byte_length = flat_length * self._array_dtype.itemsize
|
|
1138
|
+
byte_array = stream.read_bytearray(flat_byte_length)
|
|
1139
|
+
return np.frombuffer(byte_array, dtype=self._array_dtype).reshape(shape)
|
|
1140
|
+
|
|
1141
|
+
result: npt.NDArray[T_NP] = np.ndarray((flat_length,), dtype=self._array_dtype)
|
|
1142
|
+
for i in range(flat_length):
|
|
1143
|
+
result[i] = self.element_serializer.read_numpy(stream)
|
|
1144
|
+
|
|
1145
|
+
return result.reshape(shape)
|
|
1146
|
+
|
|
1147
|
+
def _is_current_array_trivially_serializable(self, value: npt.NDArray[Any]) -> bool:
|
|
1148
|
+
return (
|
|
1149
|
+
self.element_serializer.is_trivially_serializable()
|
|
1150
|
+
and value.flags.c_contiguous
|
|
1151
|
+
and (
|
|
1152
|
+
self._array_dtype.fields is None
|
|
1153
|
+
or all(f != "" for f in self._array_dtype.fields)
|
|
1154
|
+
)
|
|
1155
|
+
)
|
|
1156
|
+
|
|
1157
|
+
|
|
1158
|
+
class DynamicNDArraySerializer(NDArraySerializerBase[T, T_NP]):
|
|
1159
|
+
def __init__(
|
|
1160
|
+
self,
|
|
1161
|
+
element_serializer: TypeSerializer[T, T_NP],
|
|
1162
|
+
) -> None:
|
|
1163
|
+
super().__init__(
|
|
1164
|
+
np.object_, element_serializer, element_serializer.overall_dtype()
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
def write(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
|
|
1168
|
+
if self._subarray_shape is None:
|
|
1169
|
+
stream.write_unsigned_varint(value.ndim)
|
|
1170
|
+
for dim in value.shape:
|
|
1171
|
+
stream.write_unsigned_varint(dim)
|
|
1172
|
+
else:
|
|
1173
|
+
if len(value.shape) < len(self._subarray_shape) or (
|
|
1174
|
+
value.shape[-len(self._subarray_shape) :] != self._subarray_shape
|
|
1175
|
+
):
|
|
1176
|
+
raise ValueError(
|
|
1177
|
+
f"The array is required to have shape (..., {(', '.join((str(i) for i in self._subarray_shape)))})"
|
|
1178
|
+
)
|
|
1179
|
+
stream.write_unsigned_varint(value.ndim - len(self._subarray_shape))
|
|
1180
|
+
for dim in value.shape[: -len(self._subarray_shape)]:
|
|
1181
|
+
stream.write_unsigned_varint(dim)
|
|
1182
|
+
|
|
1183
|
+
self._write_data(stream, value)
|
|
1184
|
+
|
|
1185
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
1186
|
+
self.write(stream, cast(npt.NDArray[Any], value))
|
|
1187
|
+
|
|
1188
|
+
def read(self, stream: CodedInputStream) -> npt.NDArray[Any]:
|
|
1189
|
+
if self._subarray_shape is None:
|
|
1190
|
+
ndims = stream.read_unsigned_varint()
|
|
1191
|
+
shape = tuple(stream.read_unsigned_varint() for _ in range(ndims))
|
|
1192
|
+
else:
|
|
1193
|
+
ndims = stream.read_unsigned_varint()
|
|
1194
|
+
shape = (
|
|
1195
|
+
tuple(stream.read_unsigned_varint() for _ in range(ndims))
|
|
1196
|
+
+ self._subarray_shape
|
|
1197
|
+
)
|
|
1198
|
+
return self._read_data(stream, shape)
|
|
1199
|
+
|
|
1200
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
1201
|
+
return cast(np.object_, self.read(stream))
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
class NDArraySerializer(Generic[T, T_NP], NDArraySerializerBase[T, T_NP]):
|
|
1205
|
+
def __init__(
|
|
1206
|
+
self,
|
|
1207
|
+
element_serializer: TypeSerializer[T, T_NP],
|
|
1208
|
+
ndims: int,
|
|
1209
|
+
) -> None:
|
|
1210
|
+
super().__init__(
|
|
1211
|
+
np.object_, element_serializer, element_serializer.overall_dtype()
|
|
1212
|
+
)
|
|
1213
|
+
self._ndims = ndims
|
|
1214
|
+
|
|
1215
|
+
def write(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
|
|
1216
|
+
if self._subarray_shape is None:
|
|
1217
|
+
if value.ndim != self._ndims:
|
|
1218
|
+
raise ValueError(f"Expected {self._ndims} dimensions, got {value.ndim}")
|
|
1219
|
+
|
|
1220
|
+
for dim in value.shape:
|
|
1221
|
+
stream.write_unsigned_varint(dim)
|
|
1222
|
+
else:
|
|
1223
|
+
total_dims = len(self._subarray_shape) + self._ndims
|
|
1224
|
+
if value.ndim != total_dims:
|
|
1225
|
+
raise ValueError(f"Expected {total_dims} dimensions, got {value.ndim}")
|
|
1226
|
+
|
|
1227
|
+
if value.shape[-len(self._subarray_shape) :] != self._subarray_shape:
|
|
1228
|
+
raise ValueError(
|
|
1229
|
+
f"The array is required to have shape (..., {(', '.join((str(i) for i in self._subarray_shape)))})"
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
for dim in value.shape[: -len(self._subarray_shape)]:
|
|
1233
|
+
stream.write_unsigned_varint(dim)
|
|
1234
|
+
|
|
1235
|
+
self._write_data(stream, value)
|
|
1236
|
+
|
|
1237
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
1238
|
+
self.write(stream, cast(npt.NDArray[Any], value))
|
|
1239
|
+
|
|
1240
|
+
def read(self, stream: CodedInputStream) -> npt.NDArray[Any]:
|
|
1241
|
+
shape = tuple(stream.read_unsigned_varint() for _ in range(self._ndims))
|
|
1242
|
+
if self._subarray_shape is not None:
|
|
1243
|
+
shape += self._subarray_shape
|
|
1244
|
+
|
|
1245
|
+
return self._read_data(stream, shape)
|
|
1246
|
+
|
|
1247
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
1248
|
+
return cast(np.object_, self.read(stream))
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
class FixedNDArraySerializer(Generic[T, T_NP], NDArraySerializerBase[T, T_NP]):
|
|
1252
|
+
def __init__(
|
|
1253
|
+
self,
|
|
1254
|
+
element_serializer: TypeSerializer[T, T_NP],
|
|
1255
|
+
shape: tuple[int, ...],
|
|
1256
|
+
) -> None:
|
|
1257
|
+
dtype = element_serializer.overall_dtype()
|
|
1258
|
+
super().__init__(np.dtype((dtype, shape)), element_serializer, dtype)
|
|
1259
|
+
self._shape = shape
|
|
1260
|
+
|
|
1261
|
+
def write(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
|
|
1262
|
+
required_shape = (
|
|
1263
|
+
self._shape
|
|
1264
|
+
if self._subarray_shape is None
|
|
1265
|
+
else self._shape + self._subarray_shape
|
|
1266
|
+
)
|
|
1267
|
+
if value.shape != required_shape:
|
|
1268
|
+
raise ValueError(f"Expected shape {required_shape}, got {value.shape}")
|
|
1269
|
+
|
|
1270
|
+
self._write_data(stream, value)
|
|
1271
|
+
|
|
1272
|
+
def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
|
|
1273
|
+
self.write(stream, cast(npt.NDArray[Any], value))
|
|
1274
|
+
|
|
1275
|
+
def read(self, stream: CodedInputStream) -> npt.NDArray[Any]:
|
|
1276
|
+
full_shape = (
|
|
1277
|
+
self._shape
|
|
1278
|
+
if self._subarray_shape is None
|
|
1279
|
+
else self._shape + self._subarray_shape
|
|
1280
|
+
)
|
|
1281
|
+
return self._read_data(stream, full_shape)
|
|
1282
|
+
|
|
1283
|
+
def read_numpy(self, stream: CodedInputStream) -> np.object_:
|
|
1284
|
+
return cast(np.object_, self.read(stream))
|
|
1285
|
+
|
|
1286
|
+
def is_trivially_serializable(self) -> bool:
|
|
1287
|
+
return self.element_serializer.is_trivially_serializable()
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
class RecordSerializer(TypeSerializer[T, np.void]):
|
|
1291
|
+
def __init__(
|
|
1292
|
+
self, field_serializers: list[Tuple[str, TypeSerializer[Any, Any]]]
|
|
1293
|
+
) -> None:
|
|
1294
|
+
super().__init__(
|
|
1295
|
+
np.dtype(
|
|
1296
|
+
[
|
|
1297
|
+
(name, serializer.overall_dtype())
|
|
1298
|
+
for name, serializer in field_serializers
|
|
1299
|
+
],
|
|
1300
|
+
align=True,
|
|
1301
|
+
)
|
|
1302
|
+
)
|
|
1303
|
+
|
|
1304
|
+
self._field_serializers = field_serializers
|
|
1305
|
+
|
|
1306
|
+
def is_trivially_serializable(self) -> bool:
|
|
1307
|
+
return all(
|
|
1308
|
+
serializer.is_trivially_serializable()
|
|
1309
|
+
for _, serializer in self._field_serializers
|
|
1310
|
+
)
|
|
1311
|
+
|
|
1312
|
+
def _write(self, stream: CodedOutputStream, *values: Any) -> None:
|
|
1313
|
+
for i, (_, serializer) in enumerate(self._field_serializers):
|
|
1314
|
+
serializer.write(stream, values[i])
|
|
1315
|
+
|
|
1316
|
+
def _read(self, stream: CodedInputStream) -> tuple[Any, ...]:
|
|
1317
|
+
return tuple(
|
|
1318
|
+
serializer.read(stream) for _, serializer in self._field_serializers
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
def read_numpy(self, stream: CodedInputStream) -> np.void:
|
|
1322
|
+
return cast(np.void, self._read(stream))
|
|
1323
|
+
|
|
1324
|
+
|
|
1325
|
+
# Only used in the header
|
|
1326
|
+
int32_struct = struct.Struct("<i")
|
|
1327
|
+
assert int32_struct.size == 4
|
|
1328
|
+
|
|
1329
|
+
|
|
1330
|
+
def write_fixed_int32(stream: CodedOutputStream, value: int) -> None:
|
|
1331
|
+
if value < INT32_MIN or value > INT32_MAX:
|
|
1332
|
+
raise ValueError(
|
|
1333
|
+
f"Value {value} is outside the range of a signed 32-bit integer"
|
|
1334
|
+
)
|
|
1335
|
+
stream.write(int32_struct, value)
|
|
1336
|
+
|
|
1337
|
+
|
|
1338
|
+
def read_fixed_int32(stream: CodedInputStream) -> int:
|
|
1339
|
+
return stream.read(int32_struct)[0]
|