letta-nightly 0.10.0.dev20250806104523__py3-none-any.whl → 0.11.0.dev20250807104511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. letta/__init__.py +1 -4
  2. letta/agent.py +1 -2
  3. letta/agents/base_agent.py +4 -7
  4. letta/agents/letta_agent.py +59 -51
  5. letta/agents/letta_agent_batch.py +1 -2
  6. letta/agents/voice_agent.py +1 -2
  7. letta/agents/voice_sleeptime_agent.py +1 -3
  8. letta/constants.py +4 -1
  9. letta/embeddings.py +1 -1
  10. letta/functions/function_sets/base.py +0 -1
  11. letta/functions/mcp_client/types.py +4 -0
  12. letta/groups/supervisor_multi_agent.py +1 -1
  13. letta/interfaces/anthropic_streaming_interface.py +16 -24
  14. letta/interfaces/openai_streaming_interface.py +16 -28
  15. letta/llm_api/llm_api_tools.py +3 -3
  16. letta/local_llm/vllm/api.py +3 -0
  17. letta/orm/__init__.py +3 -1
  18. letta/orm/agent.py +8 -0
  19. letta/orm/archive.py +86 -0
  20. letta/orm/archives_agents.py +27 -0
  21. letta/orm/job.py +5 -1
  22. letta/orm/mixins.py +8 -0
  23. letta/orm/organization.py +7 -8
  24. letta/orm/passage.py +12 -10
  25. letta/orm/sqlite_functions.py +2 -2
  26. letta/orm/tool.py +5 -4
  27. letta/schemas/agent.py +4 -2
  28. letta/schemas/agent_file.py +18 -1
  29. letta/schemas/archive.py +44 -0
  30. letta/schemas/embedding_config.py +2 -16
  31. letta/schemas/enums.py +2 -1
  32. letta/schemas/group.py +28 -3
  33. letta/schemas/job.py +4 -0
  34. letta/schemas/llm_config.py +29 -14
  35. letta/schemas/memory.py +9 -3
  36. letta/schemas/npm_requirement.py +12 -0
  37. letta/schemas/passage.py +3 -3
  38. letta/schemas/providers/letta.py +1 -1
  39. letta/schemas/providers/vllm.py +4 -4
  40. letta/schemas/sandbox_config.py +3 -1
  41. letta/schemas/tool.py +10 -38
  42. letta/schemas/tool_rule.py +2 -2
  43. letta/server/db.py +8 -2
  44. letta/server/rest_api/routers/v1/agents.py +9 -8
  45. letta/server/server.py +6 -40
  46. letta/server/startup.sh +3 -0
  47. letta/services/agent_manager.py +92 -31
  48. letta/services/agent_serialization_manager.py +62 -3
  49. letta/services/archive_manager.py +269 -0
  50. letta/services/helpers/agent_manager_helper.py +111 -37
  51. letta/services/job_manager.py +24 -0
  52. letta/services/passage_manager.py +98 -54
  53. letta/services/tool_executor/core_tool_executor.py +0 -1
  54. letta/services/tool_executor/sandbox_tool_executor.py +2 -2
  55. letta/services/tool_executor/tool_execution_manager.py +1 -1
  56. letta/services/tool_manager.py +70 -26
  57. letta/services/tool_sandbox/base.py +2 -2
  58. letta/services/tool_sandbox/local_sandbox.py +5 -1
  59. letta/templates/template_helper.py +8 -0
  60. {letta_nightly-0.10.0.dev20250806104523.dist-info → letta_nightly-0.11.0.dev20250807104511.dist-info}/METADATA +5 -6
  61. {letta_nightly-0.10.0.dev20250806104523.dist-info → letta_nightly-0.11.0.dev20250807104511.dist-info}/RECORD +64 -61
  62. letta/client/client.py +0 -2207
  63. letta/orm/enums.py +0 -21
  64. {letta_nightly-0.10.0.dev20250806104523.dist-info → letta_nightly-0.11.0.dev20250807104511.dist-info}/LICENSE +0 -0
  65. {letta_nightly-0.10.0.dev20250806104523.dist-info → letta_nightly-0.11.0.dev20250807104511.dist-info}/WHEEL +0 -0
  66. {letta_nightly-0.10.0.dev20250806104523.dist-info → letta_nightly-0.11.0.dev20250807104511.dist-info}/entry_points.txt +0 -0
