ygg 0.1.33__tar.gz → 0.1.37__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {ygg-0.1.33 → ygg-0.1.37}/PKG-INFO +1 -1
  2. {ygg-0.1.33 → ygg-0.1.37}/pyproject.toml +1 -1
  3. {ygg-0.1.33 → ygg-0.1.37}/src/ygg.egg-info/PKG-INFO +1 -1
  4. {ygg-0.1.33 → ygg-0.1.37}/src/ygg.egg-info/SOURCES.txt +2 -0
  5. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/compute/cluster.py +150 -70
  6. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/compute/execution_context.py +7 -4
  7. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/compute/remote.py +31 -13
  8. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/sql/engine.py +314 -324
  9. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/sql/statement_result.py +36 -44
  10. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/workspaces/workspace.py +12 -1
  11. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/pyutils/callable_serde.py +27 -2
  12. ygg-0.1.37/src/yggdrasil/pyutils/equality.py +107 -0
  13. ygg-0.1.37/src/yggdrasil/pyutils/expiring_dict.py +176 -0
  14. ygg-0.1.37/src/yggdrasil/version.py +1 -0
  15. ygg-0.1.33/src/yggdrasil/version.py +0 -1
  16. {ygg-0.1.33 → ygg-0.1.37}/LICENSE +0 -0
  17. {ygg-0.1.33 → ygg-0.1.37}/README.md +0 -0
  18. {ygg-0.1.33 → ygg-0.1.37}/setup.cfg +0 -0
  19. {ygg-0.1.33 → ygg-0.1.37}/src/ygg.egg-info/dependency_links.txt +0 -0
  20. {ygg-0.1.33 → ygg-0.1.37}/src/ygg.egg-info/entry_points.txt +0 -0
  21. {ygg-0.1.33 → ygg-0.1.37}/src/ygg.egg-info/requires.txt +0 -0
  22. {ygg-0.1.33 → ygg-0.1.37}/src/ygg.egg-info/top_level.txt +0 -0
  23. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/__init__.py +0 -0
  24. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/__init__.py +0 -0
  25. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/compute/__init__.py +0 -0
  26. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/jobs/__init__.py +0 -0
  27. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/jobs/config.py +0 -0
  28. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/sql/__init__.py +0 -0
  29. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/sql/exceptions.py +0 -0
  30. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/sql/types.py +0 -0
  31. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/workspaces/__init__.py +0 -0
  32. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/workspaces/filesytem.py +0 -0
  33. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/workspaces/io.py +0 -0
  34. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/workspaces/path.py +0 -0
  35. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/databricks/workspaces/path_kind.py +0 -0
  36. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/dataclasses/__init__.py +0 -0
  37. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/dataclasses/dataclass.py +0 -0
  38. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/__init__.py +0 -0
  39. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/databrickslib.py +0 -0
  40. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/extensions/__init__.py +0 -0
  41. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/extensions/polars_extensions.py +0 -0
  42. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/extensions/spark_extensions.py +0 -0
  43. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/pandaslib.py +0 -0
  44. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/polarslib.py +0 -0
  45. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/libs/sparklib.py +0 -0
  46. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/pyutils/__init__.py +0 -0
  47. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/pyutils/exceptions.py +0 -0
  48. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/pyutils/modules.py +0 -0
  49. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/pyutils/parallel.py +0 -0
  50. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/pyutils/python_env.py +0 -0
  51. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/pyutils/retry.py +0 -0
  52. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/requests/__init__.py +0 -0
  53. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/requests/msal.py +0 -0
  54. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/requests/session.py +0 -0
  55. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/__init__.py +0 -0
  56. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/__init__.py +0 -0
  57. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/arrow_cast.py +0 -0
  58. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/cast_options.py +0 -0
  59. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/pandas_cast.py +0 -0
  60. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/polars_cast.py +0 -0
  61. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/polars_pandas_cast.py +0 -0
  62. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/registry.py +0 -0
  63. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/spark_cast.py +0 -0
  64. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/spark_pandas_cast.py +0 -0
  65. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/cast/spark_polars_cast.py +0 -0
  66. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/libs.py +0 -0
  67. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/python_arrow.py +0 -0
  68. {ygg-0.1.33 → ygg-0.1.37}/src/yggdrasil/types/python_defaults.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ygg
