mrd-python 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mrd/_ndjson.py ADDED
@@ -0,0 +1,1194 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ # pyright: reportUnnecessaryIsInstance=false
5
+ # pyright: reportUnknownArgumentType=false
6
+ # pyright: reportUnknownVariableType=false
7
+
8
+ from abc import ABC, abstractmethod
9
+ import datetime
10
+ from enum import IntFlag
11
+ import io
12
+ import json
13
+ from typing import Any, Generic, Optional, TextIO, TypeVar, Union, cast
14
+
15
+ import numpy as np
16
+ import numpy.typing as npt
17
+ from numpy.lib import recfunctions
18
+
19
+ from .yardl_types import *
20
+
21
+ CURRENT_NDJSON_FORMAT_VERSION: int = 1
22
+
23
+ INT8_MIN: int = np.iinfo(np.int8).min
24
+ INT8_MAX: int = np.iinfo(np.int8).max
25
+
26
+ UINT8_MAX: int = np.iinfo(np.uint8).max
27
+
28
+ INT16_MIN: int = np.iinfo(np.int16).min
29
+ INT16_MAX: int = np.iinfo(np.int16).max
30
+
31
+ UINT16_MAX: int = np.iinfo(np.uint16).max
32
+
33
+ INT32_MIN: int = np.iinfo(np.int32).min
34
+ INT32_MAX: int = np.iinfo(np.int32).max
35
+
36
+ UINT32_MAX: int = np.iinfo(np.uint32).max
37
+
38
+ INT64_MIN: int = np.iinfo(np.int64).min
39
+ INT64_MAX: int = np.iinfo(np.int64).max
40
+
41
+ UINT64_MAX: int = np.iinfo(np.uint64).max
42
+
43
+ MISSING_SENTINEL = object()
44
+
45
+
46
+ class NDJsonProtocolWriter(ABC):
47
+ def __init__(self, stream: Union[TextIO, str], schema: str) -> None:
48
+ if isinstance(stream, str):
49
+ self._stream = open(stream, "w", encoding="utf-8")
50
+ self._owns_stream = True
51
+ else:
52
+ self._stream = stream
53
+ self._owns_stream = False
54
+
55
+ self._write_json_line(
56
+ {
57
+ "yardl": {
58
+ "version": CURRENT_NDJSON_FORMAT_VERSION,
59
+ "schema": json.loads(schema),
60
+ },
61
+ },
62
+ )
63
+
64
+ def _close(self) -> None:
65
+ if self._owns_stream:
66
+ self._stream.close()
67
+
68
+ def _end_stream(self) -> None:
69
+ pass
70
+
71
+ def _write_json_line(self, value: object) -> None:
72
+ json.dump(
73
+ value,
74
+ self._stream,
75
+ ensure_ascii=False,
76
+ separators=(",", ":"),
77
+ check_circular=False,
78
+ )
79
+ self._stream.write("\n")
80
+
81
+
82
+ class NDJsonProtocolReader:
83
+ def __init__(
84
+ self, stream: Union[io.BufferedReader, TextIO, str], schema: str
85
+ ) -> None:
86
+ if isinstance(stream, str):
87
+ self._stream = open(stream, "r", encoding="utf-8")
88
+ self._owns_stream = True
89
+ else:
90
+ self._stream = stream
91
+ self._owns_stream = False
92
+
93
+ self._unused_value: Optional[dict[str, object]] = None
94
+
95
+ line = self._stream.readline()
96
+ try:
97
+ header_json = json.loads(line)
98
+ except json.JSONDecodeError:
99
+ raise ValueError(
100
+ "Data in the stream is not in the expected Yardl NDJSON format."
101
+ )
102
+
103
+ if not isinstance(header_json, dict) or not "yardl" in header_json:
104
+ raise ValueError(
105
+ "Data in the stream is not in the expected Yardl NDJSON format."
106
+ )
107
+
108
+ header_json = header_json["yardl"]
109
+ if not isinstance(header_json, dict):
110
+ raise ValueError(
111
+ "Data in the stream is not in the expected Yardl NDJSON format."
112
+ )
113
+
114
+ if (
115
+ header_json.get("version") # pyright: ignore [reportUnknownMemberType]
116
+ != CURRENT_NDJSON_FORMAT_VERSION
117
+ ):
118
+ raise ValueError("Unsupported yardl version.")
119
+
120
+ if header_json.get( # pyright: ignore [reportUnknownMemberType]
121
+ "schema"
122
+ ) != json.loads(schema):
123
+ raise ValueError(
124
+ "The schema of the data to be read is not compatible with the current protocol."
125
+ )
126
+
127
+ def _close(self) -> None:
128
+ if self._owns_stream:
129
+ self._stream.close()
130
+
131
+ def _read_json_line(self, stepName: str, required: bool) -> object:
132
+ missing = MISSING_SENTINEL
133
+ if self._unused_value is not None:
134
+ if (value := self._unused_value.get(stepName, missing)) is not missing:
135
+ self._unused_value = None
136
+ return value
137
+ if required:
138
+ raise ValueError(f"Expected protocol step '{stepName}' not found.")
139
+ return MISSING_SENTINEL
140
+
141
+ line = self._stream.readline()
142
+ if line == "":
143
+ if not required:
144
+ return MISSING_SENTINEL
145
+ raise ValueError(
146
+ f"Encountered EOF but expected to find protocol step '{stepName}'."
147
+ )
148
+
149
+ json_object = json.loads(line)
150
+ if (value := json_object.get(stepName, missing)) is not MISSING_SENTINEL:
151
+ return value
152
+
153
+ if not required:
154
+ self._unused_value = json_object
155
+ return MISSING_SENTINEL
156
+
157
+ raise ValueError(f"Expected protocol step '{stepName}' not found.")
158
+
159
+
160
+ T = TypeVar("T")
161
+ T_NP = TypeVar("T_NP", bound=np.generic)
162
+
163
+
164
+ class JsonConverter(Generic[T, T_NP], ABC):
165
+ def __init__(self, dtype: npt.DTypeLike) -> None:
166
+ self._dtype: np.dtype[Any] = np.dtype(dtype)
167
+
168
+ def overall_dtype(self) -> np.dtype[Any]:
169
+ return self._dtype
170
+
171
+ @abstractmethod
172
+ def to_json(self, value: T) -> object:
173
+ raise NotImplementedError
174
+
175
+ @abstractmethod
176
+ def numpy_to_json(self, value: T_NP) -> object:
177
+ raise NotImplementedError
178
+
179
+ @abstractmethod
180
+ def from_json(self, json_object: object) -> T:
181
+ raise NotImplementedError
182
+
183
+ @abstractmethod
184
+ def from_json_to_numpy(self, json_object: object) -> T_NP:
185
+ raise NotImplementedError
186
+
187
+ def supports_none(self) -> bool:
188
+ return False
189
+
190
+
191
+ class BoolConverter(JsonConverter[bool, np.bool_]):
192
+ def __init__(self) -> None:
193
+ super().__init__(np.bool_)
194
+
195
+ def to_json(self, value: bool) -> object:
196
+ if not isinstance(value, bool):
197
+ raise TypeError(f"Expected a bool but got {type(value)}")
198
+
199
+ return value
200
+
201
+ def numpy_to_json(self, value: np.bool_) -> object:
202
+ return bool(value)
203
+
204
+ def from_json(self, json_object: object) -> bool:
205
+ return bool(json_object)
206
+
207
+ def from_json_to_numpy(self, json_object: object) -> np.bool_:
208
+ return np.bool_(json_object)
209
+
210
+
211
+ bool_converter = BoolConverter()
212
+
213
+
214
+ class Int8Converter(JsonConverter[int, np.int8]):
215
+ def __init__(self) -> None:
216
+ super().__init__(np.int8)
217
+
218
+ def to_json(self, value: int) -> object:
219
+ if not isinstance(value, int):
220
+ raise ValueError(f"Value in not a signed 8-bit integer: {value}")
221
+ if value < INT8_MIN or value > INT8_MAX:
222
+ raise ValueError(
223
+ f"Value {value} is outside the range of a signed 8-bit integer"
224
+ )
225
+
226
+ return value
227
+
228
+ def numpy_to_json(self, value: np.int8) -> object:
229
+ return int(value)
230
+
231
+ def from_json(self, json_object: object) -> int:
232
+ return cast(int, json_object)
233
+
234
+ def from_json_to_numpy(self, json_object: object) -> np.int8:
235
+ return np.int8(cast(int, json_object))
236
+
237
+
238
+ int8_converter = Int8Converter()
239
+
240
+
241
+ class UInt8Converter(JsonConverter[int, np.uint8]):
242
+ def __init__(self) -> None:
243
+ super().__init__(np.uint8)
244
+
245
+ def to_json(self, value: int) -> object:
246
+ if not isinstance(value, int):
247
+ raise ValueError(f"Value in not an unsigned 8-bit integer: {value}")
248
+ if value < 0 or value > UINT8_MAX:
249
+ raise ValueError(
250
+ f"Value {value} is outside the range of an unsigned 8-bit integer"
251
+ )
252
+
253
+ return value
254
+
255
+ def numpy_to_json(self, value: np.uint8) -> object:
256
+ return int(value)
257
+
258
+ def from_json(self, json_object: object) -> int:
259
+ return cast(int, json_object)
260
+
261
+ def from_json_to_numpy(self, json_object: object) -> np.uint8:
262
+ return np.uint8(cast(int, json_object))
263
+
264
+
265
+ uint8_converter = UInt8Converter()
266
+
267
+
268
+ class Int16Converter(JsonConverter[int, np.int16]):
269
+ def __init__(self) -> None:
270
+ super().__init__(np.int16)
271
+
272
+ def to_json(self, value: int) -> object:
273
+ if not isinstance(value, int):
274
+ raise ValueError(f"Value in not a signed 16-bit integer: {value}")
275
+ if value < INT16_MIN or value > INT16_MAX:
276
+ raise ValueError(
277
+ f"Value {value} is outside the range of a signed 16-bit integer"
278
+ )
279
+
280
+ return value
281
+
282
+ def numpy_to_json(self, value: np.int16) -> object:
283
+ return int(value)
284
+
285
+ def from_json(self, json_object: object) -> int:
286
+ return cast(int, json_object)
287
+
288
+ def from_json_to_numpy(self, json_object: object) -> np.int16:
289
+ return np.int16(cast(int, json_object))
290
+
291
+
292
+ int16_converter = Int16Converter()
293
+
294
+
295
+ class UInt16Converter(JsonConverter[int, np.uint16]):
296
+ def __init__(self) -> None:
297
+ super().__init__(np.uint16)
298
+
299
+ def to_json(self, value: int) -> object:
300
+ if not isinstance(value, int):
301
+ raise ValueError(f"Value in not an unsigned 16-bit integer: {value}")
302
+ if value < 0 or value > UINT16_MAX:
303
+ raise ValueError(
304
+ f"Value {value} is outside the range of an unsigned 16-bit integer"
305
+ )
306
+
307
+ return value
308
+
309
+ def numpy_to_json(self, value: np.uint16) -> object:
310
+ return int(value)
311
+
312
+ def from_json(self, json_object: object) -> int:
313
+ return cast(int, json_object)
314
+
315
+ def from_json_to_numpy(self, json_object: object) -> np.uint16:
316
+ return np.uint16(cast(int, json_object))
317
+
318
+
319
+ uint16_converter = UInt16Converter()
320
+
321
+
322
+ class Int32Converter(JsonConverter[int, np.int32]):
323
+ def __init__(self) -> None:
324
+ super().__init__(np.int32)
325
+
326
+ def to_json(self, value: int) -> object:
327
+ if not isinstance(value, int):
328
+ raise ValueError(f"Value in not a signed 32-bit integer: {value}")
329
+ if value < INT32_MIN or value > INT32_MAX:
330
+ raise ValueError(
331
+ f"Value {value} is outside the range of a signed 32-bit integer"
332
+ )
333
+
334
+ return value
335
+
336
+ def numpy_to_json(self, value: np.int32) -> object:
337
+ return int(value)
338
+
339
+ def from_json(self, json_object: object) -> int:
340
+ return cast(int, json_object)
341
+
342
+ def from_json_to_numpy(self, json_object: object) -> np.int32:
343
+ return np.int32(cast(int, json_object))
344
+
345
+
346
+ int32_converter = Int32Converter()
347
+
348
+
349
+ class UInt32Converter(JsonConverter[int, np.uint32]):
350
+ def __init__(self) -> None:
351
+ super().__init__(np.uint32)
352
+
353
+ def to_json(self, value: int) -> object:
354
+ if not isinstance(value, int):
355
+ raise ValueError(f"Value in not an unsigned 32-bit integer: {value}")
356
+ if value < 0 or value > UINT32_MAX:
357
+ raise ValueError(
358
+ f"Value {value} is outside the range of an unsigned 32-bit integer"
359
+ )
360
+
361
+ return value
362
+
363
+ def numpy_to_json(self, value: np.uint32) -> object:
364
+ return int(value)
365
+
366
+ def from_json(self, json_object: object) -> int:
367
+ return cast(int, json_object)
368
+
369
+ def from_json_to_numpy(self, json_object: object) -> np.uint32:
370
+ return np.uint32(cast(int, json_object))
371
+
372
+
373
+ uint32_converter = UInt32Converter()
374
+
375
+
376
+ class Int64Converter(JsonConverter[int, np.int64]):
377
+ def __init__(self) -> None:
378
+ super().__init__(np.int64)
379
+
380
+ def to_json(self, value: int) -> object:
381
+ if not isinstance(value, int):
382
+ raise ValueError(f"Value in not a signed 64-bit integer: {value}")
383
+ if value < INT64_MIN or value > INT64_MAX:
384
+ raise ValueError(
385
+ f"Value {value} is outside the range of a signed 64-bit integer"
386
+ )
387
+
388
+ return value
389
+
390
+ def numpy_to_json(self, value: np.int64) -> object:
391
+ return int(value)
392
+
393
+ def from_json(self, json_object: object) -> int:
394
+ return cast(int, json_object)
395
+
396
+ def from_json_to_numpy(self, json_object: object) -> np.int64:
397
+ return np.int64(cast(int, json_object))
398
+
399
+
400
+ int64_converter = Int64Converter()
401
+
402
+
403
+ class UInt64Converter(JsonConverter[int, np.uint64]):
404
+ def __init__(self) -> None:
405
+ super().__init__(np.uint64)
406
+
407
+ def to_json(self, value: int) -> object:
408
+ if not isinstance(value, int):
409
+ raise ValueError(f"Value in not an unsigned 64-bit integer: {value}")
410
+ if value < 0 or value > UINT64_MAX:
411
+ raise ValueError(
412
+ f"Value {value} is outside the range of an unsigned 64-bit integer"
413
+ )
414
+
415
+ return value
416
+
417
+ def numpy_to_json(self, value: np.uint64) -> object:
418
+ return int(value)
419
+
420
+ def from_json(self, json_object: object) -> int:
421
+ return cast(int, json_object)
422
+
423
+ def from_json_to_numpy(self, json_object: object) -> np.uint64:
424
+ return np.uint64(cast(int, json_object))
425
+
426
+
427
+ uint64_converter = UInt64Converter()
428
+
429
+
430
+ class SizeConverter(JsonConverter[int, np.uint64]):
431
+ def __init__(self) -> None:
432
+ super().__init__(np.uint64)
433
+
434
+ def to_json(self, value: int) -> object:
435
+ if not isinstance(value, int):
436
+ raise ValueError(f"Value in not an unsigned 64-bit integer: {value}")
437
+ if value < 0 or value > UINT64_MAX:
438
+ raise ValueError(
439
+ f"Value {value} is outside the range of an unsigned 64-bit integer"
440
+ )
441
+
442
+ return value
443
+
444
+ def numpy_to_json(self, value: np.uint64) -> object:
445
+ return int(value)
446
+
447
+ def from_json(self, json_object: object) -> int:
448
+ return cast(int, json_object)
449
+
450
+ def from_json_to_numpy(self, json_object: object) -> np.uint64:
451
+ return np.uint64(cast(int, json_object))
452
+
453
+
454
+ size_converter = SizeConverter()
455
+
456
+
457
+ class Float32Converter(JsonConverter[float, np.float32]):
458
+ def __init__(self) -> None:
459
+ super().__init__(np.float32)
460
+
461
+ def to_json(self, value: float) -> object:
462
+ if not isinstance(value, float):
463
+ raise ValueError(f"Value in not a 32-bit float: {value}")
464
+
465
+ return value
466
+
467
+ def numpy_to_json(self, value: np.float32) -> object:
468
+ return float(value)
469
+
470
+ def from_json(self, json_object: object) -> float:
471
+ return cast(float, json_object)
472
+
473
+ def from_json_to_numpy(self, json_object: object) -> np.float32:
474
+ return np.float32(cast(float, json_object))
475
+
476
+
477
+ float32_converter = Float32Converter()
478
+
479
+
480
+ class Float64Converter(JsonConverter[float, np.float64]):
481
+ def __init__(self) -> None:
482
+ super().__init__(np.float64)
483
+
484
+ def to_json(self, value: float) -> object:
485
+ if not isinstance(value, float):
486
+ raise ValueError(f"Value in not a 64-bit float: {value}")
487
+
488
+ return value
489
+
490
+ def numpy_to_json(self, value: np.float64) -> object:
491
+ return float(value)
492
+
493
+ def from_json(self, json_object: object) -> float:
494
+ return cast(float, json_object)
495
+
496
+ def from_json_to_numpy(self, json_object: object) -> np.float64:
497
+ return np.float64(cast(float, json_object))
498
+
499
+
500
+ float64_converter = Float64Converter()
501
+
502
+
503
+ class Complex32Converter(JsonConverter[complex, np.complex64]):
504
+ def __init__(self) -> None:
505
+ super().__init__(np.complex64)
506
+
507
+ def to_json(self, value: complex) -> object:
508
+ if not isinstance(value, complex):
509
+ raise ValueError(f"Value in not a 32-bit complex value: {value}")
510
+
511
+ return [value.real, value.imag]
512
+
513
+ def numpy_to_json(self, value: np.complex64) -> object:
514
+ return [float(value.real), float(value.imag)]
515
+
516
+ def from_json(self, json_object: object) -> complex:
517
+ if not isinstance(json_object, list) or len(json_object) != 2:
518
+ raise ValueError(f"Expected a list of two floating-point numbers.")
519
+
520
+ return complex(json_object[0], json_object[1])
521
+
522
+ def from_json_to_numpy(self, json_object: object) -> np.complex64:
523
+ return np.complex64(self.from_json(json_object))
524
+
525
+
526
+ complexfloat32_converter = Complex32Converter()
527
+
528
+
529
+ class Complex64Converter(JsonConverter[complex, np.complex128]):
530
+ def __init__(self) -> None:
531
+ super().__init__(np.complex128)
532
+
533
+ def to_json(self, value: complex) -> object:
534
+ if not isinstance(value, complex):
535
+ raise ValueError(f"Value in not a 64-bit complex value: {value}")
536
+
537
+ return [value.real, value.imag]
538
+
539
+ def numpy_to_json(self, value: np.complex128) -> object:
540
+ return [float(value.real), float(value.imag)]
541
+
542
+ def from_json(self, json_object: object) -> complex:
543
+ if not isinstance(json_object, list) or len(json_object) != 2:
544
+ raise ValueError(f"Expected a list of two floating-point numbers.")
545
+
546
+ return complex(json_object[0], json_object[1])
547
+
548
+ def from_json_to_numpy(self, json_object: object) -> np.complex128:
549
+ return np.complex128(self.from_json(json_object))
550
+
551
+
552
+ complexfloat64_converter = Complex64Converter()
553
+
554
+
555
+ class StringConverter(JsonConverter[str, np.object_]):
556
+ def __init__(self) -> None:
557
+ super().__init__(np.object_)
558
+
559
+ def to_json(self, value: str) -> object:
560
+ if not isinstance(value, str):
561
+ raise ValueError(f"Value in not a string: {value}")
562
+ return value
563
+
564
+ def numpy_to_json(self, value: np.object_) -> object:
565
+ return self.to_json(cast(str, value))
566
+
567
+ def from_json(self, json_object: object) -> str:
568
+ return cast(str, json_object)
569
+
570
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
571
+ return np.object_(json_object)
572
+
573
+
574
+ string_converter = StringConverter()
575
+
576
+
577
+ class DateConverter(JsonConverter[datetime.date, np.datetime64]):
578
+ def __init__(self) -> None:
579
+ super().__init__(np.datetime64)
580
+
581
+ def to_json(self, value: datetime.date) -> object:
582
+ if not isinstance(value, datetime.date):
583
+ raise ValueError(f"Value in not a date: {value}")
584
+ return value.isoformat()
585
+
586
+ def numpy_to_json(self, value: np.datetime64) -> object:
587
+ return str(value.astype("datetime64[D]"))
588
+
589
+ def from_json(self, json_object: object) -> datetime.date:
590
+ return datetime.date.fromisoformat(cast(str, json_object))
591
+
592
+ def from_json_to_numpy(self, json_object: object) -> np.datetime64:
593
+ return np.datetime64(cast(str, json_object), "D")
594
+
595
+
596
+ date_converter = DateConverter()
597
+
598
+
599
+ class TimeConverter(JsonConverter[Time, np.timedelta64]):
600
+ def __init__(self) -> None:
601
+ super().__init__(np.timedelta64)
602
+
603
+ def to_json(self, value: Time) -> object:
604
+ if isinstance(value, Time):
605
+ return str(value)
606
+ elif isinstance(value, datetime.time):
607
+ return value.isoformat()
608
+
609
+ raise ValueError(f"Value in not a time: {value}")
610
+
611
+ def numpy_to_json(self, value: np.timedelta64) -> object:
612
+ return str(Time(value))
613
+
614
+ def from_json(self, json_object: object) -> Time:
615
+ return Time.parse(cast(str, json_object))
616
+
617
+ def from_json_to_numpy(self, json_object: object) -> np.timedelta64:
618
+ return self.from_json(json_object).numpy_value
619
+
620
+
621
+ time_converter = TimeConverter()
622
+
623
+
624
+ class DateTimeConverter(JsonConverter[DateTime, np.datetime64]):
625
+ def __init__(self) -> None:
626
+ super().__init__(np.datetime64)
627
+
628
+ def to_json(self, value: DateTime) -> object:
629
+ if isinstance(value, DateTime):
630
+ return str(value)
631
+ elif isinstance(value, datetime.datetime):
632
+ return value.isoformat()
633
+
634
+ raise ValueError(f"Value in not a datetime: {value}")
635
+
636
+ def numpy_to_json(self, value: np.datetime64) -> object:
637
+ return str(value)
638
+
639
+ def from_json(self, json_object: object) -> DateTime:
640
+ return DateTime.parse(cast(str, json_object))
641
+
642
+ def from_json_to_numpy(self, json_object: object) -> np.datetime64:
643
+ return self.from_json(json_object).numpy_value
644
+
645
+
646
+ datetime_converter = DateTimeConverter()
647
+
648
+ TEnum = TypeVar("TEnum", bound=OutOfRangeEnum)
649
+
650
+
651
+ class EnumConverter(Generic[TEnum, T_NP], JsonConverter[TEnum, T_NP]):
652
+ def __init__(
653
+ self,
654
+ enum_type: type[TEnum],
655
+ numpy_type: type,
656
+ name_to_value: dict[str, TEnum],
657
+ value_to_name: dict[TEnum, str],
658
+ ) -> None:
659
+ super().__init__(numpy_type)
660
+ self._enum_type = enum_type
661
+ self._name_to_value = name_to_value
662
+ self._value_to_name = value_to_name
663
+
664
+ def to_json(self, value: TEnum) -> object:
665
+ if not isinstance(value, self._enum_type):
666
+ raise ValueError(f"Value in not an enum or not the right type: {value}")
667
+ if value.name == "":
668
+ return value.value
669
+
670
+ return self._value_to_name[value]
671
+
672
+ def numpy_to_json(self, value: T_NP) -> object:
673
+ return self.to_json(self._enum_type(value))
674
+
675
+ def from_json(self, json_object: object) -> TEnum:
676
+ if isinstance(json_object, int):
677
+ return self._enum_type(json_object)
678
+
679
+ return self._name_to_value[cast(str, json_object)]
680
+
681
+ def from_json_to_numpy(self, json_object: object) -> T_NP:
682
+ return self.from_json(json_object).value
683
+
684
+
685
+ TFlag = TypeVar("TFlag", bound=IntFlag)
686
+
687
+
688
+ class FlagsConverter(Generic[TFlag, T_NP], JsonConverter[TFlag, T_NP]):
689
+ def __init__(
690
+ self,
691
+ enum_type: type[TFlag],
692
+ numpy_type: type,
693
+ name_to_value: dict[str, TFlag],
694
+ value_to_name: dict[TFlag, str],
695
+ ) -> None:
696
+ super().__init__(numpy_type)
697
+ self._enum_type = enum_type
698
+ self._name_to_value = name_to_value
699
+ self._value_to_name = value_to_name
700
+ self._zero_enum = enum_type(0)
701
+ self._zero_json = (
702
+ [value_to_name[self._zero_enum]] if self._zero_enum in value_to_name else []
703
+ )
704
+
705
+ def to_json(self, value: TFlag) -> object:
706
+ if not isinstance(value, self._enum_type):
707
+ raise ValueError(f"Value in not an enum or not the right type: {value}")
708
+ if value.value == 0:
709
+ return self._zero_json
710
+
711
+ remaining_int_value = value.value
712
+ result: list[str] = []
713
+ for enum_value in self._value_to_name:
714
+ if enum_value.value == 0:
715
+ continue
716
+ if enum_value.value & remaining_int_value == enum_value.value:
717
+ result.append(self._value_to_name[enum_value])
718
+ remaining_int_value &= ~enum_value.value
719
+ if remaining_int_value == 0:
720
+ break
721
+
722
+ if remaining_int_value == 0:
723
+ return result
724
+
725
+ return value.value
726
+
727
+ def numpy_to_json(self, value: T_NP) -> object:
728
+ return self.to_json(self._enum_type(int(value))) # type: ignore
729
+
730
+ def from_json(self, json_object: object) -> TFlag:
731
+ if isinstance(json_object, int):
732
+ return self._enum_type(json_object)
733
+
734
+ assert isinstance(json_object, list)
735
+ res = self._zero_enum
736
+
737
+ for name in json_object:
738
+ res |= self._name_to_value[name]
739
+
740
+ return res
741
+
742
+ def from_json_to_numpy(self, json_object: object) -> T_NP:
743
+ return self.from_json(json_object).value # type: ignore
744
+
745
+
746
+ class OptionalConverter(Generic[T, T_NP], JsonConverter[Optional[T], np.void]):
747
+ def __init__(self, element_converter: JsonConverter[T, T_NP]) -> None:
748
+ super().__init__(
749
+ np.dtype(
750
+ [("has_value", np.bool_), ("value", element_converter.overall_dtype())]
751
+ )
752
+ )
753
+ self._element_converter = element_converter
754
+ self._none = cast(np.void, np.zeros((), dtype=self.overall_dtype())[()])
755
+
756
+ def to_json(self, value: Optional[T]) -> object:
757
+ if value is None:
758
+ return None
759
+ return self._element_converter.to_json(value)
760
+
761
+ def numpy_to_json(self, value: np.void) -> object:
762
+ if value["has_value"]:
763
+ return self._element_converter.numpy_to_json(value["value"])
764
+ return None
765
+
766
+ def from_json(self, json_object: object) -> Optional[T]:
767
+ if json_object is None:
768
+ return None
769
+ return self._element_converter.from_json(json_object)
770
+
771
+ def from_json_to_numpy(self, json_object: object) -> np.void:
772
+ if json_object is None:
773
+ return self._none
774
+ return (True, self._element_converter.from_json_to_numpy(json_object)) # type: ignore
775
+
776
+ def supports_none(self) -> bool:
777
+ return True
778
+
779
+
780
+ class UnionConverter(JsonConverter[T, np.object_]):
781
+ def __init__(
782
+ self,
783
+ union_type: type,
784
+ cases: list[Optional[tuple[type, JsonConverter[Any, Any], list[type]]]],
785
+ simple: bool,
786
+ ) -> None:
787
+ super().__init__(np.object_)
788
+ self._union_type = union_type
789
+ self._cases = cases
790
+ self._simple = simple
791
+ self._offset = 1 if cases[0] is None else 0
792
+ if self._simple:
793
+ self._json_type_to_case_index = {
794
+ json_type: case_index
795
+ for (case_index, case) in enumerate(cases)
796
+ if case is not None
797
+ for json_type in case[2]
798
+ }
799
+ else:
800
+ self.tag_to_case_index: dict[str, int] = {
801
+ case[0].tag: case_index # type: ignore
802
+ for (case_index, case) in enumerate(cases)
803
+ if case is not None
804
+ }
805
+
806
+ def to_json(self, value: T) -> object:
807
+ if value is None:
808
+ if self._cases[0] is None:
809
+ return None
810
+ else:
811
+ raise ValueError("None is not a valid for this union type")
812
+
813
+ if not isinstance(value, self._union_type):
814
+ raise ValueError(f"Value in not a union or not the right type: {value}")
815
+
816
+ tag_index = value.index + self._offset # type: ignore
817
+ inner_json_value = self._cases[tag_index][1].to_json(value.value) # type: ignore
818
+
819
+ if self._simple:
820
+ return inner_json_value
821
+ else:
822
+ return {value.tag: inner_json_value} # type: ignore
823
+
824
+ def numpy_to_json(self, value: np.object_) -> object:
825
+ return self.to_json(cast(T, value))
826
+
827
+ def from_json(self, json_object: object) -> T:
828
+ if json_object is None:
829
+ if self._cases[0] is None:
830
+ return None # type: ignore
831
+ else:
832
+ raise ValueError("None is not a valid for this union type")
833
+
834
+ if self._simple:
835
+ idx = self._json_type_to_case_index[type(json_object)]
836
+ case = self._cases[idx]
837
+ return case[0](case[1].from_json(json_object)) # type: ignore
838
+ else:
839
+ assert isinstance(json_object, dict)
840
+ tag, inner_json_object = next(iter(json_object.items()))
841
+ case = self._cases[self.tag_to_case_index[tag]]
842
+ return case[0](case[1].from_json(inner_json_object)) # type: ignore
843
+
844
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
845
+ return self.from_json(json_object) # type: ignore
846
+
847
+ def supports_none(self) -> bool:
848
+ return self._cases[0] is None
849
+
850
+
851
+ class VectorConverter(Generic[T, T_NP], JsonConverter[list[T], np.object_]):
852
+ def __init__(self, element_converter: JsonConverter[T, T_NP]) -> None:
853
+ super().__init__(np.object_)
854
+ self._element_converter = element_converter
855
+
856
+ def to_json(self, value: list[T]) -> object:
857
+ if not isinstance(value, list):
858
+ raise ValueError(f"Value in not a list: {value}")
859
+ return [self._element_converter.to_json(v) for v in value]
860
+
861
+ def numpy_to_json(self, value: object) -> object:
862
+ if isinstance(value, list):
863
+ return [self._element_converter.to_json(v) for v in value]
864
+
865
+ if not isinstance(value, np.ndarray):
866
+ raise ValueError(f"Value in not a list or ndarray: {value}")
867
+
868
+ if value.ndim != 1:
869
+ raise ValueError(f"Value in not a 1-dimensional ndarray: {value}")
870
+
871
+ return [self._element_converter.numpy_to_json(v) for v in value]
872
+
873
+ def from_json(self, json_object: object) -> list[T]:
874
+ if not isinstance(json_object, list):
875
+ raise ValueError(f"Value in not a list: {json_object}")
876
+ return [self._element_converter.from_json(v) for v in json_object]
877
+
878
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
879
+ return cast(np.object_, self.from_json(json_object))
880
+
881
+
882
+ class FixedVectorConverter(Generic[T, T_NP], JsonConverter[list[T], np.object_]):
883
+ def __init__(self, element_converter: JsonConverter[T, T_NP], length: int) -> None:
884
+ super().__init__(np.dtype((element_converter.overall_dtype(), length)))
885
+ self._element_converter = element_converter
886
+ self._length = length
887
+
888
+ def to_json(self, value: list[T]) -> object:
889
+ if not isinstance(value, list):
890
+ raise ValueError(f"Value in not a list: {value}")
891
+ if len(value) != self._length:
892
+ raise ValueError(f"Value in not a list of length {self._length}: {value}")
893
+ return [self._element_converter.to_json(v) for v in value]
894
+
895
+ def numpy_to_json(self, value: np.object_) -> object:
896
+ if not isinstance(value, np.ndarray):
897
+ raise ValueError(f"Value in not an ndarray: {value}")
898
+ if value.shape != (self._length,):
899
+ raise ValueError(f"Value does not have expected shape of {self._length}")
900
+
901
+ return [self._element_converter.numpy_to_json(v) for v in value]
902
+
903
+ def from_json(self, json_object: object) -> list[T]:
904
+ if not isinstance(json_object, list):
905
+ raise ValueError(f"Value in not a list: {json_object}")
906
+ if len(json_object) != self._length:
907
+ raise ValueError(
908
+ f"Value in not a list of length {self._length}: {json_object}"
909
+ )
910
+ return [self._element_converter.from_json(v) for v in json_object]
911
+
912
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
913
+ if not isinstance(json_object, list):
914
+ raise ValueError(f"Value in not a list: {json_object}")
915
+ if len(json_object) != self._length:
916
+ raise ValueError(
917
+ f"Value in not a list of length {self._length}: {json_object}"
918
+ )
919
+ return cast(
920
+ np.object_,
921
+ [self._element_converter.from_json_to_numpy(v) for v in json_object],
922
+ )
923
+
924
+
925
+ TKey = TypeVar("TKey")
926
+ TKey_NP = TypeVar("TKey_NP", bound=np.generic)
927
+ TValue = TypeVar("TValue")
928
+ TValue_NP = TypeVar("TValue_NP", bound=np.generic)
929
+
930
+
931
+ class MapConverter(
932
+ Generic[TKey, TKey_NP, TValue, TValue_NP],
933
+ JsonConverter[dict[TKey, TValue], np.object_],
934
+ ):
935
+ def __init__(
936
+ self,
937
+ key_converter: JsonConverter[TKey, TKey_NP],
938
+ value_converter: JsonConverter[TValue, TValue_NP],
939
+ ) -> None:
940
+ super().__init__(np.object_)
941
+ self._key_converter = key_converter
942
+ self._value_converter = value_converter
943
+
944
+ def to_json(self, value: dict[TKey, TValue]) -> object:
945
+ if not isinstance(value, dict):
946
+ raise ValueError(f"Value in not a dict: {value}")
947
+
948
+ if isinstance(self._key_converter, StringConverter):
949
+ return {
950
+ cast(str, k): self._value_converter.to_json(v) for k, v in value.items()
951
+ }
952
+
953
+ return [
954
+ [self._key_converter.to_json(k), self._value_converter.to_json(v)]
955
+ for k, v in value.items()
956
+ ]
957
+
958
+ def numpy_to_json(self, value: np.object_) -> object:
959
+ return self.to_json(cast(dict[TKey, TValue], value))
960
+
961
+ def from_json(self, json_object: object) -> dict[TKey, TValue]:
962
+ if isinstance(self._key_converter, StringConverter):
963
+ if not isinstance(json_object, dict):
964
+ raise ValueError(f"Value in not a dict: {json_object}")
965
+
966
+ return {
967
+ cast(TKey, k): self._value_converter.from_json(v)
968
+ for k, v in json_object.items()
969
+ }
970
+
971
+ if not isinstance(json_object, list):
972
+ raise ValueError(f"Value in not a list: {json_object}")
973
+
974
+ return {
975
+ self._key_converter.from_json(k): self._value_converter.from_json(v)
976
+ for [k, v] in json_object
977
+ }
978
+
979
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
980
+ return cast(np.object_, self.from_json(json_object))
981
+
982
+
983
+ class NDArrayConverterBase(
984
+ Generic[T, T_NP], JsonConverter[npt.NDArray[Any], np.object_]
985
+ ):
986
+ def __init__(
987
+ self,
988
+ overall_dtype: npt.DTypeLike,
989
+ element_converter: JsonConverter[T, T_NP],
990
+ dtype: npt.DTypeLike,
991
+ ) -> None:
992
+ super().__init__(overall_dtype)
993
+ self._element_converter = element_converter
994
+
995
+ (
996
+ self._array_dtype,
997
+ self._subarray_shape,
998
+ ) = NDArrayConverterBase._get_dtype_and_subarray_shape(
999
+ dtype if isinstance(dtype, np.dtype) else np.dtype(dtype)
1000
+ )
1001
+ if self._subarray_shape == ():
1002
+ self._subarray_shape = None
1003
+
1004
+ @staticmethod
1005
+ def _get_dtype_and_subarray_shape(
1006
+ dtype: np.dtype[Any],
1007
+ ) -> tuple[np.dtype[Any], tuple[int, ...]]:
1008
+ if dtype.subdtype is None:
1009
+ return dtype, ()
1010
+ subres = NDArrayConverterBase._get_dtype_and_subarray_shape(dtype.subdtype[0])
1011
+ return (subres[0], dtype.subdtype[1] + subres[1])
1012
+
1013
+ def check_dtype(self, input_dtype: npt.DTypeLike):
1014
+ if input_dtype != self._array_dtype:
1015
+ # see if it's the same dtype but packed, not aligned
1016
+ packed_dtype = recfunctions.repack_fields(self._array_dtype, align=False, recurse=True) # type: ignore
1017
+ if packed_dtype != input_dtype:
1018
+ if packed_dtype == self._array_dtype:
1019
+ message = f"Expected dtype {self._array_dtype}, got {input_dtype}"
1020
+ else:
1021
+ message = f"Expected dtype {self._array_dtype} or {packed_dtype}, got {input_dtype}"
1022
+
1023
+ raise ValueError(message)
1024
+
1025
+ def _read(
1026
+ self, shape: tuple[int, ...], json_object: list[object]
1027
+ ) -> npt.NDArray[Any]:
1028
+ subarray_shape_not_none = (
1029
+ () if self._subarray_shape is None else self._subarray_shape
1030
+ )
1031
+
1032
+ partially_flattened_shape = (np.prod(shape),) + subarray_shape_not_none # type: ignore
1033
+ result = np.ndarray(partially_flattened_shape, dtype=self._array_dtype)
1034
+ for i in range(partially_flattened_shape[0]):
1035
+ result[i] = self._element_converter.from_json_to_numpy(json_object[i])
1036
+
1037
+ return result.reshape(shape + subarray_shape_not_none)
1038
+
1039
+
1040
+ class FixedNDArrayConverter(Generic[T, T_NP], NDArrayConverterBase[T, T_NP]):
1041
+ def __init__(
1042
+ self,
1043
+ element_converter: JsonConverter[T, T_NP],
1044
+ shape: tuple[int, ...],
1045
+ ) -> None:
1046
+ dtype = element_converter.overall_dtype()
1047
+ super().__init__(np.dtype((dtype, shape)), element_converter, dtype)
1048
+ self._shape = shape
1049
+
1050
+ def to_json(self, value: npt.NDArray[Any]) -> object:
1051
+ if not isinstance(value, np.ndarray):
1052
+ raise ValueError(f"Value in not an ndarray: {value}")
1053
+
1054
+ self.check_dtype(value.dtype)
1055
+
1056
+ required_shape = (
1057
+ self._shape
1058
+ if self._subarray_shape is None
1059
+ else self._shape + self._subarray_shape
1060
+ )
1061
+
1062
+ if value.shape != required_shape:
1063
+ raise ValueError(f"Expected shape {required_shape}, got {value.shape}")
1064
+
1065
+ if self._subarray_shape is None:
1066
+ return [self._element_converter.numpy_to_json(v) for v in value.flat]
1067
+
1068
+ reshaped = value.reshape((-1,) + self._subarray_shape)
1069
+ return [self._element_converter.numpy_to_json(v) for v in reshaped]
1070
+
1071
+ def numpy_to_json(self, value: np.object_) -> object:
1072
+ return self.to_json(cast(npt.NDArray[Any], value))
1073
+
1074
+ def from_json(self, json_object: object) -> npt.NDArray[Any]:
1075
+ if not isinstance(json_object, list):
1076
+ raise ValueError(f"Value in not a list: {json_object}")
1077
+
1078
+ return self._read(self._shape, json_object)
1079
+
1080
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
1081
+ return cast(np.object_, self.from_json(json_object))
1082
+
1083
+
1084
+ class DynamicNDArrayConverter(NDArrayConverterBase[T, T_NP]):
1085
+ def __init__(
1086
+ self,
1087
+ element_serializer: JsonConverter[T, T_NP],
1088
+ ) -> None:
1089
+ super().__init__(
1090
+ np.object_, element_serializer, element_serializer.overall_dtype()
1091
+ )
1092
+
1093
+ def to_json(self, value: npt.NDArray[Any]) -> object:
1094
+ if not isinstance(value, np.ndarray):
1095
+ raise ValueError(f"Value in not an ndarray: {value}")
1096
+
1097
+ self.check_dtype(value.dtype)
1098
+
1099
+ if self._subarray_shape is None:
1100
+ return {
1101
+ "shape": value.shape,
1102
+ "data": [self._element_converter.numpy_to_json(v) for v in value.flat],
1103
+ }
1104
+
1105
+ if len(value.shape) < len(self._subarray_shape) or (
1106
+ value.shape[-len(self._subarray_shape) :] != self._subarray_shape
1107
+ ):
1108
+ raise ValueError(
1109
+ f"The array is required to have shape (..., {(', '.join((str(i) for i in self._subarray_shape)))})"
1110
+ )
1111
+
1112
+ reshaped = value.reshape((-1,) + self._subarray_shape)
1113
+ return {
1114
+ "shape": value.shape[: -len(self._subarray_shape)],
1115
+ "data": [self._element_converter.numpy_to_json(v) for v in reshaped],
1116
+ }
1117
+
1118
+ def numpy_to_json(self, value: np.object_) -> object:
1119
+ return self.to_json(cast(npt.NDArray[Any], value))
1120
+
1121
+ def from_json(self, json_object: object) -> npt.NDArray[Any]:
1122
+ if not isinstance(json_object, dict):
1123
+ raise ValueError(f"Value in not a dict: {json_object}")
1124
+
1125
+ if "shape" not in json_object or "data" not in json_object:
1126
+ raise ValueError(f"Value in not a dict with shape and data: {json_object}")
1127
+
1128
+ shape = tuple(json_object["shape"])
1129
+ data = json_object["data"]
1130
+
1131
+ return self._read(shape, data)
1132
+
1133
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
1134
+ return cast(np.object_, self.from_json(json_object))
1135
+
1136
+
1137
+ class NDArrayConverter(Generic[T, T_NP], NDArrayConverterBase[T, T_NP]):
1138
+ def __init__(
1139
+ self,
1140
+ element_converter: JsonConverter[T, T_NP],
1141
+ ndims: int,
1142
+ ) -> None:
1143
+ super().__init__(
1144
+ np.object_, element_converter, element_converter.overall_dtype()
1145
+ )
1146
+ self._ndims = ndims
1147
+
1148
+ def to_json(self, value: npt.NDArray[Any]) -> object:
1149
+ if not isinstance(value, np.ndarray):
1150
+ raise ValueError(f"Value in not an ndarray: {value}")
1151
+
1152
+ self.check_dtype(value.dtype)
1153
+
1154
+ if self._subarray_shape is None:
1155
+ if value.ndim != self._ndims:
1156
+ raise ValueError(f"Expected {self._ndims} dimensions, got {value.ndim}")
1157
+
1158
+ return {
1159
+ "shape": value.shape,
1160
+ "data": [self._element_converter.numpy_to_json(v) for v in value.flat],
1161
+ }
1162
+
1163
+ total_dims = len(self._subarray_shape) + self._ndims
1164
+ if value.ndim != total_dims:
1165
+ raise ValueError(f"Expected {total_dims} dimensions, got {value.ndim}")
1166
+
1167
+ if value.shape[-len(self._subarray_shape) :] != self._subarray_shape:
1168
+ raise ValueError(
1169
+ f"The array is required to have shape (..., {(', '.join((str(i) for i in self._subarray_shape)))})"
1170
+ )
1171
+
1172
+ reshaped = value.reshape((-1,) + self._subarray_shape)
1173
+ return {
1174
+ "shape": value.shape[: -len(self._subarray_shape)],
1175
+ "data": [self._element_converter.numpy_to_json(v) for v in reshaped],
1176
+ }
1177
+
1178
+ def numpy_to_json(self, value: np.object_) -> object:
1179
+ return self.to_json(cast(npt.NDArray[Any], value))
1180
+
1181
+ def from_json(self, json_object: object) -> npt.NDArray[Any]:
1182
+ if not isinstance(json_object, dict):
1183
+ raise ValueError(f"Value in not a dict: {json_object}")
1184
+
1185
+ if "shape" not in json_object or "data" not in json_object:
1186
+ raise ValueError(f"Value in not a dict with shape and data: {json_object}")
1187
+
1188
+ shape = tuple(json_object["shape"])
1189
+ data = json_object["data"]
1190
+
1191
+ return self._read(shape, data)
1192
+
1193
+ def from_json_to_numpy(self, json_object: object) -> np.object_:
1194
+ return cast(np.object_, self.from_json(json_object))