@@ -9,14 +9,16 @@ from sqlalchemy import select
9
9
  from letta.constants import MAX_EMBEDDING_DIM
10
10
  from letta.embeddings import embedding_model, parse_and_chunk_text
11
11
  from letta.helpers.decorators import async_redis_cache
12
+ from letta.orm import ArchivesAgents
12
13
  from letta.orm.errors import NoResultFound
13
- from letta.orm.passage import AgentPassage, SourcePassage
14
+ from letta.orm.passage import ArchivalPassage, SourcePassage
14
15
  from letta.otel.tracing import trace_method
15
16
  from letta.schemas.agent import AgentState
16
17
  from letta.schemas.file import FileMetadata as PydanticFileMetadata
17
18
  from letta.schemas.passage import Passage as PydanticPassage
18
19
  from letta.schemas.user import User as PydanticUser
19
20
  from letta.server.db import db_registry
21
+ from letta.services.archive_manager import ArchiveManager
20
22
  from letta.utils import enforce_types
21
23
 
22
24
 
@@ -42,6 +44,9 @@ async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> li
42
44
  class PassageManager:
43
45
  """Manager class to handle business logic related to Passages."""
44
46
 
47
+ def __init__(self):
48
+ self.archive_manager = ArchiveManager()
49
+
45
50
  # AGENT PASSAGE METHODS
46
51
  @enforce_types
47
52
  @trace_method
@@ -49,7 +54,7 @@ class PassageManager:
49
54
  """Fetch an agent passage by ID."""
50
55
  with db_registry.session() as session:
51
56
  try:
52
- passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
57
+ passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
53
58
  return passage.to_pydantic()
54
59
  except NoResultFound:
55
60
  raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
@@ -60,7 +65,7 @@ class PassageManager:
60
65
  """Fetch an agent passage by ID."""
61
66
  async with db_registry.async_session() as session:
62
67
  try:
63
- passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
68
+ passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
64
69
  return passage.to_pydantic()
65
70
  except NoResultFound:
66
71
  raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
@@ -109,7 +114,7 @@ class PassageManager:
109
114
  except NoResultFound:
110
115
  # Try archival passages
111
116
  try:
112
- passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
117
+ passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
113
118
  return passage.to_pydantic()
114
119
  except NoResultFound:
115
120
  raise NoResultFound(f"Passage with id {passage_id} not found in database.")
@@ -134,7 +139,7 @@ class PassageManager:
134
139
  except NoResultFound:
135
140
  # Try archival passages
136
141
  try:
137
- passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
142
+ passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
138
143
  return passage.to_pydantic()
139
144
  except NoResultFound:
140
145
  raise NoResultFound(f"Passage with id {passage_id} not found in database.")
@@ -143,8 +148,8 @@ class PassageManager:
143
148
  @trace_method
144
149
  def create_agent_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
145
150
  """Create a new agent passage."""
146
- if not pydantic_passage.agent_id:
147
- raise ValueError("Agent passage must have agent_id")
151
+ if not pydantic_passage.archive_id:
152
+ raise ValueError("Agent passage must have archive_id")
148
153
  if pydantic_passage.source_id:
149
154
  raise ValueError("Agent passage cannot have source_id")
150
155
 
@@ -159,8 +164,8 @@ class PassageManager:
159
164
  "is_deleted": data.get("is_deleted", False),
160
165
  "created_at": data.get("created_at", datetime.now(timezone.utc)),
161
166
  }
162
- agent_fields = {"agent_id": data["agent_id"]}
163
- passage = AgentPassage(**common_fields, **agent_fields)
167
+ agent_fields = {"archive_id": data["archive_id"]}
168
+ passage = ArchivalPassage(**common_fields, **agent_fields)
164
169
 
