letta-nightly 0.8.4.dev20250619104255__py3-none-any.whl → 0.8.5.dev20250619180801__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agents/letta_agent.py +54 -20
- letta/agents/voice_agent.py +47 -31
- letta/constants.py +1 -1
- letta/data_sources/redis_client.py +11 -6
- letta/functions/function_sets/builtin.py +35 -11
- letta/functions/prompts.py +26 -0
- letta/functions/types.py +6 -0
- letta/interfaces/openai_chat_completions_streaming_interface.py +0 -1
- letta/llm_api/anthropic.py +9 -1
- letta/llm_api/anthropic_client.py +8 -11
- letta/llm_api/aws_bedrock.py +10 -6
- letta/llm_api/llm_api_tools.py +3 -0
- letta/llm_api/openai_client.py +1 -1
- letta/orm/agent.py +14 -1
- letta/orm/job.py +3 -0
- letta/orm/provider.py +3 -1
- letta/schemas/agent.py +7 -0
- letta/schemas/embedding_config.py +8 -0
- letta/schemas/enums.py +0 -1
- letta/schemas/job.py +1 -0
- letta/schemas/providers.py +13 -5
- letta/server/rest_api/routers/v1/agents.py +76 -35
- letta/server/rest_api/routers/v1/providers.py +7 -7
- letta/server/rest_api/routers/v1/sources.py +39 -19
- letta/server/rest_api/routers/v1/tools.py +96 -31
- letta/services/agent_manager.py +8 -2
- letta/services/file_processor/chunker/llama_index_chunker.py +89 -1
- letta/services/file_processor/embedder/openai_embedder.py +6 -1
- letta/services/file_processor/parser/mistral_parser.py +2 -2
- letta/services/helpers/agent_manager_helper.py +44 -16
- letta/services/job_manager.py +35 -17
- letta/services/mcp/base_client.py +26 -1
- letta/services/mcp_manager.py +33 -18
- letta/services/provider_manager.py +30 -0
- letta/services/tool_executor/builtin_tool_executor.py +335 -43
- letta/services/tool_manager.py +25 -1
- letta/services/user_manager.py +1 -1
- letta/settings.py +3 -0
- {letta_nightly-0.8.4.dev20250619104255.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/METADATA +4 -3
- {letta_nightly-0.8.4.dev20250619104255.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/RECORD +44 -42
- {letta_nightly-0.8.4.dev20250619104255.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/LICENSE +0 -0
- {letta_nightly-0.8.4.dev20250619104255.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/WHEEL +0 -0
- {letta_nightly-0.8.4.dev20250619104255.dist-info → letta_nightly-0.8.5.dev20250619180801.dist-info}/entry_points.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import List
|
1
|
+
from typing import List, Tuple
|
2
2
|
|
3
3
|
from mistralai import OCRPageObject
|
4
4
|
|
@@ -27,3 +27,91 @@ class LlamaIndexChunker:
|
|
27
27
|
except Exception as e:
|
28
28
|
logger.error(f"Chunking failed: {str(e)}")
|
29
29
|
raise
|
30
|
+
|
31
|
+
|
32
|
+
class MarkdownChunker:
|
33
|
+
"""Markdown-specific chunker that preserves line numbers for citation purposes"""
|
34
|
+
|
35
|
+
def __init__(self, chunk_size: int = 2048):
|
36
|
+
self.chunk_size = chunk_size
|
37
|
+
# No overlap for line-based citations to avoid ambiguity
|
38
|
+
|
39
|
+
from llama_index.core.node_parser import MarkdownNodeParser
|
40
|
+
|
41
|
+
self.parser = MarkdownNodeParser()
|
42
|
+
|
43
|
+
def chunk_markdown_with_line_numbers(self, markdown_content: str) -> List[Tuple[str, int, int]]:
|
44
|
+
"""
|
45
|
+
Chunk markdown content while preserving line number mappings.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
List of tuples: (chunk_text, start_line, end_line)
|
49
|
+
"""
|
50
|
+
try:
|
51
|
+
# Split content into lines for line number tracking
|
52
|
+
lines = markdown_content.split("\n")
|
53
|
+
|
54
|
+
# Create nodes using MarkdownNodeParser
|
55
|
+
from llama_index.core import Document
|
56
|
+
|
57
|
+
document = Document(text=markdown_content)
|
58
|
+
nodes = self.parser.get_nodes_from_documents([document])
|
59
|
+
|
60
|
+
chunks_with_line_numbers = []
|
61
|
+
|
62
|
+
for node in nodes:
|
63
|
+
chunk_text = node.text
|
64
|
+
|
65
|
+
# Find the line numbers for this chunk
|
66
|
+
start_line, end_line = self._find_line_numbers(chunk_text, lines)
|
67
|
+
|
68
|
+
chunks_with_line_numbers.append((chunk_text, start_line, end_line))
|
69
|
+
|
70
|
+
return chunks_with_line_numbers
|
71
|
+
|
72
|
+
except Exception as e:
|
73
|
+
logger.error(f"Markdown chunking failed: {str(e)}")
|
74
|
+
# Fallback to simple line-based chunking
|
75
|
+
return self._fallback_line_chunking(markdown_content)
|
76
|
+
|
77
|
+
def _find_line_numbers(self, chunk_text: str, lines: List[str]) -> Tuple[int, int]:
|
78
|
+
"""Find the start and end line numbers for a given chunk of text."""
|
79
|
+
chunk_lines = chunk_text.split("\n")
|
80
|
+
|
81
|
+
# Find the first line of the chunk in the original document
|
82
|
+
start_line = 1
|
83
|
+
for i, line in enumerate(lines):
|
84
|
+
if chunk_lines[0].strip() in line.strip() and len(chunk_lines[0].strip()) > 10: # Avoid matching short lines
|
85
|
+
start_line = i + 1
|
86
|
+
break
|
87
|
+
|
88
|
+
# Calculate end line
|
89
|
+
end_line = start_line + len(chunk_lines) - 1
|
90
|
+
|
91
|
+
return start_line, min(end_line, len(lines))
|
92
|
+
|
93
|
+
def _fallback_line_chunking(self, markdown_content: str) -> List[Tuple[str, int, int]]:
|
94
|
+
"""Fallback chunking method that simply splits by lines with no overlap."""
|
95
|
+
lines = markdown_content.split("\n")
|
96
|
+
chunks = []
|
97
|
+
|
98
|
+
i = 0
|
99
|
+
while i < len(lines):
|
100
|
+
chunk_lines = []
|
101
|
+
start_line = i + 1
|
102
|
+
char_count = 0
|
103
|
+
|
104
|
+
# Build chunk until we hit size limit
|
105
|
+
while i < len(lines) and char_count < self.chunk_size:
|
106
|
+
line = lines[i]
|
107
|
+
chunk_lines.append(line)
|
108
|
+
char_count += len(line) + 1 # +1 for newline
|
109
|
+
i += 1
|
110
|
+
|
111
|
+
end_line = i
|
112
|
+
chunk_text = "\n".join(chunk_lines)
|
113
|
+
chunks.append((chunk_text, start_line, end_line))
|
114
|
+
|
115
|
+
# No overlap - continue from where we left off
|
116
|
+
|
117
|
+
return chunks
|
@@ -16,7 +16,12 @@ class OpenAIEmbedder:
|
|
16
16
|
"""OpenAI-based embedding generation"""
|
17
17
|
|
18
18
|
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
19
|
-
self.
|
19
|
+
self.default_embedding_config = (
|
20
|
+
EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai")
|
21
|
+
if model_settings.openai_api_key
|
22
|
+
else EmbeddingConfig.default_config(model_name="letta")
|
23
|
+
)
|
24
|
+
self.embedding_config = embedding_config or self.default_embedding_config
|
20
25
|
|
21
26
|
# TODO: Unify to global OpenAI client
|
22
27
|
self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key)
|
@@ -20,11 +20,10 @@ class MistralFileParser(FileParser):
|
|
20
20
|
async def extract_text(self, content: bytes, mime_type: str) -> OCRResponse:
|
21
21
|
"""Extract text using Mistral OCR or shortcut for plain text."""
|
22
22
|
try:
|
23
|
-
logger.info(f"Extracting text using Mistral OCR model: {self.model}")
|
24
|
-
|
25
23
|
# TODO: Kind of hacky...we try to exit early here?
|
26
24
|
# TODO: Create our internal file parser representation we return instead of OCRResponse
|
27
25
|
if is_simple_text_mime_type(mime_type):
|
26
|
+
logger.info(f"Extracting text directly (no Mistral): {self.model}")
|
28
27
|
text = content.decode("utf-8", errors="replace")
|
29
28
|
return OCRResponse(
|
30
29
|
model=self.model,
|
@@ -43,6 +42,7 @@ class MistralFileParser(FileParser):
|
|
43
42
|
base64_encoded_content = base64.b64encode(content).decode("utf-8")
|
44
43
|
document_url = f"data:{mime_type};base64,{base64_encoded_content}"
|
45
44
|
|
45
|
+
logger.info(f"Extracting text using Mistral OCR model: {self.model}")
|
46
46
|
async with Mistral(api_key=settings.mistral_api_key) as mistral:
|
47
47
|
ocr_response = await mistral.ocr.process_async(
|
48
48
|
model="mistral-ocr-latest", document={"type": "document_url", "document_url": document_url}, include_image_base64=False
|
@@ -449,41 +449,69 @@ def _cursor_filter(created_at_col, id_col, ref_created_at, ref_id, forward: bool
|
|
449
449
|
)
|
450
450
|
|
451
451
|
|
452
|
-
def _apply_pagination(
|
452
|
+
def _apply_pagination(
|
453
|
+
query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
|
454
|
+
) -> any:
|
455
|
+
# Determine the sort column
|
456
|
+
if sort_by == "last_run_completion":
|
457
|
+
sort_column = AgentModel.last_run_completion
|
458
|
+
else:
|
459
|
+
sort_column = AgentModel.created_at
|
460
|
+
|
453
461
|
if after:
|
454
|
-
|
462
|
+
if sort_by == "last_run_completion":
|
463
|
+
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after)).first()
|
464
|
+
else:
|
465
|
+
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after)).first()
|
455
466
|
if result:
|
456
|
-
|
457
|
-
query = query.where(_cursor_filter(
|
467
|
+
after_sort_value, after_id = result
|
468
|
+
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
|
458
469
|
|
459
470
|
if before:
|
460
|
-
|
471
|
+
if sort_by == "last_run_completion":
|
472
|
+
result = session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before)).first()
|
473
|
+
else:
|
474
|
+
result = session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before)).first()
|
461
475
|
if result:
|
462
|
-
|
463
|
-
query = query.where(_cursor_filter(
|
476
|
+
before_sort_value, before_id = result
|
477
|
+
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
|
464
478
|
|
465
479
|
# Apply ordering
|
466
480
|
order_fn = asc if ascending else desc
|
467
|
-
query = query.order_by(order_fn(
|
481
|
+
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
|
468
482
|
return query
|
469
483
|
|
470
484
|
|
471
|
-
async def _apply_pagination_async(
|
485
|
+
async def _apply_pagination_async(
|
486
|
+
query, before: Optional[str], after: Optional[str], session, ascending: bool = True, sort_by: str = "created_at"
|
487
|
+
) -> any:
|
488
|
+
# Determine the sort column
|
489
|
+
if sort_by == "last_run_completion":
|
490
|
+
sort_column = AgentModel.last_run_completion
|
491
|
+
else:
|
492
|
+
sort_column = AgentModel.created_at
|
493
|
+
|
472
494
|
if after:
|
473
|
-
|
495
|
+
if sort_by == "last_run_completion":
|
496
|
+
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == after))).first()
|
497
|
+
else:
|
498
|
+
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == after))).first()
|
474
499
|
if result:
|
475
|
-
|
476
|
-
query = query.where(_cursor_filter(
|
500
|
+
after_sort_value, after_id = result
|
501
|
+
query = query.where(_cursor_filter(sort_column, AgentModel.id, after_sort_value, after_id, forward=ascending))
|
477
502
|
|
478
503
|
if before:
|
479
|
-
|
504
|
+
if sort_by == "last_run_completion":
|
505
|
+
result = (await session.execute(select(AgentModel.last_run_completion, AgentModel.id).where(AgentModel.id == before))).first()
|
506
|
+
else:
|
507
|
+
result = (await session.execute(select(AgentModel.created_at, AgentModel.id).where(AgentModel.id == before))).first()
|
480
508
|
if result:
|
481
|
-
|
482
|
-
query = query.where(_cursor_filter(
|
509
|
+
before_sort_value, before_id = result
|
510
|
+
query = query.where(_cursor_filter(sort_column, AgentModel.id, before_sort_value, before_id, forward=not ascending))
|
483
511
|
|
484
512
|
# Apply ordering
|
485
513
|
order_fn = asc if ascending else desc
|
486
|
-
query = query.order_by(order_fn(
|
514
|
+
query = query.order_by(order_fn(sort_column), order_fn(AgentModel.id))
|
487
515
|
return query
|
488
516
|
|
489
517
|
|
letta/services/job_manager.py
CHANGED
@@ -6,6 +6,7 @@ from sqlalchemy import select
|
|
6
6
|
from sqlalchemy.orm import Session
|
7
7
|
|
8
8
|
from letta.helpers.datetime_helpers import get_utc_time
|
9
|
+
from letta.log import get_logger
|
9
10
|
from letta.orm.enums import JobType
|
10
11
|
from letta.orm.errors import NoResultFound
|
11
12
|
from letta.orm.job import Job as JobModel
|
@@ -28,6 +29,8 @@ from letta.schemas.user import User as PydanticUser
|
|
28
29
|
from letta.server.db import db_registry
|
29
30
|
from letta.utils import enforce_types
|
30
31
|
|
32
|
+
logger = get_logger(__name__)
|
33
|
+
|
31
34
|
|
32
35
|
class JobManager:
|
33
36
|
"""Manager class to handle business logic related to Jobs."""
|
@@ -67,18 +70,22 @@ class JobManager:
|
|
67
70
|
with db_registry.session() as session:
|
68
71
|
# Fetch the job by ID
|
69
72
|
job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"])
|
73
|
+
not_completed_before = not bool(job.completed_at)
|
70
74
|
|
71
75
|
# Update job attributes with only the fields that were explicitly set
|
72
76
|
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
73
77
|
|
74
78
|
# Automatically update the completion timestamp if status is set to 'completed'
|
75
79
|
for key, value in update_data.items():
|
80
|
+
# Ensure completed_at is timezone-naive for database compatibility
|
81
|
+
if key == "completed_at" and value is not None and hasattr(value, "replace"):
|
82
|
+
value = value.replace(tzinfo=None)
|
76
83
|
setattr(job, key, value)
|
77
84
|
|
78
|
-
if
|
85
|
+
if job_update.status in {JobStatus.completed, JobStatus.failed} and not_completed_before:
|
79
86
|
job.completed_at = get_utc_time().replace(tzinfo=None)
|
80
87
|
if job.callback_url:
|
81
|
-
self._dispatch_callback(
|
88
|
+
self._dispatch_callback(job)
|
82
89
|
|
83
90
|
# Save the updated job to the database
|
84
91
|
job.update(db_session=session, actor=actor)
|
@@ -92,18 +99,22 @@ class JobManager:
|
|
92
99
|
async with db_registry.async_session() as session:
|
93
100
|
# Fetch the job by ID
|
94
101
|
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
|
102
|
+
not_completed_before = not bool(job.completed_at)
|
95
103
|
|
96
104
|
# Update job attributes with only the fields that were explicitly set
|
97
105
|
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
98
106
|
|
99
107
|
# Automatically update the completion timestamp if status is set to 'completed'
|
100
108
|
for key, value in update_data.items():
|
109
|
+
# Ensure completed_at is timezone-naive for database compatibility
|
110
|
+
if key == "completed_at" and value is not None and hasattr(value, "replace"):
|
111
|
+
value = value.replace(tzinfo=None)
|
101
112
|
setattr(job, key, value)
|
102
113
|
|
103
|
-
if
|
114
|
+
if job_update.status in {JobStatus.completed, JobStatus.failed} and not_completed_before:
|
104
115
|
job.completed_at = get_utc_time().replace(tzinfo=None)
|
105
116
|
if job.callback_url:
|
106
|
-
await self._dispatch_callback_async(
|
117
|
+
await self._dispatch_callback_async(job)
|
107
118
|
|
108
119
|
# Save the updated job to the database
|
109
120
|
await job.update_async(db_session=session, actor=actor)
|
@@ -586,7 +597,7 @@ class JobManager:
|
|
586
597
|
request_config = job.request_config or LettaRequestConfig()
|
587
598
|
return request_config
|
588
599
|
|
589
|
-
def _dispatch_callback(self,
|
600
|
+
def _dispatch_callback(self, job: JobModel) -> None:
|
590
601
|
"""
|
591
602
|
POST a standard JSON payload to job.callback_url
|
592
603
|
and record timestamp + HTTP status.
|
@@ -595,22 +606,25 @@ class JobManager:
|
|
595
606
|
payload = {
|
596
607
|
"job_id": job.id,
|
597
608
|
"status": job.status,
|
598
|
-
"completed_at": job.completed_at.isoformat(),
|
609
|
+
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
610
|
+
"metadata": job.metadata_,
|
599
611
|
}
|
600
612
|
try:
|
601
613
|
import httpx
|
602
614
|
|
603
615
|
resp = httpx.post(job.callback_url, json=payload, timeout=5.0)
|
604
|
-
job.callback_sent_at = get_utc_time()
|
616
|
+
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
605
617
|
job.callback_status_code = resp.status_code
|
606
618
|
|
607
|
-
except Exception:
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
619
|
+
except Exception as e:
|
620
|
+
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
|
621
|
+
logger.error(error_message)
|
622
|
+
# Record the failed attempt
|
623
|
+
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
624
|
+
job.callback_error = error_message
|
625
|
+
# Continue silently - callback failures should not affect job completion
|
612
626
|
|
613
|
-
async def _dispatch_callback_async(self,
|
627
|
+
async def _dispatch_callback_async(self, job: JobModel) -> None:
|
614
628
|
"""
|
615
629
|
POST a standard JSON payload to job.callback_url and record timestamp + HTTP status asynchronously.
|
616
630
|
"""
|
@@ -618,6 +632,7 @@ class JobManager:
|
|
618
632
|
"job_id": job.id,
|
619
633
|
"status": job.status,
|
620
634
|
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
635
|
+
"metadata": job.metadata_,
|
621
636
|
}
|
622
637
|
|
623
638
|
try:
|
@@ -628,7 +643,10 @@ class JobManager:
|
|
628
643
|
# Ensure timestamp is timezone-naive for DB compatibility
|
629
644
|
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
630
645
|
job.callback_status_code = resp.status_code
|
631
|
-
except Exception:
|
632
|
-
|
633
|
-
|
634
|
-
|
646
|
+
except Exception as e:
|
647
|
+
error_message = f"Failed to dispatch callback for job {job.id} to {job.callback_url}: {str(e)}"
|
648
|
+
logger.error(error_message)
|
649
|
+
# Record the failed attempt
|
650
|
+
job.callback_sent_at = get_utc_time().replace(tzinfo=None)
|
651
|
+
job.callback_error = error_message
|
652
|
+
# Continue silently - callback failures should not affect job completion
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
from contextlib import AsyncExitStack
|
2
3
|
from typing import Optional, Tuple
|
3
4
|
|
@@ -18,6 +19,9 @@ class AsyncBaseMCPClient:
|
|
18
19
|
self.exit_stack = AsyncExitStack()
|
19
20
|
self.session: Optional[ClientSession] = None
|
20
21
|
self.initialized = False
|
22
|
+
# Track the task that created this client
|
23
|
+
self._creation_task = asyncio.current_task()
|
24
|
+
self._cleanup_queue = asyncio.Queue(maxsize=1)
|
21
25
|
|
22
26
|
async def connect_to_server(self):
|
23
27
|
try:
|
@@ -74,8 +78,29 @@ class AsyncBaseMCPClient:
|
|
74
78
|
raise RuntimeError("MCPClient has not been initialized")
|
75
79
|
|
76
80
|
async def cleanup(self):
|
77
|
-
"""Clean up resources"""
|
81
|
+
"""Clean up resources - ensure this runs in the same task"""
|
82
|
+
if hasattr(self, "_cleanup_task"):
|
83
|
+
# If we're in a different task, schedule cleanup in original task
|
84
|
+
current_task = asyncio.current_task()
|
85
|
+
if current_task != self._creation_task:
|
86
|
+
# Create a future to signal completion
|
87
|
+
cleanup_done = asyncio.Future()
|
88
|
+
self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done))
|
89
|
+
await cleanup_done
|
90
|
+
return
|
91
|
+
|
92
|
+
# Normal cleanup
|
78
93
|
await self.exit_stack.aclose()
|
79
94
|
|
80
95
|
def to_sync_client(self):
|
81
96
|
raise NotImplementedError("Subclasses must implement to_sync_client")
|
97
|
+
|
98
|
+
async def __aenter__(self):
|
99
|
+
"""Enter the async context manager."""
|
100
|
+
await self.connect_to_server()
|
101
|
+
return self
|
102
|
+
|
103
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
104
|
+
"""Exit the async context manager."""
|
105
|
+
await self.cleanup()
|
106
|
+
return False # Don't suppress exceptions
|
letta/services/mcp_manager.py
CHANGED
@@ -75,23 +75,23 @@ class MCPManager:
|
|
75
75
|
server_config = mcp_config[mcp_server_name]
|
76
76
|
|
77
77
|
if isinstance(server_config, SSEServerConfig):
|
78
|
-
mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
78
|
+
# mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
79
|
+
async with AsyncSSEMCPClient(server_config=server_config) as mcp_client:
|
80
|
+
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
81
|
+
logger.info(f"MCP Result: {result}, Success: {success}")
|
82
|
+
return result, success
|
79
83
|
elif isinstance(server_config, StdioServerConfig):
|
80
|
-
|
84
|
+
async with AsyncStdioMCPClient(server_config=server_config) as mcp_client:
|
85
|
+
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
86
|
+
logger.info(f"MCP Result: {result}, Success: {success}")
|
87
|
+
return result, success
|
81
88
|
elif isinstance(server_config, StreamableHTTPServerConfig):
|
82
|
-
|
89
|
+
async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client:
|
90
|
+
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
91
|
+
logger.info(f"MCP Result: {result}, Success: {success}")
|
92
|
+
return result, success
|
83
93
|
else:
|
84
94
|
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
85
|
-
await mcp_client.connect_to_server()
|
86
|
-
|
87
|
-
# call tool
|
88
|
-
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
89
|
-
logger.info(f"MCP Result: {result}, Success: {success}")
|
90
|
-
# TODO: change to pydantic tool
|
91
|
-
|
92
|
-
await mcp_client.cleanup()
|
93
|
-
|
94
|
-
return result, success
|
95
95
|
|
96
96
|
@enforce_types
|
97
97
|
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
|
@@ -149,19 +149,19 @@ class MCPManager:
|
|
149
149
|
return mcp_server
|
150
150
|
|
151
151
|
@enforce_types
|
152
|
-
async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) ->
|
153
|
-
"""Create a new
|
154
|
-
with db_registry.
|
152
|
+
async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
153
|
+
"""Create a new MCP server."""
|
154
|
+
async with db_registry.async_session() as session:
|
155
155
|
# Set the organization id at the ORM layer
|
156
156
|
pydantic_mcp_server.organization_id = actor.organization_id
|
157
157
|
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
|
158
158
|
|
159
159
|
mcp_server = MCPServerModel(**mcp_server_data)
|
160
|
-
mcp_server.
|
160
|
+
mcp_server = await mcp_server.create_async(session, actor=actor)
|
161
161
|
return mcp_server.to_pydantic()
|
162
162
|
|
163
163
|
@enforce_types
|
164
|
-
async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) ->
|
164
|
+
async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
|
165
165
|
"""Update a tool by its ID with the given ToolUpdate object."""
|
166
166
|
async with db_registry.async_session() as session:
|
167
167
|
# Fetch the tool by ID
|
@@ -177,6 +177,21 @@ class MCPManager:
|
|
177
177
|
# Save the updated tool to the database mcp_server = await mcp_server.update_async(db_session=session, actor=actor)
|
178
178
|
return mcp_server.to_pydantic()
|
179
179
|
|
180
|
+
@enforce_types
|
181
|
+
async def update_mcp_server_by_name(self, mcp_server_name: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
|
182
|
+
"""Update an MCP server by its name."""
|
183
|
+
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
184
|
+
if not mcp_server_id:
|
185
|
+
raise HTTPException(
|
186
|
+
status_code=404,
|
187
|
+
detail={
|
188
|
+
"code": "MCPServerNotFoundError",
|
189
|
+
"message": f"MCP server {mcp_server_name} not found",
|
190
|
+
"mcp_server_name": mcp_server_name,
|
191
|
+
},
|
192
|
+
)
|
193
|
+
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
|
194
|
+
|
180
195
|
@enforce_types
|
181
196
|
async def get_mcp_server_id_by_name(self, mcp_server_name: str, actor: PydanticUser) -> Optional[str]:
|
182
197
|
"""Retrieve a MCP server by its name and a user"""
|
@@ -71,6 +71,25 @@ class ProviderManager:
|
|
71
71
|
existing_provider.update(session, actor=actor)
|
72
72
|
return existing_provider.to_pydantic()
|
73
73
|
|
74
|
+
@enforce_types
|
75
|
+
@trace_method
|
76
|
+
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
77
|
+
"""Update provider details."""
|
78
|
+
async with db_registry.async_session() as session:
|
79
|
+
# Retrieve the existing provider by ID
|
80
|
+
existing_provider = await ProviderModel.read_async(
|
81
|
+
db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True
|
82
|
+
)
|
83
|
+
|
84
|
+
# Update only the fields that are provided in ProviderUpdate
|
85
|
+
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
86
|
+
for key, value in update_data.items():
|
87
|
+
setattr(existing_provider, key, value)
|
88
|
+
|
89
|
+
# Commit the updated provider
|
90
|
+
await existing_provider.update_async(session, actor=actor)
|
91
|
+
return existing_provider.to_pydantic()
|
92
|
+
|
74
93
|
@enforce_types
|
75
94
|
@trace_method
|
76
95
|
def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
|
@@ -175,6 +194,15 @@ class ProviderManager:
|
|
175
194
|
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
176
195
|
return providers[0].api_key if providers else None
|
177
196
|
|
197
|
+
@enforce_types
|
198
|
+
@trace_method
|
199
|
+
async def get_bedrock_credentials_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
200
|
+
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
201
|
+
access_key = providers[0].api_key if providers else None
|
202
|
+
secret_key = providers[0].api_secret if providers else None
|
203
|
+
region = providers[0].region if providers else None
|
204
|
+
return access_key, secret_key, region
|
205
|
+
|
178
206
|
@enforce_types
|
179
207
|
@trace_method
|
180
208
|
def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|
@@ -183,6 +211,8 @@ class ProviderManager:
|
|
183
211
|
provider_type=provider_check.provider_type,
|
184
212
|
api_key=provider_check.api_key,
|
185
213
|
provider_category=ProviderCategory.byok,
|
214
|
+
secret_key=provider_check.api_secret,
|
215
|
+
region=provider_check.region,
|
186
216
|
).cast_to_subtype()
|
187
217
|
|
188
218
|
# TODO: add more string sanity checks here before we hit actual endpoints
|