alita-sdk 0.3.124__py3-none-any.whl → 0.3.126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
-
2
1
  import logging
3
2
  import importlib
4
3
  from copy import deepcopy as copy
@@ -17,6 +16,7 @@ from .constants import REACT_ADDON, REACT_VARS, XML_ADDON
17
16
  from .chat_message_template import Jinja2TemplatedChatMessagesTemplate
18
17
  from ..tools.echo import EchoTool
19
18
  from ..toolkits.tools import get_tools
19
+ from langchain_core.tools import BaseTool
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -53,7 +53,7 @@ class Assistant:
53
53
  target_name
54
54
  )
55
55
  self.client = target_cls(**model_params)
56
- self.tools = get_tools(data['tools'], alita=alita, llm=self.client)
56
+ self.tools = get_tools(data['tools'], alita_client=alita, llm=self.client)
57
57
  if app_type == "pipeline":
58
58
  self.prompt = data['instructions']
59
59
  else:
@@ -106,18 +106,22 @@ class Assistant:
106
106
  max_execution_time=None, return_intermediate_steps=True)
107
107
 
108
108
  def getAgentExecutor(self):
109
- agent = create_json_chat_agent(llm=self.client, tools=self.tools, prompt=self.prompt,
110
- #tools_renderer=render_react_text_description_and_args
111
- )
109
+ # Exclude compiled graph runnables from simple tool agents
110
+ simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
111
+ agent = create_json_chat_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
112
112
  return self._agent_executor(agent)
113
113
 
114
114
 
115
115
  def getXMLAgentExecutor(self):
116
- agent = create_xml_chat_agent(llm=self.client, tools=self.tools, prompt=self.prompt)
116
+ # Exclude compiled graph runnables from simple tool agents
117
+ simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
118
+ agent = create_xml_chat_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
117
119
  return self._agent_executor(agent)
118
120
 
119
121
  def getOpenAIToolsAgentExecutor(self):
120
- agent = create_openai_tools_agent(llm=self.client, tools=self.tools, prompt=self.prompt)
122
+ # Exclude compiled graph runnables from simple tool agents
123
+ simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
124
+ agent = create_openai_tools_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
121
125
  return self._agent_executor(agent)
122
126
 
123
127
  def pipeline(self):
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  from typing import Union, Any, Optional, Annotated, get_type_hints
3
3
  from uuid import uuid4
4
+ from typing import Dict
4
5
 
5
6
  import yaml
6
7
  from langchain_core.callbacks import dispatch_custom_event
@@ -17,7 +18,7 @@ from langgraph.prebuilt import InjectedStore
17
18
  from langgraph.store.base import BaseStore
18
19
 
19
20
  from .mixedAgentRenderes import convert_message_to_json
20
- from .utils import create_state
21
+ from .utils import create_state, propagate_the_input_mapping
21
22
  from ..tools.function import FunctionTool
22
23
  from ..tools.indexer_tool import IndexerNode
23
24
  from ..tools.llm import LLMNode
@@ -30,15 +31,101 @@ from ..tools.router import RouterNode
30
31
 
31
32
  logger = logging.getLogger(__name__)
32
33
 
