ygg 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,43 +1,168 @@
1
- """Callable serialization helpers for cross-process execution."""
1
+ """Callable serialization helpers for cross-process execution.
2
+
3
+ Design goals:
4
+ - Prefer import-by-reference when possible (module + qualname), fallback to dill.
5
+ - Optional environment payload: selected globals and/or closure values.
6
+ - Cross-process bridge: generate a self-contained Python command string that:
7
+ 1) materializes the callable
8
+ 2) decodes args/kwargs payload
9
+ 3) executes
10
+ 4) emits a single tagged base64 line with a compressed result blob
11
+
12
+ Compression/framing:
13
+ - CS2 framing only (no CS1 logic).
14
+ - Frame header: MAGIC(3) + codec(u8) + orig_len(u32) + param(u8) + data
15
+ - Codecs:
16
+ 0 raw (rarely used; mostly means "no frame")
17
+ 1 zlib
18
+ 2 lzma
19
+ 3 zstd (optional dependency)
20
+ """
2
21
 
3
22
  from __future__ import annotations
4
23
 
5
24
  import base64
25
+ import binascii
6
26
  import dis
7
27
  import importlib
8
28
  import inspect
9
- import json
29
+ import io
30
+ import lzma
10
31
  import os
32
+ import secrets
11
33
  import struct
12
34
  import sys
13
35
  import zlib
14
36
  from dataclasses import dataclass
15
37
  from pathlib import Path
16
- from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar, Union, Iterable
38
+ from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, TypeVar, Union, TYPE_CHECKING
17
39
 
18
40
  import dill
19
41
 
42
+ if TYPE_CHECKING:
43
+ from ..databricks.workspaces import Workspace
44
+
20
45
  __all__ = ["CallableSerde"]
21
46
 
22
47
  T = TypeVar("T", bound="CallableSerde")
23
48
 
49
+ # ---------------------------
50
+ # Framing / compression (CS2)
51
+ # ---------------------------
24
52
 
25
- # ---------- internal helpers ----------
53
+ _MAGIC = b"CS2"
26
54
 
27
- _MAGIC = b"CS1" # CallableSerde framing v1
28
- _FLAG_COMPRESSED = 1
55
+ _CODEC_RAW = 0
56
+ _CODEC_ZLIB = 1
57
+ _CODEC_LZMA = 2
58
+ _CODEC_ZSTD = 3
29
59
 
30
60
 
