letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +24 -0
- letta/__main__.py +3 -0
- letta/agent.py +1427 -0
- letta/agent_store/chroma.py +295 -0
- letta/agent_store/db.py +546 -0
- letta/agent_store/lancedb.py +177 -0
- letta/agent_store/milvus.py +198 -0
- letta/agent_store/qdrant.py +201 -0
- letta/agent_store/storage.py +188 -0
- letta/benchmark/benchmark.py +96 -0
- letta/benchmark/constants.py +14 -0
- letta/cli/cli.py +689 -0
- letta/cli/cli_config.py +1282 -0
- letta/cli/cli_load.py +166 -0
- letta/client/__init__.py +0 -0
- letta/client/admin.py +171 -0
- letta/client/client.py +2360 -0
- letta/client/streaming.py +90 -0
- letta/client/utils.py +61 -0
- letta/config.py +484 -0
- letta/configs/anthropic.json +13 -0
- letta/configs/letta_hosted.json +11 -0
- letta/configs/openai.json +12 -0
- letta/constants.py +134 -0
- letta/credentials.py +140 -0
- letta/data_sources/connectors.py +247 -0
- letta/embeddings.py +218 -0
- letta/errors.py +26 -0
- letta/functions/__init__.py +0 -0
- letta/functions/function_sets/base.py +174 -0
- letta/functions/function_sets/extras.py +132 -0
- letta/functions/functions.py +105 -0
- letta/functions/schema_generator.py +205 -0
- letta/humans/__init__.py +0 -0
- letta/humans/examples/basic.txt +1 -0
- letta/humans/examples/cs_phd.txt +9 -0
- letta/interface.py +314 -0
- letta/llm_api/__init__.py +0 -0
- letta/llm_api/anthropic.py +383 -0
- letta/llm_api/azure_openai.py +155 -0
- letta/llm_api/cohere.py +396 -0
- letta/llm_api/google_ai.py +468 -0
- letta/llm_api/llm_api_tools.py +485 -0
- letta/llm_api/openai.py +470 -0
- letta/local_llm/README.md +3 -0
- letta/local_llm/__init__.py +0 -0
- letta/local_llm/chat_completion_proxy.py +279 -0
- letta/local_llm/constants.py +31 -0
- letta/local_llm/function_parser.py +68 -0
- letta/local_llm/grammars/__init__.py +0 -0
- letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
- letta/local_llm/grammars/json.gbnf +26 -0
- letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
- letta/local_llm/groq/api.py +97 -0
- letta/local_llm/json_parser.py +202 -0
- letta/local_llm/koboldcpp/api.py +62 -0
- letta/local_llm/koboldcpp/settings.py +23 -0
- letta/local_llm/llamacpp/api.py +58 -0
- letta/local_llm/llamacpp/settings.py +22 -0
- letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
- letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
- letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
- letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
- letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
- letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
- letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
- letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
- letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
- letta/local_llm/lmstudio/api.py +100 -0
- letta/local_llm/lmstudio/settings.py +29 -0
- letta/local_llm/ollama/api.py +88 -0
- letta/local_llm/ollama/settings.py +32 -0
- letta/local_llm/settings/__init__.py +0 -0
- letta/local_llm/settings/deterministic_mirostat.py +45 -0
- letta/local_llm/settings/settings.py +72 -0
- letta/local_llm/settings/simple.py +28 -0
- letta/local_llm/utils.py +265 -0
- letta/local_llm/vllm/api.py +63 -0
- letta/local_llm/webui/api.py +60 -0
- letta/local_llm/webui/legacy_api.py +58 -0
- letta/local_llm/webui/legacy_settings.py +23 -0
- letta/local_llm/webui/settings.py +24 -0
- letta/log.py +76 -0
- letta/main.py +437 -0
- letta/memory.py +440 -0
- letta/metadata.py +884 -0
- letta/openai_backcompat/__init__.py +0 -0
- letta/openai_backcompat/openai_object.py +437 -0
- letta/persistence_manager.py +148 -0
- letta/personas/__init__.py +0 -0
- letta/personas/examples/anna_pa.txt +13 -0
- letta/personas/examples/google_search_persona.txt +15 -0
- letta/personas/examples/memgpt_doc.txt +6 -0
- letta/personas/examples/memgpt_starter.txt +4 -0
- letta/personas/examples/sam.txt +14 -0
- letta/personas/examples/sam_pov.txt +14 -0
- letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
- letta/personas/examples/sqldb/test.db +0 -0
- letta/prompts/__init__.py +0 -0
- letta/prompts/gpt_summarize.py +14 -0
- letta/prompts/gpt_system.py +26 -0
- letta/prompts/system/memgpt_base.txt +49 -0
- letta/prompts/system/memgpt_chat.txt +58 -0
- letta/prompts/system/memgpt_chat_compressed.txt +13 -0
- letta/prompts/system/memgpt_chat_fstring.txt +51 -0
- letta/prompts/system/memgpt_doc.txt +50 -0
- letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
- letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
- letta/prompts/system/memgpt_modified_chat.txt +23 -0
- letta/pytest.ini +0 -0
- letta/schemas/agent.py +117 -0
- letta/schemas/api_key.py +21 -0
- letta/schemas/block.py +135 -0
- letta/schemas/document.py +21 -0
- letta/schemas/embedding_config.py +54 -0
- letta/schemas/enums.py +35 -0
- letta/schemas/job.py +38 -0
- letta/schemas/letta_base.py +80 -0
- letta/schemas/letta_message.py +175 -0
- letta/schemas/letta_request.py +23 -0
- letta/schemas/letta_response.py +28 -0
- letta/schemas/llm_config.py +54 -0
- letta/schemas/memory.py +224 -0
- letta/schemas/message.py +727 -0
- letta/schemas/openai/chat_completion_request.py +123 -0
- letta/schemas/openai/chat_completion_response.py +136 -0
- letta/schemas/openai/chat_completions.py +123 -0
- letta/schemas/openai/embedding_response.py +11 -0
- letta/schemas/openai/openai.py +157 -0
- letta/schemas/organization.py +20 -0
- letta/schemas/passage.py +80 -0
- letta/schemas/source.py +62 -0
- letta/schemas/tool.py +143 -0
- letta/schemas/usage.py +18 -0
- letta/schemas/user.py +33 -0
- letta/server/__init__.py +0 -0
- letta/server/constants.py +6 -0
- letta/server/rest_api/__init__.py +0 -0
- letta/server/rest_api/admin/__init__.py +0 -0
- letta/server/rest_api/admin/agents.py +21 -0
- letta/server/rest_api/admin/tools.py +83 -0
- letta/server/rest_api/admin/users.py +98 -0
- letta/server/rest_api/app.py +193 -0
- letta/server/rest_api/auth/__init__.py +0 -0
- letta/server/rest_api/auth/index.py +43 -0
- letta/server/rest_api/auth_token.py +22 -0
- letta/server/rest_api/interface.py +726 -0
- letta/server/rest_api/routers/__init__.py +0 -0
- letta/server/rest_api/routers/openai/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
- letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
- letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
- letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
- letta/server/rest_api/routers/v1/__init__.py +15 -0
- letta/server/rest_api/routers/v1/agents.py +543 -0
- letta/server/rest_api/routers/v1/blocks.py +73 -0
- letta/server/rest_api/routers/v1/jobs.py +46 -0
- letta/server/rest_api/routers/v1/llms.py +28 -0
- letta/server/rest_api/routers/v1/organizations.py +61 -0
- letta/server/rest_api/routers/v1/sources.py +199 -0
- letta/server/rest_api/routers/v1/tools.py +103 -0
- letta/server/rest_api/routers/v1/users.py +109 -0
- letta/server/rest_api/static_files.py +74 -0
- letta/server/rest_api/utils.py +69 -0
- letta/server/server.py +1995 -0
- letta/server/startup.sh +8 -0
- letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
- letta/server/static_files/assets/index-156816da.css +1 -0
- letta/server/static_files/assets/index-486e3228.js +274 -0
- letta/server/static_files/favicon.ico +0 -0
- letta/server/static_files/index.html +39 -0
- letta/server/static_files/memgpt_logo_transparent.png +0 -0
- letta/server/utils.py +46 -0
- letta/server/ws_api/__init__.py +0 -0
- letta/server/ws_api/example_client.py +104 -0
- letta/server/ws_api/interface.py +108 -0
- letta/server/ws_api/protocol.py +100 -0
- letta/server/ws_api/server.py +145 -0
- letta/settings.py +165 -0
- letta/streaming_interface.py +396 -0
- letta/system.py +207 -0
- letta/utils.py +1065 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
letta/memory.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Callable, Dict, List, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
|
|
6
|
+
from letta.embeddings import embedding_model, parse_and_chunk_text, query_embedding
|
|
7
|
+
from letta.llm_api.llm_api_tools import create
|
|
8
|
+
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
|
9
|
+
from letta.schemas.agent import AgentState
|
|
10
|
+
from letta.schemas.memory import Memory
|
|
11
|
+
from letta.schemas.message import Message
|
|
12
|
+
from letta.schemas.passage import Passage
|
|
13
|
+
from letta.utils import (
|
|
14
|
+
count_tokens,
|
|
15
|
+
extract_date_from_timestamp,
|
|
16
|
+
get_local_time,
|
|
17
|
+
printd,
|
|
18
|
+
validate_date_format,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_memory_functions(cls: Memory) -> Dict[str, Callable]:
|
|
23
|
+
"""Get memory functions for a memory class"""
|
|
24
|
+
functions = {}
|
|
25
|
+
|
|
26
|
+
# collect base memory functions (should not be included)
|
|
27
|
+
base_functions = []
|
|
28
|
+
for func_name in dir(Memory):
|
|
29
|
+
funct = getattr(Memory, func_name)
|
|
30
|
+
if callable(funct):
|
|
31
|
+
base_functions.append(func_name)
|
|
32
|
+
|
|
33
|
+
for func_name in dir(cls):
|
|
34
|
+
if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
|
|
35
|
+
continue
|
|
36
|
+
if func_name in base_functions: # dont use BaseMemory functions
|
|
37
|
+
continue
|
|
38
|
+
func = getattr(cls, func_name)
|
|
39
|
+
if not callable(func): # not a function
|
|
40
|
+
continue
|
|
41
|
+
functions[func_name] = func
|
|
42
|
+
return functions
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _format_summary_history(message_history: List[Message]):
|
|
46
|
+
# TODO use existing prompt formatters for this (eg ChatML)
|
|
47
|
+
return "\n".join([f"{m.role}: {m.text}" for m in message_history])
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def summarize_messages(
|
|
51
|
+
agent_state: AgentState,
|
|
52
|
+
message_sequence_to_summarize: List[Message],
|
|
53
|
+
insert_acknowledgement_assistant_message: bool = True,
|
|
54
|
+
):
|
|
55
|
+
"""Summarize a message sequence using GPT"""
|
|
56
|
+
# we need the context_window
|
|
57
|
+
context_window = agent_state.llm_config.context_window
|
|
58
|
+
|
|
59
|
+
summary_prompt = SUMMARY_PROMPT_SYSTEM
|
|
60
|
+
summary_input = _format_summary_history(message_sequence_to_summarize)
|
|
61
|
+
summary_input_tkns = count_tokens(summary_input)
|
|
62
|
+
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_FRAC * context_window:
|
|
63
|
+
trunc_ratio = (MESSAGE_SUMMARY_WARNING_FRAC * context_window / summary_input_tkns) * 0.8 # For good measure...
|
|
64
|
+
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
|
|
65
|
+
summary_input = str(
|
|
66
|
+
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])]
|
|
67
|
+
+ message_sequence_to_summarize[cutoff:]
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
dummy_user_id = agent_state.user_id
|
|
71
|
+
dummy_agent_id = agent_state.id
|
|
72
|
+
message_sequence = []
|
|
73
|
+
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt))
|
|
74
|
+
if insert_acknowledgement_assistant_message:
|
|
75
|
+
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="assistant", text=MESSAGE_SUMMARY_REQUEST_ACK))
|
|
76
|
+
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input))
|
|
77
|
+
|
|
78
|
+
response = create(
|
|
79
|
+
llm_config=agent_state.llm_config,
|
|
80
|
+
user_id=agent_state.user_id,
|
|
81
|
+
messages=message_sequence,
|
|
82
|
+
stream=False,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
printd(f"summarize_messages gpt reply: {response.choices[0]}")
|
|
86
|
+
reply = response.choices[0].message.content
|
|
87
|
+
return reply
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ArchivalMemory(ABC):
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def insert(self, memory_string: str):
|
|
93
|
+
"""Insert new archival memory
|
|
94
|
+
|
|
95
|
+
:param memory_string: Memory string to insert
|
|
96
|
+
:type memory_string: str
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]:
|
|
101
|
+
"""Search archival memory
|
|
102
|
+
|
|
103
|
+
:param query_string: Query string
|
|
104
|
+
:type query_string: str
|
|
105
|
+
:param count: Number of results to return (None for all)
|
|
106
|
+
:type count: Optional[int]
|
|
107
|
+
:param start: Offset to start returning results from (None if 0)
|
|
108
|
+
:type start: Optional[int]
|
|
109
|
+
|
|
110
|
+
:return: Tuple of (list of results, total number of results)
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def compile(self) -> str:
|
|
115
|
+
"""Convert archival memory into a string representation for a prompt"""
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def count(self) -> int:
|
|
119
|
+
"""Count the number of memories in the archival memory"""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class RecallMemory(ABC):
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def text_search(self, query_string, count=None, start=None):
|
|
125
|
+
"""Search messages that match query_string in recall memory"""
|
|
126
|
+
|
|
127
|
+
@abstractmethod
|
|
128
|
+
def date_search(self, start_date, end_date, count=None, start=None):
|
|
129
|
+
"""Search messages between start_date and end_date in recall memory"""
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def compile(self) -> str:
|
|
133
|
+
"""Convert recall memory into a string representation for a prompt"""
|
|
134
|
+
|
|
135
|
+
@abstractmethod
|
|
136
|
+
def count(self) -> int:
|
|
137
|
+
"""Count the number of memories in the recall memory"""
|
|
138
|
+
|
|
139
|
+
@abstractmethod
|
|
140
|
+
def insert(self, message: Message):
|
|
141
|
+
"""Insert message into recall memory"""
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class DummyRecallMemory(RecallMemory):
|
|
145
|
+
"""Dummy in-memory version of a recall memory database (eg run on MongoDB)
|
|
146
|
+
|
|
147
|
+
Recall memory here is basically just a full conversation history with the user.
|
|
148
|
+
Queryable via string matching, or date matching.
|
|
149
|
+
|
|
150
|
+
Recall Memory: The AI's capability to search through past interactions,
|
|
151
|
+
effectively allowing it to 'remember' prior engagements with a user.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(self, message_database=None, restrict_search_to_summaries=False):
|
|
155
|
+
self._message_logs = [] if message_database is None else message_database # consists of full message dicts
|
|
156
|
+
|
|
157
|
+
# If true, the pool of messages that can be queried are the automated summaries only
|
|
158
|
+
# (generated when the conversation window needs to be shortened)
|
|
159
|
+
self.restrict_search_to_summaries = restrict_search_to_summaries
|
|
160
|
+
|
|
161
|
+
def __len__(self):
|
|
162
|
+
return len(self._message_logs)
|
|
163
|
+
|
|
164
|
+
def count(self) -> int:
|
|
165
|
+
return len(self)
|
|
166
|
+
|
|
167
|
+
def compile(self) -> str:
|
|
168
|
+
# don't dump all the conversations, just statistics
|
|
169
|
+
system_count = user_count = assistant_count = function_count = other_count = 0
|
|
170
|
+
for msg in self._message_logs:
|
|
171
|
+
role = msg["message"]["role"]
|
|
172
|
+
if role == "system":
|
|
173
|
+
system_count += 1
|
|
174
|
+
elif role == "user":
|
|
175
|
+
user_count += 1
|
|
176
|
+
elif role == "assistant":
|
|
177
|
+
assistant_count += 1
|
|
178
|
+
elif role == "function":
|
|
179
|
+
function_count += 1
|
|
180
|
+
else:
|
|
181
|
+
other_count += 1
|
|
182
|
+
memory_str = (
|
|
183
|
+
f"Statistics:"
|
|
184
|
+
+ f"\n{len(self._message_logs)} total messages"
|
|
185
|
+
+ f"\n{system_count} system"
|
|
186
|
+
+ f"\n{user_count} user"
|
|
187
|
+
+ f"\n{assistant_count} assistant"
|
|
188
|
+
+ f"\n{function_count} function"
|
|
189
|
+
+ f"\n{other_count} other"
|
|
190
|
+
)
|
|
191
|
+
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
|
192
|
+
|
|
193
|
+
def insert(self, message):
|
|
194
|
+
raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top")
|
|
195
|
+
|
|
196
|
+
def text_search(self, query_string, count=None, start=None):
|
|
197
|
+
# in the dummy version, run an (inefficient) case-insensitive match search
|
|
198
|
+
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
|
199
|
+
start = 0 if start is None else int(start)
|
|
200
|
+
count = 0 if count is None else int(count)
|
|
201
|
+
|
|
202
|
+
printd(
|
|
203
|
+
f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages"
|
|
204
|
+
)
|
|
205
|
+
matches = [
|
|
206
|
+
d for d in message_pool if d["message"]["content"] is not None and query_string.lower() in d["message"]["content"].lower()
|
|
207
|
+
]
|
|
208
|
+
printd(f"recall_memory - matches:\n{matches[start:start+count]}")
|
|
209
|
+
|
|
210
|
+
# start/count support paging through results
|
|
211
|
+
if start is not None and count is not None:
|
|
212
|
+
return matches[start : start + count], len(matches)
|
|
213
|
+
elif start is None and count is not None:
|
|
214
|
+
return matches[:count], len(matches)
|
|
215
|
+
elif start is not None and count is None:
|
|
216
|
+
return matches[start:], len(matches)
|
|
217
|
+
else:
|
|
218
|
+
return matches, len(matches)
|
|
219
|
+
|
|
220
|
+
def date_search(self, start_date, end_date, count=None, start=None):
|
|
221
|
+
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
|
222
|
+
|
|
223
|
+
# First, validate the start_date and end_date format
|
|
224
|
+
if not validate_date_format(start_date) or not validate_date_format(end_date):
|
|
225
|
+
raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
|
|
226
|
+
|
|
227
|
+
# Convert dates to datetime objects for comparison
|
|
228
|
+
start_date_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d")
|
|
229
|
+
end_date_dt = datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
|
230
|
+
|
|
231
|
+
# Next, match items inside self._message_logs
|
|
232
|
+
matches = [
|
|
233
|
+
d
|
|
234
|
+
for d in message_pool
|
|
235
|
+
if start_date_dt <= datetime.datetime.strptime(extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
|
|
236
|
+
]
|
|
237
|
+
|
|
238
|
+
# start/count support paging through results
|
|
239
|
+
start = 0 if start is None else int(start)
|
|
240
|
+
count = 0 if count is None else int(count)
|
|
241
|
+
if start is not None and count is not None:
|
|
242
|
+
return matches[start : start + count], len(matches)
|
|
243
|
+
elif start is None and count is not None:
|
|
244
|
+
return matches[:count], len(matches)
|
|
245
|
+
elif start is not None and count is None:
|
|
246
|
+
return matches[start:], len(matches)
|
|
247
|
+
else:
|
|
248
|
+
return matches, len(matches)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class BaseRecallMemory(RecallMemory):
|
|
252
|
+
"""Recall memory based on base functions implemented by storage connectors"""
|
|
253
|
+
|
|
254
|
+
def __init__(self, agent_state, restrict_search_to_summaries=False):
|
|
255
|
+
# If true, the pool of messages that can be queried are the automated summaries only
|
|
256
|
+
# (generated when the conversation window needs to be shortened)
|
|
257
|
+
self.restrict_search_to_summaries = restrict_search_to_summaries
|
|
258
|
+
from letta.agent_store.storage import StorageConnector
|
|
259
|
+
|
|
260
|
+
self.agent_state = agent_state
|
|
261
|
+
|
|
262
|
+
# create embedding model
|
|
263
|
+
self.embed_model = embedding_model(agent_state.embedding_config)
|
|
264
|
+
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
|
265
|
+
|
|
266
|
+
# create storage backend
|
|
267
|
+
self.storage = StorageConnector.get_recall_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
|
|
268
|
+
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
|
269
|
+
self.cache = {}
|
|
270
|
+
|
|
271
|
+
def get_all(self, start=0, count=None):
|
|
272
|
+
start = 0 if start is None else int(start)
|
|
273
|
+
count = 0 if count is None else int(count)
|
|
274
|
+
results = self.storage.get_all(start, count)
|
|
275
|
+
results_json = [message.to_openai_dict() for message in results]
|
|
276
|
+
return results_json, len(results)
|
|
277
|
+
|
|
278
|
+
def text_search(self, query_string, count=None, start=None):
|
|
279
|
+
start = 0 if start is None else int(start)
|
|
280
|
+
count = 0 if count is None else int(count)
|
|
281
|
+
results = self.storage.query_text(query_string, count, start)
|
|
282
|
+
results_json = [message.to_openai_dict_search_results() for message in results]
|
|
283
|
+
return results_json, len(results)
|
|
284
|
+
|
|
285
|
+
def date_search(self, start_date, end_date, count=None, start=None):
|
|
286
|
+
start = 0 if start is None else int(start)
|
|
287
|
+
count = 0 if count is None else int(count)
|
|
288
|
+
results = self.storage.query_date(start_date, end_date, count, start)
|
|
289
|
+
results_json = [message.to_openai_dict_search_results() for message in results]
|
|
290
|
+
return results_json, len(results)
|
|
291
|
+
|
|
292
|
+
def compile(self) -> str:
|
|
293
|
+
total = self.storage.size()
|
|
294
|
+
system_count = self.storage.size(filters={"role": "system"})
|
|
295
|
+
user_count = self.storage.size(filters={"role": "user"})
|
|
296
|
+
assistant_count = self.storage.size(filters={"role": "assistant"})
|
|
297
|
+
function_count = self.storage.size(filters={"role": "function"})
|
|
298
|
+
other_count = total - (system_count + user_count + assistant_count + function_count)
|
|
299
|
+
|
|
300
|
+
memory_str = (
|
|
301
|
+
f"Statistics:"
|
|
302
|
+
+ f"\n{total} total messages"
|
|
303
|
+
+ f"\n{system_count} system"
|
|
304
|
+
+ f"\n{user_count} user"
|
|
305
|
+
+ f"\n{assistant_count} assistant"
|
|
306
|
+
+ f"\n{function_count} function"
|
|
307
|
+
+ f"\n{other_count} other"
|
|
308
|
+
)
|
|
309
|
+
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
|
310
|
+
|
|
311
|
+
def insert(self, message: Message):
|
|
312
|
+
self.storage.insert(message)
|
|
313
|
+
|
|
314
|
+
def insert_many(self, messages: List[Message]):
|
|
315
|
+
self.storage.insert_many(messages)
|
|
316
|
+
|
|
317
|
+
def save(self):
|
|
318
|
+
self.storage.save()
|
|
319
|
+
|
|
320
|
+
def __len__(self):
|
|
321
|
+
return self.storage.size()
|
|
322
|
+
|
|
323
|
+
def count(self) -> int:
|
|
324
|
+
return len(self)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class EmbeddingArchivalMemory(ArchivalMemory):
|
|
328
|
+
"""Archival memory with embedding based search"""
|
|
329
|
+
|
|
330
|
+
def __init__(self, agent_state: AgentState, top_k: int = 100):
|
|
331
|
+
"""Init function for archival memory
|
|
332
|
+
|
|
333
|
+
:param archival_memory_database: name of dataset to pre-fill archival with
|
|
334
|
+
:type archival_memory_database: str
|
|
335
|
+
"""
|
|
336
|
+
from letta.agent_store.storage import StorageConnector
|
|
337
|
+
|
|
338
|
+
self.top_k = top_k
|
|
339
|
+
self.agent_state = agent_state
|
|
340
|
+
|
|
341
|
+
# create embedding model
|
|
342
|
+
self.embed_model = embedding_model(agent_state.embedding_config)
|
|
343
|
+
if agent_state.embedding_config.embedding_chunk_size is None:
|
|
344
|
+
raise ValueError(f"Must set {agent_state.embedding_config.embedding_chunk_size}")
|
|
345
|
+
else:
|
|
346
|
+
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
|
347
|
+
|
|
348
|
+
# create storage backend
|
|
349
|
+
self.storage = StorageConnector.get_archival_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
|
|
350
|
+
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
|
351
|
+
self.cache = {}
|
|
352
|
+
|
|
353
|
+
def create_passage(self, text, embedding):
|
|
354
|
+
return Passage(
|
|
355
|
+
user_id=self.agent_state.user_id,
|
|
356
|
+
agent_id=self.agent_state.id,
|
|
357
|
+
text=text,
|
|
358
|
+
embedding=embedding,
|
|
359
|
+
embedding_config=self.agent_state.embedding_config,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
def save(self):
|
|
363
|
+
"""Save the index to disk"""
|
|
364
|
+
self.storage.save()
|
|
365
|
+
|
|
366
|
+
def insert(self, memory_string, return_ids=False) -> Union[bool, List[str]]:
|
|
367
|
+
"""Embed and save memory string"""
|
|
368
|
+
|
|
369
|
+
if not isinstance(memory_string, str):
|
|
370
|
+
raise TypeError("memory must be a string")
|
|
371
|
+
|
|
372
|
+
try:
|
|
373
|
+
passages = []
|
|
374
|
+
|
|
375
|
+
# breakup string into passages
|
|
376
|
+
for text in parse_and_chunk_text(memory_string, self.embedding_chunk_size):
|
|
377
|
+
embedding = self.embed_model.get_text_embedding(text)
|
|
378
|
+
# fixing weird bug where type returned isn't a list, but instead is an object
|
|
379
|
+
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
|
|
380
|
+
if isinstance(embedding, dict):
|
|
381
|
+
try:
|
|
382
|
+
embedding = embedding["data"][0]["embedding"]
|
|
383
|
+
except (KeyError, IndexError):
|
|
384
|
+
# TODO as a fallback, see if we can find any lists in the payload
|
|
385
|
+
raise TypeError(
|
|
386
|
+
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
|
387
|
+
)
|
|
388
|
+
passages.append(self.create_passage(text, embedding))
|
|
389
|
+
|
|
390
|
+
# grab the return IDs before the list gets modified
|
|
391
|
+
ids = [str(p.id) for p in passages]
|
|
392
|
+
|
|
393
|
+
# insert passages
|
|
394
|
+
self.storage.insert_many(passages)
|
|
395
|
+
|
|
396
|
+
if return_ids:
|
|
397
|
+
return ids
|
|
398
|
+
else:
|
|
399
|
+
return True
|
|
400
|
+
|
|
401
|
+
except Exception as e:
|
|
402
|
+
print("Archival insert error", e)
|
|
403
|
+
raise e
|
|
404
|
+
|
|
405
|
+
def search(self, query_string, count=None, start=None):
|
|
406
|
+
"""Search query string"""
|
|
407
|
+
start = 0 if start is None else int(start)
|
|
408
|
+
count = self.top_k if count is None else int(count)
|
|
409
|
+
|
|
410
|
+
if not isinstance(query_string, str):
|
|
411
|
+
return TypeError("query must be a string")
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
if query_string not in self.cache:
|
|
415
|
+
# self.cache[query_string] = self.retriever.retrieve(query_string)
|
|
416
|
+
query_vec = query_embedding(self.embed_model, query_string)
|
|
417
|
+
self.cache[query_string] = self.storage.query(query_string, query_vec, top_k=self.top_k)
|
|
418
|
+
|
|
419
|
+
end = min(count + start, len(self.cache[query_string]))
|
|
420
|
+
|
|
421
|
+
results = self.cache[query_string][start:end]
|
|
422
|
+
results = [{"timestamp": get_local_time(), "content": node.text} for node in results]
|
|
423
|
+
return results, len(results)
|
|
424
|
+
except Exception as e:
|
|
425
|
+
print("Archival search error", e)
|
|
426
|
+
raise e
|
|
427
|
+
|
|
428
|
+
def compile(self) -> str:
|
|
429
|
+
limit = 10
|
|
430
|
+
passages = []
|
|
431
|
+
for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10
|
|
432
|
+
passages.append(str(passage.text))
|
|
433
|
+
memory_str = "\n".join(passages)
|
|
434
|
+
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" + f"\nSize: {self.storage.size()}"
|
|
435
|
+
|
|
436
|
+
def __len__(self):
|
|
437
|
+
return self.storage.size()
|
|
438
|
+
|
|
439
|
+
def count(self) -> int:
|
|
440
|
+
return len(self)
|