ovld 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/core.py CHANGED
@@ -3,9 +3,15 @@
3
3
  import inspect
4
4
  import itertools
5
5
  import math
6
+ import sys
6
7
  import textwrap
7
8
  import typing
8
- from types import FunctionType
9
+ from functools import partial
10
+ from types import CodeType
11
+
12
+ from .mro import compose_mro
13
+ from .recode import Conformer, rename_function
14
+ from .utils import BOOTSTRAP, MISSING, keyword_decorator
9
15
 
10
16
  try:
11
17
  from types import UnionType
@@ -13,20 +19,13 @@ except ImportError: # pragma: no cover
13
19
  UnionType = None
14
20
 
15
21
 
16
- try:
17
- from types import GenericAlias
18
- except ImportError: # pragma: no cover
22
+ class GenericAliasMC(type):
23
+ def __instancecheck__(cls, obj):
24
+ return hasattr(obj, "__origin__")
19
25
 
20
- class GenericAliasMC(type):
21
- def __instancecheck__(cls, obj):
22
- return hasattr(obj, "__origin__")
23
26
 
24
- class GenericAlias(metaclass=GenericAliasMC):
25
- pass
26
-
27
-
28
- from .mro import compose_mro
29
- from .utils import BOOTSTRAP, MISSING, keyword_decorator
27
+ class GenericAlias(metaclass=GenericAliasMC):
28
+ pass
30
29
 
31
30
 
32
31
  def is_type_of_type(t):
@@ -70,13 +69,23 @@ class TypeMap(dict):
70
69
  the next time getitem is called.
71
70
  """
72
71
  results = {}
72
+ abscollect = set()
73
73
  if is_type_of_type(obj_t):
74
- mro = [type[t] for t in compose_mro(obj_t.__args__[0], self.types)]
74
+ mro = [
75
+ type[t]
76
+ for t in compose_mro(obj_t.__args__[0], self.types, abscollect)
77
+ ]
75
78
  mro.append(type)
76
79
  mro.append(object)
77
80
  else:
78
- mro = compose_mro(obj_t, self.types)
79
- for lvl, cls in enumerate(reversed(mro)):
81
+ mro = compose_mro(obj_t, self.types, abscollect)
82
+
83
+ lvl = -1
84
+ prev_is_abstract = False
85
+ for cls in reversed(mro):
86
+ if cls not in abscollect or not prev_is_abstract:
87
+ lvl += 1
88
+ prev_is_abstract = cls in abscollect
80
89
  handlers = self.entries.get(cls, None)
81
90
  if handlers:
82
91
  results.update({h: lvl for h in handlers})
@@ -109,12 +118,16 @@ class MultiTypeMap(dict):
109
118
 
110
119
  def __init__(self, key_error=KeyError):
111
120
  self.maps = {}
121
+ self.priorities = {}
112
122
  self.empty = MISSING
113
123
  self.key_error = key_error
124
+ self.transform = type
125
+ self.all = {}
126
+ self.errors = {}
114
127
 
115
- def transform(self, obj):
128
+ def _transform(self, obj):
116
129
  if isinstance(obj, GenericAlias):
117
- return type[obj.__origin__]
130
+ return type[obj]
118
131
  elif obj is typing.Any:
119
132
  return type[object]
120
133
  elif isinstance(obj, type):
@@ -137,30 +150,31 @@ class MultiTypeMap(dict):
137
150
  """
138
151
  self.clear()
139
152
 
140
- amin, amax, vararg = nargs
153
+ if any(isinstance(x, GenericAlias) for x in obj_t_tup):
154
+ self.transform = self._transform
155
+
156
+ amin, amax, vararg, priority = nargs
141
157
 
142
158
  entry = (handler, amin, amax, vararg)
143
159
  if not obj_t_tup:
144
160
  self.empty = entry
145
161
 
162
+ self.priorities[handler] = priority
163
+
146
164
  for i, cls in enumerate(obj_t_tup):
147
- tm = self.maps.setdefault(i, TypeMap())
148
- tm.register(cls, entry)
165
+ if i not in self.maps:
166
+ self.maps[i] = TypeMap()
167
+ self.maps[i].register(cls, entry)
149
168
  if vararg:
150
- tm = self.maps.setdefault(-1, TypeMap())
151
- tm.register(object, entry)
169
+ if -1 not in self.maps:
170
+ self.maps[-1] = TypeMap()
171
+ self.maps[-1].register(object, entry)
152
172
 
153
- def __missing__(self, obj_t_tup):
173
+ def resolve(self, obj_t_tup):
154
174
  specificities = {}
155
175
  candidates = None
156
176
  nargs = len(obj_t_tup)
157
177
 
158
- if not obj_t_tup:
159
- if self.empty is MISSING:
160
- raise self.key_error(obj_t_tup, ())
161
- else:
162
- return self.empty[0]
163
-
164
178
  for i, cls in enumerate(obj_t_tup):
