alita-sdk 0.3.257__py3-none-any.whl → 0.3.584__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of alita-sdk might be problematic. Click here for more details.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent/__init__.py +5 -0
- alita_sdk/cli/agent/default.py +258 -0
- alita_sdk/cli/agent_executor.py +155 -0
- alita_sdk/cli/agent_loader.py +215 -0
- alita_sdk/cli/agent_ui.py +228 -0
- alita_sdk/cli/agents.py +3794 -0
- alita_sdk/cli/callbacks.py +647 -0
- alita_sdk/cli/cli.py +168 -0
- alita_sdk/cli/config.py +306 -0
- alita_sdk/cli/context/__init__.py +30 -0
- alita_sdk/cli/context/cleanup.py +198 -0
- alita_sdk/cli/context/manager.py +731 -0
- alita_sdk/cli/context/message.py +285 -0
- alita_sdk/cli/context/strategies.py +289 -0
- alita_sdk/cli/context/token_estimation.py +127 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/input_handler.py +419 -0
- alita_sdk/cli/inventory.py +1073 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/toolkit.py +327 -0
- alita_sdk/cli/toolkit_loader.py +85 -0
- alita_sdk/cli/tools/__init__.py +43 -0
- alita_sdk/cli/tools/approval.py +224 -0
- alita_sdk/cli/tools/filesystem.py +1751 -0
- alita_sdk/cli/tools/planning.py +389 -0
- alita_sdk/cli/tools/terminal.py +414 -0
- alita_sdk/community/__init__.py +72 -12
- alita_sdk/community/inventory/__init__.py +236 -0
- alita_sdk/community/inventory/config.py +257 -0
- alita_sdk/community/inventory/enrichment.py +2137 -0
- alita_sdk/community/inventory/extractors.py +1469 -0
- alita_sdk/community/inventory/ingestion.py +3172 -0
- alita_sdk/community/inventory/knowledge_graph.py +1457 -0
- alita_sdk/community/inventory/parsers/__init__.py +218 -0
- alita_sdk/community/inventory/parsers/base.py +295 -0
- alita_sdk/community/inventory/parsers/csharp_parser.py +907 -0
- alita_sdk/community/inventory/parsers/go_parser.py +851 -0
- alita_sdk/community/inventory/parsers/html_parser.py +389 -0
- alita_sdk/community/inventory/parsers/java_parser.py +593 -0
- alita_sdk/community/inventory/parsers/javascript_parser.py +629 -0
- alita_sdk/community/inventory/parsers/kotlin_parser.py +768 -0
- alita_sdk/community/inventory/parsers/markdown_parser.py +362 -0
- alita_sdk/community/inventory/parsers/python_parser.py +604 -0
- alita_sdk/community/inventory/parsers/rust_parser.py +858 -0
- alita_sdk/community/inventory/parsers/swift_parser.py +832 -0
- alita_sdk/community/inventory/parsers/text_parser.py +322 -0
- alita_sdk/community/inventory/parsers/yaml_parser.py +370 -0
- alita_sdk/community/inventory/patterns/__init__.py +61 -0
- alita_sdk/community/inventory/patterns/ast_adapter.py +380 -0
- alita_sdk/community/inventory/patterns/loader.py +348 -0
- alita_sdk/community/inventory/patterns/registry.py +198 -0
- alita_sdk/community/inventory/presets.py +535 -0
- alita_sdk/community/inventory/retrieval.py +1403 -0
- alita_sdk/community/inventory/toolkit.py +173 -0
- alita_sdk/community/inventory/toolkit_utils.py +176 -0
- alita_sdk/community/inventory/visualize.py +1370 -0
- alita_sdk/configurations/__init__.py +11 -0
- alita_sdk/configurations/ado.py +148 -2
- alita_sdk/configurations/azure_search.py +1 -1
- alita_sdk/configurations/bigquery.py +1 -1
- alita_sdk/configurations/bitbucket.py +94 -2
- alita_sdk/configurations/browser.py +18 -0
- alita_sdk/configurations/carrier.py +19 -0
- alita_sdk/configurations/confluence.py +130 -1
- alita_sdk/configurations/delta_lake.py +1 -1
- alita_sdk/configurations/figma.py +76 -5
- alita_sdk/configurations/github.py +65 -1
- alita_sdk/configurations/gitlab.py +81 -0
- alita_sdk/configurations/google_places.py +17 -0
- alita_sdk/configurations/jira.py +103 -0
- alita_sdk/configurations/openapi.py +323 -0
- alita_sdk/configurations/postman.py +1 -1
- alita_sdk/configurations/qtest.py +72 -3
- alita_sdk/configurations/report_portal.py +115 -0
- alita_sdk/configurations/salesforce.py +19 -0
- alita_sdk/configurations/service_now.py +1 -12
- alita_sdk/configurations/sharepoint.py +167 -0
- alita_sdk/configurations/sonar.py +18 -0
- alita_sdk/configurations/sql.py +20 -0
- alita_sdk/configurations/testio.py +101 -0
- alita_sdk/configurations/testrail.py +88 -0
- alita_sdk/configurations/xray.py +94 -1
- alita_sdk/configurations/zephyr_enterprise.py +94 -1
- alita_sdk/configurations/zephyr_essential.py +95 -0
- alita_sdk/runtime/clients/artifact.py +21 -4
- alita_sdk/runtime/clients/client.py +458 -67
- alita_sdk/runtime/clients/mcp_discovery.py +342 -0
- alita_sdk/runtime/clients/mcp_manager.py +262 -0
- alita_sdk/runtime/clients/sandbox_client.py +352 -0
- alita_sdk/runtime/langchain/_constants_bkup.py +1318 -0
- alita_sdk/runtime/langchain/assistant.py +183 -43
- alita_sdk/runtime/langchain/constants.py +647 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaExcelLoader.py +209 -31
- alita_sdk/runtime/langchain/document_loaders/AlitaImageLoader.py +1 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLinesLoader.py +77 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +10 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaMarkdownLoader.py +66 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaPDFLoader.py +79 -10
- alita_sdk/runtime/langchain/document_loaders/AlitaPowerPointLoader.py +52 -15
- alita_sdk/runtime/langchain/document_loaders/AlitaPythonLoader.py +9 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaTableLoader.py +1 -4
- alita_sdk/runtime/langchain/document_loaders/AlitaTextLoader.py +15 -2
- alita_sdk/runtime/langchain/document_loaders/ImageParser.py +30 -0
- alita_sdk/runtime/langchain/document_loaders/constants.py +189 -41
- alita_sdk/runtime/langchain/interfaces/llm_processor.py +4 -2
- alita_sdk/runtime/langchain/langraph_agent.py +493 -105
- alita_sdk/runtime/langchain/utils.py +118 -8
- alita_sdk/runtime/llms/preloaded.py +2 -6
- alita_sdk/runtime/models/mcp_models.py +61 -0
- alita_sdk/runtime/skills/__init__.py +91 -0
- alita_sdk/runtime/skills/callbacks.py +498 -0
- alita_sdk/runtime/skills/discovery.py +540 -0
- alita_sdk/runtime/skills/executor.py +610 -0
- alita_sdk/runtime/skills/input_builder.py +371 -0
- alita_sdk/runtime/skills/models.py +330 -0
- alita_sdk/runtime/skills/registry.py +355 -0
- alita_sdk/runtime/skills/skill_runner.py +330 -0
- alita_sdk/runtime/toolkits/__init__.py +28 -0
- alita_sdk/runtime/toolkits/application.py +14 -4
- alita_sdk/runtime/toolkits/artifact.py +25 -9
- alita_sdk/runtime/toolkits/datasource.py +13 -6
- alita_sdk/runtime/toolkits/mcp.py +782 -0
- alita_sdk/runtime/toolkits/planning.py +178 -0
- alita_sdk/runtime/toolkits/skill_router.py +238 -0
- alita_sdk/runtime/toolkits/subgraph.py +11 -6
- alita_sdk/runtime/toolkits/tools.py +314 -70
- alita_sdk/runtime/toolkits/vectorstore.py +11 -5
- alita_sdk/runtime/tools/__init__.py +24 -0
- alita_sdk/runtime/tools/application.py +16 -4
- alita_sdk/runtime/tools/artifact.py +367 -33
- alita_sdk/runtime/tools/data_analysis.py +183 -0
- alita_sdk/runtime/tools/function.py +100 -4
- alita_sdk/runtime/tools/graph.py +81 -0
- alita_sdk/runtime/tools/image_generation.py +218 -0
- alita_sdk/runtime/tools/llm.py +1032 -177
- alita_sdk/runtime/tools/loop.py +3 -1
- alita_sdk/runtime/tools/loop_output.py +3 -1
- alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
- alita_sdk/runtime/tools/mcp_remote_tool.py +181 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +3 -1
- alita_sdk/runtime/tools/planning/__init__.py +36 -0
- alita_sdk/runtime/tools/planning/models.py +246 -0
- alita_sdk/runtime/tools/planning/wrapper.py +607 -0
- alita_sdk/runtime/tools/router.py +2 -1
- alita_sdk/runtime/tools/sandbox.py +375 -0
- alita_sdk/runtime/tools/skill_router.py +776 -0
- alita_sdk/runtime/tools/tool.py +3 -1
- alita_sdk/runtime/tools/vectorstore.py +69 -65
- alita_sdk/runtime/tools/vectorstore_base.py +163 -90
- alita_sdk/runtime/utils/AlitaCallback.py +137 -21
- alita_sdk/runtime/utils/constants.py +5 -1
- alita_sdk/runtime/utils/mcp_client.py +492 -0
- alita_sdk/runtime/utils/mcp_oauth.py +361 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +434 -0
- alita_sdk/runtime/utils/mcp_tools_discovery.py +124 -0
- alita_sdk/runtime/utils/streamlit.py +41 -14
- alita_sdk/runtime/utils/toolkit_utils.py +28 -9
- alita_sdk/runtime/utils/utils.py +48 -0
- alita_sdk/tools/__init__.py +135 -37
- alita_sdk/tools/ado/__init__.py +2 -2
- alita_sdk/tools/ado/repos/__init__.py +16 -19
- alita_sdk/tools/ado/repos/repos_wrapper.py +12 -20
- alita_sdk/tools/ado/test_plan/__init__.py +27 -8
- alita_sdk/tools/ado/test_plan/test_plan_wrapper.py +56 -28
- alita_sdk/tools/ado/wiki/__init__.py +28 -12
- alita_sdk/tools/ado/wiki/ado_wrapper.py +114 -40
- alita_sdk/tools/ado/work_item/__init__.py +28 -12
- alita_sdk/tools/ado/work_item/ado_wrapper.py +95 -11
- alita_sdk/tools/advanced_jira_mining/__init__.py +13 -8
- alita_sdk/tools/aws/delta_lake/__init__.py +15 -11
- alita_sdk/tools/aws/delta_lake/tool.py +5 -1
- alita_sdk/tools/azure_ai/search/__init__.py +14 -8
- alita_sdk/tools/base/tool.py +5 -1
- alita_sdk/tools/base_indexer_toolkit.py +454 -110
- alita_sdk/tools/bitbucket/__init__.py +28 -19
- alita_sdk/tools/bitbucket/api_wrapper.py +285 -27
- alita_sdk/tools/bitbucket/cloud_api_wrapper.py +5 -5
- alita_sdk/tools/browser/__init__.py +41 -16
- alita_sdk/tools/browser/crawler.py +3 -1
- alita_sdk/tools/browser/utils.py +15 -6
- alita_sdk/tools/carrier/__init__.py +18 -17
- alita_sdk/tools/carrier/backend_reports_tool.py +8 -4
- alita_sdk/tools/carrier/excel_reporter.py +8 -4
- alita_sdk/tools/chunkers/__init__.py +3 -1
- alita_sdk/tools/chunkers/code/codeparser.py +1 -1
- alita_sdk/tools/chunkers/sematic/json_chunker.py +2 -1
- alita_sdk/tools/chunkers/sematic/markdown_chunker.py +97 -6
- alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
- alita_sdk/tools/chunkers/universal_chunker.py +270 -0
- alita_sdk/tools/cloud/aws/__init__.py +12 -7
- alita_sdk/tools/cloud/azure/__init__.py +12 -7
- alita_sdk/tools/cloud/gcp/__init__.py +12 -7
- alita_sdk/tools/cloud/k8s/__init__.py +12 -7
- alita_sdk/tools/code/linter/__init__.py +10 -8
- alita_sdk/tools/code/loaders/codesearcher.py +3 -2
- alita_sdk/tools/code/sonar/__init__.py +21 -13
- alita_sdk/tools/code_indexer_toolkit.py +199 -0
- alita_sdk/tools/confluence/__init__.py +22 -14
- alita_sdk/tools/confluence/api_wrapper.py +197 -58
- alita_sdk/tools/confluence/loader.py +14 -2
- alita_sdk/tools/custom_open_api/__init__.py +12 -5
- alita_sdk/tools/elastic/__init__.py +11 -8
- alita_sdk/tools/elitea_base.py +546 -64
- alita_sdk/tools/figma/__init__.py +60 -11
- alita_sdk/tools/figma/api_wrapper.py +1400 -167
- alita_sdk/tools/figma/figma_client.py +73 -0
- alita_sdk/tools/figma/toon_tools.py +2748 -0
- alita_sdk/tools/github/__init__.py +18 -17
- alita_sdk/tools/github/api_wrapper.py +9 -26
- alita_sdk/tools/github/github_client.py +81 -12
- alita_sdk/tools/github/schemas.py +2 -1
- alita_sdk/tools/github/tool.py +5 -1
- alita_sdk/tools/gitlab/__init__.py +19 -13
- alita_sdk/tools/gitlab/api_wrapper.py +256 -80
- alita_sdk/tools/gitlab_org/__init__.py +14 -10
- alita_sdk/tools/google/bigquery/__init__.py +14 -13
- alita_sdk/tools/google/bigquery/tool.py +5 -1
- alita_sdk/tools/google_places/__init__.py +21 -11
- alita_sdk/tools/jira/__init__.py +22 -11
- alita_sdk/tools/jira/api_wrapper.py +315 -168
- alita_sdk/tools/keycloak/__init__.py +11 -8
- alita_sdk/tools/localgit/__init__.py +9 -3
- alita_sdk/tools/localgit/local_git.py +62 -54
- alita_sdk/tools/localgit/tool.py +5 -1
- alita_sdk/tools/memory/__init__.py +38 -14
- alita_sdk/tools/non_code_indexer_toolkit.py +7 -2
- alita_sdk/tools/ocr/__init__.py +11 -8
- alita_sdk/tools/openapi/__init__.py +491 -106
- alita_sdk/tools/openapi/api_wrapper.py +1357 -0
- alita_sdk/tools/openapi/tool.py +20 -0
- alita_sdk/tools/pandas/__init__.py +20 -12
- alita_sdk/tools/pandas/api_wrapper.py +40 -45
- alita_sdk/tools/pandas/dataframe/generator/base.py +3 -1
- alita_sdk/tools/postman/__init__.py +11 -11
- alita_sdk/tools/postman/api_wrapper.py +19 -8
- alita_sdk/tools/postman/postman_analysis.py +8 -1
- alita_sdk/tools/pptx/__init__.py +11 -10
- alita_sdk/tools/qtest/__init__.py +22 -14
- alita_sdk/tools/qtest/api_wrapper.py +1784 -88
- alita_sdk/tools/rally/__init__.py +13 -10
- alita_sdk/tools/report_portal/__init__.py +23 -16
- alita_sdk/tools/salesforce/__init__.py +22 -16
- alita_sdk/tools/servicenow/__init__.py +21 -16
- alita_sdk/tools/servicenow/api_wrapper.py +1 -1
- alita_sdk/tools/sharepoint/__init__.py +17 -14
- alita_sdk/tools/sharepoint/api_wrapper.py +179 -39
- alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
- alita_sdk/tools/sharepoint/utils.py +8 -2
- alita_sdk/tools/slack/__init__.py +13 -8
- alita_sdk/tools/sql/__init__.py +22 -19
- alita_sdk/tools/sql/api_wrapper.py +71 -23
- alita_sdk/tools/testio/__init__.py +21 -13
- alita_sdk/tools/testrail/__init__.py +13 -11
- alita_sdk/tools/testrail/api_wrapper.py +214 -46
- alita_sdk/tools/utils/__init__.py +28 -4
- alita_sdk/tools/utils/content_parser.py +241 -55
- alita_sdk/tools/utils/text_operations.py +254 -0
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +83 -27
- alita_sdk/tools/xray/__init__.py +18 -14
- alita_sdk/tools/xray/api_wrapper.py +58 -113
- alita_sdk/tools/yagmail/__init__.py +9 -3
- alita_sdk/tools/zephyr/__init__.py +12 -7
- alita_sdk/tools/zephyr_enterprise/__init__.py +16 -9
- alita_sdk/tools/zephyr_enterprise/api_wrapper.py +30 -15
- alita_sdk/tools/zephyr_essential/__init__.py +16 -10
- alita_sdk/tools/zephyr_essential/api_wrapper.py +297 -54
- alita_sdk/tools/zephyr_essential/client.py +6 -4
- alita_sdk/tools/zephyr_scale/__init__.py +13 -8
- alita_sdk/tools/zephyr_scale/api_wrapper.py +39 -31
- alita_sdk/tools/zephyr_squad/__init__.py +12 -7
- {alita_sdk-0.3.257.dist-info → alita_sdk-0.3.584.dist-info}/METADATA +184 -37
- alita_sdk-0.3.584.dist-info/RECORD +452 -0
- alita_sdk-0.3.584.dist-info/entry_points.txt +2 -0
- alita_sdk/tools/bitbucket/tools.py +0 -304
- alita_sdk-0.3.257.dist-info/RECORD +0 -343
- {alita_sdk-0.3.257.dist-info → alita_sdk-0.3.584.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.257.dist-info → alita_sdk-0.3.584.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.257.dist-info → alita_sdk-0.3.584.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import re
|
|
2
3
|
from typing import Union, Any, Optional, Annotated, get_type_hints
|
|
3
4
|
from uuid import uuid4
|
|
4
5
|
from typing import Dict
|
|
@@ -11,6 +12,7 @@ from langchain_core.runnables import Runnable
|
|
|
11
12
|
from langchain_core.runnables import RunnableConfig
|
|
12
13
|
from langchain_core.tools import BaseTool, ToolException
|
|
13
14
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
15
|
+
from langgraph.errors import GraphRecursionError
|
|
14
16
|
from langgraph.graph import StateGraph
|
|
15
17
|
from langgraph.graph.graph import END, START
|
|
16
18
|
from langgraph.graph.state import CompiledStateGraph
|
|
@@ -18,8 +20,10 @@ from langgraph.managed.base import is_managed_value
|
|
|
18
20
|
from langgraph.prebuilt import InjectedStore
|
|
19
21
|
from langgraph.store.base import BaseStore
|
|
20
22
|
|
|
23
|
+
from .constants import PRINTER_NODE_RS, PRINTER, PRINTER_COMPLETED_STATE
|
|
21
24
|
from .mixedAgentRenderes import convert_message_to_json
|
|
22
|
-
from .utils import create_state, propagate_the_input_mapping
|
|
25
|
+
from .utils import create_state, propagate_the_input_mapping, safe_format
|
|
26
|
+
from ..utils.constants import TOOLKIT_NAME_META, TOOL_NAME_META
|
|
23
27
|
from ..tools.function import FunctionTool
|
|
24
28
|
from ..tools.indexer_tool import IndexerNode
|
|
25
29
|
from ..tools.llm import LLMNode
|
|
@@ -27,7 +31,7 @@ from ..tools.loop import LoopNode
|
|
|
27
31
|
from ..tools.loop_output import LoopToolNode
|
|
28
32
|
from ..tools.tool import ToolNode
|
|
29
33
|
from ..utils.evaluate import EvaluateTemplate
|
|
30
|
-
from ..utils.utils import clean_string
|
|
34
|
+
from ..utils.utils import clean_string
|
|
31
35
|
from ..tools.router import RouterNode
|
|
32
36
|
|
|
33
37
|
logger = logging.getLogger(__name__)
|
|
@@ -169,12 +173,13 @@ Answer only with step name, no need to add descrip in case none of the steps are
|
|
|
169
173
|
"""
|
|
170
174
|
|
|
171
175
|
def __init__(self, client, steps: str, description: str = "", decisional_inputs: Optional[list[str]] = [],
|
|
172
|
-
default_output: str = 'END'):
|
|
176
|
+
default_output: str = 'END', is_node: bool = False):
|
|
173
177
|
self.client = client
|
|
174
178
|
self.steps = ",".join([clean_string(step) for step in steps])
|
|
175
179
|
self.description = description
|
|
176
180
|
self.decisional_inputs = decisional_inputs
|
|
177
181
|
self.default_output = default_output if default_output != 'END' else END
|
|
182
|
+
self.is_node = is_node
|
|
178
183
|
|
|
179
184
|
def invoke(self, state: Annotated[BaseStore, InjectedStore()], config: Optional[RunnableConfig] = None) -> str:
|
|
180
185
|
additional_info = ""
|
|
@@ -184,10 +189,10 @@ Answer only with step name, no need to add descrip in case none of the steps are
|
|
|
184
189
|
decision_input = state.get('messages', [])[:]
|
|
185
190
|
else:
|
|
186
191
|
if len(additional_info) == 0:
|
|
187
|
-
additional_info = """###
|
|
192
|
+
additional_info = """### Additional info: """
|
|
188
193
|
additional_info += "{field}: {value}\n".format(field=field, value=state.get(field, ""))
|
|
189
194
|
decision_input.append(HumanMessage(
|
|
190
|
-
self.prompt.format(steps=self.steps, description=self.description, additional_info=additional_info)))
|
|
195
|
+
self.prompt.format(steps=self.steps, description=safe_format(self.description, state), additional_info=additional_info)))
|
|
191
196
|
completion = self.client.invoke(decision_input)
|
|
192
197
|
result = clean_string(completion.content.strip())
|
|
193
198
|
logger.info(f"Plan to transition to: {result}")
|
|
@@ -196,7 +201,8 @@ Answer only with step name, no need to add descrip in case none of the steps are
|
|
|
196
201
|
dispatch_custom_event(
|
|
197
202
|
"on_decision_edge", {"decisional_inputs": self.decisional_inputs, "state": state}, config=config
|
|
198
203
|
)
|
|
199
|
-
|
|
204
|
+
# support of legacy `decision` as part of node
|
|
205
|
+
return {"router_output": result} if self.is_node else result
|
|
200
206
|
|
|
201
207
|
|
|
202
208
|
class TransitionalEdge(Runnable):
|
|
@@ -231,6 +237,35 @@ class StateDefaultNode(Runnable):
|
|
|
231
237
|
result[key] = temp_value
|
|
232
238
|
return result
|
|
233
239
|
|
|
240
|
+
class PrinterNode(Runnable):
|
|
241
|
+
name = "PrinterNode"
|
|
242
|
+
|
|
243
|
+
def __init__(self, input_mapping: Optional[dict[str, dict]]):
|
|
244
|
+
self.input_mapping = input_mapping
|
|
245
|
+
|
|
246
|
+
def invoke(self, state: BaseStore, config: Optional[RunnableConfig] = None) -> dict:
|
|
247
|
+
logger.info(f"Printer Node - Current state variables: {state}")
|
|
248
|
+
result = {}
|
|
249
|
+
logger.debug(f"Initial text pattern: {self.input_mapping}")
|
|
250
|
+
mapping = propagate_the_input_mapping(self.input_mapping, [], state)
|
|
251
|
+
# for printer node we expect that all the lists will be joined into strings already
|
|
252
|
+
# Join any lists that haven't been converted yet
|
|
253
|
+
for key, value in mapping.items():
|
|
254
|
+
if isinstance(value, list):
|
|
255
|
+
mapping[key] = ', '.join(str(item) for item in value)
|
|
256
|
+
if mapping.get(PRINTER) is None:
|
|
257
|
+
raise ToolException(f"PrinterNode requires '{PRINTER}' field in input mapping")
|
|
258
|
+
formatted_output = mapping[PRINTER]
|
|
259
|
+
# add info label to the printer's output
|
|
260
|
+
if not formatted_output == PRINTER_COMPLETED_STATE:
|
|
261
|
+
# convert formatted output to string if it's not
|
|
262
|
+
if not isinstance(formatted_output, str):
|
|
263
|
+
formatted_output = str(formatted_output)
|
|
264
|
+
formatted_output += f"\n\n-----\n*How to proceed?*\n* *to resume the pipeline - type anything...*"
|
|
265
|
+
logger.debug(f"Formatted output: {formatted_output}")
|
|
266
|
+
result[PRINTER_NODE_RS] = formatted_output
|
|
267
|
+
return result
|
|
268
|
+
|
|
234
269
|
|
|
235
270
|
class StateModifierNode(Runnable):
|
|
236
271
|
name = "StateModifierNode"
|
|
@@ -248,19 +283,82 @@ class StateModifierNode(Runnable):
|
|
|
248
283
|
|
|
249
284
|
# Collect input variables from state
|
|
250
285
|
input_data = {}
|
|
286
|
+
|
|
251
287
|
for var in self.input_variables:
|
|
252
288
|
if var in state:
|
|
253
289
|
input_data[var] = state.get(var)
|
|
254
|
-
|
|
290
|
+
type_of_output = type(state.get(self.output_variables[0])) if self.output_variables else None
|
|
255
291
|
# Render the template using Jinja
|
|
256
|
-
|
|
257
|
-
|
|
292
|
+
import json
|
|
293
|
+
import base64
|
|
294
|
+
from jinja2 import Environment
|
|
295
|
+
|
|
296
|
+
def from_json(value):
|
|
297
|
+
"""Convert JSON string to Python object"""
|
|
298
|
+
try:
|
|
299
|
+
return json.loads(value)
|
|
300
|
+
except (json.JSONDecodeError, TypeError) as e:
|
|
301
|
+
logger.warning(f"Failed to parse JSON value: {e}")
|
|
302
|
+
return value
|
|
303
|
+
|
|
304
|
+
def base64_to_string(value):
|
|
305
|
+
"""Convert base64 encoded string to regular string"""
|
|
306
|
+
try:
|
|
307
|
+
return base64.b64decode(value).decode('utf-8')
|
|
308
|
+
except Exception as e:
|
|
309
|
+
logger.warning(f"Failed to decode base64 value: {e}")
|
|
310
|
+
return value
|
|
311
|
+
|
|
312
|
+
def split_by_words(value, chunk_size=100):
|
|
313
|
+
words = value.split()
|
|
314
|
+
return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
|
|
315
|
+
|
|
316
|
+
def split_by_regex(value, pattern):
|
|
317
|
+
"""Splits the provided string using the specified regex pattern."""
|
|
318
|
+
return re.split(pattern, value)
|
|
319
|
+
|
|
320
|
+
env = Environment()
|
|
321
|
+
env.filters['from_json'] = from_json
|
|
322
|
+
env.filters['base64_to_string'] = base64_to_string
|
|
323
|
+
env.filters['split_by_words'] = split_by_words
|
|
324
|
+
env.filters['split_by_regex'] = split_by_regex
|
|
325
|
+
|
|
326
|
+
template = env.from_string(self.template)
|
|
327
|
+
rendered_message = template.render(**input_data)
|
|
258
328
|
result = {}
|
|
259
329
|
# Store the rendered message in the state or messages
|
|
260
330
|
if len(self.output_variables) > 0:
|
|
261
331
|
# Use the first output variable to store the rendered content
|
|
262
332
|
output_var = self.output_variables[0]
|
|
263
|
-
|
|
333
|
+
|
|
334
|
+
# Convert rendered_message to the appropriate type
|
|
335
|
+
if type_of_output is not None:
|
|
336
|
+
try:
|
|
337
|
+
if type_of_output == dict:
|
|
338
|
+
result[output_var] = json.loads(rendered_message) if isinstance(rendered_message, str) else dict(rendered_message)
|
|
339
|
+
elif type_of_output == list:
|
|
340
|
+
result[output_var] = json.loads(rendered_message) if isinstance(rendered_message, str) else list(rendered_message)
|
|
341
|
+
elif type_of_output == int:
|
|
342
|
+
result[output_var] = int(rendered_message)
|
|
343
|
+
elif type_of_output == float:
|
|
344
|
+
result[output_var] = float(rendered_message)
|
|
345
|
+
elif type_of_output == str:
|
|
346
|
+
result[output_var] = str(rendered_message)
|
|
347
|
+
elif type_of_output == bool:
|
|
348
|
+
if isinstance(rendered_message, str):
|
|
349
|
+
result[output_var] = rendered_message.lower() in ('true', '1', 'yes', 'on')
|
|
350
|
+
else:
|
|
351
|
+
result[output_var] = bool(rendered_message)
|
|
352
|
+
elif type_of_output == type(None):
|
|
353
|
+
result[output_var] = None
|
|
354
|
+
else:
|
|
355
|
+
# Fallback to string if type is not recognized
|
|
356
|
+
result[output_var] = str(rendered_message)
|
|
357
|
+
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
|
358
|
+
logger.warning(f"Failed to convert rendered_message to {type_of_output.__name__}: {e}. Using string fallback.")
|
|
359
|
+
result[output_var] = str(rendered_message)
|
|
360
|
+
else:
|
|
361
|
+
result[output_var] = rendered_message
|
|
264
362
|
|
|
265
363
|
# Clean up specified variables (make them empty, not delete)
|
|
266
364
|
|
|
@@ -284,8 +382,8 @@ class StateModifierNode(Runnable):
|
|
|
284
382
|
return result
|
|
285
383
|
|
|
286
384
|
|
|
287
|
-
|
|
288
|
-
|
|
385
|
+
def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_before=None, interrupt_after=None,
|
|
386
|
+
state_class=None, output_variables=None):
|
|
289
387
|
# prepare output channels
|
|
290
388
|
if interrupt_after is None:
|
|
291
389
|
interrupt_after = []
|
|
@@ -350,6 +448,50 @@ def prepare_output_schema(lg_builder, memory, store, debug=False, interrupt_befo
|
|
|
350
448
|
return compiled
|
|
351
449
|
|
|
352
450
|
|
|
451
|
+
def find_tool_by_name_or_metadata(tools: list, tool_name: str, toolkit_name: Optional[str] = None) -> Optional[BaseTool]:
|
|
452
|
+
"""
|
|
453
|
+
Find a tool by name or by matching metadata (toolkit_name + tool_name).
|
|
454
|
+
|
|
455
|
+
For toolkit nodes with toolkit_name specified, this function checks:
|
|
456
|
+
1. Metadata match first (toolkit_name + tool_name) - PRIORITY when toolkit_name is provided
|
|
457
|
+
2. Direct tool name match (backward compatibility fallback)
|
|
458
|
+
|
|
459
|
+
For toolkit nodes without toolkit_name, or other node types:
|
|
460
|
+
1. Direct tool name match
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
tools: List of available tools
|
|
464
|
+
tool_name: The tool name to search for
|
|
465
|
+
toolkit_name: Optional toolkit name for metadata matching
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
The matching tool or None if not found
|
|
469
|
+
"""
|
|
470
|
+
# When toolkit_name is specified, prioritize metadata matching
|
|
471
|
+
if toolkit_name:
|
|
472
|
+
for tool in tools:
|
|
473
|
+
# Check metadata match first
|
|
474
|
+
if hasattr(tool, 'metadata') and tool.metadata:
|
|
475
|
+
metadata_toolkit_name = tool.metadata.get(TOOLKIT_NAME_META)
|
|
476
|
+
metadata_tool_name = tool.metadata.get(TOOL_NAME_META)
|
|
477
|
+
|
|
478
|
+
# Match if both toolkit_name and tool_name in metadata match
|
|
479
|
+
if metadata_toolkit_name == toolkit_name and metadata_tool_name == tool_name:
|
|
480
|
+
return tool
|
|
481
|
+
|
|
482
|
+
# Fallback to direct name match for backward compatibility
|
|
483
|
+
for tool in tools:
|
|
484
|
+
if tool.name == tool_name:
|
|
485
|
+
return tool
|
|
486
|
+
else:
|
|
487
|
+
# No toolkit_name specified, use direct name match only
|
|
488
|
+
for tool in tools:
|
|
489
|
+
if tool.name == tool_name:
|
|
490
|
+
return tool
|
|
491
|
+
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
|
|
353
495
|
def create_graph(
|
|
354
496
|
client: Any,
|
|
355
497
|
yaml_schema: str,
|
|
@@ -385,44 +527,64 @@ def create_graph(
|
|
|
385
527
|
node_type = node.get('type', 'function')
|
|
386
528
|
node_id = clean_string(node['id'])
|
|
387
529
|
toolkit_name = node.get('toolkit_name')
|
|
388
|
-
tool_name = clean_string(node.get('tool',
|
|
389
|
-
|
|
390
|
-
tool_name = f"{clean_string(toolkit_name)}{TOOLKIT_SPLITTER}{tool_name}"
|
|
530
|
+
tool_name = clean_string(node.get('tool', ''))
|
|
531
|
+
# Tool names are now clean (no prefix needed)
|
|
391
532
|
logger.info(f"Node: {node_id} : {node_type} - {tool_name}")
|
|
392
|
-
if node_type in ['function', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
533
|
+
if node_type in ['function', 'toolkit', 'mcp', 'tool', 'loop', 'loop_from_tool', 'indexer', 'subgraph', 'pipeline', 'agent']:
|
|
534
|
+
if node_type in ['mcp', 'toolkit', 'agent'] and not tool_name:
|
|
535
|
+
# tool is not specified
|
|
536
|
+
raise ToolException(f"Tool name is required for {node_type} node with id '{node_id}'")
|
|
537
|
+
|
|
538
|
+
# Unified validation and tool finding for toolkit, mcp, and agent node types
|
|
539
|
+
matching_tool = None
|
|
540
|
+
if node_type in ['toolkit', 'mcp', 'agent']:
|
|
541
|
+
# Use enhanced validation that checks both direct name and metadata
|
|
542
|
+
matching_tool = find_tool_by_name_or_metadata(tools, tool_name, toolkit_name)
|
|
543
|
+
if not matching_tool:
|
|
544
|
+
# tool is not found in the provided tools
|
|
545
|
+
error_msg = f"Node `{node_id}` with type `{node_type}` has tool '{tool_name}'"
|
|
546
|
+
if toolkit_name:
|
|
547
|
+
error_msg += f" (toolkit: '{toolkit_name}')"
|
|
548
|
+
error_msg += f" which is not found in the provided tools. Make sure it is connected properly. Available tools: {format_tools(tools)}"
|
|
549
|
+
raise ToolException(error_msg)
|
|
550
|
+
else:
|
|
551
|
+
# For other node types, find tool by direct name match
|
|
552
|
+
for tool in tools:
|
|
553
|
+
if tool.name == tool_name:
|
|
554
|
+
matching_tool = tool
|
|
555
|
+
break
|
|
556
|
+
|
|
557
|
+
if matching_tool:
|
|
558
|
+
if node_type in ['function', 'toolkit', 'mcp']:
|
|
396
559
|
lg_builder.add_node(node_id, FunctionTool(
|
|
397
|
-
tool=
|
|
560
|
+
tool=matching_tool, name=node_id, return_type='dict',
|
|
398
561
|
output_variables=node.get('output', []),
|
|
399
562
|
input_mapping=node.get('input_mapping',
|
|
400
563
|
{'messages': {'type': 'variable', 'value': 'messages'}}),
|
|
401
564
|
input_variables=node.get('input', ['messages'])))
|
|
402
565
|
elif node_type == 'agent':
|
|
403
566
|
input_params = node.get('input', ['messages'])
|
|
404
|
-
input_mapping =
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
if 'messages' in input_params:
|
|
408
|
-
input_mapping['chat_history'] = {'type': 'variable', 'value': 'messages'}
|
|
567
|
+
input_mapping = node.get('input_mapping',
|
|
568
|
+
{'messages': {'type': 'variable', 'value': 'messages'}})
|
|
569
|
+
output_vars = node.get('output', [])
|
|
409
570
|
lg_builder.add_node(node_id, FunctionTool(
|
|
410
|
-
client=client, tool=
|
|
411
|
-
name=
|
|
412
|
-
output_variables=
|
|
571
|
+
client=client, tool=matching_tool,
|
|
572
|
+
name=node_id, return_type='str',
|
|
573
|
+
output_variables=output_vars + ['messages'] if 'messages' not in output_vars else output_vars,
|
|
413
574
|
input_variables=input_params,
|
|
414
575
|
input_mapping= input_mapping
|
|
415
576
|
))
|
|
416
577
|
elif node_type == 'subgraph' or node_type == 'pipeline':
|
|
417
578
|
# assign parent memory/store
|
|
418
|
-
#
|
|
419
|
-
#
|
|
579
|
+
# matching_tool.checkpointer = memory
|
|
580
|
+
# matching_tool.store = store
|
|
420
581
|
# wrap with mappings
|
|
421
582
|
pipeline_name = node.get('tool', None)
|
|
422
583
|
if not pipeline_name:
|
|
423
|
-
raise ValueError(
|
|
584
|
+
raise ValueError(
|
|
585
|
+
"Subgraph must have a 'tool' node: add required tool to the subgraph node")
|
|
424
586
|
node_fn = SubgraphRunnable(
|
|
425
|
-
inner=
|
|
587
|
+
inner=matching_tool.graph,
|
|
426
588
|
name=pipeline_name,
|
|
427
589
|
input_mapping=node.get('input_mapping', {}),
|
|
428
590
|
output_mapping=node.get('output_mapping', {}),
|
|
@@ -431,26 +593,17 @@ def create_graph(
|
|
|
431
593
|
break # skip legacy handling
|
|
432
594
|
elif node_type == 'tool':
|
|
433
595
|
lg_builder.add_node(node_id, ToolNode(
|
|
434
|
-
client=client, tool=
|
|
435
|
-
name=
|
|
596
|
+
client=client, tool=matching_tool,
|
|
597
|
+
name=node_id, return_type='dict',
|
|
436
598
|
output_variables=node.get('output', []),
|
|
437
599
|
input_variables=node.get('input', ['messages']),
|
|
438
600
|
structured_output=node.get('structured_output', False),
|
|
439
601
|
task=node.get('task')
|
|
440
602
|
))
|
|
441
|
-
# TODO: decide on struct output for agent nodes
|
|
442
|
-
# elif node_type == 'agent':
|
|
443
|
-
# lg_builder.add_node(node_id, AgentNode(
|
|
444
|
-
# client=client, tool=tool,
|
|
445
|
-
# name=node['id'], return_type='dict',
|
|
446
|
-
# output_variables=node.get('output', []),
|
|
447
|
-
# input_variables=node.get('input', ['messages']),
|
|
448
|
-
# task=node.get('task')
|
|
449
|
-
# ))
|
|
450
603
|
elif node_type == 'loop':
|
|
451
604
|
lg_builder.add_node(node_id, LoopNode(
|
|
452
|
-
client=client, tool=
|
|
453
|
-
name=
|
|
605
|
+
client=client, tool=matching_tool,
|
|
606
|
+
name=node_id, return_type='dict',
|
|
454
607
|
output_variables=node.get('output', []),
|
|
455
608
|
input_variables=node.get('input', ['messages']),
|
|
456
609
|
task=node.get('task', '')
|
|
@@ -459,14 +612,15 @@ def create_graph(
|
|
|
459
612
|
loop_toolkit_name = node.get('loop_toolkit_name')
|
|
460
613
|
loop_tool_name = node.get('loop_tool')
|
|
461
614
|
if (loop_toolkit_name and loop_tool_name) or loop_tool_name:
|
|
462
|
-
|
|
615
|
+
# Use clean tool name (no prefix)
|
|
616
|
+
loop_tool_name = clean_string(loop_tool_name)
|
|
463
617
|
for t in tools:
|
|
464
618
|
if t.name == loop_tool_name:
|
|
465
619
|
logger.debug(f"Loop tool discovered: {t}")
|
|
466
620
|
lg_builder.add_node(node_id, LoopToolNode(
|
|
467
621
|
client=client,
|
|
468
|
-
name=
|
|
469
|
-
tool=
|
|
622
|
+
name=node_id, return_type='dict',
|
|
623
|
+
tool=matching_tool, loop_tool=t,
|
|
470
624
|
variables_mapping=node.get('variables_mapping', {}),
|
|
471
625
|
output_variables=node.get('output', []),
|
|
472
626
|
input_variables=node.get('input', ['messages']),
|
|
@@ -482,16 +636,28 @@ def create_graph(
|
|
|
482
636
|
indexer_tool = t
|
|
483
637
|
logger.info(f"Indexer tool: {indexer_tool}")
|
|
484
638
|
lg_builder.add_node(node_id, IndexerNode(
|
|
485
|
-
client=client, tool=
|
|
639
|
+
client=client, tool=matching_tool,
|
|
486
640
|
index_tool=indexer_tool,
|
|
487
641
|
input_mapping=node.get('input_mapping', {}),
|
|
488
|
-
name=
|
|
642
|
+
name=node_id, return_type='dict',
|
|
489
643
|
chunking_tool=node.get('chunking_tool', None),
|
|
490
644
|
chunking_config=node.get('chunking_config', {}),
|
|
491
645
|
output_variables=node.get('output', []),
|
|
492
646
|
input_variables=node.get('input', ['messages']),
|
|
493
647
|
structured_output=node.get('structured_output', False)))
|
|
494
|
-
|
|
648
|
+
elif node_type == 'code':
|
|
649
|
+
from ..tools.sandbox import create_sandbox_tool
|
|
650
|
+
sandbox_tool = create_sandbox_tool(stateful=False, allow_net=True,
|
|
651
|
+
alita_client=kwargs.get('alita_client', None))
|
|
652
|
+
code_data = node.get('code', {'type': 'fixed', 'value': "return 'Code block is empty'"})
|
|
653
|
+
lg_builder.add_node(node_id, FunctionTool(
|
|
654
|
+
tool=sandbox_tool, name=node['id'], return_type='dict',
|
|
655
|
+
output_variables=node.get('output', []),
|
|
656
|
+
input_mapping={'code': code_data},
|
|
657
|
+
input_variables=node.get('input', ['messages']),
|
|
658
|
+
structured_output=node.get('structured_output', False),
|
|
659
|
+
alita_client=kwargs.get('alita_client', None)
|
|
660
|
+
))
|
|
495
661
|
elif node_type == 'llm':
|
|
496
662
|
output_vars = node.get('output', [])
|
|
497
663
|
output_vars_dict = {
|
|
@@ -504,10 +670,10 @@ def create_graph(
|
|
|
504
670
|
tool_names = []
|
|
505
671
|
if isinstance(connected_tools, dict):
|
|
506
672
|
for toolkit, selected_tools in connected_tools.items():
|
|
507
|
-
|
|
508
|
-
|
|
673
|
+
# Add tool names directly (no prefix)
|
|
674
|
+
tool_names.extend(selected_tools)
|
|
509
675
|
elif isinstance(connected_tools, list):
|
|
510
|
-
#
|
|
676
|
+
# Use provided tool names as-is
|
|
511
677
|
tool_names = connected_tools
|
|
512
678
|
|
|
513
679
|
if tool_names:
|
|
@@ -520,28 +686,44 @@ def create_graph(
|
|
|
520
686
|
else:
|
|
521
687
|
# Use all available tools
|
|
522
688
|
available_tools = [tool for tool in tools if isinstance(tool, BaseTool)]
|
|
523
|
-
|
|
689
|
+
|
|
524
690
|
lg_builder.add_node(node_id, LLMNode(
|
|
525
|
-
client=client,
|
|
526
|
-
|
|
527
|
-
name=
|
|
691
|
+
client=client,
|
|
692
|
+
input_mapping=node.get('input_mapping', {'messages': {'type': 'variable', 'value': 'messages'}}),
|
|
693
|
+
name=node_id,
|
|
528
694
|
return_type='dict',
|
|
529
|
-
response_key=node.get('response_key', 'messages'),
|
|
530
695
|
structured_output_dict=output_vars_dict,
|
|
531
696
|
output_variables=output_vars,
|
|
532
697
|
input_variables=node.get('input', ['messages']),
|
|
533
698
|
structured_output=node.get('structured_output', False),
|
|
699
|
+
tool_execution_timeout=node.get('tool_execution_timeout', 900),
|
|
534
700
|
available_tools=available_tools,
|
|
535
|
-
tool_names=tool_names
|
|
536
|
-
|
|
537
|
-
# Add a RouterNode as an independent node
|
|
538
|
-
lg_builder.add_node(node_id, RouterNode(
|
|
539
|
-
name=node['id'],
|
|
540
|
-
condition=node.get('condition', ''),
|
|
541
|
-
routes=node.get('routes', []),
|
|
542
|
-
default_output=node.get('default_output', 'END'),
|
|
543
|
-
input_variables=node.get('input', ['messages'])
|
|
701
|
+
tool_names=tool_names,
|
|
702
|
+
steps_limit=kwargs.get('steps_limit', 25)
|
|
544
703
|
))
|
|
704
|
+
elif node_type in ['router', 'decision']:
|
|
705
|
+
if node_type == 'router':
|
|
706
|
+
# Add a RouterNode as an independent node
|
|
707
|
+
lg_builder.add_node(node_id, RouterNode(
|
|
708
|
+
name=node_id,
|
|
709
|
+
condition=node.get('condition', ''),
|
|
710
|
+
routes=node.get('routes', []),
|
|
711
|
+
default_output=node.get('default_output', 'END'),
|
|
712
|
+
input_variables=node.get('input', ['messages'])
|
|
713
|
+
))
|
|
714
|
+
elif node_type == 'decision':
|
|
715
|
+
logger.info(f'Adding decision: {node["nodes"]}')
|
|
716
|
+
# fallback to old-style decision node
|
|
717
|
+
decisional_inputs = node.get('decisional_inputs')
|
|
718
|
+
decisional_inputs = node.get('input', ['messages']) if not decisional_inputs else decisional_inputs
|
|
719
|
+
lg_builder.add_node(node_id, DecisionEdge(
|
|
720
|
+
client, node['nodes'],
|
|
721
|
+
node.get('description', ""),
|
|
722
|
+
decisional_inputs=decisional_inputs,
|
|
723
|
+
default_output=node.get('default_output', 'END'),
|
|
724
|
+
is_node=True
|
|
725
|
+
))
|
|
726
|
+
|
|
545
727
|
# Add a single conditional edge for all routes
|
|
546
728
|
lg_builder.add_conditional_edges(
|
|
547
729
|
node_id,
|
|
@@ -552,6 +734,7 @@ def create_graph(
|
|
|
552
734
|
default_output=node.get('default_output', 'END')
|
|
553
735
|
)
|
|
554
736
|
)
|
|
737
|
+
continue
|
|
555
738
|
elif node_type == 'state_modifier':
|
|
556
739
|
lg_builder.add_node(node_id, StateModifierNode(
|
|
557
740
|
template=node.get('template', ''),
|
|
@@ -559,6 +742,22 @@ def create_graph(
|
|
|
559
742
|
input_variables=node.get('input', ['messages']),
|
|
560
743
|
output_variables=node.get('output', [])
|
|
561
744
|
))
|
|
745
|
+
elif node_type == 'printer':
|
|
746
|
+
lg_builder.add_node(node_id, PrinterNode(
|
|
747
|
+
input_mapping=node.get('input_mapping', {'printer': {'type': 'fixed', 'value': ''}}),
|
|
748
|
+
))
|
|
749
|
+
|
|
750
|
+
# add interrupts after printer node if specified
|
|
751
|
+
interrupt_after.append(clean_string(node_id))
|
|
752
|
+
|
|
753
|
+
# reset printer output variable to avoid carrying over
|
|
754
|
+
reset_node_id = f"{node_id}_reset"
|
|
755
|
+
lg_builder.add_node(reset_node_id, PrinterNode(
|
|
756
|
+
input_mapping={'printer': {'type': 'fixed', 'value': PRINTER_COMPLETED_STATE}}
|
|
757
|
+
))
|
|
758
|
+
lg_builder.add_conditional_edges(node_id, TransitionalEdge(reset_node_id))
|
|
759
|
+
lg_builder.add_conditional_edges(reset_node_id, TransitionalEdge(clean_string(node['transition'])))
|
|
760
|
+
continue
|
|
562
761
|
if node.get('transition'):
|
|
563
762
|
next_step = clean_string(node['transition'])
|
|
564
763
|
logger.info(f'Adding transition: {next_step}')
|
|
@@ -584,14 +783,11 @@ def create_graph(
|
|
|
584
783
|
entry_point = clean_string(schema['entry_point'])
|
|
585
784
|
except KeyError:
|
|
586
785
|
raise ToolException("Entry point is not defined in the schema. Please define 'entry_point' in the schema.")
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
lg_builder.set_entry_point(state_default_node.name)
|
|
593
|
-
lg_builder.add_conditional_edges(state_default_node.name, TransitionalEdge(entry_point))
|
|
594
|
-
break
|
|
786
|
+
if state.items():
|
|
787
|
+
state_default_node = StateDefaultNode(default_vars=set_defaults(state))
|
|
788
|
+
lg_builder.add_node(state_default_node.name, state_default_node)
|
|
789
|
+
lg_builder.set_entry_point(state_default_node.name)
|
|
790
|
+
lg_builder.add_conditional_edges(state_default_node.name, TransitionalEdge(entry_point))
|
|
595
791
|
else:
|
|
596
792
|
# if no state variables are defined, set the entry point directly
|
|
597
793
|
lg_builder.set_entry_point(entry_point)
|
|
@@ -619,7 +815,7 @@ def create_graph(
|
|
|
619
815
|
)
|
|
620
816
|
except ValueError as e:
|
|
621
817
|
raise ValueError(
|
|
622
|
-
f"Validation of the schema failed. {e}\n\nDEBUG INFO:**Schema Nodes:**\n\n{lg_builder.nodes}\n\n**Schema
|
|
818
|
+
f"Validation of the schema failed. {e}\n\nDEBUG INFO:**Schema Nodes:**\n\n*{'\n*'.join(lg_builder.nodes.keys())}\n\n**Schema Edges:**\n\n{lg_builder.edges}\n\n**Tools Available:**\n\n{format_tools(tools)}")
|
|
623
819
|
# If building a nested subgraph, return the raw CompiledStateGraph
|
|
624
820
|
if for_subgraph:
|
|
625
821
|
return graph
|
|
@@ -633,6 +829,46 @@ def create_graph(
|
|
|
633
829
|
)
|
|
634
830
|
return compiled.validate()
|
|
635
831
|
|
|
832
|
+
def format_tools(tools_list: list) -> str:
|
|
833
|
+
"""Format a list of tool names into a comma-separated string."""
|
|
834
|
+
try:
|
|
835
|
+
return ', '.join([tool.name for tool in tools_list])
|
|
836
|
+
except Exception as e:
|
|
837
|
+
logger.warning(f"Failed to format tools list: {e}")
|
|
838
|
+
return str(tools_list)
|
|
839
|
+
|
|
840
|
+
def set_defaults(d):
|
|
841
|
+
"""Set default values for dictionary entries based on their type."""
|
|
842
|
+
type_defaults = {
|
|
843
|
+
'str': '',
|
|
844
|
+
'list': [],
|
|
845
|
+
'dict': {},
|
|
846
|
+
'int': 0,
|
|
847
|
+
'float': 0.0,
|
|
848
|
+
'bool': False,
|
|
849
|
+
# add more types as needed
|
|
850
|
+
}
|
|
851
|
+
# Build state_types mapping with STRING type names (not actual type objects)
|
|
852
|
+
state_types = {}
|
|
853
|
+
|
|
854
|
+
for k, v in d.items():
|
|
855
|
+
# Skip 'input' key as it is not a state initial variable
|
|
856
|
+
if k == 'input':
|
|
857
|
+
continue
|
|
858
|
+
# set value or default if type is defined
|
|
859
|
+
if 'value' not in v:
|
|
860
|
+
v['value'] = type_defaults.get(v['type'], None)
|
|
861
|
+
|
|
862
|
+
# Also build the state_types mapping with STRING type names
|
|
863
|
+
var_type = v['type'] if isinstance(v, dict) else v
|
|
864
|
+
if var_type in ['str', 'int', 'float', 'bool', 'list', 'dict', 'number']:
|
|
865
|
+
# Store the string type name, not the actual type object
|
|
866
|
+
state_types[k] = var_type if var_type != 'number' else 'int'
|
|
867
|
+
|
|
868
|
+
# Add state_types as a default value that will be set at initialization
|
|
869
|
+
# Use string type names to avoid serialization issues
|
|
870
|
+
d['state_types'] = {'type': 'dict', 'value': state_types}
|
|
871
|
+
return d
|
|
636
872
|
|
|
637
873
|
def convert_dict_to_message(msg_dict):
|
|
638
874
|
"""Convert a dictionary message to a LangChain message object."""
|
|
@@ -665,56 +901,208 @@ class LangGraphAgentRunnable(CompiledStateGraph):
|
|
|
665
901
|
def invoke(self, input: Union[dict[str, Any], Any],
|
|
666
902
|
config: Optional[RunnableConfig] = None,
|
|
667
903
|
*args, **kwargs):
|
|
668
|
-
logger.info(f"
|
|
669
|
-
if
|
|
904
|
+
logger.info(f"Incoming Input: {input}")
|
|
905
|
+
if config is None:
|
|
906
|
+
config = RunnableConfig()
|
|
907
|
+
if not config.get("configurable", {}).get("thread_id", ""):
|
|
670
908
|
config["configurable"] = {"thread_id": str(uuid4())}
|
|
671
909
|
thread_id = config.get("configurable", {}).get("thread_id")
|
|
910
|
+
|
|
911
|
+
# Check if checkpoint exists early for chat_history handling
|
|
912
|
+
checkpoint_exists = self.checkpointer and self.checkpointer.get_tuple(config)
|
|
913
|
+
|
|
672
914
|
# Handle chat history and current input properly
|
|
673
915
|
if input.get('chat_history') and not input.get('messages'):
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
916
|
+
if checkpoint_exists:
|
|
917
|
+
# Checkpoint already has conversation history - discard redundant chat_history
|
|
918
|
+
input.pop('chat_history', None)
|
|
919
|
+
else:
|
|
920
|
+
# No checkpoint - convert chat history dict messages to LangChain message objects
|
|
921
|
+
chat_history = input.pop('chat_history')
|
|
922
|
+
input['messages'] = [convert_dict_to_message(msg) for msg in chat_history]
|
|
923
|
+
|
|
924
|
+
# handler for LLM node: if no input (Chat perspective), then take last human message
|
|
925
|
+
# Track if input came from messages to handle content extraction properly
|
|
926
|
+
input_from_messages = False
|
|
927
|
+
if not input.get('input'):
|
|
928
|
+
if input.get('messages'):
|
|
929
|
+
input['input'] = [next((msg for msg in reversed(input['messages']) if isinstance(msg, HumanMessage)),
|
|
930
|
+
None)]
|
|
931
|
+
if input['input'] is not None:
|
|
932
|
+
input_from_messages = True
|
|
933
|
+
|
|
678
934
|
# Append current input to existing messages instead of overwriting
|
|
679
935
|
if input.get('input'):
|
|
680
|
-
|
|
936
|
+
if isinstance(input['input'], str):
|
|
937
|
+
current_message = input['input']
|
|
938
|
+
else:
|
|
939
|
+
# input can be a list of messages or a single message object
|
|
940
|
+
current_message = input.get('input')[-1]
|
|
941
|
+
|
|
942
|
+
# TODO: add handler after we add 2+ inputs (filterByType, etc.)
|
|
943
|
+
if isinstance(current_message, HumanMessage):
|
|
944
|
+
current_content = current_message.content
|
|
945
|
+
if isinstance(current_content, list):
|
|
946
|
+
# Extract text parts and keep non-text parts (images, etc.)
|
|
947
|
+
text_contents = []
|
|
948
|
+
non_text_parts = []
|
|
949
|
+
|
|
950
|
+
for item in current_content:
|
|
951
|
+
if isinstance(item, dict) and item.get('type') == 'text':
|
|
952
|
+
text_contents.append(item['text'])
|
|
953
|
+
elif isinstance(item, str):
|
|
954
|
+
text_contents.append(item)
|
|
955
|
+
else:
|
|
956
|
+
# Keep image_url and other non-text content
|
|
957
|
+
non_text_parts.append(item)
|
|
958
|
+
|
|
959
|
+
# Set input to the joined text
|
|
960
|
+
input['input'] = ". ".join(text_contents) if text_contents else ""
|
|
961
|
+
|
|
962
|
+
# If this message came from input['messages'], update or remove it
|
|
963
|
+
if input_from_messages:
|
|
964
|
+
if non_text_parts:
|
|
965
|
+
# Keep the message but only with non-text content (images, etc.)
|
|
966
|
+
current_message.content = non_text_parts
|
|
967
|
+
else:
|
|
968
|
+
# All content was text, remove this message from the list
|
|
969
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
970
|
+
else:
|
|
971
|
+
# Message came from input['input'], not from input['messages']
|
|
972
|
+
# If there are non-text parts (images, etc.), preserve them in messages
|
|
973
|
+
if non_text_parts:
|
|
974
|
+
# Initialize messages if it doesn't exist or is empty
|
|
975
|
+
if not input.get('messages'):
|
|
976
|
+
input['messages'] = []
|
|
977
|
+
# Create a new message with only non-text content
|
|
978
|
+
non_text_message = HumanMessage(content=non_text_parts)
|
|
979
|
+
input['messages'].append(non_text_message)
|
|
980
|
+
|
|
981
|
+
elif isinstance(current_content, str):
|
|
982
|
+
# on regenerate case
|
|
983
|
+
input['input'] = current_content
|
|
984
|
+
# If from messages and all content is text, remove the message
|
|
985
|
+
if input_from_messages:
|
|
986
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
987
|
+
else:
|
|
988
|
+
input['input'] = str(current_content)
|
|
989
|
+
# If from messages, remove since we extracted the content
|
|
990
|
+
if input_from_messages:
|
|
991
|
+
input['messages'] = [msg for msg in input['messages'] if msg is not current_message]
|
|
992
|
+
elif isinstance(current_message, str):
|
|
993
|
+
input['input'] = current_message
|
|
994
|
+
else:
|
|
995
|
+
input['input'] = str(current_message)
|
|
681
996
|
if input.get('messages'):
|
|
682
997
|
# Ensure existing messages are LangChain objects
|
|
683
998
|
input['messages'] = [convert_dict_to_message(msg) for msg in input['messages']]
|
|
684
999
|
# Append to existing messages
|
|
685
|
-
input['messages'].append(current_message)
|
|
1000
|
+
# input['messages'].append(current_message)
|
|
1001
|
+
# else:
|
|
1002
|
+
# NOTE: Commented out to prevent duplicates with input['input']
|
|
1003
|
+
# input['messages'] = [current_message]
|
|
1004
|
+
|
|
1005
|
+
# Validate that input is not empty after all processing
|
|
1006
|
+
if not input.get('input'):
|
|
1007
|
+
raise RuntimeError(
|
|
1008
|
+
"Empty input after processing. Cannot send empty string to LLM. "
|
|
1009
|
+
"This likely means the message contained only non-text content "
|
|
1010
|
+
"with no accompanying text."
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
logger.info(f"Input: {thread_id} - {input}")
|
|
1014
|
+
try:
|
|
1015
|
+
if self.checkpointer and self.checkpointer.get_tuple(config):
|
|
1016
|
+
if config.pop("should_continue", False):
|
|
1017
|
+
invoke_input = input
|
|
1018
|
+
else:
|
|
1019
|
+
self.update_state(config, input)
|
|
1020
|
+
invoke_input = None
|
|
1021
|
+
result = super().invoke(invoke_input, config=config, *args, **kwargs)
|
|
686
1022
|
else:
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
self.
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
1023
|
+
result = super().invoke(input, config=config, *args, **kwargs)
|
|
1024
|
+
except GraphRecursionError as e:
|
|
1025
|
+
current_recursion_limit = config.get("recursion_limit", 0)
|
|
1026
|
+
logger.warning("ToolExecutionLimitReached caught in LangGraphAgentRunnable: %s", e)
|
|
1027
|
+
return self._handle_graph_recursion_error(
|
|
1028
|
+
config=config,
|
|
1029
|
+
thread_id=thread_id,
|
|
1030
|
+
current_recursion_limit=current_recursion_limit,
|
|
1031
|
+
)
|
|
1032
|
+
|
|
695
1033
|
try:
|
|
696
|
-
if
|
|
697
|
-
|
|
698
|
-
|
|
1034
|
+
# Check if printer node output exists
|
|
1035
|
+
printer_output = result.get(PRINTER_NODE_RS)
|
|
1036
|
+
if printer_output == PRINTER_COMPLETED_STATE:
|
|
1037
|
+
# Printer completed, extract last AI message
|
|
1038
|
+
messages = result['messages']
|
|
1039
|
+
output = next(
|
|
1040
|
+
(msg.content for msg in reversed(messages)
|
|
1041
|
+
if not isinstance(msg, HumanMessage)),
|
|
1042
|
+
messages[-1].content
|
|
1043
|
+
) if messages else result.get('output')
|
|
1044
|
+
elif printer_output is not None:
|
|
1045
|
+
# Printer node has output (interrupted state)
|
|
1046
|
+
output = printer_output
|
|
699
1047
|
else:
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
1048
|
+
# No printer node, extract last AI message from messages
|
|
1049
|
+
messages = result.get('messages', [])
|
|
1050
|
+
output = next(
|
|
1051
|
+
(msg.content for msg in reversed(messages)
|
|
1052
|
+
if not isinstance(msg, HumanMessage)),
|
|
1053
|
+
None
|
|
1054
|
+
)
|
|
1055
|
+
except Exception:
|
|
1056
|
+
# Fallback: try to get last value or last message
|
|
1057
|
+
output = str(list(result.values())[-1]) if result else 'Output is undefined'
|
|
704
1058
|
config_state = self.get_state(config)
|
|
705
|
-
|
|
706
|
-
|
|
1059
|
+
is_execution_finished = not config_state.next
|
|
1060
|
+
if is_execution_finished:
|
|
1061
|
+
thread_id = None
|
|
1062
|
+
|
|
1063
|
+
final_output = f"Assistant run has been completed, but output is None.\nAdding last message if any: {messages[-1] if messages else []}" if is_execution_finished and output is None else output
|
|
707
1064
|
|
|
708
1065
|
result_with_state = {
|
|
709
|
-
"output":
|
|
1066
|
+
"output": final_output,
|
|
710
1067
|
"thread_id": thread_id,
|
|
711
|
-
"execution_finished":
|
|
1068
|
+
"execution_finished": is_execution_finished
|
|
712
1069
|
}
|
|
713
1070
|
|
|
714
1071
|
# Include all state values in the result
|
|
715
1072
|
if hasattr(config_state, 'values') and config_state.values:
|
|
1073
|
+
# except of key = 'output' which is already included
|
|
1074
|
+
for key, value in config_state.values.items():
|
|
1075
|
+
if key != 'output':
|
|
1076
|
+
result_with_state[key] = value
|
|
1077
|
+
|
|
1078
|
+
return result_with_state
|
|
1079
|
+
|
|
1080
|
+
def _handle_graph_recursion_error(
|
|
1081
|
+
self,
|
|
1082
|
+
config: RunnableConfig,
|
|
1083
|
+
thread_id: str,
|
|
1084
|
+
current_recursion_limit: int,
|
|
1085
|
+
) -> dict:
|
|
1086
|
+
"""Handle GraphRecursionError by returning a soft-boundary response."""
|
|
1087
|
+
config_state = self.get_state(config)
|
|
1088
|
+
is_execution_finished = False
|
|
1089
|
+
|
|
1090
|
+
friendly_output = (
|
|
1091
|
+
f"Tool step limit {current_recursion_limit} reached for this run. You can continue by sending another "
|
|
1092
|
+
"message or refining your request."
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
result_with_state: dict[str, Any] = {
|
|
1096
|
+
"output": friendly_output,
|
|
1097
|
+
"thread_id": thread_id,
|
|
1098
|
+
"execution_finished": is_execution_finished,
|
|
1099
|
+
"tool_execution_limit_reached": True,
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
if hasattr(config_state, "values") and config_state.values:
|
|
716
1103
|
for key, value in config_state.values.items():
|
|
717
|
-
|
|
1104
|
+
if key != "output":
|
|
1105
|
+
result_with_state[key] = value
|
|
718
1106
|
|
|
719
1107
|
return result_with_state
|
|
720
1108
|
|