letta-nightly 0.6.2.dev20241210030340__py3-none-any.whl → 0.6.2.dev20241211031658__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (43) hide show
  1. letta/agent.py +32 -43
  2. letta/agent_store/db.py +12 -54
  3. letta/agent_store/storage.py +10 -9
  4. letta/cli/cli.py +1 -0
  5. letta/client/client.py +4 -3
  6. letta/config.py +2 -2
  7. letta/data_sources/connectors.py +4 -3
  8. letta/embeddings.py +29 -9
  9. letta/functions/function_sets/base.py +36 -11
  10. letta/metadata.py +13 -2
  11. letta/o1_agent.py +2 -3
  12. letta/offline_memory_agent.py +2 -1
  13. letta/orm/__init__.py +1 -0
  14. letta/orm/file.py +1 -0
  15. letta/orm/mixins.py +12 -2
  16. letta/orm/organization.py +3 -0
  17. letta/orm/passage.py +72 -0
  18. letta/orm/sqlalchemy_base.py +66 -10
  19. letta/orm/sqlite_functions.py +140 -0
  20. letta/orm/user.py +1 -1
  21. letta/schemas/agent.py +4 -3
  22. letta/schemas/letta_message.py +5 -1
  23. letta/schemas/letta_request.py +3 -3
  24. letta/schemas/passage.py +6 -4
  25. letta/schemas/sandbox_config.py +1 -0
  26. letta/schemas/tool_rule.py +0 -3
  27. letta/server/rest_api/app.py +34 -12
  28. letta/server/rest_api/routers/v1/agents.py +20 -7
  29. letta/server/server.py +76 -52
  30. letta/server/static_files/assets/{index-4848e3d7.js → index-048c9598.js} +1 -1
  31. letta/server/static_files/assets/{index-43ab4d62.css → index-0e31b727.css} +1 -1
  32. letta/server/static_files/index.html +2 -2
  33. letta/services/message_manager.py +3 -0
  34. letta/services/passage_manager.py +225 -0
  35. letta/services/source_manager.py +2 -1
  36. letta/services/tool_execution_sandbox.py +19 -7
  37. letta/settings.py +2 -0
  38. {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/METADATA +10 -15
  39. {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/RECORD +42 -40
  40. letta/agent_store/chroma.py +0 -297
  41. {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/LICENSE +0 -0
  42. {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/WHEEL +0 -0
  43. {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/entry_points.txt +0 -0
letta/orm/organization.py CHANGED
@@ -33,7 +33,10 @@ class Organization(SqlalchemyBase):
33
33
  sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
34
34
  "SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
35
35
  )
36
+
37
+ # relationships
36
38
  messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
39
+ passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
37
40
 
38
41
  # TODO: Map these relationships later when we actually make these models
39
42
  # below is just a suggestion
letta/orm/passage.py ADDED
@@ -0,0 +1,72 @@
1
+ from datetime import datetime
2
+ from typing import List, Optional, TYPE_CHECKING
3
+ from sqlalchemy import Column, String, DateTime, Index, JSON, UniqueConstraint, ForeignKey
4
+ from sqlalchemy.orm import Mapped, mapped_column, relationship
5
+ from sqlalchemy.types import TypeDecorator, BINARY
6
+
7
+ import numpy as np
8
+ import base64
9
+
10
+ from letta.orm.source import EmbeddingConfigColumn
11
+ from letta.orm.sqlalchemy_base import SqlalchemyBase
12
+ from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin
13
+ from letta.schemas.passage import Passage as PydanticPassage
14
+
15
+ from letta.config import LettaConfig
16
+ from letta.constants import MAX_EMBEDDING_DIM
17
+ from letta.settings import settings
18
+
19
+ config = LettaConfig()
20
+
21
+ if TYPE_CHECKING:
22
+ from letta.orm.file import File
23
+ from letta.orm.organization import Organization
24
+
25
+ class CommonVector(TypeDecorator):
26
+ """Common type for representing vectors in SQLite"""
27
+ impl = BINARY
28
+ cache_ok = True
29
+
30
+ def load_dialect_impl(self, dialect):
31
+ return dialect.type_descriptor(BINARY())
32
+
33
+ def process_bind_param(self, value, dialect):
34
+ if value is None:
35
+ return value
36
+ if isinstance(value, list):
37
+ value = np.array(value, dtype=np.float32)
38
+ return base64.b64encode(value.tobytes())
39
+
40
+ def process_result_value(self, value, dialect):
41
+ if not value:
42
+ return value
43
+ if dialect.name == "sqlite":
44
+ value = base64.b64decode(value)
45
+ return np.frombuffer(value, dtype=np.float32)
46
+
47
+ # TODO: After migration to Passage, will need to manually delete passages where files
48
+ # are deleted on web
49
+ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
50
+ """Defines data model for storing Passages"""
51
+ __tablename__ = "passages"
52
+ __table_args__ = {"extend_existing": True}
53
+ __pydantic_model__ = PydanticPassage
54
+
55
+ id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
56
+ text: Mapped[str] = mapped_column(doc="Passage text content")
57
+ source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier")
58
+ embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
59
+ metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
60
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
61
+ if settings.letta_pg_uri_no_default:
62
+ from pgvector.sqlalchemy import Vector
63
+ embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
64
+ else:
65
+ embedding = Column(CommonVector)
66
+
67
+ # Foreign keys
68
+ agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True)
69
+
70
+ # Relationships
71
+ organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin")
72
+ file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin")
@@ -1,13 +1,15 @@
1
1
  from datetime import datetime
2
2
  from enum import Enum
3
3
  from typing import TYPE_CHECKING, List, Literal, Optional, Type
4
+ import sqlite3
4
5
 
5
- from sqlalchemy import String, func, select
6
+ from sqlalchemy import String, desc, func, or_, select
6
7
  from sqlalchemy.exc import DBAPIError
7
8
  from sqlalchemy.orm import Mapped, Session, mapped_column
8
9
 
9
10
  from letta.log import get_logger
10
11
  from letta.orm.base import Base, CommonSqlalchemyMetaMixins
12
+ from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance
11
13
  from letta.orm.errors import (
12
14
  ForeignKeyConstraintViolationError,
13
15
  NoResultFound,
@@ -60,14 +62,26 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
60
62
  end_date: Optional[datetime] = None,
61
63
  limit: Optional[int] = 50,
62
64
  query_text: Optional[str] = None,
65
+ query_embedding: Optional[List[float]] = None,
66
+ ascending: bool = True,
63
67
  **kwargs,
64
68
  ) -> List[Type["SqlalchemyBase"]]:
65
- """List records with advanced filtering and pagination options."""
69
+ """
70
+ List records with cursor-based pagination, ordering by created_at.
71
+ Cursor is an ID, but pagination is based on the cursor object's created_at value.
72
+ """
66
73
  if start_date and end_date and start_date > end_date:
67
74
  raise ValueError("start_date must be earlier than or equal to end_date")
68
75
 
69
76
  logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
70
77
  with db_session as session:
78
+ # If cursor provided, get the reference object
79
+ cursor_obj = None
80
+ if cursor:
81
+ cursor_obj = session.get(cls, cursor)
82
+ if not cursor_obj:
83
+ raise NoResultFound(f"No {cls.__name__} found with id {cursor}")
84
+
71
85
  query = select(cls)
72
86
 
73
87
  # Apply filtering logic
@@ -80,22 +94,64 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
80
94
 
81
95
  # Date range filtering
82
96
  if start_date:
83
- query = query.filter(cls.created_at >= start_date)
97
+ query = query.filter(cls.created_at > start_date)
84
98
  if end_date:
85
- query = query.filter(cls.created_at <= end_date)
86
-
87
- # Cursor-based pagination
88
- if cursor:
89
- query = query.where(cls.id > cursor)
99
+ query = query.filter(cls.created_at < end_date)
100
+
101
+ # Cursor-based pagination using created_at
102
+ # TODO: There is a really nasty race condition issue here with Sqlite
103
+ # TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
104
+ if cursor_obj:
105
+ if ascending:
106
+ query = query.where(cls.created_at >= cursor_obj.created_at).where(
107
+ or_(cls.created_at > cursor_obj.created_at, cls.id > cursor_obj.id)
108
+ )
109
+ else:
110
+ query = query.where(cls.created_at <= cursor_obj.created_at).where(
111
+ or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
112
+ )
90
113
 
91
114
  # Apply text search
92
115
  if query_text:
116
+ from sqlalchemy import func
93
117
  query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
94
118
 
119
+ # Apply embedding search (Passages)
120
+ is_ordered = False
121
+ if query_embedding:
122
+ # check if embedding column exists. should only exist for passages
123
+ if not hasattr(cls, "embedding"):
124
+ raise ValueError(f"Class {cls.__name__} does not have an embedding column")
125
+
126
+ from letta.settings import settings
127
+ if settings.letta_pg_uri_no_default:
128
+ # PostgreSQL with pgvector
129
+ from pgvector.sqlalchemy import Vector
130
+ query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
131
+ else:
132
+ # SQLite with custom vector type
133
+ from sqlalchemy import func
134
+
135
+ query_embedding_binary = adapt_array(query_embedding)
136
+ query = query.order_by(
137
+ func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
138
+ cls.created_at.asc(),
139
+ cls.id.asc()
140
+ )
141
+ is_ordered = True
142
+
95
143
  # Handle ordering and soft deletes
96
144
  if hasattr(cls, "is_deleted"):
97
145
  query = query.where(cls.is_deleted == False)
98
- query = query.order_by(cls.id).limit(limit)
146
+
147
+ # Apply ordering by created_at
148
+ if not is_ordered:
149
+ if ascending:
150
+ query = query.order_by(cls.created_at, cls.id)
151
+ else:
152
+ query = query.order_by(desc(cls.created_at), desc(cls.id))
153
+
154
+ query = query.limit(limit)
99
155
 
100
156
  return list(session.execute(query).scalars())
101
157
 
@@ -342,4 +398,4 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
342
398
  def to_record(self) -> Type["BaseModel"]:
343
399
  """Deprecated accessor for to_pydantic"""
344
400
  logger.warning("to_record is deprecated, use to_pydantic instead.")
345
- return self.to_pydantic()
401
+ return self.to_pydantic()
@@ -0,0 +1,140 @@
1
+ from typing import Optional, Union
2
+
3
+ import base64
4
+ import numpy as np
5
+ from sqlalchemy import event
6
+ from sqlalchemy.engine import Engine
7
+ import sqlite3
8
+
9
+ from letta.constants import MAX_EMBEDDING_DIM
10
+
11
+ def adapt_array(arr):
12
+ """
13
+ Converts numpy array to binary for SQLite storage
14
+ """
15
+ if arr is None:
16
+ return None
17
+
18
+ if isinstance(arr, list):
19
+ arr = np.array(arr, dtype=np.float32)
20
+ elif not isinstance(arr, np.ndarray):
21
+ raise ValueError(f"Unsupported type: {type(arr)}")
22
+
23
+ # Convert to bytes and then base64 encode
24
+ bytes_data = arr.tobytes()
25
+ base64_data = base64.b64encode(bytes_data)
26
+ return sqlite3.Binary(base64_data)
27
+
28
+ def convert_array(text):
29
+ """
30
+ Converts binary back to numpy array
31
+ """
32
+ if text is None:
33
+ return None
34
+ if isinstance(text, list):
35
+ return np.array(text, dtype=np.float32)
36
+ if isinstance(text, np.ndarray):
37
+ return text
38
+
39
+ # Handle both bytes and sqlite3.Binary
40
+ binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text
41
+
42
+ try:
43
+ # First decode base64
44
+ decoded_data = base64.b64decode(binary_data)
45
+ # Then convert to numpy array
46
+ return np.frombuffer(decoded_data, dtype=np.float32)
47
+ except Exception as e:
48
+ return None
49
+
50
+ def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
51
+ """
52
+ Verifies that an embedding has the expected dimension
53
+
54
+ Args:
55
+ embedding: Input embedding array
56
+ expected_dim: Expected embedding dimension (default: 4096)
57
+
58
+ Returns:
59
+ bool: True if dimension matches, False otherwise
60
+ """
61
+ if embedding is None:
62
+ return False
63
+ return embedding.shape[0] == expected_dim
64
+
65
+ def validate_and_transform_embedding(
66
+ embedding: Union[bytes, sqlite3.Binary, list, np.ndarray],
67
+ expected_dim: int = MAX_EMBEDDING_DIM,
68
+ dtype: np.dtype = np.float32
69
+ ) -> Optional[np.ndarray]:
70
+ """
71
+ Validates and transforms embeddings to ensure correct dimensionality.
72
+
73
+ Args:
74
+ embedding: Input embedding in various possible formats
75
+ expected_dim: Expected embedding dimension (default 4096)
76
+ dtype: NumPy dtype for the embedding (default float32)
77
+
78
+ Returns:
79
+ np.ndarray: Validated and transformed embedding
80
+
81
+ Raises:
82
+ ValueError: If embedding dimension doesn't match expected dimension
83
+ """
84
+ if embedding is None:
85
+ return None
86
+
87
+ # Convert to numpy array based on input type
88
+ if isinstance(embedding, (bytes, sqlite3.Binary)):
89
+ vec = convert_array(embedding)
90
+ elif isinstance(embedding, list):
91
+ vec = np.array(embedding, dtype=dtype)
92
+ elif isinstance(embedding, np.ndarray):
93
+ vec = embedding.astype(dtype)
94
+ else:
95
+ raise ValueError(f"Unsupported embedding type: {type(embedding)}")
96
+
97
+ # Validate dimension
98
+ if vec.shape[0] != expected_dim:
99
+ raise ValueError(
100
+ f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
101
+ )
102
+
103
+ return vec
104
+
105
+ def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
106
+ """
107
+ Calculate cosine distance between two embeddings
108
+
109
+ Args:
110
+ embedding1: First embedding
111
+ embedding2: Second embedding
112
+ expected_dim: Expected embedding dimension (default 4096)
113
+
114
+ Returns:
115
+ float: Cosine distance
116
+ """
117
+
118
+ if embedding1 is None or embedding2 is None:
119
+ return 0.0 # Maximum distance if either embedding is None
120
+
121
+ try:
122
+ vec1 = validate_and_transform_embedding(embedding1, expected_dim)
123
+ vec2 = validate_and_transform_embedding(embedding2, expected_dim)
124
+ except ValueError as e:
125
+ return 0.0
126
+
127
+ similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
128
+ distance = float(1.0 - similarity)
129
+
130
+ return distance
131
+
132
+ @event.listens_for(Engine, "connect")
133
+ def register_functions(dbapi_connection, connection_record):
134
+ """Register SQLite functions"""
135
+ if isinstance(dbapi_connection, sqlite3.Connection):
136
+ dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
137
+
138
+ # Register adapters and converters for numpy arrays
139
+ sqlite3.register_adapter(np.ndarray, adapt_array)
140
+ sqlite3.register_converter("ARRAY", convert_array)
letta/orm/user.py CHANGED
@@ -20,7 +20,7 @@ class User(SqlalchemyBase, OrganizationMixin):
20
20
 
21
21
  # relationships
22
22
  organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
23
- jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.")
23
+ jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan")
24
24
 
25
25
  # TODO: Add this back later potentially
26
26
  # agents: Mapped[List["Agent"]] = relationship(
letta/schemas/agent.py CHANGED
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
4
4
 
5
5
  from pydantic import BaseModel, Field, field_validator
6
6
 
7
+ from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
7
8
  from letta.schemas.block import CreateBlock
8
9
  from letta.schemas.embedding_config import EmbeddingConfig
9
10
  from letta.schemas.letta_base import LettaBase
@@ -108,7 +109,7 @@ class CreateAgent(BaseAgent): #
108
109
  # all optional as server can generate defaults
109
110
  name: Optional[str] = Field(None, description="The name of the agent.")
110
111
  message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
111
-
112
+
112
113
  # memory creation
113
114
  memory_blocks: List[CreateBlock] = Field(
114
115
  # [CreateHuman(), CreatePersona()], description="The blocks to create in the agent's in-context memory."
@@ -116,11 +117,11 @@ class CreateAgent(BaseAgent): #
116
117
  description="The blocks to create in the agent's in-context memory.",
117
118
  )
118
119
 
119
- tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
120
+ tools: List[str] = Field(BASE_TOOLS + BASE_MEMORY_TOOLS, description="The tools used by the agent.")
120
121
  tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.")
121
122
  tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
122
123
  system: Optional[str] = Field(None, description="The system prompt used by the agent.")
123
- agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
124
+ agent_type: AgentType = Field(AgentType.memgpt_agent, description="The type of agent.")
124
125
  llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
125
126
  embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
126
127
  # Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from datetime import datetime, timezone
3
- from typing import Annotated, Literal, Optional, Union
3
+ from typing import Annotated, List, Literal, Optional, Union
4
4
 
5
5
  from pydantic import BaseModel, Field, field_serializer, field_validator
6
6
 
@@ -150,12 +150,16 @@ class FunctionReturn(LettaMessage):
150
150
  id (str): The ID of the message
151
151
  date (datetime): The date the message was created in ISO format
152
152
  function_call_id (str): A unique identifier for the function call that generated this message
153
+ stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
154
+ stderr (Optional[List(str)]): Captured stderr from the function invocation
153
155
  """
154
156
 
155
157
  message_type: Literal["function_return"] = "function_return"
156
158
  function_return: str
157
159
  status: Literal["success", "error"]
158
160
  function_call_id: str
161
+ stdout: Optional[List[str]] = None
162
+ stderr: Optional[List[str]] = None
159
163
 
160
164
 
161
165
  # Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
@@ -1,13 +1,13 @@
1
- from typing import List, Union
1
+ from typing import List
2
2
 
3
3
  from pydantic import BaseModel, Field
4
4
 
5
5
  from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
6
- from letta.schemas.message import Message, MessageCreate
6
+ from letta.schemas.message import MessageCreate
7
7
 
8
8
 
9
9
  class LettaRequest(BaseModel):
10
- messages: Union[List[MessageCreate], List[Message]] = Field(..., description="The messages to be sent to the agent.")
10
+ messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
11
11
 
12
12
  # Flags to support the use of AssistantMessage message types
13
13
 
letta/schemas/passage.py CHANGED
@@ -5,15 +5,17 @@ from pydantic import Field, field_validator
5
5
 
6
6
  from letta.constants import MAX_EMBEDDING_DIM
7
7
  from letta.schemas.embedding_config import EmbeddingConfig
8
- from letta.schemas.letta_base import LettaBase
8
+ from letta.schemas.letta_base import OrmMetadataBase
9
9
  from letta.utils import get_utc_time
10
10
 
11
11
 
12
- class PassageBase(LettaBase):
13
- __id_prefix__ = "passage"
12
+ class PassageBase(OrmMetadataBase):
13
+ __id_prefix__ = "passage_legacy"
14
+
15
+ is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")
14
16
 
15
17
  # associated user/agent
16
- user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
18
+ organization_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
17
19
  agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.")
18
20
 
19
21
  # origin data source
@@ -19,6 +19,7 @@ class SandboxRunResult(BaseModel):
19
19
  func_return: Optional[Any] = Field(None, description="The function return object")
20
20
  agent_state: Optional[AgentState] = Field(None, description="The agent state")
21
21
  stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation")
22
+ stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
22
23
  sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")
23
24
 
24
25
 
@@ -17,7 +17,6 @@ class ChildToolRule(BaseToolRule):
17
17
  A ToolRule represents a tool that can be invoked by the agent.
18
18
  """
19
19
 
20
- # type: str = Field("ToolRule")
21
20
  type: ToolRuleType = ToolRuleType.constrain_child_tools
22
21
  children: List[str] = Field(..., description="The children tools that can be invoked.")
23
22
 
@@ -27,7 +26,6 @@ class InitToolRule(BaseToolRule):
27
26
  Represents the initial tool rule configuration.
28
27
  """
29
28
 
30
- # type: str = Field("InitToolRule")
31
29
  type: ToolRuleType = ToolRuleType.run_first
32
30
 
33
31
 
@@ -36,7 +34,6 @@ class TerminalToolRule(BaseToolRule):
36
34
  Represents a terminal tool rule configuration where if this tool gets called, it must end the agent loop.
37
35
  """
38
36
 
39
- # type: str = Field("TerminalToolRule")
40
37
  type: ToolRuleType = ToolRuleType.exit_loop
41
38
 
42
39
 
@@ -6,7 +6,7 @@ from pathlib import Path
6
6
  from typing import Optional
7
7
 
8
8
  import uvicorn
9
- from fastapi import FastAPI
9
+ from fastapi import FastAPI, Request
10
10
  from fastapi.responses import JSONResponse
11
11
  from starlette.middleware.base import BaseHTTPMiddleware
12
12
  from starlette.middleware.cors import CORSMiddleware
@@ -109,7 +109,13 @@ random_password = os.getenv("LETTA_SERVER_PASSWORD") or generate_password()
109
109
 
110
110
 
111
111
  class CheckPasswordMiddleware(BaseHTTPMiddleware):
112
+
112
113
  async def dispatch(self, request, call_next):
114
+
115
+ # Exclude health check endpoint from password protection
116
+ if request.url.path == "/v1/health/" or request.url.path == "/latest/health/":
117
+ return await call_next(request)
118
+
113
119
  if request.headers.get("X-BARE-PASSWORD") == f"password {random_password}":
114
120
  return await call_next(request)
115
121
 
@@ -136,17 +142,18 @@ def create_application() -> "FastAPI":
136
142
  },
137
143
  )
138
144
 
145
+ debug_mode = "--debug" in sys.argv
139
146
  app = FastAPI(
140
147
  swagger_ui_parameters={"docExpansion": "none"},
141
148
  # openapi_tags=TAGS_METADATA,
142
149
  title="Letta",
143
150
  summary="Create LLM agents with long-term memory and custom tools 📚🦙",
144
151
  version="1.0.0", # TODO wire this up to the version in the package
145
- debug=True,
152
+ debug=debug_mode, # if True, the stack trace will be printed in the response
146
153
  )
147
154
 
148
155
  @app.exception_handler(Exception)
149
- async def generic_error_handler(request, exc):
156
+ async def generic_error_handler(request: Request, exc: Exception):
150
157
  # Log the actual error for debugging
151
158
  log.error(f"Unhandled error: {exc}", exc_info=True)
152
159
 
@@ -166,16 +173,19 @@ def create_application() -> "FastAPI":
166
173
  },
