letta-nightly 0.8.4.dev20250618104304__py3-none-any.whl → 0.8.5.dev20250619180801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. letta/__init__.py +1 -1
  2. letta/agents/letta_agent.py +54 -20
  3. letta/agents/voice_agent.py +47 -31
  4. letta/constants.py +1 -1
  5. letta/data_sources/redis_client.py +11 -6
  6. letta/functions/function_sets/builtin.py +35 -11
  7. letta/functions/prompts.py +26 -0
  8. letta/functions/types.py +6 -0
  9. letta/interfaces/openai_chat_completions_streaming_interface.py +0 -1
  10. letta/llm_api/anthropic.py +9 -1
  11. letta/llm_api/anthropic_client.py +22 -3
  12. letta/llm_api/aws_bedrock.py +10 -6
  13. letta/llm_api/llm_api_tools.py +3 -0
  14. letta/llm_api/openai_client.py +1 -1
  15. letta/orm/agent.py +14 -1
  16. letta/orm/job.py +3 -0
  17. letta/orm/provider.py +3 -1
  18. letta/schemas/agent.py +7 -0
  19. letta/schemas/embedding_config.py +8 -0
  20. letta/schemas/enums.py +0 -1
  21. letta/schemas/job.py +1 -0
  22. letta/schemas/providers.py +13 -5
  23. letta/server/rest_api/routers/v1/agents.py +76 -35
  24. letta/server/rest_api/routers/v1/providers.py +7 -7
  25. letta/server/rest_api/routers/v1/sources.py +39 -19
  26. letta/server/rest_api/routers/v1/tools.py +96 -31
  27. letta/services/agent_manager.py +8 -2
  28. letta/services/file_processor/chunker/llama_index_chunker.py +89 -1
  29. letta/services/file_processor/embedder/openai_embedder.py +6 -1
  30. letta/services/file_processor/parser/mistral_parser.py +2 -2
  31. letta/services/helpers/agent_manager_helper.py +44 -16
  32. letta/services/job_manager.py +35 -17
  33. letta/services/mcp/base_client.py +26 -1
  34. letta/services/mcp_manager.py +33 -18
  35. letta/services/provider_manager.py +30 -0
  36. letta/services/tool_executor/builtin_tool_executor.py +335 -43
  37. letta/services/tool_manager.py +25 -1
  38. letta/services/user_manager.py +1 -1
  39. letta/settings.py +3 -0
  40. {letta_nightly-0.8.4.dev20250618104304.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/METADATA +4 -3
  41. {letta_nightly-0.8.4.dev20250618104304.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/RECORD +44 -42
  42. {letta_nightly-0.8.4.dev20250618104304.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/LICENSE +0 -0
  43. {letta_nightly-0.8.4.dev20250618104304.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/WHEEL +0 -0
  44. {letta_nightly-0.8.4.dev20250618104304.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import List
1
+ from typing import List, Tuple
2
2
 
3
3
  from mistralai import OCRPageObject
4
4
 
@@ -27,3 +27,91 @@ class LlamaIndexChunker:
27
27
  except Exception as e:
28
28
  logger.error(f"Chunking failed: {str(e)}")
29
29
  raise
30
+
31
+
32
+ class MarkdownChunker:
33
+ """Markdown-specific chunker that preserves line numbers for citation purposes"""
34
+
35
+ def __init__(self, chunk_size: int = 2048):
36
+ self.chunk_size = chunk_size
37
+ # No overlap for line-based citations to avoid ambiguity
38
+
39
+ from llama_index.core.node_parser import MarkdownNodeParser
40
+
41
+ self.parser = MarkdownNodeParser()
42
+
43
+ def chunk_markdown_with_line_numbers(self, markdown_content: str) -> List[Tuple[str, int, int]]:
44
+ """
45
+ Chunk markdown content while preserving line number mappings.
46
+
47
+ Returns:
48
+ List of tuples: (chunk_text, start_line, end_line)
49
+ """
50
+ try:
51
+ # Split content into lines for line number tracking
52
+ lines = markdown_content.split("\n")
53
+
54
+ # Create nodes using MarkdownNodeParser
55
+ from llama_index.core import Document
56
+
57
+ document = Document(text=markdown_content)
58
+ nodes = self.parser.get_nodes_from_documents([document])
59
+
60
+ chunks_with_line_numbers = []
61
+
62
+ for node in nodes:
63
+ chunk_text = node.text
64
+
65
+ # Find the line numbers for this chunk
66
+ start_line, end_line = self._find_line_numbers(chunk_text, lines)
67
+
68
+ chunks_with_line_numbers.append((chunk_text, start_line, end_line))
69
+
70
+ return chunks_with_line_numbers
71
+
72
+ except Exception as e:
73
+ logger.error(f"Markdown chunking failed: {str(e)}")
74
+ # Fallback to simple line-based chunking
75
+ return self._fallback_line_chunking(markdown_content)
76
+
77
+ def _find_line_numbers(self, chunk_text: str, lines: List[str]) -> Tuple[int, int]:
78
+ """Find the start and end line numbers for a given chunk of text."""
79
+ chunk_lines = chunk_text.split("\n")
80
+
81
+ # Find the first line of the chunk in the original document
82
+ start_line = 1
83
+ for i, line in enumerate(lines):
84
+ if chunk_lines[0].strip() in line.strip() and len(chunk_lines[0].strip()) > 10: # Avoid matching short lines
85
+ start_line = i + 1
86
+ break
87
+
88
+ # Calculate end line
89
+ end_line = start_line + len(chunk_lines) - 1
90
+
91
+ return start_line, min(end_line, len(lines))
92
+
93
+ def _fallback_line_chunking(self, markdown_content: str) -> List[Tuple[str, int, int]]:
94
+ """Fallback chunking method that simply splits by lines with no overlap."""
95
+ lines = markdown_content.split("\n")
96
+ chunks = []
97
+
98
+ i = 0
99
+ while i < len(lines):
100
+ chunk_lines = []
101
+ start_line = i + 1
102
+ char_count = 0
103
+
104
+ # Build chunk until we hit size limit
105
+ while i < len(lines) and char_count < self.chunk_size:
106
+ line = lines[i]
107
+ chunk_lines.append(line)
108
+ char_count += len(line) + 1 # +1 for newline
109
+ i += 1
110
+
111
+ end_line = i
112
+ chunk_text = "\n".join(chunk_lines)
113
+ chunks.append((chunk_text, start_line, end_line))
114
+
115
+ # No overlap - continue from where we left off
116
+
117
+ return chunks
@@ -16,7 +16,12 @@ class OpenAIEmbedder:
16
16
  """OpenAI-based embedding generation"""
17
17
 
18
18
  def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
19
- self.embedding_config = embedding_config or EmbeddingConfig.default_config(provider="openai")
19
+ self.default_embedding_config = (
20
+ EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai")
21
+ if model_settings.openai_api_key
22
+ else EmbeddingConfig.default_config(model_name="letta")
23
+ )
24
+ self.embedding_config = embedding_config or self.default_embedding_config
20
25
 
21
26
  # TODO: Unify to global OpenAI client
22
27
  self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key)
@@ -20,11 +20,10 @@ class MistralFileParser(FileParser):
20
20
  async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse:
21
21
  """Extract text using Mistral OCR or shortcut for plain text."""
22
22
  try:
23
- logger.info(f"Extracting text using Mistral OCR model: {self.model}")
24
-
25
23
  # TODO: Kind of hacky...we try to exit early here?
26
24
  # TODO: Create our internal file parser representation we return instead of OCRResponse
27
25
  if is_simple_text_mime_type(mime_type):
26
+ logger.info(f"Extracting text directly (no Mistral): {self.model}")
28
27
  text = content.decode("utf-8", errors="replace")
29
28
  return OCRResponse(
30
29
  model=self.model,
@@ -43,6 +42,7 @@ class MistralFileParser(FileParser):
43
42
  base64_encoded_content = base64.b64encode(content).decode("utf-8")
44
43
  document_url = f"data:{mime_type};base64,{base64_encoded_content}"
45
44
 
45
+ logger.info(f"Extracting text using Mistral OCR model: {self.model}")
46
46
  async with Mistral(api_key=settings.mistral_api_key) as mistral:
47
47
  ocr_response = await mistral.ocr.process_async(
48
48
  model="mistral-ocr-latest", document={"type": "document_url", "document_url": document_url}, include_image_base64=False
@@ -449,41 +449,69 @@ def _cursor_filter(created_at_col, id_col, ref_created_at, ref_id, forward: bool
449
449
  )
450
450
 
451
451
 
452
- def _apply_pagination(query, before: Optional[str], after: Optional[str], session, ascending: bool = True) -> any:
452
+ def _apply_pagination(
453
+ query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
454
+ ) -> any:
455
+ # Determine the sort column
456
+ if sort_by == "last_run_completion":
457
+ sort_column = AgentModel.last_run_completion
458
+ else:
459
+ sort_column = AgentModel.created_at
460
+
453
461
  if after:
454
- result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
462
+ if sort_by == "last_run_completion":
463
+ result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after)).first()
464
+ else:
465
+ result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
455
466
  if result:
456
- after_created_at, after_id = result
457
- query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, after_created_at, after_id, forward=ascending))
467
+ after_sort_value, after_id = result
468
+ query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
458
469
 
459
470
  if before:
460
- result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
471
+ if sort_by == "last_run_completion":
472
+ result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before)).first()
473
+ else:
474
+ result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
461
475
  if result:
