ygg 0.1.55__py3-none-any.whl → 0.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ygg
3
- Version: 0.1.55
3
+ Version: 0.1.56
4
4
  Summary: Type-friendly utilities for moving data between Python objects, Arrow, Polars, Pandas, Spark, and Databricks
5
5
  Author: Yggdrasil contributors
6
6
  License: Apache License
@@ -1,12 +1,14 @@
1
- ygg-0.1.55.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
1
+ ygg-0.1.56.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
2
2
  yggdrasil/__init__.py,sha256=4-ghPak2S6zfMqmnlxW2GCgPb5s79znpKa2hGEGXcE4,24
3
- yggdrasil/version.py,sha256=RJcC5yver_ugtAU_HgwZ-2HWyszBXpsLZyyNma6G4Dc,22
3
+ yggdrasil/exceptions.py,sha256=NEpbDFn-8ZRsLiEgJicCwrTHNMWAGtdrTJzosfAeVJo,82
4
+ yggdrasil/version.py,sha256=c0ITmemMU7anWgDBUD3t_BAhA3li2gt3XgswCbHv1oU,22
4
5
  yggdrasil/databricks/__init__.py,sha256=skctY2c8W-hI81upx9F_PWRe5ishL3hrdiTuizgDjdw,152
5
6
  yggdrasil/databricks/ai/__init__.py,sha256=Mkp70UOVBzDQvdPNsqncHcyzxe5PnSGYE_bHnYxA1eA,21
6
- yggdrasil/databricks/ai/loki.py,sha256=iyekxctP6393LBN0PJOZgHBxQs9vDyQeRXPrru_krF0,3661
7
+ yggdrasil/databricks/ai/loki.py,sha256=HyVWxzJgfW03YO6TMOTJ1oNvrBJovnqYDn_MeNV6Ni0,11989
7
8
  yggdrasil/databricks/compute/__init__.py,sha256=NvdzmaJSNYY1uJthv1hHdBuNu3bD_-Z65DWnaJt9yXg,289
8
9
  yggdrasil/databricks/compute/cluster.py,sha256=YomLfvB0oxbgl6WDgBRxI1UXsxwlEbR6gq3FUbPHscY,44199
9
- yggdrasil/databricks/compute/execution_context.py,sha256=jIV6uru2NeX3O5lg-3KEqmXtLxxq45CFgkBQgQIIOHQ,23327
10
+ yggdrasil/databricks/compute/exceptions.py,sha256=Ug0ioxu5m2atdTX4OLH0s4R4dylHNxEdn7VhQI66b5M,209
11
+ yggdrasil/databricks/compute/execution_context.py,sha256=mhcwSvKTxgcUHdb7huSEjCVU_feiXSGq0JLyLXldjQM,23952
10
12
  yggdrasil/databricks/compute/remote.py,sha256=yicEhyQypssRa2ByscO36s3cBkEgORFsRME9aaq91Pc,3045
11
13
  yggdrasil/databricks/jobs/__init__.py,sha256=snxGSJb0M5I39v0y3IR-uEeSlZR248cQ_4DJ1sYs-h8,154
12
14
  yggdrasil/databricks/jobs/config.py,sha256=9LGeHD04hbfy0xt8_6oobC4moKJh4_DTjZiK4Q2Tqjk,11557
@@ -19,7 +21,7 @@ yggdrasil/databricks/sql/warehouse.py,sha256=1J0dyQLJb-OS1_1xU1eAVZ4CoL2-FhFeowK
19
21
  yggdrasil/databricks/workspaces/__init__.py,sha256=dv2zotoFVhNFlTCdRq6gwf5bEzeZkOZszoNZMs0k59g,114
20
22
  yggdrasil/databricks/workspaces/filesytem.py,sha256=Z8JXU7_XUEbw9fpTQT1avRQKi-IAP2KemXBMPkUoY4w,9805
21
23
  yggdrasil/databricks/workspaces/io.py,sha256=PAoxIxYvTC162Dx2qL2hk8oAdt8BnYrQ3jJHcJm4VkA,33116
22
- yggdrasil/databricks/workspaces/path.py,sha256=KkvLFHrps3UFr4ogYdESbJHEMfQBcWfWfXjlrv_7rTU,55180
24
+ yggdrasil/databricks/workspaces/path.py,sha256=h1j3bvjwKcDhJvlU_kAaLcLVz4jrdaWgjqPQMarZRHU,55233
23
25
  yggdrasil/databricks/workspaces/path_kind.py,sha256=rhWe1ky7uPD0du0bZSv2S4fK4C5zWd7zAF3UeS2iiPU,283
