letta-nightly 0.7.30.dev20250603104343__py3-none-any.whl → 0.8.0.dev20250604201135__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +7 -1
- letta/agent.py +14 -7
- letta/agents/base_agent.py +1 -0
- letta/agents/ephemeral_summary_agent.py +104 -0
- letta/agents/helpers.py +35 -3
- letta/agents/letta_agent.py +492 -176
- letta/agents/letta_agent_batch.py +22 -16
- letta/agents/prompts/summary_system_prompt.txt +62 -0
- letta/agents/voice_agent.py +22 -7
- letta/agents/voice_sleeptime_agent.py +13 -8
- letta/constants.py +33 -1
- letta/data_sources/connectors.py +52 -36
- letta/errors.py +4 -0
- letta/functions/ast_parsers.py +13 -30
- letta/functions/function_sets/base.py +3 -1
- letta/functions/functions.py +2 -0
- letta/functions/mcp_client/base_client.py +151 -97
- letta/functions/mcp_client/sse_client.py +49 -31
- letta/functions/mcp_client/stdio_client.py +107 -106
- letta/functions/schema_generator.py +22 -22
- letta/groups/helpers.py +3 -4
- letta/groups/sleeptime_multi_agent.py +4 -4
- letta/groups/sleeptime_multi_agent_v2.py +22 -0
- letta/helpers/composio_helpers.py +16 -0
- letta/helpers/converters.py +20 -0
- letta/helpers/datetime_helpers.py +1 -6
- letta/helpers/tool_rule_solver.py +2 -1
- letta/interfaces/anthropic_streaming_interface.py +17 -2
- letta/interfaces/openai_chat_completions_streaming_interface.py +1 -0
- letta/interfaces/openai_streaming_interface.py +18 -2
- letta/llm_api/anthropic_client.py +24 -3
- letta/llm_api/google_ai_client.py +0 -15
- letta/llm_api/google_vertex_client.py +6 -5
- letta/llm_api/llm_client_base.py +15 -0
- letta/llm_api/openai.py +2 -2
- letta/llm_api/openai_client.py +60 -8
- letta/orm/__init__.py +2 -0
- letta/orm/agent.py +45 -43
- letta/orm/base.py +0 -2
- letta/orm/block.py +1 -0
- letta/orm/custom_columns.py +13 -0
- letta/orm/enums.py +5 -0
- letta/orm/file.py +3 -1
- letta/orm/files_agents.py +68 -0
- letta/orm/mcp_server.py +48 -0
- letta/orm/message.py +1 -0
- letta/orm/organization.py +11 -2
- letta/orm/passage.py +25 -10
- letta/orm/sandbox_config.py +5 -2
- letta/orm/sqlalchemy_base.py +171 -110
- letta/prompts/system/memgpt_base.txt +6 -1
- letta/prompts/system/memgpt_v2_chat.txt +57 -0
- letta/prompts/system/sleeptime.txt +2 -0
- letta/prompts/system/sleeptime_v2.txt +28 -0
- letta/schemas/agent.py +87 -20
- letta/schemas/block.py +7 -1
- letta/schemas/file.py +57 -0
- letta/schemas/mcp.py +74 -0
- letta/schemas/memory.py +5 -2
- letta/schemas/message.py +9 -0
- letta/schemas/openai/openai.py +0 -6
- letta/schemas/providers.py +33 -4
- letta/schemas/tool.py +26 -21
- letta/schemas/tool_execution_result.py +5 -0
- letta/server/db.py +23 -8
- letta/server/rest_api/app.py +73 -56
- letta/server/rest_api/interface.py +4 -4
- letta/server/rest_api/routers/v1/agents.py +132 -47
- letta/server/rest_api/routers/v1/blocks.py +3 -2
- letta/server/rest_api/routers/v1/embeddings.py +3 -3
- letta/server/rest_api/routers/v1/groups.py +3 -3
- letta/server/rest_api/routers/v1/jobs.py +14 -17
- letta/server/rest_api/routers/v1/organizations.py +10 -10
- letta/server/rest_api/routers/v1/providers.py +12 -10
- letta/server/rest_api/routers/v1/runs.py +3 -3
- letta/server/rest_api/routers/v1/sandbox_configs.py +12 -12
- letta/server/rest_api/routers/v1/sources.py +108 -43
- letta/server/rest_api/routers/v1/steps.py +8 -6
- letta/server/rest_api/routers/v1/tools.py +134 -95
- letta/server/rest_api/utils.py +12 -1
- letta/server/server.py +272 -73
- letta/services/agent_manager.py +246 -313
- letta/services/block_manager.py +30 -9
- letta/services/context_window_calculator/__init__.py +0 -0
- letta/services/context_window_calculator/context_window_calculator.py +150 -0
- letta/services/context_window_calculator/token_counter.py +82 -0
- letta/services/file_processor/__init__.py +0 -0
- letta/services/file_processor/chunker/__init__.py +0 -0
- letta/services/file_processor/chunker/llama_index_chunker.py +29 -0
- letta/services/file_processor/embedder/__init__.py +0 -0
- letta/services/file_processor/embedder/openai_embedder.py +84 -0
- letta/services/file_processor/file_processor.py +123 -0
- letta/services/file_processor/parser/__init__.py +0 -0
- letta/services/file_processor/parser/base_parser.py +9 -0
- letta/services/file_processor/parser/mistral_parser.py +54 -0
- letta/services/file_processor/types.py +0 -0
- letta/services/files_agents_manager.py +184 -0
- letta/services/group_manager.py +118 -0
- letta/services/helpers/agent_manager_helper.py +76 -21
- letta/services/helpers/tool_execution_helper.py +3 -0
- letta/services/helpers/tool_parser_helper.py +100 -0
- letta/services/identity_manager.py +44 -42
- letta/services/job_manager.py +21 -10
- letta/services/mcp/base_client.py +5 -2
- letta/services/mcp/sse_client.py +3 -5
- letta/services/mcp/stdio_client.py +3 -5
- letta/services/mcp_manager.py +281 -0
- letta/services/message_manager.py +40 -26
- letta/services/organization_manager.py +55 -19
- letta/services/passage_manager.py +211 -13
- letta/services/provider_manager.py +48 -2
- letta/services/sandbox_config_manager.py +105 -0
- letta/services/source_manager.py +4 -5
- letta/services/step_manager.py +9 -6
- letta/services/summarizer/summarizer.py +50 -23
- letta/services/telemetry_manager.py +7 -0
- letta/services/tool_executor/tool_execution_manager.py +11 -52
- letta/services/tool_executor/tool_execution_sandbox.py +4 -34
- letta/services/tool_executor/tool_executor.py +107 -105
- letta/services/tool_manager.py +56 -17
- letta/services/tool_sandbox/base.py +39 -92
- letta/services/tool_sandbox/e2b_sandbox.py +16 -11
- letta/services/tool_sandbox/local_sandbox.py +51 -23
- letta/services/user_manager.py +36 -3
- letta/settings.py +10 -3
- letta/templates/__init__.py +0 -0
- letta/templates/sandbox_code_file.py.j2 +47 -0
- letta/templates/template_helper.py +16 -0
- letta/tracing.py +30 -1
- letta/types/__init__.py +7 -0
- letta/utils.py +25 -1
- {letta_nightly-0.7.30.dev20250603104343.dist-info → letta_nightly-0.8.0.dev20250604201135.dist-info}/METADATA +7 -2
- {letta_nightly-0.7.30.dev20250603104343.dist-info → letta_nightly-0.8.0.dev20250604201135.dist-info}/RECORD +136 -110
- {letta_nightly-0.7.30.dev20250603104343.dist-info → letta_nightly-0.8.0.dev20250604201135.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.30.dev20250603104343.dist-info → letta_nightly-0.8.0.dev20250604201135.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.30.dev20250603104343.dist-info → letta_nightly-0.8.0.dev20250604201135.dist-info}/entry_points.txt +0 -0
@@ -49,7 +49,7 @@ class MessageManager:
|
|
49
49
|
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
50
50
|
"""Fetch messages by ID and return them in the requested order."""
|
51
51
|
with db_registry.session() as session:
|
52
|
-
results = MessageModel.
|
52
|
+
results = MessageModel.read_multiple(db_session=session, identifiers=message_ids, actor=actor)
|
53
53
|
return self._get_messages_by_id_postprocess(results, message_ids)
|
54
54
|
|
55
55
|
@enforce_types
|
@@ -57,10 +57,8 @@ class MessageManager:
|
|
57
57
|
async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
58
58
|
"""Fetch messages by ID and return them in the requested order. Async version of above function."""
|
59
59
|
async with db_registry.async_session() as session:
|
60
|
-
results = await MessageModel.
|
61
|
-
|
62
|
-
)
|
63
|
-
return self._get_messages_by_id_postprocess(results, message_ids)
|
60
|
+
results = await MessageModel.read_multiple_async(db_session=session, identifiers=message_ids, actor=actor)
|
61
|
+
return self._get_messages_by_id_postprocess(results, message_ids)
|
64
62
|
|
65
63
|
def _get_messages_by_id_postprocess(
|
66
64
|
self,
|
@@ -349,6 +347,29 @@ class MessageManager:
|
|
349
347
|
ascending=ascending,
|
350
348
|
)
|
351
349
|
|
350
|
+
@enforce_types
|
351
|
+
@trace_method
|
352
|
+
async def list_user_messages_for_agent_async(
|
353
|
+
self,
|
354
|
+
agent_id: str,
|
355
|
+
actor: PydanticUser,
|
356
|
+
after: Optional[str] = None,
|
357
|
+
before: Optional[str] = None,
|
358
|
+
query_text: Optional[str] = None,
|
359
|
+
limit: Optional[int] = 50,
|
360
|
+
ascending: bool = True,
|
361
|
+
) -> List[PydanticMessage]:
|
362
|
+
return await self.list_messages_for_agent_async(
|
363
|
+
agent_id=agent_id,
|
364
|
+
actor=actor,
|
365
|
+
after=after,
|
366
|
+
before=before,
|
367
|
+
query_text=query_text,
|
368
|
+
roles=[MessageRole.user],
|
369
|
+
limit=limit,
|
370
|
+
ascending=ascending,
|
371
|
+
)
|
372
|
+
|
352
373
|
@enforce_types
|
353
374
|
@trace_method
|
354
375
|
def list_messages_for_agent(
|
@@ -400,24 +421,17 @@ class MessageManager:
|
|
400
421
|
if group_id:
|
401
422
|
query = query.filter(MessageModel.group_id == group_id)
|
402
423
|
|
403
|
-
# If query_text is provided, filter messages
|
404
|
-
# whose text includes the query string (case-insensitive).
|
424
|
+
# If query_text is provided, filter messages using subquery + json_array_elements.
|
405
425
|
if query_text:
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
elif dialect_name == "sqlite": # using `json_each` and JSON path expressions
|
414
|
-
json_item = func.json_each(MessageModel.content).alias("json_item")
|
415
|
-
subquery_sql = text(
|
416
|
-
"json_extract(value, '$.type') = 'text' AND lower(json_extract(value, '$.text')) LIKE lower(:query_text)"
|
426
|
+
content_element = func.json_array_elements(MessageModel.content).alias("content_element")
|
427
|
+
query = query.filter(
|
428
|
+
exists(
|
429
|
+
select(1)
|
430
|
+
.select_from(content_element)
|
431
|
+
.where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
|
432
|
+
.params(query_text=f"%{query_text}%")
|
417
433
|
)
|
418
|
-
|
419
|
-
|
420
|
-
query = query.filter(exists(subquery.params(query_text=f"%{query_text}%")))
|
434
|
+
)
|
421
435
|
|
422
436
|
# If role(s) are provided, filter messages by those roles.
|
423
437
|
if roles:
|
@@ -557,23 +571,23 @@ class MessageManager:
|
|
557
571
|
|
558
572
|
@enforce_types
|
559
573
|
@trace_method
|
560
|
-
def
|
574
|
+
async def delete_all_messages_for_agent_async(self, agent_id: str, actor: PydanticUser) -> int:
|
561
575
|
"""
|
562
576
|
Efficiently deletes all messages associated with a given agent_id,
|
563
577
|
while enforcing permission checks and avoiding any ORM‑level loads.
|
564
578
|
"""
|
565
|
-
with db_registry.
|
579
|
+
async with db_registry.async_session() as session:
|
566
580
|
# 1) verify the agent exists and the actor has access
|
567
|
-
AgentModel.
|
581
|
+
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
568
582
|
|
569
583
|
# 2) issue a CORE DELETE against the mapped class
|
570
584
|
stmt = (
|
571
585
|
delete(MessageModel).where(MessageModel.agent_id == agent_id).where(MessageModel.organization_id == actor.organization_id)
|
572
586
|
)
|
573
|
-
result = session.execute(stmt)
|
587
|
+
result = await session.execute(stmt)
|
574
588
|
|
575
589
|
# 3) commit once
|
576
|
-
session.commit()
|
590
|
+
await session.commit()
|
577
591
|
|
578
592
|
# 4) return the number of rows deleted
|
579
593
|
return result.rowcount
|
@@ -17,9 +17,9 @@ class OrganizationManager:
|
|
17
17
|
|
18
18
|
@enforce_types
|
19
19
|
@trace_method
|
20
|
-
def
|
20
|
+
async def get_default_organization_async(self) -> PydanticOrganization:
|
21
21
|
"""Fetch the default organization."""
|
22
|
-
return self.
|
22
|
+
return await self.get_organization_by_id_async(self.DEFAULT_ORG_ID)
|
23
23
|
|
24
24
|
@enforce_types
|
25
25
|
@trace_method
|
@@ -29,52 +29,80 @@ class OrganizationManager:
|
|
29
29
|
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
30
30
|
return organization.to_pydantic()
|
31
31
|
|
32
|
+
@enforce_types
|
33
|
+
@trace_method
|
34
|
+
async def get_organization_by_id_async(self, org_id: str) -> Optional[PydanticOrganization]:
|
35
|
+
"""Fetch an organization by ID."""
|
36
|
+
async with db_registry.async_session() as session:
|
37
|
+
organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
38
|
+
return organization.to_pydantic()
|
39
|
+
|
32
40
|
@enforce_types
|
33
41
|
@trace_method
|
34
42
|
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
43
|
+
"""Create the default organization."""
|
44
|
+
with db_registry.session() as session:
|
45
|
+
try:
|
46
|
+
organization = OrganizationModel.read(db_session=session, identifier=pydantic_org.id)
|
47
|
+
return organization.to_pydantic()
|
48
|
+
except:
|
49
|
+
organization = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
|
50
|
+
organization = organization.create(session)
|
51
|
+
return organization.to_pydantic()
|
52
|
+
|
53
|
+
@enforce_types
|
54
|
+
@trace_method
|
55
|
+
async def create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
35
56
|
"""Create a new organization."""
|
36
57
|
try:
|
37
|
-
org = self.
|
58
|
+
org = await self.get_organization_by_id_async(pydantic_org.id)
|
38
59
|
return org
|
39
60
|
except NoResultFound:
|
40
|
-
return self.
|
61
|
+
return await self._create_organization_async(pydantic_org=pydantic_org)
|
41
62
|
|
42
63
|
@enforce_types
|
43
64
|
@trace_method
|
44
|
-
def
|
45
|
-
with db_registry.
|
65
|
+
async def _create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
66
|
+
async with db_registry.async_session() as session:
|
46
67
|
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
|
47
|
-
org.
|
68
|
+
await org.create_async(session)
|
48
69
|
return org.to_pydantic()
|
49
70
|
|
50
71
|
@enforce_types
|
51
72
|
@trace_method
|
52
73
|
def create_default_organization(self) -> PydanticOrganization:
|
53
74
|
"""Create the default organization."""
|
54
|
-
|
75
|
+
pydantic_org = PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
|
76
|
+
return self.create_organization(pydantic_org)
|
55
77
|
|
56
78
|
@enforce_types
|
57
79
|
@trace_method
|
58
|
-
def
|
80
|
+
async def create_default_organization_async(self) -> PydanticOrganization:
|
81
|
+
"""Create the default organization."""
|
82
|
+
return await self.create_organization_async(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
|
83
|
+
|
84
|
+
@enforce_types
|
85
|
+
@trace_method
|
86
|
+
async def update_organization_name_using_id_async(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
|
59
87
|
"""Update an organization."""
|
60
|
-
with db_registry.
|
61
|
-
org = OrganizationModel.
|
88
|
+
async with db_registry.async_session() as session:
|
89
|
+
org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
62
90
|
if name:
|
63
91
|
org.name = name
|
64
|
-
org.
|
92
|
+
await org.update_async(session)
|
65
93
|
return org.to_pydantic()
|
66
94
|
|
67
95
|
@enforce_types
|
68
96
|
@trace_method
|
69
|
-
def
|
97
|
+
async def update_organization_async(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
|
70
98
|
"""Update an organization."""
|
71
|
-
with db_registry.
|
72
|
-
org = OrganizationModel.
|
99
|
+
async with db_registry.async_session() as session:
|
100
|
+
org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
73
101
|
if org_update.name:
|
74
102
|
org.name = org_update.name
|
75
103
|
if org_update.privileged_tools:
|
76
104
|
org.privileged_tools = org_update.privileged_tools
|
77
|
-
org.
|
105
|
+
await org.update_async(session)
|
78
106
|
return org.to_pydantic()
|
79
107
|
|
80
108
|
@enforce_types
|
@@ -87,10 +115,18 @@ class OrganizationManager:
|
|
87
115
|
|
88
116
|
@enforce_types
|
89
117
|
@trace_method
|
90
|
-
def
|
118
|
+
async def delete_organization_by_id_async(self, org_id: str):
|
119
|
+
"""Delete an organization by marking it as deleted."""
|
120
|
+
async with db_registry.async_session() as session:
|
121
|
+
organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
122
|
+
await organization.hard_delete_async(session)
|
123
|
+
|
124
|
+
@enforce_types
|
125
|
+
@trace_method
|
126
|
+
async def list_organizations_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
|
91
127
|
"""List all organizations with optional pagination."""
|
92
|
-
with db_registry.
|
93
|
-
organizations = OrganizationModel.
|
128
|
+
async with db_registry.async_session() as session:
|
129
|
+
organizations = await OrganizationModel.list_async(
|
94
130
|
db_session=session,
|
95
131
|
after=after,
|
96
132
|
limit=limit,
|
@@ -1,7 +1,11 @@
|
|
1
|
+
import asyncio
|
1
2
|
from datetime import datetime, timezone
|
3
|
+
from functools import lru_cache
|
2
4
|
from typing import List, Optional
|
3
5
|
|
4
|
-
from
|
6
|
+
from async_lru import alru_cache
|
7
|
+
from openai import AsyncOpenAI, OpenAI
|
8
|
+
from sqlalchemy import select
|
5
9
|
|
6
10
|
from letta.constants import MAX_EMBEDDING_DIM
|
7
11
|
from letta.embeddings import embedding_model, parse_and_chunk_text
|
@@ -15,6 +19,26 @@ from letta.tracing import trace_method
|
|
15
19
|
from letta.utils import enforce_types
|
16
20
|
|
17
21
|
|
22
|
+
# TODO: Add redis-backed caching for backend
|
23
|
+
@lru_cache(maxsize=8192)
|
24
|
+
def get_openai_embedding(text: str, model: str, endpoint: str) -> List[float]:
|
25
|
+
from letta.settings import model_settings
|
26
|
+
|
27
|
+
client = OpenAI(api_key=model_settings.openai_api_key, base_url=endpoint, max_retries=0)
|
28
|
+
response = client.embeddings.create(input=text, model=model)
|
29
|
+
return response.data[0].embedding
|
30
|
+
|
31
|
+
|
32
|
+
# TODO: Add redis-backed caching for backend
|
33
|
+
@alru_cache(maxsize=8192)
|
34
|
+
async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> List[float]:
|
35
|
+
from letta.settings import model_settings
|
36
|
+
|
37
|
+
client = AsyncOpenAI(api_key=model_settings.openai_api_key, base_url=endpoint, max_retries=0)
|
38
|
+
response = await client.embeddings.create(input=text, model=model)
|
39
|
+
return response.data[0].embedding
|
40
|
+
|
41
|
+
|
18
42
|
class PassageManager:
|
19
43
|
"""Manager class to handle business logic related to Passages."""
|
20
44
|
|
@@ -35,11 +59,45 @@ class PassageManager:
|
|
35
59
|
except NoResultFound:
|
36
60
|
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
37
61
|
|
62
|
+
@enforce_types
|
63
|
+
@trace_method
|
64
|
+
async def get_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
65
|
+
"""Fetch a passage by ID."""
|
66
|
+
async with db_registry.async_session() as session:
|
67
|
+
# Try source passages first
|
68
|
+
try:
|
69
|
+
passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
70
|
+
return passage.to_pydantic()
|
71
|
+
except NoResultFound:
|
72
|
+
# Try archival passages
|
73
|
+
try:
|
74
|
+
passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
75
|
+
return passage.to_pydantic()
|
76
|
+
except NoResultFound:
|
77
|
+
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
78
|
+
|
38
79
|
@enforce_types
|
39
80
|
@trace_method
|
40
81
|
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
82
|
+
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
83
|
+
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
|
84
|
+
|
85
|
+
with db_registry.session() as session:
|
86
|
+
passage.create(session, actor=actor)
|
87
|
+
return passage.to_pydantic()
|
88
|
+
|
89
|
+
@enforce_types
|
90
|
+
@trace_method
|
91
|
+
async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
41
92
|
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
42
93
|
# Common fields for both passage types
|
94
|
+
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
|
95
|
+
async with db_registry.async_session() as session:
|
96
|
+
passage = await passage.create_async(session, actor=actor)
|
97
|
+
return passage.to_pydantic()
|
98
|
+
|
99
|
+
@trace_method
|
100
|
+
def _preprocess_passage_for_creation(self, pydantic_passage: PydanticPassage) -> "SqlAlchemyBase":
|
43
101
|
data = pydantic_passage.model_dump(to_orm=True)
|
44
102
|
common_fields = {
|
45
103
|
"id": data.get("id"),
|
@@ -68,9 +126,7 @@ class PassageManager:
|
|
68
126
|
else:
|
69
127
|
raise ValueError("Passage must have either agent_id or source_id")
|
70
128
|
|
71
|
-
|
72
|
-
passage.create(session, actor=actor)
|
73
|
-
return passage.to_pydantic()
|
129
|
+
return passage
|
74
130
|
|
75
131
|
@enforce_types
|
76
132
|
@trace_method
|
@@ -78,6 +134,33 @@ class PassageManager:
|
|
78
134
|
"""Create multiple passages."""
|
79
135
|
return [self.create_passage(p, actor) for p in passages]
|
80
136
|
|
137
|
+
@enforce_types
|
138
|
+
@trace_method
|
139
|
+
async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
140
|
+
"""Create multiple passages."""
|
141
|
+
async with db_registry.async_session() as session:
|
142
|
+
agent_passages = []
|
143
|
+
source_passages = []
|
144
|
+
|
145
|
+
for p in passages:
|
146
|
+
model = self._preprocess_passage_for_creation(p)
|
147
|
+
if isinstance(model, AgentPassage):
|
148
|
+
agent_passages.append(model)
|
149
|
+
elif isinstance(model, SourcePassage):
|
150
|
+
source_passages.append(model)
|
151
|
+
else:
|
152
|
+
raise TypeError(f"Unexpected passage type: {type(model)}")
|
153
|
+
|
154
|
+
results = []
|
155
|
+
if agent_passages:
|
156
|
+
agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
|
157
|
+
results.extend(agent_created)
|
158
|
+
if source_passages:
|
159
|
+
source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
|
160
|
+
results.extend(source_created)
|
161
|
+
|
162
|
+
return [p.to_pydantic() for p in results]
|
163
|
+
|
81
164
|
@enforce_types
|
82
165
|
@trace_method
|
83
166
|
def insert_passage(
|
@@ -106,14 +189,11 @@ class PassageManager:
|
|
106
189
|
embedding = embed_model.get_text_embedding(text)
|
107
190
|
else:
|
108
191
|
# TODO should have the settings passed in via the server call
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
api_key=model_settings.openai_api_key, base_url=agent_state.embedding_config.embedding_endpoint, max_retries=0
|
192
|
+
embedding = get_openai_embedding(
|
193
|
+
text,
|
194
|
+
agent_state.embedding_config.embedding_model,
|
195
|
+
agent_state.embedding_config.embedding_endpoint,
|
114
196
|
)
|
115
|
-
response = client.embeddings.create(input=text, model=agent_state.embedding_config.embedding_model)
|
116
|
-
embedding = response.data[0].embedding
|
117
197
|
|
118
198
|
if isinstance(embedding, dict):
|
119
199
|
try:
|
@@ -140,6 +220,78 @@ class PassageManager:
|
|
140
220
|
except Exception as e:
|
141
221
|
raise e
|
142
222
|
|
223
|
+
@enforce_types
|
224
|
+
@trace_method
|
225
|
+
async def insert_passage_async(
|
226
|
+
self,
|
227
|
+
agent_state: AgentState,
|
228
|
+
agent_id: str,
|
229
|
+
text: str,
|
230
|
+
actor: PydanticUser,
|
231
|
+
) -> List[PydanticPassage]:
|
232
|
+
"""Insert passage(s) into archival memory"""
|
233
|
+
|
234
|
+
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
235
|
+
text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
|
236
|
+
|
237
|
+
if not text_chunks:
|
238
|
+
return []
|
239
|
+
|
240
|
+
try:
|
241
|
+
embeddings = await self._generate_embeddings_concurrent(text_chunks, agent_state.embedding_config)
|
242
|
+
|
243
|
+
passages = [
|
244
|
+
PydanticPassage(
|
245
|
+
organization_id=actor.organization_id,
|
246
|
+
agent_id=agent_id,
|
247
|
+
text=chunk_text,
|
248
|
+
embedding=embedding,
|
249
|
+
embedding_config=agent_state.embedding_config,
|
250
|
+
)
|
251
|
+
for chunk_text, embedding in zip(text_chunks, embeddings)
|
252
|
+
]
|
253
|
+
|
254
|
+
passages = await self.create_many_passages_async(passages=passages, actor=actor)
|
255
|
+
|
256
|
+
return passages
|
257
|
+
|
258
|
+
except Exception as e:
|
259
|
+
raise e
|
260
|
+
|
261
|
+
async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config) -> List[List[float]]:
|
262
|
+
"""Generate embeddings for all text chunks concurrently"""
|
263
|
+
|
264
|
+
if embedding_config.embedding_endpoint_type != "openai":
|
265
|
+
embed_model = embedding_model(embedding_config)
|
266
|
+
loop = asyncio.get_event_loop()
|
267
|
+
|
268
|
+
tasks = [loop.run_in_executor(None, embed_model.get_text_embedding, text) for text in text_chunks]
|
269
|
+
embeddings = await asyncio.gather(*tasks)
|
270
|
+
else:
|
271
|
+
tasks = [
|
272
|
+
get_openai_embedding_async(
|
273
|
+
text,
|
274
|
+
embedding_config.embedding_model,
|
275
|
+
embedding_config.embedding_endpoint,
|
276
|
+
)
|
277
|
+
for text in text_chunks
|
278
|
+
]
|
279
|
+
embeddings = await asyncio.gather(*tasks)
|
280
|
+
|
281
|
+
processed_embeddings = []
|
282
|
+
for embedding in embeddings:
|
283
|
+
if isinstance(embedding, dict):
|
284
|
+
try:
|
285
|
+
processed_embeddings.append(embedding["data"][0]["embedding"])
|
286
|
+
except (KeyError, IndexError):
|
287
|
+
raise TypeError(
|
288
|
+
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
289
|
+
)
|
290
|
+
else:
|
291
|
+
processed_embeddings.append(embedding)
|
292
|
+
|
293
|
+
return processed_embeddings
|
294
|
+
|
143
295
|
@enforce_types
|
144
296
|
@trace_method
|
145
297
|
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
|
@@ -197,6 +349,28 @@ class PassageManager:
|
|
197
349
|
except NoResultFound:
|
198
350
|
raise NoResultFound(f"Passage with id {passage_id} not found.")
|
199
351
|
|
352
|
+
@enforce_types
|
353
|
+
@trace_method
|
354
|
+
async def delete_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
|
355
|
+
"""Delete a passage from either source or archival passages."""
|
356
|
+
if not passage_id:
|
357
|
+
raise ValueError("Passage ID must be provided.")
|
358
|
+
|
359
|
+
async with db_registry.async_session() as session:
|
360
|
+
# Try source passages first
|
361
|
+
try:
|
362
|
+
passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
363
|
+
await passage.hard_delete_async(session, actor=actor)
|
364
|
+
return True
|
365
|
+
except NoResultFound:
|
366
|
+
# Try archival passages
|
367
|
+
try:
|
368
|
+
passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
369
|
+
await passage.hard_delete_async(session, actor=actor)
|
370
|
+
return True
|
371
|
+
except NoResultFound:
|
372
|
+
raise NoResultFound(f"Passage with id {passage_id} not found.")
|
373
|
+
|
200
374
|
@enforce_types
|
201
375
|
@trace_method
|
202
376
|
def delete_passages(
|
@@ -210,6 +384,17 @@ class PassageManager:
|
|
210
384
|
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
211
385
|
return True
|
212
386
|
|
387
|
+
@enforce_types
|
388
|
+
@trace_method
|
389
|
+
async def delete_source_passages_async(
|
390
|
+
self,
|
391
|
+
actor: PydanticUser,
|
392
|
+
passages: List[PydanticPassage],
|
393
|
+
) -> bool:
|
394
|
+
async with db_registry.async_session() as session:
|
395
|
+
await SourcePassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
|
396
|
+
return True
|
397
|
+
|
213
398
|
@enforce_types
|
214
399
|
@trace_method
|
215
400
|
def size(
|
@@ -243,7 +428,7 @@ class PassageManager:
|
|
243
428
|
|
244
429
|
@enforce_types
|
245
430
|
@trace_method
|
246
|
-
def
|
431
|
+
async def estimate_embeddings_size_async(
|
247
432
|
self,
|
248
433
|
actor: PydanticUser,
|
249
434
|
agent_id: Optional[str] = None,
|
@@ -263,4 +448,17 @@ class PassageManager:
|
|
263
448
|
raise ValueError(f"Invalid storage unit: {storage_unit}. Must be one of {list(BYTES_PER_STORAGE_UNIT.keys())}.")
|
264
449
|
BYTES_PER_EMBEDDING_DIM = 4
|
265
450
|
GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM
|
266
|
-
return self.
|
451
|
+
return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
|
452
|
+
|
453
|
+
@enforce_types
|
454
|
+
@trace_method
|
455
|
+
async def list_passages_by_file_id_async(self, file_id: str, actor: PydanticUser) -> List[PydanticPassage]:
|
456
|
+
"""
|
457
|
+
List all source passages associated with a given file_id.
|
458
|
+
"""
|
459
|
+
async with db_registry.async_session() as session:
|
460
|
+
result = await session.execute(
|
461
|
+
select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id)
|
462
|
+
)
|
463
|
+
passages = result.scalars().all()
|
464
|
+
return [p.to_pydantic() for p in passages]
|
@@ -33,13 +33,34 @@ class ProviderManager:
|
|
33
33
|
new_provider.create(session, actor=actor)
|
34
34
|
return new_provider.to_pydantic()
|
35
35
|
|
36
|
+
@enforce_types
|
37
|
+
@trace_method
|
38
|
+
async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
|
39
|
+
"""Create a new provider if it doesn't already exist."""
|
40
|
+
async with db_registry.async_session() as session:
|
41
|
+
provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok}
|
42
|
+
provider = PydanticProvider(**provider_create_args)
|
43
|
+
|
44
|
+
if provider.name == provider.provider_type.value:
|
45
|
+
raise ValueError("Provider name must be unique and different from provider type")
|
46
|
+
|
47
|
+
# Assign the organization id based on the actor
|
48
|
+
provider.organization_id = actor.organization_id
|
49
|
+
|
50
|
+
# Lazily create the provider id prior to persistence
|
51
|
+
provider.resolve_identifier()
|
52
|
+
|
53
|
+
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
|
54
|
+
await new_provider.create_async(session, actor=actor)
|
55
|
+
return new_provider.to_pydantic()
|
56
|
+
|
36
57
|
@enforce_types
|
37
58
|
@trace_method
|
38
59
|
def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
39
60
|
"""Update provider details."""
|
40
61
|
with db_registry.session() as session:
|
41
62
|
# Retrieve the existing provider by ID
|
42
|
-
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
|
63
|
+
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True)
|
43
64
|
|
44
65
|
# Update only the fields that are provided in ProviderUpdate
|
45
66
|
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
@@ -56,7 +77,7 @@ class ProviderManager:
|
|
56
77
|
"""Delete a provider."""
|
57
78
|
with db_registry.session() as session:
|
58
79
|
# Clear api key field
|
59
|
-
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
|
80
|
+
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True)
|
60
81
|
existing_provider.api_key = None
|
61
82
|
existing_provider.update(session, actor=actor)
|
62
83
|
|
@@ -65,6 +86,23 @@ class ProviderManager:
|
|
65
86
|
|
66
87
|
session.commit()
|
67
88
|
|
89
|
+
@enforce_types
|
90
|
+
@trace_method
|
91
|
+
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
|
92
|
+
"""Delete a provider."""
|
93
|
+
async with db_registry.async_session() as session:
|
94
|
+
# Clear api key field
|
95
|
+
existing_provider = await ProviderModel.read_async(
|
96
|
+
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
97
|
+
)
|
98
|
+
existing_provider.api_key = None
|
99
|
+
await existing_provider.update_async(session, actor=actor)
|
100
|
+
|
101
|
+
# Soft delete in provider table
|
102
|
+
await existing_provider.delete_async(session, actor=actor)
|
103
|
+
|
104
|
+
await session.commit()
|
105
|
+
|
68
106
|
@enforce_types
|
69
107
|
@trace_method
|
70
108
|
def list_providers(
|
@@ -87,6 +125,7 @@ class ProviderManager:
|
|
87
125
|
after=after,
|
88
126
|
limit=limit,
|
89
127
|
actor=actor,
|
128
|
+
check_is_deleted=True,
|
90
129
|
**filter_kwargs,
|
91
130
|
)
|
92
131
|
return [provider.to_pydantic() for provider in providers]
|
@@ -113,6 +152,7 @@ class ProviderManager:
|
|
113
152
|
after=after,
|
114
153
|
limit=limit,
|
115
154
|
actor=actor,
|
155
|
+
check_is_deleted=True,
|
116
156
|
**filter_kwargs,
|
117
157
|
)
|
118
158
|
return [provider.to_pydantic() for provider in providers]
|
@@ -129,6 +169,12 @@ class ProviderManager:
|
|
129
169
|
providers = self.list_providers(name=provider_name, actor=actor)
|
130
170
|
return providers[0].api_key if providers else None
|
131
171
|
|
172
|
+
@enforce_types
|
173
|
+
@trace_method
|
174
|
+
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
175
|
+
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
176
|
+
return providers[0].api_key if providers else None
|
177
|
+
|
132
178
|
@enforce_types
|
133
179
|
@trace_method
|
134
180
|
def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|