ygg 0.1.20__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,6 @@ import functools
15
15
  import inspect
16
16
  import logging
17
17
  import os
18
- import sys
19
18
  import time
20
19
  from dataclasses import dataclass
21
20
  from types import ModuleType
@@ -23,10 +22,10 @@ from typing import Any, Iterator, Optional, Union, List, Callable, Dict, ClassVa
23
22
 
24
23
  from .execution_context import ExecutionContext
25
24
  from ..workspaces.workspace import WorkspaceService, Workspace
26
- from ... import retry
25
+ from ... import retry, CallableSerde
27
26
  from ...libs.databrickslib import databricks_sdk
28
27
  from ...pyutils.modules import PipIndexSettings
29
- from ...ser import CallableSerdeMixin
28
+ from ...pyutils.python_env import PythonEnv
30
29
 
31
30
  if databricks_sdk is None: # pragma: no cover - import guard
32
31
  ResourceDoesNotExist = Exception # type: ignore
@@ -127,7 +126,7 @@ class Cluster(WorkspaceService):
127
126
  # 🔥 first time for this host → create
128
127
  inst = cls._env_clusters[host] = (
129
128
  cls(workspace=workspace, cluster_id=cluster_id, cluster_name=cluster_name)
130
- .replicate_current_environment(
129
+ .push_python_environment(
131
130
  single_user_name=single_user_name,
132
131
  runtime_engine=runtime_engine,
133
132
  libraries=libraries,
@@ -137,8 +136,9 @@ class Cluster(WorkspaceService):
137
136
 
138
137
  return inst
139
138
 
140
- def replicate_current_environment(
139
+ def push_python_environment(
141
140
  self,
141
+ source: Optional[PythonEnv] = None,
142
142
  cluster_id: Optional[str] = None,
143
143
  cluster_name: Optional[str] = None,
144
144
  single_user_name: Optional[str] = None,
@@ -146,15 +146,19 @@ class Cluster(WorkspaceService):
146
146
  libraries: Optional[list[str]] = None,
147
147
  **kwargs
148
148
  ) -> "Cluster":
149
+ if source is None:
150
+ source = PythonEnv.get_current()
151
+
149
152
  libraries = list(libraries) if libraries is not None else []
150
153
  libraries.extend([
151
154
  _ for _ in [
152
155
  "ygg",
153
- "dill"
156
+ "dill",
157
+ "uv",
154
158
  ] if _ not in libraries
155
159
  ])
156
160
 
157
- python_version = sys.version_info
161
+ python_version = source.version_info
158
162
 
159
163
  if python_version[0] < 3:
160
164
  python_version = None
@@ -172,6 +176,28 @@ class Cluster(WorkspaceService):
172
176
  )
173
177
 
174
178
  return inst
179
+
180
+ def pull_python_environment(
181
+ self,
182
+ target: Optional[PythonEnv] = None,
183
+ ):
184
+ with self.context() as c:
185
+ m = c.remote_metadata
186
+ requirements = m.requirements
187
+ version_info = m.version_info
188
+
189
+ if target is None:
190
+ target = PythonEnv.create(
191
+ name=f"dbx-{self.name}",
192
+ python=".".join(str(_) for _ in version_info)
193
+ )
194
+ else:
195
+ target.update(
196
+ requirements=requirements,
197
+ python=".".join(str(_) for _ in version_info)
198
+ )
199
+
200
+ return target
175
201
 
176
202
  @property
177
203
  def details(self):
@@ -364,6 +390,8 @@ class Cluster(WorkspaceService):
364
390
  python_version: Optional[Union[str, tuple[int, ...]]] = None,
365
391
  **kwargs
366
392
  ):
393
+ pip_settings = PipIndexSettings.default_settings()
394
+
367
395
  if kwargs:
368
396
  details = ClusterDetails(**{
369
397
  **details.as_shallow_dict(),
@@ -395,6 +423,13 @@ class Cluster(WorkspaceService):
395
423
  if details.is_single_node is not None and details.kind is None:
396
424
  details.kind = Kind.CLASSIC_PREVIEW
397
425
 
426
+ if pip_settings.extra_index_urls:
427
+ if details.spark_env_vars is None:
428
+ details.spark_env_vars = {}
429
+ str_urls = " ".join(pip_settings.extra_index_urls)
430
+ details.spark_env_vars["UV_EXTRA_INDEX_URL"] = details.spark_env_vars.get("UV_INDEX", str_urls)
431
+ details.spark_env_vars["PIP_EXTRA_INDEX_URL"] = details.spark_env_vars.get("PIP_EXTRA_INDEX_URL", str_urls)
432
+
398
433
  return details
399
434
 
400
435
  def create_or_update(
@@ -569,7 +604,7 @@ class Cluster(WorkspaceService):
569
604
  logger.info("Deleting %s", self)
570
605
  return self.clusters_client().delete(cluster_id=self.cluster_id)
571
606
 
572
- def execution_context(
607
+ def context(
573
608
  self,
574
609
  language: Optional["Language"] = None,
575
610
  context_id: Optional[str] = None
@@ -591,7 +626,7 @@ class Cluster(WorkspaceService):
591
626
  timeout: Optional[dt.timedelta] = None,
592
627
  result_tag: Optional[str] = None,
593
628
  ):
594
- return self.execution_context(language=language).execute(
629
+ return self.context(language=language).execute(
595
630
  obj=obj,
596
631
  args=args,
597
632
  kwargs=kwargs,
@@ -607,8 +642,10 @@ class Cluster(WorkspaceService):
607
642
  self,
608
643
  _func: Optional[Callable] = None,
609
644
  *,
645
+ before: Optional[Callable] = None,
610
646
  language: Optional["Language"] = None,
611
647
  env_keys: Optional[List[str]] = None,
648
+ env_variables: Optional[Dict[str, str]] = None,
612
649
  timeout: Optional[dt.timedelta] = None,
613
650
  result_tag: Optional[str] = None,
614
651
  **options
@@ -630,19 +667,23 @@ class Cluster(WorkspaceService):
630
667
  def h(z): ...
631
668
  """
632
669
  def decorator(func: Callable):
633
- context = self.execution_context(language=language or Language.PYTHON)
634
- serialized = CallableSerdeMixin.from_callable(func)
670
+ context = self.context(language=language or Language.PYTHON)
671
+ serialized = CallableSerde.from_callable(func)
672
+ do_before = CallableSerde.from_callable(before)
635
673
 
636
674
  @functools.wraps(func)
637
675
  def wrapper(*args, **kwargs):
638
676
  if os.getenv("DATABRICKS_RUNTIME_VERSION") is not None:
639
677
  return func(*args, **kwargs)
640
678
 
679
+ do_before()
680
+
641
681
  return context.execute(
642
682
  obj=serialized,
643
683
  args=list(args),
644
684
  kwargs=kwargs,
645
685
  env_keys=env_keys,
686
+ env_variables=env_variables,
646
687
  timeout=timeout,
647
688
  result_tag=result_tag,
648
689
  **options
@@ -685,7 +726,9 @@ class Cluster(WorkspaceService):
685
726
 
686
727
  if wait_timeout is not None:
687
728
  self.wait_installed_libraries(
688
- timeout=wait_timeout, pip_settings=pip_settings, raise_error=raise_error
729
+ timeout=wait_timeout,
730
+ pip_settings=pip_settings,
731
+ raise_error=raise_error
689
732
  )
690
733
 
691
734
  return self
@@ -790,7 +833,7 @@ class Cluster(WorkspaceService):
790
833
  self,
791
834
  libraries: str | ModuleType | List[str | ModuleType],
792
835
  ):
793
- return self.execution_context().install_temporary_libraries(libraries=libraries)
836
+ return self.context().install_temporary_libraries(libraries=libraries)
794
837
 
795
838
  def _check_library(
796
839
  self,
@@ -823,7 +866,11 @@ class Cluster(WorkspaceService):
823
866
  repo = None
824
867
 
825
868
  if pip_settings.extra_index_url:
826
- if value.startswith("datamanagement") or value.startswith("TSSecrets") or value.startswith("tgp_"):
869
+ if (
870
+ value.startswith("datamanagement")
871
+ or value.startswith("TSSecrets")
872
+ or value.startswith("tgp_")
873
+ ):
827
874
  repo = pip_settings.extra_index_url
828
875
 
829
876
  return Library(
@@ -11,12 +11,12 @@ import sys
11
11
  import threading
12
12
  import zipfile
13
13
  from types import ModuleType
14
- from typing import TYPE_CHECKING, Optional, Any, Callable, List, Dict, Union, Iterable
14
+ from typing import TYPE_CHECKING, Optional, Any, Callable, List, Dict, Union, Iterable, Tuple
15
15
 
16
16
  from ...libs.databrickslib import databricks_sdk
17
17
  from ...pyutils.exceptions import raise_parsed_traceback
18
18
  from ...pyutils.modules import resolve_local_lib_path
19
- from ...ser import CallableSerdeMixin
19
+ from ...pyutils.callable_serde import CallableSerde
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from .cluster import Cluster
@@ -35,6 +35,8 @@ logger = logging.getLogger(__name__)
35
35
  class RemoteMetadata:
36
36
  site_packages_path: Optional[str] = dc.field(default=None)
37
37
  os_env: Dict[str, str] = dc.field(default_factory=dict)
38
+ requirements: Optional[str] = dc.field(default=None)
39
+ version_info: Tuple[int, int, int] = dc.field(default=(0, 0, 0))
38
40
 
39
41
  def os_env_diff(
40
42
  self,
@@ -49,6 +51,7 @@ class RemoteMetadata:
49
51
  if k not in self.os_env.keys()
50
52
  }
51
53
 
54
+
52
55
  @dc.dataclass
53
56
  class ExecutionContext:
54
57
  """
@@ -74,23 +77,21 @@ class ExecutionContext:
74
77
  _was_connected: Optional[bool] = None
75
78
  _remote_metadata: Optional[RemoteMetadata] = None
76
79
 
77
- __lock: threading.Lock = dc.field(default_factory=threading.Lock, init=False, repr=False)
80
+ _lock: threading.RLock = dc.field(default_factory=threading.RLock, init=False, repr=False)
78
81
 
79
82
  # --- Pickle / cloudpickle support (don’t serialize locks or cached remote metadata) ---
80
83
  def __getstate__(self):
81
84
  state = self.__dict__.copy()
82
85
 
83
- # name-mangled field for __lock in instance dict:
84
- state.pop("_ExecutionContext__lock", None)
86
+ # name-mangled field for _lock in instance dict:
87
+ state.pop("_lock", None)
85
88
 
86
89
  return state
87
90
 
88
91
  def __setstate__(self, state):
89
- self.__dict__.update(state)
92
+ state["_lock"] = state.get("_lock", threading.RLock())
90
93
 
91
- # recreate lock + reset cache on unpickle
92
- self.__dict__["_ExecutionContext__lock"] = threading.Lock()
93
- self._remote_metadata = None
94
+ self.__dict__.update(state)
94
95
 
95
96
  def __enter__(self) -> "ExecutionContext":
96
97
  self.cluster.__enter__()
@@ -106,20 +107,22 @@ class ExecutionContext:
106
107
  self.close()
107
108
 
108
109
  @property
109
- def remote_metadata(self) -> "RemoteMetadata":
110
+ def remote_metadata(self) -> RemoteMetadata:
110
111
  # fast path (no lock)
111
112
  rm = self._remote_metadata
112
113
  if rm is not None:
113
114
  return rm
114
115
 
115
116
  # slow path guarded
116
- with self.__lock:
117
+ with self._lock:
117
118
  # double-check after acquiring lock
118
119
  if self._remote_metadata is None:
119
120
  cmd = r"""import glob
120
121
  import json
121
122
  import os
123
+ from yggdrasil.pyutils.python_env import PythonEnv
122
124
 
125
+ current_env = PythonEnv.get_current()
123
126
  meta = {}
124
127
 
125
128
  for path in glob.glob('/local_**/.ephemeral_nfs/cluster_libraries/python/lib/python*/site-*', recursive=False):
@@ -130,9 +133,11 @@ for path in glob.glob('/local_**/.ephemeral_nfs/cluster_libraries/python/lib/pyt
130
133
  os_env = meta["os_env"] = {}
131
134
  for k, v in os.environ.items():
132
135
  os_env[k] = v
136
+
137
+ meta["requirements"] = current_env.export_requirements_matrix()
138
+ meta["version_info"] = current_env.version_info
133
139
 
134
- print(json.dumps(meta))
135
- """
140
+ print(json.dumps(meta))"""
136
141
 
137
142
  content = self.execute_command(
138
143
  command=cmd,
@@ -226,7 +231,7 @@ print(json.dumps(meta))
226
231
  # ------------ public API ------------
227
232
  def execute(
228
233
  self,
229
- obj: Union[str, Callable, CallableSerdeMixin],
234
+ obj: Union[str, Callable],
230
235
  *,
231
236
  args: List[Any] = None,
232
237
  kwargs: Dict[str, Any] = None,
@@ -260,7 +265,7 @@ print(json.dumps(meta))
260
265
 
261
266
  def execute_callable(
262
267
  self,
263
- func: Callable | CallableSerdeMixin,
268
+ func: Callable | CallableSerde,
264
269
  args: List[Any] = None,
265
270
  kwargs: Dict[str, Any] = None,
266
271
  env_keys: Optional[Iterable[str]] = None,
@@ -269,7 +274,7 @@ print(json.dumps(meta))
269
274
  timeout: Optional[dt.timedelta] = None,
270
275
  command: Optional[str] = None,
271
276
  use_dill: Optional[bool] = None
272
- ) -> str:
277
+ ) -> Any:
273
278
  if self.is_in_databricks_environment():
274
279
  args = args or []
275
280
  kwargs = kwargs or {}
@@ -284,7 +289,7 @@ print(json.dumps(meta))
284
289
  self,
285
290
  )
286
291
 
287
- serialized = CallableSerdeMixin.from_callable(func)
292
+ serialized = CallableSerde.from_callable(func)
288
293
 
289
294
  self.install_temporary_libraries(libraries=serialized.package_root)
290
295
 
@@ -302,9 +307,6 @@ print(json.dumps(meta))
302
307
  command = serialized.to_command(
303
308
  args=args,
304
309
  kwargs=kwargs,
305
- env_keys=env_keys or [],
306
- env_variables=env_variables,
307
- use_dill=use_dill,
308
310
  result_tag=result_tag,
309
311
  ) if not command else command
310
312
 
@@ -9,8 +9,6 @@ from typing import (
9
9
 
10
10
  from ..workspaces.workspace import Workspace
11
11
 
12
-
13
-
14
12
  ReturnType = TypeVar("ReturnType")
15
13
 
16
14
  logger = logging.getLogger(__name__)
@@ -1,2 +1,4 @@
1
1
  from .retry import retry
2
2
  from .parallel import parallelize
3
+ from .python_env import PythonEnv
4
+ from .callable_serde import CallableSerde