24
26
  yggdrasil/databricks/workspaces/volumes_path.py,sha256=s8CA33cG3jpMVJy5MILLlkEBcFg_qInDCF2jozLj1Fg,2431
25
27
  yggdrasil/databricks/workspaces/workspace.py,sha256=5DCPz5io_rmrpGNi5I6RChmyZ8kjlNUFGQl8mzQJThg,25511
@@ -59,8 +61,8 @@ yggdrasil/types/cast/registry.py,sha256=OOqIfbIjPH-a3figvu-zTvEtUDTEWhe2xIl3cCA4
59
61
  yggdrasil/types/cast/spark_cast.py,sha256=_KAsl1DqmKMSfWxqhVE7gosjYdgiL1C5bDQv6eP3HtA,24926
60
62
  yggdrasil/types/cast/spark_pandas_cast.py,sha256=BuTiWrdCANZCdD_p2MAytqm74eq-rdRXd-LGojBRrfU,5023
61
63
  yggdrasil/types/cast/spark_polars_cast.py,sha256=btmZNHXn2NSt3fUuB4xg7coaE0RezIBdZD92H8NK0Jw,9073
62
- ygg-0.1.55.dist-info/METADATA,sha256=F84590C1dKd4ZllEbe1uzAsMuUw9we134qVVQmOcGNI,18528
63
- ygg-0.1.55.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
64
- ygg-0.1.55.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
65
- ygg-0.1.55.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
66
- ygg-0.1.55.dist-info/RECORD,,
64
+ ygg-0.1.56.dist-info/METADATA,sha256=Rr4DBB8q39XEEzBTQtYB_pja448GR74cEh3gc8-BKvY,18528
65
+ ygg-0.1.56.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
66
+ ygg-0.1.56.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
67
+ ygg-0.1.56.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
68
+ ygg-0.1.56.dist-info/RECORD,,
@@ -1,44 +1,79 @@
1
- from typing import Optional, Dict, Any, List
2
- from dataclasses import field, dataclass
1
+ """
2
+ loki.py
3
+
4
+ Databricks Model Serving (OpenAI-compatible) wrapper with:
5
+ - Loki.ask(): stateless call
6
+ - TradingChatSession: stateful commodity trading analytics chat
7
+ - SqlChatSession: stateful Databricks SQL generator chat
8
+
9
+ Important constraint:
10
+ - Gemini models only support ONE system prompt.
11
+ => We must NOT send multiple system messages.
12
+ => We fold summary + context blocks into a single system string.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ from dataclasses import dataclass, field
19
+ from typing import Any, Dict, List, Optional, Union
3
20
 
4
21
  from ..workspaces.workspace import WorkspaceService
5
22
 
6
23
  try:
24
+ from openai import OpenAI as _OpenAI # noqa: F401
25
+ except ImportError:
26
+ _OpenAI = None # type: ignore
27
+
28
+
29
+ def make_ai_client(api_key: str, base_url: str):
30
+ """Late import so module can load even if openai isn't installed."""
7
31
  from openai import OpenAI
32
+ return OpenAI(api_key=api_key, base_url=base_url)
33
+
34
+
35
+ __all__ = ["Loki", "TradingChatSession", "SqlChatSession"]
8
36
 
9
- def make_ai_client(
10
- api_key: str,
11
- base_url: str
12
- ):
13
- return OpenAI(
14
- api_key=api_key,
15
- base_url=base_url
16
- )
17
- except ImportError:
18
- class OpenAI:
19
- pass
20
-
21
- def make_ai_client(
22
- api_key: str,
23
- base_url: str
24
- ):
25
- from openai import OpenAI
26
-
27
- return OpenAI(
28
- api_key=api_key,
29
- base_url=base_url
30
- )
31
37
 
