ygg 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,355 @@
1
+ import dataclasses as dc
2
+ import inspect
3
+ import logging
4
+ from typing import Optional, Sequence
5
+
6
+ from ..workspaces import WorkspaceService, Workspace
7
+ from ...pyutils.equality import dicts_equal, dict_diff
8
+ from ...pyutils.expiring_dict import ExpiringDict
9
+
10
+ try:
11
+ from databricks.sdk import WarehousesAPI
12
+ from databricks.sdk.service.sql import (
13
+ State, EndpointInfo, EndpointTags, EndpointTagPair, EndpointInfoWarehouseType
14
+ )
15
+
16
+ _CREATE_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.create).parameters.keys()}
17
+ _EDIT_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.edit).parameters.keys()}
18
+ except ImportError:
19
+ class WarehousesAPI:
20
+ pass
21
+
22
+ class State:
23
+ pass
24
+
25
+ class EndpointInfo:
26
+ pass
27
+
28
+ class EndpointTags:
29
+ pass
30
+
31
+ class EndpointTagPair:
32
+ pass
33
+
34
+ class EndpointInfoWarehouseType:
35
+ pass
36
+
37
+
38
+ __all__ = [
39
+ "SQLWarehouse"
40
+ ]
41
+
42
+
43
+ LOGGER = logging.getLogger(__name__)
44
+ NAME_ID_CACHE: dict[str, ExpiringDict] = {}
45
+
46
+
47
+ def set_cached_warehouse_name(
48
+ host: str,
49
+ warehouse_name: str,
50
+ warehouse_id: str
51
+ ) -> None:
52
+ existing = NAME_ID_CACHE.get(host)
53
+
54
+ if not existing:
55
+ existing = NAME_ID_CACHE[host] = ExpiringDict(default_ttl=60)
56
+
57
+ existing[warehouse_name] = warehouse_id
58
+
59
+
60
+ def get_cached_warehouse_id(
61
+ host: str,
62
+ warehouse_name: str,
63
+ ) -> str:
64
+ existing = NAME_ID_CACHE.get(host)
65
+
66
+ return existing.get(warehouse_name) if existing else None
67
+
68
+
69
+ @dc.dataclass
70
+ class SQLWarehouse(WorkspaceService):
71
+ warehouse_id: Optional[str] = None
72
+ warehouse_name: Optional[str] = None
73
+
74
+ _details: Optional[EndpointInfo] = dc.field(default=None, repr=False)
75
+
76
+ def warehouse_client(self):
77
+ return self.workspace.sdk().warehouses
78
+
79
+ def default(
80
+ self,
81
+ name: str = "YGG-DEFAULT",
82
+ **kwargs
83
+ ):
84
+ return self.create_or_update(
85
+ name=name,
86
+ **kwargs
87
+ )
88
+
89
+ @property
90
+ def details(self) -> EndpointInfo:
91
+ if self._details is None:
92
+ self.refresh()
93
+ return self._details
94
+
95
+ def latest_details(self):
96
+ return self.warehouse_client().get(id=self.warehouse_id)
97
+
98
+ def refresh(self):
99
+ self.details = self.latest_details()
100
+ return self
101
+
102
+ @details.setter
103
+ def details(self, value: EndpointInfo):
104
+ self._details = value
105
+
106
+ self.warehouse_id = value.id
107
+ self.warehouse_name = value.name
108
+
109
+ @property
110
+ def state(self):
111
+ return self.latest_details().state
112
+
113
+ @property
114
+ def running(self):
115
+ return self.state in {State.RUNNING}
116
+
117
+ @property
118
+ def pending(self):
119
+ return self.state in {State.DELETING, State.STARTING, State.STOPPING}
120
+
121
+ def start(self):
122
+ if not self.running:
123
+ self.warehouse_client().start(id=self.warehouse_id)
124
+ return self
125
+
126
+ def stop(self):
127
+ if self.running:
128
+ return self.warehouse_client().stop(id=self.warehouse_id)
129
+ return self
130
+
131
+ def find_warehouse(
132
+ self,
133
+ warehouse_id: Optional[str] = None,
134
+ warehouse_name: Optional[str] = None,
135
+ raise_error: bool = True
136
+ ):
137
+ if warehouse_id:
138
+ return SQLWarehouse(
139
+ workspace=self.workspace,
140
+ warehouse_id=warehouse_id,
141
+ warehouse_name=warehouse_name
142
+ )
143
+
144
+ if self.warehouse_id:
145
+ return self
146
+
147
+ warehouse_name = warehouse_name or self.warehouse_name
148
+
149
+ warehouse_id = get_cached_warehouse_id(host=self.workspace.host, warehouse_name=warehouse_name)
150
+
151
+ if warehouse_id:
152
+ return SQLWarehouse(
153
+ workspace=self.workspace,
154
+ warehouse_id=warehouse_id,
155
+ warehouse_name=warehouse_name
156
+ )
157
+
158
+ for warehouse in self.list_warehouses():
159
+ if warehouse.warehouse_name == warehouse_name:
160
+ set_cached_warehouse_name(host=self.workspace.host, warehouse_name=warehouse_name, warehouse_id=warehouse.warehouse_id)
161
+ return warehouse
162
+
163
+ if raise_error:
164
+ raise ValueError(
165
+ f"SQL Warehouse {warehouse_name!r} not found"
166
+ )
167
+ return None
168
+
169
+ def list_warehouses(self):
170
+ for info in self.warehouse_client().list():
171
+ warehouse = SQLWarehouse(
172
+ workspace=self.workspace,
173
+ warehouse_id=info.id,
174
+ warehouse_name=info.name,
175
+ _details=info
176
+ )
177
+
178
+ yield warehouse
179
+
180
+ def _check_details(
181
+ self,
182
+ keys: Sequence[str],
183
+ details: Optional[EndpointInfo] = None,
184
+ **warehouse_specs
185
+ ):
186
+ if details is None:
187
+ details = EndpointInfo(**{
188
+ k: v
189
+ for k, v in warehouse_specs.items()
190
+ if k in keys
191
+ })
192
+ else:
193
+ kwargs = {
194
+ **details.as_shallow_dict(),
195
+ **warehouse_specs
196
+ }
197
+
198
+ details = EndpointInfo(
199
+ **{
200
+ k: v
201
+ for k, v in kwargs.items()
202
+ if k in keys
203
+ },
204
+ )
205
+
206
+ if details.cluster_size is None:
207
+ details.cluster_size = "Small"
208
+
209
+ if details.name is None:
210
+ details.name = "YGG-%s" % details.cluster_size.upper()
211
+
212
+ default_tags = self.workspace.default_tags()
213
+
214
+ if details.tags is None:
215
+ details.tags = EndpointTags(custom_tags=[
216
+ EndpointTagPair(key=k, value=v)
217
+ for k, v in default_tags.items()
218
+ ])
219
+ else:
220
+ tags = {
221
+ pair.key: pair.value
222
+ for pair in details.tags.custom_tags
223
+ }
224
+
225
+ tags.update(default_tags)
226
+
227
+ if details.tags is not None and not isinstance(details.tags, EndpointTags):
228
+ details.tags = EndpointTags(custom_tags=[
229
+ EndpointTagPair(key=k, value=v)
230
+ for k, v in default_tags.items()
231
+ ])
232
+
233
+ if not details.max_num_clusters:
234
+ details.max_num_clusters = 4
235
+
236
+ if details.warehouse_type is None:
237
+ details.warehouse_type = EndpointInfoWarehouseType.CLASSIC
238
+
239
+ return details
240
+
241
+ def create_or_update(
242
+ self,
243
+ warehouse_id: Optional[str] = None,
244
+ name: Optional[str] = None,
245
+ **warehouse_specs
246
+ ):
247
+ name = name or self.warehouse_name
248
+ found = self.find_warehouse(warehouse_id=warehouse_id, warehouse_name=name, raise_error=False)
249
+
250
+ if found is not None:
251
+ return found.update(name=name, **warehouse_specs)
252
+ return self.create(name=name, **warehouse_specs)
253
+
254
+ def create(
255
+ self,
256
+ name: Optional[str] = None,
257
+ **warehouse_specs
258
+ ):
259
+ name = name or self.warehouse_name
260
+
261
+ details = self._check_details(
262
+ keys=_CREATE_ARG_NAMES,
263
+ name=name,
264
+ **warehouse_specs
265
+ )
266
+
267
+ info = self.warehouse_client().create_and_wait(**{
268
+ k: v
269
+ for k, v in details.as_shallow_dict().items()
270
+ if k in _CREATE_ARG_NAMES
271
+ })
272
+
273
+ return SQLWarehouse(
274
+ workspace=self.workspace,
275
+ warehouse_id=info.id,
276
+ warehouse_name=info.name,
277
+ _details=info
278
+ )
279
+
280
+ def update(
281
+ self,
282
+ **warehouse_specs
283
+ ):
284
+ if not warehouse_specs:
285
+ return self
286
+
287
+ existing_details = {
288
+ k: v
289
+ for k, v in self.details.as_shallow_dict().items()
290
+ if k in _EDIT_ARG_NAMES
291
+ }
292
+
293
+ update_details = {
294
+ k: v
295
+ for k, v in (
296
+ self._check_details(details=self.details, keys=_EDIT_ARG_NAMES, **warehouse_specs)
297
+ .as_shallow_dict()
298
+ .items()
299
+ )
300
+ if k in _EDIT_ARG_NAMES
301
+ }
302
+
303
+ same = dicts_equal(
304
+ existing_details,
305
+ update_details,
306
+ keys=_EDIT_ARG_NAMES,
307
+ treat_missing_as_none=True,
308
+ float_tol=0.0, # set e.g. 1e-6 if you have float-y stuff
309
+ )
310
+
311
+ if not same:
312
+ diff = {
313
+ k: v[1]
314
+ for k, v in dict_diff(existing_details, update_details, keys=_EDIT_ARG_NAMES).items()
315
+ }
316
+
317
+ LOGGER.debug(
318
+ "Updating %s with %s",
319
+ self, diff
320
+ )
321
+
322
+ self.warehouse_client().edit_and_wait(**update_details)
323
+
324
+ LOGGER.info(
325
+ "Updated %s",
326
+ self
327
+ )
328
+
329
+ return self
330
+
331
+ def sql(
332
+ self,
333
+ workspace: Optional[Workspace] = None,
334
+ warehouse_id: Optional[str] = None,
335
+ catalog_name: Optional[str] = None,
336
+ schema_name: Optional[str] = None,
337
+ ):
338
+ """Return a SQLEngine configured for this workspace.
339
+
340
+ Args:
341
+ workspace: Optional workspace override.
342
+ warehouse_id: Optional SQL warehouse id.
343
+ catalog_name: Optional catalog name.
344
+ schema_name: Optional schema name.
345
+
346
+ Returns:
347
+ A SQLEngine instance.
348
+ """
349
+
350
+ return self.workspace.sql(
351
+ workspace=workspace,
352
+ warehouse_id=warehouse_id or self.warehouse_id,
353
+ catalog_name=catalog_name,
354
+ schema_name=schema_name
355
+ )
@@ -72,7 +72,7 @@ class Workspace:
72
72
  """Configuration wrapper for connecting to a Databricks workspace."""