462
- before_created_at, before_id = result
463
- query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, before_created_at, before_id, forward=not ascending))
476
+ before_sort_value, before_id = result
477
+ query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
464
478
 
465
479
  # Apply ordering
466
480
  order_fn = asc if ascending else desc
467
- query = query.order_by(order_fn(AgentModel.created_at), order_fn(AgentModel.id))
481
+ query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
468
482
  return query
469
483
 
470
484
 
471
- async def _apply_pagination_async(query, before: Optional[str], after: Optional[str], session, ascending: bool = True) -> any:
485
+ async def _apply_pagination_async(
486
+ query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
487
+ ) -> any:
488
+ # Determine the sort column
489
+ if sort_by == "last_run_completion":
490
+ sort_column = AgentModel.last_run_completion
491
+ else:
492
+ sort_column = AgentModel.created_at
493
+
472
494
  if after:
473
- result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
495
+ if sort_by == "last_run_completion":
496
+ result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after))).first()
497
+ else:
498
+ result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
474
499
  if result:
475
- after_created_at, after_id = result
476
- query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, after_created_at, after_id, forward=ascending))
500
+ after_sort_value, after_id = result
501
+ query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
477
502
 
478
503
  if before:
479
- result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
504
+ if sort_by == "last_run_completion":
505
+ result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before))).first()
506
+ else:
507
+ result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
480
508
  if result:
481
- before_created_at, before_id = result
482
- query = query.where(_cursor_filter(AgentModel.created_at, AgentModel.id, before_created_at, before_id, forward=not ascending))
509
+ before_sort_value, before_id = result
510
+ query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
483
511
 
484
512
  # Apply ordering
485
513
  order_fn = asc if ascending else desc
486
- query = query.order_by(order_fn(AgentModel.created_at), order_fn(AgentModel.id))
514
+ query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
487
515
  return query
488
516
 
489
517
 
@@ -6,6 +6,7 @@ from sqlalchemy import select
6
6
  from sqlalchemy.orm import Session
7
7
 
8
8
  from letta.helpers.datetime_helpers import get_utc_time
9
+ from letta.log import get_logger
9
10
  from letta.orm.enums import JobType
10
11
  from letta.orm.errors import NoResultFound
11
12
  from letta.orm.job import Job as JobModel
@@ -28,6 +29,8 @@ from letta.schemas.user import User as PydanticUser
28
29
  from letta.server.db import db_registry
29
30
  from letta.utils import enforce_types
30
31
 
32
+ logger = get_logger(__name__)
33
+
31
34
 
32
35
  class JobManager:
33
36
  """Manager class to handle business logic related to Jobs."""
@@ -67,18 +70,22 @@ class JobManager:
67
70
  with db_registry.session() as session:
68
71
  # Fetch the job by ID
69
72
  job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"])
73
+ not_completed_before = not bool(job.completed_at)
70
74
 
71
75
  # Update job attributes with only the fields that were explicitly set
72
76
  update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
73
77
 
