letta-nightly 0.6.1.dev20241206104246__py3-none-any.whl → 0.6.1.dev20241208104134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (49) hide show
  1. letta/agent.py +68 -76
  2. letta/agent_store/db.py +1 -77
  3. letta/agent_store/storage.py +0 -5
  4. letta/cli/cli.py +1 -4
  5. letta/client/client.py +11 -14
  6. letta/constants.py +1 -0
  7. letta/functions/function_sets/base.py +33 -5
  8. letta/functions/helpers.py +3 -3
  9. letta/llm_api/openai.py +0 -1
  10. letta/local_llm/llm_chat_completion_wrappers/chatml.py +13 -1
  11. letta/main.py +2 -2
  12. letta/memory.py +4 -82
  13. letta/metadata.py +0 -35
  14. letta/o1_agent.py +7 -2
  15. letta/offline_memory_agent.py +6 -0
  16. letta/orm/__init__.py +2 -0
  17. letta/orm/file.py +1 -1
  18. letta/orm/message.py +64 -0
  19. letta/orm/mixins.py +16 -0
  20. letta/orm/organization.py +1 -0
  21. letta/orm/sqlalchemy_base.py +118 -26
  22. letta/schemas/letta_base.py +7 -6
  23. letta/schemas/message.py +6 -12
  24. letta/schemas/tool.py +18 -11
  25. letta/server/rest_api/app.py +2 -3
  26. letta/server/rest_api/routers/v1/agents.py +7 -6
  27. letta/server/rest_api/routers/v1/blocks.py +2 -2
  28. letta/server/rest_api/routers/v1/tools.py +26 -4
  29. letta/server/rest_api/utils.py +3 -1
  30. letta/server/server.py +67 -62
  31. letta/server/static_files/assets/index-43ab4d62.css +1 -0
  32. letta/server/static_files/assets/index-4848e3d7.js +40 -0
  33. letta/server/static_files/index.html +2 -2
  34. letta/services/block_manager.py +1 -1
  35. letta/services/message_manager.py +194 -0
  36. letta/services/organization_manager.py +6 -9
  37. letta/services/sandbox_config_manager.py +16 -1
  38. letta/services/source_manager.py +1 -1
  39. letta/services/tool_manager.py +2 -4
  40. letta/services/user_manager.py +1 -1
  41. {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/METADATA +2 -2
  42. {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/RECORD +45 -45
  43. letta/agent_store/lancedb.py +0 -177
  44. letta/persistence_manager.py +0 -149
  45. letta/server/static_files/assets/index-1b5d1a41.js +0 -271
  46. letta/server/static_files/assets/index-56a3f8c6.css +0 -1
  47. {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/LICENSE +0 -0
  48. {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/WHEEL +0 -0
  49. {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241208104134.dist-info}/entry_points.txt +0 -0
@@ -1,8 +1,10 @@
1
+ from datetime import datetime
2
+ from enum import Enum
1
3
  from typing import TYPE_CHECKING, List, Literal, Optional, Type
2
4
 
3
- from sqlalchemy import String, select
5
+ from sqlalchemy import String, func, select
4
6
  from sqlalchemy.exc import DBAPIError
5
- from sqlalchemy.orm import Mapped, mapped_column
7
+ from sqlalchemy.orm import Mapped, Session, mapped_column
6
8
 
7
9
  from letta.log import get_logger
8
10
  from letta.orm.base import Base, CommonSqlalchemyMetaMixins
@@ -20,6 +22,11 @@ if TYPE_CHECKING:
20
22
  logger = get_logger(__name__)
21
23
 
22
24
 
25
+ class AccessType(str, Enum):
26
+ ORGANIZATION = "organization"
27
+ USER = "user"
28
+
29
+
23
30
  class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
24
31
  __abstract__ = True
25
32
 
@@ -28,46 +35,68 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
28
35
  id: Mapped[str] = mapped_column(String, primary_key=True)
29
36
 
30
37
  @classmethod
31
- def list(
32
- cls, *, db_session: "Session", cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs
33
- ) -> List[Type["SqlalchemyBase"]]:
34
- """
35
- List records with optional cursor (for pagination), limit, and automatic filtering.
38
+ def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
39
+ """Get a record by ID.
36
40
 
37
41
  Args:
38
- db_session: The database session to use.
39
- cursor: Optional ID to start pagination from.
40
- limit: Maximum number of records to return.
41
- **kwargs: Filters passed as equality conditions or iterable for IN filtering.
42
+ db_session: SQLAlchemy session
43
+ id: Record ID to retrieve
42
44
 
43
45
  Returns:
44
- A list of model instances matching the filters.
46
+ Optional[SqlalchemyBase]: The record if found, None otherwise
45
47
  """
46
- logger.debug(f"Listing {cls.__name__} with filters {kwargs}")
48
+ try:
49
+ return db_session.query(cls).filter(cls.id == id).first()
50
+ except DBAPIError:
51
+ return None
52
+
53
+ @classmethod
54
+ def list(
55
+ cls,
56
+ *,
57
+ db_session: "Session",
58
+ cursor: Optional[str] = None,
59
+ start_date: Optional[datetime] = None,
60
+ end_date: Optional[datetime] = None,
61
+ limit: Optional[int] = 50,
62
+ query_text: Optional[str] = None,
63
+ **kwargs,
64
+ ) -> List[Type["SqlalchemyBase"]]:
65
+ """List records with advanced filtering and pagination options."""
66
+ if start_date and end_date and start_date > end_date:
67
+ raise ValueError("start_date must be earlier than or equal to end_date")
68
+
69
+ logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
47
70
  with db_session as session:
48
- # Start with a base query
49
71
  query = select(cls)
50
72
 
51
73
  # Apply filtering logic
52
74
  for key, value in kwargs.items():
53
75
  column = getattr(cls, key)
54
- if isinstance(value, (list, tuple, set)): # Check for iterables
76
+ if isinstance(value, (list, tuple, set)):
55
77
  query = query.where(column.in_(value))
56
- else: # Single value for equality filtering
78
+ else:
57
79
  query = query.where(column == value)
58
80
 
59
- # Apply cursor for pagination
81
+ # Date range filtering
82
+ if start_date:
83
+ query = query.filter(cls.created_at >= start_date)
84
+ if end_date:
85
+ query = query.filter(cls.created_at <= end_date)
86
+
87
+ # Cursor-based pagination
60
88
  if cursor:
61
89
  query = query.where(cls.id > cursor)
62
90
 
63
- # Handle soft deletes if the class has the 'is_deleted' attribute
91
+ # Apply text search
92
+ if query_text:
93
+ query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
94
+
95
+ # Handle ordering and soft deletes
64
96
  if hasattr(cls, "is_deleted"):
65
97
  query = query.where(cls.is_deleted == False)
66
-
67
- # Add ordering and limit
68
98
  query = query.order_by(cls.id).limit(limit)
69
99
 
70
- # Execute the query and return results as model instances
71
100
  return list(session.execute(query).scalars())
72
101
 
73
102
  @classmethod
@@ -77,6 +106,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
77
106
  identifier: Optional[str] = None,
78
107
  actor: Optional["User"] = None,
79
108
  access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
109
+ access_type: AccessType = AccessType.ORGANIZATION,
80
110
  **kwargs,
81
111
  ) -> Type["SqlalchemyBase"]:
82
112
  """The primary accessor for an ORM record.
@@ -108,7 +138,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
108
138
  query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
109
139
 
110
140
  if actor:
111
- query = cls.apply_access_predicate(query, actor, access)
141
+ query = cls.apply_access_predicate(query, actor, access, access_type)
112
142
  query_conditions.append(f"access level in {access} for actor='{actor}'")
113
143
 
114
144
  if hasattr(cls, "is_deleted"):
@@ -170,12 +200,66 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
170
200
  session.refresh(self)
171
201
  return self
172
202
 
203
+ @classmethod
204
+ def size(
205
+ cls,
206
+ *,
207
+ db_session: "Session",
208
+ actor: Optional["User"] = None,
209
+ access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
210
+ access_type: AccessType = AccessType.ORGANIZATION,
211
+ **kwargs,
212
+ ) -> int:
213
+ """
214
+ Get the count of rows that match the provided filters.
215
+
216
+ Args:
217
+ db_session: SQLAlchemy session
218
+ **kwargs: Filters to apply to the query (e.g., column_name=value)
219
+
220
+ Returns:
221
+ int: The count of rows that match the filters
222
+
223
+ Raises:
224
+ DBAPIError: If a database error occurs
225
+ """
226
+ logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
227
+
228
+ with db_session as session:
229
+ query = select(func.count()).select_from(cls)
230
+
231
+ if actor:
232
+ query = cls.apply_access_predicate(query, actor, access, access_type)
233
+
234
+ # Apply filtering logic based on kwargs
235
+ for key, value in kwargs.items():
236
+ if value:
237
+ column = getattr(cls, key, None)
238
+ if not column:
239
+ raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
240
+ if isinstance(value, (list, tuple, set)): # Check for iterables
241
+ query = query.where(column.in_(value))
242
+ else: # Single value for equality filtering
243
+ query = query.where(column == value)
244
+
245
+ # Handle soft deletes if the class has the 'is_deleted' attribute
246
+ if hasattr(cls, "is_deleted"):
247
+ query = query.where(cls.is_deleted == False)
248
+
249
+ try:
250
+ count = session.execute(query).scalar()
251
+ return count if count else 0
252
+ except DBAPIError as e:
253
+ logger.exception(f"Failed to calculate size for {cls.__name__}")
254
+ raise e
255
+
173
256
  @classmethod
174
257
  def apply_access_predicate(
175
258
  cls,
176
259
  query: "Select",
177
260
  actor: "User",
178
261
  access: List[Literal["read", "write", "admin"]],
262
+ access_type: AccessType = AccessType.ORGANIZATION,
179
263
  ) -> "Select":
180
264
  """applies a WHERE clause restricting results to the given actor and access level
181
265
  Args:
@@ -189,10 +273,18 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
189
273
  the sqlalchemy select statement restricted to the given access.
190
274
  """
191
275
  del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
192
- org_id = getattr(actor, "organization_id", None)
193
- if not org_id:
194
- raise ValueError(f"object {actor} has no organization accessor")
195
- return query.where(cls.organization_id == org_id, cls.is_deleted == False)
276
+ if access_type == AccessType.ORGANIZATION:
277
+ org_id = getattr(actor, "organization_id", None)
278
+ if not org_id:
279
+ raise ValueError(f"object {actor} has no organization accessor")
280
+ return query.where(cls.organization_id == org_id, cls.is_deleted == False)
281
+ elif access_type == AccessType.USER:
282
+ user_id = getattr(actor, "id", None)
283
+ if not user_id:
284
+ raise ValueError(f"object {actor} has no user accessor")
285
+ return query.where(cls.user_id == user_id, cls.is_deleted == False)
286
+ else:
287
+ raise ValueError(f"unknown access_type: {access_type}")
196
288
 
197
289
  @classmethod
198
290
  def _handle_dbapi_error(cls, e: DBAPIError):
@@ -33,18 +33,19 @@ class LettaBase(BaseModel):
33
33
  def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
34
34
  prefix = prefix or cls.__id_prefix__
35
35
 
36
- # TODO: generate ID from regex pattern?
37
- def _generate_id() -> str:
38
- return f"{prefix}-{uuid.uuid4()}"
39
-
40
36
  return Field(
41
37
  ...,
42
38
  description=cls._id_description(prefix),
43
39
  pattern=cls._id_regex_pattern(prefix),
44
40
  examples=[cls._id_example(prefix)],
45
- default_factory=_generate_id,
41
+ default_factory=cls._generate_id,
46
42
  )
47
43
 
44
+ @classmethod
45
+ def _generate_id(cls, prefix: Optional[str] = None) -> str:
46
+ prefix = prefix or cls.__id_prefix__
47
+ return f"{prefix}-{uuid.uuid4()}"
48
+
48
49
  # def _generate_id(self) -> str:
49
50
  # return f"{self.__id_prefix__}-{uuid.uuid4()}"
50
51
 
@@ -78,7 +79,7 @@ class LettaBase(BaseModel):
78
79
  """
79
80
  _ = values # for SCA
80
81
  if isinstance(v, UUID):
81
- logger.warning(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
82
+ logger.debug(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
82
83
  return f"{cls.__id_prefix__}-{v}"
83
84
  return v
84
85
 
letta/schemas/message.py CHANGED
@@ -13,7 +13,7 @@ from letta.constants import (
13
13
  )
14
14
  from letta.local_llm.constants import INNER_THOUGHTS_KWARG
15
15
  from letta.schemas.enums import MessageRole
16
- from letta.schemas.letta_base import LettaBase
16
+ from letta.schemas.letta_base import OrmMetadataBase
17
17
  from letta.schemas.letta_message import (
18
18
  AssistantMessage,
19
19
  FunctionCall,
@@ -50,7 +50,7 @@ def add_inner_thoughts_to_tool_call(
50
50
  raise e
51
51
 
52
52
 
53
- class BaseMessage(LettaBase):
53
+ class BaseMessage(OrmMetadataBase):
54
54
  __id_prefix__ = "message"
55
55
 
56
56
 
@@ -66,10 +66,9 @@ class MessageCreate(BaseMessage):
66
66
  name: Optional[str] = Field(None, description="The name of the participant.")
67
67
 
68
68
 
69
- class UpdateMessage(BaseMessage):
69
+ class MessageUpdate(BaseMessage):
70
70
  """Request to update a message"""
71
71
 
72
- id: str = Field(..., description="The id of the message.")
73
72
  role: Optional[MessageRole] = Field(None, description="The role of the participant.")
74
73
  text: Optional[str] = Field(None, description="The text of the message.")
75
74
  # NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
@@ -105,13 +104,14 @@ class Message(BaseMessage):
105
104
  id: str = BaseMessage.generate_id_field()
106
105
  role: MessageRole = Field(..., description="The role of the participant.")
107
106
  text: Optional[str] = Field(None, description="The text of the message.")
108
- user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
107
+ organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
109
108
  agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
110
109
  model: Optional[str] = Field(None, description="The model used to make the function call.")
111
110
  name: Optional[str] = Field(None, description="The name of the participant.")
112
- created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")
113
111
  tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
114
112
  tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
113
+ # This overrides the optional base orm schema, created_at MUST exist on all messages objects
114
+ created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
115
115
 
116
116
  @field_validator("role")
117
117
  @classmethod
@@ -281,7 +281,6 @@ class Message(BaseMessage):
281
281
  )
282
282
  if id is not None:
283
283
  return Message(
284
- user_id=user_id,
285
284
  agent_id=agent_id,
286
285
  model=model,
287
286
  # standard fields expected in an OpenAI ChatCompletion message object
@@ -295,7 +294,6 @@ class Message(BaseMessage):
295
294
  )
296
295
  else:
297
296
  return Message(
298
- user_id=user_id,
299
297
  agent_id=agent_id,
300
298
  model=model,
301
299
  # standard fields expected in an OpenAI ChatCompletion message object
@@ -328,7 +326,6 @@ class Message(BaseMessage):
328
326
 
329
327
  if id is not None:
330
328
  return Message(
331
- user_id=user_id,
332
329
  agent_id=agent_id,
333
330
  model=model,
334
331
  # standard fields expected in an OpenAI ChatCompletion message object
@@ -342,7 +339,6 @@ class Message(BaseMessage):
342
339
  )
343
340
  else:
344
341
  return Message(
345
- user_id=user_id,
346
342
  agent_id=agent_id,
347
343
  model=model,
348
344
  # standard fields expected in an OpenAI ChatCompletion message object
@@ -375,7 +371,6 @@ class Message(BaseMessage):
375
371
  # If we're going from tool-call style
376
372
  if id is not None:
377
373
  return Message(
378
- user_id=user_id,
379
374
  agent_id=agent_id,
380
375
  model=model,
381
376
  # standard fields expected in an OpenAI ChatCompletion message object
@@ -389,7 +384,6 @@ class Message(BaseMessage):
389
384
  )
390
385
  else:
391
386
  return Message(
392
- user_id=user_id,
393
387
  agent_id=agent_id,
394
388
  model=model,
395
389
  # standard fields expected in an OpenAI ChatCompletion message object
letta/schemas/tool.py CHANGED
@@ -93,7 +93,7 @@ class ToolCreate(LettaBase):
93
93
  )
94
94
 
95
95
  @classmethod
96
- def from_composio(cls, action: "ActionType") -> "ToolCreate":
96
+ def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "ToolCreate":
97
97
  """
98
98
  Class method to create an instance of Letta-compatible Composio Tool.
99
99
  Check https://docs.composio.dev/introduction/intro/overview to look at options for from_composio
@@ -101,15 +101,20 @@ class ToolCreate(LettaBase):
101
101
  This function will error if we find more than one tool, or 0 tools.
102
102
 
103
103
  Args:
104
- action ActionType: A action name to filter tools by.
104
+ action_name str: A action name to filter tools by.
105
105
  Returns:
106
106
  Tool: A Letta Tool initialized with attributes derived from the Composio tool.
107
107
  """
108
108
  from composio import LogLevel
109
109
  from composio_langchain import ComposioToolSet
110
110
 
111
- composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR)
112
- composio_tools = composio_toolset.get_tools(actions=[action])
111
+ if api_key:
112
+ # Pass in an external API key
113
+ composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR, api_key=api_key)
114
+ else:
115
+ # Use environmental variable
116
+ composio_toolset = ComposioToolSet(logging_level=LogLevel.ERROR)
117
+ composio_tools = composio_toolset.get_tools(actions=[action_name])
113
118
 
114
119
  assert len(composio_tools) > 0, "User supplied parameters do not match any Composio tools"
115
120
  assert len(composio_tools) == 1, f"User supplied parameters match too many Composio tools; {len(composio_tools)} > 1"
@@ -119,7 +124,7 @@ class ToolCreate(LettaBase):
119
124
  description = composio_tool.description
120
125
  source_type = "python"
121
126
  tags = ["composio"]
122
- wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action)
127
+ wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action_name)
123
128
  json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description)
124
129
 
125
130
  return cls(
@@ -177,14 +182,16 @@ class ToolCreate(LettaBase):
177
182
 
178
183
  @classmethod
179
184
  def load_default_composio_tools(cls) -> List["ToolCreate"]:
180
- from composio_langchain import Action
185
+ pass
181
186
 
182
- calculator = ToolCreate.from_composio(action=Action.MATHEMATICAL_CALCULATOR)
183
- serp_news = ToolCreate.from_composio(action=Action.SERPAPI_NEWS_SEARCH)
184
- serp_google_search = ToolCreate.from_composio(action=Action.SERPAPI_SEARCH)
185
- serp_google_maps = ToolCreate.from_composio(action=Action.SERPAPI_GOOGLE_MAPS_SEARCH)
187
+ # TODO: Disable composio tools for now
188
+ # TODO: Naming is causing issues
189
+ # calculator = ToolCreate.from_composio(action_name=Action.MATHEMATICAL_CALCULATOR.name)
190
+ # serp_news = ToolCreate.from_composio(action_name=Action.SERPAPI_NEWS_SEARCH.name)
191
+ # serp_google_search = ToolCreate.from_composio(action_name=Action.SERPAPI_SEARCH.name)
192
+ # serp_google_maps = ToolCreate.from_composio(action_name=Action.SERPAPI_GOOGLE_MAPS_SEARCH.name)
186
193
 
187
- return [calculator, serp_news, serp_google_search, serp_google_maps]
194
+ return []
188
195
 
189
196
 
190
197
  class ToolUpdate(LettaBase):
@@ -144,9 +144,8 @@ def create_application() -> "FastAPI":
144
144
  debug=True,
145
145
  )
146
146
 
147
- if (os.getenv("LETTA_SERVER_ADE") == "true") or "--ade" in sys.argv:
148
- settings.cors_origins.append("https://app.letta.com")
149
- print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
147
+ settings.cors_origins.append("https://app.letta.com")
148
+ print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
150
149
 
151
150
  if (os.getenv("LETTA_SERVER_SECURE") == "true") or "--secure" in sys.argv:
152
151
  print(f"▶ Using secure mode with password: {random_password}")
@@ -28,7 +28,7 @@ from letta.schemas.memory import (
28
28
  Memory,
29
29
  RecallMemorySummary,
30
30
  )
31
- from letta.schemas.message import Message, MessageCreate, UpdateMessage
31
+ from letta.schemas.message import Message, MessageCreate, MessageUpdate
32
32
  from letta.schemas.passage import Passage
33
33
  from letta.schemas.source import Source
34
34
  from letta.schemas.tool import Tool
@@ -409,7 +409,7 @@ def get_agent_messages(
409
409
  return server.get_agent_recall_cursor(
410
410
  user_id=actor.id,
411
411
  agent_id=agent_id,
412
- before=before,
412
+ cursor=before,
413
413
  limit=limit,
414
414
  reverse=True,
415
415
  return_message_object=msg_object,
@@ -422,14 +422,13 @@ def get_agent_messages(
422
422
  def update_message(
423
423
  agent_id: str,
424
424
  message_id: str,
425
- request: UpdateMessage = Body(...),
425
+ request: MessageUpdate = Body(...),
426
426
  server: "SyncServer" = Depends(get_letta_server),
427
427
  ):
428
428
  """
429
429
  Update the details of a message associated with an agent.
430
430
  """
431
- assert request.id == message_id, f"Message ID mismatch: {request.id} != {message_id}"
432
- return server.update_agent_message(agent_id=agent_id, request=request)
431
+ return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request)
433
432
 
434
433
 
435
434
  @router.post(
@@ -465,7 +464,7 @@ async def send_message(
465
464
  @router.post(
466
465
  "/{agent_id}/messages/stream",
467
466
  response_model=None,
468
- operation_id="create_agent_message",
467
+ operation_id="create_agent_message_stream",
469
468
  responses={
470
469
  200: {
471
470
  "description": "Successful response",
@@ -486,6 +485,8 @@ async def send_message_streaming(
486
485
  This endpoint accepts a message from a user and processes it through the agent.
487
486
  It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
488
487
  """
488
+ request.stream_tokens = False
489
+
489
490
  actor = server.get_user_or_default(user_id=user_id)
490
491
  result = await send_message_to_agent(
491
492
  server=server,
@@ -76,7 +76,7 @@ def get_block(
76
76
  raise HTTPException(status_code=404, detail="Block not found")
77
77
 
78
78
 
79
- @router.patch("/{block_id}/attach", response_model=Block, operation_id="update_agent_memory_block")
79
+ @router.patch("/{block_id}/attach", response_model=Block, operation_id="link_agent_memory_block")
80
80
  def link_agent_memory_block(
81
81
  block_id: str,
82
82
  agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
@@ -96,7 +96,7 @@ def link_agent_memory_block(
96
96
  return block
97
97
 
98
98
 
99
- @router.patch("/{block_id}/detach", response_model=Memory, operation_id="update_agent_memory_block")
99
+ @router.patch("/{block_id}/detach", response_model=Memory, operation_id="unlink_agent_memory_block")
100
100
  def unlink_agent_memory_block(
101
101
  block_id: str,
102
102
  agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
@@ -7,6 +7,7 @@ from letta.errors import LettaToolCreateError
7
7
  from letta.orm.errors import UniqueConstraintViolationError
8
8
  from letta.schemas.letta_message import FunctionReturn
9
9
  from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
10
+ from letta.schemas.user import User
10
11
  from letta.server.rest_api.utils import get_letta_server
11
12
  from letta.server.server import SyncServer
12
13
 
@@ -213,22 +214,27 @@ def run_tool_from_source(
213
214
 
214
215
 
215
216
  @router.get("/composio/apps", response_model=List[AppModel], operation_id="list_composio_apps")
216
- def list_composio_apps(server: SyncServer = Depends(get_letta_server)):
217
+ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")):
217
218
  """
218
219
  Get a list of all Composio apps
219
220
  """
220
- return server.get_composio_apps()
221
+ actor = server.get_user_or_default(user_id=user_id)
222
+ composio_api_key = get_composio_key(server, actor=actor)
223
+ return server.get_composio_apps(api_key=composio_api_key)
221
224
 
222
225
 
223
226
  @router.get("/composio/apps/{composio_app_name}/actions", response_model=List[ActionModel], operation_id="list_composio_actions_by_app")
224
227
  def list_composio_actions_by_app(
225
228
  composio_app_name: str,
226
229
  server: SyncServer = Depends(get_letta_server),
230
+ user_id: Optional[str] = Header(None, alias="user_id"),
227
231
  ):
228
232
  """
229
233
  Get a list of all Composio actions for a specific app
230
234
  """
231
- return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name)
235
+ actor = server.get_user_or_default(user_id=user_id)
236
+ composio_api_key = get_composio_key(server, actor=actor)
237
+ return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name, api_key=composio_api_key)
232
238
 
233
239
 
234
240
  @router.post("/composio/{composio_action_name}", response_model=Tool, operation_id="add_composio_tool")
@@ -241,5 +247,21 @@ def add_composio_tool(
241
247
  Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
242
248
  """
243
249
  actor = server.get_user_or_default(user_id=user_id)
244
- tool_create = ToolCreate.from_composio(action=composio_action_name)
250
+ composio_api_key = get_composio_key(server, actor=actor)
251
+ tool_create = ToolCreate.from_composio(action_name=composio_action_name, api_key=composio_api_key)
245
252
  return server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=actor)
253
+
254
+
255
+ # TODO: Factor this out to somewhere else
256
+ def get_composio_key(server: SyncServer, actor: User):
257
+ api_keys = server.sandbox_config_manager.list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
258
+ if not api_keys:
259
+ raise HTTPException(
260
+ status_code=400, # Bad Request
261
+ detail=f"No API keys found for Composio. Please add your Composio API Key as an environment variable for your sandbox configuration.",
262
+ )
263
+
264
+ # TODO: Add more protections around this
265
+ # Ideally, not tied to a specific sandbox, but for now we just get the first one
266
+ # Theoretically possible for someone to have different composio api keys per sandbox
267
+ return api_keys[0].value
@@ -1,6 +1,5 @@
1
1
  import asyncio
2
2
  import json
3
- import traceback
4
3
  import warnings
5
4
  from enum import Enum
6
5
  from typing import AsyncGenerator, Optional, Union
@@ -62,6 +61,9 @@ async def sse_async_generator(
62
61
  raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
63
62
  yield sse_formatter({"usage": usage.model_dump()})
64
63
  except Exception as e:
64
+ import traceback
65
+
66
+ traceback.print_exc()
65
67
  warnings.warn(f"Error getting usage data: {e}")
66
68
  yield sse_formatter({"error": "Failed to get usage data"})
67
69