ygg 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,645 +0,0 @@
1
- # yggdrasil/ser.py
2
- from __future__ import annotations
3
-
4
- import ast
5
- import base64
6
- import builtins
7
- import inspect
8
- import json
9
- import os
10
- import sys
11
- import textwrap
12
- import zlib
13
- from dataclasses import dataclass, field
14
- from pathlib import Path
15
- from typing import Any, Callable, Dict, Iterable, Optional, Tuple
16
-
17
- import dill
18
-
19
- PyVer = Tuple[int, int]
20
-
21
-
22
- def _pyver() -> PyVer:
23
- v = sys.version_info
24
- return v.major, v.minor
25
-
26
-
27
- def _b64e(b: bytes) -> str:
28
- return base64.b64encode(b).decode("ascii")
29
-
30
-
31
- def _b64d(s: str) -> bytes:
32
- return base64.b64decode(s.encode("ascii"))
33
-
34
-
35
- def _is_python_function(obj: Any) -> bool:
36
- return inspect.isfunction(obj) or inspect.ismethod(obj) or isinstance(obj, type(lambda: 0))
37
-
38
-
39
- def _safe_dill_dumps(obj: Any) -> Optional[bytes]:
40
- try:
41
- return dill.dumps(obj, recurse=True)
42
- except Exception:
43
- return None
44
-
45
-
46
- def _safe_dill_loads(b: Optional[bytes]) -> Any:
47
- if b is None:
48
- raise ValueError("No dill bytes to load")
49
- return dill.loads(b)
50
-
51
-
52
- def _infer_package_root(fn: Callable[..., Any]) -> Tuple[str, ...]:
53
- """
54
- Return the *topmost* package directory for the module defining `fn`.
55
-
56
- Walk upwards from the module file directory while an __init__.py exists.
57
- Example:
58
- /repo/my_pkg/sub/mod.py -> returns /repo/my_pkg
59
- If no __init__.py is found, falls back to the module's parent directory.
60
- """
61
- try:
62
- mod = inspect.getmodule(fn)
63
- file = getattr(mod, "__file__", None)
64
- if not file:
65
- return tuple()
66
-
67
- p = Path(file).resolve()
68
- if p.suffix.lower() != ".py":
69
- return tuple()
70
-
71
- cur = p.parent # start at directory containing the module file
72
-
73
- # If this directory isn't a package at all, just return it
74
- if not (cur / "__init__.py").exists():
75
- return (str(cur),)
76
-
77
- # Climb while parent is still a package (has __init__.py)
78
- # Stop at the highest directory that is still a package.
79
- top = cur
80
- while True:
81
- parent = top.parent
82
- if parent == top:
83
- break
84
- if (parent / "__init__.py").exists():
85
- top = parent
86
- continue
87
- break
88
-
89
- return (str(top),)
90
- except Exception:
91
- return tuple()
92
-
93
-
94
- def _capture_exec_env(fn: Callable[..., Any]) -> Dict[str, Dict[str, str]]:
95
- """
96
- Capture a minimal environment for exec(source) reconstruction.
97
-
98
- - globals: names referenced by fn.__code__.co_names that exist in fn.__globals__ (excluding builtins)
99
- - freevars: closure cell contents for fn.__code__.co_freevars
100
-
101
- Values are stored as base64(dill.dumps(value)) (best-effort).
102
- """
103
- code = fn.__code__
104
- fn_globals = getattr(fn, "__globals__", {}) or {}
105
-
106
- referenced = set(code.co_names or ())
107
- builtins_set = set(dir(builtins))
108
-
109
- g_payload: Dict[str, str] = {}
110
- for name in referenced:
111
- if name in builtins_set:
112
- continue
113
- if name in fn_globals:
114
- dumped = _safe_dill_dumps(fn_globals[name])
115
- if dumped is not None:
116
- g_payload[name] = _b64e(dumped)
117
-
118
- fv_payload: Dict[str, str] = {}
119
- freevars = code.co_freevars or ()
120
- closure = fn.__closure__ or ()
121
- if freevars and closure and len(freevars) == len(closure):
122
- for name, cell in zip(freevars, closure):
123
- try:
124
- val = cell.cell_contents
125
- except ValueError:
126
- continue
127
- dumped = _safe_dill_dumps(val)
128
- if dumped is not None:
129
- fv_payload[name] = _b64e(dumped)
130
-
131
- return {"globals": g_payload, "freevars": fv_payload}
132
-
133
-
134
- def _capture_module_imports(fn: Callable[..., Any]) -> str:
135
- """
136
- Capture top-level imports from the module file where `fn` is defined.
137
-
138
- We extract only:
139
- - `import x`
140
- - `import x as y`
141
- - `from a.b import c`
142
- - `from a.b import c as d`
143
-
144
- Returned as a single string block (may be empty).
145
- """
146
- try:
147
- src_file = inspect.getsourcefile(fn) or inspect.getfile(fn)
148
- if not src_file:
149
- return ""
150
- p = Path(src_file)
151
- if not p.exists() or p.suffix.lower() != ".py":
152
- return ""
153
-
154
- text = p.read_text(encoding="utf-8")
155
- tree = ast.parse(text)
156
-
157
- imports: list[str] = []
158
- for node in tree.body:
159
- if isinstance(node, (ast.Import, ast.ImportFrom)):
160
- seg = ast.get_source_segment(text, node)
161
- if seg:
162
- imports.append(seg.strip())
163
-
164
- # de-dupe while preserving order
165
- seen = set()
166
- uniq: list[str] = []
167
- for line in imports:
168
- if line not in seen:
169
- seen.add(line)
170
- uniq.append(line)
171
-
172
- return "\n".join(uniq).strip() + ("\n" if uniq else "")
173
- except Exception:
174
- return ""
175
-
176
-
177
- def parse_tagged_result(stdout_text: str, result_tag: str) -> Dict[str, Any]:
178
- """
179
- Extract the last tagged JSON payload printed by a cluster command.
180
- Expects lines like: <<<RESULT>>>{...json...}
181
- """
182
- last = None
183
- for line in stdout_text.splitlines():
184
- if line.startswith(result_tag):
185
- last = line[len(result_tag) :]
186
- if last is None:
187
- raise ValueError(f"Result tag {result_tag!r} not found in output")
188
- try:
189
- return json.loads(last)
190
- except Exception as e:
191
- raise ValueError("Tagged result was not valid JSON") from e
192
-
193
-
194
- class CommandError(RuntimeError):
195
- """
196
- Raised when the remote Databricks command returns ok=false.
197
-
198
- Attributes
199
- ----------
200
- error: str
201
- Short error message from remote.
202
- traceback: str
203
- Remote traceback (string) if provided.
204
- raw: dict
205
- Full parsed JSON payload from the tagged output line.
206
- """
207
-
208
- def __init__(self, error: str, traceback: str = "", raw: Optional[Dict[str, Any]] = None):
209
- self.error = error
210
- self.traceback = traceback
211
- self.raw = raw or {}
212
- msg = error if not traceback else f"{error}\n{traceback}"
213
- super().__init__(msg.rstrip())
214
-
215
-
216
- class CommandResultParseError(ValueError):
217
- """
218
- Raised when the command output cannot be parsed or decoded.
219
- """
220
-
221
- def __init__(self, message: str, stdout_text: Optional[str] = None):
222
- self.stdout_text = stdout_text
223
- super().__init__(message)
224
-
225
-
226
- @dataclass
227
- class CallableSerdeMixin:
228
- """
229
- Encapsulates a callable + serialization strategy + Databricks command generation.
230
-
231
- - Same Python (major, minor): prefer dill
232
- - Different Python (major, minor): exec(imports + source) with captured globals/freevars
233
- """
234
-
235
- fn: Callable[..., Any]
236
- package_root: Tuple[str, ...] = field(default_factory=tuple)
237
- ALLOW_EXEC_SOURCE: bool = True
238
-
239
- @classmethod
240
- def from_callable(cls, fn: Callable[..., Any]) -> "CallableSerdeMixin":
241
- if isinstance(fn, CallableSerdeMixin):
242
- return fn
243
- return cls(fn=fn, package_root=_infer_package_root(fn))
244
-
245
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
246
- return self.fn(*args, **kwargs)
247
-
248
- # ---------- pickle protocol ----------
249
- def __getstate__(self) -> Dict[str, Any]:
250
- if _is_python_function(self.fn):
251
- src = None
252
- try:
253
- src = textwrap.dedent(inspect.getsource(self.fn))
254
- except Exception:
255
- src = None
256
-
257
- payload: Dict[str, Any] = {
258
- "__callable__": True,
259
- "pyver": list(_pyver()),
260
- "name": getattr(self.fn, "__name__", None),
261
- "qualname": getattr(self.fn, "__qualname__", None),
262
- "module": getattr(self.fn, "__module__", None),
263
- "imports": _capture_module_imports(self.fn), # NEW
264
- "source": src,
265
- "dill_b64": None,
266
- "env": _capture_exec_env(self.fn),
267
- }
268
-
269
- dumped = _safe_dill_dumps(self.fn)
270
- if dumped is not None:
271
- payload["dill_b64"] = _b64e(dumped)
272
-
273
- return {
274
- "fn": payload,
275
- "package_root": tuple(self.package_root),
276
- "ALLOW_EXEC_SOURCE": bool(self.ALLOW_EXEC_SOURCE),
277
- }
278
-
279
- dumped = _safe_dill_dumps(self.fn)
280
- if dumped is None:
281
- raise ValueError("Callable object could not be dill-serialized")
282
- return {
283
- "fn": {"__callable__": False, "dill_b64": _b64e(dumped)},
284
- "package_root": tuple(self.package_root),
285
- "ALLOW_EXEC_SOURCE": bool(self.ALLOW_EXEC_SOURCE),
286
- }
287
-
288
- def __setstate__(self, state: Dict[str, Any]) -> None:
289
- self.package_root = tuple(state.get("package_root") or ())
290
- self.ALLOW_EXEC_SOURCE = bool(state.get("ALLOW_EXEC_SOURCE", True))
291
-
292
- fn_payload = state["fn"]
293
-
294
- if isinstance(fn_payload, dict) and fn_payload.get("__callable__") is False:
295
- self.fn = _safe_dill_loads(_b64d(fn_payload["dill_b64"]))
296
- return
297
-
298
- if not isinstance(fn_payload, dict) or fn_payload.get("__callable__") is not True:
299
- raise ValueError("Invalid callable payload")
300
-
301
- src_pyver = tuple(fn_payload.get("pyver") or ())
302
- cur_pyver = _pyver()
303
-
304
- if src_pyver == cur_pyver and fn_payload.get("dill_b64"):
305
- try:
306
- self.fn = _safe_dill_loads(_b64d(fn_payload["dill_b64"]))
307
- if callable(self.fn):
308
- return
309
- except Exception:
310
- pass
311
-
312
- if not self.ALLOW_EXEC_SOURCE:
313
- raise ValueError("Exec-based restore disabled and dill path unavailable/failed")
314
-
315
- imports = fn_payload.get("imports") or ""
316
- source = fn_payload.get("source")
317
- name = fn_payload.get("name")
318
- env = fn_payload.get("env") or {}
319
-
320
- if not source:
321
- raise ValueError("No source available for exec-based restore")
322
-
323
- ns: Dict[str, Any] = {}
324
-
325
- # preload captured names
326
- for bucket in ("globals", "freevars"):
327
- items = (env.get(bucket) or {})
328
- for k, b64 in items.items():
329
- try:
330
- ns[k] = _safe_dill_loads(_b64d(b64))
331
- except Exception:
332
- pass
333
-
334
- # NEW: run module imports first
335
- if imports.strip():
336
- exec(imports, ns, ns)
337
-
338
- exec(source, ns, ns)
339
-
340
- if name and name in ns and callable(ns[name]):
341
- self.fn = ns[name]
342
- return
343
-
344
- cands = [v for v in ns.values() if callable(v)]
345
- if cands:
346
- self.fn = cands[-1]
347
- return
348
-
349
- raise ValueError("exec(source) succeeded but no callable could be recovered")
350
-
351
- # ---------- Databricks command generation ----------
352
- def to_command(
353
- self,
354
- *,
355
- args: Tuple[Any, ...] = (),
356
- kwargs: Optional[Dict[str, Any]] = None,
357
- env_keys: Iterable[str] = (),
358
- env_variables: Optional[Dict[str, str]] = None,
359
- use_dill: bool,
360
- byte_limit: int = 0,
361
- result_tag: str = "<<<RESULT>>>",
362
- ) -> str:
363
- if kwargs is None:
364
- kwargs = {}
365
- if env_variables is None:
366
- env_variables = {}
367
-
368
- # capture env vars from client process
369
- client_env: Dict[str, str] = {}
370
- for k in env_keys:
371
- v = os.environ.get(k)
372
- if v is not None:
373
- client_env[k] = v
374
- client_env.update(env_variables)
375
-
376
- if not _is_python_function(self.fn):
377
- raise ValueError("to_command supports Python functions/methods/lambdas only")
378
-
379
- src = None
380
- try:
381
- src = textwrap.dedent(inspect.getsource(self.fn))
382
- except Exception:
383
- src = None
384
-
385
- imports = _capture_module_imports(self.fn)
386
-
387
- callable_payload: Dict[str, Any] = {
388
- "__callable__": True,
389
- "pyver": list(_pyver()),
390
- "name": getattr(self.fn, "__name__", None),
391
- "qualname": getattr(self.fn, "__qualname__", None),
392
- "module": getattr(self.fn, "__module__", None),
393
- "imports": imports, # NEW
394
- "source": src,
395
- "dill_b64": None,
396
- "env": _capture_exec_env(self.fn),
397
- }
398
- dumped_fn = _safe_dill_dumps(self.fn)
399
- if dumped_fn is not None:
400
- callable_payload["dill_b64"] = _b64e(dumped_fn)
401
-
402
- # args/kwargs transport
403
- dumped_args = _safe_dill_dumps(args)
404
- dumped_kwargs = _safe_dill_dumps(kwargs)
405
- if dumped_args is None or dumped_kwargs is None:
406
- raise ValueError("Failed to dill-serialize args/kwargs")
407
-
408
- args_pack: Dict[str, Any] = {"kind": "dill", "b64": _b64e(dumped_args)}
409
- kwargs_pack: Dict[str, Any] = {"kind": "dill", "b64": _b64e(dumped_kwargs)}
410
-
411
- if not byte_limit:
412
- byte_limit = 512 * 1024
413
-
414
- payload = {
415
- "callable": callable_payload,
416
- "use_dill": bool(use_dill),
417
- "args": args_pack,
418
- "kwargs": kwargs_pack,
419
- "env": client_env,
420
- "result_tag": result_tag,
421
- "byte_limit": byte_limit
422
- }
423
-
424
- return f"""
425
- # --- generated by yggdrasil.ser.CallableSerdeMixin.to_command ---
426
- import base64, json, os, traceback, zlib
427
- import dill
428
- from yggdrasil.databricks import *
429
-
430
- _payload = {payload!r}
431
-
432
- def _b64d(s: str) -> bytes:
433
- return base64.b64decode(s.encode("ascii"))
434
-
435
- def _try_dill_load(b64: str):
436
- return dill.loads(_b64d(b64))
437
-
438
- def _hydrate_env(env_dict):
439
- ns = {{}}
440
- for bucket in ("globals", "freevars"):
441
- items = (env_dict or {{}}).get(bucket, {{}}) or {{}}
442
- for name, b64 in items.items():
443
- try:
444
- ns[name] = _try_dill_load(b64)
445
- except Exception:
446
- pass
447
- return ns
448
-
449
- def _load_args(pack):
450
- if pack["kind"] == "dill":
451
- return dill.loads(_b64d(pack["b64"]))
452
- return tuple(json.loads(pack["text"]))
453
-
454
- def _load_kwargs(pack):
455
- if pack["kind"] == "dill":
456
- return dill.loads(_b64d(pack["b64"]))
457
- return dict(json.loads(pack["text"]))
458
-
459
- def _load_callable(cpack, use_dill: bool):
460
- # Prefer dill when requested/available
461
- if use_dill and cpack.get("dill_b64"):
462
- fn = _try_dill_load(cpack["dill_b64"])
463
- if callable(fn):
464
- return fn
465
-
466
- src = cpack.get("source")
467
- name = cpack.get("name")
468
- imports = cpack.get("imports") or ""
469
- if not src:
470
- raise ValueError("No source available for exec-based restore")
471
-
472
- ns = {{}}
473
- ns.update(_hydrate_env(cpack.get("env") or {{}}))
474
-
475
- # NEW: exec module imports first
476
- if imports.strip():
477
- exec(imports, ns, ns)
478
-
479
- exec(src, ns, ns)
480
-
481
- if name and name in ns and callable(ns[name]):
482
- return ns[name]
483
-
484
- cands = [v for v in ns.values() if callable(v)]
485
- if cands:
486
- return cands[-1]
487
-
488
- raise ValueError("exec(source) ran but no callable was recovered")
489
-
490
- def _emit(tag: str, obj: dict):
491
- print(tag + json.dumps(obj, ensure_ascii=False, separators=(",", ":")))
492
-
493
-
494
- def _zlib_level(n: int, limit: int) -> int:
495
- # ratio of size to limit
496
- r = n / max(1, limit)
497
-
498
- # fast for huge payloads, stronger only when slightly over
499
- if r >= 32:
500
- return 1
501
- if r >= 16:
502
- return 2
503
- if r >= 8:
504
- return 3
505
- if r >= 4:
506
- return 4
507
- if r >= 2:
508
- return 5
509
- if r >= 1.25:
510
- return 6
511
- # barely over: squeeze a bit more
512
- return 7
513
-
514
- # apply env vars
515
- for k, v in (_payload.get("env") or {{}}).items():
516
- if v is not None:
517
- os.environ[str(k)] = str(v)
518
-
519
- tag = _payload.get("result_tag", "<<<RESULT>>>")
520
- byte_limit = int((_payload.get("byte_limit") or 2_000_000)) # ~2MB serialized bytes
521
-
522
- try:
523
- use_dill = bool(_payload.get("use_dill", False))
524
- fn = _load_callable(_payload["callable"], use_dill=use_dill)
525
- args = _load_args(_payload["args"])
526
- kwargs = _load_kwargs(_payload["kwargs"])
527
-
528
- out = fn(*args, **kwargs)
529
- out_raw = dill.dumps(out, recurse=True)
530
-
531
- if len(out_raw) > byte_limit:
532
- lvl = _zlib_level(len(out_raw), byte_limit)
533
- out_comp = zlib.compress(out_raw, level=lvl)
534
- out_b64 = base64.b64encode(out_comp).decode("ascii")
535
- _emit(tag, {{"ok": True, "encoding": f"dill+zlib{{lvl}}+b64", "payload": out_b64}})
536
- else:
537
- out_b64 = base64.b64encode(out_raw).decode("ascii")
538
- _emit(tag, {{"ok": True, "encoding": "dill+b64", "payload": out_b64}})
539
-
540
- except Exception as e:
541
- tb = traceback.format_exc()
542
- err = {{"cls": e.__class__.__name__, "msg": str(e), "tb": tb}}
543
- _emit(tag, {{"ok": False, "err": err}})
544
-
545
- # --- end generated command ---
546
- """.lstrip()
547
-
548
- @staticmethod
549
- def parse_command_result(
550
- stdout_text: str,
551
- *,
552
- result_tag: str = "<<<RESULT>>>",
553
- decode: bool = True,
554
- ) -> Any:
555
- """
556
- Parse Databricks command stdout and return decoded result or raise.
557
-
558
- Protocol:
559
- prints lines like: <<<RESULT>>>{...json...}
560
- last tagged line wins
561
-
562
- If ok=false -> raise DatabricksCommandError
563
- If ok=true and decode=True:
564
- - encoding == "dill+b64": dill.loads(base64(payload))
565
- - encoding == "repr": returns payload as-is
566
- If decode=False: returns the raw parsed dict.
567
- """
568
- last: Optional[str] = None
569
- for line in stdout_text.splitlines():
570
- if line.startswith(result_tag):
571
- last = line[len(result_tag):]
572
- else:
573
- print(line)
574
-
575
- if last is None:
576
- raise CommandResultParseError(
577
- f"Result tag {result_tag!r} not found in command output",
578
- stdout_text=stdout_text,
579
- )
580
-
581
- try:
582
- msg: Dict[str, Any] = json.loads(last)
583
- except Exception as e:
584
- raise CommandResultParseError(
585
- "Tagged result is not valid JSON",
586
- stdout_text=stdout_text,
587
- ) from e
588
-
589
- if not decode:
590
- return msg
591
-
592
- ok = bool(msg.get("ok", False))
593
- if not ok:
594
- error = msg.get("err", {})
595
- error_class = str(error.get("cls", "RuntimeError"))
596
- error_message = str(error.get("msg", "Remote execution failed"))
597
- error_traceback = str(error.get("tb", ""))
598
-
599
- base = CommandError(
600
- error=error_message,
601
- traceback=error_traceback,
602
- raw=msg,
603
- )
604
-
605
- if error_class == "ModuleNotFound":
606
- raise ModuleNotFoundError(error_message) from base
607
-
608
- raise base
609
-
610
- encoding = msg.get("encoding")
611
- payload = msg.get("payload")
612
-
613
- if encoding == "dill+b64":
614
- if not isinstance(payload, str):
615
- raise CommandResultParseError(
616
- "Expected base64 string payload for dill+b64 encoding",
617
- stdout_text=stdout_text,
618
- )
619
- try:
620
- data = base64.b64decode(payload.encode("ascii"))
621
- return dill.loads(data)
622
- except Exception as e:
623
- raise CommandResultParseError(
624
- "Failed to decode dill+b64 payload",
625
- stdout_text=stdout_text,
626
- ) from e
627
-
628
- if encoding == "repr":
629
- return payload
630
-
631
- if encoding.startswith("dill+zlib") and encoding.endswith("+b64"):
632
- if not isinstance(payload, str):
633
- raise CommandResultParseError("Expected base64 string payload for zlib encoding",
634
- stdout_text=stdout_text)
635
- try:
636
- data = base64.b64decode(payload.encode("ascii"))
637
- raw = zlib.decompress(data)
638
- return dill.loads(raw)
639
- except Exception as e:
640
- raise CommandResultParseError("Failed to decode dill+zlib+b64 payload", stdout_text=stdout_text) from e
641
-
642
- raise CommandResultParseError(
643
- f"Unknown result encoding: {encoding!r}",
644
- stdout_text=stdout_text,
645
- )
File without changes