ygg 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ygg
3
- Version: 0.1.33
3
+ Version: 0.1.35
4
4
  Summary: Type-friendly utilities for moving data between Python objects, Arrow, Polars, Pandas, Spark, and Databricks
5
5
  Author: Yggdrasil contributors
6
6
  License: Apache License
@@ -1,15 +1,15 @@
1
- ygg-0.1.33.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
1
+ ygg-0.1.35.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
2
2
  yggdrasil/__init__.py,sha256=PfH7Xwt6uue6oqe6S5V8NhDJcVQClkKrBE1KXhdelZc,117
3
- yggdrasil/version.py,sha256=-w_wSS_K7mBDaL2ciGoQ17Ab9V3-ElRGvc4UASK1hFI,22
3
+ yggdrasil/version.py,sha256=dvUFqQgabIeithBNW7swqwxg6R59T0-i289dJYgomYQ,22
4
4
  yggdrasil/databricks/__init__.py,sha256=skctY2c8W-hI81upx9F_PWRe5ishL3hrdiTuizgDjdw,152
5
5
  yggdrasil/databricks/compute/__init__.py,sha256=NvdzmaJSNYY1uJthv1hHdBuNu3bD_-Z65DWnaJt9yXg,289
6
- yggdrasil/databricks/compute/cluster.py,sha256=PTA3OOl2s23NGrMHOCa4ues-cx6mb5p7LtsI0bl_oa8,38088
7
- yggdrasil/databricks/compute/execution_context.py,sha256=_PWn-Jjb-qRtI7HGm7IOUAKoIpmivNiBZTddxpZNBaM,22142
8
- yggdrasil/databricks/compute/remote.py,sha256=735VSSkdGwIKXZ2F_SgiPiH-dUh7bheHBc4zaB6fDp8,2265
6
+ yggdrasil/databricks/compute/cluster.py,sha256=yx3xEJ5Vgg-IPeyYxxdf74_9DSBPHvg2-FYs8ptJrl0,40404
7
+ yggdrasil/databricks/compute/execution_context.py,sha256=nxrNXoarq_JAB-Cpj0udHhq2jx-DmMbRWJdAezLrPis,22347
8
+ yggdrasil/databricks/compute/remote.py,sha256=nEN_Fr1Ouul_iKOf4B5QjEGscYAcl7nHjGsl2toRzrU,2874
9
9
  yggdrasil/databricks/jobs/__init__.py,sha256=snxGSJb0M5I39v0y3IR-uEeSlZR248cQ_4DJ1sYs-h8,154
10
10
  yggdrasil/databricks/jobs/config.py,sha256=9LGeHD04hbfy0xt8_6oobC4moKJh4_DTjZiK4Q2Tqjk,11557
11
11
  yggdrasil/databricks/sql/__init__.py,sha256=y1n5yg-drZ8QVZbEgznsRG24kdJSnFis9l2YfYCsaCM,234
12
- yggdrasil/databricks/sql/engine.py,sha256=BO2lweaL63etVJ9jia14JJvGVGZ6f0X6XsZ3QlA6_Po,38146
12
+ yggdrasil/databricks/sql/engine.py,sha256=kUFBddJJQC0AgDqH0l7GFs7d_Ony5rc8fOv4inLU6Vw,41051
13
13
  yggdrasil/databricks/sql/exceptions.py,sha256=Jqd_gT_VyPL8klJEHYEzpv5eHtmdY43WiQ7HZBaEqSk,53
14
14
  yggdrasil/databricks/sql/statement_result.py,sha256=VlHXhTcvTVya_2aJ-uUfUooZF_MqQuOZ8k7g6PBDhOM,17227
15
15
  yggdrasil/databricks/sql/types.py,sha256=5G-BM9_eOsRKEMzeDTWUsWW5g4Idvs-czVCpOCrMhdA,6412
@@ -18,7 +18,7 @@ yggdrasil/databricks/workspaces/filesytem.py,sha256=Z8JXU7_XUEbw9fpTQT1avRQKi-IA
18
18
  yggdrasil/databricks/workspaces/io.py,sha256=Tdde4LaGNJNT50R11OkEYZyNacyIW9QrOXMAicAlIr4,32208
19
19
  yggdrasil/databricks/workspaces/path.py,sha256=-XnCD9p42who3DAwnITVE1KyrZUSoXDKHA8iZi-7wk4,47743
