mcp-sqlite-memory-bank 1.2.4__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_sqlite_memory_bank/__init__.py +6 -9
- mcp_sqlite_memory_bank/database.py +923 -0
- mcp_sqlite_memory_bank/semantic.py +386 -0
- mcp_sqlite_memory_bank/server.py +838 -1140
- mcp_sqlite_memory_bank/types.py +105 -9
- mcp_sqlite_memory_bank/utils.py +22 -54
- {mcp_sqlite_memory_bank-1.2.4.dist-info → mcp_sqlite_memory_bank-1.3.0.dist-info}/METADATA +59 -8
- mcp_sqlite_memory_bank-1.3.0.dist-info/RECORD +13 -0
- mcp_sqlite_memory_bank-1.2.4.dist-info/RECORD +0 -11
- {mcp_sqlite_memory_bank-1.2.4.dist-info → mcp_sqlite_memory_bank-1.3.0.dist-info}/WHEEL +0 -0
- {mcp_sqlite_memory_bank-1.2.4.dist-info → mcp_sqlite_memory_bank-1.3.0.dist-info}/entry_points.txt +0 -0
- {mcp_sqlite_memory_bank-1.2.4.dist-info → mcp_sqlite_memory_bank-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {mcp_sqlite_memory_bank-1.2.4.dist-info → mcp_sqlite_memory_bank-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,923 @@
|
|
1
|
+
"""
|
2
|
+
Database abstraction layer for SQLite Memory Bank using SQLAlchemy Core.
|
3
|
+
|
4
|
+
This module provides a clean, type-safe abstraction over SQLite operations,
|
5
|
+
dramatically simplifying the complex database logic in server.py.
|
6
|
+
|
7
|
+
Author: Robert Meisner
|
8
|
+
"""
|
9
|
+
|
10
|
+
import os
|
11
|
+
import json
|
12
|
+
import logging
|
13
|
+
from functools import wraps
|
14
|
+
from typing import Dict, List, Any, Optional, Callable, cast
|
15
|
+
from sqlalchemy import create_engine, MetaData, Table, select, insert, update, delete, text, inspect, and_, or_
|
16
|
+
from sqlalchemy.engine import Engine
|
17
|
+
from sqlalchemy.exc import SQLAlchemyError
|
18
|
+
from contextlib import contextmanager
|
19
|
+
|
20
|
+
from .types import ValidationError, DatabaseError, SchemaError, ToolResponse, EmbeddingColumnResponse, GenerateEmbeddingsResponse, SemanticSearchResponse, RelatedContentResponse, HybridSearchResponse, EmbeddingStatsResponse
|
21
|
+
from .semantic import get_semantic_engine, is_semantic_search_available
|
22
|
+
|
23
|
+
|
24
|
+
class SQLiteMemoryDatabase:
|
25
|
+
"""
|
26
|
+
SQLAlchemy Core-based database abstraction for SQLite Memory Bank.
|
27
|
+
|
28
|
+
This class handles all database operations with automatic:
|
29
|
+
- Connection management
|
30
|
+
- Error handling and translation
|
31
|
+
- SQL injection protection
|
32
|
+
- Type conversion
|
33
|
+
- Transaction management
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, db_path: str):
|
37
|
+
"""Initialize database connection and metadata."""
|
38
|
+
self.db_path = os.path.abspath(db_path)
|
39
|
+
self.engine: Engine = create_engine(f"sqlite:///{self.db_path}", echo=False)
|
40
|
+
self.metadata = MetaData()
|
41
|
+
|
42
|
+
# Ensure database directory exists
|
43
|
+
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
44
|
+
|
45
|
+
# Initialize connection
|
46
|
+
self._refresh_metadata()
|
47
|
+
|
48
|
+
def close(self) -> None:
|
49
|
+
"""Close all database connections and dispose of the engine."""
|
50
|
+
try:
|
51
|
+
if hasattr(self, "engine"):
|
52
|
+
self.engine.dispose()
|
53
|
+
except Exception as e:
|
54
|
+
logging.warning(f"Error closing database: {e}")
|
55
|
+
|
56
|
+
def __del__(self):
|
57
|
+
"""Ensure cleanup when object is garbage collected."""
|
58
|
+
self.close()
|
59
|
+
|
60
|
+
def _refresh_metadata(self) -> None:
|
61
|
+
"""Refresh metadata to reflect current database schema."""
|
62
|
+
try:
|
63
|
+
self.metadata.clear()
|
64
|
+
self.metadata.reflect(bind=self.engine)
|
65
|
+
except SQLAlchemyError as e:
|
66
|
+
logging.warning(f"Failed to refresh metadata: {e}")
|
67
|
+
|
68
|
+
@contextmanager
|
69
|
+
def get_connection(self):
|
70
|
+
"""Get a database connection with automatic cleanup."""
|
71
|
+
conn = self.engine.connect()
|
72
|
+
try:
|
73
|
+
yield conn
|
74
|
+
finally:
|
75
|
+
conn.close()
|
76
|
+
|
77
|
+
def _ensure_table_exists(self, table_name: str) -> Table:
|
78
|
+
"""Get table metadata, refreshing if needed.
|
79
|
+
Raises ValidationError if not found.
|
80
|
+
"""
|
81
|
+
if table_name not in self.metadata.tables:
|
82
|
+
self._refresh_metadata()
|
83
|
+
|
84
|
+
if table_name not in self.metadata.tables:
|
85
|
+
raise ValidationError(f"Table '{table_name}' does not exist")
|
86
|
+
|
87
|
+
return self.metadata.tables[table_name]
|
88
|
+
|
89
|
+
def _validate_columns(self, table: Table, column_names: List[str], context: str = "operation") -> None:
|
90
|
+
"""Validate that all column names exist in the table."""
|
91
|
+
valid_columns = set(col.name for col in table.columns)
|
92
|
+
for col_name in column_names:
|
93
|
+
if col_name not in valid_columns:
|
94
|
+
raise ValidationError(f"Invalid column '{col_name}' for table " f"'{table.name}' in {context}")
|
95
|
+
|
96
|
+
def _build_where_conditions(self, table: Table, where: Dict[str, Any]) -> List:
|
97
|
+
"""Build SQLAlchemy WHERE conditions from a dictionary."""
|
98
|
+
if not where:
|
99
|
+
return []
|
100
|
+
|
101
|
+
self._validate_columns(table, list(where.keys()), "WHERE clause")
|
102
|
+
return [table.c[col_name] == value for col_name, value in where.items()]
|
103
|
+
|
104
|
+
def _execute_with_commit(self, stmt) -> Any:
|
105
|
+
"""Execute a statement with automatic connection mgmt and commit."""
|
106
|
+
with self.get_connection() as conn:
|
107
|
+
result = conn.execute(stmt)
|
108
|
+
conn.commit()
|
109
|
+
return result
|
110
|
+
|
111
|
+
def _database_operation(self, operation_name: str):
|
112
|
+
"""Decorator for database operations with standardized error handling."""
|
113
|
+
|
114
|
+
def decorator(func: Callable) -> Callable:
|
115
|
+
@wraps(func)
|
116
|
+
def wrapper(*args, **kwargs) -> ToolResponse:
|
117
|
+
try:
|
118
|
+
return func(*args, **kwargs)
|
119
|
+
except (ValidationError, SchemaError) as e:
|
120
|
+
raise e
|
121
|
+
except SQLAlchemyError as e:
|
122
|
+
raise DatabaseError(f"Failed to {operation_name}: {str(e)}")
|
123
|
+
|
124
|
+
return wrapper
|
125
|
+
|
126
|
+
return decorator
|
127
|
+
|
128
|
+
def create_table(self, table_name: str, columns: List[Dict[str, str]]) -> ToolResponse:
|
129
|
+
"""Create a new table with the specified columns."""
|
130
|
+
# Input validation
|
131
|
+
if not table_name or not table_name.isidentifier():
|
132
|
+
raise ValidationError(f"Invalid table name: {table_name}")
|
133
|
+
if not columns:
|
134
|
+
raise ValidationError("Must provide at least one column")
|
135
|
+
|
136
|
+
# Validate column definitions
|
137
|
+
for col_def in columns:
|
138
|
+
if "name" not in col_def or "type" not in col_def:
|
139
|
+
raise ValidationError(f"Column must have 'name' and 'type': {col_def}")
|
140
|
+
if not col_def["name"].isidentifier():
|
141
|
+
raise ValidationError(f"Invalid column name: {col_def['name']}")
|
142
|
+
|
143
|
+
try:
|
144
|
+
# Use raw SQL for full SQLite type support
|
145
|
+
col_defs = ", ".join([f"{col['name']} {col['type']}" for col in columns])
|
146
|
+
sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({col_defs})"
|
147
|
+
|
148
|
+
with self.get_connection() as conn:
|
149
|
+
conn.execute(text(sql))
|
150
|
+
conn.commit()
|
151
|
+
|
152
|
+
self._refresh_metadata()
|
153
|
+
return {"success": True}
|
154
|
+
|
155
|
+
except SQLAlchemyError as e:
|
156
|
+
raise DatabaseError(f"Failed to create table {table_name}: {str(e)}")
|
157
|
+
|
158
|
+
def list_tables(self) -> ToolResponse:
|
159
|
+
"""List all user-created tables."""
|
160
|
+
try:
|
161
|
+
with self.get_connection() as conn:
|
162
|
+
inspector = inspect(conn)
|
163
|
+
tables = [name for name in inspector.get_table_names() if not name.startswith("sqlite_")]
|
164
|
+
return {"success": True, "tables": tables}
|
165
|
+
except SQLAlchemyError as e:
|
166
|
+
raise DatabaseError(f"Failed to list tables: {str(e)}")
|
167
|
+
|
168
|
+
def describe_table(self, table_name: str) -> ToolResponse:
|
169
|
+
"""Get detailed schema information for a table."""
|
170
|
+
try:
|
171
|
+
table = self._ensure_table_exists(table_name)
|
172
|
+
columns = [
|
173
|
+
{
|
174
|
+
"name": col.name,
|
175
|
+
"type": str(col.type),
|
176
|
+
"nullable": col.nullable,
|
177
|
+
"default": col.default,
|
178
|
+
"primary_key": col.primary_key,
|
179
|
+
}
|
180
|
+
for col in table.columns
|
181
|
+
]
|
182
|
+
return {"success": True, "columns": columns}
|
183
|
+
except (ValidationError, SQLAlchemyError) as e:
|
184
|
+
if isinstance(e, ValidationError):
|
185
|
+
raise e
|
186
|
+
raise DatabaseError(f"Failed to describe table {table_name}: {str(e)}")
|
187
|
+
|
188
|
+
def drop_table(self, table_name: str) -> ToolResponse:
|
189
|
+
"""Drop a table."""
|
190
|
+
try:
|
191
|
+
self._ensure_table_exists(table_name) # Validates existence
|
192
|
+
self._execute_with_commit(text(f"DROP TABLE {table_name}"))
|
193
|
+
self._refresh_metadata()
|
194
|
+
return {"success": True}
|
195
|
+
except (ValidationError, SQLAlchemyError) as e:
|
196
|
+
if isinstance(e, ValidationError):
|
197
|
+
raise e
|
198
|
+
raise DatabaseError(f"Failed to drop table {table_name}: {str(e)}")
|
199
|
+
|
200
|
+
def rename_table(self, old_name: str, new_name: str) -> ToolResponse:
|
201
|
+
"""Rename a table."""
|
202
|
+
if old_name == new_name:
|
203
|
+
raise ValidationError("Old and new table names are identical")
|
204
|
+
|
205
|
+
try:
|
206
|
+
self._ensure_table_exists(old_name) # Validates old table exists
|
207
|
+
|
208
|
+
# Check if new name already exists
|
209
|
+
with self.get_connection() as conn:
|
210
|
+
inspector = inspect(conn)
|
211
|
+
if new_name in inspector.get_table_names():
|
212
|
+
raise ValidationError(f"Table '{new_name}' already exists")
|
213
|
+
|
214
|
+
conn.execute(text(f"ALTER TABLE {old_name} RENAME TO {new_name}"))
|
215
|
+
conn.commit()
|
216
|
+
|
217
|
+
self._refresh_metadata()
|
218
|
+
return {"success": True}
|
219
|
+
except (ValidationError, SQLAlchemyError) as e:
|
220
|
+
if isinstance(e, ValidationError):
|
221
|
+
raise e
|
222
|
+
raise DatabaseError(f"Failed to rename table from {old_name} to {new_name}: {str(e)}")
|
223
|
+
|
224
|
+
def insert_row(self, table_name: str, data: Dict[str, Any]) -> ToolResponse:
|
225
|
+
"""Insert a row into a table."""
|
226
|
+
if not data:
|
227
|
+
raise ValidationError("Data cannot be empty")
|
228
|
+
|
229
|
+
try:
|
230
|
+
table = self._ensure_table_exists(table_name)
|
231
|
+
self._validate_columns(table, list(data.keys()), "insert operation")
|
232
|
+
|
233
|
+
result = self._execute_with_commit(insert(table).values(**data))
|
234
|
+
return {"success": True, "id": result.lastrowid}
|
235
|
+
except (ValidationError, SQLAlchemyError) as e:
|
236
|
+
if isinstance(e, ValidationError):
|
237
|
+
raise e
|
238
|
+
raise DatabaseError(f"Failed to insert into table {table_name}: {str(e)}")
|
239
|
+
|
240
|
+
def read_rows(self, table_name: str, where: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> ToolResponse:
|
241
|
+
"""Read rows from a table with optional filtering."""
|
242
|
+
try:
|
243
|
+
table = self._ensure_table_exists(table_name)
|
244
|
+
stmt = select(table)
|
245
|
+
|
246
|
+
# Apply WHERE conditions
|
247
|
+
conditions = self._build_where_conditions(table, where or {})
|
248
|
+
if conditions:
|
249
|
+
stmt = stmt.where(and_(*conditions))
|
250
|
+
|
251
|
+
# Apply LIMIT
|
252
|
+
if limit:
|
253
|
+
stmt = stmt.limit(limit)
|
254
|
+
|
255
|
+
with self.get_connection() as conn:
|
256
|
+
result = conn.execute(stmt)
|
257
|
+
rows = [dict(row._mapping) for row in result.fetchall()]
|
258
|
+
|
259
|
+
return {"success": True, "rows": rows}
|
260
|
+
except (ValidationError, SQLAlchemyError) as e:
|
261
|
+
if isinstance(e, ValidationError):
|
262
|
+
raise e
|
263
|
+
raise DatabaseError(f"Failed to read from table {table_name}: {str(e)}")
|
264
|
+
|
265
|
+
def update_rows(self, table_name: str, data: Dict[str, Any], where: Optional[Dict[str, Any]] = None) -> ToolResponse:
|
266
|
+
"""Update rows in a table."""
|
267
|
+
if not data:
|
268
|
+
raise ValidationError("Update data cannot be empty")
|
269
|
+
|
270
|
+
try:
|
271
|
+
table = self._ensure_table_exists(table_name)
|
272
|
+
self._validate_columns(table, list(data.keys()), "update operation")
|
273
|
+
|
274
|
+
stmt = update(table).values(**data)
|
275
|
+
|
276
|
+
# Apply WHERE conditions
|
277
|
+
conditions = self._build_where_conditions(table, where or {})
|
278
|
+
if conditions:
|
279
|
+
stmt = stmt.where(and_(*conditions))
|
280
|
+
|
281
|
+
result = self._execute_with_commit(stmt)
|
282
|
+
return {"success": True, "rows_affected": result.rowcount}
|
283
|
+
except (ValidationError, SQLAlchemyError) as e:
|
284
|
+
if isinstance(e, ValidationError):
|
285
|
+
raise e
|
286
|
+
raise DatabaseError(f"Failed to update table {table_name}: {str(e)}")
|
287
|
+
|
288
|
+
def delete_rows(self, table_name: str, where: Optional[Dict[str, Any]] = None) -> ToolResponse:
|
289
|
+
"""Delete rows from a table."""
|
290
|
+
try:
|
291
|
+
table = self._ensure_table_exists(table_name)
|
292
|
+
stmt = delete(table)
|
293
|
+
|
294
|
+
# Apply WHERE conditions
|
295
|
+
conditions = self._build_where_conditions(table, where or {})
|
296
|
+
if conditions:
|
297
|
+
stmt = stmt.where(and_(*conditions))
|
298
|
+
else:
|
299
|
+
logging.warning(f"delete_rows called without WHERE clause on table {table_name}")
|
300
|
+
|
301
|
+
result = self._execute_with_commit(stmt)
|
302
|
+
return {"success": True, "rows_affected": result.rowcount}
|
303
|
+
except (ValidationError, SQLAlchemyError) as e:
|
304
|
+
if isinstance(e, ValidationError):
|
305
|
+
raise e
|
306
|
+
raise DatabaseError(f"Failed to delete from table {table_name}: {str(e)}")
|
307
|
+
|
308
|
+
def select_query(
|
309
|
+
self, table_name: str, columns: Optional[List[str]] = None, where: Optional[Dict[str, Any]] = None, limit: int = 100
|
310
|
+
) -> ToolResponse:
|
311
|
+
"""Run a SELECT query with specified columns and conditions."""
|
312
|
+
if limit < 1:
|
313
|
+
raise ValidationError("Limit must be a positive integer")
|
314
|
+
|
315
|
+
try:
|
316
|
+
table = self._ensure_table_exists(table_name)
|
317
|
+
|
318
|
+
# Build SELECT columns
|
319
|
+
if columns:
|
320
|
+
self._validate_columns(table, columns, "SELECT operation")
|
321
|
+
select_columns = [table.c[col_name] for col_name in columns]
|
322
|
+
stmt = select(*select_columns)
|
323
|
+
else:
|
324
|
+
stmt = select(table)
|
325
|
+
|
326
|
+
# Apply WHERE conditions
|
327
|
+
conditions = self._build_where_conditions(table, where or {})
|
328
|
+
if conditions:
|
329
|
+
stmt = stmt.where(and_(*conditions))
|
330
|
+
|
331
|
+
stmt = stmt.limit(limit)
|
332
|
+
|
333
|
+
with self.get_connection() as conn:
|
334
|
+
result = conn.execute(stmt)
|
335
|
+
rows = [dict(row._mapping) for row in result.fetchall()]
|
336
|
+
|
337
|
+
return {"success": True, "rows": rows}
|
338
|
+
except (ValidationError, SQLAlchemyError) as e:
|
339
|
+
if isinstance(e, ValidationError):
|
340
|
+
raise e
|
341
|
+
raise DatabaseError(f"Failed to query table {table_name}: {str(e)}")
|
342
|
+
|
343
|
+
def list_all_columns(self) -> ToolResponse:
|
344
|
+
"""List all columns for all tables."""
|
345
|
+
try:
|
346
|
+
self._refresh_metadata()
|
347
|
+
schemas = {table_name: [col.name for col in table.columns] for table_name, table in self.metadata.tables.items()}
|
348
|
+
return {"success": True, "schemas": schemas}
|
349
|
+
except SQLAlchemyError as e:
|
350
|
+
raise DatabaseError(f"Failed to list all columns: {str(e)}")
|
351
|
+
|
352
|
+
def search_content(self, query: str, tables: Optional[List[str]] = None, limit: int = 50) -> ToolResponse:
|
353
|
+
"""Perform full-text search across table content."""
|
354
|
+
if not query or not query.strip():
|
355
|
+
raise ValidationError("Search query cannot be empty")
|
356
|
+
if limit < 1:
|
357
|
+
raise ValidationError("Limit must be a positive integer")
|
358
|
+
|
359
|
+
try:
|
360
|
+
self._refresh_metadata()
|
361
|
+
search_tables = tables or list(self.metadata.tables.keys())
|
362
|
+
results = []
|
363
|
+
|
364
|
+
with self.get_connection() as conn:
|
365
|
+
for table_name in search_tables:
|
366
|
+
if table_name not in self.metadata.tables:
|
367
|
+
continue
|
368
|
+
|
369
|
+
table = self.metadata.tables[table_name]
|
370
|
+
text_columns = [
|
371
|
+
col for col in table.columns if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper()
|
372
|
+
]
|
373
|
+
|
374
|
+
if not text_columns:
|
375
|
+
continue
|
376
|
+
|
377
|
+
# Build search conditions and execute
|
378
|
+
conditions = [col.like(f"%{query}%") for col in text_columns]
|
379
|
+
stmt = select(table).where(or_(*conditions)).limit(limit)
|
380
|
+
|
381
|
+
for row in conn.execute(stmt).fetchall():
|
382
|
+
row_dict = dict(row._mapping)
|
383
|
+
|
384
|
+
# Calculate relevance and matched content
|
385
|
+
relevance = 0.0
|
386
|
+
matched_content = []
|
387
|
+
query_lower = query.lower()
|
388
|
+
|
389
|
+
for col in text_columns:
|
390
|
+
if col.name in row_dict and row_dict[col.name]:
|
391
|
+
content = str(row_dict[col.name]).lower()
|
392
|
+
if query_lower in content:
|
393
|
+
frequency = content.count(query_lower)
|
394
|
+
relevance += frequency / len(content)
|
395
|
+
matched_content.append(f"{col.name}: {row_dict[col.name]}")
|
396
|
+
|
397
|
+
if relevance > 0:
|
398
|
+
results.append(
|
399
|
+
{
|
400
|
+
"table": table_name,
|
401
|
+
"row_id": row_dict.get("id"),
|
402
|
+
"row_data": row_dict,
|
403
|
+
"matched_content": matched_content,
|
404
|
+
"relevance": round(relevance, 3),
|
405
|
+
}
|
406
|
+
)
|
407
|
+
|
408
|
+
# Sort by relevance and limit results
|
409
|
+
results.sort(key=lambda x: x["relevance"], reverse=True)
|
410
|
+
results = results[:limit]
|
411
|
+
|
412
|
+
return {
|
413
|
+
"success": True,
|
414
|
+
"results": results,
|
415
|
+
"query": query,
|
416
|
+
"tables_searched": search_tables,
|
417
|
+
"total_results": len(results),
|
418
|
+
}
|
419
|
+
except (ValidationError, SQLAlchemyError) as e:
|
420
|
+
if isinstance(e, ValidationError):
|
421
|
+
raise e
|
422
|
+
raise DatabaseError(f"Failed to search content: {str(e)}")
|
423
|
+
|
424
|
+
def explore_tables(self, pattern: Optional[str] = None, include_row_counts: bool = True) -> ToolResponse:
|
425
|
+
"""Explore table structures and content."""
|
426
|
+
try:
|
427
|
+
self._refresh_metadata()
|
428
|
+
table_names = list(self.metadata.tables.keys())
|
429
|
+
|
430
|
+
if pattern:
|
431
|
+
table_names = [name for name in table_names if pattern.replace("%", "") in name]
|
432
|
+
|
433
|
+
exploration = {"tables": [], "total_tables": len(table_names), "total_rows": 0}
|
434
|
+
|
435
|
+
with self.get_connection() as conn:
|
436
|
+
for table_name in table_names:
|
437
|
+
table = self.metadata.tables[table_name]
|
438
|
+
|
439
|
+
# Build column info and identify text columns
|
440
|
+
columns = []
|
441
|
+
text_columns = []
|
442
|
+
|
443
|
+
for col in table.columns:
|
444
|
+
col_data = {
|
445
|
+
"name": col.name,
|
446
|
+
"type": str(col.type),
|
447
|
+
"nullable": col.nullable,
|
448
|
+
"default": col.default,
|
449
|
+
"primary_key": col.primary_key,
|
450
|
+
}
|
451
|
+
columns.append(col_data)
|
452
|
+
|
453
|
+
if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper():
|
454
|
+
text_columns.append(col.name)
|
455
|
+
|
456
|
+
table_info = {"name": table_name, "columns": columns, "text_columns": text_columns}
|
457
|
+
|
458
|
+
# Add row count if requested
|
459
|
+
if include_row_counts:
|
460
|
+
count_result = conn.execute(select(text("COUNT(*)")).select_from(table))
|
461
|
+
row_count = count_result.scalar()
|
462
|
+
table_info["row_count"] = row_count
|
463
|
+
exploration["total_rows"] += row_count
|
464
|
+
|
465
|
+
# Add sample data
|
466
|
+
sample_result = conn.execute(select(table).limit(3))
|
467
|
+
sample_rows = [dict(row._mapping) for row in sample_result.fetchall()]
|
468
|
+
if sample_rows:
|
469
|
+
table_info["sample_data"] = sample_rows
|
470
|
+
|
471
|
+
# Add content preview for text columns
|
472
|
+
if text_columns:
|
473
|
+
content_preview = {}
|
474
|
+
for col_name in text_columns[:3]: # Limit to first 3 text columns
|
475
|
+
col = table.c[col_name]
|
476
|
+
preview_result = conn.execute(select(col).distinct().where(col.isnot(None)).limit(5))
|
477
|
+
unique_values = [row[0] for row in preview_result.fetchall() if row[0]]
|
478
|
+
if unique_values:
|
479
|
+
content_preview[col_name] = unique_values
|
480
|
+
|
481
|
+
if content_preview:
|
482
|
+
table_info["content_preview"] = content_preview
|
483
|
+
|
484
|
+
exploration["tables"].append(table_info)
|
485
|
+
|
486
|
+
return {"success": True, "exploration": exploration}
|
487
|
+
except SQLAlchemyError as e:
|
488
|
+
raise DatabaseError(f"Failed to explore tables: {str(e)}")
|
489
|
+
|
490
|
+
# --- Semantic Search Methods ---
|
491
|
+
|
492
|
+
def add_embedding_column(self, table_name: str, embedding_column: str = "embedding") -> EmbeddingColumnResponse:
|
493
|
+
"""Add an embedding column to a table for semantic search."""
|
494
|
+
try:
|
495
|
+
table = self._ensure_table_exists(table_name)
|
496
|
+
|
497
|
+
# Check if embedding column already exists
|
498
|
+
if embedding_column in [col.name for col in table.columns]:
|
499
|
+
return {"success": True, "message": f"Embedding column '{embedding_column}' already exists"}
|
500
|
+
|
501
|
+
# Add embedding column as TEXT (JSON storage)
|
502
|
+
with self.get_connection() as conn:
|
503
|
+
conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {embedding_column} TEXT"))
|
504
|
+
conn.commit()
|
505
|
+
|
506
|
+
self._refresh_metadata()
|
507
|
+
return {"success": True, "message": f"Added embedding column '{embedding_column}' to table '{table_name}'"}
|
508
|
+
|
509
|
+
except (ValidationError, SQLAlchemyError) as e:
|
510
|
+
if isinstance(e, ValidationError):
|
511
|
+
raise e
|
512
|
+
raise DatabaseError(f"Failed to add embedding column: {str(e)}")
|
513
|
+
|
514
|
+
def generate_embeddings(self, table_name: str, text_columns: List[str],
|
515
|
+
embedding_column: str = "embedding",
|
516
|
+
model_name: str = "all-MiniLM-L6-v2",
|
517
|
+
batch_size: int = 50) -> GenerateEmbeddingsResponse:
|
518
|
+
"""Generate embeddings for text content in a table."""
|
519
|
+
if not is_semantic_search_available():
|
520
|
+
raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
|
521
|
+
|
522
|
+
try:
|
523
|
+
table = self._ensure_table_exists(table_name)
|
524
|
+
semantic_engine = get_semantic_engine(model_name)
|
525
|
+
|
526
|
+
# Validate text columns exist
|
527
|
+
table_columns = [col.name for col in table.columns]
|
528
|
+
for col in text_columns:
|
529
|
+
if col not in table_columns:
|
530
|
+
raise ValidationError(f"Column '{col}' not found in table '{table_name}'")
|
531
|
+
|
532
|
+
# Add embedding column if it doesn't exist
|
533
|
+
if embedding_column not in table_columns:
|
534
|
+
self.add_embedding_column(table_name, embedding_column)
|
535
|
+
table = self._ensure_table_exists(table_name) # Refresh
|
536
|
+
|
537
|
+
# Get all rows that need embeddings
|
538
|
+
with self.get_connection() as conn:
|
539
|
+
# Select rows without embeddings or with null embeddings
|
540
|
+
stmt = select(table).where(
|
541
|
+
or_(table.c[embedding_column].is_(None),
|
542
|
+
table.c[embedding_column] == "",
|
543
|
+
table.c[embedding_column] == "null")
|
544
|
+
)
|
545
|
+
rows = conn.execute(stmt).fetchall()
|
546
|
+
|
547
|
+
if not rows:
|
548
|
+
embedding_dim = semantic_engine.get_embedding_dimensions() or 0
|
549
|
+
return {"success": True, "message": "All rows already have embeddings", "processed": 0, "model": model_name, "embedding_dimension": embedding_dim}
|
550
|
+
|
551
|
+
processed = 0
|
552
|
+
for i in range(0, len(rows), batch_size):
|
553
|
+
batch = rows[i:i + batch_size]
|
554
|
+
|
555
|
+
for row in batch:
|
556
|
+
row_dict = dict(row._mapping)
|
557
|
+
|
558
|
+
# Combine text from specified columns
|
559
|
+
text_parts = []
|
560
|
+
for col in text_columns:
|
561
|
+
if col in row_dict and row_dict[col]:
|
562
|
+
text_parts.append(str(row_dict[col]))
|
563
|
+
|
564
|
+
if text_parts:
|
565
|
+
combined_text = " ".join(text_parts)
|
566
|
+
|
567
|
+
# Generate embedding
|
568
|
+
embedding = semantic_engine.generate_embedding(combined_text)
|
569
|
+
embedding_json = json.dumps(embedding)
|
570
|
+
|
571
|
+
# Update row with embedding
|
572
|
+
update_stmt = update(table).where(
|
573
|
+
table.c["id"] == row_dict["id"]
|
574
|
+
).values({embedding_column: embedding_json})
|
575
|
+
|
576
|
+
conn.execute(update_stmt)
|
577
|
+
processed += 1
|
578
|
+
|
579
|
+
conn.commit()
|
580
|
+
logging.info(f"Generated embeddings for batch {i//batch_size + 1}, processed {processed} rows")
|
581
|
+
|
582
|
+
return {
|
583
|
+
"success": True,
|
584
|
+
"message": f"Generated embeddings for {processed} rows",
|
585
|
+
"processed": processed,
|
586
|
+
"model": model_name,
|
587
|
+
"embedding_dimension": semantic_engine.get_embedding_dimensions() or 0
|
588
|
+
}
|
589
|
+
|
590
|
+
except (ValidationError, SQLAlchemyError) as e:
|
591
|
+
if isinstance(e, ValidationError):
|
592
|
+
raise e
|
593
|
+
raise DatabaseError(f"Failed to generate embeddings: {str(e)}")
|
594
|
+
|
595
|
+
def semantic_search(self, query: str, tables: Optional[List[str]] = None,
|
596
|
+
embedding_column: str = "embedding",
|
597
|
+
text_columns: Optional[List[str]] = None,
|
598
|
+
similarity_threshold: float = 0.5,
|
599
|
+
limit: int = 10,
|
600
|
+
model_name: str = "all-MiniLM-L6-v2") -> SemanticSearchResponse:
|
601
|
+
"""Perform semantic search across tables using vector embeddings."""
|
602
|
+
if not is_semantic_search_available():
|
603
|
+
raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
|
604
|
+
|
605
|
+
if not query or not query.strip():
|
606
|
+
raise ValidationError("Search query cannot be empty")
|
607
|
+
|
608
|
+
try:
|
609
|
+
self._refresh_metadata()
|
610
|
+
search_tables = tables or list(self.metadata.tables.keys())
|
611
|
+
semantic_engine = get_semantic_engine(model_name)
|
612
|
+
|
613
|
+
all_results = []
|
614
|
+
|
615
|
+
with self.get_connection() as conn:
|
616
|
+
for table_name in search_tables:
|
617
|
+
if table_name not in self.metadata.tables:
|
618
|
+
continue
|
619
|
+
|
620
|
+
table = self.metadata.tables[table_name]
|
621
|
+
|
622
|
+
# Check if table has embedding column
|
623
|
+
if embedding_column not in [col.name for col in table.columns]:
|
624
|
+
logging.warning(f"Table '{table_name}' does not have embedding column '{embedding_column}'")
|
625
|
+
continue
|
626
|
+
|
627
|
+
# Get all rows with embeddings
|
628
|
+
stmt = select(table).where(
|
629
|
+
and_(table.c[embedding_column].isnot(None),
|
630
|
+
table.c[embedding_column] != "",
|
631
|
+
table.c[embedding_column] != "null")
|
632
|
+
)
|
633
|
+
rows = conn.execute(stmt).fetchall()
|
634
|
+
|
635
|
+
if not rows:
|
636
|
+
continue
|
637
|
+
|
638
|
+
# Convert to list of dicts for semantic search
|
639
|
+
content_data = [dict(row._mapping) for row in rows]
|
640
|
+
|
641
|
+
# Determine text columns for highlighting
|
642
|
+
if text_columns is None:
|
643
|
+
text_cols = [col.name for col in table.columns
|
644
|
+
if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper()]
|
645
|
+
else:
|
646
|
+
text_cols = text_columns
|
647
|
+
|
648
|
+
# Perform semantic search on this table
|
649
|
+
table_results = semantic_engine.semantic_search(
|
650
|
+
query, content_data, embedding_column, text_cols,
|
651
|
+
similarity_threshold, limit * 2 # Get more for global ranking
|
652
|
+
)
|
653
|
+
|
654
|
+
# Add table name to results
|
655
|
+
for result in table_results:
|
656
|
+
result["table_name"] = table_name
|
657
|
+
|
658
|
+
all_results.extend(table_results)
|
659
|
+
|
660
|
+
# Sort all results by similarity score and limit
|
661
|
+
all_results.sort(key=lambda x: x.get("similarity_score", 0), reverse=True)
|
662
|
+
final_results = all_results[:limit]
|
663
|
+
|
664
|
+
return {
|
665
|
+
"success": True,
|
666
|
+
"results": final_results,
|
667
|
+
"query": query,
|
668
|
+
"tables_searched": search_tables,
|
669
|
+
"total_results": len(final_results),
|
670
|
+
"model": model_name,
|
671
|
+
"similarity_threshold": similarity_threshold
|
672
|
+
}
|
673
|
+
|
674
|
+
except (ValidationError, SQLAlchemyError) as e:
|
675
|
+
if isinstance(e, ValidationError):
|
676
|
+
raise e
|
677
|
+
raise DatabaseError(f"Semantic search failed: {str(e)}")
|
678
|
+
|
679
|
+
def find_related_content(self, table_name: str, row_id: int,
|
680
|
+
embedding_column: str = "embedding",
|
681
|
+
similarity_threshold: float = 0.5,
|
682
|
+
limit: int = 5,
|
683
|
+
model_name: str = "all-MiniLM-L6-v2") -> RelatedContentResponse:
|
684
|
+
"""Find content related to a specific row by semantic similarity."""
|
685
|
+
if not is_semantic_search_available():
|
686
|
+
raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
|
687
|
+
|
688
|
+
try:
|
689
|
+
table = self._ensure_table_exists(table_name)
|
690
|
+
semantic_engine = get_semantic_engine(model_name)
|
691
|
+
|
692
|
+
with self.get_connection() as conn:
|
693
|
+
# Get the target row
|
694
|
+
target_stmt = select(table).where(table.c["id"] == row_id)
|
695
|
+
target_row = conn.execute(target_stmt).fetchone()
|
696
|
+
|
697
|
+
if not target_row:
|
698
|
+
raise ValidationError(f"Row with id {row_id} not found in table '{table_name}'")
|
699
|
+
|
700
|
+
target_dict = dict(target_row._mapping)
|
701
|
+
|
702
|
+
# Check if target has embedding
|
703
|
+
if (embedding_column not in target_dict or
|
704
|
+
not target_dict[embedding_column] or
|
705
|
+
target_dict[embedding_column] in ["", "null"]):
|
706
|
+
raise ValidationError(f"Row {row_id} does not have an embedding")
|
707
|
+
|
708
|
+
# Get target embedding
|
709
|
+
target_embedding = json.loads(target_dict[embedding_column])
|
710
|
+
|
711
|
+
# Get all other rows with embeddings
|
712
|
+
stmt = select(table).where(
|
713
|
+
and_(table.c["id"] != row_id,
|
714
|
+
table.c[embedding_column].isnot(None),
|
715
|
+
table.c[embedding_column] != "",
|
716
|
+
table.c[embedding_column] != "null")
|
717
|
+
)
|
718
|
+
rows = conn.execute(stmt).fetchall()
|
719
|
+
|
720
|
+
if not rows:
|
721
|
+
return {
|
722
|
+
"success": True,
|
723
|
+
"results": [],
|
724
|
+
"target_row": target_dict,
|
725
|
+
"total_results": 0,
|
726
|
+
"similarity_threshold": similarity_threshold,
|
727
|
+
"model": model_name,
|
728
|
+
"message": "No other rows with embeddings found"
|
729
|
+
}
|
730
|
+
|
731
|
+
# Find similar rows
|
732
|
+
content_data = [dict(row._mapping) for row in rows]
|
733
|
+
candidate_embeddings = []
|
734
|
+
valid_indices = []
|
735
|
+
|
736
|
+
for idx, row_dict in enumerate(content_data):
|
737
|
+
try:
|
738
|
+
embedding = json.loads(row_dict[embedding_column])
|
739
|
+
candidate_embeddings.append(embedding)
|
740
|
+
valid_indices.append(idx)
|
741
|
+
except json.JSONDecodeError:
|
742
|
+
continue
|
743
|
+
|
744
|
+
if not candidate_embeddings:
|
745
|
+
return {
|
746
|
+
"success": True,
|
747
|
+
"results": [],
|
748
|
+
"target_row": target_dict,
|
749
|
+
"total_results": 0,
|
750
|
+
"similarity_threshold": similarity_threshold,
|
751
|
+
"model": model_name,
|
752
|
+
"message": "No valid embeddings found for comparison"
|
753
|
+
}
|
754
|
+
|
755
|
+
# Calculate similarities
|
756
|
+
similar_indices = semantic_engine.find_similar_embeddings(
|
757
|
+
target_embedding, candidate_embeddings,
|
758
|
+
similarity_threshold, limit
|
759
|
+
)
|
760
|
+
|
761
|
+
# Build results
|
762
|
+
results = []
|
763
|
+
for candidate_idx, similarity_score in similar_indices:
|
764
|
+
original_idx = valid_indices[candidate_idx]
|
765
|
+
row_dict = content_data[original_idx].copy()
|
766
|
+
row_dict["similarity_score"] = round(similarity_score, 3)
|
767
|
+
results.append(row_dict)
|
768
|
+
|
769
|
+
return {
|
770
|
+
"success": True,
|
771
|
+
"results": results,
|
772
|
+
"target_row": target_dict,
|
773
|
+
"total_results": len(results),
|
774
|
+
"similarity_threshold": similarity_threshold,
|
775
|
+
"model": model_name,
|
776
|
+
"message": f"Found {len(results)} related items"
|
777
|
+
}
|
778
|
+
|
779
|
+
except (ValidationError, SQLAlchemyError) as e:
|
780
|
+
if isinstance(e, ValidationError):
|
781
|
+
raise e
|
782
|
+
raise DatabaseError(f"Failed to find related content: {str(e)}")
|
783
|
+
|
784
|
+
def hybrid_search(self, query: str, tables: Optional[List[str]] = None,
|
785
|
+
text_columns: Optional[List[str]] = None,
|
786
|
+
embedding_column: str = "embedding",
|
787
|
+
semantic_weight: float = 0.7,
|
788
|
+
text_weight: float = 0.3,
|
789
|
+
limit: int = 10,
|
790
|
+
model_name: str = "all-MiniLM-L6-v2") -> HybridSearchResponse:
|
791
|
+
"""Combine semantic search with keyword matching for optimal results."""
|
792
|
+
if not is_semantic_search_available():
|
793
|
+
# Fallback to text search only
|
794
|
+
fallback_result = self.search_content(query, tables, limit)
|
795
|
+
# Convert to HybridSearchResponse format
|
796
|
+
return cast(HybridSearchResponse, {
|
797
|
+
**fallback_result,
|
798
|
+
"search_type": "text_only",
|
799
|
+
"semantic_weight": 0.0,
|
800
|
+
"text_weight": 1.0,
|
801
|
+
"model": "none"
|
802
|
+
})
|
803
|
+
|
804
|
+
try:
|
805
|
+
# Get semantic search results
|
806
|
+
semantic_response = self.semantic_search(
|
807
|
+
query, tables, embedding_column, text_columns,
|
808
|
+
similarity_threshold=0.3, limit=limit * 2, model_name=model_name
|
809
|
+
)
|
810
|
+
|
811
|
+
if not semantic_response.get("success"):
|
812
|
+
return cast(HybridSearchResponse, {
|
813
|
+
**semantic_response,
|
814
|
+
"search_type": "semantic_failed",
|
815
|
+
"semantic_weight": semantic_weight,
|
816
|
+
"text_weight": text_weight,
|
817
|
+
"model": model_name
|
818
|
+
})
|
819
|
+
|
820
|
+
semantic_results = semantic_response.get("results", [])
|
821
|
+
|
822
|
+
if not semantic_results:
|
823
|
+
# Fallback to text search
|
824
|
+
fallback_result = self.search_content(query, tables, limit)
|
825
|
+
return cast(HybridSearchResponse, {
|
826
|
+
**fallback_result,
|
827
|
+
"search_type": "text_fallback",
|
828
|
+
"semantic_weight": semantic_weight,
|
829
|
+
"text_weight": text_weight,
|
830
|
+
"model": model_name
|
831
|
+
})
|
832
|
+
|
833
|
+
# Enhance with text matching scores
|
834
|
+
semantic_engine = get_semantic_engine(model_name)
|
835
|
+
enhanced_results = semantic_engine.hybrid_search(
|
836
|
+
query, semantic_results, text_columns or [],
|
837
|
+
embedding_column, semantic_weight, text_weight, limit
|
838
|
+
)
|
839
|
+
|
840
|
+
return {
|
841
|
+
"success": True,
|
842
|
+
"results": enhanced_results,
|
843
|
+
"query": query,
|
844
|
+
"search_type": "hybrid",
|
845
|
+
"semantic_weight": semantic_weight,
|
846
|
+
"text_weight": text_weight,
|
847
|
+
"total_results": len(enhanced_results),
|
848
|
+
"model": model_name
|
849
|
+
}
|
850
|
+
|
851
|
+
except (ValidationError, SQLAlchemyError) as e:
|
852
|
+
if isinstance(e, ValidationError):
|
853
|
+
raise e
|
854
|
+
raise DatabaseError(f"Hybrid search failed: {str(e)}")
|
855
|
+
|
856
|
+
def get_embedding_stats(self, table_name: str, embedding_column: str = "embedding") -> EmbeddingStatsResponse:
|
857
|
+
"""Get statistics about embeddings in a table."""
|
858
|
+
try:
|
859
|
+
table = self._ensure_table_exists(table_name)
|
860
|
+
|
861
|
+
with self.get_connection() as conn:
|
862
|
+
# Count total rows
|
863
|
+
total_count = conn.execute(select(text("COUNT(*)")).select_from(table)).scalar() or 0
|
864
|
+
|
865
|
+
# Count rows with embeddings
|
866
|
+
embedded_count = conn.execute(
|
867
|
+
select(text("COUNT(*)")).select_from(table).where(
|
868
|
+
and_(table.c[embedding_column].isnot(None),
|
869
|
+
table.c[embedding_column] != "",
|
870
|
+
table.c[embedding_column] != "null")
|
871
|
+
)
|
872
|
+
).scalar() or 0
|
873
|
+
|
874
|
+
# Get sample embedding to check dimensions
|
875
|
+
sample_stmt = select(table.c[embedding_column]).where(
|
876
|
+
and_(table.c[embedding_column].isnot(None),
|
877
|
+
table.c[embedding_column] != "",
|
878
|
+
table.c[embedding_column] != "null")
|
879
|
+
).limit(1)
|
880
|
+
|
881
|
+
sample_result = conn.execute(sample_stmt).fetchone()
|
882
|
+
dimensions = None
|
883
|
+
if sample_result and sample_result[0]:
|
884
|
+
try:
|
885
|
+
sample_embedding = json.loads(sample_result[0])
|
886
|
+
dimensions = len(sample_embedding)
|
887
|
+
except json.JSONDecodeError:
|
888
|
+
pass
|
889
|
+
|
890
|
+
coverage_percent = (embedded_count / total_count * 100) if total_count > 0 else 0.0
|
891
|
+
|
892
|
+
return {
|
893
|
+
"success": True,
|
894
|
+
"table_name": table_name,
|
895
|
+
"total_rows": total_count,
|
896
|
+
"embedded_rows": embedded_count,
|
897
|
+
"coverage_percent": round(coverage_percent, 1),
|
898
|
+
"embedding_dimensions": dimensions,
|
899
|
+
"embedding_column": embedding_column
|
900
|
+
}
|
901
|
+
|
902
|
+
except (ValidationError, SQLAlchemyError) as e:
|
903
|
+
if isinstance(e, ValidationError):
|
904
|
+
raise e
|
905
|
+
raise DatabaseError(f"Failed to get embedding stats: {str(e)}")
|
906
|
+
|
907
|
+
|
908
|
+
# Global database instance
|
909
|
+
_db_instance: Optional[SQLiteMemoryDatabase] = None
|
910
|
+
|
911
|
+
|
912
|
+
def get_database(db_path: Optional[str] = None) -> SQLiteMemoryDatabase:
|
913
|
+
"""Get or create the global database instance."""
|
914
|
+
global _db_instance
|
915
|
+
|
916
|
+
actual_path = db_path or os.environ.get("DB_PATH", "./test.db")
|
917
|
+
if _db_instance is None or (db_path and db_path != _db_instance.db_path):
|
918
|
+
# Close previous instance if it exists
|
919
|
+
if _db_instance is not None:
|
920
|
+
_db_instance.close()
|
921
|
+
_db_instance = SQLiteMemoryDatabase(actual_path)
|
922
|
+
|
923
|
+
return _db_instance
|