3
- Version: 0.1.33
3
+ Version: 0.1.37
4
4
  Summary: Type-friendly utilities for moving data between Python objects, Arrow, Polars, Pandas, Spark, and Databricks
5
5
  Author: Yggdrasil contributors
6
6
  License: Apache License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "ygg"
7
- version = "0.1.33"
7
+ version = "0.1.37"
8
8
  description = "Type-friendly utilities for moving data between Python objects, Arrow, Polars, Pandas, Spark, and Databricks"
9
9
  readme = { file = "README.md", content-type = "text/markdown" }
10
10
  license = { file = "LICENSE" }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ygg
3
- Version: 0.1.33
3
+ Version: 0.1.37
4
4
  Summary: Type-friendly utilities for moving data between Python objects, Arrow, Polars, Pandas, Spark, and Databricks
5
5
  Author: Yggdrasil contributors
6
6
  License: Apache License
@@ -39,7 +39,9 @@ src/yggdrasil/libs/extensions/polars_extensions.py
39
39
  src/yggdrasil/libs/extensions/spark_extensions.py
40
40
  src/yggdrasil/pyutils/__init__.py
41
41
  src/yggdrasil/pyutils/callable_serde.py
42
+ src/yggdrasil/pyutils/equality.py
42
43
  src/yggdrasil/pyutils/exceptions.py
44
+ src/yggdrasil/pyutils/expiring_dict.py
43
45
  src/yggdrasil/pyutils/modules.py
44
46
  src/yggdrasil/pyutils/parallel.py
45
47
  src/yggdrasil/pyutils/python_env.py
@@ -24,6 +24,8 @@ from .execution_context import ExecutionContext
24
24
  from ..workspaces.workspace import WorkspaceService, Workspace
25
25
  from ... import retry, CallableSerde
26
26
  from ...libs.databrickslib import databricks_sdk
27
+ from ...pyutils.equality import dicts_equal, dict_diff
28
+ from ...pyutils.expiring_dict import ExpiringDict
27
29
  from ...pyutils.modules import PipIndexSettings
28
30
  from ...pyutils.python_env import PythonEnv
29
31
 
@@ -45,6 +47,31 @@ else: # pragma: no cover - runtime fallback when SDK is missing
45
47
  __all__ = ["Cluster"]
46
48
 
47
49
 
50
+ NAME_ID_CACHE: dict[str, ExpiringDict] = {}
51
+
52
+
53
+ def set_cached_cluster_name(
54
+ host: str,
55
+ cluster_name: str,
56
+ cluster_id: str
57
+ ) -> None:
58
+ existing = NAME_ID_CACHE.get(host)
59
+
60
+ if not existing:
61
+ existing = NAME_ID_CACHE[host] = ExpiringDict(default_ttl=60)
62
+
63
+ existing[cluster_name] = cluster_id
64
+
65
+
66
+ def get_cached_cluster_id(
67
+ host: str,
68
+ cluster_name: str,
69
+ ) -> str:
70
+ existing = NAME_ID_CACHE.get(host)
71
+
72
+ return existing.get(cluster_name) if existing else None
73
+
74
+
48
75
  logger = logging.getLogger(__name__)
49
76
 
50
77
 
@@ -84,6 +111,7 @@ class Cluster(WorkspaceService):
84
111
 
85
112
  _details: Optional["ClusterDetails"] = dataclasses.field(default=None, repr=False)
86
113
  _details_refresh_time: float = dataclasses.field(default=0, repr=False)
114
+ _system_context: Optional[ExecutionContext] = dataclasses.field(default=None, repr=False)
87
115
 
88
116
  # host → Cluster instance
89
117
  _env_clusters: ClassVar[Dict[str, "Cluster"]] = {}
@@ -98,10 +126,11 @@ class Cluster(WorkspaceService):
98
126
  """Return the current cluster name."""
99
127
  return self.cluster_name
100
128
 