20
20
  yggdrasil/databricks/workspaces/path_kind.py,sha256=Xc319NysH8_6E9C0Q8nCxDHYG07_SnzyUVKHe0dNdDQ,305
21
- yggdrasil/databricks/workspaces/workspace.py,sha256=RixYZYhASeKPPPdP6JPN1XQC8M_rC0N-oBheasF--1k,23183
21
+ yggdrasil/databricks/workspaces/workspace.py,sha256=MW-BEyldROqbX9SBbDspvlys_zehJjK5YgM3sGLfW-g,23382
22
22
  yggdrasil/dataclasses/__init__.py,sha256=6SdfIyTsoM4AuVw5TW4Q-UWXz41EyfsMcpD30cmjbSM,125
23
23
  yggdrasil/dataclasses/dataclass.py,sha256=fKokFUnqe4CmXXGMTdF4XDWbCUl_c_-se-UD48L5s1E,6594
24
24
  yggdrasil/libs/__init__.py,sha256=ulzk-ZkFUI2Pfo93YKtO8MBsEWtRZzLos7HAxN74R0w,168
@@ -30,8 +30,10 @@ yggdrasil/libs/extensions/__init__.py,sha256=mcXW5Li3Cbprbs4Ci-b5A0Ju0wmLcfvEiFu
30
30
  yggdrasil/libs/extensions/polars_extensions.py,sha256=RTkGi8llhPJjX7x9egix7-yXWo2X24zIAPSKXV37SSA,12397
31
31
  yggdrasil/libs/extensions/spark_extensions.py,sha256=E64n-3SFTDgMuXwWitX6vOYP9ln2lpGKb0htoBLEZgc,16745
32
32
  yggdrasil/pyutils/__init__.py,sha256=tl-LapAc71TV7RMgf2ftKwrzr8iiLOGHeJgA3RvO93w,293
33
- yggdrasil/pyutils/callable_serde.py,sha256=prxzYRrjR6-mZ9i1rIWakWL0w1JHPnTKBO_RR1NMacg,20992
33
+ yggdrasil/pyutils/callable_serde.py,sha256=euY7Kiy04i1tpWKuB0b2qQ1FokLC3nq0cv7PObWYUBE,21809
34
+ yggdrasil/pyutils/equality.py,sha256=Xyf8D1dLUCm3spDEir8Zyj7O4US_fBJwEylJCfJ9slI,3080
34
35
  yggdrasil/pyutils/exceptions.py,sha256=ssKNm-rjhavHUOZmGA7_1Gq9tSHDrb2EFI-cnBuWgng,3388
36
+ yggdrasil/pyutils/expiring_dict.py,sha256=q9gb09-2EUN-jQZumUw5BXOQGYcj1wb85qKtQlciSxg,5825
35
37
  yggdrasil/pyutils/modules.py,sha256=B7IP99YqUMW6-DIESFzBx8-09V1d0a8qrIJUDFhhL2g,11424
36
38
  yggdrasil/pyutils/parallel.py,sha256=ubuq2m9dJzWYUyKCga4Y_9bpaeMYUrleYxdp49CHr44,6781
37
39
  yggdrasil/pyutils/python_env.py,sha256=tuglnjdqHQjNh18qDladVoSEOjCD0RcnMEPYJ0tArOs,50985
@@ -53,8 +55,8 @@ yggdrasil/types/cast/registry.py,sha256=_zdFGmUBB7P-e_LIcJlOxMcxAkXoA-UXB6HqLMgT
53
55
  yggdrasil/types/cast/spark_cast.py,sha256=_KAsl1DqmKMSfWxqhVE7gosjYdgiL1C5bDQv6eP3HtA,24926
54
56
  yggdrasil/types/cast/spark_pandas_cast.py,sha256=BuTiWrdCANZCdD_p2MAytqm74eq-rdRXd-LGojBRrfU,5023
55
57
  yggdrasil/types/cast/spark_polars_cast.py,sha256=btmZNHXn2NSt3fUuB4xg7coaE0RezIBdZD92H8NK0Jw,9073
