alita-sdk 0.3.379__py3-none-any.whl → 0.3.462__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent_executor.py +144 -0
- alita_sdk/cli/agent_loader.py +197 -0
- alita_sdk/cli/agent_ui.py +166 -0
- alita_sdk/cli/agents.py +1069 -0
- alita_sdk/cli/callbacks.py +576 -0
- alita_sdk/cli/cli.py +159 -0
- alita_sdk/cli/config.py +153 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/toolkit.py +330 -0
- alita_sdk/cli/toolkit_loader.py +55 -0
- alita_sdk/cli/tools/__init__.py +9 -0
- alita_sdk/cli/tools/filesystem.py +905 -0
- alita_sdk/configurations/bitbucket.py +95 -0
- alita_sdk/configurations/confluence.py +96 -1
- alita_sdk/configurations/gitlab.py +79 -0
- alita_sdk/configurations/jira.py +103 -0
- alita_sdk/configurations/testrail.py +88 -0
- alita_sdk/configurations/xray.py +93 -0
- alita_sdk/configurations/zephyr_enterprise.py +93 -0
- alita_sdk/configurations/zephyr_essential.py +75 -0
- alita_sdk/runtime/clients/client.py +47 -10
- alita_sdk/runtime/clients/mcp_discovery.py +342 -0
- alita_sdk/runtime/clients/mcp_manager.py +262 -0
- alita_sdk/runtime/clients/sandbox_client.py +8 -0
- alita_sdk/runtime/langchain/assistant.py +37 -16
- alita_sdk/runtime/langchain/constants.py +6 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +4 -1
- alita_sdk/runtime/langchain/document_loaders/constants.py +28 -12
- alita_sdk/runtime/langchain/langraph_agent.py +146 -31
- alita_sdk/runtime/langchain/utils.py +39 -7
- alita_sdk/runtime/models/mcp_models.py +61 -0
- alita_sdk/runtime/toolkits/__init__.py +24 -0
- alita_sdk/runtime/toolkits/application.py +8 -1
- alita_sdk/runtime/toolkits/artifact.py +5 -6
- alita_sdk/runtime/toolkits/mcp.py +895 -0
- alita_sdk/runtime/toolkits/tools.py +137 -56
- alita_sdk/runtime/tools/__init__.py +7 -2
- alita_sdk/runtime/tools/application.py +7 -0
- alita_sdk/runtime/tools/function.py +29 -25
- alita_sdk/runtime/tools/graph.py +10 -4
- alita_sdk/runtime/tools/image_generation.py +104 -8
- alita_sdk/runtime/tools/llm.py +204 -114
- alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
- alita_sdk/runtime/tools/mcp_remote_tool.py +166 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
- alita_sdk/runtime/tools/sandbox.py +57 -43
- alita_sdk/runtime/tools/vectorstore.py +2 -1
- alita_sdk/runtime/tools/vectorstore_base.py +19 -3
- alita_sdk/runtime/utils/mcp_oauth.py +164 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +405 -0
- alita_sdk/runtime/utils/streamlit.py +34 -3
- alita_sdk/runtime/utils/toolkit_utils.py +14 -4
- alita_sdk/tools/__init__.py +46 -31
- alita_sdk/tools/ado/repos/__init__.py +1 -0
- alita_sdk/tools/ado/test_plan/__init__.py +1 -1
- alita_sdk/tools/ado/wiki/__init__.py +1 -5
- alita_sdk/tools/ado/work_item/__init__.py +1 -5
- alita_sdk/tools/ado/work_item/ado_wrapper.py +17 -8
- alita_sdk/tools/base_indexer_toolkit.py +105 -43
- alita_sdk/tools/bitbucket/__init__.py +1 -0
- alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
- alita_sdk/tools/code/sonar/__init__.py +1 -1
- alita_sdk/tools/code_indexer_toolkit.py +13 -3
- alita_sdk/tools/confluence/__init__.py +2 -2
- alita_sdk/tools/confluence/api_wrapper.py +29 -7
- alita_sdk/tools/confluence/loader.py +10 -0
- alita_sdk/tools/github/__init__.py +2 -2
- alita_sdk/tools/gitlab/__init__.py +2 -1
- alita_sdk/tools/gitlab/api_wrapper.py +11 -7
- alita_sdk/tools/gitlab_org/__init__.py +1 -2
- alita_sdk/tools/google_places/__init__.py +2 -1
- alita_sdk/tools/jira/__init__.py +1 -0
- alita_sdk/tools/jira/api_wrapper.py +1 -1
- alita_sdk/tools/memory/__init__.py +1 -1
- alita_sdk/tools/openapi/__init__.py +10 -1
- alita_sdk/tools/pandas/__init__.py +1 -1
- alita_sdk/tools/postman/__init__.py +2 -1
- alita_sdk/tools/pptx/__init__.py +2 -2
- alita_sdk/tools/qtest/__init__.py +3 -3
- alita_sdk/tools/qtest/api_wrapper.py +1708 -76
- alita_sdk/tools/rally/__init__.py +1 -2
- alita_sdk/tools/report_portal/__init__.py +1 -0
- alita_sdk/tools/salesforce/__init__.py +1 -0
- alita_sdk/tools/servicenow/__init__.py +2 -3
- alita_sdk/tools/sharepoint/__init__.py +1 -0
- alita_sdk/tools/sharepoint/api_wrapper.py +125 -34
- alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
- alita_sdk/tools/sharepoint/utils.py +8 -2
- alita_sdk/tools/slack/__init__.py +1 -0
- alita_sdk/tools/sql/__init__.py +2 -1
- alita_sdk/tools/testio/__init__.py +1 -0
- alita_sdk/tools/testrail/__init__.py +1 -3
- alita_sdk/tools/utils/content_parser.py +27 -16
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +18 -5
- alita_sdk/tools/xray/__init__.py +2 -1
- alita_sdk/tools/zephyr/__init__.py +2 -1
- alita_sdk/tools/zephyr_enterprise/__init__.py +1 -0
- alita_sdk/tools/zephyr_essential/__init__.py +1 -0
- alita_sdk/tools/zephyr_scale/__init__.py +1 -0
- alita_sdk/tools/zephyr_squad/__init__.py +1 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/METADATA +8 -2
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/RECORD +110 -86
- alita_sdk-0.3.462.dist-info/entry_points.txt +2 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/top_level.txt +0 -0
|
@@ -19,8 +19,9 @@ from langgraph.managed.base import is_managed_value
|
|
|
19
19
|
from langgraph.prebuilt import InjectedStore
|
|
20
20
|
from langgraph.store.base import BaseStore
|
|
21
21
|
|
|
22
|
+
from .constants import PRINTER_NODE_RS, PRINTER, PRINTER_COMPLETED_STATE
|
|
22
23
|
from .mixedAgentRenderes import convert_message_to_json
|
|
23
|
-
from .utils import create_state, propagate_the_input_mapping
|
|
24
|
+
from .utils import create_state, propagate_the_input_mapping, safe_format
|
|
24
25
|
from ..tools.function import FunctionTool
|
|
25
26
|
from ..tools.indexer_tool import IndexerNode
|
|
26
27
|
from ..tools.llm import LLMNode
|
|
@@ -232,6 +233,32 @@ class StateDefaultNode(Runnable):
|
|
|
232
233
|
result[key] = temp_value
|
|
233
234
|
return result
|
|
234
235
|
|
|
236
|
+
class PrinterNode(Runnable):
|
|
237
|
+
name = "PrinterNode"
|
|
238
|
+
|
|
239
|
+
def __init__(self, input_mapping: Optional[dict[str, dict]]):
|
|
240
|
+
self.input_mapping = input_mapping
|
|
241
|
+
|
|
242
|
+
def invoke(self, state: BaseStore, config: Optional[RunnableConfig] = None) -> dict:
|
|
243
|
+
logger.info(f"Printer Node - Current state variables: {state}")
|
|
244
|
+
result = {}
|
|
245
|
+
logger.debug(f"Initial text pattern: {self.input_mapping}")
|
|
246
|
+
mapping = propagate_the_input_mapping(self.input_mapping, [], state)
|
|
247
|
+
# for printer node we expect that all the lists will be joined into strings already
|
|
248
|
+
# Join any lists that haven't been converted yet
|
|
249
|
+
for key, value in mapping.items():
|
|
250
|
+
if isinstance(value, list):
|
|
251
|
+
mapping[key] = ', '.join(str(item) for item in value)
|
|
252
|
+
if mapping.get(PRINTER) is None:
|
|
253
|
+
raise ToolException(f"PrinterNode requires '{PRINTER}' field in input mapping")
|
|
254
|
+
formatted_output = mapping[PRINTER]
|
|
255
|
+
# add info label to the printer's output
|
|
256
|
+
if not formatted_output == PRINTER_COMPLETED_STATE:
|
|
257
|
+
formatted_output += f"\n\n-----\n*How to proceed?*\n* *to resume the pipeline - type anything...*"
|
|
258
|
+
logger.debug(f"Formatted output: {formatted_output}")
|
|
259
|
+
result[PRINTER_NODE_RS] = formatted_output
|
|
260
|
+
return result
|
|
261
|
+
|
|
235
262
|
|
|
236
263
|
class StateModifierNode(Runnable):
|
|
237
264
|
name = "StateModifierNode"
|
|
@@ -348,8 +375,8 @@ class StateModifierNode(Runnable):
|
|
|
348
375
|
return result
|
|
349
376
|
|
|
350
377
|
|
|
351
|
-
|
|
352
|
-
|
|
378
|
+
def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None,
|
|
379
|
+
state_class=None, output_variables=None):
|
|
353
380
|
# prepare output channels
|
|
354
381
|
if interrupt_after is None:
|
|
355
382
|
interrupt_after = []
|
|
@@ -453,10 +480,14 @@ def create_graph(
|
|
|
453
480
|
if toolkit_name:
|
|
454
481
|
tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
|
|
455
482
|
logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
|
|
456
|
-
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
483
|
+
if node_type in ['function', 'toolkit', 'mcp', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
484
|
+
if node_type == 'mcp' and tool_name not in [tool.name for tool in tools]:
|
|
485
|
+
# MCP is not connected and node cannot be added
|
|
486
|
+
raise ToolException(f"MCP tool '{tool_name}' not found in the provided tools. "
|
|
487
|
+
f"Make sure it is connected properly. Available tools: {[tool.name for tool in tools]}")
|
|
457
488
|
for tool in tools:
|
|
458
489
|
if tool.name == tool_name:
|
|
459
|
-
if node_type
|
|
490
|
+
if node_type in ['function', 'toolkit', 'mcp']:
|
|
460
491
|
lg_builder.add_node(node_id, FunctionTool(
|
|
461
492
|
tool=tool, name=node_id, return_type='dict',
|
|
462
493
|
output_variables=node.get('output', []),
|
|
@@ -466,11 +497,12 @@ def create_graph(
|
|
|
466
497
|
elif node_type == 'agent':
|
|
467
498
|
input_params = node.get('input', ['messages'])
|
|
468
499
|
input_mapping = node.get('input_mapping',
|
|
469
|
-
|
|
500
|
+
{'messages': {'type': 'variable', 'value': 'messages'}})
|
|
501
|
+
output_vars = node.get('output', [])
|
|
470
502
|
lg_builder.add_node(node_id, FunctionTool(
|
|
471
503
|
client=client, tool=tool,
|
|
472
504
|
name=node_id, return_type='str',
|
|
473
|
-
output_variables=
|
|
505
|
+
output_variables=output_vars + ['messages'] if 'messages' not in output_vars else output_vars,
|
|
474
506
|
input_variables=input_params,
|
|
475
507
|
input_mapping= input_mapping
|
|
476
508
|
))
|
|
@@ -481,7 +513,8 @@ def create_graph(
|
|
|
481
513
|
# wrap with mappings
|
|
482
514
|
pipeline_name = node.get('tool', None)
|
|
483
515
|
if not pipeline_name:
|
|
484
|
-
raise ValueError(
|
|
516
|
+
raise ValueError(
|
|
517
|
+
"Subgraph must have a 'tool' node: add required tool to the subgraph node")
|
|
485
518
|
node_fn = SubgraphRunnable(
|
|
486
519
|
inner=tool.graph,
|
|
487
520
|
name=pipeline_name,
|
|
@@ -499,15 +532,6 @@ def create_graph(
|
|
|
499
532
|
structured_output=node.get('structured_output', False),
|
|
500
533
|
task=node.get('task')
|
|
501
534
|
))
|
|
502
|
-
# TODO: decide on struct output for agent nodes
|
|
503
|
-
# elif node_type == 'agent':
|
|
504
|
-
# lg_builder.add_node(node_id, AgentNode(
|
|
505
|
-
# client=client, tool=tool,
|
|
506
|
-
# name=node['id'], return_type='dict',
|
|
507
|
-
# output_variables=node.get('output', []),
|
|
508
|
-
# input_variables=node.get('input', ['messages']),
|
|
509
|
-
# task=node.get('task')
|
|
510
|
-
# ))
|
|
511
535
|
elif node_type == 'loop':
|
|
512
536
|
lg_builder.add_node(node_id, LoopNode(
|
|
513
537
|
client=client, tool=tool,
|
|
@@ -520,7 +544,8 @@ def create_graph(
|
|
|
520
544
|
loop_toolkit_name = node.get('loop_toolkit_name')
|
|
521
545
|
loop_tool_name = node.get('loop_tool')
|
|
522
546
|
if (loop_toolkit_name and loop_tool_name) or loop_tool_name:
|
|
523
|
-
loop_tool_name = f"{clean_string(loop_toolkit_name)}{TOOLKIT_SPLITTER}{loop_tool_name}" if loop_toolkit_name else clean_string(
|
|
547
|
+
loop_tool_name = f"{clean_string(loop_toolkit_name)}{TOOLKIT_SPLITTER}{loop_tool_name}" if loop_toolkit_name else clean_string(
|
|
548
|
+
loop_tool_name)
|
|
524
549
|
for t in tools:
|
|
525
550
|
if t.name == loop_tool_name:
|
|
526
551
|
logger.debug(f"Loop tool discovered: {t}")
|
|
@@ -555,7 +580,8 @@ def create_graph(
|
|
|
555
580
|
break
|
|
556
581
|
elif node_type == 'code':
|
|
557
582
|
from ..tools.sandbox import create_sandbox_tool
|
|
558
|
-
sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True
|
|
583
|
+
sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True,
|
|
584
|
+
alita_client=kwargs.get('alita_client', None))
|
|
559
585
|
code_data = node.get('code', {'type': 'fixed', 'value': "return 'Code block is empty'"})
|
|
560
586
|
lg_builder.add_node(node_id, FunctionTool(
|
|
561
587
|
tool=sandbox_tool, name=node['id'], return_type='dict',
|
|
@@ -593,7 +619,7 @@ def create_graph(
|
|
|
593
619
|
else:
|
|
594
620
|
# Use all available tools
|
|
595
621
|
available_tools = [tool for tool in tools if isinstance(tool, BaseTool)]
|
|
596
|
-
|
|
622
|
+
|
|
597
623
|
lg_builder.add_node(node_id, LLMNode(
|
|
598
624
|
client=client,
|
|
599
625
|
input_mapping=node.get('input_mapping', {'messages': {'type': 'variable', 'value': 'messages'}}),
|
|
@@ -604,7 +630,9 @@ def create_graph(
|
|
|
604
630
|
input_variables=node.get('input', ['messages']),
|
|
605
631
|
structured_output=node.get('structured_output', False),
|
|
606
632
|
available_tools=available_tools,
|
|
607
|
-
tool_names=tool_names
|
|
633
|
+
tool_names=tool_names,
|
|
634
|
+
steps_limit=kwargs.get('steps_limit', 25)
|
|
635
|
+
))
|
|
608
636
|
elif node_type == 'router':
|
|
609
637
|
# Add a RouterNode as an independent node
|
|
610
638
|
lg_builder.add_node(node_id, RouterNode(
|
|
@@ -624,6 +652,7 @@ def create_graph(
|
|
|
624
652
|
default_output=node.get('default_output', 'END')
|
|
625
653
|
)
|
|
626
654
|
)
|
|
655
|
+
continue
|
|
627
656
|
elif node_type == 'state_modifier':
|
|
628
657
|
lg_builder.add_node(node_id, StateModifierNode(
|
|
629
658
|
template=node.get('template', ''),
|
|
@@ -631,6 +660,22 @@ def create_graph(
|
|
|
631
660
|
input_variables=node.get('input', ['messages']),
|
|
632
661
|
output_variables=node.get('output', [])
|
|
633
662
|
))
|
|
663
|
+
elif node_type == 'printer':
|
|
664
|
+
lg_builder.add_node(node_id, PrinterNode(
|
|
665
|
+
input_mapping=node.get('input_mapping', {'printer': {'type': 'fixed', 'value': ''}}),
|
|
666
|
+
))
|
|
667
|
+
|
|
668
|
+
# add interrupts after printer node if specified
|
|
669
|
+
interrupt_after.append(clean_string(node_id))
|
|
670
|
+
|
|
671
|
+
# reset printer output variable to avoid carrying over
|
|
672
|
+
reset_node_id = f"{node_id}_reset"
|
|
673
|
+
lg_builder.add_node(reset_node_id, PrinterNode(
|
|
674
|
+
input_mapping={'printer': {'type': 'fixed', 'value': PRINTER_COMPLETED_STATE}}
|
|
675
|
+
))
|
|
676
|
+
lg_builder.add_conditional_edges(node_id, TransitionalEdge(reset_node_id))
|
|
677
|
+
lg_builder.add_conditional_edges(reset_node_id, TransitionalEdge(clean_string(node['transition'])))
|
|
678
|
+
continue
|
|
634
679
|
if node.get('transition'):
|
|
635
680
|
next_step = clean_string(node['transition'])
|
|
636
681
|
logger.info(f'Adding transition: {next_step}')
|
|
@@ -777,31 +822,103 @@ class LangGraphAgentRunnable(CompiledStateGraph):
|
|
|
777
822
|
# Convert chat history dict messages to LangChain message objects
|
|
778
823
|
chat_history = input.pop('chat_history')
|
|
779
824
|
input['messages'] = [convert_dict_to_message(msg) for msg in chat_history]
|
|
780
|
-
|
|
825
|
+
|
|
826
|
+
# handler for LLM node: if no input (Chat perspective), then take last human message
|
|
827
|
+
# Track if input came from messages to handle content extraction properly
|
|
828
|
+
input_from_messages = False
|
|
829
|
+
if not input.get('input'):
|
|
830
|
+
if input.get('messages'):
|
|
831
|
+
input['input'] = [next((msg for msg in reversed(input['messages']) if isinstance(msg, HumanMessage)),
|
|
832
|
+
None)]
|
|
833
|
+
if input['input'] is not None:
|
|
834
|
+
input_from_messages = True
|
|
835
|
+
|
|
781
836
|
# Append current input to existing messages instead of overwriting
|
|
782
837
|
if input.get('input'):
|
|
783
838
|
if isinstance(input['input'], str):
|
|
784
839
|
current_message = input['input']
|
|
785
840
|
else:
|
|
841
|
+
# input can be a list of messages or a single message object
|
|
786
842
|
current_message = input.get('input')[-1]
|
|
843
|
+
|
|
787
844
|
# TODO: add handler after we add 2+ inputs (filterByType, etc.)
|
|
788
|
-
|
|
845
|
+
if isinstance(current_message, HumanMessage):
|
|
846
|
+
current_content = current_message.content
|
|
847
|
+
if isinstance(current_content, list):
|
|
848
|
+
# Extract text parts and keep non-text parts (images, etc.)
|
|
849
|
+
text_contents = []
|
|
850
|
+
non_text_parts = []
|
|
851
|
+
|
|
852
|
+
for item in current_content:
|
|
853
|
+
if isinstance(item, dict) and item.get('type') == 'text':
|
|
854
|
+
text_contents.append(item['text'])
|
|
855
|
+
elif isinstance(item, str):
|
|
856
|
+
text_contents.append(item)
|
|
857
|
+
else:
|
|
858
|
+
# Keep image_url and other non-text content
|
|
859
|
+
non_text_parts.append(item)
|
|
860
|
+
|
|
861
|
+
# Set input to the joined text
|
|
862
|
+
input['input'] = ". ".join(text_contents) if text_contents else ""
|
|
863
|
+
|
|
864
|
+
# If this message came from input['messages'], update or remove it
|
|
865
|
+
if input_from_messages:
|
|
866
|
+
if non_text_parts:
|
|
867
|
+
# Keep the message but only with non-text content (images, etc.)
|
|
868
|
+
current_message.content = non_text_parts
|
|
869
|
+
else:
|
|
870
|
+
# All content was text, remove this message from the list
|
|
871
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
872
|
+
|
|
873
|
+
elif isinstance(current_content, str):
|
|
874
|
+
# on regenerate case
|
|
875
|
+
input['input'] = current_content
|
|
876
|
+
# If from messages and all content is text, remove the message
|
|
877
|
+
if input_from_messages:
|
|
878
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
879
|
+
else:
|
|
880
|
+
input['input'] = str(current_content)
|
|
881
|
+
# If from messages, remove since we extracted the content
|
|
882
|
+
if input_from_messages:
|
|
883
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
884
|
+
elif isinstance(current_message, str):
|
|
885
|
+
input['input'] = current_message
|
|
886
|
+
else:
|
|
887
|
+
input['input'] = str(current_message)
|
|
789
888
|
if input.get('messages'):
|
|
790
889
|
# Ensure existing messages are LangChain objects
|
|
791
890
|
input['messages'] = [convert_dict_to_message(msg) for msg in input['messages']]
|
|
792
891
|
# Append to existing messages
|
|
793
|
-
input['messages'].append(current_message)
|
|
794
|
-
else:
|
|
795
|
-
#
|
|
796
|
-
input['messages'] = [current_message]
|
|
892
|
+
# input['messages'].append(current_message)
|
|
893
|
+
# else:
|
|
894
|
+
# NOTE: Commented out to prevent duplicates with input['input']
|
|
895
|
+
# input['messages'] = [current_message]
|
|
896
|
+
|
|
897
|
+
# Validate that input is not empty after all processing
|
|
898
|
+
if not input.get('input'):
|
|
899
|
+
raise RuntimeError(
|
|
900
|
+
"Empty input after processing. Cannot send empty string to LLM. "
|
|
901
|
+
"This likely means the message contained only non-text content "
|
|
902
|
+
"with no accompanying text."
|
|
903
|
+
)
|
|
904
|
+
|
|
797
905
|
logging.info(f"Input: {thread_id} - {input}")
|
|
798
906
|
if self.checkpointer and self.checkpointer.get_tuple(config):
|
|
799
907
|
self.update_state(config, input)
|
|
800
|
-
|
|
908
|
+
if config.pop("should_continue", False):
|
|
909
|
+
invoke_input = input
|
|
910
|
+
else:
|
|
911
|
+
invoke_input = None
|
|
912
|
+
result = super().invoke(invoke_input, config=config, *args, **kwargs)
|
|
801
913
|
else:
|
|
802
914
|
result = super().invoke(input, config=config, *args, **kwargs)
|
|
803
915
|
try:
|
|
804
|
-
|
|
916
|
+
if result.get(PRINTER_NODE_RS) == PRINTER_COMPLETED_STATE:
|
|
917
|
+
output = next((msg.content for msg in reversed(result['messages']) if not isinstance(msg, HumanMessage)),
|
|
918
|
+
result['messages'][-1].content)
|
|
919
|
+
else:
|
|
920
|
+
# used for printer node output - it will be reset by next `reset` node
|
|
921
|
+
output = result.get(PRINTER_NODE_RS)
|
|
805
922
|
except:
|
|
806
923
|
output = list(result.values())[-1]
|
|
807
924
|
config_state = self.get_state(config)
|
|
@@ -809,8 +926,6 @@ class LangGraphAgentRunnable(CompiledStateGraph):
|
|
|
809
926
|
if is_execution_finished:
|
|
810
927
|
thread_id = None
|
|
811
928
|
|
|
812
|
-
|
|
813
|
-
|
|
814
929
|
result_with_state = {
|
|
815
930
|
"output": output,
|
|
816
931
|
"thread_id": thread_id,
|
|
@@ -2,11 +2,12 @@ import builtins
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
import re
|
|
5
|
-
from pydantic import create_model, Field
|
|
5
|
+
from pydantic import create_model, Field, Json
|
|
6
6
|
from typing import Tuple, TypedDict, Any, Optional, Annotated
|
|
7
7
|
from langchain_core.messages import AnyMessage
|
|
8
|
-
from
|
|
9
|
-
|
|
8
|
+
from langgraph.graph import add_messages
|
|
9
|
+
|
|
10
|
+
from ...runtime.langchain.constants import ELITEA_RS, PRINTER_NODE_RS
|
|
10
11
|
|
|
11
12
|
logger = logging.getLogger(__name__)
|
|
12
13
|
|
|
@@ -130,13 +131,15 @@ def parse_type(type_str):
|
|
|
130
131
|
|
|
131
132
|
|
|
132
133
|
def create_state(data: Optional[dict] = None):
|
|
133
|
-
state_dict = {'input': str, '
|
|
134
|
+
state_dict = {'input': str, 'messages': 'list[str]', 'router_output': str,
|
|
135
|
+
ELITEA_RS: str, PRINTER_NODE_RS: str} # Always include router_output
|
|
134
136
|
types_dict = {}
|
|
135
137
|
if not data:
|
|
136
138
|
data = {'messages': 'list[str]'}
|
|
137
139
|
for key, value in data.items():
|
|
138
140
|
# support of old & new UI
|
|
139
141
|
value = value['type'] if isinstance(value, dict) else value
|
|
142
|
+
value = 'str' if value == 'string' else value # normalize string type (old state support)
|
|
140
143
|
if key == 'messages':
|
|
141
144
|
state_dict[key] = Annotated[list[AnyMessage], add_messages]
|
|
142
145
|
elif value in ['str', 'int', 'float', 'bool', 'list', 'dict', 'number', 'dict']:
|
|
@@ -181,16 +184,45 @@ def propagate_the_input_mapping(input_mapping: dict[str, dict], input_variables:
|
|
|
181
184
|
input_data[key] = value['value'].format(**var_dict)
|
|
182
185
|
except KeyError as e:
|
|
183
186
|
logger.error(f"KeyError in fstring formatting for key '{key}'. Attempt to find proper data in state.\n{e}")
|
|
184
|
-
|
|
187
|
+
try:
|
|
188
|
+
# search for variables in state if not found in var_dict
|
|
189
|
+
input_data[key] = safe_format(value['value'], state)
|
|
190
|
+
except KeyError as no_var_exception:
|
|
191
|
+
logger.error(f"KeyError in fstring formatting for key '{key}' with state data.\n{no_var_exception}")
|
|
192
|
+
# leave value as is if still not found (could be a constant string marked as fstring by mistake)
|
|
193
|
+
input_data[key] = value['value']
|
|
185
194
|
elif value['type'] == 'fixed':
|
|
186
195
|
input_data[key] = value['value']
|
|
187
196
|
else:
|
|
188
197
|
input_data[key] = source.get(value['value'], "")
|
|
189
198
|
return input_data
|
|
190
199
|
|
|
200
|
+
def safe_format(template, mapping):
|
|
201
|
+
"""Format a template string using a mapping, leaving placeholders unchanged if keys are missing."""
|
|
202
|
+
|
|
203
|
+
def replacer(match):
|
|
204
|
+
key = match.group(1)
|
|
205
|
+
return str(mapping.get(key, f'{{{key}}}'))
|
|
206
|
+
return re.sub(r'\{(\w+)\}', replacer, template)
|
|
191
207
|
|
|
192
208
|
def create_pydantic_model(model_name: str, variables: dict[str, dict]):
|
|
193
209
|
fields = {}
|
|
194
210
|
for var_name, var_data in variables.items():
|
|
195
|
-
fields[var_name] = (
|
|
196
|
-
return create_model(model_name, **fields)
|
|
211
|
+
fields[var_name] = (parse_pydantic_type(var_data['type']), Field(description=var_data.get('description', None)))
|
|
212
|
+
return create_model(model_name, **fields)
|
|
213
|
+
|
|
214
|
+
def parse_pydantic_type(type_name: str):
|
|
215
|
+
"""
|
|
216
|
+
Helper function to parse type names into Python types.
|
|
217
|
+
Extend this function to handle custom types like 'dict' -> Json[Any].
|
|
218
|
+
"""
|
|
219
|
+
type_mapping = {
|
|
220
|
+
'str': str,
|
|
221
|
+
'int': int,
|
|
222
|
+
'float': float,
|
|
223
|
+
'bool': bool,
|
|
224
|
+
'dict': Json[Any], # Map 'dict' to Pydantic's Json type
|
|
225
|
+
'list': list,
|
|
226
|
+
'any': Any
|
|
227
|
+
}
|
|
228
|
+
return type_mapping.get(type_name, Any)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Models for MCP (Model Context Protocol) configuration.
|
|
3
|
+
Following MCP specification for remote HTTP servers only.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional, List, Dict, Any
|
|
7
|
+
from pydantic import BaseModel, Field, validator
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class McpConnectionConfig(BaseModel):
|
|
12
|
+
"""
|
|
13
|
+
MCP connection configuration for remote HTTP servers.
|
|
14
|
+
Based on https://modelcontextprotocol.io/specification/2025-06-18
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
url: str = Field(description="MCP server HTTP URL (http:// or https://)")
|
|
18
|
+
headers: Optional[Dict[str, str]] = Field(
|
|
19
|
+
default=None,
|
|
20
|
+
description="HTTP headers for the connection (JSON object)"
|
|
21
|
+
)
|
|
22
|
+
session_id: Optional[str] = Field(
|
|
23
|
+
default=None,
|
|
24
|
+
description="MCP session ID for stateful SSE servers (managed by client)"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
@validator('url')
|
|
28
|
+
def validate_url(cls, v):
|
|
29
|
+
"""Validate URL is HTTP/HTTPS."""
|
|
30
|
+
if not v:
|
|
31
|
+
raise ValueError("URL cannot be empty")
|
|
32
|
+
|
|
33
|
+
parsed = urlparse(v)
|
|
34
|
+
if parsed.scheme not in ['http', 'https']:
|
|
35
|
+
raise ValueError("URL must use http:// or https:// scheme for remote MCP servers")
|
|
36
|
+
|
|
37
|
+
if not parsed.netloc:
|
|
38
|
+
raise ValueError("URL must include host and port")
|
|
39
|
+
|
|
40
|
+
return v
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class McpToolkitConfig(BaseModel):
|
|
44
|
+
"""Configuration for a single remote MCP server toolkit."""
|
|
45
|
+
|
|
46
|
+
server_name: str = Field(description="MCP server name/identifier")
|
|
47
|
+
connection: McpConnectionConfig = Field(description="MCP connection configuration")
|
|
48
|
+
timeout: int = Field(default=60, description="Request timeout in seconds", ge=1, le=3600)
|
|
49
|
+
selected_tools: List[str] = Field(default_factory=list, description="Specific tools to enable (empty = all)")
|
|
50
|
+
enable_caching: bool = Field(default=True, description="Enable tool schema caching")
|
|
51
|
+
cache_ttl: int = Field(default=300, description="Cache TTL in seconds", ge=60, le=3600)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class McpToolMetadata(BaseModel):
|
|
55
|
+
"""Metadata about an MCP tool."""
|
|
56
|
+
|
|
57
|
+
name: str = Field(description="Tool name")
|
|
58
|
+
description: str = Field(description="Tool description")
|
|
59
|
+
server: str = Field(description="Source server name")
|
|
60
|
+
input_schema: Dict[str, Any] = Field(description="Tool input schema")
|
|
61
|
+
enabled: bool = Field(default=True, description="Whether tool is enabled")
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Runtime toolkits module for Alita SDK.
|
|
3
|
+
This module provides various toolkit implementations for LangGraph agents.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .application import ApplicationToolkit
|
|
7
|
+
from .artifact import ArtifactToolkit
|
|
8
|
+
from .datasource import DatasourcesToolkit
|
|
9
|
+
from .prompt import PromptToolkit
|
|
10
|
+
from .subgraph import SubgraphToolkit
|
|
11
|
+
from .vectorstore import VectorStoreToolkit
|
|
12
|
+
from .mcp import McpToolkit
|
|
13
|
+
from ...tools.memory import MemoryToolkit
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ApplicationToolkit",
|
|
17
|
+
"ArtifactToolkit",
|
|
18
|
+
"DatasourcesToolkit",
|
|
19
|
+
"PromptToolkit",
|
|
20
|
+
"SubgraphToolkit",
|
|
21
|
+
"VectorStoreToolkit",
|
|
22
|
+
"McpToolkit",
|
|
23
|
+
"MemoryToolkit"
|
|
24
|
+
]
|
|
@@ -39,7 +39,14 @@ class ApplicationToolkit(BaseToolkit):
|
|
|
39
39
|
description=app_details.get("description"),
|
|
40
40
|
application=app,
|
|
41
41
|
args_schema=applicationToolSchema,
|
|
42
|
-
return_type='str'
|
|
42
|
+
return_type='str',
|
|
43
|
+
client=client,
|
|
44
|
+
args_runnable={
|
|
45
|
+
"application_id": application_id,
|
|
46
|
+
"application_version_id": application_version_id,
|
|
47
|
+
"store": store,
|
|
48
|
+
"llm": client.get_llm(version_details['llm_settings']['model_name'], model_settings),
|
|
49
|
+
})])
|
|
43
50
|
|
|
44
51
|
def get_tools(self):
|
|
45
52
|
return self.tools
|
|
@@ -23,11 +23,7 @@ class ArtifactToolkit(BaseToolkit):
|
|
|
23
23
|
# client = (Any, FieldInfo(description="Client object", required=True, autopopulate=True)),
|
|
24
24
|
bucket=(str, FieldInfo(
|
|
25
25
|
description="Bucket name",
|
|
26
|
-
pattern=r'^[a-z][a-z0-9-]*$'
|
|
27
|
-
json_schema_extra={
|
|
28
|
-
'toolkit_name': True,
|
|
29
|
-
'max_toolkit_length': ArtifactToolkit.toolkit_max_length
|
|
30
|
-
}
|
|
26
|
+
pattern=r'^[a-z][a-z0-9-]*$'
|
|
31
27
|
)),
|
|
32
28
|
selected_tools=(List[Literal[tuple(selected_tools)]], Field(default=[], json_schema_extra={'args_schemas': selected_tools})),
|
|
33
29
|
# indexer settings
|
|
@@ -37,7 +33,10 @@ class ArtifactToolkit(BaseToolkit):
|
|
|
37
33
|
embedding_model=(Optional[str], Field(default=None, description="Embedding configuration.",
|
|
38
34
|
json_schema_extra={'configuration_model': 'embedding'})),
|
|
39
35
|
|
|
40
|
-
__config__=ConfigDict(json_schema_extra={'metadata': {"label": "Artifact",
|
|
36
|
+
__config__=ConfigDict(json_schema_extra={'metadata': {"label": "Artifact",
|
|
37
|
+
"icon_url": None,
|
|
38
|
+
"max_length": ArtifactToolkit.toolkit_max_length
|
|
39
|
+
}})
|
|
41
40
|
)
|
|
42
41
|
|
|
43
42
|
@classmethod
|