alita-sdk 0.3.379__py3-none-any.whl → 0.3.462__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of alita-sdk might be problematic. Click here for more details.

Files changed (110) hide show
  1. alita_sdk/cli/__init__.py +10 -0
  2. alita_sdk/cli/__main__.py +17 -0
  3. alita_sdk/cli/agent_executor.py +144 -0
  4. alita_sdk/cli/agent_loader.py +197 -0
  5. alita_sdk/cli/agent_ui.py +166 -0
  6. alita_sdk/cli/agents.py +1069 -0
  7. alita_sdk/cli/callbacks.py +576 -0
  8. alita_sdk/cli/cli.py +159 -0
  9. alita_sdk/cli/config.py +153 -0
  10. alita_sdk/cli/formatting.py +182 -0
  11. alita_sdk/cli/mcp_loader.py +315 -0
  12. alita_sdk/cli/toolkit.py +330 -0
  13. alita_sdk/cli/toolkit_loader.py +55 -0
  14. alita_sdk/cli/tools/__init__.py +9 -0
  15. alita_sdk/cli/tools/filesystem.py +905 -0
  16. alita_sdk/configurations/bitbucket.py +95 -0
  17. alita_sdk/configurations/confluence.py +96 -1
  18. alita_sdk/configurations/gitlab.py +79 -0
  19. alita_sdk/configurations/jira.py +103 -0
  20. alita_sdk/configurations/testrail.py +88 -0
  21. alita_sdk/configurations/xray.py +93 -0
  22. alita_sdk/configurations/zephyr_enterprise.py +93 -0
  23. alita_sdk/configurations/zephyr_essential.py +75 -0
  24. alita_sdk/runtime/clients/client.py +47 -10
  25. alita_sdk/runtime/clients/mcp_discovery.py +342 -0
  26. alita_sdk/runtime/clients/mcp_manager.py +262 -0
  27. alita_sdk/runtime/clients/sandbox_client.py +8 -0
  28. alita_sdk/runtime/langchain/assistant.py +37 -16
  29. alita_sdk/runtime/langchain/constants.py +6 -1
  30. alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
  31. alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +4 -1
  32. alita_sdk/runtime/langchain/document_loaders/constants.py +28 -12
  33. alita_sdk/runtime/langchain/langraph_agent.py +146 -31
  34. alita_sdk/runtime/langchain/utils.py +39 -7
  35. alita_sdk/runtime/models/mcp_models.py +61 -0
  36. alita_sdk/runtime/toolkits/__init__.py +24 -0
  37. alita_sdk/runtime/toolkits/application.py +8 -1
  38. alita_sdk/runtime/toolkits/artifact.py +5 -6
  39. alita_sdk/runtime/toolkits/mcp.py +895 -0
  40. alita_sdk/runtime/toolkits/tools.py +137 -56
  41. alita_sdk/runtime/tools/__init__.py +7 -2
  42. alita_sdk/runtime/tools/application.py +7 -0
  43. alita_sdk/runtime/tools/function.py +29 -25
  44. alita_sdk/runtime/tools/graph.py +10 -4
  45. alita_sdk/runtime/tools/image_generation.py +104 -8
  46. alita_sdk/runtime/tools/llm.py +204 -114
  47. alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
  48. alita_sdk/runtime/tools/mcp_remote_tool.py +166 -0
  49. alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
  50. alita_sdk/runtime/tools/sandbox.py +57 -43
  51. alita_sdk/runtime/tools/vectorstore.py +2 -1
  52. alita_sdk/runtime/tools/vectorstore_base.py +19 -3
  53. alita_sdk/runtime/utils/mcp_oauth.py +164 -0
  54. alita_sdk/runtime/utils/mcp_sse_client.py +405 -0
  55. alita_sdk/runtime/utils/streamlit.py +34 -3
  56. alita_sdk/runtime/utils/toolkit_utils.py +14 -4
  57. alita_sdk/tools/__init__.py +46 -31
  58. alita_sdk/tools/ado/repos/__init__.py +1 -0
  59. alita_sdk/tools/ado/test_plan/__init__.py +1 -1
  60. alita_sdk/tools/ado/wiki/__init__.py +1 -5
  61. alita_sdk/tools/ado/work_item/__init__.py +1 -5
  62. alita_sdk/tools/ado/work_item/ado_wrapper.py +17 -8
  63. alita_sdk/tools/base_indexer_toolkit.py +105 -43
  64. alita_sdk/tools/bitbucket/__init__.py +1 -0
  65. alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
  66. alita_sdk/tools/code/sonar/__init__.py +1 -1
  67. alita_sdk/tools/code_indexer_toolkit.py +13 -3
  68. alita_sdk/tools/confluence/__init__.py +2 -2
  69. alita_sdk/tools/confluence/api_wrapper.py +29 -7
  70. alita_sdk/tools/confluence/loader.py +10 -0
  71. alita_sdk/tools/github/__init__.py +2 -2
  72. alita_sdk/tools/gitlab/__init__.py +2 -1
  73. alita_sdk/tools/gitlab/api_wrapper.py +11 -7
  74. alita_sdk/tools/gitlab_org/__init__.py +1 -2
  75. alita_sdk/tools/google_places/__init__.py +2 -1
  76. alita_sdk/tools/jira/__init__.py +1 -0
  77. alita_sdk/tools/jira/api_wrapper.py +1 -1
  78. alita_sdk/tools/memory/__init__.py +1 -1
  79. alita_sdk/tools/openapi/__init__.py +10 -1
  80. alita_sdk/tools/pandas/__init__.py +1 -1
  81. alita_sdk/tools/postman/__init__.py +2 -1
  82. alita_sdk/tools/pptx/__init__.py +2 -2
  83. alita_sdk/tools/qtest/__init__.py +3 -3
  84. alita_sdk/tools/qtest/api_wrapper.py +1708 -76
  85. alita_sdk/tools/rally/__init__.py +1 -2
  86. alita_sdk/tools/report_portal/__init__.py +1 -0
  87. alita_sdk/tools/salesforce/__init__.py +1 -0
  88. alita_sdk/tools/servicenow/__init__.py +2 -3
  89. alita_sdk/tools/sharepoint/__init__.py +1 -0
  90. alita_sdk/tools/sharepoint/api_wrapper.py +125 -34
  91. alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
  92. alita_sdk/tools/sharepoint/utils.py +8 -2
  93. alita_sdk/tools/slack/__init__.py +1 -0
  94. alita_sdk/tools/sql/__init__.py +2 -1
  95. alita_sdk/tools/testio/__init__.py +1 -0
  96. alita_sdk/tools/testrail/__init__.py +1 -3
  97. alita_sdk/tools/utils/content_parser.py +27 -16
  98. alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +18 -5
  99. alita_sdk/tools/xray/__init__.py +2 -1
  100. alita_sdk/tools/zephyr/__init__.py +2 -1
  101. alita_sdk/tools/zephyr_enterprise/__init__.py +1 -0
  102. alita_sdk/tools/zephyr_essential/__init__.py +1 -0
  103. alita_sdk/tools/zephyr_scale/__init__.py +1 -0
  104. alita_sdk/tools/zephyr_squad/__init__.py +1 -0
  105. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/METADATA +8 -2
  106. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/RECORD +110 -86
  107. alita_sdk-0.3.462.dist-info/entry_points.txt +2 -0
  108. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/WHEEL +0 -0
  109. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/licenses/LICENSE +0 -0
  110. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import logging
