alita-sdk 0.3.435__py3-none-any.whl → 0.3.449__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of alita-sdk might be problematic. Click here for more details.

@@ -0,0 +1,166 @@
1
+ """
2
+ MCP Remote Tool for direct HTTP/SSE invocation.
3
+ This tool is used for remote MCP servers accessed via HTTP/SSE.
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import logging
9
+ import time
10
+ import uuid
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from typing import Any, Dict, Optional
13
+
14
+ from .mcp_server_tool import McpServerTool
15
+ from pydantic import Field
16
+ from ..utils.mcp_oauth import (
17
+ McpAuthorizationRequired,
18
+ canonical_resource,
19
+ extract_resource_metadata_url,
20
+ fetch_resource_metadata_async,
21
+ infer_authorization_servers_from_realm,
22
+ )
23
+ from ..utils.mcp_sse_client import McpSseClient
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class McpRemoteTool(McpServerTool):
29
+ """
30
+ Tool for invoking remote MCP server tools via HTTP/SSE.
31
+ Extends McpServerTool and overrides _run to use direct HTTP calls instead of client.mcp_tool_call.
32
+ """
33
+
34
+ # Remote MCP connection details
35
+ server_url: str = Field(..., description="URL of the remote MCP server")
36
+ server_headers: Optional[Dict[str, str]] = Field(default=None, description="HTTP headers for authentication")
37
+ original_tool_name: Optional[str] = Field(default=None, description="Original tool name from MCP server (before optimization)")
38
+ is_prompt: bool = False # Flag to indicate if this is a prompt tool
39
+ prompt_name: Optional[str] = None # Original prompt name if this is a prompt
40
+ session_id: Optional[str] = Field(default=None, description="MCP session ID for stateful SSE servers")
41
+
42
+ def model_post_init(self, __context: Any) -> None:
43
+ """Update metadata with session info after model initialization."""
44
+ super().model_post_init(__context)
45
+ self._update_metadata_with_session()
46
+
47
+ def _update_metadata_with_session(self):
48
+ """Update the metadata dict with current session information."""
49
+ if self.session_id:
50
+ if self.metadata is None:
51
+ self.metadata = {}
52
+ self.metadata.update({
53
+ 'mcp_session_id': self.session_id,
54
+ 'mcp_server_url': canonical_resource(self.server_url)
55
+ })
56
+
57
+ def __getstate__(self):
58
+ """Custom serialization for pickle compatibility."""
59
+ state = super().__getstate__()
60
+ # Ensure headers are serializable
61
+ if 'server_headers' in state and state['server_headers'] is not None:
62
+ state['server_headers'] = dict(state['server_headers'])
63
+ return state
64
+
65
+ def _run(self, *args, **kwargs):
66
+ """
67
+ Execute the MCP tool via direct HTTP/SSE call to the remote server.
68
+ Overrides the parent method to avoid using client.mcp_tool_call.
69
+ """
70
+ try:
71
+ # Always create a new event loop for sync context
72
+ with ThreadPoolExecutor() as executor:
73
+ future = executor.submit(self._run_in_new_loop, kwargs)
74
+ return future.result(timeout=self.tool_timeout_sec)
75
+ except McpAuthorizationRequired:
76
+ # Bubble up so LangChain can surface a tool error with useful metadata
77
+ raise
78
+ except Exception as e:
79
+ logger.error(f"Error executing remote MCP tool '{self.name}': {e}")
80
+ return f"Error executing tool: {e}"
81
+
82
+ def _run_in_new_loop(self, kwargs: Dict[str, Any]) -> str:
83
+ """Run the async tool invocation in a new event loop."""
84
+ return asyncio.run(self._execute_remote_tool(kwargs))
85
+
86
+ async def _execute_remote_tool(self, kwargs: Dict[str, Any]) -> str:
87
+ """Execute the actual remote MCP tool call using SSE client."""
88
+ from ...tools.utils import TOOLKIT_SPLITTER
89
+
90
+ # Check for session_id requirement
91
+ if not self.session_id:
92
+ logger.error(f"[MCP Session] Missing session_id for tool '{self.name}'")
93
+ raise Exception("sessionId required. Frontend must generate UUID and send with mcp_tokens.")
94
+
95
+ # Use the original tool name from discovery for MCP server invocation
96
+ tool_name_for_server = self.original_tool_name
97
+ if not tool_name_for_server:
98
+ tool_name_for_server = self.name.rsplit(TOOLKIT_SPLITTER, 1)[-1] if TOOLKIT_SPLITTER in self.name else self.name
99
+ logger.warning(f"original_tool_name not set for '{self.name}', using extracted: {tool_name_for_server}")
100
+
101
+ logger.info(f"[MCP SSE] Executing tool '{tool_name_for_server}' with session {self.session_id}")
102
+
103
+ try:
104
+ # Prepare headers
105
+ headers = {}
106
+ if self.server_headers:
107
+ headers.update(self.server_headers)
108
+
109
+ # Create SSE client
110
+ client = McpSseClient(
111
+ url=self.server_url,
112
+ session_id=self.session_id,
113
+ headers=headers,
114
+ timeout=self.tool_timeout_sec
115
+ )
116
+
117
+ # Execute tool call via SSE
118
+ result = await client.call_tool(tool_name_for_server, kwargs)
119
+
120
+ # Format the result
121
+ if isinstance(result, dict):
122
+ # Check for content array (common in MCP responses)
123
+ if "content" in result:
124
+ content_items = result["content"]
125
+ if isinstance(content_items, list):
126
+ # Extract text from content items
127
+ text_parts = []
128
+ for item in content_items:
129
+ if isinstance(item, dict):
130
+ if item.get("type") == "text" and "text" in item:
131
+ text_parts.append(item["text"])
132
+ elif "text" in item:
133
+ text_parts.append(item["text"])
134
+ else:
135
+ text_parts.append(json.dumps(item))
136
+ else:
137
+ text_parts.append(str(item))
138
+ return "\n".join(text_parts)
139
+
140
+ # Return formatted JSON if no content field
141
+ return json.dumps(result, indent=2)
142
+
143
+ # Return as string for other types
144
+ return str(result)
145
+
146
+ except Exception as e:
147
+ logger.error(f"[MCP SSE] Tool execution failed: {e}", exc_info=True)
148
+ raise
149
+
150
+ def _parse_sse(self, text: str) -> Dict[str, Any]:
151
+ """Parse Server-Sent Events (SSE) format response."""
152
+ for line in text.split('\n'):
153
+ line = line.strip()
154
+ if line.startswith('data:'):
155
+ json_str = line[5:].strip()
156
+ return json.loads(json_str)
157
+ raise ValueError("No data found in SSE response")
158
+
159
+ def get_session_metadata(self) -> dict:
160
+ """Return session metadata to be included in tool responses."""
161
+ if self.session_id:
162
+ return {
163
+ 'mcp_session_id': self.session_id,
164
+ 'mcp_server_url': canonical_resource(self.server_url)
165
+ }
166
+ return {}
@@ -15,63 +15,12 @@ class McpServerTool(BaseTool):
15
15
  description: str
16
16
  args_schema: Optional[Type[BaseModel]] = None
17
17
  return_type: str = "str"
18
- client: Any = Field(default=None, exclude=True) # Exclude from serialization
18
+ client: Any
19
19
  server: str
20
20
  tool_timeout_sec: int = 60
21
- is_prompt: bool = False # Flag to indicate if this is a prompt tool
22
- prompt_name: Optional[str] = None # Original prompt name if this is a prompt
23
21
 
24
22
  model_config = ConfigDict(arbitrary_types_allowed=True)
25
23
 
26
- def __getstate__(self):
27
- """Custom serialization to exclude non-serializable objects."""
28
- state = self.__dict__.copy()
29
- # Remove the client since it contains threading objects that can't be pickled
30
- state['client'] = None
31
- # Store args_schema as a schema dict instead of the dynamic class
32
- if hasattr(self, 'args_schema') and self.args_schema is not None:
33
- # Convert the Pydantic model back to schema dict for pickling
34
- try:
35
- state['_args_schema_dict'] = self.args_schema.model_json_schema()
36
- state['args_schema'] = None
37
- except Exception as e:
38
- logger.warning(f"Failed to serialize args_schema: {e}")
39
- # If conversion fails, just remove it
40
- state['args_schema'] = None
41
- state['_args_schema_dict'] = {}
42
- return state
43
-
44
- def __setstate__(self, state):
45
- """Custom deserialization to handle missing objects."""
46
- # Restore the args_schema from the stored schema dict
47
- args_schema_dict = state.pop('_args_schema_dict', {})
48
-
49
- # Initialize required Pydantic internal attributes
50
- if '__pydantic_fields_set__' not in state:
51
- state['__pydantic_fields_set__'] = set(state.keys())
52
- if '__pydantic_extra__' not in state:
53
- state['__pydantic_extra__'] = None
54
- if '__pydantic_private__' not in state:
55
- state['__pydantic_private__'] = None
56
-
57
- # Directly update the object's __dict__ to bypass Pydantic validation
58
- self.__dict__.update(state)
59
-
60
- # Recreate the args_schema from the stored dict if available
61
- if args_schema_dict:
62
- try:
63
- recreated_schema = self.create_pydantic_model_from_schema(args_schema_dict)
64
- self.__dict__['args_schema'] = recreated_schema
65
- except Exception as e:
66
- logger.warning(f"Failed to recreate args_schema: {e}")
67
- self.__dict__['args_schema'] = None
68
- else:
69
- self.__dict__['args_schema'] = None
70
-
71
- # Note: client will be None after unpickling
72
- # The toolkit should reinitialize the client when needed
73
-
74
-
75
24
  @staticmethod
76
25
  def create_pydantic_model_from_schema(schema: dict, model_name: str = "ArgsSchema"):
77
26
  def parse_type(field: dict, name: str = "Field") -> Any:
@@ -143,30 +92,14 @@ class McpServerTool(BaseTool):
143
92
 
144
93
  def _run(self, *args, **kwargs):
145
94
  # Extract the actual tool/prompt name (remove toolkit prefix)
146
- actual_name = self.name.rsplit(TOOLKIT_SPLITTER)[1] if TOOLKIT_SPLITTER in self.name else self.name
147
-
148
- if self.is_prompt:
149
- # For prompts, use prompts/get endpoint
150
- call_data = {
151
- "server": self.server,
152
- "tool_timeout_sec": self.tool_timeout_sec,
153
- "tool_call_id": str(uuid.uuid4()),
154
- "method": "prompts/get",
155
- "params": {
156
- "name": self.prompt_name or actual_name.replace("prompt_", ""),
157
- "arguments": kwargs.get("arguments", kwargs)
158
- }
159
- }
160
- else:
161
- # For regular tools, use tools/call endpoint
162
- call_data = {
163
- "server": self.server,
164
- "tool_timeout_sec": self.tool_timeout_sec,
165
- "tool_call_id": str(uuid.uuid4()),
166
- "params": {
167
- "name": actual_name,
168
- "arguments": kwargs
169
- }
95
+ call_data = {
96
+ "server": self.server,
97
+ "tool_timeout_sec": self.tool_timeout_sec,
98
+ "tool_call_id": str(uuid.uuid4()),
99
+ "params": {
100
+ "name": self.name.rsplit(TOOLKIT_SPLITTER)[1] if TOOLKIT_SPLITTER in self.name else self.name,
101
+ "arguments": kwargs
170
102
  }
103
+ }
171
104
 
172
105
  return self.client.mcp_tool_call(call_data)
@@ -0,0 +1,164 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from typing import Any, Dict, Optional
5
+ from urllib.parse import urlparse
6
+
7
+ import requests
8
+ from langchain_core.tools import ToolException
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class McpAuthorizationRequired(ToolException):
14
+ """Raised when an MCP server requires OAuth authorization before use."""
15
+
16
+ def __init__(
17
+ self,
18
+ message: str,
19
+ server_url: str,
20
+ resource_metadata_url: Optional[str] = None,
21
+ www_authenticate: Optional[str] = None,
22
+ resource_metadata: Optional[Dict[str, Any]] = None,
23
+ status: Optional[int] = None,
24
+ tool_name: Optional[str] = None,
25
+ ):
26
+ super().__init__(message)
27
+ self.server_url = server_url
28
+ self.resource_metadata_url = resource_metadata_url
29
+ self.www_authenticate = www_authenticate
30
+ self.resource_metadata = resource_metadata
31
+ self.status = status
32
+ self.tool_name = tool_name
33
+
34
+ def to_dict(self) -> Dict[str, Any]:
35
+ return {
36
+ "message": str(self),
37
+ "server_url": self.server_url,
38
+ "resource_metadata_url": self.resource_metadata_url,
39
+ "www_authenticate": self.www_authenticate,
40
+ "resource_metadata": self.resource_metadata,
41
+ "status": self.status,
42
+ "tool_name": self.tool_name,
43
+ }
44
+
45
+
46
+ def extract_resource_metadata_url(www_authenticate: Optional[str], server_url: Optional[str] = None) -> Optional[str]:
47
+ """
48
+ Pull the resource_metadata URL from a WWW-Authenticate header if present.
49
+ If not found and server_url is provided, try to construct resource metadata URLs.
50
+ """
51
+ if not www_authenticate and not server_url:
52
+ return None
53
+
54
+ # RFC9728 returns `resource_metadata="<url>"` inside the header value
55
+ if www_authenticate:
56
+ match = re.search(r'resource_metadata\s*=\s*\"?([^\", ]+)\"?', www_authenticate)
57
+ if match:
58
+ return match.group(1)
59
+
60
+ # For servers that don't provide resource_metadata in WWW-Authenticate,
61
+ # we'll return None and rely on inferring authorization servers from the realm
62
+ # or using well-known OAuth discovery endpoints directly
63
+ return None
64
+
65
+
66
+ def fetch_oauth_authorization_server_metadata(base_url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
67
+ """
68
+ Fetch OAuth authorization server metadata from well-known endpoints.
69
+ Tries both oauth-authorization-server and openid-configuration discovery endpoints.
70
+ """
71
+ discovery_endpoints = [
72
+ f"{base_url}/.well-known/oauth-authorization-server",
73
+ f"{base_url}/.well-known/openid-configuration",
74
+ ]
75
+
76
+ for endpoint in discovery_endpoints:
77
+ try:
78
+ resp = requests.get(endpoint, timeout=timeout)
79
+ if resp.status_code == 200:
80
+ return resp.json()
81
+ except Exception as exc:
82
+ logger.debug(f"Failed to fetch OAuth metadata from {endpoint}: {exc}")
83
+ continue
84
+
85
+ return None
86
+
87
+
88
+ def infer_authorization_servers_from_realm(www_authenticate: Optional[str], server_url: str) -> Optional[list]:
89
+ """
90
+ Infer authorization server URLs from WWW-Authenticate realm or server URL.
91
+ This is used when the server doesn't provide resource_metadata endpoint.
92
+ """
93
+ if not www_authenticate and not server_url:
94
+ return None
95
+
96
+ authorization_servers = []
97
+
98
+ # Try to extract realm from WWW-Authenticate header
99
+ realm = None
100
+ if www_authenticate:
101
+ realm_match = re.search(r'realm\s*=\s*\"([^\"]+)\"', www_authenticate)
102
+ if realm_match:
103
+ realm = realm_match.group(1)
104
+
105
+ # Parse the server URL to get base domain
106
+ parsed = urlparse(server_url)
107
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
108
+
109
+ # Return the base authorization server URL (not the discovery endpoint)
110
+ # The client will append .well-known paths when fetching metadata
111
+ authorization_servers.append(base_url)
112
+
113
+ return authorization_servers if authorization_servers else None
114
+
115
+
116
+ def fetch_resource_metadata(resource_metadata_url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
117
+ """Fetch and parse the protected resource metadata document."""
118
+ try:
119
+ resp = requests.get(resource_metadata_url, timeout=timeout)
120
+ resp.raise_for_status()
121
+ return resp.json()
122
+ except Exception as exc: # broad catch – we want to surface auth requirement even if this fails
123
+ logger.warning("Failed to fetch resource metadata from %s: %s", resource_metadata_url, exc)
124
+ return None
125
+
126
+
127
+ async def fetch_resource_metadata_async(resource_metadata_url: str, session=None, timeout: int = 10) -> Optional[Dict[str, Any]]:
128
+ """Async variant for fetching protected resource metadata."""
129
+ try:
130
+ import aiohttp
131
+
132
+ client_timeout = aiohttp.ClientTimeout(total=timeout)
133
+ if session:
134
+ async with session.get(resource_metadata_url, timeout=client_timeout) as resp:
135
+ text = await resp.text()
136
+ else:
137
+ async with aiohttp.ClientSession(timeout=client_timeout) as local_session:
138
+ async with local_session.get(resource_metadata_url) as resp:
139
+ text = await resp.text()
140
+
141
+ try:
142
+ return json.loads(text)
143
+ except json.JSONDecodeError:
144
+ logger.warning("Resource metadata at %s is not valid JSON: %s", resource_metadata_url, text[:200])
145
+ return None
146
+ except Exception as exc:
147
+ logger.warning("Failed to fetch resource metadata from %s: %s", resource_metadata_url, exc)
148
+ return None
149
+
150
+
151
+ def canonical_resource(server_url: str) -> str:
152
+ """Produce a canonical resource identifier for the MCP server."""
153
+ parsed = urlparse(server_url)
154
+ # Normalize scheme/host casing per RFC guidance
155
+ normalized = parsed._replace(
156
+ scheme=parsed.scheme.lower(),
157
+ netloc=parsed.netloc.lower(),
158
+ )
159
+ resource = normalized.geturl()
160
+
161
+ # Prefer form without trailing slash unless path is meaningful
162
+ if resource.endswith("/") and parsed.path in ("", "/"):
163
+ resource = resource[:-1]
164
+ return resource