ovld 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/dependent.py CHANGED
@@ -3,7 +3,6 @@ import re
3
3
  from collections.abc import Callable as _Callable
4
4
  from collections.abc import Mapping, Sequence
5
5
  from functools import partial
6
- from itertools import count
7
6
  from typing import (
8
7
  TYPE_CHECKING,
9
8
  Any,
@@ -11,57 +10,22 @@ from typing import (
11
10
  TypeVar,
12
11
  )
13
12
 
13
+ from .codegen import Code
14
14
  from .types import (
15
15
  Intersection,
16
16
  Order,
17
17
  clsstring,
18
+ get_args,
18
19
  normalize_type,
19
20
  subclasscheck,
20
21
  typeorder,
21
22
  )
22
23
 
23
- _current = count()
24
-
25
-
26
- def generate_checking_code(typ):
27
- if hasattr(typ, "codegen"):
28
- return typ.codegen()
29
- else:
30
- return CodeGen("isinstance({arg}, {this})", this=typ)
31
-
32
-
33
- class CodeGen:
34
- def __init__(self, template, substitutions={}, **substitutions_kw):
35
- self.template = template
36
- self.substitutions = {**substitutions, **substitutions_kw}
37
-
38
- def mangle(self):
39
- renamings = {
40
- k: f"{{{k}__{next(_current)}}}" for k in self.substitutions
41
- }
42
- renamings["arg"] = "{arg}"
43
- new_subs = {
44
- newk[1:-1]: self.substitutions[k]
45
- for k, newk in renamings.items()
46
- if k in self.substitutions
47
- }
48
- return CodeGen(self.template.format(**renamings), new_subs)
49
-
50
-
51
- def combine(master_template, args):
52
- fmts = []
53
- subs = {}
54
- for cg in args:
55
- mangled = cg.mangle()
56
- fmts.append(mangled.template)
57
- subs.update(mangled.substitutions)
58
- return CodeGen(master_template.format(*fmts), subs)
59
-
60
24
 
61
25
  def is_dependent(t):
62
26
  if isinstance(t, DependentType):
63
27
  return True
64
- elif any(is_dependent(subt) for subt in getattr(t, "__args__", ())):
28
+ elif any(is_dependent(subt) for subt in get_args(t)):
65
29
  return True
66
30
  return False
67
31
 
@@ -89,7 +53,7 @@ class DependentType(type):
89
53
  raise NotImplementedError()
90
54
 
91
55
  def codegen(self):
92
- return CodeGen("{this}.check({arg})", this=self)
56
+ return Code("$this.check($arg)", this=self)
93
57
 
94
58
  def __type_order__(self, other):
95
59
  if isinstance(other, DependentType):
@@ -105,9 +69,7 @@ class DependentType(type):
105
69
  return Order.NONE
106
70
  else: # pragma: no cover
107
71
  return order
108
- elif subclasscheck(other, self.bound) or subclasscheck(
109
- self.bound, other
110
- ):
72
+ elif subclasscheck(other, self.bound) or subclasscheck(self.bound, other):
111
73
  return Order.LESS
112
74
  else:
113
75
  return Order.NONE
@@ -137,9 +99,7 @@ class DependentType(type):
137
99
 
138
100
  class ParametrizedDependentType(DependentType):
139
101
  def __init__(self, *parameters, bound=None):
140
- super().__init__(
141
- self.default_bound(*parameters) if bound is None else bound
142
- )
102
+ super().__init__(self.default_bound(*parameters) if bound is None else bound)
143
103
  self.__args__ = self.parameters = parameters
144
104
  self.__origin__ = None
145
105
  self.__post_init__()
@@ -191,12 +151,10 @@ class FuncDependentType(ParametrizedDependentType):
191
151
  if len(self.parameters) != len(other.parameters):
192
152
  return False
193
153
  p1g = sum(
194
- p1 is Any and p2 is not Any
195
- for p1, p2 in zip(self.parameters, other.parameters)
154
+ p1 is Any and p2 is not Any for p1, p2 in zip(self.parameters, other.parameters)
196
155
  )
197
156
  p2g = sum(
198
- p2 is Any and p1 is not Any
199
- for p1, p2 in zip(self.parameters, other.parameters)
157
+ p2 is Any and p1 is not Any for p1, p2 in zip(self.parameters, other.parameters)
200
158
  )
201
159
  return p2g and not p1g
202
160
 
@@ -255,16 +213,16 @@ class Equals(ParametrizedDependentType):
255
213
 
256
214
  @classmethod
257
215
  def keygen(cls):
258
- return "{arg}"
216
+ return Code("$arg")
259
217
 
260
218
  def get_keys(self):
261
219
  return [self.parameter]
262
220
 
263
221
  def codegen(self):
264
222
  if len(self.parameters) == 1:
265
- return CodeGen("({arg} == {p})", p=self.parameter)
223
+ return Code("($arg == $p)", p=self.parameter)
266
224
  else:
267
- return CodeGen("({arg} in {ps})", ps=self.parameters)
225
+ return Code("($arg in $ps)", ps=self.parameters)
268
226
 
269
227
 
270
228
  class ProductType(ParametrizedDependentType):
@@ -281,19 +239,18 @@ class ProductType(ParametrizedDependentType):
281
239
  )
