omlish 0.0.0.dev451__py3-none-any.whl → 0.0.0.dev453__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
omlish/funcs/guard.py ADDED
@@ -0,0 +1,214 @@
1
+ import abc
2
+ import functools
3
+ import operator
4
+ import typing as ta
5
+
6
+ from .. import check
7
+ from .. import collections as col
8
+ from .. import lang
9
+
10
+
11
+ T = ta.TypeVar('T')
12
+ T_co = ta.TypeVar('T_co', covariant=True)
13
+ U = ta.TypeVar('U')
14
+ P = ta.ParamSpec('P')
15
+
16
+
17
+ ##
18
+
19
+
20
+ class GuardFn(ta.Protocol[P, T_co]):
21
+ def __get__(self, instance, owner=None) -> ta.Self: ...
22
+
23
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T_co] | None: ...
24
+
25
+
26
+ ##
27
+
28
+
29
+ @ta.final
30
+ class DumbGuardFn(ta.Generic[P, T]):
31
+ def __init__(self, fn: ta.Callable[P, T]) -> None:
32
+ self._fn = fn
33
+
34
+ def __get__(self, instance, owner=None):
35
+ return DumbGuardFn(self._fn.__get__(instance, owner)) # noqa
36
+
37
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T]:
38
+ return functools.partial(self._fn, *args, **kwargs)
39
+
40
+
41
+ dumb = DumbGuardFn
42
+
43
+
44
+ ##
45
+
46
+
47
+ class AmbiguousGuardFnError(Exception):
48
+ pass
49
+
50
+
51
+ @ta.final
52
+ class MultiGuardFn(ta.Generic[P, T]):
53
+ def __init__(
54
+ self,
55
+ *children: GuardFn[P, T],
56
+ strict: bool = False,
57
+ ) -> None:
58
+ self._children, self._strict = children, strict
59
+
60
+ def __get__(self, instance, owner=None):
61
+ return MultiGuardFn(*map(operator.methodcaller('__get__', instance, owner), self._children), strict=self._strict) # noqa
62
+
63
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T] | None:
64
+ matches = []
65
+ for c in self._children:
66
+ if (m := c(*args, **kwargs)) is not None:
67
+ if not self._strict:
68
+ return m
69
+ matches.append(m)
70
+ if not matches:
71
+ return None
72
+ elif len(matches) > 1:
73
+ raise AmbiguousGuardFnError
74
+ else:
75
+ return matches[0]
76
+
77
+
78
+ multi = MultiGuardFn
79
+
80
+
81
+ ##
82
+
83
+
84
+ class _BaseGuardFnMethod(lang.Abstract, ta.Generic[P, T]):
85
+ def __init__(
86
+ self,
87
+ *,
88
+ strict: bool = False,
89
+ requires_override: bool = False,
90
+ instance_cache: bool = False,
91
+ prototype: ta.Callable[P, T] | None = None,
92
+ ) -> None:
93
+ super().__init__()
94
+
95
+ self._strict = strict
96
+ self._instance_cache = instance_cache
97
+ self._prototype = prototype
98
+
99
+ self._registry: col.AttrRegistry[ta.Callable, None] = col.AttrRegistry(
100
+ requires_override=requires_override,
101
+ )
102
+
103
+ self._cache: col.AttrRegistryCache[ta.Callable, None, MultiGuardFn] = col.AttrRegistryCache(
104
+ self._registry,
105
+ self._prepare,
106
+ )
107
+
108
+ _owner: type | None = None
109
+ _name: str | None = None
110
+
111
+ def __set_name__(self, owner, name):
112
+ if self._owner is None:
113
+ self._owner = owner
114
+ if self._name is None:
115
+ self._name = name
116
+
117
+ def register(self, fn: U) -> U:
118
+ check.callable(fn)
119
+ self._registry.register(ta.cast(ta.Callable, fn), None)
120
+ return fn
121
+
122
+ def _prepare(self, instance_cls: type, collected: ta.Mapping[str, tuple[ta.Callable, None]]) -> MultiGuardFn:
123
+ return MultiGuardFn(
124
+ *[getattr(instance_cls, a) for a in collected],
125
+ strict=self._strict,
126
+ )
127
+
128
+ @abc.abstractmethod
129
+ def _bind(self, instance, owner):
130
+ raise NotImplementedError
131
+
132
+ def __get__(self, instance, owner=None):
133
+ if instance is None:
134
+ return self
135
+
136
+ if self._instance_cache:
137
+ try:
138
+ return instance.__dict__[self._name]
139
+ except KeyError:
140
+ pass
141
+
142
+ bound = self._bind(instance, owner)
143
+
144
+ if self._instance_cache:
145
+ instance.__dict__[self._name] = bound
146
+
147
+ return bound
148
+
149
+ def _call(self, *args, **kwargs):
150
+ instance, *rest = args
151
+ return self.__get__(instance)(*rest, **kwargs)
152
+
153
+ #
154
+
155
+
156
+ class GuardFnMethod(_BaseGuardFnMethod[P, T]):
157
+ def _bind(self, instance, owner):
158
+ return self._cache.get(type(instance)).__get__(instance, owner) # noqa
159
+
160
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ta.Callable[[], T] | None:
161
+ return self._call(*args, **kwargs)
162
+
163
+
164
+ def method(
165
+ *,
166
+ strict: bool = False,
167
+ requires_override: bool = False,
168
+ instance_cache: bool = False,
169
+ ) -> ta.Callable[[ta.Callable[P, T]], GuardFnMethod[P, T]]: # noqa
170
+ def inner(fn):
171
+ return GuardFnMethod(
172
+ strict=strict,
173
+ requires_override=requires_override,
174
+ instance_cache=instance_cache,
175
+ prototype=fn,
176
+ )
177
+
178
+ return inner
179
+
180
+
181
+ #
182
+
183
+
184
+ class DumbGuardFnMethod(_BaseGuardFnMethod[P, T]):
185
+ def _bind(self, instance, owner):
186
+ gf = self._cache.get(type(instance)).__get__(instance, owner) # noqa
187
+ x = self._prototype.__get__(instance, owner) # type: ignore
188
+
189
+ def inner(*args, **kwargs):
190
+ if (m := gf(*args, **kwargs)) is not None:
191
+ return m()
192
+ return x(*args, **kwargs)
193
+
194
+ return inner
195
+
196
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
197
+ return self._call(*args, **kwargs)
198
+
199
+
200
+ def dumb_method(
201
+ *,
202
+ strict: bool = False,
203
+ requires_override: bool = False,
204
+ instance_cache: bool = False,
205
+ ) -> ta.Callable[[ta.Callable[P, T]], DumbGuardFnMethod[P, T]]: # noqa
206
+ def inner(fn):
207
+ return DumbGuardFnMethod(
208
+ strict=strict,
209
+ requires_override=requires_override,
210
+ instance_cache=instance_cache,
211
+ prototype=fn,
212
+ )
213
+
214
+ return inner
omlish/funcs/match.py CHANGED
@@ -5,6 +5,8 @@ TODO:
5
5
  - unify MatchFnClass with dispatch.method?
