mcp-sqlite-memory-bank 1.2.4__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1010 @@
1
+ """
2
+ Database abstraction layer for SQLite Memory Bank using SQLAlchemy Core.
3
+
4
+ This module provides a clean, type-safe abstraction over SQLite operations,
5
+ dramatically simplifying the complex database logic in server.py.
6
+
7
+ Author: Robert Meisner
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import logging
13
+ from functools import wraps
14
+ from typing import Dict, List, Any, Optional, Callable, cast
15
+ from sqlalchemy import create_engine, MetaData, Table, select, insert, update, delete, text, inspect, and_, or_
16
+ from sqlalchemy.engine import Engine
17
+ from sqlalchemy.exc import SQLAlchemyError
18
+ from contextlib import contextmanager
19
+
20
+ from .types import (
21
+ ValidationError,
22
+ DatabaseError,
23
+ SchemaError,
24
+ ToolResponse,
25
+ EmbeddingColumnResponse,
26
+ GenerateEmbeddingsResponse,
27
+ SemanticSearchResponse,
28
+ RelatedContentResponse,
29
+ HybridSearchResponse,
30
+ EmbeddingStatsResponse,
31
+ )
32
+ from .semantic import get_semantic_engine, is_semantic_search_available
33
+
34
+
35
+ class SQLiteMemoryDatabase:
36
+ """
37
+ SQLAlchemy Core-based database abstraction for SQLite Memory Bank.
38
+
39
+ This class handles all database operations with automatic:
40
+ - Connection management
41
+ - Error handling and translation
42
+ - SQL injection protection
43
+ - Type conversion
44
+ - Transaction management
45
+ """
46
+
47
+ def __init__(self, db_path: str):
48
+ """Initialize database connection and metadata."""
49
+ self.db_path = os.path.abspath(db_path)
50
+ self.engine: Engine = create_engine(f"sqlite:///{self.db_path}", echo=False)
51
+ self.metadata = MetaData()
52
+
53
+ # Ensure database directory exists
54
+ os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
55
+
56
+ # Initialize connection
57
+ self._refresh_metadata()
58
+
59
+ def close(self) -> None:
60
+ """Close all database connections and dispose of the engine."""
61
+ try:
62
+ if hasattr(self, "engine"):
63
+ self.engine.dispose()
64
+ except Exception as e:
65
+ logging.warning(f"Error closing database: {e}")
66
+
67
+ def __del__(self):
68
+ """Ensure cleanup when object is garbage collected."""
69
+ self.close()
70
+
71
+ def _refresh_metadata(self) -> None:
72
+ """Refresh metadata to reflect current database schema."""
73
+ try:
74
+ self.metadata.clear()
75
+ self.metadata.reflect(bind=self.engine)
76
+ except SQLAlchemyError as e:
77
+ logging.warning(f"Failed to refresh metadata: {e}")
78
+
79
+ @contextmanager
80
+ def get_connection(self):
81
+ """Get a database connection with automatic cleanup."""
82
+ conn = self.engine.connect()
83
+ try:
84
+ yield conn
85
+ finally:
86
+ conn.close()
87
+
88
+ def _ensure_table_exists(self, table_name: str) -> Table:
89
+ """Get table metadata, refreshing if needed.
90
+ Raises ValidationError if not found.
91
+ """
92
+ if table_name not in self.metadata.tables:
93
+ self._refresh_metadata()
94
+
95
+ if table_name not in self.metadata.tables:
96
+ raise ValidationError(f"Table '{table_name}' does not exist")
97
+
98
+ return self.metadata.tables[table_name]
99
+
100
+ def _validate_columns(self, table: Table, column_names: List[str], context: str = "operation") -> None:
101
+ """Validate that all column names exist in the table."""
102
+ valid_columns = set(col.name for col in table.columns)
103
+ for col_name in column_names:
104
+ if col_name not in valid_columns:
105
+ raise ValidationError(f"Invalid column '{col_name}' for table " f"'{table.name}' in {context}")
106
+
107
+ def _build_where_conditions(self, table: Table, where: Dict[str, Any]) -> List:
108
+ """Build SQLAlchemy WHERE conditions from a dictionary."""
109
+ if not where:
110
+ return []
111
+
112
+ self._validate_columns(table, list(where.keys()), "WHERE clause")
113
+ return [table.c[col_name] == value for col_name, value in where.items()]
114
+
115
+ def _execute_with_commit(self, stmt) -> Any:
116
+ """Execute a statement with automatic connection mgmt and commit."""
117
+ with self.get_connection() as conn:
118
+ result = conn.execute(stmt)
119
+ conn.commit()
120
+ return result
121
+
122
+ def _database_operation(self, operation_name: str):
123
+ """Decorator for database operations with standardized error handling."""
124
+
125
+ def decorator(func: Callable) -> Callable:
126
+ @wraps(func)
127
+ def wrapper(*args, **kwargs) -> ToolResponse:
128
+ try:
129
+ return func(*args, **kwargs)
130
+ except (ValidationError, SchemaError) as e:
131
+ raise e
132
+ except SQLAlchemyError as e:
133
+ raise DatabaseError(f"Failed to {operation_name}: {str(e)}")
134
+
135
+ return wrapper
136
+
137
+ return decorator
138
+
139
+ def create_table(self, table_name: str, columns: List[Dict[str, str]]) -> ToolResponse:
140
+ """Create a new table with the specified columns."""
141
+ # Input validation
142
+ if not table_name or not table_name.isidentifier():
143
+ raise ValidationError(f"Invalid table name: {table_name}")
144
+ if not columns:
145
+ raise ValidationError("Must provide at least one column")
146
+
147
+ # Validate column definitions
148
+ for col_def in columns:
149
+ if "name" not in col_def or "type" not in col_def:
150
+ raise ValidationError(f"Column must have 'name' and 'type': {col_def}")
151
+ if not col_def["name"].isidentifier():
152
+ raise ValidationError(f"Invalid column name: {col_def['name']}")
153
+
154
+ try:
155
+ # Use raw SQL for full SQLite type support
156
+ col_defs = ", ".join([f"{col['name']} {col['type']}" for col in columns])
157
+ sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({col_defs})"
158
+
159
+ with self.get_connection() as conn:
160
+ conn.execute(text(sql))
161
+ conn.commit()
162
+
163
+ self._refresh_metadata()
164
+ return {"success": True}
165
+
166
+ except SQLAlchemyError as e:
167
+ raise DatabaseError(f"Failed to create table {table_name}: {str(e)}")
168
+
169
+ def list_tables(self) -> ToolResponse:
170
+ """List all user-created tables."""
171
+ try:
172
+ with self.get_connection() as conn:
173
+ inspector = inspect(conn)
174
+ tables = [name for name in inspector.get_table_names() if not name.startswith("sqlite_")]
175
+ return {"success": True, "tables": tables}
176
+ except SQLAlchemyError as e:
177
+ raise DatabaseError(f"Failed to list tables: {str(e)}")
178
+
179
+ def describe_table(self, table_name: str) -> ToolResponse:
180
+ """Get detailed schema information for a table."""
181
+ try:
182
+ table = self._ensure_table_exists(table_name)
183
+ columns = [
184
+ {
185
+ "name": col.name,
186
+ "type": str(col.type),
187
+ "nullable": col.nullable,
188
+ "default": col.default,
189
+ "primary_key": col.primary_key,
190
+ }
191
+ for col in table.columns
192
+ ]
193
+ return {"success": True, "columns": columns}
194
+ except (ValidationError, SQLAlchemyError) as e:
195
+ if isinstance(e, ValidationError):
196
+ raise e
197
+ raise DatabaseError(f"Failed to describe table {table_name}: {str(e)}")
198
+
199
+ def drop_table(self, table_name: str) -> ToolResponse:
200
+ """Drop a table."""
201
+ try:
202
+ self._ensure_table_exists(table_name) # Validates existence
203
+ self._execute_with_commit(text(f"DROP TABLE {table_name}"))
204
+ self._refresh_metadata()
205
+ return {"success": True}
206
+ except (ValidationError, SQLAlchemyError) as e:
207
+ if isinstance(e, ValidationError):
208
+ raise e
209
+ raise DatabaseError(f"Failed to drop table {table_name}: {str(e)}")
210
+
211
+ def rename_table(self, old_name: str, new_name: str) -> ToolResponse:
212
+ """Rename a table."""
213
+ if old_name == new_name:
214
+ raise ValidationError("Old and new table names are identical")
215
+
216
+ try:
217
+ self._ensure_table_exists(old_name) # Validates old table exists
218
+
219
+ # Check if new name already exists
220
+ with self.get_connection() as conn:
221
+ inspector = inspect(conn)
222
+ if new_name in inspector.get_table_names():
223
+ raise ValidationError(f"Table '{new_name}' already exists")
224
+
225
+ conn.execute(text(f"ALTER TABLE {old_name} RENAME TO {new_name}"))
226
+ conn.commit()
227
+
228
+ self._refresh_metadata()
229
+ return {"success": True}
230
+ except (ValidationError, SQLAlchemyError) as e:
231
+ if isinstance(e, ValidationError):
232
+ raise e
233
+ raise DatabaseError(f"Failed to rename table from {old_name} to {new_name}: {str(e)}")
234
+
235
+ def insert_row(self, table_name: str, data: Dict[str, Any]) -> ToolResponse:
236
+ """Insert a row into a table."""
237
+ if not data:
238
+ raise ValidationError("Data cannot be empty")
239
+
240
+ try:
241
+ table = self._ensure_table_exists(table_name)
242
+ self._validate_columns(table, list(data.keys()), "insert operation")
243
+
244
+ result = self._execute_with_commit(insert(table).values(**data))
245
+ return {"success": True, "id": result.lastrowid}
246
+ except (ValidationError, SQLAlchemyError) as e:
247
+ if isinstance(e, ValidationError):
248
+ raise e
249
+ raise DatabaseError(f"Failed to insert into table {table_name}: {str(e)}")
250
+
251
+ def read_rows(
252
+ self, table_name: str, where: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
253
+ ) -> ToolResponse:
254
+ """Read rows from a table with optional filtering."""
255
+ try:
256
+ table = self._ensure_table_exists(table_name)
257
+ stmt = select(table)
258
+
259
+ # Apply WHERE conditions
260
+ conditions = self._build_where_conditions(table, where or {})
261
+ if conditions:
262
+ stmt = stmt.where(and_(*conditions))
263
+
264
+ # Apply LIMIT
265
+ if limit:
266
+ stmt = stmt.limit(limit)
267
+
268
+ with self.get_connection() as conn:
269
+ result = conn.execute(stmt)
270
+ rows = [dict(row._mapping) for row in result.fetchall()]
271
+
272
+ return {"success": True, "rows": rows}
273
+ except (ValidationError, SQLAlchemyError) as e:
274
+ if isinstance(e, ValidationError):
275
+ raise e
276
+ raise DatabaseError(f"Failed to read from table {table_name}: {str(e)}")
277
+
278
+ def update_rows(
279
+ self, table_name: str, data: Dict[str, Any], where: Optional[Dict[str, Any]] = None
280
+ ) -> ToolResponse:
281
+ """Update rows in a table."""
282
+ if not data:
283
+ raise ValidationError("Update data cannot be empty")
284
+
285
+ try:
286
+ table = self._ensure_table_exists(table_name)
287
+ self._validate_columns(table, list(data.keys()), "update operation")
288
+
289
+ stmt = update(table).values(**data)
290
+
291
+ # Apply WHERE conditions
292
+ conditions = self._build_where_conditions(table, where or {})
293
+ if conditions:
294
+ stmt = stmt.where(and_(*conditions))
295
+
296
+ result = self._execute_with_commit(stmt)
297
+ return {"success": True, "rows_affected": result.rowcount}
298
+ except (ValidationError, SQLAlchemyError) as e:
299
+ if isinstance(e, ValidationError):
300
+ raise e
301
+ raise DatabaseError(f"Failed to update table {table_name}: {str(e)}")
302
+
303
+ def delete_rows(self, table_name: str, where: Optional[Dict[str, Any]] = None) -> ToolResponse:
304
+ """Delete rows from a table."""
305
+ try:
306
+ table = self._ensure_table_exists(table_name)
307
+ stmt = delete(table)
308
+
309
+ # Apply WHERE conditions
310
+ conditions = self._build_where_conditions(table, where or {})
311
+ if conditions:
312
+ stmt = stmt.where(and_(*conditions))
313
+ else:
314
+ logging.warning(f"delete_rows called without WHERE clause on table {table_name}")
315
+
316
+ result = self._execute_with_commit(stmt)
317
+ return {"success": True, "rows_affected": result.rowcount}
318
+ except (ValidationError, SQLAlchemyError) as e:
319
+ if isinstance(e, ValidationError):
320
+ raise e
321
+ raise DatabaseError(f"Failed to delete from table {table_name}: {str(e)}")
322
+
323
+ def select_query(
324
+ self,
325
+ table_name: str,
326
+ columns: Optional[List[str]] = None,
327
+ where: Optional[Dict[str, Any]] = None,
328
+ limit: int = 100,
329
+ ) -> ToolResponse:
330
+ """Run a SELECT query with specified columns and conditions."""
331
+ if limit < 1:
332
+ raise ValidationError("Limit must be a positive integer")
333
+
334
+ try:
335
+ table = self._ensure_table_exists(table_name)
336
+
337
+ # Build SELECT columns
338
+ if columns:
339
+ self._validate_columns(table, columns, "SELECT operation")
340
+ select_columns = [table.c[col_name] for col_name in columns]
341
+ stmt = select(*select_columns)
342
+ else:
343
+ stmt = select(table)
344
+
345
+ # Apply WHERE conditions
346
+ conditions = self._build_where_conditions(table, where or {})
347
+ if conditions:
348
+ stmt = stmt.where(and_(*conditions))
349
+
350
+ stmt = stmt.limit(limit)
351
+
352
+ with self.get_connection() as conn:
353
+ result = conn.execute(stmt)
354
+ rows = [dict(row._mapping) for row in result.fetchall()]
355
+
356
+ return {"success": True, "rows": rows}
357
+ except (ValidationError, SQLAlchemyError) as e:
358
+ if isinstance(e, ValidationError):
359
+ raise e
360
+ raise DatabaseError(f"Failed to query table {table_name}: {str(e)}")
361
+
362
+ def list_all_columns(self) -> ToolResponse:
363
+ """List all columns for all tables."""
364
+ try:
365
+ self._refresh_metadata()
366
+ schemas = {
367
+ table_name: [col.name for col in table.columns] for table_name, table in self.metadata.tables.items()
368
+ }
369
+ return {"success": True, "schemas": schemas}
370
+ except SQLAlchemyError as e:
371
+ raise DatabaseError(f"Failed to list all columns: {str(e)}")
372
+
373
+ def search_content(self, query: str, tables: Optional[List[str]] = None, limit: int = 50) -> ToolResponse:
374
+ """Perform full-text search across table content."""
375
+ if not query or not query.strip():
376
+ raise ValidationError("Search query cannot be empty")
377
+ if limit < 1:
378
+ raise ValidationError("Limit must be a positive integer")
379
+
380
+ try:
381
+ self._refresh_metadata()
382
+ search_tables = tables or list(self.metadata.tables.keys())
383
+ results = []
384
+
385
+ with self.get_connection() as conn:
386
+ for table_name in search_tables:
387
+ if table_name not in self.metadata.tables:
388
+ continue
389
+
390
+ table = self.metadata.tables[table_name]
391
+ text_columns = [
392
+ col
393
+ for col in table.columns
394
+ if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper()
395
+ ]
396
+
397
+ if not text_columns:
398
+ continue
399
+
400
+ # Build search conditions and execute
401
+ conditions = [col.like(f"%{query}%") for col in text_columns]
402
+ stmt = select(table).where(or_(*conditions)).limit(limit)
403
+
404
+ for row in conn.execute(stmt).fetchall():
405
+ row_dict = dict(row._mapping)
406
+
407
+ # Calculate relevance and matched content
408
+ relevance = 0.0
409
+ matched_content = []
410
+ query_lower = query.lower()
411
+
412
+ for col in text_columns:
413
+ if col.name in row_dict and row_dict[col.name]:
414
+ content = str(row_dict[col.name]).lower()
415
+ if query_lower in content:
416
+ frequency = content.count(query_lower)
417
+ relevance += frequency / len(content)
418
+ matched_content.append(f"{col.name}: {row_dict[col.name]}")
419
+
420
+ if relevance > 0:
421
+ results.append(
422
+ {
423
+ "table": table_name,
424
+ "row_id": row_dict.get("id"),
425
+ "row_data": row_dict,
426
+ "matched_content": matched_content,
427
+ "relevance": round(relevance, 3),
428
+ }
429
+ )
430
+
431
+ # Sort by relevance and limit results
432
+ results.sort(key=lambda x: x["relevance"], reverse=True)
433
+ results = results[:limit]
434
+
435
+ return {
436
+ "success": True,
437
+ "results": results,
438
+ "query": query,
439
+ "tables_searched": search_tables,
440
+ "total_results": len(results),
441
+ }
442
+ except (ValidationError, SQLAlchemyError) as e:
443
+ if isinstance(e, ValidationError):
444
+ raise e
445
+ raise DatabaseError(f"Failed to search content: {str(e)}")
446
+
447
+ def explore_tables(self, pattern: Optional[str] = None, include_row_counts: bool = True) -> ToolResponse:
448
+ """Explore table structures and content."""
449
+ try:
450
+ self._refresh_metadata()
451
+ table_names = list(self.metadata.tables.keys())
452
+
453
+ if pattern:
454
+ table_names = [name for name in table_names if pattern.replace("%", "") in name]
455
+
456
+ exploration = {"tables": [], "total_tables": len(table_names), "total_rows": 0}
457
+
458
+ with self.get_connection() as conn:
459
+ for table_name in table_names:
460
+ table = self.metadata.tables[table_name]
461
+
462
+ # Build column info and identify text columns
463
+ columns = []
464
+ text_columns = []
465
+
466
+ for col in table.columns:
467
+ col_data = {
468
+ "name": col.name,
469
+ "type": str(col.type),
470
+ "nullable": col.nullable,
471
+ "default": col.default,
472
+ "primary_key": col.primary_key,
473
+ }
474
+ columns.append(col_data)
475
+
476
+ if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper():
477
+ text_columns.append(col.name)
478
+
479
+ table_info = {"name": table_name, "columns": columns, "text_columns": text_columns}
480
+
481
+ # Add row count if requested
482
+ if include_row_counts:
483
+ count_result = conn.execute(select(text("COUNT(*)")).select_from(table))
484
+ row_count = count_result.scalar()
485
+ table_info["row_count"] = row_count
486
+ exploration["total_rows"] += row_count
487
+
488
+ # Add sample data
489
+ sample_result = conn.execute(select(table).limit(3))
490
+ sample_rows = [dict(row._mapping) for row in sample_result.fetchall()]
491
+ if sample_rows:
492
+ table_info["sample_data"] = sample_rows
493
+
494
+ # Add content preview for text columns
495
+ if text_columns:
496
+ content_preview = {}
497
+ for col_name in text_columns[:3]: # Limit to first 3 text columns
498
+ col = table.c[col_name]
499
+ preview_result = conn.execute(select(col).distinct().where(col.isnot(None)).limit(5))
500
+ unique_values = [row[0] for row in preview_result.fetchall() if row[0]]
501
+ if unique_values:
502
+ content_preview[col_name] = unique_values
503
+
504
+ if content_preview:
505
+ table_info["content_preview"] = content_preview
506
+
507
+ exploration["tables"].append(table_info)
508
+
509
+ return {"success": True, "exploration": exploration}
510
+ except SQLAlchemyError as e:
511
+ raise DatabaseError(f"Failed to explore tables: {str(e)}")
512
+
513
+ # --- Semantic Search Methods ---
514
+
515
+ def add_embedding_column(self, table_name: str, embedding_column: str = "embedding") -> EmbeddingColumnResponse:
516
+ """Add an embedding column to a table for semantic search."""
517
+ try:
518
+ table = self._ensure_table_exists(table_name)
519
+
520
+ # Check if embedding column already exists
521
+ if embedding_column in [col.name for col in table.columns]:
522
+ return {"success": True, "message": f"Embedding column '{embedding_column}' already exists"}
523
+
524
+ # Add embedding column as TEXT (JSON storage)
525
+ with self.get_connection() as conn:
526
+ conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {embedding_column} TEXT"))
527
+ conn.commit()
528
+
529
+ self._refresh_metadata()
530
+ return {"success": True, "message": f"Added embedding column '{embedding_column}' to table '{table_name}'"}
531
+
532
+ except (ValidationError, SQLAlchemyError) as e:
533
+ if isinstance(e, ValidationError):
534
+ raise e
535
+ raise DatabaseError(f"Failed to add embedding column: {str(e)}")
536
+
537
+ def generate_embeddings(
538
+ self,
539
+ table_name: str,
540
+ text_columns: List[str],
541
+ embedding_column: str = "embedding",
542
+ model_name: str = "all-MiniLM-L6-v2",
543
+ batch_size: int = 50,
544
+ ) -> GenerateEmbeddingsResponse:
545
+ """Generate embeddings for text content in a table."""
546
+ if not is_semantic_search_available():
547
+ raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
548
+
549
+ try:
550
+ table = self._ensure_table_exists(table_name)
551
+ semantic_engine = get_semantic_engine(model_name)
552
+
553
+ # Validate text columns exist
554
+ table_columns = [col.name for col in table.columns]
555
+ for col in text_columns:
556
+ if col not in table_columns:
557
+ raise ValidationError(f"Column '{col}' not found in table '{table_name}'")
558
+
559
+ # Add embedding column if it doesn't exist
560
+ if embedding_column not in table_columns:
561
+ self.add_embedding_column(table_name, embedding_column)
562
+ table = self._ensure_table_exists(table_name) # Refresh
563
+
564
+ # Get all rows that need embeddings
565
+ with self.get_connection() as conn:
566
+ # Select rows without embeddings or with null embeddings
567
+ stmt = select(table).where(
568
+ or_(
569
+ table.c[embedding_column].is_(None),
570
+ table.c[embedding_column] == "",
571
+ table.c[embedding_column] == "null",
572
+ )
573
+ )
574
+ rows = conn.execute(stmt).fetchall()
575
+
576
+ if not rows:
577
+ embedding_dim = semantic_engine.get_embedding_dimensions() or 0
578
+ return {
579
+ "success": True,
580
+ "message": "All rows already have embeddings",
581
+ "processed": 0,
582
+ "model": model_name,
583
+ "embedding_dimension": embedding_dim,
584
+ }
585
+
586
+ processed = 0
587
+ for i in range(0, len(rows), batch_size):
588
+ batch = rows[i : i + batch_size]
589
+
590
+ for row in batch:
591
+ row_dict = dict(row._mapping)
592
+
593
+ # Combine text from specified columns
594
+ text_parts = []
595
+ for col in text_columns:
596
+ if col in row_dict and row_dict[col]:
597
+ text_parts.append(str(row_dict[col]))
598
+
599
+ if text_parts:
600
+ combined_text = " ".join(text_parts)
601
+
602
+ # Generate embedding
603
+ embedding = semantic_engine.generate_embedding(combined_text)
604
+ embedding_json = json.dumps(embedding)
605
+
606
+ # Update row with embedding
607
+ update_stmt = (
608
+ update(table)
609
+ .where(table.c["id"] == row_dict["id"])
610
+ .values({embedding_column: embedding_json})
611
+ )
612
+
613
+ conn.execute(update_stmt)
614
+ processed += 1
615
+
616
+ conn.commit()
617
+ logging.info(f"Generated embeddings for batch {i//batch_size + 1}, processed {processed} rows")
618
+
619
+ return {
620
+ "success": True,
621
+ "message": f"Generated embeddings for {processed} rows",
622
+ "processed": processed,
623
+ "model": model_name,
624
+ "embedding_dimension": semantic_engine.get_embedding_dimensions() or 0,
625
+ }
626
+
627
+ except (ValidationError, SQLAlchemyError) as e:
628
+ if isinstance(e, ValidationError):
629
+ raise e
630
+ raise DatabaseError(f"Failed to generate embeddings: {str(e)}")
631
+
632
+ def semantic_search(
633
+ self,
634
+ query: str,
635
+ tables: Optional[List[str]] = None,
636
+ embedding_column: str = "embedding",
637
+ text_columns: Optional[List[str]] = None,
638
+ similarity_threshold: float = 0.5,
639
+ limit: int = 10,
640
+ model_name: str = "all-MiniLM-L6-v2",
641
+ ) -> SemanticSearchResponse:
642
+ """Perform semantic search across tables using vector embeddings."""
643
+ if not is_semantic_search_available():
644
+ raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
645
+
646
+ if not query or not query.strip():
647
+ raise ValidationError("Search query cannot be empty")
648
+
649
+ try:
650
+ self._refresh_metadata()
651
+ search_tables = tables or list(self.metadata.tables.keys())
652
+ semantic_engine = get_semantic_engine(model_name)
653
+
654
+ all_results = []
655
+
656
+ with self.get_connection() as conn:
657
+ for table_name in search_tables:
658
+ if table_name not in self.metadata.tables:
659
+ continue
660
+
661
+ table = self.metadata.tables[table_name]
662
+
663
+ # Check if table has embedding column
664
+ if embedding_column not in [col.name for col in table.columns]:
665
+ logging.warning(f"Table '{table_name}' does not have embedding column '{embedding_column}'")
666
+ continue
667
+
668
+ # Get all rows with embeddings
669
+ stmt = select(table).where(
670
+ and_(
671
+ table.c[embedding_column].isnot(None),
672
+ table.c[embedding_column] != "",
673
+ table.c[embedding_column] != "null",
674
+ )
675
+ )
676
+ rows = conn.execute(stmt).fetchall()
677
+
678
+ if not rows:
679
+ continue
680
+
681
+ # Convert to list of dicts for semantic search
682
+ content_data = [dict(row._mapping) for row in rows]
683
+
684
+ # Determine text columns for highlighting
685
+ if text_columns is None:
686
+ text_cols = [
687
+ col.name
688
+ for col in table.columns
689
+ if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper()
690
+ ]
691
+ else:
692
+ text_cols = text_columns
693
+
694
+ # Perform semantic search on this table
695
+ table_results = semantic_engine.semantic_search(
696
+ query,
697
+ content_data,
698
+ embedding_column,
699
+ text_cols,
700
+ similarity_threshold,
701
+ limit * 2, # Get more for global ranking
702
+ )
703
+
704
+ # Add table name to results
705
+ for result in table_results:
706
+ result["table_name"] = table_name
707
+
708
+ all_results.extend(table_results)
709
+
710
+ # Sort all results by similarity score and limit
711
+ all_results.sort(key=lambda x: x.get("similarity_score", 0), reverse=True)
712
+ final_results = all_results[:limit]
713
+
714
+ return {
715
+ "success": True,
716
+ "results": final_results,
717
+ "query": query,
718
+ "tables_searched": search_tables,
719
+ "total_results": len(final_results),
720
+ "model": model_name,
721
+ "similarity_threshold": similarity_threshold,
722
+ }
723
+
724
+ except (ValidationError, SQLAlchemyError) as e:
725
+ if isinstance(e, ValidationError):
726
+ raise e
727
+ raise DatabaseError(f"Semantic search failed: {str(e)}")
728
+
729
+ def find_related_content(
730
+ self,
731
+ table_name: str,
732
+ row_id: int,
733
+ embedding_column: str = "embedding",
734
+ similarity_threshold: float = 0.5,
735
+ limit: int = 5,
736
+ model_name: str = "all-MiniLM-L6-v2",
737
+ ) -> RelatedContentResponse:
738
+ """Find content related to a specific row by semantic similarity."""
739
+ if not is_semantic_search_available():
740
+ raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
741
+
742
+ try:
743
+ table = self._ensure_table_exists(table_name)
744
+ semantic_engine = get_semantic_engine(model_name)
745
+
746
+ with self.get_connection() as conn:
747
+ # Get the target row
748
+ target_stmt = select(table).where(table.c["id"] == row_id)
749
+ target_row = conn.execute(target_stmt).fetchone()
750
+
751
+ if not target_row:
752
+ raise ValidationError(f"Row with id {row_id} not found in table '{table_name}'")
753
+
754
+ target_dict = dict(target_row._mapping)
755
+
756
+ # Check if target has embedding
757
+ if (
758
+ embedding_column not in target_dict
759
+ or not target_dict[embedding_column]
760
+ or target_dict[embedding_column] in ["", "null"]
761
+ ):
762
+ raise ValidationError(f"Row {row_id} does not have an embedding")
763
+
764
+ # Get target embedding
765
+ target_embedding = json.loads(target_dict[embedding_column])
766
+
767
+ # Get all other rows with embeddings
768
+ stmt = select(table).where(
769
+ and_(
770
+ table.c["id"] != row_id,
771
+ table.c[embedding_column].isnot(None),
772
+ table.c[embedding_column] != "",
773
+ table.c[embedding_column] != "null",
774
+ )
775
+ )
776
+ rows = conn.execute(stmt).fetchall()
777
+
778
+ if not rows:
779
+ return {
780
+ "success": True,
781
+ "results": [],
782
+ "target_row": target_dict,
783
+ "total_results": 0,
784
+ "similarity_threshold": similarity_threshold,
785
+ "model": model_name,
786
+ "message": "No other rows with embeddings found",
787
+ }
788
+
789
+ # Find similar rows
790
+ content_data = [dict(row._mapping) for row in rows]
791
+ candidate_embeddings = []
792
+ valid_indices = []
793
+
794
+ for idx, row_dict in enumerate(content_data):
795
+ try:
796
+ embedding = json.loads(row_dict[embedding_column])
797
+ candidate_embeddings.append(embedding)
798
+ valid_indices.append(idx)
799
+ except json.JSONDecodeError:
800
+ continue
801
+
802
+ if not candidate_embeddings:
803
+ return {
804
+ "success": True,
805
+ "results": [],
806
+ "target_row": target_dict,
807
+ "total_results": 0,
808
+ "similarity_threshold": similarity_threshold,
809
+ "model": model_name,
810
+ "message": "No valid embeddings found for comparison",
811
+ }
812
+
813
+ # Calculate similarities
814
+ similar_indices = semantic_engine.find_similar_embeddings(
815
+ target_embedding, candidate_embeddings, similarity_threshold, limit
816
+ )
817
+
818
+ # Build results
819
+ results = []
820
+ for candidate_idx, similarity_score in similar_indices:
821
+ original_idx = valid_indices[candidate_idx]
822
+ row_dict = content_data[original_idx].copy()
823
+ row_dict["similarity_score"] = round(similarity_score, 3)
824
+ results.append(row_dict)
825
+
826
+ return {
827
+ "success": True,
828
+ "results": results,
829
+ "target_row": target_dict,
830
+ "total_results": len(results),
831
+ "similarity_threshold": similarity_threshold,
832
+ "model": model_name,
833
+ "message": f"Found {len(results)} related items",
834
+ }
835
+
836
+ except (ValidationError, SQLAlchemyError) as e:
837
+ if isinstance(e, ValidationError):
838
+ raise e
839
+ raise DatabaseError(f"Failed to find related content: {str(e)}")
840
+
841
+ def hybrid_search(
842
+ self,
843
+ query: str,
844
+ tables: Optional[List[str]] = None,
845
+ text_columns: Optional[List[str]] = None,
846
+ embedding_column: str = "embedding",
847
+ semantic_weight: float = 0.7,
848
+ text_weight: float = 0.3,
849
+ limit: int = 10,
850
+ model_name: str = "all-MiniLM-L6-v2",
851
+ ) -> HybridSearchResponse:
852
+ """Combine semantic search with keyword matching for optimal results."""
853
+ if not is_semantic_search_available():
854
+ # Fallback to text search only
855
+ fallback_result = self.search_content(query, tables, limit)
856
+ # Convert to HybridSearchResponse format
857
+ return cast(
858
+ HybridSearchResponse,
859
+ {
860
+ **fallback_result,
861
+ "search_type": "text_only",
862
+ "semantic_weight": 0.0,
863
+ "text_weight": 1.0,
864
+ "model": "none",
865
+ },
866
+ )
867
+
868
+ try:
869
+ # Get semantic search results
870
+ semantic_response = self.semantic_search(
871
+ query,
872
+ tables,
873
+ embedding_column,
874
+ text_columns,
875
+ similarity_threshold=0.3,
876
+ limit=limit * 2,
877
+ model_name=model_name,
878
+ )
879
+
880
+ if not semantic_response.get("success"):
881
+ return cast(
882
+ HybridSearchResponse,
883
+ {
884
+ **semantic_response,
885
+ "search_type": "semantic_failed",
886
+ "semantic_weight": semantic_weight,
887
+ "text_weight": text_weight,
888
+ "model": model_name,
889
+ },
890
+ )
891
+
892
+ semantic_results = semantic_response.get("results", [])
893
+
894
+ if not semantic_results:
895
+ # Fallback to text search
896
+ fallback_result = self.search_content(query, tables, limit)
897
+ return cast(
898
+ HybridSearchResponse,
899
+ {
900
+ **fallback_result,
901
+ "search_type": "text_fallback",
902
+ "semantic_weight": semantic_weight,
903
+ "text_weight": text_weight,
904
+ "model": model_name,
905
+ },
906
+ )
907
+
908
+ # Enhance with text matching scores
909
+ semantic_engine = get_semantic_engine(model_name)
910
+ enhanced_results = semantic_engine.hybrid_search(
911
+ query, semantic_results, text_columns or [], embedding_column, semantic_weight, text_weight, limit
912
+ )
913
+
914
+ return {
915
+ "success": True,
916
+ "results": enhanced_results,
917
+ "query": query,
918
+ "search_type": "hybrid",
919
+ "semantic_weight": semantic_weight,
920
+ "text_weight": text_weight,
921
+ "total_results": len(enhanced_results),
922
+ "model": model_name,
923
+ }
924
+
925
+ except (ValidationError, SQLAlchemyError) as e:
926
+ if isinstance(e, ValidationError):
927
+ raise e
928
+ raise DatabaseError(f"Hybrid search failed: {str(e)}")
929
+
930
+ def get_embedding_stats(self, table_name: str, embedding_column: str = "embedding") -> EmbeddingStatsResponse:
931
+ """Get statistics about embeddings in a table."""
932
+ try:
933
+ table = self._ensure_table_exists(table_name)
934
+
935
+ with self.get_connection() as conn:
936
+ # Count total rows
937
+ total_count = conn.execute(select(text("COUNT(*)")).select_from(table)).scalar() or 0
938
+
939
+ # Count rows with embeddings
940
+ embedded_count = (
941
+ conn.execute(
942
+ select(text("COUNT(*)"))
943
+ .select_from(table)
944
+ .where(
945
+ and_(
946
+ table.c[embedding_column].isnot(None),
947
+ table.c[embedding_column] != "",
948
+ table.c[embedding_column] != "null",
949
+ )
950
+ )
951
+ ).scalar()
952
+ or 0
953
+ )
954
+
955
+ # Get sample embedding to check dimensions
956
+ sample_stmt = (
957
+ select(table.c[embedding_column])
958
+ .where(
959
+ and_(
960
+ table.c[embedding_column].isnot(None),
961
+ table.c[embedding_column] != "",
962
+ table.c[embedding_column] != "null",
963
+ )
964
+ )
965
+ .limit(1)
966
+ )
967
+
968
+ sample_result = conn.execute(sample_stmt).fetchone()
969
+ dimensions = None
970
+ if sample_result and sample_result[0]:
971
+ try:
972
+ sample_embedding = json.loads(sample_result[0])
973
+ dimensions = len(sample_embedding)
974
+ except json.JSONDecodeError:
975
+ pass
976
+
977
+ coverage_percent = (embedded_count / total_count * 100) if total_count > 0 else 0.0
978
+
979
+ return {
980
+ "success": True,
981
+ "table_name": table_name,
982
+ "total_rows": total_count,
983
+ "embedded_rows": embedded_count,
984
+ "coverage_percent": round(coverage_percent, 1),
985
+ "embedding_dimensions": dimensions,
986
+ "embedding_column": embedding_column,
987
+ }
988
+
989
+ except (ValidationError, SQLAlchemyError) as e:
990
+ if isinstance(e, ValidationError):
991
+ raise e
992
+ raise DatabaseError(f"Failed to get embedding stats: {str(e)}")
993
+
994
+
995
+ # Global database instance
996
+ _db_instance: Optional[SQLiteMemoryDatabase] = None
997
+
998
+
999
+ def get_database(db_path: Optional[str] = None) -> SQLiteMemoryDatabase:
1000
+ """Get or create the global database instance."""
1001
+ global _db_instance
1002
+
1003
+ actual_path = db_path or os.environ.get("DB_PATH", "./test.db")
1004
+ if _db_instance is None or (db_path and db_path != _db_instance.db_path):
1005
+ # Close previous instance if it exists
1006
+ if _db_instance is not None:
1007
+ _db_instance.close()
1008
+ _db_instance = SQLiteMemoryDatabase(actual_path)
1009
+
1010
+ return _db_instance