56
- ygg-0.1.33.dist-info/METADATA,sha256=de-ntaFUDrFbT6rErjGnnZfPxN4DFBmxgokYLakRRcA,19204
57
- ygg-0.1.33.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
58
- ygg-0.1.33.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
59
- ygg-0.1.33.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
60
- ygg-0.1.33.dist-info/RECORD,,
58
+ ygg-0.1.35.dist-info/METADATA,sha256=qSh-Cd0LtjU_CV9Je0E1ns1AnuRbIUPMIpvVJoJtIUA,19204
59
+ ygg-0.1.35.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
60
+ ygg-0.1.35.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
61
+ ygg-0.1.35.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
62
+ ygg-0.1.35.dist-info/RECORD,,
@@ -24,6 +24,8 @@ from .execution_context import ExecutionContext
24
24
  from ..workspaces.workspace import WorkspaceService, Workspace
25
25
  from ... import retry, CallableSerde
26
26
  from ...libs.databrickslib import databricks_sdk
27
+ from ...pyutils.equality import dicts_equal, dict_diff
28
+ from ...pyutils.expiring_dict import ExpiringDict
27
29
  from ...pyutils.modules import PipIndexSettings
28
30
  from ...pyutils.python_env import PythonEnv
29
31
 
@@ -45,6 +47,31 @@ else: # pragma: no cover - runtime fallback when SDK is missing
45
47
  __all__ = ["Cluster"]
46
48
 
47
49
 
50
+ NAME_ID_CACHE: dict[str, ExpiringDict] = {}
51
+
52
+
53
+ def set_cached_cluster_name(
54
+ host: str,
55
+ cluster_name: str,
56
+ cluster_id: str
57
+ ) -> None:
58
+ existing = NAME_ID_CACHE.get(host)
59
+
60
+ if not existing:
61
+ existing = NAME_ID_CACHE[host] = ExpiringDict(default_ttl=60)
62
+
63
+ existing[cluster_name] = cluster_id
64
+
65
+
66
+ def get_cached_cluster_id(
67
+ host: str,
68
+ cluster_name: str,
69
+ ) -> str:
70
+ existing = NAME_ID_CACHE.get(host)
71
+
72
+ return existing.get(cluster_name) if existing else None
73
+
74
+
48
75
  logger = logging.getLogger(__name__)
49
76
 
50
77
 
@@ -84,6 +111,7 @@ class Cluster(WorkspaceService):
84
111
 
85
112
  _details: Optional["ClusterDetails"] = dataclasses.field(default=None, repr=False)
86
113
  _details_refresh_time: float = dataclasses.field(default=0, repr=False)
114
+ _system_context: Optional[ExecutionContext] = dataclasses.field(default=None, repr=False)
87
115
 
88
116
  # host → Cluster instance
89
117
  _env_clusters: ClassVar[Dict[str, "Cluster"]] = {}
@@ -98,10 +126,11 @@ class Cluster(WorkspaceService):
98
126
  """Return the current cluster name."""
99
127
  return self.cluster_name
100
128
 
101
- def __post_init__(self):
102
- """Initialize cached details after dataclass construction."""
103
- if self._details is not None:
104
- self.details = self._details
129
+ @property
130
+ def system_context(self):
131
+ if self._system_context is None:
132
+ self._system_context = self.context(language=Language.PYTHON)
133
+ return self._system_context
105
134
 
106
135
  def is_in_databricks_environment(self):
107
136
  """Return True when running on a Databricks runtime."""
@@ -233,9 +262,8 @@ class Cluster(WorkspaceService):
233
262
  Returns:
234
263
  The updated PythonEnv instance.
