letta-nightly 0.8.3.dev20250612104349__py3-none-any.whl → 0.8.4.dev20250614104137__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. letta/__init__.py +1 -1
  2. letta/agent.py +11 -1
  3. letta/agents/base_agent.py +11 -4
  4. letta/agents/ephemeral_summary_agent.py +3 -2
  5. letta/agents/letta_agent.py +109 -78
  6. letta/agents/letta_agent_batch.py +4 -3
  7. letta/agents/voice_agent.py +3 -3
  8. letta/agents/voice_sleeptime_agent.py +3 -2
  9. letta/client/client.py +6 -3
  10. letta/constants.py +6 -0
  11. letta/data_sources/connectors.py +3 -5
  12. letta/functions/async_composio_toolset.py +4 -1
  13. letta/functions/function_sets/files.py +4 -3
  14. letta/functions/schema_generator.py +5 -2
  15. letta/groups/sleeptime_multi_agent_v2.py +4 -3
  16. letta/helpers/converters.py +7 -1
  17. letta/helpers/message_helper.py +31 -11
  18. letta/helpers/tool_rule_solver.py +69 -4
  19. letta/interfaces/anthropic_streaming_interface.py +8 -1
  20. letta/interfaces/openai_streaming_interface.py +4 -1
  21. letta/llm_api/anthropic_client.py +4 -4
  22. letta/llm_api/openai_client.py +56 -11
  23. letta/local_llm/utils.py +3 -20
  24. letta/orm/sqlalchemy_base.py +7 -1
  25. letta/otel/metric_registry.py +26 -0
  26. letta/otel/metrics.py +78 -14
  27. letta/schemas/letta_message_content.py +64 -3
  28. letta/schemas/letta_request.py +5 -1
  29. letta/schemas/message.py +61 -14
  30. letta/schemas/openai/chat_completion_request.py +1 -1
  31. letta/schemas/providers.py +41 -14
  32. letta/schemas/tool_rule.py +67 -0
  33. letta/schemas/user.py +2 -2
  34. letta/server/rest_api/routers/v1/agents.py +22 -12
  35. letta/server/rest_api/routers/v1/sources.py +13 -25
  36. letta/server/server.py +10 -5
  37. letta/services/agent_manager.py +5 -1
  38. letta/services/file_manager.py +219 -0
  39. letta/services/file_processor/chunker/line_chunker.py +119 -14
  40. letta/services/file_processor/file_processor.py +8 -8
  41. letta/services/file_processor/file_types.py +303 -0
  42. letta/services/file_processor/parser/mistral_parser.py +2 -11
  43. letta/services/helpers/agent_manager_helper.py +6 -0
  44. letta/services/message_manager.py +32 -0
  45. letta/services/organization_manager.py +4 -6
  46. letta/services/passage_manager.py +1 -0
  47. letta/services/source_manager.py +0 -208
  48. letta/services/tool_executor/composio_tool_executor.py +5 -1
  49. letta/services/tool_executor/files_tool_executor.py +291 -15
  50. letta/services/user_manager.py +8 -8
  51. letta/system.py +3 -1
  52. letta/utils.py +7 -13
  53. {letta_nightly-0.8.3.dev20250612104349.dist-info → letta_nightly-0.8.4.dev20250614104137.dist-info}/METADATA +2 -2
  54. {letta_nightly-0.8.3.dev20250612104349.dist-info → letta_nightly-0.8.4.dev20250614104137.dist-info}/RECORD +57 -55
  55. {letta_nightly-0.8.3.dev20250612104349.dist-info → letta_nightly-0.8.4.dev20250614104137.dist-info}/LICENSE +0 -0
  56. {letta_nightly-0.8.3.dev20250612104349.dist-info → letta_nightly-0.8.4.dev20250614104137.dist-info}/WHEEL +0 -0
  57. {letta_nightly-0.8.3.dev20250612104349.dist-info → letta_nightly-0.8.4.dev20250614104137.dist-info}/entry_points.txt +0 -0
@@ -12,7 +12,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError
12
12
  from starlette.responses import Response, StreamingResponse
13
13
 
14
14
  from letta.agents.letta_agent import LettaAgent
15
- from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
15
+ from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
16
16
  from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
17
17
  from letta.helpers.datetime_helpers import get_utc_timestamp_ns
18
18
  from letta.log import get_logger
@@ -316,7 +316,7 @@ async def attach_source(
316
316
  # Check if the agent is missing any files tools
317
317
  agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=actor)
318
318
 
319
- files = await server.source_manager.list_files(source_id, actor, include_content=True)
319
+ files = await server.file_manager.list_files(source_id, actor, include_content=True)
320
320
  texts = []
321
321
  file_ids = []
322
322
  file_names = []
@@ -354,7 +354,7 @@ async def detach_source(
354
354
  if not agent_state.sources:
355
355
  agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=actor)
356
356
 
357
- files = await server.source_manager.list_files(source_id, actor)
357
+ files = await server.file_manager.list_files(source_id, actor)
358
358
  file_ids = [f.id for f in files]
359
359
  await server.remove_files_from_context_window(agent_state=agent_state, file_ids=file_ids, actor=actor)
360
360
 
@@ -371,6 +371,14 @@ async def detach_source(
371
371
  @router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent")
372
372
  async def retrieve_agent(
373
373
  agent_id: str,
374
+ include_relationships: Optional[List[str]] = Query(
375
+ None,
376
+ description=(
377
+ "Specify which relational fields (e.g., 'tools', 'sources', 'memory') to include in the response. "
378
+ "If not provided, all relationships are loaded by default. "
379
+ "Using this can optimize performance by reducing unnecessary joins."
380
+ ),
381
+ ),
374
382
  server: "SyncServer" = Depends(get_letta_server),
375
383
  actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
376
384
  ):
@@ -380,7 +388,7 @@ async def retrieve_agent(
380
388
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
381
389
 
382
390
  try:
383
- return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
391
+ return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, include_relationships=include_relationships, actor=actor)
384
392
  except NoResultFound as e:
385
393
  raise HTTPException(status_code=404, detail=str(e))
386
394
 
@@ -665,13 +673,13 @@ async def send_message(
665
673
  Process a user message and return the agent's response.
666
674
  This endpoint accepts a message from a user and processes it through the agent.
667
675
  """
676
+ request_start_timestamp_ns = get_utc_timestamp_ns()
668
677
  MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
669
678
 
670
679
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
671
- request_start_timestamp_ns = get_utc_timestamp_ns()
672
680
  # TODO: This is redundant, remove soon
673
681
  agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
674
- agent_eligible = agent.enable_sleeptime or agent.agent_type == AgentType.sleeptime_agent or not agent.multi_agent_group
682
+ agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
675
683
  model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
676
684
 
677
685
  if agent_eligible and model_compatible:
@@ -701,7 +709,7 @@ async def send_message(
701
709
 
702
710
  result = await agent_loop.step(
703
711
  request.messages,
704
- max_steps=10,
712
+ max_steps=request.max_steps,
705
713
  use_assistant_message=request.use_assistant_message,
706
714
  request_start_timestamp_ns=request_start_timestamp_ns,
707
715
  include_return_message_types=request.include_return_message_types,
@@ -747,16 +755,16 @@ async def send_message_streaming(
747
755
  This endpoint accepts a message from a user and processes it through the agent.
748
756
  It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
749
757
  """
758
+ request_start_timestamp_ns = get_utc_timestamp_ns()
750
759
  MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
751
760
 
752
761
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
753
762
  # TODO: This is redundant, remove soon
754
763
  agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
755
- agent_eligible = agent.enable_sleeptime or agent.agent_type == AgentType.sleeptime_agent or not agent.multi_agent_group
764
+ agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
756
765
  model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
757
766
  model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
758
767
  not_letta_endpoint = not ("inference.letta.com" in agent.llm_config.model_endpoint)
759
- request_start_timestamp_ns = get_utc_timestamp_ns()
760
768
 
761
769
  if agent_eligible and model_compatible:
762
770
  if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent:
@@ -790,7 +798,7 @@ async def send_message_streaming(
790
798
  result = StreamingResponseWithStatusCode(
791
799
  agent_loop.step_stream(
792
800
  input_messages=request.messages,
793
- max_steps=10,
801
+ max_steps=request.max_steps,
794
802
  use_assistant_message=request.use_assistant_message,
795
803
  request_start_timestamp_ns=request_start_timestamp_ns,
796
804
  include_return_message_types=request.include_return_message_types,
@@ -801,7 +809,7 @@ async def send_message_streaming(
801
809
  result = StreamingResponseWithStatusCode(
802
810
  agent_loop.step_stream_no_tokens(
803
811
  request.messages,
804
- max_steps=10,
812
+ max_steps=request.max_steps,
805
813
  use_assistant_message=request.use_assistant_message,
806
814
  request_start_timestamp_ns=request_start_timestamp_ns,
807
815
  include_return_message_types=request.include_return_message_types,
@@ -835,6 +843,7 @@ async def process_message_background(
835
843
  use_assistant_message: bool,
836
844
  assistant_message_tool_name: str,
837
845
  assistant_message_tool_kwarg: str,
846
+ max_steps: int = DEFAULT_MAX_STEPS,
838
847
  include_return_message_types: Optional[List[MessageType]] = None,
839
848
  ) -> None:
840
849
  """Background task to process the message and update job status."""
@@ -919,6 +928,7 @@ async def send_message_async(
919
928
  use_assistant_message=request.use_assistant_message,
920
929
  assistant_message_tool_name=request.assistant_message_tool_name,
921
930
  assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
931
+ max_steps=request.max_steps,
922
932
  include_return_message_types=request.include_return_message_types,
923
933
  )
924
934
 
@@ -969,7 +979,7 @@ async def summarize_agent_conversation(
969
979
 
970
980
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
971
981
  agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"])
972
- agent_eligible = agent.enable_sleeptime or agent.agent_type == AgentType.sleeptime_agent or not agent.multi_agent_group
982
+ agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"]
973
983
  model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
974
984
 
975
985
  if agent_eligible and model_compatible:
@@ -21,16 +21,15 @@ from letta.server.server import SyncServer
21
21
  from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
22
22
  from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
23
23
  from letta.services.file_processor.file_processor import FileProcessor
24
+ from letta.services.file_processor.file_types import get_allowed_media_types, get_extension_to_mime_type_map, register_mime_types
24
25
  from letta.services.file_processor.parser.mistral_parser import MistralFileParser
25
26
  from letta.settings import model_settings, settings
26
27
  from letta.utils import safe_create_task, sanitize_filename
27
28
 
28
29
  logger = get_logger(__name__)
29
30
 
30
- mimetypes.add_type("text/markdown", ".md")
31
- mimetypes.add_type("text/markdown", ".markdown")
32
- mimetypes.add_type("application/jsonl", ".jsonl")
33
- mimetypes.add_type("application/x-jsonlines", ".jsonl")
31
+ # Register all supported file types with Python's mimetypes module
32
+ register_mime_types()
34
33
 
35
34
 
36
35
  router = APIRouter(prefix="/sources", tags=["sources"])
@@ -154,7 +153,7 @@ async def delete_source(
154
153
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
155
154
  source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
156
155
  agent_states = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
157
- files = await server.source_manager.list_files(source_id, actor)
156
+ files = await server.file_manager.list_files(source_id, actor)
158
157
  file_ids = [f.id for f in files]
159
158
 
160
159
  for agent_state in agent_states:
@@ -179,15 +178,7 @@ async def upload_file_to_source(
179
178
  """
180
179
  Upload a file to a data source.
181
180
  """
182
- allowed_media_types = {
183
- "application/pdf",
184
- "text/plain",
185
- "text/markdown",
186
- "text/x-markdown",
187
- "application/json",
188
- "application/jsonl",
189
- "application/x-jsonlines",
190
- }
181
+ allowed_media_types = get_allowed_media_types()
191
182
 
192
183
  # Normalize incoming Content-Type header (strip charset or any parameters).
193
184
  raw_ct = file.content_type or ""
@@ -201,21 +192,18 @@ async def upload_file_to_source(
201
192
 
202
193
  if media_type not in allowed_media_types:
203
194
  ext = Path(file.filename).suffix.lower()
204
- ext_map = {
205
- ".pdf": "application/pdf",
206
- ".txt": "text/plain",
207
- ".json": "application/json",
208
- ".md": "text/markdown",
209
- ".markdown": "text/markdown",
210
- ".jsonl": "application/jsonl",
211
- }
195
+ ext_map = get_extension_to_mime_type_map()
212
196
  media_type = ext_map.get(ext, media_type)
213
197
 
214
198
  # If still not allowed, reject with 415.
215
199
  if media_type not in allowed_media_types:
216
200
  raise HTTPException(
217
201
  status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
218
- detail=(f"Unsupported file type: {media_type or 'unknown'} " f"(filename: {file.filename}). Only PDF, .txt, or .json allowed."),
202
+ detail=(
203
+ f"Unsupported file type: {media_type or 'unknown'} "
204
+ f"(filename: {file.filename}). "
205
+ f"Supported types: PDF, text files (.txt, .md), JSON, and code files (.py, .js, .java, etc.)."
206
+ ),
219
207
  )
220
208
 
221
209
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
@@ -294,7 +282,7 @@ async def list_source_files(
294
282
  List paginated files associated with a data source.
295
283
  """
296
284
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
297
- return await server.source_manager.list_files(
285
+ return await server.file_manager.list_files(
298
286
  source_id=source_id,
299
287
  limit=limit,
300
288
  after=after,
@@ -317,7 +305,7 @@ async def delete_file_from_source(
317
305
  """
318
306
  actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
319
307
 
320
- deleted_file = await server.source_manager.delete_file(file_id=file_id, actor=actor)
308
+ deleted_file = await server.file_manager.delete_file(file_id=file_id, actor=actor)
321
309
 
322
310
  await server.remove_file_from_context_windows(source_id=source_id, file_id=deleted_file.id, actor=actor)
323
311
 
letta/server/server.py CHANGED
@@ -80,6 +80,7 @@ from letta.server.rest_api.interface import StreamingServerInterface
80
80
  from letta.server.rest_api.utils import sse_async_generator
81
81
  from letta.services.agent_manager import AgentManager
82
82
  from letta.services.block_manager import BlockManager
83
+ from letta.services.file_manager import FileManager
83
84
  from letta.services.files_agents_manager import FileAgentManager
84
85
  from letta.services.group_manager import GroupManager
85
86
  from letta.services.helpers.tool_execution_helper import prepare_local_sandbox
@@ -219,6 +220,7 @@ class SyncServer(Server):
219
220
  self.batch_manager = LLMBatchManager()
220
221
  self.telemetry_manager = TelemetryManager()
221
222
  self.file_agent_manager = FileAgentManager()
223
+ self.file_manager = FileManager()
222
224
 
223
225
  # A resusable httpx client
224
226
  timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)
@@ -1507,7 +1509,7 @@ class SyncServer(Server):
1507
1509
  raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
1508
1510
 
1509
1511
  # load data into the document store
1510
- passage_count, document_count = await load_data(connector, source, self.passage_manager, self.source_manager, actor=actor)
1512
+ passage_count, document_count = await load_data(connector, source, self.passage_manager, self.file_manager, actor=actor)
1511
1513
  return passage_count, document_count
1512
1514
 
1513
1515
  def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
@@ -2026,7 +2028,8 @@ class SyncServer(Server):
2026
2028
  )
2027
2029
 
2028
2030
  # Composio wrappers
2029
- def get_composio_client(self, api_key: Optional[str] = None):
2031
+ @staticmethod
2032
+ def get_composio_client(api_key: Optional[str] = None):
2030
2033
  if api_key:
2031
2034
  return Composio(api_key=api_key)
2032
2035
  elif tool_settings.composio_api_key:
@@ -2034,9 +2037,10 @@ class SyncServer(Server):
2034
2037
  else:
2035
2038
  return Composio()
2036
2039
 
2037
- def get_composio_apps(self, api_key: Optional[str] = None) -> List["AppModel"]:
2040
+ @staticmethod
2041
+ def get_composio_apps(api_key: Optional[str] = None) -> List["AppModel"]:
2038
2042
  """Get a list of all Composio apps with actions"""
2039
- apps = self.get_composio_client(api_key=api_key).apps.get()
2043
+ apps = SyncServer.get_composio_client(api_key=api_key).apps.get()
2040
2044
  apps_with_actions = []
2041
2045
  for app in apps:
2042
2046
  # A bit of hacky logic until composio patches this
@@ -2047,7 +2051,8 @@ class SyncServer(Server):
2047
2051
 
2048
2052
  def get_composio_actions_from_app_name(self, composio_app_name: str, api_key: Optional[str] = None) -> List["ActionModel"]:
2049
2053
  actions = self.get_composio_client(api_key=api_key).actions.get(apps=[composio_app_name])
2050
- return actions
2054
+ # Filter out deprecated composio actions
2055
+ return [action for action in actions if "deprecated" not in action.description.lower()]
2051
2056
 
2052
2057
  # MCP wrappers
2053
2058
  # TODO support both command + SSE servers (via config)
@@ -19,6 +19,7 @@ from letta.constants import (
19
19
  FILES_TOOLS,
20
20
  MULTI_AGENT_TOOLS,
21
21
  )
22
+ from letta.helpers import ToolRulesSolver
22
23
  from letta.helpers.datetime_helpers import get_utc_time
23
24
  from letta.llm_api.llm_client import LLMClient
24
25
  from letta.log import get_logger
@@ -1444,7 +1445,7 @@ class AgentManager:
1444
1445
  @trace_method
1445
1446
  @enforce_types
1446
1447
  async def rebuild_system_prompt_async(
1447
- self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True
1448
+ self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None
1448
1449
  ) -> PydanticAgentState:
1449
1450
  """Rebuilds the system message with the latest memory object and any shared memory block updates
1450
1451
 
@@ -1453,6 +1454,8 @@ class AgentManager:
1453
1454
  Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
1454
1455
  """
1455
1456
  agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
1457
+ if not tool_rules_solver:
1458
+ tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
1456
1459
 
1457
1460
  curr_system_message = await self.get_system_message_async(
1458
1461
  agent_id=agent_id, actor=actor
@@ -1492,6 +1495,7 @@ class AgentManager:
1492
1495
  in_context_memory_last_edit=memory_edit_timestamp,
1493
1496
  previous_message_count=num_messages,
1494
1497
  archival_memory_size=num_archival_memories,
1498
+ tool_rules_solver=tool_rules_solver,
1495
1499
  )
1496
1500
 
1497
1501
  diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
@@ -0,0 +1,219 @@
1
+ from datetime import datetime
2
+ from typing import List, Optional
3
+
4
+ from sqlalchemy import select, update
5
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
6
+ from sqlalchemy.exc import IntegrityError
7
+ from sqlalchemy.orm import selectinload
8
+
9
+ from letta.orm.errors import NoResultFound
10
+ from letta.orm.file import FileContent as FileContentModel
11
+ from letta.orm.file import FileMetadata as FileMetadataModel
12
+ from letta.orm.sqlalchemy_base import AccessType
13
+ from letta.otel.tracing import trace_method
14
+ from letta.schemas.enums import FileProcessingStatus
15
+ from letta.schemas.file import FileMetadata as PydanticFileMetadata
16
+ from letta.schemas.user import User as PydanticUser
17
+ from letta.server.db import db_registry
18
+ from letta.utils import enforce_types
19
+
20
+
21
+ class FileManager:
22
+ """Manager class to handle business logic related to files."""
23
+
24
+ @enforce_types
25
+ @trace_method
26
+ async def create_file(
27
+ self,
28
+ file_metadata: PydanticFileMetadata,
29
+ actor: PydanticUser,
30
+ *,
31
+ text: Optional[str] = None,
32
+ ) -> PydanticFileMetadata:
33
+
34
+ # short-circuit if it already exists
35
+ existing = await self.get_file_by_id(file_metadata.id, actor=actor)
36
+ if existing:
37
+ return existing
38
+
39
+ async with db_registry.async_session() as session:
40
+ try:
41
+ file_metadata.organization_id = actor.organization_id
42
+ file_orm = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
43
+ await file_orm.create_async(session, actor=actor, no_commit=True)
44
+
45
+ if text is not None:
46
+ content_orm = FileContentModel(file_id=file_orm.id, text=text)
47
+ await content_orm.create_async(session, actor=actor, no_commit=True)
48
+
49
+ await session.commit()
50
+ await session.refresh(file_orm)
51
+ return await file_orm.to_pydantic_async()
52
+
53
+ except IntegrityError:
54
+ await session.rollback()
55
+ return await self.get_file_by_id(file_metadata.id, actor=actor)
56
+
57
+ # TODO: We make actor optional for now, but should most likely be enforced due to security reasons
58
+ @enforce_types
59
+ @trace_method
60
+ async def get_file_by_id(
61
+ self,
62
+ file_id: str,
63
+ actor: Optional[PydanticUser] = None,
64
+ *,
65
+ include_content: bool = False,
66
+ ) -> Optional[PydanticFileMetadata]:
67
+ """Retrieve a file by its ID.
68
+
69
+ If `include_content=True`, the FileContent relationship is eagerly
70
+ loaded so `to_pydantic(include_content=True)` never triggers a
71
+ lazy SELECT (avoids MissingGreenlet).
72
+ """
73
+ async with db_registry.async_session() as session:
74
+ try:
75
+ if include_content:
76
+ # explicit eager load
77
+ query = (
78
+ select(FileMetadataModel).where(FileMetadataModel.id == file_id).options(selectinload(FileMetadataModel.content))
79
+ )
80
+ # apply org-scoping if actor provided
81
+ if actor:
82
+ query = FileMetadataModel.apply_access_predicate(
83
+ query,
84
+ actor,
85
+ access=["read"],
86
+ access_type=AccessType.ORGANIZATION,
87
+ )
88
+
89
+ result = await session.execute(query)
90
+ file_orm = result.scalar_one()
91
+ else:
92
+ # fast path (metadata only)
93
+ file_orm = await FileMetadataModel.read_async(
94
+ db_session=session,
95
+ identifier=file_id,
96
+ actor=actor,
97
+ )
98
+
99
+ return await file_orm.to_pydantic_async(include_content=include_content)
100
+
101
+ except NoResultFound:
102
+ return None
103
+
104
+ @enforce_types
105
+ @trace_method
106
+ async def update_file_status(
107
+ self,
108
+ *,
109
+ file_id: str,
110
+ actor: PydanticUser,
111
+ processing_status: Optional[FileProcessingStatus] = None,
112
+ error_message: Optional[str] = None,
113
+ ) -> PydanticFileMetadata:
114
+ """
115
+ Update processing_status and/or error_message on a FileMetadata row.
116
+
117
+ * 1st round-trip → UPDATE
118
+ * 2nd round-trip → SELECT fresh row (same as read_async)
119
+ """
120
+
121
+ if processing_status is None and error_message is None:
122
+ raise ValueError("Nothing to update")
123
+
124
+ values: dict[str, object] = {"updated_at": datetime.utcnow()}
125
+ if processing_status is not None:
126
+ values["processing_status"] = processing_status
127
+ if error_message is not None:
128
+ values["error_message"] = error_message
129
+
130
+ async with db_registry.async_session() as session:
131
+ # Fast in-place update – no ORM hydration
132
+ stmt = (
133
+ update(FileMetadataModel)
134
+ .where(
135
+ FileMetadataModel.id == file_id,
136
+ FileMetadataModel.organization_id == actor.organization_id,
137
+ )
138
+ .values(**values)
139
+ )
140
+ await session.execute(stmt)
141
+ await session.commit()
142
+
143
+ # Reload via normal accessor so we return a fully-attached object
144
+ file_orm = await FileMetadataModel.read_async(
145
+ db_session=session,
146
+ identifier=file_id,
147
+ actor=actor,
148
+ )
149
+ return await file_orm.to_pydantic_async()
150
+
151
+ @enforce_types
152
+ @trace_method
153
+ async def upsert_file_content(
154
+ self,
155
+ *,
156
+ file_id: str,
157
+ text: str,
158
+ actor: PydanticUser,
159
+ ) -> PydanticFileMetadata:
160
+ async with db_registry.async_session() as session:
161
+ await FileMetadataModel.read_async(session, file_id, actor)
162
+
163
+ dialect_name = session.bind.dialect.name
164
+
165
+ if dialect_name == "postgresql":
166
+ stmt = (
167
+ pg_insert(FileContentModel)
168
+ .values(file_id=file_id, text=text)
169
+ .on_conflict_do_update(
170
+ index_elements=[FileContentModel.file_id],
171
+ set_={"text": text},
172
+ )
173
+ )
174
+ await session.execute(stmt)
175
+ else:
176
+ # Emulate upsert for SQLite and others
177
+ stmt = select(FileContentModel).where(FileContentModel.file_id == file_id)
178
+ result = await session.execute(stmt)
179
+ existing = result.scalar_one_or_none()
180
+
181
+ if existing:
182
+ await session.execute(update(FileContentModel).where(FileContentModel.file_id == file_id).values(text=text))
183
+ else:
184
+ session.add(FileContentModel(file_id=file_id, text=text))
185
+
186
+ await session.commit()
187
+
188
+ # Reload with content
189
+ query = select(FileMetadataModel).options(selectinload(FileMetadataModel.content)).where(FileMetadataModel.id == file_id)
190
+ result = await session.execute(query)
191
+ return await result.scalar_one().to_pydantic_async(include_content=True)
192
+
193
+ @enforce_types
194
+ @trace_method
195
+ async def list_files(
196
+ self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, include_content: bool = False
197
+ ) -> List[PydanticFileMetadata]:
198
+ """List all files with optional pagination."""
199
+ async with db_registry.async_session() as session:
200
+ options = [selectinload(FileMetadataModel.content)] if include_content else None
201
+
202
+ files = await FileMetadataModel.list_async(
203
+ db_session=session,
204
+ after=after,
205
+ limit=limit,
206
+ organization_id=actor.organization_id,
207
+ source_id=source_id,
208
+ query_options=options,
209
+ )
210
+ return [await file.to_pydantic_async(include_content=include_content) for file in files]
211
+
212
+ @enforce_types
213
+ @trace_method
214
+ async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
215
+ """Delete a file by its ID."""
216
+ async with db_registry.async_session() as session:
217
+ file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
218
+ await file.hard_delete_async(db_session=session, actor=actor)
219
+ return await file.to_pydantic_async()