282
240
 
283
241
  def codegen(self):
284
- checks = ["len({arg}) == {n}"]
242
+ checks = ["len($arg) == $n"]
285
243
  params = {"n": len(self.parameters)}
286
244
  for i, p in enumerate(self.parameters):
287
- checks.append(f"isinstance({{arg}}[{i}], {{p{i}}})")
245
+ checks.append(f"isinstance($arg[{i}], $p{i})")
288
246
  params[f"p{i}"] = p
289
- return CodeGen(" and ".join(checks), params)
247
+ return Code(" and ".join(checks), params)
290
248
 
291
249
  def __type_order__(self, other):
292
250
  if isinstance(other, ProductType):
293
251
  if len(other.parameters) == len(self.parameters):
294
252
  return Order.merge(
295
- typeorder(a, b)
296
- for a, b in zip(self.parameters, other.parameters)
253
+ typeorder(a, b) for a, b in zip(self.parameters, other.parameters)
297
254
  )
298
255
  else:
299
256
  return Order.NONE
@@ -337,8 +294,14 @@ def Callable(fn: _Callable, argt, rett):
337
294
 
338
295
 
339
296
  @dependent_check
340
- def HasKey(value: Mapping, *keys):
341
- return all(k in value for k in keys)
297
+ class HasKey:
298
+ def check(self, value: Mapping):
299
+ return all(k in value for k in self.parameters)
300
+
301
+ def codegen(self):
302
+ return Code(
303
+ "($[ and ]checks)", checks=[Code("($k in $arg)", k=k) for k in self.parameters]
304
+ )
342
305
 
343
306
 
344
307
  @dependent_check
@@ -360,7 +323,7 @@ class Regexp:
360
323
  return bool(self.rx.search(value))
361
324
 
362
325
  def codegen(self):
363
- return CodeGen("bool({rx}.search({arg}))", rx=self.rx)
326
+ return Code("bool($rx.search($arg))", rx=self.rx)
364
327
 
365
328
 
366
329
  class Dependent:
@@ -377,14 +340,3 @@ if TYPE_CHECKING: # pragma: no cover
377
340
  T = TypeVar("T")
378
341
  A = TypeVar("A")
379
342
  Dependent: TypeAlias = Annotated[T, A]
