alita-sdk 0.3.125__py3-none-any.whl → 0.3.127__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
-
2
1
  import logging
3
2
  import importlib
4
3
  from copy import deepcopy as copy
@@ -17,6 +16,7 @@ from .constants import REACT_ADDON, REACT_VARS, XML_ADDON
17
16
  from .chat_message_template import Jinja2TemplatedChatMessagesTemplate
18
17
  from ..tools.echo import EchoTool
19
18
  from ..toolkits.tools import get_tools
19
+ from langchain_core.tools import BaseTool
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -53,7 +53,7 @@ class Assistant:
53
53
  target_name
54
54
  )
55
55
  self.client = target_cls(**model_params)
56
- self.tools = get_tools(data['tools'], alita=alita, llm=self.client)
56
+ self.tools = get_tools(data['tools'], alita_client=alita, llm=self.client)
57
57
  if app_type == "pipeline":
58
58
  self.prompt = data['instructions']
59
59
  else:
@@ -106,18 +106,22 @@ class Assistant:
106
106
  max_execution_time=None, return_intermediate_steps=True)
107
107
 
108
108
  def getAgentExecutor(self):
109
- agent = create_json_chat_agent(llm=self.client, tools=self.tools, prompt=self.prompt,
110
- #tools_renderer=render_react_text_description_and_args
111
- )
109
+ # Exclude compiled graph runnables from simple tool agents
110
+ simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
111
+ agent = create_json_chat_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
112
112
  return self._agent_executor(agent)
113
113
 
114
114
 
115
115
  def getXMLAgentExecutor(self):
116
- agent = create_xml_chat_agent(llm=self.client, tools=self.tools, prompt=self.prompt)
116
+ # Exclude compiled graph runnables from simple tool agents
117
+ simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
118
+ agent = create_xml_chat_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
117
119
  return self._agent_executor(agent)
118
120
 
119
121
  def getOpenAIToolsAgentExecutor(self):
120
- agent = create_openai_tools_agent(llm=self.client, tools=self.tools, prompt=self.prompt)
122
+ # Exclude compiled graph runnables from simple tool agents
123
+ simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
124
+ agent = create_openai_tools_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
121
125
  return self._agent_executor(agent)
122
126
 
123
127
  def pipeline(self):
@@ -1,8 +1,10 @@
1
1
  import logging
2
2
  from typing import Union, Any, Optional, Annotated, get_type_hints
3
3
  from uuid import uuid4
4
+ from typing import Dict
4
5
 
5
6
  import yaml
7
+ import ast
6
8
  from langchain_core.callbacks import dispatch_custom_event
7
9
  from langchain_core.messages import HumanMessage
8
10
  from langchain_core.runnables import Runnable
@@ -17,7 +19,7 @@ from langgraph.prebuilt import InjectedStore
17
19
  from langgraph.store.base import BaseStore
18
20
 
19
21
  from .mixedAgentRenderes import convert_message_to_json
20
- from .utils import create_state
22
+ from .utils import create_state, propagate_the_input_mapping
21
23
  from ..tools.function import FunctionTool
22
24
  from ..tools.indexer_tool import IndexerNode
23
25
  from ..tools.llm import LLMNode
@@ -30,15 +32,101 @@ from ..tools.router import RouterNode
30
32
 
31
33
  logger = logging.getLogger(__name__)
32
34
 
35
+ # Global registry for subgraph definitions
36
+ # Structure: {'subgraph_name': {'yaml': 'yaml_def', 'tools': [tools], 'flattened': False}}
37
+ SUBGRAPH_REGISTRY: Dict[str, Dict[str, Any]] = {}
38
+
39
+
40
+ # Wrapper for injecting a compiled subgraph into a parent StateGraph
41
+ class SubgraphRunnable(CompiledStateGraph):
42
+ def __init__(
43
+ self,
44
+ inner: CompiledStateGraph,
45
+ *,
46
+ name: str,
47
+ input_mapping: Dict[str, Any],
48
+ output_mapping: Dict[str, Any]
49
+ ):
50
+ # copy child graph internals
51
+ super().__init__(
52
+ builder=inner.builder,
53
+ config_type=inner.config_type,
54
+ nodes=inner.nodes,
55
+ channels=inner.channels,
56
+ input_channels=inner.input_channels,
57
+ stream_mode=inner.stream_mode,
58
+ output_channels=inner.output_channels,
59
+ stream_channels=inner.stream_channels,
60
+ checkpointer=inner.checkpointer,
61
+ interrupt_before_nodes=inner.interrupt_before_nodes,
62
+ interrupt_after_nodes=inner.interrupt_after_nodes,
63
+ auto_validate=False,
64
+ debug=inner.debug,
65
+ store=inner.store,
66
+ )
67
+ self.inner = inner
68
+ self.name = name
69
+ self.input_mapping = input_mapping or {}
70
+ self.output_mapping = output_mapping or {}
71
+
72
+ def invoke(
73
+ self,
74
+ state: Union[dict[str, Any], Any],
75
+ config: Optional[RunnableConfig] = None,
76
+ **kwargs: Any,
77
+ ) -> Union[dict[str, Any], Any]:
78
+ # Detailed logging for debugging
79
+ logger.debug(f"SubgraphRunnable '{self.name}' invoke called with state: {state}")
80
+ logger.debug(f"SubgraphRunnable '{self.name}' config: {config}")
81
+
82
+ # 1) parent -> child mapping
83
+ if not self.input_mapping:
84
+ child_input = state.copy()
85
+ else:
86
+ child_input = propagate_the_input_mapping(
87
+ self.input_mapping, list(self.input_mapping.keys()), state
88
+ )
89
+ # debug trace of messages flowing into child
90
+ logger.debug(f"SubgraphRunnable '{self.name}' child_input.messages: {child_input.get('messages')}")
91
+ logger.debug(f"SubgraphRunnable '{self.name}' child_input.input: {child_input.get('input')}")
92
+
93
+ # 2) Invoke the child graph.
94
+ # Pass None as the first argument for input if the child is expected to resume
95
+ # using its (now updated) checkpoint. The CompiledStateGraph.invoke method, when
96
+ # input is None but a checkpoint exists, loads from the checkpoint.
97
+ # Any resume commands (if applicable for internal child interrupts) are in 'config'.
98
+ # logger.debug(f"SubgraphRunnable '{self.name}': Invoking child graph super().invoke(None, config).")
99
+ subgraph_output = super().invoke(child_input, config=config, **kwargs)
100
+
101
+ # 3) child complete: apply output_mapping or passthrough
102
+ logger.debug(f"SubgraphRunnable '{self.name}' child complete, applying mappings")
103
+ result: Dict[str, Any] = {}
104
+ if self.output_mapping:
105
+ for child_key, parent_key in self.output_mapping.items():
106
+ if child_key in subgraph_output:
107
+ state[parent_key] = subgraph_output[child_key]
108
+ result[parent_key] = subgraph_output[child_key]
109
+ logger.debug(f"SubgraphRunnable '{self.name}' mapped {child_key} -> {parent_key}")
110
+ else:
111
+ for k, v in subgraph_output.items():
112
+ state[k] = v
113
+ result[k] = v
114
+
115
+ # include full messages history on completion
116
+ if 'messages' not in result:
117
+ result['messages'] = subgraph_output.get('messages', [])
118
+ logger.debug(f"SubgraphRunnable '{self.name}' returning result: {result}")
119
+ return result
120
+
33
121
 
34
122
  class ConditionalEdge(Runnable):
35
123
  name = "ConditionalEdge"
36
124
 
37
125
  def __init__(self, condition: str, condition_inputs: Optional[list[str]] = [],
38
- conditional_outputs: Optional[list[str]] = [], default_output: str = 'END'):
126
+ conditional_outputs: Optional[list[str]] = [], default_output: str = END):
39
127
  self.condition = condition
40
128
  self.condition_inputs = condition_inputs
41
- self.conditional_outputs = {clean_string(cond) for cond in conditional_outputs}
129
+ self.conditional_outputs = {clean_string(cond if not 'END' == cond else '__end__') for cond in conditional_outputs}
42
130
  self.default_output = clean_string(default_output)
43
131
 
44
132
  def invoke(self, state: Annotated[BaseStore, InjectedStore()], config: Optional[RunnableConfig] = None) -> str:
@@ -124,6 +212,25 @@ class TransitionalEdge(Runnable):
124
212
  )
125
213
  return self.next_step if self.next_step != 'END' else END
126
214
 
215
+ class StateDefaultNode(Runnable):
216
+ name = "StateDefaultNode"
217
+
218
+ def __init__(self, default_vars: dict = {}):
219
+ self.default_vars = default_vars
220
+
221
+ def invoke(self, state: BaseStore, config: Optional[RunnableConfig] = None) -> dict:
222
+ logger.info("Setting default state variables")
223
+ result = {}
224
+ for key, value in self.default_vars.items():
225
+ if isinstance(value, dict) and 'value' in value:
226
+ temp_value = value['value']
227
+ try:
228
+ result[key] = ast.literal_eval(temp_value)
229
+ except:
230
+ logger.debug("Unable to evaluate value, using as is")
231
+ result[key] = temp_value
232
+ return result
233
+
127
234
 
128
235
  class StateModifierNode(Runnable):
129
236
  name = "StateModifierNode"
@@ -177,8 +284,13 @@ class StateModifierNode(Runnable):
177
284
  return result
178
285
 
179
286
 
180
- def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=[], interrupt_after=[]):
287
+
288
+ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None, state_class=None):
181
289
  # prepare output channels
