alita-sdk 0.3.124__py3-none-any.whl → 0.3.126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/langchain/assistant.py +11 -7
- alita_sdk/langchain/langraph_agent.py +389 -16
- alita_sdk/toolkits/subgraph.py +53 -0
- alita_sdk/toolkits/tools.py +47 -33
- {alita_sdk-0.3.124.dist-info → alita_sdk-0.3.126.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.124.dist-info → alita_sdk-0.3.126.dist-info}/RECORD +9 -8
- {alita_sdk-0.3.124.dist-info → alita_sdk-0.3.126.dist-info}/WHEEL +1 -1
- {alita_sdk-0.3.124.dist-info → alita_sdk-0.3.126.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.124.dist-info → alita_sdk-0.3.126.dist-info}/top_level.txt +0 -0
alita_sdk/langchain/assistant.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
|
2
1
|
import logging
|
3
2
|
import importlib
|
4
3
|
from copy import deepcopy as copy
|
@@ -17,6 +16,7 @@ from .constants import REACT_ADDON, REACT_VARS, XML_ADDON
|
|
17
16
|
from .chat_message_template import Jinja2TemplatedChatMessagesTemplate
|
18
17
|
from ..tools.echo import EchoTool
|
19
18
|
from ..toolkits.tools import get_tools
|
19
|
+
from langchain_core.tools import BaseTool
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
@@ -53,7 +53,7 @@ class Assistant:
|
|
53
53
|
target_name
|
54
54
|
)
|
55
55
|
self.client = target_cls(**model_params)
|
56
|
-
self.tools = get_tools(data['tools'],
|
56
|
+
self.tools = get_tools(data['tools'], alita_client=alita, llm=self.client)
|
57
57
|
if app_type == "pipeline":
|
58
58
|
self.prompt = data['instructions']
|
59
59
|
else:
|
@@ -106,18 +106,22 @@ class Assistant:
|
|
106
106
|
max_execution_time=None, return_intermediate_steps=True)
|
107
107
|
|
108
108
|
def getAgentExecutor(self):
|
109
|
-
|
110
|
-
|
111
|
-
|
109
|
+
# Exclude compiled graph runnables from simple tool agents
|
110
|
+
simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
|
111
|
+
agent = create_json_chat_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
|
112
112
|
return self._agent_executor(agent)
|
113
113
|
|
114
114
|
|
115
115
|
def getXMLAgentExecutor(self):
|
116
|
-
|
116
|
+
# Exclude compiled graph runnables from simple tool agents
|
117
|
+
simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
|
118
|
+
agent = create_xml_chat_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
|
117
119
|
return self._agent_executor(agent)
|
118
120
|
|
119
121
|
def getOpenAIToolsAgentExecutor(self):
|
120
|
-
|
122
|
+
# Exclude compiled graph runnables from simple tool agents
|
123
|
+
simple_tools = [t for t in self.tools if isinstance(t, BaseTool)]
|
124
|
+
agent = create_openai_tools_agent(llm=self.client, tools=simple_tools, prompt=self.prompt)
|
121
125
|
return self._agent_executor(agent)
|
122
126
|
|
123
127
|
def pipeline(self):
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from typing import Union, Any, Optional, Annotated, get_type_hints
|
3
3
|
from uuid import uuid4
|
4
|
+
from typing import Dict
|
4
5
|
|
5
6
|
import yaml
|
6
7
|
from langchain_core.callbacks import dispatch_custom_event
|
@@ -17,7 +18,7 @@ from langgraph.prebuilt import InjectedStore
|
|
17
18
|
from langgraph.store.base import BaseStore
|
18
19
|
|
19
20
|
from .mixedAgentRenderes import convert_message_to_json
|
20
|
-
from .utils import create_state
|
21
|
+
from .utils import create_state, propagate_the_input_mapping
|
21
22
|
from ..tools.function import FunctionTool
|
22
23
|
from ..tools.indexer_tool import IndexerNode
|
23
24
|
from ..tools.llm import LLMNode
|
@@ -30,15 +31,101 @@ from ..tools.router import RouterNode
|
|
30
31
|
|
31
32
|
logger = logging.getLogger(__name__)
|
32
33
|
|
34
|
+
# Global registry for subgraph definitions
|
35
|
+
# Structure: {'subgraph_name': {'yaml': 'yaml_def', 'tools': [tools], 'flattened': False}}
|
36
|
+
SUBGRAPH_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
37
|
+
|
38
|
+
|
39
|
+
# Wrapper for injecting a compiled subgraph into a parent StateGraph
|
40
|
+
class SubgraphRunnable(CompiledStateGraph):
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
inner: CompiledStateGraph,
|
44
|
+
*,
|
45
|
+
name: str,
|
46
|
+
input_mapping: Dict[str, Any],
|
47
|
+
output_mapping: Dict[str, Any]
|
48
|
+
):
|
49
|
+
# copy child graph internals
|
50
|
+
super().__init__(
|
51
|
+
builder=inner.builder,
|
52
|
+
config_type=inner.config_type,
|
53
|
+
nodes=inner.nodes,
|
54
|
+
channels=inner.channels,
|
55
|
+
input_channels=inner.input_channels,
|
56
|
+
stream_mode=inner.stream_mode,
|
57
|
+
output_channels=inner.output_channels,
|
58
|
+
stream_channels=inner.stream_channels,
|
59
|
+
checkpointer=inner.checkpointer,
|
60
|
+
interrupt_before_nodes=inner.interrupt_before_nodes,
|
61
|
+
interrupt_after_nodes=inner.interrupt_after_nodes,
|
62
|
+
auto_validate=False,
|
63
|
+
debug=inner.debug,
|
64
|
+
store=inner.store,
|
65
|
+
)
|
66
|
+
self.inner = inner
|
67
|
+
self.name = name
|
68
|
+
self.input_mapping = input_mapping or {}
|
69
|
+
self.output_mapping = output_mapping or {}
|
70
|
+
|
71
|
+
def invoke(
|
72
|
+
self,
|
73
|
+
state: Union[dict[str, Any], Any],
|
74
|
+
config: Optional[RunnableConfig] = None,
|
75
|
+
**kwargs: Any,
|
76
|
+
) -> Union[dict[str, Any], Any]:
|
77
|
+
# Detailed logging for debugging
|
78
|
+
logger.debug(f"SubgraphRunnable '{self.name}' invoke called with state: {state}")
|
79
|
+
logger.debug(f"SubgraphRunnable '{self.name}' config: {config}")
|
80
|
+
|
81
|
+
# 1) parent -> child mapping
|
82
|
+
if not self.input_mapping:
|
83
|
+
child_input = state.copy()
|
84
|
+
else:
|
85
|
+
child_input = propagate_the_input_mapping(
|
86
|
+
self.input_mapping, list(self.input_mapping.keys()), state
|
87
|
+
)
|
88
|
+
# debug trace of messages flowing into child
|
89
|
+
logger.debug(f"SubgraphRunnable '{self.name}' child_input.messages: {child_input.get('messages')}")
|
90
|
+
logger.debug(f"SubgraphRunnable '{self.name}' child_input.input: {child_input.get('input')}")
|
91
|
+
|
92
|
+
# 2) Invoke the child graph.
|
93
|
+
# Pass None as the first argument for input if the child is expected to resume
|
94
|
+
# using its (now updated) checkpoint. The CompiledStateGraph.invoke method, when
|
95
|
+
# input is None but a checkpoint exists, loads from the checkpoint.
|
96
|
+
# Any resume commands (if applicable for internal child interrupts) are in 'config'.
|
97
|
+
# logger.debug(f"SubgraphRunnable '{self.name}': Invoking child graph super().invoke(None, config).")
|
98
|
+
subgraph_output = super().invoke(child_input, config=config, **kwargs)
|
99
|
+
|
100
|
+
# 3) child complete: apply output_mapping or passthrough
|
101
|
+
logger.debug(f"SubgraphRunnable '{self.name}' child complete, applying mappings")
|
102
|
+
result: Dict[str, Any] = {}
|
103
|
+
if self.output_mapping:
|
104
|
+
for child_key, parent_key in self.output_mapping.items():
|
105
|
+
if child_key in subgraph_output:
|
106
|
+
state[parent_key] = subgraph_output[child_key]
|
107
|
+
result[parent_key] = subgraph_output[child_key]
|
108
|
+
logger.debug(f"SubgraphRunnable '{self.name}' mapped {child_key} -> {parent_key}")
|
109
|
+
else:
|
110
|
+
for k, v in subgraph_output.items():
|
111
|
+
state[k] = v
|
112
|
+
result[k] = v
|
113
|
+
|
114
|
+
# include full messages history on completion
|
115
|
+
if 'messages' not in result:
|
116
|
+
result['messages'] = subgraph_output.get('messages', [])
|
117
|
+
logger.debug(f"SubgraphRunnable '{self.name}' returning result: {result}")
|
118
|
+
return result
|
119
|
+
|
33
120
|
|
34
121
|
class ConditionalEdge(Runnable):
|
35
122
|
name = "ConditionalEdge"
|
36
123
|
|
37
124
|
def __init__(self, condition: str, condition_inputs: Optional[list[str]] = [],
|
38
|
-
conditional_outputs: Optional[list[str]] = [], default_output: str =
|
125
|
+
conditional_outputs: Optional[list[str]] = [], default_output: str = END):
|
39
126
|
self.condition = condition
|
40
127
|
self.condition_inputs = condition_inputs
|
41
|
-
self.conditional_outputs = {clean_string(cond) for cond in conditional_outputs}
|
128
|
+
self.conditional_outputs = {clean_string(cond if not 'END' == cond else '__end__') for cond in conditional_outputs}
|
42
129
|
self.default_output = clean_string(default_output)
|
43
130
|
|
44
131
|
def invoke(self, state: Annotated[BaseStore, InjectedStore()], config: Optional[RunnableConfig] = None) -> str:
|
@@ -177,8 +264,13 @@ class StateModifierNode(Runnable):
|
|
177
264
|
return result
|
178
265
|
|
179
266
|
|
180
|
-
|
267
|
+
|
268
|
+
def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None, state_class=None):
|
181
269
|
# prepare output channels
|
270
|
+
if interrupt_after is None:
|
271
|
+
interrupt_after = []
|
272
|
+
if interrupt_before is None:
|
273
|
+
interrupt_before = []
|
182
274
|
output_channels = (
|
183
275
|
"__root__"
|
184
276
|
if len(lg_builder.schemas[lg_builder.output]) == 1
|
@@ -239,14 +331,24 @@ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_befo
|
|
239
331
|
def create_graph(
|
240
332
|
client: Any,
|
241
333
|
yaml_schema: str,
|
242
|
-
tools: list[BaseTool],
|
334
|
+
tools: list[Union[BaseTool, CompiledStateGraph]],
|
243
335
|
*args,
|
244
336
|
memory: Optional[Any] = None,
|
245
337
|
store: Optional[BaseStore] = None,
|
246
338
|
debug: bool = False,
|
339
|
+
for_subgraph: bool = False,
|
247
340
|
**kwargs
|
248
341
|
):
|
249
342
|
""" Create a message graph from a yaml schema """
|
343
|
+
|
344
|
+
# For top-level graphs (not subgraphs), detect and flatten any subgraphs
|
345
|
+
if not for_subgraph:
|
346
|
+
flattened_yaml, additional_tools = detect_and_flatten_subgraphs(yaml_schema)
|
347
|
+
# Add collected tools from subgraphs to the tools list
|
348
|
+
tools = list(tools) + additional_tools
|
349
|
+
# Use the flattened YAML for building the graph
|
350
|
+
yaml_schema = flattened_yaml
|
351
|
+
|
250
352
|
schema = yaml.safe_load(yaml_schema)
|
251
353
|
logger.debug(f"Schema: {schema}")
|
252
354
|
logger.debug(f"Tools: {tools}")
|
@@ -264,7 +366,7 @@ def create_graph(
|
|
264
366
|
if toolkit_name:
|
265
367
|
tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
|
266
368
|
logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
|
267
|
-
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer']:
|
369
|
+
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph']:
|
268
370
|
for tool in tools:
|
269
371
|
if tool.name == tool_name:
|
270
372
|
if node_type == 'function':
|
@@ -274,6 +376,19 @@ def create_graph(
|
|
274
376
|
input_mapping=node.get('input_mapping',
|
275
377
|
{'messages': {'type': 'variable', 'value': 'messages'}}),
|
276
378
|
input_variables=node.get('input', ['messages'])))
|
379
|
+
elif node_type == 'subgraph':
|
380
|
+
# assign parent memory/store
|
381
|
+
# tool.checkpointer = memory
|
382
|
+
# tool.store = store
|
383
|
+
# wrap with mappings
|
384
|
+
node_fn = SubgraphRunnable(
|
385
|
+
inner=tool,
|
386
|
+
name=node['id'],
|
387
|
+
input_mapping=node.get('input_mapping', {}),
|
388
|
+
output_mapping=node.get('output_mapping', {}),
|
389
|
+
)
|
390
|
+
lg_builder.add_node(node_id, node_fn)
|
391
|
+
break # skip legacy handling
|
277
392
|
elif node_type == 'tool':
|
278
393
|
lg_builder.add_node(node_id, ToolNode(
|
279
394
|
client=client, tool=tool,
|
@@ -394,20 +509,37 @@ def create_graph(
|
|
394
509
|
interrupt_before = interrupt_before or []
|
395
510
|
interrupt_after = interrupt_after or []
|
396
511
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
(
|
401
|
-
|
402
|
-
|
512
|
+
if not for_subgraph:
|
513
|
+
# validate the graph for LangGraphAgentRunnable before the actual construction
|
514
|
+
lg_builder.validate(
|
515
|
+
interrupt=(
|
516
|
+
(interrupt_before if interrupt_before != "*" else []) + interrupt_after
|
517
|
+
if interrupt_after != "*"
|
518
|
+
else []
|
519
|
+
)
|
403
520
|
)
|
521
|
+
|
522
|
+
# Compile into a CompiledStateGraph for the subgraph
|
523
|
+
graph = lg_builder.compile(
|
524
|
+
checkpointer=True,
|
525
|
+
interrupt_before=interrupt_before,
|
526
|
+
interrupt_after=interrupt_after,
|
527
|
+
store=store,
|
528
|
+
debug=debug,
|
404
529
|
)
|
405
530
|
except ValueError as e:
|
406
531
|
raise ValueError(
|
407
532
|
f"Validation of the schema failed. {e}\n\nDEBUG INFO:**Schema Nodes:**\n\n{lg_builder.nodes}\n\n**Schema Enges:**\n\n{lg_builder.edges}\n\n**Tools Available:**\n\n{tools}")
|
408
|
-
|
409
|
-
|
410
|
-
|
533
|
+
# If building a nested subgraph, return the raw CompiledStateGraph
|
534
|
+
if for_subgraph:
|
535
|
+
return graph
|
536
|
+
# Otherwise prepare top-level runnable wrapper and validate
|
537
|
+
compiled = prepare_output_schema(
|
538
|
+
lg_builder, memory, store, debug,
|
539
|
+
interrupt_before=interrupt_before,
|
540
|
+
interrupt_after=interrupt_after,
|
541
|
+
state_class={state_class: None}
|
542
|
+
)
|
411
543
|
return compiled.validate()
|
412
544
|
|
413
545
|
|
@@ -440,8 +572,249 @@ class LangGraphAgentRunnable(CompiledStateGraph):
|
|
440
572
|
config_state = self.get_state(config)
|
441
573
|
if config_state.next:
|
442
574
|
thread_id = config['configurable']['thread_id']
|
443
|
-
|
575
|
+
|
576
|
+
result_with_state = {
|
444
577
|
"output": output,
|
445
578
|
"thread_id": thread_id,
|
446
579
|
"execution_finished": not config_state.next
|
447
580
|
}
|
581
|
+
|
582
|
+
# Include all state values in the result
|
583
|
+
if hasattr(config_state, 'values') and config_state.values:
|
584
|
+
for key, value in config_state.values.items():
|
585
|
+
result_with_state[key] = value
|
586
|
+
|
587
|
+
return result_with_state
|
588
|
+
|
589
|
+
def merge_subgraphs(parent_yaml: str, registry: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
|
590
|
+
"""
|
591
|
+
Merge subgraphs into parent graph by flattening YAML structures.
|
592
|
+
|
593
|
+
This function implements the complete flattening approach:
|
594
|
+
1. Parse parent YAML
|
595
|
+
2. Detect subgraph nodes
|
596
|
+
3. Recursively flatten subgraphs
|
597
|
+
4. Merge states, nodes, interrupts, and transitions
|
598
|
+
5. Return single unified graph definition
|
599
|
+
|
600
|
+
Args:
|
601
|
+
parent_yaml: YAML string of parent graph
|
602
|
+
registry: Global subgraph registry
|
603
|
+
|
604
|
+
Returns:
|
605
|
+
Dict containing flattened graph definition
|
606
|
+
"""
|
607
|
+
import copy
|
608
|
+
|
609
|
+
# Parse parent YAML
|
610
|
+
parent_def = yaml.safe_load(parent_yaml)
|
611
|
+
|
612
|
+
# Check if already flattened (prevent infinite recursion)
|
613
|
+
if parent_def.get('_flattened', False):
|
614
|
+
return parent_def
|
615
|
+
|
616
|
+
# Find subgraph nodes in parent
|
617
|
+
subgraph_nodes = []
|
618
|
+
regular_nodes = []
|
619
|
+
|
620
|
+
for node in parent_def.get('nodes', []):
|
621
|
+
if node.get('type') == 'subgraph':
|
622
|
+
subgraph_nodes.append(node)
|
623
|
+
else:
|
624
|
+
regular_nodes.append(node)
|
625
|
+
|
626
|
+
# If no subgraphs, return as-is
|
627
|
+
if not subgraph_nodes:
|
628
|
+
parent_def['_flattened'] = True
|
629
|
+
return parent_def
|
630
|
+
|
631
|
+
# Start with parent state and merge subgraph states
|
632
|
+
merged_state = copy.deepcopy(parent_def.get('state', {}))
|
633
|
+
merged_nodes = copy.deepcopy(regular_nodes)
|
634
|
+
merged_interrupts_before = set(parent_def.get('interrupt_before', []))
|
635
|
+
merged_interrupts_after = set(parent_def.get('interrupt_after', []))
|
636
|
+
all_tools = []
|
637
|
+
|
638
|
+
# Track node remapping for transition rewiring
|
639
|
+
node_mapping = {} # subgraph_node_id -> actual_internal_node_id
|
640
|
+
|
641
|
+
# Process each subgraph
|
642
|
+
for subgraph_node in subgraph_nodes:
|
643
|
+
# Support both 'tool' and 'subgraph' fields for subgraph name
|
644
|
+
subgraph_name = subgraph_node.get('tool') or subgraph_node.get('subgraph')
|
645
|
+
subgraph_node_id = subgraph_node['id']
|
646
|
+
|
647
|
+
if subgraph_name not in registry:
|
648
|
+
logger.warning(f"Subgraph '{subgraph_name}' not found in registry")
|
649
|
+
continue
|
650
|
+
|
651
|
+
# Get subgraph definition
|
652
|
+
subgraph_entry = registry[subgraph_name]
|
653
|
+
subgraph_yaml = subgraph_entry['yaml']
|
654
|
+
subgraph_tools = subgraph_entry.get('tools', [])
|
655
|
+
|
656
|
+
# Recursively flatten the subgraph (in case it has nested subgraphs)
|
657
|
+
subgraph_def = merge_subgraphs(subgraph_yaml, registry)
|
658
|
+
|
659
|
+
# Collect tools from subgraph
|
660
|
+
all_tools.extend(subgraph_tools)
|
661
|
+
|
662
|
+
# Merge state (union of all fields)
|
663
|
+
for field_name, field_type in subgraph_def.get('state', {}).items():
|
664
|
+
if field_name not in merged_state:
|
665
|
+
merged_state[field_name] = field_type
|
666
|
+
elif merged_state[field_name] != field_type:
|
667
|
+
logger.warning(f"State field '{field_name}' type mismatch: {merged_state[field_name]} vs {field_type}")
|
668
|
+
|
669
|
+
# Map subgraph node to its entry point
|
670
|
+
subgraph_entry_point = subgraph_def.get('entry_point')
|
671
|
+
if subgraph_entry_point:
|
672
|
+
node_mapping[subgraph_node_id] = subgraph_entry_point
|
673
|
+
logger.debug(f"Mapped subgraph node '{subgraph_node_id}' to entry point '{subgraph_entry_point}'")
|
674
|
+
|
675
|
+
# Add subgraph nodes without prefixing (keep original IDs)
|
676
|
+
for sub_node in subgraph_def.get('nodes', []):
|
677
|
+
# Keep original node ID - no prefixing
|
678
|
+
new_node = copy.deepcopy(sub_node)
|
679
|
+
merged_nodes.append(new_node)
|
680
|
+
|
681
|
+
# Handle the original subgraph node's transition - apply it to nodes that end with END
|
682
|
+
original_transition = subgraph_node.get('transition')
|
683
|
+
if original_transition and original_transition != 'END' and original_transition != END:
|
684
|
+
# Find nodes in this subgraph that have END transitions and update them
|
685
|
+
for node in merged_nodes:
|
686
|
+
# Check if this is a node from the current subgraph by checking if it was just added
|
687
|
+
# and has an END transition
|
688
|
+
if node.get('transition') == 'END' and node in subgraph_def.get('nodes', []):
|
689
|
+
node['transition'] = original_transition
|
690
|
+
|
691
|
+
# Merge interrupts without prefixing (keep original names)
|
692
|
+
for interrupt in subgraph_def.get('interrupt_before', []):
|
693
|
+
merged_interrupts_before.add(interrupt) # No prefixing
|
694
|
+
for interrupt in subgraph_def.get('interrupt_after', []):
|
695
|
+
merged_interrupts_after.add(interrupt) # No prefixing
|
696
|
+
|
697
|
+
# Handle entry point - keep parent's unless it's a subgraph node
|
698
|
+
entry_point = parent_def.get('entry_point')
|
699
|
+
logger.debug(f"Original entry point: {entry_point}")
|
700
|
+
logger.debug(f"Node mapping: {node_mapping}")
|
701
|
+
if entry_point in node_mapping:
|
702
|
+
# Parent entry point is a subgraph, redirect to subgraph's entry point
|
703
|
+
old_entry_point = entry_point
|
704
|
+
entry_point = node_mapping[entry_point]
|
705
|
+
logger.debug(f"Entry point changed from {old_entry_point} to {entry_point}")
|
706
|
+
else:
|
707
|
+
logger.debug(f"Entry point {entry_point} not in node mapping, keeping as-is")
|
708
|
+
|
709
|
+
# Rewrite transitions in regular nodes that point to subgraph nodes
|
710
|
+
for node in merged_nodes:
|
711
|
+
# Handle direct transitions
|
712
|
+
if 'transition' in node:
|
713
|
+
transition = node['transition']
|
714
|
+
if transition in node_mapping:
|
715
|
+
node['transition'] = node_mapping[transition]
|
716
|
+
|
717
|
+
# Handle conditional transitions
|
718
|
+
if 'condition' in node:
|
719
|
+
condition = node['condition']
|
720
|
+
if 'conditional_outputs' in condition:
|
721
|
+
new_outputs = []
|
722
|
+
for output in condition['conditional_outputs']:
|
723
|
+
if output in node_mapping:
|
724
|
+
new_outputs.append(node_mapping[output])
|
725
|
+
else:
|
726
|
+
new_outputs.append(output)
|
727
|
+
condition['conditional_outputs'] = new_outputs
|
728
|
+
|
729
|
+
if 'default_output' in condition:
|
730
|
+
default = condition['default_output']
|
731
|
+
if default in node_mapping:
|
732
|
+
condition['default_output'] = node_mapping[default]
|
733
|
+
|
734
|
+
# Update condition_definition Jinja2 template to replace subgraph node references
|
735
|
+
if 'condition_definition' in condition:
|
736
|
+
condition_definition = condition['condition_definition']
|
737
|
+
# Replace subgraph node references in the Jinja2 template
|
738
|
+
for subgraph_node_id, subgraph_entry_point in node_mapping.items():
|
739
|
+
condition_definition = condition_definition.replace(subgraph_node_id, subgraph_entry_point)
|
740
|
+
condition['condition_definition'] = condition_definition
|
741
|
+
|
742
|
+
# Handle decision nodes
|
743
|
+
if 'decision' in node:
|
744
|
+
decision = node['decision']
|
745
|
+
# Update decision.nodes list to replace subgraph node references
|
746
|
+
if 'nodes' in decision:
|
747
|
+
new_nodes = []
|
748
|
+
for decision_node in decision['nodes']:
|
749
|
+
if decision_node in node_mapping:
|
750
|
+
new_nodes.append(node_mapping[decision_node])
|
751
|
+
else:
|
752
|
+
new_nodes.append(decision_node)
|
753
|
+
decision['nodes'] = new_nodes
|
754
|
+
|
755
|
+
# Update decision.default_output to replace subgraph node references
|
756
|
+
if 'default_output' in decision:
|
757
|
+
default_output = decision['default_output']
|
758
|
+
if default_output in node_mapping:
|
759
|
+
decision['default_output'] = node_mapping[default_output]
|
760
|
+
|
761
|
+
# Build final flattened definition
|
762
|
+
flattened = {
|
763
|
+
'name': parent_def.get('name', 'FlattenedGraph'),
|
764
|
+
'state': merged_state,
|
765
|
+
'nodes': merged_nodes,
|
766
|
+
'entry_point': entry_point,
|
767
|
+
'_flattened': True,
|
768
|
+
'_all_tools': all_tools # Store tools for later collection
|
769
|
+
}
|
770
|
+
|
771
|
+
# Add interrupts if present
|
772
|
+
if merged_interrupts_before:
|
773
|
+
flattened['interrupt_before'] = list(merged_interrupts_before)
|
774
|
+
if merged_interrupts_after:
|
775
|
+
flattened['interrupt_after'] = list(merged_interrupts_after)
|
776
|
+
|
777
|
+
return flattened
|
778
|
+
|
779
|
+
|
780
|
+
def detect_and_flatten_subgraphs(yaml_schema: str) -> tuple[str, list]:
|
781
|
+
"""
|
782
|
+
Detect subgraphs in YAML and flatten them if found.
|
783
|
+
|
784
|
+
Returns:
|
785
|
+
tuple: (flattened_yaml_string, collected_tools)
|
786
|
+
"""
|
787
|
+
# Parse to check for subgraphs
|
788
|
+
schema_dict = yaml.safe_load(yaml_schema)
|
789
|
+
subgraph_nodes = [
|
790
|
+
node for node in schema_dict.get('nodes', [])
|
791
|
+
if node.get('type') == 'subgraph'
|
792
|
+
]
|
793
|
+
|
794
|
+
if not subgraph_nodes:
|
795
|
+
return yaml_schema, []
|
796
|
+
|
797
|
+
# Check if all required subgraphs are available in registry
|
798
|
+
missing_subgraphs = []
|
799
|
+
for node in subgraph_nodes:
|
800
|
+
# Support both 'tool' and 'subgraph' fields for subgraph name
|
801
|
+
# Don't clean the string - registry keys use original names
|
802
|
+
subgraph_name = node.get('tool') or node.get('subgraph')
|
803
|
+
if subgraph_name and subgraph_name not in SUBGRAPH_REGISTRY:
|
804
|
+
missing_subgraphs.append(subgraph_name)
|
805
|
+
|
806
|
+
if missing_subgraphs:
|
807
|
+
logger.warning(f"Cannot flatten - missing subgraphs: {missing_subgraphs}")
|
808
|
+
return yaml_schema, []
|
809
|
+
|
810
|
+
# Flatten the graph
|
811
|
+
flattened_def = merge_subgraphs(yaml_schema, SUBGRAPH_REGISTRY)
|
812
|
+
|
813
|
+
# Extract tools
|
814
|
+
all_tools = flattened_def.pop('_all_tools', [])
|
815
|
+
|
816
|
+
# Convert back to YAML
|
817
|
+
flattened_yaml = yaml.dump(flattened_def, default_flow_style=False)
|
818
|
+
|
819
|
+
return flattened_yaml, all_tools
|
820
|
+
|
@@ -0,0 +1,53 @@
|
|
1
|
+
from typing import List, Any
|
2
|
+
|
3
|
+
from langgraph.graph.state import CompiledStateGraph
|
4
|
+
|
5
|
+
from ..langchain.langraph_agent import create_graph, SUBGRAPH_REGISTRY
|
6
|
+
from ..utils.utils import clean_string
|
7
|
+
|
8
|
+
|
9
|
+
class SubgraphToolkit:
|
10
|
+
|
11
|
+
@staticmethod
|
12
|
+
def get_toolkit(
|
13
|
+
client: Any,
|
14
|
+
application_id: int,
|
15
|
+
application_version_id: int,
|
16
|
+
llm,
|
17
|
+
app_api_key: str,
|
18
|
+
selected_tools: list[str] = []
|
19
|
+
) -> List[CompiledStateGraph]:
|
20
|
+
from .tools import get_tools
|
21
|
+
# from langgraph.checkpoint.memory import MemorySaver
|
22
|
+
|
23
|
+
app_details = client.get_app_details(application_id)
|
24
|
+
version_details = client.get_app_version_details(application_id, application_version_id)
|
25
|
+
tools = get_tools(version_details['tools'], alita_client=client, llm=llm)
|
26
|
+
|
27
|
+
# Get the subgraph name
|
28
|
+
subgraph_name = app_details.get("name")
|
29
|
+
|
30
|
+
# Populate the registry for flattening approach
|
31
|
+
SUBGRAPH_REGISTRY[subgraph_name] = {
|
32
|
+
'yaml': version_details['instructions'],
|
33
|
+
'tools': tools,
|
34
|
+
'flattened': False
|
35
|
+
}
|
36
|
+
|
37
|
+
# For backward compatibility, still create a compiled graph stub
|
38
|
+
# This is mainly used for identification in the parent graph's tools list
|
39
|
+
graph = create_graph(
|
40
|
+
client=llm,
|
41
|
+
tools=tools,
|
42
|
+
yaml_schema=version_details['instructions'],
|
43
|
+
debug=False,
|
44
|
+
store=None,
|
45
|
+
memory=None,
|
46
|
+
for_subgraph=True, # compile as raw subgraph
|
47
|
+
)
|
48
|
+
|
49
|
+
# Tag the graph stub for parent lookup
|
50
|
+
graph.name = clean_string(subgraph_name)
|
51
|
+
|
52
|
+
# Return the compiled graph stub for backward compatibility
|
53
|
+
return [graph]
|
alita_sdk/toolkits/tools.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
|
-
from alita_tools import get_tools as alita_tools
|
4
3
|
from alita_tools import get_toolkits as alita_toolkits
|
4
|
+
from alita_tools import get_tools as alita_tools
|
5
5
|
|
6
|
-
from .prompt import PromptToolkit
|
7
|
-
from .datasource import DatasourcesToolkit
|
8
6
|
from .application import ApplicationToolkit
|
9
7
|
from .artifact import ArtifactToolkit
|
8
|
+
from .datasource import DatasourcesToolkit
|
9
|
+
from .prompt import PromptToolkit
|
10
|
+
from .subgraph import SubgraphToolkit
|
10
11
|
from .vectorstore import VectorStoreToolkit
|
11
|
-
|
12
12
|
## Community tools and toolkits
|
13
13
|
from ..community.analysis.jira_analyse import AnalyseJira
|
14
14
|
from ..community.browseruse import BrowserUseToolkit
|
@@ -35,7 +35,7 @@ def get_toolkits():
|
|
35
35
|
return core_toolkits + community_toolkits + alita_toolkits()
|
36
36
|
|
37
37
|
|
38
|
-
def get_tools(tools_list: list,
|
38
|
+
def get_tools(tools_list: list, alita_client, llm) -> list:
|
39
39
|
prompts = []
|
40
40
|
tools = []
|
41
41
|
|
@@ -47,33 +47,43 @@ def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> l
|
|
47
47
|
])
|
48
48
|
elif tool['type'] == 'datasource':
|
49
49
|
tools.extend(DatasourcesToolkit.get_toolkit(
|
50
|
-
|
50
|
+
alita_client,
|
51
51
|
datasource_ids=[int(tool['settings']['datasource_id'])],
|
52
52
|
selected_tools=tool['settings']['selected_tools'],
|
53
53
|
toolkit_name=tool.get('toolkit_name', '') or tool.get('name', '')
|
54
54
|
).get_tools())
|
55
55
|
elif tool['type'] == 'application':
|
56
56
|
tools.extend(ApplicationToolkit.get_toolkit(
|
57
|
-
|
57
|
+
alita_client,
|
58
58
|
application_id=int(tool['settings']['application_id']),
|
59
59
|
application_version_id=int(tool['settings']['application_version_id']),
|
60
|
-
app_api_key=
|
60
|
+
app_api_key=alita_client.auth_token,
|
61
61
|
selected_tools=[]
|
62
62
|
).get_tools())
|
63
|
+
elif tool['type'] == 'subgraph':
|
64
|
+
# static get_toolkit returns a list of CompiledStateGraph stubs
|
65
|
+
tools.extend(SubgraphToolkit.get_toolkit(
|
66
|
+
alita_client,
|
67
|
+
application_id=int(tool['settings']['application_id']),
|
68
|
+
application_version_id=int(tool['settings']['application_version_id']),
|
69
|
+
app_api_key=alita_client.auth_token,
|
70
|
+
selected_tools=[],
|
71
|
+
llm=llm
|
72
|
+
))
|
63
73
|
elif tool['type'] == 'artifact':
|
64
74
|
tools.extend(ArtifactToolkit.get_toolkit(
|
65
|
-
client=
|
75
|
+
client=alita_client,
|
66
76
|
bucket=tool['settings']['bucket'],
|
67
77
|
toolkit_name=tool.get('toolkit_name', ''),
|
68
78
|
selected_tools=tool['settings'].get('selected_tools', [])
|
69
79
|
).get_tools())
|
70
80
|
if tool['type'] == 'analyse_jira':
|
71
81
|
tools.extend(AnalyseJira.get_toolkit(
|
72
|
-
client=
|
82
|
+
client=alita_client,
|
73
83
|
**tool['settings']).get_tools())
|
74
84
|
if tool['type'] == 'browser_use':
|
75
85
|
tools.extend(BrowserUseToolkit.get_toolkit(
|
76
|
-
client=
|
86
|
+
client=alita_client,
|
77
87
|
llm=llm,
|
78
88
|
toolkit_name=tool.get('toolkit_name', ''),
|
79
89
|
**tool['settings']).get_tools())
|
@@ -83,33 +93,37 @@ def get_tools(tools_list: list, alita: 'AlitaClient', llm: 'LLMLikeObject') -> l
|
|
83
93
|
toolkit_name=tool.get('toolkit_name', ''),
|
84
94
|
**tool['settings']).get_tools())
|
85
95
|
if len(prompts) > 0:
|
86
|
-
tools += PromptToolkit.get_toolkit(
|
87
|
-
tools += alita_tools(tools_list,
|
88
|
-
tools += _mcp_tools(tools_list,
|
96
|
+
tools += PromptToolkit.get_toolkit(alita_client, prompts).get_tools()
|
97
|
+
tools += alita_tools(tools_list, alita_client, llm)
|
98
|
+
tools += _mcp_tools(tools_list, alita_client)
|
89
99
|
return tools
|
90
100
|
|
91
101
|
|
92
102
|
def _mcp_tools(tools_list, alita):
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
for selected_toolkit in tools_list:
|
98
|
-
toolkit_name = selected_toolkit['type'].lower()
|
99
|
-
toolkit_conf = toolkit_lookup.get(toolkit_name)
|
100
|
-
#
|
101
|
-
if not toolkit_conf:
|
102
|
-
logger.warning(f"Toolkit '{toolkit_name}' not found in available toolkits.")
|
103
|
-
continue
|
103
|
+
try:
|
104
|
+
all_available_toolkits = alita.get_mcp_toolkits()
|
105
|
+
toolkit_lookup = {tk["name"].lower(): tk for tk in all_available_toolkits}
|
106
|
+
tools = []
|
104
107
|
#
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
if not
|
110
|
-
|
111
|
-
|
112
|
-
|
108
|
+
for selected_toolkit in tools_list:
|
109
|
+
toolkit_name = selected_toolkit['type'].lower()
|
110
|
+
toolkit_conf = toolkit_lookup.get(toolkit_name)
|
111
|
+
#
|
112
|
+
if not toolkit_conf:
|
113
|
+
logger.warning(f"Toolkit '{toolkit_name}' not found in available toolkits.")
|
114
|
+
continue
|
115
|
+
#
|
116
|
+
available_tools = toolkit_conf.get("tools", [])
|
117
|
+
selected_tools = [name.lower() for name in selected_toolkit['settings'].get('selected_tools', [])]
|
118
|
+
for available_tool in available_tools:
|
119
|
+
tool_name = available_tool.get("name", "").lower()
|
120
|
+
if not selected_tools or tool_name in selected_tools:
|
121
|
+
if server_tool := _init_single_mcp_tool(toolkit_name, available_tool, alita, selected_toolkit['settings']):
|
122
|
+
tools.append(server_tool)
|
123
|
+
return tools
|
124
|
+
except Exception:
|
125
|
+
logger.error("Error while fetching MCP tools", exc_info=True)
|
126
|
+
return []
|
113
127
|
|
114
128
|
|
115
129
|
def _init_single_mcp_tool(toolkit_name, available_tool, alita, toolkit_settings):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: alita_sdk
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.126
|
4
4
|
Summary: SDK for building langchain agents using resouces from Alita
|
5
5
|
Author-email: Artem Rozumenko <artyom.rozumenko@gmail.com>, Mikalai Biazruchka <mikalai_biazruchka@epam.com>, Roman Mitusov <roman_mitusov@epam.com>, Ivan Krakhmaliuk <lifedjik@gmail.com>
|
6
6
|
Project-URL: Homepage, https://projectalita.ai
|
@@ -15,11 +15,11 @@ alita_sdk/community/analysis/jira_analyse/api_wrapper.py,sha256=JqGSxg_3x0ErzII3
|
|
15
15
|
alita_sdk/community/browseruse/__init__.py,sha256=uAxPZEX7ihpt8HtcGDFrzTNv9WcklT1wG1ItTwUO8y4,3601
|
16
16
|
alita_sdk/community/browseruse/api_wrapper.py,sha256=Y05NKWfTROPmBxe8ZFIELSGBX5v3RTNP30OTO2Tj8uI,10838
|
17
17
|
alita_sdk/langchain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
-
alita_sdk/langchain/assistant.py,sha256=
|
18
|
+
alita_sdk/langchain/assistant.py,sha256=J_xhwbNl934BgDKSpAMC9a1u6v03DZQcTYaamCztEPk,6272
|
19
19
|
alita_sdk/langchain/chat_message_template.py,sha256=kPz8W2BG6IMyITFDA5oeb5BxVRkHEVZhuiGl4MBZKdc,2176
|
20
20
|
alita_sdk/langchain/constants.py,sha256=eHVJ_beJNTf1WJo4yq7KMK64fxsRvs3lKc34QCXSbpk,3319
|
21
21
|
alita_sdk/langchain/indexer.py,sha256=0ENHy5EOhThnAiYFc7QAsaTNp9rr8hDV_hTK8ahbatk,37592
|
22
|
-
alita_sdk/langchain/langraph_agent.py,sha256=
|
22
|
+
alita_sdk/langchain/langraph_agent.py,sha256=HwopuxCWDOg6i-ZKbxZzrqnRZ84pGIS7kVN349ER8bs,36510
|
23
23
|
alita_sdk/langchain/mixedAgentParser.py,sha256=M256lvtsL3YtYflBCEp-rWKrKtcY1dJIyRGVv7KW9ME,2611
|
24
24
|
alita_sdk/langchain/mixedAgentRenderes.py,sha256=asBtKqm88QhZRILditjYICwFVKF5KfO38hu2O-WrSWE,5964
|
25
25
|
alita_sdk/langchain/utils.py,sha256=Npferkn10dvdksnKzLJLBI5bNGQyVWTBwqp3vQtUqmY,6631
|
@@ -68,7 +68,8 @@ alita_sdk/toolkits/application.py,sha256=LrxbBV05lkRP3_WtKGBKtMdoQHXVY-_AtFr1cUu
|
|
68
68
|
alita_sdk/toolkits/artifact.py,sha256=7zb17vhJ3CigeTqvzQ4VNBsU5UOCJqAwz7fOJGMYqXw,2348
|
69
69
|
alita_sdk/toolkits/datasource.py,sha256=v3FQu8Gmvq7gAGAnFEbA8qofyUhh98rxgIjY6GHBfyI,2494
|
70
70
|
alita_sdk/toolkits/prompt.py,sha256=WIpTkkVYWqIqOWR_LlSWz3ug8uO9tm5jJ7aZYdiGRn0,1192
|
71
|
-
alita_sdk/toolkits/
|
71
|
+
alita_sdk/toolkits/subgraph.py,sha256=ZYqI4yVLbEPAjCR8dpXbjbL2ipX598Hk3fL6AgaqFD4,1758
|
72
|
+
alita_sdk/toolkits/tools.py,sha256=gk3nvQBdab3QM8v93ff2nrN4ZfcT779yae2RygkTl8s,5834
|
72
73
|
alita_sdk/toolkits/vectorstore.py,sha256=di08-CRl0KJ9xSZ8_24VVnPZy58iLqHtXW8vuF29P64,2893
|
73
74
|
alita_sdk/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
74
75
|
alita_sdk/tools/application.py,sha256=UJlYd3Sub10LpAoKkKEpvd4miWyrS-yYE5NKyqx-H4Q,2194
|
@@ -92,10 +93,10 @@ alita_sdk/utils/evaluate.py,sha256=iM1P8gzBLHTuSCe85_Ng_h30m52hFuGuhNXJ7kB1tgI,1
|
|
92
93
|
alita_sdk/utils/logging.py,sha256=hBE3qAzmcLMdamMp2YRXwOOK9P4lmNaNhM76kntVljs,3124
|
93
94
|
alita_sdk/utils/streamlit.py,sha256=zp8owZwHI3HZplhcExJf6R3-APtWx-z6s5jznT2hY_k,29124
|
94
95
|
alita_sdk/utils/utils.py,sha256=dM8whOJAuFJFe19qJ69-FLzrUp6d2G-G6L7d4ss2XqM,346
|
95
|
-
alita_sdk-0.3.
|
96
|
+
alita_sdk-0.3.126.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
96
97
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
97
98
|
tests/test_jira_analysis.py,sha256=I0cErH5R_dHVyutpXrM1QEo7jfBuKWTmDQvJBPjx18I,3281
|
98
|
-
alita_sdk-0.3.
|
99
|
-
alita_sdk-0.3.
|
100
|
-
alita_sdk-0.3.
|
101
|
-
alita_sdk-0.3.
|
99
|
+
alita_sdk-0.3.126.dist-info/METADATA,sha256=eX2afqBm4mw5_LMMdJ_HXMXHTp93O8bWSIQ3jCVy8go,7075
|
100
|
+
alita_sdk-0.3.126.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
101
|
+
alita_sdk-0.3.126.dist-info/top_level.txt,sha256=SWRhxB7Et3cOy3RkE5hR7OIRnHoo3K8EXzoiNlkfOmc,25
|
102
|
+
alita_sdk-0.3.126.dist-info/RECORD,,
|
File without changes
|
File without changes
|