235
264
  """
236
- with self.context() as c:
237
- m = c.remote_metadata
238
- version_info = m.version_info
265
+ m = self.system_context.remote_metadata
266
+ version_info = m.version_info
239
267
 
240
268
  python_version = ".".join(str(_) for _ in version_info)
241
269
 
@@ -258,7 +286,7 @@ class Cluster(WorkspaceService):
258
286
  )
259
287
 
260
288
  return target
261
-
289
+
262
290
  @property
263
291
  def details(self):
264
292
  """Return cached cluster details, refreshing when needed."""
@@ -300,21 +328,6 @@ class Cluster(WorkspaceService):
300
328
  return details.state
301
329
  return State.UNKNOWN
302
330
 
303
- def get_state(self, max_delay: float = None):
304
- """Return the cluster state with a custom refresh delay.
305
-
306
- Args:
307
- max_delay: Maximum age in seconds before refresh.
308
-
309
- Returns:
310
- The current cluster state.
311
- """
312
- details = self.fresh_details(max_delay=max_delay)
313
-
314
- if details is not None:
315
- return details.state
316
- return State.UNKNOWN
317
-
318
331
  @property
319
332
  def is_running(self):
320
333
  """Return True when the cluster is running."""
@@ -323,7 +336,10 @@ class Cluster(WorkspaceService):
323
336
  @property
324
337
  def is_pending(self):
325
338
  """Return True when the cluster is starting, resizing, or terminating."""
326
- return self.state in (State.PENDING, State.RESIZING, State.RESTARTING, State.TERMINATING)
339
+ return self.state in (
340
+ State.PENDING, State.RESIZING, State.RESTARTING,
341
+ State.TERMINATING
342
+ )
327
343
 
328
344
  @property
329
345
  def is_error(self):
@@ -507,45 +523,51 @@ class Cluster(WorkspaceService):
507
523
  ):
508
524
  pip_settings = PipIndexSettings.default_settings()
509
525
 
510
- if kwargs:
511
- details = ClusterDetails(**{
512
- **details.as_shallow_dict(),
513
- **kwargs
514
- })
526
+ new_details = ClusterDetails(**{
527
+ **details.as_shallow_dict(),
528
+ **kwargs
529
+ })
530
+
531
+ default_tags = self.workspace.default_tags()
515
532
 
516
- if details.custom_tags is None:
517
- details.custom_tags = self.workspace.default_tags()
533
+ if new_details.custom_tags is None:
534
+ new_details.custom_tags = default_tags
535
+ elif default_tags:
536
+ new_tags = new_details.custom_tags.copy()
537
+ new_tags.update(default_tags)
518
538
 
519
- if details.cluster_name is None:
520
- details.cluster_name = self.workspace.current_user.user_name
539
+ new_details.custom_tags = new_tags
521
540
 
522
- if details.spark_version is None or python_version:
523
- details.spark_version = self.latest_spark_version(
541
+ if new_details.cluster_name is None:
542
+ new_details.cluster_name = self.workspace.current_user.user_name
543
+
544
+ if new_details.spark_version is None or python_version:
545
+ new_details.spark_version = self.latest_spark_version(
524
546
  photon=False, python_version=python_version
525
547
  ).key
526
548
 
527
- if details.single_user_name:
528
- if not details.data_security_mode:
529
- details.data_security_mode = DataSecurityMode.DATA_SECURITY_MODE_DEDICATED
549
+ if new_details.single_user_name:
550
+ if not new_details.data_security_mode:
551
+ new_details.data_security_mode = DataSecurityMode.DATA_SECURITY_MODE_DEDICATED
530
552
 
531
- if not details.node_type_id:
532
- details.node_type_id = "rd-fleet.xlarge"
553
+ if not new_details.node_type_id:
554
+ new_details.node_type_id = "rd-fleet.xlarge"
533
555
 
534
- if getattr(details, "virtual_cluster_size", None) is None and details.num_workers is None and details.autoscale is None:
535
- if details.is_single_node is None:
536
- details.is_single_node = True
556
+ if getattr(new_details, "virtual_cluster_size", None) is None and new_details.num_workers is None and new_details.autoscale is None:
557
+ if new_details.is_single_node is None:
558
+ new_details.is_single_node = True
537
559
 
538
- if details.is_single_node is not None and details.kind is None:
539
- details.kind = Kind.CLASSIC_PREVIEW
560
+ if new_details.is_single_node is not None and new_details.kind is None:
561
+ new_details.kind = Kind.CLASSIC_PREVIEW
540
562
 
541
563
  if pip_settings.extra_index_urls:
542
- if details.spark_env_vars is None:
543
- details.spark_env_vars = {}
564
+ if new_details.spark_env_vars is None:
565
+ new_details.spark_env_vars = {}
544
566
  str_urls = " ".join(pip_settings.extra_index_urls)
545
- details.spark_env_vars["UV_EXTRA_INDEX_URL"] = details.spark_env_vars.get("UV_INDEX", str_urls)
546
- details.spark_env_vars["PIP_EXTRA_INDEX_URL"] = details.spark_env_vars.get("PIP_EXTRA_INDEX_URL", str_urls)
567
+ new_details.spark_env_vars["UV_EXTRA_INDEX_URL"] = new_details.spark_env_vars.get("UV_INDEX", str_urls)
568
+ new_details.spark_env_vars["PIP_EXTRA_INDEX_URL"] = new_details.spark_env_vars.get("PIP_EXTRA_INDEX_URL", str_urls)
547
569
 
548
- return details
570
+ return new_details
549
571
 
550
572
  def create_or_update(
551
573
  self,
@@ -637,8 +659,6 @@ class Cluster(WorkspaceService):
637
659
  Returns:
638
660
  The updated Cluster instance.
639
661
  """
