alita-sdk 0.3.423__py3-none-any.whl → 0.3.449__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/runtime/clients/client.py +45 -9
- alita_sdk/runtime/clients/mcp_discovery.py +342 -0
- alita_sdk/runtime/clients/mcp_manager.py +262 -0
- alita_sdk/runtime/langchain/assistant.py +10 -2
- alita_sdk/runtime/langchain/constants.py +1 -1
- alita_sdk/runtime/langchain/langraph_agent.py +4 -1
- alita_sdk/runtime/models/mcp_models.py +61 -0
- alita_sdk/runtime/toolkits/__init__.py +24 -0
- alita_sdk/runtime/toolkits/mcp.py +892 -0
- alita_sdk/runtime/toolkits/tools.py +61 -3
- alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
- alita_sdk/runtime/tools/mcp_remote_tool.py +166 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
- alita_sdk/runtime/utils/mcp_oauth.py +164 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +347 -0
- alita_sdk/runtime/utils/streamlit.py +34 -3
- alita_sdk/runtime/utils/toolkit_utils.py +14 -4
- alita_sdk/tools/__init__.py +5 -0
- alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
- alita_sdk/tools/gitlab/api_wrapper.py +5 -0
- alita_sdk/tools/qtest/api_wrapper.py +240 -39
- {alita_sdk-0.3.423.dist-info → alita_sdk-0.3.449.dist-info}/METADATA +2 -1
- {alita_sdk-0.3.423.dist-info → alita_sdk-0.3.449.dist-info}/RECORD +26 -18
- {alita_sdk-0.3.423.dist-info → alita_sdk-0.3.449.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.423.dist-info → alita_sdk-0.3.449.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.423.dist-info → alita_sdk-0.3.449.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
from urllib.parse import urlparse
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
from langchain_core.tools import ToolException
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class McpAuthorizationRequired(ToolException):
|
|
14
|
+
"""Raised when an MCP server requires OAuth authorization before use."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
message: str,
|
|
19
|
+
server_url: str,
|
|
20
|
+
resource_metadata_url: Optional[str] = None,
|
|
21
|
+
www_authenticate: Optional[str] = None,
|
|
22
|
+
resource_metadata: Optional[Dict[str, Any]] = None,
|
|
23
|
+
status: Optional[int] = None,
|
|
24
|
+
tool_name: Optional[str] = None,
|
|
25
|
+
):
|
|
26
|
+
super().__init__(message)
|
|
27
|
+
self.server_url = server_url
|
|
28
|
+
self.resource_metadata_url = resource_metadata_url
|
|
29
|
+
self.www_authenticate = www_authenticate
|
|
30
|
+
self.resource_metadata = resource_metadata
|
|
31
|
+
self.status = status
|
|
32
|
+
self.tool_name = tool_name
|
|
33
|
+
|
|
34
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
35
|
+
return {
|
|
36
|
+
"message": str(self),
|
|
37
|
+
"server_url": self.server_url,
|
|
38
|
+
"resource_metadata_url": self.resource_metadata_url,
|
|
39
|
+
"www_authenticate": self.www_authenticate,
|
|
40
|
+
"resource_metadata": self.resource_metadata,
|
|
41
|
+
"status": self.status,
|
|
42
|
+
"tool_name": self.tool_name,
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def extract_resource_metadata_url(www_authenticate: Optional[str], server_url: Optional[str] = None) -> Optional[str]:
|
|
47
|
+
"""
|
|
48
|
+
Pull the resource_metadata URL from a WWW-Authenticate header if present.
|
|
49
|
+
If not found and server_url is provided, try to construct resource metadata URLs.
|
|
50
|
+
"""
|
|
51
|
+
if not www_authenticate and not server_url:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
# RFC9728 returns `resource_metadata="<url>"` inside the header value
|
|
55
|
+
if www_authenticate:
|
|
56
|
+
match = re.search(r'resource_metadata\s*=\s*\"?([^\", ]+)\"?', www_authenticate)
|
|
57
|
+
if match:
|
|
58
|
+
return match.group(1)
|
|
59
|
+
|
|
60
|
+
# For servers that don't provide resource_metadata in WWW-Authenticate,
|
|
61
|
+
# we'll return None and rely on inferring authorization servers from the realm
|
|
62
|
+
# or using well-known OAuth discovery endpoints directly
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def fetch_oauth_authorization_server_metadata(base_url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
|
|
67
|
+
"""
|
|
68
|
+
Fetch OAuth authorization server metadata from well-known endpoints.
|
|
69
|
+
Tries both oauth-authorization-server and openid-configuration discovery endpoints.
|
|
70
|
+
"""
|
|
71
|
+
discovery_endpoints = [
|
|
72
|
+
f"{base_url}/.well-known/oauth-authorization-server",
|
|
73
|
+
f"{base_url}/.well-known/openid-configuration",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
for endpoint in discovery_endpoints:
|
|
77
|
+
try:
|
|
78
|
+
resp = requests.get(endpoint, timeout=timeout)
|
|
79
|
+
if resp.status_code == 200:
|
|
80
|
+
return resp.json()
|
|
81
|
+
except Exception as exc:
|
|
82
|
+
logger.debug(f"Failed to fetch OAuth metadata from {endpoint}: {exc}")
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def infer_authorization_servers_from_realm(www_authenticate: Optional[str], server_url: str) -> Optional[list]:
|
|
89
|
+
"""
|
|
90
|
+
Infer authorization server URLs from WWW-Authenticate realm or server URL.
|
|
91
|
+
This is used when the server doesn't provide resource_metadata endpoint.
|
|
92
|
+
"""
|
|
93
|
+
if not www_authenticate and not server_url:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
authorization_servers = []
|
|
97
|
+
|
|
98
|
+
# Try to extract realm from WWW-Authenticate header
|
|
99
|
+
realm = None
|
|
100
|
+
if www_authenticate:
|
|
101
|
+
realm_match = re.search(r'realm\s*=\s*\"([^\"]+)\"', www_authenticate)
|
|
102
|
+
if realm_match:
|
|
103
|
+
realm = realm_match.group(1)
|
|
104
|
+
|
|
105
|
+
# Parse the server URL to get base domain
|
|
106
|
+
parsed = urlparse(server_url)
|
|
107
|
+
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
108
|
+
|
|
109
|
+
# Return the base authorization server URL (not the discovery endpoint)
|
|
110
|
+
# The client will append .well-known paths when fetching metadata
|
|
111
|
+
authorization_servers.append(base_url)
|
|
112
|
+
|
|
113
|
+
return authorization_servers if authorization_servers else None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def fetch_resource_metadata(resource_metadata_url: str, timeout: int = 10) -> Optional[Dict[str, Any]]:
|
|
117
|
+
"""Fetch and parse the protected resource metadata document."""
|
|
118
|
+
try:
|
|
119
|
+
resp = requests.get(resource_metadata_url, timeout=timeout)
|
|
120
|
+
resp.raise_for_status()
|
|
121
|
+
return resp.json()
|
|
122
|
+
except Exception as exc: # broad catch – we want to surface auth requirement even if this fails
|
|
123
|
+
logger.warning("Failed to fetch resource metadata from %s: %s", resource_metadata_url, exc)
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
async def fetch_resource_metadata_async(resource_metadata_url: str, session=None, timeout: int = 10) -> Optional[Dict[str, Any]]:
|
|
128
|
+
"""Async variant for fetching protected resource metadata."""
|
|
129
|
+
try:
|
|
130
|
+
import aiohttp
|
|
131
|
+
|
|
132
|
+
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
|
133
|
+
if session:
|
|
134
|
+
async with session.get(resource_metadata_url, timeout=client_timeout) as resp:
|
|
135
|
+
text = await resp.text()
|
|
136
|
+
else:
|
|
137
|
+
async with aiohttp.ClientSession(timeout=client_timeout) as local_session:
|
|
138
|
+
async with local_session.get(resource_metadata_url) as resp:
|
|
139
|
+
text = await resp.text()
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
return json.loads(text)
|
|
143
|
+
except json.JSONDecodeError:
|
|
144
|
+
logger.warning("Resource metadata at %s is not valid JSON: %s", resource_metadata_url, text[:200])
|
|
145
|
+
return None
|
|
146
|
+
except Exception as exc:
|
|
147
|
+
logger.warning("Failed to fetch resource metadata from %s: %s", resource_metadata_url, exc)
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def canonical_resource(server_url: str) -> str:
|
|
152
|
+
"""Produce a canonical resource identifier for the MCP server."""
|
|
153
|
+
parsed = urlparse(server_url)
|
|
154
|
+
# Normalize scheme/host casing per RFC guidance
|
|
155
|
+
normalized = parsed._replace(
|
|
156
|
+
scheme=parsed.scheme.lower(),
|
|
157
|
+
netloc=parsed.netloc.lower(),
|
|
158
|
+
)
|
|
159
|
+
resource = normalized.geturl()
|
|
160
|
+
|
|
161
|
+
# Prefer form without trailing slash unless path is meaningful
|
|
162
|
+
if resource.endswith("/") and parsed.path in ("", "/"):
|
|
163
|
+
resource = resource[:-1]
|
|
164
|
+
return resource
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP SSE (Server-Sent Events) Client
|
|
3
|
+
Handles persistent SSE connections for MCP servers like Atlassian
|
|
4
|
+
"""
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Dict, Any, Optional, AsyncIterator
|
|
9
|
+
import aiohttp
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class McpSseClient:
|
|
15
|
+
"""
|
|
16
|
+
Client for MCP servers using SSE (Server-Sent Events) transport.
|
|
17
|
+
|
|
18
|
+
For Atlassian-style SSE (dual-connection model):
|
|
19
|
+
- GET request opens persistent SSE stream for receiving events
|
|
20
|
+
- POST requests send commands (return 202 Accepted immediately)
|
|
21
|
+
- Responses come via the GET stream
|
|
22
|
+
|
|
23
|
+
This client handles:
|
|
24
|
+
- Opening persistent SSE connection via GET
|
|
25
|
+
- Sending JSON-RPC requests via POST
|
|
26
|
+
- Reading SSE event streams
|
|
27
|
+
- Matching responses to requests by ID
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, url: str, session_id: str, headers: Optional[Dict[str, str]] = None, timeout: int = 300):
|
|
31
|
+
"""
|
|
32
|
+
Initialize SSE client.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
url: Base URL of the MCP SSE server
|
|
36
|
+
session_id: Client-generated UUID for session
|
|
37
|
+
headers: Additional headers (e.g., Authorization)
|
|
38
|
+
timeout: Request timeout in seconds
|
|
39
|
+
"""
|
|
40
|
+
self.url = url
|
|
41
|
+
self.session_id = session_id
|
|
42
|
+
self.headers = headers or {}
|
|
43
|
+
self.timeout = timeout
|
|
44
|
+
self.url_with_session = f"{url}?sessionId={session_id}"
|
|
45
|
+
self._stream_task = None
|
|
46
|
+
self._pending_requests = {} # request_id -> asyncio.Future
|
|
47
|
+
self._stream_session = None
|
|
48
|
+
self._stream_response = None
|
|
49
|
+
self._endpoint_ready = asyncio.Event() # Signal when endpoint is received
|
|
50
|
+
|
|
51
|
+
logger.info(f"[MCP SSE Client] Initialized for {url} with session {session_id}")
|
|
52
|
+
|
|
53
|
+
async def _ensure_stream_connected(self):
|
|
54
|
+
"""Ensure the GET stream is connected and reading events."""
|
|
55
|
+
if self._stream_task is None or self._stream_task.done():
|
|
56
|
+
logger.info(f"[MCP SSE Client] Opening persistent SSE stream...")
|
|
57
|
+
self._stream_session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None))
|
|
58
|
+
|
|
59
|
+
headers = {
|
|
60
|
+
"Accept": "text/event-stream",
|
|
61
|
+
**self.headers
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
self._stream_response = await self._stream_session.get(self.url_with_session, headers=headers)
|
|
65
|
+
|
|
66
|
+
logger.info(f"[MCP SSE Client] Stream opened: status={self._stream_response.status}")
|
|
67
|
+
|
|
68
|
+
if self._stream_response.status != 200:
|
|
69
|
+
error_text = await self._stream_response.text()
|
|
70
|
+
raise Exception(f"Failed to open SSE stream: HTTP {self._stream_response.status}: {error_text}")
|
|
71
|
+
|
|
72
|
+
# Start background task to read stream
|
|
73
|
+
self._stream_task = asyncio.create_task(self._read_stream())
|
|
74
|
+
|
|
75
|
+
async def _read_stream(self):
|
|
76
|
+
"""Background task that continuously reads the SSE stream."""
|
|
77
|
+
logger.info(f"[MCP SSE Client] Starting stream reader...")
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
buffer = ""
|
|
81
|
+
current_event = {}
|
|
82
|
+
|
|
83
|
+
async for chunk in self._stream_response.content.iter_chunked(1024):
|
|
84
|
+
chunk_str = chunk.decode('utf-8')
|
|
85
|
+
buffer += chunk_str
|
|
86
|
+
|
|
87
|
+
# Process complete lines
|
|
88
|
+
while '\n' in buffer:
|
|
89
|
+
line, buffer = buffer.split('\n', 1)
|
|
90
|
+
line_str = line.strip()
|
|
91
|
+
|
|
92
|
+
# Empty line indicates end of event
|
|
93
|
+
if not line_str:
|
|
94
|
+
if current_event and 'data' in current_event:
|
|
95
|
+
self._process_event(current_event)
|
|
96
|
+
current_event = {}
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# Parse SSE fields
|
|
100
|
+
if line_str.startswith('event:'):
|
|
101
|
+
current_event['event'] = line_str[6:].strip()
|
|
102
|
+
elif line_str.startswith('data:'):
|
|
103
|
+
data_str = line_str[5:].strip()
|
|
104
|
+
current_event['data'] = data_str
|
|
105
|
+
elif line_str.startswith('id:'):
|
|
106
|
+
current_event['id'] = line_str[3:].strip()
|
|
107
|
+
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f"[MCP SSE Client] Stream reader error: {e}")
|
|
110
|
+
# Fail all pending requests
|
|
111
|
+
for future in self._pending_requests.values():
|
|
112
|
+
if not future.done():
|
|
113
|
+
future.set_exception(e)
|
|
114
|
+
finally:
|
|
115
|
+
logger.info(f"[MCP SSE Client] Stream reader stopped")
|
|
116
|
+
|
|
117
|
+
def _process_event(self, event: Dict[str, str]):
|
|
118
|
+
"""Process a complete SSE event."""
|
|
119
|
+
event_type = event.get('event', 'message')
|
|
120
|
+
data_str = event.get('data', '')
|
|
121
|
+
|
|
122
|
+
# Handle 'endpoint' event - server provides the actual session URL to use
|
|
123
|
+
if event_type == 'endpoint':
|
|
124
|
+
# Extract session ID from endpoint URL
|
|
125
|
+
# Format: /v1/sse?sessionId=<uuid>
|
|
126
|
+
if 'sessionId=' in data_str:
|
|
127
|
+
new_session_id = data_str.split('sessionId=')[1].split('&')[0]
|
|
128
|
+
logger.info(f"[MCP SSE Client] Server provided session ID: {new_session_id}")
|
|
129
|
+
self.session_id = new_session_id
|
|
130
|
+
self.url_with_session = f"{self.url}?sessionId={new_session_id}"
|
|
131
|
+
self._endpoint_ready.set() # Signal that we can now send requests
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
# Skip other non-message events
|
|
135
|
+
if event_type != 'message' and not data_str.startswith('{'):
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
if not data_str:
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
data = json.loads(data_str)
|
|
143
|
+
request_id = data.get('id')
|
|
144
|
+
|
|
145
|
+
logger.debug(f"[MCP SSE Client] Received response for request {request_id}")
|
|
146
|
+
|
|
147
|
+
# Resolve pending request
|
|
148
|
+
if request_id and request_id in self._pending_requests:
|
|
149
|
+
future = self._pending_requests.pop(request_id)
|
|
150
|
+
if not future.done():
|
|
151
|
+
future.set_result(data)
|
|
152
|
+
|
|
153
|
+
except json.JSONDecodeError as e:
|
|
154
|
+
logger.warning(f"[MCP SSE Client] Failed to parse SSE data: {e}, data: {repr(data_str)[:200]}")
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.error(f"[MCP SSE Client] Stream reader error: {e}")
|
|
158
|
+
# Fail all pending requests
|
|
159
|
+
for future in self._pending_requests.values():
|
|
160
|
+
if not future.done():
|
|
161
|
+
future.set_exception(e)
|
|
162
|
+
finally:
|
|
163
|
+
logger.info(f"[MCP SSE Client] Stream reader stopped")
|
|
164
|
+
|
|
165
|
+
async def send_request(self, method: str, params: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None) -> Dict[str, Any]:
|
|
166
|
+
"""
|
|
167
|
+
Send a JSON-RPC request and wait for response via SSE stream.
|
|
168
|
+
|
|
169
|
+
Uses dual-connection model:
|
|
170
|
+
1. GET stream is kept open to receive responses
|
|
171
|
+
2. POST request sends the command (returns 202 immediately)
|
|
172
|
+
3. Response comes via the GET stream
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
method: JSON-RPC method name (e.g., "tools/list", "tools/call")
|
|
176
|
+
params: Method parameters
|
|
177
|
+
request_id: Optional request ID (auto-generated if not provided)
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Parsed JSON-RPC response
|
|
181
|
+
|
|
182
|
+
Raises:
|
|
183
|
+
Exception: If request fails or times out
|
|
184
|
+
"""
|
|
185
|
+
import time
|
|
186
|
+
if request_id is None:
|
|
187
|
+
request_id = f"{method.replace('/', '_')}_{int(time.time() * 1000)}"
|
|
188
|
+
|
|
189
|
+
request = {
|
|
190
|
+
"jsonrpc": "2.0",
|
|
191
|
+
"id": request_id,
|
|
192
|
+
"method": method,
|
|
193
|
+
"params": params or {}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
logger.debug(f"[MCP SSE Client] Sending request: {method} (id={request_id})")
|
|
197
|
+
|
|
198
|
+
# Ensure stream is connected
|
|
199
|
+
await self._ensure_stream_connected()
|
|
200
|
+
|
|
201
|
+
# Wait for endpoint event (server provides the actual session ID to use)
|
|
202
|
+
await asyncio.wait_for(self._endpoint_ready.wait(), timeout=10)
|
|
203
|
+
|
|
204
|
+
# Create future for this request
|
|
205
|
+
future = asyncio.Future()
|
|
206
|
+
self._pending_requests[request_id] = future
|
|
207
|
+
|
|
208
|
+
# Send POST request
|
|
209
|
+
headers = {
|
|
210
|
+
"Content-Type": "application/json",
|
|
211
|
+
**self.headers
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
timeout = aiohttp.ClientTimeout(total=30)
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
218
|
+
async with session.post(self.url_with_session, json=request, headers=headers) as response:
|
|
219
|
+
if response.status == 404:
|
|
220
|
+
error_text = await response.text()
|
|
221
|
+
raise Exception(f"HTTP 404: {error_text}")
|
|
222
|
+
|
|
223
|
+
# 202 is expected - response will come via stream
|
|
224
|
+
if response.status not in [200, 202]:
|
|
225
|
+
error_text = await response.text()
|
|
226
|
+
raise Exception(f"HTTP {response.status}: {error_text}")
|
|
227
|
+
|
|
228
|
+
# Wait for response from stream (with timeout)
|
|
229
|
+
result = await asyncio.wait_for(future, timeout=self.timeout)
|
|
230
|
+
|
|
231
|
+
# Check for JSON-RPC error
|
|
232
|
+
if 'error' in result:
|
|
233
|
+
error = result['error']
|
|
234
|
+
raise Exception(f"MCP Error: {error.get('message', str(error))}")
|
|
235
|
+
|
|
236
|
+
return result
|
|
237
|
+
|
|
238
|
+
except asyncio.TimeoutError:
|
|
239
|
+
self._pending_requests.pop(request_id, None)
|
|
240
|
+
logger.error(f"[MCP SSE Client] Request timeout after {self.timeout}s")
|
|
241
|
+
raise Exception(f"SSE request timeout after {self.timeout}s")
|
|
242
|
+
except Exception as e:
|
|
243
|
+
self._pending_requests.pop(request_id, None)
|
|
244
|
+
logger.error(f"[MCP SSE Client] Request failed: {e}")
|
|
245
|
+
raise
|
|
246
|
+
|
|
247
|
+
async def close(self):
|
|
248
|
+
"""Close the persistent SSE stream."""
|
|
249
|
+
logger.info(f"[MCP SSE Client] Closing connection...")
|
|
250
|
+
|
|
251
|
+
if self._stream_task and not self._stream_task.done():
|
|
252
|
+
self._stream_task.cancel()
|
|
253
|
+
try:
|
|
254
|
+
await self._stream_task
|
|
255
|
+
except asyncio.CancelledError:
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
if self._stream_response:
|
|
259
|
+
self._stream_response.close()
|
|
260
|
+
|
|
261
|
+
if self._stream_session:
|
|
262
|
+
await self._stream_session.close()
|
|
263
|
+
|
|
264
|
+
logger.info(f"[MCP SSE Client] Connection closed")
|
|
265
|
+
|
|
266
|
+
async def __aenter__(self):
|
|
267
|
+
"""Async context manager entry."""
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
271
|
+
"""Async context manager exit."""
|
|
272
|
+
await self.close()
|
|
273
|
+
|
|
274
|
+
async def initialize(self) -> Dict[str, Any]:
|
|
275
|
+
"""
|
|
276
|
+
Send initialize request to establish MCP protocol session.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Server capabilities and info
|
|
280
|
+
"""
|
|
281
|
+
response = await self.send_request(
|
|
282
|
+
method="initialize",
|
|
283
|
+
params={
|
|
284
|
+
"protocolVersion": "2024-11-05",
|
|
285
|
+
"capabilities": {
|
|
286
|
+
"roots": {"listChanged": True},
|
|
287
|
+
"sampling": {}
|
|
288
|
+
},
|
|
289
|
+
"clientInfo": {
|
|
290
|
+
"name": "Alita MCP Client",
|
|
291
|
+
"version": "1.0.0"
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
logger.info(f"[MCP SSE Client] MCP session initialized")
|
|
297
|
+
return response.get('result', {})
|
|
298
|
+
|
|
299
|
+
async def list_tools(self) -> list:
|
|
300
|
+
"""
|
|
301
|
+
Discover available tools from the MCP server.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
List of tool definitions
|
|
305
|
+
"""
|
|
306
|
+
response = await self.send_request(method="tools/list")
|
|
307
|
+
result = response.get('result', {})
|
|
308
|
+
tools = result.get('tools', [])
|
|
309
|
+
|
|
310
|
+
logger.info(f"[MCP SSE Client] Discovered {len(tools)} tools")
|
|
311
|
+
return tools
|
|
312
|
+
|
|
313
|
+
async def list_prompts(self) -> list:
|
|
314
|
+
"""
|
|
315
|
+
Discover available prompts from the MCP server.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
List of prompt definitions
|
|
319
|
+
"""
|
|
320
|
+
response = await self.send_request(method="prompts/list")
|
|
321
|
+
result = response.get('result', {})
|
|
322
|
+
prompts = result.get('prompts', [])
|
|
323
|
+
|
|
324
|
+
logger.debug(f"[MCP SSE Client] Discovered {len(prompts)} prompts")
|
|
325
|
+
return prompts
|
|
326
|
+
|
|
327
|
+
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
|
328
|
+
"""
|
|
329
|
+
Execute a tool on the MCP server.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
tool_name: Name of the tool to call
|
|
333
|
+
arguments: Tool arguments
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Tool execution result
|
|
337
|
+
"""
|
|
338
|
+
response = await self.send_request(
|
|
339
|
+
method="tools/call",
|
|
340
|
+
params={
|
|
341
|
+
"name": tool_name,
|
|
342
|
+
"arguments": arguments
|
|
343
|
+
}
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
result = response.get('result', {})
|
|
347
|
+
return result
|
|
@@ -868,10 +868,24 @@ def run_streamlit(st, ai_icon=None, user_icon=None):
|
|
|
868
868
|
label = f"{'🔒 ' if is_secret else ''}{'*' if is_required else ''}{field_name.replace('_', ' ').title()}"
|
|
869
869
|
|
|
870
870
|
if field_type == 'string':
|
|
871
|
-
if
|
|
871
|
+
# Check if this is an enum field
|
|
872
|
+
if field_schema.get('enum'):
|
|
873
|
+
# Dropdown for enum values
|
|
874
|
+
options = field_schema['enum']
|
|
875
|
+
default_index = 0
|
|
876
|
+
if default_value and str(default_value) in options:
|
|
877
|
+
default_index = options.index(str(default_value))
|
|
878
|
+
toolkit_config_values[field_name] = st.selectbox(
|
|
879
|
+
label,
|
|
880
|
+
options=options,
|
|
881
|
+
index=default_index,
|
|
882
|
+
help=field_description,
|
|
883
|
+
key=f"config_{field_name}_{selected_toolkit_idx}"
|
|
884
|
+
)
|
|
885
|
+
elif is_secret:
|
|
872
886
|
toolkit_config_values[field_name] = st.text_input(
|
|
873
887
|
label,
|
|
874
|
-
value=str(default_value) if default_value else '',
|
|
888
|
+
value=str(default_value) if default_value else '',
|
|
875
889
|
help=field_description,
|
|
876
890
|
type="password",
|
|
877
891
|
key=f"config_{field_name}_{selected_toolkit_idx}"
|
|
@@ -879,7 +893,7 @@ def run_streamlit(st, ai_icon=None, user_icon=None):
|
|
|
879
893
|
else:
|
|
880
894
|
toolkit_config_values[field_name] = st.text_input(
|
|
881
895
|
label,
|
|
882
|
-
value=str(default_value) if default_value else '',
|
|
896
|
+
value=str(default_value) if default_value else '',
|
|
883
897
|
help=field_description,
|
|
884
898
|
key=f"config_{field_name}_{selected_toolkit_idx}"
|
|
885
899
|
)
|
|
@@ -971,6 +985,23 @@ def run_streamlit(st, ai_icon=None, user_icon=None):
|
|
|
971
985
|
key=f"config_{field_name}_{selected_toolkit_idx}"
|
|
972
986
|
)
|
|
973
987
|
toolkit_config_values[field_name] = [line.strip() for line in array_input.split('\n') if line.strip()]
|
|
988
|
+
elif field_type == 'object':
|
|
989
|
+
# Handle object/dict types (like headers)
|
|
990
|
+
obj_input = st.text_area(
|
|
991
|
+
f"{label} (JSON object)",
|
|
992
|
+
value=json.dumps(default_value) if isinstance(default_value, dict) else str(default_value) if default_value else '',
|
|
993
|
+
help=f"{field_description} - Enter as JSON object, e.g. {{\"Authorization\": \"Bearer token\"}}",
|
|
994
|
+
placeholder='{"key": "value"}',
|
|
995
|
+
key=f"config_{field_name}_{selected_toolkit_idx}"
|
|
996
|
+
)
|
|
997
|
+
try:
|
|
998
|
+
if obj_input.strip():
|
|
999
|
+
toolkit_config_values[field_name] = json.loads(obj_input)
|
|
1000
|
+
else:
|
|
1001
|
+
toolkit_config_values[field_name] = None
|
|
1002
|
+
except json.JSONDecodeError as e:
|
|
1003
|
+
st.error(f"Invalid JSON format for {field_name}: {e}")
|
|
1004
|
+
toolkit_config_values[field_name] = None
|
|
974
1005
|
else:
|
|
975
1006
|
st.info("This toolkit doesn't require additional configuration.")
|
|
976
1007
|
|
|
@@ -29,13 +29,14 @@ def instantiate_toolkit_with_client(toolkit_config: Dict[str, Any],
|
|
|
29
29
|
|
|
30
30
|
Raises:
|
|
31
31
|
ValueError: If required configuration or client is missing
|
|
32
|
+
McpAuthorizationRequired: If MCP server requires OAuth authorization
|
|
32
33
|
Exception: If toolkit instantiation fails
|
|
33
34
|
"""
|
|
35
|
+
toolkit_name = toolkit_config.get('toolkit_name', 'unknown')
|
|
34
36
|
try:
|
|
35
37
|
from ..toolkits.tools import get_tools
|
|
36
38
|
|
|
37
|
-
toolkit_name
|
|
38
|
-
if not toolkit_name:
|
|
39
|
+
if not toolkit_name or toolkit_name == 'unknown':
|
|
39
40
|
raise ValueError("toolkit_name is required in configuration")
|
|
40
41
|
|
|
41
42
|
if not llm_client:
|
|
@@ -46,11 +47,14 @@ def instantiate_toolkit_with_client(toolkit_config: Dict[str, Any],
|
|
|
46
47
|
# Log the configuration being used
|
|
47
48
|
logger.info(f"Instantiating toolkit {toolkit_name} with LLM client")
|
|
48
49
|
logger.debug(f"Toolkit {toolkit_name} configuration: {toolkit_config}")
|
|
49
|
-
|
|
50
|
+
|
|
51
|
+
# Use toolkit type from config, or fall back to lowercase toolkit name
|
|
52
|
+
toolkit_type = toolkit_config.get('type', toolkit_name.lower())
|
|
53
|
+
|
|
50
54
|
# Create a tool configuration dict with required fields
|
|
51
55
|
tool_config = {
|
|
52
56
|
'id': toolkit_config.get('id', random.randint(1, 1000000)),
|
|
53
|
-
'type': toolkit_config.get('type',
|
|
57
|
+
'type': toolkit_config.get('type', toolkit_type),
|
|
54
58
|
'settings': settings,
|
|
55
59
|
'toolkit_name': toolkit_name
|
|
56
60
|
}
|
|
@@ -67,6 +71,12 @@ def instantiate_toolkit_with_client(toolkit_config: Dict[str, Any],
|
|
|
67
71
|
return tools
|
|
68
72
|
|
|
69
73
|
except Exception as e:
|
|
74
|
+
# Re-raise McpAuthorizationRequired without logging as error
|
|
75
|
+
from ..utils.mcp_oauth import McpAuthorizationRequired
|
|
76
|
+
if isinstance(e, McpAuthorizationRequired):
|
|
77
|
+
logger.info(f"Toolkit {toolkit_name} requires MCP OAuth authorization")
|
|
78
|
+
raise
|
|
79
|
+
# Log and re-raise other errors
|
|
70
80
|
logger.error(f"Error instantiating toolkit {toolkit_name} with client: {str(e)}")
|
|
71
81
|
raise
|
|
72
82
|
|
alita_sdk/tools/__init__.py
CHANGED
|
@@ -131,6 +131,11 @@ def get_tools(tools_list, alita, llm, store: Optional[BaseStore] = None, *args,
|
|
|
131
131
|
logger.error(f"Error getting ADO repos tools: {e}")
|
|
132
132
|
continue
|
|
133
133
|
|
|
134
|
+
# Skip MCP toolkit - it's handled by runtime/toolkits/tools.py to avoid duplicate loading
|
|
135
|
+
if tool_type == 'mcp':
|
|
136
|
+
logger.debug(f"Skipping MCP toolkit '{tool.get('toolkit_name')}' - handled by runtime toolkit system")
|
|
137
|
+
continue
|
|
138
|
+
|
|
134
139
|
# Handle standard tools
|
|
135
140
|
if tool_type in AVAILABLE_TOOLS and 'get_tools' in AVAILABLE_TOOLS[tool_type]:
|
|
136
141
|
try:
|
|
@@ -6,7 +6,7 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
|
6
6
|
from langchain.text_splitter import TokenTextSplitter
|
|
7
7
|
|
|
8
8
|
from typing import Optional, List
|
|
9
|
-
from
|
|
9
|
+
from pydantic import BaseModel
|
|
10
10
|
from ..utils import tiktoken_length
|
|
11
11
|
|
|
12
12
|
logger = getLogger(__name__)
|
|
@@ -115,6 +115,11 @@ class GitLabAPIWrapper(CodeIndexerToolkit):
|
|
|
115
115
|
"""Remove trailing slash from URL if present."""
|
|
116
116
|
return url.rstrip('/') if url else url
|
|
117
117
|
|
|
118
|
+
@model_validator(mode='before')
|
|
119
|
+
@classmethod
|
|
120
|
+
def validate_toolkit_before(cls, values: Dict) -> Dict:
|
|
121
|
+
return super().validate_toolkit(values)
|
|
122
|
+
|
|
118
123
|
@model_validator(mode='after')
|
|
119
124
|
def validate_toolkit(self):
|
|
120
125
|
try:
|