letta-nightly 0.6.4.dev20241213193437__py3-none-any.whl → 0.6.4.dev20241215104129__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/agent.py +54 -45
- letta/chat_only_agent.py +6 -8
- letta/cli/cli.py +2 -10
- letta/client/client.py +121 -138
- letta/config.py +0 -161
- letta/main.py +3 -8
- letta/memory.py +3 -14
- letta/o1_agent.py +1 -5
- letta/offline_memory_agent.py +2 -6
- letta/orm/__init__.py +2 -0
- letta/orm/agent.py +109 -0
- letta/orm/agents_tags.py +10 -18
- letta/orm/block.py +29 -4
- letta/orm/blocks_agents.py +5 -11
- letta/orm/custom_columns.py +152 -0
- letta/orm/message.py +3 -38
- letta/orm/organization.py +2 -7
- letta/orm/passage.py +10 -32
- letta/orm/source.py +5 -25
- letta/orm/sources_agents.py +13 -0
- letta/orm/sqlalchemy_base.py +54 -30
- letta/orm/tool.py +1 -19
- letta/orm/tools_agents.py +7 -24
- letta/orm/user.py +3 -4
- letta/schemas/agent.py +48 -65
- letta/schemas/memory.py +2 -1
- letta/schemas/sandbox_config.py +12 -1
- letta/server/rest_api/app.py +0 -5
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +1 -1
- letta/server/rest_api/routers/v1/agents.py +99 -78
- letta/server/rest_api/routers/v1/blocks.py +22 -25
- letta/server/rest_api/routers/v1/jobs.py +4 -4
- letta/server/rest_api/routers/v1/sandbox_configs.py +10 -10
- letta/server/rest_api/routers/v1/sources.py +12 -12
- letta/server/rest_api/routers/v1/tools.py +35 -15
- letta/server/rest_api/routers/v1/users.py +0 -46
- letta/server/server.py +172 -716
- letta/server/ws_api/server.py +0 -5
- letta/services/agent_manager.py +405 -0
- letta/services/block_manager.py +13 -21
- letta/services/helpers/agent_manager_helper.py +90 -0
- letta/services/organization_manager.py +0 -1
- letta/services/passage_manager.py +62 -62
- letta/services/sandbox_config_manager.py +3 -3
- letta/services/source_manager.py +22 -1
- letta/services/user_manager.py +11 -6
- letta/utils.py +2 -2
- {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/METADATA +1 -1
- {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/RECORD +53 -57
- letta/metadata.py +0 -407
- letta/schemas/agents_tags.py +0 -33
- letta/schemas/api_key.py +0 -21
- letta/schemas/blocks_agents.py +0 -32
- letta/schemas/tools_agents.py +0 -32
- letta/server/rest_api/routers/openai/assistants/threads.py +0 -338
- letta/services/agents_tags_manager.py +0 -64
- letta/services/blocks_agents_manager.py +0 -106
- letta/services/tools_agents_manager.py +0 -94
- {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
from typing import List, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from sqlalchemy import JSON
|
|
6
|
+
from sqlalchemy.types import BINARY, TypeDecorator
|
|
7
|
+
|
|
8
|
+
from letta.schemas.embedding_config import EmbeddingConfig
|
|
9
|
+
from letta.schemas.enums import ToolRuleType
|
|
10
|
+
from letta.schemas.llm_config import LLMConfig
|
|
11
|
+
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
|
12
|
+
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EmbeddingConfigColumn(TypeDecorator):
|
|
16
|
+
"""Custom type for storing EmbeddingConfig as JSON."""
|
|
17
|
+
|
|
18
|
+
impl = JSON
|
|
19
|
+
cache_ok = True
|
|
20
|
+
|
|
21
|
+
def load_dialect_impl(self, dialect):
|
|
22
|
+
return dialect.type_descriptor(JSON())
|
|
23
|
+
|
|
24
|
+
def process_bind_param(self, value, dialect):
|
|
25
|
+
if value and isinstance(value, EmbeddingConfig):
|
|
26
|
+
return value.model_dump()
|
|
27
|
+
return value
|
|
28
|
+
|
|
29
|
+
def process_result_value(self, value, dialect):
|
|
30
|
+
if value:
|
|
31
|
+
return EmbeddingConfig(**value)
|
|
32
|
+
return value
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LLMConfigColumn(TypeDecorator):
|
|
36
|
+
"""Custom type for storing LLMConfig as JSON."""
|
|
37
|
+
|
|
38
|
+
impl = JSON
|
|
39
|
+
cache_ok = True
|
|
40
|
+
|
|
41
|
+
def load_dialect_impl(self, dialect):
|
|
42
|
+
return dialect.type_descriptor(JSON())
|
|
43
|
+
|
|
44
|
+
def process_bind_param(self, value, dialect):
|
|
45
|
+
if value and isinstance(value, LLMConfig):
|
|
46
|
+
return value.model_dump()
|
|
47
|
+
return value
|
|
48
|
+
|
|
49
|
+
def process_result_value(self, value, dialect):
|
|
50
|
+
if value:
|
|
51
|
+
return LLMConfig(**value)
|
|
52
|
+
return value
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ToolRulesColumn(TypeDecorator):
|
|
56
|
+
"""Custom type for storing a list of ToolRules as JSON"""
|
|
57
|
+
|
|
58
|
+
impl = JSON
|
|
59
|
+
cache_ok = True
|
|
60
|
+
|
|
61
|
+
def load_dialect_impl(self, dialect):
|
|
62
|
+
return dialect.type_descriptor(JSON())
|
|
63
|
+
|
|
64
|
+
def process_bind_param(self, value, dialect):
|
|
65
|
+
"""Convert a list of ToolRules to JSON-serializable format."""
|
|
66
|
+
if value:
|
|
67
|
+
data = [rule.model_dump() for rule in value]
|
|
68
|
+
for d in data:
|
|
69
|
+
d["type"] = d["type"].value
|
|
70
|
+
|
|
71
|
+
for d in data:
|
|
72
|
+
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
|
|
73
|
+
return data
|
|
74
|
+
return value
|
|
75
|
+
|
|
76
|
+
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
|
|
77
|
+
"""Convert JSON back to a list of ToolRules."""
|
|
78
|
+
if value:
|
|
79
|
+
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
|
|
80
|
+
return value
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
|
|
84
|
+
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
|
|
85
|
+
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
|
|
86
|
+
if rule_type == ToolRuleType.run_first:
|
|
87
|
+
return InitToolRule(**data)
|
|
88
|
+
elif rule_type == ToolRuleType.exit_loop:
|
|
89
|
+
return TerminalToolRule(**data)
|
|
90
|
+
elif rule_type == ToolRuleType.constrain_child_tools:
|
|
91
|
+
rule = ChildToolRule(**data)
|
|
92
|
+
return rule
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unknown tool rule type: {rule_type}")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ToolCallColumn(TypeDecorator):
|
|
98
|
+
|
|
99
|
+
impl = JSON
|
|
100
|
+
cache_ok = True
|
|
101
|
+
|
|
102
|
+
def load_dialect_impl(self, dialect):
|
|
103
|
+
return dialect.type_descriptor(JSON())
|
|
104
|
+
|
|
105
|
+
def process_bind_param(self, value, dialect):
|
|
106
|
+
if value:
|
|
107
|
+
values = []
|
|
108
|
+
for v in value:
|
|
109
|
+
if isinstance(v, ToolCall):
|
|
110
|
+
values.append(v.model_dump())
|
|
111
|
+
else:
|
|
112
|
+
values.append(v)
|
|
113
|
+
return values
|
|
114
|
+
|
|
115
|
+
return value
|
|
116
|
+
|
|
117
|
+
def process_result_value(self, value, dialect):
|
|
118
|
+
if value:
|
|
119
|
+
tools = []
|
|
120
|
+
for tool_value in value:
|
|
121
|
+
if "function" in tool_value:
|
|
122
|
+
tool_call_function = ToolCallFunction(**tool_value["function"])
|
|
123
|
+
del tool_value["function"]
|
|
124
|
+
else:
|
|
125
|
+
tool_call_function = None
|
|
126
|
+
tools.append(ToolCall(function=tool_call_function, **tool_value))
|
|
127
|
+
return tools
|
|
128
|
+
return value
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class CommonVector(TypeDecorator):
|
|
132
|
+
"""Common type for representing vectors in SQLite"""
|
|
133
|
+
|
|
134
|
+
impl = BINARY
|
|
135
|
+
cache_ok = True
|
|
136
|
+
|
|
137
|
+
def load_dialect_impl(self, dialect):
|
|
138
|
+
return dialect.type_descriptor(BINARY())
|
|
139
|
+
|
|
140
|
+
def process_bind_param(self, value, dialect):
|
|
141
|
+
if value is None:
|
|
142
|
+
return value
|
|
143
|
+
if isinstance(value, list):
|
|
144
|
+
value = np.array(value, dtype=np.float32)
|
|
145
|
+
return base64.b64encode(value.tobytes())
|
|
146
|
+
|
|
147
|
+
def process_result_value(self, value, dialect):
|
|
148
|
+
if not value:
|
|
149
|
+
return value
|
|
150
|
+
if dialect.name == "sqlite":
|
|
151
|
+
value = base64.b64decode(value)
|
|
152
|
+
return np.frombuffer(value, dtype=np.float32)
|
letta/orm/message.py
CHANGED
|
@@ -1,46 +1,12 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
-
from sqlalchemy import JSON, TypeDecorator
|
|
4
3
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
4
|
|
|
5
|
+
from letta.orm.custom_columns import ToolCallColumn
|
|
6
6
|
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
|
7
7
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
8
8
|
from letta.schemas.message import Message as PydanticMessage
|
|
9
|
-
from letta.schemas.openai.chat_completions import ToolCall
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class ToolCallColumn(TypeDecorator):
|
|
13
|
-
|
|
14
|
-
impl = JSON
|
|
15
|
-
cache_ok = True
|
|
16
|
-
|
|
17
|
-
def load_dialect_impl(self, dialect):
|
|
18
|
-
return dialect.type_descriptor(JSON())
|
|
19
|
-
|
|
20
|
-
def process_bind_param(self, value, dialect):
|
|
21
|
-
if value:
|
|
22
|
-
values = []
|
|
23
|
-
for v in value:
|
|
24
|
-
if isinstance(v, ToolCall):
|
|
25
|
-
values.append(v.model_dump())
|
|
26
|
-
else:
|
|
27
|
-
values.append(v)
|
|
28
|
-
return values
|
|
29
|
-
|
|
30
|
-
return value
|
|
31
|
-
|
|
32
|
-
def process_result_value(self, value, dialect):
|
|
33
|
-
if value:
|
|
34
|
-
tools = []
|
|
35
|
-
for tool_value in value:
|
|
36
|
-
if "function" in tool_value:
|
|
37
|
-
tool_call_function = ToolCallFunction(**tool_value["function"])
|
|
38
|
-
del tool_value["function"]
|
|
39
|
-
else:
|
|
40
|
-
tool_call_function = None
|
|
41
|
-
tools.append(ToolCall(function=tool_call_function, **tool_value))
|
|
42
|
-
return tools
|
|
43
|
-
return value
|
|
9
|
+
from letta.schemas.openai.chat_completions import ToolCall
|
|
44
10
|
|
|
45
11
|
|
|
46
12
|
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
@@ -59,6 +25,5 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
|
59
25
|
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
|
60
26
|
|
|
61
27
|
# Relationships
|
|
62
|
-
|
|
63
|
-
# agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
|
28
|
+
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
|
64
29
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
letta/orm/organization.py
CHANGED
|
@@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization
|
|
|
7
7
|
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
9
|
|
|
10
|
+
from letta.orm.agent import Agent
|
|
10
11
|
from letta.orm.file import FileMetadata
|
|
11
12
|
from letta.orm.tool import Tool
|
|
12
13
|
from letta.orm.user import User
|
|
@@ -25,7 +26,6 @@ class Organization(SqlalchemyBase):
|
|
|
25
26
|
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
|
26
27
|
blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
|
|
27
28
|
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
|
|
28
|
-
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
|
|
29
29
|
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
|
|
30
30
|
sandbox_configs: Mapped[List["SandboxConfig"]] = relationship(
|
|
31
31
|
"SandboxConfig", back_populates="organization", cascade="all, delete-orphan"
|
|
@@ -36,10 +36,5 @@ class Organization(SqlalchemyBase):
|
|
|
36
36
|
|
|
37
37
|
# relationships
|
|
38
38
|
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
|
39
|
+
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
|
39
40
|
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
|
|
40
|
-
|
|
41
|
-
# TODO: Map these relationships later when we actually make these models
|
|
42
|
-
# below is just a suggestion
|
|
43
|
-
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
|
44
|
-
# tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
|
45
|
-
# documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan")
|
letta/orm/passage.py
CHANGED
|
@@ -1,19 +1,16 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import
|
|
3
|
-
from sqlalchemy import Column, String, DateTime, JSON, ForeignKey
|
|
4
|
-
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
|
-
from sqlalchemy.types import TypeDecorator, BINARY
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
6
3
|
|
|
7
|
-
import
|
|
8
|
-
import
|
|
4
|
+
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
|
|
5
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
9
6
|
|
|
7
|
+
from letta.config import LettaConfig
|
|
8
|
+
from letta.constants import MAX_EMBEDDING_DIM
|
|
9
|
+
from letta.orm.custom_columns import CommonVector
|
|
10
|
+
from letta.orm.mixins import FileMixin, OrganizationMixin
|
|
10
11
|
from letta.orm.source import EmbeddingConfigColumn
|
|
11
12
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
12
|
-
from letta.orm.mixins import FileMixin, OrganizationMixin
|
|
13
13
|
from letta.schemas.passage import Passage as PydanticPassage
|
|
14
|
-
|
|
15
|
-
from letta.config import LettaConfig
|
|
16
|
-
from letta.constants import MAX_EMBEDDING_DIM
|
|
17
14
|
from letta.settings import settings
|
|
18
15
|
|
|
19
16
|
config = LettaConfig()
|
|
@@ -21,32 +18,12 @@ config = LettaConfig()
|
|
|
21
18
|
if TYPE_CHECKING:
|
|
22
19
|
from letta.orm.organization import Organization
|
|
23
20
|
|
|
24
|
-
class CommonVector(TypeDecorator):
|
|
25
|
-
"""Common type for representing vectors in SQLite"""
|
|
26
|
-
impl = BINARY
|
|
27
|
-
cache_ok = True
|
|
28
|
-
|
|
29
|
-
def load_dialect_impl(self, dialect):
|
|
30
|
-
return dialect.type_descriptor(BINARY())
|
|
31
|
-
|
|
32
|
-
def process_bind_param(self, value, dialect):
|
|
33
|
-
if value is None:
|
|
34
|
-
return value
|
|
35
|
-
if isinstance(value, list):
|
|
36
|
-
value = np.array(value, dtype=np.float32)
|
|
37
|
-
return base64.b64encode(value.tobytes())
|
|
38
21
|
|
|
39
|
-
|
|
40
|
-
if not value:
|
|
41
|
-
return value
|
|
42
|
-
if dialect.name == "sqlite":
|
|
43
|
-
value = base64.b64decode(value)
|
|
44
|
-
return np.frombuffer(value, dtype=np.float32)
|
|
45
|
-
|
|
46
|
-
# TODO: After migration to Passage, will need to manually delete passages where files
|
|
22
|
+
# TODO: After migration to Passage, will need to manually delete passages where files
|
|
47
23
|
# are deleted on web
|
|
48
24
|
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
|
49
25
|
"""Defines data model for storing Passages"""
|
|
26
|
+
|
|
50
27
|
__tablename__ = "passages"
|
|
51
28
|
__table_args__ = {"extend_existing": True}
|
|
52
29
|
__pydantic_model__ = PydanticPassage
|
|
@@ -59,6 +36,7 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
|
|
59
36
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
|
60
37
|
if settings.letta_pg_uri_no_default:
|
|
61
38
|
from pgvector.sqlalchemy import Vector
|
|
39
|
+
|
|
62
40
|
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
|
63
41
|
else:
|
|
64
42
|
embedding = Column(CommonVector)
|
letta/orm/source.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional
|
|
2
2
|
|
|
3
|
-
from sqlalchemy import JSON
|
|
3
|
+
from sqlalchemy import JSON
|
|
4
4
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
5
|
|
|
6
|
+
from letta.orm import FileMetadata
|
|
7
|
+
from letta.orm.custom_columns import EmbeddingConfigColumn
|
|
6
8
|
from letta.orm.mixins import OrganizationMixin
|
|
7
9
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
8
10
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
@@ -12,28 +14,6 @@ if TYPE_CHECKING:
|
|
|
12
14
|
from letta.orm.organization import Organization
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class EmbeddingConfigColumn(TypeDecorator):
|
|
16
|
-
"""Custom type for storing EmbeddingConfig as JSON"""
|
|
17
|
-
|
|
18
|
-
impl = JSON
|
|
19
|
-
cache_ok = True
|
|
20
|
-
|
|
21
|
-
def load_dialect_impl(self, dialect):
|
|
22
|
-
return dialect.type_descriptor(JSON())
|
|
23
|
-
|
|
24
|
-
def process_bind_param(self, value, dialect):
|
|
25
|
-
if value:
|
|
26
|
-
# return vars(value)
|
|
27
|
-
if isinstance(value, EmbeddingConfig):
|
|
28
|
-
return value.model_dump()
|
|
29
|
-
return value
|
|
30
|
-
|
|
31
|
-
def process_result_value(self, value, dialect):
|
|
32
|
-
if value:
|
|
33
|
-
return EmbeddingConfig(**value)
|
|
34
|
-
return value
|
|
35
|
-
|
|
36
|
-
|
|
37
17
|
class Source(SqlalchemyBase, OrganizationMixin):
|
|
38
18
|
"""A source represents an embedded text passage"""
|
|
39
19
|
|
|
@@ -47,5 +27,5 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
|
|
47
27
|
|
|
48
28
|
# relationships
|
|
49
29
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
|
50
|
-
files: Mapped[List["
|
|
51
|
-
|
|
30
|
+
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
|
31
|
+
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from sqlalchemy import ForeignKey, String
|
|
2
|
+
from sqlalchemy.orm import Mapped, mapped_column
|
|
3
|
+
|
|
4
|
+
from letta.orm.base import Base
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SourcesAgents(Base):
|
|
8
|
+
"""Agents can have zero to many sources"""
|
|
9
|
+
|
|
10
|
+
__tablename__ = "sources_agents"
|
|
11
|
+
|
|
12
|
+
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
|
13
|
+
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)
|
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from enum import Enum
|
|
3
|
-
from typing import TYPE_CHECKING, List, Literal, Optional
|
|
4
|
-
import sqlite3
|
|
3
|
+
from typing import TYPE_CHECKING, List, Literal, Optional
|
|
5
4
|
|
|
6
5
|
from sqlalchemy import String, desc, func, or_, select
|
|
7
6
|
from sqlalchemy.exc import DBAPIError
|
|
@@ -9,12 +8,12 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
|
|
9
8
|
|
|
10
9
|
from letta.log import get_logger
|
|
11
10
|
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
|
12
|
-
from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance
|
|
13
11
|
from letta.orm.errors import (
|
|
14
12
|
ForeignKeyConstraintViolationError,
|
|
15
13
|
NoResultFound,
|
|
16
14
|
UniqueConstraintViolationError,
|
|
17
15
|
)
|
|
16
|
+
from letta.orm.sqlite_functions import adapt_array
|
|
18
17
|
|
|
19
18
|
if TYPE_CHECKING:
|
|
20
19
|
from pydantic import BaseModel
|
|
@@ -64,11 +63,26 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
64
63
|
query_text: Optional[str] = None,
|
|
65
64
|
query_embedding: Optional[List[float]] = None,
|
|
66
65
|
ascending: bool = True,
|
|
66
|
+
tags: Optional[List[str]] = None,
|
|
67
|
+
match_all_tags: bool = False,
|
|
67
68
|
**kwargs,
|
|
68
|
-
) -> List[
|
|
69
|
+
) -> List["SqlalchemyBase"]:
|
|
69
70
|
"""
|
|
70
71
|
List records with cursor-based pagination, ordering by created_at.
|
|
71
72
|
Cursor is an ID, but pagination is based on the cursor object's created_at value.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
db_session: SQLAlchemy session
|
|
76
|
+
cursor: ID of the last item seen (for pagination)
|
|
77
|
+
start_date: Filter items after this date
|
|
78
|
+
end_date: Filter items before this date
|
|
79
|
+
limit: Maximum number of items to return
|
|
80
|
+
query_text: Text to search for
|
|
81
|
+
query_embedding: Vector to search for similar embeddings
|
|
82
|
+
ascending: Sort direction
|
|
83
|
+
tags: List of tags to filter by
|
|
84
|
+
match_all_tags: If True, return items matching all tags. If False, match any tag.
|
|
85
|
+
**kwargs: Additional filters to apply
|
|
72
86
|
"""
|
|
73
87
|
if start_date and end_date and start_date > end_date:
|
|
74
88
|
raise ValueError("start_date must be earlier than or equal to end_date")
|
|
@@ -84,7 +98,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
84
98
|
|
|
85
99
|
query = select(cls)
|
|
86
100
|
|
|
87
|
-
#
|
|
101
|
+
# Handle tag filtering if the model has tags
|
|
102
|
+
if tags and hasattr(cls, "tags"):
|
|
103
|
+
query = select(cls)
|
|
104
|
+
|
|
105
|
+
if match_all_tags:
|
|
106
|
+
# Match ALL tags - use subqueries
|
|
107
|
+
for tag in tags:
|
|
108
|
+
subquery = select(cls.tags.property.mapper.class_.agent_id).where(cls.tags.property.mapper.class_.tag == tag)
|
|
109
|
+
query = query.filter(cls.id.in_(subquery))
|
|
110
|
+
else:
|
|
111
|
+
# Match ANY tag - use join and filter
|
|
112
|
+
query = (
|
|
113
|
+
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id) # Deduplicate results
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Group by primary key and all necessary columns to avoid JSON comparison
|
|
117
|
+
query = query.group_by(cls.id)
|
|
118
|
+
|
|
119
|
+
# Apply filtering logic from kwargs
|
|
88
120
|
for key, value in kwargs.items():
|
|
89
121
|
column = getattr(cls, key)
|
|
90
122
|
if isinstance(value, (list, tuple, set)):
|
|
@@ -98,9 +130,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
98
130
|
if end_date:
|
|
99
131
|
query = query.filter(cls.created_at < end_date)
|
|
100
132
|
|
|
101
|
-
# Cursor-based pagination
|
|
102
|
-
# TODO: There is a really nasty race condition issue here with Sqlite
|
|
103
|
-
# TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
|
|
133
|
+
# Cursor-based pagination
|
|
104
134
|
if cursor_obj:
|
|
105
135
|
if ascending:
|
|
106
136
|
query = query.where(cls.created_at >= cursor_obj.created_at).where(
|
|
@@ -111,40 +141,34 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
111
141
|
or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
|
|
112
142
|
)
|
|
113
143
|
|
|
114
|
-
#
|
|
144
|
+
# Text search
|
|
115
145
|
if query_text:
|
|
116
|
-
from sqlalchemy import func
|
|
117
146
|
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
|
118
147
|
|
|
119
|
-
#
|
|
148
|
+
# Embedding search (for Passages)
|
|
120
149
|
is_ordered = False
|
|
121
150
|
if query_embedding:
|
|
122
|
-
# check if embedding column exists. should only exist for passages
|
|
123
151
|
if not hasattr(cls, "embedding"):
|
|
124
152
|
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
|
125
|
-
|
|
153
|
+
|
|
126
154
|
from letta.settings import settings
|
|
155
|
+
|
|
127
156
|
if settings.letta_pg_uri_no_default:
|
|
128
157
|
# PostgreSQL with pgvector
|
|
129
|
-
from pgvector.sqlalchemy import Vector
|
|
130
158
|
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
|
131
159
|
else:
|
|
132
160
|
# SQLite with custom vector type
|
|
133
|
-
from sqlalchemy import func
|
|
134
|
-
|
|
135
161
|
query_embedding_binary = adapt_array(query_embedding)
|
|
136
162
|
query = query.order_by(
|
|
137
|
-
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
|
138
|
-
cls.created_at.asc(),
|
|
139
|
-
cls.id.asc()
|
|
163
|
+
func.cosine_distance(cls.embedding, query_embedding_binary).asc(), cls.created_at.asc(), cls.id.asc()
|
|
140
164
|
)
|
|
141
165
|
is_ordered = True
|
|
142
166
|
|
|
143
|
-
# Handle
|
|
167
|
+
# Handle soft deletes
|
|
144
168
|
if hasattr(cls, "is_deleted"):
|
|
145
169
|
query = query.where(cls.is_deleted == False)
|
|
146
|
-
|
|
147
|
-
# Apply ordering
|
|
170
|
+
|
|
171
|
+
# Apply ordering
|
|
148
172
|
if not is_ordered:
|
|
149
173
|
if ascending:
|
|
150
174
|
query = query.order_by(cls.created_at, cls.id)
|
|
@@ -164,7 +188,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
164
188
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
165
189
|
access_type: AccessType = AccessType.ORGANIZATION,
|
|
166
190
|
**kwargs,
|
|
167
|
-
) ->
|
|
191
|
+
) -> "SqlalchemyBase":
|
|
168
192
|
"""The primary accessor for an ORM record.
|
|
169
193
|
Args:
|
|
170
194
|
db_session: the database session to use when retrieving the record
|
|
@@ -207,7 +231,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
207
231
|
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
|
208
232
|
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
|
|
209
233
|
|
|
210
|
-
def create(self, db_session: "Session", actor: Optional["User"] = None) ->
|
|
234
|
+
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
211
235
|
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
212
236
|
|
|
213
237
|
if actor:
|
|
@@ -221,7 +245,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
221
245
|
except DBAPIError as e:
|
|
222
246
|
self._handle_dbapi_error(e)
|
|
223
247
|
|
|
224
|
-
def delete(self, db_session: "Session", actor: Optional["User"] = None) ->
|
|
248
|
+
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
225
249
|
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
226
250
|
|
|
227
251
|
if actor:
|
|
@@ -245,7 +269,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
245
269
|
else:
|
|
246
270
|
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
|
247
271
|
|
|
248
|
-
def update(self, db_session: "Session", actor: Optional["User"] = None) ->
|
|
272
|
+
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
|
249
273
|
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
|
250
274
|
if actor:
|
|
251
275
|
self._set_created_and_updated_by_fields(actor.id)
|
|
@@ -388,14 +412,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
388
412
|
raise
|
|
389
413
|
|
|
390
414
|
@property
|
|
391
|
-
def __pydantic_model__(self) ->
|
|
415
|
+
def __pydantic_model__(self) -> "BaseModel":
|
|
392
416
|
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
|
|
393
417
|
|
|
394
|
-
def to_pydantic(self) ->
|
|
418
|
+
def to_pydantic(self) -> "BaseModel":
|
|
395
419
|
"""converts to the basic pydantic model counterpart"""
|
|
396
420
|
return self.__pydantic_model__.model_validate(self)
|
|
397
421
|
|
|
398
|
-
def to_record(self) ->
|
|
422
|
+
def to_record(self) -> "BaseModel":
|
|
399
423
|
"""Deprecated accessor for to_pydantic"""
|
|
400
424
|
logger.warning("to_record is deprecated, use to_pydantic instead.")
|
|
401
|
-
return self.to_pydantic()
|
|
425
|
+
return self.to_pydantic()
|
letta/orm/tool.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, List, Optional
|
|
2
2
|
|
|
3
|
-
from sqlalchemy import JSON, String, UniqueConstraint
|
|
3
|
+
from sqlalchemy import JSON, String, UniqueConstraint
|
|
4
4
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
5
|
|
|
6
6
|
# TODO everything in functions should live in this model
|
|
@@ -11,7 +11,6 @@ from letta.schemas.tool import Tool as PydanticTool
|
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from letta.orm.organization import Organization
|
|
14
|
-
from letta.orm.tools_agents import ToolsAgents
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class Tool(SqlalchemyBase, OrganizationMixin):
|
|
@@ -42,20 +41,3 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
|
|
42
41
|
|
|
43
42
|
# relationships
|
|
44
43
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
|
45
|
-
tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
# Add event listener to update tool_name in ToolsAgents when Tool name changes
|
|
49
|
-
@event.listens_for(Tool, "before_update")
|
|
50
|
-
def update_tool_name_in_tools_agents(mapper, connection, target):
|
|
51
|
-
"""Update tool_name in ToolsAgents when Tool name changes."""
|
|
52
|
-
state = target._sa_instance_state
|
|
53
|
-
history = state.get_history("name", passive=True)
|
|
54
|
-
if not history.has_changes():
|
|
55
|
-
return
|
|
56
|
-
|
|
57
|
-
# Get the new name and update all associated ToolsAgents records
|
|
58
|
-
new_name = target.name
|
|
59
|
-
from letta.orm.tools_agents import ToolsAgents
|
|
60
|
-
|
|
61
|
-
connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name))
|
letta/orm/tools_agents.py
CHANGED
|
@@ -1,32 +1,15 @@
|
|
|
1
|
-
from sqlalchemy import ForeignKey,
|
|
2
|
-
from sqlalchemy.orm import Mapped, mapped_column
|
|
1
|
+
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
|
2
|
+
from sqlalchemy.orm import Mapped, mapped_column
|
|
3
3
|
|
|
4
|
-
from letta.orm
|
|
5
|
-
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
|
|
4
|
+
from letta.orm import Base
|
|
6
5
|
|
|
7
6
|
|
|
8
|
-
class ToolsAgents(
|
|
7
|
+
class ToolsAgents(Base):
|
|
9
8
|
"""Agents can have one or many tools associated with them."""
|
|
10
9
|
|
|
11
10
|
__tablename__ = "tools_agents"
|
|
12
|
-
|
|
13
|
-
__table_args__ = (
|
|
14
|
-
UniqueConstraint(
|
|
15
|
-
"agent_id",
|
|
16
|
-
"tool_name",
|
|
17
|
-
name="unique_tool_per_agent",
|
|
18
|
-
),
|
|
19
|
-
ForeignKeyConstraint(
|
|
20
|
-
["tool_id"],
|
|
21
|
-
["tools.id"],
|
|
22
|
-
name="fk_tool_id",
|
|
23
|
-
),
|
|
24
|
-
)
|
|
11
|
+
__table_args__ = (UniqueConstraint("agent_id", "tool_id", name="unique_agent_tool"),)
|
|
25
12
|
|
|
26
13
|
# Each agent must have unique tool names
|
|
27
|
-
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
|
28
|
-
tool_id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
29
|
-
tool_name: Mapped[str] = mapped_column(String, primary_key=True)
|
|
30
|
-
|
|
31
|
-
# relationships
|
|
32
|
-
tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents")
|
|
14
|
+
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
|
|
15
|
+
tool_id: Mapped[str] = mapped_column(String, ForeignKey("tools.id", ondelete="CASCADE"), primary_key=True)
|
letta/orm/user.py
CHANGED
|
@@ -20,10 +20,9 @@ class User(SqlalchemyBase, OrganizationMixin):
|
|
|
20
20
|
|
|
21
21
|
# relationships
|
|
22
22
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
|
|
23
|
-
jobs: Mapped[List["Job"]] = relationship(
|
|
23
|
+
jobs: Mapped[List["Job"]] = relationship(
|
|
24
|
+
"Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan"
|
|
25
|
+
)
|
|
24
26
|
|
|
25
27
|
# TODO: Add this back later potentially
|
|
26
|
-
# agents: Mapped[List["Agent"]] = relationship(
|
|
27
|
-
# "Agent", secondary="users_agents", back_populates="users", doc="the agents associated with this user."
|
|
28
|
-
# )
|
|
29
28
|
# tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.")
|