6
6
  - __call__ = mfs.method(); @__call__.register(lambda: ...) def _call_... ?
7
7
  - not really the same thing, dispatch is unordered this is necessarily ordered
8
+ - !! well.. unify interface at least?
9
+ - guard(*a, **k) -> bool + fn(*a, **k) -> T becomes dispatch(*a, **k) -> (Callable -> T) | None
8
10
  """
9
11
  import abc
10
12
  import dataclasses as dc
@@ -29,7 +29,7 @@ class Renderer:
29
29
 
30
30
  self._out = out
31
31
 
32
- @dispatch.method
32
+ @dispatch.method(instance_cache=True)
33
33
  def render(self, item: Item) -> None:
34
34
  raise TypeError(item)
35
35
 
omlish/lite/maybes.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # ruff: noqa: UP007 UP045
2
2
  import abc
3
3
  import functools
4
+ import operator
4
5
  import typing as ta
5
6
 
6
7
  from .abstract import Abstract
@@ -208,3 +209,10 @@ class _EmptyMaybe(_Maybe[T]):
208
209
 
209
210
 
210
211
  Maybe._empty = _EmptyMaybe() # noqa
212
+
213
+
214
+ ##
215
+
216
+
217
+ setattr(Maybe, 'just', _JustMaybe) # noqa
218
+ setattr(Maybe, 'empty', functools.partial(operator.attrgetter('_empty'), Maybe))
@@ -113,6 +113,17 @@ with _lang.auto_proxy_init(globals()):
113
113
  OptionalUnmarshaler,
114
114
  )
115
115
 
116
+ from .composite.unions import ( # noqa
117
+ MatchUnionMarshaler,
118
+ MatchUnionUnmarshaler,
119
+
120
+ PRIMITIVE_UNION_TYPES,
121
+ PrimitiveUnionMarshaler,
122
+ PrimitiveUnionMarshalerFactory,
123
+ PrimitiveUnionUnmarshaler,
124
+ PrimitiveUnionUnmarshalerFactory,
125
+ )
126
+
116
127
  from .composite.wrapped import ( # noqa
117
128
  WrappedMarshaler,
118
129
  WrappedUnmarshaler,
@@ -197,6 +208,7 @@ with _lang.auto_proxy_init(globals()):
197
208
  )
198
209
 
199
210
  from .polymorphism.metadata import ( # noqa
211
+ AutoStripSuffix,
200
212
  FieldTypeTagging,
201
213
  Impl,
202
214
  Impls,
@@ -212,13 +224,8 @@ with _lang.auto_proxy_init(globals()):
212
224
  )
213
225
 
214
226
  from .polymorphism.unions import ( # noqa
215
- PRIMITIVE_UNION_TYPES,
216
227
  PolymorphismUnionMarshalerFactory,
217
228
  PolymorphismUnionUnmarshalerFactory,
218
- PrimitiveUnionMarshaler,
219
- PrimitiveUnionMarshalerFactory,
220
- PrimitiveUnionUnmarshaler,
221
- PrimitiveUnionUnmarshalerFactory,
222
229
  )
223
230
 
224
231
  from .polymorphism.unmarshal import ( # noqa
@@ -47,8 +47,9 @@ class MarshalContext(BaseContext, lang.Final):
47
47
 
48
48
  def make(self, o: ta.Any) -> 'Marshaler':
49
49
  rty = self._reflect(o)
50
+ fac = check.not_none(self.factory)
50
51
  try:
51
- return check.not_none(self.factory).make_marshaler(self, rty)
52
+ return fac.make_marshaler(self, rty)
52
53
  except mfs.MatchGuardError:
53
54
  raise UnhandledTypeError(rty) # noqa
54
55
 
@@ -62,8 +63,9 @@ class UnmarshalContext(BaseContext, lang.Final):
62
63
 
63
64
  def make(self, o: ta.Any) -> 'Unmarshaler':
64
65
  rty = self._reflect(o)
66
+ fac = check.not_none(self.factory)
65
67
  try:
66
- return check.not_none(self.factory).make_unmarshaler(self, rty)
68
+ return fac.make_unmarshaler(self, rty)
67
69
  except mfs.MatchGuardError:
68
70
  raise UnhandledTypeError(rty) # noqa
69
71
 
@@ -7,7 +7,6 @@ import dataclasses as dc
7
7
  import threading
8
8
  import typing as ta
9
9
 
10
- from ... import collections as col
11
10
  from ... import lang
12
11
 
13
12
 
@@ -22,18 +21,6 @@ RegistryItemT = ta.TypeVar('RegistryItemT', bound=RegistryItem)
22
21
  RegistryItemU = ta.TypeVar('RegistryItemU', bound=RegistryItem)
23
22
 
24
23
 
25
- @dc.dataclass(frozen=True)
26
- class _KeyRegistryItems(ta.Generic[RegistryItemT]):
27
- key: ta.Any
28
- items: list[RegistryItemT] = dc.field(default_factory=list)
29
- item_lists_by_ty: dict[type[RegistryItemT], list[RegistryItemT]] = dc.field(default_factory=dict)
30
-
31
- def add(self, *items: RegistryItemT) -> None:
32
- for i in items:
33
- self.items.append(i)
34
- self.item_lists_by_ty.setdefault(type(i), []).append(i)
35
-
36
-
37
24
  class RegistrySealedError(Exception):
38
25
  pass
39
26
 
@@ -50,14 +37,123 @@ class Registry(ta.Generic[RegistryItemT]):
50
37
  lock = threading.RLock()
51
38
  self._lock = lock
52
39
 
53
- self._dct: dict[ta.Any, _KeyRegistryItems[RegistryItemT]] = {}
54
- self._id_dct: ta.MutableMapping[ta.Any, _KeyRegistryItems[RegistryItemT]] = col.IdentityKeyDict()
40
+ self._state: Registry._State[RegistryItemT] = Registry._State(
41
+ dct={},
42
+ id_dct={},
43
+ version=0,
44
+ )
55
45
 
56
- self._version = 0
57
46
  self._sealed = False
58
47
 
59
48
  #
60
49
 
50
+ @dc.dataclass(frozen=True)
51
+ class _KeyItems(ta.Generic[RegistryItemU]):
52
+ key: ta.Any
53
+ items: ta.Sequence[RegistryItemU] = ()
54
+ item_lists_by_ty: ta.Mapping[type[RegistryItemU], ta.Sequence[RegistryItemU]] = dc.field(default_factory=dict)
55
+
56
+ def add(self, *items: RegistryItemU) -> 'Registry._KeyItems[RegistryItemU]':
57
+ item_lists_by_ty: dict[type[RegistryItemU], list[RegistryItemU]] = {}
58
+
59
+ for i in items:
60
+ try:
61
+ l = item_lists_by_ty[type(i)]
62
+ except KeyError:
63
+ l = item_lists_by_ty[type(i)] = list(self.item_lists_by_ty.get(type(i), ()))
64
+ l.append(i)
65
+
66
+ return Registry._KeyItems(
67
+ self.key,
68
+ (*self.items, *items),
69
+ {**self.item_lists_by_ty, **item_lists_by_ty},
70
+ )
71
+
72
+ @dc.dataclass(frozen=True, kw_only=True)
73
+ class _State(ta.Generic[RegistryItemU]):
74
+ dct: ta.Mapping[ta.Any, 'Registry._KeyItems[RegistryItemU]']
75
+ id_dct: ta.Mapping[ta.Any, 'Registry._KeyItems[RegistryItemU]']
76
+ version: int
77
+
78
+ #
79
+
80
+ def register(
81
+ self,
82
+ key: ta.Any,
83
+ *items: RegistryItemT,
84
+ identity: bool = False,
85
+ ) -> 'Registry._State[RegistryItemU]':
86
+ if not items:
87
+ return self
88
+
89
+ sr_dct: ta.Any = self.dct if not identity else self.id_dct
90
+ if (sr := sr_dct.get(key)) is None:
91
+ sr = Registry._KeyItems(key)
92
+ sr = sr.add(*items)
93
+ new = {key: sr}
94
+
95
+ return Registry._State(
96
+ dct={**self.dct, **new} if not identity else self.dct,
97
+ id_dct={**self.id_dct, **new} if identity else self.id_dct,
98
+ version=self.version + 1,
99
+ )
100
+
101
+ #
102
+
103
+ _get_cache: dict[ta.Any, ta.Sequence[RegistryItem]] = dc.field(default_factory=dict)
104
+
105
+ def get(
106
+ self,
107
+ key: ta.Any,
108
+ *,
109
+ identity: bool | None = None,
110
+ ) -> ta.Sequence[RegistryItem]:
111
+ if identity is None:
112
+ try:
113
+ return self._get_cache[key]
114
+ except KeyError:
115
+ pass
116
+
117
+ ret = self._get_cache[key] = (
118
+ *self.get(key, identity=True),
119
+ *self.get(key, identity=False),
120
+ )
121
+ return ret
122
+
123
+ dct: ta.Any = self.dct if not identity else self.id_dct
124
+ try:
125
+ return dct[key].items
126
+ except KeyError:
127
+ return ()
128
+
129
+ _get_of_cache: dict[ta.Any, dict[type, ta.Sequence[RegistryItem]]] = dc.field(default_factory=dict)
130
+
131
+ def get_of(
132
+ self,
133
+ key: ta.Any,
134
+ item_ty: type[RegistryItem],
135
+ *,
136
+ identity: bool | None = None,
137
+ ) -> ta.Sequence[RegistryItem]:
138
+ if identity is None:
139
+ try:
140
+ return self._get_of_cache[key][item_ty]
141
+ except KeyError:
142
+ pass
143
+
144
+ ret = self._get_of_cache.setdefault(key, {})[item_ty] = (
145
+ *self.get_of(key, item_ty, identity=True),
146
+ *self.get_of(key, item_ty, identity=False),
147
+ )
148
+ return ret
149
+
150
+ dct: ta.Any = self.dct if not identity else self.id_dct
151
+ try:
152
+ sr = dct[key]
153
+ except KeyError:
154
+ return ()
155
+ return sr.item_lists_by_ty.get(item_ty, ())
156
+
61
157
  def is_sealed(self) -> bool:
62
158
  if self._sealed:
63
159
  return True
@@ -92,12 +188,11 @@ class Registry(ta.Generic[RegistryItemT]):
92
188
  if self._sealed:
93
189
  raise RegistrySealedError(self)
94
190
 
95
- dct: ta.Any = self._id_dct if identity else self._dct
96
- if (sr := dct.get(key)) is None:
97
- sr = dct[key] = _KeyRegistryItems(key)
98
- sr.add(*items)
99
-
100
- self._version += 1
191
+ self._state = self._state.register(
192
+ key,
193
+ *items,
194
+ identity=identity,
195
+ )
101
196
 
102
197
  return self
103
198
 
@@ -109,17 +204,7 @@ class Registry(ta.Generic[RegistryItemT]):
109
204
  *,
110
205
  identity: bool | None = None,
111
206
  ) -> ta.Sequence[RegistryItem]:
112
- if identity is None:
113
- return (
114
- *self.get(key, identity=True),
115
- *self.get(key, identity=False),
116
- )
117
-
118
- dct: ta.Any = self._id_dct if identity else self._dct
119
- try:
120
- return dct[key].items
121
- except KeyError:
122
- return ()
207
+ return self._state.get(key, identity=identity)
123
208
 
124
209
  def get_of(
125
210
  self,
@@ -128,15 +213,4 @@ class Registry(ta.Generic[RegistryItemT]):
128
213
  *,
129
214
  identity: bool | None = None,
130
215
  ) -> ta.Sequence[RegistryItemU]:
131
- if identity is None:
132
- return (
133
- *self.get_of(key, item_ty, identity=True),
134
- *self.get_of(key, item_ty, identity=False),
135
- )
136
-
137
- dct: ta.Any = self._id_dct if identity else self._dct
138
- try:
139
- sr = dct[key]
140
- except KeyError:
141
- return ()
142
- return sr.item_lists_by_ty.get(item_ty, ())
216
+ return self._state.get_of(key, item_ty, identity=identity) # type: ignore[return-value]
@@ -1,3 +1,7 @@
1
+ """
2
+ TODO:
3
+ - squash literal unions - typing machinery doesn't
4
+ """
1
5
  import dataclasses as dc
2
6
  import typing as ta
3
7
 
@@ -26,7 +30,7 @@ class LiteralMarshaler(Marshaler):
26
30
 
27
31
  class LiteralMarshalerFactory(SimpleMarshalerFactory):
28
32
  def guard(self, ctx: MarshalContext, rty: rfl.Type) -> bool:
29
- return isinstance(rty, rfl.Literal)
33
+ return isinstance(rty, rfl.Literal) and len(set(map(type, rty.args))) == 1
30
34
 
31
35
  def fn(self, ctx: MarshalContext, rty: rfl.Type) -> Marshaler:
32
36
  lty = check.isinstance(rty, rfl.Literal)
@@ -45,7 +49,7 @@ class LiteralUnmarshaler(Unmarshaler):
45
49
 
46
50
  class LiteralUnmarshalerFactory(SimpleUnmarshalerFactory):
47
51
  def guard(self, ctx: UnmarshalContext, rty: rfl.Type) -> bool:
48
- return isinstance(rty, rfl.Literal)
52
+ return isinstance(rty, rfl.Literal) and len(set(map(type, rty.args))) == 1
49
53
 
50
54
  def fn(self, ctx: UnmarshalContext, rty: rfl.Type) -> Unmarshaler:
51
55
  lty = check.isinstance(rty, rfl.Literal)