alita-sdk 0.3.435__py3-none-any.whl → 0.3.457__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/runtime/clients/client.py +39 -7
- alita_sdk/runtime/langchain/assistant.py +10 -2
- alita_sdk/runtime/langchain/langraph_agent.py +57 -15
- alita_sdk/runtime/langchain/utils.py +19 -3
- alita_sdk/runtime/models/mcp_models.py +4 -0
- alita_sdk/runtime/toolkits/artifact.py +5 -6
- alita_sdk/runtime/toolkits/mcp.py +258 -150
- alita_sdk/runtime/toolkits/tools.py +44 -2
- alita_sdk/runtime/tools/function.py +2 -1
- alita_sdk/runtime/tools/mcp_remote_tool.py +166 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +9 -76
- alita_sdk/runtime/tools/vectorstore_base.py +17 -2
- alita_sdk/runtime/utils/mcp_oauth.py +164 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +405 -0
- alita_sdk/runtime/utils/toolkit_utils.py +9 -2
- alita_sdk/tools/ado/repos/__init__.py +1 -0
- alita_sdk/tools/ado/test_plan/__init__.py +1 -1
- alita_sdk/tools/ado/wiki/__init__.py +1 -5
- alita_sdk/tools/ado/work_item/__init__.py +1 -5
- alita_sdk/tools/base_indexer_toolkit.py +10 -6
- alita_sdk/tools/bitbucket/__init__.py +1 -0
- alita_sdk/tools/code/sonar/__init__.py +1 -1
- alita_sdk/tools/confluence/__init__.py +2 -2
- alita_sdk/tools/github/__init__.py +2 -2
- alita_sdk/tools/gitlab/__init__.py +2 -1
- alita_sdk/tools/gitlab_org/__init__.py +1 -2
- alita_sdk/tools/google_places/__init__.py +2 -1
- alita_sdk/tools/jira/__init__.py +1 -0
- alita_sdk/tools/memory/__init__.py +1 -1
- alita_sdk/tools/pandas/__init__.py +1 -1
- alita_sdk/tools/postman/__init__.py +2 -1
- alita_sdk/tools/pptx/__init__.py +2 -2
- alita_sdk/tools/qtest/__init__.py +3 -3
- alita_sdk/tools/qtest/api_wrapper.py +374 -29
- alita_sdk/tools/rally/__init__.py +1 -2
- alita_sdk/tools/report_portal/__init__.py +1 -0
- alita_sdk/tools/salesforce/__init__.py +1 -0
- alita_sdk/tools/servicenow/__init__.py +2 -3
- alita_sdk/tools/sharepoint/__init__.py +1 -0
- alita_sdk/tools/slack/__init__.py +1 -0
- alita_sdk/tools/sql/__init__.py +2 -1
- alita_sdk/tools/testio/__init__.py +1 -0
- alita_sdk/tools/testrail/__init__.py +1 -3
- alita_sdk/tools/xray/__init__.py +2 -1
- alita_sdk/tools/zephyr/__init__.py +2 -1
- alita_sdk/tools/zephyr_enterprise/__init__.py +1 -0
- alita_sdk/tools/zephyr_essential/__init__.py +1 -0
- alita_sdk/tools/zephyr_scale/__init__.py +1 -0
- alita_sdk/tools/zephyr_squad/__init__.py +1 -0
- {alita_sdk-0.3.435.dist-info → alita_sdk-0.3.457.dist-info}/METADATA +2 -1
- {alita_sdk-0.3.435.dist-info → alita_sdk-0.3.457.dist-info}/RECORD +54 -51
- {alita_sdk-0.3.435.dist-info → alita_sdk-0.3.457.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.435.dist-info → alita_sdk-0.3.457.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.435.dist-info → alita_sdk-0.3.457.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP Remote Tool for direct HTTP/SSE invocation.
|
|
3
|
+
This tool is used for remote MCP servers accessed via HTTP/SSE.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import time
|
|
10
|
+
import uuid
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
12
|
+
from typing import Any, Dict, Optional
|
|
13
|
+
|
|
14
|
+
from .mcp_server_tool import McpServerTool
|
|
15
|
+
from pydantic import Field
|
|
16
|
+
from ..utils.mcp_oauth import (
|
|
17
|
+
McpAuthorizationRequired,
|
|
18
|
+
canonical_resource,
|
|
19
|
+
extract_resource_metadata_url,
|
|
20
|
+
fetch_resource_metadata_async,
|
|
21
|
+
infer_authorization_servers_from_realm,
|
|
22
|
+
)
|
|
23
|
+
from ..utils.mcp_sse_client import McpSseClient
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class McpRemoteTool(McpServerTool):
|
|
29
|
+
"""
|
|
30
|
+
Tool for invoking remote MCP server tools via HTTP/SSE.
|
|
31
|
+
Extends McpServerTool and overrides _run to use direct HTTP calls instead of client.mcp_tool_call.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Remote MCP connection details
|
|
35
|
+
server_url: str = Field(..., description="URL of the remote MCP server")
|
|
36
|
+
server_headers: Optional[Dict[str, str]] = Field(default=None, description="HTTP headers for authentication")
|
|
37
|
+
original_tool_name: Optional[str] = Field(default=None, description="Original tool name from MCP server (before optimization)")
|
|
38
|
+
is_prompt: bool = False # Flag to indicate if this is a prompt tool
|
|
39
|
+
prompt_name: Optional[str] = None # Original prompt name if this is a prompt
|
|
40
|
+
session_id: Optional[str] = Field(default=None, description="MCP session ID for stateful SSE servers")
|
|
41
|
+
|
|
42
|
+
def model_post_init(self, __context: Any) -> None:
|
|
43
|
+
"""Update metadata with session info after model initialization."""
|
|
44
|
+
super().model_post_init(__context)
|
|
45
|
+
self._update_metadata_with_session()
|
|
46
|
+
|
|
47
|
+
def _update_metadata_with_session(self):
|
|
48
|
+
"""Update the metadata dict with current session information."""
|
|
49
|
+
if self.session_id:
|
|
50
|
+
if self.metadata is None:
|
|
51
|
+
self.metadata = {}
|
|
52
|
+
self.metadata.update({
|
|
53
|
+
'mcp_session_id': self.session_id,
|
|
54
|
+
'mcp_server_url': canonical_resource(self.server_url)
|
|
55
|
+
})
|
|
56
|
+
|
|
57
|
+
def __getstate__(self):
|
|
58
|
+
"""Custom serialization for pickle compatibility."""
|
|
59
|
+
state = super().__getstate__()
|
|
60
|
+
# Ensure headers are serializable
|
|
61
|
+
if 'server_headers' in state and state['server_headers'] is not None:
|
|
62
|
+
state['server_headers'] = dict(state['server_headers'])
|
|
63
|
+
return state
|
|
64
|
+
|
|
65
|
+
def _run(self, *args, **kwargs):
|
|
66
|
+
"""
|
|
67
|
+
Execute the MCP tool via direct HTTP/SSE call to the remote server.
|
|
68
|
+
Overrides the parent method to avoid using client.mcp_tool_call.
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
# Always create a new event loop for sync context
|
|
72
|
+
with ThreadPoolExecutor() as executor:
|
|
73
|
+
future = executor.submit(self._run_in_new_loop, kwargs)
|
|
74
|
+
return future.result(timeout=self.tool_timeout_sec)
|
|
75
|
+
except McpAuthorizationRequired:
|
|
76
|
+
# Bubble up so LangChain can surface a tool error with useful metadata
|
|
77
|
+
raise
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(f"Error executing remote MCP tool '{self.name}': {e}")
|
|
80
|
+
return f"Error executing tool: {e}"
|
|
81
|
+
|
|
82
|
+
def _run_in_new_loop(self, kwargs: Dict[str, Any]) -> str:
|
|
83
|
+
"""Run the async tool invocation in a new event loop."""
|
|
84
|
+
return asyncio.run(self._execute_remote_tool(kwargs))
|
|
85
|
+
|
|
86
|
+
async def _execute_remote_tool(self, kwargs: Dict[str, Any]) -> str:
|
|
87
|
+
"""Execute the actual remote MCP tool call using SSE client."""
|
|
88
|
+
from ...tools.utils import TOOLKIT_SPLITTER
|
|
89
|
+
|
|
90
|
+
# Check for session_id requirement
|
|
91
|
+
if not self.session_id:
|
|
92
|
+
logger.error(f"[MCP Session] Missing session_id for tool '{self.name}'")
|
|
93
|
+
raise Exception("sessionId required. Frontend must generate UUID and send with mcp_tokens.")
|
|
94
|
+
|
|
95
|
+
# Use the original tool name from discovery for MCP server invocation
|
|
96
|
+
tool_name_for_server = self.original_tool_name
|
|
97
|
+
if not tool_name_for_server:
|
|
98
|
+
tool_name_for_server = self.name.rsplit(TOOLKIT_SPLITTER, 1)[-1] if TOOLKIT_SPLITTER in self.name else self.name
|
|
99
|
+
logger.warning(f"original_tool_name not set for '{self.name}', using extracted: {tool_name_for_server}")
|
|
100
|
+
|
|
101
|
+
logger.info(f"[MCP SSE] Executing tool '{tool_name_for_server}' with session {self.session_id}")
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
# Prepare headers
|
|
105
|
+
headers = {}
|
|
106
|
+
if self.server_headers:
|
|
107
|
+
headers.update(self.server_headers)
|
|
108
|
+
|
|
109
|
+
# Create SSE client
|
|
110
|
+
client = McpSseClient(
|
|
111
|
+
url=self.server_url,
|
|
112
|
+
session_id=self.session_id,
|
|
113
|
+
headers=headers,
|
|
114
|
+
timeout=self.tool_timeout_sec
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Execute tool call via SSE
|
|
118
|
+
result = await client.call_tool(tool_name_for_server, kwargs)
|
|
119
|
+
|
|
120
|
+
# Format the result
|
|
121
|
+
if isinstance(result, dict):
|
|
122
|
+
# Check for content array (common in MCP responses)
|
|
123
|
+
if "content" in result:
|
|
124
|
+
content_items = result["content"]
|
|
125
|
+
if isinstance(content_items, list):
|
|
126
|
+
# Extract text from content items
|
|
127
|
+
text_parts = []
|
|
128
|
+
for item in content_items:
|
|
129
|
+
if isinstance(item, dict):
|
|
130
|
+
if item.get("type") == "text" and "text" in item:
|
|
131
|
+
text_parts.append(item["text"])
|
|
132
|
+
elif "text" in item:
|
|
133
|
+
text_parts.append(item["text"])
|
|
134
|
+
else:
|
|
135
|
+
text_parts.append(json.dumps(item))
|
|
136
|
+
else:
|
|
137
|
+
text_parts.append(str(item))
|
|
138
|
+
return "\n".join(text_parts)
|
|
139
|
+
|
|
140
|
+
# Return formatted JSON if no content field
|
|
141
|
+
return json.dumps(result, indent=2)
|
|
142
|
+
|
|
143
|
+
# Return as string for other types
|
|
144
|
+
return str(result)
|
|
145
|
+
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.error(f"[MCP SSE] Tool execution failed: {e}", exc_info=True)
|
|
148
|
+
raise
|
|
149
|
+
|
|
150
|
+
def _parse_sse(self, text: str) -> Dict[str, Any]:
|
|
151
|
+
"""Parse Server-Sent Events (SSE) format response."""
|
|
152
|
+
for line in text.split('\n'):
|
|
153
|
+
line = line.strip()
|
|
154
|
+
if line.startswith('data:'):
|
|
155
|
+
json_str = line[5:].strip()
|
|
156
|
+
return json.loads(json_str)
|
|
157
|
+
raise ValueError("No data found in SSE response")
|
|
158
|
+
|
|
159
|
+
def get_session_metadata(self) -> dict:
|
|
160
|
+
"""Return session metadata to be included in tool responses."""
|
|
161
|
+
if self.session_id:
|
|
162
|
+
return {
|
|
163
|
+
'mcp_session_id': self.session_id,
|
|
164
|
+
'mcp_server_url': canonical_resource(self.server_url)
|
|
165
|
+
}
|
|
166
|
+
return {}
|
|
@@ -15,63 +15,12 @@ class McpServerTool(BaseTool):
|
|
|
15
15
|
description: str
|
|
16
16
|
args_schema: Optional[Type[BaseModel]] = None
|
|
17
17
|
return_type: str = "str"
|
|
18
|
-
client: Any
|
|
18
|
+
client: Any
|
|
19
19
|
server: str
|
|
20
20
|
tool_timeout_sec: int = 60
|
|
21
|
-
is_prompt: bool = False # Flag to indicate if this is a prompt tool
|
|
22
|
-
prompt_name: Optional[str] = None # Original prompt name if this is a prompt
|
|
23
21
|
|
|
24
22
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
25
23
|
|
|
26
|
-
def __getstate__(self):
|
|
27
|
-
"""Custom serialization to exclude non-serializable objects."""
|
|
28
|
-
state = self.__dict__.copy()
|
|
29
|
-
# Remove the client since it contains threading objects that can't be pickled
|
|
30
|
-
state['client'] = None
|
|
31
|
-
# Store args_schema as a schema dict instead of the dynamic class
|
|
32
|
-
if hasattr(self, 'args_schema') and self.args_schema is not None:
|
|
33
|
-
# Convert the Pydantic model back to schema dict for pickling
|
|
34
|
-
try:
|
|
35
|
-
state['_args_schema_dict'] = self.args_schema.model_json_schema()
|
|
36
|
-
state['args_schema'] = None
|
|
37
|
-
except Exception as e:
|
|
38
|
-
logger.warning(f"Failed to serialize args_schema: {e}")
|
|
39
|
-
# If conversion fails, just remove it
|
|
40
|
-
state['args_schema'] = None
|
|
41
|
-
state['_args_schema_dict'] = {}
|
|
42
|
-
return state
|
|
43
|
-
|
|
44
|
-
def __setstate__(self, state):
|
|
45
|
-
"""Custom deserialization to handle missing objects."""
|
|
46
|
-
# Restore the args_schema from the stored schema dict
|
|
47
|
-
args_schema_dict = state.pop('_args_schema_dict', {})
|
|
48
|
-
|
|
49
|
-
# Initialize required Pydantic internal attributes
|
|
50
|
-
if '__pydantic_fields_set__' not in state:
|
|
51
|
-
state['__pydantic_fields_set__'] = set(state.keys())
|
|
52
|
-
if '__pydantic_extra__' not in state:
|
|
53
|
-
state['__pydantic_extra__'] = None
|
|
54
|
-
if '__pydantic_private__' not in state:
|
|
55
|
-
state['__pydantic_private__'] = None
|
|
56
|
-
|
|
57
|
-
# Directly update the object's __dict__ to bypass Pydantic validation
|
|
58
|
-
self.__dict__.update(state)
|
|
59
|
-
|
|
60
|
-
# Recreate the args_schema from the stored dict if available
|
|
61
|
-
if args_schema_dict:
|
|
62
|
-
try:
|
|
63
|
-
recreated_schema = self.create_pydantic_model_from_schema(args_schema_dict)
|
|
64
|
-
self.__dict__['args_schema'] = recreated_schema
|
|
65
|
-
except Exception as e:
|
|
66
|
-
logger.warning(f"Failed to recreate args_schema: {e}")
|
|
67
|
-
self.__dict__['args_schema'] = None
|
|
68
|
-
else:
|
|
69
|
-
self.__dict__['args_schema'] = None
|
|
70
|
-
|
|
71
|
-
# Note: client will be None after unpickling
|
|
72
|
-
# The toolkit should reinitialize the client when needed
|
|
73
|
-
|
|
74
|
-
|
|
75
24
|
@staticmethod
|
|
76
25
|
def create_pydantic_model_from_schema(schema: dict, model_name: str = "ArgsSchema"):
|
|
77
26
|
def parse_type(field: dict, name: str = "Field") -> Any:
|
|
@@ -143,30 +92,14 @@ class McpServerTool(BaseTool):
|
|
|
143
92
|
|
|
144
93
|
def _run(self, *args, **kwargs):
|
|
145
94
|
# Extract the actual tool/prompt name (remove toolkit prefix)
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
"
|
|
152
|
-
"
|
|
153
|
-
"tool_call_id": str(uuid.uuid4()),
|
|
154
|
-
"method": "prompts/get",
|
|
155
|
-
"params": {
|
|
156
|
-
"name": self.prompt_name or actual_name.replace("prompt_", ""),
|
|
157
|
-
"arguments": kwargs.get("arguments", kwargs)
|
|
158
|
-
}
|
|
159
|
-
}
|
|
160
|
-
else:
|
|
161
|
-
# For regular tools, use tools/call endpoint
|
|
162
|
-
call_data = {
|
|
163
|
-
"server": self.server,
|
|
164
|
-
"tool_timeout_sec": self.tool_timeout_sec,
|
|
165
|
-
"tool_call_id": str(uuid.uuid4()),
|
|
166
|
-
"params": {
|
|
167
|
-
"name": actual_name,
|
|
168
|
-
"arguments": kwargs
|
|
169
|
-
}
|
|
95
|
+
call_data = {
|
|
96
|
+
"server": self.server,
|
|
97
|
+
"tool_timeout_sec": self.tool_timeout_sec,
|
|
98
|
+
"tool_call_id": str(uuid.uuid4()),
|
|
99
|
+
"params": {
|
|
100
|
+
"name": self.name.rsplit(TOOLKIT_SPLITTER)[1] if TOOLKIT_SPLITTER in self.name else self.name,
|
|
101
|
+
"arguments": kwargs
|
|
170
102
|
}
|
|
103
|
+
}
|
|
171
104
|
|
|
172
105
|
return self.client.mcp_tool_call(call_data)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import math
|
|
3
2
|
from collections import OrderedDict
|
|
4
3
|
from logging import getLogger
|
|
5
4
|
from typing import Any, Optional, List, Dict, Generator
|
|
6
5
|
|
|
6
|
+
import math
|
|
7
7
|
from langchain_core.documents import Document
|
|
8
8
|
from langchain_core.messages import HumanMessage
|
|
9
9
|
from langchain_core.tools import ToolException
|
|
@@ -12,7 +12,7 @@ from pydantic import BaseModel, model_validator, Field
|
|
|
12
12
|
|
|
13
13
|
from alita_sdk.tools.elitea_base import BaseToolApiWrapper
|
|
14
14
|
from alita_sdk.tools.vector_adapters.VectorStoreAdapter import VectorStoreAdapterFactory
|
|
15
|
-
from
|
|
15
|
+
from ...runtime.utils.utils import IndexerKeywords
|
|
16
16
|
|
|
17
17
|
logger = getLogger(__name__)
|
|
18
18
|
|
|
@@ -222,6 +222,21 @@ class VectorStoreWrapperBase(BaseToolApiWrapper):
|
|
|
222
222
|
raise RuntimeError(f"Multiple index_meta documents found: {index_metas}")
|
|
223
223
|
return index_metas[0] if index_metas else None
|
|
224
224
|
|
|
225
|
+
def get_indexed_count(self, index_name: str) -> int:
|
|
226
|
+
from sqlalchemy.orm import Session
|
|
227
|
+
from sqlalchemy import func, or_
|
|
228
|
+
|
|
229
|
+
with Session(self.vectorstore.session_maker.bind) as session:
|
|
230
|
+
return session.query(
|
|
231
|
+
self.vectorstore.EmbeddingStore.id,
|
|
232
|
+
).filter(
|
|
233
|
+
func.jsonb_extract_path_text(self.vectorstore.EmbeddingStore.cmetadata, 'collection') == index_name,
|
|
234
|
+
or_(
|
|
235
|
+
func.jsonb_extract_path_text(self.vectorstore.EmbeddingStore.cmetadata, 'type').is_(None),
|
|
236
|
+
func.jsonb_extract_path_text(self.vectorstore.EmbeddingStore.cmetadata, 'type') != IndexerKeywords.INDEX_META_TYPE.value
|
|
237
|
+
)
|
|
238
|
+
).count()
|
|
239
|
+
|
|
225
240
|
def _clean_collection(self, index_name: str = ''):
|
|
226
241
|
"""
|
|
227
242
|
Clean the vectorstore collection by deleting all indexed data.
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
from urllib.parse import urlparse
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
from langchain_core.tools import ToolException
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class McpAuthorizationRequired(ToolException):
|
|
14
|
+
"""Raised when an MCP server requires OAuth authorization before use."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
message: str,
|
|
19
|
+
server_url: str,
|
|
20
|
+
resource_metadata_url: Optional[str] = None,
|
|
21
|
+
www_authenticate: Optional[str] = None,
|
|
22
|
+
resource_metadata: Optional[Dict[str, Any]] = None,
|
|
23
|
+
status: Optional[int] = None,
|
|
24
|
+
tool_name: Optional[str] = None,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(message)
|
|
27
|
+
self.server_url = server_url
|
|
28
|
+
self.resource_metadata_url = resource_metadata_url
|
|
29
|
+
self.www_authenticate = www_authenticate
|
|
30
|
+
self.resource_metadata = resource_metadata
|
|
31
|
+
self.status = status
|
|
32
|
+
self.tool_name = tool_name
|
|
33
|
+
|
|
34
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
35
|
+
return {
|
|
36
|
+
"message": str(self),
|
|
37
|
+
"server_url": self.server_url,
|
|
38
|
+
"resource_metadata_url": self.resource_metadata_url,
|
|
39
|
+
"www_authenticate": self.www_authenticate,
|
|
40
|
+
"resource_metadata": self.resource_metadata,
|
|
41
|
+
"status": self.status,
|
|
42
|
+
"tool_name": self.tool_name,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def extract_resource_metadata_url(www_authenticate: Optional[str], server_url: Optional[str] = None) -> Optional[str]:
|
|
47
|
+
"""
|
|
48
|
+
Pull the resource_metadata URL from a WWW-Authenticate header if present.
|
|
49
|
+
If not found and server_url is provided, try to construct resource metadata URLs.
|
|
50
|
+
"""
|
|
51
|
+
if not www_authenticate and not server_url:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
# RFC9728 returns `resource_metadata="<url>"` inside the header value
|
|
55
|
+
if www_authenticate:
|
|
56
|
+
match = re.search(r'resource_metadata\s*=\s*\"?([^\", ]+)\"?', www_authenticate)
|
|
57
|
+
if match:
|
|
58
|
+
return match.group(1)
|
|
59
|
+
|
|
60
|
+
# For servers that don't provide resource_metadata in WWW-Authenticate,
|
|
61
|
+
# we'll return None and rely on inferring authorization servers from the realm
|
|
62
|
+
# or using well-known OAuth discovery endpoints directly
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def fetch_oauth_authorization_server_metadata(base_url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
|
|
67
|
+
"""
|
|
68
|
+
Fetch OAuth authorization server metadata from well-known endpoints.
|
|
69
|
+
Tries both oauth-authorization-server and openid-configuration discovery endpoints.
|
|
70
|
+
"""
|
|
71
|
+
discovery_endpoints = [
|
|
72
|
+
f"{base_url}/.well-known/oauth-authorization-server",
|
|
73
|
+
f"{base_url}/.well-known/openid-configuration",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
for endpoint in discovery_endpoints:
|
|
77
|
+
try:
|
|
78
|
+
resp = requests.get(endpoint, timeout=timeout)
|
|
79
|
+
if resp.status_code == 200:
|
|
80
|
+
return resp.json()
|
|
81
|
+
except Exception as exc:
|
|
82
|
+
logger.debug(f"Failed to fetch OAuth metadata from {endpoint}: {exc}")
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def infer_authorization_servers_from_realm(www_authenticate: Optional[str], server_url: str) -> Optional[list]:
|
|
89
|
+
"""
|
|
90
|
+
Infer authorization server URLs from WWW-Authenticate realm or server URL.
|
|
91
|
+
This is used when the server doesn't provide resource_metadata endpoint.
|
|
92
|
+
"""
|
|
93
|
+
if not www_authenticate and not server_url:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
authorization_servers = []
|
|
97
|
+
|
|
98
|
+
# Try to extract realm from WWW-Authenticate header
|
|
99
|
+
realm = None
|
|
100
|
+
if www_authenticate:
|
|
101
|
+
realm_match = re.search(r'realm\s*=\s*\"([^\"]+)\"', www_authenticate)
|
|
102
|
+
if realm_match:
|
|
103
|
+
realm = realm_match.group(1)
|
|
104
|
+
|
|
105
|
+
# Parse the server URL to get base domain
|
|
106
|
+
parsed = urlparse(server_url)
|
|
107
|
+
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
108
|
+
|
|
109
|
+
# Return the base authorization server URL (not the discovery endpoint)
|
|
110
|
+
# The client will append .well-known paths when fetching metadata
|
|
111
|
+
authorization_servers.append(base_url)
|
|
112
|
+
|
|
113
|
+
return authorization_servers if authorization_servers else None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def fetch_resource_metadata(resource_metadata_url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
|
|
117
|
+
"""Fetch and parse the protected resource metadata document."""
|
|
118
|
+
try:
|
|
119
|
+
resp = requests.get(resource_metadata_url, timeout=timeout)
|
|
120
|
+
resp.raise_for_status()
|
|
121
|
+
return resp.json()
|
|
122
|
+
except Exception as exc: # broad catch – we want to surface auth requirement even if this fails
|
|
123
|
+
logger.warning("Failed to fetch resource metadata from %s: %s", resource_metadata_url, exc)
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
async def fetch_resource_metadata_async(resource_metadata_url: str, session=None, timeout: int = 10) -> Optional[Dict[str, Any]]:
|
|
128
|
+
"""Async variant for fetching protected resource metadata."""
|
|
129
|
+
try:
|
|
130
|
+
import aiohttp
|
|
131
|
+
|
|
132
|
+
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
|
133
|
+
if session:
|
|
134
|
+
async with session.get(resource_metadata_url, timeout=client_timeout) as resp:
|
|
135
|
+
text = await resp.text()
|
|
136
|
+
else:
|
|
137
|
+
async with aiohttp.ClientSession(timeout=client_timeout) as local_session:
|
|
138
|
+
async with local_session.get(resource_metadata_url) as resp:
|
|
139
|
+
text = await resp.text()
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
return json.loads(text)
|
|
143
|
+
except json.JSONDecodeError:
|
|
144
|
+
logger.warning("Resource metadata at %s is not valid JSON: %s", resource_metadata_url, text[:200])
|
|
145
|
+
return None
|
|
146
|
+
except Exception as exc:
|
|
147
|
+
logger.warning("Failed to fetch resource metadata from %s: %s", resource_metadata_url, exc)
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def canonical_resource(server_url: str) -> str:
|
|
152
|
+
"""Produce a canonical resource identifier for the MCP server."""
|
|
153
|
+
parsed = urlparse(server_url)
|
|
154
|
+
# Normalize scheme/host casing per RFC guidance
|
|
155
|
+
normalized = parsed._replace(
|
|
156
|
+
scheme=parsed.scheme.lower(),
|
|
157
|
+
netloc=parsed.netloc.lower(),
|
|
158
|
+
)
|
|
159
|
+
resource = normalized.geturl()
|
|
160
|
+
|
|
161
|
+
# Prefer form without trailing slash unless path is meaningful
|
|
162
|
+
if resource.endswith("/") and parsed.path in ("", "/"):
|
|
163
|
+
resource = resource[:-1]
|
|
164
|
+
return resource
|