101
- def __post_init__(self):
102
- """Initialize cached details after dataclass construction."""
103
- if self._details is not None:
104
- self.details = self._details
129
+ @property
130
+ def system_context(self):
131
+ if self._system_context is None:
132
+ self._system_context = self.context(language=Language.PYTHON)
133
+ return self._system_context
105
134
 
106
135
  def is_in_databricks_environment(self):
107
136
  """Return True when running on a Databricks runtime."""
@@ -233,9 +262,8 @@ class Cluster(WorkspaceService):
233
262
  Returns:
234
263
  The updated PythonEnv instance.
235
264
  """
236
- with self.context() as c:
237
- m = c.remote_metadata
238
- version_info = m.version_info
265
+ m = self.system_context.remote_metadata
266
+ version_info = m.version_info
239
267
 
240
268
  python_version = ".".join(str(_) for _ in version_info)
241
269
 
@@ -258,7 +286,7 @@ class Cluster(WorkspaceService):
258
286
  )
259
287
 
260
288
  return target
261
-
289
+
262
290
  @property
263
291
  def details(self):
264
292
  """Return cached cluster details, refreshing when needed."""
@@ -282,6 +310,11 @@ class Cluster(WorkspaceService):
282
310
  self.details = self.clusters_client().get(cluster_id=self.cluster_id)
283
311
  return self._details
284
312
 
313
+ def refresh(self, max_delay: float | None = None):
314
+ self.details = self.fresh_details(max_delay=max_delay)
315
+
316
+ return self
317
+
285
318
  @details.setter
286
319
  def details(self, value: "ClusterDetails"):
287
320
  """Cache cluster details and update identifiers."""
@@ -294,25 +327,10 @@ class Cluster(WorkspaceService):
294
327
  @property
295
328
  def state(self):
296
329
  """Return the current cluster state."""
297
- details = self.fresh_details(max_delay=10)
298
-
299
- if details is not None:
300
- return details.state
301
- return State.UNKNOWN
302
-
303
- def get_state(self, max_delay: float = None):
304
- """Return the cluster state with a custom refresh delay.
305
-
306
- Args:
307
- max_delay: Maximum age in seconds before refresh.
330
+ self.refresh()
308
331
 
