letta-nightly 0.6.4.dev20241217104233__py3-none-any.whl → 0.6.5.dev20241218055539__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/agent.py +68 -65
- letta/client/client.py +1 -0
- letta/constants.py +6 -1
- letta/embeddings.py +3 -9
- letta/functions/function_sets/base.py +9 -57
- letta/functions/schema_generator.py +1 -1
- letta/llm_api/anthropic.py +38 -13
- letta/llm_api/llm_api_tools.py +12 -1
- letta/local_llm/function_parser.py +1 -1
- letta/orm/errors.py +8 -0
- letta/orm/sqlalchemy_base.py +24 -17
- letta/providers.py +2 -0
- letta/schemas/agent.py +35 -0
- letta/schemas/sandbox_config.py +2 -1
- letta/server/rest_api/app.py +32 -7
- letta/server/rest_api/routers/v1/tools.py +1 -1
- letta/server/server.py +81 -57
- letta/services/agent_manager.py +3 -0
- letta/services/tool_execution_sandbox.py +54 -45
- letta/settings.py +9 -4
- letta/utils.py +8 -0
- {letta_nightly-0.6.4.dev20241217104233.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/METADATA +1 -1
- {letta_nightly-0.6.4.dev20241217104233.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/RECORD +27 -27
- {letta_nightly-0.6.4.dev20241217104233.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.4.dev20241217104233.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.4.dev20241217104233.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/entry_points.txt +0 -0
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from enum import Enum
|
|
3
|
+
from functools import wraps
|
|
3
4
|
from typing import TYPE_CHECKING, List, Literal, Optional
|
|
4
5
|
|
|
5
6
|
from sqlalchemy import String, desc, func, or_, select
|
|
6
|
-
from sqlalchemy.exc import DBAPIError, IntegrityError
|
|
7
|
+
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
|
7
8
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
|
8
9
|
|
|
9
10
|
from letta.log import get_logger
|
|
10
11
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
|
11
12
|
from letta.orm.errors import (
|
|
13
|
+
DatabaseTimeoutError,
|
|
12
14
|
ForeignKeyConstraintViolationError,
|
|
13
15
|
NoResultFound,
|
|
14
16
|
UniqueConstraintViolationError,
|
|
@@ -23,6 +25,20 @@ if TYPE_CHECKING:
|
|
|
23
25
|
logger = get_logger(__name__)
|
|
24
26
|
|
|
25
27
|
|
|
28
|
+
def handle_db_timeout(func):
|
|
29
|
+
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""
|
|
30
|
+
|
|
31
|
+
@wraps(func)
|
|
32
|
+
def wrapper(*args, **kwargs):
|
|
33
|
+
try:
|
|
34
|
+
return func(*args, **kwargs)
|
|
35
|
+
except TimeoutError as e:
|
|
36
|
+
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
|
|
37
|
+
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)
|
|
38
|
+
|
|
39
|
+
return wrapper
|
|
40
|
+
|
|
41
|
+
|
|
26
42
|
class AccessType(str, Enum):
|
|
27
43
|
ORGANIZATION = "organization"
|
|
28
44
|
USER = "user"
|
|
@@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
36
52
|
id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
37
53
|
|
|
38
54
|
@classmethod
|
|
39
|
-
|
|
40
|
-
"""Get a record by ID.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
db_session: SQLAlchemy session
|
|
44
|
-
id: Record ID to retrieve
|
|
45
|
-
|
|
46
|
-
Returns:
|
|
47
|
-
Optional[SqlalchemyBase]: The record if found, None otherwise
|
|
48
|
-
"""
|
|
49
|
-
try:
|
|
50
|
-
return db_session.query(cls).filter(cls.id == id).first()
|
|
51
|
-
except DBAPIError:
|
|
52
|
-
return None
|
|
53
|
-
|
|
54
|
-
@classmethod
|
|
55
|
+
@handle_db_timeout
|
|
55
56
|
def list(
|
|
56
57
|
cls,
|
|
57
58
|
*,
|
|
@@ -180,6 +181,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
180
181
|
return list(session.execute(query).scalars())
|
|
181
182
|
|
|
182
183
|
@classmethod
|
|
184
|
+
@handle_db_timeout
|
|
183
185
|
def read(
|
|
184
186
|
cls,
|
|
185
187
|
db_session: "Session",
|
|
@@ -231,6 +233,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
231
233
|
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
|
232
234
|
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
|
|
233
235
|
|
|
236
|
+
@handle_db_timeout
|
|
234
237
|
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
235
238
|
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
236
239
|
|
|
@@ -245,6 +248,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
245
248
|
except (DBAPIError, IntegrityError) as e:
|
|
246
249
|
self._handle_dbapi_error(e)
|
|
247
250
|
|
|
251
|
+
@handle_db_timeout
|
|
248
252
|
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
249
253
|
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
250
254
|
|
|
@@ -254,6 +258,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
254
258
|
self.is_deleted = True
|
|
255
259
|
return self.update(db_session)
|
|
256
260
|
|
|
261
|
+
@handle_db_timeout
|
|
257
262
|
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
|
|
258
263
|
"""Permanently removes the record from the database."""
|
|
259
264
|
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
@@ -269,6 +274,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
269
274
|
else:
|
|
270
275
|
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
|
271
276
|
|
|
277
|
+
@handle_db_timeout
|
|
272
278
|
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
273
279
|
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
274
280
|
if actor:
|
|
@@ -281,6 +287,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
281
287
|
return self
|
|
282
288
|
|
|
283
289
|
@classmethod
|
|
290
|
+
@handle_db_timeout
|
|
284
291
|
def size(
|
|
285
292
|
cls,
|
|
286
293
|
*,
|
letta/providers.py
CHANGED
|
@@ -13,6 +13,7 @@ from letta.schemas.llm_config import LLMConfig
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class Provider(BaseModel):
|
|
16
|
+
name: str = Field(..., description="The name of the provider")
|
|
16
17
|
|
|
17
18
|
def list_llm_models(self) -> List[LLMConfig]:
|
|
18
19
|
return []
|
|
@@ -465,6 +466,7 @@ class TogetherProvider(OpenAIProvider):
|
|
|
465
466
|
|
|
466
467
|
class GoogleAIProvider(Provider):
|
|
467
468
|
# gemini
|
|
469
|
+
name: str = "google_ai"
|
|
468
470
|
api_key: str = Field(..., description="API key for the Google AI API.")
|
|
469
471
|
base_url: str = "https://generativelanguage.googleapis.com"
|
|
470
472
|
|
letta/schemas/agent.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Dict, List, Optional
|
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel, Field, field_validator
|
|
5
5
|
|
|
6
|
+
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
6
7
|
from letta.schemas.block import CreateBlock
|
|
7
8
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
8
9
|
from letta.schemas.letta_base import OrmMetadataBase
|
|
@@ -107,6 +108,16 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
|
|
107
108
|
include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.")
|
|
108
109
|
description: Optional[str] = Field(None, description="The description of the agent.")
|
|
109
110
|
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
|
111
|
+
llm: Optional[str] = Field(
|
|
112
|
+
None,
|
|
113
|
+
description="The LLM configuration handle used by the agent, specified in the format "
|
|
114
|
+
"provider/model-name, as an alternative to specifying llm_config.",
|
|
115
|
+
)
|
|
116
|
+
embedding: Optional[str] = Field(
|
|
117
|
+
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
|
118
|
+
)
|
|
119
|
+
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
|
120
|
+
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
|
|
110
121
|
|
|
111
122
|
@field_validator("name")
|
|
112
123
|
@classmethod
|
|
@@ -133,6 +144,30 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
|
|
133
144
|
|
|
134
145
|
return name
|
|
135
146
|
|
|
147
|
+
@field_validator("llm")
|
|
148
|
+
@classmethod
|
|
149
|
+
def validate_llm(cls, llm: Optional[str]) -> Optional[str]:
|
|
150
|
+
if not llm:
|
|
151
|
+
return llm
|
|
152
|
+
|
|
153
|
+
provider_name, model_name = llm.split("/", 1)
|
|
154
|
+
if not provider_name or not model_name:
|
|
155
|
+
raise ValueError("The llm config handle should be in the format provider/model-name")
|
|
156
|
+
|
|
157
|
+
return llm
|
|
158
|
+
|
|
159
|
+
@field_validator("embedding")
|
|
160
|
+
@classmethod
|
|
161
|
+
def validate_embedding(cls, embedding: Optional[str]) -> Optional[str]:
|
|
162
|
+
if not embedding:
|
|
163
|
+
return embedding
|
|
164
|
+
|
|
165
|
+
provider_name, model_name = embedding.split("/", 1)
|
|
166
|
+
if not provider_name or not model_name:
|
|
167
|
+
raise ValueError("The embedding config handle should be in the format provider/model-name")
|
|
168
|
+
|
|
169
|
+
return embedding
|
|
170
|
+
|
|
136
171
|
|
|
137
172
|
class UpdateAgent(BaseModel):
|
|
138
173
|
name: Optional[str] = Field(None, description="The name of the agent.")
|
letta/schemas/sandbox_config.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import hashlib
|
|
2
2
|
import json
|
|
3
3
|
from enum import Enum
|
|
4
|
-
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
5
5
|
|
|
6
6
|
from pydantic import BaseModel, Field, model_validator
|
|
7
7
|
|
|
@@ -21,6 +21,7 @@ class SandboxRunResult(BaseModel):
|
|
|
21
21
|
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
|
22
22
|
stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation")
|
|
23
23
|
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
|
24
|
+
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
|
24
25
|
sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")
|
|
25
26
|
|
|
26
27
|
|
letta/server/rest_api/app.py
CHANGED
|
@@ -15,7 +15,12 @@ from letta.__init__ import __version__
|
|
|
15
15
|
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
|
16
16
|
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
|
|
17
17
|
from letta.log import get_logger
|
|
18
|
-
from letta.orm.errors import
|
|
18
|
+
from letta.orm.errors import (
|
|
19
|
+
DatabaseTimeoutError,
|
|
20
|
+
ForeignKeyConstraintViolationError,
|
|
21
|
+
NoResultFound,
|
|
22
|
+
UniqueConstraintViolationError,
|
|
23
|
+
)
|
|
19
24
|
from letta.schemas.letta_response import LettaResponse
|
|
20
25
|
from letta.server.constants import REST_DEFAULT_PORT
|
|
21
26
|
|
|
@@ -175,7 +180,6 @@ def create_application() -> "FastAPI":
|
|
|
175
180
|
|
|
176
181
|
@app.exception_handler(NoResultFound)
|
|
177
182
|
async def no_result_found_handler(request: Request, exc: NoResultFound):
|
|
178
|
-
logger.error(f"NoResultFound request: {request}")
|
|
179
183
|
logger.error(f"NoResultFound: {exc}")
|
|
180
184
|
|
|
181
185
|
return JSONResponse(
|
|
@@ -183,6 +187,32 @@ def create_application() -> "FastAPI":
|
|
|
183
187
|
content={"detail": str(exc)},
|
|
184
188
|
)
|
|
185
189
|
|
|
190
|
+
@app.exception_handler(ForeignKeyConstraintViolationError)
|
|
191
|
+
async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError):
|
|
192
|
+
logger.error(f"ForeignKeyConstraintViolationError: {exc}")
|
|
193
|
+
|
|
194
|
+
return JSONResponse(
|
|
195
|
+
status_code=409,
|
|
196
|
+
content={"detail": str(exc)},
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
@app.exception_handler(UniqueConstraintViolationError)
|
|
200
|
+
async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError):
|
|
201
|
+
logger.error(f"UniqueConstraintViolationError: {exc}")
|
|
202
|
+
|
|
203
|
+
return JSONResponse(
|
|
204
|
+
status_code=409,
|
|
205
|
+
content={"detail": str(exc)},
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
@app.exception_handler(DatabaseTimeoutError)
|
|
209
|
+
async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError):
|
|
210
|
+
logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}")
|
|
211
|
+
return JSONResponse(
|
|
212
|
+
status_code=503,
|
|
213
|
+
content={"detail": "The database is temporarily unavailable. Please try again later."},
|
|
214
|
+
)
|
|
215
|
+
|
|
186
216
|
@app.exception_handler(ValueError)
|
|
187
217
|
async def value_error_handler(request: Request, exc: ValueError):
|
|
188
218
|
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
|
@@ -235,11 +265,6 @@ def create_application() -> "FastAPI":
|
|
|
235
265
|
|
|
236
266
|
@app.on_event("startup")
|
|
237
267
|
def on_startup():
|
|
238
|
-
# load the default tools
|
|
239
|
-
# from letta.orm.tool import Tool
|
|
240
|
-
|
|
241
|
-
# Tool.load_default_tools(get_db_session())
|
|
242
|
-
|
|
243
268
|
generate_openapi_schema(app)
|
|
244
269
|
|
|
245
270
|
@app.on_event("shutdown")
|
letta/server/server.py
CHANGED
|
@@ -4,7 +4,6 @@ import os
|
|
|
4
4
|
import traceback
|
|
5
5
|
import warnings
|
|
6
6
|
from abc import abstractmethod
|
|
7
|
-
from asyncio import Lock
|
|
8
7
|
from datetime import datetime
|
|
9
8
|
from typing import Callable, List, Optional, Tuple, Union
|
|
10
9
|
|
|
@@ -75,7 +74,7 @@ from letta.services.source_manager import SourceManager
|
|
|
75
74
|
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
|
76
75
|
from letta.services.tool_manager import ToolManager
|
|
77
76
|
from letta.services.user_manager import UserManager
|
|
78
|
-
from letta.utils import get_utc_time, json_dumps, json_loads
|
|
77
|
+
from letta.utils import get_friendly_error_msg, get_utc_time, json_dumps, json_loads
|
|
79
78
|
|
|
80
79
|
logger = get_logger(__name__)
|
|
81
80
|
|
|
@@ -191,7 +190,14 @@ if settings.letta_pg_uri_no_default:
|
|
|
191
190
|
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
|
192
191
|
|
|
193
192
|
# create engine
|
|
194
|
-
engine = create_engine(
|
|
193
|
+
engine = create_engine(
|
|
194
|
+
settings.letta_pg_uri,
|
|
195
|
+
pool_size=settings.pg_pool_size,
|
|
196
|
+
max_overflow=settings.pg_max_overflow,
|
|
197
|
+
pool_timeout=settings.pg_pool_timeout,
|
|
198
|
+
pool_recycle=settings.pg_pool_recycle,
|
|
199
|
+
echo=settings.pg_echo,
|
|
200
|
+
)
|
|
195
201
|
else:
|
|
196
202
|
# TODO: don't rely on config storage
|
|
197
203
|
engine = create_engine("sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db"))
|
|
@@ -265,9 +271,6 @@ class SyncServer(Server):
|
|
|
265
271
|
|
|
266
272
|
self.credentials = LettaCredentials.load()
|
|
267
273
|
|
|
268
|
-
# Locks
|
|
269
|
-
self.send_message_lock = Lock()
|
|
270
|
-
|
|
271
274
|
# Initialize the metadata store
|
|
272
275
|
config = LettaConfig.load()
|
|
273
276
|
if settings.letta_pg_uri_no_default:
|
|
@@ -773,6 +776,18 @@ class SyncServer(Server):
|
|
|
773
776
|
# interface
|
|
774
777
|
interface: Union[AgentInterface, None] = None,
|
|
775
778
|
) -> AgentState:
|
|
779
|
+
if request.llm_config is None:
|
|
780
|
+
if request.llm is None:
|
|
781
|
+
raise ValueError("Must specify either llm or llm_config in request")
|
|
782
|
+
request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit)
|
|
783
|
+
|
|
784
|
+
if request.embedding_config is None:
|
|
785
|
+
if request.embedding is None:
|
|
786
|
+
raise ValueError("Must specify either embedding or embedding_config in request")
|
|
787
|
+
request.embedding_config = self.get_embedding_config_from_handle(
|
|
788
|
+
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
789
|
+
)
|
|
790
|
+
|
|
776
791
|
"""Create a new agent using a config"""
|
|
777
792
|
# Invoke manager
|
|
778
793
|
agent_state = self.agent_manager.create_agent(
|
|
@@ -821,7 +836,7 @@ class SyncServer(Server):
|
|
|
821
836
|
) -> AgentState:
|
|
822
837
|
"""Update the agents core memory block, return the new state"""
|
|
823
838
|
# Update agent state in the db first
|
|
824
|
-
self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
|
|
839
|
+
agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
|
|
825
840
|
|
|
826
841
|
# Get the agent object (loaded in memory)
|
|
827
842
|
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
@@ -838,13 +853,9 @@ class SyncServer(Server):
|
|
|
838
853
|
# then (2) setting the attributes ._messages and .state.message_ids
|
|
839
854
|
letta_agent.set_message_buffer(message_ids=request.message_ids)
|
|
840
855
|
|
|
841
|
-
# tools
|
|
842
|
-
if request.tool_ids:
|
|
843
|
-
letta_agent.link_tools(letta_agent.agent_state.tools)
|
|
844
|
-
|
|
845
856
|
letta_agent.update_state()
|
|
846
857
|
|
|
847
|
-
return
|
|
858
|
+
return agent_state
|
|
848
859
|
|
|
849
860
|
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
|
|
850
861
|
"""Get tools from an existing agent"""
|
|
@@ -867,11 +878,6 @@ class SyncServer(Server):
|
|
|
867
878
|
|
|
868
879
|
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
|
869
880
|
|
|
870
|
-
# TODO: This is very redundant, and should probably be simplified
|
|
871
|
-
# Get the agent object (loaded in memory)
|
|
872
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
873
|
-
letta_agent.link_tools(agent_state.tools)
|
|
874
|
-
|
|
875
881
|
return agent_state
|
|
876
882
|
|
|
877
883
|
def remove_tool_from_agent(
|
|
@@ -885,10 +891,6 @@ class SyncServer(Server):
|
|
|
885
891
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
886
892
|
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
|
887
893
|
|
|
888
|
-
# Get the agent object (loaded in memory)
|
|
889
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
890
|
-
letta_agent.link_tools(agent_state.tools)
|
|
891
|
-
|
|
892
894
|
return agent_state
|
|
893
895
|
|
|
894
896
|
# convert name->id
|
|
@@ -1280,6 +1282,55 @@ class SyncServer(Server):
|
|
|
1280
1282
|
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
|
1281
1283
|
return embedding_models
|
|
1282
1284
|
|
|
1285
|
+
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
|
1286
|
+
provider_name, model_name = handle.split("/", 1)
|
|
1287
|
+
provider = self.get_provider_from_name(provider_name)
|
|
1288
|
+
|
|
1289
|
+
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
|
|
1290
|
+
if not llm_configs:
|
|
1291
|
+
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
|
|
1292
|
+
elif len(llm_configs) > 1:
|
|
1293
|
+
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
|
|
1294
|
+
else:
|
|
1295
|
+
llm_config = llm_configs[0]
|
|
1296
|
+
|
|
1297
|
+
if context_window_limit:
|
|
1298
|
+
if context_window_limit > llm_config.context_window:
|
|
1299
|
+
raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})")
|
|
1300
|
+
llm_config.context_window = context_window_limit
|
|
1301
|
+
|
|
1302
|
+
return llm_config
|
|
1303
|
+
|
|
1304
|
+
def get_embedding_config_from_handle(
|
|
1305
|
+
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
|
1306
|
+
) -> EmbeddingConfig:
|
|
1307
|
+
provider_name, model_name = handle.split("/", 1)
|
|
1308
|
+
provider = self.get_provider_from_name(provider_name)
|
|
1309
|
+
|
|
1310
|
+
embedding_configs = [config for config in provider.list_embedding_models() if config.embedding_model == model_name]
|
|
1311
|
+
if not embedding_configs:
|
|
1312
|
+
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
|
|
1313
|
+
elif len(embedding_configs) > 1:
|
|
1314
|
+
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
|
|
1315
|
+
else:
|
|
1316
|
+
embedding_config = embedding_configs[0]
|
|
1317
|
+
|
|
1318
|
+
if embedding_chunk_size:
|
|
1319
|
+
embedding_config.embedding_chunk_size = embedding_chunk_size
|
|
1320
|
+
|
|
1321
|
+
return embedding_config
|
|
1322
|
+
|
|
1323
|
+
def get_provider_from_name(self, provider_name: str) -> Provider:
|
|
1324
|
+
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
|
|
1325
|
+
if not providers:
|
|
1326
|
+
raise ValueError(f"Provider {provider_name} is not supported")
|
|
1327
|
+
elif len(providers) > 1:
|
|
1328
|
+
raise ValueError(f"Multiple providers with name {provider_name} supported")
|
|
1329
|
+
else:
|
|
1330
|
+
provider = providers[0]
|
|
1331
|
+
|
|
1332
|
+
return provider
|
|
1333
|
+
|
|
1283
1334
|
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
|
1284
1335
|
"""Add a new LLM model"""
|
|
1285
1336
|
|
|
@@ -1300,7 +1351,7 @@ class SyncServer(Server):
|
|
|
1300
1351
|
|
|
1301
1352
|
def run_tool_from_source(
|
|
1302
1353
|
self,
|
|
1303
|
-
|
|
1354
|
+
actor: User,
|
|
1304
1355
|
tool_args: str,
|
|
1305
1356
|
tool_source: str,
|
|
1306
1357
|
tool_source_type: Optional[str] = None,
|
|
@@ -1328,56 +1379,29 @@ class SyncServer(Server):
|
|
|
1328
1379
|
|
|
1329
1380
|
# Next, attempt to run the tool with the sandbox
|
|
1330
1381
|
try:
|
|
1331
|
-
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict,
|
|
1332
|
-
function_response = str(sandbox_run_result.func_return)
|
|
1333
|
-
stdout = [s for s in sandbox_run_result.stdout if s.strip()]
|
|
1334
|
-
stderr = [s for s in sandbox_run_result.stderr if s.strip()]
|
|
1335
|
-
|
|
1336
|
-
# expected error
|
|
1337
|
-
if stderr:
|
|
1338
|
-
error_msg = self.get_error_msg_for_func_return(tool.name, stderr[-1])
|
|
1339
|
-
return FunctionReturn(
|
|
1340
|
-
id="null",
|
|
1341
|
-
function_call_id="null",
|
|
1342
|
-
date=get_utc_time(),
|
|
1343
|
-
status="error",
|
|
1344
|
-
function_return=error_msg,
|
|
1345
|
-
stdout=stdout,
|
|
1346
|
-
stderr=stderr,
|
|
1347
|
-
)
|
|
1348
|
-
|
|
1382
|
+
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
|
|
1349
1383
|
return FunctionReturn(
|
|
1350
1384
|
id="null",
|
|
1351
1385
|
function_call_id="null",
|
|
1352
1386
|
date=get_utc_time(),
|
|
1353
|
-
status=
|
|
1354
|
-
function_return=
|
|
1355
|
-
stdout=stdout,
|
|
1356
|
-
stderr=stderr,
|
|
1387
|
+
status=sandbox_run_result.status,
|
|
1388
|
+
function_return=str(sandbox_run_result.func_return),
|
|
1389
|
+
stdout=sandbox_run_result.stdout,
|
|
1390
|
+
stderr=sandbox_run_result.stderr,
|
|
1357
1391
|
)
|
|
1358
1392
|
|
|
1359
|
-
# unexpected error TODO(@cthomas): consolidate error handling
|
|
1360
1393
|
except Exception as e:
|
|
1361
|
-
|
|
1394
|
+
func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e))
|
|
1362
1395
|
return FunctionReturn(
|
|
1363
1396
|
id="null",
|
|
1364
1397
|
function_call_id="null",
|
|
1365
1398
|
date=get_utc_time(),
|
|
1366
1399
|
status="error",
|
|
1367
|
-
function_return=
|
|
1368
|
-
stdout=[
|
|
1400
|
+
function_return=func_return,
|
|
1401
|
+
stdout=[],
|
|
1369
1402
|
stderr=[traceback.format_exc()],
|
|
1370
1403
|
)
|
|
1371
1404
|
|
|
1372
|
-
def get_error_msg_for_func_return(self, tool_name, exception_message):
|
|
1373
|
-
# same as agent.py
|
|
1374
|
-
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
|
1375
|
-
|
|
1376
|
-
error_msg = f"Error executing tool {tool_name}: {exception_message}"
|
|
1377
|
-
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
|
1378
|
-
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
|
1379
|
-
return error_msg
|
|
1380
|
-
|
|
1381
1405
|
# Composio wrappers
|
|
1382
1406
|
def get_composio_client(self, api_key: Optional[str] = None):
|
|
1383
1407
|
if api_key:
|
letta/services/agent_manager.py
CHANGED
|
@@ -61,6 +61,9 @@ class AgentManager:
|
|
|
61
61
|
) -> PydanticAgentState:
|
|
62
62
|
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)
|
|
63
63
|
|
|
64
|
+
if not agent_create.llm_config or not agent_create.embedding_config:
|
|
65
|
+
raise ValueError("llm_config and embedding_config are required")
|
|
66
|
+
|
|
64
67
|
# create blocks (note: cannot be linked into the agent_id is created)
|
|
65
68
|
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
|
66
69
|
for create_block in agent_create.memory_blocks:
|