hexdag 0.5.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hexdag/__init__.py +116 -0
- hexdag/__main__.py +30 -0
- hexdag/adapters/executors/__init__.py +5 -0
- hexdag/adapters/executors/local_executor.py +316 -0
- hexdag/builtin/__init__.py +6 -0
- hexdag/builtin/adapters/__init__.py +51 -0
- hexdag/builtin/adapters/anthropic/__init__.py +5 -0
- hexdag/builtin/adapters/anthropic/anthropic_adapter.py +151 -0
- hexdag/builtin/adapters/database/__init__.py +6 -0
- hexdag/builtin/adapters/database/csv/csv_adapter.py +249 -0
- hexdag/builtin/adapters/database/pgvector/__init__.py +5 -0
- hexdag/builtin/adapters/database/pgvector/pgvector_adapter.py +478 -0
- hexdag/builtin/adapters/database/sqlalchemy/sqlalchemy_adapter.py +252 -0
- hexdag/builtin/adapters/database/sqlite/__init__.py +5 -0
- hexdag/builtin/adapters/database/sqlite/sqlite_adapter.py +410 -0
- hexdag/builtin/adapters/local/README.md +59 -0
- hexdag/builtin/adapters/local/__init__.py +7 -0
- hexdag/builtin/adapters/local/local_observer_manager.py +696 -0
- hexdag/builtin/adapters/memory/__init__.py +47 -0
- hexdag/builtin/adapters/memory/file_memory_adapter.py +297 -0
- hexdag/builtin/adapters/memory/in_memory_memory.py +216 -0
- hexdag/builtin/adapters/memory/schemas.py +57 -0
- hexdag/builtin/adapters/memory/session_memory.py +178 -0
- hexdag/builtin/adapters/memory/sqlite_memory_adapter.py +215 -0
- hexdag/builtin/adapters/memory/state_memory.py +280 -0
- hexdag/builtin/adapters/mock/README.md +89 -0
- hexdag/builtin/adapters/mock/__init__.py +15 -0
- hexdag/builtin/adapters/mock/hexdag.toml +50 -0
- hexdag/builtin/adapters/mock/mock_database.py +225 -0
- hexdag/builtin/adapters/mock/mock_embedding.py +223 -0
- hexdag/builtin/adapters/mock/mock_llm.py +177 -0
- hexdag/builtin/adapters/mock/mock_tool_adapter.py +192 -0
- hexdag/builtin/adapters/mock/mock_tool_router.py +232 -0
- hexdag/builtin/adapters/openai/__init__.py +5 -0
- hexdag/builtin/adapters/openai/openai_adapter.py +634 -0
- hexdag/builtin/adapters/secret/__init__.py +7 -0
- hexdag/builtin/adapters/secret/local_secret_adapter.py +248 -0
- hexdag/builtin/adapters/unified_tool_router.py +280 -0
- hexdag/builtin/macros/__init__.py +17 -0
- hexdag/builtin/macros/conversation_agent.py +390 -0
- hexdag/builtin/macros/llm_macro.py +151 -0
- hexdag/builtin/macros/reasoning_agent.py +423 -0
- hexdag/builtin/macros/tool_macro.py +380 -0
- hexdag/builtin/nodes/__init__.py +38 -0
- hexdag/builtin/nodes/_discovery.py +123 -0
- hexdag/builtin/nodes/agent_node.py +696 -0
- hexdag/builtin/nodes/base_node_factory.py +242 -0
- hexdag/builtin/nodes/composite_node.py +926 -0
- hexdag/builtin/nodes/data_node.py +201 -0
- hexdag/builtin/nodes/expression_node.py +487 -0
- hexdag/builtin/nodes/function_node.py +454 -0
- hexdag/builtin/nodes/llm_node.py +491 -0
- hexdag/builtin/nodes/loop_node.py +920 -0
- hexdag/builtin/nodes/mapped_input.py +518 -0
- hexdag/builtin/nodes/port_call_node.py +269 -0
- hexdag/builtin/nodes/tool_call_node.py +195 -0
- hexdag/builtin/nodes/tool_utils.py +390 -0
- hexdag/builtin/prompts/__init__.py +68 -0
- hexdag/builtin/prompts/base.py +422 -0
- hexdag/builtin/prompts/chat_prompts.py +303 -0
- hexdag/builtin/prompts/error_correction_prompts.py +320 -0
- hexdag/builtin/prompts/tool_prompts.py +160 -0
- hexdag/builtin/tools/builtin_tools.py +84 -0
- hexdag/builtin/tools/database_tools.py +164 -0
- hexdag/cli/__init__.py +17 -0
- hexdag/cli/__main__.py +7 -0
- hexdag/cli/commands/__init__.py +27 -0
- hexdag/cli/commands/build_cmd.py +812 -0
- hexdag/cli/commands/create_cmd.py +208 -0
- hexdag/cli/commands/docs_cmd.py +293 -0
- hexdag/cli/commands/generate_types_cmd.py +252 -0
- hexdag/cli/commands/init_cmd.py +188 -0
- hexdag/cli/commands/pipeline_cmd.py +494 -0
- hexdag/cli/commands/plugin_dev_cmd.py +529 -0
- hexdag/cli/commands/plugins_cmd.py +441 -0
- hexdag/cli/commands/studio_cmd.py +101 -0
- hexdag/cli/commands/validate_cmd.py +221 -0
- hexdag/cli/main.py +84 -0
- hexdag/core/__init__.py +83 -0
- hexdag/core/config/__init__.py +20 -0
- hexdag/core/config/loader.py +479 -0
- hexdag/core/config/models.py +150 -0
- hexdag/core/configurable.py +294 -0
- hexdag/core/context/__init__.py +37 -0
- hexdag/core/context/execution_context.py +378 -0
- hexdag/core/docs/__init__.py +26 -0
- hexdag/core/docs/extractors.py +678 -0
- hexdag/core/docs/generators.py +890 -0
- hexdag/core/docs/models.py +120 -0
- hexdag/core/domain/__init__.py +10 -0
- hexdag/core/domain/dag.py +1225 -0
- hexdag/core/exceptions.py +234 -0
- hexdag/core/expression_parser.py +569 -0
- hexdag/core/logging.py +449 -0
- hexdag/core/models/__init__.py +17 -0
- hexdag/core/models/base.py +138 -0
- hexdag/core/orchestration/__init__.py +46 -0
- hexdag/core/orchestration/body_executor.py +481 -0
- hexdag/core/orchestration/components/__init__.py +97 -0
- hexdag/core/orchestration/components/adapter_lifecycle_manager.py +113 -0
- hexdag/core/orchestration/components/checkpoint_manager.py +134 -0
- hexdag/core/orchestration/components/execution_coordinator.py +360 -0
- hexdag/core/orchestration/components/health_check_manager.py +176 -0
- hexdag/core/orchestration/components/input_mapper.py +143 -0
- hexdag/core/orchestration/components/lifecycle_manager.py +583 -0
- hexdag/core/orchestration/components/node_executor.py +377 -0
- hexdag/core/orchestration/components/secret_manager.py +202 -0
- hexdag/core/orchestration/components/wave_executor.py +158 -0
- hexdag/core/orchestration/constants.py +17 -0
- hexdag/core/orchestration/events/README.md +312 -0
- hexdag/core/orchestration/events/__init__.py +104 -0
- hexdag/core/orchestration/events/batching.py +330 -0
- hexdag/core/orchestration/events/decorators.py +139 -0
- hexdag/core/orchestration/events/events.py +573 -0
- hexdag/core/orchestration/events/observers/__init__.py +30 -0
- hexdag/core/orchestration/events/observers/core_observers.py +690 -0
- hexdag/core/orchestration/events/observers/models.py +111 -0
- hexdag/core/orchestration/events/taxonomy.py +269 -0
- hexdag/core/orchestration/hook_context.py +237 -0
- hexdag/core/orchestration/hooks.py +437 -0
- hexdag/core/orchestration/models.py +418 -0
- hexdag/core/orchestration/orchestrator.py +910 -0
- hexdag/core/orchestration/orchestrator_factory.py +275 -0
- hexdag/core/orchestration/port_wrappers.py +327 -0
- hexdag/core/orchestration/prompt/__init__.py +32 -0
- hexdag/core/orchestration/prompt/template.py +332 -0
- hexdag/core/pipeline_builder/__init__.py +21 -0
- hexdag/core/pipeline_builder/component_instantiator.py +386 -0
- hexdag/core/pipeline_builder/include_tag.py +265 -0
- hexdag/core/pipeline_builder/pipeline_config.py +133 -0
- hexdag/core/pipeline_builder/py_tag.py +223 -0
- hexdag/core/pipeline_builder/tag_discovery.py +268 -0
- hexdag/core/pipeline_builder/yaml_builder.py +1196 -0
- hexdag/core/pipeline_builder/yaml_validator.py +569 -0
- hexdag/core/ports/__init__.py +65 -0
- hexdag/core/ports/api_call.py +133 -0
- hexdag/core/ports/database.py +489 -0
- hexdag/core/ports/embedding.py +215 -0
- hexdag/core/ports/executor.py +237 -0
- hexdag/core/ports/file_storage.py +117 -0
- hexdag/core/ports/healthcheck.py +87 -0
- hexdag/core/ports/llm.py +551 -0
- hexdag/core/ports/memory.py +70 -0
- hexdag/core/ports/observer_manager.py +130 -0
- hexdag/core/ports/secret.py +145 -0
- hexdag/core/ports/tool_router.py +94 -0
- hexdag/core/ports_builder.py +623 -0
- hexdag/core/protocols.py +273 -0
- hexdag/core/resolver.py +304 -0
- hexdag/core/schema/__init__.py +9 -0
- hexdag/core/schema/generator.py +742 -0
- hexdag/core/secrets.py +242 -0
- hexdag/core/types.py +413 -0
- hexdag/core/utils/async_warnings.py +206 -0
- hexdag/core/utils/schema_conversion.py +78 -0
- hexdag/core/utils/sql_validation.py +86 -0
- hexdag/core/validation/secure_json.py +148 -0
- hexdag/core/yaml_macro.py +517 -0
- hexdag/mcp_server.py +3120 -0
- hexdag/studio/__init__.py +10 -0
- hexdag/studio/build_ui.py +92 -0
- hexdag/studio/server/__init__.py +1 -0
- hexdag/studio/server/main.py +100 -0
- hexdag/studio/server/routes/__init__.py +9 -0
- hexdag/studio/server/routes/execute.py +208 -0
- hexdag/studio/server/routes/export.py +558 -0
- hexdag/studio/server/routes/files.py +207 -0
- hexdag/studio/server/routes/plugins.py +419 -0
- hexdag/studio/server/routes/validate.py +220 -0
- hexdag/studio/ui/index.html +13 -0
- hexdag/studio/ui/package-lock.json +2992 -0
- hexdag/studio/ui/package.json +31 -0
- hexdag/studio/ui/postcss.config.js +6 -0
- hexdag/studio/ui/public/hexdag.svg +5 -0
- hexdag/studio/ui/src/App.tsx +251 -0
- hexdag/studio/ui/src/components/Canvas.tsx +408 -0
- hexdag/studio/ui/src/components/ContextMenu.tsx +187 -0
- hexdag/studio/ui/src/components/FileBrowser.tsx +123 -0
- hexdag/studio/ui/src/components/Header.tsx +181 -0
- hexdag/studio/ui/src/components/HexdagNode.tsx +193 -0
- hexdag/studio/ui/src/components/NodeInspector.tsx +512 -0
- hexdag/studio/ui/src/components/NodePalette.tsx +262 -0
- hexdag/studio/ui/src/components/NodePortsSection.tsx +403 -0
- hexdag/studio/ui/src/components/PluginManager.tsx +347 -0
- hexdag/studio/ui/src/components/PortsEditor.tsx +481 -0
- hexdag/studio/ui/src/components/PythonEditor.tsx +195 -0
- hexdag/studio/ui/src/components/ValidationPanel.tsx +105 -0
- hexdag/studio/ui/src/components/YamlEditor.tsx +196 -0
- hexdag/studio/ui/src/components/index.ts +8 -0
- hexdag/studio/ui/src/index.css +92 -0
- hexdag/studio/ui/src/main.tsx +10 -0
- hexdag/studio/ui/src/types/index.ts +123 -0
- hexdag/studio/ui/src/vite-env.d.ts +1 -0
- hexdag/studio/ui/tailwind.config.js +29 -0
- hexdag/studio/ui/tsconfig.json +37 -0
- hexdag/studio/ui/tsconfig.node.json +13 -0
- hexdag/studio/ui/vite.config.ts +35 -0
- hexdag/visualization/__init__.py +69 -0
- hexdag/visualization/dag_visualizer.py +1020 -0
- hexdag-0.5.0.dev1.dist-info/METADATA +369 -0
- hexdag-0.5.0.dev1.dist-info/RECORD +261 -0
- hexdag-0.5.0.dev1.dist-info/WHEEL +4 -0
- hexdag-0.5.0.dev1.dist-info/entry_points.txt +4 -0
- hexdag-0.5.0.dev1.dist-info/licenses/LICENSE +190 -0
- hexdag_plugins/.gitignore +43 -0
- hexdag_plugins/README.md +73 -0
- hexdag_plugins/__init__.py +1 -0
- hexdag_plugins/azure/LICENSE +21 -0
- hexdag_plugins/azure/README.md +414 -0
- hexdag_plugins/azure/__init__.py +21 -0
- hexdag_plugins/azure/azure_blob_adapter.py +450 -0
- hexdag_plugins/azure/azure_cosmos_adapter.py +383 -0
- hexdag_plugins/azure/azure_keyvault_adapter.py +314 -0
- hexdag_plugins/azure/azure_openai_adapter.py +415 -0
- hexdag_plugins/azure/pyproject.toml +107 -0
- hexdag_plugins/azure/tests/__init__.py +1 -0
- hexdag_plugins/azure/tests/test_azure_blob_adapter.py +350 -0
- hexdag_plugins/azure/tests/test_azure_cosmos_adapter.py +323 -0
- hexdag_plugins/azure/tests/test_azure_keyvault_adapter.py +330 -0
- hexdag_plugins/azure/tests/test_azure_openai_adapter.py +329 -0
- hexdag_plugins/hexdag_etl/README.md +168 -0
- hexdag_plugins/hexdag_etl/__init__.py +53 -0
- hexdag_plugins/hexdag_etl/examples/01_simple_pandas_transform.py +270 -0
- hexdag_plugins/hexdag_etl/examples/02_simple_pandas_only.py +149 -0
- hexdag_plugins/hexdag_etl/examples/03_file_io_pipeline.py +109 -0
- hexdag_plugins/hexdag_etl/examples/test_pandas_transform.py +84 -0
- hexdag_plugins/hexdag_etl/hexdag.toml +25 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/__init__.py +48 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/__init__.py +13 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/api_extract.py +230 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/base_node_factory.py +181 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/file_io.py +415 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/outlook.py +492 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/pandas_transform.py +563 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/sql_extract_load.py +112 -0
- hexdag_plugins/hexdag_etl/pyproject.toml +82 -0
- hexdag_plugins/hexdag_etl/test_transform.py +54 -0
- hexdag_plugins/hexdag_etl/tests/test_plugin_integration.py +62 -0
- hexdag_plugins/mysql_adapter/LICENSE +21 -0
- hexdag_plugins/mysql_adapter/README.md +224 -0
- hexdag_plugins/mysql_adapter/__init__.py +6 -0
- hexdag_plugins/mysql_adapter/mysql_adapter.py +408 -0
- hexdag_plugins/mysql_adapter/pyproject.toml +93 -0
- hexdag_plugins/mysql_adapter/tests/test_mysql_adapter.py +259 -0
- hexdag_plugins/storage/README.md +184 -0
- hexdag_plugins/storage/__init__.py +19 -0
- hexdag_plugins/storage/file/__init__.py +5 -0
- hexdag_plugins/storage/file/local.py +325 -0
- hexdag_plugins/storage/ports/__init__.py +5 -0
- hexdag_plugins/storage/ports/vector_store.py +236 -0
- hexdag_plugins/storage/sql/__init__.py +7 -0
- hexdag_plugins/storage/sql/base.py +187 -0
- hexdag_plugins/storage/sql/mysql.py +27 -0
- hexdag_plugins/storage/sql/postgresql.py +27 -0
- hexdag_plugins/storage/tests/__init__.py +1 -0
- hexdag_plugins/storage/tests/test_local_file_storage.py +161 -0
- hexdag_plugins/storage/tests/test_sql_adapters.py +212 -0
- hexdag_plugins/storage/vector/__init__.py +7 -0
- hexdag_plugins/storage/vector/chromadb.py +223 -0
- hexdag_plugins/storage/vector/in_memory.py +285 -0
- hexdag_plugins/storage/vector/pgvector.py +502 -0
|
@@ -0,0 +1,910 @@
|
|
|
1
|
+
"""DAG Orchestrator - Core execution engine for the Hex-DAG framework.
|
|
2
|
+
|
|
3
|
+
The Orchestrator walks DirectedGraphs in topological order, executing nodes
|
|
4
|
+
concurrently where possible using asyncio.gather().
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
from collections.abc import AsyncIterator
|
|
11
|
+
from contextlib import asynccontextmanager, suppress
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from hexdag.core.ports.observer_manager import ObserverManagerPort
|
|
16
|
+
else:
|
|
17
|
+
ObserverManagerPort = Any
|
|
18
|
+
|
|
19
|
+
from hexdag.core.context import (
|
|
20
|
+
ExecutionContext,
|
|
21
|
+
get_observer_manager,
|
|
22
|
+
get_ports,
|
|
23
|
+
set_current_graph,
|
|
24
|
+
set_node_results,
|
|
25
|
+
set_ports,
|
|
26
|
+
)
|
|
27
|
+
from hexdag.core.domain.dag import DirectedGraph, DirectedGraphError
|
|
28
|
+
from hexdag.core.exceptions import OrchestratorError
|
|
29
|
+
from hexdag.core.logging import get_logger
|
|
30
|
+
from hexdag.core.orchestration import NodeExecutionContext
|
|
31
|
+
from hexdag.core.orchestration.components import ExecutionCoordinator
|
|
32
|
+
from hexdag.core.orchestration.components.lifecycle_manager import (
|
|
33
|
+
HookConfig,
|
|
34
|
+
LifecycleManager,
|
|
35
|
+
PipelineStatus,
|
|
36
|
+
PostDagHookConfig,
|
|
37
|
+
)
|
|
38
|
+
from hexdag.core.orchestration.constants import (
|
|
39
|
+
EXECUTOR_CONTEXT_GRAPH,
|
|
40
|
+
EXECUTOR_CONTEXT_INITIAL_INPUT,
|
|
41
|
+
EXECUTOR_CONTEXT_NODE_RESULTS,
|
|
42
|
+
)
|
|
43
|
+
from hexdag.core.orchestration.events import WaveCompleted, WaveStarted
|
|
44
|
+
from hexdag.core.orchestration.models import PortsConfiguration
|
|
45
|
+
from hexdag.core.orchestration.port_wrappers import wrap_ports_with_observability
|
|
46
|
+
from hexdag.core.ports.executor import ExecutionTask, ExecutorPort
|
|
47
|
+
from hexdag.core.ports_builder import PortsBuilder
|
|
48
|
+
|
|
49
|
+
from .events import (
|
|
50
|
+
PipelineCancelled,
|
|
51
|
+
PipelineCompleted,
|
|
52
|
+
PipelineStarted,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
logger = get_logger(__name__)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _has_async_lifecycle(obj: Any, method_name: str) -> bool:
|
|
59
|
+
"""Check if object has an async lifecycle method (asetup/aclose).
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
obj : Any
|
|
64
|
+
Object to check
|
|
65
|
+
method_name : str
|
|
66
|
+
Method name to check (e.g., "asetup", "aclose")
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
bool
|
|
71
|
+
True if object has the method and it's a coroutine function
|
|
72
|
+
"""
|
|
73
|
+
return hasattr(obj, method_name) and asyncio.iscoroutinefunction(
|
|
74
|
+
getattr(obj, method_name, None)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Default configuration constants
|
|
79
|
+
DEFAULT_MAX_CONCURRENT_NODES = 10
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@asynccontextmanager
|
|
83
|
+
async def _managed_ports(
|
|
84
|
+
base_ports: dict[str, Any],
|
|
85
|
+
additional_ports: dict[str, Any] | None = None,
|
|
86
|
+
executor: ExecutorPort | None = None,
|
|
87
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
88
|
+
"""Manage port and executor lifecycle with automatic cleanup.
|
|
89
|
+
|
|
90
|
+
Ports and executors that implement asetup()/aclose() methods will be
|
|
91
|
+
automatically initialized and cleaned up.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
base_ports : dict[str, Any]
|
|
96
|
+
Base ports to manage
|
|
97
|
+
additional_ports : dict[str, Any] | None
|
|
98
|
+
Additional ports to merge with base ports
|
|
99
|
+
executor : ExecutorPort | None
|
|
100
|
+
Optional executor to manage lifecycle for
|
|
101
|
+
"""
|
|
102
|
+
all_ports = {**base_ports}
|
|
103
|
+
if additional_ports:
|
|
104
|
+
all_ports.update(additional_ports)
|
|
105
|
+
|
|
106
|
+
# Setup executor first (if provided)
|
|
107
|
+
executor_initialized = False
|
|
108
|
+
if executor is not None and _has_async_lifecycle(executor, "asetup"):
|
|
109
|
+
try:
|
|
110
|
+
await executor.asetup()
|
|
111
|
+
executor_initialized = True
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.error(f"Executor setup failed: {e}")
|
|
114
|
+
raise
|
|
115
|
+
|
|
116
|
+
# Setup ports
|
|
117
|
+
initialized: list[str] = []
|
|
118
|
+
for name, port in all_ports.items():
|
|
119
|
+
if _has_async_lifecycle(port, "asetup"):
|
|
120
|
+
try:
|
|
121
|
+
await port.asetup()
|
|
122
|
+
initialized.append(name)
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.error(f"Port setup failed: {name}: {e}")
|
|
125
|
+
# Cleanup initialized ports
|
|
126
|
+
for cleanup_name in initialized:
|
|
127
|
+
cleanup_port = all_ports[cleanup_name]
|
|
128
|
+
if _has_async_lifecycle(cleanup_port, "aclose"):
|
|
129
|
+
with suppress(Exception):
|
|
130
|
+
await cleanup_port.aclose()
|
|
131
|
+
# Cleanup executor if initialized
|
|
132
|
+
if (
|
|
133
|
+
executor_initialized
|
|
134
|
+
and executor is not None
|
|
135
|
+
and _has_async_lifecycle(executor, "aclose")
|
|
136
|
+
):
|
|
137
|
+
with suppress(Exception):
|
|
138
|
+
await executor.aclose()
|
|
139
|
+
raise
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
yield all_ports
|
|
143
|
+
finally:
|
|
144
|
+
# Cleanup ports
|
|
145
|
+
for name in initialized:
|
|
146
|
+
port = all_ports[name]
|
|
147
|
+
if _has_async_lifecycle(port, "aclose"):
|
|
148
|
+
try:
|
|
149
|
+
await port.aclose()
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.warning(f"Port cleanup failed: {name}: {e}")
|
|
152
|
+
|
|
153
|
+
# Cleanup executor last
|
|
154
|
+
if (
|
|
155
|
+
executor_initialized
|
|
156
|
+
and executor is not None
|
|
157
|
+
and _has_async_lifecycle(executor, "aclose")
|
|
158
|
+
):
|
|
159
|
+
try:
|
|
160
|
+
await executor.aclose()
|
|
161
|
+
except Exception as e:
|
|
162
|
+
logger.warning(f"Executor cleanup failed: {e}")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class Orchestrator:
|
|
166
|
+
"""Orchestrates DAG execution with concurrent processing and resource management.
|
|
167
|
+
|
|
168
|
+
The orchestrator executes DirectedGraphs by:
|
|
169
|
+
|
|
170
|
+
1. Computing execution waves via topological sorting
|
|
171
|
+
2. Running each wave's nodes concurrently with configurable limits
|
|
172
|
+
3. Passing outputs between nodes
|
|
173
|
+
4. Tracking execution with events
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
max_concurrent_nodes: int = DEFAULT_MAX_CONCURRENT_NODES,
|
|
179
|
+
ports: dict[str, Any] | PortsConfiguration | None = None,
|
|
180
|
+
strict_validation: bool = False,
|
|
181
|
+
default_node_timeout: float | None = None,
|
|
182
|
+
pre_hook_config: HookConfig | None = None,
|
|
183
|
+
post_hook_config: PostDagHookConfig | None = None,
|
|
184
|
+
executor: ExecutorPort | None = None,
|
|
185
|
+
) -> None:
|
|
186
|
+
"""Initialize orchestrator with configuration.
|
|
187
|
+
|
|
188
|
+
Args
|
|
189
|
+
----
|
|
190
|
+
max_concurrent_nodes: Maximum number of nodes to execute concurrently
|
|
191
|
+
ports: Shared ports/dependencies for all pipeline executions.
|
|
192
|
+
Can be either a flat dict (backward compatible) or a PortsConfiguration
|
|
193
|
+
for advanced type-specific and node-level port customization.
|
|
194
|
+
strict_validation: If True, raise errors on validation failure
|
|
195
|
+
default_node_timeout: Default timeout in seconds for each node (None = no timeout)
|
|
196
|
+
pre_hook_config: Configuration for pre-DAG hooks (health checks, secrets, etc.)
|
|
197
|
+
post_hook_config: Configuration for post-DAG hooks (cleanup, checkpoints, etc.)
|
|
198
|
+
executor: Optional executor port for pluggable execution strategies.
|
|
199
|
+
If None (default), creates LocalExecutor with the provided configuration.
|
|
200
|
+
Set to a custom ExecutorPort implementation (e.g., CeleryExecutor,
|
|
201
|
+
AzureFunctionsExecutor) for distributed or serverless execution.
|
|
202
|
+
|
|
203
|
+
Notes
|
|
204
|
+
-----
|
|
205
|
+
ARCHITECTURAL EXCEPTION: This is the ONLY place in hexdag/core that imports
|
|
206
|
+
from hexdag/adapters (lazy import of LocalExecutor). This exception is
|
|
207
|
+
enforced by pre-commit hooks to ensure it remains isolated.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
self.max_concurrent_nodes = max_concurrent_nodes
|
|
211
|
+
self.strict_validation = strict_validation
|
|
212
|
+
self.default_node_timeout = default_node_timeout
|
|
213
|
+
|
|
214
|
+
# Default to LocalExecutor if no executor provided
|
|
215
|
+
# ARCHITECTURAL EXCEPTION: Lazy import to avoid core -> adapters dependency at module level
|
|
216
|
+
if executor is None:
|
|
217
|
+
from hexdag.adapters.executors import LocalExecutor # noqa: PLC0415
|
|
218
|
+
|
|
219
|
+
if default_node_timeout is not None:
|
|
220
|
+
executor = LocalExecutor(
|
|
221
|
+
max_concurrent_nodes=max_concurrent_nodes,
|
|
222
|
+
strict_validation=strict_validation,
|
|
223
|
+
default_node_timeout=default_node_timeout,
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
executor = LocalExecutor(
|
|
227
|
+
max_concurrent_nodes=max_concurrent_nodes,
|
|
228
|
+
strict_validation=strict_validation,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
self.executor = executor
|
|
232
|
+
|
|
233
|
+
self.ports_config: PortsConfiguration | None
|
|
234
|
+
if isinstance(ports, PortsConfiguration):
|
|
235
|
+
self.ports_config = ports
|
|
236
|
+
self.ports = (
|
|
237
|
+
{k: v.port for k, v in ports.global_ports.items()} if ports.global_ports else {}
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
self.ports_config = None
|
|
241
|
+
self.ports = ports or {}
|
|
242
|
+
|
|
243
|
+
# Validate known port types
|
|
244
|
+
self._validate_port_types(self.ports)
|
|
245
|
+
|
|
246
|
+
# Unified managers (consolidate 11 managers into 2 unified managers)
|
|
247
|
+
self._execution_coordinator = ExecutionCoordinator()
|
|
248
|
+
self._lifecycle_manager = LifecycleManager(pre_hook_config, post_hook_config)
|
|
249
|
+
|
|
250
|
+
async def _notify_observer(
|
|
251
|
+
self, observer_manager: ObserverManagerPort | None, event: Any
|
|
252
|
+
) -> None:
|
|
253
|
+
"""Notify observer if it exists (delegates to ExecutionCoordinator)."""
|
|
254
|
+
await self._execution_coordinator.notify_observer(observer_manager, event)
|
|
255
|
+
|
|
256
|
+
def _validate_port_types(self, ports: dict[str, Any]) -> None:
|
|
257
|
+
"""Validate that orchestrator ports match expected types.
|
|
258
|
+
|
|
259
|
+
Args
|
|
260
|
+
----
|
|
261
|
+
ports: Dictionary of ports to validate
|
|
262
|
+
|
|
263
|
+
Notes
|
|
264
|
+
-----
|
|
265
|
+
This provides helpful warnings if ports don't match expected protocols.
|
|
266
|
+
Currently checks observer_manager.
|
|
267
|
+
"""
|
|
268
|
+
# Check observer_manager if provided
|
|
269
|
+
if "observer_manager" in ports:
|
|
270
|
+
obs = ports["observer_manager"]
|
|
271
|
+
if not hasattr(obs, "notify"):
|
|
272
|
+
logger.warning(
|
|
273
|
+
f"Port 'observer_manager' doesn't have 'notify' method. "
|
|
274
|
+
f"Expected ObserverManagerPort, got {type(obs).__name__}"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def _validate_required_ports(
|
|
278
|
+
self, graph: DirectedGraph, available_ports: dict[str, Any]
|
|
279
|
+
) -> None:
|
|
280
|
+
"""Validate that all required ports for nodes in the DAG are available.
|
|
281
|
+
|
|
282
|
+
Args
|
|
283
|
+
----
|
|
284
|
+
graph: The DirectedGraph to validate
|
|
285
|
+
available_ports: Dictionary of available ports
|
|
286
|
+
|
|
287
|
+
Raises
|
|
288
|
+
------
|
|
289
|
+
OrchestratorError: If required ports are missing
|
|
290
|
+
"""
|
|
291
|
+
missing_ports: dict[str, list[str]] = {}
|
|
292
|
+
|
|
293
|
+
for node_name, node_spec in graph.items(): # Using .items() instead of .nodes.items()
|
|
294
|
+
fn = node_spec.fn
|
|
295
|
+
required_ports: list[str] = []
|
|
296
|
+
|
|
297
|
+
# Try to get required_ports from the function/method
|
|
298
|
+
if hasattr(fn, "_hexdag_required_ports"):
|
|
299
|
+
required_ports = getattr(fn, "_hexdag_required_ports", [])
|
|
300
|
+
# Check if bound method - use getattr to avoid type checker issues
|
|
301
|
+
elif (self_obj := getattr(fn, "__self__", None)) is not None:
|
|
302
|
+
# It's a bound method - check the class
|
|
303
|
+
node_class = self_obj.__class__
|
|
304
|
+
required_ports = getattr(node_class, "_hexdag_required_ports", [])
|
|
305
|
+
|
|
306
|
+
# Check each required port
|
|
307
|
+
for port_name in required_ports:
|
|
308
|
+
if port_name not in available_ports:
|
|
309
|
+
if node_name not in missing_ports:
|
|
310
|
+
missing_ports[node_name] = []
|
|
311
|
+
missing_ports[node_name].append(port_name)
|
|
312
|
+
|
|
313
|
+
# Raise error if any ports are missing
|
|
314
|
+
if missing_ports:
|
|
315
|
+
error_msg = "Missing required ports:\n"
|
|
316
|
+
for node_name, ports in missing_ports.items():
|
|
317
|
+
error_msg += f" Node '{node_name}': {', '.join(ports)}\n"
|
|
318
|
+
error_msg += f"\nAvailable ports: {', '.join(available_ports.keys())}"
|
|
319
|
+
raise OrchestratorError(error_msg)
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def from_builder(
|
|
323
|
+
cls,
|
|
324
|
+
builder: PortsBuilder,
|
|
325
|
+
max_concurrent_nodes: int = DEFAULT_MAX_CONCURRENT_NODES,
|
|
326
|
+
strict_validation: bool = False,
|
|
327
|
+
executor: ExecutorPort | None = None,
|
|
328
|
+
) -> "Orchestrator":
|
|
329
|
+
"""Create an Orchestrator using a PortsBuilder.
|
|
330
|
+
|
|
331
|
+
This provides a more intuitive way to configure the orchestrator
|
|
332
|
+
with type-safe port configuration.
|
|
333
|
+
|
|
334
|
+
Args
|
|
335
|
+
----
|
|
336
|
+
builder: Configured PortsBuilder instance
|
|
337
|
+
max_concurrent_nodes: Maximum number of nodes to execute concurrently
|
|
338
|
+
strict_validation: If True, raise errors on validation failure
|
|
339
|
+
executor: Optional executor port for pluggable execution strategies
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
Orchestrator
|
|
344
|
+
New orchestrator instance with configured ports
|
|
345
|
+
|
|
346
|
+
Example
|
|
347
|
+
-------
|
|
348
|
+
```python
|
|
349
|
+
orchestrator = Orchestrator.from_builder(
|
|
350
|
+
PortsBuilder()
|
|
351
|
+
.with_llm(OpenAIAdapter())
|
|
352
|
+
.with_database(PostgresAdapter())
|
|
353
|
+
.with_observer_manager(LocalObserverManager())
|
|
354
|
+
)
|
|
355
|
+
```
|
|
356
|
+
"""
|
|
357
|
+
return cls(
|
|
358
|
+
max_concurrent_nodes=max_concurrent_nodes,
|
|
359
|
+
ports=builder.build(),
|
|
360
|
+
strict_validation=strict_validation,
|
|
361
|
+
executor=executor,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
async def run(
|
|
365
|
+
self,
|
|
366
|
+
graph: DirectedGraph,
|
|
367
|
+
initial_input: Any,
|
|
368
|
+
additional_ports: dict[str, Any] | PortsBuilder | None = None,
|
|
369
|
+
validate: bool = True,
|
|
370
|
+
dynamic: bool = False,
|
|
371
|
+
max_dynamic_iterations: int = 100,
|
|
372
|
+
**kwargs: Any,
|
|
373
|
+
) -> dict[str, Any]:
|
|
374
|
+
"""Execute a DAG with concurrent processing and resource limits.
|
|
375
|
+
|
|
376
|
+
Supports both traditional dictionary-based ports and the new PortsBuilder
|
|
377
|
+
for additional_ports parameter. When using PortsBuilder, it will be
|
|
378
|
+
automatically converted to a dictionary before merging with base ports.
|
|
379
|
+
|
|
380
|
+
Args
|
|
381
|
+
----
|
|
382
|
+
graph: The DirectedGraph to execute
|
|
383
|
+
initial_input: Initial input data for the graph
|
|
384
|
+
additional_ports: Either a dictionary of ports or a PortsBuilder instance
|
|
385
|
+
validate: Whether to validate the graph before execution
|
|
386
|
+
dynamic: Enable dynamic graph expansion (for agent macros).
|
|
387
|
+
When True, supports:
|
|
388
|
+
- Runtime node injection via get_current_graph()
|
|
389
|
+
- Re-execution of nodes that return None
|
|
390
|
+
- Iterative expansion until all nodes complete
|
|
391
|
+
max_dynamic_iterations: Maximum number of expansion iterations (safety limit).
|
|
392
|
+
Prevents infinite loops in dynamic execution.
|
|
393
|
+
**kwargs: Additional keyword arguments
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
dict[str, Any]
|
|
398
|
+
Dictionary mapping node names to their execution results
|
|
399
|
+
|
|
400
|
+
Examples
|
|
401
|
+
--------
|
|
402
|
+
Using dictionary for additional ports (traditional approach):
|
|
403
|
+
|
|
404
|
+
>>> results = await orchestrator.run( # doctest: +SKIP
|
|
405
|
+
... graph,
|
|
406
|
+
... input_data,
|
|
407
|
+
... additional_ports={"llm": MockLLM()}
|
|
408
|
+
... )
|
|
409
|
+
|
|
410
|
+
Using PortsBuilder for additional ports (new approach):
|
|
411
|
+
|
|
412
|
+
>>> results = await orchestrator.run( # doctest: +SKIP
|
|
413
|
+
... graph,
|
|
414
|
+
... input_data,
|
|
415
|
+
... additional_ports=PortsBuilder().with_llm(MockLLM())
|
|
416
|
+
... )
|
|
417
|
+
"""
|
|
418
|
+
# Prepare additional ports (convert PortsBuilder if needed)
|
|
419
|
+
additional_ports_dict: dict[str, Any] | None = None
|
|
420
|
+
if additional_ports:
|
|
421
|
+
if isinstance(additional_ports, PortsBuilder):
|
|
422
|
+
additional_ports_dict = additional_ports.build()
|
|
423
|
+
else:
|
|
424
|
+
additional_ports_dict = additional_ports
|
|
425
|
+
|
|
426
|
+
# Use managed_ports context manager for automatic lifecycle management
|
|
427
|
+
async with _managed_ports(self.ports, additional_ports_dict, self.executor) as all_ports:
|
|
428
|
+
return await self._execute_with_ports(
|
|
429
|
+
graph, initial_input, all_ports, validate, dynamic, max_dynamic_iterations, **kwargs
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
async def _execute_with_ports(
|
|
433
|
+
self,
|
|
434
|
+
graph: DirectedGraph,
|
|
435
|
+
initial_input: Any,
|
|
436
|
+
all_ports: dict[str, Any],
|
|
437
|
+
validate: bool,
|
|
438
|
+
dynamic: bool,
|
|
439
|
+
max_dynamic_iterations: int,
|
|
440
|
+
**kwargs: Any,
|
|
441
|
+
) -> dict[str, Any]:
|
|
442
|
+
"""Execute DAG with managed ports (internal method).
|
|
443
|
+
|
|
444
|
+
This method is separated to work with the managed_ports context manager.
|
|
445
|
+
|
|
446
|
+
Args
|
|
447
|
+
----
|
|
448
|
+
graph: DirectedGraph to execute
|
|
449
|
+
initial_input: Initial input data
|
|
450
|
+
all_ports: Merged ports dictionary
|
|
451
|
+
validate: Whether to validate graph
|
|
452
|
+
dynamic: Enable dynamic graph expansion
|
|
453
|
+
max_dynamic_iterations: Maximum dynamic expansion iterations
|
|
454
|
+
**kwargs: Additional arguments
|
|
455
|
+
"""
|
|
456
|
+
if validate:
|
|
457
|
+
# Validate DAG structure - catch specific DAG errors
|
|
458
|
+
try:
|
|
459
|
+
# By default, skip type checking for backward compatibility
|
|
460
|
+
# Enable via graph.validate(check_type_compatibility=True)
|
|
461
|
+
graph.validate(check_type_compatibility=False)
|
|
462
|
+
except DirectedGraphError as e:
|
|
463
|
+
# DAG-specific errors (cycles, missing nodes, etc.)
|
|
464
|
+
raise OrchestratorError(f"Invalid DAG: {e}") from e
|
|
465
|
+
except (ValueError, TypeError, KeyError) as e:
|
|
466
|
+
# Other validation errors
|
|
467
|
+
raise OrchestratorError(f"Invalid DAG: {e}") from e
|
|
468
|
+
|
|
469
|
+
# Validate required ports for all nodes
|
|
470
|
+
self._validate_required_ports(graph, all_ports)
|
|
471
|
+
|
|
472
|
+
node_results: dict[str, Any] = {}
|
|
473
|
+
waves = graph.waves()
|
|
474
|
+
pipeline_start_time = time.time()
|
|
475
|
+
|
|
476
|
+
observer_manager: ObserverManagerPort | None = all_ports.get("observer_manager")
|
|
477
|
+
|
|
478
|
+
wrapped_ports = wrap_ports_with_observability(all_ports)
|
|
479
|
+
|
|
480
|
+
pipeline_name = getattr(graph, "name", "unnamed")
|
|
481
|
+
context = NodeExecutionContext(dag_id=pipeline_name)
|
|
482
|
+
run_id = str(uuid.uuid4())
|
|
483
|
+
|
|
484
|
+
async with ExecutionContext(
|
|
485
|
+
observer_manager=observer_manager,
|
|
486
|
+
run_id=run_id,
|
|
487
|
+
ports=wrapped_ports,
|
|
488
|
+
):
|
|
489
|
+
set_current_graph(graph)
|
|
490
|
+
set_node_results(node_results)
|
|
491
|
+
|
|
492
|
+
# PRE-DAG LIFECYCLE: Execute before pipeline starts
|
|
493
|
+
pre_hook_results = await self._lifecycle_manager.pre_execute(
|
|
494
|
+
context=context,
|
|
495
|
+
pipeline_name=pipeline_name,
|
|
496
|
+
)
|
|
497
|
+
context.metadata["pre_dag_hooks"] = pre_hook_results
|
|
498
|
+
|
|
499
|
+
# Fire pipeline started event
|
|
500
|
+
event = PipelineStarted(
|
|
501
|
+
name=pipeline_name,
|
|
502
|
+
total_waves=len(waves),
|
|
503
|
+
total_nodes=len(graph.nodes),
|
|
504
|
+
)
|
|
505
|
+
await self._notify_observer(observer_manager, event)
|
|
506
|
+
|
|
507
|
+
timeout = None
|
|
508
|
+
pipeline_status: PipelineStatus = PipelineStatus.SUCCESS
|
|
509
|
+
pipeline_error: BaseException | None = None
|
|
510
|
+
cancelled = False
|
|
511
|
+
|
|
512
|
+
try:
|
|
513
|
+
# Route to appropriate execution mode
|
|
514
|
+
if dynamic:
|
|
515
|
+
# Dynamic execution with runtime node injection
|
|
516
|
+
cancelled = await self._execute_dynamic(
|
|
517
|
+
graph=graph,
|
|
518
|
+
node_results=node_results,
|
|
519
|
+
initial_input=initial_input,
|
|
520
|
+
context=context,
|
|
521
|
+
timeout=timeout,
|
|
522
|
+
validate=validate,
|
|
523
|
+
max_iterations=max_dynamic_iterations,
|
|
524
|
+
**kwargs,
|
|
525
|
+
)
|
|
526
|
+
else:
|
|
527
|
+
# Static execution (traditional wave-based)
|
|
528
|
+
cancelled = await self._execute_with_executor(
|
|
529
|
+
waves=waves,
|
|
530
|
+
graph=graph,
|
|
531
|
+
node_results=node_results,
|
|
532
|
+
initial_input=initial_input,
|
|
533
|
+
context=context,
|
|
534
|
+
timeout=timeout,
|
|
535
|
+
validate=validate,
|
|
536
|
+
**kwargs,
|
|
537
|
+
)
|
|
538
|
+
except BaseException as e:
|
|
539
|
+
pipeline_error = e
|
|
540
|
+
raise # Re-raise immediately
|
|
541
|
+
else:
|
|
542
|
+
# Success path - determine status after execution
|
|
543
|
+
if cancelled:
|
|
544
|
+
pipeline_status = PipelineStatus.CANCELLED
|
|
545
|
+
finally:
|
|
546
|
+
if pipeline_error is not None:
|
|
547
|
+
pipeline_status = PipelineStatus.FAILED
|
|
548
|
+
# Fire appropriate completion/cancellation event
|
|
549
|
+
duration_ms = (time.time() - pipeline_start_time) * 1000
|
|
550
|
+
|
|
551
|
+
if cancelled:
|
|
552
|
+
pipeline_cancelled = PipelineCancelled(
|
|
553
|
+
name=pipeline_name,
|
|
554
|
+
duration_ms=duration_ms,
|
|
555
|
+
reason="timeout",
|
|
556
|
+
partial_results=node_results,
|
|
557
|
+
)
|
|
558
|
+
await self._notify_observer(observer_manager, pipeline_cancelled)
|
|
559
|
+
elif pipeline_status == PipelineStatus.SUCCESS:
|
|
560
|
+
pipeline_completed = PipelineCompleted(
|
|
561
|
+
name=pipeline_name,
|
|
562
|
+
duration_ms=duration_ms,
|
|
563
|
+
node_results=node_results,
|
|
564
|
+
)
|
|
565
|
+
await self._notify_observer(observer_manager, pipeline_completed)
|
|
566
|
+
|
|
567
|
+
# POST-DAG LIFECYCLE: Always execute for cleanup (even on failure)
|
|
568
|
+
try:
|
|
569
|
+
post_hook_results = await self._lifecycle_manager.post_execute(
|
|
570
|
+
context=context,
|
|
571
|
+
pipeline_name=pipeline_name,
|
|
572
|
+
pipeline_status=pipeline_status.value,
|
|
573
|
+
node_results=node_results,
|
|
574
|
+
error=pipeline_error,
|
|
575
|
+
)
|
|
576
|
+
context.metadata["post_dag_hooks"] = post_hook_results
|
|
577
|
+
except Exception as post_hook_error:
|
|
578
|
+
# Log all hook errors but don't fail the pipeline
|
|
579
|
+
# (hooks are for cleanup/observability, not critical path)
|
|
580
|
+
logger.error(
|
|
581
|
+
f"Post-DAG lifecycle failed: {post_hook_error}",
|
|
582
|
+
exc_info=True,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
return node_results
|
|
586
|
+
|
|
587
|
+
async def _execute_with_executor(
|
|
588
|
+
self,
|
|
589
|
+
waves: list[list[str]],
|
|
590
|
+
graph: DirectedGraph,
|
|
591
|
+
node_results: dict[str, Any],
|
|
592
|
+
initial_input: Any,
|
|
593
|
+
context: NodeExecutionContext,
|
|
594
|
+
timeout: float | None,
|
|
595
|
+
validate: bool,
|
|
596
|
+
**kwargs: Any,
|
|
597
|
+
) -> bool:
|
|
598
|
+
"""Execute all waves using the configured executor.
|
|
599
|
+
|
|
600
|
+
This method delegates execution to the executor (LocalExecutor by default,
|
|
601
|
+
or CeleryExecutor, AzureFunctionsExecutor, etc. if provided).
|
|
602
|
+
|
|
603
|
+
Parameters
|
|
604
|
+
----------
|
|
605
|
+
waves : list[list[str]]
|
|
606
|
+
List of execution waves
|
|
607
|
+
graph : DirectedGraph
|
|
608
|
+
The DAG to execute
|
|
609
|
+
node_results : dict[str, Any]
|
|
610
|
+
Accumulated results from previous waves
|
|
611
|
+
initial_input : Any
|
|
612
|
+
Initial input to the pipeline
|
|
613
|
+
context : NodeExecutionContext
|
|
614
|
+
Execution context
|
|
615
|
+
timeout : float | None
|
|
616
|
+
Optional timeout for entire execution
|
|
617
|
+
validate : bool
|
|
618
|
+
Whether to perform validation
|
|
619
|
+
**kwargs : Any
|
|
620
|
+
Additional parameters
|
|
621
|
+
|
|
622
|
+
Returns
|
|
623
|
+
-------
|
|
624
|
+
bool
|
|
625
|
+
True if cancelled (timeout), False if completed
|
|
626
|
+
"""
|
|
627
|
+
# Note: We extend the existing ports with executor-specific context.
|
|
628
|
+
# This is safe because set_ports() wraps in MappingProxyType (immutable).
|
|
629
|
+
existing_ports_result = get_ports()
|
|
630
|
+
existing_ports: dict[str, Any] = (
|
|
631
|
+
dict(existing_ports_result) if existing_ports_result else {}
|
|
632
|
+
)
|
|
633
|
+
executor_ports = {
|
|
634
|
+
**existing_ports,
|
|
635
|
+
EXECUTOR_CONTEXT_GRAPH: graph,
|
|
636
|
+
EXECUTOR_CONTEXT_NODE_RESULTS: node_results,
|
|
637
|
+
EXECUTOR_CONTEXT_INITIAL_INPUT: initial_input,
|
|
638
|
+
}
|
|
639
|
+
set_ports(executor_ports)
|
|
640
|
+
|
|
641
|
+
try:
|
|
642
|
+
async with asyncio.timeout(timeout):
|
|
643
|
+
# Dynamic wave execution - re-compute waves if graph expands
|
|
644
|
+
wave_idx = 0
|
|
645
|
+
previous_wave_count = len(waves)
|
|
646
|
+
|
|
647
|
+
while wave_idx < len(waves):
|
|
648
|
+
wave = waves[wave_idx]
|
|
649
|
+
wave_start_time = time.time()
|
|
650
|
+
|
|
651
|
+
# Fire wave started event
|
|
652
|
+
wave_event = WaveStarted(wave_index=wave_idx + 1, nodes=wave)
|
|
653
|
+
await self._notify_observer(get_observer_manager(), wave_event)
|
|
654
|
+
|
|
655
|
+
tasks = []
|
|
656
|
+
for node_name in wave:
|
|
657
|
+
task = ExecutionTask(
|
|
658
|
+
node_name=node_name,
|
|
659
|
+
node_input=None, # Executor will prepare input from graph
|
|
660
|
+
wave_index=wave_idx + 1,
|
|
661
|
+
should_validate=validate,
|
|
662
|
+
context_data={
|
|
663
|
+
"dag_id": context.dag_id,
|
|
664
|
+
"run_id": context.metadata.get("run_id"),
|
|
665
|
+
"attempt": context.attempt,
|
|
666
|
+
},
|
|
667
|
+
params=kwargs,
|
|
668
|
+
)
|
|
669
|
+
tasks.append(task)
|
|
670
|
+
|
|
671
|
+
# Execute wave using executor
|
|
672
|
+
wave_results = await self.executor.aexecute_wave(tasks)
|
|
673
|
+
|
|
674
|
+
# Note: Failures propagate as NodeExecutionError from executor,
|
|
675
|
+
# so we only see SUCCESS results here
|
|
676
|
+
for node_name, result in wave_results.items():
|
|
677
|
+
node_results[node_name] = result.output
|
|
678
|
+
|
|
679
|
+
set_node_results(node_results)
|
|
680
|
+
|
|
681
|
+
# Fire wave completed event
|
|
682
|
+
wave_completed = WaveCompleted(
|
|
683
|
+
wave_index=wave_idx + 1,
|
|
684
|
+
duration_ms=(time.time() - wave_start_time) * 1000,
|
|
685
|
+
)
|
|
686
|
+
await self._notify_observer(get_observer_manager(), wave_completed)
|
|
687
|
+
|
|
688
|
+
new_waves = graph.waves()
|
|
689
|
+
if len(new_waves) != previous_wave_count:
|
|
690
|
+
# Graph expanded! Re-compute waves
|
|
691
|
+
logger.info(
|
|
692
|
+
"Dynamic expansion detected: {old} → {new} waves",
|
|
693
|
+
old=previous_wave_count,
|
|
694
|
+
new=len(new_waves),
|
|
695
|
+
)
|
|
696
|
+
waves = new_waves
|
|
697
|
+
previous_wave_count = len(new_waves)
|
|
698
|
+
# Don't increment wave_idx - continue from where we are
|
|
699
|
+
|
|
700
|
+
# Move to next wave
|
|
701
|
+
wave_idx += 1
|
|
702
|
+
|
|
703
|
+
return False # Not cancelled
|
|
704
|
+
|
|
705
|
+
except TimeoutError:
|
|
706
|
+
return True # Cancelled due to timeout
|
|
707
|
+
|
|
708
|
+
async def _execute_dynamic(
|
|
709
|
+
self,
|
|
710
|
+
graph: DirectedGraph,
|
|
711
|
+
node_results: dict[str, Any],
|
|
712
|
+
initial_input: Any,
|
|
713
|
+
context: NodeExecutionContext,
|
|
714
|
+
timeout: float | None,
|
|
715
|
+
validate: bool,
|
|
716
|
+
max_iterations: int,
|
|
717
|
+
**kwargs: Any,
|
|
718
|
+
) -> bool:
|
|
719
|
+
"""Execute graph with dynamic node injection support.
|
|
720
|
+
|
|
721
|
+
This method supports:
|
|
722
|
+
1. Runtime node injection via get_current_graph()
|
|
723
|
+
2. Re-execution of nodes that return None
|
|
724
|
+
3. Iterative expansion until completion
|
|
725
|
+
|
|
726
|
+
Parameters
|
|
727
|
+
----------
|
|
728
|
+
graph : DirectedGraph
|
|
729
|
+
The graph being executed (may be modified at runtime)
|
|
730
|
+
node_results : dict[str, Any]
|
|
731
|
+
Dictionary to store node execution results
|
|
732
|
+
initial_input : Any
|
|
733
|
+
Initial input data for the pipeline
|
|
734
|
+
context : NodeExecutionContext
|
|
735
|
+
Execution context for the pipeline
|
|
736
|
+
timeout : float | None
|
|
737
|
+
Optional timeout for the entire execution
|
|
738
|
+
validate : bool
|
|
739
|
+
Whether to validate nodes
|
|
740
|
+
max_iterations : int
|
|
741
|
+
Maximum number of dynamic expansion iterations
|
|
742
|
+
**kwargs : Any
|
|
743
|
+
Additional arguments
|
|
744
|
+
|
|
745
|
+
Returns
|
|
746
|
+
-------
|
|
747
|
+
bool
|
|
748
|
+
True if execution was cancelled, False otherwise
|
|
749
|
+
|
|
750
|
+
Raises
|
|
751
|
+
------
|
|
752
|
+
OrchestratorError
|
|
753
|
+
If max_iterations is exceeded (infinite loop protection)
|
|
754
|
+
"""
|
|
755
|
+
executed_nodes: set[str] = set()
|
|
756
|
+
iteration = 0
|
|
757
|
+
start_time = time.time()
|
|
758
|
+
|
|
759
|
+
logger.info(f"Starting dynamic execution (max_iterations={max_iterations})")
|
|
760
|
+
|
|
761
|
+
while iteration < max_iterations:
|
|
762
|
+
iteration += 1
|
|
763
|
+
|
|
764
|
+
# Check timeout
|
|
765
|
+
if timeout and (time.time() - start_time) > timeout:
|
|
766
|
+
logger.warning("Dynamic execution timeout reached")
|
|
767
|
+
return True # Cancelled
|
|
768
|
+
|
|
769
|
+
# Get current graph state (may have new nodes injected)
|
|
770
|
+
current_node_names = set(graph.nodes.keys())
|
|
771
|
+
|
|
772
|
+
# Find nodes ready to execute
|
|
773
|
+
ready_nodes = self._get_ready_nodes(
|
|
774
|
+
graph=graph,
|
|
775
|
+
all_node_names=current_node_names,
|
|
776
|
+
executed_nodes=executed_nodes,
|
|
777
|
+
node_results=node_results,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
if not ready_nodes:
|
|
781
|
+
# No more nodes to execute - we're done
|
|
782
|
+
logger.info(
|
|
783
|
+
f"Dynamic execution completed after {iteration} iterations "
|
|
784
|
+
f"({len(executed_nodes)} nodes executed)"
|
|
785
|
+
)
|
|
786
|
+
break
|
|
787
|
+
|
|
788
|
+
logger.debug(
|
|
789
|
+
f"Dynamic iteration {iteration}: executing {len(ready_nodes)} nodes: {ready_nodes}"
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
# Execute wave of ready nodes
|
|
793
|
+
wave_cancelled = await self._execute_with_executor(
|
|
794
|
+
waves=[ready_nodes], # Single wave
|
|
795
|
+
graph=graph,
|
|
796
|
+
node_results=node_results,
|
|
797
|
+
initial_input=initial_input,
|
|
798
|
+
context=context,
|
|
799
|
+
timeout=timeout,
|
|
800
|
+
validate=validate,
|
|
801
|
+
**kwargs,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
if wave_cancelled:
|
|
805
|
+
return True # Propagate cancellation
|
|
806
|
+
|
|
807
|
+
# Mark nodes as executed
|
|
808
|
+
# BUT: don't mark nodes that returned None (they need re-execution)
|
|
809
|
+
for node_name in ready_nodes:
|
|
810
|
+
result = node_results.get(node_name)
|
|
811
|
+
if result is not None:
|
|
812
|
+
executed_nodes.add(node_name)
|
|
813
|
+
logger.debug(f"Node {node_name} completed successfully")
|
|
814
|
+
else:
|
|
815
|
+
logger.debug(
|
|
816
|
+
f"Node {node_name} returned None, will re-execute after dependencies"
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
# Update context for expander nodes
|
|
820
|
+
set_node_results(node_results)
|
|
821
|
+
|
|
822
|
+
# Check if new nodes were added to graph
|
|
823
|
+
new_nodes = current_node_names.symmetric_difference(set(graph.nodes.keys()))
|
|
824
|
+
if new_nodes:
|
|
825
|
+
logger.info(f"Detected {len(new_nodes)} newly injected nodes: {new_nodes}")
|
|
826
|
+
|
|
827
|
+
# Check if we exceeded max iterations
|
|
828
|
+
if iteration >= max_iterations:
|
|
829
|
+
unexecuted = set(graph.nodes.keys()) - executed_nodes
|
|
830
|
+
raise OrchestratorError(
|
|
831
|
+
f"Dynamic execution exceeded max_iterations={max_iterations}. "
|
|
832
|
+
f"Possible infinite loop. "
|
|
833
|
+
f"Executed {len(executed_nodes)} nodes, "
|
|
834
|
+
f"{len(unexecuted)} nodes remain: {unexecuted}"
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
return False # Not cancelled
|
|
838
|
+
|
|
839
|
+
def _get_ready_nodes(
|
|
840
|
+
self,
|
|
841
|
+
graph: DirectedGraph,
|
|
842
|
+
all_node_names: set[str],
|
|
843
|
+
executed_nodes: set[str],
|
|
844
|
+
node_results: dict[str, Any],
|
|
845
|
+
) -> list[str]:
|
|
846
|
+
"""Get nodes that are ready to execute in dynamic mode.
|
|
847
|
+
|
|
848
|
+
A node is ready if:
|
|
849
|
+
1. It hasn't been executed successfully yet (not in executed_nodes)
|
|
850
|
+
2. All its dependencies have completed (results available)
|
|
851
|
+
|
|
852
|
+
Parameters
|
|
853
|
+
----------
|
|
854
|
+
graph : DirectedGraph
|
|
855
|
+
The current graph
|
|
856
|
+
all_node_names : set[str]
|
|
857
|
+
All node names in the current graph
|
|
858
|
+
executed_nodes : set[str]
|
|
859
|
+
Names of nodes that have been executed successfully
|
|
860
|
+
node_results : dict[str, Any]
|
|
861
|
+
Current execution results
|
|
862
|
+
|
|
863
|
+
Returns
|
|
864
|
+
-------
|
|
865
|
+
list[str]
|
|
866
|
+
List of node names ready to execute
|
|
867
|
+
"""
|
|
868
|
+
ready = []
|
|
869
|
+
|
|
870
|
+
for node_name in all_node_names:
|
|
871
|
+
# Skip if already executed successfully
|
|
872
|
+
if node_name in executed_nodes:
|
|
873
|
+
continue
|
|
874
|
+
|
|
875
|
+
node_spec = graph.nodes[node_name]
|
|
876
|
+
|
|
877
|
+
# Check if all dependencies are satisfied
|
|
878
|
+
deps_satisfied = all(dep in node_results for dep in node_spec.deps)
|
|
879
|
+
|
|
880
|
+
if deps_satisfied:
|
|
881
|
+
ready.append(node_name)
|
|
882
|
+
|
|
883
|
+
return ready
|
|
884
|
+
|
|
885
|
+
def _resolve_ports_for_node(self, node_name: str, node_spec: Any) -> dict[str, Any]:
|
|
886
|
+
"""Resolve ports for a specific node.
|
|
887
|
+
|
|
888
|
+
Uses PortsConfiguration if available, otherwise returns global ports.
|
|
889
|
+
Resolution order: per-node > per-type > global
|
|
890
|
+
|
|
891
|
+
Args
|
|
892
|
+
----
|
|
893
|
+
node_name: Name of the node
|
|
894
|
+
node_spec: NodeSpec containing node metadata
|
|
895
|
+
|
|
896
|
+
Returns
|
|
897
|
+
-------
|
|
898
|
+
dict[str, Any]: Resolved ports for this node
|
|
899
|
+
"""
|
|
900
|
+
if self.ports_config is None:
|
|
901
|
+
# No PortsConfiguration, use global ports
|
|
902
|
+
return self.ports
|
|
903
|
+
|
|
904
|
+
if node_type := getattr(node_spec, "subtype", None):
|
|
905
|
+
node_type = node_type.value if hasattr(node_type, "value") else str(node_type)
|
|
906
|
+
|
|
907
|
+
# Resolve ports using PortsConfiguration
|
|
908
|
+
resolved_ports = self.ports_config.resolve_ports(node_name, node_type)
|
|
909
|
+
|
|
910
|
+
return {k: v.port for k, v in resolved_ports.items()}
|