letta-nightly 0.1.7.dev20240924104148__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +24 -0
- letta/__main__.py +3 -0
- letta/agent.py +1427 -0
- letta/agent_store/chroma.py +295 -0
- letta/agent_store/db.py +546 -0
- letta/agent_store/lancedb.py +177 -0
- letta/agent_store/milvus.py +198 -0
- letta/agent_store/qdrant.py +201 -0
- letta/agent_store/storage.py +188 -0
- letta/benchmark/benchmark.py +96 -0
- letta/benchmark/constants.py +14 -0
- letta/cli/cli.py +689 -0
- letta/cli/cli_config.py +1282 -0
- letta/cli/cli_load.py +166 -0
- letta/client/__init__.py +0 -0
- letta/client/admin.py +171 -0
- letta/client/client.py +2360 -0
- letta/client/streaming.py +90 -0
- letta/client/utils.py +61 -0
- letta/config.py +484 -0
- letta/configs/anthropic.json +13 -0
- letta/configs/letta_hosted.json +11 -0
- letta/configs/openai.json +12 -0
- letta/constants.py +134 -0
- letta/credentials.py +140 -0
- letta/data_sources/connectors.py +247 -0
- letta/embeddings.py +218 -0
- letta/errors.py +26 -0
- letta/functions/__init__.py +0 -0
- letta/functions/function_sets/base.py +174 -0
- letta/functions/function_sets/extras.py +132 -0
- letta/functions/functions.py +105 -0
- letta/functions/schema_generator.py +205 -0
- letta/humans/__init__.py +0 -0
- letta/humans/examples/basic.txt +1 -0
- letta/humans/examples/cs_phd.txt +9 -0
- letta/interface.py +314 -0
- letta/llm_api/__init__.py +0 -0
- letta/llm_api/anthropic.py +383 -0
- letta/llm_api/azure_openai.py +155 -0
- letta/llm_api/cohere.py +396 -0
- letta/llm_api/google_ai.py +468 -0
- letta/llm_api/llm_api_tools.py +485 -0
- letta/llm_api/openai.py +470 -0
- letta/local_llm/README.md +3 -0
- letta/local_llm/__init__.py +0 -0
- letta/local_llm/chat_completion_proxy.py +279 -0
- letta/local_llm/constants.py +31 -0
- letta/local_llm/function_parser.py +68 -0
- letta/local_llm/grammars/__init__.py +0 -0
- letta/local_llm/grammars/gbnf_grammar_generator.py +1324 -0
- letta/local_llm/grammars/json.gbnf +26 -0
- letta/local_llm/grammars/json_func_calls_with_inner_thoughts.gbnf +32 -0
- letta/local_llm/groq/api.py +97 -0
- letta/local_llm/json_parser.py +202 -0
- letta/local_llm/koboldcpp/api.py +62 -0
- letta/local_llm/koboldcpp/settings.py +23 -0
- letta/local_llm/llamacpp/api.py +58 -0
- letta/local_llm/llamacpp/settings.py +22 -0
- letta/local_llm/llm_chat_completion_wrappers/__init__.py +0 -0
- letta/local_llm/llm_chat_completion_wrappers/airoboros.py +452 -0
- letta/local_llm/llm_chat_completion_wrappers/chatml.py +470 -0
- letta/local_llm/llm_chat_completion_wrappers/configurable_wrapper.py +387 -0
- letta/local_llm/llm_chat_completion_wrappers/dolphin.py +246 -0
- letta/local_llm/llm_chat_completion_wrappers/llama3.py +345 -0
- letta/local_llm/llm_chat_completion_wrappers/simple_summary_wrapper.py +156 -0
- letta/local_llm/llm_chat_completion_wrappers/wrapper_base.py +11 -0
- letta/local_llm/llm_chat_completion_wrappers/zephyr.py +345 -0
- letta/local_llm/lmstudio/api.py +100 -0
- letta/local_llm/lmstudio/settings.py +29 -0
- letta/local_llm/ollama/api.py +88 -0
- letta/local_llm/ollama/settings.py +32 -0
- letta/local_llm/settings/__init__.py +0 -0
- letta/local_llm/settings/deterministic_mirostat.py +45 -0
- letta/local_llm/settings/settings.py +72 -0
- letta/local_llm/settings/simple.py +28 -0
- letta/local_llm/utils.py +265 -0
- letta/local_llm/vllm/api.py +63 -0
- letta/local_llm/webui/api.py +60 -0
- letta/local_llm/webui/legacy_api.py +58 -0
- letta/local_llm/webui/legacy_settings.py +23 -0
- letta/local_llm/webui/settings.py +24 -0
- letta/log.py +76 -0
- letta/main.py +437 -0
- letta/memory.py +440 -0
- letta/metadata.py +884 -0
- letta/openai_backcompat/__init__.py +0 -0
- letta/openai_backcompat/openai_object.py +437 -0
- letta/persistence_manager.py +148 -0
- letta/personas/__init__.py +0 -0
- letta/personas/examples/anna_pa.txt +13 -0
- letta/personas/examples/google_search_persona.txt +15 -0
- letta/personas/examples/memgpt_doc.txt +6 -0
- letta/personas/examples/memgpt_starter.txt +4 -0
- letta/personas/examples/sam.txt +14 -0
- letta/personas/examples/sam_pov.txt +14 -0
- letta/personas/examples/sam_simple_pov_gpt35.txt +13 -0
- letta/personas/examples/sqldb/test.db +0 -0
- letta/prompts/__init__.py +0 -0
- letta/prompts/gpt_summarize.py +14 -0
- letta/prompts/gpt_system.py +26 -0
- letta/prompts/system/memgpt_base.txt +49 -0
- letta/prompts/system/memgpt_chat.txt +58 -0
- letta/prompts/system/memgpt_chat_compressed.txt +13 -0
- letta/prompts/system/memgpt_chat_fstring.txt +51 -0
- letta/prompts/system/memgpt_doc.txt +50 -0
- letta/prompts/system/memgpt_gpt35_extralong.txt +53 -0
- letta/prompts/system/memgpt_intuitive_knowledge.txt +31 -0
- letta/prompts/system/memgpt_modified_chat.txt +23 -0
- letta/pytest.ini +0 -0
- letta/schemas/agent.py +117 -0
- letta/schemas/api_key.py +21 -0
- letta/schemas/block.py +135 -0
- letta/schemas/document.py +21 -0
- letta/schemas/embedding_config.py +54 -0
- letta/schemas/enums.py +35 -0
- letta/schemas/job.py +38 -0
- letta/schemas/letta_base.py +80 -0
- letta/schemas/letta_message.py +175 -0
- letta/schemas/letta_request.py +23 -0
- letta/schemas/letta_response.py +28 -0
- letta/schemas/llm_config.py +54 -0
- letta/schemas/memory.py +224 -0
- letta/schemas/message.py +727 -0
- letta/schemas/openai/chat_completion_request.py +123 -0
- letta/schemas/openai/chat_completion_response.py +136 -0
- letta/schemas/openai/chat_completions.py +123 -0
- letta/schemas/openai/embedding_response.py +11 -0
- letta/schemas/openai/openai.py +157 -0
- letta/schemas/organization.py +20 -0
- letta/schemas/passage.py +80 -0
- letta/schemas/source.py +62 -0
- letta/schemas/tool.py +143 -0
- letta/schemas/usage.py +18 -0
- letta/schemas/user.py +33 -0
- letta/server/__init__.py +0 -0
- letta/server/constants.py +6 -0
- letta/server/rest_api/__init__.py +0 -0
- letta/server/rest_api/admin/__init__.py +0 -0
- letta/server/rest_api/admin/agents.py +21 -0
- letta/server/rest_api/admin/tools.py +83 -0
- letta/server/rest_api/admin/users.py +98 -0
- letta/server/rest_api/app.py +193 -0
- letta/server/rest_api/auth/__init__.py +0 -0
- letta/server/rest_api/auth/index.py +43 -0
- letta/server/rest_api/auth_token.py +22 -0
- letta/server/rest_api/interface.py +726 -0
- letta/server/rest_api/routers/__init__.py +0 -0
- letta/server/rest_api/routers/openai/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/__init__.py +0 -0
- letta/server/rest_api/routers/openai/assistants/assistants.py +115 -0
- letta/server/rest_api/routers/openai/assistants/schemas.py +121 -0
- letta/server/rest_api/routers/openai/assistants/threads.py +336 -0
- letta/server/rest_api/routers/openai/chat_completions/__init__.py +0 -0
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +131 -0
- letta/server/rest_api/routers/v1/__init__.py +15 -0
- letta/server/rest_api/routers/v1/agents.py +543 -0
- letta/server/rest_api/routers/v1/blocks.py +73 -0
- letta/server/rest_api/routers/v1/jobs.py +46 -0
- letta/server/rest_api/routers/v1/llms.py +28 -0
- letta/server/rest_api/routers/v1/organizations.py +61 -0
- letta/server/rest_api/routers/v1/sources.py +199 -0
- letta/server/rest_api/routers/v1/tools.py +103 -0
- letta/server/rest_api/routers/v1/users.py +109 -0
- letta/server/rest_api/static_files.py +74 -0
- letta/server/rest_api/utils.py +69 -0
- letta/server/server.py +1995 -0
- letta/server/startup.sh +8 -0
- letta/server/static_files/assets/index-0cbf7ad5.js +274 -0
- letta/server/static_files/assets/index-156816da.css +1 -0
- letta/server/static_files/assets/index-486e3228.js +274 -0
- letta/server/static_files/favicon.ico +0 -0
- letta/server/static_files/index.html +39 -0
- letta/server/static_files/memgpt_logo_transparent.png +0 -0
- letta/server/utils.py +46 -0
- letta/server/ws_api/__init__.py +0 -0
- letta/server/ws_api/example_client.py +104 -0
- letta/server/ws_api/interface.py +108 -0
- letta/server/ws_api/protocol.py +100 -0
- letta/server/ws_api/server.py +145 -0
- letta/settings.py +165 -0
- letta/streaming_interface.py +396 -0
- letta/system.py +207 -0
- letta/utils.py +1065 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/LICENSE +190 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/METADATA +98 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/RECORD +189 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/WHEEL +4 -0
- letta_nightly-0.1.7.dev20240924104148.dist-info/entry_points.txt +3 -0
letta/server/server.py
ADDED
|
@@ -0,0 +1,1995 @@
|
|
|
1
|
+
# inspecting tools
|
|
2
|
+
import importlib
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
import traceback
|
|
6
|
+
import warnings
|
|
7
|
+
from abc import abstractmethod
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Callable, List, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
from fastapi import HTTPException
|
|
12
|
+
|
|
13
|
+
import letta.constants as constants
|
|
14
|
+
import letta.server.utils as server_utils
|
|
15
|
+
import letta.system as system
|
|
16
|
+
from letta.agent import Agent, save_agent
|
|
17
|
+
from letta.agent_store.storage import StorageConnector, TableType
|
|
18
|
+
from letta.cli.cli_config import get_model_options
|
|
19
|
+
from letta.config import LettaConfig
|
|
20
|
+
from letta.credentials import LettaCredentials
|
|
21
|
+
from letta.data_sources.connectors import DataConnector, load_data
|
|
22
|
+
|
|
23
|
+
# from letta.data_types import (
|
|
24
|
+
# AgentState,
|
|
25
|
+
# EmbeddingConfig,
|
|
26
|
+
# LLMConfig,
|
|
27
|
+
# Message,
|
|
28
|
+
# Preset,
|
|
29
|
+
# Source,
|
|
30
|
+
# Token,
|
|
31
|
+
# User,
|
|
32
|
+
# )
|
|
33
|
+
from letta.functions.functions import (
|
|
34
|
+
generate_schema,
|
|
35
|
+
load_function_set,
|
|
36
|
+
parse_source_code,
|
|
37
|
+
)
|
|
38
|
+
from letta.functions.schema_generator import generate_schema
|
|
39
|
+
|
|
40
|
+
# TODO use custom interface
|
|
41
|
+
from letta.interface import AgentInterface # abstract
|
|
42
|
+
from letta.interface import CLIInterface # for printing to terminal
|
|
43
|
+
from letta.log import get_logger
|
|
44
|
+
from letta.memory import get_memory_functions
|
|
45
|
+
from letta.metadata import MetadataStore
|
|
46
|
+
from letta.prompts import gpt_system
|
|
47
|
+
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
|
48
|
+
from letta.schemas.api_key import APIKey, APIKeyCreate
|
|
49
|
+
from letta.schemas.block import (
|
|
50
|
+
Block,
|
|
51
|
+
CreateBlock,
|
|
52
|
+
CreateHuman,
|
|
53
|
+
CreatePersona,
|
|
54
|
+
UpdateBlock,
|
|
55
|
+
)
|
|
56
|
+
from letta.schemas.document import Document
|
|
57
|
+
from letta.schemas.embedding_config import EmbeddingConfig
|
|
58
|
+
|
|
59
|
+
# openai schemas
|
|
60
|
+
from letta.schemas.enums import JobStatus
|
|
61
|
+
from letta.schemas.job import Job
|
|
62
|
+
from letta.schemas.letta_message import LettaMessage
|
|
63
|
+
from letta.schemas.llm_config import LLMConfig
|
|
64
|
+
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
|
65
|
+
from letta.schemas.message import Message, UpdateMessage
|
|
66
|
+
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
|
67
|
+
from letta.schemas.organization import Organization, OrganizationCreate
|
|
68
|
+
from letta.schemas.passage import Passage
|
|
69
|
+
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
|
70
|
+
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
|
71
|
+
from letta.schemas.usage import LettaUsageStatistics
|
|
72
|
+
from letta.schemas.user import User, UserCreate
|
|
73
|
+
from letta.utils import create_random_username, json_dumps, json_loads
|
|
74
|
+
|
|
75
|
+
# from letta.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
logger = get_logger(__name__)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Server(object):
|
|
82
|
+
"""Abstract server class that supports multi-agent multi-user"""
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def list_agents(self, user_id: str) -> dict:
|
|
86
|
+
"""List all available agents to a user"""
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
|
|
91
|
+
"""Paginated query of in-context messages in agent message queue"""
|
|
92
|
+
raise NotImplementedError
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
|
|
96
|
+
"""Return the memory of an agent (core memory + non-core statistics)"""
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
def get_agent_state(self, user_id: str, agent_id: str) -> dict:
|
|
101
|
+
"""Return the config of an agent"""
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def get_server_config(self, user_id: str) -> dict:
|
|
106
|
+
"""Return the base config"""
|
|
107
|
+
raise NotImplementedError
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> dict:
|
|
111
|
+
"""Update the agents core memory block, return the new state"""
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def create_agent(
|
|
116
|
+
self,
|
|
117
|
+
user_id: str,
|
|
118
|
+
agent_config: Union[dict, AgentState],
|
|
119
|
+
interface: Union[AgentInterface, None],
|
|
120
|
+
) -> str:
|
|
121
|
+
"""Create a new agent using a config"""
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def user_message(self, user_id: str, agent_id: str, message: str) -> None:
|
|
126
|
+
"""Process a message from the user, internally calls step"""
|
|
127
|
+
raise NotImplementedError
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def system_message(self, user_id: str, agent_id: str, message: str) -> None:
|
|
131
|
+
"""Process a message from the system, internally calls step"""
|
|
132
|
+
raise NotImplementedError
|
|
133
|
+
|
|
134
|
+
@abstractmethod
|
|
135
|
+
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
|
136
|
+
"""Run a command on the agent, e.g. /memory
|
|
137
|
+
|
|
138
|
+
May return a string with a message generated by the command
|
|
139
|
+
"""
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
from sqlalchemy import create_engine
|
|
144
|
+
from sqlalchemy.orm import declarative_base, sessionmaker
|
|
145
|
+
|
|
146
|
+
from letta.agent_store.db import MessageModel, PassageModel
|
|
147
|
+
from letta.config import LettaConfig
|
|
148
|
+
|
|
149
|
+
# NOTE: hack to see if single session management works
|
|
150
|
+
from letta.metadata import (
|
|
151
|
+
AgentModel,
|
|
152
|
+
AgentSourceMappingModel,
|
|
153
|
+
APIKeyModel,
|
|
154
|
+
BlockModel,
|
|
155
|
+
JobModel,
|
|
156
|
+
OrganizationModel,
|
|
157
|
+
SourceModel,
|
|
158
|
+
ToolModel,
|
|
159
|
+
UserModel,
|
|
160
|
+
)
|
|
161
|
+
from letta.settings import settings
|
|
162
|
+
|
|
163
|
+
config = LettaConfig.load()
|
|
164
|
+
|
|
165
|
+
if settings.letta_pg_uri_no_default:
|
|
166
|
+
config.recall_storage_type = "postgres"
|
|
167
|
+
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
|
168
|
+
config.archival_storage_type = "postgres"
|
|
169
|
+
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
|
170
|
+
|
|
171
|
+
# create engine
|
|
172
|
+
engine = create_engine(settings.letta_pg_uri)
|
|
173
|
+
else:
|
|
174
|
+
# TODO: don't rely on config storage
|
|
175
|
+
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))
|
|
176
|
+
|
|
177
|
+
Base = declarative_base()
|
|
178
|
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
179
|
+
Base.metadata.create_all(
|
|
180
|
+
engine,
|
|
181
|
+
tables=[
|
|
182
|
+
UserModel.__table__,
|
|
183
|
+
AgentModel.__table__,
|
|
184
|
+
SourceModel.__table__,
|
|
185
|
+
AgentSourceMappingModel.__table__,
|
|
186
|
+
APIKeyModel.__table__,
|
|
187
|
+
BlockModel.__table__,
|
|
188
|
+
ToolModel.__table__,
|
|
189
|
+
JobModel.__table__,
|
|
190
|
+
PassageModel.__table__,
|
|
191
|
+
MessageModel.__table__,
|
|
192
|
+
OrganizationModel.__table__,
|
|
193
|
+
],
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# Dependency
|
|
198
|
+
def get_db():
|
|
199
|
+
db = SessionLocal()
|
|
200
|
+
try:
|
|
201
|
+
yield db
|
|
202
|
+
finally:
|
|
203
|
+
db.close()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
from contextlib import contextmanager
|
|
207
|
+
|
|
208
|
+
db_context = contextmanager(get_db)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class SyncServer(Server):
|
|
212
|
+
"""Simple single-threaded / blocking server process"""
|
|
213
|
+
|
|
214
|
+
def __init__(
|
|
215
|
+
self,
|
|
216
|
+
chaining: bool = True,
|
|
217
|
+
max_chaining_steps: bool = None,
|
|
218
|
+
default_interface_factory: Callable[[], AgentInterface] = lambda: CLIInterface(),
|
|
219
|
+
# default_interface: AgentInterface = CLIInterface(),
|
|
220
|
+
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
|
|
221
|
+
# auth_mode: str = "none", # "none, "jwt", "external"
|
|
222
|
+
):
|
|
223
|
+
"""Server process holds in-memory agents that are being run"""
|
|
224
|
+
|
|
225
|
+
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
|
226
|
+
self.active_agents = []
|
|
227
|
+
|
|
228
|
+
# chaining = whether or not to run again if request_heartbeat=true
|
|
229
|
+
self.chaining = chaining
|
|
230
|
+
|
|
231
|
+
# if chaining == true, what's the max number of times we'll chain before yielding?
|
|
232
|
+
# none = no limit, can go on forever
|
|
233
|
+
self.max_chaining_steps = max_chaining_steps
|
|
234
|
+
|
|
235
|
+
# The default interface that will get assigned to agents ON LOAD
|
|
236
|
+
self.default_interface_factory = default_interface_factory
|
|
237
|
+
# self.default_interface = default_interface
|
|
238
|
+
# self.default_interface = default_interface_cls()
|
|
239
|
+
|
|
240
|
+
# Initialize the connection to the DB
|
|
241
|
+
# try:
|
|
242
|
+
# self.config = LettaConfig.load()
|
|
243
|
+
# assert self.config.default_llm_config is not None, "default_llm_config must be set in the config"
|
|
244
|
+
# assert self.config.default_embedding_config is not None, "default_embedding_config must be set in the config"
|
|
245
|
+
# except Exception as e:
|
|
246
|
+
# # TODO: very hacky - need to improve model config for docker container
|
|
247
|
+
# if os.getenv("OPENAI_API_KEY") is None:
|
|
248
|
+
# logger.error("No OPENAI_API_KEY environment variable set and no ~/.letta/config")
|
|
249
|
+
# raise e
|
|
250
|
+
|
|
251
|
+
# from letta.cli.cli import QuickstartChoice, quickstart
|
|
252
|
+
|
|
253
|
+
# quickstart(backend=QuickstartChoice.openai, debug=False, terminal=False, latest=False)
|
|
254
|
+
# self.config = LettaConfig.load()
|
|
255
|
+
# self.config.save()
|
|
256
|
+
|
|
257
|
+
# TODO figure out how to handle credentials for the server
|
|
258
|
+
self.credentials = LettaCredentials.load()
|
|
259
|
+
|
|
260
|
+
# Generate default LLM/Embedding configs for the server
|
|
261
|
+
# TODO: we may also want to do the same thing with default persona/human/etc.
|
|
262
|
+
self.server_llm_config = settings.llm_config
|
|
263
|
+
self.server_embedding_config = settings.embedding_config
|
|
264
|
+
# self.server_llm_config = LLMConfig(
|
|
265
|
+
# model=self.config.default_llm_config.model,
|
|
266
|
+
# model_endpoint_type=self.config.default_llm_config.model_endpoint_type,
|
|
267
|
+
# model_endpoint=self.config.default_llm_config.model_endpoint,
|
|
268
|
+
# model_wrapper=self.config.default_llm_config.model_wrapper,
|
|
269
|
+
# context_window=self.config.default_llm_config.context_window,
|
|
270
|
+
# )
|
|
271
|
+
# self.server_embedding_config = EmbeddingConfig(
|
|
272
|
+
# embedding_endpoint_type=self.config.default_embedding_config.embedding_endpoint_type,
|
|
273
|
+
# embedding_endpoint=self.config.default_embedding_config.embedding_endpoint,
|
|
274
|
+
# embedding_dim=self.config.default_embedding_config.embedding_dim,
|
|
275
|
+
# embedding_model=self.config.default_embedding_config.embedding_model,
|
|
276
|
+
# embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size,
|
|
277
|
+
# )
|
|
278
|
+
assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config)
|
|
279
|
+
|
|
280
|
+
# Override config values with settings
|
|
281
|
+
|
|
282
|
+
# Initialize the metadata store
|
|
283
|
+
config = LettaConfig.load()
|
|
284
|
+
if settings.letta_pg_uri_no_default:
|
|
285
|
+
config.recall_storage_type = "postgres"
|
|
286
|
+
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
|
287
|
+
config.archival_storage_type = "postgres"
|
|
288
|
+
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
|
289
|
+
config.default_llm_config = self.server_llm_config
|
|
290
|
+
config.default_embedding_config = self.server_embedding_config
|
|
291
|
+
config.save()
|
|
292
|
+
self.config = config
|
|
293
|
+
self.ms = MetadataStore(self.config)
|
|
294
|
+
|
|
295
|
+
# TODO: this should be removed
|
|
296
|
+
# add global default tools (for admin)
|
|
297
|
+
self.add_default_tools(module_name="base")
|
|
298
|
+
|
|
299
|
+
def save_agents(self):
|
|
300
|
+
"""Saves all the agents that are in the in-memory object store"""
|
|
301
|
+
for agent_d in self.active_agents:
|
|
302
|
+
try:
|
|
303
|
+
save_agent(agent_d["agent"], self.ms)
|
|
304
|
+
logger.debug(f"Saved agent {agent_d['agent_id']}")
|
|
305
|
+
except Exception as e:
|
|
306
|
+
logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}")
|
|
307
|
+
|
|
308
|
+
def _get_agent(self, user_id: str, agent_id: str) -> Union[Agent, None]:
|
|
309
|
+
"""Get the agent object from the in-memory object store"""
|
|
310
|
+
for d in self.active_agents:
|
|
311
|
+
if d["user_id"] == str(user_id) and d["agent_id"] == str(agent_id):
|
|
312
|
+
return d["agent"]
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
def _add_agent(self, user_id: str, agent_id: str, agent_obj: Agent) -> None:
|
|
316
|
+
"""Put an agent object inside the in-memory object store"""
|
|
317
|
+
# Make sure the agent doesn't already exist
|
|
318
|
+
if self._get_agent(user_id=user_id, agent_id=agent_id) is not None:
|
|
319
|
+
# Can be triggered on concucrent request, so don't throw a full error
|
|
320
|
+
logger.exception(f"Agent (user={user_id}, agent={agent_id}) is already loaded")
|
|
321
|
+
return
|
|
322
|
+
# Add Agent instance to the in-memory list
|
|
323
|
+
self.active_agents.append(
|
|
324
|
+
{
|
|
325
|
+
"user_id": str(user_id),
|
|
326
|
+
"agent_id": str(agent_id),
|
|
327
|
+
"agent": agent_obj,
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
def _load_agent(self, user_id: str, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent:
|
|
332
|
+
"""Loads a saved agent into memory (if it doesn't exist, throw an error)"""
|
|
333
|
+
assert isinstance(user_id, str), user_id
|
|
334
|
+
assert isinstance(agent_id, str), agent_id
|
|
335
|
+
|
|
336
|
+
# If an interface isn't specified, use the default
|
|
337
|
+
if interface is None:
|
|
338
|
+
interface = self.default_interface_factory()
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
logger.debug(f"Grabbing agent user_id={user_id} agent_id={agent_id} from database")
|
|
342
|
+
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
|
|
343
|
+
if not agent_state:
|
|
344
|
+
logger.exception(f"agent_id {agent_id} does not exist")
|
|
345
|
+
raise ValueError(f"agent_id {agent_id} does not exist")
|
|
346
|
+
|
|
347
|
+
# Instantiate an agent object using the state retrieved
|
|
348
|
+
logger.debug(f"Creating an agent object")
|
|
349
|
+
tool_objs = []
|
|
350
|
+
for name in agent_state.tools:
|
|
351
|
+
tool_obj = self.ms.get_tool(tool_name=name, user_id=user_id)
|
|
352
|
+
if not tool_obj:
|
|
353
|
+
logger.exception(f"Tool {name} does not exist for user {user_id}")
|
|
354
|
+
raise ValueError(f"Tool {name} does not exist for user {user_id}")
|
|
355
|
+
tool_objs.append(tool_obj)
|
|
356
|
+
|
|
357
|
+
# Make sure the memory is a memory object
|
|
358
|
+
assert isinstance(agent_state.memory, Memory)
|
|
359
|
+
|
|
360
|
+
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
|
361
|
+
|
|
362
|
+
# Add the agent to the in-memory store and return its reference
|
|
363
|
+
logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
|
|
364
|
+
self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=letta_agent)
|
|
365
|
+
return letta_agent
|
|
366
|
+
|
|
367
|
+
except Exception as e:
|
|
368
|
+
logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}")
|
|
369
|
+
raise
|
|
370
|
+
|
|
371
|
+
def _get_or_load_agent(self, agent_id: str) -> Agent:
|
|
372
|
+
"""Check if the agent is in-memory, then load"""
|
|
373
|
+
agent_state = self.ms.get_agent(agent_id=agent_id)
|
|
374
|
+
if not agent_state:
|
|
375
|
+
raise ValueError(f"Agent does not exist")
|
|
376
|
+
user_id = agent_state.user_id
|
|
377
|
+
|
|
378
|
+
logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}")
|
|
379
|
+
# TODO: consider disabling loading cached agents due to potential concurrency issues
|
|
380
|
+
letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
|
381
|
+
if not letta_agent:
|
|
382
|
+
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
|
|
383
|
+
letta_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
|
|
384
|
+
return letta_agent
|
|
385
|
+
|
|
386
|
+
def _step(self, user_id: str, agent_id: str, input_message: Union[str, Message], timestamp: Optional[datetime]) -> LettaUsageStatistics:
|
|
387
|
+
"""Send the input message through the agent"""
|
|
388
|
+
logger.debug(f"Got input message: {input_message}")
|
|
389
|
+
try:
|
|
390
|
+
|
|
391
|
+
# Get the agent object (loaded in memory)
|
|
392
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
393
|
+
if letta_agent is None:
|
|
394
|
+
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded")
|
|
395
|
+
|
|
396
|
+
# Determine whether or not to token stream based on the capability of the interface
|
|
397
|
+
token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False
|
|
398
|
+
|
|
399
|
+
logger.debug(f"Starting agent step")
|
|
400
|
+
no_verify = True
|
|
401
|
+
next_input_message = input_message
|
|
402
|
+
counter = 0
|
|
403
|
+
total_usage = UsageStatistics()
|
|
404
|
+
step_count = 0
|
|
405
|
+
while True:
|
|
406
|
+
step_response = letta_agent.step(
|
|
407
|
+
next_input_message,
|
|
408
|
+
first_message=False,
|
|
409
|
+
skip_verify=no_verify,
|
|
410
|
+
return_dicts=False,
|
|
411
|
+
stream=token_streaming,
|
|
412
|
+
timestamp=timestamp,
|
|
413
|
+
ms=self.ms,
|
|
414
|
+
)
|
|
415
|
+
step_response.messages
|
|
416
|
+
heartbeat_request = step_response.heartbeat_request
|
|
417
|
+
function_failed = step_response.function_failed
|
|
418
|
+
token_warning = step_response.in_context_memory_warning
|
|
419
|
+
usage = step_response.usage
|
|
420
|
+
|
|
421
|
+
step_count += 1
|
|
422
|
+
total_usage += usage
|
|
423
|
+
counter += 1
|
|
424
|
+
letta_agent.interface.step_complete()
|
|
425
|
+
|
|
426
|
+
logger.debug("Saving agent state")
|
|
427
|
+
# save updated state
|
|
428
|
+
save_agent(letta_agent, self.ms)
|
|
429
|
+
|
|
430
|
+
# Chain stops
|
|
431
|
+
if not self.chaining:
|
|
432
|
+
logger.debug("No chaining, stopping after one step")
|
|
433
|
+
break
|
|
434
|
+
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
|
|
435
|
+
logger.debug(f"Hit max chaining steps, stopping after {counter} steps")
|
|
436
|
+
break
|
|
437
|
+
# Chain handlers
|
|
438
|
+
elif token_warning:
|
|
439
|
+
next_input_message = system.get_token_limit_warning()
|
|
440
|
+
continue # always chain
|
|
441
|
+
elif function_failed:
|
|
442
|
+
next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
|
|
443
|
+
continue # always chain
|
|
444
|
+
elif heartbeat_request:
|
|
445
|
+
next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
|
|
446
|
+
continue # always chain
|
|
447
|
+
# Letta no-op / yield
|
|
448
|
+
else:
|
|
449
|
+
break
|
|
450
|
+
|
|
451
|
+
except Exception as e:
|
|
452
|
+
logger.error(f"Error in server._step: {e}")
|
|
453
|
+
print(traceback.print_exc())
|
|
454
|
+
raise
|
|
455
|
+
finally:
|
|
456
|
+
logger.debug("Calling step_yield()")
|
|
457
|
+
letta_agent.interface.step_yield()
|
|
458
|
+
|
|
459
|
+
return LettaUsageStatistics(**total_usage.dict(), step_count=step_count)
|
|
460
|
+
|
|
461
|
+
def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
|
|
462
|
+
"""Process a CLI command"""
|
|
463
|
+
|
|
464
|
+
logger.debug(f"Got command: {command}")
|
|
465
|
+
|
|
466
|
+
# Get the agent object (loaded in memory)
|
|
467
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
468
|
+
usage = None
|
|
469
|
+
|
|
470
|
+
if command.lower() == "exit":
|
|
471
|
+
# exit not supported on server.py
|
|
472
|
+
raise ValueError(command)
|
|
473
|
+
|
|
474
|
+
elif command.lower() == "save" or command.lower() == "savechat":
|
|
475
|
+
save_agent(letta_agent, self.ms)
|
|
476
|
+
|
|
477
|
+
elif command.lower() == "attach":
|
|
478
|
+
# Different from CLI, we extract the data source name from the command
|
|
479
|
+
command = command.strip().split()
|
|
480
|
+
try:
|
|
481
|
+
data_source = int(command[1])
|
|
482
|
+
except:
|
|
483
|
+
raise ValueError(command)
|
|
484
|
+
|
|
485
|
+
# attach data to agent from source
|
|
486
|
+
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
487
|
+
letta_agent.attach_source(data_source, source_connector, self.ms)
|
|
488
|
+
|
|
489
|
+
elif command.lower() == "dump" or command.lower().startswith("dump "):
|
|
490
|
+
# Check if there's an additional argument that's an integer
|
|
491
|
+
command = command.strip().split()
|
|
492
|
+
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
|
|
493
|
+
if amount == 0:
|
|
494
|
+
letta_agent.interface.print_messages(letta_agent.messages, dump=True)
|
|
495
|
+
else:
|
|
496
|
+
letta_agent.interface.print_messages(letta_agent.messages[-min(amount, len(letta_agent.messages)) :], dump=True)
|
|
497
|
+
|
|
498
|
+
elif command.lower() == "dumpraw":
|
|
499
|
+
letta_agent.interface.print_messages_raw(letta_agent.messages)
|
|
500
|
+
|
|
501
|
+
elif command.lower() == "memory":
|
|
502
|
+
ret_str = (
|
|
503
|
+
f"\nDumping memory contents:\n"
|
|
504
|
+
+ f"\n{str(letta_agent.memory)}"
|
|
505
|
+
+ f"\n{str(letta_agent.persistence_manager.archival_memory)}"
|
|
506
|
+
+ f"\n{str(letta_agent.persistence_manager.recall_memory)}"
|
|
507
|
+
)
|
|
508
|
+
return ret_str
|
|
509
|
+
|
|
510
|
+
elif command.lower() == "pop" or command.lower().startswith("pop "):
|
|
511
|
+
# Check if there's an additional argument that's an integer
|
|
512
|
+
command = command.strip().split()
|
|
513
|
+
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
|
|
514
|
+
n_messages = len(letta_agent.messages)
|
|
515
|
+
MIN_MESSAGES = 2
|
|
516
|
+
if n_messages <= MIN_MESSAGES:
|
|
517
|
+
logger.debug(f"Agent only has {n_messages} messages in stack, none left to pop")
|
|
518
|
+
elif n_messages - pop_amount < MIN_MESSAGES:
|
|
519
|
+
logger.debug(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
|
|
520
|
+
else:
|
|
521
|
+
logger.debug(f"Popping last {pop_amount} messages from stack")
|
|
522
|
+
for _ in range(min(pop_amount, len(letta_agent.messages))):
|
|
523
|
+
letta_agent.messages.pop()
|
|
524
|
+
|
|
525
|
+
elif command.lower() == "retry":
|
|
526
|
+
# TODO this needs to also modify the persistence manager
|
|
527
|
+
logger.debug(f"Retrying for another answer")
|
|
528
|
+
while len(letta_agent.messages) > 0:
|
|
529
|
+
if letta_agent.messages[-1].get("role") == "user":
|
|
530
|
+
# we want to pop up to the last user message and send it again
|
|
531
|
+
letta_agent.messages[-1].get("content")
|
|
532
|
+
letta_agent.messages.pop()
|
|
533
|
+
break
|
|
534
|
+
letta_agent.messages.pop()
|
|
535
|
+
|
|
536
|
+
elif command.lower() == "rethink" or command.lower().startswith("rethink "):
|
|
537
|
+
# TODO this needs to also modify the persistence manager
|
|
538
|
+
if len(command) < len("rethink "):
|
|
539
|
+
logger.warning("Missing text after the command")
|
|
540
|
+
else:
|
|
541
|
+
for x in range(len(letta_agent.messages) - 1, 0, -1):
|
|
542
|
+
if letta_agent.messages[x].get("role") == "assistant":
|
|
543
|
+
text = command[len("rethink ") :].strip()
|
|
544
|
+
letta_agent.messages[x].update({"content": text})
|
|
545
|
+
break
|
|
546
|
+
|
|
547
|
+
elif command.lower() == "rewrite" or command.lower().startswith("rewrite "):
|
|
548
|
+
# TODO this needs to also modify the persistence manager
|
|
549
|
+
if len(command) < len("rewrite "):
|
|
550
|
+
logger.warning("Missing text after the command")
|
|
551
|
+
else:
|
|
552
|
+
for x in range(len(letta_agent.messages) - 1, 0, -1):
|
|
553
|
+
if letta_agent.messages[x].get("role") == "assistant":
|
|
554
|
+
text = command[len("rewrite ") :].strip()
|
|
555
|
+
args = json_loads(letta_agent.messages[x].get("function_call").get("arguments"))
|
|
556
|
+
args["message"] = text
|
|
557
|
+
letta_agent.messages[x].get("function_call").update({"arguments": json_dumps(args)})
|
|
558
|
+
break
|
|
559
|
+
|
|
560
|
+
# No skip options
|
|
561
|
+
elif command.lower() == "wipe":
|
|
562
|
+
# exit not supported on server.py
|
|
563
|
+
raise ValueError(command)
|
|
564
|
+
|
|
565
|
+
elif command.lower() == "heartbeat":
|
|
566
|
+
input_message = system.get_heartbeat()
|
|
567
|
+
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
|
568
|
+
|
|
569
|
+
elif command.lower() == "memorywarning":
|
|
570
|
+
input_message = system.get_token_limit_warning()
|
|
571
|
+
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
|
|
572
|
+
|
|
573
|
+
if not usage:
|
|
574
|
+
usage = LettaUsageStatistics()
|
|
575
|
+
|
|
576
|
+
return usage
|
|
577
|
+
|
|
578
|
+
def user_message(
|
|
579
|
+
self,
|
|
580
|
+
user_id: str,
|
|
581
|
+
agent_id: str,
|
|
582
|
+
message: Union[str, Message],
|
|
583
|
+
timestamp: Optional[datetime] = None,
|
|
584
|
+
) -> LettaUsageStatistics:
|
|
585
|
+
"""Process an incoming user message and feed it through the Letta agent"""
|
|
586
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
587
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
588
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
589
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
590
|
+
|
|
591
|
+
# Basic input sanitization
|
|
592
|
+
if isinstance(message, str):
|
|
593
|
+
if len(message) == 0:
|
|
594
|
+
raise ValueError(f"Invalid input: '{message}'")
|
|
595
|
+
|
|
596
|
+
# If the input begins with a command prefix, reject
|
|
597
|
+
elif message.startswith("/"):
|
|
598
|
+
raise ValueError(f"Invalid input: '{message}'")
|
|
599
|
+
|
|
600
|
+
packaged_user_message = system.package_user_message(
|
|
601
|
+
user_message=message,
|
|
602
|
+
time=timestamp.isoformat() if timestamp else None,
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# NOTE: eventually deprecate and only allow passing Message types
|
|
606
|
+
# Convert to a Message object
|
|
607
|
+
if timestamp:
|
|
608
|
+
message = Message(
|
|
609
|
+
user_id=user_id,
|
|
610
|
+
agent_id=agent_id,
|
|
611
|
+
role="user",
|
|
612
|
+
text=packaged_user_message,
|
|
613
|
+
created_at=timestamp,
|
|
614
|
+
)
|
|
615
|
+
else:
|
|
616
|
+
message = Message(
|
|
617
|
+
user_id=user_id,
|
|
618
|
+
agent_id=agent_id,
|
|
619
|
+
role="user",
|
|
620
|
+
text=packaged_user_message,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Run the agent state forward
|
|
624
|
+
usage = self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message, timestamp=timestamp)
|
|
625
|
+
return usage
|
|
626
|
+
|
|
627
|
+
def system_message(
|
|
628
|
+
self,
|
|
629
|
+
user_id: str,
|
|
630
|
+
agent_id: str,
|
|
631
|
+
message: Union[str, Message],
|
|
632
|
+
timestamp: Optional[datetime] = None,
|
|
633
|
+
) -> LettaUsageStatistics:
|
|
634
|
+
"""Process an incoming system message and feed it through the Letta agent"""
|
|
635
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
636
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
637
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
638
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
639
|
+
|
|
640
|
+
# Basic input sanitization
|
|
641
|
+
if isinstance(message, str):
|
|
642
|
+
if len(message) == 0:
|
|
643
|
+
raise ValueError(f"Invalid input: '{message}'")
|
|
644
|
+
|
|
645
|
+
# If the input begins with a command prefix, reject
|
|
646
|
+
elif message.startswith("/"):
|
|
647
|
+
raise ValueError(f"Invalid input: '{message}'")
|
|
648
|
+
|
|
649
|
+
packaged_system_message = system.package_system_message(system_message=message)
|
|
650
|
+
|
|
651
|
+
# NOTE: eventually deprecate and only allow passing Message types
|
|
652
|
+
# Convert to a Message object
|
|
653
|
+
|
|
654
|
+
if timestamp:
|
|
655
|
+
message = Message(
|
|
656
|
+
user_id=user_id,
|
|
657
|
+
agent_id=agent_id,
|
|
658
|
+
role="system",
|
|
659
|
+
text=packaged_system_message,
|
|
660
|
+
created_at=timestamp,
|
|
661
|
+
)
|
|
662
|
+
else:
|
|
663
|
+
message = Message(
|
|
664
|
+
user_id=user_id,
|
|
665
|
+
agent_id=agent_id,
|
|
666
|
+
role="system",
|
|
667
|
+
text=packaged_system_message,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
if isinstance(message, Message):
|
|
671
|
+
# Can't have a null text field
|
|
672
|
+
if len(message.text) == 0 or message.text is None:
|
|
673
|
+
raise ValueError(f"Invalid input: '{message.text}'")
|
|
674
|
+
# If the input begins with a command prefix, reject
|
|
675
|
+
elif message.text.startswith("/"):
|
|
676
|
+
raise ValueError(f"Invalid input: '{message.text}'")
|
|
677
|
+
|
|
678
|
+
else:
|
|
679
|
+
raise TypeError(f"Invalid input: '{message}' - type {type(message)}")
|
|
680
|
+
|
|
681
|
+
if timestamp:
|
|
682
|
+
# Override the timestamp with what the caller provided
|
|
683
|
+
message.created_at = timestamp
|
|
684
|
+
|
|
685
|
+
# Run the agent state forward
|
|
686
|
+
return self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message, timestamp=timestamp)
|
|
687
|
+
|
|
688
|
+
# @LockingServer.agent_lock_decorator
|
|
689
|
+
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
|
|
690
|
+
"""Run a command on the agent"""
|
|
691
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
692
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
693
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
694
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
695
|
+
|
|
696
|
+
# If the input begins with a command prefix, attempt to process it as a command
|
|
697
|
+
if command.startswith("/"):
|
|
698
|
+
if len(command) > 1:
|
|
699
|
+
command = command[1:] # strip the prefix
|
|
700
|
+
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
|
701
|
+
|
|
702
|
+
def list_users_paginated(self, cursor: str, limit: int) -> List[User]:
|
|
703
|
+
"""List all users"""
|
|
704
|
+
# TODO: make this paginated
|
|
705
|
+
next_cursor, users = self.ms.get_all_users(cursor, limit)
|
|
706
|
+
return next_cursor, users
|
|
707
|
+
|
|
708
|
+
def create_user(self, request: UserCreate) -> User:
|
|
709
|
+
"""Create a new user using a config"""
|
|
710
|
+
if not request.name:
|
|
711
|
+
# auto-generate a name
|
|
712
|
+
request.name = create_random_username()
|
|
713
|
+
user = User(name=request.name, org_id=request.org_id)
|
|
714
|
+
self.ms.create_user(user)
|
|
715
|
+
logger.debug(f"Created new user from config: {user}")
|
|
716
|
+
|
|
717
|
+
# add default for the user
|
|
718
|
+
# TODO: move to org
|
|
719
|
+
assert user.id is not None, f"User id is None: {user}"
|
|
720
|
+
self.add_default_blocks(user.id)
|
|
721
|
+
self.add_default_tools(module_name="base", user_id=user.id)
|
|
722
|
+
|
|
723
|
+
return user
|
|
724
|
+
|
|
725
|
+
def create_organization(self, request: OrganizationCreate) -> Organization:
|
|
726
|
+
"""Create a new org using a config"""
|
|
727
|
+
if not request.name:
|
|
728
|
+
# auto-generate a name
|
|
729
|
+
request.name = create_random_username()
|
|
730
|
+
org = Organization(name=request.name)
|
|
731
|
+
self.ms.create_organization(org)
|
|
732
|
+
logger.info(f"Created new org from config: {org}")
|
|
733
|
+
|
|
734
|
+
# add default for the org
|
|
735
|
+
# TODO: add default data
|
|
736
|
+
|
|
737
|
+
return org
|
|
738
|
+
|
|
739
|
+
def create_agent(
|
|
740
|
+
self,
|
|
741
|
+
request: CreateAgent,
|
|
742
|
+
user_id: str,
|
|
743
|
+
# interface
|
|
744
|
+
interface: Union[AgentInterface, None] = None,
|
|
745
|
+
) -> AgentState:
|
|
746
|
+
"""Create a new agent using a config"""
|
|
747
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
748
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
749
|
+
|
|
750
|
+
if interface is None:
|
|
751
|
+
interface = self.default_interface_factory()
|
|
752
|
+
|
|
753
|
+
# create agent name
|
|
754
|
+
if request.name is None:
|
|
755
|
+
request.name = create_random_username()
|
|
756
|
+
|
|
757
|
+
# system debug
|
|
758
|
+
if request.system is None:
|
|
759
|
+
# TODO: don't hardcode
|
|
760
|
+
request.system = gpt_system.get_system_text("memgpt_chat")
|
|
761
|
+
|
|
762
|
+
logger.debug(f"Attempting to find user: {user_id}")
|
|
763
|
+
user = self.ms.get_user(user_id=user_id)
|
|
764
|
+
if not user:
|
|
765
|
+
raise ValueError(f"cannot find user with associated client id: {user_id}")
|
|
766
|
+
|
|
767
|
+
try:
|
|
768
|
+
# model configuration
|
|
769
|
+
llm_config = request.llm_config if request.llm_config else self.server_llm_config
|
|
770
|
+
embedding_config = request.embedding_config if request.embedding_config else self.server_embedding_config
|
|
771
|
+
|
|
772
|
+
# get tools + make sure they exist
|
|
773
|
+
tool_objs = []
|
|
774
|
+
if request.tools:
|
|
775
|
+
for tool_name in request.tools:
|
|
776
|
+
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
|
777
|
+
assert tool_obj, f"Tool {tool_name} does not exist"
|
|
778
|
+
tool_objs.append(tool_obj)
|
|
779
|
+
|
|
780
|
+
assert request.memory is not None
|
|
781
|
+
memory_functions = get_memory_functions(request.memory)
|
|
782
|
+
for func_name, func in memory_functions.items():
|
|
783
|
+
|
|
784
|
+
if request.tools and func_name in request.tools:
|
|
785
|
+
# tool already added
|
|
786
|
+
continue
|
|
787
|
+
source_code = parse_source_code(func)
|
|
788
|
+
json_schema = generate_schema(func, func_name)
|
|
789
|
+
source_type = "python"
|
|
790
|
+
tags = ["memory", "memgpt-base"]
|
|
791
|
+
tool = self.create_tool(
|
|
792
|
+
request=ToolCreate(
|
|
793
|
+
source_code=source_code,
|
|
794
|
+
source_type=source_type,
|
|
795
|
+
tags=tags,
|
|
796
|
+
json_schema=json_schema,
|
|
797
|
+
user_id=user_id,
|
|
798
|
+
),
|
|
799
|
+
update=True,
|
|
800
|
+
user_id=user_id,
|
|
801
|
+
)
|
|
802
|
+
tool_objs.append(tool)
|
|
803
|
+
if not request.tools:
|
|
804
|
+
request.tools = []
|
|
805
|
+
request.tools.append(tool.name)
|
|
806
|
+
|
|
807
|
+
# TODO: save the agent state
|
|
808
|
+
agent_state = AgentState(
|
|
809
|
+
name=request.name,
|
|
810
|
+
user_id=user_id,
|
|
811
|
+
tools=request.tools if request.tools else [],
|
|
812
|
+
llm_config=llm_config,
|
|
813
|
+
embedding_config=embedding_config,
|
|
814
|
+
system=request.system,
|
|
815
|
+
memory=request.memory,
|
|
816
|
+
description=request.description,
|
|
817
|
+
metadata_=request.metadata_,
|
|
818
|
+
)
|
|
819
|
+
agent = Agent(
|
|
820
|
+
interface=interface,
|
|
821
|
+
agent_state=agent_state,
|
|
822
|
+
tools=tool_objs,
|
|
823
|
+
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
|
824
|
+
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
|
|
825
|
+
)
|
|
826
|
+
# rebuilding agent memory on agent create in case shared memory blocks
|
|
827
|
+
# were specified in the new agent's memory config. we're doing this for two reasons:
|
|
828
|
+
# 1. if only the ID of the shared memory block was specified, we can fetch its most recent value
|
|
829
|
+
# 2. if the shared block state changed since this agent initialization started, we can be sure to have the latest value
|
|
830
|
+
agent.rebuild_memory(force=True, ms=self.ms)
|
|
831
|
+
# FIXME: this is a hacky way to get the system prompts injected into agent into the DB
|
|
832
|
+
# self.ms.update_agent(agent.agent_state)
|
|
833
|
+
except Exception as e:
|
|
834
|
+
logger.exception(e)
|
|
835
|
+
try:
|
|
836
|
+
if agent:
|
|
837
|
+
self.ms.delete_agent(agent_id=agent.agent_state.id)
|
|
838
|
+
except Exception as delete_e:
|
|
839
|
+
logger.exception(f"Failed to delete_agent:\n{delete_e}")
|
|
840
|
+
raise e
|
|
841
|
+
|
|
842
|
+
# save agent
|
|
843
|
+
save_agent(agent, self.ms)
|
|
844
|
+
logger.debug(f"Created new agent from config: {agent}")
|
|
845
|
+
|
|
846
|
+
assert isinstance(agent.agent_state.memory, Memory), f"Invalid memory type: {type(agent_state.memory)}"
|
|
847
|
+
# return AgentState
|
|
848
|
+
return agent.agent_state
|
|
849
|
+
|
|
850
|
+
def update_agent(
|
|
851
|
+
self,
|
|
852
|
+
request: UpdateAgentState,
|
|
853
|
+
user_id: str,
|
|
854
|
+
):
|
|
855
|
+
"""Update the agents core memory block, return the new state"""
|
|
856
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
857
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
858
|
+
if self.ms.get_agent(agent_id=request.id) is None:
|
|
859
|
+
raise ValueError(f"Agent agent_id={request.id} does not exist")
|
|
860
|
+
|
|
861
|
+
# Get the agent object (loaded in memory)
|
|
862
|
+
letta_agent = self._get_or_load_agent(agent_id=request.id)
|
|
863
|
+
|
|
864
|
+
# update the core memory of the agent
|
|
865
|
+
if request.memory:
|
|
866
|
+
assert isinstance(request.memory, Memory), type(request.memory)
|
|
867
|
+
new_memory_contents = request.memory.to_flat_dict()
|
|
868
|
+
_ = self.update_agent_core_memory(user_id=user_id, agent_id=request.id, new_memory_contents=new_memory_contents)
|
|
869
|
+
|
|
870
|
+
# update the system prompt
|
|
871
|
+
if request.system:
|
|
872
|
+
letta_agent.update_system_prompt(request.system)
|
|
873
|
+
|
|
874
|
+
# update in-context messages
|
|
875
|
+
if request.message_ids:
|
|
876
|
+
# This means the user is trying to change what messages are in the message buffer
|
|
877
|
+
# Internally this requires (1) pulling from recall,
|
|
878
|
+
# then (2) setting the attributes ._messages and .state.message_ids
|
|
879
|
+
letta_agent.set_message_buffer(message_ids=request.message_ids)
|
|
880
|
+
|
|
881
|
+
# tools
|
|
882
|
+
if request.tools:
|
|
883
|
+
# Replace tools and also re-link
|
|
884
|
+
|
|
885
|
+
# (1) get tools + make sure they exist
|
|
886
|
+
tool_objs = []
|
|
887
|
+
for tool_name in request.tools:
|
|
888
|
+
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
|
889
|
+
assert tool_obj, f"Tool {tool_name} does not exist"
|
|
890
|
+
tool_objs.append(tool_obj)
|
|
891
|
+
|
|
892
|
+
# (2) replace the list of tool names ("ids") inside the agent state
|
|
893
|
+
letta_agent.agent_state.tools = request.tools
|
|
894
|
+
|
|
895
|
+
# (3) then attempt to link the tools modules
|
|
896
|
+
letta_agent.link_tools(tool_objs)
|
|
897
|
+
|
|
898
|
+
# configs
|
|
899
|
+
if request.llm_config:
|
|
900
|
+
letta_agent.agent_state.llm_config = request.llm_config
|
|
901
|
+
if request.embedding_config:
|
|
902
|
+
letta_agent.agent_state.embedding_config = request.embedding_config
|
|
903
|
+
|
|
904
|
+
# other minor updates
|
|
905
|
+
if request.name:
|
|
906
|
+
letta_agent.agent_state.name = request.name
|
|
907
|
+
if request.metadata_:
|
|
908
|
+
letta_agent.agent_state.metadata_ = request.metadata_
|
|
909
|
+
|
|
910
|
+
# save the agent
|
|
911
|
+
assert isinstance(letta_agent.memory, Memory)
|
|
912
|
+
save_agent(letta_agent, self.ms)
|
|
913
|
+
# TODO: probably reload the agent somehow?
|
|
914
|
+
return letta_agent.agent_state
|
|
915
|
+
|
|
916
|
+
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
|
|
917
|
+
"""Convert AgentState to a dict for a JSON response"""
|
|
918
|
+
assert agent_state is not None
|
|
919
|
+
|
|
920
|
+
agent_config = {
|
|
921
|
+
"id": agent_state.id,
|
|
922
|
+
"name": agent_state.name,
|
|
923
|
+
"human": agent_state._metadata.get("human", None),
|
|
924
|
+
"persona": agent_state._metadata.get("persona", None),
|
|
925
|
+
"created_at": agent_state.created_at.isoformat(),
|
|
926
|
+
}
|
|
927
|
+
return agent_config
|
|
928
|
+
|
|
929
|
+
def list_agents(
|
|
930
|
+
self,
|
|
931
|
+
user_id: str,
|
|
932
|
+
) -> List[AgentState]:
|
|
933
|
+
"""List all available agents to a user"""
|
|
934
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
935
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
936
|
+
|
|
937
|
+
agents_states = self.ms.list_agents(user_id=user_id)
|
|
938
|
+
return agents_states
|
|
939
|
+
|
|
940
|
+
# TODO make return type pydantic
|
|
941
|
+
def list_agents_legacy(
|
|
942
|
+
self,
|
|
943
|
+
user_id: str,
|
|
944
|
+
) -> dict:
|
|
945
|
+
"""List all available agents to a user"""
|
|
946
|
+
|
|
947
|
+
if user_id is None:
|
|
948
|
+
agents_states = self.ms.list_all_agents()
|
|
949
|
+
else:
|
|
950
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
951
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
952
|
+
|
|
953
|
+
agents_states = self.ms.list_agents(user_id=user_id)
|
|
954
|
+
|
|
955
|
+
agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states]
|
|
956
|
+
|
|
957
|
+
# TODO add a get_message_obj_from_message_id(...) function
|
|
958
|
+
# this would allow grabbing Message.created_by without having to load the agent object
|
|
959
|
+
# all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
|
|
960
|
+
self.ms.list_tools()
|
|
961
|
+
|
|
962
|
+
for agent_state, return_dict in zip(agents_states, agents_states_dicts):
|
|
963
|
+
|
|
964
|
+
# Get the agent object (loaded in memory)
|
|
965
|
+
letta_agent = self._get_or_load_agent(user_id=agent_state.user_id, agent_id=agent_state.id)
|
|
966
|
+
|
|
967
|
+
# TODO remove this eventually when return type get pydanticfied
|
|
968
|
+
# this is to add persona_name and human_name so that the columns in UI can populate
|
|
969
|
+
# TODO hack for frontend, remove
|
|
970
|
+
# (top level .persona is persona_name, and nested memory.persona is the state)
|
|
971
|
+
# TODO: eventually modify this to be contained in the metadata
|
|
972
|
+
return_dict["persona"] = agent_state._metadata.get("persona", None)
|
|
973
|
+
return_dict["human"] = agent_state._metadata.get("human", None)
|
|
974
|
+
|
|
975
|
+
# Add information about tools
|
|
976
|
+
# TODO letta_agent should really have a field of List[ToolModel]
|
|
977
|
+
# then we could just pull that field and return it here
|
|
978
|
+
# return_dict["tools"] = [tool for tool in all_available_tools if tool.json_schema in letta_agent.functions]
|
|
979
|
+
|
|
980
|
+
# get tool info from agent state
|
|
981
|
+
tools = []
|
|
982
|
+
for tool_name in agent_state.tools:
|
|
983
|
+
tool = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
|
984
|
+
tools.append(tool)
|
|
985
|
+
return_dict["tools"] = tools
|
|
986
|
+
|
|
987
|
+
# Add information about memory (raw core, size of recall, size of archival)
|
|
988
|
+
core_memory = letta_agent.memory
|
|
989
|
+
recall_memory = letta_agent.persistence_manager.recall_memory
|
|
990
|
+
archival_memory = letta_agent.persistence_manager.archival_memory
|
|
991
|
+
memory_obj = {
|
|
992
|
+
"core_memory": core_memory.to_flat_dict(),
|
|
993
|
+
"recall_memory": len(recall_memory) if recall_memory is not None else None,
|
|
994
|
+
"archival_memory": len(archival_memory) if archival_memory is not None else None,
|
|
995
|
+
}
|
|
996
|
+
return_dict["memory"] = memory_obj
|
|
997
|
+
|
|
998
|
+
# Add information about last run
|
|
999
|
+
# NOTE: 'last_run' is just the timestamp on the latest message in the buffer
|
|
1000
|
+
# Retrieve the Message object via the recall storage or by directly access _messages
|
|
1001
|
+
last_msg_obj = letta_agent._messages[-1]
|
|
1002
|
+
return_dict["last_run"] = last_msg_obj.created_at
|
|
1003
|
+
|
|
1004
|
+
# Add information about attached sources
|
|
1005
|
+
sources_ids = self.ms.list_attached_sources(agent_id=agent_state.id)
|
|
1006
|
+
sources = [self.ms.get_source(source_id=s_id) for s_id in sources_ids]
|
|
1007
|
+
return_dict["sources"] = [vars(s) for s in sources]
|
|
1008
|
+
|
|
1009
|
+
# Sort agents by "last_run" in descending order, most recent first
|
|
1010
|
+
agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True)
|
|
1011
|
+
|
|
1012
|
+
logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}")
|
|
1013
|
+
return {
|
|
1014
|
+
"num_agents": len(agents_states),
|
|
1015
|
+
"agents": agents_states_dicts,
|
|
1016
|
+
}
|
|
1017
|
+
|
|
1018
|
+
# blocks
|
|
1019
|
+
|
|
1020
|
+
def get_blocks(
|
|
1021
|
+
self,
|
|
1022
|
+
user_id: Optional[str] = None,
|
|
1023
|
+
label: Optional[str] = None,
|
|
1024
|
+
template: Optional[bool] = None,
|
|
1025
|
+
name: Optional[str] = None,
|
|
1026
|
+
id: Optional[str] = None,
|
|
1027
|
+
) -> Optional[List[Block]]:
|
|
1028
|
+
|
|
1029
|
+
return self.ms.get_blocks(user_id=user_id, label=label, template=template, name=name, id=id)
|
|
1030
|
+
|
|
1031
|
+
def get_block(self, block_id: str):
|
|
1032
|
+
|
|
1033
|
+
blocks = self.get_blocks(id=block_id)
|
|
1034
|
+
if blocks is None or len(blocks) == 0:
|
|
1035
|
+
raise ValueError("Block does not exist")
|
|
1036
|
+
if len(blocks) > 1:
|
|
1037
|
+
raise ValueError("Multiple blocks with the same id")
|
|
1038
|
+
return blocks[0]
|
|
1039
|
+
|
|
1040
|
+
def create_block(self, request: CreateBlock, user_id: str, update: bool = False) -> Block:
|
|
1041
|
+
existing_blocks = self.ms.get_blocks(name=request.name, user_id=user_id, template=request.template, label=request.label)
|
|
1042
|
+
if existing_blocks is not None:
|
|
1043
|
+
existing_block = existing_blocks[0]
|
|
1044
|
+
assert len(existing_blocks) == 1
|
|
1045
|
+
if update:
|
|
1046
|
+
return self.update_block(UpdateBlock(id=existing_block.id, **vars(request)), user_id)
|
|
1047
|
+
else:
|
|
1048
|
+
raise ValueError(f"Block with name {request.name} already exists")
|
|
1049
|
+
block = Block(**vars(request))
|
|
1050
|
+
self.ms.create_block(block)
|
|
1051
|
+
return block
|
|
1052
|
+
|
|
1053
|
+
def update_block(self, request: UpdateBlock) -> Block:
|
|
1054
|
+
block = self.get_block(request.id)
|
|
1055
|
+
block.limit = request.limit if request.limit is not None else block.limit
|
|
1056
|
+
block.value = request.value if request.value is not None else block.value
|
|
1057
|
+
block.name = request.name if request.name is not None else block.name
|
|
1058
|
+
self.ms.update_block(block=block)
|
|
1059
|
+
return block
|
|
1060
|
+
|
|
1061
|
+
def delete_block(self, block_id: str):
|
|
1062
|
+
block = self.get_block(block_id)
|
|
1063
|
+
self.ms.delete_block(block_id)
|
|
1064
|
+
return block
|
|
1065
|
+
|
|
1066
|
+
# convert name->id
|
|
1067
|
+
|
|
1068
|
+
def get_agent_id(self, name: str, user_id: str):
|
|
1069
|
+
agent_state = self.ms.get_agent(agent_name=name, user_id=user_id)
|
|
1070
|
+
if not agent_state:
|
|
1071
|
+
return None
|
|
1072
|
+
return agent_state.id
|
|
1073
|
+
|
|
1074
|
+
def get_source(self, source_id: str, user_id: str) -> Source:
|
|
1075
|
+
existing_source = self.ms.get_source(source_id=source_id, user_id=user_id)
|
|
1076
|
+
if not existing_source:
|
|
1077
|
+
raise ValueError("Source does not exist")
|
|
1078
|
+
return existing_source
|
|
1079
|
+
|
|
1080
|
+
def get_source_id(self, source_name: str, user_id: str) -> str:
|
|
1081
|
+
existing_source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
|
1082
|
+
if not existing_source:
|
|
1083
|
+
raise ValueError("Source does not exist")
|
|
1084
|
+
return existing_source.id
|
|
1085
|
+
|
|
1086
|
+
def get_agent(self, user_id: str, agent_id: str, agent_name: Optional[str] = None):
|
|
1087
|
+
"""Get the agent state"""
|
|
1088
|
+
return self.ms.get_agent(agent_id=agent_id, user_id=user_id)
|
|
1089
|
+
|
|
1090
|
+
def get_user(self, user_id: str) -> User:
|
|
1091
|
+
"""Get the user"""
|
|
1092
|
+
return self.ms.get_user(user_id=user_id)
|
|
1093
|
+
|
|
1094
|
+
def get_agent_memory(self, agent_id: str) -> Memory:
|
|
1095
|
+
"""Return the memory of an agent (core memory)"""
|
|
1096
|
+
agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1097
|
+
return agent.memory
|
|
1098
|
+
|
|
1099
|
+
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
|
|
1100
|
+
agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1101
|
+
return ArchivalMemorySummary(size=len(agent.persistence_manager.archival_memory))
|
|
1102
|
+
|
|
1103
|
+
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
|
|
1104
|
+
agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1105
|
+
return RecallMemorySummary(size=len(agent.persistence_manager.recall_memory))
|
|
1106
|
+
|
|
1107
|
+
def get_in_context_message_ids(self, agent_id: str) -> List[str]:
|
|
1108
|
+
"""Get the message ids of the in-context messages in the agent's memory"""
|
|
1109
|
+
# Get the agent object (loaded in memory)
|
|
1110
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1111
|
+
return [m.id for m in letta_agent._messages]
|
|
1112
|
+
|
|
1113
|
+
def get_in_context_messages(self, agent_id: str) -> List[Message]:
|
|
1114
|
+
"""Get the in-context messages in the agent's memory"""
|
|
1115
|
+
# Get the agent object (loaded in memory)
|
|
1116
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1117
|
+
return letta_agent._messages
|
|
1118
|
+
|
|
1119
|
+
def get_agent_message(self, agent_id: str, message_id: str) -> Message:
|
|
1120
|
+
"""Get a single message from the agent's memory"""
|
|
1121
|
+
# Get the agent object (loaded in memory)
|
|
1122
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1123
|
+
message = letta_agent.persistence_manager.recall_memory.storage.get(id=message_id)
|
|
1124
|
+
return message
|
|
1125
|
+
|
|
1126
|
+
def get_agent_messages(
|
|
1127
|
+
self,
|
|
1128
|
+
agent_id: str,
|
|
1129
|
+
start: int,
|
|
1130
|
+
count: int,
|
|
1131
|
+
return_message_object: bool = True,
|
|
1132
|
+
) -> Union[List[Message], List[LettaMessage]]:
|
|
1133
|
+
"""Paginated query of all messages in agent message queue"""
|
|
1134
|
+
# Get the agent object (loaded in memory)
|
|
1135
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1136
|
+
|
|
1137
|
+
if start < 0 or count < 0:
|
|
1138
|
+
raise ValueError("Start and count values should be non-negative")
|
|
1139
|
+
|
|
1140
|
+
if start + count < len(letta_agent._messages): # messages can be returned from whats in memory
|
|
1141
|
+
# Reverse the list to make it in reverse chronological order
|
|
1142
|
+
reversed_messages = letta_agent._messages[::-1]
|
|
1143
|
+
# Check if start is within the range of the list
|
|
1144
|
+
if start >= len(reversed_messages):
|
|
1145
|
+
raise IndexError("Start index is out of range")
|
|
1146
|
+
|
|
1147
|
+
# Calculate the end index, ensuring it does not exceed the list length
|
|
1148
|
+
end_index = min(start + count, len(reversed_messages))
|
|
1149
|
+
|
|
1150
|
+
# Slice the list for pagination
|
|
1151
|
+
messages = reversed_messages[start:end_index]
|
|
1152
|
+
|
|
1153
|
+
## Convert to json
|
|
1154
|
+
## Add a tag indicating in-context or not
|
|
1155
|
+
# json_messages = [{**record.to_json(), "in_context": True} for record in messages]
|
|
1156
|
+
|
|
1157
|
+
else:
|
|
1158
|
+
# need to access persistence manager for additional messages
|
|
1159
|
+
db_iterator = letta_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start)
|
|
1160
|
+
|
|
1161
|
+
# get a single page of messages
|
|
1162
|
+
# TODO: handle stop iteration
|
|
1163
|
+
page = next(db_iterator, [])
|
|
1164
|
+
|
|
1165
|
+
# return messages in reverse chronological order
|
|
1166
|
+
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
|
|
1167
|
+
assert all(isinstance(m, Message) for m in messages)
|
|
1168
|
+
|
|
1169
|
+
## Convert to json
|
|
1170
|
+
## Add a tag indicating in-context or not
|
|
1171
|
+
# json_messages = [record.to_json() for record in messages]
|
|
1172
|
+
# in_context_message_ids = [str(m.id) for m in letta_agent._messages]
|
|
1173
|
+
# for d in json_messages:
|
|
1174
|
+
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
|
|
1175
|
+
|
|
1176
|
+
if not return_message_object:
|
|
1177
|
+
messages = [msg for m in messages for msg in m.to_letta_message()]
|
|
1178
|
+
|
|
1179
|
+
return messages
|
|
1180
|
+
|
|
1181
|
+
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
|
|
1182
|
+
"""Paginated query of all messages in agent archival memory"""
|
|
1183
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1184
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1185
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1186
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1187
|
+
|
|
1188
|
+
# Get the agent object (loaded in memory)
|
|
1189
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1190
|
+
|
|
1191
|
+
# iterate over records
|
|
1192
|
+
db_iterator = letta_agent.persistence_manager.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
|
|
1193
|
+
|
|
1194
|
+
# get a single page of messages
|
|
1195
|
+
page = next(db_iterator, [])
|
|
1196
|
+
return page
|
|
1197
|
+
|
|
1198
|
+
def get_agent_archival_cursor(
|
|
1199
|
+
self,
|
|
1200
|
+
user_id: str,
|
|
1201
|
+
agent_id: str,
|
|
1202
|
+
after: Optional[str] = None,
|
|
1203
|
+
before: Optional[str] = None,
|
|
1204
|
+
limit: Optional[int] = 100,
|
|
1205
|
+
order_by: Optional[str] = "created_at",
|
|
1206
|
+
reverse: Optional[bool] = False,
|
|
1207
|
+
) -> List[Passage]:
|
|
1208
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1209
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1210
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1211
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1212
|
+
|
|
1213
|
+
# Get the agent object (loaded in memory)
|
|
1214
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1215
|
+
|
|
1216
|
+
# iterate over recorde
|
|
1217
|
+
cursor, records = letta_agent.persistence_manager.archival_memory.storage.get_all_cursor(
|
|
1218
|
+
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
|
1219
|
+
)
|
|
1220
|
+
return records
|
|
1221
|
+
|
|
1222
|
+
def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[Passage]:
|
|
1223
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1224
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1225
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1226
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1227
|
+
|
|
1228
|
+
# Get the agent object (loaded in memory)
|
|
1229
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1230
|
+
|
|
1231
|
+
# Insert into archival memory
|
|
1232
|
+
passage_ids = letta_agent.persistence_manager.archival_memory.insert(memory_string=memory_contents, return_ids=True)
|
|
1233
|
+
|
|
1234
|
+
# TODO: this is gross, fix
|
|
1235
|
+
return [letta_agent.persistence_manager.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
|
|
1236
|
+
|
|
1237
|
+
def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str):
|
|
1238
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1239
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1240
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1241
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1242
|
+
|
|
1243
|
+
# TODO: should return a passage
|
|
1244
|
+
|
|
1245
|
+
# Get the agent object (loaded in memory)
|
|
1246
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1247
|
+
|
|
1248
|
+
# Delete by ID
|
|
1249
|
+
# TODO check if it exists first, and throw error if not
|
|
1250
|
+
letta_agent.persistence_manager.archival_memory.storage.delete({"id": memory_id})
|
|
1251
|
+
|
|
1252
|
+
# TODO: return archival memory
|
|
1253
|
+
|
|
1254
|
+
def get_agent_recall_cursor(
|
|
1255
|
+
self,
|
|
1256
|
+
user_id: str,
|
|
1257
|
+
agent_id: str,
|
|
1258
|
+
after: Optional[str] = None,
|
|
1259
|
+
before: Optional[str] = None,
|
|
1260
|
+
limit: Optional[int] = 100,
|
|
1261
|
+
order_by: Optional[str] = "created_at",
|
|
1262
|
+
order: Optional[str] = "asc",
|
|
1263
|
+
reverse: Optional[bool] = False,
|
|
1264
|
+
return_message_object: bool = True,
|
|
1265
|
+
) -> Union[List[Message], List[LettaMessage]]:
|
|
1266
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1267
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1268
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1269
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1270
|
+
|
|
1271
|
+
# Get the agent object (loaded in memory)
|
|
1272
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1273
|
+
|
|
1274
|
+
# iterate over records
|
|
1275
|
+
cursor, records = letta_agent.persistence_manager.recall_memory.storage.get_all_cursor(
|
|
1276
|
+
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
|
1277
|
+
)
|
|
1278
|
+
|
|
1279
|
+
assert all(isinstance(m, Message) for m in records)
|
|
1280
|
+
|
|
1281
|
+
if not return_message_object:
|
|
1282
|
+
# If we're GETing messages in reverse, we need to reverse the inner list (generated by to_letta_message)
|
|
1283
|
+
if reverse:
|
|
1284
|
+
records = [msg for m in records for msg in m.to_letta_message()[::-1]]
|
|
1285
|
+
else:
|
|
1286
|
+
records = [msg for m in records for msg in m.to_letta_message()]
|
|
1287
|
+
|
|
1288
|
+
return records
|
|
1289
|
+
|
|
1290
|
+
def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]:
|
|
1291
|
+
"""Return the config of an agent"""
|
|
1292
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1293
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1294
|
+
if agent_id:
|
|
1295
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1296
|
+
return None
|
|
1297
|
+
else:
|
|
1298
|
+
agent_state = self.ms.get_agent(agent_name=agent_name, user_id=user_id)
|
|
1299
|
+
if agent_state is None:
|
|
1300
|
+
raise ValueError(f"Agent agent_name={agent_name} does not exist")
|
|
1301
|
+
agent_id = agent_state.id
|
|
1302
|
+
|
|
1303
|
+
# Get the agent object (loaded in memory)
|
|
1304
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1305
|
+
assert isinstance(letta_agent.memory, Memory)
|
|
1306
|
+
assert isinstance(letta_agent.agent_state.memory, Memory)
|
|
1307
|
+
return letta_agent.agent_state.model_copy(deep=True)
|
|
1308
|
+
|
|
1309
|
+
def get_server_config(self, include_defaults: bool = False) -> dict:
|
|
1310
|
+
"""Return the base config"""
|
|
1311
|
+
|
|
1312
|
+
def clean_keys(config):
|
|
1313
|
+
config_copy = config.copy()
|
|
1314
|
+
for k, v in config.items():
|
|
1315
|
+
if k == "key" or "_key" in k:
|
|
1316
|
+
config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
|
|
1317
|
+
return config_copy
|
|
1318
|
+
|
|
1319
|
+
# TODO: do we need a seperate server config?
|
|
1320
|
+
base_config = vars(self.config)
|
|
1321
|
+
clean_base_config = clean_keys(base_config)
|
|
1322
|
+
|
|
1323
|
+
clean_base_config_default_llm_config_dict = vars(clean_base_config["default_llm_config"])
|
|
1324
|
+
clean_base_config_default_embedding_config_dict = vars(clean_base_config["default_embedding_config"])
|
|
1325
|
+
|
|
1326
|
+
clean_base_config["default_llm_config"] = clean_base_config_default_llm_config_dict
|
|
1327
|
+
clean_base_config["default_embedding_config"] = clean_base_config_default_embedding_config_dict
|
|
1328
|
+
response = {"config": clean_base_config}
|
|
1329
|
+
|
|
1330
|
+
if include_defaults:
|
|
1331
|
+
default_config = vars(LettaConfig())
|
|
1332
|
+
clean_default_config = clean_keys(default_config)
|
|
1333
|
+
clean_default_config["default_llm_config"] = clean_base_config_default_llm_config_dict
|
|
1334
|
+
clean_default_config["default_embedding_config"] = clean_base_config_default_embedding_config_dict
|
|
1335
|
+
response["defaults"] = clean_default_config
|
|
1336
|
+
|
|
1337
|
+
return response
|
|
1338
|
+
|
|
1339
|
+
def get_available_models(self) -> List[LLMConfig]:
|
|
1340
|
+
"""Poll the LLM endpoint for a list of available models"""
|
|
1341
|
+
|
|
1342
|
+
credentials = LettaCredentials().load()
|
|
1343
|
+
|
|
1344
|
+
try:
|
|
1345
|
+
model_options = get_model_options(
|
|
1346
|
+
credentials=credentials,
|
|
1347
|
+
model_endpoint_type=self.config.default_llm_config.model_endpoint_type,
|
|
1348
|
+
model_endpoint=self.config.default_llm_config.model_endpoint,
|
|
1349
|
+
)
|
|
1350
|
+
return model_options
|
|
1351
|
+
|
|
1352
|
+
except Exception as e:
|
|
1353
|
+
logger.exception(f"Failed to get list of available models from LLM endpoint:\n{str(e)}")
|
|
1354
|
+
raise
|
|
1355
|
+
|
|
1356
|
+
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> Memory:
|
|
1357
|
+
"""Update the agents core memory block, return the new state"""
|
|
1358
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1359
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1360
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1361
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1362
|
+
|
|
1363
|
+
# Get the agent object (loaded in memory)
|
|
1364
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1365
|
+
|
|
1366
|
+
# old_core_memory = self.get_agent_memory(agent_id=agent_id)
|
|
1367
|
+
|
|
1368
|
+
modified = False
|
|
1369
|
+
for key, value in new_memory_contents.items():
|
|
1370
|
+
if letta_agent.memory.get_block(key) is None:
|
|
1371
|
+
# raise ValueError(f"Key {key} not found in agent memory {list(letta_agent.memory.list_block_names())}")
|
|
1372
|
+
raise ValueError(f"Key {key} not found in agent memory {str(letta_agent.memory.memory)}")
|
|
1373
|
+
if value is None:
|
|
1374
|
+
continue
|
|
1375
|
+
if letta_agent.memory.get_block(key) != value:
|
|
1376
|
+
letta_agent.memory.update_block_value(name=key, value=value) # update agent memory
|
|
1377
|
+
modified = True
|
|
1378
|
+
|
|
1379
|
+
# If we modified the memory contents, we need to rebuild the memory block inside the system message
|
|
1380
|
+
if modified:
|
|
1381
|
+
letta_agent.rebuild_memory()
|
|
1382
|
+
# save agent
|
|
1383
|
+
save_agent(letta_agent, self.ms)
|
|
1384
|
+
|
|
1385
|
+
return self.ms.get_agent(agent_id=agent_id).memory
|
|
1386
|
+
|
|
1387
|
+
def rename_agent(self, user_id: str, agent_id: str, new_agent_name: str) -> AgentState:
|
|
1388
|
+
"""Update the name of the agent in the database"""
|
|
1389
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1390
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1391
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1392
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1393
|
+
|
|
1394
|
+
# Get the agent object (loaded in memory)
|
|
1395
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1396
|
+
|
|
1397
|
+
current_name = letta_agent.agent_state.name
|
|
1398
|
+
if current_name == new_agent_name:
|
|
1399
|
+
raise ValueError(f"New name ({new_agent_name}) is the same as the current name")
|
|
1400
|
+
|
|
1401
|
+
try:
|
|
1402
|
+
letta_agent.agent_state.name = new_agent_name
|
|
1403
|
+
self.ms.update_agent(agent=letta_agent.agent_state)
|
|
1404
|
+
except Exception as e:
|
|
1405
|
+
logger.exception(f"Failed to update agent name with:\n{str(e)}")
|
|
1406
|
+
raise ValueError(f"Failed to update agent name in database")
|
|
1407
|
+
|
|
1408
|
+
assert isinstance(letta_agent.agent_state.id, str)
|
|
1409
|
+
return letta_agent.agent_state
|
|
1410
|
+
|
|
1411
|
+
def delete_user(self, user_id: str):
|
|
1412
|
+
# TODO: delete user
|
|
1413
|
+
pass
|
|
1414
|
+
|
|
1415
|
+
def delete_agent(self, user_id: str, agent_id: str):
|
|
1416
|
+
"""Delete an agent in the database"""
|
|
1417
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
1418
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
1419
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
1420
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
1421
|
+
|
|
1422
|
+
# Verify that the agent exists and is owned by the user
|
|
1423
|
+
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
|
|
1424
|
+
if not agent_state:
|
|
1425
|
+
raise ValueError(f"Could not find agent_id={agent_id} under user_id={user_id}")
|
|
1426
|
+
if agent_state.user_id != user_id:
|
|
1427
|
+
raise ValueError(f"Could not authorize agent_id={agent_id} with user_id={user_id}")
|
|
1428
|
+
|
|
1429
|
+
# First, if the agent is in the in-memory cache we should remove it
|
|
1430
|
+
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
|
1431
|
+
try:
|
|
1432
|
+
self.active_agents = [d for d in self.active_agents if str(d["agent_id"]) != str(agent_id)]
|
|
1433
|
+
except Exception as e:
|
|
1434
|
+
logger.exception(f"Failed to delete agent {agent_id} from cache via ID with:\n{str(e)}")
|
|
1435
|
+
raise ValueError(f"Failed to delete agent {agent_id} from cache")
|
|
1436
|
+
|
|
1437
|
+
# Next, attempt to delete it from the actual database
|
|
1438
|
+
try:
|
|
1439
|
+
self.ms.delete_agent(agent_id=agent_id)
|
|
1440
|
+
except Exception as e:
|
|
1441
|
+
logger.exception(f"Failed to delete agent {agent_id} via ID with:\n{str(e)}")
|
|
1442
|
+
raise ValueError(f"Failed to delete agent {agent_id} in database")
|
|
1443
|
+
|
|
1444
|
+
def authenticate_user(self) -> str:
|
|
1445
|
+
# TODO: Implement actual authentication to enable multi user setup
|
|
1446
|
+
return str(LettaConfig.load().anon_clientid)
|
|
1447
|
+
|
|
1448
|
+
def api_key_to_user(self, api_key: str) -> str:
|
|
1449
|
+
"""Decode an API key to a user"""
|
|
1450
|
+
user = self.ms.get_user_from_api_key(api_key=api_key)
|
|
1451
|
+
if user is None:
|
|
1452
|
+
raise HTTPException(status_code=403, detail="Invalid credentials")
|
|
1453
|
+
else:
|
|
1454
|
+
return user.id
|
|
1455
|
+
|
|
1456
|
+
def create_api_key(self, request: APIKeyCreate) -> APIKey: # TODO: add other fields
|
|
1457
|
+
"""Create a new API key for a user"""
|
|
1458
|
+
if request.name is None:
|
|
1459
|
+
request.name = f"API Key {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
|
1460
|
+
token = self.ms.create_api_key(user_id=request.user_id, name=request.name)
|
|
1461
|
+
return token
|
|
1462
|
+
|
|
1463
|
+
def list_api_keys(self, user_id: str) -> List[APIKey]:
|
|
1464
|
+
"""List all API keys for a user"""
|
|
1465
|
+
return self.ms.get_all_api_keys_for_user(user_id=user_id)
|
|
1466
|
+
|
|
1467
|
+
def delete_api_key(self, api_key: str) -> APIKey:
|
|
1468
|
+
api_key_obj = self.ms.get_api_key(api_key=api_key)
|
|
1469
|
+
if api_key_obj is None:
|
|
1470
|
+
raise ValueError("API key does not exist")
|
|
1471
|
+
self.ms.delete_api_key(api_key=api_key)
|
|
1472
|
+
return api_key_obj
|
|
1473
|
+
|
|
1474
|
+
def create_source(self, request: SourceCreate, user_id: str) -> Source: # TODO: add other fields
|
|
1475
|
+
"""Create a new data source"""
|
|
1476
|
+
source = Source(
|
|
1477
|
+
name=request.name,
|
|
1478
|
+
user_id=user_id,
|
|
1479
|
+
embedding_config=self.config.default_embedding_config,
|
|
1480
|
+
)
|
|
1481
|
+
self.ms.create_source(source)
|
|
1482
|
+
assert self.ms.get_source(source_name=request.name, user_id=user_id) is not None, f"Failed to create source {request.name}"
|
|
1483
|
+
return source
|
|
1484
|
+
|
|
1485
|
+
def update_source(self, request: SourceUpdate, user_id: str) -> Source:
|
|
1486
|
+
"""Update an existing data source"""
|
|
1487
|
+
if not request.id:
|
|
1488
|
+
existing_source = self.ms.get_source(source_name=request.name, user_id=user_id)
|
|
1489
|
+
else:
|
|
1490
|
+
existing_source = self.ms.get_source(source_id=request.id)
|
|
1491
|
+
if not existing_source:
|
|
1492
|
+
raise ValueError("Source does not exist")
|
|
1493
|
+
|
|
1494
|
+
# override updated fields
|
|
1495
|
+
if request.name:
|
|
1496
|
+
existing_source.name = request.name
|
|
1497
|
+
if request.metadata_:
|
|
1498
|
+
existing_source.metadata_ = request.metadata_
|
|
1499
|
+
if request.description:
|
|
1500
|
+
existing_source.description = request.description
|
|
1501
|
+
|
|
1502
|
+
self.ms.update_source(existing_source)
|
|
1503
|
+
return existing_source
|
|
1504
|
+
|
|
1505
|
+
def delete_source(self, source_id: str, user_id: str):
|
|
1506
|
+
"""Delete a data source"""
|
|
1507
|
+
source = self.ms.get_source(source_id=source_id, user_id=user_id)
|
|
1508
|
+
self.ms.delete_source(source_id)
|
|
1509
|
+
|
|
1510
|
+
# delete data from passage store
|
|
1511
|
+
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1512
|
+
passage_store.delete({"source_id": source_id})
|
|
1513
|
+
|
|
1514
|
+
# TODO: delete data from agent passage stores (?)
|
|
1515
|
+
|
|
1516
|
+
def create_job(self, user_id: str) -> Job:
|
|
1517
|
+
"""Create a new job"""
|
|
1518
|
+
job = Job(
|
|
1519
|
+
user_id=user_id,
|
|
1520
|
+
status=JobStatus.created,
|
|
1521
|
+
)
|
|
1522
|
+
self.ms.create_job(job)
|
|
1523
|
+
return job
|
|
1524
|
+
|
|
1525
|
+
def delete_job(self, job_id: str):
|
|
1526
|
+
"""Delete a job"""
|
|
1527
|
+
self.ms.delete_job(job_id)
|
|
1528
|
+
|
|
1529
|
+
def get_job(self, job_id: str) -> Job:
|
|
1530
|
+
"""Get a job"""
|
|
1531
|
+
return self.ms.get_job(job_id)
|
|
1532
|
+
|
|
1533
|
+
def list_jobs(self, user_id: str) -> List[Job]:
|
|
1534
|
+
"""List all jobs for a user"""
|
|
1535
|
+
return self.ms.list_jobs(user_id=user_id)
|
|
1536
|
+
|
|
1537
|
+
def list_active_jobs(self, user_id: str) -> List[Job]:
|
|
1538
|
+
"""List all active jobs for a user"""
|
|
1539
|
+
jobs = self.ms.list_jobs(user_id=user_id)
|
|
1540
|
+
return [job for job in jobs if job.status in [JobStatus.created, JobStatus.running]]
|
|
1541
|
+
|
|
1542
|
+
def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Job:
|
|
1543
|
+
|
|
1544
|
+
# update job
|
|
1545
|
+
job = self.ms.get_job(job_id)
|
|
1546
|
+
job.status = JobStatus.running
|
|
1547
|
+
self.ms.update_job(job)
|
|
1548
|
+
|
|
1549
|
+
# try:
|
|
1550
|
+
from letta.data_sources.connectors import DirectoryConnector
|
|
1551
|
+
|
|
1552
|
+
source = self.ms.get_source(source_id=source_id)
|
|
1553
|
+
connector = DirectoryConnector(input_files=[file_path])
|
|
1554
|
+
num_passages, num_documents = self.load_data(user_id=source.user_id, source_name=source.name, connector=connector)
|
|
1555
|
+
# except Exception as e:
|
|
1556
|
+
# # job failed with error
|
|
1557
|
+
# error = str(e)
|
|
1558
|
+
# print(error)
|
|
1559
|
+
# job.status = JobStatus.failed
|
|
1560
|
+
# job.metadata_["error"] = error
|
|
1561
|
+
# self.ms.update_job(job)
|
|
1562
|
+
# # TODO: delete any associated passages/documents?
|
|
1563
|
+
|
|
1564
|
+
# # return failed job
|
|
1565
|
+
# return job
|
|
1566
|
+
|
|
1567
|
+
# update job status
|
|
1568
|
+
job.status = JobStatus.completed
|
|
1569
|
+
job.metadata_["num_passages"] = num_passages
|
|
1570
|
+
job.metadata_["num_documents"] = num_documents
|
|
1571
|
+
self.ms.update_job(job)
|
|
1572
|
+
|
|
1573
|
+
return job
|
|
1574
|
+
|
|
1575
|
+
def load_data(
|
|
1576
|
+
self,
|
|
1577
|
+
user_id: str,
|
|
1578
|
+
connector: DataConnector,
|
|
1579
|
+
source_name: str,
|
|
1580
|
+
) -> Tuple[int, int]:
|
|
1581
|
+
"""Load data from a DataConnector into a source for a specified user_id"""
|
|
1582
|
+
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
|
1583
|
+
|
|
1584
|
+
# load data from a data source into the document store
|
|
1585
|
+
source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
|
1586
|
+
if source is None:
|
|
1587
|
+
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
|
|
1588
|
+
|
|
1589
|
+
# get the data connectors
|
|
1590
|
+
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1591
|
+
# TODO: add document store support
|
|
1592
|
+
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
|
1593
|
+
|
|
1594
|
+
# load data into the document store
|
|
1595
|
+
passage_count, document_count = load_data(connector, source, passage_store, document_store)
|
|
1596
|
+
return passage_count, document_count
|
|
1597
|
+
|
|
1598
|
+
def attach_source_to_agent(
|
|
1599
|
+
self,
|
|
1600
|
+
user_id: str,
|
|
1601
|
+
agent_id: str,
|
|
1602
|
+
# source_id: str,
|
|
1603
|
+
source_id: Optional[str] = None,
|
|
1604
|
+
source_name: Optional[str] = None,
|
|
1605
|
+
) -> Source:
|
|
1606
|
+
# attach a data source to an agent
|
|
1607
|
+
data_source = self.ms.get_source(source_id=source_id, user_id=user_id, source_name=source_name)
|
|
1608
|
+
if data_source is None:
|
|
1609
|
+
raise ValueError(f"Data source id={source_id} name={source_name} does not exist for user_id {user_id}")
|
|
1610
|
+
|
|
1611
|
+
# get connection to data source storage
|
|
1612
|
+
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1613
|
+
|
|
1614
|
+
# load agent
|
|
1615
|
+
agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1616
|
+
|
|
1617
|
+
# attach source to agent
|
|
1618
|
+
agent.attach_source(data_source.id, source_connector, self.ms)
|
|
1619
|
+
|
|
1620
|
+
return data_source
|
|
1621
|
+
|
|
1622
|
+
def detach_source_from_agent(
|
|
1623
|
+
self,
|
|
1624
|
+
user_id: str,
|
|
1625
|
+
agent_id: str,
|
|
1626
|
+
# source_id: str,
|
|
1627
|
+
source_id: Optional[str] = None,
|
|
1628
|
+
source_name: Optional[str] = None,
|
|
1629
|
+
) -> Source:
|
|
1630
|
+
# TODO: remove all passages coresponding to source from agent's archival memory
|
|
1631
|
+
raise NotImplementedError
|
|
1632
|
+
|
|
1633
|
+
def list_attached_sources(self, agent_id: str) -> List[Source]:
|
|
1634
|
+
# list all attached sources to an agent
|
|
1635
|
+
return self.ms.list_attached_sources(agent_id)
|
|
1636
|
+
|
|
1637
|
+
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
|
|
1638
|
+
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
|
|
1639
|
+
return []
|
|
1640
|
+
|
|
1641
|
+
def list_data_source_documents(self, user_id: str, source_id: str) -> List[Document]:
|
|
1642
|
+
warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning)
|
|
1643
|
+
return []
|
|
1644
|
+
|
|
1645
|
+
def list_all_sources(self, user_id: str) -> List[Source]:
|
|
1646
|
+
"""List all sources (w/ extra metadata) belonging to a user"""
|
|
1647
|
+
|
|
1648
|
+
sources = self.ms.list_sources(user_id=user_id)
|
|
1649
|
+
|
|
1650
|
+
# Add extra metadata to the sources
|
|
1651
|
+
sources_with_metadata = []
|
|
1652
|
+
for source in sources:
|
|
1653
|
+
|
|
1654
|
+
# count number of passages
|
|
1655
|
+
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1656
|
+
num_passages = passage_conn.size({"source_id": source.id})
|
|
1657
|
+
|
|
1658
|
+
# TODO: add when documents table implemented
|
|
1659
|
+
## count number of documents
|
|
1660
|
+
# document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
|
1661
|
+
# num_documents = document_conn.size({"data_source": source.name})
|
|
1662
|
+
num_documents = 0
|
|
1663
|
+
|
|
1664
|
+
agent_ids = self.ms.list_attached_agents(source_id=source.id)
|
|
1665
|
+
# add the agent name information
|
|
1666
|
+
attached_agents = [
|
|
1667
|
+
{
|
|
1668
|
+
"id": str(a_id),
|
|
1669
|
+
"name": self.ms.get_agent(user_id=user_id, agent_id=a_id).name,
|
|
1670
|
+
}
|
|
1671
|
+
for a_id in agent_ids
|
|
1672
|
+
]
|
|
1673
|
+
|
|
1674
|
+
# Overwrite metadata field, should be empty anyways
|
|
1675
|
+
source.metadata_ = dict(
|
|
1676
|
+
num_documents=num_documents,
|
|
1677
|
+
num_passages=num_passages,
|
|
1678
|
+
attached_agents=attached_agents,
|
|
1679
|
+
)
|
|
1680
|
+
|
|
1681
|
+
sources_with_metadata.append(source)
|
|
1682
|
+
|
|
1683
|
+
return sources_with_metadata
|
|
1684
|
+
|
|
1685
|
+
def get_tool(self, tool_id: str) -> Optional[Tool]:
|
|
1686
|
+
"""Get tool by ID."""
|
|
1687
|
+
return self.ms.get_tool(tool_id=tool_id)
|
|
1688
|
+
|
|
1689
|
+
def get_tool_id(self, name: str, user_id: str) -> Optional[str]:
|
|
1690
|
+
"""Get tool ID from name and user_id."""
|
|
1691
|
+
tool = self.ms.get_tool(tool_name=name, user_id=user_id)
|
|
1692
|
+
if not tool or tool.id is None:
|
|
1693
|
+
return None
|
|
1694
|
+
return tool.id
|
|
1695
|
+
|
|
1696
|
+
def update_tool(
|
|
1697
|
+
self,
|
|
1698
|
+
request: ToolUpdate,
|
|
1699
|
+
) -> Tool:
|
|
1700
|
+
"""Update an existing tool"""
|
|
1701
|
+
existing_tool = self.ms.get_tool(tool_id=request.id)
|
|
1702
|
+
if not existing_tool:
|
|
1703
|
+
raise ValueError(f"Tool does not exist")
|
|
1704
|
+
|
|
1705
|
+
# override updated fields
|
|
1706
|
+
if request.source_code:
|
|
1707
|
+
existing_tool.source_code = request.source_code
|
|
1708
|
+
if request.source_type:
|
|
1709
|
+
existing_tool.source_type = request.source_type
|
|
1710
|
+
if request.tags:
|
|
1711
|
+
existing_tool.tags = request.tags
|
|
1712
|
+
if request.json_schema:
|
|
1713
|
+
existing_tool.json_schema = request.json_schema
|
|
1714
|
+
if request.name:
|
|
1715
|
+
existing_tool.name = request.name
|
|
1716
|
+
|
|
1717
|
+
self.ms.update_tool(existing_tool)
|
|
1718
|
+
return self.ms.get_tool(tool_id=request.id)
|
|
1719
|
+
|
|
1720
|
+
def create_tool(self, request: ToolCreate, user_id: Optional[str] = None, update: bool = True) -> Tool: # TODO: add other fields
|
|
1721
|
+
"""Create a new tool"""
|
|
1722
|
+
|
|
1723
|
+
# NOTE: deprecated code that existed when we were trying to pretend that `self` was the memory object
|
|
1724
|
+
# if request.tags and "memory" in request.tags:
|
|
1725
|
+
# # special modifications to memory functions
|
|
1726
|
+
# # self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
|
|
1727
|
+
# request.source_code = request.source_code.replace("self.memory", "self.memory.memory")
|
|
1728
|
+
|
|
1729
|
+
if not request.json_schema:
|
|
1730
|
+
# auto-generate openai schema
|
|
1731
|
+
try:
|
|
1732
|
+
env = {}
|
|
1733
|
+
env.update(globals())
|
|
1734
|
+
exec(request.source_code, env)
|
|
1735
|
+
|
|
1736
|
+
# get available functions
|
|
1737
|
+
functions = [f for f in env if callable(env[f])]
|
|
1738
|
+
|
|
1739
|
+
except Exception as e:
|
|
1740
|
+
logger.error(f"Failed to execute source code: {e}")
|
|
1741
|
+
|
|
1742
|
+
# TODO: not sure if this always works
|
|
1743
|
+
func = env[functions[-1]]
|
|
1744
|
+
json_schema = generate_schema(func, request.name)
|
|
1745
|
+
else:
|
|
1746
|
+
# provided by client
|
|
1747
|
+
json_schema = request.json_schema
|
|
1748
|
+
|
|
1749
|
+
if not request.name:
|
|
1750
|
+
# use name from JSON schema
|
|
1751
|
+
request.name = json_schema["name"]
|
|
1752
|
+
assert request.name, f"Tool name must be provided in json_schema {json_schema}. This should never happen."
|
|
1753
|
+
|
|
1754
|
+
# check if already exists:
|
|
1755
|
+
existing_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id)
|
|
1756
|
+
if existing_tool:
|
|
1757
|
+
if update:
|
|
1758
|
+
updated_tool = self.update_tool(ToolUpdate(id=existing_tool.id, **vars(request)))
|
|
1759
|
+
assert updated_tool is not None, f"Failed to update tool {request.name}"
|
|
1760
|
+
return updated_tool
|
|
1761
|
+
else:
|
|
1762
|
+
raise ValueError(f"Tool {request.name} already exists and update=False")
|
|
1763
|
+
|
|
1764
|
+
tool = Tool(
|
|
1765
|
+
name=request.name,
|
|
1766
|
+
source_code=request.source_code,
|
|
1767
|
+
source_type=request.source_type,
|
|
1768
|
+
tags=request.tags,
|
|
1769
|
+
json_schema=json_schema,
|
|
1770
|
+
user_id=user_id,
|
|
1771
|
+
)
|
|
1772
|
+
self.ms.create_tool(tool)
|
|
1773
|
+
created_tool = self.ms.get_tool(tool_name=request.name, user_id=user_id)
|
|
1774
|
+
return created_tool
|
|
1775
|
+
|
|
1776
|
+
def delete_tool(self, tool_id: str):
|
|
1777
|
+
"""Delete a tool"""
|
|
1778
|
+
self.ms.delete_tool(tool_id)
|
|
1779
|
+
|
|
1780
|
+
def list_tools(self, user_id: str) -> List[Tool]:
|
|
1781
|
+
"""List tools available to user_id"""
|
|
1782
|
+
tools = self.ms.list_tools(user_id)
|
|
1783
|
+
return tools
|
|
1784
|
+
|
|
1785
|
+
def add_default_tools(self, module_name="base", user_id: Optional[str] = None):
|
|
1786
|
+
"""Add default tools in {module_name}.py"""
|
|
1787
|
+
full_module_name = f"letta.functions.function_sets.{module_name}"
|
|
1788
|
+
try:
|
|
1789
|
+
module = importlib.import_module(full_module_name)
|
|
1790
|
+
except Exception as e:
|
|
1791
|
+
# Handle other general exceptions
|
|
1792
|
+
raise e
|
|
1793
|
+
|
|
1794
|
+
try:
|
|
1795
|
+
# Load the function set
|
|
1796
|
+
functions_to_schema = load_function_set(module)
|
|
1797
|
+
except ValueError as e:
|
|
1798
|
+
err = f"Error loading function set '{module_name}': {e}"
|
|
1799
|
+
|
|
1800
|
+
# create tool in db
|
|
1801
|
+
for name, schema in functions_to_schema.items():
|
|
1802
|
+
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
|
1803
|
+
source_code = inspect.getsource(schema["python_function"])
|
|
1804
|
+
tags = [module_name]
|
|
1805
|
+
if module_name == "base":
|
|
1806
|
+
tags.append("letta-base")
|
|
1807
|
+
|
|
1808
|
+
# create to tool
|
|
1809
|
+
self.create_tool(
|
|
1810
|
+
ToolCreate(
|
|
1811
|
+
name=name,
|
|
1812
|
+
tags=tags,
|
|
1813
|
+
source_type="python",
|
|
1814
|
+
module=schema["module"],
|
|
1815
|
+
source_code=source_code,
|
|
1816
|
+
json_schema=schema["json_schema"],
|
|
1817
|
+
user_id=user_id,
|
|
1818
|
+
),
|
|
1819
|
+
update=True,
|
|
1820
|
+
)
|
|
1821
|
+
|
|
1822
|
+
def add_default_blocks(self, user_id: str):
|
|
1823
|
+
from letta.utils import list_human_files, list_persona_files
|
|
1824
|
+
|
|
1825
|
+
assert user_id is not None, "User ID must be provided"
|
|
1826
|
+
|
|
1827
|
+
for persona_file in list_persona_files():
|
|
1828
|
+
text = open(persona_file, "r", encoding="utf-8").read()
|
|
1829
|
+
name = os.path.basename(persona_file).replace(".txt", "")
|
|
1830
|
+
self.create_block(CreatePersona(user_id=user_id, name=name, value=text, template=True), user_id=user_id, update=True)
|
|
1831
|
+
|
|
1832
|
+
for human_file in list_human_files():
|
|
1833
|
+
text = open(human_file, "r", encoding="utf-8").read()
|
|
1834
|
+
name = os.path.basename(human_file).replace(".txt", "")
|
|
1835
|
+
self.create_block(CreateHuman(user_id=user_id, name=name, value=text, template=True), user_id=user_id, update=True)
|
|
1836
|
+
|
|
1837
|
+
def get_agent_message(self, agent_id: str, message_id: str) -> Optional[Message]:
|
|
1838
|
+
"""Get a single message from the agent's memory"""
|
|
1839
|
+
# Get the agent object (loaded in memory)
|
|
1840
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1841
|
+
message = letta_agent.persistence_manager.recall_memory.storage.get(id=message_id)
|
|
1842
|
+
return message
|
|
1843
|
+
|
|
1844
|
+
def update_agent_message(self, agent_id: str, request: UpdateMessage) -> Message:
|
|
1845
|
+
"""Update the details of a message associated with an agent"""
|
|
1846
|
+
|
|
1847
|
+
# Get the current message
|
|
1848
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1849
|
+
return letta_agent.update_message(request=request)
|
|
1850
|
+
|
|
1851
|
+
# TODO decide whether this should be done in the server.py or agent.py
|
|
1852
|
+
# Reason to put it in agent.py:
|
|
1853
|
+
# - we use the agent object's persistence_manager to update the message
|
|
1854
|
+
# - it makes it easy to do things like `retry`, `rethink`, etc.
|
|
1855
|
+
# Reason to put it in server.py:
|
|
1856
|
+
# - fundamentally, we should be able to edit a message (without agent id)
|
|
1857
|
+
# in the server by directly accessing the DB / message store
|
|
1858
|
+
"""
|
|
1859
|
+
message = letta_agent.persistence_manager.recall_memory.storage.get(id=request.id)
|
|
1860
|
+
if message is None:
|
|
1861
|
+
raise ValueError(f"Message with id {request.id} not found")
|
|
1862
|
+
|
|
1863
|
+
# Override fields
|
|
1864
|
+
# NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof
|
|
1865
|
+
if request.role:
|
|
1866
|
+
message.role = request.role
|
|
1867
|
+
if request.text:
|
|
1868
|
+
message.text = request.text
|
|
1869
|
+
if request.name:
|
|
1870
|
+
message.name = request.name
|
|
1871
|
+
if request.tool_calls:
|
|
1872
|
+
assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages"
|
|
1873
|
+
message.tool_calls = request.tool_calls
|
|
1874
|
+
if request.tool_call_id:
|
|
1875
|
+
assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages"
|
|
1876
|
+
message.tool_call_id = request.tool_call_id
|
|
1877
|
+
|
|
1878
|
+
# Save the updated message
|
|
1879
|
+
letta_agent.persistence_manager.recall_memory.storage.update(record=message)
|
|
1880
|
+
|
|
1881
|
+
# Return the updated message
|
|
1882
|
+
updated_message = letta_agent.persistence_manager.recall_memory.storage.get(id=message.id)
|
|
1883
|
+
if updated_message is None:
|
|
1884
|
+
raise ValueError(f"Error persisting message - message with id {request.id} not found")
|
|
1885
|
+
return updated_message
|
|
1886
|
+
"""
|
|
1887
|
+
|
|
1888
|
+
def rewrite_agent_message(self, agent_id: str, new_text: str) -> Message:
|
|
1889
|
+
|
|
1890
|
+
# Get the current message
|
|
1891
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1892
|
+
return letta_agent.rewrite_message(new_text=new_text)
|
|
1893
|
+
|
|
1894
|
+
def rethink_agent_message(self, agent_id: str, new_thought: str) -> Message:
|
|
1895
|
+
|
|
1896
|
+
# Get the current message
|
|
1897
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1898
|
+
return letta_agent.rethink_message(new_thought=new_thought)
|
|
1899
|
+
|
|
1900
|
+
def retry_agent_message(self, agent_id: str) -> List[Message]:
|
|
1901
|
+
|
|
1902
|
+
# Get the current message
|
|
1903
|
+
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
|
1904
|
+
return letta_agent.retry_message()
|
|
1905
|
+
|
|
1906
|
+
def set_current_user(self, user_id: Optional[str]):
|
|
1907
|
+
"""Very hacky way to set the current user for the server, to be replaced once server becomes stateless
|
|
1908
|
+
|
|
1909
|
+
NOTE: clearly not thread-safe, only exists to provide basic user_id support for REST API for now
|
|
1910
|
+
"""
|
|
1911
|
+
|
|
1912
|
+
# Make sure the user_id actually exists
|
|
1913
|
+
if user_id is not None:
|
|
1914
|
+
user_obj = self.get_user(user_id)
|
|
1915
|
+
if not user_obj:
|
|
1916
|
+
raise ValueError(f"User with id {user_id} not found")
|
|
1917
|
+
|
|
1918
|
+
self._current_user = user_id
|
|
1919
|
+
|
|
1920
|
+
def get_default_user(self) -> User:
|
|
1921
|
+
|
|
1922
|
+
from letta.constants import (
|
|
1923
|
+
DEFAULT_ORG_ID,
|
|
1924
|
+
DEFAULT_ORG_NAME,
|
|
1925
|
+
DEFAULT_USER_ID,
|
|
1926
|
+
DEFAULT_USER_NAME,
|
|
1927
|
+
)
|
|
1928
|
+
|
|
1929
|
+
# check if default org exists
|
|
1930
|
+
default_org = self.ms.get_organization(DEFAULT_ORG_ID)
|
|
1931
|
+
if not default_org:
|
|
1932
|
+
org = Organization(name=DEFAULT_ORG_NAME, id=DEFAULT_ORG_ID)
|
|
1933
|
+
self.ms.create_organization(org)
|
|
1934
|
+
|
|
1935
|
+
# check if default user exists
|
|
1936
|
+
default_user = self.get_user(DEFAULT_USER_ID)
|
|
1937
|
+
if not default_user:
|
|
1938
|
+
user = User(name=DEFAULT_USER_NAME, org_id=DEFAULT_ORG_ID, id=DEFAULT_USER_ID)
|
|
1939
|
+
self.ms.create_user(user)
|
|
1940
|
+
|
|
1941
|
+
# add default data (TODO: move to org)
|
|
1942
|
+
self.add_default_blocks(user.id)
|
|
1943
|
+
self.add_default_tools(module_name="base", user_id=user.id)
|
|
1944
|
+
|
|
1945
|
+
# check if default org exists
|
|
1946
|
+
return self.get_user(DEFAULT_USER_ID)
|
|
1947
|
+
|
|
1948
|
+
# TODO(ethan) wire back to real method in future ORM PR
|
|
1949
|
+
def get_current_user(self) -> User:
|
|
1950
|
+
"""Returns the currently authed user.
|
|
1951
|
+
|
|
1952
|
+
Since server is the core gateway this needs to pass through server as the
|
|
1953
|
+
first touchpoint.
|
|
1954
|
+
"""
|
|
1955
|
+
|
|
1956
|
+
# Check if _current_user is set and if it's non-null:
|
|
1957
|
+
if hasattr(self, "_current_user") and self._current_user is not None:
|
|
1958
|
+
current_user = self.get_user(self._current_user)
|
|
1959
|
+
if not current_user:
|
|
1960
|
+
warnings.warn(f"Provided user '{self._current_user}' not found, using default user")
|
|
1961
|
+
else:
|
|
1962
|
+
return current_user
|
|
1963
|
+
|
|
1964
|
+
return self.get_default_user()
|
|
1965
|
+
## NOTE: same code as local client to get the default user
|
|
1966
|
+
#config = LettaConfig.load()
|
|
1967
|
+
#user_id = config.anon_clientid
|
|
1968
|
+
#user = self.get_user(user_id)
|
|
1969
|
+
|
|
1970
|
+
#if not user:
|
|
1971
|
+
# user = self.create_user(UserCreate())
|
|
1972
|
+
|
|
1973
|
+
# # # update config
|
|
1974
|
+
# config.anon_clientid = str(user.id)
|
|
1975
|
+
# config.save()
|
|
1976
|
+
|
|
1977
|
+
#return user
|
|
1978
|
+
|
|
1979
|
+
def list_models(self) -> List[LLMConfig]:
|
|
1980
|
+
"""List available models"""
|
|
1981
|
+
|
|
1982
|
+
# TODO: allow multiple options from endpoint
|
|
1983
|
+
# model_options = get_model_options(
|
|
1984
|
+
# credentials=LettaCredentials().load(),
|
|
1985
|
+
# model_endpoint_type=settings.llm_endpoint,
|
|
1986
|
+
# model_endpoint=settings.llm_endpoint_type
|
|
1987
|
+
# )
|
|
1988
|
+
|
|
1989
|
+
return [settings.llm_config]
|
|
1990
|
+
|
|
1991
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1992
|
+
"""List available embedding models"""
|
|
1993
|
+
|
|
1994
|
+
# TODO support multiple models
|
|
1995
|
+
return [settings.embedding_config]
|