rakam-systems-vectorstore 0.1.1rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. rakam_systems_vectorstore/MANIFEST.in +26 -0
  2. rakam_systems_vectorstore/README.md +1071 -0
  3. rakam_systems_vectorstore/__init__.py +93 -0
  4. rakam_systems_vectorstore/components/__init__.py +0 -0
  5. rakam_systems_vectorstore/components/chunker/__init__.py +19 -0
  6. rakam_systems_vectorstore/components/chunker/advanced_chunker.py +1019 -0
  7. rakam_systems_vectorstore/components/chunker/text_chunker.py +154 -0
  8. rakam_systems_vectorstore/components/embedding_model/__init__.py +0 -0
  9. rakam_systems_vectorstore/components/embedding_model/configurable_embeddings.py +546 -0
  10. rakam_systems_vectorstore/components/embedding_model/openai_embeddings.py +259 -0
  11. rakam_systems_vectorstore/components/loader/__init__.py +31 -0
  12. rakam_systems_vectorstore/components/loader/adaptive_loader.py +512 -0
  13. rakam_systems_vectorstore/components/loader/code_loader.py +699 -0
  14. rakam_systems_vectorstore/components/loader/doc_loader.py +812 -0
  15. rakam_systems_vectorstore/components/loader/eml_loader.py +556 -0
  16. rakam_systems_vectorstore/components/loader/html_loader.py +626 -0
  17. rakam_systems_vectorstore/components/loader/md_loader.py +622 -0
  18. rakam_systems_vectorstore/components/loader/odt_loader.py +750 -0
  19. rakam_systems_vectorstore/components/loader/pdf_loader.py +771 -0
  20. rakam_systems_vectorstore/components/loader/pdf_loader_light.py +723 -0
  21. rakam_systems_vectorstore/components/loader/tabular_loader.py +597 -0
  22. rakam_systems_vectorstore/components/vectorstore/__init__.py +0 -0
  23. rakam_systems_vectorstore/components/vectorstore/apps.py +10 -0
  24. rakam_systems_vectorstore/components/vectorstore/configurable_pg_vector_store.py +1661 -0
  25. rakam_systems_vectorstore/components/vectorstore/faiss_vector_store.py +878 -0
  26. rakam_systems_vectorstore/components/vectorstore/migrations/0001_initial.py +55 -0
  27. rakam_systems_vectorstore/components/vectorstore/migrations/__init__.py +0 -0
  28. rakam_systems_vectorstore/components/vectorstore/models.py +10 -0
  29. rakam_systems_vectorstore/components/vectorstore/pg_models.py +97 -0
  30. rakam_systems_vectorstore/components/vectorstore/pg_vector_store.py +827 -0
  31. rakam_systems_vectorstore/config.py +266 -0
  32. rakam_systems_vectorstore/core.py +8 -0
  33. rakam_systems_vectorstore/pyproject.toml +113 -0
  34. rakam_systems_vectorstore/server/README.md +290 -0
  35. rakam_systems_vectorstore/server/__init__.py +20 -0
  36. rakam_systems_vectorstore/server/mcp_server_vector.py +325 -0
  37. rakam_systems_vectorstore/setup.py +103 -0
  38. rakam_systems_vectorstore-0.1.1rc7.dist-info/METADATA +370 -0
  39. rakam_systems_vectorstore-0.1.1rc7.dist-info/RECORD +40 -0
  40. rakam_systems_vectorstore-0.1.1rc7.dist-info/WHEEL +4 -0
