ovld 0.3.9__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ovld/recode.py CHANGED
@@ -1,11 +1,293 @@
1
+ import ast
1
2
  import inspect
2
- from types import FunctionType
3
+ import linecache
4
+ import textwrap
5
+ from ast import _splitlines_no_ff as splitlines
6
+ from functools import reduce
7
+ from itertools import count
8
+ from types import CodeType, FunctionType
3
9
 
4
- recurse = object()
10
+ from .dependent import DependentType
11
+ from .utils import Unusable, UsageError
12
+
13
+ recurse = Unusable(
14
+ "recurse() can only be used from inside an @ovld-registered function."
15
+ )
16
+ call_next = Unusable(
17
+ "call_next() can only be used from inside an @ovld-registered function."
18
+ )
19
+
20
+
21
+ dispatch_template = """
22
+ from ovld.utils import MISSING
23
+
24
+ def __DISPATCH__(self, {args}):
25
+ {inits}
26
+ {body}
27
+ {call}
28
+ """
29
+
30
+
31
+ call_template = """
32
+ method = self.map[({lookup})]
33
+ return method({posargs})
34
+ """
35
+
36
+
37
+ def instantiate_code(symbol, code, inject={}):
38
+ virtual_file = f"<ovld{hash(code)}>"
39
+ linecache.cache[virtual_file] = (None, None, splitlines(code), virtual_file)
40
+ code = compile(source=code, filename=virtual_file, mode="exec")
41
+ glb = {**inject}
42
+ exec(code, glb, glb)
43
+ return glb[symbol]
44
+
45
+
46
+ # # Previous version: generate a temporary file
47
+ # def instantiate_code(symbol, code, inject={}):
48
+ # tf = tempfile.NamedTemporaryFile("w")
49
+ # _tempfiles.append(tf)
50
+ # tf.write(code)
51
+ # tf.flush()
52
+ # glb = runpy.run_path(tf.name)
53
+ # rval = glb[symbol]
54
+ # rval.__globals__.update(inject)
55
+ # return rval
56
+
57
+
58
+ def generate_dispatch(arganal):
59
+ def join(li, sep=", ", trail=False):
60
+ li = [x for x in li if x]
61
+ rval = sep.join(li)
62
+ if len(li) == 1 and trail:
63
+ rval += ","
64
+ return rval
65
+
66
+ spr, spo, pr, po, kr, ko = arganal.compile()
67
+
68
+ inits = set()
69
+
70
+ kwargsstar = ""
71
+ targsstar = ""
72
+
73
+ args = []
74
+ body = [""]
75
+ posargs = ["self.obj" if arganal.is_method else ""]
76
+ lookup = []
77
+
78
+ i = 0
79
+
80
+ for name in spr + spo:
81
+ if name in spr:
82
+ args.append(name)
83
+ else:
84
+ args.append(f"{name}=MISSING")
85
+ posargs.append(name)
86
+ lookup.append(f"{arganal.lookup_for(i)}({name})")
87
+ i += 1
88
+
89
+ if len(po) <= 1:
90
+ # If there are more than one non-strictly positional optional arguments,
91
+ # then all positional arguments are strictly positional, because if e.g.
92
+ # x and y are optional we want x==MISSING to imply that y==MISSING, but
93
+ # that only works if y cannot be provided as a keyword argument.
94
+ args.append("/")
95
+
96
+ for name in pr + po:
97
+ if name in pr:
98
+ args.append(name)
99
+ else:
100
+ args.append(f"{name}=MISSING")
101
+ posargs.append(name)
102
+ lookup.append(f"{arganal.lookup_for(i)}({name})")
103
+ i += 1
104
+
105
+ if len(po) > 1:
106
+ args.append("/")
107
+
108
+ if kr or ko:
109
+ args.append("*")
110
+
111
+ for name in kr:
112
+ lookup_fn = (
113
+ "self.map.transform"
114
+ if name in arganal.complex_transforms
115
+ else "type"
116
+ )
117
+ args.append(f"{name}")
118
+ posargs.append(f"{name}={name}")
119
+ lookup.append(f"({name!r}, {lookup_fn}({name}))")
120
+
121
+ for name in ko:
122
+ args.append(f"{name}=MISSING")
123
+ kwargsstar = "**KWARGS"
124
+ targsstar = "*TARGS"
125
+ inits.add("KWARGS = {}")
126
+ inits.add("TARGS = []")
127
+ body.append(f"if {name} is not MISSING:")
128
+ body.append(f" KWARGS[{name!r}] = {name}")
129
+ body.append(
130
+ f" TARGS.append(({name!r}, {arganal.lookup_for(name)}({name})))"
131
+ )
132
+
133
+ posargs.append(kwargsstar)
134
+ lookup.append(targsstar)
135
+
136
+ fullcall = call_template.format(
137
+ lookup=join(lookup, trail=True),
138
+ posargs=join(posargs),
139
+ )
140
+
141
+ calls = []
142
+ if spo or po:
143
+ req = len(spr + pr)
144
+ for i, arg in enumerate(spo + po):
145
+ call = call_template.format(
146
+ lookup=join(lookup[: req + i], trail=True),
147
+ posargs=join(posargs[: req + i + 1]),
148
+ )
149
+ call = textwrap.indent(call, " ")
150
+ calls.append(f"\nif {arg} is MISSING:{call}")
151
+ calls.append(fullcall)
152
+
153
+ code = dispatch_template.format(
154
+ inits=join(inits, sep="\n "),
155
+ args=join(args),
156
+ body=join(body, sep="\n "),
157
+ call=textwrap.indent("".join(calls), " "),
158
+ )
159
+ return instantiate_code("__DISPATCH__", code)
160
+
161
+
162
+ class GenSym:
163
+ def __init__(self, prefix):
164
+ self.prefix = prefix
165
+ self.count = count()
166
+ self.variables = {}
167
+
168
+ def add(self, value):
169
+ if isinstance(value, (int, float, str)):
170
+ return repr(value)
171
+ id = f"{self.prefix}{next(self.count)}"
172
+ self.variables[id] = value
173
+ return id
174
+
175
+
176
+ def generate_dependent_dispatch(tup, handlers, next_call, slf, name, err, nerr):
177
+ def to_dict(tup):
178
+ return dict(
179
+ entry if isinstance(entry, tuple) else (i, entry)
180
+ for i, entry in enumerate(tup)
181
+ )
182
+
183
+ def argname(x):
184
+ return f"ARG{x}" if isinstance(x, int) else x
185
+
186
+ def argprovide(x):
187
+ return f"ARG{x}" if isinstance(x, int) else f"{x}={x}"
188
+
189
+ def codegen(typ, arg):
190
+ cg = typ.codegen()
191
+ return cg.template.format(
192
+ arg=arg, **{k: gen.add(v) for k, v in cg.substitutions.items()}
193
+ )
194
+
195
+ tup = to_dict(tup)
196
+ handlers = [(h, to_dict(types)) for h, types in handlers]
197
+ gen = GenSym(prefix="INJECT")
198
+ conjs = []
199
+
200
+ exclusive = False
201
+ keyexpr = None
202
+ keyed = None
203
+ for k in tup:
204
+ featured = set(types[k] for h, types in handlers)
205
+ if len(featured) == len(handlers):
206
+ possibilities = set(type(t) for t in featured)
207
+ focus = possibilities.pop()
208
+ # Possibilities is now empty if only one type of DependentType
209
+
210
+ if not possibilities:
211
+ if getattr(focus, "keyable_type", False):
212
+ all_keys = [
213
+ {key: h for key in types[k].get_keys()}
214
+ for h, types in handlers
215
+ ]
216
+ keyed = reduce(lambda a, b: {**a, **b}, all_keys)
217
+ if (
218
+ len(keyed) == sum(map(len, all_keys))
219
+ and len(featured) < 4
220
+ ):
221
+ exclusive = True
222
+ keyexpr = None
223
+ else:
224
+ keyexpr = focus.keygen().format(arg=argname(k))
225
+
226
+ else:
227
+ exclusive = getattr(focus, "exclusive_type", False)
228
+
229
+ for i, (h, types) in enumerate(handlers):
230
+ relevant = [k for k in tup if isinstance(types[k], DependentType)]
231
+ if len(relevant) > 1:
232
+ # The keyexpr method only works if there is only one condition to check.
233
+ keyexpr = keyed = None
234
+ codes = [codegen(types[k], argname(k)) for k in relevant]
235
+ conj = " and ".join(codes)
236
+ if not conj: # pragma: no cover
237
+ # Not sure if this can happen
238
+ conj = "True"
239
+ conjs.append(conj)
240
+
241
+ argspec = ", ".join(argname(x) for x in tup)
242
+ argcall = ", ".join(argprovide(x) for x in tup)
243
+
244
+ body = []
245
+ if keyexpr:
246
+ body.append(f"HANDLER = {gen.add(keyed)}.get({keyexpr}, FALLTHROUGH)")
247
+ body.append(f"return HANDLER({slf}{argcall})")
248
+
249
+ elif exclusive:
250
+ for i, conj in enumerate(conjs):
251
+ body.append(f"if {conj}: return HANDLER{i}({slf}{argcall})")
252
+ body.append(f"return FALLTHROUGH({slf}{argcall})")
253
+
254
+ else:
255
+ for i, conj in enumerate(conjs):
256
+ body.append(f"MATCH{i} = {conj}")
257
+
258
+ summation = " + ".join(f"MATCH{i}" for i in range(len(handlers)))
259
+ body.append(f"SUMMATION = {summation}")
260
+ body.append("if SUMMATION == 1:")
261
+ for i, (h, types) in enumerate(handlers):
262
+ body.append(f" if MATCH{i}: return HANDLER{i}({slf}{argcall})")
263
+ body.append("elif SUMMATION == 0:")
264
+ body.append(f" return FALLTHROUGH({slf}{argcall})")
265
+ body.append("else:")
266
+ body.append(f" raise {gen.add(err)}")
267
+
268
+ body_text = textwrap.indent("\n".join(body), " ")
269
+ code = f"def __DEPENDENT_DISPATCH__({slf}{argspec}):\n{body_text}"
270
+
271
+ inject = gen.variables
272
+ for i, (h, types) in enumerate(handlers):
273
+ inject[f"HANDLER{i}"] = h
274
+
275
+ def raise_error(*args, **kwargs):
276
+ raise nerr
277
+
278
+ inject["FALLTHROUGH"] = (next_call and next_call[0]) or raise_error
279
+
280
+ fn = instantiate_code(
281
+ symbol="__DEPENDENT_DISPATCH__", code=code, inject=inject
282
+ )
283
+ return rename_function(fn, name)
284
+
285
+
286
+ _current = count()
5
287
 
