mem0ai-azure-mysql 0.1.115__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mem0/__init__.py +6 -0
  2. mem0/client/__init__.py +0 -0
  3. mem0/client/main.py +1535 -0
  4. mem0/client/project.py +860 -0
  5. mem0/client/utils.py +29 -0
  6. mem0/configs/__init__.py +0 -0
  7. mem0/configs/base.py +90 -0
  8. mem0/configs/dbs/__init__.py +4 -0
  9. mem0/configs/dbs/base.py +41 -0
  10. mem0/configs/dbs/mysql.py +25 -0
  11. mem0/configs/embeddings/__init__.py +0 -0
  12. mem0/configs/embeddings/base.py +108 -0
  13. mem0/configs/enums.py +7 -0
  14. mem0/configs/llms/__init__.py +0 -0
  15. mem0/configs/llms/base.py +152 -0
  16. mem0/configs/prompts.py +333 -0
  17. mem0/configs/vector_stores/__init__.py +0 -0
  18. mem0/configs/vector_stores/azure_ai_search.py +59 -0
  19. mem0/configs/vector_stores/baidu.py +29 -0
  20. mem0/configs/vector_stores/chroma.py +40 -0
  21. mem0/configs/vector_stores/elasticsearch.py +47 -0
  22. mem0/configs/vector_stores/faiss.py +39 -0
  23. mem0/configs/vector_stores/langchain.py +32 -0
  24. mem0/configs/vector_stores/milvus.py +43 -0
  25. mem0/configs/vector_stores/mongodb.py +25 -0
  26. mem0/configs/vector_stores/opensearch.py +41 -0
  27. mem0/configs/vector_stores/pgvector.py +37 -0
  28. mem0/configs/vector_stores/pinecone.py +56 -0
  29. mem0/configs/vector_stores/qdrant.py +49 -0
  30. mem0/configs/vector_stores/redis.py +26 -0
  31. mem0/configs/vector_stores/supabase.py +44 -0
  32. mem0/configs/vector_stores/upstash_vector.py +36 -0
  33. mem0/configs/vector_stores/vertex_ai_vector_search.py +27 -0
  34. mem0/configs/vector_stores/weaviate.py +43 -0
  35. mem0/dbs/__init__.py +4 -0
  36. mem0/dbs/base.py +68 -0
  37. mem0/dbs/configs.py +21 -0
  38. mem0/dbs/mysql.py +321 -0
  39. mem0/embeddings/__init__.py +0 -0
  40. mem0/embeddings/aws_bedrock.py +100 -0
  41. mem0/embeddings/azure_openai.py +43 -0
  42. mem0/embeddings/base.py +31 -0
  43. mem0/embeddings/configs.py +30 -0
  44. mem0/embeddings/gemini.py +39 -0
  45. mem0/embeddings/huggingface.py +41 -0
  46. mem0/embeddings/langchain.py +35 -0
  47. mem0/embeddings/lmstudio.py +29 -0
  48. mem0/embeddings/mock.py +11 -0
  49. mem0/embeddings/ollama.py +53 -0
  50. mem0/embeddings/openai.py +49 -0
  51. mem0/embeddings/together.py +31 -0
  52. mem0/embeddings/vertexai.py +54 -0
  53. mem0/graphs/__init__.py +0 -0
  54. mem0/graphs/configs.py +96 -0
  55. mem0/graphs/neptune/__init__.py +0 -0
  56. mem0/graphs/neptune/base.py +410 -0
  57. mem0/graphs/neptune/main.py +372 -0
  58. mem0/graphs/tools.py +371 -0
  59. mem0/graphs/utils.py +97 -0
  60. mem0/llms/__init__.py +0 -0
  61. mem0/llms/anthropic.py +64 -0
  62. mem0/llms/aws_bedrock.py +270 -0
  63. mem0/llms/azure_openai.py +114 -0
  64. mem0/llms/azure_openai_structured.py +76 -0
  65. mem0/llms/base.py +32 -0
  66. mem0/llms/configs.py +34 -0
  67. mem0/llms/deepseek.py +85 -0
  68. mem0/llms/gemini.py +201 -0
  69. mem0/llms/groq.py +88 -0
  70. mem0/llms/langchain.py +65 -0
  71. mem0/llms/litellm.py +87 -0
  72. mem0/llms/lmstudio.py +53 -0
  73. mem0/llms/ollama.py +94 -0
  74. mem0/llms/openai.py +124 -0
  75. mem0/llms/openai_structured.py +52 -0
  76. mem0/llms/sarvam.py +89 -0
  77. mem0/llms/together.py +88 -0
  78. mem0/llms/vllm.py +89 -0
  79. mem0/llms/xai.py +52 -0
  80. mem0/memory/__init__.py +0 -0
  81. mem0/memory/base.py +63 -0
  82. mem0/memory/graph_memory.py +632 -0
  83. mem0/memory/main.py +1843 -0
  84. mem0/memory/memgraph_memory.py +630 -0
  85. mem0/memory/setup.py +56 -0
  86. mem0/memory/storage.py +218 -0
  87. mem0/memory/telemetry.py +90 -0
  88. mem0/memory/utils.py +133 -0
  89. mem0/proxy/__init__.py +0 -0
  90. mem0/proxy/main.py +194 -0
  91. mem0/utils/factory.py +132 -0
  92. mem0/vector_stores/__init__.py +0 -0
  93. mem0/vector_stores/azure_ai_search.py +383 -0
  94. mem0/vector_stores/baidu.py +368 -0
  95. mem0/vector_stores/base.py +58 -0
  96. mem0/vector_stores/chroma.py +229 -0
  97. mem0/vector_stores/configs.py +60 -0
  98. mem0/vector_stores/elasticsearch.py +235 -0
  99. mem0/vector_stores/faiss.py +473 -0
  100. mem0/vector_stores/langchain.py +179 -0
  101. mem0/vector_stores/milvus.py +245 -0
  102. mem0/vector_stores/mongodb.py +293 -0
  103. mem0/vector_stores/opensearch.py +281 -0
  104. mem0/vector_stores/pgvector.py +294 -0
  105. mem0/vector_stores/pinecone.py +373 -0
  106. mem0/vector_stores/qdrant.py +240 -0
  107. mem0/vector_stores/redis.py +295 -0
  108. mem0/vector_stores/supabase.py +237 -0
  109. mem0/vector_stores/upstash_vector.py +293 -0
  110. mem0/vector_stores/vertex_ai_vector_search.py +629 -0
  111. mem0/vector_stores/weaviate.py +316 -0
  112. mem0ai_azure_mysql-0.1.115.data/data/README.md +169 -0
  113. mem0ai_azure_mysql-0.1.115.dist-info/METADATA +224 -0
  114. mem0ai_azure_mysql-0.1.115.dist-info/RECORD +116 -0
  115. mem0ai_azure_mysql-0.1.115.dist-info/WHEEL +4 -0
  116. mem0ai_azure_mysql-0.1.115.dist-info/licenses/LICENSE +201 -0