@@ -0,0 +1,1661 @@
1
+ """
2
+ Configurable PostgreSQL Vector Store with enhanced features.
3
+
4
+ This module provides an enhanced, fully configurable PgVectorStore that:
5
+ - Supports configuration via YAML/JSON files or dictionaries
6
+ - Allows pluggable embedding models
7
+ - Provides update_vector capability
8
+ - Maintains clean separation from other components
9
+ - Supports all search configurations
10
+ - **Dimension-agnostic vector storage**: No need to recreate tables when switching models!
11
+
12
+ ## Flexible Vector Storage
13
+
14
+ Vector columns are created WITHOUT dimension constraints, allowing you to:
15
+ ✓ Switch between embedding models without altering database schema
16
+ ✓ Store vectors of any dimension in the same table structure
17
+ ✓ No automatic table recreation or data loss
18
+ ✓ Simplified database management
19
+
20
+ ## Multi-Model Support
21
+
22
+ By default (use_dimension_specific_tables=True), each embedding model automatically
23
+ gets its own dedicated tables based on the model name:
24
+
25
+ - 'all-MiniLM-L6-v2' → application_nodeentry_all_minilm_l6_v2
26
+ - 'multi-qa-mpnet-base-cos-v1' → application_nodeentry_multi_qa_mpnet_base_cos_v1
27
+ - 'text-embedding-ada-002' → application_nodeentry_text_embedding_ada_002
28
+
29
+ **Why model-specific tables?**
30
+
31
+ Even if two models have the same dimensions (e.g., both 384D), their vector spaces
32
+ are completely different! Mixing embeddings from different models would give
33
+ meaningless results.
34
+
35
+ Example:
36
+ - Model A: 'all-MiniLM-L6-v2' (384D)
37
+ - Model B: 'paraphrase-MiniLM-L3-v2' (384D)
38
+
39
+ These produce vectors in DIFFERENT semantic spaces. You cannot:
40
+ ❌ Search Model A embeddings using Model B query vectors
41
+ ❌ Store both in the same table and expect meaningful results
42
+
43
+ This allows you to:
44
+ ✓ Use multiple models simultaneously (each in its own vector space)
45
+ ✓ Prevent accidental mixing of incompatible vector spaces
46
+ ✓ No manual table management needed
47
+
48
+ Example:
49
+ # Safe by default - each model uses its own tables
50
+ store_mini = ConfigurablePgVectorStore(
51
+ config=config_minilm # Uses all-MiniLM-L6-v2
52
+ )
53
+
54
+ store_mpnet = ConfigurablePgVectorStore(
55
+ config=config_mpnet # Uses multi-qa-mpnet-base-cos-v1
56
+ )
57
+ # Both can coexist without conflicts or vector space mixing!
58
+
59
+ For shared table behavior, set use_dimension_specific_tables=False.
60
+ """
61
+
62
+ from __future__ import annotations
63
+
64
+ import time
65
+ from functools import lru_cache
66
+ from typing import Any, Dict, List, Optional, Tuple, Union
67
+
68
+ import numpy as np
69
+ from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
70
+ from django.db import connection, transaction
71
+
72
+ from rakam_systems_core.ai_utils import logging
73
+ from rakam_systems_core.ai_core.interfaces.vectorstore import VectorStore
74
+ from rakam_systems_vectorstore.components.embedding_model.configurable_embeddings import ConfigurableEmbeddings
75
+ from rakam_systems_vectorstore.components.vectorstore.pg_models import Collection, NodeEntry
76
+ from rakam_systems_vectorstore.config import VectorStoreConfig, load_config
77
+ from rakam_systems_vectorstore.core import Node, NodeMetadata, VSFile
78
+
79
+ logger = logging.getLogger(__name__)
80
+
81
+
82
+ class ConfigurablePgVectorStore(VectorStore):
83
+ """
84
+ Enhanced PostgreSQL Vector Store with full configuration support.
85
+
86
+ Features:
87
+ - Configuration via YAML/JSON or dict
88
+ - Pluggable embedding models
89
+ - Configurable similarity metrics
90
+ - Hybrid search with configurable weights
91
+ - Update operations for vectors
92
+ - Comprehensive metadata filtering
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ name: str = "configurable_pg_vector_store",
98
+ config: Optional[Union[VectorStoreConfig, Dict, str]] = None,
99
+ auto_recreate_on_dimension_mismatch: bool = False,
100
+ use_dimension_specific_tables: bool = True
101
+ ):
102
+ """
103
+ Initialize configurable PostgreSQL vector store.
104
+
105
+ Args:
106
+ name: Component name
107
+ config: Configuration (VectorStoreConfig object, dict, or path to config file)
108
+ auto_recreate_on_dimension_mismatch: DEPRECATED - No longer used. Vector columns now
109
+ support any dimension without schema changes.
110
+ use_dimension_specific_tables: If True, each embedding model gets its own dedicated tables
111
+ based on the model name, preventing:
112
+ - Mixing incompatible vector spaces
113
+ - Meaningless search results from mixed embeddings
114
+ DEFAULT: True (STRONGLY recommended)
115
+
116
+ Important:
117
+ Even models with the same dimensions produce vectors in different semantic spaces!
118
+ For example, 'all-MiniLM-L6-v2' and 'paraphrase-MiniLM-L3-v2' are both 384D,
119
+ but their vectors are NOT compatible. Always use model-specific tables.
120
+ """
121
+ # Load configuration
122
+ if isinstance(config, VectorStoreConfig):
123
+ self.vs_config = config
124
+ elif isinstance(config, dict):
125
+ self.vs_config = VectorStoreConfig.from_dict(config)
126
+ elif isinstance(config, str):
127
+ # Path to config file
128
+ self.vs_config = load_config(config)
129
+ else:
130
+ # Use defaults
131
+ self.vs_config = VectorStoreConfig()
132
+
133
+ # Validate configuration
134
+ self.vs_config.validate()
135
+
136
+ # Initialize base component
137
+ super().__init__(name=name, config=self.vs_config.to_dict())
138
+
139
+ # Setup logging
140
+ if self.vs_config.enable_logging:
141
+ logging.basicConfig(level=self.vs_config.log_level)
142
+
143
+ # Store configuration
144
+ self.auto_recreate_on_dimension_mismatch = auto_recreate_on_dimension_mismatch
145
+ self.use_dimension_specific_tables = use_dimension_specific_tables
146
+
147
+ # Table names will be set after we know the embedding dimension
148
+ self.table_collection = "application_collection"
149
+ self.table_nodeentry = "application_nodeentry"
150
+
151
+ # Initialize embedding model
152
+ self.embedding_model = ConfigurableEmbeddings(
153
+ name=f"{name}_embeddings",
154
+ config=self.vs_config.embedding
155
+ )
156
+
157
+ self.embedding_dim: Optional[int] = None
158
+
159
+ logger.info(f"Initialized {name} with config: {self.vs_config.name}")
160
+
161
+ def setup(self) -> None:
162
+ """Initialize resources and connections."""
163
+ # Skip if already initialized
164
+ if self.initialized:
165
+ logger.debug(
166
+ "ConfigurablePgVectorStore already initialized, skipping setup")
167
+ return
168
+
169
+ logger.info("Setting up ConfigurablePgVectorStore...")
170
+
171
+ # Ensure pgvector extension
172
+ self._ensure_pgvector_extension()
173
+
174
+ # Setup embedding model (will skip if already initialized)
175
+ self.embedding_model.setup()
176
+ self.embedding_dim = self.embedding_model.embedding_dimension
177
+
178
+ # Set table names based on model if using model-specific tables
179
+ if self.use_dimension_specific_tables:
180
+ # Create a safe table suffix from model name
181
+ # Each model gets its own table because even same-dimension models
182
+ # have different vector spaces!
183
+ model_name = self.vs_config.embedding.model_name
184
+ safe_model_name = self._sanitize_model_name(model_name)
185
+
186
+ self.table_collection = f"application_collection_{safe_model_name}"
187
+ self.table_nodeentry = f"application_nodeentry_{safe_model_name}"
188
+ logger.info(
189
+ f"Using model-specific tables for '{model_name}' ({self.embedding_dim}D): "
190
+ f"collection={self.table_collection}, "
191
+ f"nodeentry={self.table_nodeentry}"
192
+ )
193
+
194
+ # Ensure the required tables exist
195
+ self._ensure_vector_dimension_compatibility()
196
+
197
+ logger.info(
198
+ f"Vector store ready with embedding dimension: {self.embedding_dim}")
199
+ super().setup()
200
+
201
+ def _sanitize_model_name(self, model_name: str) -> str:
202
+ """
203
+ Convert model name to a safe table suffix.
204
+
205
+ Examples:
206
+ 'all-MiniLM-L6-v2' -> 'all_minilm_l6_v2'
207
+ 'sentence-transformers/multi-qa-mpnet-base-cos-v1' -> 'multi_qa_mpnet_base_cos_v1'
208
+ 'text-embedding-ada-002' -> 'text_embedding_ada_002'
209
+ """
210
+ import re
211
+
212
+ # Remove common prefixes
213
+ name = model_name.replace('sentence-transformers/', '')
214
+ name = name.replace('models/', '')
215
+
216
+ # Replace non-alphanumeric with underscore
217
+ name = re.sub(r'[^a-zA-Z0-9]', '_', name)
218
+
219
+ # Convert to lowercase
220
+ name = name.lower()
221
+
222
+ # Remove consecutive underscores
223
+ name = re.sub(r'_+', '_', name)
224
+
225
+ # Remove leading/trailing underscores
226
+ name = name.strip('_')
227
+
228
+ # Limit length (PostgreSQL identifier limit is 63 chars)
229
+ if len(name) > 40:
230
+ # Keep last 40 chars (usually has version info)
231
+ name = name[-40:]
232
+ name = name.lstrip('_')
233
+
234
+ return name
235
+
236
+ def _ensure_pgvector_extension(self) -> None:
237
+ """Ensure pgvector extension is installed."""
238
+ with connection.cursor() as cursor:
239
+ try:
240
+ cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
241
+ logger.info("Ensured pgvector extension is installed")
242
+ except Exception as e:
243
+ logger.error(f"Failed to create pgvector extension: {e}")
244
+ raise
245
+
246
+ def _ensure_vector_dimension_compatibility(self) -> None:
247
+ """
248
+ Ensures that the required tables exist.
249
+
250
+ Note: Vector columns are created without dimension constraints, allowing
251
+ flexibility to store vectors of any dimension without needing to alter
252
+ the database schema when switching embedding models.
253
+ """
254
+ with connection.cursor() as cursor:
255
+ try:
256
+ # First ensure the collection table exists
257
+ cursor.execute(f"""
258
+ CREATE TABLE IF NOT EXISTS {self.table_collection} (
259
+ id SERIAL PRIMARY KEY,
260
+ name VARCHAR(255) UNIQUE NOT NULL,
261
+ embedding_dim INTEGER NOT NULL DEFAULT 384,
262
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
263
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
264
+ );
265
+ """)
266
+
267
+ # Check if the nodeentry table exists
268
+ cursor.execute(f"""
269
+ SELECT EXISTS (
270
+ SELECT FROM information_schema.tables
271
+ WHERE table_name = '{self.table_nodeentry}'
272
+ );
273
+ """)
274
+ table_exists = cursor.fetchone()[0]
275
+
276
+ if not table_exists:
277
+ # Table doesn't exist, create it without dimension constraint
278
+ logger.info(
279
+ f"Creating new table '{self.table_nodeentry}' (supports any vector dimension)...")
280
+ cursor.execute(f"""
281
+ CREATE TABLE {self.table_nodeentry} (
282
+ node_id SERIAL PRIMARY KEY,
283
+ collection_id INTEGER NOT NULL REFERENCES {self.table_collection}(id) ON DELETE CASCADE,
284
+ content TEXT NOT NULL,
285
+ embedding vector,
286
+ source_file_uuid VARCHAR(255) NOT NULL,
287
+ position INTEGER,
288
+ custom_metadata JSONB DEFAULT '{{}}'::jsonb,
289
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
290
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
291
+ );
292
+ """)
293
+
294
+ # Create indexes
295
+ cursor.execute(f"""
296
+ CREATE INDEX {self.table_nodeentry}_source_idx
297
+ ON {self.table_nodeentry}(source_file_uuid);
298
+ """)
299
+ cursor.execute(f"""
300
+ CREATE INDEX {self.table_nodeentry}_collect_idx
301
+ ON {self.table_nodeentry}(collection_id, source_file_uuid);
302
+ """)
303
+
304
+ logger.info(
305
+ f"✓ Created table '{self.table_nodeentry}' (dimension-agnostic)")
306
+ else:
307
+ logger.info(
308
+ f"✓ Table '{self.table_nodeentry}' already exists")
309
+
310
+ except Exception as e:
311
+ logger.error(f"Failed to ensure table exists: {e}")
312
+ raise
313
+
314
+ def get_or_create_collection(
315
+ self,
316
+ collection_name: str,
317
+ embedding_dim: Optional[int] = None
318
+ ) -> Collection:
319
+ """
320
+ Get or create a collection.
321
+
322
+ Args:
323
+ collection_name: Name of the collection
324
+ embedding_dim: Embedding dimension (uses model dimension if not specified)
325
+
326
+ Returns:
327
+ Collection object
328
+ """
329
+ if embedding_dim is None:
330
+ embedding_dim = self.embedding_dim
331
+
332
+ # Use raw SQL when custom table names are in use
333
+ if self.use_dimension_specific_tables:
334
+ with connection.cursor() as cursor:
335
+ # Try to get existing collection
336
+ cursor.execute(
337
+ f"""
338
+ SELECT id, name, embedding_dim, created_at, updated_at
339
+ FROM {self.table_collection}
340
+ WHERE name = %s
341
+ """,
342
+ [collection_name]
343
+ )
344
+ row = cursor.fetchone()
345
+
346
+ if row:
347
+ # Collection exists - create a Collection object manually
348
+ collection = Collection(
349
+ id=row[0],
350
+ name=row[1],
351
+ embedding_dim=row[2],
352
+ created_at=row[3],
353
+ updated_at=row[4]
354
+ )
355
+ # Mark it as existing in DB
356
+ collection._state.adding = False
357
+ logger.info(
358
+ f"Using existing collection: {collection_name}")
359
+ else:
360
+ # Create new collection
361
+ cursor.execute(
362
+ f"""
363
+ INSERT INTO {self.table_collection} (name, embedding_dim, created_at, updated_at)
364
+ VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
365
+ RETURNING id, name, embedding_dim, created_at, updated_at
366
+ """,
367
+ [collection_name, embedding_dim]
368
+ )
369
+ row = cursor.fetchone()
370
+ collection = Collection(
371
+ id=row[0],
372
+ name=row[1],
373
+ embedding_dim=row[2],
374
+ created_at=row[3],
375
+ updated_at=row[4]
376
+ )
377
+ collection._state.adding = False
378
+ logger.info(f"Created new collection: {collection_name}")
379
+
380
+ return collection
381
+ else:
382
+ # Use Django ORM for standard tables
383
+ collection, created = Collection.objects.get_or_create(
384
+ name=collection_name,
385
+ defaults={"embedding_dim": embedding_dim}
386
+ )
387
+
388
+ logger.info(
389
+ f"{'Created new' if created else 'Using existing'} collection: {collection_name}"
390
+ )
391
+ return collection
392
+
393
+ def get_collection(self, collection_name: str) -> Collection:
394
+ """
395
+ Get an existing collection (raises ValueError if not found).
396
+
397
+ Args:
398
+ collection_name: Name of the collection
399
+
400
+ Returns:
401
+ Collection object
402
+
403
+ Raises:
404
+ ValueError: If collection does not exist
405
+ """
406
+ # Use raw SQL when custom table names are in use
407
+ if self.use_dimension_specific_tables:
408
+ with connection.cursor() as cursor:
409
+ cursor.execute(
410
+ f"""
411
+ SELECT id, name, embedding_dim, created_at, updated_at
412
+ FROM {self.table_collection}
413
+ WHERE name = %s
414
+ """,
415
+ [collection_name]
416
+ )
417
+ row = cursor.fetchone()
418
+
419
+ if not row:
420
+ raise ValueError(
421
+ f"Collection not found: {collection_name}")
422
+
423
+ # Create Collection object from row
424
+ collection = Collection(
425
+ id=row[0],
426
+ name=row[1],
427
+ embedding_dim=row[2],
428
+ created_at=row[3],
429
+ updated_at=row[4]
430
+ )
431
+ collection._state.adding = False
432
+ return collection
433
+ else:
434
+ try:
435
+ return Collection.objects.get(name=collection_name)
436
+ except Collection.DoesNotExist:
437
+ raise ValueError(f"Collection not found: {collection_name}")
438
+
439
+ def _get_distance_operator(self, distance_type: Optional[str] = None) -> str:
440
+ """Get SQL distance operator for the configured similarity metric."""
441
+ if distance_type is None:
442
+ distance_type = self.vs_config.search.similarity_metric
443
+
444
+ operators = {
445
+ "cosine": "<=>",
446
+ "l2": "<->",
447
+ "dot_product": "<#>",
448
+ "dot": "<#>"
449
+ }
450
+
451
+ if distance_type not in operators:
452
+ raise ValueError(f"Unsupported distance type: {distance_type}")
453
+
454
+ return operators[distance_type]
455
+
456
+ @lru_cache(maxsize=1000)
457
+ def _get_query_embedding(self, query: str) -> np.ndarray:
458
+ """Get embedding for a query (with caching)."""
459
+ if not self.vs_config.enable_caching:
460
+ # Don't use cache
461
+ return np.array(self.embedding_model.encode_query(query), dtype=np.float32)
462
+
463
+ embedding = self.embedding_model.encode_query(query)
464
+ return np.array(embedding, dtype=np.float32)
465
+
466
+ def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
467
+ """Normalize embedding vector."""
468
+ norm = np.linalg.norm(embedding)
469
+ if norm > 0:
470
+ return embedding / norm
471
+ return embedding
472
+
473
+ @transaction.atomic
474
+ def create_collection_from_nodes(
475
+ self,
476
+ collection_name: str,
477
+ nodes: List[Node]
478
+ ) -> None:
479
+ """
480
+ Create a collection from nodes.
481
+
482
+ Args:
483
+ collection_name: Name of collection
484
+ nodes: List of Node objects
485
+ """
486
+ if not nodes:
487
+ logger.warning(
488
+ f"No nodes provided for collection '{collection_name}'")
489
+ return
490
+
491
+ # Filter out nodes with None or empty content (these would cause embedding errors)
492
+ original_count = len(nodes)
493
+ nodes = [node for node in nodes if node.content is not None and str(
494
+ node.content).strip()]
495
+
496
+ if len(nodes) < original_count:
497
+ logger.warning(
498
+ f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
499
+
500
+ if not nodes:
501
+ logger.warning(
502
+ f"No valid nodes to add for collection '{collection_name}' after filtering")
503
+ return
504
+
505
+ logger.info(
506
+ f"Creating collection '{collection_name}' with {len(nodes)} nodes")
507
+
508
+ # Get or create collection
509
+ collection = self.get_or_create_collection(collection_name)
510
+
511
+ # Generate embeddings - ensure all content is string type
512
+ texts = [str(node.content) for node in nodes]
513
+ embeddings = self.embedding_model.encode_documents(texts)
514
+
515
+ # Use raw SQL when custom table names are in use
516
+ if self.use_dimension_specific_tables:
517
+ import json
518
+ with connection.cursor() as cursor:
519
+ # Clear existing nodes
520
+ cursor.execute(
521
+ f"DELETE FROM {self.table_nodeentry} WHERE collection_id = %s",
522
+ [collection.id]
523
+ )
524
+
525
+ # Insert nodes using raw SQL
526
+ for i, node in enumerate(nodes):
527
+ # Convert custom_metadata dict to JSON string
528
+ custom_metadata_json = json.dumps(
529
+ node.metadata.custom or {})
530
+
531
+ cursor.execute(
532
+ f"""
533
+ INSERT INTO {self.table_nodeentry}
534
+ (collection_id, content, embedding, source_file_uuid, position, custom_metadata, created_at, updated_at)
535
+ VALUES (%s, %s, %s, %s, %s, %s::jsonb, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
536
+ RETURNING node_id
537
+ """,
538
+ [
539
+ collection.id,
540
+ node.content,
541
+ embeddings[i],
542
+ node.metadata.source_file_uuid,
543
+ node.metadata.position,
544
+ custom_metadata_json
545
+ ]
546
+ )
547
+ node_id = cursor.fetchone()[0]
548
+ node.metadata.node_id = node_id
549
+
550
+ logger.info(
551
+ f"Created collection '{collection_name}' with {len(nodes)} nodes")
552
+ else:
553
+ # Use Django ORM for standard tables
554
+ # Clear existing nodes
555
+ NodeEntry.objects.filter(collection=collection).delete()
556
+
557
+ # Create node entries
558
+ node_entries = [
559
+ NodeEntry(
560
+ collection=collection,
561
+ content=node.content,
562
+ embedding=embeddings[i],
563
+ source_file_uuid=node.metadata.source_file_uuid,
564
+ position=node.metadata.position,
565
+ custom_metadata=node.metadata.custom or {},
566
+ )
567
+ for i, node in enumerate(nodes)
568
+ ]
569
+
570
+ # Bulk insert
571
+ created_entries = NodeEntry.objects.bulk_create(
572
+ node_entries,
573
+ batch_size=self.vs_config.index.batch_insert_size
574
+ )
575
+
576
+ # Update node IDs
577
+ for i, node in enumerate(nodes):
578
+ node.metadata.node_id = created_entries[i].node_id
579
+
580
+ logger.info(
581
+ f"Created collection '{collection_name}' with {len(created_entries)} nodes")
582
+
583
+ @transaction.atomic
584
+ def create_collection_from_files(
585
+ self,
586
+ collection_name: str,
587
+ files: List[VSFile]
588
+ ) -> None:
589
+ """
590
+ Create collection from VSFile objects.
591
+
592
+ Args:
593
+ collection_name: Name of collection
594
+ files: List of VSFile objects
595
+ """
596
+ nodes = [node for file in files for node in file.nodes]
597
+ self.create_collection_from_nodes(collection_name, nodes)
598
+
599
+ def add_nodes(self, collection_name: str, nodes: List[Node]) -> None:
600
+ """
601
+ Add nodes to existing collection.
602
+
603
+ Args:
604
+ collection_name: Name of collection
605
+ nodes: Nodes to add
606
+ """
607
+ if not nodes:
608
+ logger.warning("No nodes to add")
609
+ return
610
+
611
+ # Filter out nodes with None or empty content (these would cause embedding errors)
612
+ original_count = len(nodes)
613
+ nodes = [node for node in nodes if node.content is not None and str(
614
+ node.content).strip()]
615
+
616
+ if len(nodes) < original_count:
617
+ logger.warning(
618
+ f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
619
+
620
+ if not nodes:
621
+ logger.warning("No valid nodes to add after filtering")
622
+ return
623
+
624
+ logger.info(
625
+ f"Adding {len(nodes)} nodes to collection '{collection_name}'")
626
+
627
+ # Generate embeddings BEFORE starting the database transaction
628
+ # This is critical for large datasets as embedding generation can take minutes,
629
+ # which would otherwise cause the DB connection to timeout
630
+ logger.info(
631
+ f"Preparing {len(nodes)} nodes for embedding generation...")
632
+ texts = [str(node.content) for node in nodes]
633
+ total_texts = len(texts)
634
+ logger.info(
635
+ f"Starting embedding generation for {total_texts} texts (this may take a while)...")
636
+ logger.info(
637
+ f"Model: {self.embedding_model.model_name}, Batch size: {self.embedding_model.batch_size}")
638
+ logger.info(
639
+ f"Expected batches: {(total_texts + self.embedding_model.batch_size - 1) // self.embedding_model.batch_size}")
640
+ embed_start_time = time.time()
641
+
642
+ logger.info("Calling embedding model encode_documents()...")
643
+ embeddings = self.embedding_model.encode_documents(texts)
644
+
645
+ embed_elapsed = time.time() - embed_start_time
646
+ avg_rate = total_texts / embed_elapsed if embed_elapsed > 0 else 0
647
+ logger.info(
648
+ f"✓ Embedding generation completed in {embed_elapsed:.1f}s ({avg_rate:.1f} texts/s)")
649
+ logger.info(f"Generated {len(embeddings)} embeddings")
650
+
651
+ # MEMORY OPTIMIZATION: Clear texts list after embedding generation
652
+ # The texts are no longer needed as embeddings are already computed
653
+ del texts
654
+ import gc
655
+ gc.collect()
656
+
657
+ # Now perform the database operations within a transaction
658
+ self._insert_nodes_with_embeddings(collection_name, nodes, embeddings)
659
+
660
+ # MEMORY OPTIMIZATION: Clear embeddings after insertion
661
+ del embeddings
662
+ gc.collect()
663
+
664
+ @transaction.atomic
665
+ def _insert_nodes_with_embeddings(self, collection_name: str, nodes: List[Node], embeddings: List) -> None:
666
+ """
667
+ Insert nodes with pre-computed embeddings into the database.
668
+
669
+ This is a separate method to allow embedding generation to happen
670
+ outside the database transaction, preventing connection timeouts
671
+ for large datasets.
672
+
673
+ Args:
674
+ collection_name: Name of collection
675
+ nodes: Nodes to insert
676
+ embeddings: Pre-computed embeddings for the nodes
677
+ """
678
+ collection = self.get_collection(collection_name)
679
+
680
+ if self.use_dimension_specific_tables:
681
+ # Use raw SQL with batch inserts when custom tables are in use
682
+ import json
683
+ from django.db import transaction
684
+
685
+ batch_size = self.vs_config.index.batch_insert_size
686
+ total_nodes = len(nodes)
687
+ total_batches = (total_nodes + batch_size - 1) // batch_size
688
+
689
+ logger.info(
690
+ f"Starting batch insert: {total_nodes} nodes in {total_batches} batches (batch_size={batch_size})")
691
+ insert_start_time = time.time()
692
+
693
+ # Process in batches with individual transactions to prevent connection timeout
694
+ # Each batch is committed separately to avoid long-running transactions
695
+ for batch_idx, batch_start in enumerate(range(0, total_nodes, batch_size)):
696
+ batch_start_time = time.time()
697
+ batch_end = min(batch_start + batch_size, total_nodes)
698
+ batch_nodes = nodes[batch_start:batch_end]
699
+ batch_embeddings = embeddings[batch_start:batch_end]
700
+
701
+ # Use atomic transaction for each batch
702
+ with transaction.atomic():
703
+ with connection.cursor() as cursor:
704
+ # Build batch insert values
705
+ values_list = []
706
+ params = []
707
+ for i, node in enumerate(batch_nodes):
708
+ custom_metadata_json = json.dumps(
709
+ node.metadata.custom or {})
710
+ # Convert embedding to string format for pgvector: "[1.0, 2.0, ...]"
711
+ embedding_values = batch_embeddings[i].tolist() if hasattr(
712
+ batch_embeddings[i], 'tolist') else list(batch_embeddings[i])
713
+ embedding_str = "[" + ",".join(str(x)
714
+ for x in embedding_values) + "]"
715
+ # Convert any dict values to JSON strings for psycopg2 compatibility
716
+ source_file_uuid = json.dumps(node.metadata.source_file_uuid) if isinstance(
717
+ node.metadata.source_file_uuid, dict) else node.metadata.source_file_uuid
718
+ position = json.dumps(node.metadata.position) if isinstance(
719
+ node.metadata.position, dict) else node.metadata.position
720
+ content = json.dumps(node.content) if isinstance(
721
+ node.content, dict) else node.content
722
+
723
+ values_list.append(
724
+ "(%s, %s, %s::vector, %s, %s, %s::jsonb, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)")
725
+ params.extend([
726
+ collection.id,
727
+ content,
728
+ embedding_str,
729
+ source_file_uuid,
730
+ position,
731
+ custom_metadata_json
732
+ ])
733
+
734
+ # Execute batch insert
735
+ cursor.execute(
736
+ f"""
737
+ INSERT INTO {self.table_nodeentry}
738
+ (collection_id, content, embedding, source_file_uuid, position, custom_metadata, created_at, updated_at)
739
+ VALUES {", ".join(values_list)}
740
+ RETURNING node_id
741
+ """,
742
+ params
743
+ )
744
+
745
+ # Get returned node IDs and update nodes
746
+ node_ids = cursor.fetchall()
747
+ for i, node in enumerate(batch_nodes):
748
+ node.metadata.node_id = node_ids[i][0]
749
+
750
+ # Transaction is committed here automatically when exiting the atomic() context
751
+ batch_elapsed = time.time() - batch_start_time
752
+
753
+ # Log progress for every batch or at milestones
754
+ current_batch = batch_idx + 1
755
+ if current_batch % 10 == 0 or current_batch == total_batches or batch_elapsed > 1.0:
756
+ total_elapsed = time.time() - insert_start_time
757
+ nodes_per_sec = batch_end / total_elapsed if total_elapsed > 0 else 0
758
+ eta_seconds = (total_nodes - batch_end) / \
759
+ nodes_per_sec if nodes_per_sec > 0 else 0
760
+ logger.info(
761
+ f"Insert progress: {batch_end}/{total_nodes} nodes "
762
+ f"({batch_end * 100 // total_nodes}%) | "
763
+ f"Batch {current_batch}/{total_batches} took {batch_elapsed:.2f}s | "
764
+ f"Rate: {nodes_per_sec:.0f} nodes/s | "
765
+ f"ETA: {eta_seconds:.0f}s"
766
+ )
767
+
768
+ total_time = time.time() - insert_start_time
769
+ logger.info(
770
+ f"Completed inserting {total_nodes} nodes to '{collection_name}' in {total_time:.2f}s ({total_nodes/total_time:.0f} nodes/s)")
771
+ else:
772
+ # Create entries using ORM
773
+ node_entries = [
774
+ NodeEntry(
775
+ collection=collection,
776
+ content=node.content,
777
+ embedding=embeddings[i],
778
+ source_file_uuid=node.metadata.source_file_uuid,
779
+ position=node.metadata.position,
780
+ custom_metadata=node.metadata.custom or {},
781
+ )
782
+ for i, node in enumerate(nodes)
783
+ ]
784
+
785
+ created_entries = NodeEntry.objects.bulk_create(
786
+ node_entries,
787
+ batch_size=self.vs_config.index.batch_insert_size
788
+ )
789
+
790
+ # Update node IDs
791
+ for i, node in enumerate(nodes):
792
+ node.metadata.node_id = created_entries[i].node_id
793
+
794
+ logger.info(
795
+ f"Added {len(created_entries)} nodes to '{collection_name}'")
796
+
797
+ @transaction.atomic
798
+ def update_vector(
799
+ self,
800
+ collection_name: str,
801
+ node_id: int,
802
+ new_content: Optional[str] = None,
803
+ new_embedding: Optional[List[float]] = None,
804
+ new_metadata: Optional[Dict[str, Any]] = None
805
+ ) -> None:
806
+ """
807
+ Update a vector in the collection.
808
+
809
+ Args:
810
+ collection_name: Name of collection
811
+ node_id: ID of node to update
812
+ new_content: New content (will regenerate embedding if provided)
813
+ new_embedding: New embedding vector (used if new_content not provided)
814
+ new_metadata: New metadata to merge with existing
815
+ """
816
+ collection = self.get_collection(collection_name)
817
+
818
+ if self.use_dimension_specific_tables:
819
+ # Use raw SQL when custom tables are in use
820
+ import json
821
+ with connection.cursor() as cursor:
822
+ # First, get the current node
823
+ cursor.execute(
824
+ f"""
825
+ SELECT content, embedding, custom_metadata
826
+ FROM {self.table_nodeentry}
827
+ WHERE collection_id = %s AND node_id = %s
828
+ """,
829
+ [collection.id, node_id]
830
+ )
831
+ row = cursor.fetchone()
832
+
833
+ if not row:
834
+ raise ValueError(
835
+ f"Node {node_id} not found in collection '{collection_name}'")
836
+
837
+ current_content, current_embedding, current_metadata = row
838
+
839
+ # Parse JSON metadata if it's a string
840
+ if isinstance(current_metadata, str):
841
+ current_metadata = json.loads(current_metadata)
842
+
843
+ # Determine updates
844
+ updated_content = current_content
845
+ updated_embedding = current_embedding
846
+ updated_metadata = current_metadata or {}
847
+
848
+ if new_content is not None:
849
+ updated_content = new_content
850
+ updated_embedding = self.embedding_model.encode_query(
851
+ new_content)
852
+ logger.info(
853
+ f"Updated content and regenerated embedding for node {node_id}")
854
+ elif new_embedding is not None:
855
+ updated_embedding = new_embedding
856
+ logger.info(f"Updated embedding for node {node_id}")
857
+
858
+ if new_metadata is not None:
859
+ updated_metadata.update(new_metadata)
860
+ logger.info(f"Updated metadata for node {node_id}")
861
+
862
+ # Update the node
863
+ updated_metadata_json = json.dumps(updated_metadata)
864
+ # Convert embedding to string format for pgvector: "[1.0, 2.0, ...]"
865
+ if updated_embedding is not None:
866
+ embedding_values = updated_embedding.tolist() if hasattr(
867
+ updated_embedding, 'tolist') else list(updated_embedding)
868
+ embedding_str = "[" + ",".join(str(x)
869
+ for x in embedding_values) + "]"
870
+ else:
871
+ embedding_str = None
872
+ cursor.execute(
873
+ f"""
874
+ UPDATE {self.table_nodeentry}
875
+ SET content = %s, embedding = %s::vector, custom_metadata = %s::jsonb, updated_at = CURRENT_TIMESTAMP
876
+ WHERE collection_id = %s AND node_id = %s
877
+ """,
878
+ [updated_content, embedding_str,
879
+ updated_metadata_json, collection.id, node_id]
880
+ )
881
+
882
+ logger.info(
883
+ f"Successfully updated node {node_id} in collection '{collection_name}'")
884
+ else:
885
+ try:
886
+ node_entry = NodeEntry.objects.get(
887
+ collection=collection, node_id=node_id)
888
+ except NodeEntry.DoesNotExist:
889
+ raise ValueError(
890
+ f"Node {node_id} not found in collection '{collection_name}'")
891
+
892
+ # Update content and embedding
893
+ if new_content is not None:
894
+ node_entry.content = new_content
895
+ # Generate new embedding
896
+ embedding = self.embedding_model.encode_query(new_content)
897
+ node_entry.embedding = embedding
898
+ logger.info(
899
+ f"Updated content and regenerated embedding for node {node_id}")
900
+ elif new_embedding is not None:
901
+ node_entry.embedding = new_embedding
902
+ logger.info(f"Updated embedding for node {node_id}")
903
+
904
+ # Update metadata
905
+ if new_metadata is not None:
906
+ # Merge with existing metadata
907
+ current_metadata = node_entry.custom_metadata or {}
908
+ current_metadata.update(new_metadata)
909
+ node_entry.custom_metadata = current_metadata
910
+ logger.info(f"Updated metadata for node {node_id}")
911
+
912
+ node_entry.save()
913
+ logger.info(
914
+ f"Successfully updated node {node_id} in collection '{collection_name}'")
915
+
916
+ @transaction.atomic
917
+ def delete_nodes(self, collection_name: str, node_ids: List[int]) -> None:
918
+ """
919
+ Delete nodes from collection.
920
+
921
+ Args:
922
+ collection_name: Name of collection
923
+ node_ids: List of node IDs to delete
924
+ """
925
+ if not node_ids:
926
+ logger.warning("No node IDs to delete")
927
+ return
928
+
929
+ collection = self.get_collection(collection_name)
930
+
931
+ if self.use_dimension_specific_tables:
932
+ # Use raw SQL when custom tables are in use
933
+ with connection.cursor() as cursor:
934
+ placeholders = ','.join(['%s'] * len(node_ids))
935
+ cursor.execute(
936
+ f"""
937
+ DELETE FROM {self.table_nodeentry}
938
+ WHERE collection_id = %s AND node_id IN ({placeholders})
939
+ """,
940
+ [collection.id] + list(node_ids)
941
+ )
942
+ deleted_count = cursor.rowcount
943
+ else:
944
+ deleted_count, _ = NodeEntry.objects.filter(
945
+ collection=collection,
946
+ node_id__in=node_ids
947
+ ).delete()
948
+
949
+ logger.info(
950
+ f"Deleted {deleted_count} nodes from collection '{collection_name}'")
951
+
952
+ @transaction.atomic
953
+ def delete_collection(self, collection_name: str) -> None:
954
+ """
955
+ Delete a collection and all its nodes.
956
+
957
+ Args:
958
+ collection_name: Name of collection to delete
959
+ """
960
+ collection = self.get_collection(collection_name)
961
+
962
+ if self.use_dimension_specific_tables:
963
+ # Use raw SQL when custom tables are in use
964
+ with connection.cursor() as cursor:
965
+ # Get node count first
966
+ cursor.execute(
967
+ f"SELECT COUNT(*) FROM {self.table_nodeentry} WHERE collection_id = %s",
968
+ [collection.id]
969
+ )
970
+ node_count = cursor.fetchone()[0]
971
+
972
+ # Delete nodes (cascade should handle this, but be explicit)
973
+ cursor.execute(
974
+ f"DELETE FROM {self.table_nodeentry} WHERE collection_id = %s",
975
+ [collection.id]
976
+ )
977
+
978
+ # Delete collection
979
+ cursor.execute(
980
+ f"DELETE FROM {self.table_collection} WHERE id = %s",
981
+ [collection.id]
982
+ )
983
+
984
+ logger.info(
985
+ f"Deleted collection '{collection_name}' with {node_count} nodes")
986
+ else:
987
+ node_count = NodeEntry.objects.filter(
988
+ collection=collection).count()
989
+ collection.delete()
990
+ logger.info(
991
+ f"Deleted collection '{collection_name}' with {node_count} nodes")
992
+
993
+ def keyword_search(
994
+ self,
995
+ collection_name: str,
996
+ query: str,
997
+ number: Optional[int] = None,
998
+ meta_data_filters: Optional[Dict[str, Any]] = None,
999
+ min_rank: float = 0.0,
1000
+ ranking_algorithm: Optional[str] = None
1001
+ ) -> Tuple[Dict, List[Node]]:
1002
+ """
1003
+ Perform pure keyword-based full-text search in collection.
1004
+
1005
+ Supports multiple ranking algorithms:
1006
+ - BM25: Okapi BM25 ranking function (default, state-of-the-art)
1007
+ - ts_rank: PostgreSQL's native full-text search ranking
1008
+
1009
+ Args:
1010
+ collection_name: Name of collection to search
1011
+ query: Search query (keywords)
1012
+ number: Number of results to return (uses config default if None)
1013
+ meta_data_filters: Optional metadata filters to apply
1014
+ min_rank: Minimum rank threshold (0.0 to 1.0, default 0.0)
1015
+ ranking_algorithm: Ranking algorithm to use ('bm25' or 'ts_rank', uses config default if None)
1016
+
1017
+ Returns:
1018
+ Tuple of (results dict, list of Node objects)
1019
+
1020
+ Example:
1021
+ # Using BM25 (default)
1022
+ results, nodes = store.keyword_search(
1023
+ collection_name="my_docs",
1024
+ query="machine learning algorithms",
1025
+ number=10,
1026
+ min_rank=0.01
1027
+ )
1028
+
1029
+ # Using ts_rank
1030
+ results, nodes = store.keyword_search(
1031
+ collection_name="my_docs",
1032
+ query="machine learning algorithms",
1033
+ number=10,
1034
+ ranking_algorithm="ts_rank"
1035
+ )
1036
+ """
1037
+ # Use config defaults
1038
+ if number is None:
1039
+ number = self.vs_config.search.default_top_k
1040
+ if ranking_algorithm is None:
1041
+ ranking_algorithm = self.vs_config.search.keyword_ranking_algorithm
1042
+
1043
+ logger.info(
1044
+ f"Keyword search ({ranking_algorithm}) in '{collection_name}' for: '{query}'")
1045
+
1046
+ # Get collection
1047
+ try:
1048
+ collection = self.get_collection(collection_name)
1049
+ except ValueError as e:
1050
+ logger.error(str(e))
1051
+ raise
1052
+
1053
+ # Use correct table name
1054
+ table_name = self.table_nodeentry if self.use_dimension_specific_tables else NodeEntry._meta.db_table
1055
+
1056
+ # Build WHERE clause with metadata filters
1057
+ where_conditions = ["collection_id = %s"]
1058
+ where_params = [collection.id]
1059
+
1060
+ if meta_data_filters:
1061
+ for key, value in meta_data_filters.items():
1062
+ where_conditions.append(f"custom_metadata->>'{key}' = %s")
1063
+ where_params.append(str(value))
1064
+
1065
+ where_clause = " AND ".join(where_conditions)
1066
+
1067
+ # Build SQL query based on ranking algorithm
1068
+ if ranking_algorithm == "bm25":
1069
+ # BM25 uses where_clause twice (doc_stats and collection_stats)
1070
+ # Parameter order: where_params (doc_stats), where_params (collection_stats), query, min_rank, limit
1071
+ sql_query = self._build_bm25_query(
1072
+ table_name, where_clause, query, min_rank, number)
1073
+ query_params = where_params + \
1074
+ where_params + [query, min_rank, number]
1075
+ else: # ts_rank
1076
+ # ts_rank uses where_clause once
1077
+ # Parameter order: query (SELECT), where_params (WHERE), query (AND), query (AND), min_rank, limit
1078
+ sql_query = self._build_tsrank_query(
1079
+ table_name, where_clause, query, min_rank, number)
1080
+ query_params = [query] + where_params + \
1081
+ [query, query, min_rank, number]
1082
+
1083
+ # Execute query
1084
+ with connection.cursor() as cursor:
1085
+ cursor.execute(sql_query, query_params)
1086
+ results = cursor.fetchall()
1087
+ columns = [col[0] for col in cursor.description]
1088
+
1089
+ # Process results
1090
+ valid_suggestions = {}
1091
+ suggested_nodes = []
1092
+ seen_texts = set()
1093
+
1094
+ for row in results:
1095
+ result_dict = dict(zip(columns, row))
1096
+ node_id = result_dict["node_id"]
1097
+ content = result_dict["content"]
1098
+ rank = result_dict["rank"]
1099
+
1100
+ if content not in seen_texts:
1101
+ seen_texts.add(content)
1102
+
1103
+ custom_metadata = result_dict["custom_metadata"] or {}
1104
+
1105
+ metadata = NodeMetadata(
1106
+ source_file_uuid=result_dict["source_file_uuid"],
1107
+ position=result_dict["position"],
1108
+ custom=custom_metadata,
1109
+ )
1110
+ metadata.node_id = node_id
1111
+
1112
+ node = Node(content=content, metadata=metadata)
1113
+ suggested_nodes.append(node)
1114
+
1115
+ valid_suggestions[str(node_id)] = (
1116
+ {
1117
+ "node_id": node_id,
1118
+ "source_file_uuid": result_dict["source_file_uuid"],
1119
+ "position": result_dict["position"],
1120
+ "custom": custom_metadata,
1121
+ },
1122
+ content,
1123
+ float(rank),
1124
+ )
1125
+
1126
+ logger.info(
1127
+ f"Keyword search returned {len(valid_suggestions)} results")
1128
+ return valid_suggestions, suggested_nodes
1129
+
1130
+ def _build_bm25_query(
1131
+ self,
1132
+ table_name: str,
1133
+ where_clause: str,
1134
+ query: str,
1135
+ min_rank: float,
1136
+ limit: int
1137
+ ) -> str:
1138
+ """
1139
+ Build BM25 ranking SQL query.
1140
+
1141
+ BM25 (Best Matching 25) is a ranking function used by search engines.
1142
+ It's based on the probabilistic retrieval framework and considers:
1143
+ - Term frequency (TF): How often query terms appear in the document
1144
+ - Inverse document frequency (IDF): How rare the terms are across all documents
1145
+ - Document length normalization: Adjusts for document length
1146
+
1147
+ Formula: BM25(D,Q) = Σ IDF(qi) * (f(qi,D) * (k1 + 1)) / (f(qi,D) + k1 * (1 - b + b * |D| / avgdl))
1148
+
1149
+ Where:
1150
+ - D: document
1151
+ - Q: query
1152
+ - qi: query term i
1153
+ - f(qi,D): frequency of qi in D
1154
+ - |D|: length of document D
1155
+ - avgdl: average document length in the collection
1156
+ - k1: term frequency saturation parameter (default: 1.5)
1157
+ - b: length normalization parameter (default: 0.75)
1158
+ """
1159
+ k1 = self.vs_config.search.bm25_k1
1160
+ b = self.vs_config.search.bm25_b
1161
+
1162
+ return f"""
1163
+ WITH doc_stats AS (
1164
+ -- Calculate document statistics
1165
+ SELECT
1166
+ node_id,
1167
+ content,
1168
+ source_file_uuid,
1169
+ position,
1170
+ custom_metadata,
1171
+ LENGTH(content) AS doc_length,
1172
+ to_tsvector('english', content) AS doc_vector
1173
+ FROM
1174
+ {table_name}
1175
+ WHERE
1176
+ {where_clause}
1177
+ ),
1178
+ collection_stats AS (
1179
+ -- Calculate collection-wide statistics
1180
+ SELECT
1181
+ AVG(LENGTH(content)) AS avg_doc_length,
1182
+ COUNT(*) AS total_docs
1183
+ FROM
1184
+ {table_name}
1185
+ WHERE
1186
+ {where_clause}
1187
+ ),
1188
+ query_terms AS (
1189
+ -- Extract query terms and calculate IDF
1190
+ SELECT
1191
+ word,
1192
+ -- IDF calculation: log((N - df + 0.5) / (df + 0.5) + 1)
1193
+ -- where N is total docs and df is document frequency
1194
+ LN(
1195
+ (cs.total_docs - COUNT(DISTINCT ds.node_id) + 0.5) /
1196
+ (COUNT(DISTINCT ds.node_id) + 0.5) + 1
1197
+ ) AS idf
1198
+ FROM
1199
+ unnest(string_to_array(lower(%s), ' ')) AS word,
1200
+ doc_stats ds,
1201
+ collection_stats cs
1202
+ WHERE
1203
+ ds.doc_vector @@ to_tsquery('english', word)
1204
+ GROUP BY
1205
+ word, cs.total_docs
1206
+ ),
1207
+ bm25_scores AS (
1208
+ -- Calculate BM25 score for each document
1209
+ SELECT
1210
+ ds.node_id,
1211
+ ds.content,
1212
+ ds.source_file_uuid,
1213
+ ds.position,
1214
+ ds.custom_metadata,
1215
+ SUM(
1216
+ qt.idf *
1217
+ (
1218
+ -- Term frequency component
1219
+ (ts_rank(ds.doc_vector, to_tsquery('english', qt.word)) * 1000 * ({k1} + 1)) /
1220
+ (
1221
+ ts_rank(ds.doc_vector, to_tsquery('english', qt.word)) * 1000 +
1222
+ {k1} * (1 - {b} + {b} * ds.doc_length / cs.avg_doc_length)
1223
+ )
1224
+ )
1225
+ ) AS bm25_score
1226
+ FROM
1227
+ doc_stats ds
1228
+ CROSS JOIN
1229
+ collection_stats cs
1230
+ CROSS JOIN
1231
+ query_terms qt
1232
+ WHERE
1233
+ ds.doc_vector @@ to_tsquery('english', qt.word)
1234
+ GROUP BY
1235
+ ds.node_id,
1236
+ ds.content,
1237
+ ds.source_file_uuid,
1238
+ ds.position,
1239
+ ds.custom_metadata
1240
+ )
1241
+ SELECT
1242
+ node_id,
1243
+ content,
1244
+ source_file_uuid,
1245
+ position,
1246
+ custom_metadata,
1247
+ COALESCE(bm25_score, 0.0) AS rank
1248
+ FROM
1249
+ bm25_scores
1250
+ WHERE
1251
+ COALESCE(bm25_score, 0.0) > %s
1252
+ ORDER BY
1253
+ rank DESC
1254
+ LIMIT
1255
+ %s
1256
+ """
1257
+
1258
+ def _build_tsrank_query(
1259
+ self,
1260
+ table_name: str,
1261
+ where_clause: str,
1262
+ query: str,
1263
+ min_rank: float,
1264
+ limit: int
1265
+ ) -> str:
1266
+ """
1267
+ Build ts_rank SQL query.
1268
+
1269
+ Uses PostgreSQL's native full-text search ranking function.
1270
+ """
1271
+ return f"""
1272
+ SELECT
1273
+ node_id,
1274
+ content,
1275
+ source_file_uuid,
1276
+ position,
1277
+ custom_metadata,
1278
+ ts_rank(to_tsvector('english', content), plainto_tsquery('english', %s)) AS rank
1279
+ FROM
1280
+ {table_name}
1281
+ WHERE
1282
+ {where_clause}
1283
+ AND to_tsvector('english', content) @@ plainto_tsquery('english', %s)
1284
+ AND ts_rank(to_tsvector('english', content), plainto_tsquery('english', %s)) > %s
1285
+ ORDER BY
1286
+ rank DESC
1287
+ LIMIT
1288
+ %s
1289
+ """
1290
+
1291
+ def search(
1292
+ self,
1293
+ collection_name: str,
1294
+ query: str,
1295
+ distance_type: Optional[str] = None,
1296
+ number: Optional[int] = None,
1297
+ meta_data_filters: Optional[Dict[str, Any]] = None,
1298
+ hybrid_search: Optional[bool] = None
1299
+ ) -> Tuple[Dict, List[Node]]:
1300
+ """
1301
+ Search for similar vectors in collection.
1302
+
1303
+ Args:
1304
+ collection_name: Name of collection to search
1305
+ query: Search query
1306
+ distance_type: Distance metric (uses config default if None)
1307
+ number: Number of results (uses config default if None)
1308
+ meta_data_filters: Metadata filters
1309
+ hybrid_search: Enable hybrid search (uses config default if None)
1310
+
1311
+ Returns:
1312
+ Tuple of (results dict, list of Node objects)
1313
+ """
1314
+ # Use config defaults
1315
+ if distance_type is None:
1316
+ distance_type = self.vs_config.search.similarity_metric
1317
+ if number is None:
1318
+ number = self.vs_config.search.default_top_k
1319
+ if hybrid_search is None:
1320
+ hybrid_search = self.vs_config.search.enable_hybrid_search
1321
+
1322
+ logger.info(f"Searching in '{collection_name}' for: '{query}'")
1323
+
1324
+ # Get collection
1325
+ try:
1326
+ collection = self.get_collection(collection_name)
1327
+ except ValueError as e:
1328
+ logger.error(str(e))
1329
+ raise
1330
+
1331
+ # Get query embedding
1332
+ query_embedding = self._get_query_embedding(query)
1333
+
1334
+ # Normalize if using cosine distance
1335
+ if distance_type == "cosine":
1336
+ query_embedding = self._normalize_embedding(query_embedding)
1337
+
1338
+ # Determine search buffer
1339
+ search_buffer_factor = (
1340
+ self.vs_config.search.search_buffer_factor if hybrid_search else 1
1341
+ )
1342
+ limit = number * search_buffer_factor
1343
+
1344
+ # Build SQL query
1345
+ distance_operator = self._get_distance_operator(distance_type)
1346
+ embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
1347
+
1348
+ # Use correct table name
1349
+ table_name = self.table_nodeentry if self.use_dimension_specific_tables else NodeEntry._meta.db_table
1350
+
1351
+ # Build WHERE clause with metadata filters
1352
+ where_conditions = ["collection_id = %s"]
1353
+ query_params = [embedding_str, collection.id]
1354
+
1355
+ if meta_data_filters:
1356
+ for key, value in meta_data_filters.items():
1357
+ where_conditions.append(f"custom_metadata->>'{key}' = %s")
1358
+ query_params.append(str(value))
1359
+
1360
+ where_clause = " AND ".join(where_conditions)
1361
+ query_params.append(limit)
1362
+
1363
+ sql_query = f"""
1364
+ SELECT
1365
+ node_id,
1366
+ content,
1367
+ source_file_uuid,
1368
+ position,
1369
+ custom_metadata,
1370
+ embedding {distance_operator} %s::vector AS distance
1371
+ FROM
1372
+ {table_name}
1373
+ WHERE
1374
+ {where_clause}
1375
+ ORDER BY
1376
+ distance
1377
+ LIMIT
1378
+ %s
1379
+ """
1380
+
1381
+ # Execute query
1382
+ with connection.cursor() as cursor:
1383
+ cursor.execute(sql_query, query_params)
1384
+ results = cursor.fetchall()
1385
+ columns = [col[0] for col in cursor.description]
1386
+
1387
+ # Process results
1388
+ valid_suggestions = {}
1389
+ suggested_nodes = []
1390
+ seen_texts = set()
1391
+
1392
+ for row in results:
1393
+ result_dict = dict(zip(columns, row))
1394
+ node_id = result_dict["node_id"]
1395
+ content = result_dict["content"]
1396
+ distance = result_dict["distance"]
1397
+
1398
+ if content not in seen_texts:
1399
+ seen_texts.add(content)
1400
+
1401
+ custom_metadata = result_dict["custom_metadata"] or {}
1402
+
1403
+ metadata = NodeMetadata(
1404
+ source_file_uuid=result_dict["source_file_uuid"],
1405
+ position=result_dict["position"],
1406
+ custom=custom_metadata,
1407
+ )
1408
+ metadata.node_id = node_id
1409
+
1410
+ node = Node(content=content, metadata=metadata)
1411
+ suggested_nodes.append(node)
1412
+
1413
+ valid_suggestions[str(node_id)] = (
1414
+ {
1415
+ "node_id": node_id,
1416
+ "source_file_uuid": result_dict["source_file_uuid"],
1417
+ "position": result_dict["position"],
1418
+ "custom": custom_metadata,
1419
+ },
1420
+ content,
1421
+ float(distance),
1422
+ )
1423
+
1424
+ # Apply hybrid search and re-ranking if enabled
1425
+ if hybrid_search and self.vs_config.search.rerank and valid_suggestions:
1426
+ valid_suggestions, suggested_nodes = self._rerank_results(
1427
+ query, list(valid_suggestions.values()
1428
+ ), suggested_nodes, number
1429
+ )
1430
+
1431
+ logger.info(f"Search returned {len(valid_suggestions)} results")
1432
+ return valid_suggestions, suggested_nodes
1433
+
1434
+ def _rerank_results(
1435
+ self,
1436
+ query: str,
1437
+ results: List[Tuple[Dict, str, float]],
1438
+ suggested_nodes: List[Node],
1439
+ top_k: int,
1440
+ ) -> Tuple[Dict, List[Node]]:
1441
+ """Re-rank results using hybrid scoring."""
1442
+ logger.debug(f"Re-ranking {len(results)} results")
1443
+
1444
+ # Get hybrid alpha from config
1445
+ alpha = self.vs_config.search.hybrid_alpha
1446
+
1447
+ # Perform full-text search
1448
+ node_ids = [int(res[0]["node_id"]) for res in results]
1449
+
1450
+ if self.use_dimension_specific_tables:
1451
+ # Use raw SQL for full-text search with custom tables
1452
+ with connection.cursor() as cursor:
1453
+ placeholders = ','.join(['%s'] * len(node_ids))
1454
+ cursor.execute(
1455
+ f"""
1456
+ SELECT node_id,
1457
+ ts_rank(to_tsvector('english', content), plainto_tsquery('english', %s)) as rank
1458
+ FROM {self.table_nodeentry}
1459
+ WHERE node_id IN ({placeholders})
1460
+ """,
1461
+ [query] + node_ids
1462
+ )
1463
+ node_id_to_rank = {row[0]: row[1] for row in cursor.fetchall()}
1464
+ else:
1465
+ search_query = SearchQuery(query, config="english")
1466
+ queryset = NodeEntry.objects.filter(
1467
+ node_id__in=node_ids
1468
+ ).annotate(
1469
+ rank=SearchRank(SearchVector(
1470
+ "content", config="english"), search_query)
1471
+ )
1472
+ node_id_to_rank = {node.node_id: node.rank for node in queryset}
1473
+
1474
+ # Combine scores
1475
+ reranked_results = []
1476
+
1477
+ for metadata, content, distance in results:
1478
+ node_id = metadata["node_id"]
1479
+ keyword_score = node_id_to_rank.get(node_id, 0.0)
1480
+
1481
+ # Combined score: alpha * vector + (1-alpha) * keyword
1482
+ combined_score = alpha * (1 - distance) + \
1483
+ (1 - alpha) * keyword_score
1484
+ reranked_results.append((metadata, content, combined_score))
1485
+
1486
+ # Sort and take top_k
1487
+ reranked_results = sorted(
1488
+ reranked_results, key=lambda x: x[2], reverse=True)[:top_k]
1489
+ valid_suggestions = {
1490
+ str(res[0]["node_id"]): res for res in reranked_results}
1491
+
1492
+ # Update node order
1493
+ node_id_order = [res[0]["node_id"] for res in reranked_results]
1494
+ updated_nodes = sorted(
1495
+ suggested_nodes,
1496
+ key=lambda node: (
1497
+ node_id_order.index(node.metadata.node_id)
1498
+ if node.metadata.node_id in node_id_order
1499
+ else len(node_id_order)
1500
+ ),
1501
+ )[:top_k]
1502
+
1503
+ return valid_suggestions, updated_nodes
1504
+
1505
+ def list_collections(self) -> List[str]:
1506
+ """List all collections."""
1507
+ if self.use_dimension_specific_tables:
1508
+ with connection.cursor() as cursor:
1509
+ cursor.execute(
1510
+ f"SELECT name FROM {self.table_collection} ORDER BY name")
1511
+ return [row[0] for row in cursor.fetchall()]
1512
+ else:
1513
+ return list(Collection.objects.values_list("name", flat=True))
1514
+
1515
+ def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
1516
+ """Get information about a collection."""
1517
+ collection = self.get_collection(collection_name)
1518
+
1519
+ if self.use_dimension_specific_tables:
1520
+ with connection.cursor() as cursor:
1521
+ cursor.execute(
1522
+ f"SELECT COUNT(*) FROM {self.table_nodeentry} WHERE collection_id = %s",
1523
+ [collection.id]
1524
+ )
1525
+ node_count = cursor.fetchone()[0]
1526
+ else:
1527
+ node_count = NodeEntry.objects.filter(
1528
+ collection=collection).count()
1529
+ return {
1530
+ "name": collection.name,
1531
+ "embedding_dim": collection.embedding_dim,
1532
+ "node_count": node_count,
1533
+ "created_at": collection.created_at,
1534
+ "updated_at": collection.updated_at,
1535
+ }
1536
+
1537
+ # VectorStore interface methods
1538
+ def add(self, vectors: List[List[float]], metadatas: List[Dict[str, Any]]) -> Any:
1539
+ """Add vectors with metadata (VectorStore interface)."""
1540
+ if not vectors or not metadatas:
1541
+ logger.warning("Empty vectors or metadatas")
1542
+ return []
1543
+
1544
+ if len(vectors) != len(metadatas):
1545
+ raise ValueError(
1546
+ "Number of vectors must match number of metadatas")
1547
+
1548
+ collection_name = metadatas[0].get(
1549
+ "collection_name", "default_collection")
1550
+ collection = self.get_or_create_collection(collection_name)
1551
+
1552
+ if self.use_dimension_specific_tables:
1553
+ # Use raw SQL when custom tables are in use
1554
+ import json
1555
+ node_ids = []
1556
+ with connection.cursor() as cursor:
1557
+ for i, (vector, metadata) in enumerate(zip(vectors, metadatas)):
1558
+ custom_metadata = {
1559
+ k: v
1560
+ for k, v in metadata.items()
1561
+ if k not in ["content", "source_file_uuid", "position", "collection_name"]
1562
+ }
1563
+ custom_metadata_json = json.dumps(custom_metadata)
1564
+
1565
+ cursor.execute(
1566
+ f"""
1567
+ INSERT INTO {self.table_nodeentry}
1568
+ (collection_id, content, embedding, source_file_uuid, position, custom_metadata, created_at, updated_at)
1569
+ VALUES (%s, %s, %s, %s, %s, %s::jsonb, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
1570
+ RETURNING node_id
1571
+ """,
1572
+ [
1573
+ collection.id,
1574
+ metadata.get("content", ""),
1575
+ vector,
1576
+ metadata.get("source_file_uuid", ""),
1577
+ metadata.get("position", i),
1578
+ custom_metadata_json
1579
+ ]
1580
+ )
1581
+ node_id = cursor.fetchone()[0]
1582
+ node_ids.append(node_id)
1583
+ return node_ids
1584
+ else:
1585
+ node_entries = []
1586
+ for i, (vector, metadata) in enumerate(zip(vectors, metadatas)):
1587
+ node_entries.append(
1588
+ NodeEntry(
1589
+ collection=collection,
1590
+ content=metadata.get("content", ""),
1591
+ embedding=vector,
1592
+ source_file_uuid=metadata.get("source_file_uuid", ""),
1593
+ position=metadata.get("position", i),
1594
+ custom_metadata={
1595
+ k: v
1596
+ for k, v in metadata.items()
1597
+ if k not in ["content", "source_file_uuid", "position", "collection_name"]
1598
+ },
1599
+ )
1600
+ )
1601
+
1602
+ created_entries = NodeEntry.objects.bulk_create(node_entries)
1603
+ return [entry.node_id for entry in created_entries]
1604
+
1605
+ def query(self, vector: List[float], top_k: int = 5, **kwargs) -> List[Dict[str, Any]]:
1606
+ """Query vector store (VectorStore interface)."""
1607
+ collection_name = kwargs.get("collection_name", "default_collection")
1608
+ distance_type = kwargs.get(
1609
+ "distance_type", self.vs_config.search.similarity_metric)
1610
+
1611
+ try:
1612
+ collection = self.get_collection(collection_name)
1613
+ except ValueError:
1614
+ logger.warning(f"Collection '{collection_name}' not found")
1615
+ return []
1616
+
1617
+ # Normalize query vector if needed
1618
+ query_embedding = np.array(vector, dtype=np.float32)
1619
+ if distance_type == "cosine":
1620
+ query_embedding = self._normalize_embedding(query_embedding)
1621
+
1622
+ # Build and execute query
1623
+ distance_operator = self._get_distance_operator(distance_type)
1624
+ embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
1625
+
1626
+ # Use correct table name
1627
+ table_name = self.table_nodeentry if self.use_dimension_specific_tables else NodeEntry._meta.db_table
1628
+
1629
+ sql_query = f"""
1630
+ SELECT node_id, content, source_file_uuid, position, custom_metadata,
1631
+ embedding {distance_operator} %s::vector AS distance
1632
+ FROM {table_name}
1633
+ WHERE collection_id = %s
1634
+ ORDER BY distance
1635
+ LIMIT %s
1636
+ """
1637
+
1638
+ with connection.cursor() as cursor:
1639
+ cursor.execute(sql_query, [embedding_str, collection.id, top_k])
1640
+ results = cursor.fetchall()
1641
+ columns = [col[0] for col in cursor.description]
1642
+
1643
+ return [
1644
+ {
1645
+ "node_id": dict(zip(columns, row))["node_id"],
1646
+ "content": dict(zip(columns, row))["content"],
1647
+ "metadata": dict(zip(columns, row))["custom_metadata"] or {},
1648
+ "distance": float(dict(zip(columns, row))["distance"]),
1649
+ }
1650
+ for row in results
1651
+ ]
1652
+
1653
+ def shutdown(self) -> None:
1654
+ """Shutdown and cleanup resources."""
1655
+ logger.info("Shutting down ConfigurablePgVectorStore")
1656
+ if self.embedding_model:
1657
+ self.embedding_model.shutdown()
1658
+ super().shutdown()
1659
+
1660
+
1661
+ __all__ = ["ConfigurablePgVectorStore"]