letta-nightly 0.6.4.dev20241216104246__py3-none-any.whl → 0.6.5.dev20241218055539__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/agent.py +95 -101
- letta/client/client.py +1 -0
- letta/constants.py +6 -1
- letta/embeddings.py +3 -9
- letta/functions/function_sets/base.py +11 -57
- letta/functions/schema_generator.py +2 -6
- letta/llm_api/anthropic.py +38 -13
- letta/llm_api/llm_api_tools.py +12 -1
- letta/local_llm/function_parser.py +2 -2
- letta/orm/__init__.py +1 -1
- letta/orm/agent.py +19 -1
- letta/orm/errors.py +8 -0
- letta/orm/file.py +3 -2
- letta/orm/mixins.py +3 -14
- letta/orm/organization.py +19 -3
- letta/orm/passage.py +59 -23
- letta/orm/source.py +4 -0
- letta/orm/sqlalchemy_base.py +25 -18
- letta/prompts/system/memgpt_modified_chat.txt +1 -1
- letta/prompts/system/memgpt_modified_o1.txt +1 -1
- letta/providers.py +2 -0
- letta/schemas/agent.py +35 -0
- letta/schemas/embedding_config.py +20 -2
- letta/schemas/passage.py +1 -1
- letta/schemas/sandbox_config.py +2 -1
- letta/server/rest_api/app.py +43 -5
- letta/server/rest_api/routers/v1/tools.py +1 -1
- letta/server/rest_api/utils.py +24 -5
- letta/server/server.py +105 -164
- letta/server/ws_api/server.py +1 -1
- letta/services/agent_manager.py +344 -9
- letta/services/passage_manager.py +76 -100
- letta/services/tool_execution_sandbox.py +54 -45
- letta/settings.py +10 -5
- letta/utils.py +8 -0
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/METADATA +6 -6
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/RECORD +41 -41
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/entry_points.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Literal, Optional
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, Field
|
|
4
4
|
|
|
@@ -20,7 +20,25 @@ class EmbeddingConfig(BaseModel):
|
|
|
20
20
|
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
|
-
embedding_endpoint_type:
|
|
23
|
+
embedding_endpoint_type: Literal[
|
|
24
|
+
"openai",
|
|
25
|
+
"anthropic",
|
|
26
|
+
"cohere",
|
|
27
|
+
"google_ai",
|
|
28
|
+
"azure",
|
|
29
|
+
"groq",
|
|
30
|
+
"ollama",
|
|
31
|
+
"webui",
|
|
32
|
+
"webui-legacy",
|
|
33
|
+
"lmstudio",
|
|
34
|
+
"lmstudio-legacy",
|
|
35
|
+
"llamacpp",
|
|
36
|
+
"koboldcpp",
|
|
37
|
+
"vllm",
|
|
38
|
+
"hugging-face",
|
|
39
|
+
"mistral",
|
|
40
|
+
"together", # completions endpoint
|
|
41
|
+
] = Field(..., description="The endpoint type for the model.")
|
|
24
42
|
embedding_endpoint: Optional[str] = Field(None, description="The endpoint for the model (`None` if local).")
|
|
25
43
|
embedding_model: str = Field(..., description="The model for the embedding.")
|
|
26
44
|
embedding_dim: int = Field(..., description="The dimension of the embedding.")
|
letta/schemas/passage.py
CHANGED
letta/schemas/sandbox_config.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import hashlib
|
|
2
2
|
import json
|
|
3
3
|
from enum import Enum
|
|
4
|
-
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
from pydantic import BaseModel, Field, model_validator
|
|
7
7
|
|
|
@@ -21,6 +21,7 @@ class SandboxRunResult(BaseModel):
|
|
|
21
21
|
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
|
22
22
|
stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation")
|
|
23
23
|
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
|
24
|
+
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
|
24
25
|
sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")
|
|
25
26
|
|
|
26
27
|
|
letta/server/rest_api/app.py
CHANGED
|
@@ -14,6 +14,13 @@ from starlette.middleware.cors import CORSMiddleware
|
|
|
14
14
|
from letta.__init__ import __version__
|
|
15
15
|
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
|
16
16
|
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
|
|
17
|
+
from letta.log import get_logger
|
|
18
|
+
from letta.orm.errors import (
|
|
19
|
+
DatabaseTimeoutError,
|
|
20
|
+
ForeignKeyConstraintViolationError,
|
|
21
|
+
NoResultFound,
|
|
22
|
+
UniqueConstraintViolationError,
|
|
23
|
+
)
|
|
17
24
|
from letta.schemas.letta_response import LettaResponse
|
|
18
25
|
from letta.server.constants import REST_DEFAULT_PORT
|
|
19
26
|
|
|
@@ -45,6 +52,7 @@ from letta.settings import settings
|
|
|
45
52
|
# NOTE(charles): @ethan I had to add this to get the global as the bottom to work
|
|
46
53
|
interface: StreamingServerInterface = StreamingServerInterface
|
|
47
54
|
server = SyncServer(default_interface_factory=lambda: interface())
|
|
55
|
+
logger = get_logger(__name__)
|
|
48
56
|
|
|
49
57
|
# TODO: remove
|
|
50
58
|
password = None
|
|
@@ -170,6 +178,41 @@ def create_application() -> "FastAPI":
|
|
|
170
178
|
},
|
|
171
179
|
)
|
|
172
180
|
|
|
181
|
+
@app.exception_handler(NoResultFound)
|
|
182
|
+
async def no_result_found_handler(request: Request, exc: NoResultFound):
|
|
183
|
+
logger.error(f"NoResultFound: {exc}")
|
|
184
|
+
|
|
185
|
+
return JSONResponse(
|
|
186
|
+
status_code=404,
|
|
187
|
+
content={"detail": str(exc)},
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
@app.exception_handler(ForeignKeyConstraintViolationError)
|
|
191
|
+
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError):
|
|
192
|
+
logger.error(f"ForeignKeyConstraintViolationError: {exc}")
|
|
193
|
+
|
|
194
|
+
return JSONResponse(
|
|
195
|
+
status_code=409,
|
|
196
|
+
content={"detail": str(exc)},
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
@app.exception_handler(UniqueConstraintViolationError)
|
|
200
|
+
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError):
|
|
201
|
+
logger.error(f"UniqueConstraintViolationError: {exc}")
|
|
202
|
+
|
|
203
|
+
return JSONResponse(
|
|
204
|
+
status_code=409,
|
|
205
|
+
content={"detail": str(exc)},
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
@app.exception_handler(DatabaseTimeoutError)
|
|
209
|
+
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
|
|
210
|
+
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
|
|
211
|
+
return JSONResponse(
|
|
212
|
+
status_code=503,
|
|
213
|
+
content={"detail": "The database is temporarily unavailable. Please try again later."},
|
|
214
|
+
)
|
|
215
|
+
|
|
173
216
|
@app.exception_handler(ValueError)
|
|
174
217
|
async def value_error_handler(request: Request, exc: ValueError):
|
|
175
218
|
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
|
@@ -222,11 +265,6 @@ def create_application() -> "FastAPI":
|
|
|
222
265
|
|
|
223
266
|
@app.on_event("startup")
|
|
224
267
|
def on_startup():
|
|
225
|
-
# load the default tools
|
|
226
|
-
# from letta.orm.tool import Tool
|
|
227
|
-
|
|
228
|
-
# Tool.load_default_tools(get_db_session())
|
|
229
|
-
|
|
230
268
|
generate_openapi_schema(app)
|
|
231
269
|
|
|
232
270
|
@app.on_event("shutdown")
|
letta/server/rest_api/utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
|
+
import os
|
|
3
4
|
import warnings
|
|
4
5
|
from enum import Enum
|
|
5
6
|
from typing import AsyncGenerator, Optional, Union
|
|
@@ -64,13 +65,31 @@ async def sse_async_generator(
|
|
|
64
65
|
import traceback
|
|
65
66
|
|
|
66
67
|
traceback.print_exc()
|
|
67
|
-
warnings.warn(f"
|
|
68
|
-
|
|
68
|
+
warnings.warn(f"SSE stream generator failed: {e}")
|
|
69
|
+
|
|
70
|
+
# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
|
|
71
|
+
# Print the stack trace
|
|
72
|
+
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
|
73
|
+
import sentry_sdk
|
|
74
|
+
|
|
75
|
+
sentry_sdk.capture_exception(e)
|
|
76
|
+
|
|
77
|
+
yield sse_formatter({"error": f"Stream failed (internal error occured)"})
|
|
69
78
|
|
|
70
79
|
except Exception as e:
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
80
|
+
import traceback
|
|
81
|
+
|
|
82
|
+
traceback.print_exc()
|
|
83
|
+
warnings.warn(f"SSE stream generator failed: {e}")
|
|
84
|
+
|
|
85
|
+
# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
|
|
86
|
+
# Print the stack trace
|
|
87
|
+
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
|
|
88
|
+
import sentry_sdk
|
|
89
|
+
|
|
90
|
+
sentry_sdk.capture_exception(e)
|
|
91
|
+
|
|
92
|
+
yield sse_formatter({"error": "Stream failed (decoder encountered an error)"})
|
|
74
93
|
|
|
75
94
|
finally:
|
|
76
95
|
if finish_message:
|
letta/server/server.py
CHANGED
|
@@ -4,7 +4,6 @@ import os
|
|
|
4
4
|
import traceback
|
|
5
5
|
import warnings
|
|
6
6
|
from abc import abstractmethod
|
|
7
|
-
from asyncio import Lock
|
|
8
7
|
from datetime import datetime
|
|
9
8
|
from typing import Callable, List, Optional, Tuple, Union
|
|
10
9
|
|
|
@@ -19,7 +18,6 @@ from letta.agent import Agent, save_agent
|
|
|
19
18
|
from letta.chat_only_agent import ChatOnlyAgent
|
|
20
19
|
from letta.credentials import LettaCredentials
|
|
21
20
|
from letta.data_sources.connectors import DataConnector, load_data
|
|
22
|
-
from letta.errors import LettaAgentNotFoundError
|
|
23
21
|
|
|
24
22
|
# TODO use custom interface
|
|
25
23
|
from letta.interface import AgentInterface # abstract
|
|
@@ -76,7 +74,7 @@ from letta.services.source_manager import SourceManager
|
|
|
76
74
|
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
|
77
75
|
from letta.services.tool_manager import ToolManager
|
|
78
76
|
from letta.services.user_manager import UserManager
|
|
79
|
-
from letta.utils import get_utc_time, json_dumps, json_loads
|
|
77
|
+
from letta.utils import get_friendly_error_msg, get_utc_time, json_dumps, json_loads
|
|
80
78
|
|
|
81
79
|
logger = get_logger(__name__)
|
|
82
80
|
|
|
@@ -192,7 +190,14 @@ if settings.letta_pg_uri_no_default:
|
|
|
192
190
|
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
|
193
191
|
|
|
194
192
|
# create engine
|
|
195
|
-
engine = create_engine(
|
|
193
|
+
engine = create_engine(
|
|
194
|
+
settings.letta_pg_uri,
|
|
195
|
+
pool_size=settings.pg_pool_size,
|
|
196
|
+
max_overflow=settings.pg_max_overflow,
|
|
197
|
+
pool_timeout=settings.pg_pool_timeout,
|
|
198
|
+
pool_recycle=settings.pg_pool_recycle,
|
|
199
|
+
echo=settings.pg_echo,
|
|
200
|
+
)
|
|
196
201
|
else:
|
|
197
202
|
# TODO: don't rely on config storage
|
|
198
203
|
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))
|
|
@@ -266,9 +271,6 @@ class SyncServer(Server):
|
|
|
266
271
|
|
|
267
272
|
self.credentials = LettaCredentials.load()
|
|
268
273
|
|
|
269
|
-
# Locks
|
|
270
|
-
self.send_message_lock = Lock()
|
|
271
|
-
|
|
272
274
|
# Initialize the metadata store
|
|
273
275
|
config = LettaConfig.load()
|
|
274
276
|
if settings.letta_pg_uri_no_default:
|
|
@@ -399,9 +401,6 @@ class SyncServer(Server):
|
|
|
399
401
|
with agent_lock:
|
|
400
402
|
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
|
401
403
|
|
|
402
|
-
if agent_state is None:
|
|
403
|
-
raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist")
|
|
404
|
-
|
|
405
404
|
interface = interface or self.default_interface_factory()
|
|
406
405
|
if agent_state.agent_type == AgentType.memgpt_agent:
|
|
407
406
|
agent = Agent(agent_state=agent_state, interface=interface, user=actor)
|
|
@@ -777,6 +776,18 @@ class SyncServer(Server):
|
|
|
777
776
|
# interface
|
|
778
777
|
interface: Union[AgentInterface, None] = None,
|
|
779
778
|
) -> AgentState:
|
|
779
|
+
if request.llm_config is None:
|
|
780
|
+
if request.llm is None:
|
|
781
|
+
raise ValueError("Must specify either llm or llm_config in request")
|
|
782
|
+
request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit)
|
|
783
|
+
|
|
784
|
+
if request.embedding_config is None:
|
|
785
|
+
if request.embedding is None:
|
|
786
|
+
raise ValueError("Must specify either embedding or embedding_config in request")
|
|
787
|
+
request.embedding_config = self.get_embedding_config_from_handle(
|
|
788
|
+
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
789
|
+
)
|
|
790
|
+
|
|
780
791
|
"""Create a new agent using a config"""
|
|
781
792
|
# Invoke manager
|
|
782
793
|
agent_state = self.agent_manager.create_agent(
|
|
@@ -824,13 +835,13 @@ class SyncServer(Server):
|
|
|
824
835
|
actor: User,
|
|
825
836
|
) -> AgentState:
|
|
826
837
|
"""Update the agents core memory block, return the new state"""
|
|
838
|
+
# Update agent state in the db first
|
|
839
|
+
agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
|
|
840
|
+
|
|
827
841
|
# Get the agent object (loaded in memory)
|
|
828
842
|
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
829
843
|
|
|
830
|
-
#
|
|
831
|
-
if request.tags is not None: # Allow for empty list
|
|
832
|
-
letta_agent.agent_state.tags = request.tags
|
|
833
|
-
|
|
844
|
+
# TODO: Everything below needs to get removed, no updating anything in memory
|
|
834
845
|
# update the system prompt
|
|
835
846
|
if request.system:
|
|
836
847
|
letta_agent.update_system_prompt(request.system)
|
|
@@ -842,45 +853,9 @@ class SyncServer(Server):
|
|
|
842
853
|
# then (2) setting the attributes ._messages and .state.message_ids
|
|
843
854
|
letta_agent.set_message_buffer(message_ids=request.message_ids)
|
|
844
855
|
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
# (1) get tools + make sure they exist
|
|
850
|
-
# Current and target tools as sets of tool names
|
|
851
|
-
current_tools = letta_agent.agent_state.tools
|
|
852
|
-
current_tool_ids = set([t.id for t in current_tools])
|
|
853
|
-
target_tool_ids = set(request.tool_ids)
|
|
854
|
-
|
|
855
|
-
# Calculate tools to add and remove
|
|
856
|
-
tool_ids_to_add = target_tool_ids - current_tool_ids
|
|
857
|
-
tools_ids_to_remove = current_tool_ids - target_tool_ids
|
|
858
|
-
|
|
859
|
-
# update agent tool list
|
|
860
|
-
for tool_id in tools_ids_to_remove:
|
|
861
|
-
self.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
|
862
|
-
for tool_id in tool_ids_to_add:
|
|
863
|
-
self.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
|
864
|
-
|
|
865
|
-
# reload agent
|
|
866
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
867
|
-
|
|
868
|
-
# configs
|
|
869
|
-
if request.llm_config:
|
|
870
|
-
letta_agent.agent_state.llm_config = request.llm_config
|
|
871
|
-
if request.embedding_config:
|
|
872
|
-
letta_agent.agent_state.embedding_config = request.embedding_config
|
|
873
|
-
|
|
874
|
-
# other minor updates
|
|
875
|
-
if request.name:
|
|
876
|
-
letta_agent.agent_state.name = request.name
|
|
877
|
-
if request.metadata_:
|
|
878
|
-
letta_agent.agent_state.metadata_ = request.metadata_
|
|
879
|
-
|
|
880
|
-
# save the agent
|
|
881
|
-
save_agent(letta_agent)
|
|
882
|
-
# TODO: probably reload the agent somehow?
|
|
883
|
-
return letta_agent.agent_state
|
|
856
|
+
letta_agent.update_state()
|
|
857
|
+
|
|
858
|
+
return agent_state
|
|
884
859
|
|
|
885
860
|
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
|
|
886
861
|
"""Get tools from an existing agent"""
|
|
@@ -901,32 +876,9 @@ class SyncServer(Server):
|
|
|
901
876
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
902
877
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
903
878
|
|
|
904
|
-
|
|
905
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
906
|
-
|
|
907
|
-
# Get all the tool objects from the request
|
|
908
|
-
tool_objs = []
|
|
909
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
|
910
|
-
assert tool_obj, f"Tool with id={tool_id} does not exist"
|
|
911
|
-
tool_objs.append(tool_obj)
|
|
912
|
-
|
|
913
|
-
for tool in letta_agent.agent_state.tools:
|
|
914
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
|
|
915
|
-
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
|
879
|
+
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
|
916
880
|
|
|
917
|
-
|
|
918
|
-
if tool_obj.id != tool_id:
|
|
919
|
-
tool_objs.append(tool_obj)
|
|
920
|
-
|
|
921
|
-
# replace the list of tool names ("ids") inside the agent state
|
|
922
|
-
letta_agent.agent_state.tools = tool_objs
|
|
923
|
-
|
|
924
|
-
# then attempt to link the tools modules
|
|
925
|
-
letta_agent.link_tools(tool_objs)
|
|
926
|
-
|
|
927
|
-
# save the agent
|
|
928
|
-
save_agent(letta_agent)
|
|
929
|
-
return letta_agent.agent_state
|
|
881
|
+
return agent_state
|
|
930
882
|
|
|
931
883
|
def remove_tool_from_agent(
|
|
932
884
|
self,
|
|
@@ -937,29 +889,9 @@ class SyncServer(Server):
|
|
|
937
889
|
"""Remove tools from an existing agent"""
|
|
938
890
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
939
891
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
892
|
+
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
|
940
893
|
|
|
941
|
-
|
|
942
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
943
|
-
|
|
944
|
-
# Get all the tool_objs
|
|
945
|
-
tool_objs = []
|
|
946
|
-
for tool in letta_agent.agent_state.tools:
|
|
947
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
|
|
948
|
-
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
|
949
|
-
|
|
950
|
-
# If it's not the tool we want to remove
|
|
951
|
-
if tool_obj.id != tool_id:
|
|
952
|
-
tool_objs.append(tool_obj)
|
|
953
|
-
|
|
954
|
-
# replace the list of tool names ("ids") inside the agent state
|
|
955
|
-
letta_agent.agent_state.tools = tool_objs
|
|
956
|
-
|
|
957
|
-
# then attempt to link the tools modules
|
|
958
|
-
letta_agent.link_tools(tool_objs)
|
|
959
|
-
|
|
960
|
-
# save the agent
|
|
961
|
-
save_agent(letta_agent)
|
|
962
|
-
return letta_agent.agent_state
|
|
894
|
+
return agent_state
|
|
963
895
|
|
|
964
896
|
# convert name->id
|
|
965
897
|
|
|
@@ -970,7 +902,7 @@ class SyncServer(Server):
|
|
|
970
902
|
|
|
971
903
|
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
|
|
972
904
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
973
|
-
return ArchivalMemorySummary(size=
|
|
905
|
+
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
|
|
974
906
|
|
|
975
907
|
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
|
|
976
908
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
@@ -987,18 +919,9 @@ class SyncServer(Server):
|
|
|
987
919
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
988
920
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
989
921
|
|
|
990
|
-
|
|
991
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
992
|
-
|
|
993
|
-
# iterate over records
|
|
994
|
-
records = letta_agent.passage_manager.list_passages(
|
|
995
|
-
actor=actor,
|
|
996
|
-
agent_id=agent_id,
|
|
997
|
-
cursor=cursor,
|
|
998
|
-
limit=limit,
|
|
999
|
-
)
|
|
922
|
+
passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor)
|
|
1000
923
|
|
|
1001
|
-
return
|
|
924
|
+
return passages
|
|
1002
925
|
|
|
1003
926
|
def get_agent_archival_cursor(
|
|
1004
927
|
self,
|
|
@@ -1012,15 +935,13 @@ class SyncServer(Server):
|
|
|
1012
935
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
1013
936
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
1014
937
|
|
|
1015
|
-
# Get the agent object (loaded in memory)
|
|
1016
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
1017
|
-
|
|
1018
938
|
# iterate over records
|
|
1019
|
-
records =
|
|
1020
|
-
actor=
|
|
939
|
+
records = self.agent_manager.list_passages(
|
|
940
|
+
actor=actor,
|
|
1021
941
|
agent_id=agent_id,
|
|
1022
942
|
cursor=cursor,
|
|
1023
943
|
limit=limit,
|
|
944
|
+
ascending=not reverse,
|
|
1024
945
|
)
|
|
1025
946
|
return records
|
|
1026
947
|
|
|
@@ -1105,7 +1026,7 @@ class SyncServer(Server):
|
|
|
1105
1026
|
config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
|
|
1106
1027
|
return config_copy
|
|
1107
1028
|
|
|
1108
|
-
# TODO: do we need a
|
|
1029
|
+
# TODO: do we need a separate server config?
|
|
1109
1030
|
base_config = vars(self.config)
|
|
1110
1031
|
clean_base_config = clean_keys(base_config)
|
|
1111
1032
|
|
|
@@ -1136,7 +1057,8 @@ class SyncServer(Server):
|
|
|
1136
1057
|
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
|
1137
1058
|
|
|
1138
1059
|
# delete data from passage store
|
|
1139
|
-
self.
|
|
1060
|
+
passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
|
|
1061
|
+
self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
|
|
1140
1062
|
|
|
1141
1063
|
# TODO: delete data from agent passage stores (?)
|
|
1142
1064
|
|
|
@@ -1167,9 +1089,11 @@ class SyncServer(Server):
|
|
|
1167
1089
|
for agent_state in agent_states:
|
|
1168
1090
|
agent_id = agent_state.id
|
|
1169
1091
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
1170
|
-
|
|
1092
|
+
|
|
1093
|
+
# Attach source to agent
|
|
1094
|
+
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
|
1171
1095
|
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
|
|
1172
|
-
new_passage_size = self.
|
|
1096
|
+
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
|
1173
1097
|
assert new_passage_size >= curr_passage_size # in case empty files are added
|
|
1174
1098
|
|
|
1175
1099
|
return job
|
|
@@ -1233,14 +1157,9 @@ class SyncServer(Server):
|
|
|
1233
1157
|
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
1234
1158
|
elif source_name:
|
|
1235
1159
|
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
|
1160
|
+
source_id = source.id
|
|
1236
1161
|
else:
|
|
1237
1162
|
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
|
1238
|
-
source_id = source.id
|
|
1239
|
-
|
|
1240
|
-
# TODO: This should be done with the ORM?
|
|
1241
|
-
# delete all Passage objects with source_id==source_id from agent's archival memory
|
|
1242
|
-
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
1243
|
-
agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id)
|
|
1244
1163
|
|
|
1245
1164
|
# delete agent-source mapping
|
|
1246
1165
|
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
|
@@ -1262,7 +1181,7 @@ class SyncServer(Server):
|
|
|
1262
1181
|
for source in sources:
|
|
1263
1182
|
|
|
1264
1183
|
# count number of passages
|
|
1265
|
-
num_passages = self.
|
|
1184
|
+
num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id)
|
|
1266
1185
|
|
|
1267
1186
|
# TODO: add when files table implemented
|
|
1268
1187
|
## count number of files
|
|
@@ -1363,6 +1282,55 @@ class SyncServer(Server):
|
|
|
1363
1282
|
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
|
1364
1283
|
return embedding_models
|
|
1365
1284
|
|
|
1285
|
+
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
|
1286
|
+
provider_name, model_name = handle.split("/", 1)
|
|
1287
|
+
provider = self.get_provider_from_name(provider_name)
|
|
1288
|
+
|
|
1289
|
+
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
|
|
1290
|
+
if not llm_configs:
|
|
1291
|
+
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
|
|
1292
|
+
elif len(llm_configs) > 1:
|
|
1293
|
+
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
|
|
1294
|
+
else:
|
|
1295
|
+
llm_config = llm_configs[0]
|
|
1296
|
+
|
|
1297
|
+
if context_window_limit:
|
|
1298
|
+
if context_window_limit > llm_config.context_window:
|
|
1299
|
+
raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})")
|
|
1300
|
+
llm_config.context_window = context_window_limit
|
|
1301
|
+
|
|
1302
|
+
return llm_config
|
|
1303
|
+
|
|
1304
|
+
def get_embedding_config_from_handle(
|
|
1305
|
+
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
1306
|
+
) -> EmbeddingConfig:
|
|
1307
|
+
provider_name, model_name = handle.split("/", 1)
|
|
1308
|
+
provider = self.get_provider_from_name(provider_name)
|
|
1309
|
+
|
|
1310
|
+
embedding_configs = [config for config in provider.list_embedding_models() if config.embedding_model == model_name]
|
|
1311
|
+
if not embedding_configs:
|
|
1312
|
+
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
|
|
1313
|
+
elif len(embedding_configs) > 1:
|
|
1314
|
+
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
|
|
1315
|
+
else:
|
|
1316
|
+
embedding_config = embedding_configs[0]
|
|
1317
|
+
|
|
1318
|
+
if embedding_chunk_size:
|
|
1319
|
+
embedding_config.embedding_chunk_size = embedding_chunk_size
|
|
1320
|
+
|
|
1321
|
+
return embedding_config
|
|
1322
|
+
|
|
1323
|
+
def get_provider_from_name(self, provider_name: str) -> Provider:
|
|
1324
|
+
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
|
|
1325
|
+
if not providers:
|
|
1326
|
+
raise ValueError(f"Provider {provider_name} is not supported")
|
|
1327
|
+
elif len(providers) > 1:
|
|
1328
|
+
raise ValueError(f"Multiple providers with name {provider_name} supported")
|
|
1329
|
+
else:
|
|
1330
|
+
provider = providers[0]
|
|
1331
|
+
|
|
1332
|
+
return provider
|
|
1333
|
+
|
|
1366
1334
|
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
|
1367
1335
|
"""Add a new LLM model"""
|
|
1368
1336
|
|
|
@@ -1383,7 +1351,7 @@ class SyncServer(Server):
|
|
|
1383
1351
|
|
|
1384
1352
|
def run_tool_from_source(
|
|
1385
1353
|
self,
|
|
1386
|
-
|
|
1354
|
+
actor: User,
|
|
1387
1355
|
tool_args: str,
|
|
1388
1356
|
tool_source: str,
|
|
1389
1357
|
tool_source_type: Optional[str] = None,
|
|
@@ -1411,56 +1379,29 @@ class SyncServer(Server):
|
|
|
1411
1379
|
|
|
1412
1380
|
# Next, attempt to run the tool with the sandbox
|
|
1413
1381
|
try:
|
|
1414
|
-
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict,
|
|
1415
|
-
function_response = str(sandbox_run_result.func_return)
|
|
1416
|
-
stdout = [s for s in sandbox_run_result.stdout if s.strip()]
|
|
1417
|
-
stderr = [s for s in sandbox_run_result.stderr if s.strip()]
|
|
1418
|
-
|
|
1419
|
-
# expected error
|
|
1420
|
-
if stderr:
|
|
1421
|
-
error_msg = self.get_error_msg_for_func_return(tool.name, stderr[-1])
|
|
1422
|
-
return FunctionReturn(
|
|
1423
|
-
id="null",
|
|
1424
|
-
function_call_id="null",
|
|
1425
|
-
date=get_utc_time(),
|
|
1426
|
-
status="error",
|
|
1427
|
-
function_return=error_msg,
|
|
1428
|
-
stdout=stdout,
|
|
1429
|
-
stderr=stderr,
|
|
1430
|
-
)
|
|
1431
|
-
|
|
1382
|
+
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
|
|
1432
1383
|
return FunctionReturn(
|
|
1433
1384
|
id="null",
|
|
1434
1385
|
function_call_id="null",
|
|
1435
1386
|
date=get_utc_time(),
|
|
1436
|
-
status=
|
|
1437
|
-
function_return=
|
|
1438
|
-
stdout=stdout,
|
|
1439
|
-
stderr=stderr,
|
|
1387
|
+
status=sandbox_run_result.status,
|
|
1388
|
+
function_return=str(sandbox_run_result.func_return),
|
|
1389
|
+
stdout=sandbox_run_result.stdout,
|
|
1390
|
+
stderr=sandbox_run_result.stderr,
|
|
1440
1391
|
)
|
|
1441
1392
|
|
|
1442
|
-
# unexpected error TODO(@cthomas): consolidate error handling
|
|
1443
1393
|
except Exception as e:
|
|
1444
|
-
|
|
1394
|
+
func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e))
|
|
1445
1395
|
return FunctionReturn(
|
|
1446
1396
|
id="null",
|
|
1447
1397
|
function_call_id="null",
|
|
1448
1398
|
date=get_utc_time(),
|
|
1449
1399
|
status="error",
|
|
1450
|
-
function_return=
|
|
1451
|
-
stdout=[
|
|
1400
|
+
function_return=func_return,
|
|
1401
|
+
stdout=[],
|
|
1452
1402
|
stderr=[traceback.format_exc()],
|
|
1453
1403
|
)
|
|
1454
1404
|
|
|
1455
|
-
def get_error_msg_for_func_return(self, tool_name, exception_message):
|
|
1456
|
-
# same as agent.py
|
|
1457
|
-
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
|
1458
|
-
|
|
1459
|
-
error_msg = f"Error executing tool {tool_name}: {exception_message}"
|
|
1460
|
-
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
|
1461
|
-
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
|
1462
|
-
return error_msg
|
|
1463
|
-
|
|
1464
1405
|
# Composio wrappers
|
|
1465
1406
|
def get_composio_client(self, api_key: Optional[str] = None):
|
|
1466
1407
|
if api_key:
|
letta/server/ws_api/server.py
CHANGED
|
@@ -33,7 +33,7 @@ class WebSocketServer:
|
|
|
33
33
|
self.initialize_server()
|
|
34
34
|
# Can play with ping_interval and ping_timeout
|
|
35
35
|
# See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html
|
|
36
|
-
# and https://github.com/
|
|
36
|
+
# and https://github.com/letta-ai/letta/issues/471
|
|
37
37
|
async with websockets.serve(self.handle_client, self.host, self.port):
|
|
38
38
|
await asyncio.Future() # Run forever
|
|
39
39
|
|