letta-nightly 0.6.4.dev20241213193437__py3-none-any.whl → 0.6.4.dev20241215104129__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

Files changed (62) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +54 -45
  3. letta/chat_only_agent.py +6 -8
  4. letta/cli/cli.py +2 -10
  5. letta/client/client.py +121 -138
  6. letta/config.py +0 -161
  7. letta/main.py +3 -8
  8. letta/memory.py +3 -14
  9. letta/o1_agent.py +1 -5
  10. letta/offline_memory_agent.py +2 -6
  11. letta/orm/__init__.py +2 -0
  12. letta/orm/agent.py +109 -0
  13. letta/orm/agents_tags.py +10 -18
  14. letta/orm/block.py +29 -4
  15. letta/orm/blocks_agents.py +5 -11
  16. letta/orm/custom_columns.py +152 -0
  17. letta/orm/message.py +3 -38
  18. letta/orm/organization.py +2 -7
  19. letta/orm/passage.py +10 -32
  20. letta/orm/source.py +5 -25
  21. letta/orm/sources_agents.py +13 -0
  22. letta/orm/sqlalchemy_base.py +54 -30
  23. letta/orm/tool.py +1 -19
  24. letta/orm/tools_agents.py +7 -24
  25. letta/orm/user.py +3 -4
  26. letta/schemas/agent.py +48 -65
  27. letta/schemas/memory.py +2 -1
  28. letta/schemas/sandbox_config.py +12 -1
  29. letta/server/rest_api/app.py +0 -5
  30. letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +1 -1
  31. letta/server/rest_api/routers/v1/agents.py +99 -78
  32. letta/server/rest_api/routers/v1/blocks.py +22 -25
  33. letta/server/rest_api/routers/v1/jobs.py +4 -4
  34. letta/server/rest_api/routers/v1/sandbox_configs.py +10 -10
  35. letta/server/rest_api/routers/v1/sources.py +12 -12
  36. letta/server/rest_api/routers/v1/tools.py +35 -15
  37. letta/server/rest_api/routers/v1/users.py +0 -46
  38. letta/server/server.py +172 -716
  39. letta/server/ws_api/server.py +0 -5
  40. letta/services/agent_manager.py +405 -0
  41. letta/services/block_manager.py +13 -21
  42. letta/services/helpers/agent_manager_helper.py +90 -0
  43. letta/services/organization_manager.py +0 -1
  44. letta/services/passage_manager.py +62 -62
  45. letta/services/sandbox_config_manager.py +3 -3
  46. letta/services/source_manager.py +22 -1
  47. letta/services/user_manager.py +11 -6
  48. letta/utils.py +2 -2
  49. {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/METADATA +1 -1
  50. {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/RECORD +53 -57
  51. letta/metadata.py +0 -407
  52. letta/schemas/agents_tags.py +0 -33
  53. letta/schemas/api_key.py +0 -21
  54. letta/schemas/blocks_agents.py +0 -32
  55. letta/schemas/tools_agents.py +0 -32
  56. letta/server/rest_api/routers/openai/assistants/threads.py +0 -338
  57. letta/services/agents_tags_manager.py +0 -64
  58. letta/services/blocks_agents_manager.py +0 -106
  59. letta/services/tools_agents_manager.py +0 -94
  60. {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/LICENSE +0 -0
  61. {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/WHEEL +0 -0
  62. {letta_nightly-0.6.4.dev20241213193437.dist-info → letta_nightly-0.6.4.dev20241215104129.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,152 @@
1
+ import base64
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ from sqlalchemy import JSON
6
+ from sqlalchemy.types import BINARY, TypeDecorator
7
+
8
+ from letta.schemas.embedding_config import EmbeddingConfig
9
+ from letta.schemas.enums import ToolRuleType
10
+ from letta.schemas.llm_config import LLMConfig
11
+ from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
12
+ from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
13
+
14
+
15
+ class EmbeddingConfigColumn(TypeDecorator):
16
+ """Custom type for storing EmbeddingConfig as JSON."""
17
+
18
+ impl = JSON
19
+ cache_ok = True
20
+
21
+ def load_dialect_impl(self, dialect):
22
+ return dialect.type_descriptor(JSON())
23
+
24
+ def process_bind_param(self, value, dialect):
25
+ if value and isinstance(value, EmbeddingConfig):
26
+ return value.model_dump()
27
+ return value
28
+
29
+ def process_result_value(self, value, dialect):
30
+ if value:
31
+ return EmbeddingConfig(**value)
32
+ return value
33
+
34
+
35
+ class LLMConfigColumn(TypeDecorator):
36
+ """Custom type for storing LLMConfig as JSON."""
37
+
38
+ impl = JSON
39
+ cache_ok = True
40
+
41
+ def load_dialect_impl(self, dialect):
42
+ return dialect.type_descriptor(JSON())
43
+
44
+ def process_bind_param(self, value, dialect):
45
+ if value and isinstance(value, LLMConfig):
46
+ return value.model_dump()
47
+ return value
48
+
49
+ def process_result_value(self, value, dialect):
50
+ if value:
51
+ return LLMConfig(**value)
52
+ return value
53
+
54
+
55
+ class ToolRulesColumn(TypeDecorator):
56
+ """Custom type for storing a list of ToolRules as JSON"""
57
+
58
+ impl = JSON
59
+ cache_ok = True
60
+
61
+ def load_dialect_impl(self, dialect):
62
+ return dialect.type_descriptor(JSON())
63
+
64
+ def process_bind_param(self, value, dialect):
65
+ """Convert a list of ToolRules to JSON-serializable format."""
66
+ if value:
67
+ data = [rule.model_dump() for rule in value]
68
+ for d in data:
69
+ d["type"] = d["type"].value
70
+
71
+ for d in data:
72
+ assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
73
+ return data
74
+ return value
75
+
76
+ def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
77
+ """Convert JSON back to a list of ToolRules."""
78
+ if value:
79
+ return [self.deserialize_tool_rule(rule_data) for rule_data in value]
80
+ return value
81
+
82
+ @staticmethod
83
+ def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
84
+ """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
85
+ rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
86
+ if rule_type == ToolRuleType.run_first:
87
+ return InitToolRule(**data)
88
+ elif rule_type == ToolRuleType.exit_loop:
89
+ return TerminalToolRule(**data)
90
+ elif rule_type == ToolRuleType.constrain_child_tools:
91
+ rule = ChildToolRule(**data)
92
+ return rule
93
+ else:
94
+ raise ValueError(f"Unknown tool rule type: {rule_type}")
95
+
96
+
97
+ class ToolCallColumn(TypeDecorator):
98
+
99
+ impl = JSON
100
+ cache_ok = True
101
+
102
+ def load_dialect_impl(self, dialect):
103
+ return dialect.type_descriptor(JSON())
104
+
105
+ def process_bind_param(self, value, dialect):
106
+ if value:
107
+ values = []
108
+ for v in value:
109
+ if isinstance(v, ToolCall):
110
+ values.append(v.model_dump())
111
+ else:
112
+ values.append(v)
113
+ return values
114
+
115
+ return value
116
+
117
+ def process_result_value(self, value, dialect):
118
+ if value:
119
+ tools = []
120
+ for tool_value in value:
121
+ if "function" in tool_value:
122
+ tool_call_function = ToolCallFunction(**tool_value["function"])
123
+ del tool_value["function"]
124
+ else:
125
+ tool_call_function = None
126
+ tools.append(ToolCall(function=tool_call_function, **tool_value))
127
+ return tools
128
+ return value
129
+
130
+
131
+ class CommonVector(TypeDecorator):
132
+ """Common type for representing vectors in SQLite"""
133
+
134
+ impl = BINARY
135
+ cache_ok = True
136
+
137
+ def load_dialect_impl(self, dialect):
138
+ return dialect.type_descriptor(BINARY())
139
+
140
+ def process_bind_param(self, value, dialect):
141
+ if value is None:
142
+ return value
143
+ if isinstance(value, list):
144
+ value = np.array(value, dtype=np.float32)
145
+ return base64.b64encode(value.tobytes())
146
+
147
+ def process_result_value(self, value, dialect):
148
+ if not value:
149
+ return value
150
+ if dialect.name == "sqlite":
151
+ value = base64.b64decode(value)
152
+ return np.frombuffer(value, dtype=np.float32)
letta/orm/message.py CHANGED
@@ -1,46 +1,12 @@
1
1
  from typing import Optional
2
2
 
3
- from sqlalchemy import JSON, TypeDecorator
4
3
  from sqlalchemy.orm import Mapped, mapped_column, relationship
5
4
 
5
+ from letta.orm.custom_columns import ToolCallColumn
6
6
  from letta.orm.mixins import AgentMixin, OrganizationMixin
7
7
  from letta.orm.sqlalchemy_base import SqlalchemyBase
8
8
  from letta.schemas.message import Message as PydanticMessage
9
- from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
10
-
11
-
12
- class ToolCallColumn(TypeDecorator):
13
-
14
- impl = JSON
15
- cache_ok = True
16
-
17
- def load_dialect_impl(self, dialect):
18
- return dialect.type_descriptor(JSON())
19
-
20
- def process_bind_param(self, value, dialect):
21
- if value:
22
- values = []
23
- for v in value:
24
- if isinstance(v, ToolCall):
25
- values.append(v.model_dump())
26
- else:
27
- values.append(v)
28
- return values
29
-
30
- return value
31
-
32
- def process_result_value(self, value, dialect):
33
- if value:
34
- tools = []
35
- for tool_value in value:
36
- if "function" in tool_value:
37
- tool_call_function = ToolCallFunction(**tool_value["function"])
38
- del tool_value["function"]
39
- else:
40
- tool_call_function = None
41
- tools.append(ToolCall(function=tool_call_function, **tool_value))
42
- return tools
43
- return value
9
+ from letta.schemas.openai.chat_completions import ToolCall
44
10
 
45
11
 
46
12
  class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
@@ -59,6 +25,5 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
59
25
  tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
60
26
 
61
27
  # Relationships
62
- # TODO: Add in after Agent ORM is created
63
- # agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
28
+ agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
64
29
  organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
letta/orm/organization.py CHANGED
@@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization
7
7
 
8
8
  if TYPE_CHECKING:
9
9
 
10
+ from letta.orm.agent import Agent
10
11
  from letta.orm.file import FileMetadata
11
12
  from letta.orm.tool import Tool
12
13
  from letta.orm.user import User
@@ -25,7 +26,6 @@ class Organization(SqlalchemyBase):
25
26
  tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
26
27
  blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
27
28
  sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
28
- agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
29
29
  files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
30
30
  sandbox_configs: Mapped[List["SandboxConfig"]] = relationship(
31
31
  "SandboxConfig", back_populates="organization", cascade="all, delete-orphan"
@@ -36,10 +36,5 @@ class Organization(SqlalchemyBase):
36
36
 
37
37
  # relationships
38
38
  messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
39
+ agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
39
40
  passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
40
-
41
- # TODO: Map these relationships later when we actually make these models
42
- # below is just a suggestion
43
- # agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
44
- # tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
45
- # documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan")
letta/orm/passage.py CHANGED
@@ -1,19 +1,16 @@
1
1
  from datetime import datetime
2
- from typing import Optional, TYPE_CHECKING
3
- from sqlalchemy import Column, String, DateTime, JSON, ForeignKey
4
- from sqlalchemy.orm import Mapped, mapped_column, relationship
5
- from sqlalchemy.types import TypeDecorator, BINARY
2
+ from typing import TYPE_CHECKING, Optional
6
3
 
7
- import numpy as np
8
- import base64
4
+ from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
5
+ from sqlalchemy.orm import Mapped, mapped_column, relationship
9
6
 
7
+ from letta.config import LettaConfig
8
+ from letta.constants import MAX_EMBEDDING_DIM
9
+ from letta.orm.custom_columns import CommonVector
10
+ from letta.orm.mixins import FileMixin, OrganizationMixin
10
11
  from letta.orm.source import EmbeddingConfigColumn
11
12
  from letta.orm.sqlalchemy_base import SqlalchemyBase
12
- from letta.orm.mixins import FileMixin, OrganizationMixin
13
13
  from letta.schemas.passage import Passage as PydanticPassage
14
-
15
- from letta.config import LettaConfig
16
- from letta.constants import MAX_EMBEDDING_DIM
17
14
  from letta.settings import settings
18
15
 
19
16
  config = LettaConfig()
@@ -21,32 +18,12 @@ config = LettaConfig()
21
18
  if TYPE_CHECKING:
22
19
  from letta.orm.organization import Organization
23
20
 
24
- class CommonVector(TypeDecorator):
25
- """Common type for representing vectors in SQLite"""
26
- impl = BINARY
27
- cache_ok = True
28
-
29
- def load_dialect_impl(self, dialect):
30
- return dialect.type_descriptor(BINARY())
31
-
32
- def process_bind_param(self, value, dialect):
33
- if value is None:
34
- return value
35
- if isinstance(value, list):
36
- value = np.array(value, dtype=np.float32)
37
- return base64.b64encode(value.tobytes())
38
21
 
39
- def process_result_value(self, value, dialect):
40
- if not value:
41
- return value
42
- if dialect.name == "sqlite":
43
- value = base64.b64decode(value)
44
- return np.frombuffer(value, dtype=np.float32)
45
-
46
- # TODO: After migration to Passage, will need to manually delete passages where files
22
+ # TODO: After migration to Passage, will need to manually delete passages where files
47
23
  # are deleted on web
48
24
  class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
49
25
  """Defines data model for storing Passages"""
26
+
50
27
  __tablename__ = "passages"
51
28
  __table_args__ = {"extend_existing": True}
52
29
  __pydantic_model__ = PydanticPassage
@@ -59,6 +36,7 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
59
36
  created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
60
37
  if settings.letta_pg_uri_no_default:
61
38
  from pgvector.sqlalchemy import Vector
39
+
62
40
  embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
63
41
  else:
64
42
  embedding = Column(CommonVector)
letta/orm/source.py CHANGED
@@ -1,8 +1,10 @@
1
1
  from typing import TYPE_CHECKING, List, Optional
2
2
 
3
- from sqlalchemy import JSON, TypeDecorator
3
+ from sqlalchemy import JSON
4
4
  from sqlalchemy.orm import Mapped, mapped_column, relationship
5
5
 
6
+ from letta.orm import FileMetadata
7
+ from letta.orm.custom_columns import EmbeddingConfigColumn
6
8
  from letta.orm.mixins import OrganizationMixin
7
9
  from letta.orm.sqlalchemy_base import SqlalchemyBase
8
10
  from letta.schemas.embedding_config import EmbeddingConfig
@@ -12,28 +14,6 @@ if TYPE_CHECKING:
12
14
  from letta.orm.organization import Organization
13
15
 
14
16
 
15
- class EmbeddingConfigColumn(TypeDecorator):
16
- """Custom type for storing EmbeddingConfig as JSON"""
17
-
18
- impl = JSON
19
- cache_ok = True
20
-
21
- def load_dialect_impl(self, dialect):
22
- return dialect.type_descriptor(JSON())
23
-
24
- def process_bind_param(self, value, dialect):
25
- if value:
26
- # return vars(value)
27
- if isinstance(value, EmbeddingConfig):
28
- return value.model_dump()
29
- return value
30
-
31
- def process_result_value(self, value, dialect):
32
- if value:
33
- return EmbeddingConfig(**value)
34
- return value
35
-
36
-
37
17
  class Source(SqlalchemyBase, OrganizationMixin):
38
18
  """A source represents an embedded text passage"""
39
19
 
@@ -47,5 +27,5 @@ class Source(SqlalchemyBase, OrganizationMixin):
47
27
 
48
28
  # relationships
49
29
  organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
50
- files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
51
- # agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
30
+ files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
31
+ agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
@@ -0,0 +1,13 @@
1
+ from sqlalchemy import ForeignKey, String
2
+ from sqlalchemy.orm import Mapped, mapped_column
3
+
4
+ from letta.orm.base import Base
5
+
6
+
7
+ class SourcesAgents(Base):
8
+ """Agents can have zero to many sources"""
9
+
10
+ __tablename__ = "sources_agents"
11
+
12
+ agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
13
+ source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)
@@ -1,7 +1,6 @@
1
1
  from datetime import datetime
2
2
  from enum import Enum
3
- from typing import TYPE_CHECKING, List, Literal, Optional, Type
4
- import sqlite3
3
+ from typing import TYPE_CHECKING, List, Literal, Optional
5
4
 
6
5
  from sqlalchemy import String, desc, func, or_, select
7
6
  from sqlalchemy.exc import DBAPIError
@@ -9,12 +8,12 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
9
8
 
10
9
  from letta.log import get_logger
11
10
  from letta.orm.base import Base, CommonSqlalchemyMetaMixins
12
- from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance
13
11
  from letta.orm.errors import (
14
12
  ForeignKeyConstraintViolationError,
15
13
  NoResultFound,
16
14
  UniqueConstraintViolationError,
17
15
  )
16
+ from letta.orm.sqlite_functions import adapt_array
18
17
 
19
18
  if TYPE_CHECKING:
20
19
  from pydantic import BaseModel
@@ -64,11 +63,26 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
64
63
  query_text: Optional[str] = None,
65
64
  query_embedding: Optional[List[float]] = None,
66
65
  ascending: bool = True,
66
+ tags: Optional[List[str]] = None,
67
+ match_all_tags: bool = False,
67
68
  **kwargs,
68
- ) -> List[Type["SqlalchemyBase"]]:
69
+ ) -> List["SqlalchemyBase"]:
69
70
  """
70
71
  List records with cursor-based pagination, ordering by created_at.
71
72
  Cursor is an ID, but pagination is based on the cursor object's created_at value.
73
+
74
+ Args:
75
+ db_session: SQLAlchemy session
76
+ cursor: ID of the last item seen (for pagination)
77
+ start_date: Filter items after this date
78
+ end_date: Filter items before this date
79
+ limit: Maximum number of items to return
80
+ query_text: Text to search for
81
+ query_embedding: Vector to search for similar embeddings
82
+ ascending: Sort direction
83
+ tags: List of tags to filter by
84
+ match_all_tags: If True, return items matching all tags. If False, match any tag.
85
+ **kwargs: Additional filters to apply
72
86
  """
73
87
  if start_date and end_date and start_date > end_date:
74
88
  raise ValueError("start_date must be earlier than or equal to end_date")
@@ -84,7 +98,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
84
98
 
85
99
  query = select(cls)
86
100
 
87
- # Apply filtering logic
101
+ # Handle tag filtering if the model has tags
102
+ if tags and hasattr(cls, "tags"):
103
+ query = select(cls)
104
+
105
+ if match_all_tags:
106
+ # Match ALL tags - use subqueries
107
+ for tag in tags:
108
+ subquery = select(cls.tags.property.mapper.class_.agent_id).where(cls.tags.property.mapper.class_.tag == tag)
109
+ query = query.filter(cls.id.in_(subquery))
110
+ else:
111
+ # Match ANY tag - use join and filter
112
+ query = (
113
+ query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id) # Deduplicate results
114
+ )
115
+
116
+ # Group by primary key and all necessary columns to avoid JSON comparison
117
+ query = query.group_by(cls.id)
118
+
119
+ # Apply filtering logic from kwargs
88
120
  for key, value in kwargs.items():
89
121
  column = getattr(cls, key)
90
122
  if isinstance(value, (list, tuple, set)):
@@ -98,9 +130,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
98
130
  if end_date:
99
131
  query = query.filter(cls.created_at < end_date)
100
132
 
101
- # Cursor-based pagination using created_at
102
- # TODO: There is a really nasty race condition issue here with Sqlite
103
- # TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
133
+ # Cursor-based pagination
104
134
  if cursor_obj:
105
135
  if ascending:
106
136
  query = query.where(cls.created_at >= cursor_obj.created_at).where(
@@ -111,40 +141,34 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
111
141
  or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
112
142
  )
113
143
 
114
- # Apply text search
144
+ # Text search
115
145
  if query_text:
116
- from sqlalchemy import func
117
146
  query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
118
147
 
119
- # Apply embedding search (Passages)
148
+ # Embedding search (for Passages)
120
149
  is_ordered = False
121
150
  if query_embedding:
122
- # check if embedding column exists. should only exist for passages
123
151
  if not hasattr(cls, "embedding"):
124
152
  raise ValueError(f"Class {cls.__name__} does not have an embedding column")
125
-
153
+
126
154
  from letta.settings import settings
155
+
127
156
  if settings.letta_pg_uri_no_default:
128
157
  # PostgreSQL with pgvector
129
- from pgvector.sqlalchemy import Vector
130
158
  query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
131
159
  else:
132
160
  # SQLite with custom vector type
133
- from sqlalchemy import func
134
-
135
161
  query_embedding_binary = adapt_array(query_embedding)
136
162
  query = query.order_by(
137
- func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
138
- cls.created_at.asc(),
139
- cls.id.asc()
163
+ func.cosine_distance(cls.embedding, query_embedding_binary).asc(), cls.created_at.asc(), cls.id.asc()
140
164
  )
141
165
  is_ordered = True
142
166
 
143
- # Handle ordering and soft deletes
167
+ # Handle soft deletes
144
168
  if hasattr(cls, "is_deleted"):
145
169
  query = query.where(cls.is_deleted == False)
146
-
147
- # Apply ordering by created_at
170
+
171
+ # Apply ordering
148
172
  if not is_ordered:
149
173
  if ascending:
150
174
  query = query.order_by(cls.created_at, cls.id)
@@ -164,7 +188,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
164
188
  access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
165
189
  access_type: AccessType = AccessType.ORGANIZATION,
166
190
  **kwargs,
167
- ) -> Type["SqlalchemyBase"]:
191
+ ) -> "SqlalchemyBase":
168
192
  """The primary accessor for an ORM record.
169
193
  Args:
170
194
  db_session: the database session to use when retrieving the record
@@ -207,7 +231,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
207
231
  conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
208
232
  raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
209
233
 
210
- def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
234
+ def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
211
235
  logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
212
236
 
213
237
  if actor:
@@ -221,7 +245,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
221
245
  except DBAPIError as e:
222
246
  self._handle_dbapi_error(e)
223
247
 
224
- def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
248
+ def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
225
249
  logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
226
250
 
227
251
  if actor:
@@ -245,7 +269,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
245
269
  else:
246
270
  logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
247
271
 
248
- def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
272
+ def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
249
273
  logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
250
274
  if actor:
251
275
  self._set_created_and_updated_by_fields(actor.id)
@@ -388,14 +412,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
388
412
  raise
389
413
 
390
414
  @property
391
- def __pydantic_model__(self) -> Type["BaseModel"]:
415
+ def __pydantic_model__(self) -> "BaseModel":
392
416
  raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
393
417
 
394
- def to_pydantic(self) -> Type["BaseModel"]:
418
+ def to_pydantic(self) -> "BaseModel":
395
419
  """converts to the basic pydantic model counterpart"""
396
420
  return self.__pydantic_model__.model_validate(self)
397
421
 
398
- def to_record(self) -> Type["BaseModel"]:
422
+ def to_record(self) -> "BaseModel":
399
423
  """Deprecated accessor for to_pydantic"""
400
424
  logger.warning("to_record is deprecated, use to_pydantic instead.")
401
- return self.to_pydantic()
425
+ return self.to_pydantic()
letta/orm/tool.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from typing import TYPE_CHECKING, List, Optional
2
2
 
3
- from sqlalchemy import JSON, String, UniqueConstraint, event
3
+ from sqlalchemy import JSON, String, UniqueConstraint
4
4
  from sqlalchemy.orm import Mapped, mapped_column, relationship
5
5
 
6
6
  # TODO everything in functions should live in this model
@@ -11,7 +11,6 @@ from letta.schemas.tool import Tool as PydanticTool
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from letta.orm.organization import Organization
14
- from letta.orm.tools_agents import ToolsAgents
15
14
 
16
15
 
17
16
  class Tool(SqlalchemyBase, OrganizationMixin):
@@ -42,20 +41,3 @@ class Tool(SqlalchemyBase, OrganizationMixin):
42
41
 
43
42
  # relationships
44
43
  organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
45
- tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan")
46
-
47
-
48
- # Add event listener to update tool_name in ToolsAgents when Tool name changes
49
- @event.listens_for(Tool, "before_update")
50
- def update_tool_name_in_tools_agents(mapper, connection, target):
51
- """Update tool_name in ToolsAgents when Tool name changes."""
52
- state = target._sa_instance_state
53
- history = state.get_history("name", passive=True)
54
- if not history.has_changes():
55
- return
56
-
57
- # Get the new name and update all associated ToolsAgents records
58
- new_name = target.name
59
- from letta.orm.tools_agents import ToolsAgents
60
-
61
- connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name))
letta/orm/tools_agents.py CHANGED
@@ -1,32 +1,15 @@
1
- from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
2
- from sqlalchemy.orm import Mapped, mapped_column, relationship
1
+ from sqlalchemy import ForeignKey, String, UniqueConstraint
2
+ from sqlalchemy.orm import Mapped, mapped_column
3
3
 
4
- from letta.orm.sqlalchemy_base import SqlalchemyBase
5
- from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
4
+ from letta.orm import Base
6
5
 
7
6
 
8
- class ToolsAgents(SqlalchemyBase):
7
+ class ToolsAgents(Base):
9
8
  """Agents can have one or many tools associated with them."""
10
9
 
11
10
  __tablename__ = "tools_agents"
12
- __pydantic_model__ = PydanticToolsAgents
13
- __table_args__ = (
14
- UniqueConstraint(
15
- "agent_id",
16
- "tool_name",
17
- name="unique_tool_per_agent",
18
- ),
19
- ForeignKeyConstraint(
20
- ["tool_id"],
21
- ["tools.id"],
22
- name="fk_tool_id",
23
- ),
24
- )
11
+ __table_args__ = (UniqueConstraint("agent_id", "tool_id", name="unique_agent_tool"),)
25
12
 
26
13
  # Each agent must have unique tool names
27
- agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
28
- tool_id: Mapped[str] = mapped_column(String, primary_key=True)
29
- tool_name: Mapped[str] = mapped_column(String, primary_key=True)
30
-
31
- # relationships
32
- tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents")
14
+ agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
15
+ tool_id: Mapped[str] = mapped_column(String, ForeignKey("tools.id", ondelete="CASCADE"), primary_key=True)
letta/orm/user.py CHANGED
@@ -20,10 +20,9 @@ class User(SqlalchemyBase, OrganizationMixin):
20
20
 
21
21
  # relationships
22
22
  organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
23
- jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan")
23
+ jobs: Mapped[List["Job"]] = relationship(
24
+ "Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan"
25
+ )
24
26
 
25
27
  # TODO: Add this back later potentially
26
- # agents: Mapped[List["Agent"]] = relationship(
27
- # "Agent", secondary="users_agents", back_populates="users", doc="the agents associated with this user."
28
- # )
29
28
  # tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.")