165
170
  with db_registry.session() as session:
166
171
  passage.create(session, actor=actor)
@@ -170,8 +175,8 @@ class PassageManager:
170
175
  @trace_method
171
176
  async def create_agent_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
172
177
  """Create a new agent passage."""
173
- if not pydantic_passage.agent_id:
174
- raise ValueError("Agent passage must have agent_id")
178
+ if not pydantic_passage.archive_id:
179
+ raise ValueError("Agent passage must have archive_id")
175
180
  if pydantic_passage.source_id:
176
181
  raise ValueError("Agent passage cannot have source_id")
177
182
 
@@ -186,8 +191,8 @@ class PassageManager:
186
191
  "is_deleted": data.get("is_deleted", False),
187
192
  "created_at": data.get("created_at", datetime.now(timezone.utc)),
188
193
  }
189
- agent_fields = {"agent_id": data["agent_id"]}
190
- passage = AgentPassage(**common_fields, **agent_fields)
194
+ agent_fields = {"archive_id": data["archive_id"]}
195
+ passage = ArchivalPassage(**common_fields, **agent_fields)
191
196
 
192
197
  async with db_registry.async_session() as session:
193
198
  passage = await passage.create_async(session, actor=actor)
@@ -201,8 +206,8 @@ class PassageManager:
201
206
  """Create a new source passage."""
202
207
  if not pydantic_passage.source_id:
203
208
  raise ValueError("Source passage must have source_id")
204
- if pydantic_passage.agent_id:
205
- raise ValueError("Source passage cannot have agent_id")
209
+ if pydantic_passage.archive_id:
210
+ raise ValueError("Source passage cannot have archive_id")
206
211
 
207
212
  data = pydantic_passage.model_dump(to_orm=True)
208
213
  common_fields = {
@@ -234,8 +239,8 @@ class PassageManager:
234
239
  """Create a new source passage."""
235
240
  if not pydantic_passage.source_id:
236
241
  raise ValueError("Source passage must have source_id")
237
- if pydantic_passage.agent_id:
238
- raise ValueError("Source passage cannot have agent_id")
242
+ if pydantic_passage.archive_id:
243
+ raise ValueError("Source passage cannot have archive_id")
239
244
 
240
245
  data = pydantic_passage.model_dump(to_orm=True)
241
246
  common_fields = {
@@ -308,21 +313,21 @@ class PassageManager:
308
313
  "created_at": data.get("created_at", datetime.now(timezone.utc)),
309
314
  }
310
315
 
311
- if "agent_id" in data and data["agent_id"]:
312
- assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
316
+ if "archive_id" in data and data["archive_id"]:
317
+ assert not data.get("source_id"), "Passage cannot have both archive_id and source_id"
313
318
  agent_fields = {
314
- "agent_id": data["agent_id"],
319
+ "archive_id": data["archive_id"],
315
320
  }
316
- passage = AgentPassage(**common_fields, **agent_fields)
321
+ passage = ArchivalPassage(**common_fields, **agent_fields)
317
322
  elif "source_id" in data and data["source_id"]:
318
- assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
323
+ assert not data.get("archive_id"), "Passage cannot have both archive_id and source_id"
319
324
  source_fields = {
320
325
  "source_id": data["source_id"],
321
326
  "file_id": data.get("file_id"),
322
327
  }
323
328
  passage = SourcePassage(**common_fields, **source_fields)
324
329
  else:
325
- raise ValueError("Passage must have either agent_id or source_id")
330
+ raise ValueError("Passage must have either archive_id or source_id")
326
331
 
327
332
  return passage
328
333
 
@@ -334,14 +339,14 @@ class PassageManager:
334
339
 
335
340
  @enforce_types
336
341
  @trace_method
337
- async def create_many_agent_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
338
- """Create multiple agent passages."""
339
- agent_passages = []
342
+ async def create_many_archival_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
343
+ """Create multiple archival passages."""
344
+ archival_passages = []
340
345
  for p in passages:
341
- if not p.agent_id:
342
- raise ValueError("Agent passage must have agent_id")
346
+ if not p.archive_id:
347
+ raise ValueError("Archival passage must have archive_id")
343
348
  if p.source_id:
344
- raise ValueError("Agent passage cannot have source_id")
349
+ raise ValueError("Archival passage cannot have source_id")
345
350
 
346
351
  data = p.model_dump(to_orm=True)
347
352
  common_fields = {
@@ -354,12 +359,12 @@ class PassageManager:
354
359
  "is_deleted": data.get("is_deleted", False),
355
360
  "created_at": data.get("created_at", datetime.now(timezone.utc)),
356
361
  }
357
- agent_fields = {"agent_id": data["agent_id"]}
358
- agent_passages.append(AgentPassage(**common_fields, **agent_fields))
362
+ archival_fields = {"archive_id": data["archive_id"]}
363
+ archival_passages.append(ArchivalPassage(**common_fields, **archival_fields))
359
364
 
360
365
  async with db_registry.async_session() as session:
361
- agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
362
- return [p.to_pydantic() for p in agent_created]
366
+ archival_created = await ArchivalPassage.batch_create_async(items=archival_passages, db_session=session, actor=actor)
367
+ return [p.to_pydantic() for p in archival_created]
363
368
 
364
369
  @enforce_types
365
370
  @trace_method
@@ -379,8 +384,8 @@ class PassageManager:
379
384
  for p in passages:
380
385
  if not p.source_id:
381
386
  raise ValueError("Source passage must have source_id")
382
- if p.agent_id:
383
- raise ValueError("Source passage cannot have agent_id")
387
+ if p.archive_id:
388
+ raise ValueError("Source passage cannot have archive_id")
384
389
 
385
390
  data = p.model_dump(to_orm=True)
386
391
  common_fields = {
@@ -436,7 +441,7 @@ class PassageManager:
436
441
 
437
442
  for p in passages:
438
443
  model = self._preprocess_passage_for_creation(p)
439
- if isinstance(model, AgentPassage):
444
+ if isinstance(model, ArchivalPassage):
440
445
  agent_passages.append(model)
441
446
  elif isinstance(model, SourcePassage):
442
447
  source_passages.append(model)
@@ -445,7 +450,7 @@ class PassageManager:
445
450
 
446
451
  results = []
447
452
  if agent_passages:
448
- agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
453
+ agent_created = await ArchivalPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
449
454
  results.extend(agent_created)
450
455
  if source_passages:
451
456
  source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
@@ -458,7 +463,6 @@ class PassageManager:
458
463
  def insert_passage(
459
464
  self,
460
465
  agent_state: AgentState,
461
- agent_id: str,
462
466
  text: str,
463
467
  actor: PydanticUser,
464
468
  ) -> List[PydanticPassage]:
@@ -494,10 +498,15 @@ class PassageManager:
494
498
  raise TypeError(
495
499
  f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
496
500
  )
501
+ # Get or create the default archive for the agent
502
+ archive = self.archive_manager.get_or_create_default_archive_for_agent(
503
+ agent_id=agent_state.id, agent_name=agent_state.name, actor=actor
504
+ )
505
+
497
506
  passage = self.create_agent_passage(
498
507
  PydanticPassage(
499
508
  organization_id=actor.organization_id,
500
- agent_id=agent_id,
509
+ archive_id=archive.id,
501
510
  text=text,
502
511
  embedding=embedding,
503
512
  embedding_config=agent_state.embedding_config,
@@ -516,12 +525,18 @@ class PassageManager:
516
525
  async def insert_passage_async(
517
526
  self,
518
527
  agent_state: AgentState,
519
- agent_id: str,
520
528
  text: str,
521
529
  actor: PydanticUser,
522
530
  image_ids: Optional[List[str]] = None,
523
531
  ) -> List[PydanticPassage]:
524
532
  """Insert passage(s) into archival memory"""
533
+ # Get or create default archive for the agent
534
+ archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(
535
+ agent_id=agent_state.id,
536
+ agent_name=agent_state.name,
537
+ actor=actor,
538
+ )
539
+ archive_id = archive.id
525
540
 
526
541
  embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
527
542
  text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
@@ -535,7 +550,7 @@ class PassageManager:
535
550
  passages = [
536
551
  PydanticPassage(
537
552
  organization_id=actor.organization_id,
538
- agent_id=agent_id,
553
+ archive_id=archive_id,
539
554
  text=chunk_text,
540
555
  embedding=embedding,
541
556
  embedding_config=agent_state.embedding_config,
@@ -543,7 +558,7 @@ class PassageManager:
543
558
  for chunk_text, embedding in zip(text_chunks, embeddings)
544
559
  ]
545
560
 
546
- passages = await self.create_many_agent_passages_async(passages=passages, actor=actor)
561
+ passages = await self.create_many_archival_passages_async(passages=passages, actor=actor)
547
562
 
548
563
  return passages
549
564
 
@@ -595,7 +610,7 @@ class PassageManager:
595
610
 
596
611
  with db_registry.session() as session:
597
612
  try:
598
- curr_passage = AgentPassage.read(
613
+ curr_passage = ArchivalPassage.read(
599
614
  db_session=session,
600
615
  identifier=passage_id,
601
616
  actor=actor,
@@ -623,7 +638,7 @@ class PassageManager:
623
638
 
624
639
  async with db_registry.async_session() as session:
625
640
  try:
626
- curr_passage = await AgentPassage.read_async(
641
+ curr_passage = await ArchivalPassage.read_async(
627
642
  db_session=session,
628
643
  identifier=passage_id,
629
644
  actor=actor,
@@ -705,7 +720,7 @@ class PassageManager:
705
720
 
706
721
  with db_registry.session() as session:
707
722
  try:
708
- passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
723
+ passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
709
724
  passage.hard_delete(session, actor=actor)
710
725
  return True
711
726
  except NoResultFound:
@@ -720,7 +735,7 @@ class PassageManager:
720
735
 
721
736
  async with db_registry.async_session() as session:
722
737
  try:
723
- passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
738
+ passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
724
739
  await passage.hard_delete_async(session, actor=actor)
725
740
  return True
726
741
  except NoResultFound:
@@ -783,7 +798,7 @@ class PassageManager:
783
798
  except NoResultFound:
784
799
  # Try agent passages
785
800
  try:
786
- curr_passage = AgentPassage.read(
801
+ curr_passage = ArchivalPassage.read(
787
802
  db_session=session,
788
803
  identifier=passage_id,
789
804
  actor=actor,
@@ -824,7 +839,7 @@ class PassageManager:
824
839
  except NoResultFound:
825
840
  # Try archival passages
826
841
  try:
827
- passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
842
+ passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
828
843
  passage.hard_delete(session, actor=actor)
829
844
  return True
830
845
  except NoResultFound:
@@ -854,7 +869,7 @@ class PassageManager:
854
869
  except NoResultFound:
855
870
  # Try archival passages
856
871
  try:
857
- passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
872
+ passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
858
873
  await passage.hard_delete_async(session, actor=actor)
859
874
  return True
860
875
  except NoResultFound:
@@ -883,7 +898,7 @@ class PassageManager:
883
898
  ) -> bool:
884
899
  """Delete multiple agent passages."""
885
900
  async with db_registry.async_session() as session:
886
- await AgentPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
901
+ await ArchivalPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
887
902
  return True
888
903
 
889
904
  @enforce_types
@@ -947,7 +962,21 @@ class PassageManager:
947
962
  agent_id: The agent ID of the messages
948
963
  """
949
964
  with db_registry.session() as session:
950
- return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
965
+ if agent_id:
966
+ # Count passages through the archives relationship
967
+ return (
968
+ session.query(ArchivalPassage)
969
+ .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
970
+ .filter(
971
+ ArchivesAgents.agent_id == agent_id,
972
+ ArchivalPassage.organization_id == actor.organization_id,
973
+ ArchivalPassage.is_deleted == False,
974
+ )
975
+ .count()
976
+ )
977
+ else:
978
+ # Count all archival passages in the organization
979
+ return ArchivalPassage.size(db_session=session, actor=actor)
951
980
 
952
981
  # DEPRECATED - Use agent_passage_size() instead since this only counted agent passages anyway
953
982
  @enforce_types
@@ -961,8 +990,7 @@ class PassageManager:
961
990
  import warnings
962
991
 
963
992
  warnings.warn("size is deprecated. Use agent_passage_size() instead.", DeprecationWarning, stacklevel=2)
964
- with db_registry.session() as session:
965
- return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
993
+ return self.agent_passage_size(actor=actor, agent_id=agent_id)
966
994
 
967
995
  @enforce_types
968
996
  @trace_method
@@ -977,7 +1005,23 @@ class PassageManager:
977
1005
  agent_id: The agent ID of the messages
978
1006
  """
979
1007
  async with db_registry.async_session() as session:
980
- return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id)
1008
+ if agent_id:
1009
+ # Count passages through the archives relationship
1010
+ from sqlalchemy import func, select
1011
+
1012
+ result = await session.execute(
1013
+ select(func.count(ArchivalPassage.id))
1014
+ .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
1015
+ .where(
1016
+ ArchivesAgents.agent_id == agent_id,
1017
+ ArchivalPassage.organization_id == actor.organization_id,
1018
+ ArchivalPassage.is_deleted == False,
1019
+ )
1020
+ )
1021
+ return result.scalar() or 0
1022
+ else:
1023
+ # Count all archival passages in the organization
1024
+ return await ArchivalPassage.size_async(db_session=session, actor=actor)
981
1025
 
982
1026
  @enforce_types
983
1027
  @trace_method
@@ -176,7 +176,6 @@ class LettaCoreToolExecutor(ToolExecutor):
176
176
  """
177
177
  await PassageManager().insert_passage_async(
178
178
  agent_state=agent_state,
179
- agent_id=agent_state.id,
180
179
  text=content,
181
180
  actor=actor,
182
181
  )
@@ -42,7 +42,7 @@ class SandboxToolExecutor(ToolExecutor):
42
42
 
43
43
  # Store original memory state
44
44
  if agent_state:
45
- orig_memory_str = await agent_state.memory.compile_async()
45
+ orig_memory_str = await agent_state.memory.compile_in_thread_async()
46
46
  else:
47
47
  orig_memory_str = None
48
48
 
@@ -73,7 +73,7 @@ class SandboxToolExecutor(ToolExecutor):
73
73
 
74
74
  # Verify memory integrity
75
75
  if agent_state:
76
- new_memory_str = await agent_state.memory.compile_async()
76
+ new_memory_str = await agent_state.memory.compile_in_thread_async()
77
77
  assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool"
78
78
 
79
79
  # Update agent memory if needed
@@ -4,11 +4,11 @@ from typing import Any, Dict, Optional, Type
4
4
  from letta.constants import FUNCTION_RETURN_VALUE_TRUNCATED
5
5
  from letta.helpers.datetime_helpers import AsyncTimer
6
6
  from letta.log import get_logger
7
- from letta.orm.enums import ToolType
8
7
  from letta.otel.context import get_ctx_attributes
9
8
  from letta.otel.metric_registry import MetricRegistry
10
9
  from letta.otel.tracing import trace_method
11
10
  from letta.schemas.agent import AgentState
11
+ from letta.schemas.enums import ToolType
12
12
  from letta.schemas.sandbox_config import SandboxConfig
13
13
  from letta.schemas.tool import Tool
14
14
  from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -22,12 +22,12 @@ from letta.constants import (
22
22
  from letta.errors import LettaToolNameConflictError
23
23
  from letta.functions.functions import derive_openai_json_schema, load_function_set
24
24
  from letta.log import get_logger
25
- from letta.orm.enums import ToolType
26
25
 
27
26
  # TODO: Remove this once we translate all of these to the ORM
28
27
  from letta.orm.errors import NoResultFound
29
28
  from letta.orm.tool import Tool as ToolModel
30
29
  from letta.otel.tracing import trace_method
30
+ from letta.schemas.enums import ToolType
31
31
  from letta.schemas.tool import Tool as PydanticTool
32
32
  from letta.schemas.tool import ToolCreate, ToolUpdate
33
33
  from letta.schemas.user import User as PydanticUser
@@ -46,7 +46,7 @@ class ToolManager:
46
46
  # TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
47
47
  @enforce_types
48
48
  @trace_method
49
- def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
49
+ def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False) -> PydanticTool:
50
50
  """Create a new tool based on the ToolCreate schema."""
51
51
  tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor)
52
52
  if tool_id:
@@ -60,7 +60,9 @@ class ToolManager:
60
60
  updated_tool_type = None
61
61
  if "tool_type" in update_data:
62
62
  updated_tool_type = update_data.get("tool_type")
63
- tool = self.update_tool_by_id(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
63
+ tool = self.update_tool_by_id(
64
+ tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
65
+ )
64
66
  else:
65
67
  printd(
66
68
  f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
@@ -73,7 +75,9 @@ class ToolManager:
73
75
 
74
76
  @enforce_types
75
77
  @trace_method
76
- async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
78
+ async def create_or_update_tool_async(
79
+ self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False
80
+ ) -> PydanticTool:
77
81
  """Create a new tool based on the ToolCreate schema."""
78
82
  tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
79
83
  if tool_id:
@@ -88,7 +92,9 @@ class ToolManager:
88
92
  updated_tool_type = None
89
93
  if "tool_type" in update_data:
90
94
  updated_tool_type = update_data.get("tool_type")
91
- tool = await self.update_tool_by_id_async(tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type)
95
+ tool = await self.update_tool_by_id_async(
96
+ tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
97
+ )
92
98
  else:
93
99
  printd(
94
100
  f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
@@ -358,9 +364,13 @@ class ToolManager:
358
364
  results.append(pydantic_tool)
359
365
  except (ValueError, ModuleNotFoundError, AttributeError) as e:
360
366
  tools_to_delete.append(tool)
361
- logger.warning(f"Deleting malformed tool with id={tool.id} and name={tool.name}, error was:\n{e}")
362
- logger.warning("Deleted tool: ")
363
- logger.warning(tool.pretty_print_columns())
367
+ logger.warning(
368
+ "Deleting malformed tool with id=%s and name=%s. Error was:\n%s\nDeleted tool:%s",
369
+ tool.id,
370
+ tool.name,
371
+ e,
372
+ tool.pretty_print_columns(),
373
+ )
364
374
 
365
375
  for tool in tools_to_delete:
366
376
  await self.delete_tool_by_id_async(tool.id, actor=actor)
@@ -387,7 +397,12 @@ class ToolManager:
387
397
  @enforce_types
388
398
  @trace_method
389
399
  def update_tool_by_id(
390
- self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
400
+ self,
401
+ tool_id: str,
402
+ tool_update: ToolUpdate,
403
+ actor: PydanticUser,
404
+ updated_tool_type: Optional[ToolType] = None,
405
+ bypass_name_check: bool = False,
391
406
  ) -> PydanticTool:
392
407
  """Update a tool by its ID with the given ToolUpdate object."""
393
408
  # First, check if source code update would cause a name conflict
@@ -395,17 +410,29 @@ class ToolManager:
395
410
  new_name = None
396
411
  new_schema = None
397
412
 
398
- if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
399
- # Derive the new schema and name from the source code
400
- new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
401
- new_name = new_schema["name"]
413
+ # TODO: Consider this behavior...is this what we want?
414
+ # TODO: I feel like it's bad if json_schema strays from source code so
415
+ # if source code is provided, always derive the name from it
416
+ if "source_code" in update_data.keys() and not bypass_name_check:
417
+ # derive the schema from source code to get the function name
418
+ derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
419
+ new_name = derived_schema["name"]
420
+
421
+ # if json_schema wasn't provided, use the derived schema
422
+ if "json_schema" not in update_data.keys():
423
+ new_schema = derived_schema
424
+ else:
425
+ # if json_schema was provided, update only its name to match the source code
426
+ new_schema = update_data["json_schema"].copy()
427
+ new_schema["name"] = new_name
428
+ # update the json_schema in update_data so it gets applied in the loop
429
+ update_data["json_schema"] = new_schema
402
430
 
403
- # Get current tool to check if name is changing
431
+ # get current tool to check if name is changing
404
432
  current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor)
405
-
406
- # Check if the name is changing and if so, verify it doesn't conflict
433
+ # check if the name is changing and if so, verify it doesn't conflict
407
434
  if new_name != current_tool.name:
408
- # Check if a tool with the new name already exists
435
+ # check if a tool with the new name already exists
409
436
  existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor)
410
437
  if existing_tool:
411
438
  raise LettaToolNameConflictError(tool_name=new_name)
@@ -433,7 +460,12 @@ class ToolManager:
433
460
  @enforce_types
434
461
  @trace_method
435
462
  async def update_tool_by_id_async(
436
- self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
463
+ self,
464
+ tool_id: str,
465
+ tool_update: ToolUpdate,
466
+ actor: PydanticUser,
467
+ updated_tool_type: Optional[ToolType] = None,
468
+ bypass_name_check: bool = False,
437
469
  ) -> PydanticTool:
438
470
  """Update a tool by its ID with the given ToolUpdate object."""
439
471
  # First, check if source code update would cause a name conflict
@@ -441,17 +473,29 @@ class ToolManager:
441
473
  new_name = None
442
474
  new_schema = None
443
475
 
444
- if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
445
- # Derive the new schema and name from the source code
446
- new_schema = derive_openai_json_schema(source_code=update_data["source_code"])
447
- new_name = new_schema["name"]
476
+ # TODO: Consider this behavior...is this what we want?
477
+ # TODO: I feel like it's bad if json_schema strays from source code so
478
+ # if source code is provided, always derive the name from it
479
+ if "source_code" in update_data.keys() and not bypass_name_check:
480
+ # derive the schema from source code to get the function name
481
+ derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
482
+ new_name = derived_schema["name"]
483
+
484
+ # if json_schema wasn't provided, use the derived schema
485
+ if "json_schema" not in update_data.keys():
486
+ new_schema = derived_schema
487
+ else:
488
+ # if json_schema was provided, update only its name to match the source code
489
+ new_schema = update_data["json_schema"].copy()
490
+ new_schema["name"] = new_name
491
+ # update the json_schema in update_data so it gets applied in the loop
492
+ update_data["json_schema"] = new_schema
448
493
 
449
- # Get current tool to check if name is changing
494
+ # get current tool to check if name is changing
450
495
  current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
451
-
452
- # Check if the name is changing and if so, verify it doesn't conflict
496
+ # check if the name is changing and if so, verify it doesn't conflict
453
497
  if new_name != current_tool.name:
454
- # Check if a tool with the new name already exists
498
+ # check if a tool with the new name already exists
455
499
  name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
456
500
  if name_exists:
457
501
  raise LettaToolNameConflictError(tool_name=new_name)
@@ -80,7 +80,7 @@ class AsyncToolSandboxBase(ABC):
80
80
  Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool,
81
81
  then base64-encode/pickle the result. Runs a jinja2 template constructing the python file.
82
82
  """
83
- from letta.templates.template_helper import render_template_async
83
+ from letta.templates.template_helper import render_template_in_thread
84
84
 
85
85
  # Select the appropriate template based on whether the function is async
86
86
  TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2"
@@ -107,7 +107,7 @@ class AsyncToolSandboxBase(ABC):
107
107
 
108
108
  agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
109
109
 
110
- return await render_template_async(
110
+ return await render_template_in_thread(
111
111
  TEMPLATE_NAME,
112
112
  future_import=future_import,
113
113
  inject_agent_state=self.inject_agent_state,