ovld 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/core.py CHANGED
@@ -2,16 +2,23 @@
2
2
 
3
3
  import inspect
4
4
  import itertools
5
- import math
6
5
  import sys
7
6
  import textwrap
8
7
  import typing
9
- from functools import partial
10
- from types import CodeType
11
-
12
- from .mro import compose_mro
13
- from .recode import Conformer, rename_function
14
- from .utils import BOOTSTRAP, MISSING, keyword_decorator
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass, field, replace
10
+ from functools import cached_property, partial
11
+ from types import GenericAlias
12
+
13
+ from .recode import (
14
+ Conformer,
15
+ adapt_function,
16
+ generate_dispatch,
17
+ rename_function,
18
+ )
19
+ from .typemap import MultiTypeMap, is_type_of_type
20
+ from .types import normalize_type
21
+ from .utils import UsageError, keyword_decorator
15
22
 
16
23
  try:
17
24
  from types import UnionType
@@ -19,311 +26,265 @@ except ImportError: # pragma: no cover
19
26
  UnionType = None
20
27
 
21
28
 
22
- class GenericAliasMC(type):
23
- def __instancecheck__(cls, obj):
24
- return hasattr(obj, "__origin__")
29
+ _current_id = itertools.count()
25
30
 
26
31
 
27
- class GenericAlias(metaclass=GenericAliasMC):
28
- pass
32
+ def _fresh(t):
33
+ """Returns a new subclass of type t.
29
34
 
35
+ Each Ovld corresponds to its own class, which allows for specialization of
36
+ methods.
37
+ """
38
+ methods = {}
39
+ if not isinstance(getattr(t, "__doc__", None), str):
40
+ methods["__doc__"] = t.__doc__
41
+ return type(t.__name__, (t,), methods)
30
42
 
31
- def is_type_of_type(t):
32
- return getattr(t, "__origin__", None) is type
33
43
 
44
+ @keyword_decorator
45
+ def _setattrs(fn, **kwargs):
46
+ for k, v in kwargs.items():
47
+ setattr(fn, k, v)
48
+ return fn
34
49
 
35
- class TypeMap(dict):
36
- """Represents a mapping from types to handlers.
37
50
 
38
- The mro of a type is considered when getting the handler, so setting the
39
- [object] key creates a default for all objects.
51
+ @keyword_decorator
52
+ def _compile_first(fn, rename=None):
53
+ def first_entry(self, *args, **kwargs):
54
+ self.compile()
55
+ method = getattr(self, fn.__name__)
56
+ assert method is not first_entry
57
+ return method(*args, **kwargs)
40
58
 
41
- typemap[some_type] returns a tuple of a handler and a "level" that
42
- represents the distance from the handler to the type `object`. Essentially,
43
- the level is the index of the type for which the handler was registered
44
- in the mro of `some_type`. So for example, `object` has level 0, a class
45
- that inherits directly from `object` has level 1, and so on.
46
- """
59
+ first_entry._replace_by = fn
60
+ first_entry._rename = rename
61
+ return first_entry
47
62
 
48
- def __init__(self):
49
- self.entries = {}
50
- self.types = set()
51
63
 
52
- def register(self, obj_t, handler):
53
- """Register a handler for the given object type."""
54
- if isinstance(obj_t, str):
55
- obj_t = eval(obj_t, getattr(handler[0], "__globals__", {}))
64
+ def arg0_is_self(fn):
65
+ sgn = inspect.signature(fn)
66
+ params = list(sgn.parameters.values())
67
+ return params and params[0].name == "self"
56
68
 
57
- self.clear()
58
- if is_type_of_type(obj_t):
59
- self.types.add(obj_t.__args__[0])
60
- else:
61
- self.types.add(obj_t)
62
- s = self.entries.setdefault(obj_t, set())
63
- s.add(handler)
64
69
 
65
- def __missing__(self, obj_t):
66
- """Get the handler for the given type.
70
+ @dataclass(frozen=True)
71
+ class Arginfo:
72
+ position: typing.Optional[int]
73
+ name: typing.Optional[str]
74
+ required: bool
75
+ ann: type
67
76
 
68
- The result is cached so that the normal dict getitem will find it
69
- the next time getitem is called.
70
- """
71
- results = {}
72
- abscollect = set()
73
- if is_type_of_type(obj_t):
74
- mro = [
75
- type[t]
76
- for t in compose_mro(obj_t.__args__[0], self.types, abscollect)
77
- ]
78
- mro.append(type)
79
- mro.append(object)
80
- else:
81
- mro = compose_mro(obj_t, self.types, abscollect)
82
-
83
- lvl = -1
84
- prev_is_abstract = False
85
- for cls in reversed(mro):
86
- if cls not in abscollect or not prev_is_abstract:
87
- lvl += 1
88
- prev_is_abstract = cls in abscollect
89
- handlers = self.entries.get(cls, None)
90
- if handlers:
91
- results.update({h: lvl for h in handlers})
92
-
93
- if results:
94
- self[obj_t] = results
95
- return results
96
- else:
97
- raise KeyError(obj_t)
77
+ @cached_property
78
+ def is_complex(self):
79
+ return isinstance(self.ann, GenericAlias)
98
80
 
81
+ @cached_property
82
+ def canonical(self):
83
+ return self.name if self.position is None else self.position
99
84
 
