mrd-python 2.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mrd/_binary.py ADDED
@@ -0,0 +1,1339 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ # pyright: reportUnnecessaryIsInstance=false
5
+
6
+ import datetime
7
+ from enum import Enum
8
+ from io import BufferedIOBase, BufferedReader, BytesIO
9
+ from typing import (
10
+ BinaryIO,
11
+ Iterable,
12
+ Protocol,
13
+ TypeVar,
14
+ Generic,
15
+ Any,
16
+ Optional,
17
+ Tuple,
18
+ cast,
19
+ )
20
+ from abc import ABC, abstractmethod
21
+ import struct
22
+ import sys
23
+ import numpy as np
24
+
25
+ from numpy.lib import recfunctions
26
+ import numpy.typing as npt
27
+
28
+ from .yardl_types import *
29
+
30
+ if sys.byteorder != "little":
31
+ raise RuntimeError("Only little-endian systems are currently supported")
32
+
33
+ MAGIC_BYTES: bytes = b"yardl"
34
+ CURRENT_BINARY_FORMAT_VERSION: int = 1
35
+
36
+ INT8_MIN: int = np.iinfo(np.int8).min
37
+ INT8_MAX: int = np.iinfo(np.int8).max
38
+
39
+ UINT8_MAX: int = np.iinfo(np.uint8).max
40
+
41
+ INT16_MIN: int = np.iinfo(np.int16).min
42
+ INT16_MAX: int = np.iinfo(np.int16).max
43
+
44
+ UINT16_MAX: int = np.iinfo(np.uint16).max
45
+
46
+ INT32_MIN: int = np.iinfo(np.int32).min
47
+ INT32_MAX: int = np.iinfo(np.int32).max
48
+
49
+ UINT32_MAX: int = np.iinfo(np.uint32).max
50
+
51
+ INT64_MIN: int = np.iinfo(np.int64).min
52
+ INT64_MAX: int = np.iinfo(np.int64).max
53
+
54
+ UINT64_MAX: int = np.iinfo(np.uint64).max
55
+
56
+
57
+ class BinaryProtocolWriter(ABC):
58
+ def __init__(self, stream: Union[BinaryIO, str], schema: str) -> None:
59
+ self._stream = CodedOutputStream(stream)
60
+ self._stream.write_bytes(MAGIC_BYTES)
61
+ write_fixed_int32(self._stream, CURRENT_BINARY_FORMAT_VERSION)
62
+ string_serializer.write(self._stream, schema)
63
+
64
+ def _close(self) -> None:
65
+ self._stream.close()
66
+
67
+ def _end_stream(self) -> None:
68
+ self._stream.ensure_capacity(1)
69
+ self._stream.write_byte_no_check(0)
70
+
71
+
72
+ class BinaryProtocolReader(ABC):
73
+ def __init__(
74
+ self,
75
+ stream: Union[BufferedReader, BytesIO, BinaryIO, str],
76
+ expected_schema: Optional[str],
77
+ ) -> None:
78
+ self._stream = CodedInputStream(stream)
79
+ magic_bytes = self._stream.read_view(len(MAGIC_BYTES))
80
+ if magic_bytes != MAGIC_BYTES: # pyright: ignore [reportUnnecessaryComparison]
81
+ raise RuntimeError("Invalid magic bytes")
82
+
83
+ version = read_fixed_int32(self._stream)
84
+ if version != CURRENT_BINARY_FORMAT_VERSION:
85
+ raise RuntimeError("Invalid binary format version")
86
+
87
+ self._schema = string_serializer.read(self._stream)
88
+ if expected_schema and self._schema != expected_schema:
89
+ raise RuntimeError("Invalid schema")
90
+
91
+ def _close(self) -> None:
92
+ self._stream.close()
93
+
94
+
95
+ class CodedOutputStream:
96
+ def __init__(
97
+ self, stream: Union[BinaryIO, str], *, buffer_size: int = 65536
98
+ ) -> None:
99
+ if isinstance(stream, str):
100
+ self._stream = cast(BinaryIO, open(stream, "wb"))
101
+ self._owns_stream = True
102
+ else:
103
+ self._stream = stream
104
+ self._owns_stream = False
105
+
106
+ self._buffer = bytearray(buffer_size)
107
+ self._offset = 0
108
+
109
+ def close(self) -> None:
110
+ self.flush()
111
+ if self._owns_stream:
112
+ self._stream.close()
113
+
114
+ def ensure_capacity(self, size: int) -> None:
115
+ if (len(self._buffer) - self._offset) < size:
116
+ self.flush()
117
+
118
+ def flush(self) -> None:
119
+ if self._offset > 0:
120
+ self._stream.write(self._buffer[: self._offset])
121
+ self._stream.flush()
122
+ self._offset = 0
123
+
124
+ def write(self, formatter: struct.Struct, *args: Any) -> None:
125
+ size = formatter.size
126
+ if (len(self._buffer) - self._offset) < size:
127
+ self.flush()
128
+
129
+ formatter.pack_into(self._buffer, self._offset, *args)
130
+ self._offset += size
131
+
132
+ def write_bytes(self, data: Union[bytes, bytearray]) -> None:
133
+ if len(data) > (len(self._buffer) - self._offset):
134
+ self.flush()
135
+ self._stream.write(data)
136
+ else:
137
+ self._buffer[self._offset : self._offset + len(data)] = data
138
+ self._offset += len(data)
139
+
140
+ def write_bytes_directly(self, data: Union[bytes, bytearray, memoryview]) -> None:
141
+ self.flush()
142
+ self._stream.write(data)
143
+
144
+ def write_byte_no_check(self, value: int) -> None:
145
+ assert 0 <= value <= UINT8_MAX
146
+ self._buffer[self._offset] = value
147
+ self._offset += 1
148
+
149
+ def write_unsigned_varint(
150
+ self,
151
+ value: Union[int, np.uint8, np.uint16, np.uint32, np.uint64],
152
+ ) -> None:
153
+ if (len(self._buffer) - self._offset) < 10:
154
+ self.flush()
155
+
156
+ int_val = int(value) # bitwise ops not supported on numpy types
157
+
158
+ while True:
159
+ if int_val < 0x80:
160
+ self.write_byte_no_check(int_val)
161
+ return
162
+
163
+ self.write_byte_no_check((int_val & 0x7F) | 0x80)
164
+ int_val >>= 7
165
+
166
+ def zigzag_encode(
167
+ self,
168
+ value: Union[int, np.int8, np.int16, np.int32, np.int64],
169
+ ) -> int:
170
+ int_val = int(value)
171
+ return (int_val << 1) ^ (int_val >> 63)
172
+
173
+ def write_signed_varint(
174
+ self,
175
+ value: Union[int, np.int8, np.int16, np.int32, np.int64],
176
+ ) -> None:
177
+ self.write_unsigned_varint(self.zigzag_encode(value))
178
+
179
+
180
+ class CodedInputStream:
181
+ def __init__(
182
+ self,
183
+ stream: Union[BufferedReader, BytesIO, BinaryIO, str],
184
+ *,
185
+ buffer_size: int = 65536,
186
+ ) -> None:
187
+ if isinstance(stream, str):
188
+ self._stream = open(stream, "rb")
189
+ self._owns_stream = True
190
+ else:
191
+ if not isinstance(stream, BufferedIOBase):
192
+ self._stream = BufferedReader(stream) # type: ignore
193
+ else:
194
+ self._stream = stream
195
+ self._owns_stream = False
196
+
197
+ self._last_read_count = 0
198
+ self._buffer = bytearray(buffer_size)
199
+ self._view = memoryview(self._buffer)
200
+ self._offset = 0
201
+ self._at_end = False
202
+
203
+ def close(self) -> None:
204
+ if self._owns_stream:
205
+ self._stream.close()
206
+
207
+ def read(self, formatter: struct.Struct) -> tuple[Any, ...]:
208
+ if self._last_read_count - self._offset < formatter.size:
209
+ self._fill_buffer(formatter.size)
210
+
211
+ result = formatter.unpack_from(self._buffer, self._offset)
212
+ self._offset += formatter.size
213
+ return result
214
+
215
+ def read_byte(self) -> int:
216
+ if self._last_read_count - self._offset < 1:
217
+ self._fill_buffer(1)
218
+
219
+ result = self._buffer[self._offset]
220
+ self._offset += 1
221
+ return result
222
+
223
+ def read_unsigned_varint(self) -> int:
224
+ result = 0
225
+ shift = 0
226
+ while True:
227
+ if self._last_read_count - self._offset < 1:
228
+ self._fill_buffer(1)
229
+
230
+ byte = self._buffer[self._offset]
231
+ self._offset += 1
232
+ result |= (byte & 0x7F) << shift
233
+ if byte < 0x80:
234
+ return result
235
+ shift += 7
236
+
237
+ def zigzag_decode(self, value: int) -> int:
238
+ return (value >> 1) ^ -(value & 1)
239
+
240
+ def read_signed_varint(self) -> int:
241
+ return self.zigzag_decode(self.read_unsigned_varint())
242
+
243
+ def read_view(self, count: int) -> memoryview:
244
+ if count <= (self._last_read_count - self._offset):
245
+ res = self._view[self._offset : self._offset + count]
246
+ self._offset += count
247
+ return res
248
+
249
+ if count > len(self._buffer):
250
+ local_buf = bytearray(count)
251
+ local_view = memoryview(local_buf)
252
+ remaining = self._last_read_count - self._offset
253
+ local_view[:remaining] = self._view[self._offset : self._last_read_count]
254
+ self._offset = self._last_read_count
255
+ if self._stream.readinto(local_view[remaining:]) < count - remaining:
256
+ raise EOFError("Unexpected EOF")
257
+ return local_view
258
+
259
+ self._fill_buffer(count)
260
+ result = self._view[self._offset : self._offset + count]
261
+ self._offset += count
262
+ return result
263
+
264
+ def read_bytearray(self, count: int) -> bytearray:
265
+ if count <= (self._last_read_count - self._offset):
266
+ res = bytearray(self._view[self._offset : self._offset + count])
267
+ self._offset += count
268
+ return res
269
+
270
+ if count > len(self._buffer):
271
+ local_buf = bytearray(count)
272
+ local_view = memoryview(local_buf)
273
+ remaining = self._last_read_count - self._offset
274
+ local_view[:remaining] = self._view[self._offset : self._last_read_count]
275
+ self._offset = self._last_read_count
276
+ if self._stream.readinto(local_view[remaining:]) < count - remaining:
277
+ raise EOFError("Unexpected EOF")
278
+ return local_buf
279
+
280
+ self._fill_buffer(count)
281
+ result = self._view[self._offset : self._offset + count]
282
+ self._offset += count
283
+ return bytearray(result)
284
+
285
+ def _fill_buffer(self, min_count: int = 0) -> None:
286
+ remaining = self._last_read_count - self._offset
287
+ if remaining > 0:
288
+ remaining_view = memoryview(self._buffer)[
289
+ self._offset : self._offset + remaining + 1
290
+ ]
291
+ self._buffer[:remaining] = remaining_view
292
+
293
+ slice = memoryview(self._buffer)[remaining:]
294
+ self._last_read_count = self._stream.readinto(slice) + remaining
295
+ self._offset = 0
296
+ if self._last_read_count == 0:
297
+ self._at_end = True
298
+ if min_count > 0 and (self._last_read_count) < min_count:
299
+ raise EOFError("Unexpected EOF")
300
+
301
+
302
+ T = TypeVar("T")
303
+ T_NP = TypeVar("T_NP", bound=np.generic)
304
+
305
+
306
+ class TypeSerializer(Generic[T, T_NP], ABC):
307
+ def __init__(self, dtype: npt.DTypeLike) -> None:
308
+ self._dtype: np.dtype[Any] = np.dtype(dtype)
309
+
310
+ def overall_dtype(self) -> np.dtype[Any]:
311
+ return self._dtype
312
+
313
+ def struct_format_str(self) -> Optional[str]:
314
+ return None
315
+
316
+ @abstractmethod
317
+ def write(self, stream: CodedOutputStream, value: T) -> None:
318
+ raise NotImplementedError
319
+
320
+ @abstractmethod
321
+ def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None:
322
+ raise NotImplementedError
323
+
324
+ @abstractmethod
325
+ def read(self, stream: CodedInputStream) -> T:
326
+ raise NotImplementedError
327
+
328
+ @abstractmethod
329
+ def read_numpy(self, stream: CodedInputStream) -> T_NP:
330
+ raise NotImplementedError
331
+
332
+ def is_trivially_serializable(self) -> bool:
333
+ return False
334
+
335
+
336
+ class StructSerializer(TypeSerializer[T, T_NP]):
337
+ def __init__(self, numpy_type: type, format_string: str) -> None:
338
+ super().__init__(numpy_type)
339
+ self._struct = struct.Struct(format_string)
340
+ self._numpy_type = numpy_type
341
+
342
+ def write(self, stream: CodedOutputStream, value: T) -> None:
343
+ stream.write(self._struct, value)
344
+
345
+ def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None:
346
+ stream.write(self._struct, value)
347
+
348
+ def read(self, stream: CodedInputStream) -> T:
349
+ return cast(T, stream.read(self._struct)[0])
350
+
351
+ def read_numpy(self, stream: CodedInputStream) -> T_NP:
352
+ return cast(T_NP, self._numpy_type(stream.read(self._struct)[0]))
353
+
354
+ def struct_format_str(self) -> str:
355
+ return self._struct.format
356
+
357
+
358
+ class BoolSerializer(StructSerializer[bool, np.bool_]):
359
+ def __init__(self) -> None:
360
+ super().__init__(np.bool_, "<?")
361
+
362
+ def read(self, stream: CodedInputStream) -> bool:
363
+ return super().read(stream)
364
+
365
+ def read_numpy(self, stream: CodedInputStream) -> np.bool_:
366
+ return super().read_numpy(stream)
367
+
368
+
369
+ bool_serializer = BoolSerializer()
370
+
371
+
372
+ class Int8Serializer(StructSerializer[Int8, np.int8]):
373
+ def __init__(self) -> None:
374
+ super().__init__(np.int8, "<b")
375
+
376
+ def read(self, stream: CodedInputStream) -> Int8:
377
+ return super().read(stream)
378
+
379
+ def is_trivially_serializable(self) -> bool:
380
+ return True
381
+
382
+
383
+ int8_serializer = Int8Serializer()
384
+
385
+
386
+ class UInt8Serializer(StructSerializer[UInt8, np.uint8]):
387
+ def __init__(self) -> None:
388
+ super().__init__(np.uint8, "<B")
389
+
390
+ def read(self, stream: CodedInputStream) -> UInt8:
391
+ return super().read(stream)
392
+
393
+ def is_trivially_serializable(self) -> bool:
394
+ return True
395
+
396
+
397
+ uint8_serializer = UInt8Serializer()
398
+
399
+
400
+ class Int16Serializer(TypeSerializer[Int16, np.int16]):
401
+ def __init__(self) -> None:
402
+ super().__init__(np.int16)
403
+
404
+ def write(self, stream: CodedOutputStream, value: Int16) -> None:
405
+ if isinstance(value, int):
406
+ if value < INT16_MIN or value > INT16_MAX:
407
+ raise ValueError(
408
+ f"Value {value} is outside the range of a signed 16-bit integer"
409
+ )
410
+ elif not isinstance(value, cast(type, np.int16)):
411
+ raise ValueError(f"Value is not a signed 16-bit integer: {value}")
412
+
413
+ stream.write_signed_varint(value)
414
+
415
+ def write_numpy(self, stream: CodedOutputStream, value: np.int16) -> None:
416
+ stream.write_signed_varint(value)
417
+
418
+ def read(self, stream: CodedInputStream) -> Int16:
419
+ return stream.read_signed_varint()
420
+
421
+ def read_numpy(self, stream: CodedInputStream) -> np.int16:
422
+ return np.int16(stream.read_signed_varint())
423
+
424
+
425
+ int16_serializer = Int16Serializer()
426
+
427
+
428
+ class UInt16Serializer(TypeSerializer[UInt16, np.uint16]):
429
+ def __init__(self) -> None:
430
+ super().__init__(np.uint16)
431
+
432
+ def write(self, stream: CodedOutputStream, value: UInt16) -> None:
433
+ if isinstance(value, int):
434
+ if value < 0 or value > UINT16_MAX:
435
+ raise ValueError(
436
+ f"Value {value} is outside the range of an unsigned 16-bit integer"
437
+ )
438
+ elif not isinstance(value, cast(type, np.uint16)):
439
+ raise ValueError(f"Value is not an unsigned 16-bit integer: {value}")
440
+
441
+ stream.write_unsigned_varint(value)
442
+
443
+ def write_numpy(self, stream: CodedOutputStream, value: np.uint16) -> None:
444
+ stream.write_unsigned_varint(value)
445
+
446
+ def read(self, stream: CodedInputStream) -> UInt16:
447
+ return stream.read_unsigned_varint()
448
+
449
+ def read_numpy(self, stream: CodedInputStream) -> np.uint16:
450
+ return np.uint16(stream.read_unsigned_varint())
451
+
452
+
453
+ uint16_serializer = UInt16Serializer()
454
+
455
+
456
+ class Int32Serializer(TypeSerializer[Int32, np.int32]):
457
+ def __init__(self) -> None:
458
+ super().__init__(np.int32)
459
+
460
+ def write(self, stream: CodedOutputStream, value: Int32) -> None:
461
+ if isinstance(value, int):
462
+ if value < INT32_MIN or value > INT32_MAX:
463
+ raise ValueError(
464
+ f"Value {value} is outside the range of a signed 32-bit integer"
465
+ )
466
+ elif not isinstance(value, cast(type, np.int32)):
467
+ raise ValueError(f"Value is not a signed 32-bit integer: {value}")
468
+
469
+ stream.write_signed_varint(value)
470
+
471
+ def write_numpy(self, stream: CodedOutputStream, value: np.int32) -> None:
472
+ stream.write_signed_varint(value)
473
+
474
+ def read(self, stream: CodedInputStream) -> Int32:
475
+ return stream.read_signed_varint()
476
+
477
+ def read_numpy(self, stream: CodedInputStream) -> np.int32:
478
+ return np.int32(stream.read_signed_varint())
479
+
480
+
481
+ int32_serializer = Int32Serializer()
482
+
483
+
484
+ class UInt32Serializer(TypeSerializer[UInt32, np.uint32]):
485
+ def __init__(self) -> None:
486
+ super().__init__(np.uint32)
487
+
488
+ def write(self, stream: CodedOutputStream, value: UInt32) -> None:
489
+ if isinstance(value, int):
490
+ if value < 0 or value > UINT32_MAX:
491
+ raise ValueError(
492
+ f"Value {value} is outside the range of an unsigned 32-bit integer"
493
+ )
494
+ elif not isinstance(value, cast(type, np.uint32)):
495
+ raise ValueError(f"Value is not an unsigned 32-bit integer: {value}")
496
+
497
+ stream.write_unsigned_varint(value)
498
+
499
+ def write_numpy(self, stream: CodedOutputStream, value: np.uint32) -> None:
500
+ stream.write_unsigned_varint(value)
501
+
502
+ def read(self, stream: CodedInputStream) -> UInt32:
503
+ return stream.read_unsigned_varint()
504
+
505
+ def read_numpy(self, stream: CodedInputStream) -> np.uint32:
506
+ return np.uint32(stream.read_unsigned_varint())
507
+
508
+
509
+ uint32_serializer = UInt32Serializer()
510
+
511
+
512
+ class Int64Serializer(TypeSerializer[Int64, np.int64]):
513
+ def __init__(self) -> None:
514
+ super().__init__(np.int64)
515
+
516
+ def write(self, stream: CodedOutputStream, value: Int64) -> None:
517
+ if isinstance(value, int):
518
+ if value < INT64_MIN or value > INT64_MAX:
519
+ raise ValueError(
520
+ f"Value {value} is outside the range of a signed 64-bit integer"
521
+ )
522
+ elif not isinstance(value, cast(type, np.int64)):
523
+ raise ValueError(f"Value is not a signed 64-bit integer: {value}")
524
+
525
+ stream.write_signed_varint(value)
526
+
527
+ def write_numpy(self, stream: CodedOutputStream, value: np.int64) -> None:
528
+ stream.write_signed_varint(value)
529
+
530
+ def read(self, stream: CodedInputStream) -> Int64:
531
+ return stream.read_signed_varint()
532
+
533
+ def read_numpy(self, stream: CodedInputStream) -> np.int64:
534
+ return np.int64(stream.read_signed_varint())
535
+
536
+
537
+ int64_serializer = Int64Serializer()
538
+
539
+
540
+ class UInt64Serializer(TypeSerializer[UInt64, np.uint64]):
541
+ def __init__(self) -> None:
542
+ super().__init__(np.uint64)
543
+
544
+ def write(self, stream: CodedOutputStream, value: UInt64) -> None:
545
+ if isinstance(value, int):
546
+ if value < 0 or value > UINT64_MAX:
547
+ raise ValueError(
548
+ f"Value {value} is outside the range of an unsigned 64-bit integer"
549
+ )
550
+ elif not isinstance(value, cast(type, np.uint64)):
551
+ raise ValueError(f"Value is not an unsigned 64-bit integer: {value}")
552
+
553
+ stream.write_unsigned_varint(value)
554
+
555
+ def write_numpy(self, stream: CodedOutputStream, value: np.uint64) -> None:
556
+ stream.write_unsigned_varint(value)
557
+
558
+ def read(self, stream: CodedInputStream) -> UInt64:
559
+ return stream.read_unsigned_varint()
560
+
561
+ def read_numpy(self, stream: CodedInputStream) -> np.uint64:
562
+ return np.uint64(stream.read_unsigned_varint())
563
+
564
+
565
+ uint64_serializer = UInt64Serializer()
566
+
567
+
568
+ class SizeSerializer(TypeSerializer[Size, np.uint64]):
569
+ def __init__(self) -> None:
570
+ super().__init__(np.uint64)
571
+
572
+ def write(self, stream: CodedOutputStream, value: Size) -> None:
573
+ if isinstance(value, int):
574
+ if value < 0 or value > UINT64_MAX:
575
+ raise ValueError(
576
+ f"Value {value} is outside the range of an unsigned 64-bit integer"
577
+ )
578
+ elif not isinstance(value, cast(type, np.uint64)):
579
+ raise ValueError(f"Value is not an unsigned 64-bit integer: {value}")
580
+
581
+ stream.write_unsigned_varint(value)
582
+
583
+ def write_numpy(self, stream: CodedOutputStream, value: np.uint64) -> None:
584
+ stream.write_unsigned_varint(value)
585
+
586
+ def read(self, stream: CodedInputStream) -> Size:
587
+ return stream.read_unsigned_varint()
588
+
589
+ def read_numpy(self, stream: CodedInputStream) -> np.uint64:
590
+ return np.uint64(stream.read_unsigned_varint())
591
+
592
+
593
+ size_serializer = SizeSerializer()
594
+
595
+
596
+ class Float32Serializer(StructSerializer[Float32, np.float32]):
597
+ def __init__(self) -> None:
598
+ super().__init__(np.float32, "<f")
599
+
600
+ def read(self, stream: CodedInputStream) -> Float32:
601
+ return super().read(stream)
602
+
603
+ def is_trivially_serializable(self) -> bool:
604
+ return True
605
+
606
+
607
+ float32_serializer = Float32Serializer()
608
+
609
+
610
+ class Float64Serializer(StructSerializer[Float64, np.float64]):
611
+ def __init__(self) -> None:
612
+ super().__init__(np.float64, "<d")
613
+
614
+ def read(self, stream: CodedInputStream) -> Float64:
615
+ return super().read(stream)
616
+
617
+ def is_trivially_serializable(self) -> bool:
618
+ return True
619
+
620
+
621
+ float64_serializer = Float64Serializer()
622
+
623
+
624
+ class Complex32Serializer(StructSerializer[ComplexFloat, np.complex64]):
625
+ def __init__(self) -> None:
626
+ super().__init__(np.complex64, "<ff")
627
+
628
+ def write(self, stream: CodedOutputStream, value: ComplexFloat) -> None:
629
+ stream.write(self._struct, value.real, value.imag)
630
+
631
+ def read(self, stream: CodedInputStream) -> ComplexFloat:
632
+ return ComplexFloat(*stream.read(self._struct))
633
+
634
+ def read_numpy(self, stream: CodedInputStream) -> np.complex64:
635
+ real, imag = stream.read(self._struct)
636
+ return np.complex64(complex(real, imag))
637
+
638
+ def is_trivially_serializable(self) -> bool:
639
+ return True
640
+
641
+
642
+ complexfloat32_serializer = Complex32Serializer()
643
+
644
+
645
+ class Complex64Serializer(StructSerializer[ComplexDouble, np.complex128]):
646
+ def __init__(self) -> None:
647
+ super().__init__(np.complex128, "<dd")
648
+
649
+ def write(self, stream: CodedOutputStream, value: ComplexDouble) -> None:
650
+ stream.write(self._struct, value.real, value.imag)
651
+
652
+ def read(self, stream: CodedInputStream) -> ComplexDouble:
653
+ return ComplexDouble(*stream.read(self._struct))
654
+
655
+ def read_numpy(self, stream: CodedInputStream) -> np.complex128:
656
+ real, imag = stream.read(self._struct)
657
+ return np.complex128(complex(real, imag))
658
+
659
+ def is_trivially_serializable(self) -> bool:
660
+ return True
661
+
662
+
663
+ complexfloat64_serializer = Complex64Serializer()
664
+
665
+
666
+ class StringSerializer(TypeSerializer[str, np.object_]):
667
+ def __init__(self) -> None:
668
+ super().__init__(np.object_)
669
+
670
+ def write(self, stream: CodedOutputStream, value: str) -> None:
671
+ b = value.encode("utf-8")
672
+ stream.write_unsigned_varint(len(b))
673
+ stream.write_bytes(b)
674
+
675
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
676
+ self.write(stream, cast(str, value))
677
+
678
+ def read(self, stream: CodedInputStream) -> str:
679
+ length = stream.read_unsigned_varint()
680
+ view = stream.read_view(length)
681
+ return str(view, "utf-8")
682
+
683
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
684
+ return np.object_(self.read(stream))
685
+
686
+
687
+ string_serializer = StringSerializer()
688
+
689
+ EPOCH_ORDINAL_DAYS = datetime.date(1970, 1, 1).toordinal()
690
+ DATETIME_DAYS_DTYPE = np.dtype("datetime64[D]")
691
+
692
+
693
+ class DateSerializer(TypeSerializer[datetime.date, np.datetime64]):
694
+ def __init__(self) -> None:
695
+ super().__init__(DATETIME_DAYS_DTYPE)
696
+
697
+ def write(self, stream: CodedOutputStream, value: datetime.date) -> None:
698
+ if isinstance(value, datetime.date):
699
+ stream.write_signed_varint(value.toordinal() - EPOCH_ORDINAL_DAYS)
700
+ else:
701
+ if not isinstance(value, np.datetime64):
702
+ raise ValueError(
703
+ f"Expected datetime.date or numpy.datetime64, got {type(value)}"
704
+ )
705
+
706
+ self.write_numpy(stream, value)
707
+
708
+ def write_numpy(self, stream: CodedOutputStream, value: np.datetime64) -> None:
709
+ if value.dtype == DATETIME_DAYS_DTYPE:
710
+ stream.write_signed_varint(value.astype(np.int32))
711
+ else:
712
+ stream.write_signed_varint(
713
+ value.astype(DATETIME_DAYS_DTYPE).astype(np.int32)
714
+ )
715
+
716
+ def read(self, stream: CodedInputStream) -> datetime.date:
717
+ days_since_epoch = stream.read_signed_varint()
718
+ return datetime.date.fromordinal(days_since_epoch + EPOCH_ORDINAL_DAYS)
719
+
720
+ def read_numpy(self, stream: CodedInputStream) -> np.datetime64:
721
+ days_since_epoch = stream.read_signed_varint()
722
+ return np.datetime64(days_since_epoch, "D")
723
+
724
+
725
+ date_serializer = DateSerializer()
726
+
727
+ TIMEDELTA_NANOSECONDS_DTYPE = np.dtype("timedelta64[ns]")
728
+
729
+
730
+ class TimeSerializer(TypeSerializer[Time, np.timedelta64]):
731
+ def __init__(self) -> None:
732
+ super().__init__(TIMEDELTA_NANOSECONDS_DTYPE)
733
+
734
+ def write(self, stream: CodedOutputStream, value: Time) -> None:
735
+ if isinstance(value, Time):
736
+ self.write_numpy(stream, value.numpy_value)
737
+ elif isinstance(value, datetime.time):
738
+ self.write_numpy(stream, Time.from_time(value).numpy_value)
739
+ else:
740
+ if not isinstance(value, np.timedelta64):
741
+ raise ValueError(
742
+ f"Expected a Time, datetime.time or np.timedelta64, got {type(value)}"
743
+ )
744
+
745
+ self.write_numpy(stream, value)
746
+
747
+ def write_numpy(self, stream: CodedOutputStream, value: np.timedelta64) -> None:
748
+ if value.dtype == TIMEDELTA_NANOSECONDS_DTYPE:
749
+ stream.write_signed_varint(value.astype(np.int64))
750
+ else:
751
+ stream.write_signed_varint(
752
+ value.astype(DATETIME_NANOSECONDS_DTYPE).astype(np.int64)
753
+ )
754
+
755
+ def read(self, stream: CodedInputStream) -> Time:
756
+ nanoseconds_since_midnight = stream.read_signed_varint()
757
+ return Time(nanoseconds_since_midnight)
758
+
759
+ def read_numpy(self, stream: CodedInputStream) -> np.timedelta64:
760
+ nanoseconds_since_midnight = stream.read_signed_varint()
761
+ return np.timedelta64(nanoseconds_since_midnight, "ns")
762
+
763
+
764
+ time_serializer = TimeSerializer()
765
+
766
+ DATETIME_NANOSECONDS_DTYPE = np.dtype("datetime64[ns]")
767
+ EPOCH_DATETIME = datetime.datetime.utcfromtimestamp(0)
768
+
769
+
770
+ class DateTimeSerializer(TypeSerializer[DateTime, np.datetime64]):
771
+ def __init__(self) -> None:
772
+ super().__init__(DATETIME_NANOSECONDS_DTYPE)
773
+
774
+ def write(self, stream: CodedOutputStream, value: DateTime) -> None:
775
+ if isinstance(value, DateTime):
776
+ self.write_numpy(stream, value.numpy_value)
777
+ elif isinstance(value, datetime.datetime):
778
+ self.write_numpy(stream, DateTime.from_datetime(value).numpy_value)
779
+ else:
780
+ if not isinstance(value, np.datetime64):
781
+ raise ValueError(
782
+ f"Expected datetime.datetime or numpy.datetime64, got {type(value)}"
783
+ )
784
+
785
+ self.write_numpy(stream, value)
786
+
787
+ def write_numpy(self, stream: CodedOutputStream, value: np.datetime64) -> None:
788
+ if value.dtype == DATETIME_NANOSECONDS_DTYPE:
789
+ stream.write_signed_varint(value.astype(np.int64))
790
+ else:
791
+ stream.write_signed_varint(
792
+ value.astype(DATETIME_NANOSECONDS_DTYPE).astype(np.int64)
793
+ )
794
+
795
+ def read(self, stream: CodedInputStream) -> DateTime:
796
+ nanoseconds_since_epoch = stream.read_signed_varint()
797
+ return DateTime(nanoseconds_since_epoch)
798
+
799
+ def read_numpy(self, stream: CodedInputStream) -> np.datetime64:
800
+ nanoseconds_since_epoch = stream.read_signed_varint()
801
+ return np.datetime64(nanoseconds_since_epoch, "ns")
802
+
803
+
804
+ datetime_serializer = DateTimeSerializer()
805
+
806
+
807
+ class NoneSerializer(TypeSerializer[None, Any]):
808
+ def __init__(self) -> None:
809
+ super().__init__(np.object_)
810
+
811
+ def write(self, stream: CodedOutputStream, value: None) -> None:
812
+ pass
813
+
814
+ def write_numpy(self, stream: CodedOutputStream, value: Any) -> None:
815
+ pass
816
+
817
+ def read(self, stream: CodedInputStream) -> None:
818
+ return None
819
+
820
+ def read_numpy(self, stream: CodedInputStream) -> Any:
821
+ return np.object_()
822
+
823
+
824
+ none_serializer = NoneSerializer()
825
+
826
+ TEnum = TypeVar("TEnum", bound=Enum)
827
+
828
+
829
+ class EnumSerializer(Generic[TEnum, T, T_NP], TypeSerializer[TEnum, T_NP]):
830
+ def __init__(
831
+ self, integer_serializer: TypeSerializer[T, T_NP], enum_type: type
832
+ ) -> None:
833
+ super().__init__(integer_serializer.overall_dtype())
834
+ self._integer_serializer = integer_serializer
835
+ self._enum_type = enum_type
836
+
837
+ def write(self, stream: CodedOutputStream, value: TEnum) -> None:
838
+ self._integer_serializer.write(stream, value.value)
839
+
840
+ def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None:
841
+ return self._integer_serializer.write_numpy(stream, value)
842
+
843
+ def read(self, stream: CodedInputStream) -> TEnum:
844
+ int_value = self._integer_serializer.read(stream)
845
+ return self._enum_type(int_value)
846
+
847
+ def read_numpy(self, stream: CodedInputStream) -> T_NP:
848
+ return self._integer_serializer.read_numpy(stream)
849
+
850
+ def is_trivially_serializable(self) -> bool:
851
+ return self._integer_serializer.is_trivially_serializable()
852
+
853
+
854
+ class OptionalSerializer(Generic[T, T_NP], TypeSerializer[Optional[T], np.void]):
855
+ def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None:
856
+ super().__init__(
857
+ np.dtype(
858
+ [("has_value", np.bool_), ("value", element_serializer.overall_dtype())]
859
+ )
860
+ )
861
+ self._element_serializer = element_serializer
862
+ self._none = cast(np.void, np.zeros((), dtype=self.overall_dtype())[()])
863
+
864
+ def write(self, stream: CodedOutputStream, value: Optional[T]) -> None:
865
+ stream.ensure_capacity(1)
866
+ if value is None:
867
+ stream.write_byte_no_check(0)
868
+ else:
869
+ stream.write_byte_no_check(1)
870
+ self._element_serializer.write(stream, value)
871
+
872
+ def write_numpy(self, stream: CodedOutputStream, value: np.void) -> None:
873
+ stream.ensure_capacity(1)
874
+ if not value["has_value"]:
875
+ stream.write_byte_no_check(0)
876
+ else:
877
+ stream.write_byte_no_check(1)
878
+ self._element_serializer.write_numpy(stream, value["value"])
879
+
880
+ def read(self, stream: CodedInputStream) -> Optional[T]:
881
+ has_value = stream.read_byte()
882
+ if has_value == 0:
883
+ return None
884
+ else:
885
+ return self._element_serializer.read(stream)
886
+
887
+ def read_numpy(self, stream: CodedInputStream) -> np.void:
888
+ has_value = stream.read_byte()
889
+ if has_value == 0:
890
+ return self._none
891
+ else:
892
+ return cast(np.void, (True, self._element_serializer.read_numpy(stream)))
893
+
894
+ def is_trivially_serializable(self) -> bool:
895
+ return super().is_trivially_serializable()
896
+
897
+
898
+ class UnionCaseProtocol(Protocol):
899
+ index: int
900
+ value: Any
901
+
902
+
903
+ class UnionSerializer(TypeSerializer[T, np.object_]):
904
+ def __init__(
905
+ self,
906
+ union_type: type,
907
+ cases: list[Optional[tuple[type, TypeSerializer[Any, Any]]]],
908
+ ) -> None:
909
+ super().__init__(np.object_)
910
+ self._union_type = union_type
911
+ self._cases = cases
912
+ self._offset = 1 if cases[0] is None else 0
913
+
914
+ def write(self, stream: CodedOutputStream, value: T) -> None:
915
+ if value is None:
916
+ if self._cases[0] is None:
917
+ stream.write_byte_no_check(0)
918
+ return
919
+ else:
920
+ raise ValueError("None is not a valid for this union type")
921
+
922
+ if not isinstance(value, self._union_type):
923
+ raise ValueError(
924
+ f"Expected union value of type {self._union_type} but got {type(value)}"
925
+ )
926
+
927
+ union_value = cast(UnionCaseProtocol, value)
928
+
929
+ tag_index = union_value.index + self._offset
930
+ stream.ensure_capacity(1)
931
+ stream.write_byte_no_check(tag_index)
932
+ type_case = self._cases[tag_index]
933
+ assert type_case is not None
934
+ type_case[1].write(stream, union_value.value)
935
+
936
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
937
+ self.write(stream, cast(T, value))
938
+
939
+ def read(self, stream: CodedInputStream) -> T:
940
+ case_index = stream.read_byte()
941
+ if case_index == 0 and self._offset == 1:
942
+ return None # type: ignore
943
+ case_type, case_serializer = self._cases[case_index] # type: ignore
944
+ return case_type(case_serializer.read(stream)) # type: ignore
945
+
946
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
947
+ return self.read(stream) # type: ignore
948
+
949
+
950
+ class StreamSerializer(TypeSerializer[Iterable[T], Any]):
951
+ def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None:
952
+ super().__init__(np.object_)
953
+ self._element_serializer = element_serializer
954
+
955
+ def write(self, stream: CodedOutputStream, value: Iterable[T]) -> None:
956
+ # Note that the final 0 is missing and will be added before the next protocol step
957
+ # or the protocol is closed.
958
+ if isinstance(value, list) and len(value) > 0:
959
+ stream.write_unsigned_varint(len(value))
960
+ for element in value:
961
+ self._element_serializer.write(stream, element)
962
+ else:
963
+ for element in value:
964
+ stream.write_byte_no_check(1)
965
+ self._element_serializer.write(stream, element)
966
+
967
+ def write_numpy(self, stream: CodedOutputStream, value: Any) -> None:
968
+ raise NotImplementedError()
969
+
970
+ def read(self, stream: CodedInputStream) -> Iterable[T]:
971
+ while (i := stream.read_unsigned_varint()) > 0:
972
+ for _ in range(i):
973
+ yield self._element_serializer.read(stream)
974
+
975
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
976
+ raise NotImplementedError()
977
+
978
+
979
+ class FixedVectorSerializer(Generic[T, T_NP], TypeSerializer[list[T], np.object_]):
980
+ def __init__(
981
+ self, element_serializer: TypeSerializer[T, T_NP], length: int
982
+ ) -> None:
983
+ super().__init__(np.dtype((element_serializer.overall_dtype(), length)))
984
+ self.element_serializer = element_serializer
985
+ self._length = length
986
+
987
+ def write(self, stream: CodedOutputStream, value: list[T]) -> None:
988
+ if len(value) != self._length:
989
+ raise ValueError(
990
+ f"Expected a list of length {self._length}, got {len(value)}"
991
+ )
992
+ for element in value:
993
+ self.element_serializer.write(stream, element)
994
+
995
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
996
+ raise NotImplementedError("Internal error: expected this to be a subarray")
997
+
998
+ def read(self, stream: CodedInputStream) -> list[T]:
999
+ return [self.element_serializer.read(stream) for _ in range(self._length)]
1000
+
1001
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
1002
+ raise NotImplementedError("Internal error: expected this to be a subarray")
1003
+
1004
+ def is_trivially_serializable(self) -> bool:
1005
+ return self.element_serializer.is_trivially_serializable()
1006
+
1007
+
1008
+ class VectorSerializer(Generic[T, T_NP], TypeSerializer[list[T], np.object_]):
1009
+ def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None:
1010
+ super().__init__(np.object_)
1011
+ self._element_serializer = element_serializer
1012
+
1013
+ def write(self, stream: CodedOutputStream, value: list[T]) -> None:
1014
+ stream.write_unsigned_varint(len(value))
1015
+ for element in value:
1016
+ self._element_serializer.write(stream, element)
1017
+
1018
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
1019
+ if not isinstance(value, list):
1020
+ raise ValueError(f"Expected a list, got {type(value)}")
1021
+
1022
+ stream.write_unsigned_varint(len(value))
1023
+ for element in cast(list[T], value):
1024
+ self._element_serializer.write(stream, element)
1025
+
1026
+ def read(self, stream: CodedInputStream) -> list[T]:
1027
+ length = stream.read_unsigned_varint()
1028
+ return [self._element_serializer.read(stream) for _ in range(length)]
1029
+
1030
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
1031
+ return np.object_(self.read(stream))
1032
+
1033
+
1034
+ TKey = TypeVar("TKey")
1035
+ TKey_NP = TypeVar("TKey_NP", bound=np.generic)
1036
+ TValue = TypeVar("TValue")
1037
+ TValue_NP = TypeVar("TValue_NP", bound=np.generic)
1038
+
1039
+
1040
+ class MapSerializer(
1041
+ Generic[TKey, TKey_NP, TValue, TValue_NP],
1042
+ TypeSerializer[dict[TKey, TValue], np.object_],
1043
+ ):
1044
+ def __init__(
1045
+ self,
1046
+ key_serializer: TypeSerializer[TKey, TKey_NP],
1047
+ value_serializer: TypeSerializer[TValue, TValue_NP],
1048
+ ) -> None:
1049
+ super().__init__(np.object_)
1050
+ self._key_serializer = key_serializer
1051
+ self._value_serializer = value_serializer
1052
+
1053
+ def write(self, stream: CodedOutputStream, value: dict[TKey, TValue]) -> None:
1054
+ stream.write_unsigned_varint(len(value))
1055
+ for k, v in value.items():
1056
+ self._key_serializer.write(stream, k)
1057
+ self._value_serializer.write(stream, v)
1058
+
1059
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
1060
+ self.write(stream, cast(dict[TKey, TValue], value))
1061
+
1062
+ def read(self, stream: CodedInputStream) -> dict[TKey, TValue]:
1063
+ length = stream.read_unsigned_varint()
1064
+ return {
1065
+ self._key_serializer.read(stream): self._value_serializer.read(stream)
1066
+ for _ in range(length)
1067
+ }
1068
+
1069
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
1070
+ return np.object_(self.read(stream))
1071
+
1072
+
1073
+ class NDArraySerializerBase(
1074
+ Generic[T, T_NP], TypeSerializer[npt.NDArray[Any], np.object_]
1075
+ ):
1076
+ def __init__(
1077
+ self,
1078
+ overall_dtype: npt.DTypeLike,
1079
+ element_serializer: TypeSerializer[T, T_NP],
1080
+ dtype: npt.DTypeLike,
1081
+ ) -> None:
1082
+ super().__init__(overall_dtype)
1083
+ self.element_serializer = element_serializer
1084
+
1085
+ (
1086
+ self._array_dtype,
1087
+ self._subarray_shape,
1088
+ ) = NDArraySerializerBase._get_dtype_and_subarray_shape(
1089
+ dtype
1090
+ if isinstance(dtype, np.dtype)
1091
+ else np.dtype(dtype) # pyright: ignore [reportUnknownArgumentType]
1092
+ )
1093
+ if self._subarray_shape == ():
1094
+ self._subarray_shape = None
1095
+ else:
1096
+ if isinstance(element_serializer, FixedNDArraySerializer) or isinstance(
1097
+ element_serializer, FixedVectorSerializer
1098
+ ):
1099
+ self.element_serializer = cast(
1100
+ TypeSerializer[T, T_NP],
1101
+ element_serializer.element_serializer, # pyright: ignore [reportUnknownMemberType]
1102
+ )
1103
+
1104
+ @staticmethod
1105
+ def _get_dtype_and_subarray_shape(
1106
+ dtype: np.dtype[Any],
1107
+ ) -> tuple[np.dtype[Any], tuple[int, ...]]:
1108
+ if dtype.subdtype is None:
1109
+ return dtype, ()
1110
+ subres = NDArraySerializerBase._get_dtype_and_subarray_shape(dtype.subdtype[0])
1111
+ return (subres[0], dtype.subdtype[1] + subres[1])
1112
+
1113
+ def _write_data(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
1114
+ if value.dtype != self._array_dtype:
1115
+ # see if it's the same dtype but packed, not aligned
1116
+ packed_dtype = recfunctions.repack_fields(self._array_dtype, align=False, recurse=True) # type: ignore
1117
+ if packed_dtype != value.dtype:
1118
+ if packed_dtype == self._array_dtype:
1119
+ message = f"Expected dtype {self._array_dtype}, got {value.dtype}"
1120
+ else:
1121
+ message = f"Expected dtype {self._array_dtype} or {packed_dtype}, got {value.dtype}"
1122
+
1123
+ raise ValueError(message)
1124
+
1125
+ if self._is_current_array_trivially_serializable(value):
1126
+ stream.write_bytes_directly(value.data)
1127
+ else:
1128
+ for element in value.flat:
1129
+ self.element_serializer.write_numpy(stream, element)
1130
+
1131
+ def _read_data(
1132
+ self, stream: CodedInputStream, shape: tuple[int, ...]
1133
+ ) -> npt.NDArray[Any]:
1134
+ flat_length = int(np.prod(shape)) # type: ignore
1135
+
1136
+ if self.element_serializer.is_trivially_serializable():
1137
+ flat_byte_length = flat_length * self._array_dtype.itemsize
1138
+ byte_array = stream.read_bytearray(flat_byte_length)
1139
+ return np.frombuffer(byte_array, dtype=self._array_dtype).reshape(shape)
1140
+
1141
+ result: npt.NDArray[T_NP] = np.ndarray((flat_length,), dtype=self._array_dtype)
1142
+ for i in range(flat_length):
1143
+ result[i] = self.element_serializer.read_numpy(stream)
1144
+
1145
+ return result.reshape(shape)
1146
+
1147
+ def _is_current_array_trivially_serializable(self, value: npt.NDArray[Any]) -> bool:
1148
+ return (
1149
+ self.element_serializer.is_trivially_serializable()
1150
+ and value.flags.c_contiguous
1151
+ and (
1152
+ self._array_dtype.fields is None
1153
+ or all(f != "" for f in self._array_dtype.fields)
1154
+ )
1155
+ )
1156
+
1157
+
1158
+ class DynamicNDArraySerializer(NDArraySerializerBase[T, T_NP]):
1159
+ def __init__(
1160
+ self,
1161
+ element_serializer: TypeSerializer[T, T_NP],
1162
+ ) -> None:
1163
+ super().__init__(
1164
+ np.object_, element_serializer, element_serializer.overall_dtype()
1165
+ )
1166
+
1167
+ def write(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
1168
+ if self._subarray_shape is None:
1169
+ stream.write_unsigned_varint(value.ndim)
1170
+ for dim in value.shape:
1171
+ stream.write_unsigned_varint(dim)
1172
+ else:
1173
+ if len(value.shape) < len(self._subarray_shape) or (
1174
+ value.shape[-len(self._subarray_shape) :] != self._subarray_shape
1175
+ ):
1176
+ raise ValueError(
1177
+ f"The array is required to have shape (..., {(', '.join((str(i) for i in self._subarray_shape)))})"
1178
+ )
1179
+ stream.write_unsigned_varint(value.ndim - len(self._subarray_shape))
1180
+ for dim in value.shape[: -len(self._subarray_shape)]:
1181
+ stream.write_unsigned_varint(dim)
1182
+
1183
+ self._write_data(stream, value)
1184
+
1185
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
1186
+ self.write(stream, cast(npt.NDArray[Any], value))
1187
+
1188
+ def read(self, stream: CodedInputStream) -> npt.NDArray[Any]:
1189
+ if self._subarray_shape is None:
1190
+ ndims = stream.read_unsigned_varint()
1191
+ shape = tuple(stream.read_unsigned_varint() for _ in range(ndims))
1192
+ else:
1193
+ ndims = stream.read_unsigned_varint()
1194
+ shape = (
1195
+ tuple(stream.read_unsigned_varint() for _ in range(ndims))
1196
+ + self._subarray_shape
1197
+ )
1198
+ return self._read_data(stream, shape)
1199
+
1200
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
1201
+ return cast(np.object_, self.read(stream))
1202
+
1203
+
1204
+ class NDArraySerializer(Generic[T, T_NP], NDArraySerializerBase[T, T_NP]):
1205
+ def __init__(
1206
+ self,
1207
+ element_serializer: TypeSerializer[T, T_NP],
1208
+ ndims: int,
1209
+ ) -> None:
1210
+ super().__init__(
1211
+ np.object_, element_serializer, element_serializer.overall_dtype()
1212
+ )
1213
+ self._ndims = ndims
1214
+
1215
+ def write(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
1216
+ if self._subarray_shape is None:
1217
+ if value.ndim != self._ndims:
1218
+ raise ValueError(f"Expected {self._ndims} dimensions, got {value.ndim}")
1219
+
1220
+ for dim in value.shape:
1221
+ stream.write_unsigned_varint(dim)
1222
+ else:
1223
+ total_dims = len(self._subarray_shape) + self._ndims
1224
+ if value.ndim != total_dims:
1225
+ raise ValueError(f"Expected {total_dims} dimensions, got {value.ndim}")
1226
+
1227
+ if value.shape[-len(self._subarray_shape) :] != self._subarray_shape:
1228
+ raise ValueError(
1229
+ f"The array is required to have shape (..., {(', '.join((str(i) for i in self._subarray_shape)))})"
1230
+ )
1231
+
1232
+ for dim in value.shape[: -len(self._subarray_shape)]:
1233
+ stream.write_unsigned_varint(dim)
1234
+
1235
+ self._write_data(stream, value)
1236
+
1237
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
1238
+ self.write(stream, cast(npt.NDArray[Any], value))
1239
+
1240
+ def read(self, stream: CodedInputStream) -> npt.NDArray[Any]:
1241
+ shape = tuple(stream.read_unsigned_varint() for _ in range(self._ndims))
1242
+ if self._subarray_shape is not None:
1243
+ shape += self._subarray_shape
1244
+
1245
+ return self._read_data(stream, shape)
1246
+
1247
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
1248
+ return cast(np.object_, self.read(stream))
1249
+
1250
+
1251
+ class FixedNDArraySerializer(Generic[T, T_NP], NDArraySerializerBase[T, T_NP]):
1252
+ def __init__(
1253
+ self,
1254
+ element_serializer: TypeSerializer[T, T_NP],
1255
+ shape: tuple[int, ...],
1256
+ ) -> None:
1257
+ dtype = element_serializer.overall_dtype()
1258
+ super().__init__(np.dtype((dtype, shape)), element_serializer, dtype)
1259
+ self._shape = shape
1260
+
1261
+ def write(self, stream: CodedOutputStream, value: npt.NDArray[Any]) -> None:
1262
+ required_shape = (
1263
+ self._shape
1264
+ if self._subarray_shape is None
1265
+ else self._shape + self._subarray_shape
1266
+ )
1267
+ if value.shape != required_shape:
1268
+ raise ValueError(f"Expected shape {required_shape}, got {value.shape}")
1269
+
1270
+ self._write_data(stream, value)
1271
+
1272
+ def write_numpy(self, stream: CodedOutputStream, value: np.object_) -> None:
1273
+ self.write(stream, cast(npt.NDArray[Any], value))
1274
+
1275
+ def read(self, stream: CodedInputStream) -> npt.NDArray[Any]:
1276
+ full_shape = (
1277
+ self._shape
1278
+ if self._subarray_shape is None
1279
+ else self._shape + self._subarray_shape
1280
+ )
1281
+ return self._read_data(stream, full_shape)
1282
+
1283
+ def read_numpy(self, stream: CodedInputStream) -> np.object_:
1284
+ return cast(np.object_, self.read(stream))
1285
+
1286
+ def is_trivially_serializable(self) -> bool:
1287
+ return self.element_serializer.is_trivially_serializable()
1288
+
1289
+
1290
+ class RecordSerializer(TypeSerializer[T, np.void]):
1291
+ def __init__(
1292
+ self, field_serializers: list[Tuple[str, TypeSerializer[Any, Any]]]
1293
+ ) -> None:
1294
+ super().__init__(
1295
+ np.dtype(
1296
+ [
1297
+ (name, serializer.overall_dtype())
1298
+ for name, serializer in field_serializers
1299
+ ],
1300
+ align=True,
1301
+ )
1302
+ )
1303
+
1304
+ self._field_serializers = field_serializers
1305
+
1306
+ def is_trivially_serializable(self) -> bool:
1307
+ return all(
1308
+ serializer.is_trivially_serializable()
1309
+ for _, serializer in self._field_serializers
1310
+ )
1311
+
1312
+ def _write(self, stream: CodedOutputStream, *values: Any) -> None:
1313
+ for i, (_, serializer) in enumerate(self._field_serializers):
1314
+ serializer.write(stream, values[i])
1315
+
1316
+ def _read(self, stream: CodedInputStream) -> tuple[Any, ...]:
1317
+ return tuple(
1318
+ serializer.read(stream) for _, serializer in self._field_serializers
1319
+ )
1320
+
1321
+ def read_numpy(self, stream: CodedInputStream) -> np.void:
1322
+ return cast(np.void, self._read(stream))
1323
+
1324
+
1325
+ # Only used in the header
1326
+ int32_struct = struct.Struct("<i")
1327
+ assert int32_struct.size == 4
1328
+
1329
+
1330
+ def write_fixed_int32(stream: CodedOutputStream, value: int) -> None:
1331
+ if value < INT32_MIN or value > INT32_MAX:
1332
+ raise ValueError(
1333
+ f"Value {value} is outside the range of a signed 32-bit integer"
1334
+ )
1335
+ stream.write(int32_struct, value)
1336
+
1337
+
1338
+ def read_fixed_int32(stream: CodedInputStream) -> int:
1339
+ return stream.read(int32_struct)[0]