rakam-systems-vectorstore 0.1.1rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. rakam_systems_vectorstore/MANIFEST.in +26 -0
  2. rakam_systems_vectorstore/README.md +1071 -0
  3. rakam_systems_vectorstore/__init__.py +93 -0
  4. rakam_systems_vectorstore/components/__init__.py +0 -0
  5. rakam_systems_vectorstore/components/chunker/__init__.py +19 -0
  6. rakam_systems_vectorstore/components/chunker/advanced_chunker.py +1019 -0
  7. rakam_systems_vectorstore/components/chunker/text_chunker.py +154 -0
  8. rakam_systems_vectorstore/components/embedding_model/__init__.py +0 -0
  9. rakam_systems_vectorstore/components/embedding_model/configurable_embeddings.py +546 -0
  10. rakam_systems_vectorstore/components/embedding_model/openai_embeddings.py +259 -0
  11. rakam_systems_vectorstore/components/loader/__init__.py +31 -0
  12. rakam_systems_vectorstore/components/loader/adaptive_loader.py +512 -0
  13. rakam_systems_vectorstore/components/loader/code_loader.py +699 -0
  14. rakam_systems_vectorstore/components/loader/doc_loader.py +812 -0
  15. rakam_systems_vectorstore/components/loader/eml_loader.py +556 -0
  16. rakam_systems_vectorstore/components/loader/html_loader.py +626 -0
  17. rakam_systems_vectorstore/components/loader/md_loader.py +622 -0
  18. rakam_systems_vectorstore/components/loader/odt_loader.py +750 -0
  19. rakam_systems_vectorstore/components/loader/pdf_loader.py +771 -0
  20. rakam_systems_vectorstore/components/loader/pdf_loader_light.py +723 -0
  21. rakam_systems_vectorstore/components/loader/tabular_loader.py +597 -0
  22. rakam_systems_vectorstore/components/vectorstore/__init__.py +0 -0
  23. rakam_systems_vectorstore/components/vectorstore/apps.py +10 -0
  24. rakam_systems_vectorstore/components/vectorstore/configurable_pg_vector_store.py +1661 -0
  25. rakam_systems_vectorstore/components/vectorstore/faiss_vector_store.py +878 -0
  26. rakam_systems_vectorstore/components/vectorstore/migrations/0001_initial.py +55 -0
  27. rakam_systems_vectorstore/components/vectorstore/migrations/__init__.py +0 -0
  28. rakam_systems_vectorstore/components/vectorstore/models.py +10 -0
  29. rakam_systems_vectorstore/components/vectorstore/pg_models.py +97 -0
  30. rakam_systems_vectorstore/components/vectorstore/pg_vector_store.py +827 -0
  31. rakam_systems_vectorstore/config.py +266 -0
  32. rakam_systems_vectorstore/core.py +8 -0
  33. rakam_systems_vectorstore/pyproject.toml +113 -0
  34. rakam_systems_vectorstore/server/README.md +290 -0
  35. rakam_systems_vectorstore/server/__init__.py +20 -0
  36. rakam_systems_vectorstore/server/mcp_server_vector.py +325 -0
  37. rakam_systems_vectorstore/setup.py +103 -0
  38. rakam_systems_vectorstore-0.1.1rc7.dist-info/METADATA +370 -0
  39. rakam_systems_vectorstore-0.1.1rc7.dist-info/RECORD +40 -0
  40. rakam_systems_vectorstore-0.1.1rc7.dist-info/WHEEL +4 -0
