ygg 0.1.20__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,563 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import dis
5
+ import importlib
6
+ import inspect
7
+ import json
8
+ import struct
9
+ import sys
10
+ import textwrap
11
+ import zlib
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar, Union
15
+
16
+ import dill
17
+
18
+ __all__ = ["CallableSerde"]
19
+
20
+ T = TypeVar("T", bound="CallableSerde")
21
+
22
+
23
+ # ---------- internal helpers ----------
24
+
25
+ _MAGIC = b"CS1" # CallableSerde framing v1
26
+ _FLAG_COMPRESSED = 1
27
+
28
+
29
+ def _resolve_attr_chain(mod: Any, qualname: str) -> Any:
30
+ obj = mod
31
+ for part in qualname.split("."):
32
+ obj = getattr(obj, part)
33
+ return obj
34
+
35
+
36
+ def _find_pkg_root_from_file(file_path: Path) -> Optional[Path]:
37
+ """
38
+ Walk up parents while __init__.py exists.
39
+ Return the directory that should be on sys.path (parent of top package dir).
40
+ """
41
+ file_path = file_path.resolve()
42
+ d = file_path.parent
43
+
44
+ top_pkg_dir = None
45
+ while (d / "__init__.py").is_file():
46
+ top_pkg_dir = d
47
+ d = d.parent
48
+
49
+ return top_pkg_dir.parent if top_pkg_dir else None
50
+
51
+
52
+ def _callable_file_line(fn: Callable[..., Any]) -> Tuple[Optional[str], Optional[int]]:
53
+ file = None
54
+ line = None
55
+ try:
56
+ file = inspect.getsourcefile(fn) or inspect.getfile(fn)
57
+ except Exception:
58
+ file = None
59
+ if file:
60
+ try:
61
+ _, line = inspect.getsourcelines(fn)
62
+ except Exception:
63
+ line = None
64
+ return file, line
65
+
66
+
67
+ def _referenced_global_names(fn: Callable[..., Any]) -> Set[str]:
68
+ """
69
+ Names that the function *actually* resolves from globals/namespaces at runtime.
70
+ Uses bytecode to avoid shipping random junk.
71
+ """
72
+ names: Set[str] = set()
73
+ try:
74
+ for ins in dis.get_instructions(fn):
75
+ if ins.opname in ("LOAD_GLOBAL", "LOAD_NAME") and isinstance(ins.argval, str):
76
+ names.add(ins.argval)
77
+ except Exception:
78
+ # fallback: less precise
79
+ try:
80
+ names.update(getattr(fn.__code__, "co_names", ()) or ())
81
+ except Exception:
82
+ pass
83
+
84
+ names.discard("__builtins__")
85
+ return names
86
+
87
+
88
+ def _is_importable_reference(fn: Callable[..., Any]) -> bool:
89
+ mod_name = getattr(fn, "__module__", None)
90
+ qualname = getattr(fn, "__qualname__", None)
91
+ if not mod_name or not qualname:
92
+ return False
93
+ if "<locals>" in qualname:
94
+ return False
95
+ try:
96
+ mod = importlib.import_module(mod_name)
97
+ obj = _resolve_attr_chain(mod, qualname)
98
+ return callable(obj)
99
+ except Exception:
100
+ return False
101
+
102
+
103
+ def _pick_zlib_level(n: int, limit: int) -> int:
104
+ """
105
+ Ramp compression level 1..9 based on how much payload exceeds byte_limit.
106
+ ratio=1 -> level=1
107
+ ratio=4 -> level=9
108
+ clamp beyond.
109
+ """
110
+ ratio = n / max(1, limit)
111
+ x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
112
+ return max(1, min(9, int(round(1 + 8 * x))))
113
+
114
+
115
+ def _encode_result_blob(raw: bytes, byte_limit: int) -> bytes:
116
+ """
117
+ For small payloads: return raw dill bytes (backwards compat).
118
+ For large payloads: wrap in framed header + zlib-compressed bytes (if beneficial).
119
+ """
120
+ if len(raw) <= byte_limit:
121
+ return raw
122
+
123
+ level = _pick_zlib_level(len(raw), byte_limit)
124
+ compressed = zlib.compress(raw, level)
125
+
126
+ # If compression doesn't help, keep raw
127
+ if len(compressed) >= len(raw):
128
+ return raw
129
+
130
+ # Frame:
131
+ # MAGIC(3) + flags(u8) + orig_len(u32) + level(u8) + data
132
+ header = _MAGIC + struct.pack(">BIB", _FLAG_COMPRESSED, len(raw), level)
133
+ return header + compressed
134
+
135
+
136
+ def _decode_result_blob(blob: bytes) -> bytes:
137
+ """
138
+ If framed, decompress if flagged, return raw dill bytes.
139
+ Else treat as raw dill bytes.
140
+ """
141
+ if not blob.startswith(_MAGIC):
142
+ return blob
143
+
144
+ if len(blob) < 3 + 1 + 4 + 1:
145
+ raise ValueError("Framed result too short / corrupted.")
146
+
147
+ flags, orig_len, _level = struct.unpack(">BIB", blob[3 : 3 + 1 + 4 + 1])
148
+ data = blob[3 + 1 + 4 + 1 :]
149
+
150
+ if flags & _FLAG_COMPRESSED:
151
+ raw = zlib.decompress(data)
152
+ if orig_len and len(raw) != orig_len:
153
+ raise ValueError(f"Decompressed length mismatch: got {len(raw)}, expected {orig_len}")
154
+ return raw
155
+
156
+ return data
157
+
158
+
159
+ def _dump_env(
160
+ fn: Callable[..., Any],
161
+ *,
162
+ include_globals: bool,
163
+ include_closure: bool,
164
+ filter_used_globals: bool,
165
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
166
+ """
167
+ Returns (env, meta).
168
+ env is dill-able and contains:
169
+ - "globals": {name: value} (filtered to used names if enabled)
170
+ - "closure": {freevar: value} (capture only; injection not generally safe)
171
+ """
172
+ env: Dict[str, Any] = {}
173
+ meta: Dict[str, Any] = {
174
+ "missing_globals": [],
175
+ "skipped_globals": [],
176
+ "skipped_closure": [],
177
+ "filter_used_globals": bool(filter_used_globals),
178
+ }
179
+
180
+ if include_globals:
181
+ g = getattr(fn, "__globals__", None) or {}
182
+ names = sorted(_referenced_global_names(fn)) if filter_used_globals else sorted(set(g.keys()))
183
+
184
+ env_g: Dict[str, Any] = {}
185
+ for name in names:
186
+ if name not in g:
187
+ meta["missing_globals"].append(name)
188
+ continue
189
+ try:
190
+ dill.dumps(g[name], recurse=True)
191
+ env_g[name] = g[name]
192
+ except Exception:
193
+ meta["skipped_globals"].append(name)
194
+
195
+ if env_g:
196
+ env["globals"] = env_g
197
+
198
+ if include_closure:
199
+ freevars = getattr(getattr(fn, "__code__", None), "co_freevars", ()) or ()
200
+ closure = getattr(fn, "__closure__", None) or ()
201
+
202
+ clo: Dict[str, Any] = {}
203
+ if freevars and closure and len(freevars) == len(closure):
204
+ for name, cell in zip(freevars, closure):
205
+ try:
206
+ val = cell.cell_contents
207
+ dill.dumps(val, recurse=True)
208
+ clo[name] = val
209
+ except Exception:
210
+ meta["skipped_closure"].append(name)
211
+
212
+ if clo:
213
+ env["closure"] = clo
214
+
215
+ return env, meta
216
+
217
+
218
+ # ---------- main class ----------
219
+
220
+ @dataclass
221
+ class CallableSerde:
222
+ """
223
+ Core field: `fn`
224
+ Serialized/backing fields used when fn isn't present yet.
225
+
226
+ kind:
227
+ - "auto": resolve import if possible else dill
228
+ - "import": module + qualname
229
+ - "dill": dill_b64
230
+
231
+ Optional env payload:
232
+ - env_b64: dill(base64) of {"globals": {...}, "closure": {...}}
233
+ """
234
+ fn: Optional[Callable[..., Any]] = None
235
+
236
+ _kind: str = "auto" # "auto" | "import" | "dill"
237
+ _module: Optional[str] = None
238
+ _qualname: Optional[str] = None
239
+ _pkg_root: Optional[str] = None
240
+ _dill_b64: Optional[str] = None
241
+
242
+ _env_b64: Optional[str] = None
243
+ _env_meta: Optional[Dict[str, Any]] = None
244
+
245
+ # ----- construction -----
246
+
247
+ @classmethod
248
+ def from_callable(cls: type[T], x: Union[Callable[..., Any], T]) -> T:
249
+ if isinstance(x, cls):
250
+ return x
251
+ return cls(fn=x) # type: ignore[return-value]
252
+
253
+ # ----- lazy-ish properties (computed on access) -----
254
+
255
+ @property
256
+ def module(self) -> Optional[str]:
257
+ return self._module or (getattr(self.fn, "__module__", None) if self.fn else None)
258
+
259
+ @property
260
+ def qualname(self) -> Optional[str]:
261
+ return self._qualname or (getattr(self.fn, "__qualname__", None) if self.fn else None)
262
+
263
+ @property
264
+ def file(self) -> Optional[str]:
265
+ if not self.fn:
266
+ return None
267
+ f, _ = _callable_file_line(self.fn)
268
+ return f
269
+
270
+ @property
271
+ def line(self) -> Optional[int]:
272
+ if not self.fn:
273
+ return None
274
+ _, ln = _callable_file_line(self.fn)
275
+ return ln
276
+
277
+ @property
278
+ def pkg_root(self) -> Optional[str]:
279
+ if self._pkg_root:
280
+ return self._pkg_root
281
+ if not self.file:
282
+ return None
283
+ root = _find_pkg_root_from_file(Path(self.file))
284
+ return str(root) if root else None
285
+
286
+ @property
287
+ def relpath_from_pkg_root(self) -> Optional[str]:
288
+ if not self.file or not self.pkg_root:
289
+ return None
290
+ try:
291
+ return str(Path(self.file).resolve().relative_to(Path(self.pkg_root).resolve()))
292
+ except Exception:
293
+ return self.file
294
+
295
+ @property
296
+ def importable(self) -> bool:
297
+ if self.fn is None:
298
+ return bool(self.module and self.qualname and "<locals>" not in (self.qualname or ""))
299
+ return _is_importable_reference(self.fn)
300
+
301
+ # ----- serde API -----
302
+
303
+ def dump(
304
+ self,
305
+ *,
306
+ prefer: str = "import", # "import" | "dill"
307
+ dump_env: str = "none", # "none" | "globals" | "closure" | "both"
308
+ filter_used_globals: bool = True,
309
+ ) -> Dict[str, Any]:
310
+ kind = prefer
311
+ if kind == "import" and not self.importable:
312
+ kind = "dill"
313
+
314
+ out: Dict[str, Any] = {
315
+ "kind": kind,
316
+ "module": self.module,
317
+ "qualname": self.qualname,
318
+ "pkg_root": self.pkg_root,
319
+ "file": self.file,
320
+ "line": self.line,
321
+ "relpath_from_pkg_root": self.relpath_from_pkg_root,
322
+ }
323
+
324
+ if kind == "dill":
325
+ if self._dill_b64 is None:
326
+ if self.fn is None:
327
+ raise ValueError("No callable available to dill-dump.")
328
+ payload = dill.dumps(self.fn, recurse=True)
329
+ self._dill_b64 = base64.b64encode(payload).decode("ascii")
330
+ out["dill_b64"] = self._dill_b64
331
+
332
+ if dump_env != "none":
333
+ if self.fn is None:
334
+ raise ValueError("dump_env requested but fn is not present.")
335
+ include_globals = dump_env in ("globals", "both")
336
+ include_closure = dump_env in ("closure", "both")
337
+ env, meta = _dump_env(
338
+ self.fn,
339
+ include_globals=include_globals,
340
+ include_closure=include_closure,
341
+ filter_used_globals=filter_used_globals,
342
+ )
343
+ self._env_meta = meta
344
+ if env:
345
+ self._env_b64 = base64.b64encode(dill.dumps(env, recurse=True)).decode("ascii")
346
+ out["env_b64"] = self._env_b64
347
+ out["env_meta"] = meta
348
+
349
+ return out
350
+
351
+ @classmethod
352
+ def load(cls: type[T], d: Dict[str, Any], *, add_pkg_root_to_syspath: bool = True) -> T:
353
+ obj = cls(
354
+ fn=None,
355
+ _kind=d.get("kind", "auto"),
356
+ _module=d.get("module"),
357
+ _qualname=d.get("qualname"),
358
+ _pkg_root=d.get("pkg_root"),
359
+ _dill_b64=d.get("dill_b64"),
360
+ )
361
+ obj._env_b64 = d.get("env_b64")
362
+ obj._env_meta = d.get("env_meta")
363
+
364
+ if add_pkg_root_to_syspath and obj._pkg_root and obj._pkg_root not in sys.path:
365
+ sys.path.insert(0, obj._pkg_root)
366
+
367
+ return obj # type: ignore[return-value]
368
+
369
+ def materialize(self, *, add_pkg_root_to_syspath: bool = True) -> Callable[..., Any]:
370
+ if self.fn is not None:
371
+ return self.fn
372
+
373
+ if add_pkg_root_to_syspath and self.pkg_root and self.pkg_root not in sys.path:
374
+ sys.path.insert(0, self.pkg_root)
375
+
376
+ kind = self._kind
377
+ if kind == "auto":
378
+ kind = "import" if (self.module and self.qualname and "<locals>" not in (self.qualname or "")) else "dill"
379
+
380
+ if kind == "import":
381
+ if not self.module or not self.qualname:
382
+ raise ValueError("Missing module/qualname for import load.")
383
+ mod = importlib.import_module(self.module)
384
+ fn = _resolve_attr_chain(mod, self.qualname)
385
+ if not callable(fn):
386
+ raise TypeError("Imported object is not callable.")
387
+ self.fn = fn
388
+ return fn
389
+
390
+ if kind == "dill":
391
+ if not self._dill_b64:
392
+ raise ValueError("Missing dill_b64 for dill load.")
393
+ payload = base64.b64decode(self._dill_b64.encode("ascii"))
394
+ fn = dill.loads(payload)
395
+ if not callable(fn):
396
+ raise TypeError("Dill payload did not decode to a callable.")
397
+ self.fn = fn
398
+ return fn
399
+
400
+ raise ValueError(f"Unknown kind: {kind}")
401
+
402
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
403
+ fn = self.materialize()
404
+ return fn(*args, **kwargs)
405
+
406
+ # ----- command execution bridge -----
407
+
408
+ def to_command(
409
+ self,
410
+ args: Tuple[Any, ...] = (),
411
+ kwargs: Optional[Dict[str, Any]] = None,
412
+ *,
413
+ result_tag: str = "__CALLABLE_SERDE_RESULT__",
414
+ prefer: str = "dill",
415
+ byte_limit: int = 256_000,
416
+ dump_env: str = "none", # "none" | "globals" | "closure" | "both"
417
+ filter_used_globals: bool = True,
418
+ ) -> str:
419
+ """
420
+ Returns Python code string to execute in another interpreter.
421
+ Prints one line: "{result_tag}:{base64(blob)}"
422
+ where blob is raw dill bytes or framed+zlib.
423
+ """
424
+ kwargs = kwargs or {}
425
+
426
+ serde_dict = self.dump(
427
+ prefer=prefer,
428
+ dump_env=dump_env,
429
+ filter_used_globals=filter_used_globals,
430
+ )
431
+ serde_json = json.dumps(serde_dict, ensure_ascii=False)
432
+
433
+ call_payload_b64 = base64.b64encode(
434
+ dill.dumps((args, kwargs), recurse=True)
435
+ ).decode("ascii")
436
+
437
+ # NOTE: plain string template + replace. No f-string. No brace escaping.
438
+ template = r"""
439
+ import base64, json, sys, struct, zlib, importlib, dis
440
+ import dill
441
+
442
+ RESULT_TAG = __RESULT_TAG__
443
+ BYTE_LIMIT = __BYTE_LIMIT__
444
+
445
+ MAGIC = b"CS1"
446
+ FLAG_COMPRESSED = 1
447
+
448
+ def _resolve_attr_chain(mod, qualname: str):
449
+ obj = mod
450
+ for part in qualname.split("."):
451
+ obj = getattr(obj, part)
452
+ return obj
453
+
454
+ def _pick_level(n: int, limit: int) -> int:
455
+ ratio = n / max(1, limit)
456
+ x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
457
+ return max(1, min(9, int(round(1 + 8 * x))))
458
+
459
+ def _encode_result(raw: bytes, byte_limit: int) -> bytes:
460
+ if len(raw) <= byte_limit:
461
+ return raw
462
+ level = _pick_level(len(raw), byte_limit)
463
+ compressed = zlib.compress(raw, level)
464
+ if len(compressed) >= len(raw):
465
+ return raw
466
+ header = MAGIC + struct.pack(">BIB", FLAG_COMPRESSED, len(raw), level)
467
+ return header + compressed
468
+
469
+ def _needed_globals(fn) -> set[str]:
470
+ names = set()
471
+ try:
472
+ for ins in dis.get_instructions(fn):
473
+ if ins.opname in ("LOAD_GLOBAL", "LOAD_NAME") and isinstance(ins.argval, str):
474
+ names.add(ins.argval)
475
+ except Exception:
476
+ try:
477
+ names.update(getattr(fn.__code__, "co_names", ()) or ())
478
+ except Exception:
479
+ pass
480
+ names.discard("__builtins__")
481
+ return names
482
+
483
+ def _apply_env(fn, env: dict, filter_used: bool):
484
+ if not env:
485
+ return
486
+ g = getattr(fn, "__globals__", None)
487
+ if not isinstance(g, dict):
488
+ return
489
+
490
+ env_g = env.get("globals") or {}
491
+ if env_g:
492
+ if filter_used:
493
+ needed = _needed_globals(fn)
494
+ for name in needed:
495
+ if name in env_g:
496
+ g.setdefault(name, env_g[name])
497
+ else:
498
+ for name, val in env_g.items():
499
+ g.setdefault(name, val)
500
+
501
+ serde = json.loads(__SERDE_JSON__)
502
+
503
+ pkg_root = serde.get("pkg_root")
504
+ if pkg_root and pkg_root not in sys.path:
505
+ sys.path.insert(0, pkg_root)
506
+
507
+ kind = serde.get("kind")
508
+ if kind == "import":
509
+ mod = importlib.import_module(serde["module"])
510
+ fn = _resolve_attr_chain(mod, serde["qualname"])
511
+ elif kind == "dill":
512
+ fn = dill.loads(base64.b64decode(serde["dill_b64"]))
513
+ else:
514
+ if serde.get("module") and serde.get("qualname") and "<locals>" not in serde.get("qualname", ""):
515
+ mod = importlib.import_module(serde["module"])
516
+ fn = _resolve_attr_chain(mod, serde["qualname"])
517
+ else:
518
+ fn = dill.loads(base64.b64decode(serde["dill_b64"]))
519
+
520
+ env_b64 = serde.get("env_b64")
521
+ if env_b64:
522
+ env = dill.loads(base64.b64decode(env_b64))
523
+ meta = serde.get("env_meta") or {}
524
+ _apply_env(fn, env, bool(meta.get("filter_used_globals", True)))
525
+
526
+ args, kwargs = dill.loads(base64.b64decode(__CALL_PAYLOAD_B64__))
527
+
528
+ res = fn(*args, **kwargs)
529
+ raw = dill.dumps(res, recurse=True)
530
+ blob = _encode_result(raw, BYTE_LIMIT)
531
+
532
+ # No f-string. No braces. No drama.
533
+ sys.stdout.write(str(RESULT_TAG) + ":" + base64.b64encode(blob).decode("ascii") + "\n")
534
+ """.strip()
535
+
536
+ code = (
537
+ template
538
+ .replace("__RESULT_TAG__", repr(result_tag))
539
+ .replace("__BYTE_LIMIT__", str(int(byte_limit)))
540
+ .replace("__SERDE_JSON__", repr(serde_json))
541
+ .replace("__CALL_PAYLOAD_B64__", repr(call_payload_b64))
542
+ )
543
+
544
+ return code
545
+
546
+ @staticmethod
547
+ def parse_command_result(output: str, *, result_tag: str = "__CALLABLE_SERDE_RESULT__") -> Any:
548
+ """
549
+ Parse stdout/stderr combined text, find last "{result_tag}:{b64}" line.
550
+ Supports raw dill or framed+zlib compressed payloads.
551
+ """
552
+ prefix = f"{result_tag}:"
553
+ b64 = None
554
+ for line in reversed(output.splitlines()):
555
+ if line.startswith(prefix):
556
+ b64 = line[len(prefix):].strip()
557
+ break
558
+ if not b64:
559
+ raise ValueError(f"Result tag not found in output: {result_tag!r}")
560
+
561
+ blob = base64.b64decode(b64.encode("ascii"))
562
+ raw = _decode_result_blob(blob)
563
+ return dill.loads(raw)