ovld 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ovld/__init__.py +23 -1
- ovld/codegen.py +303 -0
- ovld/core.py +62 -349
- ovld/dependent.py +24 -72
- ovld/medley.py +408 -0
- ovld/mro.py +6 -3
- ovld/py.typed +0 -0
- ovld/recode.py +99 -165
- ovld/signatures.py +275 -0
- ovld/typemap.py +40 -38
- ovld/types.py +47 -44
- ovld/utils.py +55 -18
- ovld/version.py +1 -1
- {ovld-0.4.5.dist-info → ovld-0.5.0.dist-info}/METADATA +62 -16
- ovld-0.5.0.dist-info/RECORD +18 -0
- {ovld-0.4.5.dist-info → ovld-0.5.0.dist-info}/WHEEL +1 -1
- ovld-0.4.5.dist-info/RECORD +0 -14
- {ovld-0.4.5.dist-info → ovld-0.5.0.dist-info}/licenses/LICENSE +0 -0
ovld/dependent.py
CHANGED
@@ -3,7 +3,6 @@ import re
|
|
3
3
|
from collections.abc import Callable as _Callable
|
4
4
|
from collections.abc import Mapping, Sequence
|
5
5
|
from functools import partial
|
6
|
-
from itertools import count
|
7
6
|
from typing import (
|
8
7
|
TYPE_CHECKING,
|
9
8
|
Any,
|
@@ -11,57 +10,22 @@ from typing import (
|
|
11
10
|
TypeVar,
|
12
11
|
)
|
13
12
|
|
13
|
+
from .codegen import Code
|
14
14
|
from .types import (
|
15
15
|
Intersection,
|
16
16
|
Order,
|
17
17
|
clsstring,
|
18
|
+
get_args,
|
18
19
|
normalize_type,
|
19
20
|
subclasscheck,
|
20
21
|
typeorder,
|
21
22
|
)
|
22
23
|
|
23
|
-
_current = count()
|
24
|
-
|
25
|
-
|
26
|
-
def generate_checking_code(typ):
|
27
|
-
if hasattr(typ, "codegen"):
|
28
|
-
return typ.codegen()
|
29
|
-
else:
|
30
|
-
return CodeGen("isinstance({arg}, {this})", this=typ)
|
31
|
-
|
32
|
-
|
33
|
-
class CodeGen:
|
34
|
-
def __init__(self, template, substitutions={}, **substitutions_kw):
|
35
|
-
self.template = template
|
36
|
-
self.substitutions = {**substitutions, **substitutions_kw}
|
37
|
-
|
38
|
-
def mangle(self):
|
39
|
-
renamings = {
|
40
|
-
k: f"{{{k}__{next(_current)}}}" for k in self.substitutions
|
41
|
-
}
|
42
|
-
renamings["arg"] = "{arg}"
|
43
|
-
new_subs = {
|
44
|
-
newk[1:-1]: self.substitutions[k]
|
45
|
-
for k, newk in renamings.items()
|
46
|
-
if k in self.substitutions
|
47
|
-
}
|
48
|
-
return CodeGen(self.template.format(**renamings), new_subs)
|
49
|
-
|
50
|
-
|
51
|
-
def combine(master_template, args):
|
52
|
-
fmts = []
|
53
|
-
subs = {}
|
54
|
-
for cg in args:
|
55
|
-
mangled = cg.mangle()
|
56
|
-
fmts.append(mangled.template)
|
57
|
-
subs.update(mangled.substitutions)
|
58
|
-
return CodeGen(master_template.format(*fmts), subs)
|
59
|
-
|
60
24
|
|
61
25
|
def is_dependent(t):
|
62
26
|
if isinstance(t, DependentType):
|
63
27
|
return True
|
64
|
-
elif any(is_dependent(subt) for subt in
|
28
|
+
elif any(is_dependent(subt) for subt in get_args(t)):
|
65
29
|
return True
|
66
30
|
return False
|
67
31
|
|
@@ -89,7 +53,7 @@ class DependentType(type):
|
|
89
53
|
raise NotImplementedError()
|
90
54
|
|
91
55
|
def codegen(self):
|
92
|
-
return
|
56
|
+
return Code("$this.check($arg)", this=self)
|
93
57
|
|
94
58
|
def __type_order__(self, other):
|
95
59
|
if isinstance(other, DependentType):
|
@@ -105,9 +69,7 @@ class DependentType(type):
|
|
105
69
|
return Order.NONE
|
106
70
|
else: # pragma: no cover
|
107
71
|
return order
|
108
|
-
elif subclasscheck(other, self.bound) or subclasscheck(
|
109
|
-
self.bound, other
|
110
|
-
):
|
72
|
+
elif subclasscheck(other, self.bound) or subclasscheck(self.bound, other):
|
111
73
|
return Order.LESS
|
112
74
|
else:
|
113
75
|
return Order.NONE
|
@@ -137,9 +99,7 @@ class DependentType(type):
|
|
137
99
|
|
138
100
|
class ParametrizedDependentType(DependentType):
|
139
101
|
def __init__(self, *parameters, bound=None):
|
140
|
-
super().__init__(
|
141
|
-
self.default_bound(*parameters) if bound is None else bound
|
142
|
-
)
|
102
|
+
super().__init__(self.default_bound(*parameters) if bound is None else bound)
|
143
103
|
self.__args__ = self.parameters = parameters
|
144
104
|
self.__origin__ = None
|
145
105
|
self.__post_init__()
|
@@ -191,12 +151,10 @@ class FuncDependentType(ParametrizedDependentType):
|
|
191
151
|
if len(self.parameters) != len(other.parameters):
|
192
152
|
return False
|
193
153
|
p1g = sum(
|
194
|
-
p1 is Any and p2 is not Any
|
195
|
-
for p1, p2 in zip(self.parameters, other.parameters)
|
154
|
+
p1 is Any and p2 is not Any for p1, p2 in zip(self.parameters, other.parameters)
|
196
155
|
)
|
197
156
|
p2g = sum(
|
198
|
-
p2 is Any and p1 is not Any
|
199
|
-
for p1, p2 in zip(self.parameters, other.parameters)
|
157
|
+
p2 is Any and p1 is not Any for p1, p2 in zip(self.parameters, other.parameters)
|
200
158
|
)
|
201
159
|
return p2g and not p1g
|
202
160
|
|
@@ -255,16 +213,16 @@ class Equals(ParametrizedDependentType):
|
|
255
213
|
|
256
214
|
@classmethod
|
257
215
|
def keygen(cls):
|
258
|
-
return "
|
216
|
+
return Code("$arg")
|
259
217
|
|
260
218
|
def get_keys(self):
|
261
219
|
return [self.parameter]
|
262
220
|
|
263
221
|
def codegen(self):
|
264
222
|
if len(self.parameters) == 1:
|
265
|
-
return
|
223
|
+
return Code("($arg == $p)", p=self.parameter)
|
266
224
|
else:
|
267
|
-
return
|
225
|
+
return Code("($arg in $ps)", ps=self.parameters)
|
268
226
|
|
269
227
|
|
270
228
|
class ProductType(ParametrizedDependentType):
|
@@ -281,19 +239,18 @@ class ProductType(ParametrizedDependentType):
|
|
281
239
|
)
|
282
240
|
|
283
241
|
def codegen(self):
|
284
|
-
checks = ["len(
|
242
|
+
checks = ["len($arg) == $n"]
|
285
243
|
params = {"n": len(self.parameters)}
|
286
244
|
for i, p in enumerate(self.parameters):
|
287
|
-
checks.append(f"isinstance(
|
245
|
+
checks.append(f"isinstance($arg[{i}], $p{i})")
|
288
246
|
params[f"p{i}"] = p
|
289
|
-
return
|
247
|
+
return Code(" and ".join(checks), params)
|
290
248
|
|
291
249
|
def __type_order__(self, other):
|
292
250
|
if isinstance(other, ProductType):
|
293
251
|
if len(other.parameters) == len(self.parameters):
|
294
252
|
return Order.merge(
|
295
|
-
typeorder(a, b)
|
296
|
-
for a, b in zip(self.parameters, other.parameters)
|
253
|
+
typeorder(a, b) for a, b in zip(self.parameters, other.parameters)
|
297
254
|
)
|
298
255
|
else:
|
299
256
|
return Order.NONE
|
@@ -337,8 +294,14 @@ def Callable(fn: _Callable, argt, rett):
|
|
337
294
|
|
338
295
|
|
339
296
|
@dependent_check
|
340
|
-
|
341
|
-
|
297
|
+
class HasKey:
|
298
|
+
def check(self, value: Mapping):
|
299
|
+
return all(k in value for k in self.parameters)
|
300
|
+
|
301
|
+
def codegen(self):
|
302
|
+
return Code(
|
303
|
+
"($[ and ]checks)", checks=[Code("($k in $arg)", k=k) for k in self.parameters]
|
304
|
+
)
|
342
305
|
|
343
306
|
|
344
307
|
@dependent_check
|
@@ -360,7 +323,7 @@ class Regexp:
|
|
360
323
|
return bool(self.rx.search(value))
|
361
324
|
|
362
325
|
def codegen(self):
|
363
|
-
return
|
326
|
+
return Code("bool($rx.search($arg))", rx=self.rx)
|
364
327
|
|
365
328
|
|
366
329
|
class Dependent:
|
@@ -377,14 +340,3 @@ if TYPE_CHECKING: # pragma: no cover
|
|
377
340
|
T = TypeVar("T")
|
378
341
|
A = TypeVar("A")
|
379
342
|
Dependent: TypeAlias = Annotated[T, A]
|
380
|
-
|
381
|
-
|
382
|
-
__all__ = [
|
383
|
-
"Dependent",
|
384
|
-
"DependentType",
|
385
|
-
"Equals",
|
386
|
-
"HasKey",
|
387
|
-
"StartsWith",
|
388
|
-
"EndsWith",
|
389
|
-
"dependent_check",
|
390
|
-
]
|
ovld/medley.py
ADDED
@@ -0,0 +1,408 @@
|
|
1
|
+
import functools
|
2
|
+
import inspect
|
3
|
+
from copy import copy
|
4
|
+
from dataclasses import MISSING, dataclass, fields, make_dataclass, replace
|
5
|
+
from typing import Annotated, TypeVar, get_origin
|
6
|
+
|
7
|
+
from .core import Ovld, to_ovld
|
8
|
+
from .types import eval_annotation
|
9
|
+
from .utils import Named
|
10
|
+
|
11
|
+
ABSENT = Named("ABSENT")
|
12
|
+
CODEGEN = Named("CODEGEN")
|
13
|
+
|
14
|
+
|
15
|
+
class Combiner:
|
16
|
+
def __init__(self, field=None):
|
17
|
+
self.field = field
|
18
|
+
|
19
|
+
def __set_name__(self, obj, field):
|
20
|
+
self.field = field
|
21
|
+
|
22
|
+
def get(self, cls): # pragma: no cover
|
23
|
+
raise NotImplementedError()
|
24
|
+
|
25
|
+
def copy(self):
|
26
|
+
return type(self)(self.field)
|
27
|
+
|
28
|
+
def include(self, other):
|
29
|
+
if type(self) is not type(other):
|
30
|
+
raise TypeError("Cannot merge different combiner classes.")
|
31
|
+
self.include_sametype(other)
|
32
|
+
|
33
|
+
def include_sametype(self, other): # pragma: no cover
|
34
|
+
pass
|
35
|
+
|
36
|
+
def juxtapose(self, impl): # pragma: no cover
|
37
|
+
raise NotImplementedError()
|
38
|
+
|
39
|
+
|
40
|
+
class KeepLast(Combiner):
|
41
|
+
def __init__(self, field=None):
|
42
|
+
super().__init__(field)
|
43
|
+
self.impl = ABSENT
|
44
|
+
|
45
|
+
def get(self, cls):
|
46
|
+
return self.impl
|
47
|
+
|
48
|
+
def include_sametype(self, other):
|
49
|
+
self.impl = other.impl
|
50
|
+
|
51
|
+
def juxtapose(self, impl):
|
52
|
+
self.impl = impl
|
53
|
+
|
54
|
+
|
55
|
+
class ImplList(Combiner):
|
56
|
+
def __init__(self, field=None, impls=None):
|
57
|
+
super().__init__(field)
|
58
|
+
self.impls = impls or []
|
59
|
+
|
60
|
+
def copy(self):
|
61
|
+
return type(self)(self.field, self.impls)
|
62
|
+
|
63
|
+
def get(self, cls):
|
64
|
+
if not self.impls:
|
65
|
+
return ABSENT
|
66
|
+
rval = self.wrap()
|
67
|
+
return functools.wraps(self.impls[0])(rval)
|
68
|
+
|
69
|
+
def wrap(self): # pragma: no cover
|
70
|
+
raise NotImplementedError()
|
71
|
+
|
72
|
+
def include_sametype(self, other):
|
73
|
+
self.impls += other.impls
|
74
|
+
|
75
|
+
def juxtapose(self, impl):
|
76
|
+
self.impls.append(impl)
|
77
|
+
|
78
|
+
|
79
|
+
class RunAll(ImplList):
|
80
|
+
def wrap(_self):
|
81
|
+
def run_all(self, *args, **kwargs):
|
82
|
+
for impl in _self.impls:
|
83
|
+
impl(self, *args, **kwargs)
|
84
|
+
|
85
|
+
return run_all
|
86
|
+
|
87
|
+
|
88
|
+
class ReduceAll(ImplList):
|
89
|
+
def wrap(_self):
|
90
|
+
def reduce_all(self, x, *args, **kwargs):
|
91
|
+
result = _self.impls[0](self, x, *args, **kwargs)
|
92
|
+
for impl in _self.impls[1:]:
|
93
|
+
result = impl(self, result, *args, **kwargs)
|
94
|
+
return result
|
95
|
+
|
96
|
+
return reduce_all
|
97
|
+
|
98
|
+
|
99
|
+
class ChainAll(ImplList):
|
100
|
+
def wrap(_self):
|
101
|
+
def chain_all(self, *args, **kwargs):
|
102
|
+
self = _self.impls[0](self, *args, **kwargs)
|
103
|
+
for impl in _self.impls[1:]:
|
104
|
+
self = impl(self, *args, **kwargs)
|
105
|
+
return self
|
106
|
+
|
107
|
+
return chain_all
|
108
|
+
|
109
|
+
|
110
|
+
class BuildOvld(Combiner):
|
111
|
+
def __init__(self, field=None, ovld=None):
|
112
|
+
super().__init__(field)
|
113
|
+
self.ovld = ovld or Ovld(linkback=True)
|
114
|
+
self.pending = []
|
115
|
+
if field is not None:
|
116
|
+
self.__set_name__(None, field)
|
117
|
+
|
118
|
+
def __set_name__(self, obj, field):
|
119
|
+
self.ovld.rename(field)
|
120
|
+
|
121
|
+
def get(self, cls):
|
122
|
+
self.ovld.specialization_self = cls
|
123
|
+
for f, arg in self.pending:
|
124
|
+
f(arg)
|
125
|
+
self.pending.clear()
|
126
|
+
if not self.ovld.defns:
|
127
|
+
return ABSENT
|
128
|
+
self.ovld.compile()
|
129
|
+
return self.ovld.dispatch
|
130
|
+
|
131
|
+
def copy(self):
|
132
|
+
return type(self)(self.field, self.ovld.copy(linkback=True))
|
133
|
+
|
134
|
+
def include_sametype(self, other):
|
135
|
+
self.ovld.add_mixins(other.ovld)
|
136
|
+
|
137
|
+
def juxtapose(self, impl):
|
138
|
+
if ov := to_ovld(impl, force=False):
|
139
|
+
self.pending.append((self.ovld.add_mixins, ov))
|
140
|
+
elif inspect.isfunction(impl):
|
141
|
+
self.pending.append((self.ovld.register, impl))
|
142
|
+
else: # pragma: no cover
|
143
|
+
raise TypeError("Expected a function or ovld.")
|
144
|
+
|
145
|
+
|
146
|
+
class medley_cls_dict(dict):
|
147
|
+
def __init__(self, bases):
|
148
|
+
super().__init__()
|
149
|
+
self._combiners = {}
|
150
|
+
self.set_direct("_ovld_combiners", self._combiners)
|
151
|
+
self._basic = set()
|
152
|
+
for base in bases:
|
153
|
+
for attr, combiner in getattr(base, "_ovld_combiners", {}).items():
|
154
|
+
if attr in self._combiners:
|
155
|
+
self._combiners[attr].include(combiner)
|
156
|
+
else:
|
157
|
+
self._combiners[attr] = combiner.copy()
|
158
|
+
|
159
|
+
def set_direct(self, attr, value):
|
160
|
+
super().__setitem__(attr, value)
|
161
|
+
|
162
|
+
def __setitem__(self, attr, value):
|
163
|
+
if attr == "__annotations__":
|
164
|
+
self.set_direct(attr, value)
|
165
|
+
return
|
166
|
+
|
167
|
+
if attr == "__init__":
|
168
|
+
raise Exception("Do not define __init__ in a Medley, use __post_init__.")
|
169
|
+
|
170
|
+
if isinstance(value, Combiner):
|
171
|
+
value.__set_name__(None, attr)
|
172
|
+
self._combiners[attr] = value
|
173
|
+
return
|
174
|
+
|
175
|
+
combiner = self._combiners.get(attr, None)
|
176
|
+
if combiner is None:
|
177
|
+
if inspect.isfunction(value) or isinstance(value, Ovld):
|
178
|
+
combiner = BuildOvld(attr)
|
179
|
+
else:
|
180
|
+
combiner = KeepLast(attr)
|
181
|
+
self._combiners[attr] = combiner
|
182
|
+
|
183
|
+
combiner.juxtapose(value)
|
184
|
+
|
185
|
+
def __missing__(self, attr):
|
186
|
+
if attr in self._combiners:
|
187
|
+
if (value := self._combiners[attr].get(None)) is not ABSENT:
|
188
|
+
return value
|
189
|
+
raise KeyError(attr)
|
190
|
+
|
191
|
+
|
192
|
+
def codegen_key(*instances):
|
193
|
+
rval = {}
|
194
|
+
for instance in instances:
|
195
|
+
keyd = {name: getattr(instance, name) for name in type(instance)._ovld_codegen_fields}
|
196
|
+
rval.update(keyd)
|
197
|
+
return rval
|
198
|
+
|
199
|
+
|
200
|
+
def specialize(cls, key):
|
201
|
+
ns = medley_cls_dict((cls,))
|
202
|
+
new_t = MedleyMC(cls.__name__, (cls,), ns)
|
203
|
+
new_t._ovld_specialization_parent = cls
|
204
|
+
for k, v in key.items():
|
205
|
+
setattr(new_t, k, v)
|
206
|
+
cls._ovld_codegen_fields = list(key.keys())
|
207
|
+
return new_t
|
208
|
+
|
209
|
+
|
210
|
+
class MedleyMC(type):
|
211
|
+
def __subclasscheck__(cls, subclass):
|
212
|
+
if getattr(cls, "_ovld_medleys", None):
|
213
|
+
return all(issubclass(subclass, m) for m in cls._ovld_medleys)
|
214
|
+
return super().__subclasscheck__(subclass)
|
215
|
+
|
216
|
+
@classmethod
|
217
|
+
def __prepare__(mcls, name, bases):
|
218
|
+
return medley_cls_dict(bases)
|
219
|
+
|
220
|
+
def __new__(mcls, name, bases, namespace):
|
221
|
+
result = super().__new__(mcls, name, bases, namespace)
|
222
|
+
for attr, combiner in result._ovld_combiners.items():
|
223
|
+
if (value := combiner.get(result)) is not ABSENT:
|
224
|
+
setattr(result, attr, value)
|
225
|
+
dc = dataclass(result)
|
226
|
+
dc._ovld_specialization_parent = None
|
227
|
+
dc._ovld_specializations = {}
|
228
|
+
dc._ovld_codegen_fields = [
|
229
|
+
field.name
|
230
|
+
for field in fields(dc)
|
231
|
+
if (
|
232
|
+
(t := eval_annotation(field.type, dc, {}, catch=True))
|
233
|
+
and get_origin(t) is Annotated
|
234
|
+
and CODEGEN in t.__metadata__
|
235
|
+
)
|
236
|
+
]
|
237
|
+
return dc
|
238
|
+
|
239
|
+
def extend(cls, *others):
|
240
|
+
if not others:
|
241
|
+
return cls
|
242
|
+
melded = meld_classes((cls, *others), require_defaults=True)
|
243
|
+
for other in others:
|
244
|
+
for k, v in vars(other).items():
|
245
|
+
if k in ["__module__", "__firstlineno__"]:
|
246
|
+
continue
|
247
|
+
elif comb := cls._ovld_combiners.get(k):
|
248
|
+
comb.juxtapose(v)
|
249
|
+
setattr(cls, k, comb.get(cls))
|
250
|
+
elif not k.startswith("_ovld_") and not k.startswith("__"):
|
251
|
+
setattr(cls, k, v)
|
252
|
+
cls.__init__ = melded.__init__
|
253
|
+
for subcls in cls.__subclasses__():
|
254
|
+
subothers = [o for o in others if not issubclass(subcls, o)]
|
255
|
+
subcls.extend(*subothers)
|
256
|
+
return cls
|
257
|
+
|
258
|
+
def __add__(cls, other):
|
259
|
+
return meld_classes((cls, other))
|
260
|
+
|
261
|
+
def __iadd__(cls, other):
|
262
|
+
return cls.extend(other)
|
263
|
+
|
264
|
+
def __sub__(cls, other):
|
265
|
+
return unmeld_classes(cls, other)
|
266
|
+
|
267
|
+
def __call__(cls, *args, **kwargs):
|
268
|
+
made = super().__call__(*args, **kwargs)
|
269
|
+
if cls._ovld_codegen_fields and (keyd := codegen_key(made)):
|
270
|
+
cls = cls._ovld_specialization_parent or cls
|
271
|
+
key = tuple(sorted(keyd.items()))
|
272
|
+
if key in cls._ovld_specializations:
|
273
|
+
new_t = cls._ovld_specializations[key]
|
274
|
+
else:
|
275
|
+
new_t = specialize(cls, keyd)
|
276
|
+
cls._ovld_specializations[key] = new_t
|
277
|
+
obj = object.__new__(new_t)
|
278
|
+
obj.__dict__.update(made.__dict__)
|
279
|
+
return obj
|
280
|
+
else:
|
281
|
+
return made
|
282
|
+
|
283
|
+
|
284
|
+
def use_combiner(combiner):
|
285
|
+
def deco(fn):
|
286
|
+
cmb = combiner(fn.__name__)
|
287
|
+
cmb.juxtapose(fn)
|
288
|
+
return cmb
|
289
|
+
|
290
|
+
return deco
|
291
|
+
|
292
|
+
|
293
|
+
class Medley(metaclass=MedleyMC):
|
294
|
+
__post_init__ = RunAll()
|
295
|
+
__add__ = KeepLast()
|
296
|
+
__sub__ = KeepLast()
|
297
|
+
|
298
|
+
def __add__(self, other):
|
299
|
+
if isinstance(self, type(other)) and not type(self)._ovld_codegen_fields:
|
300
|
+
return replace(self, **vars(other))
|
301
|
+
else:
|
302
|
+
return meld([self, other])
|
303
|
+
|
304
|
+
def __sub__(self, other):
|
305
|
+
return unmeld(self, other)
|
306
|
+
|
307
|
+
|
308
|
+
def unmeld_classes(main: type, exclude: type):
|
309
|
+
classes = tuple(c for c in main.__bases__ if c is not exclude)
|
310
|
+
return meld_classes(classes)
|
311
|
+
|
312
|
+
|
313
|
+
_meld_classes_cache = {}
|
314
|
+
|
315
|
+
|
316
|
+
def meld_classes(classes, require_defaults=False):
|
317
|
+
medleys = {}
|
318
|
+
for i, cls in enumerate(classes):
|
319
|
+
if require_defaults and i == 0:
|
320
|
+
medleys[cls] = True
|
321
|
+
else:
|
322
|
+
medleys.update({x: True for x in getattr(cls, "_ovld_medleys", [cls])})
|
323
|
+
for cls in classes:
|
324
|
+
if not hasattr(cls, "_ovld_medleys"):
|
325
|
+
for base in cls.mro():
|
326
|
+
if base is not cls and base in medleys:
|
327
|
+
del medleys[base]
|
328
|
+
medleys = tuple(medleys)
|
329
|
+
if len(medleys) == 1:
|
330
|
+
return medleys[0]
|
331
|
+
|
332
|
+
cache_key = (medleys, require_defaults)
|
333
|
+
if cache_key in _meld_classes_cache:
|
334
|
+
return _meld_classes_cache[cache_key]
|
335
|
+
|
336
|
+
def remap_field(dc_field, require_default):
|
337
|
+
if require_default:
|
338
|
+
if dc_field.default is MISSING:
|
339
|
+
# NOTE: we do not accept default_factory, because we need the default value to be set
|
340
|
+
# in the class so that existing instances of classes[0] can see it.
|
341
|
+
raise TypeError(
|
342
|
+
f"Dataclass field '{dc_field.name}' must have a default value (not a default_factory) in order to be melded in."
|
343
|
+
)
|
344
|
+
dc_field = copy(dc_field)
|
345
|
+
dc_field.kw_only = True
|
346
|
+
return dc_field
|
347
|
+
|
348
|
+
cg_fields = set()
|
349
|
+
dc_fields = []
|
350
|
+
|
351
|
+
for base in medleys:
|
352
|
+
rqdef = require_defaults and base is not medleys[0]
|
353
|
+
cg_fields.update(base._ovld_codegen_fields)
|
354
|
+
dc_fields.extend(
|
355
|
+
(f.name, f.type, remap_field(f, rqdef)) for f in base.__dataclass_fields__.values()
|
356
|
+
)
|
357
|
+
|
358
|
+
merged = medley_cls_dict(medleys)
|
359
|
+
merged.set_direct("_ovld_codegen_fields", tuple(cg_fields))
|
360
|
+
merged.set_direct("_ovld_medleys", tuple(medleys))
|
361
|
+
|
362
|
+
result = make_dataclass(
|
363
|
+
cls_name="+".join(sorted(c.__name__ for c in medleys)),
|
364
|
+
bases=tuple(medleys),
|
365
|
+
fields=dc_fields,
|
366
|
+
kw_only=True,
|
367
|
+
namespace=merged,
|
368
|
+
)
|
369
|
+
|
370
|
+
_meld_classes_cache[cache_key] = result
|
371
|
+
return result
|
372
|
+
|
373
|
+
|
374
|
+
@functools.cache
|
375
|
+
def meld_classes_with_key(classes, key):
|
376
|
+
key = dict(key)
|
377
|
+
typ = meld_classes(classes)
|
378
|
+
if not key:
|
379
|
+
return typ
|
380
|
+
else:
|
381
|
+
return specialize(typ, key)
|
382
|
+
|
383
|
+
|
384
|
+
def meld(objects):
|
385
|
+
key = codegen_key(*objects)
|
386
|
+
classes = tuple(type(o) for o in objects)
|
387
|
+
cls = meld_classes_with_key(classes, tuple(key.items()))
|
388
|
+
obj = object.__new__(cls)
|
389
|
+
for o in objects:
|
390
|
+
for k, v in vars(o).items():
|
391
|
+
setattr(obj, k, v)
|
392
|
+
return obj
|
393
|
+
|
394
|
+
|
395
|
+
def unmeld(obj: object, exclude: type):
|
396
|
+
if type(obj)._ovld_codegen_fields: # pragma: no cover
|
397
|
+
raise TypeError("Cannot unmeld an object with codegen fields")
|
398
|
+
cls = unmeld_classes(type(obj), exclude)
|
399
|
+
values = {}
|
400
|
+
excluded = exclude.__dataclass_fields__
|
401
|
+
for f in cls.__dataclass_fields__.values():
|
402
|
+
if f.name not in excluded:
|
403
|
+
values[f.name] = getattr(obj, f.name)
|
404
|
+
return cls(**values)
|
405
|
+
|
406
|
+
|
407
|
+
T = TypeVar("T")
|
408
|
+
CodegenParameter = Annotated[T, CODEGEN]
|
ovld/mro.py
CHANGED
@@ -3,6 +3,8 @@ from enum import Enum
|
|
3
3
|
from graphlib import TopologicalSorter
|
4
4
|
from typing import get_args, get_origin
|
5
5
|
|
6
|
+
from .utils import UnionTypes
|
7
|
+
|
6
8
|
|
7
9
|
class Order(Enum):
|
8
10
|
LESS = -1
|
@@ -121,6 +123,9 @@ def subclasscheck(t1, t2):
|
|
121
123
|
):
|
122
124
|
return result
|
123
125
|
|
126
|
+
if t2 in UnionTypes:
|
127
|
+
return isinstance(t1, t2)
|
128
|
+
|
124
129
|
o1 = get_origin(t1)
|
125
130
|
o2 = get_origin(t2)
|
126
131
|
|
@@ -140,9 +145,7 @@ def subclasscheck(t1, t2):
|
|
140
145
|
args2 = get_args(t2)
|
141
146
|
if len(args1) != len(args2):
|
142
147
|
return False
|
143
|
-
return all(
|
144
|
-
subclasscheck(a1, a2) for a1, a2 in zip(args1, args2)
|
145
|
-
)
|
148
|
+
return all(subclasscheck(a1, a2) for a1, a2 in zip(args1, args2))
|
146
149
|
else:
|
147
150
|
return False
|
148
151
|
else:
|
ovld/py.typed
ADDED
File without changes
|