ygg 0.1.44__py3-none-any.whl → 0.1.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,355 @@
1
+ import dataclasses as dc
2
+ import inspect
3
+ import logging
4
+ from typing import Optional, Sequence
5
+
6
+ from ..workspaces import WorkspaceService, Workspace
7
+ from ...pyutils.equality import dicts_equal, dict_diff
8
+ from ...pyutils.expiring_dict import ExpiringDict
9
+
10
+ try:
11
+ from databricks.sdk import WarehousesAPI
12
+ from databricks.sdk.service.sql import (
13
+ State, EndpointInfo, EndpointTags, EndpointTagPair, EndpointInfoWarehouseType
14
+ )
15
+
16
+ _CREATE_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.create).parameters.keys()}
17
+ _EDIT_ARG_NAMES = {_ for _ in inspect.signature(WarehousesAPI.edit).parameters.keys()}
18
+ except ImportError:
19
+ class WarehousesAPI:
20
+ pass
21
+
22
+ class State:
23
+ pass
24
+
25
+ class EndpointInfo:
26
+ pass
27
+
28
+ class EndpointTags:
29
+ pass
30
+
31
+ class EndpointTagPair:
32
+ pass
33
+
34
+ class EndpointInfoWarehouseType:
35
+ pass
36
+
37
+
38
+ __all__ = [
39
+ "SQLWarehouse"
40
+ ]
41
+
42
+
43
+ LOGGER = logging.getLogger(__name__)
44
+ NAME_ID_CACHE: dict[str, ExpiringDict] = {}
45
+
46
+
47
+ def set_cached_warehouse_name(
48
+ host: str,
49
+ warehouse_name: str,
50
+ warehouse_id: str
51
+ ) -> None:
52
+ existing = NAME_ID_CACHE.get(host)
53
+
54
+ if not existing:
55
+ existing = NAME_ID_CACHE[host] = ExpiringDict(default_ttl=60)
56
+
57
+ existing[warehouse_name] = warehouse_id
58
+
59
+
60
+ def get_cached_warehouse_id(
61
+ host: str,
62
+ warehouse_name: str,
63
+ ) -> str:
64
+ existing = NAME_ID_CACHE.get(host)
65
+
66
+ return existing.get(warehouse_name) if existing else None
67
+
68
+
69
+ @dc.dataclass
70
+ class SQLWarehouse(WorkspaceService):
71
+ warehouse_id: Optional[str] = None
72
+ warehouse_name: Optional[str] = None
73
+
74
+ _details: Optional[EndpointInfo] = dc.field(default=None, repr=False)
75
+
76
+ def warehouse_client(self):
77
+ return self.workspace.sdk().warehouses
78
+
79
+ def default(
80
+ self,
81
+ name: str = "YGG-DEFAULT",
82
+ **kwargs
83
+ ):
84
+ return self.create_or_update(
85
+ name=name,
86
+ **kwargs
87
+ )
88
+
89
+ @property
90
+ def details(self) -> EndpointInfo:
91
+ if self._details is None:
92
+ self.refresh()
93
+ return self._details
94
+
95
+ def latest_details(self):
96
+ return self.warehouse_client().get(id=self.warehouse_id)
97
+
98
+ def refresh(self):
99
+ self.details = self.latest_details()
100
+ return self
101
+
102
+ @details.setter
103
+ def details(self, value: EndpointInfo):
104
+ self._details = value
105
+
106
+ self.warehouse_id = value.id
107
+ self.warehouse_name = value.name
108
+
109
+ @property
110
+ def state(self):
111
+ return self.latest_details().state
112
+
113
+ @property
114
+ def running(self):
115
+ return self.state in {State.RUNNING}
116
+
117
+ @property
118
+ def pending(self):
119
+ return self.state in {State.DELETING, State.STARTING, State.STOPPING}
120
+
121
+ def start(self):
122
+ if not self.running:
123
+ self.warehouse_client().start(id=self.warehouse_id)
124
+ return self
125
+
126
+ def stop(self):
127
+ if self.running:
128
+ return self.warehouse_client().stop(id=self.warehouse_id)
129
+ return self
130
+
131
+ def find_warehouse(
132
+ self,
133
+ warehouse_id: Optional[str] = None,
134
+ warehouse_name: Optional[str] = None,
135
+ raise_error: bool = True
136
+ ):
137
+ if warehouse_id:
138
+ return SQLWarehouse(
139
+ workspace=self.workspace,
140
+ warehouse_id=warehouse_id,
141
+ warehouse_name=warehouse_name
142
+ )
143
+
144
+ if self.warehouse_id:
145
+ return self
146
+
147
+ warehouse_name = warehouse_name or self.warehouse_name
148
+
149
+ warehouse_id = get_cached_warehouse_id(host=self.workspace.host, warehouse_name=warehouse_name)
150
+
151
+ if warehouse_id:
152
+ return SQLWarehouse(
153
+ workspace=self.workspace,
154
+ warehouse_id=warehouse_id,
155
+ warehouse_name=warehouse_name
156
+ )
157
+
158
+ for warehouse in self.list_warehouses():
159
+ if warehouse.warehouse_name == warehouse_name:
160
+ set_cached_warehouse_name(host=self.workspace.host, warehouse_name=warehouse_name, warehouse_id=warehouse.warehouse_id)
161
+ return warehouse
162
+
163
+ if raise_error:
164
+ raise ValueError(
165
+ f"SQL Warehouse {warehouse_name!r} not found"
166
+ )
167
+ return None
168
+
169
+ def list_warehouses(self):
170
+ for info in self.warehouse_client().list():
171
+ warehouse = SQLWarehouse(
172
+ workspace=self.workspace,
173
+ warehouse_id=info.id,
174
+ warehouse_name=info.name,
175
+ _details=info
176
+ )
177
+
178
+ yield warehouse
179
+
180
+ def _check_details(
181
+ self,
182
+ keys: Sequence[str],
183
+ details: Optional[EndpointInfo] = None,
184
+ **warehouse_specs
185
+ ):
186
+ if details is None:
187
+ details = EndpointInfo(**{
188
+ k: v
189
+ for k, v in warehouse_specs.items()
190
+ if k in keys
191
+ })
192
+ else:
193
+ kwargs = {
194
+ **details.as_shallow_dict(),
195
+ **warehouse_specs
196
+ }
197
+
198
+ details = EndpointInfo(
199
+ **{
200
+ k: v
201
+ for k, v in kwargs.items()
202
+ if k in keys
203
+ },
204
+ )
205
+
206
+ if details.cluster_size is None:
207
+ details.cluster_size = "Small"
208
+
209
+ if details.name is None:
210
+ details.name = "YGG-%s" % details.cluster_size.upper()
211
+
212
+ default_tags = self.workspace.default_tags()
213
+
214
+ if details.tags is None:
215
+ details.tags = EndpointTags(custom_tags=[
216
+ EndpointTagPair(key=k, value=v)
217
+ for k, v in default_tags.items()
218
+ ])
219
+ else:
220
+ tags = {
221
+ pair.key: pair.value
222
+ for pair in details.tags.custom_tags
223
+ }
224
+
225
+ tags.update(default_tags)
226
+
227
+ if details.tags is not None and not isinstance(details.tags, EndpointTags):
228
+ details.tags = EndpointTags(custom_tags=[
229
+ EndpointTagPair(key=k, value=v)
230
+ for k, v in default_tags.items()
231
+ ])
232
+
233
+ if not details.max_num_clusters:
234
+ details.max_num_clusters = 4
235
+
236
+ if details.warehouse_type is None:
237
+ details.warehouse_type = EndpointInfoWarehouseType.CLASSIC
238
+
239
+ return details
240
+
241
+ def create_or_update(
242
+ self,
243
+ warehouse_id: Optional[str] = None,
244
+ name: Optional[str] = None,
245
+ **warehouse_specs
246
+ ):
247
+ name = name or self.warehouse_name
248
+ found = self.find_warehouse(warehouse_id=warehouse_id, warehouse_name=name, raise_error=False)
249
+
250
+ if found is not None:
251
+ return found.update(name=name, **warehouse_specs)
252
+ return self.create(name=name, **warehouse_specs)
253
+
254
+ def create(
255
+ self,
256
+ name: Optional[str] = None,
257
+ **warehouse_specs
258
+ ):
259
+ name = name or self.warehouse_name
260
+
261
+ details = self._check_details(
262
+ keys=_CREATE_ARG_NAMES,
263
+ name=name,
264
+ **warehouse_specs
265
+ )
266
+
267
+ info = self.warehouse_client().create_and_wait(**{
268
+ k: v
269
+ for k, v in details.as_shallow_dict().items()
270
+ if k in _CREATE_ARG_NAMES
271
+ })
272
+
273
+ return SQLWarehouse(
274
+ workspace=self.workspace,
275
+ warehouse_id=info.id,
276
+ warehouse_name=info.name,
277
+ _details=info
278
+ )
279
+
280
+ def update(
281
+ self,
282
+ **warehouse_specs
283
+ ):
284
+ if not warehouse_specs:
285
+ return self
286
+
287
+ existing_details = {
288
+ k: v
289
+ for k, v in self.details.as_shallow_dict().items()
290
+ if k in _EDIT_ARG_NAMES
291
+ }
292
+
293
+ update_details = {
294
+ k: v
295
+ for k, v in (
296
+ self._check_details(details=self.details, keys=_EDIT_ARG_NAMES, **warehouse_specs)
297
+ .as_shallow_dict()
298
+ .items()
299
+ )
300
+ if k in _EDIT_ARG_NAMES
301
+ }
302
+
303
+ same = dicts_equal(
304
+ existing_details,
305
+ update_details,
306
+ keys=_EDIT_ARG_NAMES,
307
+ treat_missing_as_none=True,
308
+ float_tol=0.0, # set e.g. 1e-6 if you have float-y stuff
309
+ )
310
+
311
+ if not same:
312
+ diff = {
313
+ k: v[1]
314
+ for k, v in dict_diff(existing_details, update_details, keys=_EDIT_ARG_NAMES).items()
315
+ }
316
+
317
+ LOGGER.debug(
318
+ "Updating %s with %s",
319
+ self, diff
320
+ )
321
+
322
+ self.warehouse_client().edit_and_wait(**update_details)
323
+
324
+ LOGGER.info(
325
+ "Updated %s",
326
+ self
327
+ )
328
+
329
+ return self
330
+
331
+ def sql(
332
+ self,
333
+ workspace: Optional[Workspace] = None,
334
+ warehouse_id: Optional[str] = None,
335
+ catalog_name: Optional[str] = None,
336
+ schema_name: Optional[str] = None,
337
+ ):
338
+ """Return a SQLEngine configured for this workspace.
339
+
340
+ Args:
341
+ workspace: Optional workspace override.
342
+ warehouse_id: Optional SQL warehouse id.
343
+ catalog_name: Optional catalog name.
344
+ schema_name: Optional schema name.
345
+
346
+ Returns:
347
+ A SQLEngine instance.
348
+ """
349
+
350
+ return self.workspace.sql(
351
+ workspace=workspace,
352
+ warehouse_id=warehouse_id or self.warehouse_id,
353
+ catalog_name=catalog_name,
354
+ schema_name=schema_name
355
+ )
@@ -72,7 +72,7 @@ class Workspace:
72
72
  """Configuration wrapper for connecting to a Databricks workspace."""