309
- Returns:
310
- The current cluster state.
311
- """
312
- details = self.fresh_details(max_delay=max_delay)
313
-
314
- if details is not None:
315
- return details.state
332
+ if self._details is not None:
333
+ return self._details.state
316
334
  return State.UNKNOWN
317
335
 
318
336
  @property
@@ -323,7 +341,10 @@ class Cluster(WorkspaceService):
323
341
  @property
324
342
  def is_pending(self):
325
343
  """Return True when the cluster is starting, resizing, or terminating."""
326
- return self.state in (State.PENDING, State.RESIZING, State.RESTARTING, State.TERMINATING)
344
+ return self.state in (
345
+ State.PENDING, State.RESIZING, State.RESTARTING,
346
+ State.TERMINATING
347
+ )
327
348
 
328
349
  @property
329
350
  def is_error(self):
@@ -340,7 +361,7 @@ class Cluster(WorkspaceService):
340
361
  def wait_for_status(
341
362
  self,
342
363
  tick: float = 0.5,
343
- timeout: float = 600,
364
+ timeout: Union[float, dt.timedelta] = 600,
344
365
  backoff: int = 2,
345
366
  max_sleep_time: float = 15
346
367
  ):
@@ -358,6 +379,9 @@ class Cluster(WorkspaceService):
358
379
  start = time.time()
359
380
  sleep_time = tick
360
381
 
382
+ if isinstance(timeout, dt.timedelta):
383
+ timeout = timeout.total_seconds()
384
+
361
385
  while self.is_pending:
362
386
  time.sleep(sleep_time)
363
387
 
@@ -507,45 +531,51 @@ class Cluster(WorkspaceService):
507
531
  ):
508
532
  pip_settings = PipIndexSettings.default_settings()
509
533
 
510
- if kwargs:
511
- details = ClusterDetails(**{
512
- **details.as_shallow_dict(),
513
- **kwargs
514
- })
534
+ new_details = ClusterDetails(**{
535
+ **details.as_shallow_dict(),
536
+ **kwargs
537
+ })
538
+
539
+ default_tags = self.workspace.default_tags()
515
540
 
516
- if details.custom_tags is None:
517
- details.custom_tags = self.workspace.default_tags()
541
+ if new_details.custom_tags is None:
542
+ new_details.custom_tags = default_tags
543
+ elif default_tags:
544
+ new_tags = new_details.custom_tags.copy()
545
+ new_tags.update(default_tags)
518
546
 
519
- if details.cluster_name is None:
520
- details.cluster_name = self.workspace.current_user.user_name
547
+ new_details.custom_tags = new_tags
521
548
 
522
- if details.spark_version is None or python_version:
523
- details.spark_version = self.latest_spark_version(
549
+ if new_details.cluster_name is None:
550
+ new_details.cluster_name = self.workspace.current_user.user_name
551
+
552
+ if new_details.spark_version is None or python_version:
553
+ new_details.spark_version = self.latest_spark_version(
524
554
  photon=False, python_version=python_version
525
555
  ).key
526
556
 
527
- if details.single_user_name:
528
- if not details.data_security_mode:
529
- details.data_security_mode = DataSecurityMode.DATA_SECURITY_MODE_DEDICATED
557
+ if new_details.single_user_name:
558
+ if not new_details.data_security_mode:
559
+ new_details.data_security_mode = DataSecurityMode.DATA_SECURITY_MODE_DEDICATED
530
560
 
531
- if not details.node_type_id:
532
- details.node_type_id = "rd-fleet.xlarge"
561
+ if not new_details.node_type_id:
562
+ new_details.node_type_id = "rd-fleet.xlarge"
533
563
 
534
- if getattr(details, "virtual_cluster_size", None) is None and details.num_workers is None and details.autoscale is None:
535
- if details.is_single_node is None:
536
- details.is_single_node = True
564
+ if getattr(new_details, "virtual_cluster_size", None) is None and new_details.num_workers is None and new_details.autoscale is None:
565
+ if new_details.is_single_node is None:
566
+ new_details.is_single_node = True
537
567
 
538
- if details.is_single_node is not None and details.kind is None:
539
- details.kind = Kind.CLASSIC_PREVIEW
568
+ if new_details.is_single_node is not None and new_details.kind is None:
569
+ new_details.kind = Kind.CLASSIC_PREVIEW
540
570
 
541
571
  if pip_settings.extra_index_urls:
542
- if details.spark_env_vars is None:
543
- details.spark_env_vars = {}
572
+ if new_details.spark_env_vars is None:
573
+ new_details.spark_env_vars = {}
544
574
  str_urls = " ".join(pip_settings.extra_index_urls)
545
- details.spark_env_vars["UV_EXTRA_INDEX_URL"] = details.spark_env_vars.get("UV_INDEX", str_urls)
546
- details.spark_env_vars["PIP_EXTRA_INDEX_URL"] = details.spark_env_vars.get("PIP_EXTRA_INDEX_URL", str_urls)
575
+ new_details.spark_env_vars["UV_EXTRA_INDEX_URL"] = new_details.spark_env_vars.get("UV_INDEX", str_urls)
576
+ new_details.spark_env_vars["PIP_EXTRA_INDEX_URL"] = new_details.spark_env_vars.get("PIP_EXTRA_INDEX_URL", str_urls)
547
577
 
548
- return details
578
+ return new_details
549
579
 
550
580
  def create_or_update(
551
581
  self,
@@ -637,8 +667,6 @@ class Cluster(WorkspaceService):
637
667
  Returns:
638
668
  The updated Cluster instance.
639
669
  """
640
- self.install_libraries(libraries=libraries, wait_timeout=None, raise_error=False)
641
-
642
670
  existing_details = {
643
671
  k: v
644
672
  for k, v in self.details.as_shallow_dict().items()
@@ -651,20 +679,36 @@ class Cluster(WorkspaceService):
651
679
  if k in _EDIT_ARG_NAMES
652
680
  }
653
681
 
