ovld 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ovld/__init__.py +23 -1
- ovld/codegen.py +303 -0
- ovld/core.py +62 -349
- ovld/dependent.py +24 -72
- ovld/medley.py +408 -0
- ovld/mro.py +6 -3
- ovld/py.typed +0 -0
- ovld/recode.py +99 -165
- ovld/signatures.py +275 -0
- ovld/typemap.py +40 -38
- ovld/types.py +47 -44
- ovld/utils.py +55 -18
- ovld/version.py +1 -1
- {ovld-0.4.5.dist-info → ovld-0.5.0.dist-info}/METADATA +62 -16
- ovld-0.5.0.dist-info/RECORD +18 -0
- {ovld-0.4.5.dist-info → ovld-0.5.0.dist-info}/WHEEL +1 -1
- ovld-0.4.5.dist-info/RECORD +0 -14
- {ovld-0.4.5.dist-info → ovld-0.5.0.dist-info}/licenses/LICENSE +0 -0
ovld/recode.py
CHANGED
@@ -1,41 +1,23 @@
|
|
1
1
|
import ast
|
2
2
|
import inspect
|
3
|
-
import linecache
|
4
3
|
import textwrap
|
5
|
-
from ast import _splitlines_no_ff as splitlines
|
6
4
|
from functools import reduce
|
7
5
|
from itertools import count
|
8
6
|
from types import CodeType, FunctionType
|
9
7
|
|
10
|
-
from .
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
"call_next() can only be used from inside an @ovld-registered function."
|
8
|
+
from .codegen import (
|
9
|
+
Code,
|
10
|
+
instantiate_code,
|
11
|
+
rename_code,
|
12
|
+
rename_function,
|
13
|
+
transfer_function,
|
17
14
|
)
|
15
|
+
from .utils import MISSING, NameDatabase, SpecialForm, UsageError, subtler_type
|
18
16
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
code = compile(source=code, filename=virtual_file, mode="exec")
|
24
|
-
glb = {**inject}
|
25
|
-
exec(code, glb, glb)
|
26
|
-
return glb[symbol]
|
27
|
-
|
28
|
-
|
29
|
-
# # Previous version: generate a temporary file
|
30
|
-
# def instantiate_code(symbol, code, inject={}):
|
31
|
-
# tf = tempfile.NamedTemporaryFile("w")
|
32
|
-
# _tempfiles.append(tf)
|
33
|
-
# tf.write(code)
|
34
|
-
# tf.flush()
|
35
|
-
# glb = runpy.run_path(tf.name)
|
36
|
-
# rval = glb[symbol]
|
37
|
-
# rval.__globals__.update(inject)
|
38
|
-
# return rval
|
17
|
+
recurse = SpecialForm("recurse")
|
18
|
+
call_next = SpecialForm("call_next")
|
19
|
+
resolve = SpecialForm("resolve")
|
20
|
+
current_code = SpecialForm("current_code")
|
39
21
|
|
40
22
|
|
41
23
|
dispatch_template = """
|
@@ -53,6 +35,13 @@ return {mvar}({posargs})
|
|
53
35
|
"""
|
54
36
|
|
55
37
|
|
38
|
+
def generate_checking_code(typ):
|
39
|
+
if hasattr(typ, "codegen"):
|
40
|
+
return typ.codegen()
|
41
|
+
else:
|
42
|
+
return Code("isinstance($arg, $this)", this=typ)
|
43
|
+
|
44
|
+
|
56
45
|
def generate_dispatch(ov, arganal):
|
57
46
|
def join(li, sep=", ", trail=False):
|
58
47
|
li = [x for x in li if x]
|
@@ -171,12 +160,11 @@ def generate_dispatch(ov, arganal):
|
|
171
160
|
|
172
161
|
|
173
162
|
def generate_dependent_dispatch(tup, handlers, next_call, slf, name, err, nerr):
|
174
|
-
from .dependent import
|
163
|
+
from .dependent import is_dependent
|
175
164
|
|
176
165
|
def to_dict(tup):
|
177
166
|
return dict(
|
178
|
-
entry if isinstance(entry, tuple) else (i, entry)
|
179
|
-
for i, entry in enumerate(tup)
|
167
|
+
entry if isinstance(entry, tuple) else (i, entry) for i, entry in enumerate(tup)
|
180
168
|
)
|
181
169
|
|
182
170
|
def argname(x):
|
@@ -185,12 +173,6 @@ def generate_dependent_dispatch(tup, handlers, next_call, slf, name, err, nerr):
|
|
185
173
|
def argprovide(x):
|
186
174
|
return f"ARG{x}" if isinstance(x, int) else f"{x}={x}"
|
187
175
|
|
188
|
-
def codegen(typ, arg):
|
189
|
-
cg = generate_checking_code(typ)
|
190
|
-
return cg.template.format(
|
191
|
-
arg=arg, **{k: ndb[v] for k, v in cg.substitutions.items()}
|
192
|
-
)
|
193
|
-
|
194
176
|
tup = to_dict(tup)
|
195
177
|
handlers = [(h, to_dict(types)) for h, types in handlers]
|
196
178
|
ndb = NameDatabase(default_name="INJECT")
|
@@ -209,18 +191,14 @@ def generate_dependent_dispatch(tup, handlers, next_call, slf, name, err, nerr):
|
|
209
191
|
if not possibilities:
|
210
192
|
if getattr(focus, "keyable_type", False):
|
211
193
|
all_keys = [
|
212
|
-
{key: h for key in types[k].get_keys()}
|
213
|
-
for h, types in handlers
|
194
|
+
{key: h for key in types[k].get_keys()} for h, types in handlers
|
214
195
|
]
|
215
196
|
keyed = reduce(lambda a, b: {**a, **b}, all_keys)
|
216
|
-
if (
|
217
|
-
len(keyed) == sum(map(len, all_keys))
|
218
|
-
and len(featured) < 4
|
219
|
-
):
|
197
|
+
if len(keyed) == sum(map(len, all_keys)) and len(featured) < 4:
|
220
198
|
exclusive = True
|
221
199
|
keyexpr = None
|
222
200
|
else:
|
223
|
-
keyexpr = focus.keygen().
|
201
|
+
keyexpr = focus.keygen().sub(arg=Code(argname(k))).fill(ndb)
|
224
202
|
|
225
203
|
else:
|
226
204
|
exclusive = getattr(focus, "exclusive_type", False)
|
@@ -230,7 +208,10 @@ def generate_dependent_dispatch(tup, handlers, next_call, slf, name, err, nerr):
|
|
230
208
|
if len(relevant) > 1:
|
231
209
|
# The keyexpr method only works if there is only one condition to check.
|
232
210
|
keyexpr = keyed = None
|
233
|
-
codes = [
|
211
|
+
codes = [
|
212
|
+
generate_checking_code(types[k]).sub(arg=Code(argname(k))).fill(ndb)
|
213
|
+
for k in relevant
|
214
|
+
]
|
234
215
|
conj = " and ".join(codes)
|
235
216
|
if not conj: # pragma: no cover
|
236
217
|
# Not sure if this can happen
|
@@ -279,9 +260,7 @@ def generate_dependent_dispatch(tup, handlers, next_call, slf, name, err, nerr):
|
|
279
260
|
|
280
261
|
inject["FALLTHROUGH"] = (next_call and next_call[0]) or raise_error
|
281
262
|
|
282
|
-
fn = instantiate_code(
|
283
|
-
symbol="__DEPENDENT_DISPATCH__", code=code, inject=inject
|
284
|
-
)
|
263
|
+
fn = instantiate_code(symbol="__DEPENDENT_DISPATCH__", code=code, inject=inject)
|
285
264
|
return rename_function(fn, name)
|
286
265
|
|
287
266
|
|
@@ -311,14 +290,10 @@ class Conformer:
|
|
311
290
|
if new_code is None:
|
312
291
|
return
|
313
292
|
ofn = self.orig_fn
|
314
|
-
new_fn =
|
315
|
-
|
316
|
-
|
317
|
-
ofn.__name__,
|
318
|
-
ofn.__defaults__,
|
319
|
-
ofn.__closure__,
|
293
|
+
new_fn = transfer_function(
|
294
|
+
func=ofn,
|
295
|
+
code=new_code,
|
320
296
|
)
|
321
|
-
new_fn.__annotations__ = ofn.__annotations__
|
322
297
|
|
323
298
|
self.ovld.register(new_fn)
|
324
299
|
|
@@ -329,90 +304,45 @@ class Conformer:
|
|
329
304
|
self.code = new_code
|
330
305
|
|
331
306
|
|
332
|
-
def rename_code(co, newname): # pragma: no cover
|
333
|
-
if hasattr(co, "replace"):
|
334
|
-
if hasattr(co, "co_qualname"):
|
335
|
-
return co.replace(co_name=newname, co_qualname=newname)
|
336
|
-
else:
|
337
|
-
return co.replace(co_name=newname)
|
338
|
-
else:
|
339
|
-
return type(co)(
|
340
|
-
co.co_argcount,
|
341
|
-
co.co_kwonlyargcount,
|
342
|
-
co.co_nlocals,
|
343
|
-
co.co_stacksize,
|
344
|
-
co.co_flags,
|
345
|
-
co.co_code,
|
346
|
-
co.co_consts,
|
347
|
-
co.co_names,
|
348
|
-
co.co_varnames,
|
349
|
-
co.co_filename,
|
350
|
-
newname,
|
351
|
-
co.co_firstlineno,
|
352
|
-
co.co_lnotab,
|
353
|
-
co.co_freevars,
|
354
|
-
co.co_cellvars,
|
355
|
-
)
|
356
|
-
|
357
|
-
|
358
|
-
def rename_function(fn, newname):
|
359
|
-
"""Create a copy of the function with a different name."""
|
360
|
-
newcode = rename_code(fn.__code__, newname)
|
361
|
-
new_fn = FunctionType(
|
362
|
-
newcode, fn.__globals__, newname, fn.__defaults__, fn.__closure__
|
363
|
-
)
|
364
|
-
new_fn.__kwdefaults__ = fn.__kwdefaults__
|
365
|
-
new_fn.__annotations__ = fn.__annotations__
|
366
|
-
return new_fn
|
367
|
-
|
368
|
-
|
369
307
|
class NameConverter(ast.NodeTransformer):
|
370
|
-
def __init__(
|
371
|
-
self,
|
372
|
-
anal,
|
373
|
-
recurse_sym,
|
374
|
-
call_next_sym,
|
375
|
-
ovld_mangled,
|
376
|
-
map_mangled,
|
377
|
-
code_mangled,
|
378
|
-
):
|
308
|
+
def __init__(self, anal, special_syms, mapping):
|
379
309
|
self.analysis = anal
|
380
|
-
self.
|
381
|
-
self.
|
382
|
-
self.ovld_mangled =
|
383
|
-
self.map_mangled =
|
384
|
-
self.code_mangled =
|
310
|
+
self.syms = special_syms
|
311
|
+
self.mapping = mapping
|
312
|
+
self.ovld_mangled = mapping[recurse]
|
313
|
+
self.map_mangled = mapping[resolve]
|
314
|
+
self.code_mangled = mapping[current_code]
|
385
315
|
self.count = count()
|
386
316
|
|
317
|
+
def is_special(self, name, *kinds):
|
318
|
+
return any(name in self.syms[kind] for kind in kinds)
|
319
|
+
|
387
320
|
def visit_Name(self, node):
|
388
|
-
if node.id
|
321
|
+
if node.id in self.mapping:
|
389
322
|
return ast.copy_location(
|
390
323
|
old_node=node,
|
391
|
-
new_node=ast.Name(self.
|
324
|
+
new_node=ast.Name(self.mapping[node.id], ctx=node.ctx),
|
392
325
|
)
|
393
|
-
elif node.id
|
326
|
+
elif self.is_special(node.id, call_next):
|
394
327
|
raise UsageError("call_next should be called right away")
|
395
328
|
else:
|
396
329
|
return node
|
397
330
|
|
398
331
|
def visit_Call(self, node):
|
399
|
-
if not isinstance(node.func, ast.Name) or
|
400
|
-
|
401
|
-
self.call_next_sym,
|
332
|
+
if not isinstance(node.func, ast.Name) or not self.is_special(
|
333
|
+
node.func.id, recurse, call_next
|
402
334
|
):
|
403
335
|
return self.generic_visit(node)
|
404
336
|
|
405
337
|
if any(isinstance(arg, ast.Starred) for arg in node.args):
|
406
338
|
return self.generic_visit(node)
|
407
339
|
|
408
|
-
cn = node.func.id
|
340
|
+
cn = self.is_special(node.func.id, call_next)
|
409
341
|
tmp = f"__TMP{next(self.count)}_"
|
410
342
|
|
411
343
|
def _make_lookup_call(key, arg):
|
412
344
|
name = (
|
413
|
-
"__SUBTLER_TYPE"
|
414
|
-
if self.analysis.lookup_for(key) is subtler_type
|
415
|
-
else "type"
|
345
|
+
"__SUBTLER_TYPE" if self.analysis.lookup_for(key) is subtler_type else "type"
|
416
346
|
)
|
417
347
|
value = ast.NamedExpr(
|
418
348
|
target=ast.Name(id=f"{tmp}{key}", ctx=ast.Store()),
|
@@ -426,9 +356,7 @@ class NameConverter(ast.NodeTransformer):
|
|
426
356
|
)
|
427
357
|
|
428
358
|
# type index for positional arguments
|
429
|
-
type_parts = [
|
430
|
-
_make_lookup_call(i, arg) for i, arg in enumerate(node.args)
|
431
|
-
]
|
359
|
+
type_parts = [_make_lookup_call(i, arg) for i, arg in enumerate(node.args)]
|
432
360
|
|
433
361
|
# type index for keyword arguments
|
434
362
|
type_parts += [
|
@@ -460,10 +388,7 @@ class NameConverter(ast.NodeTransformer):
|
|
460
388
|
new_node = ast.Call(
|
461
389
|
func=method,
|
462
390
|
args=selfarg
|
463
|
-
+ [
|
464
|
-
ast.Name(id=f"{tmp}{i}", ctx=ast.Load())
|
465
|
-
for i, arg in enumerate(node.args)
|
466
|
-
],
|
391
|
+
+ [ast.Name(id=f"{tmp}{i}", ctx=ast.Load()) for i, arg in enumerate(node.args)],
|
467
392
|
keywords=[
|
468
393
|
ast.keyword(
|
469
394
|
arg=kw.arg,
|
@@ -475,37 +400,41 @@ class NameConverter(ast.NodeTransformer):
|
|
475
400
|
return ast.copy_location(old_node=node, new_node=new_node)
|
476
401
|
|
477
402
|
|
478
|
-
def _search_names(co,
|
479
|
-
|
480
|
-
if
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
for
|
489
|
-
|
403
|
+
def _search_names(co, specials, glb, closure=None):
|
404
|
+
def _search(co, values, glb, closure=None):
|
405
|
+
if isinstance(co, CodeType):
|
406
|
+
if closure is not None:
|
407
|
+
for varname, cell in zip(co.co_freevars, closure):
|
408
|
+
try:
|
409
|
+
if any(cell.cell_contents is v for v in values):
|
410
|
+
yield varname
|
411
|
+
except ValueError: # cell is empty
|
412
|
+
pass
|
413
|
+
for name in co.co_names:
|
414
|
+
if any(glb.get(name, None) is v for v in values):
|
415
|
+
yield name
|
416
|
+
else:
|
417
|
+
for ct in co.co_consts:
|
418
|
+
yield from _search(ct, values, glb)
|
419
|
+
|
420
|
+
return {k: list(_search(co, v, glb, closure)) for k, v in specials.items()}
|
490
421
|
|
491
422
|
|
492
423
|
def adapt_function(fn, ovld, newname):
|
493
424
|
"""Create a copy of the function with a different name."""
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
(recurse, ovld, ovld.dispatch),
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
425
|
+
syms = _search_names(
|
426
|
+
fn.__code__,
|
427
|
+
{
|
428
|
+
recurse: (recurse, ovld, ovld.dispatch),
|
429
|
+
call_next: (call_next,),
|
430
|
+
resolve: (resolve,),
|
431
|
+
current_code: (current_code,),
|
432
|
+
},
|
433
|
+
fn.__globals__,
|
434
|
+
fn.__closure__,
|
504
435
|
)
|
505
|
-
if
|
506
|
-
return recode(
|
507
|
-
fn, ovld, rec_syms and rec_syms[0], cn_syms and cn_syms[0], newname
|
508
|
-
)
|
436
|
+
if any(syms.values()):
|
437
|
+
return recode(fn, ovld, syms, newname)
|
509
438
|
else:
|
510
439
|
return rename_function(fn, newname)
|
511
440
|
|
@@ -536,7 +465,7 @@ def closure_wrap(tree, fname, names):
|
|
536
465
|
return ast.Module(body=[wrap], type_ignores=[])
|
537
466
|
|
538
467
|
|
539
|
-
def recode(fn, ovld,
|
468
|
+
def recode(fn, ovld, syms, newname):
|
540
469
|
ovld_mangled = f"___OVLD{ovld.id}"
|
541
470
|
map_mangled = f"___MAP{ovld.id}"
|
542
471
|
code_mangled = f"___CODE{next(_current)}"
|
@@ -550,14 +479,22 @@ def recode(fn, ovld, recurse_sym, call_next_sym, newname):
|
|
550
479
|
" avoid calling recurse()/call_next()"
|
551
480
|
)
|
552
481
|
tree = ast.parse(textwrap.dedent(src))
|
482
|
+
|
483
|
+
mapping = {
|
484
|
+
recurse: ovld_mangled,
|
485
|
+
resolve: map_mangled,
|
486
|
+
current_code: code_mangled,
|
487
|
+
}
|
488
|
+
for special, symbols in syms.items():
|
489
|
+
for sym in symbols:
|
490
|
+
if special in mapping:
|
491
|
+
mapping[sym] = mapping[special]
|
553
492
|
new = NameConverter(
|
554
493
|
anal=ovld.argument_analysis,
|
555
|
-
|
556
|
-
|
557
|
-
ovld_mangled=ovld_mangled,
|
558
|
-
map_mangled=map_mangled,
|
559
|
-
code_mangled=code_mangled,
|
494
|
+
special_syms=syms,
|
495
|
+
mapping=mapping,
|
560
496
|
).visit(tree)
|
497
|
+
|
561
498
|
new.body[0].decorator_list = []
|
562
499
|
if fn.__closure__:
|
563
500
|
new = closure_wrap(new.body[0], "irrelevant", fn.__code__.co_freevars)
|
@@ -568,17 +505,14 @@ def recode(fn, ovld, recurse_sym, call_next_sym, newname):
|
|
568
505
|
res = [x for x in res.co_consts if isinstance(x, CodeType)][0]
|
569
506
|
(*_, new_code) = [ct for ct in res.co_consts if isinstance(ct, CodeType)]
|
570
507
|
new_closure = tuple(
|
571
|
-
[
|
572
|
-
fn.__closure__[fn.__code__.co_freevars.index(name)]
|
573
|
-
for name in new_code.co_freevars
|
574
|
-
]
|
508
|
+
[fn.__closure__[fn.__code__.co_freevars.index(name)] for name in new_code.co_freevars]
|
575
509
|
)
|
576
|
-
new_fn =
|
577
|
-
|
510
|
+
new_fn = transfer_function(
|
511
|
+
func=fn,
|
512
|
+
code=rename_code(new_code, newname),
|
513
|
+
name=newname,
|
514
|
+
closure=new_closure,
|
578
515
|
)
|
579
|
-
new_fn.__kwdefaults__ = fn.__kwdefaults__
|
580
|
-
new_fn.__annotations__ = fn.__annotations__
|
581
|
-
new_fn = rename_function(new_fn, newname)
|
582
516
|
new_fn.__globals__["__SUBTLER_TYPE"] = subtler_type
|
583
517
|
new_fn.__globals__[ovld_mangled] = ovld.dispatch
|
584
518
|
new_fn.__globals__[map_mangled] = ovld.map
|
ovld/signatures.py
ADDED
@@ -0,0 +1,275 @@
|
|
1
|
+
"""Utilities to deal with function signatures."""
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import itertools
|
5
|
+
import typing
|
6
|
+
from collections import OrderedDict, defaultdict
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from functools import cached_property
|
9
|
+
from types import GenericAlias
|
10
|
+
|
11
|
+
from .types import normalize_type
|
12
|
+
from .utils import MISSING, subtler_type
|
13
|
+
|
14
|
+
|
15
|
+
class LazySignature(inspect.Signature):
|
16
|
+
def __init__(self, ovld):
|
17
|
+
super().__init__([])
|
18
|
+
self.ovld = ovld
|
19
|
+
|
20
|
+
def replace(
|
21
|
+
self, *, parameters=inspect._void, return_annotation=inspect._void
|
22
|
+
): # pragma: no cover
|
23
|
+
if parameters is inspect._void:
|
24
|
+
parameters = self.parameters.values()
|
25
|
+
|
26
|
+
if return_annotation is inspect._void:
|
27
|
+
return_annotation = self._return_annotation
|
28
|
+
|
29
|
+
return inspect.Signature(parameters, return_annotation=return_annotation)
|
30
|
+
|
31
|
+
@property
|
32
|
+
def parameters(self):
|
33
|
+
anal = self.ovld.analyze_arguments()
|
34
|
+
parameters = []
|
35
|
+
if anal.is_method:
|
36
|
+
parameters.append(
|
37
|
+
inspect.Parameter(
|
38
|
+
name="self",
|
39
|
+
kind=inspect._POSITIONAL_ONLY,
|
40
|
+
)
|
41
|
+
)
|
42
|
+
parameters += [
|
43
|
+
inspect.Parameter(
|
44
|
+
name=p,
|
45
|
+
kind=inspect._POSITIONAL_ONLY,
|
46
|
+
)
|
47
|
+
for p in anal.strict_positional_required
|
48
|
+
]
|
49
|
+
parameters += [
|
50
|
+
inspect.Parameter(
|
51
|
+
name=p,
|
52
|
+
kind=inspect._POSITIONAL_ONLY,
|
53
|
+
default=MISSING,
|
54
|
+
)
|
55
|
+
for p in anal.strict_positional_optional
|
56
|
+
]
|
57
|
+
parameters += [
|
58
|
+
inspect.Parameter(
|
59
|
+
name=p,
|
60
|
+
kind=inspect._POSITIONAL_OR_KEYWORD,
|
61
|
+
)
|
62
|
+
for p in anal.positional_required
|
63
|
+
]
|
64
|
+
parameters += [
|
65
|
+
inspect.Parameter(
|
66
|
+
name=p,
|
67
|
+
kind=inspect._POSITIONAL_OR_KEYWORD,
|
68
|
+
default=MISSING,
|
69
|
+
)
|
70
|
+
for p in anal.positional_optional
|
71
|
+
]
|
72
|
+
parameters += [
|
73
|
+
inspect.Parameter(
|
74
|
+
name=p,
|
75
|
+
kind=inspect._KEYWORD_ONLY,
|
76
|
+
)
|
77
|
+
for p in anal.keyword_required
|
78
|
+
]
|
79
|
+
parameters += [
|
80
|
+
inspect.Parameter(
|
81
|
+
name=p,
|
82
|
+
kind=inspect._KEYWORD_ONLY,
|
83
|
+
default=MISSING,
|
84
|
+
)
|
85
|
+
for p in anal.keyword_optional
|
86
|
+
]
|
87
|
+
return OrderedDict({p.name: p for p in parameters})
|
88
|
+
|
89
|
+
|
90
|
+
@dataclass(frozen=True)
|
91
|
+
class Arginfo:
|
92
|
+
position: typing.Optional[int]
|
93
|
+
name: typing.Optional[str]
|
94
|
+
required: bool
|
95
|
+
ann: type
|
96
|
+
|
97
|
+
@cached_property
|
98
|
+
def is_complex(self):
|
99
|
+
return isinstance(self.ann, GenericAlias)
|
100
|
+
|
101
|
+
@cached_property
|
102
|
+
def canonical(self):
|
103
|
+
return self.name if self.position is None else self.position
|
104
|
+
|
105
|
+
|
106
|
+
@dataclass(frozen=True)
|
107
|
+
class Signature:
|
108
|
+
types: tuple
|
109
|
+
return_type: type
|
110
|
+
req_pos: int
|
111
|
+
max_pos: int
|
112
|
+
req_names: frozenset
|
113
|
+
vararg: bool
|
114
|
+
priority: float
|
115
|
+
tiebreak: int = 0
|
116
|
+
is_method: bool = False
|
117
|
+
arginfo: list[Arginfo] = field(default_factory=list, hash=False, compare=False)
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def extract(cls, fn):
|
121
|
+
typelist = []
|
122
|
+
sig = inspect.signature(fn)
|
123
|
+
max_pos = 0
|
124
|
+
req_pos = 0
|
125
|
+
req_names = set()
|
126
|
+
is_method = False
|
127
|
+
|
128
|
+
arginfo = []
|
129
|
+
for i, (name, param) in enumerate(sig.parameters.items()):
|
130
|
+
if name == "self" or (name == "cls" and getattr(fn, "specializer", False)):
|
131
|
+
if i != 0: # pragma: no cover
|
132
|
+
raise Exception(
|
133
|
+
f"Argument name '{name}' marks a method and must always be in the first position."
|
134
|
+
)
|
135
|
+
is_method = True
|
136
|
+
continue
|
137
|
+
pos = nm = None
|
138
|
+
ann = normalize_type(param.annotation, fn)
|
139
|
+
if param.kind is inspect._POSITIONAL_ONLY:
|
140
|
+
pos = i - is_method
|
141
|
+
typelist.append(ann)
|
142
|
+
req_pos += param.default is inspect._empty
|
143
|
+
max_pos += 1
|
144
|
+
elif param.kind is inspect._POSITIONAL_OR_KEYWORD:
|
145
|
+
pos = i - is_method
|
146
|
+
nm = param.name
|
147
|
+
typelist.append(ann)
|
148
|
+
req_pos += param.default is inspect._empty
|
149
|
+
max_pos += 1
|
150
|
+
elif param.kind is inspect._KEYWORD_ONLY:
|
151
|
+
nm = param.name
|
152
|
+
typelist.append((param.name, ann))
|
153
|
+
if param.default is inspect._empty:
|
154
|
+
req_names.add(param.name)
|
155
|
+
elif param.kind is inspect._VAR_POSITIONAL:
|
156
|
+
raise TypeError("ovld does not support *args")
|
157
|
+
elif param.kind is inspect._VAR_KEYWORD:
|
158
|
+
raise TypeError("ovld does not support **kwargs")
|
159
|
+
arginfo.append(
|
160
|
+
Arginfo(
|
161
|
+
position=pos,
|
162
|
+
name=nm,
|
163
|
+
required=param.default is inspect._empty,
|
164
|
+
ann=ann,
|
165
|
+
)
|
166
|
+
)
|
167
|
+
|
168
|
+
return cls(
|
169
|
+
types=tuple(typelist),
|
170
|
+
return_type=normalize_type(sig.return_annotation, fn),
|
171
|
+
req_pos=req_pos,
|
172
|
+
max_pos=max_pos,
|
173
|
+
req_names=frozenset(req_names),
|
174
|
+
vararg=False,
|
175
|
+
is_method=is_method,
|
176
|
+
priority=None,
|
177
|
+
arginfo=arginfo,
|
178
|
+
)
|
179
|
+
|
180
|
+
|
181
|
+
class ArgumentAnalyzer:
|
182
|
+
def __init__(self):
|
183
|
+
self.name_to_positions = defaultdict(set)
|
184
|
+
self.position_to_names = defaultdict(set)
|
185
|
+
self.counts = defaultdict(lambda: [0, 0])
|
186
|
+
self.complex_transforms = set()
|
187
|
+
self.total = 0
|
188
|
+
self.is_method = None
|
189
|
+
self.done = False
|
190
|
+
|
191
|
+
def add(self, fn):
|
192
|
+
self.done = False
|
193
|
+
sig = Signature.extract(fn)
|
194
|
+
self.complex_transforms.update(arg.canonical for arg in sig.arginfo if arg.is_complex)
|
195
|
+
for arg in sig.arginfo:
|
196
|
+
if arg.position is not None:
|
197
|
+
self.position_to_names[arg.position].add(arg.name)
|
198
|
+
if arg.name is not None:
|
199
|
+
self.name_to_positions[arg.name].add(arg.canonical)
|
200
|
+
|
201
|
+
cnt = self.counts[arg.canonical]
|
202
|
+
cnt[0] += arg.required
|
203
|
+
cnt[1] += 1
|
204
|
+
|
205
|
+
self.total += 1
|
206
|
+
|
207
|
+
if self.is_method is None:
|
208
|
+
self.is_method = sig.is_method
|
209
|
+
elif self.is_method != sig.is_method: # pragma: no cover
|
210
|
+
raise TypeError(
|
211
|
+
"Some, but not all registered methods define `self`. It should be all or none."
|
212
|
+
)
|
213
|
+
|
214
|
+
def compile(self):
|
215
|
+
if self.done:
|
216
|
+
return
|
217
|
+
for name, pos in self.name_to_positions.items():
|
218
|
+
if len(pos) != 1:
|
219
|
+
if all(isinstance(p, int) for p in pos):
|
220
|
+
raise TypeError(
|
221
|
+
f"Argument '{name}' is declared in different positions by different methods. The same argument name should always be in the same position unless it is strictly positional."
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
raise TypeError(
|
225
|
+
f"Argument '{name}' is declared in a positional and keyword setting by different methods. It should be either."
|
226
|
+
)
|
227
|
+
|
228
|
+
p_to_n = [list(names) for _, names in sorted(self.position_to_names.items())]
|
229
|
+
|
230
|
+
positional = list(
|
231
|
+
itertools.takewhile(
|
232
|
+
lambda names: len(names) == 1 and isinstance(names[0], str),
|
233
|
+
reversed(p_to_n),
|
234
|
+
)
|
235
|
+
)
|
236
|
+
positional.reverse()
|
237
|
+
strict_positional = p_to_n[: len(p_to_n) - len(positional)]
|
238
|
+
|
239
|
+
assert strict_positional + positional == p_to_n
|
240
|
+
|
241
|
+
self.strict_positional_required = [
|
242
|
+
f"ARG{pos + 1}"
|
243
|
+
for pos, _ in enumerate(strict_positional)
|
244
|
+
if self.counts[pos][0] == self.total
|
245
|
+
]
|
246
|
+
self.strict_positional_optional = [
|
247
|
+
f"ARG{pos + 1}"
|
248
|
+
for pos, _ in enumerate(strict_positional)
|
249
|
+
if self.counts[pos][0] != self.total
|
250
|
+
]
|
251
|
+
|
252
|
+
self.positional_required = [
|
253
|
+
names[0]
|
254
|
+
for pos, names in enumerate(positional)
|
255
|
+
if self.counts[pos + len(strict_positional)][0] == self.total
|
256
|
+
]
|
257
|
+
self.positional_optional = [
|
258
|
+
names[0]
|
259
|
+
for pos, names in enumerate(positional)
|
260
|
+
if self.counts[pos + len(strict_positional)][0] != self.total
|
261
|
+
]
|
262
|
+
|
263
|
+
keywords = [
|
264
|
+
name for _, (name,) in self.name_to_positions.items() if not isinstance(name, int)
|
265
|
+
]
|
266
|
+
self.keyword_required = [
|
267
|
+
name for name in keywords if self.counts[name][0] == self.total
|
268
|
+
]
|
269
|
+
self.keyword_optional = [
|
270
|
+
name for name in keywords if self.counts[name][0] != self.total
|
271
|
+
]
|
272
|
+
self.done = True
|
273
|
+
|
274
|
+
def lookup_for(self, key):
|
275
|
+
return subtler_type if key in self.complex_transforms else type
|