letta-nightly 0.10.0.dev20250805104522__py3-none-any.whl → 0.11.0.dev20250807000848__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -4
- letta/agent.py +1 -2
- letta/agents/base_agent.py +4 -7
- letta/agents/letta_agent.py +59 -51
- letta/agents/letta_agent_batch.py +1 -2
- letta/agents/voice_agent.py +1 -2
- letta/agents/voice_sleeptime_agent.py +1 -3
- letta/constants.py +4 -1
- letta/embeddings.py +1 -1
- letta/functions/function_sets/base.py +0 -1
- letta/functions/mcp_client/types.py +4 -0
- letta/groups/supervisor_multi_agent.py +1 -1
- letta/interfaces/anthropic_streaming_interface.py +16 -24
- letta/interfaces/openai_streaming_interface.py +16 -28
- letta/llm_api/llm_api_tools.py +3 -3
- letta/local_llm/vllm/api.py +3 -0
- letta/orm/__init__.py +3 -1
- letta/orm/agent.py +8 -0
- letta/orm/archive.py +86 -0
- letta/orm/archives_agents.py +27 -0
- letta/orm/job.py +5 -1
- letta/orm/mixins.py +8 -0
- letta/orm/organization.py +7 -8
- letta/orm/passage.py +12 -10
- letta/orm/sqlite_functions.py +2 -2
- letta/orm/tool.py +5 -4
- letta/schemas/agent.py +4 -2
- letta/schemas/agent_file.py +18 -1
- letta/schemas/archive.py +44 -0
- letta/schemas/embedding_config.py +2 -16
- letta/schemas/enums.py +2 -1
- letta/schemas/group.py +28 -3
- letta/schemas/job.py +4 -0
- letta/schemas/llm_config.py +29 -14
- letta/schemas/memory.py +9 -3
- letta/schemas/npm_requirement.py +12 -0
- letta/schemas/passage.py +3 -3
- letta/schemas/providers/letta.py +1 -1
- letta/schemas/providers/vllm.py +4 -4
- letta/schemas/sandbox_config.py +3 -1
- letta/schemas/tool.py +10 -38
- letta/schemas/tool_rule.py +2 -2
- letta/server/db.py +8 -2
- letta/server/rest_api/routers/v1/agents.py +9 -8
- letta/server/server.py +6 -40
- letta/server/startup.sh +3 -0
- letta/services/agent_manager.py +92 -31
- letta/services/agent_serialization_manager.py +62 -3
- letta/services/archive_manager.py +269 -0
- letta/services/helpers/agent_manager_helper.py +111 -37
- letta/services/job_manager.py +24 -0
- letta/services/passage_manager.py +98 -54
- letta/services/tool_executor/core_tool_executor.py +0 -1
- letta/services/tool_executor/sandbox_tool_executor.py +2 -2
- letta/services/tool_executor/tool_execution_manager.py +1 -1
- letta/services/tool_manager.py +70 -26
- letta/services/tool_sandbox/base.py +2 -2
- letta/services/tool_sandbox/local_sandbox.py +5 -1
- letta/templates/template_helper.py +8 -0
- {letta_nightly-0.10.0.dev20250805104522.dist-info → letta_nightly-0.11.0.dev20250807000848.dist-info}/METADATA +5 -6
- {letta_nightly-0.10.0.dev20250805104522.dist-info → letta_nightly-0.11.0.dev20250807000848.dist-info}/RECORD +64 -61
- letta/client/client.py +0 -2207
- letta/orm/enums.py +0 -21
- {letta_nightly-0.10.0.dev20250805104522.dist-info → letta_nightly-0.11.0.dev20250807000848.dist-info}/LICENSE +0 -0
- {letta_nightly-0.10.0.dev20250805104522.dist-info → letta_nightly-0.11.0.dev20250807000848.dist-info}/WHEEL +0 -0
- {letta_nightly-0.10.0.dev20250805104522.dist-info → letta_nightly-0.11.0.dev20250807000848.dist-info}/entry_points.txt +0 -0
@@ -9,14 +9,16 @@ from sqlalchemy import select
|
|
9
9
|
from letta.constants import MAX_EMBEDDING_DIM
|
10
10
|
from letta.embeddings import embedding_model, parse_and_chunk_text
|
11
11
|
from letta.helpers.decorators import async_redis_cache
|
12
|
+
from letta.orm import ArchivesAgents
|
12
13
|
from letta.orm.errors import NoResultFound
|
13
|
-
from letta.orm.passage import
|
14
|
+
from letta.orm.passage import ArchivalPassage, SourcePassage
|
14
15
|
from letta.otel.tracing import trace_method
|
15
16
|
from letta.schemas.agent import AgentState
|
16
17
|
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
17
18
|
from letta.schemas.passage import Passage as PydanticPassage
|
18
19
|
from letta.schemas.user import User as PydanticUser
|
19
20
|
from letta.server.db import db_registry
|
21
|
+
from letta.services.archive_manager import ArchiveManager
|
20
22
|
from letta.utils import enforce_types
|
21
23
|
|
22
24
|
|
@@ -42,6 +44,9 @@ async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> li
|
|
42
44
|
class PassageManager:
|
43
45
|
"""Manager class to handle business logic related to Passages."""
|
44
46
|
|
47
|
+
def __init__(self):
|
48
|
+
self.archive_manager = ArchiveManager()
|
49
|
+
|
45
50
|
# AGENT PASSAGE METHODS
|
46
51
|
@enforce_types
|
47
52
|
@trace_method
|
@@ -49,7 +54,7 @@ class PassageManager:
|
|
49
54
|
"""Fetch an agent passage by ID."""
|
50
55
|
with db_registry.session() as session:
|
51
56
|
try:
|
52
|
-
passage =
|
57
|
+
passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
53
58
|
return passage.to_pydantic()
|
54
59
|
except NoResultFound:
|
55
60
|
raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
|
@@ -60,7 +65,7 @@ class PassageManager:
|
|
60
65
|
"""Fetch an agent passage by ID."""
|
61
66
|
async with db_registry.async_session() as session:
|
62
67
|
try:
|
63
|
-
passage = await
|
68
|
+
passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
64
69
|
return passage.to_pydantic()
|
65
70
|
except NoResultFound:
|
66
71
|
raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
|
@@ -109,7 +114,7 @@ class PassageManager:
|
|
109
114
|
except NoResultFound:
|
110
115
|
# Try archival passages
|
111
116
|
try:
|
112
|
-
passage =
|
117
|
+
passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
113
118
|
return passage.to_pydantic()
|
114
119
|
except NoResultFound:
|
115
120
|
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
@@ -134,7 +139,7 @@ class PassageManager:
|
|
134
139
|
except NoResultFound:
|
135
140
|
# Try archival passages
|
136
141
|
try:
|
137
|
-
passage = await
|
142
|
+
passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
138
143
|
return passage.to_pydantic()
|
139
144
|
except NoResultFound:
|
140
145
|
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
@@ -143,8 +148,8 @@ class PassageManager:
|
|
143
148
|
@trace_method
|
144
149
|
def create_agent_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
145
150
|
"""Create a new agent passage."""
|
146
|
-
if not pydantic_passage.
|
147
|
-
raise ValueError("Agent passage must have
|
151
|
+
if not pydantic_passage.archive_id:
|
152
|
+
raise ValueError("Agent passage must have archive_id")
|
148
153
|
if pydantic_passage.source_id:
|
149
154
|
raise ValueError("Agent passage cannot have source_id")
|
150
155
|
|
@@ -159,8 +164,8 @@ class PassageManager:
|
|
159
164
|
"is_deleted": data.get("is_deleted", False),
|
160
165
|
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
161
166
|
}
|
162
|
-
agent_fields = {"
|
163
|
-
passage =
|
167
|
+
agent_fields = {"archive_id": data["archive_id"]}
|
168
|
+
passage = ArchivalPassage(**common_fields, **agent_fields)
|
164
169
|
|
165
170
|
with db_registry.session() as session:
|
166
171
|
passage.create(session, actor=actor)
|
@@ -170,8 +175,8 @@ class PassageManager:
|
|
170
175
|
@trace_method
|
171
176
|
async def create_agent_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
172
177
|
"""Create a new agent passage."""
|
173
|
-
if not pydantic_passage.
|
174
|
-
raise ValueError("Agent passage must have
|
178
|
+
if not pydantic_passage.archive_id:
|
179
|
+
raise ValueError("Agent passage must have archive_id")
|
175
180
|
if pydantic_passage.source_id:
|
176
181
|
raise ValueError("Agent passage cannot have source_id")
|
177
182
|
|
@@ -186,8 +191,8 @@ class PassageManager:
|
|
186
191
|
"is_deleted": data.get("is_deleted", False),
|
187
192
|
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
188
193
|
}
|
189
|
-
agent_fields = {"
|
190
|
-
passage =
|
194
|
+
agent_fields = {"archive_id": data["archive_id"]}
|
195
|
+
passage = ArchivalPassage(**common_fields, **agent_fields)
|
191
196
|
|
192
197
|
async with db_registry.async_session() as session:
|
193
198
|
passage = await passage.create_async(session, actor=actor)
|
@@ -201,8 +206,8 @@ class PassageManager:
|
|
201
206
|
"""Create a new source passage."""
|
202
207
|
if not pydantic_passage.source_id:
|
203
208
|
raise ValueError("Source passage must have source_id")
|
204
|
-
if pydantic_passage.
|
205
|
-
raise ValueError("Source passage cannot have
|
209
|
+
if pydantic_passage.archive_id:
|
210
|
+
raise ValueError("Source passage cannot have archive_id")
|
206
211
|
|
207
212
|
data = pydantic_passage.model_dump(to_orm=True)
|
208
213
|
common_fields = {
|
@@ -234,8 +239,8 @@ class PassageManager:
|
|
234
239
|
"""Create a new source passage."""
|
235
240
|
if not pydantic_passage.source_id:
|
236
241
|
raise ValueError("Source passage must have source_id")
|
237
|
-
if pydantic_passage.
|
238
|
-
raise ValueError("Source passage cannot have
|
242
|
+
if pydantic_passage.archive_id:
|
243
|
+
raise ValueError("Source passage cannot have archive_id")
|
239
244
|
|
240
245
|
data = pydantic_passage.model_dump(to_orm=True)
|
241
246
|
common_fields = {
|
@@ -308,21 +313,21 @@ class PassageManager:
|
|
308
313
|
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
309
314
|
}
|
310
315
|
|
311
|
-
if "
|
312
|
-
assert not data.get("source_id"), "Passage cannot have both
|
316
|
+
if "archive_id" in data and data["archive_id"]:
|
317
|
+
assert not data.get("source_id"), "Passage cannot have both archive_id and source_id"
|
313
318
|
agent_fields = {
|
314
|
-
"
|
319
|
+
"archive_id": data["archive_id"],
|
315
320
|
}
|
316
|
-
passage =
|
321
|
+
passage = ArchivalPassage(**common_fields, **agent_fields)
|
317
322
|
elif "source_id" in data and data["source_id"]:
|
318
|
-
assert not data.get("
|
323
|
+
assert not data.get("archive_id"), "Passage cannot have both archive_id and source_id"
|
319
324
|
source_fields = {
|
320
325
|
"source_id": data["source_id"],
|
321
326
|
"file_id": data.get("file_id"),
|
322
327
|
}
|
323
328
|
passage = SourcePassage(**common_fields, **source_fields)
|
324
329
|
else:
|
325
|
-
raise ValueError("Passage must have either
|
330
|
+
raise ValueError("Passage must have either archive_id or source_id")
|
326
331
|
|
327
332
|
return passage
|
328
333
|
|
@@ -334,14 +339,14 @@ class PassageManager:
|
|
334
339
|
|
335
340
|
@enforce_types
|
336
341
|
@trace_method
|
337
|
-
async def
|
338
|
-
"""Create multiple
|
339
|
-
|
342
|
+
async def create_many_archival_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
343
|
+
"""Create multiple archival passages."""
|
344
|
+
archival_passages = []
|
340
345
|
for p in passages:
|
341
|
-
if not p.
|
342
|
-
raise ValueError("
|
346
|
+
if not p.archive_id:
|
347
|
+
raise ValueError("Archival passage must have archive_id")
|
343
348
|
if p.source_id:
|
344
|
-
raise ValueError("
|
349
|
+
raise ValueError("Archival passage cannot have source_id")
|
345
350
|
|
346
351
|
data = p.model_dump(to_orm=True)
|
347
352
|
common_fields = {
|
@@ -354,12 +359,12 @@ class PassageManager:
|
|
354
359
|
"is_deleted": data.get("is_deleted", False),
|
355
360
|
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
356
361
|
}
|
357
|
-
|
358
|
-
|
362
|
+
archival_fields = {"archive_id": data["archive_id"]}
|
363
|
+
archival_passages.append(ArchivalPassage(**common_fields, **archival_fields))
|
359
364
|
|
360
365
|
async with db_registry.async_session() as session:
|
361
|
-
|
362
|
-
return [p.to_pydantic() for p in
|
366
|
+
archival_created = await ArchivalPassage.batch_create_async(items=archival_passages, db_session=session, actor=actor)
|
367
|
+
return [p.to_pydantic() for p in archival_created]
|
363
368
|
|
364
369
|
@enforce_types
|
365
370
|
@trace_method
|
@@ -379,8 +384,8 @@ class PassageManager:
|
|
379
384
|
for p in passages:
|
380
385
|
if not p.source_id:
|
381
386
|
raise ValueError("Source passage must have source_id")
|
382
|
-
if p.
|
383
|
-
raise ValueError("Source passage cannot have
|
387
|
+
if p.archive_id:
|
388
|
+
raise ValueError("Source passage cannot have archive_id")
|
384
389
|
|
385
390
|
data = p.model_dump(to_orm=True)
|
386
391
|
common_fields = {
|
@@ -436,7 +441,7 @@ class PassageManager:
|
|
436
441
|
|
437
442
|
for p in passages:
|
438
443
|
model = self._preprocess_passage_for_creation(p)
|
439
|
-
if isinstance(model,
|
444
|
+
if isinstance(model, ArchivalPassage):
|
440
445
|
agent_passages.append(model)
|
441
446
|
elif isinstance(model, SourcePassage):
|
442
447
|
source_passages.append(model)
|
@@ -445,7 +450,7 @@ class PassageManager:
|
|
445
450
|
|
446
451
|
results = []
|
447
452
|
if agent_passages:
|
448
|
-
agent_created = await
|
453
|
+
agent_created = await ArchivalPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
|
449
454
|
results.extend(agent_created)
|
450
455
|
if source_passages:
|
451
456
|
source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
|
@@ -458,7 +463,6 @@ class PassageManager:
|
|
458
463
|
def insert_passage(
|
459
464
|
self,
|
460
465
|
agent_state: AgentState,
|
461
|
-
agent_id: str,
|
462
466
|
text: str,
|
463
467
|
actor: PydanticUser,
|
464
468
|
) -> List[PydanticPassage]:
|
@@ -494,10 +498,15 @@ class PassageManager:
|
|
494
498
|
raise TypeError(
|
495
499
|
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
496
500
|
)
|
501
|
+
# Get or create the default archive for the agent
|
502
|
+
archive = self.archive_manager.get_or_create_default_archive_for_agent(
|
503
|
+
agent_id=agent_state.id, agent_name=agent_state.name, actor=actor
|
504
|
+
)
|
505
|
+
|
497
506
|
passage = self.create_agent_passage(
|
498
507
|
PydanticPassage(
|
499
508
|
organization_id=actor.organization_id,
|
500
|
-
|
509
|
+
archive_id=archive.id,
|
501
510
|
text=text,
|
502
511
|
embedding=embedding,
|
503
512
|
embedding_config=agent_state.embedding_config,
|
@@ -516,12 +525,18 @@ class PassageManager:
|
|
516
525
|
async def insert_passage_async(
|
517
526
|
self,
|
518
527
|
agent_state: AgentState,
|
519
|
-
agent_id: str,
|
520
528
|
text: str,
|
521
529
|
actor: PydanticUser,
|
522
530
|
image_ids: Optional[List[str]] = None,
|
523
531
|
) -> List[PydanticPassage]:
|
524
532
|
"""Insert passage(s) into archival memory"""
|
533
|
+
# Get or create default archive for the agent
|
534
|
+
archive = await self.archive_manager.get_or_create_default_archive_for_agent_async(
|
535
|
+
agent_id=agent_state.id,
|
536
|
+
agent_name=agent_state.name,
|
537
|
+
actor=actor,
|
538
|
+
)
|
539
|
+
archive_id = archive.id
|
525
540
|
|
526
541
|
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
527
542
|
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
|
@@ -535,7 +550,7 @@ class PassageManager:
|
|
535
550
|
passages = [
|
536
551
|
PydanticPassage(
|
537
552
|
organization_id=actor.organization_id,
|
538
|
-
|
553
|
+
archive_id=archive_id,
|
539
554
|
text=chunk_text,
|
540
555
|
embedding=embedding,
|
541
556
|
embedding_config=agent_state.embedding_config,
|
@@ -543,7 +558,7 @@ class PassageManager:
|
|
543
558
|
for chunk_text, embedding in zip(text_chunks, embeddings)
|
544
559
|
]
|
545
560
|
|
546
|
-
passages = await self.
|
561
|
+
passages = await self.create_many_archival_passages_async(passages=passages, actor=actor)
|
547
562
|
|
548
563
|
return passages
|
549
564
|
|
@@ -595,7 +610,7 @@ class PassageManager:
|
|
595
610
|
|
596
611
|
with db_registry.session() as session:
|
597
612
|
try:
|
598
|
-
curr_passage =
|
613
|
+
curr_passage = ArchivalPassage.read(
|
599
614
|
db_session=session,
|
600
615
|
identifier=passage_id,
|
601
616
|
actor=actor,
|
@@ -623,7 +638,7 @@ class PassageManager:
|
|
623
638
|
|
624
639
|
async with db_registry.async_session() as session:
|
625
640
|
try:
|
626
|
-
curr_passage = await
|
641
|
+
curr_passage = await ArchivalPassage.read_async(
|
627
642
|
db_session=session,
|
628
643
|
identifier=passage_id,
|
629
644
|
actor=actor,
|
@@ -705,7 +720,7 @@ class PassageManager:
|
|
705
720
|
|
706
721
|
with db_registry.session() as session:
|
707
722
|
try:
|
708
|
-
passage =
|
723
|
+
passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
709
724
|
passage.hard_delete(session, actor=actor)
|
710
725
|
return True
|
711
726
|
except NoResultFound:
|
@@ -720,7 +735,7 @@ class PassageManager:
|
|
720
735
|
|
721
736
|
async with db_registry.async_session() as session:
|
722
737
|
try:
|
723
|
-
passage = await
|
738
|
+
passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
724
739
|
await passage.hard_delete_async(session, actor=actor)
|
725
740
|
return True
|
726
741
|
except NoResultFound:
|
@@ -783,7 +798,7 @@ class PassageManager:
|
|
783
798
|
except NoResultFound:
|
784
799
|
# Try agent passages
|
785
800
|
try:
|
786
|
-
curr_passage =
|
801
|
+
curr_passage = ArchivalPassage.read(
|
787
802
|
db_session=session,
|
788
803
|
identifier=passage_id,
|
789
804
|
actor=actor,
|
@@ -824,7 +839,7 @@ class PassageManager:
|
|
824
839
|
except NoResultFound:
|
825
840
|
# Try archival passages
|
826
841
|
try:
|
827
|
-
passage =
|
842
|
+
passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
828
843
|
passage.hard_delete(session, actor=actor)
|
829
844
|
return True
|
830
845
|
except NoResultFound:
|
@@ -854,7 +869,7 @@ class PassageManager:
|
|
854
869
|
except NoResultFound:
|
855
870
|
# Try archival passages
|
856
871
|
try:
|
857
|
-
passage = await
|
872
|
+
passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
858
873
|
await passage.hard_delete_async(session, actor=actor)
|
859
874
|
return True
|
860
875
|
except NoResultFound:
|
@@ -883,7 +898,7 @@ class PassageManager:
|
|
883
898
|
) -> bool:
|
884
899
|
"""Delete multiple agent passages."""
|
885
900
|
async with db_registry.async_session() as session:
|
886
|
-
await
|
901
|
+
await ArchivalPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
|
887
902
|
return True
|
888
903
|
|
889
904
|
@enforce_types
|
@@ -947,7 +962,21 @@ class PassageManager:
|
|
947
962
|
agent_id: The agent ID of the messages
|
948
963
|
"""
|
949
964
|
with db_registry.session() as session:
|
950
|
-
|
965
|
+
if agent_id:
|
966
|
+
# Count passages through the archives relationship
|
967
|
+
return (
|
968
|
+
session.query(ArchivalPassage)
|
969
|
+
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
|
970
|
+
.filter(
|
971
|
+
ArchivesAgents.agent_id == agent_id,
|
972
|
+
ArchivalPassage.organization_id == actor.organization_id,
|
973
|
+
ArchivalPassage.is_deleted == False,
|
974
|
+
)
|
975
|
+
.count()
|
976
|
+
)
|
977
|
+
else:
|
978
|
+
# Count all archival passages in the organization
|
979
|
+
return ArchivalPassage.size(db_session=session, actor=actor)
|
951
980
|
|
952
981
|
# DEPRECATED - Use agent_passage_size() instead since this only counted agent passages anyway
|
953
982
|
@enforce_types
|
@@ -961,8 +990,7 @@ class PassageManager:
|
|
961
990
|
import warnings
|
962
991
|
|
963
992
|
warnings.warn("size is deprecated. Use agent_passage_size() instead.", DeprecationWarning, stacklevel=2)
|
964
|
-
|
965
|
-
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
|
993
|
+
return self.agent_passage_size(actor=actor, agent_id=agent_id)
|
966
994
|
|
967
995
|
@enforce_types
|
968
996
|
@trace_method
|
@@ -977,7 +1005,23 @@ class PassageManager:
|
|
977
1005
|
agent_id: The agent ID of the messages
|
978
1006
|
"""
|
979
1007
|
async with db_registry.async_session() as session:
|
980
|
-
|
1008
|
+
if agent_id:
|
1009
|
+
# Count passages through the archives relationship
|
1010
|
+
from sqlalchemy import func, select
|
1011
|
+
|
1012
|
+
result = await session.execute(
|
1013
|
+
select(func.count(ArchivalPassage.id))
|
1014
|
+
.join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id)
|
1015
|
+
.where(
|
1016
|
+
ArchivesAgents.agent_id == agent_id,
|
1017
|
+
ArchivalPassage.organization_id == actor.organization_id,
|
1018
|
+
ArchivalPassage.is_deleted == False,
|
1019
|
+
)
|
1020
|
+
)
|
1021
|
+
return result.scalar() or 0
|
1022
|
+
else:
|
1023
|
+
# Count all archival passages in the organization
|
1024
|
+
return await ArchivalPassage.size_async(db_session=session, actor=actor)
|
981
1025
|
|
982
1026
|
@enforce_types
|
983
1027
|
@trace_method
|
@@ -42,7 +42,7 @@ class SandboxToolExecutor(ToolExecutor):
|
|
42
42
|
|
43
43
|
# Store original memory state
|
44
44
|
if agent_state:
|
45
|
-
orig_memory_str = await agent_state.memory.
|
45
|
+
orig_memory_str = await agent_state.memory.compile_in_thread_async()
|
46
46
|
else:
|
47
47
|
orig_memory_str = None
|
48
48
|
|
@@ -73,7 +73,7 @@ class SandboxToolExecutor(ToolExecutor):
|
|
73
73
|
|
74
74
|
# Verify memory integrity
|
75
75
|
if agent_state:
|
76
|
-
new_memory_str = await agent_state.memory.
|
76
|
+
new_memory_str = await agent_state.memory.compile_in_thread_async()
|
77
77
|
assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool"
|
78
78
|
|
79
79
|
# Update agent memory if needed
|
@@ -4,11 +4,11 @@ from typing import Any, Dict, Optional, Type
|
|
4
4
|
from letta.constants import FUNCTION_RETURN_VALUE_TRUNCATED
|
5
5
|
from letta.helpers.datetime_helpers import AsyncTimer
|
6
6
|
from letta.log import get_logger
|
7
|
-
from letta.orm.enums import ToolType
|
8
7
|
from letta.otel.context import get_ctx_attributes
|
9
8
|
from letta.otel.metric_registry import MetricRegistry
|
10
9
|
from letta.otel.tracing import trace_method
|
11
10
|
from letta.schemas.agent import AgentState
|
11
|
+
from letta.schemas.enums import ToolType
|
12
12
|
from letta.schemas.sandbox_config import SandboxConfig
|
13
13
|
from letta.schemas.tool import Tool
|
14
14
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
letta/services/tool_manager.py
CHANGED
@@ -22,12 +22,12 @@ from letta.constants import (
|
|
22
22
|
from letta.errors import LettaToolNameConflictError
|
23
23
|
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
24
24
|
from letta.log import get_logger
|
25
|
-
from letta.orm.enums import ToolType
|
26
25
|
|
27
26
|
# TODO: Remove this once we translate all of these to the ORM
|
28
27
|
from letta.orm.errors import NoResultFound
|
29
28
|
from letta.orm.tool import Tool as ToolModel
|
30
29
|
from letta.otel.tracing import trace_method
|
30
|
+
from letta.schemas.enums import ToolType
|
31
31
|
from letta.schemas.tool import Tool as PydanticTool
|
32
32
|
from letta.schemas.tool import ToolCreate, ToolUpdate
|
33
33
|
from letta.schemas.user import User as PydanticUser
|
@@ -46,7 +46,7 @@ class ToolManager:
|
|
46
46
|
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
|
47
47
|
@enforce_types
|
48
48
|
@trace_method
|
49
|
-
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
49
|
+
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False) -> PydanticTool:
|
50
50
|
"""Create a new tool based on the ToolCreate schema."""
|
51
51
|
tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor)
|
52
52
|
if tool_id:
|
@@ -60,7 +60,9 @@ class ToolManager:
|
|
60
60
|
updated_tool_type = None
|
61
61
|
if "tool_type" in update_data:
|
62
62
|
updated_tool_type = update_data.get("tool_type")
|
63
|
-
tool = self.update_tool_by_id(
|
63
|
+
tool = self.update_tool_by_id(
|
64
|
+
tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
|
65
|
+
)
|
64
66
|
else:
|
65
67
|
printd(
|
66
68
|
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
@@ -73,7 +75,9 @@ class ToolManager:
|
|
73
75
|
|
74
76
|
@enforce_types
|
75
77
|
@trace_method
|
76
|
-
async def create_or_update_tool_async(
|
78
|
+
async def create_or_update_tool_async(
|
79
|
+
self, pydantic_tool: PydanticTool, actor: PydanticUser, bypass_name_check: bool = False
|
80
|
+
) -> PydanticTool:
|
77
81
|
"""Create a new tool based on the ToolCreate schema."""
|
78
82
|
tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
|
79
83
|
if tool_id:
|
@@ -88,7 +92,9 @@ class ToolManager:
|
|
88
92
|
updated_tool_type = None
|
89
93
|
if "tool_type" in update_data:
|
90
94
|
updated_tool_type = update_data.get("tool_type")
|
91
|
-
tool = await self.update_tool_by_id_async(
|
95
|
+
tool = await self.update_tool_by_id_async(
|
96
|
+
tool_id, ToolUpdate(**update_data), actor, updated_tool_type=updated_tool_type, bypass_name_check=bypass_name_check
|
97
|
+
)
|
92
98
|
else:
|
93
99
|
printd(
|
94
100
|
f"`create_or_update_tool` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_tool.name}, but found existing tool with nothing to update."
|
@@ -358,9 +364,13 @@ class ToolManager:
|
|
358
364
|
results.append(pydantic_tool)
|
359
365
|
except (ValueError, ModuleNotFoundError, AttributeError) as e:
|
360
366
|
tools_to_delete.append(tool)
|
361
|
-
logger.warning(
|
362
|
-
|
363
|
-
|
367
|
+
logger.warning(
|
368
|
+
"Deleting malformed tool with id=%s and name=%s. Error was:\n%s\nDeleted tool:%s",
|
369
|
+
tool.id,
|
370
|
+
tool.name,
|
371
|
+
e,
|
372
|
+
tool.pretty_print_columns(),
|
373
|
+
)
|
364
374
|
|
365
375
|
for tool in tools_to_delete:
|
366
376
|
await self.delete_tool_by_id_async(tool.id, actor=actor)
|
@@ -387,7 +397,12 @@ class ToolManager:
|
|
387
397
|
@enforce_types
|
388
398
|
@trace_method
|
389
399
|
def update_tool_by_id(
|
390
|
-
self,
|
400
|
+
self,
|
401
|
+
tool_id: str,
|
402
|
+
tool_update: ToolUpdate,
|
403
|
+
actor: PydanticUser,
|
404
|
+
updated_tool_type: Optional[ToolType] = None,
|
405
|
+
bypass_name_check: bool = False,
|
391
406
|
) -> PydanticTool:
|
392
407
|
"""Update a tool by its ID with the given ToolUpdate object."""
|
393
408
|
# First, check if source code update would cause a name conflict
|
@@ -395,17 +410,29 @@ class ToolManager:
|
|
395
410
|
new_name = None
|
396
411
|
new_schema = None
|
397
412
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
413
|
+
# TODO: Consider this behavior...is this what we want?
|
414
|
+
# TODO: I feel like it's bad if json_schema strays from source code so
|
415
|
+
# if source code is provided, always derive the name from it
|
416
|
+
if "source_code" in update_data.keys() and not bypass_name_check:
|
417
|
+
# derive the schema from source code to get the function name
|
418
|
+
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
419
|
+
new_name = derived_schema["name"]
|
420
|
+
|
421
|
+
# if json_schema wasn't provided, use the derived schema
|
422
|
+
if "json_schema" not in update_data.keys():
|
423
|
+
new_schema = derived_schema
|
424
|
+
else:
|
425
|
+
# if json_schema was provided, update only its name to match the source code
|
426
|
+
new_schema = update_data["json_schema"].copy()
|
427
|
+
new_schema["name"] = new_name
|
428
|
+
# update the json_schema in update_data so it gets applied in the loop
|
429
|
+
update_data["json_schema"] = new_schema
|
402
430
|
|
403
|
-
#
|
431
|
+
# get current tool to check if name is changing
|
404
432
|
current_tool = self.get_tool_by_id(tool_id=tool_id, actor=actor)
|
405
|
-
|
406
|
-
# Check if the name is changing and if so, verify it doesn't conflict
|
433
|
+
# check if the name is changing and if so, verify it doesn't conflict
|
407
434
|
if new_name != current_tool.name:
|
408
|
-
#
|
435
|
+
# check if a tool with the new name already exists
|
409
436
|
existing_tool = self.get_tool_by_name(tool_name=new_name, actor=actor)
|
410
437
|
if existing_tool:
|
411
438
|
raise LettaToolNameConflictError(tool_name=new_name)
|
@@ -433,7 +460,12 @@ class ToolManager:
|
|
433
460
|
@enforce_types
|
434
461
|
@trace_method
|
435
462
|
async def update_tool_by_id_async(
|
436
|
-
self,
|
463
|
+
self,
|
464
|
+
tool_id: str,
|
465
|
+
tool_update: ToolUpdate,
|
466
|
+
actor: PydanticUser,
|
467
|
+
updated_tool_type: Optional[ToolType] = None,
|
468
|
+
bypass_name_check: bool = False,
|
437
469
|
) -> PydanticTool:
|
438
470
|
"""Update a tool by its ID with the given ToolUpdate object."""
|
439
471
|
# First, check if source code update would cause a name conflict
|
@@ -441,17 +473,29 @@ class ToolManager:
|
|
441
473
|
new_name = None
|
442
474
|
new_schema = None
|
443
475
|
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
476
|
+
# TODO: Consider this behavior...is this what we want?
|
477
|
+
# TODO: I feel like it's bad if json_schema strays from source code so
|
478
|
+
# if source code is provided, always derive the name from it
|
479
|
+
if "source_code" in update_data.keys() and not bypass_name_check:
|
480
|
+
# derive the schema from source code to get the function name
|
481
|
+
derived_schema = derive_openai_json_schema(source_code=update_data["source_code"])
|
482
|
+
new_name = derived_schema["name"]
|
483
|
+
|
484
|
+
# if json_schema wasn't provided, use the derived schema
|
485
|
+
if "json_schema" not in update_data.keys():
|
486
|
+
new_schema = derived_schema
|
487
|
+
else:
|
488
|
+
# if json_schema was provided, update only its name to match the source code
|
489
|
+
new_schema = update_data["json_schema"].copy()
|
490
|
+
new_schema["name"] = new_name
|
491
|
+
# update the json_schema in update_data so it gets applied in the loop
|
492
|
+
update_data["json_schema"] = new_schema
|
448
493
|
|
449
|
-
#
|
494
|
+
# get current tool to check if name is changing
|
450
495
|
current_tool = await self.get_tool_by_id_async(tool_id=tool_id, actor=actor)
|
451
|
-
|
452
|
-
# Check if the name is changing and if so, verify it doesn't conflict
|
496
|
+
# check if the name is changing and if so, verify it doesn't conflict
|
453
497
|
if new_name != current_tool.name:
|
454
|
-
#
|
498
|
+
# check if a tool with the new name already exists
|
455
499
|
name_exists = await self.tool_name_exists_async(tool_name=new_name, actor=actor)
|
456
500
|
if name_exists:
|
457
501
|
raise LettaToolNameConflictError(tool_name=new_name)
|
@@ -80,7 +80,7 @@ class AsyncToolSandboxBase(ABC):
|
|
80
80
|
Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool,
|
81
81
|
then base64-encode/pickle the result. Runs a jinja2 template constructing the python file.
|
82
82
|
"""
|
83
|
-
from letta.templates.template_helper import
|
83
|
+
from letta.templates.template_helper import render_template_in_thread
|
84
84
|
|
85
85
|
# Select the appropriate template based on whether the function is async
|
86
86
|
TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2"
|
@@ -107,7 +107,7 @@ class AsyncToolSandboxBase(ABC):
|
|
107
107
|
|
108
108
|
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
|
109
109
|
|
110
|
-
return await
|
110
|
+
return await render_template_in_thread(
|
111
111
|
TEMPLATE_NAME,
|
112
112
|
future_import=future_import,
|
113
113
|
inject_agent_state=self.inject_agent_state,
|