letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (189) hide show
  1. letta/__init__.py +24 -0
  2. letta/__main__.py +3 -0
  3. letta/agent.py +1427 -0
  4. letta/agent_store/chroma.py +295 -0
  5. letta/agent_store/db.py +546 -0
  6. letta/agent_store/lancedb.py +177 -0
  7. letta/agent_store/milvus.py +198 -0
  8. letta/agent_store/qdrant.py +201 -0
  9. letta/agent_store/storage.py +188 -0
  10. letta/benchmark/benchmark.py +96 -0
  11. letta/benchmark/constants.py +14 -0
  12. letta/cli/cli.py +689 -0
  13. letta/cli/cli_config.py +1282 -0
  14. letta/cli/cli_load.py +166 -0
  15. letta/client/__init__.py +0 -0
  16. letta/client/admin.py +171 -0
  17. letta/client/client.py +2360 -0
  18. letta/client/streaming.py +90 -0
  19. letta/client/utils.py +61 -0
  20. letta/config.py +484 -0
  21. letta/configs/anthropic.json +13 -0
  22. letta/configs/letta_hosted.json +11 -0
  23. letta/configs/openai.json +12 -0
  24. letta/constants.py +134 -0
  25. letta/credentials.py +140 -0
  26. letta/data_sources/connectors.py +247 -0
  27. letta/embeddings.py +218 -0
  28. letta/errors.py +26 -0
  29. letta/functions/__init__.py +0 -0
  30. letta/functions/function_sets/base.py +174 -0
  31. letta/functions/function_sets/extras.py +132 -0
  32. letta/functions/functions.py +105 -0
  33. letta/functions/schema_generator.py +205 -0
  34. letta/humans/__init__.py +0 -0
  35. letta/humans/examples/basic.txt +1 -0
  36. letta/humans/examples/cs_phd.txt +9 -0
  37. letta/interface.py +314 -0
  38. letta/llm_api/__init__.py +0 -0
  39. letta/llm_api/anthropic.py +383 -0
  40. letta/llm_api/azure_openai.py +155 -0
  41. letta/llm_api/cohere.py +396 -0
  42. letta/llm_api/google_ai.py +468 -0
  43. letta/llm_api/llm_api_tools.py +485 -0
  44. letta/llm_api/openai.py +470 -0
  45. letta/local_llm/README.md +3 -0
  46. letta/local_llm/__init__.py +0 -0
  47. letta/local_llm/chat_completion_proxy.py +279 -0
  48. letta/local_llm/constants.py +31 -0
  49. letta/local_llm/function_parser.py +68 -0
  50. letta/local_llm/grammars/__init__.py +0 -0
  51. letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
  52. letta/local_llm/grammars/json.gbnf +26 -0
  53. letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
  54. letta/local_llm/groq/api.py +97 -0
  55. letta/local_llm/json_parser.py +202 -0
  56. letta/local_llm/koboldcpp/api.py +62 -0
  57. letta/local_llm/koboldcpp/settings.py +23 -0
  58. letta/local_llm/llamacpp/api.py +58 -0
  59. letta/local_llm/llamacpp/settings.py +22 -0
  60. letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
  61. letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
  62. letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
  63. letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
  64. letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
  65. letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
  66. letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
  67. letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
  68. letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
  69. letta/local_llm/lmstudio/api.py +100 -0
  70. letta/local_llm/lmstudio/settings.py +29 -0
  71. letta/local_llm/ollama/api.py +88 -0
  72. letta/local_llm/ollama/settings.py +32 -0
  73. letta/local_llm/settings/__init__.py +0 -0
  74. letta/local_llm/settings/deterministic_mirostat.py +45 -0
  75. letta/local_llm/settings/settings.py +72 -0
  76. letta/local_llm/settings/simple.py +28 -0
  77. letta/local_llm/utils.py +265 -0
  78. letta/local_llm/vllm/api.py +63 -0
  79. letta/local_llm/webui/api.py +60 -0
  80. letta/local_llm/webui/legacy_api.py +58 -0
  81. letta/local_llm/webui/legacy_settings.py +23 -0
  82. letta/local_llm/webui/settings.py +24 -0
  83. letta/log.py +76 -0
  84. letta/main.py +437 -0
  85. letta/memory.py +440 -0
  86. letta/metadata.py +884 -0
  87. letta/openai_backcompat/__init__.py +0 -0
  88. letta/openai_backcompat/openai_object.py +437 -0
  89. letta/persistence_manager.py +148 -0
  90. letta/personas/__init__.py +0 -0
  91. letta/personas/examples/anna_pa.txt +13 -0
  92. letta/personas/examples/google_search_persona.txt +15 -0
  93. letta/personas/examples/memgpt_doc.txt +6 -0
  94. letta/personas/examples/memgpt_starter.txt +4 -0
  95. letta/personas/examples/sam.txt +14 -0
  96. letta/personas/examples/sam_pov.txt +14 -0
  97. letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
  98. letta/personas/examples/sqldb/test.db +0 -0
  99. letta/prompts/__init__.py +0 -0
  100. letta/prompts/gpt_summarize.py +14 -0
  101. letta/prompts/gpt_system.py +26 -0
  102. letta/prompts/system/memgpt_base.txt +49 -0
  103. letta/prompts/system/memgpt_chat.txt +58 -0
  104. letta/prompts/system/memgpt_chat_compressed.txt +13 -0
  105. letta/prompts/system/memgpt_chat_fstring.txt +51 -0
  106. letta/prompts/system/memgpt_doc.txt +50 -0
  107. letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
  108. letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
  109. letta/prompts/system/memgpt_modified_chat.txt +23 -0
  110. letta/pytest.ini +0 -0
  111. letta/schemas/agent.py +117 -0
  112. letta/schemas/api_key.py +21 -0
  113. letta/schemas/block.py +135 -0
  114. letta/schemas/document.py +21 -0
  115. letta/schemas/embedding_config.py +54 -0
  116. letta/schemas/enums.py +35 -0
  117. letta/schemas/job.py +38 -0
  118. letta/schemas/letta_base.py +80 -0
  119. letta/schemas/letta_message.py +175 -0
  120. letta/schemas/letta_request.py +23 -0
  121. letta/schemas/letta_response.py +28 -0
  122. letta/schemas/llm_config.py +54 -0
  123. letta/schemas/memory.py +224 -0
  124. letta/schemas/message.py +727 -0
  125. letta/schemas/openai/chat_completion_request.py +123 -0
  126. letta/schemas/openai/chat_completion_response.py +136 -0
  127. letta/schemas/openai/chat_completions.py +123 -0
  128. letta/schemas/openai/embedding_response.py +11 -0
  129. letta/schemas/openai/openai.py +157 -0
  130. letta/schemas/organization.py +20 -0
  131. letta/schemas/passage.py +80 -0
  132. letta/schemas/source.py +62 -0
  133. letta/schemas/tool.py +143 -0
  134. letta/schemas/usage.py +18 -0
  135. letta/schemas/user.py +33 -0
  136. letta/server/__init__.py +0 -0
  137. letta/server/constants.py +6 -0
  138. letta/server/rest_api/__init__.py +0 -0
  139. letta/server/rest_api/admin/__init__.py +0 -0
  140. letta/server/rest_api/admin/agents.py +21 -0
  141. letta/server/rest_api/admin/tools.py +83 -0
  142. letta/server/rest_api/admin/users.py +98 -0
  143. letta/server/rest_api/app.py +193 -0
  144. letta/server/rest_api/auth/__init__.py +0 -0
  145. letta/server/rest_api/auth/index.py +43 -0
  146. letta/server/rest_api/auth_token.py +22 -0
  147. letta/server/rest_api/interface.py +726 -0
  148. letta/server/rest_api/routers/__init__.py +0 -0
  149. letta/server/rest_api/routers/openai/__init__.py +0 -0
  150. letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
  151. letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
  152. letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
  153. letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
  154. letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
  155. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
  156. letta/server/rest_api/routers/v1/__init__.py +15 -0
  157. letta/server/rest_api/routers/v1/agents.py +543 -0
  158. letta/server/rest_api/routers/v1/blocks.py +73 -0
  159. letta/server/rest_api/routers/v1/jobs.py +46 -0
  160. letta/server/rest_api/routers/v1/llms.py +28 -0
  161. letta/server/rest_api/routers/v1/organizations.py +61 -0
  162. letta/server/rest_api/routers/v1/sources.py +199 -0
  163. letta/server/rest_api/routers/v1/tools.py +103 -0
  164. letta/server/rest_api/routers/v1/users.py +109 -0
  165. letta/server/rest_api/static_files.py +74 -0
  166. letta/server/rest_api/utils.py +69 -0
  167. letta/server/server.py +1995 -0
  168. letta/server/startup.sh +8 -0
  169. letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
  170. letta/server/static_files/assets/index-156816da.css +1 -0
  171. letta/server/static_files/assets/index-486e3228.js +274 -0
  172. letta/server/static_files/favicon.ico +0 -0
  173. letta/server/static_files/index.html +39 -0
  174. letta/server/static_files/memgpt_logo_transparent.png +0 -0
  175. letta/server/utils.py +46 -0
  176. letta/server/ws_api/__init__.py +0 -0
  177. letta/server/ws_api/example_client.py +104 -0
  178. letta/server/ws_api/interface.py +108 -0
  179. letta/server/ws_api/protocol.py +100 -0
  180. letta/server/ws_api/server.py +145 -0
  181. letta/settings.py +165 -0
  182. letta/streaming_interface.py +396 -0
  183. letta/system.py +207 -0
  184. letta/utils.py +1065 -0
  185. letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
  186. letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
  187. letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
  188. letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
  189. letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