31
- def _resolve_attr_chain(mod: Any, qualname: str) -> Any:
32
- """Resolve a dotted attribute path from a module.
61
+ def _try_import_zstd():
62
+ try:
63
+ import zstandard as zstd # type: ignore
64
+ return zstd
65
+ except Exception:
66
+ return None
33
67
 
34
- Args:
35
- mod: Module to traverse.
36
- qualname: Dotted qualified name.
37
68
 
38
- Returns:
39
- Resolved attribute.
40
- """
69
+ def _pick_zlib_level(n: int, limit: int) -> int:
70
+ """Ramp compression level 1..9 based on how far we exceed the byte_limit."""
71
+ ratio = n / max(1, limit)
72
+ x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
73
+ return max(1, min(9, int(round(1 + 8 * x))))
74
+
75
+
76
+ def _frame(codec: int, orig_len: int, param: int, payload: bytes) -> bytes:
77
+ return _MAGIC + struct.pack(">BIB", int(codec) & 0xFF, int(orig_len), int(param) & 0xFF) + payload
78
+
79
+
80
+ def _encode_with_candidates(raw: bytes, *, byte_limit: int, allow_zstd: bool) -> bytes:
81
+ """Choose the smallest among available codecs; fall back to raw if not beneficial."""
82
+ if len(raw) <= byte_limit:
83
+ return raw
84
+
85
+ candidates: list[bytes] = []
86
+
87
+ if allow_zstd:
88
+ zstd = _try_import_zstd()
89
+ if zstd is not None:
90
+ for lvl in (6, 10, 15):
91
+ try:
92
+ c = zstd.ZstdCompressor(level=lvl).compress(raw)
93
+ candidates.append(_frame(_CODEC_ZSTD, len(raw), lvl, c))
94
+ except Exception:
95
+ pass
96
+
97
+ for preset in (6, 9):
98
+ try:
99
+ c = lzma.compress(raw, preset=preset)
100
+ candidates.append(_frame(_CODEC_LZMA, len(raw), preset, c))
101
+ except Exception:
102
+ pass
103
+
104
+ lvl = _pick_zlib_level(len(raw), byte_limit)
105
+ try:
106
+ c = zlib.compress(raw, lvl)
107
+ candidates.append(_frame(_CODEC_ZLIB, len(raw), lvl, c))
108
+ except Exception:
109
+ pass
110
+
111
+ if not candidates:
112
+ return raw
113
+
114
+ best = min(candidates, key=len)
115
+ return best if len(best) < len(raw) else raw
116
+
117
+
118
+ def _encode_result_blob(raw: bytes, byte_limit: int) -> bytes:
119
+ """Result payload: zstd (if available) -> lzma -> zlib."""
120
+ return _encode_with_candidates(raw, byte_limit=byte_limit, allow_zstd=True)
121
+
122
+
123
+ def _encode_wire_blob_stdlib(raw: bytes, byte_limit: int) -> bytes:
124
+ """Wire payload (args/kwargs): stdlib-only (lzma -> zlib)."""
125
+ return _encode_with_candidates(raw, byte_limit=byte_limit, allow_zstd=False)
126
+
127
+
128
+ def _decode_result_blob(blob: bytes) -> bytes:
129
+ """Decode raw or CS2 framed data (no CS1 support)."""
130
+ if not isinstance(blob, (bytes, bytearray)) or len(blob) < 3:
131
+ return blob # type: ignore[return-value]
132
+
133
+ if not blob.startswith(_MAGIC):
134
+ return blob
135
+
136
+ if len(blob) < 3 + 6:
137
+ raise ValueError("CS2 framed blob too short / truncated.")
138
+
139
+ codec, orig_len, _param = struct.unpack(">BIB", blob[3 : 3 + 6])
140
+ data = blob[3 + 6 :]
141
+
142
+ if codec == _CODEC_RAW:
143
+ raw = data
144
+ elif codec == _CODEC_ZLIB:
145
+ raw = zlib.decompress(data)
146
+ elif codec == _CODEC_LZMA:
147
+ raw = lzma.decompress(data)
148
+ elif codec == _CODEC_ZSTD:
149
+ zstd = _try_import_zstd()
150
+ if zstd is None:
151
+ raise RuntimeError("CS2 uses zstd but 'zstandard' is not installed.")
152
+ raw = zstd.ZstdDecompressor().decompress(data, max_output_size=int(orig_len) if orig_len else 0)
153
+ else:
154
+ raise ValueError(f"Unknown CS2 codec: {codec}")
155
+
156
+ if orig_len and len(raw) != orig_len:
157
+ raise ValueError(f"Decoded length mismatch: got {len(raw)}, expected {orig_len}")
158
+ return raw
159
+
160
+
161
+ # ---------------------------
162
+ # Callable reference helpers
163
+ # ---------------------------
164
+
165
+ def _resolve_attr_chain(mod: Any, qualname: str) -> Any:
41
166
  obj = mod
42
167
  for part in qualname.split("."):
43
168
  obj = getattr(obj, part)
@@ -45,10 +170,6 @@ def _resolve_attr_chain(mod: Any, qualname: str) -> Any:
45
170
 
46
171
 
47
172
  def _find_pkg_root_from_file(file_path: Path) -> Optional[Path]:
48
- """
49
- Walk up parents while __init__.py exists.
50
- Return the directory that should be on sys.path (parent of top package dir).
51
- """
52
173
  file_path = file_path.resolve()
53
174
  d = file_path.parent
54
175
 
@@ -61,14 +182,6 @@ def _find_pkg_root_from_file(file_path: Path) -> Optional[Path]:
61
182
 
62
183
 
63
184
  def _callable_file_line(fn: Callable[..., Any]) -> Tuple[Optional[str], Optional[int]]:
64
- """Return the source file path and line number for a callable.
65
-
66
- Args:
67
- fn: Callable to inspect.
68
-
69
- Returns:
70
- Tuple of (file path, line number).
71
- """
72
185
  file = None
73
186
  line = None
74
187
  try:
@@ -84,17 +197,12 @@ def _callable_file_line(fn: Callable[..., Any]) -> Tuple[Optional[str], Optional
84
197
 
85
198
 
86
199
  def _referenced_global_names(fn: Callable[..., Any]) -> Set[str]:
87
- """
88
- Names that the function *actually* resolves from globals/namespaces at runtime.
89
- Uses bytecode to avoid shipping random junk.
90
- """
91
200
  names: Set[str] = set()
92
201
  try:
93
202
  for ins in dis.get_instructions(fn):
94
203
  if ins.opname in ("LOAD_GLOBAL", "LOAD_NAME") and isinstance(ins.argval, str):
95
204
  names.add(ins.argval)
96
205
  except Exception:
97
- # fallback: less precise
98
206
  try:
99
207
  names.update(getattr(fn.__code__, "co_names", ()) or ())
100
208
  except Exception:
@@ -105,14 +213,6 @@ def _referenced_global_names(fn: Callable[..., Any]) -> Set[str]:
105
213
 
106
214
 
107
215
  def _is_importable_reference(fn: Callable[..., Any]) -> bool:
108
- """Return True when a callable can be imported by module and qualname.
109
-
110
- Args:
111
- fn: Callable to inspect.
112
-
113
- Returns:
114
- True if importable by module/qualname.
115
- """
116
216
  mod_name = getattr(fn, "__module__", None)
117
217
  qualname = getattr(fn, "__qualname__", None)
118
218
  if not mod_name or not qualname:
@@ -127,61 +227,9 @@ def _is_importable_reference(fn: Callable[..., Any]) -> bool:
127
227
  return False
128
228
 
129
229
 
130
- def _pick_zlib_level(n: int, limit: int) -> int:
131
- """
132
- Ramp compression level 1..9 based on how much payload exceeds byte_limit.
133
- ratio=1 -> level=1
134
- ratio=4 -> level=9
135
- clamp beyond.
136
- """
137
- ratio = n / max(1, limit)
138
- x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
139
- return max(1, min(9, int(round(1 + 8 * x))))
140
-
141
-
142
- def _encode_result_blob(raw: bytes, byte_limit: int) -> bytes:
143
- """
144
- For small payloads: return raw dill bytes (backwards compat).
145
- For large payloads: wrap in framed header + zlib-compressed bytes (if beneficial).
146
- """
147
- if len(raw) <= byte_limit:
148
- return raw
149
-
150
- level = _pick_zlib_level(len(raw), byte_limit)
151
- compressed = zlib.compress(raw, level)
152
-
153
- # If compression doesn't help, keep raw
154
- if len(compressed) >= len(raw):
155
- return raw
156
-
157
- # Frame:
158
- # MAGIC(3) + flags(u8) + orig_len(u32) + level(u8) + data
159
- header = _MAGIC + struct.pack(">BIB", _FLAG_COMPRESSED, len(raw), level)
160
- return header + compressed
161
-
162
-
163
- def _decode_result_blob(blob: bytes) -> bytes:
164
- """
165
- If framed, decompress if flagged, return raw dill bytes.
166
- Else treat as raw dill bytes.
167
- """
168
- if not blob.startswith(_MAGIC):
169
- return blob
170
-
171
- if len(blob) < 3 + 1 + 4 + 1:
172
- raise ValueError("Framed result too short / corrupted.")
173
-
174
- flags, orig_len, _level = struct.unpack(">BIB", blob[3 : 3 + 1 + 4 + 1])
175
- data = blob[3 + 1 + 4 + 1 :]
176
-
177
- if flags & _FLAG_COMPRESSED:
178
- raw = zlib.decompress(data)
179
- if orig_len and len(raw) != orig_len:
180
- raise ValueError(f"Decompressed length mismatch: got {len(raw)}, expected {orig_len}")
181
- return raw
182
-
183
- return data
184
-
230
+ # ---------------------------
231
+ # Environment snapshot
232
+ # ---------------------------
185
233
 
186
234
  def _dump_env(
187
235
  fn: Callable[..., Any],
@@ -190,12 +238,6 @@ def _dump_env(
190
238
  include_closure: bool,
191
239
  filter_used_globals: bool,
192
240
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
193
- """
194
- Returns (env, meta).
195
- env is dill-able and contains:
196
- - "globals": {name: value} (filtered to used names if enabled)
197
- - "closure": {freevar: value} (capture only; injection not generally safe)
198
- """
199
241
  env: Dict[str, Any] = {}
200
242
  meta: Dict[str, Any] = {
201
243
  "missing_globals": [],
@@ -242,13 +284,14 @@ def _dump_env(
242
284
  return env, meta
243
285
 
244
286
 
245
- # ---------- main class ----------
287
+ # ----------
288
+ # Main class
289
+ # ----------
246
290
 
247
291
  @dataclass
248
292
  class CallableSerde:
249
293
  """
250
294
  Core field: `fn`
251
- Serialized/backing fields used when fn isn't present yet.
252
295
 
253
296
  kind:
254
297
  - "auto": resolve import if possible else dill
@@ -258,6 +301,7 @@ class CallableSerde:
258
301
  Optional env payload:
259
302
  - env_b64: dill(base64) of {"globals": {...}, "closure": {...}}
260
303
  """
304
+
261
305
  fn: Optional[Callable[..., Any]] = None
262
306
 
263
307
  _kind: str = "auto" # "auto" | "import" | "dill"
@@ -273,48 +317,22 @@ class CallableSerde:
273
317
 
274
318
  @classmethod
275
319
  def from_callable(cls: type[T], x: Union[Callable[..., Any], T]) -> T:
276
- """Create a CallableSerde from a callable or existing instance.
277
-
278
- Args:
279
- x: Callable or CallableSerde instance.
280
-
281
- Returns:
282
- CallableSerde instance.
283
- """
284
320
  if isinstance(x, cls):
285
321
  return x
322
+ return cls(fn=x) # type: ignore[return-value]
286
323
 
287
- obj = cls(fn=x) # type: ignore[return-value]
288
-
289
- return obj
290
-
291
- # ----- lazy-ish properties (computed on access) -----
324
+ # ----- properties -----
292
325
 
293
326
  @property
294
327
  def module(self) -> Optional[str]:
295
- """Return the callable's module name if available.
296
-
297
- Returns:
298
- Module name or None.
299
- """
300
328
  return self._module or (getattr(self.fn, "__module__", None) if self.fn else None)
301
329
 
302
330
  @property
303
331
  def qualname(self) -> Optional[str]:
304
- """Return the callable's qualified name if available.
305
-
306
- Returns:
307
- Qualified name or None.
308
- """
309
332
  return self._qualname or (getattr(self.fn, "__qualname__", None) if self.fn else None)
310
333
 
311
334
  @property
312
335
  def file(self) -> Optional[str]:
313
- """Return the filesystem path of the callable's source file.
314
-
315
- Returns:
316
- File path or None.
317
- """
318
336
  if not self.fn:
319
337
  return None
320
338
  f, _ = _callable_file_line(self.fn)
@@ -322,11 +340,6 @@ class CallableSerde:
322
340
 
323
341
  @property
324
342
  def line(self) -> Optional[int]:
325
- """Return the line number where the callable is defined.
326
-
327
- Returns:
328
- Line number or None.
329
- """
330
343
  if not self.fn:
331
344
  return None
332
345
  _, ln = _callable_file_line(self.fn)
@@ -334,11 +347,6 @@ class CallableSerde:
334
347
 
335
348
  @property
336
349
  def pkg_root(self) -> Optional[str]:
337
- """Return the inferred package root for the callable, if known.
338
-
339
- Returns:
340
- Package root path or None.
341
- """
342
350
  if self._pkg_root:
343
351
  return self._pkg_root
344
352
  if not self.file:
@@ -348,11 +356,6 @@ class CallableSerde:
348
356
 
349
357
  @property
350
358
  def relpath_from_pkg_root(self) -> Optional[str]:
351
- """Return the callable's path relative to the package root.
352
-
353
- Returns:
354
- Relative path or None.
355
- """
356
359
  if not self.file or not self.pkg_root:
357
360
  return None
358
361
  try:
@@ -362,11 +365,6 @@ class CallableSerde:
362
365
 
363
366
  @property
364
367
  def importable(self) -> bool:
365
- """Return True when the callable can be imported by reference.
366
-
367
- Returns:
368
- True if importable by module/qualname.
369
- """
370
368
  if self.fn is None:
371
369
  return bool(self.module and self.qualname and "<locals>" not in (self.qualname or ""))
372
370
  return _is_importable_reference(self.fn)
@@ -376,24 +374,12 @@ class CallableSerde:
376
374
  def dump(
377
375
  self,
378
376
  *,
379
- prefer: str = "import", # "import" | "dill"
380
- dump_env: str = "none", # "none" | "globals" | "closure" | "both"
377
+ prefer: str = "import", # "import" | "dill"
378
+ dump_env: str = "none", # "none" | "globals" | "closure" | "both"
381
379
  filter_used_globals: bool = True,
382
380
  env_keys: Optional[Iterable[str]] = None,
383
381
  env_variables: Optional[Dict[str, str]] = None,
384
382
  ) -> Dict[str, Any]:
385
- """Serialize the callable into a dict for transport.
386
-
387
- Args:
388
- prefer: Preferred serialization kind.
389
- dump_env: Environment payload selection.
390
- filter_used_globals: Filter globals to referenced names.
391
- env_keys: environment keys
392
- env_variables: environment key values
393
-
394
- Returns:
395
- Serialized payload dict.
396
- """
397
383
  kind = prefer
398
384
  if kind == "import" and not self.importable:
399
385
  kind = "dill"
@@ -420,7 +406,6 @@ class CallableSerde:
420
406
  if env_keys:
421
407
  for env_key in env_keys:
422
408
  existing = os.getenv(env_key)
423
-
424
409
  if existing:
425
410
  env_variables[env_key] = existing
426
411
 
@@ -432,6 +417,7 @@ class CallableSerde:
432
417
  raise ValueError("dump_env requested but fn is not present.")
433
418
  include_globals = dump_env in ("globals", "both")
434
419
  include_closure = dump_env in ("closure", "both")
420
+
435
421
  env, meta = _dump_env(
436
422
  self.fn,
437
423
  include_globals=include_globals,
@@ -448,15 +434,6 @@ class CallableSerde:
448
434
 
449
435
  @classmethod
450
436
  def load(cls: type[T], d: Dict[str, Any], *, add_pkg_root_to_syspath: bool = True) -> T:
451
- """Construct a CallableSerde from a serialized dict payload.
452
-
453
- Args:
454
- d: Serialized payload dict.
455
- add_pkg_root_to_syspath: Add package root to sys.path if True.
456
-
457
- Returns:
458
- CallableSerde instance.
459
- """
460
437
  obj = cls(
461
438
  fn=None,
462
439
  _kind=d.get("kind", "auto"),
@@ -474,14 +451,6 @@ class CallableSerde:
474
451
  return obj # type: ignore[return-value]
475
452
 
476
453
  def materialize(self, *, add_pkg_root_to_syspath: bool = True) -> Callable[..., Any]:
477
- """Resolve and return the underlying callable.
478
-
479
- Args:
480
- add_pkg_root_to_syspath: Add package root to sys.path if True.
481
-
482
- Returns:
483
- Resolved callable.
484
- """
485
454
  if self.fn is not None:
486
455
  return self.fn
487
456
 
@@ -515,19 +484,12 @@ class CallableSerde:
515
484
  raise ValueError(f"Unknown kind: {kind}")
516
485
 
517
486
  def __call__(self, *args: Any, **kwargs: Any) -> Any:
518
- """Invoke the materialized callable with the provided arguments.
519
-
520
- Args:
521
- *args: Positional args for the callable.
522
- **kwargs: Keyword args for the callable.
523
-
524
- Returns:
525
- Callable return value.
526
- """
527
487
  fn = self.materialize()
528
488
  return fn(*args, **kwargs)
529
489
 
530
- # ----- command execution bridge -----
490
+ # -------------------------
491
+ # Command execution bridge
492
+ # -------------------------
531
493
 
532
494
  def to_command(
533
495
  self,
@@ -536,24 +498,21 @@ class CallableSerde:
536
498
  *,
537
499
  result_tag: str = "__CALLABLE_SERDE_RESULT__",
538
500
  prefer: str = "dill",
539
- byte_limit: int = 256_000,
540
- dump_env: str = "none", # "none" | "globals" | "closure" | "both"
501
+ byte_limit: int = 64 * 1024,
502
+ dump_env: str = "none", # "none" | "globals" | "closure" | "both"
541
503
  filter_used_globals: bool = True,
542
504
  env_keys: Optional[Iterable[str]] = None,
543
505
  env_variables: Optional[Dict[str, str]] = None,
506
+ file_dump_limit: int = 512 * 1024,
507
+ transaction_id: Optional[str] = None
544
508
  ) -> str:
545
509
  """
546
510
  Returns Python code string to execute in another interpreter.
547
- Prints one line: "{result_tag}:{base64(blob)}"
548
- where blob is raw dill bytes or framed+zlib.
549
-
550
- Also compresses the input call payload (args/kwargs) using the same framing
551
- scheme when it exceeds byte_limit.
511
+ Emits exactly one line to stdout:
512
+ "{result_tag}:{base64(blob)}\\n"
513
+ where blob is raw dill bytes or CS2 framed.
552
514
  """
553
- import base64
554
515
  import json
555
- import struct
556
- import zlib
557
516
 
558
517
  args = args or ()
559
518
  kwargs = kwargs or {}
@@ -567,75 +526,31 @@ class CallableSerde:
567
526
  )
568
527
  serde_json = json.dumps(serde_dict, ensure_ascii=False)
569
528
 
570
- # --- input payload compression (args/kwargs) ---
571
- MAGIC = b"CS1"
572
- FLAG_COMPRESSED = 1
573
-
574
- def _pick_level(n: int, limit: int) -> int:
575
- ratio = n / max(1, limit)
576
- x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
577
- return max(1, min(9, int(round(1 + 8 * x))))
578
-
579
- def _encode_blob(raw: bytes, limit: int) -> bytes:
580
- if len(raw) <= limit:
581
- return raw
582
- level = _pick_level(len(raw), limit)
583
- compressed = zlib.compress(raw, level)
584
- if len(compressed) >= len(raw):
585
- return raw
586
- header = MAGIC + struct.pack(">BIB", FLAG_COMPRESSED, len(raw), level)
587
- return header + compressed
588
-
529
+ # args/kwargs payload: stdlib-only compression (lzma/zlib)
589
530
  call_raw = dill.dumps((args, kwargs), recurse=True)
590
- call_blob = _encode_blob(call_raw, int(byte_limit))
531
+ call_blob = _encode_wire_blob_stdlib(call_raw, int(byte_limit))
591
532
  call_payload_b64 = base64.b64encode(call_blob).decode("ascii")
533
+ transaction_id = transaction_id or secrets.token_urlsafe(16)
592
534
 
593
- # NOTE: plain string template + replace. No f-string. No brace escaping.
594
535
  template = r"""
595
- import base64, json, sys, struct, zlib, importlib, dis, os
536
+ import base64, json, os, sys
596
537
  import dill
538
+ import pandas
539
+
540
+ from yggdrasil.databricks import Workspace
541
+ from yggdrasil.pyutils.callable_serde import (
542
+ CallableSerde,
543
+ _decode_result_blob,
544
+ _encode_result_blob,
545
+ )
597
546
 
598
547
  RESULT_TAG = __RESULT_TAG__
599
548
  BYTE_LIMIT = __BYTE_LIMIT__
600
-
601
- MAGIC = b"CS1"
602
- FLAG_COMPRESSED = 1
603
-
604
- def _resolve_attr_chain(mod, qualname: str):
605
- obj = mod
606
- for part in qualname.split("."):
607
- obj = getattr(obj, part)
608
- return obj
609
-
610
- def _pick_level(n: int, limit: int) -> int:
611
- ratio = n / max(1, limit)
612
- x = min(1.0, max(0.0, (ratio - 1.0) / 3.0))
613
- return max(1, min(9, int(round(1 + 8 * x))))
614
-
615
- def _encode_result(raw: bytes, byte_limit: int) -> bytes:
616
- if len(raw) <= byte_limit:
617
- return raw
618
- level = _pick_level(len(raw), byte_limit)
619
- compressed = zlib.compress(raw, level)
620
- if len(compressed) >= len(raw):
621
- return raw
622
- header = MAGIC + struct.pack(">BIB", FLAG_COMPRESSED, len(raw), level)
623
- return header + compressed
624
-
625
- def _decode_blob(blob: bytes) -> bytes:
626
- # If it's framed (MAGIC + header), decompress; else return as-is.
627
- if isinstance(blob, (bytes, bytearray)) and len(blob) >= 3 and blob[:3] == MAGIC:
628
- if len(blob) >= 3 + 6:
629
- flag, orig_len, level = struct.unpack(">BIB", blob[3:3+6])
630
- if flag & FLAG_COMPRESSED:
631
- raw = zlib.decompress(blob[3+6:])
632
- # best-effort sanity check; don't hard-fail on mismatch
633
- if isinstance(orig_len, int) and orig_len > 0 and len(raw) != orig_len:
634
- return raw
635
- return raw
636
- return blob
549
+ FILE_DUMP_LIMIT = __FILE_DUMP_LIMIT__
550
+ TRANSACTION_ID = __TRANSACTION_ID__
637
551
 
638
552
  def _needed_globals(fn) -> set[str]:
553
+ import dis
639
554
  names = set()
640
555
  try:
641
556
  for ins in dis.get_instructions(fn):
@@ -657,34 +572,22 @@ def _apply_env(fn, env: dict, filter_used: bool):
657
572
  return
658
573
 
659
574
  env_g = env.get("globals") or {}
660
- if env_g:
661
- if filter_used:
662
- needed = _needed_globals(fn)
663
- for name in needed:
664
- if name in env_g:
665
- g.setdefault(name, env_g[name])
666
- else:
667
- for name, val in env_g.items():
668
- g.setdefault(name, val)
575
+ if not env_g:
576
+ return
669
577
 
670
- serde = json.loads(__SERDE_JSON__)
578
+ if filter_used:
579
+ needed = _needed_globals(fn)
580
+ for name in needed:
581
+ if name in env_g:
582
+ g.setdefault(name, env_g[name])
583
+ else:
584
+ for name, val in env_g.items():
585
+ g.setdefault(name, val)
671
586
 
672
- pkg_root = serde.get("pkg_root")
673
- if pkg_root and pkg_root not in sys.path:
674
- sys.path.insert(0, pkg_root)
587
+ serde = json.loads(__SERDE_JSON__)
675
588
 
676
- kind = serde.get("kind")
677
- if kind == "import":
678
- mod = importlib.import_module(serde["module"])
679
- fn = _resolve_attr_chain(mod, serde["qualname"])
680
- elif kind == "dill":
681
- fn = dill.loads(base64.b64decode(serde["dill_b64"]))
682
- else:
683
- if serde.get("module") and serde.get("qualname") and "<locals>" not in serde.get("qualname", ""):
684
- mod = importlib.import_module(serde["module"])
685
- fn = _resolve_attr_chain(mod, serde["qualname"])
686
- else:
687
- fn = dill.loads(base64.b64decode(serde["dill_b64"]))
589
+ cs = CallableSerde.load(serde, add_pkg_root_to_syspath=True)
590
+ fn = cs.materialize(add_pkg_root_to_syspath=True)
688
591
 
689
592
  osenv = serde.get("osenv")
690
593
  if osenv:
@@ -698,42 +601,127 @@ if env_b64:
698
601
  _apply_env(fn, env, bool(meta.get("filter_used_globals", True)))
699
602
 
700
603
  call_blob = base64.b64decode(__CALL_PAYLOAD_B64__)
701
- call_raw = _decode_blob(call_blob)
604
+ call_raw = _decode_result_blob(call_blob)
702
605
  args, kwargs = dill.loads(call_raw)
703
606
 
704
607
  res = fn(*args, **kwargs)
705
- raw = dill.dumps(res, recurse=True)
706
- blob = _encode_result(raw, BYTE_LIMIT)
707
608
 
708
- # No f-string. No braces. No drama.
709
- sys.stdout.write(str(RESULT_TAG) + ":" + base64.b64encode(blob).decode("ascii") + "\n")
710
- """.strip()
609
+ if isinstance(res, pandas.DataFrame):
610
+ dump_path = Workspace().shared_cache_path("/cmd/" + TRANSACTION_ID + ".parquet")
611
+
612
+ with dump_path.open(mode="wb") as f:
613
+ res.to_parquet(f)
614
+
615
+ blob = "DBXPATH:" + str(dump_path)
616
+ else:
617
+ raw = dill.dumps(res)
618
+ blob = _encode_result_blob(raw, BYTE_LIMIT)
619
+
620
+ if len(blob) > FILE_DUMP_LIMIT:
621
+ dump_path = Workspace().shared_cache_path("/cmd/" + TRANSACTION_ID)
622
+
623
+ with dump_path.open(mode="wb") as f:
624
+ f.write_all_bytes(data=blob)
625
+
626
+ blob = "DBXPATH:" + str(dump_path)
627
+ else:
628
+ blob = base64.b64encode(blob).decode('ascii')
711
629
 
712
- code = (
630
+ sys.stdout.write(f"{RESULT_TAG}:{len(blob)}:{blob}\n")
631
+ sys.stdout.flush()
632
+ """
633
+
634
+ return (
713
635
  template
714
636
  .replace("__RESULT_TAG__", repr(result_tag))
715
637
  .replace("__BYTE_LIMIT__", str(int(byte_limit)))
716
638
  .replace("__SERDE_JSON__", repr(serde_json))
717
639
  .replace("__CALL_PAYLOAD_B64__", repr(call_payload_b64))
640
+ .replace("__FILE_DUMP_LIMIT__", str(int(file_dump_limit)))
641
+ .replace("__TRANSACTION_ID__", repr(str(transaction_id)))
718
642
  )
719
643
 
720
- return code
721
-
722
644
  @staticmethod
723
- def parse_command_result(output: str, *, result_tag: str = "__CALLABLE_SERDE_RESULT__") -> Any:
645
+ def parse_command_result(
646
+ output: str,
647
+ *,
648
+ result_tag: str = "__CALLABLE_SERDE_RESULT__",
649
+ workspace: Optional["Workspace"] = None
650
+ ) -> Any:
724
651
  """
725
- Parse stdout/stderr combined text, find last "{result_tag}:{b64}" line.
726
- Supports raw dill or framed+zlib compressed payloads.
652
+ Expect last tagged line:
653
+ "{result_tag}:{blob_nbytes}:{b64}"
654
+
655
+ We use blob_nbytes to compute expected base64 char length and detect truncation
656
+ before decoding/decompressing.
727
657
  """
728
658
  prefix = f"{result_tag}:"
729
- b64 = None
730
- for line in reversed(output.splitlines()):
731
- if line.startswith(prefix):
732
- b64 = line[len(prefix):].strip()
733
- break
734
- if not b64:
735
- raise ValueError(f"Result tag not found in output: {result_tag!r}")
736
-
737
- blob = base64.b64decode(b64.encode("ascii"))
659
+ if prefix not in output:
660
+ raise ValueError(f"Result tag not found in output: {result_tag}")
661
+
662
+ # Grab everything after the LAST occurrence of the tag
663
+ _, tail = output.rsplit(prefix, 1)
664
+
665
+ # Parse "{nbytes}:{b64}"
666
+ try:
667
+ nbytes_str, string_result = tail.split(":", 1)
668
+ except ValueError as e:
669
+ raise ValueError(
670
+ f"Malformed result line after tag {result_tag}. "
671
+ "Expected '{tag}:{nbytes}:{b64}'."
672
+ ) from e
673
+
674
+ try:
675
+ content_length = int(nbytes_str)
676
+ except ValueError as e:
677
+ raise ValueError(f"Malformed byte count '{nbytes_str}' after tag {result_tag}") from e
678
+
679
+ if content_length < 0:
680
+ raise ValueError(f"Negative byte count {content_length} after tag {result_tag}")
681
+
682
+ string_result = string_result[:content_length]
683
+
684
+ if len(string_result) != content_length:
685
+ raise ValueError(
686
+ "Got truncated result content from command, got %s bytes and expected %s bytes" % (
687
+ len(string_result),
688
+ content_length
689
+ )
690
+ )
691
+
692
+ if string_result.startswith("DBXPATH:"):
693
+ from ..databricks.workspaces import Workspace
694
+
695
+ workspace = Workspace() if workspace is None else workspace
696
+ path = workspace.dbfs_path(
697
+ string_result.replace("DBXPATH:", "")
698
+ )
699
+
700
+ if path.name.endswith(".parquet"):
701
+ import pandas
702
+
703
+ with path.open(mode="rb") as f:
704
+ buf = io.BytesIO(f.read_all_bytes())
705
+
706
+ path.rmfile()
707
+ buf.seek(0)
708
+ return pandas.read_parquet(buf)
709
+
710
+ with path.open(mode="rb") as f:
711
+ blob = f.read_all_bytes()
712
+
713
+ path.rmfile()
714
+ else:
715
+ # Strict base64 decode (rejects junk chars)
716
+ try:
717
+ blob = base64.b64decode(string_result.encode("ascii"), validate=True)
718
+ except (UnicodeEncodeError, binascii.Error) as e:
719
+ raise ValueError("Invalid base64 payload after result tag (corrupted/contaminated).") from e
720
+
738
721
  raw = _decode_result_blob(blob)
739
- return dill.loads(raw)
722
+ try:
723
+ result = dill.loads(raw)
724
+ except Exception as e:
725
+ raise ValueError("Failed to dill.loads decoded payload") from e
726
+
727
+ return result