165
179
  try:
166
180
  results = self.maps[i][cls]
@@ -196,36 +210,85 @@ class MultiTypeMap(dict):
196
210
  if not candidates:
197
211
  raise self.key_error(obj_t_tup, ())
198
212
 
199
- candidates = [(c, tuple(specificities[c])) for c in candidates]
213
+ candidates = [
214
+ (c, self.priorities.get(c, 0), tuple(specificities[c]))
215
+ for c in candidates
216
+ ]
200
217
 
201
218
  # The sort ensures that if candidate A dominates candidate B, A will
202
219
  # appear before B in the list. That's because it must dominate all
203
220
  # other possibilities on all arguments, so the sum of all specificities
204
221
  # has to be greater.
205
- candidates.sort(key=lambda cspc: sum(cspc[1]), reverse=True)
206
-
207
- results = [candidates[0]]
208
- for c2, spc2 in candidates[1:]:
209
- # Evaluate candidate 2
210
- for c1, spc1 in results:
211
- if spc1 != spc2 and all(s1 >= s2 for s1, s2 in zip(spc1, spc2)):
212
- # Candidate 1 dominates candidate 2, therefore we can
213
- # reject candidate 2.
222
+ # Note: priority is always more important than specificity
223
+ candidates.sort(key=lambda cspc: (cspc[1], sum(cspc[2])), reverse=True)
224
+
225
+ self.all[obj_t_tup] = {
226
+ getattr(c[0], "__code__", None) for c in candidates
227
+ }
228
+
229
+ def _pull(candidates):
230
+ if not candidates:
231
+ return
232
+ rval = [candidates[0]]
233
+ c1, p1, spc1 = candidates[0]
234
+ for c2, p2, spc2 in candidates[1:]:
235
+ if p1 > p2 or (
236
+ spc1 != spc2 and all(s1 >= s2 for s1, s2 in zip(spc1, spc2))
237
+ ):
238
+ # Candidate 1 dominates candidate 2
239
+ continue
240
+ else:
241
+ # Candidate 1 does not dominate candidate 2, so we add it
242
+ # to the list.
243
+ rval.append((c2, p2, spc2))
244
+ yield rval
245
+ if len(rval) == 1:
246
+ # Only groups of length 1 are correct and reachable, so we don't
247
+ # care about the rest.
248
+ yield from _pull(candidates[1:])
249
+
250
+ results = list(_pull(candidates))
251
+ parent = None
252
+ for group in results:
253
+ tup = obj_t_tup if parent is None else (parent, *obj_t_tup)
254
+ if len(group) != 1:
255
+ self.errors[tup] = self.key_error(obj_t_tup, group)
256
+ break
257
+ else:
258
+ ((fn, _, _),) = group
259
+ self[tup] = fn
260
+ if hasattr(fn, "__code__"):
261
+ parent = fn.__code__
262
+ else:
214
263
  break
264
+
265
+ return True
266
+
267
+ def __missing__(self, obj_t_tup):
268
+ if obj_t_tup and isinstance(obj_t_tup[0], CodeType):
269
+ real_tup = obj_t_tup[1:]
270
+ self[real_tup]
271
+ if obj_t_tup[0] not in self.all[real_tup]:
272
+ return self[real_tup]
273
+ elif obj_t_tup in self.errors:
274
+ raise self.errors[obj_t_tup]
275
+ elif obj_t_tup in self: # pragma: no cover
276
+ # PROBABLY not reachable
277
+ return self[obj_t_tup]
215
278
  else:
216
- # Candidate 2 cannot be dominated, so we move it to the results
217
- # list
218
- results.append((c2, spc2))
219
-
220
- if len(results) != 1:
221
- # No candidate dominates all the others => key_error
222
- # As second argument, we provide the minimal set of candidates
223
- # that no other candidate can dominate
224
- raise self.key_error(obj_t_tup, [result for result, _ in results])
279
+ raise self.key_error(real_tup, ())
280
+
281
+ if not obj_t_tup:
282
+ if self.empty is MISSING:
283
+ raise self.key_error(obj_t_tup, ())
284
+ else:
285
+ return self.empty[0]
286
+
287
+ self.resolve(obj_t_tup)
288
+ if obj_t_tup in self.errors:
289
+ raise self.errors[obj_t_tup]
225
290
  else:
226
- ((result, _),) = results
227
- self[obj_t_tup] = result
228
- return result
291
+ return self[obj_t_tup]
229
292
 
230
293
 
231
294
  def _fresh(t):
@@ -273,8 +336,6 @@ class _Ovld:
273
336
  Arguments:
274
337
  dispatch: A function to use as the entry point. It must find the
275
338
  function to dispatch to and call it.
