letta-nightly 0.6.4.dev20241216104246__py3-none-any.whl → 0.6.4.dev20241217104233__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

letta/server/server.py CHANGED
@@ -19,7 +19,6 @@ from letta.agent import Agent, save_agent
19
19
  from letta.chat_only_agent import ChatOnlyAgent
20
20
  from letta.credentials import LettaCredentials
21
21
  from letta.data_sources.connectors import DataConnector, load_data
22
- from letta.errors import LettaAgentNotFoundError
23
22
 
24
23
  # TODO use custom interface
25
24
  from letta.interface import AgentInterface # abstract
@@ -399,9 +398,6 @@ class SyncServer(Server):
399
398
  with agent_lock:
400
399
  agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
401
400
 
402
- if agent_state is None:
403
- raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist")
404
-
405
401
  interface = interface or self.default_interface_factory()
406
402
  if agent_state.agent_type == AgentType.memgpt_agent:
407
403
  agent = Agent(agent_state=agent_state, interface=interface, user=actor)
@@ -824,13 +820,13 @@ class SyncServer(Server):
824
820
  actor: User,
825
821
  ) -> AgentState:
826
822
  """Update the agents core memory block, return the new state"""
823
+ # Update agent state in the db first
824
+ self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
825
+
827
826
  # Get the agent object (loaded in memory)
828
827
  letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
829
828
 
830
- # Update tags
831
- if request.tags is not None: # Allow for empty list
832
- letta_agent.agent_state.tags = request.tags
833
-
829
+ # TODO: Everything below needs to get removed, no updating anything in memory
834
830
  # update the system prompt
835
831
  if request.system:
836
832
  letta_agent.update_system_prompt(request.system)
@@ -844,42 +840,10 @@ class SyncServer(Server):
844
840
 
845
841
  # tools
846
842
  if request.tool_ids:
847
- # Replace tools and also re-link
848
-
849
- # (1) get tools + make sure they exist
850
- # Current and target tools as sets of tool names
851
- current_tools = letta_agent.agent_state.tools
852
- current_tool_ids = set([t.id for t in current_tools])
853
- target_tool_ids = set(request.tool_ids)
854
-
855
- # Calculate tools to add and remove
856
- tool_ids_to_add = target_tool_ids - current_tool_ids
857
- tools_ids_to_remove = current_tool_ids - target_tool_ids
858
-
859
- # update agent tool list
860
- for tool_id in tools_ids_to_remove:
861
- self.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
862
- for tool_id in tool_ids_to_add:
863
- self.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
864
-
865
- # reload agent
866
- letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
867
-
868
- # configs
869
- if request.llm_config:
870
- letta_agent.agent_state.llm_config = request.llm_config
871
- if request.embedding_config:
872
- letta_agent.agent_state.embedding_config = request.embedding_config
873
-
874
- # other minor updates
875
- if request.name:
876
- letta_agent.agent_state.name = request.name
877
- if request.metadata_:
878
- letta_agent.agent_state.metadata_ = request.metadata_
879
-
880
- # save the agent
881
- save_agent(letta_agent)
882
- # TODO: probably reload the agent somehow?
843
+ letta_agent.link_tools(letta_agent.agent_state.tools)
844
+
845
+ letta_agent.update_state()
846
+
883
847
  return letta_agent.agent_state
884
848
 
885
849
  def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
@@ -901,32 +865,14 @@ class SyncServer(Server):
901
865
  # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
902
866
  actor = self.user_manager.get_user_or_default(user_id=user_id)
903
867
 
868
+ agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
869
+
870
+ # TODO: This is very redundant, and should probably be simplified
904
871
  # Get the agent object (loaded in memory)
905
872
  letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
873
+ letta_agent.link_tools(agent_state.tools)
906
874
 
