mcp-sqlite-memory-bank 1.2.4__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,923 @@
1
+ """
2
+ Database abstraction layer for SQLite Memory Bank using SQLAlchemy Core.
3
+
4
+ This module provides a clean, type-safe abstraction over SQLite operations,
5
+ dramatically simplifying the complex database logic in server.py.
6
+
7
+ Author: Robert Meisner
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import logging
13
+ from functools import wraps
14
+ from typing import Dict, List, Any, Optional, Callable, cast
15
+ from sqlalchemy import create_engine, MetaData, Table, select, insert, update, delete, text, inspect, and_, or_
16
+ from sqlalchemy.engine import Engine
17
+ from sqlalchemy.exc import SQLAlchemyError
18
+ from contextlib import contextmanager
19
+
20
+ from .types import ValidationError, DatabaseError, SchemaError, ToolResponse, EmbeddingColumnResponse, GenerateEmbeddingsResponse, SemanticSearchResponse, RelatedContentResponse, HybridSearchResponse, EmbeddingStatsResponse
21
+ from .semantic import get_semantic_engine, is_semantic_search_available
22
+
23
+
24
+ class SQLiteMemoryDatabase:
25
+ """
26
+ SQLAlchemy Core-based database abstraction for SQLite Memory Bank.
27
+
28
+ This class handles all database operations with automatic:
29
+ - Connection management
30
+ - Error handling and translation
31
+ - SQL injection protection
32
+ - Type conversion
33
+ - Transaction management
34
+ """
35
+
36
+ def __init__(self, db_path: str):
37
+ """Initialize database connection and metadata."""
38
+ self.db_path = os.path.abspath(db_path)
39
+ self.engine: Engine = create_engine(f"sqlite:///{self.db_path}", echo=False)
40
+ self.metadata = MetaData()
41
+
42
+ # Ensure database directory exists
43
+ os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
44
+
45
+ # Initialize connection
46
+ self._refresh_metadata()
47
+
48
+ def close(self) -> None:
49
+ """Close all database connections and dispose of the engine."""
50
+ try:
51
+ if hasattr(self, "engine"):
52
+ self.engine.dispose()
53
+ except Exception as e:
54
+ logging.warning(f"Error closing database: {e}")
55
+
56
+ def __del__(self):
57
+ """Ensure cleanup when object is garbage collected."""
58
+ self.close()
59
+
60
+ def _refresh_metadata(self) -> None:
61
+ """Refresh metadata to reflect current database schema."""
62
+ try:
63
+ self.metadata.clear()
64
+ self.metadata.reflect(bind=self.engine)
65
+ except SQLAlchemyError as e:
66
+ logging.warning(f"Failed to refresh metadata: {e}")
67
+
68
+ @contextmanager
69
+ def get_connection(self):
70
+ """Get a database connection with automatic cleanup."""
71
+ conn = self.engine.connect()
72
+ try:
73
+ yield conn
74
+ finally:
75
+ conn.close()
76
+
77
+ def _ensure_table_exists(self, table_name: str) -> Table:
78
+ """Get table metadata, refreshing if needed.
79
+ Raises ValidationError if not found.
80
+ """
81
+ if table_name not in self.metadata.tables:
82
+ self._refresh_metadata()
83
+
84
+ if table_name not in self.metadata.tables:
85
+ raise ValidationError(f"Table '{table_name}' does not exist")
86
+
87
+ return self.metadata.tables[table_name]
88
+
89
+ def _validate_columns(self, table: Table, column_names: List[str], context: str = "operation") -> None:
90
+ """Validate that all column names exist in the table."""
91
+ valid_columns = set(col.name for col in table.columns)
92
+ for col_name in column_names:
93
+ if col_name not in valid_columns:
94
+ raise ValidationError(f"Invalid column '{col_name}' for table " f"'{table.name}' in {context}")
95
+
96
+ def _build_where_conditions(self, table: Table, where: Dict[str, Any]) -> List:
97
+ """Build SQLAlchemy WHERE conditions from a dictionary."""
98
+ if not where:
99
+ return []
100
+
101
+ self._validate_columns(table, list(where.keys()), "WHERE clause")
102
+ return [table.c[col_name] == value for col_name, value in where.items()]
103
+
104
+ def _execute_with_commit(self, stmt) -> Any:
105
+ """Execute a statement with automatic connection mgmt and commit."""
106
+ with self.get_connection() as conn:
107
+ result = conn.execute(stmt)
108
+ conn.commit()
109
+ return result
110
+
111
+ def _database_operation(self, operation_name: str):
112
+ """Decorator for database operations with standardized error handling."""
113
+
114
+ def decorator(func: Callable) -> Callable:
115
+ @wraps(func)
116
+ def wrapper(*args, **kwargs) -> ToolResponse:
117
+ try:
118
+ return func(*args, **kwargs)
119
+ except (ValidationError, SchemaError) as e:
120
+ raise e
121
+ except SQLAlchemyError as e:
122
+ raise DatabaseError(f"Failed to {operation_name}: {str(e)}")
123
+
124
+ return wrapper
125
+
126
+ return decorator
127
+
128
+ def create_table(self, table_name: str, columns: List[Dict[str, str]]) -> ToolResponse:
129
+ """Create a new table with the specified columns."""
130
+ # Input validation
131
+ if not table_name or not table_name.isidentifier():
132
+ raise ValidationError(f"Invalid table name: {table_name}")
133
+ if not columns:
134
+ raise ValidationError("Must provide at least one column")
135
+
136
+ # Validate column definitions
137
+ for col_def in columns:
138
+ if "name" not in col_def or "type" not in col_def:
139
+ raise ValidationError(f"Column must have 'name' and 'type': {col_def}")
140
+ if not col_def["name"].isidentifier():
141
+ raise ValidationError(f"Invalid column name: {col_def['name']}")
142
+
143
+ try:
144
+ # Use raw SQL for full SQLite type support
145
+ col_defs = ", ".join([f"{col['name']} {col['type']}" for col in columns])
146
+ sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({col_defs})"
147
+
148
+ with self.get_connection() as conn:
149
+ conn.execute(text(sql))
150
+ conn.commit()
151
+
152
+ self._refresh_metadata()
153
+ return {"success": True}
154
+
155
+ except SQLAlchemyError as e:
156
+ raise DatabaseError(f"Failed to create table {table_name}: {str(e)}")
157
+
158
+ def list_tables(self) -> ToolResponse:
159
+ """List all user-created tables."""
160
+ try:
161
+ with self.get_connection() as conn:
162
+ inspector = inspect(conn)
163
+ tables = [name for name in inspector.get_table_names() if not name.startswith("sqlite_")]
164
+ return {"success": True, "tables": tables}
165
+ except SQLAlchemyError as e:
166
+ raise DatabaseError(f"Failed to list tables: {str(e)}")
167
+
168
+ def describe_table(self, table_name: str) -> ToolResponse:
169
+ """Get detailed schema information for a table."""
170
+ try:
171
+ table = self._ensure_table_exists(table_name)
172
+ columns = [
173
+ {
174
+ "name": col.name,
175
+ "type": str(col.type),
176
+ "nullable": col.nullable,
177
+ "default": col.default,
178
+ "primary_key": col.primary_key,
179
+ }
180
+ for col in table.columns
181
+ ]
182
+ return {"success": True, "columns": columns}
183
+ except (ValidationError, SQLAlchemyError) as e:
184
+ if isinstance(e, ValidationError):
185
+ raise e
186
+ raise DatabaseError(f"Failed to describe table {table_name}: {str(e)}")
187
+
188
+ def drop_table(self, table_name: str) -> ToolResponse:
189
+ """Drop a table."""
190
+ try:
191
+ self._ensure_table_exists(table_name) # Validates existence
192
+ self._execute_with_commit(text(f"DROP TABLE {table_name}"))
193
+ self._refresh_metadata()
194
+ return {"success": True}
195
+ except (ValidationError, SQLAlchemyError) as e:
196
+ if isinstance(e, ValidationError):
197
+ raise e
198
+ raise DatabaseError(f"Failed to drop table {table_name}: {str(e)}")
199
+
200
+ def rename_table(self, old_name: str, new_name: str) -> ToolResponse:
201
+ """Rename a table."""
202
+ if old_name == new_name:
203
+ raise ValidationError("Old and new table names are identical")
204
+
205
+ try:
206
+ self._ensure_table_exists(old_name) # Validates old table exists
207
+
208
+ # Check if new name already exists
209
+ with self.get_connection() as conn:
210
+ inspector = inspect(conn)
211
+ if new_name in inspector.get_table_names():
212
+ raise ValidationError(f"Table '{new_name}' already exists")
213
+
214
+ conn.execute(text(f"ALTER TABLE {old_name} RENAME TO {new_name}"))
215
+ conn.commit()
216
+
217
+ self._refresh_metadata()
218
+ return {"success": True}
219
+ except (ValidationError, SQLAlchemyError) as e:
220
+ if isinstance(e, ValidationError):
221
+ raise e
222
+ raise DatabaseError(f"Failed to rename table from {old_name} to {new_name}: {str(e)}")
223
+
224
+ def insert_row(self, table_name: str, data: Dict[str, Any]) -> ToolResponse:
225
+ """Insert a row into a table."""
226
+ if not data:
227
+ raise ValidationError("Data cannot be empty")
228
+
229
+ try:
230
+ table = self._ensure_table_exists(table_name)
231
+ self._validate_columns(table, list(data.keys()), "insert operation")
232
+
233
+ result = self._execute_with_commit(insert(table).values(**data))
234
+ return {"success": True, "id": result.lastrowid}
235
+ except (ValidationError, SQLAlchemyError) as e:
236
+ if isinstance(e, ValidationError):
237
+ raise e
238
+ raise DatabaseError(f"Failed to insert into table {table_name}: {str(e)}")
239
+
240
+ def read_rows(self, table_name: str, where: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> ToolResponse:
241
+ """Read rows from a table with optional filtering."""
242
+ try:
243
+ table = self._ensure_table_exists(table_name)
244
+ stmt = select(table)
245
+
246
+ # Apply WHERE conditions
247
+ conditions = self._build_where_conditions(table, where or {})
248
+ if conditions:
249
+ stmt = stmt.where(and_(*conditions))
250
+
251
+ # Apply LIMIT
252
+ if limit:
253
+ stmt = stmt.limit(limit)
254
+
255
+ with self.get_connection() as conn:
256
+ result = conn.execute(stmt)
257
+ rows = [dict(row._mapping) for row in result.fetchall()]
258
+
259
+ return {"success": True, "rows": rows}
260
+ except (ValidationError, SQLAlchemyError) as e:
261
+ if isinstance(e, ValidationError):
262
+ raise e
263
+ raise DatabaseError(f"Failed to read from table {table_name}: {str(e)}")
264
+
265
+ def update_rows(self, table_name: str, data: Dict[str, Any], where: Optional[Dict[str, Any]] = None) -> ToolResponse:
266
+ """Update rows in a table."""
267
+ if not data:
268
+ raise ValidationError("Update data cannot be empty")
269
+
270
+ try:
271
+ table = self._ensure_table_exists(table_name)
272
+ self._validate_columns(table, list(data.keys()), "update operation")
273
+
274
+ stmt = update(table).values(**data)
275
+
276
+ # Apply WHERE conditions
277
+ conditions = self._build_where_conditions(table, where or {})
278
+ if conditions:
279
+ stmt = stmt.where(and_(*conditions))
280
+
281
+ result = self._execute_with_commit(stmt)
282
+ return {"success": True, "rows_affected": result.rowcount}
283
+ except (ValidationError, SQLAlchemyError) as e:
284
+ if isinstance(e, ValidationError):
285
+ raise e
286
+ raise DatabaseError(f"Failed to update table {table_name}: {str(e)}")
287
+
288
+ def delete_rows(self, table_name: str, where: Optional[Dict[str, Any]] = None) -> ToolResponse:
289
+ """Delete rows from a table."""
290
+ try:
291
+ table = self._ensure_table_exists(table_name)
292
+ stmt = delete(table)
293
+
294
+ # Apply WHERE conditions
295
+ conditions = self._build_where_conditions(table, where or {})
296
+ if conditions:
297
+ stmt = stmt.where(and_(*conditions))
298
+ else:
299
+ logging.warning(f"delete_rows called without WHERE clause on table {table_name}")
300
+
301
+ result = self._execute_with_commit(stmt)
302
+ return {"success": True, "rows_affected": result.rowcount}
303
+ except (ValidationError, SQLAlchemyError) as e:
304
+ if isinstance(e, ValidationError):
305
+ raise e
306
+ raise DatabaseError(f"Failed to delete from table {table_name}: {str(e)}")
307
+
308
+ def select_query(
309
+ self, table_name: str, columns: Optional[List[str]] = None, where: Optional[Dict[str, Any]] = None, limit: int = 100
310
+ ) -> ToolResponse:
311
+ """Run a SELECT query with specified columns and conditions."""
312
+ if limit < 1:
313
+ raise ValidationError("Limit must be a positive integer")
314
+
315
+ try:
316
+ table = self._ensure_table_exists(table_name)
317
+
318
+ # Build SELECT columns
319
+ if columns:
320
+ self._validate_columns(table, columns, "SELECT operation")
321
+ select_columns = [table.c[col_name] for col_name in columns]
322
+ stmt = select(*select_columns)
323
+ else:
324
+ stmt = select(table)
325
+
326
+ # Apply WHERE conditions
327
+ conditions = self._build_where_conditions(table, where or {})
328
+ if conditions:
329
+ stmt = stmt.where(and_(*conditions))
330
+
331
+ stmt = stmt.limit(limit)
332
+
333
+ with self.get_connection() as conn:
334
+ result = conn.execute(stmt)
335
+ rows = [dict(row._mapping) for row in result.fetchall()]
336
+
337
+ return {"success": True, "rows": rows}
338
+ except (ValidationError, SQLAlchemyError) as e:
339
+ if isinstance(e, ValidationError):
340
+ raise e
341
+ raise DatabaseError(f"Failed to query table {table_name}: {str(e)}")
342
+
343
+ def list_all_columns(self) -> ToolResponse:
344
+ """List all columns for all tables."""
345
+ try:
346
+ self._refresh_metadata()
347
+ schemas = {table_name: [col.name for col in table.columns] for table_name, table in self.metadata.tables.items()}
348
+ return {"success": True, "schemas": schemas}
349
+ except SQLAlchemyError as e:
350
+ raise DatabaseError(f"Failed to list all columns: {str(e)}")
351
+
352
+ def search_content(self, query: str, tables: Optional[List[str]] = None, limit: int = 50) -> ToolResponse:
353
+ """Perform full-text search across table content."""
354
+ if not query or not query.strip():
355
+ raise ValidationError("Search query cannot be empty")
356
+ if limit < 1:
357
+ raise ValidationError("Limit must be a positive integer")
358
+
359
+ try:
360
+ self._refresh_metadata()
361
+ search_tables = tables or list(self.metadata.tables.keys())
362
+ results = []
363
+
364
+ with self.get_connection() as conn:
365
+ for table_name in search_tables:
366
+ if table_name not in self.metadata.tables:
367
+ continue
368
+
369
+ table = self.metadata.tables[table_name]
370
+ text_columns = [
371
+ col for col in table.columns if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper()
372
+ ]
373
+
374
+ if not text_columns:
375
+ continue
376
+
377
+ # Build search conditions and execute
378
+ conditions = [col.like(f"%{query}%") for col in text_columns]
379
+ stmt = select(table).where(or_(*conditions)).limit(limit)
380
+
381
+ for row in conn.execute(stmt).fetchall():
382
+ row_dict = dict(row._mapping)
383
+
384
+ # Calculate relevance and matched content
385
+ relevance = 0.0
386
+ matched_content = []
387
+ query_lower = query.lower()
388
+
389
+ for col in text_columns:
390
+ if col.name in row_dict and row_dict[col.name]:
391
+ content = str(row_dict[col.name]).lower()
392
+ if query_lower in content:
393
+ frequency = content.count(query_lower)
394
+ relevance += frequency / len(content)
395
+ matched_content.append(f"{col.name}: {row_dict[col.name]}")
396
+
397
+ if relevance > 0:
398
+ results.append(
399
+ {
400
+ "table": table_name,
401
+ "row_id": row_dict.get("id"),
402
+ "row_data": row_dict,
403
+ "matched_content": matched_content,
404
+ "relevance": round(relevance, 3),
405
+ }
406
+ )
407
+
408
+ # Sort by relevance and limit results
409
+ results.sort(key=lambda x: x["relevance"], reverse=True)
410
+ results = results[:limit]
411
+
412
+ return {
413
+ "success": True,
414
+ "results": results,
415
+ "query": query,
416
+ "tables_searched": search_tables,
417
+ "total_results": len(results),
418
+ }
419
+ except (ValidationError, SQLAlchemyError) as e:
420
+ if isinstance(e, ValidationError):
421
+ raise e
422
+ raise DatabaseError(f"Failed to search content: {str(e)}")
423
+
424
+ def explore_tables(self, pattern: Optional[str] = None, include_row_counts: bool = True) -> ToolResponse:
425
+ """Explore table structures and content."""
426
+ try:
427
+ self._refresh_metadata()
428
+ table_names = list(self.metadata.tables.keys())
429
+
430
+ if pattern:
431
+ table_names = [name for name in table_names if pattern.replace("%", "") in name]
432
+
433
+ exploration = {"tables": [], "total_tables": len(table_names), "total_rows": 0}
434
+
435
+ with self.get_connection() as conn:
436
+ for table_name in table_names:
437
+ table = self.metadata.tables[table_name]
438
+
439
+ # Build column info and identify text columns
440
+ columns = []
441
+ text_columns = []
442
+
443
+ for col in table.columns:
444
+ col_data = {
445
+ "name": col.name,
446
+ "type": str(col.type),
447
+ "nullable": col.nullable,
448
+ "default": col.default,
449
+ "primary_key": col.primary_key,
450
+ }
451
+ columns.append(col_data)
452
+
453
+ if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper():
454
+ text_columns.append(col.name)
455
+
456
+ table_info = {"name": table_name, "columns": columns, "text_columns": text_columns}
457
+
458
+ # Add row count if requested
459
+ if include_row_counts:
460
+ count_result = conn.execute(select(text("COUNT(*)")).select_from(table))
461
+ row_count = count_result.scalar()
462
+ table_info["row_count"] = row_count
463
+ exploration["total_rows"] += row_count
464
+
465
+ # Add sample data
466
+ sample_result = conn.execute(select(table).limit(3))
467
+ sample_rows = [dict(row._mapping) for row in sample_result.fetchall()]
468
+ if sample_rows:
469
+ table_info["sample_data"] = sample_rows
470
+
471
+ # Add content preview for text columns
472
+ if text_columns:
473
+ content_preview = {}
474
+ for col_name in text_columns[:3]: # Limit to first 3 text columns
475
+ col = table.c[col_name]
476
+ preview_result = conn.execute(select(col).distinct().where(col.isnot(None)).limit(5))
477
+ unique_values = [row[0] for row in preview_result.fetchall() if row[0]]
478
+ if unique_values:
479
+ content_preview[col_name] = unique_values
480
+
481
+ if content_preview:
482
+ table_info["content_preview"] = content_preview
483
+
484
+ exploration["tables"].append(table_info)
485
+
486
+ return {"success": True, "exploration": exploration}
487
+ except SQLAlchemyError as e:
488
+ raise DatabaseError(f"Failed to explore tables: {str(e)}")
489
+
490
+ # --- Semantic Search Methods ---
491
+
492
+ def add_embedding_column(self, table_name: str, embedding_column: str = "embedding") -> EmbeddingColumnResponse:
493
+ """Add an embedding column to a table for semantic search."""
494
+ try:
495
+ table = self._ensure_table_exists(table_name)
496
+
497
+ # Check if embedding column already exists
498
+ if embedding_column in [col.name for col in table.columns]:
499
+ return {"success": True, "message": f"Embedding column '{embedding_column}' already exists"}
500
+
501
+ # Add embedding column as TEXT (JSON storage)
502
+ with self.get_connection() as conn:
503
+ conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {embedding_column} TEXT"))
504
+ conn.commit()
505
+
506
+ self._refresh_metadata()
507
+ return {"success": True, "message": f"Added embedding column '{embedding_column}' to table '{table_name}'"}
508
+
509
+ except (ValidationError, SQLAlchemyError) as e:
510
+ if isinstance(e, ValidationError):
511
+ raise e
512
+ raise DatabaseError(f"Failed to add embedding column: {str(e)}")
513
+
514
+ def generate_embeddings(self, table_name: str, text_columns: List[str],
515
+ embedding_column: str = "embedding",
516
+ model_name: str = "all-MiniLM-L6-v2",
517
+ batch_size: int = 50) -> GenerateEmbeddingsResponse:
518
+ """Generate embeddings for text content in a table."""
519
+ if not is_semantic_search_available():
520
+ raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
521
+
522
+ try:
523
+ table = self._ensure_table_exists(table_name)
524
+ semantic_engine = get_semantic_engine(model_name)
525
+
526
+ # Validate text columns exist
527
+ table_columns = [col.name for col in table.columns]
528
+ for col in text_columns:
529
+ if col not in table_columns:
530
+ raise ValidationError(f"Column '{col}' not found in table '{table_name}'")
531
+
532
+ # Add embedding column if it doesn't exist
533
+ if embedding_column not in table_columns:
534
+ self.add_embedding_column(table_name, embedding_column)
535
+ table = self._ensure_table_exists(table_name) # Refresh
536
+
537
+ # Get all rows that need embeddings
538
+ with self.get_connection() as conn:
539
+ # Select rows without embeddings or with null embeddings
540
+ stmt = select(table).where(
541
+ or_(table.c[embedding_column].is_(None),
542
+ table.c[embedding_column] == "",
543
+ table.c[embedding_column] == "null")
544
+ )
545
+ rows = conn.execute(stmt).fetchall()
546
+
547
+ if not rows:
548
+ embedding_dim = semantic_engine.get_embedding_dimensions() or 0
549
+ return {"success": True, "message": "All rows already have embeddings", "processed": 0, "model": model_name, "embedding_dimension": embedding_dim}
550
+
551
+ processed = 0
552
+ for i in range(0, len(rows), batch_size):
553
+ batch = rows[i:i + batch_size]
554
+
555
+ for row in batch:
556
+ row_dict = dict(row._mapping)
557
+
558
+ # Combine text from specified columns
559
+ text_parts = []
560
+ for col in text_columns:
561
+ if col in row_dict and row_dict[col]:
562
+ text_parts.append(str(row_dict[col]))
563
+
564
+ if text_parts:
565
+ combined_text = " ".join(text_parts)
566
+
567
+ # Generate embedding
568
+ embedding = semantic_engine.generate_embedding(combined_text)
569
+ embedding_json = json.dumps(embedding)
570
+
571
+ # Update row with embedding
572
+ update_stmt = update(table).where(
573
+ table.c["id"] == row_dict["id"]
574
+ ).values({embedding_column: embedding_json})
575
+
576
+ conn.execute(update_stmt)
577
+ processed += 1
578
+
579
+ conn.commit()
580
+ logging.info(f"Generated embeddings for batch {i//batch_size + 1}, processed {processed} rows")
581
+
582
+ return {
583
+ "success": True,
584
+ "message": f"Generated embeddings for {processed} rows",
585
+ "processed": processed,
586
+ "model": model_name,
587
+ "embedding_dimension": semantic_engine.get_embedding_dimensions() or 0
588
+ }
589
+
590
+ except (ValidationError, SQLAlchemyError) as e:
591
+ if isinstance(e, ValidationError):
592
+ raise e
593
+ raise DatabaseError(f"Failed to generate embeddings: {str(e)}")
594
+
595
+ def semantic_search(self, query: str, tables: Optional[List[str]] = None,
596
+ embedding_column: str = "embedding",
597
+ text_columns: Optional[List[str]] = None,
598
+ similarity_threshold: float = 0.5,
599
+ limit: int = 10,
600
+ model_name: str = "all-MiniLM-L6-v2") -> SemanticSearchResponse:
601
+ """Perform semantic search across tables using vector embeddings."""
602
+ if not is_semantic_search_available():
603
+ raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
604
+
605
+ if not query or not query.strip():
606
+ raise ValidationError("Search query cannot be empty")
607
+
608
+ try:
609
+ self._refresh_metadata()
610
+ search_tables = tables or list(self.metadata.tables.keys())
611
+ semantic_engine = get_semantic_engine(model_name)
612
+
613
+ all_results = []
614
+
615
+ with self.get_connection() as conn:
616
+ for table_name in search_tables:
617
+ if table_name not in self.metadata.tables:
618
+ continue
619
+
620
+ table = self.metadata.tables[table_name]
621
+
622
+ # Check if table has embedding column
623
+ if embedding_column not in [col.name for col in table.columns]:
624
+ logging.warning(f"Table '{table_name}' does not have embedding column '{embedding_column}'")
625
+ continue
626
+
627
+ # Get all rows with embeddings
628
+ stmt = select(table).where(
629
+ and_(table.c[embedding_column].isnot(None),
630
+ table.c[embedding_column] != "",
631
+ table.c[embedding_column] != "null")
632
+ )
633
+ rows = conn.execute(stmt).fetchall()
634
+
635
+ if not rows:
636
+ continue
637
+
638
+ # Convert to list of dicts for semantic search
639
+ content_data = [dict(row._mapping) for row in rows]
640
+
641
+ # Determine text columns for highlighting
642
+ if text_columns is None:
643
+ text_cols = [col.name for col in table.columns
644
+ if "TEXT" in str(col.type).upper() or "VARCHAR" in str(col.type).upper()]
645
+ else:
646
+ text_cols = text_columns
647
+
648
+ # Perform semantic search on this table
649
+ table_results = semantic_engine.semantic_search(
650
+ query, content_data, embedding_column, text_cols,
651
+ similarity_threshold, limit * 2 # Get more for global ranking
652
+ )
653
+
654
+ # Add table name to results
655
+ for result in table_results:
656
+ result["table_name"] = table_name
657
+
658
+ all_results.extend(table_results)
659
+
660
+ # Sort all results by similarity score and limit
661
+ all_results.sort(key=lambda x: x.get("similarity_score", 0), reverse=True)
662
+ final_results = all_results[:limit]
663
+
664
+ return {
665
+ "success": True,
666
+ "results": final_results,
667
+ "query": query,
668
+ "tables_searched": search_tables,
669
+ "total_results": len(final_results),
670
+ "model": model_name,
671
+ "similarity_threshold": similarity_threshold
672
+ }
673
+
674
+ except (ValidationError, SQLAlchemyError) as e:
675
+ if isinstance(e, ValidationError):
676
+ raise e
677
+ raise DatabaseError(f"Semantic search failed: {str(e)}")
678
+
679
+ def find_related_content(self, table_name: str, row_id: int,
680
+ embedding_column: str = "embedding",
681
+ similarity_threshold: float = 0.5,
682
+ limit: int = 5,
683
+ model_name: str = "all-MiniLM-L6-v2") -> RelatedContentResponse:
684
+ """Find content related to a specific row by semantic similarity."""
685
+ if not is_semantic_search_available():
686
+ raise ValidationError("Semantic search is not available. Please install sentence-transformers.")
687
+
688
+ try:
689
+ table = self._ensure_table_exists(table_name)
690
+ semantic_engine = get_semantic_engine(model_name)
691
+
692
+ with self.get_connection() as conn:
693
+ # Get the target row
694
+ target_stmt = select(table).where(table.c["id"] == row_id)
695
+ target_row = conn.execute(target_stmt).fetchone()
696
+
697
+ if not target_row:
698
+ raise ValidationError(f"Row with id {row_id} not found in table '{table_name}'")
699
+
700
+ target_dict = dict(target_row._mapping)
701
+
702
+ # Check if target has embedding
703
+ if (embedding_column not in target_dict or
704
+ not target_dict[embedding_column] or
705
+ target_dict[embedding_column] in ["", "null"]):
706
+ raise ValidationError(f"Row {row_id} does not have an embedding")
707
+
708
+ # Get target embedding
709
+ target_embedding = json.loads(target_dict[embedding_column])
710
+
711
+ # Get all other rows with embeddings
712
+ stmt = select(table).where(
713
+ and_(table.c["id"] != row_id,
714
+ table.c[embedding_column].isnot(None),
715
+ table.c[embedding_column] != "",
716
+ table.c[embedding_column] != "null")
717
+ )
718
+ rows = conn.execute(stmt).fetchall()
719
+
720
+ if not rows:
721
+ return {
722
+ "success": True,
723
+ "results": [],
724
+ "target_row": target_dict,
725
+ "total_results": 0,
726
+ "similarity_threshold": similarity_threshold,
727
+ "model": model_name,
728
+ "message": "No other rows with embeddings found"
729
+ }
730
+
731
+ # Find similar rows
732
+ content_data = [dict(row._mapping) for row in rows]
733
+ candidate_embeddings = []
734
+ valid_indices = []
735
+
736
+ for idx, row_dict in enumerate(content_data):
737
+ try:
738
+ embedding = json.loads(row_dict[embedding_column])
739
+ candidate_embeddings.append(embedding)
740
+ valid_indices.append(idx)
741
+ except json.JSONDecodeError:
742
+ continue
743
+
744
+ if not candidate_embeddings:
745
+ return {
746
+ "success": True,
747
+ "results": [],
748
+ "target_row": target_dict,
749
+ "total_results": 0,
750
+ "similarity_threshold": similarity_threshold,
751
+ "model": model_name,
752
+ "message": "No valid embeddings found for comparison"
753
+ }
754
+
755
+ # Calculate similarities
756
+ similar_indices = semantic_engine.find_similar_embeddings(
757
+ target_embedding, candidate_embeddings,
758
+ similarity_threshold, limit
759
+ )
760
+
761
+ # Build results
762
+ results = []
763
+ for candidate_idx, similarity_score in similar_indices:
764
+ original_idx = valid_indices[candidate_idx]
765
+ row_dict = content_data[original_idx].copy()
766
+ row_dict["similarity_score"] = round(similarity_score, 3)
767
+ results.append(row_dict)
768
+
769
+ return {
770
+ "success": True,
771
+ "results": results,
772
+ "target_row": target_dict,
773
+ "total_results": len(results),
774
+ "similarity_threshold": similarity_threshold,
775
+ "model": model_name,
776
+ "message": f"Found {len(results)} related items"
777
+ }
778
+
779
+ except (ValidationError, SQLAlchemyError) as e:
780
+ if isinstance(e, ValidationError):
781
+ raise e
782
+ raise DatabaseError(f"Failed to find related content: {str(e)}")
783
+
784
+ def hybrid_search(self, query: str, tables: Optional[List[str]] = None,
785
+ text_columns: Optional[List[str]] = None,
786
+ embedding_column: str = "embedding",
787
+ semantic_weight: float = 0.7,
788
+ text_weight: float = 0.3,
789
+ limit: int = 10,
790
+ model_name: str = "all-MiniLM-L6-v2") -> HybridSearchResponse:
791
+ """Combine semantic search with keyword matching for optimal results."""
792
+ if not is_semantic_search_available():
793
+ # Fallback to text search only
794
+ fallback_result = self.search_content(query, tables, limit)
795
+ # Convert to HybridSearchResponse format
796
+ return cast(HybridSearchResponse, {
797
+ **fallback_result,
798
+ "search_type": "text_only",
799
+ "semantic_weight": 0.0,
800
+ "text_weight": 1.0,
801
+ "model": "none"
802
+ })
803
+
804
+ try:
805
+ # Get semantic search results
806
+ semantic_response = self.semantic_search(
807
+ query, tables, embedding_column, text_columns,
808
+ similarity_threshold=0.3, limit=limit * 2, model_name=model_name
809
+ )
810
+
811
+ if not semantic_response.get("success"):
812
+ return cast(HybridSearchResponse, {
813
+ **semantic_response,
814
+ "search_type": "semantic_failed",
815
+ "semantic_weight": semantic_weight,
816
+ "text_weight": text_weight,
817
+ "model": model_name
818
+ })
819
+
820
+ semantic_results = semantic_response.get("results", [])
821
+
822
+ if not semantic_results:
823
+ # Fallback to text search
824
+ fallback_result = self.search_content(query, tables, limit)
825
+ return cast(HybridSearchResponse, {
826
+ **fallback_result,
827
+ "search_type": "text_fallback",
828
+ "semantic_weight": semantic_weight,
829
+ "text_weight": text_weight,
830
+ "model": model_name
831
+ })
832
+
833
+ # Enhance with text matching scores
834
+ semantic_engine = get_semantic_engine(model_name)
835
+ enhanced_results = semantic_engine.hybrid_search(
836
+ query, semantic_results, text_columns or [],
837
+ embedding_column, semantic_weight, text_weight, limit
838
+ )
839
+
840
+ return {
841
+ "success": True,
842
+ "results": enhanced_results,
843
+ "query": query,
844
+ "search_type": "hybrid",
845
+ "semantic_weight": semantic_weight,
846
+ "text_weight": text_weight,
847
+ "total_results": len(enhanced_results),
848
+ "model": model_name
849
+ }
850
+
851
+ except (ValidationError, SQLAlchemyError) as e:
852
+ if isinstance(e, ValidationError):
853
+ raise e
854
+ raise DatabaseError(f"Hybrid search failed: {str(e)}")
855
+
856
+ def get_embedding_stats(self, table_name: str, embedding_column: str = "embedding") -> EmbeddingStatsResponse:
857
+ """Get statistics about embeddings in a table."""
858
+ try:
859
+ table = self._ensure_table_exists(table_name)
860
+
861
+ with self.get_connection() as conn:
862
+ # Count total rows
863
+ total_count = conn.execute(select(text("COUNT(*)")).select_from(table)).scalar() or 0
864
+
865
+ # Count rows with embeddings
866
+ embedded_count = conn.execute(
867
+ select(text("COUNT(*)")).select_from(table).where(
868
+ and_(table.c[embedding_column].isnot(None),
869
+ table.c[embedding_column] != "",
870
+ table.c[embedding_column] != "null")
871
+ )
872
+ ).scalar() or 0
873
+
874
+ # Get sample embedding to check dimensions
875
+ sample_stmt = select(table.c[embedding_column]).where(
876
+ and_(table.c[embedding_column].isnot(None),
877
+ table.c[embedding_column] != "",
878
+ table.c[embedding_column] != "null")
879
+ ).limit(1)
880
+
881
+ sample_result = conn.execute(sample_stmt).fetchone()
882
+ dimensions = None
883
+ if sample_result and sample_result[0]:
884
+ try:
885
+ sample_embedding = json.loads(sample_result[0])
886
+ dimensions = len(sample_embedding)
887
+ except json.JSONDecodeError:
888
+ pass
889
+
890
+ coverage_percent = (embedded_count / total_count * 100) if total_count > 0 else 0.0
891
+
892
+ return {
893
+ "success": True,
894
+ "table_name": table_name,
895
+ "total_rows": total_count,
896
+ "embedded_rows": embedded_count,
897
+ "coverage_percent": round(coverage_percent, 1),
898
+ "embedding_dimensions": dimensions,
899
+ "embedding_column": embedding_column
900
+ }
901
+
902
+ except (ValidationError, SQLAlchemyError) as e:
903
+ if isinstance(e, ValidationError):
904
+ raise e
905
+ raise DatabaseError(f"Failed to get embedding stats: {str(e)}")
906
+
907
+
908
+ # Global database instance
909
+ _db_instance: Optional[SQLiteMemoryDatabase] = None
910
+
911
+
912
+ def get_database(db_path: Optional[str] = None) -> SQLiteMemoryDatabase:
913
+ """Get or create the global database instance."""
914
+ global _db_instance
915
+
916
+ actual_path = db_path or os.environ.get("DB_PATH", "./test.db")
917
+ if _db_instance is None or (db_path and db_path != _db_instance.db_path):
918
+ # Close previous instance if it exists
919
+ if _db_instance is not None:
920
+ _db_instance.close()
921
+ _db_instance = SQLiteMemoryDatabase(actual_path)
922
+
923
+ return _db_instance