34
+ # Global registry for subgraph definitions
35
+ # Structure: {'subgraph_name': {'yaml': 'yaml_def', 'tools': [tools], 'flattened': False}}
36
+ SUBGRAPH_REGISTRY: Dict[str, Dict[str, Any]] = {}
37
+
38
+
39
+ # Wrapper for injecting a compiled subgraph into a parent StateGraph
40
+ class SubgraphRunnable(CompiledStateGraph):
41
+ def __init__(
42
+ self,
43
+ inner: CompiledStateGraph,
44
+ *,
45
+ name: str,
46
+ input_mapping: Dict[str, Any],
47
+ output_mapping: Dict[str, Any]
48
+ ):
49
+ # copy child graph internals
50
+ super().__init__(
51
+ builder=inner.builder,
52
+ config_type=inner.config_type,
53
+ nodes=inner.nodes,
54
+ channels=inner.channels,
55
+ input_channels=inner.input_channels,
56
+ stream_mode=inner.stream_mode,
57
+ output_channels=inner.output_channels,
58
+ stream_channels=inner.stream_channels,
59
+ checkpointer=inner.checkpointer,
60
+ interrupt_before_nodes=inner.interrupt_before_nodes,
61
+ interrupt_after_nodes=inner.interrupt_after_nodes,
62
+ auto_validate=False,
63
+ debug=inner.debug,
64
+ store=inner.store,
65
+ )
66
+ self.inner = inner
67
+ self.name = name
68
+ self.input_mapping = input_mapping or {}
69
+ self.output_mapping = output_mapping or {}
70
+
71
+ def invoke(
72
+ self,
73
+ state: Union[dict[str, Any], Any],
74
+ config: Optional[RunnableConfig] = None,
75
+ **kwargs: Any,
76
+ ) -> Union[dict[str, Any], Any]:
77
+ # Detailed logging for debugging
78
+ logger.debug(f"SubgraphRunnable '{self.name}' invoke called with state: {state}")
79
+ logger.debug(f"SubgraphRunnable '{self.name}' config: {config}")
80
+
81
+ # 1) parent -> child mapping
82
+ if not self.input_mapping:
83
+ child_input = state.copy()
84
+ else:
85
+ child_input = propagate_the_input_mapping(
86
+ self.input_mapping, list(self.input_mapping.keys()), state
87
+ )
88
+ # debug trace of messages flowing into child
89
+ logger.debug(f"SubgraphRunnable '{self.name}' child_input.messages: {child_input.get('messages')}")
90
+ logger.debug(f"SubgraphRunnable '{self.name}' child_input.input: {child_input.get('input')}")
91
+
92
+ # 2) Invoke the child graph.
93
+ # Pass None as the first argument for input if the child is expected to resume
94
+ # using its (now updated) checkpoint. The CompiledStateGraph.invoke method, when
95
+ # input is None but a checkpoint exists, loads from the checkpoint.
96
+ # Any resume commands (if applicable for internal child interrupts) are in 'config'.
97
+ # logger.debug(f"SubgraphRunnable '{self.name}': Invoking child graph super().invoke(None, config).")
98
+ subgraph_output = super().invoke(child_input, config=config, **kwargs)
99
+
100
+ # 3) child complete: apply output_mapping or passthrough
101
+ logger.debug(f"SubgraphRunnable '{self.name}' child complete, applying mappings")
102
+ result: Dict[str, Any] = {}
103
+ if self.output_mapping:
104
+ for child_key, parent_key in self.output_mapping.items():
105
+ if child_key in subgraph_output:
106
+ state[parent_key] = subgraph_output[child_key]
107
+ result[parent_key] = subgraph_output[child_key]
108
+ logger.debug(f"SubgraphRunnable '{self.name}' mapped {child_key} -> {parent_key}")
109
+ else:
110
+ for k, v in subgraph_output.items():
111
+ state[k] = v
112
+ result[k] = v
113
+
114
+ # include full messages history on completion
115
+ if 'messages' not in result:
116
+ result['messages'] = subgraph_output.get('messages', [])
117
+ logger.debug(f"SubgraphRunnable '{self.name}' returning result: {result}")
118
+ return result
119
+
33
120
 
34
121
  class ConditionalEdge(Runnable):
35
122
  name = "ConditionalEdge"
36
123
 
37
124
  def __init__(self, condition: str, condition_inputs: Optional[list[str]] = [],
38
- conditional_outputs: Optional[list[str]] = [], default_output: str = 'END'):
125
+ conditional_outputs: Optional[list[str]] = [], default_output: str = END):
39
126
  self.condition = condition
40
127
  self.condition_inputs = condition_inputs
41
- self.conditional_outputs = {clean_string(cond) for cond in conditional_outputs}
128
+ self.conditional_outputs = {clean_string(cond if not 'END' == cond else '__end__') for cond in conditional_outputs}
42
129
  self.default_output = clean_string(default_output)
43
130
 
44
131
  def invoke(self, state: Annotated[BaseStore, InjectedStore()], config: Optional[RunnableConfig] = None) -> str:
@@ -177,8 +264,13 @@ class StateModifierNode(Runnable):
177
264
  return result
178
265
 
179
266
 
180
- def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=[], interrupt_after=[]):
267
+
268
+ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None, state_class=None):
181
269
  # prepare output channels