74
78
  # Automatically update the completion timestamp if status is set to 'completed'
75
79
  for key, value in update_data.items():
80
+ # Ensure completed_at is timezone-naive for database compatibility
81
+ if key == "completed_at" and value is not None and hasattr(value, "replace"):
82
+ value = value.replace(tzinfo=None)
76
83
  setattr(job, key, value)
77
84
 
78
- if update_data.get("status") == JobStatus.completed and not job.completed_at:
85
+ if job_update.status in {JobStatus.completed, JobStatus.failed} and not_completed_before:
79
86
  job.completed_at = get_utc_time().replace(tzinfo=None)
80
87
  if job.callback_url:
81
- self._dispatch_callback(session, job)
88
+ self._dispatch_callback(job)
82
89
 
83
90
  # Save the updated job to the database
84
91
  job.update(db_session=session, actor=actor)
@@ -92,18 +99,22 @@ class JobManager:
92
99
  async with db_registry.async_session() as session:
93
100
  # Fetch the job by ID
94
101
  job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
102
+ not_completed_before = not bool(job.completed_at)
95
103
 
96
104
  # Update job attributes with only the fields that were explicitly set
97
105
  update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
98
106
 
99
107
  # Automatically update the completion timestamp if status is set to 'completed'
100
108
  for key, value in update_data.items():
109
+ # Ensure completed_at is timezone-naive for database compatibility
110
+ if key == "completed_at" and value is not None and hasattr(value, "replace"):
111
+ value = value.replace(tzinfo=None)
101
112
  setattr(job, key, value)
102
113
 
103
- if update_data.get("status") == JobStatus.completed and not job.completed_at:
114
+ if job_update.status in {JobStatus.completed, JobStatus.failed} and not_completed_before:
104
115
  job.completed_at = get_utc_time().replace(tzinfo=None)
105
116
  if job.callback_url:
106
- await self._dispatch_callback_async(session, job)
117
+ await self._dispatch_callback_async(job)
107
118
 
108
119
  # Save the updated job to the database
109
120
  await job.update_async(db_session=session, actor=actor)
@@ -586,7 +597,7 @@ class JobManager:
586
597
  request_config = job.request_config or LettaRequestConfig()
587
598
  return request_config
588
599
 
589
- def _dispatch_callback(self, session: Session, job: JobModel) -> None:
600
+ def _dispatch_callback(self, job: JobModel) -> None:
590
601
  """
591
602
  POST a standard JSON payload to job.callback_url
592
603
  and record timestamp + HTTP status.
@@ -595,22 +606,25 @@ class JobManager:
595
606
  payload = {
596
607
  "job_id": job.id,
597
608
  "status": job.status,
598
- "completed_at": job.completed_at.isoformat(),
609
+ "completed_at": job.completed_at.isoformat() if job.completed_at else None,
610
+ "metadata": job.metadata_,
599
611
  }
600
612
  try:
601
613
  import httpx
602
614
 
603
615
  resp = httpx.post(job.callback_url, json=payload, timeout=5.0)
604
- job.callback_sent_at = get_utc_time()
616
+ job.callback_sent_at = get_utc_time().replace(tzinfo=None)
605
617
  job.callback_status_code = resp.status_code
606
618
 
607
- except Exception:
608
- return
609
-
610
- session.add(job)
611
- session.commit()
619
+ except Exception as e:
620
+ error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
621
+ logger.error(error_message)
622
+ # Record the failed attempt
623
+ job.callback_sent_at = get_utc_time().replace(tzinfo=None)
624
+ job.callback_error = error_message
625
+ # Continue silently - callback failures should not affect job completion
612
626
 
613
- async def _dispatch_callback_async(self, session, job: JobModel) -> None:
627
+ async def _dispatch_callback_async(self, job: JobModel) -> None:
614
628
  """
615
629
  POST a standard JSON payload to job.callback_url and record timestamp + HTTP status asynchronously.
