letta-nightly 0.6.2.dev20241210030340__py3-none-any.whl → 0.6.2.dev20241211031658__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +32 -43
- letta/agent_store/db.py +12 -54
- letta/agent_store/storage.py +10 -9
- letta/cli/cli.py +1 -0
- letta/client/client.py +4 -3
- letta/config.py +2 -2
- letta/data_sources/connectors.py +4 -3
- letta/embeddings.py +29 -9
- letta/functions/function_sets/base.py +36 -11
- letta/metadata.py +13 -2
- letta/o1_agent.py +2 -3
- letta/offline_memory_agent.py +2 -1
- letta/orm/__init__.py +1 -0
- letta/orm/file.py +1 -0
- letta/orm/mixins.py +12 -2
- letta/orm/organization.py +3 -0
- letta/orm/passage.py +72 -0
- letta/orm/sqlalchemy_base.py +66 -10
- letta/orm/sqlite_functions.py +140 -0
- letta/orm/user.py +1 -1
- letta/schemas/agent.py +4 -3
- letta/schemas/letta_message.py +5 -1
- letta/schemas/letta_request.py +3 -3
- letta/schemas/passage.py +6 -4
- letta/schemas/sandbox_config.py +1 -0
- letta/schemas/tool_rule.py +0 -3
- letta/server/rest_api/app.py +34 -12
- letta/server/rest_api/routers/v1/agents.py +20 -7
- letta/server/server.py +76 -52
- letta/server/static_files/assets/{index-4848e3d7.js → index-048c9598.js} +1 -1
- letta/server/static_files/assets/{index-43ab4d62.css → index-0e31b727.css} +1 -1
- letta/server/static_files/index.html +2 -2
- letta/services/message_manager.py +3 -0
- letta/services/passage_manager.py +225 -0
- letta/services/source_manager.py +2 -1
- letta/services/tool_execution_sandbox.py +19 -7
- letta/settings.py +2 -0
- {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/METADATA +10 -15
- {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/RECORD +42 -40
- letta/agent_store/chroma.py +0 -297
- {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.2.dev20241210030340.dist-info → letta_nightly-0.6.2.dev20241211031658.dist-info}/entry_points.txt +0 -0
letta/orm/organization.py
CHANGED
|
@@ -33,7 +33,10 @@ class Organization(SqlalchemyBase):
|
|
|
33
33
|
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
|
|
34
34
|
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
|
35
35
|
)
|
|
36
|
+
|
|
37
|
+
# relationships
|
|
36
38
|
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
|
39
|
+
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
|
|
37
40
|
|
|
38
41
|
# TODO: Map these relationships later when we actually make these models
|
|
39
42
|
# below is just a suggestion
|
letta/orm/passage.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import List, Optional, TYPE_CHECKING
|
|
3
|
+
from sqlalchemy import Column, String, DateTime, Index, JSON, UniqueConstraint, ForeignKey
|
|
4
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
|
+
from sqlalchemy.types import TypeDecorator, BINARY
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import base64
|
|
9
|
+
|
|
10
|
+
from letta.orm.source import EmbeddingConfigColumn
|
|
11
|
+
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
12
|
+
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin
|
|
13
|
+
from letta.schemas.passage import Passage as PydanticPassage
|
|
14
|
+
|
|
15
|
+
from letta.config import LettaConfig
|
|
16
|
+
from letta.constants import MAX_EMBEDDING_DIM
|
|
17
|
+
from letta.settings import settings
|
|
18
|
+
|
|
19
|
+
config = LettaConfig()
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from letta.orm.file import File
|
|
23
|
+
from letta.orm.organization import Organization
|
|
24
|
+
|
|
25
|
+
class CommonVector(TypeDecorator):
|
|
26
|
+
"""Common type for representing vectors in SQLite"""
|
|
27
|
+
impl = BINARY
|
|
28
|
+
cache_ok = True
|
|
29
|
+
|
|
30
|
+
def load_dialect_impl(self, dialect):
|
|
31
|
+
return dialect.type_descriptor(BINARY())
|
|
32
|
+
|
|
33
|
+
def process_bind_param(self, value, dialect):
|
|
34
|
+
if value is None:
|
|
35
|
+
return value
|
|
36
|
+
if isinstance(value, list):
|
|
37
|
+
value = np.array(value, dtype=np.float32)
|
|
38
|
+
return base64.b64encode(value.tobytes())
|
|
39
|
+
|
|
40
|
+
def process_result_value(self, value, dialect):
|
|
41
|
+
if not value:
|
|
42
|
+
return value
|
|
43
|
+
if dialect.name == "sqlite":
|
|
44
|
+
value = base64.b64decode(value)
|
|
45
|
+
return np.frombuffer(value, dtype=np.float32)
|
|
46
|
+
|
|
47
|
+
# TODO: After migration to Passage, will need to manually delete passages where files
|
|
48
|
+
# are deleted on web
|
|
49
|
+
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
|
50
|
+
"""Defines data model for storing Passages"""
|
|
51
|
+
__tablename__ = "passages"
|
|
52
|
+
__table_args__ = {"extend_existing": True}
|
|
53
|
+
__pydantic_model__ = PydanticPassage
|
|
54
|
+
|
|
55
|
+
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
|
|
56
|
+
text: Mapped[str] = mapped_column(doc="Passage text content")
|
|
57
|
+
source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier")
|
|
58
|
+
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
|
59
|
+
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
|
60
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
|
61
|
+
if settings.letta_pg_uri_no_default:
|
|
62
|
+
from pgvector.sqlalchemy import Vector
|
|
63
|
+
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
|
64
|
+
else:
|
|
65
|
+
embedding = Column(CommonVector)
|
|
66
|
+
|
|
67
|
+
# Foreign keys
|
|
68
|
+
agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True)
|
|
69
|
+
|
|
70
|
+
# Relationships
|
|
71
|
+
organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin")
|
|
72
|
+
file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin")
|
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
|
4
|
+
import sqlite3
|
|
4
5
|
|
|
5
|
-
from sqlalchemy import String, func, select
|
|
6
|
+
from sqlalchemy import String, desc, func, or_, select
|
|
6
7
|
from sqlalchemy.exc import DBAPIError
|
|
7
8
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
|
8
9
|
|
|
9
10
|
from letta.log import get_logger
|
|
10
11
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
|
12
|
+
from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance
|
|
11
13
|
from letta.orm.errors import (
|
|
12
14
|
ForeignKeyConstraintViolationError,
|
|
13
15
|
NoResultFound,
|
|
@@ -60,14 +62,26 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
60
62
|
end_date: Optional[datetime] = None,
|
|
61
63
|
limit: Optional[int] = 50,
|
|
62
64
|
query_text: Optional[str] = None,
|
|
65
|
+
query_embedding: Optional[List[float]] = None,
|
|
66
|
+
ascending: bool = True,
|
|
63
67
|
**kwargs,
|
|
64
68
|
) -> List[Type["SqlalchemyBase"]]:
|
|
65
|
-
"""
|
|
69
|
+
"""
|
|
70
|
+
List records with cursor-based pagination, ordering by created_at.
|
|
71
|
+
Cursor is an ID, but pagination is based on the cursor object's created_at value.
|
|
72
|
+
"""
|
|
66
73
|
if start_date and end_date and start_date > end_date:
|
|
67
74
|
raise ValueError("start_date must be earlier than or equal to end_date")
|
|
68
75
|
|
|
69
76
|
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
|
70
77
|
with db_session as session:
|
|
78
|
+
# If cursor provided, get the reference object
|
|
79
|
+
cursor_obj = None
|
|
80
|
+
if cursor:
|
|
81
|
+
cursor_obj = session.get(cls, cursor)
|
|
82
|
+
if not cursor_obj:
|
|
83
|
+
raise NoResultFound(f"No {cls.__name__} found with id {cursor}")
|
|
84
|
+
|
|
71
85
|
query = select(cls)
|
|
72
86
|
|
|
73
87
|
# Apply filtering logic
|
|
@@ -80,22 +94,64 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
80
94
|
|
|
81
95
|
# Date range filtering
|
|
82
96
|
if start_date:
|
|
83
|
-
query = query.filter(cls.created_at
|
|
97
|
+
query = query.filter(cls.created_at > start_date)
|
|
84
98
|
if end_date:
|
|
85
|
-
query = query.filter(cls.created_at
|
|
86
|
-
|
|
87
|
-
# Cursor-based pagination
|
|
88
|
-
|
|
89
|
-
|
|
99
|
+
query = query.filter(cls.created_at < end_date)
|
|
100
|
+
|
|
101
|
+
# Cursor-based pagination using created_at
|
|
102
|
+
# TODO: There is a really nasty race condition issue here with Sqlite
|
|
103
|
+
# TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
|
|
104
|
+
if cursor_obj:
|
|
105
|
+
if ascending:
|
|
106
|
+
query = query.where(cls.created_at >= cursor_obj.created_at).where(
|
|
107
|
+
or_(cls.created_at > cursor_obj.created_at, cls.id > cursor_obj.id)
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
query = query.where(cls.created_at <= cursor_obj.created_at).where(
|
|
111
|
+
or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
|
|
112
|
+
)
|
|
90
113
|
|
|
91
114
|
# Apply text search
|
|
92
115
|
if query_text:
|
|
116
|
+
from sqlalchemy import func
|
|
93
117
|
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
|
94
118
|
|
|
119
|
+
# Apply embedding search (Passages)
|
|
120
|
+
is_ordered = False
|
|
121
|
+
if query_embedding:
|
|
122
|
+
# check if embedding column exists. should only exist for passages
|
|
123
|
+
if not hasattr(cls, "embedding"):
|
|
124
|
+
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
|
125
|
+
|
|
126
|
+
from letta.settings import settings
|
|
127
|
+
if settings.letta_pg_uri_no_default:
|
|
128
|
+
# PostgreSQL with pgvector
|
|
129
|
+
from pgvector.sqlalchemy import Vector
|
|
130
|
+
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
|
131
|
+
else:
|
|
132
|
+
# SQLite with custom vector type
|
|
133
|
+
from sqlalchemy import func
|
|
134
|
+
|
|
135
|
+
query_embedding_binary = adapt_array(query_embedding)
|
|
136
|
+
query = query.order_by(
|
|
137
|
+
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
|
138
|
+
cls.created_at.asc(),
|
|
139
|
+
cls.id.asc()
|
|
140
|
+
)
|
|
141
|
+
is_ordered = True
|
|
142
|
+
|
|
95
143
|
# Handle ordering and soft deletes
|
|
96
144
|
if hasattr(cls, "is_deleted"):
|
|
97
145
|
query = query.where(cls.is_deleted == False)
|
|
98
|
-
|
|
146
|
+
|
|
147
|
+
# Apply ordering by created_at
|
|
148
|
+
if not is_ordered:
|
|
149
|
+
if ascending:
|
|
150
|
+
query = query.order_by(cls.created_at, cls.id)
|
|
151
|
+
else:
|
|
152
|
+
query = query.order_by(desc(cls.created_at), desc(cls.id))
|
|
153
|
+
|
|
154
|
+
query = query.limit(limit)
|
|
99
155
|
|
|
100
156
|
return list(session.execute(query).scalars())
|
|
101
157
|
|
|
@@ -342,4 +398,4 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
342
398
|
def to_record(self) -> Type["BaseModel"]:
|
|
343
399
|
"""Deprecated accessor for to_pydantic"""
|
|
344
400
|
logger.warning("to_record is deprecated, use to_pydantic instead.")
|
|
345
|
-
return self.to_pydantic()
|
|
401
|
+
return self.to_pydantic()
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import numpy as np
|
|
5
|
+
from sqlalchemy import event
|
|
6
|
+
from sqlalchemy.engine import Engine
|
|
7
|
+
import sqlite3
|
|
8
|
+
|
|
9
|
+
from letta.constants import MAX_EMBEDDING_DIM
|
|
10
|
+
|
|
11
|
+
def adapt_array(arr):
|
|
12
|
+
"""
|
|
13
|
+
Converts numpy array to binary for SQLite storage
|
|
14
|
+
"""
|
|
15
|
+
if arr is None:
|
|
16
|
+
return None
|
|
17
|
+
|
|
18
|
+
if isinstance(arr, list):
|
|
19
|
+
arr = np.array(arr, dtype=np.float32)
|
|
20
|
+
elif not isinstance(arr, np.ndarray):
|
|
21
|
+
raise ValueError(f"Unsupported type: {type(arr)}")
|
|
22
|
+
|
|
23
|
+
# Convert to bytes and then base64 encode
|
|
24
|
+
bytes_data = arr.tobytes()
|
|
25
|
+
base64_data = base64.b64encode(bytes_data)
|
|
26
|
+
return sqlite3.Binary(base64_data)
|
|
27
|
+
|
|
28
|
+
def convert_array(text):
|
|
29
|
+
"""
|
|
30
|
+
Converts binary back to numpy array
|
|
31
|
+
"""
|
|
32
|
+
if text is None:
|
|
33
|
+
return None
|
|
34
|
+
if isinstance(text, list):
|
|
35
|
+
return np.array(text, dtype=np.float32)
|
|
36
|
+
if isinstance(text, np.ndarray):
|
|
37
|
+
return text
|
|
38
|
+
|
|
39
|
+
# Handle both bytes and sqlite3.Binary
|
|
40
|
+
binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
# First decode base64
|
|
44
|
+
decoded_data = base64.b64decode(binary_data)
|
|
45
|
+
# Then convert to numpy array
|
|
46
|
+
return np.frombuffer(decoded_data, dtype=np.float32)
|
|
47
|
+
except Exception as e:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
|
|
51
|
+
"""
|
|
52
|
+
Verifies that an embedding has the expected dimension
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
embedding: Input embedding array
|
|
56
|
+
expected_dim: Expected embedding dimension (default: 4096)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
bool: True if dimension matches, False otherwise
|
|
60
|
+
"""
|
|
61
|
+
if embedding is None:
|
|
62
|
+
return False
|
|
63
|
+
return embedding.shape[0] == expected_dim
|
|
64
|
+
|
|
65
|
+
def validate_and_transform_embedding(
|
|
66
|
+
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray],
|
|
67
|
+
expected_dim: int = MAX_EMBEDDING_DIM,
|
|
68
|
+
dtype: np.dtype = np.float32
|
|
69
|
+
) -> Optional[np.ndarray]:
|
|
70
|
+
"""
|
|
71
|
+
Validates and transforms embeddings to ensure correct dimensionality.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
embedding: Input embedding in various possible formats
|
|
75
|
+
expected_dim: Expected embedding dimension (default 4096)
|
|
76
|
+
dtype: NumPy dtype for the embedding (default float32)
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
np.ndarray: Validated and transformed embedding
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
ValueError: If embedding dimension doesn't match expected dimension
|
|
83
|
+
"""
|
|
84
|
+
if embedding is None:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
# Convert to numpy array based on input type
|
|
88
|
+
if isinstance(embedding, (bytes, sqlite3.Binary)):
|
|
89
|
+
vec = convert_array(embedding)
|
|
90
|
+
elif isinstance(embedding, list):
|
|
91
|
+
vec = np.array(embedding, dtype=dtype)
|
|
92
|
+
elif isinstance(embedding, np.ndarray):
|
|
93
|
+
vec = embedding.astype(dtype)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f"Unsupported embedding type: {type(embedding)}")
|
|
96
|
+
|
|
97
|
+
# Validate dimension
|
|
98
|
+
if vec.shape[0] != expected_dim:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return vec
|
|
104
|
+
|
|
105
|
+
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
|
|
106
|
+
"""
|
|
107
|
+
Calculate cosine distance between two embeddings
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
embedding1: First embedding
|
|
111
|
+
embedding2: Second embedding
|
|
112
|
+
expected_dim: Expected embedding dimension (default 4096)
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
float: Cosine distance
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
if embedding1 is None or embedding2 is None:
|
|
119
|
+
return 0.0 # Maximum distance if either embedding is None
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
vec1 = validate_and_transform_embedding(embedding1, expected_dim)
|
|
123
|
+
vec2 = validate_and_transform_embedding(embedding2, expected_dim)
|
|
124
|
+
except ValueError as e:
|
|
125
|
+
return 0.0
|
|
126
|
+
|
|
127
|
+
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
128
|
+
distance = float(1.0 - similarity)
|
|
129
|
+
|
|
130
|
+
return distance
|
|
131
|
+
|
|
132
|
+
@event.listens_for(Engine, "connect")
|
|
133
|
+
def register_functions(dbapi_connection, connection_record):
|
|
134
|
+
"""Register SQLite functions"""
|
|
135
|
+
if isinstance(dbapi_connection, sqlite3.Connection):
|
|
136
|
+
dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
|
|
137
|
+
|
|
138
|
+
# Register adapters and converters for numpy arrays
|
|
139
|
+
sqlite3.register_adapter(np.ndarray, adapt_array)
|
|
140
|
+
sqlite3.register_converter("ARRAY", convert_array)
|
letta/orm/user.py
CHANGED
|
@@ -20,7 +20,7 @@ class User(SqlalchemyBase, OrganizationMixin):
|
|
|
20
20
|
|
|
21
21
|
# relationships
|
|
22
22
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
|
|
23
|
-
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.")
|
|
23
|
+
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan")
|
|
24
24
|
|
|
25
25
|
# TODO: Add this back later potentially
|
|
26
26
|
# agents: Mapped[List["Agent"]] = relationship(
|
letta/schemas/agent.py
CHANGED
|
@@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel, Field, field_validator
|
|
6
6
|
|
|
7
|
+
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
|
7
8
|
from letta.schemas.block import CreateBlock
|
|
8
9
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
9
10
|
from letta.schemas.letta_base import LettaBase
|
|
@@ -108,7 +109,7 @@ class CreateAgent(BaseAgent): #
|
|
|
108
109
|
# all optional as server can generate defaults
|
|
109
110
|
name: Optional[str] = Field(None, description="The name of the agent.")
|
|
110
111
|
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
|
111
|
-
|
|
112
|
+
|
|
112
113
|
# memory creation
|
|
113
114
|
memory_blocks: List[CreateBlock] = Field(
|
|
114
115
|
# [CreateHuman(), CreatePersona()], description="The blocks to create in the agent's in-context memory."
|
|
@@ -116,11 +117,11 @@ class CreateAgent(BaseAgent): #
|
|
|
116
117
|
description="The blocks to create in the agent's in-context memory.",
|
|
117
118
|
)
|
|
118
119
|
|
|
119
|
-
tools:
|
|
120
|
+
tools: List[str] = Field(BASE_TOOLS + BASE_MEMORY_TOOLS, description="The tools used by the agent.")
|
|
120
121
|
tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.")
|
|
121
122
|
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
|
|
122
123
|
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
|
123
|
-
agent_type:
|
|
124
|
+
agent_type: AgentType = Field(AgentType.memgpt_agent, description="The type of agent.")
|
|
124
125
|
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
|
125
126
|
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
|
126
127
|
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
|
letta/schemas/letta_message.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from datetime import datetime, timezone
|
|
3
|
-
from typing import Annotated, Literal, Optional, Union
|
|
3
|
+
from typing import Annotated, List, Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
6
6
|
|
|
@@ -150,12 +150,16 @@ class FunctionReturn(LettaMessage):
|
|
|
150
150
|
id (str): The ID of the message
|
|
151
151
|
date (datetime): The date the message was created in ISO format
|
|
152
152
|
function_call_id (str): A unique identifier for the function call that generated this message
|
|
153
|
+
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
|
|
154
|
+
stderr (Optional[List(str)]): Captured stderr from the function invocation
|
|
153
155
|
"""
|
|
154
156
|
|
|
155
157
|
message_type: Literal["function_return"] = "function_return"
|
|
156
158
|
function_return: str
|
|
157
159
|
status: Literal["success", "error"]
|
|
158
160
|
function_call_id: str
|
|
161
|
+
stdout: Optional[List[str]] = None
|
|
162
|
+
stderr: Optional[List[str]] = None
|
|
159
163
|
|
|
160
164
|
|
|
161
165
|
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
|
letta/schemas/letta_request.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, Field
|
|
4
4
|
|
|
5
5
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
6
|
-
from letta.schemas.message import
|
|
6
|
+
from letta.schemas.message import MessageCreate
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class LettaRequest(BaseModel):
|
|
10
|
-
messages:
|
|
10
|
+
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
|
11
11
|
|
|
12
12
|
# Flags to support the use of AssistantMessage message types
|
|
13
13
|
|
letta/schemas/passage.py
CHANGED
|
@@ -5,15 +5,17 @@ from pydantic import Field, field_validator
|
|
|
5
5
|
|
|
6
6
|
from letta.constants import MAX_EMBEDDING_DIM
|
|
7
7
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
8
|
-
from letta.schemas.letta_base import
|
|
8
|
+
from letta.schemas.letta_base import OrmMetadataBase
|
|
9
9
|
from letta.utils import get_utc_time
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class PassageBase(
|
|
13
|
-
__id_prefix__ = "
|
|
12
|
+
class PassageBase(OrmMetadataBase):
|
|
13
|
+
__id_prefix__ = "passage_legacy"
|
|
14
|
+
|
|
15
|
+
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")
|
|
14
16
|
|
|
15
17
|
# associated user/agent
|
|
16
|
-
|
|
18
|
+
organization_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
|
|
17
19
|
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.")
|
|
18
20
|
|
|
19
21
|
# origin data source
|
letta/schemas/sandbox_config.py
CHANGED
|
@@ -19,6 +19,7 @@ class SandboxRunResult(BaseModel):
|
|
|
19
19
|
func_return: Optional[Any] = Field(None, description="The function return object")
|
|
20
20
|
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
|
21
21
|
stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation")
|
|
22
|
+
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
|
22
23
|
sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")
|
|
23
24
|
|
|
24
25
|
|
letta/schemas/tool_rule.py
CHANGED
|
@@ -17,7 +17,6 @@ class ChildToolRule(BaseToolRule):
|
|
|
17
17
|
A ToolRule represents a tool that can be invoked by the agent.
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
# type: str = Field("ToolRule")
|
|
21
20
|
type: ToolRuleType = ToolRuleType.constrain_child_tools
|
|
22
21
|
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
|
23
22
|
|
|
@@ -27,7 +26,6 @@ class InitToolRule(BaseToolRule):
|
|
|
27
26
|
Represents the initial tool rule configuration.
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
|
-
# type: str = Field("InitToolRule")
|
|
31
29
|
type: ToolRuleType = ToolRuleType.run_first
|
|
32
30
|
|
|
33
31
|
|
|
@@ -36,7 +34,6 @@ class TerminalToolRule(BaseToolRule):
|
|
|
36
34
|
Represents a terminal tool rule configuration where if this tool gets called, it must end the agent loop.
|
|
37
35
|
"""
|
|
38
36
|
|
|
39
|
-
# type: str = Field("TerminalToolRule")
|
|
40
37
|
type: ToolRuleType = ToolRuleType.exit_loop
|
|
41
38
|
|
|
42
39
|
|
letta/server/rest_api/app.py
CHANGED
|
@@ -6,7 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Optional
|
|
7
7
|
|
|
8
8
|
import uvicorn
|
|
9
|
-
from fastapi import FastAPI
|
|
9
|
+
from fastapi import FastAPI, Request
|
|
10
10
|
from fastapi.responses import JSONResponse
|
|
11
11
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
12
12
|
from starlette.middleware.cors import CORSMiddleware
|
|
@@ -109,7 +109,13 @@ random_password = os.getenv("LETTA_SERVER_PASSWORD") or generate_password()
|
|
|
109
109
|
|
|
110
110
|
|
|
111
111
|
class CheckPasswordMiddleware(BaseHTTPMiddleware):
|
|
112
|
+
|
|
112
113
|
async def dispatch(self, request, call_next):
|
|
114
|
+
|
|
115
|
+
# Exclude health check endpoint from password protection
|
|
116
|
+
if request.url.path == "/v1/health/" or request.url.path == "/latest/health/":
|
|
117
|
+
return await call_next(request)
|
|
118
|
+
|
|
113
119
|
if request.headers.get("X-BARE-PASSWORD") == f"password {random_password}":
|
|
114
120
|
return await call_next(request)
|
|
115
121
|
|
|
@@ -136,17 +142,18 @@ def create_application() -> "FastAPI":
|
|
|
136
142
|
},
|
|
137
143
|
)
|
|
138
144
|
|
|
145
|
+
debug_mode = "--debug" in sys.argv
|
|
139
146
|
app = FastAPI(
|
|
140
147
|
swagger_ui_parameters={"docExpansion": "none"},
|
|
141
148
|
# openapi_tags=TAGS_METADATA,
|
|
142
149
|
title="Letta",
|
|
143
150
|
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
|
|
144
151
|
version="1.0.0", # TODO wire this up to the version in the package
|
|
145
|
-
debug=True,
|
|
152
|
+
debug=debug_mode, # if True, the stack trace will be printed in the response
|
|
146
153
|
)
|
|
147
154
|
|
|
148
155
|
@app.exception_handler(Exception)
|
|
149
|
-
async def generic_error_handler(request, exc):
|
|
156
|
+
async def generic_error_handler(request: Request, exc: Exception):
|
|
150
157
|
# Log the actual error for debugging
|
|
151
158
|
log.error(f"Unhandled error: {exc}", exc_info=True)
|
|
152
159
|
|
|
@@ -166,16 +173,19 @@ def create_application() -> "FastAPI":
|
|
|
166
173
|
},
|
|
167
174
|
)
|
|
168
175
|
|
|
176
|
+
@app.exception_handler(ValueError)
|
|
177
|
+
async def value_error_handler(request: Request, exc: ValueError):
|
|
178
|
+
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
|
179
|
+
|
|
169
180
|
@app.exception_handler(LettaAgentNotFoundError)
|
|
170
|
-
async def agent_not_found_handler(request, exc):
|
|
181
|
+
async def agent_not_found_handler(request: Request, exc: LettaAgentNotFoundError):
|
|
171
182
|
return JSONResponse(status_code=404, content={"detail": "Agent not found"})
|
|
172
183
|
|
|
173
184
|
@app.exception_handler(LettaUserNotFoundError)
|
|
174
|
-
async def user_not_found_handler(request, exc):
|
|
185
|
+
async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError):
|
|
175
186
|
return JSONResponse(status_code=404, content={"detail": "User not found"})
|
|
176
187
|
|
|
177
188
|
settings.cors_origins.append("https://app.letta.com")
|
|
178
|
-
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
|
|
179
189
|
|
|
180
190
|
if (os.getenv("LETTA_SERVER_SECURE") == "true") or "--secure" in sys.argv:
|
|
181
191
|
print(f"▶ Using secure mode with password: {random_password}")
|
|
@@ -254,9 +264,21 @@ def start_server(
|
|
|
254
264
|
# Add the handler to the logger
|
|
255
265
|
server_logger.addHandler(stream_handler)
|
|
256
266
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
267
|
+
if (os.getenv("LOCAL_HTTPS") == "true") or "--localhttps" in sys.argv:
|
|
268
|
+
uvicorn.run(
|
|
269
|
+
app,
|
|
270
|
+
host=host or "localhost",
|
|
271
|
+
port=port or REST_DEFAULT_PORT,
|
|
272
|
+
ssl_keyfile="certs/localhost-key.pem",
|
|
273
|
+
ssl_certfile="certs/localhost.pem",
|
|
274
|
+
)
|
|
275
|
+
print(f"▶ Server running at: https://{host or 'localhost'}:{port or REST_DEFAULT_PORT}\n")
|
|
276
|
+
else:
|
|
277
|
+
uvicorn.run(
|
|
278
|
+
app,
|
|
279
|
+
host=host or "localhost",
|
|
280
|
+
port=port or REST_DEFAULT_PORT,
|
|
281
|
+
)
|
|
282
|
+
print(f"▶ Server running at: http://{host or 'localhost'}:{port or REST_DEFAULT_PORT}\n")
|
|
283
|
+
|
|
284
|
+
print(f"▶ View using ADE at: https://app.letta.com/development-servers/local/dashboard")
|
|
@@ -14,6 +14,7 @@ from fastapi import (
|
|
|
14
14
|
status,
|
|
15
15
|
)
|
|
16
16
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
17
|
+
from pydantic import Field
|
|
17
18
|
|
|
18
19
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
19
20
|
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
|
@@ -87,9 +88,18 @@ def get_agent_context_window(
|
|
|
87
88
|
return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
|
|
88
89
|
|
|
89
90
|
|
|
91
|
+
class CreateAgentRequest(CreateAgent):
|
|
92
|
+
"""
|
|
93
|
+
CreateAgent model specifically for POST request body, excluding user_id which comes from headers
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
# Override the user_id field to exclude it from the request body validation
|
|
97
|
+
user_id: Optional[str] = Field(None, exclude=True)
|
|
98
|
+
|
|
99
|
+
|
|
90
100
|
@router.post("/", response_model=AgentState, operation_id="create_agent")
|
|
91
101
|
def create_agent(
|
|
92
|
-
agent:
|
|
102
|
+
agent: CreateAgentRequest = Body(...),
|
|
93
103
|
server: "SyncServer" = Depends(get_letta_server),
|
|
94
104
|
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
95
105
|
):
|
|
@@ -165,7 +175,7 @@ def get_agent_state(
|
|
|
165
175
|
return server.get_agent_state(user_id=actor.id, agent_id=agent_id)
|
|
166
176
|
|
|
167
177
|
|
|
168
|
-
@router.delete("/{agent_id}", response_model=
|
|
178
|
+
@router.delete("/{agent_id}", response_model=AgentState, operation_id="delete_agent")
|
|
169
179
|
def delete_agent(
|
|
170
180
|
agent_id: str,
|
|
171
181
|
server: "SyncServer" = Depends(get_letta_server),
|
|
@@ -176,7 +186,12 @@ def delete_agent(
|
|
|
176
186
|
"""
|
|
177
187
|
actor = server.get_user_or_default(user_id=user_id)
|
|
178
188
|
|
|
179
|
-
|
|
189
|
+
agent = server.get_agent(agent_id)
|
|
190
|
+
if not agent:
|
|
191
|
+
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
|
192
|
+
|
|
193
|
+
server.delete_agent(user_id=actor.id, agent_id=agent_id)
|
|
194
|
+
return agent
|
|
180
195
|
|
|
181
196
|
|
|
182
197
|
@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="get_agent_sources")
|
|
@@ -354,8 +369,7 @@ def get_agent_archival_memory(
|
|
|
354
369
|
return server.get_agent_archival_cursor(
|
|
355
370
|
user_id=actor.id,
|
|
356
371
|
agent_id=agent_id,
|
|
357
|
-
|
|
358
|
-
before=before,
|
|
372
|
+
cursor=after, # TODO: deleting before, after. is this expected?
|
|
359
373
|
limit=limit,
|
|
360
374
|
)
|
|
361
375
|
|
|
@@ -420,7 +434,7 @@ def get_agent_messages(
|
|
|
420
434
|
return server.get_agent_recall_cursor(
|
|
421
435
|
user_id=actor.id,
|
|
422
436
|
agent_id=agent_id,
|
|
423
|
-
|
|
437
|
+
before=before,
|
|
424
438
|
limit=limit,
|
|
425
439
|
reverse=True,
|
|
426
440
|
return_message_object=msg_object,
|
|
@@ -496,7 +510,6 @@ async def send_message_streaming(
|
|
|
496
510
|
This endpoint accepts a message from a user and processes it through the agent.
|
|
497
511
|
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
|
498
512
|
"""
|
|
499
|
-
request.stream_tokens = False
|
|
500
513
|
|
|
501
514
|
actor = server.get_user_or_default(user_id=user_id)
|
|
502
515
|
result = await send_message_to_agent(
|