167
174
  )
168
175
 
176
+ @app.exception_handler(ValueError)
177
+ async def value_error_handler(request: Request, exc: ValueError):
178
+ return JSONResponse(status_code=400, content={"detail": str(exc)})
179
+
169
180
  @app.exception_handler(LettaAgentNotFoundError)
170
- async def agent_not_found_handler(request, exc):
181
+ async def agent_not_found_handler(request: Request, exc: LettaAgentNotFoundError):
171
182
  return JSONResponse(status_code=404, content={"detail": "Agent not found"})
172
183
 
173
184
  @app.exception_handler(LettaUserNotFoundError)
174
- async def user_not_found_handler(request, exc):
185
+ async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError):
175
186
  return JSONResponse(status_code=404, content={"detail": "User not found"})
176
187
 
177
188
  settings.cors_origins.append("https://app.letta.com")
178
- print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
179
189
 
180
190
  if (os.getenv("LETTA_SERVER_SECURE") == "true") or "--secure" in sys.argv:
181
191
  print(f"▶ Using secure mode with password: {random_password}")
@@ -254,9 +264,21 @@ def start_server(
254
264
  # Add the handler to the logger
255
265
  server_logger.addHandler(stream_handler)
256
266
 
257
- print(f" Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}\n")
258
- uvicorn.run(
259
- app,
260
- host=host or "localhost",
261
- port=port or REST_DEFAULT_PORT,
262
- )
267
+ if (os.getenv("LOCAL_HTTPS") == "true") or "--localhttps" in sys.argv:
268
+ uvicorn.run(
269
+ app,
270
+ host=host or "localhost",
271
+ port=port or REST_DEFAULT_PORT,
272
+ ssl_keyfile="certs/localhost-key.pem",
273
+ ssl_certfile="certs/localhost.pem",
274
+ )
275
+ print(f"▶ Server running at: https://{host or 'localhost'}:{port or REST_DEFAULT_PORT}\n")
276
+ else:
277
+ uvicorn.run(
278
+ app,
279
+ host=host or "localhost",
280
+ port=port or REST_DEFAULT_PORT,
281
+ )
282
+ print(f"▶ Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}\n")
283
+
284
+ print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
@@ -14,6 +14,7 @@ from fastapi import (
14
14
  status,
15
15
  )
16
16
  from fastapi.responses import JSONResponse, StreamingResponse
17
+ from pydantic import Field
17
18
 
18
19
  from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
19
20
  from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
@@ -87,9 +88,18 @@ def get_agent_context_window(
87
88
  return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
88
89
 
89
90
 
91
+ class CreateAgentRequest(CreateAgent):
92
+ """
93
+ CreateAgent model specifically for POST request body, excluding user_id which comes from headers
94
+ """
95
+
96
+ # Override the user_id field to exclude it from the request body validation
97
+ user_id: Optional[str] = Field(None, exclude=True)
98
+
99
+
90
100
  @router.post("/", response_model=AgentState, operation_id="create_agent")
91
101
  def create_agent(
92
- agent: CreateAgent = Body(...),
102
+ agent: CreateAgentRequest = Body(...),
93
103
  server: "SyncServer" = Depends(get_letta_server),
94
104
  user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
95
105
  ):
@@ -165,7 +175,7 @@ def get_agent_state(
165
175
  return server.get_agent_state(user_id=actor.id, agent_id=agent_id)
166
176
 
167
177
 
168
- @router.delete("/{agent_id}", response_model=None, operation_id="delete_agent")
178
+ @router.delete("/{agent_id}", response_model=AgentState, operation_id="delete_agent")
169
179
  def delete_agent(
170
180
  agent_id: str,
171
181
  server: "SyncServer" = Depends(get_letta_server),
@@ -176,7 +186,12 @@ def delete_agent(
176
186
  """
177
187
  actor = server.get_user_or_default(user_id=user_id)
178
188
 
179
- return server.delete_agent(user_id=actor.id, agent_id=agent_id)
189
+ agent = server.get_agent(agent_id)
190
+ if not agent:
191
+ raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
192
+
193
+ server.delete_agent(user_id=actor.id, agent_id=agent_id)
194
+ return agent
180
195
 
181
196
 
182
197
  @router.get("/{agent_id}/sources", response_model=List[Source], operation_id="get_agent_sources")
@@ -354,8 +369,7 @@ def get_agent_archival_memory(
354
369
  return server.get_agent_archival_cursor(
355
370
  user_id=actor.id,
356
371
  agent_id=agent_id,
357
- after=after,
358
- before=before,
372
+ cursor=after, # TODO: deleting before, after. is this expected?
359
373
  limit=limit,
360
374
  )
361
375
 
@@ -420,7 +434,7 @@ def get_agent_messages(
420
434
  return server.get_agent_recall_cursor(
421
435
  user_id=actor.id,
422
436
  agent_id=agent_id,
423
- cursor=before,
437
+ before=before,
424
438
  limit=limit,
425
439
  reverse=True,
426
440
  return_message_object=msg_object,
@@ -496,7 +510,6 @@ async def send_message_streaming(
496
510
  This endpoint accepts a message from a user and processes it through the agent.
497
511
  It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
498
512
  """
499
- request.stream_tokens = False
500
513
 
501
514
  actor = server.get_user_or_default(user_id=user_id)
502
515
  result = await send_message_to_agent(