ovld 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/typemap.py ADDED
@@ -0,0 +1,383 @@
1
+ import inspect
2
+ import math
3
+ import typing
4
+ from itertools import count
5
+ from types import CodeType
6
+
7
+ from .dependent import DependentType
8
+ from .mro import sort_types
9
+ from .recode import generate_dependent_dispatch
10
+ from .utils import MISSING
11
+
12
+
13
+ class GenericAliasMC(type):
14
+ def __instancecheck__(cls, obj):
15
+ return hasattr(obj, "__origin__")
16
+
17
+
18
+ class GenericAlias(metaclass=GenericAliasMC):
19
+ pass
20
+
21
+
22
+ def is_type_of_type(t):
23
+ return getattr(t, "__origin__", None) is type
24
+
25
+
26
+ class TypeMap(dict):
27
+ """Represents a mapping from types to handlers.
28
+
29
+ The mro of a type is considered when getting the handler, so setting the
30
+ [object] key creates a default for all objects.
31
+
32
+ typemap[some_type] returns a tuple of a handler and a "level" that
33
+ represents the distance from the handler to the type `object`. Essentially,
34
+ the level is the index of the type for which the handler was registered
35
+ in the mro of `some_type`. So for example, `object` has level 0, a class
36
+ that inherits directly from `object` has level 1, and so on.
37
+ """
38
+
39
+ def __init__(self):
40
+ self.entries = {}
41
+ self.types = set()
42
+
43
+ def register(self, obj_t, handler):
44
+ """Register a handler for the given object type."""
45
+ self.clear()
46
+ self.types.add(obj_t)
47
+ s = self.entries.setdefault(obj_t, set())
48
+ s.add(handler)
49
+
50
+ def __missing__(self, obj_t):
51
+ """Get the handler for the given type.
52
+
53
+ The result is cached so that the normal dict getitem will find it
54
+ the next time getitem is called.
55
+ """
56
+ results = {}
57
+ groups = list(sort_types(obj_t, self.types))
58
+
59
+ for lvl, grp in enumerate(reversed(groups)):
60
+ for cls in grp:
61
+ handlers = self.entries.get(cls, None)
62
+ if handlers:
63
+ results.update({h: lvl for h in handlers})
64
+
65
+ if results:
66
+ self[obj_t] = results
67
+ return results
68
+ else:
69
+ raise KeyError(obj_t)
70
+
71
+
72
+ class MultiTypeMap(dict):
73
+ """Represents a mapping from tuples of types to handlers.
74
+
75
+ The mro is taken into account to find a match. If multiple registered
76
+ handlers match the tuple of types that's given, if one of the handlers is
77
+ more specific than every other handler, that handler is returned.
78
+ Otherwise, the resolution is considered ambiguous and an error is raised.
79
+
80
+ Handler A, registered for types (A1, A2, ..., An), is more specific than
81
+ handler B, registered for types (B1, B2, ..., Bn), if there exists n such
82
+ that An is more specific than Bn, and for all n, either An == Bn or An is
83
+ more specific than Bn. An is more specific than Bn if An is a direct or
84
+ indirect subclass of Bn.
85
+
86
+ In other words, [int, object] is more specific than [object, object] and
87
+ less specific than [int, int], but it is neither less specific nor more
88
+ specific than [object, int] (which means there is an ambiguity).
89
+ """
90
+
91
+ def __init__(self, name="_ovld", key_error=KeyError):
92
+ self.maps = {}
93
+ self.priorities = {}
94
+ self.dependent = {}
95
+ self.type_tuples = {}
96
+ self.empty = MISSING
97
+ self.key_error = key_error
98
+ self.name = name
99
+ self.dispatch_id = count()
100
+ self.all = {}
101
+ self.errors = {}
102
+
103
+ def transform(self, obj):
104
+ if isinstance(obj, GenericAlias):
105
+ return type[obj]
106
+ elif obj is typing.Any:
107
+ return type[object]
108
+ elif isinstance(obj, type):
109
+ return type[obj]
110
+ else:
111
+ return type(obj)
112
+
113
+ def mro(self, obj_t_tup):
114
+ specificities = {}
115
+ candidates = None
116
+ nargs = len([t for t in obj_t_tup if not isinstance(t, tuple)])
117
+ names = {t[0] for t in obj_t_tup if isinstance(t, tuple)}
118
+
119
+ for i, cls in enumerate(obj_t_tup):
120
+ if isinstance(cls, tuple):
121
+ i, cls = cls
122
+
123
+ try:
124
+ results = self.maps[i][cls]
125
+ except KeyError:
126
+ results = {}
127
+
128
+ results = {
129
+ handler: spc
130
+ for (handler, sig), spc in results.items()
131
+ if sig.req_pos
132
+ <= nargs
133
+ <= (math.inf if sig.vararg else sig.max_pos)
134
+ and not (sig.req_names - names)
135
+ }
136
+
137
+ try:
138
+ vararg_results = self.maps[-1][cls]
139
+ except KeyError:
140
+ vararg_results = {}
141
+
142
+ vararg_results = {
143
+ handler: spc
144
+ for (handler, sig), spc in vararg_results.items()
145
+ if sig.req_pos <= nargs and i >= sig.max_pos
146
+ }
147
+
148
+ results.update(vararg_results)
149
+
150
+ if candidates is None:
151
+ candidates = set(results.keys())
152
+ else:
153
+ candidates &= results.keys()
154
+ for c in candidates:
155
+ specificities.setdefault(c, []).append(results[c])
156
+
157
+ candidates = [
158
+ (c, self.priorities.get(c, 0), tuple(specificities[c]))
159
+ for c in candidates
160
+ ]
161
+
162
+ # The sort ensures that if candidate A dominates candidate B, A will
163
+ # appear before B in the list. That's because it must dominate all
164
+ # other possibilities on all arguments, so the sum of all specificities
165
+ # has to be greater.
166
+ # Note: priority is always more important than specificity
167
+ candidates.sort(key=lambda cspc: (cspc[1], sum(cspc[2])), reverse=True)
168
+
169
+ self.all[obj_t_tup] = {
170
+ getattr(c[0], "__code__", None) for c in candidates
171
+ }
172
+
173
+ processed = set()
174
+
175
+ def _pull(candidates):
176
+ candidates = [
177
+ (c, a, b) for (c, a, b) in candidates if c not in processed
178
+ ]
179
+ if not candidates:
180
+ return
181
+ rval = [candidates[0]]
182
+ c1, p1, spc1 = candidates[0]
183
+ for c2, p2, spc2 in candidates[1:]:
184
+ if p1 > p2 or (
185
+ spc1 != spc2 and all(s1 >= s2 for s1, s2 in zip(spc1, spc2))
186
+ ):
187
+ # Candidate 1 dominates candidate 2
188
+ continue
189
+ else:
190
+ processed.add(c2)
191
+ # Candidate 1 does not dominate candidate 2, so we add it
192
+ # to the list.
193
+ rval.append((c2, p2, spc2))
194
+ yield rval
195
+ if len(rval) >= 1:
196
+ yield from _pull(candidates[1:])
197
+
198
+ return list(_pull(candidates))
199
+
200
+ def register(self, sig, handler):
201
+ """Register a handler for a tuple of argument types.
202
+
203
+ Arguments:
204
+ sig: A Signature object.
205
+ handler: A function to handle the tuple.
206
+ """
207
+ self.clear()
208
+
209
+ obj_t_tup = sig.types
210
+ entry = (handler, sig)
211
+ if not obj_t_tup:
212
+ self.empty = entry
213
+
214
+ self.priorities[handler] = sig.priority
215
+ self.type_tuples[handler] = obj_t_tup
216
+ self.dependent[handler] = any(
217
+ isinstance(t[1] if isinstance(t, tuple) else t, DependentType)
218
+ for t in obj_t_tup
219
+ )
220
+
221
+ for i, cls in enumerate(obj_t_tup):
222
+ if isinstance(cls, tuple):
223
+ i, cls = cls
224
+ if i not in self.maps:
225
+ self.maps[i] = TypeMap()
226
+ self.maps[i].register(cls, entry)
227
+
228
+ if sig.vararg: # pragma: no cover
229
+ # TODO: either add this back in, or remove it
230
+ if -1 not in self.maps:
231
+ self.maps[-1] = TypeMap()
232
+ self.maps[-1].register(object, entry)
233
+
234
+ def display_methods(self):
235
+ for h, prio in sorted(self.priorities.items(), key=lambda kv: -kv[1]):
236
+ prio = f"[{prio}]"
237
+ width = 6
238
+ print(f"{prio:{width}} \033[1m{h.__name__}\033[0m")
239
+ co = h.__code__
240
+ print(f"{'':{width-2}} @ {co.co_filename}:{co.co_firstlineno}")
241
+
242
+ def display_resolution(self, *args, **kwargs):
243
+ def dependent_match(tup, args):
244
+ for t, a in zip(tup, args):
245
+ if isinstance(t, tuple):
246
+ t = t[1]
247
+ a = a[1]
248
+ if isinstance(t, DependentType) and not t.check(a):
249
+ return False
250
+ return True
251
+
252
+ message = "No method will be called."
253
+ argt = [
254
+ *map(self.transform, args),
255
+ *[(k, self.transform(v)) for k, v in kwargs.items()],
256
+ ]
257
+ finished = False
258
+ rank = 1
259
+ for grp in self.mro(tuple(argt)):
260
+ grp.sort(key=lambda x: x[0].__name__)
261
+ match = [
262
+ dependent_match(
263
+ self.type_tuples[handler], [*args, *kwargs.items()]
264
+ )
265
+ for handler, _, _ in grp
266
+ ]
267
+ ambiguous = len([m for m in match if m]) > 1
268
+ for m, (handler, prio, spec) in zip(match, grp):
269
+ color = "\033[0m"
270
+ if finished:
271
+ bullet = "--"
272
+ color = "\033[1;90m"
273
+ elif not m:
274
+ bullet = "!="
275
+ color = "\033[1;90m"
276
+ elif ambiguous:
277
+ bullet = "=="
278
+ color = "\033[1;31m"
279
+ else:
280
+ bullet = f"#{rank}"
281
+ if rank == 1:
282
+ message = f"{handler.__name__} will be called first."
283
+ color = "\033[1;32m"
284
+ rank += 1
285
+ spec = ".".join(map(str, spec))
286
+ lvl = f"[{prio}:{spec}]"
287
+ width = 2 * len(args) + 6
288
+ print(f"{color}{bullet} {lvl:{width}} {handler.__name__}")
289
+ co = handler.__code__
290
+ print(
291
+ f" {'':{width-1}}@ {co.co_filename}:{co.co_firstlineno}\033[0m"
292
+ )
293
+ if ambiguous:
294
+ message += " There is ambiguity between multiple matching methods, marked '=='."
295
+ finished = True
296
+ print("Resolution:", message)
297
+
298
+ def wrap_dependent(self, tup, handlers, group, next_call):
299
+ handlers = list(handlers)
300
+ htup = [(h, self.type_tuples[h]) for h in handlers]
301
+ slf = (
302
+ "self, "
303
+ if inspect.getfullargspec(handlers[0]).args[0] == "self"
304
+ else ""
305
+ )
306
+ return generate_dependent_dispatch(
307
+ tup,
308
+ htup,
309
+ next_call,
310
+ slf,
311
+ name=f"{self.name}.specialized_dispatch_{next(self.dispatch_id)}",
312
+ err=self.key_error(tup, group),
313
+ nerr=self.key_error(tup, ()),
314
+ )
315
+
316
+ def resolve(self, obj_t_tup):
317
+ results = self.mro(obj_t_tup)
318
+ if not results:
319
+ raise self.key_error(obj_t_tup, ())
320
+
321
+ funcs = []
322
+ for group in reversed(results):
323
+ handlers = [fn for (fn, _, _) in group]
324
+ dependent = any(self.dependent[fn] for (fn, _, _) in group)
325
+ if dependent:
326
+ nxt = self.wrap_dependent(
327
+ obj_t_tup, handlers, group, funcs[-1] if funcs else None
328
+ )
329
+ elif len(group) != 1:
330
+ nxt = None
331
+ else:
332
+ nxt = handlers[0]
333
+ codes = [h.__code__ for h in handlers if hasattr(h, "__code__")]
334
+ funcs.append((nxt, codes))
335
+
336
+ funcs.reverse()
337
+
338
+ parents = []
339
+ for group, (func, codes) in zip(results, funcs):
340
+ tups = (
341
+ [obj_t_tup]
342
+ if not parents
343
+ else [(parent, *obj_t_tup) for parent in parents]
344
+ )
345
+ if func is None:
346
+ for tup in tups:
347
+ self.errors[tup] = self.key_error(obj_t_tup, group)
348
+ break
349
+ else:
350
+ for tup in tups:
351
+ self[tup] = func
352
+ if not codes:
353
+ break
354
+ parents = codes
355
+
356
+ return True
357
+
358
+ def __missing__(self, obj_t_tup):
359
+ if obj_t_tup and isinstance(obj_t_tup[0], CodeType):
360
+ real_tup = obj_t_tup[1:]
361
+ self[real_tup]
362
+ if obj_t_tup[0] not in self.all[real_tup]:
363
+ return self[real_tup]
364
+ elif obj_t_tup in self.errors:
365
+ raise self.errors[obj_t_tup]
366
+ elif obj_t_tup in self: # pragma: no cover
367
+ # PROBABLY not reachable
368
+ return self[obj_t_tup]
369
+ else:
370
+ raise self.key_error(real_tup, ())
371
+
372
+ if not obj_t_tup:
373
+ if self.empty is MISSING: # pragma: no cover
374
+ # Might not be reachable because of codegen
375
+ raise self.key_error(obj_t_tup, ())
376
+ else:
377
+ return self.empty[0]
378
+
379
+ self.resolve(obj_t_tup)
380
+ if obj_t_tup in self.errors:
381
+ raise self.errors[obj_t_tup]
382
+ else:
383
+ return self[obj_t_tup]
ovld/types.py ADDED
@@ -0,0 +1,219 @@
1
+ import inspect
2
+ import sys
3
+ import typing
4
+ from dataclasses import dataclass
5
+ from typing import Protocol, runtime_checkable
6
+
7
+ from ovld.utils import UsageError
8
+
9
+ from .mro import Order, TypeRelationship, subclasscheck, typeorder
10
+
11
+ try:
12
+ from types import UnionType
13
+ except ImportError: # pragma: no cover
14
+ UnionType = None
15
+
16
+
17
+ def normalize_type(t, fn):
18
+ from .dependent import DependentType, Equals
19
+
20
+ if isinstance(t, str):
21
+ t = eval(t, getattr(fn, "__globals__", {}))
22
+
23
+ if t is type:
24
+ t = type[object]
25
+ elif t is typing.Any:
26
+ t = object
27
+ elif t is inspect._empty:
28
+ t = object
29
+ elif isinstance(t, typing._AnnotatedAlias):
30
+ t = t.__origin__
31
+
32
+ origin = getattr(t, "__origin__", None)
33
+ if UnionType and isinstance(t, UnionType):
34
+ return normalize_type(t.__args__, fn)
35
+ elif origin is type:
36
+ return t
37
+ elif origin is typing.Union:
38
+ return normalize_type(t.__args__, fn)
39
+ elif origin is typing.Literal:
40
+ return Equals(*t.__args__)
41
+ elif origin and not getattr(t, "__args__", None):
42
+ return t
43
+ elif origin is not None:
44
+ raise TypeError(
45
+ f"ovld does not accept generic types except type, Union, Optional, Literal, but not: {t}"
46
+ )
47
+ elif isinstance(t, tuple):
48
+ return typing.Union[tuple(normalize_type(t2, fn) for t2 in t)]
49
+ elif isinstance(t, DependentType) and not t.bound:
50
+ raise UsageError(
51
+ f"Dependent type {t} has not been given a type bound. Please use Dependent[<bound>, {t}] instead."
52
+ )
53
+ else:
54
+ return t
55
+
56
+
57
+ class MetaMC(type):
58
+ def __new__(T, name, order):
59
+ return super().__new__(T, name, (), {"order": order})
60
+
61
+ def __init__(cls, name, order):
62
+ pass
63
+
64
+ def __typeorder__(cls, other):
65
+ order = cls.order(other)
66
+ if isinstance(order, bool):
67
+ return NotImplemented
68
+ else:
69
+ return order
70
+
71
+ def __subclasscheck__(cls, sub):
72
+ result = cls.order(sub)
73
+ if isinstance(result, TypeRelationship):
74
+ return result.order in (Order.MORE, Order.SAME)
75
+ else:
76
+ return result
77
+
78
+
79
+ def class_check(condition):
80
+ """Return a class with a subclassing relation defined by condition.
81
+
82
+ For example, a dataclass is a subclass of `class_check(dataclasses.is_dataclass)`,
83
+ and a class which name starts with "X" is a subclass of
84
+ `class_check(lambda cls: cls.__name__.startswith("X"))`.
85
+
86
+ Arguments:
87
+ condition: A function that takes a class as an argument and returns
88
+ True or False depending on whether it matches some condition.
89
+ """
90
+ return MetaMC(condition.__name__, condition)
91
+
92
+
93
+ def parametrized_class_check(fn):
94
+ """Return a parametrized class checker.
95
+
96
+ In essence, parametrized_class_check(fn)[X] will call fn(cls, X) in order
97
+ to check whether cls matches the condition defined by fn and X.
98
+
99
+ Arguments:
100
+ fn: A function that takes a class and one or more additional arguments,
101
+ and returns True or False depending on whether the class matches.
102
+ """
103
+
104
+ class _C:
105
+ def __class_getitem__(_, arg):
106
+ if not isinstance(arg, tuple):
107
+ arg = (arg,)
108
+
109
+ def arg_to_str(x):
110
+ if isinstance(x, type):
111
+ return x.__name__
112
+ else:
113
+ return repr(x)
114
+
115
+ name = f"{fn.__name__}[{', '.join(map(arg_to_str, arg))}]"
116
+ return MetaMC(name, lambda sub: fn(sub, *arg))
117
+
118
+ _C.__name__ = fn.__name__
119
+ _C.__qualname__ = fn.__qualname__
120
+ return _C
121
+
122
+
123
+ def _getcls(ref):
124
+ module, *parts = ref.split(".")
125
+ curr = __import__(module)
126
+ for part in parts:
127
+ curr = getattr(curr, part)
128
+ return curr
129
+
130
+
131
+ class Deferred:
132
+ """Represent a class from an external module without importing it.
133
+
134
+ For instance, `Deferred["numpy.ndarray"]` matches instances of
135
+ numpy.ndarray, but it does not import numpy. When tested against a
136
+ class, if the first part of class's `__module__` is `numpy`, then
137
+ we do get the class and perform a normal issubclass check.
138
+
139
+ If the module is already loaded, `Deferred` returns the class directly.
140
+
141
+ Arguments:
142
+ ref: A string starting with a module name representing the path
143
+ to import a class.
144
+ """
145
+
146
+ def __class_getitem__(cls, ref):
147
+ module, _ = ref.split(".", 1)
148
+ if module in sys.modules:
149
+ return _getcls(ref)
150
+
151
+ def check(cls):
152
+ full_cls_mod = getattr(cls, "__module__", None)
153
+ cls_module = full_cls_mod.split(".", 1)[0] if full_cls_mod else None
154
+ if cls_module == module:
155
+ return issubclass(cls, _getcls(ref))
156
+ else:
157
+ return False
158
+
159
+ return MetaMC(f"Deferred[{ref}]", check)
160
+
161
+
162
+ @parametrized_class_check
163
+ def Exactly(cls, base_cls):
164
+ """Match the class but not its subclasses."""
165
+ return TypeRelationship(
166
+ order=Order.LESS if cls is base_cls else typeorder(base_cls, cls),
167
+ matches=cls is base_cls,
168
+ )
169
+
170
+
171
+ @parametrized_class_check
172
+ def StrictSubclass(cls, base_cls):
173
+ """Match subclasses but not the base class."""
174
+ return (
175
+ isinstance(cls, type)
176
+ and issubclass(cls, base_cls)
177
+ and cls is not base_cls
178
+ )
179
+
180
+
181
+ @parametrized_class_check
182
+ def Intersection(cls, *classes):
183
+ """Match all classes."""
184
+ matches = all(subclasscheck(cls, t) for t in classes)
185
+ compare = [x for t in classes if (x := typeorder(t, cls)) is not Order.NONE]
186
+ if not compare:
187
+ return TypeRelationship(Order.NONE, matches=matches)
188
+ elif any(x is Order.LESS or x is Order.SAME for x in compare):
189
+ return TypeRelationship(Order.LESS, matches=matches)
190
+ else:
191
+ return TypeRelationship(Order.MORE, matches=matches)
192
+
193
+
194
+ @parametrized_class_check
195
+ def HasMethod(cls, method_name):
196
+ """Match classes that have a specific method."""
197
+ return hasattr(cls, method_name)
198
+
199
+
200
+ @runtime_checkable
201
+ @dataclass
202
+ class Dataclass(Protocol):
203
+ @classmethod
204
+ def __subclasshook__(cls, subclass):
205
+ return hasattr(subclass, "__dataclass_fields__") and hasattr(
206
+ subclass, "__dataclass_params__"
207
+ )
208
+
209
+
210
+ __all__ = [
211
+ "Dataclass",
212
+ "Deferred",
213
+ "Exactly",
214
+ "HasMethod",
215
+ "Intersection",
216
+ "StrictSubclass",
217
+ "class_check",
218
+ "parametrized_class_check",
219
+ ]