100
- class MultiTypeMap(dict):
101
- """Represents a mapping from tuples of types to handlers.
102
85
 
103
- The mro is taken into account to find a match. If multiple registered
104
- handlers match the tuple of types that's given, if one of the handlers is
105
- more specific than every other handler, that handler is returned.
106
- Otherwise, the resolution is considered ambiguous and an error is raised.
86
+ @dataclass(frozen=True)
87
+ class Signature:
88
+ types: tuple
89
+ req_pos: int
90
+ max_pos: int
91
+ req_names: frozenset
92
+ vararg: bool
93
+ priority: float
94
+ is_method: bool = False
95
+ arginfo: list[Arginfo] = field(
96
+ default_factory=list, hash=False, compare=False
97
+ )
107
98
 
108
- Handler A, registered for types (A1, A2, ..., An), is more specific than
109
- handler B, registered for types (B1, B2, ..., Bn), if there exists n such
110
- that An is more specific than Bn, and for all n, either An == Bn or An is
111
- more specific than Bn. An is more specific than Bn if An is a direct or
112
- indirect subclass of Bn.
99
+ @classmethod
100
+ def extract(cls, fn):
101
+ typelist = []
102
+ sig = inspect.signature(fn)
103
+ max_pos = 0
104
+ req_pos = 0
105
+ req_names = set()
106
+ is_method = False
107
+
108
+ arginfo = []
109
+ for i, (name, param) in enumerate(sig.parameters.items()):
110
+ if name == "self":
111
+ assert i == 0
112
+ is_method = True
113
+ continue
114
+ pos = nm = None
115
+ ann = normalize_type(param.annotation, fn)
116
+ if param.kind is inspect._POSITIONAL_ONLY:
117
+ pos = i - is_method
118
+ typelist.append(ann)
119
+ req_pos += param.default is inspect._empty
120
+ max_pos += 1
121
+ elif param.kind is inspect._POSITIONAL_OR_KEYWORD:
122
+ pos = i - is_method
123
+ nm = param.name
124
+ typelist.append(ann)
125
+ req_pos += param.default is inspect._empty
126
+ max_pos += 1
127
+ elif param.kind is inspect._KEYWORD_ONLY:
128
+ nm = param.name
129
+ typelist.append((param.name, ann))
130
+ if param.default is inspect._empty:
131
+ req_names.add(param.name)
132
+ elif param.kind is inspect._VAR_POSITIONAL:
133
+ raise TypeError("ovld does not support *args")
134
+ elif param.kind is inspect._VAR_KEYWORD:
135
+ raise TypeError("ovld does not support **kwargs")
136
+ arginfo.append(
137
+ Arginfo(
138
+ position=pos,
139
+ name=nm,
140
+ required=param.default is inspect._empty,
141
+ ann=normalize_type(param.annotation, fn),
142
+ )
143
+ )
144
+
145
+ return cls(
146
+ types=tuple(typelist),
147
+ req_pos=req_pos,
148
+ max_pos=max_pos,
149
+ req_names=frozenset(req_names),
150
+ vararg=False,
151
+ is_method=is_method,
152
+ priority=None,
153
+ arginfo=arginfo,
154
+ )
113
155
 
114
- In other words, [int, object] is more specific than [object, object] and
115
- less specific than [int, int], but it is neither less specific nor more
116
- specific than [object, int] (which means there is an ambiguity).
117
- """
118
156
 
119
- def __init__(self, key_error=KeyError):
120
- self.maps = {}
121
- self.priorities = {}
122
- self.empty = MISSING
123
- self.key_error = key_error
124
- self.transform = type
125
- self.all = {}
126
- self.errors = {}
127
-
128
- def _transform(self, obj):
129
- if isinstance(obj, GenericAlias):
130
- return type[obj]
131
- elif obj is typing.Any:
132
- return type[object]
133
- elif isinstance(obj, type):
134
- return type[obj]
157
+ def clsstring(cls):
158
+ if cls is object:
159
+ return "*"
160
+ elif isinstance(cls, tuple):
161
+ key, typ = cls
162
+ return f"{key}: {clsstring(typ)}"
163
+ elif is_type_of_type(cls):
164
+ arg = clsstring(cls.__args__[0])
165
+ return f"type[{arg}]"
166
+ elif hasattr(cls, "__origin__"):
167
+ if cls.__origin__ is typing.Union:
168
+ return "|".join(map(clsstring, cls.__args__))
135
169
  else:
136
- return type(obj)
137
-
138
- def register(self, obj_t_tup, nargs, handler):
139
- """Register a handler for a tuple of argument types.
140
-
141
- Arguments:
142
- obj_t_tup: A tuple of argument types.
143
- nargs: A (amin, amax, varargs) tuple where amin is the minimum
144
- number of arguments needed to match this tuple (if there are
145
- default arguments, it is possible that amin < len(obj_t_tup)),
146
- amax is the maximum number of arguments, and varargs is a
147
- boolean indicating whether there can be an arbitrary number
148
- of arguments.
149
- handler: A function to handle the tuple.
150
- """
151
- self.clear()
152
-
153
- if any(isinstance(x, GenericAlias) for x in obj_t_tup):
154
- self.transform = self._transform
155
-
156
- amin, amax, vararg, priority = nargs
157
-
158
- entry = (handler, amin, amax, vararg)
159
- if not obj_t_tup:
160
- self.empty = entry
161
-
162
- self.priorities[handler] = priority
163
-
164
- for i, cls in enumerate(obj_t_tup):
165
- if i not in self.maps:
166
- self.maps[i] = TypeMap()
167
- self.maps[i].register(cls, entry)
168
- if vararg:
169
- if -1 not in self.maps:
170
- self.maps[-1] = TypeMap()
171
- self.maps[-1].register(object, entry)
172
-
173
- def resolve(self, obj_t_tup):
174
- specificities = {}
175
- candidates = None
176
- nargs = len(obj_t_tup)
177
-
178
- for i, cls in enumerate(obj_t_tup):
179
- try:
180
- results = self.maps[i][cls]
181
- except KeyError:
182
- results = {}
183
-
184
- results = {
185
- handler: spc
186
- for (handler, min, max, va), spc in results.items()
187
- if min <= nargs <= (math.inf if va else max)
188
- }
189
-
190
- try:
191
- vararg_results = self.maps[-1][cls]
192
- except KeyError:
193
- vararg_results = {}
194
-
195
- vararg_results = {
196
- handler: spc
197
- for (handler, min, max, va), spc in vararg_results.items()
198
- if min <= nargs and i >= max
199
- }
200
-
201
- results.update(vararg_results)
202
-
203
- if candidates is None:
204
- candidates = set(results.keys())
205
- else:
206
- candidates &= results.keys()
207
- for c in candidates:
208
- specificities.setdefault(c, []).append(results[c])
170
+ return repr(cls)
171
+ elif hasattr(cls, "__name__"):
172
+ return cls.__name__
173
+ else:
174
+ return repr(cls)
209
175
 
