ovld 0.4.6__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/core.py CHANGED
@@ -4,11 +4,9 @@ import inspect
4
4
  import itertools
5
5
  import sys
6
6
  import textwrap
7
- import typing
8
- from collections import OrderedDict, defaultdict
9
- from dataclasses import dataclass, field, replace
10
- from functools import cached_property, partial
11
- from types import FunctionType, GenericAlias
7
+ from dataclasses import replace
8
+ from functools import partial
9
+ from types import FunctionType
12
10
 
13
11
  from .recode import (
14
12
  Conformer,
@@ -16,97 +14,20 @@ from .recode import (
16
14
  generate_dispatch,
17
15
  rename_code,
18
16
  )
17
+ from .signatures import ArgumentAnalyzer, LazySignature, Signature
19
18
  from .typemap import MultiTypeMap
20
- from .types import clsstring, normalize_type
21
- from .utils import MISSING, UsageError, keyword_decorator, subtler_type
19
+ from .utils import (
20
+ MISSING,
21
+ ResolutionError,
22
+ UsageError,
23
+ keyword_decorator,
24
+ sigstring,
25
+ subtler_type,
26
+ )
22
27
 
23
28
  _current_id = itertools.count()
24
29
 
25
30
 
26
- @keyword_decorator
27
- def _setattrs(fn, **kwargs):
28
- for k, v in kwargs.items():
29
- setattr(fn, k, v)
30
- return fn
31
-
32
-
33
- class LazySignature(inspect.Signature):
34
- def __init__(self, ovld):
35
- super().__init__([])
36
- self.ovld = ovld
37
-
38
- def replace(
39
- self, *, parameters=inspect._void, return_annotation=inspect._void
40
- ): # pragma: no cover
41
- if parameters is inspect._void:
42
- parameters = self.parameters.values()
43
-
44
- if return_annotation is inspect._void:
45
- return_annotation = self._return_annotation
46
-
47
- return inspect.Signature(
48
- parameters, return_annotation=return_annotation
49
- )
50
-
51
- @property
52
- def parameters(self):
53
- anal = self.ovld.analyze_arguments()
54
- parameters = []
55
- if anal.is_method:
56
- parameters.append(
57
- inspect.Parameter(
58
- name="self",
59
- kind=inspect._POSITIONAL_ONLY,
60
- )
61
- )
62
- parameters += [
63
- inspect.Parameter(
64
- name=p,
65
- kind=inspect._POSITIONAL_ONLY,
66
- )
67
- for p in anal.strict_positional_required
68
- ]
69
- parameters += [
70
- inspect.Parameter(
71
- name=p,
72
- kind=inspect._POSITIONAL_ONLY,
73
- default=MISSING,
74
- )
75
- for p in anal.strict_positional_optional
76
- ]
77
- parameters += [
78
- inspect.Parameter(
79
- name=p,
80
- kind=inspect._POSITIONAL_OR_KEYWORD,
81
- )
82
- for p in anal.positional_required
83
- ]
84
- parameters += [
85
- inspect.Parameter(
86
- name=p,
87
- kind=inspect._POSITIONAL_OR_KEYWORD,
88
- default=MISSING,
89
- )
90
- for p in anal.positional_optional
91
- ]
92
- parameters += [
93
- inspect.Parameter(
94
- name=p,
95
- kind=inspect._KEYWORD_ONLY,
96
- )
97
- for p in anal.keyword_required
98
- ]
99
- parameters += [
100
- inspect.Parameter(
101
- name=p,
102
- kind=inspect._KEYWORD_ONLY,
103
- default=MISSING,
104
- )
105
- for p in anal.keyword_optional
106
- ]
107
- return OrderedDict({p.name: p for p in parameters})
108
-
109
-
110
31
  def bootstrap_dispatch(ov, name):
111
32
  def first_entry(*args, **kwargs):
112
33
  ov.compile()
@@ -122,6 +43,7 @@ def bootstrap_dispatch(ov, name):
122
43
  dispatch.__signature__ = LazySignature(ov)
123
44
  dispatch.__ovld__ = ov
124
45
  dispatch.register = ov.register
46
+ dispatch.resolve_for_values = ov.resolve_for_values
125
47
  dispatch.resolve = ov.resolve
126
48
  dispatch.copy = ov.copy
127
49
  dispatch.variant = ov.variant
@@ -130,214 +52,10 @@ def bootstrap_dispatch(ov, name):
130
52
  dispatch.add_mixins = ov.add_mixins
131
53
  dispatch.unregister = ov.unregister
132
54
  dispatch.next = ov.next
55
+ dispatch.first_entry = first_entry
133
56
  return dispatch
134
57
 
135
58
 
136
- @dataclass(frozen=True)
137
- class Arginfo:
138
- position: typing.Optional[int]
139
- name: typing.Optional[str]
140
- required: bool
141
- ann: type
142
-
143
- @cached_property
144
- def is_complex(self):
145
- return isinstance(self.ann, GenericAlias)
146
-
147
- @cached_property
148
- def canonical(self):
149
- return self.name if self.position is None else self.position
150
-
151
-
152
- @dataclass(frozen=True)
153
- class Signature:
154
- types: tuple
155
- return_type: type
156
- req_pos: int
157
- max_pos: int
158
- req_names: frozenset
159
- vararg: bool
160
- priority: float
161
- tiebreak: int = 0
162
- is_method: bool = False
163
- arginfo: list[Arginfo] = field(
164
- default_factory=list, hash=False, compare=False
165
- )
166
-
167
- @classmethod
168
- def extract(cls, fn):
169
- typelist = []
170
- sig = inspect.signature(fn)
171
- max_pos = 0
172
- req_pos = 0
173
- req_names = set()
174
- is_method = False
175
-
176
- arginfo = []
177
- for i, (name, param) in enumerate(sig.parameters.items()):
178
- if name == "self":
179
- assert i == 0
180
- is_method = True
181
- continue
182
- pos = nm = None
183
- ann = normalize_type(param.annotation, fn)
184
- if param.kind is inspect._POSITIONAL_ONLY:
185
- pos = i - is_method
186
- typelist.append(ann)
187
- req_pos += param.default is inspect._empty
188
- max_pos += 1
189
- elif param.kind is inspect._POSITIONAL_OR_KEYWORD:
190
- pos = i - is_method
191
- nm = param.name
192
- typelist.append(ann)
193
- req_pos += param.default is inspect._empty
194
- max_pos += 1
195
- elif param.kind is inspect._KEYWORD_ONLY:
196
- nm = param.name
197
- typelist.append((param.name, ann))
198
- if param.default is inspect._empty:
199
- req_names.add(param.name)
200
- elif param.kind is inspect._VAR_POSITIONAL:
201
- raise TypeError("ovld does not support *args")
202
- elif param.kind is inspect._VAR_KEYWORD:
203
- raise TypeError("ovld does not support **kwargs")
204
- arginfo.append(
205
- Arginfo(
206
- position=pos,
207
- name=nm,
208
- required=param.default is inspect._empty,
209
- ann=normalize_type(param.annotation, fn),
210
- )
211
- )
212
-
213
- return cls(
214
- types=tuple(typelist),
215
- return_type=normalize_type(sig.return_annotation, fn),
216
- req_pos=req_pos,
217
- max_pos=max_pos,
218
- req_names=frozenset(req_names),
219
- vararg=False,
220
- is_method=is_method,
221
- priority=None,
222
- arginfo=arginfo,
223
- )
224
-
225
-
226
- def typemap_entry_string(cls):
227
- if isinstance(cls, tuple):
228
- key, typ = cls
229
- return f"{key}: {clsstring(typ)}"
230
- else:
231
- return clsstring(cls)
232
-
233
-
234
- def sigstring(types):
235
- return ", ".join(map(typemap_entry_string, types))
236
-
237
-
238
- class ArgumentAnalyzer:
239
- def __init__(self):
240
- self.name_to_positions = defaultdict(set)
241
- self.position_to_names = defaultdict(set)
242
- self.counts = defaultdict(lambda: [0, 0])
243
- self.complex_transforms = set()
244
- self.total = 0
245
- self.is_method = None
246
- self.done = False
247
-
248
- def add(self, fn):
249
- self.done = False
250
- sig = Signature.extract(fn)
251
- self.complex_transforms.update(
252
- arg.canonical for arg in sig.arginfo if arg.is_complex
253
- )
254
- for arg in sig.arginfo:
255
- if arg.position is not None:
256
- self.position_to_names[arg.position].add(arg.name)
257
- if arg.name is not None:
258
- self.name_to_positions[arg.name].add(arg.canonical)
259
-
260
- cnt = self.counts[arg.canonical]
261
- cnt[0] += arg.required
262
- cnt[1] += 1
263
-
264
- self.total += 1
265
-
266
- if self.is_method is None:
267
- self.is_method = sig.is_method
268
- elif self.is_method != sig.is_method: # pragma: no cover
269
- raise TypeError(
270
- "Some, but not all registered methods define `self`. It should be all or none."
271
- )
272
-
273
- def compile(self):
274
- if self.done:
275
- return
276
- for name, pos in self.name_to_positions.items():
277
- if len(pos) != 1:
278
- if all(isinstance(p, int) for p in pos):
279
- raise TypeError(
280
- f"Argument '{name}' is declared in different positions by different methods. The same argument name should always be in the same position unless it is strictly positional."
281
- )
282
- else:
283
- raise TypeError(
284
- f"Argument '{name}' is declared in a positional and keyword setting by different methods. It should be either."
285
- )
286
-
287
- p_to_n = [
288
- list(names) for _, names in sorted(self.position_to_names.items())
289
- ]
290
-
291
- positional = list(
292
- itertools.takewhile(
293
- lambda names: len(names) == 1 and isinstance(names[0], str),
294
- reversed(p_to_n),
295
- )
296
- )
297
- positional.reverse()
298
- strict_positional = p_to_n[: len(p_to_n) - len(positional)]
299
-
300
- assert strict_positional + positional == p_to_n
301
-
302
- self.strict_positional_required = [
303
- f"ARG{pos + 1}"
304
- for pos, _ in enumerate(strict_positional)
305
- if self.counts[pos][0] == self.total
306
- ]
307
- self.strict_positional_optional = [
308
- f"ARG{pos + 1}"
309
- for pos, _ in enumerate(strict_positional)
310
- if self.counts[pos][0] != self.total
311
- ]
312
-
313
- self.positional_required = [
314
- names[0]
315
- for pos, names in enumerate(positional)
316
- if self.counts[pos + len(strict_positional)][0] == self.total
317
- ]
318
- self.positional_optional = [
319
- names[0]
320
- for pos, names in enumerate(positional)
321
- if self.counts[pos + len(strict_positional)][0] != self.total
322
- ]
323
-
324
- keywords = [
325
- name
326
- for _, (name,) in self.name_to_positions.items()
327
- if not isinstance(name, int)
328
- ]
329
- self.keyword_required = [
330
- name for name in keywords if self.counts[name][0] == self.total
331
- ]
332
- self.keyword_optional = [
333
- name for name in keywords if self.counts[name][0] != self.total
334
- ]
335
- self.done = True
336
-
337
- def lookup_for(self, key):
338
- return subtler_type if key in self.complex_transforms else type
339
-
340
-
341
59
  class Ovld:
342
60
  """Overloaded function.
343
61
 
@@ -367,6 +85,7 @@ class Ovld:
367
85
  ):
368
86
  """Initialize an Ovld."""
369
87
  self.id = next(_current_id)
88
+ self.specialization_self = MISSING
370
89
  self._compiled = False
371
90
  self.linkback = linkback
372
91
  self.children = []
@@ -378,6 +97,7 @@ class Ovld:
378
97
  self._locked = False
379
98
  self.mixins = []
380
99
  self.argument_analysis = ArgumentAnalyzer()
100
+ self.dispatch = bootstrap_dispatch(self, name=self.shortname)
381
101
  self.add_mixins(*mixins)
382
102
 
383
103
  @property
@@ -396,23 +116,26 @@ class Ovld:
396
116
  return self.argument_analysis
397
117
 
398
118
  def mkdoc(self):
399
- docs = [fn.__doc__ for fn in self.defns.values() if fn.__doc__]
400
- if len(docs) == 1:
401
- maindoc = docs[0]
402
- else:
403
- maindoc = f"Ovld with {len(self.defns)} methods."
404
-
405
- doc = f"{maindoc}\n\n"
406
- for fn in self.defns.values():
407
- fndef = inspect.signature(fn)
408
- fdoc = fn.__doc__
409
- if not fdoc or fdoc == maindoc:
410
- doc += f"{self.__name__}{fndef}\n\n"
119
+ try:
120
+ docs = [fn.__doc__ for fn in self.defns.values() if fn.__doc__]
121
+ if len(docs) == 1:
122
+ maindoc = docs[0]
411
123
  else:
412
- if not fdoc.strip(" ").endswith("\n"):
413
- fdoc += "\n"
414
- fdoc = textwrap.indent(fdoc, " " * 4)
415
- doc += f"{self.__name__}{fndef}\n{fdoc}\n"
124
+ maindoc = f"Ovld with {len(self.defns)} methods."
125
+
126
+ doc = f"{maindoc}\n\n"
127
+ for fn in self.defns.values():
128
+ fndef = inspect.signature(fn)
129
+ fdoc = fn.__doc__
130
+ if not fdoc or fdoc == maindoc:
131
+ doc += f"{self.__name__}{fndef}\n\n"
132
+ else:
133
+ if not fdoc.strip(" ").endswith("\n"):
134
+ fdoc += "\n"
135
+ fdoc = textwrap.indent(fdoc, " " * 4)
136
+ doc += f"{self.__name__}{fndef}\n{fdoc}\n"
137
+ except Exception as exc: # pragma: no cover
138
+ doc = f"An exception occurred when calculating the docstring: {exc}"
416
139
  return doc
417
140
 
418
141
  @property
@@ -442,14 +165,12 @@ class Ovld:
442
165
  def _key_error(self, key, possibilities=None):
443
166
  typenames = sigstring(key)
444
167
  if not possibilities:
445
- return TypeError(
446
- f"No method in {self} for argument types [{typenames}]"
447
- )
168
+ return ResolutionError(f"No method in {self} for argument types [{typenames}]")
448
169
  else:
449
170
  hlp = ""
450
171
  for c in possibilities:
451
172
  hlp += f"* {c.handler.__name__} (priority: {c.priority}, specificity: {list(c.specificity)})\n"
452
- return TypeError(
173
+ return ResolutionError(
453
174
  f"Ambiguous resolution in {self} for"
454
175
  f" argument types [{typenames}]\n"
455
176
  f"Candidates are:\n{hlp}"
@@ -458,10 +179,14 @@ class Ovld:
458
179
 
459
180
  def rename(self, name, shortname=None):
460
181
  """Rename this Ovld."""
461
- self.name = name
462
- self.shortname = shortname or name
463
- self.__name__ = shortname
464
- self.dispatch = bootstrap_dispatch(self, name=self.shortname)
182
+ if name != self.name:
183
+ self.name = name
184
+ self.shortname = shortname or name
185
+ self.__name__ = self.shortname
186
+ self.dispatch = bootstrap_dispatch(self, name=self.shortname)
187
+
188
+ def __set_name__(self, inst, name):
189
+ self.rename(name)
465
190
 
466
191
  def _set_attrs_from(self, fn):
467
192
  """Inherit relevant attributes from the function."""
@@ -492,12 +217,10 @@ class Ovld:
492
217
  self.name = self.__name__ = f"ovld{self.id}"
493
218
 
494
219
  name = self.__name__
495
- self.map = MultiTypeMap(name=name, key_error=self._key_error)
220
+ self.map = MultiTypeMap(name=name, key_error=self._key_error, ovld=self)
496
221
 
497
222
  self.analyze_arguments()
498
223
  dispatch = generate_dispatch(self, self.argument_analysis)
499
- if not hasattr(self, "dispatch"):
500
- self.dispatch = bootstrap_dispatch(self, name=self.shortname)
501
224
  self.dispatch.__code__ = rename_code(dispatch.__code__, self.shortname)
502
225
  self.dispatch.__kwdefaults__ = dispatch.__kwdefaults__
503
226
  self.dispatch.__annotations__ = dispatch.__annotations__
@@ -511,16 +234,22 @@ class Ovld:
511
234
 
512
235
  self._compiled = True
513
236
 
514
- def resolve(self, *args):
237
+ def resolve_for_values(self, *args):
515
238
  """Find the correct method to call for the given arguments."""
516
239
  self.ensure_compiled()
517
240
  return self.map[tuple(map(subtler_type, args))]
518
241
 
242
+ def resolve(self, *args, after=None):
243
+ """Find the correct method to call for the given argument types."""
244
+ self.ensure_compiled()
245
+ if after:
246
+ return self.map[(getattr(after, "__code__", after), *args)]
247
+ else:
248
+ return self.map[args]
249
+
519
250
  def register_signature(self, sig, orig_fn):
520
251
  """Register a function for the given signature."""
521
- fn = adapt_function(
522
- orig_fn, self, f"{self.__name__}[{sigstring(sig.types)}]"
523
- )
252
+ fn = adapt_function(orig_fn, self, f"{self.__name__}[{sigstring(sig.types)}]")
524
253
  # We just need to keep the Conformer pointer alive for jurigged
525
254
  # to find it, if jurigged is used with ovld
526
255
  fn._conformer = Conformer(self, orig_fn, fn)
@@ -531,6 +260,7 @@ class Ovld:
531
260
  """Register a function."""
532
261
  if fn is None:
533
262
  return partial(self._register, priority=priority)
263
+ priority = getattr(fn, "priority", priority)
534
264
  return self._register(fn, priority)
535
265
 
536
266
  def _register(self, fn, priority):
@@ -542,9 +272,7 @@ class Ovld:
542
272
 
543
273
  sig = replace(Signature.extract(fn), priority=priority)
544
274
  if not self.allow_replacement and sig in self._defns:
545
- raise TypeError(
546
- f"There is already a method for {sigstring(sig.types)}"
547
- )
275
+ raise TypeError(f"There is already a method for {sigstring(sig.types)}")
548
276
 
549
277
  def _set(sig, fn):
550
278
  if sig in self._defns:
@@ -565,13 +293,16 @@ class Ovld:
565
293
  self._update()
566
294
 
567
295
  def _update(self):
568
- if self._compiled:
569
- self.compile()
296
+ self.reset()
570
297
  for child in self.children:
571
298
  child._update()
572
299
  if hasattr(self, "dispatch"):
573
300
  self.dispatch.__doc__ = self.mkdoc()
574
301
 
302
+ def reset(self):
303
+ self._compiled = False
304
+ self.dispatch.__code__ = self.dispatch.first_entry.__code__
305
+
575
306
  def copy(self, mixins=[], linkback=False):
576
307
  """Create a copy of this Ovld.
577
308
 
@@ -594,21 +325,15 @@ class Ovld:
594
325
  return ov
595
326
 
596
327
  def __get__(self, obj, cls):
597
- if not self._compiled:
598
- self.compile()
599
328
  return self.dispatch.__get__(obj, cls)
600
329
 
601
- @_setattrs(rename="dispatch")
602
330
  def __call__(self, *args, **kwargs): # pragma: no cover
603
331
  """Call the overloaded function.
604
332
 
605
333
  This should be replaced by an auto-generated function.
606
334
  """
607
- if not self._compiled:
608
- self.compile()
609
335
  return self.dispatch(*args, **kwargs)
610
336
 
611
- @_setattrs(rename="next")
612
337
  def next(self, *args):
613
338
  """Call the next matching method after the caller, in terms of priority or specificity."""
614
339
  fr = sys._getframe(1)
@@ -630,16 +355,14 @@ class Ovld:
630
355
 
631
356
  def is_ovld(x):
632
357
  """Return whether the argument is an ovld function/method."""
633
- return isinstance(x, Ovld) or isinstance(
634
- getattr(x, "__ovld__", False), Ovld
635
- )
358
+ return isinstance(x, Ovld) or isinstance(getattr(x, "__ovld__", False), Ovld)
636
359
 
637
360
 
638
- def to_ovld(x):
639
- """Return whether the argument is an ovld function/method."""
361
+ def to_ovld(x, force=True):
362
+ """Return the argument as an Ovld."""
640
363
  x = getattr(x, "__ovld__", x)
641
364
  if inspect.isfunction(x):
642
- return ovld(x, fresh=True)
365
+ return (ovld(x, fresh=True).__ovld__) if force else None
643
366
  else:
644
367
  return x if isinstance(x, Ovld) else None
645
368
 
@@ -694,9 +417,7 @@ class ovld_cls_dict(dict):
694
417
  prev.register(value)
695
418
  value = prev
696
419
 
697
- super().__setitem__(
698
- attr, value.dispatch if isinstance(value, Ovld) else value
699
- )
420
+ super().__setitem__(attr, value.dispatch if isinstance(value, Ovld) else value)
700
421
 
701
422
 
702
423
  class OvldMC(type):
@@ -713,7 +434,7 @@ class OvldMC(type):
713
434
  return type(cls)(name, bases, cls.__prepare__(name, bases))
714
435
 
715
436
  @classmethod
716
- def __prepare__(cls, name, bases):
437
+ def __prepare__(metacls, name, bases):
717
438
  d = ovld_cls_dict(bases)
718
439
 
719
440
  names = set()
@@ -723,9 +444,7 @@ class OvldMC(type):
723
444
  for name in names:
724
445
  values = [getattr(base, name, None) for base in bases]
725
446
  ovlds = [v for v in values if is_ovld(v)]
726
- mixins = [
727
- v for v in ovlds[1:] if getattr(v, "_extend_super", False)
728
- ]
447
+ mixins = [v for v in ovlds[1:] if getattr(v, "_extend_super", False)]
729
448
  if mixins:
730
449
  o = ovlds[0].copy(mixins=mixins)
731
450
  others = [v for v in values if v is not None and not is_ovld(v)]
@@ -736,6 +455,12 @@ class OvldMC(type):
736
455
 
737
456
  return d
738
457
 
458
+ def __init__(cls, name, bases, d):
459
+ for val in d.values():
460
+ if o := to_ovld(val, force=False):
461
+ o.specialization_self = cls
462
+ super().__init__(name, bases, d)
463
+
739
464
 
740
465
  class OvldBase(metaclass=OvldMC):
741
466
  """Base class that allows overloading of methods."""
@@ -796,13 +521,3 @@ def ovld(fn, priority=0, fresh=False, **kwargs):
796
521
  dispatch = _find_overload(fn, **kwargs)
797
522
  dispatch.register(fn, priority=priority)
798
523
  return dispatch.dispatch
799
-
800
-
801
- __all__ = [
802
- "Ovld",
803
- "OvldBase",
804
- "OvldMC",
805
- "extend_super",
806
- "is_ovld",
807
- "ovld",
808
- ]