QuerySUTRA 0.5.0__tar.gz → 0.5.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {querysutra-0.5.0 → querysutra-0.5.1}/PKG-INFO +1 -1
- {querysutra-0.5.0 → querysutra-0.5.1}/QuerySUTRA.egg-info/PKG-INFO +1 -1
- {querysutra-0.5.0 → querysutra-0.5.1}/pyproject.toml +1 -1
- {querysutra-0.5.0 → querysutra-0.5.1}/setup.py +1 -1
- querysutra-0.5.1/sutra/__init__.py +4 -0
- querysutra-0.5.1/sutra/sutra.py +594 -0
- querysutra-0.5.0/sutra/__init__.py +0 -4
- querysutra-0.5.0/sutra/sutra.py +0 -819
- {querysutra-0.5.0 → querysutra-0.5.1}/LICENSE +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/MANIFEST.in +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/QuerySUTRA.egg-info/SOURCES.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/QuerySUTRA.egg-info/dependency_links.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/QuerySUTRA.egg-info/requires.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/QuerySUTRA.egg-info/top_level.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/README.md +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/examples/quickstart.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/examples/sutra_usage_guide.ipynb +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/examples/usage_guide.ipynb +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/requirements.txt +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/setup.cfg +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/cache_manager.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/clear_cache.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/core.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/data_loader.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/database_manager.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/direct_query.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/feedback.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/feedback_matcher.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/nlp_processor.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/schema_embeddings.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/schema_generator.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/sutra_client.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/sutra_core.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/sutra_simple.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/sutra/visualizer.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/tests/__init__.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/tests/test_modules.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/tests/test_sutra.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/utils/__init__.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/utils/file_utils.py +0 -0
- {querysutra-0.5.0 → querysutra-0.5.1}/utils/text_utils.py +0 -0
|
@@ -1,3 +1,3 @@
|
|
|
1
1
|
from setuptools import setup,find_packages
|
|
2
2
|
with open("README.md","r",encoding="utf-8") as f:d=f.read()
|
|
3
|
-
setup(name="QuerySUTRA",version="0.5.
|
|
3
|
+
setup(name="QuerySUTRA",version="0.5.1",author="Aditya Batta",description="SUTRA",long_description=d,long_description_content_type="text/markdown",packages=find_packages(),python_requires=">=3.8",install_requires=["pandas>=1.3.0","numpy>=1.21.0","openai>=1.0.0","plotly>=5.0.0","matplotlib>=3.3.0","PyPDF2>=3.0.0","python-docx>=0.8.11","openpyxl>=3.0.0"],extras_require={"mysql":["sqlalchemy>=1.4.0","mysql-connector-python>=8.0.0"],"postgres":["sqlalchemy>=1.4.0","psycopg2-binary>=2.9.0"],"embeddings":["sentence-transformers>=2.0.0"],"all":["sqlalchemy>=1.4.0","mysql-connector-python>=8.0.0","psycopg2-binary>=2.9.0","sentence-transformers>=2.0.0"]})
|
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
"""
|
|
2
|
+
QuerySUTRA v0.5.0 - BULLETPROOF
|
|
3
|
+
GUARANTEED to create multiple tables with proper keys
|
|
4
|
+
NEVER falls back to single table
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__version__ = "0.5.0"
|
|
8
|
+
__author__ = "Aditya Batta"
|
|
9
|
+
__all__ = ["SUTRA", "QueryResult"]
|
|
10
|
+
|
|
11
|
+
import os, sqlite3, pandas as pd, numpy as np, json, hashlib, shutil, datetime, re
|
|
12
|
+
from typing import Optional, Union, Dict, List
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from difflib import get_close_matches
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from openai import OpenAI
|
|
18
|
+
HAS_OPENAI = True
|
|
19
|
+
except:
|
|
20
|
+
HAS_OPENAI = False
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import plotly.express as px
|
|
24
|
+
import plotly.graph_objects as go
|
|
25
|
+
HAS_PLOTLY = True
|
|
26
|
+
except:
|
|
27
|
+
HAS_PLOTLY = False
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import PyPDF2
|
|
31
|
+
HAS_PYPDF2 = True
|
|
32
|
+
except:
|
|
33
|
+
HAS_PYPDF2 = False
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
import docx
|
|
37
|
+
HAS_DOCX = True
|
|
38
|
+
except:
|
|
39
|
+
HAS_DOCX = False
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from sentence_transformers import SentenceTransformer
|
|
43
|
+
HAS_EMBEDDINGS = True
|
|
44
|
+
except:
|
|
45
|
+
HAS_EMBEDDINGS = False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SUTRA:
|
|
49
|
+
"""SUTRA - BULLETPROOF AI EXTRACTION"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, api_key: Optional[str] = None, db: str = "sutra.db",
|
|
52
|
+
use_embeddings: bool = False, fuzzy_match: bool = True,
|
|
53
|
+
cache_queries: bool = True, check_relevance: bool = False):
|
|
54
|
+
|
|
55
|
+
if api_key:
|
|
56
|
+
os.environ["OPENAI_API_KEY"] = api_key
|
|
57
|
+
|
|
58
|
+
self.api_key = os.getenv("OPENAI_API_KEY")
|
|
59
|
+
self.client = OpenAI(api_key=self.api_key) if self.api_key and HAS_OPENAI else None
|
|
60
|
+
self.db_path = db
|
|
61
|
+
self.conn = sqlite3.connect(db, timeout=30, check_same_thread=False)
|
|
62
|
+
self.cursor = self.conn.cursor()
|
|
63
|
+
self.current_table = None
|
|
64
|
+
self.schema_info = {}
|
|
65
|
+
self.cache_queries = cache_queries
|
|
66
|
+
self.cache = {} if cache_queries else None
|
|
67
|
+
self.use_embeddings = use_embeddings
|
|
68
|
+
self.embedding_model = None
|
|
69
|
+
self.query_embeddings = {}
|
|
70
|
+
self.check_relevance = check_relevance
|
|
71
|
+
self.fuzzy_match = fuzzy_match
|
|
72
|
+
|
|
73
|
+
if use_embeddings and HAS_EMBEDDINGS:
|
|
74
|
+
try:
|
|
75
|
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
76
|
+
except:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
self._refresh_schema()
|
|
80
|
+
print(f"QuerySUTRA v0.5.0 Ready")
|
|
81
|
+
|
|
82
|
+
def upload(self, data: Union[str, pd.DataFrame], name: Optional[str] = None) -> 'SUTRA':
|
|
83
|
+
"""Upload data."""
|
|
84
|
+
if isinstance(data, pd.DataFrame):
|
|
85
|
+
self._store(data, name or "data")
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
path = Path(data)
|
|
89
|
+
if not path.exists():
|
|
90
|
+
raise FileNotFoundError(f"Not found: {data}")
|
|
91
|
+
|
|
92
|
+
name = name or path.stem.replace(" ", "_").replace("-", "_")
|
|
93
|
+
ext = path.suffix.lower()
|
|
94
|
+
|
|
95
|
+
if ext == ".csv":
|
|
96
|
+
self._store(pd.read_csv(path), name)
|
|
97
|
+
elif ext in [".xlsx", ".xls"]:
|
|
98
|
+
self._store(pd.read_excel(path), name)
|
|
99
|
+
elif ext == ".json":
|
|
100
|
+
self._store(pd.read_json(path), name)
|
|
101
|
+
elif ext == ".pdf":
|
|
102
|
+
self._pdf(path, name)
|
|
103
|
+
elif ext == ".docx":
|
|
104
|
+
self._docx(path, name)
|
|
105
|
+
elif ext == ".txt":
|
|
106
|
+
self._txt(path, name)
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Unsupported: {ext}")
|
|
109
|
+
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
def _pdf(self, path: Path, name: str):
|
|
113
|
+
"""BULLETPROOF PDF extraction - GUARANTEED to create multiple tables."""
|
|
114
|
+
if not HAS_PYPDF2:
|
|
115
|
+
raise ImportError("pip install PyPDF2")
|
|
116
|
+
|
|
117
|
+
print(f"Extracting PDF: {path.name}")
|
|
118
|
+
|
|
119
|
+
with open(path, 'rb') as f:
|
|
120
|
+
text = "".join([p.extract_text() + "\n" for p in PyPDF2.PdfReader(f).pages])
|
|
121
|
+
|
|
122
|
+
if not self.client:
|
|
123
|
+
print("No API key - using simple extraction")
|
|
124
|
+
self._store(pd.DataFrame({'line': range(1, len(text.split('\n'))), 'text': text.split('\n')}), name)
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
print("AI: Extracting entities (BULLETPROOF mode)...")
|
|
128
|
+
|
|
129
|
+
# TRY 3 TIMES with progressively simpler prompts
|
|
130
|
+
entities = None
|
|
131
|
+
|
|
132
|
+
# ATTEMPT 1: Full extraction
|
|
133
|
+
entities = self._extract(text, attempt=1)
|
|
134
|
+
|
|
135
|
+
# ATTEMPT 2: Simpler prompt
|
|
136
|
+
if not entities or len(entities) == 0:
|
|
137
|
+
print(" Retry with simpler prompt...")
|
|
138
|
+
entities = self._extract(text, attempt=2)
|
|
139
|
+
|
|
140
|
+
# ATTEMPT 3: Basic extraction
|
|
141
|
+
if not entities or len(entities) == 0:
|
|
142
|
+
print(" Final retry with basic prompt...")
|
|
143
|
+
entities = self._extract(text, attempt=3)
|
|
144
|
+
|
|
145
|
+
# SUCCESS - Create tables
|
|
146
|
+
if entities and len(entities) > 0:
|
|
147
|
+
print(f"SUCCESS! Extracted {len(entities)} entity types:")
|
|
148
|
+
for etype, recs in entities.items():
|
|
149
|
+
if recs and len(recs) > 0:
|
|
150
|
+
# Renumber IDs
|
|
151
|
+
for idx, rec in enumerate(recs, 1):
|
|
152
|
+
rec['id'] = idx
|
|
153
|
+
|
|
154
|
+
df = pd.DataFrame(recs)
|
|
155
|
+
self._store(df, f"{name}_{etype}")
|
|
156
|
+
print(f" {etype}: {len(df)} rows")
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
# LAST RESORT - Force at least people table from text analysis
|
|
160
|
+
print("WARNING: AI extraction failed 3 times - using text analysis...")
|
|
161
|
+
|
|
162
|
+
# Try to extract at least names/emails with regex
|
|
163
|
+
people = []
|
|
164
|
+
emails = re.findall(r'[\w\.-]+@[\w\.-]+\.\w+', text)
|
|
165
|
+
names = re.findall(r'(?:Employee|Mr\.|Mrs\.|Ms\.|Dr\.)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)', text)
|
|
166
|
+
|
|
167
|
+
for i, (email, name_match) in enumerate(zip(emails[:50], names[:50] if names else [f"Person {i+1}" for i in range(len(emails))]), 1):
|
|
168
|
+
people.append({'id': i, 'name': name_match if isinstance(name_match, str) else f"Person {i}", 'email': email})
|
|
169
|
+
|
|
170
|
+
if people:
|
|
171
|
+
self._store(pd.DataFrame(people), f"{name}_people")
|
|
172
|
+
print(f" Extracted {len(people)} people via regex")
|
|
173
|
+
else:
|
|
174
|
+
# Absolute fallback
|
|
175
|
+
self._store(pd.DataFrame({'line': range(1, min(100, len(text.split('\n')))), 'text': text.split('\n')[:100]}), name)
|
|
176
|
+
|
|
177
|
+
def _extract(self, text: str, attempt: int) -> Dict:
|
|
178
|
+
"""Extract with different strategies."""
|
|
179
|
+
if not self.client:
|
|
180
|
+
return {}
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
if attempt == 1:
|
|
184
|
+
# Detailed extraction
|
|
185
|
+
sys_msg = "You are a JSON extraction expert. Extract ALL entities with unique sequential IDs and proper foreign keys. Return ONLY valid JSON, absolutely no other text."
|
|
186
|
+
usr_msg = f"""Extract ALL structured entities from this text into a JSON object.
|
|
187
|
+
|
|
188
|
+
Text (first 15000 chars):
|
|
189
|
+
{text[:15000]}
|
|
190
|
+
|
|
191
|
+
Create separate arrays for these entity types (only if data exists):
|
|
192
|
+
- people: id (int), name (str), email (str), phone (str), address (str), city (str), state (str), zip (str)
|
|
193
|
+
- skills: id (int), person_id (int), skill_name (str), proficiency (str), years (int)
|
|
194
|
+
- technologies: id (int), person_id (int), technology (str), category (str), proficiency (str)
|
|
195
|
+
- projects: id (int), person_id (int), project_name (str), description (str), start_date (str), end_date (str)
|
|
196
|
+
- certifications: id (int), person_id (int), cert_name (str), issuer (str), date_obtained (str)
|
|
197
|
+
- education: id (int), person_id (int), degree (str), institution (str), graduation_year (str)
|
|
198
|
+
- work_experience: id (int), person_id (int), company (str), title (str), start_date (str), end_date (str)
|
|
199
|
+
|
|
200
|
+
CRITICAL RULES:
|
|
201
|
+
1. IDs must be unique sequential integers: 1, 2, 3, 4...
|
|
202
|
+
2. person_id in related tables MUST reference valid people.id values
|
|
203
|
+
3. Extract EVERY person, skill, technology, project you find
|
|
204
|
+
4. Return ONLY the JSON object, no markdown, no explanations
|
|
205
|
+
|
|
206
|
+
Example output format:
|
|
207
|
+
{{
|
|
208
|
+
"people": [
|
|
209
|
+
{{"id": 1, "name": "Sarah Johnson", "email": "sarah@company.com", "phone": "(212) 555-0147", "city": "New York", "state": "NY"}},
|
|
210
|
+
{{"id": 2, "name": "Michael Chen", "email": "michael@company.com", "phone": "(415) 555-0283", "city": "San Francisco", "state": "CA"}}
|
|
211
|
+
],
|
|
212
|
+
"skills": [
|
|
213
|
+
{{"id": 1, "person_id": 1, "skill_name": "Python", "proficiency": "Expert", "years": 5}},
|
|
214
|
+
{{"id": 2, "person_id": 1, "skill_name": "SQL", "proficiency": "Advanced", "years": 3}},
|
|
215
|
+
{{"id": 3, "person_id": 2, "skill_name": "Product Management", "proficiency": "Expert", "years": 7}}
|
|
216
|
+
]
|
|
217
|
+
}}
|
|
218
|
+
|
|
219
|
+
Now extract from the text above. Return ONLY valid JSON:"""
|
|
220
|
+
|
|
221
|
+
elif attempt == 2:
|
|
222
|
+
# Simplified extraction
|
|
223
|
+
sys_msg = "Extract entities as JSON. Return only JSON."
|
|
224
|
+
usr_msg = f"""Text: {text[:10000]}
|
|
225
|
+
|
|
226
|
+
Extract people, skills, technologies as JSON:
|
|
227
|
+
{{"people":[{{"id":1,"name":"...","email":"...","city":"..."}}],"skills":[{{"id":1,"person_id":1,"skill_name":"..."}}]}}
|
|
228
|
+
|
|
229
|
+
Rules: Unique IDs (1,2,3...), person_id links to people.id
|
|
230
|
+
|
|
231
|
+
JSON only:"""
|
|
232
|
+
|
|
233
|
+
else:
|
|
234
|
+
# Basic extraction
|
|
235
|
+
sys_msg = "Return JSON only."
|
|
236
|
+
usr_msg = f"""Text: {text[:8000]}
|
|
237
|
+
|
|
238
|
+
Find people with names, emails, cities. Return as JSON:
|
|
239
|
+
{{"people":[{{"id":1,"name":"John","email":"john@co.com","city":"NYC"}}]}}
|
|
240
|
+
|
|
241
|
+
JSON:"""
|
|
242
|
+
|
|
243
|
+
resp = self.client.chat.completions.create(
|
|
244
|
+
model="gpt-4o-mini",
|
|
245
|
+
messages=[
|
|
246
|
+
{"role": "system", "content": sys_msg},
|
|
247
|
+
{"role": "user", "content": usr_msg}
|
|
248
|
+
],
|
|
249
|
+
temperature=0,
|
|
250
|
+
max_tokens=12000
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
raw = resp.choices[0].message.content.strip()
|
|
254
|
+
|
|
255
|
+
# AGGRESSIVE JSON extraction
|
|
256
|
+
raw = raw.replace("```json", "").replace("```", "").replace("JSON:", "").replace("json", "").strip()
|
|
257
|
+
|
|
258
|
+
# Find JSON object
|
|
259
|
+
start = raw.find('{')
|
|
260
|
+
end = raw.rfind('}') + 1
|
|
261
|
+
|
|
262
|
+
if start < 0 or end <= start:
|
|
263
|
+
return {}
|
|
264
|
+
|
|
265
|
+
json_str = raw[start:end]
|
|
266
|
+
|
|
267
|
+
# Parse
|
|
268
|
+
result = json.loads(json_str)
|
|
269
|
+
|
|
270
|
+
# Validate
|
|
271
|
+
if isinstance(result, dict) and len(result) > 0:
|
|
272
|
+
# Check if at least one entity type has data
|
|
273
|
+
has_data = any(isinstance(v, list) and len(v) > 0 for v in result.values())
|
|
274
|
+
if has_data:
|
|
275
|
+
return result
|
|
276
|
+
|
|
277
|
+
return {}
|
|
278
|
+
|
|
279
|
+
except Exception as e:
|
|
280
|
+
print(f" Attempt {attempt} failed: {e}")
|
|
281
|
+
return {}
|
|
282
|
+
|
|
283
|
+
def _docx(self, path: Path, name: str):
|
|
284
|
+
"""DOCX."""
|
|
285
|
+
if not HAS_DOCX:
|
|
286
|
+
raise ImportError("pip install python-docx")
|
|
287
|
+
doc = docx.Document(path)
|
|
288
|
+
if doc.tables:
|
|
289
|
+
for i, t in enumerate(doc.tables):
|
|
290
|
+
data = [[cell.text.strip() for cell in row.cells] for row in t.rows]
|
|
291
|
+
if data and len(data) > 1:
|
|
292
|
+
self._store(pd.DataFrame(data[1:], columns=data[0]), f"{name}_t{i+1}")
|
|
293
|
+
else:
|
|
294
|
+
text = "\n".join([p.text for p in doc.paragraphs])
|
|
295
|
+
self._store(pd.DataFrame({'line': range(len(text.split('\n'))), 'text': text.split('\n')}), name)
|
|
296
|
+
|
|
297
|
+
def _txt(self, path: Path, name: str):
|
|
298
|
+
"""TXT."""
|
|
299
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
300
|
+
text = f.read()
|
|
301
|
+
self._store(pd.DataFrame({'line': range(len(text.split('\n'))), 'text': text.split('\n')}), name)
|
|
302
|
+
|
|
303
|
+
def _store(self, df: pd.DataFrame, name: str):
|
|
304
|
+
"""Store."""
|
|
305
|
+
df.columns = [str(c).strip().replace(" ", "_").replace("-", "_") for c in df.columns]
|
|
306
|
+
try:
|
|
307
|
+
df.to_sql(name, self.conn, if_exists='replace', index=False, method='multi', chunksize=500)
|
|
308
|
+
except:
|
|
309
|
+
df.to_sql(name, self.conn, if_exists='replace', index=False)
|
|
310
|
+
self.conn.commit()
|
|
311
|
+
self.current_table = name
|
|
312
|
+
self._refresh_schema()
|
|
313
|
+
print(f" {name}: {len(df)} rows")
|
|
314
|
+
|
|
315
|
+
def ask(self, q: str, viz: Union[bool, str] = False, table: Optional[str] = None) -> 'QueryResult':
|
|
316
|
+
"""Query."""
|
|
317
|
+
if not self.client:
|
|
318
|
+
return QueryResult(False, "", pd.DataFrame(), None, "No API")
|
|
319
|
+
|
|
320
|
+
t = table or self.current_table or (self._get_tables()[0] if self._get_tables() else None)
|
|
321
|
+
if not t:
|
|
322
|
+
return QueryResult(False, "", pd.DataFrame(), None, "No table")
|
|
323
|
+
|
|
324
|
+
if self.use_embeddings and self.embedding_model:
|
|
325
|
+
cached = self._check_cache(q, t)
|
|
326
|
+
if cached:
|
|
327
|
+
return cached
|
|
328
|
+
|
|
329
|
+
if self.fuzzy_match:
|
|
330
|
+
q = self._fuzzy(q, t)
|
|
331
|
+
|
|
332
|
+
key = hashlib.md5(f"{q}:{t}".encode()).hexdigest()
|
|
333
|
+
if self.cache_queries and self.cache and key in self.cache:
|
|
334
|
+
sql = self.cache[key]
|
|
335
|
+
else:
|
|
336
|
+
sql = self._gen_sql(q, t)
|
|
337
|
+
if self.cache_queries and self.cache:
|
|
338
|
+
self.cache[key] = sql
|
|
339
|
+
|
|
340
|
+
print(f"SQL: {sql}")
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
df = pd.read_sql_query(sql, self.conn)
|
|
344
|
+
print(f"Success! {len(df)} rows")
|
|
345
|
+
fig = self._viz(df, q, viz if isinstance(viz, str) else "auto") if viz else None
|
|
346
|
+
r = QueryResult(True, sql, df, fig)
|
|
347
|
+
|
|
348
|
+
if self.use_embeddings and self.embedding_model:
|
|
349
|
+
self._store_cache(q, t, r)
|
|
350
|
+
|
|
351
|
+
return r
|
|
352
|
+
except Exception as e:
|
|
353
|
+
return QueryResult(False, sql, pd.DataFrame(), None, str(e))
|
|
354
|
+
|
|
355
|
+
def _fuzzy(self, q: str, t: str) -> str:
|
|
356
|
+
"""Fuzzy match."""
|
|
357
|
+
try:
|
|
358
|
+
cols = [c for c, d in self.schema_info.get(t, {}).items() if 'TEXT' in d]
|
|
359
|
+
if not cols:
|
|
360
|
+
return q
|
|
361
|
+
for col in cols[:2]:
|
|
362
|
+
df = pd.read_sql_query(f"SELECT DISTINCT {col} FROM {t} LIMIT 100", self.conn)
|
|
363
|
+
vals = [str(v) for v in df[col].dropna()]
|
|
364
|
+
words = q.split()
|
|
365
|
+
for i, w in enumerate(words):
|
|
366
|
+
m = get_close_matches(w, vals, n=1, cutoff=0.6)
|
|
367
|
+
if m and w != m[0]:
|
|
368
|
+
words[i] = m[0]
|
|
369
|
+
q = " ".join(words)
|
|
370
|
+
return q
|
|
371
|
+
except:
|
|
372
|
+
return q
|
|
373
|
+
|
|
374
|
+
def _check_cache(self, q: str, t: str) -> Optional['QueryResult']:
|
|
375
|
+
"""Check cache."""
|
|
376
|
+
if not self.query_embeddings:
|
|
377
|
+
return None
|
|
378
|
+
emb = self.embedding_model.encode([q])[0]
|
|
379
|
+
best, sim = None, 0.85
|
|
380
|
+
for cq, d in self.query_embeddings.items():
|
|
381
|
+
if d['table'] != t:
|
|
382
|
+
continue
|
|
383
|
+
s = np.dot(emb, d['embedding']) / (np.linalg.norm(emb) * np.linalg.norm(d['embedding']))
|
|
384
|
+
if s > sim:
|
|
385
|
+
sim, best = s, cq
|
|
386
|
+
return self.query_embeddings[best]['result'] if best else None
|
|
387
|
+
|
|
388
|
+
def _store_cache(self, q: str, t: str, r: 'QueryResult'):
|
|
389
|
+
"""Store cache."""
|
|
390
|
+
emb = self.embedding_model.encode([q])[0]
|
|
391
|
+
self.query_embeddings[q] = {'table': t, 'embedding': emb, 'result': r}
|
|
392
|
+
|
|
393
|
+
def _viz(self, df: pd.DataFrame, title: str, vt: str):
|
|
394
|
+
"""Viz."""
|
|
395
|
+
if not HAS_PLOTLY:
|
|
396
|
+
return None
|
|
397
|
+
try:
|
|
398
|
+
n = df.select_dtypes(include=[np.number]).columns.tolist()
|
|
399
|
+
c = df.select_dtypes(include=['object']).columns.tolist()
|
|
400
|
+
if vt == "pie" and c and n:
|
|
401
|
+
fig = px.pie(df, names=c[0], values=n[0], title=title)
|
|
402
|
+
elif vt == "bar" and c and n:
|
|
403
|
+
fig = px.bar(df, x=c[0], y=n[0], title=title)
|
|
404
|
+
elif vt == "line" and n:
|
|
405
|
+
fig = px.line(df, y=n[0], title=title)
|
|
406
|
+
elif vt == "scatter" and len(n) >= 2:
|
|
407
|
+
fig = px.scatter(df, x=n[0], y=n[1], title=title)
|
|
408
|
+
else:
|
|
409
|
+
fig = px.bar(df, y=df.columns[0], title=title)
|
|
410
|
+
fig.show()
|
|
411
|
+
return fig
|
|
412
|
+
except:
|
|
413
|
+
return None
|
|
414
|
+
|
|
415
|
+
def tables(self) -> Dict:
|
|
416
|
+
"""List tables."""
|
|
417
|
+
t = self._get_tables()
|
|
418
|
+
print("\n" + "="*70)
|
|
419
|
+
print("TABLES")
|
|
420
|
+
print("="*70)
|
|
421
|
+
if not t:
|
|
422
|
+
print("No tables")
|
|
423
|
+
return {}
|
|
424
|
+
r = {}
|
|
425
|
+
for i, tb in enumerate(t, 1):
|
|
426
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {tb}", self.conn).iloc[0, 0]
|
|
427
|
+
cols = list(self.schema_info.get(tb, {}).keys())
|
|
428
|
+
print(f" {i}. {tb}: {cnt} rows, {len(cols)} cols")
|
|
429
|
+
r[tb] = {'rows': cnt, 'columns': cols}
|
|
430
|
+
print("="*70)
|
|
431
|
+
return r
|
|
432
|
+
|
|
433
|
+
def schema(self, table: Optional[str] = None) -> Dict:
|
|
434
|
+
"""Schema."""
|
|
435
|
+
if not self.schema_info:
|
|
436
|
+
self._refresh_schema()
|
|
437
|
+
print("\n" + "="*70)
|
|
438
|
+
print("SCHEMA")
|
|
439
|
+
print("="*70)
|
|
440
|
+
r = {}
|
|
441
|
+
for t in ([table] if table else self.schema_info.keys()):
|
|
442
|
+
if t in self.schema_info:
|
|
443
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {t}", self.conn).iloc[0, 0]
|
|
444
|
+
print(f"\n{t}: {cnt} records")
|
|
445
|
+
for c, d in self.schema_info[t].items():
|
|
446
|
+
print(f" - {c:<30} {d}")
|
|
447
|
+
r[t] = {'records': cnt, 'columns': self.schema_info[t]}
|
|
448
|
+
print("="*70)
|
|
449
|
+
return r
|
|
450
|
+
|
|
451
|
+
def peek(self, table: Optional[str] = None, n: int = 5) -> pd.DataFrame:
|
|
452
|
+
"""Preview."""
|
|
453
|
+
t = table or self.current_table
|
|
454
|
+
if not t:
|
|
455
|
+
return pd.DataFrame()
|
|
456
|
+
df = pd.read_sql_query(f"SELECT * FROM {t} LIMIT {n}", self.conn)
|
|
457
|
+
print(f"\nSample from '{t}':")
|
|
458
|
+
print(df.to_string(index=False))
|
|
459
|
+
return df
|
|
460
|
+
|
|
461
|
+
def sql(self, query: str, viz: Union[bool, str] = False) -> 'QueryResult':
|
|
462
|
+
"""SQL."""
|
|
463
|
+
try:
|
|
464
|
+
df = pd.read_sql_query(query, self.conn)
|
|
465
|
+
print(f"Success! {len(df)} rows")
|
|
466
|
+
fig = self._viz(df, "Result", viz if isinstance(viz, str) else "auto") if viz else None
|
|
467
|
+
return QueryResult(True, query, df, fig)
|
|
468
|
+
except Exception as e:
|
|
469
|
+
return QueryResult(False, query, pd.DataFrame(), None, str(e))
|
|
470
|
+
|
|
471
|
+
def save_to_mysql(self, host: str, user: str, password: str, database: str, port: int = 3306):
|
|
472
|
+
"""MySQL export."""
|
|
473
|
+
try:
|
|
474
|
+
from sqlalchemy import create_engine
|
|
475
|
+
import mysql.connector
|
|
476
|
+
except:
|
|
477
|
+
raise ImportError("pip install QuerySUTRA[mysql]")
|
|
478
|
+
|
|
479
|
+
print(f"Exporting to MySQL: {database}")
|
|
480
|
+
|
|
481
|
+
try:
|
|
482
|
+
tc = mysql.connector.connect(host=host, user=user, password=password, port=port)
|
|
483
|
+
tc.cursor().execute(f"CREATE DATABASE IF NOT EXISTS `{database}`")
|
|
484
|
+
tc.close()
|
|
485
|
+
except:
|
|
486
|
+
pass
|
|
487
|
+
|
|
488
|
+
engine = create_engine(f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}")
|
|
489
|
+
for t in self._get_tables():
|
|
490
|
+
df = pd.read_sql_query(f"SELECT * FROM {t}", self.conn)
|
|
491
|
+
df.to_sql(t, engine, if_exists='replace', index=False)
|
|
492
|
+
print(f" {t}: {len(df)} rows")
|
|
493
|
+
print("Done!")
|
|
494
|
+
return self
|
|
495
|
+
|
|
496
|
+
def export_db(self, path: str, format: str = "sqlite"):
|
|
497
|
+
"""Export."""
|
|
498
|
+
if format == "sqlite":
|
|
499
|
+
shutil.copy2(self.db_path, path)
|
|
500
|
+
elif format == "json":
|
|
501
|
+
data = {t: pd.read_sql_query(f"SELECT * FROM {t}", self.conn).to_dict('records') for t in self._get_tables()}
|
|
502
|
+
with open(path, 'w') as f:
|
|
503
|
+
json.dump(data, f, indent=2, default=str)
|
|
504
|
+
print(f"Saved: {path}")
|
|
505
|
+
return self
|
|
506
|
+
|
|
507
|
+
@classmethod
|
|
508
|
+
def load_from_db(cls, db_path: str, api_key: Optional[str] = None, **kwargs):
|
|
509
|
+
"""Load database."""
|
|
510
|
+
if not Path(db_path).exists():
|
|
511
|
+
raise FileNotFoundError(f"Not found: {db_path}")
|
|
512
|
+
return cls(api_key=api_key, db=db_path, **kwargs)
|
|
513
|
+
|
|
514
|
+
@classmethod
|
|
515
|
+
def connect_mysql(cls, host: str, user: str, password: str, database: str, port: int = 3306, api_key: Optional[str] = None, **kwargs):
|
|
516
|
+
"""Connect MySQL."""
|
|
517
|
+
try:
|
|
518
|
+
from sqlalchemy import create_engine
|
|
519
|
+
import mysql.connector
|
|
520
|
+
except:
|
|
521
|
+
raise ImportError("pip install QuerySUTRA[mysql]")
|
|
522
|
+
|
|
523
|
+
try:
|
|
524
|
+
tc = mysql.connector.connect(host=host, user=user, password=password, port=port)
|
|
525
|
+
tc.cursor().execute(f"CREATE DATABASE IF NOT EXISTS {database}")
|
|
526
|
+
tc.close()
|
|
527
|
+
except:
|
|
528
|
+
pass
|
|
529
|
+
|
|
530
|
+
engine = create_engine(f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}")
|
|
531
|
+
temp_db = f"mysql_{database}.db"
|
|
532
|
+
instance = cls(api_key=api_key, db=temp_db, **kwargs)
|
|
533
|
+
|
|
534
|
+
tables = pd.read_sql_query("SHOW TABLES", engine).iloc[:, 0].tolist()
|
|
535
|
+
for t in tables:
|
|
536
|
+
pd.read_sql_query(f"SELECT * FROM {t}", engine).to_sql(t, instance.conn, if_exists='replace', index=False)
|
|
537
|
+
|
|
538
|
+
instance._refresh_schema()
|
|
539
|
+
print(f"Connected! {len(tables)} tables")
|
|
540
|
+
return instance
|
|
541
|
+
|
|
542
|
+
def _gen_sql(self, q: str, t: str) -> str:
|
|
543
|
+
"""Generate SQL."""
|
|
544
|
+
schema = self.schema_info.get(t, {})
|
|
545
|
+
sample = pd.read_sql_query(f"SELECT * FROM {t} LIMIT 3", self.conn).to_string(index=False)
|
|
546
|
+
cols = ", ".join([f"{c} ({d})" for c, d in schema.items()])
|
|
547
|
+
|
|
548
|
+
r = self.client.chat.completions.create(
|
|
549
|
+
model="gpt-4o-mini",
|
|
550
|
+
messages=[
|
|
551
|
+
{"role": "system", "content": "SQL expert. Return only SQL."},
|
|
552
|
+
{"role": "user", "content": f"Table: {t}\nColumns: {cols}\nSample:\n{sample}\n\nQ: {q}\n\nSQL:"}
|
|
553
|
+
],
|
|
554
|
+
temperature=0
|
|
555
|
+
)
|
|
556
|
+
return r.choices[0].message.content.strip().replace("```sql", "").replace("```", "").strip()
|
|
557
|
+
|
|
558
|
+
def _get_tables(self) -> List[str]:
|
|
559
|
+
"""Tables."""
|
|
560
|
+
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
561
|
+
return [r[0] for r in self.cursor.fetchall()]
|
|
562
|
+
|
|
563
|
+
def _refresh_schema(self):
|
|
564
|
+
"""Refresh."""
|
|
565
|
+
self.schema_info = {}
|
|
566
|
+
for t in self._get_tables():
|
|
567
|
+
self.cursor.execute(f"PRAGMA table_info({t})")
|
|
568
|
+
self.schema_info[t] = {r[1]: r[2] for r in self.cursor.fetchall()}
|
|
569
|
+
|
|
570
|
+
def close(self):
|
|
571
|
+
if self.conn:
|
|
572
|
+
self.conn.close()
|
|
573
|
+
|
|
574
|
+
def __enter__(self):
|
|
575
|
+
return self
|
|
576
|
+
|
|
577
|
+
def __exit__(self, *args):
|
|
578
|
+
self.close()
|
|
579
|
+
|
|
580
|
+
def __repr__(self):
|
|
581
|
+
return f"SUTRA(tables={len(self.schema_info)})"
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
class QueryResult:
|
|
585
|
+
"""Result."""
|
|
586
|
+
def __init__(self, success: bool, sql: str, data: pd.DataFrame, viz, error: str = None):
|
|
587
|
+
self.success, self.sql, self.data, self.viz, self.error = success, sql, data, viz, error
|
|
588
|
+
|
|
589
|
+
def __repr__(self):
|
|
590
|
+
return f"QueryResult(rows={len(self.data)})" if self.success else f"QueryResult(error='{self.error}')"
|
|
591
|
+
|
|
592
|
+
def show(self):
|
|
593
|
+
print(self.data if self.success else f"Error: {self.error}")
|
|
594
|
+
return self
|