907
- # Get all the tool objects from the request
908
- tool_objs = []
909
- tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
910
- assert tool_obj, f"Tool with id={tool_id} does not exist"
911
- tool_objs.append(tool_obj)
912
-
913
- for tool in letta_agent.agent_state.tools:
914
- tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
915
- assert tool_obj, f"Tool with id={tool.id} does not exist"
916
-
917
- # If it's not the already added tool
918
- if tool_obj.id != tool_id:
919
- tool_objs.append(tool_obj)
920
-
921
- # replace the list of tool names ("ids") inside the agent state
922
- letta_agent.agent_state.tools = tool_objs
923
-
924
- # then attempt to link the tools modules
925
- letta_agent.link_tools(tool_objs)
926
-
927
- # save the agent
928
- save_agent(letta_agent)
929
- return letta_agent.agent_state
875
+ return agent_state
930
876
 
931
877
  def remove_tool_from_agent(
932
878
  self,
@@ -937,29 +883,13 @@ class SyncServer(Server):
937
883
  """Remove tools from an existing agent"""
938
884
  # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
939
885
  actor = self.user_manager.get_user_or_default(user_id=user_id)
886
+ agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
940
887
 
941
888
  # Get the agent object (loaded in memory)
942
889
  letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
890
+ letta_agent.link_tools(agent_state.tools)
943
891
 
944
- # Get all the tool_objs
945
- tool_objs = []
946
- for tool in letta_agent.agent_state.tools:
947
- tool_obj = self.tool_manager.get_tool_by_id(tool_id=tool.id, actor=actor)
948
- assert tool_obj, f"Tool with id={tool.id} does not exist"
949
-
950
- # If it's not the tool we want to remove
951
- if tool_obj.id != tool_id:
952
- tool_objs.append(tool_obj)
953
-
954
- # replace the list of tool names ("ids") inside the agent state
955
- letta_agent.agent_state.tools = tool_objs
956
-
957
- # then attempt to link the tools modules
958
- letta_agent.link_tools(tool_objs)
959
-
960
- # save the agent
961
- save_agent(letta_agent)
962
- return letta_agent.agent_state
892
+ return agent_state
963
893
 
964
894
  # convert name->id
965
895
 
@@ -970,7 +900,7 @@ class SyncServer(Server):
970
900
 
971
901
  def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
972
902
  agent = self.load_agent(agent_id=agent_id, actor=actor)
973
- return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user))
903
+ return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
974
904
 
975
905
  def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
976
906
  agent = self.load_agent(agent_id=agent_id, actor=actor)
@@ -987,18 +917,9 @@ class SyncServer(Server):
987
917
  # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
988
918
  actor = self.user_manager.get_user_or_default(user_id=user_id)
989
919
 
990
- # Get the agent object (loaded in memory)
991
- letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
992
-
993
- # iterate over records
994
- records = letta_agent.passage_manager.list_passages(
995
- actor=actor,
996
- agent_id=agent_id,
997
- cursor=cursor,
998
- limit=limit,
999
- )
920
+ passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor)
1000
921
 
1001
- return records
922
+ return passages
1002
923
 
