QuerySUTRA 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {querysutra-0.5.1.dist-info → querysutra-0.5.3.dist-info}/METADATA +1 -1
- {querysutra-0.5.1.dist-info → querysutra-0.5.3.dist-info}/RECORD +7 -7
- sutra/__init__.py +2 -2
- sutra/sutra.py +141 -140
- {querysutra-0.5.1.dist-info → querysutra-0.5.3.dist-info}/WHEEL +0 -0
- {querysutra-0.5.1.dist-info → querysutra-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {querysutra-0.5.1.dist-info → querysutra-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
querysutra-0.5.
|
|
2
|
-
sutra/__init__.py,sha256=
|
|
1
|
+
querysutra-0.5.3.dist-info/licenses/LICENSE,sha256=F-4b93u0OVrVwGXgMwBRq6MlGyUT9zmre1oh5Gft5Ts,1066
|
|
2
|
+
sutra/__init__.py,sha256=25HUMETpmA1tlMl5j-ajdo9MRXljSZBrirSTH7w7jIc,118
|
|
3
3
|
sutra/cache_manager.py,sha256=e0AAeUqoR-aiqzZ3fB-IDvpJ4JA6-YBFyRJxusEnIrA,3082
|
|
4
4
|
sutra/clear_cache.py,sha256=rVIz29p7V11Uh6oHXeaWpFtYXXv-2OED91cHMAWWxtQ,187
|
|
5
5
|
sutra/core.py,sha256=R_JbOlZTukegP92Dr-WLsdr632_otFN7o9qSvcxyBtw,10497
|
|
@@ -11,7 +11,7 @@ sutra/feedback_matcher.py,sha256=WXYpGtFJnOyYQOzy-z8uBiUWH5vyJJOMS1NwEYzNfic,286
|
|
|
11
11
|
sutra/nlp_processor.py,sha256=wMS1hz1aGWjSwPUD7lSNBbQapFtLgF2l65j0QKXQOd0,5461
|
|
12
12
|
sutra/schema_embeddings.py,sha256=bVPzpJOdYTyUdG2k3ZdgYJLrX2opHBx68RIjJcMlueo,9732
|
|
13
13
|
sutra/schema_generator.py,sha256=BX_vXmnvSGc6nCBx40WLSoNL3WIYPDahd1cEYloyY4M,1925
|
|
14
|
-
sutra/sutra.py,sha256=
|
|
14
|
+
sutra/sutra.py,sha256=61juV3zlMau4UZJ-5IxjaN-Bc1XBP8w2vkYfum-aXlY,21979
|
|
15
15
|
sutra/sutra_client.py,sha256=PYYDGqVbA9pB-Zcsm52i9KarwijCIGVZOThgONZP6Vs,14203
|
|
16
16
|
sutra/sutra_core.py,sha256=diaWOXUHn1wrqCQrBhLKL612tMQioaqx-ILc3y9-CqM,11708
|
|
17
17
|
sutra/sutra_simple.py,sha256=rnqzG7OAt4p64XtO0peMqHS1pG5tdA8U3EYTMVsq7BE,23201
|
|
@@ -22,7 +22,7 @@ tests/test_sutra.py,sha256=6Z4SoIuBzza101304I7plkyPVkUBbjIxR8uPs9z5ntg,2383
|
|
|
22
22
|
utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
23
|
utils/file_utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
utils/text_utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
-
querysutra-0.5.
|
|
26
|
-
querysutra-0.5.
|
|
27
|
-
querysutra-0.5.
|
|
28
|
-
querysutra-0.5.
|
|
25
|
+
querysutra-0.5.3.dist-info/METADATA,sha256=yFffBSYGfbLrYnXA7OFGHk1mO37fpUV-0iglmHXbAVQ,7258
|
|
26
|
+
querysutra-0.5.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
27
|
+
querysutra-0.5.3.dist-info/top_level.txt,sha256=9v0buw21eo5LaUU_3Cf9b9MqRyEvtM9cHaOuEXUKVqM,18
|
|
28
|
+
querysutra-0.5.3.dist-info/RECORD,,
|
sutra/__init__.py
CHANGED
sutra/sutra.py
CHANGED
|
@@ -1,10 +1,5 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
GUARANTEED to create multiple tables with proper keys
|
|
4
|
-
NEVER falls back to single table
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
__version__ = "0.5.0"
|
|
1
|
+
"""QuerySUTRA v0.5.2 - FIXED: Smart table selection"""
|
|
2
|
+
__version__ = "0.5.2"
|
|
8
3
|
__author__ = "Aditya Batta"
|
|
9
4
|
__all__ = ["SUTRA", "QueryResult"]
|
|
10
5
|
|
|
@@ -46,7 +41,7 @@ except:
|
|
|
46
41
|
|
|
47
42
|
|
|
48
43
|
class SUTRA:
|
|
49
|
-
"""SUTRA -
|
|
44
|
+
"""SUTRA - FIXED: Considers ALL tables"""
|
|
50
45
|
|
|
51
46
|
def __init__(self, api_key: Optional[str] = None, db: str = "sutra.db",
|
|
52
47
|
use_embeddings: bool = False, fuzzy_match: bool = True,
|
|
@@ -77,10 +72,10 @@ class SUTRA:
|
|
|
77
72
|
pass
|
|
78
73
|
|
|
79
74
|
self._refresh_schema()
|
|
80
|
-
print(f"QuerySUTRA v0.5.
|
|
75
|
+
print(f"QuerySUTRA v0.5.2 Ready")
|
|
81
76
|
|
|
82
77
|
def upload(self, data: Union[str, pd.DataFrame], name: Optional[str] = None) -> 'SUTRA':
|
|
83
|
-
"""Upload
|
|
78
|
+
"""Upload."""
|
|
84
79
|
if isinstance(data, pd.DataFrame):
|
|
85
80
|
self._store(data, name or "data")
|
|
86
81
|
return self
|
|
@@ -110,7 +105,7 @@ class SUTRA:
|
|
|
110
105
|
return self
|
|
111
106
|
|
|
112
107
|
def _pdf(self, path: Path, name: str):
|
|
113
|
-
"""
|
|
108
|
+
"""PDF extraction."""
|
|
114
109
|
if not HAS_PYPDF2:
|
|
115
110
|
raise ImportError("pip install PyPDF2")
|
|
116
111
|
|
|
@@ -120,164 +115,119 @@ class SUTRA:
|
|
|
120
115
|
text = "".join([p.extract_text() + "\n" for p in PyPDF2.PdfReader(f).pages])
|
|
121
116
|
|
|
122
117
|
if not self.client:
|
|
123
|
-
print("No API key
|
|
124
|
-
self._store(pd.DataFrame({'line': range(1, len(text.split('\n'))), 'text': text.split('\n')}), name)
|
|
118
|
+
print("ERROR: No API key!")
|
|
125
119
|
return
|
|
126
120
|
|
|
127
|
-
print("AI: Extracting
|
|
121
|
+
print("AI: Extracting...")
|
|
128
122
|
|
|
129
|
-
# TRY 3 TIMES with progressively simpler prompts
|
|
130
123
|
entities = None
|
|
124
|
+
for attempt in [1, 2, 3]:
|
|
125
|
+
entities = self._extract(text, attempt)
|
|
126
|
+
if entities and len(entities) > 0:
|
|
127
|
+
break
|
|
128
|
+
if attempt < 3:
|
|
129
|
+
print(f" Retry {attempt+1}/3...")
|
|
131
130
|
|
|
132
|
-
# ATTEMPT 1: Full extraction
|
|
133
|
-
entities = self._extract(text, attempt=1)
|
|
134
|
-
|
|
135
|
-
# ATTEMPT 2: Simpler prompt
|
|
136
|
-
if not entities or len(entities) == 0:
|
|
137
|
-
print(" Retry with simpler prompt...")
|
|
138
|
-
entities = self._extract(text, attempt=2)
|
|
139
|
-
|
|
140
|
-
# ATTEMPT 3: Basic extraction
|
|
141
|
-
if not entities or len(entities) == 0:
|
|
142
|
-
print(" Final retry with basic prompt...")
|
|
143
|
-
entities = self._extract(text, attempt=3)
|
|
144
|
-
|
|
145
|
-
# SUCCESS - Create tables
|
|
146
131
|
if entities and len(entities) > 0:
|
|
147
|
-
print(f"
|
|
132
|
+
print(f"Extracted {len(entities)} entity types:")
|
|
148
133
|
for etype, recs in entities.items():
|
|
149
134
|
if recs and len(recs) > 0:
|
|
150
|
-
# Renumber IDs
|
|
151
135
|
for idx, rec in enumerate(recs, 1):
|
|
152
136
|
rec['id'] = idx
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
self._store(df, f"{name}_{etype}")
|
|
156
|
-
print(f" {etype}: {len(df)} rows")
|
|
137
|
+
self._store(pd.DataFrame(recs), f"{name}_{etype}")
|
|
138
|
+
print(f" {etype}: {len(recs)} rows")
|
|
157
139
|
return
|
|
158
140
|
|
|
159
|
-
|
|
160
|
-
print("WARNING: AI extraction failed 3 times - using text analysis...")
|
|
161
|
-
|
|
162
|
-
# Try to extract at least names/emails with regex
|
|
141
|
+
print("Using regex fallback...")
|
|
163
142
|
people = []
|
|
164
143
|
emails = re.findall(r'[\w\.-]+@[\w\.-]+\.\w+', text)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
144
|
+
name_patterns = [
|
|
145
|
+
r'(?:Employee|Name|Mr\.|Mrs\.|Ms\.|Dr\.)\s*[:\-]?\s*([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)',
|
|
146
|
+
r'([A-Z][a-z]+\s+[A-Z][a-z]+)\s+(?:lives|resides|works|is based)',
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
names = []
|
|
150
|
+
for pattern in name_patterns:
|
|
151
|
+
names.extend(re.findall(pattern, text))
|
|
152
|
+
if len(names) >= len(emails):
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
max_people = min(len(emails), 50)
|
|
156
|
+
for i in range(max_people):
|
|
157
|
+
people.append({
|
|
158
|
+
'id': i + 1,
|
|
159
|
+
'name': names[i] if i < len(names) else f"Person {i+1}",
|
|
160
|
+
'email': emails[i] if i < len(emails) else f"person{i+1}@unknown.com"
|
|
161
|
+
})
|
|
169
162
|
|
|
170
163
|
if people:
|
|
171
164
|
self._store(pd.DataFrame(people), f"{name}_people")
|
|
172
|
-
print(f" Extracted {len(people)} people
|
|
165
|
+
print(f" Extracted {len(people)} people")
|
|
173
166
|
else:
|
|
174
|
-
|
|
175
|
-
self._store(pd.DataFrame({'line': range(1,
|
|
167
|
+
lines = [l.strip() for l in text.split('\n') if l.strip()][:100]
|
|
168
|
+
self._store(pd.DataFrame({'line': range(1, len(lines)+1), 'text': lines}), name)
|
|
176
169
|
|
|
177
170
|
def _extract(self, text: str, attempt: int) -> Dict:
|
|
178
|
-
"""Extract
|
|
171
|
+
"""Extract."""
|
|
179
172
|
if not self.client:
|
|
180
173
|
return {}
|
|
181
174
|
|
|
182
175
|
try:
|
|
183
176
|
if attempt == 1:
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
usr_msg = f"""Extract ALL structured entities from this text into a JSON object.
|
|
177
|
+
sys_msg = "Extract entities as JSON. Return ONLY valid JSON."
|
|
178
|
+
usr_msg = f"""Extract ALL entities.
|
|
187
179
|
|
|
188
|
-
Text
|
|
180
|
+
Text:
|
|
189
181
|
{text[:15000]}
|
|
190
182
|
|
|
191
|
-
|
|
192
|
-
- people: id (int), name (str), email (str), phone (str), address (str), city (str), state (str), zip (str)
|
|
193
|
-
- skills: id (int), person_id (int), skill_name (str), proficiency (str), years (int)
|
|
194
|
-
- technologies: id (int), person_id (int), technology (str), category (str), proficiency (str)
|
|
195
|
-
- projects: id (int), person_id (int), project_name (str), description (str), start_date (str), end_date (str)
|
|
196
|
-
- certifications: id (int), person_id (int), cert_name (str), issuer (str), date_obtained (str)
|
|
197
|
-
- education: id (int), person_id (int), degree (str), institution (str), graduation_year (str)
|
|
198
|
-
- work_experience: id (int), person_id (int), company (str), title (str), start_date (str), end_date (str)
|
|
183
|
+
JSON with: people, skills, technologies, projects, certifications, education, work_experience
|
|
199
184
|
|
|
200
|
-
|
|
201
|
-
1.
|
|
202
|
-
2. person_id in related tables MUST reference valid people.id values
|
|
203
|
-
3. Extract EVERY person, skill, technology, project you find
|
|
204
|
-
4. Return ONLY the JSON object, no markdown, no explanations
|
|
185
|
+
Example:
|
|
186
|
+
{{"people":[{{"id":1,"name":"Sarah","email":"s@co.com","city":"NYC","state":"NY"}}],"skills":[{{"id":1,"person_id":1,"skill_name":"Python"}}]}}
|
|
205
187
|
|
|
206
|
-
|
|
207
|
-
{{
|
|
208
|
-
"people": [
|
|
209
|
-
{{"id": 1, "name": "Sarah Johnson", "email": "sarah@company.com", "phone": "(212) 555-0147", "city": "New York", "state": "NY"}},
|
|
210
|
-
{{"id": 2, "name": "Michael Chen", "email": "michael@company.com", "phone": "(415) 555-0283", "city": "San Francisco", "state": "CA"}}
|
|
211
|
-
],
|
|
212
|
-
"skills": [
|
|
213
|
-
{{"id": 1, "person_id": 1, "skill_name": "Python", "proficiency": "Expert", "years": 5}},
|
|
214
|
-
{{"id": 2, "person_id": 1, "skill_name": "SQL", "proficiency": "Advanced", "years": 3}},
|
|
215
|
-
{{"id": 3, "person_id": 2, "skill_name": "Product Management", "proficiency": "Expert", "years": 7}}
|
|
216
|
-
]
|
|
217
|
-
}}
|
|
188
|
+
Unique IDs (1,2,3...), person_id links to people.id
|
|
218
189
|
|
|
219
|
-
|
|
220
|
-
|
|
190
|
+
JSON:"""
|
|
221
191
|
elif attempt == 2:
|
|
222
|
-
|
|
223
|
-
sys_msg = "Extract entities as JSON. Return only JSON."
|
|
192
|
+
sys_msg = "Return JSON."
|
|
224
193
|
usr_msg = f"""Text: {text[:10000]}
|
|
225
194
|
|
|
226
|
-
Extract people
|
|
227
|
-
{{"people":[{{"id":1,"name":"...","email":"..."
|
|
228
|
-
|
|
229
|
-
Rules: Unique IDs (1,2,3...), person_id links to people.id
|
|
195
|
+
Extract people:
|
|
196
|
+
{{"people":[{{"id":1,"name":"...","email":"..."}}]}}
|
|
230
197
|
|
|
231
|
-
JSON
|
|
232
|
-
|
|
198
|
+
JSON:"""
|
|
233
199
|
else:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
usr_msg = f"""Text: {text[:8000]}
|
|
200
|
+
sys_msg = "JSON."
|
|
201
|
+
usr_msg = f"""Names/emails from: {text[:8000]}
|
|
237
202
|
|
|
238
|
-
|
|
239
|
-
{{"people":[{{"id":1,"name":"John","email":"john@co.com","city":"NYC"}}]}}
|
|
240
|
-
|
|
241
|
-
JSON:"""
|
|
203
|
+
{{"people":[{{"id":1,"name":"John","email":"j@co.com"}}]}}"""
|
|
242
204
|
|
|
243
|
-
|
|
205
|
+
r = self.client.chat.completions.create(
|
|
244
206
|
model="gpt-4o-mini",
|
|
245
|
-
messages=[
|
|
246
|
-
{"role": "system", "content": sys_msg},
|
|
247
|
-
{"role": "user", "content": usr_msg}
|
|
248
|
-
],
|
|
207
|
+
messages=[{"role": "system", "content": sys_msg}, {"role": "user", "content": usr_msg}],
|
|
249
208
|
temperature=0,
|
|
250
209
|
max_tokens=12000
|
|
251
210
|
)
|
|
252
211
|
|
|
253
|
-
raw =
|
|
254
|
-
|
|
255
|
-
# AGGRESSIVE JSON extraction
|
|
256
|
-
raw = raw.replace("```json", "").replace("```", "").replace("JSON:", "").replace("json", "").strip()
|
|
212
|
+
raw = r.choices[0].message.content.strip()
|
|
213
|
+
raw = raw.replace("```json", "").replace("```", "").replace("JSON:", "").strip()
|
|
257
214
|
|
|
258
|
-
# Find JSON object
|
|
259
215
|
start = raw.find('{')
|
|
260
216
|
end = raw.rfind('}') + 1
|
|
261
217
|
|
|
262
218
|
if start < 0 or end <= start:
|
|
263
219
|
return {}
|
|
264
220
|
|
|
265
|
-
|
|
221
|
+
result = json.loads(raw[start:end])
|
|
266
222
|
|
|
267
|
-
# Parse
|
|
268
|
-
result = json.loads(json_str)
|
|
269
|
-
|
|
270
|
-
# Validate
|
|
271
223
|
if isinstance(result, dict) and len(result) > 0:
|
|
272
|
-
# Check if at least one entity type has data
|
|
273
224
|
has_data = any(isinstance(v, list) and len(v) > 0 for v in result.values())
|
|
274
225
|
if has_data:
|
|
275
226
|
return result
|
|
276
|
-
|
|
277
227
|
return {}
|
|
278
228
|
|
|
279
229
|
except Exception as e:
|
|
280
|
-
print(f" Attempt {attempt} failed: {e}")
|
|
230
|
+
print(f" Attempt {attempt} failed: {str(e)[:100]}")
|
|
281
231
|
return {}
|
|
282
232
|
|
|
283
233
|
def _docx(self, path: Path, name: str):
|
|
@@ -292,13 +242,15 @@ JSON:"""
|
|
|
292
242
|
self._store(pd.DataFrame(data[1:], columns=data[0]), f"{name}_t{i+1}")
|
|
293
243
|
else:
|
|
294
244
|
text = "\n".join([p.text for p in doc.paragraphs])
|
|
295
|
-
|
|
245
|
+
lines = [l.strip() for l in text.split('\n') if l.strip()]
|
|
246
|
+
self._store(pd.DataFrame({'line': range(1, len(lines)+1), 'text': lines}), name)
|
|
296
247
|
|
|
297
248
|
def _txt(self, path: Path, name: str):
|
|
298
249
|
"""TXT."""
|
|
299
250
|
with open(path, 'r', encoding='utf-8') as f:
|
|
300
251
|
text = f.read()
|
|
301
|
-
|
|
252
|
+
lines = [l.strip() for l in text.split('\n') if l.strip()]
|
|
253
|
+
self._store(pd.DataFrame({'line': range(1, len(lines)+1), 'text': lines}), name)
|
|
302
254
|
|
|
303
255
|
def _store(self, df: pd.DataFrame, name: str):
|
|
304
256
|
"""Store."""
|
|
@@ -313,29 +265,32 @@ JSON:"""
|
|
|
313
265
|
print(f" {name}: {len(df)} rows")
|
|
314
266
|
|
|
315
267
|
def ask(self, q: str, viz: Union[bool, str] = False, table: Optional[str] = None) -> 'QueryResult':
|
|
316
|
-
"""
|
|
268
|
+
"""
|
|
269
|
+
Query - FIXED: Considers ALL tables, picks best one or joins multiple.
|
|
270
|
+
"""
|
|
317
271
|
if not self.client:
|
|
318
272
|
return QueryResult(False, "", pd.DataFrame(), None, "No API")
|
|
319
273
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
274
|
+
print(f"\nQuestion: {q}")
|
|
275
|
+
|
|
276
|
+
# FIXED: If no table specified, let AI pick the right one(s)
|
|
277
|
+
if not table:
|
|
278
|
+
# Get ALL table schemas
|
|
279
|
+
all_schemas = {}
|
|
280
|
+
for tbl in self._get_tables():
|
|
281
|
+
all_schemas[tbl] = {
|
|
282
|
+
'columns': list(self.schema_info.get(tbl, {}).keys()),
|
|
283
|
+
'row_count': pd.read_sql_query(f"SELECT COUNT(*) FROM {tbl}", self.conn).iloc[0, 0]
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
# Let AI decide which table(s) to use
|
|
287
|
+
sql = self._gen_sql_smart(q, all_schemas)
|
|
288
|
+
else:
|
|
289
|
+
# Use specified table
|
|
290
|
+
sql = self._gen_sql(q, table)
|
|
328
291
|
|
|
329
292
|
if self.fuzzy_match:
|
|
330
|
-
q = self._fuzzy(q,
|
|
331
|
-
|
|
332
|
-
key = hashlib.md5(f"{q}:{t}".encode()).hexdigest()
|
|
333
|
-
if self.cache_queries and self.cache and key in self.cache:
|
|
334
|
-
sql = self.cache[key]
|
|
335
|
-
else:
|
|
336
|
-
sql = self._gen_sql(q, t)
|
|
337
|
-
if self.cache_queries and self.cache:
|
|
338
|
-
self.cache[key] = sql
|
|
293
|
+
q = self._fuzzy(q, table or self._get_tables()[0])
|
|
339
294
|
|
|
340
295
|
print(f"SQL: {sql}")
|
|
341
296
|
|
|
@@ -346,14 +301,60 @@ JSON:"""
|
|
|
346
301
|
r = QueryResult(True, sql, df, fig)
|
|
347
302
|
|
|
348
303
|
if self.use_embeddings and self.embedding_model:
|
|
349
|
-
self._store_cache(q,
|
|
304
|
+
self._store_cache(q, table or "all", r)
|
|
350
305
|
|
|
351
306
|
return r
|
|
352
307
|
except Exception as e:
|
|
308
|
+
print(f"Error: {e}")
|
|
353
309
|
return QueryResult(False, sql, pd.DataFrame(), None, str(e))
|
|
354
310
|
|
|
311
|
+
def _gen_sql_smart(self, q: str, all_schemas: Dict) -> str:
|
|
312
|
+
"""
|
|
313
|
+
FIXED: Generate SQL considering ALL tables and their relationships.
|
|
314
|
+
"""
|
|
315
|
+
# Build context with ALL tables
|
|
316
|
+
schema_context = "Database has these tables:\n"
|
|
317
|
+
for tbl, info in all_schemas.items():
|
|
318
|
+
schema_context += f"\n{tbl} ({info['row_count']} rows):\n"
|
|
319
|
+
schema_context += f" Columns: {', '.join(info['columns'])}\n"
|
|
320
|
+
|
|
321
|
+
# Add sample data from key tables
|
|
322
|
+
samples = ""
|
|
323
|
+
for tbl in list(all_schemas.keys())[:3]: # First 3 tables
|
|
324
|
+
try:
|
|
325
|
+
sample_df = pd.read_sql_query(f"SELECT * FROM {tbl} LIMIT 2", self.conn)
|
|
326
|
+
samples += f"\nSample from {tbl}:\n{sample_df.to_string(index=False)}\n"
|
|
327
|
+
except:
|
|
328
|
+
pass
|
|
329
|
+
|
|
330
|
+
prompt = f"""You are an SQL expert. Generate a query for this question.
|
|
331
|
+
|
|
332
|
+
{schema_context}
|
|
333
|
+
|
|
334
|
+
{samples}
|
|
335
|
+
|
|
336
|
+
Question: {q}
|
|
337
|
+
|
|
338
|
+
Rules:
|
|
339
|
+
1. Use JOIN if question needs data from multiple tables
|
|
340
|
+
2. If asking about "employee" or "person" info, always include employee_data_people table
|
|
341
|
+
3. Use proper foreign key relationships (person_id references people.id)
|
|
342
|
+
4. Return employee names/info when asked "which employee" or "who"
|
|
343
|
+
|
|
344
|
+
Return ONLY the SQL query, no explanations:"""
|
|
345
|
+
|
|
346
|
+
r = self.client.chat.completions.create(
|
|
347
|
+
model="gpt-4o-mini",
|
|
348
|
+
messages=[
|
|
349
|
+
{"role": "system", "content": "SQL expert. Generate queries using proper JOINs. Return only SQL."},
|
|
350
|
+
{"role": "user", "content": prompt}
|
|
351
|
+
],
|
|
352
|
+
temperature=0
|
|
353
|
+
)
|
|
354
|
+
return r.choices[0].message.content.strip().replace("```sql", "").replace("```", "").strip()
|
|
355
|
+
|
|
355
356
|
def _fuzzy(self, q: str, t: str) -> str:
|
|
356
|
-
"""Fuzzy
|
|
357
|
+
"""Fuzzy."""
|
|
357
358
|
try:
|
|
358
359
|
cols = [c for c, d in self.schema_info.get(t, {}).items() if 'TEXT' in d]
|
|
359
360
|
if not cols:
|
|
@@ -372,7 +373,7 @@ JSON:"""
|
|
|
372
373
|
return q
|
|
373
374
|
|
|
374
375
|
def _check_cache(self, q: str, t: str) -> Optional['QueryResult']:
|
|
375
|
-
"""
|
|
376
|
+
"""Cache."""
|
|
376
377
|
if not self.query_embeddings:
|
|
377
378
|
return None
|
|
378
379
|
emb = self.embedding_model.encode([q])[0]
|
|
@@ -386,7 +387,7 @@ JSON:"""
|
|
|
386
387
|
return self.query_embeddings[best]['result'] if best else None
|
|
387
388
|
|
|
388
389
|
def _store_cache(self, q: str, t: str, r: 'QueryResult'):
|
|
389
|
-
"""Store
|
|
390
|
+
"""Store."""
|
|
390
391
|
emb = self.embedding_model.encode([q])[0]
|
|
391
392
|
self.query_embeddings[q] = {'table': t, 'embedding': emb, 'result': r}
|
|
392
393
|
|
|
@@ -413,7 +414,7 @@ JSON:"""
|
|
|
413
414
|
return None
|
|
414
415
|
|
|
415
416
|
def tables(self) -> Dict:
|
|
416
|
-
"""
|
|
417
|
+
"""Tables."""
|
|
417
418
|
t = self._get_tables()
|
|
418
419
|
print("\n" + "="*70)
|
|
419
420
|
print("TABLES")
|
|
@@ -469,7 +470,7 @@ JSON:"""
|
|
|
469
470
|
return QueryResult(False, query, pd.DataFrame(), None, str(e))
|
|
470
471
|
|
|
471
472
|
def save_to_mysql(self, host: str, user: str, password: str, database: str, port: int = 3306):
|
|
472
|
-
"""MySQL
|
|
473
|
+
"""MySQL."""
|
|
473
474
|
try:
|
|
474
475
|
from sqlalchemy import create_engine
|
|
475
476
|
import mysql.connector
|
|
@@ -506,14 +507,14 @@ JSON:"""
|
|
|
506
507
|
|
|
507
508
|
@classmethod
|
|
508
509
|
def load_from_db(cls, db_path: str, api_key: Optional[str] = None, **kwargs):
|
|
509
|
-
"""Load
|
|
510
|
+
"""Load."""
|
|
510
511
|
if not Path(db_path).exists():
|
|
511
512
|
raise FileNotFoundError(f"Not found: {db_path}")
|
|
512
513
|
return cls(api_key=api_key, db=db_path, **kwargs)
|
|
513
514
|
|
|
514
515
|
@classmethod
|
|
515
516
|
def connect_mysql(cls, host: str, user: str, password: str, database: str, port: int = 3306, api_key: Optional[str] = None, **kwargs):
|
|
516
|
-
"""
|
|
517
|
+
"""MySQL."""
|
|
517
518
|
try:
|
|
518
519
|
from sqlalchemy import create_engine
|
|
519
520
|
import mysql.connector
|
|
@@ -540,7 +541,7 @@ JSON:"""
|
|
|
540
541
|
return instance
|
|
541
542
|
|
|
542
543
|
def _gen_sql(self, q: str, t: str) -> str:
|
|
543
|
-
"""
|
|
544
|
+
"""SQL for single table."""
|
|
544
545
|
schema = self.schema_info.get(t, {})
|
|
545
546
|
sample = pd.read_sql_query(f"SELECT * FROM {t} LIMIT 3", self.conn).to_string(index=False)
|
|
546
547
|
cols = ", ".join([f"{c} ({d})" for c, d in schema.items()])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|