QuerySUTRA 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- querysutra-0.4.0.dist-info/METADATA +438 -0
- {querysutra-0.3.3.dist-info → querysutra-0.4.0.dist-info}/RECORD +7 -7
- sutra/__init__.py +2 -4
- sutra/sutra.py +251 -457
- querysutra-0.3.3.dist-info/METADATA +0 -285
- {querysutra-0.3.3.dist-info → querysutra-0.4.0.dist-info}/WHEEL +0 -0
- {querysutra-0.3.3.dist-info → querysutra-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {querysutra-0.3.3.dist-info → querysutra-0.4.0.dist-info}/top_level.txt +0 -0
sutra/sutra.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
"""
|
|
2
|
-
QuerySUTRA v0.3.
|
|
2
|
+
QuerySUTRA v0.3.5 - FIXED COLAB COMPATIBILITY
|
|
3
3
|
SUTRA: Structured-Unstructured-Text-Retrieval-Architecture
|
|
4
4
|
|
|
5
|
-
FIXED:
|
|
6
|
-
-
|
|
7
|
-
-
|
|
8
|
-
-
|
|
9
|
-
-
|
|
5
|
+
FIXED:
|
|
6
|
+
- Colab disk I/O errors resolved
|
|
7
|
+
- Batch processing for large datasets
|
|
8
|
+
- Proper error handling
|
|
9
|
+
- Unique IDs and proper foreign keys
|
|
10
|
+
- Comprehensive entity extraction
|
|
10
11
|
|
|
11
12
|
Author: Aditya Batta
|
|
12
13
|
License: MIT
|
|
13
|
-
Version: 0.3.
|
|
14
|
+
Version: 0.3.5
|
|
14
15
|
"""
|
|
15
16
|
|
|
16
|
-
__version__ = "0.3.
|
|
17
|
+
__version__ = "0.3.5"
|
|
17
18
|
__author__ = "Aditya Batta"
|
|
18
19
|
__title__ = "QuerySUTRA: Structured-Unstructured-Text-Retrieval-Architecture"
|
|
19
20
|
__all__ = ["SUTRA", "QueryResult", "quick_start"]
|
|
@@ -75,8 +76,6 @@ except ImportError:
|
|
|
75
76
|
class SUTRA:
|
|
76
77
|
"""
|
|
77
78
|
SUTRA: Structured-Unstructured-Text-Retrieval-Architecture
|
|
78
|
-
|
|
79
|
-
Professional data analysis with proper relational database structure
|
|
80
79
|
"""
|
|
81
80
|
|
|
82
81
|
def __init__(self,
|
|
@@ -86,9 +85,8 @@ class SUTRA:
|
|
|
86
85
|
check_relevance: bool = False,
|
|
87
86
|
fuzzy_match: bool = True,
|
|
88
87
|
cache_queries: bool = True):
|
|
89
|
-
"""Initialize SUTRA
|
|
90
|
-
print("Initializing QuerySUTRA v0.3.
|
|
91
|
-
print("SUTRA: Structured-Unstructured-Text-Retrieval-Architecture")
|
|
88
|
+
"""Initialize SUTRA."""
|
|
89
|
+
print("Initializing QuerySUTRA v0.3.5")
|
|
92
90
|
|
|
93
91
|
if api_key:
|
|
94
92
|
os.environ["OPENAI_API_KEY"] = api_key
|
|
@@ -97,7 +95,15 @@ class SUTRA:
|
|
|
97
95
|
self.client = OpenAI(api_key=self.api_key) if self.api_key and HAS_OPENAI else None
|
|
98
96
|
|
|
99
97
|
self.db_path = db
|
|
100
|
-
|
|
98
|
+
|
|
99
|
+
# FIXED: Better connection handling for Colab
|
|
100
|
+
try:
|
|
101
|
+
self.conn = sqlite3.connect(db, timeout=30, check_same_thread=False)
|
|
102
|
+
self.conn.execute("PRAGMA journal_mode=WAL")
|
|
103
|
+
self.conn.execute("PRAGMA synchronous=NORMAL")
|
|
104
|
+
except:
|
|
105
|
+
self.conn = sqlite3.connect(db, check_same_thread=False)
|
|
106
|
+
|
|
101
107
|
self.cursor = self.conn.cursor()
|
|
102
108
|
|
|
103
109
|
self.current_table = None
|
|
@@ -115,37 +121,29 @@ class SUTRA:
|
|
|
115
121
|
|
|
116
122
|
if use_embeddings and HAS_EMBEDDINGS:
|
|
117
123
|
try:
|
|
118
|
-
print("Loading embeddings model...")
|
|
119
124
|
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
120
125
|
print("Embeddings ready")
|
|
121
126
|
except:
|
|
122
|
-
print("Embeddings unavailable")
|
|
123
127
|
self.use_embeddings = False
|
|
124
128
|
|
|
125
129
|
self._refresh_schema()
|
|
126
|
-
|
|
127
130
|
print(f"Ready! Database: {db}")
|
|
128
|
-
if not self.api_key:
|
|
129
|
-
print("No API key - use .sql() for direct queries")
|
|
130
131
|
|
|
131
132
|
@classmethod
|
|
132
133
|
def load_from_db(cls, db_path: str, api_key: Optional[str] = None, **kwargs):
|
|
133
|
-
"""Load existing
|
|
134
|
+
"""Load existing database."""
|
|
134
135
|
if not Path(db_path).exists():
|
|
135
136
|
raise FileNotFoundError(f"Database not found: {db_path}")
|
|
136
137
|
|
|
137
|
-
print(f"Loading
|
|
138
|
+
print(f"Loading: {db_path}")
|
|
138
139
|
instance = cls(api_key=api_key, db=db_path, **kwargs)
|
|
139
|
-
|
|
140
|
-
tables = instance.tables()
|
|
141
|
-
print(f"Loaded {len(tables)} tables")
|
|
142
|
-
|
|
140
|
+
print(f"Loaded {len(instance.tables())} tables")
|
|
143
141
|
return instance
|
|
144
142
|
|
|
145
143
|
@classmethod
|
|
146
144
|
def connect_mysql(cls, host: str, user: str, password: str, database: str,
|
|
147
145
|
port: int = 3306, api_key: Optional[str] = None, **kwargs):
|
|
148
|
-
"""Connect to MySQL
|
|
146
|
+
"""Connect to MySQL."""
|
|
149
147
|
try:
|
|
150
148
|
from sqlalchemy import create_engine
|
|
151
149
|
except ImportError:
|
|
@@ -153,16 +151,13 @@ class SUTRA:
|
|
|
153
151
|
|
|
154
152
|
print(f"Connecting to MySQL: {host}:{port}/{database}")
|
|
155
153
|
|
|
156
|
-
|
|
154
|
+
engine = create_engine(f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}")
|
|
157
155
|
|
|
158
156
|
temp_db = f"sutra_mysql_{database}.db"
|
|
159
157
|
instance = cls(api_key=api_key, db=temp_db, **kwargs)
|
|
160
158
|
|
|
161
|
-
engine = create_engine(connection_string)
|
|
162
|
-
|
|
163
159
|
tables = pd.read_sql_query("SHOW TABLES", engine).iloc[:, 0].tolist()
|
|
164
|
-
|
|
165
|
-
print(f"Found {len(tables)} tables, syncing...")
|
|
160
|
+
print(f"Syncing {len(tables)} tables...")
|
|
166
161
|
|
|
167
162
|
for table in tables:
|
|
168
163
|
df = pd.read_sql_query(f"SELECT * FROM {table}", engine)
|
|
@@ -170,14 +165,13 @@ class SUTRA:
|
|
|
170
165
|
print(f" {table}: {len(df)} rows")
|
|
171
166
|
|
|
172
167
|
instance._refresh_schema()
|
|
173
|
-
print(
|
|
174
|
-
|
|
168
|
+
print("Connected!")
|
|
175
169
|
return instance
|
|
176
170
|
|
|
177
171
|
@classmethod
|
|
178
172
|
def connect_postgres(cls, host: str, user: str, password: str, database: str,
|
|
179
173
|
port: int = 5432, api_key: Optional[str] = None, **kwargs):
|
|
180
|
-
"""Connect to PostgreSQL
|
|
174
|
+
"""Connect to PostgreSQL."""
|
|
181
175
|
try:
|
|
182
176
|
from sqlalchemy import create_engine
|
|
183
177
|
except ImportError:
|
|
@@ -185,19 +179,13 @@ class SUTRA:
|
|
|
185
179
|
|
|
186
180
|
print(f"Connecting to PostgreSQL: {host}:{port}/{database}")
|
|
187
181
|
|
|
188
|
-
|
|
182
|
+
engine = create_engine(f"postgresql://{user}:{password}@{host}:{port}/{database}")
|
|
189
183
|
|
|
190
184
|
temp_db = f"sutra_postgres_{database}.db"
|
|
191
185
|
instance = cls(api_key=api_key, db=temp_db, **kwargs)
|
|
192
186
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
tables = pd.read_sql_query(
|
|
196
|
-
"SELECT tablename FROM pg_tables WHERE schemaname='public'",
|
|
197
|
-
engine
|
|
198
|
-
)['tablename'].tolist()
|
|
199
|
-
|
|
200
|
-
print(f"Found {len(tables)} tables, syncing...")
|
|
187
|
+
tables = pd.read_sql_query("SELECT tablename FROM pg_tables WHERE schemaname='public'", engine)['tablename'].tolist()
|
|
188
|
+
print(f"Syncing {len(tables)} tables...")
|
|
201
189
|
|
|
202
190
|
for table in tables:
|
|
203
191
|
df = pd.read_sql_query(f"SELECT * FROM {table}", engine)
|
|
@@ -205,21 +193,13 @@ class SUTRA:
|
|
|
205
193
|
print(f" {table}: {len(df)} rows")
|
|
206
194
|
|
|
207
195
|
instance._refresh_schema()
|
|
208
|
-
print(
|
|
209
|
-
|
|
196
|
+
print("Connected!")
|
|
210
197
|
return instance
|
|
211
198
|
|
|
212
199
|
def upload(self, data: Union[str, pd.DataFrame], name: Optional[str] = None,
|
|
213
200
|
extract_entities: Optional[List[str]] = None) -> 'SUTRA':
|
|
214
|
-
"""
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
Args:
|
|
218
|
-
data: File path or DataFrame
|
|
219
|
-
name: Table name
|
|
220
|
-
extract_entities: Custom entities to extract (e.g., ['skills', 'technologies'])
|
|
221
|
-
"""
|
|
222
|
-
print(f"\nUploading data...")
|
|
201
|
+
"""Upload data."""
|
|
202
|
+
print("\nUploading data...")
|
|
223
203
|
|
|
224
204
|
if isinstance(data, pd.DataFrame):
|
|
225
205
|
name = name or "data"
|
|
@@ -238,42 +218,35 @@ class SUTRA:
|
|
|
238
218
|
if ext == ".csv":
|
|
239
219
|
df = pd.read_csv(path)
|
|
240
220
|
self._store_dataframe(df, name)
|
|
241
|
-
|
|
242
221
|
elif ext in [".xlsx", ".xls"]:
|
|
243
222
|
df = pd.read_excel(path)
|
|
244
223
|
self._store_dataframe(df, name)
|
|
245
|
-
|
|
246
224
|
elif ext == ".json":
|
|
247
225
|
df = pd.read_json(path)
|
|
248
226
|
self._store_dataframe(df, name)
|
|
249
|
-
|
|
250
227
|
elif ext == ".sql":
|
|
251
228
|
with open(path) as f:
|
|
252
229
|
self.cursor.executescript(f.read())
|
|
253
230
|
self.conn.commit()
|
|
254
231
|
self._refresh_schema()
|
|
255
232
|
print("SQL executed")
|
|
256
|
-
|
|
257
233
|
elif ext == ".pdf":
|
|
258
234
|
self._smart_upload_pdf(path, name, extract_entities)
|
|
259
|
-
|
|
260
235
|
elif ext == ".docx":
|
|
261
236
|
self._smart_upload_docx(path, name, extract_entities)
|
|
262
|
-
|
|
263
237
|
elif ext == ".txt":
|
|
264
238
|
self._smart_upload_txt(path, name, extract_entities)
|
|
265
|
-
|
|
266
239
|
else:
|
|
267
|
-
raise ValueError(f"Unsupported
|
|
240
|
+
raise ValueError(f"Unsupported: {ext}")
|
|
268
241
|
|
|
269
242
|
return self
|
|
270
243
|
|
|
271
244
|
def _smart_upload_pdf(self, path: Path, base_name: str, extract_entities: Optional[List[str]] = None):
|
|
272
|
-
"""Parse PDF
|
|
245
|
+
"""Parse PDF."""
|
|
273
246
|
if not HAS_PYPDF2:
|
|
274
247
|
raise ImportError("Run: pip install PyPDF2")
|
|
275
248
|
|
|
276
|
-
print("Extracting
|
|
249
|
+
print("Extracting from PDF...")
|
|
277
250
|
|
|
278
251
|
with open(path, 'rb') as file:
|
|
279
252
|
pdf_reader = PyPDF2.PdfReader(file)
|
|
@@ -283,23 +256,23 @@ class SUTRA:
|
|
|
283
256
|
print(f" Page {page_num}/{len(pdf_reader.pages)}")
|
|
284
257
|
|
|
285
258
|
if self.client:
|
|
286
|
-
print("AI:
|
|
259
|
+
print("AI: Extracting entities...")
|
|
287
260
|
tables = self._create_tables_with_ai(text, base_name, extract_entities)
|
|
288
261
|
|
|
289
262
|
if tables and len(tables) > 0:
|
|
290
|
-
print(f"\nCreated {len(tables)}
|
|
291
|
-
for
|
|
292
|
-
|
|
293
|
-
cols = len(self.schema_info.get(
|
|
294
|
-
print(f" {
|
|
263
|
+
print(f"\nCreated {len(tables)} tables:")
|
|
264
|
+
for tbl in tables:
|
|
265
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {tbl}", self.conn).iloc[0, 0]
|
|
266
|
+
cols = len(self.schema_info.get(tbl, {}))
|
|
267
|
+
print(f" {tbl}: {cnt} rows, {cols} columns")
|
|
295
268
|
return
|
|
296
269
|
|
|
297
|
-
print("
|
|
270
|
+
print("Creating simple table")
|
|
298
271
|
df = self._parse_text_simple(text)
|
|
299
272
|
self._store_dataframe(df, base_name)
|
|
300
273
|
|
|
301
274
|
def _smart_upload_docx(self, path: Path, base_name: str, extract_entities: Optional[List[str]] = None):
|
|
302
|
-
"""Parse DOCX
|
|
275
|
+
"""Parse DOCX."""
|
|
303
276
|
if not HAS_DOCX:
|
|
304
277
|
raise ImportError("Run: pip install python-docx")
|
|
305
278
|
|
|
@@ -320,137 +293,84 @@ class SUTRA:
|
|
|
320
293
|
text = "\n".join([para.text for para in doc.paragraphs])
|
|
321
294
|
|
|
322
295
|
if self.client:
|
|
323
|
-
print("AI:
|
|
296
|
+
print("AI: Extracting...")
|
|
324
297
|
tables = self._create_tables_with_ai(text, base_name, extract_entities)
|
|
325
|
-
|
|
326
298
|
if tables and len(tables) > 0:
|
|
327
299
|
print(f"\nCreated {len(tables)} tables:")
|
|
328
|
-
for
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
print(f" {tbl_name}: {count} rows, {cols} columns")
|
|
300
|
+
for tbl in tables:
|
|
301
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {tbl}", self.conn).iloc[0, 0]
|
|
302
|
+
print(f" {tbl}: {cnt} rows")
|
|
332
303
|
return
|
|
333
304
|
|
|
334
305
|
df = self._parse_text_simple(text)
|
|
335
306
|
self._store_dataframe(df, base_name)
|
|
336
307
|
|
|
337
308
|
def _smart_upload_txt(self, path: Path, base_name: str, extract_entities: Optional[List[str]] = None):
|
|
338
|
-
"""Parse TXT
|
|
309
|
+
"""Parse TXT."""
|
|
339
310
|
print("Reading TXT...")
|
|
340
311
|
|
|
341
312
|
with open(path, 'r', encoding='utf-8') as file:
|
|
342
313
|
text = file.read()
|
|
343
314
|
|
|
344
315
|
if self.client:
|
|
345
|
-
print("AI:
|
|
316
|
+
print("AI: Extracting...")
|
|
346
317
|
tables = self._create_tables_with_ai(text, base_name, extract_entities)
|
|
347
|
-
|
|
348
318
|
if tables and len(tables) > 0:
|
|
349
319
|
print(f"\nCreated {len(tables)} tables:")
|
|
350
|
-
for
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
print(f" {tbl_name}: {count} rows, {cols} columns")
|
|
320
|
+
for tbl in tables:
|
|
321
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {tbl}", self.conn).iloc[0, 0]
|
|
322
|
+
print(f" {tbl}: {cnt} rows")
|
|
354
323
|
return
|
|
355
324
|
|
|
356
325
|
df = self._parse_text_simple(text)
|
|
357
326
|
self._store_dataframe(df, base_name)
|
|
358
327
|
|
|
359
328
|
def _create_tables_with_ai(self, text: str, base_name: str, custom_entities: Optional[List[str]] = None) -> List[str]:
|
|
360
|
-
"""
|
|
361
|
-
AI extracts ALL entities with PROPER primary and foreign keys.
|
|
362
|
-
|
|
363
|
-
CRITICAL: Each entity gets UNIQUE IDs, foreign keys properly link tables.
|
|
364
|
-
"""
|
|
329
|
+
"""AI extraction with proper keys."""
|
|
365
330
|
if not self.client:
|
|
366
331
|
return []
|
|
367
332
|
|
|
368
333
|
try:
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
-
|
|
377
|
-
-
|
|
378
|
-
-
|
|
379
|
-
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
- work_experience: Work history (id, person_id, company, title, start_date, end_date)
|
|
383
|
-
- events: Events/meetings (id, host_id, description, location, date, attendee_ids)
|
|
384
|
-
- organizations: Companies/departments (id, name, address, city, industry)
|
|
385
|
-
- products: Products/services (id, name, description, price, category)
|
|
386
|
-
- ANY other structured entities you identify
|
|
334
|
+
entity_list = """Extract ALL entities you find:
|
|
335
|
+
- people: id, name, email, phone, address, city, state, zip
|
|
336
|
+
- skills: id, person_id, skill_name, proficiency, years
|
|
337
|
+
- technologies: id, person_id, technology, category, proficiency
|
|
338
|
+
- projects: id, person_id, project_name, description, role
|
|
339
|
+
- certifications: id, person_id, cert_name, issuer, date
|
|
340
|
+
- education: id, person_id, degree, institution, year
|
|
341
|
+
- work_experience: id, person_id, company, title, start_date, end_date
|
|
342
|
+
- events: id, host_id, description, location, date
|
|
343
|
+
- organizations: id, name, address, city
|
|
344
|
+
- ANY other structured data
|
|
345
|
+
|
|
346
|
+
CRITICAL: Use UNIQUE sequential IDs (1,2,3...) for each table. Foreign keys MUST reference valid IDs."""
|
|
387
347
|
|
|
388
|
-
|
|
348
|
+
if custom_entities:
|
|
349
|
+
entity_list = f"Extract these entities: {', '.join(custom_entities)}"
|
|
389
350
|
|
|
390
|
-
extraction_prompt = f"""
|
|
351
|
+
extraction_prompt = f"""Extract structured data from this text.
|
|
391
352
|
|
|
392
353
|
Text:
|
|
393
|
-
{text[:
|
|
394
|
-
|
|
395
|
-
{entity_instruction}
|
|
396
|
-
|
|
397
|
-
CRITICAL REQUIREMENTS FOR PROPER DATABASE DESIGN:
|
|
398
|
-
|
|
399
|
-
1. PRIMARY KEYS:
|
|
400
|
-
- Each table MUST have unique sequential IDs starting from 1
|
|
401
|
-
- Person 1 gets id=1, Person 2 gets id=2, etc.
|
|
402
|
-
- NO DUPLICATE IDs within same table
|
|
403
|
-
- IDs must be integers
|
|
354
|
+
{text[:5000]}
|
|
404
355
|
|
|
405
|
-
|
|
406
|
-
- Use foreign keys to link related tables
|
|
407
|
-
- Example: skills table has person_id that references people.id
|
|
408
|
-
- Example: projects table has person_id that references people.id
|
|
409
|
-
- Foreign keys MUST match existing primary keys
|
|
356
|
+
{entity_list}
|
|
410
357
|
|
|
411
|
-
|
|
412
|
-
- Each entity type gets its own table
|
|
413
|
-
- Use clear table names (people, skills, technologies, not table1, table2)
|
|
414
|
-
- Include ALL relevant attributes for each entity
|
|
415
|
-
|
|
416
|
-
Return JSON with this EXACT structure:
|
|
358
|
+
Return JSON:
|
|
417
359
|
{{
|
|
418
|
-
"people": [
|
|
419
|
-
|
|
420
|
-
{{"id": 2, "name": "Jane Smith", "email": "jane@email.com", "phone": "+1-555-0101", "city": "New York", "state": "NY"}},
|
|
421
|
-
...
|
|
422
|
-
],
|
|
423
|
-
"skills": [
|
|
424
|
-
{{"id": 1, "person_id": 1, "skill_name": "Python", "proficiency": "Expert", "years": 5}},
|
|
425
|
-
{{"id": 2, "person_id": 1, "skill_name": "SQL", "proficiency": "Advanced", "years": 3}},
|
|
426
|
-
{{"id": 3, "person_id": 2, "skill_name": "Java", "proficiency": "Expert", "years": 7}},
|
|
427
|
-
...
|
|
428
|
-
],
|
|
429
|
-
"technologies": [
|
|
430
|
-
{{"id": 1, "person_id": 1, "technology": "React", "category": "Frontend"}},
|
|
431
|
-
{{"id": 2, "person_id": 1, "technology": "PostgreSQL", "category": "Database"}},
|
|
432
|
-
{{"id": 3, "person_id": 2, "technology": "Spring Boot", "category": "Backend"}},
|
|
433
|
-
...
|
|
434
|
-
],
|
|
435
|
-
"projects": [
|
|
436
|
-
{{"id": 1, "person_id": 1, "project_name": "E-commerce Platform", "role": "Lead Developer"}},
|
|
437
|
-
{{"id": 2, "person_id": 2, "project_name": "Analytics Dashboard", "role": "Backend Engineer"}},
|
|
438
|
-
...
|
|
439
|
-
]
|
|
360
|
+
"people": [{{"id": 1, "name": "John", ...}}, {{"id": 2, "name": "Jane", ...}}],
|
|
361
|
+
"skills": [{{"id": 1, "person_id": 1, "skill_name": "Python", ...}}, {{"id": 2, "person_id": 2, ...}}]
|
|
440
362
|
}}
|
|
441
363
|
|
|
442
|
-
|
|
443
|
-
-
|
|
444
|
-
-
|
|
445
|
-
-
|
|
446
|
-
-
|
|
447
|
-
- Return ONLY valid JSON, no explanations
|
|
448
|
-
- Be COMPREHENSIVE - extract skills, technologies, projects, certifications, education, work history, etc."""
|
|
364
|
+
Requirements:
|
|
365
|
+
- UNIQUE IDs: id=1,2,3,... (no duplicates)
|
|
366
|
+
- Valid foreign keys: person_id must match people.id
|
|
367
|
+
- Extract EVERYTHING
|
|
368
|
+
- Return ONLY valid JSON"""
|
|
449
369
|
|
|
450
370
|
response = self.client.chat.completions.create(
|
|
451
371
|
model="gpt-4o-mini",
|
|
452
372
|
messages=[
|
|
453
|
-
{"role": "system", "content": "
|
|
373
|
+
{"role": "system", "content": "Extract entities with unique IDs and proper foreign keys. Return only JSON."},
|
|
454
374
|
{"role": "user", "content": extraction_prompt}
|
|
455
375
|
],
|
|
456
376
|
temperature=0,
|
|
@@ -471,20 +391,45 @@ IMPORTANT:
|
|
|
471
391
|
try:
|
|
472
392
|
df = pd.DataFrame(records)
|
|
473
393
|
if not df.empty:
|
|
474
|
-
|
|
394
|
+
# FIXED: Store with better error handling
|
|
395
|
+
self._store_dataframe_safe(df, table_name)
|
|
475
396
|
created_tables.append(table_name)
|
|
476
397
|
print(f" {entity_type}: {len(df)} records")
|
|
477
398
|
except Exception as e:
|
|
478
|
-
print(f"
|
|
399
|
+
print(f" Error {entity_type}: {e}")
|
|
479
400
|
|
|
480
401
|
return created_tables
|
|
481
402
|
|
|
482
403
|
except Exception as e:
|
|
483
|
-
print(f"AI
|
|
404
|
+
print(f"AI error: {e}")
|
|
484
405
|
return []
|
|
485
406
|
|
|
407
|
+
def _store_dataframe_safe(self, df: pd.DataFrame, name: str):
|
|
408
|
+
"""FIXED: Store with proper error handling for Colab."""
|
|
409
|
+
try:
|
|
410
|
+
df.columns = [str(c).strip().replace(" ", "_").replace("-", "_") for c in df.columns]
|
|
411
|
+
|
|
412
|
+
# FIXED: Use method='multi' for better performance and if_exists='replace'
|
|
413
|
+
df.to_sql(name, self.conn, if_exists='replace', index=False, method='multi', chunksize=500)
|
|
414
|
+
|
|
415
|
+
self.conn.commit() # FIXED: Explicit commit
|
|
416
|
+
self.current_table = name
|
|
417
|
+
self._refresh_schema()
|
|
418
|
+
|
|
419
|
+
except Exception as e:
|
|
420
|
+
# FIXED: Fallback to single-row insert if bulk fails
|
|
421
|
+
print(f" Bulk insert failed, using row-by-row (slower but safer)")
|
|
422
|
+
try:
|
|
423
|
+
df.to_sql(name, self.conn, if_exists='replace', index=False)
|
|
424
|
+
self.conn.commit()
|
|
425
|
+
self.current_table = name
|
|
426
|
+
self._refresh_schema()
|
|
427
|
+
except Exception as e2:
|
|
428
|
+
print(f" Storage error: {e2}")
|
|
429
|
+
raise
|
|
430
|
+
|
|
486
431
|
def _parse_text_simple(self, text: str) -> pd.DataFrame:
|
|
487
|
-
"""
|
|
432
|
+
"""Simple parsing."""
|
|
488
433
|
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
|
489
434
|
|
|
490
435
|
if not lines:
|
|
@@ -507,44 +452,34 @@ IMPORTANT:
|
|
|
507
452
|
|
|
508
453
|
def _store_dataframe(self, df: pd.DataFrame, name: str, silent: bool = False):
|
|
509
454
|
"""Store DataFrame."""
|
|
510
|
-
|
|
511
|
-
df.to_sql(name, self.conn, if_exists='replace', index=False)
|
|
512
|
-
self.current_table = name
|
|
513
|
-
self._refresh_schema()
|
|
455
|
+
self._store_dataframe_safe(df, name)
|
|
514
456
|
|
|
515
457
|
if not silent:
|
|
516
458
|
print(f"Uploaded: {name}")
|
|
517
459
|
print(f" {len(df)} rows, {len(df.columns)} columns")
|
|
518
460
|
|
|
519
461
|
def ask(self, question: str, viz: Union[bool, str] = False, table: Optional[str] = None) -> 'QueryResult':
|
|
520
|
-
"""
|
|
462
|
+
"""Natural language query."""
|
|
521
463
|
if not self.client:
|
|
522
|
-
print("No API key")
|
|
523
464
|
return QueryResult(False, "", pd.DataFrame(), None, "No API key")
|
|
524
465
|
|
|
525
466
|
print(f"\nQuestion: {question}")
|
|
526
467
|
|
|
527
|
-
if self.check_relevance:
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
return QueryResult(False, "", pd.DataFrame(), None, "Irrelevant")
|
|
468
|
+
if self.check_relevance and not self._is_relevant_query(question):
|
|
469
|
+
print("Warning: Query may be irrelevant")
|
|
470
|
+
choice = input("Continue? (yes/no): ").strip().lower()
|
|
471
|
+
if choice not in ['yes', 'y']:
|
|
472
|
+
return QueryResult(False, "", pd.DataFrame(), None, "Irrelevant")
|
|
533
473
|
|
|
534
|
-
tbl = table or self.current_table
|
|
474
|
+
tbl = table or self.current_table or (self._get_table_names()[0] if self._get_table_names() else None)
|
|
535
475
|
if not tbl:
|
|
536
|
-
|
|
537
|
-
if all_tables:
|
|
538
|
-
tbl = all_tables[0]
|
|
539
|
-
else:
|
|
540
|
-
print("No tables found")
|
|
541
|
-
return QueryResult(False, "", pd.DataFrame(), None, "No table")
|
|
476
|
+
return QueryResult(False, "", pd.DataFrame(), None, "No table")
|
|
542
477
|
|
|
543
478
|
if self.use_embeddings and self.embedding_model:
|
|
544
|
-
|
|
545
|
-
if
|
|
546
|
-
print("
|
|
547
|
-
return
|
|
479
|
+
cached = self._check_embedding_cache(question, tbl)
|
|
480
|
+
if cached:
|
|
481
|
+
print(" Cached result")
|
|
482
|
+
return cached
|
|
548
483
|
|
|
549
484
|
if self.fuzzy_match:
|
|
550
485
|
question = self._apply_fuzzy_matching(question, tbl)
|
|
@@ -567,7 +502,7 @@ IMPORTANT:
|
|
|
567
502
|
fig = None
|
|
568
503
|
if viz:
|
|
569
504
|
viz_type = viz if isinstance(viz, str) else "auto"
|
|
570
|
-
fig = self._visualize(df, question, viz_type
|
|
505
|
+
fig = self._visualize(df, question, viz_type)
|
|
571
506
|
|
|
572
507
|
result = QueryResult(True, sql_query, df, fig)
|
|
573
508
|
|
|
@@ -584,142 +519,116 @@ IMPORTANT:
|
|
|
584
519
|
if not self.client:
|
|
585
520
|
return True
|
|
586
521
|
|
|
587
|
-
tables = self._get_table_names()
|
|
588
|
-
|
|
589
|
-
for tbl in tables
|
|
590
|
-
cols
|
|
591
|
-
columns.extend(cols[:5])
|
|
522
|
+
tables = self._get_table_names()[:3]
|
|
523
|
+
cols = []
|
|
524
|
+
for tbl in tables:
|
|
525
|
+
cols.extend(list(self.schema_info.get(tbl, {}).keys())[:5])
|
|
592
526
|
|
|
593
|
-
|
|
527
|
+
context = f"Tables: {', '.join(tables)}. Columns: {', '.join(cols[:15])}"
|
|
594
528
|
|
|
595
529
|
try:
|
|
596
|
-
|
|
530
|
+
resp = self.client.chat.completions.create(
|
|
597
531
|
model="gpt-4o-mini",
|
|
598
532
|
messages=[
|
|
599
|
-
{"role": "system", "content": "
|
|
600
|
-
{"role": "user", "content": f"
|
|
533
|
+
{"role": "system", "content": "Return only 'yes' or 'no'."},
|
|
534
|
+
{"role": "user", "content": f"Relevant to {context}?\n\nQ: {question}\n\nyes/no:"}
|
|
601
535
|
],
|
|
602
536
|
temperature=0,
|
|
603
537
|
max_tokens=5
|
|
604
538
|
)
|
|
605
|
-
|
|
606
|
-
return 'yes' in response.choices[0].message.content.strip().lower()
|
|
539
|
+
return 'yes' in resp.choices[0].message.content.lower()
|
|
607
540
|
except:
|
|
608
541
|
return True
|
|
609
542
|
|
|
610
543
|
def _apply_fuzzy_matching(self, question: str, table: str) -> str:
|
|
611
|
-
"""Fuzzy
|
|
544
|
+
"""Fuzzy matching."""
|
|
612
545
|
if not self.schema_info.get(table):
|
|
613
546
|
return question
|
|
614
547
|
|
|
615
548
|
try:
|
|
616
|
-
string_cols = [col for col, dtype in self.schema_info[table].items()
|
|
617
|
-
if 'TEXT' in dtype or 'VARCHAR' in dtype]
|
|
618
|
-
|
|
549
|
+
string_cols = [col for col, dtype in self.schema_info[table].items() if 'TEXT' in dtype]
|
|
619
550
|
if not string_cols:
|
|
620
551
|
return question
|
|
621
552
|
|
|
622
553
|
for col in string_cols[:2]:
|
|
623
554
|
df = pd.read_sql_query(f"SELECT DISTINCT {col} FROM {table} LIMIT 100", self.conn)
|
|
624
|
-
|
|
555
|
+
values = [str(v) for v in df[col].dropna().tolist()]
|
|
625
556
|
|
|
626
557
|
words = question.split()
|
|
627
558
|
for i, word in enumerate(words):
|
|
628
|
-
matches = get_close_matches(word,
|
|
559
|
+
matches = get_close_matches(word, values, n=1, cutoff=0.6)
|
|
629
560
|
if matches and word != matches[0]:
|
|
630
561
|
words[i] = matches[0]
|
|
631
562
|
print(f" Fuzzy: '{word}' -> '{matches[0]}'")
|
|
632
|
-
|
|
633
563
|
question = " ".join(words)
|
|
634
|
-
|
|
635
564
|
return question
|
|
636
565
|
except:
|
|
637
566
|
return question
|
|
638
567
|
|
|
639
568
|
def _check_embedding_cache(self, question: str, table: str) -> Optional['QueryResult']:
|
|
640
|
-
"""Check
|
|
569
|
+
"""Check cache."""
|
|
641
570
|
if not self.query_embeddings:
|
|
642
571
|
return None
|
|
643
572
|
|
|
644
|
-
|
|
573
|
+
q_emb = self.embedding_model.encode([question])[0]
|
|
645
574
|
|
|
646
575
|
best_match = None
|
|
647
|
-
|
|
576
|
+
best_sim = 0.85
|
|
648
577
|
|
|
649
|
-
for cached_q,
|
|
650
|
-
if
|
|
578
|
+
for cached_q, data in self.query_embeddings.items():
|
|
579
|
+
if data['table'] != table:
|
|
651
580
|
continue
|
|
652
581
|
|
|
653
|
-
|
|
654
|
-
np.linalg.norm(q_embedding) * np.linalg.norm(cached_data['embedding'])
|
|
655
|
-
)
|
|
582
|
+
sim = np.dot(q_emb, data['embedding']) / (np.linalg.norm(q_emb) * np.linalg.norm(data['embedding']))
|
|
656
583
|
|
|
657
|
-
if
|
|
658
|
-
|
|
584
|
+
if sim > best_sim:
|
|
585
|
+
best_sim = sim
|
|
659
586
|
best_match = cached_q
|
|
660
587
|
|
|
661
588
|
if best_match:
|
|
662
|
-
print(f" Similar
|
|
589
|
+
print(f" Similar ({best_sim:.0%}): '{best_match}'")
|
|
663
590
|
return self.query_embeddings[best_match]['result']
|
|
664
591
|
|
|
665
592
|
return None
|
|
666
593
|
|
|
667
594
|
def _store_in_embedding_cache(self, question: str, table: str, result: 'QueryResult'):
|
|
668
|
-
"""Store
|
|
669
|
-
|
|
670
|
-
self.query_embeddings[question] = {
|
|
671
|
-
'table': table,
|
|
672
|
-
'embedding': q_embedding,
|
|
673
|
-
'result': result
|
|
674
|
-
}
|
|
595
|
+
"""Store cache."""
|
|
596
|
+
q_emb = self.embedding_model.encode([question])[0]
|
|
597
|
+
self.query_embeddings[question] = {'table': table, 'embedding': q_emb, 'result': result}
|
|
675
598
|
|
|
676
599
|
def _visualize(self, df: pd.DataFrame, title: str, viz_type: str = "auto"):
|
|
677
|
-
"""
|
|
600
|
+
"""Visualize."""
|
|
678
601
|
if not HAS_PLOTLY and not HAS_MATPLOTLIB:
|
|
679
|
-
print("Install plotly or matplotlib")
|
|
680
602
|
return None
|
|
681
603
|
|
|
682
604
|
print(f"Creating {viz_type} chart...")
|
|
683
|
-
|
|
684
|
-
if HAS_PLOTLY:
|
|
685
|
-
return self._plotly_viz(df, title, viz_type)
|
|
686
|
-
else:
|
|
687
|
-
return self._matplotlib_viz(df, title, viz_type)
|
|
605
|
+
return self._plotly_viz(df, title, viz_type) if HAS_PLOTLY else self._matplotlib_viz(df, title, viz_type)
|
|
688
606
|
|
|
689
607
|
def _plotly_viz(self, df: pd.DataFrame, title: str, viz_type: str):
|
|
690
|
-
"""Plotly
|
|
608
|
+
"""Plotly viz."""
|
|
691
609
|
try:
|
|
692
|
-
|
|
693
|
-
|
|
610
|
+
num = df.select_dtypes(include=[np.number]).columns.tolist()
|
|
611
|
+
cat = df.select_dtypes(include=['object']).columns.tolist()
|
|
694
612
|
|
|
695
|
-
if viz_type == "table"
|
|
696
|
-
fig = go.Figure(data=[go.Table(
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
fig =
|
|
708
|
-
elif viz_type == "heatmap" and len(numeric) >= 2:
|
|
709
|
-
corr = df[numeric].corr()
|
|
710
|
-
fig = go.Figure(data=go.Heatmap(
|
|
711
|
-
z=corr.values, x=corr.columns, y=corr.columns, colorscale='Viridis'
|
|
712
|
-
))
|
|
613
|
+
if viz_type == "table":
|
|
614
|
+
fig = go.Figure(data=[go.Table(header=dict(values=list(df.columns)), cells=dict(values=[df[c] for c in df.columns]))])
|
|
615
|
+
elif viz_type == "pie" and cat and num:
|
|
616
|
+
fig = px.pie(df, names=cat[0], values=num[0], title=title)
|
|
617
|
+
elif viz_type == "bar" and cat and num:
|
|
618
|
+
fig = px.bar(df, x=cat[0], y=num[0], title=title)
|
|
619
|
+
elif viz_type == "line" and num:
|
|
620
|
+
fig = px.line(df, y=num[0], title=title)
|
|
621
|
+
elif viz_type == "scatter" and len(num) >= 2:
|
|
622
|
+
fig = px.scatter(df, x=num[0], y=num[1], title=title)
|
|
623
|
+
elif viz_type == "heatmap" and len(num) >= 2:
|
|
624
|
+
corr = df[num].corr()
|
|
625
|
+
fig = go.Figure(data=go.Heatmap(z=corr.values, x=corr.columns, y=corr.columns))
|
|
713
626
|
fig.update_layout(title=title)
|
|
714
|
-
|
|
715
|
-
if
|
|
716
|
-
fig = px.pie(df, names=
|
|
717
|
-
elif len(numeric) >= 2:
|
|
718
|
-
fig = px.line(df, y=numeric[0], title=title)
|
|
627
|
+
else:
|
|
628
|
+
if cat and num:
|
|
629
|
+
fig = px.pie(df, names=cat[0], values=num[0], title=title) if len(df) <= 10 else px.bar(df, x=cat[0], y=num[0], title=title)
|
|
719
630
|
else:
|
|
720
631
|
fig = px.bar(df, y=df.columns[0], title=title)
|
|
721
|
-
else:
|
|
722
|
-
fig = px.bar(df, x=categorical[0] if categorical else df.index, y=numeric[0] if numeric else df.columns[0], title=title)
|
|
723
632
|
|
|
724
633
|
fig.show()
|
|
725
634
|
print("Chart displayed")
|
|
@@ -729,54 +638,47 @@ IMPORTANT:
|
|
|
729
638
|
return None
|
|
730
639
|
|
|
731
640
|
def _matplotlib_viz(self, df: pd.DataFrame, title: str, viz_type: str):
|
|
732
|
-
"""Matplotlib
|
|
641
|
+
"""Matplotlib viz."""
|
|
733
642
|
try:
|
|
734
643
|
plt.figure(figsize=(10, 6))
|
|
735
|
-
|
|
644
|
+
num = df.select_dtypes(include=[np.number]).columns
|
|
736
645
|
|
|
737
|
-
if viz_type == "pie"
|
|
646
|
+
if viz_type == "pie":
|
|
738
647
|
df[df.columns[0]].value_counts().plot(kind='pie')
|
|
739
|
-
elif viz_type == "line" and len(
|
|
740
|
-
df[
|
|
648
|
+
elif viz_type == "line" and len(num) > 0:
|
|
649
|
+
df[num[0]].plot(kind='line')
|
|
741
650
|
else:
|
|
742
|
-
if len(
|
|
743
|
-
df[numeric[0]].plot(kind='bar')
|
|
744
|
-
else:
|
|
745
|
-
df.iloc[:, 0].value_counts().plot(kind='bar')
|
|
651
|
+
(df[num[0]] if len(num) > 0 else df.iloc[:, 0].value_counts()).plot(kind='bar')
|
|
746
652
|
|
|
747
653
|
plt.title(title)
|
|
748
654
|
plt.tight_layout()
|
|
749
655
|
plt.show()
|
|
750
|
-
print("Chart displayed")
|
|
751
656
|
return plt.gcf()
|
|
752
657
|
except Exception as e:
|
|
753
658
|
print(f"Viz error: {e}")
|
|
754
659
|
return None
|
|
755
660
|
|
|
756
661
|
def tables(self) -> Dict[str, dict]:
|
|
757
|
-
"""List
|
|
662
|
+
"""List tables."""
|
|
758
663
|
print("\n" + "="*70)
|
|
759
664
|
print("TABLES IN DATABASE")
|
|
760
665
|
print("="*70)
|
|
761
666
|
|
|
762
667
|
all_tables = self._get_table_names()
|
|
763
|
-
|
|
764
668
|
if not all_tables:
|
|
765
|
-
print("No tables
|
|
669
|
+
print("No tables")
|
|
766
670
|
return {}
|
|
767
671
|
|
|
768
672
|
result = {}
|
|
769
673
|
for i, tbl in enumerate(all_tables, 1):
|
|
770
|
-
|
|
771
|
-
cols = self.schema_info.get(tbl, {})
|
|
772
|
-
col_list = list(cols.keys())
|
|
674
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {tbl}", self.conn).iloc[0, 0]
|
|
675
|
+
cols = list(self.schema_info.get(tbl, {}).keys())
|
|
773
676
|
|
|
774
|
-
|
|
775
|
-
print(f"{
|
|
776
|
-
print(f" {
|
|
777
|
-
print(f" Columns: {', '.join(col_list[:8])}")
|
|
677
|
+
print(f" {i}. {tbl}")
|
|
678
|
+
print(f" {cnt} rows, {len(cols)} columns")
|
|
679
|
+
print(f" {', '.join(cols[:8])}")
|
|
778
680
|
|
|
779
|
-
result[tbl] = {'rows':
|
|
681
|
+
result[tbl] = {'rows': cnt, 'columns': cols}
|
|
780
682
|
|
|
781
683
|
print("="*70)
|
|
782
684
|
return result
|
|
@@ -795,184 +697,122 @@ IMPORTANT:
|
|
|
795
697
|
result = {}
|
|
796
698
|
for tbl in tables_to_show:
|
|
797
699
|
if tbl in self.schema_info:
|
|
798
|
-
|
|
799
|
-
print(f"\nTable: {tbl}")
|
|
800
|
-
print(f"Records: {count}")
|
|
801
|
-
print("Columns:")
|
|
802
|
-
|
|
700
|
+
cnt = pd.read_sql_query(f"SELECT COUNT(*) FROM {tbl}", self.conn).iloc[0, 0]
|
|
701
|
+
print(f"\nTable: {tbl} ({cnt} records)")
|
|
803
702
|
for col, dtype in self.schema_info[tbl].items():
|
|
804
|
-
print(f" - {col:<30}
|
|
805
|
-
|
|
806
|
-
result[tbl] = {
|
|
807
|
-
'records': count,
|
|
808
|
-
'columns': self.schema_info[tbl]
|
|
809
|
-
}
|
|
703
|
+
print(f" - {col:<30} {dtype}")
|
|
704
|
+
result[tbl] = {'records': cnt, 'columns': self.schema_info[tbl]}
|
|
810
705
|
|
|
811
706
|
print("="*70)
|
|
812
707
|
return result
|
|
813
708
|
|
|
814
709
|
def peek(self, table: Optional[str] = None, n: int = 5) -> pd.DataFrame:
|
|
815
|
-
"""Preview
|
|
710
|
+
"""Preview."""
|
|
816
711
|
tbl = table or self.current_table
|
|
817
712
|
if not tbl:
|
|
818
|
-
print("No table specified")
|
|
819
713
|
return pd.DataFrame()
|
|
820
714
|
|
|
821
715
|
df = pd.read_sql_query(f"SELECT * FROM {tbl} LIMIT {n}", self.conn)
|
|
822
|
-
print(f"\nSample from '{tbl}'
|
|
716
|
+
print(f"\nSample from '{tbl}':")
|
|
823
717
|
print(df.to_string(index=False))
|
|
824
718
|
return df
|
|
825
719
|
|
|
826
720
|
def info(self):
|
|
827
|
-
"""
|
|
721
|
+
"""Overview."""
|
|
828
722
|
return self.tables()
|
|
829
723
|
|
|
830
724
|
def sql(self, query: str, viz: Union[bool, str] = False) -> 'QueryResult':
|
|
831
725
|
"""Execute SQL."""
|
|
832
726
|
print("\nExecuting SQL...")
|
|
833
|
-
|
|
834
727
|
try:
|
|
835
728
|
df = pd.read_sql_query(query, self.conn)
|
|
836
729
|
print(f"Success! {len(df)} rows")
|
|
837
730
|
|
|
838
|
-
fig = None
|
|
839
|
-
if viz:
|
|
840
|
-
viz_type = viz if isinstance(viz, str) else "auto"
|
|
841
|
-
fig = self._visualize(df, "SQL Result", viz_type=viz_type)
|
|
842
|
-
|
|
731
|
+
fig = self._visualize(df, "SQL Result", viz if isinstance(viz, str) else "auto") if viz else None
|
|
843
732
|
return QueryResult(True, query, df, fig)
|
|
844
733
|
except Exception as e:
|
|
845
734
|
print(f"Error: {e}")
|
|
846
735
|
return QueryResult(False, query, pd.DataFrame(), None, str(e))
|
|
847
736
|
|
|
848
737
|
def interactive(self, question: str) -> 'QueryResult':
|
|
849
|
-
"""Interactive
|
|
738
|
+
"""Interactive."""
|
|
850
739
|
print(f"\nQuestion: {question}")
|
|
851
740
|
choice = input("Visualize? (yes/no/pie/bar/line/scatter): ").strip().lower()
|
|
852
|
-
|
|
853
741
|
viz = choice if choice in ['pie', 'bar', 'line', 'scatter', 'table', 'heatmap'] else (True if choice in ['yes', 'y'] else False)
|
|
854
|
-
|
|
855
742
|
return self.ask(question, viz=viz)
|
|
856
743
|
|
|
857
744
|
def export_db(self, path: str, format: str = "sqlite"):
|
|
858
|
-
"""Export
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
for table in self._get_table_names():
|
|
870
|
-
df = pd.read_sql_query(f"SELECT * FROM {table}", self.conn)
|
|
871
|
-
data[table] = df.to_dict(orient='records')
|
|
872
|
-
with open(path, 'w', encoding='utf-8') as f:
|
|
873
|
-
json.dump(data, f, indent=2, default=str)
|
|
874
|
-
elif format == "excel":
|
|
875
|
-
with pd.ExcelWriter(path, engine='openpyxl') as writer:
|
|
876
|
-
for table in self._get_table_names():
|
|
877
|
-
df = pd.read_sql_query(f"SELECT * FROM {table}", self.conn)
|
|
878
|
-
df.to_excel(writer, sheet_name=table[:31], index=False)
|
|
745
|
+
"""Export."""
|
|
746
|
+
formats = {
|
|
747
|
+
"sqlite": lambda: shutil.copy2(self.db_path, path),
|
|
748
|
+
"sql": lambda: open(path, 'w', encoding='utf-8').writelines(f'{line}\n' for line in self.conn.iterdump()),
|
|
749
|
+
"json": lambda: json.dump({t: pd.read_sql_query(f"SELECT * FROM {t}", self.conn).to_dict('records') for t in self._get_table_names()}, open(path, 'w', encoding='utf-8'), indent=2, default=str),
|
|
750
|
+
"excel": lambda: pd.ExcelWriter(path, engine='openpyxl').__enter__() and [pd.read_sql_query(f"SELECT * FROM {t}", self.conn).to_excel(path, sheet_name=t[:31], index=False) for t in self._get_table_names()]
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
if format in formats:
|
|
754
|
+
formats[format]()
|
|
755
|
+
print(f"Saved: {path}")
|
|
879
756
|
else:
|
|
880
757
|
raise ValueError(f"Unsupported: {format}")
|
|
881
|
-
|
|
882
|
-
print(f"Saved to {path}")
|
|
883
758
|
return self
|
|
884
759
|
|
|
885
|
-
def save_to_mysql(self, host: str, user: str, password: str, database: str,
|
|
886
|
-
port: int = 3306, tables: Optional[List[str]] = None):
|
|
760
|
+
def save_to_mysql(self, host: str, user: str, password: str, database: str, port: int = 3306, tables: Optional[List[str]] = None):
|
|
887
761
|
"""Export to MySQL."""
|
|
888
762
|
try:
|
|
889
763
|
from sqlalchemy import create_engine
|
|
764
|
+
engine = create_engine(f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}")
|
|
765
|
+
|
|
766
|
+
print(f"Exporting to MySQL...")
|
|
767
|
+
for t in (tables or self._get_table_names()):
|
|
768
|
+
df = pd.read_sql_query(f"SELECT * FROM {t}", self.conn)
|
|
769
|
+
df.to_sql(t, engine, if_exists='replace', index=False)
|
|
770
|
+
print(f" {t}: {len(df)} rows")
|
|
771
|
+
print("Complete!")
|
|
772
|
+
return self
|
|
890
773
|
except ImportError:
|
|
891
774
|
raise ImportError("Run: pip install QuerySUTRA[mysql]")
|
|
892
|
-
|
|
893
|
-
print(f"\nConnecting to MySQL: {host}:{port}...")
|
|
894
|
-
|
|
895
|
-
engine = create_engine(f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}")
|
|
896
|
-
|
|
897
|
-
tables_to_export = tables or self._get_table_names()
|
|
898
|
-
|
|
899
|
-
print(f"Exporting {len(tables_to_export)} tables...")
|
|
900
|
-
|
|
901
|
-
for table in tables_to_export:
|
|
902
|
-
df = pd.read_sql_query(f"SELECT * FROM {table}", self.conn)
|
|
903
|
-
df.to_sql(table, engine, if_exists='replace', index=False)
|
|
904
|
-
print(f" {table}: {len(df)} rows")
|
|
905
|
-
|
|
906
|
-
print("Complete!")
|
|
907
|
-
return self
|
|
908
775
|
|
|
909
|
-
def save_to_postgres(self, host: str, user: str, password: str, database: str,
|
|
910
|
-
port: int = 5432, tables: Optional[List[str]] = None):
|
|
776
|
+
def save_to_postgres(self, host: str, user: str, password: str, database: str, port: int = 5432, tables: Optional[List[str]] = None):
|
|
911
777
|
"""Export to PostgreSQL."""
|
|
912
778
|
try:
|
|
913
779
|
from sqlalchemy import create_engine
|
|
780
|
+
engine = create_engine(f"postgresql://{user}:{password}@{host}:{port}/{database}")
|
|
781
|
+
|
|
782
|
+
print(f"Exporting to PostgreSQL...")
|
|
783
|
+
for t in (tables or self._get_table_names()):
|
|
784
|
+
df = pd.read_sql_query(f"SELECT * FROM {t}", self.conn)
|
|
785
|
+
df.to_sql(t, engine, if_exists='replace', index=False)
|
|
786
|
+
print(f" {t}: {len(df)} rows")
|
|
787
|
+
print("Complete!")
|
|
788
|
+
return self
|
|
914
789
|
except ImportError:
|
|
915
790
|
raise ImportError("Run: pip install QuerySUTRA[postgres]")
|
|
916
|
-
|
|
917
|
-
print(f"\nConnecting to PostgreSQL: {host}:{port}...")
|
|
918
|
-
|
|
919
|
-
engine = create_engine(f"postgresql://{user}:{password}@{host}:{port}/{database}")
|
|
920
|
-
|
|
921
|
-
tables_to_export = tables or self._get_table_names()
|
|
922
|
-
|
|
923
|
-
print(f"Exporting {len(tables_to_export)} tables...")
|
|
924
|
-
|
|
925
|
-
for table in tables_to_export:
|
|
926
|
-
df = pd.read_sql_query(f"SELECT * FROM {table}", self.conn)
|
|
927
|
-
df.to_sql(table, engine, if_exists='replace', index=False)
|
|
928
|
-
print(f" {table}: {len(df)} rows")
|
|
929
|
-
|
|
930
|
-
print("Complete!")
|
|
931
|
-
return self
|
|
932
791
|
|
|
933
|
-
def backup(self,
|
|
934
|
-
"""
|
|
935
|
-
if
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
else:
|
|
939
|
-
backup_dir = Path(".")
|
|
940
|
-
|
|
941
|
-
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
942
|
-
|
|
943
|
-
print("\nCreating backup...")
|
|
944
|
-
|
|
945
|
-
db_backup = backup_dir / f"sutra_{timestamp}.db"
|
|
946
|
-
self.export_db(str(db_backup), format="sqlite")
|
|
947
|
-
|
|
948
|
-
json_backup = backup_dir / f"sutra_{timestamp}.json"
|
|
949
|
-
self.export_db(str(json_backup), format="json")
|
|
950
|
-
|
|
951
|
-
print(f"\nBackup complete!")
|
|
952
|
-
print(f" Database: {db_backup}")
|
|
953
|
-
print(f" Data: {json_backup}")
|
|
792
|
+
def backup(self, path: str = None):
|
|
793
|
+
"""Backup."""
|
|
794
|
+
dir = Path(path) if path else Path(".")
|
|
795
|
+
dir.mkdir(parents=True, exist_ok=True)
|
|
796
|
+
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
954
797
|
|
|
798
|
+
print("Creating backup...")
|
|
799
|
+
self.export_db(str(dir / f"sutra_{ts}.db"), "sqlite")
|
|
800
|
+
self.export_db(str(dir / f"sutra_{ts}.json"), "json")
|
|
801
|
+
print("Backup complete!")
|
|
955
802
|
return self
|
|
956
803
|
|
|
957
804
|
def export(self, data: pd.DataFrame, path: str, format: str = "csv"):
|
|
958
805
|
"""Export results."""
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
elif format == "json":
|
|
964
|
-
data.to_json(path, orient="records", indent=2)
|
|
965
|
-
else:
|
|
966
|
-
raise ValueError(f"Unknown: {format}")
|
|
967
|
-
|
|
968
|
-
print(f"Exported to {path}")
|
|
806
|
+
{"csv": lambda: data.to_csv(path, index=False),
|
|
807
|
+
"excel": lambda: data.to_excel(path, index=False),
|
|
808
|
+
"json": lambda: data.to_json(path, orient="records", indent=2)}[format]()
|
|
809
|
+
print(f"Exported: {path}")
|
|
969
810
|
return self
|
|
970
811
|
|
|
971
812
|
def close(self):
|
|
972
|
-
"""Close
|
|
813
|
+
"""Close."""
|
|
973
814
|
if self.conn:
|
|
974
815
|
self.conn.close()
|
|
975
|
-
print("Closed")
|
|
976
816
|
|
|
977
817
|
def _get_table_names(self) -> List[str]:
|
|
978
818
|
"""Get tables."""
|
|
@@ -981,45 +821,27 @@ IMPORTANT:
|
|
|
981
821
|
|
|
982
822
|
def _refresh_schema(self):
|
|
983
823
|
"""Refresh schema."""
|
|
984
|
-
tables = self._get_table_names()
|
|
985
|
-
|
|
986
824
|
self.schema_info = {}
|
|
987
|
-
for tbl in
|
|
825
|
+
for tbl in self._get_table_names():
|
|
988
826
|
self.cursor.execute(f"PRAGMA table_info({tbl})")
|
|
989
827
|
self.schema_info[tbl] = {r[1]: r[2] for r in self.cursor.fetchall()}
|
|
990
828
|
|
|
991
829
|
def _generate_sql(self, question: str, table: str) -> str:
|
|
992
830
|
"""Generate SQL."""
|
|
993
831
|
schema = self.schema_info.get(table, {})
|
|
994
|
-
|
|
995
|
-
sample = sample_df.to_string(index=False)
|
|
996
|
-
|
|
832
|
+
sample = pd.read_sql_query(f"SELECT * FROM {table} LIMIT 3", self.conn).to_string(index=False)
|
|
997
833
|
schema_str = ", ".join([f"{col} ({dtype})" for col, dtype in schema.items()])
|
|
998
834
|
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
Database: SQLite
|
|
1002
|
-
Table: {table}
|
|
1003
|
-
Columns: {schema_str}
|
|
1004
|
-
|
|
1005
|
-
Sample:
|
|
1006
|
-
{sample}
|
|
1007
|
-
|
|
1008
|
-
Question: {question}
|
|
1009
|
-
|
|
1010
|
-
Return ONLY SQL."""
|
|
1011
|
-
|
|
1012
|
-
response = self.client.chat.completions.create(
|
|
835
|
+
resp = self.client.chat.completions.create(
|
|
1013
836
|
model="gpt-4o-mini",
|
|
1014
837
|
messages=[
|
|
1015
|
-
{"role": "system", "content": "SQL expert. Return only SQL
|
|
1016
|
-
{"role": "user", "content":
|
|
838
|
+
{"role": "system", "content": "SQL expert. Return only SQL."},
|
|
839
|
+
{"role": "user", "content": f"Convert to SQL.\nTable: {table}\nColumns: {schema_str}\nSample:\n{sample}\n\nQ: {question}\n\nSQL:"}
|
|
1017
840
|
],
|
|
1018
841
|
temperature=0
|
|
1019
842
|
)
|
|
1020
843
|
|
|
1021
|
-
|
|
1022
|
-
return sql.replace("```sql", "").replace("```", "").strip()
|
|
844
|
+
return resp.choices[0].message.content.strip().replace("```sql", "").replace("```", "").strip()
|
|
1023
845
|
|
|
1024
846
|
def __enter__(self):
|
|
1025
847
|
return self
|
|
@@ -1028,53 +850,25 @@ Return ONLY SQL."""
|
|
|
1028
850
|
self.close()
|
|
1029
851
|
|
|
1030
852
|
def __repr__(self):
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
features.append("cache")
|
|
1034
|
-
if self.use_embeddings:
|
|
1035
|
-
features.append("embeddings")
|
|
1036
|
-
if self.check_relevance:
|
|
1037
|
-
features.append("relevance")
|
|
1038
|
-
if self.fuzzy_match:
|
|
1039
|
-
features.append("fuzzy")
|
|
1040
|
-
|
|
1041
|
-
feat_str = f", {', '.join(features)}" if features else ""
|
|
1042
|
-
return f"SUTRA(tables={len(self.schema_info)}{feat_str})"
|
|
853
|
+
feat = [f for f, v in [("cache", self.cache_queries), ("embeddings", self.use_embeddings), ("relevance", self.check_relevance), ("fuzzy", self.fuzzy_match)] if v]
|
|
854
|
+
return f"SUTRA(tables={len(self.schema_info)}, {', '.join(feat)})"
|
|
1043
855
|
|
|
1044
856
|
|
|
1045
857
|
class QueryResult:
|
|
1046
|
-
"""
|
|
1047
|
-
|
|
858
|
+
"""Result."""
|
|
1048
859
|
def __init__(self, success: bool, sql: str, data: pd.DataFrame, viz, error: str = None):
|
|
1049
|
-
self.success = success
|
|
1050
|
-
self.sql = sql
|
|
1051
|
-
self.data = data
|
|
1052
|
-
self.viz = viz
|
|
1053
|
-
self.error = error
|
|
860
|
+
self.success, self.sql, self.data, self.viz, self.error = success, sql, data, viz, error
|
|
1054
861
|
|
|
1055
862
|
def __repr__(self):
|
|
1056
863
|
return f"QueryResult(rows={len(self.data)}, cols={len(self.data.columns)})" if self.success else f"QueryResult(error='{self.error}')"
|
|
1057
864
|
|
|
1058
865
|
def show(self):
|
|
1059
|
-
print(self.data
|
|
866
|
+
print(self.data if self.success else f"Error: {self.error}")
|
|
1060
867
|
return self
|
|
1061
868
|
|
|
1062
869
|
|
|
1063
870
|
def quick_start(api_key: str, data_path: str, question: str, viz: Union[bool, str] = False):
|
|
1064
|
-
"""
|
|
871
|
+
"""Quick start."""
|
|
1065
872
|
with SUTRA(api_key=api_key) as sutra:
|
|
1066
873
|
sutra.upload(data_path)
|
|
1067
874
|
return sutra.ask(question, viz=viz)
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
if __name__ == "__main__":
|
|
1071
|
-
print("""
|
|
1072
|
-
QuerySUTRA v0.3.3 - Professional Data Analysis
|
|
1073
|
-
SUTRA: Structured-Unstructured-Text-Retrieval-Architecture
|
|
1074
|
-
|
|
1075
|
-
Fixed: Proper primary and foreign keys with unique IDs
|
|
1076
|
-
Features: Load existing DB, custom viz, fuzzy matching, embeddings
|
|
1077
|
-
|
|
1078
|
-
Installation: pip install QuerySUTRA
|
|
1079
|
-
Usage: from sutra import SUTRA
|
|
1080
|
-
""")
|