654
- if update_details != existing_details:
682
+ same = dicts_equal(
683
+ existing_details,
684
+ update_details,
685
+ keys=_EDIT_ARG_NAMES,
686
+ treat_missing_as_none=True,
687
+ float_tol=0.0, # set e.g. 1e-6 if you have float-y stuff
688
+ )
689
+
690
+ if not same:
691
+ diff = {
692
+ k: v[1]
693
+ for k, v in dict_diff(existing_details, update_details, keys=_EDIT_ARG_NAMES).items()
694
+ }
695
+
655
696
  logger.debug(
656
697
  "Updating %s with %s",
657
- self, update_details
698
+ self, diff
658
699
  )
659
700
 
660
701
  self.wait_for_status()
661
- self.details = self.clusters_client().edit_and_wait(**update_details)
702
+ self.details = self.clusters_client().edit(**update_details)
703
+ self.wait_for_status()
662
704
 
663
705
  logger.info(
664
706
  "Updated %s",
665
707
  self
666
708
  )
667
709
 
710
+ self.install_libraries(libraries=libraries, wait_timeout=None, raise_error=False)
711
+
668
712
  return self
669
713
 
670
714
  def list_clusters(self) -> Iterator["Cluster"]:
@@ -704,6 +748,12 @@ class Cluster(WorkspaceService):
704
748
  if not cluster_name and not cluster_id:
705
749
  raise ValueError("Either name or cluster_id must be provided")
706
750
 
751
+ if not cluster_id:
752
+ cluster_id = get_cached_cluster_id(
753
+ host=self.workspace.safe_host,
754
+ cluster_name=cluster_name
755
+ )
756
+
707
757
  if cluster_id:
708
758
  try:
709
759
  details = self.clusters_client().get(cluster_id=cluster_id)
@@ -713,13 +763,19 @@ class Cluster(WorkspaceService):
713
763
  return None
714
764
 
715
765
  return Cluster(
716
- workspace=self.workspace, cluster_id=details.cluster_id, _details=details
766
+ workspace=self.workspace,
767
+ cluster_id=details.cluster_id,
768
+ cluster_name=details.cluster_name,
769
+ _details=details
717
770
  )
718
771
 
719
- cluster_name_cf = cluster_name.casefold()
720
-
721
772
  for cluster in self.list_clusters():
722
- if cluster_name_cf == cluster.details.cluster_name.casefold():
773
+ if cluster_name == cluster.details.cluster_name:
774
+ set_cached_cluster_name(
775
+ host=self.workspace.safe_host,
776
+ cluster_name=cluster.cluster_name,
777
+ cluster_id=cluster.cluster_id
778
+ )
723
779
  return cluster
724
780
 
725
781
  if raise_error:
@@ -728,16 +784,18 @@ class Cluster(WorkspaceService):
728
784
 
729
785
  def ensure_running(
730
786
  self,
787
+ wait_timeout: Optional[dt.timedelta] = dt.timedelta(minutes=20)
731
788
  ) -> "Cluster":
732
789
  """Ensure the cluster is running.
733
790
 
734
791
  Returns:
735
792
  The current Cluster instance.
736
793
  """
737
- return self.start()
794
+ return self.start(wait_timeout=wait_timeout)
738
795
 
739
796
  def start(
740
797
  self,
798
+ wait_timeout: Optional[dt.timedelta] = dt.timedelta(minutes=20)
741
799
  ) -> "Cluster":
742
800
  """Start the cluster if it is not already running.
743
801
 
@@ -748,8 +806,13 @@ class Cluster(WorkspaceService):
748
806
 
749
807
  if not self.is_running:
750
808
  logger.info("Starting %s", self)
751
- self.details = self.clusters_client().start_and_wait(cluster_id=self.cluster_id)
752
- return self.wait_installed_libraries()
809
+
810
+ if wait_timeout:
811
+ self.clusters_client().start(cluster_id=self.cluster_id)
812
+ self.wait_for_status(timeout=wait_timeout.total_seconds())
813
+ self.wait_installed_libraries(timeout=wait_timeout)
814
+ else:
815
+ self.clusters_client().start(cluster_id=self.cluster_id)
753
816
 
754
817
  return self
755
818
 
@@ -812,6 +875,7 @@ class Cluster(WorkspaceService):
812
875
  env_keys: Optional[List[str]] = None,
813
876
  timeout: Optional[dt.timedelta] = None,
814
877
  result_tag: Optional[str] = None,
878
+ context: Optional[ExecutionContext] = None,
815
879
  ):
816
880
  """Execute a command or callable on the cluster.
