ovld 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/dependent.py CHANGED
@@ -3,7 +3,6 @@ import re
3
3
  from collections.abc import Callable as _Callable
4
4
  from collections.abc import Mapping, Sequence
5
5
  from functools import partial
6
- from itertools import count
7
6
  from typing import (
8
7
  TYPE_CHECKING,
9
8
  Any,
@@ -11,6 +10,7 @@ from typing import (
11
10
  TypeVar,
12
11
  )
13
12
 
13
+ from .codegen import Code
14
14
  from .types import (
15
15
  Intersection,
16
16
  Order,
@@ -21,43 +21,6 @@ from .types import (
21
21
  typeorder,
22
22
  )
23
23
 
24
- _current = count()
25
-
26
-
27
- def generate_checking_code(typ):
28
- if hasattr(typ, "codegen"):
29
- return typ.codegen()
30
- else:
31
- return CodeGen("isinstance({arg}, {this})", this=typ)
32
-
33
-
34
- class CodeGen:
35
- def __init__(self, template, substitutions={}, **substitutions_kw):
36
- self.template = template
37
- self.substitutions = {**substitutions, **substitutions_kw}
38
-
39
- def mangle(self):
40
- renamings = {
41
- k: f"{{{k}__{next(_current)}}}" for k in self.substitutions
42
- }
43
- renamings["arg"] = "{arg}"
44
- new_subs = {
45
- newk[1:-1]: self.substitutions[k]
46
- for k, newk in renamings.items()
47
- if k in self.substitutions
48
- }
49
- return CodeGen(self.template.format(**renamings), new_subs)
50
-
51
-
52
- def combine(master_template, args):
53
- fmts = []
54
- subs = {}
55
- for cg in args:
56
- mangled = cg.mangle()
57
- fmts.append(mangled.template)
58
- subs.update(mangled.substitutions)
59
- return CodeGen(master_template.format(*fmts), subs)
60
-
61
24
 
62
25
  def is_dependent(t):
63
26
  if isinstance(t, DependentType):
@@ -90,7 +53,7 @@ class DependentType(type):
90
53
  raise NotImplementedError()
91
54
 
92
55
  def codegen(self):
93
- return CodeGen("{this}.check({arg})", this=self)
56
+ return Code("$this.check($arg)", this=self)
94
57
 
95
58
  def __type_order__(self, other):
96
59
  if isinstance(other, DependentType):
@@ -106,9 +69,7 @@ class DependentType(type):
106
69
  return Order.NONE
107
70
  else: # pragma: no cover
108
71
  return order
109
- elif subclasscheck(other, self.bound) or subclasscheck(
110
- self.bound, other
111
- ):
72
+ elif subclasscheck(other, self.bound) or subclasscheck(self.bound, other):
112
73
  return Order.LESS
113
74
  else:
114
75
  return Order.NONE
@@ -138,9 +99,7 @@ class DependentType(type):
138
99
 
139
100
  class ParametrizedDependentType(DependentType):
140
101
  def __init__(self, *parameters, bound=None):
141
- super().__init__(
142
- self.default_bound(*parameters) if bound is None else bound
143
- )
102
+ super().__init__(self.default_bound(*parameters) if bound is None else bound)
144
103
  self.__args__ = self.parameters = parameters
145
104
  self.__origin__ = None
146
105
  self.__post_init__()
@@ -192,12 +151,10 @@ class FuncDependentType(ParametrizedDependentType):
192
151
  if len(self.parameters) != len(other.parameters):
193
152
  return False
194
153
  p1g = sum(
195
- p1 is Any and p2 is not Any
196
- for p1, p2 in zip(self.parameters, other.parameters)
154
+ p1 is Any and p2 is not Any for p1, p2 in zip(self.parameters, other.parameters)
197
155
  )
198
156
  p2g = sum(
199
- p2 is Any and p1 is not Any
200
- for p1, p2 in zip(self.parameters, other.parameters)
157
+ p2 is Any and p1 is not Any for p1, p2 in zip(self.parameters, other.parameters)
201
158
  )
202
159
  return p2g and not p1g
203
160
 
@@ -256,16 +213,16 @@ class Equals(ParametrizedDependentType):
256
213
 
257
214
  @classmethod
258
215
  def keygen(cls):
259
- return "{arg}"
216
+ return Code("$arg")
260
217
 
261
218
  def get_keys(self):
262
219
  return [self.parameter]
263
220
 
264
221
  def codegen(self):
265
222
  if len(self.parameters) == 1:
266
- return CodeGen("({arg} == {p})", p=self.parameter)
223
+ return Code("($arg == $p)", p=self.parameter)
267
224
  else:
268
- return CodeGen("({arg} in {ps})", ps=self.parameters)
225
+ return Code("($arg in $ps)", ps=self.parameters)
269
226
 
270
227
 
271
228
  class ProductType(ParametrizedDependentType):
@@ -282,19 +239,18 @@ class ProductType(ParametrizedDependentType):
282
239
  )
283
240
 
284
241
  def codegen(self):
285
- checks = ["len({arg}) == {n}"]
242
+ checks = ["len($arg) == $n"]
286
243
  params = {"n": len(self.parameters)}
287
244
  for i, p in enumerate(self.parameters):
288
- checks.append(f"isinstance({{arg}}[{i}], {{p{i}}})")
245
+ checks.append(f"isinstance($arg[{i}], $p{i})")
289
246
  params[f"p{i}"] = p
290
- return CodeGen(" and ".join(checks), params)
247
+ return Code(" and ".join(checks), params)
291
248
 
292
249
  def __type_order__(self, other):
293
250
  if isinstance(other, ProductType):
294
251
  if len(other.parameters) == len(self.parameters):
295
252
  return Order.merge(
296
- typeorder(a, b)
297
- for a, b in zip(self.parameters, other.parameters)
253
+ typeorder(a, b) for a, b in zip(self.parameters, other.parameters)
298
254
  )
299
255
  else:
300
256
  return Order.NONE
@@ -338,8 +294,14 @@ def Callable(fn: _Callable, argt, rett):
338
294
 
339
295
 
340
296
  @dependent_check
341
- def HasKey(value: Mapping, *keys):
342
- return all(k in value for k in keys)
297
+ class HasKey:
298
+ def check(self, value: Mapping):
299
+ return all(k in value for k in self.parameters)
300
+
301
+ def codegen(self):
302
+ return Code(
303
+ "($[ and ]checks)", checks=[Code("($k in $arg)", k=k) for k in self.parameters]
304
+ )
343
305
 
344
306
 
345
307
  @dependent_check
@@ -361,7 +323,7 @@ class Regexp:
361
323
  return bool(self.rx.search(value))
362
324
 
363
325
  def codegen(self):
364
- return CodeGen("bool({rx}.search({arg}))", rx=self.rx)
326
+ return Code("bool($rx.search($arg))", rx=self.rx)
365
327
 
366
328
 
367
329
  class Dependent:
@@ -378,14 +340,3 @@ if TYPE_CHECKING: # pragma: no cover
378
340
  T = TypeVar("T")
379
341
  A = TypeVar("A")
380
342
  Dependent: TypeAlias = Annotated[T, A]
381
-
382
-
383
- __all__ = [
384
- "Dependent",
385
- "DependentType",
386
- "Equals",
387
- "HasKey",
388
- "StartsWith",
389
- "EndsWith",
390
- "dependent_check",
391
- ]
ovld/medley.py ADDED
@@ -0,0 +1,416 @@
1
+ import functools
2
+ import inspect
3
+ from copy import copy
4
+ from dataclasses import MISSING, dataclass, fields, make_dataclass, replace
5
+ from typing import Annotated, TypeVar, get_origin
6
+
7
+ from .core import Ovld, to_ovld
8
+ from .types import eval_annotation
9
+ from .utils import Named
10
+
11
+ ABSENT = Named("ABSENT")
12
+ CODEGEN = Named("CODEGEN")
13
+
14
+
15
+ class Combiner:
16
+ def __init__(self, field=None):
17
+ self.field = field
18
+
19
+ def __set_name__(self, obj, field):
20
+ self.field = field
21
+
22
+ def get(self, cls): # pragma: no cover
23
+ raise NotImplementedError()
24
+
25
+ def copy(self):
26
+ return type(self)(self.field)
27
+
28
+ def include(self, other):
29
+ if type(self) is not type(other):
30
+ raise TypeError("Cannot merge different combiner classes.")
31
+ self.include_sametype(other)
32
+
33
+ def include_sametype(self, other): # pragma: no cover
34
+ pass
35
+
36
+ def juxtapose(self, impl): # pragma: no cover
37
+ raise NotImplementedError()
38
+
39
+
40
+ class KeepLast(Combiner):
41
+ def __init__(self, field=None):
42
+ super().__init__(field)
43
+ self.impl = ABSENT
44
+
45
+ def get(self, cls):
46
+ return self.impl
47
+
48
+ def include_sametype(self, other):
49
+ self.impl = other.impl
50
+
51
+ def juxtapose(self, impl):
52
+ self.impl = impl
53
+
54
+
55
+ class ImplList(Combiner):
56
+ def __init__(self, field=None, impls=None):
57
+ super().__init__(field)
58
+ self.impls = impls or []
59
+
60
+ def copy(self):
61
+ return type(self)(self.field, self.impls)
62
+
63
+ def get(self, cls):
64
+ if not self.impls:
65
+ return ABSENT
66
+ rval = self.wrap()
67
+ return functools.wraps(self.impls[0])(rval)
68
+
69
+ def wrap(self): # pragma: no cover
70
+ raise NotImplementedError()
71
+
72
+ def include_sametype(self, other):
73
+ self.impls += other.impls
74
+
75
+ def juxtapose(self, impl):
76
+ self.impls.append(impl)
77
+
78
+
79
+ class RunAll(ImplList):
80
+ def wrap(_self):
81
+ def run_all(self, *args, **kwargs):
82
+ for impl in _self.impls:
83
+ impl(self, *args, **kwargs)
84
+
85
+ return run_all
86
+
87
+
88
+ class ReduceAll(ImplList):
89
+ def wrap(_self):
90
+ def reduce_all(self, x, *args, **kwargs):
91
+ result = _self.impls[0](self, x, *args, **kwargs)
92
+ for impl in _self.impls[1:]:
93
+ result = impl(self, result, *args, **kwargs)
94
+ return result
95
+
96
+ return reduce_all
97
+
98
+
99
+ class ChainAll(ImplList):
100
+ def wrap(_self):
101
+ def chain_all(self, *args, **kwargs):
102
+ self = _self.impls[0](self, *args, **kwargs)
103
+ for impl in _self.impls[1:]:
104
+ self = impl(self, *args, **kwargs)
105
+ return self
106
+
107
+ return chain_all
108
+
109
+
110
+ class BuildOvld(Combiner):
111
+ def __init__(self, field=None, ovld=None):
112
+ super().__init__(field)
113
+ self.ovld = ovld or Ovld(name=field, linkback=True)
114
+ self.pending = []
115
+ if field is not None:
116
+ self.__set_name__(None, field)
117
+
118
+ def __set_name__(self, obj, field):
119
+ self.ovld.rename(field)
120
+
121
+ def get(self, cls):
122
+ self.ovld.specialization_self = cls
123
+ for f, arg in self.pending:
124
+ f(arg)
125
+ self.pending.clear()
126
+ if not self.ovld.defns:
127
+ return ABSENT
128
+ self.ovld.reset()
129
+ return self.ovld.dispatch
130
+
131
+ def copy(self):
132
+ return type(self)(self.field, self.ovld.copy(linkback=True))
133
+
134
+ def include_sametype(self, other):
135
+ self.ovld.add_mixins(other.ovld)
136
+
137
+ def juxtapose(self, impl):
138
+ if ov := to_ovld(impl, force=False):
139
+ self.pending.append((self.ovld.add_mixins, ov))
140
+ elif inspect.isfunction(impl):
141
+ self.pending.append((self.ovld.register, impl))
142
+ else: # pragma: no cover
143
+ raise TypeError("Expected a function or ovld.")
144
+
145
+
146
+ class medley_cls_dict(dict):
147
+ def __init__(self, bases, default_combiner=None):
148
+ if default_combiner is None:
149
+ (default_combiner,) = {b._ovld_default_combiner for b in bases}
150
+ super().__init__()
151
+ self._combiners = {}
152
+ self._default_combiner = default_combiner
153
+ self.set_direct("_ovld_combiners", self._combiners)
154
+ self.set_direct("_ovld_default_combiner", default_combiner)
155
+ self._basic = set()
156
+ for base in bases:
157
+ for attr, combiner in getattr(base, "_ovld_combiners", {}).items():
158
+ if attr in self._combiners:
159
+ self._combiners[attr].include(combiner)
160
+ else:
161
+ self._combiners[attr] = combiner.copy()
162
+
163
+ def set_direct(self, attr, value):
164
+ super().__setitem__(attr, value)
165
+
166
+ def __setitem__(self, attr, value):
167
+ if attr == "__annotations__":
168
+ self.set_direct(attr, value)
169
+ return
170
+
171
+ if attr == "__init__":
172
+ raise Exception("Do not define __init__ in a Medley, use __post_init__.")
173
+
174
+ if isinstance(value, Combiner):
175
+ value.__set_name__(None, attr)
176
+ self._combiners[attr] = value
177
+ return
178
+
179
+ combiner = self._combiners.get(attr, None)
180
+ if combiner is None:
181
+ if inspect.isfunction(value) or isinstance(value, Ovld):
182
+ combiner = self._default_combiner(attr)
183
+ else:
184
+ combiner = KeepLast(attr)
185
+ self._combiners[attr] = combiner
186
+
187
+ combiner.juxtapose(value)
188
+
189
+ def __missing__(self, attr):
190
+ if attr in self._combiners:
191
+ if (value := self._combiners[attr].get(None)) is not ABSENT:
192
+ return value
193
+ raise KeyError(attr)
194
+
195
+
196
+ def codegen_key(*instances):
197
+ rval = {}
198
+ for instance in instances:
199
+ keyd = {name: getattr(instance, name) for name in type(instance)._ovld_codegen_fields}
200
+ rval.update(keyd)
201
+ return rval
202
+
203
+
204
+ def specialize(cls, key):
205
+ ns = medley_cls_dict((cls,))
206
+ new_t = MedleyMC(cls.__name__, (cls,), ns)
207
+ new_t._ovld_specialization_parent = cls
208
+ for k, v in key.items():
209
+ setattr(new_t, k, v)
210
+ cls._ovld_codegen_fields = list(key.keys())
211
+ return new_t
212
+
213
+
214
+ def remap_field(dc_field, require_default=False):
215
+ if require_default:
216
+ if dc_field.default is MISSING:
217
+ # NOTE: we do not accept default_factory, because we need the default value to be set
218
+ # in the class so that existing instances of classes[0] can see it.
219
+ raise TypeError(
220
+ f"Dataclass field '{dc_field.name}' must have a default value (not a default_factory) in order to be melded in."
221
+ )
222
+ dc_field = copy(dc_field)
223
+ dc_field.kw_only = True
224
+ return dc_field
225
+
226
+
227
+ class MedleyMC(type):
228
+ def __subclasscheck__(cls, subclass):
229
+ if getattr(cls, "_ovld_medleys", None):
230
+ return all(issubclass(subclass, m) for m in cls._ovld_medleys)
231
+ return super().__subclasscheck__(subclass)
232
+
233
+ @classmethod
234
+ def __prepare__(mcls, name, bases, default_combiner=None):
235
+ return medley_cls_dict(bases, default_combiner=default_combiner)
236
+
237
+ def __new__(mcls, name, bases, namespace, default_combiner=None):
238
+ result = super().__new__(mcls, name, bases, namespace)
239
+ for attr, combiner in result._ovld_combiners.items():
240
+ if (value := combiner.get(result)) is not ABSENT:
241
+ setattr(result, attr, value)
242
+ dc = dataclass(result)
243
+ dc._ovld_specialization_parent = None
244
+ dc._ovld_specializations = {}
245
+ dc._ovld_codegen_fields = [
246
+ field.name
247
+ for field in fields(dc)
248
+ if (
249
+ (t := eval_annotation(field.type, dc, {}, catch=True))
250
+ and get_origin(t) is Annotated
251
+ and CODEGEN in t.__metadata__
252
+ )
253
+ ]
254
+ return dc
255
+
256
+ def extend(cls, *others, extend_subclasses=True):
257
+ if not others: # pragma: no cover
258
+ return cls
259
+ all_fields = [(f.name, f.type, f) for f in fields(cls)]
260
+ for other in others:
261
+ all_fields += [(f.name, f.type, remap_field(f, True)) for f in fields(other)]
262
+ melded = make_dataclass("_", fields=all_fields)
263
+ for other in others:
264
+ for k, v in vars(other).items():
265
+ if k in ["__module__", "__firstlineno__", "__static_attributes__"]:
266
+ continue
267
+ elif comb := cls._ovld_combiners.get(k):
268
+ comb.juxtapose(v)
269
+ setattr(cls, k, comb.get(cls))
270
+ elif not k.startswith("_ovld_") and not k.startswith("__"):
271
+ setattr(cls, k, v)
272
+ cls.__init__ = melded.__init__
273
+ if extend_subclasses:
274
+ for subcls in cls.__subclasses__():
275
+ subothers = [o for o in others if not issubclass(subcls, o)]
276
+ subcls.extend(*subothers, extend_subclasses=False)
277
+ return cls
278
+
279
+ def __add__(cls, other):
280
+ return meld_classes((cls, other))
281
+
282
+ def __iadd__(cls, other):
283
+ return cls.extend(other)
284
+
285
+ def __sub__(cls, other):
286
+ return unmeld_classes(cls, other)
287
+
288
+ def __call__(cls, *args, **kwargs):
289
+ made = super().__call__(*args, **kwargs)
290
+ if cls._ovld_codegen_fields and (keyd := codegen_key(made)):
291
+ cls = cls._ovld_specialization_parent or cls
292
+ key = tuple(sorted(keyd.items()))
293
+ if key in cls._ovld_specializations:
294
+ new_t = cls._ovld_specializations[key]
295
+ else:
296
+ new_t = specialize(cls, keyd)
297
+ cls._ovld_specializations[key] = new_t
298
+ obj = object.__new__(new_t)
299
+ obj.__dict__.update(made.__dict__)
300
+ return obj
301
+ else:
302
+ return made
303
+
304
+
305
+ def use_combiner(combiner):
306
+ def deco(fn):
307
+ cmb = combiner(fn.__name__)
308
+ cmb.juxtapose(fn)
309
+ return cmb
310
+
311
+ return deco
312
+
313
+
314
+ class Medley(metaclass=MedleyMC, default_combiner=BuildOvld):
315
+ __post_init__ = RunAll()
316
+ __add__ = KeepLast()
317
+ __sub__ = KeepLast()
318
+
319
+ def __add__(self, other):
320
+ if isinstance(self, type(other)) and not type(self)._ovld_codegen_fields:
321
+ return replace(self, **vars(other))
322
+ else:
323
+ return meld([self, other])
324
+
325
+ def __sub__(self, other):
326
+ return unmeld(self, other)
327
+
328
+
329
+ def unmeld_classes(main: type, exclude: type):
330
+ classes = tuple(c for c in main.__bases__ if c is not exclude)
331
+ return meld_classes(classes)
332
+
333
+
334
+ _meld_classes_cache = {}
335
+
336
+
337
+ def meld_classes(classes):
338
+ medleys = {}
339
+ for cls in classes:
340
+ medleys.update({x: True for x in getattr(cls, "_ovld_medleys", [cls])})
341
+ for cls in classes:
342
+ if not hasattr(cls, "_ovld_medleys"):
343
+ for base in cls.mro():
344
+ if base is not cls and base in medleys:
345
+ del medleys[base]
346
+ medleys = tuple(medleys)
347
+ if len(medleys) == 1:
348
+ return medleys[0]
349
+
350
+ cache_key = medleys
351
+ if cache_key in _meld_classes_cache:
352
+ return _meld_classes_cache[cache_key]
353
+
354
+ cg_fields = set()
355
+ dc_fields = []
356
+
357
+ for base in medleys:
358
+ cg_fields.update(base._ovld_codegen_fields)
359
+ dc_fields.extend(
360
+ (f.name, f.type, remap_field(f)) for f in base.__dataclass_fields__.values()
361
+ )
362
+
363
+ merged = medley_cls_dict(medleys)
364
+ merged.set_direct("_ovld_codegen_fields", tuple(cg_fields))
365
+ merged.set_direct("_ovld_medleys", tuple(medleys))
366
+
367
+ if "__qualname__" in merged._combiners:
368
+ del merged._combiners["__qualname__"]
369
+
370
+ result = make_dataclass(
371
+ cls_name="+".join(sorted(c.__name__ for c in medleys)),
372
+ bases=tuple(medleys),
373
+ fields=dc_fields,
374
+ kw_only=True,
375
+ namespace=merged,
376
+ )
377
+
378
+ _meld_classes_cache[cache_key] = result
379
+ return result
380
+
381
+
382
+ @functools.cache
383
+ def meld_classes_with_key(classes, key):
384
+ key = dict(key)
385
+ typ = meld_classes(classes)
386
+ if not key:
387
+ return typ
388
+ else:
389
+ return specialize(typ, key)
390
+
391
+
392
+ def meld(objects):
393
+ key = codegen_key(*objects)
394
+ classes = tuple(type(o) for o in objects)
395
+ cls = meld_classes_with_key(classes, tuple(key.items()))
396
+ obj = object.__new__(cls)
397
+ for o in objects:
398
+ for k, v in vars(o).items():
399
+ setattr(obj, k, v)
400
+ return obj
401
+
402
+
403
+ def unmeld(obj: object, exclude: type):
404
+ if type(obj)._ovld_codegen_fields: # pragma: no cover
405
+ raise TypeError("Cannot unmeld an object with codegen fields")
406
+ cls = unmeld_classes(type(obj), exclude)
407
+ values = {}
408
+ excluded = exclude.__dataclass_fields__
409
+ for f in cls.__dataclass_fields__.values():
410
+ if f.name not in excluded:
411
+ values[f.name] = getattr(obj, f.name)
412
+ return cls(**values)
413
+
414
+
415
+ T = TypeVar("T")
416
+ CodegenParameter = Annotated[T, CODEGEN]
ovld/mro.py CHANGED
@@ -145,9 +145,7 @@ def subclasscheck(t1, t2):
145
145
  args2 = get_args(t2)
146
146
  if len(args1) != len(args2):
147
147
  return False
148
- return all(
149
- subclasscheck(a1, a2) for a1, a2 in zip(args1, args2)
150
- )
148
+ return all(subclasscheck(a1, a2) for a1, a2 in zip(args1, args2))
151
149
  else:
152
150
  return False
153
151
  else: