rakam-systems-vectorstore 0.1.1rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rakam_systems_vectorstore/MANIFEST.in +26 -0
- rakam_systems_vectorstore/README.md +1071 -0
- rakam_systems_vectorstore/__init__.py +93 -0
- rakam_systems_vectorstore/components/__init__.py +0 -0
- rakam_systems_vectorstore/components/chunker/__init__.py +19 -0
- rakam_systems_vectorstore/components/chunker/advanced_chunker.py +1019 -0
- rakam_systems_vectorstore/components/chunker/text_chunker.py +154 -0
- rakam_systems_vectorstore/components/embedding_model/__init__.py +0 -0
- rakam_systems_vectorstore/components/embedding_model/configurable_embeddings.py +546 -0
- rakam_systems_vectorstore/components/embedding_model/openai_embeddings.py +259 -0
- rakam_systems_vectorstore/components/loader/__init__.py +31 -0
- rakam_systems_vectorstore/components/loader/adaptive_loader.py +512 -0
- rakam_systems_vectorstore/components/loader/code_loader.py +699 -0
- rakam_systems_vectorstore/components/loader/doc_loader.py +812 -0
- rakam_systems_vectorstore/components/loader/eml_loader.py +556 -0
- rakam_systems_vectorstore/components/loader/html_loader.py +626 -0
- rakam_systems_vectorstore/components/loader/md_loader.py +622 -0
- rakam_systems_vectorstore/components/loader/odt_loader.py +750 -0
- rakam_systems_vectorstore/components/loader/pdf_loader.py +771 -0
- rakam_systems_vectorstore/components/loader/pdf_loader_light.py +723 -0
- rakam_systems_vectorstore/components/loader/tabular_loader.py +597 -0
- rakam_systems_vectorstore/components/vectorstore/__init__.py +0 -0
- rakam_systems_vectorstore/components/vectorstore/apps.py +10 -0
- rakam_systems_vectorstore/components/vectorstore/configurable_pg_vector_store.py +1661 -0
- rakam_systems_vectorstore/components/vectorstore/faiss_vector_store.py +878 -0
- rakam_systems_vectorstore/components/vectorstore/migrations/0001_initial.py +55 -0
- rakam_systems_vectorstore/components/vectorstore/migrations/__init__.py +0 -0
- rakam_systems_vectorstore/components/vectorstore/models.py +10 -0
- rakam_systems_vectorstore/components/vectorstore/pg_models.py +97 -0
- rakam_systems_vectorstore/components/vectorstore/pg_vector_store.py +827 -0
- rakam_systems_vectorstore/config.py +266 -0
- rakam_systems_vectorstore/core.py +8 -0
- rakam_systems_vectorstore/pyproject.toml +113 -0
- rakam_systems_vectorstore/server/README.md +290 -0
- rakam_systems_vectorstore/server/__init__.py +20 -0
- rakam_systems_vectorstore/server/mcp_server_vector.py +325 -0
- rakam_systems_vectorstore/setup.py +103 -0
- rakam_systems_vectorstore-0.1.1rc7.dist-info/METADATA +370 -0
- rakam_systems_vectorstore-0.1.1rc7.dist-info/RECORD +40 -0
- rakam_systems_vectorstore-0.1.1rc7.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,827 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from typing import Any
|
|
5
|
+
from typing import Dict
|
|
6
|
+
from typing import List
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
|
|
10
|
+
import dotenv
|
|
11
|
+
import numpy as np
|
|
12
|
+
from django.contrib.postgres.search import SearchQuery
|
|
13
|
+
from django.contrib.postgres.search import SearchRank
|
|
14
|
+
from django.contrib.postgres.search import SearchVector
|
|
15
|
+
from django.db import connection
|
|
16
|
+
from django.db import transaction
|
|
17
|
+
from openai import OpenAI
|
|
18
|
+
from sentence_transformers import SentenceTransformer
|
|
19
|
+
|
|
20
|
+
from rakam_systems_core.ai_utils import logging
|
|
21
|
+
from rakam_systems_core.ai_core.interfaces.vectorstore import VectorStore
|
|
22
|
+
from rakam_systems_vectorstore.components.vectorstore.pg_models import Collection
|
|
23
|
+
from rakam_systems_vectorstore.components.vectorstore.pg_models import NodeEntry
|
|
24
|
+
from rakam_systems_vectorstore.core import Node
|
|
25
|
+
from rakam_systems_vectorstore.core import NodeMetadata
|
|
26
|
+
from rakam_systems_vectorstore.core import VSFile
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# Load environment variables
|
|
31
|
+
dotenv.load_dotenv()
|
|
32
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PgVectorStore(VectorStore):
|
|
36
|
+
"""
|
|
37
|
+
A class for managing collection-based vector stores using pgvector and Django ORM.
|
|
38
|
+
Enhanced for better semantic search performance with hybrid search, re-ranking, and caching.
|
|
39
|
+
|
|
40
|
+
Note: Vector columns are created without dimension constraints, allowing flexibility
|
|
41
|
+
to use different embedding models without needing to alter the database schema.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
name: str = "pg_vector_store",
|
|
47
|
+
config=None,
|
|
48
|
+
embedding_model: str = "Snowflake/snowflake-arctic-embed-m",
|
|
49
|
+
use_embedding_api: bool = False,
|
|
50
|
+
api_model: str = "text-embedding-3-small",
|
|
51
|
+
) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Initializes the PgVectorStore with the specified embedding model.
|
|
54
|
+
|
|
55
|
+
:param name: Name of the vector store component.
|
|
56
|
+
:param config: Configuration object.
|
|
57
|
+
:param embedding_model: Pre-trained SentenceTransformer model name.
|
|
58
|
+
:param use_embedding_api: Whether to use OpenAI's embedding API instead of local model.
|
|
59
|
+
:param api_model: OpenAI API model to use for embeddings if use_embedding_api is True.
|
|
60
|
+
"""
|
|
61
|
+
super().__init__(name=name, config=config)
|
|
62
|
+
self._ensure_pgvector_extension()
|
|
63
|
+
self.use_embedding_api = use_embedding_api
|
|
64
|
+
|
|
65
|
+
if self.use_embedding_api:
|
|
66
|
+
self.client = OpenAI(api_key=api_key)
|
|
67
|
+
self.api_model = api_model
|
|
68
|
+
sample_embedding = self._get_api_embedding("Sample text")
|
|
69
|
+
self.embedding_dim = len(sample_embedding)
|
|
70
|
+
else:
|
|
71
|
+
self.embedding_model = SentenceTransformer(
|
|
72
|
+
embedding_model, trust_remote_code=True
|
|
73
|
+
)
|
|
74
|
+
self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
|
|
75
|
+
|
|
76
|
+
logger.info(
|
|
77
|
+
f"Initialized PgVectorStore with embedding dimension: {self.embedding_dim}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def _ensure_pgvector_extension(self) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Ensures that the pgvector extension is installed in the PostgreSQL database.
|
|
83
|
+
"""
|
|
84
|
+
with connection.cursor() as cursor:
|
|
85
|
+
try:
|
|
86
|
+
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
|
87
|
+
logger.info("Ensured pgvector extension is installed")
|
|
88
|
+
except Exception as e:
|
|
89
|
+
logger.error(f"Failed to create pgvector extension: {e}")
|
|
90
|
+
raise
|
|
91
|
+
|
|
92
|
+
def _get_api_embedding(self, text: str) -> List[float]:
|
|
93
|
+
"""
|
|
94
|
+
Gets embedding from OpenAI API.
|
|
95
|
+
|
|
96
|
+
:param text: Text to embed
|
|
97
|
+
:return: Embedding vector
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
response = self.client.embeddings.create(
|
|
101
|
+
input=[text], model=self.api_model)
|
|
102
|
+
return response.data[0].embedding
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error(f"Failed to get API embedding: {e}")
|
|
105
|
+
raise
|
|
106
|
+
|
|
107
|
+
@lru_cache(maxsize=1000)
|
|
108
|
+
def predict_embeddings(self, query: str) -> np.ndarray:
|
|
109
|
+
"""
|
|
110
|
+
Predicts embeddings for a given query using the embedding model.
|
|
111
|
+
Caches results to reduce redundant computations.
|
|
112
|
+
|
|
113
|
+
:param query: Query string to encode.
|
|
114
|
+
:return: Normalized embedding vector for the query.
|
|
115
|
+
"""
|
|
116
|
+
logger.debug(f"Predicting embeddings for query: {query}")
|
|
117
|
+
start_time = time.time()
|
|
118
|
+
|
|
119
|
+
if self.use_embedding_api:
|
|
120
|
+
query_embedding = self._get_api_embedding(query)
|
|
121
|
+
query_embedding = np.array(query_embedding, dtype="float32")
|
|
122
|
+
else:
|
|
123
|
+
query_embedding = self.embedding_model.encode(query)
|
|
124
|
+
query_embedding = np.array(query_embedding, dtype="float32")
|
|
125
|
+
|
|
126
|
+
# Normalize embedding for cosine similarity
|
|
127
|
+
norm = np.linalg.norm(query_embedding)
|
|
128
|
+
if norm > 0:
|
|
129
|
+
query_embedding = query_embedding / norm
|
|
130
|
+
else:
|
|
131
|
+
logger.warning(f"Zero norm encountered for query: {query}")
|
|
132
|
+
|
|
133
|
+
logger.debug(
|
|
134
|
+
f"Embedding generation took {time.time() - start_time:.2f} seconds"
|
|
135
|
+
)
|
|
136
|
+
return query_embedding
|
|
137
|
+
|
|
138
|
+
def get_embeddings(
|
|
139
|
+
self, sentences: List[str], parallel: bool = True, batch_size: int = 8
|
|
140
|
+
) -> np.ndarray:
|
|
141
|
+
"""
|
|
142
|
+
Generates embeddings for a list of sentences with normalization.
|
|
143
|
+
|
|
144
|
+
:param sentences: List of sentences to encode.
|
|
145
|
+
:param parallel: Whether to use parallel processing (default is True).
|
|
146
|
+
:param batch_size: Batch size for processing (default is 8).
|
|
147
|
+
:return: Normalized embedding vectors for the sentences.
|
|
148
|
+
"""
|
|
149
|
+
logger.info(f"Generating embeddings for {len(sentences)} sentences")
|
|
150
|
+
start = time.time()
|
|
151
|
+
|
|
152
|
+
if self.use_embedding_api:
|
|
153
|
+
all_embeddings = []
|
|
154
|
+
for i in range(0, len(sentences), batch_size):
|
|
155
|
+
batch = sentences[i: i + batch_size]
|
|
156
|
+
response = self.client.embeddings.create(
|
|
157
|
+
input=batch, model=self.api_model
|
|
158
|
+
)
|
|
159
|
+
batch_embeddings = [data.embedding for data in response.data]
|
|
160
|
+
all_embeddings.extend(batch_embeddings)
|
|
161
|
+
embeddings = np.array(all_embeddings, dtype="float32")
|
|
162
|
+
else:
|
|
163
|
+
if parallel:
|
|
164
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
165
|
+
pool = self.embedding_model.start_multi_process_pool(
|
|
166
|
+
target_devices=["cpu"] * 5
|
|
167
|
+
)
|
|
168
|
+
embeddings = self.embedding_model.encode_multi_process(
|
|
169
|
+
sentences, pool, batch_size=batch_size
|
|
170
|
+
)
|
|
171
|
+
self.embedding_model.stop_multi_process_pool(pool)
|
|
172
|
+
else:
|
|
173
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
174
|
+
embeddings = self.embedding_model.encode(
|
|
175
|
+
sentences,
|
|
176
|
+
batch_size=batch_size,
|
|
177
|
+
show_progress_bar=True,
|
|
178
|
+
convert_to_tensor=True,
|
|
179
|
+
)
|
|
180
|
+
embeddings = embeddings.cpu().detach().numpy()
|
|
181
|
+
|
|
182
|
+
# Normalize embeddings
|
|
183
|
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
184
|
+
norms[norms == 0] = 1 # Avoid division by zero
|
|
185
|
+
embeddings = embeddings / norms
|
|
186
|
+
|
|
187
|
+
logger.info(
|
|
188
|
+
f"Time taken to encode {len(sentences)} items: {time.time() - start:.2f} seconds"
|
|
189
|
+
)
|
|
190
|
+
return embeddings
|
|
191
|
+
|
|
192
|
+
def get_or_create_collection(self, collection_name: str) -> Collection:
|
|
193
|
+
"""
|
|
194
|
+
Gets or creates a collection with the specified name.
|
|
195
|
+
|
|
196
|
+
:param collection_name: Name of the collection.
|
|
197
|
+
:return: Collection object.
|
|
198
|
+
"""
|
|
199
|
+
collection, created = Collection.objects.get_or_create(
|
|
200
|
+
name=collection_name, defaults={
|
|
201
|
+
"embedding_dim": self.embedding_dim}
|
|
202
|
+
)
|
|
203
|
+
logger.info(
|
|
204
|
+
f"{'Created new' if created else 'Using existing'} collection: {collection_name}"
|
|
205
|
+
)
|
|
206
|
+
return collection
|
|
207
|
+
|
|
208
|
+
def _rerank_results(
|
|
209
|
+
self,
|
|
210
|
+
query: str,
|
|
211
|
+
results: List[Tuple[Dict, str, float]],
|
|
212
|
+
suggested_nodes: List[Node],
|
|
213
|
+
top_k: int,
|
|
214
|
+
) -> Tuple[Dict, List[Node]]:
|
|
215
|
+
"""
|
|
216
|
+
Re-ranks search results using a combination of vector similarity and keyword relevance.
|
|
217
|
+
|
|
218
|
+
:param query: The search query.
|
|
219
|
+
:param results: Initial search results (metadata, content, distance).
|
|
220
|
+
:param suggested_nodes: List of Node objects.
|
|
221
|
+
:param top_k: Number of results to return after re-ranking.
|
|
222
|
+
:return: Tuple of re-ranked results dictionary and updated suggested_nodes.
|
|
223
|
+
"""
|
|
224
|
+
logger.debug(f"Re-ranking {len(results)} results for query: {query}")
|
|
225
|
+
|
|
226
|
+
# Perform full-text search to get keyword relevance scores
|
|
227
|
+
search_query = SearchQuery(query, config="english")
|
|
228
|
+
queryset = NodeEntry.objects.filter(
|
|
229
|
+
collection__name="document_collection",
|
|
230
|
+
node_id__in=[int(res[0]["node_id"]) for res in results],
|
|
231
|
+
).annotate(
|
|
232
|
+
rank=SearchRank(SearchVector(
|
|
233
|
+
"content", config="english"), search_query)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Combine vector distance and keyword rank
|
|
237
|
+
reranked_results = []
|
|
238
|
+
node_id_to_rank = {node.node_id: node.rank for node in queryset}
|
|
239
|
+
for metadata, content, distance in results:
|
|
240
|
+
node_id = metadata["node_id"]
|
|
241
|
+
keyword_score = node_id_to_rank.get(node_id, 0.0)
|
|
242
|
+
# Combine scores (adjust weights as needed)
|
|
243
|
+
combined_score = 0.7 * (1 - distance) + 0.3 * keyword_score
|
|
244
|
+
reranked_results.append((metadata, content, combined_score))
|
|
245
|
+
|
|
246
|
+
# Sort by combined score and take top_k
|
|
247
|
+
reranked_results = sorted(reranked_results, key=lambda x: x[2], reverse=True)[
|
|
248
|
+
:top_k
|
|
249
|
+
]
|
|
250
|
+
valid_suggestions = {
|
|
251
|
+
str(res[0]["node_id"]): res for res in reranked_results}
|
|
252
|
+
|
|
253
|
+
# Update suggested_nodes to match re-ranked order
|
|
254
|
+
node_id_order = [res[0]["node_id"] for res in reranked_results]
|
|
255
|
+
updated_nodes = sorted(
|
|
256
|
+
suggested_nodes,
|
|
257
|
+
key=lambda node: node_id_order.index(node.metadata.node_id)
|
|
258
|
+
if node.metadata.node_id in node_id_order
|
|
259
|
+
else len(node_id_order),
|
|
260
|
+
)[:top_k]
|
|
261
|
+
|
|
262
|
+
logger.debug(f"Re-ranked to {len(valid_suggestions)} results")
|
|
263
|
+
return valid_suggestions, updated_nodes
|
|
264
|
+
|
|
265
|
+
def search(
|
|
266
|
+
self,
|
|
267
|
+
collection_name: str,
|
|
268
|
+
query: str,
|
|
269
|
+
distance_type: str = "cosine",
|
|
270
|
+
number: int = 5,
|
|
271
|
+
meta_data_filters: Optional[Dict[str, Any]] = None,
|
|
272
|
+
hybrid_search: bool = True,
|
|
273
|
+
) -> Tuple[Dict, List[Node]]:
|
|
274
|
+
"""
|
|
275
|
+
Retrieve relevant documents from the vector store using hybrid search and re-ranking.
|
|
276
|
+
|
|
277
|
+
:param collection_name: Name of the collection to search.
|
|
278
|
+
:param query: Search query.
|
|
279
|
+
:param distance_type: Distance metric ("cosine", "l2", "dot").
|
|
280
|
+
:param number: Number of results to return.
|
|
281
|
+
:param meta_data_filters: Dictionary of metadata filters (e.g., {"is_validated": True}).
|
|
282
|
+
:param hybrid_search: Whether to use hybrid search combining vector and keyword search.
|
|
283
|
+
:return: Tuple of search results (dictionary) and suggested nodes.
|
|
284
|
+
"""
|
|
285
|
+
logger.info(
|
|
286
|
+
f"Searching in collection: {collection_name} for query: '{query}'")
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
collection = Collection.objects.get(name=collection_name)
|
|
290
|
+
except Collection.DoesNotExist:
|
|
291
|
+
logger.error(f"No collection found with name: {collection_name}")
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"No collection found with name: {collection_name}")
|
|
294
|
+
|
|
295
|
+
# Generate query embedding
|
|
296
|
+
query_embedding = self.predict_embeddings(query)
|
|
297
|
+
|
|
298
|
+
# Build base queryset
|
|
299
|
+
queryset = NodeEntry.objects.filter(collection=collection)
|
|
300
|
+
|
|
301
|
+
# Apply metadata filters
|
|
302
|
+
if meta_data_filters:
|
|
303
|
+
for key, value in meta_data_filters.items():
|
|
304
|
+
queryset = queryset.filter(
|
|
305
|
+
**{f"custom_metadata__{key}": value})
|
|
306
|
+
|
|
307
|
+
# Construct SQL query for vector search
|
|
308
|
+
if distance_type == "cosine":
|
|
309
|
+
distance_operator = "<=>"
|
|
310
|
+
elif distance_type == "l2":
|
|
311
|
+
distance_operator = "<->"
|
|
312
|
+
elif distance_type == "dot":
|
|
313
|
+
distance_operator = "<#>"
|
|
314
|
+
else:
|
|
315
|
+
logger.error(f"Unsupported distance type: {distance_type}")
|
|
316
|
+
raise ValueError(f"Unsupported distance type: {distance_type}")
|
|
317
|
+
|
|
318
|
+
# Request more results for hybrid search and re-ranking
|
|
319
|
+
search_buffer_factor = 2 if hybrid_search else 1
|
|
320
|
+
limit = number * search_buffer_factor
|
|
321
|
+
embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
|
322
|
+
|
|
323
|
+
sql_query = f"""
|
|
324
|
+
SELECT
|
|
325
|
+
node_id,
|
|
326
|
+
content,
|
|
327
|
+
source_file_uuid,
|
|
328
|
+
position,
|
|
329
|
+
custom_metadata,
|
|
330
|
+
embedding {distance_operator} %s::vector AS distance
|
|
331
|
+
FROM
|
|
332
|
+
{NodeEntry._meta.db_table}
|
|
333
|
+
WHERE
|
|
334
|
+
collection_id = %s
|
|
335
|
+
ORDER BY
|
|
336
|
+
distance
|
|
337
|
+
LIMIT
|
|
338
|
+
%s
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
# Execute vector search
|
|
342
|
+
with connection.cursor() as cursor:
|
|
343
|
+
cursor.execute(sql_query, [embedding_str, collection.id, limit])
|
|
344
|
+
results = cursor.fetchall()
|
|
345
|
+
columns = [col[0] for col in cursor.description]
|
|
346
|
+
|
|
347
|
+
# Process vector search results
|
|
348
|
+
valid_suggestions = {}
|
|
349
|
+
suggested_nodes = []
|
|
350
|
+
seen_texts = set()
|
|
351
|
+
|
|
352
|
+
for row in results:
|
|
353
|
+
result_dict = dict(zip(columns, row))
|
|
354
|
+
node_id = result_dict["node_id"]
|
|
355
|
+
content = result_dict["content"]
|
|
356
|
+
distance = result_dict["distance"]
|
|
357
|
+
|
|
358
|
+
if content not in seen_texts:
|
|
359
|
+
seen_texts.add(content)
|
|
360
|
+
custom_metadata = result_dict["custom_metadata"] or {}
|
|
361
|
+
if isinstance(custom_metadata, str):
|
|
362
|
+
try:
|
|
363
|
+
import json
|
|
364
|
+
|
|
365
|
+
custom_metadata = json.loads(custom_metadata)
|
|
366
|
+
except (json.JSONDecodeError, TypeError):
|
|
367
|
+
custom_metadata = {}
|
|
368
|
+
|
|
369
|
+
metadata = NodeMetadata(
|
|
370
|
+
source_file_uuid=result_dict["source_file_uuid"],
|
|
371
|
+
position=result_dict["position"],
|
|
372
|
+
custom=custom_metadata,
|
|
373
|
+
)
|
|
374
|
+
metadata.node_id = node_id
|
|
375
|
+
node = Node(content=content, metadata=metadata)
|
|
376
|
+
node.embedding = result_dict.get("embedding")
|
|
377
|
+
suggested_nodes.append(node)
|
|
378
|
+
|
|
379
|
+
valid_suggestions[str(node_id)] = (
|
|
380
|
+
{
|
|
381
|
+
"node_id": node_id,
|
|
382
|
+
"source_file_uuid": result_dict["source_file_uuid"],
|
|
383
|
+
"position": result_dict["position"],
|
|
384
|
+
"custom": custom_metadata,
|
|
385
|
+
},
|
|
386
|
+
content,
|
|
387
|
+
float(distance),
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Perform hybrid search and re-ranking if enabled
|
|
391
|
+
if hybrid_search:
|
|
392
|
+
valid_suggestions, suggested_nodes = self._rerank_results(
|
|
393
|
+
query, list(valid_suggestions.values()
|
|
394
|
+
), suggested_nodes, number
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
logger.info(f"Search returned {len(valid_suggestions)} results")
|
|
398
|
+
return valid_suggestions, suggested_nodes
|
|
399
|
+
|
|
400
|
+
@transaction.atomic
|
|
401
|
+
def create_collection_from_files(
|
|
402
|
+
self, collection_name: str, files: List[VSFile]
|
|
403
|
+
) -> None:
|
|
404
|
+
"""
|
|
405
|
+
Creates a collection from a list of VSFile objects.
|
|
406
|
+
|
|
407
|
+
:param collection_name: Name of the collection to create.
|
|
408
|
+
:param files: List of VSFile objects containing nodes.
|
|
409
|
+
"""
|
|
410
|
+
logger.info(f"Creating collection: {collection_name} from files")
|
|
411
|
+
nodes = [node for file in files for node in file.nodes]
|
|
412
|
+
self.create_collection_from_nodes(collection_name, nodes)
|
|
413
|
+
|
|
414
|
+
@transaction.atomic
|
|
415
|
+
def create_collection_from_nodes(
|
|
416
|
+
self, collection_name: str, nodes: List[Node]
|
|
417
|
+
) -> None:
|
|
418
|
+
"""
|
|
419
|
+
Creates a collection from a list of nodes.
|
|
420
|
+
|
|
421
|
+
:param collection_name: Name of the collection to create.
|
|
422
|
+
:param nodes: List of Node objects.
|
|
423
|
+
"""
|
|
424
|
+
if not nodes:
|
|
425
|
+
logger.warning(
|
|
426
|
+
f"Cannot create collection '{collection_name}' because nodes list is empty"
|
|
427
|
+
)
|
|
428
|
+
return
|
|
429
|
+
|
|
430
|
+
# Filter out nodes with None or empty content (these would cause embedding errors)
|
|
431
|
+
original_count = len(nodes)
|
|
432
|
+
nodes = [node for node in nodes if node.content is not None and str(
|
|
433
|
+
node.content).strip()]
|
|
434
|
+
|
|
435
|
+
if len(nodes) < original_count:
|
|
436
|
+
logger.warning(
|
|
437
|
+
f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
|
|
438
|
+
|
|
439
|
+
if not nodes:
|
|
440
|
+
logger.warning(
|
|
441
|
+
f"No valid nodes for collection '{collection_name}' after filtering")
|
|
442
|
+
return
|
|
443
|
+
|
|
444
|
+
total_nodes = len(nodes)
|
|
445
|
+
logger.info(
|
|
446
|
+
f"Creating collection: {collection_name} with {total_nodes} nodes")
|
|
447
|
+
|
|
448
|
+
start_time = time.time()
|
|
449
|
+
collection = self.get_or_create_collection(collection_name)
|
|
450
|
+
NodeEntry.objects.filter(collection=collection).delete()
|
|
451
|
+
|
|
452
|
+
# Generate embeddings
|
|
453
|
+
embed_start = time.time()
|
|
454
|
+
text_chunks = [str(node.content) for node in nodes]
|
|
455
|
+
embeddings = self.get_embeddings(text_chunks, parallel=False)
|
|
456
|
+
embed_time = time.time() - embed_start
|
|
457
|
+
logger.info(
|
|
458
|
+
f"Embeddings generated in {embed_time:.2f}s ({total_nodes/embed_time:.0f} nodes/s)")
|
|
459
|
+
|
|
460
|
+
# Prepare node entries
|
|
461
|
+
prep_start = time.time()
|
|
462
|
+
node_entries = [
|
|
463
|
+
NodeEntry(
|
|
464
|
+
collection=collection,
|
|
465
|
+
content=node.content,
|
|
466
|
+
embedding=embeddings[i].tolist(),
|
|
467
|
+
source_file_uuid=node.metadata.source_file_uuid,
|
|
468
|
+
position=node.metadata.position,
|
|
469
|
+
custom_metadata=node.metadata.custom or {},
|
|
470
|
+
)
|
|
471
|
+
for i, node in enumerate(nodes)
|
|
472
|
+
]
|
|
473
|
+
prep_time = time.time() - prep_start
|
|
474
|
+
logger.info(f"Node entries prepared in {prep_time:.2f}s")
|
|
475
|
+
|
|
476
|
+
# Bulk insert
|
|
477
|
+
insert_start = time.time()
|
|
478
|
+
created_entries = NodeEntry.objects.bulk_create(node_entries)
|
|
479
|
+
insert_time = time.time() - insert_start
|
|
480
|
+
logger.info(
|
|
481
|
+
f"Bulk insert completed in {insert_time:.2f}s ({total_nodes/insert_time:.0f} nodes/s)")
|
|
482
|
+
|
|
483
|
+
for i, node in enumerate(nodes):
|
|
484
|
+
node.metadata.node_id = created_entries[i].node_id
|
|
485
|
+
|
|
486
|
+
total_time = time.time() - start_time
|
|
487
|
+
logger.info(
|
|
488
|
+
f"Created collection '{collection_name}' with {len(created_entries)} nodes in {total_time:.2f}s"
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
@transaction.atomic
|
|
492
|
+
def add_nodes(self, collection_name: str, nodes: List[Node]) -> None:
|
|
493
|
+
"""
|
|
494
|
+
Adds nodes to an existing collection.
|
|
495
|
+
|
|
496
|
+
:param collection_name: Name of the collection to update.
|
|
497
|
+
:param nodes: List of Node objects to be added.
|
|
498
|
+
"""
|
|
499
|
+
if not nodes:
|
|
500
|
+
logger.warning("No nodes to add")
|
|
501
|
+
return
|
|
502
|
+
|
|
503
|
+
# Filter out nodes with None or empty content (these would cause embedding errors)
|
|
504
|
+
original_count = len(nodes)
|
|
505
|
+
nodes = [node for node in nodes if node.content is not None and str(
|
|
506
|
+
node.content).strip()]
|
|
507
|
+
|
|
508
|
+
if len(nodes) < original_count:
|
|
509
|
+
logger.warning(
|
|
510
|
+
f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
|
|
511
|
+
|
|
512
|
+
if not nodes:
|
|
513
|
+
logger.warning("No valid nodes to add after filtering")
|
|
514
|
+
return
|
|
515
|
+
|
|
516
|
+
total_nodes = len(nodes)
|
|
517
|
+
logger.info(
|
|
518
|
+
f"Adding {total_nodes} nodes to collection: {collection_name}")
|
|
519
|
+
|
|
520
|
+
start_time = time.time()
|
|
521
|
+
try:
|
|
522
|
+
collection = Collection.objects.get(name=collection_name)
|
|
523
|
+
except Collection.DoesNotExist:
|
|
524
|
+
raise ValueError(
|
|
525
|
+
f"No collection found with name: {collection_name}")
|
|
526
|
+
|
|
527
|
+
# Generate embeddings
|
|
528
|
+
embed_start = time.time()
|
|
529
|
+
text_chunks = [str(node.content) for node in nodes]
|
|
530
|
+
embeddings = self.get_embeddings(text_chunks, parallel=False)
|
|
531
|
+
embed_time = time.time() - embed_start
|
|
532
|
+
logger.info(
|
|
533
|
+
f"Embeddings generated in {embed_time:.2f}s ({total_nodes/embed_time:.0f} nodes/s)")
|
|
534
|
+
|
|
535
|
+
# Prepare node entries
|
|
536
|
+
prep_start = time.time()
|
|
537
|
+
node_entries = [
|
|
538
|
+
NodeEntry(
|
|
539
|
+
collection=collection,
|
|
540
|
+
content=node.content,
|
|
541
|
+
embedding=embeddings[i].tolist(),
|
|
542
|
+
source_file_uuid=node.metadata.source_file_uuid,
|
|
543
|
+
position=node.metadata.position,
|
|
544
|
+
custom_metadata=node.metadata.custom or {},
|
|
545
|
+
)
|
|
546
|
+
for i, node in enumerate(nodes)
|
|
547
|
+
]
|
|
548
|
+
prep_time = time.time() - prep_start
|
|
549
|
+
logger.info(f"Node entries prepared in {prep_time:.2f}s")
|
|
550
|
+
|
|
551
|
+
# Bulk insert
|
|
552
|
+
insert_start = time.time()
|
|
553
|
+
created_entries = NodeEntry.objects.bulk_create(node_entries)
|
|
554
|
+
insert_time = time.time() - insert_start
|
|
555
|
+
logger.info(
|
|
556
|
+
f"Bulk insert completed in {insert_time:.2f}s ({total_nodes/insert_time:.0f} nodes/s)")
|
|
557
|
+
|
|
558
|
+
for i, node in enumerate(nodes):
|
|
559
|
+
node.metadata.node_id = created_entries[i].node_id
|
|
560
|
+
|
|
561
|
+
total_time = time.time() - start_time
|
|
562
|
+
logger.info(
|
|
563
|
+
f"Added {len(created_entries)} nodes to collection '{collection_name}' in {total_time:.2f}s"
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
@transaction.atomic
|
|
567
|
+
def delete_nodes(self, collection_name: str, node_ids: List[int]) -> None:
|
|
568
|
+
"""
|
|
569
|
+
Deletes nodes from an existing collection.
|
|
570
|
+
|
|
571
|
+
:param collection_name: Name of the collection to update.
|
|
572
|
+
:param node_ids: List of node IDs to be deleted.
|
|
573
|
+
"""
|
|
574
|
+
if not node_ids:
|
|
575
|
+
logger.warning("No node IDs to delete")
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
logger.info(
|
|
579
|
+
f"Deleting {len(node_ids)} nodes from collection: {collection_name}"
|
|
580
|
+
)
|
|
581
|
+
try:
|
|
582
|
+
collection = Collection.objects.get(name=collection_name)
|
|
583
|
+
except Collection.DoesNotExist:
|
|
584
|
+
raise ValueError(
|
|
585
|
+
f"No collection found with name: {collection_name}")
|
|
586
|
+
|
|
587
|
+
existing_ids = set(
|
|
588
|
+
NodeEntry.objects.filter(
|
|
589
|
+
collection=collection, node_id__in=node_ids
|
|
590
|
+
).values_list("node_id", flat=True)
|
|
591
|
+
)
|
|
592
|
+
missing_ids = set(node_ids) - existing_ids
|
|
593
|
+
if missing_ids:
|
|
594
|
+
logger.warning(
|
|
595
|
+
f"Node ID(s) {missing_ids} not found in collection {collection_name}"
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
deleted_count, _ = NodeEntry.objects.filter(
|
|
599
|
+
collection=collection, node_id__in=existing_ids
|
|
600
|
+
).delete()
|
|
601
|
+
logger.info(
|
|
602
|
+
f"Deleted {deleted_count} nodes from collection '{collection_name}'"
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
@transaction.atomic
|
|
606
|
+
def add_files(self, collection_name: str, files: List[VSFile]) -> None:
|
|
607
|
+
"""
|
|
608
|
+
Adds file nodes to the specified collection.
|
|
609
|
+
|
|
610
|
+
:param collection_name: Name of the collection to update.
|
|
611
|
+
:param files: List of VSFile objects whose nodes are to be added.
|
|
612
|
+
"""
|
|
613
|
+
logger.info(f"Adding files to collection: {collection_name}")
|
|
614
|
+
all_nodes = [node for file in files for node in file.nodes]
|
|
615
|
+
self.add_nodes(collection_name, all_nodes)
|
|
616
|
+
|
|
617
|
+
@transaction.atomic
|
|
618
|
+
def delete_files(self, collection_name: str, files: List[VSFile]) -> None:
|
|
619
|
+
"""
|
|
620
|
+
Deletes file nodes from the specified collection.
|
|
621
|
+
|
|
622
|
+
:param collection_name: Name of the collection to update.
|
|
623
|
+
:param files: List of VSFile objects whose nodes are to be deleted.
|
|
624
|
+
"""
|
|
625
|
+
logger.info(f"Deleting files from collection: {collection_name}")
|
|
626
|
+
node_ids_to_delete = [
|
|
627
|
+
node.metadata.node_id
|
|
628
|
+
for file in files
|
|
629
|
+
for node in file.nodes
|
|
630
|
+
if node.metadata.node_id
|
|
631
|
+
]
|
|
632
|
+
if node_ids_to_delete:
|
|
633
|
+
self.delete_nodes(collection_name, node_ids_to_delete)
|
|
634
|
+
else:
|
|
635
|
+
logger.warning("No node IDs found in provided files")
|
|
636
|
+
|
|
637
|
+
def list_collections(self) -> List[str]:
|
|
638
|
+
"""
|
|
639
|
+
Lists all available collections.
|
|
640
|
+
|
|
641
|
+
:return: List of collection names.
|
|
642
|
+
"""
|
|
643
|
+
return list(Collection.objects.values_list("name", flat=True))
|
|
644
|
+
|
|
645
|
+
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
|
|
646
|
+
"""
|
|
647
|
+
Gets information about a collection.
|
|
648
|
+
|
|
649
|
+
:param collection_name: Name of the collection.
|
|
650
|
+
:return: Dictionary containing collection information.
|
|
651
|
+
"""
|
|
652
|
+
try:
|
|
653
|
+
collection = Collection.objects.get(name=collection_name)
|
|
654
|
+
except Collection.DoesNotExist:
|
|
655
|
+
raise ValueError(
|
|
656
|
+
f"No collection found with name: {collection_name}")
|
|
657
|
+
|
|
658
|
+
node_count = NodeEntry.objects.filter(collection=collection).count()
|
|
659
|
+
return {
|
|
660
|
+
"name": collection.name,
|
|
661
|
+
"embedding_dim": collection.embedding_dim,
|
|
662
|
+
"node_count": node_count,
|
|
663
|
+
"created_at": collection.created_at,
|
|
664
|
+
"updated_at": collection.updated_at,
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
@transaction.atomic
|
|
668
|
+
def delete_collection(self, collection_name: str) -> None:
|
|
669
|
+
"""
|
|
670
|
+
Deletes a collection and all its nodes.
|
|
671
|
+
|
|
672
|
+
:param collection_name: Name of the collection to delete.
|
|
673
|
+
"""
|
|
674
|
+
try:
|
|
675
|
+
collection = Collection.objects.get(name=collection_name)
|
|
676
|
+
except Collection.DoesNotExist:
|
|
677
|
+
raise ValueError(
|
|
678
|
+
f"No collection found with name: {collection_name}")
|
|
679
|
+
|
|
680
|
+
node_count = NodeEntry.objects.filter(hourly=collection).count()
|
|
681
|
+
collection.delete()
|
|
682
|
+
logger.info(
|
|
683
|
+
f"Deleted collection '{collection_name}' with {node_count} nodes")
|
|
684
|
+
|
|
685
|
+
# VectorStore interface methods
|
|
686
|
+
def add(self, vectors: List[List[float]], metadatas: List[Dict[str, Any]]) -> Any:
|
|
687
|
+
"""
|
|
688
|
+
Adds vectors with metadata to the default collection.
|
|
689
|
+
This method implements the VectorStore interface.
|
|
690
|
+
|
|
691
|
+
:param vectors: List of embedding vectors to add.
|
|
692
|
+
:param metadatas: List of metadata dictionaries for each vector.
|
|
693
|
+
:return: List of node IDs that were created.
|
|
694
|
+
"""
|
|
695
|
+
if not vectors or not metadatas:
|
|
696
|
+
logger.warning("Empty vectors or metadatas provided to add()")
|
|
697
|
+
return []
|
|
698
|
+
|
|
699
|
+
if len(vectors) != len(metadatas):
|
|
700
|
+
raise ValueError(
|
|
701
|
+
"Number of vectors must match number of metadatas")
|
|
702
|
+
|
|
703
|
+
# Get or create default collection
|
|
704
|
+
collection_name = metadatas[0].get(
|
|
705
|
+
"collection_name", "default_collection")
|
|
706
|
+
collection = self.get_or_create_collection(collection_name)
|
|
707
|
+
|
|
708
|
+
# Create nodes from vectors and metadatas
|
|
709
|
+
node_entries = []
|
|
710
|
+
for i, (vector, metadata) in enumerate(zip(vectors, metadatas)):
|
|
711
|
+
content = metadata.get("content", "")
|
|
712
|
+
source_file_uuid = metadata.get("source_file_uuid", "")
|
|
713
|
+
position = metadata.get("position", i)
|
|
714
|
+
custom_metadata = {
|
|
715
|
+
k: v
|
|
716
|
+
for k, v in metadata.items()
|
|
717
|
+
if k not in ["content", "source_file_uuid", "position", "collection_name"]
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
node_entries.append(
|
|
721
|
+
NodeEntry(
|
|
722
|
+
collection=collection,
|
|
723
|
+
content=content,
|
|
724
|
+
embedding=vector,
|
|
725
|
+
source_file_uuid=source_file_uuid,
|
|
726
|
+
position=position,
|
|
727
|
+
custom_metadata=custom_metadata,
|
|
728
|
+
)
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
created_entries = NodeEntry.objects.bulk_create(node_entries)
|
|
732
|
+
node_ids = [entry.node_id for entry in created_entries]
|
|
733
|
+
logger.info(
|
|
734
|
+
f"Added {len(node_ids)} vectors to collection '{collection_name}'")
|
|
735
|
+
return node_ids
|
|
736
|
+
|
|
737
|
+
def query(
|
|
738
|
+
self, vector: List[float], top_k: int = 5, **kwargs
|
|
739
|
+
) -> List[Dict[str, Any]]:
|
|
740
|
+
"""
|
|
741
|
+
Queries the vector store for similar vectors.
|
|
742
|
+
This method implements the VectorStore interface.
|
|
743
|
+
|
|
744
|
+
:param vector: Query vector.
|
|
745
|
+
:param top_k: Number of results to return.
|
|
746
|
+
:param kwargs: Additional parameters (collection_name, distance_type, meta_data_filters).
|
|
747
|
+
:return: List of dictionaries containing search results.
|
|
748
|
+
"""
|
|
749
|
+
collection_name = kwargs.get("collection_name", "default_collection")
|
|
750
|
+
distance_type = kwargs.get("distance_type", "cosine")
|
|
751
|
+
meta_data_filters = kwargs.get("meta_data_filters")
|
|
752
|
+
|
|
753
|
+
try:
|
|
754
|
+
collection = Collection.objects.get(name=collection_name)
|
|
755
|
+
except Collection.DoesNotExist:
|
|
756
|
+
logger.warning(f"Collection '{collection_name}' not found")
|
|
757
|
+
return []
|
|
758
|
+
|
|
759
|
+
# Convert vector to numpy array
|
|
760
|
+
query_embedding = np.array(vector, dtype="float32")
|
|
761
|
+
|
|
762
|
+
# Normalize if using cosine distance
|
|
763
|
+
if distance_type == "cosine":
|
|
764
|
+
norm = np.linalg.norm(query_embedding)
|
|
765
|
+
if norm > 0:
|
|
766
|
+
query_embedding = query_embedding / norm
|
|
767
|
+
|
|
768
|
+
# Build queryset
|
|
769
|
+
queryset = NodeEntry.objects.filter(collection=collection)
|
|
770
|
+
|
|
771
|
+
# Apply metadata filters
|
|
772
|
+
if meta_data_filters:
|
|
773
|
+
for key, value in meta_data_filters.items():
|
|
774
|
+
queryset = queryset.filter(
|
|
775
|
+
**{f"custom_metadata__{key}": value})
|
|
776
|
+
|
|
777
|
+
# Determine distance operator
|
|
778
|
+
if distance_type == "cosine":
|
|
779
|
+
distance_operator = "<=>"
|
|
780
|
+
elif distance_type == "l2":
|
|
781
|
+
distance_operator = "<->"
|
|
782
|
+
elif distance_type == "dot":
|
|
783
|
+
distance_operator = "<#>"
|
|
784
|
+
else:
|
|
785
|
+
raise ValueError(f"Unsupported distance type: {distance_type}")
|
|
786
|
+
|
|
787
|
+
embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
|
788
|
+
|
|
789
|
+
sql_query = f"""
|
|
790
|
+
SELECT
|
|
791
|
+
node_id,
|
|
792
|
+
content,
|
|
793
|
+
source_file_uuid,
|
|
794
|
+
position,
|
|
795
|
+
custom_metadata,
|
|
796
|
+
embedding {distance_operator} %s::vector AS distance
|
|
797
|
+
FROM
|
|
798
|
+
{NodeEntry._meta.db_table}
|
|
799
|
+
WHERE
|
|
800
|
+
collection_id = %s
|
|
801
|
+
ORDER BY
|
|
802
|
+
distance
|
|
803
|
+
LIMIT
|
|
804
|
+
%s
|
|
805
|
+
"""
|
|
806
|
+
|
|
807
|
+
# Execute query
|
|
808
|
+
with connection.cursor() as cursor:
|
|
809
|
+
cursor.execute(sql_query, [embedding_str, collection.id, top_k])
|
|
810
|
+
results = cursor.fetchall()
|
|
811
|
+
columns = [col[0] for col in cursor.description]
|
|
812
|
+
|
|
813
|
+
# Format results
|
|
814
|
+
formatted_results = []
|
|
815
|
+
for row in results:
|
|
816
|
+
result_dict = dict(zip(columns, row))
|
|
817
|
+
formatted_results.append({
|
|
818
|
+
"node_id": result_dict["node_id"],
|
|
819
|
+
"content": result_dict["content"],
|
|
820
|
+
"source_file_uuid": result_dict["source_file_uuid"],
|
|
821
|
+
"position": result_dict["position"],
|
|
822
|
+
"metadata": result_dict["custom_metadata"] or {},
|
|
823
|
+
"distance": float(result_dict["distance"]),
|
|
824
|
+
})
|
|
825
|
+
|
|
826
|
+
logger.info(f"Query returned {len(formatted_results)} results")
|
|
827
|
+
return formatted_results
|