letta-nightly 0.6.1.dev20241206104246__py3-none-any.whl → 0.6.1.dev20241208104134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +68 -76
- letta/agent_store/db.py +1 -77
- letta/agent_store/storage.py +0 -5
- letta/cli/cli.py +1 -4
- letta/client/client.py +11 -14
- letta/constants.py +1 -0
- letta/functions/function_sets/base.py +33 -5
- letta/functions/helpers.py +3 -3
- letta/llm_api/openai.py +0 -1
- letta/local_llm/llm_chat_completion_wrappers/chatml.py +13 -1
- letta/main.py +2 -2
- letta/memory.py +4 -82
- letta/metadata.py +0 -35
- letta/o1_agent.py +7 -2
- letta/offline_memory_agent.py +6 -0
- letta/orm/__init__.py +2 -0
- letta/orm/file.py +1 -1
- letta/orm/message.py +64 -0
- letta/orm/mixins.py +16 -0
- letta/orm/organization.py +1 -0
- letta/orm/sqlalchemy_base.py +118 -26
- letta/schemas/letta_base.py +7 -6
- letta/schemas/message.py +6 -12
- letta/schemas/tool.py +18 -11
- letta/server/rest_api/app.py +2 -3
- letta/server/rest_api/routers/v1/agents.py +7 -6
- letta/server/rest_api/routers/v1/blocks.py +2 -2
- letta/server/rest_api/routers/v1/tools.py +26 -4
- letta/server/rest_api/utils.py +3 -1
- letta/server/server.py +67 -62
- letta/server/static_files/assets/index-43ab4d62.css +1 -0
- letta/server/static_files/assets/index-4848e3d7.js +40 -0
- letta/server/static_files/index.html +2 -2
- letta/services/block_manager.py +1 -1
- letta/services/message_manager.py +194 -0
- letta/services/organization_manager.py +6 -9
- letta/services/sandbox_config_manager.py +16 -1
- letta/services/source_manager.py +1 -1
- letta/services/tool_manager.py +2 -4
- letta/services/user_manager.py +1 -1
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/METADATA +2 -2
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/RECORD +45 -45
- letta/agent_store/lancedb.py +0 -177
- letta/persistence_manager.py +0 -149
- letta/server/static_files/assets/index-1b5d1a41.js +0 -271
- letta/server/static_files/assets/index-56a3f8c6.css +0 -1
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/entry_points.txt +0 -0
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from enum import Enum
|
|
1
3
|
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
|
2
4
|
|
|
3
|
-
from sqlalchemy import String, select
|
|
5
|
+
from sqlalchemy import String, func, select
|
|
4
6
|
from sqlalchemy.exc import DBAPIError
|
|
5
|
-
from sqlalchemy.orm import Mapped, mapped_column
|
|
7
|
+
from sqlalchemy.orm import Mapped, Session, mapped_column
|
|
6
8
|
|
|
7
9
|
from letta.log import get_logger
|
|
8
10
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
|
@@ -20,6 +22,11 @@ if TYPE_CHECKING:
|
|
|
20
22
|
logger = get_logger(__name__)
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class AccessType(str, Enum):
|
|
26
|
+
ORGANIZATION = "organization"
|
|
27
|
+
USER = "user"
|
|
28
|
+
|
|
29
|
+
|
|
23
30
|
class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
24
31
|
__abstract__ = True
|
|
25
32
|
|
|
@@ -28,46 +35,68 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
28
35
|
id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
29
36
|
|
|
30
37
|
@classmethod
|
|
31
|
-
def
|
|
32
|
-
|
|
33
|
-
) -> List[Type["SqlalchemyBase"]]:
|
|
34
|
-
"""
|
|
35
|
-
List records with optional cursor (for pagination), limit, and automatic filtering.
|
|
38
|
+
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
|
|
39
|
+
"""Get a record by ID.
|
|
36
40
|
|
|
37
41
|
Args:
|
|
38
|
-
db_session:
|
|
39
|
-
|
|
40
|
-
limit: Maximum number of records to return.
|
|
41
|
-
**kwargs: Filters passed as equality conditions or iterable for IN filtering.
|
|
42
|
+
db_session: SQLAlchemy session
|
|
43
|
+
id: Record ID to retrieve
|
|
42
44
|
|
|
43
45
|
Returns:
|
|
44
|
-
|
|
46
|
+
Optional[SqlalchemyBase]: The record if found, None otherwise
|
|
45
47
|
"""
|
|
46
|
-
|
|
48
|
+
try:
|
|
49
|
+
return db_session.query(cls).filter(cls.id == id).first()
|
|
50
|
+
except DBAPIError:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def list(
|
|
55
|
+
cls,
|
|
56
|
+
*,
|
|
57
|
+
db_session: "Session",
|
|
58
|
+
cursor: Optional[str] = None,
|
|
59
|
+
start_date: Optional[datetime] = None,
|
|
60
|
+
end_date: Optional[datetime] = None,
|
|
61
|
+
limit: Optional[int] = 50,
|
|
62
|
+
query_text: Optional[str] = None,
|
|
63
|
+
**kwargs,
|
|
64
|
+
) -> List[Type["SqlalchemyBase"]]:
|
|
65
|
+
"""List records with advanced filtering and pagination options."""
|
|
66
|
+
if start_date and end_date and start_date > end_date:
|
|
67
|
+
raise ValueError("start_date must be earlier than or equal to end_date")
|
|
68
|
+
|
|
69
|
+
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
|
47
70
|
with db_session as session:
|
|
48
|
-
# Start with a base query
|
|
49
71
|
query = select(cls)
|
|
50
72
|
|
|
51
73
|
# Apply filtering logic
|
|
52
74
|
for key, value in kwargs.items():
|
|
53
75
|
column = getattr(cls, key)
|
|
54
|
-
if isinstance(value, (list, tuple, set)):
|
|
76
|
+
if isinstance(value, (list, tuple, set)):
|
|
55
77
|
query = query.where(column.in_(value))
|
|
56
|
-
else:
|
|
78
|
+
else:
|
|
57
79
|
query = query.where(column == value)
|
|
58
80
|
|
|
59
|
-
#
|
|
81
|
+
# Date range filtering
|
|
82
|
+
if start_date:
|
|
83
|
+
query = query.filter(cls.created_at >= start_date)
|
|
84
|
+
if end_date:
|
|
85
|
+
query = query.filter(cls.created_at <= end_date)
|
|
86
|
+
|
|
87
|
+
# Cursor-based pagination
|
|
60
88
|
if cursor:
|
|
61
89
|
query = query.where(cls.id > cursor)
|
|
62
90
|
|
|
63
|
-
#
|
|
91
|
+
# Apply text search
|
|
92
|
+
if query_text:
|
|
93
|
+
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
|
94
|
+
|
|
95
|
+
# Handle ordering and soft deletes
|
|
64
96
|
if hasattr(cls, "is_deleted"):
|
|
65
97
|
query = query.where(cls.is_deleted == False)
|
|
66
|
-
|
|
67
|
-
# Add ordering and limit
|
|
68
98
|
query = query.order_by(cls.id).limit(limit)
|
|
69
99
|
|
|
70
|
-
# Execute the query and return results as model instances
|
|
71
100
|
return list(session.execute(query).scalars())
|
|
72
101
|
|
|
73
102
|
@classmethod
|
|
@@ -77,6 +106,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
77
106
|
identifier: Optional[str] = None,
|
|
78
107
|
actor: Optional["User"] = None,
|
|
79
108
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
109
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
80
110
|
**kwargs,
|
|
81
111
|
) -> Type["SqlalchemyBase"]:
|
|
82
112
|
"""The primary accessor for an ORM record.
|
|
@@ -108,7 +138,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
108
138
|
query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
|
|
109
139
|
|
|
110
140
|
if actor:
|
|
111
|
-
query = cls.apply_access_predicate(query, actor, access)
|
|
141
|
+
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
112
142
|
query_conditions.append(f"access level in {access} for actor='{actor}'")
|
|
113
143
|
|
|
114
144
|
if hasattr(cls, "is_deleted"):
|
|
@@ -170,12 +200,66 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
170
200
|
session.refresh(self)
|
|
171
201
|
return self
|
|
172
202
|
|
|
203
|
+
@classmethod
|
|
204
|
+
def size(
|
|
205
|
+
cls,
|
|
206
|
+
*,
|
|
207
|
+
db_session: "Session",
|
|
208
|
+
actor: Optional["User"] = None,
|
|
209
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
210
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
211
|
+
**kwargs,
|
|
212
|
+
) -> int:
|
|
213
|
+
"""
|
|
214
|
+
Get the count of rows that match the provided filters.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
db_session: SQLAlchemy session
|
|
218
|
+
**kwargs: Filters to apply to the query (e.g., column_name=value)
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
int: The count of rows that match the filters
|
|
222
|
+
|
|
223
|
+
Raises:
|
|
224
|
+
DBAPIError: If a database error occurs
|
|
225
|
+
"""
|
|
226
|
+
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
|
227
|
+
|
|
228
|
+
with db_session as session:
|
|
229
|
+
query = select(func.count()).select_from(cls)
|
|
230
|
+
|
|
231
|
+
if actor:
|
|
232
|
+
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
233
|
+
|
|
234
|
+
# Apply filtering logic based on kwargs
|
|
235
|
+
for key, value in kwargs.items():
|
|
236
|
+
if value:
|
|
237
|
+
column = getattr(cls, key, None)
|
|
238
|
+
if not column:
|
|
239
|
+
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
|
240
|
+
if isinstance(value, (list, tuple, set)): # Check for iterables
|
|
241
|
+
query = query.where(column.in_(value))
|
|
242
|
+
else: # Single value for equality filtering
|
|
243
|
+
query = query.where(column == value)
|
|
244
|
+
|
|
245
|
+
# Handle soft deletes if the class has the 'is_deleted' attribute
|
|
246
|
+
if hasattr(cls, "is_deleted"):
|
|
247
|
+
query = query.where(cls.is_deleted == False)
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
count = session.execute(query).scalar()
|
|
251
|
+
return count if count else 0
|
|
252
|
+
except DBAPIError as e:
|
|
253
|
+
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
|
254
|
+
raise e
|
|
255
|
+
|
|
173
256
|
@classmethod
|
|
174
257
|
def apply_access_predicate(
|
|
175
258
|
cls,
|
|
176
259
|
query: "Select",
|
|
177
260
|
actor: "User",
|
|
178
261
|
access: List[Literal["read", "write", "admin"]],
|
|
262
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
179
263
|
) -> "Select":
|
|
180
264
|
"""applies a WHERE clause restricting results to the given actor and access level
|
|
181
265
|
Args:
|
|
@@ -189,10 +273,18 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
189
273
|
the sqlalchemy select statement restricted to the given access.
|
|
190
274
|
"""
|
|
191
275
|
del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
276
|
+
if access_type == AccessType.ORGANIZATION:
|
|
277
|
+
org_id = getattr(actor, "organization_id", None)
|
|
278
|
+
if not org_id:
|
|
279
|
+
raise ValueError(f"object {actor} has no organization accessor")
|
|
280
|
+
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
|
|
281
|
+
elif access_type == AccessType.USER:
|
|
282
|
+
user_id = getattr(actor, "id", None)
|
|
283
|
+
if not user_id:
|
|
284
|
+
raise ValueError(f"object {actor} has no user accessor")
|
|
285
|
+
return query.where(cls.user_id == user_id, cls.is_deleted == False)
|
|
286
|
+
else:
|
|
287
|
+
raise ValueError(f"unknown access_type: {access_type}")
|
|
196
288
|
|
|
197
289
|
@classmethod
|
|
198
290
|
def _handle_dbapi_error(cls, e: DBAPIError):
|
letta/schemas/letta_base.py
CHANGED
|
@@ -33,18 +33,19 @@ class LettaBase(BaseModel):
|
|
|
33
33
|
def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
|
|
34
34
|
prefix = prefix or cls.__id_prefix__
|
|
35
35
|
|
|
36
|
-
# TODO: generate ID from regex pattern?
|
|
37
|
-
def _generate_id() -> str:
|
|
38
|
-
return f"{prefix}-{uuid.uuid4()}"
|
|
39
|
-
|
|
40
36
|
return Field(
|
|
41
37
|
...,
|
|
42
38
|
description=cls._id_description(prefix),
|
|
43
39
|
pattern=cls._id_regex_pattern(prefix),
|
|
44
40
|
examples=[cls._id_example(prefix)],
|
|
45
|
-
default_factory=_generate_id,
|
|
41
|
+
default_factory=cls._generate_id,
|
|
46
42
|
)
|
|
47
43
|
|
|
44
|
+
@classmethod
|
|
45
|
+
def _generate_id(cls, prefix: Optional[str] = None) -> str:
|
|
46
|
+
prefix = prefix or cls.__id_prefix__
|
|
47
|
+
return f"{prefix}-{uuid.uuid4()}"
|
|
48
|
+
|
|
48
49
|
# def _generate_id(self) -> str:
|
|
49
50
|
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
|
|
50
51
|
|
|
@@ -78,7 +79,7 @@ class LettaBase(BaseModel):
|
|
|
78
79
|
"""
|
|
79
80
|
_ = values # for SCA
|
|
80
81
|
if isinstance(v, UUID):
|
|
81
|
-
logger.
|
|
82
|
+
logger.debug(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
|
|
82
83
|
return f"{cls.__id_prefix__}-{v}"
|
|
83
84
|
return v
|
|
84
85
|
|
letta/schemas/message.py
CHANGED
|
@@ -13,7 +13,7 @@ from letta.constants import (
|
|
|
13
13
|
)
|
|
14
14
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
|
15
15
|
from letta.schemas.enums import MessageRole
|
|
16
|
-
from letta.schemas.letta_base import
|
|
16
|
+
from letta.schemas.letta_base import OrmMetadataBase
|
|
17
17
|
from letta.schemas.letta_message import (
|
|
18
18
|
AssistantMessage,
|
|
19
19
|
FunctionCall,
|
|
@@ -50,7 +50,7 @@ def add_inner_thoughts_to_tool_call(
|
|
|
50
50
|
raise e
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
class BaseMessage(
|
|
53
|
+
class BaseMessage(OrmMetadataBase):
|
|
54
54
|
__id_prefix__ = "message"
|
|
55
55
|
|
|
56
56
|
|
|
@@ -66,10 +66,9 @@ class MessageCreate(BaseMessage):
|
|
|
66
66
|
name: Optional[str] = Field(None, description="The name of the participant.")
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
class
|
|
69
|
+
class MessageUpdate(BaseMessage):
|
|
70
70
|
"""Request to update a message"""
|
|
71
71
|
|
|
72
|
-
id: str = Field(..., description="The id of the message.")
|
|
73
72
|
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
|
74
73
|
text: Optional[str] = Field(None, description="The text of the message.")
|
|
75
74
|
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
|
@@ -105,13 +104,14 @@ class Message(BaseMessage):
|
|
|
105
104
|
id: str = BaseMessage.generate_id_field()
|
|
106
105
|
role: MessageRole = Field(..., description="The role of the participant.")
|
|
107
106
|
text: Optional[str] = Field(None, description="The text of the message.")
|
|
108
|
-
|
|
107
|
+
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
|
109
108
|
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
|
110
109
|
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
|
111
110
|
name: Optional[str] = Field(None, description="The name of the participant.")
|
|
112
|
-
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")
|
|
113
111
|
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
|
|
114
112
|
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
|
113
|
+
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
|
114
|
+
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
|
115
115
|
|
|
116
116
|
@field_validator("role")
|
|
117
117
|
@classmethod
|
|
@@ -281,7 +281,6 @@ class Message(BaseMessage):
|
|
|
281
281
|
)
|
|
282
282
|
if id is not None:
|
|
283
283
|
return Message(
|
|
284
|
-
user_id=user_id,
|
|
285
284
|
agent_id=agent_id,
|
|
286
285
|
model=model,
|
|
287
286
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -295,7 +294,6 @@ class Message(BaseMessage):
|
|
|
295
294
|
)
|
|
296
295
|
else:
|
|
297
296
|
return Message(
|
|
298
|
-
user_id=user_id,
|
|
299
297
|
agent_id=agent_id,
|
|
300
298
|
model=model,
|
|
301
299
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -328,7 +326,6 @@ class Message(BaseMessage):
|
|
|
328
326
|
|
|
329
327
|
if id is not None:
|
|
330
328
|
return Message(
|
|
331
|
-
user_id=user_id,
|
|
332
329
|
agent_id=agent_id,
|
|
333
330
|
model=model,
|
|
334
331
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -342,7 +339,6 @@ class Message(BaseMessage):
|
|
|
342
339
|
)
|
|
343
340
|
else:
|
|
344
341
|
return Message(
|
|
345
|
-
user_id=user_id,
|
|
346
342
|
agent_id=agent_id,
|
|
347
343
|
model=model,
|
|
348
344
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -375,7 +371,6 @@ class Message(BaseMessage):
|
|
|
375
371
|
# If we're going from tool-call style
|
|
376
372
|
if id is not None:
|
|
377
373
|
return Message(
|
|
378
|
-
user_id=user_id,
|
|
379
374
|
agent_id=agent_id,
|
|
380
375
|
model=model,
|
|
381
376
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -389,7 +384,6 @@ class Message(BaseMessage):
|
|
|
389
384
|
)
|
|
390
385
|
else:
|
|
391
386
|
return Message(
|
|
392
|
-
user_id=user_id,
|
|
393
387
|
agent_id=agent_id,
|
|
394
388
|
model=model,
|
|
395
389
|
# standard fields expected in an OpenAI ChatCompletion message object
|
letta/schemas/tool.py
CHANGED
|
@@ -93,7 +93,7 @@ class ToolCreate(LettaBase):
|
|
|
93
93
|
)
|
|
94
94
|
|
|
95
95
|
@classmethod
|
|
96
|
-
def from_composio(cls,
|
|
96
|
+
def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "ToolCreate":
|
|
97
97
|
"""
|
|
98
98
|
Class method to create an instance of Letta-compatible Composio Tool.
|
|
99
99
|
Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio
|
|
@@ -101,15 +101,20 @@ class ToolCreate(LettaBase):
|
|
|
101
101
|
This function will error if we find more than one tool, or 0 tools.
|
|
102
102
|
|
|
103
103
|
Args:
|
|
104
|
-
|
|
104
|
+
action_name str: A action name to filter tools by.
|
|
105
105
|
Returns:
|
|
106
106
|
Tool: A Letta Tool initialized with attributes derived from the Composio tool.
|
|
107
107
|
"""
|
|
108
108
|
from composio import LogLevel
|
|
109
109
|
from composio_langchain import ComposioToolSet
|
|
110
110
|
|
|
111
|
-
|
|
112
|
-
|
|
111
|
+
if api_key:
|
|
112
|
+
# Pass in an external API key
|
|
113
|
+
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, api_key=api_key)
|
|
114
|
+
else:
|
|
115
|
+
# Use environmental variable
|
|
116
|
+
composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR)
|
|
117
|
+
composio_tools = composio_toolset.get_tools(actions=[action_name])
|
|
113
118
|
|
|
114
119
|
assert len(composio_tools) > 0, "User supplied parameters do not match any Composio tools"
|
|
115
120
|
assert len(composio_tools) == 1, f"User supplied parameters match too many Composio tools; {len(composio_tools)} > 1"
|
|
@@ -119,7 +124,7 @@ class ToolCreate(LettaBase):
|
|
|
119
124
|
description = composio_tool.description
|
|
120
125
|
source_type = "python"
|
|
121
126
|
tags = ["composio"]
|
|
122
|
-
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(
|
|
127
|
+
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action_name)
|
|
123
128
|
json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description)
|
|
124
129
|
|
|
125
130
|
return cls(
|
|
@@ -177,14 +182,16 @@ class ToolCreate(LettaBase):
|
|
|
177
182
|
|
|
178
183
|
@classmethod
|
|
179
184
|
def load_default_composio_tools(cls) -> List["ToolCreate"]:
|
|
180
|
-
|
|
185
|
+
pass
|
|
181
186
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
187
|
+
# TODO: Disable composio tools for now
|
|
188
|
+
# TODO: Naming is causing issues
|
|
189
|
+
# calculator = ToolCreate.from_composio(action_name=Action.MATHEMATICAL_CALCULATOR.name)
|
|
190
|
+
# serp_news = ToolCreate.from_composio(action_name=Action.SERPAPI_NEWS_SEARCH.name)
|
|
191
|
+
# serp_google_search = ToolCreate.from_composio(action_name=Action.SERPAPI_SEARCH.name)
|
|
192
|
+
# serp_google_maps = ToolCreate.from_composio(action_name=Action.SERPAPI_GOOGLE_MAPS_SEARCH.name)
|
|
186
193
|
|
|
187
|
-
return [
|
|
194
|
+
return []
|
|
188
195
|
|
|
189
196
|
|
|
190
197
|
class ToolUpdate(LettaBase):
|
letta/server/rest_api/app.py
CHANGED
|
@@ -144,9 +144,8 @@ def create_application() -> "FastAPI":
|
|
|
144
144
|
debug=True,
|
|
145
145
|
)
|
|
146
146
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
|
|
147
|
+
settings.cors_origins.append("https://app.letta.com")
|
|
148
|
+
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
|
|
150
149
|
|
|
151
150
|
if (os.getenv("LETTA_SERVER_SECURE") == "true") or "--secure" in sys.argv:
|
|
152
151
|
print(f"▶ Using secure mode with password: {random_password}")
|
|
@@ -28,7 +28,7 @@ from letta.schemas.memory import (
|
|
|
28
28
|
Memory,
|
|
29
29
|
RecallMemorySummary,
|
|
30
30
|
)
|
|
31
|
-
from letta.schemas.message import Message, MessageCreate,
|
|
31
|
+
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
|
32
32
|
from letta.schemas.passage import Passage
|
|
33
33
|
from letta.schemas.source import Source
|
|
34
34
|
from letta.schemas.tool import Tool
|
|
@@ -409,7 +409,7 @@ def get_agent_messages(
|
|
|
409
409
|
return server.get_agent_recall_cursor(
|
|
410
410
|
user_id=actor.id,
|
|
411
411
|
agent_id=agent_id,
|
|
412
|
-
|
|
412
|
+
cursor=before,
|
|
413
413
|
limit=limit,
|
|
414
414
|
reverse=True,
|
|
415
415
|
return_message_object=msg_object,
|
|
@@ -422,14 +422,13 @@ def get_agent_messages(
|
|
|
422
422
|
def update_message(
|
|
423
423
|
agent_id: str,
|
|
424
424
|
message_id: str,
|
|
425
|
-
request:
|
|
425
|
+
request: MessageUpdate = Body(...),
|
|
426
426
|
server: "SyncServer" = Depends(get_letta_server),
|
|
427
427
|
):
|
|
428
428
|
"""
|
|
429
429
|
Update the details of a message associated with an agent.
|
|
430
430
|
"""
|
|
431
|
-
|
|
432
|
-
return server.update_agent_message(agent_id=agent_id, request=request)
|
|
431
|
+
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request)
|
|
433
432
|
|
|
434
433
|
|
|
435
434
|
@router.post(
|
|
@@ -465,7 +464,7 @@ async def send_message(
|
|
|
465
464
|
@router.post(
|
|
466
465
|
"/{agent_id}/messages/stream",
|
|
467
466
|
response_model=None,
|
|
468
|
-
operation_id="
|
|
467
|
+
operation_id="create_agent_message_stream",
|
|
469
468
|
responses={
|
|
470
469
|
200: {
|
|
471
470
|
"description": "Successful response",
|
|
@@ -486,6 +485,8 @@ async def send_message_streaming(
|
|
|
486
485
|
This endpoint accepts a message from a user and processes it through the agent.
|
|
487
486
|
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
|
488
487
|
"""
|
|
488
|
+
request.stream_tokens = False
|
|
489
|
+
|
|
489
490
|
actor = server.get_user_or_default(user_id=user_id)
|
|
490
491
|
result = await send_message_to_agent(
|
|
491
492
|
server=server,
|
|
@@ -76,7 +76,7 @@ def get_block(
|
|
|
76
76
|
raise HTTPException(status_code=404, detail="Block not found")
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
@router.patch("/{block_id}/attach", response_model=Block, operation_id="
|
|
79
|
+
@router.patch("/{block_id}/attach", response_model=Block, operation_id="link_agent_memory_block")
|
|
80
80
|
def link_agent_memory_block(
|
|
81
81
|
block_id: str,
|
|
82
82
|
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
|
@@ -96,7 +96,7 @@ def link_agent_memory_block(
|
|
|
96
96
|
return block
|
|
97
97
|
|
|
98
98
|
|
|
99
|
-
@router.patch("/{block_id}/detach", response_model=Memory, operation_id="
|
|
99
|
+
@router.patch("/{block_id}/detach", response_model=Memory, operation_id="unlink_agent_memory_block")
|
|
100
100
|
def unlink_agent_memory_block(
|
|
101
101
|
block_id: str,
|
|
102
102
|
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
|
@@ -7,6 +7,7 @@ from letta.errors import LettaToolCreateError
|
|
|
7
7
|
from letta.orm.errors import UniqueConstraintViolationError
|
|
8
8
|
from letta.schemas.letta_message import FunctionReturn
|
|
9
9
|
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
|
|
10
|
+
from letta.schemas.user import User
|
|
10
11
|
from letta.server.rest_api.utils import get_letta_server
|
|
11
12
|
from letta.server.server import SyncServer
|
|
12
13
|
|
|
@@ -213,22 +214,27 @@ def run_tool_from_source(
|
|
|
213
214
|
|
|
214
215
|
|
|
215
216
|
@router.get("/composio/apps", response_model=List[AppModel], operation_id="list_composio_apps")
|
|
216
|
-
def list_composio_apps(server: SyncServer = Depends(get_letta_server)):
|
|
217
|
+
def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")):
|
|
217
218
|
"""
|
|
218
219
|
Get a list of all Composio apps
|
|
219
220
|
"""
|
|
220
|
-
|
|
221
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
222
|
+
composio_api_key = get_composio_key(server, actor=actor)
|
|
223
|
+
return server.get_composio_apps(api_key=composio_api_key)
|
|
221
224
|
|
|
222
225
|
|
|
223
226
|
@router.get("/composio/apps/{composio_app_name}/actions", response_model=List[ActionModel], operation_id="list_composio_actions_by_app")
|
|
224
227
|
def list_composio_actions_by_app(
|
|
225
228
|
composio_app_name: str,
|
|
226
229
|
server: SyncServer = Depends(get_letta_server),
|
|
230
|
+
user_id: Optional[str] = Header(None, alias="user_id"),
|
|
227
231
|
):
|
|
228
232
|
"""
|
|
229
233
|
Get a list of all Composio actions for a specific app
|
|
230
234
|
"""
|
|
231
|
-
|
|
235
|
+
actor = server.get_user_or_default(user_id=user_id)
|
|
236
|
+
composio_api_key = get_composio_key(server, actor=actor)
|
|
237
|
+
return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name, api_key=composio_api_key)
|
|
232
238
|
|
|
233
239
|
|
|
234
240
|
@router.post("/composio/{composio_action_name}", response_model=Tool, operation_id="add_composio_tool")
|
|
@@ -241,5 +247,21 @@ def add_composio_tool(
|
|
|
241
247
|
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
|
|
242
248
|
"""
|
|
243
249
|
actor = server.get_user_or_default(user_id=user_id)
|
|
244
|
-
|
|
250
|
+
composio_api_key = get_composio_key(server, actor=actor)
|
|
251
|
+
tool_create = ToolCreate.from_composio(action_name=composio_action_name, api_key=composio_api_key)
|
|
245
252
|
return server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=actor)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# TODO: Factor this out to somewhere else
|
|
256
|
+
def get_composio_key(server: SyncServer, actor: User):
|
|
257
|
+
api_keys = server.sandbox_config_manager.list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
|
|
258
|
+
if not api_keys:
|
|
259
|
+
raise HTTPException(
|
|
260
|
+
status_code=400, # Bad Request
|
|
261
|
+
detail=f"No API keys found for Composio. Please add your Composio API Key as an environment variable for your sandbox configuration.",
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# TODO: Add more protections around this
|
|
265
|
+
# Ideally, not tied to a specific sandbox, but for now we just get the first one
|
|
266
|
+
# Theoretically possible for someone to have different composio api keys per sandbox
|
|
267
|
+
return api_keys[0].value
|
letta/server/rest_api/utils.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
|
-
import traceback
|
|
4
3
|
import warnings
|
|
5
4
|
from enum import Enum
|
|
6
5
|
from typing import AsyncGenerator, Optional, Union
|
|
@@ -62,6 +61,9 @@ async def sse_async_generator(
|
|
|
62
61
|
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
|
|
63
62
|
yield sse_formatter({"usage": usage.model_dump()})
|
|
64
63
|
except Exception as e:
|
|
64
|
+
import traceback
|
|
65
|
+
|
|
66
|
+
traceback.print_exc()
|
|
65
67
|
warnings.warn(f"Error getting usage data: {e}")
|
|
66
68
|
yield sse_formatter({"error": "Failed to get usage data"})
|
|
67
69
|
|