letta-nightly 0.6.4.dev20241216104246__py3-none-any.whl → 0.6.5.dev20241218055539__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/agent.py +95 -101
- letta/client/client.py +1 -0
- letta/constants.py +6 -1
- letta/embeddings.py +3 -9
- letta/functions/function_sets/base.py +11 -57
- letta/functions/schema_generator.py +2 -6
- letta/llm_api/anthropic.py +38 -13
- letta/llm_api/llm_api_tools.py +12 -1
- letta/local_llm/function_parser.py +2 -2
- letta/orm/__init__.py +1 -1
- letta/orm/agent.py +19 -1
- letta/orm/errors.py +8 -0
- letta/orm/file.py +3 -2
- letta/orm/mixins.py +3 -14
- letta/orm/organization.py +19 -3
- letta/orm/passage.py +59 -23
- letta/orm/source.py +4 -0
- letta/orm/sqlalchemy_base.py +25 -18
- letta/prompts/system/memgpt_modified_chat.txt +1 -1
- letta/prompts/system/memgpt_modified_o1.txt +1 -1
- letta/providers.py +2 -0
- letta/schemas/agent.py +35 -0
- letta/schemas/embedding_config.py +20 -2
- letta/schemas/passage.py +1 -1
- letta/schemas/sandbox_config.py +2 -1
- letta/server/rest_api/app.py +43 -5
- letta/server/rest_api/routers/v1/tools.py +1 -1
- letta/server/rest_api/utils.py +24 -5
- letta/server/server.py +105 -164
- letta/server/ws_api/server.py +1 -1
- letta/services/agent_manager.py +344 -9
- letta/services/passage_manager.py +76 -100
- letta/services/tool_execution_sandbox.py +54 -45
- letta/settings.py +10 -5
- letta/utils.py +8 -0
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/METADATA +6 -6
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/RECORD +41 -41
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.4.dev20241216104246.dist-info → letta_nightly-0.6.5.dev20241218055539.dist-info}/entry_points.txt +0 -0
letta/services/agent_manager.py
CHANGED
|
@@ -1,16 +1,26 @@
|
|
|
1
1
|
from typing import Dict, List, Optional
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
import numpy as np
|
|
2
4
|
|
|
3
|
-
from
|
|
5
|
+
from sqlalchemy import select, union_all, literal, func, Select
|
|
6
|
+
|
|
7
|
+
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
|
8
|
+
from letta.embeddings import embedding_model
|
|
9
|
+
from letta.log import get_logger
|
|
4
10
|
from letta.orm import Agent as AgentModel
|
|
5
11
|
from letta.orm import Block as BlockModel
|
|
6
12
|
from letta.orm import Source as SourceModel
|
|
7
13
|
from letta.orm import Tool as ToolModel
|
|
14
|
+
from letta.orm import AgentPassage, SourcePassage
|
|
15
|
+
from letta.orm import SourcesAgents
|
|
8
16
|
from letta.orm.errors import NoResultFound
|
|
17
|
+
from letta.orm.sqlite_functions import adapt_array
|
|
9
18
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
|
10
19
|
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
|
11
20
|
from letta.schemas.block import Block as PydanticBlock
|
|
12
21
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
13
22
|
from letta.schemas.llm_config import LLMConfig
|
|
23
|
+
from letta.schemas.passage import Passage as PydanticPassage
|
|
14
24
|
from letta.schemas.source import Source as PydanticSource
|
|
15
25
|
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
|
16
26
|
from letta.schemas.user import User as PydanticUser
|
|
@@ -20,11 +30,13 @@ from letta.services.helpers.agent_manager_helper import (
|
|
|
20
30
|
_process_tags,
|
|
21
31
|
derive_system_message,
|
|
22
32
|
)
|
|
23
|
-
from letta.services.passage_manager import PassageManager
|
|
24
33
|
from letta.services.source_manager import SourceManager
|
|
25
34
|
from letta.services.tool_manager import ToolManager
|
|
35
|
+
from letta.settings import settings
|
|
26
36
|
from letta.utils import enforce_types
|
|
27
37
|
|
|
38
|
+
logger = get_logger(__name__)
|
|
39
|
+
|
|
28
40
|
|
|
29
41
|
# Agent Manager Class
|
|
30
42
|
class AgentManager:
|
|
@@ -49,6 +61,9 @@ class AgentManager:
|
|
|
49
61
|
) -> PydanticAgentState:
|
|
50
62
|
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)
|
|
51
63
|
|
|
64
|
+
if not agent_create.llm_config or not agent_create.embedding_config:
|
|
65
|
+
raise ValueError("llm_config and embedding_config are required")
|
|
66
|
+
|
|
52
67
|
# create blocks (note: cannot be linked into the agent_id is created)
|
|
53
68
|
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
|
54
69
|
for create_block in agent_create.memory_blocks:
|
|
@@ -226,13 +241,6 @@ class AgentManager:
|
|
|
226
241
|
with self.session_maker() as session:
|
|
227
242
|
# Retrieve the agent
|
|
228
243
|
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
|
229
|
-
|
|
230
|
-
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
|
|
231
|
-
# TODO: This is done very hacky on purpose
|
|
232
|
-
# TODO: 1000 limit is also wack
|
|
233
|
-
passage_manager = PassageManager()
|
|
234
|
-
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
|
|
235
|
-
|
|
236
244
|
agent_state = agent.to_pydantic()
|
|
237
245
|
agent.hard_delete(session)
|
|
238
246
|
return agent_state
|
|
@@ -403,3 +411,330 @@ class AgentManager:
|
|
|
403
411
|
|
|
404
412
|
agent.update(session, actor=actor)
|
|
405
413
|
return agent.to_pydantic()
|
|
414
|
+
|
|
415
|
+
# ======================================================================================================================
|
|
416
|
+
# Passage Management
|
|
417
|
+
# ======================================================================================================================
|
|
418
|
+
def _build_passage_query(
|
|
419
|
+
self,
|
|
420
|
+
actor: PydanticUser,
|
|
421
|
+
agent_id: Optional[str] = None,
|
|
422
|
+
file_id: Optional[str] = None,
|
|
423
|
+
query_text: Optional[str] = None,
|
|
424
|
+
start_date: Optional[datetime] = None,
|
|
425
|
+
end_date: Optional[datetime] = None,
|
|
426
|
+
cursor: Optional[str] = None,
|
|
427
|
+
source_id: Optional[str] = None,
|
|
428
|
+
embed_query: bool = False,
|
|
429
|
+
ascending: bool = True,
|
|
430
|
+
embedding_config: Optional[EmbeddingConfig] = None,
|
|
431
|
+
agent_only: bool = False,
|
|
432
|
+
) -> Select:
|
|
433
|
+
"""Helper function to build the base passage query with all filters applied.
|
|
434
|
+
|
|
435
|
+
Returns the query before any limit or count operations are applied.
|
|
436
|
+
"""
|
|
437
|
+
embedded_text = None
|
|
438
|
+
if embed_query:
|
|
439
|
+
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
|
440
|
+
assert query_text is not None, "query_text must be specified for vector search"
|
|
441
|
+
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
|
442
|
+
embedded_text = np.array(embedded_text)
|
|
443
|
+
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
|
444
|
+
|
|
445
|
+
with self.session_maker() as session:
|
|
446
|
+
# Start with base query for source passages
|
|
447
|
+
source_passages = None
|
|
448
|
+
if not agent_only: # Include source passages
|
|
449
|
+
if agent_id is not None:
|
|
450
|
+
source_passages = (
|
|
451
|
+
select(
|
|
452
|
+
SourcePassage,
|
|
453
|
+
literal(None).label('agent_id')
|
|
454
|
+
)
|
|
455
|
+
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
|
456
|
+
.where(SourcesAgents.agent_id == agent_id)
|
|
457
|
+
.where(SourcePassage.organization_id == actor.organization_id)
|
|
458
|
+
)
|
|
459
|
+
else:
|
|
460
|
+
source_passages = (
|
|
461
|
+
select(
|
|
462
|
+
SourcePassage,
|
|
463
|
+
literal(None).label('agent_id')
|
|
464
|
+
)
|
|
465
|
+
.where(SourcePassage.organization_id == actor.organization_id)
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
if source_id:
|
|
469
|
+
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
|
470
|
+
if file_id:
|
|
471
|
+
source_passages = source_passages.where(SourcePassage.file_id == file_id)
|
|
472
|
+
|
|
473
|
+
# Add agent passages query
|
|
474
|
+
agent_passages = None
|
|
475
|
+
if agent_id is not None:
|
|
476
|
+
agent_passages = (
|
|
477
|
+
select(
|
|
478
|
+
AgentPassage.id,
|
|
479
|
+
AgentPassage.text,
|
|
480
|
+
AgentPassage.embedding_config,
|
|
481
|
+
AgentPassage.metadata_,
|
|
482
|
+
AgentPassage.embedding,
|
|
483
|
+
AgentPassage.created_at,
|
|
484
|
+
AgentPassage.updated_at,
|
|
485
|
+
AgentPassage.is_deleted,
|
|
486
|
+
AgentPassage._created_by_id,
|
|
487
|
+
AgentPassage._last_updated_by_id,
|
|
488
|
+
AgentPassage.organization_id,
|
|
489
|
+
literal(None).label('file_id'),
|
|
490
|
+
literal(None).label('source_id'),
|
|
491
|
+
AgentPassage.agent_id
|
|
492
|
+
)
|
|
493
|
+
.where(AgentPassage.agent_id == agent_id)
|
|
494
|
+
.where(AgentPassage.organization_id == actor.organization_id)
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Combine queries
|
|
498
|
+
if source_passages is not None and agent_passages is not None:
|
|
499
|
+
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
|
|
500
|
+
elif agent_passages is not None:
|
|
501
|
+
combined_query = agent_passages.cte('combined_passages')
|
|
502
|
+
elif source_passages is not None:
|
|
503
|
+
combined_query = source_passages.cte('combined_passages')
|
|
504
|
+
else:
|
|
505
|
+
raise ValueError("No passages found")
|
|
506
|
+
|
|
507
|
+
# Build main query from combined CTE
|
|
508
|
+
main_query = select(combined_query)
|
|
509
|
+
|
|
510
|
+
# Apply filters
|
|
511
|
+
if start_date:
|
|
512
|
+
main_query = main_query.where(combined_query.c.created_at >= start_date)
|
|
513
|
+
if end_date:
|
|
514
|
+
main_query = main_query.where(combined_query.c.created_at <= end_date)
|
|
515
|
+
if source_id:
|
|
516
|
+
main_query = main_query.where(combined_query.c.source_id == source_id)
|
|
517
|
+
if file_id:
|
|
518
|
+
main_query = main_query.where(combined_query.c.file_id == file_id)
|
|
519
|
+
|
|
520
|
+
# Vector search
|
|
521
|
+
if embedded_text:
|
|
522
|
+
if settings.letta_pg_uri_no_default:
|
|
523
|
+
# PostgreSQL with pgvector
|
|
524
|
+
main_query = main_query.order_by(
|
|
525
|
+
combined_query.c.embedding.cosine_distance(embedded_text).asc()
|
|
526
|
+
)
|
|
527
|
+
else:
|
|
528
|
+
# SQLite with custom vector type
|
|
529
|
+
query_embedding_binary = adapt_array(embedded_text)
|
|
530
|
+
if ascending:
|
|
531
|
+
main_query = main_query.order_by(
|
|
532
|
+
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
|
533
|
+
combined_query.c.created_at.asc(),
|
|
534
|
+
combined_query.c.id.asc()
|
|
535
|
+
)
|
|
536
|
+
else:
|
|
537
|
+
main_query = main_query.order_by(
|
|
538
|
+
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
|
539
|
+
combined_query.c.created_at.desc(),
|
|
540
|
+
combined_query.c.id.asc()
|
|
541
|
+
)
|
|
542
|
+
else:
|
|
543
|
+
if query_text:
|
|
544
|
+
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
|
|
545
|
+
|
|
546
|
+
# Handle cursor-based pagination
|
|
547
|
+
if cursor:
|
|
548
|
+
cursor_query = select(combined_query.c.created_at).where(
|
|
549
|
+
combined_query.c.id == cursor
|
|
550
|
+
).scalar_subquery()
|
|
551
|
+
|
|
552
|
+
if ascending:
|
|
553
|
+
main_query = main_query.where(
|
|
554
|
+
combined_query.c.created_at > cursor_query
|
|
555
|
+
)
|
|
556
|
+
else:
|
|
557
|
+
main_query = main_query.where(
|
|
558
|
+
combined_query.c.created_at < cursor_query
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
# Add ordering if not already ordered by similarity
|
|
562
|
+
if not embed_query:
|
|
563
|
+
if ascending:
|
|
564
|
+
main_query = main_query.order_by(
|
|
565
|
+
combined_query.c.created_at.asc(),
|
|
566
|
+
combined_query.c.id.asc(),
|
|
567
|
+
)
|
|
568
|
+
else:
|
|
569
|
+
main_query = main_query.order_by(
|
|
570
|
+
combined_query.c.created_at.desc(),
|
|
571
|
+
combined_query.c.id.asc(),
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
return main_query
|
|
575
|
+
|
|
576
|
+
@enforce_types
|
|
577
|
+
def list_passages(
|
|
578
|
+
self,
|
|
579
|
+
actor: PydanticUser,
|
|
580
|
+
agent_id: Optional[str] = None,
|
|
581
|
+
file_id: Optional[str] = None,
|
|
582
|
+
limit: Optional[int] = 50,
|
|
583
|
+
query_text: Optional[str] = None,
|
|
584
|
+
start_date: Optional[datetime] = None,
|
|
585
|
+
end_date: Optional[datetime] = None,
|
|
586
|
+
cursor: Optional[str] = None,
|
|
587
|
+
source_id: Optional[str] = None,
|
|
588
|
+
embed_query: bool = False,
|
|
589
|
+
ascending: bool = True,
|
|
590
|
+
embedding_config: Optional[EmbeddingConfig] = None,
|
|
591
|
+
agent_only: bool = False
|
|
592
|
+
) -> List[PydanticPassage]:
|
|
593
|
+
"""Lists all passages attached to an agent."""
|
|
594
|
+
with self.session_maker() as session:
|
|
595
|
+
main_query = self._build_passage_query(
|
|
596
|
+
actor=actor,
|
|
597
|
+
agent_id=agent_id,
|
|
598
|
+
file_id=file_id,
|
|
599
|
+
query_text=query_text,
|
|
600
|
+
start_date=start_date,
|
|
601
|
+
end_date=end_date,
|
|
602
|
+
cursor=cursor,
|
|
603
|
+
source_id=source_id,
|
|
604
|
+
embed_query=embed_query,
|
|
605
|
+
ascending=ascending,
|
|
606
|
+
embedding_config=embedding_config,
|
|
607
|
+
agent_only=agent_only,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# Add limit
|
|
611
|
+
if limit:
|
|
612
|
+
main_query = main_query.limit(limit)
|
|
613
|
+
|
|
614
|
+
# Execute query
|
|
615
|
+
results = list(session.execute(main_query))
|
|
616
|
+
|
|
617
|
+
passages = []
|
|
618
|
+
for row in results:
|
|
619
|
+
data = dict(row._mapping)
|
|
620
|
+
if data['agent_id'] is not None:
|
|
621
|
+
# This is an AgentPassage - remove source fields
|
|
622
|
+
data.pop('source_id', None)
|
|
623
|
+
data.pop('file_id', None)
|
|
624
|
+
passage = AgentPassage(**data)
|
|
625
|
+
else:
|
|
626
|
+
# This is a SourcePassage - remove agent field
|
|
627
|
+
data.pop('agent_id', None)
|
|
628
|
+
passage = SourcePassage(**data)
|
|
629
|
+
passages.append(passage)
|
|
630
|
+
|
|
631
|
+
return [p.to_pydantic() for p in passages]
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
@enforce_types
|
|
635
|
+
def passage_size(
|
|
636
|
+
self,
|
|
637
|
+
actor: PydanticUser,
|
|
638
|
+
agent_id: Optional[str] = None,
|
|
639
|
+
file_id: Optional[str] = None,
|
|
640
|
+
query_text: Optional[str] = None,
|
|
641
|
+
start_date: Optional[datetime] = None,
|
|
642
|
+
end_date: Optional[datetime] = None,
|
|
643
|
+
cursor: Optional[str] = None,
|
|
644
|
+
source_id: Optional[str] = None,
|
|
645
|
+
embed_query: bool = False,
|
|
646
|
+
ascending: bool = True,
|
|
647
|
+
embedding_config: Optional[EmbeddingConfig] = None,
|
|
648
|
+
agent_only: bool = False
|
|
649
|
+
) -> int:
|
|
650
|
+
"""Returns the count of passages matching the given criteria."""
|
|
651
|
+
with self.session_maker() as session:
|
|
652
|
+
main_query = self._build_passage_query(
|
|
653
|
+
actor=actor,
|
|
654
|
+
agent_id=agent_id,
|
|
655
|
+
file_id=file_id,
|
|
656
|
+
query_text=query_text,
|
|
657
|
+
start_date=start_date,
|
|
658
|
+
end_date=end_date,
|
|
659
|
+
cursor=cursor,
|
|
660
|
+
source_id=source_id,
|
|
661
|
+
embed_query=embed_query,
|
|
662
|
+
ascending=ascending,
|
|
663
|
+
embedding_config=embedding_config,
|
|
664
|
+
agent_only=agent_only,
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
# Convert to count query
|
|
668
|
+
count_query = select(func.count()).select_from(main_query.subquery())
|
|
669
|
+
return session.scalar(count_query) or 0
|
|
670
|
+
|
|
671
|
+
# ======================================================================================================================
|
|
672
|
+
# Tool Management
|
|
673
|
+
# ======================================================================================================================
|
|
674
|
+
@enforce_types
|
|
675
|
+
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
|
676
|
+
"""
|
|
677
|
+
Attaches a tool to an agent.
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
agent_id: ID of the agent to attach the tool to.
|
|
681
|
+
tool_id: ID of the tool to attach.
|
|
682
|
+
actor: User performing the action.
|
|
683
|
+
|
|
684
|
+
Raises:
|
|
685
|
+
NoResultFound: If the agent or tool is not found.
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
PydanticAgentState: The updated agent state.
|
|
689
|
+
"""
|
|
690
|
+
with self.session_maker() as session:
|
|
691
|
+
# Verify the agent exists and user has permission to access it
|
|
692
|
+
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
|
693
|
+
|
|
694
|
+
# Use the _process_relationship helper to attach the tool
|
|
695
|
+
_process_relationship(
|
|
696
|
+
session=session,
|
|
697
|
+
agent=agent,
|
|
698
|
+
relationship_name="tools",
|
|
699
|
+
model_class=ToolModel,
|
|
700
|
+
item_ids=[tool_id],
|
|
701
|
+
allow_partial=False, # Ensure the tool exists
|
|
702
|
+
replace=False, # Extend the existing tools
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# Commit and refresh the agent
|
|
706
|
+
agent.update(session, actor=actor)
|
|
707
|
+
return agent.to_pydantic()
|
|
708
|
+
|
|
709
|
+
@enforce_types
|
|
710
|
+
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
|
711
|
+
"""
|
|
712
|
+
Detaches a tool from an agent.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
agent_id: ID of the agent to detach the tool from.
|
|
716
|
+
tool_id: ID of the tool to detach.
|
|
717
|
+
actor: User performing the action.
|
|
718
|
+
|
|
719
|
+
Raises:
|
|
720
|
+
NoResultFound: If the agent or tool is not found.
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
PydanticAgentState: The updated agent state.
|
|
724
|
+
"""
|
|
725
|
+
with self.session_maker() as session:
|
|
726
|
+
# Verify the agent exists and user has permission to access it
|
|
727
|
+
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
|
728
|
+
|
|
729
|
+
# Filter out the tool to be detached
|
|
730
|
+
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
|
|
731
|
+
|
|
732
|
+
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
|
|
733
|
+
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
|
734
|
+
|
|
735
|
+
# Update the tools relationship
|
|
736
|
+
agent.tools = remaining_tools
|
|
737
|
+
|
|
738
|
+
# Commit and refresh the agent
|
|
739
|
+
agent.update(session, actor=actor)
|
|
740
|
+
return agent.to_pydantic()
|
|
@@ -1,12 +1,13 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
1
|
from typing import List, Optional
|
|
3
|
-
|
|
2
|
+
from datetime import datetime
|
|
4
3
|
import numpy as np
|
|
5
4
|
|
|
5
|
+
from sqlalchemy import select, union_all, literal
|
|
6
|
+
|
|
6
7
|
from letta.constants import MAX_EMBEDDING_DIM
|
|
7
8
|
from letta.embeddings import embedding_model, parse_and_chunk_text
|
|
8
9
|
from letta.orm.errors import NoResultFound
|
|
9
|
-
from letta.orm.passage import
|
|
10
|
+
from letta.orm.passage import AgentPassage, SourcePassage
|
|
10
11
|
from letta.schemas.agent import AgentState
|
|
11
12
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
12
13
|
from letta.schemas.passage import Passage as PydanticPassage
|
|
@@ -14,6 +15,7 @@ from letta.schemas.user import User as PydanticUser
|
|
|
14
15
|
from letta.utils import enforce_types
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
|
|
17
19
|
class PassageManager:
|
|
18
20
|
"""Manager class to handle business logic related to Passages."""
|
|
19
21
|
|
|
@@ -26,14 +28,51 @@ class PassageManager:
|
|
|
26
28
|
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
|
27
29
|
"""Fetch a passage by ID."""
|
|
28
30
|
with self.session_maker() as session:
|
|
29
|
-
|
|
30
|
-
|
|
31
|
+
# Try source passages first
|
|
32
|
+
try:
|
|
33
|
+
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
|
34
|
+
return passage.to_pydantic()
|
|
35
|
+
except NoResultFound:
|
|
36
|
+
# Try archival passages
|
|
37
|
+
try:
|
|
38
|
+
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
|
39
|
+
return passage.to_pydantic()
|
|
40
|
+
except NoResultFound:
|
|
41
|
+
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
|
31
42
|
|
|
32
43
|
@enforce_types
|
|
33
44
|
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
|
34
|
-
"""Create a new passage."""
|
|
45
|
+
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
|
46
|
+
# Common fields for both passage types
|
|
47
|
+
data = pydantic_passage.model_dump()
|
|
48
|
+
common_fields = {
|
|
49
|
+
"id": data.get("id"),
|
|
50
|
+
"text": data["text"],
|
|
51
|
+
"embedding": data["embedding"],
|
|
52
|
+
"embedding_config": data["embedding_config"],
|
|
53
|
+
"organization_id": data["organization_id"],
|
|
54
|
+
"metadata_": data.get("metadata_", {}),
|
|
55
|
+
"is_deleted": data.get("is_deleted", False),
|
|
56
|
+
"created_at": data.get("created_at", datetime.utcnow()),
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
if "agent_id" in data and data["agent_id"]:
|
|
60
|
+
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
|
|
61
|
+
agent_fields = {
|
|
62
|
+
"agent_id": data["agent_id"],
|
|
63
|
+
}
|
|
64
|
+
passage = AgentPassage(**common_fields, **agent_fields)
|
|
65
|
+
elif "source_id" in data and data["source_id"]:
|
|
66
|
+
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
|
|
67
|
+
source_fields = {
|
|
68
|
+
"source_id": data["source_id"],
|
|
69
|
+
"file_id": data.get("file_id"),
|
|
70
|
+
}
|
|
71
|
+
passage = SourcePassage(**common_fields, **source_fields)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError("Passage must have either agent_id or source_id")
|
|
74
|
+
|
|
35
75
|
with self.session_maker() as session:
|
|
36
|
-
passage = PassageModel(**pydantic_passage.model_dump())
|
|
37
76
|
passage.create(session, actor=actor)
|
|
38
77
|
return passage.to_pydantic()
|
|
39
78
|
|
|
@@ -93,14 +132,23 @@ class PassageManager:
|
|
|
93
132
|
raise ValueError("Passage ID must be provided.")
|
|
94
133
|
|
|
95
134
|
with self.session_maker() as session:
|
|
96
|
-
#
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
135
|
+
# Try source passages first
|
|
136
|
+
try:
|
|
137
|
+
curr_passage = SourcePassage.read(
|
|
138
|
+
db_session=session,
|
|
139
|
+
identifier=passage_id,
|
|
140
|
+
actor=actor,
|
|
141
|
+
)
|
|
142
|
+
except NoResultFound:
|
|
143
|
+
# Try agent passages
|
|
144
|
+
try:
|
|
145
|
+
curr_passage = AgentPassage.read(
|
|
146
|
+
db_session=session,
|
|
147
|
+
identifier=passage_id,
|
|
148
|
+
actor=actor,
|
|
149
|
+
)
|
|
150
|
+
except NoResultFound:
|
|
151
|
+
raise ValueError(f"Passage with id {passage_id} does not exist.")
|
|
104
152
|
|
|
105
153
|
# Update the database record with values from the provided record
|
|
106
154
|
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
|
|
@@ -113,104 +161,32 @@ class PassageManager:
|
|
|
113
161
|
|
|
114
162
|
@enforce_types
|
|
115
163
|
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
|
116
|
-
"""Delete a passage."""
|
|
164
|
+
"""Delete a passage from either source or archival passages."""
|
|
117
165
|
if not passage_id:
|
|
118
166
|
raise ValueError("Passage ID must be provided.")
|
|
119
167
|
|
|
120
168
|
with self.session_maker() as session:
|
|
169
|
+
# Try source passages first
|
|
121
170
|
try:
|
|
122
|
-
passage =
|
|
171
|
+
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
|
123
172
|
passage.hard_delete(session, actor=actor)
|
|
173
|
+
return True
|
|
124
174
|
except NoResultFound:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
file_id: Optional[str] = None,
|
|
133
|
-
cursor: Optional[str] = None,
|
|
134
|
-
limit: Optional[int] = 50,
|
|
135
|
-
query_text: Optional[str] = None,
|
|
136
|
-
start_date: Optional[datetime] = None,
|
|
137
|
-
end_date: Optional[datetime] = None,
|
|
138
|
-
ascending: bool = True,
|
|
139
|
-
source_id: Optional[str] = None,
|
|
140
|
-
embed_query: bool = False,
|
|
141
|
-
embedding_config: Optional[EmbeddingConfig] = None,
|
|
142
|
-
) -> List[PydanticPassage]:
|
|
143
|
-
"""List passages with pagination."""
|
|
144
|
-
with self.session_maker() as session:
|
|
145
|
-
filters = {"organization_id": actor.organization_id}
|
|
146
|
-
if agent_id:
|
|
147
|
-
filters["agent_id"] = agent_id
|
|
148
|
-
if file_id:
|
|
149
|
-
filters["file_id"] = file_id
|
|
150
|
-
if source_id:
|
|
151
|
-
filters["source_id"] = source_id
|
|
152
|
-
|
|
153
|
-
embedded_text = None
|
|
154
|
-
if embed_query:
|
|
155
|
-
assert embedding_config is not None
|
|
156
|
-
|
|
157
|
-
# Embed the text
|
|
158
|
-
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
|
159
|
-
|
|
160
|
-
# Pad the embedding with zeros
|
|
161
|
-
embedded_text = np.array(embedded_text)
|
|
162
|
-
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
|
163
|
-
|
|
164
|
-
results = PassageModel.list(
|
|
165
|
-
db_session=session,
|
|
166
|
-
cursor=cursor,
|
|
167
|
-
start_date=start_date,
|
|
168
|
-
end_date=end_date,
|
|
169
|
-
limit=limit,
|
|
170
|
-
ascending=ascending,
|
|
171
|
-
query_text=query_text if not embedded_text else None,
|
|
172
|
-
query_embedding=embedded_text,
|
|
173
|
-
**filters,
|
|
174
|
-
)
|
|
175
|
-
return [p.to_pydantic() for p in results]
|
|
176
|
-
|
|
177
|
-
@enforce_types
|
|
178
|
-
def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int:
|
|
179
|
-
"""Get the total count of messages with optional filters.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
actor : The user requesting the count
|
|
183
|
-
agent_id: The agent ID
|
|
184
|
-
"""
|
|
185
|
-
with self.session_maker() as session:
|
|
186
|
-
return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs)
|
|
175
|
+
# Try archival passages
|
|
176
|
+
try:
|
|
177
|
+
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
|
178
|
+
passage.hard_delete(session, actor=actor)
|
|
179
|
+
return True
|
|
180
|
+
except NoResultFound:
|
|
181
|
+
raise NoResultFound(f"Passage with id {passage_id} not found.")
|
|
187
182
|
|
|
188
183
|
def delete_passages(
|
|
189
184
|
self,
|
|
190
185
|
actor: PydanticUser,
|
|
191
|
-
|
|
192
|
-
file_id: Optional[str] = None,
|
|
193
|
-
start_date: Optional[datetime] = None,
|
|
194
|
-
end_date: Optional[datetime] = None,
|
|
195
|
-
limit: Optional[int] = 50,
|
|
196
|
-
cursor: Optional[str] = None,
|
|
197
|
-
query_text: Optional[str] = None,
|
|
198
|
-
source_id: Optional[str] = None,
|
|
186
|
+
passages: List[PydanticPassage],
|
|
199
187
|
) -> bool:
|
|
200
|
-
|
|
201
|
-
passages = self.list_passages(
|
|
202
|
-
actor=actor,
|
|
203
|
-
agent_id=agent_id,
|
|
204
|
-
file_id=file_id,
|
|
205
|
-
cursor=cursor,
|
|
206
|
-
limit=limit,
|
|
207
|
-
start_date=start_date,
|
|
208
|
-
end_date=end_date,
|
|
209
|
-
query_text=query_text,
|
|
210
|
-
source_id=source_id,
|
|
211
|
-
)
|
|
212
|
-
|
|
213
188
|
# TODO: This is very inefficient
|
|
214
189
|
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
|
215
190
|
for passage in passages:
|
|
216
191
|
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
|
192
|
+
return True
|