letta-nightly 0.7.29.dev20250602104315__py3-none-any.whl → 0.8.0.dev20250604104349__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. letta/__init__.py +7 -1
  2. letta/agent.py +16 -9
  3. letta/agents/base_agent.py +1 -0
  4. letta/agents/ephemeral_summary_agent.py +104 -0
  5. letta/agents/helpers.py +35 -3
  6. letta/agents/letta_agent.py +492 -176
  7. letta/agents/letta_agent_batch.py +22 -16
  8. letta/agents/prompts/summary_system_prompt.txt +62 -0
  9. letta/agents/voice_agent.py +22 -7
  10. letta/agents/voice_sleeptime_agent.py +13 -8
  11. letta/constants.py +33 -1
  12. letta/data_sources/connectors.py +52 -36
  13. letta/errors.py +4 -0
  14. letta/functions/ast_parsers.py +13 -30
  15. letta/functions/function_sets/base.py +3 -1
  16. letta/functions/functions.py +2 -0
  17. letta/functions/mcp_client/base_client.py +151 -97
  18. letta/functions/mcp_client/sse_client.py +49 -31
  19. letta/functions/mcp_client/stdio_client.py +107 -106
  20. letta/functions/schema_generator.py +22 -22
  21. letta/groups/helpers.py +3 -4
  22. letta/groups/sleeptime_multi_agent.py +4 -4
  23. letta/groups/sleeptime_multi_agent_v2.py +22 -0
  24. letta/helpers/composio_helpers.py +16 -0
  25. letta/helpers/converters.py +20 -0
  26. letta/helpers/datetime_helpers.py +1 -6
  27. letta/helpers/tool_rule_solver.py +2 -1
  28. letta/interfaces/anthropic_streaming_interface.py +17 -2
  29. letta/interfaces/openai_chat_completions_streaming_interface.py +1 -0
  30. letta/interfaces/openai_streaming_interface.py +18 -2
  31. letta/jobs/llm_batch_job_polling.py +1 -1
  32. letta/jobs/scheduler.py +1 -1
  33. letta/llm_api/anthropic_client.py +24 -3
  34. letta/llm_api/google_ai_client.py +0 -15
  35. letta/llm_api/google_vertex_client.py +6 -5
  36. letta/llm_api/llm_client_base.py +15 -0
  37. letta/llm_api/openai.py +2 -2
  38. letta/llm_api/openai_client.py +60 -8
  39. letta/orm/__init__.py +2 -0
  40. letta/orm/agent.py +45 -43
  41. letta/orm/base.py +0 -2
  42. letta/orm/block.py +1 -0
  43. letta/orm/custom_columns.py +13 -0
  44. letta/orm/enums.py +5 -0
  45. letta/orm/file.py +3 -1
  46. letta/orm/files_agents.py +68 -0
  47. letta/orm/mcp_server.py +48 -0
  48. letta/orm/message.py +1 -0
  49. letta/orm/organization.py +11 -2
  50. letta/orm/passage.py +25 -10
  51. letta/orm/sandbox_config.py +5 -2
  52. letta/orm/sqlalchemy_base.py +171 -110
  53. letta/prompts/system/memgpt_base.txt +6 -1
  54. letta/prompts/system/memgpt_v2_chat.txt +57 -0
  55. letta/prompts/system/sleeptime.txt +2 -0
  56. letta/prompts/system/sleeptime_v2.txt +28 -0
  57. letta/schemas/agent.py +87 -20
  58. letta/schemas/block.py +7 -1
  59. letta/schemas/file.py +57 -0
  60. letta/schemas/mcp.py +74 -0
  61. letta/schemas/memory.py +5 -2
  62. letta/schemas/message.py +9 -0
  63. letta/schemas/openai/openai.py +0 -6
  64. letta/schemas/providers.py +33 -4
  65. letta/schemas/tool.py +26 -21
  66. letta/schemas/tool_execution_result.py +5 -0
  67. letta/server/db.py +23 -8
  68. letta/server/rest_api/app.py +73 -56
  69. letta/server/rest_api/interface.py +4 -4
  70. letta/server/rest_api/routers/v1/agents.py +132 -47
  71. letta/server/rest_api/routers/v1/blocks.py +3 -2
  72. letta/server/rest_api/routers/v1/embeddings.py +3 -3
  73. letta/server/rest_api/routers/v1/groups.py +3 -3
  74. letta/server/rest_api/routers/v1/jobs.py +14 -17
  75. letta/server/rest_api/routers/v1/organizations.py +10 -10
  76. letta/server/rest_api/routers/v1/providers.py +12 -10
  77. letta/server/rest_api/routers/v1/runs.py +3 -3
  78. letta/server/rest_api/routers/v1/sandbox_configs.py +12 -12
  79. letta/server/rest_api/routers/v1/sources.py +108 -43
  80. letta/server/rest_api/routers/v1/steps.py +8 -6
  81. letta/server/rest_api/routers/v1/tools.py +134 -95
  82. letta/server/rest_api/utils.py +12 -1
  83. letta/server/server.py +272 -73
  84. letta/services/agent_manager.py +246 -313
  85. letta/services/block_manager.py +30 -9
  86. letta/services/context_window_calculator/__init__.py +0 -0
  87. letta/services/context_window_calculator/context_window_calculator.py +150 -0
  88. letta/services/context_window_calculator/token_counter.py +82 -0
  89. letta/services/file_processor/__init__.py +0 -0
  90. letta/services/file_processor/chunker/__init__.py +0 -0
  91. letta/services/file_processor/chunker/llama_index_chunker.py +29 -0
  92. letta/services/file_processor/embedder/__init__.py +0 -0
  93. letta/services/file_processor/embedder/openai_embedder.py +84 -0
  94. letta/services/file_processor/file_processor.py +123 -0
  95. letta/services/file_processor/parser/__init__.py +0 -0
  96. letta/services/file_processor/parser/base_parser.py +9 -0
  97. letta/services/file_processor/parser/mistral_parser.py +54 -0
  98. letta/services/file_processor/types.py +0 -0
  99. letta/services/files_agents_manager.py +184 -0
  100. letta/services/group_manager.py +118 -0
  101. letta/services/helpers/agent_manager_helper.py +76 -21
  102. letta/services/helpers/tool_execution_helper.py +3 -0
  103. letta/services/helpers/tool_parser_helper.py +100 -0
  104. letta/services/identity_manager.py +44 -42
  105. letta/services/job_manager.py +21 -10
  106. letta/services/mcp/base_client.py +5 -2
  107. letta/services/mcp/sse_client.py +3 -5
  108. letta/services/mcp/stdio_client.py +3 -5
  109. letta/services/mcp_manager.py +281 -0
  110. letta/services/message_manager.py +40 -26
  111. letta/services/organization_manager.py +55 -19
  112. letta/services/passage_manager.py +211 -13
  113. letta/services/provider_manager.py +48 -2
  114. letta/services/sandbox_config_manager.py +105 -0
  115. letta/services/source_manager.py +4 -5
  116. letta/services/step_manager.py +9 -6
  117. letta/services/summarizer/summarizer.py +50 -23
  118. letta/services/telemetry_manager.py +7 -0
  119. letta/services/tool_executor/tool_execution_manager.py +11 -52
  120. letta/services/tool_executor/tool_execution_sandbox.py +4 -34
  121. letta/services/tool_executor/tool_executor.py +107 -105
  122. letta/services/tool_manager.py +56 -17
  123. letta/services/tool_sandbox/base.py +39 -92
  124. letta/services/tool_sandbox/e2b_sandbox.py +16 -11
  125. letta/services/tool_sandbox/local_sandbox.py +51 -23
  126. letta/services/user_manager.py +36 -3
  127. letta/settings.py +10 -3
  128. letta/templates/__init__.py +0 -0
  129. letta/templates/sandbox_code_file.py.j2 +47 -0
  130. letta/templates/template_helper.py +16 -0
  131. letta/tracing.py +30 -1
  132. letta/types/__init__.py +7 -0
  133. letta/utils.py +25 -1
  134. {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/METADATA +7 -2
  135. {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/RECORD +138 -112
  136. {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/LICENSE +0 -0
  137. {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/WHEEL +0 -0
  138. {letta_nightly-0.7.29.dev20250602104315.dist-info → letta_nightly-0.8.0.dev20250604104349.dist-info}/entry_points.txt +0 -0
@@ -49,7 +49,7 @@ class MessageManager:
49
49
  def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
50
50
  """Fetch messages by ID and return them in the requested order."""
51
51
  with db_registry.session() as session:
52
- results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
52
+ results = MessageModel.read_multiple(db_session=session, identifiers=message_ids, actor=actor)
53
53
  return self._get_messages_by_id_postprocess(results, message_ids)
54
54
 
55
55
  @enforce_types
@@ -57,10 +57,8 @@ class MessageManager:
57
57
  async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
58
58
  """Fetch messages by ID and return them in the requested order. Async version of above function."""
59
59
  async with db_registry.async_session() as session:
60
- results = await MessageModel.list_async(
61
- db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids)
62
- )
63
- return self._get_messages_by_id_postprocess(results, message_ids)
60
+ results = await MessageModel.read_multiple_async(db_session=session, identifiers=message_ids, actor=actor)
61
+ return self._get_messages_by_id_postprocess(results, message_ids)
64
62
 
65
63
  def _get_messages_by_id_postprocess(
66
64
  self,
@@ -349,6 +347,29 @@ class MessageManager:
349
347
  ascending=ascending,
350
348
  )
351
349
 
350
+ @enforce_types
351
+ @trace_method
352
+ async def list_user_messages_for_agent_async(
353
+ self,
354
+ agent_id: str,
355
+ actor: PydanticUser,
356
+ after: Optional[str] = None,
357
+ before: Optional[str] = None,
358
+ query_text: Optional[str] = None,
359
+ limit: Optional[int] = 50,
360
+ ascending: bool = True,
361
+ ) -> List[PydanticMessage]:
362
+ return await self.list_messages_for_agent_async(
363
+ agent_id=agent_id,
364
+ actor=actor,
365
+ after=after,
366
+ before=before,
367
+ query_text=query_text,
368
+ roles=[MessageRole.user],
369
+ limit=limit,
370
+ ascending=ascending,
371
+ )
372
+
352
373
  @enforce_types
353
374
  @trace_method
354
375
  def list_messages_for_agent(
@@ -400,24 +421,17 @@ class MessageManager:
400
421
  if group_id:
401
422
  query = query.filter(MessageModel.group_id == group_id)
402
423
 
403
- # If query_text is provided, filter messages by matching any "text" type content block
404
- # whose text includes the query string (case-insensitive).
424
+ # If query_text is provided, filter messages using subquery + json_array_elements.
405
425
  if query_text:
406
- dialect_name = session.bind.dialect.name
407
-
408
- if dialect_name == "postgresql": # using subquery + json_array_elements.
409
- content_element = func.json_array_elements(MessageModel.content).alias("content_element")
410
- subquery_sql = text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text")
411
- subquery = select(1).select_from(content_element).where(subquery_sql)
412
-
413
- elif dialect_name == "sqlite": # using `json_each` and JSON path expressions
414
- json_item = func.json_each(MessageModel.content).alias("json_item")
415
- subquery_sql = text(
416
- "json_extract(value, '$.type') = 'text' AND lower(json_extract(value, '$.text')) LIKE lower(:query_text)"
426
+ content_element = func.json_array_elements(MessageModel.content).alias("content_element")
427
+ query = query.filter(
428
+ exists(
429
+ select(1)
430
+ .select_from(content_element)
431
+ .where(text("content_element->>'type' = 'text' AND content_element->>'text' ILIKE :query_text"))
432
+ .params(query_text=f"%{query_text}%")
417
433
  )
418
- subquery = select(1).select_from(json_item).where(subquery_sql)
419
-
420
- query = query.filter(exists(subquery.params(query_text=f"%{query_text}%")))
434
+ )
421
435
 
422
436
  # If role(s) are provided, filter messages by those roles.
423
437
  if roles:
@@ -557,23 +571,23 @@ class MessageManager:
557
571
 
558
572
  @enforce_types
559
573
  @trace_method
560
- def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
574
+ async def delete_all_messages_for_agent_async(self, agent_id: str, actor: PydanticUser) -> int:
561
575
  """
562
576
  Efficiently deletes all messages associated with a given agent_id,
563
577
  while enforcing permission checks and avoiding any ORM‑level loads.
564
578
  """
565
- with db_registry.session() as session:
579
+ async with db_registry.async_session() as session:
566
580
  # 1) verify the agent exists and the actor has access
567
- AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
581
+ await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
568
582
 
569
583
  # 2) issue a CORE DELETE against the mapped class
570
584
  stmt = (
571
585
  delete(MessageModel).where(MessageModel.agent_id == agent_id).where(MessageModel.organization_id == actor.organization_id)
572
586
  )
573
- result = session.execute(stmt)
587
+ result = await session.execute(stmt)
574
588
 
575
589
  # 3) commit once
576
- session.commit()
590
+ await session.commit()
577
591
 
578
592
  # 4) return the number of rows deleted
579
593
  return result.rowcount
@@ -17,9 +17,9 @@ class OrganizationManager:
17
17
 
18
18
  @enforce_types
19
19
  @trace_method
20
- def get_default_organization(self) -> PydanticOrganization:
20
+ async def get_default_organization_async(self) -> PydanticOrganization:
21
21
  """Fetch the default organization."""
22
- return self.get_organization_by_id(self.DEFAULT_ORG_ID)
22
+ return await self.get_organization_by_id_async(self.DEFAULT_ORG_ID)
23
23
 
24
24
  @enforce_types
25
25
  @trace_method
@@ -29,52 +29,80 @@ class OrganizationManager:
29
29
  organization = OrganizationModel.read(db_session=session, identifier=org_id)
30
30
  return organization.to_pydantic()
31
31
 
32
+ @enforce_types
33
+ @trace_method
34
+ async def get_organization_by_id_async(self, org_id: str) -> Optional[PydanticOrganization]:
35
+ """Fetch an organization by ID."""
36
+ async with db_registry.async_session() as session:
37
+ organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
38
+ return organization.to_pydantic()
39
+
32
40
  @enforce_types
33
41
  @trace_method
34
42
  def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
43
+ """Create the default organization."""
44
+ with db_registry.session() as session:
45
+ try:
46
+ organization = OrganizationModel.read(db_session=session, identifier=pydantic_org.id)
47
+ return organization.to_pydantic()
48
+ except:
49
+ organization = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
50
+ organization = organization.create(session)
51
+ return organization.to_pydantic()
52
+
53
+ @enforce_types
54
+ @trace_method
55
+ async def create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
35
56
  """Create a new organization."""
36
57
  try:
37
- org = self.get_organization_by_id(pydantic_org.id)
58
+ org = await self.get_organization_by_id_async(pydantic_org.id)
38
59
  return org
39
60
  except NoResultFound:
40
- return self._create_organization(pydantic_org=pydantic_org)
61
+ return await self._create_organization_async(pydantic_org=pydantic_org)
41
62
 
42
63
  @enforce_types
43
64
  @trace_method
44
- def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
45
- with db_registry.session() as session:
65
+ async def _create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
66
+ async with db_registry.async_session() as session:
46
67
  org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
47
- org.create(session)
68
+ await org.create_async(session)
48
69
  return org.to_pydantic()
49
70
 
50
71
  @enforce_types
51
72
  @trace_method
52
73
  def create_default_organization(self) -> PydanticOrganization:
53
74
  """Create the default organization."""
54
- return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
75
+ pydantic_org = PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
76
+ return self.create_organization(pydantic_org)
55
77
 
56
78
  @enforce_types
57
79
  @trace_method
58
- def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
80
+ async def create_default_organization_async(self) -> PydanticOrganization:
81
+ """Create the default organization."""
82
+ return await self.create_organization_async(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
83
+
84
+ @enforce_types
85
+ @trace_method
86
+ async def update_organization_name_using_id_async(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
59
87
  """Update an organization."""
60
- with db_registry.session() as session:
61
- org = OrganizationModel.read(db_session=session, identifier=org_id)
88
+ async with db_registry.async_session() as session:
89
+ org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
62
90
  if name:
63
91
  org.name = name
64
- org.update(session)
92
+ await org.update_async(session)
65
93
  return org.to_pydantic()
66
94
 
67
95
  @enforce_types
68
96
  @trace_method
69
- def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
97
+ async def update_organization_async(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
70
98
  """Update an organization."""
71
- with db_registry.session() as session:
72
- org = OrganizationModel.read(db_session=session, identifier=org_id)
99
+ async with db_registry.async_session() as session:
100
+ org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
73
101
  if org_update.name:
74
102
  org.name = org_update.name
75
103
  if org_update.privileged_tools:
76
104
  org.privileged_tools = org_update.privileged_tools
77
- org.update(session)
105
+ await org.update_async(session)
78
106
  return org.to_pydantic()
79
107
 
80
108
  @enforce_types
@@ -87,10 +115,18 @@ class OrganizationManager:
87
115
 
88
116
  @enforce_types
89
117
  @trace_method
90
- def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
118
+ async def delete_organization_by_id_async(self, org_id: str):
119
+ """Delete an organization by marking it as deleted."""
120
+ async with db_registry.async_session() as session:
121
+ organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
122
+ await organization.hard_delete_async(session)
123
+
124
+ @enforce_types
125
+ @trace_method
126
+ async def list_organizations_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
91
127
  """List all organizations with optional pagination."""
92
- with db_registry.session() as session:
93
- organizations = OrganizationModel.list(
128
+ async with db_registry.async_session() as session:
129
+ organizations = await OrganizationModel.list_async(
94
130
  db_session=session,
95
131
  after=after,
96
132
  limit=limit,
@@ -1,7 +1,11 @@
1
+ import asyncio
1
2
  from datetime import datetime, timezone
3
+ from functools import lru_cache
2
4
  from typing import List, Optional
3
5
 
4
- from openai import OpenAI
6
+ from async_lru import alru_cache
7
+ from openai import AsyncOpenAI, OpenAI
8
+ from sqlalchemy import select
5
9
 
6
10
  from letta.constants import MAX_EMBEDDING_DIM
7
11
  from letta.embeddings import embedding_model, parse_and_chunk_text
@@ -15,6 +19,26 @@ from letta.tracing import trace_method
15
19
  from letta.utils import enforce_types
16
20
 
17
21
 
22
+ # TODO: Add redis-backed caching for backend
23
+ @lru_cache(maxsize=8192)
24
+ def get_openai_embedding(text: str, model: str, endpoint: str) -> List[float]:
25
+ from letta.settings import model_settings
26
+
27
+ client = OpenAI(api_key=model_settings.openai_api_key, base_url=endpoint, max_retries=0)
28
+ response = client.embeddings.create(input=text, model=model)
29
+ return response.data[0].embedding
30
+
31
+
32
+ # TODO: Add redis-backed caching for backend
33
+ @alru_cache(maxsize=8192)
34
+ async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> List[float]:
35
+ from letta.settings import model_settings
36
+
37
+ client = AsyncOpenAI(api_key=model_settings.openai_api_key, base_url=endpoint, max_retries=0)
38
+ response = await client.embeddings.create(input=text, model=model)
39
+ return response.data[0].embedding
40
+
41
+
18
42
  class PassageManager:
19
43
  """Manager class to handle business logic related to Passages."""
20
44
 
@@ -35,11 +59,45 @@ class PassageManager:
35
59
  except NoResultFound:
36
60
  raise NoResultFound(f"Passage with id {passage_id} not found in database.")
37
61
 
62
+ @enforce_types
63
+ @trace_method
64
+ async def get_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
65
+ """Fetch a passage by ID."""
66
+ async with db_registry.async_session() as session:
67
+ # Try source passages first
68
+ try:
69
+ passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
70
+ return passage.to_pydantic()
71
+ except NoResultFound:
72
+ # Try archival passages
73
+ try:
74
+ passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
75
+ return passage.to_pydantic()
76
+ except NoResultFound:
77
+ raise NoResultFound(f"Passage with id {passage_id} not found in database.")
78
+
38
79
  @enforce_types
39
80
  @trace_method
40
81
  def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
82
+ """Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
83
+ passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
84
+
85
+ with db_registry.session() as session:
86
+ passage.create(session, actor=actor)
87
+ return passage.to_pydantic()
88
+
89
+ @enforce_types
90
+ @trace_method
91
+ async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
41
92
  """Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
42
93
  # Common fields for both passage types
94
+ passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
95
+ async with db_registry.async_session() as session:
96
+ passage = await passage.create_async(session, actor=actor)
97
+ return passage.to_pydantic()
98
+
99
+ @trace_method
100
+ def _preprocess_passage_for_creation(self, pydantic_passage: PydanticPassage) -> "SqlAlchemyBase":
43
101
  data = pydantic_passage.model_dump(to_orm=True)
44
102
  common_fields = {
45
103
  "id": data.get("id"),
@@ -68,9 +126,7 @@ class PassageManager:
68
126
  else:
69
127
  raise ValueError("Passage must have either agent_id or source_id")
70
128
 
71
- with db_registry.session() as session:
72
- passage.create(session, actor=actor)
73
- return passage.to_pydantic()
129
+ return passage
74
130
 
75
131
  @enforce_types
76
132
  @trace_method
@@ -78,6 +134,33 @@ class PassageManager:
78
134
  """Create multiple passages."""
79
135
  return [self.create_passage(p, actor) for p in passages]
80
136
 
137
+ @enforce_types
138
+ @trace_method
139
+ async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
140
+ """Create multiple passages."""
141
+ async with db_registry.async_session() as session:
142
+ agent_passages = []
143
+ source_passages = []
144
+
145
+ for p in passages:
146
+ model = self._preprocess_passage_for_creation(p)
147
+ if isinstance(model, AgentPassage):
148
+ agent_passages.append(model)
149
+ elif isinstance(model, SourcePassage):
150
+ source_passages.append(model)
151
+ else:
152
+ raise TypeError(f"Unexpected passage type: {type(model)}")
153
+
154
+ results = []
155
+ if agent_passages:
156
+ agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
157
+ results.extend(agent_created)
158
+ if source_passages:
159
+ source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
160
+ results.extend(source_created)
161
+
162
+ return [p.to_pydantic() for p in results]
163
+
81
164
  @enforce_types
82
165
  @trace_method
83
166
  def insert_passage(
@@ -106,14 +189,11 @@ class PassageManager:
106
189
  embedding = embed_model.get_text_embedding(text)
107
190
  else:
108
191
  # TODO should have the settings passed in via the server call
109
- from letta.settings import model_settings
110
-
111
- # Simple OpenAI client code
112
- client = OpenAI(
113
- api_key=model_settings.openai_api_key, base_url=agent_state.embedding_config.embedding_endpoint, max_retries=0
192
+ embedding = get_openai_embedding(
193
+ text,
194
+ agent_state.embedding_config.embedding_model,
195
+ agent_state.embedding_config.embedding_endpoint,
114
196
  )
115
- response = client.embeddings.create(input=text, model=agent_state.embedding_config.embedding_model)
116
- embedding = response.data[0].embedding
117
197
 
118
198
  if isinstance(embedding, dict):
119
199
  try:
@@ -140,6 +220,78 @@ class PassageManager:
140
220
  except Exception as e:
141
221
  raise e
142
222
 
223
+ @enforce_types
224
+ @trace_method
225
+ async def insert_passage_async(
226
+ self,
227
+ agent_state: AgentState,
228
+ agent_id: str,
229
+ text: str,
230
+ actor: PydanticUser,
231
+ ) -> List[PydanticPassage]:
232
+ """Insert passage(s) into archival memory"""
233
+
234
+ embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
235
+ text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size))
236
+
237
+ if not text_chunks:
238
+ return []
239
+
240
+ try:
241
+ embeddings = await self._generate_embeddings_concurrent(text_chunks, agent_state.embedding_config)
242
+
243
+ passages = [
244
+ PydanticPassage(
245
+ organization_id=actor.organization_id,
246
+ agent_id=agent_id,
247
+ text=chunk_text,
248
+ embedding=embedding,
249
+ embedding_config=agent_state.embedding_config,
250
+ )
251
+ for chunk_text, embedding in zip(text_chunks, embeddings)
252
+ ]
253
+
254
+ passages = await self.create_many_passages_async(passages=passages, actor=actor)
255
+
256
+ return passages
257
+
258
+ except Exception as e:
259
+ raise e
260
+
261
+ async def _generate_embeddings_concurrent(self, text_chunks: List[str], embedding_config) -> List[List[float]]:
262
+ """Generate embeddings for all text chunks concurrently"""
263
+
264
+ if embedding_config.embedding_endpoint_type != "openai":
265
+ embed_model = embedding_model(embedding_config)
266
+ loop = asyncio.get_event_loop()
267
+
268
+ tasks = [loop.run_in_executor(None, embed_model.get_text_embedding, text) for text in text_chunks]
269
+ embeddings = await asyncio.gather(*tasks)
270
+ else:
271
+ tasks = [
272
+ get_openai_embedding_async(
273
+ text,
274
+ embedding_config.embedding_model,
275
+ embedding_config.embedding_endpoint,
276
+ )
277
+ for text in text_chunks
278
+ ]
279
+ embeddings = await asyncio.gather(*tasks)
280
+
281
+ processed_embeddings = []
282
+ for embedding in embeddings:
283
+ if isinstance(embedding, dict):
284
+ try:
285
+ processed_embeddings.append(embedding["data"][0]["embedding"])
286
+ except (KeyError, IndexError):
287
+ raise TypeError(
288
+ f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
289
+ )
290
+ else:
291
+ processed_embeddings.append(embedding)
292
+
293
+ return processed_embeddings
294
+
143
295
  @enforce_types
144
296
  @trace_method
145
297
  def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
@@ -197,6 +349,28 @@ class PassageManager:
197
349
  except NoResultFound:
198
350
  raise NoResultFound(f"Passage with id {passage_id} not found.")
199
351
 
352
+ @enforce_types
353
+ @trace_method
354
+ async def delete_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
355
+ """Delete a passage from either source or archival passages."""
356
+ if not passage_id:
357
+ raise ValueError("Passage ID must be provided.")
358
+
359
+ async with db_registry.async_session() as session:
360
+ # Try source passages first
361
+ try:
362
+ passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
363
+ await passage.hard_delete_async(session, actor=actor)
364
+ return True
365
+ except NoResultFound:
366
+ # Try archival passages
367
+ try:
368
+ passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
369
+ await passage.hard_delete_async(session, actor=actor)
370
+ return True
371
+ except NoResultFound:
372
+ raise NoResultFound(f"Passage with id {passage_id} not found.")
373
+
200
374
  @enforce_types
201
375
  @trace_method
202
376
  def delete_passages(
@@ -210,6 +384,17 @@ class PassageManager:
210
384
  self.delete_passage_by_id(passage_id=passage.id, actor=actor)
211
385
  return True
212
386
 
387
+ @enforce_types
388
+ @trace_method
389
+ async def delete_source_passages_async(
390
+ self,
391
+ actor: PydanticUser,
392
+ passages: List[PydanticPassage],
393
+ ) -> bool:
394
+ async with db_registry.async_session() as session:
395
+ await SourcePassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
396
+ return True
397
+
213
398
  @enforce_types
214
399
  @trace_method
215
400
  def size(
@@ -243,7 +428,7 @@ class PassageManager:
243
428
 
244
429
  @enforce_types
245
430
  @trace_method
246
- def estimate_embeddings_size(
431
+ async def estimate_embeddings_size_async(
247
432
  self,
248
433
  actor: PydanticUser,
249
434
  agent_id: Optional[str] = None,
@@ -263,4 +448,17 @@ class PassageManager:
263
448
  raise ValueError(f"Invalid storage unit: {storage_unit}. Must be one of {list(BYTES_PER_STORAGE_UNIT.keys())}.")
264
449
  BYTES_PER_EMBEDDING_DIM = 4
265
450
  GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM
266
- return self.size(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
451
+ return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
452
+
453
+ @enforce_types
454
+ @trace_method
455
+ async def list_passages_by_file_id_async(self, file_id: str, actor: PydanticUser) -> List[PydanticPassage]:
456
+ """
457
+ List all source passages associated with a given file_id.
458
+ """
459
+ async with db_registry.async_session() as session:
460
+ result = await session.execute(
461
+ select(SourcePassage).where(SourcePassage.file_id == file_id).where(SourcePassage.organization_id == actor.organization_id)
462
+ )
463
+ passages = result.scalars().all()
464
+ return [p.to_pydantic() for p in passages]
@@ -33,13 +33,34 @@ class ProviderManager:
33
33
  new_provider.create(session, actor=actor)
34
34
  return new_provider.to_pydantic()
35
35
 
36
+ @enforce_types
37
+ @trace_method
38
+ async def create_provider_async(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
39
+ """Create a new provider if it doesn't already exist."""
40
+ async with db_registry.async_session() as session:
41
+ provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok}
42
+ provider = PydanticProvider(**provider_create_args)
43
+
44
+ if provider.name == provider.provider_type.value:
45
+ raise ValueError("Provider name must be unique and different from provider type")
46
+
47
+ # Assign the organization id based on the actor
48
+ provider.organization_id = actor.organization_id
49
+
50
+ # Lazily create the provider id prior to persistence
51
+ provider.resolve_identifier()
52
+
53
+ new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
54
+ await new_provider.create_async(session, actor=actor)
55
+ return new_provider.to_pydantic()
56
+
36
57
  @enforce_types
37
58
  @trace_method
38
59
  def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
39
60
  """Update provider details."""
40
61
  with db_registry.session() as session:
41
62
  # Retrieve the existing provider by ID
42
- existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
63
+ existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True)
43
64
 
44
65
  # Update only the fields that are provided in ProviderUpdate
45
66
  update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
@@ -56,7 +77,7 @@ class ProviderManager:
56
77
  """Delete a provider."""
57
78
  with db_registry.session() as session:
58
79
  # Clear api key field
59
- existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
80
+ existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True)
60
81
  existing_provider.api_key = None
61
82
  existing_provider.update(session, actor=actor)
62
83
 
@@ -65,6 +86,23 @@ class ProviderManager:
65
86
 
66
87
  session.commit()
67
88
 
89
+ @enforce_types
90
+ @trace_method
91
+ async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
92
+ """Delete a provider."""
93
+ async with db_registry.async_session() as session:
94
+ # Clear api key field
95
+ existing_provider = await ProviderModel.read_async(
96
+ db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
97
+ )
98
+ existing_provider.api_key = None
99
+ await existing_provider.update_async(session, actor=actor)
100
+
101
+ # Soft delete in provider table
102
+ await existing_provider.delete_async(session, actor=actor)
103
+
104
+ await session.commit()
105
+
68
106
  @enforce_types
69
107
  @trace_method
70
108
  def list_providers(
@@ -87,6 +125,7 @@ class ProviderManager:
87
125
  after=after,
88
126
  limit=limit,
89
127
  actor=actor,
128
+ check_is_deleted=True,
90
129
  **filter_kwargs,
91
130
  )
92
131
  return [provider.to_pydantic() for provider in providers]
@@ -113,6 +152,7 @@ class ProviderManager:
113
152
  after=after,
114
153
  limit=limit,
115
154
  actor=actor,
155
+ check_is_deleted=True,
116
156
  **filter_kwargs,
117
157
  )
118
158
  return [provider.to_pydantic() for provider in providers]
@@ -129,6 +169,12 @@ class ProviderManager:
129
169
  providers = self.list_providers(name=provider_name, actor=actor)
130
170
  return providers[0].api_key if providers else None
131
171
 
172
+ @enforce_types
173
+ @trace_method
174
+ async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
175
+ providers = await self.list_providers_async(name=provider_name, actor=actor)
176
+ return providers[0].api_key if providers else None
177
+
132
178
  @enforce_types
133
179
  @trace_method
134
180
  def check_provider_api_key(self, provider_check: ProviderCheck) -> None: