letta-nightly 0.8.15.dev20250719104256__py3-none-any.whl → 0.8.16.dev20250721070720__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +27 -11
  3. letta/agents/helpers.py +1 -1
  4. letta/agents/letta_agent.py +518 -322
  5. letta/agents/letta_agent_batch.py +1 -2
  6. letta/agents/voice_agent.py +15 -17
  7. letta/client/client.py +3 -3
  8. letta/constants.py +5 -0
  9. letta/embeddings.py +0 -2
  10. letta/errors.py +8 -0
  11. letta/functions/function_sets/base.py +3 -3
  12. letta/functions/helpers.py +2 -3
  13. letta/groups/sleeptime_multi_agent.py +0 -1
  14. letta/helpers/composio_helpers.py +2 -2
  15. letta/helpers/converters.py +1 -1
  16. letta/helpers/pinecone_utils.py +8 -0
  17. letta/helpers/tool_rule_solver.py +13 -18
  18. letta/llm_api/aws_bedrock.py +16 -2
  19. letta/llm_api/cohere.py +1 -1
  20. letta/llm_api/openai_client.py +1 -1
  21. letta/local_llm/grammars/gbnf_grammar_generator.py +1 -1
  22. letta/local_llm/llm_chat_completion_wrappers/zephyr.py +14 -14
  23. letta/local_llm/utils.py +1 -2
  24. letta/orm/agent.py +3 -3
  25. letta/orm/block.py +4 -4
  26. letta/orm/files_agents.py +0 -1
  27. letta/orm/identity.py +2 -0
  28. letta/orm/mcp_server.py +0 -2
  29. letta/orm/message.py +140 -14
  30. letta/orm/organization.py +5 -5
  31. letta/orm/passage.py +4 -4
  32. letta/orm/source.py +1 -1
  33. letta/orm/sqlalchemy_base.py +61 -39
  34. letta/orm/step.py +2 -0
  35. letta/otel/db_pool_monitoring.py +308 -0
  36. letta/otel/metric_registry.py +94 -1
  37. letta/otel/sqlalchemy_instrumentation.py +548 -0
  38. letta/otel/sqlalchemy_instrumentation_integration.py +124 -0
  39. letta/otel/tracing.py +37 -1
  40. letta/schemas/agent.py +0 -3
  41. letta/schemas/agent_file.py +283 -0
  42. letta/schemas/block.py +0 -3
  43. letta/schemas/file.py +28 -26
  44. letta/schemas/letta_message.py +15 -4
  45. letta/schemas/memory.py +1 -1
  46. letta/schemas/message.py +31 -26
  47. letta/schemas/openai/chat_completion_response.py +0 -1
  48. letta/schemas/providers.py +20 -0
  49. letta/schemas/source.py +11 -13
  50. letta/schemas/step.py +12 -0
  51. letta/schemas/tool.py +0 -4
  52. letta/serialize_schemas/marshmallow_agent.py +14 -1
  53. letta/serialize_schemas/marshmallow_block.py +23 -1
  54. letta/serialize_schemas/marshmallow_message.py +1 -3
  55. letta/serialize_schemas/marshmallow_tool.py +23 -1
  56. letta/server/db.py +110 -6
  57. letta/server/rest_api/app.py +85 -73
  58. letta/server/rest_api/routers/v1/agents.py +68 -53
  59. letta/server/rest_api/routers/v1/blocks.py +2 -2
  60. letta/server/rest_api/routers/v1/jobs.py +3 -0
  61. letta/server/rest_api/routers/v1/organizations.py +2 -2
  62. letta/server/rest_api/routers/v1/sources.py +18 -2
  63. letta/server/rest_api/routers/v1/tools.py +11 -12
  64. letta/server/rest_api/routers/v1/users.py +1 -1
  65. letta/server/rest_api/streaming_response.py +13 -5
  66. letta/server/rest_api/utils.py +8 -25
  67. letta/server/server.py +11 -4
  68. letta/server/ws_api/server.py +2 -2
  69. letta/services/agent_file_manager.py +616 -0
  70. letta/services/agent_manager.py +133 -46
  71. letta/services/block_manager.py +38 -17
  72. letta/services/file_manager.py +106 -21
  73. letta/services/file_processor/file_processor.py +93 -0
  74. letta/services/files_agents_manager.py +28 -0
  75. letta/services/group_manager.py +4 -5
  76. letta/services/helpers/agent_manager_helper.py +57 -9
  77. letta/services/identity_manager.py +22 -0
  78. letta/services/job_manager.py +210 -91
  79. letta/services/llm_batch_manager.py +9 -6
  80. letta/services/mcp/stdio_client.py +1 -2
  81. letta/services/mcp_manager.py +0 -1
  82. letta/services/message_manager.py +49 -26
  83. letta/services/passage_manager.py +0 -1
  84. letta/services/provider_manager.py +1 -1
  85. letta/services/source_manager.py +114 -5
  86. letta/services/step_manager.py +36 -4
  87. letta/services/telemetry_manager.py +9 -2
  88. letta/services/tool_executor/builtin_tool_executor.py +5 -1
  89. letta/services/tool_executor/core_tool_executor.py +3 -3
  90. letta/services/tool_manager.py +95 -20
  91. letta/services/user_manager.py +4 -12
  92. letta/settings.py +23 -6
  93. letta/system.py +1 -1
  94. letta/utils.py +26 -2
  95. {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/METADATA +3 -2
  96. {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/RECORD +99 -94
  97. {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/LICENSE +0 -0
  98. {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/WHEEL +0 -0
  99. {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/entry_points.txt +0 -0
letta/orm/message.py CHANGED
@@ -11,7 +11,7 @@ from letta.schemas.letta_message_content import MessageContent
11
11
  from letta.schemas.letta_message_content import TextContent as PydanticTextContent
12
12
  from letta.schemas.message import Message as PydanticMessage
13
13
  from letta.schemas.message import ToolReturn
14
- from letta.settings import settings
14
+ from letta.settings import DatabaseChoice, settings
15
15
 
16
16
 
17
17
  class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
@@ -49,6 +49,9 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
49
49
  nullable=True,
50
50
  doc="The id of the LLMBatchItem that this message is associated with",
51
51
  )
52
+ is_err: Mapped[Optional[bool]] = mapped_column(
53
+ nullable=True, doc="Whether this message is part of an error step. Used only for debugging purposes."
54
+ )
52
55
 
53
56
  # Monotonically increasing sequence for efficient/correct listing
54
57
  sequence_id: Mapped[int] = mapped_column(
@@ -59,7 +62,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
59
62
  )
60
63
 
61
64
  # Relationships
62
- organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
65
+ organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="raise")
63
66
  step: Mapped["Step"] = relationship("Step", back_populates="messages", lazy="selectin")
64
67
 
65
68
  # Job relationship
@@ -78,7 +81,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
78
81
  if self.text and not model.content:
79
82
  model.content = [PydanticTextContent(text=self.text)]
80
83
  # If there are no tool calls, set tool_calls to None
81
- if len(self.tool_calls) == 0:
84
+ if self.tool_calls is None or len(self.tool_calls) == 0:
82
85
  model.tool_calls = None
83
86
  return model
84
87
 
@@ -86,16 +89,139 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
86
89
  # listener
87
90
 
88
91
 
89
- @event.listens_for(Message, "before_insert")
90
- def set_sequence_id_for_sqlite(mapper, connection, target):
91
- # TODO: Kind of hacky, used to detect if we are using sqlite or not
92
- if not settings.letta_pg_uri_no_default:
93
- session = Session.object_session(target)
92
+ @event.listens_for(Session, "before_flush")
93
+ def set_sequence_id_for_sqlite_bulk(session, flush_context, instances):
94
+ # Handle bulk inserts for SQLite
95
+ if settings.database_engine is DatabaseChoice.SQLITE:
96
+ # Find all new Message objects that need sequence IDs
97
+ new_messages = [obj for obj in session.new if isinstance(obj, Message) and obj.sequence_id is None]
98
+
99
+ if new_messages:
100
+ # Create a sequence table if it doesn't exist for atomic increments
101
+ session.execute(
102
+ text(
103
+ """
104
+ CREATE TABLE IF NOT EXISTS message_sequence (
105
+ id INTEGER PRIMARY KEY,
106
+ next_val INTEGER NOT NULL DEFAULT 1
107
+ )
108
+ """
109
+ )
110
+ )
111
+
112
+ # Initialize the sequence table if empty
113
+ session.execute(
114
+ text(
115
+ """
116
+ INSERT OR IGNORE INTO message_sequence (id, next_val)
117
+ SELECT 1, COALESCE(MAX(sequence_id), 0) + 1
118
+ FROM messages
119
+ """
120
+ )
121
+ )
122
+
123
+ # Get the number of records being inserted
124
+ records_count = len(new_messages)
125
+
126
+ # Atomically reserve a range of sequence values for this batch
127
+ result = session.execute(
128
+ text(
129
+ """
130
+ UPDATE message_sequence
131
+ SET next_val = next_val + :count
132
+ WHERE id = 1
133
+ RETURNING next_val - :count
134
+ """
135
+ ),
136
+ {"count": records_count},
137
+ )
138
+
139
+ start_sequence_id = result.scalar()
140
+ if start_sequence_id is None:
141
+ # Fallback if RETURNING doesn't work (older SQLite versions)
142
+ session.execute(
143
+ text(
144
+ """
145
+ UPDATE message_sequence
146
+ SET next_val = next_val + :count
147
+ WHERE id = 1
148
+ """
149
+ ),
150
+ {"count": records_count},
151
+ )
152
+ start_sequence_id = session.execute(
153
+ text(
154
+ """
155
+ SELECT next_val - :count FROM message_sequence WHERE id = 1
156
+ """
157
+ ),
158
+ {"count": records_count},
159
+ ).scalar()
160
+
161
+ # Assign sequential IDs to each record
162
+ for i, obj in enumerate(new_messages):
163
+ obj.sequence_id = start_sequence_id + i
94
164
 
95
- if not hasattr(session, "_sequence_id_counter"):
96
- # Initialize counter for this flush
97
- max_seq = connection.scalar(text("SELECT MAX(sequence_id) FROM messages"))
98
- session._sequence_id_counter = max_seq or 0
99
165
 
100
- session._sequence_id_counter += 1
101
- target.sequence_id = session._sequence_id_counter
166
+ @event.listens_for(Message, "before_insert")
167
+ def set_sequence_id_for_sqlite(mapper, connection, target):
168
+ if settings.database_engine is DatabaseChoice.SQLITE:
169
+ # For SQLite, we need to generate sequence_id manually
170
+ # Use a database-level atomic operation to avoid race conditions
171
+
172
+ # Create a sequence table if it doesn't exist for atomic increments
173
+ connection.execute(
174
+ text(
175
+ """
176
+ CREATE TABLE IF NOT EXISTS message_sequence (
177
+ id INTEGER PRIMARY KEY,
178
+ next_val INTEGER NOT NULL DEFAULT 1
179
+ )
180
+ """
181
+ )
182
+ )
183
+
184
+ # Initialize the sequence table if empty
185
+ connection.execute(
186
+ text(
187
+ """
188
+ INSERT OR IGNORE INTO message_sequence (id, next_val)
189
+ SELECT 1, COALESCE(MAX(sequence_id), 0) + 1
190
+ FROM messages
191
+ """
192
+ )
193
+ )
194
+
195
+ # Atomically get the next sequence value
196
+ result = connection.execute(
197
+ text(
198
+ """
199
+ UPDATE message_sequence
200
+ SET next_val = next_val + 1
201
+ WHERE id = 1
202
+ RETURNING next_val - 1
203
+ """
204
+ )
205
+ )
206
+
207
+ sequence_id = result.scalar()
208
+ if sequence_id is None:
209
+ # Fallback if RETURNING doesn't work (older SQLite versions)
210
+ connection.execute(
211
+ text(
212
+ """
213
+ UPDATE message_sequence
214
+ SET next_val = next_val + 1
215
+ WHERE id = 1
216
+ """
217
+ )
218
+ )
219
+ sequence_id = connection.execute(
220
+ text(
221
+ """
222
+ SELECT next_val - 1 FROM message_sequence WHERE id = 1
223
+ """
224
+ )
225
+ ).scalar()
226
+
227
+ target.sequence_id = sequence_id
letta/orm/organization.py CHANGED
@@ -6,18 +6,17 @@ from letta.orm.sqlalchemy_base import SqlalchemyBase
6
6
  from letta.schemas.organization import Organization as PydanticOrganization
7
7
 
8
8
  if TYPE_CHECKING:
9
+ from letta.orm import Source
9
10
  from letta.orm.agent import Agent
10
- from letta.orm.agent_passage import AgentPassage
11
11
  from letta.orm.block import Block
12
12
  from letta.orm.group import Group
13
13
  from letta.orm.identity import Identity
14
- from letta.orm.llm_batch_item import LLMBatchItem
14
+ from letta.orm.llm_batch_items import LLMBatchItem
15
15
  from letta.orm.llm_batch_job import LLMBatchJob
16
16
  from letta.orm.message import Message
17
+ from letta.orm.passage import AgentPassage, SourcePassage
17
18
  from letta.orm.provider import Provider
18
- from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig
19
- from letta.orm.sandbox_environment_variable import SandboxEnvironmentVariable
20
- from letta.orm.source_passage import SourcePassage
19
+ from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable
21
20
  from letta.orm.tool import Tool
22
21
  from letta.orm.user import User
23
22
 
@@ -48,6 +47,7 @@ class Organization(SqlalchemyBase):
48
47
 
49
48
  # relationships
50
49
  agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
50
+ sources: Mapped[List["Source"]] = relationship("Source", cascade="all, delete-orphan")
51
51
  messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
52
52
  source_passages: Mapped[List["SourcePassage"]] = relationship(
53
53
  "SourcePassage", back_populates="organization", cascade="all, delete-orphan"
letta/orm/passage.py CHANGED
@@ -9,7 +9,7 @@ from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
9
9
  from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin
10
10
  from letta.orm.sqlalchemy_base import SqlalchemyBase
11
11
  from letta.schemas.passage import Passage as PydanticPassage
12
- from letta.settings import settings
12
+ from letta.settings import DatabaseChoice, settings
13
13
 
14
14
  config = LettaConfig()
15
15
 
@@ -29,7 +29,7 @@ class BasePassage(SqlalchemyBase, OrganizationMixin):
29
29
  metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
30
30
 
31
31
  # Vector embedding field based on database type
32
- if settings.letta_pg_uri_no_default:
32
+ if settings.database_engine is DatabaseChoice.POSTGRES:
33
33
  from pgvector.sqlalchemy import Vector
34
34
 
35
35
  embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
@@ -56,7 +56,7 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
56
56
  @declared_attr
57
57
  def __table_args__(cls):
58
58
  # TODO (cliandy): investigate if this is necessary, may be for SQLite compatability or do we need to add as well?
59
- if settings.letta_pg_uri_no_default:
59
+ if settings.database_engine is DatabaseChoice.POSTGRES:
60
60
  return (
61
61
  Index("source_passages_org_idx", "organization_id"),
62
62
  Index("source_passages_created_at_id_idx", "created_at", "id"),
@@ -81,7 +81,7 @@ class AgentPassage(BasePassage, AgentMixin):
81
81
 
82
82
  @declared_attr
83
83
  def __table_args__(cls):
84
- if settings.letta_pg_uri_no_default:
84
+ if settings.database_engine is DatabaseChoice.POSTGRES:
85
85
  return (
86
86
  Index("agent_passages_org_idx", "organization_id"),
87
87
  Index("ix_agent_passages_org_agent", "organization_id", "agent_id"),
letta/orm/source.py CHANGED
@@ -20,7 +20,7 @@ class Source(SqlalchemyBase, OrganizationMixin):
20
20
  __pydantic_model__ = PydanticSource
21
21
 
22
22
  __table_args__ = (
23
- Index(f"source_created_at_id_idx", "created_at", "id"),
23
+ Index("source_created_at_id_idx", "created_at", "id"),
24
24
  UniqueConstraint("name", "organization_id", name="uq_source_name_organization"),
25
25
  {"extend_existing": True},
26
26
  )
@@ -5,7 +5,7 @@ from functools import wraps
5
5
  from pprint import pformat
6
6
  from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
7
7
 
8
- from sqlalchemy import Sequence, String, and_, delete, func, or_, select, text
8
+ from sqlalchemy import Sequence, String, and_, delete, func, or_, select
9
9
  from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
10
10
  from sqlalchemy.ext.asyncio import AsyncSession
11
11
  from sqlalchemy.orm import Mapped, Session, mapped_column
@@ -15,6 +15,7 @@ from letta.log import get_logger
15
15
  from letta.orm.base import Base, CommonSqlalchemyMetaMixins
16
16
  from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
17
17
  from letta.orm.sqlite_functions import adapt_array
18
+ from letta.settings import DatabaseChoice
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from pydantic import BaseModel
@@ -353,10 +354,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
353
354
 
354
355
  if before_obj and after_obj:
355
356
  # Window-based query - get records between before and after
356
- conditions = [
357
- or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)),
358
- or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)),
359
- ]
357
+ conditions.append(
358
+ or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id))
359
+ )
360
+ conditions.append(
361
+ or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id))
362
+ )
360
363
  else:
361
364
  # Pure pagination query
362
365
  if before_obj:
@@ -393,7 +396,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
393
396
 
394
397
  from letta.settings import settings
395
398
 
396
- if settings.letta_pg_uri_no_default:
399
+ if settings.database_engine is DatabaseChoice.POSTGRES:
397
400
  # PostgreSQL with pgvector
398
401
  query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
399
402
  else:
@@ -509,14 +512,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
509
512
  query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, check_is_deleted, **kwargs)
510
513
  if query is None:
511
514
  raise NoResultFound(f"{cls.__name__} not found with identifier {identifier}")
512
- if is_postgresql_session(db_session):
513
- await db_session.execute(text("SET LOCAL enable_seqscan = OFF"))
514
- try:
515
- result = await db_session.execute(query)
516
- item = result.scalar_one_or_none()
517
- finally:
518
- if is_postgresql_session(db_session):
519
- await db_session.execute(text("SET LOCAL enable_seqscan = ON"))
515
+
516
+ result = await db_session.execute(query)
517
+ item = result.scalar_one_or_none()
520
518
 
521
519
  if item is None:
522
520
  raise NoResultFound(f"{cls.__name__} not found with {', '.join(query_conditions if query_conditions else ['no conditions'])}")
@@ -656,7 +654,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
656
654
  self._handle_dbapi_error(e)
657
655
 
658
656
  @handle_db_timeout
659
- async def create_async(self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
657
+ async def create_async(
658
+ self,
659
+ db_session: "AsyncSession",
660
+ actor: Optional["User"] = None,
661
+ no_commit: bool = False,
662
+ no_refresh: bool = False,
663
+ ) -> "SqlalchemyBase":
660
664
  """Async version of create function"""
661
665
  logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
662
666
 
@@ -668,7 +672,9 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
668
672
  await db_session.flush() # no commit, just flush to get PK
669
673
  else:
670
674
  await db_session.commit()
671
- await db_session.refresh(self)
675
+
676
+ if not no_refresh:
677
+ await db_session.refresh(self)
672
678
  return self
673
679
  except (DBAPIError, IntegrityError) as e:
674
680
  self._handle_dbapi_error(e)
@@ -717,7 +723,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
717
723
  @classmethod
718
724
  @handle_db_timeout
719
725
  async def batch_create_async(
720
- cls, items: List["SqlalchemyBase"], db_session: "AsyncSession", actor: Optional["User"] = None
726
+ cls,
727
+ items: List["SqlalchemyBase"],
728
+ db_session: "AsyncSession",
729
+ actor: Optional["User"] = None,
730
+ no_commit: bool = False,
731
+ no_refresh: bool = False,
721
732
  ) -> List["SqlalchemyBase"]:
722
733
  """
723
734
  Async version of batch_create method.
@@ -726,10 +737,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
726
737
  items: List of model instances to create
727
738
  db_session: AsyncSession session
728
739
  actor: Optional user performing the action
740
+ no_commit: Whether to commit the transaction
741
+ no_refresh: Whether to refresh the created objects
729
742
  Returns:
730
743
  List of created model instances
731
744
  """
732
745
  logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}")
