mem0ai-azure-mysql 0.1.115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem0/__init__.py +6 -0
- mem0/client/__init__.py +0 -0
- mem0/client/main.py +1535 -0
- mem0/client/project.py +860 -0
- mem0/client/utils.py +29 -0
- mem0/configs/__init__.py +0 -0
- mem0/configs/base.py +90 -0
- mem0/configs/dbs/__init__.py +4 -0
- mem0/configs/dbs/base.py +41 -0
- mem0/configs/dbs/mysql.py +25 -0
- mem0/configs/embeddings/__init__.py +0 -0
- mem0/configs/embeddings/base.py +108 -0
- mem0/configs/enums.py +7 -0
- mem0/configs/llms/__init__.py +0 -0
- mem0/configs/llms/base.py +152 -0
- mem0/configs/prompts.py +333 -0
- mem0/configs/vector_stores/__init__.py +0 -0
- mem0/configs/vector_stores/azure_ai_search.py +59 -0
- mem0/configs/vector_stores/baidu.py +29 -0
- mem0/configs/vector_stores/chroma.py +40 -0
- mem0/configs/vector_stores/elasticsearch.py +47 -0
- mem0/configs/vector_stores/faiss.py +39 -0
- mem0/configs/vector_stores/langchain.py +32 -0
- mem0/configs/vector_stores/milvus.py +43 -0
- mem0/configs/vector_stores/mongodb.py +25 -0
- mem0/configs/vector_stores/opensearch.py +41 -0
- mem0/configs/vector_stores/pgvector.py +37 -0
- mem0/configs/vector_stores/pinecone.py +56 -0
- mem0/configs/vector_stores/qdrant.py +49 -0
- mem0/configs/vector_stores/redis.py +26 -0
- mem0/configs/vector_stores/supabase.py +44 -0
- mem0/configs/vector_stores/upstash_vector.py +36 -0
- mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
- mem0/configs/vector_stores/weaviate.py +43 -0
- mem0/dbs/__init__.py +4 -0
- mem0/dbs/base.py +68 -0
- mem0/dbs/configs.py +21 -0
- mem0/dbs/mysql.py +321 -0
- mem0/embeddings/__init__.py +0 -0
- mem0/embeddings/aws_bedrock.py +100 -0
- mem0/embeddings/azure_openai.py +43 -0
- mem0/embeddings/base.py +31 -0
- mem0/embeddings/configs.py +30 -0
- mem0/embeddings/gemini.py +39 -0
- mem0/embeddings/huggingface.py +41 -0
- mem0/embeddings/langchain.py +35 -0
- mem0/embeddings/lmstudio.py +29 -0
- mem0/embeddings/mock.py +11 -0
- mem0/embeddings/ollama.py +53 -0
- mem0/embeddings/openai.py +49 -0
- mem0/embeddings/together.py +31 -0
- mem0/embeddings/vertexai.py +54 -0
- mem0/graphs/__init__.py +0 -0
- mem0/graphs/configs.py +96 -0
- mem0/graphs/neptune/__init__.py +0 -0
- mem0/graphs/neptune/base.py +410 -0
- mem0/graphs/neptune/main.py +372 -0
- mem0/graphs/tools.py +371 -0
- mem0/graphs/utils.py +97 -0
- mem0/llms/__init__.py +0 -0
- mem0/llms/anthropic.py +64 -0
- mem0/llms/aws_bedrock.py +270 -0
- mem0/llms/azure_openai.py +114 -0
- mem0/llms/azure_openai_structured.py +76 -0
- mem0/llms/base.py +32 -0
- mem0/llms/configs.py +34 -0
- mem0/llms/deepseek.py +85 -0
- mem0/llms/gemini.py +201 -0
- mem0/llms/groq.py +88 -0
- mem0/llms/langchain.py +65 -0
- mem0/llms/litellm.py +87 -0
- mem0/llms/lmstudio.py +53 -0
- mem0/llms/ollama.py +94 -0
- mem0/llms/openai.py +124 -0
- mem0/llms/openai_structured.py +52 -0
- mem0/llms/sarvam.py +89 -0
- mem0/llms/together.py +88 -0
- mem0/llms/vllm.py +89 -0
- mem0/llms/xai.py +52 -0
- mem0/memory/__init__.py +0 -0
- mem0/memory/base.py +63 -0
- mem0/memory/graph_memory.py +632 -0
- mem0/memory/main.py +1843 -0
- mem0/memory/memgraph_memory.py +630 -0
- mem0/memory/setup.py +56 -0
- mem0/memory/storage.py +218 -0
- mem0/memory/telemetry.py +90 -0
- mem0/memory/utils.py +133 -0
- mem0/proxy/__init__.py +0 -0
- mem0/proxy/main.py +194 -0
- mem0/utils/factory.py +132 -0
- mem0/vector_stores/__init__.py +0 -0
- mem0/vector_stores/azure_ai_search.py +383 -0
- mem0/vector_stores/baidu.py +368 -0
- mem0/vector_stores/base.py +58 -0
- mem0/vector_stores/chroma.py +229 -0
- mem0/vector_stores/configs.py +60 -0
- mem0/vector_stores/elasticsearch.py +235 -0
- mem0/vector_stores/faiss.py +473 -0
- mem0/vector_stores/langchain.py +179 -0
- mem0/vector_stores/milvus.py +245 -0
- mem0/vector_stores/mongodb.py +293 -0
- mem0/vector_stores/opensearch.py +281 -0
- mem0/vector_stores/pgvector.py +294 -0
- mem0/vector_stores/pinecone.py +373 -0
- mem0/vector_stores/qdrant.py +240 -0
- mem0/vector_stores/redis.py +295 -0
- mem0/vector_stores/supabase.py +237 -0
- mem0/vector_stores/upstash_vector.py +293 -0
- mem0/vector_stores/vertex_ai_vector_search.py +629 -0
- mem0/vector_stores/weaviate.py +316 -0
- mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
- mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
- mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
- mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
- mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from opensearchpy import OpenSearch, RequestsHttpConnection
|
|
7
|
+
except ImportError:
|
|
8
|
+
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from mem0.configs.vector_stores.opensearch import OpenSearchConfig
|
|
13
|
+
from mem0.vector_stores.base import VectorStoreBase
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OutputData(BaseModel):
|
|
19
|
+
id: str
|
|
20
|
+
score: float
|
|
21
|
+
payload: Dict
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenSearchDB(VectorStoreBase):
|
|
25
|
+
def __init__(self, **kwargs):
|
|
26
|
+
config = OpenSearchConfig(**kwargs)
|
|
27
|
+
|
|
28
|
+
# Initialize OpenSearch client
|
|
29
|
+
self.client = OpenSearch(
|
|
30
|
+
hosts=[{"host": config.host, "port": config.port or 9200}],
|
|
31
|
+
http_auth=config.http_auth
|
|
32
|
+
if config.http_auth
|
|
33
|
+
else ((config.user, config.password) if (config.user and config.password) else None),
|
|
34
|
+
use_ssl=config.use_ssl,
|
|
35
|
+
verify_certs=config.verify_certs,
|
|
36
|
+
connection_class=RequestsHttpConnection,
|
|
37
|
+
pool_maxsize=20,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self.collection_name = config.collection_name
|
|
41
|
+
self.embedding_model_dims = config.embedding_model_dims
|
|
42
|
+
self.create_col(self.collection_name, self.embedding_model_dims)
|
|
43
|
+
|
|
44
|
+
def create_index(self) -> None:
|
|
45
|
+
"""Create OpenSearch index with proper mappings if it doesn't exist."""
|
|
46
|
+
index_settings = {
|
|
47
|
+
"settings": {
|
|
48
|
+
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "10s", "knn": True}
|
|
49
|
+
},
|
|
50
|
+
"mappings": {
|
|
51
|
+
"properties": {
|
|
52
|
+
"text": {"type": "text"},
|
|
53
|
+
"vector_field": {
|
|
54
|
+
"type": "knn_vector",
|
|
55
|
+
"dimension": self.embedding_model_dims,
|
|
56
|
+
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
|
57
|
+
},
|
|
58
|
+
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}},
|
|
59
|
+
}
|
|
60
|
+
},
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
if not self.client.indices.exists(index=self.collection_name):
|
|
64
|
+
self.client.indices.create(index=self.collection_name, body=index_settings)
|
|
65
|
+
logger.info(f"Created index {self.collection_name}")
|
|
66
|
+
else:
|
|
67
|
+
logger.info(f"Index {self.collection_name} already exists")
|
|
68
|
+
|
|
69
|
+
def create_col(self, name: str, vector_size: int) -> None:
|
|
70
|
+
"""Create a new collection (index in OpenSearch)."""
|
|
71
|
+
index_settings = {
|
|
72
|
+
"settings": {"index.knn": True},
|
|
73
|
+
"mappings": {
|
|
74
|
+
"properties": {
|
|
75
|
+
"vector_field": {
|
|
76
|
+
"type": "knn_vector",
|
|
77
|
+
"dimension": vector_size,
|
|
78
|
+
"method": {"engine": "nmslib", "name": "hnsw", "space_type": "cosinesimil"},
|
|
79
|
+
},
|
|
80
|
+
"payload": {"type": "object"},
|
|
81
|
+
"id": {"type": "keyword"},
|
|
82
|
+
}
|
|
83
|
+
},
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
if not self.client.indices.exists(index=name):
|
|
87
|
+
logger.warning(f"Creating index {name}, it might take 1-2 minutes...")
|
|
88
|
+
self.client.indices.create(index=name, body=index_settings)
|
|
89
|
+
|
|
90
|
+
# Wait for index to be ready
|
|
91
|
+
max_retries = 180 # 3 minutes timeout
|
|
92
|
+
retry_count = 0
|
|
93
|
+
while retry_count < max_retries:
|
|
94
|
+
try:
|
|
95
|
+
# Check if index is ready by attempting a simple search
|
|
96
|
+
self.client.search(index=name, body={"query": {"match_all": {}}})
|
|
97
|
+
time.sleep(1)
|
|
98
|
+
logger.info(f"Index {name} is ready")
|
|
99
|
+
return
|
|
100
|
+
except Exception:
|
|
101
|
+
retry_count += 1
|
|
102
|
+
if retry_count == max_retries:
|
|
103
|
+
raise TimeoutError(f"Index {name} creation timed out after {max_retries} seconds")
|
|
104
|
+
time.sleep(0.5)
|
|
105
|
+
|
|
106
|
+
def insert(
|
|
107
|
+
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None
|
|
108
|
+
) -> List[OutputData]:
|
|
109
|
+
"""Insert vectors into the index."""
|
|
110
|
+
if not ids:
|
|
111
|
+
ids = [str(i) for i in range(len(vectors))]
|
|
112
|
+
|
|
113
|
+
if payloads is None:
|
|
114
|
+
payloads = [{} for _ in range(len(vectors))]
|
|
115
|
+
|
|
116
|
+
for i, (vec, id_) in enumerate(zip(vectors, ids)):
|
|
117
|
+
body = {
|
|
118
|
+
"vector_field": vec,
|
|
119
|
+
"payload": payloads[i],
|
|
120
|
+
"id": id_,
|
|
121
|
+
}
|
|
122
|
+
self.client.index(index=self.collection_name, body=body)
|
|
123
|
+
|
|
124
|
+
results = []
|
|
125
|
+
|
|
126
|
+
return results
|
|
127
|
+
|
|
128
|
+
def search(
|
|
129
|
+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
|
|
130
|
+
) -> List[OutputData]:
|
|
131
|
+
"""Search for similar vectors using OpenSearch k-NN search with optional filters."""
|
|
132
|
+
|
|
133
|
+
# Base KNN query
|
|
134
|
+
knn_query = {
|
|
135
|
+
"knn": {
|
|
136
|
+
"vector_field": {
|
|
137
|
+
"vector": vectors,
|
|
138
|
+
"k": limit * 2,
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
# Start building the full query
|
|
144
|
+
query_body = {"size": limit * 2, "query": None}
|
|
145
|
+
|
|
146
|
+
# Prepare filter conditions if applicable
|
|
147
|
+
filter_clauses = []
|
|
148
|
+
if filters:
|
|
149
|
+
for key in ["user_id", "run_id", "agent_id"]:
|
|
150
|
+
value = filters.get(key)
|
|
151
|
+
if value:
|
|
152
|
+
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
|
153
|
+
|
|
154
|
+
# Combine knn with filters if needed
|
|
155
|
+
if filter_clauses:
|
|
156
|
+
query_body["query"] = {"bool": {"must": knn_query, "filter": filter_clauses}}
|
|
157
|
+
else:
|
|
158
|
+
query_body["query"] = knn_query
|
|
159
|
+
|
|
160
|
+
# Execute search
|
|
161
|
+
response = self.client.search(index=self.collection_name, body=query_body)
|
|
162
|
+
|
|
163
|
+
hits = response["hits"]["hits"]
|
|
164
|
+
results = [
|
|
165
|
+
OutputData(id=hit["_source"].get("id"), score=hit["_score"], payload=hit["_source"].get("payload", {}))
|
|
166
|
+
for hit in hits
|
|
167
|
+
]
|
|
168
|
+
return results
|
|
169
|
+
|
|
170
|
+
def delete(self, vector_id: str) -> None:
|
|
171
|
+
"""Delete a vector by custom ID."""
|
|
172
|
+
# First, find the document by custom ID
|
|
173
|
+
search_query = {"query": {"term": {"id": vector_id}}}
|
|
174
|
+
|
|
175
|
+
response = self.client.search(index=self.collection_name, body=search_query)
|
|
176
|
+
hits = response.get("hits", {}).get("hits", [])
|
|
177
|
+
|
|
178
|
+
if not hits:
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
opensearch_id = hits[0]["_id"]
|
|
182
|
+
|
|
183
|
+
# Delete using the actual document ID
|
|
184
|
+
self.client.delete(index=self.collection_name, id=opensearch_id)
|
|
185
|
+
|
|
186
|
+
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None:
|
|
187
|
+
"""Update a vector and its payload using the custom 'id' field."""
|
|
188
|
+
|
|
189
|
+
# First, find the document by custom ID
|
|
190
|
+
search_query = {"query": {"term": {"id": vector_id}}}
|
|
191
|
+
|
|
192
|
+
response = self.client.search(index=self.collection_name, body=search_query)
|
|
193
|
+
hits = response.get("hits", {}).get("hits", [])
|
|
194
|
+
|
|
195
|
+
if not hits:
|
|
196
|
+
return
|
|
197
|
+
|
|
198
|
+
opensearch_id = hits[0]["_id"] # The actual document ID in OpenSearch
|
|
199
|
+
|
|
200
|
+
# Prepare updated fields
|
|
201
|
+
doc = {}
|
|
202
|
+
if vector is not None:
|
|
203
|
+
doc["vector_field"] = vector
|
|
204
|
+
if payload is not None:
|
|
205
|
+
doc["payload"] = payload
|
|
206
|
+
|
|
207
|
+
if doc:
|
|
208
|
+
try:
|
|
209
|
+
response = self.client.update(index=self.collection_name, id=opensearch_id, body={"doc": doc})
|
|
210
|
+
except Exception:
|
|
211
|
+
pass
|
|
212
|
+
|
|
213
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
214
|
+
"""Retrieve a vector by ID."""
|
|
215
|
+
try:
|
|
216
|
+
# First check if index exists
|
|
217
|
+
if not self.client.indices.exists(index=self.collection_name):
|
|
218
|
+
logger.info(f"Index {self.collection_name} does not exist, creating it...")
|
|
219
|
+
self.create_col(self.collection_name, self.embedding_model_dims)
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
search_query = {"query": {"term": {"id": vector_id}}}
|
|
223
|
+
response = self.client.search(index=self.collection_name, body=search_query)
|
|
224
|
+
|
|
225
|
+
hits = response["hits"]["hits"]
|
|
226
|
+
|
|
227
|
+
if not hits:
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
return OutputData(id=hits[0]["_source"].get("id"), score=1.0, payload=hits[0]["_source"].get("payload", {}))
|
|
231
|
+
except Exception as e:
|
|
232
|
+
logger.error(f"Error retrieving vector {vector_id}: {str(e)}")
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
def list_cols(self) -> List[str]:
|
|
236
|
+
"""List all collections (indices)."""
|
|
237
|
+
return list(self.client.indices.get_alias().keys())
|
|
238
|
+
|
|
239
|
+
def delete_col(self) -> None:
|
|
240
|
+
"""Delete a collection (index)."""
|
|
241
|
+
self.client.indices.delete(index=self.collection_name)
|
|
242
|
+
|
|
243
|
+
def col_info(self, name: str) -> Any:
|
|
244
|
+
"""Get information about a collection (index)."""
|
|
245
|
+
return self.client.indices.get(index=name)
|
|
246
|
+
|
|
247
|
+
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[OutputData]:
|
|
248
|
+
try:
|
|
249
|
+
"""List all memories with optional filters."""
|
|
250
|
+
query: Dict = {"query": {"match_all": {}}}
|
|
251
|
+
|
|
252
|
+
filter_clauses = []
|
|
253
|
+
if filters:
|
|
254
|
+
for key in ["user_id", "run_id", "agent_id"]:
|
|
255
|
+
value = filters.get(key)
|
|
256
|
+
if value:
|
|
257
|
+
filter_clauses.append({"term": {f"payload.{key}.keyword": value}})
|
|
258
|
+
|
|
259
|
+
if filter_clauses:
|
|
260
|
+
query["query"] = {"bool": {"filter": filter_clauses}}
|
|
261
|
+
|
|
262
|
+
if limit:
|
|
263
|
+
query["size"] = limit
|
|
264
|
+
|
|
265
|
+
response = self.client.search(index=self.collection_name, body=query)
|
|
266
|
+
hits = response["hits"]["hits"]
|
|
267
|
+
|
|
268
|
+
return [
|
|
269
|
+
[
|
|
270
|
+
OutputData(id=hit["_source"].get("id"), score=1.0, payload=hit["_source"].get("payload", {}))
|
|
271
|
+
for hit in hits
|
|
272
|
+
]
|
|
273
|
+
]
|
|
274
|
+
except Exception:
|
|
275
|
+
return []
|
|
276
|
+
|
|
277
|
+
def reset(self):
|
|
278
|
+
"""Reset the index by deleting and recreating it."""
|
|
279
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
280
|
+
self.delete_col()
|
|
281
|
+
self.create_col(self.collection_name, self.embedding_model_dims)
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import psycopg2
|
|
9
|
+
from psycopg2.extras import execute_values
|
|
10
|
+
except ImportError:
|
|
11
|
+
raise ImportError("The 'psycopg2' library is required. Please install it using 'pip install psycopg2'.")
|
|
12
|
+
|
|
13
|
+
from mem0.vector_stores.base import VectorStoreBase
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OutputData(BaseModel):
|
|
19
|
+
id: Optional[str]
|
|
20
|
+
score: Optional[float]
|
|
21
|
+
payload: Optional[dict]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PGVector(VectorStoreBase):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
dbname,
|
|
28
|
+
collection_name,
|
|
29
|
+
embedding_model_dims,
|
|
30
|
+
user,
|
|
31
|
+
password,
|
|
32
|
+
host,
|
|
33
|
+
port,
|
|
34
|
+
diskann,
|
|
35
|
+
hnsw,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the PGVector database.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
dbname (str): Database name
|
|
42
|
+
collection_name (str): Collection name
|
|
43
|
+
embedding_model_dims (int): Dimension of the embedding vector
|
|
44
|
+
user (str): Database user
|
|
45
|
+
password (str): Database password
|
|
46
|
+
host (str, optional): Database host
|
|
47
|
+
port (int, optional): Database port
|
|
48
|
+
diskann (bool, optional): Use DiskANN for faster search
|
|
49
|
+
hnsw (bool, optional): Use HNSW for faster search
|
|
50
|
+
"""
|
|
51
|
+
self.collection_name = collection_name
|
|
52
|
+
self.use_diskann = diskann
|
|
53
|
+
self.use_hnsw = hnsw
|
|
54
|
+
self.embedding_model_dims = embedding_model_dims
|
|
55
|
+
|
|
56
|
+
self.conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
|
|
57
|
+
self.cur = self.conn.cursor()
|
|
58
|
+
|
|
59
|
+
collections = self.list_cols()
|
|
60
|
+
if collection_name not in collections:
|
|
61
|
+
self.create_col(embedding_model_dims)
|
|
62
|
+
|
|
63
|
+
def create_col(self, embedding_model_dims):
|
|
64
|
+
"""
|
|
65
|
+
Create a new collection (table in PostgreSQL).
|
|
66
|
+
Will also initialize vector search index if specified.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
embedding_model_dims (int): Dimension of the embedding vector.
|
|
70
|
+
"""
|
|
71
|
+
self.cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
72
|
+
self.cur.execute(
|
|
73
|
+
f"""
|
|
74
|
+
CREATE TABLE IF NOT EXISTS {self.collection_name} (
|
|
75
|
+
id UUID PRIMARY KEY,
|
|
76
|
+
vector vector({embedding_model_dims}),
|
|
77
|
+
payload JSONB
|
|
78
|
+
);
|
|
79
|
+
"""
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if self.use_diskann and embedding_model_dims < 2000:
|
|
83
|
+
# Check if vectorscale extension is installed
|
|
84
|
+
self.cur.execute("SELECT * FROM pg_extension WHERE extname = 'vectorscale'")
|
|
85
|
+
if self.cur.fetchone():
|
|
86
|
+
# Create DiskANN index if extension is installed for faster search
|
|
87
|
+
self.cur.execute(
|
|
88
|
+
f"""
|
|
89
|
+
CREATE INDEX IF NOT EXISTS {self.collection_name}_diskann_idx
|
|
90
|
+
ON {self.collection_name}
|
|
91
|
+
USING diskann (vector);
|
|
92
|
+
"""
|
|
93
|
+
)
|
|
94
|
+
elif self.use_hnsw:
|
|
95
|
+
self.cur.execute(
|
|
96
|
+
f"""
|
|
97
|
+
CREATE INDEX IF NOT EXISTS {self.collection_name}_hnsw_idx
|
|
98
|
+
ON {self.collection_name}
|
|
99
|
+
USING hnsw (vector vector_cosine_ops)
|
|
100
|
+
"""
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self.conn.commit()
|
|
104
|
+
|
|
105
|
+
def insert(self, vectors, payloads=None, ids=None):
|
|
106
|
+
"""
|
|
107
|
+
Insert vectors into a collection.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
vectors (List[List[float]]): List of vectors to insert.
|
|
111
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors.
|
|
112
|
+
ids (List[str], optional): List of IDs corresponding to vectors.
|
|
113
|
+
"""
|
|
114
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
|
115
|
+
json_payloads = [json.dumps(payload) for payload in payloads]
|
|
116
|
+
|
|
117
|
+
data = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, json_payloads)]
|
|
118
|
+
execute_values(
|
|
119
|
+
self.cur,
|
|
120
|
+
f"INSERT INTO {self.collection_name} (id, vector, payload) VALUES %s",
|
|
121
|
+
data,
|
|
122
|
+
)
|
|
123
|
+
self.conn.commit()
|
|
124
|
+
|
|
125
|
+
def search(self, query, vectors, limit=5, filters=None):
|
|
126
|
+
"""
|
|
127
|
+
Search for similar vectors.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
query (str): Query.
|
|
131
|
+
vectors (List[float]): Query vector.
|
|
132
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
133
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
list: Search results.
|
|
137
|
+
"""
|
|
138
|
+
filter_conditions = []
|
|
139
|
+
filter_params = []
|
|
140
|
+
|
|
141
|
+
if filters:
|
|
142
|
+
for k, v in filters.items():
|
|
143
|
+
filter_conditions.append("payload->>%s = %s")
|
|
144
|
+
filter_params.extend([k, str(v)])
|
|
145
|
+
|
|
146
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
147
|
+
|
|
148
|
+
self.cur.execute(
|
|
149
|
+
f"""
|
|
150
|
+
SELECT id, vector <=> %s::vector AS distance, payload
|
|
151
|
+
FROM {self.collection_name}
|
|
152
|
+
{filter_clause}
|
|
153
|
+
ORDER BY distance
|
|
154
|
+
LIMIT %s
|
|
155
|
+
""",
|
|
156
|
+
(vectors, *filter_params, limit),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
results = self.cur.fetchall()
|
|
160
|
+
return [OutputData(id=str(r[0]), score=float(r[1]), payload=r[2]) for r in results]
|
|
161
|
+
|
|
162
|
+
def delete(self, vector_id):
|
|
163
|
+
"""
|
|
164
|
+
Delete a vector by ID.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
vector_id (str): ID of the vector to delete.
|
|
168
|
+
"""
|
|
169
|
+
self.cur.execute(f"DELETE FROM {self.collection_name} WHERE id = %s", (vector_id,))
|
|
170
|
+
self.conn.commit()
|
|
171
|
+
|
|
172
|
+
def update(self, vector_id, vector=None, payload=None):
|
|
173
|
+
"""
|
|
174
|
+
Update a vector and its payload.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
vector_id (str): ID of the vector to update.
|
|
178
|
+
vector (List[float], optional): Updated vector.
|
|
179
|
+
payload (Dict, optional): Updated payload.
|
|
180
|
+
"""
|
|
181
|
+
if vector:
|
|
182
|
+
self.cur.execute(
|
|
183
|
+
f"UPDATE {self.collection_name} SET vector = %s WHERE id = %s",
|
|
184
|
+
(vector, vector_id),
|
|
185
|
+
)
|
|
186
|
+
if payload:
|
|
187
|
+
self.cur.execute(
|
|
188
|
+
f"UPDATE {self.collection_name} SET payload = %s WHERE id = %s",
|
|
189
|
+
(psycopg2.extras.Json(payload), vector_id),
|
|
190
|
+
)
|
|
191
|
+
self.conn.commit()
|
|
192
|
+
|
|
193
|
+
def get(self, vector_id) -> OutputData:
|
|
194
|
+
"""
|
|
195
|
+
Retrieve a vector by ID.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
vector_id (str): ID of the vector to retrieve.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
OutputData: Retrieved vector.
|
|
202
|
+
"""
|
|
203
|
+
self.cur.execute(
|
|
204
|
+
f"SELECT id, vector, payload FROM {self.collection_name} WHERE id = %s",
|
|
205
|
+
(vector_id,),
|
|
206
|
+
)
|
|
207
|
+
result = self.cur.fetchone()
|
|
208
|
+
if not result:
|
|
209
|
+
return None
|
|
210
|
+
return OutputData(id=str(result[0]), score=None, payload=result[2])
|
|
211
|
+
|
|
212
|
+
def list_cols(self) -> List[str]:
|
|
213
|
+
"""
|
|
214
|
+
List all collections.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
List[str]: List of collection names.
|
|
218
|
+
"""
|
|
219
|
+
self.cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
|
|
220
|
+
return [row[0] for row in self.cur.fetchall()]
|
|
221
|
+
|
|
222
|
+
def delete_col(self):
|
|
223
|
+
"""Delete a collection."""
|
|
224
|
+
self.cur.execute(f"DROP TABLE IF EXISTS {self.collection_name}")
|
|
225
|
+
self.conn.commit()
|
|
226
|
+
|
|
227
|
+
def col_info(self):
|
|
228
|
+
"""
|
|
229
|
+
Get information about a collection.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Dict[str, Any]: Collection information.
|
|
233
|
+
"""
|
|
234
|
+
self.cur.execute(
|
|
235
|
+
f"""
|
|
236
|
+
SELECT
|
|
237
|
+
table_name,
|
|
238
|
+
(SELECT COUNT(*) FROM {self.collection_name}) as row_count,
|
|
239
|
+
(SELECT pg_size_pretty(pg_total_relation_size('{self.collection_name}'))) as total_size
|
|
240
|
+
FROM information_schema.tables
|
|
241
|
+
WHERE table_schema = 'public' AND table_name = %s
|
|
242
|
+
""",
|
|
243
|
+
(self.collection_name,),
|
|
244
|
+
)
|
|
245
|
+
result = self.cur.fetchone()
|
|
246
|
+
return {"name": result[0], "count": result[1], "size": result[2]}
|
|
247
|
+
|
|
248
|
+
def list(self, filters=None, limit=100):
|
|
249
|
+
"""
|
|
250
|
+
List all vectors in a collection.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
filters (Dict, optional): Filters to apply to the list.
|
|
254
|
+
limit (int, optional): Number of vectors to return. Defaults to 100.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
List[OutputData]: List of vectors.
|
|
258
|
+
"""
|
|
259
|
+
filter_conditions = []
|
|
260
|
+
filter_params = []
|
|
261
|
+
|
|
262
|
+
if filters:
|
|
263
|
+
for k, v in filters.items():
|
|
264
|
+
filter_conditions.append("payload->>%s = %s")
|
|
265
|
+
filter_params.extend([k, str(v)])
|
|
266
|
+
|
|
267
|
+
filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else ""
|
|
268
|
+
|
|
269
|
+
query = f"""
|
|
270
|
+
SELECT id, vector, payload
|
|
271
|
+
FROM {self.collection_name}
|
|
272
|
+
{filter_clause}
|
|
273
|
+
LIMIT %s
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
self.cur.execute(query, (*filter_params, limit))
|
|
277
|
+
|
|
278
|
+
results = self.cur.fetchall()
|
|
279
|
+
return [[OutputData(id=str(r[0]), score=None, payload=r[2]) for r in results]]
|
|
280
|
+
|
|
281
|
+
def __del__(self):
|
|
282
|
+
"""
|
|
283
|
+
Close the database connection when the object is deleted.
|
|
284
|
+
"""
|
|
285
|
+
if hasattr(self, "cur"):
|
|
286
|
+
self.cur.close()
|
|
287
|
+
if hasattr(self, "conn"):
|
|
288
|
+
self.conn.close()
|
|
289
|
+
|
|
290
|
+
def reset(self):
|
|
291
|
+
"""Reset the index by deleting and recreating it."""
|
|
292
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
293
|
+
self.delete_col()
|
|
294
|
+
self.create_col(self.embedding_model_dims)
|