rakam-systems-vectorstore 0.1.1rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rakam_systems_vectorstore/MANIFEST.in +26 -0
- rakam_systems_vectorstore/README.md +1071 -0
- rakam_systems_vectorstore/__init__.py +93 -0
- rakam_systems_vectorstore/components/__init__.py +0 -0
- rakam_systems_vectorstore/components/chunker/__init__.py +19 -0
- rakam_systems_vectorstore/components/chunker/advanced_chunker.py +1019 -0
- rakam_systems_vectorstore/components/chunker/text_chunker.py +154 -0
- rakam_systems_vectorstore/components/embedding_model/__init__.py +0 -0
- rakam_systems_vectorstore/components/embedding_model/configurable_embeddings.py +546 -0
- rakam_systems_vectorstore/components/embedding_model/openai_embeddings.py +259 -0
- rakam_systems_vectorstore/components/loader/__init__.py +31 -0
- rakam_systems_vectorstore/components/loader/adaptive_loader.py +512 -0
- rakam_systems_vectorstore/components/loader/code_loader.py +699 -0
- rakam_systems_vectorstore/components/loader/doc_loader.py +812 -0
- rakam_systems_vectorstore/components/loader/eml_loader.py +556 -0
- rakam_systems_vectorstore/components/loader/html_loader.py +626 -0
- rakam_systems_vectorstore/components/loader/md_loader.py +622 -0
- rakam_systems_vectorstore/components/loader/odt_loader.py +750 -0
- rakam_systems_vectorstore/components/loader/pdf_loader.py +771 -0
- rakam_systems_vectorstore/components/loader/pdf_loader_light.py +723 -0
- rakam_systems_vectorstore/components/loader/tabular_loader.py +597 -0
- rakam_systems_vectorstore/components/vectorstore/__init__.py +0 -0
- rakam_systems_vectorstore/components/vectorstore/apps.py +10 -0
- rakam_systems_vectorstore/components/vectorstore/configurable_pg_vector_store.py +1661 -0
- rakam_systems_vectorstore/components/vectorstore/faiss_vector_store.py +878 -0
- rakam_systems_vectorstore/components/vectorstore/migrations/0001_initial.py +55 -0
- rakam_systems_vectorstore/components/vectorstore/migrations/__init__.py +0 -0
- rakam_systems_vectorstore/components/vectorstore/models.py +10 -0
- rakam_systems_vectorstore/components/vectorstore/pg_models.py +97 -0
- rakam_systems_vectorstore/components/vectorstore/pg_vector_store.py +827 -0
- rakam_systems_vectorstore/config.py +266 -0
- rakam_systems_vectorstore/core.py +8 -0
- rakam_systems_vectorstore/pyproject.toml +113 -0
- rakam_systems_vectorstore/server/README.md +290 -0
- rakam_systems_vectorstore/server/__init__.py +20 -0
- rakam_systems_vectorstore/server/mcp_server_vector.py +325 -0
- rakam_systems_vectorstore/setup.py +103 -0
- rakam_systems_vectorstore-0.1.1rc7.dist-info/METADATA +370 -0
- rakam_systems_vectorstore-0.1.1rc7.dist-info/RECORD +40 -0
- rakam_systems_vectorstore-0.1.1rc7.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1661 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configurable PostgreSQL Vector Store with enhanced features.
|
|
3
|
+
|
|
4
|
+
This module provides an enhanced, fully configurable PgVectorStore that:
|
|
5
|
+
- Supports configuration via YAML/JSON files or dictionaries
|
|
6
|
+
- Allows pluggable embedding models
|
|
7
|
+
- Provides update_vector capability
|
|
8
|
+
- Maintains clean separation from other components
|
|
9
|
+
- Supports all search configurations
|
|
10
|
+
- **Dimension-agnostic vector storage**: No need to recreate tables when switching models!
|
|
11
|
+
|
|
12
|
+
## Flexible Vector Storage
|
|
13
|
+
|
|
14
|
+
Vector columns are created WITHOUT dimension constraints, allowing you to:
|
|
15
|
+
✓ Switch between embedding models without altering database schema
|
|
16
|
+
✓ Store vectors of any dimension in the same table structure
|
|
17
|
+
✓ No automatic table recreation or data loss
|
|
18
|
+
✓ Simplified database management
|
|
19
|
+
|
|
20
|
+
## Multi-Model Support
|
|
21
|
+
|
|
22
|
+
By default (use_dimension_specific_tables=True), each embedding model automatically
|
|
23
|
+
gets its own dedicated tables based on the model name:
|
|
24
|
+
|
|
25
|
+
- 'all-MiniLM-L6-v2' → application_nodeentry_all_minilm_l6_v2
|
|
26
|
+
- 'multi-qa-mpnet-base-cos-v1' → application_nodeentry_multi_qa_mpnet_base_cos_v1
|
|
27
|
+
- 'text-embedding-ada-002' → application_nodeentry_text_embedding_ada_002
|
|
28
|
+
|
|
29
|
+
**Why model-specific tables?**
|
|
30
|
+
|
|
31
|
+
Even if two models have the same dimensions (e.g., both 384D), their vector spaces
|
|
32
|
+
are completely different! Mixing embeddings from different models would give
|
|
33
|
+
meaningless results.
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
- Model A: 'all-MiniLM-L6-v2' (384D)
|
|
37
|
+
- Model B: 'paraphrase-MiniLM-L3-v2' (384D)
|
|
38
|
+
|
|
39
|
+
These produce vectors in DIFFERENT semantic spaces. You cannot:
|
|
40
|
+
❌ Search Model A embeddings using Model B query vectors
|
|
41
|
+
❌ Store both in the same table and expect meaningful results
|
|
42
|
+
|
|
43
|
+
This allows you to:
|
|
44
|
+
✓ Use multiple models simultaneously (each in its own vector space)
|
|
45
|
+
✓ Prevent accidental mixing of incompatible vector spaces
|
|
46
|
+
✓ No manual table management needed
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
# Safe by default - each model uses its own tables
|
|
50
|
+
store_mini = ConfigurablePgVectorStore(
|
|
51
|
+
config=config_minilm # Uses all-MiniLM-L6-v2
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
store_mpnet = ConfigurablePgVectorStore(
|
|
55
|
+
config=config_mpnet # Uses multi-qa-mpnet-base-cos-v1
|
|
56
|
+
)
|
|
57
|
+
# Both can coexist without conflicts or vector space mixing!
|
|
58
|
+
|
|
59
|
+
For shared table behavior, set use_dimension_specific_tables=False.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
from __future__ import annotations
|
|
63
|
+
|
|
64
|
+
import time
|
|
65
|
+
from functools import lru_cache
|
|
66
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
67
|
+
|
|
68
|
+
import numpy as np
|
|
69
|
+
from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
|
|
70
|
+
from django.db import connection, transaction
|
|
71
|
+
|
|
72
|
+
from rakam_systems_core.ai_utils import logging
|
|
73
|
+
from rakam_systems_core.ai_core.interfaces.vectorstore import VectorStore
|
|
74
|
+
from rakam_systems_vectorstore.components.embedding_model.configurable_embeddings import ConfigurableEmbeddings
|
|
75
|
+
from rakam_systems_vectorstore.components.vectorstore.pg_models import Collection, NodeEntry
|
|
76
|
+
from rakam_systems_vectorstore.config import VectorStoreConfig, load_config
|
|
77
|
+
from rakam_systems_vectorstore.core import Node, NodeMetadata, VSFile
|
|
78
|
+
|
|
79
|
+
logger = logging.getLogger(__name__)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ConfigurablePgVectorStore(VectorStore):
|
|
83
|
+
"""
|
|
84
|
+
Enhanced PostgreSQL Vector Store with full configuration support.
|
|
85
|
+
|
|
86
|
+
Features:
|
|
87
|
+
- Configuration via YAML/JSON or dict
|
|
88
|
+
- Pluggable embedding models
|
|
89
|
+
- Configurable similarity metrics
|
|
90
|
+
- Hybrid search with configurable weights
|
|
91
|
+
- Update operations for vectors
|
|
92
|
+
- Comprehensive metadata filtering
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
name: str = "configurable_pg_vector_store",
|
|
98
|
+
config: Optional[Union[VectorStoreConfig, Dict, str]] = None,
|
|
99
|
+
auto_recreate_on_dimension_mismatch: bool = False,
|
|
100
|
+
use_dimension_specific_tables: bool = True
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Initialize configurable PostgreSQL vector store.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
name: Component name
|
|
107
|
+
config: Configuration (VectorStoreConfig object, dict, or path to config file)
|
|
108
|
+
auto_recreate_on_dimension_mismatch: DEPRECATED - No longer used. Vector columns now
|
|
109
|
+
support any dimension without schema changes.
|
|
110
|
+
use_dimension_specific_tables: If True, each embedding model gets its own dedicated tables
|
|
111
|
+
based on the model name, preventing:
|
|
112
|
+
- Mixing incompatible vector spaces
|
|
113
|
+
- Meaningless search results from mixed embeddings
|
|
114
|
+
DEFAULT: True (STRONGLY recommended)
|
|
115
|
+
|
|
116
|
+
Important:
|
|
117
|
+
Even models with the same dimensions produce vectors in different semantic spaces!
|
|
118
|
+
For example, 'all-MiniLM-L6-v2' and 'paraphrase-MiniLM-L3-v2' are both 384D,
|
|
119
|
+
but their vectors are NOT compatible. Always use model-specific tables.
|
|
120
|
+
"""
|
|
121
|
+
# Load configuration
|
|
122
|
+
if isinstance(config, VectorStoreConfig):
|
|
123
|
+
self.vs_config = config
|
|
124
|
+
elif isinstance(config, dict):
|
|
125
|
+
self.vs_config = VectorStoreConfig.from_dict(config)
|
|
126
|
+
elif isinstance(config, str):
|
|
127
|
+
# Path to config file
|
|
128
|
+
self.vs_config = load_config(config)
|
|
129
|
+
else:
|
|
130
|
+
# Use defaults
|
|
131
|
+
self.vs_config = VectorStoreConfig()
|
|
132
|
+
|
|
133
|
+
# Validate configuration
|
|
134
|
+
self.vs_config.validate()
|
|
135
|
+
|
|
136
|
+
# Initialize base component
|
|
137
|
+
super().__init__(name=name, config=self.vs_config.to_dict())
|
|
138
|
+
|
|
139
|
+
# Setup logging
|
|
140
|
+
if self.vs_config.enable_logging:
|
|
141
|
+
logging.basicConfig(level=self.vs_config.log_level)
|
|
142
|
+
|
|
143
|
+
# Store configuration
|
|
144
|
+
self.auto_recreate_on_dimension_mismatch = auto_recreate_on_dimension_mismatch
|
|
145
|
+
self.use_dimension_specific_tables = use_dimension_specific_tables
|
|
146
|
+
|
|
147
|
+
# Table names will be set after we know the embedding dimension
|
|
148
|
+
self.table_collection = "application_collection"
|
|
149
|
+
self.table_nodeentry = "application_nodeentry"
|
|
150
|
+
|
|
151
|
+
# Initialize embedding model
|
|
152
|
+
self.embedding_model = ConfigurableEmbeddings(
|
|
153
|
+
name=f"{name}_embeddings",
|
|
154
|
+
config=self.vs_config.embedding
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.embedding_dim: Optional[int] = None
|
|
158
|
+
|
|
159
|
+
logger.info(f"Initialized {name} with config: {self.vs_config.name}")
|
|
160
|
+
|
|
161
|
+
def setup(self) -> None:
|
|
162
|
+
"""Initialize resources and connections."""
|
|
163
|
+
# Skip if already initialized
|
|
164
|
+
if self.initialized:
|
|
165
|
+
logger.debug(
|
|
166
|
+
"ConfigurablePgVectorStore already initialized, skipping setup")
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
logger.info("Setting up ConfigurablePgVectorStore...")
|
|
170
|
+
|
|
171
|
+
# Ensure pgvector extension
|
|
172
|
+
self._ensure_pgvector_extension()
|
|
173
|
+
|
|
174
|
+
# Setup embedding model (will skip if already initialized)
|
|
175
|
+
self.embedding_model.setup()
|
|
176
|
+
self.embedding_dim = self.embedding_model.embedding_dimension
|
|
177
|
+
|
|
178
|
+
# Set table names based on model if using model-specific tables
|
|
179
|
+
if self.use_dimension_specific_tables:
|
|
180
|
+
# Create a safe table suffix from model name
|
|
181
|
+
# Each model gets its own table because even same-dimension models
|
|
182
|
+
# have different vector spaces!
|
|
183
|
+
model_name = self.vs_config.embedding.model_name
|
|
184
|
+
safe_model_name = self._sanitize_model_name(model_name)
|
|
185
|
+
|
|
186
|
+
self.table_collection = f"application_collection_{safe_model_name}"
|
|
187
|
+
self.table_nodeentry = f"application_nodeentry_{safe_model_name}"
|
|
188
|
+
logger.info(
|
|
189
|
+
f"Using model-specific tables for '{model_name}' ({self.embedding_dim}D): "
|
|
190
|
+
f"collection={self.table_collection}, "
|
|
191
|
+
f"nodeentry={self.table_nodeentry}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Ensure the required tables exist
|
|
195
|
+
self._ensure_vector_dimension_compatibility()
|
|
196
|
+
|
|
197
|
+
logger.info(
|
|
198
|
+
f"Vector store ready with embedding dimension: {self.embedding_dim}")
|
|
199
|
+
super().setup()
|
|
200
|
+
|
|
201
|
+
def _sanitize_model_name(self, model_name: str) -> str:
|
|
202
|
+
"""
|
|
203
|
+
Convert model name to a safe table suffix.
|
|
204
|
+
|
|
205
|
+
Examples:
|
|
206
|
+
'all-MiniLM-L6-v2' -> 'all_minilm_l6_v2'
|
|
207
|
+
'sentence-transformers/multi-qa-mpnet-base-cos-v1' -> 'multi_qa_mpnet_base_cos_v1'
|
|
208
|
+
'text-embedding-ada-002' -> 'text_embedding_ada_002'
|
|
209
|
+
"""
|
|
210
|
+
import re
|
|
211
|
+
|
|
212
|
+
# Remove common prefixes
|
|
213
|
+
name = model_name.replace('sentence-transformers/', '')
|
|
214
|
+
name = name.replace('models/', '')
|
|
215
|
+
|
|
216
|
+
# Replace non-alphanumeric with underscore
|
|
217
|
+
name = re.sub(r'[^a-zA-Z0-9]', '_', name)
|
|
218
|
+
|
|
219
|
+
# Convert to lowercase
|
|
220
|
+
name = name.lower()
|
|
221
|
+
|
|
222
|
+
# Remove consecutive underscores
|
|
223
|
+
name = re.sub(r'_+', '_', name)
|
|
224
|
+
|
|
225
|
+
# Remove leading/trailing underscores
|
|
226
|
+
name = name.strip('_')
|
|
227
|
+
|
|
228
|
+
# Limit length (PostgreSQL identifier limit is 63 chars)
|
|
229
|
+
if len(name) > 40:
|
|
230
|
+
# Keep last 40 chars (usually has version info)
|
|
231
|
+
name = name[-40:]
|
|
232
|
+
name = name.lstrip('_')
|
|
233
|
+
|
|
234
|
+
return name
|
|
235
|
+
|
|
236
|
+
def _ensure_pgvector_extension(self) -> None:
|
|
237
|
+
"""Ensure pgvector extension is installed."""
|
|
238
|
+
with connection.cursor() as cursor:
|
|
239
|
+
try:
|
|
240
|
+
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
|
241
|
+
logger.info("Ensured pgvector extension is installed")
|
|
242
|
+
except Exception as e:
|
|
243
|
+
logger.error(f"Failed to create pgvector extension: {e}")
|
|
244
|
+
raise
|
|
245
|
+
|
|
246
|
+
def _ensure_vector_dimension_compatibility(self) -> None:
|
|
247
|
+
"""
|
|
248
|
+
Ensures that the required tables exist.
|
|
249
|
+
|
|
250
|
+
Note: Vector columns are created without dimension constraints, allowing
|
|
251
|
+
flexibility to store vectors of any dimension without needing to alter
|
|
252
|
+
the database schema when switching embedding models.
|
|
253
|
+
"""
|
|
254
|
+
with connection.cursor() as cursor:
|
|
255
|
+
try:
|
|
256
|
+
# First ensure the collection table exists
|
|
257
|
+
cursor.execute(f"""
|
|
258
|
+
CREATE TABLE IF NOT EXISTS {self.table_collection} (
|
|
259
|
+
id SERIAL PRIMARY KEY,
|
|
260
|
+
name VARCHAR(255) UNIQUE NOT NULL,
|
|
261
|
+
embedding_dim INTEGER NOT NULL DEFAULT 384,
|
|
262
|
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
|
263
|
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
264
|
+
);
|
|
265
|
+
""")
|
|
266
|
+
|
|
267
|
+
# Check if the nodeentry table exists
|
|
268
|
+
cursor.execute(f"""
|
|
269
|
+
SELECT EXISTS (
|
|
270
|
+
SELECT FROM information_schema.tables
|
|
271
|
+
WHERE table_name = '{self.table_nodeentry}'
|
|
272
|
+
);
|
|
273
|
+
""")
|
|
274
|
+
table_exists = cursor.fetchone()[0]
|
|
275
|
+
|
|
276
|
+
if not table_exists:
|
|
277
|
+
# Table doesn't exist, create it without dimension constraint
|
|
278
|
+
logger.info(
|
|
279
|
+
f"Creating new table '{self.table_nodeentry}' (supports any vector dimension)...")
|
|
280
|
+
cursor.execute(f"""
|
|
281
|
+
CREATE TABLE {self.table_nodeentry} (
|
|
282
|
+
node_id SERIAL PRIMARY KEY,
|
|
283
|
+
collection_id INTEGER NOT NULL REFERENCES {self.table_collection}(id) ON DELETE CASCADE,
|
|
284
|
+
content TEXT NOT NULL,
|
|
285
|
+
embedding vector,
|
|
286
|
+
source_file_uuid VARCHAR(255) NOT NULL,
|
|
287
|
+
position INTEGER,
|
|
288
|
+
custom_metadata JSONB DEFAULT '{{}}'::jsonb,
|
|
289
|
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
|
290
|
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
291
|
+
);
|
|
292
|
+
""")
|
|
293
|
+
|
|
294
|
+
# Create indexes
|
|
295
|
+
cursor.execute(f"""
|
|
296
|
+
CREATE INDEX {self.table_nodeentry}_source_idx
|
|
297
|
+
ON {self.table_nodeentry}(source_file_uuid);
|
|
298
|
+
""")
|
|
299
|
+
cursor.execute(f"""
|
|
300
|
+
CREATE INDEX {self.table_nodeentry}_collect_idx
|
|
301
|
+
ON {self.table_nodeentry}(collection_id, source_file_uuid);
|
|
302
|
+
""")
|
|
303
|
+
|
|
304
|
+
logger.info(
|
|
305
|
+
f"✓ Created table '{self.table_nodeentry}' (dimension-agnostic)")
|
|
306
|
+
else:
|
|
307
|
+
logger.info(
|
|
308
|
+
f"✓ Table '{self.table_nodeentry}' already exists")
|
|
309
|
+
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.error(f"Failed to ensure table exists: {e}")
|
|
312
|
+
raise
|
|
313
|
+
|
|
314
|
+
def get_or_create_collection(
|
|
315
|
+
self,
|
|
316
|
+
collection_name: str,
|
|
317
|
+
embedding_dim: Optional[int] = None
|
|
318
|
+
) -> Collection:
|
|
319
|
+
"""
|
|
320
|
+
Get or create a collection.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
collection_name: Name of the collection
|
|
324
|
+
embedding_dim: Embedding dimension (uses model dimension if not specified)
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Collection object
|
|
328
|
+
"""
|
|
329
|
+
if embedding_dim is None:
|
|
330
|
+
embedding_dim = self.embedding_dim
|
|
331
|
+
|
|
332
|
+
# Use raw SQL when custom table names are in use
|
|
333
|
+
if self.use_dimension_specific_tables:
|
|
334
|
+
with connection.cursor() as cursor:
|
|
335
|
+
# Try to get existing collection
|
|
336
|
+
cursor.execute(
|
|
337
|
+
f"""
|
|
338
|
+
SELECT id, name, embedding_dim, created_at, updated_at
|
|
339
|
+
FROM {self.table_collection}
|
|
340
|
+
WHERE name = %s
|
|
341
|
+
""",
|
|
342
|
+
[collection_name]
|
|
343
|
+
)
|
|
344
|
+
row = cursor.fetchone()
|
|
345
|
+
|
|
346
|
+
if row:
|
|
347
|
+
# Collection exists - create a Collection object manually
|
|
348
|
+
collection = Collection(
|
|
349
|
+
id=row[0],
|
|
350
|
+
name=row[1],
|
|
351
|
+
embedding_dim=row[2],
|
|
352
|
+
created_at=row[3],
|
|
353
|
+
updated_at=row[4]
|
|
354
|
+
)
|
|
355
|
+
# Mark it as existing in DB
|
|
356
|
+
collection._state.adding = False
|
|
357
|
+
logger.info(
|
|
358
|
+
f"Using existing collection: {collection_name}")
|
|
359
|
+
else:
|
|
360
|
+
# Create new collection
|
|
361
|
+
cursor.execute(
|
|
362
|
+
f"""
|
|
363
|
+
INSERT INTO {self.table_collection} (name, embedding_dim, created_at, updated_at)
|
|
364
|
+
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
|
365
|
+
RETURNING id, name, embedding_dim, created_at, updated_at
|
|
366
|
+
""",
|
|
367
|
+
[collection_name, embedding_dim]
|
|
368
|
+
)
|
|
369
|
+
row = cursor.fetchone()
|
|
370
|
+
collection = Collection(
|
|
371
|
+
id=row[0],
|
|
372
|
+
name=row[1],
|
|
373
|
+
embedding_dim=row[2],
|
|
374
|
+
created_at=row[3],
|
|
375
|
+
updated_at=row[4]
|
|
376
|
+
)
|
|
377
|
+
collection._state.adding = False
|
|
378
|
+
logger.info(f"Created new collection: {collection_name}")
|
|
379
|
+
|
|
380
|
+
return collection
|
|
381
|
+
else:
|
|
382
|
+
# Use Django ORM for standard tables
|
|
383
|
+
collection, created = Collection.objects.get_or_create(
|
|
384
|
+
name=collection_name,
|
|
385
|
+
defaults={"embedding_dim": embedding_dim}
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
logger.info(
|
|
389
|
+
f"{'Created new' if created else 'Using existing'} collection: {collection_name}"
|
|
390
|
+
)
|
|
391
|
+
return collection
|
|
392
|
+
|
|
393
|
+
def get_collection(self, collection_name: str) -> Collection:
|
|
394
|
+
"""
|
|
395
|
+
Get an existing collection (raises ValueError if not found).
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
collection_name: Name of the collection
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
Collection object
|
|
402
|
+
|
|
403
|
+
Raises:
|
|
404
|
+
ValueError: If collection does not exist
|
|
405
|
+
"""
|
|
406
|
+
# Use raw SQL when custom table names are in use
|
|
407
|
+
if self.use_dimension_specific_tables:
|
|
408
|
+
with connection.cursor() as cursor:
|
|
409
|
+
cursor.execute(
|
|
410
|
+
f"""
|
|
411
|
+
SELECT id, name, embedding_dim, created_at, updated_at
|
|
412
|
+
FROM {self.table_collection}
|
|
413
|
+
WHERE name = %s
|
|
414
|
+
""",
|
|
415
|
+
[collection_name]
|
|
416
|
+
)
|
|
417
|
+
row = cursor.fetchone()
|
|
418
|
+
|
|
419
|
+
if not row:
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"Collection not found: {collection_name}")
|
|
422
|
+
|
|
423
|
+
# Create Collection object from row
|
|
424
|
+
collection = Collection(
|
|
425
|
+
id=row[0],
|
|
426
|
+
name=row[1],
|
|
427
|
+
embedding_dim=row[2],
|
|
428
|
+
created_at=row[3],
|
|
429
|
+
updated_at=row[4]
|
|
430
|
+
)
|
|
431
|
+
collection._state.adding = False
|
|
432
|
+
return collection
|
|
433
|
+
else:
|
|
434
|
+
try:
|
|
435
|
+
return Collection.objects.get(name=collection_name)
|
|
436
|
+
except Collection.DoesNotExist:
|
|
437
|
+
raise ValueError(f"Collection not found: {collection_name}")
|
|
438
|
+
|
|
439
|
+
def _get_distance_operator(self, distance_type: Optional[str] = None) -> str:
|
|
440
|
+
"""Get SQL distance operator for the configured similarity metric."""
|
|
441
|
+
if distance_type is None:
|
|
442
|
+
distance_type = self.vs_config.search.similarity_metric
|
|
443
|
+
|
|
444
|
+
operators = {
|
|
445
|
+
"cosine": "<=>",
|
|
446
|
+
"l2": "<->",
|
|
447
|
+
"dot_product": "<#>",
|
|
448
|
+
"dot": "<#>"
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
if distance_type not in operators:
|
|
452
|
+
raise ValueError(f"Unsupported distance type: {distance_type}")
|
|
453
|
+
|
|
454
|
+
return operators[distance_type]
|
|
455
|
+
|
|
456
|
+
@lru_cache(maxsize=1000)
|
|
457
|
+
def _get_query_embedding(self, query: str) -> np.ndarray:
|
|
458
|
+
"""Get embedding for a query (with caching)."""
|
|
459
|
+
if not self.vs_config.enable_caching:
|
|
460
|
+
# Don't use cache
|
|
461
|
+
return np.array(self.embedding_model.encode_query(query), dtype=np.float32)
|
|
462
|
+
|
|
463
|
+
embedding = self.embedding_model.encode_query(query)
|
|
464
|
+
return np.array(embedding, dtype=np.float32)
|
|
465
|
+
|
|
466
|
+
def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
|
|
467
|
+
"""Normalize embedding vector."""
|
|
468
|
+
norm = np.linalg.norm(embedding)
|
|
469
|
+
if norm > 0:
|
|
470
|
+
return embedding / norm
|
|
471
|
+
return embedding
|
|
472
|
+
|
|
473
|
+
@transaction.atomic
|
|
474
|
+
def create_collection_from_nodes(
|
|
475
|
+
self,
|
|
476
|
+
collection_name: str,
|
|
477
|
+
nodes: List[Node]
|
|
478
|
+
) -> None:
|
|
479
|
+
"""
|
|
480
|
+
Create a collection from nodes.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
collection_name: Name of collection
|
|
484
|
+
nodes: List of Node objects
|
|
485
|
+
"""
|
|
486
|
+
if not nodes:
|
|
487
|
+
logger.warning(
|
|
488
|
+
f"No nodes provided for collection '{collection_name}'")
|
|
489
|
+
return
|
|
490
|
+
|
|
491
|
+
# Filter out nodes with None or empty content (these would cause embedding errors)
|
|
492
|
+
original_count = len(nodes)
|
|
493
|
+
nodes = [node for node in nodes if node.content is not None and str(
|
|
494
|
+
node.content).strip()]
|
|
495
|
+
|
|
496
|
+
if len(nodes) < original_count:
|
|
497
|
+
logger.warning(
|
|
498
|
+
f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
|
|
499
|
+
|
|
500
|
+
if not nodes:
|
|
501
|
+
logger.warning(
|
|
502
|
+
f"No valid nodes to add for collection '{collection_name}' after filtering")
|
|
503
|
+
return
|
|
504
|
+
|
|
505
|
+
logger.info(
|
|
506
|
+
f"Creating collection '{collection_name}' with {len(nodes)} nodes")
|
|
507
|
+
|
|
508
|
+
# Get or create collection
|
|
509
|
+
collection = self.get_or_create_collection(collection_name)
|
|
510
|
+
|
|
511
|
+
# Generate embeddings - ensure all content is string type
|
|
512
|
+
texts = [str(node.content) for node in nodes]
|
|
513
|
+
embeddings = self.embedding_model.encode_documents(texts)
|
|
514
|
+
|
|
515
|
+
# Use raw SQL when custom table names are in use
|
|
516
|
+
if self.use_dimension_specific_tables:
|
|
517
|
+
import json
|
|
518
|
+
with connection.cursor() as cursor:
|
|
519
|
+
# Clear existing nodes
|
|
520
|
+
cursor.execute(
|
|
521
|
+
f"DELETE FROM {self.table_nodeentry} WHERE collection_id = %s",
|
|
522
|
+
[collection.id]
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# Insert nodes using raw SQL
|
|
526
|
+
for i, node in enumerate(nodes):
|
|
527
|
+
# Convert custom_metadata dict to JSON string
|
|
528
|
+
custom_metadata_json = json.dumps(
|
|
529
|
+
node.metadata.custom or {})
|
|
530
|
+
|
|
531
|
+
cursor.execute(
|
|
532
|
+
f"""
|
|
533
|
+
INSERT INTO {self.table_nodeentry}
|
|
534
|
+
(collection_id, content, embedding, source_file_uuid, position, custom_metadata, created_at, updated_at)
|
|
535
|
+
VALUES (%s, %s, %s, %s, %s, %s::jsonb, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
|
536
|
+
RETURNING node_id
|
|
537
|
+
""",
|
|
538
|
+
[
|
|
539
|
+
collection.id,
|
|
540
|
+
node.content,
|
|
541
|
+
embeddings[i],
|
|
542
|
+
node.metadata.source_file_uuid,
|
|
543
|
+
node.metadata.position,
|
|
544
|
+
custom_metadata_json
|
|
545
|
+
]
|
|
546
|
+
)
|
|
547
|
+
node_id = cursor.fetchone()[0]
|
|
548
|
+
node.metadata.node_id = node_id
|
|
549
|
+
|
|
550
|
+
logger.info(
|
|
551
|
+
f"Created collection '{collection_name}' with {len(nodes)} nodes")
|
|
552
|
+
else:
|
|
553
|
+
# Use Django ORM for standard tables
|
|
554
|
+
# Clear existing nodes
|
|
555
|
+
NodeEntry.objects.filter(collection=collection).delete()
|
|
556
|
+
|
|
557
|
+
# Create node entries
|
|
558
|
+
node_entries = [
|
|
559
|
+
NodeEntry(
|
|
560
|
+
collection=collection,
|
|
561
|
+
content=node.content,
|
|
562
|
+
embedding=embeddings[i],
|
|
563
|
+
source_file_uuid=node.metadata.source_file_uuid,
|
|
564
|
+
position=node.metadata.position,
|
|
565
|
+
custom_metadata=node.metadata.custom or {},
|
|
566
|
+
)
|
|
567
|
+
for i, node in enumerate(nodes)
|
|
568
|
+
]
|
|
569
|
+
|
|
570
|
+
# Bulk insert
|
|
571
|
+
created_entries = NodeEntry.objects.bulk_create(
|
|
572
|
+
node_entries,
|
|
573
|
+
batch_size=self.vs_config.index.batch_insert_size
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
# Update node IDs
|
|
577
|
+
for i, node in enumerate(nodes):
|
|
578
|
+
node.metadata.node_id = created_entries[i].node_id
|
|
579
|
+
|
|
580
|
+
logger.info(
|
|
581
|
+
f"Created collection '{collection_name}' with {len(created_entries)} nodes")
|
|
582
|
+
|
|
583
|
+
@transaction.atomic
|
|
584
|
+
def create_collection_from_files(
|
|
585
|
+
self,
|
|
586
|
+
collection_name: str,
|
|
587
|
+
files: List[VSFile]
|
|
588
|
+
) -> None:
|
|
589
|
+
"""
|
|
590
|
+
Create collection from VSFile objects.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
collection_name: Name of collection
|
|
594
|
+
files: List of VSFile objects
|
|
595
|
+
"""
|
|
596
|
+
nodes = [node for file in files for node in file.nodes]
|
|
597
|
+
self.create_collection_from_nodes(collection_name, nodes)
|
|
598
|
+
|
|
599
|
+
def add_nodes(self, collection_name: str, nodes: List[Node]) -> None:
|
|
600
|
+
"""
|
|
601
|
+
Add nodes to existing collection.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
collection_name: Name of collection
|
|
605
|
+
nodes: Nodes to add
|
|
606
|
+
"""
|
|
607
|
+
if not nodes:
|
|
608
|
+
logger.warning("No nodes to add")
|
|
609
|
+
return
|
|
610
|
+
|
|
611
|
+
# Filter out nodes with None or empty content (these would cause embedding errors)
|
|
612
|
+
original_count = len(nodes)
|
|
613
|
+
nodes = [node for node in nodes if node.content is not None and str(
|
|
614
|
+
node.content).strip()]
|
|
615
|
+
|
|
616
|
+
if len(nodes) < original_count:
|
|
617
|
+
logger.warning(
|
|
618
|
+
f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
|
|
619
|
+
|
|
620
|
+
if not nodes:
|
|
621
|
+
logger.warning("No valid nodes to add after filtering")
|
|
622
|
+
return
|
|
623
|
+
|
|
624
|
+
logger.info(
|
|
625
|
+
f"Adding {len(nodes)} nodes to collection '{collection_name}'")
|
|
626
|
+
|
|
627
|
+
# Generate embeddings BEFORE starting the database transaction
|
|
628
|
+
# This is critical for large datasets as embedding generation can take minutes,
|
|
629
|
+
# which would otherwise cause the DB connection to timeout
|
|
630
|
+
logger.info(
|
|
631
|
+
f"Preparing {len(nodes)} nodes for embedding generation...")
|
|
632
|
+
texts = [str(node.content) for node in nodes]
|
|
633
|
+
total_texts = len(texts)
|
|
634
|
+
logger.info(
|
|
635
|
+
f"Starting embedding generation for {total_texts} texts (this may take a while)...")
|
|
636
|
+
logger.info(
|
|
637
|
+
f"Model: {self.embedding_model.model_name}, Batch size: {self.embedding_model.batch_size}")
|
|
638
|
+
logger.info(
|
|
639
|
+
f"Expected batches: {(total_texts + self.embedding_model.batch_size - 1) // self.embedding_model.batch_size}")
|
|
640
|
+
embed_start_time = time.time()
|
|
641
|
+
|
|
642
|
+
logger.info("Calling embedding model encode_documents()...")
|
|
643
|
+
embeddings = self.embedding_model.encode_documents(texts)
|
|
644
|
+
|
|
645
|
+
embed_elapsed = time.time() - embed_start_time
|
|
646
|
+
avg_rate = total_texts / embed_elapsed if embed_elapsed > 0 else 0
|
|
647
|
+
logger.info(
|
|
648
|
+
f"✓ Embedding generation completed in {embed_elapsed:.1f}s ({avg_rate:.1f} texts/s)")
|
|
649
|
+
logger.info(f"Generated {len(embeddings)} embeddings")
|
|
650
|
+
|
|
651
|
+
# MEMORY OPTIMIZATION: Clear texts list after embedding generation
|
|
652
|
+
# The texts are no longer needed as embeddings are already computed
|
|
653
|
+
del texts
|
|
654
|
+
import gc
|
|
655
|
+
gc.collect()
|
|
656
|
+
|
|
657
|
+
# Now perform the database operations within a transaction
|
|
658
|
+
self._insert_nodes_with_embeddings(collection_name, nodes, embeddings)
|
|
659
|
+
|
|
660
|
+
# MEMORY OPTIMIZATION: Clear embeddings after insertion
|
|
661
|
+
del embeddings
|
|
662
|
+
gc.collect()
|
|
663
|
+
|
|
664
|
+
@transaction.atomic
|
|
665
|
+
def _insert_nodes_with_embeddings(self, collection_name: str, nodes: List[Node], embeddings: List) -> None:
|
|
666
|
+
"""
|
|
667
|
+
Insert nodes with pre-computed embeddings into the database.
|
|
668
|
+
|
|
669
|
+
This is a separate method to allow embedding generation to happen
|
|
670
|
+
outside the database transaction, preventing connection timeouts
|
|
671
|
+
for large datasets.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
collection_name: Name of collection
|
|
675
|
+
nodes: Nodes to insert
|
|
676
|
+
embeddings: Pre-computed embeddings for the nodes
|
|
677
|
+
"""
|
|
678
|
+
collection = self.get_collection(collection_name)
|
|
679
|
+
|
|
680
|
+
if self.use_dimension_specific_tables:
|
|
681
|
+
# Use raw SQL with batch inserts when custom tables are in use
|
|
682
|
+
import json
|
|
683
|
+
from django.db import transaction
|
|
684
|
+
|
|
685
|
+
batch_size = self.vs_config.index.batch_insert_size
|
|
686
|
+
total_nodes = len(nodes)
|
|
687
|
+
total_batches = (total_nodes + batch_size - 1) // batch_size
|
|
688
|
+
|
|
689
|
+
logger.info(
|
|
690
|
+
f"Starting batch insert: {total_nodes} nodes in {total_batches} batches (batch_size={batch_size})")
|
|
691
|
+
insert_start_time = time.time()
|
|
692
|
+
|
|
693
|
+
# Process in batches with individual transactions to prevent connection timeout
|
|
694
|
+
# Each batch is committed separately to avoid long-running transactions
|
|
695
|
+
for batch_idx, batch_start in enumerate(range(0, total_nodes, batch_size)):
|
|
696
|
+
batch_start_time = time.time()
|
|
697
|
+
batch_end = min(batch_start + batch_size, total_nodes)
|
|
698
|
+
batch_nodes = nodes[batch_start:batch_end]
|
|
699
|
+
batch_embeddings = embeddings[batch_start:batch_end]
|
|
700
|
+
|
|
701
|
+
# Use atomic transaction for each batch
|
|
702
|
+
with transaction.atomic():
|
|
703
|
+
with connection.cursor() as cursor:
|
|
704
|
+
# Build batch insert values
|
|
705
|
+
values_list = []
|
|
706
|
+
params = []
|
|
707
|
+
for i, node in enumerate(batch_nodes):
|
|
708
|
+
custom_metadata_json = json.dumps(
|
|
709
|
+
node.metadata.custom or {})
|
|
710
|
+
# Convert embedding to string format for pgvector: "[1.0, 2.0, ...]"
|
|
711
|
+
embedding_values = batch_embeddings[i].tolist() if hasattr(
|
|
712
|
+
batch_embeddings[i], 'tolist') else list(batch_embeddings[i])
|
|
713
|
+
embedding_str = "[" + ",".join(str(x)
|
|
714
|
+
for x in embedding_values) + "]"
|
|
715
|
+
# Convert any dict values to JSON strings for psycopg2 compatibility
|
|
716
|
+
source_file_uuid = json.dumps(node.metadata.source_file_uuid) if isinstance(
|
|
717
|
+
node.metadata.source_file_uuid, dict) else node.metadata.source_file_uuid
|
|
718
|
+
position = json.dumps(node.metadata.position) if isinstance(
|
|
719
|
+
node.metadata.position, dict) else node.metadata.position
|
|
720
|
+
content = json.dumps(node.content) if isinstance(
|
|
721
|
+
node.content, dict) else node.content
|
|
722
|
+
|
|
723
|
+
values_list.append(
|
|
724
|
+
"(%s, %s, %s::vector, %s, %s, %s::jsonb, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)")
|
|
725
|
+
params.extend([
|
|
726
|
+
collection.id,
|
|
727
|
+
content,
|
|
728
|
+
embedding_str,
|
|
729
|
+
source_file_uuid,
|
|
730
|
+
position,
|
|
731
|
+
custom_metadata_json
|
|
732
|
+
])
|
|
733
|
+
|
|
734
|
+
# Execute batch insert
|
|
735
|
+
cursor.execute(
|
|
736
|
+
f"""
|
|
737
|
+
INSERT INTO {self.table_nodeentry}
|
|
738
|
+
(collection_id, content, embedding, source_file_uuid, position, custom_metadata, created_at, updated_at)
|
|
739
|
+
VALUES {", ".join(values_list)}
|
|
740
|
+
RETURNING node_id
|
|
741
|
+
""",
|
|
742
|
+
params
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
# Get returned node IDs and update nodes
|
|
746
|
+
node_ids = cursor.fetchall()
|
|
747
|
+
for i, node in enumerate(batch_nodes):
|
|
748
|
+
node.metadata.node_id = node_ids[i][0]
|
|
749
|
+
|
|
750
|
+
# Transaction is committed here automatically when exiting the atomic() context
|
|
751
|
+
batch_elapsed = time.time() - batch_start_time
|
|
752
|
+
|
|
753
|
+
# Log progress for every batch or at milestones
|
|
754
|
+
current_batch = batch_idx + 1
|
|
755
|
+
if current_batch % 10 == 0 or current_batch == total_batches or batch_elapsed > 1.0:
|
|
756
|
+
total_elapsed = time.time() - insert_start_time
|
|
757
|
+
nodes_per_sec = batch_end / total_elapsed if total_elapsed > 0 else 0
|
|
758
|
+
eta_seconds = (total_nodes - batch_end) / \
|
|
759
|
+
nodes_per_sec if nodes_per_sec > 0 else 0
|
|
760
|
+
logger.info(
|
|
761
|
+
f"Insert progress: {batch_end}/{total_nodes} nodes "
|
|
762
|
+
f"({batch_end * 100 // total_nodes}%) | "
|
|
763
|
+
f"Batch {current_batch}/{total_batches} took {batch_elapsed:.2f}s | "
|
|
764
|
+
f"Rate: {nodes_per_sec:.0f} nodes/s | "
|
|
765
|
+
f"ETA: {eta_seconds:.0f}s"
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
total_time = time.time() - insert_start_time
|
|
769
|
+
logger.info(
|
|
770
|
+
f"Completed inserting {total_nodes} nodes to '{collection_name}' in {total_time:.2f}s ({total_nodes/total_time:.0f} nodes/s)")
|
|
771
|
+
else:
|
|
772
|
+
# Create entries using ORM
|
|
773
|
+
node_entries = [
|
|
774
|
+
NodeEntry(
|
|
775
|
+
collection=collection,
|
|
776
|
+
content=node.content,
|
|
777
|
+
embedding=embeddings[i],
|
|
778
|
+
source_file_uuid=node.metadata.source_file_uuid,
|
|
779
|
+
position=node.metadata.position,
|
|
780
|
+
custom_metadata=node.metadata.custom or {},
|
|
781
|
+
)
|
|
782
|
+
for i, node in enumerate(nodes)
|
|
783
|
+
]
|
|
784
|
+
|
|
785
|
+
created_entries = NodeEntry.objects.bulk_create(
|
|
786
|
+
node_entries,
|
|
787
|
+
batch_size=self.vs_config.index.batch_insert_size
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# Update node IDs
|
|
791
|
+
for i, node in enumerate(nodes):
|
|
792
|
+
node.metadata.node_id = created_entries[i].node_id
|
|
793
|
+
|
|
794
|
+
logger.info(
|
|
795
|
+
f"Added {len(created_entries)} nodes to '{collection_name}'")
|
|
796
|
+
|
|
797
|
+
@transaction.atomic
|
|
798
|
+
def update_vector(
|
|
799
|
+
self,
|
|
800
|
+
collection_name: str,
|
|
801
|
+
node_id: int,
|
|
802
|
+
new_content: Optional[str] = None,
|
|
803
|
+
new_embedding: Optional[List[float]] = None,
|
|
804
|
+
new_metadata: Optional[Dict[str, Any]] = None
|
|
805
|
+
) -> None:
|
|
806
|
+
"""
|
|
807
|
+
Update a vector in the collection.
|
|
808
|
+
|
|
809
|
+
Args:
|
|
810
|
+
collection_name: Name of collection
|
|
811
|
+
node_id: ID of node to update
|
|
812
|
+
new_content: New content (will regenerate embedding if provided)
|
|
813
|
+
new_embedding: New embedding vector (used if new_content not provided)
|
|
814
|
+
new_metadata: New metadata to merge with existing
|
|
815
|
+
"""
|
|
816
|
+
collection = self.get_collection(collection_name)
|
|
817
|
+
|
|
818
|
+
if self.use_dimension_specific_tables:
|
|
819
|
+
# Use raw SQL when custom tables are in use
|
|
820
|
+
import json
|
|
821
|
+
with connection.cursor() as cursor:
|
|
822
|
+
# First, get the current node
|
|
823
|
+
cursor.execute(
|
|
824
|
+
f"""
|
|
825
|
+
SELECT content, embedding, custom_metadata
|
|
826
|
+
FROM {self.table_nodeentry}
|
|
827
|
+
WHERE collection_id = %s AND node_id = %s
|
|
828
|
+
""",
|
|
829
|
+
[collection.id, node_id]
|
|
830
|
+
)
|
|
831
|
+
row = cursor.fetchone()
|
|
832
|
+
|
|
833
|
+
if not row:
|
|
834
|
+
raise ValueError(
|
|
835
|
+
f"Node {node_id} not found in collection '{collection_name}'")
|
|
836
|
+
|
|
837
|
+
current_content, current_embedding, current_metadata = row
|
|
838
|
+
|
|
839
|
+
# Parse JSON metadata if it's a string
|
|
840
|
+
if isinstance(current_metadata, str):
|
|
841
|
+
current_metadata = json.loads(current_metadata)
|
|
842
|
+
|
|
843
|
+
# Determine updates
|
|
844
|
+
updated_content = current_content
|
|
845
|
+
updated_embedding = current_embedding
|
|
846
|
+
updated_metadata = current_metadata or {}
|
|
847
|
+
|
|
848
|
+
if new_content is not None:
|
|
849
|
+
updated_content = new_content
|
|
850
|
+
updated_embedding = self.embedding_model.encode_query(
|
|
851
|
+
new_content)
|
|
852
|
+
logger.info(
|
|
853
|
+
f"Updated content and regenerated embedding for node {node_id}")
|
|
854
|
+
elif new_embedding is not None:
|
|
855
|
+
updated_embedding = new_embedding
|
|
856
|
+
logger.info(f"Updated embedding for node {node_id}")
|
|
857
|
+
|
|
858
|
+
if new_metadata is not None:
|
|
859
|
+
updated_metadata.update(new_metadata)
|
|
860
|
+
logger.info(f"Updated metadata for node {node_id}")
|
|
861
|
+
|
|
862
|
+
# Update the node
|
|
863
|
+
updated_metadata_json = json.dumps(updated_metadata)
|
|
864
|
+
# Convert embedding to string format for pgvector: "[1.0, 2.0, ...]"
|
|
865
|
+
if updated_embedding is not None:
|
|
866
|
+
embedding_values = updated_embedding.tolist() if hasattr(
|
|
867
|
+
updated_embedding, 'tolist') else list(updated_embedding)
|
|
868
|
+
embedding_str = "[" + ",".join(str(x)
|
|
869
|
+
for x in embedding_values) + "]"
|
|
870
|
+
else:
|
|
871
|
+
embedding_str = None
|
|
872
|
+
cursor.execute(
|
|
873
|
+
f"""
|
|
874
|
+
UPDATE {self.table_nodeentry}
|
|
875
|
+
SET content = %s, embedding = %s::vector, custom_metadata = %s::jsonb, updated_at = CURRENT_TIMESTAMP
|
|
876
|
+
WHERE collection_id = %s AND node_id = %s
|
|
877
|
+
""",
|
|
878
|
+
[updated_content, embedding_str,
|
|
879
|
+
updated_metadata_json, collection.id, node_id]
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
logger.info(
|
|
883
|
+
f"Successfully updated node {node_id} in collection '{collection_name}'")
|
|
884
|
+
else:
|
|
885
|
+
try:
|
|
886
|
+
node_entry = NodeEntry.objects.get(
|
|
887
|
+
collection=collection, node_id=node_id)
|
|
888
|
+
except NodeEntry.DoesNotExist:
|
|
889
|
+
raise ValueError(
|
|
890
|
+
f"Node {node_id} not found in collection '{collection_name}'")
|
|
891
|
+
|
|
892
|
+
# Update content and embedding
|
|
893
|
+
if new_content is not None:
|
|
894
|
+
node_entry.content = new_content
|
|
895
|
+
# Generate new embedding
|
|
896
|
+
embedding = self.embedding_model.encode_query(new_content)
|
|
897
|
+
node_entry.embedding = embedding
|
|
898
|
+
logger.info(
|
|
899
|
+
f"Updated content and regenerated embedding for node {node_id}")
|
|
900
|
+
elif new_embedding is not None:
|
|
901
|
+
node_entry.embedding = new_embedding
|
|
902
|
+
logger.info(f"Updated embedding for node {node_id}")
|
|
903
|
+
|
|
904
|
+
# Update metadata
|
|
905
|
+
if new_metadata is not None:
|
|
906
|
+
# Merge with existing metadata
|
|
907
|
+
current_metadata = node_entry.custom_metadata or {}
|
|
908
|
+
current_metadata.update(new_metadata)
|
|
909
|
+
node_entry.custom_metadata = current_metadata
|
|
910
|
+
logger.info(f"Updated metadata for node {node_id}")
|
|
911
|
+
|
|
912
|
+
node_entry.save()
|
|
913
|
+
logger.info(
|
|
914
|
+
f"Successfully updated node {node_id} in collection '{collection_name}'")
|
|
915
|
+
|
|
916
|
+
@transaction.atomic
|
|
917
|
+
def delete_nodes(self, collection_name: str, node_ids: List[int]) -> None:
|
|
918
|
+
"""
|
|
919
|
+
Delete nodes from collection.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
collection_name: Name of collection
|
|
923
|
+
node_ids: List of node IDs to delete
|
|
924
|
+
"""
|
|
925
|
+
if not node_ids:
|
|
926
|
+
logger.warning("No node IDs to delete")
|
|
927
|
+
return
|
|
928
|
+
|
|
929
|
+
collection = self.get_collection(collection_name)
|
|
930
|
+
|
|
931
|
+
if self.use_dimension_specific_tables:
|
|
932
|
+
# Use raw SQL when custom tables are in use
|
|
933
|
+
with connection.cursor() as cursor:
|
|
934
|
+
placeholders = ','.join(['%s'] * len(node_ids))
|
|
935
|
+
cursor.execute(
|
|
936
|
+
f"""
|
|
937
|
+
DELETE FROM {self.table_nodeentry}
|
|
938
|
+
WHERE collection_id = %s AND node_id IN ({placeholders})
|
|
939
|
+
""",
|
|
940
|
+
[collection.id] + list(node_ids)
|
|
941
|
+
)
|
|
942
|
+
deleted_count = cursor.rowcount
|
|
943
|
+
else:
|
|
944
|
+
deleted_count, _ = NodeEntry.objects.filter(
|
|
945
|
+
collection=collection,
|
|
946
|
+
node_id__in=node_ids
|
|
947
|
+
).delete()
|
|
948
|
+
|
|
949
|
+
logger.info(
|
|
950
|
+
f"Deleted {deleted_count} nodes from collection '{collection_name}'")
|
|
951
|
+
|
|
952
|
+
@transaction.atomic
|
|
953
|
+
def delete_collection(self, collection_name: str) -> None:
|
|
954
|
+
"""
|
|
955
|
+
Delete a collection and all its nodes.
|
|
956
|
+
|
|
957
|
+
Args:
|
|
958
|
+
collection_name: Name of collection to delete
|
|
959
|
+
"""
|
|
960
|
+
collection = self.get_collection(collection_name)
|
|
961
|
+
|
|
962
|
+
if self.use_dimension_specific_tables:
|
|
963
|
+
# Use raw SQL when custom tables are in use
|
|
964
|
+
with connection.cursor() as cursor:
|
|
965
|
+
# Get node count first
|
|
966
|
+
cursor.execute(
|
|
967
|
+
f"SELECT COUNT(*) FROM {self.table_nodeentry} WHERE collection_id = %s",
|
|
968
|
+
[collection.id]
|
|
969
|
+
)
|
|
970
|
+
node_count = cursor.fetchone()[0]
|
|
971
|
+
|
|
972
|
+
# Delete nodes (cascade should handle this, but be explicit)
|
|
973
|
+
cursor.execute(
|
|
974
|
+
f"DELETE FROM {self.table_nodeentry} WHERE collection_id = %s",
|
|
975
|
+
[collection.id]
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
# Delete collection
|
|
979
|
+
cursor.execute(
|
|
980
|
+
f"DELETE FROM {self.table_collection} WHERE id = %s",
|
|
981
|
+
[collection.id]
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
logger.info(
|
|
985
|
+
f"Deleted collection '{collection_name}' with {node_count} nodes")
|
|
986
|
+
else:
|
|
987
|
+
node_count = NodeEntry.objects.filter(
|
|
988
|
+
collection=collection).count()
|
|
989
|
+
collection.delete()
|
|
990
|
+
logger.info(
|
|
991
|
+
f"Deleted collection '{collection_name}' with {node_count} nodes")
|
|
992
|
+
|
|
993
|
+
def keyword_search(
|
|
994
|
+
self,
|
|
995
|
+
collection_name: str,
|
|
996
|
+
query: str,
|
|
997
|
+
number: Optional[int] = None,
|
|
998
|
+
meta_data_filters: Optional[Dict[str, Any]] = None,
|
|
999
|
+
min_rank: float = 0.0,
|
|
1000
|
+
ranking_algorithm: Optional[str] = None
|
|
1001
|
+
) -> Tuple[Dict, List[Node]]:
|
|
1002
|
+
"""
|
|
1003
|
+
Perform pure keyword-based full-text search in collection.
|
|
1004
|
+
|
|
1005
|
+
Supports multiple ranking algorithms:
|
|
1006
|
+
- BM25: Okapi BM25 ranking function (default, state-of-the-art)
|
|
1007
|
+
- ts_rank: PostgreSQL's native full-text search ranking
|
|
1008
|
+
|
|
1009
|
+
Args:
|
|
1010
|
+
collection_name: Name of collection to search
|
|
1011
|
+
query: Search query (keywords)
|
|
1012
|
+
number: Number of results to return (uses config default if None)
|
|
1013
|
+
meta_data_filters: Optional metadata filters to apply
|
|
1014
|
+
min_rank: Minimum rank threshold (0.0 to 1.0, default 0.0)
|
|
1015
|
+
ranking_algorithm: Ranking algorithm to use ('bm25' or 'ts_rank', uses config default if None)
|
|
1016
|
+
|
|
1017
|
+
Returns:
|
|
1018
|
+
Tuple of (results dict, list of Node objects)
|
|
1019
|
+
|
|
1020
|
+
Example:
|
|
1021
|
+
# Using BM25 (default)
|
|
1022
|
+
results, nodes = store.keyword_search(
|
|
1023
|
+
collection_name="my_docs",
|
|
1024
|
+
query="machine learning algorithms",
|
|
1025
|
+
number=10,
|
|
1026
|
+
min_rank=0.01
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
# Using ts_rank
|
|
1030
|
+
results, nodes = store.keyword_search(
|
|
1031
|
+
collection_name="my_docs",
|
|
1032
|
+
query="machine learning algorithms",
|
|
1033
|
+
number=10,
|
|
1034
|
+
ranking_algorithm="ts_rank"
|
|
1035
|
+
)
|
|
1036
|
+
"""
|
|
1037
|
+
# Use config defaults
|
|
1038
|
+
if number is None:
|
|
1039
|
+
number = self.vs_config.search.default_top_k
|
|
1040
|
+
if ranking_algorithm is None:
|
|
1041
|
+
ranking_algorithm = self.vs_config.search.keyword_ranking_algorithm
|
|
1042
|
+
|
|
1043
|
+
logger.info(
|
|
1044
|
+
f"Keyword search ({ranking_algorithm}) in '{collection_name}' for: '{query}'")
|
|
1045
|
+
|
|
1046
|
+
# Get collection
|
|
1047
|
+
try:
|
|
1048
|
+
collection = self.get_collection(collection_name)
|
|
1049
|
+
except ValueError as e:
|
|
1050
|
+
logger.error(str(e))
|
|
1051
|
+
raise
|
|
1052
|
+
|
|
1053
|
+
# Use correct table name
|
|
1054
|
+
table_name = self.table_nodeentry if self.use_dimension_specific_tables else NodeEntry._meta.db_table
|
|
1055
|
+
|
|
1056
|
+
# Build WHERE clause with metadata filters
|
|
1057
|
+
where_conditions = ["collection_id = %s"]
|
|
1058
|
+
where_params = [collection.id]
|
|
1059
|
+
|
|
1060
|
+
if meta_data_filters:
|
|
1061
|
+
for key, value in meta_data_filters.items():
|
|
1062
|
+
where_conditions.append(f"custom_metadata->>'{key}' = %s")
|
|
1063
|
+
where_params.append(str(value))
|
|
1064
|
+
|
|
1065
|
+
where_clause = " AND ".join(where_conditions)
|
|
1066
|
+
|
|
1067
|
+
# Build SQL query based on ranking algorithm
|
|
1068
|
+
if ranking_algorithm == "bm25":
|
|
1069
|
+
# BM25 uses where_clause twice (doc_stats and collection_stats)
|
|
1070
|
+
# Parameter order: where_params (doc_stats), where_params (collection_stats), query, min_rank, limit
|
|
1071
|
+
sql_query = self._build_bm25_query(
|
|
1072
|
+
table_name, where_clause, query, min_rank, number)
|
|
1073
|
+
query_params = where_params + \
|
|
1074
|
+
where_params + [query, min_rank, number]
|
|
1075
|
+
else: # ts_rank
|
|
1076
|
+
# ts_rank uses where_clause once
|
|
1077
|
+
# Parameter order: query (SELECT), where_params (WHERE), query (AND), query (AND), min_rank, limit
|
|
1078
|
+
sql_query = self._build_tsrank_query(
|
|
1079
|
+
table_name, where_clause, query, min_rank, number)
|
|
1080
|
+
query_params = [query] + where_params + \
|
|
1081
|
+
[query, query, min_rank, number]
|
|
1082
|
+
|
|
1083
|
+
# Execute query
|
|
1084
|
+
with connection.cursor() as cursor:
|
|
1085
|
+
cursor.execute(sql_query, query_params)
|
|
1086
|
+
results = cursor.fetchall()
|
|
1087
|
+
columns = [col[0] for col in cursor.description]
|
|
1088
|
+
|
|
1089
|
+
# Process results
|
|
1090
|
+
valid_suggestions = {}
|
|
1091
|
+
suggested_nodes = []
|
|
1092
|
+
seen_texts = set()
|
|
1093
|
+
|
|
1094
|
+
for row in results:
|
|
1095
|
+
result_dict = dict(zip(columns, row))
|
|
1096
|
+
node_id = result_dict["node_id"]
|
|
1097
|
+
content = result_dict["content"]
|
|
1098
|
+
rank = result_dict["rank"]
|
|
1099
|
+
|
|
1100
|
+
if content not in seen_texts:
|
|
1101
|
+
seen_texts.add(content)
|
|
1102
|
+
|
|
1103
|
+
custom_metadata = result_dict["custom_metadata"] or {}
|
|
1104
|
+
|
|
1105
|
+
metadata = NodeMetadata(
|
|
1106
|
+
source_file_uuid=result_dict["source_file_uuid"],
|
|
1107
|
+
position=result_dict["position"],
|
|
1108
|
+
custom=custom_metadata,
|
|
1109
|
+
)
|
|
1110
|
+
metadata.node_id = node_id
|
|
1111
|
+
|
|
1112
|
+
node = Node(content=content, metadata=metadata)
|
|
1113
|
+
suggested_nodes.append(node)
|
|
1114
|
+
|
|
1115
|
+
valid_suggestions[str(node_id)] = (
|
|
1116
|
+
{
|
|
1117
|
+
"node_id": node_id,
|
|
1118
|
+
"source_file_uuid": result_dict["source_file_uuid"],
|
|
1119
|
+
"position": result_dict["position"],
|
|
1120
|
+
"custom": custom_metadata,
|
|
1121
|
+
},
|
|
1122
|
+
content,
|
|
1123
|
+
float(rank),
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
logger.info(
|
|
1127
|
+
f"Keyword search returned {len(valid_suggestions)} results")
|
|
1128
|
+
return valid_suggestions, suggested_nodes
|
|
1129
|
+
|
|
1130
|
+
def _build_bm25_query(
|
|
1131
|
+
self,
|
|
1132
|
+
table_name: str,
|
|
1133
|
+
where_clause: str,
|
|
1134
|
+
query: str,
|
|
1135
|
+
min_rank: float,
|
|
1136
|
+
limit: int
|
|
1137
|
+
) -> str:
|
|
1138
|
+
"""
|
|
1139
|
+
Build BM25 ranking SQL query.
|
|
1140
|
+
|
|
1141
|
+
BM25 (Best Matching 25) is a ranking function used by search engines.
|
|
1142
|
+
It's based on the probabilistic retrieval framework and considers:
|
|
1143
|
+
- Term frequency (TF): How often query terms appear in the document
|
|
1144
|
+
- Inverse document frequency (IDF): How rare the terms are across all documents
|
|
1145
|
+
- Document length normalization: Adjusts for document length
|
|
1146
|
+
|
|
1147
|
+
Formula: BM25(D,Q) = Σ IDF(qi) * (f(qi,D) * (k1 + 1)) / (f(qi,D) + k1 * (1 - b + b * |D| / avgdl))
|
|
1148
|
+
|
|
1149
|
+
Where:
|
|
1150
|
+
- D: document
|
|
1151
|
+
- Q: query
|
|
1152
|
+
- qi: query term i
|
|
1153
|
+
- f(qi,D): frequency of qi in D
|
|
1154
|
+
- |D|: length of document D
|
|
1155
|
+
- avgdl: average document length in the collection
|
|
1156
|
+
- k1: term frequency saturation parameter (default: 1.5)
|
|
1157
|
+
- b: length normalization parameter (default: 0.75)
|
|
1158
|
+
"""
|
|
1159
|
+
k1 = self.vs_config.search.bm25_k1
|
|
1160
|
+
b = self.vs_config.search.bm25_b
|
|
1161
|
+
|
|
1162
|
+
return f"""
|
|
1163
|
+
WITH doc_stats AS (
|
|
1164
|
+
-- Calculate document statistics
|
|
1165
|
+
SELECT
|
|
1166
|
+
node_id,
|
|
1167
|
+
content,
|
|
1168
|
+
source_file_uuid,
|
|
1169
|
+
position,
|
|
1170
|
+
custom_metadata,
|
|
1171
|
+
LENGTH(content) AS doc_length,
|
|
1172
|
+
to_tsvector('english', content) AS doc_vector
|
|
1173
|
+
FROM
|
|
1174
|
+
{table_name}
|
|
1175
|
+
WHERE
|
|
1176
|
+
{where_clause}
|
|
1177
|
+
),
|
|
1178
|
+
collection_stats AS (
|
|
1179
|
+
-- Calculate collection-wide statistics
|
|
1180
|
+
SELECT
|
|
1181
|
+
AVG(LENGTH(content)) AS avg_doc_length,
|
|
1182
|
+
COUNT(*) AS total_docs
|
|
1183
|
+
FROM
|
|
1184
|
+
{table_name}
|
|
1185
|
+
WHERE
|
|
1186
|
+
{where_clause}
|
|
1187
|
+
),
|
|
1188
|
+
query_terms AS (
|
|
1189
|
+
-- Extract query terms and calculate IDF
|
|
1190
|
+
SELECT
|
|
1191
|
+
word,
|
|
1192
|
+
-- IDF calculation: log((N - df + 0.5) / (df + 0.5) + 1)
|
|
1193
|
+
-- where N is total docs and df is document frequency
|
|
1194
|
+
LN(
|
|
1195
|
+
(cs.total_docs - COUNT(DISTINCT ds.node_id) + 0.5) /
|
|
1196
|
+
(COUNT(DISTINCT ds.node_id) + 0.5) + 1
|
|
1197
|
+
) AS idf
|
|
1198
|
+
FROM
|
|
1199
|
+
unnest(string_to_array(lower(%s), ' ')) AS word,
|
|
1200
|
+
doc_stats ds,
|
|
1201
|
+
collection_stats cs
|
|
1202
|
+
WHERE
|
|
1203
|
+
ds.doc_vector @@ to_tsquery('english', word)
|
|
1204
|
+
GROUP BY
|
|
1205
|
+
word, cs.total_docs
|
|
1206
|
+
),
|
|
1207
|
+
bm25_scores AS (
|
|
1208
|
+
-- Calculate BM25 score for each document
|
|
1209
|
+
SELECT
|
|
1210
|
+
ds.node_id,
|
|
1211
|
+
ds.content,
|
|
1212
|
+
ds.source_file_uuid,
|
|
1213
|
+
ds.position,
|
|
1214
|
+
ds.custom_metadata,
|
|
1215
|
+
SUM(
|
|
1216
|
+
qt.idf *
|
|
1217
|
+
(
|
|
1218
|
+
-- Term frequency component
|
|
1219
|
+
(ts_rank(ds.doc_vector, to_tsquery('english', qt.word)) * 1000 * ({k1} + 1)) /
|
|
1220
|
+
(
|
|
1221
|
+
ts_rank(ds.doc_vector, to_tsquery('english', qt.word)) * 1000 +
|
|
1222
|
+
{k1} * (1 - {b} + {b} * ds.doc_length / cs.avg_doc_length)
|
|
1223
|
+
)
|
|
1224
|
+
)
|
|
1225
|
+
) AS bm25_score
|
|
1226
|
+
FROM
|
|
1227
|
+
doc_stats ds
|
|
1228
|
+
CROSS JOIN
|
|
1229
|
+
collection_stats cs
|
|
1230
|
+
CROSS JOIN
|
|
1231
|
+
query_terms qt
|
|
1232
|
+
WHERE
|
|
1233
|
+
ds.doc_vector @@ to_tsquery('english', qt.word)
|
|
1234
|
+
GROUP BY
|
|
1235
|
+
ds.node_id,
|
|
1236
|
+
ds.content,
|
|
1237
|
+
ds.source_file_uuid,
|
|
1238
|
+
ds.position,
|
|
1239
|
+
ds.custom_metadata
|
|
1240
|
+
)
|
|
1241
|
+
SELECT
|
|
1242
|
+
node_id,
|
|
1243
|
+
content,
|
|
1244
|
+
source_file_uuid,
|
|
1245
|
+
position,
|
|
1246
|
+
custom_metadata,
|
|
1247
|
+
COALESCE(bm25_score, 0.0) AS rank
|
|
1248
|
+
FROM
|
|
1249
|
+
bm25_scores
|
|
1250
|
+
WHERE
|
|
1251
|
+
COALESCE(bm25_score, 0.0) > %s
|
|
1252
|
+
ORDER BY
|
|
1253
|
+
rank DESC
|
|
1254
|
+
LIMIT
|
|
1255
|
+
%s
|
|
1256
|
+
"""
|
|
1257
|
+
|
|
1258
|
+
def _build_tsrank_query(
|
|
1259
|
+
self,
|
|
1260
|
+
table_name: str,
|
|
1261
|
+
where_clause: str,
|
|
1262
|
+
query: str,
|
|
1263
|
+
min_rank: float,
|
|
1264
|
+
limit: int
|
|
1265
|
+
) -> str:
|
|
1266
|
+
"""
|
|
1267
|
+
Build ts_rank SQL query.
|
|
1268
|
+
|
|
1269
|
+
Uses PostgreSQL's native full-text search ranking function.
|
|
1270
|
+
"""
|
|
1271
|
+
return f"""
|
|
1272
|
+
SELECT
|
|
1273
|
+
node_id,
|
|
1274
|
+
content,
|
|
1275
|
+
source_file_uuid,
|
|
1276
|
+
position,
|
|
1277
|
+
custom_metadata,
|
|
1278
|
+
ts_rank(to_tsvector('english', content), plainto_tsquery('english', %s)) AS rank
|
|
1279
|
+
FROM
|
|
1280
|
+
{table_name}
|
|
1281
|
+
WHERE
|
|
1282
|
+
{where_clause}
|
|
1283
|
+
AND to_tsvector('english', content) @@ plainto_tsquery('english', %s)
|
|
1284
|
+
AND ts_rank(to_tsvector('english', content), plainto_tsquery('english', %s)) > %s
|
|
1285
|
+
ORDER BY
|
|
1286
|
+
rank DESC
|
|
1287
|
+
LIMIT
|
|
1288
|
+
%s
|
|
1289
|
+
"""
|
|
1290
|
+
|
|
1291
|
+
def search(
|
|
1292
|
+
self,
|
|
1293
|
+
collection_name: str,
|
|
1294
|
+
query: str,
|
|
1295
|
+
distance_type: Optional[str] = None,
|
|
1296
|
+
number: Optional[int] = None,
|
|
1297
|
+
meta_data_filters: Optional[Dict[str, Any]] = None,
|
|
1298
|
+
hybrid_search: Optional[bool] = None
|
|
1299
|
+
) -> Tuple[Dict, List[Node]]:
|
|
1300
|
+
"""
|
|
1301
|
+
Search for similar vectors in collection.
|
|
1302
|
+
|
|
1303
|
+
Args:
|
|
1304
|
+
collection_name: Name of collection to search
|
|
1305
|
+
query: Search query
|
|
1306
|
+
distance_type: Distance metric (uses config default if None)
|
|
1307
|
+
number: Number of results (uses config default if None)
|
|
1308
|
+
meta_data_filters: Metadata filters
|
|
1309
|
+
hybrid_search: Enable hybrid search (uses config default if None)
|
|
1310
|
+
|
|
1311
|
+
Returns:
|
|
1312
|
+
Tuple of (results dict, list of Node objects)
|
|
1313
|
+
"""
|
|
1314
|
+
# Use config defaults
|
|
1315
|
+
if distance_type is None:
|
|
1316
|
+
distance_type = self.vs_config.search.similarity_metric
|
|
1317
|
+
if number is None:
|
|
1318
|
+
number = self.vs_config.search.default_top_k
|
|
1319
|
+
if hybrid_search is None:
|
|
1320
|
+
hybrid_search = self.vs_config.search.enable_hybrid_search
|
|
1321
|
+
|
|
1322
|
+
logger.info(f"Searching in '{collection_name}' for: '{query}'")
|
|
1323
|
+
|
|
1324
|
+
# Get collection
|
|
1325
|
+
try:
|
|
1326
|
+
collection = self.get_collection(collection_name)
|
|
1327
|
+
except ValueError as e:
|
|
1328
|
+
logger.error(str(e))
|
|
1329
|
+
raise
|
|
1330
|
+
|
|
1331
|
+
# Get query embedding
|
|
1332
|
+
query_embedding = self._get_query_embedding(query)
|
|
1333
|
+
|
|
1334
|
+
# Normalize if using cosine distance
|
|
1335
|
+
if distance_type == "cosine":
|
|
1336
|
+
query_embedding = self._normalize_embedding(query_embedding)
|
|
1337
|
+
|
|
1338
|
+
# Determine search buffer
|
|
1339
|
+
search_buffer_factor = (
|
|
1340
|
+
self.vs_config.search.search_buffer_factor if hybrid_search else 1
|
|
1341
|
+
)
|
|
1342
|
+
limit = number * search_buffer_factor
|
|
1343
|
+
|
|
1344
|
+
# Build SQL query
|
|
1345
|
+
distance_operator = self._get_distance_operator(distance_type)
|
|
1346
|
+
embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
|
1347
|
+
|
|
1348
|
+
# Use correct table name
|
|
1349
|
+
table_name = self.table_nodeentry if self.use_dimension_specific_tables else NodeEntry._meta.db_table
|
|
1350
|
+
|
|
1351
|
+
# Build WHERE clause with metadata filters
|
|
1352
|
+
where_conditions = ["collection_id = %s"]
|
|
1353
|
+
query_params = [embedding_str, collection.id]
|
|
1354
|
+
|
|
1355
|
+
if meta_data_filters:
|
|
1356
|
+
for key, value in meta_data_filters.items():
|
|
1357
|
+
where_conditions.append(f"custom_metadata->>'{key}' = %s")
|
|
1358
|
+
query_params.append(str(value))
|
|
1359
|
+
|
|
1360
|
+
where_clause = " AND ".join(where_conditions)
|
|
1361
|
+
query_params.append(limit)
|
|
1362
|
+
|
|
1363
|
+
sql_query = f"""
|
|
1364
|
+
SELECT
|
|
1365
|
+
node_id,
|
|
1366
|
+
content,
|
|
1367
|
+
source_file_uuid,
|
|
1368
|
+
position,
|
|
1369
|
+
custom_metadata,
|
|
1370
|
+
embedding {distance_operator} %s::vector AS distance
|
|
1371
|
+
FROM
|
|
1372
|
+
{table_name}
|
|
1373
|
+
WHERE
|
|
1374
|
+
{where_clause}
|
|
1375
|
+
ORDER BY
|
|
1376
|
+
distance
|
|
1377
|
+
LIMIT
|
|
1378
|
+
%s
|
|
1379
|
+
"""
|
|
1380
|
+
|
|
1381
|
+
# Execute query
|
|
1382
|
+
with connection.cursor() as cursor:
|
|
1383
|
+
cursor.execute(sql_query, query_params)
|
|
1384
|
+
results = cursor.fetchall()
|
|
1385
|
+
columns = [col[0] for col in cursor.description]
|
|
1386
|
+
|
|
1387
|
+
# Process results
|
|
1388
|
+
valid_suggestions = {}
|
|
1389
|
+
suggested_nodes = []
|
|
1390
|
+
seen_texts = set()
|
|
1391
|
+
|
|
1392
|
+
for row in results:
|
|
1393
|
+
result_dict = dict(zip(columns, row))
|
|
1394
|
+
node_id = result_dict["node_id"]
|
|
1395
|
+
content = result_dict["content"]
|
|
1396
|
+
distance = result_dict["distance"]
|
|
1397
|
+
|
|
1398
|
+
if content not in seen_texts:
|
|
1399
|
+
seen_texts.add(content)
|
|
1400
|
+
|
|
1401
|
+
custom_metadata = result_dict["custom_metadata"] or {}
|
|
1402
|
+
|
|
1403
|
+
metadata = NodeMetadata(
|
|
1404
|
+
source_file_uuid=result_dict["source_file_uuid"],
|
|
1405
|
+
position=result_dict["position"],
|
|
1406
|
+
custom=custom_metadata,
|
|
1407
|
+
)
|
|
1408
|
+
metadata.node_id = node_id
|
|
1409
|
+
|
|
1410
|
+
node = Node(content=content, metadata=metadata)
|
|
1411
|
+
suggested_nodes.append(node)
|
|
1412
|
+
|
|
1413
|
+
valid_suggestions[str(node_id)] = (
|
|
1414
|
+
{
|
|
1415
|
+
"node_id": node_id,
|
|
1416
|
+
"source_file_uuid": result_dict["source_file_uuid"],
|
|
1417
|
+
"position": result_dict["position"],
|
|
1418
|
+
"custom": custom_metadata,
|
|
1419
|
+
},
|
|
1420
|
+
content,
|
|
1421
|
+
float(distance),
|
|
1422
|
+
)
|
|
1423
|
+
|
|
1424
|
+
# Apply hybrid search and re-ranking if enabled
|
|
1425
|
+
if hybrid_search and self.vs_config.search.rerank and valid_suggestions:
|
|
1426
|
+
valid_suggestions, suggested_nodes = self._rerank_results(
|
|
1427
|
+
query, list(valid_suggestions.values()
|
|
1428
|
+
), suggested_nodes, number
|
|
1429
|
+
)
|
|
1430
|
+
|
|
1431
|
+
logger.info(f"Search returned {len(valid_suggestions)} results")
|
|
1432
|
+
return valid_suggestions, suggested_nodes
|
|
1433
|
+
|
|
1434
|
+
def _rerank_results(
|
|
1435
|
+
self,
|
|
1436
|
+
query: str,
|
|
1437
|
+
results: List[Tuple[Dict, str, float]],
|
|
1438
|
+
suggested_nodes: List[Node],
|
|
1439
|
+
top_k: int,
|
|
1440
|
+
) -> Tuple[Dict, List[Node]]:
|
|
1441
|
+
"""Re-rank results using hybrid scoring."""
|
|
1442
|
+
logger.debug(f"Re-ranking {len(results)} results")
|
|
1443
|
+
|
|
1444
|
+
# Get hybrid alpha from config
|
|
1445
|
+
alpha = self.vs_config.search.hybrid_alpha
|
|
1446
|
+
|
|
1447
|
+
# Perform full-text search
|
|
1448
|
+
node_ids = [int(res[0]["node_id"]) for res in results]
|
|
1449
|
+
|
|
1450
|
+
if self.use_dimension_specific_tables:
|
|
1451
|
+
# Use raw SQL for full-text search with custom tables
|
|
1452
|
+
with connection.cursor() as cursor:
|
|
1453
|
+
placeholders = ','.join(['%s'] * len(node_ids))
|
|
1454
|
+
cursor.execute(
|
|
1455
|
+
f"""
|
|
1456
|
+
SELECT node_id,
|
|
1457
|
+
ts_rank(to_tsvector('english', content), plainto_tsquery('english', %s)) as rank
|
|
1458
|
+
FROM {self.table_nodeentry}
|
|
1459
|
+
WHERE node_id IN ({placeholders})
|
|
1460
|
+
""",
|
|
1461
|
+
[query] + node_ids
|
|
1462
|
+
)
|
|
1463
|
+
node_id_to_rank = {row[0]: row[1] for row in cursor.fetchall()}
|
|
1464
|
+
else:
|
|
1465
|
+
search_query = SearchQuery(query, config="english")
|
|
1466
|
+
queryset = NodeEntry.objects.filter(
|
|
1467
|
+
node_id__in=node_ids
|
|
1468
|
+
).annotate(
|
|
1469
|
+
rank=SearchRank(SearchVector(
|
|
1470
|
+
"content", config="english"), search_query)
|
|
1471
|
+
)
|
|
1472
|
+
node_id_to_rank = {node.node_id: node.rank for node in queryset}
|
|
1473
|
+
|
|
1474
|
+
# Combine scores
|
|
1475
|
+
reranked_results = []
|
|
1476
|
+
|
|
1477
|
+
for metadata, content, distance in results:
|
|
1478
|
+
node_id = metadata["node_id"]
|
|
1479
|
+
keyword_score = node_id_to_rank.get(node_id, 0.0)
|
|
1480
|
+
|
|
1481
|
+
# Combined score: alpha * vector + (1-alpha) * keyword
|
|
1482
|
+
combined_score = alpha * (1 - distance) + \
|
|
1483
|
+
(1 - alpha) * keyword_score
|
|
1484
|
+
reranked_results.append((metadata, content, combined_score))
|
|
1485
|
+
|
|
1486
|
+
# Sort and take top_k
|
|
1487
|
+
reranked_results = sorted(
|
|
1488
|
+
reranked_results, key=lambda x: x[2], reverse=True)[:top_k]
|
|
1489
|
+
valid_suggestions = {
|
|
1490
|
+
str(res[0]["node_id"]): res for res in reranked_results}
|
|
1491
|
+
|
|
1492
|
+
# Update node order
|
|
1493
|
+
node_id_order = [res[0]["node_id"] for res in reranked_results]
|
|
1494
|
+
updated_nodes = sorted(
|
|
1495
|
+
suggested_nodes,
|
|
1496
|
+
key=lambda node: (
|
|
1497
|
+
node_id_order.index(node.metadata.node_id)
|
|
1498
|
+
if node.metadata.node_id in node_id_order
|
|
1499
|
+
else len(node_id_order)
|
|
1500
|
+
),
|
|
1501
|
+
)[:top_k]
|
|
1502
|
+
|
|
1503
|
+
return valid_suggestions, updated_nodes
|
|
1504
|
+
|
|
1505
|
+
def list_collections(self) -> List[str]:
|
|
1506
|
+
"""List all collections."""
|
|
1507
|
+
if self.use_dimension_specific_tables:
|
|
1508
|
+
with connection.cursor() as cursor:
|
|
1509
|
+
cursor.execute(
|
|
1510
|
+
f"SELECT name FROM {self.table_collection} ORDER BY name")
|
|
1511
|
+
return [row[0] for row in cursor.fetchall()]
|
|
1512
|
+
else:
|
|
1513
|
+
return list(Collection.objects.values_list("name", flat=True))
|
|
1514
|
+
|
|
1515
|
+
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
|
|
1516
|
+
"""Get information about a collection."""
|
|
1517
|
+
collection = self.get_collection(collection_name)
|
|
1518
|
+
|
|
1519
|
+
if self.use_dimension_specific_tables:
|
|
1520
|
+
with connection.cursor() as cursor:
|
|
1521
|
+
cursor.execute(
|
|
1522
|
+
f"SELECT COUNT(*) FROM {self.table_nodeentry} WHERE collection_id = %s",
|
|
1523
|
+
[collection.id]
|
|
1524
|
+
)
|
|
1525
|
+
node_count = cursor.fetchone()[0]
|
|
1526
|
+
else:
|
|
1527
|
+
node_count = NodeEntry.objects.filter(
|
|
1528
|
+
collection=collection).count()
|
|
1529
|
+
return {
|
|
1530
|
+
"name": collection.name,
|
|
1531
|
+
"embedding_dim": collection.embedding_dim,
|
|
1532
|
+
"node_count": node_count,
|
|
1533
|
+
"created_at": collection.created_at,
|
|
1534
|
+
"updated_at": collection.updated_at,
|
|
1535
|
+
}
|
|
1536
|
+
|
|
1537
|
+
# VectorStore interface methods
|
|
1538
|
+
def add(self, vectors: List[List[float]], metadatas: List[Dict[str, Any]]) -> Any:
|
|
1539
|
+
"""Add vectors with metadata (VectorStore interface)."""
|
|
1540
|
+
if not vectors or not metadatas:
|
|
1541
|
+
logger.warning("Empty vectors or metadatas")
|
|
1542
|
+
return []
|
|
1543
|
+
|
|
1544
|
+
if len(vectors) != len(metadatas):
|
|
1545
|
+
raise ValueError(
|
|
1546
|
+
"Number of vectors must match number of metadatas")
|
|
1547
|
+
|
|
1548
|
+
collection_name = metadatas[0].get(
|
|
1549
|
+
"collection_name", "default_collection")
|
|
1550
|
+
collection = self.get_or_create_collection(collection_name)
|
|
1551
|
+
|
|
1552
|
+
if self.use_dimension_specific_tables:
|
|
1553
|
+
# Use raw SQL when custom tables are in use
|
|
1554
|
+
import json
|
|
1555
|
+
node_ids = []
|
|
1556
|
+
with connection.cursor() as cursor:
|
|
1557
|
+
for i, (vector, metadata) in enumerate(zip(vectors, metadatas)):
|
|
1558
|
+
custom_metadata = {
|
|
1559
|
+
k: v
|
|
1560
|
+
for k, v in metadata.items()
|
|
1561
|
+
if k not in ["content", "source_file_uuid", "position", "collection_name"]
|
|
1562
|
+
}
|
|
1563
|
+
custom_metadata_json = json.dumps(custom_metadata)
|
|
1564
|
+
|
|
1565
|
+
cursor.execute(
|
|
1566
|
+
f"""
|
|
1567
|
+
INSERT INTO {self.table_nodeentry}
|
|
1568
|
+
(collection_id, content, embedding, source_file_uuid, position, custom_metadata, created_at, updated_at)
|
|
1569
|
+
VALUES (%s, %s, %s, %s, %s, %s::jsonb, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
|
1570
|
+
RETURNING node_id
|
|
1571
|
+
""",
|
|
1572
|
+
[
|
|
1573
|
+
collection.id,
|
|
1574
|
+
metadata.get("content", ""),
|
|
1575
|
+
vector,
|
|
1576
|
+
metadata.get("source_file_uuid", ""),
|
|
1577
|
+
metadata.get("position", i),
|
|
1578
|
+
custom_metadata_json
|
|
1579
|
+
]
|
|
1580
|
+
)
|
|
1581
|
+
node_id = cursor.fetchone()[0]
|
|
1582
|
+
node_ids.append(node_id)
|
|
1583
|
+
return node_ids
|
|
1584
|
+
else:
|
|
1585
|
+
node_entries = []
|
|
1586
|
+
for i, (vector, metadata) in enumerate(zip(vectors, metadatas)):
|
|
1587
|
+
node_entries.append(
|
|
1588
|
+
NodeEntry(
|
|
1589
|
+
collection=collection,
|
|
1590
|
+
content=metadata.get("content", ""),
|
|
1591
|
+
embedding=vector,
|
|
1592
|
+
source_file_uuid=metadata.get("source_file_uuid", ""),
|
|
1593
|
+
position=metadata.get("position", i),
|
|
1594
|
+
custom_metadata={
|
|
1595
|
+
k: v
|
|
1596
|
+
for k, v in metadata.items()
|
|
1597
|
+
if k not in ["content", "source_file_uuid", "position", "collection_name"]
|
|
1598
|
+
},
|
|
1599
|
+
)
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1602
|
+
created_entries = NodeEntry.objects.bulk_create(node_entries)
|
|
1603
|
+
return [entry.node_id for entry in created_entries]
|
|
1604
|
+
|
|
1605
|
+
def query(self, vector: List[float], top_k: int = 5, **kwargs) -> List[Dict[str, Any]]:
|
|
1606
|
+
"""Query vector store (VectorStore interface)."""
|
|
1607
|
+
collection_name = kwargs.get("collection_name", "default_collection")
|
|
1608
|
+
distance_type = kwargs.get(
|
|
1609
|
+
"distance_type", self.vs_config.search.similarity_metric)
|
|
1610
|
+
|
|
1611
|
+
try:
|
|
1612
|
+
collection = self.get_collection(collection_name)
|
|
1613
|
+
except ValueError:
|
|
1614
|
+
logger.warning(f"Collection '{collection_name}' not found")
|
|
1615
|
+
return []
|
|
1616
|
+
|
|
1617
|
+
# Normalize query vector if needed
|
|
1618
|
+
query_embedding = np.array(vector, dtype=np.float32)
|
|
1619
|
+
if distance_type == "cosine":
|
|
1620
|
+
query_embedding = self._normalize_embedding(query_embedding)
|
|
1621
|
+
|
|
1622
|
+
# Build and execute query
|
|
1623
|
+
distance_operator = self._get_distance_operator(distance_type)
|
|
1624
|
+
embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
|
1625
|
+
|
|
1626
|
+
# Use correct table name
|
|
1627
|
+
table_name = self.table_nodeentry if self.use_dimension_specific_tables else NodeEntry._meta.db_table
|
|
1628
|
+
|
|
1629
|
+
sql_query = f"""
|
|
1630
|
+
SELECT node_id, content, source_file_uuid, position, custom_metadata,
|
|
1631
|
+
embedding {distance_operator} %s::vector AS distance
|
|
1632
|
+
FROM {table_name}
|
|
1633
|
+
WHERE collection_id = %s
|
|
1634
|
+
ORDER BY distance
|
|
1635
|
+
LIMIT %s
|
|
1636
|
+
"""
|
|
1637
|
+
|
|
1638
|
+
with connection.cursor() as cursor:
|
|
1639
|
+
cursor.execute(sql_query, [embedding_str, collection.id, top_k])
|
|
1640
|
+
results = cursor.fetchall()
|
|
1641
|
+
columns = [col[0] for col in cursor.description]
|
|
1642
|
+
|
|
1643
|
+
return [
|
|
1644
|
+
{
|
|
1645
|
+
"node_id": dict(zip(columns, row))["node_id"],
|
|
1646
|
+
"content": dict(zip(columns, row))["content"],
|
|
1647
|
+
"metadata": dict(zip(columns, row))["custom_metadata"] or {},
|
|
1648
|
+
"distance": float(dict(zip(columns, row))["distance"]),
|
|
1649
|
+
}
|
|
1650
|
+
for row in results
|
|
1651
|
+
]
|
|
1652
|
+
|
|
1653
|
+
def shutdown(self) -> None:
|
|
1654
|
+
"""Shutdown and cleanup resources."""
|
|
1655
|
+
logger.info("Shutting down ConfigurablePgVectorStore")
|
|
1656
|
+
if self.embedding_model:
|
|
1657
|
+
self.embedding_model.shutdown()
|
|
1658
|
+
super().shutdown()
|
|
1659
|
+
|
|
1660
|
+
|
|
1661
|
+
__all__ = ["ConfigurablePgVectorStore"]
|