210
- if not candidates:
211
- raise self.key_error(obj_t_tup, ())
212
176
 
213
- candidates = [
214
- (c, self.priorities.get(c, 0), tuple(specificities[c]))
215
- for c in candidates
216
- ]
177
+ def sigstring(types):
178
+ return ", ".join(map(clsstring, types))
217
179
 
218
- # The sort ensures that if candidate A dominates candidate B, A will
219
- # appear before B in the list. That's because it must dominate all
220
- # other possibilities on all arguments, so the sum of all specificities
221
- # has to be greater.
222
- # Note: priority is always more important than specificity
223
- candidates.sort(key=lambda cspc: (cspc[1], sum(cspc[2])), reverse=True)
224
-
225
- self.all[obj_t_tup] = {
226
- getattr(c[0], "__code__", None) for c in candidates
227
- }
228
-
229
- def _pull(candidates):
230
- if not candidates:
231
- return
232
- rval = [candidates[0]]
233
- c1, p1, spc1 = candidates[0]
234
- for c2, p2, spc2 in candidates[1:]:
235
- if p1 > p2 or (
236
- spc1 != spc2 and all(s1 >= s2 for s1, s2 in zip(spc1, spc2))
237
- ):
238
- # Candidate 1 dominates candidate 2
239
- continue
240
- else:
241
- # Candidate 1 does not dominate candidate 2, so we add it
242
- # to the list.
243
- rval.append((c2, p2, spc2))
244
- yield rval
245
- if len(rval) == 1:
246
- # Only groups of length 1 are correct and reachable, so we don't
247
- # care about the rest.
248
- yield from _pull(candidates[1:])
249
-
250
- results = list(_pull(candidates))
251
- parent = None
252
- for group in results:
253
- tup = obj_t_tup if parent is None else (parent, *obj_t_tup)
254
- if len(group) != 1:
255
- self.errors[tup] = self.key_error(obj_t_tup, group)
256
- break
257
- else:
258
- ((fn, _, _),) = group
259
- self[tup] = fn
260
- if hasattr(fn, "__code__"):
261
- parent = fn.__code__
262
- else:
263
- break
264
-
265
- return True
266
-
267
- def __missing__(self, obj_t_tup):
268
- if obj_t_tup and isinstance(obj_t_tup[0], CodeType):
269
- real_tup = obj_t_tup[1:]
270
- self[real_tup]
271
- if obj_t_tup[0] not in self.all[real_tup]:
272
- return self[real_tup]
273
- elif obj_t_tup in self.errors:
274
- raise self.errors[obj_t_tup]
275
- elif obj_t_tup in self: # pragma: no cover
276
- # PROBABLY not reachable
277
- return self[obj_t_tup]
278
- else:
279
- raise self.key_error(real_tup, ())
280
180
 
281
- if not obj_t_tup:
282
- if self.empty is MISSING:
283
- raise self.key_error(obj_t_tup, ())
284
- else:
285
- return self.empty[0]
286
-
287
- self.resolve(obj_t_tup)
288
- if obj_t_tup in self.errors:
289
- raise self.errors[obj_t_tup]
290
- else:
291
- return self[obj_t_tup]
181
+ class ArgumentAnalyzer:
182
+ def __init__(self):
183
+ self.name_to_positions = defaultdict(set)
184
+ self.position_to_names = defaultdict(set)
185
+ self.counts = defaultdict(lambda: [0, 0])
186
+ self.complex_transforms = set()
187
+ self.total = 0
188
+ self.is_method = None
189
+
190
+ def add(self, fn):
191
+ sig = Signature.extract(fn)
192
+ self.complex_transforms.update(
193
+ arg.canonical for arg in sig.arginfo if arg.is_complex
194
+ )
195
+ for arg in sig.arginfo:
196
+ if arg.position is not None:
197
+ self.position_to_names[arg.position].add(arg.name)
198
+ if arg.name is not None:
199
+ self.name_to_positions[arg.name].add(arg.canonical)
200
+
201
+ cnt = self.counts[arg.canonical]
202
+ cnt[0] += arg.required
203
+ cnt[1] += 1
204
+
205
+ self.total += 1
206
+
207
+ if self.is_method is None:
208
+ self.is_method = sig.is_method
209
+ elif self.is_method != sig.is_method: # pragma: no cover
210
+ raise TypeError(
211
+ "Some, but not all registered methods define `self`. It should be all or none."
212
+ )
292
213
 
214
+ def compile(self):
215
+ for name, pos in self.name_to_positions.items():
216
+ if len(pos) != 1:
217
+ if all(isinstance(p, int) for p in pos):
218
+ raise TypeError(
219
+ f"Argument '{name}' is declared in different positions by different methods. The same argument name should always be in the same position unless it is strictly positional."
220
+ )
221
+ else:
222
+ raise TypeError(
223
+ f"Argument '{name}' is declared in a positional and keyword setting by different methods. It should be either."
224
+ )
293
225
 
294
- def _fresh(t):
295
- """Returns a new subclass of type t.
226
+ p_to_n = [
227
+ list(names) for _, names in sorted(self.position_to_names.items())
228
+ ]
296
229
 
297
- Each Ovld corresponds to its own class, which allows for specialization of
298
- methods.
299
- """
300
- return type(t.__name__, (t,), {})
230
+ positional = list(
231
+ itertools.takewhile(
232
+ lambda names: len(names) == 1 and isinstance(names[0], str),
233
+ reversed(p_to_n),
234
+ )
235
+ )
236
+ positional.reverse()
237
+ strict_positional = p_to_n[: len(p_to_n) - len(positional)]
301
238
 
239
+ assert strict_positional + positional == p_to_n
302
240
 
303
- @keyword_decorator
304
- def _setattrs(fn, **kwargs):
305
- for k, v in kwargs.items():
306
- setattr(fn, k, v)
307
- return fn
241
+ strict_positional_required = [
242
+ f"ARG{pos + 1}"
243
+ for pos, _ in enumerate(strict_positional)
244
+ if self.counts[pos][0] == self.total
245
+ ]
246
+ strict_positional_optional = [
247
+ f"ARG{pos + 1}"
248
+ for pos, _ in enumerate(strict_positional)
249
+ if self.counts[pos][0] != self.total
250
+ ]
308
251
 
252
+ positional_required = [
253
+ names[0]
254
+ for pos, names in enumerate(positional)
255
+ if self.counts[pos + len(strict_positional)][0] == self.total
256
+ ]
257
+ positional_optional = [
258
+ names[0]
259
+ for pos, names in enumerate(positional)
260
+ if self.counts[pos + len(strict_positional)][0] != self.total
261
+ ]
309
262
 
310
- @keyword_decorator
311
- def _compile_first(fn, rename=None):
312
- def deco(self, *args, **kwargs):
313
- self.compile()
314
- method = getattr(self, fn.__name__)
315
- assert method is not deco
316
- return method(*args, **kwargs)
263
+ keywords = [
264
+ name
265
+ for _, (name,) in self.name_to_positions.items()
266
+ if not isinstance(name, int)
267
+ ]
268
+ keyword_required = [
269
+ name for name in keywords if self.counts[name][0] == self.total
270
+ ]
271
+ keyword_optional = [
272
+ name for name in keywords if self.counts[name][0] != self.total
273
+ ]
317
274
 
318
- def setalt(alt):
319
- deco._alt = alt
320
- return None
275
+ return (
276
+ strict_positional_required,
277
+ strict_positional_optional,
278
+ positional_required,
279
+ positional_optional,
280
+ keyword_required,
281
+ keyword_optional,
282
+ )
321
283
 
322
- deco.setalt = setalt
323
- deco._replace_by = fn
324
- deco._alt = None
325
- deco._rename = rename
326
- return deco
284
+ def lookup_for(self, key):
285
+ return (
286
+ "self.map.transform" if key in self.complex_transforms else "type"
287
+ )
327
288
 
328
289
 
329
290
  class _Ovld:
@@ -334,19 +295,11 @@ class _Ovld:
334
295
  function should annotate the same parameter.
335
296
 
336
297
  Arguments:
337
- dispatch: A function to use as the entry point. It must find the
338
- function to dispatch to and call it.
339
- postprocess: A function to call on the return value. It is not called
340
- after recursive calls.
341
298
  mixins: A list of Ovld instances that contribute functions to this
342
299
  Ovld.
343
- type_error: The error type to raise when no function can be found to
344
- dispatch to (default: TypeError).
345
300
  name: Optional name for the Ovld. If not provided, it will be
346
301
  gotten automatically from the first registered function or
347
302
  dispatch.
348
- mapper: Class implementing a mapping interface from a tuple of
349
- types to a handler (default: MultiTypeMap).
350
303
  linkback: Whether to keep a pointer in the parent mixins to this
351
304
  ovld so that updates can be propagated. (default: False)
352
305
  allow_replacement: Allow replacing a method by another with the
@@ -356,43 +309,27 @@ class _Ovld:
356
309
  def __init__(
357
310
  self,
358
311
  *,
359
- dispatch=None,
360
- postprocess=None,
361
- type_error=TypeError,
362
312
  mixins=[],
363
313
  bootstrap=None,
364
314
  name=None,
365
- mapper=MultiTypeMap,
366
315
  linkback=False,
367
316
  allow_replacement=True,
368
317
  ):
369
318
  """Initialize an Ovld."""
319
+ self.id = next(_current_id)
370
320
  self._compiled = False
371
- self._dispatch = dispatch
372
- self.maindoc = None
373
- self.mapper = mapper
374
321
  self.linkback = linkback
375
322
  self.children = []
376
- self.type_error = type_error
377
- self.postprocess = postprocess
378
323
  self.allow_replacement = allow_replacement
379
- self.bootstrap_class = OvldCall
380
- if self.postprocess:
381
- assert bootstrap is not False
382
- self.bootstrap = True
383
- elif isinstance(bootstrap, type):
384
- self.bootstrap_class = bootstrap
385
- self.bootstrap = True
386
- else:
387
- self.bootstrap = bootstrap
324
+ self.bootstrap = bootstrap
388
325
  self.name = name
326
+ self.shortname = name or f"__OVLD{self.id}"
389
327
  self.__name__ = name
390
328
  self._defns = {}
391
329
  self._locked = False
392
330
  self.mixins = []
393
331
  self.add_mixins(*mixins)
394
- self.ocls = _fresh(self.bootstrap_class)
395
- self._make_signature()
332
+ self.ocls = _fresh(OvldCall)
396
333
 
397
334
  @property
398
335
  def defns(self):
@@ -402,6 +339,42 @@ class _Ovld:
402
339
  defns.update(self._defns)
403
340
  return defns
404
341
 
342
+ @property
343
+ def __doc__(self):
344
+ if not self._compiled:
345
+ self.compile()
346
+
347
+ docs = [fn.__doc__ for fn in self.defns.values() if fn.__doc__]
348
+ if len(docs) == 1:
349
+ maindoc = docs[0]
350
+ else:
351
+ maindoc = f"Ovld with {len(self.defns)} methods."
352
+
353
+ doc = f"{maindoc}\n\n"
354
+ for fn in self.defns.values():
355
+ fndef = inspect.signature(fn)
356
+ fdoc = fn.__doc__
357
+ if not fdoc or fdoc == maindoc:
358
+ doc += f"{self.__name__}{fndef}\n\n"
359
+ else:
360
+ if not fdoc.strip(" ").endswith("\n"):
361
+ fdoc += "\n"
362
+ fdoc = textwrap.indent(fdoc, " " * 4)
363
+ doc += f"{self.__name__}{fndef}\n{fdoc}\n"
364
+ return doc
365
+
366
+ @property
367
+ def __signature__(self):
368
+ if not self._compiled:
369
+ self.compile()
370
+
371
+ sig = inspect.signature(self._dispatch)
372
+ if not self.argument_analysis.is_method:
373
+ sig = inspect.Signature(
374
+ [v for k, v in sig.parameters.items() if k != "self"]
375
+ )
376
+ return sig
377
+
405
378
  def lock(self):
406
379
  self._locked = True
407
380
 