616
630
  """
@@ -618,6 +632,7 @@ class JobManager:
618
632
  "job_id": job.id,
619
633
  "status": job.status,
620
634
  "completed_at": job.completed_at.isoformat() if job.completed_at else None,
635
+ "metadata": job.metadata_,
621
636
  }
622
637
 
623
638
  try:
@@ -628,7 +643,10 @@ class JobManager:
628
643
  # Ensure timestamp is timezone-naive for DB compatibility
629
644
  job.callback_sent_at = get_utc_time().replace(tzinfo=None)
630
645
  job.callback_status_code = resp.status_code
631
- except Exception:
632
- # Silently fail on callback errors - job updates should still succeed
633
- # In production, this would include proper error logging
634
- pass
646
+ except Exception as e:
647
+ error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
648
+ logger.error(error_message)
649
+ # Record the failed attempt
650
+ job.callback_sent_at = get_utc_time().replace(tzinfo=None)
651
+ job.callback_error = error_message
652
+ # Continue silently - callback failures should not affect job completion
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  from contextlib import AsyncExitStack
2
3
  from typing import Optional, Tuple
3
4
 
@@ -18,6 +19,9 @@ class AsyncBaseMCPClient:
18
19
  self.exit_stack = AsyncExitStack()
19
20
  self.session: Optional[ClientSession] = None
20
21
  self.initialized = False
22
+ # Track the task that created this client
23
+ self._creation_task = asyncio.current_task()
24
+ self._cleanup_queue = asyncio.Queue(maxsize=1)
21
25
 
22
26
  async def connect_to_server(self):
23
27
  try:
@@ -74,8 +78,29 @@ class AsyncBaseMCPClient:
74
78
  raise RuntimeError("MCPClient has not been initialized")
75
79
 
76
80
  async def cleanup(self):
77
- """Clean up resources"""
81
+ """Clean up resources - ensure this runs in the same task"""
82
+ if hasattr(self, "_cleanup_task"):
83
+ # If we're in a different task, schedule cleanup in original task
84
+ current_task = asyncio.current_task()
85
+ if current_task != self._creation_task:
86
+ # Create a future to signal completion
87
+ cleanup_done = asyncio.Future()
88
+ self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done))
89
+ await cleanup_done
90
+ return
91
+
92
+ # Normal cleanup
78
93
  await self.exit_stack.aclose()
79
94
 
80
95
  def to_sync_client(self):
81
96
  raise NotImplementedError("Subclasses must implement to_sync_client")
97
+
98
+ async def __aenter__(self):
99
+ """Enter the async context manager."""
100
+ await self.connect_to_server()
101
+ return self
102
+
103
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
104
+ """Exit the async context manager."""
105
+ await self.cleanup()
106
+ return False # Don't suppress exceptions
@@ -75,23 +75,23 @@ class MCPManager:
75
75
  server_config = mcp_config[mcp_server_name]
76
76
 
77
77
  if isinstance(server_config, SSEServerConfig):
78
- mcp_client = AsyncSSEMCPClient(server_config=server_config)
78
+ # mcp_client = AsyncSSEMCPClient(server_config=server_config)
79
+ async with AsyncSSEMCPClient(server_config=server_config) as mcp_client:
80
+ result, success = await mcp_client.execute_tool(tool_name, tool_args)
81
+ logger.info(f"MCP Result: {result}, Success: {success}")
82
+ return result, success
79
83
  elif isinstance(server_config, StdioServerConfig):
80
- mcp_client = AsyncStdioMCPClient(server_config=server_config)
84
+ async with AsyncStdioMCPClient(server_config=server_config) as mcp_client:
85
+ result, success = await mcp_client.execute_tool(tool_name, tool_args)
86
+ logger.info(f"MCP Result: {result}, Success: {success}")
87
+ return result, success
81
88
  elif isinstance(server_config, StreamableHTTPServerConfig):
82
- mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config)
89
+ async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client:
90
+ result, success = await mcp_client.execute_tool(tool_name, tool_args)
91
+ logger.info(f"MCP Result: {result}, Success: {success}")
92
+ return result, success
83
93
  else:
84
94
  raise ValueError(f"Unsupported server config type: {type(server_config)}")
85
- await mcp_client.connect_to_server()
86
-
87
- # call tool
88
- result, success = await mcp_client.execute_tool(tool_name, tool_args)
89
- logger.info(f"MCP Result: {result}, Success: {success}")
90
- # TODO: change to pydantic tool
91
-
92
- await mcp_client.cleanup()
93
-
94
- return result, success
95
95
 
96
96
  @enforce_types
97
97
  async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
@@ -149,19 +149,19 @@ class MCPManager:
149
149
  return mcp_server
150
150
 
151
151
  @enforce_types
152
- async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> PydanticTool:
153
- """Create a new tool based on the ToolCreate schema."""
154
- with db_registry.session() as session:
152
+ async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
153
+ """Create a new MCP server."""
154
+ async with db_registry.async_session() as session:
155
155
  # Set the organization id at the ORM layer
156
156
  pydantic_mcp_server.organization_id = actor.organization_id
157
157
  mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
158
158
 
159
159
  mcp_server = MCPServerModel(**mcp_server_data)
160
- mcp_server.create(session, actor=actor) # Re-raise other database-related errors
160
+ mcp_server = await mcp_server.create_async(session, actor=actor)
161
161
  return mcp_server.to_pydantic()
162
162
 
163
163
  @enforce_types
164
- async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> PydanticTool:
164
+ async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
165
165
  """Update a tool by its ID with the given ToolUpdate object."""
166
166
  async with db_registry.async_session() as session:
167
167
  # Fetch the tool by ID
@@ -177,6 +177,21 @@ class MCPManager:
177
177
  # Save the updated tool to the database mcp_server = await mcp_server.update_async(db_session=session, actor=actor)
178
178
  return mcp_server.to_pydantic()
179
179
 
180
+ @enforce_types
181
+ async def update_mcp_server_by_name(self, mcp_server_name: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
182
+ """Update an MCP server by its name."""
183
+ mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
184
+ if not mcp_server_id:
185
+ raise HTTPException(
186
+ status_code=404,
187
+ detail={
188
+ "code": "MCPServerNotFoundError",
189
+ "message": f"MCP server {mcp_server_name} not found",
190
+ "mcp_server_name": mcp_server_name,
191
+ },
192
+ )
193
+ return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
194
+
180
195
  @enforce_types
181
196
  async def get_mcp_server_id_by_name(self, mcp_server_name: str, actor: PydanticUser) -> Optional[str]:
182
197
  """Retrieve a MCP server by its name and a user"""
@@ -71,6 +71,25 @@ class ProviderManager:
71
71
  existing_provider.update(session, actor=actor)
72
72
  return existing_provider.to_pydantic()
73
73
 
74
+ @enforce_types
75
+ @trace_method
76
+ async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
77
+ """Update provider details."""
78
+ async with db_registry.async_session() as session:
79
+ # Retrieve the existing provider by ID
80
+ existing_provider = await ProviderModel.read_async(
81
+ db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
82
+ )
83
+
84
+ # Update only the fields that are provided in ProviderUpdate
85
+ update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
86
+ for key, value in update_data.items():
87
+ setattr(existing_provider, key, value)
88
+
89
+ # Commit the updated provider
90
+ await existing_provider.update_async(session, actor=actor)
91
+ return existing_provider.to_pydantic()
92
+
74
93
  @enforce_types
75
94
  @trace_method
76
95
  def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
@@ -175,6 +194,15 @@ class ProviderManager:
175
194
  providers = await self.list_providers_async(name=provider_name, actor=actor)
176
195
  return providers[0].api_key if providers else None
177
196
 
197
+ @enforce_types
198
+ @trace_method
199
+ async def get_bedrock_credentials_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
200
+ providers = await self.list_providers_async(name=provider_name, actor=actor)
201
+ access_key = providers[0].api_key if providers else None
202
+ secret_key = providers[0].api_secret if providers else None
203
+ region = providers[0].region if providers else None
204
+ return access_key, secret_key, region
205
+
178
206
  @enforce_types
179
207
  @trace_method
180
208
  def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
@@ -183,6 +211,8 @@ class ProviderManager:
183
211
  provider_type=provider_check.provider_type,
184
212
  api_key=provider_check.api_key,
185
213
  provider_category=ProviderCategory.byok,
214
+ secret_key=provider_check.api_secret,
215
+ region=provider_check.region,
186
216
  ).cast_to_subtype()
187
217
 
188
218
  # TODO: add more string sanity checks here before we hit actual endpoints