ygg 0.1.56__py3-none-any.whl → 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/METADATA +1 -1
  2. ygg-0.1.60.dist-info/RECORD +74 -0
  3. {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/WHEEL +1 -1
  4. yggdrasil/ai/__init__.py +2 -0
  5. yggdrasil/ai/session.py +89 -0
  6. yggdrasil/ai/sql_session.py +310 -0
  7. yggdrasil/databricks/__init__.py +0 -3
  8. yggdrasil/databricks/compute/cluster.py +68 -113
  9. yggdrasil/databricks/compute/command_execution.py +674 -0
  10. yggdrasil/databricks/compute/exceptions.py +7 -2
  11. yggdrasil/databricks/compute/execution_context.py +465 -277
  12. yggdrasil/databricks/compute/remote.py +4 -14
  13. yggdrasil/databricks/exceptions.py +10 -0
  14. yggdrasil/databricks/sql/__init__.py +0 -4
  15. yggdrasil/databricks/sql/engine.py +161 -173
  16. yggdrasil/databricks/sql/exceptions.py +9 -1
  17. yggdrasil/databricks/sql/statement_result.py +108 -120
  18. yggdrasil/databricks/sql/warehouse.py +331 -92
  19. yggdrasil/databricks/workspaces/io.py +92 -9
  20. yggdrasil/databricks/workspaces/path.py +120 -74
  21. yggdrasil/databricks/workspaces/workspace.py +212 -68
  22. yggdrasil/libs/databrickslib.py +23 -18
  23. yggdrasil/libs/extensions/spark_extensions.py +1 -1
  24. yggdrasil/libs/pandaslib.py +15 -6
  25. yggdrasil/libs/polarslib.py +49 -13
  26. yggdrasil/pyutils/__init__.py +1 -0
  27. yggdrasil/pyutils/callable_serde.py +12 -19
  28. yggdrasil/pyutils/exceptions.py +16 -0
  29. yggdrasil/pyutils/mimetypes.py +0 -0
  30. yggdrasil/pyutils/python_env.py +13 -12
  31. yggdrasil/pyutils/waiting_config.py +171 -0
  32. yggdrasil/types/cast/arrow_cast.py +3 -0
  33. yggdrasil/types/cast/pandas_cast.py +157 -169
  34. yggdrasil/types/cast/polars_cast.py +11 -43
  35. yggdrasil/types/dummy_class.py +81 -0
  36. yggdrasil/version.py +1 -1
  37. ygg-0.1.56.dist-info/RECORD +0 -68
  38. yggdrasil/databricks/ai/__init__.py +0 -1
  39. yggdrasil/databricks/ai/loki.py +0 -374
  40. {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/entry_points.txt +0 -0
  41. {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/licenses/LICENSE +0 -0
  42. {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@
2
2
 
3
3
  import base64
4
4
  import dataclasses as dc
5
- import datetime as dt
6
5
  import io
7
6
  import json
8
7
  import logging
@@ -12,29 +11,57 @@ import re
12
11
  import sys
13
12
  import threading
14
13
  import zipfile
14
+ from concurrent.futures import ThreadPoolExecutor
15
15
  from threading import Thread
16
16
  from types import ModuleType
17
17
  from typing import TYPE_CHECKING, Optional, Any, Callable, List, Dict, Union, Iterable, Tuple
18
18
 
19
- from .exceptions import CommandAborted
20
- from ...libs.databrickslib import databricks_sdk
19
+ from .command_execution import CommandExecution
20
+ from .exceptions import ClientTerminatedSession
21
+ from ...libs.databrickslib import databricks_sdk, DatabricksDummyClass
22
+ from ...pyutils.callable_serde import CallableSerde
21
23
  from ...pyutils.exceptions import raise_parsed_traceback
22
24
  from ...pyutils.expiring_dict import ExpiringDict
23
25
  from ...pyutils.modules import resolve_local_lib_path
24
- from ...pyutils.callable_serde import CallableSerde
26
+ from ...pyutils.waiting_config import WaitingConfig, WaitingConfigArg
25
27
 
26
28
  if TYPE_CHECKING:
27
29
  from .cluster import Cluster
28
30
 
29
31
  if databricks_sdk is not None:
30
- from databricks.sdk.service.compute import Language, ResultType
32
+ from databricks.sdk.service.compute import Language, ResultType, CommandStatusResponse
33
+ else:
34
+ Language = DatabricksDummyClass
35
+ ResultType = DatabricksDummyClass
36
+ CommandStatusResponse = DatabricksDummyClass
37
+
31
38
 
32
39
  __all__ = [
33
40
  "ExecutionContext"
34
41
  ]
35
42
 
36
43
  LOGGER = logging.getLogger(__name__)
44
+ UPLOADED_PACKAGE_ROOTS: Dict[str, ExpiringDict] = {}
45
+ BytesLike = Union[bytes, bytearray, memoryview]
46
+
47
+ @dc.dataclass(frozen=True)
48
+ class BytesSource:
49
+ """
50
+ Hashable wrapper for in-memory content so it can be used as a dict key.
37
51
 
52
+ name: only used for debugging / metadata (not required to match remote basename)
53
+ data: bytes-like payload
54
+ """
55
+ name: str
56
+ data: bytes
57
+
58
+ LocalSpec = Union[
59
+ str,
60
+ os.PathLike,
61
+ bytes, # raw bytes as key (works, but no name)
62
+ BytesSource, # recommended for buffers
63
+ Tuple[str, BytesLike], # (name, data) helper
64
+ ]
38
65
 
39
66
  @dc.dataclass
40
67
  class RemoteMetadata:
@@ -42,6 +69,7 @@ class RemoteMetadata:
42
69
  site_packages_path: Optional[str] = dc.field(default=None)
43
70
  os_env: Dict[str, str] = dc.field(default_factory=dict)
44
71
  version_info: Tuple[int, int, int] = dc.field(default=(0, 0, 0))
72
+ temp_path: str = ""
45
73
 
46
74
  def os_env_diff(
47
75
  self,
@@ -77,14 +105,14 @@ class ExecutionContext:
77
105
  ctx.execute("print(x + 1)")
78
106
  """
79
107
  cluster: "Cluster"
80
- language: Optional["Language"] = None
81
108
  context_id: Optional[str] = None
82
109
 
83
- _was_connected: Optional[bool] = dc.field(default=None, repr=False)
84
- _remote_metadata: Optional[RemoteMetadata] = dc.field(default=None, repr=False)
85
- _uploaded_package_roots: Optional[ExpiringDict] = dc.field(default_factory=ExpiringDict, repr=False)
110
+ language: Optional[Language] = dc.field(default=None, repr=False, compare=False, hash=False)
86
111
 
87
- _lock: threading.RLock = dc.field(default_factory=threading.RLock, init=False, repr=False)
112
+ _was_connected: Optional[bool] = dc.field(default=None, repr=False, compare=False, hash=False)
113
+ _remote_metadata: Optional[RemoteMetadata] = dc.field(default=None, repr=False, compare=False, hash=False)
114
+ _uploaded_package_roots: Optional[ExpiringDict] = dc.field(default_factory=ExpiringDict, repr=False, compare=False, hash=False)
115
+ _lock: threading.RLock = dc.field(default_factory=threading.RLock, init=False, repr=False, compare=False, hash=False)
88
116
 
89
117
  # --- Pickle / cloudpickle support (don’t serialize locks or cached remote metadata) ---
90
118
  def __getstate__(self):
@@ -114,9 +142,28 @@ class ExecutionContext:
114
142
  self.close(wait=False)
115
143
  self.cluster.__exit__(exc_type, exc_val=exc_val, exc_tb=exc_tb)
116
144
 
117
- def __del__(self):
118
- """Best-effort cleanup for the remote execution context."""
119
- self.close(wait=False)
145
+ def __repr__(self):
146
+ return "%s(url=%s)" % (
147
+ self.__class__.__name__,
148
+ self.url()
149
+ )
150
+
151
+ def __str__(self):
152
+ return self.url()
153
+
154
+ def url(self) -> str:
155
+ return "%s/context/%s" % (
156
+ self.cluster.url(),
157
+ self.context_id or "unknown"
158
+ )
159
+
160
+ @property
161
+ def workspace(self):
162
+ return self.cluster.workspace
163
+
164
+ @property
165
+ def cluster_id(self):
166
+ return self.cluster.cluster_id
120
167
 
121
168
  @property
122
169
  def remote_metadata(self) -> RemoteMetadata:
@@ -130,46 +177,43 @@ class ExecutionContext:
130
177
  with self._lock:
131
178
  # double-check after acquiring lock
132
179
  if self._remote_metadata is None:
133
- cmd = r"""import glob, json, os
180
+ cmd = r"""import glob, json, os, tempfile
134
181
  from yggdrasil.pyutils.python_env import PythonEnv
135
182
 
136
183
  current_env = PythonEnv.get_current()
137
184
  meta = {}
138
185
 
186
+ # temp dir (explicit + stable for downstream code)
187
+ tmp_dir = tempfile.mkdtemp(prefix="tmp_")
188
+ meta["temp_path"] = tmp_dir
189
+ os.environ["TMPDIR"] = tmp_dir # many libs respect this
190
+
191
+ # find site-packages
139
192
  for path in glob.glob('/local_**/.ephemeral_nfs/cluster_libraries/python/lib/python*/site-*', recursive=False):
140
193
  if path.endswith('site-packages'):
141
194
  meta["site_packages_path"] = path
142
195
  break
143
196
 
197
+ # env vars snapshot
144
198
  os_env = meta["os_env"] = {}
145
199
  for k, v in os.environ.items():
146
200
  os_env[k] = v
147
-
201
+
148
202
  meta["version_info"] = current_env.version_info
149
203
 
150
204
  print(json.dumps(meta))"""
151
205
 
152
- try:
153
- content = self.execute_command(
154
- command=cmd,
155
- result_tag="<<RESULT>>",
156
- print_stdout=False,
157
- )
158
- except ImportError:
159
- self.cluster.ensure_running()
160
-
161
- content = self.execute_command(
162
- command=cmd,
163
- result_tag="<<RESULT>>",
164
- print_stdout=False,
165
- )
166
-
167
- self._remote_metadata = RemoteMetadata(**json.loads(content))
206
+ content = self.command(
207
+ command=cmd,
208
+ language=Language.PYTHON,
209
+ ).start().wait().result(unpickle=True)
210
+
211
+ self._remote_metadata = RemoteMetadata(**content)
168
212
 
169
213
  return self._remote_metadata
170
214
 
171
215
  # ------------ internal helpers ------------
172
- def _workspace_client(self):
216
+ def workspace_client(self):
173
217
  """Return the Databricks SDK client for command execution.
174
218
 
175
219
  Returns:
@@ -177,14 +221,26 @@ print(json.dumps(meta))"""
177
221
  """
178
222
  return self.cluster.workspace.sdk()
179
223
 
224
+ def shared_cache_path(
225
+ self,
226
+ suffix: str
227
+ ):
228
+ assert suffix, "Missing suffix arg"
229
+
230
+ return self.cluster.shared_cache_path(
231
+ suffix="/context/%s" % suffix.lstrip("/")
232
+ )
233
+
180
234
  def create(
181
235
  self,
182
236
  language: "Language",
183
- ) -> any:
237
+ wait: Optional[WaitingConfigArg] = True,
238
+ ) -> "ExecutionContext":
184
239
  """Create a command execution context, retrying if needed.
185
240
 
186
241
  Args:
187
242
  language: The Databricks command language to use.
243
+ wait: Waiting config to update
188
244
 
189
245
  Returns:
190
246
  The created command execution context response.
@@ -194,41 +250,55 @@ print(json.dumps(meta))"""
194
250
  self.cluster
195
251
  )
196
252
 
197
- client = self._workspace_client().command_execution
253
+ client = self.workspace_client().command_execution
198
254
 
199
255
  try:
200
- created = client.create_and_wait(
201
- cluster_id=self.cluster.cluster_id,
202
- language=language,
203
- )
204
- except:
205
- self.cluster.ensure_running()
256
+ with ThreadPoolExecutor(max_workers=1) as ex:
257
+ fut = ex.submit(
258
+ client.create,
259
+ cluster_id=self.cluster_id,
260
+ language=language,
261
+ )
262
+
263
+ try:
264
+ created = fut.result(timeout=10).response
265
+ except TimeoutError:
266
+ self.cluster.ensure_running(wait=True)
206
267
 
207
- created = client.create_and_wait(
208
- cluster_id=self.cluster.cluster_id,
268
+ created = client.create(
269
+ cluster_id=self.cluster_id,
270
+ language=language,
271
+ ).response
272
+ except Exception as e:
273
+ LOGGER.warning(e)
274
+
275
+ self.cluster.ensure_running(wait=True)
276
+
277
+ created = client.create(
278
+ cluster_id=self.cluster_id,
209
279
  language=language,
210
- )
280
+ ).response
211
281
 
212
282
  LOGGER.info(
213
- "Created Databricks command execution context %s",
283
+ "Created %s",
214
284
  self
215
285
  )
216
286
 
217
- created = getattr(created, "response", created)
218
-
219
287
  self.context_id = created.id
220
288
 
221
289
  return self
222
290
 
223
291
  def connect(
224
292
  self,
225
- language: Optional["Language"] = None,
226
- reset: bool = False
293
+ language: Optional[Language] = None,
294
+ wait: Optional[WaitingConfigArg] = True,
295
+ reset: bool = False,
227
296
  ) -> "ExecutionContext":
228
297
  """Create a remote command execution context if not already open.
229
298
 
230
299
  Args:
231
300
  language: Optional language override for the context.
301
+ wait: Wait config
232
302
  reset: Reset existing if connected
233
303
 
234
304
  Returns:
@@ -238,6 +308,11 @@ print(json.dumps(meta))"""
238
308
  if not reset:
239
309
  return self
240
310
 
311
+ LOGGER.info(
312
+ "%s reset connection",
313
+ self
314
+ )
315
+
241
316
  self.close(wait=False)
242
317
 
243
318
  language = language or self.language
@@ -245,7 +320,10 @@ print(json.dumps(meta))"""
245
320
  if language is None:
246
321
  language = Language.PYTHON
247
322
 
248
- return self.create(language=language)
323
+ return self.create(
324
+ language=language,
325
+ wait=wait
326
+ )
249
327
 
250
328
  def close(self, wait: bool = True) -> None:
251
329
  """Destroy the remote command execution context if it exists.
@@ -256,7 +334,7 @@ print(json.dumps(meta))"""
256
334
  if not self.context_id:
257
335
  return
258
336
 
259
- client = self._workspace_client()
337
+ client = self.workspace_client()
260
338
 
261
339
  try:
262
340
  if wait:
@@ -279,6 +357,68 @@ print(json.dumps(meta))"""
279
357
  self.context_id = None
280
358
 
281
359
  # ------------ public API ------------
360
+ def command(
361
+ self,
362
+ context: Optional["ExecutionContext"] = None,
363
+ func: Optional[Callable] = None,
364
+ command_id: Optional[str] = None,
365
+ command: Optional[str] = None,
366
+ language: Optional[Language] = None,
367
+ environ: Optional[Dict[str, str]] = None,
368
+ ):
369
+ context = self if context is None else context
370
+
371
+ return CommandExecution(
372
+ context=context,
373
+ command_id=command_id,
374
+ language=language,
375
+ command=command,
376
+ ).create(
377
+ context=context,
378
+ language=language,
379
+ command=command,
380
+ func=func,
381
+ environ=environ
382
+ )
383
+
384
+ def decorate(
385
+ self,
386
+ func: Optional[Callable] = None,
387
+ command: Optional[str] = None,
388
+ language: Optional[Language] = None,
389
+ command_id: Optional[str] = None,
390
+ environ: Optional[Union[Iterable[str], Dict[str, str]]] = None,
391
+ ) -> Callable:
392
+ language = Language.PYTHON if language is None else language
393
+
394
+ def decorator(
395
+ f: Callable,
396
+ c: ExecutionContext = self,
397
+ cmd: Optional[str] = command,
398
+ l: Optional[Language] = language,
399
+ cid: Optional[str] = command_id,
400
+ env: Optional[Union[Iterable[str], Dict[str, str]]] = environ,
401
+ ):
402
+ if c.is_in_databricks_environment():
403
+ return func
404
+
405
+ c.cluster.ensure_running(
406
+ wait=False
407
+ )
408
+
409
+ return c.command(
410
+ context=c,
411
+ func=f,
412
+ command_id=cid,
413
+ command=cmd,
414
+ language=l,
415
+ environ=env
416
+ )
417
+
418
+ if func is not None and callable(func):
419
+ return decorator(f=func)
420
+ return decorator
421
+
282
422
  def execute(
283
423
  self,
284
424
  obj: Union[str, Callable],
@@ -287,9 +427,7 @@ print(json.dumps(meta))"""
287
427
  kwargs: Dict[str, Any] = None,
288
428
  env_keys: Optional[List[str]] = None,
289
429
  env_variables: Optional[dict[str, str]] = None,
290
- timeout: Optional[dt.timedelta] = None,
291
- result_tag: Optional[str] = None,
292
- **options
430
+ timeout: Optional[WaitingConfigArg] = True,
293
431
  ):
294
432
  """Execute a string command or a callable in the remote context.
295
433
 
@@ -300,8 +438,6 @@ print(json.dumps(meta))"""
300
438
  env_keys: Environment variable names to forward.
301
439
  env_variables: Environment variables to inject remotely.
302
440
  timeout: Optional timeout for execution.
303
- result_tag: Optional result tag for parsing output.
304
- **options: Additional execution options.
305
441
 
306
442
  Returns:
307
443
  The decoded execution result.
@@ -309,9 +445,7 @@ print(json.dumps(meta))"""
309
445
  if isinstance(obj, str):
310
446
  return self.execute_command(
311
447
  command=obj,
312
- timeout=timeout,
313
- result_tag=result_tag,
314
- **options
448
+ wait=timeout,
315
449
  )
316
450
  elif callable(obj):
317
451
  return self.execute_callable(
@@ -321,9 +455,9 @@ print(json.dumps(meta))"""
321
455
  env_keys=env_keys,
322
456
  env_variables=env_variables,
323
457
  timeout=timeout,
324
- **options
325
458
  )
326
- raise ValueError(f"Cannot execute {type(obj)}")
459
+ else:
460
+ raise ValueError(f"Cannot execute {type(obj)}")
327
461
 
328
462
  def is_in_databricks_environment(self):
329
463
  """Return True when running on a Databricks runtime."""
@@ -336,8 +470,7 @@ print(json.dumps(meta))"""
336
470
  kwargs: Dict[str, Any] = None,
337
471
  env_keys: Optional[Iterable[str]] = None,
338
472
  env_variables: Optional[Dict[str, str]] = None,
339
- print_stdout: Optional[bool] = True,
340
- timeout: Optional[dt.timedelta] = None,
473
+ timeout: Optional[WaitingConfigArg] = True,
341
474
  command: Optional[str] = None,
342
475
  ) -> Any:
343
476
  """Execute a Python callable remotely and return the decoded result.
@@ -348,7 +481,6 @@ print(json.dumps(meta))"""
348
481
  kwargs: Keyword arguments for the callable.
349
482
  env_keys: Environment variable names to forward.
350
483
  env_variables: Environment variables to inject remotely.
351
- print_stdout: Whether to print stdout from the command output.
352
484
  timeout: Optional timeout for execution.
353
485
  command: Optional prebuilt command string override.
354
486
 
@@ -395,48 +527,14 @@ print(json.dumps(meta))"""
395
527
 
396
528
  raw_result = self.execute_command(
397
529
  command,
398
- timeout=timeout, result_tag=result_tag, print_stdout=print_stdout
530
+ wait=timeout
399
531
  )
400
532
 
401
- try:
402
- result = serialized.parse_command_result(
403
- raw_result,
404
- result_tag=result_tag,
405
- workspace=self.cluster.workspace
406
- )
407
- except ModuleNotFoundError as remote_module_error:
408
- _MOD_NOT_FOUND_RE = re.compile(r"No module named ['\"]([^'\"]+)['\"]")
409
- module_name = _MOD_NOT_FOUND_RE.search(str(remote_module_error))
410
- module_name = module_name.group(1) if module_name else None
411
- module_name = module_name.split(".")[0]
412
-
413
- if module_name and "yggdrasil" not in module_name:
414
- LOGGER.debug(
415
- "Installing missing module %s from local environment",
416
- module_name,
417
- )
418
-
419
- self.install_temporary_libraries(
420
- libraries=[module_name],
421
- )
422
-
423
- LOGGER.warning(
424
- "Installed missing module %s from local environment",
425
- module_name,
426
- )
427
-
428
- return self.execute_callable(
429
- func=func,
430
- args=args,
431
- kwargs=kwargs,
432
- env_keys=env_keys,
433
- env_variables=env_variables,
434
- print_stdout=print_stdout,
435
- timeout=timeout,
436
- command=command,
437
- )
438
-
439
- raise remote_module_error
533
+ result = serialized.parse_command_result(
534
+ raw_result,
535
+ result_tag=result_tag,
536
+ workspace=self.cluster.workspace
537
+ )
440
538
 
441
539
  return result
442
540
 
@@ -444,163 +542,256 @@ print(json.dumps(meta))"""
444
542
  self,
445
543
  command: str,
446
544
  *,
447
- timeout: Optional[dt.timedelta] = dt.timedelta(minutes=20),
448
- result_tag: Optional[str] = None,
449
- print_stdout: Optional[bool] = True,
545
+ language: Optional[Language] = None,
546
+ wait: Optional[WaitingConfigArg] = True,
547
+ raise_error: bool = True
450
548
  ) -> str:
451
549
  """Execute a command in this context and return decoded output.
452
550
 
453
551
  Args:
454
552
  command: The command string to execute.
455
- timeout: Optional timeout for execution.
456
- result_tag: Optional tag to extract a specific result segment.
457
- print_stdout: Whether to print stdout for tagged output.
553
+ language: Language to execute
554
+ wait: Optional timeout for execution.
555
+ raise_error: Raises error if failed
458
556
 
459
557
  Returns:
460
558
  The decoded command output string.
461
559
  """
462
- self.connect()
463
-
464
- client = self._workspace_client()
465
- result = client.command_execution.execute_and_wait(
466
- cluster_id=self.cluster.cluster_id,
467
- context_id=self.context_id,
468
- language=self.language,
560
+ return self.command(
469
561
  command=command,
470
- timeout=timeout or dt.timedelta(minutes=20)
471
- )
472
-
473
- try:
474
- return self._decode_result(
475
- result,
476
- result_tag=result_tag,
477
- print_stdout=print_stdout
478
- )
479
- except CommandAborted:
480
- return self.connect(language=self.language, reset=True).execute_command(
481
- command=command,
482
- timeout=timeout,
483
- result_tag=result_tag,
484
- print_stdout=print_stdout
485
- )
486
- except ModuleNotFoundError as remote_module_error:
487
- _MOD_NOT_FOUND_RE = re.compile(r"No module named ['\"]([^'\"]+)['\"]")
488
- module_name = _MOD_NOT_FOUND_RE.search(str(remote_module_error))
489
- module_name = module_name.group(1) if module_name else None
490
- module_name = module_name.split(".")[0]
491
-
492
- if module_name and "yggdrasil" not in module_name:
493
- LOGGER.debug(
494
- "Installing missing module %s from local environment",
495
- module_name,
496
- )
497
-
498
- self.install_temporary_libraries(
499
- libraries=[module_name],
500
- )
501
-
502
- LOGGER.warning(
503
- "Installed missing module %s from local environment",
504
- module_name,
505
- )
506
-
507
- return self.execute_command(
508
- command=command,
509
- timeout=timeout,
510
- result_tag=result_tag,
511
- print_stdout=print_stdout
512
- )
513
-
514
- raise remote_module_error
562
+ language=language,
563
+ ).wait(wait=wait, raise_error=raise_error)
515
564
 
516
565
  # ------------------------------------------------------------------
517
566
  # generic local → remote uploader, via remote python
518
567
  # ------------------------------------------------------------------
519
- def upload_local_path(self, local_path: str, remote_path: str) -> None:
520
- """
521
- Generic uploader.
522
-
523
- - If local_path is a file:
524
- remote_path is the *file* path on remote.
525
- - If local_path is a directory:
526
- remote_path is the *directory root* on remote; the directory
527
- contents are mirrored under it.
528
- Args:
529
- local_path: Local file or directory to upload.
530
- remote_path: Target path on the remote cluster.
531
-
532
- Returns:
533
- None.
568
+ def upload_local_path(
569
+ self,
570
+ paths: Union[Iterable[Tuple[LocalSpec, str]], Dict[LocalSpec, str]],
571
+ byte_limit: int = 64 * 1024
572
+ ) -> None:
534
573
  """
535
- local_path = os.path.abspath(local_path)
536
- if not os.path.exists(local_path):
537
- raise FileNotFoundError(f"Local path not found: {local_path}")
538
-
539
- # normalize to POSIX for remote (Linux)
540
- remote_path = remote_path.replace("\\", "/")
574
+ One-shot uploader. Sends exactly ONE remote command.
541
575
 
542
- if os.path.isfile(local_path):
543
- # ---------- single file ----------
544
- with open(local_path, "rb") as f:
545
- data_b64 = base64.b64encode(f.read()).decode("ascii")
576
+ paths: dict[local_spec -> remote_target]
546
577
 
547
- cmd = f"""import base64, os
578
+ local_spec can be:
579
+ - str | PathLike: local file or directory
580
+ - bytes/bytearray/memoryview: raw content (remote_target must be a file path)
581
+ - BytesSource(name, data): raw content with a name
582
+ - (name, bytes-like): raw content with a name
548
583
 
549
- remote_file = {remote_path!r}
550
- data_b64 = {data_b64!r}
551
-
552
- os.makedirs(os.path.dirname(remote_file), exist_ok=True)
553
- with open(remote_file, "wb") as f:
554
- f.write(base64.b64decode(data_b64))
555
- """
556
-
557
- self.execute_command(command=cmd, print_stdout=False)
558
- return
559
-
560
- # ---------- directory ----------
584
+ remote_target:
585
+ - if local_spec is file: full remote file path
586
+ - if local_spec is dir: remote directory root
587
+ - if local_spec is bytes: full remote file path
588
+ """
589
+ if isinstance(paths, dict):
590
+ paths = paths.items()
591
+
592
+ def _to_bytes(x: BytesLike) -> bytes:
593
+ if isinstance(x, bytes):
594
+ return x
595
+ if isinstance(x, bytearray):
596
+ return bytes(x)
597
+ if isinstance(x, memoryview):
598
+ return x.tobytes()
599
+ elif isinstance(x, io.BytesIO):
600
+ return x.getvalue()
601
+ raise TypeError(f"Unsupported bytes-like: {type(x)!r}")
602
+
603
+ # normalize + validate + build a unified "work list"
604
+ work: list[dict[str, Any]] = []
605
+ for local_spec, remote in paths:
606
+ if not isinstance(remote, str) or not remote:
607
+ raise TypeError("remote_target must be a non-empty string")
608
+
609
+ remote_posix = remote.replace("\\", "/")
610
+
611
+ # --- bytes payloads ---
612
+ if isinstance(local_spec, BytesSource):
613
+ work.append({
614
+ "kind": "bytes",
615
+ "name": local_spec.name,
616
+ "data": local_spec.data,
617
+ "remote": remote_posix,
618
+ })
619
+ continue
620
+
621
+ if isinstance(local_spec, tuple) and len(local_spec) == 2 and isinstance(local_spec[0], str):
622
+ name, data = local_spec
623
+ work.append({
624
+ "kind": "bytes",
625
+ "name": name,
626
+ "data": _to_bytes(data),
627
+ "remote": remote_posix,
628
+ })
629
+ continue
630
+
631
+ if isinstance(local_spec, (bytes, bytearray, memoryview, io.BytesIO)):
632
+ work.append({
633
+ "kind": "bytes",
634
+ "name": "blob",
635
+ "data": _to_bytes(local_spec),
636
+ "remote": remote_posix,
637
+ })
638
+ continue
639
+
640
+ # --- filesystem payloads ---
641
+ if isinstance(local_spec, os.PathLike):
642
+ local_spec = os.fspath(local_spec)
643
+
644
+ if isinstance(local_spec, str):
645
+ local_abs = os.path.abspath(local_spec)
646
+ if not os.path.exists(local_abs):
647
+ raise FileNotFoundError(f"Local path not found: {local_spec}")
648
+
649
+ if os.path.isfile(local_abs):
650
+ work.append({
651
+ "kind": "file",
652
+ "local": local_abs,
653
+ "remote": remote_posix,
654
+ })
655
+ else:
656
+ work.append({
657
+ "kind": "dir",
658
+ "local": local_abs,
659
+ "remote_root": remote_posix.rstrip("/"),
660
+ "top": os.path.basename(local_abs.rstrip(os.sep)) or "dir",
661
+ })
662
+ continue
663
+
664
+ raise TypeError(f"Unsupported local_spec type: {type(local_spec)!r}")
665
+
666
+ # build one zip containing all content
667
+ manifest: list[dict[str, Any]] = []
561
668
  buf = io.BytesIO()
562
- local_root = local_path
669
+ with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
670
+ for idx, item in enumerate(work):
671
+ kind = item["kind"]
672
+
673
+ if kind == "bytes":
674
+ zip_name = f"BYTES/{idx}"
675
+ zf.writestr(zip_name, item["data"])
676
+ manifest.append({
677
+ "kind": "bytes",
678
+ "zip": zip_name,
679
+ "remote": item["remote"],
680
+ })
681
+
682
+ elif kind == "file":
683
+ zip_name = f"FILE/{idx}"
684
+ zf.write(item["local"], arcname=zip_name)
685
+ manifest.append({
686
+ "kind": "file",
687
+ "zip": zip_name,
688
+ "remote": item["remote"],
689
+ })
690
+
691
+ elif kind == "dir":
692
+ local_root = item["local"]
693
+ top = item["top"]
694
+ prefix = f"DIR/{idx}/{top}"
695
+
696
+ for root, dirs, files in os.walk(local_root):
697
+ dirs[:] = [d for d in dirs if d != "__pycache__"]
698
+
699
+ rel_root = os.path.relpath(root, local_root)
700
+ rel_root = "" if rel_root == "." else rel_root
701
+
702
+ for name in files:
703
+ if name.endswith((".pyc", ".pyo")):
704
+ continue
705
+ full = os.path.join(root, name)
706
+ rel_path = os.path.join(rel_root, name) if rel_root else name
707
+ zip_name = f"{prefix}/{rel_path}".replace("\\", "/")
708
+ zf.write(full, arcname=zip_name)
709
+
710
+ manifest.append({
711
+ "kind": "dir",
712
+ "zip_prefix": f"{prefix}/",
713
+ "remote_root": item["remote_root"],
714
+ })
715
+
716
+ else:
717
+ raise ValueError(f"Unknown kind in work list: {kind}")
718
+
719
+ raw = buf.getvalue()
720
+
721
+ # optional zlib on top of zip
722
+ algo = "none"
723
+ payload = raw
724
+ if len(raw) > byte_limit:
725
+ import zlib
726
+ compressed = zlib.compress(raw, level=9)
727
+ if len(compressed) < int(len(raw) * 0.95):
728
+ algo = "zlib"
729
+ payload = compressed
730
+
731
+ packed = b"ALG:" + algo.encode("ascii") + b"\n" + payload
732
+ data_b64 = base64.b64encode(packed).decode("ascii")
733
+
734
+ cmd = f"""import base64, io, os, zipfile, zlib
735
+
736
+ packed_b64 = {data_b64!r}
737
+ manifest = {manifest!r}
738
+
739
+ packed = base64.b64decode(packed_b64)
740
+ nl = packed.find(b"\\n")
741
+ if nl == -1 or not packed.startswith(b"ALG:"):
742
+ raise ValueError("Bad payload header")
743
+
744
+ algo = packed[4:nl].decode("ascii")
745
+ payload = packed[nl+1:]
746
+
747
+ if algo == "none":
748
+ raw = payload
749
+ elif algo == "zlib":
750
+ raw = zlib.decompress(payload)
751
+ else:
752
+ raise ValueError(f"Unknown compression algo: {{algo}}")
753
+
754
+ buf = io.BytesIO(raw)
755
+ with zipfile.ZipFile(buf, "r") as zf:
756
+ names = set(zf.namelist())
563
757
 
564
- # zip local folder into memory
565
- with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
566
- for root, dirs, files in os.walk(local_root):
567
- # skip __pycache__
568
- dirs[:] = [d for d in dirs if d != "__pycache__"]
758
+ for item in manifest:
759
+ kind = item["kind"]
569
760
 
570
- rel_root = os.path.relpath(root, local_root)
571
- if rel_root == ".":
572
- rel_root = ""
573
- for name in files:
574
- if name.endswith((".pyc", ".pyo")):
575
- continue
576
- full = os.path.join(root, name)
577
- arcname = os.path.join(rel_root, name) if rel_root else name
578
- zf.write(full, arcname=arcname)
761
+ if kind in ("file", "bytes"):
762
+ zip_name = item["zip"]
763
+ remote_file = item["remote"]
764
+ if zip_name not in names:
765
+ raise FileNotFoundError(f"Missing in zip: {{zip_name}}")
579
766
 
580
- data_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
767
+ parent = os.path.dirname(remote_file)
768
+ if parent:
769
+ os.makedirs(parent, exist_ok=True)
581
770
 
582
- cmd = f"""import base64, io, os, zipfile
771
+ with zf.open(zip_name, "r") as src, open(remote_file, "wb") as dst:
772
+ dst.write(src.read())
583
773
 
584
- remote_root = {remote_path!r}
585
- data_b64 = {data_b64!r}
774
+ elif kind == "dir":
775
+ prefix = item["zip_prefix"]
776
+ remote_root = item["remote_root"]
777
+ os.makedirs(remote_root, exist_ok=True)
586
778
 
587
- os.makedirs(remote_root, exist_ok=True)
779
+ for n in names:
780
+ if not n.startswith(prefix):
781
+ continue
782
+ rel = n[len(prefix):]
783
+ if not rel or rel.endswith("/"):
784
+ continue
588
785
 
589
- buf = io.BytesIO(base64.b64decode(data_b64))
590
- with zipfile.ZipFile(buf, "r") as zf:
591
- for member in zf.infolist():
592
- rel_name = member.filename
593
- target_path = os.path.join(remote_root, rel_name)
786
+ target = os.path.join(remote_root, rel)
787
+ os.makedirs(os.path.dirname(target), exist_ok=True)
788
+ with zf.open(n, "r") as src, open(target, "wb") as dst:
789
+ dst.write(src.read())
594
790
 
595
- if member.is_dir() or rel_name.endswith("/"):
596
- os.makedirs(target_path, exist_ok=True)
597
791
  else:
598
- os.makedirs(os.path.dirname(target_path), exist_ok=True)
599
- with zf.open(member, "r") as src, open(target_path, "wb") as dst:
600
- dst.write(src.read())
792
+ raise ValueError(f"Unknown manifest kind: {{kind}}")
601
793
  """
602
-
603
- self.execute_command(command=cmd, print_stdout=False)
794
+ self.execute_command(command=cmd)
604
795
 
605
796
  # ------------------------------------------------------------------
606
797
  # upload local lib into remote site-packages
@@ -608,7 +799,6 @@ with zipfile.ZipFile(buf, "r") as zf:
608
799
  def install_temporary_libraries(
609
800
  self,
610
801
  libraries: str | ModuleType | List[str | ModuleType],
611
- with_dependencies: bool = True
612
802
  ) -> Union[str, ModuleType, List[str | ModuleType]]:
613
803
  """
614
804
  Upload a local Python lib/module into the remote cluster's
@@ -621,7 +811,6 @@ with zipfile.ZipFile(buf, "r") as zf:
621
811
  - module object (e.g. import ygg; workspace.upload_local_lib(ygg))
622
812
  Args:
623
813
  libraries: Library path, name, module, or iterable of these.
624
- with_dependencies: Whether to include dependencies (unused).
625
814
 
626
815
  Returns:
627
816
  The resolved library or list of libraries uploaded.
@@ -632,6 +821,13 @@ with zipfile.ZipFile(buf, "r") as zf:
632
821
  ]
633
822
 
634
823
  resolved = resolve_local_lib_path(libraries)
824
+
825
+ LOGGER.debug(
826
+ "Installing temporary lib '%s' in %s",
827
+ resolved,
828
+ self
829
+ )
830
+
635
831
  str_resolved = str(resolved)
636
832
  existing = self._uploaded_package_roots.get(str_resolved)
637
833
 
@@ -645,61 +841,53 @@ with zipfile.ZipFile(buf, "r") as zf:
645
841
  # site-packages/<module_file>
646
842
  remote_target = posixpath.join(remote_site_packages_path, resolved.name)
647
843
 
648
- self.upload_local_path(resolved, remote_target)
844
+ self.upload_local_path({
845
+ str_resolved: remote_target
846
+ })
649
847
 
650
848
  self._uploaded_package_roots[str_resolved] = remote_target
651
849
 
652
- return libraries
653
-
654
- def _decode_result(
655
- self,
656
- result: Any,
657
- *,
658
- result_tag: Optional[str],
659
- print_stdout: Optional[bool] = True
660
- ) -> str:
661
- """Mirror the old Cluster.execute_command result handling.
850
+ LOGGER.info(
851
+ "Installed temporary lib '%s' in %s",
852
+ resolved,
853
+ self
854
+ )
662
855
 
663
- Args:
664
- result: Raw command execution response.
665
- result_tag: Optional tag to extract a segment from output.
666
- print_stdout: Whether to print stdout when using tags.
856
+ return libraries
667
857
 
668
- Returns:
669
- The decoded output string.
670
- """
671
- if not getattr(result, "results", None):
672
- raise RuntimeError("Command execution returned no results")
673
858
 
674
- res = result.results
859
+ def _decode_result(
860
+ result: CommandStatusResponse,
861
+ language: Language
862
+ ) -> str:
863
+ """Mirror the old Cluster.execute_command result handling.
675
864
 
676
- # error handling
677
- if res.result_type == ResultType.ERROR:
678
- message = res.cause or "Command execution failed"
865
+ Args:
866
+ result: Raw command execution response.
679
867
 
680
- if "client terminated the session" in message:
681
- raise CommandAborted(message)
868
+ Returns:
869
+ The decoded output string.
870
+ """
871
+ res = result.results
682
872
 
683
- if self.language == Language.PYTHON:
684
- raise_parsed_traceback(message)
873
+ # error handling
874
+ if res.result_type == ResultType.ERROR:
875
+ message = res.cause or "Command execution failed"
685
876
 
686
- remote_tb = (
687
- getattr(res, "data", None)
688
- or getattr(res, "stack_trace", None)
689
- or getattr(res, "traceback", None)
690
- )
877
+ if "client terminated the session" in message:
878
+ raise ClientTerminatedSession(message)
691
879
 
692
- if remote_tb:
693
- message = f"{message}\n{remote_tb}"
880
+ if language == Language.PYTHON:
881
+ raise_parsed_traceback(message)
694
882
 
695
- raise RuntimeError(message)
883
+ raise RuntimeError(message)
696
884
 
697
- # normal output
698
- if res.result_type == ResultType.TEXT:
699
- output = getattr(res, "data", "") or ""
700
- elif getattr(res, "data", None) is not None:
701
- output = str(res.data)
702
- else:
703
- output = ""
885
+ # normal output
886
+ if res.result_type == ResultType.TEXT:
887
+ output = res.data or ""
888
+ elif res.data is not None:
889
+ output = str(res.data)
890
+ else:
891
+ output = ""
704
892
 
705
- return output
893
+ return output