alita-sdk 0.3.379__py3-none-any.whl → 0.3.462__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of alita-sdk might be problematic. Click here for more details.

Files changed (110) hide show
  1. alita_sdk/cli/__init__.py +10 -0
  2. alita_sdk/cli/__main__.py +17 -0
  3. alita_sdk/cli/agent_executor.py +144 -0
  4. alita_sdk/cli/agent_loader.py +197 -0
  5. alita_sdk/cli/agent_ui.py +166 -0
  6. alita_sdk/cli/agents.py +1069 -0
  7. alita_sdk/cli/callbacks.py +576 -0
  8. alita_sdk/cli/cli.py +159 -0
  9. alita_sdk/cli/config.py +153 -0
  10. alita_sdk/cli/formatting.py +182 -0
  11. alita_sdk/cli/mcp_loader.py +315 -0
  12. alita_sdk/cli/toolkit.py +330 -0
  13. alita_sdk/cli/toolkit_loader.py +55 -0
  14. alita_sdk/cli/tools/__init__.py +9 -0
  15. alita_sdk/cli/tools/filesystem.py +905 -0
  16. alita_sdk/configurations/bitbucket.py +95 -0
  17. alita_sdk/configurations/confluence.py +96 -1
  18. alita_sdk/configurations/gitlab.py +79 -0
  19. alita_sdk/configurations/jira.py +103 -0
  20. alita_sdk/configurations/testrail.py +88 -0
  21. alita_sdk/configurations/xray.py +93 -0
  22. alita_sdk/configurations/zephyr_enterprise.py +93 -0
  23. alita_sdk/configurations/zephyr_essential.py +75 -0
  24. alita_sdk/runtime/clients/client.py +47 -10
  25. alita_sdk/runtime/clients/mcp_discovery.py +342 -0
  26. alita_sdk/runtime/clients/mcp_manager.py +262 -0
  27. alita_sdk/runtime/clients/sandbox_client.py +8 -0
  28. alita_sdk/runtime/langchain/assistant.py +37 -16
  29. alita_sdk/runtime/langchain/constants.py +6 -1
  30. alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
  31. alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +4 -1
  32. alita_sdk/runtime/langchain/document_loaders/constants.py +28 -12
  33. alita_sdk/runtime/langchain/langraph_agent.py +146 -31
  34. alita_sdk/runtime/langchain/utils.py +39 -7
  35. alita_sdk/runtime/models/mcp_models.py +61 -0
  36. alita_sdk/runtime/toolkits/__init__.py +24 -0
  37. alita_sdk/runtime/toolkits/application.py +8 -1
  38. alita_sdk/runtime/toolkits/artifact.py +5 -6
  39. alita_sdk/runtime/toolkits/mcp.py +895 -0
  40. alita_sdk/runtime/toolkits/tools.py +137 -56
  41. alita_sdk/runtime/tools/__init__.py +7 -2
  42. alita_sdk/runtime/tools/application.py +7 -0
  43. alita_sdk/runtime/tools/function.py +29 -25
  44. alita_sdk/runtime/tools/graph.py +10 -4
  45. alita_sdk/runtime/tools/image_generation.py +104 -8
  46. alita_sdk/runtime/tools/llm.py +204 -114
  47. alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
  48. alita_sdk/runtime/tools/mcp_remote_tool.py +166 -0
  49. alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
  50. alita_sdk/runtime/tools/sandbox.py +57 -43
  51. alita_sdk/runtime/tools/vectorstore.py +2 -1
  52. alita_sdk/runtime/tools/vectorstore_base.py +19 -3
  53. alita_sdk/runtime/utils/mcp_oauth.py +164 -0
  54. alita_sdk/runtime/utils/mcp_sse_client.py +405 -0
  55. alita_sdk/runtime/utils/streamlit.py +34 -3
  56. alita_sdk/runtime/utils/toolkit_utils.py +14 -4
  57. alita_sdk/tools/__init__.py +46 -31
  58. alita_sdk/tools/ado/repos/__init__.py +1 -0
  59. alita_sdk/tools/ado/test_plan/__init__.py +1 -1
  60. alita_sdk/tools/ado/wiki/__init__.py +1 -5
  61. alita_sdk/tools/ado/work_item/__init__.py +1 -5
  62. alita_sdk/tools/ado/work_item/ado_wrapper.py +17 -8
  63. alita_sdk/tools/base_indexer_toolkit.py +105 -43
  64. alita_sdk/tools/bitbucket/__init__.py +1 -0
  65. alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
  66. alita_sdk/tools/code/sonar/__init__.py +1 -1
  67. alita_sdk/tools/code_indexer_toolkit.py +13 -3
  68. alita_sdk/tools/confluence/__init__.py +2 -2
  69. alita_sdk/tools/confluence/api_wrapper.py +29 -7
  70. alita_sdk/tools/confluence/loader.py +10 -0
  71. alita_sdk/tools/github/__init__.py +2 -2
  72. alita_sdk/tools/gitlab/__init__.py +2 -1
  73. alita_sdk/tools/gitlab/api_wrapper.py +11 -7
  74. alita_sdk/tools/gitlab_org/__init__.py +1 -2
  75. alita_sdk/tools/google_places/__init__.py +2 -1
  76. alita_sdk/tools/jira/__init__.py +1 -0
  77. alita_sdk/tools/jira/api_wrapper.py +1 -1
  78. alita_sdk/tools/memory/__init__.py +1 -1
  79. alita_sdk/tools/openapi/__init__.py +10 -1
  80. alita_sdk/tools/pandas/__init__.py +1 -1
  81. alita_sdk/tools/postman/__init__.py +2 -1
  82. alita_sdk/tools/pptx/__init__.py +2 -2
  83. alita_sdk/tools/qtest/__init__.py +3 -3
  84. alita_sdk/tools/qtest/api_wrapper.py +1708 -76
  85. alita_sdk/tools/rally/__init__.py +1 -2
  86. alita_sdk/tools/report_portal/__init__.py +1 -0
  87. alita_sdk/tools/salesforce/__init__.py +1 -0
  88. alita_sdk/tools/servicenow/__init__.py +2 -3
  89. alita_sdk/tools/sharepoint/__init__.py +1 -0
  90. alita_sdk/tools/sharepoint/api_wrapper.py +125 -34
  91. alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
  92. alita_sdk/tools/sharepoint/utils.py +8 -2
  93. alita_sdk/tools/slack/__init__.py +1 -0
  94. alita_sdk/tools/sql/__init__.py +2 -1
  95. alita_sdk/tools/testio/__init__.py +1 -0
  96. alita_sdk/tools/testrail/__init__.py +1 -3
  97. alita_sdk/tools/utils/content_parser.py +27 -16
  98. alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +18 -5
  99. alita_sdk/tools/xray/__init__.py +2 -1
  100. alita_sdk/tools/zephyr/__init__.py +2 -1
  101. alita_sdk/tools/zephyr_enterprise/__init__.py +1 -0
  102. alita_sdk/tools/zephyr_essential/__init__.py +1 -0
  103. alita_sdk/tools/zephyr_scale/__init__.py +1 -0
  104. alita_sdk/tools/zephyr_squad/__init__.py +1 -0
  105. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/METADATA +8 -2
  106. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/RECORD +110 -86
  107. alita_sdk-0.3.462.dist-info/entry_points.txt +2 -0
  108. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/WHEEL +0 -0
  109. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/licenses/LICENSE +0 -0
  110. {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.462.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,9 @@ from langgraph.managed.base import is_managed_value
19
19
  from langgraph.prebuilt import InjectedStore
20
20
  from langgraph.store.base import BaseStore
21
21
 
22
+ from .constants import PRINTER_NODE_RS, PRINTER, PRINTER_COMPLETED_STATE
22
23
  from .mixedAgentRenderes import convert_message_to_json
23
- from .utils import create_state, propagate_the_input_mapping
24
+ from .utils import create_state, propagate_the_input_mapping, safe_format
24
25
  from ..tools.function import FunctionTool
25
26
  from ..tools.indexer_tool import IndexerNode
26
27
  from ..tools.llm import LLMNode
@@ -232,6 +233,32 @@ class StateDefaultNode(Runnable):
232
233
  result[key] = temp_value
233
234
  return result
234
235
 
236
+ class PrinterNode(Runnable):
237
+ name = "PrinterNode"
238
+
239
+ def __init__(self, input_mapping: Optional[dict[str, dict]]):
240
+ self.input_mapping = input_mapping
241
+
242
+ def invoke(self, state: BaseStore, config: Optional[RunnableConfig] = None) -> dict:
243
+ logger.info(f"Printer Node - Current state variables: {state}")
244
+ result = {}
245
+ logger.debug(f"Initial text pattern: {self.input_mapping}")
246
+ mapping = propagate_the_input_mapping(self.input_mapping, [], state)
247
+ # for printer node we expect that all the lists will be joined into strings already
248
+ # Join any lists that haven't been converted yet
249
+ for key, value in mapping.items():
250
+ if isinstance(value, list):
251
+ mapping[key] = ', '.join(str(item) for item in value)
252
+ if mapping.get(PRINTER) is None:
253
+ raise ToolException(f"PrinterNode requires '{PRINTER}' field in input mapping")
254
+ formatted_output = mapping[PRINTER]
255
+ # add info label to the printer's output
256
+ if not formatted_output == PRINTER_COMPLETED_STATE:
257
+ formatted_output += f"\n\n-----\n*How to proceed?*\n* *to resume the pipeline - type anything...*"
258
+ logger.debug(f"Formatted output: {formatted_output}")
259
+ result[PRINTER_NODE_RS] = formatted_output
260
+ return result
261
+
235
262
 
236
263
  class StateModifierNode(Runnable):
237
264
  name = "StateModifierNode"
@@ -348,8 +375,8 @@ class StateModifierNode(Runnable):
348
375
  return result
349
376
 
350
377
 
351
-
352
- def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None, state_class=None, output_variables=None):
378
+ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None,
379
+ state_class=None, output_variables=None):
353
380
  # prepare output channels
