letta-nightly 0.8.15.dev20250720104313__py3-none-any.whl → 0.8.16.dev20250721104533__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +27 -11
- letta/agents/helpers.py +1 -1
- letta/agents/letta_agent.py +518 -322
- letta/agents/letta_agent_batch.py +1 -2
- letta/agents/voice_agent.py +15 -17
- letta/client/client.py +3 -3
- letta/constants.py +5 -0
- letta/embeddings.py +0 -2
- letta/errors.py +8 -0
- letta/functions/function_sets/base.py +3 -3
- letta/functions/helpers.py +2 -3
- letta/groups/sleeptime_multi_agent.py +0 -1
- letta/helpers/composio_helpers.py +2 -2
- letta/helpers/converters.py +1 -1
- letta/helpers/pinecone_utils.py +8 -0
- letta/helpers/tool_rule_solver.py +13 -18
- letta/llm_api/aws_bedrock.py +16 -2
- letta/llm_api/cohere.py +1 -1
- letta/llm_api/openai_client.py +1 -1
- letta/local_llm/grammars/gbnf_grammar_generator.py +1 -1
- letta/local_llm/llm_chat_completion_wrappers/zephyr.py +14 -14
- letta/local_llm/utils.py +1 -2
- letta/orm/agent.py +3 -3
- letta/orm/block.py +4 -4
- letta/orm/files_agents.py +0 -1
- letta/orm/identity.py +2 -0
- letta/orm/mcp_server.py +0 -2
- letta/orm/message.py +140 -14
- letta/orm/organization.py +5 -5
- letta/orm/passage.py +4 -4
- letta/orm/source.py +1 -1
- letta/orm/sqlalchemy_base.py +61 -39
- letta/orm/step.py +2 -0
- letta/otel/db_pool_monitoring.py +308 -0
- letta/otel/metric_registry.py +94 -1
- letta/otel/sqlalchemy_instrumentation.py +548 -0
- letta/otel/sqlalchemy_instrumentation_integration.py +124 -0
- letta/otel/tracing.py +37 -1
- letta/schemas/agent.py +0 -3
- letta/schemas/agent_file.py +283 -0
- letta/schemas/block.py +0 -3
- letta/schemas/file.py +28 -26
- letta/schemas/letta_message.py +15 -4
- letta/schemas/memory.py +1 -1
- letta/schemas/message.py +31 -26
- letta/schemas/openai/chat_completion_response.py +0 -1
- letta/schemas/providers.py +20 -0
- letta/schemas/source.py +11 -13
- letta/schemas/step.py +12 -0
- letta/schemas/tool.py +0 -4
- letta/serialize_schemas/marshmallow_agent.py +14 -1
- letta/serialize_schemas/marshmallow_block.py +23 -1
- letta/serialize_schemas/marshmallow_message.py +1 -3
- letta/serialize_schemas/marshmallow_tool.py +23 -1
- letta/server/db.py +110 -6
- letta/server/rest_api/app.py +85 -73
- letta/server/rest_api/routers/v1/agents.py +68 -53
- letta/server/rest_api/routers/v1/blocks.py +2 -2
- letta/server/rest_api/routers/v1/jobs.py +3 -0
- letta/server/rest_api/routers/v1/organizations.py +2 -2
- letta/server/rest_api/routers/v1/sources.py +18 -2
- letta/server/rest_api/routers/v1/tools.py +11 -12
- letta/server/rest_api/routers/v1/users.py +1 -1
- letta/server/rest_api/streaming_response.py +13 -5
- letta/server/rest_api/utils.py +8 -25
- letta/server/server.py +11 -4
- letta/server/ws_api/server.py +2 -2
- letta/services/agent_file_manager.py +616 -0
- letta/services/agent_manager.py +133 -46
- letta/services/block_manager.py +38 -17
- letta/services/file_manager.py +106 -21
- letta/services/file_processor/file_processor.py +93 -0
- letta/services/files_agents_manager.py +28 -0
- letta/services/group_manager.py +4 -5
- letta/services/helpers/agent_manager_helper.py +57 -9
- letta/services/identity_manager.py +22 -0
- letta/services/job_manager.py +210 -91
- letta/services/llm_batch_manager.py +9 -6
- letta/services/mcp/stdio_client.py +1 -2
- letta/services/mcp_manager.py +0 -1
- letta/services/message_manager.py +49 -26
- letta/services/passage_manager.py +0 -1
- letta/services/provider_manager.py +1 -1
- letta/services/source_manager.py +114 -5
- letta/services/step_manager.py +36 -4
- letta/services/telemetry_manager.py +9 -2
- letta/services/tool_executor/builtin_tool_executor.py +5 -1
- letta/services/tool_executor/core_tool_executor.py +3 -3
- letta/services/tool_manager.py +95 -20
- letta/services/user_manager.py +4 -12
- letta/settings.py +23 -6
- letta/system.py +1 -1
- letta/utils.py +26 -2
- {letta_nightly-0.8.15.dev20250720104313.dist-info → letta_nightly-0.8.16.dev20250721104533.dist-info}/METADATA +3 -2
- {letta_nightly-0.8.15.dev20250720104313.dist-info → letta_nightly-0.8.16.dev20250721104533.dist-info}/RECORD +99 -94
- {letta_nightly-0.8.15.dev20250720104313.dist-info → letta_nightly-0.8.16.dev20250721104533.dist-info}/LICENSE +0 -0
- {letta_nightly-0.8.15.dev20250720104313.dist-info → letta_nightly-0.8.16.dev20250721104533.dist-info}/WHEEL +0 -0
- {letta_nightly-0.8.15.dev20250720104313.dist-info → letta_nightly-0.8.16.dev20250721104533.dist-info}/entry_points.txt +0 -0
letta/orm/message.py
CHANGED
@@ -11,7 +11,7 @@ from letta.schemas.letta_message_content import MessageContent
|
|
11
11
|
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
12
12
|
from letta.schemas.message import Message as PydanticMessage
|
13
13
|
from letta.schemas.message import ToolReturn
|
14
|
-
from letta.settings import settings
|
14
|
+
from letta.settings import DatabaseChoice, settings
|
15
15
|
|
16
16
|
|
17
17
|
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
@@ -49,6 +49,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
49
49
|
nullable=True,
|
50
50
|
doc="The id of the LLMBatchItem that this message is associated with",
|
51
51
|
)
|
52
|
+
is_err: Mapped[Optional[bool]] = mapped_column(
|
53
|
+
nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
|
54
|
+
)
|
52
55
|
|
53
56
|
# Monotonically increasing sequence for efficient/correct listing
|
54
57
|
sequence_id: Mapped[int] = mapped_column(
|
@@ -59,7 +62,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
59
62
|
)
|
60
63
|
|
61
64
|
# Relationships
|
62
|
-
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="
|
65
|
+
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise")
|
63
66
|
step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
|
64
67
|
|
65
68
|
# Job relationship
|
@@ -78,7 +81,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
78
81
|
if self.text and not model.content:
|
79
82
|
model.content = [PydanticTextContent(text=self.text)]
|
80
83
|
# If there are no tool calls, set tool_calls to None
|
81
|
-
if len(self.tool_calls) == 0:
|
84
|
+
if self.tool_calls is None or len(self.tool_calls) == 0:
|
82
85
|
model.tool_calls = None
|
83
86
|
return model
|
84
87
|
|
@@ -86,16 +89,139 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
86
89
|
# listener
|
87
90
|
|
88
91
|
|
89
|
-
@event.listens_for(
|
90
|
-
def
|
91
|
-
#
|
92
|
-
if
|
93
|
-
|
92
|
+
@event.listens_for(Session, "before_flush")
|
93
|
+
def set_sequence_id_for_sqlite_bulk(session, flush_context, instances):
|
94
|
+
# Handle bulk inserts for SQLite
|
95
|
+
if settings.database_engine is DatabaseChoice.SQLITE:
|
96
|
+
# Find all new Message objects that need sequence IDs
|
97
|
+
new_messages = [obj for obj in session.new if isinstance(obj, Message) and obj.sequence_id is None]
|
98
|
+
|
99
|
+
if new_messages:
|
100
|
+
# Create a sequence table if it doesn't exist for atomic increments
|
101
|
+
session.execute(
|
102
|
+
text(
|
103
|
+
"""
|
104
|
+
CREATE TABLE IF NOT EXISTS message_sequence (
|
105
|
+
id INTEGER PRIMARY KEY,
|
106
|
+
next_val INTEGER NOT NULL DEFAULT 1
|
107
|
+
)
|
108
|
+
"""
|
109
|
+
)
|
110
|
+
)
|
111
|
+
|
112
|
+
# Initialize the sequence table if empty
|
113
|
+
session.execute(
|
114
|
+
text(
|
115
|
+
"""
|
116
|
+
INSERT OR IGNORE INTO message_sequence (id, next_val)
|
117
|
+
SELECT 1, COALESCE(MAX(sequence_id), 0) + 1
|
118
|
+
FROM messages
|
119
|
+
"""
|
120
|
+
)
|
121
|
+
)
|
122
|
+
|
123
|
+
# Get the number of records being inserted
|
124
|
+
records_count = len(new_messages)
|
125
|
+
|
126
|
+
# Atomically reserve a range of sequence values for this batch
|
127
|
+
result = session.execute(
|
128
|
+
text(
|
129
|
+
"""
|
130
|
+
UPDATE message_sequence
|
131
|
+
SET next_val = next_val + :count
|
132
|
+
WHERE id = 1
|
133
|
+
RETURNING next_val - :count
|
134
|
+
"""
|
135
|
+
),
|
136
|
+
{"count": records_count},
|
137
|
+
)
|
138
|
+
|
139
|
+
start_sequence_id = result.scalar()
|
140
|
+
if start_sequence_id is None:
|
141
|
+
# Fallback if RETURNING doesn't work (older SQLite versions)
|
142
|
+
session.execute(
|
143
|
+
text(
|
144
|
+
"""
|
145
|
+
UPDATE message_sequence
|
146
|
+
SET next_val = next_val + :count
|
147
|
+
WHERE id = 1
|
148
|
+
"""
|
149
|
+
),
|
150
|
+
{"count": records_count},
|
151
|
+
)
|
152
|
+
start_sequence_id = session.execute(
|
153
|
+
text(
|
154
|
+
"""
|
155
|
+
SELECT next_val - :count FROM message_sequence WHERE id = 1
|
156
|
+
"""
|
157
|
+
),
|
158
|
+
{"count": records_count},
|
159
|
+
).scalar()
|
160
|
+
|
161
|
+
# Assign sequential IDs to each record
|
162
|
+
for i, obj in enumerate(new_messages):
|
163
|
+
obj.sequence_id = start_sequence_id + i
|
94
164
|
|
95
|
-
if not hasattr(session, "_sequence_id_counter"):
|
96
|
-
# Initialize counter for this flush
|
97
|
-
max_seq = connection.scalar(text("SELECT MAX(sequence_id) FROM messages"))
|
98
|
-
session._sequence_id_counter = max_seq or 0
|
99
165
|
|
100
|
-
|
101
|
-
|
166
|
+
@event.listens_for(Message, "before_insert")
|
167
|
+
def set_sequence_id_for_sqlite(mapper, connection, target):
|
168
|
+
if settings.database_engine is DatabaseChoice.SQLITE:
|
169
|
+
# For SQLite, we need to generate sequence_id manually
|
170
|
+
# Use a database-level atomic operation to avoid race conditions
|
171
|
+
|
172
|
+
# Create a sequence table if it doesn't exist for atomic increments
|
173
|
+
connection.execute(
|
174
|
+
text(
|
175
|
+
"""
|
176
|
+
CREATE TABLE IF NOT EXISTS message_sequence (
|
177
|
+
id INTEGER PRIMARY KEY,
|
178
|
+
next_val INTEGER NOT NULL DEFAULT 1
|
179
|
+
)
|
180
|
+
"""
|
181
|
+
)
|
182
|
+
)
|
183
|
+
|
184
|
+
# Initialize the sequence table if empty
|
185
|
+
connection.execute(
|
186
|
+
text(
|
187
|
+
"""
|
188
|
+
INSERT OR IGNORE INTO message_sequence (id, next_val)
|
189
|
+
SELECT 1, COALESCE(MAX(sequence_id), 0) + 1
|
190
|
+
FROM messages
|
191
|
+
"""
|
192
|
+
)
|
193
|
+
)
|
194
|
+
|
195
|
+
# Atomically get the next sequence value
|
196
|
+
result = connection.execute(
|
197
|
+
text(
|
198
|
+
"""
|
199
|
+
UPDATE message_sequence
|
200
|
+
SET next_val = next_val + 1
|
201
|
+
WHERE id = 1
|
202
|
+
RETURNING next_val - 1
|
203
|
+
"""
|
204
|
+
)
|
205
|
+
)
|
206
|
+
|
207
|
+
sequence_id = result.scalar()
|
208
|
+
if sequence_id is None:
|
209
|
+
# Fallback if RETURNING doesn't work (older SQLite versions)
|
210
|
+
connection.execute(
|
211
|
+
text(
|
212
|
+
"""
|
213
|
+
UPDATE message_sequence
|
214
|
+
SET next_val = next_val + 1
|
215
|
+
WHERE id = 1
|
216
|
+
"""
|
217
|
+
)
|
218
|
+
)
|
219
|
+
sequence_id = connection.execute(
|
220
|
+
text(
|
221
|
+
"""
|
222
|
+
SELECT next_val - 1 FROM message_sequence WHERE id = 1
|
223
|
+
"""
|
224
|
+
)
|
225
|
+
).scalar()
|
226
|
+
|
227
|
+
target.sequence_id = sequence_id
|
letta/orm/organization.py
CHANGED
@@ -6,18 +6,17 @@ from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
6
6
|
from letta.schemas.organization import Organization as PydanticOrganization
|
7
7
|
|
8
8
|
if TYPE_CHECKING:
|
9
|
+
from letta.orm import Source
|
9
10
|
from letta.orm.agent import Agent
|
10
|
-
from letta.orm.agent_passage import AgentPassage
|
11
11
|
from letta.orm.block import Block
|
12
12
|
from letta.orm.group import Group
|
13
13
|
from letta.orm.identity import Identity
|
14
|
-
from letta.orm.
|
14
|
+
from letta.orm.llm_batch_items import LLMBatchItem
|
15
15
|
from letta.orm.llm_batch_job import LLMBatchJob
|
16
16
|
from letta.orm.message import Message
|
17
|
+
from letta.orm.passage import AgentPassage, SourcePassage
|
17
18
|
from letta.orm.provider import Provider
|
18
|
-
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig
|
19
|
-
from letta.orm.sandbox_environment_variable import SandboxEnvironmentVariable
|
20
|
-
from letta.orm.source_passage import SourcePassage
|
19
|
+
from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
|
21
20
|
from letta.orm.tool import Tool
|
22
21
|
from letta.orm.user import User
|
23
22
|
|
@@ -48,6 +47,7 @@ class Organization(SqlalchemyBase):
|
|
48
47
|
|
49
48
|
# relationships
|
50
49
|
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
50
|
+
sources: Mapped[List["Source"]] = relationship("Source", cascade="all, delete-orphan")
|
51
51
|
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
52
52
|
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
53
53
|
"SourcePassage", back_populates="organization", cascade="all, delete-orphan"
|
letta/orm/passage.py
CHANGED
@@ -9,7 +9,7 @@ from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
|
|
9
9
|
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin
|
10
10
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
11
11
|
from letta.schemas.passage import Passage as PydanticPassage
|
12
|
-
from letta.settings import settings
|
12
|
+
from letta.settings import DatabaseChoice, settings
|
13
13
|
|
14
14
|
config = LettaConfig()
|
15
15
|
|
@@ -29,7 +29,7 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
|
|
29
29
|
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
30
30
|
|
31
31
|
# Vector embedding field based on database type
|
32
|
-
if settings.
|
32
|
+
if settings.database_engine is DatabaseChoice.POSTGRES:
|
33
33
|
from pgvector.sqlalchemy import Vector
|
34
34
|
|
35
35
|
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
@@ -56,7 +56,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
|
56
56
|
@declared_attr
|
57
57
|
def __table_args__(cls):
|
58
58
|
# TODO (cliandy): investigate if this is necessary, may be for SQLite compatability or do we need to add as well?
|
59
|
-
if settings.
|
59
|
+
if settings.database_engine is DatabaseChoice.POSTGRES:
|
60
60
|
return (
|
61
61
|
Index("source_passages_org_idx", "organization_id"),
|
62
62
|
Index("source_passages_created_at_id_idx", "created_at", "id"),
|
@@ -81,7 +81,7 @@ class AgentPassage(BasePassage, AgentMixin):
|
|
81
81
|
|
82
82
|
@declared_attr
|
83
83
|
def __table_args__(cls):
|
84
|
-
if settings.
|
84
|
+
if settings.database_engine is DatabaseChoice.POSTGRES:
|
85
85
|
return (
|
86
86
|
Index("agent_passages_org_idx", "organization_id"),
|
87
87
|
Index("ix_agent_passages_org_agent", "organization_id", "agent_id"),
|
letta/orm/source.py
CHANGED
@@ -20,7 +20,7 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
|
20
20
|
__pydantic_model__ = PydanticSource
|
21
21
|
|
22
22
|
__table_args__ = (
|
23
|
-
Index(
|
23
|
+
Index("source_created_at_id_idx", "created_at", "id"),
|
24
24
|
UniqueConstraint("name", "organization_id", name="uq_source_name_organization"),
|
25
25
|
{"extend_existing": True},
|
26
26
|
)
|
letta/orm/sqlalchemy_base.py
CHANGED
@@ -5,7 +5,7 @@ from functools import wraps
|
|
5
5
|
from pprint import pformat
|
6
6
|
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
|
7
7
|
|
8
|
-
from sqlalchemy import Sequence, String, and_, delete, func, or_, select
|
8
|
+
from sqlalchemy import Sequence, String, and_, delete, func, or_, select
|
9
9
|
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
10
10
|
from sqlalchemy.ext.asyncio import AsyncSession
|
11
11
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
@@ -15,6 +15,7 @@ from letta.log import get_logger
|
|
15
15
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
16
16
|
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
17
17
|
from letta.orm.sqlite_functions import adapt_array
|
18
|
+
from letta.settings import DatabaseChoice
|
18
19
|
|
19
20
|
if TYPE_CHECKING:
|
20
21
|
from pydantic import BaseModel
|
@@ -353,10 +354,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
353
354
|
|
354
355
|
if before_obj and after_obj:
|
355
356
|
# Window-based query - get records between before and after
|
356
|
-
conditions
|
357
|
-
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
|
358
|
-
|
359
|
-
|
357
|
+
conditions.append(
|
358
|
+
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
|
359
|
+
)
|
360
|
+
conditions.append(
|
361
|
+
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id))
|
362
|
+
)
|
360
363
|
else:
|
361
364
|
# Pure pagination query
|
362
365
|
if before_obj:
|
@@ -393,7 +396,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
393
396
|
|
394
397
|
from letta.settings import settings
|
395
398
|
|
396
|
-
if settings.
|
399
|
+
if settings.database_engine is DatabaseChoice.POSTGRES:
|
397
400
|
# PostgreSQL with pgvector
|
398
401
|
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
399
402
|
else:
|
@@ -509,14 +512,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
509
512
|
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, check_is_deleted, **kwargs)
|
510
513
|
if query is None:
|
511
514
|
raise NoResultFound(f"{cls.__name__} not found with identifier {identifier}")
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
result = await db_session.execute(query)
|
516
|
-
item = result.scalar_one_or_none()
|
517
|
-
finally:
|
518
|
-
if is_postgresql_session(db_session):
|
519
|
-
await db_session.execute(text("SET LOCAL enable_seqscan = ON"))
|
515
|
+
|
516
|
+
result = await db_session.execute(query)
|
517
|
+
item = result.scalar_one_or_none()
|
520
518
|
|
521
519
|
if item is None:
|
522
520
|
raise NoResultFound(f"{cls.__name__} not found with {', '.join(query_conditions if query_conditions else ['no conditions'])}")
|
@@ -656,7 +654,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
656
654
|
self._handle_dbapi_error(e)
|
657
655
|
|
658
656
|
@handle_db_timeout
|
659
|
-
async def create_async(
|
657
|
+
async def create_async(
|
658
|
+
self,
|
659
|
+
db_session: "AsyncSession",
|
660
|
+
actor: Optional["User"] = None,
|
661
|
+
no_commit: bool = False,
|
662
|
+
no_refresh: bool = False,
|
663
|
+
) -> "SqlalchemyBase":
|
660
664
|
"""Async version of create function"""
|
661
665
|
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
662
666
|
|
@@ -668,7 +672,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
668
672
|
await db_session.flush() # no commit, just flush to get PK
|
669
673
|
else:
|
670
674
|
await db_session.commit()
|
671
|
-
|
675
|
+
|
676
|
+
if not no_refresh:
|
677
|
+
await db_session.refresh(self)
|
672
678
|
return self
|
673
679
|
except (DBAPIError, IntegrityError) as e:
|
674
680
|
self._handle_dbapi_error(e)
|
@@ -717,7 +723,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
717
723
|
@classmethod
|
718
724
|
@handle_db_timeout
|
719
725
|
async def batch_create_async(
|
720
|
-
cls,
|
726
|
+
cls,
|
727
|
+
items: List["SqlalchemyBase"],
|
728
|
+
db_session: "AsyncSession",
|
729
|
+
actor: Optional["User"] = None,
|
730
|
+
no_commit: bool = False,
|
731
|
+
no_refresh: bool = False,
|
721
732
|
) -> List["SqlalchemyBase"]:
|
722
733
|
"""
|
723
734
|
Async version of batch_create method.
|
@@ -726,10 +737,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
726
737
|
items: List of model instances to create
|
727
738
|
db_session: AsyncSession session
|
728
739
|
actor: Optional user performing the action
|
740
|
+
no_commit: Whether to commit the transaction
|
741
|
+
no_refresh: Whether to refresh the created objects
|
729
742
|
Returns:
|
730
743
|
List of created model instances
|
731
744
|
"""
|
732
745
|
logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}")
|
746
|
+
|
733
747
|
if not items:
|
734
748
|
return []
|
735
749
|
|
@@ -740,21 +754,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
740
754
|
|
741
755
|
try:
|
742
756
|
db_session.add_all(items)
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
await db_session.commit()
|
749
|
-
|
750
|
-
# Re-query the objects to get them with relationships loaded
|
751
|
-
query = select(cls).where(cls.id.in_(item_ids))
|
752
|
-
if hasattr(cls, "created_at"):
|
753
|
-
query = query.order_by(cls.created_at)
|
757
|
+
if no_commit:
|
758
|
+
await db_session.flush()
|
759
|
+
else:
|
760
|
+
await db_session.commit()
|
754
761
|
|
755
|
-
|
756
|
-
|
762
|
+
if no_refresh:
|
763
|
+
return items
|
764
|
+
else:
|
765
|
+
# Re-query the objects to get them with relationships loaded
|
766
|
+
item_ids = [item.id for item in items]
|
767
|
+
query = select(cls).where(cls.id.in_(item_ids))
|
768
|
+
if hasattr(cls, "created_at"):
|
769
|
+
query = query.order_by(cls.created_at)
|
757
770
|
|
771
|
+
result = await db_session.execute(query)
|
772
|
+
return list(result.scalars())
|
758
773
|
except (DBAPIError, IntegrityError) as e:
|
759
774
|
cls._handle_dbapi_error(e)
|
760
775
|
|
@@ -854,20 +869,27 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
854
869
|
return self
|
855
870
|
|
856
871
|
@handle_db_timeout
|
857
|
-
async def update_async(
|
872
|
+
async def update_async(
|
873
|
+
self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False
|
874
|
+
) -> "SqlalchemyBase":
|
858
875
|
"""Async version of update function"""
|
859
|
-
logger.debug(
|
876
|
+
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
877
|
+
|
860
878
|
if actor:
|
861
879
|
self._set_created_and_updated_by_fields(actor.id)
|
862
880
|
self.set_updated_at()
|
881
|
+
try:
|
882
|
+
db_session.add(self)
|
883
|
+
if no_commit:
|
884
|
+
await db_session.flush()
|
885
|
+
else:
|
886
|
+
await db_session.commit()
|
863
887
|
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
await db_session.refresh(self)
|
870
|
-
return self
|
888
|
+
if not no_refresh:
|
889
|
+
await db_session.refresh(self)
|
890
|
+
return self
|
891
|
+
except (DBAPIError, IntegrityError) as e:
|
892
|
+
self._handle_dbapi_error(e)
|
871
893
|
|
872
894
|
@classmethod
|
873
895
|
def _size_preprocess(
|
letta/orm/step.py
CHANGED
@@ -5,6 +5,7 @@ from sqlalchemy import JSON, ForeignKey, String
|
|
5
5
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
6
6
|
|
7
7
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
8
|
+
from letta.schemas.letta_stop_reason import StopReasonType
|
8
9
|
from letta.schemas.step import Step as PydanticStep
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -45,6 +46,7 @@ class Step(SqlalchemyBase):
|
|
45
46
|
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
46
47
|
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
47
48
|
completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
49
|
+
stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.")
|
48
50
|
tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
|
49
51
|
tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
|
50
52
|
trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")
|