agentrun-mem0ai 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. agentrun_mem0/__init__.py +6 -0
  2. agentrun_mem0/client/__init__.py +0 -0
  3. agentrun_mem0/client/main.py +1747 -0
  4. agentrun_mem0/client/project.py +931 -0
  5. agentrun_mem0/client/utils.py +115 -0
  6. agentrun_mem0/configs/__init__.py +0 -0
  7. agentrun_mem0/configs/base.py +90 -0
  8. agentrun_mem0/configs/embeddings/__init__.py +0 -0
  9. agentrun_mem0/configs/embeddings/base.py +110 -0
  10. agentrun_mem0/configs/enums.py +7 -0
  11. agentrun_mem0/configs/llms/__init__.py +0 -0
  12. agentrun_mem0/configs/llms/anthropic.py +56 -0
  13. agentrun_mem0/configs/llms/aws_bedrock.py +192 -0
  14. agentrun_mem0/configs/llms/azure.py +57 -0
  15. agentrun_mem0/configs/llms/base.py +62 -0
  16. agentrun_mem0/configs/llms/deepseek.py +56 -0
  17. agentrun_mem0/configs/llms/lmstudio.py +59 -0
  18. agentrun_mem0/configs/llms/ollama.py +56 -0
  19. agentrun_mem0/configs/llms/openai.py +79 -0
  20. agentrun_mem0/configs/llms/vllm.py +56 -0
  21. agentrun_mem0/configs/prompts.py +459 -0
  22. agentrun_mem0/configs/rerankers/__init__.py +0 -0
  23. agentrun_mem0/configs/rerankers/base.py +17 -0
  24. agentrun_mem0/configs/rerankers/cohere.py +15 -0
  25. agentrun_mem0/configs/rerankers/config.py +12 -0
  26. agentrun_mem0/configs/rerankers/huggingface.py +17 -0
  27. agentrun_mem0/configs/rerankers/llm.py +48 -0
  28. agentrun_mem0/configs/rerankers/sentence_transformer.py +16 -0
  29. agentrun_mem0/configs/rerankers/zero_entropy.py +28 -0
  30. agentrun_mem0/configs/vector_stores/__init__.py +0 -0
  31. agentrun_mem0/configs/vector_stores/alibabacloud_mysql.py +64 -0
  32. agentrun_mem0/configs/vector_stores/aliyun_tablestore.py +32 -0
  33. agentrun_mem0/configs/vector_stores/azure_ai_search.py +57 -0
  34. agentrun_mem0/configs/vector_stores/azure_mysql.py +84 -0
  35. agentrun_mem0/configs/vector_stores/baidu.py +27 -0
  36. agentrun_mem0/configs/vector_stores/chroma.py +58 -0
  37. agentrun_mem0/configs/vector_stores/databricks.py +61 -0
  38. agentrun_mem0/configs/vector_stores/elasticsearch.py +65 -0
  39. agentrun_mem0/configs/vector_stores/faiss.py +37 -0
  40. agentrun_mem0/configs/vector_stores/langchain.py +30 -0
  41. agentrun_mem0/configs/vector_stores/milvus.py +42 -0
  42. agentrun_mem0/configs/vector_stores/mongodb.py +25 -0
  43. agentrun_mem0/configs/vector_stores/neptune.py +27 -0
  44. agentrun_mem0/configs/vector_stores/opensearch.py +41 -0
  45. agentrun_mem0/configs/vector_stores/pgvector.py +52 -0
  46. agentrun_mem0/configs/vector_stores/pinecone.py +55 -0
  47. agentrun_mem0/configs/vector_stores/qdrant.py +47 -0
  48. agentrun_mem0/configs/vector_stores/redis.py +24 -0
  49. agentrun_mem0/configs/vector_stores/s3_vectors.py +28 -0
  50. agentrun_mem0/configs/vector_stores/supabase.py +44 -0
  51. agentrun_mem0/configs/vector_stores/upstash_vector.py +34 -0
  52. agentrun_mem0/configs/vector_stores/valkey.py +15 -0
  53. agentrun_mem0/configs/vector_stores/vertex_ai_vector_search.py +28 -0
  54. agentrun_mem0/configs/vector_stores/weaviate.py +41 -0
  55. agentrun_mem0/embeddings/__init__.py +0 -0
  56. agentrun_mem0/embeddings/aws_bedrock.py +100 -0
  57. agentrun_mem0/embeddings/azure_openai.py +55 -0
  58. agentrun_mem0/embeddings/base.py +31 -0
  59. agentrun_mem0/embeddings/configs.py +30 -0
  60. agentrun_mem0/embeddings/gemini.py +39 -0
  61. agentrun_mem0/embeddings/huggingface.py +44 -0
  62. agentrun_mem0/embeddings/langchain.py +35 -0
  63. agentrun_mem0/embeddings/lmstudio.py +29 -0
  64. agentrun_mem0/embeddings/mock.py +11 -0
  65. agentrun_mem0/embeddings/ollama.py +53 -0
  66. agentrun_mem0/embeddings/openai.py +49 -0
  67. agentrun_mem0/embeddings/together.py +31 -0
  68. agentrun_mem0/embeddings/vertexai.py +64 -0
  69. agentrun_mem0/exceptions.py +503 -0
  70. agentrun_mem0/graphs/__init__.py +0 -0
  71. agentrun_mem0/graphs/configs.py +105 -0
  72. agentrun_mem0/graphs/neptune/__init__.py +0 -0
  73. agentrun_mem0/graphs/neptune/base.py +497 -0
  74. agentrun_mem0/graphs/neptune/neptunedb.py +511 -0
  75. agentrun_mem0/graphs/neptune/neptunegraph.py +474 -0
  76. agentrun_mem0/graphs/tools.py +371 -0
  77. agentrun_mem0/graphs/utils.py +97 -0
  78. agentrun_mem0/llms/__init__.py +0 -0
  79. agentrun_mem0/llms/anthropic.py +87 -0
  80. agentrun_mem0/llms/aws_bedrock.py +665 -0
  81. agentrun_mem0/llms/azure_openai.py +141 -0
  82. agentrun_mem0/llms/azure_openai_structured.py +91 -0
  83. agentrun_mem0/llms/base.py +131 -0
  84. agentrun_mem0/llms/configs.py +34 -0
  85. agentrun_mem0/llms/deepseek.py +107 -0
  86. agentrun_mem0/llms/gemini.py +201 -0
  87. agentrun_mem0/llms/groq.py +88 -0
  88. agentrun_mem0/llms/langchain.py +94 -0
  89. agentrun_mem0/llms/litellm.py +87 -0
  90. agentrun_mem0/llms/lmstudio.py +114 -0
  91. agentrun_mem0/llms/ollama.py +117 -0
  92. agentrun_mem0/llms/openai.py +147 -0
  93. agentrun_mem0/llms/openai_structured.py +52 -0
  94. agentrun_mem0/llms/sarvam.py +89 -0
  95. agentrun_mem0/llms/together.py +88 -0
  96. agentrun_mem0/llms/vllm.py +107 -0
  97. agentrun_mem0/llms/xai.py +52 -0
  98. agentrun_mem0/memory/__init__.py +0 -0
  99. agentrun_mem0/memory/base.py +63 -0
  100. agentrun_mem0/memory/graph_memory.py +698 -0
  101. agentrun_mem0/memory/kuzu_memory.py +713 -0
  102. agentrun_mem0/memory/main.py +2229 -0
  103. agentrun_mem0/memory/memgraph_memory.py +689 -0
  104. agentrun_mem0/memory/setup.py +56 -0
  105. agentrun_mem0/memory/storage.py +218 -0
  106. agentrun_mem0/memory/telemetry.py +90 -0
  107. agentrun_mem0/memory/utils.py +208 -0
  108. agentrun_mem0/proxy/__init__.py +0 -0
  109. agentrun_mem0/proxy/main.py +189 -0
  110. agentrun_mem0/reranker/__init__.py +9 -0
  111. agentrun_mem0/reranker/base.py +20 -0
  112. agentrun_mem0/reranker/cohere_reranker.py +85 -0
  113. agentrun_mem0/reranker/huggingface_reranker.py +147 -0
  114. agentrun_mem0/reranker/llm_reranker.py +142 -0
  115. agentrun_mem0/reranker/sentence_transformer_reranker.py +107 -0
  116. agentrun_mem0/reranker/zero_entropy_reranker.py +96 -0
  117. agentrun_mem0/utils/factory.py +283 -0
  118. agentrun_mem0/utils/gcp_auth.py +167 -0
  119. agentrun_mem0/vector_stores/__init__.py +0 -0
  120. agentrun_mem0/vector_stores/alibabacloud_mysql.py +547 -0
  121. agentrun_mem0/vector_stores/aliyun_tablestore.py +252 -0
  122. agentrun_mem0/vector_stores/azure_ai_search.py +396 -0
  123. agentrun_mem0/vector_stores/azure_mysql.py +463 -0
  124. agentrun_mem0/vector_stores/baidu.py +368 -0
  125. agentrun_mem0/vector_stores/base.py +58 -0
  126. agentrun_mem0/vector_stores/chroma.py +332 -0
  127. agentrun_mem0/vector_stores/configs.py +67 -0
  128. agentrun_mem0/vector_stores/databricks.py +761 -0
  129. agentrun_mem0/vector_stores/elasticsearch.py +237 -0
  130. agentrun_mem0/vector_stores/faiss.py +479 -0
  131. agentrun_mem0/vector_stores/langchain.py +180 -0
  132. agentrun_mem0/vector_stores/milvus.py +250 -0
  133. agentrun_mem0/vector_stores/mongodb.py +310 -0
  134. agentrun_mem0/vector_stores/neptune_analytics.py +467 -0
  135. agentrun_mem0/vector_stores/opensearch.py +292 -0
  136. agentrun_mem0/vector_stores/pgvector.py +404 -0
  137. agentrun_mem0/vector_stores/pinecone.py +382 -0
  138. agentrun_mem0/vector_stores/qdrant.py +270 -0
  139. agentrun_mem0/vector_stores/redis.py +295 -0
  140. agentrun_mem0/vector_stores/s3_vectors.py +176 -0
  141. agentrun_mem0/vector_stores/supabase.py +237 -0
  142. agentrun_mem0/vector_stores/upstash_vector.py +293 -0
  143. agentrun_mem0/vector_stores/valkey.py +824 -0
  144. agentrun_mem0/vector_stores/vertex_ai_vector_search.py +635 -0
  145. agentrun_mem0/vector_stores/weaviate.py +343 -0
  146. agentrun_mem0ai-0.0.11.data/data/README.md +205 -0
  147. agentrun_mem0ai-0.0.11.dist-info/METADATA +277 -0
  148. agentrun_mem0ai-0.0.11.dist-info/RECORD +150 -0
  149. agentrun_mem0ai-0.0.11.dist-info/WHEEL +4 -0
  150. agentrun_mem0ai-0.0.11.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,292 @@
