omlish 0.0.0.dev451__py3-none-any.whl → 0.0.0.dev453__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omlish/__about__.py +2 -2
- omlish/collections/__init__.py +7 -1
- omlish/collections/attrregistry.py +182 -0
- omlish/collections/mappings.py +25 -0
- omlish/diag/pycharm.py +16 -2
- omlish/dispatch/methods.py +50 -140
- omlish/dom/rendering.py +1 -1
- omlish/funcs/guard.py +214 -0
- omlish/funcs/match.py +2 -0
- omlish/graphs/dot/rendering.py +1 -1
- omlish/lite/maybes.py +8 -0
- omlish/marshal/__init__.py +12 -5
- omlish/marshal/base/contexts.py +4 -2
- omlish/marshal/base/registries.py +119 -45
- omlish/marshal/composite/literals.py +6 -2
- omlish/marshal/composite/unions.py +213 -0
- omlish/marshal/factories/moduleimport/factories.py +8 -10
- omlish/marshal/polymorphism/metadata.py +16 -5
- omlish/marshal/polymorphism/standard.py +17 -3
- omlish/marshal/polymorphism/unions.py +0 -129
- omlish/marshal/standard.py +6 -2
- omlish/sql/queries/_marshal.py +1 -1
- omlish/sql/queries/rendering.py +1 -1
- omlish/text/parts.py +2 -2
- omlish/typedvalues/marshal.py +1 -1
- {omlish-0.0.0.dev451.dist-info → omlish-0.0.0.dev453.dist-info}/METADATA +1 -1
- {omlish-0.0.0.dev451.dist-info → omlish-0.0.0.dev453.dist-info}/RECORD +31 -28
- {omlish-0.0.0.dev451.dist-info → omlish-0.0.0.dev453.dist-info}/WHEEL +0 -0
- {omlish-0.0.0.dev451.dist-info → omlish-0.0.0.dev453.dist-info}/entry_points.txt +0 -0
- {omlish-0.0.0.dev451.dist-info → omlish-0.0.0.dev453.dist-info}/licenses/LICENSE +0 -0
- {omlish-0.0.0.dev451.dist-info → omlish-0.0.0.dev453.dist-info}/top_level.txt +0 -0
omlish/funcs/guard.py
ADDED
@@ -0,0 +1,214 @@
|
|
1
|
+
import abc
|
2
|
+
import functools
|
3
|
+
import operator
|
4
|
+
import typing as ta
|
5
|
+
|
6
|
+
from .. import check
|
7
|
+
from .. import collections as col
|
8
|
+
from .. import lang
|
9
|
+
|
10
|
+
|
11
|
+
T = ta.TypeVar('T')
|
12
|
+
T_co = ta.TypeVar('T_co', covariant=True)
|
13
|
+
U = ta.TypeVar('U')
|
14
|
+
P = ta.ParamSpec('P')
|
15
|
+
|
16
|
+
|
17
|
+
##
|
18
|
+
|
19
|
+
|
20
|
+
class GuardFn(ta.Protocol[P, T_co]):
|
21
|
+
def __get__(self, instance, owner=None) -> ta.Self: ...
|
22
|
+
|
23
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T_co] | None: ...
|
24
|
+
|
25
|
+
|
26
|
+
##
|
27
|
+
|
28
|
+
|
29
|
+
@ta.final
|
30
|
+
class DumbGuardFn(ta.Generic[P, T]):
|
31
|
+
def __init__(self, fn: ta.Callable[P, T]) -> None:
|
32
|
+
self._fn = fn
|
33
|
+
|
34
|
+
def __get__(self, instance, owner=None):
|
35
|
+
return DumbGuardFn(self._fn.__get__(instance, owner)) # noqa
|
36
|
+
|
37
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T]:
|
38
|
+
return functools.partial(self._fn, *args, **kwargs)
|
39
|
+
|
40
|
+
|
41
|
+
dumb = DumbGuardFn
|
42
|
+
|
43
|
+
|
44
|
+
##
|
45
|
+
|
46
|
+
|
47
|
+
class AmbiguousGuardFnError(Exception):
|
48
|
+
pass
|
49
|
+
|
50
|
+
|
51
|
+
@ta.final
|
52
|
+
class MultiGuardFn(ta.Generic[P, T]):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
*children: GuardFn[P, T],
|
56
|
+
strict: bool = False,
|
57
|
+
) -> None:
|
58
|
+
self._children, self._strict = children, strict
|
59
|
+
|
60
|
+
def __get__(self, instance, owner=None):
|
61
|
+
return MultiGuardFn(*map(operator.methodcaller('__get__', instance, owner), self._children), strict=self._strict) # noqa
|
62
|
+
|
63
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T] | None:
|
64
|
+
matches = []
|
65
|
+
for c in self._children:
|
66
|
+
if (m := c(*args, **kwargs)) is not None:
|
67
|
+
if not self._strict:
|
68
|
+
return m
|
69
|
+
matches.append(m)
|
70
|
+
if not matches:
|
71
|
+
return None
|
72
|
+
elif len(matches) > 1:
|
73
|
+
raise AmbiguousGuardFnError
|
74
|
+
else:
|
75
|
+
return matches[0]
|
76
|
+
|
77
|
+
|
78
|
+
multi = MultiGuardFn
|
79
|
+
|
80
|
+
|
81
|
+
##
|
82
|
+
|
83
|
+
|
84
|
+
class _BaseGuardFnMethod(lang.Abstract, ta.Generic[P, T]):
|
85
|
+
def __init__(
|
86
|
+
self,
|
87
|
+
*,
|
88
|
+
strict: bool = False,
|
89
|
+
requires_override: bool = False,
|
90
|
+
instance_cache: bool = False,
|
91
|
+
prototype: ta.Callable[P, T] | None = None,
|
92
|
+
) -> None:
|
93
|
+
super().__init__()
|
94
|
+
|
95
|
+
self._strict = strict
|
96
|
+
self._instance_cache = instance_cache
|
97
|
+
self._prototype = prototype
|
98
|
+
|
99
|
+
self._registry: col.AttrRegistry[ta.Callable, None] = col.AttrRegistry(
|
100
|
+
requires_override=requires_override,
|
101
|
+
)
|
102
|
+
|
103
|
+
self._cache: col.AttrRegistryCache[ta.Callable, None, MultiGuardFn] = col.AttrRegistryCache(
|
104
|
+
self._registry,
|
105
|
+
self._prepare,
|
106
|
+
)
|
107
|
+
|
108
|
+
_owner: type | None = None
|
109
|
+
_name: str | None = None
|
110
|
+
|
111
|
+
def __set_name__(self, owner, name):
|
112
|
+
if self._owner is None:
|
113
|
+
self._owner = owner
|
114
|
+
if self._name is None:
|
115
|
+
self._name = name
|
116
|
+
|
117
|
+
def register(self, fn: U) -> U:
|
118
|
+
check.callable(fn)
|
119
|
+
self._registry.register(ta.cast(ta.Callable, fn), None)
|
120
|
+
return fn
|
121
|
+
|
122
|
+
def _prepare(self, instance_cls: type, collected: ta.Mapping[str, tuple[ta.Callable, None]]) -> MultiGuardFn:
|
123
|
+
return MultiGuardFn(
|
124
|
+
*[getattr(instance_cls, a) for a in collected],
|
125
|
+
strict=self._strict,
|
126
|
+
)
|
127
|
+
|
128
|
+
@abc.abstractmethod
|
129
|
+
def _bind(self, instance, owner):
|
130
|
+
raise NotImplementedError
|
131
|
+
|
132
|
+
def __get__(self, instance, owner=None):
|
133
|
+
if instance is None:
|
134
|
+
return self
|
135
|
+
|
136
|
+
if self._instance_cache:
|
137
|
+
try:
|
138
|
+
return instance.__dict__[self._name]
|
139
|
+
except KeyError:
|
140
|
+
pass
|
141
|
+
|
142
|
+
bound = self._bind(instance, owner)
|
143
|
+
|
144
|
+
if self._instance_cache:
|
145
|
+
instance.__dict__[self._name] = bound
|
146
|
+
|
147
|
+
return bound
|
148
|
+
|
149
|
+
def _call(self, *args, **kwargs):
|
150
|
+
instance, *rest = args
|
151
|
+
return self.__get__(instance)(*rest, **kwargs)
|
152
|
+
|
153
|
+
#
|
154
|
+
|
155
|
+
|
156
|
+
class GuardFnMethod(_BaseGuardFnMethod[P, T]):
|
157
|
+
def _bind(self, instance, owner):
|
158
|
+
return self._cache.get(type(instance)).__get__(instance, owner) # noqa
|
159
|
+
|
160
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T] | None:
|
161
|
+
return self._call(*args, **kwargs)
|
162
|
+
|
163
|
+
|
164
|
+
def method(
|
165
|
+
*,
|
166
|
+
strict: bool = False,
|
167
|
+
requires_override: bool = False,
|
168
|
+
instance_cache: bool = False,
|
169
|
+
) -> ta.Callable[[ta.Callable[P, T]], GuardFnMethod[P, T]]: # noqa
|
170
|
+
def inner(fn):
|
171
|
+
return GuardFnMethod(
|
172
|
+
strict=strict,
|
173
|
+
requires_override=requires_override,
|
174
|
+
instance_cache=instance_cache,
|
175
|
+
prototype=fn,
|
176
|
+
)
|
177
|
+
|
178
|
+
return inner
|
179
|
+
|
180
|
+
|
181
|
+
#
|
182
|
+
|
183
|
+
|
184
|
+
class DumbGuardFnMethod(_BaseGuardFnMethod[P, T]):
|
185
|
+
def _bind(self, instance, owner):
|
186
|
+
gf = self._cache.get(type(instance)).__get__(instance, owner) # noqa
|
187
|
+
x = self._prototype.__get__(instance, owner) # type: ignore
|
188
|
+
|
189
|
+
def inner(*args, **kwargs):
|
190
|
+
if (m := gf(*args, **kwargs)) is not None:
|
191
|
+
return m()
|
192
|
+
return x(*args, **kwargs)
|
193
|
+
|
194
|
+
return inner
|
195
|
+
|
196
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
197
|
+
return self._call(*args, **kwargs)
|
198
|
+
|
199
|
+
|
200
|
+
def dumb_method(
|
201
|
+
*,
|
202
|
+
strict: bool = False,
|
203
|
+
requires_override: bool = False,
|
204
|
+
instance_cache: bool = False,
|
205
|
+
) -> ta.Callable[[ta.Callable[P, T]], DumbGuardFnMethod[P, T]]: # noqa
|
206
|
+
def inner(fn):
|
207
|
+
return DumbGuardFnMethod(
|
208
|
+
strict=strict,
|
209
|
+
requires_override=requires_override,
|
210
|
+
instance_cache=instance_cache,
|
211
|
+
prototype=fn,
|
212
|
+
)
|
213
|
+
|
214
|
+
return inner
|
omlish/funcs/match.py
CHANGED
@@ -5,6 +5,8 @@ TODO:
|
|
5
5
|
- unify MatchFnClass with dispatch.method?
|
6
6
|
- __call__ = mfs.method(); @__call__.register(lambda: ...) def _call_... ?
|
7
7
|
- not really the same thing, dispatch is unordered this is necessarily ordered
|
8
|
+
- !! well.. unify interface at least?
|
9
|
+
- guard(*a, **k) -> bool + fn(*a, **k) -> T becomes dispatch(*a, **k) -> (Callable -> T) | None
|
8
10
|
"""
|
9
11
|
import abc
|
10
12
|
import dataclasses as dc
|
omlish/graphs/dot/rendering.py
CHANGED
omlish/lite/maybes.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# ruff: noqa: UP007 UP045
|
2
2
|
import abc
|
3
3
|
import functools
|
4
|
+
import operator
|
4
5
|
import typing as ta
|
5
6
|
|
6
7
|
from .abstract import Abstract
|
@@ -208,3 +209,10 @@ class _EmptyMaybe(_Maybe[T]):
|
|
208
209
|
|
209
210
|
|
210
211
|
Maybe._empty = _EmptyMaybe() # noqa
|
212
|
+
|
213
|
+
|
214
|
+
##
|
215
|
+
|
216
|
+
|
217
|
+
setattr(Maybe, 'just', _JustMaybe) # noqa
|
218
|
+
setattr(Maybe, 'empty', functools.partial(operator.attrgetter('_empty'), Maybe))
|
omlish/marshal/__init__.py
CHANGED
@@ -113,6 +113,17 @@ with _lang.auto_proxy_init(globals()):
|
|
113
113
|
OptionalUnmarshaler,
|
114
114
|
)
|
115
115
|
|
116
|
+
from .composite.unions import ( # noqa
|
117
|
+
MatchUnionMarshaler,
|
118
|
+
MatchUnionUnmarshaler,
|
119
|
+
|
120
|
+
PRIMITIVE_UNION_TYPES,
|
121
|
+
PrimitiveUnionMarshaler,
|
122
|
+
PrimitiveUnionMarshalerFactory,
|
123
|
+
PrimitiveUnionUnmarshaler,
|
124
|
+
PrimitiveUnionUnmarshalerFactory,
|
125
|
+
)
|
126
|
+
|
116
127
|
from .composite.wrapped import ( # noqa
|
117
128
|
WrappedMarshaler,
|
118
129
|
WrappedUnmarshaler,
|
@@ -197,6 +208,7 @@ with _lang.auto_proxy_init(globals()):
|
|
197
208
|
)
|
198
209
|
|
199
210
|
from .polymorphism.metadata import ( # noqa
|
211
|
+
AutoStripSuffix,
|
200
212
|
FieldTypeTagging,
|
201
213
|
Impl,
|
202
214
|
Impls,
|
@@ -212,13 +224,8 @@ with _lang.auto_proxy_init(globals()):
|
|
212
224
|
)
|
213
225
|
|
214
226
|
from .polymorphism.unions import ( # noqa
|
215
|
-
PRIMITIVE_UNION_TYPES,
|
216
227
|
PolymorphismUnionMarshalerFactory,
|
217
228
|
PolymorphismUnionUnmarshalerFactory,
|
218
|
-
PrimitiveUnionMarshaler,
|
219
|
-
PrimitiveUnionMarshalerFactory,
|
220
|
-
PrimitiveUnionUnmarshaler,
|
221
|
-
PrimitiveUnionUnmarshalerFactory,
|
222
229
|
)
|
223
230
|
|
224
231
|
from .polymorphism.unmarshal import ( # noqa
|
omlish/marshal/base/contexts.py
CHANGED
@@ -47,8 +47,9 @@ class MarshalContext(BaseContext, lang.Final):
|
|
47
47
|
|
48
48
|
def make(self, o: ta.Any) -> 'Marshaler':
|
49
49
|
rty = self._reflect(o)
|
50
|
+
fac = check.not_none(self.factory)
|
50
51
|
try:
|
51
|
-
return
|
52
|
+
return fac.make_marshaler(self, rty)
|
52
53
|
except mfs.MatchGuardError:
|
53
54
|
raise UnhandledTypeError(rty) # noqa
|
54
55
|
|
@@ -62,8 +63,9 @@ class UnmarshalContext(BaseContext, lang.Final):
|
|
62
63
|
|
63
64
|
def make(self, o: ta.Any) -> 'Unmarshaler':
|
64
65
|
rty = self._reflect(o)
|
66
|
+
fac = check.not_none(self.factory)
|
65
67
|
try:
|
66
|
-
return
|
68
|
+
return fac.make_unmarshaler(self, rty)
|
67
69
|
except mfs.MatchGuardError:
|
68
70
|
raise UnhandledTypeError(rty) # noqa
|
69
71
|
|
@@ -7,7 +7,6 @@ import dataclasses as dc
|
|
7
7
|
import threading
|
8
8
|
import typing as ta
|
9
9
|
|
10
|
-
from ... import collections as col
|
11
10
|
from ... import lang
|
12
11
|
|
13
12
|
|
@@ -22,18 +21,6 @@ RegistryItemT = ta.TypeVar('RegistryItemT', bound=RegistryItem)
|
|
22
21
|
RegistryItemU = ta.TypeVar('RegistryItemU', bound=RegistryItem)
|
23
22
|
|
24
23
|
|
25
|
-
@dc.dataclass(frozen=True)
|
26
|
-
class _KeyRegistryItems(ta.Generic[RegistryItemT]):
|
27
|
-
key: ta.Any
|
28
|
-
items: list[RegistryItemT] = dc.field(default_factory=list)
|
29
|
-
item_lists_by_ty: dict[type[RegistryItemT], list[RegistryItemT]] = dc.field(default_factory=dict)
|
30
|
-
|
31
|
-
def add(self, *items: RegistryItemT) -> None:
|
32
|
-
for i in items:
|
33
|
-
self.items.append(i)
|
34
|
-
self.item_lists_by_ty.setdefault(type(i), []).append(i)
|
35
|
-
|
36
|
-
|
37
24
|
class RegistrySealedError(Exception):
|
38
25
|
pass
|
39
26
|
|
@@ -50,14 +37,123 @@ class Registry(ta.Generic[RegistryItemT]):
|
|
50
37
|
lock = threading.RLock()
|
51
38
|
self._lock = lock
|
52
39
|
|
53
|
-
self.
|
54
|
-
|
40
|
+
self._state: Registry._State[RegistryItemT] = Registry._State(
|
41
|
+
dct={},
|
42
|
+
id_dct={},
|
43
|
+
version=0,
|
44
|
+
)
|
55
45
|
|
56
|
-
self._version = 0
|
57
46
|
self._sealed = False
|
58
47
|
|
59
48
|
#
|
60
49
|
|
50
|
+
@dc.dataclass(frozen=True)
|
51
|
+
class _KeyItems(ta.Generic[RegistryItemU]):
|
52
|
+
key: ta.Any
|
53
|
+
items: ta.Sequence[RegistryItemU] = ()
|
54
|
+
item_lists_by_ty: ta.Mapping[type[RegistryItemU], ta.Sequence[RegistryItemU]] = dc.field(default_factory=dict)
|
55
|
+
|
56
|
+
def add(self, *items: RegistryItemU) -> 'Registry._KeyItems[RegistryItemU]':
|
57
|
+
item_lists_by_ty: dict[type[RegistryItemU], list[RegistryItemU]] = {}
|
58
|
+
|
59
|
+
for i in items:
|
60
|
+
try:
|
61
|
+
l = item_lists_by_ty[type(i)]
|
62
|
+
except KeyError:
|
63
|
+
l = item_lists_by_ty[type(i)] = list(self.item_lists_by_ty.get(type(i), ()))
|
64
|
+
l.append(i)
|
65
|
+
|
66
|
+
return Registry._KeyItems(
|
67
|
+
self.key,
|
68
|
+
(*self.items, *items),
|
69
|
+
{**self.item_lists_by_ty, **item_lists_by_ty},
|
70
|
+
)
|
71
|
+
|
72
|
+
@dc.dataclass(frozen=True, kw_only=True)
|
73
|
+
class _State(ta.Generic[RegistryItemU]):
|
74
|
+
dct: ta.Mapping[ta.Any, 'Registry._KeyItems[RegistryItemU]']
|
75
|
+
id_dct: ta.Mapping[ta.Any, 'Registry._KeyItems[RegistryItemU]']
|
76
|
+
version: int
|
77
|
+
|
78
|
+
#
|
79
|
+
|
80
|
+
def register(
|
81
|
+
self,
|
82
|
+
key: ta.Any,
|
83
|
+
*items: RegistryItemT,
|
84
|
+
identity: bool = False,
|
85
|
+
) -> 'Registry._State[RegistryItemU]':
|
86
|
+
if not items:
|
87
|
+
return self
|
88
|
+
|
89
|
+
sr_dct: ta.Any = self.dct if not identity else self.id_dct
|
90
|
+
if (sr := sr_dct.get(key)) is None:
|
91
|
+
sr = Registry._KeyItems(key)
|
92
|
+
sr = sr.add(*items)
|
93
|
+
new = {key: sr}
|
94
|
+
|
95
|
+
return Registry._State(
|
96
|
+
dct={**self.dct, **new} if not identity else self.dct,
|
97
|
+
id_dct={**self.id_dct, **new} if identity else self.id_dct,
|
98
|
+
version=self.version + 1,
|
99
|
+
)
|
100
|
+
|
101
|
+
#
|
102
|
+
|
103
|
+
_get_cache: dict[ta.Any, ta.Sequence[RegistryItem]] = dc.field(default_factory=dict)
|
104
|
+
|
105
|
+
def get(
|
106
|
+
self,
|
107
|
+
key: ta.Any,
|
108
|
+
*,
|
109
|
+
identity: bool | None = None,
|
110
|
+
) -> ta.Sequence[RegistryItem]:
|
111
|
+
if identity is None:
|
112
|
+
try:
|
113
|
+
return self._get_cache[key]
|
114
|
+
except KeyError:
|
115
|
+
pass
|
116
|
+
|
117
|
+
ret = self._get_cache[key] = (
|
118
|
+
*self.get(key, identity=True),
|
119
|
+
*self.get(key, identity=False),
|
120
|
+
)
|
121
|
+
return ret
|
122
|
+
|
123
|
+
dct: ta.Any = self.dct if not identity else self.id_dct
|
124
|
+
try:
|
125
|
+
return dct[key].items
|
126
|
+
except KeyError:
|
127
|
+
return ()
|
128
|
+
|
129
|
+
_get_of_cache: dict[ta.Any, dict[type, ta.Sequence[RegistryItem]]] = dc.field(default_factory=dict)
|
130
|
+
|
131
|
+
def get_of(
|
132
|
+
self,
|
133
|
+
key: ta.Any,
|
134
|
+
item_ty: type[RegistryItem],
|
135
|
+
*,
|
136
|
+
identity: bool | None = None,
|
137
|
+
) -> ta.Sequence[RegistryItem]:
|
138
|
+
if identity is None:
|
139
|
+
try:
|
140
|
+
return self._get_of_cache[key][item_ty]
|
141
|
+
except KeyError:
|
142
|
+
pass
|
143
|
+
|
144
|
+
ret = self._get_of_cache.setdefault(key, {})[item_ty] = (
|
145
|
+
*self.get_of(key, item_ty, identity=True),
|
146
|
+
*self.get_of(key, item_ty, identity=False),
|
147
|
+
)
|
148
|
+
return ret
|
149
|
+
|
150
|
+
dct: ta.Any = self.dct if not identity else self.id_dct
|
151
|
+
try:
|
152
|
+
sr = dct[key]
|
153
|
+
except KeyError:
|
154
|
+
return ()
|
155
|
+
return sr.item_lists_by_ty.get(item_ty, ())
|
156
|
+
|
61
157
|
def is_sealed(self) -> bool:
|
62
158
|
if self._sealed:
|
63
159
|
return True
|
@@ -92,12 +188,11 @@ class Registry(ta.Generic[RegistryItemT]):
|
|
92
188
|
if self._sealed:
|
93
189
|
raise RegistrySealedError(self)
|
94
190
|
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
self._version += 1
|
191
|
+
self._state = self._state.register(
|
192
|
+
key,
|
193
|
+
*items,
|
194
|
+
identity=identity,
|
195
|
+
)
|
101
196
|
|
102
197
|
return self
|
103
198
|
|
@@ -109,17 +204,7 @@ class Registry(ta.Generic[RegistryItemT]):
|
|
109
204
|
*,
|
110
205
|
identity: bool | None = None,
|
111
206
|
) -> ta.Sequence[RegistryItem]:
|
112
|
-
|
113
|
-
return (
|
114
|
-
*self.get(key, identity=True),
|
115
|
-
*self.get(key, identity=False),
|
116
|
-
)
|
117
|
-
|
118
|
-
dct: ta.Any = self._id_dct if identity else self._dct
|
119
|
-
try:
|
120
|
-
return dct[key].items
|
121
|
-
except KeyError:
|
122
|
-
return ()
|
207
|
+
return self._state.get(key, identity=identity)
|
123
208
|
|
124
209
|
def get_of(
|
125
210
|
self,
|
@@ -128,15 +213,4 @@ class Registry(ta.Generic[RegistryItemT]):
|
|
128
213
|
*,
|
129
214
|
identity: bool | None = None,
|
130
215
|
) -> ta.Sequence[RegistryItemU]:
|
131
|
-
|
132
|
-
return (
|
133
|
-
*self.get_of(key, item_ty, identity=True),
|
134
|
-
*self.get_of(key, item_ty, identity=False),
|
135
|
-
)
|
136
|
-
|
137
|
-
dct: ta.Any = self._id_dct if identity else self._dct
|
138
|
-
try:
|
139
|
-
sr = dct[key]
|
140
|
-
except KeyError:
|
141
|
-
return ()
|
142
|
-
return sr.item_lists_by_ty.get(item_ty, ())
|
216
|
+
return self._state.get_of(key, item_ty, identity=identity) # type: ignore[return-value]
|
@@ -1,3 +1,7 @@
|
|
1
|
+
"""
|
2
|
+
TODO:
|
3
|
+
- squash literal unions - typing machinery doesn't
|
4
|
+
"""
|
1
5
|
import dataclasses as dc
|
2
6
|
import typing as ta
|
3
7
|
|
@@ -26,7 +30,7 @@ class LiteralMarshaler(Marshaler):
|
|
26
30
|
|
27
31
|
class LiteralMarshalerFactory(SimpleMarshalerFactory):
|
28
32
|
def guard(self, ctx: MarshalContext, rty: rfl.Type) -> bool:
|
29
|
-
return isinstance(rty, rfl.Literal)
|
33
|
+
return isinstance(rty, rfl.Literal) and len(set(map(type, rty.args))) == 1
|
30
34
|
|
31
35
|
def fn(self, ctx: MarshalContext, rty: rfl.Type) -> Marshaler:
|
32
36
|
lty = check.isinstance(rty, rfl.Literal)
|
@@ -45,7 +49,7 @@ class LiteralUnmarshaler(Unmarshaler):
|
|
45
49
|
|
46
50
|
class LiteralUnmarshalerFactory(SimpleUnmarshalerFactory):
|
47
51
|
def guard(self, ctx: UnmarshalContext, rty: rfl.Type) -> bool:
|
48
|
-
return isinstance(rty, rfl.Literal)
|
52
|
+
return isinstance(rty, rfl.Literal) and len(set(map(type, rty.args))) == 1
|
49
53
|
|
50
54
|
def fn(self, ctx: UnmarshalContext, rty: rfl.Type) -> Unmarshaler:
|
51
55
|
lty = check.isinstance(rty, rfl.Literal)
|