290
+ if interrupt_after is None:
291
+ interrupt_after = []
292
+ if interrupt_before is None:
293
+ interrupt_before = []
182
294
  output_channels = (
183
295
  "__root__"
184
296
  if len(lg_builder.schemas[lg_builder.output]) == 1
@@ -239,19 +351,30 @@ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_befo
239
351
  def create_graph(
240
352
  client: Any,
241
353
  yaml_schema: str,
242
- tools: list[BaseTool],
354
+ tools: list[Union[BaseTool, CompiledStateGraph]],
243
355
  *args,
244
356
  memory: Optional[Any] = None,
245
357
  store: Optional[BaseStore] = None,
246
358
  debug: bool = False,
359
+ for_subgraph: bool = False,
247
360
  **kwargs
248
361
  ):
249
362
  """ Create a message graph from a yaml schema """
363
+
364
+ # For top-level graphs (not subgraphs), detect and flatten any subgraphs
365
+ if not for_subgraph:
366
+ flattened_yaml, additional_tools = detect_and_flatten_subgraphs(yaml_schema)
367
+ # Add collected tools from subgraphs to the tools list
368
+ tools = list(tools) + additional_tools
369
+ # Use the flattened YAML for building the graph
370
+ yaml_schema = flattened_yaml
371
+
250
372
  schema = yaml.safe_load(yaml_schema)
251
373
  logger.debug(f"Schema: {schema}")
252
374
  logger.debug(f"Tools: {tools}")
253
375
  logger.info(f"Tools: {[tool.name for tool in tools]}")
254
- state_class = create_state(schema.get('state', {}))
376
+ state = schema.get('state', {})
377
+ state_class = create_state(state)
255
378
  lg_builder = StateGraph(state_class)
256
379
  interrupt_before = [clean_string(every) for every in schema.get('interrupt_before', [])]
257
380
  interrupt_after = [clean_string(every) for every in schema.get('interrupt_after', [])]
@@ -264,7 +387,7 @@ def create_graph(
264
387
  if toolkit_name:
265
388
  tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
266
389
  logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
267
- if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer']:
390
+ if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph']:
268
391
  for tool in tools:
269
392
  if tool.name == tool_name:
270
393
  if node_type == 'function':
@@ -274,6 +397,19 @@ def create_graph(
274
397
  input_mapping=node.get('input_mapping',
275
398
  {'messages': {'type': 'variable', 'value': 'messages'}}),
276
399
  input_variables=node.get('input', ['messages'])))
400
+ elif node_type == 'subgraph':
401
+ # assign parent memory/store
402
+ # tool.checkpointer = memory
403
+ # tool.store = store
404
+ # wrap with mappings
405
+ node_fn = SubgraphRunnable(
406
+ inner=tool,
407
+ name=node['id'],
408
+ input_mapping=node.get('input_mapping', {}),
409
+ output_mapping=node.get('output_mapping', {}),
410
+ )
411
+ lg_builder.add_node(node_id, node_fn)
412
+ break # skip legacy handling
277
413
  elif node_type == 'tool':
278
414
  lg_builder.add_node(node_id, ToolNode(
279
415
  client=client, tool=tool,
@@ -388,26 +524,54 @@ def create_graph(
388
524
  conditional_outputs=node['condition'].get('conditional_outputs', []),
389
525
  default_output=node['condition'].get('default_output', 'END')))
390
526
 
391
- lg_builder.set_entry_point(clean_string(schema['entry_point']))
527
+ # set default value for state variable at START
528
+ entry_point = clean_string(schema['entry_point'])
529
+ for key, value in state.items():
530
+ if 'type' in value and 'value' in value:
531
+ # set default value for state variable if it is defined in the schema
532
+ state_default_node = StateDefaultNode(default_vars=state)
533
+ lg_builder.add_node(state_default_node.name, state_default_node)
534
+ lg_builder.set_entry_point(state_default_node.name)
535
+ lg_builder.add_conditional_edges(state_default_node.name, TransitionalEdge(entry_point))
536
+ break
537
+ else:
538
+ # if no state variables are defined, set the entry point directly
539
+ lg_builder.set_entry_point(entry_point)
392
540
 
393
- # assign default values
394
541
  interrupt_before = interrupt_before or []
395
542
  interrupt_after = interrupt_after or []
396
543
 
397
- # validate the graph
398
- lg_builder.validate(
399
- interrupt=(
400
- (interrupt_before if interrupt_before != "*" else []) + interrupt_after
401
- if interrupt_after != "*"
402
- else []
544
+ if not for_subgraph:
545
+ # validate the graph for LangGraphAgentRunnable before the actual construction
546
+ lg_builder.validate(
547
+ interrupt=(
548
+ (interrupt_before if interrupt_before != "*" else []) + interrupt_after
549
+ if interrupt_after != "*"
550
+ else []
551
+ )
403
552
  )
553
+
554
+ # Compile into a CompiledStateGraph for the subgraph
555
+ graph = lg_builder.compile(
556
+ checkpointer=True,
557
+ interrupt_before=interrupt_before,
558
+ interrupt_after=interrupt_after,
559
+ store=store,
560
+ debug=debug,
404
561
  )
405
562
  except ValueError as e:
406
563
  raise ValueError(
407
564
  f"Validation of the schema failed. {e}\n\nDEBUG INFO:**Schema Nodes:**\n\n{lg_builder.nodes}\n\n**Schema Enges:**\n\n{lg_builder.edges}\n\n**Tools Available:**\n\n{tools}")
408
- compiled = prepare_output_schema(lg_builder, memory, store, debug,
409
- interrupt_before=interrupt_before,
410
- interrupt_after=interrupt_after)
565
+ # If building a nested subgraph, return the raw CompiledStateGraph
566
+ if for_subgraph:
567
+ return graph
568
+ # Otherwise prepare top-level runnable wrapper and validate
569
+ compiled = prepare_output_schema(
570
+ lg_builder, memory, store, debug,
571
+ interrupt_before=interrupt_before,
572
+ interrupt_after=interrupt_after,
573
+ state_class={state_class: None}
574
+ )
411
575
  return compiled.validate()
412
576
 
413
577
 
@@ -440,8 +604,249 @@ class LangGraphAgentRunnable(CompiledStateGraph):
440
604
  config_state = self.get_state(config)
441
605
  if config_state.next:
442
606
  thread_id = config['configurable']['thread_id']
443
- return {
607
+
608
+ result_with_state = {
444
609
  "output": output,
445
610
  "thread_id": thread_id,
446
611
  "execution_finished": not config_state.next
447
612
  }
613
+
614
+ # Include all state values in the result
615
+ if hasattr(config_state, 'values') and config_state.values:
616
+ for key, value in config_state.values.items():
617
+ result_with_state[key] = value
618
+
619
+ return result_with_state
620
+
621
+ def merge_subgraphs(parent_yaml: str, registry: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
622
+ """
623
+ Merge subgraphs into parent graph by flattening YAML structures.
624
+
625
+ This function implements the complete flattening approach:
626
+ 1. Parse parent YAML
627
+ 2. Detect subgraph nodes
628
+ 3. Recursively flatten subgraphs
629
+ 4. Merge states, nodes, interrupts, and transitions
630
+ 5. Return single unified graph definition
631
+
632
+ Args:
633
+ parent_yaml: YAML string of parent graph
634
+ registry: Global subgraph registry
635
+
636
+ Returns:
637
+ Dict containing flattened graph definition
638
+ """
639
+ import copy
640
+
641
+ # Parse parent YAML
642
+ parent_def = yaml.safe_load(parent_yaml)
643
+
644
+ # Check if already flattened (prevent infinite recursion)
645
+ if parent_def.get('_flattened', False):
646
+ return parent_def
647
+
648
+ # Find subgraph nodes in parent
649
+ subgraph_nodes = []
650
+ regular_nodes = []
651
+
652
+ for node in parent_def.get('nodes', []):
653
+ if node.get('type') == 'subgraph':
654
+ subgraph_nodes.append(node)
655
+ else:
656
+ regular_nodes.append(node)
657
+
658
+ # If no subgraphs, return as-is
659
+ if not subgraph_nodes:
660
+ parent_def['_flattened'] = True
661
+ return parent_def
662
+
663
+ # Start with parent state and merge subgraph states
664
+ merged_state = copy.deepcopy(parent_def.get('state', {}))
665
+ merged_nodes = copy.deepcopy(regular_nodes)
666
+ merged_interrupts_before = set(parent_def.get('interrupt_before', []))
667
+ merged_interrupts_after = set(parent_def.get('interrupt_after', []))
668
+ all_tools = []
669
+
670
+ # Track node remapping for transition rewiring
671
+ node_mapping = {} # subgraph_node_id -> actual_internal_node_id
672
+
673
+ # Process each subgraph
674
+ for subgraph_node in subgraph_nodes:
675
+ # Support both 'tool' and 'subgraph' fields for subgraph name
676
+ subgraph_name = subgraph_node.get('tool') or subgraph_node.get('subgraph')
677
+ subgraph_node_id = subgraph_node['id']
678
+
679
+ if subgraph_name not in registry:
680
+ logger.warning(f"Subgraph '{subgraph_name}' not found in registry")
681
+ continue
682
+
683
+ # Get subgraph definition
684
+ subgraph_entry = registry[subgraph_name]
685
+ subgraph_yaml = subgraph_entry['yaml']
686
+ subgraph_tools = subgraph_entry.get('tools', [])
687
+
688
+ # Recursively flatten the subgraph (in case it has nested subgraphs)
689
+ subgraph_def = merge_subgraphs(subgraph_yaml, registry)
690
+
691
+ # Collect tools from subgraph
692
+ all_tools.extend(subgraph_tools)
693
+
694
+ # Merge state (union of all fields)
695
+ for field_name, field_type in subgraph_def.get('state', {}).items():
696
+ if field_name not in merged_state:
697
+ merged_state[field_name] = field_type
698
+ elif merged_state[field_name] != field_type:
699
+ logger.warning(f"State field '{field_name}' type mismatch: {merged_state[field_name]} vs {field_type}")
700
+
701
+ # Map subgraph node to its entry point
702
+ subgraph_entry_point = subgraph_def.get('entry_point')
703
+ if subgraph_entry_point:
704
+ node_mapping[subgraph_node_id] = subgraph_entry_point
705
+ logger.debug(f"Mapped subgraph node '{subgraph_node_id}' to entry point '{subgraph_entry_point}'")
706
+
707
+ # Add subgraph nodes without prefixing (keep original IDs)
708
+ for sub_node in subgraph_def.get('nodes', []):
709
+ # Keep original node ID - no prefixing
710
+ new_node = copy.deepcopy(sub_node)
711
+ merged_nodes.append(new_node)
712
+
713
+ # Handle the original subgraph node's transition - apply it to nodes that end with END
714
+ original_transition = subgraph_node.get('transition')
715
+ if original_transition and original_transition != 'END' and original_transition != END:
716
+ # Find nodes in this subgraph that have END transitions and update them
717
+ for node in merged_nodes:
718
+ # Check if this is a node from the current subgraph by checking if it was just added
719
+ # and has an END transition
720
+ if node.get('transition') == 'END' and node in subgraph_def.get('nodes', []):
721
+ node['transition'] = original_transition
722
+
723
+ # Merge interrupts without prefixing (keep original names)
724
+ for interrupt in subgraph_def.get('interrupt_before', []):
725
+ merged_interrupts_before.add(interrupt) # No prefixing
726
+ for interrupt in subgraph_def.get('interrupt_after', []):
727
+ merged_interrupts_after.add(interrupt) # No prefixing
728
+
729
+ # Handle entry point - keep parent's unless it's a subgraph node
730
+ entry_point = parent_def.get('entry_point')
731
+ logger.debug(f"Original entry point: {entry_point}")
732
+ logger.debug(f"Node mapping: {node_mapping}")
733
+ if entry_point in node_mapping:
734
+ # Parent entry point is a subgraph, redirect to subgraph's entry point
735
+ old_entry_point = entry_point
736
+ entry_point = node_mapping[entry_point]
737
+ logger.debug(f"Entry point changed from {old_entry_point} to {entry_point}")
738
+ else:
739
+ logger.debug(f"Entry point {entry_point} not in node mapping, keeping as-is")
740
+
741
+ # Rewrite transitions in regular nodes that point to subgraph nodes
742
+ for node in merged_nodes:
743
+ # Handle direct transitions
744
+ if 'transition' in node:
745
+ transition = node['transition']
746
+ if transition in node_mapping:
747
+ node['transition'] = node_mapping[transition]
748
+
749
+ # Handle conditional transitions
750
+ if 'condition' in node:
751
+ condition = node['condition']
752
+ if 'conditional_outputs' in condition:
753
+ new_outputs = []
754
+ for output in condition['conditional_outputs']:
755
+ if output in node_mapping:
756
+ new_outputs.append(node_mapping[output])
757
+ else:
758
+ new_outputs.append(output)
759
+ condition['conditional_outputs'] = new_outputs
760
+
761
+ if 'default_output' in condition:
762
+ default = condition['default_output']
763
+ if default in node_mapping:
764
+ condition['default_output'] = node_mapping[default]
765
+
766
+ # Update condition_definition Jinja2 template to replace subgraph node references
767
+ if 'condition_definition' in condition:
768
+ condition_definition = condition['condition_definition']
769
+ # Replace subgraph node references in the Jinja2 template
770
+ for subgraph_node_id, subgraph_entry_point in node_mapping.items():
771
+ condition_definition = condition_definition.replace(subgraph_node_id, subgraph_entry_point)
772
+ condition['condition_definition'] = condition_definition
773
+
774
+ # Handle decision nodes
775
+ if 'decision' in node:
776
+ decision = node['decision']
777
+ # Update decision.nodes list to replace subgraph node references
778
+ if 'nodes' in decision:
779
+ new_nodes = []
780
+ for decision_node in decision['nodes']:
781
+ if decision_node in node_mapping:
782
+ new_nodes.append(node_mapping[decision_node])
783
+ else:
784
+ new_nodes.append(decision_node)
785
+ decision['nodes'] = new_nodes
786
+
787
+ # Update decision.default_output to replace subgraph node references
788
+ if 'default_output' in decision:
789
+ default_output = decision['default_output']
790
+ if default_output in node_mapping:
791
+ decision['default_output'] = node_mapping[default_output]
792
+
793
+ # Build final flattened definition
794
+ flattened = {
795
+ 'name': parent_def.get('name', 'FlattenedGraph'),
796
+ 'state': merged_state,
797
+ 'nodes': merged_nodes,
798
+ 'entry_point': entry_point,
799
+ '_flattened': True,
800
+ '_all_tools': all_tools # Store tools for later collection
801
+ }
802
+
803
+ # Add interrupts if present
804
+ if merged_interrupts_before:
805
+ flattened['interrupt_before'] = list(merged_interrupts_before)
806
+ if merged_interrupts_after:
807
+ flattened['interrupt_after'] = list(merged_interrupts_after)
808
+
809
+ return flattened
810
+
811
+
812
+ def detect_and_flatten_subgraphs(yaml_schema: str) -> tuple[str, list]:
813
+ """
814
+ Detect subgraphs in YAML and flatten them if found.
815
+
816
+ Returns:
817
+ tuple: (flattened_yaml_string, collected_tools)
818
+ """
819
+ # Parse to check for subgraphs
820
+ schema_dict = yaml.safe_load(yaml_schema)
821
+ subgraph_nodes = [
822
+ node for node in schema_dict.get('nodes', [])
823
+ if node.get('type') == 'subgraph'
824
+ ]
825
+
826
+ if not subgraph_nodes:
827
+ return yaml_schema, []
828
+
829
+ # Check if all required subgraphs are available in registry
830
+ missing_subgraphs = []
831
+ for node in subgraph_nodes:
832
+ # Support both 'tool' and 'subgraph' fields for subgraph name
833
+ # Don't clean the string - registry keys use original names
834
+ subgraph_name = node.get('tool') or node.get('subgraph')
835
+ if subgraph_name and subgraph_name not in SUBGRAPH_REGISTRY:
836
+ missing_subgraphs.append(subgraph_name)
837
+
838
+ if missing_subgraphs:
839
+ logger.warning(f"Cannot flatten - missing subgraphs: {missing_subgraphs}")
840
+ return yaml_schema, []
841
+
842
+ # Flatten the graph
843
+ flattened_def = merge_subgraphs(yaml_schema, SUBGRAPH_REGISTRY)
844
+
845
+ # Extract tools
846
+ all_tools = flattened_def.pop('_all_tools', [])
847
+
848
+ # Convert back to YAML
849
+ flattened_yaml = yaml.dump(flattened_def, default_flow_style=False)
850
+
851
+ return flattened_yaml, all_tools
852
+
@@ -0,0 +1,53 @@
1
+ from typing import List, Any
2
+
3
+ from langgraph.graph.state import CompiledStateGraph
4
+
5
+ from ..langchain.langraph_agent import create_graph, SUBGRAPH_REGISTRY
6
+ from ..utils.utils import clean_string
7
+
8
+
9
+ class SubgraphToolkit:
10
+
11
+ @staticmethod
12
+ def get_toolkit(
13
+ client: Any,
14
+ application_id: int,
15
+ application_version_id: int,
16
+ llm,
17
+ app_api_key: str,
18
+ selected_tools: list[str] = []
19
+ ) -> List[CompiledStateGraph]:
20
+ from .tools import get_tools
21
+ # from langgraph.checkpoint.memory import MemorySaver
22
+
23
+ app_details = client.get_app_details(application_id)
24
+ version_details = client.get_app_version_details(application_id, application_version_id)
25
+ tools = get_tools(version_details['tools'], alita_client=client, llm=llm)
26
+
27
+ # Get the subgraph name
28
+ subgraph_name = app_details.get("name")
29
+
30
+ # Populate the registry for flattening approach
31
+ SUBGRAPH_REGISTRY[subgraph_name] = {
32
+ 'yaml': version_details['instructions'],
33
+ 'tools': tools,
34
+ 'flattened': False
35
+ }
36
+
37
+ # For backward compatibility, still create a compiled graph stub
38
+ # This is mainly used for identification in the parent graph's tools list
39
+ graph = create_graph(
40
+ client=llm,
41
+ tools=tools,
42
+ yaml_schema=version_details['instructions'],
43
+ debug=False,
44
+ store=None,
45
+ memory=None,
46
+ for_subgraph=True, # compile as raw subgraph
47
+ )
48
+
49
+ # Tag the graph stub for parent lookup
50
+ graph.name = clean_string(subgraph_name)
51
+
52
+ # Return the compiled graph stub for backward compatibility
53
+ return [graph]
@@ -1,14 +1,14 @@
1
1
  import logging
2
2
 
3
- from alita_tools import get_tools as alita_tools
4
3
  from alita_tools import get_toolkits as alita_toolkits
4
+ from alita_tools import get_tools as alita_tools
5
5
 
6
- from .prompt import PromptToolkit
7
- from .datasource import DatasourcesToolkit
8
6
  from .application import ApplicationToolkit
9
7
  from .artifact import ArtifactToolkit
8
+ from .datasource import DatasourcesToolkit
9
+ from .prompt import PromptToolkit
10
+ from .subgraph import SubgraphToolkit
10
11
  from .vectorstore import VectorStoreToolkit
11
-
12
12
  ## Community tools and toolkits
13
13
  from ..community.analysis.jira_analyse import AnalyseJira
14
14
  from ..community.browseruse import BrowserUseToolkit
@@ -35,7 +35,7 @@ def get_toolkits():
35
35
  return core_toolkits + community_toolkits + alita_toolkits()
36
36
 
37
37
 
38
- def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> list:
38
+ def get_tools(tools_list: list, alita_client, llm) -> list:
39
39
  prompts = []
40
40
  tools = []
41
41
 
@@ -47,33 +47,43 @@ def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> l
47
47
  ])
48
48
  elif tool['type'] == 'datasource':
49
49
  tools.extend(DatasourcesToolkit.get_toolkit(
50
- alita,
50
+ alita_client,
51
51
  datasource_ids=[int(tool['settings']['datasource_id'])],
52
52
  selected_tools=tool['settings']['selected_tools'],
53
53
  toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
54
54
  ).get_tools())
55
55
  elif tool['type'] == 'application':
56
56
  tools.extend(ApplicationToolkit.get_toolkit(
57
- alita,
57
+ alita_client,
58
58
  application_id=int(tool['settings']['application_id']),
59
59
  application_version_id=int(tool['settings']['application_version_id']),
60
- app_api_key=alita.auth_token,
60
+ app_api_key=alita_client.auth_token,
61
61
  selected_tools=[]
62
62
  ).get_tools())
63
+ elif tool['type'] == 'subgraph':
64
+ # static get_toolkit returns a list of CompiledStateGraph stubs
65
+ tools.extend(SubgraphToolkit.get_toolkit(
66
+ alita_client,
67
+ application_id=int(tool['settings']['application_id']),
68
+ application_version_id=int(tool['settings']['application_version_id']),
69
+ app_api_key=alita_client.auth_token,
70
+ selected_tools=[],
71
+ llm=llm
72
+ ))
63
73
  elif tool['type'] == 'artifact':
64
74
  tools.extend(ArtifactToolkit.get_toolkit(
65
- client=alita,
75
+ client=alita_client,
66
76
  bucket=tool['settings']['bucket'],
67
77
  toolkit_name=tool.get('toolkit_name', ''),
68
78
  selected_tools=tool['settings'].get('selected_tools', [])
69
79
  ).get_tools())
70
80
  if tool['type'] == 'analyse_jira':
71
81
  tools.extend(AnalyseJira.get_toolkit(
72
- client=alita,
82
+ client=alita_client,
73
83
  **tool['settings']).get_tools())
74
84
  if tool['type'] == 'browser_use':
75
85
  tools.extend(BrowserUseToolkit.get_toolkit(
76
- client=alita,
86
+ client=alita_client,
77
87
  llm=llm,
78
88
  toolkit_name=tool.get('toolkit_name', ''),
79
89
  **tool['settings']).get_tools())
@@ -83,9 +93,9 @@ def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> l
83
93
  toolkit_name=tool.get('toolkit_name', ''),
84
94
  **tool['settings']).get_tools())
85
95
  if len(prompts) > 0:
86
- tools += PromptToolkit.get_toolkit(alita, prompts).get_tools()
87
- tools += alita_tools(tools_list, alita, llm)
88
- tools += _mcp_tools(tools_list, alita)
96
+ tools += PromptToolkit.get_toolkit(alita_client, prompts).get_tools()
97
+ tools += alita_tools(tools_list, alita_client, llm)
98
+ tools += _mcp_tools(tools_list, alita_client)
89
99
  return tools
90
100
 
91
101
 
alita_sdk/tools/tool.py CHANGED
@@ -10,6 +10,7 @@ from langchain_core.tools import BaseTool
10
10
  from langchain_core.utils.function_calling import convert_to_openai_tool
11
11
  from pydantic import ValidationError, BaseModel, create_model
12
12
 
13
+ from .application import Application
13
14
  from ..langchain.utils import _extract_json
14
15
 
15
16
  logger = logging.getLogger(__name__)
@@ -74,8 +75,9 @@ Anwer must be JSON only extractable by JSON.LOADS."""
74
75
  ))
75
76
  ]
76
77
  if self.structured_output:
77
- # cut defaults from schema
78
- fields = {name: (field.annotation, ...) for name, field in self.tool.args_schema.model_fields.items()}
78
+ # cut defaults from schema and remove chat_history for application as a tool
79
+ fields = {name: (field.annotation, ...) for name, field
80
+ in self.tool.args_schema.model_fields.items() if name != 'chat_history'}
79
81
  input_schema = create_model('NewModel', **fields)
80
82
 
81
83
  llm = self.client.with_structured_output(input_schema)
@@ -87,6 +89,10 @@ Anwer must be JSON only extractable by JSON.LOADS."""
87
89
  result = _extract_json(completion.content.strip())
88
90
  logger.info(f"ToolNode tool params: {result}")
89
91
  try:
92
+ # handler for application added as a tool
93
+ if isinstance(self.tool, Application):
94
+ # set empty chat history
95
+ result['chat_history'] = None
90
96
  tool_result = self.tool.invoke(result, config=config, kwargs=kwargs)
91
97
  dispatch_custom_event(
92
98
  "on_tool_node", {
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alita_sdk
3
- Version: 0.3.125
3
+ Version: 0.3.127
4
4
  Summary: SDK for building langchain agents using resouces from Alita
5
5
  Author-email: Artem Rozumenko <artyom.rozumenko@gmail.com>, Mikalai Biazruchka <mikalai_biazruchka@epam.com>, Roman Mitusov <roman_mitusov@epam.com>, Ivan Krakhmaliuk <lifedjik@gmail.com>
6
6
  Project-URL: Homepage, https://projectalita.ai
@@ -15,11 +15,11 @@ alita_sdk/community/analysis/jira_analyse/api_wrapper.py,sha256=JqGSxg_3x0ErzII3
15
15
  alita_sdk/community/browseruse/__init__.py,sha256=uAxPZEX7ihpt8HtcGDFrzTNv9WcklT1wG1ItTwUO8y4,3601
16
16
  alita_sdk/community/browseruse/api_wrapper.py,sha256=Y05NKWfTROPmBxe8ZFIELSGBX5v3RTNP30OTO2Tj8uI,10838
17
17
  alita_sdk/langchain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
- alita_sdk/langchain/assistant.py,sha256=V0MpZG9IQRlKMr1SfAabVbUsIp-gNPO9A1SWxuFikNY,5930
18
+ alita_sdk/langchain/assistant.py,sha256=J_xhwbNl934BgDKSpAMC9a1u6v03DZQcTYaamCztEPk,6272
19
19
  alita_sdk/langchain/chat_message_template.py,sha256=kPz8W2BG6IMyITFDA5oeb5BxVRkHEVZhuiGl4MBZKdc,2176
20
20
  alita_sdk/langchain/constants.py,sha256=eHVJ_beJNTf1WJo4yq7KMK64fxsRvs3lKc34QCXSbpk,3319
21
21
  alita_sdk/langchain/indexer.py,sha256=0ENHy5EOhThnAiYFc7QAsaTNp9rr8hDV_hTK8ahbatk,37592
22
- alita_sdk/langchain/langraph_agent.py,sha256=f9rGk6QGQbfbGudwpM5ax9yS-xDlElFqGAOsZbGvrtI,20919
22
+ alita_sdk/langchain/langraph_agent.py,sha256=PrD_9XEX7_LDOT_SuohW_nhqLMlzc14xmByHPrO0V6E,37951
23
23
  alita_sdk/langchain/mixedAgentParser.py,sha256=M256lvtsL3YtYflBCEp-rWKrKtcY1dJIyRGVv7KW9ME,2611
24
24
  alita_sdk/langchain/mixedAgentRenderes.py,sha256=asBtKqm88QhZRILditjYICwFVKF5KfO38hu2O-WrSWE,5964
25
25
  alita_sdk/langchain/utils.py,sha256=Npferkn10dvdksnKzLJLBI5bNGQyVWTBwqp3vQtUqmY,6631
@@ -68,7 +68,8 @@ alita_sdk/toolkits/application.py,sha256=LrxbBV05lkRP3_WtKGBKtMdoQHXVY-_AtFr1cUu
68
68
  alita_sdk/toolkits/artifact.py,sha256=7zb17vhJ3CigeTqvzQ4VNBsU5UOCJqAwz7fOJGMYqXw,2348
69
69
  alita_sdk/toolkits/datasource.py,sha256=v3FQu8Gmvq7gAGAnFEbA8qofyUhh98rxgIjY6GHBfyI,2494
70
70
  alita_sdk/toolkits/prompt.py,sha256=WIpTkkVYWqIqOWR_LlSWz3ug8uO9tm5jJ7aZYdiGRn0,1192
71
- alita_sdk/toolkits/tools.py,sha256=RFMJ2bYNotPNmz6-clUdAakuIyrunSA16nFZ1gDFSF0,5273
71
+ alita_sdk/toolkits/subgraph.py,sha256=ZYqI4yVLbEPAjCR8dpXbjbL2ipX598Hk3fL6AgaqFD4,1758
72
+ alita_sdk/toolkits/tools.py,sha256=gk3nvQBdab3QM8v93ff2nrN4ZfcT779yae2RygkTl8s,5834
72
73
  alita_sdk/toolkits/vectorstore.py,sha256=di08-CRl0KJ9xSZ8_24VVnPZy58iLqHtXW8vuF29P64,2893
73
74
  alita_sdk/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
74
75
  alita_sdk/tools/application.py,sha256=UJlYd3Sub10LpAoKkKEpvd4miWyrS-yYE5NKyqx-H4Q,2194
@@ -84,7 +85,7 @@ alita_sdk/tools/mcp_server_tool.py,sha256=xcH9AiqfR2TYrwJ3Ixw-_A7XDodtJCnwmq1Ssi
84
85
  alita_sdk/tools/pgvector_search.py,sha256=NN2BGAnq4SsDHIhUcFZ8d_dbEOM8QwB0UwpsWCYruXU,11692
85
86
  alita_sdk/tools/prompt.py,sha256=nJafb_e5aOM1Rr3qGFCR-SKziU9uCsiP2okIMs9PppM,741
86
87
  alita_sdk/tools/router.py,sha256=wCvZjVkdXK9dMMeEerrgKf5M790RudH68pDortnHSz0,1517
87
- alita_sdk/tools/tool.py,sha256=jFRq8BeC55NwpgdpsqGk_Y3tZL4YKN0rE7sVS5OE3yg,5092
88
+ alita_sdk/tools/tool.py,sha256=f2ULDU4PU4PlLgygT_lsInLgNROJeWUNXLe0i0uOcqI,5419
88
89
  alita_sdk/tools/vectorstore.py,sha256=F-DoHxPa4UVsKB-FEd-wWa59QGQifKMwcSNcZ5WZOKc,23496
89
90
  alita_sdk/utils/AlitaCallback.py,sha256=cvpDhR4QLVCNQci6CO6TEUrUVDZU9_CRSwzcHGm3SGw,7356
90
91
  alita_sdk/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -92,10 +93,10 @@ alita_sdk/utils/evaluate.py,sha256=iM1P8gzBLHTuSCe85_Ng_h30m52hFuGuhNXJ7kB1tgI,1
92
93
  alita_sdk/utils/logging.py,sha256=hBE3qAzmcLMdamMp2YRXwOOK9P4lmNaNhM76kntVljs,3124
93
94
  alita_sdk/utils/streamlit.py,sha256=zp8owZwHI3HZplhcExJf6R3-APtWx-z6s5jznT2hY_k,29124
94
95
  alita_sdk/utils/utils.py,sha256=dM8whOJAuFJFe19qJ69-FLzrUp6d2G-G6L7d4ss2XqM,346
95
- alita_sdk-0.3.125.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
96
+ alita_sdk-0.3.127.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
96
97
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
97
98
  tests/test_jira_analysis.py,sha256=I0cErH5R_dHVyutpXrM1QEo7jfBuKWTmDQvJBPjx18I,3281
98
- alita_sdk-0.3.125.dist-info/METADATA,sha256=raIYIJIfySOwIA5jp0DHo4UlVizFOhLpWwLwx9AhQi8,7075
99
- alita_sdk-0.3.125.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
100
- alita_sdk-0.3.125.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
101
- alita_sdk-0.3.125.dist-info/RECORD,,
99
+ alita_sdk-0.3.127.dist-info/METADATA,sha256=Ox_VkvvGHqTNfe_wFqkVXU0etHQmW92EziEoEM5D158,7075
100
+ alita_sdk-0.3.127.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
101
+ alita_sdk-0.3.127.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
102
+ alita_sdk-0.3.127.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5