1
+ import logging
2
+ import time
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ try:
6
+ from opensearchpy import OpenSearch, RequestsHttpConnection
7
+ except ImportError:
8
+ raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
9
+
10
+ from pydantic import BaseModel
11
+
12
+ from agentrun_mem0.configs.vector_stores.opensearch import OpenSearchConfig
13
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OutputData(BaseModel):
19
+ id: str
20
+ score: float
21
+ payload: Dict
22
+
23
+
24
+ class OpenSearchDB(VectorStoreBase):
25
+ def __init__(self, **kwargs):
26
+ config = OpenSearchConfig(**kwargs)
27
+
28
+ # Initialize OpenSearch client
29
+ self.client = OpenSearch(
30
+ hosts=[{"host": config.host, "port": config.port or 9200}],
31
+ http_auth=config.http_auth
32
+ if config.http_auth
33
+ else ((config.user, config.password) if (config.user and config.password) else None),
34
+ use_ssl=config.use_ssl,
35
+ verify_certs=config.verify_certs,
36
+ connection_class=RequestsHttpConnection,
37
+ pool_maxsize=20,
38
+ )
39
+
40
+ self.collection_name = config.collection_name
41
+ self.embedding_model_dims = config.embedding_model_dims
42
+ self.create_col(self.collection_name, self.embedding_model_dims)
43
+
44
+ def create_index(self) -> None:
45
+ """Create OpenSearch index with proper mappings if it doesn't exist."""
46
+ index_settings = {
47
+ "settings": {
48
+ "index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
49
+ },
50
+ "mappings": {
51
+ "properties": {
52
+ "text": {"type": "text"},
53
+ "vector_field": {
54
+ "type": "knn_vector",
55
+ "dimension": self.embedding_model_dims,
56
+ "method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
57
+ },
58
+ "metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
59
+ }
60
+ },
61
+ }
62
+
63
+ if not self.client.indices.exists(index=self.collection_name):
64
+ self.client.indices.create(index=self.collection_name, body=index_settings)
65
+ logger.info(f"Created index {self.collection_name}")
66
+ else:
67
+ logger.info(f"Index {self.collection_name} already exists")
68
+
69
+ def create_col(self, name: str, vector_size: int) -> None:
70
+ """Create a new collection (index in OpenSearch)."""
71
+ index_settings = {
72
+ "settings": {"index.knn": True},
73
+ "mappings": {
74
+ "properties": {
75
+ "vector_field": {
76
+ "type": "knn_vector",
77
+ "dimension": vector_size,
78
+ "method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
79
+ },
80
+ "payload": {"type": "object"},
81
+ "id": {"type": "keyword"},
82
+ }
83
+ },
84
+ }
85
+
86
+ if not self.client.indices.exists(index=name):
87
+ logger.warning(f"Creating index {name}, it might take 1-2 minutes...")
88
+ self.client.indices.create(index=name, body=index_settings)
89
+
90
+ # Wait for index to be ready
91
+ max_retries = 180 # 3 minutes timeout
92
+ retry_count = 0
93
+ while retry_count < max_retries:
94
+ try:
95
+ # Check if index is ready by attempting a simple search
96
+ self.client.search(index=name, body={"query": {"match_all": {}}})
97
+ time.sleep(1)
98
+ logger.info(f"Index {name} is ready")
99
+ return
100
+ except Exception:
101
+ retry_count += 1
102
+ if retry_count == max_retries:
103
+ raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
104
+ time.sleep(0.5)
105
+
106
+ def insert(
107
+ self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
108
+ ) -> List[OutputData]:
109
+ """Insert vectors into the index."""
110
+ if not ids:
111
+ ids = [str(i) for i in range(len(vectors))]
112
+
113
+ if payloads is None:
114
+ payloads = [{} for _ in range(len(vectors))]
115
+
116
+ results = []
117
+ for i, (vec, id_) in enumerate(zip(vectors, ids)):
118
+ body = {
119
+ "vector_field": vec,
120
+ "payload": payloads[i],
121
+ "id": id_,
122
+ }
123
+ try:
124
+ self.client.index(index=self.collection_name, body=body)
125
+ # Force refresh to make documents immediately searchable for tests
126
+ self.client.indices.refresh(index=self.collection_name)
127
+
128
+ results.append(OutputData(
129
+ id=id_,
130
+ score=1.0, # No score for inserts
131
+ payload=payloads[i]
132
+ ))
133
+ except Exception as e:
134
+ logger.error(f"Error inserting vector {id_}: {e}")
135
+ raise
136
+
137
+ return results
138
+
139
+ def search(
140
+ self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
141
+ ) -> List[OutputData]:
142
+ """Search for similar vectors using OpenSearch k-NN search with optional filters."""
143
+
144
+ # Base KNN query
145
+ knn_query = {
146
+ "knn": {
147
+ "vector_field": {
148
+ "vector": vectors,
149
+ "k": limit * 2,
150
+ }
151
+ }
152
+ }
153
+
154
+ # Start building the full query
155
+ query_body = {"size": limit * 2, "query": None}
156
+
157
+ # Prepare filter conditions if applicable
158
+ filter_clauses = []
159
+ if filters:
160
+ for key in ["user_id", "run_id", "agent_id"]:
161
+ value = filters.get(key)
162
+ if value:
163
+ filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
164
+
165
+ # Combine knn with filters if needed
166
+ if filter_clauses:
167
+ query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
168
+ else:
169
+ query_body["query"] = knn_query
170
+
171
+ try:
172
+ # Execute search
173
+ response = self.client.search(index=self.collection_name, body=query_body)
174
+
175
+ hits = response["hits"]["hits"]
176
+ results = [
177
+ OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
178
+ for hit in hits[:limit] # Ensure we don't exceed limit
179
+ ]
180
+ return results
181
+ except Exception as e:
182
+ logger.error(f"Error during search: {e}")
183
+ return []
184
+
185
+ def delete(self, vector_id: str) -> None:
186
+ """Delete a vector by custom ID."""
187
+ # First, find the document by custom ID
188
+ search_query = {"query": {"term": {"id": vector_id}}}
189
+
190
+ response = self.client.search(index=self.collection_name, body=search_query)
191
+ hits = response.get("hits", {}).get("hits", [])
192
+
193
+ if not hits:
194
+ return
195
+
196
+ opensearch_id = hits[0]["_id"]
197
+
198
+ # Delete using the actual document ID
199
+ self.client.delete(index=self.collection_name, id=opensearch_id)
200
+
201
+ def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
202
+ """Update a vector and its payload using the custom 'id' field."""
203
+
204
+ # First, find the document by custom ID
205
+ search_query = {"query": {"term": {"id": vector_id}}}
206
+
207
+ response = self.client.search(index=self.collection_name, body=search_query)
208
+ hits = response.get("hits", {}).get("hits", [])
209
+
210
+ if not hits:
211
+ return
212
+
213
+ opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
214
+
215
+ # Prepare updated fields
216
+ doc = {}
217
+ if vector is not None:
218
+ doc["vector_field"] = vector
219
+ if payload is not None:
220
+ doc["payload"] = payload
221
+
222
+ if doc:
223
+ try:
224
+ response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
225
+ except Exception:
226
+ pass
227
+
228
+ def get(self, vector_id: str) -> Optional[OutputData]:
229
+ """Retrieve a vector by ID."""
230
+ try:
231
+ search_query = {"query": {"term": {"id": vector_id}}}
232
+ response = self.client.search(index=self.collection_name, body=search_query)
233
+
234
+ hits = response["hits"]["hits"]
235
+
236
+ if not hits:
237
+ return None
238
+
239
+ return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
240
+ except Exception as e:
241
+ logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
242
+ return None
243
+
244
+ def list_cols(self) -> List[str]:
245
+ """List all collections (indices)."""
246
+ return list(self.client.indices.get_alias().keys())
247
+
248
+ def delete_col(self) -> None:
249
+ """Delete a collection (index)."""
250
+ self.client.indices.delete(index=self.collection_name)
251
+
252
+ def col_info(self, name: str) -> Any:
253
+ """Get information about a collection (index)."""
254
+ return self.client.indices.get(index=name)
255
+
256
+ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
257
+ try:
258
+ """List all memories with optional filters."""
259
+ query: Dict = {"query": {"match_all": {}}}
260
+
261
+ filter_clauses = []
262
+ if filters:
263
+ for key in ["user_id", "run_id", "agent_id"]:
264
+ value = filters.get(key)
265
+ if value:
266
+ filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
267
+
268
+ if filter_clauses:
269
+ query["query"] = {"bool": {"filter": filter_clauses}}
270
+
271
+ if limit:
272
+ query["size"] = limit
273
+
274
+ response = self.client.search(index=self.collection_name, body=query)
275
+ hits = response["hits"]["hits"]
276
+
277
+ # Return a flat list, not a nested array
278
+ results = [
279
+ OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
280
+ for hit in hits
281
+ ]
282
+ return [results] # VectorStore expects tuple/list format
283
+ except Exception as e:
284
+ logger.error(f"Error listing vectors: {e}")
285
+ return []
286
+
287
+
288
+ def reset(self):
289
+ """Reset the index by deleting and recreating it."""
290
+ logger.warning(f"Resetting index {self.collection_name}...")
291
+ self.delete_col()
292
+ self.create_col(self.collection_name, self.embedding_model_dims)
@@ -0,0 +1,404 @@
1
+ import json
2
+ import logging
3
+ from contextlib import contextmanager
4
+ from typing import Any, List, Optional
5
+
6
+ from pydantic import BaseModel
7
+
8
+ # Try to import psycopg (psycopg3) first, then fall back to psycopg2
9
+ try:
10
+ from psycopg.types.json import Json
11
+ from psycopg_pool import ConnectionPool
12
+ PSYCOPG_VERSION = 3
13
+ logger = logging.getLogger(__name__)
14
+ logger.info("Using psycopg (psycopg3) with ConnectionPool for PostgreSQL connections")
15
+ except ImportError:
16
+ try:
17
+ from psycopg2.extras import Json, execute_values
18
+ from psycopg2.pool import ThreadedConnectionPool as ConnectionPool
19
+ PSYCOPG_VERSION = 2
20
+ logger = logging.getLogger(__name__)
21
+ logger.info("Using psycopg2 with ThreadedConnectionPool for PostgreSQL connections")
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Neither 'psycopg' nor 'psycopg2' library is available. "
25
+ "Please install one of them using 'pip install psycopg[pool]' or 'pip install psycopg2'"
26
+ )
27
+
28
+ from agentrun_mem0.vector_stores.base import VectorStoreBase
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class OutputData(BaseModel):
34
+ id: Optional[str]
35
+ score: Optional[float]
36
+ payload: Optional[dict]
37
+
38
+
39
+ class PGVector(VectorStoreBase):
40
+ def __init__(
41
+ self,
42
+ dbname,
43
+ collection_name,
44
+ embedding_model_dims,
45
+ user,
46
+ password,
47
+ host,
48
+ port,
49
+ diskann,
50
+ hnsw,
51
+ minconn=1,
52
+ maxconn=5,
53
+ sslmode=None,
54
+ connection_string=None,
55
+ connection_pool=None,
56
+ ):
57
+ """
58
+ Initialize the PGVector database.
59
+
60
+ Args:
61
+ dbname (str): Database name
62
+ collection_name (str): Collection name
63
+ embedding_model_dims (int): Dimension of the embedding vector
64
+ user (str): Database user
65
+ password (str): Database password
66
+ host (str, optional): Database host
67
+ port (int, optional): Database port
68
+ diskann (bool, optional): Use DiskANN for faster search
69
+ hnsw (bool, optional): Use HNSW for faster search
70
+ minconn (int): Minimum number of connections to keep in the connection pool
71
+ maxconn (int): Maximum number of connections allowed in the connection pool
72
+ sslmode (str, optional): SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')
73
+ connection_string (str, optional): PostgreSQL connection string (overrides individual connection parameters)
74
+ connection_pool (Any, optional): psycopg2 connection pool object (overrides connection string and individual parameters)
75
+ """
76
+ self.collection_name = collection_name
77
+ self.use_diskann = diskann
78
+ self.use_hnsw = hnsw
79
+ self.embedding_model_dims = embedding_model_dims
80
+ self.connection_pool = None
81
+
82
+ # Connection setup with priority: connection_pool > connection_string > individual parameters
83
+ if connection_pool is not None:
84
+ # Use provided connection pool
85
+ self.connection_pool = connection_pool
86
+ elif connection_string:
87
+ if sslmode:
88
+ # Append sslmode to connection string if provided
89
+ if 'sslmode=' in connection_string:
90
+ # Replace existing sslmode
91
+ import re
92
+ connection_string = re.sub(r'sslmode=[^ ]*', f'sslmode={sslmode}', connection_string)
93
+ else:
94
+ # Add sslmode to connection string
95
+ connection_string = f"{connection_string} sslmode={sslmode}"
96
+ else:
97
+ connection_string = f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
98
+ if sslmode:
99
+ connection_string = f"{connection_string} sslmode={sslmode}"
100
+
101
+ if self.connection_pool is None:
102
+ if PSYCOPG_VERSION == 3:
103
+ # psycopg3 ConnectionPool
104
+ self.connection_pool = ConnectionPool(conninfo=connection_string, min_size=minconn, max_size=maxconn, open=True)
105
+ else:
106
+ # psycopg2 ThreadedConnectionPool
107
+ self.connection_pool = ConnectionPool(minconn=minconn, maxconn=maxconn, dsn=connection_string)
108
+
109
+ collections = self.list_cols()
110
+ if collection_name not in collections:
111
+ self.create_col()
112
+
113
+ @contextmanager
114
+ def _get_cursor(self, commit: bool = False):
115
+ """
116
+ Unified context manager to get a cursor from the appropriate pool.
117
+ Auto-commits or rolls back based on exception, and returns the connection to the pool.
118
+ """
119
+ if PSYCOPG_VERSION == 3:
120
+ # psycopg3 auto-manages commit/rollback and pool return
121
+ with self.connection_pool.connection() as conn:
122
+ with conn.cursor() as cur:
123
+ try:
124
+ yield cur
125
+ if commit:
126
+ conn.commit()
127
+ except Exception:
128
+ conn.rollback()
129
+ logger.error("Error in cursor context (psycopg3)", exc_info=True)
130
+ raise
131
+ else:
132
+ # psycopg2 manual getconn/putconn
133
+ conn = self.connection_pool.getconn()
134
+ cur = conn.cursor()
135
+ try:
136
+ yield cur
137
+ if commit:
138
+ conn.commit()
139
+ except Exception as exc:
140
+ conn.rollback()
141
+ logger.error(f"Error occurred: {exc}")
142
+ raise exc
143
+ finally:
144
+ cur.close()
145
+ self.connection_pool.putconn(conn)
146
+
147
+ def create_col(self) -> None:
148
+ """
149
+ Create a new collection (table in PostgreSQL).
150
+ Will also initialize vector search index if specified.
151
+ """
152
+ with self._get_cursor(commit=True) as cur:
153
+ cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
154
+ cur.execute(
155
+ f"""
156
+ CREATE TABLE IF NOT EXISTS {self.collection_name} (
157
+ id UUID PRIMARY KEY,
158
+ vector vector({self.embedding_model_dims}),
159
+ payload JSONB
160
+ );
161
+ """
162
+ )
163
+ if self.use_diskann and self.embedding_model_dims < 2000:
164
+ cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
165
+ if cur.fetchone():
166
+ # Create DiskANN index if extension is installed for faster search
167
+ cur.execute(
168
+ f"""
169
+ CREATE INDEX IF NOT EXISTS {self.collection_name}_diskann_idx
170
+ ON {self.collection_name}
171
+ USING diskann (vector);
172
+ """
173
+ )
174
+ elif self.use_hnsw:
175
+ cur.execute(
176
+ f"""
177
+ CREATE INDEX IF NOT EXISTS {self.collection_name}_hnsw_idx
178
+ ON {self.collection_name}
179
+ USING hnsw (vector vector_cosine_ops)
180
+ """
181
+ )
182
+
183
+ def insert(self, vectors: list[list[float]], payloads=None, ids=None) -> None:
184
+ logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
185
+ json_payloads = [json.dumps(payload) for payload in payloads]
186
+
187
+ data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
188
+ if PSYCOPG_VERSION == 3:
189
+ with self._get_cursor(commit=True) as cur:
190
+ cur.executemany(
191
+ f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES (%s, %s, %s)",
192
+ data,
193
+ )
194
+ else:
195
+ with self._get_cursor(commit=True) as cur:
196
+ execute_values(
197
+ cur,
198
+ f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
199
+ data,
200
+ )
201
+
202
+ def search(
203
+ self,
204
+ query: str,
205
+ vectors: list[float],
206
+ limit: Optional[int] = 5,
207
+ filters: Optional[dict] = None,
208
+ ) -> List[OutputData]:
209
+ """
210
+ Search for similar vectors.
211
+
212
+ Args:
213
+ query (str): Query.
214
+ vectors (List[float]): Query vector.
215
+ limit (int, optional): Number of results to return. Defaults to 5.
216
+ filters (Dict, optional): Filters to apply to the search. Defaults to None.
217
+
218
+ Returns:
219
+ list: Search results.
220
+ """
221
+ filter_conditions = []
222
+ filter_params = []
223
+
224
+ if filters:
225
+ for k, v in filters.items():
226
+ filter_conditions.append("payload->>%s = %s")
227
+ filter_params.extend([k, str(v)])
228
+
229
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
230
+
231
+ with self._get_cursor() as cur:
232
+ cur.execute(
233
+ f"""
234
+ SELECT id, vector <=> %s::vector AS distance, payload
235
+ FROM {self.collection_name}
236
+ {filter_clause}
237
+ ORDER BY distance
238
+ LIMIT %s
239
+ """,
240
+ (vectors, *filter_params, limit),
241
+ )
242
+
243
+ results = cur.fetchall()
244
+ return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
245
+
246
+ def delete(self, vector_id: str) -> None:
247
+ """
248
+ Delete a vector by ID.
249
+
250
+ Args:
251
+ vector_id (str): ID of the vector to delete.
252
+ """
253
+ with self._get_cursor(commit=True) as cur:
254
+ cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
255
+
256
+ def update(
257
+ self,
258
+ vector_id: str,
259
+ vector: Optional[list[float]] = None,
260
+ payload: Optional[dict] = None,
261
+ ) -> None:
262
+ """
263
+ Update a vector and its payload.
264
+
265
+ Args:
266
+ vector_id (str): ID of the vector to update.
267
+ vector (List[float], optional): Updated vector.
268
+ payload (Dict, optional): Updated payload.
269
+ """
270
+ with self._get_cursor(commit=True) as cur:
271
+ if vector:
272
+ cur.execute(
273
+ f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
274
+ (vector, vector_id),
275
+ )
276
+ if payload:
277
+ # Handle JSON serialization based on psycopg version
278
+ if PSYCOPG_VERSION == 3:
279
+ # psycopg3 uses psycopg.types.json.Json
280
+ cur.execute(
281
+ f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
282
+ (Json(payload), vector_id),
283
+ )
284
+ else:
285
+ # psycopg2 uses psycopg2.extras.Json
286
+ cur.execute(
287
+ f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
288
+ (Json(payload), vector_id),
289
+ )
290
+
291
+
292
+ def get(self, vector_id: str) -> OutputData:
293
+ """
294
+ Retrieve a vector by ID.
295
+
296
+ Args:
297
+ vector_id (str): ID of the vector to retrieve.
298
+
299
+ Returns:
300
+ OutputData: Retrieved vector.
301
+ """
302
+ with self._get_cursor() as cur:
303
+ cur.execute(
304
+ f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
305
+ (vector_id,),
306
+ )
307
+ result = cur.fetchone()
308
+ if not result:
309
+ return None
310
+ return OutputData(id=str(result[0]), score=None, payload=result[2])
311
+
312
+ def list_cols(self) -> List[str]:
313
+ """
314
+ List all collections.
315
+
316
+ Returns:
317
+ List[str]: List of collection names.
318
+ """
319
+ with self._get_cursor() as cur:
320
+ cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
321
+ return [row[0] for row in cur.fetchall()]
322
+
323
+ def delete_col(self) -> None:
324
+ """Delete a collection."""
325
+ with self._get_cursor(commit=True) as cur:
326
+ cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
327
+
328
+ def col_info(self) -> dict[str, Any]:
329
+ """
330
+ Get information about a collection.
331
+
332
+ Returns:
333
+ Dict[str, Any]: Collection information.
334
+ """
335
+ with self._get_cursor() as cur:
336
+ cur.execute(
337
+ f"""
338
+ SELECT
339
+ table_name,
340
+ (SELECT COUNT(*) FROM {self.collection_name}) as row_count,
341
+ (SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
342
+ FROM information_schema.tables
343
+ WHERE table_schema = 'public' AND table_name = %s
344
+ """,
345
+ (self.collection_name,),
346
+ )
347
+ result = cur.fetchone()
348
+ return {"name": result[0], "count": result[1], "size": result[2]}
349
+
350
+ def list(
351
+ self,
352
+ filters: Optional[dict] = None,
353
+ limit: Optional[int] = 100
354
+ ) -> List[OutputData]:
355
+ """
356
+ List all vectors in a collection.
357
+
358
+ Args:
359
+ filters (Dict, optional): Filters to apply to the list.
360
+ limit (int, optional): Number of vectors to return. Defaults to 100.
361
+
362
+ Returns:
363
+ List[OutputData]: List of vectors.
364
+ """
365
+ filter_conditions = []
366
+ filter_params = []
367
+
368
+ if filters:
369
+ for k, v in filters.items():
370
+ filter_conditions.append("payload->>%s = %s")
371
+ filter_params.extend([k, str(v)])
372
+
373
+ filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
374
+
375
+ query = f"""
376
+ SELECT id, vector, payload
377
+ FROM {self.collection_name}
378
+ {filter_clause}
379
+ LIMIT %s
380
+ """
381
+
382
+ with self._get_cursor() as cur:
383
+ cur.execute(query, (*filter_params, limit))
384
+ results = cur.fetchall()
385
+ return [[OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]]
386
+
387
+ def __del__(self) -> None:
388
+ """
389
+ Close the database connection pool when the object is deleted.
390
+ """
391
+ try:
392
+ # Close pool appropriately
393
+ if PSYCOPG_VERSION == 3:
394
+ self.connection_pool.close()
395
+ else:
396
+ self.connection_pool.closeall()
397
+ except Exception:
398
+ pass
399
+
400
+ def reset(self) -> None:
401
+ """Reset the index by deleting and recreating it."""
402
+ logger.warning(f"Resetting index {self.collection_name}...")
403
+ self.delete_col()
404
+ self.create_col()