hexdag 0.5.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hexdag/__init__.py +116 -0
- hexdag/__main__.py +30 -0
- hexdag/adapters/executors/__init__.py +5 -0
- hexdag/adapters/executors/local_executor.py +316 -0
- hexdag/builtin/__init__.py +6 -0
- hexdag/builtin/adapters/__init__.py +51 -0
- hexdag/builtin/adapters/anthropic/__init__.py +5 -0
- hexdag/builtin/adapters/anthropic/anthropic_adapter.py +151 -0
- hexdag/builtin/adapters/database/__init__.py +6 -0
- hexdag/builtin/adapters/database/csv/csv_adapter.py +249 -0
- hexdag/builtin/adapters/database/pgvector/__init__.py +5 -0
- hexdag/builtin/adapters/database/pgvector/pgvector_adapter.py +478 -0
- hexdag/builtin/adapters/database/sqlalchemy/sqlalchemy_adapter.py +252 -0
- hexdag/builtin/adapters/database/sqlite/__init__.py +5 -0
- hexdag/builtin/adapters/database/sqlite/sqlite_adapter.py +410 -0
- hexdag/builtin/adapters/local/README.md +59 -0
- hexdag/builtin/adapters/local/__init__.py +7 -0
- hexdag/builtin/adapters/local/local_observer_manager.py +696 -0
- hexdag/builtin/adapters/memory/__init__.py +47 -0
- hexdag/builtin/adapters/memory/file_memory_adapter.py +297 -0
- hexdag/builtin/adapters/memory/in_memory_memory.py +216 -0
- hexdag/builtin/adapters/memory/schemas.py +57 -0
- hexdag/builtin/adapters/memory/session_memory.py +178 -0
- hexdag/builtin/adapters/memory/sqlite_memory_adapter.py +215 -0
- hexdag/builtin/adapters/memory/state_memory.py +280 -0
- hexdag/builtin/adapters/mock/README.md +89 -0
- hexdag/builtin/adapters/mock/__init__.py +15 -0
- hexdag/builtin/adapters/mock/hexdag.toml +50 -0
- hexdag/builtin/adapters/mock/mock_database.py +225 -0
- hexdag/builtin/adapters/mock/mock_embedding.py +223 -0
- hexdag/builtin/adapters/mock/mock_llm.py +177 -0
- hexdag/builtin/adapters/mock/mock_tool_adapter.py +192 -0
- hexdag/builtin/adapters/mock/mock_tool_router.py +232 -0
- hexdag/builtin/adapters/openai/__init__.py +5 -0
- hexdag/builtin/adapters/openai/openai_adapter.py +634 -0
- hexdag/builtin/adapters/secret/__init__.py +7 -0
- hexdag/builtin/adapters/secret/local_secret_adapter.py +248 -0
- hexdag/builtin/adapters/unified_tool_router.py +280 -0
- hexdag/builtin/macros/__init__.py +17 -0
- hexdag/builtin/macros/conversation_agent.py +390 -0
- hexdag/builtin/macros/llm_macro.py +151 -0
- hexdag/builtin/macros/reasoning_agent.py +423 -0
- hexdag/builtin/macros/tool_macro.py +380 -0
- hexdag/builtin/nodes/__init__.py +38 -0
- hexdag/builtin/nodes/_discovery.py +123 -0
- hexdag/builtin/nodes/agent_node.py +696 -0
- hexdag/builtin/nodes/base_node_factory.py +242 -0
- hexdag/builtin/nodes/composite_node.py +926 -0
- hexdag/builtin/nodes/data_node.py +201 -0
- hexdag/builtin/nodes/expression_node.py +487 -0
- hexdag/builtin/nodes/function_node.py +454 -0
- hexdag/builtin/nodes/llm_node.py +491 -0
- hexdag/builtin/nodes/loop_node.py +920 -0
- hexdag/builtin/nodes/mapped_input.py +518 -0
- hexdag/builtin/nodes/port_call_node.py +269 -0
- hexdag/builtin/nodes/tool_call_node.py +195 -0
- hexdag/builtin/nodes/tool_utils.py +390 -0
- hexdag/builtin/prompts/__init__.py +68 -0
- hexdag/builtin/prompts/base.py +422 -0
- hexdag/builtin/prompts/chat_prompts.py +303 -0
- hexdag/builtin/prompts/error_correction_prompts.py +320 -0
- hexdag/builtin/prompts/tool_prompts.py +160 -0
- hexdag/builtin/tools/builtin_tools.py +84 -0
- hexdag/builtin/tools/database_tools.py +164 -0
- hexdag/cli/__init__.py +17 -0
- hexdag/cli/__main__.py +7 -0
- hexdag/cli/commands/__init__.py +27 -0
- hexdag/cli/commands/build_cmd.py +812 -0
- hexdag/cli/commands/create_cmd.py +208 -0
- hexdag/cli/commands/docs_cmd.py +293 -0
- hexdag/cli/commands/generate_types_cmd.py +252 -0
- hexdag/cli/commands/init_cmd.py +188 -0
- hexdag/cli/commands/pipeline_cmd.py +494 -0
- hexdag/cli/commands/plugin_dev_cmd.py +529 -0
- hexdag/cli/commands/plugins_cmd.py +441 -0
- hexdag/cli/commands/studio_cmd.py +101 -0
- hexdag/cli/commands/validate_cmd.py +221 -0
- hexdag/cli/main.py +84 -0
- hexdag/core/__init__.py +83 -0
- hexdag/core/config/__init__.py +20 -0
- hexdag/core/config/loader.py +479 -0
- hexdag/core/config/models.py +150 -0
- hexdag/core/configurable.py +294 -0
- hexdag/core/context/__init__.py +37 -0
- hexdag/core/context/execution_context.py +378 -0
- hexdag/core/docs/__init__.py +26 -0
- hexdag/core/docs/extractors.py +678 -0
- hexdag/core/docs/generators.py +890 -0
- hexdag/core/docs/models.py +120 -0
- hexdag/core/domain/__init__.py +10 -0
- hexdag/core/domain/dag.py +1225 -0
- hexdag/core/exceptions.py +234 -0
- hexdag/core/expression_parser.py +569 -0
- hexdag/core/logging.py +449 -0
- hexdag/core/models/__init__.py +17 -0
- hexdag/core/models/base.py +138 -0
- hexdag/core/orchestration/__init__.py +46 -0
- hexdag/core/orchestration/body_executor.py +481 -0
- hexdag/core/orchestration/components/__init__.py +97 -0
- hexdag/core/orchestration/components/adapter_lifecycle_manager.py +113 -0
- hexdag/core/orchestration/components/checkpoint_manager.py +134 -0
- hexdag/core/orchestration/components/execution_coordinator.py +360 -0
- hexdag/core/orchestration/components/health_check_manager.py +176 -0
- hexdag/core/orchestration/components/input_mapper.py +143 -0
- hexdag/core/orchestration/components/lifecycle_manager.py +583 -0
- hexdag/core/orchestration/components/node_executor.py +377 -0
- hexdag/core/orchestration/components/secret_manager.py +202 -0
- hexdag/core/orchestration/components/wave_executor.py +158 -0
- hexdag/core/orchestration/constants.py +17 -0
- hexdag/core/orchestration/events/README.md +312 -0
- hexdag/core/orchestration/events/__init__.py +104 -0
- hexdag/core/orchestration/events/batching.py +330 -0
- hexdag/core/orchestration/events/decorators.py +139 -0
- hexdag/core/orchestration/events/events.py +573 -0
- hexdag/core/orchestration/events/observers/__init__.py +30 -0
- hexdag/core/orchestration/events/observers/core_observers.py +690 -0
- hexdag/core/orchestration/events/observers/models.py +111 -0
- hexdag/core/orchestration/events/taxonomy.py +269 -0
- hexdag/core/orchestration/hook_context.py +237 -0
- hexdag/core/orchestration/hooks.py +437 -0
- hexdag/core/orchestration/models.py +418 -0
- hexdag/core/orchestration/orchestrator.py +910 -0
- hexdag/core/orchestration/orchestrator_factory.py +275 -0
- hexdag/core/orchestration/port_wrappers.py +327 -0
- hexdag/core/orchestration/prompt/__init__.py +32 -0
- hexdag/core/orchestration/prompt/template.py +332 -0
- hexdag/core/pipeline_builder/__init__.py +21 -0
- hexdag/core/pipeline_builder/component_instantiator.py +386 -0
- hexdag/core/pipeline_builder/include_tag.py +265 -0
- hexdag/core/pipeline_builder/pipeline_config.py +133 -0
- hexdag/core/pipeline_builder/py_tag.py +223 -0
- hexdag/core/pipeline_builder/tag_discovery.py +268 -0
- hexdag/core/pipeline_builder/yaml_builder.py +1196 -0
- hexdag/core/pipeline_builder/yaml_validator.py +569 -0
- hexdag/core/ports/__init__.py +65 -0
- hexdag/core/ports/api_call.py +133 -0
- hexdag/core/ports/database.py +489 -0
- hexdag/core/ports/embedding.py +215 -0
- hexdag/core/ports/executor.py +237 -0
- hexdag/core/ports/file_storage.py +117 -0
- hexdag/core/ports/healthcheck.py +87 -0
- hexdag/core/ports/llm.py +551 -0
- hexdag/core/ports/memory.py +70 -0
- hexdag/core/ports/observer_manager.py +130 -0
- hexdag/core/ports/secret.py +145 -0
- hexdag/core/ports/tool_router.py +94 -0
- hexdag/core/ports_builder.py +623 -0
- hexdag/core/protocols.py +273 -0
- hexdag/core/resolver.py +304 -0
- hexdag/core/schema/__init__.py +9 -0
- hexdag/core/schema/generator.py +742 -0
- hexdag/core/secrets.py +242 -0
- hexdag/core/types.py +413 -0
- hexdag/core/utils/async_warnings.py +206 -0
- hexdag/core/utils/schema_conversion.py +78 -0
- hexdag/core/utils/sql_validation.py +86 -0
- hexdag/core/validation/secure_json.py +148 -0
- hexdag/core/yaml_macro.py +517 -0
- hexdag/mcp_server.py +3120 -0
- hexdag/studio/__init__.py +10 -0
- hexdag/studio/build_ui.py +92 -0
- hexdag/studio/server/__init__.py +1 -0
- hexdag/studio/server/main.py +100 -0
- hexdag/studio/server/routes/__init__.py +9 -0
- hexdag/studio/server/routes/execute.py +208 -0
- hexdag/studio/server/routes/export.py +558 -0
- hexdag/studio/server/routes/files.py +207 -0
- hexdag/studio/server/routes/plugins.py +419 -0
- hexdag/studio/server/routes/validate.py +220 -0
- hexdag/studio/ui/index.html +13 -0
- hexdag/studio/ui/package-lock.json +2992 -0
- hexdag/studio/ui/package.json +31 -0
- hexdag/studio/ui/postcss.config.js +6 -0
- hexdag/studio/ui/public/hexdag.svg +5 -0
- hexdag/studio/ui/src/App.tsx +251 -0
- hexdag/studio/ui/src/components/Canvas.tsx +408 -0
- hexdag/studio/ui/src/components/ContextMenu.tsx +187 -0
- hexdag/studio/ui/src/components/FileBrowser.tsx +123 -0
- hexdag/studio/ui/src/components/Header.tsx +181 -0
- hexdag/studio/ui/src/components/HexdagNode.tsx +193 -0
- hexdag/studio/ui/src/components/NodeInspector.tsx +512 -0
- hexdag/studio/ui/src/components/NodePalette.tsx +262 -0
- hexdag/studio/ui/src/components/NodePortsSection.tsx +403 -0
- hexdag/studio/ui/src/components/PluginManager.tsx +347 -0
- hexdag/studio/ui/src/components/PortsEditor.tsx +481 -0
- hexdag/studio/ui/src/components/PythonEditor.tsx +195 -0
- hexdag/studio/ui/src/components/ValidationPanel.tsx +105 -0
- hexdag/studio/ui/src/components/YamlEditor.tsx +196 -0
- hexdag/studio/ui/src/components/index.ts +8 -0
- hexdag/studio/ui/src/index.css +92 -0
- hexdag/studio/ui/src/main.tsx +10 -0
- hexdag/studio/ui/src/types/index.ts +123 -0
- hexdag/studio/ui/src/vite-env.d.ts +1 -0
- hexdag/studio/ui/tailwind.config.js +29 -0
- hexdag/studio/ui/tsconfig.json +37 -0
- hexdag/studio/ui/tsconfig.node.json +13 -0
- hexdag/studio/ui/vite.config.ts +35 -0
- hexdag/visualization/__init__.py +69 -0
- hexdag/visualization/dag_visualizer.py +1020 -0
- hexdag-0.5.0.dev1.dist-info/METADATA +369 -0
- hexdag-0.5.0.dev1.dist-info/RECORD +261 -0
- hexdag-0.5.0.dev1.dist-info/WHEEL +4 -0
- hexdag-0.5.0.dev1.dist-info/entry_points.txt +4 -0
- hexdag-0.5.0.dev1.dist-info/licenses/LICENSE +190 -0
- hexdag_plugins/.gitignore +43 -0
- hexdag_plugins/README.md +73 -0
- hexdag_plugins/__init__.py +1 -0
- hexdag_plugins/azure/LICENSE +21 -0
- hexdag_plugins/azure/README.md +414 -0
- hexdag_plugins/azure/__init__.py +21 -0
- hexdag_plugins/azure/azure_blob_adapter.py +450 -0
- hexdag_plugins/azure/azure_cosmos_adapter.py +383 -0
- hexdag_plugins/azure/azure_keyvault_adapter.py +314 -0
- hexdag_plugins/azure/azure_openai_adapter.py +415 -0
- hexdag_plugins/azure/pyproject.toml +107 -0
- hexdag_plugins/azure/tests/__init__.py +1 -0
- hexdag_plugins/azure/tests/test_azure_blob_adapter.py +350 -0
- hexdag_plugins/azure/tests/test_azure_cosmos_adapter.py +323 -0
- hexdag_plugins/azure/tests/test_azure_keyvault_adapter.py +330 -0
- hexdag_plugins/azure/tests/test_azure_openai_adapter.py +329 -0
- hexdag_plugins/hexdag_etl/README.md +168 -0
- hexdag_plugins/hexdag_etl/__init__.py +53 -0
- hexdag_plugins/hexdag_etl/examples/01_simple_pandas_transform.py +270 -0
- hexdag_plugins/hexdag_etl/examples/02_simple_pandas_only.py +149 -0
- hexdag_plugins/hexdag_etl/examples/03_file_io_pipeline.py +109 -0
- hexdag_plugins/hexdag_etl/examples/test_pandas_transform.py +84 -0
- hexdag_plugins/hexdag_etl/hexdag.toml +25 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/__init__.py +48 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/__init__.py +13 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/api_extract.py +230 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/base_node_factory.py +181 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/file_io.py +415 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/outlook.py +492 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/pandas_transform.py +563 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/sql_extract_load.py +112 -0
- hexdag_plugins/hexdag_etl/pyproject.toml +82 -0
- hexdag_plugins/hexdag_etl/test_transform.py +54 -0
- hexdag_plugins/hexdag_etl/tests/test_plugin_integration.py +62 -0
- hexdag_plugins/mysql_adapter/LICENSE +21 -0
- hexdag_plugins/mysql_adapter/README.md +224 -0
- hexdag_plugins/mysql_adapter/__init__.py +6 -0
- hexdag_plugins/mysql_adapter/mysql_adapter.py +408 -0
- hexdag_plugins/mysql_adapter/pyproject.toml +93 -0
- hexdag_plugins/mysql_adapter/tests/test_mysql_adapter.py +259 -0
- hexdag_plugins/storage/README.md +184 -0
- hexdag_plugins/storage/__init__.py +19 -0
- hexdag_plugins/storage/file/__init__.py +5 -0
- hexdag_plugins/storage/file/local.py +325 -0
- hexdag_plugins/storage/ports/__init__.py +5 -0
- hexdag_plugins/storage/ports/vector_store.py +236 -0
- hexdag_plugins/storage/sql/__init__.py +7 -0
- hexdag_plugins/storage/sql/base.py +187 -0
- hexdag_plugins/storage/sql/mysql.py +27 -0
- hexdag_plugins/storage/sql/postgresql.py +27 -0
- hexdag_plugins/storage/tests/__init__.py +1 -0
- hexdag_plugins/storage/tests/test_local_file_storage.py +161 -0
- hexdag_plugins/storage/tests/test_sql_adapters.py +212 -0
- hexdag_plugins/storage/vector/__init__.py +7 -0
- hexdag_plugins/storage/vector/chromadb.py +223 -0
- hexdag_plugins/storage/vector/in_memory.py +285 -0
- hexdag_plugins/storage/vector/pgvector.py +502 -0
|
@@ -0,0 +1,920 @@
|
|
|
1
|
+
"""LoopNode for creating loop control nodes with conditional execution.
|
|
2
|
+
|
|
3
|
+
.. deprecated::
|
|
4
|
+
LoopNode and ConditionalNode are deprecated. Use CompositeNode instead:
|
|
5
|
+
- LoopNode → CompositeNode with mode='while'
|
|
6
|
+
- ConditionalNode → CompositeNode with mode='switch'
|
|
7
|
+
|
|
8
|
+
This module provides:
|
|
9
|
+
- LoopNode: iterative control with a single while_condition,
|
|
10
|
+
state preservation, and result collection by convention.
|
|
11
|
+
- ConditionalNode: multi-branch router with callable predicates.
|
|
12
|
+
|
|
13
|
+
Single-style (functional) API:
|
|
14
|
+
- No string predicates are supported.
|
|
15
|
+
|
|
16
|
+
Conventions:
|
|
17
|
+
- Prefer while_condition for loop control.
|
|
18
|
+
- Result collection:
|
|
19
|
+
- If iterating over a collection and you want all outputs → set collect_mode="list".
|
|
20
|
+
- Otherwise → defaults to "last".
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import asyncio
|
|
24
|
+
import time
|
|
25
|
+
import warnings
|
|
26
|
+
from collections.abc import Callable, Collection, Iterable
|
|
27
|
+
from dataclasses import dataclass
|
|
28
|
+
from enum import Enum
|
|
29
|
+
from typing import Any, Literal, cast
|
|
30
|
+
|
|
31
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
32
|
+
|
|
33
|
+
from hexdag.builtin.nodes.base_node_factory import BaseNodeFactory
|
|
34
|
+
from hexdag.core.domain.dag import NodeSpec
|
|
35
|
+
from hexdag.core.logging import get_logger
|
|
36
|
+
|
|
37
|
+
logger = get_logger(__name__)
|
|
38
|
+
|
|
39
|
+
CollectMode = Literal["list", "last", "reduce"]
|
|
40
|
+
TieBreak = Literal["first_true"]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class StopReason(str, Enum):
|
|
44
|
+
"""Reasons for loop termination in metadata of LoopNode"""
|
|
45
|
+
|
|
46
|
+
CONDITION = "condition" # while_condition returned False
|
|
47
|
+
LIMIT = "limit" # max_iterations reached
|
|
48
|
+
CONDITION_ERROR = "condition_error" # while_condition raised an exception
|
|
49
|
+
BREAK_GUARD = "break_guard" # break_if predicate triggered
|
|
50
|
+
NONE = "none" # loop did not run or ended unexpectedly
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(frozen=True)
|
|
54
|
+
class ReduceConfig:
|
|
55
|
+
"""Configuration for reduce-based result collection in loops."""
|
|
56
|
+
|
|
57
|
+
reducer: Callable[[Any, Any], Any]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class NodeParams(BaseModel):
|
|
61
|
+
"""Parameters for node construction."""
|
|
62
|
+
|
|
63
|
+
model_config = ConfigDict(extra="forbid")
|
|
64
|
+
in_model: Any | None = None
|
|
65
|
+
out_model: Any | None = None
|
|
66
|
+
deps: set[str] = Field(default_factory=set)
|
|
67
|
+
|
|
68
|
+
@field_validator("deps", mode="before")
|
|
69
|
+
@classmethod
|
|
70
|
+
def _coerce_deps(cls, v: str | Collection[str] | None) -> set[str]:
|
|
71
|
+
if v is None:
|
|
72
|
+
return set()
|
|
73
|
+
if isinstance(v, (list, tuple, set)):
|
|
74
|
+
return set(v)
|
|
75
|
+
if isinstance(v, str):
|
|
76
|
+
return {v}
|
|
77
|
+
raise ValueError("deps must be a collection of strings or a single string")
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _normalize_input_data(input_data: Any) -> Any:
|
|
81
|
+
"""Normalize input data for loop and conditional nodes."""
|
|
82
|
+
if hasattr(input_data, "model_dump"):
|
|
83
|
+
data = input_data.model_dump()
|
|
84
|
+
elif isinstance(input_data, dict):
|
|
85
|
+
data = dict(input_data)
|
|
86
|
+
else:
|
|
87
|
+
data = {"input": input_data}
|
|
88
|
+
|
|
89
|
+
return data
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _eval_break_guards(
|
|
93
|
+
guards: Iterable[Callable[[dict, dict], bool]], data: dict, state: dict
|
|
94
|
+
) -> bool:
|
|
95
|
+
"""OR-semantics: return True if any guard signals to break; guard errors are ignored."""
|
|
96
|
+
for idx, g in enumerate(guards):
|
|
97
|
+
try:
|
|
98
|
+
if g(data, state):
|
|
99
|
+
return True
|
|
100
|
+
except Exception as e:
|
|
101
|
+
logger.warning("break_if[%d] raised; ignoring error: %s", idx, e)
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _apply_on_iteration_end(
|
|
106
|
+
on_end: Callable[[dict, Any], dict] | None, state: dict, out: Any
|
|
107
|
+
) -> dict:
|
|
108
|
+
"""Run on_iteration_end; ignore errors; return possibly updated state."""
|
|
109
|
+
if on_end is None:
|
|
110
|
+
return state
|
|
111
|
+
try:
|
|
112
|
+
new_state = on_end(state, out)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.warning("on_iteration_end failed: %s", e)
|
|
115
|
+
return state
|
|
116
|
+
if new_state is state:
|
|
117
|
+
return dict(state)
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
return dict(new_state)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
logger.warning("on_iteration_end failed: %s", e)
|
|
123
|
+
return state
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class LoopNode(BaseNodeFactory):
|
|
127
|
+
"""Advanced loop control node (functional-only).
|
|
128
|
+
|
|
129
|
+
.. deprecated::
|
|
130
|
+
LoopNode is deprecated. Use CompositeNode with mode='while' instead.
|
|
131
|
+
See hexdag.builtin.nodes.composite_node for the new unified API.
|
|
132
|
+
|
|
133
|
+
Key points:
|
|
134
|
+
- Single controlling predicate: while_condition(data, state) -> bool (required).
|
|
135
|
+
- State is preserved across iterations via the state dict (updated per iteration).
|
|
136
|
+
- Result shape: {
|
|
137
|
+
"result": <list|last|reduced>,
|
|
138
|
+
"metadata": {
|
|
139
|
+
iterations, stopped_by, max_iterations, state
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
- No string predicates, no eval-based resolution.
|
|
143
|
+
Result collection:
|
|
144
|
+
- To collect all outputs, set collect_mode="list".
|
|
145
|
+
- To reduce across iterations, set collect_mode="reduce" and provide a reducer.
|
|
146
|
+
- Otherwise defaults to "last".
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
# Explicit schema for YAML/MCP usage (builder pattern doesn't expose params well)
|
|
150
|
+
_yaml_schema: dict[str, Any] = {
|
|
151
|
+
"type": "object",
|
|
152
|
+
"description": "Loop control node for iterative processing",
|
|
153
|
+
"properties": {
|
|
154
|
+
"while_condition": {
|
|
155
|
+
"type": "string",
|
|
156
|
+
"description": "Module path to condition function: (data, state) -> bool",
|
|
157
|
+
},
|
|
158
|
+
"body": {
|
|
159
|
+
"type": "string",
|
|
160
|
+
"description": "Module path to body function: (data, state) -> Any",
|
|
161
|
+
},
|
|
162
|
+
"max_iterations": {
|
|
163
|
+
"type": "integer",
|
|
164
|
+
"default": 100,
|
|
165
|
+
"description": "Maximum number of iterations before stopping",
|
|
166
|
+
},
|
|
167
|
+
"collect_mode": {
|
|
168
|
+
"type": "string",
|
|
169
|
+
"enum": ["last", "list", "reduce"],
|
|
170
|
+
"default": "last",
|
|
171
|
+
"description": "How to collect results: last value, all values, or reduced",
|
|
172
|
+
},
|
|
173
|
+
"initial_state": {
|
|
174
|
+
"type": "object",
|
|
175
|
+
"default": {},
|
|
176
|
+
"description": "Initial state dict passed to first iteration",
|
|
177
|
+
},
|
|
178
|
+
"iteration_key": {
|
|
179
|
+
"type": "string",
|
|
180
|
+
"default": "loop_iteration",
|
|
181
|
+
"description": "Key name for current iteration number in state",
|
|
182
|
+
},
|
|
183
|
+
},
|
|
184
|
+
"required": ["while_condition", "body"],
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
def __init__(
|
|
188
|
+
self, name: str | None = None, condition: Callable[[dict, dict], bool] | None = None
|
|
189
|
+
) -> None:
|
|
190
|
+
"""Initialize LoopNode factory."""
|
|
191
|
+
super().__init__()
|
|
192
|
+
|
|
193
|
+
self._name: str | None = name
|
|
194
|
+
self._while: Callable[[dict, dict], bool] | None = condition
|
|
195
|
+
self._body: Callable[[dict, dict], Any] = lambda d, s: None
|
|
196
|
+
self._on_end: Callable[[dict, Any], dict] = lambda s, _: s
|
|
197
|
+
self._init_state: dict = {}
|
|
198
|
+
self._collect_mode: CollectMode = "last"
|
|
199
|
+
self._reduce_cfg: ReduceConfig | None = None
|
|
200
|
+
self._max_iter: int = 1
|
|
201
|
+
self._iter_key: str = "loop_iteration"
|
|
202
|
+
self._break_if: list[Callable[[dict, dict], bool]] = []
|
|
203
|
+
|
|
204
|
+
self._deps: set[str] = set()
|
|
205
|
+
self._in_model: Any | None = None
|
|
206
|
+
self._out_model: Any | None = None
|
|
207
|
+
|
|
208
|
+
def name(self, n: str) -> "LoopNode":
|
|
209
|
+
"""Set the node name."""
|
|
210
|
+
self._name = n
|
|
211
|
+
return self
|
|
212
|
+
|
|
213
|
+
def condition(self, fn: Callable[[dict, dict], bool] | str) -> "LoopNode":
|
|
214
|
+
"""Set the loop continuation condition function.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
fn : Callable[[dict, dict], bool] | str
|
|
219
|
+
Either a callable predicate function that takes (data, state)
|
|
220
|
+
and returns bool, or a string expression like "state.iteration < 10"
|
|
221
|
+
that will be compiled into a safe predicate.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
LoopNode
|
|
226
|
+
Self for method chaining.
|
|
227
|
+
|
|
228
|
+
Examples
|
|
229
|
+
--------
|
|
230
|
+
Using a callable::
|
|
231
|
+
|
|
232
|
+
node.condition(lambda d, s: s.get("iteration", 0) < 10)
|
|
233
|
+
|
|
234
|
+
Using a string expression::
|
|
235
|
+
|
|
236
|
+
node.condition("state.iteration < 10")
|
|
237
|
+
node.condition("not done and count < max_count")
|
|
238
|
+
"""
|
|
239
|
+
if isinstance(fn, str):
|
|
240
|
+
from hexdag.core.expression_parser import compile_expression
|
|
241
|
+
|
|
242
|
+
fn = compile_expression(fn)
|
|
243
|
+
elif not callable(fn):
|
|
244
|
+
raise ValueError("condition(): fn must be callable or a string expression")
|
|
245
|
+
self._while = fn
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
def do(self, fn: Callable[[dict, dict], Any]) -> "LoopNode":
|
|
249
|
+
"""Set the loop body function."""
|
|
250
|
+
self._body = fn
|
|
251
|
+
return self
|
|
252
|
+
|
|
253
|
+
def on_iteration_end(self, fn: Callable[[dict, Any], dict]) -> "LoopNode":
|
|
254
|
+
"""Set the state update function called after each iteration."""
|
|
255
|
+
self._on_end = fn
|
|
256
|
+
return self
|
|
257
|
+
|
|
258
|
+
def init_state(self, state: dict) -> "LoopNode":
|
|
259
|
+
self._init_state = dict(state or {})
|
|
260
|
+
return self
|
|
261
|
+
|
|
262
|
+
def collect_last(self) -> "LoopNode":
|
|
263
|
+
self._collect_mode = "last"
|
|
264
|
+
self._reduce_cfg = None
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
def collect_list(self) -> "LoopNode":
|
|
268
|
+
self._collect_mode = "list"
|
|
269
|
+
self._reduce_cfg = None
|
|
270
|
+
return self
|
|
271
|
+
|
|
272
|
+
def collect_reduce(self, reducer: Callable[[Any, Any], Any]) -> "LoopNode":
|
|
273
|
+
self._collect_mode = "reduce"
|
|
274
|
+
self._reduce_cfg = ReduceConfig(reducer=reducer)
|
|
275
|
+
return self
|
|
276
|
+
|
|
277
|
+
def max_iterations(self, n: int) -> "LoopNode":
|
|
278
|
+
self._max_iter = n
|
|
279
|
+
return self
|
|
280
|
+
|
|
281
|
+
def iteration_key(self, key: str) -> "LoopNode":
|
|
282
|
+
self._iter_key = key
|
|
283
|
+
return self
|
|
284
|
+
|
|
285
|
+
def break_if(self, *preds: Callable[[dict, dict], bool]) -> "LoopNode":
|
|
286
|
+
self._break_if.extend(preds)
|
|
287
|
+
return self
|
|
288
|
+
|
|
289
|
+
def deps(self, deps: Iterable[str]) -> "LoopNode":
|
|
290
|
+
self._deps = set(deps or [])
|
|
291
|
+
return self
|
|
292
|
+
|
|
293
|
+
def in_model(self, model: Any) -> "LoopNode":
|
|
294
|
+
self._in_model = model
|
|
295
|
+
return self
|
|
296
|
+
|
|
297
|
+
def out_model(self, model: Any) -> "LoopNode":
|
|
298
|
+
self._out_model = model
|
|
299
|
+
return self
|
|
300
|
+
|
|
301
|
+
@staticmethod
|
|
302
|
+
def _init_collector(
|
|
303
|
+
mode: CollectMode, reduce_cfg: ReduceConfig | None
|
|
304
|
+
) -> tuple[Any, Callable[[Any, Any], Any]]:
|
|
305
|
+
"""Initialize collector node."""
|
|
306
|
+
if mode == "list":
|
|
307
|
+
acc: Any = []
|
|
308
|
+
|
|
309
|
+
def collect_list(a: Any, x: Any) -> Any:
|
|
310
|
+
a.append(x)
|
|
311
|
+
return a
|
|
312
|
+
|
|
313
|
+
return acc, collect_list
|
|
314
|
+
if mode == "last":
|
|
315
|
+
acc = None
|
|
316
|
+
|
|
317
|
+
def collect_last(_a: Any, x: Any) -> Any:
|
|
318
|
+
return x
|
|
319
|
+
|
|
320
|
+
return acc, collect_last
|
|
321
|
+
if mode == "reduce":
|
|
322
|
+
if reduce_cfg is None or reduce_cfg.reducer is None:
|
|
323
|
+
raise ValueError("ReduceConfig is required when collect_mode ='reduce'")
|
|
324
|
+
reducer = reduce_cfg.reducer
|
|
325
|
+
acc = None
|
|
326
|
+
|
|
327
|
+
def collect_reduce(a: Any, x: Any) -> Any:
|
|
328
|
+
return x if a is None else reducer(a, x)
|
|
329
|
+
|
|
330
|
+
return acc, collect_reduce
|
|
331
|
+
raise ValueError("collect_mode must be one of: list | last | reduce")
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def _should_continue(
|
|
335
|
+
while_condition: Callable[[dict, dict], bool], data: dict, state: dict
|
|
336
|
+
) -> tuple[bool, StopReason | None]:
|
|
337
|
+
"""Return whether condition should continue."""
|
|
338
|
+
try:
|
|
339
|
+
ok = bool(while_condition(data, state))
|
|
340
|
+
if not ok:
|
|
341
|
+
return False, StopReason.CONDITION
|
|
342
|
+
return True, None
|
|
343
|
+
except Exception as e:
|
|
344
|
+
logger.warning("main condition raised; stop loop: %s", e)
|
|
345
|
+
return False, StopReason.CONDITION_ERROR
|
|
346
|
+
|
|
347
|
+
def build(self) -> NodeSpec:
|
|
348
|
+
"""Build NodeSpec with validation."""
|
|
349
|
+
# Validation moved here (instead of LoopConfig)
|
|
350
|
+
if not self._name:
|
|
351
|
+
raise ValueError("LoopNode name is required")
|
|
352
|
+
if self._while is None:
|
|
353
|
+
raise ValueError("condition(...) is required")
|
|
354
|
+
if self._max_iter <= 0:
|
|
355
|
+
raise ValueError("max_iterations must be positive")
|
|
356
|
+
if self._collect_mode not in ("list", "last", "reduce"):
|
|
357
|
+
raise ValueError("collect_mode must be one of: list | last | reduce")
|
|
358
|
+
if self._collect_mode == "reduce" and (
|
|
359
|
+
self._reduce_cfg is None or self._reduce_cfg.reducer is None
|
|
360
|
+
):
|
|
361
|
+
raise ValueError("ReduceConfig with a reducer is required when collect_mode='reduce'")
|
|
362
|
+
|
|
363
|
+
return self(
|
|
364
|
+
name=self._name,
|
|
365
|
+
deps=self._deps,
|
|
366
|
+
in_model=self._in_model,
|
|
367
|
+
out_model=self._out_model,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def __call__(
|
|
371
|
+
self,
|
|
372
|
+
name: str,
|
|
373
|
+
while_condition: str | Callable[[dict, dict], bool] | None = None,
|
|
374
|
+
body: str | Callable[[dict, dict], Any] | None = None,
|
|
375
|
+
max_iterations: int | None = None,
|
|
376
|
+
collect_mode: CollectMode | None = None,
|
|
377
|
+
initial_state: dict | None = None,
|
|
378
|
+
iteration_key: str | None = None,
|
|
379
|
+
**kwargs: Any,
|
|
380
|
+
) -> NodeSpec:
|
|
381
|
+
"""Builds a LoopNode NodeSpec.
|
|
382
|
+
|
|
383
|
+
Supports two modes:
|
|
384
|
+
1. Builder pattern: Use .condition(), .do(), etc. methods, then .build()
|
|
385
|
+
2. YAML/direct: Pass while_condition and body as parameters
|
|
386
|
+
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
name : str
|
|
390
|
+
Node name.
|
|
391
|
+
while_condition : str | Callable | None
|
|
392
|
+
For YAML: String expression like "state.iteration < 10".
|
|
393
|
+
For builder: Set via .condition() method.
|
|
394
|
+
body : str | Callable | None
|
|
395
|
+
For YAML: Module path to body function (e.g., "myapp.process").
|
|
396
|
+
For builder: Set via .do() method.
|
|
397
|
+
max_iterations : int | None
|
|
398
|
+
Safety cap to prevent infinite loops (default: 100).
|
|
399
|
+
collect_mode : CollectMode | None
|
|
400
|
+
How to collect results: "last", "list", or "reduce".
|
|
401
|
+
initial_state : dict | None
|
|
402
|
+
Initial state dict passed to first iteration.
|
|
403
|
+
iteration_key : str | None
|
|
404
|
+
Key name for current iteration number in state.
|
|
405
|
+
**kwargs
|
|
406
|
+
Passed through to NodeSpec (e.g., in_model, out_model, deps).
|
|
407
|
+
|
|
408
|
+
Examples
|
|
409
|
+
--------
|
|
410
|
+
YAML usage::
|
|
411
|
+
|
|
412
|
+
- kind: loop_node
|
|
413
|
+
metadata:
|
|
414
|
+
name: retry_loop
|
|
415
|
+
spec:
|
|
416
|
+
while_condition: "state.iteration < 3"
|
|
417
|
+
body: myapp.process_item
|
|
418
|
+
max_iterations: 10
|
|
419
|
+
initial_state:
|
|
420
|
+
counter: 0
|
|
421
|
+
"""
|
|
422
|
+
warnings.warn(
|
|
423
|
+
"LoopNode is deprecated. Use CompositeNode with mode='while' instead. "
|
|
424
|
+
"See hexdag.builtin.nodes.composite_node for the unified API.",
|
|
425
|
+
DeprecationWarning,
|
|
426
|
+
stacklevel=2,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
from hexdag.core.expression_parser import compile_expression
|
|
430
|
+
from hexdag.core.resolver import resolve
|
|
431
|
+
|
|
432
|
+
# Determine source: YAML parameters or builder state
|
|
433
|
+
if while_condition is not None:
|
|
434
|
+
# YAML mode: Compile string condition
|
|
435
|
+
if isinstance(while_condition, str):
|
|
436
|
+
final_condition = compile_expression(while_condition)
|
|
437
|
+
else:
|
|
438
|
+
final_condition = while_condition
|
|
439
|
+
else:
|
|
440
|
+
# Builder mode: Use internal state
|
|
441
|
+
if self._while is None:
|
|
442
|
+
raise ValueError("while_condition is required")
|
|
443
|
+
final_condition = self._while
|
|
444
|
+
|
|
445
|
+
final_body: Callable[[dict, dict], Any]
|
|
446
|
+
if body is not None:
|
|
447
|
+
# YAML mode: Resolve body function from module path or use callable directly
|
|
448
|
+
resolved = resolve(body) if isinstance(body, str) else body
|
|
449
|
+
final_body = cast("Callable[[dict, dict], Any]", resolved)
|
|
450
|
+
else:
|
|
451
|
+
# Builder mode: Use internal state
|
|
452
|
+
final_body = self._body
|
|
453
|
+
|
|
454
|
+
# Use YAML params or builder state with defaults
|
|
455
|
+
final_max_iter = max_iterations if max_iterations is not None else self._max_iter
|
|
456
|
+
final_collect_mode = collect_mode if collect_mode is not None else self._collect_mode
|
|
457
|
+
final_init_state = dict(initial_state or self._init_state or {})
|
|
458
|
+
final_iter_key = iteration_key if iteration_key is not None else self._iter_key
|
|
459
|
+
|
|
460
|
+
on_iteration_end = self._on_end
|
|
461
|
+
reduce_cfg = self._reduce_cfg
|
|
462
|
+
break_if = list(self._break_if or [])
|
|
463
|
+
|
|
464
|
+
# Capture for closure
|
|
465
|
+
_condition = final_condition
|
|
466
|
+
_body_fn = final_body
|
|
467
|
+
_max_iter = final_max_iter
|
|
468
|
+
_collect_mode = final_collect_mode
|
|
469
|
+
_init_state = final_init_state
|
|
470
|
+
_iter_key = final_iter_key
|
|
471
|
+
_on_end = on_iteration_end
|
|
472
|
+
_reduce_cfg = reduce_cfg
|
|
473
|
+
_break_if = break_if
|
|
474
|
+
|
|
475
|
+
async def loop_fn(input_data: Any, **ports: Any) -> dict[str, Any]:
|
|
476
|
+
"""Execute the enhanced loop.
|
|
477
|
+
|
|
478
|
+
Steps per iteration:
|
|
479
|
+
1) Check safety cap (max_iterations) and main condition.
|
|
480
|
+
2) Run body_fn (sync/async) and collect result according to collect_mode.
|
|
481
|
+
3) Update state via on_iteration_end(state, out).
|
|
482
|
+
|
|
483
|
+
Returns a dict with:
|
|
484
|
+
- Original input fields (normalized to dict).
|
|
485
|
+
- "loop": metadata with iterations, final state, stop flags.
|
|
486
|
+
- One of: "outputs" (list), "output" (last), or "reduced" (accumulator).
|
|
487
|
+
"""
|
|
488
|
+
node_logger = logger.bind(node=name, node_type="loop_node")
|
|
489
|
+
start_time = time.perf_counter()
|
|
490
|
+
|
|
491
|
+
# Log loop start
|
|
492
|
+
node_logger.info(
|
|
493
|
+
"Starting loop",
|
|
494
|
+
max_iterations=_max_iter,
|
|
495
|
+
collect_mode=_collect_mode,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# Normalize input to dict without external helpers
|
|
499
|
+
data = _normalize_input_data(input_data)
|
|
500
|
+
|
|
501
|
+
# Local loop state
|
|
502
|
+
state = dict(_init_state or {})
|
|
503
|
+
|
|
504
|
+
acc, collect_fn = self._init_collector(_collect_mode, _reduce_cfg)
|
|
505
|
+
|
|
506
|
+
iteration_count = 0
|
|
507
|
+
stopped_by: StopReason = StopReason.NONE
|
|
508
|
+
|
|
509
|
+
while True:
|
|
510
|
+
# Safety cap
|
|
511
|
+
if iteration_count >= _max_iter:
|
|
512
|
+
stopped_by = StopReason.LIMIT
|
|
513
|
+
node_logger.debug(
|
|
514
|
+
"Loop reached max iterations",
|
|
515
|
+
iteration=iteration_count,
|
|
516
|
+
max_iterations=_max_iter,
|
|
517
|
+
)
|
|
518
|
+
break
|
|
519
|
+
|
|
520
|
+
# Main condition
|
|
521
|
+
ok, reason = self._should_continue(_condition, data, state)
|
|
522
|
+
if not ok:
|
|
523
|
+
stopped_by = reason or StopReason.CONDITION
|
|
524
|
+
node_logger.debug(
|
|
525
|
+
"Loop condition returned False",
|
|
526
|
+
iteration=iteration_count,
|
|
527
|
+
stop_reason=stopped_by.value,
|
|
528
|
+
)
|
|
529
|
+
break
|
|
530
|
+
|
|
531
|
+
# Log iteration start at debug level
|
|
532
|
+
node_logger.debug(
|
|
533
|
+
"Loop iteration",
|
|
534
|
+
iteration=iteration_count + 1,
|
|
535
|
+
state_keys=list(state.keys()),
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# Body execution (supports async)
|
|
539
|
+
out = _body_fn(data, state)
|
|
540
|
+
if asyncio.iscoroutine(out):
|
|
541
|
+
out = await out
|
|
542
|
+
|
|
543
|
+
# Collect results
|
|
544
|
+
acc = collect_fn(acc, out)
|
|
545
|
+
|
|
546
|
+
# State update after each iteration
|
|
547
|
+
state = _apply_on_iteration_end(_on_end, state, out)
|
|
548
|
+
|
|
549
|
+
# Break guard
|
|
550
|
+
if _eval_break_guards(_break_if, data, state):
|
|
551
|
+
stopped_by = StopReason.BREAK_GUARD
|
|
552
|
+
iteration_count += 1
|
|
553
|
+
state[_iter_key] = iteration_count
|
|
554
|
+
node_logger.debug(
|
|
555
|
+
"Break guard triggered",
|
|
556
|
+
iteration=iteration_count,
|
|
557
|
+
)
|
|
558
|
+
break
|
|
559
|
+
|
|
560
|
+
iteration_count += 1
|
|
561
|
+
state[_iter_key] = iteration_count
|
|
562
|
+
|
|
563
|
+
# Log loop completion
|
|
564
|
+
duration_ms = (time.perf_counter() - start_time) * 1000
|
|
565
|
+
node_logger.info(
|
|
566
|
+
"Loop completed",
|
|
567
|
+
total_iterations=iteration_count,
|
|
568
|
+
stopped_by=stopped_by.value,
|
|
569
|
+
collect_mode=_collect_mode,
|
|
570
|
+
duration_ms=f"{duration_ms:.2f}",
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Build final result payload
|
|
574
|
+
return {
|
|
575
|
+
"result": acc,
|
|
576
|
+
"metadata": {
|
|
577
|
+
"iterations": iteration_count,
|
|
578
|
+
"stopped_by": stopped_by,
|
|
579
|
+
"max_iterations": _max_iter,
|
|
580
|
+
"state": state,
|
|
581
|
+
},
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
# Extract framework-level parameters from kwargs
|
|
585
|
+
framework = self.extract_framework_params(kwargs)
|
|
586
|
+
|
|
587
|
+
# Map DirectedGraph-related arguments to NodeSpec fields
|
|
588
|
+
try:
|
|
589
|
+
params_model = NodeParams(**kwargs)
|
|
590
|
+
except Exception as e:
|
|
591
|
+
raise ValueError(f"Invalid node params: {e}") from e
|
|
592
|
+
|
|
593
|
+
return NodeSpec(
|
|
594
|
+
name=name,
|
|
595
|
+
fn=loop_fn,
|
|
596
|
+
in_model=params_model.in_model,
|
|
597
|
+
out_model=params_model.out_model,
|
|
598
|
+
deps=frozenset(params_model.deps),
|
|
599
|
+
params=params_model.model_dump(exclude_none=True),
|
|
600
|
+
timeout=framework["timeout"],
|
|
601
|
+
max_retries=framework["max_retries"],
|
|
602
|
+
when=framework["when"],
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
class ConditionalNode(BaseNodeFactory):
|
|
607
|
+
"""Multi-branch conditional router (functional-only, breaking change).
|
|
608
|
+
|
|
609
|
+
.. deprecated::
|
|
610
|
+
ConditionalNode is deprecated. Use CompositeNode with mode='switch' instead.
|
|
611
|
+
See hexdag.builtin.nodes.composite_node for the new unified API.
|
|
612
|
+
|
|
613
|
+
API:
|
|
614
|
+
- branches: list of {"pred": Callable[[dict, dict], bool], "action": str}
|
|
615
|
+
- else_action: str | None — fallback action if no branch matches.
|
|
616
|
+
- tie_break: currently only "first_true" is supported.
|
|
617
|
+
Return:
|
|
618
|
+
{
|
|
619
|
+
"result": <action | None>,
|
|
620
|
+
"metadata": {
|
|
621
|
+
"matched_branch": <int | None>,
|
|
622
|
+
"evaluations": <list[bool]>,
|
|
623
|
+
"has_else":
|
|
624
|
+
}
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
Notes:
|
|
628
|
+
- Functional-only predicates (no strings, no eval).
|
|
629
|
+
- Input is normalized to dict internally; original input is not echoed back.
|
|
630
|
+
"""
|
|
631
|
+
|
|
632
|
+
# Explicit schema for YAML/MCP usage (builder pattern doesn't expose params well)
|
|
633
|
+
_yaml_schema: dict[str, Any] = {
|
|
634
|
+
"type": "object",
|
|
635
|
+
"description": "Multi-branch conditional router for workflow control flow",
|
|
636
|
+
"properties": {
|
|
637
|
+
"branches": {
|
|
638
|
+
"type": "array",
|
|
639
|
+
"description": "List of condition branches evaluated in order",
|
|
640
|
+
"items": {
|
|
641
|
+
"type": "object",
|
|
642
|
+
"properties": {
|
|
643
|
+
"condition": {
|
|
644
|
+
"type": "string",
|
|
645
|
+
"description": "Expression like 'node.field == value' or callable",
|
|
646
|
+
},
|
|
647
|
+
"action": {
|
|
648
|
+
"type": "string",
|
|
649
|
+
"description": "Action name to return if condition matches",
|
|
650
|
+
},
|
|
651
|
+
},
|
|
652
|
+
"required": ["condition", "action"],
|
|
653
|
+
},
|
|
654
|
+
},
|
|
655
|
+
"else_action": {
|
|
656
|
+
"type": "string",
|
|
657
|
+
"description": "Default action if no branch conditions match",
|
|
658
|
+
},
|
|
659
|
+
"tie_break": {
|
|
660
|
+
"type": "string",
|
|
661
|
+
"enum": ["first_true"],
|
|
662
|
+
"default": "first_true",
|
|
663
|
+
"description": "Strategy for handling multiple matching branches",
|
|
664
|
+
},
|
|
665
|
+
},
|
|
666
|
+
"required": ["branches"],
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
def __init__(self, name: str | None = None) -> None:
|
|
670
|
+
super().__init__()
|
|
671
|
+
# builder state
|
|
672
|
+
self._name: str | None = name
|
|
673
|
+
self._branches: list[dict[str, Any]] = []
|
|
674
|
+
self._else_action: str | None = None
|
|
675
|
+
|
|
676
|
+
# NodeSpec params
|
|
677
|
+
self._deps: set[str] = set()
|
|
678
|
+
self._in_model: Any | None = None
|
|
679
|
+
self._out_model: Any | None = None
|
|
680
|
+
|
|
681
|
+
def name(self, n: str) -> "ConditionalNode":
|
|
682
|
+
self._name = n
|
|
683
|
+
return self
|
|
684
|
+
|
|
685
|
+
def when(
|
|
686
|
+
self,
|
|
687
|
+
pred: Callable[[dict, dict], bool] | str,
|
|
688
|
+
action: str,
|
|
689
|
+
) -> "ConditionalNode":
|
|
690
|
+
"""Add a conditional branch.
|
|
691
|
+
|
|
692
|
+
Parameters
|
|
693
|
+
----------
|
|
694
|
+
pred : Callable[[dict, dict], bool] | str
|
|
695
|
+
Either a callable predicate function that takes (data, state)
|
|
696
|
+
and returns bool, or a string expression like "action == 'ACCEPT'"
|
|
697
|
+
that will be compiled into a safe predicate.
|
|
698
|
+
action : str
|
|
699
|
+
The action name to return if this branch matches.
|
|
700
|
+
|
|
701
|
+
Returns
|
|
702
|
+
-------
|
|
703
|
+
ConditionalNode
|
|
704
|
+
Self for method chaining.
|
|
705
|
+
|
|
706
|
+
Examples
|
|
707
|
+
--------
|
|
708
|
+
Using a callable::
|
|
709
|
+
|
|
710
|
+
node.when(lambda d, s: d.get("status") == "active", "process")
|
|
711
|
+
|
|
712
|
+
Using a string expression::
|
|
713
|
+
|
|
714
|
+
node.when("status == 'active'", "process")
|
|
715
|
+
node.when("node.action == 'ACCEPT' and confidence > 0.8", "approve")
|
|
716
|
+
node.when("state.iteration < 10", "continue")
|
|
717
|
+
"""
|
|
718
|
+
if isinstance(pred, str):
|
|
719
|
+
from hexdag.core.expression_parser import compile_expression
|
|
720
|
+
|
|
721
|
+
pred = compile_expression(pred)
|
|
722
|
+
elif not callable(pred):
|
|
723
|
+
raise ValueError("when(): pred must be callable or a string expression")
|
|
724
|
+
if not isinstance(action, str) or not action:
|
|
725
|
+
raise ValueError("when(): action must be a non-empty string")
|
|
726
|
+
self._branches.append({"pred": pred, "action": action})
|
|
727
|
+
return self
|
|
728
|
+
|
|
729
|
+
def otherwise(self, action: str) -> "ConditionalNode":
|
|
730
|
+
if not isinstance(action, str) or not action:
|
|
731
|
+
raise ValueError("otherwise(): action must be a non-empty string")
|
|
732
|
+
self._else_action = action
|
|
733
|
+
return self
|
|
734
|
+
|
|
735
|
+
def deps(self, deps: Iterable[str]) -> "ConditionalNode":
|
|
736
|
+
self._deps = set(deps or [])
|
|
737
|
+
return self
|
|
738
|
+
|
|
739
|
+
def in_model(self, model: Any) -> "ConditionalNode":
|
|
740
|
+
self._in_model = model
|
|
741
|
+
return self
|
|
742
|
+
|
|
743
|
+
def out_model(self, model: Any) -> "ConditionalNode":
|
|
744
|
+
self._out_model = model
|
|
745
|
+
return self
|
|
746
|
+
|
|
747
|
+
def build(self) -> NodeSpec:
|
|
748
|
+
"""Build NodeSpec with validation."""
|
|
749
|
+
if not self._name:
|
|
750
|
+
raise ValueError("ConditionalNode name is required")
|
|
751
|
+
if not self._branches and self._else_action is None:
|
|
752
|
+
raise ValueError("At least one branch (when) or otherwise(...) action is required")
|
|
753
|
+
|
|
754
|
+
return self(
|
|
755
|
+
name=self._name,
|
|
756
|
+
deps=self._deps,
|
|
757
|
+
in_model=self._in_model,
|
|
758
|
+
out_model=self._out_model,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def __call__(
|
|
762
|
+
self,
|
|
763
|
+
name: str,
|
|
764
|
+
branches: list[dict[str, str]] | None = None,
|
|
765
|
+
else_action: str | None = None,
|
|
766
|
+
tie_break: TieBreak = "first_true",
|
|
767
|
+
**kwargs: Any,
|
|
768
|
+
) -> NodeSpec:
|
|
769
|
+
"""Builds a ConditionalNode NodeSpec.
|
|
770
|
+
|
|
771
|
+
Supports two modes:
|
|
772
|
+
1. Builder pattern: Use .when() and .otherwise() methods, then .build()
|
|
773
|
+
2. YAML/direct: Pass branches and else_action as parameters
|
|
774
|
+
|
|
775
|
+
Parameters
|
|
776
|
+
----------
|
|
777
|
+
name : str
|
|
778
|
+
Node name.
|
|
779
|
+
branches : list[dict[str, str]] | None
|
|
780
|
+
For YAML usage: List of branches with "condition" (string expression)
|
|
781
|
+
and "action" fields. Example:
|
|
782
|
+
[{"condition": "action == 'ACCEPT'", "action": "approve"}]
|
|
783
|
+
else_action : str | None
|
|
784
|
+
Optional fallback action when no branch matches.
|
|
785
|
+
tie_break : TieBreak
|
|
786
|
+
Branch selection strategy; only "first_true" supported.
|
|
787
|
+
**kwargs
|
|
788
|
+
Passed through to NodeSpec (e.g., in_model, out_model, deps).
|
|
789
|
+
|
|
790
|
+
Examples
|
|
791
|
+
--------
|
|
792
|
+
YAML usage::
|
|
793
|
+
|
|
794
|
+
- kind: conditional_node
|
|
795
|
+
metadata:
|
|
796
|
+
name: router
|
|
797
|
+
spec:
|
|
798
|
+
branches:
|
|
799
|
+
- condition: "action == 'ACCEPT'"
|
|
800
|
+
action: approve
|
|
801
|
+
- condition: "confidence < 0.5"
|
|
802
|
+
action: manual_review
|
|
803
|
+
else_action: default_handler
|
|
804
|
+
"""
|
|
805
|
+
warnings.warn(
|
|
806
|
+
"ConditionalNode is deprecated. Use CompositeNode with mode='switch' instead. "
|
|
807
|
+
"See hexdag.builtin.nodes.composite_node for the unified API.",
|
|
808
|
+
DeprecationWarning,
|
|
809
|
+
stacklevel=2,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
from hexdag.core.expression_parser import compile_expression
|
|
813
|
+
|
|
814
|
+
# Determine source of branches: builder state or YAML parameters
|
|
815
|
+
if branches is not None:
|
|
816
|
+
# YAML mode: Convert string conditions to compiled predicates
|
|
817
|
+
compiled_branches: list[dict[str, Any]] = []
|
|
818
|
+
for branch in branches:
|
|
819
|
+
condition = branch.get("condition")
|
|
820
|
+
action = branch.get("action")
|
|
821
|
+
if not condition or not action:
|
|
822
|
+
raise ValueError(
|
|
823
|
+
f"Each branch must have 'condition' and 'action' fields. Got: {branch}"
|
|
824
|
+
)
|
|
825
|
+
# Compile string expression to predicate
|
|
826
|
+
pred = compile_expression(condition)
|
|
827
|
+
compiled_branches.append({"pred": pred, "action": action})
|
|
828
|
+
final_branches = compiled_branches
|
|
829
|
+
final_else_action = else_action
|
|
830
|
+
else:
|
|
831
|
+
# Builder mode: Use internal state from .when() calls
|
|
832
|
+
final_branches = list(self._branches or [])
|
|
833
|
+
final_else_action = self._else_action
|
|
834
|
+
|
|
835
|
+
# Capture for closure
|
|
836
|
+
_branches = final_branches
|
|
837
|
+
_else_action = final_else_action
|
|
838
|
+
_tie_break = tie_break
|
|
839
|
+
|
|
840
|
+
async def conditional_fn(input_data: Any, **ports: Any) -> dict[str, Any]:
|
|
841
|
+
"""Evaluate branches in order and pick the routing action.
|
|
842
|
+
|
|
843
|
+
- Normalizes input to dict.
|
|
844
|
+
- For callable predicates, passes (data, state) where state may be provided via ports.
|
|
845
|
+
"""
|
|
846
|
+
node_logger = logger.bind(node=name, node_type="conditional_node")
|
|
847
|
+
|
|
848
|
+
# Log evaluation start
|
|
849
|
+
node_logger.debug(
|
|
850
|
+
"Evaluating conditions",
|
|
851
|
+
branch_count=len(_branches),
|
|
852
|
+
has_else=_else_action is not None,
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
# Normalizes input to dict
|
|
856
|
+
data = _normalize_input_data(input_data)
|
|
857
|
+
|
|
858
|
+
state = ports.get("state", {}) if isinstance(ports.get("state", {}), dict) else {}
|
|
859
|
+
|
|
860
|
+
chosen: str | None = None
|
|
861
|
+
chosen_idx: int | None = None
|
|
862
|
+
evaluations: list[bool] = []
|
|
863
|
+
|
|
864
|
+
for idx, br in enumerate(_branches):
|
|
865
|
+
ok = False
|
|
866
|
+
try:
|
|
867
|
+
ok = bool(br["pred"](data, state))
|
|
868
|
+
except Exception as e:
|
|
869
|
+
node_logger.warning(
|
|
870
|
+
"Branch predicate raised exception",
|
|
871
|
+
branch_index=idx,
|
|
872
|
+
error=str(e),
|
|
873
|
+
)
|
|
874
|
+
evaluations.append(ok)
|
|
875
|
+
if ok and chosen is None:
|
|
876
|
+
chosen = br["action"]
|
|
877
|
+
chosen_idx = idx
|
|
878
|
+
if _tie_break == "first_true":
|
|
879
|
+
break
|
|
880
|
+
|
|
881
|
+
if chosen is None:
|
|
882
|
+
chosen = _else_action
|
|
883
|
+
|
|
884
|
+
result = {
|
|
885
|
+
"result": chosen,
|
|
886
|
+
"metadata": {
|
|
887
|
+
"matched_branch": chosen_idx,
|
|
888
|
+
"evaluations": evaluations,
|
|
889
|
+
"has_else": _else_action is not None,
|
|
890
|
+
},
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
# Log routing decision
|
|
894
|
+
node_logger.info(
|
|
895
|
+
"Routing decision",
|
|
896
|
+
chosen_action=chosen,
|
|
897
|
+
matched_branch=chosen_idx,
|
|
898
|
+
used_else=chosen_idx is None and chosen is not None,
|
|
899
|
+
)
|
|
900
|
+
return result
|
|
901
|
+
|
|
902
|
+
# Extract framework-level parameters from kwargs
|
|
903
|
+
framework = self.extract_framework_params(kwargs)
|
|
904
|
+
|
|
905
|
+
try:
|
|
906
|
+
params_model = NodeParams(**kwargs)
|
|
907
|
+
except Exception as e:
|
|
908
|
+
raise ValueError(f"Invalid node parameters: {e}") from e
|
|
909
|
+
|
|
910
|
+
return NodeSpec(
|
|
911
|
+
name=name,
|
|
912
|
+
fn=conditional_fn,
|
|
913
|
+
in_model=params_model.in_model,
|
|
914
|
+
out_model=params_model.out_model,
|
|
915
|
+
deps=frozenset(params_model.deps),
|
|
916
|
+
params=params_model.model_dump(exclude_none=True),
|
|
917
|
+
timeout=framework["timeout"],
|
|
918
|
+
max_retries=framework["max_retries"],
|
|
919
|
+
when=framework["when"],
|
|
920
|
+
)
|