mem0ai-azure-mysql 0.1.115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem0/__init__.py +6 -0
- mem0/client/__init__.py +0 -0
- mem0/client/main.py +1535 -0
- mem0/client/project.py +860 -0
- mem0/client/utils.py +29 -0
- mem0/configs/__init__.py +0 -0
- mem0/configs/base.py +90 -0
- mem0/configs/dbs/__init__.py +4 -0
- mem0/configs/dbs/base.py +41 -0
- mem0/configs/dbs/mysql.py +25 -0
- mem0/configs/embeddings/__init__.py +0 -0
- mem0/configs/embeddings/base.py +108 -0
- mem0/configs/enums.py +7 -0
- mem0/configs/llms/__init__.py +0 -0
- mem0/configs/llms/base.py +152 -0
- mem0/configs/prompts.py +333 -0
- mem0/configs/vector_stores/__init__.py +0 -0
- mem0/configs/vector_stores/azure_ai_search.py +59 -0
- mem0/configs/vector_stores/baidu.py +29 -0
- mem0/configs/vector_stores/chroma.py +40 -0
- mem0/configs/vector_stores/elasticsearch.py +47 -0
- mem0/configs/vector_stores/faiss.py +39 -0
- mem0/configs/vector_stores/langchain.py +32 -0
- mem0/configs/vector_stores/milvus.py +43 -0
- mem0/configs/vector_stores/mongodb.py +25 -0
- mem0/configs/vector_stores/opensearch.py +41 -0
- mem0/configs/vector_stores/pgvector.py +37 -0
- mem0/configs/vector_stores/pinecone.py +56 -0
- mem0/configs/vector_stores/qdrant.py +49 -0
- mem0/configs/vector_stores/redis.py +26 -0
- mem0/configs/vector_stores/supabase.py +44 -0
- mem0/configs/vector_stores/upstash_vector.py +36 -0
- mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
- mem0/configs/vector_stores/weaviate.py +43 -0
- mem0/dbs/__init__.py +4 -0
- mem0/dbs/base.py +68 -0
- mem0/dbs/configs.py +21 -0
- mem0/dbs/mysql.py +321 -0
- mem0/embeddings/__init__.py +0 -0
- mem0/embeddings/aws_bedrock.py +100 -0
- mem0/embeddings/azure_openai.py +43 -0
- mem0/embeddings/base.py +31 -0
- mem0/embeddings/configs.py +30 -0
- mem0/embeddings/gemini.py +39 -0
- mem0/embeddings/huggingface.py +41 -0
- mem0/embeddings/langchain.py +35 -0
- mem0/embeddings/lmstudio.py +29 -0
- mem0/embeddings/mock.py +11 -0
- mem0/embeddings/ollama.py +53 -0
- mem0/embeddings/openai.py +49 -0
- mem0/embeddings/together.py +31 -0
- mem0/embeddings/vertexai.py +54 -0
- mem0/graphs/__init__.py +0 -0
- mem0/graphs/configs.py +96 -0
- mem0/graphs/neptune/__init__.py +0 -0
- mem0/graphs/neptune/base.py +410 -0
- mem0/graphs/neptune/main.py +372 -0
- mem0/graphs/tools.py +371 -0
- mem0/graphs/utils.py +97 -0
- mem0/llms/__init__.py +0 -0
- mem0/llms/anthropic.py +64 -0
- mem0/llms/aws_bedrock.py +270 -0
- mem0/llms/azure_openai.py +114 -0
- mem0/llms/azure_openai_structured.py +76 -0
- mem0/llms/base.py +32 -0
- mem0/llms/configs.py +34 -0
- mem0/llms/deepseek.py +85 -0
- mem0/llms/gemini.py +201 -0
- mem0/llms/groq.py +88 -0
- mem0/llms/langchain.py +65 -0
- mem0/llms/litellm.py +87 -0
- mem0/llms/lmstudio.py +53 -0
- mem0/llms/ollama.py +94 -0
- mem0/llms/openai.py +124 -0
- mem0/llms/openai_structured.py +52 -0
- mem0/llms/sarvam.py +89 -0
- mem0/llms/together.py +88 -0
- mem0/llms/vllm.py +89 -0
- mem0/llms/xai.py +52 -0
- mem0/memory/__init__.py +0 -0
- mem0/memory/base.py +63 -0
- mem0/memory/graph_memory.py +632 -0
- mem0/memory/main.py +1843 -0
- mem0/memory/memgraph_memory.py +630 -0
- mem0/memory/setup.py +56 -0
- mem0/memory/storage.py +218 -0
- mem0/memory/telemetry.py +90 -0
- mem0/memory/utils.py +133 -0
- mem0/proxy/__init__.py +0 -0
- mem0/proxy/main.py +194 -0
- mem0/utils/factory.py +132 -0
- mem0/vector_stores/__init__.py +0 -0
- mem0/vector_stores/azure_ai_search.py +383 -0
- mem0/vector_stores/baidu.py +368 -0
- mem0/vector_stores/base.py +58 -0
- mem0/vector_stores/chroma.py +229 -0
- mem0/vector_stores/configs.py +60 -0
- mem0/vector_stores/elasticsearch.py +235 -0
- mem0/vector_stores/faiss.py +473 -0
- mem0/vector_stores/langchain.py +179 -0
- mem0/vector_stores/milvus.py +245 -0
- mem0/vector_stores/mongodb.py +293 -0
- mem0/vector_stores/opensearch.py +281 -0
- mem0/vector_stores/pgvector.py +294 -0
- mem0/vector_stores/pinecone.py +373 -0
- mem0/vector_stores/qdrant.py +240 -0
- mem0/vector_stores/redis.py +295 -0
- mem0/vector_stores/supabase.py +237 -0
- mem0/vector_stores/upstash_vector.py +293 -0
- mem0/vector_stores/vertex_ai_vector_search.py +629 -0
- mem0/vector_stores/weaviate.py +316 -0
- mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
- mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
- mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
- mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
- mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from functools import reduce
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytz
|
|
8
|
+
import redis
|
|
9
|
+
from redis.commands.search.query import Query
|
|
10
|
+
from redisvl.index import SearchIndex
|
|
11
|
+
from redisvl.query import VectorQuery
|
|
12
|
+
from redisvl.query.filter import Tag
|
|
13
|
+
|
|
14
|
+
from mem0.memory.utils import extract_json
|
|
15
|
+
from mem0.vector_stores.base import VectorStoreBase
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# TODO: Improve as these are not the best fields for the Redis's perspective. Might do away with them.
|
|
20
|
+
DEFAULT_FIELDS = [
|
|
21
|
+
{"name": "memory_id", "type": "tag"},
|
|
22
|
+
{"name": "hash", "type": "tag"},
|
|
23
|
+
{"name": "agent_id", "type": "tag"},
|
|
24
|
+
{"name": "run_id", "type": "tag"},
|
|
25
|
+
{"name": "user_id", "type": "tag"},
|
|
26
|
+
{"name": "memory", "type": "text"},
|
|
27
|
+
{"name": "metadata", "type": "text"},
|
|
28
|
+
# TODO: Although it is numeric but also accepts string
|
|
29
|
+
{"name": "created_at", "type": "numeric"},
|
|
30
|
+
{"name": "updated_at", "type": "numeric"},
|
|
31
|
+
{
|
|
32
|
+
"name": "embedding",
|
|
33
|
+
"type": "vector",
|
|
34
|
+
"attrs": {"distance_metric": "cosine", "algorithm": "flat", "datatype": "float32"},
|
|
35
|
+
},
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class MemoryResult:
|
|
42
|
+
def __init__(self, id: str, payload: dict, score: float = None):
|
|
43
|
+
self.id = id
|
|
44
|
+
self.payload = payload
|
|
45
|
+
self.score = score
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RedisDB(VectorStoreBase):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
redis_url: str,
|
|
52
|
+
collection_name: str,
|
|
53
|
+
embedding_model_dims: int,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Initialize the Redis vector store.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
redis_url (str): Redis URL.
|
|
60
|
+
collection_name (str): Collection name.
|
|
61
|
+
embedding_model_dims (int): Embedding model dimensions.
|
|
62
|
+
"""
|
|
63
|
+
self.embedding_model_dims = embedding_model_dims
|
|
64
|
+
index_schema = {
|
|
65
|
+
"name": collection_name,
|
|
66
|
+
"prefix": f"mem0:{collection_name}",
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
fields = DEFAULT_FIELDS.copy()
|
|
70
|
+
fields[-1]["attrs"]["dims"] = embedding_model_dims
|
|
71
|
+
|
|
72
|
+
self.schema = {"index": index_schema, "fields": fields}
|
|
73
|
+
|
|
74
|
+
self.client = redis.Redis.from_url(redis_url)
|
|
75
|
+
self.index = SearchIndex.from_dict(self.schema)
|
|
76
|
+
self.index.set_client(self.client)
|
|
77
|
+
self.index.create(overwrite=True)
|
|
78
|
+
|
|
79
|
+
def create_col(self, name=None, vector_size=None, distance=None):
|
|
80
|
+
"""
|
|
81
|
+
Create a new collection (index) in Redis.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
|
|
85
|
+
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
|
|
86
|
+
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The created index object.
|
|
90
|
+
"""
|
|
91
|
+
# Use provided parameters or fall back to instance attributes
|
|
92
|
+
collection_name = name or self.schema["index"]["name"]
|
|
93
|
+
embedding_dims = vector_size or self.embedding_model_dims
|
|
94
|
+
distance_metric = distance or "cosine"
|
|
95
|
+
|
|
96
|
+
# Create a new schema with the specified parameters
|
|
97
|
+
index_schema = {
|
|
98
|
+
"name": collection_name,
|
|
99
|
+
"prefix": f"mem0:{collection_name}",
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
# Copy the default fields and update the vector field with the specified dimensions
|
|
103
|
+
fields = DEFAULT_FIELDS.copy()
|
|
104
|
+
fields[-1]["attrs"]["dims"] = embedding_dims
|
|
105
|
+
fields[-1]["attrs"]["distance_metric"] = distance_metric
|
|
106
|
+
|
|
107
|
+
# Create the schema
|
|
108
|
+
schema = {"index": index_schema, "fields": fields}
|
|
109
|
+
|
|
110
|
+
# Create the index
|
|
111
|
+
index = SearchIndex.from_dict(schema)
|
|
112
|
+
index.set_client(self.client)
|
|
113
|
+
index.create(overwrite=True)
|
|
114
|
+
|
|
115
|
+
# Update instance attributes if creating a new collection
|
|
116
|
+
if name:
|
|
117
|
+
self.schema = schema
|
|
118
|
+
self.index = index
|
|
119
|
+
|
|
120
|
+
return index
|
|
121
|
+
|
|
122
|
+
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
|
123
|
+
data = []
|
|
124
|
+
for vector, payload, id in zip(vectors, payloads, ids):
|
|
125
|
+
# Start with required fields
|
|
126
|
+
entry = {
|
|
127
|
+
"memory_id": id,
|
|
128
|
+
"hash": payload["hash"],
|
|
129
|
+
"memory": payload["data"],
|
|
130
|
+
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
|
131
|
+
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
# Conditionally add optional fields
|
|
135
|
+
for field in ["agent_id", "run_id", "user_id"]:
|
|
136
|
+
if field in payload:
|
|
137
|
+
entry[field] = payload[field]
|
|
138
|
+
|
|
139
|
+
# Add metadata excluding specific keys
|
|
140
|
+
entry["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
|
141
|
+
|
|
142
|
+
data.append(entry)
|
|
143
|
+
self.index.load(data, id_field="memory_id")
|
|
144
|
+
|
|
145
|
+
def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None):
|
|
146
|
+
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
|
147
|
+
filter = reduce(lambda x, y: x & y, conditions)
|
|
148
|
+
|
|
149
|
+
v = VectorQuery(
|
|
150
|
+
vector=np.array(vectors, dtype=np.float32).tobytes(),
|
|
151
|
+
vector_field_name="embedding",
|
|
152
|
+
return_fields=["memory_id", "hash", "agent_id", "run_id", "user_id", "memory", "metadata", "created_at"],
|
|
153
|
+
filter_expression=filter,
|
|
154
|
+
num_results=limit,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
results = self.index.query(v)
|
|
158
|
+
|
|
159
|
+
return [
|
|
160
|
+
MemoryResult(
|
|
161
|
+
id=result["memory_id"],
|
|
162
|
+
score=result["vector_distance"],
|
|
163
|
+
payload={
|
|
164
|
+
"hash": result["hash"],
|
|
165
|
+
"data": result["memory"],
|
|
166
|
+
"created_at": datetime.fromtimestamp(
|
|
167
|
+
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
|
168
|
+
).isoformat(timespec="microseconds"),
|
|
169
|
+
**(
|
|
170
|
+
{
|
|
171
|
+
"updated_at": datetime.fromtimestamp(
|
|
172
|
+
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
|
173
|
+
).isoformat(timespec="microseconds")
|
|
174
|
+
}
|
|
175
|
+
if "updated_at" in result
|
|
176
|
+
else {}
|
|
177
|
+
),
|
|
178
|
+
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
|
179
|
+
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
|
180
|
+
},
|
|
181
|
+
)
|
|
182
|
+
for result in results
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
def delete(self, vector_id):
|
|
186
|
+
self.index.drop_keys(f"{self.schema['index']['prefix']}:{vector_id}")
|
|
187
|
+
|
|
188
|
+
def update(self, vector_id=None, vector=None, payload=None):
|
|
189
|
+
data = {
|
|
190
|
+
"memory_id": vector_id,
|
|
191
|
+
"hash": payload["hash"],
|
|
192
|
+
"memory": payload["data"],
|
|
193
|
+
"created_at": int(datetime.fromisoformat(payload["created_at"]).timestamp()),
|
|
194
|
+
"updated_at": int(datetime.fromisoformat(payload["updated_at"]).timestamp()),
|
|
195
|
+
"embedding": np.array(vector, dtype=np.float32).tobytes(),
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
for field in ["agent_id", "run_id", "user_id"]:
|
|
199
|
+
if field in payload:
|
|
200
|
+
data[field] = payload[field]
|
|
201
|
+
|
|
202
|
+
data["metadata"] = json.dumps({k: v for k, v in payload.items() if k not in excluded_keys})
|
|
203
|
+
self.index.load(data=[data], keys=[f"{self.schema['index']['prefix']}:{vector_id}"], id_field="memory_id")
|
|
204
|
+
|
|
205
|
+
def get(self, vector_id):
|
|
206
|
+
result = self.index.fetch(vector_id)
|
|
207
|
+
payload = {
|
|
208
|
+
"hash": result["hash"],
|
|
209
|
+
"data": result["memory"],
|
|
210
|
+
"created_at": datetime.fromtimestamp(int(result["created_at"]), tz=pytz.timezone("US/Pacific")).isoformat(
|
|
211
|
+
timespec="microseconds"
|
|
212
|
+
),
|
|
213
|
+
**(
|
|
214
|
+
{
|
|
215
|
+
"updated_at": datetime.fromtimestamp(
|
|
216
|
+
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
|
217
|
+
).isoformat(timespec="microseconds")
|
|
218
|
+
}
|
|
219
|
+
if "updated_at" in result
|
|
220
|
+
else {}
|
|
221
|
+
),
|
|
222
|
+
**{field: result[field] for field in ["agent_id", "run_id", "user_id"] if field in result},
|
|
223
|
+
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
return MemoryResult(id=result["memory_id"], payload=payload)
|
|
227
|
+
|
|
228
|
+
def list_cols(self):
|
|
229
|
+
return self.index.listall()
|
|
230
|
+
|
|
231
|
+
def delete_col(self):
|
|
232
|
+
self.index.delete()
|
|
233
|
+
|
|
234
|
+
def col_info(self, name):
|
|
235
|
+
return self.index.info()
|
|
236
|
+
|
|
237
|
+
def reset(self):
|
|
238
|
+
"""
|
|
239
|
+
Reset the index by deleting and recreating it.
|
|
240
|
+
"""
|
|
241
|
+
collection_name = self.schema["index"]["name"]
|
|
242
|
+
logger.warning(f"Resetting index {collection_name}...")
|
|
243
|
+
self.delete_col()
|
|
244
|
+
|
|
245
|
+
self.index = SearchIndex.from_dict(self.schema)
|
|
246
|
+
self.index.set_client(self.client)
|
|
247
|
+
self.index.create(overwrite=True)
|
|
248
|
+
|
|
249
|
+
# or use
|
|
250
|
+
# self.create_col(collection_name, self.embedding_model_dims)
|
|
251
|
+
|
|
252
|
+
# Recreate the index with the same parameters
|
|
253
|
+
self.create_col(collection_name, self.embedding_model_dims)
|
|
254
|
+
|
|
255
|
+
def list(self, filters: dict = None, limit: int = None) -> list:
|
|
256
|
+
"""
|
|
257
|
+
List all recent created memories from the vector store.
|
|
258
|
+
"""
|
|
259
|
+
conditions = [Tag(key) == value for key, value in filters.items() if value is not None]
|
|
260
|
+
filter = reduce(lambda x, y: x & y, conditions)
|
|
261
|
+
query = Query(str(filter)).sort_by("created_at", asc=False)
|
|
262
|
+
if limit is not None:
|
|
263
|
+
query = Query(str(filter)).sort_by("created_at", asc=False).paging(0, limit)
|
|
264
|
+
|
|
265
|
+
results = self.index.search(query)
|
|
266
|
+
return [
|
|
267
|
+
[
|
|
268
|
+
MemoryResult(
|
|
269
|
+
id=result["memory_id"],
|
|
270
|
+
payload={
|
|
271
|
+
"hash": result["hash"],
|
|
272
|
+
"data": result["memory"],
|
|
273
|
+
"created_at": datetime.fromtimestamp(
|
|
274
|
+
int(result["created_at"]), tz=pytz.timezone("US/Pacific")
|
|
275
|
+
).isoformat(timespec="microseconds"),
|
|
276
|
+
**(
|
|
277
|
+
{
|
|
278
|
+
"updated_at": datetime.fromtimestamp(
|
|
279
|
+
int(result["updated_at"]), tz=pytz.timezone("US/Pacific")
|
|
280
|
+
).isoformat(timespec="microseconds")
|
|
281
|
+
}
|
|
282
|
+
if result.__dict__.get("updated_at")
|
|
283
|
+
else {}
|
|
284
|
+
),
|
|
285
|
+
**{
|
|
286
|
+
field: result[field]
|
|
287
|
+
for field in ["agent_id", "run_id", "user_id"]
|
|
288
|
+
if field in result.__dict__
|
|
289
|
+
},
|
|
290
|
+
**{k: v for k, v in json.loads(extract_json(result["metadata"])).items()},
|
|
291
|
+
},
|
|
292
|
+
)
|
|
293
|
+
for result in results.docs
|
|
294
|
+
]
|
|
295
|
+
]
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import vecs
|
|
9
|
+
except ImportError:
|
|
10
|
+
raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.")
|
|
11
|
+
|
|
12
|
+
from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod
|
|
13
|
+
from mem0.vector_stores.base import VectorStoreBase
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OutputData(BaseModel):
|
|
19
|
+
id: Optional[str]
|
|
20
|
+
score: Optional[float]
|
|
21
|
+
payload: Optional[dict]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Supabase(VectorStoreBase):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
connection_string: str,
|
|
28
|
+
collection_name: str,
|
|
29
|
+
embedding_model_dims: int,
|
|
30
|
+
index_method: IndexMethod = IndexMethod.AUTO,
|
|
31
|
+
index_measure: IndexMeasure = IndexMeasure.COSINE,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize the Supabase vector store using vecs.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
connection_string (str): PostgreSQL connection string
|
|
38
|
+
collection_name (str): Collection name
|
|
39
|
+
embedding_model_dims (int): Dimension of the embedding vector
|
|
40
|
+
index_method (IndexMethod): Index method to use. Defaults to AUTO.
|
|
41
|
+
index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE.
|
|
42
|
+
"""
|
|
43
|
+
self.db = vecs.create_client(connection_string)
|
|
44
|
+
self.collection_name = collection_name
|
|
45
|
+
self.embedding_model_dims = embedding_model_dims
|
|
46
|
+
self.index_method = index_method
|
|
47
|
+
self.index_measure = index_measure
|
|
48
|
+
|
|
49
|
+
collections = self.list_cols()
|
|
50
|
+
if collection_name not in collections:
|
|
51
|
+
self.create_col(embedding_model_dims)
|
|
52
|
+
|
|
53
|
+
def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]:
|
|
54
|
+
"""
|
|
55
|
+
Preprocess filters to be compatible with vecs.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
filters (Dict, optional): Filters to preprocess. Multiple filters will be
|
|
59
|
+
combined with AND logic.
|
|
60
|
+
"""
|
|
61
|
+
if filters is None:
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
if len(filters) == 1:
|
|
65
|
+
# For single filter, keep the simple format
|
|
66
|
+
key, value = next(iter(filters.items()))
|
|
67
|
+
return {key: {"$eq": value}}
|
|
68
|
+
|
|
69
|
+
# For multiple filters, use $and clause
|
|
70
|
+
return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]}
|
|
71
|
+
|
|
72
|
+
def create_col(self, embedding_model_dims: Optional[int] = None) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Create a new collection with vector support.
|
|
75
|
+
Will also initialize vector search index.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
embedding_model_dims (int, optional): Dimension of the embedding vector.
|
|
79
|
+
If not provided, uses the dimension specified in initialization.
|
|
80
|
+
"""
|
|
81
|
+
dims = embedding_model_dims or self.embedding_model_dims
|
|
82
|
+
if not dims:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"embedding_model_dims must be provided either during initialization or when creating collection"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
logger.info(f"Creating new collection: {self.collection_name}")
|
|
88
|
+
try:
|
|
89
|
+
self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims)
|
|
90
|
+
self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value)
|
|
91
|
+
logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}")
|
|
92
|
+
except Exception as e:
|
|
93
|
+
logger.error(f"Failed to create collection: {str(e)}")
|
|
94
|
+
raise
|
|
95
|
+
|
|
96
|
+
def insert(
|
|
97
|
+
self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None
|
|
98
|
+
):
|
|
99
|
+
"""
|
|
100
|
+
Insert vectors into the collection.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
vectors (List[List[float]]): List of vectors to insert
|
|
104
|
+
payloads (List[Dict], optional): List of payloads corresponding to vectors
|
|
105
|
+
ids (List[str], optional): List of IDs corresponding to vectors
|
|
106
|
+
"""
|
|
107
|
+
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
|
108
|
+
|
|
109
|
+
if not ids:
|
|
110
|
+
ids = [str(uuid.uuid4()) for _ in vectors]
|
|
111
|
+
if not payloads:
|
|
112
|
+
payloads = [{} for _ in vectors]
|
|
113
|
+
|
|
114
|
+
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
|
|
115
|
+
|
|
116
|
+
self.collection.upsert(records)
|
|
117
|
+
|
|
118
|
+
def search(
|
|
119
|
+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None
|
|
120
|
+
) -> List[OutputData]:
|
|
121
|
+
"""
|
|
122
|
+
Search for similar vectors.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
query (str): Query.
|
|
126
|
+
vectors (List[float]): Query vector.
|
|
127
|
+
limit (int, optional): Number of results to return. Defaults to 5.
|
|
128
|
+
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
List[OutputData]: Search results
|
|
132
|
+
"""
|
|
133
|
+
filters = self._preprocess_filters(filters)
|
|
134
|
+
results = self.collection.query(
|
|
135
|
+
data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
|
|
139
|
+
|
|
140
|
+
def delete(self, vector_id: str):
|
|
141
|
+
"""
|
|
142
|
+
Delete a vector by ID.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
vector_id (str): ID of the vector to delete
|
|
146
|
+
"""
|
|
147
|
+
self.collection.delete([(vector_id,)])
|
|
148
|
+
|
|
149
|
+
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None):
|
|
150
|
+
"""
|
|
151
|
+
Update a vector and/or its payload.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
vector_id (str): ID of the vector to update
|
|
155
|
+
vector (List[float], optional): Updated vector
|
|
156
|
+
payload (Dict, optional): Updated payload
|
|
157
|
+
"""
|
|
158
|
+
if vector is None:
|
|
159
|
+
# If only updating metadata, we need to get the existing vector
|
|
160
|
+
existing = self.get(vector_id)
|
|
161
|
+
if existing and existing.payload:
|
|
162
|
+
vector = existing.payload.get("vector", [])
|
|
163
|
+
|
|
164
|
+
if vector:
|
|
165
|
+
self.collection.upsert([(vector_id, vector, payload or {})])
|
|
166
|
+
|
|
167
|
+
def get(self, vector_id: str) -> Optional[OutputData]:
|
|
168
|
+
"""
|
|
169
|
+
Retrieve a vector by ID.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
vector_id (str): ID of the vector to retrieve
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Optional[OutputData]: Retrieved vector data or None if not found
|
|
176
|
+
"""
|
|
177
|
+
result = self.collection.fetch([(vector_id,)])
|
|
178
|
+
if not result:
|
|
179
|
+
return []
|
|
180
|
+
|
|
181
|
+
record = result[0]
|
|
182
|
+
return OutputData(id=str(record.id), score=None, payload=record.metadata)
|
|
183
|
+
|
|
184
|
+
def list_cols(self) -> List[str]:
|
|
185
|
+
"""
|
|
186
|
+
List all collections.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
List[str]: List of collection names
|
|
190
|
+
"""
|
|
191
|
+
return self.db.list_collections()
|
|
192
|
+
|
|
193
|
+
def delete_col(self):
|
|
194
|
+
"""Delete the collection."""
|
|
195
|
+
self.db.delete_collection(self.collection_name)
|
|
196
|
+
|
|
197
|
+
def col_info(self) -> dict:
|
|
198
|
+
"""
|
|
199
|
+
Get information about the collection.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Dict: Collection information including name and configuration
|
|
203
|
+
"""
|
|
204
|
+
info = self.collection.describe()
|
|
205
|
+
return {
|
|
206
|
+
"name": info.name,
|
|
207
|
+
"count": info.vectors,
|
|
208
|
+
"dimension": info.dimension,
|
|
209
|
+
"index": {"method": info.index_method, "metric": info.distance_metric},
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]:
|
|
213
|
+
"""
|
|
214
|
+
List vectors in the collection.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
filters (Dict, optional): Filters to apply
|
|
218
|
+
limit (int, optional): Maximum number of results to return. Defaults to 100.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
List[OutputData]: List of vectors
|
|
222
|
+
"""
|
|
223
|
+
filters = self._preprocess_filters(filters)
|
|
224
|
+
query = [0] * self.embedding_model_dims
|
|
225
|
+
ids = self.collection.query(
|
|
226
|
+
data=query, limit=limit, filters=filters, include_metadata=True, include_value=False
|
|
227
|
+
)
|
|
228
|
+
ids = [id[0] for id in ids]
|
|
229
|
+
records = self.collection.fetch(ids=ids)
|
|
230
|
+
|
|
231
|
+
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
|
232
|
+
|
|
233
|
+
def reset(self):
|
|
234
|
+
"""Reset the index by deleting and recreating it."""
|
|
235
|
+
logger.warning(f"Resetting index {self.collection_name}...")
|
|
236
|
+
self.delete_col()
|
|
237
|
+
self.create_col(self.embedding_model_dims)
|