73
73
  # Databricks / generic
74
74
  host: Optional[str] = None
75
- account_id: Optional[str] = None
75
+ account_id: Optional[str] = dataclasses.field(default=None, repr=False)
76
76
  token: Optional[str] = dataclasses.field(default=None, repr=False)
77
77
  client_id: Optional[str] = dataclasses.field(default=None, repr=False)
78
78
  client_secret: Optional[str] = dataclasses.field(default=None, repr=False)
@@ -570,6 +570,7 @@ class Workspace:
570
570
  ("Product", self.product),
571
571
  ("ProductVersion", self.product_version),
572
572
  ("ProductTag", self.product_tag),
573
+ ("ProductUser", self.current_user.user_name)
573
574
  )
574
575
  if v
575
576
  }
@@ -589,17 +590,17 @@ class Workspace:
589
590
  def sql(
590
591
  self,
591
592
  workspace: Optional["Workspace"] = None,
593
+ warehouse_id: Optional[str] = None,
592
594
  catalog_name: Optional[str] = None,
593
595
  schema_name: Optional[str] = None,
594
- **kwargs
595
596
  ):
596
597
  """Return a SQLEngine configured for this workspace.
597
598
 
598
599
  Args:
599
600
  workspace: Optional workspace override.
601
+ warehouse_id: Optional SQL warehouse id.
600
602
  catalog_name: Optional catalog name.
601
603
  schema_name: Optional schema name.
602
- **kwargs: Additional SQLEngine parameters.
603
604
 
604
605
  Returns:
605
606
  A SQLEngine instance.
@@ -608,16 +609,29 @@ class Workspace:
608
609
 
609
610
  return SQLEngine(
610
611
  workspace=self if workspace is None else workspace,
612
+ warehouse_id=warehouse_id,
611
613
  catalog_name=catalog_name,
612
614
  schema_name=schema_name,
613
- **kwargs
615
+ )
616
+
617
+ def warehouses(
618
+ self,
619
+ workspace: Optional["Workspace"] = None,
620
+ warehouse_id: Optional[str] = None,
621
+ warehouse_name: Optional[str] = None,
622
+ ):
623
+ from ..sql.warehouse import SQLWarehouse
624
+
625
+ return SQLWarehouse(
626
+ workspace=self if workspace is None else workspace,
627
+ warehouse_id=warehouse_id,
628
+ warehouse_name=warehouse_name
614
629
  )
615
630
 
616
631
  def clusters(
617
632
  self,
618
633
  cluster_id: Optional[str] = None,
619
634
  cluster_name: Optional[str] = None,
620
- **kwargs
621
635
  ) -> "Cluster":
622
636
  """Return a Cluster helper bound to this workspace.
623
637
 
@@ -635,7 +649,6 @@ class Workspace:
635
649
  workspace=self,
636
650
  cluster_id=cluster_id,
637
651
  cluster_name=cluster_name,
638
- **kwargs
639
652
  )
640
653
 
641
654
  # ---------------------------------------------------------------------------