genshotsql 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ Metadata-Version: 2.4
2
+ Name: genshotsql
3
+ Version: 0.1.2
4
+ Summary: A framework for interacting with SQL databases by writing prompts
5
+ Author: Pranav Verma
@@ -0,0 +1,5 @@
1
+ Metadata-Version: 2.4
2
+ Name: genshotsql
3
+ Version: 0.1.2
4
+ Summary: A framework for interacting with SQL databases by writing prompts
5
+ Author: Pranav Verma
@@ -0,0 +1,8 @@
1
+ pyproject.toml
2
+ genshotsql.egg-info/PKG-INFO
3
+ genshotsql.egg-info/SOURCES.txt
4
+ genshotsql.egg-info/dependency_links.txt
5
+ genshotsql.egg-info/entry_points.txt
6
+ genshotsql.egg-info/top_level.txt
7
+ templates/chat_with_database.py
8
+ templates/config.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ gensql = genshotsql.cli:main
@@ -0,0 +1 @@
1
+ templates
@@ -0,0 +1,12 @@
1
+ [project]
2
+ name = "genshotsql"
3
+ version = "0.1.2"
4
+ description = "A framework for interacting with SQL databases by writing prompts"
5
+ authors = [
6
+ { name = "Pranav Verma" }
7
+ ]
8
+
9
+ [project.scripts]
10
+ gensql = "genshotsql.cli:main"
11
+
12
+
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,468 @@
1
+ import json
2
+ import re
3
+ import urllib.request
4
+ import urllib.error
5
+
6
+ import faiss
7
+ import mysql.connector
8
+ import numpy as np
9
+
10
+ # Import configuration from config.py
11
+ from config import DB_CONFIG, OLLAMA_MODEL, OLLAMA_URL, OLLAMA_EMBED_MODEL, OLLAMA_EMBED_URL, SCHEMAS
12
+
13
+ RAG_DOCUMENTS = []
14
+
15
+ def add_schema_to_rag(schema):
16
+ """Add a structured schema to RAG documents"""
17
+ table_name = schema["table_name"]
18
+
19
+ content = f"Table: {table_name}\n"
20
+
21
+ if schema.get("description"):
22
+ content += f"Purpose: {schema['description']}\n"
23
+
24
+ content += "Columns:\n"
25
+
26
+ for column in schema["columns"]:
27
+ line = f"- {column['name']} {column['type']}"
28
+
29
+ if column.get("primary_key"):
30
+ line += " PRIMARY KEY"
31
+
32
+ if column.get("description"):
33
+ line += f". {column['description']}"
34
+
35
+ content += line + "\n"
36
+
37
+ if schema.get("usage"):
38
+ content += schema["usage"]
39
+
40
+ RAG_DOCUMENTS.append({
41
+ "title": f"{table_name} table",
42
+ "content": content.strip()
43
+ })
44
+
45
+ def add_table_to_rag_documents(table_name, schema_text):
46
+ """Add a table schema to RAG documents"""
47
+ RAG_DOCUMENTS.append({
48
+ "title": f"{table_name} table",
49
+ "content": schema_text.strip()
50
+ })
51
+ return True
52
+
53
+ def ask_ollama(prompt):
54
+ """Send a prompt to Ollama and get response"""
55
+ data = {
56
+ "model": OLLAMA_MODEL,
57
+ "prompt": prompt,
58
+ "stream": False,
59
+ }
60
+
61
+ request = urllib.request.Request(
62
+ OLLAMA_URL,
63
+ data=json.dumps(data).encode("utf-8"),
64
+ headers={"Content-Type": "application/json"},
65
+ )
66
+
67
+ try:
68
+ with urllib.request.urlopen(request, timeout=30) as response:
69
+ result = json.loads(response.read().decode("utf-8"))
70
+ return result["response"].strip()
71
+ except urllib.error.URLError as e:
72
+ raise Exception(f"Failed to connect to Ollama: {e}")
73
+
74
+ def clean_json(text):
75
+ """Clean JSON response from Ollama"""
76
+ text = text.strip()
77
+ text = re.sub(r"^```json", "", text, flags=re.IGNORECASE).strip()
78
+ text = re.sub(r"^```", "", text).strip()
79
+ text = re.sub(r"```$", "", text).strip()
80
+ match = re.search(r"\{.*\}", text, flags=re.DOTALL)
81
+ if match:
82
+ return match.group(0)
83
+ return text
84
+
85
+ def convert_readable_schema_to_rag_schema(readable_schema):
86
+ """Convert user-written schema to RAG format"""
87
+ prompt = f"""
88
+ Convert this user-written database schema into a RAG document.
89
+
90
+ Return only valid JSON with these keys:
91
+ - table_name
92
+ - schema
93
+
94
+ Rules:
95
+ - Do not write explanation before or after the JSON.
96
+ - Keep the schema simple and readable.
97
+ - Include every column mentioned by the user.
98
+ - The schema must start with: Table: <table_name>
99
+ - Include Purpose, Columns, Use this table when, and Join rules if relationships exist.
100
+ - If a primary key is mentioned, write PRIMARY KEY.
101
+ - If a foreign key relationship is mentioned, write FOREIGN KEY REFERENCES table(column).
102
+ - Do not invent columns that are not mentioned.
103
+
104
+ User-written schema:
105
+ {readable_schema}
106
+
107
+ Example JSON format:
108
+ {{
109
+ "table_name": "department",
110
+ "schema": "Table: department\\nPurpose: Stores department details.\\nColumns:\\n- id INT PRIMARY KEY. Unique department id.\\n- department_name VARCHAR(50). Department name.\\nUse this table when the question asks about departments."
111
+ }}
112
+ """
113
+
114
+ response = clean_json(ask_ollama(prompt))
115
+ converted_schema = json.loads(response)
116
+
117
+ if "table_name" not in converted_schema or "schema" not in converted_schema:
118
+ raise ValueError("The model did not return table_name and schema.")
119
+
120
+ schema = converted_schema["schema"]
121
+
122
+ mentioned_id_columns = re.findall(r"\b([a-zA-Z][a-zA-Z0-9_]*_id)\b", readable_schema)
123
+ for column in mentioned_id_columns:
124
+ if column not in schema:
125
+ foreign_key_match = re.search(
126
+ rf"\b{column}\b.*?connected to ([a-zA-Z][a-zA-Z0-9_]*) id",
127
+ readable_schema,
128
+ flags=re.IGNORECASE,
129
+ )
130
+
131
+ if foreign_key_match:
132
+ referenced_table = foreign_key_match.group(1).lower()
133
+ column_text = (
134
+ f"- {column} INT FOREIGN KEY REFERENCES {referenced_table}(id). "
135
+ f"Connected to {referenced_table}.")
136
+ join_text = f"- {converted_schema['table_name']}.{column} = {referenced_table}.id"
137
+ else:
138
+ column_text = f"- {column} INT. Mentioned by the user."
139
+ join_text = None
140
+
141
+ if "Columns:" in schema:
142
+ schema = schema.replace("Use this table when", f"{column_text}\nUse this table when")
143
+ else:
144
+ schema += f"\nColumns:\n{column_text}"
145
+
146
+ if join_text and "Join rules:" in schema:
147
+ schema += f"\n{join_text}"
148
+ elif join_text:
149
+ schema += f"\nJoin rules:\n{join_text}"
150
+
151
+ return converted_schema["table_name"], schema
152
+
153
+ def add_readable_schema_to_rag_documents(readable_schema):
154
+ """Add a readable schema to RAG documents"""
155
+ table_name, schema = convert_readable_schema_to_rag_schema(readable_schema)
156
+ return add_table_to_rag_documents(table_name, schema)
157
+
158
+ def get_embedding(text):
159
+ """Get embedding for text from Ollama"""
160
+ data = {
161
+ "model": OLLAMA_EMBED_MODEL,
162
+ "prompt": text,
163
+ }
164
+
165
+ request = urllib.request.Request(
166
+ OLLAMA_EMBED_URL,
167
+ data=json.dumps(data).encode("utf-8"),
168
+ headers={"Content-Type": "application/json"},
169
+ )
170
+
171
+ try:
172
+ with urllib.request.urlopen(request, timeout=30) as response:
173
+ result = json.loads(response.read().decode("utf-8"))
174
+ return result["embedding"]
175
+ except urllib.error.URLError as e:
176
+ raise Exception(f"Failed to get embedding from Ollama: {e}")
177
+
178
+ def normalize_vectors(vectors):
179
+ """Normalize vectors for cosine similarity"""
180
+ norms = np.linalg.norm(vectors, axis=1, keepdims=True)
181
+ return vectors / np.maximum(norms, 1e-12)
182
+
183
+ def build_faiss_index():
184
+ """Build FAISS index from RAG documents"""
185
+ if not RAG_DOCUMENTS:
186
+ return None
187
+
188
+ document_texts = [
189
+ document["title"] + "\n" + document["content"]
190
+ for document in RAG_DOCUMENTS
191
+ ]
192
+ embeddings = np.array(
193
+ [get_embedding(text) for text in document_texts],
194
+ dtype="float32",
195
+ )
196
+ embeddings = normalize_vectors(embeddings)
197
+
198
+ index = faiss.IndexFlatIP(embeddings.shape[1])
199
+ index.add(embeddings)
200
+ return index
201
+
202
+ FAISS_INDEX = None
203
+
204
+ def retrieve_context(question, top_k=3):
205
+ """Retrieve relevant context for a question"""
206
+ global FAISS_INDEX
207
+
208
+ if not RAG_DOCUMENTS:
209
+ return "No schema information available."
210
+
211
+ top_k = min(top_k, len(RAG_DOCUMENTS))
212
+
213
+ if FAISS_INDEX is None:
214
+ FAISS_INDEX = build_faiss_index()
215
+
216
+ if FAISS_INDEX is None:
217
+ return "No schema information available."
218
+
219
+ query_embedding = np.array([get_embedding(question)], dtype="float32")
220
+ query_embedding = normalize_vectors(query_embedding)
221
+
222
+ _, indexes = FAISS_INDEX.search(query_embedding, top_k)
223
+ selected_documents = [RAG_DOCUMENTS[index] for index in indexes[0]]
224
+
225
+ return "\n\n".join(document["content"] for document in selected_documents)
226
+
227
+ def clean_sql(text):
228
+ """Clean SQL response from Ollama"""
229
+ text = text.strip()
230
+ text = re.sub(r"^```sql", "", text, flags=re.IGNORECASE).strip()
231
+ text = re.sub(r"^```", "", text).strip()
232
+ text = re.sub(r"```$", "", text).strip()
233
+ return text.rstrip(";")
234
+
235
+ def question_to_sql(question):
236
+ """Convert a natural language question to SQL"""
237
+ context = retrieve_context(question)
238
+
239
+ prompt = f"""
240
+ You are a MySQL assistant.
241
+ Convert the user's question into one MySQL SELECT query.
242
+
243
+ Rules:
244
+ - Return only SQL.
245
+ - Only use SELECT queries.
246
+ - Do not use INSERT, UPDATE, DELETE, DROP, ALTER, or CREATE.
247
+ - Use only the retrieved database context below.
248
+ - Always use proper JOIN syntax.
249
+ - Use parameterized values (strings in quotes).
250
+
251
+ Retrieved database context:
252
+ {context}
253
+
254
+ User question: {question}
255
+ """
256
+
257
+ sql = clean_sql(ask_ollama(prompt))
258
+
259
+ if not sql.lower().startswith("select"):
260
+ raise ValueError("Only SELECT queries are allowed.")
261
+
262
+ blocked_words = ["insert", "update", "delete", "drop", "alter", "create", "truncate"]
263
+ if any(word in sql.lower() for word in blocked_words):
264
+ raise ValueError("This query is not allowed.")
265
+
266
+ return sql
267
+
268
+ def repair_sql(question, bad_sql, error_message):
269
+ """Repair SQL that caused an error"""
270
+ context = retrieve_context(question)
271
+
272
+ prompt = f"""
273
+ You are a MySQL assistant.
274
+ The SQL query below failed. Fix it.
275
+
276
+ Rules:
277
+ - Return only corrected SQL.
278
+ - Only use SELECT queries.
279
+ - Do not use INSERT, UPDATE, DELETE, DROP, ALTER, or CREATE.
280
+ - Use only the retrieved database context below.
281
+ - Always use proper JOIN syntax.
282
+ - Fix any syntax errors or column name issues.
283
+
284
+ Retrieved database context:
285
+ {context}
286
+
287
+ User question:
288
+ {question}
289
+
290
+ Bad SQL:
291
+ {bad_sql}
292
+
293
+ Database error:
294
+ {error_message}
295
+ """
296
+
297
+ sql = clean_sql(ask_ollama(prompt))
298
+
299
+ if not sql.lower().startswith("select"):
300
+ raise ValueError("Only SELECT queries are allowed.")
301
+
302
+ blocked_words = ["insert", "update", "delete", "drop", "alter", "create", "truncate"]
303
+ if any(word in sql.lower() for word in blocked_words):
304
+ raise ValueError("This query is not allowed.")
305
+
306
+ return sql
307
+
308
+ def run_query(sql):
309
+ """Execute SQL query on the database"""
310
+ try:
311
+ conn = mysql.connector.connect(**DB_CONFIG)
312
+ cursor = conn.cursor()
313
+ cursor.execute(sql)
314
+ rows = cursor.fetchall()
315
+ columns = [column[0] for column in cursor.description]
316
+ cursor.close()
317
+ conn.close()
318
+ return columns, rows
319
+ except mysql.connector.Error as e:
320
+ raise Exception(f"Database error: {e}")
321
+
322
+ def explain_result(question, sql, columns, rows):
323
+ """Explain query results in natural language"""
324
+ # Limit rows to avoid overwhelming the prompt
325
+ display_rows = rows[:10] if len(rows) > 10 else rows
326
+
327
+ prompt = f"""
328
+ The user asked: {question}
329
+
330
+ SQL used:
331
+ {sql}
332
+
333
+ Columns:
334
+ {columns}
335
+
336
+ Rows (showing first {len(display_rows)} of {len(rows)}):
337
+ {display_rows}
338
+
339
+ Explain the result in simple English.
340
+ """
341
+ return ask_ollama(prompt)
342
+
343
+ def initialize_schemas():
344
+ """Initialize the database schemas"""
345
+ for schema in SCHEMAS:
346
+ add_schema_to_rag(schema)
347
+
348
+ print(f"Loaded {len(SCHEMAS)} schemas into RAG")
349
+
350
+ def test_connection():
351
+ """Test database connection"""
352
+ try:
353
+ conn = mysql.connector.connect(**DB_CONFIG)
354
+ cursor = conn.cursor()
355
+ cursor.execute("SELECT VERSION()")
356
+ version = cursor.fetchone()
357
+ cursor.close()
358
+ conn.close()
359
+ print(f"✓ Database connected successfully (MySQL Version: {version[0]})")
360
+ return True
361
+ except Exception as e:
362
+ print(f"✗ Database connection failed: {e}")
363
+ return False
364
+
365
+ def main():
366
+ print("="*60)
367
+ print("RAG Chat with Database")
368
+ print("="*60)
369
+ print("\nType 'exit' to stop the program")
370
+ print("Type 'test' to test database connection")
371
+ print("Type 'schemas' to show loaded schemas\n")
372
+
373
+ # Test database connection
374
+ if not test_connection():
375
+ print("\nPlease check your database configuration in config.py")
376
+ return
377
+
378
+ # Initialize schemas
379
+ try:
380
+ initialize_schemas()
381
+ print(f"✓ Loaded {len(SCHEMAS)} schemas into RAG\n")
382
+ except Exception as e:
383
+ print(f"Error initializing schemas: {e}")
384
+ return
385
+
386
+ while True:
387
+ try:
388
+ question = input("\n❓ Ask: ").strip()
389
+
390
+ if question.lower() in ["exit", "quit"]:
391
+ print("\nGoodbye! 👋")
392
+ break
393
+
394
+ if question.lower() == "test":
395
+ test_connection()
396
+ continue
397
+
398
+ if question.lower() == "schemas":
399
+ print("\nLoaded Schemas:")
400
+ for schema in SCHEMAS:
401
+ print(f" - {schema['table_name']}: {schema.get('description', 'No description')}")
402
+ continue
403
+
404
+ if not question:
405
+ continue
406
+
407
+ print("\n🤔 Processing your question...")
408
+
409
+ # Generate SQL
410
+ sql = question_to_sql(question)
411
+ print(f"\n📝 Generated SQL: {sql}")
412
+
413
+ # Try to execute with repair attempts
414
+ columns = None
415
+ rows = None
416
+
417
+ for attempt in range(2):
418
+ try:
419
+ columns, rows = run_query(sql)
420
+ break
421
+ except Exception as db_error:
422
+ if attempt == 1:
423
+ print(f"\n❌ Database error after retry: {db_error}")
424
+ raise
425
+
426
+ print(f"\n🔧 SQL error, attempting repair...")
427
+ sql = repair_sql(question, sql, str(db_error))
428
+ print(f"📝 Repaired SQL: {sql}")
429
+
430
+ # Display results
431
+ print("\n" + "="*60)
432
+ print("📊 RESULTS:")
433
+ print("="*60)
434
+
435
+ if rows:
436
+ # Print column headers
437
+ header = " | ".join(columns)
438
+ print(header)
439
+ print("-" * len(header))
440
+ for i, row in enumerate(rows[:20]): # Show first 20 rows
441
+ print(" | ".join(str(value) for value in row))
442
+
443
+ if len(rows) > 20:
444
+ print(f"\n... and {len(rows) - 20} more rows")
445
+
446
+ print(f"\n📈 Total rows returned: {len(rows)}")
447
+ else:
448
+ print("No rows found")
449
+
450
+ # Get natural language explanation
451
+ print("\n" + "="*60)
452
+ print("💡 EXPLANATION:")
453
+ print("="*60)
454
+ answer = explain_result(question, sql, columns, rows)
455
+ print(f"\n{answer}\n")
456
+
457
+ except KeyboardInterrupt:
458
+ print("\n\nGoodbye! 👋")
459
+ break
460
+ except Exception as error:
461
+ print(f"\n❌ Error: {error}\n")
462
+ print("💡 Tips:")
463
+ print(" - Make sure Ollama is running: 'ollama serve'")
464
+ print(" - Pull required models: 'ollama pull llama3.2:3b && ollama pull nomic-embed-text'")
465
+ print(" - Check database connection in config.py")
466
+
467
+ if __name__ == "__main__":
468
+ main()
@@ -0,0 +1,143 @@
1
+ # config.example.py
2
+ # Copy this file to config.py and update with your credentials
3
+
4
+ # Database configuration
5
+ DB_CONFIG = {
6
+ "host": "host_name",
7
+ "user": "user_name",
8
+ "password": "password",
9
+ "database": "database_name",
10
+ }
11
+
12
+ # Ollama configuration
13
+ OLLAMA_MODEL = "llama3.2:3b"
14
+ OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
15
+ OLLAMA_EMBED_MODEL = "nomic-embed-text"
16
+ OLLAMA_EMBED_URL = "http://127.0.0.1:11434/api/embeddings"
17
+
18
+ # Database schemas
19
+ SCHEMAS = [
20
+ ## Add your database schemas here
21
+ ]
22
+
23
+ # SCHEMAS = [
24
+ # {
25
+ # "table_name": "employee",
26
+ # "description": "Stores employee details including personal information, department assignment, and salary.",
27
+ # "columns": [
28
+ # {
29
+ # "name": "id",
30
+ # "type": "INT",
31
+ # "primary_key": True,
32
+ # "description": "Unique employee identifier."
33
+ # },
34
+ # {
35
+ # "name": "name",
36
+ # "type": "VARCHAR(100)",
37
+ # "primary_key": False,
38
+ # "description": "Full name of the employee."
39
+ # },
40
+ # {
41
+ # "name": "department_id",
42
+ # "type": "INT",
43
+ # "primary_key": False,
44
+ # "description": "Foreign key referencing the department where the employee works."
45
+ # },
46
+ # {
47
+ # "name": "salary",
48
+ # "type": "INT",
49
+ # "primary_key": False,
50
+ # "description": "Employee's salary in numeric format."
51
+ # }
52
+ # ],
53
+ # "usage": "Use this table for employee information, employee names, salaries, and department assignments."
54
+ # },
55
+ # {
56
+ # "table_name": "department",
57
+ # "description": "Stores department information and office location mapping.",
58
+ # "columns": [
59
+ # {
60
+ # "name": "id",
61
+ # "type": "INT",
62
+ # "primary_key": True,
63
+ # "description": "Unique department identifier."
64
+ # },
65
+ # {
66
+ # "name": "name",
67
+ # "type": "VARCHAR(100)",
68
+ # "primary_key": False,
69
+ # "description": "Department name (e.g., IT, HR, Finance, Sales, Marketing, Operations, Support, Research, Legal, Engineering)."
70
+ # },
71
+ # {
72
+ # "name": "office_id",
73
+ # "type": "INT",
74
+ # "primary_key": False,
75
+ # "description": "Foreign key referencing the office location where the department is situated."
76
+ # }
77
+ # ],
78
+ # "usage": "Use this table when the question asks about departments, department names, or office locations."
79
+ # },
80
+ # {
81
+ # "table_name": "project",
82
+ # "description": "Stores project information available in the company.",
83
+ # "columns": [
84
+ # {
85
+ # "name": "id",
86
+ # "type": "INT",
87
+ # "primary_key": True,
88
+ # "description": "Unique project identifier."
89
+ # },
90
+ # {
91
+ # "name": "project_name",
92
+ # "type": "VARCHAR(100)",
93
+ # "primary_key": False,
94
+ # "description": "Name of the project (e.g., Payroll System, Inventory App, Customer Portal, CRM Upgrade, Mobile Banking, Data Warehouse, Analytics Dashboard, AI Assistant, Cloud Migration, E-Commerce Platform)."
95
+ # }
96
+ # ],
97
+ # "usage": "Use this table when the question asks about projects, project names, or project-related information."
98
+ # },
99
+ # {
100
+ # "table_name": "employee_project",
101
+ # "description": "Junction table mapping employees to projects they are working on (many-to-many relationship).",
102
+ # "columns": [
103
+ # {
104
+ # "name": "id",
105
+ # "type": "INT",
106
+ # "primary_key": True,
107
+ # "description": "Unique mapping identifier."
108
+ # },
109
+ # {
110
+ # "name": "employee_id",
111
+ # "type": "INT",
112
+ # "primary_key": False,
113
+ # "description": "Foreign key referencing the employee assigned to a project."
114
+ # },
115
+ # {
116
+ # "name": "project_id",
117
+ # "type": "INT",
118
+ # "primary_key": False,
119
+ # "description": "Foreign key referencing the project assigned to an employee."
120
+ # }
121
+ # ],
122
+ # "usage": "Use this table to find relationships between employees and projects. This table does NOT contain name columns. To get employee or project names, join with employee and project tables."
123
+ # },
124
+ # {
125
+ # "table_name": "office",
126
+ # "description": "Stores office location information.",
127
+ # "columns": [
128
+ # {
129
+ # "name": "id",
130
+ # "type": "INT",
131
+ # "primary_key": True,
132
+ # "description": "Unique office identifier."
133
+ # },
134
+ # {
135
+ # "name": "city",
136
+ # "type": "VARCHAR(100)",
137
+ # "primary_key": False,
138
+ # "description": "City where the office is located (e.g., Bangalore, Mumbai, Delhi, Pune, Hyderabad, Chennai, Kolkata, Ahmedabad, Noida, Jaipur)."
139
+ # }
140
+ # ],
141
+ # "usage": "Use this table when the question asks about office locations, cities, or where departments are situated."
142
+ # }
143
+ # ]