746
+
733
747
  if not items:
734
748
  return []
735
749
 
@@ -740,21 +754,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
740
754
 
741
755
  try:
742
756
  db_session.add_all(items)
743
- await db_session.flush() # Flush to generate IDs but don't commit yet
744
-
745
- # Collect IDs to fetch the complete objects after commit
746
- item_ids = [item.id for item in items]
747
-
748
- await db_session.commit()
749
-
750
- # Re-query the objects to get them with relationships loaded
751
- query = select(cls).where(cls.id.in_(item_ids))
752
- if hasattr(cls, "created_at"):
753
- query = query.order_by(cls.created_at)
757
+ if no_commit:
758
+ await db_session.flush()
759
+ else:
760
+ await db_session.commit()
754
761
 
755
- result = await db_session.execute(query)
756
- return list(result.scalars())
762
+ if no_refresh:
763
+ return items
764
+ else:
765
+ # Re-query the objects to get them with relationships loaded
766
+ item_ids = [item.id for item in items]
767
+ query = select(cls).where(cls.id.in_(item_ids))
768
+ if hasattr(cls, "created_at"):
769
+ query = query.order_by(cls.created_at)
757
770
 
771
+ result = await db_session.execute(query)
772
+ return list(result.scalars())
758
773
  except (DBAPIError, IntegrityError) as e:
759
774
  cls._handle_dbapi_error(e)
760
775
 
@@ -854,20 +869,27 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
854
869
  return self
855
870
 
856
871
  @handle_db_timeout
857
- async def update_async(self, db_session: AsyncSession, actor: "User | None" = None, no_commit: bool = False) -> "SqlalchemyBase":
872
+ async def update_async(
873
+ self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False, no_refresh: bool = False
874
+ ) -> "SqlalchemyBase":
858
875
  """Async version of update function"""
859
- logger.debug(...)
876
+ logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
877
+
860
878
  if actor:
861
879
  self._set_created_and_updated_by_fields(actor.id)
862
880
  self.set_updated_at()
881
+ try:
882
+ db_session.add(self)
883
+ if no_commit:
884
+ await db_session.flush()
885
+ else:
886
+ await db_session.commit()
863
887
 
864
- db_session.add(self)
865
- if no_commit:
866
- await db_session.flush()
867
- else:
868
- await db_session.commit()
869
- await db_session.refresh(self)
870
- return self
888
+ if not no_refresh:
889
+ await db_session.refresh(self)
890
+ return self
891
+ except (DBAPIError, IntegrityError) as e:
892
+ self._handle_dbapi_error(e)
871
893
 
872
894
  @classmethod
873
895
  def _size_preprocess(
letta/orm/step.py CHANGED
@@ -5,6 +5,7 @@ from sqlalchemy import JSON, ForeignKey, String
5
5
  from sqlalchemy.orm import Mapped, mapped_column, relationship
6
6
 
7
7
  from letta.orm.sqlalchemy_base import SqlalchemyBase
8
+ from letta.schemas.letta_stop_reason import StopReasonType
8
9
  from letta.schemas.step import Step as PydanticStep
9
10
 
10
11
  if TYPE_CHECKING:
@@ -45,6 +46,7 @@ class Step(SqlalchemyBase):
45
46
  prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
46
47
  total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
47
48
  completion_tokens_details: Mapped[Optional[Dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
49
+ stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(None, nullable=True, doc="The stop reason associated with this step.")
48
50
  tags: Mapped[Optional[List]] = mapped_column(JSON, doc="Metadata tags.")
49
51
  tid: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="Transaction ID that processed the step.")
50
52
  trace_id: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The trace id of the agent step.")