letta-nightly 0.6.4.dev20241215104129__py3-none-any.whl → 0.6.4.dev20241217104233__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/agent.py +28 -37
- letta/functions/function_sets/base.py +3 -1
- letta/functions/schema_generator.py +1 -5
- letta/local_llm/function_parser.py +1 -1
- letta/orm/__init__.py +1 -1
- letta/orm/agent.py +19 -1
- letta/orm/file.py +3 -2
- letta/orm/mixins.py +3 -14
- letta/orm/organization.py +19 -3
- letta/orm/passage.py +59 -23
- letta/orm/source.py +4 -0
- letta/orm/sqlalchemy_base.py +2 -2
- letta/prompts/system/memgpt_modified_chat.txt +1 -1
- letta/prompts/system/memgpt_modified_o1.txt +1 -1
- letta/schemas/embedding_config.py +20 -2
- letta/schemas/passage.py +1 -1
- letta/server/rest_api/app.py +13 -0
- letta/server/rest_api/utils.py +24 -5
- letta/server/server.py +31 -114
- letta/server/ws_api/server.py +1 -1
- letta/services/agent_manager.py +341 -9
- letta/services/passage_manager.py +76 -100
- letta/settings.py +1 -1
- {letta_nightly-0.6.4.dev20241215104129.dist-info → letta_nightly-0.6.4.dev20241217104233.dist-info}/METADATA +6 -6
- {letta_nightly-0.6.4.dev20241215104129.dist-info → letta_nightly-0.6.4.dev20241217104233.dist-info}/RECORD +28 -28
- {letta_nightly-0.6.4.dev20241215104129.dist-info → letta_nightly-0.6.4.dev20241217104233.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.4.dev20241215104129.dist-info → letta_nightly-0.6.4.dev20241217104233.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.4.dev20241215104129.dist-info → letta_nightly-0.6.4.dev20241217104233.dist-info}/entry_points.txt +0 -0
letta/server/server.py
CHANGED
|
@@ -19,7 +19,6 @@ from letta.agent import Agent, save_agent
|
|
|
19
19
|
from letta.chat_only_agent import ChatOnlyAgent
|
|
20
20
|
from letta.credentials import LettaCredentials
|
|
21
21
|
from letta.data_sources.connectors import DataConnector, load_data
|
|
22
|
-
from letta.errors import LettaAgentNotFoundError
|
|
23
22
|
|
|
24
23
|
# TODO use custom interface
|
|
25
24
|
from letta.interface import AgentInterface # abstract
|
|
@@ -399,9 +398,6 @@ class SyncServer(Server):
|
|
|
399
398
|
with agent_lock:
|
|
400
399
|
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
|
401
400
|
|
|
402
|
-
if agent_state is None:
|
|
403
|
-
raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist")
|
|
404
|
-
|
|
405
401
|
interface = interface or self.default_interface_factory()
|
|
406
402
|
if agent_state.agent_type == AgentType.memgpt_agent:
|
|
407
403
|
agent = Agent(agent_state=agent_state, interface=interface, user=actor)
|
|
@@ -824,13 +820,13 @@ class SyncServer(Server):
|
|
|
824
820
|
actor: User,
|
|
825
821
|
) -> AgentState:
|
|
826
822
|
"""Update the agents core memory block, return the new state"""
|
|
823
|
+
# Update agent state in the db first
|
|
824
|
+
self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
|
|
825
|
+
|
|
827
826
|
# Get the agent object (loaded in memory)
|
|
828
827
|
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
829
828
|
|
|
830
|
-
#
|
|
831
|
-
if request.tags is not None: # Allow for empty list
|
|
832
|
-
letta_agent.agent_state.tags = request.tags
|
|
833
|
-
|
|
829
|
+
# TODO: Everything below needs to get removed, no updating anything in memory
|
|
834
830
|
# update the system prompt
|
|
835
831
|
if request.system:
|
|
836
832
|
letta_agent.update_system_prompt(request.system)
|
|
@@ -844,42 +840,10 @@ class SyncServer(Server):
|
|
|
844
840
|
|
|
845
841
|
# tools
|
|
846
842
|
if request.tool_ids:
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
current_tools = letta_agent.agent_state.tools
|
|
852
|
-
current_tool_ids = set([t.id for t in current_tools])
|
|
853
|
-
target_tool_ids = set(request.tool_ids)
|
|
854
|
-
|
|
855
|
-
# Calculate tools to add and remove
|
|
856
|
-
tool_ids_to_add = target_tool_ids - current_tool_ids
|
|
857
|
-
tools_ids_to_remove = current_tool_ids - target_tool_ids
|
|
858
|
-
|
|
859
|
-
# update agent tool list
|
|
860
|
-
for tool_id in tools_ids_to_remove:
|
|
861
|
-
self.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
|
862
|
-
for tool_id in tool_ids_to_add:
|
|
863
|
-
self.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
|
864
|
-
|
|
865
|
-
# reload agent
|
|
866
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
867
|
-
|
|
868
|
-
# configs
|
|
869
|
-
if request.llm_config:
|
|
870
|
-
letta_agent.agent_state.llm_config = request.llm_config
|
|
871
|
-
if request.embedding_config:
|
|
872
|
-
letta_agent.agent_state.embedding_config = request.embedding_config
|
|
873
|
-
|
|
874
|
-
# other minor updates
|
|
875
|
-
if request.name:
|
|
876
|
-
letta_agent.agent_state.name = request.name
|
|
877
|
-
if request.metadata_:
|
|
878
|
-
letta_agent.agent_state.metadata_ = request.metadata_
|
|
879
|
-
|
|
880
|
-
# save the agent
|
|
881
|
-
save_agent(letta_agent)
|
|
882
|
-
# TODO: probably reload the agent somehow?
|
|
843
|
+
letta_agent.link_tools(letta_agent.agent_state.tools)
|
|
844
|
+
|
|
845
|
+
letta_agent.update_state()
|
|
846
|
+
|
|
883
847
|
return letta_agent.agent_state
|
|
884
848
|
|
|
885
849
|
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
|
|
@@ -901,32 +865,14 @@ class SyncServer(Server):
|
|
|
901
865
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
902
866
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
903
867
|
|
|
868
|
+
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
|
869
|
+
|
|
870
|
+
# TODO: This is very redundant, and should probably be simplified
|
|
904
871
|
# Get the agent object (loaded in memory)
|
|
905
872
|
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
873
|
+
letta_agent.link_tools(agent_state.tools)
|
|
906
874
|
|
|
907
|
-
|
|
908
|
-
tool_objs = []
|
|
909
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
|
910
|
-
assert tool_obj, f"Tool with id={tool_id} does not exist"
|
|
911
|
-
tool_objs.append(tool_obj)
|
|
912
|
-
|
|
913
|
-
for tool in letta_agent.agent_state.tools:
|
|
914
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
|
|
915
|
-
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
|
916
|
-
|
|
917
|
-
# If it's not the already added tool
|
|
918
|
-
if tool_obj.id != tool_id:
|
|
919
|
-
tool_objs.append(tool_obj)
|
|
920
|
-
|
|
921
|
-
# replace the list of tool names ("ids") inside the agent state
|
|
922
|
-
letta_agent.agent_state.tools = tool_objs
|
|
923
|
-
|
|
924
|
-
# then attempt to link the tools modules
|
|
925
|
-
letta_agent.link_tools(tool_objs)
|
|
926
|
-
|
|
927
|
-
# save the agent
|
|
928
|
-
save_agent(letta_agent)
|
|
929
|
-
return letta_agent.agent_state
|
|
875
|
+
return agent_state
|
|
930
876
|
|
|
931
877
|
def remove_tool_from_agent(
|
|
932
878
|
self,
|
|
@@ -937,29 +883,13 @@ class SyncServer(Server):
|
|
|
937
883
|
"""Remove tools from an existing agent"""
|
|
938
884
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
939
885
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
886
|
+
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
|
940
887
|
|
|
941
888
|
# Get the agent object (loaded in memory)
|
|
942
889
|
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
890
|
+
letta_agent.link_tools(agent_state.tools)
|
|
943
891
|
|
|
944
|
-
|
|
945
|
-
tool_objs = []
|
|
946
|
-
for tool in letta_agent.agent_state.tools:
|
|
947
|
-
tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
|
|
948
|
-
assert tool_obj, f"Tool with id={tool.id} does not exist"
|
|
949
|
-
|
|
950
|
-
# If it's not the tool we want to remove
|
|
951
|
-
if tool_obj.id != tool_id:
|
|
952
|
-
tool_objs.append(tool_obj)
|
|
953
|
-
|
|
954
|
-
# replace the list of tool names ("ids") inside the agent state
|
|
955
|
-
letta_agent.agent_state.tools = tool_objs
|
|
956
|
-
|
|
957
|
-
# then attempt to link the tools modules
|
|
958
|
-
letta_agent.link_tools(tool_objs)
|
|
959
|
-
|
|
960
|
-
# save the agent
|
|
961
|
-
save_agent(letta_agent)
|
|
962
|
-
return letta_agent.agent_state
|
|
892
|
+
return agent_state
|
|
963
893
|
|
|
964
894
|
# convert name->id
|
|
965
895
|
|
|
@@ -970,7 +900,7 @@ class SyncServer(Server):
|
|
|
970
900
|
|
|
971
901
|
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
|
|
972
902
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
973
|
-
return ArchivalMemorySummary(size=
|
|
903
|
+
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
|
|
974
904
|
|
|
975
905
|
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
|
|
976
906
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
@@ -987,18 +917,9 @@ class SyncServer(Server):
|
|
|
987
917
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
988
918
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
989
919
|
|
|
990
|
-
|
|
991
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
992
|
-
|
|
993
|
-
# iterate over records
|
|
994
|
-
records = letta_agent.passage_manager.list_passages(
|
|
995
|
-
actor=actor,
|
|
996
|
-
agent_id=agent_id,
|
|
997
|
-
cursor=cursor,
|
|
998
|
-
limit=limit,
|
|
999
|
-
)
|
|
920
|
+
passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor)
|
|
1000
921
|
|
|
1001
|
-
return
|
|
922
|
+
return passages
|
|
1002
923
|
|
|
1003
924
|
def get_agent_archival_cursor(
|
|
1004
925
|
self,
|
|
@@ -1012,15 +933,13 @@ class SyncServer(Server):
|
|
|
1012
933
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
|
1013
934
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
|
1014
935
|
|
|
1015
|
-
# Get the agent object (loaded in memory)
|
|
1016
|
-
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
1017
|
-
|
|
1018
936
|
# iterate over records
|
|
1019
|
-
records =
|
|
1020
|
-
actor=
|
|
937
|
+
records = self.agent_manager.list_passages(
|
|
938
|
+
actor=actor,
|
|
1021
939
|
agent_id=agent_id,
|
|
1022
940
|
cursor=cursor,
|
|
1023
941
|
limit=limit,
|
|
942
|
+
ascending=not reverse,
|
|
1024
943
|
)
|
|
1025
944
|
return records
|
|
1026
945
|
|
|
@@ -1105,7 +1024,7 @@ class SyncServer(Server):
|
|
|
1105
1024
|
config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
|
|
1106
1025
|
return config_copy
|
|
1107
1026
|
|
|
1108
|
-
# TODO: do we need a
|
|
1027
|
+
# TODO: do we need a separate server config?
|
|
1109
1028
|
base_config = vars(self.config)
|
|
1110
1029
|
clean_base_config = clean_keys(base_config)
|
|
1111
1030
|
|
|
@@ -1136,7 +1055,8 @@ class SyncServer(Server):
|
|
|
1136
1055
|
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
|
1137
1056
|
|
|
1138
1057
|
# delete data from passage store
|
|
1139
|
-
self.
|
|
1058
|
+
passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
|
|
1059
|
+
self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
|
|
1140
1060
|
|
|
1141
1061
|
# TODO: delete data from agent passage stores (?)
|
|
1142
1062
|
|
|
@@ -1167,9 +1087,11 @@ class SyncServer(Server):
|
|
|
1167
1087
|
for agent_state in agent_states:
|
|
1168
1088
|
agent_id = agent_state.id
|
|
1169
1089
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
1170
|
-
|
|
1090
|
+
|
|
1091
|
+
# Attach source to agent
|
|
1092
|
+
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
|
1171
1093
|
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
|
|
1172
|
-
new_passage_size = self.
|
|
1094
|
+
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
|
1173
1095
|
assert new_passage_size >= curr_passage_size # in case empty files are added
|
|
1174
1096
|
|
|
1175
1097
|
return job
|
|
@@ -1233,14 +1155,9 @@ class SyncServer(Server):
|
|
|
1233
1155
|
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
|
1234
1156
|
elif source_name:
|
|
1235
1157
|
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
|
1158
|
+
source_id = source.id
|
|
1236
1159
|
else:
|
|
1237
1160
|
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
|
1238
|
-
source_id = source.id
|
|
1239
|
-
|
|
1240
|
-
# TODO: This should be done with the ORM?
|
|
1241
|
-
# delete all Passage objects with source_id==source_id from agent's archival memory
|
|
1242
|
-
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
1243
|
-
agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id)
|
|
1244
1161
|
|
|
1245
1162
|
# delete agent-source mapping
|
|
1246
1163
|
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
|
@@ -1262,7 +1179,7 @@ class SyncServer(Server):
|
|
|
1262
1179
|
for source in sources:
|
|
1263
1180
|
|
|
1264
1181
|
# count number of passages
|
|
1265
|
-
num_passages = self.
|
|
1182
|
+
num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id)
|
|
1266
1183
|
|
|
1267
1184
|
# TODO: add when files table implemented
|
|
1268
1185
|
## count number of files
|
letta/server/ws_api/server.py
CHANGED
|
@@ -33,7 +33,7 @@ class WebSocketServer:
|
|
|
33
33
|
self.initialize_server()
|
|
34
34
|
# Can play with ping_interval and ping_timeout
|
|
35
35
|
# See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html
|
|
36
|
-
# and https://github.com/
|
|
36
|
+
# and https://github.com/letta-ai/letta/issues/471
|
|
37
37
|
async with websockets.serve(self.handle_client, self.host, self.port):
|
|
38
38
|
await asyncio.Future() # Run forever
|
|
39
39
|
|
letta/services/agent_manager.py
CHANGED
|
@@ -1,16 +1,26 @@
|
|
|
1
1
|
from typing import Dict, List, Optional
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
import numpy as np
|
|
2
4
|
|
|
3
|
-
from
|
|
5
|
+
from sqlalchemy import select, union_all, literal, func, Select
|
|
6
|
+
|
|
7
|
+
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
|
8
|
+
from letta.embeddings import embedding_model
|
|
9
|
+
from letta.log import get_logger
|
|
4
10
|
from letta.orm import Agent as AgentModel
|
|
5
11
|
from letta.orm import Block as BlockModel
|
|
6
12
|
from letta.orm import Source as SourceModel
|
|
7
13
|
from letta.orm import Tool as ToolModel
|
|
14
|
+
from letta.orm import AgentPassage, SourcePassage
|
|
15
|
+
from letta.orm import SourcesAgents
|
|
8
16
|
from letta.orm.errors import NoResultFound
|
|
17
|
+
from letta.orm.sqlite_functions import adapt_array
|
|
9
18
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
|
10
19
|
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
|
11
20
|
from letta.schemas.block import Block as PydanticBlock
|
|
12
21
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
13
22
|
from letta.schemas.llm_config import LLMConfig
|
|
23
|
+
from letta.schemas.passage import Passage as PydanticPassage
|
|
14
24
|
from letta.schemas.source import Source as PydanticSource
|
|
15
25
|
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
|
16
26
|
from letta.schemas.user import User as PydanticUser
|
|
@@ -20,11 +30,13 @@ from letta.services.helpers.agent_manager_helper import (
|
|
|
20
30
|
_process_tags,
|
|
21
31
|
derive_system_message,
|
|
22
32
|
)
|
|
23
|
-
from letta.services.passage_manager import PassageManager
|
|
24
33
|
from letta.services.source_manager import SourceManager
|
|
25
34
|
from letta.services.tool_manager import ToolManager
|
|
35
|
+
from letta.settings import settings
|
|
26
36
|
from letta.utils import enforce_types
|
|
27
37
|
|
|
38
|
+
logger = get_logger(__name__)
|
|
39
|
+
|
|
28
40
|
|
|
29
41
|
# Agent Manager Class
|
|
30
42
|
class AgentManager:
|
|
@@ -226,13 +238,6 @@ class AgentManager:
|
|
|
226
238
|
with self.session_maker() as session:
|
|
227
239
|
# Retrieve the agent
|
|
228
240
|
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
|
229
|
-
|
|
230
|
-
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
|
|
231
|
-
# TODO: This is done very hacky on purpose
|
|
232
|
-
# TODO: 1000 limit is also wack
|
|
233
|
-
passage_manager = PassageManager()
|
|
234
|
-
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
|
|
235
|
-
|
|
236
241
|
agent_state = agent.to_pydantic()
|
|
237
242
|
agent.hard_delete(session)
|
|
238
243
|
return agent_state
|
|
@@ -403,3 +408,330 @@ class AgentManager:
|
|
|
403
408
|
|
|
404
409
|
agent.update(session, actor=actor)
|
|
405
410
|
return agent.to_pydantic()
|
|
411
|
+
|
|
412
|
+
# ======================================================================================================================
|
|
413
|
+
# Passage Management
|
|
414
|
+
# ======================================================================================================================
|
|
415
|
+
def _build_passage_query(
|
|
416
|
+
self,
|
|
417
|
+
actor: PydanticUser,
|
|
418
|
+
agent_id: Optional[str] = None,
|
|
419
|
+
file_id: Optional[str] = None,
|
|
420
|
+
query_text: Optional[str] = None,
|
|
421
|
+
start_date: Optional[datetime] = None,
|
|
422
|
+
end_date: Optional[datetime] = None,
|
|
423
|
+
cursor: Optional[str] = None,
|
|
424
|
+
source_id: Optional[str] = None,
|
|
425
|
+
embed_query: bool = False,
|
|
426
|
+
ascending: bool = True,
|
|
427
|
+
embedding_config: Optional[EmbeddingConfig] = None,
|
|
428
|
+
agent_only: bool = False,
|
|
429
|
+
) -> Select:
|
|
430
|
+
"""Helper function to build the base passage query with all filters applied.
|
|
431
|
+
|
|
432
|
+
Returns the query before any limit or count operations are applied.
|
|
433
|
+
"""
|
|
434
|
+
embedded_text = None
|
|
435
|
+
if embed_query:
|
|
436
|
+
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
|
437
|
+
assert query_text is not None, "query_text must be specified for vector search"
|
|
438
|
+
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
|
439
|
+
embedded_text = np.array(embedded_text)
|
|
440
|
+
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
|
441
|
+
|
|
442
|
+
with self.session_maker() as session:
|
|
443
|
+
# Start with base query for source passages
|
|
444
|
+
source_passages = None
|
|
445
|
+
if not agent_only: # Include source passages
|
|
446
|
+
if agent_id is not None:
|
|
447
|
+
source_passages = (
|
|
448
|
+
select(
|
|
449
|
+
SourcePassage,
|
|
450
|
+
literal(None).label('agent_id')
|
|
451
|
+
)
|
|
452
|
+
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
|
453
|
+
.where(SourcesAgents.agent_id == agent_id)
|
|
454
|
+
.where(SourcePassage.organization_id == actor.organization_id)
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
source_passages = (
|
|
458
|
+
select(
|
|
459
|
+
SourcePassage,
|
|
460
|
+
literal(None).label('agent_id')
|
|
461
|
+
)
|
|
462
|
+
.where(SourcePassage.organization_id == actor.organization_id)
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if source_id:
|
|
466
|
+
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
|
467
|
+
if file_id:
|
|
468
|
+
source_passages = source_passages.where(SourcePassage.file_id == file_id)
|
|
469
|
+
|
|
470
|
+
# Add agent passages query
|
|
471
|
+
agent_passages = None
|
|
472
|
+
if agent_id is not None:
|
|
473
|
+
agent_passages = (
|
|
474
|
+
select(
|
|
475
|
+
AgentPassage.id,
|
|
476
|
+
AgentPassage.text,
|
|
477
|
+
AgentPassage.embedding_config,
|
|
478
|
+
AgentPassage.metadata_,
|
|
479
|
+
AgentPassage.embedding,
|
|
480
|
+
AgentPassage.created_at,
|
|
481
|
+
AgentPassage.updated_at,
|
|
482
|
+
AgentPassage.is_deleted,
|
|
483
|
+
AgentPassage._created_by_id,
|
|
484
|
+
AgentPassage._last_updated_by_id,
|
|
485
|
+
AgentPassage.organization_id,
|
|
486
|
+
literal(None).label('file_id'),
|
|
487
|
+
literal(None).label('source_id'),
|
|
488
|
+
AgentPassage.agent_id
|
|
489
|
+
)
|
|
490
|
+
.where(AgentPassage.agent_id == agent_id)
|
|
491
|
+
.where(AgentPassage.organization_id == actor.organization_id)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Combine queries
|
|
495
|
+
if source_passages is not None and agent_passages is not None:
|
|
496
|
+
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
|
|
497
|
+
elif agent_passages is not None:
|
|
498
|
+
combined_query = agent_passages.cte('combined_passages')
|
|
499
|
+
elif source_passages is not None:
|
|
500
|
+
combined_query = source_passages.cte('combined_passages')
|
|
501
|
+
else:
|
|
502
|
+
raise ValueError("No passages found")
|
|
503
|
+
|
|
504
|
+
# Build main query from combined CTE
|
|
505
|
+
main_query = select(combined_query)
|
|
506
|
+
|
|
507
|
+
# Apply filters
|
|
508
|
+
if start_date:
|
|
509
|
+
main_query = main_query.where(combined_query.c.created_at >= start_date)
|
|
510
|
+
if end_date:
|
|
511
|
+
main_query = main_query.where(combined_query.c.created_at <= end_date)
|
|
512
|
+
if source_id:
|
|
513
|
+
main_query = main_query.where(combined_query.c.source_id == source_id)
|
|
514
|
+
if file_id:
|
|
515
|
+
main_query = main_query.where(combined_query.c.file_id == file_id)
|
|
516
|
+
|
|
517
|
+
# Vector search
|
|
518
|
+
if embedded_text:
|
|
519
|
+
if settings.letta_pg_uri_no_default:
|
|
520
|
+
# PostgreSQL with pgvector
|
|
521
|
+
main_query = main_query.order_by(
|
|
522
|
+
combined_query.c.embedding.cosine_distance(embedded_text).asc()
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
# SQLite with custom vector type
|
|
526
|
+
query_embedding_binary = adapt_array(embedded_text)
|
|
527
|
+
if ascending:
|
|
528
|
+
main_query = main_query.order_by(
|
|
529
|
+
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
|
530
|
+
combined_query.c.created_at.asc(),
|
|
531
|
+
combined_query.c.id.asc()
|
|
532
|
+
)
|
|
533
|
+
else:
|
|
534
|
+
main_query = main_query.order_by(
|
|
535
|
+
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
|
536
|
+
combined_query.c.created_at.desc(),
|
|
537
|
+
combined_query.c.id.asc()
|
|
538
|
+
)
|
|
539
|
+
else:
|
|
540
|
+
if query_text:
|
|
541
|
+
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
|
|
542
|
+
|
|
543
|
+
# Handle cursor-based pagination
|
|
544
|
+
if cursor:
|
|
545
|
+
cursor_query = select(combined_query.c.created_at).where(
|
|
546
|
+
combined_query.c.id == cursor
|
|
547
|
+
).scalar_subquery()
|
|
548
|
+
|
|
549
|
+
if ascending:
|
|
550
|
+
main_query = main_query.where(
|
|
551
|
+
combined_query.c.created_at > cursor_query
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
main_query = main_query.where(
|
|
555
|
+
combined_query.c.created_at < cursor_query
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Add ordering if not already ordered by similarity
|
|
559
|
+
if not embed_query:
|
|
560
|
+
if ascending:
|
|
561
|
+
main_query = main_query.order_by(
|
|
562
|
+
combined_query.c.created_at.asc(),
|
|
563
|
+
combined_query.c.id.asc(),
|
|
564
|
+
)
|
|
565
|
+
else:
|
|
566
|
+
main_query = main_query.order_by(
|
|
567
|
+
combined_query.c.created_at.desc(),
|
|
568
|
+
combined_query.c.id.asc(),
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
return main_query
|
|
572
|
+
|
|
573
|
+
@enforce_types
|
|
574
|
+
def list_passages(
|
|
575
|
+
self,
|
|
576
|
+
actor: PydanticUser,
|
|
577
|
+
agent_id: Optional[str] = None,
|
|
578
|
+
file_id: Optional[str] = None,
|
|
579
|
+
limit: Optional[int] = 50,
|
|
580
|
+
query_text: Optional[str] = None,
|
|
581
|
+
start_date: Optional[datetime] = None,
|
|
582
|
+
end_date: Optional[datetime] = None,
|
|
583
|
+
cursor: Optional[str] = None,
|
|
584
|
+
source_id: Optional[str] = None,
|
|
585
|
+
embed_query: bool = False,
|
|
586
|
+
ascending: bool = True,
|
|
587
|
+
embedding_config: Optional[EmbeddingConfig] = None,
|
|
588
|
+
agent_only: bool = False
|
|
589
|
+
) -> List[PydanticPassage]:
|
|
590
|
+
"""Lists all passages attached to an agent."""
|
|
591
|
+
with self.session_maker() as session:
|
|
592
|
+
main_query = self._build_passage_query(
|
|
593
|
+
actor=actor,
|
|
594
|
+
agent_id=agent_id,
|
|
595
|
+
file_id=file_id,
|
|
596
|
+
query_text=query_text,
|
|
597
|
+
start_date=start_date,
|
|
598
|
+
end_date=end_date,
|
|
599
|
+
cursor=cursor,
|
|
600
|
+
source_id=source_id,
|
|
601
|
+
embed_query=embed_query,
|
|
602
|
+
ascending=ascending,
|
|
603
|
+
embedding_config=embedding_config,
|
|
604
|
+
agent_only=agent_only,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
# Add limit
|
|
608
|
+
if limit:
|
|
609
|
+
main_query = main_query.limit(limit)
|
|
610
|
+
|
|
611
|
+
# Execute query
|
|
612
|
+
results = list(session.execute(main_query))
|
|
613
|
+
|
|
614
|
+
passages = []
|
|
615
|
+
for row in results:
|
|
616
|
+
data = dict(row._mapping)
|
|
617
|
+
if data['agent_id'] is not None:
|
|
618
|
+
# This is an AgentPassage - remove source fields
|
|
619
|
+
data.pop('source_id', None)
|
|
620
|
+
data.pop('file_id', None)
|
|
621
|
+
passage = AgentPassage(**data)
|
|
622
|
+
else:
|
|
623
|
+
# This is a SourcePassage - remove agent field
|
|
624
|
+
data.pop('agent_id', None)
|
|
625
|
+
passage = SourcePassage(**data)
|
|
626
|
+
passages.append(passage)
|
|
627
|
+
|
|
628
|
+
return [p.to_pydantic() for p in passages]
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
@enforce_types
|
|
632
|
+
def passage_size(
|
|
633
|
+
self,
|
|
634
|
+
actor: PydanticUser,
|
|
635
|
+
agent_id: Optional[str] = None,
|
|
636
|
+
file_id: Optional[str] = None,
|
|
637
|
+
query_text: Optional[str] = None,
|
|
638
|
+
start_date: Optional[datetime] = None,
|
|
639
|
+
end_date: Optional[datetime] = None,
|
|
640
|
+
cursor: Optional[str] = None,
|
|
641
|
+
source_id: Optional[str] = None,
|
|
642
|
+
embed_query: bool = False,
|
|
643
|
+
ascending: bool = True,
|
|
644
|
+
embedding_config: Optional[EmbeddingConfig] = None,
|
|
645
|
+
agent_only: bool = False
|
|
646
|
+
) -> int:
|
|
647
|
+
"""Returns the count of passages matching the given criteria."""
|
|
648
|
+
with self.session_maker() as session:
|
|
649
|
+
main_query = self._build_passage_query(
|
|
650
|
+
actor=actor,
|
|
651
|
+
agent_id=agent_id,
|
|
652
|
+
file_id=file_id,
|
|
653
|
+
query_text=query_text,
|
|
654
|
+
start_date=start_date,
|
|
655
|
+
end_date=end_date,
|
|
656
|
+
cursor=cursor,
|
|
657
|
+
source_id=source_id,
|
|
658
|
+
embed_query=embed_query,
|
|
659
|
+
ascending=ascending,
|
|
660
|
+
embedding_config=embedding_config,
|
|
661
|
+
agent_only=agent_only,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Convert to count query
|
|
665
|
+
count_query = select(func.count()).select_from(main_query.subquery())
|
|
666
|
+
return session.scalar(count_query) or 0
|
|
667
|
+
|
|
668
|
+
# ======================================================================================================================
|
|
669
|
+
# Tool Management
|
|
670
|
+
# ======================================================================================================================
|
|
671
|
+
@enforce_types
|
|
672
|
+
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
|
673
|
+
"""
|
|
674
|
+
Attaches a tool to an agent.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
agent_id: ID of the agent to attach the tool to.
|
|
678
|
+
tool_id: ID of the tool to attach.
|
|
679
|
+
actor: User performing the action.
|
|
680
|
+
|
|
681
|
+
Raises:
|
|
682
|
+
NoResultFound: If the agent or tool is not found.
|
|
683
|
+
|
|
684
|
+
Returns:
|
|
685
|
+
PydanticAgentState: The updated agent state.
|
|
686
|
+
"""
|
|
687
|
+
with self.session_maker() as session:
|
|
688
|
+
# Verify the agent exists and user has permission to access it
|
|
689
|
+
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
|
690
|
+
|
|
691
|
+
# Use the _process_relationship helper to attach the tool
|
|
692
|
+
_process_relationship(
|
|
693
|
+
session=session,
|
|
694
|
+
agent=agent,
|
|
695
|
+
relationship_name="tools",
|
|
696
|
+
model_class=ToolModel,
|
|
697
|
+
item_ids=[tool_id],
|
|
698
|
+
allow_partial=False, # Ensure the tool exists
|
|
699
|
+
replace=False, # Extend the existing tools
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
# Commit and refresh the agent
|
|
703
|
+
agent.update(session, actor=actor)
|
|
704
|
+
return agent.to_pydantic()
|
|
705
|
+
|
|
706
|
+
@enforce_types
|
|
707
|
+
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
|
708
|
+
"""
|
|
709
|
+
Detaches a tool from an agent.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
agent_id: ID of the agent to detach the tool from.
|
|
713
|
+
tool_id: ID of the tool to detach.
|
|
714
|
+
actor: User performing the action.
|
|
715
|
+
|
|
716
|
+
Raises:
|
|
717
|
+
NoResultFound: If the agent or tool is not found.
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
PydanticAgentState: The updated agent state.
|
|
721
|
+
"""
|
|
722
|
+
with self.session_maker() as session:
|
|
723
|
+
# Verify the agent exists and user has permission to access it
|
|
724
|
+
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
|
725
|
+
|
|
726
|
+
# Filter out the tool to be detached
|
|
727
|
+
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
|
|
728
|
+
|
|
729
|
+
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
|
|
730
|
+
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
|
731
|
+
|
|
732
|
+
# Update the tools relationship
|
|
733
|
+
agent.tools = remaining_tools
|
|
734
|
+
|
|
735
|
+
# Commit and refresh the agent
|
|
736
|
+
agent.update(session, actor=actor)
|
|
737
|
+
return agent.to_pydantic()
|