354
381
  if interrupt_after is None:
355
382
  interrupt_after = []
@@ -453,10 +480,14 @@ def create_graph(
453
480
  if toolkit_name:
454
481
  tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
455
482
  logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
456
- if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
483
+ if node_type in ['function', 'toolkit', 'mcp', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
484
+ if node_type == 'mcp' and tool_name not in [tool.name for tool in tools]:
485
+ # MCP is not connected and node cannot be added
486
+ raise ToolException(f"MCP tool '{tool_name}' not found in the provided tools. "
487
+ f"Make sure it is connected properly. Available tools: {[tool.name for tool in tools]}")
457
488
  for tool in tools:
458
489
  if tool.name == tool_name:
459
- if node_type == 'function':
490
+ if node_type in ['function', 'toolkit', 'mcp']:
460
491
  lg_builder.add_node(node_id, FunctionTool(
461
492
  tool=tool, name=node_id, return_type='dict',
462
493
  output_variables=node.get('output', []),
@@ -466,11 +497,12 @@ def create_graph(
466
497
  elif node_type == 'agent':
467
498
  input_params = node.get('input', ['messages'])
468
499
  input_mapping = node.get('input_mapping',
469
- {'messages': {'type': 'variable', 'value': 'messages'}})
500
+ {'messages': {'type': 'variable', 'value': 'messages'}})
501
+ output_vars = node.get('output', [])
470
502
  lg_builder.add_node(node_id, FunctionTool(
471
503
  client=client, tool=tool,
472
504
  name=node_id, return_type='str',
473
- output_variables=node.get('output', []),
505
+ output_variables=output_vars + ['messages'] if 'messages' not in output_vars else output_vars,
474
506
  input_variables=input_params,
475
507
  input_mapping= input_mapping
476
508
  ))
@@ -481,7 +513,8 @@ def create_graph(
481
513
  # wrap with mappings
482
514
  pipeline_name = node.get('tool', None)
483
515
  if not pipeline_name:
484
- raise ValueError("Subgraph must have a 'tool' node: add required tool to the subgraph node")
516
+ raise ValueError(
517
+ "Subgraph must have a 'tool' node: add required tool to the subgraph node")
485
518
  node_fn = SubgraphRunnable(
486
519
  inner=tool.graph,
487
520
  name=pipeline_name,
@@ -499,15 +532,6 @@ def create_graph(
499
532
  structured_output=node.get('structured_output', False),
500
533
  task=node.get('task')
501
534
  ))
502
- # TODO: decide on struct output for agent nodes
503
- # elif node_type == 'agent':
504
- # lg_builder.add_node(node_id, AgentNode(
505
- # client=client, tool=tool,
506
- # name=node['id'], return_type='dict',
507
- # output_variables=node.get('output', []),
508
- # input_variables=node.get('input', ['messages']),
509
- # task=node.get('task')
510
- # ))
511
535
  elif node_type == 'loop':
512
536
  lg_builder.add_node(node_id, LoopNode(
513
537
  client=client, tool=tool,
@@ -520,7 +544,8 @@ def create_graph(
520
544
  loop_toolkit_name = node.get('loop_toolkit_name')
521
545
  loop_tool_name = node.get('loop_tool')
522
546
  if (loop_toolkit_name and loop_tool_name) or loop_tool_name:
523
- loop_tool_name = f"{clean_string(loop_toolkit_name)}{TOOLKIT_SPLITTER}{loop_tool_name}" if loop_toolkit_name else clean_string(loop_tool_name)
547
+ loop_tool_name = f"{clean_string(loop_toolkit_name)}{TOOLKIT_SPLITTER}{loop_tool_name}" if loop_toolkit_name else clean_string(
548
+ loop_tool_name)
524
549
  for t in tools:
525
550
  if t.name == loop_tool_name:
526
551
  logger.debug(f"Loop tool discovered: {t}")
@@ -555,7 +580,8 @@ def create_graph(
555
580
  break
556
581
  elif node_type == 'code':
557
582
  from ..tools.sandbox import create_sandbox_tool
558
- sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True)
583
+ sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True,
584
+ alita_client=kwargs.get('alita_client', None))
559
585
  code_data = node.get('code', {'type': 'fixed', 'value': "return 'Code block is empty'"})
560
586
  lg_builder.add_node(node_id, FunctionTool(
561
587
  tool=sandbox_tool, name=node['id'], return_type='dict',
@@ -593,7 +619,7 @@ def create_graph(
593
619
  else:
594
620
  # Use all available tools
595
621
  available_tools = [tool for tool in tools if isinstance(tool, BaseTool)]
596
-
622
+
597
623
  lg_builder.add_node(node_id, LLMNode(
598
624
  client=client,
599
625
  input_mapping=node.get('input_mapping', {'messages': {'type': 'variable', 'value': 'messages'}}),
@@ -604,7 +630,9 @@ def create_graph(
604
630
  input_variables=node.get('input', ['messages']),
605
631
  structured_output=node.get('structured_output', False),
606
632
  available_tools=available_tools,
607
- tool_names=tool_names))
633
+ tool_names=tool_names,
634
+ steps_limit=kwargs.get('steps_limit', 25)
635
+ ))
608
636
  elif node_type == 'router':
609
637
  # Add a RouterNode as an independent node
610
638
  lg_builder.add_node(node_id, RouterNode(
@@ -624,6 +652,7 @@ def create_graph(
624
652
  default_output=node.get('default_output', 'END')
625
653
  )
626
654
  )
655
+ continue
627
656
  elif node_type == 'state_modifier':
628
657
  lg_builder.add_node(node_id, StateModifierNode(
629
658
  template=node.get('template', ''),
@@ -631,6 +660,22 @@ def create_graph(
631
660
  input_variables=node.get('input', ['messages']),
632
661
  output_variables=node.get('output', [])
633
662
  ))
663
+ elif node_type == 'printer':
664
+ lg_builder.add_node(node_id, PrinterNode(
665
+ input_mapping=node.get('input_mapping', {'printer': {'type': 'fixed', 'value': ''}}),
666
+ ))
667
+
668
+ # add interrupts after printer node if specified
669
+ interrupt_after.append(clean_string(node_id))
670
+
671
+ # reset printer output variable to avoid carrying over
672
+ reset_node_id = f"{node_id}_reset"
673
+ lg_builder.add_node(reset_node_id, PrinterNode(
674
+ input_mapping={'printer': {'type': 'fixed', 'value': PRINTER_COMPLETED_STATE}}
675
+ ))
676
+ lg_builder.add_conditional_edges(node_id, TransitionalEdge(reset_node_id))
677
+ lg_builder.add_conditional_edges(reset_node_id, TransitionalEdge(clean_string(node['transition'])))
678
+ continue
634
679
  if node.get('transition'):
635
680
  next_step = clean_string(node['transition'])
636
681
  logger.info(f'Adding transition: {next_step}')
@@ -777,31 +822,103 @@ class LangGraphAgentRunnable(CompiledStateGraph):
777
822
  # Convert chat history dict messages to LangChain message objects
778
823
  chat_history = input.pop('chat_history')
779
824
  input['messages'] = [convert_dict_to_message(msg) for msg in chat_history]
780
-
825
+
826
+ # handler for LLM node: if no input (Chat perspective), then take last human message
827
+ # Track if input came from messages to handle content extraction properly
828
+ input_from_messages = False
829
+ if not input.get('input'):
830
+ if input.get('messages'):
831
+ input['input'] = [next((msg for msg in reversed(input['messages']) if isinstance(msg, HumanMessage)),
832
+ None)]
833
+ if input['input'] is not None:
834
+ input_from_messages = True
835
+
781
836
  # Append current input to existing messages instead of overwriting
782
837
  if input.get('input'):
783
838
  if isinstance(input['input'], str):
784
839
  current_message = input['input']
785
840
  else:
841
+ # input can be a list of messages or a single message object
786
842
  current_message = input.get('input')[-1]
843
+
787
844
  # TODO: add handler after we add 2+ inputs (filterByType, etc.)
788
- input['input'] = current_message if isinstance(current_message, str) else str(current_message)
845
+ if isinstance(current_message, HumanMessage):
846
+ current_content = current_message.content
847
+ if isinstance(current_content, list):
848
+ # Extract text parts and keep non-text parts (images, etc.)
849
+ text_contents = []
850
+ non_text_parts = []
851
+
852
+ for item in current_content:
853
+ if isinstance(item, dict) and item.get('type') == 'text':
854
+ text_contents.append(item['text'])
855
+ elif isinstance(item, str):
856
+ text_contents.append(item)
857
+ else:
858
+ # Keep image_url and other non-text content
859
+ non_text_parts.append(item)
860
+
861
+ # Set input to the joined text
862
+ input['input'] = ". ".join(text_contents) if text_contents else ""
863
+
864
+ # If this message came from input['messages'], update or remove it
865
+ if input_from_messages:
866
+ if non_text_parts:
867
+ # Keep the message but only with non-text content (images, etc.)
868
+ current_message.content = non_text_parts
869
+ else:
870
+ # All content was text, remove this message from the list
871
+ input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
872
+
873
+ elif isinstance(current_content, str):
874
+ # on regenerate case
875
+ input['input'] = current_content
876
+ # If from messages and all content is text, remove the message
877
+ if input_from_messages:
878
+ input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
879
+ else:
880
+ input['input'] = str(current_content)
881
+ # If from messages, remove since we extracted the content
882
+ if input_from_messages:
883
+ input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
884
+ elif isinstance(current_message, str):
885
+ input['input'] = current_message
886
+ else:
887
+ input['input'] = str(current_message)
789
888
  if input.get('messages'):
790
889
  # Ensure existing messages are LangChain objects
791
890
  input['messages'] = [convert_dict_to_message(msg) for msg in input['messages']]
792
891
  # Append to existing messages
793
- input['messages'].append(current_message)
794
- else:
795
- # No existing messages, create new list
796
- input['messages'] = [current_message]
892
+ # input['messages'].append(current_message)
893
+ # else:
894
+ # NOTE: Commented out to prevent duplicates with input['input']
895
+ # input['messages'] = [current_message]
896
+
897
+ # Validate that input is not empty after all processing
898
+ if not input.get('input'):
899
+ raise RuntimeError(
900
+ "Empty input after processing. Cannot send empty string to LLM. "
901
+ "This likely means the message contained only non-text content "
902
+ "with no accompanying text."
903
+ )
904
+
797
905
  logging.info(f"Input: {thread_id} - {input}")
798
906
  if self.checkpointer and self.checkpointer.get_tuple(config):
799
907
  self.update_state(config, input)
800
- result = super().invoke(None, config=config, *args, **kwargs)
908
+ if config.pop("should_continue", False):
909
+ invoke_input = input
910
+ else:
911
+ invoke_input = None
912
+ result = super().invoke(invoke_input, config=config, *args, **kwargs)
801
913
  else:
802
914
  result = super().invoke(input, config=config, *args, **kwargs)
803
915
  try:
804
- output = next((msg.content for msg in reversed(result['messages']) if not isinstance(msg, HumanMessage)), result['messages'][-1].content)
916
+ if result.get(PRINTER_NODE_RS) == PRINTER_COMPLETED_STATE:
917
+ output = next((msg.content for msg in reversed(result['messages']) if not isinstance(msg, HumanMessage)),
918
+ result['messages'][-1].content)
919
+ else:
920
+ # used for printer node output - it will be reset by next `reset` node
921
+ output = result.get(PRINTER_NODE_RS)
805
922
  except:
806
923
  output = list(result.values())[-1]
807
924
  config_state = self.get_state(config)
@@ -809,8 +926,6 @@ class LangGraphAgentRunnable(CompiledStateGraph):
809
926
  if is_execution_finished:
810
927
  thread_id = None
811
928
 
812
-
813
-
814
929
  result_with_state = {
815
930
  "output": output,
816
931
  "thread_id": thread_id,
@@ -2,11 +2,12 @@ import builtins
2
2
  import json
3
3
  import logging
4
4
  import re
5
- from pydantic import create_model, Field
5
+ from pydantic import create_model, Field, Json
6
6
  from typing import Tuple, TypedDict, Any, Optional, Annotated
7
7
  from langchain_core.messages import AnyMessage
8
- from langchain_core.prompts import PromptTemplate
9
- from langgraph.graph import MessagesState, add_messages
8
+ from langgraph.graph import add_messages
9
+
10
+ from ...runtime.langchain.constants import ELITEA_RS, PRINTER_NODE_RS
10
11
 
11
12
  logger = logging.getLogger(__name__)
12
13
 
@@ -130,13 +131,15 @@ def parse_type(type_str):
130
131
 
131
132
 
132
133
  def create_state(data: Optional[dict] = None):
133
- state_dict = {'input': str, 'router_output': str} # Always include router_output
134
+ state_dict = {'input': str, 'messages': 'list[str]', 'router_output': str,
135
+ ELITEA_RS: str, PRINTER_NODE_RS: str} # Always include router_output
134
136
  types_dict = {}
135
137
  if not data:
136
138
  data = {'messages': 'list[str]'}
137
139
  for key, value in data.items():
138
140
  # support of old & new UI
139
141
  value = value['type'] if isinstance(value, dict) else value
142
+ value = 'str' if value == 'string' else value # normalize string type (old state support)
140
143
  if key == 'messages':
141
144
  state_dict[key] = Annotated[list[AnyMessage], add_messages]
142
145
  elif value in ['str', 'int', 'float', 'bool', 'list', 'dict', 'number', 'dict']:
@@ -181,16 +184,45 @@ def propagate_the_input_mapping(input_mapping: dict[str, dict], input_variables:
181
184
  input_data[key] = value['value'].format(**var_dict)
182
185
  except KeyError as e:
183
186
  logger.error(f"KeyError in fstring formatting for key '{key}'. Attempt to find proper data in state.\n{e}")
184
- input_data[key] = value['value'].format(**state)
187
+ try:
188
+ # search for variables in state if not found in var_dict
189
+ input_data[key] = safe_format(value['value'], state)
190
+ except KeyError as no_var_exception:
191
+ logger.error(f"KeyError in fstring formatting for key '{key}' with state data.\n{no_var_exception}")
192
+ # leave value as is if still not found (could be a constant string marked as fstring by mistake)
193
+ input_data[key] = value['value']
185
194
  elif value['type'] == 'fixed':
186
195
  input_data[key] = value['value']
187
196
  else:
188
197
  input_data[key] = source.get(value['value'], "")
189
198
  return input_data
190
199
 
200
+ def safe_format(template, mapping):
201
+ """Format a template string using a mapping, leaving placeholders unchanged if keys are missing."""
202
+
203
+ def replacer(match):
204
+ key = match.group(1)
205
+ return str(mapping.get(key, f'{{{key}}}'))
206
+ return re.sub(r'\{(\w+)\}', replacer, template)
191
207
 
192
208
  def create_pydantic_model(model_name: str, variables: dict[str, dict]):
193
209
  fields = {}
194
210
  for var_name, var_data in variables.items():
195
- fields[var_name] = (parse_type(var_data['type']), Field(description=var_data.get('description', None)))
196
- return create_model(model_name, **fields)
211
+ fields[var_name] = (parse_pydantic_type(var_data['type']), Field(description=var_data.get('description', None)))
212
+ return create_model(model_name, **fields)
213
+
214
+ def parse_pydantic_type(type_name: str):
215
+ """
216
+ Helper function to parse type names into Python types.
217
+ Extend this function to handle custom types like 'dict' -> Json[Any].
218
+ """
219
+ type_mapping = {
220
+ 'str': str,
221
+ 'int': int,
222
+ 'float': float,
223
+ 'bool': bool,
224
+ 'dict': Json[Any], # Map 'dict' to Pydantic's Json type
225
+ 'list': list,
226
+ 'any': Any
227
+ }
228
+ return type_mapping.get(type_name, Any)
@@ -0,0 +1,61 @@
1
+ """
2
+ Models for MCP (Model Context Protocol) configuration.
3
+ Following MCP specification for remote HTTP servers only.
4
+ """
5
+
6
+ from typing import Optional, List, Dict, Any
7
+ from pydantic import BaseModel, Field, validator
8
+ from urllib.parse import urlparse
9
+
10
+
11
+ class McpConnectionConfig(BaseModel):
12
+ """
13
+ MCP connection configuration for remote HTTP servers.
14
+ Based on https://modelcontextprotocol.io/specification/2025-06-18
15
+ """
16
+
17
+ url: str = Field(description="MCP server HTTP URL (http:// or https://)")
18
+ headers: Optional[Dict[str, str]] = Field(
19
+ default=None,
20
+ description="HTTP headers for the connection (JSON object)"
21
+ )
22
+ session_id: Optional[str] = Field(
23
+ default=None,
24
+ description="MCP session ID for stateful SSE servers (managed by client)"
25
+ )
26
+
27
+ @validator('url')
28
+ def validate_url(cls, v):
29
+ """Validate URL is HTTP/HTTPS."""
30
+ if not v:
31
+ raise ValueError("URL cannot be empty")
32
+
33
+ parsed = urlparse(v)
34
+ if parsed.scheme not in ['http', 'https']:
35
+ raise ValueError("URL must use http:// or https:// scheme for remote MCP servers")
36
+
37
+ if not parsed.netloc:
38
+ raise ValueError("URL must include host and port")
39
+
40
+ return v
41
+
42
+
43
+ class McpToolkitConfig(BaseModel):
44
+ """Configuration for a single remote MCP server toolkit."""
45
+
46
+ server_name: str = Field(description="MCP server name/identifier")
47
+ connection: McpConnectionConfig = Field(description="MCP connection configuration")
48
+ timeout: int = Field(default=60, description="Request timeout in seconds", ge=1, le=3600)
49
+ selected_tools: List[str] = Field(default_factory=list, description="Specific tools to enable (empty = all)")
50
+ enable_caching: bool = Field(default=True, description="Enable tool schema caching")
51
+ cache_ttl: int = Field(default=300, description="Cache TTL in seconds", ge=60, le=3600)
52
+
53
+
54
+ class McpToolMetadata(BaseModel):
55
+ """Metadata about an MCP tool."""
56
+
57
+ name: str = Field(description="Tool name")
58
+ description: str = Field(description="Tool description")
59
+ server: str = Field(description="Source server name")
60
+ input_schema: Dict[str, Any] = Field(description="Tool input schema")
61
+ enabled: bool = Field(default=True, description="Whether tool is enabled")
@@ -0,0 +1,24 @@
1
+ """
2
+ Runtime toolkits module for Alita SDK.
3
+ This module provides various toolkit implementations for LangGraph agents.
4
+ """
5
+
6
+ from .application import ApplicationToolkit
7
+ from .artifact import ArtifactToolkit
8
+ from .datasource import DatasourcesToolkit
9
+ from .prompt import PromptToolkit
10
+ from .subgraph import SubgraphToolkit
11
+ from .vectorstore import VectorStoreToolkit
12
+ from .mcp import McpToolkit
13
+ from ...tools.memory import MemoryToolkit
14
+
15
+ __all__ = [
16
+ "ApplicationToolkit",
17
+ "ArtifactToolkit",
18
+ "DatasourcesToolkit",
19
+ "PromptToolkit",
20
+ "SubgraphToolkit",
21
+ "VectorStoreToolkit",
22
+ "McpToolkit",
23
+ "MemoryToolkit"
24
+ ]
@@ -39,7 +39,14 @@ class ApplicationToolkit(BaseToolkit):
39
39
  description=app_details.get("description"),
40
40
  application=app,
41
41
  args_schema=applicationToolSchema,
42
- return_type='str')])
42
+ return_type='str',
43
+ client=client,
44
+ args_runnable={
45
+ "application_id": application_id,
46
+ "application_version_id": application_version_id,
47
+ "store": store,
48
+ "llm": client.get_llm(version_details['llm_settings']['model_name'], model_settings),
49
+ })])
43
50
 
44
51
  def get_tools(self):
45
52
  return self.tools
@@ -23,11 +23,7 @@ class ArtifactToolkit(BaseToolkit):
23
23
  # client = (Any, FieldInfo(description="Client object", required=True, autopopulate=True)),
24
24
  bucket=(str, FieldInfo(
25
25
  description="Bucket name",
26
- pattern=r'^[a-z][a-z0-9-]*$',
27
- json_schema_extra={
28
- 'toolkit_name': True,
29
- 'max_toolkit_length': ArtifactToolkit.toolkit_max_length
30
- }
26
+ pattern=r'^[a-z][a-z0-9-]*$'
31
27
  )),
32
28
  selected_tools=(List[Literal[tuple(selected_tools)]], Field(default=[], json_schema_extra={'args_schemas': selected_tools})),
33
29
  # indexer settings
@@ -37,7 +33,10 @@ class ArtifactToolkit(BaseToolkit):
37
33
  embedding_model=(Optional[str], Field(default=None, description="Embedding configuration.",
38
34
  json_schema_extra={'configuration_model': 'embedding'})),
39
35
 
40
- __config__=ConfigDict(json_schema_extra={'metadata': {"label": "Artifact", "icon_url": None}})
36
+ __config__=ConfigDict(json_schema_extra={'metadata': {"label": "Artifact",
37
+ "icon_url": None,
38
+ "max_length": ArtifactToolkit.toolkit_max_length
39
+ }})
41
40
  )
42
41
 
43
42
  @classmethod