QuerySUTRA 0.5.0__tar.gz → 0.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {querysutra-0.5.0 → querysutra-0.5.2}/PKG-INFO +1 -1
- {querysutra-0.5.0 → querysutra-0.5.2}/QuerySUTRA.egg-info/PKG-INFO +1 -1
- {querysutra-0.5.0 → querysutra-0.5.2}/pyproject.toml +1 -1
- {querysutra-0.5.0 → querysutra-0.5.2}/setup.py +1 -1
- querysutra-0.5.2/sutra/__init__.py +4 -0
- querysutra-0.5.2/sutra/sutra.py +560 -0
- querysutra-0.5.0/sutra/__init__.py +0 -4
- querysutra-0.5.0/sutra/sutra.py +0 -819
- {querysutra-0.5.0 → querysutra-0.5.2}/LICENSE +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/MANIFEST.in +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/QuerySUTRA.egg-info/SOURCES.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/QuerySUTRA.egg-info/dependency_links.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/QuerySUTRA.egg-info/requires.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/QuerySUTRA.egg-info/top_level.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/README.md +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/examples/quickstart.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/examples/sutra_usage_guide.ipynb +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/examples/usage_guide.ipynb +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/requirements.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/setup.cfg +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/cache_manager.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/clear_cache.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/core.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/data_loader.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/database_manager.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/direct_query.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/feedback.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/feedback_matcher.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/nlp_processor.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/schema_embeddings.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/schema_generator.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/sutra_client.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/sutra_core.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/sutra_simple.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/sutra/visualizer.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/tests/__init__.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/tests/test_modules.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/tests/test_sutra.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/utils/__init__.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/utils/file_utils.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.2}/utils/text_utils.py +0 -0
|
@@ -1,3 +1,3 @@
|
|
|
1
1
|
from setuptools import setup,find_packages
|
|
2
2
|
with open("README.md","r",encoding="utf-8") as f:d=f.read()
|
|
3
|
-
setup(name="QuerySUTRA",version="0.5.
|
|
3
|
+
setup(name="QuerySUTRA",version="0.5.2",author="Aditya Batta",description="SUTRA",long_description=d,long_description_content_type="text/markdown",packages=find_packages(),python_requires=">=3.8",install_requires=["pandas>=1.3.0","numpy>=1.21.0","openai>=1.0.0","plotly>=5.0.0","matplotlib>=3.3.0","PyPDF2>=3.0.0","python-docx>=0.8.11","openpyxl>=3.0.0"],extras_require={"mysql":["sqlalchemy>=1.4.0","mysql-connector-python>=8.0.0"],"postgres":["sqlalchemy>=1.4.0","psycopg2-binary>=2.9.0"],"embeddings":["sentence-transformers>=2.0.0"],"all":["sqlalchemy>=1.4.0","mysql-connector-python>=8.0.0","psycopg2-binary>=2.9.0","sentence-transformers>=2.0.0"]})
|
|
@@ -0,0 +1,560 @@
|
|
|
1
|
+
"""QuerySUTRA v0.5.1 - BULLETPROOF & FIXED"""
|
|
2
|
+
__version__ = "0.5.1"
|
|
3
|
+
__author__ = "Aditya Batta"
|
|
4
|
+
__all__ = ["SUTRA", "QueryResult"]
|
|
5
|
+
|
|
6
|
+
import os, sqlite3, pandas as pd, numpy as np, json, hashlib, shutil, datetime, re
|
|
7
|
+
from typing import Optional, Union, Dict, List
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from difflib import get_close_matches
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from openai import OpenAI
|
|
13
|
+
HAS_OPENAI = True
|
|
14
|
+
except:
|
|
15
|
+
HAS_OPENAI = False
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import plotly.express as px
|
|
19
|
+
import plotly.graph_objects as go
|
|
20
|
+
HAS_PLOTLY = True
|
|
21
|
+
except:
|
|
22
|
+
HAS_PLOTLY = False
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import PyPDF2
|
|
26
|
+
HAS_PYPDF2 = True
|
|
27
|
+
except:
|
|
28
|
+
HAS_PYPDF2 = False
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import docx
|
|
32
|
+
HAS_DOCX = True
|
|
33
|
+
except:
|
|
34
|
+
HAS_DOCX = False
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
from sentence_transformers import SentenceTransformer
|
|
38
|
+
HAS_EMBEDDINGS = True
|
|
39
|
+
except:
|
|
40
|
+
HAS_EMBEDDINGS = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SUTRA:
|
|
44
|
+
"""SUTRA - BULLETPROOF"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, api_key: Optional[str] = None, db: str = "sutra.db",
|
|
47
|
+
use_embeddings: bool = False, fuzzy_match: bool = True,
|
|
48
|
+
cache_queries: bool = True, check_relevance: bool = False):
|
|
49
|
+
|
|
50
|
+
if api_key:
|
|
51
|
+
os.environ["OPENAI_API_KEY"] = api_key
|
|
52
|
+
|
|
53
|
+
self.api_key = os.getenv("OPENAI_API_KEY")
|
|
54
|
+
self.client = OpenAI(api_key=self.api_key) if self.api_key and HAS_OPENAI else None
|
|
55
|
+
self.db_path = db
|
|
56
|
+
self.conn = sqlite3.connect(db, timeout=30, check_same_thread=False)
|
|
57
|
+
self.cursor = self.conn.cursor()
|
|
58
|
+
self.current_table = None
|
|
59
|
+
self.schema_info = {}
|
|
60
|
+
self.cache_queries = cache_queries
|
|
61
|
+
self.cache = {} if cache_queries else None
|
|
62
|
+
self.use_embeddings = use_embeddings
|
|
63
|
+
self.embedding_model = None
|
|
64
|
+
self.query_embeddings = {}
|
|
65
|
+
self.check_relevance = check_relevance
|
|
66
|
+
self.fuzzy_match = fuzzy_match
|
|
67
|
+
|
|
68
|
+
if use_embeddings and HAS_EMBEDDINGS:
|
|
69
|
+
try:
|
|
70
|
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
71
|
+
except:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
self._refresh_schema()
|
|
75
|
+
print(f"QuerySUTRA v0.5.1 Ready")
|
|
76
|
+
|
|
77
|
+
def upload(self, data: Union[str, pd.DataFrame], name: Optional[str] = None) -> 'SUTRA':
|
|
78
|
+
"""Upload."""
|
|
79
|
+
if isinstance(data, pd.DataFrame):
|
|
80
|
+
self._store(data, name or "data")
|
|
81
|
+
return self
|
|
82
|
+
|
|
83
|
+
path = Path(data)
|
|
84
|
+
if not path.exists():
|
|
85
|
+
raise FileNotFoundError(f"Not found: {data}")
|
|
86
|
+
|
|
87
|
+
name = name or path.stem.replace(" ", "_").replace("-", "_")
|
|
88
|
+
ext = path.suffix.lower()
|
|
89
|
+
|
|
90
|
+
if ext == ".csv":
|
|
91
|
+
self._store(pd.read_csv(path), name)
|
|
92
|
+
elif ext in [".xlsx", ".xls"]:
|
|
93
|
+
self._store(pd.read_excel(path), name)
|
|
94
|
+
elif ext == ".json":
|
|
95
|
+
self._store(pd.read_json(path), name)
|
|
96
|
+
elif ext == ".pdf":
|
|
97
|
+
self._pdf(path, name)
|
|
98
|
+
elif ext == ".docx":
|
|
99
|
+
self._docx(path, name)
|
|
100
|
+
elif ext == ".txt":
|
|
101
|
+
self._txt(path, name)
|
|
102
|
+
else:
|
|
103
|
+
raise ValueError(f"Unsupported: {ext}")
|
|
104
|
+
|
|
105
|
+
return self
|
|
106
|
+
|
|
107
|
+
def _pdf(self, path: Path, name: str):
|
|
108
|
+
"""BULLETPROOF PDF - ALWAYS creates multiple tables."""
|
|
109
|
+
if not HAS_PYPDF2:
|
|
110
|
+
raise ImportError("pip install PyPDF2")
|
|
111
|
+
|
|
112
|
+
print(f"Extracting PDF: {path.name}")
|
|
113
|
+
|
|
114
|
+
with open(path, 'rb') as f:
|
|
115
|
+
text = "".join([p.extract_text() + "\n" for p in PyPDF2.PdfReader(f).pages])
|
|
116
|
+
|
|
117
|
+
if not self.client:
|
|
118
|
+
print("ERROR: No API key! Set api_key parameter")
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
print("AI: Extracting...")
|
|
122
|
+
|
|
123
|
+
# TRY 3 TIMES
|
|
124
|
+
entities = None
|
|
125
|
+
for attempt in [1, 2, 3]:
|
|
126
|
+
entities = self._extract(text, attempt)
|
|
127
|
+
if entities and len(entities) > 0:
|
|
128
|
+
break
|
|
129
|
+
if attempt < 3:
|
|
130
|
+
print(f" Retry {attempt+1}/3...")
|
|
131
|
+
|
|
132
|
+
# Create tables from entities
|
|
133
|
+
if entities and len(entities) > 0:
|
|
134
|
+
print(f"Extracted {len(entities)} entity types:")
|
|
135
|
+
for etype, recs in entities.items():
|
|
136
|
+
if recs and len(recs) > 0:
|
|
137
|
+
for idx, rec in enumerate(recs, 1):
|
|
138
|
+
rec['id'] = idx
|
|
139
|
+
self._store(pd.DataFrame(recs), f"{name}_{etype}")
|
|
140
|
+
print(f" {etype}: {len(recs)} rows")
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
# REGEX FALLBACK - FIXED
|
|
144
|
+
print("Using regex fallback...")
|
|
145
|
+
people = []
|
|
146
|
+
emails = re.findall(r'[\w\.-]+@[\w\.-]+\.\w+', text)
|
|
147
|
+
|
|
148
|
+
# Extract names from common patterns
|
|
149
|
+
name_patterns = [
|
|
150
|
+
r'(?:Employee|Name|Mr\.|Mrs\.|Ms\.|Dr\.)\s*[:\-]?\s*([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)',
|
|
151
|
+
r'([A-Z][a-z]+\s+[A-Z][a-z]+)\s+(?:lives|resides|works|is based)',
|
|
152
|
+
r'\*\*([A-Z][a-z]+\s+[A-Z][a-z]+)\*\*'
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
names = []
|
|
156
|
+
for pattern in name_patterns:
|
|
157
|
+
names.extend(re.findall(pattern, text))
|
|
158
|
+
if len(names) >= len(emails):
|
|
159
|
+
break
|
|
160
|
+
|
|
161
|
+
# Match emails to names
|
|
162
|
+
max_people = min(len(emails), 50)
|
|
163
|
+
for i in range(max_people):
|
|
164
|
+
people.append({
|
|
165
|
+
'id': i + 1,
|
|
166
|
+
'name': names[i] if i < len(names) else f"Person {i+1}",
|
|
167
|
+
'email': emails[i] if i < len(emails) else f"person{i+1}@unknown.com"
|
|
168
|
+
})
|
|
169
|
+
|
|
170
|
+
if people:
|
|
171
|
+
self._store(pd.DataFrame(people), f"{name}_people")
|
|
172
|
+
print(f" Extracted {len(people)} people via regex")
|
|
173
|
+
else:
|
|
174
|
+
# Absolute last resort
|
|
175
|
+
lines = [l.strip() for l in text.split('\n') if l.strip()][:100]
|
|
176
|
+
self._store(pd.DataFrame({'line': range(1, len(lines)+1), 'text': lines}), name)
|
|
177
|
+
|
|
178
|
+
def _extract(self, text: str, attempt: int) -> Dict:
|
|
179
|
+
"""Extract with 3 different strategies."""
|
|
180
|
+
if not self.client:
|
|
181
|
+
return {}
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
if attempt == 1:
|
|
185
|
+
sys_msg = "Extract entities as JSON. Return ONLY valid JSON."
|
|
186
|
+
usr_msg = f"""Extract ALL entities from text.
|
|
187
|
+
|
|
188
|
+
Text:
|
|
189
|
+
{text[:15000]}
|
|
190
|
+
|
|
191
|
+
Return JSON with: people, skills, technologies, projects, certifications, education, work_experience
|
|
192
|
+
|
|
193
|
+
Example:
|
|
194
|
+
{{"people":[{{"id":1,"name":"Sarah Johnson","email":"sarah@co.com","city":"New York","state":"NY"}},{{"id":2,"name":"Michael Chen","email":"michael@co.com","city":"SF","state":"CA"}}],"skills":[{{"id":1,"person_id":1,"skill_name":"Python","proficiency":"Expert"}}]}}
|
|
195
|
+
|
|
196
|
+
Rules: Unique IDs (1,2,3...), person_id references people.id
|
|
197
|
+
|
|
198
|
+
JSON:"""
|
|
199
|
+
|
|
200
|
+
elif attempt == 2:
|
|
201
|
+
sys_msg = "Return JSON."
|
|
202
|
+
usr_msg = f"""Text: {text[:10000]}
|
|
203
|
+
|
|
204
|
+
Extract people as JSON:
|
|
205
|
+
{{"people":[{{"id":1,"name":"...","email":"...","city":"..."}}]}}
|
|
206
|
+
|
|
207
|
+
JSON:"""
|
|
208
|
+
|
|
209
|
+
else:
|
|
210
|
+
sys_msg = "JSON only."
|
|
211
|
+
usr_msg = f"""Find names and emails in: {text[:8000]}
|
|
212
|
+
|
|
213
|
+
{{"people":[{{"id":1,"name":"John","email":"john@co.com"}}]}}"""
|
|
214
|
+
|
|
215
|
+
r = self.client.chat.completions.create(
|
|
216
|
+
model="gpt-4o-mini",
|
|
217
|
+
messages=[
|
|
218
|
+
{"role": "system", "content": sys_msg},
|
|
219
|
+
{"role": "user", "content": usr_msg}
|
|
220
|
+
],
|
|
221
|
+
temperature=0,
|
|
222
|
+
max_tokens=12000
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
raw = r.choices[0].message.content.strip()
|
|
226
|
+
raw = raw.replace("```json", "").replace("```", "").replace("JSON:", "").strip()
|
|
227
|
+
|
|
228
|
+
start = raw.find('{')
|
|
229
|
+
end = raw.rfind('}') + 1
|
|
230
|
+
|
|
231
|
+
if start < 0 or end <= start:
|
|
232
|
+
return {}
|
|
233
|
+
|
|
234
|
+
result = json.loads(raw[start:end])
|
|
235
|
+
|
|
236
|
+
if isinstance(result, dict) and len(result) > 0:
|
|
237
|
+
has_data = any(isinstance(v, list) and len(v) > 0 for v in result.values())
|
|
238
|
+
if has_data:
|
|
239
|
+
return result
|
|
240
|
+
|
|
241
|
+
return {}
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
print(f" Attempt {attempt} failed: {str(e)[:100]}")
|
|
245
|
+
return {}
|
|
246
|
+
|
|
247
|
+
def _docx(self, path: Path, name: str):
|
|
248
|
+
"""DOCX."""
|
|
249
|
+
if not HAS_DOCX:
|
|
250
|
+
raise ImportError("pip install python-docx")
|
|
251
|
+
doc = docx.Document(path)
|
|
252
|
+
if doc.tables:
|
|
253
|
+
for i, t in enumerate(doc.tables):
|
|
254
|
+
data = [[cell.text.strip() for cell in row.cells] for row in t.rows]
|
|
255
|
+
if data and len(data) > 1:
|
|
256
|
+
self._store(pd.DataFrame(data[1:], columns=data[0]), f"{name}_t{i+1}")
|
|
257
|
+
else:
|
|
258
|
+
text = "\n".join([p.text for p in doc.paragraphs])
|
|
259
|
+
lines = [l.strip() for l in text.split('\n') if l.strip()]
|
|
260
|
+
self._store(pd.DataFrame({'line': range(1, len(lines)+1), 'text': lines}), name)
|
|
261
|
+
|
|
262
|
+
def _txt(self, path: Path, name: str):
|
|
263
|
+
"""TXT."""
|
|
264
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
265
|
+
text = f.read()
|
|
266
|
+
lines = [l.strip() for l in text.split('\n') if l.strip()]
|
|
267
|
+
self._store(pd.DataFrame({'line': range(1, len(lines)+1), 'text': lines}), name)
|
|
268
|
+
|
|
269
|
+
def _store(self, df: pd.DataFrame, name: str):
|
|
270
|
+
"""Store."""
|
|
271
|
+
df.columns = [str(c).strip().replace(" ", "_").replace("-", "_") for c in df.columns]
|
|
272
|
+
try:
|
|
273
|
+
df.to_sql(name, self.conn, if_exists='replace', index=False, method='multi', chunksize=500)
|
|
274
|
+
except:
|
|
275
|
+
df.to_sql(name, self.conn, if_exists='replace', index=False)
|
|
276
|
+
self.conn.commit()
|
|
277
|
+
self.current_table = name
|
|
278
|
+
self._refresh_schema()
|
|
279
|
+
print(f" {name}: {len(df)} rows")
|
|
280
|
+
|
|
281
|
+
def ask(self, q: str, viz: Union[bool, str] = False, table: Optional[str] = None) -> 'QueryResult':
|
|
282
|
+
"""Query."""
|
|
283
|
+
if not self.client:
|
|
284
|
+
return QueryResult(False, "", pd.DataFrame(), None, "No API")
|
|
285
|
+
|
|
286
|
+
t = table or self.current_table or (self._get_tables()[0] if self._get_tables() else None)
|
|
287
|
+
if not t:
|
|
288
|
+
return QueryResult(False, "", pd.DataFrame(), None, "No table")
|
|
289
|
+
|
|
290
|
+
if self.use_embeddings and self.embedding_model:
|
|
291
|
+
cached = self._check_cache(q, t)
|
|
292
|
+
if cached:
|
|
293
|
+
return cached
|
|
294
|
+
|
|
295
|
+
if self.fuzzy_match:
|
|
296
|
+
q = self._fuzzy(q, t)
|
|
297
|
+
|
|
298
|
+
key = hashlib.md5(f"{q}:{t}".encode()).hexdigest()
|
|
299
|
+
if self.cache_queries and self.cache and key in self.cache:
|
|
300
|
+
sql = self.cache[key]
|
|
301
|
+
else:
|
|
302
|
+
sql = self._gen_sql(q, t)
|
|
303
|
+
if self.cache_queries and self.cache:
|
|
304
|
+
self.cache[key] = sql
|
|
305
|
+
|
|
306
|
+
print(f"SQL: {sql}")
|
|
307
|
+
|
|
308
|
+
try:
|
|
309
|
+
df = pd.read_sql_query(sql, self.conn)
|
|
310
|
+
print(f"Success! {len(df)} rows")
|
|
311
|
+
fig = self._viz(df, q, viz if isinstance(viz, str) else "auto") if viz else None
|
|
312
|
+
r = QueryResult(True, sql, df, fig)
|
|
313
|
+
|
|
314
|
+
if self.use_embeddings and self.embedding_model:
|
|
315
|
+
self._store_cache(q, t, r)
|
|
316
|
+
|
|
317
|
+
return r
|
|
318
|
+
except Exception as e:
|
|
319
|
+
return QueryResult(False, sql, pd.DataFrame(), None, str(e))
|
|
320
|
+
|
|
321
|
+
def _fuzzy(self, q: str, t: str) -> str:
|
|
322
|
+
"""Fuzzy."""
|
|
323
|
+
try:
|
|
324
|
+
cols = [c for c, d in self.schema_info.get(t, {}).items() if 'TEXT' in d]
|
|
325
|
+
if not cols:
|
|
326
|
+
return q
|
|
327
|
+
for col in cols[:2]:
|
|
328
|
+
df = pd.read_sql_query(f"SELECT DISTINCT {col} FROM {t} LIMIT 100", self.conn)
|
|
329
|
+
vals = [str(v) for v in df[col].dropna()]
|
|
330
|
+
words = q.split()
|
|
331
|
+
for i, w in enumerate(words):
|
|
332
|
+
m = get_close_matches(w, vals, n=1, cutoff=0.6)
|
|
333
|
+
if m and w != m[0]:
|
|
334
|
+
words[i] = m[0]
|
|
335
|
+
q = " ".join(words)
|
|
336
|
+
return q
|
|
337
|
+
except:
|
|
338
|
+
return q
|
|
339
|
+
|
|
340
|
+
def _check_cache(self, q: str, t: str) -> Optional['QueryResult']:
|
|
341
|
+
"""Cache."""
|
|
342
|
+
if not self.query_embeddings:
|
|
343
|
+
return None
|
|
344
|
+
emb = self.embedding_model.encode([q])[0]
|
|
345
|
+
best, sim = None, 0.85
|
|
346
|
+
for cq, d in self.query_embeddings.items():
|
|
347
|
+
if d['table'] != t:
|
|
348
|
+
continue
|
|
349
|
+
s = np.dot(emb, d['embedding']) / (np.linalg.norm(emb) * np.linalg.norm(d['embedding']))
|
|
350
|
+
if s > sim:
|
|
351
|
+
sim, best = s, cq
|
|
352
|
+
return self.query_embeddings[best]['result'] if best else None
|
|
353
|
+
|
|
354
|
+
def _store_cache(self, q: str, t: str, r: 'QueryResult'):
|
|
355
|
+
"""Store."""
|
|
356
|
+
emb = self.embedding_model.encode([q])[0]
|
|
357
|
+
self.query_embeddings[q] = {'table': t, 'embedding': emb, 'result': r}
|
|
358
|
+
|
|
359
|
+
def _viz(self, df: pd.DataFrame, title: str, vt: str):
|
|
360
|
+
"""Viz."""
|
|
361
|
+
if not HAS_PLOTLY:
|
|
362
|
+
return None
|
|
363
|
+
try:
|
|
364
|
+
n = df.select_dtypes(include=[np.number]).columns.tolist()
|
|
365
|
+
c = df.select_dtypes(include=['object']).columns.tolist()
|
|
366
|
+
if vt == "pie" and c and n:
|
|
367
|
+
fig = px.pie(df, names=c[0], values=n[0], title=title)
|
|
368
|
+
elif vt == "bar" and c and n:
|
|
369
|
+
fig = px.bar(df, x=c[0], y=n[0], title=title)
|
|
370
|
+
elif vt == "line" and n:
|
|
371
|
+
fig = px.line(df, y=n[0], title=title)
|
|
372
|
+
elif vt == "scatter" and len(n) >= 2:
|
|
373
|
+
fig = px.scatter(df, x=n[0], y=n[1], title=title)
|
|
374
|
+
else:
|
|
375
|
+
fig = px.bar(df, y=df.columns[0], title=title)
|
|
376
|
+
fig.show()
|
|
377
|
+
return fig
|
|
378
|
+
except:
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
def tables(self) -> Dict:
|
|
382
|
+
"""Tables."""
|
|
383
|
+
t = self._get_tables()
|
|
384
|
+
print("\n" + "="*70)
|
|
385
|
+
print("TABLES")
|
|
386
|
+
print("="*70)
|
|
387
|
+
if not t:
|
|
388
|
+
print("No tables")
|
|
389
|
+
return {}
|
|
390
|
+
r = {}
|
|
391
|
+
for i, tb in enumerate(t, 1):
|
|
392
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {tb}", self.conn).iloc[0, 0]
|
|
393
|
+
cols = list(self.schema_info.get(tb, {}).keys())
|
|
394
|
+
print(f" {i}. {tb}: {cnt} rows, {len(cols)} cols")
|
|
395
|
+
r[tb] = {'rows': cnt, 'columns': cols}
|
|
396
|
+
print("="*70)
|
|
397
|
+
return r
|
|
398
|
+
|
|
399
|
+
def schema(self, table: Optional[str] = None) -> Dict:
|
|
400
|
+
"""Schema."""
|
|
401
|
+
if not self.schema_info:
|
|
402
|
+
self._refresh_schema()
|
|
403
|
+
print("\n" + "="*70)
|
|
404
|
+
print("SCHEMA")
|
|
405
|
+
print("="*70)
|
|
406
|
+
r = {}
|
|
407
|
+
for t in ([table] if table else self.schema_info.keys()):
|
|
408
|
+
if t in self.schema_info:
|
|
409
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {t}", self.conn).iloc[0, 0]
|
|
410
|
+
print(f"\n{t}: {cnt} records")
|
|
411
|
+
for c, d in self.schema_info[t].items():
|
|
412
|
+
print(f" - {c:<30} {d}")
|
|
413
|
+
r[t] = {'records': cnt, 'columns': self.schema_info[t]}
|
|
414
|
+
print("="*70)
|
|
415
|
+
return r
|
|
416
|
+
|
|
417
|
+
def peek(self, table: Optional[str] = None, n: int = 5) -> pd.DataFrame:
|
|
418
|
+
"""Preview."""
|
|
419
|
+
t = table or self.current_table
|
|
420
|
+
if not t:
|
|
421
|
+
return pd.DataFrame()
|
|
422
|
+
df = pd.read_sql_query(f"SELECT * FROM {t} LIMIT {n}", self.conn)
|
|
423
|
+
print(f"\nSample from '{t}':")
|
|
424
|
+
print(df.to_string(index=False))
|
|
425
|
+
return df
|
|
426
|
+
|
|
427
|
+
def sql(self, query: str, viz: Union[bool, str] = False) -> 'QueryResult':
|
|
428
|
+
"""SQL."""
|
|
429
|
+
try:
|
|
430
|
+
df = pd.read_sql_query(query, self.conn)
|
|
431
|
+
print(f"Success! {len(df)} rows")
|
|
432
|
+
fig = self._viz(df, "Result", viz if isinstance(viz, str) else "auto") if viz else None
|
|
433
|
+
return QueryResult(True, query, df, fig)
|
|
434
|
+
except Exception as e:
|
|
435
|
+
return QueryResult(False, query, pd.DataFrame(), None, str(e))
|
|
436
|
+
|
|
437
|
+
def save_to_mysql(self, host: str, user: str, password: str, database: str, port: int = 3306):
|
|
438
|
+
"""MySQL."""
|
|
439
|
+
try:
|
|
440
|
+
from sqlalchemy import create_engine
|
|
441
|
+
import mysql.connector
|
|
442
|
+
except:
|
|
443
|
+
raise ImportError("pip install QuerySUTRA[mysql]")
|
|
444
|
+
|
|
445
|
+
print(f"Exporting to MySQL: {database}")
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
tc = mysql.connector.connect(host=host, user=user, password=password, port=port)
|
|
449
|
+
tc.cursor().execute(f"CREATE DATABASE IF NOT EXISTS `{database}`")
|
|
450
|
+
tc.close()
|
|
451
|
+
except:
|
|
452
|
+
pass
|
|
453
|
+
|
|
454
|
+
engine = create_engine(f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}")
|
|
455
|
+
for t in self._get_tables():
|
|
456
|
+
df = pd.read_sql_query(f"SELECT * FROM {t}", self.conn)
|
|
457
|
+
df.to_sql(t, engine, if_exists='replace', index=False)
|
|
458
|
+
print(f" {t}: {len(df)} rows")
|
|
459
|
+
print("Done!")
|
|
460
|
+
return self
|
|
461
|
+
|
|
462
|
+
def export_db(self, path: str, format: str = "sqlite"):
|
|
463
|
+
"""Export."""
|
|
464
|
+
if format == "sqlite":
|
|
465
|
+
shutil.copy2(self.db_path, path)
|
|
466
|
+
elif format == "json":
|
|
467
|
+
data = {t: pd.read_sql_query(f"SELECT * FROM {t}", self.conn).to_dict('records') for t in self._get_tables()}
|
|
468
|
+
with open(path, 'w') as f:
|
|
469
|
+
json.dump(data, f, indent=2, default=str)
|
|
470
|
+
print(f"Saved: {path}")
|
|
471
|
+
return self
|
|
472
|
+
|
|
473
|
+
@classmethod
|
|
474
|
+
def load_from_db(cls, db_path: str, api_key: Optional[str] = None, **kwargs):
|
|
475
|
+
"""Load."""
|
|
476
|
+
if not Path(db_path).exists():
|
|
477
|
+
raise FileNotFoundError(f"Not found: {db_path}")
|
|
478
|
+
return cls(api_key=api_key, db=db_path, **kwargs)
|
|
479
|
+
|
|
480
|
+
@classmethod
|
|
481
|
+
def connect_mysql(cls, host: str, user: str, password: str, database: str, port: int = 3306, api_key: Optional[str] = None, **kwargs):
|
|
482
|
+
"""MySQL."""
|
|
483
|
+
try:
|
|
484
|
+
from sqlalchemy import create_engine
|
|
485
|
+
import mysql.connector
|
|
486
|
+
except:
|
|
487
|
+
raise ImportError("pip install QuerySUTRA[mysql]")
|
|
488
|
+
|
|
489
|
+
try:
|
|
490
|
+
tc = mysql.connector.connect(host=host, user=user, password=password, port=port)
|
|
491
|
+
tc.cursor().execute(f"CREATE DATABASE IF NOT EXISTS {database}")
|
|
492
|
+
tc.close()
|
|
493
|
+
except:
|
|
494
|
+
pass
|
|
495
|
+
|
|
496
|
+
engine = create_engine(f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}")
|
|
497
|
+
temp_db = f"mysql_{database}.db"
|
|
498
|
+
instance = cls(api_key=api_key, db=temp_db, **kwargs)
|
|
499
|
+
|
|
500
|
+
tables = pd.read_sql_query("SHOW TABLES", engine).iloc[:, 0].tolist()
|
|
501
|
+
for t in tables:
|
|
502
|
+
pd.read_sql_query(f"SELECT * FROM {t}", engine).to_sql(t, instance.conn, if_exists='replace', index=False)
|
|
503
|
+
|
|
504
|
+
instance._refresh_schema()
|
|
505
|
+
print(f"Connected! {len(tables)} tables")
|
|
506
|
+
return instance
|
|
507
|
+
|
|
508
|
+
def _gen_sql(self, q: str, t: str) -> str:
|
|
509
|
+
"""SQL."""
|
|
510
|
+
schema = self.schema_info.get(t, {})
|
|
511
|
+
sample = pd.read_sql_query(f"SELECT * FROM {t} LIMIT 3", self.conn).to_string(index=False)
|
|
512
|
+
cols = ", ".join([f"{c} ({d})" for c, d in schema.items()])
|
|
513
|
+
|
|
514
|
+
r = self.client.chat.completions.create(
|
|
515
|
+
model="gpt-4o-mini",
|
|
516
|
+
messages=[
|
|
517
|
+
{"role": "system", "content": "SQL expert. Return only SQL."},
|
|
518
|
+
{"role": "user", "content": f"Table: {t}\nColumns: {cols}\nSample:\n{sample}\n\nQ: {q}\n\nSQL:"}
|
|
519
|
+
],
|
|
520
|
+
temperature=0
|
|
521
|
+
)
|
|
522
|
+
return r.choices[0].message.content.strip().replace("```sql", "").replace("```", "").strip()
|
|
523
|
+
|
|
524
|
+
def _get_tables(self) -> List[str]:
|
|
525
|
+
"""Tables."""
|
|
526
|
+
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
527
|
+
return [r[0] for r in self.cursor.fetchall()]
|
|
528
|
+
|
|
529
|
+
def _refresh_schema(self):
|
|
530
|
+
"""Refresh."""
|
|
531
|
+
self.schema_info = {}
|
|
532
|
+
for t in self._get_tables():
|
|
533
|
+
self.cursor.execute(f"PRAGMA table_info({t})")
|
|
534
|
+
self.schema_info[t] = {r[1]: r[2] for r in self.cursor.fetchall()}
|
|
535
|
+
|
|
536
|
+
def close(self):
|
|
537
|
+
if self.conn:
|
|
538
|
+
self.conn.close()
|
|
539
|
+
|
|
540
|
+
def __enter__(self):
|
|
541
|
+
return self
|
|
542
|
+
|
|
543
|
+
def __exit__(self, *args):
|
|
544
|
+
self.close()
|
|
545
|
+
|
|
546
|
+
def __repr__(self):
|
|
547
|
+
return f"SUTRA(tables={len(self.schema_info)})"
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class QueryResult:
|
|
551
|
+
"""Result."""
|
|
552
|
+
def __init__(self, success: bool, sql: str, data: pd.DataFrame, viz, error: str = None):
|
|
553
|
+
self.success, self.sql, self.data, self.viz, self.error = success, sql, data, viz, error
|
|
554
|
+
|
|
555
|
+
def __repr__(self):
|
|
556
|
+
return f"QueryResult(rows={len(self.data)})" if self.success else f"QueryResult(error='{self.error}')"
|
|
557
|
+
|
|
558
|
+
def show(self):
|
|
559
|
+
print(self.data if self.success else f"Error: {self.error}")
|
|
560
|
+
return self
|