380
-
381
-
382
- __all__ = [
383
- "Dependent",
384
- "DependentType",
385
- "Equals",
386
- "HasKey",
387
- "StartsWith",
388
- "EndsWith",
389
- "dependent_check",
390
- ]
ovld/medley.py ADDED
@@ -0,0 +1,408 @@
1
+ import functools
2
+ import inspect
3
+ from copy import copy
4
+ from dataclasses import MISSING, dataclass, fields, make_dataclass, replace
5
+ from typing import Annotated, TypeVar, get_origin
6
+
7
+ from .core import Ovld, to_ovld
8
+ from .types import eval_annotation
9
+ from .utils import Named
10
+
11
+ ABSENT = Named("ABSENT")
12
+ CODEGEN = Named("CODEGEN")
13
+
14
+
15
+ class Combiner:
16
+ def __init__(self, field=None):
17
+ self.field = field
18
+
19
+ def __set_name__(self, obj, field):
20
+ self.field = field
21
+
22
+ def get(self, cls): # pragma: no cover
23
+ raise NotImplementedError()
24
+
25
+ def copy(self):
26
+ return type(self)(self.field)
27
+
28
+ def include(self, other):
29
+ if type(self) is not type(other):
30
+ raise TypeError("Cannot merge different combiner classes.")
31
+ self.include_sametype(other)
32
+
33
+ def include_sametype(self, other): # pragma: no cover
34
+ pass
35
+
36
+ def juxtapose(self, impl): # pragma: no cover
37
+ raise NotImplementedError()
38
+
39
+
40
+ class KeepLast(Combiner):
41
+ def __init__(self, field=None):
42
+ super().__init__(field)
43
+ self.impl = ABSENT
44
+
45
+ def get(self, cls):
46
+ return self.impl
47
+
48
+ def include_sametype(self, other):
49
+ self.impl = other.impl
50
+
51
+ def juxtapose(self, impl):
52
+ self.impl = impl
53
+
54
+
55
+ class ImplList(Combiner):
56
+ def __init__(self, field=None, impls=None):
57
+ super().__init__(field)
58
+ self.impls = impls or []
59
+
60
+ def copy(self):
61
+ return type(self)(self.field, self.impls)
62
+
63
+ def get(self, cls):
64
+ if not self.impls:
65
+ return ABSENT
66
+ rval = self.wrap()
67
+ return functools.wraps(self.impls[0])(rval)
68
+
69
+ def wrap(self): # pragma: no cover
70
+ raise NotImplementedError()
71
+
72
+ def include_sametype(self, other):
73
+ self.impls += other.impls
74
+
75
+ def juxtapose(self, impl):
76
+ self.impls.append(impl)
77
+
78
+
79
+ class RunAll(ImplList):
80
+ def wrap(_self):
81
+ def run_all(self, *args, **kwargs):
82
+ for impl in _self.impls:
83
+ impl(self, *args, **kwargs)
84
+
85
+ return run_all
86
+
87
+
88
+ class ReduceAll(ImplList):
89
+ def wrap(_self):
90
+ def reduce_all(self, x, *args, **kwargs):
91
+ result = _self.impls[0](self, x, *args, **kwargs)
92
+ for impl in _self.impls[1:]:
93
+ result = impl(self, result, *args, **kwargs)
94
+ return result
95
+
96
+ return reduce_all
97
+
98
+
99
+ class ChainAll(ImplList):
100
+ def wrap(_self):
101
+ def chain_all(self, *args, **kwargs):
102
+ self = _self.impls[0](self, *args, **kwargs)
103
+ for impl in _self.impls[1:]:
104
+ self = impl(self, *args, **kwargs)
105
+ return self
106
+
107
+ return chain_all
108
+
109
+
110
+ class BuildOvld(Combiner):
111
+ def __init__(self, field=None, ovld=None):
112
+ super().__init__(field)
113
+ self.ovld = ovld or Ovld(linkback=True)
114
+ self.pending = []
115
+ if field is not None:
116
+ self.__set_name__(None, field)
117
+
118
+ def __set_name__(self, obj, field):
119
+ self.ovld.rename(field)
120
+
121
+ def get(self, cls):
122
+ self.ovld.specialization_self = cls
123
+ for f, arg in self.pending:
124
+ f(arg)
125
+ self.pending.clear()
126
+ if not self.ovld.defns:
127
+ return ABSENT
128
+ self.ovld.compile()
129
+ return self.ovld.dispatch
130
+
131
+ def copy(self):
132
+ return type(self)(self.field, self.ovld.copy(linkback=True))
133
+
134
+ def include_sametype(self, other):
135
+ self.ovld.add_mixins(other.ovld)
136
+
137
+ def juxtapose(self, impl):
138
+ if ov := to_ovld(impl, force=False):
139
+ self.pending.append((self.ovld.add_mixins, ov))
140
+ elif inspect.isfunction(impl):
141
+ self.pending.append((self.ovld.register, impl))
142
+ else: # pragma: no cover
143
+ raise TypeError("Expected a function or ovld.")
144
+
145
+
146
+ class medley_cls_dict(dict):
147
+ def __init__(self, bases):
148
+ super().__init__()
149
+ self._combiners = {}
150
+ self.set_direct("_ovld_combiners", self._combiners)
151
+ self._basic = set()
152
+ for base in bases:
153
+ for attr, combiner in getattr(base, "_ovld_combiners", {}).items():
154
+ if attr in self._combiners:
155
+ self._combiners[attr].include(combiner)
156
+ else:
157
+ self._combiners[attr] = combiner.copy()
158
+
159
+ def set_direct(self, attr, value):
160
+ super().__setitem__(attr, value)
161
+
162
+ def __setitem__(self, attr, value):
163
+ if attr == "__annotations__":
164
+ self.set_direct(attr, value)
165
+ return
166
+
167
+ if attr == "__init__":
168
+ raise Exception("Do not define __init__ in a Medley, use __post_init__.")
169
+
170
+ if isinstance(value, Combiner):
171
+ value.__set_name__(None, attr)
172
+ self._combiners[attr] = value
173
+ return
174
+
175
+ combiner = self._combiners.get(attr, None)
176
+ if combiner is None:
177
+ if inspect.isfunction(value) or isinstance(value, Ovld):
178
+ combiner = BuildOvld(attr)
179
+ else:
180
+ combiner = KeepLast(attr)
181
+ self._combiners[attr] = combiner
182
+
183
+ combiner.juxtapose(value)
184
+
185
+ def __missing__(self, attr):
186
+ if attr in self._combiners:
187
+ if (value := self._combiners[attr].get(None)) is not ABSENT:
188
+ return value
189
+ raise KeyError(attr)
190
+
191
+
192
+ def codegen_key(*instances):
193
+ rval = {}
194
+ for instance in instances:
195
+ keyd = {name: getattr(instance, name) for name in type(instance)._ovld_codegen_fields}
196
+ rval.update(keyd)
197
+ return rval
198
+
199
+
200
+ def specialize(cls, key):
201
+ ns = medley_cls_dict((cls,))
202
+ new_t = MedleyMC(cls.__name__, (cls,), ns)
203
+ new_t._ovld_specialization_parent = cls
204
+ for k, v in key.items():
205
+ setattr(new_t, k, v)
206
+ cls._ovld_codegen_fields = list(key.keys())
207
+ return new_t
208
+
209
+
210
+ class MedleyMC(type):
211
+ def __subclasscheck__(cls, subclass):
212
+ if getattr(cls, "_ovld_medleys", None):
213
+ return all(issubclass(subclass, m) for m in cls._ovld_medleys)
214
+ return super().__subclasscheck__(subclass)
215
+
216
+ @classmethod
217
+ def __prepare__(mcls, name, bases):
218
+ return medley_cls_dict(bases)
219
+
220
+ def __new__(mcls, name, bases, namespace):
221
+ result = super().__new__(mcls, name, bases, namespace)
222
+ for attr, combiner in result._ovld_combiners.items():
223
+ if (value := combiner.get(result)) is not ABSENT:
224
+ setattr(result, attr, value)
225
+ dc = dataclass(result)
226
+ dc._ovld_specialization_parent = None
227
+ dc._ovld_specializations = {}
228
+ dc._ovld_codegen_fields = [
229
+ field.name
230
+ for field in fields(dc)
231
+ if (
232
+ (t := eval_annotation(field.type, dc, {}, catch=True))
233
+ and get_origin(t) is Annotated
234
+ and CODEGEN in t.__metadata__
235
+ )
236
+ ]
237
+ return dc
238
+
239
+ def extend(cls, *others):
240
+ if not others:
241
+ return cls
242
+ melded = meld_classes((cls, *others), require_defaults=True)
243
+ for other in others:
244
+ for k, v in vars(other).items():
245
+ if k in ["__module__", "__firstlineno__"]:
246
+ continue
247
+ elif comb := cls._ovld_combiners.get(k):
248
+ comb.juxtapose(v)
249
+ setattr(cls, k, comb.get(cls))
250
+ elif not k.startswith("_ovld_") and not k.startswith("__"):
251
+ setattr(cls, k, v)
252
+ cls.__init__ = melded.__init__
253
+ for subcls in cls.__subclasses__():
254
+ subothers = [o for o in others if not issubclass(subcls, o)]
255
+ subcls.extend(*subothers)
256
+ return cls
257
+
258
+ def __add__(cls, other):
259
+ return meld_classes((cls, other))
260
+
261
+ def __iadd__(cls, other):
262
+ return cls.extend(other)
263
+
264
+ def __sub__(cls, other):
265
+ return unmeld_classes(cls, other)
266
+
267
+ def __call__(cls, *args, **kwargs):
268
+ made = super().__call__(*args, **kwargs)
269
+ if cls._ovld_codegen_fields and (keyd := codegen_key(made)):
270
+ cls = cls._ovld_specialization_parent or cls
271
+ key = tuple(sorted(keyd.items()))
272
+ if key in cls._ovld_specializations:
273
+ new_t = cls._ovld_specializations[key]
274
+ else:
275
+ new_t = specialize(cls, keyd)
276
+ cls._ovld_specializations[key] = new_t
277
+ obj = object.__new__(new_t)
278
+ obj.__dict__.update(made.__dict__)
279
+ return obj
280
+ else:
281
+ return made
282
+
283
+
284
+ def use_combiner(combiner):
285
+ def deco(fn):
286
+ cmb = combiner(fn.__name__)
287
+ cmb.juxtapose(fn)
288
+ return cmb
289
+
290
+ return deco
291
+
292
+
293
+ class Medley(metaclass=MedleyMC):
294
+ __post_init__ = RunAll()
295
+ __add__ = KeepLast()
296
+ __sub__ = KeepLast()
297
+
298
+ def __add__(self, other):
299
+ if isinstance(self, type(other)) and not type(self)._ovld_codegen_fields:
300
+ return replace(self, **vars(other))
301
+ else:
302
+ return meld([self, other])
303
+
304
+ def __sub__(self, other):
305
+ return unmeld(self, other)
306
+
307
+
308
+ def unmeld_classes(main: type, exclude: type):
309
+ classes = tuple(c for c in main.__bases__ if c is not exclude)
310
+ return meld_classes(classes)
311
+
312
+
313
+ _meld_classes_cache = {}
314
+
315
+
316
+ def meld_classes(classes, require_defaults=False):
317
+ medleys = {}
318
+ for i, cls in enumerate(classes):
319
+ if require_defaults and i == 0:
320
+ medleys[cls] = True
321
+ else:
322
+ medleys.update({x: True for x in getattr(cls, "_ovld_medleys", [cls])})
323
+ for cls in classes:
324
+ if not hasattr(cls, "_ovld_medleys"):
325
+ for base in cls.mro():
326
+ if base is not cls and base in medleys:
327
+ del medleys[base]
328
+ medleys = tuple(medleys)
329
+ if len(medleys) == 1:
330
+ return medleys[0]
331
+
332
+ cache_key = (medleys, require_defaults)
333
+ if cache_key in _meld_classes_cache:
334
+ return _meld_classes_cache[cache_key]
335
+
336
+ def remap_field(dc_field, require_default):
337
+ if require_default:
338
+ if dc_field.default is MISSING:
339
+ # NOTE: we do not accept default_factory, because we need the default value to be set
340
+ # in the class so that existing instances of classes[0] can see it.
341
+ raise TypeError(
342
+ f"Dataclass field '{dc_field.name}' must have a default value (not a default_factory) in order to be melded in."
343
+ )
344
+ dc_field = copy(dc_field)
345
+ dc_field.kw_only = True
346
+ return dc_field
347
+
348
+ cg_fields = set()
349
+ dc_fields = []
350
+
351
+ for base in medleys:
352
+ rqdef = require_defaults and base is not medleys[0]
353
+ cg_fields.update(base._ovld_codegen_fields)
354
+ dc_fields.extend(
355
+ (f.name, f.type, remap_field(f, rqdef)) for f in base.__dataclass_fields__.values()
356
+ )
357
+
358
+ merged = medley_cls_dict(medleys)
359
+ merged.set_direct("_ovld_codegen_fields", tuple(cg_fields))
360
+ merged.set_direct("_ovld_medleys", tuple(medleys))
361
+
362
+ result = make_dataclass(
363
+ cls_name="+".join(sorted(c.__name__ for c in medleys)),
364
+ bases=tuple(medleys),
365
+ fields=dc_fields,
366
+ kw_only=True,
367
+ namespace=merged,
368
+ )
369
+
370
+ _meld_classes_cache[cache_key] = result
371
+ return result
372
+
373
+
374
+ @functools.cache
375
+ def meld_classes_with_key(classes, key):
376
+ key = dict(key)
377
+ typ = meld_classes(classes)
378
+ if not key:
379
+ return typ
380
+ else:
381
+ return specialize(typ, key)
382
+
383
+
384
+ def meld(objects):
385
+ key = codegen_key(*objects)
386
+ classes = tuple(type(o) for o in objects)
387
+ cls = meld_classes_with_key(classes, tuple(key.items()))
388
+ obj = object.__new__(cls)
389
+ for o in objects:
390
+ for k, v in vars(o).items():
391
+ setattr(obj, k, v)
392
+ return obj
393
+
394
+
395
+ def unmeld(obj: object, exclude: type):
396
+ if type(obj)._ovld_codegen_fields: # pragma: no cover
397
+ raise TypeError("Cannot unmeld an object with codegen fields")
398
+ cls = unmeld_classes(type(obj), exclude)
399
+ values = {}
400
+ excluded = exclude.__dataclass_fields__
401
+ for f in cls.__dataclass_fields__.values():
402
+ if f.name not in excluded:
403
+ values[f.name] = getattr(obj, f.name)
404
+ return cls(**values)
405
+
406
+
407
+ T = TypeVar("T")
408
+ CodegenParameter = Annotated[T, CODEGEN]
ovld/mro.py CHANGED
@@ -3,6 +3,8 @@ from enum import Enum
3
3
  from graphlib import TopologicalSorter
4
4
  from typing import get_args, get_origin
5
5
 
6
+ from .utils import UnionTypes
7
+
6
8
 
7
9
  class Order(Enum):
8
10
  LESS = -1
@@ -121,6 +123,9 @@ def subclasscheck(t1, t2):
121
123
  ):
122
124
  return result
123
125
 
126
+ if t2 in UnionTypes:
127
+ return isinstance(t1, t2)
128
+
124
129
  o1 = get_origin(t1)
125
130
  o2 = get_origin(t2)
126
131
 
@@ -140,9 +145,7 @@ def subclasscheck(t1, t2):
140
145
  args2 = get_args(t2)
141
146
  if len(args1) != len(args2):
142
147
  return False
143
- return all(
144
- subclasscheck(a1, a2) for a1, a2 in zip(args1, args2)
145
- )
148
+ return all(subclasscheck(a1, a2) for a1, a2 in zip(args1, args2))
146
149
  else:
147
150
  return False
148
151
  else:
ovld/py.typed ADDED
File without changes