32
- __all__ = [
33
- "Loki"
34
- ]
38
+ DEFAULT_TRADING_SYSTEM_PROMPT = """You are Loki: a conversational commodity trading analytics copilot.
39
+
40
+ Scope:
41
+ - Commodity trading analytics: curves, forwards, spreads, basis, hedging, risk, PnL explain, inventory, scheduling.
42
+ - Databricks-first workflows: Spark/Delta/Unity Catalog, Databricks SQL, performant Python.
43
+
44
+ Rules:
45
+ - Do NOT invent prices, positions, PnL, risk, or market facts not provided.
46
+ - State assumptions explicitly (units, time conventions, delivery months, calendars).
47
+ - Prefer actionable output (SQL + efficient Python). Avoid slow patterns.
48
+ - If data is missing, list exactly what you need and proceed with a reasonable template.
49
+
50
+ Style:
51
+ - Be concise, practical, performance-focused.
52
+ """
53
+
54
+ DEFAULT_SQL_SYSTEM_PROMPT = """You are LokiSQL: a Databricks SQL generator for commodity trading analytics.
55
+
56
+ Hard rules:
57
+ - Output ONLY SQL unless the user explicitly asks for explanation.
58
+ - Use Databricks SQL / Spark SQL dialect.
59
+ - Prefer readable CTEs, explicit column lists, deterministic joins.
60
+ - Do NOT invent table/column names. If missing, use placeholders like <table>, <col>.
61
+ - If ambiguous, output best-effort SQL template with SQL comments (-- TODO ...) and placeholders.
62
+ - Performance: push filters early, avoid exploding joins, avoid SELECT *.
63
+
64
+ Default assumptions:
65
+ - Dates UTC unless specified.
66
+ """
35
67
 
36
68
 
37
69
  @dataclass
38
70
  class Loki(WorkspaceService):
39
- model: str = "databricks-gemini-2-5-flash"
71
+ """
72
+ Loki wraps an OpenAI-compatible client pointing at Databricks Model Serving endpoints.
73
+ """
40
74
 
41
- _ai_client: Optional[OpenAI] = field(repr=False, hash=False, default=None)
75
+ model: str = "databricks-gemini-2-5-flash"
76
+ _ai_client: Optional[Any] = field(repr=False, hash=False, default=None)
42
77
 
43
78
  @property
44
79
  def ai_client(self):
@@ -47,9 +82,10 @@ class Loki(WorkspaceService):
47
82
  return self._ai_client
48
83
 
49
84
  def make_aiclient(self):
85
+ host = self.workspace.host.rstrip("/")
50
86
  return make_ai_client(
51
87
  api_key=self.workspace.current_token(),
52
- base_url=self.workspace.host + "/serving-endpoints"
88
+ base_url=f"{host}/serving-endpoints",
53
89
  )
54
90
 
55
91
  def ask(
@@ -60,55 +96,25 @@ class Loki(WorkspaceService):
60
96
  max_tokens: int = 5000,
61
97
  temperature: Optional[float] = None,
62
98
  extra_messages: Optional[List[Dict[str, str]]] = None,
63
- **kwargs,
99
+ **kwargs: Any,
64
100
  ) -> str:
65
101
  """
66
- Send a chat prompt to a Databricks Model Serving endpoint using the
67
- OpenAI-compatible Chat Completions API and return the assistant's text.
68
-
69
- This is a thin convenience wrapper around:
70
- self.ai_client.chat.completions.create(...)
71
-
72
- Parameters
73
- ----------
74
- command:
75
- The user prompt text to send.
76
- system:
77
- Optional system instruction (prepended as the first message).
78
- max_tokens:
79
- Upper bound on the number of tokens generated in the response.
80
- temperature:
81
- Optional sampling temperature. If None, the client/model default is used.
82
- extra_messages:
83
- Optional list of additional chat messages to insert before the user prompt.
84
- Each item should be a dict like {"role": "...", "content": "..."}.
85
- Useful for few-shot examples or carrying prior context.
86
- **kwargs:
87
- Any additional parameters forwarded directly to
88
- `chat.completions.create(...)` (e.g. top_p, presence_penalty, etc.).
89
-
90
- Returns
91
- -------
92
- str
93
- The assistant message content (empty string if the API returns no content).
94
-
95
- Raises
96
- ------
97
- Exception
98
- Propagates any exceptions raised by the OpenAI client (HTTP errors,
99
- auth errors, invalid request errors, timeouts, etc.).
102
+ Stateless single call to the model.
103
+
104
+ NOTE (Gemini constraint):
105
+ - Provide at most ONE system prompt (i.e., a single system message).
106
+ - Do not pass additional messages with role="system".
100
107
  """
101
108
  messages: List[Dict[str, str]] = []
102
-
103
109
  if system:
104
110
  messages.append({"role": "system", "content": system})
105
111
 
106
112
  if extra_messages:
113
+ # IMPORTANT: caller must not include any "system" roles here for Gemini models
107
114
  messages.extend(extra_messages)
108
115
 
109
116
  messages.append({"role": "user", "content": command})
110
117
 
111
- # Build params cleanly (only include temperature if provided)
112
118
  params: Dict[str, Any] = dict(
113
119
  model=self.model,
114
120
  messages=messages,
@@ -120,3 +126,249 @@ class Loki(WorkspaceService):
120
126
 
121
127
  resp = self.ai_client.chat.completions.create(**params)
122
128
  return resp.choices[0].message.content or ""
129
+
130
+ def new_trading_chat(
131
+ self,
132
+ *,
133
+ system_prompt: str = DEFAULT_TRADING_SYSTEM_PROMPT,
134
+ max_context_turns: int = 20,
135
+ max_context_chars: int = 120_000,
136
+ ) -> "TradingChatSession":
137
+ return TradingChatSession(
138
+ loki=self,
139
+ system_prompt=system_prompt,
140
+ max_context_turns=max_context_turns,
141
+ max_context_chars=max_context_chars,
142
+ )
143
+
144
+ def new_sql_chat(
145
+ self,
146
+ *,
147
+ system_prompt: str = DEFAULT_SQL_SYSTEM_PROMPT,
148
+ max_context_turns: int = 20,
149
+ max_context_chars: int = 120_000,
150
+ ) -> "SqlChatSession":
151
+ return SqlChatSession(
152
+ loki=self,
153
+ system_prompt=system_prompt,
154
+ max_context_turns=max_context_turns,
155
+ max_context_chars=max_context_chars,
156
+ )
157
+
158
+
159
+ @dataclass
160
+ class _BaseChatSession:
161
+ """
162
+ Stateful session that maintains history + injected context blocks.
163
+
164
+ Gemini constraint:
165
+ - We must fold ALL system content into one system string.
166
+ - Therefore summary/context_blocks are concatenated into the system prompt.
167
+ """
168
+ loki: Loki
169
+ system_prompt: str
170
+
171
+ history: List[Dict[str, str]] = field(default_factory=list)
172
+ summary: Optional[str] = None
173
+ context_blocks: List[str] = field(default_factory=list)
174
+
175
+ max_context_turns: int = 20
176
+ max_context_chars: int = 120_000
177
+
178
+ def reset(self) -> None:
179
+ self.history.clear()
180
+ self.summary = None
181
+ self.context_blocks.clear()
182
+
183
+ def add_context(self, title: str, payload: Union[str, Dict[str, Any], List[Any]]) -> None:
184
+ if isinstance(payload, str):
185
+ payload_str = payload
186
+ else:
187
+ payload_str = json.dumps(payload, ensure_ascii=False, indent=2)
188
+
189
+ self.context_blocks.append(f"[Context: {title}]\n{payload_str}".strip())
190
+ self._trim()
191
+
192
+ def _estimate_chars(self, msgs: List[Dict[str, str]]) -> int:
193
+ return sum(len(m.get("content", "")) for m in msgs)
194
+
195
+ def _build_system(self, extra_system: Optional[str] = None) -> str:
196
+ parts: List[str] = [self.system_prompt.strip()]
197
+ if extra_system:
198
+ parts.append(extra_system.strip())
199
+ if self.summary:
200
+ parts.append(f"[ConversationSummary]\n{self.summary}".strip())
201
+ if self.context_blocks:
202
+ parts.append("\n\n".join(self.context_blocks).strip())
203
+ return "\n\n".join(p for p in parts if p)
204
+
205
+ def _trim(self) -> None:
206
+ # Turn trim (keep last N turns => N*2 messages)
207
+ if self.max_context_turns > 0:
208
+ max_msgs = self.max_context_turns * 2
209
+ if len(self.history) > max_msgs:
210
+ self.history = self.history[-max_msgs:]
211
+
212
+ # Char trim: shrink history first, then context blocks if needed
213
+ def total_chars() -> int:
214
+ sys_len = len(self._build_system())
215
+ return sys_len + self._estimate_chars(self.history)
216
+
217
+ while total_chars() > self.max_context_chars and self.history:
218
+ self.history = self.history[1:]
219
+
220
+ while total_chars() > self.max_context_chars and self.context_blocks:
221
+ self.context_blocks = self.context_blocks[1:]
222
+
223
+
224
+ @dataclass
225
+ class TradingChatSession(_BaseChatSession):
226
+ """
227
+ Commodity trading analytics chat session.
228
+ Optionally returns structured JSON for downstream automation.
229
+ """
230
+
231
+ def chat(
232
+ self,
233
+ user_text: str,
234
+ *,
235
+ structured: bool = True,
236
+ max_tokens: int = 12000,
237
+ temperature: Optional[float] = None,
238
+ **kwargs: Any,
239
+ ) -> Union[str, Dict[str, Any]]:
240
+ self._trim()
241
+
242
+ extra_system = None
243
+ if structured:
244
+ extra_system = (
245
+ "Respond ONLY as valid JSON with keys: "
246
+ "final_answer (string), assumptions (array of strings), data_needed (array of strings), "
247
+ "sql (string or null), python (string or null). "
248
+ "No markdown. No extra keys."
249
+ )
250
+
251
+ system = self._build_system(extra_system=extra_system)
252
+
253
+ assistant_text = self.loki.ask(
254
+ user_text,
255
+ system=system,
256
+ extra_messages=self.history, # NOTE: history must contain no system roles
257
+ max_tokens=max_tokens,
258
+ temperature=temperature,
259
+ **kwargs,
260
+ )
261
+
262
+ self.history.append({"role": "user", "content": user_text})
263
+ self.history.append({"role": "assistant", "content": assistant_text})
264
+ self._trim()
265
+
266
+ if structured:
267
+ parsed = _try_parse_json_object(assistant_text)
268
+ if parsed is not None:
269
+ return parsed
270
+
271
+ return assistant_text
272
+
273
+
274
+ @dataclass
275
+ class SqlChatSession(_BaseChatSession):
276
+ """
277
+ SQL-only conversational session that generates Databricks SQL.
278
+
279
+ Uses a single system message with strict instructions to output SQL only.
280
+ """
281
+
282
+ def generate_sql(
283
+ self,
284
+ request: str,
285
+ *,
286
+ max_tokens: int = 12000,
287
+ temperature: Optional[float] = None,
288
+ sql_only: bool = True,
289
+ **kwargs: Any,
290
+ ) -> str:
291
+ self._trim()
292
+
293
+ extra_system = None
294
+ if sql_only:
295
+ extra_system = (
296
+ "Reminder: Output ONLY SQL. "
297
+ "If ambiguity exists, use SQL comments (-- TODO ...) and placeholders, but still output SQL only."
298
+ )
299
+
300
+ system = self._build_system(extra_system=extra_system)
301
+
302
+ sql = self.loki.ask(
303
+ request,
304
+ system=system,
305
+ extra_messages=self.history, # history must contain no system roles
306
+ max_tokens=max_tokens,
307
+ temperature=temperature,
308
+ **kwargs,
309
+ ).strip()
310
+
311
+ self.history.append({"role": "user", "content": request})
312
+ self.history.append({"role": "assistant", "content": sql})
313
+ self._trim()
314
+
315
+ return _strip_sql_fences(sql)
316
+
317
+
318
+ def _strip_sql_fences(text: str) -> str:
319
+ t = text.strip()
320
+ if t.startswith("```"):
321
+ lines = t.splitlines()
322
+ lines = lines[1:] # drop ``` or ```sql
323
+ if lines and lines[-1].strip().startswith("```"):
324
+ lines = lines[:-1]
325
+ return "\n".join(lines).strip()
326
+ return t
327
+
328
+
329
+ def _strip_markdown_fences(text: str) -> str:
330
+ """
331
+ Remove ```lang ... ``` fences if present.
332
+ Keeps inner content unchanged.
333
+ """
334
+ t = text.strip()
335
+ if not t.startswith("```"):
336
+ return t
337
+
338
+ lines = t.splitlines()
339
+ if not lines:
340
+ return t
341
+
342
+ # Drop first line: ``` or ```json
343
+ lines = lines[1:]
344
+
345
+ # Drop last line if it's ```
346
+ if lines and lines[-1].strip().startswith("```"):
347
+ lines = lines[:-1]
348
+
349
+ return "\n".join(lines).strip()
350
+
351
+
352
+ def _try_parse_json_object(text: str) -> Optional[Dict[str, Any]]:
353
+ t = _strip_markdown_fences(text).strip()
354
+
355
+ # Best effort extraction if there's extra junk around JSON
356
+ if not t.startswith("{"):
357
+ start = t.find("{")
358
+ end = t.rfind("}")
359
+ if start != -1 and end != -1 and end > start:
360
+ t = t[start : end + 1]
361
+
362
+ try:
363
+ obj = json.loads(t)
364
+ except Exception:
365
+ return None
366
+
367
+ if not isinstance(obj, dict):
368
+ return None
369
+
370
+ required = {"final_answer", "assumptions", "data_needed", "sql", "python"}
371
+ if not required.issubset(set(obj.keys())):
372
+ return None
373
+
374
+ return obj
@@ -0,0 +1,14 @@
1
+ from ...exceptions import YGGException
2
+
3
+ __all__ = [
4
+ "ComputeException",
5
+ "CommandAborted"
6
+ ]
7
+
8
+
9
+ class ComputeException(YGGException):
10
+ pass
11
+
12
+
13
+ class CommandAborted(YGGException):
14
+ pass
@@ -16,6 +16,7 @@ from threading import Thread
16
16
  from types import ModuleType
17
17
  from typing import TYPE_CHECKING, Optional, Any, Callable, List, Dict, Union, Iterable, Tuple
18
18
 
19
+ from .exceptions import CommandAborted
19
20
  from ...libs.databrickslib import databricks_sdk
20
21
  from ...pyutils.exceptions import raise_parsed_traceback
21
22
  from ...pyutils.expiring_dict import ExpiringDict
@@ -110,16 +111,12 @@ class ExecutionContext:
110
111
  def __exit__(self, exc_type, exc_val, exc_tb):
111
112
  """Exit the context manager and close the remote context if created."""
112
113
  if not self._was_connected:
113
- self.close()
114
+ self.close(wait=False)
114
115
  self.cluster.__exit__(exc_type, exc_val=exc_val, exc_tb=exc_tb)
115
116
 
116
117
  def __del__(self):
117
118
  """Best-effort cleanup for the remote execution context."""
118
- if self.context_id:
119
- try:
120
- Thread(target=self.close).start()
121
- except BaseException:
122
- pass
119
+ self.close(wait=False)
123
120
 
124
121
  @property
125
122
  def remote_metadata(self) -> RemoteMetadata:
@@ -180,7 +177,7 @@ print(json.dumps(meta))"""
180
177
  """
181
178
  return self.cluster.workspace.sdk()
182
179
 
183
- def create_command(
180
+ def create(
184
181
  self,
185
182
  language: "Language",
186
183
  ) -> any:
@@ -197,15 +194,17 @@ print(json.dumps(meta))"""
197
194
  self.cluster
198
195
  )
199
196
 
197
+ client = self._workspace_client().command_execution
198
+
200
199
  try:
201
- created = self._workspace_client().command_execution.create_and_wait(
200
+ created = client.create_and_wait(
202
201
  cluster_id=self.cluster.cluster_id,
203
202
  language=language,
204
203
  )
205
204
  except:
206
205
  self.cluster.ensure_running()
207
206
 
208
- created = self._workspace_client().command_execution.create_and_wait(
207
+ created = client.create_and_wait(
209
208
  cluster_id=self.cluster.cluster_id,
210
209
  language=language,
211
210
  )
@@ -217,42 +216,38 @@ print(json.dumps(meta))"""
217
216
 
218
217
  created = getattr(created, "response", created)
219
218
 
220
- return created
219
+ self.context_id = created.id
220
+
221
+ return self
221
222
 
222
223
  def connect(
223
224
  self,
224
- language: Optional["Language"] = None
225
+ language: Optional["Language"] = None,
226
+ reset: bool = False
225
227
  ) -> "ExecutionContext":
226
228
  """Create a remote command execution context if not already open.
227
229
 
228
230
  Args:
229
231
  language: Optional language override for the context.
232
+ reset: Reset existing if connected
230
233
 
231
234
  Returns:
232
235
  The connected ExecutionContext instance.
233
236
  """
234
237
  if self.context_id is not None:
235
- return self
238
+ if not reset:
239
+ return self
236
240
 
237
- self.language = language or self.language
241
+ self.close(wait=False)
238
242
 
239
- if self.language is None:
240
- self.language = Language.PYTHON
243
+ language = language or self.language
241
244
 
242
- ctx = self.create_command(language=self.language)
245
+ if language is None:
246
+ language = Language.PYTHON
243
247
 
244
- context_id = ctx.id
245
- if not context_id:
246
- raise RuntimeError("Failed to create command execution context")
248
+ return self.create(language=language)
247
249
 
248
- self.context_id = context_id
249
- LOGGER.info(
250
- "Opened execution context for %s",
251
- self
252
- )
253
- return self
254
-
255
- def close(self) -> None:
250
+ def close(self, wait: bool = True) -> None:
256
251
  """Destroy the remote command execution context if it exists.
257
252
 
258
253
  Returns:
@@ -261,12 +256,23 @@ print(json.dumps(meta))"""
261
256
  if not self.context_id:
262
257
  return
263
258
 
259
+ client = self._workspace_client()
260
+
264
261
  try:
265
- self._workspace_client().command_execution.destroy(
266
- cluster_id=self.cluster.cluster_id,
267
- context_id=self.context_id,
268
- )
269
- except Exception:
262
+ if wait:
263
+ client.command_execution.destroy(
264
+ cluster_id=self.cluster.cluster_id,
265
+ context_id=self.context_id,
266
+ )
267
+ else:
268
+ Thread(
269
+ target=client.command_execution.destroy,
270
+ kwargs={
271
+ "cluster_id": self.cluster.cluster_id,
272
+ "context_id": self.context_id,
273
+ }
274
+ ).start()
275
+ except BaseException:
270
276
  # non-fatal: context cleanup best-effort
271
277
  pass
272
278
  finally:
@@ -465,7 +471,18 @@ print(json.dumps(meta))"""
465
471
  )
466
472
 
467
473
  try:
468
- return self._decode_result(result, result_tag=result_tag, print_stdout=print_stdout)
474
+ return self._decode_result(
475
+ result,
476
+ result_tag=result_tag,
477
+ print_stdout=print_stdout
478
+ )
479
+ except CommandAborted:
480
+ return self.connect(language=self.language, reset=True).execute_command(
481
+ command=command,
482
+ timeout=timeout,
483
+ result_tag=result_tag,
484
+ print_stdout=print_stdout
485
+ )
469
486
  except ModuleNotFoundError as remote_module_error:
470
487
  _MOD_NOT_FOUND_RE = re.compile(r"No module named ['\"]([^'\"]+)['\"]")
471
488
  module_name = _MOD_NOT_FOUND_RE.search(str(remote_module_error))
@@ -660,6 +677,9 @@ with zipfile.ZipFile(buf, "r") as zf:
660
677
  if res.result_type == ResultType.ERROR:
661
678
  message = res.cause or "Command execution failed"
662
679
 
680
+ if "client terminated the session" in message:
681
+ raise CommandAborted(message)
682
+
663
683
  if self.language == Language.PYTHON:
664
684
  raise_parsed_traceback(message)
665
685
 
@@ -668,6 +688,7 @@ with zipfile.ZipFile(buf, "r") as zf:
668
688
  or getattr(res, "stack_trace", None)
669
689
  or getattr(res, "traceback", None)
670
690
  )
691
+
671
692
  if remote_tb:
672
693
  message = f"{message}\n{remote_tb}"
673
694
 
@@ -30,7 +30,7 @@ from ...types.cast.registry import convert, register_converter
30
30
  from ...types.file_format import ExcelFileFormat
31
31
 
32
32
  if databricks is not None:
33
- from databricks.sdk.service.catalog import VolumeType, PathOperation, VolumeInfo
33
+ from databricks.sdk.service.catalog import VolumeType, VolumeInfo
34
34
  from databricks.sdk.service.workspace import ObjectType
35
35
  from databricks.sdk.errors.platform import (
36
36
  NotFound,
@@ -1236,6 +1236,8 @@ class DatabricksPath:
1236
1236
  self,
1237
1237
  operation: Optional["PathOperation"] = None
1238
1238
  ):
1239
+ from databricks.sdk.service.catalog import PathOperation
1240
+
1239
1241
  if self.kind != DatabricksPathKind.VOLUME:
1240
1242
  raise ValueError(f"Cannot generate temporary credentials for {repr(self)}")
1241
1243
 
@@ -0,0 +1,7 @@
1
+ __all__ = [
2
+ "YGGException"
3
+ ]
4
+
5
+
6
+ class YGGException(Exception):
7
+ pass
yggdrasil/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.55"
1
+ __version__ = "0.1.56"
File without changes