1003
924
  def get_agent_archival_cursor(
1004
925
  self,
@@ -1012,15 +933,13 @@ class SyncServer(Server):
1012
933
  # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
1013
934
  actor = self.user_manager.get_user_or_default(user_id=user_id)
1014
935
 
1015
- # Get the agent object (loaded in memory)
1016
- letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
1017
-
1018
936
  # iterate over records
1019
- records = letta_agent.passage_manager.list_passages(
1020
- actor=self.default_user,
937
+ records = self.agent_manager.list_passages(
938
+ actor=actor,
1021
939
  agent_id=agent_id,
1022
940
  cursor=cursor,
1023
941
  limit=limit,
942
+ ascending=not reverse,
1024
943
  )
1025
944
  return records
1026
945
 
@@ -1105,7 +1024,7 @@ class SyncServer(Server):
1105
1024
  config_copy[k] = server_utils.shorten_key_middle(v, chars_each_side=5)
1106
1025
  return config_copy
1107
1026
 
1108
- # TODO: do we need a seperate server config?
1027
+ # TODO: do we need a separate server config?
1109
1028
  base_config = vars(self.config)
1110
1029
  clean_base_config = clean_keys(base_config)
1111
1030
 
@@ -1136,7 +1055,8 @@ class SyncServer(Server):
1136
1055
  self.source_manager.delete_source(source_id=source_id, actor=actor)
1137
1056
 
1138
1057
  # delete data from passage store
1139
- self.passage_manager.delete_passages(actor=actor, limit=None, source_id=source_id)
1058
+ passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
1059
+ self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
1140
1060
 
1141
1061
  # TODO: delete data from agent passage stores (?)
1142
1062
 
@@ -1167,9 +1087,11 @@ class SyncServer(Server):
1167
1087
  for agent_state in agent_states:
1168
1088
  agent_id = agent_state.id
1169
1089
  agent = self.load_agent(agent_id=agent_id, actor=actor)
1170
- curr_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
1090
+
1091
+ # Attach source to agent
1092
+ curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
1171
1093
  agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
1172
- new_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
1094
+ new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
1173
1095
  assert new_passage_size >= curr_passage_size # in case empty files are added
1174
1096
 
1175
1097
  return job
@@ -1233,14 +1155,9 @@ class SyncServer(Server):
1233
1155
  source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
1234
1156
  elif source_name:
1235
1157
  source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
1158
+ source_id = source.id
1236
1159
  else:
1237
1160
  raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
1238
- source_id = source.id
1239
-
1240
- # TODO: This should be done with the ORM?
1241
- # delete all Passage objects with source_id==source_id from agent's archival memory
1242
- agent = self.load_agent(agent_id=agent_id, actor=actor)
1243
- agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id)
1244
1161
 
1245
1162
  # delete agent-source mapping
1246
1163
  self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
@@ -1262,7 +1179,7 @@ class SyncServer(Server):
1262
1179
  for source in sources:
1263
1180
 
1264
1181
  # count number of passages
1265
- num_passages = self.passage_manager.size(actor=actor, source_id=source.id)
1182
+ num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id)
1266
1183
 
1267
1184
  # TODO: add when files table implemented
1268
1185
  ## count number of files
@@ -33,7 +33,7 @@ class WebSocketServer:
33
33
  self.initialize_server()
34
34
  # Can play with ping_interval and ping_timeout
35
35
  # See: https://websockets.readthedocs.io/en/stable/topics/timeouts.html
36
- # and https://github.com/cpacker/Letta/issues/471
36
+ # and https://github.com/letta-ai/letta/issues/471
37
37
  async with websockets.serve(self.handle_client, self.host, self.port):
38
38
  await asyncio.Future() # Run forever
39
39
 
@@ -1,16 +1,26 @@
1
1
  from typing import Dict, List, Optional
2
+ from datetime import datetime
3
+ import numpy as np
2
4
 
3
- from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
5
+ from sqlalchemy import select, union_all, literal, func, Select
6
+
7
+ from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
8
+ from letta.embeddings import embedding_model
9
+ from letta.log import get_logger
4
10
  from letta.orm import Agent as AgentModel
5
11
  from letta.orm import Block as BlockModel
6
12
  from letta.orm import Source as SourceModel
7
13
  from letta.orm import Tool as ToolModel
14
+ from letta.orm import AgentPassage, SourcePassage
15
+ from letta.orm import SourcesAgents
8
16
  from letta.orm.errors import NoResultFound
17
+ from letta.orm.sqlite_functions import adapt_array
9
18
  from letta.schemas.agent import AgentState as PydanticAgentState
10
19
  from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
11
20
  from letta.schemas.block import Block as PydanticBlock
12
21
  from letta.schemas.embedding_config import EmbeddingConfig
13
22
  from letta.schemas.llm_config import LLMConfig
23
+ from letta.schemas.passage import Passage as PydanticPassage
14
24
  from letta.schemas.source import Source as PydanticSource
15
25
  from letta.schemas.tool_rule import ToolRule as PydanticToolRule
16
26
  from letta.schemas.user import User as PydanticUser
@@ -20,11 +30,13 @@ from letta.services.helpers.agent_manager_helper import (
20
30
  _process_tags,
21
31
  derive_system_message,
22
32
  )
23
- from letta.services.passage_manager import PassageManager
24
33
  from letta.services.source_manager import SourceManager
25
34
  from letta.services.tool_manager import ToolManager
