ygg 0.1.55__py3-none-any.whl → 0.1.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ygg-0.1.55.dist-info → ygg-0.1.56.dist-info}/METADATA +1 -1
- {ygg-0.1.55.dist-info → ygg-0.1.56.dist-info}/RECORD +12 -10
- yggdrasil/databricks/ai/loki.py +319 -67
- yggdrasil/databricks/compute/exceptions.py +14 -0
- yggdrasil/databricks/compute/execution_context.py +54 -33
- yggdrasil/databricks/workspaces/path.py +3 -1
- yggdrasil/exceptions.py +7 -0
- yggdrasil/version.py +1 -1
- {ygg-0.1.55.dist-info → ygg-0.1.56.dist-info}/WHEEL +0 -0
- {ygg-0.1.55.dist-info → ygg-0.1.56.dist-info}/entry_points.txt +0 -0
- {ygg-0.1.55.dist-info → ygg-0.1.56.dist-info}/licenses/LICENSE +0 -0
- {ygg-0.1.55.dist-info → ygg-0.1.56.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
ygg-0.1.
|
|
1
|
+
ygg-0.1.56.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
2
2
|
yggdrasil/__init__.py,sha256=4-ghPak2S6zfMqmnlxW2GCgPb5s79znpKa2hGEGXcE4,24
|
|
3
|
-
yggdrasil/
|
|
3
|
+
yggdrasil/exceptions.py,sha256=NEpbDFn-8ZRsLiEgJicCwrTHNMWAGtdrTJzosfAeVJo,82
|
|
4
|
+
yggdrasil/version.py,sha256=c0ITmemMU7anWgDBUD3t_BAhA3li2gt3XgswCbHv1oU,22
|
|
4
5
|
yggdrasil/databricks/__init__.py,sha256=skctY2c8W-hI81upx9F_PWRe5ishL3hrdiTuizgDjdw,152
|
|
5
6
|
yggdrasil/databricks/ai/__init__.py,sha256=Mkp70UOVBzDQvdPNsqncHcyzxe5PnSGYE_bHnYxA1eA,21
|
|
6
|
-
yggdrasil/databricks/ai/loki.py,sha256=
|
|
7
|
+
yggdrasil/databricks/ai/loki.py,sha256=HyVWxzJgfW03YO6TMOTJ1oNvrBJovnqYDn_MeNV6Ni0,11989
|
|
7
8
|
yggdrasil/databricks/compute/__init__.py,sha256=NvdzmaJSNYY1uJthv1hHdBuNu3bD_-Z65DWnaJt9yXg,289
|
|
8
9
|
yggdrasil/databricks/compute/cluster.py,sha256=YomLfvB0oxbgl6WDgBRxI1UXsxwlEbR6gq3FUbPHscY,44199
|
|
9
|
-
yggdrasil/databricks/compute/
|
|
10
|
+
yggdrasil/databricks/compute/exceptions.py,sha256=Ug0ioxu5m2atdTX4OLH0s4R4dylHNxEdn7VhQI66b5M,209
|
|
11
|
+
yggdrasil/databricks/compute/execution_context.py,sha256=mhcwSvKTxgcUHdb7huSEjCVU_feiXSGq0JLyLXldjQM,23952
|
|
10
12
|
yggdrasil/databricks/compute/remote.py,sha256=yicEhyQypssRa2ByscO36s3cBkEgORFsRME9aaq91Pc,3045
|
|
11
13
|
yggdrasil/databricks/jobs/__init__.py,sha256=snxGSJb0M5I39v0y3IR-uEeSlZR248cQ_4DJ1sYs-h8,154
|
|
12
14
|
yggdrasil/databricks/jobs/config.py,sha256=9LGeHD04hbfy0xt8_6oobC4moKJh4_DTjZiK4Q2Tqjk,11557
|
|
@@ -19,7 +21,7 @@ yggdrasil/databricks/sql/warehouse.py,sha256=1J0dyQLJb-OS1_1xU1eAVZ4CoL2-FhFeowK
|
|
|
19
21
|
yggdrasil/databricks/workspaces/__init__.py,sha256=dv2zotoFVhNFlTCdRq6gwf5bEzeZkOZszoNZMs0k59g,114
|
|
20
22
|
yggdrasil/databricks/workspaces/filesytem.py,sha256=Z8JXU7_XUEbw9fpTQT1avRQKi-IAP2KemXBMPkUoY4w,9805
|
|
21
23
|
yggdrasil/databricks/workspaces/io.py,sha256=PAoxIxYvTC162Dx2qL2hk8oAdt8BnYrQ3jJHcJm4VkA,33116
|
|
22
|
-
yggdrasil/databricks/workspaces/path.py,sha256=
|
|
24
|
+
yggdrasil/databricks/workspaces/path.py,sha256=h1j3bvjwKcDhJvlU_kAaLcLVz4jrdaWgjqPQMarZRHU,55233
|
|
23
25
|
yggdrasil/databricks/workspaces/path_kind.py,sha256=rhWe1ky7uPD0du0bZSv2S4fK4C5zWd7zAF3UeS2iiPU,283
|
|
24
26
|
yggdrasil/databricks/workspaces/volumes_path.py,sha256=s8CA33cG3jpMVJy5MILLlkEBcFg_qInDCF2jozLj1Fg,2431
|
|
25
27
|
yggdrasil/databricks/workspaces/workspace.py,sha256=5DCPz5io_rmrpGNi5I6RChmyZ8kjlNUFGQl8mzQJThg,25511
|
|
@@ -59,8 +61,8 @@ yggdrasil/types/cast/registry.py,sha256=OOqIfbIjPH-a3figvu-zTvEtUDTEWhe2xIl3cCA4
|
|
|
59
61
|
yggdrasil/types/cast/spark_cast.py,sha256=_KAsl1DqmKMSfWxqhVE7gosjYdgiL1C5bDQv6eP3HtA,24926
|
|
60
62
|
yggdrasil/types/cast/spark_pandas_cast.py,sha256=BuTiWrdCANZCdD_p2MAytqm74eq-rdRXd-LGojBRrfU,5023
|
|
61
63
|
yggdrasil/types/cast/spark_polars_cast.py,sha256=btmZNHXn2NSt3fUuB4xg7coaE0RezIBdZD92H8NK0Jw,9073
|
|
62
|
-
ygg-0.1.
|
|
63
|
-
ygg-0.1.
|
|
64
|
-
ygg-0.1.
|
|
65
|
-
ygg-0.1.
|
|
66
|
-
ygg-0.1.
|
|
64
|
+
ygg-0.1.56.dist-info/METADATA,sha256=Rr4DBB8q39XEEzBTQtYB_pja448GR74cEh3gc8-BKvY,18528
|
|
65
|
+
ygg-0.1.56.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
66
|
+
ygg-0.1.56.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
|
|
67
|
+
ygg-0.1.56.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
|
|
68
|
+
ygg-0.1.56.dist-info/RECORD,,
|
yggdrasil/databricks/ai/loki.py
CHANGED
|
@@ -1,44 +1,79 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""
|
|
2
|
+
loki.py
|
|
3
|
+
|
|
4
|
+
Databricks Model Serving (OpenAI-compatible) wrapper with:
|
|
5
|
+
- Loki.ask(): stateless call
|
|
6
|
+
- TradingChatSession: stateful commodity trading analytics chat
|
|
7
|
+
- SqlChatSession: stateful Databricks SQL generator chat
|
|
8
|
+
|
|
9
|
+
Important constraint:
|
|
10
|
+
- Gemini models only support ONE system prompt.
|
|
11
|
+
=> We must NOT send multiple system messages.
|
|
12
|
+
=> We fold summary + context blocks into a single system string.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import Any, Dict, List, Optional, Union
|
|
3
20
|
|
|
4
21
|
from ..workspaces.workspace import WorkspaceService
|
|
5
22
|
|
|
6
23
|
try:
|
|
24
|
+
from openai import OpenAI as _OpenAI # noqa: F401
|
|
25
|
+
except ImportError:
|
|
26
|
+
_OpenAI = None # type: ignore
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def make_ai_client(api_key: str, base_url: str):
|
|
30
|
+
"""Late import so module can load even if openai isn't installed."""
|
|
7
31
|
from openai import OpenAI
|
|
32
|
+
return OpenAI(api_key=api_key, base_url=base_url)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
__all__ = ["Loki", "TradingChatSession", "SqlChatSession"]
|
|
8
36
|
|
|
9
|
-
def make_ai_client(
|
|
10
|
-
api_key: str,
|
|
11
|
-
base_url: str
|
|
12
|
-
):
|
|
13
|
-
return OpenAI(
|
|
14
|
-
api_key=api_key,
|
|
15
|
-
base_url=base_url
|
|
16
|
-
)
|
|
17
|
-
except ImportError:
|
|
18
|
-
class OpenAI:
|
|
19
|
-
pass
|
|
20
|
-
|
|
21
|
-
def make_ai_client(
|
|
22
|
-
api_key: str,
|
|
23
|
-
base_url: str
|
|
24
|
-
):
|
|
25
|
-
from openai import OpenAI
|
|
26
|
-
|
|
27
|
-
return OpenAI(
|
|
28
|
-
api_key=api_key,
|
|
29
|
-
base_url=base_url
|
|
30
|
-
)
|
|
31
37
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
38
|
+
DEFAULT_TRADING_SYSTEM_PROMPT = """You are Loki: a conversational commodity trading analytics copilot.
|
|
39
|
+
|
|
40
|
+
Scope:
|
|
41
|
+
- Commodity trading analytics: curves, forwards, spreads, basis, hedging, risk, PnL explain, inventory, scheduling.
|
|
42
|
+
- Databricks-first workflows: Spark/Delta/Unity Catalog, Databricks SQL, performant Python.
|
|
43
|
+
|
|
44
|
+
Rules:
|
|
45
|
+
- Do NOT invent prices, positions, PnL, risk, or market facts not provided.
|
|
46
|
+
- State assumptions explicitly (units, time conventions, delivery months, calendars).
|
|
47
|
+
- Prefer actionable output (SQL + efficient Python). Avoid slow patterns.
|
|
48
|
+
- If data is missing, list exactly what you need and proceed with a reasonable template.
|
|
49
|
+
|
|
50
|
+
Style:
|
|
51
|
+
- Be concise, practical, performance-focused.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
DEFAULT_SQL_SYSTEM_PROMPT = """You are LokiSQL: a Databricks SQL generator for commodity trading analytics.
|
|
55
|
+
|
|
56
|
+
Hard rules:
|
|
57
|
+
- Output ONLY SQL unless the user explicitly asks for explanation.
|
|
58
|
+
- Use Databricks SQL / Spark SQL dialect.
|
|
59
|
+
- Prefer readable CTEs, explicit column lists, deterministic joins.
|
|
60
|
+
- Do NOT invent table/column names. If missing, use placeholders like <table>, <col>.
|
|
61
|
+
- If ambiguous, output best-effort SQL template with SQL comments (-- TODO ...) and placeholders.
|
|
62
|
+
- Performance: push filters early, avoid exploding joins, avoid SELECT *.
|
|
63
|
+
|
|
64
|
+
Default assumptions:
|
|
65
|
+
- Dates UTC unless specified.
|
|
66
|
+
"""
|
|
35
67
|
|
|
36
68
|
|
|
37
69
|
@dataclass
|
|
38
70
|
class Loki(WorkspaceService):
|
|
39
|
-
|
|
71
|
+
"""
|
|
72
|
+
Loki wraps an OpenAI-compatible client pointing at Databricks Model Serving endpoints.
|
|
73
|
+
"""
|
|
40
74
|
|
|
41
|
-
|
|
75
|
+
model: str = "databricks-gemini-2-5-flash"
|
|
76
|
+
_ai_client: Optional[Any] = field(repr=False, hash=False, default=None)
|
|
42
77
|
|
|
43
78
|
@property
|
|
44
79
|
def ai_client(self):
|
|
@@ -47,9 +82,10 @@ class Loki(WorkspaceService):
|
|
|
47
82
|
return self._ai_client
|
|
48
83
|
|
|
49
84
|
def make_aiclient(self):
|
|
85
|
+
host = self.workspace.host.rstrip("/")
|
|
50
86
|
return make_ai_client(
|
|
51
87
|
api_key=self.workspace.current_token(),
|
|
52
|
-
base_url=
|
|
88
|
+
base_url=f"{host}/serving-endpoints",
|
|
53
89
|
)
|
|
54
90
|
|
|
55
91
|
def ask(
|
|
@@ -60,55 +96,25 @@ class Loki(WorkspaceService):
|
|
|
60
96
|
max_tokens: int = 5000,
|
|
61
97
|
temperature: Optional[float] = None,
|
|
62
98
|
extra_messages: Optional[List[Dict[str, str]]] = None,
|
|
63
|
-
**kwargs,
|
|
99
|
+
**kwargs: Any,
|
|
64
100
|
) -> str:
|
|
65
101
|
"""
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
Parameters
|
|
73
|
-
----------
|
|
74
|
-
command:
|
|
75
|
-
The user prompt text to send.
|
|
76
|
-
system:
|
|
77
|
-
Optional system instruction (prepended as the first message).
|
|
78
|
-
max_tokens:
|
|
79
|
-
Upper bound on the number of tokens generated in the response.
|
|
80
|
-
temperature:
|
|
81
|
-
Optional sampling temperature. If None, the client/model default is used.
|
|
82
|
-
extra_messages:
|
|
83
|
-
Optional list of additional chat messages to insert before the user prompt.
|
|
84
|
-
Each item should be a dict like {"role": "...", "content": "..."}.
|
|
85
|
-
Useful for few-shot examples or carrying prior context.
|
|
86
|
-
**kwargs:
|
|
87
|
-
Any additional parameters forwarded directly to
|
|
88
|
-
`chat.completions.create(...)` (e.g. top_p, presence_penalty, etc.).
|
|
89
|
-
|
|
90
|
-
Returns
|
|
91
|
-
-------
|
|
92
|
-
str
|
|
93
|
-
The assistant message content (empty string if the API returns no content).
|
|
94
|
-
|
|
95
|
-
Raises
|
|
96
|
-
------
|
|
97
|
-
Exception
|
|
98
|
-
Propagates any exceptions raised by the OpenAI client (HTTP errors,
|
|
99
|
-
auth errors, invalid request errors, timeouts, etc.).
|
|
102
|
+
Stateless single call to the model.
|
|
103
|
+
|
|
104
|
+
NOTE (Gemini constraint):
|
|
105
|
+
- Provide at most ONE system prompt (i.e., a single system message).
|
|
106
|
+
- Do not pass additional messages with role="system".
|
|
100
107
|
"""
|
|
101
108
|
messages: List[Dict[str, str]] = []
|
|
102
|
-
|
|
103
109
|
if system:
|
|
104
110
|
messages.append({"role": "system", "content": system})
|
|
105
111
|
|
|
106
112
|
if extra_messages:
|
|
113
|
+
# IMPORTANT: caller must not include any "system" roles here for Gemini models
|
|
107
114
|
messages.extend(extra_messages)
|
|
108
115
|
|
|
109
116
|
messages.append({"role": "user", "content": command})
|
|
110
117
|
|
|
111
|
-
# Build params cleanly (only include temperature if provided)
|
|
112
118
|
params: Dict[str, Any] = dict(
|
|
113
119
|
model=self.model,
|
|
114
120
|
messages=messages,
|
|
@@ -120,3 +126,249 @@ class Loki(WorkspaceService):
|
|
|
120
126
|
|
|
121
127
|
resp = self.ai_client.chat.completions.create(**params)
|
|
122
128
|
return resp.choices[0].message.content or ""
|
|
129
|
+
|
|
130
|
+
def new_trading_chat(
|
|
131
|
+
self,
|
|
132
|
+
*,
|
|
133
|
+
system_prompt: str = DEFAULT_TRADING_SYSTEM_PROMPT,
|
|
134
|
+
max_context_turns: int = 20,
|
|
135
|
+
max_context_chars: int = 120_000,
|
|
136
|
+
) -> "TradingChatSession":
|
|
137
|
+
return TradingChatSession(
|
|
138
|
+
loki=self,
|
|
139
|
+
system_prompt=system_prompt,
|
|
140
|
+
max_context_turns=max_context_turns,
|
|
141
|
+
max_context_chars=max_context_chars,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def new_sql_chat(
|
|
145
|
+
self,
|
|
146
|
+
*,
|
|
147
|
+
system_prompt: str = DEFAULT_SQL_SYSTEM_PROMPT,
|
|
148
|
+
max_context_turns: int = 20,
|
|
149
|
+
max_context_chars: int = 120_000,
|
|
150
|
+
) -> "SqlChatSession":
|
|
151
|
+
return SqlChatSession(
|
|
152
|
+
loki=self,
|
|
153
|
+
system_prompt=system_prompt,
|
|
154
|
+
max_context_turns=max_context_turns,
|
|
155
|
+
max_context_chars=max_context_chars,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclass
|
|
160
|
+
class _BaseChatSession:
|
|
161
|
+
"""
|
|
162
|
+
Stateful session that maintains history + injected context blocks.
|
|
163
|
+
|
|
164
|
+
Gemini constraint:
|
|
165
|
+
- We must fold ALL system content into one system string.
|
|
166
|
+
- Therefore summary/context_blocks are concatenated into the system prompt.
|
|
167
|
+
"""
|
|
168
|
+
loki: Loki
|
|
169
|
+
system_prompt: str
|
|
170
|
+
|
|
171
|
+
history: List[Dict[str, str]] = field(default_factory=list)
|
|
172
|
+
summary: Optional[str] = None
|
|
173
|
+
context_blocks: List[str] = field(default_factory=list)
|
|
174
|
+
|
|
175
|
+
max_context_turns: int = 20
|
|
176
|
+
max_context_chars: int = 120_000
|
|
177
|
+
|
|
178
|
+
def reset(self) -> None:
|
|
179
|
+
self.history.clear()
|
|
180
|
+
self.summary = None
|
|
181
|
+
self.context_blocks.clear()
|
|
182
|
+
|
|
183
|
+
def add_context(self, title: str, payload: Union[str, Dict[str, Any], List[Any]]) -> None:
|
|
184
|
+
if isinstance(payload, str):
|
|
185
|
+
payload_str = payload
|
|
186
|
+
else:
|
|
187
|
+
payload_str = json.dumps(payload, ensure_ascii=False, indent=2)
|
|
188
|
+
|
|
189
|
+
self.context_blocks.append(f"[Context: {title}]\n{payload_str}".strip())
|
|
190
|
+
self._trim()
|
|
191
|
+
|
|
192
|
+
def _estimate_chars(self, msgs: List[Dict[str, str]]) -> int:
|
|
193
|
+
return sum(len(m.get("content", "")) for m in msgs)
|
|
194
|
+
|
|
195
|
+
def _build_system(self, extra_system: Optional[str] = None) -> str:
|
|
196
|
+
parts: List[str] = [self.system_prompt.strip()]
|
|
197
|
+
if extra_system:
|
|
198
|
+
parts.append(extra_system.strip())
|
|
199
|
+
if self.summary:
|
|
200
|
+
parts.append(f"[ConversationSummary]\n{self.summary}".strip())
|
|
201
|
+
if self.context_blocks:
|
|
202
|
+
parts.append("\n\n".join(self.context_blocks).strip())
|
|
203
|
+
return "\n\n".join(p for p in parts if p)
|
|
204
|
+
|
|
205
|
+
def _trim(self) -> None:
|
|
206
|
+
# Turn trim (keep last N turns => N*2 messages)
|
|
207
|
+
if self.max_context_turns > 0:
|
|
208
|
+
max_msgs = self.max_context_turns * 2
|
|
209
|
+
if len(self.history) > max_msgs:
|
|
210
|
+
self.history = self.history[-max_msgs:]
|
|
211
|
+
|
|
212
|
+
# Char trim: shrink history first, then context blocks if needed
|
|
213
|
+
def total_chars() -> int:
|
|
214
|
+
sys_len = len(self._build_system())
|
|
215
|
+
return sys_len + self._estimate_chars(self.history)
|
|
216
|
+
|
|
217
|
+
while total_chars() > self.max_context_chars and self.history:
|
|
218
|
+
self.history = self.history[1:]
|
|
219
|
+
|
|
220
|
+
while total_chars() > self.max_context_chars and self.context_blocks:
|
|
221
|
+
self.context_blocks = self.context_blocks[1:]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@dataclass
|
|
225
|
+
class TradingChatSession(_BaseChatSession):
|
|
226
|
+
"""
|
|
227
|
+
Commodity trading analytics chat session.
|
|
228
|
+
Optionally returns structured JSON for downstream automation.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def chat(
|
|
232
|
+
self,
|
|
233
|
+
user_text: str,
|
|
234
|
+
*,
|
|
235
|
+
structured: bool = True,
|
|
236
|
+
max_tokens: int = 12000,
|
|
237
|
+
temperature: Optional[float] = None,
|
|
238
|
+
**kwargs: Any,
|
|
239
|
+
) -> Union[str, Dict[str, Any]]:
|
|
240
|
+
self._trim()
|
|
241
|
+
|
|
242
|
+
extra_system = None
|
|
243
|
+
if structured:
|
|
244
|
+
extra_system = (
|
|
245
|
+
"Respond ONLY as valid JSON with keys: "
|
|
246
|
+
"final_answer (string), assumptions (array of strings), data_needed (array of strings), "
|
|
247
|
+
"sql (string or null), python (string or null). "
|
|
248
|
+
"No markdown. No extra keys."
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
system = self._build_system(extra_system=extra_system)
|
|
252
|
+
|
|
253
|
+
assistant_text = self.loki.ask(
|
|
254
|
+
user_text,
|
|
255
|
+
system=system,
|
|
256
|
+
extra_messages=self.history, # NOTE: history must contain no system roles
|
|
257
|
+
max_tokens=max_tokens,
|
|
258
|
+
temperature=temperature,
|
|
259
|
+
**kwargs,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
self.history.append({"role": "user", "content": user_text})
|
|
263
|
+
self.history.append({"role": "assistant", "content": assistant_text})
|
|
264
|
+
self._trim()
|
|
265
|
+
|
|
266
|
+
if structured:
|
|
267
|
+
parsed = _try_parse_json_object(assistant_text)
|
|
268
|
+
if parsed is not None:
|
|
269
|
+
return parsed
|
|
270
|
+
|
|
271
|
+
return assistant_text
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@dataclass
|
|
275
|
+
class SqlChatSession(_BaseChatSession):
|
|
276
|
+
"""
|
|
277
|
+
SQL-only conversational session that generates Databricks SQL.
|
|
278
|
+
|
|
279
|
+
Uses a single system message with strict instructions to output SQL only.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def generate_sql(
|
|
283
|
+
self,
|
|
284
|
+
request: str,
|
|
285
|
+
*,
|
|
286
|
+
max_tokens: int = 12000,
|
|
287
|
+
temperature: Optional[float] = None,
|
|
288
|
+
sql_only: bool = True,
|
|
289
|
+
**kwargs: Any,
|
|
290
|
+
) -> str:
|
|
291
|
+
self._trim()
|
|
292
|
+
|
|
293
|
+
extra_system = None
|
|
294
|
+
if sql_only:
|
|
295
|
+
extra_system = (
|
|
296
|
+
"Reminder: Output ONLY SQL. "
|
|
297
|
+
"If ambiguity exists, use SQL comments (-- TODO ...) and placeholders, but still output SQL only."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
system = self._build_system(extra_system=extra_system)
|
|
301
|
+
|
|
302
|
+
sql = self.loki.ask(
|
|
303
|
+
request,
|
|
304
|
+
system=system,
|
|
305
|
+
extra_messages=self.history, # history must contain no system roles
|
|
306
|
+
max_tokens=max_tokens,
|
|
307
|
+
temperature=temperature,
|
|
308
|
+
**kwargs,
|
|
309
|
+
).strip()
|
|
310
|
+
|
|
311
|
+
self.history.append({"role": "user", "content": request})
|
|
312
|
+
self.history.append({"role": "assistant", "content": sql})
|
|
313
|
+
self._trim()
|
|
314
|
+
|
|
315
|
+
return _strip_sql_fences(sql)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _strip_sql_fences(text: str) -> str:
|
|
319
|
+
t = text.strip()
|
|
320
|
+
if t.startswith("```"):
|
|
321
|
+
lines = t.splitlines()
|
|
322
|
+
lines = lines[1:] # drop ``` or ```sql
|
|
323
|
+
if lines and lines[-1].strip().startswith("```"):
|
|
324
|
+
lines = lines[:-1]
|
|
325
|
+
return "\n".join(lines).strip()
|
|
326
|
+
return t
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _strip_markdown_fences(text: str) -> str:
|
|
330
|
+
"""
|
|
331
|
+
Remove ```lang ... ``` fences if present.
|
|
332
|
+
Keeps inner content unchanged.
|
|
333
|
+
"""
|
|
334
|
+
t = text.strip()
|
|
335
|
+
if not t.startswith("```"):
|
|
336
|
+
return t
|
|
337
|
+
|
|
338
|
+
lines = t.splitlines()
|
|
339
|
+
if not lines:
|
|
340
|
+
return t
|
|
341
|
+
|
|
342
|
+
# Drop first line: ``` or ```json
|
|
343
|
+
lines = lines[1:]
|
|
344
|
+
|
|
345
|
+
# Drop last line if it's ```
|
|
346
|
+
if lines and lines[-1].strip().startswith("```"):
|
|
347
|
+
lines = lines[:-1]
|
|
348
|
+
|
|
349
|
+
return "\n".join(lines).strip()
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _try_parse_json_object(text: str) -> Optional[Dict[str, Any]]:
|
|
353
|
+
t = _strip_markdown_fences(text).strip()
|
|
354
|
+
|
|
355
|
+
# Best effort extraction if there's extra junk around JSON
|
|
356
|
+
if not t.startswith("{"):
|
|
357
|
+
start = t.find("{")
|
|
358
|
+
end = t.rfind("}")
|
|
359
|
+
if start != -1 and end != -1 and end > start:
|
|
360
|
+
t = t[start : end + 1]
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
obj = json.loads(t)
|
|
364
|
+
except Exception:
|
|
365
|
+
return None
|
|
366
|
+
|
|
367
|
+
if not isinstance(obj, dict):
|
|
368
|
+
return None
|
|
369
|
+
|
|
370
|
+
required = {"final_answer", "assumptions", "data_needed", "sql", "python"}
|
|
371
|
+
if not required.issubset(set(obj.keys())):
|
|
372
|
+
return None
|
|
373
|
+
|
|
374
|
+
return obj
|
|
@@ -16,6 +16,7 @@ from threading import Thread
|
|
|
16
16
|
from types import ModuleType
|
|
17
17
|
from typing import TYPE_CHECKING, Optional, Any, Callable, List, Dict, Union, Iterable, Tuple
|
|
18
18
|
|
|
19
|
+
from .exceptions import CommandAborted
|
|
19
20
|
from ...libs.databrickslib import databricks_sdk
|
|
20
21
|
from ...pyutils.exceptions import raise_parsed_traceback
|
|
21
22
|
from ...pyutils.expiring_dict import ExpiringDict
|
|
@@ -110,16 +111,12 @@ class ExecutionContext:
|
|
|
110
111
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
111
112
|
"""Exit the context manager and close the remote context if created."""
|
|
112
113
|
if not self._was_connected:
|
|
113
|
-
self.close()
|
|
114
|
+
self.close(wait=False)
|
|
114
115
|
self.cluster.__exit__(exc_type, exc_val=exc_val, exc_tb=exc_tb)
|
|
115
116
|
|
|
116
117
|
def __del__(self):
|
|
117
118
|
"""Best-effort cleanup for the remote execution context."""
|
|
118
|
-
|
|
119
|
-
try:
|
|
120
|
-
Thread(target=self.close).start()
|
|
121
|
-
except BaseException:
|
|
122
|
-
pass
|
|
119
|
+
self.close(wait=False)
|
|
123
120
|
|
|
124
121
|
@property
|
|
125
122
|
def remote_metadata(self) -> RemoteMetadata:
|
|
@@ -180,7 +177,7 @@ print(json.dumps(meta))"""
|
|
|
180
177
|
"""
|
|
181
178
|
return self.cluster.workspace.sdk()
|
|
182
179
|
|
|
183
|
-
def
|
|
180
|
+
def create(
|
|
184
181
|
self,
|
|
185
182
|
language: "Language",
|
|
186
183
|
) -> any:
|
|
@@ -197,15 +194,17 @@ print(json.dumps(meta))"""
|
|
|
197
194
|
self.cluster
|
|
198
195
|
)
|
|
199
196
|
|
|
197
|
+
client = self._workspace_client().command_execution
|
|
198
|
+
|
|
200
199
|
try:
|
|
201
|
-
created =
|
|
200
|
+
created = client.create_and_wait(
|
|
202
201
|
cluster_id=self.cluster.cluster_id,
|
|
203
202
|
language=language,
|
|
204
203
|
)
|
|
205
204
|
except:
|
|
206
205
|
self.cluster.ensure_running()
|
|
207
206
|
|
|
208
|
-
created =
|
|
207
|
+
created = client.create_and_wait(
|
|
209
208
|
cluster_id=self.cluster.cluster_id,
|
|
210
209
|
language=language,
|
|
211
210
|
)
|
|
@@ -217,42 +216,38 @@ print(json.dumps(meta))"""
|
|
|
217
216
|
|
|
218
217
|
created = getattr(created, "response", created)
|
|
219
218
|
|
|
220
|
-
|
|
219
|
+
self.context_id = created.id
|
|
220
|
+
|
|
221
|
+
return self
|
|
221
222
|
|
|
222
223
|
def connect(
|
|
223
224
|
self,
|
|
224
|
-
language: Optional["Language"] = None
|
|
225
|
+
language: Optional["Language"] = None,
|
|
226
|
+
reset: bool = False
|
|
225
227
|
) -> "ExecutionContext":
|
|
226
228
|
"""Create a remote command execution context if not already open.
|
|
227
229
|
|
|
228
230
|
Args:
|
|
229
231
|
language: Optional language override for the context.
|
|
232
|
+
reset: Reset existing if connected
|
|
230
233
|
|
|
231
234
|
Returns:
|
|
232
235
|
The connected ExecutionContext instance.
|
|
233
236
|
"""
|
|
234
237
|
if self.context_id is not None:
|
|
235
|
-
|
|
238
|
+
if not reset:
|
|
239
|
+
return self
|
|
236
240
|
|
|
237
|
-
|
|
241
|
+
self.close(wait=False)
|
|
238
242
|
|
|
239
|
-
|
|
240
|
-
self.language = Language.PYTHON
|
|
243
|
+
language = language or self.language
|
|
241
244
|
|
|
242
|
-
|
|
245
|
+
if language is None:
|
|
246
|
+
language = Language.PYTHON
|
|
243
247
|
|
|
244
|
-
|
|
245
|
-
if not context_id:
|
|
246
|
-
raise RuntimeError("Failed to create command execution context")
|
|
248
|
+
return self.create(language=language)
|
|
247
249
|
|
|
248
|
-
|
|
249
|
-
LOGGER.info(
|
|
250
|
-
"Opened execution context for %s",
|
|
251
|
-
self
|
|
252
|
-
)
|
|
253
|
-
return self
|
|
254
|
-
|
|
255
|
-
def close(self) -> None:
|
|
250
|
+
def close(self, wait: bool = True) -> None:
|
|
256
251
|
"""Destroy the remote command execution context if it exists.
|
|
257
252
|
|
|
258
253
|
Returns:
|
|
@@ -261,12 +256,23 @@ print(json.dumps(meta))"""
|
|
|
261
256
|
if not self.context_id:
|
|
262
257
|
return
|
|
263
258
|
|
|
259
|
+
client = self._workspace_client()
|
|
260
|
+
|
|
264
261
|
try:
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
262
|
+
if wait:
|
|
263
|
+
client.command_execution.destroy(
|
|
264
|
+
cluster_id=self.cluster.cluster_id,
|
|
265
|
+
context_id=self.context_id,
|
|
266
|
+
)
|
|
267
|
+
else:
|
|
268
|
+
Thread(
|
|
269
|
+
target=client.command_execution.destroy,
|
|
270
|
+
kwargs={
|
|
271
|
+
"cluster_id": self.cluster.cluster_id,
|
|
272
|
+
"context_id": self.context_id,
|
|
273
|
+
}
|
|
274
|
+
).start()
|
|
275
|
+
except BaseException:
|
|
270
276
|
# non-fatal: context cleanup best-effort
|
|
271
277
|
pass
|
|
272
278
|
finally:
|
|
@@ -465,7 +471,18 @@ print(json.dumps(meta))"""
|
|
|
465
471
|
)
|
|
466
472
|
|
|
467
473
|
try:
|
|
468
|
-
return self._decode_result(
|
|
474
|
+
return self._decode_result(
|
|
475
|
+
result,
|
|
476
|
+
result_tag=result_tag,
|
|
477
|
+
print_stdout=print_stdout
|
|
478
|
+
)
|
|
479
|
+
except CommandAborted:
|
|
480
|
+
return self.connect(language=self.language, reset=True).execute_command(
|
|
481
|
+
command=command,
|
|
482
|
+
timeout=timeout,
|
|
483
|
+
result_tag=result_tag,
|
|
484
|
+
print_stdout=print_stdout
|
|
485
|
+
)
|
|
469
486
|
except ModuleNotFoundError as remote_module_error:
|
|
470
487
|
_MOD_NOT_FOUND_RE = re.compile(r"No module named ['\"]([^'\"]+)['\"]")
|
|
471
488
|
module_name = _MOD_NOT_FOUND_RE.search(str(remote_module_error))
|
|
@@ -660,6 +677,9 @@ with zipfile.ZipFile(buf, "r") as zf:
|
|
|
660
677
|
if res.result_type == ResultType.ERROR:
|
|
661
678
|
message = res.cause or "Command execution failed"
|
|
662
679
|
|
|
680
|
+
if "client terminated the session" in message:
|
|
681
|
+
raise CommandAborted(message)
|
|
682
|
+
|
|
663
683
|
if self.language == Language.PYTHON:
|
|
664
684
|
raise_parsed_traceback(message)
|
|
665
685
|
|
|
@@ -668,6 +688,7 @@ with zipfile.ZipFile(buf, "r") as zf:
|
|
|
668
688
|
or getattr(res, "stack_trace", None)
|
|
669
689
|
or getattr(res, "traceback", None)
|
|
670
690
|
)
|
|
691
|
+
|
|
671
692
|
if remote_tb:
|
|
672
693
|
message = f"{message}\n{remote_tb}"
|
|
673
694
|
|
|
@@ -30,7 +30,7 @@ from ...types.cast.registry import convert, register_converter
|
|
|
30
30
|
from ...types.file_format import ExcelFileFormat
|
|
31
31
|
|
|
32
32
|
if databricks is not None:
|
|
33
|
-
from databricks.sdk.service.catalog import VolumeType,
|
|
33
|
+
from databricks.sdk.service.catalog import VolumeType, VolumeInfo
|
|
34
34
|
from databricks.sdk.service.workspace import ObjectType
|
|
35
35
|
from databricks.sdk.errors.platform import (
|
|
36
36
|
NotFound,
|
|
@@ -1236,6 +1236,8 @@ class DatabricksPath:
|
|
|
1236
1236
|
self,
|
|
1237
1237
|
operation: Optional["PathOperation"] = None
|
|
1238
1238
|
):
|
|
1239
|
+
from databricks.sdk.service.catalog import PathOperation
|
|
1240
|
+
|
|
1239
1241
|
if self.kind != DatabricksPathKind.VOLUME:
|
|
1240
1242
|
raise ValueError(f"Cannot generate temporary credentials for {repr(self)}")
|
|
1241
1243
|
|
yggdrasil/exceptions.py
ADDED
yggdrasil/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.56"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|