ygg 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,6 @@ import functools
15
15
  import inspect
16
16
  import logging
17
17
  import os
18
- import sys
19
18
  import time
20
19
  from dataclasses import dataclass
21
20
  from types import ModuleType
@@ -23,10 +22,10 @@ from typing import Any, Iterator, Optional, Union, List, Callable, Dict, ClassVa
23
22
 
24
23
  from .execution_context import ExecutionContext
25
24
  from ..workspaces.workspace import WorkspaceService, Workspace
26
- from ... import retry
25
+ from ... import retry, CallableSerde
27
26
  from ...libs.databrickslib import databricks_sdk
28
27
  from ...pyutils.modules import PipIndexSettings
29
- from ...ser import CallableSerdeMixin
28
+ from ...pyutils.python_env import PythonEnv
30
29
 
31
30
  if databricks_sdk is None: # pragma: no cover - import guard
32
31
  ResourceDoesNotExist = Exception # type: ignore
@@ -127,7 +126,7 @@ class Cluster(WorkspaceService):
127
126
  # 🔥 first time for this host → create
128
127
  inst = cls._env_clusters[host] = (
129
128
  cls(workspace=workspace, cluster_id=cluster_id, cluster_name=cluster_name)
130
- .replicate_current_environment(
129
+ .push_python_environment(
131
130
  single_user_name=single_user_name,
132
131
  runtime_engine=runtime_engine,
133
132
  libraries=libraries,
@@ -137,8 +136,9 @@ class Cluster(WorkspaceService):
137
136
 
138
137
  return inst
139
138
 
140
- def replicate_current_environment(
139
+ def push_python_environment(
141
140
  self,
141
+ source: Optional[PythonEnv] = None,
142
142
  cluster_id: Optional[str] = None,
143
143
  cluster_name: Optional[str] = None,
144
144
  single_user_name: Optional[str] = None,
@@ -146,15 +146,19 @@ class Cluster(WorkspaceService):
146
146
  libraries: Optional[list[str]] = None,
147
147
  **kwargs
148
148
  ) -> "Cluster":
149
+ if source is None:
150
+ source = PythonEnv.get_current()
151
+
149
152
  libraries = list(libraries) if libraries is not None else []
150
153
  libraries.extend([
151
154
  _ for _ in [
152
155
  "ygg",
153
- "dill"
156
+ "dill",
157
+ "uv",
154
158
  ] if _ not in libraries
155
159
  ])
156
160
 
157
- python_version = sys.version_info
161
+ python_version = source.version_info
158
162
 
159
163
  if python_version[0] < 3:
160
164
  python_version = None
@@ -172,6 +176,28 @@ class Cluster(WorkspaceService):
172
176
  )
173
177
 
174
178
  return inst
179
+
180
+ def pull_python_environment(
181
+ self,
182
+ target: Optional[PythonEnv] = None,
183
+ ):
184
+ with self.context() as c:
185
+ m = c.remote_metadata
186
+ requirements = m.requirements
187
+ version_info = m.version_info
188
+
189
+ if target is None:
190
+ target = PythonEnv.create(
191
+ name=f"dbx-{self.name}",
192
+ python=".".join(str(_) for _ in version_info)
193
+ )
194
+ else:
195
+ target.update(
196
+ requirements=requirements,
197
+ python=".".join(str(_) for _ in version_info)
198
+ )
199
+
200
+ return target
175
201
 
176
202
  @property
177
203
  def details(self):
@@ -569,7 +595,7 @@ class Cluster(WorkspaceService):
569
595
  logger.info("Deleting %s", self)
570
596
  return self.clusters_client().delete(cluster_id=self.cluster_id)
571
597
 
572
- def execution_context(
598
+ def context(
573
599
  self,
574
600
  language: Optional["Language"] = None,
575
601
  context_id: Optional[str] = None
@@ -591,7 +617,7 @@ class Cluster(WorkspaceService):
591
617
  timeout: Optional[dt.timedelta] = None,
592
618
  result_tag: Optional[str] = None,
593
619
  ):
594
- return self.execution_context(language=language).execute(
620
+ return self.context(language=language).execute(
595
621
  obj=obj,
596
622
  args=args,
597
623
  kwargs=kwargs,
@@ -607,8 +633,10 @@ class Cluster(WorkspaceService):
607
633
  self,
608
634
  _func: Optional[Callable] = None,
609
635
  *,
636
+ before: Optional[Callable] = None,
610
637
  language: Optional["Language"] = None,
611
638
  env_keys: Optional[List[str]] = None,
639
+ env_variables: Optional[Dict[str, str]] = None,
612
640
  timeout: Optional[dt.timedelta] = None,
613
641
  result_tag: Optional[str] = None,
614
642
  **options
@@ -630,19 +658,23 @@ class Cluster(WorkspaceService):
630
658
  def h(z): ...
631
659
  """
632
660
  def decorator(func: Callable):
633
- context = self.execution_context(language=language or Language.PYTHON)
634
- serialized = CallableSerdeMixin.from_callable(func)
661
+ context = self.context(language=language or Language.PYTHON)
662
+ serialized = CallableSerde.from_callable(func)
663
+ do_before = CallableSerde.from_callable(before)
635
664
 
636
665
  @functools.wraps(func)
637
666
  def wrapper(*args, **kwargs):
638
667
  if os.getenv("DATABRICKS_RUNTIME_VERSION") is not None:
639
668
  return func(*args, **kwargs)
640
669
 
670
+ do_before()
671
+
641
672
  return context.execute(
642
673
  obj=serialized,
643
674
  args=list(args),
644
675
  kwargs=kwargs,
645
676
  env_keys=env_keys,
677
+ env_variables=env_variables,
646
678
  timeout=timeout,
647
679
  result_tag=result_tag,
648
680
  **options
@@ -685,7 +717,9 @@ class Cluster(WorkspaceService):
685
717
 
686
718
  if wait_timeout is not None:
687
719
  self.wait_installed_libraries(
688
- timeout=wait_timeout, pip_settings=pip_settings, raise_error=raise_error
720
+ timeout=wait_timeout,
721
+ pip_settings=pip_settings,
722
+ raise_error=raise_error
689
723
  )
690
724
 
691
725
  return self
@@ -790,7 +824,7 @@ class Cluster(WorkspaceService):
790
824
  self,
791
825
  libraries: str | ModuleType | List[str | ModuleType],
792
826
  ):
793
- return self.execution_context().install_temporary_libraries(libraries=libraries)
827
+ return self.context().install_temporary_libraries(libraries=libraries)
794
828
 
795
829
  def _check_library(
796
830
  self,
@@ -823,7 +857,11 @@ class Cluster(WorkspaceService):
823
857
  repo = None
824
858
 
825
859
  if pip_settings.extra_index_url:
826
- if value.startswith("datamanagement") or value.startswith("TSSecrets") or value.startswith("tgp_"):
860
+ if (
861
+ value.startswith("datamanagement")
862
+ or value.startswith("TSSecrets")
863
+ or value.startswith("tgp_")
864
+ ):
827
865
  repo = pip_settings.extra_index_url
828
866
 
829
867
  return Library(
@@ -11,12 +11,12 @@ import sys
11
11
  import threading
12
12
  import zipfile
13
13
  from types import ModuleType
14
- from typing import TYPE_CHECKING, Optional, Any, Callable, List, Dict, Union, Iterable
14
+ from typing import TYPE_CHECKING, Optional, Any, Callable, List, Dict, Union, Iterable, Tuple
15
15
 
16
16
  from ...libs.databrickslib import databricks_sdk
17
17
  from ...pyutils.exceptions import raise_parsed_traceback
18
18
  from ...pyutils.modules import resolve_local_lib_path
19
- from ...ser import CallableSerdeMixin
19
+ from ...pyutils.callable_serde import CallableSerde
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from .cluster import Cluster
@@ -35,6 +35,8 @@ logger = logging.getLogger(__name__)
35
35
  class RemoteMetadata:
36
36
  site_packages_path: Optional[str] = dc.field(default=None)
37
37
  os_env: Dict[str, str] = dc.field(default_factory=dict)
38
+ requirements: Optional[str] = dc.field(default=None)
39
+ version_info: Tuple[int, int, int] = dc.field(default=(0, 0, 0))
38
40
 
39
41
  def os_env_diff(
40
42
  self,
@@ -49,6 +51,7 @@ class RemoteMetadata:
49
51
  if k not in self.os_env.keys()
50
52
  }
51
53
 
54
+
52
55
  @dc.dataclass
53
56
  class ExecutionContext:
54
57
  """
@@ -74,23 +77,21 @@ class ExecutionContext:
74
77
  _was_connected: Optional[bool] = None
75
78
  _remote_metadata: Optional[RemoteMetadata] = None
76
79
 
77
- __lock: threading.Lock = dc.field(default_factory=threading.Lock, init=False, repr=False)
80
+ _lock: threading.RLock = dc.field(default_factory=threading.RLock, init=False, repr=False)
78
81
 
79
82
  # --- Pickle / cloudpickle support (don’t serialize locks or cached remote metadata) ---
80
83
  def __getstate__(self):
81
84
  state = self.__dict__.copy()
82
85
 
83
- # name-mangled field for __lock in instance dict:
84
- state.pop("_ExecutionContext__lock", None)
86
+ # name-mangled field for _lock in instance dict:
87
+ state.pop("_lock", None)
85
88
 
86
89
  return state
87
90
 
88
91
  def __setstate__(self, state):
89
- self.__dict__.update(state)
92
+ state["_lock"] = state.get("_lock", threading.RLock())
90
93
 
91
- # recreate lock + reset cache on unpickle
92
- self.__dict__["_ExecutionContext__lock"] = threading.Lock()
93
- self._remote_metadata = None
94
+ self.__dict__.update(state)
94
95
 
95
96
  def __enter__(self) -> "ExecutionContext":
96
97
  self.cluster.__enter__()
@@ -106,20 +107,22 @@ class ExecutionContext:
106
107
  self.close()
107
108
 
108
109
  @property
109
- def remote_metadata(self) -> "RemoteMetadata":
110
+ def remote_metadata(self) -> RemoteMetadata:
110
111
  # fast path (no lock)
111
112
  rm = self._remote_metadata
112
113
  if rm is not None:
113
114
  return rm
114
115
 
115
116
  # slow path guarded
116
- with self.__lock:
117
+ with self._lock:
117
118
  # double-check after acquiring lock
118
119
  if self._remote_metadata is None:
119
120
  cmd = r"""import glob
120
121
  import json
121
122
  import os
123
+ from yggdrasil.pyutils import PythonEnv
122
124
 
125
+ current_env = PythonEnv.get_current()
123
126
  meta = {}
124
127
 
125
128
  for path in glob.glob('/local_**/.ephemeral_nfs/cluster_libraries/python/lib/python*/site-*', recursive=False):
@@ -130,9 +133,11 @@ for path in glob.glob('/local_**/.ephemeral_nfs/cluster_libraries/python/lib/pyt
130
133
  os_env = meta["os_env"] = {}
131
134
  for k, v in os.environ.items():
132
135
  os_env[k] = v
136
+
137
+ meta["requirements"] = current_env.export_requirements_matrix()
138
+ meta["version_info"] = current_env.version_info
133
139
 
134
- print(json.dumps(meta))
135
- """
140
+ print(json.dumps(meta))"""
136
141
 
137
142
  content = self.execute_command(
138
143
  command=cmd,
@@ -226,7 +231,7 @@ print(json.dumps(meta))
226
231
  # ------------ public API ------------
227
232
  def execute(
228
233
  self,
229
- obj: Union[str, Callable, CallableSerdeMixin],
234
+ obj: Union[str, Callable],
230
235
  *,
231
236
  args: List[Any] = None,
232
237
  kwargs: Dict[str, Any] = None,
@@ -260,7 +265,7 @@ print(json.dumps(meta))
260
265
 
261
266
  def execute_callable(
262
267
  self,
263
- func: Callable | CallableSerdeMixin,
268
+ func: Callable | CallableSerde,
264
269
  args: List[Any] = None,
265
270
  kwargs: Dict[str, Any] = None,
266
271
  env_keys: Optional[Iterable[str]] = None,
@@ -269,7 +274,7 @@ print(json.dumps(meta))
269
274
  timeout: Optional[dt.timedelta] = None,
270
275
  command: Optional[str] = None,
271
276
  use_dill: Optional[bool] = None
272
- ) -> str:
277
+ ) -> Any:
273
278
  if self.is_in_databricks_environment():
274
279
  args = args or []
275
280
  kwargs = kwargs or {}
@@ -284,7 +289,7 @@ print(json.dumps(meta))
284
289
  self,
285
290
  )
286
291
 
287
- serialized = CallableSerdeMixin.from_callable(func)
292
+ serialized = CallableSerde.from_callable(func)
288
293
 
289
294
  self.install_temporary_libraries(libraries=serialized.package_root)
290
295
 
@@ -302,9 +307,6 @@ print(json.dumps(meta))
302
307
  command = serialized.to_command(
303
308
  args=args,
304
309
  kwargs=kwargs,
305
- env_keys=env_keys or [],
306
- env_variables=env_variables,
307
- use_dill=use_dill,
308
310
  result_tag=result_tag,
309
311
  ) if not command else command
310
312
 
@@ -9,8 +9,6 @@ from typing import (
9
9
 
10
10
  from ..workspaces.workspace import Workspace
11
11
 
12
-
13
-
14
12
  ReturnType = TypeVar("ReturnType")
15
13
 
16
14
  logger = logging.getLogger(__name__)
@@ -1,2 +1,4 @@
1
1
  from .retry import retry
2
2
  from .parallel import parallelize
3
+ from .python_env import PythonEnv
4
+ from .callable_serde import CallableSerde