alita-sdk 0.3.379__py3-none-any.whl → 0.3.627__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent/__init__.py +5 -0
- alita_sdk/cli/agent/default.py +258 -0
- alita_sdk/cli/agent_executor.py +156 -0
- alita_sdk/cli/agent_loader.py +245 -0
- alita_sdk/cli/agent_ui.py +228 -0
- alita_sdk/cli/agents.py +3113 -0
- alita_sdk/cli/callbacks.py +647 -0
- alita_sdk/cli/cli.py +168 -0
- alita_sdk/cli/config.py +306 -0
- alita_sdk/cli/context/__init__.py +30 -0
- alita_sdk/cli/context/cleanup.py +198 -0
- alita_sdk/cli/context/manager.py +731 -0
- alita_sdk/cli/context/message.py +285 -0
- alita_sdk/cli/context/strategies.py +289 -0
- alita_sdk/cli/context/token_estimation.py +127 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/input_handler.py +419 -0
- alita_sdk/cli/inventory.py +1073 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/testcases/__init__.py +94 -0
- alita_sdk/cli/testcases/data_generation.py +119 -0
- alita_sdk/cli/testcases/discovery.py +96 -0
- alita_sdk/cli/testcases/executor.py +84 -0
- alita_sdk/cli/testcases/logger.py +85 -0
- alita_sdk/cli/testcases/parser.py +172 -0
- alita_sdk/cli/testcases/prompts.py +91 -0
- alita_sdk/cli/testcases/reporting.py +125 -0
- alita_sdk/cli/testcases/setup.py +108 -0
- alita_sdk/cli/testcases/test_runner.py +282 -0
- alita_sdk/cli/testcases/utils.py +39 -0
- alita_sdk/cli/testcases/validation.py +90 -0
- alita_sdk/cli/testcases/workflow.py +196 -0
- alita_sdk/cli/toolkit.py +327 -0
- alita_sdk/cli/toolkit_loader.py +85 -0
- alita_sdk/cli/tools/__init__.py +43 -0
- alita_sdk/cli/tools/approval.py +224 -0
- alita_sdk/cli/tools/filesystem.py +1751 -0
- alita_sdk/cli/tools/planning.py +389 -0
- alita_sdk/cli/tools/terminal.py +414 -0
- alita_sdk/community/__init__.py +72 -12
- alita_sdk/community/inventory/__init__.py +236 -0
- alita_sdk/community/inventory/config.py +257 -0
- alita_sdk/community/inventory/enrichment.py +2137 -0
- alita_sdk/community/inventory/extractors.py +1469 -0
- alita_sdk/community/inventory/ingestion.py +3172 -0
- alita_sdk/community/inventory/knowledge_graph.py +1457 -0
- alita_sdk/community/inventory/parsers/__init__.py +218 -0
- alita_sdk/community/inventory/parsers/base.py +295 -0
- alita_sdk/community/inventory/parsers/csharp_parser.py +907 -0
- alita_sdk/community/inventory/parsers/go_parser.py +851 -0
- alita_sdk/community/inventory/parsers/html_parser.py +389 -0
- alita_sdk/community/inventory/parsers/java_parser.py +593 -0
- alita_sdk/community/inventory/parsers/javascript_parser.py +629 -0
- alita_sdk/community/inventory/parsers/kotlin_parser.py +768 -0
- alita_sdk/community/inventory/parsers/markdown_parser.py +362 -0
- alita_sdk/community/inventory/parsers/python_parser.py +604 -0
- alita_sdk/community/inventory/parsers/rust_parser.py +858 -0
- alita_sdk/community/inventory/parsers/swift_parser.py +832 -0
- alita_sdk/community/inventory/parsers/text_parser.py +322 -0
- alita_sdk/community/inventory/parsers/yaml_parser.py +370 -0
- alita_sdk/community/inventory/patterns/__init__.py +61 -0
- alita_sdk/community/inventory/patterns/ast_adapter.py +380 -0
- alita_sdk/community/inventory/patterns/loader.py +348 -0
- alita_sdk/community/inventory/patterns/registry.py +198 -0
- alita_sdk/community/inventory/presets.py +535 -0
- alita_sdk/community/inventory/retrieval.py +1403 -0
- alita_sdk/community/inventory/toolkit.py +173 -0
- alita_sdk/community/inventory/toolkit_utils.py +176 -0
- alita_sdk/community/inventory/visualize.py +1370 -0
- alita_sdk/configurations/__init__.py +1 -1
- alita_sdk/configurations/ado.py +141 -20
- alita_sdk/configurations/bitbucket.py +94 -2
- alita_sdk/configurations/confluence.py +130 -1
- alita_sdk/configurations/figma.py +76 -0
- alita_sdk/configurations/gitlab.py +91 -0
- alita_sdk/configurations/jira.py +103 -0
- alita_sdk/configurations/openapi.py +329 -0
- alita_sdk/configurations/qtest.py +72 -1
- alita_sdk/configurations/report_portal.py +96 -0
- alita_sdk/configurations/sharepoint.py +148 -0
- alita_sdk/configurations/testio.py +83 -0
- alita_sdk/configurations/testrail.py +88 -0
- alita_sdk/configurations/xray.py +93 -0
- alita_sdk/configurations/zephyr_enterprise.py +93 -0
- alita_sdk/configurations/zephyr_essential.py +75 -0
- alita_sdk/runtime/clients/artifact.py +3 -3
- alita_sdk/runtime/clients/client.py +388 -46
- alita_sdk/runtime/clients/mcp_discovery.py +342 -0
- alita_sdk/runtime/clients/mcp_manager.py +262 -0
- alita_sdk/runtime/clients/sandbox_client.py +8 -21
- alita_sdk/runtime/langchain/_constants_bkup.py +1318 -0
- alita_sdk/runtime/langchain/assistant.py +157 -39
- alita_sdk/runtime/langchain/constants.py +647 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaExcelLoader.py +103 -60
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLinesLoader.py +77 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +10 -4
- alita_sdk/runtime/langchain/document_loaders/AlitaPowerPointLoader.py +226 -7
- alita_sdk/runtime/langchain/document_loaders/AlitaTextLoader.py +5 -2
- alita_sdk/runtime/langchain/document_loaders/constants.py +40 -19
- alita_sdk/runtime/langchain/langraph_agent.py +405 -84
- alita_sdk/runtime/langchain/utils.py +106 -7
- alita_sdk/runtime/llms/preloaded.py +2 -6
- alita_sdk/runtime/models/mcp_models.py +61 -0
- alita_sdk/runtime/skills/__init__.py +91 -0
- alita_sdk/runtime/skills/callbacks.py +498 -0
- alita_sdk/runtime/skills/discovery.py +540 -0
- alita_sdk/runtime/skills/executor.py +610 -0
- alita_sdk/runtime/skills/input_builder.py +371 -0
- alita_sdk/runtime/skills/models.py +330 -0
- alita_sdk/runtime/skills/registry.py +355 -0
- alita_sdk/runtime/skills/skill_runner.py +330 -0
- alita_sdk/runtime/toolkits/__init__.py +31 -0
- alita_sdk/runtime/toolkits/application.py +29 -10
- alita_sdk/runtime/toolkits/artifact.py +20 -11
- alita_sdk/runtime/toolkits/datasource.py +13 -6
- alita_sdk/runtime/toolkits/mcp.py +783 -0
- alita_sdk/runtime/toolkits/mcp_config.py +1048 -0
- alita_sdk/runtime/toolkits/planning.py +178 -0
- alita_sdk/runtime/toolkits/skill_router.py +238 -0
- alita_sdk/runtime/toolkits/subgraph.py +251 -6
- alita_sdk/runtime/toolkits/tools.py +356 -69
- alita_sdk/runtime/toolkits/vectorstore.py +11 -5
- alita_sdk/runtime/tools/__init__.py +10 -3
- alita_sdk/runtime/tools/application.py +27 -6
- alita_sdk/runtime/tools/artifact.py +511 -28
- alita_sdk/runtime/tools/data_analysis.py +183 -0
- alita_sdk/runtime/tools/function.py +67 -35
- alita_sdk/runtime/tools/graph.py +10 -4
- alita_sdk/runtime/tools/image_generation.py +148 -46
- alita_sdk/runtime/tools/llm.py +1003 -128
- alita_sdk/runtime/tools/loop.py +3 -1
- alita_sdk/runtime/tools/loop_output.py +3 -1
- alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
- alita_sdk/runtime/tools/mcp_remote_tool.py +181 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +8 -5
- alita_sdk/runtime/tools/planning/__init__.py +36 -0
- alita_sdk/runtime/tools/planning/models.py +246 -0
- alita_sdk/runtime/tools/planning/wrapper.py +607 -0
- alita_sdk/runtime/tools/router.py +2 -4
- alita_sdk/runtime/tools/sandbox.py +65 -48
- alita_sdk/runtime/tools/skill_router.py +776 -0
- alita_sdk/runtime/tools/tool.py +3 -1
- alita_sdk/runtime/tools/vectorstore.py +9 -3
- alita_sdk/runtime/tools/vectorstore_base.py +70 -14
- alita_sdk/runtime/utils/AlitaCallback.py +137 -21
- alita_sdk/runtime/utils/constants.py +5 -1
- alita_sdk/runtime/utils/mcp_client.py +492 -0
- alita_sdk/runtime/utils/mcp_oauth.py +361 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +434 -0
- alita_sdk/runtime/utils/mcp_tools_discovery.py +124 -0
- alita_sdk/runtime/utils/serialization.py +155 -0
- alita_sdk/runtime/utils/streamlit.py +40 -13
- alita_sdk/runtime/utils/toolkit_utils.py +30 -9
- alita_sdk/runtime/utils/utils.py +36 -0
- alita_sdk/tools/__init__.py +134 -35
- alita_sdk/tools/ado/repos/__init__.py +51 -32
- alita_sdk/tools/ado/repos/repos_wrapper.py +148 -89
- alita_sdk/tools/ado/test_plan/__init__.py +25 -9
- alita_sdk/tools/ado/test_plan/test_plan_wrapper.py +23 -1
- alita_sdk/tools/ado/utils.py +1 -18
- alita_sdk/tools/ado/wiki/__init__.py +25 -12
- alita_sdk/tools/ado/wiki/ado_wrapper.py +291 -22
- alita_sdk/tools/ado/work_item/__init__.py +26 -13
- alita_sdk/tools/ado/work_item/ado_wrapper.py +73 -11
- alita_sdk/tools/advanced_jira_mining/__init__.py +11 -8
- alita_sdk/tools/aws/delta_lake/__init__.py +13 -9
- alita_sdk/tools/aws/delta_lake/tool.py +5 -1
- alita_sdk/tools/azure_ai/search/__init__.py +11 -8
- alita_sdk/tools/azure_ai/search/api_wrapper.py +1 -1
- alita_sdk/tools/base/tool.py +5 -1
- alita_sdk/tools/base_indexer_toolkit.py +271 -84
- alita_sdk/tools/bitbucket/__init__.py +17 -11
- alita_sdk/tools/bitbucket/api_wrapper.py +59 -11
- alita_sdk/tools/bitbucket/cloud_api_wrapper.py +49 -35
- alita_sdk/tools/browser/__init__.py +5 -4
- alita_sdk/tools/carrier/__init__.py +5 -6
- alita_sdk/tools/carrier/backend_reports_tool.py +6 -6
- alita_sdk/tools/carrier/run_ui_test_tool.py +6 -6
- alita_sdk/tools/carrier/ui_reports_tool.py +5 -5
- alita_sdk/tools/chunkers/__init__.py +3 -1
- alita_sdk/tools/chunkers/code/treesitter/treesitter.py +37 -13
- alita_sdk/tools/chunkers/sematic/json_chunker.py +1 -0
- alita_sdk/tools/chunkers/sematic/markdown_chunker.py +97 -6
- alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
- alita_sdk/tools/chunkers/universal_chunker.py +270 -0
- alita_sdk/tools/cloud/aws/__init__.py +10 -7
- alita_sdk/tools/cloud/azure/__init__.py +10 -7
- alita_sdk/tools/cloud/gcp/__init__.py +10 -7
- alita_sdk/tools/cloud/k8s/__init__.py +10 -7
- alita_sdk/tools/code/linter/__init__.py +10 -8
- alita_sdk/tools/code/loaders/codesearcher.py +3 -2
- alita_sdk/tools/code/sonar/__init__.py +11 -8
- alita_sdk/tools/code_indexer_toolkit.py +82 -22
- alita_sdk/tools/confluence/__init__.py +22 -16
- alita_sdk/tools/confluence/api_wrapper.py +107 -30
- alita_sdk/tools/confluence/loader.py +14 -2
- alita_sdk/tools/custom_open_api/__init__.py +12 -5
- alita_sdk/tools/elastic/__init__.py +11 -8
- alita_sdk/tools/elitea_base.py +493 -30
- alita_sdk/tools/figma/__init__.py +58 -11
- alita_sdk/tools/figma/api_wrapper.py +1235 -143
- alita_sdk/tools/figma/figma_client.py +73 -0
- alita_sdk/tools/figma/toon_tools.py +2748 -0
- alita_sdk/tools/github/__init__.py +14 -15
- alita_sdk/tools/github/github_client.py +224 -100
- alita_sdk/tools/github/graphql_client_wrapper.py +119 -33
- alita_sdk/tools/github/schemas.py +14 -5
- alita_sdk/tools/github/tool.py +5 -1
- alita_sdk/tools/github/tool_prompts.py +9 -22
- alita_sdk/tools/gitlab/__init__.py +16 -11
- alita_sdk/tools/gitlab/api_wrapper.py +218 -48
- alita_sdk/tools/gitlab_org/__init__.py +10 -9
- alita_sdk/tools/gitlab_org/api_wrapper.py +63 -64
- alita_sdk/tools/google/bigquery/__init__.py +13 -12
- alita_sdk/tools/google/bigquery/tool.py +5 -1
- alita_sdk/tools/google_places/__init__.py +11 -8
- alita_sdk/tools/google_places/api_wrapper.py +1 -1
- alita_sdk/tools/jira/__init__.py +17 -10
- alita_sdk/tools/jira/api_wrapper.py +92 -41
- alita_sdk/tools/keycloak/__init__.py +11 -8
- alita_sdk/tools/localgit/__init__.py +9 -3
- alita_sdk/tools/localgit/local_git.py +62 -54
- alita_sdk/tools/localgit/tool.py +5 -1
- alita_sdk/tools/memory/__init__.py +12 -4
- alita_sdk/tools/non_code_indexer_toolkit.py +1 -0
- alita_sdk/tools/ocr/__init__.py +11 -8
- alita_sdk/tools/openapi/__init__.py +491 -106
- alita_sdk/tools/openapi/api_wrapper.py +1368 -0
- alita_sdk/tools/openapi/tool.py +20 -0
- alita_sdk/tools/pandas/__init__.py +20 -12
- alita_sdk/tools/pandas/api_wrapper.py +38 -25
- alita_sdk/tools/pandas/dataframe/generator/base.py +3 -1
- alita_sdk/tools/postman/__init__.py +10 -9
- alita_sdk/tools/pptx/__init__.py +11 -10
- alita_sdk/tools/pptx/pptx_wrapper.py +1 -1
- alita_sdk/tools/qtest/__init__.py +31 -11
- alita_sdk/tools/qtest/api_wrapper.py +2135 -86
- alita_sdk/tools/rally/__init__.py +10 -9
- alita_sdk/tools/rally/api_wrapper.py +1 -1
- alita_sdk/tools/report_portal/__init__.py +12 -8
- alita_sdk/tools/salesforce/__init__.py +10 -8
- alita_sdk/tools/servicenow/__init__.py +17 -15
- alita_sdk/tools/servicenow/api_wrapper.py +1 -1
- alita_sdk/tools/sharepoint/__init__.py +10 -7
- alita_sdk/tools/sharepoint/api_wrapper.py +129 -38
- alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
- alita_sdk/tools/sharepoint/utils.py +8 -2
- alita_sdk/tools/slack/__init__.py +10 -7
- alita_sdk/tools/slack/api_wrapper.py +2 -2
- alita_sdk/tools/sql/__init__.py +12 -9
- alita_sdk/tools/testio/__init__.py +10 -7
- alita_sdk/tools/testrail/__init__.py +11 -10
- alita_sdk/tools/testrail/api_wrapper.py +1 -1
- alita_sdk/tools/utils/__init__.py +9 -4
- alita_sdk/tools/utils/content_parser.py +103 -18
- alita_sdk/tools/utils/text_operations.py +410 -0
- alita_sdk/tools/utils/tool_prompts.py +79 -0
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +30 -13
- alita_sdk/tools/xray/__init__.py +13 -9
- alita_sdk/tools/yagmail/__init__.py +9 -3
- alita_sdk/tools/zephyr/__init__.py +10 -7
- alita_sdk/tools/zephyr_enterprise/__init__.py +11 -7
- alita_sdk/tools/zephyr_essential/__init__.py +10 -7
- alita_sdk/tools/zephyr_essential/api_wrapper.py +30 -13
- alita_sdk/tools/zephyr_essential/client.py +2 -2
- alita_sdk/tools/zephyr_scale/__init__.py +11 -8
- alita_sdk/tools/zephyr_scale/api_wrapper.py +2 -2
- alita_sdk/tools/zephyr_squad/__init__.py +10 -7
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/METADATA +154 -8
- alita_sdk-0.3.627.dist-info/RECORD +468 -0
- alita_sdk-0.3.627.dist-info/entry_points.txt +2 -0
- alita_sdk-0.3.379.dist-info/RECORD +0 -360
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/top_level.txt +0 -0
|
@@ -12,6 +12,7 @@ from langchain_core.runnables import Runnable
|
|
|
12
12
|
from langchain_core.runnables import RunnableConfig
|
|
13
13
|
from langchain_core.tools import BaseTool, ToolException
|
|
14
14
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
15
|
+
from langgraph.errors import GraphRecursionError
|
|
15
16
|
from langgraph.graph import StateGraph
|
|
16
17
|
from langgraph.graph.graph import END, START
|
|
17
18
|
from langgraph.graph.state import CompiledStateGraph
|
|
@@ -19,8 +20,10 @@ from langgraph.managed.base import is_managed_value
|
|
|
19
20
|
from langgraph.prebuilt import InjectedStore
|
|
20
21
|
from langgraph.store.base import BaseStore
|
|
21
22
|
|
|
23
|
+
from .constants import PRINTER_NODE_RS, PRINTER, PRINTER_COMPLETED_STATE
|
|
22
24
|
from .mixedAgentRenderes import convert_message_to_json
|
|
23
|
-
from .utils import create_state, propagate_the_input_mapping
|
|
25
|
+
from .utils import create_state, propagate_the_input_mapping, safe_format
|
|
26
|
+
from ..utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META
|
|
24
27
|
from ..tools.function import FunctionTool
|
|
25
28
|
from ..tools.indexer_tool import IndexerNode
|
|
26
29
|
from ..tools.llm import LLMNode
|
|
@@ -28,7 +31,7 @@ from ..tools.loop import LoopNode
|
|
|
28
31
|
from ..tools.loop_output import LoopToolNode
|
|
29
32
|
from ..tools.tool import ToolNode
|
|
30
33
|
from ..utils.evaluate import EvaluateTemplate
|
|
31
|
-
from ..utils.utils import clean_string
|
|
34
|
+
from ..utils.utils import clean_string
|
|
32
35
|
from ..tools.router import RouterNode
|
|
33
36
|
|
|
34
37
|
logger = logging.getLogger(__name__)
|
|
@@ -170,12 +173,13 @@ Answer only with step name, no need to add descrip in case none of the steps are
|
|
|
170
173
|
"""
|
|
171
174
|
|
|
172
175
|
def __init__(self, client, steps: str, description: str = "", decisional_inputs: Optional[list[str]] = [],
|
|
173
|
-
default_output: str = 'END'):
|
|
176
|
+
default_output: str = 'END', is_node: bool = False):
|
|
174
177
|
self.client = client
|
|
175
178
|
self.steps = ",".join([clean_string(step) for step in steps])
|
|
176
179
|
self.description = description
|
|
177
180
|
self.decisional_inputs = decisional_inputs
|
|
178
181
|
self.default_output = default_output if default_output != 'END' else END
|
|
182
|
+
self.is_node = is_node
|
|
179
183
|
|
|
180
184
|
def invoke(self, state: Annotated[BaseStore, InjectedStore()], config: Optional[RunnableConfig] = None) -> str:
|
|
181
185
|
additional_info = ""
|
|
@@ -185,10 +189,10 @@ Answer only with step name, no need to add descrip in case none of the steps are
|
|
|
185
189
|
decision_input = state.get('messages', [])[:]
|
|
186
190
|
else:
|
|
187
191
|
if len(additional_info) == 0:
|
|
188
|
-
additional_info = """###
|
|
192
|
+
additional_info = """### Additional info: """
|
|
189
193
|
additional_info += "{field}: {value}\n".format(field=field, value=state.get(field, ""))
|
|
190
194
|
decision_input.append(HumanMessage(
|
|
191
|
-
self.prompt.format(steps=self.steps, description=self.description, additional_info=additional_info)))
|
|
195
|
+
self.prompt.format(steps=self.steps, description=safe_format(self.description, state), additional_info=additional_info)))
|
|
192
196
|
completion = self.client.invoke(decision_input)
|
|
193
197
|
result = clean_string(completion.content.strip())
|
|
194
198
|
logger.info(f"Plan to transition to: {result}")
|
|
@@ -197,7 +201,8 @@ Answer only with step name, no need to add descrip in case none of the steps are
|
|
|
197
201
|
dispatch_custom_event(
|
|
198
202
|
"on_decision_edge", {"decisional_inputs": self.decisional_inputs, "state": state}, config=config
|
|
199
203
|
)
|
|
200
|
-
|
|
204
|
+
# support of legacy `decision` as part of node
|
|
205
|
+
return {"router_output": result} if self.is_node else result
|
|
201
206
|
|
|
202
207
|
|
|
203
208
|
class TransitionalEdge(Runnable):
|
|
@@ -225,11 +230,54 @@ class StateDefaultNode(Runnable):
|
|
|
225
230
|
for key, value in self.default_vars.items():
|
|
226
231
|
if isinstance(value, dict) and 'value' in value:
|
|
227
232
|
temp_value = value['value']
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
233
|
+
declared_type = value.get('type', '').lower()
|
|
234
|
+
|
|
235
|
+
# If the declared type is 'str' or 'string', preserve the string value
|
|
236
|
+
# Don't auto-convert even if it looks like a valid Python literal
|
|
237
|
+
if declared_type in ('str', 'string'):
|
|
232
238
|
result[key] = temp_value
|
|
239
|
+
else:
|
|
240
|
+
# For other types, try to evaluate as Python literal
|
|
241
|
+
try:
|
|
242
|
+
result[key] = ast.literal_eval(temp_value)
|
|
243
|
+
except:
|
|
244
|
+
logger.debug("Unable to evaluate value, using as is")
|
|
245
|
+
result[key] = temp_value
|
|
246
|
+
return result
|
|
247
|
+
|
|
248
|
+
class PrinterNode(Runnable):
|
|
249
|
+
name = "PrinterNode"
|
|
250
|
+
DEFAULT_FINAL_MSG = "How to proceed? To resume the pipeline - type anything..."
|
|
251
|
+
|
|
252
|
+
def __init__(self, input_mapping: Optional[dict[str, dict]], final_message: Optional[str] = None):
|
|
253
|
+
self.input_mapping = input_mapping
|
|
254
|
+
# Apply fallback logic for empty/None values
|
|
255
|
+
if final_message and final_message.strip():
|
|
256
|
+
self.final_message = final_message.strip()
|
|
257
|
+
else:
|
|
258
|
+
self.final_message = self.DEFAULT_FINAL_MSG
|
|
259
|
+
|
|
260
|
+
def invoke(self, state: BaseStore, config: Optional[RunnableConfig] = None) -> dict:
|
|
261
|
+
logger.info(f"Printer Node - Current state variables: {state}")
|
|
262
|
+
result = {}
|
|
263
|
+
logger.debug(f"Initial text pattern: {self.input_mapping}")
|
|
264
|
+
mapping = propagate_the_input_mapping(self.input_mapping, [], state)
|
|
265
|
+
# for printer node we expect that all the lists will be joined into strings already
|
|
266
|
+
# Join any lists that haven't been converted yet
|
|
267
|
+
for key, value in mapping.items():
|
|
268
|
+
if isinstance(value, list):
|
|
269
|
+
mapping[key] = ', '.join(str(item) for item in value)
|
|
270
|
+
if mapping.get(PRINTER) is None:
|
|
271
|
+
raise ToolException(f"PrinterNode requires '{PRINTER}' field in input mapping")
|
|
272
|
+
formatted_output = mapping[PRINTER]
|
|
273
|
+
# add info label to the printer's output
|
|
274
|
+
if not formatted_output == PRINTER_COMPLETED_STATE:
|
|
275
|
+
# convert formatted output to string if it's not
|
|
276
|
+
if not isinstance(formatted_output, str):
|
|
277
|
+
formatted_output = str(formatted_output)
|
|
278
|
+
formatted_output += f"\n\n-----\n*{self.final_message}*"
|
|
279
|
+
logger.debug(f"Formatted output: {formatted_output}")
|
|
280
|
+
result[PRINTER_NODE_RS] = formatted_output
|
|
233
281
|
return result
|
|
234
282
|
|
|
235
283
|
|
|
@@ -348,8 +396,8 @@ class StateModifierNode(Runnable):
|
|
|
348
396
|
return result
|
|
349
397
|
|
|
350
398
|
|
|
351
|
-
|
|
352
|
-
|
|
399
|
+
def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None,
|
|
400
|
+
state_class=None, output_variables=None):
|
|
353
401
|
# prepare output channels
|
|
354
402
|
if interrupt_after is None:
|
|
355
403
|
interrupt_after = []
|
|
@@ -414,6 +462,50 @@ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_befo
|
|
|
414
462
|
return compiled
|
|
415
463
|
|
|
416
464
|
|
|
465
|
+
def find_tool_by_name_or_metadata(tools: list, tool_name: str, toolkit_name: Optional[str] = None) -> Optional[BaseTool]:
|
|
466
|
+
"""
|
|
467
|
+
Find a tool by name or by matching metadata (toolkit_name + tool_name).
|
|
468
|
+
|
|
469
|
+
For toolkit nodes with toolkit_name specified, this function checks:
|
|
470
|
+
1. Metadata match first (toolkit_name + tool_name) - PRIORITY when toolkit_name is provided
|
|
471
|
+
2. Direct tool name match (backward compatibility fallback)
|
|
472
|
+
|
|
473
|
+
For toolkit nodes without toolkit_name, or other node types:
|
|
474
|
+
1. Direct tool name match
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
tools: List of available tools
|
|
478
|
+
tool_name: The tool name to search for
|
|
479
|
+
toolkit_name: Optional toolkit name for metadata matching
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
The matching tool or None if not found
|
|
483
|
+
"""
|
|
484
|
+
# When toolkit_name is specified, prioritize metadata matching
|
|
485
|
+
if toolkit_name:
|
|
486
|
+
for tool in tools:
|
|
487
|
+
# Check metadata match first
|
|
488
|
+
if hasattr(tool, 'metadata') and tool.metadata:
|
|
489
|
+
metadata_toolkit_name = tool.metadata.get(TOOLKIT_NAME_META)
|
|
490
|
+
metadata_tool_name = tool.metadata.get(TOOL_NAME_META)
|
|
491
|
+
|
|
492
|
+
# Match if both toolkit_name and tool_name in metadata match
|
|
493
|
+
if metadata_toolkit_name == toolkit_name and metadata_tool_name == tool_name:
|
|
494
|
+
return tool
|
|
495
|
+
|
|
496
|
+
# Fallback to direct name match for backward compatibility
|
|
497
|
+
for tool in tools:
|
|
498
|
+
if tool.name == tool_name:
|
|
499
|
+
return tool
|
|
500
|
+
else:
|
|
501
|
+
# No toolkit_name specified, use direct name match only
|
|
502
|
+
for tool in tools:
|
|
503
|
+
if tool.name == tool_name:
|
|
504
|
+
return tool
|
|
505
|
+
|
|
506
|
+
return None
|
|
507
|
+
|
|
508
|
+
|
|
417
509
|
def create_graph(
|
|
418
510
|
client: Any,
|
|
419
511
|
yaml_schema: str,
|
|
@@ -427,13 +519,25 @@ def create_graph(
|
|
|
427
519
|
):
|
|
428
520
|
""" Create a message graph from a yaml schema """
|
|
429
521
|
|
|
522
|
+
# TODO: deprecate next release (1/15/2026)
|
|
430
523
|
# For top-level graphs (not subgraphs), detect and flatten any subgraphs
|
|
431
|
-
if not for_subgraph:
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
524
|
+
# if not for_subgraph:
|
|
525
|
+
# flattened_yaml, additional_tools = detect_and_flatten_subgraphs(yaml_schema)
|
|
526
|
+
# # Add collected tools from subgraphs to the tools list
|
|
527
|
+
# tools = list(tools) + additional_tools
|
|
528
|
+
# # Use the flattened YAML for building the graph
|
|
529
|
+
# yaml_schema = flattened_yaml
|
|
530
|
+
# else:
|
|
531
|
+
# # For subgraphs, filter out PrinterNodes from YAML
|
|
532
|
+
# from ..toolkits.subgraph import _filter_printer_nodes_from_yaml
|
|
533
|
+
# yaml_schema = _filter_printer_nodes_from_yaml(yaml_schema)
|
|
534
|
+
# logger.info("Filtered PrinterNodes from subgraph YAML in create_graph")
|
|
535
|
+
|
|
536
|
+
if for_subgraph:
|
|
537
|
+
# Sanitization for sub-graphs
|
|
538
|
+
from ..toolkits.subgraph import _filter_printer_nodes_from_yaml
|
|
539
|
+
yaml_schema = _filter_printer_nodes_from_yaml(yaml_schema)
|
|
540
|
+
logger.info("Filtered PrinterNodes from subgraph YAML in create_graph")
|
|
437
541
|
|
|
438
542
|
schema = yaml.safe_load(yaml_schema)
|
|
439
543
|
logger.debug(f"Schema: {schema}")
|
|
@@ -449,16 +553,37 @@ def create_graph(
|
|
|
449
553
|
node_type = node.get('type', 'function')
|
|
450
554
|
node_id = clean_string(node['id'])
|
|
451
555
|
toolkit_name = node.get('toolkit_name')
|
|
452
|
-
tool_name = clean_string(node.get('tool',
|
|
453
|
-
|
|
454
|
-
tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
|
|
556
|
+
tool_name = clean_string(node.get('tool', ''))
|
|
557
|
+
# Tool names are now clean (no prefix needed)
|
|
455
558
|
logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
|
|
456
|
-
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
559
|
+
if node_type in ['function', 'toolkit', 'mcp', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
560
|
+
if node_type in ['mcp', 'toolkit', 'agent'] and not tool_name:
|
|
561
|
+
# tool is not specified
|
|
562
|
+
raise ToolException(f"Tool name is required for {node_type} node with id '{node_id}'")
|
|
563
|
+
|
|
564
|
+
# Unified validation and tool finding for toolkit, mcp, and agent node types
|
|
565
|
+
matching_tool = None
|
|
566
|
+
if node_type in ['toolkit', 'mcp', 'agent']:
|
|
567
|
+
# Use enhanced validation that checks both direct name and metadata
|
|
568
|
+
matching_tool = find_tool_by_name_or_metadata(tools, tool_name, toolkit_name)
|
|
569
|
+
if not matching_tool:
|
|
570
|
+
# tool is not found in the provided tools
|
|
571
|
+
error_msg = f"Node `{node_id}` with type `{node_type}` has tool '{tool_name}'"
|
|
572
|
+
if toolkit_name:
|
|
573
|
+
error_msg += f" (toolkit: '{toolkit_name}')"
|
|
574
|
+
error_msg += f" which is not found in the provided tools. Make sure it is connected properly. Available tools: {format_tools(tools)}"
|
|
575
|
+
raise ToolException(error_msg)
|
|
576
|
+
else:
|
|
577
|
+
# For other node types, find tool by direct name match
|
|
578
|
+
for tool in tools:
|
|
579
|
+
if tool.name == tool_name:
|
|
580
|
+
matching_tool = tool
|
|
581
|
+
break
|
|
582
|
+
|
|
583
|
+
if matching_tool:
|
|
584
|
+
if node_type in ['function', 'toolkit', 'mcp']:
|
|
460
585
|
lg_builder.add_node(node_id, FunctionTool(
|
|
461
|
-
tool=
|
|
586
|
+
tool=matching_tool, name=node_id, return_type='dict',
|
|
462
587
|
output_variables=node.get('output', []),
|
|
463
588
|
input_mapping=node.get('input_mapping',
|
|
464
589
|
{'messages': {'type': 'variable', 'value': 'messages'}}),
|
|
@@ -466,24 +591,26 @@ def create_graph(
|
|
|
466
591
|
elif node_type == 'agent':
|
|
467
592
|
input_params = node.get('input', ['messages'])
|
|
468
593
|
input_mapping = node.get('input_mapping',
|
|
469
|
-
|
|
594
|
+
{'messages': {'type': 'variable', 'value': 'messages'}})
|
|
595
|
+
output_vars = node.get('output', [])
|
|
470
596
|
lg_builder.add_node(node_id, FunctionTool(
|
|
471
|
-
client=client, tool=
|
|
597
|
+
client=client, tool=matching_tool,
|
|
472
598
|
name=node_id, return_type='str',
|
|
473
|
-
output_variables=
|
|
599
|
+
output_variables=output_vars + ['messages'] if 'messages' not in output_vars else output_vars,
|
|
474
600
|
input_variables=input_params,
|
|
475
601
|
input_mapping= input_mapping
|
|
476
602
|
))
|
|
477
603
|
elif node_type == 'subgraph' or node_type == 'pipeline':
|
|
478
604
|
# assign parent memory/store
|
|
479
|
-
#
|
|
480
|
-
#
|
|
605
|
+
# matching_tool.checkpointer = memory
|
|
606
|
+
# matching_tool.store = store
|
|
481
607
|
# wrap with mappings
|
|
482
608
|
pipeline_name = node.get('tool', None)
|
|
483
609
|
if not pipeline_name:
|
|
484
|
-
raise ValueError(
|
|
610
|
+
raise ValueError(
|
|
611
|
+
"Subgraph must have a 'tool' node: add required tool to the subgraph node")
|
|
485
612
|
node_fn = SubgraphRunnable(
|
|
486
|
-
inner=
|
|
613
|
+
inner=matching_tool.graph,
|
|
487
614
|
name=pipeline_name,
|
|
488
615
|
input_mapping=node.get('input_mapping', {}),
|
|
489
616
|
output_mapping=node.get('output_mapping', {}),
|
|
@@ -492,25 +619,16 @@ def create_graph(
|
|
|
492
619
|
break # skip legacy handling
|
|
493
620
|
elif node_type == 'tool':
|
|
494
621
|
lg_builder.add_node(node_id, ToolNode(
|
|
495
|
-
client=client, tool=
|
|
622
|
+
client=client, tool=matching_tool,
|
|
496
623
|
name=node_id, return_type='dict',
|
|
497
624
|
output_variables=node.get('output', []),
|
|
498
625
|
input_variables=node.get('input', ['messages']),
|
|
499
626
|
structured_output=node.get('structured_output', False),
|
|
500
627
|
task=node.get('task')
|
|
501
628
|
))
|
|
502
|
-
# TODO: decide on struct output for agent nodes
|
|
503
|
-
# elif node_type == 'agent':
|
|
504
|
-
# lg_builder.add_node(node_id, AgentNode(
|
|
505
|
-
# client=client, tool=tool,
|
|
506
|
-
# name=node['id'], return_type='dict',
|
|
507
|
-
# output_variables=node.get('output', []),
|
|
508
|
-
# input_variables=node.get('input', ['messages']),
|
|
509
|
-
# task=node.get('task')
|
|
510
|
-
# ))
|
|
511
629
|
elif node_type == 'loop':
|
|
512
630
|
lg_builder.add_node(node_id, LoopNode(
|
|
513
|
-
client=client, tool=
|
|
631
|
+
client=client, tool=matching_tool,
|
|
514
632
|
name=node_id, return_type='dict',
|
|
515
633
|
output_variables=node.get('output', []),
|
|
516
634
|
input_variables=node.get('input', ['messages']),
|
|
@@ -520,14 +638,15 @@ def create_graph(
|
|
|
520
638
|
loop_toolkit_name = node.get('loop_toolkit_name')
|
|
521
639
|
loop_tool_name = node.get('loop_tool')
|
|
522
640
|
if (loop_toolkit_name and loop_tool_name) or loop_tool_name:
|
|
523
|
-
|
|
641
|
+
# Use clean tool name (no prefix)
|
|
642
|
+
loop_tool_name = clean_string(loop_tool_name)
|
|
524
643
|
for t in tools:
|
|
525
644
|
if t.name == loop_tool_name:
|
|
526
645
|
logger.debug(f"Loop tool discovered: {t}")
|
|
527
646
|
lg_builder.add_node(node_id, LoopToolNode(
|
|
528
647
|
client=client,
|
|
529
648
|
name=node_id, return_type='dict',
|
|
530
|
-
tool=
|
|
649
|
+
tool=matching_tool, loop_tool=t,
|
|
531
650
|
variables_mapping=node.get('variables_mapping', {}),
|
|
532
651
|
output_variables=node.get('output', []),
|
|
533
652
|
input_variables=node.get('input', ['messages']),
|
|
@@ -543,7 +662,7 @@ def create_graph(
|
|
|
543
662
|
indexer_tool = t
|
|
544
663
|
logger.info(f"Indexer tool: {indexer_tool}")
|
|
545
664
|
lg_builder.add_node(node_id, IndexerNode(
|
|
546
|
-
client=client, tool=
|
|
665
|
+
client=client, tool=matching_tool,
|
|
547
666
|
index_tool=indexer_tool,
|
|
548
667
|
input_mapping=node.get('input_mapping', {}),
|
|
549
668
|
name=node_id, return_type='dict',
|
|
@@ -552,10 +671,10 @@ def create_graph(
|
|
|
552
671
|
output_variables=node.get('output', []),
|
|
553
672
|
input_variables=node.get('input', ['messages']),
|
|
554
673
|
structured_output=node.get('structured_output', False)))
|
|
555
|
-
break
|
|
556
674
|
elif node_type == 'code':
|
|
557
675
|
from ..tools.sandbox import create_sandbox_tool
|
|
558
|
-
sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True
|
|
676
|
+
sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True,
|
|
677
|
+
alita_client=kwargs.get('alita_client', None))
|
|
559
678
|
code_data = node.get('code', {'type': 'fixed', 'value': "return 'Code block is empty'"})
|
|
560
679
|
lg_builder.add_node(node_id, FunctionTool(
|
|
561
680
|
tool=sandbox_tool, name=node['id'], return_type='dict',
|
|
@@ -577,10 +696,10 @@ def create_graph(
|
|
|
577
696
|
tool_names = []
|
|
578
697
|
if isinstance(connected_tools, dict):
|
|
579
698
|
for toolkit, selected_tools in connected_tools.items():
|
|
580
|
-
|
|
581
|
-
|
|
699
|
+
# Add tool names directly (no prefix)
|
|
700
|
+
tool_names.extend(selected_tools)
|
|
582
701
|
elif isinstance(connected_tools, list):
|
|
583
|
-
#
|
|
702
|
+
# Use provided tool names as-is
|
|
584
703
|
tool_names = connected_tools
|
|
585
704
|
|
|
586
705
|
if tool_names:
|
|
@@ -593,7 +712,7 @@ def create_graph(
|
|
|
593
712
|
else:
|
|
594
713
|
# Use all available tools
|
|
595
714
|
available_tools = [tool for tool in tools if isinstance(tool, BaseTool)]
|
|
596
|
-
|
|
715
|
+
|
|
597
716
|
lg_builder.add_node(node_id, LLMNode(
|
|
598
717
|
client=client,
|
|
599
718
|
input_mapping=node.get('input_mapping', {'messages': {'type': 'variable', 'value': 'messages'}}),
|
|
@@ -603,17 +722,34 @@ def create_graph(
|
|
|
603
722
|
output_variables=output_vars,
|
|
604
723
|
input_variables=node.get('input', ['messages']),
|
|
605
724
|
structured_output=node.get('structured_output', False),
|
|
725
|
+
tool_execution_timeout=node.get('tool_execution_timeout', 900),
|
|
606
726
|
available_tools=available_tools,
|
|
607
|
-
tool_names=tool_names
|
|
608
|
-
|
|
609
|
-
# Add a RouterNode as an independent node
|
|
610
|
-
lg_builder.add_node(node_id, RouterNode(
|
|
611
|
-
name=node_id,
|
|
612
|
-
condition=node.get('condition', ''),
|
|
613
|
-
routes=node.get('routes', []),
|
|
614
|
-
default_output=node.get('default_output', 'END'),
|
|
615
|
-
input_variables=node.get('input', ['messages'])
|
|
727
|
+
tool_names=tool_names,
|
|
728
|
+
steps_limit=kwargs.get('steps_limit', 25)
|
|
616
729
|
))
|
|
730
|
+
elif node_type in ['router', 'decision']:
|
|
731
|
+
if node_type == 'router':
|
|
732
|
+
# Add a RouterNode as an independent node
|
|
733
|
+
lg_builder.add_node(node_id, RouterNode(
|
|
734
|
+
name=node_id,
|
|
735
|
+
condition=node.get('condition', ''),
|
|
736
|
+
routes=node.get('routes', []),
|
|
737
|
+
default_output=node.get('default_output', 'END'),
|
|
738
|
+
input_variables=node.get('input', ['messages'])
|
|
739
|
+
))
|
|
740
|
+
elif node_type == 'decision':
|
|
741
|
+
logger.info(f'Adding decision: {node["nodes"]}')
|
|
742
|
+
# fallback to old-style decision node
|
|
743
|
+
decisional_inputs = node.get('decisional_inputs')
|
|
744
|
+
decisional_inputs = node.get('input', ['messages']) if not decisional_inputs else decisional_inputs
|
|
745
|
+
lg_builder.add_node(node_id, DecisionEdge(
|
|
746
|
+
client, node['nodes'],
|
|
747
|
+
node.get('description', ""),
|
|
748
|
+
decisional_inputs=decisional_inputs,
|
|
749
|
+
default_output=node.get('default_output', 'END'),
|
|
750
|
+
is_node=True
|
|
751
|
+
))
|
|
752
|
+
|
|
617
753
|
# Add a single conditional edge for all routes
|
|
618
754
|
lg_builder.add_conditional_edges(
|
|
619
755
|
node_id,
|
|
@@ -624,6 +760,7 @@ def create_graph(
|
|
|
624
760
|
default_output=node.get('default_output', 'END')
|
|
625
761
|
)
|
|
626
762
|
)
|
|
763
|
+
continue
|
|
627
764
|
elif node_type == 'state_modifier':
|
|
628
765
|
lg_builder.add_node(node_id, StateModifierNode(
|
|
629
766
|
template=node.get('template', ''),
|
|
@@ -631,6 +768,23 @@ def create_graph(
|
|
|
631
768
|
input_variables=node.get('input', ['messages']),
|
|
632
769
|
output_variables=node.get('output', [])
|
|
633
770
|
))
|
|
771
|
+
elif node_type == 'printer':
|
|
772
|
+
lg_builder.add_node(node_id, PrinterNode(
|
|
773
|
+
input_mapping=node.get('input_mapping', {'printer': {'type': 'fixed', 'value': ''}}),
|
|
774
|
+
final_message=node.get('final_message'),
|
|
775
|
+
))
|
|
776
|
+
|
|
777
|
+
# add interrupts after printer node if specified
|
|
778
|
+
interrupt_after.append(clean_string(node_id))
|
|
779
|
+
|
|
780
|
+
# reset printer output variable to avoid carrying over
|
|
781
|
+
reset_node_id = f"{node_id}_reset"
|
|
782
|
+
lg_builder.add_node(reset_node_id, PrinterNode(
|
|
783
|
+
input_mapping={'printer': {'type': 'fixed', 'value': PRINTER_COMPLETED_STATE}}
|
|
784
|
+
))
|
|
785
|
+
lg_builder.add_conditional_edges(node_id, TransitionalEdge(reset_node_id))
|
|
786
|
+
lg_builder.add_conditional_edges(reset_node_id, TransitionalEdge(clean_string(node['transition'])))
|
|
787
|
+
continue
|
|
634
788
|
if node.get('transition'):
|
|
635
789
|
next_step = clean_string(node['transition'])
|
|
636
790
|
logger.info(f'Adding transition: {next_step}')
|
|
@@ -687,8 +841,20 @@ def create_graph(
|
|
|
687
841
|
debug=debug,
|
|
688
842
|
)
|
|
689
843
|
except ValueError as e:
|
|
690
|
-
|
|
691
|
-
|
|
844
|
+
# Build a clearer debug message without complex f-string expressions
|
|
845
|
+
debug_nodes = "\n*".join(lg_builder.nodes.keys()) if lg_builder and lg_builder.nodes else ""
|
|
846
|
+
debug_message = (
|
|
847
|
+
"Validation of the schema failed. {err}\n\n"
|
|
848
|
+
"DEBUG INFO:**Schema Nodes:**\n\n*{nodes}\n\n"
|
|
849
|
+
"**Schema Edges:**\n\n{edges}\n\n"
|
|
850
|
+
"**Tools Available:**\n\n{tools}"
|
|
851
|
+
).format(
|
|
852
|
+
err=e,
|
|
853
|
+
nodes=debug_nodes,
|
|
854
|
+
edges=lg_builder.edges if lg_builder else {},
|
|
855
|
+
tools=format_tools(tools),
|
|
856
|
+
)
|
|
857
|
+
raise ValueError(debug_message)
|
|
692
858
|
# If building a nested subgraph, return the raw CompiledStateGraph
|
|
693
859
|
if for_subgraph:
|
|
694
860
|
return graph
|
|
@@ -702,6 +868,14 @@ def create_graph(
|
|
|
702
868
|
)
|
|
703
869
|
return compiled.validate()
|
|
704
870
|
|
|
871
|
+
def format_tools(tools_list: list) -> str:
|
|
872
|
+
"""Format a list of tool names into a comma-separated string."""
|
|
873
|
+
try:
|
|
874
|
+
return ', '.join([tool.name for tool in tools_list])
|
|
875
|
+
except Exception as e:
|
|
876
|
+
logger.warning(f"Failed to format tools list: {e}")
|
|
877
|
+
return str(tools_list)
|
|
878
|
+
|
|
705
879
|
def set_defaults(d):
|
|
706
880
|
"""Set default values for dictionary entries based on their type."""
|
|
707
881
|
type_defaults = {
|
|
@@ -772,55 +946,202 @@ class LangGraphAgentRunnable(CompiledStateGraph):
|
|
|
772
946
|
if not config.get("configurable", {}).get("thread_id", ""):
|
|
773
947
|
config["configurable"] = {"thread_id": str(uuid4())}
|
|
774
948
|
thread_id = config.get("configurable", {}).get("thread_id")
|
|
949
|
+
|
|
950
|
+
# Check if checkpoint exists early for chat_history handling
|
|
951
|
+
checkpoint_exists = self.checkpointer and self.checkpointer.get_tuple(config)
|
|
952
|
+
|
|
775
953
|
# Handle chat history and current input properly
|
|
776
954
|
if input.get('chat_history') and not input.get('messages'):
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
955
|
+
if checkpoint_exists:
|
|
956
|
+
# Checkpoint already has conversation history - discard redundant chat_history
|
|
957
|
+
input.pop('chat_history', None)
|
|
958
|
+
else:
|
|
959
|
+
# No checkpoint - convert chat history dict messages to LangChain message objects
|
|
960
|
+
chat_history = input.pop('chat_history')
|
|
961
|
+
input['messages'] = [convert_dict_to_message(msg) for msg in chat_history]
|
|
962
|
+
|
|
963
|
+
# handler for LLM node: if no input (Chat perspective), then take last human message
|
|
964
|
+
# Track if input came from messages to handle content extraction properly
|
|
965
|
+
input_from_messages = False
|
|
966
|
+
if not input.get('input'):
|
|
967
|
+
if input.get('messages'):
|
|
968
|
+
input['input'] = [next((msg for msg in reversed(input['messages']) if isinstance(msg, HumanMessage)),
|
|
969
|
+
None)]
|
|
970
|
+
if input['input'] is not None:
|
|
971
|
+
input_from_messages = True
|
|
972
|
+
|
|
781
973
|
# Append current input to existing messages instead of overwriting
|
|
782
974
|
if input.get('input'):
|
|
783
975
|
if isinstance(input['input'], str):
|
|
784
976
|
current_message = input['input']
|
|
785
977
|
else:
|
|
978
|
+
# input can be a list of messages or a single message object
|
|
786
979
|
current_message = input.get('input')[-1]
|
|
980
|
+
|
|
787
981
|
# TODO: add handler after we add 2+ inputs (filterByType, etc.)
|
|
788
|
-
|
|
982
|
+
if isinstance(current_message, HumanMessage):
|
|
983
|
+
current_content = current_message.content
|
|
984
|
+
if isinstance(current_content, list):
|
|
985
|
+
# Extract text parts and keep non-text parts (images, etc.)
|
|
986
|
+
text_contents = []
|
|
987
|
+
non_text_parts = []
|
|
988
|
+
|
|
989
|
+
for item in current_content:
|
|
990
|
+
if isinstance(item, dict) and item.get('type') == 'text':
|
|
991
|
+
text_contents.append(item['text'])
|
|
992
|
+
elif isinstance(item, str):
|
|
993
|
+
text_contents.append(item)
|
|
994
|
+
else:
|
|
995
|
+
# Keep image_url and other non-text content
|
|
996
|
+
non_text_parts.append(item)
|
|
997
|
+
|
|
998
|
+
# Set input to the joined text
|
|
999
|
+
input['input'] = ". ".join(text_contents) if text_contents else ""
|
|
1000
|
+
|
|
1001
|
+
# If this message came from input['messages'], update or remove it
|
|
1002
|
+
if input_from_messages:
|
|
1003
|
+
if non_text_parts:
|
|
1004
|
+
# Keep the message but only with non-text content (images, etc.)
|
|
1005
|
+
current_message.content = non_text_parts
|
|
1006
|
+
else:
|
|
1007
|
+
# All content was text, remove this message from the list
|
|
1008
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
1009
|
+
else:
|
|
1010
|
+
# Message came from input['input'], not from input['messages']
|
|
1011
|
+
# If there are non-text parts (images, etc.), preserve them in messages
|
|
1012
|
+
if non_text_parts:
|
|
1013
|
+
# Initialize messages if it doesn't exist or is empty
|
|
1014
|
+
if not input.get('messages'):
|
|
1015
|
+
input['messages'] = []
|
|
1016
|
+
# Create a new message with only non-text content
|
|
1017
|
+
non_text_message = HumanMessage(content=non_text_parts)
|
|
1018
|
+
input['messages'].append(non_text_message)
|
|
1019
|
+
|
|
1020
|
+
elif isinstance(current_content, str):
|
|
1021
|
+
# on regenerate case
|
|
1022
|
+
input['input'] = current_content
|
|
1023
|
+
# If from messages and all content is text, remove the message
|
|
1024
|
+
if input_from_messages:
|
|
1025
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
1026
|
+
else:
|
|
1027
|
+
input['input'] = str(current_content)
|
|
1028
|
+
# If from messages, remove since we extracted the content
|
|
1029
|
+
if input_from_messages:
|
|
1030
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
1031
|
+
elif isinstance(current_message, str):
|
|
1032
|
+
input['input'] = current_message
|
|
1033
|
+
else:
|
|
1034
|
+
input['input'] = str(current_message)
|
|
789
1035
|
if input.get('messages'):
|
|
790
1036
|
# Ensure existing messages are LangChain objects
|
|
791
1037
|
input['messages'] = [convert_dict_to_message(msg) for msg in input['messages']]
|
|
792
1038
|
# Append to existing messages
|
|
793
|
-
input['messages'].append(current_message)
|
|
1039
|
+
# input['messages'].append(current_message)
|
|
1040
|
+
# else:
|
|
1041
|
+
# NOTE: Commented out to prevent duplicates with input['input']
|
|
1042
|
+
# input['messages'] = [current_message]
|
|
1043
|
+
|
|
1044
|
+
# Validate that input is not empty after all processing
|
|
1045
|
+
if not input.get('input'):
|
|
1046
|
+
raise RuntimeError(
|
|
1047
|
+
"Empty input after processing. Cannot send empty string to LLM. "
|
|
1048
|
+
"This likely means the message contained only non-text content "
|
|
1049
|
+
"with no accompanying text."
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
logger.info(f"Input: {thread_id} - {input}")
|
|
1053
|
+
try:
|
|
1054
|
+
if self.checkpointer and self.checkpointer.get_tuple(config):
|
|
1055
|
+
if config.pop("should_continue", False):
|
|
1056
|
+
invoke_input = input
|
|
1057
|
+
else:
|
|
1058
|
+
self.update_state(config, input)
|
|
1059
|
+
invoke_input = None
|
|
1060
|
+
result = super().invoke(invoke_input, config=config, *args, **kwargs)
|
|
794
1061
|
else:
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
self.
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
1062
|
+
result = super().invoke(input, config=config, *args, **kwargs)
|
|
1063
|
+
except GraphRecursionError as e:
|
|
1064
|
+
current_recursion_limit = config.get("recursion_limit", 0)
|
|
1065
|
+
logger.warning("ToolExecutionLimitReached caught in LangGraphAgentRunnable: %s", e)
|
|
1066
|
+
return self._handle_graph_recursion_error(
|
|
1067
|
+
config=config,
|
|
1068
|
+
thread_id=thread_id,
|
|
1069
|
+
current_recursion_limit=current_recursion_limit,
|
|
1070
|
+
)
|
|
1071
|
+
|
|
803
1072
|
try:
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
1073
|
+
# Check if printer node output exists
|
|
1074
|
+
printer_output = result.get(PRINTER_NODE_RS)
|
|
1075
|
+
if printer_output == PRINTER_COMPLETED_STATE:
|
|
1076
|
+
# Printer completed, extract last AI message
|
|
1077
|
+
messages = result['messages']
|
|
1078
|
+
output = next(
|
|
1079
|
+
(msg.content for msg in reversed(messages)
|
|
1080
|
+
if not isinstance(msg, HumanMessage)),
|
|
1081
|
+
messages[-1].content
|
|
1082
|
+
) if messages else result.get('output')
|
|
1083
|
+
elif printer_output is not None:
|
|
1084
|
+
# Printer node has output (interrupted state)
|
|
1085
|
+
output = printer_output
|
|
1086
|
+
else:
|
|
1087
|
+
# No printer node, extract last AI message from messages
|
|
1088
|
+
messages = result.get('messages', [])
|
|
1089
|
+
output = next(
|
|
1090
|
+
(msg.content for msg in reversed(messages)
|
|
1091
|
+
if not isinstance(msg, HumanMessage)),
|
|
1092
|
+
None
|
|
1093
|
+
)
|
|
1094
|
+
except Exception:
|
|
1095
|
+
# Fallback: try to get last value or last message
|
|
1096
|
+
output = str(list(result.values())[-1]) if result else 'Output is undefined'
|
|
807
1097
|
config_state = self.get_state(config)
|
|
808
1098
|
is_execution_finished = not config_state.next
|
|
809
1099
|
if is_execution_finished:
|
|
810
1100
|
thread_id = None
|
|
811
1101
|
|
|
812
|
-
|
|
1102
|
+
final_output = f"Assistant run has been completed, but output is None.\nAdding last message if any: {messages[-1] if messages else []}" if is_execution_finished and output is None else output
|
|
813
1103
|
|
|
814
1104
|
result_with_state = {
|
|
815
|
-
"output":
|
|
1105
|
+
"output": final_output,
|
|
816
1106
|
"thread_id": thread_id,
|
|
817
1107
|
"execution_finished": is_execution_finished
|
|
818
1108
|
}
|
|
819
1109
|
|
|
820
1110
|
# Include all state values in the result
|
|
821
1111
|
if hasattr(config_state, 'values') and config_state.values:
|
|
1112
|
+
# except of key = 'output' which is already included
|
|
1113
|
+
for key, value in config_state.values.items():
|
|
1114
|
+
if key != 'output':
|
|
1115
|
+
result_with_state[key] = value
|
|
1116
|
+
|
|
1117
|
+
return result_with_state
|
|
1118
|
+
|
|
1119
|
+
def _handle_graph_recursion_error(
|
|
1120
|
+
self,
|
|
1121
|
+
config: RunnableConfig,
|
|
1122
|
+
thread_id: str,
|
|
1123
|
+
current_recursion_limit: int,
|
|
1124
|
+
) -> dict:
|
|
1125
|
+
"""Handle GraphRecursionError by returning a soft-boundary response."""
|
|
1126
|
+
config_state = self.get_state(config)
|
|
1127
|
+
is_execution_finished = False
|
|
1128
|
+
|
|
1129
|
+
friendly_output = (
|
|
1130
|
+
f"Tool step limit {current_recursion_limit} reached for this run. You can continue by sending another "
|
|
1131
|
+
"message or refining your request."
|
|
1132
|
+
)
|
|
1133
|
+
|
|
1134
|
+
result_with_state: dict[str, Any] = {
|
|
1135
|
+
"output": friendly_output,
|
|
1136
|
+
"thread_id": thread_id,
|
|
1137
|
+
"execution_finished": is_execution_finished,
|
|
1138
|
+
"tool_execution_limit_reached": True,
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
if hasattr(config_state, "values") and config_state.values:
|
|
822
1142
|
for key, value in config_state.values.items():
|
|
823
|
-
|
|
1143
|
+
if key != "output":
|
|
1144
|
+
result_with_state[key] = value
|
|
824
1145
|
|
|
825
1146
|
return result_with_state
|
|
826
1147
|
|