640
- self.install_libraries(libraries=libraries, wait_timeout=None, raise_error=False)
641
-
642
662
  existing_details = {
643
663
  k: v
644
664
  for k, v in self.details.as_shallow_dict().items()
@@ -651,10 +671,23 @@ class Cluster(WorkspaceService):
651
671
  if k in _EDIT_ARG_NAMES
652
672
  }
653
673
 
654
- if update_details != existing_details:
674
+ same = dicts_equal(
675
+ existing_details,
676
+ update_details,
677
+ keys=_EDIT_ARG_NAMES,
678
+ treat_missing_as_none=True,
679
+ float_tol=0.0, # set e.g. 1e-6 if you have float-y stuff
680
+ )
681
+
682
+ if not same:
683
+ diff = {
684
+ k: v[1]
685
+ for k, v in dict_diff(existing_details, update_details, keys=_EDIT_ARG_NAMES).items()
686
+ }
687
+
655
688
  logger.debug(
656
689
  "Updating %s with %s",
657
- self, update_details
690
+ self, diff
658
691
  )
659
692
 
660
693
  self.wait_for_status()
@@ -665,6 +698,8 @@ class Cluster(WorkspaceService):
665
698
  self
666
699
  )
667
700
 
701
+ self.install_libraries(libraries=libraries, wait_timeout=None, raise_error=False)
702
+
668
703
  return self
669
704
 
670
705
  def list_clusters(self) -> Iterator["Cluster"]:
@@ -704,6 +739,12 @@ class Cluster(WorkspaceService):
704
739
  if not cluster_name and not cluster_id:
705
740
  raise ValueError("Either name or cluster_id must be provided")
706
741
 
742
+ if not cluster_id:
743
+ cluster_id = get_cached_cluster_id(
744
+ host=self.workspace.safe_host,
745
+ cluster_name=cluster_name
746
+ )
747
+
707
748
  if cluster_id:
708
749
  try:
709
750
  details = self.clusters_client().get(cluster_id=cluster_id)
@@ -716,10 +757,13 @@ class Cluster(WorkspaceService):
716
757
  workspace=self.workspace, cluster_id=details.cluster_id, _details=details
717
758
  )
718
759
 
719
- cluster_name_cf = cluster_name.casefold()
720
-
721
760
  for cluster in self.list_clusters():
722
- if cluster_name_cf == cluster.details.cluster_name.casefold():
761
+ if cluster_name == cluster.details.cluster_name:
762
+ set_cached_cluster_name(
763
+ host=self.workspace.safe_host,
764
+ cluster_name=cluster.cluster_name,
765
+ cluster_id=cluster.cluster_id
766
+ )
723
767
  return cluster
724
768
 
725
769
  if raise_error:
@@ -728,16 +772,18 @@ class Cluster(WorkspaceService):
728
772
 
729
773
  def ensure_running(
730
774
  self,
775
+ wait_timeout: Optional[dt.timedelta] = dt.timedelta(minutes=20)
731
776
  ) -> "Cluster":
732
777
  """Ensure the cluster is running.
733
778
 
734
779
  Returns:
735
780
  The current Cluster instance.
736
781
  """
737
- return self.start()
782
+ return self.start(wait_timeout=wait_timeout)
738
783
 
739
784
  def start(
740
785
  self,
786
+ wait_timeout: Optional[dt.timedelta] = dt.timedelta(minutes=20)
741
787
  ) -> "Cluster":