73
73
  # Databricks / generic
74
74
  host: Optional[str] = None
75
- account_id: Optional[str] = None
75
+ account_id: Optional[str] = dataclasses.field(default=None, repr=False)
76
76
  token: Optional[str] = dataclasses.field(default=None, repr=False)
77
77
  client_id: Optional[str] = dataclasses.field(default=None, repr=False)
78
78
  client_secret: Optional[str] = dataclasses.field(default=None, repr=False)
@@ -220,7 +220,6 @@ class Workspace:
220
220
  instance = self.clone_instance() if clone else self
221
221
 
222
222
  require_databricks_sdk()
223
- logger.debug("Connecting %s", self)
224
223
 
225
224
  # Build Config from config_dict if available, else from fields.
226
225
  kwargs = {
@@ -291,8 +290,6 @@ class Workspace:
291
290
  if v is not None:
292
291
  setattr(instance, key, v)
293
292
 
294
- logger.info("Connected %s", instance)
295
-
296
293
  return instance
297
294
 
298
295
  # ------------------------------------------------------------------ #
@@ -570,6 +567,7 @@ class Workspace:
570
567
  ("Product", self.product),
571
568
  ("ProductVersion", self.product_version),
572
569
  ("ProductTag", self.product_tag),
570
+ ("ProductUser", self.current_user.user_name)
573
571
  )
