mem0ai-azure-mysql 0.1.115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem0/__init__.py +6 -0
- mem0/client/__init__.py +0 -0
- mem0/client/main.py +1535 -0
- mem0/client/project.py +860 -0
- mem0/client/utils.py +29 -0
- mem0/configs/__init__.py +0 -0
- mem0/configs/base.py +90 -0
- mem0/configs/dbs/__init__.py +4 -0
- mem0/configs/dbs/base.py +41 -0
- mem0/configs/dbs/mysql.py +25 -0
- mem0/configs/embeddings/__init__.py +0 -0
- mem0/configs/embeddings/base.py +108 -0
- mem0/configs/enums.py +7 -0
- mem0/configs/llms/__init__.py +0 -0
- mem0/configs/llms/base.py +152 -0
- mem0/configs/prompts.py +333 -0
- mem0/configs/vector_stores/__init__.py +0 -0
- mem0/configs/vector_stores/azure_ai_search.py +59 -0
- mem0/configs/vector_stores/baidu.py +29 -0
- mem0/configs/vector_stores/chroma.py +40 -0
- mem0/configs/vector_stores/elasticsearch.py +47 -0
- mem0/configs/vector_stores/faiss.py +39 -0
- mem0/configs/vector_stores/langchain.py +32 -0
- mem0/configs/vector_stores/milvus.py +43 -0
- mem0/configs/vector_stores/mongodb.py +25 -0
- mem0/configs/vector_stores/opensearch.py +41 -0
- mem0/configs/vector_stores/pgvector.py +37 -0
- mem0/configs/vector_stores/pinecone.py +56 -0
- mem0/configs/vector_stores/qdrant.py +49 -0
- mem0/configs/vector_stores/redis.py +26 -0
- mem0/configs/vector_stores/supabase.py +44 -0
- mem0/configs/vector_stores/upstash_vector.py +36 -0
- mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
- mem0/configs/vector_stores/weaviate.py +43 -0
- mem0/dbs/__init__.py +4 -0
- mem0/dbs/base.py +68 -0
- mem0/dbs/configs.py +21 -0
- mem0/dbs/mysql.py +321 -0
- mem0/embeddings/__init__.py +0 -0
- mem0/embeddings/aws_bedrock.py +100 -0
- mem0/embeddings/azure_openai.py +43 -0
- mem0/embeddings/base.py +31 -0
- mem0/embeddings/configs.py +30 -0
- mem0/embeddings/gemini.py +39 -0
- mem0/embeddings/huggingface.py +41 -0
- mem0/embeddings/langchain.py +35 -0
- mem0/embeddings/lmstudio.py +29 -0
- mem0/embeddings/mock.py +11 -0
- mem0/embeddings/ollama.py +53 -0
- mem0/embeddings/openai.py +49 -0
- mem0/embeddings/together.py +31 -0
- mem0/embeddings/vertexai.py +54 -0
- mem0/graphs/__init__.py +0 -0
- mem0/graphs/configs.py +96 -0
- mem0/graphs/neptune/__init__.py +0 -0
- mem0/graphs/neptune/base.py +410 -0
- mem0/graphs/neptune/main.py +372 -0
- mem0/graphs/tools.py +371 -0
- mem0/graphs/utils.py +97 -0
- mem0/llms/__init__.py +0 -0
- mem0/llms/anthropic.py +64 -0
- mem0/llms/aws_bedrock.py +270 -0
- mem0/llms/azure_openai.py +114 -0
- mem0/llms/azure_openai_structured.py +76 -0
- mem0/llms/base.py +32 -0
- mem0/llms/configs.py +34 -0
- mem0/llms/deepseek.py +85 -0
- mem0/llms/gemini.py +201 -0
- mem0/llms/groq.py +88 -0
- mem0/llms/langchain.py +65 -0
- mem0/llms/litellm.py +87 -0
- mem0/llms/lmstudio.py +53 -0
- mem0/llms/ollama.py +94 -0
- mem0/llms/openai.py +124 -0
- mem0/llms/openai_structured.py +52 -0
- mem0/llms/sarvam.py +89 -0
- mem0/llms/together.py +88 -0
- mem0/llms/vllm.py +89 -0
- mem0/llms/xai.py +52 -0
- mem0/memory/__init__.py +0 -0
- mem0/memory/base.py +63 -0
- mem0/memory/graph_memory.py +632 -0
- mem0/memory/main.py +1843 -0
- mem0/memory/memgraph_memory.py +630 -0
- mem0/memory/setup.py +56 -0
- mem0/memory/storage.py +218 -0
- mem0/memory/telemetry.py +90 -0
- mem0/memory/utils.py +133 -0
- mem0/proxy/__init__.py +0 -0
- mem0/proxy/main.py +194 -0
- mem0/utils/factory.py +132 -0
- mem0/vector_stores/__init__.py +0 -0
- mem0/vector_stores/azure_ai_search.py +383 -0
- mem0/vector_stores/baidu.py +368 -0
- mem0/vector_stores/base.py +58 -0
- mem0/vector_stores/chroma.py +229 -0
- mem0/vector_stores/configs.py +60 -0
- mem0/vector_stores/elasticsearch.py +235 -0
- mem0/vector_stores/faiss.py +473 -0
- mem0/vector_stores/langchain.py +179 -0
- mem0/vector_stores/milvus.py +245 -0
- mem0/vector_stores/mongodb.py +293 -0
- mem0/vector_stores/opensearch.py +281 -0
- mem0/vector_stores/pgvector.py +294 -0
- mem0/vector_stores/pinecone.py +373 -0
- mem0/vector_stores/qdrant.py +240 -0
- mem0/vector_stores/redis.py +295 -0
- mem0/vector_stores/supabase.py +237 -0
- mem0/vector_stores/upstash_vector.py +293 -0
- mem0/vector_stores/vertex_ai_vector_search.py +629 -0
- mem0/vector_stores/weaviate.py +316 -0
- mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
- mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
- mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
- mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
- mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
mem0/memory/storage.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sqlite3
|
|
3
|
+
import threading
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SQLiteManager:
|
|
11
|
+
def __init__(self, db_path: str = ":memory:"):
|
|
12
|
+
self.db_path = db_path
|
|
13
|
+
self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
14
|
+
self._lock = threading.Lock()
|
|
15
|
+
self._migrate_history_table()
|
|
16
|
+
self._create_history_table()
|
|
17
|
+
|
|
18
|
+
def _migrate_history_table(self) -> None:
|
|
19
|
+
"""
|
|
20
|
+
If a pre-existing history table had the old group-chat columns,
|
|
21
|
+
rename it, create the new schema, copy the intersecting data, then
|
|
22
|
+
drop the old table.
|
|
23
|
+
"""
|
|
24
|
+
with self._lock:
|
|
25
|
+
try:
|
|
26
|
+
# Start a transaction
|
|
27
|
+
self.connection.execute("BEGIN")
|
|
28
|
+
cur = self.connection.cursor()
|
|
29
|
+
|
|
30
|
+
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
|
|
31
|
+
if cur.fetchone() is None:
|
|
32
|
+
self.connection.execute("COMMIT")
|
|
33
|
+
return # nothing to migrate
|
|
34
|
+
|
|
35
|
+
cur.execute("PRAGMA table_info(history)")
|
|
36
|
+
old_cols = {row[1] for row in cur.fetchall()}
|
|
37
|
+
|
|
38
|
+
expected_cols = {
|
|
39
|
+
"id",
|
|
40
|
+
"memory_id",
|
|
41
|
+
"old_memory",
|
|
42
|
+
"new_memory",
|
|
43
|
+
"event",
|
|
44
|
+
"created_at",
|
|
45
|
+
"updated_at",
|
|
46
|
+
"is_deleted",
|
|
47
|
+
"actor_id",
|
|
48
|
+
"role",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
if old_cols == expected_cols:
|
|
52
|
+
self.connection.execute("COMMIT")
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
logger.info("Migrating history table to new schema (no convo columns).")
|
|
56
|
+
|
|
57
|
+
# Clean up any existing history_old table from previous failed migration
|
|
58
|
+
cur.execute("DROP TABLE IF EXISTS history_old")
|
|
59
|
+
|
|
60
|
+
# Rename the current history table
|
|
61
|
+
cur.execute("ALTER TABLE history RENAME TO history_old")
|
|
62
|
+
|
|
63
|
+
# Create the new history table with updated schema
|
|
64
|
+
cur.execute(
|
|
65
|
+
"""
|
|
66
|
+
CREATE TABLE history (
|
|
67
|
+
id TEXT PRIMARY KEY,
|
|
68
|
+
memory_id TEXT,
|
|
69
|
+
old_memory TEXT,
|
|
70
|
+
new_memory TEXT,
|
|
71
|
+
event TEXT,
|
|
72
|
+
created_at DATETIME,
|
|
73
|
+
updated_at DATETIME,
|
|
74
|
+
is_deleted INTEGER,
|
|
75
|
+
actor_id TEXT,
|
|
76
|
+
role TEXT
|
|
77
|
+
)
|
|
78
|
+
"""
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Copy data from old table to new table
|
|
82
|
+
intersecting = list(expected_cols & old_cols)
|
|
83
|
+
if intersecting:
|
|
84
|
+
cols_csv = ", ".join(intersecting)
|
|
85
|
+
cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
|
|
86
|
+
|
|
87
|
+
# Drop the old table
|
|
88
|
+
cur.execute("DROP TABLE history_old")
|
|
89
|
+
|
|
90
|
+
# Commit the transaction
|
|
91
|
+
self.connection.execute("COMMIT")
|
|
92
|
+
logger.info("History table migration completed successfully.")
|
|
93
|
+
|
|
94
|
+
except Exception as e:
|
|
95
|
+
# Rollback the transaction on any error
|
|
96
|
+
self.connection.execute("ROLLBACK")
|
|
97
|
+
logger.error(f"History table migration failed: {e}")
|
|
98
|
+
raise
|
|
99
|
+
|
|
100
|
+
def _create_history_table(self) -> None:
|
|
101
|
+
with self._lock:
|
|
102
|
+
try:
|
|
103
|
+
self.connection.execute("BEGIN")
|
|
104
|
+
self.connection.execute(
|
|
105
|
+
"""
|
|
106
|
+
CREATE TABLE IF NOT EXISTS history (
|
|
107
|
+
id TEXT PRIMARY KEY,
|
|
108
|
+
memory_id TEXT,
|
|
109
|
+
old_memory TEXT,
|
|
110
|
+
new_memory TEXT,
|
|
111
|
+
event TEXT,
|
|
112
|
+
created_at DATETIME,
|
|
113
|
+
updated_at DATETIME,
|
|
114
|
+
is_deleted INTEGER,
|
|
115
|
+
actor_id TEXT,
|
|
116
|
+
role TEXT
|
|
117
|
+
)
|
|
118
|
+
"""
|
|
119
|
+
)
|
|
120
|
+
self.connection.execute("COMMIT")
|
|
121
|
+
except Exception as e:
|
|
122
|
+
self.connection.execute("ROLLBACK")
|
|
123
|
+
logger.error(f"Failed to create history table: {e}")
|
|
124
|
+
raise
|
|
125
|
+
|
|
126
|
+
def add_history(
|
|
127
|
+
self,
|
|
128
|
+
memory_id: str,
|
|
129
|
+
old_memory: Optional[str],
|
|
130
|
+
new_memory: Optional[str],
|
|
131
|
+
event: str,
|
|
132
|
+
*,
|
|
133
|
+
created_at: Optional[str] = None,
|
|
134
|
+
updated_at: Optional[str] = None,
|
|
135
|
+
is_deleted: int = 0,
|
|
136
|
+
actor_id: Optional[str] = None,
|
|
137
|
+
role: Optional[str] = None,
|
|
138
|
+
) -> None:
|
|
139
|
+
with self._lock:
|
|
140
|
+
try:
|
|
141
|
+
self.connection.execute("BEGIN")
|
|
142
|
+
self.connection.execute(
|
|
143
|
+
"""
|
|
144
|
+
INSERT INTO history (
|
|
145
|
+
id, memory_id, old_memory, new_memory, event,
|
|
146
|
+
created_at, updated_at, is_deleted, actor_id, role
|
|
147
|
+
)
|
|
148
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
149
|
+
""",
|
|
150
|
+
(
|
|
151
|
+
str(uuid.uuid4()),
|
|
152
|
+
memory_id,
|
|
153
|
+
old_memory,
|
|
154
|
+
new_memory,
|
|
155
|
+
event,
|
|
156
|
+
created_at,
|
|
157
|
+
updated_at,
|
|
158
|
+
is_deleted,
|
|
159
|
+
actor_id,
|
|
160
|
+
role,
|
|
161
|
+
),
|
|
162
|
+
)
|
|
163
|
+
self.connection.execute("COMMIT")
|
|
164
|
+
except Exception as e:
|
|
165
|
+
self.connection.execute("ROLLBACK")
|
|
166
|
+
logger.error(f"Failed to add history record: {e}")
|
|
167
|
+
raise
|
|
168
|
+
|
|
169
|
+
def get_history(self, memory_id: str) -> List[Dict[str, Any]]:
|
|
170
|
+
with self._lock:
|
|
171
|
+
cur = self.connection.execute(
|
|
172
|
+
"""
|
|
173
|
+
SELECT id, memory_id, old_memory, new_memory, event,
|
|
174
|
+
created_at, updated_at, is_deleted, actor_id, role
|
|
175
|
+
FROM history
|
|
176
|
+
WHERE memory_id = ?
|
|
177
|
+
ORDER BY created_at ASC, DATETIME(updated_at) ASC
|
|
178
|
+
""",
|
|
179
|
+
(memory_id,),
|
|
180
|
+
)
|
|
181
|
+
rows = cur.fetchall()
|
|
182
|
+
|
|
183
|
+
return [
|
|
184
|
+
{
|
|
185
|
+
"id": r[0],
|
|
186
|
+
"memory_id": r[1],
|
|
187
|
+
"old_memory": r[2],
|
|
188
|
+
"new_memory": r[3],
|
|
189
|
+
"event": r[4],
|
|
190
|
+
"created_at": r[5],
|
|
191
|
+
"updated_at": r[6],
|
|
192
|
+
"is_deleted": bool(r[7]),
|
|
193
|
+
"actor_id": r[8],
|
|
194
|
+
"role": r[9],
|
|
195
|
+
}
|
|
196
|
+
for r in rows
|
|
197
|
+
]
|
|
198
|
+
|
|
199
|
+
def reset(self) -> None:
|
|
200
|
+
"""Drop and recreate the history table."""
|
|
201
|
+
with self._lock:
|
|
202
|
+
try:
|
|
203
|
+
self.connection.execute("BEGIN")
|
|
204
|
+
self.connection.execute("DROP TABLE IF EXISTS history")
|
|
205
|
+
self.connection.execute("COMMIT")
|
|
206
|
+
self._create_history_table()
|
|
207
|
+
except Exception as e:
|
|
208
|
+
self.connection.execute("ROLLBACK")
|
|
209
|
+
logger.error(f"Failed to reset history table: {e}")
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
def close(self) -> None:
|
|
213
|
+
if self.connection:
|
|
214
|
+
self.connection.close()
|
|
215
|
+
self.connection = None
|
|
216
|
+
|
|
217
|
+
def __del__(self):
|
|
218
|
+
self.close()
|
mem0/memory/telemetry.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import platform
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from posthog import Posthog
|
|
7
|
+
|
|
8
|
+
import mem0
|
|
9
|
+
from mem0.memory.setup import get_or_create_user_id
|
|
10
|
+
|
|
11
|
+
MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
|
|
12
|
+
PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
|
|
13
|
+
HOST = "https://us.i.posthog.com"
|
|
14
|
+
|
|
15
|
+
if isinstance(MEM0_TELEMETRY, str):
|
|
16
|
+
MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
|
|
17
|
+
|
|
18
|
+
if not isinstance(MEM0_TELEMETRY, bool):
|
|
19
|
+
raise ValueError("MEM0_TELEMETRY must be a boolean value.")
|
|
20
|
+
|
|
21
|
+
logging.getLogger("posthog").setLevel(logging.CRITICAL + 1)
|
|
22
|
+
logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnonymousTelemetry:
|
|
26
|
+
def __init__(self, vector_store=None):
|
|
27
|
+
self.posthog = Posthog(project_api_key=PROJECT_API_KEY, host=HOST)
|
|
28
|
+
|
|
29
|
+
self.user_id = get_or_create_user_id(vector_store)
|
|
30
|
+
|
|
31
|
+
if not MEM0_TELEMETRY:
|
|
32
|
+
self.posthog.disabled = True
|
|
33
|
+
|
|
34
|
+
def capture_event(self, event_name, properties=None, user_email=None):
|
|
35
|
+
if properties is None:
|
|
36
|
+
properties = {}
|
|
37
|
+
properties = {
|
|
38
|
+
"client_source": "python",
|
|
39
|
+
"client_version": mem0.__version__,
|
|
40
|
+
"python_version": sys.version,
|
|
41
|
+
"os": sys.platform,
|
|
42
|
+
"os_version": platform.version(),
|
|
43
|
+
"os_release": platform.release(),
|
|
44
|
+
"processor": platform.processor(),
|
|
45
|
+
"machine": platform.machine(),
|
|
46
|
+
**properties,
|
|
47
|
+
}
|
|
48
|
+
distinct_id = self.user_id if user_email is None else user_email
|
|
49
|
+
self.posthog.capture(distinct_id=distinct_id, event=event_name, properties=properties)
|
|
50
|
+
|
|
51
|
+
def close(self):
|
|
52
|
+
self.posthog.shutdown()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
client_telemetry = AnonymousTelemetry()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def capture_event(event_name, memory_instance, additional_data=None):
|
|
59
|
+
oss_telemetry = AnonymousTelemetry(
|
|
60
|
+
vector_store=memory_instance._telemetry_vector_store
|
|
61
|
+
if hasattr(memory_instance, "_telemetry_vector_store")
|
|
62
|
+
else None,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
event_data = {
|
|
66
|
+
"collection": memory_instance.collection_name,
|
|
67
|
+
"vector_size": memory_instance.embedding_model.config.embedding_dims,
|
|
68
|
+
"history_store": "sqlite",
|
|
69
|
+
"graph_store": f"{memory_instance.graph.__class__.__module__}.{memory_instance.graph.__class__.__name__}"
|
|
70
|
+
if memory_instance.config.graph_store.config
|
|
71
|
+
else None,
|
|
72
|
+
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
|
|
73
|
+
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
|
|
74
|
+
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
|
|
75
|
+
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.api_version}",
|
|
76
|
+
}
|
|
77
|
+
if additional_data:
|
|
78
|
+
event_data.update(additional_data)
|
|
79
|
+
|
|
80
|
+
oss_telemetry.capture_event(event_name, event_data)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def capture_client_event(event_name, instance, additional_data=None):
|
|
84
|
+
event_data = {
|
|
85
|
+
"function": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
|
|
86
|
+
}
|
|
87
|
+
if additional_data:
|
|
88
|
+
event_data.update(additional_data)
|
|
89
|
+
|
|
90
|
+
client_telemetry.capture_event(event_name, event_data, instance.user_email)
|
mem0/memory/utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_fact_retrieval_messages(message):
|
|
8
|
+
return FACT_RETRIEVAL_PROMPT, f"Input:\n{message}"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def parse_messages(messages):
|
|
12
|
+
response = ""
|
|
13
|
+
for msg in messages:
|
|
14
|
+
if msg["role"] == "system":
|
|
15
|
+
response += f"system: {msg['content']}\n"
|
|
16
|
+
if msg["role"] == "user":
|
|
17
|
+
response += f"user: {msg['content']}\n"
|
|
18
|
+
if msg["role"] == "assistant":
|
|
19
|
+
response += f"assistant: {msg['content']}\n"
|
|
20
|
+
return response
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def format_entities(entities):
|
|
24
|
+
if not entities:
|
|
25
|
+
return ""
|
|
26
|
+
|
|
27
|
+
formatted_lines = []
|
|
28
|
+
for entity in entities:
|
|
29
|
+
simplified = f"{entity['source']} -- {entity['relationship']} -- {entity['destination']}"
|
|
30
|
+
formatted_lines.append(simplified)
|
|
31
|
+
|
|
32
|
+
return "\n".join(formatted_lines)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def remove_code_blocks(content: str) -> str:
|
|
36
|
+
"""
|
|
37
|
+
Removes enclosing code block markers ```[language] and ``` from a given string.
|
|
38
|
+
|
|
39
|
+
Remarks:
|
|
40
|
+
- The function uses a regex pattern to match code blocks that may start with ``` followed by an optional language tag (letters or numbers) and end with ```.
|
|
41
|
+
- If a code block is detected, it returns only the inner content, stripping out the markers.
|
|
42
|
+
- If no code block markers are found, the original content is returned as-is.
|
|
43
|
+
"""
|
|
44
|
+
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
|
|
45
|
+
match = re.match(pattern, content.strip())
|
|
46
|
+
return match.group(1).strip() if match else content.strip()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def extract_json(text):
|
|
50
|
+
"""
|
|
51
|
+
Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
|
|
52
|
+
If no code block is found, returns the text as-is.
|
|
53
|
+
"""
|
|
54
|
+
text = text.strip()
|
|
55
|
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
|
|
56
|
+
if match:
|
|
57
|
+
json_str = match.group(1)
|
|
58
|
+
else:
|
|
59
|
+
json_str = text # assume it's raw JSON
|
|
60
|
+
return json_str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_image_description(image_obj, llm, vision_details):
|
|
64
|
+
"""
|
|
65
|
+
Get the description of the image
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
if isinstance(image_obj, str):
|
|
69
|
+
messages = [
|
|
70
|
+
{
|
|
71
|
+
"role": "user",
|
|
72
|
+
"content": [
|
|
73
|
+
{
|
|
74
|
+
"type": "text",
|
|
75
|
+
"text": "A user is providing an image. Provide a high level description of the image and do not include any additional text.",
|
|
76
|
+
},
|
|
77
|
+
{"type": "image_url", "image_url": {"url": image_obj, "detail": vision_details}},
|
|
78
|
+
],
|
|
79
|
+
},
|
|
80
|
+
]
|
|
81
|
+
else:
|
|
82
|
+
messages = [image_obj]
|
|
83
|
+
|
|
84
|
+
response = llm.generate_response(messages=messages)
|
|
85
|
+
return response
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def parse_vision_messages(messages, llm=None, vision_details="auto"):
|
|
89
|
+
"""
|
|
90
|
+
Parse the vision messages from the messages
|
|
91
|
+
"""
|
|
92
|
+
returned_messages = []
|
|
93
|
+
for msg in messages:
|
|
94
|
+
if msg["role"] == "system":
|
|
95
|
+
returned_messages.append(msg)
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Handle message content
|
|
99
|
+
if isinstance(msg["content"], list):
|
|
100
|
+
# Multiple image URLs in content
|
|
101
|
+
description = get_image_description(msg, llm, vision_details)
|
|
102
|
+
returned_messages.append({"role": msg["role"], "content": description})
|
|
103
|
+
elif isinstance(msg["content"], dict) and msg["content"].get("type") == "image_url":
|
|
104
|
+
# Single image content
|
|
105
|
+
image_url = msg["content"]["image_url"]["url"]
|
|
106
|
+
try:
|
|
107
|
+
description = get_image_description(image_url, llm, vision_details)
|
|
108
|
+
returned_messages.append({"role": msg["role"], "content": description})
|
|
109
|
+
except Exception:
|
|
110
|
+
raise Exception(f"Error while downloading {image_url}.")
|
|
111
|
+
else:
|
|
112
|
+
# Regular text content
|
|
113
|
+
returned_messages.append(msg)
|
|
114
|
+
|
|
115
|
+
return returned_messages
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def process_telemetry_filters(filters):
|
|
119
|
+
"""
|
|
120
|
+
Process the telemetry filters
|
|
121
|
+
"""
|
|
122
|
+
if filters is None:
|
|
123
|
+
return {}
|
|
124
|
+
|
|
125
|
+
encoded_ids = {}
|
|
126
|
+
if "user_id" in filters:
|
|
127
|
+
encoded_ids["user_id"] = hashlib.md5(filters["user_id"].encode()).hexdigest()
|
|
128
|
+
if "agent_id" in filters:
|
|
129
|
+
encoded_ids["agent_id"] = hashlib.md5(filters["agent_id"].encode()).hexdigest()
|
|
130
|
+
if "run_id" in filters:
|
|
131
|
+
encoded_ids["run_id"] = hashlib.md5(filters["run_id"].encode()).hexdigest()
|
|
132
|
+
|
|
133
|
+
return list(filters.keys()), encoded_ids
|
mem0/proxy/__init__.py
ADDED
|
File without changes
|
mem0/proxy/main.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
import threading
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
import mem0
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import litellm
|
|
13
|
+
except ImportError:
|
|
14
|
+
user_input = input("The 'litellm' library is required. Install it now? [y/N]: ")
|
|
15
|
+
if user_input.lower() == "y":
|
|
16
|
+
try:
|
|
17
|
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
|
|
18
|
+
import litellm
|
|
19
|
+
except subprocess.CalledProcessError:
|
|
20
|
+
print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
|
|
21
|
+
sys.exit(1)
|
|
22
|
+
else:
|
|
23
|
+
raise ImportError("The required 'litellm' library is not installed.")
|
|
24
|
+
sys.exit(1)
|
|
25
|
+
|
|
26
|
+
from mem0 import Memory, MemoryClient
|
|
27
|
+
from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
|
|
28
|
+
from mem0.memory.telemetry import capture_client_event, capture_event
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Mem0:
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
config: Optional[dict] = None,
|
|
37
|
+
api_key: Optional[str] = None,
|
|
38
|
+
host: Optional[str] = None,
|
|
39
|
+
):
|
|
40
|
+
if api_key:
|
|
41
|
+
self.mem0_client = MemoryClient(api_key, host)
|
|
42
|
+
else:
|
|
43
|
+
self.mem0_client = Memory.from_config(config) if config else Memory()
|
|
44
|
+
|
|
45
|
+
self.chat = Chat(self.mem0_client)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Chat:
|
|
49
|
+
def __init__(self, mem0_client):
|
|
50
|
+
self.completions = Completions(mem0_client)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Completions:
|
|
54
|
+
def __init__(self, mem0_client):
|
|
55
|
+
self.mem0_client = mem0_client
|
|
56
|
+
|
|
57
|
+
def create(
|
|
58
|
+
self,
|
|
59
|
+
model: str,
|
|
60
|
+
messages: List = [],
|
|
61
|
+
# Mem0 arguments
|
|
62
|
+
user_id: Optional[str] = None,
|
|
63
|
+
agent_id: Optional[str] = None,
|
|
64
|
+
run_id: Optional[str] = None,
|
|
65
|
+
metadata: Optional[dict] = None,
|
|
66
|
+
filters: Optional[dict] = None,
|
|
67
|
+
limit: Optional[int] = 10,
|
|
68
|
+
# LLM arguments
|
|
69
|
+
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
|
70
|
+
temperature: Optional[float] = None,
|
|
71
|
+
top_p: Optional[float] = None,
|
|
72
|
+
n: Optional[int] = None,
|
|
73
|
+
stream: Optional[bool] = None,
|
|
74
|
+
stream_options: Optional[dict] = None,
|
|
75
|
+
stop=None,
|
|
76
|
+
max_tokens: Optional[int] = None,
|
|
77
|
+
presence_penalty: Optional[float] = None,
|
|
78
|
+
frequency_penalty: Optional[float] = None,
|
|
79
|
+
logit_bias: Optional[dict] = None,
|
|
80
|
+
user: Optional[str] = None,
|
|
81
|
+
# openai v1.0+ new params
|
|
82
|
+
response_format: Optional[dict] = None,
|
|
83
|
+
seed: Optional[int] = None,
|
|
84
|
+
tools: Optional[List] = None,
|
|
85
|
+
tool_choice: Optional[Union[str, dict]] = None,
|
|
86
|
+
logprobs: Optional[bool] = None,
|
|
87
|
+
top_logprobs: Optional[int] = None,
|
|
88
|
+
parallel_tool_calls: Optional[bool] = None,
|
|
89
|
+
deployment_id=None,
|
|
90
|
+
extra_headers: Optional[dict] = None,
|
|
91
|
+
# soon to be deprecated params by OpenAI
|
|
92
|
+
functions: Optional[List] = None,
|
|
93
|
+
function_call: Optional[str] = None,
|
|
94
|
+
# set api_base, api_version, api_key
|
|
95
|
+
base_url: Optional[str] = None,
|
|
96
|
+
api_version: Optional[str] = None,
|
|
97
|
+
api_key: Optional[str] = None,
|
|
98
|
+
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
|
|
99
|
+
):
|
|
100
|
+
if not any([user_id, agent_id, run_id]):
|
|
101
|
+
raise ValueError("One of user_id, agent_id, run_id must be provided")
|
|
102
|
+
|
|
103
|
+
if not litellm.supports_function_calling(model):
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"Model '{model}' does not support function calling. Please use a model that supports function calling."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
prepared_messages = self._prepare_messages(messages)
|
|
109
|
+
if prepared_messages[-1]["role"] == "user":
|
|
110
|
+
self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
|
|
111
|
+
relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
|
|
112
|
+
logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
|
|
113
|
+
prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
|
|
114
|
+
|
|
115
|
+
response = litellm.completion(
|
|
116
|
+
model=model,
|
|
117
|
+
messages=prepared_messages,
|
|
118
|
+
temperature=temperature,
|
|
119
|
+
top_p=top_p,
|
|
120
|
+
n=n,
|
|
121
|
+
timeout=timeout,
|
|
122
|
+
stream=stream,
|
|
123
|
+
stream_options=stream_options,
|
|
124
|
+
stop=stop,
|
|
125
|
+
max_tokens=max_tokens,
|
|
126
|
+
presence_penalty=presence_penalty,
|
|
127
|
+
frequency_penalty=frequency_penalty,
|
|
128
|
+
logit_bias=logit_bias,
|
|
129
|
+
user=user,
|
|
130
|
+
response_format=response_format,
|
|
131
|
+
seed=seed,
|
|
132
|
+
tools=tools,
|
|
133
|
+
tool_choice=tool_choice,
|
|
134
|
+
logprobs=logprobs,
|
|
135
|
+
top_logprobs=top_logprobs,
|
|
136
|
+
parallel_tool_calls=parallel_tool_calls,
|
|
137
|
+
deployment_id=deployment_id,
|
|
138
|
+
extra_headers=extra_headers,
|
|
139
|
+
functions=functions,
|
|
140
|
+
function_call=function_call,
|
|
141
|
+
base_url=base_url,
|
|
142
|
+
api_version=api_version,
|
|
143
|
+
api_key=api_key,
|
|
144
|
+
model_list=model_list,
|
|
145
|
+
)
|
|
146
|
+
if isinstance(self.mem0_client, Memory):
|
|
147
|
+
capture_event("mem0.chat.create", self.mem0_client)
|
|
148
|
+
else:
|
|
149
|
+
capture_client_event("mem0.chat.create", self.mem0_client)
|
|
150
|
+
return response
|
|
151
|
+
|
|
152
|
+
def _prepare_messages(self, messages: List[dict]) -> List[dict]:
|
|
153
|
+
if not messages or messages[0]["role"] != "system":
|
|
154
|
+
return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
|
|
155
|
+
return messages
|
|
156
|
+
|
|
157
|
+
def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
|
|
158
|
+
def add_task():
|
|
159
|
+
logger.debug("Adding to memory asynchronously")
|
|
160
|
+
self.mem0_client.add(
|
|
161
|
+
messages=messages,
|
|
162
|
+
user_id=user_id,
|
|
163
|
+
agent_id=agent_id,
|
|
164
|
+
run_id=run_id,
|
|
165
|
+
metadata=metadata,
|
|
166
|
+
filters=filters,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
threading.Thread(target=add_task, daemon=True).start()
|
|
170
|
+
|
|
171
|
+
def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
|
|
172
|
+
# Currently, only pass the last 6 messages to the search API to prevent long query
|
|
173
|
+
message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
|
|
174
|
+
# TODO: Make it better by summarizing the past conversation
|
|
175
|
+
return self.mem0_client.search(
|
|
176
|
+
query="\n".join(message_input),
|
|
177
|
+
user_id=user_id,
|
|
178
|
+
agent_id=agent_id,
|
|
179
|
+
run_id=run_id,
|
|
180
|
+
filters=filters,
|
|
181
|
+
limit=limit,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def _format_query_with_memories(self, messages, relevant_memories):
|
|
185
|
+
# Check if self.mem0_client is an instance of Memory or MemoryClient
|
|
186
|
+
|
|
187
|
+
entities = []
|
|
188
|
+
if isinstance(self.mem0_client, mem0.memory.main.Memory):
|
|
189
|
+
memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
|
|
190
|
+
if relevant_memories.get("relations"):
|
|
191
|
+
entities = [entity for entity in relevant_memories["relations"]]
|
|
192
|
+
elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
|
|
193
|
+
memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
|
|
194
|
+
return f"- Relevant Memories/Facts: {memories_text}\n\n- Entities: {entities}\n\n- User Question: {messages[-1]['content']}"
|