2
3
  from traceback import format_exc
3
4
  from typing import Any, Optional, List, Union
@@ -7,6 +8,7 @@ from langchain_core.runnables import RunnableConfig
7
8
  from langchain_core.tools import BaseTool, ToolException
8
9
  from pydantic import Field
9
10
 
11
+ from ..langchain.constants import ELITEA_RS
10
12
  from ..langchain.utils import create_pydantic_model, propagate_the_input_mapping
11
13
 
12
14
  logger = logging.getLogger(__name__)
@@ -30,6 +32,7 @@ class LLMNode(BaseTool):
30
32
  structured_output: Optional[bool] = Field(default=False, description='Whether to use structured output')
31
33
  available_tools: Optional[List[BaseTool]] = Field(default=None, description='Available tools for binding')
32
34
  tool_names: Optional[List[str]] = Field(default=None, description='Specific tool names to filter')
35
+ steps_limit: Optional[int] = Field(default=25, description='Maximum steps for tool execution')
33
36
 
34
37
  def get_filtered_tools(self) -> List[BaseTool]:
35
38
  """
@@ -88,8 +91,11 @@ class LLMNode(BaseTool):
88
91
  raise ToolException(f"LLMNode requires 'system' and 'task' parameters in input mapping. "
89
92
  f"Actual params: {func_args}")
90
93
  # cast to str in case user passes variable different from str
91
- messages = [SystemMessage(content=str(func_args.get('system'))), HumanMessage(content=str(func_args.get('task')))]
92
- messages.extend(func_args.get('chat_history', []))
94
+ messages = [SystemMessage(content=str(func_args.get('system'))), *func_args.get('chat_history', []), HumanMessage(content=str(func_args.get('task')))]
95
+ # Remove pre-last item if last two messages are same type and content
96
+ if len(messages) >= 2 and type(messages[-1]) == type(messages[-2]) and messages[-1].content == messages[
97
+ -2].content:
98
+ messages.pop(-2)
93
99
  else:
94
100
  # Flow for chat-based LLM node w/o prompt/task from pipeline but with messages in state
95
101
  # verify messages structure
@@ -122,14 +128,27 @@ class LLMNode(BaseTool):
122
128
  }
123
129
  for key, value in (self.structured_output_dict or {}).items()
124
130
  }
131
+ # Add default output field for proper response to user
132
+ struct_params['elitea_response'] = {'description': 'final output to user', 'type': 'str'}
125
133
  struct_model = create_pydantic_model(f"LLMOutput", struct_params)
126
- llm = llm_client.with_structured_output(struct_model)
127
- completion = llm.invoke(messages, config=config)
128
- result = completion.model_dump()
134
+ completion = llm_client.invoke(messages, config=config)
135
+ if hasattr(completion, 'tool_calls') and completion.tool_calls:
136
+ new_messages, _ = self._run_async_in_sync_context(
137
+ self.__perform_tool_calling(completion, messages, llm_client, config)
138
+ )
139
+ llm = self.__get_struct_output_model(llm_client, struct_model)
140
+ completion = llm.invoke(new_messages, config=config)
141
+ result = completion.model_dump()
142
+ else:
143
+ llm = self.__get_struct_output_model(llm_client, struct_model)
144
+ completion = llm.invoke(messages, config=config)
145
+ result = completion.model_dump()
129
146
 
130
147
  # Ensure messages are properly formatted
131
148
  if result.get('messages') and isinstance(result['messages'], list):
132
149
  result['messages'] = [{'role': 'assistant', 'content': '\n'.join(result['messages'])}]
150
+ else:
151
+ result['messages'] = messages + [AIMessage(content=result.get(ELITEA_RS, ''))]
133
152
 
134
153
  return result
135
154
  else:
@@ -139,115 +158,17 @@ class LLMNode(BaseTool):
139
158
  # Handle both tool-calling and regular responses
140
159
  if hasattr(completion, 'tool_calls') and completion.tool_calls:
141
160
  # Handle iterative tool-calling and execution
142
- new_messages = messages + [completion]
143
- max_iterations = 15
144
- iteration = 0
145
-
146
- # Continue executing tools until no more tool calls or max iterations reached
147
- current_completion = completion
148
- while (hasattr(current_completion, 'tool_calls') and
149
- current_completion.tool_calls and
150
- iteration < max_iterations):
151
-
152
- iteration += 1
153
- logger.info(f"Tool execution iteration {iteration}/{max_iterations}")
154
-
155
- # Execute each tool call in the current completion
156
- tool_calls = current_completion.tool_calls if hasattr(current_completion.tool_calls,
157
- '__iter__') else []
158
-
159
- for tool_call in tool_calls:
160
- tool_name = tool_call.get('name', '') if isinstance(tool_call, dict) else getattr(tool_call,
161
- 'name',
162
- '')
163
- tool_args = tool_call.get('args', {}) if isinstance(tool_call, dict) else getattr(tool_call,
164
- 'args',
165
- {})
166
- tool_call_id = tool_call.get('id', '') if isinstance(tool_call, dict) else getattr(
167
- tool_call, 'id', '')
168
-
169
- # Find the tool in filtered tools
170
- filtered_tools = self.get_filtered_tools()
171
- tool_to_execute = None
172
- for tool in filtered_tools:
173
- if tool.name == tool_name:
174
- tool_to_execute = tool
175
- break
176
-
177
- if tool_to_execute:
178
- try:
179
- logger.info(f"Executing tool '{tool_name}' with args: {tool_args}")
180
- tool_result = tool_to_execute.invoke(tool_args)
181
-
182
- # Create tool message with result - preserve structured content
183
- from langchain_core.messages import ToolMessage
184
-
185
- # Check if tool_result is structured content (list of dicts)
186
- # TODO: need solid check for being compatible with ToolMessage content format
187
- if isinstance(tool_result, list) and all(
188
- isinstance(item, dict) and 'type' in item for item in tool_result
189
- ):
190
- # Use structured content directly for multimodal support
191
- tool_message = ToolMessage(
192
- content=tool_result,
193
- tool_call_id=tool_call_id
194
- )
195
- else:
196
- # Fallback to string conversion for other tool results
197
- tool_message = ToolMessage(
198
- content=str(tool_result),
199
- tool_call_id=tool_call_id
200
- )
201
- new_messages.append(tool_message)
202
-
203
- except Exception as e:
204
- logger.error(f"Error executing tool '{tool_name}': {e}")
205
- # Create error tool message
206
- from langchain_core.messages import ToolMessage
207
- tool_message = ToolMessage(
208
- content=f"Error executing {tool_name}: {str(e)}",
209
- tool_call_id=tool_call_id
210
- )
211
- new_messages.append(tool_message)
212
- else:
213
- logger.warning(f"Tool '{tool_name}' not found in available tools")
214
- # Create error tool message for missing tool
215
- from langchain_core.messages import ToolMessage
216
- tool_message = ToolMessage(
217
- content=f"Tool '{tool_name}' not available",
218
- tool_call_id=tool_call_id
219
- )
220
- new_messages.append(tool_message)
221
-
222
- # Call LLM again with tool results to get next response
223
- try:
224
- current_completion = llm_client.invoke(new_messages, config=config)
225
- new_messages.append(current_completion)
226
-
227
- # Check if we still have tool calls
228
- if hasattr(current_completion, 'tool_calls') and current_completion.tool_calls:
229
- logger.info(f"LLM requested {len(current_completion.tool_calls)} more tool calls")
230
- else:
231
- logger.info("LLM completed without requesting more tools")
232
- break
233
-
234
- except Exception as e:
235
- logger.error(f"Error in LLM call during iteration {iteration}: {e}")
236
- # Add error message and break the loop
237
- error_msg = f"Error processing tool results in iteration {iteration}: {str(e)}"
238
- new_messages.append(AIMessage(content=error_msg))
239
- break
240
-
241
- # Log completion status
242
- if iteration >= max_iterations:
243
- logger.warning(f"Reached maximum iterations ({max_iterations}) for tool execution")
244
- # Add a warning message to the chat
245
- warning_msg = f"Maximum tool execution iterations ({max_iterations}) reached. Stopping tool execution."
246
- new_messages.append(AIMessage(content=warning_msg))
247
- else:
248
- logger.info(f"Tool execution completed after {iteration} iterations")
161
+ new_messages, current_completion = self._run_async_in_sync_context(
162
+ self.__perform_tool_calling(completion, messages, llm_client, config)
163
+ )
249
164
 
250
- return {"messages": new_messages}
165
+ output_msgs = {"messages": new_messages}
166
+ if self.output_variables:
167
+ if self.output_variables[0] == 'messages':
168
+ return output_msgs
169
+ output_msgs[self.output_variables[0]] = current_completion.content if current_completion else None
170
+
171
+ return output_msgs
251
172
  else:
252
173
  # Regular text response
253
174
  content = completion.content.strip() if hasattr(completion, 'content') else str(completion)
@@ -273,4 +194,173 @@ class LLMNode(BaseTool):
273
194
 
274
195
  def _run(self, *args, **kwargs):
275
196
  # Legacy support for old interface
276
- return self.invoke(kwargs, **kwargs)
197
+ return self.invoke(kwargs, **kwargs)
198
+
199
+ def _run_async_in_sync_context(self, coro):
200
+ """Run async coroutine from sync context.
201
+
202
+ For MCP tools with persistent sessions, we reuse the same event loop
203
+ that was used to create the MCP client and sessions (set by CLI).
204
+ """
205
+ try:
206
+ loop = asyncio.get_running_loop()
207
+ # Already in async context - run in thread with new loop
208
+ import threading
209
+
210
+ result_container = []
211
+
212
+ def run_in_thread():
213
+ new_loop = asyncio.new_event_loop()
214
+ asyncio.set_event_loop(new_loop)
215
+ try:
216
+ result_container.append(new_loop.run_until_complete(coro))
217
+ finally:
218
+ new_loop.close()
219
+
220
+ thread = threading.Thread(target=run_in_thread)
221
+ thread.start()
222
+ thread.join()
223
+ return result_container[0] if result_container else None
224
+
225
+ except RuntimeError:
226
+ # No event loop running - use/create persistent loop
227
+ # This loop is shared with MCP session creation for stateful tools
228
+ if not hasattr(self.__class__, '_persistent_loop') or \
229
+ self.__class__._persistent_loop is None or \
230
+ self.__class__._persistent_loop.is_closed():
231
+ self.__class__._persistent_loop = asyncio.new_event_loop()
232
+ logger.debug("Created persistent event loop for async tools")
233
+
234
+ loop = self.__class__._persistent_loop
235
+ asyncio.set_event_loop(loop)
236
+ return loop.run_until_complete(coro)
237
+
238
+ async def _arun(self, *args, **kwargs):
239
+ # Legacy async support
240
+ return self.invoke(kwargs, **kwargs)
241
+
242
+ async def __perform_tool_calling(self, completion, messages, llm_client, config):
243
+ # Handle iterative tool-calling and execution
244
+ logger.info(f"__perform_tool_calling called with {len(completion.tool_calls) if hasattr(completion, 'tool_calls') else 0} tool calls")
245
+ new_messages = messages + [completion]
246
+ iteration = 0
247
+
248
+ # Continue executing tools until no more tool calls or max iterations reached
249
+ current_completion = completion
250
+ while (hasattr(current_completion, 'tool_calls') and
251
+ current_completion.tool_calls and
252
+ iteration < self.steps_limit):
253
+
254
+ iteration += 1
255
+ logger.info(f"Tool execution iteration {iteration}/{self.steps_limit}")
256
+
257
+ # Execute each tool call in the current completion
258
+ tool_calls = current_completion.tool_calls if hasattr(current_completion.tool_calls,
259
+ '__iter__') else []
260
+
261
+ for tool_call in tool_calls:
262
+ tool_name = tool_call.get('name', '') if isinstance(tool_call, dict) else getattr(tool_call,
263
+ 'name',
264
+ '')
265
+ tool_args = tool_call.get('args', {}) if isinstance(tool_call, dict) else getattr(tool_call,
266
+ 'args',
267
+ {})
268
+ tool_call_id = tool_call.get('id', '') if isinstance(tool_call, dict) else getattr(
269
+ tool_call, 'id', '')
270
+
271
+ # Find the tool in filtered tools
272
+ filtered_tools = self.get_filtered_tools()
273
+ tool_to_execute = None
274
+ for tool in filtered_tools:
275
+ if tool.name == tool_name:
276
+ tool_to_execute = tool
277
+ break
278
+
279
+ if tool_to_execute:
280
+ try:
281
+ logger.info(f"Executing tool '{tool_name}' with args: {tool_args}")
282
+
283
+ # Try async invoke first (for MCP tools), fallback to sync
284
+ tool_result = None
285
+ try:
286
+ # Try async invocation first
287
+ tool_result = await tool_to_execute.ainvoke(tool_args, config=config)
288
+ except NotImplementedError:
289
+ # Tool doesn't support async, use sync invoke
290
+ logger.debug(f"Tool '{tool_name}' doesn't support async, using sync invoke")
291
+ tool_result = tool_to_execute.invoke(tool_args, config=config)
292
+
293
+ # Create tool message with result - preserve structured content
294
+ from langchain_core.messages import ToolMessage
295
+
296
+ # Check if tool_result is structured content (list of dicts)
297
+ # TODO: need solid check for being compatible with ToolMessage content format
298
+ if isinstance(tool_result, list) and all(
299
+ isinstance(item, dict) and 'type' in item for item in tool_result
300
+ ):
301
+ # Use structured content directly for multimodal support
302
+ tool_message = ToolMessage(
303
+ content=tool_result,
304
+ tool_call_id=tool_call_id
305
+ )
306
+ else:
307
+ # Fallback to string conversion for other tool results
308
+ tool_message = ToolMessage(
309
+ content=str(tool_result),
310
+ tool_call_id=tool_call_id
311
+ )
312
+ new_messages.append(tool_message)
313
+
314
+ except Exception as e:
315
+ import traceback
316
+ error_details = traceback.format_exc()
317
+ logger.error(f"Error executing tool '{tool_name}': {e}\n{error_details}")
318
+ # Create error tool message
319
+ from langchain_core.messages import ToolMessage
320
+ tool_message = ToolMessage(
321
+ content=f"Error executing {tool_name}: {str(e)}",
322
+ tool_call_id=tool_call_id
323
+ )
324
+ new_messages.append(tool_message)
325
+ else:
326
+ logger.warning(f"Tool '{tool_name}' not found in available tools")
327
+ # Create error tool message for missing tool
328
+ from langchain_core.messages import ToolMessage
329
+ tool_message = ToolMessage(
330
+ content=f"Tool '{tool_name}' not available",
331
+ tool_call_id=tool_call_id
332
+ )
333
+ new_messages.append(tool_message)
334
+
335
+ # Call LLM again with tool results to get next response
336
+ try:
337
+ current_completion = llm_client.invoke(new_messages, config=config)
338
+ new_messages.append(current_completion)
339
+
340
+ # Check if we still have tool calls
341
+ if hasattr(current_completion, 'tool_calls') and current_completion.tool_calls:
342
+ logger.info(f"LLM requested {len(current_completion.tool_calls)} more tool calls")
343
+ else:
344
+ logger.info("LLM completed without requesting more tools")
345
+ break
346
+
347
+ except Exception as e:
348
+ logger.error(f"Error in LLM call during iteration {iteration}: {e}")
349
+ # Add error message and break the loop
350
+ error_msg = f"Error processing tool results in iteration {iteration}: {str(e)}"
351
+ new_messages.append(AIMessage(content=error_msg))
352
+ break
353
+
354
+ # Log completion status
355
+ if iteration >= self.steps_limit:
356
+ logger.warning(f"Reached maximum iterations ({self.steps_limit}) for tool execution")
357
+ # Add a warning message to the chat
358
+ warning_msg = f"Maximum tool execution iterations ({self.steps_limit}) reached. Stopping tool execution."
359
+ new_messages.append(AIMessage(content=warning_msg))
360
+ else:
361
+ logger.info(f"Tool execution completed after {iteration} iterations")
362
+
363
+ return new_messages, current_completion
364
+
365
+ def __get_struct_output_model(self, llm_client, pydantic_model):
366
+ return llm_client.with_structured_output(pydantic_model)
@@ -0,0 +1,284 @@
1
+ """
2
+ MCP Server Inspection Tool.
3
+ Allows inspecting available tools, prompts, and resources on an MCP server.
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import logging
9
+ import time
10
+ from typing import Any, Type, Dict, List, Optional
11
+
12
+ from langchain_core.tools import BaseTool
13
+ from pydantic import BaseModel, Field, ConfigDict
14
+ import aiohttp
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class McpInspectInput(BaseModel):
20
+ """Input schema for MCP server inspection tool."""
21
+
22
+ resource_type: str = Field(
23
+ default="all",
24
+ description="What to inspect: 'tools', 'prompts', 'resources', or 'all'"
25
+ )
26
+
27
+
28
+ class McpInspectTool(BaseTool):
29
+ """Tool for inspecting available tools, prompts, and resources on an MCP server."""
30
+
31
+ name: str = "mcp_inspect"
32
+ description: str = "List available tools, prompts, and resources from the MCP server"
33
+ args_schema: Type[BaseModel] = McpInspectInput
34
+ return_type: str = "str"
35
+
36
+ # MCP server connection details
37
+ server_name: str = Field(..., description="Name of the MCP server")
38
+ server_url: str = Field(..., description="URL of the MCP server")
39
+ server_headers: Optional[Dict[str, str]] = Field(default=None, description="HTTP headers for authentication")
40
+ timeout: int = Field(default=30, description="Request timeout in seconds")
41
+
42
+ model_config = ConfigDict(arbitrary_types_allowed=True)
43
+
44
+ def __getstate__(self):
45
+ """Custom serialization for pickle compatibility."""
46
+ state = self.__dict__.copy()
47
+ # Convert headers dict to regular dict to avoid any reference issues
48
+ if 'server_headers' in state and state['server_headers'] is not None:
49
+ state['server_headers'] = dict(state['server_headers'])
50
+ return state
51
+
52
+ def __setstate__(self, state):
53
+ """Custom deserialization for pickle compatibility."""
54
+ # Initialize Pydantic internal attributes if needed
55
+ if '__pydantic_fields_set__' not in state:
56
+ state['__pydantic_fields_set__'] = set(state.keys())
57
+ if '__pydantic_extra__' not in state:
58
+ state['__pydantic_extra__'] = None
59
+ if '__pydantic_private__' not in state:
60
+ state['__pydantic_private__'] = None
61
+
62
+ # Update object state
63
+ self.__dict__.update(state)
64
+
65
+ def _run(self, resource_type: str = "all") -> str:
66
+ """Inspect the MCP server for available resources."""
67
+ try:
68
+ # Always create a new event loop for sync context
69
+ # This avoids issues with existing event loops in threads
70
+ import concurrent.futures
71
+ with concurrent.futures.ThreadPoolExecutor() as executor:
72
+ future = executor.submit(self._run_in_new_loop, resource_type)
73
+ return future.result(timeout=self.timeout)
74
+ except Exception as e:
75
+ logger.error(f"Error inspecting MCP server '{self.server_name}': {e}")
76
+ return f"Error inspecting MCP server: {e}"
77
+
78
+ def _run_in_new_loop(self, resource_type: str) -> str:
79
+ """Run the async inspection in a new event loop."""
80
+ return asyncio.run(self._inspect_server(resource_type))
81
+
82
+ async def _inspect_server(self, resource_type: str) -> str:
83
+ """Perform the actual MCP server inspection."""
84
+ results = {}
85
+
86
+ # Determine what to inspect
87
+ inspect_tools = resource_type in ["all", "tools"]
88
+ inspect_prompts = resource_type in ["all", "prompts"]
89
+ inspect_resources = resource_type in ["all", "resources"]
90
+
91
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
92
+
93
+ # List tools
94
+ if inspect_tools:
95
+ try:
96
+ tools = await self._list_tools(session)
97
+ results["tools"] = tools
98
+ except Exception as e:
99
+ logger.warning(f"Failed to list tools from {self.server_name}: {e}")
100
+ results["tools"] = {"error": str(e)}
101
+
102
+ # List prompts
103
+ if inspect_prompts:
104
+ try:
105
+ prompts = await self._list_prompts(session)
106
+ results["prompts"] = prompts
107
+ except Exception as e:
108
+ logger.warning(f"Failed to list prompts from {self.server_name}: {e}")
109
+ results["prompts"] = {"error": str(e)}
110
+
111
+ # List resources
112
+ if inspect_resources:
113
+ try:
114
+ resources = await self._list_resources(session)
115
+ results["resources"] = resources
116
+ except Exception as e:
117
+ logger.warning(f"Failed to list resources from {self.server_name}: {e}")
118
+ results["resources"] = {"error": str(e)}
119
+
120
+ return self._format_results(results, resource_type)
121
+
122
+ def _parse_sse(self, text: str) -> Dict[str, Any]:
123
+ """Parse Server-Sent Events (SSE) format response."""
124
+ for line in text.split('\n'):
125
+ line = line.strip()
126
+ if line.startswith('data:'):
127
+ json_str = line[5:].strip()
128
+ return json.loads(json_str)
129
+ raise ValueError("No data found in SSE response")
130
+
131
+ async def _list_tools(self, session: aiohttp.ClientSession) -> Dict[str, Any]:
132
+ """List available tools from the MCP server."""
133
+ request = {
134
+ "jsonrpc": "2.0",
135
+ "id": f"list_tools_{int(time.time())}",
136
+ "method": "tools/list",
137
+ "params": {}
138
+ }
139
+
140
+ headers = {
141
+ "Content-Type": "application/json",
142
+ "Accept": "application/json, text/event-stream",
143
+ **self.server_headers
144
+ }
145
+
146
+ async with session.post(self.server_url, json=request, headers=headers) as response:
147
+ if response.status != 200:
148
+ raise Exception(f"HTTP {response.status}: {await response.text()}")
149
+
150
+ # Handle both JSON and SSE responses
151
+ content_type = response.headers.get('Content-Type', '')
152
+ if 'text/event-stream' in content_type:
153
+ # Parse SSE format
154
+ text = await response.text()
155
+ data = self._parse_sse(text)
156
+ else:
157
+ data = await response.json()
158
+
159
+ if "error" in data:
160
+ raise Exception(f"MCP Error: {data['error']}")
161
+
162
+ return data.get("result", {})
163
+
164
+ async def _list_prompts(self, session: aiohttp.ClientSession) -> Dict[str, Any]:
165
+ """List available prompts from the MCP server."""
166
+ request = {
167
+ "jsonrpc": "2.0",
168
+ "id": f"list_prompts_{int(time.time())}",
169
+ "method": "prompts/list",
170
+ "params": {}
171
+ }
172
+
173
+ headers = {
174
+ "Content-Type": "application/json",
175
+ "Accept": "application/json, text/event-stream",
176
+ **self.server_headers
177
+ }
178
+
179
+ async with session.post(self.server_url, json=request, headers=headers) as response:
180
+ if response.status != 200:
181
+ raise Exception(f"HTTP {response.status}: {await response.text()}")
182
+
183
+ # Handle both JSON and SSE responses
184
+ content_type = response.headers.get('Content-Type', '')
185
+ if 'text/event-stream' in content_type:
186
+ text = await response.text()
187
+ data = self._parse_sse(text)
188
+ else:
189
+ data = await response.json()
190
+
191
+ if "error" in data:
192
+ raise Exception(f"MCP Error: {data['error']}")
193
+
194
+ return data.get("result", {})
195
+
196
+ async def _list_resources(self, session: aiohttp.ClientSession) -> Dict[str, Any]:
197
+ """List available resources from the MCP server."""
198
+ request = {
199
+ "jsonrpc": "2.0",
200
+ "id": f"list_resources_{int(time.time())}",
201
+ "method": "resources/list",
202
+ "params": {}
203
+ }
204
+
205
+ headers = {
206
+ "Content-Type": "application/json",
207
+ "Accept": "application/json, text/event-stream",
208
+ **self.server_headers
209
+ }
210
+
211
+ async with session.post(self.server_url, json=request, headers=headers) as response:
212
+ if response.status != 200:
213
+ raise Exception(f"HTTP {response.status}: {await response.text()}")
214
+
215
+ # Handle both JSON and SSE responses
216
+ content_type = response.headers.get('Content-Type', '')
217
+ if 'text/event-stream' in content_type:
218
+ text = await response.text()
219
+ data = self._parse_sse(text)
220
+ else:
221
+ data = await response.json()
222
+
223
+ if "error" in data:
224
+ raise Exception(f"MCP Error: {data['error']}")
225
+
226
+ return data.get("result", {})
227
+
228
+ def _format_results(self, results: Dict[str, Any], resource_type: str) -> str:
229
+ """Format the inspection results for display."""
230
+ output_lines = [f"=== MCP Server Inspection: {self.server_name} ==="]
231
+ output_lines.append(f"Server URL: {self.server_url}")
232
+ output_lines.append("")
233
+
234
+ # Format tools
235
+ if "tools" in results:
236
+ if "error" in results["tools"]:
237
+ output_lines.append(f"❌ TOOLS: Error - {results['tools']['error']}")
238
+ else:
239
+ tools = results["tools"].get("tools", [])
240
+ output_lines.append(f"🔧 TOOLS ({len(tools)} available):")
241
+ if tools:
242
+ for tool in tools:
243
+ name = tool.get("name", "Unknown")
244
+ desc = tool.get("description", "No description")
245
+ output_lines.append(f" • {name}: {desc}")
246
+ else:
247
+ output_lines.append(" (No tools available)")
248
+ output_lines.append("")
249
+
250
+ # Format prompts
251
+ if "prompts" in results:
252
+ if "error" in results["prompts"]:
253
+ output_lines.append(f"❌ PROMPTS: Error - {results['prompts']['error']}")
254
+ else:
255
+ prompts = results["prompts"].get("prompts", [])
256
+ output_lines.append(f"💬 PROMPTS ({len(prompts)} available):")
257
+ if prompts:
258
+ for prompt in prompts:
259
+ name = prompt.get("name", "Unknown")
260
+ desc = prompt.get("description", "No description")
261
+ output_lines.append(f" • {name}: {desc}")
262
+ else:
263
+ output_lines.append(" (No prompts available)")
264
+ output_lines.append("")
265
+
266
+ # Format resources
267
+ if "resources" in results:
268
+ if "error" in results["resources"]:
269
+ output_lines.append(f"❌ RESOURCES: Error - {results['resources']['error']}")
270
+ else:
271
+ resources = results["resources"].get("resources", [])
272
+ output_lines.append(f"📁 RESOURCES ({len(resources)} available):")
273
+ if resources:
274
+ for resource in resources:
275
+ uri = resource.get("uri", "Unknown")
276
+ name = resource.get("name", uri)
277
+ desc = resource.get("description", "No description")
278
+ output_lines.append(f" • {name}: {desc}")
279
+ output_lines.append(f" URI: {uri}")
280
+ else:
281
+ output_lines.append(" (No resources available)")
282
+ output_lines.append("")
283
+
284
+ return "\n".join(output_lines)