35
+ from letta.settings import settings
26
36
  from letta.utils import enforce_types
27
37
 
38
+ logger = get_logger(__name__)
39
+
28
40
 
29
41
  # Agent Manager Class
30
42
  class AgentManager:
@@ -226,13 +238,6 @@ class AgentManager:
226
238
  with self.session_maker() as session:
227
239
  # Retrieve the agent
228
240
  agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
229
-
230
- # TODO: @mindy delete this piece when we have a proper passages/sources implementation
231
- # TODO: This is done very hacky on purpose
232
- # TODO: 1000 limit is also wack
233
- passage_manager = PassageManager()
234
- passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
235
-
236
241
  agent_state = agent.to_pydantic()
237
242
  agent.hard_delete(session)
238
243
  return agent_state
@@ -403,3 +408,330 @@ class AgentManager:
403
408
 
404
409
  agent.update(session, actor=actor)
405
410
  return agent.to_pydantic()
411
+
412
+ # ======================================================================================================================
413
+ # Passage Management
414
+ # ======================================================================================================================
415
+ def _build_passage_query(
416
+ self,
417
+ actor: PydanticUser,
418
+ agent_id: Optional[str] = None,
419
+ file_id: Optional[str] = None,
420
+ query_text: Optional[str] = None,
421
+ start_date: Optional[datetime] = None,
422
+ end_date: Optional[datetime] = None,
423
+ cursor: Optional[str] = None,
424
+ source_id: Optional[str] = None,
425
+ embed_query: bool = False,
426
+ ascending: bool = True,
427
+ embedding_config: Optional[EmbeddingConfig] = None,
428
+ agent_only: bool = False,
429
+ ) -> Select:
430
+ """Helper function to build the base passage query with all filters applied.
431
+
432
+ Returns the query before any limit or count operations are applied.
433
+ """
434
+ embedded_text = None
435
+ if embed_query:
436
+ assert embedding_config is not None, "embedding_config must be specified for vector search"
437
+ assert query_text is not None, "query_text must be specified for vector search"
438
+ embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
439
+ embedded_text = np.array(embedded_text)
440
+ embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
441
+
442
+ with self.session_maker() as session:
443
+ # Start with base query for source passages
444
+ source_passages = None
445
+ if not agent_only: # Include source passages
446
+ if agent_id is not None:
447
+ source_passages = (
448
+ select(
449
+ SourcePassage,
450
+ literal(None).label('agent_id')
451
+ )
452
+ .join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
453
+ .where(SourcesAgents.agent_id == agent_id)
454
+ .where(SourcePassage.organization_id == actor.organization_id)
455
+ )
456
+ else:
457
+ source_passages = (
458
+ select(
459
+ SourcePassage,
460
+ literal(None).label('agent_id')
461
+ )
462
+ .where(SourcePassage.organization_id == actor.organization_id)
463
+ )
464
+
465
+ if source_id:
466
+ source_passages = source_passages.where(SourcePassage.source_id == source_id)
467
+ if file_id:
468
+ source_passages = source_passages.where(SourcePassage.file_id == file_id)
469
+
470
+ # Add agent passages query
471
+ agent_passages = None
472
+ if agent_id is not None:
473
+ agent_passages = (
474
+ select(
475
+ AgentPassage.id,
476
+ AgentPassage.text,
477
+ AgentPassage.embedding_config,
478
+ AgentPassage.metadata_,
479
+ AgentPassage.embedding,
480
+ AgentPassage.created_at,
481
+ AgentPassage.updated_at,
482
+ AgentPassage.is_deleted,
483
+ AgentPassage._created_by_id,
484
+ AgentPassage._last_updated_by_id,
485
+ AgentPassage.organization_id,
486
+ literal(None).label('file_id'),
487
+ literal(None).label('source_id'),
488
+ AgentPassage.agent_id
489
+ )
490
+ .where(AgentPassage.agent_id == agent_id)
491
+ .where(AgentPassage.organization_id == actor.organization_id)
492
+ )
493
+
494
+ # Combine queries
495
+ if source_passages is not None and agent_passages is not None:
496
+ combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
497
+ elif agent_passages is not None:
498
+ combined_query = agent_passages.cte('combined_passages')
499
+ elif source_passages is not None:
500
+ combined_query = source_passages.cte('combined_passages')
501
+ else:
502
+ raise ValueError("No passages found")
503
+
504
+ # Build main query from combined CTE
505
+ main_query = select(combined_query)
506
+
507
+ # Apply filters
508
+ if start_date:
509
+ main_query = main_query.where(combined_query.c.created_at >= start_date)
510
+ if end_date:
511
+ main_query = main_query.where(combined_query.c.created_at <= end_date)
512
+ if source_id:
513
+ main_query = main_query.where(combined_query.c.source_id == source_id)
514
+ if file_id:
515
+ main_query = main_query.where(combined_query.c.file_id == file_id)
516
+
517
+ # Vector search
518
+ if embedded_text:
519
+ if settings.letta_pg_uri_no_default:
520
+ # PostgreSQL with pgvector
521
+ main_query = main_query.order_by(
522
+ combined_query.c.embedding.cosine_distance(embedded_text).asc()
523
+ )
524
+ else:
525
+ # SQLite with custom vector type
526
+ query_embedding_binary = adapt_array(embedded_text)
527
+ if ascending:
528
+ main_query = main_query.order_by(
529
+ func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
530
+ combined_query.c.created_at.asc(),
531
+ combined_query.c.id.asc()
532
+ )
533
+ else:
534
+ main_query = main_query.order_by(
535
+ func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
536
+ combined_query.c.created_at.desc(),
537
+ combined_query.c.id.asc()
538
+ )
539
+ else:
540
+ if query_text:
541
+ main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
542
+
543
+ # Handle cursor-based pagination
544
+ if cursor:
545
+ cursor_query = select(combined_query.c.created_at).where(
546
+ combined_query.c.id == cursor
547
+ ).scalar_subquery()
548
+
549
+ if ascending:
550
+ main_query = main_query.where(
551
+ combined_query.c.created_at > cursor_query
552
+ )
553
+ else:
554
+ main_query = main_query.where(
555
+ combined_query.c.created_at < cursor_query
556
+ )
557
+
558
+ # Add ordering if not already ordered by similarity
559
+ if not embed_query:
560
+ if ascending:
561
+ main_query = main_query.order_by(
562
+ combined_query.c.created_at.asc(),
563
+ combined_query.c.id.asc(),
564
+ )
565
+ else:
566
+ main_query = main_query.order_by(
567
+ combined_query.c.created_at.desc(),
568
+ combined_query.c.id.asc(),
569
+ )
570
+
571
+ return main_query
572
+
573
+ @enforce_types
574
+ def list_passages(
575
+ self,
576
+ actor: PydanticUser,
577
+ agent_id: Optional[str] = None,
578
+ file_id: Optional[str] = None,
579
+ limit: Optional[int] = 50,
580
+ query_text: Optional[str] = None,
581
+ start_date: Optional[datetime] = None,
582
+ end_date: Optional[datetime] = None,
583
+ cursor: Optional[str] = None,
584
+ source_id: Optional[str] = None,
585
+ embed_query: bool = False,
586
+ ascending: bool = True,
587
+ embedding_config: Optional[EmbeddingConfig] = None,
588
+ agent_only: bool = False
589
+ ) -> List[PydanticPassage]:
590
+ """Lists all passages attached to an agent."""
591
+ with self.session_maker() as session:
592
+ main_query = self._build_passage_query(
593
+ actor=actor,
594
+ agent_id=agent_id,
595
+ file_id=file_id,
596
+ query_text=query_text,
597
+ start_date=start_date,
598
+ end_date=end_date,
599
+ cursor=cursor,
600
+ source_id=source_id,
601
+ embed_query=embed_query,
602
+ ascending=ascending,
603
+ embedding_config=embedding_config,
604
+ agent_only=agent_only,
605
+ )
606
+
607
+ # Add limit
608
+ if limit:
609
+ main_query = main_query.limit(limit)
610
+
611
+ # Execute query
612
+ results = list(session.execute(main_query))
613
+
614
+ passages = []
615
+ for row in results:
616
+ data = dict(row._mapping)
617
+ if data['agent_id'] is not None:
618
+ # This is an AgentPassage - remove source fields
619
+ data.pop('source_id', None)
620
+ data.pop('file_id', None)
621
+ passage = AgentPassage(**data)
622
+ else:
623
+ # This is a SourcePassage - remove agent field
624
+ data.pop('agent_id', None)
625
+ passage = SourcePassage(**data)
626
+ passages.append(passage)
627
+
628
+ return [p.to_pydantic() for p in passages]
629
+
630
+
631
+ @enforce_types
632
+ def passage_size(
633
+ self,
634
+ actor: PydanticUser,
635
+ agent_id: Optional[str] = None,
636
+ file_id: Optional[str] = None,
637
+ query_text: Optional[str] = None,
638
+ start_date: Optional[datetime] = None,
639
+ end_date: Optional[datetime] = None,
640
+ cursor: Optional[str] = None,
641
+ source_id: Optional[str] = None,
642
+ embed_query: bool = False,
643
+ ascending: bool = True,
644
+ embedding_config: Optional[EmbeddingConfig] = None,
645
+ agent_only: bool = False
646
+ ) -> int:
647
+ """Returns the count of passages matching the given criteria."""
648
+ with self.session_maker() as session:
649
+ main_query = self._build_passage_query(
650
+ actor=actor,
651
+ agent_id=agent_id,
652
+ file_id=file_id,
653
+ query_text=query_text,
654
+ start_date=start_date,
655
+ end_date=end_date,
656
+ cursor=cursor,
657
+ source_id=source_id,
658
+ embed_query=embed_query,
659
+ ascending=ascending,
660
+ embedding_config=embedding_config,
661
+ agent_only=agent_only,
662
+ )
663
+
664
+ # Convert to count query
665
+ count_query = select(func.count()).select_from(main_query.subquery())
666
+ return session.scalar(count_query) or 0
667
+
668
+ # ======================================================================================================================
669
+ # Tool Management
670
+ # ======================================================================================================================
671
+ @enforce_types
672
+ def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
673
+ """
674
+ Attaches a tool to an agent.
675
+
676
+ Args:
677
+ agent_id: ID of the agent to attach the tool to.
678
+ tool_id: ID of the tool to attach.
679
+ actor: User performing the action.
680
+
681
+ Raises:
682
+ NoResultFound: If the agent or tool is not found.
683
+
684
+ Returns:
685
+ PydanticAgentState: The updated agent state.
686
+ """
687
+ with self.session_maker() as session:
688
+ # Verify the agent exists and user has permission to access it
689
+ agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
690
+
691
+ # Use the _process_relationship helper to attach the tool
692
+ _process_relationship(
693
+ session=session,
694
+ agent=agent,
695
+ relationship_name="tools",
696
+ model_class=ToolModel,
697
+ item_ids=[tool_id],
698
+ allow_partial=False, # Ensure the tool exists
699
+ replace=False, # Extend the existing tools
700
+ )
701
+
702
+ # Commit and refresh the agent
703
+ agent.update(session, actor=actor)
704
+ return agent.to_pydantic()
705
+
706
+ @enforce_types
707
+ def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
708
+ """
709
+ Detaches a tool from an agent.
710
+
711
+ Args:
712
+ agent_id: ID of the agent to detach the tool from.
713
+ tool_id: ID of the tool to detach.
714
+ actor: User performing the action.
715
+
716
+ Raises:
717
+ NoResultFound: If the agent or tool is not found.
718
+
719
+ Returns:
720
+ PydanticAgentState: The updated agent state.
721
+ """
722
+ with self.session_maker() as session:
723
+ # Verify the agent exists and user has permission to access it
724
+ agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
725
+
726
+ # Filter out the tool to be detached
727
+ remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
728
+
729
+ if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
730
+ logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
731
+
732
+ # Update the tools relationship
733
+ agent.tools = remaining_tools
734
+
735
+ # Commit and refresh the agent
736
+ agent.update(session, actor=actor)
737
+ return agent.to_pydantic()