@@ -414,40 +387,21 @@ class _Ovld:
414
387
  for mixin in mixins:
415
388
  if self.linkback:
416
389
  mixin.children.append(self)
417
- if mixin._defns:
418
- assert mixin.bootstrap is not None
419
- if self.bootstrap is None:
420
- self.bootstrap = mixin.bootstrap
421
- assert mixin.bootstrap is self.bootstrap
422
- if mixin._dispatch:
423
- self._dispatch = mixin._dispatch
390
+ if mixin._defns and self.bootstrap is None:
391
+ self.bootstrap = mixin.bootstrap
424
392
  self.mixins += mixins
425
393
 
426
- def _sig_string(self, type_tuple):
427
- def clsname(cls):
428
- if cls is object:
429
- return "*"
430
- elif is_type_of_type(cls):
431
- arg = clsname(cls.__args__[0])
432
- return f"type[{arg}]"
433
- elif hasattr(cls, "__name__"):
434
- return cls.__name__
435
- else:
436
- return repr(cls)
437
-
438
- return ", ".join(map(clsname, type_tuple))
439
-
440
394
  def _key_error(self, key, possibilities=None):
441
- typenames = self._sig_string(key)
395
+ typenames = sigstring(key)
442
396
  if not possibilities:
443
- return self.type_error(
397
+ return TypeError(
444
398
  f"No method in {self} for argument types [{typenames}]"
445
399
  )
446
400
  else:
447
401
  hlp = ""
448
402
  for p, prio, spc in possibilities:
449
403
  hlp += f"* {p.__name__} (priority: {prio}, specificity: {list(spc)})\n"
450
- return self.type_error(
404
+ return TypeError(
451
405
  f"Ambiguous resolution in {self} for"
452
406
  f" argument types [{typenames}]\n"
453
407
  f"Candidates are:\n{hlp}"
@@ -458,58 +412,15 @@ class _Ovld:
458
412
  """Rename this Ovld."""
459
413
  self.name = name
460
414
  self.__name__ = name
461
- self._make_signature()
462
-
463
- def _make_signature(self):
464
- """Make the __doc__ and __signature__."""
465
415
 
466
- def modelA(*args, **kwargs): # pragma: no cover
467
- pass
468
-
469
- def modelB(self, *args, **kwargs): # pragma: no cover
470
- pass
471
-
472
- seen = set()
473
- doc = (
474
- f"{self.maindoc}\n"
475
- if self.maindoc
476
- else f"Ovld with {len(self.defns)} methods.\n\n"
477
- )
478
- for key, fn in self.defns.items():
479
- if fn in seen:
480
- continue
481
- seen.add(fn)
482
- fndef = inspect.signature(fn)
483
- fdoc = fn.__doc__
484
- if not fdoc or fdoc == self.maindoc:
485
- doc += f" ``{self.__name__}{fndef}``\n\n"
486
- else:
487
- if not fdoc.strip(" ").endswith("\n"):
488
- fdoc += "\n"
489
- fdoc = textwrap.indent(fdoc, " " * 8)
490
- doc += f" ``{self.__name__}{fndef}``\n{fdoc}\n"
491
- self.__doc__ = doc
492
- if self.bootstrap:
493
- self.__signature__ = inspect.signature(modelB)
494
- else:
495
- self.__signature__ = inspect.signature(modelA)
496
-
497
- def _set_attrs_from(self, fn, dispatch=False):
416
+ def _set_attrs_from(self, fn):
498
417
  """Inherit relevant attributes from the function."""
499
418
  if self.bootstrap is None:
500
- sign = inspect.signature(fn)
501
- params = list(sign.parameters.values())
502
- if not dispatch:
503
- if params and params[0].name == "self":
504
- self.bootstrap = True
505
- else:
506
- self.bootstrap = False
419
+ self.bootstrap = arg0_is_self(fn)
507
420
 
508
421
  if self.name is None:
509
422
  self.name = f"{fn.__module__}.{fn.__qualname__}"
510
- self.maindoc = fn.__doc__
511
- if self.maindoc and not self.maindoc.strip(" ").endswith("\n"):
512
- self.maindoc += "\n"
423
+ self.shortname = fn.__name__
513
424
  self.__name__ = fn.__name__
514
425
  self.__qualname__ = fn.__qualname__
515
426
  self.__module__ = fn.__module__
@@ -533,57 +444,51 @@ class _Ovld:
533
444
  for mixin in self.mixins:
534
445
  if self not in mixin.children:
535
446
  mixin.lock()
536
- self._compiled = True
537
- self.map = self.mapper(key_error=self._key_error)
538
447
 
539
448
  cls = type(self)
540
449
  if self.name is None:
541
- self.name = self.__name__ = f"ovld{id(self)}"
450
+ self.name = self.__name__ = f"ovld{self.id}"
542
451
 
543
452
  name = self.__name__
453
+ self.map = MultiTypeMap(name=name, key_error=self._key_error)
544
454
 
545
455
  # Replace the appropriate functions by their final behavior
546
456
  for method in dir(cls):
547
457
  value = getattr(cls, method)
548
458
  repl = getattr(value, "_replace_by", None)
549
459
  if repl:
550
- if self.bootstrap and value._alt:
551
- repl = value._alt
552
460
  repl = self._maybe_rename(repl)
553
461
  setattr(cls, method, repl)
554
462
 
555
463
  target = self.ocls if self.bootstrap else cls
556
- if self._dispatch:
557
- target.__call__ = self._dispatch
558
464
 
559
- # Rename the dispatch
560
- target.__call__ = rename_function(target.__call__, f"{name}.dispatch")
465
+ anal = ArgumentAnalyzer()
466
+ for key, fn in list(self.defns.items()):
467
+ anal.add(fn)
468
+ self.argument_analysis = anal
469
+ dispatch = generate_dispatch(anal)
470
+ self._dispatch = dispatch
471
+ target.__call__ = rename_function(dispatch, f"{name}.dispatch")
561
472
 
562
473
  for key, fn in list(self.defns.items()):
563
474
  self.register_signature(key, fn)
564
475
 
565
- def dispatch(self, dispatch):
566
- """Set a dispatch function."""
567
- if self._dispatch is not None:
568
- raise TypeError(f"dispatch for {self} is already set")
569
- self._dispatch = dispatch
570
- self._set_attrs_from(dispatch, dispatch=True)
571
- return self
476
+ self._compiled = True
572
477
 
478
+ @_compile_first
573
479
  def resolve(self, *args):
574
480
  """Find the correct method to call for the given arguments."""
575
481
  return self.map[tuple(map(self.map.transform, args))]
576
482
 
577
- def register_signature(self, key, orig_fn):
483
+ def register_signature(self, sig, orig_fn):
578
484
  """Register a function for the given signature."""
579
- sig, min, max, vararg, priority = key
580
- fn = rename_function(
581
- orig_fn, f"{self.__name__}[{self._sig_string(sig)}]"
485
+ fn = adapt_function(
486
+ orig_fn, self, f"{self.__name__}[{sigstring(sig.types)}]"
582
487
  )
583
488
  # We just need to keep the Conformer pointer alive for jurigged
584
489
  # to find it, if jurigged is used with ovld
585
490
  fn._conformer = Conformer(self, orig_fn, fn)
586
- self.map.register(sig, (min, max, vararg, priority), fn)
491
+ self.map.register(sig, fn)
587
492
  return self
588
493
 
589
494
  def register(self, fn=None, priority=0):
@@ -595,60 +500,17 @@ class _Ovld:
595
500
  def _register(self, fn, priority):
596
501
  """Register a function."""
597
502
 
598
- def _normalize_type(t, force_tuple=False):
599
- origin = getattr(t, "__origin__", None)
600
- if UnionType and isinstance(t, UnionType):
601
- return _normalize_type(t.__args__)
602
- elif origin is type:
603
- return (t,) if force_tuple else t
604
- elif origin is typing.Union:
605
- return _normalize_type(t.__args__)
606
- elif origin is not None:
607
- raise TypeError(
608
- f"ovld does not accept generic types except type, Union or Optional, not {t}"
609
- )
610
- elif isinstance(t, tuple):
611
- return tuple(_normalize_type(t2) for t2 in t)
612
- elif force_tuple:
613
- return (t,)
614
- else:
615
- return t
616
-
617
503
  self._attempt_modify()
618
504
 
619
505
  self._set_attrs_from(fn)
620
506
 
621
- ann = fn.__annotations__
622
- argspec = inspect.getfullargspec(fn)
623
- argnames = argspec.args
624
- if self.bootstrap:
625
- if argnames[0] != "self":
626
- raise TypeError(
627
- "The first argument of the function must be named `self`"
628
- )
629
- argnames = argnames[1:]
630
-
631
- typelist = []
632
- for i, name in enumerate(argnames):
633
- t = ann.get(name, None)
634
- if t is None:
635
- typelist.append(object)
636
- else:
637
- typelist.append(t)
638
-
639
- max_pos = len(argnames)
640
- req_pos = max_pos - len(argspec.defaults or ())
641
-
642
- typelist_tups = tuple(
643
- _normalize_type(t, force_tuple=True) for t in typelist
644
- )
645
- for tl in itertools.product(*typelist_tups):
646
- sig = (tuple(tl), req_pos, max_pos, bool(argspec.varargs), priority)
647
- if not self.allow_replacement and sig in self._defns:
648
- raise TypeError(f"There is already a method for {tl}")
649
- self._defns[(*sig,)] = fn
507
+ sig = replace(Signature.extract(fn), priority=priority)
508
+ if not self.allow_replacement and sig in self._defns:
509
+ raise TypeError(
510
+ f"There is already a method for {sigstring(sig.types)}"
511
+ )
512
+ self._defns[sig] = fn
650
513
 
651
- self._make_signature()
652
514
  self._update()
653
515
  return self
654
516
 
@@ -664,13 +526,7 @@ class _Ovld:
664
526
  for child in self.children:
665
527
  child._update()
666
528
 
667
- def copy(
668
- self,
669
- dispatch=MISSING,
670
- postprocess=None,
671
- mixins=[],
672
- linkback=False,
673
- ):
529
+ def copy(self, mixins=[], linkback=False):
674
530
  """Create a copy of this Ovld.
675
531
 
676
532
  New functions can be registered to the copy without affecting the
@@ -678,13 +534,11 @@ class _Ovld:
678
534
  """
679
535
  return _fresh(_Ovld)(
680
536
  bootstrap=self.bootstrap,
681
- dispatch=self._dispatch if dispatch is MISSING else dispatch,
682
537
  mixins=[self, *mixins],
683
- postprocess=postprocess or self.postprocess,
684
538
  linkback=linkback,
685
539
  )
686
540
 
687
- def variant(self, fn=None, **kwargs):
541
+ def variant(self, fn=None, priority=0, **kwargs):
688
542
  """Decorator to create a variant of this Ovld.
689
543
 
690
544
  New functions can be registered to the variant without affecting the
@@ -692,73 +546,58 @@ class _Ovld:
692
546
  """
693
547
  ov = self.copy(**kwargs)
694
548
  if fn is None:
695
- return ov.register
549
+ return partial(ov.register, priority=priority)
696
550
  else:
697
- ov.register(fn)
551
+ ov.register(fn, priority=priority)
698
552
  return ov
699
553
 
700
- @_compile_first
701
- def get_map(self):
702
- return self.map
703
-
704
554
  @_compile_first
705
555
  def __get__(self, obj, cls):
706
556
  if obj is None:
707
557
  return self
708
- return self.ocls(
709
- map=self.map,
710
- bind_to=obj,
711
- super=self.mixins[0] if len(self.mixins) == 1 else None,
712
- )
558
+ key = self.shortname
559
+ rval = obj.__dict__.get(key, None)
560
+ if rval is None:
561
+ obj.__dict__[key] = rval = self.ocls(self, obj)
562
+ return rval
713
563
 
714
564
  @_compile_first
715
565
  def __getitem__(self, t):
716
566
  if not isinstance(t, tuple):
717
567
  t = (t,)
718
- assert not self.bootstrap
719
568
  return self.map[t]
720
569
 
721
570
  @_compile_first
722
571
  @_setattrs(rename="dispatch")
723
- def __call__(self, *args, **kwargs):
572
+ def __call__(self, *args): # pragma: no cover
724
573
  """Call the overloaded function.
725
574
 
726
- This version of __call__ is used when bootstrap is False.
727
-
728
- If bootstrap is False and a dispatch function is provided, it
729
- replaces this function.
575
+ This should be replaced by an auto-generated function.
730
576
  """
731
577
  key = tuple(map(self.map.transform, args))
732
578
  method = self.map[key]
733
- return method(*args, **kwargs)
579
+ return method(*args)
734
580
 
735
581
  @_compile_first
736
582
  @_setattrs(rename="next")
737
- def next(self, *args, **kwargs):
583
+ def next(self, *args):
738
584
  """Call the next matching method after the caller, in terms of priority or specificity."""
739
585
  fr = sys._getframe(1)
740
586
  key = (fr.f_code, *map(self.map.transform, args))
741
587
  method = self.map[key]
742
- return method(*args, **kwargs)
743
-
744
- @__call__.setalt
745
- @_setattrs(rename="entry")
746
- def __ovldcall__(self, *args, **kwargs):
747
- """Call the overloaded function.
748
-
749
- This version of __call__ is used when bootstrap is True. This function is
750
- only called once at the entry point: recursive calls will will be to
751
- OvldCall.__call__.
752
- """
753
- ovc = self.__get__(BOOTSTRAP, None)
754
- res = ovc(*args, **kwargs)
755
- if self.postprocess:
756
- res = self.postprocess(self, res)
757
- return res
588
+ return method(*args)
758
589
 
759
590
  def __repr__(self):
760
591
  return f"<Ovld {self.name or hex(id(self))}>"
761
592
 
593
+ @_compile_first
594
+ def display_methods(self):
595
+ self.map.display_methods()
596
+
597
+ @_compile_first
598
+ def display_resolution(self, *args, **kwargs):
599
+ self.map.display_resolution(*args, **kwargs)
600
+
762
601
 
763
602
  def is_ovld(x):
764
603
  """Return whether the argument is an ovld function/method."""
@@ -768,47 +607,43 @@ def is_ovld(x):
768
607
  class OvldCall:
769
608
  """Context for an Ovld call."""
770
609
 
771
- def __init__(self, map, bind_to, super=None):
610
+ def __init__(self, ovld, bind_to):
772
611
  """Initialize an OvldCall."""
773
- self.map = map
774
- self._parent = super
775
- self.obj = self if bind_to is BOOTSTRAP else bind_to
612
+ self.ovld = ovld
613
+ self.map = ovld.map
614
+ self.obj = bind_to
776
615
 
777
- def __getitem__(self, t):
778
- """Find the right method to call given a tuple of types."""
779
- if not isinstance(t, tuple):
780
- t = (t,)
781
- return self.map[t].__get__(self.obj)
616
+ @property
617
+ def __name__(self):
618
+ return self.ovld.__name__
782
619
 
783
- def next(self, *args, **kwargs):
620
+ @property
621
+ def __doc__(self):
622
+ return self.ovld.__doc__
623
+
624
+ @property
625
+ def __signature__(self):
626
+ return self.ovld.__signature__
627
+
628
+ def next(self, *args):
784
629
  """Call the next matching method after the caller, in terms of priority or specificity."""
785
630
  fr = sys._getframe(1)
786
631
  key = (fr.f_code, *map(self.map.transform, args))
787
632
  method = self.map[key]
788
- return method(self.obj, *args, **kwargs)
789
-
790
- def super(self, *args, **kwargs):
791
- """Use the parent ovld's method for this call."""
792
- pmap = self._parent.get_map()
793
- method = pmap[tuple(map(pmap.transform, args))]
794
- return method.__get__(self.obj)(*args, **kwargs)
633
+ return method(self.obj, *args)
795
634
 
796
635
  def resolve(self, *args):
797
636
  """Find the right method to call for the given arguments."""
798
- return self[tuple(map(self.map.transform, args))]
799
-
800
- def call(self, *args):
801
- """Call the right method for the given arguments."""
802
- return self[tuple(map(self.map.transform, args))](*args)
637
+ return self.map[tuple(map(self.map.transform, args))].__get__(self.obj)
803
638
 
804
- def __call__(self, *args, **kwargs):
639
+ def __call__(self, *args): # pragma: no cover
805
640
  """Call this overloaded function.
806
641
 
807
- If a dispatch function is provided, it replaces this function.
642
+ This should be replaced by an auto-generated function.
808
643
  """
809
644
  key = tuple(map(self.map.transform, args))
810
645
  method = self.map[key]
811
- return method(self.obj, *args, **kwargs)
646
+ return method(self.obj, *args)
812
647
 
813
648
 
814
649
  def Ovld(*args, **kwargs):
@@ -823,7 +658,7 @@ def extend_super(fn):
823
658
  plus this definition and others with the same name.
824
659
  """
825
660
  if not is_ovld(fn):
826
- fn = ovld(fn)
661
+ fn = ovld(fn, fresh=True)
827
662
  fn._extend_super = True
828
663
  return fn
829
664
 
@@ -835,29 +670,45 @@ class ovld_cls_dict(dict):
835
670
  """
836
671
 
837
672
  def __init__(self, bases):
838
- self._mock = type("MockSuper", bases, {})
673
+ self._bases = bases
839
674
 
840
675
  def __setitem__(self, attr, value):
676
+ prev = None
841
677
  if attr in self:
842
678
  prev = self[attr]
679
+ if inspect.isfunction(prev):
680
+ prev = ovld(prev, fresh=True)
681
+ elif not is_ovld(prev): # pragma: no cover
682
+ prev = None
843
683
  elif is_ovld(value) and getattr(value, "_extend_super", False):
844
- prev = getattr(self._mock, attr, None)
845
- if is_ovld(prev):
846
- prev = prev.copy()
684
+ mixins = []
685
+ for base in self._bases:
686
+ if (candidate := getattr(base, attr, None)) is not None:
687
+ if is_ovld(candidate) or inspect.isfunction(candidate):
688
+ mixins.append(candidate)
689
+ if mixins:
690
+ prev, *others = mixins
691
+ if is_ovld(prev):
692
+ prev = prev.copy()
693
+ else:
694
+ prev = ovld(prev, fresh=True)
695
+ for other in others:
696
+ if is_ovld(other):
697
+ prev.add_mixins(other)
698
+ else:
699
+ prev.register(other)
847
700
  else:
848
701
  prev = None
849
702
 
850
703
  if prev is not None:
851
- if inspect.isfunction(prev):
852
- prev = ovld(prev)
853
-
854
- if is_ovld(prev):
855
- if is_ovld(value):
856
- prev.add_mixins(value)
857
- value = prev
858
- elif inspect.isfunction(value):
859
- prev.register(value)
860
- value = prev
704
+ if is_ovld(value) and prev is not value:
705
+ if prev.name is None:
706
+ prev.rename(value.name)
707
+ prev.add_mixins(value)
708
+ value = prev
709
+ elif inspect.isfunction(value):
710
+ prev.register(value)
711
+ value = prev
861
712
 
862
713
  super().__setitem__(attr, value)
863
714
 
@@ -905,19 +756,31 @@ class OvldBase(metaclass=OvldMC):
905
756
 
906
757
 
907
758
  def _find_overload(fn, **kwargs):
908
- mod = __import__(fn.__module__, fromlist="_")
909
- dispatch = getattr(mod, fn.__qualname__, None)
759
+ fr = sys._getframe(1) # We typically expect to get to frame 3.
760
+ while fr and fn.__code__ not in fr.f_code.co_consts:
761
+ # We are basically searching for the function's code object in the stack.
762
+ # When a class/function A is nested in a class/function B, the former's
763
+ # code object is in the latter's co_consts. If ovld is used as a decorator,
764
+ # on A, then necessarily we are inside the execution of B, so B should be
765
+ # on the stack and we should be able to find A's code object in there.
766
+ fr = fr.f_back
767
+
768
+ if not fr:
769
+ raise UsageError("@ovld only works as a decorator.")
770
+
771
+ dispatch = fr.f_locals.get(fn.__name__, None)
772
+
910
773
  if dispatch is None:
911
774
  dispatch = _fresh(_Ovld)(**kwargs)
775
+ elif not is_ovld(dispatch): # pragma: no cover
776
+ raise TypeError("@ovld requires Ovld instance")
912
777
  elif kwargs: # pragma: no cover
913
778
  raise TypeError("Cannot configure an overload that already exists")
914
- if not is_ovld(dispatch): # pragma: no cover
915
- raise TypeError("@ovld requires Ovld instance")
916
779
  return dispatch
917
780
 
918
781
 
919
782
  @keyword_decorator
920
- def ovld(fn, priority=0, **kwargs):
783
+ def ovld(fn, priority=0, fresh=False, **kwargs):
921
784
  """Overload a function.
922
785
 
923
786
  Overloading is based on the function name.
@@ -931,71 +794,29 @@ def ovld(fn, priority=0, **kwargs):
931
794
 
932
795
  Arguments:
933
796
  fn: The function to register.
934
- dispatch: A function to use as the entry point. It must find the
935
- function to dispatch to and call it.
936
- postprocess: A function to call on the return value. It is not called
937
- after recursive calls.
797
+ priority: The priority of the function in the resolution order.
798
+ fresh: Whether to create a new ovld or try to reuse an existing one.
938
799
  mixins: A list of Ovld instances that contribute functions to this
939
800
  Ovld.
940
- type_error: The error type to raise when no function can be found to
941
- dispatch to (default: TypeError).
942
801
  name: Optional name for the Ovld. If not provided, it will be
943
802
  gotten automatically from the first registered function or
944
803
  dispatch.
945
- mapper: Class implementing a mapping interface from a tuple of
946
- types to a handler (default: MultiTypeMap).
947
804
  linkback: Whether to keep a pointer in the parent mixins to this
948
805
  ovld so that updates can be propagated. (default: False)
949
806
  """
950
- dispatch = _find_overload(fn, **kwargs)
807
+ if fresh:
808
+ dispatch = _fresh(_Ovld)(**kwargs)
809
+ else:
810
+ dispatch = _find_overload(fn, **kwargs)
951
811
  return dispatch.register(fn, priority=priority)
952
812
 
953
813
 
954
- @keyword_decorator
955
- def ovld_dispatch(dispatch, **kwargs):
956
- """Overload a function using the decorated function as a dispatcher.
957
-
958
- The dispatch is the entry point for the function and receives a `self`
959
- which is an Ovld or OvldCall instance, and the rest of the arguments.
960
- It may call `self.resolve(arg1, arg2, ...)` to get the right method to
961
- call.
962
-
963
- The decorator optionally takes keyword arguments, *only* on the first
964
- use.
965
-
966
- Arguments:
967
- dispatch: The function to use as the entry point. It must find the
968
- function to dispatch to and call it.
969
- postprocess: A function to call on the return value. It is not called
970
- after recursive calls.
971
- mixins: A list of Ovld instances that contribute functions to this
972
- Ovld.
973
- type_error: The error type to raise when no function can be found to
974
- dispatch to (default: TypeError).
975
- name: Optional name for the Ovld. If not provided, it will be
976
- gotten automatically from the first registered function or
977
- dispatch.
978
- mapper: Class implementing a mapping interface from a tuple of
979
- types to a handler (default: MultiTypeMap).
980
- linkback: Whether to keep a pointer in the parent mixins to this
981
- ovld so that updates can be propagated. (default: False)
982
- """
983
- ov = _find_overload(dispatch, **kwargs)
984
- return ov.dispatch(dispatch)
985
-
986
-
987
- ovld.dispatch = ovld_dispatch
988
-
989
-
990
814
  __all__ = [
991
- "MultiTypeMap",
992
815
  "Ovld",
993
816
  "OvldBase",
994
817
  "OvldCall",
995
818
  "OvldMC",
996
- "TypeMap",
997
819
  "extend_super",
998
820
  "is_ovld",
999
821
  "ovld",
1000
- "ovld_dispatch",
1001
822
  ]