letta/memory.py ADDED
@@ -0,0 +1,440 @@
1
+ import datetime
2
+ from abc import ABC, abstractmethod
3
+ from typing import Callable, Dict, List, Tuple, Union
4
+
5
+ from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
6
+ from letta.embeddings import embedding_model, parse_and_chunk_text, query_embedding
7
+ from letta.llm_api.llm_api_tools import create
8
+ from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
9
+ from letta.schemas.agent import AgentState
10
+ from letta.schemas.memory import Memory
11
+ from letta.schemas.message import Message
12
+ from letta.schemas.passage import Passage
13
+ from letta.utils import (
14
+ count_tokens,
15
+ extract_date_from_timestamp,
16
+ get_local_time,
17
+ printd,
18
+ validate_date_format,
19
+ )
20
+
21
+
22
+ def get_memory_functions(cls: Memory) -> Dict[str, Callable]:
23
+ """Get memory functions for a memory class"""
24
+ functions = {}
25
+
26
+ # collect base memory functions (should not be included)
27
+ base_functions = []
28
+ for func_name in dir(Memory):
29
+ funct = getattr(Memory, func_name)
30
+ if callable(funct):
31
+ base_functions.append(func_name)
32
+
33
+ for func_name in dir(cls):
34
+ if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
35
+ continue
36
+ if func_name in base_functions: # dont use BaseMemory functions
37
+ continue
38
+ func = getattr(cls, func_name)
39
+ if not callable(func): # not a function
40
+ continue
41
+ functions[func_name] = func
42
+ return functions
43
+
44
+
45
+ def _format_summary_history(message_history: List[Message]):
46
+ # TODO use existing prompt formatters for this (eg ChatML)
47
+ return "\n".join([f"{m.role}: {m.text}" for m in message_history])
48
+
49
+
50
+ def summarize_messages(
51
+ agent_state: AgentState,
52
+ message_sequence_to_summarize: List[Message],
53
+ insert_acknowledgement_assistant_message: bool = True,
54
+ ):
55
+ """Summarize a message sequence using GPT"""
56
+ # we need the context_window
57
+ context_window = agent_state.llm_config.context_window
58
+
59
+ summary_prompt = SUMMARY_PROMPT_SYSTEM
60
+ summary_input = _format_summary_history(message_sequence_to_summarize)
61
+ summary_input_tkns = count_tokens(summary_input)
62
+ if summary_input_tkns > MESSAGE_SUMMARY_WARNING_FRAC * context_window:
63
+ trunc_ratio = (MESSAGE_SUMMARY_WARNING_FRAC * context_window / summary_input_tkns) * 0.8 # For good measure...
64
+ cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
65
+ summary_input = str(
66
+ [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])]
67
+ + message_sequence_to_summarize[cutoff:]
68
+ )
69
+
70
+ dummy_user_id = agent_state.user_id
71
+ dummy_agent_id = agent_state.id
72
+ message_sequence = []
73
+ message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt))
74
+ if insert_acknowledgement_assistant_message:
75
+ message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="assistant", text=MESSAGE_SUMMARY_REQUEST_ACK))
76
+ message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input))
77
+
78
+ response = create(
79
+ llm_config=agent_state.llm_config,
80
+ user_id=agent_state.user_id,
81
+ messages=message_sequence,
82
+ stream=False,
83
+ )
84
+
85
+ printd(f"summarize_messages gpt reply: {response.choices[0]}")
86
+ reply = response.choices[0].message.content
87
+ return reply
88
+
89
+
90
+ class ArchivalMemory(ABC):
91
+ @abstractmethod
92
+ def insert(self, memory_string: str):
93
+ """Insert new archival memory
94
+
95
+ :param memory_string: Memory string to insert
96
+ :type memory_string: str
97
+ """
98
+
99
+ @abstractmethod
100
+ def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]:
101
+ """Search archival memory
102
+
103
+ :param query_string: Query string
104
+ :type query_string: str
105
+ :param count: Number of results to return (None for all)
106
+ :type count: Optional[int]
107
+ :param start: Offset to start returning results from (None if 0)
108
+ :type start: Optional[int]
109
+
110
+ :return: Tuple of (list of results, total number of results)
111
+ """
112
+
113
+ @abstractmethod
114
+ def compile(self) -> str:
115
+ """Convert archival memory into a string representation for a prompt"""
116
+
117
+ @abstractmethod
118
+ def count(self) -> int:
119
+ """Count the number of memories in the archival memory"""
120
+
121
+
122
+ class RecallMemory(ABC):
123
+ @abstractmethod
124
+ def text_search(self, query_string, count=None, start=None):
125
+ """Search messages that match query_string in recall memory"""
126
+
127
+ @abstractmethod
128
+ def date_search(self, start_date, end_date, count=None, start=None):
129
+ """Search messages between start_date and end_date in recall memory"""
130
+
131
+ @abstractmethod
132
+ def compile(self) -> str:
133
+ """Convert recall memory into a string representation for a prompt"""
134
+
135
+ @abstractmethod
136
+ def count(self) -> int:
137
+ """Count the number of memories in the recall memory"""
138
+
139
+ @abstractmethod
140
+ def insert(self, message: Message):
141
+ """Insert message into recall memory"""
142
+
143
+
144
+ class DummyRecallMemory(RecallMemory):
145
+ """Dummy in-memory version of a recall memory database (eg run on MongoDB)
146
+
147
+ Recall memory here is basically just a full conversation history with the user.
148
+ Queryable via string matching, or date matching.
149
+
150
+ Recall Memory: The AI's capability to search through past interactions,
151
+ effectively allowing it to 'remember' prior engagements with a user.
152
+ """
153
+
154
+ def __init__(self, message_database=None, restrict_search_to_summaries=False):
155
+ self._message_logs = [] if message_database is None else message_database # consists of full message dicts
156
+
157
+ # If true, the pool of messages that can be queried are the automated summaries only
158
+ # (generated when the conversation window needs to be shortened)
159
+ self.restrict_search_to_summaries = restrict_search_to_summaries
160
+
161
+ def __len__(self):
162
+ return len(self._message_logs)
163
+
164
+ def count(self) -> int:
165
+ return len(self)
166
+
167
+ def compile(self) -> str:
168
+ # don't dump all the conversations, just statistics
169
+ system_count = user_count = assistant_count = function_count = other_count = 0
170
+ for msg in self._message_logs:
171
+ role = msg["message"]["role"]
172
+ if role == "system":
173
+ system_count += 1
174
+ elif role == "user":
175
+ user_count += 1
176
+ elif role == "assistant":
177
+ assistant_count += 1
178
+ elif role == "function":
179
+ function_count += 1
180
+ else:
181
+ other_count += 1
182
+ memory_str = (
183
+ f"Statistics:"
184
+ + f"\n{len(self._message_logs)} total messages"
185
+ + f"\n{system_count} system"
186
+ + f"\n{user_count} user"
187
+ + f"\n{assistant_count} assistant"
188
+ + f"\n{function_count} function"
189
+ + f"\n{other_count} other"
190
+ )
191
+ return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
192
+
193
+ def insert(self, message):
194
+ raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top")
195
+
196
+ def text_search(self, query_string, count=None, start=None):
197
+ # in the dummy version, run an (inefficient) case-insensitive match search
198
+ message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
199
+ start = 0 if start is None else int(start)
200
+ count = 0 if count is None else int(count)
201
+
202
+ printd(
203
+ f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages"
204
+ )
205
+ matches = [
206
+ d for d in message_pool if d["message"]["content"] is not None and query_string.lower() in d["message"]["content"].lower()
207
+ ]
208
+ printd(f"recall_memory - matches:\n{matches[start:start+count]}")
209
+
210
+ # start/count support paging through results
211
+ if start is not None and count is not None:
212
+ return matches[start : start + count], len(matches)
213
+ elif start is None and count is not None:
214
+ return matches[:count], len(matches)
215
+ elif start is not None and count is None:
216
+ return matches[start:], len(matches)
217
+ else:
218
+ return matches, len(matches)
219
+
220
+ def date_search(self, start_date, end_date, count=None, start=None):
221
+ message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
222
+
223
+ # First, validate the start_date and end_date format
224
+ if not validate_date_format(start_date) or not validate_date_format(end_date):
225
+ raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
226
+
227
+ # Convert dates to datetime objects for comparison
228
+ start_date_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d")
229
+ end_date_dt = datetime.datetime.strptime(end_date, "%Y-%m-%d")
230
+
231
+ # Next, match items inside self._message_logs
232
+ matches = [
233
+ d
234
+ for d in message_pool
235
+ if start_date_dt <= datetime.datetime.strptime(extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
236
+ ]
237
+
238
+ # start/count support paging through results
239
+ start = 0 if start is None else int(start)
240
+ count = 0 if count is None else int(count)
241
+ if start is not None and count is not None:
242
+ return matches[start : start + count], len(matches)
243
+ elif start is None and count is not None:
244
+ return matches[:count], len(matches)
245
+ elif start is not None and count is None:
246
+ return matches[start:], len(matches)
247
+ else:
248
+ return matches, len(matches)
249
+
250
+
251
+ class BaseRecallMemory(RecallMemory):
252
+ """Recall memory based on base functions implemented by storage connectors"""
253
+
254
+ def __init__(self, agent_state, restrict_search_to_summaries=False):
255
+ # If true, the pool of messages that can be queried are the automated summaries only
256
+ # (generated when the conversation window needs to be shortened)
257
+ self.restrict_search_to_summaries = restrict_search_to_summaries
258
+ from letta.agent_store.storage import StorageConnector
259
+
260
+ self.agent_state = agent_state
261
+
262
+ # create embedding model
263
+ self.embed_model = embedding_model(agent_state.embedding_config)
264
+ self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
265
+
266
+ # create storage backend
267
+ self.storage = StorageConnector.get_recall_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
268
+ # TODO: have some mechanism for cleanup otherwise will lead to OOM
269
+ self.cache = {}
270
+
271
+ def get_all(self, start=0, count=None):
272
+ start = 0 if start is None else int(start)
273
+ count = 0 if count is None else int(count)
274
+ results = self.storage.get_all(start, count)
275
+ results_json = [message.to_openai_dict() for message in results]
276
+ return results_json, len(results)
277
+
278
+ def text_search(self, query_string, count=None, start=None):
279
+ start = 0 if start is None else int(start)
280
+ count = 0 if count is None else int(count)
281
+ results = self.storage.query_text(query_string, count, start)
282
+ results_json = [message.to_openai_dict_search_results() for message in results]
283
+ return results_json, len(results)
284
+
285
+ def date_search(self, start_date, end_date, count=None, start=None):
286
+ start = 0 if start is None else int(start)
287
+ count = 0 if count is None else int(count)
288
+ results = self.storage.query_date(start_date, end_date, count, start)
289
+ results_json = [message.to_openai_dict_search_results() for message in results]
290
+ return results_json, len(results)
291
+
292
+ def compile(self) -> str:
293
+ total = self.storage.size()
294
+ system_count = self.storage.size(filters={"role": "system"})
295
+ user_count = self.storage.size(filters={"role": "user"})
296
+ assistant_count = self.storage.size(filters={"role": "assistant"})
297
+ function_count = self.storage.size(filters={"role": "function"})
298
+ other_count = total - (system_count + user_count + assistant_count + function_count)
299
+
300
+ memory_str = (
301
+ f"Statistics:"
302
+ + f"\n{total} total messages"
303
+ + f"\n{system_count} system"
304
+ + f"\n{user_count} user"
305
+ + f"\n{assistant_count} assistant"
306
+ + f"\n{function_count} function"
307
+ + f"\n{other_count} other"
308
+ )
309
+ return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
310
+
311
+ def insert(self, message: Message):
312
+ self.storage.insert(message)
313
+
314
+ def insert_many(self, messages: List[Message]):
315
+ self.storage.insert_many(messages)
316
+
317
+ def save(self):
318
+ self.storage.save()
319
+
320
+ def __len__(self):
321
+ return self.storage.size()
322
+
323
+ def count(self) -> int:
324
+ return len(self)
325
+
326
+
327
+ class EmbeddingArchivalMemory(ArchivalMemory):
328
+ """Archival memory with embedding based search"""
329
+
330
+ def __init__(self, agent_state: AgentState, top_k: int = 100):
331
+ """Init function for archival memory
332
+
333
+ :param archival_memory_database: name of dataset to pre-fill archival with
334
+ :type archival_memory_database: str
335
+ """
336
+ from letta.agent_store.storage import StorageConnector
337
+
338
+ self.top_k = top_k
339
+ self.agent_state = agent_state
340
+
341
+ # create embedding model
342
+ self.embed_model = embedding_model(agent_state.embedding_config)
343
+ if agent_state.embedding_config.embedding_chunk_size is None:
344
+ raise ValueError(f"Must set {agent_state.embedding_config.embedding_chunk_size}")
345
+ else:
346
+ self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
347
+
348
+ # create storage backend
349
+ self.storage = StorageConnector.get_archival_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
350
+ # TODO: have some mechanism for cleanup otherwise will lead to OOM
351
+ self.cache = {}
352
+
353
+ def create_passage(self, text, embedding):
354
+ return Passage(
355
+ user_id=self.agent_state.user_id,
356
+ agent_id=self.agent_state.id,
357
+ text=text,
358
+ embedding=embedding,
359
+ embedding_config=self.agent_state.embedding_config,
360
+ )
361
+
362
+ def save(self):
363
+ """Save the index to disk"""
364
+ self.storage.save()
365
+
366
+ def insert(self, memory_string, return_ids=False) -> Union[bool, List[str]]:
367
+ """Embed and save memory string"""
368
+
369
+ if not isinstance(memory_string, str):
370
+ raise TypeError("memory must be a string")
371
+
372
+ try:
373
+ passages = []
374
+
375
+ # breakup string into passages
376
+ for text in parse_and_chunk_text(memory_string, self.embedding_chunk_size):
377
+ embedding = self.embed_model.get_text_embedding(text)
378
+ # fixing weird bug where type returned isn't a list, but instead is an object
379
+ # eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
380
+ if isinstance(embedding, dict):
381
+ try:
382
+ embedding = embedding["data"][0]["embedding"]
383
+ except (KeyError, IndexError):
384
+ # TODO as a fallback, see if we can find any lists in the payload
385
+ raise TypeError(
386
+ f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
387
+ )
388
+ passages.append(self.create_passage(text, embedding))
389
+
390
+ # grab the return IDs before the list gets modified
391
+ ids = [str(p.id) for p in passages]
392
+
393
+ # insert passages
394
+ self.storage.insert_many(passages)
395
+
396
+ if return_ids:
397
+ return ids
398
+ else:
399
+ return True
400
+
401
+ except Exception as e:
402
+ print("Archival insert error", e)
403
+ raise e
404
+
405
+ def search(self, query_string, count=None, start=None):
406
+ """Search query string"""
407
+ start = 0 if start is None else int(start)
408
+ count = self.top_k if count is None else int(count)
409
+
410
+ if not isinstance(query_string, str):
411
+ return TypeError("query must be a string")
412
+
413
+ try:
414
+ if query_string not in self.cache:
415
+ # self.cache[query_string] = self.retriever.retrieve(query_string)
416
+ query_vec = query_embedding(self.embed_model, query_string)
417
+ self.cache[query_string] = self.storage.query(query_string, query_vec, top_k=self.top_k)
418
+
419
+ end = min(count + start, len(self.cache[query_string]))
420
+
421
+ results = self.cache[query_string][start:end]
422
+ results = [{"timestamp": get_local_time(), "content": node.text} for node in results]
423
+ return results, len(results)
424
+ except Exception as e:
425
+ print("Archival search error", e)
426
+ raise e
427
+
428
+ def compile(self) -> str:
429
+ limit = 10
430
+ passages = []
431
+ for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10
432
+ passages.append(str(passage.text))
433
+ memory_str = "\n".join(passages)
434
+ return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" + f"\nSize: {self.storage.size()}"
435
+
436
+ def __len__(self):
437
+ return self.storage.size()
438
+
439
+ def count(self) -> int:
440
+ return len(self)