742
788
  """Start the cluster if it is not already running.
743
789
 
@@ -748,8 +794,15 @@ class Cluster(WorkspaceService):
748
794
 
749
795
  if not self.is_running:
750
796
  logger.info("Starting %s", self)
751
- self.details = self.clusters_client().start_and_wait(cluster_id=self.cluster_id)
752
- return self.wait_installed_libraries()
797
+
798
+ if wait_timeout:
799
+ self.details = (
800
+ self.clusters_client()
801
+ .start_and_wait(cluster_id=self.cluster_id, timeout=wait_timeout)
802
+ )
803
+ self.wait_installed_libraries(timeout=wait_timeout)
804
+ else:
805
+ self.clusters_client().start(cluster_id=self.cluster_id)
753
806
 
754
807
  return self
755
808
 
@@ -812,6 +865,7 @@ class Cluster(WorkspaceService):
812
865
  env_keys: Optional[List[str]] = None,
813
866
  timeout: Optional[dt.timedelta] = None,
814
867
  result_tag: Optional[str] = None,
868
+ context: Optional[ExecutionContext] = None,
815
869
  ):
816
870
  """Execute a command or callable on the cluster.
817
871
 
@@ -823,11 +877,14 @@ class Cluster(WorkspaceService):
823
877
  env_keys: Optional environment variable names to pass.
824
878
  timeout: Optional timeout for execution.
825
879
  result_tag: Optional result tag for parsing output.
880
+ context: ExecutionContext to run or create new one
826
881
 
827
882
  Returns:
828
883
  The decoded result from the execution context.
829
884
  """
830
- return self.context(language=language).execute(
885
+ context = self.system_context if context is None else context
886
+
887
+ return context.execute(
831
888
  obj=obj,
832
889
  args=args,
833
890
  kwargs=kwargs,
@@ -849,6 +906,7 @@ class Cluster(WorkspaceService):
849
906
  timeout: Optional[dt.timedelta] = None,
850
907
  result_tag: Optional[str] = None,
851
908
  force_local: bool = False,
909
+ context: Optional[ExecutionContext] = None,
852
910
  **options
853
911
  ):
854
912
  """
@@ -875,16 +933,28 @@ class Cluster(WorkspaceService):
875
933
  timeout: Optional timeout for remote execution.
876
934
  result_tag: Optional tag for parsing remote output.
877
935
  force_local: force local execution
936
+ context: ExecutionContext to run or create new one
878
937
  **options: Additional execution options passed through.
879
938
 
880
939
  Returns:
881
940
  A decorator or wrapped function that executes remotely.
882
941
  """
942
+ if force_local or self.is_in_databricks_environment():
943
+ # Support both @ws.remote and @ws.remote(...)
944
+ if _func is not None and callable(_func):
945
+ return _func
946
+
947
+ def identity(x):
948
+ return x
949
+
950
+ return identity
951
+
952
+ context = self.system_context if context is None else context
953
+
883
954
  def decorator(func: Callable):
884
955
  if force_local or self.is_in_databricks_environment():
885
956
  return func
886
957
 
887
- context = self.context(language=language or Language.PYTHON)
888
958
  serialized = CallableSerde.from_callable(func)
889
959
 
890
960
  @functools.wraps(func)
@@ -1075,7 +1145,7 @@ class Cluster(WorkspaceService):
1075
1145
  "Waiting %s to install libraries timed out" % self
1076
1146
  )
1077
1147
 
1078
- time.sleep(10)
1148
+ time.sleep(5)
1079
1149
  statuses = list(self.installed_library_statuses())
1080
1150
 
1081
1151
  return self
@@ -1111,7 +1181,7 @@ class Cluster(WorkspaceService):
1111
1181
  )
1112
1182
 
1113
1183
  with open(value, mode="rb") as f:
1114
- target_path.write_bytes(f.read())
1184
+ target_path.open().write_all_bytes(f.read())
1115
1185
 
1116
1186
  value = str(target_path)
1117
1187
  elif "." in value and not "/" in value:
@@ -78,8 +78,8 @@ class ExecutionContext:
78
78
  language: Optional["Language"] = None
79
79
  context_id: Optional[str] = None
80
80
 
81
- _was_connected: Optional[bool] = None
82
- _remote_metadata: Optional[RemoteMetadata] = None
81
+ _was_connected: Optional[bool] = dc.field(default=None, repr=False)
82
+ _remote_metadata: Optional[RemoteMetadata] = dc.field(default=None, repr=False)
83
83
 
84
84
  _lock: threading.RLock = dc.field(default_factory=threading.RLock, init=False, repr=False)
85
85
 
@@ -367,6 +367,8 @@ print(json.dumps(meta))"""
367
367
  args=args,
368
368
  kwargs=kwargs,
369
369
  result_tag=result_tag,
370
+ env_keys=env_keys,
371
+ env_variables=env_variables
370
372
  ) if not command else command
371
373
 