817
881
 
@@ -823,11 +887,14 @@ class Cluster(WorkspaceService):
823
887
  env_keys: Optional environment variable names to pass.
824
888
  timeout: Optional timeout for execution.
825
889
  result_tag: Optional result tag for parsing output.
890
+ context: ExecutionContext to run or create new one
826
891
 
827
892
  Returns:
828
893
  The decoded result from the execution context.
829
894
  """
830
- return self.context(language=language).execute(
895
+ context = self.system_context if context is None else context
896
+
897
+ return context.execute(
831
898
  obj=obj,
832
899
  args=args,
833
900
  kwargs=kwargs,
@@ -849,6 +916,7 @@ class Cluster(WorkspaceService):
849
916
  timeout: Optional[dt.timedelta] = None,
850
917
  result_tag: Optional[str] = None,
851
918
  force_local: bool = False,
919
+ context: Optional[ExecutionContext] = None,
852
920
  **options
853
921
  ):
854
922
  """
@@ -875,16 +943,28 @@ class Cluster(WorkspaceService):
875
943
  timeout: Optional timeout for remote execution.
876
944
  result_tag: Optional tag for parsing remote output.
877
945
  force_local: force local execution
946
+ context: ExecutionContext to run or create new one
878
947
  **options: Additional execution options passed through.
879
948
 
880
949
  Returns:
881
950
  A decorator or wrapped function that executes remotely.
882
951
  """
952
+ if force_local or self.is_in_databricks_environment():
953
+ # Support both @ws.remote and @ws.remote(...)
954
+ if _func is not None and callable(_func):
955
+ return _func
956
+
957
+ def identity(x):
958
+ return x
959
+
960
+ return identity
961
+
962
+ context = self.system_context if context is None else context
963
+
883
964
  def decorator(func: Callable):
884
965
  if force_local or self.is_in_databricks_environment():
885
966
  return func
886
967
 
887
- context = self.context(language=language or Language.PYTHON)
888
968
  serialized = CallableSerde.from_callable(func)
889
969
 
890
970
  @functools.wraps(func)
@@ -1075,7 +1155,7 @@ class Cluster(WorkspaceService):
1075
1155
  "Waiting %s to install libraries timed out" % self
1076
1156
  )
1077
1157
 
1078
- time.sleep(10)
1158
+ time.sleep(5)
1079
1159
  statuses = list(self.installed_library_statuses())
1080
1160
 
1081
1161
  return self
@@ -1111,7 +1191,7 @@ class Cluster(WorkspaceService):
1111
1191
  )
1112
1192
 
1113
1193
  with open(value, mode="rb") as f:
1114
- target_path.write_bytes(f.read())
1194
+ target_path.open().write_all_bytes(f.read())
1115
1195
 
1116
1196
  value = str(target_path)
1117
1197
  elif "." in value and not "/" in value:
@@ -78,8 +78,8 @@ class ExecutionContext:
78
78
  language: Optional["Language"] = None
79
79
  context_id: Optional[str] = None
80
80
 
81
- _was_connected: Optional[bool] = None
82
- _remote_metadata: Optional[RemoteMetadata] = None
81
+ _was_connected: Optional[bool] = dc.field(default=None, repr=False)
82
+ _remote_metadata: Optional[RemoteMetadata] = dc.field(default=None, repr=False)
83
83
 
84
84
  _lock: threading.RLock = dc.field(default_factory=threading.RLock, init=False, repr=False)
85
85
 
@@ -367,6 +367,8 @@ print(json.dumps(meta))"""
367
367
  args=args,
368
368
  kwargs=kwargs,
369
369
  result_tag=result_tag,
370
+ env_keys=env_keys,
371
+ env_variables=env_variables
370
372
  ) if not command else command
371
373
 
