mem0ai-azure-mysql 0.1.115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem0/__init__.py +6 -0
- mem0/client/__init__.py +0 -0
- mem0/client/main.py +1535 -0
- mem0/client/project.py +860 -0
- mem0/client/utils.py +29 -0
- mem0/configs/__init__.py +0 -0
- mem0/configs/base.py +90 -0
- mem0/configs/dbs/__init__.py +4 -0
- mem0/configs/dbs/base.py +41 -0
- mem0/configs/dbs/mysql.py +25 -0
- mem0/configs/embeddings/__init__.py +0 -0
- mem0/configs/embeddings/base.py +108 -0
- mem0/configs/enums.py +7 -0
- mem0/configs/llms/__init__.py +0 -0
- mem0/configs/llms/base.py +152 -0
- mem0/configs/prompts.py +333 -0
- mem0/configs/vector_stores/__init__.py +0 -0
- mem0/configs/vector_stores/azure_ai_search.py +59 -0
- mem0/configs/vector_stores/baidu.py +29 -0
- mem0/configs/vector_stores/chroma.py +40 -0
- mem0/configs/vector_stores/elasticsearch.py +47 -0
- mem0/configs/vector_stores/faiss.py +39 -0
- mem0/configs/vector_stores/langchain.py +32 -0
- mem0/configs/vector_stores/milvus.py +43 -0
- mem0/configs/vector_stores/mongodb.py +25 -0
- mem0/configs/vector_stores/opensearch.py +41 -0
- mem0/configs/vector_stores/pgvector.py +37 -0
- mem0/configs/vector_stores/pinecone.py +56 -0
- mem0/configs/vector_stores/qdrant.py +49 -0
- mem0/configs/vector_stores/redis.py +26 -0
- mem0/configs/vector_stores/supabase.py +44 -0
- mem0/configs/vector_stores/upstash_vector.py +36 -0
- mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
- mem0/configs/vector_stores/weaviate.py +43 -0
- mem0/dbs/__init__.py +4 -0
- mem0/dbs/base.py +68 -0
- mem0/dbs/configs.py +21 -0
- mem0/dbs/mysql.py +321 -0
- mem0/embeddings/__init__.py +0 -0
- mem0/embeddings/aws_bedrock.py +100 -0
- mem0/embeddings/azure_openai.py +43 -0
- mem0/embeddings/base.py +31 -0
- mem0/embeddings/configs.py +30 -0
- mem0/embeddings/gemini.py +39 -0
- mem0/embeddings/huggingface.py +41 -0
- mem0/embeddings/langchain.py +35 -0
- mem0/embeddings/lmstudio.py +29 -0
- mem0/embeddings/mock.py +11 -0
- mem0/embeddings/ollama.py +53 -0
- mem0/embeddings/openai.py +49 -0
- mem0/embeddings/together.py +31 -0
- mem0/embeddings/vertexai.py +54 -0
- mem0/graphs/__init__.py +0 -0
- mem0/graphs/configs.py +96 -0
- mem0/graphs/neptune/__init__.py +0 -0
- mem0/graphs/neptune/base.py +410 -0
- mem0/graphs/neptune/main.py +372 -0
- mem0/graphs/tools.py +371 -0
- mem0/graphs/utils.py +97 -0
- mem0/llms/__init__.py +0 -0
- mem0/llms/anthropic.py +64 -0
- mem0/llms/aws_bedrock.py +270 -0
- mem0/llms/azure_openai.py +114 -0
- mem0/llms/azure_openai_structured.py +76 -0
- mem0/llms/base.py +32 -0
- mem0/llms/configs.py +34 -0
- mem0/llms/deepseek.py +85 -0
- mem0/llms/gemini.py +201 -0
- mem0/llms/groq.py +88 -0
- mem0/llms/langchain.py +65 -0
- mem0/llms/litellm.py +87 -0
- mem0/llms/lmstudio.py +53 -0
- mem0/llms/ollama.py +94 -0
- mem0/llms/openai.py +124 -0
- mem0/llms/openai_structured.py +52 -0
- mem0/llms/sarvam.py +89 -0
- mem0/llms/together.py +88 -0
- mem0/llms/vllm.py +89 -0
- mem0/llms/xai.py +52 -0
- mem0/memory/__init__.py +0 -0
- mem0/memory/base.py +63 -0
- mem0/memory/graph_memory.py +632 -0
- mem0/memory/main.py +1843 -0
- mem0/memory/memgraph_memory.py +630 -0
- mem0/memory/setup.py +56 -0
- mem0/memory/storage.py +218 -0
- mem0/memory/telemetry.py +90 -0
- mem0/memory/utils.py +133 -0
- mem0/proxy/__init__.py +0 -0
- mem0/proxy/main.py +194 -0
- mem0/utils/factory.py +132 -0
- mem0/vector_stores/__init__.py +0 -0
- mem0/vector_stores/azure_ai_search.py +383 -0
- mem0/vector_stores/baidu.py +368 -0
- mem0/vector_stores/base.py +58 -0
- mem0/vector_stores/chroma.py +229 -0
- mem0/vector_stores/configs.py +60 -0
- mem0/vector_stores/elasticsearch.py +235 -0
- mem0/vector_stores/faiss.py +473 -0
- mem0/vector_stores/langchain.py +179 -0
- mem0/vector_stores/milvus.py +245 -0
- mem0/vector_stores/mongodb.py +293 -0
- mem0/vector_stores/opensearch.py +281 -0
- mem0/vector_stores/pgvector.py +294 -0
- mem0/vector_stores/pinecone.py +373 -0
- mem0/vector_stores/qdrant.py +240 -0
- mem0/vector_stores/redis.py +295 -0
- mem0/vector_stores/supabase.py +237 -0
- mem0/vector_stores/upstash_vector.py +293 -0
- mem0/vector_stores/vertex_ai_vector_search.py +629 -0
- mem0/vector_stores/weaviate.py +316 -0
- mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
- mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
- mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
- mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
- mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
mem0/dbs/mysql.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import ssl
|
|
2
|
+
import logging
|
|
3
|
+
import threading
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import create_engine, text
|
|
8
|
+
from sqlalchemy.engine import Engine
|
|
9
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
10
|
+
from sqlalchemy.orm import sessionmaker, Session
|
|
11
|
+
|
|
12
|
+
from mem0.configs.dbs.mysql import MySQLConfig
|
|
13
|
+
from mem0.dbs.base import DBBase
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MySQLManager(DBBase):
|
|
19
|
+
"""MySQL implementation of DBBase for managing memory history using SQLAlchemy."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: Optional[MySQLConfig] = None):
|
|
22
|
+
super().__init__(config)
|
|
23
|
+
if config is None:
|
|
24
|
+
self.config = MySQLConfig()
|
|
25
|
+
else:
|
|
26
|
+
self.config = config
|
|
27
|
+
|
|
28
|
+
self.engine: Optional[Engine] = None
|
|
29
|
+
self.Session: Optional[sessionmaker] = None
|
|
30
|
+
self._lock = threading.Lock()
|
|
31
|
+
self._connect()
|
|
32
|
+
self._migrate_history_table()
|
|
33
|
+
self._create_history_table()
|
|
34
|
+
|
|
35
|
+
def _connect(self) -> None:
|
|
36
|
+
"""Establish connection to MySQL database using SQLAlchemy."""
|
|
37
|
+
try:
|
|
38
|
+
# Build connection URL
|
|
39
|
+
connection_url = self._build_connection_url()
|
|
40
|
+
|
|
41
|
+
# Create engine with connection pooling
|
|
42
|
+
connect_args = {}
|
|
43
|
+
if hasattr(self.config, 'connection_params'):
|
|
44
|
+
# Add valid MySQL connection parameters
|
|
45
|
+
valid_params = {
|
|
46
|
+
'ssl_ca', 'ssl_cert', 'ssl_key', 'ssl_verify_cert', 'ssl_verify_identity',
|
|
47
|
+
'connect_timeout', 'charset', 'init_command'
|
|
48
|
+
}
|
|
49
|
+
for key, value in self.config.connection_params.items():
|
|
50
|
+
if key in valid_params:
|
|
51
|
+
connect_args[key] = value
|
|
52
|
+
|
|
53
|
+
if self.config.ssl_enabled:
|
|
54
|
+
connect_args['ssl'] = ssl.create_default_context()
|
|
55
|
+
|
|
56
|
+
self.engine = create_engine(
|
|
57
|
+
connection_url,
|
|
58
|
+
connect_args=connect_args,
|
|
59
|
+
pool_pre_ping=True,
|
|
60
|
+
pool_recycle=3600,
|
|
61
|
+
echo=False
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.Session = sessionmaker(bind=self.engine)
|
|
65
|
+
|
|
66
|
+
# Test connection
|
|
67
|
+
with self.engine.connect() as conn:
|
|
68
|
+
conn.execute(text("SELECT 1"))
|
|
69
|
+
|
|
70
|
+
logger.info("Successfully connected to MySQL database using SQLAlchemy")
|
|
71
|
+
|
|
72
|
+
except SQLAlchemyError as e:
|
|
73
|
+
logger.error(f"Failed to connect to MySQL: {e}")
|
|
74
|
+
raise
|
|
75
|
+
|
|
76
|
+
def _build_connection_url(self) -> str:
|
|
77
|
+
"""Build SQLAlchemy connection URL for MySQL."""
|
|
78
|
+
# Use PyMySQL as the driver (more reliable than mysqlclient)
|
|
79
|
+
url_parts = ["mysql+pymysql://"]
|
|
80
|
+
|
|
81
|
+
if self.config.user:
|
|
82
|
+
url_parts.append(self.config.user)
|
|
83
|
+
if self.config.password:
|
|
84
|
+
url_parts.append(f":{self.config.password}")
|
|
85
|
+
url_parts.append("@")
|
|
86
|
+
|
|
87
|
+
url_parts.append(self.config.host or "localhost")
|
|
88
|
+
|
|
89
|
+
if self.config.port:
|
|
90
|
+
url_parts.append(f":{self.config.port}")
|
|
91
|
+
|
|
92
|
+
if self.config.database:
|
|
93
|
+
url_parts.append(f"/{self.config.database}")
|
|
94
|
+
|
|
95
|
+
return "".join(url_parts)
|
|
96
|
+
|
|
97
|
+
def _migrate_history_table(self) -> None:
|
|
98
|
+
"""
|
|
99
|
+
If a pre-existing history table had the old schema,
|
|
100
|
+
rename it, create the new schema, copy the intersecting data, then
|
|
101
|
+
drop the old table.
|
|
102
|
+
"""
|
|
103
|
+
with self._lock:
|
|
104
|
+
if self.engine is None:
|
|
105
|
+
raise RuntimeError("Database connection is not established")
|
|
106
|
+
try:
|
|
107
|
+
with self.engine.begin() as conn:
|
|
108
|
+
# Check if history table exists
|
|
109
|
+
result = conn.execute(text("""
|
|
110
|
+
SELECT COUNT(*)
|
|
111
|
+
FROM information_schema.tables
|
|
112
|
+
WHERE table_schema = DATABASE()
|
|
113
|
+
AND table_name = 'history'
|
|
114
|
+
"""))
|
|
115
|
+
|
|
116
|
+
count = result.scalar()
|
|
117
|
+
if count == 0:
|
|
118
|
+
return # nothing to migrate
|
|
119
|
+
|
|
120
|
+
# Get current table columns
|
|
121
|
+
result = conn.execute(text("""
|
|
122
|
+
SELECT column_name
|
|
123
|
+
FROM information_schema.columns
|
|
124
|
+
WHERE table_schema = DATABASE()
|
|
125
|
+
AND table_name = 'history'
|
|
126
|
+
"""))
|
|
127
|
+
old_cols = {row[0] for row in result.fetchall()}
|
|
128
|
+
|
|
129
|
+
expected_cols = {
|
|
130
|
+
"id",
|
|
131
|
+
"memory_id",
|
|
132
|
+
"old_memory",
|
|
133
|
+
"new_memory",
|
|
134
|
+
"event",
|
|
135
|
+
"created_at",
|
|
136
|
+
"updated_at",
|
|
137
|
+
"is_deleted",
|
|
138
|
+
"actor_id",
|
|
139
|
+
"role",
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if old_cols == expected_cols:
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
logger.info("Migrating history table to new schema (no convo columns).")
|
|
146
|
+
|
|
147
|
+
# Clean up any existing history_old table from previous failed migration
|
|
148
|
+
conn.execute(text("DROP TABLE IF EXISTS history_old"))
|
|
149
|
+
|
|
150
|
+
# Rename the current history table
|
|
151
|
+
conn.execute(text("ALTER TABLE history RENAME TO history_old"))
|
|
152
|
+
|
|
153
|
+
# Create the new history table with updated schema
|
|
154
|
+
conn.execute(text("""
|
|
155
|
+
CREATE TABLE history (
|
|
156
|
+
id VARCHAR(36) PRIMARY KEY,
|
|
157
|
+
memory_id TEXT,
|
|
158
|
+
old_memory TEXT,
|
|
159
|
+
new_memory TEXT,
|
|
160
|
+
event TEXT,
|
|
161
|
+
created_at TIMESTAMP NULL,
|
|
162
|
+
updated_at TIMESTAMP NULL,
|
|
163
|
+
is_deleted INT,
|
|
164
|
+
actor_id TEXT,
|
|
165
|
+
role TEXT
|
|
166
|
+
)
|
|
167
|
+
"""))
|
|
168
|
+
|
|
169
|
+
# Copy data from old table to new table
|
|
170
|
+
intersecting = list(expected_cols & old_cols)
|
|
171
|
+
if intersecting:
|
|
172
|
+
cols_str = ", ".join(f"`{col}`" for col in intersecting)
|
|
173
|
+
query = f"INSERT INTO history ({cols_str}) SELECT {cols_str} FROM history_old"
|
|
174
|
+
conn.execute(text(query))
|
|
175
|
+
|
|
176
|
+
# Drop the old table
|
|
177
|
+
conn.execute(text("DROP TABLE history_old"))
|
|
178
|
+
|
|
179
|
+
logger.info("History table migration completed successfully.")
|
|
180
|
+
|
|
181
|
+
except SQLAlchemyError as e:
|
|
182
|
+
logger.error(f"History table migration failed: {e}")
|
|
183
|
+
raise
|
|
184
|
+
|
|
185
|
+
def _create_history_table(self) -> None:
|
|
186
|
+
"""Create the history table if it doesn't exist."""
|
|
187
|
+
with self._lock:
|
|
188
|
+
if self.engine is None:
|
|
189
|
+
raise RuntimeError("Database connection is not established")
|
|
190
|
+
try:
|
|
191
|
+
with self.engine.begin() as conn:
|
|
192
|
+
conn.execute(text("""
|
|
193
|
+
CREATE TABLE IF NOT EXISTS history (
|
|
194
|
+
id VARCHAR(36) PRIMARY KEY,
|
|
195
|
+
memory_id TEXT,
|
|
196
|
+
old_memory TEXT,
|
|
197
|
+
new_memory TEXT,
|
|
198
|
+
event TEXT,
|
|
199
|
+
created_at TIMESTAMP NULL,
|
|
200
|
+
updated_at TIMESTAMP NULL,
|
|
201
|
+
is_deleted INT,
|
|
202
|
+
actor_id TEXT,
|
|
203
|
+
role TEXT
|
|
204
|
+
)
|
|
205
|
+
"""))
|
|
206
|
+
|
|
207
|
+
except SQLAlchemyError as e:
|
|
208
|
+
logger.error(f"Failed to create history table: {e}")
|
|
209
|
+
raise
|
|
210
|
+
|
|
211
|
+
def add_history(
|
|
212
|
+
self,
|
|
213
|
+
memory_id: str,
|
|
214
|
+
old_memory: Optional[str],
|
|
215
|
+
new_memory: Optional[str],
|
|
216
|
+
event: str,
|
|
217
|
+
*,
|
|
218
|
+
created_at: Optional[str] = None,
|
|
219
|
+
updated_at: Optional[str] = None,
|
|
220
|
+
is_deleted: int = 0,
|
|
221
|
+
actor_id: Optional[str] = None,
|
|
222
|
+
role: Optional[str] = None,
|
|
223
|
+
) -> None:
|
|
224
|
+
"""Add a history record to the database.
|
|
225
|
+
|
|
226
|
+
:param memory_id: The ID of the memory being tracked
|
|
227
|
+
:param old_memory: The previous memory content
|
|
228
|
+
:param new_memory: The new memory content
|
|
229
|
+
:param event: The type of event that occurred
|
|
230
|
+
:param created_at: When the record was created
|
|
231
|
+
:param updated_at: When the record was last updated
|
|
232
|
+
:param is_deleted: Whether the record is deleted (0 or 1)
|
|
233
|
+
:param actor_id: ID of the actor who made the change
|
|
234
|
+
:param role: Role of the actor
|
|
235
|
+
"""
|
|
236
|
+
with self._lock:
|
|
237
|
+
if self.engine is None:
|
|
238
|
+
raise RuntimeError("Database connection is not established")
|
|
239
|
+
try:
|
|
240
|
+
with self.engine.begin() as conn:
|
|
241
|
+
conn.execute(text("""
|
|
242
|
+
INSERT INTO history (
|
|
243
|
+
id, memory_id, old_memory, new_memory, event,
|
|
244
|
+
created_at, updated_at, is_deleted, actor_id, role
|
|
245
|
+
)
|
|
246
|
+
VALUES (:id, :memory_id, :old_memory, :new_memory, :event,
|
|
247
|
+
:created_at, :updated_at, :is_deleted, :actor_id, :role)
|
|
248
|
+
"""), {
|
|
249
|
+
"id": str(uuid.uuid4()),
|
|
250
|
+
"memory_id": memory_id,
|
|
251
|
+
"old_memory": old_memory,
|
|
252
|
+
"new_memory": new_memory,
|
|
253
|
+
"event": event,
|
|
254
|
+
"created_at": created_at,
|
|
255
|
+
"updated_at": updated_at,
|
|
256
|
+
"is_deleted": is_deleted,
|
|
257
|
+
"actor_id": actor_id,
|
|
258
|
+
"role": role,
|
|
259
|
+
})
|
|
260
|
+
|
|
261
|
+
except SQLAlchemyError as e:
|
|
262
|
+
logger.error(f"Failed to add history record: {e}")
|
|
263
|
+
raise
|
|
264
|
+
|
|
265
|
+
def get_history(self, memory_id: str) -> List[Dict[str, Any]]:
|
|
266
|
+
"""Retrieve history records for a given memory ID.
|
|
267
|
+
|
|
268
|
+
:param memory_id: The ID of the memory to get history for
|
|
269
|
+
:return: List of history records as dictionaries
|
|
270
|
+
"""
|
|
271
|
+
with self._lock:
|
|
272
|
+
if self.engine is None:
|
|
273
|
+
raise RuntimeError("Database connection is not established")
|
|
274
|
+
|
|
275
|
+
with self.engine.connect() as conn:
|
|
276
|
+
result = conn.execute(text("""
|
|
277
|
+
SELECT id, memory_id, old_memory, new_memory, event,
|
|
278
|
+
created_at, updated_at, is_deleted, actor_id, role
|
|
279
|
+
FROM history
|
|
280
|
+
WHERE memory_id = :memory_id
|
|
281
|
+
ORDER BY created_at ASC, updated_at ASC
|
|
282
|
+
"""), {"memory_id": memory_id})
|
|
283
|
+
|
|
284
|
+
rows = result.fetchall()
|
|
285
|
+
|
|
286
|
+
return [
|
|
287
|
+
{
|
|
288
|
+
"id": r.id,
|
|
289
|
+
"memory_id": r.memory_id,
|
|
290
|
+
"old_memory": r.old_memory,
|
|
291
|
+
"new_memory": r.new_memory,
|
|
292
|
+
"event": r.event,
|
|
293
|
+
"created_at": r.created_at,
|
|
294
|
+
"updated_at": r.updated_at,
|
|
295
|
+
"is_deleted": bool(r.is_deleted),
|
|
296
|
+
"actor_id": r.actor_id,
|
|
297
|
+
"role": r.role,
|
|
298
|
+
}
|
|
299
|
+
for r in rows
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
def reset(self) -> None:
|
|
303
|
+
"""Reset/clear all data in the database."""
|
|
304
|
+
with self._lock:
|
|
305
|
+
if self.engine is None:
|
|
306
|
+
raise RuntimeError("Database connection is not established")
|
|
307
|
+
try:
|
|
308
|
+
with self.engine.begin() as conn:
|
|
309
|
+
conn.execute(text("DROP TABLE IF EXISTS history"))
|
|
310
|
+
self._create_history_table()
|
|
311
|
+
|
|
312
|
+
except SQLAlchemyError as e:
|
|
313
|
+
logger.error(f"Failed to reset history table: {e}")
|
|
314
|
+
raise
|
|
315
|
+
|
|
316
|
+
def close(self) -> None:
|
|
317
|
+
"""Close the database connection and clean up resources."""
|
|
318
|
+
if self.engine:
|
|
319
|
+
self.engine.dispose()
|
|
320
|
+
self.engine = None
|
|
321
|
+
self.Session = None
|
|
File without changes
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import boto3
|
|
7
|
+
except ImportError:
|
|
8
|
+
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
13
|
+
from mem0.embeddings.base import EmbeddingBase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AWSBedrockEmbedding(EmbeddingBase):
|
|
17
|
+
"""AWS Bedrock embedding implementation.
|
|
18
|
+
|
|
19
|
+
This class uses AWS Bedrock's embedding models.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
23
|
+
super().__init__(config)
|
|
24
|
+
|
|
25
|
+
self.config.model = self.config.model or "amazon.titan-embed-text-v1"
|
|
26
|
+
|
|
27
|
+
# Get AWS config from environment variables or use defaults
|
|
28
|
+
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
|
29
|
+
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
|
30
|
+
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", "")
|
|
31
|
+
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
|
32
|
+
|
|
33
|
+
# Check if AWS config is provided in the config
|
|
34
|
+
if hasattr(self.config, "aws_access_key_id"):
|
|
35
|
+
aws_access_key = self.config.aws_access_key_id
|
|
36
|
+
if hasattr(self.config, "aws_secret_access_key"):
|
|
37
|
+
aws_secret_key = self.config.aws_secret_access_key
|
|
38
|
+
if hasattr(self.config, "aws_region"):
|
|
39
|
+
aws_region = self.config.aws_region
|
|
40
|
+
|
|
41
|
+
self.client = boto3.client(
|
|
42
|
+
"bedrock-runtime",
|
|
43
|
+
region_name=aws_region,
|
|
44
|
+
aws_access_key_id=aws_access_key if aws_access_key else None,
|
|
45
|
+
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
|
46
|
+
aws_session_token=aws_session_token if aws_session_token else None,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def _normalize_vector(self, embeddings):
|
|
50
|
+
"""Normalize the embedding to a unit vector."""
|
|
51
|
+
emb = np.array(embeddings)
|
|
52
|
+
norm_emb = emb / np.linalg.norm(emb)
|
|
53
|
+
return norm_emb.tolist()
|
|
54
|
+
|
|
55
|
+
def _get_embedding(self, text):
|
|
56
|
+
"""Call out to Bedrock embedding endpoint."""
|
|
57
|
+
|
|
58
|
+
# Format input body based on the provider
|
|
59
|
+
provider = self.config.model.split(".")[0]
|
|
60
|
+
input_body = {}
|
|
61
|
+
|
|
62
|
+
if provider == "cohere":
|
|
63
|
+
input_body["input_type"] = "search_document"
|
|
64
|
+
input_body["texts"] = [text]
|
|
65
|
+
else:
|
|
66
|
+
# Amazon and other providers
|
|
67
|
+
input_body["inputText"] = text
|
|
68
|
+
|
|
69
|
+
body = json.dumps(input_body)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
response = self.client.invoke_model(
|
|
73
|
+
body=body,
|
|
74
|
+
modelId=self.config.model,
|
|
75
|
+
accept="application/json",
|
|
76
|
+
contentType="application/json",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
response_body = json.loads(response.get("body").read())
|
|
80
|
+
|
|
81
|
+
if provider == "cohere":
|
|
82
|
+
embeddings = response_body.get("embeddings")[0]
|
|
83
|
+
else:
|
|
84
|
+
embeddings = response_body.get("embedding")
|
|
85
|
+
|
|
86
|
+
return embeddings
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise ValueError(f"Error getting embedding from AWS Bedrock: {e}")
|
|
89
|
+
|
|
90
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
91
|
+
"""
|
|
92
|
+
Get the embedding for the given text using AWS Bedrock.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
text (str): The text to embed.
|
|
96
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
97
|
+
Returns:
|
|
98
|
+
list: The embedding vector.
|
|
99
|
+
"""
|
|
100
|
+
return self._get_embedding(text)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from openai import AzureOpenAI
|
|
5
|
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
6
|
+
|
|
7
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
8
|
+
from mem0.embeddings.base import EmbeddingBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AzureOpenAIEmbedding(EmbeddingBase):
|
|
12
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
13
|
+
super().__init__(config)
|
|
14
|
+
|
|
15
|
+
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT")
|
|
16
|
+
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT")
|
|
17
|
+
api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION")
|
|
18
|
+
default_headers = self.config.azure_kwargs.default_headers
|
|
19
|
+
|
|
20
|
+
credential = DefaultAzureCredential()
|
|
21
|
+
token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default")
|
|
22
|
+
|
|
23
|
+
self.client = AzureOpenAI(
|
|
24
|
+
azure_deployment=azure_deployment,
|
|
25
|
+
azure_endpoint=azure_endpoint,
|
|
26
|
+
api_version=api_version,
|
|
27
|
+
azure_ad_token_provider=token_provider,
|
|
28
|
+
http_client=self.config.http_client,
|
|
29
|
+
default_headers=default_headers,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
33
|
+
"""
|
|
34
|
+
Get the embedding for the given text using OpenAI.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
text (str): The text to embed.
|
|
38
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
39
|
+
Returns:
|
|
40
|
+
list: The embedding vector.
|
|
41
|
+
"""
|
|
42
|
+
text = text.replace("\n", " ")
|
|
43
|
+
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|
mem0/embeddings/base.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EmbeddingBase(ABC):
|
|
8
|
+
"""Initialized a base embedding class
|
|
9
|
+
|
|
10
|
+
:param config: Embedding configuration option class, defaults to None
|
|
11
|
+
:type config: Optional[BaseEmbedderConfig], optional
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
15
|
+
if config is None:
|
|
16
|
+
self.config = BaseEmbedderConfig()
|
|
17
|
+
else:
|
|
18
|
+
self.config = config
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]]):
|
|
22
|
+
"""
|
|
23
|
+
Get the embedding for the given text.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
text (str): The text to embed.
|
|
27
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
28
|
+
Returns:
|
|
29
|
+
list: The embedding vector.
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, field_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EmbedderConfig(BaseModel):
|
|
7
|
+
provider: str = Field(
|
|
8
|
+
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
|
|
9
|
+
default="openai",
|
|
10
|
+
)
|
|
11
|
+
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})
|
|
12
|
+
|
|
13
|
+
@field_validator("config")
|
|
14
|
+
def validate_config(cls, v, values):
|
|
15
|
+
provider = values.data.get("provider")
|
|
16
|
+
if provider in [
|
|
17
|
+
"openai",
|
|
18
|
+
"ollama",
|
|
19
|
+
"huggingface",
|
|
20
|
+
"azure_openai",
|
|
21
|
+
"gemini",
|
|
22
|
+
"vertexai",
|
|
23
|
+
"together",
|
|
24
|
+
"lmstudio",
|
|
25
|
+
"langchain",
|
|
26
|
+
"aws_bedrock",
|
|
27
|
+
]:
|
|
28
|
+
return v
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError(f"Unsupported embedding provider: {provider}")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from google import genai
|
|
5
|
+
from google.genai import types
|
|
6
|
+
|
|
7
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
8
|
+
from mem0.embeddings.base import EmbeddingBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleGenAIEmbedding(EmbeddingBase):
|
|
12
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
13
|
+
super().__init__(config)
|
|
14
|
+
|
|
15
|
+
self.config.model = self.config.model or "models/text-embedding-004"
|
|
16
|
+
self.config.embedding_dims = self.config.embedding_dims or self.config.output_dimensionality or 768
|
|
17
|
+
|
|
18
|
+
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
|
19
|
+
|
|
20
|
+
self.client = genai.Client(api_key=api_key)
|
|
21
|
+
|
|
22
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
23
|
+
"""
|
|
24
|
+
Get the embedding for the given text using Google Generative AI.
|
|
25
|
+
Args:
|
|
26
|
+
text (str): The text to embed.
|
|
27
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
28
|
+
Returns:
|
|
29
|
+
list: The embedding vector.
|
|
30
|
+
"""
|
|
31
|
+
text = text.replace("\n", " ")
|
|
32
|
+
|
|
33
|
+
# Create config for embedding parameters
|
|
34
|
+
config = types.EmbedContentConfig(output_dimensionality=self.config.embedding_dims)
|
|
35
|
+
|
|
36
|
+
# Call the embed_content method with the correct parameters
|
|
37
|
+
response = self.client.models.embed_content(model=self.config.model, contents=text, config=config)
|
|
38
|
+
|
|
39
|
+
return response.embeddings[0].values
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from openai import OpenAI
|
|
5
|
+
from sentence_transformers import SentenceTransformer
|
|
6
|
+
|
|
7
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
8
|
+
from mem0.embeddings.base import EmbeddingBase
|
|
9
|
+
|
|
10
|
+
logging.getLogger("transformers").setLevel(logging.WARNING)
|
|
11
|
+
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
|
12
|
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HuggingFaceEmbedding(EmbeddingBase):
|
|
16
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
17
|
+
super().__init__(config)
|
|
18
|
+
|
|
19
|
+
if config.huggingface_base_url:
|
|
20
|
+
self.client = OpenAI(base_url=config.huggingface_base_url)
|
|
21
|
+
else:
|
|
22
|
+
self.config.model = self.config.model or "multi-qa-MiniLM-L6-cos-v1"
|
|
23
|
+
|
|
24
|
+
self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs)
|
|
25
|
+
|
|
26
|
+
self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension()
|
|
27
|
+
|
|
28
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
29
|
+
"""
|
|
30
|
+
Get the embedding for the given text using Hugging Face.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
text (str): The text to embed.
|
|
34
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
35
|
+
Returns:
|
|
36
|
+
list: The embedding vector.
|
|
37
|
+
"""
|
|
38
|
+
if self.config.huggingface_base_url:
|
|
39
|
+
return self.client.embeddings.create(input=text, model="tei").data[0].embedding
|
|
40
|
+
else:
|
|
41
|
+
return self.model.encode(text, convert_to_numpy=True).tolist()
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Literal, Optional
|
|
2
|
+
|
|
3
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
4
|
+
from mem0.embeddings.base import EmbeddingBase
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from langchain.embeddings.base import Embeddings
|
|
8
|
+
except ImportError:
|
|
9
|
+
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LangchainEmbedding(EmbeddingBase):
|
|
13
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
14
|
+
super().__init__(config)
|
|
15
|
+
|
|
16
|
+
if self.config.model is None:
|
|
17
|
+
raise ValueError("`model` parameter is required")
|
|
18
|
+
|
|
19
|
+
if not isinstance(self.config.model, Embeddings):
|
|
20
|
+
raise ValueError("`model` must be an instance of Embeddings")
|
|
21
|
+
|
|
22
|
+
self.langchain_model = self.config.model
|
|
23
|
+
|
|
24
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
25
|
+
"""
|
|
26
|
+
Get the embedding for the given text using Langchain.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
text (str): The text to embed.
|
|
30
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
31
|
+
Returns:
|
|
32
|
+
list: The embedding vector.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
return self.langchain_model.embed_query(text)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Literal, Optional
|
|
2
|
+
|
|
3
|
+
from openai import OpenAI
|
|
4
|
+
|
|
5
|
+
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|
6
|
+
from mem0.embeddings.base import EmbeddingBase
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LMStudioEmbedding(EmbeddingBase):
|
|
10
|
+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
|
11
|
+
super().__init__(config)
|
|
12
|
+
|
|
13
|
+
self.config.model = self.config.model or "nomic-ai/nomic-embed-text-v1.5-GGUF/nomic-embed-text-v1.5.f16.gguf"
|
|
14
|
+
self.config.embedding_dims = self.config.embedding_dims or 1536
|
|
15
|
+
self.config.api_key = self.config.api_key or "lm-studio"
|
|
16
|
+
|
|
17
|
+
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
|
|
18
|
+
|
|
19
|
+
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
|
20
|
+
"""
|
|
21
|
+
Get the embedding for the given text using LM Studio.
|
|
22
|
+
Args:
|
|
23
|
+
text (str): The text to embed.
|
|
24
|
+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
|
25
|
+
Returns:
|
|
26
|
+
list: The embedding vector.
|
|
27
|
+
"""
|
|
28
|
+
text = text.replace("\n", " ")
|
|
29
|
+
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding
|