mem0/memory/storage.py ADDED
@@ -0,0 +1,218 @@
1
+ import logging
2
+ import sqlite3
3
+ import threading
4
+ import uuid
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class SQLiteManager:
11
+ def __init__(self, db_path: str = ":memory:"):
12
+ self.db_path = db_path
13
+ self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
14
+ self._lock = threading.Lock()
15
+ self._migrate_history_table()
16
+ self._create_history_table()
17
+
18
+ def _migrate_history_table(self) -> None:
19
+ """
20
+ If a pre-existing history table had the old group-chat columns,
21
+ rename it, create the new schema, copy the intersecting data, then
22
+ drop the old table.
23
+ """
24
+ with self._lock:
25
+ try:
26
+ # Start a transaction
27
+ self.connection.execute("BEGIN")
28
+ cur = self.connection.cursor()
29
+
30
+ cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='history'")
31
+ if cur.fetchone() is None:
32
+ self.connection.execute("COMMIT")
33
+ return # nothing to migrate
34
+
35
+ cur.execute("PRAGMA table_info(history)")
36
+ old_cols = {row[1] for row in cur.fetchall()}
37
+
38
+ expected_cols = {
39
+ "id",
40
+ "memory_id",
41
+ "old_memory",
42
+ "new_memory",
43
+ "event",
44
+ "created_at",
45
+ "updated_at",
46
+ "is_deleted",
47
+ "actor_id",
48
+ "role",
49
+ }
50
+
51
+ if old_cols == expected_cols:
52
+ self.connection.execute("COMMIT")
53
+ return
54
+
55
+ logger.info("Migrating history table to new schema (no convo columns).")
56
+
57
+ # Clean up any existing history_old table from previous failed migration
58
+ cur.execute("DROP TABLE IF EXISTS history_old")
59
+
60
+ # Rename the current history table
61
+ cur.execute("ALTER TABLE history RENAME TO history_old")
62
+
63
+ # Create the new history table with updated schema
64
+ cur.execute(
65
+ """
66
+ CREATE TABLE history (
67
+ id TEXT PRIMARY KEY,
68
+ memory_id TEXT,
69
+ old_memory TEXT,
70
+ new_memory TEXT,
71
+ event TEXT,
72
+ created_at DATETIME,
73
+ updated_at DATETIME,
74
+ is_deleted INTEGER,
75
+ actor_id TEXT,
76
+ role TEXT
77
+ )
78
+ """
79
+ )
80
+
81
+ # Copy data from old table to new table
82
+ intersecting = list(expected_cols & old_cols)
83
+ if intersecting:
84
+ cols_csv = ", ".join(intersecting)
85
+ cur.execute(f"INSERT INTO history ({cols_csv}) SELECT {cols_csv} FROM history_old")
86
+
87
+ # Drop the old table
88
+ cur.execute("DROP TABLE history_old")
89
+
90
+ # Commit the transaction
91
+ self.connection.execute("COMMIT")
92
+ logger.info("History table migration completed successfully.")
93
+
94
+ except Exception as e:
95
+ # Rollback the transaction on any error
96
+ self.connection.execute("ROLLBACK")
97
+ logger.error(f"History table migration failed: {e}")
98
+ raise
99
+
100
+ def _create_history_table(self) -> None:
101
+ with self._lock:
102
+ try:
103
+ self.connection.execute("BEGIN")
104
+ self.connection.execute(
105
+ """
106
+ CREATE TABLE IF NOT EXISTS history (
107
+ id TEXT PRIMARY KEY,
108
+ memory_id TEXT,
109
+ old_memory TEXT,
110
+ new_memory TEXT,
111
+ event TEXT,
112
+ created_at DATETIME,
113
+ updated_at DATETIME,
114
+ is_deleted INTEGER,
115
+ actor_id TEXT,
116
+ role TEXT
117
+ )
118
+ """
119
+ )
120
+ self.connection.execute("COMMIT")
121
+ except Exception as e:
122
+ self.connection.execute("ROLLBACK")
123
+ logger.error(f"Failed to create history table: {e}")
124
+ raise
125
+
126
+ def add_history(
127
+ self,
128
+ memory_id: str,
129
+ old_memory: Optional[str],
130
+ new_memory: Optional[str],
131
+ event: str,
132
+ *,
133
+ created_at: Optional[str] = None,
134
+ updated_at: Optional[str] = None,
135
+ is_deleted: int = 0,
136
+ actor_id: Optional[str] = None,
137
+ role: Optional[str] = None,
138
+ ) -> None:
139
+ with self._lock:
140
+ try:
141
+ self.connection.execute("BEGIN")
142
+ self.connection.execute(
143
+ """
144
+ INSERT INTO history (
145
+ id, memory_id, old_memory, new_memory, event,
146
+ created_at, updated_at, is_deleted, actor_id, role
147
+ )
148
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
149
+ """,
150
+ (
151
+ str(uuid.uuid4()),
152
+ memory_id,
153
+ old_memory,
154
+ new_memory,
155
+ event,
156
+ created_at,
157
+ updated_at,
158
+ is_deleted,
159
+ actor_id,
160
+ role,
161
+ ),
162
+ )
163
+ self.connection.execute("COMMIT")
164
+ except Exception as e:
165
+ self.connection.execute("ROLLBACK")
166
+ logger.error(f"Failed to add history record: {e}")
167
+ raise
168
+
169
+ def get_history(self, memory_id: str) -> List[Dict[str, Any]]:
170
+ with self._lock:
171
+ cur = self.connection.execute(
172
+ """
173
+ SELECT id, memory_id, old_memory, new_memory, event,
174
+ created_at, updated_at, is_deleted, actor_id, role
175
+ FROM history
176
+ WHERE memory_id = ?
177
+ ORDER BY created_at ASC, DATETIME(updated_at) ASC
178
+ """,
179
+ (memory_id,),
180
+ )
181
+ rows = cur.fetchall()
182
+
183
+ return [
184
+ {
185
+ "id": r[0],
186
+ "memory_id": r[1],
187
+ "old_memory": r[2],
188
+ "new_memory": r[3],
189
+ "event": r[4],
190
+ "created_at": r[5],
191
+ "updated_at": r[6],
192
+ "is_deleted": bool(r[7]),
193
+ "actor_id": r[8],
194
+ "role": r[9],
195
+ }
196
+ for r in rows
197
+ ]
198
+
199
+ def reset(self) -> None:
200
+ """Drop and recreate the history table."""
201
+ with self._lock:
202
+ try:
203
+ self.connection.execute("BEGIN")
204
+ self.connection.execute("DROP TABLE IF EXISTS history")
205
+ self.connection.execute("COMMIT")
206
+ self._create_history_table()
207
+ except Exception as e:
208
+ self.connection.execute("ROLLBACK")
209
+ logger.error(f"Failed to reset history table: {e}")
210
+ raise
211
+
212
+ def close(self) -> None:
213
+ if self.connection:
214
+ self.connection.close()
215
+ self.connection = None
216
+
217
+ def __del__(self):
218
+ self.close()
@@ -0,0 +1,90 @@
1
+ import logging
2
+ import os
3
+ import platform
4
+ import sys
5
+
6
+ from posthog import Posthog
7
+
8
+ import mem0
9
+ from mem0.memory.setup import get_or_create_user_id
10
+
11
+ MEM0_TELEMETRY = os.environ.get("MEM0_TELEMETRY", "True")
12
+ PROJECT_API_KEY = "phc_hgJkUVJFYtmaJqrvf6CYN67TIQ8yhXAkWzUn9AMU4yX"
13
+ HOST = "https://us.i.posthog.com"
14
+
15
+ if isinstance(MEM0_TELEMETRY, str):
16
+ MEM0_TELEMETRY = MEM0_TELEMETRY.lower() in ("true", "1", "yes")
17
+
18
+ if not isinstance(MEM0_TELEMETRY, bool):
19
+ raise ValueError("MEM0_TELEMETRY must be a boolean value.")
20
+
21
+ logging.getLogger("posthog").setLevel(logging.CRITICAL + 1)
22
+ logging.getLogger("urllib3").setLevel(logging.CRITICAL + 1)
23
+
24
+
25
+ class AnonymousTelemetry:
26
+ def __init__(self, vector_store=None):
27
+ self.posthog = Posthog(project_api_key=PROJECT_API_KEY, host=HOST)
28
+
29
+ self.user_id = get_or_create_user_id(vector_store)
30
+
31
+ if not MEM0_TELEMETRY:
32
+ self.posthog.disabled = True
33
+
34
+ def capture_event(self, event_name, properties=None, user_email=None):
35
+ if properties is None:
36
+ properties = {}
37
+ properties = {
38
+ "client_source": "python",
39
+ "client_version": mem0.__version__,
40
+ "python_version": sys.version,
41
+ "os": sys.platform,
42
+ "os_version": platform.version(),
43
+ "os_release": platform.release(),
44
+ "processor": platform.processor(),
45
+ "machine": platform.machine(),
46
+ **properties,
47
+ }
48
+ distinct_id = self.user_id if user_email is None else user_email
49
+ self.posthog.capture(distinct_id=distinct_id, event=event_name, properties=properties)
50
+
51
+ def close(self):
52
+ self.posthog.shutdown()
53
+
54
+
55
+ client_telemetry = AnonymousTelemetry()
56
+
57
+
58
+ def capture_event(event_name, memory_instance, additional_data=None):
59
+ oss_telemetry = AnonymousTelemetry(
60
+ vector_store=memory_instance._telemetry_vector_store
61
+ if hasattr(memory_instance, "_telemetry_vector_store")
62
+ else None,
63
+ )
64
+
65
+ event_data = {
66
+ "collection": memory_instance.collection_name,
67
+ "vector_size": memory_instance.embedding_model.config.embedding_dims,
68
+ "history_store": "sqlite",
69
+ "graph_store": f"{memory_instance.graph.__class__.__module__}.{memory_instance.graph.__class__.__name__}"
70
+ if memory_instance.config.graph_store.config
71
+ else None,
72
+ "vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
73
+ "llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
74
+ "embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
75
+ "function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.api_version}",
76
+ }
77
+ if additional_data:
78
+ event_data.update(additional_data)
79
+
80
+ oss_telemetry.capture_event(event_name, event_data)
81
+
82
+
83
+ def capture_client_event(event_name, instance, additional_data=None):
84
+ event_data = {
85
+ "function": f"{instance.__class__.__module__}.{instance.__class__.__name__}",
86
+ }
87
+ if additional_data:
88
+ event_data.update(additional_data)
89
+
90
+ client_telemetry.capture_event(event_name, event_data, instance.user_email)
mem0/memory/utils.py ADDED
@@ -0,0 +1,133 @@
1
+ import hashlib
2
+ import re
3
+
4
+ from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
5
+
6
+
7
+ def get_fact_retrieval_messages(message):
8
+ return FACT_RETRIEVAL_PROMPT, f"Input:\n{message}"
9
+
10
+
11
+ def parse_messages(messages):
12
+ response = ""
13
+ for msg in messages:
14
+ if msg["role"] == "system":
15
+ response += f"system: {msg['content']}\n"
16
+ if msg["role"] == "user":
17
+ response += f"user: {msg['content']}\n"
18
+ if msg["role"] == "assistant":
19
+ response += f"assistant: {msg['content']}\n"
20
+ return response
21
+
22
+
23
+ def format_entities(entities):
24
+ if not entities:
25
+ return ""
26
+
27
+ formatted_lines = []
28
+ for entity in entities:
29
+ simplified = f"{entity['source']} -- {entity['relationship']} -- {entity['destination']}"
30
+ formatted_lines.append(simplified)
31
+
32
+ return "\n".join(formatted_lines)
33
+
34
+
35
+ def remove_code_blocks(content: str) -> str:
36
+ """
37
+ Removes enclosing code block markers ```[language] and ``` from a given string.
38
+
39
+ Remarks:
40
+ - The function uses a regex pattern to match code blocks that may start with ``` followed by an optional language tag (letters or numbers) and end with ```.
41
+ - If a code block is detected, it returns only the inner content, stripping out the markers.
42
+ - If no code block markers are found, the original content is returned as-is.
43
+ """
44
+ pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
45
+ match = re.match(pattern, content.strip())
46
+ return match.group(1).strip() if match else content.strip()
47
+
48
+
49
+ def extract_json(text):
50
+ """
51
+ Extracts JSON content from a string, removing enclosing triple backticks and optional 'json' tag if present.
52
+ If no code block is found, returns the text as-is.
53
+ """
54
+ text = text.strip()
55
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
56
+ if match:
57
+ json_str = match.group(1)
58
+ else:
59
+ json_str = text # assume it's raw JSON
60
+ return json_str
61
+
62
+
63
+ def get_image_description(image_obj, llm, vision_details):
64
+ """
65
+ Get the description of the image
66
+ """
67
+
68
+ if isinstance(image_obj, str):
69
+ messages = [
70
+ {
71
+ "role": "user",
72
+ "content": [
73
+ {
74
+ "type": "text",
75
+ "text": "A user is providing an image. Provide a high level description of the image and do not include any additional text.",
76
+ },
77
+ {"type": "image_url", "image_url": {"url": image_obj, "detail": vision_details}},
78
+ ],
79
+ },
80
+ ]
81
+ else:
82
+ messages = [image_obj]
83
+
84
+ response = llm.generate_response(messages=messages)
85
+ return response
86
+
87
+
88
+ def parse_vision_messages(messages, llm=None, vision_details="auto"):
89
+ """
90
+ Parse the vision messages from the messages
91
+ """
92
+ returned_messages = []
93
+ for msg in messages:
94
+ if msg["role"] == "system":
95
+ returned_messages.append(msg)
96
+ continue
97
+
98
+ # Handle message content
99
+ if isinstance(msg["content"], list):
100
+ # Multiple image URLs in content
101
+ description = get_image_description(msg, llm, vision_details)
102
+ returned_messages.append({"role": msg["role"], "content": description})
103
+ elif isinstance(msg["content"], dict) and msg["content"].get("type") == "image_url":
104
+ # Single image content
105
+ image_url = msg["content"]["image_url"]["url"]
106
+ try:
107
+ description = get_image_description(image_url, llm, vision_details)
108
+ returned_messages.append({"role": msg["role"], "content": description})
109
+ except Exception:
110
+ raise Exception(f"Error while downloading {image_url}.")
111
+ else:
112
+ # Regular text content
113
+ returned_messages.append(msg)
114
+
115
+ return returned_messages
116
+
117
+
118
+ def process_telemetry_filters(filters):
119
+ """
120
+ Process the telemetry filters
121
+ """
122
+ if filters is None:
123
+ return {}
124
+
125
+ encoded_ids = {}
126
+ if "user_id" in filters:
127
+ encoded_ids["user_id"] = hashlib.md5(filters["user_id"].encode()).hexdigest()
128
+ if "agent_id" in filters:
129
+ encoded_ids["agent_id"] = hashlib.md5(filters["agent_id"].encode()).hexdigest()
130
+ if "run_id" in filters:
131
+ encoded_ids["run_id"] = hashlib.md5(filters["run_id"].encode()).hexdigest()
132
+
133
+ return list(filters.keys()), encoded_ids
mem0/proxy/__init__.py ADDED
File without changes
mem0/proxy/main.py ADDED
@@ -0,0 +1,194 @@
1
+ import logging
2
+ import subprocess
3
+ import sys
4
+ import threading
5
+ from typing import List, Optional, Union
6
+
7
+ import httpx
8
+
9
+ import mem0
10
+
11
+ try:
12
+ import litellm
13
+ except ImportError:
14
+ user_input = input("The 'litellm' library is required. Install it now? [y/N]: ")
15
+ if user_input.lower() == "y":
16
+ try:
17
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "litellm"])
18
+ import litellm
19
+ except subprocess.CalledProcessError:
20
+ print("Failed to install 'litellm'. Please install it manually using 'pip install litellm'.")
21
+ sys.exit(1)
22
+ else:
23
+ raise ImportError("The required 'litellm' library is not installed.")
24
+ sys.exit(1)
25
+
26
+ from mem0 import Memory, MemoryClient
27
+ from mem0.configs.prompts import MEMORY_ANSWER_PROMPT
28
+ from mem0.memory.telemetry import capture_client_event, capture_event
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class Mem0:
34
+ def __init__(
35
+ self,
36
+ config: Optional[dict] = None,
37
+ api_key: Optional[str] = None,
38
+ host: Optional[str] = None,
39
+ ):
40
+ if api_key:
41
+ self.mem0_client = MemoryClient(api_key, host)
42
+ else:
43
+ self.mem0_client = Memory.from_config(config) if config else Memory()
44
+
45
+ self.chat = Chat(self.mem0_client)
46
+
47
+
48
+ class Chat:
49
+ def __init__(self, mem0_client):
50
+ self.completions = Completions(mem0_client)
51
+
52
+
53
+ class Completions:
54
+ def __init__(self, mem0_client):
55
+ self.mem0_client = mem0_client
56
+
57
+ def create(
58
+ self,
59
+ model: str,
60
+ messages: List = [],
61
+ # Mem0 arguments
62
+ user_id: Optional[str] = None,
63
+ agent_id: Optional[str] = None,
64
+ run_id: Optional[str] = None,
65
+ metadata: Optional[dict] = None,
66
+ filters: Optional[dict] = None,
67
+ limit: Optional[int] = 10,
68
+ # LLM arguments
69
+ timeout: Optional[Union[float, str, httpx.Timeout]] = None,
70
+ temperature: Optional[float] = None,
71
+ top_p: Optional[float] = None,
72
+ n: Optional[int] = None,
73
+ stream: Optional[bool] = None,
74
+ stream_options: Optional[dict] = None,
75
+ stop=None,
76
+ max_tokens: Optional[int] = None,
77
+ presence_penalty: Optional[float] = None,
78
+ frequency_penalty: Optional[float] = None,
79
+ logit_bias: Optional[dict] = None,
80
+ user: Optional[str] = None,
81
+ # openai v1.0+ new params
82
+ response_format: Optional[dict] = None,
83
+ seed: Optional[int] = None,
84
+ tools: Optional[List] = None,
85
+ tool_choice: Optional[Union[str, dict]] = None,
86
+ logprobs: Optional[bool] = None,
87
+ top_logprobs: Optional[int] = None,
88
+ parallel_tool_calls: Optional[bool] = None,
89
+ deployment_id=None,
90
+ extra_headers: Optional[dict] = None,
91
+ # soon to be deprecated params by OpenAI
92
+ functions: Optional[List] = None,
93
+ function_call: Optional[str] = None,
94
+ # set api_base, api_version, api_key
95
+ base_url: Optional[str] = None,
96
+ api_version: Optional[str] = None,
97
+ api_key: Optional[str] = None,
98
+ model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
99
+ ):
100
+ if not any([user_id, agent_id, run_id]):
101
+ raise ValueError("One of user_id, agent_id, run_id must be provided")
102
+
103
+ if not litellm.supports_function_calling(model):
104
+ raise ValueError(
105
+ f"Model '{model}' does not support function calling. Please use a model that supports function calling."
106
+ )
107
+
108
+ prepared_messages = self._prepare_messages(messages)
109
+ if prepared_messages[-1]["role"] == "user":
110
+ self._async_add_to_memory(messages, user_id, agent_id, run_id, metadata, filters)
111
+ relevant_memories = self._fetch_relevant_memories(messages, user_id, agent_id, run_id, filters, limit)
112
+ logger.debug(f"Retrieved {len(relevant_memories)} relevant memories")
113
+ prepared_messages[-1]["content"] = self._format_query_with_memories(messages, relevant_memories)
114
+
115
+ response = litellm.completion(
116
+ model=model,
117
+ messages=prepared_messages,
118
+ temperature=temperature,
119
+ top_p=top_p,
120
+ n=n,
121
+ timeout=timeout,
122
+ stream=stream,
123
+ stream_options=stream_options,
124
+ stop=stop,
125
+ max_tokens=max_tokens,
126
+ presence_penalty=presence_penalty,
127
+ frequency_penalty=frequency_penalty,
128
+ logit_bias=logit_bias,
129
+ user=user,
130
+ response_format=response_format,
131
+ seed=seed,
132
+ tools=tools,
133
+ tool_choice=tool_choice,
134
+ logprobs=logprobs,
135
+ top_logprobs=top_logprobs,
136
+ parallel_tool_calls=parallel_tool_calls,
137
+ deployment_id=deployment_id,
138
+ extra_headers=extra_headers,
139
+ functions=functions,
140
+ function_call=function_call,
141
+ base_url=base_url,
142
+ api_version=api_version,
143
+ api_key=api_key,
144
+ model_list=model_list,
145
+ )
146
+ if isinstance(self.mem0_client, Memory):
147
+ capture_event("mem0.chat.create", self.mem0_client)
148
+ else:
149
+ capture_client_event("mem0.chat.create", self.mem0_client)
150
+ return response
151
+
152
+ def _prepare_messages(self, messages: List[dict]) -> List[dict]:
153
+ if not messages or messages[0]["role"] != "system":
154
+ return [{"role": "system", "content": MEMORY_ANSWER_PROMPT}] + messages
155
+ return messages
156
+
157
+ def _async_add_to_memory(self, messages, user_id, agent_id, run_id, metadata, filters):
158
+ def add_task():
159
+ logger.debug("Adding to memory asynchronously")
160
+ self.mem0_client.add(
161
+ messages=messages,
162
+ user_id=user_id,
163
+ agent_id=agent_id,
164
+ run_id=run_id,
165
+ metadata=metadata,
166
+ filters=filters,
167
+ )
168
+
169
+ threading.Thread(target=add_task, daemon=True).start()
170
+
171
+ def _fetch_relevant_memories(self, messages, user_id, agent_id, run_id, filters, limit):
172
+ # Currently, only pass the last 6 messages to the search API to prevent long query
173
+ message_input = [f"{message['role']}: {message['content']}" for message in messages][-6:]
174
+ # TODO: Make it better by summarizing the past conversation
175
+ return self.mem0_client.search(
176
+ query="\n".join(message_input),
177
+ user_id=user_id,
178
+ agent_id=agent_id,
179
+ run_id=run_id,
180
+ filters=filters,
181
+ limit=limit,
182
+ )
183
+
184
+ def _format_query_with_memories(self, messages, relevant_memories):
185
+ # Check if self.mem0_client is an instance of Memory or MemoryClient
186
+
187
+ entities = []
188
+ if isinstance(self.mem0_client, mem0.memory.main.Memory):
189
+ memories_text = "\n".join(memory["memory"] for memory in relevant_memories["results"])
190
+ if relevant_memories.get("relations"):
191
+ entities = [entity for entity in relevant_memories["relations"]]
192
+ elif isinstance(self.mem0_client, mem0.client.main.MemoryClient):
193
+ memories_text = "\n".join(memory["memory"] for memory in relevant_memories)
194
+ return f"- Relevant Memories/Facts: {memories_text}\n\n- Entities: {entities}\n\n- User Question: {messages[-1]['content']}"