ovld 0.4.6__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/core.py CHANGED
@@ -4,11 +4,9 @@ import inspect
4
4
  import itertools
5
5
  import sys
6
6
  import textwrap
7
- import typing
8
- from collections import OrderedDict, defaultdict
9
- from dataclasses import dataclass, field, replace
10
- from functools import cached_property, partial
11
- from types import FunctionType, GenericAlias
7
+ from dataclasses import replace
8
+ from functools import partial
9
+ from types import FunctionType
12
10
 
13
11
  from .recode import (
14
12
  Conformer,
@@ -16,97 +14,19 @@ from .recode import (
16
14
  generate_dispatch,
17
15
  rename_code,
18
16
  )
17
+ from .signatures import ArgumentAnalyzer, LazySignature, Signature
19
18
  from .typemap import MultiTypeMap
20
- from .types import clsstring, normalize_type
21
- from .utils import MISSING, UsageError, keyword_decorator, subtler_type
19
+ from .utils import (
20
+ MISSING,
21
+ UsageError,
22
+ keyword_decorator,
23
+ sigstring,
24
+ subtler_type,
25
+ )
22
26
 
23
27
  _current_id = itertools.count()
24
28
 
25
29
 
26
- @keyword_decorator
27
- def _setattrs(fn, **kwargs):
28
- for k, v in kwargs.items():
29
- setattr(fn, k, v)
30
- return fn
31
-
32
-
33
- class LazySignature(inspect.Signature):
34
- def __init__(self, ovld):
35
- super().__init__([])
36
- self.ovld = ovld
37
-
38
- def replace(
39
- self, *, parameters=inspect._void, return_annotation=inspect._void
40
- ): # pragma: no cover
41
- if parameters is inspect._void:
42
- parameters = self.parameters.values()
43
-
44
- if return_annotation is inspect._void:
45
- return_annotation = self._return_annotation
46
-
47
- return inspect.Signature(
48
- parameters, return_annotation=return_annotation
49
- )
50
-
51
- @property
52
- def parameters(self):
53
- anal = self.ovld.analyze_arguments()
54
- parameters = []
55
- if anal.is_method:
56
- parameters.append(
57
- inspect.Parameter(
58
- name="self",
59
- kind=inspect._POSITIONAL_ONLY,
60
- )
61
- )
62
- parameters += [
63
- inspect.Parameter(
64
- name=p,
65
- kind=inspect._POSITIONAL_ONLY,
66
- )
67
- for p in anal.strict_positional_required
68
- ]
69
- parameters += [
70
- inspect.Parameter(
71
- name=p,
72
- kind=inspect._POSITIONAL_ONLY,
73
- default=MISSING,
74
- )
75
- for p in anal.strict_positional_optional
76
- ]
77
- parameters += [
78
- inspect.Parameter(
79
- name=p,
80
- kind=inspect._POSITIONAL_OR_KEYWORD,
81
- )
82
- for p in anal.positional_required
83
- ]
84
- parameters += [
85
- inspect.Parameter(
86
- name=p,
87
- kind=inspect._POSITIONAL_OR_KEYWORD,
88
- default=MISSING,
89
- )
90
- for p in anal.positional_optional
91
- ]
92
- parameters += [
93
- inspect.Parameter(
94
- name=p,
95
- kind=inspect._KEYWORD_ONLY,
96
- )
97
- for p in anal.keyword_required
98
- ]
99
- parameters += [
100
- inspect.Parameter(
101
- name=p,
102
- kind=inspect._KEYWORD_ONLY,
103
- default=MISSING,
104
- )
105
- for p in anal.keyword_optional
106
- ]
107
- return OrderedDict({p.name: p for p in parameters})
108
-
109
-
110
30
  def bootstrap_dispatch(ov, name):
111
31
  def first_entry(*args, **kwargs):
112
32
  ov.compile()
@@ -122,6 +42,7 @@ def bootstrap_dispatch(ov, name):
122
42
  dispatch.__signature__ = LazySignature(ov)
123
43
  dispatch.__ovld__ = ov
124
44
  dispatch.register = ov.register
45
+ dispatch.resolve_for_values = ov.resolve_for_values
125
46
  dispatch.resolve = ov.resolve
126
47
  dispatch.copy = ov.copy
127
48
  dispatch.variant = ov.variant
@@ -133,211 +54,6 @@ def bootstrap_dispatch(ov, name):
133
54
  return dispatch
134
55
 
135
56
 
136
- @dataclass(frozen=True)
137
- class Arginfo:
138
- position: typing.Optional[int]
139
- name: typing.Optional[str]
140
- required: bool
141
- ann: type
142
-
143
- @cached_property
144
- def is_complex(self):
145
- return isinstance(self.ann, GenericAlias)
146
-
147
- @cached_property
148
- def canonical(self):
149
- return self.name if self.position is None else self.position
150
-
151
-
152
- @dataclass(frozen=True)
153
- class Signature:
154
- types: tuple
155
- return_type: type
156
- req_pos: int
157
- max_pos: int
158
- req_names: frozenset
159
- vararg: bool
160
- priority: float
161
- tiebreak: int = 0
162
- is_method: bool = False
163
- arginfo: list[Arginfo] = field(
164
- default_factory=list, hash=False, compare=False
165
- )
166
-
167
- @classmethod
168
- def extract(cls, fn):
169
- typelist = []
170
- sig = inspect.signature(fn)
171
- max_pos = 0
172
- req_pos = 0
173
- req_names = set()
174
- is_method = False
175
-
176
- arginfo = []
177
- for i, (name, param) in enumerate(sig.parameters.items()):
178
- if name == "self":
179
- assert i == 0
180
- is_method = True
181
- continue
182
- pos = nm = None
183
- ann = normalize_type(param.annotation, fn)
184
- if param.kind is inspect._POSITIONAL_ONLY:
185
- pos = i - is_method
186
- typelist.append(ann)
187
- req_pos += param.default is inspect._empty
188
- max_pos += 1
189
- elif param.kind is inspect._POSITIONAL_OR_KEYWORD:
190
- pos = i - is_method
191
- nm = param.name
192
- typelist.append(ann)
193
- req_pos += param.default is inspect._empty
194
- max_pos += 1
195
- elif param.kind is inspect._KEYWORD_ONLY:
196
- nm = param.name
197
- typelist.append((param.name, ann))
198
- if param.default is inspect._empty:
199
- req_names.add(param.name)
200
- elif param.kind is inspect._VAR_POSITIONAL:
201
- raise TypeError("ovld does not support *args")
202
- elif param.kind is inspect._VAR_KEYWORD:
203
- raise TypeError("ovld does not support **kwargs")
204
- arginfo.append(
205
- Arginfo(
206
- position=pos,
207
- name=nm,
208
- required=param.default is inspect._empty,
209
- ann=normalize_type(param.annotation, fn),
210
- )
211
- )
212
-
213
- return cls(
214
- types=tuple(typelist),
215
- return_type=normalize_type(sig.return_annotation, fn),
216
- req_pos=req_pos,
217
- max_pos=max_pos,
218
- req_names=frozenset(req_names),
219
- vararg=False,
220
- is_method=is_method,
221
- priority=None,
222
- arginfo=arginfo,
223
- )
224
-
225
-
226
- def typemap_entry_string(cls):
227
- if isinstance(cls, tuple):
228
- key, typ = cls
229
- return f"{key}: {clsstring(typ)}"
230
- else:
231
- return clsstring(cls)
232
-
233
-
234
- def sigstring(types):
235
- return ", ".join(map(typemap_entry_string, types))
236
-
237
-
238
- class ArgumentAnalyzer:
239
- def __init__(self):
240
- self.name_to_positions = defaultdict(set)
241
- self.position_to_names = defaultdict(set)
242
- self.counts = defaultdict(lambda: [0, 0])
243
- self.complex_transforms = set()
244
- self.total = 0
245
- self.is_method = None
246
- self.done = False
247
-
248
- def add(self, fn):
249
- self.done = False
250
- sig = Signature.extract(fn)
251
- self.complex_transforms.update(
252
- arg.canonical for arg in sig.arginfo if arg.is_complex
253
- )
254
- for arg in sig.arginfo:
255
- if arg.position is not None:
256
- self.position_to_names[arg.position].add(arg.name)
257
- if arg.name is not None:
258
- self.name_to_positions[arg.name].add(arg.canonical)
259
-
260
- cnt = self.counts[arg.canonical]
261
- cnt[0] += arg.required
262
- cnt[1] += 1
263
-
264
- self.total += 1
265
-
266
- if self.is_method is None:
267
- self.is_method = sig.is_method
268
- elif self.is_method != sig.is_method: # pragma: no cover
269
- raise TypeError(
270
- "Some, but not all registered methods define `self`. It should be all or none."
271
- )
272
-
273
- def compile(self):
274
- if self.done:
275
- return
276
- for name, pos in self.name_to_positions.items():
277
- if len(pos) != 1:
278
- if all(isinstance(p, int) for p in pos):
279
- raise TypeError(
280
- f"Argument '{name}' is declared in different positions by different methods. The same argument name should always be in the same position unless it is strictly positional."
281
- )
282
- else:
283
- raise TypeError(
284
- f"Argument '{name}' is declared in a positional and keyword setting by different methods. It should be either."
285
- )
286
-
287
- p_to_n = [
288
- list(names) for _, names in sorted(self.position_to_names.items())
289
- ]
290
-
291
- positional = list(
292
- itertools.takewhile(
293
- lambda names: len(names) == 1 and isinstance(names[0], str),
294
- reversed(p_to_n),
295
- )
296
- )
297
- positional.reverse()
298
- strict_positional = p_to_n[: len(p_to_n) - len(positional)]
299
-
300
- assert strict_positional + positional == p_to_n
301
-
302
- self.strict_positional_required = [
303
- f"ARG{pos + 1}"
304
- for pos, _ in enumerate(strict_positional)
305
- if self.counts[pos][0] == self.total
306
- ]
307
- self.strict_positional_optional = [
308
- f"ARG{pos + 1}"
309
- for pos, _ in enumerate(strict_positional)
310
- if self.counts[pos][0] != self.total
311
- ]
312
-
313
- self.positional_required = [
314
- names[0]
315
- for pos, names in enumerate(positional)
316
- if self.counts[pos + len(strict_positional)][0] == self.total
317
- ]
318
- self.positional_optional = [
319
- names[0]
320
- for pos, names in enumerate(positional)
321
- if self.counts[pos + len(strict_positional)][0] != self.total
322
- ]
323
-
324
- keywords = [
325
- name
326
- for _, (name,) in self.name_to_positions.items()
327
- if not isinstance(name, int)
328
- ]
329
- self.keyword_required = [
330
- name for name in keywords if self.counts[name][0] == self.total
331
- ]
332
- self.keyword_optional = [
333
- name for name in keywords if self.counts[name][0] != self.total
334
- ]
335
- self.done = True
336
-
337
- def lookup_for(self, key):
338
- return subtler_type if key in self.complex_transforms else type
339
-
340
-
341
57
  class Ovld:
342
58
  """Overloaded function.
343
59
 
@@ -367,6 +83,7 @@ class Ovld:
367
83
  ):
368
84
  """Initialize an Ovld."""
369
85
  self.id = next(_current_id)
86
+ self.specialization_self = MISSING
370
87
  self._compiled = False
371
88
  self.linkback = linkback
372
89
  self.children = []
@@ -396,23 +113,26 @@ class Ovld:
396
113
  return self.argument_analysis
397
114
 
398
115
  def mkdoc(self):
399
- docs = [fn.__doc__ for fn in self.defns.values() if fn.__doc__]
400
- if len(docs) == 1:
401
- maindoc = docs[0]
402
- else:
403
- maindoc = f"Ovld with {len(self.defns)} methods."
404
-
405
- doc = f"{maindoc}\n\n"
406
- for fn in self.defns.values():
407
- fndef = inspect.signature(fn)
408
- fdoc = fn.__doc__
409
- if not fdoc or fdoc == maindoc:
410
- doc += f"{self.__name__}{fndef}\n\n"
116
+ try:
117
+ docs = [fn.__doc__ for fn in self.defns.values() if fn.__doc__]
118
+ if len(docs) == 1:
119
+ maindoc = docs[0]
411
120
  else:
412
- if not fdoc.strip(" ").endswith("\n"):
413
- fdoc += "\n"
414
- fdoc = textwrap.indent(fdoc, " " * 4)
415
- doc += f"{self.__name__}{fndef}\n{fdoc}\n"
121
+ maindoc = f"Ovld with {len(self.defns)} methods."
122
+
123
+ doc = f"{maindoc}\n\n"
124
+ for fn in self.defns.values():
125
+ fndef = inspect.signature(fn)
126
+ fdoc = fn.__doc__
127
+ if not fdoc or fdoc == maindoc:
128
+ doc += f"{self.__name__}{fndef}\n\n"
129
+ else:
130
+ if not fdoc.strip(" ").endswith("\n"):
131
+ fdoc += "\n"
132
+ fdoc = textwrap.indent(fdoc, " " * 4)
133
+ doc += f"{self.__name__}{fndef}\n{fdoc}\n"
134
+ except Exception as exc: # pragma: no cover
135
+ doc = f"An exception occurred when calculating the docstring: {exc}"
416
136
  return doc
417
137
 
418
138
  @property
@@ -442,9 +162,7 @@ class Ovld:
442
162
  def _key_error(self, key, possibilities=None):
443
163
  typenames = sigstring(key)
444
164
  if not possibilities:
445
- return TypeError(
446
- f"No method in {self} for argument types [{typenames}]"
447
- )
165
+ return TypeError(f"No method in {self} for argument types [{typenames}]")
448
166
  else:
449
167
  hlp = ""
450
168
  for c in possibilities:
@@ -460,9 +178,12 @@ class Ovld:
460
178
  """Rename this Ovld."""
461
179
  self.name = name
462
180
  self.shortname = shortname or name
463
- self.__name__ = shortname
181
+ self.__name__ = self.shortname
464
182
  self.dispatch = bootstrap_dispatch(self, name=self.shortname)
465
183
 
184
+ def __set_name__(self, inst, name):
185
+ self.rename(name)
186
+
466
187
  def _set_attrs_from(self, fn):
467
188
  """Inherit relevant attributes from the function."""
468
189
  if self.name is None:
@@ -492,7 +213,7 @@ class Ovld:
492
213
  self.name = self.__name__ = f"ovld{self.id}"
493
214
 
494
215
  name = self.__name__
495
- self.map = MultiTypeMap(name=name, key_error=self._key_error)
216
+ self.map = MultiTypeMap(name=name, key_error=self._key_error, ovld=self)
496
217
 
497
218
  self.analyze_arguments()
498
219
  dispatch = generate_dispatch(self, self.argument_analysis)
@@ -511,16 +232,22 @@ class Ovld:
511
232
 
512
233
  self._compiled = True
513
234
 
514
- def resolve(self, *args):
235
+ def resolve_for_values(self, *args):
515
236
  """Find the correct method to call for the given arguments."""
516
237
  self.ensure_compiled()
517
238
  return self.map[tuple(map(subtler_type, args))]
518
239
 
240
+ def resolve(self, *args, after=None):
241
+ """Find the correct method to call for the given argument types."""
242
+ self.ensure_compiled()
243
+ if after:
244
+ return self.map[(getattr(after, "__code__", after), *args)]
245
+ else:
246
+ return self.map[args]
247
+
519
248
  def register_signature(self, sig, orig_fn):
520
249
  """Register a function for the given signature."""
521
- fn = adapt_function(
522
- orig_fn, self, f"{self.__name__}[{sigstring(sig.types)}]"
523
- )
250
+ fn = adapt_function(orig_fn, self, f"{self.__name__}[{sigstring(sig.types)}]")
524
251
  # We just need to keep the Conformer pointer alive for jurigged
525
252
  # to find it, if jurigged is used with ovld
526
253
  fn._conformer = Conformer(self, orig_fn, fn)
@@ -542,9 +269,7 @@ class Ovld:
542
269
 
543
270
  sig = replace(Signature.extract(fn), priority=priority)
544
271
  if not self.allow_replacement and sig in self._defns:
545
- raise TypeError(
546
- f"There is already a method for {sigstring(sig.types)}"
547
- )
272
+ raise TypeError(f"There is already a method for {sigstring(sig.types)}")
548
273
 
549
274
  def _set(sig, fn):
550
275
  if sig in self._defns:
@@ -598,7 +323,6 @@ class Ovld:
598
323
  self.compile()
599
324
  return self.dispatch.__get__(obj, cls)
600
325
 
601
- @_setattrs(rename="dispatch")
602
326
  def __call__(self, *args, **kwargs): # pragma: no cover
603
327
  """Call the overloaded function.
604
328
 
@@ -608,7 +332,6 @@ class Ovld:
608
332
  self.compile()
609
333
  return self.dispatch(*args, **kwargs)
610
334
 
611
- @_setattrs(rename="next")
612
335
  def next(self, *args):
613
336
  """Call the next matching method after the caller, in terms of priority or specificity."""
614
337
  fr = sys._getframe(1)
@@ -630,16 +353,14 @@ class Ovld:
630
353
 
631
354
  def is_ovld(x):
632
355
  """Return whether the argument is an ovld function/method."""
633
- return isinstance(x, Ovld) or isinstance(
634
- getattr(x, "__ovld__", False), Ovld
635
- )
356
+ return isinstance(x, Ovld) or isinstance(getattr(x, "__ovld__", False), Ovld)
636
357
 
637
358
 
638
- def to_ovld(x):
639
- """Return whether the argument is an ovld function/method."""
359
+ def to_ovld(x, force=True):
360
+ """Return the argument as an Ovld."""
640
361
  x = getattr(x, "__ovld__", x)
641
362
  if inspect.isfunction(x):
642
- return ovld(x, fresh=True)
363
+ return (ovld(x, fresh=True).__ovld__) if force else None
643
364
  else:
644
365
  return x if isinstance(x, Ovld) else None
645
366
 
@@ -694,9 +415,7 @@ class ovld_cls_dict(dict):
694
415
  prev.register(value)
695
416
  value = prev
696
417
 
697
- super().__setitem__(
698
- attr, value.dispatch if isinstance(value, Ovld) else value
699
- )
418
+ super().__setitem__(attr, value.dispatch if isinstance(value, Ovld) else value)
700
419
 
701
420
 
702
421
  class OvldMC(type):
@@ -713,7 +432,7 @@ class OvldMC(type):
713
432
  return type(cls)(name, bases, cls.__prepare__(name, bases))
714
433
 
715
434
  @classmethod
716
- def __prepare__(cls, name, bases):
435
+ def __prepare__(metacls, name, bases):
717
436
  d = ovld_cls_dict(bases)
718
437
 
719
438
  names = set()
@@ -723,9 +442,7 @@ class OvldMC(type):
723
442
  for name in names:
724
443
  values = [getattr(base, name, None) for base in bases]
725
444
  ovlds = [v for v in values if is_ovld(v)]
726
- mixins = [
727
- v for v in ovlds[1:] if getattr(v, "_extend_super", False)
728
- ]
445
+ mixins = [v for v in ovlds[1:] if getattr(v, "_extend_super", False)]
729
446
  if mixins:
730
447
  o = ovlds[0].copy(mixins=mixins)
731
448
  others = [v for v in values if v is not None and not is_ovld(v)]
@@ -736,6 +453,12 @@ class OvldMC(type):
736
453
 
737
454
  return d
738
455
 
456
+ def __init__(cls, name, bases, d):
457
+ for val in d.values():
458
+ if o := to_ovld(val, force=False):
459
+ o.specialization_self = cls
460
+ super().__init__(name, bases, d)
461
+
739
462
 
740
463
  class OvldBase(metaclass=OvldMC):
741
464
  """Base class that allows overloading of methods."""
@@ -796,13 +519,3 @@ def ovld(fn, priority=0, fresh=False, **kwargs):
796
519
  dispatch = _find_overload(fn, **kwargs)
797
520
  dispatch.register(fn, priority=priority)
798
521
  return dispatch.dispatch
799
-
800
-
801
- __all__ = [
802
- "Ovld",
803
- "OvldBase",
804
- "OvldMC",
805
- "extend_super",
806
- "is_ovld",
807
- "ovld",
808
- ]