alita-sdk 0.3.449__py3-none-any.whl → 0.3.465__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent/__init__.py +0 -0
- alita_sdk/cli/agent/default.py +176 -0
- alita_sdk/cli/agent_executor.py +155 -0
- alita_sdk/cli/agent_loader.py +197 -0
- alita_sdk/cli/agent_ui.py +218 -0
- alita_sdk/cli/agents.py +1911 -0
- alita_sdk/cli/callbacks.py +576 -0
- alita_sdk/cli/cli.py +159 -0
- alita_sdk/cli/config.py +164 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/input_handler.py +256 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/toolkit.py +330 -0
- alita_sdk/cli/toolkit_loader.py +55 -0
- alita_sdk/cli/tools/__init__.py +36 -0
- alita_sdk/cli/tools/approval.py +224 -0
- alita_sdk/cli/tools/filesystem.py +905 -0
- alita_sdk/cli/tools/planning.py +403 -0
- alita_sdk/cli/tools/terminal.py +280 -0
- alita_sdk/runtime/clients/client.py +16 -1
- alita_sdk/runtime/langchain/constants.py +2 -1
- alita_sdk/runtime/langchain/langraph_agent.py +74 -20
- alita_sdk/runtime/langchain/utils.py +20 -4
- alita_sdk/runtime/toolkits/artifact.py +5 -6
- alita_sdk/runtime/toolkits/mcp.py +5 -2
- alita_sdk/runtime/toolkits/tools.py +1 -0
- alita_sdk/runtime/tools/function.py +19 -6
- alita_sdk/runtime/tools/llm.py +65 -7
- alita_sdk/runtime/tools/vectorstore_base.py +17 -2
- alita_sdk/runtime/utils/mcp_sse_client.py +64 -6
- alita_sdk/tools/ado/repos/__init__.py +1 -0
- alita_sdk/tools/ado/test_plan/__init__.py +1 -1
- alita_sdk/tools/ado/wiki/__init__.py +1 -5
- alita_sdk/tools/ado/work_item/__init__.py +1 -5
- alita_sdk/tools/base_indexer_toolkit.py +64 -8
- alita_sdk/tools/bitbucket/__init__.py +1 -0
- alita_sdk/tools/code/sonar/__init__.py +1 -1
- alita_sdk/tools/confluence/__init__.py +2 -2
- alita_sdk/tools/github/__init__.py +2 -2
- alita_sdk/tools/gitlab/__init__.py +2 -1
- alita_sdk/tools/gitlab_org/__init__.py +1 -2
- alita_sdk/tools/google_places/__init__.py +2 -1
- alita_sdk/tools/jira/__init__.py +1 -0
- alita_sdk/tools/memory/__init__.py +1 -1
- alita_sdk/tools/pandas/__init__.py +1 -1
- alita_sdk/tools/postman/__init__.py +2 -1
- alita_sdk/tools/pptx/__init__.py +2 -2
- alita_sdk/tools/qtest/__init__.py +3 -3
- alita_sdk/tools/qtest/api_wrapper.py +1235 -51
- alita_sdk/tools/rally/__init__.py +1 -2
- alita_sdk/tools/report_portal/__init__.py +1 -0
- alita_sdk/tools/salesforce/__init__.py +1 -0
- alita_sdk/tools/servicenow/__init__.py +2 -3
- alita_sdk/tools/sharepoint/__init__.py +1 -0
- alita_sdk/tools/sharepoint/api_wrapper.py +22 -2
- alita_sdk/tools/sharepoint/authorization_helper.py +17 -1
- alita_sdk/tools/slack/__init__.py +1 -0
- alita_sdk/tools/sql/__init__.py +2 -1
- alita_sdk/tools/testio/__init__.py +1 -0
- alita_sdk/tools/testrail/__init__.py +1 -3
- alita_sdk/tools/xray/__init__.py +2 -1
- alita_sdk/tools/zephyr/__init__.py +2 -1
- alita_sdk/tools/zephyr_enterprise/__init__.py +1 -0
- alita_sdk/tools/zephyr_essential/__init__.py +1 -0
- alita_sdk/tools/zephyr_scale/__init__.py +1 -0
- alita_sdk/tools/zephyr_squad/__init__.py +1 -0
- {alita_sdk-0.3.449.dist-info → alita_sdk-0.3.465.dist-info}/METADATA +145 -2
- {alita_sdk-0.3.449.dist-info → alita_sdk-0.3.465.dist-info}/RECORD +74 -52
- alita_sdk-0.3.465.dist-info/entry_points.txt +2 -0
- {alita_sdk-0.3.449.dist-info → alita_sdk-0.3.465.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.449.dist-info → alita_sdk-0.3.465.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.449.dist-info → alita_sdk-0.3.465.dist-info}/top_level.txt +0 -0
|
@@ -68,6 +68,7 @@ class AlitaClient:
|
|
|
68
68
|
self.bucket_url = f"{self.base_url}{self.api_path}/artifacts/buckets/{self.project_id}"
|
|
69
69
|
self.configurations_url = f'{self.base_url}{self.api_path}/integrations/integrations/default/{self.project_id}?section=configurations&unsecret=true'
|
|
70
70
|
self.ai_section_url = f'{self.base_url}{self.api_path}/integrations/integrations/default/{self.project_id}?section=ai'
|
|
71
|
+
self.models_url = f'{self.base_url}{self.api_path}/configurations/models/{self.project_id}?include_shared=true'
|
|
71
72
|
self.image_generation_url = f"{self.base_url}{self.llm_path}/images/generations"
|
|
72
73
|
self.configurations: list = configurations or []
|
|
73
74
|
self.model_timeout = kwargs.get('model_timeout', 120)
|
|
@@ -175,6 +176,20 @@ class AlitaClient:
|
|
|
175
176
|
return resp.json()
|
|
176
177
|
return []
|
|
177
178
|
|
|
179
|
+
def get_available_models(self):
|
|
180
|
+
"""Get list of available models from the configurations API.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
List of model dictionaries with 'name' and other properties,
|
|
184
|
+
or empty list if request fails.
|
|
185
|
+
"""
|
|
186
|
+
resp = requests.get(self.models_url, headers=self.headers, verify=False)
|
|
187
|
+
if resp.ok:
|
|
188
|
+
data = resp.json()
|
|
189
|
+
# API returns {"items": [...], ...}
|
|
190
|
+
return data.get('items', [])
|
|
191
|
+
return []
|
|
192
|
+
|
|
178
193
|
def get_embeddings(self, embedding_model: str) -> OpenAIEmbeddings:
|
|
179
194
|
"""
|
|
180
195
|
Get an instance of OpenAIEmbeddings configured with the project ID and auth token.
|
|
@@ -565,7 +580,7 @@ class AlitaClient:
|
|
|
565
580
|
monitoring_meta = tasknode_task.meta.get("monitoring", {})
|
|
566
581
|
return monitoring_meta["user_id"]
|
|
567
582
|
except Exception as e:
|
|
568
|
-
logger.
|
|
583
|
+
logger.debug(f"Error: Could not determine user ID for MCP tool: {e}")
|
|
569
584
|
return None
|
|
570
585
|
|
|
571
586
|
def predict_agent(self, llm: ChatOpenAI, instructions: str = "You are a helpful assistant.",
|
|
@@ -19,7 +19,7 @@ from langgraph.managed.base import is_managed_value
|
|
|
19
19
|
from langgraph.prebuilt import InjectedStore
|
|
20
20
|
from langgraph.store.base import BaseStore
|
|
21
21
|
|
|
22
|
-
from .constants import PRINTER_NODE_RS, PRINTER
|
|
22
|
+
from .constants import PRINTER_NODE_RS, PRINTER, PRINTER_COMPLETED_STATE
|
|
23
23
|
from .mixedAgentRenderes import convert_message_to_json
|
|
24
24
|
from .utils import create_state, propagate_the_input_mapping, safe_format
|
|
25
25
|
from ..tools.function import FunctionTool
|
|
@@ -244,11 +244,19 @@ class PrinterNode(Runnable):
|
|
|
244
244
|
result = {}
|
|
245
245
|
logger.debug(f"Initial text pattern: {self.input_mapping}")
|
|
246
246
|
mapping = propagate_the_input_mapping(self.input_mapping, [], state)
|
|
247
|
+
# for printer node we expect that all the lists will be joined into strings already
|
|
248
|
+
# Join any lists that haven't been converted yet
|
|
249
|
+
for key, value in mapping.items():
|
|
250
|
+
if isinstance(value, list):
|
|
251
|
+
mapping[key] = ', '.join(str(item) for item in value)
|
|
247
252
|
if mapping.get(PRINTER) is None:
|
|
248
253
|
raise ToolException(f"PrinterNode requires '{PRINTER}' field in input mapping")
|
|
249
254
|
formatted_output = mapping[PRINTER]
|
|
250
255
|
# add info label to the printer's output
|
|
251
|
-
if formatted_output:
|
|
256
|
+
if not formatted_output == PRINTER_COMPLETED_STATE:
|
|
257
|
+
# convert formatted output to string if it's not
|
|
258
|
+
if not isinstance(formatted_output, str):
|
|
259
|
+
formatted_output = str(formatted_output)
|
|
252
260
|
formatted_output += f"\n\n-----\n*How to proceed?*\n* *to resume the pipeline - type anything...*"
|
|
253
261
|
logger.debug(f"Formatted output: {formatted_output}")
|
|
254
262
|
result[PRINTER_NODE_RS] = formatted_output
|
|
@@ -475,10 +483,14 @@ def create_graph(
|
|
|
475
483
|
if toolkit_name:
|
|
476
484
|
tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
|
|
477
485
|
logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
|
|
478
|
-
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
486
|
+
if node_type in ['function', 'toolkit', 'mcp', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
487
|
+
if node_type == 'mcp' and tool_name not in [tool.name for tool in tools]:
|
|
488
|
+
# MCP is not connected and node cannot be added
|
|
489
|
+
raise ToolException(f"MCP tool '{tool_name}' not found in the provided tools. "
|
|
490
|
+
f"Make sure it is connected properly. Available tools: {[tool.name for tool in tools]}")
|
|
479
491
|
for tool in tools:
|
|
480
492
|
if tool.name == tool_name:
|
|
481
|
-
if node_type
|
|
493
|
+
if node_type in ['function', 'toolkit', 'mcp']:
|
|
482
494
|
lg_builder.add_node(node_id, FunctionTool(
|
|
483
495
|
tool=tool, name=node_id, return_type='dict',
|
|
484
496
|
output_variables=node.get('output', []),
|
|
@@ -643,6 +655,7 @@ def create_graph(
|
|
|
643
655
|
default_output=node.get('default_output', 'END')
|
|
644
656
|
)
|
|
645
657
|
)
|
|
658
|
+
continue
|
|
646
659
|
elif node_type == 'state_modifier':
|
|
647
660
|
lg_builder.add_node(node_id, StateModifierNode(
|
|
648
661
|
template=node.get('template', ''),
|
|
@@ -661,9 +674,9 @@ def create_graph(
|
|
|
661
674
|
# reset printer output variable to avoid carrying over
|
|
662
675
|
reset_node_id = f"{node_id}_reset"
|
|
663
676
|
lg_builder.add_node(reset_node_id, PrinterNode(
|
|
664
|
-
input_mapping={'printer': {'type': 'fixed', 'value':
|
|
677
|
+
input_mapping={'printer': {'type': 'fixed', 'value': PRINTER_COMPLETED_STATE}}
|
|
665
678
|
))
|
|
666
|
-
lg_builder.
|
|
679
|
+
lg_builder.add_conditional_edges(node_id, TransitionalEdge(reset_node_id))
|
|
667
680
|
lg_builder.add_conditional_edges(reset_node_id, TransitionalEdge(clean_string(node['transition'])))
|
|
668
681
|
continue
|
|
669
682
|
if node.get('transition'):
|
|
@@ -814,35 +827,63 @@ class LangGraphAgentRunnable(CompiledStateGraph):
|
|
|
814
827
|
input['messages'] = [convert_dict_to_message(msg) for msg in chat_history]
|
|
815
828
|
|
|
816
829
|
# handler for LLM node: if no input (Chat perspective), then take last human message
|
|
830
|
+
# Track if input came from messages to handle content extraction properly
|
|
831
|
+
input_from_messages = False
|
|
817
832
|
if not input.get('input'):
|
|
818
833
|
if input.get('messages'):
|
|
819
834
|
input['input'] = [next((msg for msg in reversed(input['messages']) if isinstance(msg, HumanMessage)),
|
|
820
|
-
|
|
835
|
+
None)]
|
|
836
|
+
if input['input'] is not None:
|
|
837
|
+
input_from_messages = True
|
|
821
838
|
|
|
822
839
|
# Append current input to existing messages instead of overwriting
|
|
823
840
|
if input.get('input'):
|
|
824
841
|
if isinstance(input['input'], str):
|
|
825
842
|
current_message = input['input']
|
|
826
843
|
else:
|
|
844
|
+
# input can be a list of messages or a single message object
|
|
827
845
|
current_message = input.get('input')[-1]
|
|
828
846
|
|
|
829
847
|
# TODO: add handler after we add 2+ inputs (filterByType, etc.)
|
|
830
848
|
if isinstance(current_message, HumanMessage):
|
|
831
849
|
current_content = current_message.content
|
|
832
850
|
if isinstance(current_content, list):
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
851
|
+
# Extract text parts and keep non-text parts (images, etc.)
|
|
852
|
+
text_contents = []
|
|
853
|
+
non_text_parts = []
|
|
854
|
+
|
|
855
|
+
for item in current_content:
|
|
856
|
+
if isinstance(item, dict) and item.get('type') == 'text':
|
|
857
|
+
text_contents.append(item['text'])
|
|
858
|
+
elif isinstance(item, str):
|
|
859
|
+
text_contents.append(item)
|
|
860
|
+
else:
|
|
861
|
+
# Keep image_url and other non-text content
|
|
862
|
+
non_text_parts.append(item)
|
|
863
|
+
|
|
864
|
+
# Set input to the joined text
|
|
865
|
+
input['input'] = ". ".join(text_contents) if text_contents else ""
|
|
866
|
+
|
|
867
|
+
# If this message came from input['messages'], update or remove it
|
|
868
|
+
if input_from_messages:
|
|
869
|
+
if non_text_parts:
|
|
870
|
+
# Keep the message but only with non-text content (images, etc.)
|
|
871
|
+
current_message.content = non_text_parts
|
|
872
|
+
else:
|
|
873
|
+
# All content was text, remove this message from the list
|
|
874
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
875
|
+
|
|
841
876
|
elif isinstance(current_content, str):
|
|
842
877
|
# on regenerate case
|
|
843
878
|
input['input'] = current_content
|
|
879
|
+
# If from messages and all content is text, remove the message
|
|
880
|
+
if input_from_messages:
|
|
881
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
844
882
|
else:
|
|
845
883
|
input['input'] = str(current_content)
|
|
884
|
+
# If from messages, remove since we extracted the content
|
|
885
|
+
if input_from_messages:
|
|
886
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
846
887
|
elif isinstance(current_message, str):
|
|
847
888
|
input['input'] = current_message
|
|
848
889
|
else:
|
|
@@ -852,17 +893,30 @@ class LangGraphAgentRunnable(CompiledStateGraph):
|
|
|
852
893
|
input['messages'] = [convert_dict_to_message(msg) for msg in input['messages']]
|
|
853
894
|
# Append to existing messages
|
|
854
895
|
# input['messages'].append(current_message)
|
|
855
|
-
else:
|
|
856
|
-
#
|
|
857
|
-
input['messages'] = [current_message]
|
|
896
|
+
# else:
|
|
897
|
+
# NOTE: Commented out to prevent duplicates with input['input']
|
|
898
|
+
# input['messages'] = [current_message]
|
|
899
|
+
|
|
900
|
+
# Validate that input is not empty after all processing
|
|
901
|
+
if not input.get('input'):
|
|
902
|
+
raise RuntimeError(
|
|
903
|
+
"Empty input after processing. Cannot send empty string to LLM. "
|
|
904
|
+
"This likely means the message contained only non-text content "
|
|
905
|
+
"with no accompanying text."
|
|
906
|
+
)
|
|
907
|
+
|
|
858
908
|
logging.info(f"Input: {thread_id} - {input}")
|
|
859
909
|
if self.checkpointer and self.checkpointer.get_tuple(config):
|
|
860
910
|
self.update_state(config, input)
|
|
861
|
-
|
|
911
|
+
if config.pop("should_continue", False):
|
|
912
|
+
invoke_input = input
|
|
913
|
+
else:
|
|
914
|
+
invoke_input = None
|
|
915
|
+
result = super().invoke(invoke_input, config=config, *args, **kwargs)
|
|
862
916
|
else:
|
|
863
917
|
result = super().invoke(input, config=config, *args, **kwargs)
|
|
864
918
|
try:
|
|
865
|
-
if
|
|
919
|
+
if result.get(PRINTER_NODE_RS) == PRINTER_COMPLETED_STATE:
|
|
866
920
|
output = next((msg.content for msg in reversed(result['messages']) if not isinstance(msg, HumanMessage)),
|
|
867
921
|
result['messages'][-1].content)
|
|
868
922
|
else:
|
|
@@ -2,7 +2,7 @@ import builtins
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
import re
|
|
5
|
-
from pydantic import create_model, Field
|
|
5
|
+
from pydantic import create_model, Field, Json
|
|
6
6
|
from typing import Tuple, TypedDict, Any, Optional, Annotated
|
|
7
7
|
from langchain_core.messages import AnyMessage
|
|
8
8
|
from langgraph.graph import add_messages
|
|
@@ -131,7 +131,7 @@ def parse_type(type_str):
|
|
|
131
131
|
|
|
132
132
|
|
|
133
133
|
def create_state(data: Optional[dict] = None):
|
|
134
|
-
state_dict = {'input': str, 'router_output': str,
|
|
134
|
+
state_dict = {'input': str, 'messages': 'list[str]', 'router_output': str,
|
|
135
135
|
ELITEA_RS: str, PRINTER_NODE_RS: str} # Always include router_output
|
|
136
136
|
types_dict = {}
|
|
137
137
|
if not data:
|
|
@@ -208,5 +208,21 @@ def safe_format(template, mapping):
|
|
|
208
208
|
def create_pydantic_model(model_name: str, variables: dict[str, dict]):
|
|
209
209
|
fields = {}
|
|
210
210
|
for var_name, var_data in variables.items():
|
|
211
|
-
fields[var_name] = (
|
|
212
|
-
return create_model(model_name, **fields)
|
|
211
|
+
fields[var_name] = (parse_pydantic_type(var_data['type']), Field(description=var_data.get('description', None)))
|
|
212
|
+
return create_model(model_name, **fields)
|
|
213
|
+
|
|
214
|
+
def parse_pydantic_type(type_name: str):
|
|
215
|
+
"""
|
|
216
|
+
Helper function to parse type names into Python types.
|
|
217
|
+
Extend this function to handle custom types like 'dict' -> Json[Any].
|
|
218
|
+
"""
|
|
219
|
+
type_mapping = {
|
|
220
|
+
'str': str,
|
|
221
|
+
'int': int,
|
|
222
|
+
'float': float,
|
|
223
|
+
'bool': bool,
|
|
224
|
+
'dict': Json[Any], # Map 'dict' to Pydantic's Json type
|
|
225
|
+
'list': list,
|
|
226
|
+
'any': Any
|
|
227
|
+
}
|
|
228
|
+
return type_mapping.get(type_name, Any)
|
|
@@ -23,11 +23,7 @@ class ArtifactToolkit(BaseToolkit):
|
|
|
23
23
|
# client = (Any, FieldInfo(description="Client object", required=True, autopopulate=True)),
|
|
24
24
|
bucket=(str, FieldInfo(
|
|
25
25
|
description="Bucket name",
|
|
26
|
-
pattern=r'^[a-z][a-z0-9-]*$'
|
|
27
|
-
json_schema_extra={
|
|
28
|
-
'toolkit_name': True,
|
|
29
|
-
'max_toolkit_length': ArtifactToolkit.toolkit_max_length
|
|
30
|
-
}
|
|
26
|
+
pattern=r'^[a-z][a-z0-9-]*$'
|
|
31
27
|
)),
|
|
32
28
|
selected_tools=(List[Literal[tuple(selected_tools)]], Field(default=[], json_schema_extra={'args_schemas': selected_tools})),
|
|
33
29
|
# indexer settings
|
|
@@ -37,7 +33,10 @@ class ArtifactToolkit(BaseToolkit):
|
|
|
37
33
|
embedding_model=(Optional[str], Field(default=None, description="Embedding configuration.",
|
|
38
34
|
json_schema_extra={'configuration_model': 'embedding'})),
|
|
39
35
|
|
|
40
|
-
__config__=ConfigDict(json_schema_extra={'metadata': {"label": "Artifact",
|
|
36
|
+
__config__=ConfigDict(json_schema_extra={'metadata': {"label": "Artifact",
|
|
37
|
+
"icon_url": None,
|
|
38
|
+
"max_length": ArtifactToolkit.toolkit_max_length
|
|
39
|
+
}})
|
|
41
40
|
)
|
|
42
41
|
|
|
43
42
|
@classmethod
|
|
@@ -498,9 +498,12 @@ class McpToolkit(BaseToolkit):
|
|
|
498
498
|
all_tools = []
|
|
499
499
|
session_id = connection_config.session_id
|
|
500
500
|
|
|
501
|
+
# Generate temporary session_id if not provided (for OAuth flow)
|
|
502
|
+
# The real session_id should come from frontend after OAuth completes
|
|
501
503
|
if not session_id:
|
|
502
|
-
|
|
503
|
-
|
|
504
|
+
import uuid
|
|
505
|
+
session_id = str(uuid.uuid4())
|
|
506
|
+
logger.info(f"[MCP SSE] Generated temporary session_id for OAuth: {session_id}")
|
|
504
507
|
|
|
505
508
|
logger.info(f"[MCP SSE] Discovering from {connection_config.url} with session {session_id}")
|
|
506
509
|
|
|
@@ -110,6 +110,7 @@ def get_tools(tools_list: list, alita_client, llm, memory_store: BaseStore = Non
|
|
|
110
110
|
toolkit_name=tool.get('toolkit_name', ''),
|
|
111
111
|
**tool['settings']).get_tools())
|
|
112
112
|
elif tool['type'] == 'mcp':
|
|
113
|
+
# remote mcp tool initialization with token injection
|
|
113
114
|
settings = dict(tool['settings'])
|
|
114
115
|
url = settings.get('url')
|
|
115
116
|
headers = settings.get('headers')
|
|
@@ -16,6 +16,18 @@ from ..langchain.utils import propagate_the_input_mapping
|
|
|
16
16
|
logger = logging.getLogger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
def replace_escaped_newlines(data):
|
|
20
|
+
"""
|
|
21
|
+
Replace \\n with \n in all string values recursively.
|
|
22
|
+
Required for sanitization of state variables in code node
|
|
23
|
+
"""
|
|
24
|
+
if isinstance(data, dict):
|
|
25
|
+
return {key: replace_escaped_newlines(value) for key, value in data.items()}
|
|
26
|
+
elif isinstance(data, str):
|
|
27
|
+
return data.replace('\\n', '\n')
|
|
28
|
+
else:
|
|
29
|
+
return data
|
|
30
|
+
|
|
19
31
|
class FunctionTool(BaseTool):
|
|
20
32
|
name: str = 'FunctionalTool'
|
|
21
33
|
description: str = 'This is direct call node for tools'
|
|
@@ -30,11 +42,13 @@ class FunctionTool(BaseTool):
|
|
|
30
42
|
def _prepare_pyodide_input(self, state: Union[str, dict, ToolCall]) -> str:
|
|
31
43
|
"""Prepare input for PyodideSandboxTool by injecting state into the code block."""
|
|
32
44
|
# add state into the code block here since it might be changed during the execution of the code
|
|
33
|
-
state_copy = deepcopy(state)
|
|
45
|
+
state_copy = replace_escaped_newlines(deepcopy(state))
|
|
34
46
|
|
|
35
47
|
del state_copy['messages'] # remove messages to avoid issues with pickling without langchain-core
|
|
36
48
|
# inject state into the code block as alita_state variable
|
|
37
|
-
|
|
49
|
+
state_json = json.dumps(state_copy, ensure_ascii=False)
|
|
50
|
+
pyodide_predata = f'#state dict\nimport json\nalita_state = json.loads({json.dumps(state_json)})\n'
|
|
51
|
+
|
|
38
52
|
return pyodide_predata
|
|
39
53
|
|
|
40
54
|
def _handle_pyodide_output(self, tool_result: Any) -> dict:
|
|
@@ -94,9 +108,7 @@ class FunctionTool(BaseTool):
|
|
|
94
108
|
# special handler for PyodideSandboxTool
|
|
95
109
|
if self._is_pyodide_tool():
|
|
96
110
|
code = func_args['code']
|
|
97
|
-
func_args['code'] =
|
|
98
|
-
# handle new lines in the code properly
|
|
99
|
-
.replace('\\n','\\\\n'))
|
|
111
|
+
func_args['code'] = f"{self._prepare_pyodide_input(state)}\n{code}"
|
|
100
112
|
try:
|
|
101
113
|
tool_result = self.tool.invoke(func_args, config, **kwargs)
|
|
102
114
|
dispatch_custom_event(
|
|
@@ -120,7 +132,8 @@ class FunctionTool(BaseTool):
|
|
|
120
132
|
messages_dict = {
|
|
121
133
|
"messages": [{
|
|
122
134
|
"role": "assistant",
|
|
123
|
-
"content": dumps(tool_result)
|
|
135
|
+
"content": dumps(tool_result)
|
|
136
|
+
if not isinstance(tool_result, ToolException) and not isinstance(tool_result, str)
|
|
124
137
|
else str(tool_result)
|
|
125
138
|
}]
|
|
126
139
|
}
|
alita_sdk/runtime/tools/llm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import logging
|
|
2
3
|
from traceback import format_exc
|
|
3
4
|
from typing import Any, Optional, List, Union
|
|
@@ -132,7 +133,9 @@ class LLMNode(BaseTool):
|
|
|
132
133
|
struct_model = create_pydantic_model(f"LLMOutput", struct_params)
|
|
133
134
|
completion = llm_client.invoke(messages, config=config)
|
|
134
135
|
if hasattr(completion, 'tool_calls') and completion.tool_calls:
|
|
135
|
-
new_messages, _ = self.
|
|
136
|
+
new_messages, _ = self._run_async_in_sync_context(
|
|
137
|
+
self.__perform_tool_calling(completion, messages, llm_client, config)
|
|
138
|
+
)
|
|
136
139
|
llm = self.__get_struct_output_model(llm_client, struct_model)
|
|
137
140
|
completion = llm.invoke(new_messages, config=config)
|
|
138
141
|
result = completion.model_dump()
|
|
@@ -155,7 +158,9 @@ class LLMNode(BaseTool):
|
|
|
155
158
|
# Handle both tool-calling and regular responses
|
|
156
159
|
if hasattr(completion, 'tool_calls') and completion.tool_calls:
|
|
157
160
|
# Handle iterative tool-calling and execution
|
|
158
|
-
new_messages, current_completion = self.
|
|
161
|
+
new_messages, current_completion = self._run_async_in_sync_context(
|
|
162
|
+
self.__perform_tool_calling(completion, messages, llm_client, config)
|
|
163
|
+
)
|
|
159
164
|
|
|
160
165
|
output_msgs = {"messages": new_messages}
|
|
161
166
|
if self.output_variables:
|
|
@@ -190,9 +195,53 @@ class LLMNode(BaseTool):
|
|
|
190
195
|
def _run(self, *args, **kwargs):
|
|
191
196
|
# Legacy support for old interface
|
|
192
197
|
return self.invoke(kwargs, **kwargs)
|
|
198
|
+
|
|
199
|
+
def _run_async_in_sync_context(self, coro):
|
|
200
|
+
"""Run async coroutine from sync context.
|
|
201
|
+
|
|
202
|
+
For MCP tools with persistent sessions, we reuse the same event loop
|
|
203
|
+
that was used to create the MCP client and sessions (set by CLI).
|
|
204
|
+
"""
|
|
205
|
+
try:
|
|
206
|
+
loop = asyncio.get_running_loop()
|
|
207
|
+
# Already in async context - run in thread with new loop
|
|
208
|
+
import threading
|
|
209
|
+
|
|
210
|
+
result_container = []
|
|
211
|
+
|
|
212
|
+
def run_in_thread():
|
|
213
|
+
new_loop = asyncio.new_event_loop()
|
|
214
|
+
asyncio.set_event_loop(new_loop)
|
|
215
|
+
try:
|
|
216
|
+
result_container.append(new_loop.run_until_complete(coro))
|
|
217
|
+
finally:
|
|
218
|
+
new_loop.close()
|
|
219
|
+
|
|
220
|
+
thread = threading.Thread(target=run_in_thread)
|
|
221
|
+
thread.start()
|
|
222
|
+
thread.join()
|
|
223
|
+
return result_container[0] if result_container else None
|
|
224
|
+
|
|
225
|
+
except RuntimeError:
|
|
226
|
+
# No event loop running - use/create persistent loop
|
|
227
|
+
# This loop is shared with MCP session creation for stateful tools
|
|
228
|
+
if not hasattr(self.__class__, '_persistent_loop') or \
|
|
229
|
+
self.__class__._persistent_loop is None or \
|
|
230
|
+
self.__class__._persistent_loop.is_closed():
|
|
231
|
+
self.__class__._persistent_loop = asyncio.new_event_loop()
|
|
232
|
+
logger.debug("Created persistent event loop for async tools")
|
|
233
|
+
|
|
234
|
+
loop = self.__class__._persistent_loop
|
|
235
|
+
asyncio.set_event_loop(loop)
|
|
236
|
+
return loop.run_until_complete(coro)
|
|
237
|
+
|
|
238
|
+
async def _arun(self, *args, **kwargs):
|
|
239
|
+
# Legacy async support
|
|
240
|
+
return self.invoke(kwargs, **kwargs)
|
|
193
241
|
|
|
194
|
-
def __perform_tool_calling(self, completion, messages, llm_client, config):
|
|
242
|
+
async def __perform_tool_calling(self, completion, messages, llm_client, config):
|
|
195
243
|
# Handle iterative tool-calling and execution
|
|
244
|
+
logger.info(f"__perform_tool_calling called with {len(completion.tool_calls) if hasattr(completion, 'tool_calls') else 0} tool calls")
|
|
196
245
|
new_messages = messages + [completion]
|
|
197
246
|
iteration = 0
|
|
198
247
|
|
|
@@ -230,9 +279,16 @@ class LLMNode(BaseTool):
|
|
|
230
279
|
if tool_to_execute:
|
|
231
280
|
try:
|
|
232
281
|
logger.info(f"Executing tool '{tool_name}' with args: {tool_args}")
|
|
233
|
-
|
|
234
|
-
#
|
|
235
|
-
tool_result =
|
|
282
|
+
|
|
283
|
+
# Try async invoke first (for MCP tools), fallback to sync
|
|
284
|
+
tool_result = None
|
|
285
|
+
try:
|
|
286
|
+
# Try async invocation first
|
|
287
|
+
tool_result = await tool_to_execute.ainvoke(tool_args, config=config)
|
|
288
|
+
except NotImplementedError:
|
|
289
|
+
# Tool doesn't support async, use sync invoke
|
|
290
|
+
logger.debug(f"Tool '{tool_name}' doesn't support async, using sync invoke")
|
|
291
|
+
tool_result = tool_to_execute.invoke(tool_args, config=config)
|
|
236
292
|
|
|
237
293
|
# Create tool message with result - preserve structured content
|
|
238
294
|
from langchain_core.messages import ToolMessage
|
|
@@ -256,7 +312,9 @@ class LLMNode(BaseTool):
|
|
|
256
312
|
new_messages.append(tool_message)
|
|
257
313
|
|
|
258
314
|
except Exception as e:
|
|
259
|
-
|
|
315
|
+
import traceback
|
|
316
|
+
error_details = traceback.format_exc()
|
|
317
|
+
logger.error(f"Error executing tool '{tool_name}': {e}\n{error_details}")
|
|
260
318
|
# Create error tool message
|
|
261
319
|
from langchain_core.messages import ToolMessage
|
|
262
320
|
tool_message = ToolMessage(
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
|
-
import math
|
|
3
2
|
from collections import OrderedDict
|
|
4
3
|
from logging import getLogger
|
|
5
4
|
from typing import Any, Optional, List, Dict, Generator
|
|
6
5
|
|
|
6
|
+
import math
|
|
7
7
|
from langchain_core.documents import Document
|
|
8
8
|
from langchain_core.messages import HumanMessage
|
|
9
9
|
from langchain_core.tools import ToolException
|
|
@@ -12,7 +12,7 @@ from pydantic import BaseModel, model_validator, Field
|
|
|
12
12
|
|
|
13
13
|
from alita_sdk.tools.elitea_base import BaseToolApiWrapper
|
|
14
14
|
from alita_sdk.tools.vector_adapters.VectorStoreAdapter import VectorStoreAdapterFactory
|
|
15
|
-
from
|
|
15
|
+
from ...runtime.utils.utils import IndexerKeywords
|
|
16
16
|
|
|
17
17
|
logger = getLogger(__name__)
|
|
18
18
|
|
|
@@ -222,6 +222,21 @@ class VectorStoreWrapperBase(BaseToolApiWrapper):
|
|
|
222
222
|
raise RuntimeError(f"Multiple index_meta documents found: {index_metas}")
|
|
223
223
|
return index_metas[0] if index_metas else None
|
|
224
224
|
|
|
225
|
+
def get_indexed_count(self, index_name: str) -> int:
|
|
226
|
+
from sqlalchemy.orm import Session
|
|
227
|
+
from sqlalchemy import func, or_
|
|
228
|
+
|
|
229
|
+
with Session(self.vectorstore.session_maker.bind) as session:
|
|
230
|
+
return session.query(
|
|
231
|
+
self.vectorstore.EmbeddingStore.id,
|
|
232
|
+
).filter(
|
|
233
|
+
func.jsonb_extract_path_text(self.vectorstore.EmbeddingStore.cmetadata, 'collection') == index_name,
|
|
234
|
+
or_(
|
|
235
|
+
func.jsonb_extract_path_text(self.vectorstore.EmbeddingStore.cmetadata, 'type').is_(None),
|
|
236
|
+
func.jsonb_extract_path_text(self.vectorstore.EmbeddingStore.cmetadata, 'type') != IndexerKeywords.INDEX_META_TYPE.value
|
|
237
|
+
)
|
|
238
|
+
).count()
|
|
239
|
+
|
|
225
240
|
def _clean_collection(self, index_name: str = ''):
|
|
226
241
|
"""
|
|
227
242
|
Clean the vectorstore collection by deleting all indexed data.
|
|
@@ -65,6 +65,53 @@ class McpSseClient:
|
|
|
65
65
|
|
|
66
66
|
logger.info(f"[MCP SSE Client] Stream opened: status={self._stream_response.status}")
|
|
67
67
|
|
|
68
|
+
# Handle 401 Unauthorized - need OAuth
|
|
69
|
+
if self._stream_response.status == 401:
|
|
70
|
+
from ..utils.mcp_oauth import (
|
|
71
|
+
McpAuthorizationRequired,
|
|
72
|
+
canonical_resource,
|
|
73
|
+
extract_resource_metadata_url,
|
|
74
|
+
fetch_resource_metadata_async,
|
|
75
|
+
infer_authorization_servers_from_realm,
|
|
76
|
+
fetch_oauth_authorization_server_metadata
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
auth_header = self._stream_response.headers.get('WWW-Authenticate', '')
|
|
80
|
+
resource_metadata_url = extract_resource_metadata_url(auth_header, self.url)
|
|
81
|
+
|
|
82
|
+
metadata = None
|
|
83
|
+
if resource_metadata_url:
|
|
84
|
+
metadata = await fetch_resource_metadata_async(
|
|
85
|
+
resource_metadata_url,
|
|
86
|
+
session=self._stream_session,
|
|
87
|
+
timeout=30
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Infer authorization servers if not in metadata
|
|
91
|
+
if not metadata or not metadata.get('authorization_servers'):
|
|
92
|
+
inferred_servers = infer_authorization_servers_from_realm(auth_header, self.url)
|
|
93
|
+
if inferred_servers:
|
|
94
|
+
if not metadata:
|
|
95
|
+
metadata = {}
|
|
96
|
+
metadata['authorization_servers'] = inferred_servers
|
|
97
|
+
logger.info(f"[MCP SSE Client] Inferred authorization servers: {inferred_servers}")
|
|
98
|
+
|
|
99
|
+
# Fetch OAuth metadata
|
|
100
|
+
auth_server_metadata = fetch_oauth_authorization_server_metadata(inferred_servers[0], timeout=30)
|
|
101
|
+
if auth_server_metadata:
|
|
102
|
+
metadata['oauth_authorization_server'] = auth_server_metadata
|
|
103
|
+
logger.info(f"[MCP SSE Client] Fetched OAuth metadata")
|
|
104
|
+
|
|
105
|
+
raise McpAuthorizationRequired(
|
|
106
|
+
message=f"MCP server {self.url} requires OAuth authorization",
|
|
107
|
+
server_url=canonical_resource(self.url),
|
|
108
|
+
resource_metadata_url=resource_metadata_url,
|
|
109
|
+
www_authenticate=auth_header,
|
|
110
|
+
resource_metadata=metadata,
|
|
111
|
+
status=self._stream_response.status,
|
|
112
|
+
tool_name=self.url,
|
|
113
|
+
)
|
|
114
|
+
|
|
68
115
|
if self._stream_response.status != 200:
|
|
69
116
|
error_text = await self._stream_response.text()
|
|
70
117
|
raise Exception(f"Failed to open SSE stream: HTTP {self._stream_response.status}: {error_text}")
|
|
@@ -248,18 +295,29 @@ class McpSseClient:
|
|
|
248
295
|
"""Close the persistent SSE stream."""
|
|
249
296
|
logger.info(f"[MCP SSE Client] Closing connection...")
|
|
250
297
|
|
|
298
|
+
# Cancel background stream reader task
|
|
251
299
|
if self._stream_task and not self._stream_task.done():
|
|
252
300
|
self._stream_task.cancel()
|
|
253
301
|
try:
|
|
254
302
|
await self._stream_task
|
|
255
|
-
except asyncio.CancelledError:
|
|
256
|
-
|
|
303
|
+
except (asyncio.CancelledError, Exception) as e:
|
|
304
|
+
logger.debug(f"[MCP SSE Client] Stream task cleanup: {e}")
|
|
257
305
|
|
|
258
|
-
|
|
259
|
-
|
|
306
|
+
# Close response stream
|
|
307
|
+
if self._stream_response and not self._stream_response.closed:
|
|
308
|
+
try:
|
|
309
|
+
self._stream_response.close()
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.debug(f"[MCP SSE Client] Response close error: {e}")
|
|
260
312
|
|
|
261
|
-
|
|
262
|
-
|
|
313
|
+
# Close session
|
|
314
|
+
if self._stream_session and not self._stream_session.closed:
|
|
315
|
+
try:
|
|
316
|
+
await self._stream_session.close()
|
|
317
|
+
# Give aiohttp time to cleanup
|
|
318
|
+
await asyncio.sleep(0.1)
|
|
319
|
+
except Exception as e:
|
|
320
|
+
logger.debug(f"[MCP SSE Client] Session close error: {e}")
|
|
263
321
|
|
|
264
322
|
logger.info(f"[MCP SSE Client] Connection closed")
|
|
265
323
|
|
|
@@ -27,7 +27,6 @@ class AzureDevOpsPlansToolkit(BaseToolkit):
|
|
|
27
27
|
AzureDevOpsPlansToolkit.toolkit_max_length = get_max_toolkit_length(selected_tools)
|
|
28
28
|
m = create_model(
|
|
29
29
|
name_alias,
|
|
30
|
-
name=(str, Field(description="Toolkit name", json_schema_extra={'toolkit_name': True, 'max_toolkit_length': AzureDevOpsPlansToolkit.toolkit_max_length})),
|
|
31
30
|
ado_configuration=(AdoConfiguration, Field(description="Ado configuration", json_schema_extra={'configuration_types': ['ado']})),
|
|
32
31
|
limit=(Optional[int], Field(description="ADO plans limit used for limitation of the list with results", default=5)),
|
|
33
32
|
# indexer settings
|
|
@@ -40,6 +39,7 @@ class AzureDevOpsPlansToolkit(BaseToolkit):
|
|
|
40
39
|
{
|
|
41
40
|
"label": "ADO plans",
|
|
42
41
|
"icon_url": "ado-plans.svg",
|
|
42
|
+
"max_length": AzureDevOpsPlansToolkit.toolkit_max_length,
|
|
43
43
|
"categories": ["test management"],
|
|
44
44
|
"extra_categories": ["test case management", "qa"],
|
|
45
45
|
"sections": {
|