372
374
  raw_result = self.execute_command(
@@ -382,8 +384,9 @@ print(json.dumps(meta))"""
382
384
  module_name = module_name.group(1) if module_name else None
383
385
  module_name = module_name.split(".")[0]
384
386
 
385
- if module_name:
387
+ if module_name and "yggdrasil" not in module_name:
386
388
  self.close()
389
+
387
390
  self.cluster.install_libraries(
388
391
  libraries=[module_name],
389
392
  raise_error=True,
@@ -442,7 +445,7 @@ print(json.dumps(meta))"""
442
445
  module_name = module_name.group(1) if module_name else None
443
446
  module_name = module_name.split(".")[0]
444
447
 
445
- if module_name:
448
+ if module_name and "yggdrasil" not in module_name:
446
449
  self.close()
447
450
  self.cluster.install_libraries(
448
451
  libraries=[module_name],
@@ -2,11 +2,12 @@
2
2
 
3
3
  import datetime as dt
4
4
  import logging
5
+ import os
5
6
  from typing import (
6
7
  Callable,
7
8
  Optional,
8
9
  TypeVar,
9
- List, TYPE_CHECKING,
10
+ List, TYPE_CHECKING, Union,
10
11
  )
11
12
 
12
13
  if TYPE_CHECKING:
@@ -14,15 +15,26 @@ if TYPE_CHECKING:
14
15
 
15
16
  from ..workspaces.workspace import Workspace
16
17
 
18
+
19
+ __all__ = [
20
+ "databricks_remote_compute"
21
+ ]
22
+
23
+
17
24
  ReturnType = TypeVar("ReturnType")
18
25
 
19
26
  logger = logging.getLogger(__name__)
20
27
 
21
28
 
29
+ def identity(x):
30
+ return x
31
+
32
+
22
33
  def databricks_remote_compute(
34
+ _func: Optional[Callable] = None,
23
35
  cluster_id: Optional[str] = None,
24
36
  cluster_name: Optional[str] = None,
25
- workspace: Optional[Workspace] = None,
37
+ workspace: Optional[Union[Workspace, str]] = None,
26
38
  cluster: Optional["Cluster"] = None,
27
39
  timeout: Optional[dt.timedelta] = None,
28
40
  env_keys: Optional[List[str]] = None,
@@ -32,6 +44,7 @@ def databricks_remote_compute(
32
44
  """Return a decorator that executes functions on a remote cluster.
33
45
 
34
46
  Args:
47
+ _func: function to decorate
35
48
  cluster_id: Optional cluster id to target.
36
49
  cluster_name: Optional cluster name to target.
37
50
  workspace: Workspace instance or host string for lookup.
@@ -45,13 +58,19 @@ def databricks_remote_compute(
45
58
  A decorator that runs functions on the resolved Databricks cluster.
46
59
  """
47
60
  if force_local or Workspace.is_in_databricks_environment():
48
- def identity(x):
49
- return x
61
+ return identity if _func is None else _func
50
62
 
51
- return identity
63
+ if workspace is None:
64
+ workspace = os.getenv("DATABRICKS_HOST")
52
65
 
53
- if isinstance(workspace, str):
54
- workspace = Workspace(host=workspace)
66
+ if workspace is None:
67
+ return identity if _func is None else _func
68
+
69
+ if not isinstance(workspace, Workspace):
70
+ if isinstance(workspace, str):
71
+ workspace = Workspace(host=workspace).connect(clone=False)
72
+ else:
73
+ raise ValueError("Cannot initialize databricks workspace with %s" % type(workspace))
55
74
 
56
75
  if cluster is None:
57
76
  if cluster_id or cluster_name:
@@ -62,16 +81,15 @@ def databricks_remote_compute(
62
81
  else:
63
82
  cluster = workspace.clusters().replicated_current_environment(
64
83
  workspace=workspace,
65
- cluster_name=cluster_name
84
+ cluster_name=cluster_name,
85
+ single_user_name=workspace.current_user.user_name
66
86
  )
67
87
 
88
+ cluster.ensure_running(wait_timeout=None)
89
+
68
90
  return cluster.execution_decorator(
91
+ _func=_func,
69
92
  env_keys=env_keys,
70
93
  timeout=timeout,
71
94
  **options
72
95
  )
73
-
74
-
75
- __all__ = [
76
- "databricks_remote_compute",
77
- ]