omdev 0.0.0.dev438__py3-none-any.whl → 0.0.0.dev440__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1689 @@
1
+ #!/usr/bin/env python3
2
+ # noinspection DuplicatedCode
3
+ # @omlish-lite
4
+ # @omlish-script
5
+ # @omlish-generated
6
+ # @omlish-amalg-output ../../../omlish/lite/marshal.py
7
+ # @omlish-git-diff-omit
8
+ # ruff: noqa: UP006 UP007 UP036 UP045
9
+ """
10
+ TODO:
11
+ - pickle stdlib objs? have to pin to 3.8 pickle protocol, will be cross-version
12
+ - Options.sequence_cls = list, mapping_cls = dict, ... - def with_mutable_containers() -> Options
13
+ """
14
+ import abc
15
+ import base64
16
+ import collections
17
+ import collections.abc
18
+ import dataclasses as dc # noqa
19
+ import datetime
20
+ import decimal
21
+ import enum
22
+ import fractions
23
+ import functools
24
+ import inspect
25
+ import sys
26
+ import threading
27
+ import types
28
+ import typing as ta
29
+ import uuid
30
+ import weakref
31
+
32
+
33
+ ########################################
34
+
35
+
36
+ if sys.version_info < (3, 8):
37
+ raise OSError(f'Requires python (3, 8), got {sys.version_info} from {sys.executable}') # noqa
38
+
39
+
40
+ ########################################
41
+
42
+
43
+ # check.py
44
+ T = ta.TypeVar('T')
45
+ SizedT = ta.TypeVar('SizedT', bound=ta.Sized)
46
+ CheckMessage = ta.Union[str, ta.Callable[..., ta.Optional[str]], None] # ta.TypeAlias
47
+ CheckLateConfigureFn = ta.Callable[['Checks'], None] # ta.TypeAlias
48
+ CheckOnRaiseFn = ta.Callable[[Exception], None] # ta.TypeAlias
49
+ CheckExceptionFactory = ta.Callable[..., Exception] # ta.TypeAlias
50
+ CheckArgsRenderer = ta.Callable[..., ta.Optional[str]] # ta.TypeAlias
51
+
52
+
53
+ ########################################
54
+ # ../abstract.py
55
+
56
+
57
+ ##
58
+
59
+
60
+ _ABSTRACT_METHODS_ATTR = '__abstractmethods__'
61
+ _IS_ABSTRACT_METHOD_ATTR = '__isabstractmethod__'
62
+
63
+
64
+ def is_abstract_method(obj: ta.Any) -> bool:
65
+ return bool(getattr(obj, _IS_ABSTRACT_METHOD_ATTR, False))
66
+
67
+
68
+ def update_abstracts(cls, *, force=False):
69
+ if not force and not hasattr(cls, _ABSTRACT_METHODS_ATTR):
70
+ # Per stdlib: We check for __abstractmethods__ here because cls might by a C implementation or a python
71
+ # implementation (especially during testing), and we want to handle both cases.
72
+ return cls
73
+
74
+ abstracts: ta.Set[str] = set()
75
+
76
+ for scls in cls.__bases__:
77
+ for name in getattr(scls, _ABSTRACT_METHODS_ATTR, ()):
78
+ value = getattr(cls, name, None)
79
+ if getattr(value, _IS_ABSTRACT_METHOD_ATTR, False):
80
+ abstracts.add(name)
81
+
82
+ for name, value in cls.__dict__.items():
83
+ if getattr(value, _IS_ABSTRACT_METHOD_ATTR, False):
84
+ abstracts.add(name)
85
+
86
+ setattr(cls, _ABSTRACT_METHODS_ATTR, frozenset(abstracts))
87
+ return cls
88
+
89
+
90
+ #
91
+
92
+
93
+ class AbstractTypeError(TypeError):
94
+ pass
95
+
96
+
97
+ _FORCE_ABSTRACT_ATTR = '__forceabstract__'
98
+
99
+
100
+ class Abstract:
101
+ """
102
+ Different from, but interoperable with, abc.ABC / abc.ABCMeta:
103
+
104
+ - This raises AbstractTypeError during class creation, not instance instantiation - unless Abstract or abc.ABC are
105
+ explicitly present in the class's direct bases.
106
+ - This will forbid instantiation of classes with Abstract in their direct bases even if there are no
107
+ abstractmethods left on the class.
108
+ - This is a mixin, not a metaclass.
109
+ - As it is not an ABCMeta, this does not support virtual base classes. As a result, operations like `isinstance`
110
+ and `issubclass` are ~7x faster.
111
+ - It additionally enforces a base class order of (Abstract, abc.ABC) to preemptively prevent common mro conflicts.
112
+
113
+ If not mixed-in with an ABCMeta, it will update __abstractmethods__ itself.
114
+ """
115
+
116
+ __slots__ = ()
117
+
118
+ __abstractmethods__: ta.ClassVar[ta.FrozenSet[str]] = frozenset()
119
+
120
+ #
121
+
122
+ def __forceabstract__(self):
123
+ raise TypeError
124
+
125
+ # This is done manually, rather than through @abc.abstractmethod, to mask it from static analysis.
126
+ setattr(__forceabstract__, _IS_ABSTRACT_METHOD_ATTR, True)
127
+
128
+ #
129
+
130
+ def __init_subclass__(cls, **kwargs: ta.Any) -> None:
131
+ setattr(
132
+ cls,
133
+ _FORCE_ABSTRACT_ATTR,
134
+ getattr(Abstract, _FORCE_ABSTRACT_ATTR) if Abstract in cls.__bases__ else False,
135
+ )
136
+
137
+ super().__init_subclass__(**kwargs)
138
+
139
+ if not (Abstract in cls.__bases__ or abc.ABC in cls.__bases__):
140
+ ams = {a: cls for a, o in cls.__dict__.items() if is_abstract_method(o)}
141
+
142
+ seen = set(cls.__dict__)
143
+ for b in cls.__bases__:
144
+ ams.update({a: b for a in set(getattr(b, _ABSTRACT_METHODS_ATTR, [])) - seen}) # noqa
145
+ seen.update(dir(b))
146
+
147
+ if ams:
148
+ raise AbstractTypeError(
149
+ f'Cannot subclass abstract class {cls.__name__} with abstract methods: ' +
150
+ ', '.join(sorted([
151
+ '.'.join([
152
+ *([m] if (m := getattr(c, '__module__')) else []),
153
+ getattr(c, '__qualname__', getattr(c, '__name__')),
154
+ a,
155
+ ])
156
+ for a, c in ams.items()
157
+ ])),
158
+ )
159
+
160
+ xbi = (Abstract, abc.ABC) # , ta.Generic ?
161
+ bis = [(cls.__bases__.index(b), b) for b in xbi if b in cls.__bases__]
162
+ if bis != sorted(bis):
163
+ raise TypeError(
164
+ f'Abstract subclass {cls.__name__} must have proper base class order of '
165
+ f'({", ".join(getattr(b, "__name__") for b in xbi)}), got: '
166
+ f'({", ".join(getattr(b, "__name__") for _, b in sorted(bis))})',
167
+ )
168
+
169
+ if not isinstance(cls, abc.ABCMeta):
170
+ update_abstracts(cls, force=True)
171
+
172
+
173
+ ########################################
174
+ # ../check.py
175
+ """
176
+ TODO:
177
+ - def maybe(v: lang.Maybe[T])
178
+ - def not_ ?
179
+ - ** class @dataclass Raise - user message should be able to be an exception type or instance or factory
180
+ """
181
+
182
+
183
+ ##
184
+
185
+
186
+ class Checks:
187
+ def __init__(self) -> None:
188
+ super().__init__()
189
+
190
+ self._config_lock = threading.RLock()
191
+ self._on_raise_fns: ta.Sequence[CheckOnRaiseFn] = []
192
+ self._exception_factory: CheckExceptionFactory = Checks.default_exception_factory
193
+ self._args_renderer: ta.Optional[CheckArgsRenderer] = None
194
+ self._late_configure_fns: ta.Sequence[CheckLateConfigureFn] = []
195
+
196
+ @staticmethod
197
+ def default_exception_factory(exc_cls: ta.Type[Exception], *args, **kwargs) -> Exception:
198
+ return exc_cls(*args, **kwargs) # noqa
199
+
200
+ #
201
+
202
+ def register_on_raise(self, fn: CheckOnRaiseFn) -> None:
203
+ with self._config_lock:
204
+ self._on_raise_fns = [*self._on_raise_fns, fn]
205
+
206
+ def unregister_on_raise(self, fn: CheckOnRaiseFn) -> None:
207
+ with self._config_lock:
208
+ self._on_raise_fns = [e for e in self._on_raise_fns if e != fn]
209
+
210
+ #
211
+
212
+ def register_on_raise_breakpoint_if_env_var_set(self, key: str) -> None:
213
+ import os
214
+
215
+ def on_raise(exc: Exception) -> None: # noqa
216
+ if key in os.environ:
217
+ breakpoint() # noqa
218
+
219
+ self.register_on_raise(on_raise)
220
+
221
+ #
222
+
223
+ def set_exception_factory(self, factory: CheckExceptionFactory) -> None:
224
+ self._exception_factory = factory
225
+
226
+ def set_args_renderer(self, renderer: ta.Optional[CheckArgsRenderer]) -> None:
227
+ self._args_renderer = renderer
228
+
229
+ #
230
+
231
+ def register_late_configure(self, fn: CheckLateConfigureFn) -> None:
232
+ with self._config_lock:
233
+ self._late_configure_fns = [*self._late_configure_fns, fn]
234
+
235
+ def _late_configure(self) -> None:
236
+ if not self._late_configure_fns:
237
+ return
238
+
239
+ with self._config_lock:
240
+ if not (lc := self._late_configure_fns):
241
+ return
242
+
243
+ for fn in lc:
244
+ fn(self)
245
+
246
+ self._late_configure_fns = []
247
+
248
+ #
249
+
250
+ class _ArgsKwargs:
251
+ def __init__(self, *args, **kwargs):
252
+ self.args = args
253
+ self.kwargs = kwargs
254
+
255
+ def _raise(
256
+ self,
257
+ exception_type: ta.Type[Exception],
258
+ default_message: str,
259
+ message: CheckMessage,
260
+ ak: _ArgsKwargs = _ArgsKwargs(),
261
+ *,
262
+ render_fmt: ta.Optional[str] = None,
263
+ ) -> ta.NoReturn:
264
+ exc_args = ()
265
+ if callable(message):
266
+ message = ta.cast(ta.Callable, message)(*ak.args, **ak.kwargs)
267
+ if isinstance(message, tuple):
268
+ message, *exc_args = message # type: ignore
269
+
270
+ if message is None:
271
+ message = default_message
272
+
273
+ self._late_configure()
274
+
275
+ if render_fmt is not None and (af := self._args_renderer) is not None:
276
+ rendered_args = af(render_fmt, *ak.args)
277
+ if rendered_args is not None:
278
+ message = f'{message} : {rendered_args}'
279
+
280
+ exc = self._exception_factory(
281
+ exception_type,
282
+ message,
283
+ *exc_args,
284
+ *ak.args,
285
+ **ak.kwargs,
286
+ )
287
+
288
+ for fn in self._on_raise_fns:
289
+ fn(exc)
290
+
291
+ raise exc
292
+
293
+ #
294
+
295
+ def _unpack_isinstance_spec(self, spec: ta.Any) -> tuple:
296
+ if isinstance(spec, type):
297
+ return (spec,)
298
+ if not isinstance(spec, tuple):
299
+ spec = (spec,)
300
+ if None in spec:
301
+ spec = tuple(filter(None, spec)) + (None.__class__,) # noqa
302
+ if ta.Any in spec:
303
+ spec = (object,)
304
+ return spec
305
+
306
+ @ta.overload
307
+ def isinstance(self, v: ta.Any, spec: ta.Type[T], msg: CheckMessage = None) -> T:
308
+ ...
309
+
310
+ @ta.overload
311
+ def isinstance(self, v: ta.Any, spec: ta.Any, msg: CheckMessage = None) -> ta.Any:
312
+ ...
313
+
314
+ def isinstance(self, v, spec, msg=None):
315
+ if not isinstance(v, self._unpack_isinstance_spec(spec)):
316
+ self._raise(
317
+ TypeError,
318
+ 'Must be instance',
319
+ msg,
320
+ Checks._ArgsKwargs(v, spec),
321
+ render_fmt='not isinstance(%s, %s)',
322
+ )
323
+
324
+ return v
325
+
326
+ @ta.overload
327
+ def of_isinstance(self, spec: ta.Type[T], msg: CheckMessage = None) -> ta.Callable[[ta.Any], T]:
328
+ ...
329
+
330
+ @ta.overload
331
+ def of_isinstance(self, spec: ta.Any, msg: CheckMessage = None) -> ta.Callable[[ta.Any], ta.Any]:
332
+ ...
333
+
334
+ def of_isinstance(self, spec, msg=None):
335
+ def inner(v):
336
+ return self.isinstance(v, self._unpack_isinstance_spec(spec), msg)
337
+
338
+ return inner
339
+
340
+ def cast(self, v: ta.Any, cls: ta.Type[T], msg: CheckMessage = None) -> T:
341
+ if not isinstance(v, cls):
342
+ self._raise(
343
+ TypeError,
344
+ 'Must be instance',
345
+ msg,
346
+ Checks._ArgsKwargs(v, cls),
347
+ )
348
+
349
+ return v
350
+
351
+ def of_cast(self, cls: ta.Type[T], msg: CheckMessage = None) -> ta.Callable[[T], T]:
352
+ def inner(v):
353
+ return self.cast(v, cls, msg)
354
+
355
+ return inner
356
+
357
+ def not_isinstance(self, v: T, spec: ta.Any, msg: CheckMessage = None) -> T: # noqa
358
+ if isinstance(v, self._unpack_isinstance_spec(spec)):
359
+ self._raise(
360
+ TypeError,
361
+ 'Must not be instance',
362
+ msg,
363
+ Checks._ArgsKwargs(v, spec),
364
+ render_fmt='isinstance(%s, %s)',
365
+ )
366
+
367
+ return v
368
+
369
+ def of_not_isinstance(self, spec: ta.Any, msg: CheckMessage = None) -> ta.Callable[[T], T]:
370
+ def inner(v):
371
+ return self.not_isinstance(v, self._unpack_isinstance_spec(spec), msg)
372
+
373
+ return inner
374
+
375
+ ##
376
+
377
+ def issubclass(self, v: ta.Type[T], spec: ta.Any, msg: CheckMessage = None) -> ta.Type[T]: # noqa
378
+ if not issubclass(v, spec):
379
+ self._raise(
380
+ TypeError,
381
+ 'Must be subclass',
382
+ msg,
383
+ Checks._ArgsKwargs(v, spec),
384
+ render_fmt='not issubclass(%s, %s)',
385
+ )
386
+
387
+ return v
388
+
389
+ def not_issubclass(self, v: ta.Type[T], spec: ta.Any, msg: CheckMessage = None) -> ta.Type[T]:
390
+ if issubclass(v, spec):
391
+ self._raise(
392
+ TypeError,
393
+ 'Must not be subclass',
394
+ msg,
395
+ Checks._ArgsKwargs(v, spec),
396
+ render_fmt='issubclass(%s, %s)',
397
+ )
398
+
399
+ return v
400
+
401
+ #
402
+
403
+ def in_(self, v: T, c: ta.Container[T], msg: CheckMessage = None) -> T:
404
+ if v not in c:
405
+ self._raise(
406
+ ValueError,
407
+ 'Must be in',
408
+ msg,
409
+ Checks._ArgsKwargs(v, c),
410
+ render_fmt='%s not in %s',
411
+ )
412
+
413
+ return v
414
+
415
+ def not_in(self, v: T, c: ta.Container[T], msg: CheckMessage = None) -> T:
416
+ if v in c:
417
+ self._raise(
418
+ ValueError,
419
+ 'Must not be in',
420
+ msg,
421
+ Checks._ArgsKwargs(v, c),
422
+ render_fmt='%s in %s',
423
+ )
424
+
425
+ return v
426
+
427
+ def empty(self, v: SizedT, msg: CheckMessage = None) -> SizedT:
428
+ if len(v) != 0:
429
+ self._raise(
430
+ ValueError,
431
+ 'Must be empty',
432
+ msg,
433
+ Checks._ArgsKwargs(v),
434
+ render_fmt='%s',
435
+ )
436
+
437
+ return v
438
+
439
+ def iterempty(self, v: ta.Iterable[T], msg: CheckMessage = None) -> ta.Iterable[T]:
440
+ it = iter(v)
441
+ try:
442
+ next(it)
443
+ except StopIteration:
444
+ pass
445
+ else:
446
+ self._raise(
447
+ ValueError,
448
+ 'Must be empty',
449
+ msg,
450
+ Checks._ArgsKwargs(v),
451
+ render_fmt='%s',
452
+ )
453
+
454
+ return v
455
+
456
+ def not_empty(self, v: SizedT, msg: CheckMessage = None) -> SizedT:
457
+ if len(v) == 0:
458
+ self._raise(
459
+ ValueError,
460
+ 'Must not be empty',
461
+ msg,
462
+ Checks._ArgsKwargs(v),
463
+ render_fmt='%s',
464
+ )
465
+
466
+ return v
467
+
468
+ def unique(self, it: ta.Iterable[T], msg: CheckMessage = None) -> ta.Iterable[T]:
469
+ dupes = [e for e, c in collections.Counter(it).items() if c > 1]
470
+ if dupes:
471
+ self._raise(
472
+ ValueError,
473
+ 'Must be unique',
474
+ msg,
475
+ Checks._ArgsKwargs(it, dupes),
476
+ )
477
+
478
+ return it
479
+
480
+ def single(self, obj: ta.Iterable[T], msg: CheckMessage = None) -> T:
481
+ try:
482
+ [value] = obj
483
+ except ValueError:
484
+ self._raise(
485
+ ValueError,
486
+ 'Must be single',
487
+ msg,
488
+ Checks._ArgsKwargs(obj),
489
+ render_fmt='%s',
490
+ )
491
+
492
+ return value
493
+
494
+ def opt_single(self, obj: ta.Iterable[T], msg: CheckMessage = None) -> ta.Optional[T]:
495
+ it = iter(obj)
496
+ try:
497
+ value = next(it)
498
+ except StopIteration:
499
+ return None
500
+
501
+ try:
502
+ next(it)
503
+ except StopIteration:
504
+ return value # noqa
505
+
506
+ self._raise(
507
+ ValueError,
508
+ 'Must be empty or single',
509
+ msg,
510
+ Checks._ArgsKwargs(obj),
511
+ render_fmt='%s',
512
+ )
513
+
514
+ raise RuntimeError # noqa
515
+
516
+ #
517
+
518
+ def none(self, v: ta.Any, msg: CheckMessage = None) -> None:
519
+ if v is not None:
520
+ self._raise(
521
+ ValueError,
522
+ 'Must be None',
523
+ msg,
524
+ Checks._ArgsKwargs(v),
525
+ render_fmt='%s',
526
+ )
527
+
528
+ def not_none(self, v: ta.Optional[T], msg: CheckMessage = None) -> T:
529
+ if v is None:
530
+ self._raise(
531
+ ValueError,
532
+ 'Must not be None',
533
+ msg,
534
+ Checks._ArgsKwargs(v),
535
+ render_fmt='%s',
536
+ )
537
+
538
+ return v
539
+
540
+ #
541
+
542
+ def equal(self, v: T, o: ta.Any, msg: CheckMessage = None) -> T:
543
+ if o != v:
544
+ self._raise(
545
+ ValueError,
546
+ 'Must be equal',
547
+ msg,
548
+ Checks._ArgsKwargs(v, o),
549
+ render_fmt='%s != %s',
550
+ )
551
+
552
+ return v
553
+
554
+ def not_equal(self, v: T, o: ta.Any, msg: CheckMessage = None) -> T:
555
+ if o == v:
556
+ self._raise(
557
+ ValueError,
558
+ 'Must not be equal',
559
+ msg,
560
+ Checks._ArgsKwargs(v, o),
561
+ render_fmt='%s == %s',
562
+ )
563
+
564
+ return v
565
+
566
+ def is_(self, v: T, o: ta.Any, msg: CheckMessage = None) -> T:
567
+ if o is not v:
568
+ self._raise(
569
+ ValueError,
570
+ 'Must be the same',
571
+ msg,
572
+ Checks._ArgsKwargs(v, o),
573
+ render_fmt='%s is not %s',
574
+ )
575
+
576
+ return v
577
+
578
+ def is_not(self, v: T, o: ta.Any, msg: CheckMessage = None) -> T:
579
+ if o is v:
580
+ self._raise(
581
+ ValueError,
582
+ 'Must not be the same',
583
+ msg,
584
+ Checks._ArgsKwargs(v, o),
585
+ render_fmt='%s is %s',
586
+ )
587
+
588
+ return v
589
+
590
+ def callable(self, v: T, msg: CheckMessage = None) -> T: # noqa
591
+ if not callable(v):
592
+ self._raise(
593
+ TypeError,
594
+ 'Must be callable',
595
+ msg,
596
+ Checks._ArgsKwargs(v),
597
+ render_fmt='%s',
598
+ )
599
+
600
+ return v
601
+
602
+ def non_empty_str(self, v: ta.Optional[str], msg: CheckMessage = None) -> str:
603
+ if not isinstance(v, str) or not v:
604
+ self._raise(
605
+ ValueError,
606
+ 'Must be non-empty str',
607
+ msg,
608
+ Checks._ArgsKwargs(v),
609
+ render_fmt='%s',
610
+ )
611
+
612
+ return v
613
+
614
+ def replacing(self, expected: ta.Any, old: ta.Any, new: T, msg: CheckMessage = None) -> T:
615
+ if old != expected:
616
+ self._raise(
617
+ ValueError,
618
+ 'Must be replacing',
619
+ msg,
620
+ Checks._ArgsKwargs(expected, old, new),
621
+ render_fmt='%s -> %s -> %s',
622
+ )
623
+
624
+ return new
625
+
626
+ def replacing_none(self, old: ta.Any, new: T, msg: CheckMessage = None) -> T:
627
+ if old is not None:
628
+ self._raise(
629
+ ValueError,
630
+ 'Must be replacing None',
631
+ msg,
632
+ Checks._ArgsKwargs(old, new),
633
+ render_fmt='%s -> %s',
634
+ )
635
+
636
+ return new
637
+
638
+ #
639
+
640
+ def arg(self, v: bool, msg: CheckMessage = None) -> None:
641
+ if not v:
642
+ self._raise(
643
+ RuntimeError,
644
+ 'Argument condition not met',
645
+ msg,
646
+ Checks._ArgsKwargs(v),
647
+ render_fmt='%s',
648
+ )
649
+
650
+ def state(self, v: bool, msg: CheckMessage = None) -> None:
651
+ if not v:
652
+ self._raise(
653
+ RuntimeError,
654
+ 'State condition not met',
655
+ msg,
656
+ Checks._ArgsKwargs(v),
657
+ render_fmt='%s',
658
+ )
659
+
660
+
661
+ check = Checks()
662
+
663
+
664
+ ########################################
665
+ # ../objects.py
666
+
667
+
668
+ ##
669
+
670
+
671
+ def deep_subclasses(cls: ta.Type[T]) -> ta.Iterator[ta.Type[T]]:
672
+ seen = set()
673
+ todo = list(reversed(cls.__subclasses__()))
674
+ while todo:
675
+ cur = todo.pop()
676
+ if cur in seen:
677
+ continue
678
+ seen.add(cur)
679
+ yield cur
680
+ todo.extend(reversed(cur.__subclasses__()))
681
+
682
+
683
+ ##
684
+
685
+
686
+ def mro_owner_dict(
687
+ instance_cls: type,
688
+ owner_cls: ta.Optional[type] = None,
689
+ *,
690
+ bottom_up_key_order: bool = False,
691
+ sort_keys: bool = False,
692
+ ) -> ta.Mapping[str, ta.Tuple[type, ta.Any]]:
693
+ if owner_cls is None:
694
+ owner_cls = instance_cls
695
+
696
+ mro = instance_cls.__mro__[-2::-1]
697
+ try:
698
+ pos = mro.index(owner_cls)
699
+ except ValueError:
700
+ raise TypeError(f'Owner class {owner_cls} not in mro of instance class {instance_cls}') from None
701
+
702
+ dct: ta.Dict[str, ta.Tuple[type, ta.Any]] = {}
703
+ if not bottom_up_key_order:
704
+ for cur_cls in mro[:pos + 1][::-1]:
705
+ for k, v in cur_cls.__dict__.items():
706
+ if k not in dct:
707
+ dct[k] = (cur_cls, v)
708
+
709
+ else:
710
+ for cur_cls in mro[:pos + 1]:
711
+ dct.update({k: (cur_cls, v) for k, v in cur_cls.__dict__.items()})
712
+
713
+ if sort_keys:
714
+ dct = dict(sorted(dct.items(), key=lambda t: t[0]))
715
+
716
+ return dct
717
+
718
+
719
+ def mro_dict(
720
+ instance_cls: type,
721
+ owner_cls: ta.Optional[type] = None,
722
+ *,
723
+ bottom_up_key_order: bool = False,
724
+ sort_keys: bool = False,
725
+ ) -> ta.Mapping[str, ta.Any]:
726
+ return {
727
+ k: v
728
+ for k, (o, v) in mro_owner_dict(
729
+ instance_cls,
730
+ owner_cls,
731
+ bottom_up_key_order=bottom_up_key_order,
732
+ sort_keys=sort_keys,
733
+ ).items()
734
+ }
735
+
736
+
737
+ def dir_dict(o: ta.Any) -> ta.Dict[str, ta.Any]:
738
+ return {
739
+ a: getattr(o, a)
740
+ for a in dir(o)
741
+ }
742
+
743
+
744
+ ########################################
745
+ # ../reflect.py
746
+
747
+
748
+ ##
749
+
750
+
751
+ _GENERIC_ALIAS_TYPES = (
752
+ ta._GenericAlias, # type: ignore # noqa
753
+ *([ta._SpecialGenericAlias] if hasattr(ta, '_SpecialGenericAlias') else []), # noqa
754
+ )
755
+
756
+
757
+ def is_generic_alias(obj: ta.Any, *, origin: ta.Any = None) -> bool:
758
+ return (
759
+ isinstance(obj, _GENERIC_ALIAS_TYPES) and
760
+ (origin is None or ta.get_origin(obj) is origin)
761
+ )
762
+
763
+
764
+ is_callable_alias = functools.partial(is_generic_alias, origin=ta.Callable)
765
+
766
+
767
+ ##
768
+
769
+
770
+ _UNION_ALIAS_ORIGINS = frozenset([
771
+ ta.get_origin(ta.Optional[int]),
772
+ *(
773
+ [
774
+ ta.get_origin(int | None),
775
+ ta.get_origin(getattr(ta, 'TypeVar')('_T') | None),
776
+ ] if sys.version_info >= (3, 10) else ()
777
+ ),
778
+ ])
779
+
780
+
781
+ def is_union_alias(obj: ta.Any) -> bool:
782
+ return ta.get_origin(obj) in _UNION_ALIAS_ORIGINS
783
+
784
+
785
+ #
786
+
787
+
788
+ def is_optional_alias(spec: ta.Any) -> bool:
789
+ return (
790
+ is_union_alias(spec) and
791
+ len(ta.get_args(spec)) == 2 and
792
+ any(a in (None, type(None)) for a in ta.get_args(spec))
793
+ )
794
+
795
+
796
+ def get_optional_alias_arg(spec: ta.Any) -> ta.Any:
797
+ [it] = [it for it in ta.get_args(spec) if it not in (None, type(None))]
798
+ return it
799
+
800
+
801
+ ##
802
+
803
+
804
+ def is_new_type(spec: ta.Any) -> bool:
805
+ if isinstance(ta.NewType, type):
806
+ return isinstance(spec, ta.NewType)
807
+ else:
808
+ # Before https://github.com/python/cpython/commit/c2f33dfc83ab270412bf243fb21f724037effa1a
809
+ return isinstance(spec, types.FunctionType) and spec.__code__ is ta.NewType.__code__.co_consts[1] # type: ignore # noqa
810
+
811
+
812
+ def get_new_type_supertype(spec: ta.Any) -> ta.Any:
813
+ return spec.__supertype__
814
+
815
+
816
+ ##
817
+
818
+
819
+ def is_literal_type(spec: ta.Any) -> bool:
820
+ if hasattr(ta, '_LiteralGenericAlias'):
821
+ return isinstance(spec, ta._LiteralGenericAlias) # noqa
822
+ else:
823
+ return (
824
+ isinstance(spec, ta._GenericAlias) and # type: ignore # noqa
825
+ spec.__origin__ is ta.Literal
826
+ )
827
+
828
+
829
+ def get_literal_type_args(spec: ta.Any) -> ta.Iterable[ta.Any]:
830
+ return spec.__args__
831
+
832
+
833
+ ########################################
834
+ # ../strings.py
835
+
836
+
837
+ ##
838
+
839
+
840
+ def camel_case(name: str, *, lower: bool = False) -> str:
841
+ if not name:
842
+ return ''
843
+ s = ''.join(map(str.capitalize, name.split('_'))) # noqa
844
+ if lower:
845
+ s = s[0].lower() + s[1:]
846
+ return s
847
+
848
+
849
+ def snake_case(name: str) -> str:
850
+ uppers: list[int | None] = [i for i, c in enumerate(name) if c.isupper()]
851
+ return '_'.join([name[l:r].lower() for l, r in zip([None, *uppers], [*uppers, None])]).strip('_')
852
+
853
+
854
+ ##
855
+
856
+
857
+ def is_dunder(name: str) -> bool:
858
+ return (
859
+ name[:2] == name[-2:] == '__' and
860
+ name[2:3] != '_' and
861
+ name[-3:-2] != '_' and
862
+ len(name) > 4
863
+ )
864
+
865
+
866
+ def is_sunder(name: str) -> bool:
867
+ return (
868
+ name[0] == name[-1] == '_' and
869
+ name[1:2] != '_' and
870
+ name[-2:-1] != '_' and
871
+ len(name) > 2
872
+ )
873
+
874
+
875
+ ##
876
+
877
+
878
+ def strip_with_newline(s: str) -> str:
879
+ if not s:
880
+ return ''
881
+ return s.strip() + '\n'
882
+
883
+
884
+ @ta.overload
885
+ def split_keep_delimiter(s: str, d: str) -> str:
886
+ ...
887
+
888
+
889
+ @ta.overload
890
+ def split_keep_delimiter(s: bytes, d: bytes) -> bytes:
891
+ ...
892
+
893
+
894
+ def split_keep_delimiter(s, d):
895
+ ps = []
896
+ i = 0
897
+ while i < len(s):
898
+ if (n := s.find(d, i)) < i:
899
+ ps.append(s[i:])
900
+ break
901
+ ps.append(s[i:n + 1])
902
+ i = n + 1
903
+ return ps
904
+
905
+
906
+ ##
907
+
908
+
909
+ FORMAT_NUM_BYTES_SUFFIXES: ta.Sequence[str] = ['B', 'kB', 'MB', 'GB', 'TB', 'PB', 'EB']
910
+
911
+
912
+ def format_num_bytes(num_bytes: int) -> str:
913
+ for i, suffix in enumerate(FORMAT_NUM_BYTES_SUFFIXES):
914
+ value = num_bytes / 1024 ** i
915
+ if num_bytes < 1024 ** (i + 1):
916
+ if value.is_integer():
917
+ return f'{int(value)}{suffix}'
918
+ else:
919
+ return f'{value:.2f}{suffix}'
920
+
921
+ return f'{num_bytes / 1024 ** (len(FORMAT_NUM_BYTES_SUFFIXES) - 1):.2f}{FORMAT_NUM_BYTES_SUFFIXES[-1]}'
922
+
923
+
924
+ ########################################
925
+ # marshal.py
926
+
927
+
928
+ ##
929
+
930
+
931
+ @dc.dataclass(frozen=True)
932
+ class ObjMarshalOptions:
933
+ raw_bytes: bool = False
934
+ non_strict_fields: bool = False
935
+
936
+
937
+ class ObjMarshaler(Abstract):
938
+ @abc.abstractmethod
939
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
940
+ raise NotImplementedError
941
+
942
+ @abc.abstractmethod
943
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
944
+ raise NotImplementedError
945
+
946
+
947
+ class NopObjMarshaler(ObjMarshaler):
948
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
949
+ return o
950
+
951
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
952
+ return o
953
+
954
+
955
+ class ProxyObjMarshaler(ObjMarshaler):
956
+ def __init__(self, m: ta.Optional[ObjMarshaler] = None) -> None:
957
+ super().__init__()
958
+
959
+ self._m = m
960
+
961
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
962
+ return check.not_none(self._m).marshal(o, ctx)
963
+
964
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
965
+ return check.not_none(self._m).unmarshal(o, ctx)
966
+
967
+
968
+ class CastObjMarshaler(ObjMarshaler):
969
+ def __init__(self, ty: type) -> None:
970
+ super().__init__()
971
+
972
+ self._ty = ty
973
+
974
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
975
+ return o
976
+
977
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
978
+ return self._ty(o)
979
+
980
+
981
+ class DynamicObjMarshaler(ObjMarshaler):
982
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
983
+ return ctx.manager.marshal_obj(o, opts=ctx.options)
984
+
985
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
986
+ return o
987
+
988
+
989
+ class Base64ObjMarshaler(ObjMarshaler):
990
+ def __init__(self, ty: type) -> None:
991
+ super().__init__()
992
+
993
+ self._ty = ty
994
+
995
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
996
+ return base64.b64encode(o).decode('ascii')
997
+
998
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
999
+ return self._ty(base64.b64decode(o))
1000
+
1001
+
1002
+ class BytesSwitchedObjMarshaler(ObjMarshaler):
1003
+ def __init__(self, m: ObjMarshaler) -> None:
1004
+ super().__init__()
1005
+
1006
+ self._m = m
1007
+
1008
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1009
+ if ctx.options.raw_bytes:
1010
+ return o
1011
+ return self._m.marshal(o, ctx)
1012
+
1013
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1014
+ if ctx.options.raw_bytes:
1015
+ return o
1016
+ return self._m.unmarshal(o, ctx)
1017
+
1018
+
1019
+ class EnumObjMarshaler(ObjMarshaler):
1020
+ def __init__(self, ty: type) -> None:
1021
+ super().__init__()
1022
+
1023
+ self._ty = ty
1024
+
1025
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1026
+ return o.name
1027
+
1028
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1029
+ return self._ty.__members__[o] # type: ignore
1030
+
1031
+
1032
+ class OptionalObjMarshaler(ObjMarshaler):
1033
+ def __init__(self, item: ObjMarshaler) -> None:
1034
+ super().__init__()
1035
+
1036
+ self._item = item
1037
+
1038
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1039
+ if o is None:
1040
+ return None
1041
+ return self._item.marshal(o, ctx)
1042
+
1043
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1044
+ if o is None:
1045
+ return None
1046
+ return self._item.unmarshal(o, ctx)
1047
+
1048
+
1049
+ class PrimitiveUnionObjMarshaler(ObjMarshaler):
1050
+ def __init__(
1051
+ self,
1052
+ pt: ta.Tuple[type, ...],
1053
+ x: ta.Optional[ObjMarshaler] = None,
1054
+ ) -> None:
1055
+ super().__init__()
1056
+
1057
+ self._pt = pt
1058
+ self._x = x
1059
+
1060
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1061
+ if isinstance(o, self._pt):
1062
+ return o
1063
+ elif self._x is not None:
1064
+ return self._x.marshal(o, ctx)
1065
+ else:
1066
+ raise TypeError(o)
1067
+
1068
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1069
+ if isinstance(o, self._pt):
1070
+ return o
1071
+ elif self._x is not None:
1072
+ return self._x.unmarshal(o, ctx)
1073
+ else:
1074
+ raise TypeError(o)
1075
+
1076
+
1077
+ class LiteralObjMarshaler(ObjMarshaler):
1078
+ def __init__(
1079
+ self,
1080
+ item: ObjMarshaler,
1081
+ vs: frozenset,
1082
+ ) -> None:
1083
+ super().__init__()
1084
+
1085
+ self._item = item
1086
+ self._vs = vs
1087
+
1088
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1089
+ return self._item.marshal(check.in_(o, self._vs), ctx)
1090
+
1091
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1092
+ return check.in_(self._item.unmarshal(o, ctx), self._vs)
1093
+
1094
+
1095
+ class MappingObjMarshaler(ObjMarshaler):
1096
+ def __init__(
1097
+ self,
1098
+ ty: type,
1099
+ km: ObjMarshaler,
1100
+ vm: ObjMarshaler,
1101
+ ) -> None:
1102
+ super().__init__()
1103
+
1104
+ self._ty = ty
1105
+ self._km = km
1106
+ self._vm = vm
1107
+
1108
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1109
+ return {self._km.marshal(k, ctx): self._vm.marshal(v, ctx) for k, v in o.items()}
1110
+
1111
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1112
+ return self._ty((self._km.unmarshal(k, ctx), self._vm.unmarshal(v, ctx)) for k, v in o.items())
1113
+
1114
+
1115
+ class IterableObjMarshaler(ObjMarshaler):
1116
+ def __init__(
1117
+ self,
1118
+ ty: type,
1119
+ item: ObjMarshaler,
1120
+ ) -> None:
1121
+ super().__init__()
1122
+
1123
+ self._ty = ty
1124
+ self._item = item
1125
+
1126
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1127
+ return [self._item.marshal(e, ctx) for e in o]
1128
+
1129
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1130
+ return self._ty(self._item.unmarshal(e, ctx) for e in o)
1131
+
1132
+
1133
+ class FieldsObjMarshaler(ObjMarshaler):
1134
+ @dc.dataclass(frozen=True)
1135
+ class Field:
1136
+ att: str
1137
+ key: str
1138
+ m: ObjMarshaler
1139
+
1140
+ omit_if_none: bool = False
1141
+
1142
+ def __init__(
1143
+ self,
1144
+ ty: type,
1145
+ fs: ta.Sequence[Field],
1146
+ *,
1147
+ non_strict: bool = False,
1148
+ ) -> None:
1149
+ super().__init__()
1150
+
1151
+ self._ty = ty
1152
+ self._fs = fs
1153
+ self._non_strict = non_strict
1154
+
1155
+ fs_by_att: dict = {}
1156
+ fs_by_key: dict = {}
1157
+ for f in self._fs:
1158
+ check.not_in(check.non_empty_str(f.att), fs_by_att)
1159
+ check.not_in(check.non_empty_str(f.key), fs_by_key)
1160
+ fs_by_att[f.att] = f
1161
+ fs_by_key[f.key] = f
1162
+
1163
+ self._fs_by_att: ta.Mapping[str, FieldsObjMarshaler.Field] = fs_by_att
1164
+ self._fs_by_key: ta.Mapping[str, FieldsObjMarshaler.Field] = fs_by_key
1165
+
1166
+ @property
1167
+ def ty(self) -> type:
1168
+ return self._ty
1169
+
1170
+ @property
1171
+ def fs(self) -> ta.Sequence[Field]:
1172
+ return self._fs
1173
+
1174
+ #
1175
+
1176
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1177
+ d = {}
1178
+ for f in self._fs:
1179
+ mv = f.m.marshal(getattr(o, f.att), ctx)
1180
+ if mv is None and f.omit_if_none:
1181
+ continue
1182
+ d[f.key] = mv
1183
+ return d
1184
+
1185
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1186
+ kw = {}
1187
+ for k, v in o.items():
1188
+ if (f := self._fs_by_key.get(k)) is None:
1189
+ if not (self._non_strict or ctx.options.non_strict_fields):
1190
+ raise KeyError(k)
1191
+ continue
1192
+ kw[f.att] = f.m.unmarshal(v, ctx)
1193
+ return self._ty(**kw)
1194
+
1195
+
1196
+ class SingleFieldObjMarshaler(ObjMarshaler):
1197
+ def __init__(
1198
+ self,
1199
+ ty: type,
1200
+ fld: str,
1201
+ ) -> None:
1202
+ super().__init__()
1203
+
1204
+ self._ty = ty
1205
+ self._fld = fld
1206
+
1207
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1208
+ return getattr(o, self._fld)
1209
+
1210
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1211
+ return self._ty(**{self._fld: o})
1212
+
1213
+
1214
+ class PolymorphicObjMarshaler(ObjMarshaler):
1215
+ class Impl(ta.NamedTuple):
1216
+ ty: type
1217
+ tag: str
1218
+ m: ObjMarshaler
1219
+
1220
+ def __init__(
1221
+ self,
1222
+ impls_by_ty: ta.Mapping[type, Impl],
1223
+ impls_by_tag: ta.Mapping[str, Impl],
1224
+ ) -> None:
1225
+ super().__init__()
1226
+
1227
+ self._impls_by_ty = impls_by_ty
1228
+ self._impls_by_tag = impls_by_tag
1229
+
1230
+ @classmethod
1231
+ def of(cls, impls: ta.Iterable[Impl]) -> 'PolymorphicObjMarshaler':
1232
+ return cls(
1233
+ {i.ty: i for i in impls},
1234
+ {i.tag: i for i in impls},
1235
+ )
1236
+
1237
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1238
+ impl = self._impls_by_ty[type(o)]
1239
+ return {impl.tag: impl.m.marshal(o, ctx)}
1240
+
1241
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1242
+ [(t, v)] = o.items()
1243
+ impl = self._impls_by_tag[t]
1244
+ return impl.m.unmarshal(v, ctx)
1245
+
1246
+
1247
+ class DatetimeObjMarshaler(ObjMarshaler):
1248
+ def __init__(
1249
+ self,
1250
+ ty: type,
1251
+ ) -> None:
1252
+ super().__init__()
1253
+
1254
+ self._ty = ty
1255
+
1256
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1257
+ return o.isoformat()
1258
+
1259
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1260
+ return self._ty.fromisoformat(o) # type: ignore
1261
+
1262
+
1263
+ class DecimalObjMarshaler(ObjMarshaler):
1264
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1265
+ return str(check.isinstance(o, decimal.Decimal))
1266
+
1267
+ def unmarshal(self, v: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1268
+ return decimal.Decimal(check.isinstance(v, str))
1269
+
1270
+
1271
+ class FractionObjMarshaler(ObjMarshaler):
1272
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1273
+ fr = check.isinstance(o, fractions.Fraction)
1274
+ return [fr.numerator, fr.denominator]
1275
+
1276
+ def unmarshal(self, v: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1277
+ num, denom = check.isinstance(v, list)
1278
+ return fractions.Fraction(num, denom)
1279
+
1280
+
1281
+ class UuidObjMarshaler(ObjMarshaler):
1282
+ def marshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1283
+ return str(o)
1284
+
1285
+ def unmarshal(self, o: ta.Any, ctx: 'ObjMarshalContext') -> ta.Any:
1286
+ return uuid.UUID(o)
1287
+
1288
+
1289
+ ##
1290
+
1291
+
1292
+ _DEFAULT_OBJ_MARSHALERS: ta.Dict[ta.Any, ObjMarshaler] = {
1293
+ **{t: NopObjMarshaler() for t in (type(None),)},
1294
+ **{t: CastObjMarshaler(t) for t in (int, float, str, bool)},
1295
+ **{t: BytesSwitchedObjMarshaler(Base64ObjMarshaler(t)) for t in (bytes, bytearray)},
1296
+ **{t: IterableObjMarshaler(t, DynamicObjMarshaler()) for t in (list, tuple, set, frozenset)},
1297
+ **{t: MappingObjMarshaler(t, DynamicObjMarshaler(), DynamicObjMarshaler()) for t in (dict,)},
1298
+
1299
+ **{t: DynamicObjMarshaler() for t in (ta.Any, object)},
1300
+
1301
+ **{t: DatetimeObjMarshaler(t) for t in (datetime.date, datetime.time, datetime.datetime)},
1302
+ decimal.Decimal: DecimalObjMarshaler(),
1303
+ fractions.Fraction: FractionObjMarshaler(),
1304
+ uuid.UUID: UuidObjMarshaler(),
1305
+ }
1306
+
1307
+ _OBJ_MARSHALER_GENERIC_MAPPING_TYPES: ta.Dict[ta.Any, type] = {
1308
+ **{t: t for t in (dict,)},
1309
+ **{t: dict for t in (collections.abc.Mapping, collections.abc.MutableMapping)}, # noqa
1310
+ }
1311
+
1312
+ _OBJ_MARSHALER_GENERIC_ITERABLE_TYPES: ta.Dict[ta.Any, type] = {
1313
+ **{t: t for t in (list, tuple, set, frozenset)},
1314
+ collections.abc.Set: frozenset,
1315
+ collections.abc.MutableSet: set,
1316
+ collections.abc.Sequence: tuple,
1317
+ collections.abc.MutableSequence: list,
1318
+ }
1319
+
1320
+ _OBJ_MARSHALER_PRIMITIVE_TYPES: ta.Set[type] = {
1321
+ int,
1322
+ float,
1323
+ bool,
1324
+ str,
1325
+ }
1326
+
1327
+
1328
+ ##
1329
+
1330
+
1331
+ _REGISTERED_OBJ_MARSHALERS_BY_TYPE: ta.MutableMapping[type, ObjMarshaler] = weakref.WeakKeyDictionary()
1332
+
1333
+
1334
+ def register_type_obj_marshaler(ty: type, om: ObjMarshaler) -> None:
1335
+ _REGISTERED_OBJ_MARSHALERS_BY_TYPE[ty] = om
1336
+
1337
+
1338
+ def register_single_field_type_obj_marshaler(fld, ty=None):
1339
+ def inner(ty): # noqa
1340
+ register_type_obj_marshaler(ty, SingleFieldObjMarshaler(ty, fld))
1341
+ return ty
1342
+
1343
+ if ty is not None:
1344
+ return inner(ty)
1345
+ else:
1346
+ return inner
1347
+
1348
+
1349
+ ##
1350
+
1351
+
1352
+ class ObjMarshalerFieldMetadata:
1353
+ def __new__(cls, *args, **kwargs): # noqa
1354
+ raise TypeError
1355
+
1356
+
1357
+ class OBJ_MARSHALER_FIELD_KEY(ObjMarshalerFieldMetadata): # noqa
1358
+ pass
1359
+
1360
+
1361
+ class OBJ_MARSHALER_OMIT_IF_NONE(ObjMarshalerFieldMetadata): # noqa
1362
+ pass
1363
+
1364
+
1365
+ ##
1366
+
1367
+
1368
+ class ObjMarshalerManager(Abstract):
1369
+ @abc.abstractmethod
1370
+ def make_obj_marshaler(
1371
+ self,
1372
+ ty: ta.Any,
1373
+ rec: ta.Callable[[ta.Any], ObjMarshaler],
1374
+ *,
1375
+ non_strict_fields: bool = False,
1376
+ ) -> ObjMarshaler:
1377
+ raise NotImplementedError
1378
+
1379
+ @abc.abstractmethod
1380
+ def set_obj_marshaler(
1381
+ self,
1382
+ ty: ta.Any,
1383
+ m: ObjMarshaler,
1384
+ *,
1385
+ override: bool = False,
1386
+ ) -> None:
1387
+ raise NotImplementedError
1388
+
1389
+ @abc.abstractmethod
1390
+ def get_obj_marshaler(
1391
+ self,
1392
+ ty: ta.Any,
1393
+ *,
1394
+ no_cache: bool = False,
1395
+ **kwargs: ta.Any,
1396
+ ) -> ObjMarshaler:
1397
+ raise NotImplementedError
1398
+
1399
+ @abc.abstractmethod
1400
+ def make_context(self, opts: ta.Optional[ObjMarshalOptions]) -> 'ObjMarshalContext':
1401
+ raise NotImplementedError
1402
+
1403
+ #
1404
+
1405
+ def marshal_obj(
1406
+ self,
1407
+ o: ta.Any,
1408
+ ty: ta.Any = None,
1409
+ opts: ta.Optional[ObjMarshalOptions] = None,
1410
+ ) -> ta.Any:
1411
+ m = self.get_obj_marshaler(ty if ty is not None else type(o))
1412
+ return m.marshal(o, self.make_context(opts))
1413
+
1414
+ def unmarshal_obj(
1415
+ self,
1416
+ o: ta.Any,
1417
+ ty: ta.Union[ta.Type[T], ta.Any],
1418
+ opts: ta.Optional[ObjMarshalOptions] = None,
1419
+ ) -> T:
1420
+ m = self.get_obj_marshaler(ty)
1421
+ return m.unmarshal(o, self.make_context(opts))
1422
+
1423
+ def roundtrip_obj(
1424
+ self,
1425
+ o: ta.Any,
1426
+ ty: ta.Any = None,
1427
+ opts: ta.Optional[ObjMarshalOptions] = None,
1428
+ ) -> ta.Any:
1429
+ if ty is None:
1430
+ ty = type(o)
1431
+ m: ta.Any = self.marshal_obj(o, ty, opts)
1432
+ u: ta.Any = self.unmarshal_obj(m, ty, opts)
1433
+ return u
1434
+
1435
+
1436
+ #
1437
+
1438
+
1439
+ class ObjMarshalerManagerImpl(ObjMarshalerManager):
1440
+ def __init__(
1441
+ self,
1442
+ *,
1443
+ default_options: ObjMarshalOptions = ObjMarshalOptions(),
1444
+
1445
+ default_obj_marshalers: ta.Dict[ta.Any, ObjMarshaler] = _DEFAULT_OBJ_MARSHALERS, # noqa
1446
+ generic_mapping_types: ta.Dict[ta.Any, type] = _OBJ_MARSHALER_GENERIC_MAPPING_TYPES, # noqa
1447
+ generic_iterable_types: ta.Dict[ta.Any, type] = _OBJ_MARSHALER_GENERIC_ITERABLE_TYPES, # noqa
1448
+
1449
+ registered_obj_marshalers: ta.Mapping[type, ObjMarshaler] = _REGISTERED_OBJ_MARSHALERS_BY_TYPE,
1450
+ ) -> None:
1451
+ super().__init__()
1452
+
1453
+ self._default_options = default_options
1454
+
1455
+ self._obj_marshalers = dict(default_obj_marshalers)
1456
+ self._generic_mapping_types = generic_mapping_types
1457
+ self._generic_iterable_types = generic_iterable_types
1458
+ self._registered_obj_marshalers = registered_obj_marshalers
1459
+
1460
+ self._lock = threading.RLock()
1461
+ self._marshalers: ta.Dict[ta.Any, ObjMarshaler] = dict(_DEFAULT_OBJ_MARSHALERS)
1462
+ self._proxies: ta.Dict[ta.Any, ProxyObjMarshaler] = {}
1463
+
1464
+ #
1465
+
1466
+ @classmethod
1467
+ def _is_abstract(cls, ty: type) -> bool:
1468
+ return abc.ABC in ty.__bases__ or Abstract in ty.__bases__
1469
+
1470
+ #
1471
+
1472
+ def make_obj_marshaler(
1473
+ self,
1474
+ ty: ta.Any,
1475
+ rec: ta.Callable[[ta.Any], ObjMarshaler],
1476
+ *,
1477
+ non_strict_fields: bool = False,
1478
+ ) -> ObjMarshaler:
1479
+ if isinstance(ty, type):
1480
+ if (reg := self._registered_obj_marshalers.get(ty)) is not None:
1481
+ return reg
1482
+
1483
+ if self._is_abstract(ty):
1484
+ tn = ty.__name__
1485
+ impls: ta.List[ta.Tuple[type, str]] = [ # type: ignore[var-annotated]
1486
+ (ity, ity.__name__)
1487
+ for ity in deep_subclasses(ty)
1488
+ if not self._is_abstract(ity)
1489
+ ]
1490
+
1491
+ if all(itn.endswith(tn) for _, itn in impls):
1492
+ impls = [
1493
+ (ity, snake_case(itn[:-len(tn)]))
1494
+ for ity, itn in impls
1495
+ ]
1496
+
1497
+ dupe_tns = sorted(
1498
+ dn
1499
+ for dn, dc in collections.Counter(itn for _, itn in impls).items()
1500
+ if dc > 1
1501
+ )
1502
+ if dupe_tns:
1503
+ raise KeyError(f'Duplicate impl names for {ty}: {dupe_tns}')
1504
+
1505
+ return PolymorphicObjMarshaler.of([
1506
+ PolymorphicObjMarshaler.Impl(
1507
+ ity,
1508
+ itn,
1509
+ rec(ity),
1510
+ )
1511
+ for ity, itn in impls
1512
+ ])
1513
+
1514
+ if issubclass(ty, enum.Enum):
1515
+ return EnumObjMarshaler(ty)
1516
+
1517
+ if dc.is_dataclass(ty):
1518
+ return FieldsObjMarshaler(
1519
+ ty,
1520
+ [
1521
+ FieldsObjMarshaler.Field(
1522
+ att=f.name,
1523
+ key=check.non_empty_str(fk),
1524
+ m=rec(f.type),
1525
+ omit_if_none=check.isinstance(f.metadata.get(OBJ_MARSHALER_OMIT_IF_NONE, False), bool),
1526
+ )
1527
+ for f in dc.fields(ty)
1528
+ if (fk := f.metadata.get(OBJ_MARSHALER_FIELD_KEY, f.name)) is not None
1529
+ ],
1530
+ non_strict=non_strict_fields,
1531
+ )
1532
+
1533
+ if issubclass(ty, tuple) and hasattr(ty, '_fields'):
1534
+ return FieldsObjMarshaler(
1535
+ ty,
1536
+ [
1537
+ FieldsObjMarshaler.Field(
1538
+ att=p.name,
1539
+ key=p.name,
1540
+ m=rec(p.annotation),
1541
+ )
1542
+ for p in inspect.signature(ty).parameters.values()
1543
+ ],
1544
+ non_strict=non_strict_fields,
1545
+ )
1546
+
1547
+ if is_new_type(ty):
1548
+ return rec(get_new_type_supertype(ty))
1549
+
1550
+ if is_literal_type(ty):
1551
+ lvs = frozenset(get_literal_type_args(ty))
1552
+ if None in lvs:
1553
+ is_opt = True
1554
+ lvs -= frozenset([None])
1555
+ else:
1556
+ is_opt = False
1557
+ lty = check.single(set(map(type, lvs)))
1558
+ lm: ObjMarshaler = LiteralObjMarshaler(rec(lty), lvs)
1559
+ if is_opt:
1560
+ lm = OptionalObjMarshaler(lm)
1561
+ return lm
1562
+
1563
+ if is_generic_alias(ty):
1564
+ try:
1565
+ mt = self._generic_mapping_types[ta.get_origin(ty)]
1566
+ except KeyError:
1567
+ pass
1568
+ else:
1569
+ k, v = ta.get_args(ty)
1570
+ return MappingObjMarshaler(mt, rec(k), rec(v))
1571
+
1572
+ try:
1573
+ st = self._generic_iterable_types[ta.get_origin(ty)]
1574
+ except KeyError:
1575
+ pass
1576
+ else:
1577
+ [e] = ta.get_args(ty)
1578
+ return IterableObjMarshaler(st, rec(e))
1579
+
1580
+ if is_union_alias(ty):
1581
+ uts = frozenset(ta.get_args(ty))
1582
+ if None in uts or type(None) in uts:
1583
+ is_opt = True
1584
+ uts = frozenset(ut for ut in uts if ut not in (None, type(None)))
1585
+ else:
1586
+ is_opt = False
1587
+
1588
+ um: ObjMarshaler
1589
+ if not uts:
1590
+ raise TypeError(ty)
1591
+ elif len(uts) == 1:
1592
+ um = rec(check.single(uts))
1593
+ else:
1594
+ pt = tuple({ut for ut in uts if ut in _OBJ_MARSHALER_PRIMITIVE_TYPES})
1595
+ np_uts = {ut for ut in uts if ut not in _OBJ_MARSHALER_PRIMITIVE_TYPES}
1596
+ if not np_uts:
1597
+ um = PrimitiveUnionObjMarshaler(pt)
1598
+ elif len(np_uts) == 1:
1599
+ um = PrimitiveUnionObjMarshaler(pt, x=rec(check.single(np_uts)))
1600
+ else:
1601
+ raise TypeError(ty)
1602
+
1603
+ if is_opt:
1604
+ um = OptionalObjMarshaler(um)
1605
+ return um
1606
+
1607
+ raise TypeError(ty)
1608
+
1609
+ #
1610
+
1611
+ def set_obj_marshaler(
1612
+ self,
1613
+ ty: ta.Any,
1614
+ m: ObjMarshaler,
1615
+ *,
1616
+ override: bool = False,
1617
+ ) -> None:
1618
+ with self._lock:
1619
+ if not override and ty in self._obj_marshalers:
1620
+ raise KeyError(ty)
1621
+ self._obj_marshalers[ty] = m
1622
+
1623
+ def get_obj_marshaler(
1624
+ self,
1625
+ ty: ta.Any,
1626
+ *,
1627
+ no_cache: bool = False,
1628
+ **kwargs: ta.Any,
1629
+ ) -> ObjMarshaler:
1630
+ with self._lock:
1631
+ if not no_cache:
1632
+ try:
1633
+ return self._obj_marshalers[ty]
1634
+ except KeyError:
1635
+ pass
1636
+
1637
+ try:
1638
+ return self._proxies[ty]
1639
+ except KeyError:
1640
+ pass
1641
+
1642
+ rec = functools.partial(
1643
+ self.get_obj_marshaler,
1644
+ no_cache=no_cache,
1645
+ **kwargs,
1646
+ )
1647
+
1648
+ p = ProxyObjMarshaler()
1649
+ self._proxies[ty] = p
1650
+ try:
1651
+ m = self.make_obj_marshaler(ty, rec, **kwargs)
1652
+ finally:
1653
+ del self._proxies[ty]
1654
+ p._m = m # noqa
1655
+
1656
+ if not no_cache:
1657
+ self._obj_marshalers[ty] = m
1658
+ return m
1659
+
1660
+ def make_context(self, opts: ta.Optional[ObjMarshalOptions]) -> 'ObjMarshalContext':
1661
+ return ObjMarshalContext(
1662
+ options=opts or self._default_options,
1663
+ manager=self,
1664
+ )
1665
+
1666
+
1667
+ def new_obj_marshaler_manager(**kwargs: ta.Any) -> ObjMarshalerManager:
1668
+ return ObjMarshalerManagerImpl(**kwargs)
1669
+
1670
+
1671
+ ##
1672
+
1673
+
1674
+ @dc.dataclass(frozen=True)
1675
+ class ObjMarshalContext:
1676
+ options: ObjMarshalOptions
1677
+ manager: ObjMarshalerManager
1678
+
1679
+
1680
+ ##
1681
+
1682
+
1683
+ OBJ_MARSHALER_MANAGER = new_obj_marshaler_manager()
1684
+
1685
+ set_obj_marshaler = OBJ_MARSHALER_MANAGER.set_obj_marshaler
1686
+ get_obj_marshaler = OBJ_MARSHALER_MANAGER.get_obj_marshaler
1687
+
1688
+ marshal_obj = OBJ_MARSHALER_MANAGER.marshal_obj
1689
+ unmarshal_obj = OBJ_MARSHALER_MANAGER.unmarshal_obj