574
572
  if v
575
573
  }
@@ -589,17 +587,17 @@ class Workspace:
589
587
  def sql(
590
588
  self,
591
589
  workspace: Optional["Workspace"] = None,
590
+ warehouse_id: Optional[str] = None,
592
591
  catalog_name: Optional[str] = None,
593
592
  schema_name: Optional[str] = None,
594
- **kwargs
595
593
  ):
596
594
  """Return a SQLEngine configured for this workspace.
597
595
 
598
596
  Args:
599
597
  workspace: Optional workspace override.
598
+ warehouse_id: Optional SQL warehouse id.
600
599
  catalog_name: Optional catalog name.
601
600
  schema_name: Optional schema name.
602
- **kwargs: Additional SQLEngine parameters.
603
601
 
604
602
  Returns:
605
603
  A SQLEngine instance.
@@ -608,16 +606,29 @@ class Workspace:
608
606
 
609
607
  return SQLEngine(
610
608
  workspace=self if workspace is None else workspace,
609
+ warehouse_id=warehouse_id,
611
610
  catalog_name=catalog_name,
612
611
  schema_name=schema_name,
613
- **kwargs
612
+ )
613
+
614
+ def warehouses(
615
+ self,
616
+ workspace: Optional["Workspace"] = None,
617
+ warehouse_id: Optional[str] = None,
618
+ warehouse_name: Optional[str] = None,
619
+ ):
620
+ from ..sql.warehouse import SQLWarehouse
621
+
622
+ return SQLWarehouse(
623
+ workspace=self if workspace is None else workspace,
624
+ warehouse_id=warehouse_id,
625
+ warehouse_name=warehouse_name
614
626
  )
615
627
 
616
628
  def clusters(
617
629
  self,
618
630
  cluster_id: Optional[str] = None,
619
631
  cluster_name: Optional[str] = None,
620
- **kwargs
621
632
  ) -> "Cluster":
622
633
  """Return a Cluster helper bound to this workspace.
623
634
 
@@ -635,7 +646,6 @@ class Workspace:
635
646
  workspace=self,
636
647
  cluster_id=cluster_id,
637
648
  cluster_name=cluster_name,
638
- **kwargs
639
649
  )
640
650
 
641
651
  # ---------------------------------------------------------------------------