@@ -0,0 +1,827 @@
1
+ import os
2
+ import time
3
+ from functools import lru_cache
4
+ from typing import Any
5
+ from typing import Dict
6
+ from typing import List
7
+ from typing import Optional
8
+ from typing import Tuple
9
+
10
+ import dotenv
11
+ import numpy as np
12
+ from django.contrib.postgres.search import SearchQuery
13
+ from django.contrib.postgres.search import SearchRank
14
+ from django.contrib.postgres.search import SearchVector
15
+ from django.db import connection
16
+ from django.db import transaction
17
+ from openai import OpenAI
18
+ from sentence_transformers import SentenceTransformer
19
+
20
+ from rakam_systems_core.ai_utils import logging
21
+ from rakam_systems_core.ai_core.interfaces.vectorstore import VectorStore
22
+ from rakam_systems_vectorstore.components.vectorstore.pg_models import Collection
23
+ from rakam_systems_vectorstore.components.vectorstore.pg_models import NodeEntry
24
+ from rakam_systems_vectorstore.core import Node
25
+ from rakam_systems_vectorstore.core import NodeMetadata
26
+ from rakam_systems_vectorstore.core import VSFile
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Load environment variables
31
+ dotenv.load_dotenv()
32
+ api_key = os.getenv("OPENAI_API_KEY")
33
+
34
+
35
+ class PgVectorStore(VectorStore):
36
+ """
37
+ A class for managing collection-based vector stores using pgvector and Django ORM.
38
+ Enhanced for better semantic search performance with hybrid search, re-ranking, and caching.
39
+
40
+ Note: Vector columns are created without dimension constraints, allowing flexibility
41
+ to use different embedding models without needing to alter the database schema.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ name: str = "pg_vector_store",
47
+ config=None,
48
+ embedding_model: str = "Snowflake/snowflake-arctic-embed-m",
49
+ use_embedding_api: bool = False,
50
+ api_model: str = "text-embedding-3-small",
51
+ ) -> None:
52
+ """
53
+ Initializes the PgVectorStore with the specified embedding model.
54
+
55
+ :param name: Name of the vector store component.
56
+ :param config: Configuration object.
57
+ :param embedding_model: Pre-trained SentenceTransformer model name.
58
+ :param use_embedding_api: Whether to use OpenAI's embedding API instead of local model.
59
+ :param api_model: OpenAI API model to use for embeddings if use_embedding_api is True.
60
+ """
61
+ super().__init__(name=name, config=config)
62
+ self._ensure_pgvector_extension()
63
+ self.use_embedding_api = use_embedding_api
64
+
65
+ if self.use_embedding_api:
66
+ self.client = OpenAI(api_key=api_key)
67
+ self.api_model = api_model
68
+ sample_embedding = self._get_api_embedding("Sample text")
69
+ self.embedding_dim = len(sample_embedding)
70
+ else:
71
+ self.embedding_model = SentenceTransformer(
72
+ embedding_model, trust_remote_code=True
73
+ )
74
+ self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
75
+
76
+ logger.info(
77
+ f"Initialized PgVectorStore with embedding dimension: {self.embedding_dim}"
78
+ )
79
+
80
+ def _ensure_pgvector_extension(self) -> None:
81
+ """
82
+ Ensures that the pgvector extension is installed in the PostgreSQL database.
83
+ """
84
+ with connection.cursor() as cursor:
85
+ try:
86
+ cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
87
+ logger.info("Ensured pgvector extension is installed")
88
+ except Exception as e:
89
+ logger.error(f"Failed to create pgvector extension: {e}")
90
+ raise
91
+
92
+ def _get_api_embedding(self, text: str) -> List[float]:
93
+ """
94
+ Gets embedding from OpenAI API.
95
+
96
+ :param text: Text to embed
97
+ :return: Embedding vector
98
+ """
99
+ try:
100
+ response = self.client.embeddings.create(
101
+ input=[text], model=self.api_model)
102
+ return response.data[0].embedding
103
+ except Exception as e:
104
+ logger.error(f"Failed to get API embedding: {e}")
105
+ raise
106
+
107
+ @lru_cache(maxsize=1000)
108
+ def predict_embeddings(self, query: str) -> np.ndarray:
109
+ """
110
+ Predicts embeddings for a given query using the embedding model.
111
+ Caches results to reduce redundant computations.
112
+
113
+ :param query: Query string to encode.
114
+ :return: Normalized embedding vector for the query.
115
+ """
116
+ logger.debug(f"Predicting embeddings for query: {query}")
117
+ start_time = time.time()
118
+
119
+ if self.use_embedding_api:
120
+ query_embedding = self._get_api_embedding(query)
121
+ query_embedding = np.array(query_embedding, dtype="float32")
122
+ else:
123
+ query_embedding = self.embedding_model.encode(query)
124
+ query_embedding = np.array(query_embedding, dtype="float32")
125
+
126
+ # Normalize embedding for cosine similarity
127
+ norm = np.linalg.norm(query_embedding)
128
+ if norm > 0:
129
+ query_embedding = query_embedding / norm
130
+ else:
131
+ logger.warning(f"Zero norm encountered for query: {query}")
132
+
133
+ logger.debug(
134
+ f"Embedding generation took {time.time() - start_time:.2f} seconds"
135
+ )
136
+ return query_embedding
137
+
138
+ def get_embeddings(
139
+ self, sentences: List[str], parallel: bool = True, batch_size: int = 8
140
+ ) -> np.ndarray:
141
+ """
142
+ Generates embeddings for a list of sentences with normalization.
143
+
144
+ :param sentences: List of sentences to encode.
145
+ :param parallel: Whether to use parallel processing (default is True).
146
+ :param batch_size: Batch size for processing (default is 8).
147
+ :return: Normalized embedding vectors for the sentences.
148
+ """
149
+ logger.info(f"Generating embeddings for {len(sentences)} sentences")
150
+ start = time.time()
151
+
152
+ if self.use_embedding_api:
153
+ all_embeddings = []
154
+ for i in range(0, len(sentences), batch_size):
155
+ batch = sentences[i: i + batch_size]
156
+ response = self.client.embeddings.create(
157
+ input=batch, model=self.api_model
158
+ )
159
+ batch_embeddings = [data.embedding for data in response.data]
160
+ all_embeddings.extend(batch_embeddings)
161
+ embeddings = np.array(all_embeddings, dtype="float32")
162
+ else:
163
+ if parallel:
164
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
165
+ pool = self.embedding_model.start_multi_process_pool(
166
+ target_devices=["cpu"] * 5
167
+ )
168
+ embeddings = self.embedding_model.encode_multi_process(
169
+ sentences, pool, batch_size=batch_size
170
+ )
171
+ self.embedding_model.stop_multi_process_pool(pool)
172
+ else:
173
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
174
+ embeddings = self.embedding_model.encode(
175
+ sentences,
176
+ batch_size=batch_size,
177
+ show_progress_bar=True,
178
+ convert_to_tensor=True,
179
+ )
180
+ embeddings = embeddings.cpu().detach().numpy()
181
+
182
+ # Normalize embeddings
183
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
184
+ norms[norms == 0] = 1 # Avoid division by zero
185
+ embeddings = embeddings / norms
186
+
187
+ logger.info(
188
+ f"Time taken to encode {len(sentences)} items: {time.time() - start:.2f} seconds"
189
+ )
190
+ return embeddings
191
+
192
+ def get_or_create_collection(self, collection_name: str) -> Collection:
193
+ """
194
+ Gets or creates a collection with the specified name.
195
+
196
+ :param collection_name: Name of the collection.
197
+ :return: Collection object.
198
+ """
199
+ collection, created = Collection.objects.get_or_create(
200
+ name=collection_name, defaults={
201
+ "embedding_dim": self.embedding_dim}
202
+ )
203
+ logger.info(
204
+ f"{'Created new' if created else 'Using existing'} collection: {collection_name}"
205
+ )
206
+ return collection
207
+
208
+ def _rerank_results(
209
+ self,
210
+ query: str,
211
+ results: List[Tuple[Dict, str, float]],
212
+ suggested_nodes: List[Node],
213
+ top_k: int,
214
+ ) -> Tuple[Dict, List[Node]]:
215
+ """
216
+ Re-ranks search results using a combination of vector similarity and keyword relevance.
217
+
218
+ :param query: The search query.
219
+ :param results: Initial search results (metadata, content, distance).
220
+ :param suggested_nodes: List of Node objects.
221
+ :param top_k: Number of results to return after re-ranking.
222
+ :return: Tuple of re-ranked results dictionary and updated suggested_nodes.
223
+ """
224
+ logger.debug(f"Re-ranking {len(results)} results for query: {query}")
225
+
226
+ # Perform full-text search to get keyword relevance scores
227
+ search_query = SearchQuery(query, config="english")
228
+ queryset = NodeEntry.objects.filter(
229
+ collection__name="document_collection",
230
+ node_id__in=[int(res[0]["node_id"]) for res in results],
231
+ ).annotate(
232
+ rank=SearchRank(SearchVector(
233
+ "content", config="english"), search_query)
234
+ )
235
+
236
+ # Combine vector distance and keyword rank
237
+ reranked_results = []
238
+ node_id_to_rank = {node.node_id: node.rank for node in queryset}
239
+ for metadata, content, distance in results:
240
+ node_id = metadata["node_id"]
241
+ keyword_score = node_id_to_rank.get(node_id, 0.0)
242
+ # Combine scores (adjust weights as needed)
243
+ combined_score = 0.7 * (1 - distance) + 0.3 * keyword_score
244
+ reranked_results.append((metadata, content, combined_score))
245
+
246
+ # Sort by combined score and take top_k
247
+ reranked_results = sorted(reranked_results, key=lambda x: x[2], reverse=True)[
248
+ :top_k
249
+ ]
250
+ valid_suggestions = {
251
+ str(res[0]["node_id"]): res for res in reranked_results}
252
+
253
+ # Update suggested_nodes to match re-ranked order
254
+ node_id_order = [res[0]["node_id"] for res in reranked_results]
255
+ updated_nodes = sorted(
256
+ suggested_nodes,
257
+ key=lambda node: node_id_order.index(node.metadata.node_id)
258
+ if node.metadata.node_id in node_id_order
259
+ else len(node_id_order),
260
+ )[:top_k]
261
+
262
+ logger.debug(f"Re-ranked to {len(valid_suggestions)} results")
263
+ return valid_suggestions, updated_nodes
264
+
265
+ def search(
266
+ self,
267
+ collection_name: str,
268
+ query: str,
269
+ distance_type: str = "cosine",
270
+ number: int = 5,
271
+ meta_data_filters: Optional[Dict[str, Any]] = None,
272
+ hybrid_search: bool = True,
273
+ ) -> Tuple[Dict, List[Node]]:
274
+ """
275
+ Retrieve relevant documents from the vector store using hybrid search and re-ranking.
276
+
277
+ :param collection_name: Name of the collection to search.
278
+ :param query: Search query.
279
+ :param distance_type: Distance metric ("cosine", "l2", "dot").
280
+ :param number: Number of results to return.
281
+ :param meta_data_filters: Dictionary of metadata filters (e.g., {"is_validated": True}).
282
+ :param hybrid_search: Whether to use hybrid search combining vector and keyword search.
283
+ :return: Tuple of search results (dictionary) and suggested nodes.
284
+ """
285
+ logger.info(
286
+ f"Searching in collection: {collection_name} for query: '{query}'")
287
+
288
+ try:
289
+ collection = Collection.objects.get(name=collection_name)
290
+ except Collection.DoesNotExist:
291
+ logger.error(f"No collection found with name: {collection_name}")
292
+ raise ValueError(
293
+ f"No collection found with name: {collection_name}")
294
+
295
+ # Generate query embedding
296
+ query_embedding = self.predict_embeddings(query)
297
+
298
+ # Build base queryset
299
+ queryset = NodeEntry.objects.filter(collection=collection)
300
+
301
+ # Apply metadata filters
302
+ if meta_data_filters:
303
+ for key, value in meta_data_filters.items():
304
+ queryset = queryset.filter(
305
+ **{f"custom_metadata__{key}": value})
306
+
307
+ # Construct SQL query for vector search
308
+ if distance_type == "cosine":
309
+ distance_operator = "<=>"
310
+ elif distance_type == "l2":
311
+ distance_operator = "<->"
312
+ elif distance_type == "dot":
313
+ distance_operator = "<#>"
314
+ else:
315
+ logger.error(f"Unsupported distance type: {distance_type}")
316
+ raise ValueError(f"Unsupported distance type: {distance_type}")
317
+
318
+ # Request more results for hybrid search and re-ranking
319
+ search_buffer_factor = 2 if hybrid_search else 1
320
+ limit = number * search_buffer_factor
321
+ embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
322
+
323
+ sql_query = f"""
324
+ SELECT
325
+ node_id,
326
+ content,
327
+ source_file_uuid,
328
+ position,
329
+ custom_metadata,
330
+ embedding {distance_operator} %s::vector AS distance
331
+ FROM
332
+ {NodeEntry._meta.db_table}
333
+ WHERE
334
+ collection_id = %s
335
+ ORDER BY
336
+ distance
337
+ LIMIT
338
+ %s
339
+ """
340
+
341
+ # Execute vector search
342
+ with connection.cursor() as cursor:
343
+ cursor.execute(sql_query, [embedding_str, collection.id, limit])
344
+ results = cursor.fetchall()
345
+ columns = [col[0] for col in cursor.description]
346
+
347
+ # Process vector search results
348
+ valid_suggestions = {}
349
+ suggested_nodes = []
350
+ seen_texts = set()
351
+
352
+ for row in results:
353
+ result_dict = dict(zip(columns, row))
354
+ node_id = result_dict["node_id"]
355
+ content = result_dict["content"]
356
+ distance = result_dict["distance"]
357
+
358
+ if content not in seen_texts:
359
+ seen_texts.add(content)
360
+ custom_metadata = result_dict["custom_metadata"] or {}
361
+ if isinstance(custom_metadata, str):
362
+ try:
363
+ import json
364
+
365
+ custom_metadata = json.loads(custom_metadata)
366
+ except (json.JSONDecodeError, TypeError):
367
+ custom_metadata = {}
368
+
369
+ metadata = NodeMetadata(
370
+ source_file_uuid=result_dict["source_file_uuid"],
371
+ position=result_dict["position"],
372
+ custom=custom_metadata,
373
+ )
374
+ metadata.node_id = node_id
375
+ node = Node(content=content, metadata=metadata)
376
+ node.embedding = result_dict.get("embedding")
377
+ suggested_nodes.append(node)
378
+
379
+ valid_suggestions[str(node_id)] = (
380
+ {
381
+ "node_id": node_id,
382
+ "source_file_uuid": result_dict["source_file_uuid"],
383
+ "position": result_dict["position"],
384
+ "custom": custom_metadata,
385
+ },
386
+ content,
387
+ float(distance),
388
+ )
389
+
390
+ # Perform hybrid search and re-ranking if enabled
391
+ if hybrid_search:
392
+ valid_suggestions, suggested_nodes = self._rerank_results(
393
+ query, list(valid_suggestions.values()
394
+ ), suggested_nodes, number
395
+ )
396
+
397
+ logger.info(f"Search returned {len(valid_suggestions)} results")
398
+ return valid_suggestions, suggested_nodes
399
+
400
+ @transaction.atomic
401
+ def create_collection_from_files(
402
+ self, collection_name: str, files: List[VSFile]
403
+ ) -> None:
404
+ """
405
+ Creates a collection from a list of VSFile objects.
406
+
407
+ :param collection_name: Name of the collection to create.
408
+ :param files: List of VSFile objects containing nodes.
409
+ """
410
+ logger.info(f"Creating collection: {collection_name} from files")
411
+ nodes = [node for file in files for node in file.nodes]
412
+ self.create_collection_from_nodes(collection_name, nodes)
413
+
414
+ @transaction.atomic
415
+ def create_collection_from_nodes(
416
+ self, collection_name: str, nodes: List[Node]
417
+ ) -> None:
418
+ """
419
+ Creates a collection from a list of nodes.
420
+
421
+ :param collection_name: Name of the collection to create.
422
+ :param nodes: List of Node objects.
423
+ """
424
+ if not nodes:
425
+ logger.warning(
426
+ f"Cannot create collection '{collection_name}' because nodes list is empty"
427
+ )
428
+ return
429
+
430
+ # Filter out nodes with None or empty content (these would cause embedding errors)
431
+ original_count = len(nodes)
432
+ nodes = [node for node in nodes if node.content is not None and str(
433
+ node.content).strip()]
434
+
435
+ if len(nodes) < original_count:
436
+ logger.warning(
437
+ f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
438
+
439
+ if not nodes:
440
+ logger.warning(
441
+ f"No valid nodes for collection '{collection_name}' after filtering")
442
+ return
443
+
444
+ total_nodes = len(nodes)
445
+ logger.info(
446
+ f"Creating collection: {collection_name} with {total_nodes} nodes")
447
+
448
+ start_time = time.time()
449
+ collection = self.get_or_create_collection(collection_name)
450
+ NodeEntry.objects.filter(collection=collection).delete()
451
+
452
+ # Generate embeddings
453
+ embed_start = time.time()
454
+ text_chunks = [str(node.content) for node in nodes]
455
+ embeddings = self.get_embeddings(text_chunks, parallel=False)
456
+ embed_time = time.time() - embed_start
457
+ logger.info(
458
+ f"Embeddings generated in {embed_time:.2f}s ({total_nodes/embed_time:.0f} nodes/s)")
459
+
460
+ # Prepare node entries
461
+ prep_start = time.time()
462
+ node_entries = [
463
+ NodeEntry(
464
+ collection=collection,
465
+ content=node.content,
466
+ embedding=embeddings[i].tolist(),
467
+ source_file_uuid=node.metadata.source_file_uuid,
468
+ position=node.metadata.position,
469
+ custom_metadata=node.metadata.custom or {},
470
+ )
471
+ for i, node in enumerate(nodes)
472
+ ]
473
+ prep_time = time.time() - prep_start
474
+ logger.info(f"Node entries prepared in {prep_time:.2f}s")
475
+
476
+ # Bulk insert
477
+ insert_start = time.time()
478
+ created_entries = NodeEntry.objects.bulk_create(node_entries)
479
+ insert_time = time.time() - insert_start
480
+ logger.info(
481
+ f"Bulk insert completed in {insert_time:.2f}s ({total_nodes/insert_time:.0f} nodes/s)")
482
+
483
+ for i, node in enumerate(nodes):
484
+ node.metadata.node_id = created_entries[i].node_id
485
+
486
+ total_time = time.time() - start_time
487
+ logger.info(
488
+ f"Created collection '{collection_name}' with {len(created_entries)} nodes in {total_time:.2f}s"
489
+ )
490
+
491
+ @transaction.atomic
492
+ def add_nodes(self, collection_name: str, nodes: List[Node]) -> None:
493
+ """
494
+ Adds nodes to an existing collection.
495
+
496
+ :param collection_name: Name of the collection to update.
497
+ :param nodes: List of Node objects to be added.
498
+ """
499
+ if not nodes:
500
+ logger.warning("No nodes to add")
501
+ return
502
+
503
+ # Filter out nodes with None or empty content (these would cause embedding errors)
504
+ original_count = len(nodes)
505
+ nodes = [node for node in nodes if node.content is not None and str(
506
+ node.content).strip()]
507
+
508
+ if len(nodes) < original_count:
509
+ logger.warning(
510
+ f"Filtered out {original_count - len(nodes)} nodes with empty/None content")
511
+
512
+ if not nodes:
513
+ logger.warning("No valid nodes to add after filtering")
514
+ return
515
+
516
+ total_nodes = len(nodes)
517
+ logger.info(
518
+ f"Adding {total_nodes} nodes to collection: {collection_name}")
519
+
520
+ start_time = time.time()
521
+ try:
522
+ collection = Collection.objects.get(name=collection_name)
523
+ except Collection.DoesNotExist:
524
+ raise ValueError(
525
+ f"No collection found with name: {collection_name}")
526
+
527
+ # Generate embeddings
528
+ embed_start = time.time()
529
+ text_chunks = [str(node.content) for node in nodes]
530
+ embeddings = self.get_embeddings(text_chunks, parallel=False)
531
+ embed_time = time.time() - embed_start
532
+ logger.info(
533
+ f"Embeddings generated in {embed_time:.2f}s ({total_nodes/embed_time:.0f} nodes/s)")
534
+
535
+ # Prepare node entries
536
+ prep_start = time.time()
537
+ node_entries = [
538
+ NodeEntry(
539
+ collection=collection,
540
+ content=node.content,
541
+ embedding=embeddings[i].tolist(),
542
+ source_file_uuid=node.metadata.source_file_uuid,
543
+ position=node.metadata.position,
544
+ custom_metadata=node.metadata.custom or {},
545
+ )
546
+ for i, node in enumerate(nodes)
547
+ ]
548
+ prep_time = time.time() - prep_start
549
+ logger.info(f"Node entries prepared in {prep_time:.2f}s")
550
+
551
+ # Bulk insert
552
+ insert_start = time.time()
553
+ created_entries = NodeEntry.objects.bulk_create(node_entries)
554
+ insert_time = time.time() - insert_start
555
+ logger.info(
556
+ f"Bulk insert completed in {insert_time:.2f}s ({total_nodes/insert_time:.0f} nodes/s)")
557
+
558
+ for i, node in enumerate(nodes):
559
+ node.metadata.node_id = created_entries[i].node_id
560
+
561
+ total_time = time.time() - start_time
562
+ logger.info(
563
+ f"Added {len(created_entries)} nodes to collection '{collection_name}' in {total_time:.2f}s"
564
+ )
565
+
566
+ @transaction.atomic
567
+ def delete_nodes(self, collection_name: str, node_ids: List[int]) -> None:
568
+ """
569
+ Deletes nodes from an existing collection.
570
+
571
+ :param collection_name: Name of the collection to update.
572
+ :param node_ids: List of node IDs to be deleted.
573
+ """
574
+ if not node_ids:
575
+ logger.warning("No node IDs to delete")
576
+ return
577
+
578
+ logger.info(
579
+ f"Deleting {len(node_ids)} nodes from collection: {collection_name}"
580
+ )
581
+ try:
582
+ collection = Collection.objects.get(name=collection_name)
583
+ except Collection.DoesNotExist:
584
+ raise ValueError(
585
+ f"No collection found with name: {collection_name}")
586
+
587
+ existing_ids = set(
588
+ NodeEntry.objects.filter(
589
+ collection=collection, node_id__in=node_ids
590
+ ).values_list("node_id", flat=True)
591
+ )
592
+ missing_ids = set(node_ids) - existing_ids
593
+ if missing_ids:
594
+ logger.warning(
595
+ f"Node ID(s) {missing_ids} not found in collection {collection_name}"
596
+ )
597
+
598
+ deleted_count, _ = NodeEntry.objects.filter(
599
+ collection=collection, node_id__in=existing_ids
600
+ ).delete()
601
+ logger.info(
602
+ f"Deleted {deleted_count} nodes from collection '{collection_name}'"
603
+ )
604
+
605
+ @transaction.atomic
606
+ def add_files(self, collection_name: str, files: List[VSFile]) -> None:
607
+ """
608
+ Adds file nodes to the specified collection.
609
+
610
+ :param collection_name: Name of the collection to update.
611
+ :param files: List of VSFile objects whose nodes are to be added.
612
+ """
613
+ logger.info(f"Adding files to collection: {collection_name}")
614
+ all_nodes = [node for file in files for node in file.nodes]
615
+ self.add_nodes(collection_name, all_nodes)
616
+
617
+ @transaction.atomic
618
+ def delete_files(self, collection_name: str, files: List[VSFile]) -> None:
619
+ """
620
+ Deletes file nodes from the specified collection.
621
+
622
+ :param collection_name: Name of the collection to update.
623
+ :param files: List of VSFile objects whose nodes are to be deleted.
624
+ """
625
+ logger.info(f"Deleting files from collection: {collection_name}")
626
+ node_ids_to_delete = [
627
+ node.metadata.node_id
628
+ for file in files
629
+ for node in file.nodes
630
+ if node.metadata.node_id
631
+ ]
632
+ if node_ids_to_delete:
633
+ self.delete_nodes(collection_name, node_ids_to_delete)
634
+ else:
635
+ logger.warning("No node IDs found in provided files")
636
+
637
+ def list_collections(self) -> List[str]:
638
+ """
639
+ Lists all available collections.
640
+
641
+ :return: List of collection names.
642
+ """
643
+ return list(Collection.objects.values_list("name", flat=True))
644
+
645
+ def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
646
+ """
647
+ Gets information about a collection.
648
+
649
+ :param collection_name: Name of the collection.
650
+ :return: Dictionary containing collection information.
651
+ """
652
+ try:
653
+ collection = Collection.objects.get(name=collection_name)
654
+ except Collection.DoesNotExist:
655
+ raise ValueError(
656
+ f"No collection found with name: {collection_name}")
657
+
658
+ node_count = NodeEntry.objects.filter(collection=collection).count()
659
+ return {
660
+ "name": collection.name,
661
+ "embedding_dim": collection.embedding_dim,
662
+ "node_count": node_count,
663
+ "created_at": collection.created_at,
664
+ "updated_at": collection.updated_at,
665
+ }
666
+
667
+ @transaction.atomic
668
+ def delete_collection(self, collection_name: str) -> None:
669
+ """
670
+ Deletes a collection and all its nodes.
671
+
672
+ :param collection_name: Name of the collection to delete.
673
+ """
674
+ try:
675
+ collection = Collection.objects.get(name=collection_name)
676
+ except Collection.DoesNotExist:
677
+ raise ValueError(
678
+ f"No collection found with name: {collection_name}")
679
+
680
+ node_count = NodeEntry.objects.filter(hourly=collection).count()
681
+ collection.delete()
682
+ logger.info(
683
+ f"Deleted collection '{collection_name}' with {node_count} nodes")
684
+
685
+ # VectorStore interface methods
686
+ def add(self, vectors: List[List[float]], metadatas: List[Dict[str, Any]]) -> Any:
687
+ """
688
+ Adds vectors with metadata to the default collection.
689
+ This method implements the VectorStore interface.
690
+
691
+ :param vectors: List of embedding vectors to add.
692
+ :param metadatas: List of metadata dictionaries for each vector.
693
+ :return: List of node IDs that were created.
694
+ """
695
+ if not vectors or not metadatas:
696
+ logger.warning("Empty vectors or metadatas provided to add()")
697
+ return []
698
+
699
+ if len(vectors) != len(metadatas):
700
+ raise ValueError(
701
+ "Number of vectors must match number of metadatas")
702
+
703
+ # Get or create default collection
704
+ collection_name = metadatas[0].get(
705
+ "collection_name", "default_collection")
706
+ collection = self.get_or_create_collection(collection_name)
707
+
708
+ # Create nodes from vectors and metadatas
709
+ node_entries = []
710
+ for i, (vector, metadata) in enumerate(zip(vectors, metadatas)):
711
+ content = metadata.get("content", "")
712
+ source_file_uuid = metadata.get("source_file_uuid", "")
713
+ position = metadata.get("position", i)
714
+ custom_metadata = {
715
+ k: v
716
+ for k, v in metadata.items()
717
+ if k not in ["content", "source_file_uuid", "position", "collection_name"]
718
+ }
719
+
720
+ node_entries.append(
721
+ NodeEntry(
722
+ collection=collection,
723
+ content=content,
724
+ embedding=vector,
725
+ source_file_uuid=source_file_uuid,
726
+ position=position,
727
+ custom_metadata=custom_metadata,
728
+ )
729
+ )
730
+
731
+ created_entries = NodeEntry.objects.bulk_create(node_entries)
732
+ node_ids = [entry.node_id for entry in created_entries]
733
+ logger.info(
734
+ f"Added {len(node_ids)} vectors to collection '{collection_name}'")
735
+ return node_ids
736
+
737
+ def query(
738
+ self, vector: List[float], top_k: int = 5, **kwargs
739
+ ) -> List[Dict[str, Any]]:
740
+ """
741
+ Queries the vector store for similar vectors.
742
+ This method implements the VectorStore interface.
743
+
744
+ :param vector: Query vector.
745
+ :param top_k: Number of results to return.
746
+ :param kwargs: Additional parameters (collection_name, distance_type, meta_data_filters).
747
+ :return: List of dictionaries containing search results.
748
+ """
749
+ collection_name = kwargs.get("collection_name", "default_collection")
750
+ distance_type = kwargs.get("distance_type", "cosine")
751
+ meta_data_filters = kwargs.get("meta_data_filters")
752
+
753
+ try:
754
+ collection = Collection.objects.get(name=collection_name)
755
+ except Collection.DoesNotExist:
756
+ logger.warning(f"Collection '{collection_name}' not found")
757
+ return []
758
+
759
+ # Convert vector to numpy array
760
+ query_embedding = np.array(vector, dtype="float32")
761
+
762
+ # Normalize if using cosine distance
763
+ if distance_type == "cosine":
764
+ norm = np.linalg.norm(query_embedding)
765
+ if norm > 0:
766
+ query_embedding = query_embedding / norm
767
+
768
+ # Build queryset
769
+ queryset = NodeEntry.objects.filter(collection=collection)
770
+
771
+ # Apply metadata filters
772
+ if meta_data_filters:
773
+ for key, value in meta_data_filters.items():
774
+ queryset = queryset.filter(
775
+ **{f"custom_metadata__{key}": value})
776
+
777
+ # Determine distance operator
778
+ if distance_type == "cosine":
779
+ distance_operator = "<=>"
780
+ elif distance_type == "l2":
781
+ distance_operator = "<->"
782
+ elif distance_type == "dot":
783
+ distance_operator = "<#>"
784
+ else:
785
+ raise ValueError(f"Unsupported distance type: {distance_type}")
786
+
787
+ embedding_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
788
+
789
+ sql_query = f"""
790
+ SELECT
791
+ node_id,
792
+ content,
793
+ source_file_uuid,
794
+ position,
795
+ custom_metadata,
796
+ embedding {distance_operator} %s::vector AS distance
797
+ FROM
798
+ {NodeEntry._meta.db_table}
799
+ WHERE
800
+ collection_id = %s
801
+ ORDER BY
802
+ distance
803
+ LIMIT
804
+ %s
805
+ """
806
+
807
+ # Execute query
808
+ with connection.cursor() as cursor:
809
+ cursor.execute(sql_query, [embedding_str, collection.id, top_k])
810
+ results = cursor.fetchall()
811
+ columns = [col[0] for col in cursor.description]
812
+
813
+ # Format results
814
+ formatted_results = []
815
+ for row in results:
816
+ result_dict = dict(zip(columns, row))
817
+ formatted_results.append({
818
+ "node_id": result_dict["node_id"],
819
+ "content": result_dict["content"],
820
+ "source_file_uuid": result_dict["source_file_uuid"],
821
+ "position": result_dict["position"],
822
+ "metadata": result_dict["custom_metadata"] or {},
823
+ "distance": float(result_dict["distance"]),
824
+ })
825
+
826
+ logger.info(f"Query returned {len(formatted_results)} results")
827
+ return formatted_results