ygg 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ygg-0.1.44.dist-info → ygg-0.1.45.dist-info}/METADATA +1 -1
- {ygg-0.1.44.dist-info → ygg-0.1.45.dist-info}/RECORD +14 -13
- yggdrasil/databricks/compute/cluster.py +20 -16
- yggdrasil/databricks/compute/execution_context.py +35 -50
- yggdrasil/databricks/sql/engine.py +5 -2
- yggdrasil/databricks/sql/warehouse.py +355 -0
- yggdrasil/databricks/workspaces/workspace.py +19 -6
- yggdrasil/pyutils/callable_serde.py +183 -281
- yggdrasil/pyutils/expiring_dict.py +114 -25
- yggdrasil/version.py +1 -1
- {ygg-0.1.44.dist-info → ygg-0.1.45.dist-info}/WHEEL +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.45.dist-info}/entry_points.txt +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.45.dist-info}/licenses/LICENSE +0 -0
- {ygg-0.1.44.dist-info → ygg-0.1.45.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import dataclasses as dc
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional, Sequence
|
|
5
|
+
|
|
6
|
+
from ..workspaces import WorkspaceService, Workspace
|
|
7
|
+
from ...pyutils.equality import dicts_equal, dict_diff
|
|
8
|
+
from ...pyutils.expiring_dict import ExpiringDict
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from databricks.sdk import WarehousesAPI
|
|
12
|
+
from databricks.sdk.service.sql import (
|
|
13
|
+
State, EndpointInfo, EndpointTags, EndpointTagPair, EndpointInfoWarehouseType
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
_CREATE_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.create).parameters.keys()}
|
|
17
|
+
_EDIT_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.edit).parameters.keys()}
|
|
18
|
+
except ImportError:
|
|
19
|
+
class WarehousesAPI:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
class State:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
class EndpointInfo:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
class EndpointTags:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
class EndpointTagPair:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
class EndpointInfoWarehouseType:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"SQLWarehouse"
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
LOGGER = logging.getLogger(__name__)
|
|
44
|
+
NAME_ID_CACHE: dict[str, ExpiringDict] = {}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def set_cached_warehouse_name(
|
|
48
|
+
host: str,
|
|
49
|
+
warehouse_name: str,
|
|
50
|
+
warehouse_id: str
|
|
51
|
+
) -> None:
|
|
52
|
+
existing = NAME_ID_CACHE.get(host)
|
|
53
|
+
|
|
54
|
+
if not existing:
|
|
55
|
+
existing = NAME_ID_CACHE[host] = ExpiringDict(default_ttl=60)
|
|
56
|
+
|
|
57
|
+
existing[warehouse_name] = warehouse_id
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_cached_warehouse_id(
|
|
61
|
+
host: str,
|
|
62
|
+
warehouse_name: str,
|
|
63
|
+
) -> str:
|
|
64
|
+
existing = NAME_ID_CACHE.get(host)
|
|
65
|
+
|
|
66
|
+
return existing.get(warehouse_name) if existing else None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dc.dataclass
|
|
70
|
+
class SQLWarehouse(WorkspaceService):
|
|
71
|
+
warehouse_id: Optional[str] = None
|
|
72
|
+
warehouse_name: Optional[str] = None
|
|
73
|
+
|
|
74
|
+
_details: Optional[EndpointInfo] = dc.field(default=None, repr=False)
|
|
75
|
+
|
|
76
|
+
def warehouse_client(self):
|
|
77
|
+
return self.workspace.sdk().warehouses
|
|
78
|
+
|
|
79
|
+
def default(
|
|
80
|
+
self,
|
|
81
|
+
name: str = "YGG-DEFAULT",
|
|
82
|
+
**kwargs
|
|
83
|
+
):
|
|
84
|
+
return self.create_or_update(
|
|
85
|
+
name=name,
|
|
86
|
+
**kwargs
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def details(self) -> EndpointInfo:
|
|
91
|
+
if self._details is None:
|
|
92
|
+
self.refresh()
|
|
93
|
+
return self._details
|
|
94
|
+
|
|
95
|
+
def latest_details(self):
|
|
96
|
+
return self.warehouse_client().get(id=self.warehouse_id)
|
|
97
|
+
|
|
98
|
+
def refresh(self):
|
|
99
|
+
self.details = self.latest_details()
|
|
100
|
+
return self
|
|
101
|
+
|
|
102
|
+
@details.setter
|
|
103
|
+
def details(self, value: EndpointInfo):
|
|
104
|
+
self._details = value
|
|
105
|
+
|
|
106
|
+
self.warehouse_id = value.id
|
|
107
|
+
self.warehouse_name = value.name
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def state(self):
|
|
111
|
+
return self.latest_details().state
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def running(self):
|
|
115
|
+
return self.state in {State.RUNNING}
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def pending(self):
|
|
119
|
+
return self.state in {State.DELETING, State.STARTING, State.STOPPING}
|
|
120
|
+
|
|
121
|
+
def start(self):
|
|
122
|
+
if not self.running:
|
|
123
|
+
self.warehouse_client().start(id=self.warehouse_id)
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
def stop(self):
|
|
127
|
+
if self.running:
|
|
128
|
+
return self.warehouse_client().stop(id=self.warehouse_id)
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
def find_warehouse(
|
|
132
|
+
self,
|
|
133
|
+
warehouse_id: Optional[str] = None,
|
|
134
|
+
warehouse_name: Optional[str] = None,
|
|
135
|
+
raise_error: bool = True
|
|
136
|
+
):
|
|
137
|
+
if warehouse_id:
|
|
138
|
+
return SQLWarehouse(
|
|
139
|
+
workspace=self.workspace,
|
|
140
|
+
warehouse_id=warehouse_id,
|
|
141
|
+
warehouse_name=warehouse_name
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if self.warehouse_id:
|
|
145
|
+
return self
|
|
146
|
+
|
|
147
|
+
warehouse_name = warehouse_name or self.warehouse_name
|
|
148
|
+
|
|
149
|
+
warehouse_id = get_cached_warehouse_id(host=self.workspace.host, warehouse_name=warehouse_name)
|
|
150
|
+
|
|
151
|
+
if warehouse_id:
|
|
152
|
+
return SQLWarehouse(
|
|
153
|
+
workspace=self.workspace,
|
|
154
|
+
warehouse_id=warehouse_id,
|
|
155
|
+
warehouse_name=warehouse_name
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
for warehouse in self.list_warehouses():
|
|
159
|
+
if warehouse.warehouse_name == warehouse_name:
|
|
160
|
+
set_cached_warehouse_name(host=self.workspace.host, warehouse_name=warehouse_name, warehouse_id=warehouse.warehouse_id)
|
|
161
|
+
return warehouse
|
|
162
|
+
|
|
163
|
+
if raise_error:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"SQL Warehouse {warehouse_name!r} not found"
|
|
166
|
+
)
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
def list_warehouses(self):
|
|
170
|
+
for info in self.warehouse_client().list():
|
|
171
|
+
warehouse = SQLWarehouse(
|
|
172
|
+
workspace=self.workspace,
|
|
173
|
+
warehouse_id=info.id,
|
|
174
|
+
warehouse_name=info.name,
|
|
175
|
+
_details=info
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
yield warehouse
|
|
179
|
+
|
|
180
|
+
def _check_details(
|
|
181
|
+
self,
|
|
182
|
+
keys: Sequence[str],
|
|
183
|
+
details: Optional[EndpointInfo] = None,
|
|
184
|
+
**warehouse_specs
|
|
185
|
+
):
|
|
186
|
+
if details is None:
|
|
187
|
+
details = EndpointInfo(**{
|
|
188
|
+
k: v
|
|
189
|
+
for k, v in warehouse_specs.items()
|
|
190
|
+
if k in keys
|
|
191
|
+
})
|
|
192
|
+
else:
|
|
193
|
+
kwargs = {
|
|
194
|
+
**details.as_shallow_dict(),
|
|
195
|
+
**warehouse_specs
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
details = EndpointInfo(
|
|
199
|
+
**{
|
|
200
|
+
k: v
|
|
201
|
+
for k, v in kwargs.items()
|
|
202
|
+
if k in keys
|
|
203
|
+
},
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if details.cluster_size is None:
|
|
207
|
+
details.cluster_size = "Small"
|
|
208
|
+
|
|
209
|
+
if details.name is None:
|
|
210
|
+
details.name = "YGG-%s" % details.cluster_size.upper()
|
|
211
|
+
|
|
212
|
+
default_tags = self.workspace.default_tags()
|
|
213
|
+
|
|
214
|
+
if details.tags is None:
|
|
215
|
+
details.tags = EndpointTags(custom_tags=[
|
|
216
|
+
EndpointTagPair(key=k, value=v)
|
|
217
|
+
for k, v in default_tags.items()
|
|
218
|
+
])
|
|
219
|
+
else:
|
|
220
|
+
tags = {
|
|
221
|
+
pair.key: pair.value
|
|
222
|
+
for pair in details.tags.custom_tags
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
tags.update(default_tags)
|
|
226
|
+
|
|
227
|
+
if details.tags is not None and not isinstance(details.tags, EndpointTags):
|
|
228
|
+
details.tags = EndpointTags(custom_tags=[
|
|
229
|
+
EndpointTagPair(key=k, value=v)
|
|
230
|
+
for k, v in default_tags.items()
|
|
231
|
+
])
|
|
232
|
+
|
|
233
|
+
if not details.max_num_clusters:
|
|
234
|
+
details.max_num_clusters = 4
|
|
235
|
+
|
|
236
|
+
if details.warehouse_type is None:
|
|
237
|
+
details.warehouse_type = EndpointInfoWarehouseType.CLASSIC
|
|
238
|
+
|
|
239
|
+
return details
|
|
240
|
+
|
|
241
|
+
def create_or_update(
|
|
242
|
+
self,
|
|
243
|
+
warehouse_id: Optional[str] = None,
|
|
244
|
+
name: Optional[str] = None,
|
|
245
|
+
**warehouse_specs
|
|
246
|
+
):
|
|
247
|
+
name = name or self.warehouse_name
|
|
248
|
+
found = self.find_warehouse(warehouse_id=warehouse_id, warehouse_name=name, raise_error=False)
|
|
249
|
+
|
|
250
|
+
if found is not None:
|
|
251
|
+
return found.update(name=name, **warehouse_specs)
|
|
252
|
+
return self.create(name=name, **warehouse_specs)
|
|
253
|
+
|
|
254
|
+
def create(
|
|
255
|
+
self,
|
|
256
|
+
name: Optional[str] = None,
|
|
257
|
+
**warehouse_specs
|
|
258
|
+
):
|
|
259
|
+
name = name or self.warehouse_name
|
|
260
|
+
|
|
261
|
+
details = self._check_details(
|
|
262
|
+
keys=_CREATE_ARG_NAMES,
|
|
263
|
+
name=name,
|
|
264
|
+
**warehouse_specs
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
info = self.warehouse_client().create_and_wait(**{
|
|
268
|
+
k: v
|
|
269
|
+
for k, v in details.as_shallow_dict().items()
|
|
270
|
+
if k in _CREATE_ARG_NAMES
|
|
271
|
+
})
|
|
272
|
+
|
|
273
|
+
return SQLWarehouse(
|
|
274
|
+
workspace=self.workspace,
|
|
275
|
+
warehouse_id=info.id,
|
|
276
|
+
warehouse_name=info.name,
|
|
277
|
+
_details=info
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def update(
|
|
281
|
+
self,
|
|
282
|
+
**warehouse_specs
|
|
283
|
+
):
|
|
284
|
+
if not warehouse_specs:
|
|
285
|
+
return self
|
|
286
|
+
|
|
287
|
+
existing_details = {
|
|
288
|
+
k: v
|
|
289
|
+
for k, v in self.details.as_shallow_dict().items()
|
|
290
|
+
if k in _EDIT_ARG_NAMES
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
update_details = {
|
|
294
|
+
k: v
|
|
295
|
+
for k, v in (
|
|
296
|
+
self._check_details(details=self.details, keys=_EDIT_ARG_NAMES, **warehouse_specs)
|
|
297
|
+
.as_shallow_dict()
|
|
298
|
+
.items()
|
|
299
|
+
)
|
|
300
|
+
if k in _EDIT_ARG_NAMES
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
same = dicts_equal(
|
|
304
|
+
existing_details,
|
|
305
|
+
update_details,
|
|
306
|
+
keys=_EDIT_ARG_NAMES,
|
|
307
|
+
treat_missing_as_none=True,
|
|
308
|
+
float_tol=0.0, # set e.g. 1e-6 if you have float-y stuff
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if not same:
|
|
312
|
+
diff = {
|
|
313
|
+
k: v[1]
|
|
314
|
+
for k, v in dict_diff(existing_details, update_details, keys=_EDIT_ARG_NAMES).items()
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
LOGGER.debug(
|
|
318
|
+
"Updating %s with %s",
|
|
319
|
+
self, diff
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
self.warehouse_client().edit_and_wait(**update_details)
|
|
323
|
+
|
|
324
|
+
LOGGER.info(
|
|
325
|
+
"Updated %s",
|
|
326
|
+
self
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
return self
|
|
330
|
+
|
|
331
|
+
def sql(
|
|
332
|
+
self,
|
|
333
|
+
workspace: Optional[Workspace] = None,
|
|
334
|
+
warehouse_id: Optional[str] = None,
|
|
335
|
+
catalog_name: Optional[str] = None,
|
|
336
|
+
schema_name: Optional[str] = None,
|
|
337
|
+
):
|
|
338
|
+
"""Return a SQLEngine configured for this workspace.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
workspace: Optional workspace override.
|
|
342
|
+
warehouse_id: Optional SQL warehouse id.
|
|
343
|
+
catalog_name: Optional catalog name.
|
|
344
|
+
schema_name: Optional schema name.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
A SQLEngine instance.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
return self.workspace.sql(
|
|
351
|
+
workspace=workspace,
|
|
352
|
+
warehouse_id=warehouse_id or self.warehouse_id,
|
|
353
|
+
catalog_name=catalog_name,
|
|
354
|
+
schema_name=schema_name
|
|
355
|
+
)
|
|
@@ -72,7 +72,7 @@ class Workspace:
|
|
|
72
72
|
"""Configuration wrapper for connecting to a Databricks workspace."""
|
|
73
73
|
# Databricks / generic
|
|
74
74
|
host: Optional[str] = None
|
|
75
|
-
account_id: Optional[str] = None
|
|
75
|
+
account_id: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
76
76
|
token: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
77
77
|
client_id: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
78
78
|
client_secret: Optional[str] = dataclasses.field(default=None, repr=False)
|
|
@@ -570,6 +570,7 @@ class Workspace:
|
|
|
570
570
|
("Product", self.product),
|
|
571
571
|
("ProductVersion", self.product_version),
|
|
572
572
|
("ProductTag", self.product_tag),
|
|
573
|
+
("ProductUser", self.current_user.user_name)
|
|
573
574
|
)
|
|
574
575
|
if v
|
|
575
576
|
}
|
|
@@ -589,17 +590,17 @@ class Workspace:
|
|
|
589
590
|
def sql(
|
|
590
591
|
self,
|
|
591
592
|
workspace: Optional["Workspace"] = None,
|
|
593
|
+
warehouse_id: Optional[str] = None,
|
|
592
594
|
catalog_name: Optional[str] = None,
|
|
593
595
|
schema_name: Optional[str] = None,
|
|
594
|
-
**kwargs
|
|
595
596
|
):
|
|
596
597
|
"""Return a SQLEngine configured for this workspace.
|
|
597
598
|
|
|
598
599
|
Args:
|
|
599
600
|
workspace: Optional workspace override.
|
|
601
|
+
warehouse_id: Optional SQL warehouse id.
|
|
600
602
|
catalog_name: Optional catalog name.
|
|
601
603
|
schema_name: Optional schema name.
|
|
602
|
-
**kwargs: Additional SQLEngine parameters.
|
|
603
604
|
|
|
604
605
|
Returns:
|
|
605
606
|
A SQLEngine instance.
|
|
@@ -608,16 +609,29 @@ class Workspace:
|
|
|
608
609
|
|
|
609
610
|
return SQLEngine(
|
|
610
611
|
workspace=self if workspace is None else workspace,
|
|
612
|
+
warehouse_id=warehouse_id,
|
|
611
613
|
catalog_name=catalog_name,
|
|
612
614
|
schema_name=schema_name,
|
|
613
|
-
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
def warehouses(
|
|
618
|
+
self,
|
|
619
|
+
workspace: Optional["Workspace"] = None,
|
|
620
|
+
warehouse_id: Optional[str] = None,
|
|
621
|
+
warehouse_name: Optional[str] = None,
|
|
622
|
+
):
|
|
623
|
+
from ..sql.warehouse import SQLWarehouse
|
|
624
|
+
|
|
625
|
+
return SQLWarehouse(
|
|
626
|
+
workspace=self if workspace is None else workspace,
|
|
627
|
+
warehouse_id=warehouse_id,
|
|
628
|
+
warehouse_name=warehouse_name
|
|
614
629
|
)
|
|
615
630
|
|
|
616
631
|
def clusters(
|
|
617
632
|
self,
|
|
618
633
|
cluster_id: Optional[str] = None,
|
|
619
634
|
cluster_name: Optional[str] = None,
|
|
620
|
-
**kwargs
|
|
621
635
|
) -> "Cluster":
|
|
622
636
|
"""Return a Cluster helper bound to this workspace.
|
|
623
637
|
|
|
@@ -635,7 +649,6 @@ class Workspace:
|
|
|
635
649
|
workspace=self,
|
|
636
650
|
cluster_id=cluster_id,
|
|
637
651
|
cluster_name=cluster_name,
|
|
638
|
-
**kwargs
|
|
639
652
|
)
|
|
640
653
|
|
|
641
654
|
# ---------------------------------------------------------------------------
|