letta-nightly 0.6.1.dev20241206104246__py3-none-any.whl → 0.6.1.dev20241207104149__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +54 -37
- letta/agent_store/db.py +1 -77
- letta/agent_store/storage.py +0 -5
- letta/cli/cli.py +0 -1
- letta/client/client.py +3 -7
- letta/constants.py +1 -0
- letta/functions/function_sets/base.py +33 -5
- letta/main.py +2 -2
- letta/memory.py +4 -82
- letta/metadata.py +0 -35
- letta/o1_agent.py +7 -2
- letta/offline_memory_agent.py +6 -0
- letta/orm/__init__.py +2 -0
- letta/orm/file.py +1 -1
- letta/orm/message.py +66 -0
- letta/orm/mixins.py +16 -0
- letta/orm/organization.py +1 -0
- letta/orm/sqlalchemy_base.py +118 -26
- letta/schemas/letta_base.py +7 -6
- letta/schemas/message.py +1 -7
- letta/server/rest_api/routers/v1/agents.py +2 -2
- letta/server/rest_api/routers/v1/blocks.py +2 -2
- letta/server/server.py +52 -50
- letta/server/static_files/assets/index-43ab4d62.css +1 -0
- letta/server/static_files/assets/index-4848e3d7.js +40 -0
- letta/server/static_files/index.html +2 -2
- letta/services/block_manager.py +1 -1
- letta/services/message_manager.py +182 -0
- letta/services/organization_manager.py +6 -9
- letta/services/source_manager.py +1 -1
- letta/services/tool_manager.py +1 -1
- letta/services/user_manager.py +1 -1
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241207104149.dist-info}/METADATA +1 -1
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241207104149.dist-info}/RECORD +37 -37
- letta/agent_store/lancedb.py +0 -177
- letta/persistence_manager.py +0 -149
- letta/server/static_files/assets/index-1b5d1a41.js +0 -271
- letta/server/static_files/assets/index-56a3f8c6.css +0 -1
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241207104149.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241207104149.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.1.dev20241206104246.dist-info → letta_nightly-0.6.1.dev20241207104149.dist-info}/entry_points.txt +0 -0
letta/memory.py
CHANGED
|
@@ -67,14 +67,12 @@ def summarize_messages(
|
|
|
67
67
|
+ message_sequence_to_summarize[cutoff:]
|
|
68
68
|
)
|
|
69
69
|
|
|
70
|
-
|
|
70
|
+
agent_state.user_id
|
|
71
71
|
dummy_agent_id = agent_state.id
|
|
72
72
|
message_sequence = []
|
|
73
|
-
message_sequence.append(Message(
|
|
74
|
-
message_sequence.append(
|
|
75
|
-
|
|
76
|
-
)
|
|
77
|
-
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input))
|
|
73
|
+
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt))
|
|
74
|
+
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.assistant, text=MESSAGE_SUMMARY_REQUEST_ACK))
|
|
75
|
+
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.user, text=summary_input))
|
|
78
76
|
|
|
79
77
|
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
|
80
78
|
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
|
|
@@ -252,82 +250,6 @@ class DummyRecallMemory(RecallMemory):
|
|
|
252
250
|
return matches, len(matches)
|
|
253
251
|
|
|
254
252
|
|
|
255
|
-
class BaseRecallMemory(RecallMemory):
|
|
256
|
-
"""Recall memory based on base functions implemented by storage connectors"""
|
|
257
|
-
|
|
258
|
-
def __init__(self, agent_state, restrict_search_to_summaries=False):
|
|
259
|
-
# If true, the pool of messages that can be queried are the automated summaries only
|
|
260
|
-
# (generated when the conversation window needs to be shortened)
|
|
261
|
-
self.restrict_search_to_summaries = restrict_search_to_summaries
|
|
262
|
-
from letta.agent_store.storage import StorageConnector
|
|
263
|
-
|
|
264
|
-
self.agent_state = agent_state
|
|
265
|
-
|
|
266
|
-
# create embedding model
|
|
267
|
-
self.embed_model = embedding_model(agent_state.embedding_config)
|
|
268
|
-
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
|
269
|
-
|
|
270
|
-
# create storage backend
|
|
271
|
-
self.storage = StorageConnector.get_recall_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
|
|
272
|
-
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
|
273
|
-
self.cache = {}
|
|
274
|
-
|
|
275
|
-
def get_all(self, start=0, count=None):
|
|
276
|
-
start = 0 if start is None else int(start)
|
|
277
|
-
count = 0 if count is None else int(count)
|
|
278
|
-
results = self.storage.get_all(start, count)
|
|
279
|
-
results_json = [message.to_openai_dict() for message in results]
|
|
280
|
-
return results_json, len(results)
|
|
281
|
-
|
|
282
|
-
def text_search(self, query_string, count=None, start=None):
|
|
283
|
-
start = 0 if start is None else int(start)
|
|
284
|
-
count = 0 if count is None else int(count)
|
|
285
|
-
results = self.storage.query_text(query_string, count, start)
|
|
286
|
-
results_json = [message.to_openai_dict_search_results() for message in results]
|
|
287
|
-
return results_json, len(results)
|
|
288
|
-
|
|
289
|
-
def date_search(self, start_date, end_date, count=None, start=None):
|
|
290
|
-
start = 0 if start is None else int(start)
|
|
291
|
-
count = 0 if count is None else int(count)
|
|
292
|
-
results = self.storage.query_date(start_date, end_date, count, start)
|
|
293
|
-
results_json = [message.to_openai_dict_search_results() for message in results]
|
|
294
|
-
return results_json, len(results)
|
|
295
|
-
|
|
296
|
-
def compile(self) -> str:
|
|
297
|
-
total = self.storage.size()
|
|
298
|
-
system_count = self.storage.size(filters={"role": "system"})
|
|
299
|
-
user_count = self.storage.size(filters={"role": "user"})
|
|
300
|
-
assistant_count = self.storage.size(filters={"role": "assistant"})
|
|
301
|
-
function_count = self.storage.size(filters={"role": "function"})
|
|
302
|
-
other_count = total - (system_count + user_count + assistant_count + function_count)
|
|
303
|
-
|
|
304
|
-
memory_str = (
|
|
305
|
-
f"Statistics:"
|
|
306
|
-
+ f"\n{total} total messages"
|
|
307
|
-
+ f"\n{system_count} system"
|
|
308
|
-
+ f"\n{user_count} user"
|
|
309
|
-
+ f"\n{assistant_count} assistant"
|
|
310
|
-
+ f"\n{function_count} function"
|
|
311
|
-
+ f"\n{other_count} other"
|
|
312
|
-
)
|
|
313
|
-
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
|
314
|
-
|
|
315
|
-
def insert(self, message: Message):
|
|
316
|
-
self.storage.insert(message)
|
|
317
|
-
|
|
318
|
-
def insert_many(self, messages: List[Message]):
|
|
319
|
-
self.storage.insert_many(messages)
|
|
320
|
-
|
|
321
|
-
def save(self):
|
|
322
|
-
self.storage.save()
|
|
323
|
-
|
|
324
|
-
def __len__(self):
|
|
325
|
-
return self.storage.size()
|
|
326
|
-
|
|
327
|
-
def count(self) -> int:
|
|
328
|
-
return len(self)
|
|
329
|
-
|
|
330
|
-
|
|
331
253
|
class EmbeddingArchivalMemory(ArchivalMemory):
|
|
332
254
|
"""Archival memory with embedding based search"""
|
|
333
255
|
|
letta/metadata.py
CHANGED
|
@@ -14,7 +14,6 @@ from letta.schemas.api_key import APIKey
|
|
|
14
14
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
15
15
|
from letta.schemas.enums import ToolRuleType
|
|
16
16
|
from letta.schemas.llm_config import LLMConfig
|
|
17
|
-
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
|
18
17
|
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
|
19
18
|
from letta.schemas.user import User
|
|
20
19
|
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
|
@@ -66,40 +65,6 @@ class EmbeddingConfigColumn(TypeDecorator):
|
|
|
66
65
|
return value
|
|
67
66
|
|
|
68
67
|
|
|
69
|
-
class ToolCallColumn(TypeDecorator):
|
|
70
|
-
|
|
71
|
-
impl = JSON
|
|
72
|
-
cache_ok = True
|
|
73
|
-
|
|
74
|
-
def load_dialect_impl(self, dialect):
|
|
75
|
-
return dialect.type_descriptor(JSON())
|
|
76
|
-
|
|
77
|
-
def process_bind_param(self, value, dialect):
|
|
78
|
-
if value:
|
|
79
|
-
values = []
|
|
80
|
-
for v in value:
|
|
81
|
-
if isinstance(v, ToolCall):
|
|
82
|
-
values.append(v.model_dump())
|
|
83
|
-
else:
|
|
84
|
-
values.append(v)
|
|
85
|
-
return values
|
|
86
|
-
|
|
87
|
-
return value
|
|
88
|
-
|
|
89
|
-
def process_result_value(self, value, dialect):
|
|
90
|
-
if value:
|
|
91
|
-
tools = []
|
|
92
|
-
for tool_value in value:
|
|
93
|
-
if "function" in tool_value:
|
|
94
|
-
tool_call_function = ToolCallFunction(**tool_value["function"])
|
|
95
|
-
del tool_value["function"]
|
|
96
|
-
else:
|
|
97
|
-
tool_call_function = None
|
|
98
|
-
tools.append(ToolCall(function=tool_call_function, **tool_value))
|
|
99
|
-
return tools
|
|
100
|
-
return value
|
|
101
|
-
|
|
102
|
-
|
|
103
68
|
# TODO: eventually store providers?
|
|
104
69
|
# class Provider(Base):
|
|
105
70
|
# __tablename__ = "providers"
|
letta/o1_agent.py
CHANGED
|
@@ -20,7 +20,7 @@ def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
|
|
|
20
20
|
Returns:
|
|
21
21
|
Optional[str]: None is always returned as this function does not produce a response.
|
|
22
22
|
"""
|
|
23
|
-
self.interface.internal_monologue(message
|
|
23
|
+
self.interface.internal_monologue(message)
|
|
24
24
|
return None
|
|
25
25
|
|
|
26
26
|
|
|
@@ -34,7 +34,7 @@ def send_final_message(self: "Agent", message: str) -> Optional[str]:
|
|
|
34
34
|
Returns:
|
|
35
35
|
Optional[str]: None is always returned as this function does not produce a response.
|
|
36
36
|
"""
|
|
37
|
-
self.interface.internal_monologue(message
|
|
37
|
+
self.interface.internal_monologue(message)
|
|
38
38
|
return None
|
|
39
39
|
|
|
40
40
|
|
|
@@ -62,10 +62,15 @@ class O1Agent(Agent):
|
|
|
62
62
|
"""Run Agent.inner_step in a loop, terminate when final thinking message is sent or max_thinking_steps is reached"""
|
|
63
63
|
# assert ms is not None, "MetadataStore is required"
|
|
64
64
|
next_input_message = messages if isinstance(messages, list) else [messages]
|
|
65
|
+
|
|
65
66
|
counter = 0
|
|
66
67
|
total_usage = UsageStatistics()
|
|
67
68
|
step_count = 0
|
|
68
69
|
while step_count < self.max_thinking_steps:
|
|
70
|
+
# This is hacky but we need to do this for now
|
|
71
|
+
for m in next_input_message:
|
|
72
|
+
m.id = m._generate_id()
|
|
73
|
+
|
|
69
74
|
kwargs["ms"] = ms
|
|
70
75
|
kwargs["first_message"] = False
|
|
71
76
|
step_response = self.inner_step(
|
letta/offline_memory_agent.py
CHANGED
|
@@ -18,6 +18,7 @@ def trigger_rethink_memory(agent_state: "AgentState", message: Optional[str]) ->
|
|
|
18
18
|
|
|
19
19
|
"""
|
|
20
20
|
from letta import create_client
|
|
21
|
+
|
|
21
22
|
client = create_client()
|
|
22
23
|
agents = client.list_agents()
|
|
23
24
|
for agent in agents:
|
|
@@ -149,6 +150,11 @@ class OfflineMemoryAgent(Agent):
|
|
|
149
150
|
step_count = 0
|
|
150
151
|
|
|
151
152
|
while counter < self.max_memory_rethinks:
|
|
153
|
+
# This is hacky but we need to do this for now
|
|
154
|
+
# TODO: REMOVE THIS
|
|
155
|
+
for m in next_input_message:
|
|
156
|
+
m.id = m._generate_id()
|
|
157
|
+
|
|
152
158
|
kwargs["ms"] = ms
|
|
153
159
|
kwargs["first_message"] = False
|
|
154
160
|
step_response = self.inner_step(
|
letta/orm/__init__.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
from letta.orm.agents_tags import AgentsTags
|
|
1
2
|
from letta.orm.base import Base
|
|
2
3
|
from letta.orm.block import Block
|
|
3
4
|
from letta.orm.blocks_agents import BlocksAgents
|
|
4
5
|
from letta.orm.file import FileMetadata
|
|
5
6
|
from letta.orm.job import Job
|
|
7
|
+
from letta.orm.message import Message
|
|
6
8
|
from letta.orm.organization import Organization
|
|
7
9
|
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
|
8
10
|
from letta.orm.source import Source
|
letta/orm/file.py
CHANGED
letta/orm/message.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import JSON, DateTime, TypeDecorator
|
|
5
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
6
|
+
|
|
7
|
+
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
|
8
|
+
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
9
|
+
from letta.schemas.message import Message as PydanticMessage
|
|
10
|
+
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ToolCallColumn(TypeDecorator):
|
|
14
|
+
|
|
15
|
+
impl = JSON
|
|
16
|
+
cache_ok = True
|
|
17
|
+
|
|
18
|
+
def load_dialect_impl(self, dialect):
|
|
19
|
+
return dialect.type_descriptor(JSON())
|
|
20
|
+
|
|
21
|
+
def process_bind_param(self, value, dialect):
|
|
22
|
+
if value:
|
|
23
|
+
values = []
|
|
24
|
+
for v in value:
|
|
25
|
+
if isinstance(v, ToolCall):
|
|
26
|
+
values.append(v.model_dump())
|
|
27
|
+
else:
|
|
28
|
+
values.append(v)
|
|
29
|
+
return values
|
|
30
|
+
|
|
31
|
+
return value
|
|
32
|
+
|
|
33
|
+
def process_result_value(self, value, dialect):
|
|
34
|
+
if value:
|
|
35
|
+
tools = []
|
|
36
|
+
for tool_value in value:
|
|
37
|
+
if "function" in tool_value:
|
|
38
|
+
tool_call_function = ToolCallFunction(**tool_value["function"])
|
|
39
|
+
del tool_value["function"]
|
|
40
|
+
else:
|
|
41
|
+
tool_call_function = None
|
|
42
|
+
tools.append(ToolCall(function=tool_call_function, **tool_value))
|
|
43
|
+
return tools
|
|
44
|
+
return value
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
48
|
+
"""Defines data model for storing Message objects"""
|
|
49
|
+
|
|
50
|
+
__tablename__ = "messages"
|
|
51
|
+
__table_args__ = {"extend_existing": True}
|
|
52
|
+
__pydantic_model__ = PydanticMessage
|
|
53
|
+
|
|
54
|
+
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier")
|
|
55
|
+
role: Mapped[str] = mapped_column(doc="Message role (user/assistant/system/tool)")
|
|
56
|
+
text: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Message content")
|
|
57
|
+
model: Mapped[Optional[str]] = mapped_column(nullable=True, doc="LLM model used")
|
|
58
|
+
name: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Name for multi-agent scenarios")
|
|
59
|
+
tool_calls: Mapped[ToolCall] = mapped_column(ToolCallColumn, doc="Tool call information")
|
|
60
|
+
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
|
61
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
|
62
|
+
|
|
63
|
+
# Relationships
|
|
64
|
+
# TODO: Add in after Agent ORM is created
|
|
65
|
+
# agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
|
66
|
+
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
letta/orm/mixins.py
CHANGED
|
@@ -31,6 +31,22 @@ class UserMixin(Base):
|
|
|
31
31
|
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
|
32
32
|
|
|
33
33
|
|
|
34
|
+
class AgentMixin(Base):
|
|
35
|
+
"""Mixin for models that belong to an agent."""
|
|
36
|
+
|
|
37
|
+
__abstract__ = True
|
|
38
|
+
|
|
39
|
+
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FileMixin(Base):
|
|
43
|
+
"""Mixin for models that belong to a file."""
|
|
44
|
+
|
|
45
|
+
__abstract__ = True
|
|
46
|
+
|
|
47
|
+
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
|
|
48
|
+
|
|
49
|
+
|
|
34
50
|
class SourceMixin(Base):
|
|
35
51
|
"""Mixin for models (e.g. file) that belong to a source."""
|
|
36
52
|
|
letta/orm/organization.py
CHANGED
|
@@ -33,6 +33,7 @@ class Organization(SqlalchemyBase):
|
|
|
33
33
|
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
|
|
34
34
|
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
|
35
35
|
)
|
|
36
|
+
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
|
36
37
|
|
|
37
38
|
# TODO: Map these relationships later when we actually make these models
|
|
38
39
|
# below is just a suggestion
|
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from enum import Enum
|
|
1
3
|
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
|
2
4
|
|
|
3
|
-
from sqlalchemy import String, select
|
|
5
|
+
from sqlalchemy import String, func, select
|
|
4
6
|
from sqlalchemy.exc import DBAPIError
|
|
5
|
-
from sqlalchemy.orm import Mapped, mapped_column
|
|
7
|
+
from sqlalchemy.orm import Mapped, Session, mapped_column
|
|
6
8
|
|
|
7
9
|
from letta.log import get_logger
|
|
8
10
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
|
@@ -20,6 +22,11 @@ if TYPE_CHECKING:
|
|
|
20
22
|
logger = get_logger(__name__)
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
class AccessType(str, Enum):
|
|
26
|
+
ORGANIZATION = "organization"
|
|
27
|
+
USER = "user"
|
|
28
|
+
|
|
29
|
+
|
|
23
30
|
class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
24
31
|
__abstract__ = True
|
|
25
32
|
|
|
@@ -28,46 +35,68 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
28
35
|
id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
29
36
|
|
|
30
37
|
@classmethod
|
|
31
|
-
def
|
|
32
|
-
|
|
33
|
-
) -> List[Type["SqlalchemyBase"]]:
|
|
34
|
-
"""
|
|
35
|
-
List records with optional cursor (for pagination), limit, and automatic filtering.
|
|
38
|
+
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
|
|
39
|
+
"""Get a record by ID.
|
|
36
40
|
|
|
37
41
|
Args:
|
|
38
|
-
db_session:
|
|
39
|
-
|
|
40
|
-
limit: Maximum number of records to return.
|
|
41
|
-
**kwargs: Filters passed as equality conditions or iterable for IN filtering.
|
|
42
|
+
db_session: SQLAlchemy session
|
|
43
|
+
id: Record ID to retrieve
|
|
42
44
|
|
|
43
45
|
Returns:
|
|
44
|
-
|
|
46
|
+
Optional[SqlalchemyBase]: The record if found, None otherwise
|
|
45
47
|
"""
|
|
46
|
-
|
|
48
|
+
try:
|
|
49
|
+
return db_session.query(cls).filter(cls.id == id).first()
|
|
50
|
+
except DBAPIError:
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def list(
|
|
55
|
+
cls,
|
|
56
|
+
*,
|
|
57
|
+
db_session: "Session",
|
|
58
|
+
cursor: Optional[str] = None,
|
|
59
|
+
start_date: Optional[datetime] = None,
|
|
60
|
+
end_date: Optional[datetime] = None,
|
|
61
|
+
limit: Optional[int] = 50,
|
|
62
|
+
query_text: Optional[str] = None,
|
|
63
|
+
**kwargs,
|
|
64
|
+
) -> List[Type["SqlalchemyBase"]]:
|
|
65
|
+
"""List records with advanced filtering and pagination options."""
|
|
66
|
+
if start_date and end_date and start_date > end_date:
|
|
67
|
+
raise ValueError("start_date must be earlier than or equal to end_date")
|
|
68
|
+
|
|
69
|
+
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
|
47
70
|
with db_session as session:
|
|
48
|
-
# Start with a base query
|
|
49
71
|
query = select(cls)
|
|
50
72
|
|
|
51
73
|
# Apply filtering logic
|
|
52
74
|
for key, value in kwargs.items():
|
|
53
75
|
column = getattr(cls, key)
|
|
54
|
-
if isinstance(value, (list, tuple, set)):
|
|
76
|
+
if isinstance(value, (list, tuple, set)):
|
|
55
77
|
query = query.where(column.in_(value))
|
|
56
|
-
else:
|
|
78
|
+
else:
|
|
57
79
|
query = query.where(column == value)
|
|
58
80
|
|
|
59
|
-
#
|
|
81
|
+
# Date range filtering
|
|
82
|
+
if start_date:
|
|
83
|
+
query = query.filter(cls.created_at >= start_date)
|
|
84
|
+
if end_date:
|
|
85
|
+
query = query.filter(cls.created_at <= end_date)
|
|
86
|
+
|
|
87
|
+
# Cursor-based pagination
|
|
60
88
|
if cursor:
|
|
61
89
|
query = query.where(cls.id > cursor)
|
|
62
90
|
|
|
63
|
-
#
|
|
91
|
+
# Apply text search
|
|
92
|
+
if query_text:
|
|
93
|
+
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
|
94
|
+
|
|
95
|
+
# Handle ordering and soft deletes
|
|
64
96
|
if hasattr(cls, "is_deleted"):
|
|
65
97
|
query = query.where(cls.is_deleted == False)
|
|
66
|
-
|
|
67
|
-
# Add ordering and limit
|
|
68
98
|
query = query.order_by(cls.id).limit(limit)
|
|
69
99
|
|
|
70
|
-
# Execute the query and return results as model instances
|
|
71
100
|
return list(session.execute(query).scalars())
|
|
72
101
|
|
|
73
102
|
@classmethod
|
|
@@ -77,6 +106,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
77
106
|
identifier: Optional[str] = None,
|
|
78
107
|
actor: Optional["User"] = None,
|
|
79
108
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
109
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
80
110
|
**kwargs,
|
|
81
111
|
) -> Type["SqlalchemyBase"]:
|
|
82
112
|
"""The primary accessor for an ORM record.
|
|
@@ -108,7 +138,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
108
138
|
query_conditions.append(", ".join(f"{key}='{value}'" for key, value in kwargs.items()))
|
|
109
139
|
|
|
110
140
|
if actor:
|
|
111
|
-
query = cls.apply_access_predicate(query, actor, access)
|
|
141
|
+
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
112
142
|
query_conditions.append(f"access level in {access} for actor='{actor}'")
|
|
113
143
|
|
|
114
144
|
if hasattr(cls, "is_deleted"):
|
|
@@ -170,12 +200,66 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
170
200
|
session.refresh(self)
|
|
171
201
|
return self
|
|
172
202
|
|
|
203
|
+
@classmethod
|
|
204
|
+
def size(
|
|
205
|
+
cls,
|
|
206
|
+
*,
|
|
207
|
+
db_session: "Session",
|
|
208
|
+
actor: Optional["User"] = None,
|
|
209
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
210
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
211
|
+
**kwargs,
|
|
212
|
+
) -> int:
|
|
213
|
+
"""
|
|
214
|
+
Get the count of rows that match the provided filters.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
db_session: SQLAlchemy session
|
|
218
|
+
**kwargs: Filters to apply to the query (e.g., column_name=value)
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
int: The count of rows that match the filters
|
|
222
|
+
|
|
223
|
+
Raises:
|
|
224
|
+
DBAPIError: If a database error occurs
|
|
225
|
+
"""
|
|
226
|
+
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
|
227
|
+
|
|
228
|
+
with db_session as session:
|
|
229
|
+
query = select(func.count()).select_from(cls)
|
|
230
|
+
|
|
231
|
+
if actor:
|
|
232
|
+
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
233
|
+
|
|
234
|
+
# Apply filtering logic based on kwargs
|
|
235
|
+
for key, value in kwargs.items():
|
|
236
|
+
if value:
|
|
237
|
+
column = getattr(cls, key, None)
|
|
238
|
+
if not column:
|
|
239
|
+
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
|
240
|
+
if isinstance(value, (list, tuple, set)): # Check for iterables
|
|
241
|
+
query = query.where(column.in_(value))
|
|
242
|
+
else: # Single value for equality filtering
|
|
243
|
+
query = query.where(column == value)
|
|
244
|
+
|
|
245
|
+
# Handle soft deletes if the class has the 'is_deleted' attribute
|
|
246
|
+
if hasattr(cls, "is_deleted"):
|
|
247
|
+
query = query.where(cls.is_deleted == False)
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
count = session.execute(query).scalar()
|
|
251
|
+
return count if count else 0
|
|
252
|
+
except DBAPIError as e:
|
|
253
|
+
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
|
254
|
+
raise e
|
|
255
|
+
|
|
173
256
|
@classmethod
|
|
174
257
|
def apply_access_predicate(
|
|
175
258
|
cls,
|
|
176
259
|
query: "Select",
|
|
177
260
|
actor: "User",
|
|
178
261
|
access: List[Literal["read", "write", "admin"]],
|
|
262
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
179
263
|
) -> "Select":
|
|
180
264
|
"""applies a WHERE clause restricting results to the given actor and access level
|
|
181
265
|
Args:
|
|
@@ -189,10 +273,18 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
189
273
|
the sqlalchemy select statement restricted to the given access.
|
|
190
274
|
"""
|
|
191
275
|
del access # entrypoint for row-level permissions. Defaults to "same org as the actor, all permissions" at the moment
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
276
|
+
if access_type == AccessType.ORGANIZATION:
|
|
277
|
+
org_id = getattr(actor, "organization_id", None)
|
|
278
|
+
if not org_id:
|
|
279
|
+
raise ValueError(f"object {actor} has no organization accessor")
|
|
280
|
+
return query.where(cls.organization_id == org_id, cls.is_deleted == False)
|
|
281
|
+
elif access_type == AccessType.USER:
|
|
282
|
+
user_id = getattr(actor, "id", None)
|
|
283
|
+
if not user_id:
|
|
284
|
+
raise ValueError(f"object {actor} has no user accessor")
|
|
285
|
+
return query.where(cls.user_id == user_id, cls.is_deleted == False)
|
|
286
|
+
else:
|
|
287
|
+
raise ValueError(f"unknown access_type: {access_type}")
|
|
196
288
|
|
|
197
289
|
@classmethod
|
|
198
290
|
def _handle_dbapi_error(cls, e: DBAPIError):
|
letta/schemas/letta_base.py
CHANGED
|
@@ -33,18 +33,19 @@ class LettaBase(BaseModel):
|
|
|
33
33
|
def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
|
|
34
34
|
prefix = prefix or cls.__id_prefix__
|
|
35
35
|
|
|
36
|
-
# TODO: generate ID from regex pattern?
|
|
37
|
-
def _generate_id() -> str:
|
|
38
|
-
return f"{prefix}-{uuid.uuid4()}"
|
|
39
|
-
|
|
40
36
|
return Field(
|
|
41
37
|
...,
|
|
42
38
|
description=cls._id_description(prefix),
|
|
43
39
|
pattern=cls._id_regex_pattern(prefix),
|
|
44
40
|
examples=[cls._id_example(prefix)],
|
|
45
|
-
default_factory=_generate_id,
|
|
41
|
+
default_factory=cls._generate_id,
|
|
46
42
|
)
|
|
47
43
|
|
|
44
|
+
@classmethod
|
|
45
|
+
def _generate_id(cls, prefix: Optional[str] = None) -> str:
|
|
46
|
+
prefix = prefix or cls.__id_prefix__
|
|
47
|
+
return f"{prefix}-{uuid.uuid4()}"
|
|
48
|
+
|
|
48
49
|
# def _generate_id(self) -> str:
|
|
49
50
|
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
|
|
50
51
|
|
|
@@ -78,7 +79,7 @@ class LettaBase(BaseModel):
|
|
|
78
79
|
"""
|
|
79
80
|
_ = values # for SCA
|
|
80
81
|
if isinstance(v, UUID):
|
|
81
|
-
logger.
|
|
82
|
+
logger.debug(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
|
|
82
83
|
return f"{cls.__id_prefix__}-{v}"
|
|
83
84
|
return v
|
|
84
85
|
|
letta/schemas/message.py
CHANGED
|
@@ -105,7 +105,7 @@ class Message(BaseMessage):
|
|
|
105
105
|
id: str = BaseMessage.generate_id_field()
|
|
106
106
|
role: MessageRole = Field(..., description="The role of the participant.")
|
|
107
107
|
text: Optional[str] = Field(None, description="The text of the message.")
|
|
108
|
-
|
|
108
|
+
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
|
109
109
|
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
|
110
110
|
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
|
111
111
|
name: Optional[str] = Field(None, description="The name of the participant.")
|
|
@@ -281,7 +281,6 @@ class Message(BaseMessage):
|
|
|
281
281
|
)
|
|
282
282
|
if id is not None:
|
|
283
283
|
return Message(
|
|
284
|
-
user_id=user_id,
|
|
285
284
|
agent_id=agent_id,
|
|
286
285
|
model=model,
|
|
287
286
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -295,7 +294,6 @@ class Message(BaseMessage):
|
|
|
295
294
|
)
|
|
296
295
|
else:
|
|
297
296
|
return Message(
|
|
298
|
-
user_id=user_id,
|
|
299
297
|
agent_id=agent_id,
|
|
300
298
|
model=model,
|
|
301
299
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -328,7 +326,6 @@ class Message(BaseMessage):
|
|
|
328
326
|
|
|
329
327
|
if id is not None:
|
|
330
328
|
return Message(
|
|
331
|
-
user_id=user_id,
|
|
332
329
|
agent_id=agent_id,
|
|
333
330
|
model=model,
|
|
334
331
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -342,7 +339,6 @@ class Message(BaseMessage):
|
|
|
342
339
|
)
|
|
343
340
|
else:
|
|
344
341
|
return Message(
|
|
345
|
-
user_id=user_id,
|
|
346
342
|
agent_id=agent_id,
|
|
347
343
|
model=model,
|
|
348
344
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -375,7 +371,6 @@ class Message(BaseMessage):
|
|
|
375
371
|
# If we're going from tool-call style
|
|
376
372
|
if id is not None:
|
|
377
373
|
return Message(
|
|
378
|
-
user_id=user_id,
|
|
379
374
|
agent_id=agent_id,
|
|
380
375
|
model=model,
|
|
381
376
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -389,7 +384,6 @@ class Message(BaseMessage):
|
|
|
389
384
|
)
|
|
390
385
|
else:
|
|
391
386
|
return Message(
|
|
392
|
-
user_id=user_id,
|
|
393
387
|
agent_id=agent_id,
|
|
394
388
|
model=model,
|
|
395
389
|
# standard fields expected in an OpenAI ChatCompletion message object
|
|
@@ -409,7 +409,7 @@ def get_agent_messages(
|
|
|
409
409
|
return server.get_agent_recall_cursor(
|
|
410
410
|
user_id=actor.id,
|
|
411
411
|
agent_id=agent_id,
|
|
412
|
-
|
|
412
|
+
cursor=before,
|
|
413
413
|
limit=limit,
|
|
414
414
|
reverse=True,
|
|
415
415
|
return_message_object=msg_object,
|
|
@@ -465,7 +465,7 @@ async def send_message(
|
|
|
465
465
|
@router.post(
|
|
466
466
|
"/{agent_id}/messages/stream",
|
|
467
467
|
response_model=None,
|
|
468
|
-
operation_id="
|
|
468
|
+
operation_id="create_agent_message_stream",
|
|
469
469
|
responses={
|
|
470
470
|
200: {
|
|
471
471
|
"description": "Successful response",
|
|
@@ -76,7 +76,7 @@ def get_block(
|
|
|
76
76
|
raise HTTPException(status_code=404, detail="Block not found")
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
@router.patch("/{block_id}/attach", response_model=Block, operation_id="
|
|
79
|
+
@router.patch("/{block_id}/attach", response_model=Block, operation_id="link_agent_memory_block")
|
|
80
80
|
def link_agent_memory_block(
|
|
81
81
|
block_id: str,
|
|
82
82
|
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
|
@@ -96,7 +96,7 @@ def link_agent_memory_block(
|
|
|
96
96
|
return block
|
|
97
97
|
|
|
98
98
|
|
|
99
|
-
@router.patch("/{block_id}/detach", response_model=Memory, operation_id="
|
|
99
|
+
@router.patch("/{block_id}/detach", response_model=Memory, operation_id="unlink_agent_memory_block")
|
|
100
100
|
def unlink_agent_memory_block(
|
|
101
101
|
block_id: str,
|
|
102
102
|
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|