6
288
 
7
289
  class Conformer:
8
- __slots__ = ("code", "orig_fn", "renamed_fn", "ovld", "code2")
290
+ __slots__ = ("code", "orig_fn", "renamed_fn", "ovld")
9
291
 
10
292
  def __init__(self, ovld, orig_fn, renamed_fn):
11
293
  self.ovld = ovld
@@ -21,21 +303,22 @@ class Conformer:
21
303
  new_fn = None
22
304
  new_code = new
23
305
 
24
- if new_code is None:
25
- self.ovld.unregister(self.orig_fn)
26
-
27
- elif new_fn is None: # pragma: no cover
28
- # Not entirely sure if this ever happens
29
- self.renamed_fn.__code__ = new_code
30
-
31
- elif inspect.signature(self.orig_fn) != inspect.signature(new_fn):
32
- self.ovld.unregister(self.orig_fn)
33
- self.ovld.register(new_fn)
306
+ self.ovld.unregister(self.orig_fn)
34
307
 
35
- else:
36
- self.renamed_fn.__code__ = rename_code(
37
- new_code, self.renamed_fn.__code__.co_name
308
+ if new_fn is None: # pragma: no cover
309
+ if new_code is None:
310
+ return
311
+ ofn = self.orig_fn
312
+ new_fn = FunctionType(
313
+ new_code,
314
+ ofn.__globals__,
315
+ ofn.__name__,
316
+ ofn.__defaults__,
317
+ ofn.__closure__,
38
318
  )
319
+ new_fn.__annotations__ = ofn.__annotations__
320
+
321
+ self.ovld.register(new_fn)
39
322
 
40
323
  from codefind import code_registry
41
324
 
@@ -76,5 +359,224 @@ def rename_function(fn, newname):
76
359
  new_fn = FunctionType(
77
360
  newcode, fn.__globals__, newname, fn.__defaults__, fn.__closure__
78
361
  )
362
+ new_fn.__kwdefaults__ = fn.__kwdefaults__
363
+ new_fn.__annotations__ = fn.__annotations__
364
+ return new_fn
365
+
366
+
367
+ class NameConverter(ast.NodeTransformer):
368
+ def __init__(
369
+ self, anal, recurse_sym, call_next_sym, ovld_mangled, code_mangled
370
+ ):
371
+ self.analysis = anal
372
+ self.recurse_sym = recurse_sym
373
+ self.call_next_sym = call_next_sym
374
+ self.ovld_mangled = ovld_mangled
375
+ self.code_mangled = code_mangled
376
+
377
+ def visit_Name(self, node):
378
+ if node.id == self.recurse_sym:
379
+ return ast.copy_location(
380
+ old_node=node,
381
+ new_node=ast.Name(self.ovld_mangled, ctx=node.ctx),
382
+ )
383
+ elif node.id == self.call_next_sym:
384
+ raise UsageError("call_next should be called right away")
385
+ else:
386
+ return node
387
+
388
+ def visit_Call(self, node):
389
+ if not isinstance(node.func, ast.Name) or node.func.id not in (
390
+ self.recurse_sym,
391
+ self.call_next_sym,
392
+ ):
393
+ return self.generic_visit(node)
394
+
395
+ if any(isinstance(arg, ast.Starred) for arg in node.args):
396
+ return self.generic_visit(node)
397
+
398
+ cn = node.func.id == self.call_next_sym
399
+
400
+ def _make_lookup_call(key, arg):
401
+ value = ast.NamedExpr(
402
+ target=ast.Name(id=f"__TMP{key}", ctx=ast.Store()),
403
+ value=self.visit(arg),
404
+ )
405
+ if self.analysis.lookup_for(key) == "self.map.transform":
406
+ func = ast.Attribute(
407
+ value=ast.Name(id="__TMPM", ctx=ast.Load()),
408
+ attr="transform",
409
+ ctx=ast.Load(),
410
+ )
411
+ else:
412
+ func = ast.Name(id="type", ctx=ast.Load())
413
+ return ast.Call(
414
+ func=func,
415
+ args=[value],
416
+ keywords=[],
417
+ )
418
+
419
+ # type index for positional arguments
420
+ type_parts = [
421
+ _make_lookup_call(i, arg) for i, arg in enumerate(node.args)
422
+ ]
423
+
424
+ # type index for keyword arguments
425
+ type_parts += [
426
+ ast.Tuple(
427
+ elts=[
428
+ ast.Constant(value=kw.arg),
429
+ _make_lookup_call(kw.arg, kw.value),
430
+ ],
431
+ ctx=ast.Load(),
432
+ )
433
+ for kw in node.keywords
434
+ ]
435
+
436
+ if cn:
437
+ type_parts.insert(0, ast.Name(id=self.code_mangled, ctx=ast.Load()))
438
+ method = ast.Subscript(
439
+ value=ast.NamedExpr(
440
+ target=ast.Name(id="__TMPM", ctx=ast.Store()),
441
+ value=ast.Attribute(
442
+ value=ast.Name(id=self.ovld_mangled, ctx=ast.Load()),
443
+ attr="map",
444
+ ctx=ast.Load(),
445
+ ),
446
+ ),
447
+ slice=ast.Tuple(
448
+ elts=type_parts,
449
+ ctx=ast.Load(),
450
+ ),
451
+ ctx=ast.Load(),
452
+ )
453
+ if self.analysis.is_method:
454
+ method = ast.Call(
455
+ func=ast.Attribute(
456
+ value=method,
457
+ attr="__get__",
458
+ ctx=ast.Load(),
459
+ ),
460
+ args=[ast.Name(id="self", ctx=ast.Load())],
461
+ keywords=[],
462
+ )
463
+
464
+ new_node = ast.Call(
465
+ func=method,
466
+ args=[
467
+ ast.Name(id=f"__TMP{i}", ctx=ast.Load())
468
+ for i, arg in enumerate(node.args)
469
+ ],
470
+ keywords=[
471
+ ast.keyword(
472
+ arg=kw.arg,
473
+ value=ast.Name(id=f"__TMP{kw.arg}", ctx=ast.Load()),
474
+ )
475
+ for kw in node.keywords
476
+ ],
477
+ )
478
+ return ast.copy_location(old_node=node, new_node=new_node)
479
+
480
+
481
+ def _search_names(co, values, glb, closure=None):
482
+ if isinstance(co, CodeType):
483
+ if closure is not None:
484
+ for varname, cell in zip(co.co_freevars, closure):
485
+ if any(cell.cell_contents is v for v in values):
486
+ yield varname
487
+ for name in co.co_names:
488
+ if any(glb.get(name, None) is v for v in values):
489
+ yield name
490
+ else:
491
+ for ct in co.co_consts:
492
+ yield from _search_names(ct, values, glb)
493
+
494
+
495
+ def adapt_function(fn, ovld, newname):
496
+ """Create a copy of the function with a different name."""
497
+ rec_syms = list(
498
+ _search_names(
499
+ fn.__code__, (recurse, ovld), fn.__globals__, fn.__closure__
500
+ )
501
+ )
502
+ cn_syms = list(
503
+ _search_names(fn.__code__, (call_next,), fn.__globals__, fn.__closure__)
504
+ )
505
+ if rec_syms or cn_syms:
506
+ return recode(
507
+ fn, ovld, rec_syms and rec_syms[0], cn_syms and cn_syms[0], newname
508
+ )
509
+ else:
510
+ return rename_function(fn, newname)
511
+
512
+
513
+ def closure_wrap(tree, fname, names):
514
+ wrap = ast.copy_location(
515
+ ast.FunctionDef(
516
+ name="##create_closure",
517
+ args=ast.arguments(
518
+ posonlyargs=[],
519
+ args=[ast.arg(arg=name) for name in names],
520
+ vararg=None,
521
+ kwonlyargs=[],
522
+ kw_defaults=[],
523
+ kwarg=None,
524
+ defaults=[],
525
+ ),
526
+ body=[
527
+ tree,
528
+ ast.Return(ast.Name(id=fname, ctx=ast.Load())),
529
+ ],
530
+ decorator_list=[],
531
+ returns=None,
532
+ ),
533
+ tree,
534
+ )
535
+ ast.fix_missing_locations(wrap)
536
+ return ast.Module(body=[wrap], type_ignores=[])
537
+
538
+
539
+ def recode(fn, ovld, recurse_sym, call_next_sym, newname):
540
+ ovld_mangled = f"___OVLD{ovld.id}"
541
+ code_mangled = f"___CODE{next(_current)}"
542
+ try:
543
+ src = inspect.getsource(fn)
544
+ except OSError: # pragma: no cover
545
+ raise OSError(
546
+ f"ovld is unable to rewrite {fn} because it cannot read its source code."
547
+ " It may be an issue with __pycache__, so try to either change the source"
548
+ " to force a refresh, or remove __pycache__ altogether. If that does not work,"
549
+ " avoid calling recurse()/call_next()"
550
+ )
551
+ tree = ast.parse(textwrap.dedent(src))
552
+ new = NameConverter(
553
+ anal=ovld.argument_analysis,
554
+ recurse_sym=recurse_sym,
555
+ call_next_sym=call_next_sym,
556
+ ovld_mangled=ovld_mangled,
557
+ code_mangled=code_mangled,
558
+ ).visit(tree)
559
+ new.body[0].decorator_list = []
560
+ if fn.__closure__:
561
+ new = closure_wrap(new.body[0], "irrelevant", fn.__code__.co_freevars)
562
+ ast.fix_missing_locations(new)
563
+ ast.increment_lineno(new, fn.__code__.co_firstlineno - 1)
564
+ res = compile(new, mode="exec", filename=fn.__code__.co_filename)
565
+ if fn.__closure__:
566
+ res = [x for x in res.co_consts if isinstance(x, CodeType)][0]
567
+ (*_, new_code) = [ct for ct in res.co_consts if isinstance(ct, CodeType)]
568
+ new_closure = tuple(
569
+ [
570
+ fn.__closure__[fn.__code__.co_freevars.index(name)]
571
+ for name in new_code.co_freevars
572
+ ]
573
+ )
574
+ new_fn = FunctionType(
575
+ new_code, fn.__globals__, newname, fn.__defaults__, new_closure
576
+ )
577
+ new_fn.__kwdefaults__ = fn.__kwdefaults__
79
578
  new_fn.__annotations__ = fn.__annotations__
579
+ new_fn = rename_function(new_fn, newname)
580
+ new_fn.__globals__[ovld_mangled] = ovld
581
+ new_fn.__globals__[code_mangled] = new_fn.__code__
80
582
  return new_fn