270
+ if interrupt_after is None:
271
+ interrupt_after = []
272
+ if interrupt_before is None:
273
+ interrupt_before = []
182
274
  output_channels = (
183
275
  "__root__"
184
276
  if len(lg_builder.schemas[lg_builder.output]) == 1
@@ -239,14 +331,24 @@ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_befo
239
331
  def create_graph(
240
332
  client: Any,
241
333
  yaml_schema: str,
242
- tools: list[BaseTool],
334
+ tools: list[Union[BaseTool, CompiledStateGraph]],
243
335
  *args,
244
336
  memory: Optional[Any] = None,
245
337
  store: Optional[BaseStore] = None,
246
338
  debug: bool = False,
339
+ for_subgraph: bool = False,
247
340
  **kwargs
248
341
  ):
249
342
  """ Create a message graph from a yaml schema """
343
+
344
+ # For top-level graphs (not subgraphs), detect and flatten any subgraphs
345
+ if not for_subgraph:
346
+ flattened_yaml, additional_tools = detect_and_flatten_subgraphs(yaml_schema)
347
+ # Add collected tools from subgraphs to the tools list
348
+ tools = list(tools) + additional_tools
349
+ # Use the flattened YAML for building the graph
350
+ yaml_schema = flattened_yaml
351
+
250
352
  schema = yaml.safe_load(yaml_schema)
251
353
  logger.debug(f"Schema: {schema}")
252
354
  logger.debug(f"Tools: {tools}")
@@ -264,7 +366,7 @@ def create_graph(
264
366
  if toolkit_name:
265
367
  tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
266
368
  logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
267
- if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer']:
369
+ if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph']:
268
370
  for tool in tools:
269
371
  if tool.name == tool_name:
270
372
  if node_type == 'function':
@@ -274,6 +376,19 @@ def create_graph(
274
376
  input_mapping=node.get('input_mapping',
275
377
  {'messages': {'type': 'variable', 'value': 'messages'}}),
276
378
  input_variables=node.get('input', ['messages'])))
379
+ elif node_type == 'subgraph':
380
+ # assign parent memory/store
381
+ # tool.checkpointer = memory
382
+ # tool.store = store
383
+ # wrap with mappings
384
+ node_fn = SubgraphRunnable(
385
+ inner=tool,
386
+ name=node['id'],
387
+ input_mapping=node.get('input_mapping', {}),
388
+ output_mapping=node.get('output_mapping', {}),
389
+ )
390
+ lg_builder.add_node(node_id, node_fn)
391
+ break # skip legacy handling
277
392
  elif node_type == 'tool':
278
393
  lg_builder.add_node(node_id, ToolNode(
279
394
  client=client, tool=tool,
@@ -394,20 +509,37 @@ def create_graph(
394
509
  interrupt_before = interrupt_before or []
395
510
  interrupt_after = interrupt_after or []
396
511
 
397
- # validate the graph
398
- lg_builder.validate(
399
- interrupt=(
400
- (interrupt_before if interrupt_before != "*" else []) + interrupt_after
401
- if interrupt_after != "*"
402
- else []
512
+ if not for_subgraph:
513
+ # validate the graph for LangGraphAgentRunnable before the actual construction
514
+ lg_builder.validate(
515
+ interrupt=(
516
+ (interrupt_before if interrupt_before != "*" else []) + interrupt_after
517
+ if interrupt_after != "*"
518
+ else []
519
+ )
403
520
  )
521
+
522
+ # Compile into a CompiledStateGraph for the subgraph
523
+ graph = lg_builder.compile(
524
+ checkpointer=True,
525
+ interrupt_before=interrupt_before,
526
+ interrupt_after=interrupt_after,
527
+ store=store,
528
+ debug=debug,
404
529
  )
405
530
  except ValueError as e:
406
531
  raise ValueError(
407
532
  f"Validation of the schema failed. {e}\n\nDEBUG INFO:**Schema Nodes:**\n\n{lg_builder.nodes}\n\n**Schema Enges:**\n\n{lg_builder.edges}\n\n**Tools Available:**\n\n{tools}")
408
- compiled = prepare_output_schema(lg_builder, memory, store, debug,
409
- interrupt_before=interrupt_before,
410
- interrupt_after=interrupt_after)
533
+ # If building a nested subgraph, return the raw CompiledStateGraph
534
+ if for_subgraph:
535
+ return graph
536
+ # Otherwise prepare top-level runnable wrapper and validate
537
+ compiled = prepare_output_schema(
538
+ lg_builder, memory, store, debug,
539
+ interrupt_before=interrupt_before,
540
+ interrupt_after=interrupt_after,
541
+ state_class={state_class: None}
542
+ )
411
543
  return compiled.validate()
412
544
 
413
545
 
@@ -440,8 +572,249 @@ class LangGraphAgentRunnable(CompiledStateGraph):
440
572
  config_state = self.get_state(config)
441
573
  if config_state.next:
442
574
  thread_id = config['configurable']['thread_id']
443
- return {
575
+
576
+ result_with_state = {
444
577
  "output": output,
445
578
  "thread_id": thread_id,
446
579
  "execution_finished": not config_state.next
447
580
  }
581
+
582
+ # Include all state values in the result
583
+ if hasattr(config_state, 'values') and config_state.values:
584
+ for key, value in config_state.values.items():
585
+ result_with_state[key] = value
586
+
587
+ return result_with_state
588
+
589
+ def merge_subgraphs(parent_yaml: str, registry: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
590
+ """
591
+ Merge subgraphs into parent graph by flattening YAML structures.
592
+
593
+ This function implements the complete flattening approach:
594
+ 1. Parse parent YAML
595
+ 2. Detect subgraph nodes
596
+ 3. Recursively flatten subgraphs
597
+ 4. Merge states, nodes, interrupts, and transitions
598
+ 5. Return single unified graph definition
599
+
600
+ Args:
601
+ parent_yaml: YAML string of parent graph
602
+ registry: Global subgraph registry
603
+
604
+ Returns:
605
+ Dict containing flattened graph definition
606
+ """
607
+ import copy
608
+
609
+ # Parse parent YAML
610
+ parent_def = yaml.safe_load(parent_yaml)
611
+
612
+ # Check if already flattened (prevent infinite recursion)
613
+ if parent_def.get('_flattened', False):
614
+ return parent_def
615
+
616
+ # Find subgraph nodes in parent
617
+ subgraph_nodes = []
618
+ regular_nodes = []
619
+
620
+ for node in parent_def.get('nodes', []):
621
+ if node.get('type') == 'subgraph':
622
+ subgraph_nodes.append(node)
623
+ else:
624
+ regular_nodes.append(node)
625
+
626
+ # If no subgraphs, return as-is
627
+ if not subgraph_nodes:
628
+ parent_def['_flattened'] = True
629
+ return parent_def
630
+
631
+ # Start with parent state and merge subgraph states
632
+ merged_state = copy.deepcopy(parent_def.get('state', {}))
633
+ merged_nodes = copy.deepcopy(regular_nodes)
634
+ merged_interrupts_before = set(parent_def.get('interrupt_before', []))
635
+ merged_interrupts_after = set(parent_def.get('interrupt_after', []))
636
+ all_tools = []
637
+
638
+ # Track node remapping for transition rewiring
639
+ node_mapping = {} # subgraph_node_id -> actual_internal_node_id
640
+
641
+ # Process each subgraph
642
+ for subgraph_node in subgraph_nodes:
643
+ # Support both 'tool' and 'subgraph' fields for subgraph name
644
+ subgraph_name = subgraph_node.get('tool') or subgraph_node.get('subgraph')
645
+ subgraph_node_id = subgraph_node['id']
646
+
647
+ if subgraph_name not in registry:
648
+ logger.warning(f"Subgraph '{subgraph_name}' not found in registry")
649
+ continue
650
+
651
+ # Get subgraph definition
652
+ subgraph_entry = registry[subgraph_name]
653
+ subgraph_yaml = subgraph_entry['yaml']
654
+ subgraph_tools = subgraph_entry.get('tools', [])
655
+
656
+ # Recursively flatten the subgraph (in case it has nested subgraphs)
657
+ subgraph_def = merge_subgraphs(subgraph_yaml, registry)
658
+
659
+ # Collect tools from subgraph
660
+ all_tools.extend(subgraph_tools)
661
+
662
+ # Merge state (union of all fields)
663
+ for field_name, field_type in subgraph_def.get('state', {}).items():
664
+ if field_name not in merged_state:
665
+ merged_state[field_name] = field_type
666
+ elif merged_state[field_name] != field_type:
667
+ logger.warning(f"State field '{field_name}' type mismatch: {merged_state[field_name]} vs {field_type}")
668
+
669
+ # Map subgraph node to its entry point
670
+ subgraph_entry_point = subgraph_def.get('entry_point')
671
+ if subgraph_entry_point:
672
+ node_mapping[subgraph_node_id] = subgraph_entry_point
673
+ logger.debug(f"Mapped subgraph node '{subgraph_node_id}' to entry point '{subgraph_entry_point}'")
674
+
675
+ # Add subgraph nodes without prefixing (keep original IDs)
676
+ for sub_node in subgraph_def.get('nodes', []):
677
+ # Keep original node ID - no prefixing
678
+ new_node = copy.deepcopy(sub_node)
679
+ merged_nodes.append(new_node)
680
+
681
+ # Handle the original subgraph node's transition - apply it to nodes that end with END
682
+ original_transition = subgraph_node.get('transition')
683
+ if original_transition and original_transition != 'END' and original_transition != END:
684
+ # Find nodes in this subgraph that have END transitions and update them
685
+ for node in merged_nodes:
686
+ # Check if this is a node from the current subgraph by checking if it was just added
687
+ # and has an END transition
688
+ if node.get('transition') == 'END' and node in subgraph_def.get('nodes', []):
689
+ node['transition'] = original_transition
690
+
691
+ # Merge interrupts without prefixing (keep original names)
692
+ for interrupt in subgraph_def.get('interrupt_before', []):
693
+ merged_interrupts_before.add(interrupt) # No prefixing
694
+ for interrupt in subgraph_def.get('interrupt_after', []):
695
+ merged_interrupts_after.add(interrupt) # No prefixing
696
+
697
+ # Handle entry point - keep parent's unless it's a subgraph node
698
+ entry_point = parent_def.get('entry_point')
699
+ logger.debug(f"Original entry point: {entry_point}")
700
+ logger.debug(f"Node mapping: {node_mapping}")
701
+ if entry_point in node_mapping:
702
+ # Parent entry point is a subgraph, redirect to subgraph's entry point
703
+ old_entry_point = entry_point
704
+ entry_point = node_mapping[entry_point]
705
+ logger.debug(f"Entry point changed from {old_entry_point} to {entry_point}")
706
+ else:
707
+ logger.debug(f"Entry point {entry_point} not in node mapping, keeping as-is")
708
+
709
+ # Rewrite transitions in regular nodes that point to subgraph nodes
710
+ for node in merged_nodes:
711
+ # Handle direct transitions
712
+ if 'transition' in node:
713
+ transition = node['transition']
714
+ if transition in node_mapping:
715
+ node['transition'] = node_mapping[transition]
716
+
717
+ # Handle conditional transitions
718
+ if 'condition' in node:
719
+ condition = node['condition']
720
+ if 'conditional_outputs' in condition:
721
+ new_outputs = []
722
+ for output in condition['conditional_outputs']:
723
+ if output in node_mapping:
724
+ new_outputs.append(node_mapping[output])
725
+ else:
726
+ new_outputs.append(output)
727
+ condition['conditional_outputs'] = new_outputs
728
+
729
+ if 'default_output' in condition:
730
+ default = condition['default_output']
731
+ if default in node_mapping:
732
+ condition['default_output'] = node_mapping[default]
733
+
734
+ # Update condition_definition Jinja2 template to replace subgraph node references
735
+ if 'condition_definition' in condition:
736
+ condition_definition = condition['condition_definition']
737
+ # Replace subgraph node references in the Jinja2 template
738
+ for subgraph_node_id, subgraph_entry_point in node_mapping.items():
739
+ condition_definition = condition_definition.replace(subgraph_node_id, subgraph_entry_point)
740
+ condition['condition_definition'] = condition_definition
741
+
742
+ # Handle decision nodes
743
+ if 'decision' in node:
744
+ decision = node['decision']
745
+ # Update decision.nodes list to replace subgraph node references
746
+ if 'nodes' in decision:
747
+ new_nodes = []
748
+ for decision_node in decision['nodes']:
749
+ if decision_node in node_mapping:
750
+ new_nodes.append(node_mapping[decision_node])
751
+ else:
752
+ new_nodes.append(decision_node)
753
+ decision['nodes'] = new_nodes
754
+
755
+ # Update decision.default_output to replace subgraph node references
756
+ if 'default_output' in decision:
757
+ default_output = decision['default_output']
758
+ if default_output in node_mapping:
759
+ decision['default_output'] = node_mapping[default_output]
760
+
761
+ # Build final flattened definition
762
+ flattened = {
763
+ 'name': parent_def.get('name', 'FlattenedGraph'),
764
+ 'state': merged_state,
765
+ 'nodes': merged_nodes,
766
+ 'entry_point': entry_point,
767
+ '_flattened': True,
768
+ '_all_tools': all_tools # Store tools for later collection
769
+ }
770
+
771
+ # Add interrupts if present
772
+ if merged_interrupts_before:
773
+ flattened['interrupt_before'] = list(merged_interrupts_before)
774
+ if merged_interrupts_after:
775
+ flattened['interrupt_after'] = list(merged_interrupts_after)
776
+
777
+ return flattened
778
+
779
+
780
+ def detect_and_flatten_subgraphs(yaml_schema: str) -> tuple[str, list]:
781
+ """
782
+ Detect subgraphs in YAML and flatten them if found.
783
+
784
+ Returns:
785
+ tuple: (flattened_yaml_string, collected_tools)
786
+ """
787
+ # Parse to check for subgraphs
788
+ schema_dict = yaml.safe_load(yaml_schema)
789
+ subgraph_nodes = [
790
+ node for node in schema_dict.get('nodes', [])
791
+ if node.get('type') == 'subgraph'
792
+ ]
793
+
794
+ if not subgraph_nodes:
795
+ return yaml_schema, []
796
+
797
+ # Check if all required subgraphs are available in registry
798
+ missing_subgraphs = []
799
+ for node in subgraph_nodes:
800
+ # Support both 'tool' and 'subgraph' fields for subgraph name
801
+ # Don't clean the string - registry keys use original names
802
+ subgraph_name = node.get('tool') or node.get('subgraph')
803
+ if subgraph_name and subgraph_name not in SUBGRAPH_REGISTRY:
804
+ missing_subgraphs.append(subgraph_name)
805
+
806
+ if missing_subgraphs:
807
+ logger.warning(f"Cannot flatten - missing subgraphs: {missing_subgraphs}")
808
+ return yaml_schema, []
809
+
810
+ # Flatten the graph
811
+ flattened_def = merge_subgraphs(yaml_schema, SUBGRAPH_REGISTRY)
812
+
813
+ # Extract tools
814
+ all_tools = flattened_def.pop('_all_tools', [])
815
+
816
+ # Convert back to YAML
817
+ flattened_yaml = yaml.dump(flattened_def, default_flow_style=False)
818
+
819
+ return flattened_yaml, all_tools
820
+
@@ -0,0 +1,53 @@
1
+ from typing import List, Any
2
+
3
+ from langgraph.graph.state import CompiledStateGraph
4
+
5
+ from ..langchain.langraph_agent import create_graph, SUBGRAPH_REGISTRY
6
+ from ..utils.utils import clean_string
7
+
8
+
9
+ class SubgraphToolkit:
10
+
11
+ @staticmethod
12
+ def get_toolkit(
13
+ client: Any,
14
+ application_id: int,
15
+ application_version_id: int,
16
+ llm,
17
+ app_api_key: str,
18
+ selected_tools: list[str] = []
19
+ ) -> List[CompiledStateGraph]:
20
+ from .tools import get_tools
21
+ # from langgraph.checkpoint.memory import MemorySaver
22
+
23
+ app_details = client.get_app_details(application_id)
24
+ version_details = client.get_app_version_details(application_id, application_version_id)
25
+ tools = get_tools(version_details['tools'], alita_client=client, llm=llm)
26
+
27
+ # Get the subgraph name
28
+ subgraph_name = app_details.get("name")
29
+
30
+ # Populate the registry for flattening approach
31
+ SUBGRAPH_REGISTRY[subgraph_name] = {
32
+ 'yaml': version_details['instructions'],
33
+ 'tools': tools,
34
+ 'flattened': False
35
+ }
36
+
37
+ # For backward compatibility, still create a compiled graph stub
38
+ # This is mainly used for identification in the parent graph's tools list
39
+ graph = create_graph(
40
+ client=llm,
41
+ tools=tools,
42
+ yaml_schema=version_details['instructions'],
43
+ debug=False,
44
+ store=None,
45
+ memory=None,
46
+ for_subgraph=True, # compile as raw subgraph
47
+ )
48
+
49
+ # Tag the graph stub for parent lookup
50
+ graph.name = clean_string(subgraph_name)
51
+
52
+ # Return the compiled graph stub for backward compatibility
53
+ return [graph]
@@ -1,14 +1,14 @@
1
1
  import logging
2
2
 
3
- from alita_tools import get_tools as alita_tools
4
3
  from alita_tools import get_toolkits as alita_toolkits
4
+ from alita_tools import get_tools as alita_tools
5
5
 
6
- from .prompt import PromptToolkit
7
- from .datasource import DatasourcesToolkit
8
6
  from .application import ApplicationToolkit
9
7
  from .artifact import ArtifactToolkit
8
+ from .datasource import DatasourcesToolkit
9
+ from .prompt import PromptToolkit
10
+ from .subgraph import SubgraphToolkit
10
11
  from .vectorstore import VectorStoreToolkit
11
-
12
12
  ## Community tools and toolkits
13
13
  from ..community.analysis.jira_analyse import AnalyseJira
14
14
  from ..community.browseruse import BrowserUseToolkit
@@ -35,7 +35,7 @@ def get_toolkits():
35
35
  return core_toolkits + community_toolkits + alita_toolkits()
36
36
 
37
37
 
38
- def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> list:
38
+ def get_tools(tools_list: list, alita_client, llm) -> list:
39
39
  prompts = []
40
40
  tools = []
41
41
 
@@ -47,33 +47,43 @@ def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> l
47
47
  ])
48
48
  elif tool['type'] == 'datasource':
49
49
  tools.extend(DatasourcesToolkit.get_toolkit(
50
- alita,
50
+ alita_client,
51
51
  datasource_ids=[int(tool['settings']['datasource_id'])],
52
52
  selected_tools=tool['settings']['selected_tools'],
53
53
  toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
54
54
  ).get_tools())
55
55
  elif tool['type'] == 'application':
56
56
  tools.extend(ApplicationToolkit.get_toolkit(
57
- alita,
57
+ alita_client,
58
58
  application_id=int(tool['settings']['application_id']),
59
59
  application_version_id=int(tool['settings']['application_version_id']),
60
- app_api_key=alita.auth_token,
60
+ app_api_key=alita_client.auth_token,
61
61
  selected_tools=[]
62
62
  ).get_tools())
63
+ elif tool['type'] == 'subgraph':
64
+ # static get_toolkit returns a list of CompiledStateGraph stubs
65
+ tools.extend(SubgraphToolkit.get_toolkit(
66
+ alita_client,
67
+ application_id=int(tool['settings']['application_id']),
68
+ application_version_id=int(tool['settings']['application_version_id']),
69
+ app_api_key=alita_client.auth_token,
70
+ selected_tools=[],
71
+ llm=llm
72
+ ))
63
73
  elif tool['type'] == 'artifact':
64
74
  tools.extend(ArtifactToolkit.get_toolkit(
65
- client=alita,
75
+ client=alita_client,
66
76
  bucket=tool['settings']['bucket'],
67
77
  toolkit_name=tool.get('toolkit_name', ''),
68
78
  selected_tools=tool['settings'].get('selected_tools', [])
69
79
  ).get_tools())
70
80
  if tool['type'] == 'analyse_jira':
71
81
  tools.extend(AnalyseJira.get_toolkit(
72
- client=alita,
82
+ client=alita_client,
73
83
  **tool['settings']).get_tools())
74
84
  if tool['type'] == 'browser_use':
75
85
  tools.extend(BrowserUseToolkit.get_toolkit(
76
- client=alita,
86
+ client=alita_client,
77
87
  llm=llm,
78
88
  toolkit_name=tool.get('toolkit_name', ''),
79
89
  **tool['settings']).get_tools())
@@ -83,33 +93,37 @@ def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> l
83
93
  toolkit_name=tool.get('toolkit_name', ''),
84
94
  **tool['settings']).get_tools())
85
95
  if len(prompts) > 0:
86
- tools += PromptToolkit.get_toolkit(alita, prompts).get_tools()
87
- tools += alita_tools(tools_list, alita, llm)
88
- tools += _mcp_tools(tools_list, alita)
96
+ tools += PromptToolkit.get_toolkit(alita_client, prompts).get_tools()
97
+ tools += alita_tools(tools_list, alita_client, llm)
98
+ tools += _mcp_tools(tools_list, alita_client)
89
99
  return tools
90
100
 
91
101
 
92
102
  def _mcp_tools(tools_list, alita):
93
- all_available_toolkits = alita.get_mcp_toolkits()
94
- toolkit_lookup = {tk["name"].lower(): tk for tk in all_available_toolkits}
95
- tools = []
96
- #
97
- for selected_toolkit in tools_list:
98
- toolkit_name = selected_toolkit['type'].lower()
99
- toolkit_conf = toolkit_lookup.get(toolkit_name)
100
- #
101
- if not toolkit_conf:
102
- logger.warning(f"Toolkit '{toolkit_name}' not found in available toolkits.")
103
- continue
103
+ try:
104
+ all_available_toolkits = alita.get_mcp_toolkits()
105
+ toolkit_lookup = {tk["name"].lower(): tk for tk in all_available_toolkits}
106
+ tools = []
104
107
  #
105
- available_tools = toolkit_conf.get("tools", [])
106
- selected_tools = [name.lower() for name in selected_toolkit['settings'].get('selected_tools', [])]
107
- for available_tool in available_tools:
108
- tool_name = available_tool.get("name", "").lower()
109
- if not selected_tools or tool_name in selected_tools:
110
- if server_tool := _init_single_mcp_tool(toolkit_name, available_tool, alita, selected_toolkit['settings']):
111
- tools.append(server_tool)
112
- return tools
108
+ for selected_toolkit in tools_list:
109
+ toolkit_name = selected_toolkit['type'].lower()
110
+ toolkit_conf = toolkit_lookup.get(toolkit_name)
111
+ #
112
+ if not toolkit_conf:
113
+ logger.warning(f"Toolkit '{toolkit_name}' not found in available toolkits.")
114
+ continue
115
+ #
116
+ available_tools = toolkit_conf.get("tools", [])
117
+ selected_tools = [name.lower() for name in selected_toolkit['settings'].get('selected_tools', [])]
118
+ for available_tool in available_tools:
119
+ tool_name = available_tool.get("name", "").lower()
120
+ if not selected_tools or tool_name in selected_tools:
121
+ if server_tool := _init_single_mcp_tool(toolkit_name, available_tool, alita, selected_toolkit['settings']):
122
+ tools.append(server_tool)
123
+ return tools
124
+ except Exception:
125
+ logger.error("Error while fetching MCP tools", exc_info=True)
126
+ return []
113
127
 
114
128
 
115
129
  def _init_single_mcp_tool(toolkit_name, available_tool, alita, toolkit_settings):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alita_sdk
3
- Version: 0.3.124
3
+ Version: 0.3.126
4
4
  Summary: SDK for building langchain agents using resouces from Alita
5
5
  Author-email: Artem Rozumenko <artyom.rozumenko@gmail.com>, Mikalai Biazruchka <mikalai_biazruchka@epam.com>, Roman Mitusov <roman_mitusov@epam.com>, Ivan Krakhmaliuk <lifedjik@gmail.com>
6
6
  Project-URL: Homepage, https://projectalita.ai
@@ -15,11 +15,11 @@ alita_sdk/community/analysis/jira_analyse/api_wrapper.py,sha256=JqGSxg_3x0ErzII3
15
15
  alita_sdk/community/browseruse/__init__.py,sha256=uAxPZEX7ihpt8HtcGDFrzTNv9WcklT1wG1ItTwUO8y4,3601
16
16
  alita_sdk/community/browseruse/api_wrapper.py,sha256=Y05NKWfTROPmBxe8ZFIELSGBX5v3RTNP30OTO2Tj8uI,10838
17
17
  alita_sdk/langchain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- alita_sdk/langchain/assistant.py,sha256=V0MpZG9IQRlKMr1SfAabVbUsIp-gNPO9A1SWxuFikNY,5930
18
+ alita_sdk/langchain/assistant.py,sha256=J_xhwbNl934BgDKSpAMC9a1u6v03DZQcTYaamCztEPk,6272
19
19
  alita_sdk/langchain/chat_message_template.py,sha256=kPz8W2BG6IMyITFDA5oeb5BxVRkHEVZhuiGl4MBZKdc,2176
20
20
  alita_sdk/langchain/constants.py,sha256=eHVJ_beJNTf1WJo4yq7KMK64fxsRvs3lKc34QCXSbpk,3319
21
21
  alita_sdk/langchain/indexer.py,sha256=0ENHy5EOhThnAiYFc7QAsaTNp9rr8hDV_hTK8ahbatk,37592
22
- alita_sdk/langchain/langraph_agent.py,sha256=f9rGk6QGQbfbGudwpM5ax9yS-xDlElFqGAOsZbGvrtI,20919
22
+ alita_sdk/langchain/langraph_agent.py,sha256=HwopuxCWDOg6i-ZKbxZzrqnRZ84pGIS7kVN349ER8bs,36510
23
23
  alita_sdk/langchain/mixedAgentParser.py,sha256=M256lvtsL3YtYflBCEp-rWKrKtcY1dJIyRGVv7KW9ME,2611
24
24
  alita_sdk/langchain/mixedAgentRenderes.py,sha256=asBtKqm88QhZRILditjYICwFVKF5KfO38hu2O-WrSWE,5964
25
25
  alita_sdk/langchain/utils.py,sha256=Npferkn10dvdksnKzLJLBI5bNGQyVWTBwqp3vQtUqmY,6631
@@ -68,7 +68,8 @@ alita_sdk/toolkits/application.py,sha256=LrxbBV05lkRP3_WtKGBKtMdoQHXVY-_AtFr1cUu
68
68
  alita_sdk/toolkits/artifact.py,sha256=7zb17vhJ3CigeTqvzQ4VNBsU5UOCJqAwz7fOJGMYqXw,2348
69
69
  alita_sdk/toolkits/datasource.py,sha256=v3FQu8Gmvq7gAGAnFEbA8qofyUhh98rxgIjY6GHBfyI,2494
70
70
  alita_sdk/toolkits/prompt.py,sha256=WIpTkkVYWqIqOWR_LlSWz3ug8uO9tm5jJ7aZYdiGRn0,1192
71
- alita_sdk/toolkits/tools.py,sha256=ZYi-oZwezaqcZe6Ft6w2hlpbuCKYyDNjl3WRWViVT08,5074
71
+ alita_sdk/toolkits/subgraph.py,sha256=ZYqI4yVLbEPAjCR8dpXbjbL2ipX598Hk3fL6AgaqFD4,1758
72
+ alita_sdk/toolkits/tools.py,sha256=gk3nvQBdab3QM8v93ff2nrN4ZfcT779yae2RygkTl8s,5834
72
73
  alita_sdk/toolkits/vectorstore.py,sha256=di08-CRl0KJ9xSZ8_24VVnPZy58iLqHtXW8vuF29P64,2893
73
74
  alita_sdk/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
74
75
  alita_sdk/tools/application.py,sha256=UJlYd3Sub10LpAoKkKEpvd4miWyrS-yYE5NKyqx-H4Q,2194
@@ -92,10 +93,10 @@ alita_sdk/utils/evaluate.py,sha256=iM1P8gzBLHTuSCe85_Ng_h30m52hFuGuhNXJ7kB1tgI,1
92
93
  alita_sdk/utils/logging.py,sha256=hBE3qAzmcLMdamMp2YRXwOOK9P4lmNaNhM76kntVljs,3124
93
94
  alita_sdk/utils/streamlit.py,sha256=zp8owZwHI3HZplhcExJf6R3-APtWx-z6s5jznT2hY_k,29124
94
95
  alita_sdk/utils/utils.py,sha256=dM8whOJAuFJFe19qJ69-FLzrUp6d2G-G6L7d4ss2XqM,346
95
- alita_sdk-0.3.124.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
96
+ alita_sdk-0.3.126.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
96
97
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
97
98
  tests/test_jira_analysis.py,sha256=I0cErH5R_dHVyutpXrM1QEo7jfBuKWTmDQvJBPjx18I,3281
98
- alita_sdk-0.3.124.dist-info/METADATA,sha256=SLr3I6hYo-6OFuiJYygEXnREyu4nt2GlWBoqh4jHw6M,7075
99
- alita_sdk-0.3.124.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
100
- alita_sdk-0.3.124.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
101
- alita_sdk-0.3.124.dist-info/RECORD,,
99
+ alita_sdk-0.3.126.dist-info/METADATA,sha256=eX2afqBm4mw5_LMMdJ_HXMXHTp93O8bWSIQ3jCVy8go,7075
100
+ alita_sdk-0.3.126.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
101
+ alita_sdk-0.3.126.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
102
+ alita_sdk-0.3.126.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5