ygg 0.1.20__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,661 +0,0 @@
1
- # yggdrasil/ser.py
2
- from __future__ import annotations
3
-
4
- import ast
5
- import base64
6
- import builtins
7
- import inspect
8
- import json
9
- import os
10
- import sys
11
- import textwrap
12
- import zlib
13
- from dataclasses import dataclass, field
14
- from pathlib import Path
15
- from typing import Any, Callable, Dict, Iterable, Optional, Tuple
16
-
17
- import dill
18
-
19
- PyVer = Tuple[int, int]
20
-
21
-
22
- def _pyver() -> PyVer:
23
- v = sys.version_info
24
- return v.major, v.minor
25
-
26
-
27
- def _b64e(b: bytes) -> str:
28
- return base64.b64encode(b).decode("ascii")
29
-
30
-
31
- def _b64d(s: str) -> bytes:
32
- return base64.b64decode(s.encode("ascii"))
33
-
34
-
35
- def _is_python_function(obj: Any) -> bool:
36
- return inspect.isfunction(obj) or inspect.ismethod(obj) or isinstance(obj, type(lambda: 0))
37
-
38
-
39
- def _safe_dill_dumps(obj: Any) -> Optional[bytes]:
40
- try:
41
- return dill.dumps(obj, recurse=True)
42
- except Exception:
43
- return None
44
-
45
-
46
- def _safe_dill_loads(b: Optional[bytes]) -> Any:
47
- if b is None:
48
- raise ValueError("No dill bytes to load")
49
- return dill.loads(b)
50
-
51
-
52
- def _infer_package_root(fn: Callable[..., Any]) -> Tuple[str, ...]:
53
- """
54
- Return the *topmost* package directory for the module defining `fn`.
55
-
56
- Walk upwards from the module file directory while an __init__.py exists.
57
- Example:
58
- /repo/my_pkg/sub/mod.py -> returns /repo/my_pkg
59
- If no __init__.py is found, falls back to the module's parent directory.
60
- """
61
- try:
62
- mod = inspect.getmodule(fn)
63
- file = getattr(mod, "__file__", None)
64
- if not file:
65
- return tuple()
66
-
67
- p = Path(file).resolve()
68
- if p.suffix.lower() != ".py":
69
- return tuple()
70
-
71
- cur = p.parent # start at directory containing the module file
72
-
73
- # If this directory isn't a package at all, just return it
74
- if not (cur / "__init__.py").exists():
75
- return (str(cur),)
76
-
77
- # Climb while parent is still a package (has __init__.py)
78
- # Stop at the highest directory that is still a package.
79
- top = cur
80
- while True:
81
- parent = top.parent
82
- if parent == top:
83
- break
84
- if (parent / "__init__.py").exists():
85
- top = parent
86
- continue
87
- break
88
-
89
- return (str(top),)
90
- except Exception:
91
- return tuple()
92
-
93
-
94
- def _capture_exec_env(fn: Callable[..., Any]) -> Dict[str, Dict[str, str]]:
95
- """
96
- Capture a minimal environment for exec(source) reconstruction.
97
-
98
- - globals: names referenced by fn.__code__.co_names that exist in fn.__globals__ (excluding builtins)
99
- - freevars: closure cell contents for fn.__code__.co_freevars
100
-
101
- Values are stored as base64(dill.dumps(value)) (best-effort).
102
- """
103
- code = fn.__code__
104
- fn_globals = getattr(fn, "__globals__", {}) or {}
105
-
106
- referenced = set(code.co_names or ())
107
- builtins_set = set(dir(builtins))
108
-
109
- g_payload: Dict[str, str] = {}
110
- for name in referenced:
111
- if name in builtins_set:
112
- continue
113
- if name in fn_globals:
114
- dumped = _safe_dill_dumps(fn_globals[name])
115
- if dumped is not None:
116
- g_payload[name] = _b64e(dumped)
117
-
118
- fv_payload: Dict[str, str] = {}
119
- freevars = code.co_freevars or ()
120
- closure = fn.__closure__ or ()
121
- if freevars and closure and len(freevars) == len(closure):
122
- for name, cell in zip(freevars, closure):
123
- try:
124
- val = cell.cell_contents
125
- except ValueError:
126
- continue
127
- dumped = _safe_dill_dumps(val)
128
- if dumped is not None:
129
- fv_payload[name] = _b64e(dumped)
130
-
131
- return {"globals": g_payload, "freevars": fv_payload}
132
-
133
-
134
- def _capture_module_imports(fn: Callable[..., Any]) -> str:
135
- """
136
- Capture top-level imports from the module file where `fn` is defined.
137
-
138
- We extract only:
139
- - `import x`
140
- - `import x as y`
141
- - `from a.b import c`
142
- - `from a.b import c as d`
143
-
144
- Returned as a single string block (may be empty).
145
- """
146
- try:
147
- src_file = inspect.getsourcefile(fn) or inspect.getfile(fn)
148
- if not src_file:
149
- return ""
150
- p = Path(src_file)
151
- if not p.exists() or p.suffix.lower() != ".py":
152
- return ""
153
-
154
- text = p.read_text(encoding="utf-8")
155
- tree = ast.parse(text)
156
-
157
- imports: list[str] = []
158
- for node in tree.body:
159
- if isinstance(node, (ast.Import, ast.ImportFrom)):
160
- seg = ast.get_source_segment(text, node)
161
- if seg:
162
- imports.append(seg.strip())
163
-
164
- # de-dupe while preserving order
165
- seen = set()
166
- uniq: list[str] = []
167
- for line in imports:
168
- if line not in seen:
169
- seen.add(line)
170
- uniq.append(line)
171
-
172
- return "\n".join(uniq).strip() + ("\n" if uniq else "")
173
- except Exception:
174
- return ""
175
-
176
-
177
- def _pack_payload(payload: Dict[str, Any], *, level: int = 6) -> str:
178
- lvl = int(level)
179
- if lvl < 1:
180
- lvl = 1
181
- elif lvl > 9:
182
- lvl = 9
183
- raw = json.dumps(payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
184
- comp = zlib.compress(raw, level=lvl)
185
- return _b64e(comp)
186
-
187
-
188
- def parse_tagged_result(stdout_text: str, result_tag: str) -> Dict[str, Any]:
189
- """
190
- Extract the last tagged JSON payload printed by a cluster command.
191
- Expects lines like: <<<RESULT>>>{...json...}
192
- """
193
- last = None
194
- for line in stdout_text.splitlines():
195
- if line.startswith(result_tag):
196
- last = line[len(result_tag) :]
197
- if last is None:
198
- raise ValueError(f"Result tag {result_tag!r} not found in output")
199
- try:
200
- return json.loads(last)
201
- except Exception as e:
202
- raise ValueError("Tagged result was not valid JSON") from e
203
-
204
-
205
- class CommandError(RuntimeError):
206
- """
207
- Raised when the remote Databricks command returns ok=false.
208
-
209
- Attributes
210
- ----------
211
- error: str
212
- Short error message from remote.
213
- traceback: str
214
- Remote traceback (string) if provided.
215
- raw: dict
216
- Full parsed JSON payload from the tagged output line.
217
- """
218
-
219
- def __init__(self, error: str, traceback: str = "", raw: Optional[Dict[str, Any]] = None):
220
- self.error = error
221
- self.traceback = traceback
222
- self.raw = raw or {}
223
- msg = error if not traceback else f"{error}\n{traceback}"
224
- super().__init__(msg.rstrip())
225
-
226
-
227
- class CommandResultParseError(ValueError):
228
- """
229
- Raised when the command output cannot be parsed or decoded.
230
- """
231
-
232
- def __init__(self, message: str, stdout_text: Optional[str] = None):
233
- self.stdout_text = stdout_text
234
- super().__init__(message)
235
-
236
-
237
- @dataclass
238
- class CallableSerdeMixin:
239
- """
240
- Encapsulates a callable + serialization strategy + Databricks command generation.
241
-
242
- - Same Python (major, minor): prefer dill
243
- - Different Python (major, minor): exec(imports + source) with captured globals/freevars
244
- """
245
-
246
- fn: Callable[..., Any]
247
- package_root: Tuple[str, ...] = field(default_factory=tuple)
248
- ALLOW_EXEC_SOURCE: bool = True
249
-
250
- @classmethod
251
- def from_callable(cls, fn: Callable[..., Any]) -> "CallableSerdeMixin":
252
- if isinstance(fn, CallableSerdeMixin):
253
- return fn
254
- return cls(fn=fn, package_root=_infer_package_root(fn))
255
-
256
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
257
- return self.fn(*args, **kwargs)
258
-
259
- # ---------- pickle protocol ----------
260
- def __getstate__(self) -> Dict[str, Any]:
261
- if _is_python_function(self.fn):
262
- src = None
263
- try:
264
- src = textwrap.dedent(inspect.getsource(self.fn))
265
- except Exception:
266
- src = None
267
-
268
- payload: Dict[str, Any] = {
269
- "__callable__": True,
270
- "pyver": list(_pyver()),
271
- "name": getattr(self.fn, "__name__", None),
272
- "qualname": getattr(self.fn, "__qualname__", None),
273
- "module": getattr(self.fn, "__module__", None),
274
- "imports": _capture_module_imports(self.fn), # NEW
275
- "source": src,
276
- "dill_b64": None,
277
- "env": _capture_exec_env(self.fn),
278
- }
279
-
280
- dumped = _safe_dill_dumps(self.fn)
281
- if dumped is not None:
282
- payload["dill_b64"] = _b64e(dumped)
283
-
284
- return {
285
- "fn": payload,
286
- "package_root": tuple(self.package_root),
287
- "ALLOW_EXEC_SOURCE": bool(self.ALLOW_EXEC_SOURCE),
288
- }
289
-
290
- dumped = _safe_dill_dumps(self.fn)
291
- if dumped is None:
292
- raise ValueError("Callable object could not be dill-serialized")
293
- return {
294
- "fn": {"__callable__": False, "dill_b64": _b64e(dumped)},
295
- "package_root": tuple(self.package_root),
296
- "ALLOW_EXEC_SOURCE": bool(self.ALLOW_EXEC_SOURCE),
297
- }
298
-
299
- def __setstate__(self, state: Dict[str, Any]) -> None:
300
- self.package_root = tuple(state.get("package_root") or ())
301
- self.ALLOW_EXEC_SOURCE = bool(state.get("ALLOW_EXEC_SOURCE", True))
302
-
303
- fn_payload = state["fn"]
304
-
305
- if isinstance(fn_payload, dict) and fn_payload.get("__callable__") is False:
306
- self.fn = _safe_dill_loads(_b64d(fn_payload["dill_b64"]))
307
- return
308
-
309
- if not isinstance(fn_payload, dict) or fn_payload.get("__callable__") is not True:
310
- raise ValueError("Invalid callable payload")
311
-
312
- src_pyver = tuple(fn_payload.get("pyver") or ())
313
- cur_pyver = _pyver()
314
-
315
- if src_pyver == cur_pyver and fn_payload.get("dill_b64"):
316
- try:
317
- self.fn = _safe_dill_loads(_b64d(fn_payload["dill_b64"]))
318
- if callable(self.fn):
319
- return
320
- except Exception:
321
- pass
322
-
323
- if not self.ALLOW_EXEC_SOURCE:
324
- raise ValueError("Exec-based restore disabled and dill path unavailable/failed")
325
-
326
- imports = fn_payload.get("imports") or ""
327
- source = fn_payload.get("source")
328
- name = fn_payload.get("name")
329
- env = fn_payload.get("env") or {}
330
-
331
- if not source:
332
- raise ValueError("No source available for exec-based restore")
333
-
334
- ns: Dict[str, Any] = {}
335
-
336
- # preload captured names
337
- for bucket in ("globals", "freevars"):
338
- items = (env.get(bucket) or {})
339
- for k, b64 in items.items():
340
- try:
341
- ns[k] = _safe_dill_loads(_b64d(b64))
342
- except Exception:
343
- pass
344
-
345
- # NEW: run module imports first
346
- if imports.strip():
347
- exec(imports, ns, ns)
348
-
349
- exec(source, ns, ns)
350
-
351
- if name and name in ns and callable(ns[name]):
352
- self.fn = ns[name]
353
- return
354
-
355
- cands = [v for v in ns.values() if callable(v)]
356
- if cands:
357
- self.fn = cands[-1]
358
- return
359
-
360
- raise ValueError("exec(source) succeeded but no callable could be recovered")
361
-
362
- # ---------- Databricks command generation ----------
363
- def to_command(
364
- self,
365
- *,
366
- args: Tuple[Any, ...] = (),
367
- kwargs: Optional[Dict[str, Any]] = None,
368
- env_keys: Iterable[str] = (),
369
- env_variables: Optional[Dict[str, str]] = None,
370
- use_dill: bool,
371
- byte_limit: int = 0,
372
- result_tag: str = "<<<RESULT>>>",
373
- compress_input_payload: bool = True, # NEW
374
- payload_compression_level: int = 6, # NEW (1..9)
375
- ) -> str:
376
- if kwargs is None:
377
- kwargs = {}
378
- if env_variables is None:
379
- env_variables = {}
380
-
381
- # capture env vars from client process
382
- client_env: Dict[str, str] = {}
383
- for k in env_keys:
384
- v = os.environ.get(k)
385
- if v is not None:
386
- client_env[k] = v
387
- client_env.update(env_variables)
388
-
389
- if not _is_python_function(self.fn):
390
- raise ValueError("to_command supports Python functions/methods/lambdas only")
391
-
392
- src = None
393
- try:
394
- src = textwrap.dedent(inspect.getsource(self.fn))
395
- except Exception:
396
- src = None
397
-
398
- imports = _capture_module_imports(self.fn)
399
-
400
- callable_payload: Dict[str, Any] = {
401
- "__callable__": True,
402
- "pyver": list(_pyver()),
403
- "name": getattr(self.fn, "__name__", None),
404
- "qualname": getattr(self.fn, "__qualname__", None),
405
- "module": getattr(self.fn, "__module__", None),
406
- "imports": imports,
407
- "source": src,
408
- "dill_b64": None,
409
- "env": _capture_exec_env(self.fn),
410
- }
411
- dumped_fn = _safe_dill_dumps(self.fn)
412
- if dumped_fn is not None:
413
- callable_payload["dill_b64"] = _b64e(dumped_fn)
414
-
415
- # args/kwargs transport
416
- dumped_args = _safe_dill_dumps(args)
417
- dumped_kwargs = _safe_dill_dumps(kwargs)
418
- if dumped_args is None or dumped_kwargs is None:
419
- raise ValueError("Failed to dill-serialize args/kwargs")
420
-
421
- args_pack: Dict[str, Any] = {"kind": "dill", "b64": _b64e(dumped_args)}
422
- kwargs_pack: Dict[str, Any] = {"kind": "dill", "b64": _b64e(dumped_kwargs)}
423
-
424
- if not byte_limit:
425
- byte_limit = 512 * 1024
426
-
427
- payload = {
428
- "callable": callable_payload,
429
- "use_dill": bool(use_dill),
430
- "args": args_pack,
431
- "kwargs": kwargs_pack,
432
- "env": client_env,
433
- "result_tag": result_tag,
434
- "byte_limit": byte_limit,
435
- }
436
-
437
- # NEW: compress the embedded input payload to keep command size small
438
- if compress_input_payload:
439
- payload_b64 = _pack_payload(payload, level=payload_compression_level)
440
- payload_bootstrap = f"""
441
- _payload_b64 = {payload_b64!r}
442
-
443
- def _load_payload(b64: str) -> dict:
444
- raw = base64.b64decode(b64.encode("ascii"))
445
- raw = zlib.decompress(raw)
446
- return json.loads(raw.decode("utf-8"))
447
-
448
- _payload = _load_payload(_payload_b64)
449
- """.rstrip()
450
- else:
451
- payload_bootstrap = f"_payload = {payload!r}"
452
-
453
- return f"""
454
- # --- generated by yggdrasil.ser.CallableSerdeMixin.to_command ---
455
- import base64, json, os, traceback, zlib
456
- import dill
457
- from yggdrasil.databricks import *
458
-
459
- {payload_bootstrap}
460
-
461
- def _b64d(s: str) -> bytes:
462
- return base64.b64decode(s.encode("ascii"))
463
-
464
- def _try_dill_load(b64: str):
465
- return dill.loads(_b64d(b64))
466
-
467
- def _hydrate_env(env_dict):
468
- ns = {{}}
469
- for bucket in ("globals", "freevars"):
470
- items = (env_dict or {{}}).get(bucket, {{}}) or {{}}
471
- for name, b64 in items.items():
472
- try:
473
- ns[name] = _try_dill_load(b64)
474
- except Exception:
475
- pass
476
- return ns
477
-
478
- def _load_args(pack):
479
- if pack["kind"] == "dill":
480
- return dill.loads(_b64d(pack["b64"]))
481
- return tuple(json.loads(pack["text"]))
482
-
483
- def _load_kwargs(pack):
484
- if pack["kind"] == "dill":
485
- return dill.loads(_b64d(pack["b64"]))
486
- return dict(json.loads(pack["text"]))
487
-
488
- def _load_callable(cpack, use_dill: bool):
489
- # Prefer dill when requested/available
490
- if use_dill and cpack.get("dill_b64"):
491
- fn = _try_dill_load(cpack["dill_b64"])
492
- if callable(fn):
493
- return fn
494
-
495
- src = cpack.get("source")
496
- name = cpack.get("name")
497
- imports = cpack.get("imports") or ""
498
- if not src:
499
- raise ValueError("No source available for exec-based restore")
500
-
501
- ns = {{}}
502
- ns.update(_hydrate_env(cpack.get("env") or {{}}))
503
-
504
- if imports.strip():
505
- exec(imports, ns, ns)
506
-
507
- exec(src, ns, ns)
508
-
509
- if name and name in ns and callable(ns[name]):
510
- return ns[name]
511
-
512
- cands = [v for v in ns.values() if callable(v)]
513
- if cands:
514
- return cands[-1]
515
-
516
- raise ValueError("exec(source) ran but no callable was recovered")
517
-
518
- def _emit(tag: str, obj: dict):
519
- print(tag + json.dumps(obj, ensure_ascii=False, separators=(",", ":")))
520
-
521
- def _zlib_level(n: int, limit: int) -> int:
522
- r = n / max(1, limit)
523
- if r >= 32: return 1
524
- if r >= 16: return 2
525
- if r >= 8: return 3
526
- if r >= 4: return 4
527
- if r >= 2: return 5
528
- if r >= 1.25: return 6
529
- return 7
530
-
531
- for k, v in (_payload.get("env") or {{}}).items():
532
- if v is not None:
533
- os.environ[str(k)] = str(v)
534
-
535
- tag = _payload.get("result_tag", "<<<RESULT>>>")
536
- byte_limit = int((_payload.get("byte_limit") or 2_000_000))
537
-
538
- try:
539
- use_dill = bool(_payload.get("use_dill", False))
540
- fn = _load_callable(_payload["callable"], use_dill=use_dill)
541
- args = _load_args(_payload["args"])
542
- kwargs = _load_kwargs(_payload["kwargs"])
543
-
544
- out = fn(*args, **kwargs)
545
- out_raw = dill.dumps(out, recurse=True)
546
-
547
- if len(out_raw) > byte_limit:
548
- lvl = _zlib_level(len(out_raw), byte_limit)
549
- out_comp = zlib.compress(out_raw, level=lvl)
550
- out_b64 = base64.b64encode(out_comp).decode("ascii")
551
- _emit(tag, {{"ok": True, "encoding": f"dill+zlib{{lvl}}+b64", "payload": out_b64}})
552
- else:
553
- out_b64 = base64.b64encode(out_raw).decode("ascii")
554
- _emit(tag, {{"ok": True, "encoding": "dill+b64", "payload": out_b64}})
555
-
556
- except Exception as e:
557
- tb = traceback.format_exc()
558
- err = {{"cls": e.__class__.__name__, "msg": str(e), "tb": tb}}
559
- _emit(tag, {{"ok": False, "err": err}})
560
-
561
- # --- end generated command ---
562
- """.lstrip()
563
-
564
- @staticmethod
565
- def parse_command_result(
566
- stdout_text: str,
567
- *,
568
- result_tag: str = "<<<RESULT>>>",
569
- decode: bool = True,
570
- ) -> Any:
571
- """
572
- Parse Databricks command stdout and return decoded result or raise.
573
-
574
- Protocol:
575
- prints lines like: <<<RESULT>>>{...json...}
576
- last tagged line wins
577
-
578
- If ok=false -> raise DatabricksCommandError
579
- If ok=true and decode=True:
580
- - encoding == "dill+b64": dill.loads(base64(payload))
581
- - encoding == "repr": returns payload as-is
582
- If decode=False: returns the raw parsed dict.
583
- """
584
- last: Optional[str] = None
585
- for line in stdout_text.splitlines():
586
- if line.startswith(result_tag):
587
- last = line[len(result_tag):]
588
- else:
589
- print(line)
590
-
591
- if last is None:
592
- raise CommandResultParseError(
593
- f"Result tag {result_tag!r} not found in command output",
594
- stdout_text=stdout_text,
595
- )
596
-
597
- try:
598
- msg: Dict[str, Any] = json.loads(last)
599
- except Exception as e:
600
- raise CommandResultParseError(
601
- "Tagged result is not valid JSON",
602
- stdout_text=stdout_text,
603
- ) from e
604
-
605
- if not decode:
606
- return msg
607
-
608
- ok = bool(msg.get("ok", False))
609
- if not ok:
610
- error = msg.get("err", {})
611
- error_class = str(error.get("cls", "RuntimeError"))
612
- error_message = str(error.get("msg", "Remote execution failed"))
613
- error_traceback = str(error.get("tb", ""))
614
-
615
- base = CommandError(
616
- error=error_message,
617
- traceback=error_traceback,
618
- raw=msg,
619
- )
620
-
621
- if error_class == "ModuleNotFound":
622
- raise ModuleNotFoundError(error_message) from base
623
-
624
- raise base
625
-
626
- encoding = msg.get("encoding")
627
- payload = msg.get("payload")
628
-
629
- if encoding == "dill+b64":
630
- if not isinstance(payload, str):
631
- raise CommandResultParseError(
632
- "Expected base64 string payload for dill+b64 encoding",
633
- stdout_text=stdout_text,
634
- )
635
- try:
636
- data = base64.b64decode(payload.encode("ascii"))
637
- return dill.loads(data)
638
- except Exception as e:
639
- raise CommandResultParseError(
640
- "Failed to decode dill+b64 payload",
641
- stdout_text=stdout_text,
642
- ) from e
643
-
644
- if encoding == "repr":
645
- return payload
646
-
647
- if encoding.startswith("dill+zlib") and encoding.endswith("+b64"):
648
- if not isinstance(payload, str):
649
- raise CommandResultParseError("Expected base64 string payload for zlib encoding",
650
- stdout_text=stdout_text)
651
- try:
652
- data = base64.b64decode(payload.encode("ascii"))
653
- raw = zlib.decompress(data)
654
- return dill.loads(raw)
655
- except Exception as e:
656
- raise CommandResultParseError("Failed to decode dill+zlib+b64 payload", stdout_text=stdout_text) from e
657
-
658
- raise CommandResultParseError(
659
- f"Unknown result encoding: {encoding!r}",
660
- stdout_text=stdout_text,
661
- )
File without changes