ygg 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/METADATA +1 -1
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/RECORD +14 -13
- yggdrasil/databricks/compute/cluster.py +20 -16
- yggdrasil/databricks/compute/execution_context.py +46 -64
- yggdrasil/databricks/sql/engine.py +5 -2
- yggdrasil/databricks/sql/warehouse.py +355 -0
- yggdrasil/databricks/workspaces/workspace.py +19 -9
- yggdrasil/pyutils/callable_serde.py +296 -308
- yggdrasil/pyutils/expiring_dict.py +114 -25
- yggdrasil/version.py +1 -1
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/WHEEL +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/entry_points.txt +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/licenses/LICENSE +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/top_level.txt +0 -0
|
@@ -1,43 +1,168 @@
|
|
|
1
|
-
"""Callable serialization helpers for cross-process execution.
|
|
1
|
+
"""Callable serialization helpers for cross-process execution.
|
|
2
|
+
|
|
3
|
+
Design goals:
|
|
4
|
+
- Prefer import-by-reference when possible (module + qualname), fallback to dill.
|
|
5
|
+
- Optional environment payload: selected globals and/or closure values.
|
|
6
|
+
- Cross-process bridge: generate a self-contained Python command string that:
|
|
7
|
+
1) materializes the callable
|
|
8
|
+
2) decodes args/kwargs payload
|
|
9
|
+
3) executes
|
|
10
|
+
4) emits a single tagged base64 line with a compressed result blob
|
|
11
|
+
|
|
12
|
+
Compression/framing:
|
|
13
|
+
- CS2 framing only (no CS1 logic).
|
|
14
|
+
- Frame header: MAGIC(3) + codec(u8) + orig_len(u32) + param(u8) + data
|
|
15
|
+
- Codecs:
|
|
16
|
+
0 raw (rarely used; mostly means "no frame")
|
|
17
|
+
1 zlib
|
|
18
|
+
2 lzma
|
|
19
|
+
3 zstd (optional dependency)
|
|
20
|
+
"""
|
|
2
21
|
|
|
3
22
|
from __future__ import annotations
|
|
4
23
|
|
|
5
24
|
import base64
|
|
25
|
+
import binascii
|
|
6
26
|
import dis
|
|
7
27
|
import importlib
|
|
8
28
|
import inspect
|
|
9
|
-
import
|
|
29
|
+
import io
|
|
30
|
+
import lzma
|
|
10
31
|
import os
|
|
32
|
+
import secrets
|
|
11
33
|
import struct
|
|
12
34
|
import sys
|
|
13
35
|
import zlib
|
|
14
36
|
from dataclasses import dataclass
|
|
15
37
|
from pathlib import Path
|
|
16
|
-
from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar, Union,
|
|
38
|
+
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, TypeVar, Union, TYPE_CHECKING
|
|
17
39
|
|
|
18
40
|
import dill
|
|
19
41
|
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from ..databricks.workspaces import Workspace
|
|
44
|
+
|
|
20
45
|
__all__ = ["CallableSerde"]
|
|
21
46
|
|
|
22
47
|
T = TypeVar("T", bound="CallableSerde")
|
|
23
48
|
|
|
49
|
+
# ---------------------------
|
|
50
|
+
# Framing / compression (CS2)
|
|
51
|
+
# ---------------------------
|
|
24
52
|
|
|
25
|
-
|
|
53
|
+
_MAGIC = b"CS2"
|
|
26
54
|
|
|
27
|
-
|
|
28
|
-
|
|
55
|
+
_CODEC_RAW = 0
|
|
56
|
+
_CODEC_ZLIB = 1
|
|
57
|
+
_CODEC_LZMA = 2
|
|
58
|
+
_CODEC_ZSTD = 3
|
|
29
59
|
|
|
30
60
|
|
|
31
|
-
def
|
|
32
|
-
|
|
61
|
+
def _try_import_zstd():
|
|
62
|
+
try:
|
|
63
|
+
import zstandard as zstd # type: ignore
|
|
64
|
+
return zstd
|
|
65
|
+
except Exception:
|
|
66
|
+
return None
|
|
33
67
|
|
|
34
|
-
Args:
|
|
35
|
-
mod: Module to traverse.
|
|
36
|
-
qualname: Dotted qualified name.
|
|
37
68
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
69
|
+
def _pick_zlib_level(n: int, limit: int) -> int:
|
|
70
|
+
"""Ramp compression level 1..9 based on how far we exceed the byte_limit."""
|
|
71
|
+
ratio = n / max(1, limit)
|
|
72
|
+
x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
|
|
73
|
+
return max(1, min(9, int(round(1 + 8 * x))))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _frame(codec: int, orig_len: int, param: int, payload: bytes) -> bytes:
|
|
77
|
+
return _MAGIC + struct.pack(">BIB", int(codec) & 0xFF, int(orig_len), int(param) & 0xFF) + payload
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _encode_with_candidates(raw: bytes, *, byte_limit: int, allow_zstd: bool) -> bytes:
|
|
81
|
+
"""Choose the smallest among available codecs; fall back to raw if not beneficial."""
|
|
82
|
+
if len(raw) <= byte_limit:
|
|
83
|
+
return raw
|
|
84
|
+
|
|
85
|
+
candidates: list[bytes] = []
|
|
86
|
+
|
|
87
|
+
if allow_zstd:
|
|
88
|
+
zstd = _try_import_zstd()
|
|
89
|
+
if zstd is not None:
|
|
90
|
+
for lvl in (6, 10, 15):
|
|
91
|
+
try:
|
|
92
|
+
c = zstd.ZstdCompressor(level=lvl).compress(raw)
|
|
93
|
+
candidates.append(_frame(_CODEC_ZSTD, len(raw), lvl, c))
|
|
94
|
+
except Exception:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
for preset in (6, 9):
|
|
98
|
+
try:
|
|
99
|
+
c = lzma.compress(raw, preset=preset)
|
|
100
|
+
candidates.append(_frame(_CODEC_LZMA, len(raw), preset, c))
|
|
101
|
+
except Exception:
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
lvl = _pick_zlib_level(len(raw), byte_limit)
|
|
105
|
+
try:
|
|
106
|
+
c = zlib.compress(raw, lvl)
|
|
107
|
+
candidates.append(_frame(_CODEC_ZLIB, len(raw), lvl, c))
|
|
108
|
+
except Exception:
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
if not candidates:
|
|
112
|
+
return raw
|
|
113
|
+
|
|
114
|
+
best = min(candidates, key=len)
|
|
115
|
+
return best if len(best) < len(raw) else raw
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _encode_result_blob(raw: bytes, byte_limit: int) -> bytes:
|
|
119
|
+
"""Result payload: zstd (if available) -> lzma -> zlib."""
|
|
120
|
+
return _encode_with_candidates(raw, byte_limit=byte_limit, allow_zstd=True)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _encode_wire_blob_stdlib(raw: bytes, byte_limit: int) -> bytes:
|
|
124
|
+
"""Wire payload (args/kwargs): stdlib-only (lzma -> zlib)."""
|
|
125
|
+
return _encode_with_candidates(raw, byte_limit=byte_limit, allow_zstd=False)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _decode_result_blob(blob: bytes) -> bytes:
|
|
129
|
+
"""Decode raw or CS2 framed data (no CS1 support)."""
|
|
130
|
+
if not isinstance(blob, (bytes, bytearray)) or len(blob) < 3:
|
|
131
|
+
return blob # type: ignore[return-value]
|
|
132
|
+
|
|
133
|
+
if not blob.startswith(_MAGIC):
|
|
134
|
+
return blob
|
|
135
|
+
|
|
136
|
+
if len(blob) < 3 + 6:
|
|
137
|
+
raise ValueError("CS2 framed blob too short / truncated.")
|
|
138
|
+
|
|
139
|
+
codec, orig_len, _param = struct.unpack(">BIB", blob[3 : 3 + 6])
|
|
140
|
+
data = blob[3 + 6 :]
|
|
141
|
+
|
|
142
|
+
if codec == _CODEC_RAW:
|
|
143
|
+
raw = data
|
|
144
|
+
elif codec == _CODEC_ZLIB:
|
|
145
|
+
raw = zlib.decompress(data)
|
|
146
|
+
elif codec == _CODEC_LZMA:
|
|
147
|
+
raw = lzma.decompress(data)
|
|
148
|
+
elif codec == _CODEC_ZSTD:
|
|
149
|
+
zstd = _try_import_zstd()
|
|
150
|
+
if zstd is None:
|
|
151
|
+
raise RuntimeError("CS2 uses zstd but 'zstandard' is not installed.")
|
|
152
|
+
raw = zstd.ZstdDecompressor().decompress(data, max_output_size=int(orig_len) if orig_len else 0)
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"Unknown CS2 codec: {codec}")
|
|
155
|
+
|
|
156
|
+
if orig_len and len(raw) != orig_len:
|
|
157
|
+
raise ValueError(f"Decoded length mismatch: got {len(raw)}, expected {orig_len}")
|
|
158
|
+
return raw
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# ---------------------------
|
|
162
|
+
# Callable reference helpers
|
|
163
|
+
# ---------------------------
|
|
164
|
+
|
|
165
|
+
def _resolve_attr_chain(mod: Any, qualname: str) -> Any:
|
|
41
166
|
obj = mod
|
|
42
167
|
for part in qualname.split("."):
|
|
43
168
|
obj = getattr(obj, part)
|
|
@@ -45,10 +170,6 @@ def _resolve_attr_chain(mod: Any, qualname: str) -> Any:
|
|
|
45
170
|
|
|
46
171
|
|
|
47
172
|
def _find_pkg_root_from_file(file_path: Path) -> Optional[Path]:
|
|
48
|
-
"""
|
|
49
|
-
Walk up parents while __init__.py exists.
|
|
50
|
-
Return the directory that should be on sys.path (parent of top package dir).
|
|
51
|
-
"""
|
|
52
173
|
file_path = file_path.resolve()
|
|
53
174
|
d = file_path.parent
|
|
54
175
|
|
|
@@ -61,14 +182,6 @@ def _find_pkg_root_from_file(file_path: Path) -> Optional[Path]:
|
|
|
61
182
|
|
|
62
183
|
|
|
63
184
|
def _callable_file_line(fn: Callable[..., Any]) -> Tuple[Optional[str], Optional[int]]:
|
|
64
|
-
"""Return the source file path and line number for a callable.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
fn: Callable to inspect.
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
Tuple of (file path, line number).
|
|
71
|
-
"""
|
|
72
185
|
file = None
|
|
73
186
|
line = None
|
|
74
187
|
try:
|
|
@@ -84,17 +197,12 @@ def _callable_file_line(fn: Callable[..., Any]) -> Tuple[Optional[str], Optional
|
|
|
84
197
|
|
|
85
198
|
|
|
86
199
|
def _referenced_global_names(fn: Callable[..., Any]) -> Set[str]:
|
|
87
|
-
"""
|
|
88
|
-
Names that the function *actually* resolves from globals/namespaces at runtime.
|
|
89
|
-
Uses bytecode to avoid shipping random junk.
|
|
90
|
-
"""
|
|
91
200
|
names: Set[str] = set()
|
|
92
201
|
try:
|
|
93
202
|
for ins in dis.get_instructions(fn):
|
|
94
203
|
if ins.opname in ("LOAD_GLOBAL", "LOAD_NAME") and isinstance(ins.argval, str):
|
|
95
204
|
names.add(ins.argval)
|
|
96
205
|
except Exception:
|
|
97
|
-
# fallback: less precise
|
|
98
206
|
try:
|
|
99
207
|
names.update(getattr(fn.__code__, "co_names", ()) or ())
|
|
100
208
|
except Exception:
|
|
@@ -105,14 +213,6 @@ def _referenced_global_names(fn: Callable[..., Any]) -> Set[str]:
|
|
|
105
213
|
|
|
106
214
|
|
|
107
215
|
def _is_importable_reference(fn: Callable[..., Any]) -> bool:
|
|
108
|
-
"""Return True when a callable can be imported by module and qualname.
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
fn: Callable to inspect.
|
|
112
|
-
|
|
113
|
-
Returns:
|
|
114
|
-
True if importable by module/qualname.
|
|
115
|
-
"""
|
|
116
216
|
mod_name = getattr(fn, "__module__", None)
|
|
117
217
|
qualname = getattr(fn, "__qualname__", None)
|
|
118
218
|
if not mod_name or not qualname:
|
|
@@ -127,61 +227,9 @@ def _is_importable_reference(fn: Callable[..., Any]) -> bool:
|
|
|
127
227
|
return False
|
|
128
228
|
|
|
129
229
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
ratio=1 -> level=1
|
|
134
|
-
ratio=4 -> level=9
|
|
135
|
-
clamp beyond.
|
|
136
|
-
"""
|
|
137
|
-
ratio = n / max(1, limit)
|
|
138
|
-
x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
|
|
139
|
-
return max(1, min(9, int(round(1 + 8 * x))))
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def _encode_result_blob(raw: bytes, byte_limit: int) -> bytes:
|
|
143
|
-
"""
|
|
144
|
-
For small payloads: return raw dill bytes (backwards compat).
|
|
145
|
-
For large payloads: wrap in framed header + zlib-compressed bytes (if beneficial).
|
|
146
|
-
"""
|
|
147
|
-
if len(raw) <= byte_limit:
|
|
148
|
-
return raw
|
|
149
|
-
|
|
150
|
-
level = _pick_zlib_level(len(raw), byte_limit)
|
|
151
|
-
compressed = zlib.compress(raw, level)
|
|
152
|
-
|
|
153
|
-
# If compression doesn't help, keep raw
|
|
154
|
-
if len(compressed) >= len(raw):
|
|
155
|
-
return raw
|
|
156
|
-
|
|
157
|
-
# Frame:
|
|
158
|
-
# MAGIC(3) + flags(u8) + orig_len(u32) + level(u8) + data
|
|
159
|
-
header = _MAGIC + struct.pack(">BIB", _FLAG_COMPRESSED, len(raw), level)
|
|
160
|
-
return header + compressed
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
def _decode_result_blob(blob: bytes) -> bytes:
|
|
164
|
-
"""
|
|
165
|
-
If framed, decompress if flagged, return raw dill bytes.
|
|
166
|
-
Else treat as raw dill bytes.
|
|
167
|
-
"""
|
|
168
|
-
if not blob.startswith(_MAGIC):
|
|
169
|
-
return blob
|
|
170
|
-
|
|
171
|
-
if len(blob) < 3 + 1 + 4 + 1:
|
|
172
|
-
raise ValueError("Framed result too short / corrupted.")
|
|
173
|
-
|
|
174
|
-
flags, orig_len, _level = struct.unpack(">BIB", blob[3 : 3 + 1 + 4 + 1])
|
|
175
|
-
data = blob[3 + 1 + 4 + 1 :]
|
|
176
|
-
|
|
177
|
-
if flags & _FLAG_COMPRESSED:
|
|
178
|
-
raw = zlib.decompress(data)
|
|
179
|
-
if orig_len and len(raw) != orig_len:
|
|
180
|
-
raise ValueError(f"Decompressed length mismatch: got {len(raw)}, expected {orig_len}")
|
|
181
|
-
return raw
|
|
182
|
-
|
|
183
|
-
return data
|
|
184
|
-
|
|
230
|
+
# ---------------------------
|
|
231
|
+
# Environment snapshot
|
|
232
|
+
# ---------------------------
|
|
185
233
|
|
|
186
234
|
def _dump_env(
|
|
187
235
|
fn: Callable[..., Any],
|
|
@@ -190,12 +238,6 @@ def _dump_env(
|
|
|
190
238
|
include_closure: bool,
|
|
191
239
|
filter_used_globals: bool,
|
|
192
240
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
193
|
-
"""
|
|
194
|
-
Returns (env, meta).
|
|
195
|
-
env is dill-able and contains:
|
|
196
|
-
- "globals": {name: value} (filtered to used names if enabled)
|
|
197
|
-
- "closure": {freevar: value} (capture only; injection not generally safe)
|
|
198
|
-
"""
|
|
199
241
|
env: Dict[str, Any] = {}
|
|
200
242
|
meta: Dict[str, Any] = {
|
|
201
243
|
"missing_globals": [],
|
|
@@ -242,13 +284,14 @@ def _dump_env(
|
|
|
242
284
|
return env, meta
|
|
243
285
|
|
|
244
286
|
|
|
245
|
-
# ----------
|
|
287
|
+
# ----------
|
|
288
|
+
# Main class
|
|
289
|
+
# ----------
|
|
246
290
|
|
|
247
291
|
@dataclass
|
|
248
292
|
class CallableSerde:
|
|
249
293
|
"""
|
|
250
294
|
Core field: `fn`
|
|
251
|
-
Serialized/backing fields used when fn isn't present yet.
|
|
252
295
|
|
|
253
296
|
kind:
|
|
254
297
|
- "auto": resolve import if possible else dill
|
|
@@ -258,6 +301,7 @@ class CallableSerde:
|
|
|
258
301
|
Optional env payload:
|
|
259
302
|
- env_b64: dill(base64) of {"globals": {...}, "closure": {...}}
|
|
260
303
|
"""
|
|
304
|
+
|
|
261
305
|
fn: Optional[Callable[..., Any]] = None
|
|
262
306
|
|
|
263
307
|
_kind: str = "auto" # "auto" | "import" | "dill"
|
|
@@ -273,48 +317,22 @@ class CallableSerde:
|
|
|
273
317
|
|
|
274
318
|
@classmethod
|
|
275
319
|
def from_callable(cls: type[T], x: Union[Callable[..., Any], T]) -> T:
|
|
276
|
-
"""Create a CallableSerde from a callable or existing instance.
|
|
277
|
-
|
|
278
|
-
Args:
|
|
279
|
-
x: Callable or CallableSerde instance.
|
|
280
|
-
|
|
281
|
-
Returns:
|
|
282
|
-
CallableSerde instance.
|
|
283
|
-
"""
|
|
284
320
|
if isinstance(x, cls):
|
|
285
321
|
return x
|
|
322
|
+
return cls(fn=x) # type: ignore[return-value]
|
|
286
323
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
return obj
|
|
290
|
-
|
|
291
|
-
# ----- lazy-ish properties (computed on access) -----
|
|
324
|
+
# ----- properties -----
|
|
292
325
|
|
|
293
326
|
@property
|
|
294
327
|
def module(self) -> Optional[str]:
|
|
295
|
-
"""Return the callable's module name if available.
|
|
296
|
-
|
|
297
|
-
Returns:
|
|
298
|
-
Module name or None.
|
|
299
|
-
"""
|
|
300
328
|
return self._module or (getattr(self.fn, "__module__", None) if self.fn else None)
|
|
301
329
|
|
|
302
330
|
@property
|
|
303
331
|
def qualname(self) -> Optional[str]:
|
|
304
|
-
"""Return the callable's qualified name if available.
|
|
305
|
-
|
|
306
|
-
Returns:
|
|
307
|
-
Qualified name or None.
|
|
308
|
-
"""
|
|
309
332
|
return self._qualname or (getattr(self.fn, "__qualname__", None) if self.fn else None)
|
|
310
333
|
|
|
311
334
|
@property
|
|
312
335
|
def file(self) -> Optional[str]:
|
|
313
|
-
"""Return the filesystem path of the callable's source file.
|
|
314
|
-
|
|
315
|
-
Returns:
|
|
316
|
-
File path or None.
|
|
317
|
-
"""
|
|
318
336
|
if not self.fn:
|
|
319
337
|
return None
|
|
320
338
|
f, _ = _callable_file_line(self.fn)
|
|
@@ -322,11 +340,6 @@ class CallableSerde:
|
|
|
322
340
|
|
|
323
341
|
@property
|
|
324
342
|
def line(self) -> Optional[int]:
|
|
325
|
-
"""Return the line number where the callable is defined.
|
|
326
|
-
|
|
327
|
-
Returns:
|
|
328
|
-
Line number or None.
|
|
329
|
-
"""
|
|
330
343
|
if not self.fn:
|
|
331
344
|
return None
|
|
332
345
|
_, ln = _callable_file_line(self.fn)
|
|
@@ -334,11 +347,6 @@ class CallableSerde:
|
|
|
334
347
|
|
|
335
348
|
@property
|
|
336
349
|
def pkg_root(self) -> Optional[str]:
|
|
337
|
-
"""Return the inferred package root for the callable, if known.
|
|
338
|
-
|
|
339
|
-
Returns:
|
|
340
|
-
Package root path or None.
|
|
341
|
-
"""
|
|
342
350
|
if self._pkg_root:
|
|
343
351
|
return self._pkg_root
|
|
344
352
|
if not self.file:
|
|
@@ -348,11 +356,6 @@ class CallableSerde:
|
|
|
348
356
|
|
|
349
357
|
@property
|
|
350
358
|
def relpath_from_pkg_root(self) -> Optional[str]:
|
|
351
|
-
"""Return the callable's path relative to the package root.
|
|
352
|
-
|
|
353
|
-
Returns:
|
|
354
|
-
Relative path or None.
|
|
355
|
-
"""
|
|
356
359
|
if not self.file or not self.pkg_root:
|
|
357
360
|
return None
|
|
358
361
|
try:
|
|
@@ -362,11 +365,6 @@ class CallableSerde:
|
|
|
362
365
|
|
|
363
366
|
@property
|
|
364
367
|
def importable(self) -> bool:
|
|
365
|
-
"""Return True when the callable can be imported by reference.
|
|
366
|
-
|
|
367
|
-
Returns:
|
|
368
|
-
True if importable by module/qualname.
|
|
369
|
-
"""
|
|
370
368
|
if self.fn is None:
|
|
371
369
|
return bool(self.module and self.qualname and "<locals>" not in (self.qualname or ""))
|
|
372
370
|
return _is_importable_reference(self.fn)
|
|
@@ -376,24 +374,12 @@ class CallableSerde:
|
|
|
376
374
|
def dump(
|
|
377
375
|
self,
|
|
378
376
|
*,
|
|
379
|
-
prefer: str = "import",
|
|
380
|
-
dump_env: str = "none",
|
|
377
|
+
prefer: str = "import", # "import" | "dill"
|
|
378
|
+
dump_env: str = "none", # "none" | "globals" | "closure" | "both"
|
|
381
379
|
filter_used_globals: bool = True,
|
|
382
380
|
env_keys: Optional[Iterable[str]] = None,
|
|
383
381
|
env_variables: Optional[Dict[str, str]] = None,
|
|
384
382
|
) -> Dict[str, Any]:
|
|
385
|
-
"""Serialize the callable into a dict for transport.
|
|
386
|
-
|
|
387
|
-
Args:
|
|
388
|
-
prefer: Preferred serialization kind.
|
|
389
|
-
dump_env: Environment payload selection.
|
|
390
|
-
filter_used_globals: Filter globals to referenced names.
|
|
391
|
-
env_keys: environment keys
|
|
392
|
-
env_variables: environment key values
|
|
393
|
-
|
|
394
|
-
Returns:
|
|
395
|
-
Serialized payload dict.
|
|
396
|
-
"""
|
|
397
383
|
kind = prefer
|
|
398
384
|
if kind == "import" and not self.importable:
|
|
399
385
|
kind = "dill"
|
|
@@ -420,7 +406,6 @@ class CallableSerde:
|
|
|
420
406
|
if env_keys:
|
|
421
407
|
for env_key in env_keys:
|
|
422
408
|
existing = os.getenv(env_key)
|
|
423
|
-
|
|
424
409
|
if existing:
|
|
425
410
|
env_variables[env_key] = existing
|
|
426
411
|
|
|
@@ -432,6 +417,7 @@ class CallableSerde:
|
|
|
432
417
|
raise ValueError("dump_env requested but fn is not present.")
|
|
433
418
|
include_globals = dump_env in ("globals", "both")
|
|
434
419
|
include_closure = dump_env in ("closure", "both")
|
|
420
|
+
|
|
435
421
|
env, meta = _dump_env(
|
|
436
422
|
self.fn,
|
|
437
423
|
include_globals=include_globals,
|
|
@@ -448,15 +434,6 @@ class CallableSerde:
|
|
|
448
434
|
|
|
449
435
|
@classmethod
|
|
450
436
|
def load(cls: type[T], d: Dict[str, Any], *, add_pkg_root_to_syspath: bool = True) -> T:
|
|
451
|
-
"""Construct a CallableSerde from a serialized dict payload.
|
|
452
|
-
|
|
453
|
-
Args:
|
|
454
|
-
d: Serialized payload dict.
|
|
455
|
-
add_pkg_root_to_syspath: Add package root to sys.path if True.
|
|
456
|
-
|
|
457
|
-
Returns:
|
|
458
|
-
CallableSerde instance.
|
|
459
|
-
"""
|
|
460
437
|
obj = cls(
|
|
461
438
|
fn=None,
|
|
462
439
|
_kind=d.get("kind", "auto"),
|
|
@@ -474,14 +451,6 @@ class CallableSerde:
|
|
|
474
451
|
return obj # type: ignore[return-value]
|
|
475
452
|
|
|
476
453
|
def materialize(self, *, add_pkg_root_to_syspath: bool = True) -> Callable[..., Any]:
|
|
477
|
-
"""Resolve and return the underlying callable.
|
|
478
|
-
|
|
479
|
-
Args:
|
|
480
|
-
add_pkg_root_to_syspath: Add package root to sys.path if True.
|
|
481
|
-
|
|
482
|
-
Returns:
|
|
483
|
-
Resolved callable.
|
|
484
|
-
"""
|
|
485
454
|
if self.fn is not None:
|
|
486
455
|
return self.fn
|
|
487
456
|
|
|
@@ -515,19 +484,12 @@ class CallableSerde:
|
|
|
515
484
|
raise ValueError(f"Unknown kind: {kind}")
|
|
516
485
|
|
|
517
486
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
518
|
-
"""Invoke the materialized callable with the provided arguments.
|
|
519
|
-
|
|
520
|
-
Args:
|
|
521
|
-
*args: Positional args for the callable.
|
|
522
|
-
**kwargs: Keyword args for the callable.
|
|
523
|
-
|
|
524
|
-
Returns:
|
|
525
|
-
Callable return value.
|
|
526
|
-
"""
|
|
527
487
|
fn = self.materialize()
|
|
528
488
|
return fn(*args, **kwargs)
|
|
529
489
|
|
|
530
|
-
#
|
|
490
|
+
# -------------------------
|
|
491
|
+
# Command execution bridge
|
|
492
|
+
# -------------------------
|
|
531
493
|
|
|
532
494
|
def to_command(
|
|
533
495
|
self,
|
|
@@ -536,24 +498,21 @@ class CallableSerde:
|
|
|
536
498
|
*,
|
|
537
499
|
result_tag: str = "__CALLABLE_SERDE_RESULT__",
|
|
538
500
|
prefer: str = "dill",
|
|
539
|
-
byte_limit: int =
|
|
540
|
-
dump_env: str = "none",
|
|
501
|
+
byte_limit: int = 64 * 1024,
|
|
502
|
+
dump_env: str = "none", # "none" | "globals" | "closure" | "both"
|
|
541
503
|
filter_used_globals: bool = True,
|
|
542
504
|
env_keys: Optional[Iterable[str]] = None,
|
|
543
505
|
env_variables: Optional[Dict[str, str]] = None,
|
|
506
|
+
file_dump_limit: int = 512 * 1024,
|
|
507
|
+
transaction_id: Optional[str] = None
|
|
544
508
|
) -> str:
|
|
545
509
|
"""
|
|
546
510
|
Returns Python code string to execute in another interpreter.
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
Also compresses the input call payload (args/kwargs) using the same framing
|
|
551
|
-
scheme when it exceeds byte_limit.
|
|
511
|
+
Emits exactly one line to stdout:
|
|
512
|
+
"{result_tag}:{base64(blob)}\\n"
|
|
513
|
+
where blob is raw dill bytes or CS2 framed.
|
|
552
514
|
"""
|
|
553
|
-
import base64
|
|
554
515
|
import json
|
|
555
|
-
import struct
|
|
556
|
-
import zlib
|
|
557
516
|
|
|
558
517
|
args = args or ()
|
|
559
518
|
kwargs = kwargs or {}
|
|
@@ -567,75 +526,31 @@ class CallableSerde:
|
|
|
567
526
|
)
|
|
568
527
|
serde_json = json.dumps(serde_dict, ensure_ascii=False)
|
|
569
528
|
|
|
570
|
-
#
|
|
571
|
-
MAGIC = b"CS1"
|
|
572
|
-
FLAG_COMPRESSED = 1
|
|
573
|
-
|
|
574
|
-
def _pick_level(n: int, limit: int) -> int:
|
|
575
|
-
ratio = n / max(1, limit)
|
|
576
|
-
x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
|
|
577
|
-
return max(1, min(9, int(round(1 + 8 * x))))
|
|
578
|
-
|
|
579
|
-
def _encode_blob(raw: bytes, limit: int) -> bytes:
|
|
580
|
-
if len(raw) <= limit:
|
|
581
|
-
return raw
|
|
582
|
-
level = _pick_level(len(raw), limit)
|
|
583
|
-
compressed = zlib.compress(raw, level)
|
|
584
|
-
if len(compressed) >= len(raw):
|
|
585
|
-
return raw
|
|
586
|
-
header = MAGIC + struct.pack(">BIB", FLAG_COMPRESSED, len(raw), level)
|
|
587
|
-
return header + compressed
|
|
588
|
-
|
|
529
|
+
# args/kwargs payload: stdlib-only compression (lzma/zlib)
|
|
589
530
|
call_raw = dill.dumps((args, kwargs), recurse=True)
|
|
590
|
-
call_blob =
|
|
531
|
+
call_blob = _encode_wire_blob_stdlib(call_raw, int(byte_limit))
|
|
591
532
|
call_payload_b64 = base64.b64encode(call_blob).decode("ascii")
|
|
533
|
+
transaction_id = transaction_id or secrets.token_urlsafe(16)
|
|
592
534
|
|
|
593
|
-
# NOTE: plain string template + replace. No f-string. No brace escaping.
|
|
594
535
|
template = r"""
|
|
595
|
-
import base64, json,
|
|
536
|
+
import base64, json, os, sys
|
|
596
537
|
import dill
|
|
538
|
+
import pandas
|
|
539
|
+
|
|
540
|
+
from yggdrasil.databricks import Workspace
|
|
541
|
+
from yggdrasil.pyutils.callable_serde import (
|
|
542
|
+
CallableSerde,
|
|
543
|
+
_decode_result_blob,
|
|
544
|
+
_encode_result_blob,
|
|
545
|
+
)
|
|
597
546
|
|
|
598
547
|
RESULT_TAG = __RESULT_TAG__
|
|
599
548
|
BYTE_LIMIT = __BYTE_LIMIT__
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
FLAG_COMPRESSED = 1
|
|
603
|
-
|
|
604
|
-
def _resolve_attr_chain(mod, qualname: str):
|
|
605
|
-
obj = mod
|
|
606
|
-
for part in qualname.split("."):
|
|
607
|
-
obj = getattr(obj, part)
|
|
608
|
-
return obj
|
|
609
|
-
|
|
610
|
-
def _pick_level(n: int, limit: int) -> int:
|
|
611
|
-
ratio = n / max(1, limit)
|
|
612
|
-
x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
|
|
613
|
-
return max(1, min(9, int(round(1 + 8 * x))))
|
|
614
|
-
|
|
615
|
-
def _encode_result(raw: bytes, byte_limit: int) -> bytes:
|
|
616
|
-
if len(raw) <= byte_limit:
|
|
617
|
-
return raw
|
|
618
|
-
level = _pick_level(len(raw), byte_limit)
|
|
619
|
-
compressed = zlib.compress(raw, level)
|
|
620
|
-
if len(compressed) >= len(raw):
|
|
621
|
-
return raw
|
|
622
|
-
header = MAGIC + struct.pack(">BIB", FLAG_COMPRESSED, len(raw), level)
|
|
623
|
-
return header + compressed
|
|
624
|
-
|
|
625
|
-
def _decode_blob(blob: bytes) -> bytes:
|
|
626
|
-
# If it's framed (MAGIC + header), decompress; else return as-is.
|
|
627
|
-
if isinstance(blob, (bytes, bytearray)) and len(blob) >= 3 and blob[:3] == MAGIC:
|
|
628
|
-
if len(blob) >= 3 + 6:
|
|
629
|
-
flag, orig_len, level = struct.unpack(">BIB", blob[3:3+6])
|
|
630
|
-
if flag & FLAG_COMPRESSED:
|
|
631
|
-
raw = zlib.decompress(blob[3+6:])
|
|
632
|
-
# best-effort sanity check; don't hard-fail on mismatch
|
|
633
|
-
if isinstance(orig_len, int) and orig_len > 0 and len(raw) != orig_len:
|
|
634
|
-
return raw
|
|
635
|
-
return raw
|
|
636
|
-
return blob
|
|
549
|
+
FILE_DUMP_LIMIT = __FILE_DUMP_LIMIT__
|
|
550
|
+
TRANSACTION_ID = __TRANSACTION_ID__
|
|
637
551
|
|
|
638
552
|
def _needed_globals(fn) -> set[str]:
|
|
553
|
+
import dis
|
|
639
554
|
names = set()
|
|
640
555
|
try:
|
|
641
556
|
for ins in dis.get_instructions(fn):
|
|
@@ -657,34 +572,22 @@ def _apply_env(fn, env: dict, filter_used: bool):
|
|
|
657
572
|
return
|
|
658
573
|
|
|
659
574
|
env_g = env.get("globals") or {}
|
|
660
|
-
if env_g:
|
|
661
|
-
|
|
662
|
-
needed = _needed_globals(fn)
|
|
663
|
-
for name in needed:
|
|
664
|
-
if name in env_g:
|
|
665
|
-
g.setdefault(name, env_g[name])
|
|
666
|
-
else:
|
|
667
|
-
for name, val in env_g.items():
|
|
668
|
-
g.setdefault(name, val)
|
|
575
|
+
if not env_g:
|
|
576
|
+
return
|
|
669
577
|
|
|
670
|
-
|
|
578
|
+
if filter_used:
|
|
579
|
+
needed = _needed_globals(fn)
|
|
580
|
+
for name in needed:
|
|
581
|
+
if name in env_g:
|
|
582
|
+
g.setdefault(name, env_g[name])
|
|
583
|
+
else:
|
|
584
|
+
for name, val in env_g.items():
|
|
585
|
+
g.setdefault(name, val)
|
|
671
586
|
|
|
672
|
-
|
|
673
|
-
if pkg_root and pkg_root not in sys.path:
|
|
674
|
-
sys.path.insert(0, pkg_root)
|
|
587
|
+
serde = json.loads(__SERDE_JSON__)
|
|
675
588
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
mod = importlib.import_module(serde["module"])
|
|
679
|
-
fn = _resolve_attr_chain(mod, serde["qualname"])
|
|
680
|
-
elif kind == "dill":
|
|
681
|
-
fn = dill.loads(base64.b64decode(serde["dill_b64"]))
|
|
682
|
-
else:
|
|
683
|
-
if serde.get("module") and serde.get("qualname") and "<locals>" not in serde.get("qualname", ""):
|
|
684
|
-
mod = importlib.import_module(serde["module"])
|
|
685
|
-
fn = _resolve_attr_chain(mod, serde["qualname"])
|
|
686
|
-
else:
|
|
687
|
-
fn = dill.loads(base64.b64decode(serde["dill_b64"]))
|
|
589
|
+
cs = CallableSerde.load(serde, add_pkg_root_to_syspath=True)
|
|
590
|
+
fn = cs.materialize(add_pkg_root_to_syspath=True)
|
|
688
591
|
|
|
689
592
|
osenv = serde.get("osenv")
|
|
690
593
|
if osenv:
|
|
@@ -698,42 +601,127 @@ if env_b64:
|
|
|
698
601
|
_apply_env(fn, env, bool(meta.get("filter_used_globals", True)))
|
|
699
602
|
|
|
700
603
|
call_blob = base64.b64decode(__CALL_PAYLOAD_B64__)
|
|
701
|
-
call_raw =
|
|
604
|
+
call_raw = _decode_result_blob(call_blob)
|
|
702
605
|
args, kwargs = dill.loads(call_raw)
|
|
703
606
|
|
|
704
607
|
res = fn(*args, **kwargs)
|
|
705
|
-
raw = dill.dumps(res, recurse=True)
|
|
706
|
-
blob = _encode_result(raw, BYTE_LIMIT)
|
|
707
608
|
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
609
|
+
if isinstance(res, pandas.DataFrame):
|
|
610
|
+
dump_path = Workspace().shared_cache_path("/cmd/" + TRANSACTION_ID + ".parquet")
|
|
611
|
+
|
|
612
|
+
with dump_path.open(mode="wb") as f:
|
|
613
|
+
res.to_parquet(f)
|
|
614
|
+
|
|
615
|
+
blob = "DBXPATH:" + str(dump_path)
|
|
616
|
+
else:
|
|
617
|
+
raw = dill.dumps(res)
|
|
618
|
+
blob = _encode_result_blob(raw, BYTE_LIMIT)
|
|
619
|
+
|
|
620
|
+
if len(blob) > FILE_DUMP_LIMIT:
|
|
621
|
+
dump_path = Workspace().shared_cache_path("/cmd/" + TRANSACTION_ID)
|
|
622
|
+
|
|
623
|
+
with dump_path.open(mode="wb") as f:
|
|
624
|
+
f.write_all_bytes(data=blob)
|
|
625
|
+
|
|
626
|
+
blob = "DBXPATH:" + str(dump_path)
|
|
627
|
+
else:
|
|
628
|
+
blob = base64.b64encode(blob).decode('ascii')
|
|
711
629
|
|
|
712
|
-
|
|
630
|
+
sys.stdout.write(f"{RESULT_TAG}:{len(blob)}:{blob}\n")
|
|
631
|
+
sys.stdout.flush()
|
|
632
|
+
"""
|
|
633
|
+
|
|
634
|
+
return (
|
|
713
635
|
template
|
|
714
636
|
.replace("__RESULT_TAG__", repr(result_tag))
|
|
715
637
|
.replace("__BYTE_LIMIT__", str(int(byte_limit)))
|
|
716
638
|
.replace("__SERDE_JSON__", repr(serde_json))
|
|
717
639
|
.replace("__CALL_PAYLOAD_B64__", repr(call_payload_b64))
|
|
640
|
+
.replace("__FILE_DUMP_LIMIT__", str(int(file_dump_limit)))
|
|
641
|
+
.replace("__TRANSACTION_ID__", repr(str(transaction_id)))
|
|
718
642
|
)
|
|
719
643
|
|
|
720
|
-
return code
|
|
721
|
-
|
|
722
644
|
@staticmethod
|
|
723
|
-
def parse_command_result(
|
|
645
|
+
def parse_command_result(
|
|
646
|
+
output: str,
|
|
647
|
+
*,
|
|
648
|
+
result_tag: str = "__CALLABLE_SERDE_RESULT__",
|
|
649
|
+
workspace: Optional["Workspace"] = None
|
|
650
|
+
) -> Any:
|
|
724
651
|
"""
|
|
725
|
-
|
|
726
|
-
|
|
652
|
+
Expect last tagged line:
|
|
653
|
+
"{result_tag}:{blob_nbytes}:{b64}"
|
|
654
|
+
|
|
655
|
+
We use blob_nbytes to compute expected base64 char length and detect truncation
|
|
656
|
+
before decoding/decompressing.
|
|
727
657
|
"""
|
|
728
658
|
prefix = f"{result_tag}:"
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
659
|
+
if prefix not in output:
|
|
660
|
+
raise ValueError(f"Result tag not found in output: {result_tag}")
|
|
661
|
+
|
|
662
|
+
# Grab everything after the LAST occurrence of the tag
|
|
663
|
+
_, tail = output.rsplit(prefix, 1)
|
|
664
|
+
|
|
665
|
+
# Parse "{nbytes}:{b64}"
|
|
666
|
+
try:
|
|
667
|
+
nbytes_str, string_result = tail.split(":", 1)
|
|
668
|
+
except ValueError as e:
|
|
669
|
+
raise ValueError(
|
|
670
|
+
f"Malformed result line after tag {result_tag}. "
|
|
671
|
+
"Expected '{tag}:{nbytes}:{b64}'."
|
|
672
|
+
) from e
|
|
673
|
+
|
|
674
|
+
try:
|
|
675
|
+
content_length = int(nbytes_str)
|
|
676
|
+
except ValueError as e:
|
|
677
|
+
raise ValueError(f"Malformed byte count '{nbytes_str}' after tag {result_tag}") from e
|
|
678
|
+
|
|
679
|
+
if content_length < 0:
|
|
680
|
+
raise ValueError(f"Negative byte count {content_length} after tag {result_tag}")
|
|
681
|
+
|
|
682
|
+
string_result = string_result[:content_length]
|
|
683
|
+
|
|
684
|
+
if len(string_result) != content_length:
|
|
685
|
+
raise ValueError(
|
|
686
|
+
"Got truncated result content from command, got %s bytes and expected %s bytes" % (
|
|
687
|
+
len(string_result),
|
|
688
|
+
content_length
|
|
689
|
+
)
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
if string_result.startswith("DBXPATH:"):
|
|
693
|
+
from ..databricks.workspaces import Workspace
|
|
694
|
+
|
|
695
|
+
workspace = Workspace() if workspace is None else workspace
|
|
696
|
+
path = workspace.dbfs_path(
|
|
697
|
+
string_result.replace("DBXPATH:", "")
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
if path.name.endswith(".parquet"):
|
|
701
|
+
import pandas
|
|
702
|
+
|
|
703
|
+
with path.open(mode="rb") as f:
|
|
704
|
+
buf = io.BytesIO(f.read_all_bytes())
|
|
705
|
+
|
|
706
|
+
path.rmfile()
|
|
707
|
+
buf.seek(0)
|
|
708
|
+
return pandas.read_parquet(buf)
|
|
709
|
+
|
|
710
|
+
with path.open(mode="rb") as f:
|
|
711
|
+
blob = f.read_all_bytes()
|
|
712
|
+
|
|
713
|
+
path.rmfile()
|
|
714
|
+
else:
|
|
715
|
+
# Strict base64 decode (rejects junk chars)
|
|
716
|
+
try:
|
|
717
|
+
blob = base64.b64decode(string_result.encode("ascii"), validate=True)
|
|
718
|
+
except (UnicodeEncodeError, binascii.Error) as e:
|
|
719
|
+
raise ValueError("Invalid base64 payload after result tag (corrupted/contaminated).") from e
|
|
720
|
+
|
|
738
721
|
raw = _decode_result_blob(blob)
|
|
739
|
-
|
|
722
|
+
try:
|
|
723
|
+
result = dill.loads(raw)
|
|
724
|
+
except Exception as e:
|
|
725
|
+
raise ValueError("Failed to dill.loads decoded payload") from e
|
|
726
|
+
|
|
727
|
+
return result
|