ygg 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/METADATA +1 -1
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/RECORD +14 -13
- yggdrasil/databricks/compute/cluster.py +20 -16
- yggdrasil/databricks/compute/execution_context.py +46 -64
- yggdrasil/databricks/sql/engine.py +5 -2
- yggdrasil/databricks/sql/warehouse.py +355 -0
- yggdrasil/databricks/workspaces/workspace.py +19 -9
- yggdrasil/pyutils/callable_serde.py +296 -308
- yggdrasil/pyutils/expiring_dict.py +114 -25
- yggdrasil/version.py +1 -1
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/WHEEL +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/entry_points.txt +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/licenses/LICENSE +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.46.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import dataclasses as dc
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional, Sequence
|
|
5
|
+
|
|
6
|
+
from ..workspaces import WorkspaceService, Workspace
|
|
7
|
+
from ...pyutils.equality import dicts_equal, dict_diff
|
|
8
|
+
from ...pyutils.expiring_dict import ExpiringDict
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from databricks.sdk import WarehousesAPI
|
|
12
|
+
from databricks.sdk.service.sql import (
|
|
13
|
+
State, EndpointInfo, EndpointTags, EndpointTagPair, EndpointInfoWarehouseType
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
_CREATE_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.create).parameters.keys()}
|
|
17
|
+
_EDIT_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.edit).parameters.keys()}
|
|
18
|
+
except ImportError:
|
|
19
|
+
class WarehousesAPI:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
class State:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
class EndpointInfo:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
class EndpointTags:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
class EndpointTagPair:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
class EndpointInfoWarehouseType:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"SQLWarehouse"
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
LOGGER = logging.getLogger(__name__)
|
|
44
|
+
NAME_ID_CACHE: dict[str, ExpiringDict] = {}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def set_cached_warehouse_name(
|
|
48
|
+
host: str,
|
|
49
|
+
warehouse_name: str,
|
|
50
|
+
warehouse_id: str
|
|
51
|
+
) -> None:
|
|
52
|
+
existing = NAME_ID_CACHE.get(host)
|
|
53
|
+
|
|
54
|
+
if not existing:
|
|
55
|
+
existing = NAME_ID_CACHE[host] = ExpiringDict(default_ttl=60)
|
|
56
|
+
|
|
57
|
+
existing[warehouse_name] = warehouse_id
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_cached_warehouse_id(
|
|
61
|
+
host: str,
|
|
62
|
+
warehouse_name: str,
|
|
63
|
+
) -> str:
|
|
64
|
+
existing = NAME_ID_CACHE.get(host)
|
|
65
|
+
|
|
66
|
+
return existing.get(warehouse_name) if existing else None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dc.dataclass
|
|
70
|
+
class SQLWarehouse(WorkspaceService):
|
|
71
|
+
warehouse_id: Optional[str] = None
|
|
72
|
+
warehouse_name: Optional[str] = None
|
|
73
|
+
|
|
74
|
+
_details: Optional[EndpointInfo] = dc.field(default=None, repr=False)
|
|
75
|
+
|
|
76
|
+
def warehouse_client(self):
|
|
77
|
+
return self.workspace.sdk().warehouses
|
|
78
|
+
|
|
79
|
+
def default(
|
|
80
|
+
self,
|
|
81
|
+
name: str = "YGG-DEFAULT",
|
|
82
|
+
**kwargs
|
|
83
|
+
):
|
|
84
|
+
return self.create_or_update(
|
|
85
|
+
name=name,
|
|
86
|
+
**kwargs
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def details(self) -> EndpointInfo:
|
|
91
|
+
if self._details is None:
|
|
92
|
+
self.refresh()
|
|
93
|
+
return self._details
|
|
94
|
+
|
|
95
|
+
def latest_details(self):
|
|
96
|
+
return self.warehouse_client().get(id=self.warehouse_id)
|
|
97
|
+
|
|
98
|
+
def refresh(self):
|
|
99
|
+
self.details = self.latest_details()
|
|
100
|
+
return self
|
|
101
|
+
|
|
102
|
+
@details.setter
|
|
103
|
+
def details(self, value: EndpointInfo):
|
|
104
|
+
self._details = value
|
|
105
|
+
|
|
106
|
+
self.warehouse_id = value.id
|
|
107
|
+
self.warehouse_name = value.name
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def state(self):
|
|
111
|
+
return self.latest_details().state
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def running(self):
|
|
115
|
+
return self.state in {State.RUNNING}
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def pending(self):
|
|
119
|
+
return self.state in {State.DELETING, State.STARTING, State.STOPPING}
|
|
120
|
+
|
|
121
|
+
def start(self):
|
|
122
|
+
if not self.running:
|
|
123
|
+
self.warehouse_client().start(id=self.warehouse_id)
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
def stop(self):
|
|
127
|
+
if self.running:
|
|
128
|
+
return self.warehouse_client().stop(id=self.warehouse_id)
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
def find_warehouse(
|
|
132
|
+
self,
|
|
133
|
+
warehouse_id: Optional[str] = None,
|
|
134
|
+
warehouse_name: Optional[str] = None,
|
|
135
|
+
raise_error: bool = True
|
|
136
|
+
):
|
|
137
|
+
if warehouse_id:
|
|
138
|
+
return SQLWarehouse(
|
|
139
|
+
workspace=self.workspace,
|
|
140
|
+
warehouse_id=warehouse_id,
|
|
141
|
+
warehouse_name=warehouse_name
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if self.warehouse_id:
|
|
145
|
+
return self
|
|
146
|
+
|
|
147
|
+
warehouse_name = warehouse_name or self.warehouse_name
|
|
148
|
+
|
|
149
|
+
warehouse_id = get_cached_warehouse_id(host=self.workspace.host, warehouse_name=warehouse_name)
|
|
150
|
+
|
|
151
|
+
if warehouse_id:
|
|
152
|
+
return SQLWarehouse(
|
|
153
|
+
workspace=self.workspace,
|
|
154
|
+
warehouse_id=warehouse_id,
|
|
155
|
+
warehouse_name=warehouse_name
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
for warehouse in self.list_warehouses():
|
|
159
|
+
if warehouse.warehouse_name == warehouse_name:
|
|
160
|
+
set_cached_warehouse_name(host=self.workspace.host, warehouse_name=warehouse_name, warehouse_id=warehouse.warehouse_id)
|
|
161
|
+
return warehouse
|
|
162
|
+
|
|
163
|
+
if raise_error:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"SQL Warehouse {warehouse_name!r} not found"
|
|
166
|
+
)
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
def list_warehouses(self):
|
|
170
|
+
for info in self.warehouse_client().list():
|
|
171
|
+
warehouse = SQLWarehouse(
|
|
172
|
+
workspace=self.workspace,
|
|
173
|
+
warehouse_id=info.id,
|
|
174
|
+
warehouse_name=info.name,
|
|
175
|
+
_details=info
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
yield warehouse
|
|
179
|
+
|
|
180
|
+
def _check_details(
|
|
181
|
+
self,
|
|
182
|
+
keys: Sequence[str],
|
|
183
|
+
details: Optional[EndpointInfo] = None,
|
|
184
|
+
**warehouse_specs
|
|
185
|
+
):
|
|
186
|
+
if details is None:
|
|
187
|
+
details = EndpointInfo(**{
|
|
188
|
+
k: v
|
|
189
|
+
for k, v in warehouse_specs.items()
|
|
190
|
+
if k in keys
|
|
191
|
+
})
|
|
192
|
+
else:
|
|
193
|
+
kwargs = {
|
|
194
|
+
**details.as_shallow_dict(),
|
|
195
|
+
**warehouse_specs
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
details = EndpointInfo(
|
|
199
|
+
**{
|
|
200
|
+
k: v
|
|
201
|
+
for k, v in kwargs.items()
|
|
202
|
+
if k in keys
|
|
203
|
+
},
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if details.cluster_size is None:
|
|
207
|
+
details.cluster_size = "Small"
|
|
208
|
+
|
|
209
|
+
if details.name is None:
|
|
210
|
+
details.name = "YGG-%s" % details.cluster_size.upper()
|
|
211
|
+
|
|
212
|
+
default_tags = self.workspace.default_tags()
|
|
213
|
+
|
|
214
|
+
if details.tags is None:
|
|
215
|
+
details.tags = EndpointTags(custom_tags=[
|
|
216
|
+
EndpointTagPair(key=k, value=v)
|
|
217
|
+
for k, v in default_tags.items()
|
|
218
|
+
])
|
|
219
|
+
else:
|
|
220
|
+
tags = {
|
|
221
|
+
pair.key: pair.value
|
|
222
|
+
for pair in details.tags.custom_tags
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
tags.update(default_tags)
|
|
226
|
+
|
|
227
|
+
if details.tags is not None and not isinstance(details.tags, EndpointTags):
|
|
228
|
+
details.tags = EndpointTags(custom_tags=[
|
|
229
|
+
EndpointTagPair(key=k, value=v)
|
|
230
|
+
for k, v in default_tags.items()
|
|
231
|
+
])
|
|
232
|
+
|
|
233
|
+
if not details.max_num_clusters:
|
|
234
|
+
details.max_num_clusters = 4
|
|
235
|
+
|
|
236
|
+
if details.warehouse_type is None:
|
|
237
|
+
details.warehouse_type = EndpointInfoWarehouseType.CLASSIC
|
|
238
|
+
|
|
239
|
+
return details
|
|
240
|
+
|
|
241
|
+
def create_or_update(
|
|
242
|
+
self,
|
|
243
|
+
warehouse_id: Optional[str] = None,
|
|
244
|
+
name: Optional[str] = None,
|
|
245
|
+
**warehouse_specs
|
|
246
|
+
):
|
|
247
|
+
name = name or self.warehouse_name
|
|
248
|
+
found = self.find_warehouse(warehouse_id=warehouse_id, warehouse_name=name, raise_error=False)
|
|
249
|
+
|
|
250
|
+
if found is not None:
|
|
251
|
+
return found.update(name=name, **warehouse_specs)
|
|
252
|
+
return self.create(name=name, **warehouse_specs)
|
|
253
|
+
|
|
254
|
+
def create(
|
|
255
|
+
self,
|
|
256
|
+
name: Optional[str] = None,
|
|
257
|
+
**warehouse_specs
|
|
258
|
+
):
|
|
259
|
+
name = name or self.warehouse_name
|
|
260
|
+
|
|
261
|
+
details = self._check_details(
|
|
262
|
+
keys=_CREATE_ARG_NAMES,
|
|
263
|
+
name=name,
|
|
264
|
+
**warehouse_specs
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
info = self.warehouse_client().create_and_wait(**{
|
|
268
|
+
k: v
|
|
269
|
+
for k, v in details.as_shallow_dict().items()
|
|
270
|
+
if k in _CREATE_ARG_NAMES
|
|
271
|
+
})
|
|
272
|
+
|
|
273
|
+
return SQLWarehouse(
|
|
274
|
+
workspace=self.workspace,
|
|
275
|
+
warehouse_id=info.id,
|
|
276
|
+
warehouse_name=info.name,
|
|
277
|
+
_details=info
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def update(
|
|
281
|
+
self,
|
|
282
|
+
**warehouse_specs
|
|
283
|
+
):
|
|
284
|
+
if not warehouse_specs:
|
|
285
|
+
return self
|
|
286
|
+
|
|
287
|
+
existing_details = {
|
|
288
|
+
k: v
|
|
289
|
+
for k, v in self.details.as_shallow_dict().items()
|
|
290
|
+
if k in _EDIT_ARG_NAMES
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
update_details = {
|
|
294
|
+
k: v
|
|
295
|
+
for k, v in (
|
|
296
|
+
self._check_details(details=self.details, keys=_EDIT_ARG_NAMES, **warehouse_specs)
|
|
297
|
+
.as_shallow_dict()
|
|
298
|
+
.items()
|
|
299
|
+
)
|
|
300
|
+
if k in _EDIT_ARG_NAMES
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
same = dicts_equal(
|
|
304
|
+
existing_details,
|
|
305
|
+
update_details,
|
|
306
|
+
keys=_EDIT_ARG_NAMES,
|
|
307
|
+
treat_missing_as_none=True,
|
|
308
|
+
float_tol=0.0, # set e.g. 1e-6 if you have float-y stuff
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if not same:
|
|
312
|
+
diff = {
|
|
313
|
+
k: v[1]
|
|
314
|
+
for k, v in dict_diff(existing_details, update_details, keys=_EDIT_ARG_NAMES).items()
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
LOGGER.debug(
|
|
318
|
+
"Updating %s with %s",
|
|
319
|
+
self, diff
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
self.warehouse_client().edit_and_wait(**update_details)
|
|
323
|
+
|
|
324
|
+
LOGGER.info(
|
|
325
|
+
"Updated %s",
|
|
326
|
+
self
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
return self
|
|
330
|
+
|
|
331
|
+
def sql(
|
|
332
|
+
self,
|
|
333
|
+
workspace: Optional[Workspace] = None,
|
|
334
|
+
warehouse_id: Optional[str] = None,
|
|
335
|
+
catalog_name: Optional[str] = None,
|
|
336
|
+
schema_name: Optional[str] = None,
|
|
337
|
+
):
|
|
338
|
+
"""Return a SQLEngine configured for this workspace.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
workspace: Optional workspace override.
|
|
342
|
+
warehouse_id: Optional SQL warehouse id.
|
|
343
|
+
catalog_name: Optional catalog name.
|
|
344
|
+
schema_name: Optional schema name.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
A SQLEngine instance.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
return self.workspace.sql(
|
|
351
|
+
workspace=workspace,
|
|
352
|
+
warehouse_id=warehouse_id or self.warehouse_id,
|
|
353
|
+
catalog_name=catalog_name,
|
|
354
|
+
schema_name=schema_name
|
|
355
|
+
)
|
|
@@ -72,7 +72,7 @@ class Workspace:
|
|
|
72
72
|
"""Configuration wrapper for connecting to a Databricks workspace."""
|
|
73
73
|
# Databricks / generic
|
|
74
74
|
host: Optional[str] = None
|
|
75
|
-
account_id: Optional[str] = None
|
|
75
|
+
account_id: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
76
76
|
token: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
77
77
|
client_id: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
78
78
|
client_secret: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
@@ -220,7 +220,6 @@ class Workspace:
|
|
|
220
220
|
instance = self.clone_instance() if clone else self
|
|
221
221
|
|
|
222
222
|
require_databricks_sdk()
|
|
223
|
-
logger.debug("Connecting %s", self)
|
|
224
223
|
|
|
225
224
|
# Build Config from config_dict if available, else from fields.
|
|
226
225
|
kwargs = {
|
|
@@ -291,8 +290,6 @@ class Workspace:
|
|
|
291
290
|
if v is not None:
|
|
292
291
|
setattr(instance, key, v)
|
|
293
292
|
|
|
294
|
-
logger.info("Connected %s", instance)
|
|
295
|
-
|
|
296
293
|
return instance
|
|
297
294
|
|
|
298
295
|
# ------------------------------------------------------------------ #
|
|
@@ -570,6 +567,7 @@ class Workspace:
|
|
|
570
567
|
("Product", self.product),
|
|
571
568
|
("ProductVersion", self.product_version),
|
|
572
569
|
("ProductTag", self.product_tag),
|
|
570
|
+
("ProductUser", self.current_user.user_name)
|
|
573
571
|
)
|
|
574
572
|
if v
|
|
575
573
|
}
|
|
@@ -589,17 +587,17 @@ class Workspace:
|
|
|
589
587
|
def sql(
|
|
590
588
|
self,
|
|
591
589
|
workspace: Optional["Workspace"] = None,
|
|
590
|
+
warehouse_id: Optional[str] = None,
|
|
592
591
|
catalog_name: Optional[str] = None,
|
|
593
592
|
schema_name: Optional[str] = None,
|
|
594
|
-
**kwargs
|
|
595
593
|
):
|
|
596
594
|
"""Return a SQLEngine configured for this workspace.
|
|
597
595
|
|
|
598
596
|
Args:
|
|
599
597
|
workspace: Optional workspace override.
|
|
598
|
+
warehouse_id: Optional SQL warehouse id.
|
|
600
599
|
catalog_name: Optional catalog name.
|
|
601
600
|
schema_name: Optional schema name.
|
|
602
|
-
**kwargs: Additional SQLEngine parameters.
|
|
603
601
|
|
|
604
602
|
Returns:
|
|
605
603
|
A SQLEngine instance.
|
|
@@ -608,16 +606,29 @@ class Workspace:
|
|
|
608
606
|
|
|
609
607
|
return SQLEngine(
|
|
610
608
|
workspace=self if workspace is None else workspace,
|
|
609
|
+
warehouse_id=warehouse_id,
|
|
611
610
|
catalog_name=catalog_name,
|
|
612
611
|
schema_name=schema_name,
|
|
613
|
-
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
def warehouses(
|
|
615
|
+
self,
|
|
616
|
+
workspace: Optional["Workspace"] = None,
|
|
617
|
+
warehouse_id: Optional[str] = None,
|
|
618
|
+
warehouse_name: Optional[str] = None,
|
|
619
|
+
):
|
|
620
|
+
from ..sql.warehouse import SQLWarehouse
|
|
621
|
+
|
|
622
|
+
return SQLWarehouse(
|
|
623
|
+
workspace=self if workspace is None else workspace,
|
|
624
|
+
warehouse_id=warehouse_id,
|
|
625
|
+
warehouse_name=warehouse_name
|
|
614
626
|
)
|
|
615
627
|
|
|
616
628
|
def clusters(
|
|
617
629
|
self,
|
|
618
630
|
cluster_id: Optional[str] = None,
|
|
619
631
|
cluster_name: Optional[str] = None,
|
|
620
|
-
**kwargs
|
|
621
632
|
) -> "Cluster":
|
|
622
633
|
"""Return a Cluster helper bound to this workspace.
|
|
623
634
|
|
|
@@ -635,7 +646,6 @@ class Workspace:
|
|
|
635
646
|
workspace=self,
|
|
636
647
|
cluster_id=cluster_id,
|
|
637
648
|
cluster_name=cluster_name,
|
|
638
|
-
**kwargs
|
|
639
649
|
)
|
|
640
650
|
|
|
641
651
|
# ---------------------------------------------------------------------------
|