ygg 0.1.56__py3-none-any.whl → 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/METADATA +1 -1
- ygg-0.1.60.dist-info/RECORD +74 -0
- {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/WHEEL +1 -1
- yggdrasil/ai/__init__.py +2 -0
- yggdrasil/ai/session.py +89 -0
- yggdrasil/ai/sql_session.py +310 -0
- yggdrasil/databricks/__init__.py +0 -3
- yggdrasil/databricks/compute/cluster.py +68 -113
- yggdrasil/databricks/compute/command_execution.py +674 -0
- yggdrasil/databricks/compute/exceptions.py +7 -2
- yggdrasil/databricks/compute/execution_context.py +465 -277
- yggdrasil/databricks/compute/remote.py +4 -14
- yggdrasil/databricks/exceptions.py +10 -0
- yggdrasil/databricks/sql/__init__.py +0 -4
- yggdrasil/databricks/sql/engine.py +161 -173
- yggdrasil/databricks/sql/exceptions.py +9 -1
- yggdrasil/databricks/sql/statement_result.py +108 -120
- yggdrasil/databricks/sql/warehouse.py +331 -92
- yggdrasil/databricks/workspaces/io.py +92 -9
- yggdrasil/databricks/workspaces/path.py +120 -74
- yggdrasil/databricks/workspaces/workspace.py +212 -68
- yggdrasil/libs/databrickslib.py +23 -18
- yggdrasil/libs/extensions/spark_extensions.py +1 -1
- yggdrasil/libs/pandaslib.py +15 -6
- yggdrasil/libs/polarslib.py +49 -13
- yggdrasil/pyutils/__init__.py +1 -0
- yggdrasil/pyutils/callable_serde.py +12 -19
- yggdrasil/pyutils/exceptions.py +16 -0
- yggdrasil/pyutils/mimetypes.py +0 -0
- yggdrasil/pyutils/python_env.py +13 -12
- yggdrasil/pyutils/waiting_config.py +171 -0
- yggdrasil/types/cast/arrow_cast.py +3 -0
- yggdrasil/types/cast/pandas_cast.py +157 -169
- yggdrasil/types/cast/polars_cast.py +11 -43
- yggdrasil/types/dummy_class.py +81 -0
- yggdrasil/version.py +1 -1
- ygg-0.1.56.dist-info/RECORD +0 -68
- yggdrasil/databricks/ai/__init__.py +0 -1
- yggdrasil/databricks/ai/loki.py +0 -374
- {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/entry_points.txt +0 -0
- {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/licenses/LICENSE +0 -0
- {ygg-0.1.56.dist-info → ygg-0.1.60.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
ygg-0.1.60.dist-info/licenses/LICENSE,sha256=HrhfyXIkWY2tGFK11kg7vPCqhgh5DcxleloqdhrpyMY,11558
|
|
2
|
+
yggdrasil/__init__.py,sha256=4-ghPak2S6zfMqmnlxW2GCgPb5s79znpKa2hGEGXcE4,24
|
|
3
|
+
yggdrasil/exceptions.py,sha256=NEpbDFn-8ZRsLiEgJicCwrTHNMWAGtdrTJzosfAeVJo,82
|
|
4
|
+
yggdrasil/version.py,sha256=RzpPAn4AEKR5U8Ey0m3Oy_SvSsWT9yeRqhNeTwbK0ks,22
|
|
5
|
+
yggdrasil/ai/__init__.py,sha256=YEOVsyuvEOvPaZT8XN9xNysS_WOpHTbKgXgnA8up7x0,52
|
|
6
|
+
yggdrasil/ai/session.py,sha256=X4btr4OTPLzk1rZx0pZLMJ6Gni1DfEMghAmx9qI1qdE,2579
|
|
7
|
+
yggdrasil/ai/sql_session.py,sha256=n92tQjHUBIey6c3EJProiEEwfAtQm07Dtmei4WXzeG0,10812
|
|
8
|
+
yggdrasil/databricks/__init__.py,sha256=0GRBP930ManOvyo-Y5E7bz7F2msnvU677OH6rxzPwd8,87
|
|
9
|
+
yggdrasil/databricks/exceptions.py,sha256=-ZULt0wD5_Rxww11nk4Z46DvS5j18RdKR5ISmbQfUQA,142
|
|
10
|
+
yggdrasil/databricks/compute/__init__.py,sha256=NvdzmaJSNYY1uJthv1hHdBuNu3bD_-Z65DWnaJt9yXg,289
|
|
11
|
+
yggdrasil/databricks/compute/cluster.py,sha256=5frvxXyeJdAdba7vsGqk2dgfxZEveWsswgrCXZIOOSc,42145
|
|
12
|
+
yggdrasil/databricks/compute/command_execution.py,sha256=osLykTE8S_2gle3XCUiEz8nB4dLKywkiJU4DdEQPT5g,20888
|
|
13
|
+
yggdrasil/databricks/compute/exceptions.py,sha256=OgiAzicmSaAyMzPoeQASdHlvOHnfaNwOt-081_XxCHQ,340
|
|
14
|
+
yggdrasil/databricks/compute/execution_context.py,sha256=MoZwuIHIJJaqo--xJcYIGVpA-eOPFaAjOTbLeVnUzrg,29414
|
|
15
|
+
yggdrasil/databricks/compute/remote.py,sha256=sF99i7GXZcC0GiNgO9VO0I26rFbrtnDhK9vrC2gajuw,2623
|
|
16
|
+
yggdrasil/databricks/jobs/__init__.py,sha256=snxGSJb0M5I39v0y3IR-uEeSlZR248cQ_4DJ1sYs-h8,154
|
|
17
|
+
yggdrasil/databricks/jobs/config.py,sha256=9LGeHD04hbfy0xt8_6oobC4moKJh4_DTjZiK4Q2Tqjk,11557
|
|
18
|
+
yggdrasil/databricks/sql/__init__.py,sha256=PetgRp1jEj5K3TgN09FwNUVjVN8YYuGq0cDIOTqsbns,144
|
|
19
|
+
yggdrasil/databricks/sql/engine.py,sha256=6PVxrO78UxDD6jHBnDjKV2KuZ9JpuvwVckQ_mjsaeKw,49558
|
|
20
|
+
yggdrasil/databricks/sql/exceptions.py,sha256=srMR3Y9LQm45rkyxfyCgpgcoGtRRvGKWBEoUHf4kxsg,1762
|
|
21
|
+
yggdrasil/databricks/sql/statement_result.py,sha256=01DzFX1bGDIGHj0OW2ngfVVJ1w1KHlZEfAI934E35CU,15549
|
|
22
|
+
yggdrasil/databricks/sql/types.py,sha256=5G-BM9_eOsRKEMzeDTWUsWW5g4Idvs-czVCpOCrMhdA,6412
|
|
23
|
+
yggdrasil/databricks/sql/warehouse.py,sha256=W045PMLgZdt7f5w8aWIIX-vSLa5GE4V8yFffyrwZSOQ,18689
|
|
24
|
+
yggdrasil/databricks/workspaces/__init__.py,sha256=dv2zotoFVhNFlTCdRq6gwf5bEzeZkOZszoNZMs0k59g,114
|
|
25
|
+
yggdrasil/databricks/workspaces/filesytem.py,sha256=Z8JXU7_XUEbw9fpTQT1avRQKi-IAP2KemXBMPkUoY4w,9805
|
|
26
|
+
yggdrasil/databricks/workspaces/io.py,sha256=IHmOwX1cWksvfunwTr03BFPqhm8cWNEkCwMvM9vhM80,35162
|
|
27
|
+
yggdrasil/databricks/workspaces/path.py,sha256=R6-RuMG7fZYBWS4wRmbw0bOIxiGYpwif47jFLWmcLGs,56950
|
|
28
|
+
yggdrasil/databricks/workspaces/path_kind.py,sha256=rhWe1ky7uPD0du0bZSv2S4fK4C5zWd7zAF3UeS2iiPU,283
|
|
29
|
+
yggdrasil/databricks/workspaces/volumes_path.py,sha256=s8CA33cG3jpMVJy5MILLlkEBcFg_qInDCF2jozLj1Fg,2431
|
|
30
|
+
yggdrasil/databricks/workspaces/workspace.py,sha256=GEUp3f15SP5lUDx0_Ujzv5QtjVQg00WJRydR4rNLdXs,30216
|
|
31
|
+
yggdrasil/dataclasses/__init__.py,sha256=_RkhfF3KC1eSORby1dzvBXQ0-UGG3u6wyUQWX2jq1Pc,108
|
|
32
|
+
yggdrasil/dataclasses/dataclass.py,sha256=LxrCjwvmBnb8yRI_N-c31RHHxB4XoJPixmKg9iBIuaI,1148
|
|
33
|
+
yggdrasil/libs/__init__.py,sha256=zdC9OU0Xy36CLY9mg2drxN6S7isPR8aTLzJA6xVIeLE,91
|
|
34
|
+
yggdrasil/libs/databrickslib.py,sha256=t_0b_3iCGFPjBrJaIOvNzSEn5pjZBTbY_fOcDHp6qx8,1135
|
|
35
|
+
yggdrasil/libs/pandaslib.py,sha256=_U4sdFvLAFD16_65RG-RFmcx4c3fvVnALESFaAlT71M,887
|
|
36
|
+
yggdrasil/libs/polarslib.py,sha256=WnnERtMTl__ZPidcZkoV7mb8-c680zcAnJgzAoD3ZE8,1437
|
|
37
|
+
yggdrasil/libs/sparklib.py,sha256=FQ3W1iz2EIpQreorOiQuFt15rdhq2QhGEAWp8Zrbl9A,10177
|
|
38
|
+
yggdrasil/libs/extensions/__init__.py,sha256=mcXW5Li3Cbprbs4Ci-b5A0Ju0wmLcfvEiFusTx6xNjU,117
|
|
39
|
+
yggdrasil/libs/extensions/polars_extensions.py,sha256=RTkGi8llhPJjX7x9egix7-yXWo2X24zIAPSKXV37SSA,12397
|
|
40
|
+
yggdrasil/libs/extensions/spark_extensions.py,sha256=ESap-WP4A03bv3kvQK2S9OdplQC79YSpTYRJYJkH8EA,16749
|
|
41
|
+
yggdrasil/pyutils/__init__.py,sha256=AOCLAn9ogFYjXxA6AiUudgNQlzM4A4qQIqJ8MJsm5i0,251
|
|
42
|
+
yggdrasil/pyutils/callable_serde.py,sha256=rGCM0gSejwDDS2xbQimH1q_PggAdSpUsa65XKXMj9DI,22766
|
|
43
|
+
yggdrasil/pyutils/equality.py,sha256=Xyf8D1dLUCm3spDEir8Zyj7O4US_fBJwEylJCfJ9slI,3080
|
|
44
|
+
yggdrasil/pyutils/exceptions.py,sha256=1c0xxFvGML5gkDPGzD_Tgw1ff9bGMVygH8ASgeoII2E,3889
|
|
45
|
+
yggdrasil/pyutils/expiring_dict.py,sha256=pr2u25LGwPVbLfsLptiHGovUtYRRo0AMjaJtCtJl7nQ,8477
|
|
46
|
+
yggdrasil/pyutils/mimetypes.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
47
|
+
yggdrasil/pyutils/modules.py,sha256=B7IP99YqUMW6-DIESFzBx8-09V1d0a8qrIJUDFhhL2g,11424
|
|
48
|
+
yggdrasil/pyutils/parallel.py,sha256=ubuq2m9dJzWYUyKCga4Y_9bpaeMYUrleYxdp49CHr44,6781
|
|
49
|
+
yggdrasil/pyutils/python_env.py,sha256=OFvM0wxVzB1iHC4BGra0tLE3sZL1ZaIr5j9dtDeLaiU,51098
|
|
50
|
+
yggdrasil/pyutils/retry.py,sha256=gXBtn1DdmIYIUmGKOUr8-SUT7MOu97LykN2YR4uocgc,11917
|
|
51
|
+
yggdrasil/pyutils/waiting_config.py,sha256=WiMOiKyGR5iKr83YK4dljn7OCaDpxXMUx8cz-bUNGMg,6255
|
|
52
|
+
yggdrasil/requests/__init__.py,sha256=dMesyzq97_DmI765x0TwaDPEfsxFtgGNgchk8LvEN-o,103
|
|
53
|
+
yggdrasil/requests/msal.py,sha256=s2GCyzbgFdgdlJ1JqMrZ4qYVbmoG46-ZOTcaVQhZ-sQ,9220
|
|
54
|
+
yggdrasil/requests/session.py,sha256=SLnrgHY0Lby7ZxclRFUjHdfM8euN_8bSQEWl7TkJY2U,1461
|
|
55
|
+
yggdrasil/types/__init__.py,sha256=CrLiDeYNM9fO975sE5ufeVKcy7Ca702IsaG2Pk8T3YU,139
|
|
56
|
+
yggdrasil/types/dummy_class.py,sha256=XXM3_ljL4XfY5LeF-WTj-myqHaKAUmWZ23cPDrXAnBM,2327
|
|
57
|
+
yggdrasil/types/file_format.py,sha256=yqAadZ5z6CrctsQO0ZmEY7eGXLbhBUnvvNOwkPSk0GU,133
|
|
58
|
+
yggdrasil/types/python_arrow.py,sha256=mOhyecAxa5u8JWsyTO26OMOWimHHgwLKWlkNSAyIVas,25636
|
|
59
|
+
yggdrasil/types/python_defaults.py,sha256=GO3hZBZcwRHs9qiXes75y8l5X00kZHTfEC7el_x73uw,10184
|
|
60
|
+
yggdrasil/types/cast/__init__.py,sha256=Oft3pTs2bRM5hT7YqJAuOKTYYk-SACLaMOXUVdafy_I,311
|
|
61
|
+
yggdrasil/types/cast/arrow_cast.py,sha256=IZstOcHjLKPy62TFGgjMSW3ttPGt3hMi6RmDw-92T0E,41623
|
|
62
|
+
yggdrasil/types/cast/cast_options.py,sha256=nDaEvCCs7TBamhTWyDrYf3LVaBWzioIP2Q5_LXrChF4,15532
|
|
63
|
+
yggdrasil/types/cast/pandas_cast.py,sha256=6PaHgQyq06XM3_lebmB4PUlSHATL_0l8GFXMFvZGwvc,8890
|
|
64
|
+
yggdrasil/types/cast/polars_cast.py,sha256=7qs8QC0Kn3GgaKfrzqr-O-2_9uy9ZUrcJV4HIaGwoUM,27584
|
|
65
|
+
yggdrasil/types/cast/polars_pandas_cast.py,sha256=CS0P7teVv15IdX5g7v40RfkH1VMg6b-HM0V_gOfacm8,5071
|
|
66
|
+
yggdrasil/types/cast/registry.py,sha256=OOqIfbIjPH-a3figvu-zTvEtUDTEWhe2xIl3cCA4PRM,20941
|
|
67
|
+
yggdrasil/types/cast/spark_cast.py,sha256=_KAsl1DqmKMSfWxqhVE7gosjYdgiL1C5bDQv6eP3HtA,24926
|
|
68
|
+
yggdrasil/types/cast/spark_pandas_cast.py,sha256=BuTiWrdCANZCdD_p2MAytqm74eq-rdRXd-LGojBRrfU,5023
|
|
69
|
+
yggdrasil/types/cast/spark_polars_cast.py,sha256=btmZNHXn2NSt3fUuB4xg7coaE0RezIBdZD92H8NK0Jw,9073
|
|
70
|
+
ygg-0.1.60.dist-info/METADATA,sha256=LBZYw5kRHouaxOl7x_dskRBUpH-XXDk-XaFrSLGKrg0,18528
|
|
71
|
+
ygg-0.1.60.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
72
|
+
ygg-0.1.60.dist-info/entry_points.txt,sha256=6q-vpWG3kvw2dhctQ0LALdatoeefkN855Ev02I1dKGY,70
|
|
73
|
+
ygg-0.1.60.dist-info/top_level.txt,sha256=iBe9Kk4VIVbLpgv_p8OZUIfxgj4dgJ5wBg6vO3rigso,10
|
|
74
|
+
ygg-0.1.60.dist-info/RECORD,,
|
yggdrasil/ai/__init__.py
ADDED
yggdrasil/ai/session.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# ai_session.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from yggdrasil.types.dummy_class import DummyModuleClass
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from openai import OpenAI
|
|
12
|
+
except:
|
|
13
|
+
OpenAI = DummyModuleClass
|
|
14
|
+
|
|
15
|
+
__all__ = ["AISession"]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class AISession(ABC):
|
|
20
|
+
api_key: str
|
|
21
|
+
base_url: str
|
|
22
|
+
|
|
23
|
+
# Gemini default (via OpenAI-compatible gateway)
|
|
24
|
+
model: str = "gemini-2.5-flash"
|
|
25
|
+
|
|
26
|
+
client: OpenAI = field(init=False)
|
|
27
|
+
|
|
28
|
+
def __post_init__(self) -> None:
|
|
29
|
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def system_prompt(self) -> str:
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
def build_messages(
|
|
36
|
+
self,
|
|
37
|
+
user_prompt: str,
|
|
38
|
+
*,
|
|
39
|
+
context_system: Optional[str] = None,
|
|
40
|
+
extra_instructions: Optional[str] = None,
|
|
41
|
+
) -> List[Dict[str, str]]:
|
|
42
|
+
# Minimal message stack: system + optional context + user
|
|
43
|
+
sys = self.system_prompt().strip()
|
|
44
|
+
msgs: List[Dict[str, str]] = [{"role": "system", "content": sys}]
|
|
45
|
+
|
|
46
|
+
if context_system:
|
|
47
|
+
msgs.append({"role": "system", "name": "context", "content": context_system.strip()})
|
|
48
|
+
|
|
49
|
+
msg = user_prompt.strip()
|
|
50
|
+
if extra_instructions:
|
|
51
|
+
msg = f"{msg}\n\nConstraints:\n{extra_instructions.strip()}"
|
|
52
|
+
|
|
53
|
+
msgs.append({"role": "user", "content": msg})
|
|
54
|
+
return msgs
|
|
55
|
+
|
|
56
|
+
def chat(
|
|
57
|
+
self,
|
|
58
|
+
user_prompt: str,
|
|
59
|
+
*,
|
|
60
|
+
context_system: Optional[str] = None,
|
|
61
|
+
extra_instructions: Optional[str] = None,
|
|
62
|
+
temperature: float = 0.0,
|
|
63
|
+
max_output_tokens: int = 320,
|
|
64
|
+
strip_code_fences: bool = True,
|
|
65
|
+
) -> str:
|
|
66
|
+
messages = self.build_messages(
|
|
67
|
+
user_prompt,
|
|
68
|
+
context_system=context_system,
|
|
69
|
+
extra_instructions=extra_instructions,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
resp = self.client.chat.completions.create(
|
|
73
|
+
model=self.model,
|
|
74
|
+
messages=messages,
|
|
75
|
+
temperature=temperature,
|
|
76
|
+
max_tokens=max_output_tokens,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
out = (resp.choices[0].message.content or "").strip()
|
|
80
|
+
|
|
81
|
+
if strip_code_fences and out.startswith("```"):
|
|
82
|
+
out = out.split("```", 2)[1].strip()
|
|
83
|
+
low = out.lower()
|
|
84
|
+
if low.startswith("sql"):
|
|
85
|
+
out = out[3:].strip()
|
|
86
|
+
elif low.startswith("json"):
|
|
87
|
+
out = out[4:].strip()
|
|
88
|
+
|
|
89
|
+
return out
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
# sql_session.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import pyarrow as pa
|
|
10
|
+
|
|
11
|
+
from .session import AISession
|
|
12
|
+
|
|
13
|
+
__all__ = ["SQLFlavor", "SQLAISession"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SQLFlavor(str, Enum):
|
|
17
|
+
DATABRICKS = "databricks"
|
|
18
|
+
POSTGRESQL = "postgresql"
|
|
19
|
+
MONGODB = "mongodb" # aggregation pipeline JSON array, not SQL
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _b2s(x: object) -> str:
|
|
23
|
+
if x is None:
|
|
24
|
+
return ""
|
|
25
|
+
if isinstance(x, bytes):
|
|
26
|
+
return x.decode("utf-8", errors="ignore").strip()
|
|
27
|
+
return str(x).strip()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _try_parse_json_blob(s: str) -> Optional[dict]:
|
|
31
|
+
s = (s or "").strip()
|
|
32
|
+
if not s or not (s.startswith("{") and s.endswith("}")):
|
|
33
|
+
return None
|
|
34
|
+
try:
|
|
35
|
+
obj = json.loads(s)
|
|
36
|
+
return obj if isinstance(obj, dict) else None
|
|
37
|
+
except Exception:
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def parse_qualified_name(name: str) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
42
|
+
n = (name or "").strip().strip('"').strip()
|
|
43
|
+
if not n:
|
|
44
|
+
return None, None, None
|
|
45
|
+
parts = [p.strip().strip('"') for p in n.split(".") if p.strip()]
|
|
46
|
+
if len(parts) == 1:
|
|
47
|
+
return None, None, parts[0]
|
|
48
|
+
if len(parts) == 2:
|
|
49
|
+
return None, parts[0], parts[1]
|
|
50
|
+
return parts[-3], parts[-2], parts[-1]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def extract_object_name_from_metadata(metadata: Optional[Dict[object, object]], flavor: SQLFlavor) -> Optional[str]:
|
|
54
|
+
md_raw = metadata or {}
|
|
55
|
+
md: Dict[str, str] = {}
|
|
56
|
+
for k, v in md_raw.items():
|
|
57
|
+
ks = _b2s(k).lower()
|
|
58
|
+
vs = _b2s(v)
|
|
59
|
+
if ks:
|
|
60
|
+
md[ks] = vs
|
|
61
|
+
|
|
62
|
+
for k in ("table_ref", "full_table_name", "qualified_name", "table", "table_name", "name", "object", "object_name"):
|
|
63
|
+
if k in md:
|
|
64
|
+
obj = _try_parse_json_blob(md[k])
|
|
65
|
+
if obj:
|
|
66
|
+
for kk, vv in obj.items():
|
|
67
|
+
if kk is None:
|
|
68
|
+
continue
|
|
69
|
+
md[str(kk).lower()] = _b2s(vv)
|
|
70
|
+
|
|
71
|
+
if flavor == SQLFlavor.MONGODB:
|
|
72
|
+
for ck in ("collection", "collection_name", "mongo.collection", "mongodb.collection"):
|
|
73
|
+
if md.get(ck):
|
|
74
|
+
return md[ck]
|
|
75
|
+
for tk in ("table", "table_name", "name", "full_table_name", "table_ref"):
|
|
76
|
+
if md.get(tk):
|
|
77
|
+
_, _, coll = parse_qualified_name(md[tk])
|
|
78
|
+
return coll or md[tk]
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
catalog = md.get("catalog") or md.get("unity_catalog") or md.get("uc_catalog") or md.get("db_catalog") or ""
|
|
82
|
+
database = md.get("database") or md.get("db") or ""
|
|
83
|
+
if not catalog and database:
|
|
84
|
+
catalog = database
|
|
85
|
+
|
|
86
|
+
schema_name = md.get("schema") or md.get("namespace") or md.get("database_schema") or ""
|
|
87
|
+
table = md.get("table") or md.get("table_name") or md.get("relation") or md.get("object") or md.get("object_name") or ""
|
|
88
|
+
|
|
89
|
+
if not table:
|
|
90
|
+
for fk in ("table_ref", "full_table_name", "qualified_name", "name"):
|
|
91
|
+
if md.get(fk):
|
|
92
|
+
c, s, t = parse_qualified_name(md[fk])
|
|
93
|
+
catalog = catalog or (c or "")
|
|
94
|
+
schema_name = schema_name or (s or "")
|
|
95
|
+
table = table or (t or "")
|
|
96
|
+
if table:
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
if not table:
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
parts = [p for p in (catalog, schema_name, table) if p]
|
|
103
|
+
return ".".join(parts) if parts else None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def extract_column_comment(field: pa.Field) -> Optional[str]:
|
|
107
|
+
md = field.metadata or {}
|
|
108
|
+
if not md:
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
norm: Dict[str, str] = {}
|
|
112
|
+
for k, v in md.items():
|
|
113
|
+
ks = _b2s(k).lower()
|
|
114
|
+
vs = _b2s(v)
|
|
115
|
+
if ks:
|
|
116
|
+
norm[ks] = vs
|
|
117
|
+
|
|
118
|
+
for key in ("comment", "description", "doc", "column_comment", "spark.comment", "delta.comment"):
|
|
119
|
+
val = norm.get(key, "").strip()
|
|
120
|
+
if val:
|
|
121
|
+
return val
|
|
122
|
+
|
|
123
|
+
for k in ("meta", "metadata", "attrs", "properties"):
|
|
124
|
+
obj = _try_parse_json_blob(norm.get(k, ""))
|
|
125
|
+
if isinstance(obj, dict):
|
|
126
|
+
for kk in ("comment", "description", "doc"):
|
|
127
|
+
vv = _b2s(obj.get(kk, "")).strip()
|
|
128
|
+
if vv:
|
|
129
|
+
return vv
|
|
130
|
+
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# Token-light types for context (not for DDL)
|
|
135
|
+
def _arrow_type_compact(dt: pa.DataType) -> str:
|
|
136
|
+
if pa.types.is_int64(dt):
|
|
137
|
+
return "i64"
|
|
138
|
+
if pa.types.is_int32(dt):
|
|
139
|
+
return "i32"
|
|
140
|
+
if pa.types.is_int16(dt):
|
|
141
|
+
return "i16"
|
|
142
|
+
if pa.types.is_int8(dt):
|
|
143
|
+
return "i8"
|
|
144
|
+
if pa.types.is_float64(dt):
|
|
145
|
+
return "f64"
|
|
146
|
+
if pa.types.is_float32(dt) or pa.types.is_float16(dt):
|
|
147
|
+
return "f32"
|
|
148
|
+
if pa.types.is_boolean(dt):
|
|
149
|
+
return "bool"
|
|
150
|
+
if pa.types.is_string(dt) or pa.types.is_large_string(dt):
|
|
151
|
+
return "str"
|
|
152
|
+
if pa.types.is_timestamp(dt):
|
|
153
|
+
return "ts"
|
|
154
|
+
if pa.types.is_date32(dt) or pa.types.is_date64(dt):
|
|
155
|
+
return "date"
|
|
156
|
+
if pa.types.is_decimal(dt):
|
|
157
|
+
return f"dec({dt.precision},{dt.scale})"
|
|
158
|
+
if pa.types.is_list(dt) or pa.types.is_large_list(dt) or pa.types.is_fixed_size_list(dt):
|
|
159
|
+
return f"arr<{_arrow_type_compact(dt.value_type)}>"
|
|
160
|
+
if pa.types.is_struct(dt):
|
|
161
|
+
return "struct"
|
|
162
|
+
if pa.types.is_map(dt):
|
|
163
|
+
return "map"
|
|
164
|
+
return "any"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _fields_signature_compact(fields: Iterable[pa.Field], *, include_comments: bool) -> str:
|
|
168
|
+
# Super compact: col:type[, ...] + optional short comment
|
|
169
|
+
parts: List[str] = []
|
|
170
|
+
for f in fields:
|
|
171
|
+
t = _arrow_type_compact(f.type)
|
|
172
|
+
if include_comments:
|
|
173
|
+
c = extract_column_comment(f)
|
|
174
|
+
if c:
|
|
175
|
+
c = c.replace("\n", " ").strip()
|
|
176
|
+
if len(c) > 60:
|
|
177
|
+
c = c[:57] + "..."
|
|
178
|
+
parts.append(f"{f.name}:{t}#{c}")
|
|
179
|
+
continue
|
|
180
|
+
parts.append(f"{f.name}:{t}")
|
|
181
|
+
return ",".join(parts)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@dataclass
|
|
185
|
+
class SQLAISession(AISession):
|
|
186
|
+
flavor: SQLFlavor = SQLFlavor.DATABRICKS
|
|
187
|
+
max_context_objects: int = 25
|
|
188
|
+
|
|
189
|
+
# Registry
|
|
190
|
+
_fields: Dict[str, List[pa.Field]] = field(default_factory=dict, init=False)
|
|
191
|
+
_objects: Dict[str, str] = field(default_factory=dict, init=False) # alias -> table/collection name
|
|
192
|
+
_meta: Dict[str, Dict[object, object]] = field(default_factory=dict, init=False) # alias -> schema-level metadata
|
|
193
|
+
|
|
194
|
+
# Token controls
|
|
195
|
+
include_comments_in_context: bool = True
|
|
196
|
+
max_tables_in_context: int = 16 # big token saver
|
|
197
|
+
|
|
198
|
+
def system_prompt(self) -> str:
|
|
199
|
+
if self.flavor == SQLFlavor.MONGODB:
|
|
200
|
+
return (
|
|
201
|
+
"Return ONLY a MongoDB aggregation pipeline as a strict JSON array. No prose.\n"
|
|
202
|
+
"Use only fields in context. Prefer $match then $group/$project then $sort then $limit."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# SQL
|
|
206
|
+
dialect = "Databricks/Spark SQL" if self.flavor == SQLFlavor.DATABRICKS else "PostgreSQL"
|
|
207
|
+
return (
|
|
208
|
+
f"Return ONLY {dialect} query text. No prose, no markdown.\n"
|
|
209
|
+
"Use only columns in context. Prefer fully-qualified table names when provided.\n"
|
|
210
|
+
"Keep it short: avoid CTE unless necessary; return a single statement."
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def set_flavor(self, flavor: SQLFlavor) -> None:
|
|
214
|
+
self.flavor = flavor
|
|
215
|
+
|
|
216
|
+
def register_fields(
|
|
217
|
+
self,
|
|
218
|
+
alias: str,
|
|
219
|
+
fields: List[pa.Field],
|
|
220
|
+
*,
|
|
221
|
+
schema_metadata: Optional[Dict[object, object]] = None,
|
|
222
|
+
object_name: Optional[str] = None,
|
|
223
|
+
prefer_metadata_object_name: bool = True,
|
|
224
|
+
) -> None:
|
|
225
|
+
if len(self._fields) >= self.max_context_objects and alias not in self._fields:
|
|
226
|
+
oldest = next(iter(self._fields.keys()))
|
|
227
|
+
self._fields.pop(oldest, None)
|
|
228
|
+
self._objects.pop(oldest, None)
|
|
229
|
+
self._meta.pop(oldest, None)
|
|
230
|
+
|
|
231
|
+
self._fields[alias] = fields
|
|
232
|
+
self._meta[alias] = dict(schema_metadata or {})
|
|
233
|
+
|
|
234
|
+
resolved = object_name
|
|
235
|
+
if resolved is None and prefer_metadata_object_name:
|
|
236
|
+
resolved = extract_object_name_from_metadata(self._meta[alias], self.flavor)
|
|
237
|
+
if resolved:
|
|
238
|
+
self._objects[alias] = resolved
|
|
239
|
+
|
|
240
|
+
def _pick_relevant_aliases(self, prompt: str) -> List[str]:
|
|
241
|
+
"""
|
|
242
|
+
Cheap heuristic to keep context tiny:
|
|
243
|
+
score aliases by presence of alias/object/column names in prompt.
|
|
244
|
+
"""
|
|
245
|
+
p = prompt.lower()
|
|
246
|
+
scores: List[Tuple[int, str]] = []
|
|
247
|
+
|
|
248
|
+
for alias, cols in self._fields.items():
|
|
249
|
+
s = 0
|
|
250
|
+
if alias.lower() in p:
|
|
251
|
+
s += 5
|
|
252
|
+
obj = self._objects.get(alias, "")
|
|
253
|
+
if obj and obj.lower() in p:
|
|
254
|
+
s += 5
|
|
255
|
+
# column hits (cap influence)
|
|
256
|
+
hit = 0
|
|
257
|
+
for f in cols:
|
|
258
|
+
n = f.name.lower()
|
|
259
|
+
if n in p:
|
|
260
|
+
hit += 1
|
|
261
|
+
if hit >= 6:
|
|
262
|
+
break
|
|
263
|
+
s += min(hit, 6)
|
|
264
|
+
scores.append((s, alias))
|
|
265
|
+
|
|
266
|
+
scores.sort(reverse=True)
|
|
267
|
+
picked = [a for s, a in scores if s > 0][: self.max_tables_in_context]
|
|
268
|
+
|
|
269
|
+
# fallback: if nothing matched, include first table only (still minimal)
|
|
270
|
+
if not picked and self._fields:
|
|
271
|
+
picked = [next(iter(self._fields.keys()))]
|
|
272
|
+
|
|
273
|
+
return picked
|
|
274
|
+
|
|
275
|
+
def _build_schema_context(self, aliases: List[str]) -> str:
|
|
276
|
+
lines = ["ctx:"]
|
|
277
|
+
for a in aliases:
|
|
278
|
+
obj = self._objects.get(a, "")
|
|
279
|
+
sig = _fields_signature_compact(
|
|
280
|
+
self._fields[a],
|
|
281
|
+
include_comments=self.include_comments_in_context and self.flavor != SQLFlavor.MONGODB,
|
|
282
|
+
)
|
|
283
|
+
# tiny format: alias=>obj|cols
|
|
284
|
+
if obj:
|
|
285
|
+
lines.append(f"{a}=>{obj}|{sig}")
|
|
286
|
+
else:
|
|
287
|
+
lines.append(f"{a}|{sig}")
|
|
288
|
+
return "\n".join(lines)
|
|
289
|
+
|
|
290
|
+
def generate_query(
|
|
291
|
+
self,
|
|
292
|
+
user_prompt: str,
|
|
293
|
+
*,
|
|
294
|
+
temperature: float = 0.0,
|
|
295
|
+
max_output_tokens: int = 4200,
|
|
296
|
+
extra_instructions: Optional[str] = None,
|
|
297
|
+
tables: Optional[List[str]] = None,
|
|
298
|
+
) -> str:
|
|
299
|
+
# decide which tables to include (token saver)
|
|
300
|
+
aliases = tables if tables else self._pick_relevant_aliases(user_prompt)
|
|
301
|
+
context_system = self._build_schema_context(aliases)
|
|
302
|
+
|
|
303
|
+
return self.chat(
|
|
304
|
+
user_prompt,
|
|
305
|
+
context_system=context_system,
|
|
306
|
+
extra_instructions=extra_instructions,
|
|
307
|
+
temperature=temperature,
|
|
308
|
+
max_output_tokens=max_output_tokens,
|
|
309
|
+
strip_code_fences=True,
|
|
310
|
+
)
|