alita-sdk 0.3.486__py3-none-any.whl → 0.3.515__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/agent_loader.py +27 -6
- alita_sdk/cli/agents.py +10 -1
- alita_sdk/cli/inventory.py +12 -195
- alita_sdk/cli/tools/filesystem.py +95 -9
- alita_sdk/community/inventory/__init__.py +12 -0
- alita_sdk/community/inventory/toolkit.py +9 -5
- alita_sdk/community/inventory/toolkit_utils.py +176 -0
- alita_sdk/configurations/ado.py +144 -0
- alita_sdk/configurations/confluence.py +76 -42
- alita_sdk/configurations/figma.py +76 -0
- alita_sdk/configurations/gitlab.py +2 -0
- alita_sdk/configurations/qtest.py +72 -1
- alita_sdk/configurations/report_portal.py +96 -0
- alita_sdk/configurations/sharepoint.py +148 -0
- alita_sdk/configurations/testio.py +83 -0
- alita_sdk/runtime/clients/artifact.py +2 -2
- alita_sdk/runtime/clients/client.py +64 -40
- alita_sdk/runtime/clients/sandbox_client.py +14 -0
- alita_sdk/runtime/langchain/assistant.py +48 -2
- alita_sdk/runtime/langchain/constants.py +3 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaExcelLoader.py +103 -60
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLinesLoader.py +77 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +2 -1
- alita_sdk/runtime/langchain/document_loaders/constants.py +12 -7
- alita_sdk/runtime/langchain/langraph_agent.py +10 -10
- alita_sdk/runtime/langchain/utils.py +6 -1
- alita_sdk/runtime/toolkits/artifact.py +14 -5
- alita_sdk/runtime/toolkits/datasource.py +13 -6
- alita_sdk/runtime/toolkits/mcp.py +94 -219
- alita_sdk/runtime/toolkits/planning.py +13 -6
- alita_sdk/runtime/toolkits/tools.py +60 -25
- alita_sdk/runtime/toolkits/vectorstore.py +11 -5
- alita_sdk/runtime/tools/artifact.py +185 -23
- alita_sdk/runtime/tools/function.py +2 -1
- alita_sdk/runtime/tools/llm.py +155 -34
- alita_sdk/runtime/tools/mcp_remote_tool.py +25 -10
- alita_sdk/runtime/tools/mcp_server_tool.py +2 -4
- alita_sdk/runtime/tools/vectorstore_base.py +3 -3
- alita_sdk/runtime/utils/AlitaCallback.py +136 -21
- alita_sdk/runtime/utils/mcp_client.py +492 -0
- alita_sdk/runtime/utils/mcp_oauth.py +125 -8
- alita_sdk/runtime/utils/mcp_sse_client.py +35 -6
- alita_sdk/runtime/utils/mcp_tools_discovery.py +124 -0
- alita_sdk/runtime/utils/toolkit_utils.py +7 -13
- alita_sdk/runtime/utils/utils.py +2 -0
- alita_sdk/tools/__init__.py +15 -0
- alita_sdk/tools/ado/repos/__init__.py +10 -12
- alita_sdk/tools/ado/test_plan/__init__.py +23 -8
- alita_sdk/tools/ado/wiki/__init__.py +24 -8
- alita_sdk/tools/ado/wiki/ado_wrapper.py +21 -7
- alita_sdk/tools/ado/work_item/__init__.py +24 -8
- alita_sdk/tools/advanced_jira_mining/__init__.py +10 -8
- alita_sdk/tools/aws/delta_lake/__init__.py +12 -9
- alita_sdk/tools/aws/delta_lake/tool.py +5 -1
- alita_sdk/tools/azure_ai/search/__init__.py +9 -7
- alita_sdk/tools/base/tool.py +5 -1
- alita_sdk/tools/base_indexer_toolkit.py +26 -1
- alita_sdk/tools/bitbucket/__init__.py +14 -10
- alita_sdk/tools/bitbucket/api_wrapper.py +50 -2
- alita_sdk/tools/browser/__init__.py +5 -4
- alita_sdk/tools/carrier/__init__.py +5 -6
- alita_sdk/tools/chunkers/sematic/json_chunker.py +1 -0
- alita_sdk/tools/chunkers/sematic/markdown_chunker.py +2 -0
- alita_sdk/tools/chunkers/universal_chunker.py +1 -0
- alita_sdk/tools/cloud/aws/__init__.py +9 -7
- alita_sdk/tools/cloud/azure/__init__.py +9 -7
- alita_sdk/tools/cloud/gcp/__init__.py +9 -7
- alita_sdk/tools/cloud/k8s/__init__.py +9 -7
- alita_sdk/tools/code/linter/__init__.py +9 -8
- alita_sdk/tools/code/loaders/codesearcher.py +3 -2
- alita_sdk/tools/code/sonar/__init__.py +9 -7
- alita_sdk/tools/confluence/__init__.py +15 -10
- alita_sdk/tools/confluence/api_wrapper.py +63 -14
- alita_sdk/tools/custom_open_api/__init__.py +11 -5
- alita_sdk/tools/elastic/__init__.py +10 -8
- alita_sdk/tools/elitea_base.py +387 -9
- alita_sdk/tools/figma/__init__.py +8 -7
- alita_sdk/tools/github/__init__.py +12 -14
- alita_sdk/tools/github/github_client.py +68 -2
- alita_sdk/tools/github/tool.py +5 -1
- alita_sdk/tools/gitlab/__init__.py +14 -11
- alita_sdk/tools/gitlab/api_wrapper.py +81 -1
- alita_sdk/tools/gitlab_org/__init__.py +9 -8
- alita_sdk/tools/google/bigquery/__init__.py +12 -12
- alita_sdk/tools/google/bigquery/tool.py +5 -1
- alita_sdk/tools/google_places/__init__.py +9 -8
- alita_sdk/tools/jira/__init__.py +15 -10
- alita_sdk/tools/keycloak/__init__.py +10 -8
- alita_sdk/tools/localgit/__init__.py +8 -3
- alita_sdk/tools/localgit/local_git.py +62 -54
- alita_sdk/tools/localgit/tool.py +5 -1
- alita_sdk/tools/memory/__init__.py +11 -3
- alita_sdk/tools/ocr/__init__.py +10 -8
- alita_sdk/tools/openapi/__init__.py +6 -2
- alita_sdk/tools/pandas/__init__.py +9 -7
- alita_sdk/tools/postman/__init__.py +10 -11
- alita_sdk/tools/pptx/__init__.py +9 -9
- alita_sdk/tools/qtest/__init__.py +9 -8
- alita_sdk/tools/rally/__init__.py +9 -8
- alita_sdk/tools/report_portal/__init__.py +11 -9
- alita_sdk/tools/salesforce/__init__.py +9 -9
- alita_sdk/tools/servicenow/__init__.py +10 -8
- alita_sdk/tools/sharepoint/__init__.py +9 -8
- alita_sdk/tools/sharepoint/api_wrapper.py +2 -2
- alita_sdk/tools/slack/__init__.py +8 -7
- alita_sdk/tools/sql/__init__.py +9 -8
- alita_sdk/tools/testio/__init__.py +9 -8
- alita_sdk/tools/testrail/__init__.py +10 -8
- alita_sdk/tools/utils/__init__.py +9 -4
- alita_sdk/tools/utils/text_operations.py +254 -0
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +16 -18
- alita_sdk/tools/xray/__init__.py +10 -8
- alita_sdk/tools/yagmail/__init__.py +8 -3
- alita_sdk/tools/zephyr/__init__.py +8 -7
- alita_sdk/tools/zephyr_enterprise/__init__.py +10 -8
- alita_sdk/tools/zephyr_essential/__init__.py +9 -8
- alita_sdk/tools/zephyr_scale/__init__.py +9 -8
- alita_sdk/tools/zephyr_squad/__init__.py +9 -8
- {alita_sdk-0.3.486.dist-info → alita_sdk-0.3.515.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.486.dist-info → alita_sdk-0.3.515.dist-info}/RECORD +124 -119
- {alita_sdk-0.3.486.dist-info → alita_sdk-0.3.515.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.486.dist-info → alita_sdk-0.3.515.dist-info}/entry_points.txt +0 -0
- {alita_sdk-0.3.486.dist-info → alita_sdk-0.3.515.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.486.dist-info → alita_sdk-0.3.515.dist-info}/top_level.txt +0 -0
alita_sdk/runtime/tools/llm.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
3
|
from traceback import format_exc
|
|
4
|
-
from typing import Any, Optional, List, Union
|
|
4
|
+
from typing import Any, Optional, List, Union, Literal
|
|
5
5
|
|
|
6
6
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
|
7
7
|
from langchain_core.runnables import RunnableConfig
|
|
@@ -34,6 +34,7 @@ class LLMNode(BaseTool):
|
|
|
34
34
|
available_tools: Optional[List[BaseTool]] = Field(default=None, description='Available tools for binding')
|
|
35
35
|
tool_names: Optional[List[str]] = Field(default=None, description='Specific tool names to filter')
|
|
36
36
|
steps_limit: Optional[int] = Field(default=25, description='Maximum steps for tool execution')
|
|
37
|
+
tool_execution_timeout: Optional[int] = Field(default=900, description='Timeout (seconds) for tool execution. Default is 15 minutes.')
|
|
37
38
|
|
|
38
39
|
def get_filtered_tools(self) -> List[BaseTool]:
|
|
39
40
|
"""
|
|
@@ -129,7 +130,9 @@ class LLMNode(BaseTool):
|
|
|
129
130
|
# or standalone LLM node for chat (with messages only)
|
|
130
131
|
if 'system' in func_args.keys():
|
|
131
132
|
# Flow for LLM node with prompt/task from pipeline
|
|
132
|
-
if
|
|
133
|
+
if func_args.get('system') is None or func_args.get('task') is None:
|
|
134
|
+
raise ToolException(f"LLMNode requires 'system' and 'task' parameters in input mapping. "
|
|
135
|
+
f"Actual params: {func_args}")
|
|
133
136
|
raise ToolException(f"LLMNode requires 'system' and 'task' parameters in input mapping. "
|
|
134
137
|
f"Actual params: {func_args}")
|
|
135
138
|
# cast to str in case user passes variable different from str
|
|
@@ -171,26 +174,36 @@ class LLMNode(BaseTool):
|
|
|
171
174
|
for key, value in (self.structured_output_dict or {}).items()
|
|
172
175
|
}
|
|
173
176
|
# Add default output field for proper response to user
|
|
174
|
-
struct_params['elitea_response'] = {
|
|
177
|
+
struct_params['elitea_response'] = {
|
|
178
|
+
'description': 'final output to user (summarized output from LLM)', 'type': 'str',
|
|
179
|
+
"default": None}
|
|
175
180
|
struct_model = create_pydantic_model(f"LLMOutput", struct_params)
|
|
176
|
-
|
|
177
|
-
if hasattr(
|
|
181
|
+
initial_completion = llm_client.invoke(messages, config=config)
|
|
182
|
+
if hasattr(initial_completion, 'tool_calls') and initial_completion.tool_calls:
|
|
178
183
|
new_messages, _ = self._run_async_in_sync_context(
|
|
179
|
-
self.__perform_tool_calling(
|
|
184
|
+
self.__perform_tool_calling(initial_completion, messages, llm_client, config)
|
|
180
185
|
)
|
|
181
186
|
llm = self.__get_struct_output_model(llm_client, struct_model)
|
|
182
187
|
completion = llm.invoke(new_messages, config=config)
|
|
183
188
|
result = completion.model_dump()
|
|
184
189
|
else:
|
|
185
|
-
|
|
186
|
-
|
|
190
|
+
try:
|
|
191
|
+
llm = self.__get_struct_output_model(llm_client, struct_model)
|
|
192
|
+
completion = llm.invoke(messages, config=config)
|
|
193
|
+
except ValueError as e:
|
|
194
|
+
logger.error(f"Error invoking structured output model: {format_exc()}")
|
|
195
|
+
logger.info("Attemping to fall back to json mode")
|
|
196
|
+
# Fallback to regular LLM with JSON extraction
|
|
197
|
+
completion = self.__get_struct_output_model(llm_client, struct_model,
|
|
198
|
+
method="json_mode").invoke(messages, config=config)
|
|
187
199
|
result = completion.model_dump()
|
|
188
200
|
|
|
189
201
|
# Ensure messages are properly formatted
|
|
190
202
|
if result.get('messages') and isinstance(result['messages'], list):
|
|
191
203
|
result['messages'] = [{'role': 'assistant', 'content': '\n'.join(result['messages'])}]
|
|
192
204
|
else:
|
|
193
|
-
result['messages'] = messages + [
|
|
205
|
+
result['messages'] = messages + [
|
|
206
|
+
AIMessage(content=result.get(ELITEA_RS, '') or initial_completion.content)]
|
|
194
207
|
|
|
195
208
|
return result
|
|
196
209
|
else:
|
|
@@ -243,40 +256,146 @@ class LLMNode(BaseTool):
|
|
|
243
256
|
|
|
244
257
|
For MCP tools with persistent sessions, we reuse the same event loop
|
|
245
258
|
that was used to create the MCP client and sessions (set by CLI).
|
|
259
|
+
|
|
260
|
+
When called from within a running event loop (e.g., nested LLM nodes),
|
|
261
|
+
we need to handle this carefully to avoid "event loop already running" errors.
|
|
262
|
+
|
|
263
|
+
This method handles three scenarios:
|
|
264
|
+
1. Called from async context (event loop running) - creates new thread with new loop
|
|
265
|
+
2. Called from sync context with persistent loop - reuses persistent loop
|
|
266
|
+
3. Called from sync context without loop - creates new persistent loop
|
|
246
267
|
"""
|
|
268
|
+
import threading
|
|
269
|
+
|
|
270
|
+
# Check if there's a running loop
|
|
247
271
|
try:
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
272
|
+
running_loop = asyncio.get_running_loop()
|
|
273
|
+
loop_is_running = True
|
|
274
|
+
logger.debug(f"Detected running event loop (id: {id(running_loop)}), executing tool calls in separate thread")
|
|
275
|
+
except RuntimeError:
|
|
276
|
+
loop_is_running = False
|
|
277
|
+
|
|
278
|
+
# Scenario 1: Loop is currently running - MUST use thread
|
|
279
|
+
if loop_is_running:
|
|
252
280
|
result_container = []
|
|
253
|
-
|
|
281
|
+
exception_container = []
|
|
282
|
+
|
|
283
|
+
# Try to capture Streamlit context from current thread for propagation
|
|
284
|
+
streamlit_ctx = None
|
|
285
|
+
try:
|
|
286
|
+
from streamlit.runtime.scriptrunner import get_script_run_ctx, add_script_run_ctx
|
|
287
|
+
streamlit_ctx = get_script_run_ctx()
|
|
288
|
+
if streamlit_ctx:
|
|
289
|
+
logger.debug("Captured Streamlit context for propagation to worker thread")
|
|
290
|
+
except (ImportError, Exception) as e:
|
|
291
|
+
logger.debug(f"Streamlit context not available or failed to capture: {e}")
|
|
292
|
+
|
|
254
293
|
def run_in_thread():
|
|
294
|
+
"""Run coroutine in a new thread with its own event loop."""
|
|
255
295
|
new_loop = asyncio.new_event_loop()
|
|
256
296
|
asyncio.set_event_loop(new_loop)
|
|
257
297
|
try:
|
|
258
|
-
|
|
298
|
+
result = new_loop.run_until_complete(coro)
|
|
299
|
+
result_container.append(result)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.debug(f"Exception in async thread: {e}")
|
|
302
|
+
exception_container.append(e)
|
|
259
303
|
finally:
|
|
260
304
|
new_loop.close()
|
|
261
|
-
|
|
262
|
-
|
|
305
|
+
asyncio.set_event_loop(None)
|
|
306
|
+
|
|
307
|
+
thread = threading.Thread(target=run_in_thread, daemon=False)
|
|
308
|
+
|
|
309
|
+
# Propagate Streamlit context to the worker thread if available
|
|
310
|
+
if streamlit_ctx is not None:
|
|
311
|
+
try:
|
|
312
|
+
add_script_run_ctx(thread, streamlit_ctx)
|
|
313
|
+
logger.debug("Successfully propagated Streamlit context to worker thread")
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.warning(f"Failed to propagate Streamlit context to worker thread: {e}")
|
|
316
|
+
|
|
263
317
|
thread.start()
|
|
264
|
-
thread.join()
|
|
318
|
+
thread.join(timeout=self.tool_execution_timeout) # 15 minute timeout for safety
|
|
319
|
+
|
|
320
|
+
if thread.is_alive():
|
|
321
|
+
logger.error("Async operation timed out after 5 minutes")
|
|
322
|
+
raise TimeoutError("Async operation in thread timed out")
|
|
323
|
+
|
|
324
|
+
# Re-raise exception if one occurred
|
|
325
|
+
if exception_container:
|
|
326
|
+
raise exception_container[0]
|
|
327
|
+
|
|
265
328
|
return result_container[0] if result_container else None
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
#
|
|
329
|
+
|
|
330
|
+
# Scenario 2 & 3: No loop running - use or create persistent loop
|
|
331
|
+
else:
|
|
332
|
+
# Get or create persistent loop
|
|
270
333
|
if not hasattr(self.__class__, '_persistent_loop') or \
|
|
271
334
|
self.__class__._persistent_loop is None or \
|
|
272
335
|
self.__class__._persistent_loop.is_closed():
|
|
273
336
|
self.__class__._persistent_loop = asyncio.new_event_loop()
|
|
274
337
|
logger.debug("Created persistent event loop for async tools")
|
|
275
|
-
|
|
338
|
+
|
|
276
339
|
loop = self.__class__._persistent_loop
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
340
|
+
|
|
341
|
+
# Double-check the loop is not running (safety check)
|
|
342
|
+
if loop.is_running():
|
|
343
|
+
logger.debug("Persistent loop is unexpectedly running, using thread execution")
|
|
344
|
+
|
|
345
|
+
result_container = []
|
|
346
|
+
exception_container = []
|
|
347
|
+
|
|
348
|
+
# Try to capture Streamlit context from current thread for propagation
|
|
349
|
+
streamlit_ctx = None
|
|
350
|
+
try:
|
|
351
|
+
from streamlit.runtime.scriptrunner import get_script_run_ctx, add_script_run_ctx
|
|
352
|
+
streamlit_ctx = get_script_run_ctx()
|
|
353
|
+
if streamlit_ctx:
|
|
354
|
+
logger.debug("Captured Streamlit context for propagation to worker thread")
|
|
355
|
+
except (ImportError, Exception) as e:
|
|
356
|
+
logger.debug(f"Streamlit context not available or failed to capture: {e}")
|
|
357
|
+
|
|
358
|
+
def run_in_thread():
|
|
359
|
+
"""Run coroutine in a new thread with its own event loop."""
|
|
360
|
+
new_loop = asyncio.new_event_loop()
|
|
361
|
+
asyncio.set_event_loop(new_loop)
|
|
362
|
+
try:
|
|
363
|
+
result = new_loop.run_until_complete(coro)
|
|
364
|
+
result_container.append(result)
|
|
365
|
+
except Exception as ex:
|
|
366
|
+
logger.debug(f"Exception in async thread: {ex}")
|
|
367
|
+
exception_container.append(ex)
|
|
368
|
+
finally:
|
|
369
|
+
new_loop.close()
|
|
370
|
+
asyncio.set_event_loop(None)
|
|
371
|
+
|
|
372
|
+
thread = threading.Thread(target=run_in_thread, daemon=False)
|
|
373
|
+
|
|
374
|
+
# Propagate Streamlit context to the worker thread if available
|
|
375
|
+
if streamlit_ctx is not None:
|
|
376
|
+
try:
|
|
377
|
+
add_script_run_ctx(thread, streamlit_ctx)
|
|
378
|
+
logger.debug("Successfully propagated Streamlit context to worker thread")
|
|
379
|
+
except Exception as e:
|
|
380
|
+
logger.warning(f"Failed to propagate Streamlit context to worker thread: {e}")
|
|
381
|
+
|
|
382
|
+
thread.start()
|
|
383
|
+
thread.join(timeout=self.tool_execution_timeout)
|
|
384
|
+
|
|
385
|
+
if thread.is_alive():
|
|
386
|
+
logger.error("Async operation timed out after 15 minutes")
|
|
387
|
+
raise TimeoutError("Async operation in thread timed out")
|
|
388
|
+
|
|
389
|
+
if exception_container:
|
|
390
|
+
raise exception_container[0]
|
|
391
|
+
|
|
392
|
+
return result_container[0] if result_container else None
|
|
393
|
+
else:
|
|
394
|
+
# Loop exists but not running - safe to use run_until_complete
|
|
395
|
+
logger.debug(f"Using persistent loop (id: {id(loop)}) with run_until_complete")
|
|
396
|
+
asyncio.set_event_loop(loop)
|
|
397
|
+
return loop.run_until_complete(coro)
|
|
398
|
+
|
|
280
399
|
async def _arun(self, *args, **kwargs):
|
|
281
400
|
# Legacy async support
|
|
282
401
|
return self.invoke(kwargs, **kwargs)
|
|
@@ -324,12 +443,14 @@ class LLMNode(BaseTool):
|
|
|
324
443
|
|
|
325
444
|
# Try async invoke first (for MCP tools), fallback to sync
|
|
326
445
|
tool_result = None
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
446
|
+
if hasattr(tool_to_execute, 'ainvoke'):
|
|
447
|
+
try:
|
|
448
|
+
tool_result = await tool_to_execute.ainvoke(tool_args, config=config)
|
|
449
|
+
except (NotImplementedError, AttributeError):
|
|
450
|
+
logger.debug(f"Tool '{tool_name}' ainvoke failed, falling back to sync invoke")
|
|
451
|
+
tool_result = tool_to_execute.invoke(tool_args, config=config)
|
|
452
|
+
else:
|
|
453
|
+
# Sync-only tool
|
|
333
454
|
tool_result = tool_to_execute.invoke(tool_args, config=config)
|
|
334
455
|
|
|
335
456
|
# Create tool message with result - preserve structured content
|
|
@@ -539,5 +660,5 @@ class LLMNode(BaseTool):
|
|
|
539
660
|
|
|
540
661
|
return new_messages, current_completion
|
|
541
662
|
|
|
542
|
-
def __get_struct_output_model(self, llm_client, pydantic_model):
|
|
543
|
-
return llm_client.with_structured_output(pydantic_model)
|
|
663
|
+
def __get_struct_output_model(self, llm_client, pydantic_model, method: Literal["function_calling", "json_mode", "json_schema"] = "json_schema"):
|
|
664
|
+
return llm_client.with_structured_output(pydantic_model, method=method)
|
|
@@ -20,10 +20,14 @@ from ..utils.mcp_oauth import (
|
|
|
20
20
|
fetch_resource_metadata_async,
|
|
21
21
|
infer_authorization_servers_from_realm,
|
|
22
22
|
)
|
|
23
|
-
from ..utils.
|
|
23
|
+
from ..utils.mcp_client import McpClient
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
|
+
# Global registry to store MCP tool session metadata by tool name
|
|
28
|
+
# This is used to pass session info to callbacks since LangChain's serialization doesn't include all fields
|
|
29
|
+
MCP_TOOL_SESSION_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
|
30
|
+
|
|
27
31
|
|
|
28
32
|
class McpRemoteTool(McpServerTool):
|
|
29
33
|
"""
|
|
@@ -43,6 +47,7 @@ class McpRemoteTool(McpServerTool):
|
|
|
43
47
|
"""Update metadata with session info after model initialization."""
|
|
44
48
|
super().model_post_init(__context)
|
|
45
49
|
self._update_metadata_with_session()
|
|
50
|
+
self._register_session_metadata()
|
|
46
51
|
|
|
47
52
|
def _update_metadata_with_session(self):
|
|
48
53
|
"""Update the metadata dict with current session information."""
|
|
@@ -54,6 +59,15 @@ class McpRemoteTool(McpServerTool):
|
|
|
54
59
|
'mcp_server_url': canonical_resource(self.server_url)
|
|
55
60
|
})
|
|
56
61
|
|
|
62
|
+
def _register_session_metadata(self):
|
|
63
|
+
"""Register session metadata in global registry for callback access."""
|
|
64
|
+
if self.session_id and self.server_url:
|
|
65
|
+
MCP_TOOL_SESSION_REGISTRY[self.name] = {
|
|
66
|
+
'mcp_session_id': self.session_id,
|
|
67
|
+
'mcp_server_url': canonical_resource(self.server_url)
|
|
68
|
+
}
|
|
69
|
+
logger.debug(f"[MCP] Registered session metadata for tool '{self.name}': session={self.session_id}")
|
|
70
|
+
|
|
57
71
|
def __getstate__(self):
|
|
58
72
|
"""Custom serialization for pickle compatibility."""
|
|
59
73
|
state = super().__getstate__()
|
|
@@ -85,7 +99,6 @@ class McpRemoteTool(McpServerTool):
|
|
|
85
99
|
|
|
86
100
|
async def _execute_remote_tool(self, kwargs: Dict[str, Any]) -> str:
|
|
87
101
|
"""Execute the actual remote MCP tool call using SSE client."""
|
|
88
|
-
from ...tools.utils import TOOLKIT_SPLITTER
|
|
89
102
|
|
|
90
103
|
# Check for session_id requirement
|
|
91
104
|
if not self.session_id:
|
|
@@ -95,10 +108,10 @@ class McpRemoteTool(McpServerTool):
|
|
|
95
108
|
# Use the original tool name from discovery for MCP server invocation
|
|
96
109
|
tool_name_for_server = self.original_tool_name
|
|
97
110
|
if not tool_name_for_server:
|
|
98
|
-
tool_name_for_server = self.name
|
|
99
|
-
logger.warning(f"original_tool_name not set for '{self.name}', using
|
|
111
|
+
tool_name_for_server = self.name
|
|
112
|
+
logger.warning(f"original_tool_name not set for '{self.name}', using: {tool_name_for_server}")
|
|
100
113
|
|
|
101
|
-
logger.info(f"[MCP
|
|
114
|
+
logger.info(f"[MCP] Executing tool '{tool_name_for_server}' with session {self.session_id}")
|
|
102
115
|
|
|
103
116
|
try:
|
|
104
117
|
# Prepare headers
|
|
@@ -106,16 +119,18 @@ class McpRemoteTool(McpServerTool):
|
|
|
106
119
|
if self.server_headers:
|
|
107
120
|
headers.update(self.server_headers)
|
|
108
121
|
|
|
109
|
-
# Create
|
|
110
|
-
client =
|
|
122
|
+
# Create unified MCP client (auto-detects transport)
|
|
123
|
+
client = McpClient(
|
|
111
124
|
url=self.server_url,
|
|
112
125
|
session_id=self.session_id,
|
|
113
126
|
headers=headers,
|
|
114
127
|
timeout=self.tool_timeout_sec
|
|
115
128
|
)
|
|
116
129
|
|
|
117
|
-
# Execute tool call
|
|
118
|
-
|
|
130
|
+
# Execute tool call (client auto-detects SSE vs Streamable HTTP)
|
|
131
|
+
async with client:
|
|
132
|
+
await client.initialize()
|
|
133
|
+
result = await client.call_tool(tool_name_for_server, kwargs)
|
|
119
134
|
|
|
120
135
|
# Format the result
|
|
121
136
|
if isinstance(result, dict):
|
|
@@ -144,7 +159,7 @@ class McpRemoteTool(McpServerTool):
|
|
|
144
159
|
return str(result)
|
|
145
160
|
|
|
146
161
|
except Exception as e:
|
|
147
|
-
logger.error(f"[MCP
|
|
162
|
+
logger.error(f"[MCP] Tool execution failed: {e}", exc_info=True)
|
|
148
163
|
raise
|
|
149
164
|
|
|
150
165
|
def _parse_sse(self, text: str) -> Dict[str, Any]:
|
|
@@ -5,8 +5,6 @@ from typing import Any, Type, Literal, Optional, Union, List
|
|
|
5
5
|
from langchain_core.tools import BaseTool
|
|
6
6
|
from pydantic import BaseModel, Field, create_model, EmailStr, constr, ConfigDict
|
|
7
7
|
|
|
8
|
-
from ...tools.utils import TOOLKIT_SPLITTER
|
|
9
|
-
|
|
10
8
|
logger = getLogger(__name__)
|
|
11
9
|
|
|
12
10
|
|
|
@@ -91,13 +89,13 @@ class McpServerTool(BaseTool):
|
|
|
91
89
|
return create_model(model_name, **fields)
|
|
92
90
|
|
|
93
91
|
def _run(self, *args, **kwargs):
|
|
94
|
-
#
|
|
92
|
+
# Use the tool name directly (no prefix extraction needed)
|
|
95
93
|
call_data = {
|
|
96
94
|
"server": self.server,
|
|
97
95
|
"tool_timeout_sec": self.tool_timeout_sec,
|
|
98
96
|
"tool_call_id": str(uuid.uuid4()),
|
|
99
97
|
"params": {
|
|
100
|
-
"name": self.name
|
|
98
|
+
"name": self.name,
|
|
101
99
|
"arguments": kwargs
|
|
102
100
|
}
|
|
103
101
|
}
|
|
@@ -270,7 +270,7 @@ class VectorStoreWrapperBase(BaseToolApiWrapper):
|
|
|
270
270
|
)
|
|
271
271
|
).count()
|
|
272
272
|
|
|
273
|
-
def _clean_collection(self, index_name: str = ''):
|
|
273
|
+
def _clean_collection(self, index_name: str = '', including_index_meta: bool = False):
|
|
274
274
|
"""
|
|
275
275
|
Clean the vectorstore collection by deleting all indexed data.
|
|
276
276
|
"""
|
|
@@ -279,7 +279,7 @@ class VectorStoreWrapperBase(BaseToolApiWrapper):
|
|
|
279
279
|
f"Cleaning collection '{self.dataset}'",
|
|
280
280
|
tool_name="_clean_collection"
|
|
281
281
|
)
|
|
282
|
-
self.vector_adapter.clean_collection(self, index_name)
|
|
282
|
+
self.vector_adapter.clean_collection(self, index_name, including_index_meta)
|
|
283
283
|
self._log_tool_event(
|
|
284
284
|
f"Collection '{self.dataset}' has been cleaned. ",
|
|
285
285
|
tool_name="_clean_collection"
|
|
@@ -303,7 +303,7 @@ class VectorStoreWrapperBase(BaseToolApiWrapper):
|
|
|
303
303
|
logger.info("Cleaning index before re-indexing all documents.")
|
|
304
304
|
self._log_tool_event("Cleaning index before re-indexing all documents. Previous index will be removed", tool_name="index_documents")
|
|
305
305
|
try:
|
|
306
|
-
self._clean_collection(index_name)
|
|
306
|
+
self._clean_collection(index_name, including_index_meta=False)
|
|
307
307
|
self._log_tool_event("Previous index has been removed",
|
|
308
308
|
tool_name="index_documents")
|
|
309
309
|
except Exception as e:
|
|
@@ -23,9 +23,45 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
23
23
|
self.tokens_out = 0
|
|
24
24
|
self.pending_llm_requests = defaultdict(int)
|
|
25
25
|
self.current_model_name = 'gpt-4'
|
|
26
|
+
self._event_queue = [] # Queue for events when context is unavailable
|
|
26
27
|
#
|
|
27
28
|
super().__init__()
|
|
28
29
|
|
|
30
|
+
def _has_streamlit_context(self) -> bool:
|
|
31
|
+
"""Check if Streamlit context is available in the current thread."""
|
|
32
|
+
try:
|
|
33
|
+
# Try to import streamlit runtime context checker
|
|
34
|
+
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
|
35
|
+
ctx = get_script_run_ctx()
|
|
36
|
+
return ctx is not None
|
|
37
|
+
except (ImportError, Exception) as e:
|
|
38
|
+
if self.debug:
|
|
39
|
+
log.debug(f"Streamlit context check failed: {e}")
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
def _safe_streamlit_call(self, func, *args, **kwargs):
|
|
43
|
+
"""Safely execute a Streamlit UI operation, handling missing context gracefully."""
|
|
44
|
+
if not self._has_streamlit_context():
|
|
45
|
+
func_name = getattr(func, '__name__', str(func))
|
|
46
|
+
if self.debug:
|
|
47
|
+
log.warning(f"Streamlit context not available for {func_name}, queueing event")
|
|
48
|
+
# Store the event for potential replay when context is available
|
|
49
|
+
self._event_queue.append({
|
|
50
|
+
'func': func_name,
|
|
51
|
+
'args': args,
|
|
52
|
+
'kwargs': kwargs,
|
|
53
|
+
'timestamp': datetime.now(tz=timezone.utc)
|
|
54
|
+
})
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
return func(*args, **kwargs)
|
|
59
|
+
except Exception as e:
|
|
60
|
+
func_name = getattr(func, '__name__', str(func))
|
|
61
|
+
# Handle any Streamlit-specific exceptions gracefully
|
|
62
|
+
log.warning(f"Streamlit operation {func_name} failed: {e}")
|
|
63
|
+
return None
|
|
64
|
+
|
|
29
65
|
#
|
|
30
66
|
# Chain
|
|
31
67
|
#
|
|
@@ -76,10 +112,14 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
76
112
|
json.dumps(payload, ensure_ascii=False, default=lambda o: str(o))
|
|
77
113
|
)
|
|
78
114
|
|
|
79
|
-
|
|
80
|
-
|
|
115
|
+
status_widget = self._safe_streamlit_call(
|
|
116
|
+
self.st.status,
|
|
117
|
+
f"Running {payload.get('tool_name')}...",
|
|
118
|
+
expanded=True
|
|
81
119
|
)
|
|
82
|
-
|
|
120
|
+
if status_widget:
|
|
121
|
+
self.callback_state[str(run_id)] = status_widget
|
|
122
|
+
self._safe_streamlit_call(status_widget.write, f"Tool inputs: {payload}")
|
|
83
123
|
|
|
84
124
|
def on_tool_start(self, *args, run_id: UUID, **kwargs):
|
|
85
125
|
""" Callback """
|
|
@@ -88,15 +128,51 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
88
128
|
|
|
89
129
|
tool_name = args[0].get("name")
|
|
90
130
|
tool_run_id = str(run_id)
|
|
131
|
+
|
|
132
|
+
# Extract metadata from tool if available (from BaseAction.metadata)
|
|
133
|
+
# Try multiple sources for metadata with toolkit_name
|
|
134
|
+
tool_meta = args[0].copy()
|
|
135
|
+
|
|
136
|
+
# Source 1: kwargs['serialized']['metadata'] - LangChain's full tool serialization
|
|
137
|
+
if 'serialized' in kwargs and 'metadata' in kwargs['serialized']:
|
|
138
|
+
tool_meta['metadata'] = kwargs['serialized']['metadata']
|
|
139
|
+
log.info(f"[METADATA] Extracted from serialized: {kwargs['serialized']['metadata']}")
|
|
140
|
+
# Source 2: Check if metadata is directly in args[0] (some LangChain versions)
|
|
141
|
+
elif 'metadata' in args[0]:
|
|
142
|
+
tool_meta['metadata'] = args[0]['metadata']
|
|
143
|
+
log.info(f"[METADATA] Extracted from args[0]: {args[0]['metadata']}")
|
|
144
|
+
else:
|
|
145
|
+
log.info(f"[METADATA] No metadata found. args[0] keys: {list(args[0].keys())}, kwargs keys: {list(kwargs.keys())}")
|
|
146
|
+
# Fallback: Try to extract toolkit_name from description
|
|
147
|
+
description = args[0].get('description', '')
|
|
148
|
+
if description:
|
|
149
|
+
import re
|
|
150
|
+
# Try pattern 1: [Toolkit: name]
|
|
151
|
+
match = re.search(r'\[Toolkit:\s*([^\]]+)\]', description)
|
|
152
|
+
if not match:
|
|
153
|
+
# Try pattern 2: Toolkit: name at start or end
|
|
154
|
+
match = re.search(r'(?:^|\n)Toolkit:\s*([^\n]+)', description)
|
|
155
|
+
if match:
|
|
156
|
+
toolkit_name = match.group(1).strip()
|
|
157
|
+
tool_meta['metadata'] = {'toolkit_name': toolkit_name}
|
|
158
|
+
log.info(f"[METADATA] Extracted toolkit_name from description: {toolkit_name}")
|
|
159
|
+
|
|
91
160
|
payload = {
|
|
92
161
|
"tool_name": tool_name,
|
|
93
162
|
"tool_run_id": tool_run_id,
|
|
94
|
-
"tool_meta":
|
|
163
|
+
"tool_meta": tool_meta,
|
|
95
164
|
"tool_inputs": kwargs.get('inputs')
|
|
96
165
|
}
|
|
97
166
|
payload = json.loads(json.dumps(payload, ensure_ascii=False, default=lambda o: str(o)))
|
|
98
|
-
|
|
99
|
-
|
|
167
|
+
|
|
168
|
+
status_widget = self._safe_streamlit_call(
|
|
169
|
+
self.st.status,
|
|
170
|
+
f"Running {tool_name}...",
|
|
171
|
+
expanded=True
|
|
172
|
+
)
|
|
173
|
+
if status_widget:
|
|
174
|
+
self.callback_state[tool_run_id] = status_widget
|
|
175
|
+
self._safe_streamlit_call(status_widget.write, f"Tool inputs: {kwargs.get('inputs')}")
|
|
100
176
|
|
|
101
177
|
def on_tool_end(self, *args, run_id: UUID, **kwargs):
|
|
102
178
|
""" Callback """
|
|
@@ -104,11 +180,16 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
104
180
|
log.info("on_tool_end(%s, %s)", args, kwargs)
|
|
105
181
|
tool_run_id = str(run_id)
|
|
106
182
|
tool_output = args[0]
|
|
107
|
-
if self.callback_state
|
|
108
|
-
self.callback_state[tool_run_id]
|
|
109
|
-
self.
|
|
183
|
+
if self.callback_state.get(tool_run_id):
|
|
184
|
+
status_widget = self.callback_state[tool_run_id]
|
|
185
|
+
self._safe_streamlit_call(status_widget.write, f"Tool output: {tool_output}")
|
|
186
|
+
self._safe_streamlit_call(
|
|
187
|
+
status_widget.update,
|
|
188
|
+
label=f"Completed {kwargs.get('name')}",
|
|
189
|
+
state="complete",
|
|
190
|
+
expanded=False
|
|
191
|
+
)
|
|
110
192
|
self.callback_state.pop(tool_run_id, None)
|
|
111
|
-
del self.callback_state[run_id]
|
|
112
193
|
|
|
113
194
|
def on_tool_error(self, *args, run_id: UUID, **kwargs):
|
|
114
195
|
""" Callback """
|
|
@@ -116,9 +197,19 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
116
197
|
log.info("on_tool_error(%s, %s)", args, kwargs)
|
|
117
198
|
tool_run_id = str(run_id)
|
|
118
199
|
tool_exception = args[0]
|
|
119
|
-
self.callback_state
|
|
120
|
-
|
|
121
|
-
|
|
200
|
+
if self.callback_state.get(tool_run_id):
|
|
201
|
+
status_widget = self.callback_state[tool_run_id]
|
|
202
|
+
self._safe_streamlit_call(
|
|
203
|
+
status_widget.write,
|
|
204
|
+
f"{traceback.format_exception(tool_exception)}"
|
|
205
|
+
)
|
|
206
|
+
self._safe_streamlit_call(
|
|
207
|
+
status_widget.update,
|
|
208
|
+
label=f"Error {kwargs.get('name')}",
|
|
209
|
+
state="error",
|
|
210
|
+
expanded=False
|
|
211
|
+
)
|
|
212
|
+
self.callback_state.pop(tool_run_id, None)
|
|
122
213
|
|
|
123
214
|
#
|
|
124
215
|
# Agent
|
|
@@ -156,8 +247,14 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
156
247
|
self.current_model_name = metadata.get('ls_model_name', self.current_model_name)
|
|
157
248
|
llm_run_id = str(run_id)
|
|
158
249
|
|
|
159
|
-
|
|
160
|
-
|
|
250
|
+
status_widget = self._safe_streamlit_call(
|
|
251
|
+
self.st.status,
|
|
252
|
+
f"Running LLM ...",
|
|
253
|
+
expanded=True
|
|
254
|
+
)
|
|
255
|
+
if status_widget:
|
|
256
|
+
self.callback_state[llm_run_id] = status_widget
|
|
257
|
+
self._safe_streamlit_call(status_widget.write, f"LLM inputs: {messages}")
|
|
161
258
|
|
|
162
259
|
def on_llm_start(self, *args, **kwargs):
|
|
163
260
|
""" Callback """
|
|
@@ -178,16 +275,27 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
178
275
|
content = None
|
|
179
276
|
if chunk:
|
|
180
277
|
content = chunk.text
|
|
181
|
-
|
|
278
|
+
|
|
279
|
+
llm_run_id = str(run_id)
|
|
280
|
+
if self.callback_state.get(llm_run_id):
|
|
281
|
+
status_widget = self.callback_state[llm_run_id]
|
|
282
|
+
self._safe_streamlit_call(status_widget.write, content)
|
|
182
283
|
|
|
183
284
|
def on_llm_error(self, *args, run_id: UUID, **kwargs):
|
|
184
285
|
""" Callback """
|
|
185
286
|
if self.debug:
|
|
186
287
|
log.error("on_llm_error(%s, %s)", args, kwargs)
|
|
187
288
|
llm_run_id = str(run_id)
|
|
188
|
-
self.callback_state
|
|
189
|
-
|
|
190
|
-
|
|
289
|
+
if self.callback_state.get(llm_run_id):
|
|
290
|
+
status_widget = self.callback_state[llm_run_id]
|
|
291
|
+
self._safe_streamlit_call(status_widget.write, f"on_llm_error({args}, {kwargs})")
|
|
292
|
+
self._safe_streamlit_call(
|
|
293
|
+
status_widget.update,
|
|
294
|
+
label=f"Error {kwargs.get('name')}",
|
|
295
|
+
state="error",
|
|
296
|
+
expanded=False
|
|
297
|
+
)
|
|
298
|
+
self.callback_state.pop(llm_run_id, None)
|
|
191
299
|
#
|
|
192
300
|
# exception = args[0]
|
|
193
301
|
# FIXME: should we emit an error here too?
|
|
@@ -205,5 +313,12 @@ class AlitaStreamlitCallback(BaseCallbackHandler):
|
|
|
205
313
|
if self.debug:
|
|
206
314
|
log.debug("on_llm_end(%s, %s)", response, kwargs)
|
|
207
315
|
llm_run_id = str(run_id)
|
|
208
|
-
self.callback_state
|
|
209
|
-
|
|
316
|
+
if self.callback_state.get(llm_run_id):
|
|
317
|
+
status_widget = self.callback_state[llm_run_id]
|
|
318
|
+
self._safe_streamlit_call(
|
|
319
|
+
status_widget.update,
|
|
320
|
+
label=f"Completed LLM call",
|
|
321
|
+
state="complete",
|
|
322
|
+
expanded=False
|
|
323
|
+
)
|
|
324
|
+
self.callback_state.pop(llm_run_id, None)
|