276
- initial_state: A function returning the initial state, or None if
277
- there is no state.
278
339
  postprocess: A function to call on the return value. It is not called
279
340
  after recursive calls.
280
341
  mixins: A list of Ovld instances that contribute functions to this
@@ -296,7 +357,6 @@ class _Ovld:
296
357
  self,
297
358
  *,
298
359
  dispatch=None,
299
- initial_state=None,
300
360
  postprocess=None,
301
361
  type_error=TypeError,
302
362
  mixins=[],
@@ -314,11 +374,10 @@ class _Ovld:
314
374
  self.linkback = linkback
315
375
  self.children = []
316
376
  self.type_error = type_error
317
- self.initial_state = initial_state
318
377
  self.postprocess = postprocess
319
378
  self.allow_replacement = allow_replacement
320
379
  self.bootstrap_class = OvldCall
321
- if self.initial_state or self.postprocess:
380
+ if self.postprocess:
322
381
  assert bootstrap is not False
323
382
  self.bootstrap = True
324
383
  elif isinstance(bootstrap, type):
@@ -381,17 +440,18 @@ class _Ovld:
381
440
  def _key_error(self, key, possibilities=None):
382
441
  typenames = self._sig_string(key)
383
442
  if not possibilities:
384
- raise self.type_error(
443
+ return self.type_error(
385
444
  f"No method in {self} for argument types [{typenames}]"
386
445
  )
387
446
  else:
388
447
  hlp = ""
389
- for p in possibilities:
390
- hlp += f"* {p.__name__}\n"
391
- raise self.type_error(
448
+ for p, prio, spc in possibilities:
449
+ hlp += f"* {p.__name__} (priority: {prio}, specificity: {list(spc)})\n"
450
+ return self.type_error(
392
451
  f"Ambiguous resolution in {self} for"
393
452
  f" argument types [{typenames}]\n"
394
- "Candidates are:\n" + hlp
453
+ f"Candidates are:\n{hlp}"
454
+ "Note: you can use @ovld(priority=X) to give higher priority to an overload."
395
455
  )
396
456
 
397
457
  def rename(self, name):
@@ -492,15 +552,12 @@ class _Ovld:
492
552
  repl = self._maybe_rename(repl)
493
553
  setattr(cls, method, repl)
494
554
 
555
+ target = self.ocls if self.bootstrap else cls
495
556
  if self._dispatch:
496
- if self.bootstrap:
497
- self.ocls.__call__ = self._dispatch
498
- else:
499
- cls.__call__ = self._dispatch
557
+ target.__call__ = self._dispatch
500
558
 
501
559
  # Rename the dispatch
502
- if self._dispatch:
503
- self._dispatch = rename_function(self._dispatch, f"{name}.dispatch")
560
+ target.__call__ = rename_function(target.__call__, f"{name}.dispatch")
504
561
 
505
562
  for key, fn in list(self.defns.items()):
506
563
  self.register_signature(key, fn)
@@ -519,17 +576,23 @@ class _Ovld:
519
576
 
520
577
  def register_signature(self, key, orig_fn):
521
578
  """Register a function for the given signature."""
522
- sig, min, max, vararg = key
579
+ sig, min, max, vararg, priority = key
523
580
  fn = rename_function(
524
581
  orig_fn, f"{self.__name__}[{self._sig_string(sig)}]"
525
582
  )
526
583
  # We just need to keep the Conformer pointer alive for jurigged
527
584
  # to find it, if jurigged is used with ovld
528
585
  fn._conformer = Conformer(self, orig_fn, fn)
529
- self.map.register(sig, (min, max, vararg), fn)
586
+ self.map.register(sig, (min, max, vararg, priority), fn)
530
587
  return self
531
588
 
532
- def register(self, fn):
589
+ def register(self, fn=None, priority=0):
590
+ """Register a function."""
591
+ if fn is None:
592
+ return partial(self._register, priority=priority)
593
+ return self._register(fn, priority)
594
+
595
+ def _register(self, fn, priority):
533
596
  """Register a function."""
534
597
 
535
598
  def _normalize_type(t, force_tuple=False):
@@ -542,7 +605,7 @@ class _Ovld:
542
605
  return _normalize_type(t.__args__)
543
606
  elif origin is not None:
544
607
  raise TypeError(
545
- "ovld does not accept generic types except type, Union or Optional"
608
+ f"ovld does not accept generic types except type, Union or Optional, not {t}"
546
609
  )
547
610
  elif isinstance(t, tuple):
548
611
  return tuple(_normalize_type(t2) for t2 in t)
@@ -580,7 +643,7 @@ class _Ovld:
580
643
  _normalize_type(t, force_tuple=True) for t in typelist
581
644
  )
582
645
  for tl in itertools.product(*typelist_tups):
583
- sig = (tuple(tl), req_pos, max_pos, bool(argspec.varargs))
646
+ sig = (tuple(tl), req_pos, max_pos, bool(argspec.varargs), priority)
584
647
  if not self.allow_replacement and sig in self._defns:
585
648
  raise TypeError(f"There is already a method for {tl}")
586
649
  self._defns[(*sig,)] = fn
@@ -604,7 +667,6 @@ class _Ovld:
604
667
  def copy(
605
668
  self,
606
669
  dispatch=MISSING,
607
- initial_state=None,
608
670
  postprocess=None,
609
671
  mixins=[],
610
672
  linkback=False,
@@ -618,7 +680,6 @@ class _Ovld:
618
680
  bootstrap=self.bootstrap,
619
681
  dispatch=self._dispatch if dispatch is MISSING else dispatch,
620
682
  mixins=[self, *mixins],
621
- initial_state=initial_state or self.initial_state,
622
683
  postprocess=postprocess or self.postprocess,
623
684
  linkback=linkback,
624
685
  )
@@ -640,21 +701,12 @@ class _Ovld:
640
701
  def get_map(self):
641
702
  return self.map
642
703
 
643
- @_compile_first
644
- def instantiate(self, **state):
645
- return self.ocls(map=self.map, state=state, bind_to=BOOTSTRAP)
646
-
647
704
  @_compile_first
648
705
  def __get__(self, obj, cls):
649
706
  if obj is None:
650
707
  return self
651
- if self.initial_state is None or isinstance(self.initial_state, dict):
652
- state = self.initial_state
653
- else:
654
- state = self.initial_state()
655
708
  return self.ocls(
656
709
  map=self.map,
657
- state=state,
658
710
  bind_to=obj,
659
711
  super=self.mixins[0] if len(self.mixins) == 1 else None,
660
712
  )
@@ -680,14 +732,22 @@ class _Ovld:
680
732
  method = self.map[key]
681
733
  return method(*args, **kwargs)
682
734
 
735
+ @_compile_first
736
+ @_setattrs(rename="next")
737
+ def next(self, *args, **kwargs):
738
+ """Call the next matching method after the caller, in terms of priority or specificity."""
739
+ fr = sys._getframe(1)
740
+ key = (fr.f_code, *map(self.map.transform, args))
741
+ method = self.map[key]
742
+ return method(*args, **kwargs)
743
+
683
744
  @__call__.setalt
684
745
  @_setattrs(rename="entry")
685
746
  def __ovldcall__(self, *args, **kwargs):
686
747
  """Call the overloaded function.
687
748
 
688
- This version of __call__ is used when bootstrap is True. It creates an
689
- OvldCall instance to contain the state. This function is only called
690
- once at the entry point: recursive calls will will be to
749
+ This version of __call__ is used when bootstrap is True. This function is
750
+ only called once at the entry point: recursive calls will will be to
691
751
  OvldCall.__call__.
692
752
  """
693
753
  ovc = self.__get__(BOOTSTRAP, None)
@@ -708,12 +768,10 @@ def is_ovld(x):
708
768
  class OvldCall:
709
769
  """Context for an Ovld call."""
710
770
 
711
- def __init__(self, map, state, bind_to, super=None):
771
+ def __init__(self, map, bind_to, super=None):
712
772
  """Initialize an OvldCall."""
713
773
  self.map = map
714
774
  self._parent = super
715
- if state is not None:
716
- self.__dict__.update(state)
717
775
  self.obj = self if bind_to is BOOTSTRAP else bind_to
718
776
 
719
777
  def __getitem__(self, t):
@@ -722,6 +780,13 @@ class OvldCall:
722
780
  t = (t,)
723
781
  return self.map[t].__get__(self.obj)
724
782
 
783
+ def next(self, *args, **kwargs):
784
+ """Call the next matching method after the caller, in terms of priority or specificity."""
785
+ fr = sys._getframe(1)
786
+ key = (fr.f_code, *map(self.map.transform, args))
787
+ method = self.map[key]
788
+ return method(self.obj, *args, **kwargs)
789
+
725
790
  def super(self, *args, **kwargs):
726
791
  """Use the parent ovld's method for this call."""
727
792
  pmap = self._parent.get_map()
@@ -736,10 +801,6 @@ class OvldCall:
736
801
  """Call the right method for the given arguments."""
737
802
  return self[tuple(map(self.map.transform, args))](*args)
738
803
 
739
- def with_state(self, **state):
740
- """Return a new OvldCall using the given state."""
741
- return type(self)(self.map, state, BOOTSTRAP)
742
-
743
804
  def __call__(self, *args, **kwargs):
744
805
  """Call this overloaded function.
745
806
 
@@ -856,7 +917,7 @@ def _find_overload(fn, **kwargs):
856
917
 
857
918
 
858
919
  @keyword_decorator
859
- def ovld(fn, **kwargs):
920
+ def ovld(fn, priority=0, **kwargs):
860
921
  """Overload a function.
861
922
 
862
923
  Overloading is based on the function name.
@@ -872,8 +933,6 @@ def ovld(fn, **kwargs):
872
933
  fn: The function to register.
873
934
  dispatch: A function to use as the entry point. It must find the
874
935
  function to dispatch to and call it.
875
- initial_state: A function returning the initial state, or None if
876
- there is no state.
877
936
  postprocess: A function to call on the return value. It is not called
878
937
  after recursive calls.
879
938
  mixins: A list of Ovld instances that contribute functions to this
@@ -889,7 +948,7 @@ def ovld(fn, **kwargs):
889
948
  ovld so that updates can be propagated. (default: False)
890
949
  """
891
950
  dispatch = _find_overload(fn, **kwargs)
892
- return dispatch.register(fn)
951
+ return dispatch.register(fn, priority=priority)
893
952
 
894
953
 
895
954
  @keyword_decorator
@@ -907,8 +966,6 @@ def ovld_dispatch(dispatch, **kwargs):
907
966
  Arguments:
908
967
  dispatch: The function to use as the entry point. It must find the
909
968
  function to dispatch to and call it.
910
- initial_state: A function returning the initial state, or None if
911
- there is no state.
912
969
  postprocess: A function to call on the return value. It is not called
913
970
  after recursive calls.
914
971
  mixins: A list of Ovld instances that contribute functions to this
@@ -930,82 +987,6 @@ def ovld_dispatch(dispatch, **kwargs):
930
987
  ovld.dispatch = ovld_dispatch
931
988
 
932
989
 
933
- class Conformer:
934
- __slots__ = ("code", "orig_fn", "renamed_fn", "ovld", "code2")
935
-
936
- def __init__(self, ovld, orig_fn, renamed_fn):
937
- self.ovld = ovld
938
- self.orig_fn = orig_fn
939
- self.renamed_fn = renamed_fn
940
- self.code = orig_fn.__code__
941
-
942
- def __conform__(self, new):
943
- if isinstance(new, FunctionType):
944
- new_fn = new
945
- new_code = new.__code__
946
- else:
947
- new_fn = None
948
- new_code = new
949
-
950
- if new_code is None:
951
- self.ovld.unregister(self.orig_fn)
952
-
953
- elif new_fn is None: # pragma: no cover
954
- # Not entirely sure if this ever happens
955
- self.renamed_fn.__code__ = new_code
956
-
957
- elif inspect.signature(self.orig_fn) != inspect.signature(new_fn):
958
- self.ovld.unregister(self.orig_fn)
959
- self.ovld.register(new_fn)
960
-
961
- else:
962
- self.renamed_fn.__code__ = rename_code(
963
- new_code, self.renamed_fn.__code__.co_name
964
- )
965
-
966
- from codefind import code_registry
967
-
968
- code_registry.update_cache_entry(self, self.code, new_code)
969
-
970
- self.code = new_code
971
-
972
-
973
- def rename_code(co, newname): # pragma: no cover
974
- if hasattr(co, "replace"):
975
- if hasattr(co, "co_qualname"):
976
- return co.replace(co_name=newname, co_qualname=newname)
977
- else:
978
- return co.replace(co_name=newname)
979
- else:
980
- return type(co)(
981
- co.co_argcount,
982
- co.co_kwonlyargcount,
983
- co.co_nlocals,
984
- co.co_stacksize,
985
- co.co_flags,
986
- co.co_code,
987
- co.co_consts,
988
- co.co_names,
989
- co.co_varnames,
990
- co.co_filename,
991
- newname,
992
- co.co_firstlineno,
993
- co.co_lnotab,
994
- co.co_freevars,
995
- co.co_cellvars,
996
- )
997
-
998
-
999
- def rename_function(fn, newname):
1000
- """Create a copy of the function with a different name."""
1001
- newcode = rename_code(fn.__code__, newname)
1002
- new_fn = FunctionType(
1003
- newcode, fn.__globals__, newname, fn.__defaults__, fn.__closure__
1004
- )
1005
- new_fn.__annotations__ = fn.__annotations__
1006
- return new_fn
1007
-
1008
-
1009
990
  __all__ = [
1010
991
  "MultiTypeMap",
1011
992
  "Ovld",
ovld/mro.py CHANGED
@@ -6,6 +6,41 @@
6
6
  # assume they have already been properly tested.
7
7
 
8
8
 
9
+ def _issubclass(c1, c2):
10
+ c1 = getattr(c1, "__proxy_for__", c1)
11
+ c2 = getattr(c2, "__proxy_for__", c2)
12
+ if hasattr(c1, "__origin__") or hasattr(c2, "__origin__"):
13
+ o1 = getattr(c1, "__origin__", c1)
14
+ o2 = getattr(c2, "__origin__", c2)
15
+ if issubclass(o1, o2):
16
+ if o2 is c2: # pragma: no cover
17
+ return True
18
+ else:
19
+ args1 = getattr(c1, "__args__", ())
20
+ args2 = getattr(c2, "__args__", ())
21
+ if len(args1) != len(args2):
22
+ return False
23
+ return all(_issubclass(a1, a2) for a1, a2 in zip(args1, args2))
24
+ else:
25
+ return False
26
+ else:
27
+ return issubclass(c1, c2)
28
+
29
+
30
+ def _mro(cls):
31
+ if hasattr(cls, "__mro__"):
32
+ return set(cls.__mro__)
33
+ elif hasattr(cls, "__origin__"):
34
+ return set(cls.__origin__.__mro__)
35
+ else: # pragma: no cover
36
+ return {cls}
37
+
38
+
39
+ def _subclasses(cls):
40
+ scf = getattr(cls, "__subclasses__", ())
41
+ return scf and scf()
42
+
43
+
9
44
  def _c3_merge(sequences): # pragma: no cover
10
45
  """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
11
46
  Adapted from http://www.python.org/download/releases/2.3/mro/.
@@ -32,7 +67,7 @@ def _c3_merge(sequences): # pragma: no cover
32
67
  del seq[0]
33
68
 
34
69
 
35
- def _c3_mro(cls, abcs=None): # pragma: no cover
70
+ def _c3_mro(cls, abcs=None, abscollect=None): # pragma: no cover
36
71
  """Computes the method resolution order using extended C3 linearization.
37
72
  If no *abcs* are given, the algorithm works exactly like the built-in C3
38
73
  linearization used for method resolution.
@@ -46,28 +81,40 @@ def _c3_mro(cls, abcs=None): # pragma: no cover
46
81
  MRO of said class. If two implicit ABCs end up next to each other in the
47
82
  resulting MRO, their ordering depends on the order of types in *abcs*.
48
83
  """
49
- for i, base in enumerate(reversed(cls.__bases__)):
84
+ bases = getattr(cls, "__bases__", ())
85
+ if hasattr(cls, "__origin__"):
86
+ bases = [cls.__origin__, *bases]
87
+ for i, base in enumerate(reversed(bases)):
50
88
  if hasattr(base, "__abstractmethods__"):
51
- boundary = len(cls.__bases__) - i
89
+ boundary = len(bases) - i
52
90
  break # Bases up to the last explicit ABC are considered first.
53
91
  else:
54
92
  boundary = 0
55
93
  abcs = list(abcs) if abcs else []
56
- explicit_bases = list(cls.__bases__[:boundary])
94
+ explicit_bases = list(bases[:boundary])
57
95
  abstract_bases = []
58
- other_bases = list(cls.__bases__[boundary:])
96
+ other_bases = list(bases[boundary:])
59
97
  for base in abcs:
60
- if issubclass(cls, base) and not any(
61
- issubclass(b, base) for b in cls.__bases__
98
+ if _issubclass(cls, base) and not any(
99
+ _issubclass(b, base) for b in bases
62
100
  ):
63
101
  # If *cls* is the class that introduces behaviour described by
64
102
  # an ABC *base*, insert said ABC to its MRO.
65
103
  abstract_bases.append(base)
66
104
  for base in abstract_bases:
67
105
  abcs.remove(base)
68
- explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
69
- abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
70
- other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
106
+ abscollect.update(abstract_bases)
107
+ explicit_c3_mros = [
108
+ _c3_mro(base, abcs=abcs, abscollect=abscollect)
109
+ for base in explicit_bases
110
+ ]
111
+ abstract_c3_mros = [
112
+ _c3_mro(base, abcs=abcs, abscollect=abscollect)
113
+ for base in abstract_bases
114
+ ]
115
+ other_c3_mros = [
116
+ _c3_mro(base, abcs=abcs, abscollect=abscollect) for base in other_bases
117
+ ]
71
118
  return _c3_merge(
72
119
  [[cls]]
73
120
  + explicit_c3_mros
@@ -79,20 +126,16 @@ def _c3_mro(cls, abcs=None): # pragma: no cover
79
126
  )
80
127
 
81
128
 
82
- def compose_mro(cls, types): # pragma: no cover
129
+ def compose_mro(cls, types, abscollect): # pragma: no cover
83
130
  """Calculates the method resolution order for a given class *cls*.
84
131
  Includes relevant abstract base classes (with their respective bases) from
85
132
  the *types* iterable. Uses a modified C3 linearization algorithm.
86
133
  """
87
- bases = set(cls.__mro__)
134
+ bases = _mro(cls)
88
135
 
89
136
  # Remove entries which are already present in the __mro__ or unrelated.
90
137
  def is_related(typ):
91
- return (
92
- typ not in bases
93
- and hasattr(typ, "__mro__")
94
- and issubclass(cls, typ)
95
- )
138
+ return typ not in bases and _issubclass(cls, typ)
96
139
 
97
140
  types = [n for n in types if is_related(n)]
98
141
 
@@ -100,7 +143,7 @@ def compose_mro(cls, types): # pragma: no cover
100
143
  # in the MRO anyway.
101
144
  def is_strict_base(typ):
102
145
  for other in types:
103
- if typ != other and typ in other.__mro__:
146
+ if typ != other and typ in _mro(other):
104
147
  return True
105
148
  return False
106
149
 
@@ -110,9 +153,11 @@ def compose_mro(cls, types): # pragma: no cover
110
153
  type_set = set(types)
111
154
  mro = []
112
155
  for typ in types:
156
+ if typ is object:
157
+ continue
113
158
  found = []
114
- for sub in typ.__subclasses__():
115
- if sub not in bases and issubclass(cls, sub):
159
+ for sub in _subclasses(typ):
160
+ if sub not in bases and _issubclass(cls, sub):
116
161
  found.append([s for s in sub.__mro__ if s in type_set])
117
162
  if not found:
118
163
  mro.append(typ)
@@ -123,4 +168,4 @@ def compose_mro(cls, types): # pragma: no cover
123
168
  for subcls in sub:
124
169
  if subcls not in mro:
125
170
  mro.append(subcls)
126
- return _c3_mro(cls, abcs=mro)
171
+ return _c3_mro(cls, abcs=mro, abscollect=abscollect)
ovld/recode.py ADDED
@@ -0,0 +1,80 @@
1
+ import inspect
2
+ from types import FunctionType
3
+
4
+ recurse = object()
5
+
6
+
7
+ class Conformer:
8
+ __slots__ = ("code", "orig_fn", "renamed_fn", "ovld", "code2")
9
+
10
+ def __init__(self, ovld, orig_fn, renamed_fn):
11
+ self.ovld = ovld
12
+ self.orig_fn = orig_fn
13
+ self.renamed_fn = renamed_fn
14
+ self.code = orig_fn.__code__
15
+
16
+ def __conform__(self, new):
17
+ if isinstance(new, FunctionType):
18
+ new_fn = new
19
+ new_code = new.__code__
20
+ else:
21
+ new_fn = None
22
+ new_code = new
23
+
24
+ if new_code is None:
25
+ self.ovld.unregister(self.orig_fn)
26
+
27
+ elif new_fn is None: # pragma: no cover
28
+ # Not entirely sure if this ever happens
29
+ self.renamed_fn.__code__ = new_code
30
+
31
+ elif inspect.signature(self.orig_fn) != inspect.signature(new_fn):
32
+ self.ovld.unregister(self.orig_fn)
33
+ self.ovld.register(new_fn)
34
+
35
+ else:
36
+ self.renamed_fn.__code__ = rename_code(
37
+ new_code, self.renamed_fn.__code__.co_name
38
+ )
39
+
40
+ from codefind import code_registry
41
+
42
+ code_registry.update_cache_entry(self, self.code, new_code)
43
+
44
+ self.code = new_code
45
+
46
+
47
+ def rename_code(co, newname): # pragma: no cover
48
+ if hasattr(co, "replace"):
49
+ if hasattr(co, "co_qualname"):
50
+ return co.replace(co_name=newname, co_qualname=newname)
51
+ else:
52
+ return co.replace(co_name=newname)
53
+ else:
54
+ return type(co)(
55
+ co.co_argcount,
56
+ co.co_kwonlyargcount,
57
+ co.co_nlocals,
58
+ co.co_stacksize,
59
+ co.co_flags,
60
+ co.co_code,
61
+ co.co_consts,
62
+ co.co_names,
63
+ co.co_varnames,
64
+ co.co_filename,
65
+ newname,
66
+ co.co_firstlineno,
67
+ co.co_lnotab,
68
+ co.co_freevars,
69
+ co.co_cellvars,
70
+ )
71
+
72
+
73
+ def rename_function(fn, newname):
74
+ """Create a copy of the function with a different name."""
75
+ newcode = rename_code(fn.__code__, newname)
76
+ new_fn = FunctionType(
77
+ newcode, fn.__globals__, newname, fn.__defaults__, fn.__closure__
78
+ )
79
+ new_fn.__annotations__ = fn.__annotations__
80
+ return new_fn
ovld/version.py CHANGED
@@ -1 +1 @@
1
- version = "0.3.8"
1
+ version = "0.3.9"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ovld
3
- Version: 0.3.8
3
+ Version: 0.3.9
4
4
  Summary: Overloading Python functions
5
5
  Author-email: Olivier Breuleux <breuleux@gmail.com>
6
6
  License-Expression: MIT
@@ -19,7 +19,7 @@ Other features of `ovld`:
19
19
 
20
20
  * Multiple dispatch for methods (with `metaclass=ovld.OvldMC`)
21
21
  * Create variants of functions
22
- * Built-in support for extensible, stateful recursion
22
+ * Built-in support for extensible recursion
23
23
  * Function wrappers
24
24
  * Function postprocessors
25
25
  * Nice stack traces
@@ -84,22 +84,6 @@ assert mul([1, 2], [3, 4]) == [3, 8]
84
84
 
85
85
  A `variant` of a function is a copy which inherits all of the original's implementations but may define new ones. And because `self` is bound to the function that's called at the top level, the implementations for `list`, `tuple` and `dict` will bind `self` to `add` or `mul` depending on which one was called. You may also call `self.super(*args)` to invoke the parent implementation for that type.
86
86
 
87
- ## State
88
-
89
- You can pass `initial_state` to `@ovld` or `variant`. The initial state must be a function that takes no arguments. Its return value will be available in `self.state`. The state is initialized at the top level call, but recursive calls to `self` will preserve it.
90
-
91
- In other words, you can do something like this:
92
-
93
- ```python
94
- @add.variant(initial_state=lambda: 0)
95
- def count(self, x, y):
96
- self.state += 1
97
- return (f"#{self.state}", x + y)
98
-
99
- assert count([1, 2, 3], [4, 5, 6]) == [("#1", 5), ("#2", 7), ("#3", 9)]
100
- ```
101
-
102
- The initial_state function can return any object and you can use the state to any purpose (e.g. cache or memoization).
103
87
 
104
88
  ## Custom dispatch
105
89
 
@@ -134,7 +118,7 @@ assert add_default([1, 2, "alouette"]) == [2, 4, "alouettealouette"]
134
118
 
135
119
  There are other uses for this feature, e.g. memoization.
136
120
 
137
- The normal functions may also have a `self`, which works the same as bootstrapping, and you can give an `initial_state` to `@ovld.dispatch` as well.
121
+ The normal functions may also have a `self`, which works the same as bootstrapping.
138
122
 
139
123
  ## Postprocess
140
124
 
@@ -0,0 +1,10 @@
1
+ ovld/__init__.py,sha256=jqq7r0RhfvkqmQa8Wc7wqX3Qh5NbwL_6Nfr1sp88Kh0,734
2
+ ovld/core.py,sha256=3JmnfLF6puVtWkc01Hnxuss4gVdWzX9mrLVMfysQqJY,32895
3
+ ovld/mro.py,sha256=W6T5IPec6YQIvv8oSwg-xGHTzZP1zOuTZuaRCegbl-I,6296
4
+ ovld/recode.py,sha256=45VmTUmk2GmWRfTLANMHhyzVcgBit2RvwULzyWpOMhI,2234
5
+ ovld/utils.py,sha256=8kINtHrWdC0bNVccxYIkHBaGFk_YFpjtnjAa9GyXNmA,3861
6
+ ovld/version.py,sha256=8WhSz7JLKtk-44zKbF-sqJmexEzvpM-d0cPHqr6cE5k,18
7
+ ovld-0.3.9.dist-info/METADATA,sha256=wbN2H6NflCIkt-OZJvzrVXWNIz8bH0FTow83Ir0C8NM,8282
8
+ ovld-0.3.9.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ovld-0.3.9.dist-info/licenses/LICENSE,sha256=cSwNTIzd1cbI89xt3PeZZYJP2y3j8Zus4bXgo4svpX8,1066
10
+ ovld-0.3.9.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- ovld/__init__.py,sha256=jqq7r0RhfvkqmQa8Wc7wqX3Qh5NbwL_6Nfr1sp88Kh0,734
2
- ovld/core.py,sha256=B5j-k5SRPFd1l_zM6XkIK-jWN9FCtQVhtj4URI4oepQ,33210
3
- ovld/mro.py,sha256=Dof3wj-BLcV-T-Jm9tK6zqptp_JGo5lxgyRq99AvY1k,5006
4
- ovld/utils.py,sha256=8kINtHrWdC0bNVccxYIkHBaGFk_YFpjtnjAa9GyXNmA,3861
5
- ovld/version.py,sha256=C1bt0pFrmukDC9wIVB6iyGiXY-vasUa0lCdJBB8wwjk,18
6
- ovld-0.3.8.dist-info/METADATA,sha256=zdZJPKFRdMBgfHuUDzquDhsVDQyadJT3PuXhMokjIHk,9006
7
- ovld-0.3.8.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
8
- ovld-0.3.8.dist-info/licenses/LICENSE,sha256=cSwNTIzd1cbI89xt3PeZZYJP2y3j8Zus4bXgo4svpX8,1066
9
- ovld-0.3.8.dist-info/RECORD,,
File without changes