372
374
  raw_result = self.execute_command(
@@ -382,8 +384,9 @@ print(json.dumps(meta))"""
382
384
  module_name = module_name.group(1) if module_name else None
383
385
  module_name = module_name.split(".")[0]
384
386
 
385
- if module_name:
387
+ if module_name and "yggdrasil" not in module_name:
386
388
  self.close()
389
+
387
390
  self.cluster.install_libraries(
388
391
  libraries=[module_name],
389
392
  raise_error=True,
@@ -442,7 +445,7 @@ print(json.dumps(meta))"""
442
445
  module_name = module_name.group(1) if module_name else None
443
446
  module_name = module_name.split(".")[0]
444
447
 
445
- if module_name:
448
+ if module_name and "yggdrasil" not in module_name:
446
449
  self.close()
447
450
  self.cluster.install_libraries(
448
451
  libraries=[module_name],
@@ -2,11 +2,12 @@
2
2
 
3
3
  import datetime as dt
4
4
  import logging
5
+ import os
5
6
  from typing import (
6
7
  Callable,
7
8
  Optional,
8
9
  TypeVar,
9
- List, TYPE_CHECKING,
10
+ List, TYPE_CHECKING, Union,
10
11
  )
11
12
 
12
13
  if TYPE_CHECKING:
@@ -14,15 +15,26 @@ if TYPE_CHECKING:
14
15
 
15
16
  from ..workspaces.workspace import Workspace
16
17
 
18
+
19
+ __all__ = [
20
+ "databricks_remote_compute"
21
+ ]
22
+
23
+
17
24
  ReturnType = TypeVar("ReturnType")
18
25
 
19
26
  logger = logging.getLogger(__name__)
20
27
 
21
28
 
29
+ def identity(x):
30
+ return x
31
+
32
+
22
33
  def databricks_remote_compute(
34
+ _func: Optional[Callable] = None,
23
35
  cluster_id: Optional[str] = None,
24
36
  cluster_name: Optional[str] = None,
25
- workspace: Optional[Workspace] = None,
37
+ workspace: Optional[Union[Workspace, str]] = None,
26
38
  cluster: Optional["Cluster"] = None,
27
39
  timeout: Optional[dt.timedelta] = None,
28
40
  env_keys: Optional[List[str]] = None,
@@ -32,6 +44,7 @@ def databricks_remote_compute(
32
44
  """Return a decorator that executes functions on a remote cluster.
33
45
 
34
46
  Args:
47
+ _func: function to decorate
35
48
  cluster_id: Optional cluster id to target.
36
49
  cluster_name: Optional cluster name to target.
37
50
  workspace: Workspace instance or host string for lookup.
@@ -45,13 +58,19 @@ def databricks_remote_compute(
45
58
  A decorator that runs functions on the resolved Databricks cluster.
46
59
  """
47
60
  if force_local or Workspace.is_in_databricks_environment():
48
- def identity(x):
49
- return x
61
+ return identity if _func is None else _func
50
62
 
51
- return identity
63
+ if workspace is None:
64
+ workspace = os.getenv("DATABRICKS_HOST")
52
65
 
53
- if isinstance(workspace, str):
54
- workspace = Workspace(host=workspace)
66
+ if workspace is None:
67
+ return identity if _func is None else _func
68
+
69
+ if not isinstance(workspace, Workspace):
70
+ if isinstance(workspace, str):
71
+ workspace = Workspace(host=workspace).connect(clone=False)
72
+ else:
73
+ raise ValueError("Cannot initialize databricks workspace with %s" % type(workspace))
55
74
 
56
75
  if cluster is None:
57
76
  if cluster_id or cluster_name:
@@ -62,16 +81,15 @@ def databricks_remote_compute(
62
81
  else:
63
82
  cluster = workspace.clusters().replicated_current_environment(
64
83
  workspace=workspace,
65
- cluster_name=cluster_name
84
+ cluster_name=cluster_name,
85
+ single_user_name=workspace.current_user.user_name
66
86
  )
67
87
 
88
+ cluster.ensure_running(wait_timeout=None)
89
+
68
90
  return cluster.execution_decorator(
91
+ _func=_func,
69
92
  env_keys=env_keys,
70
93
  timeout=timeout,
71
94
  **options
72
95
  )
73
-
74
-
75
- __all__ = [
76
- "databricks_remote_compute",
77
- ]