hexdag 0.5.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hexdag/__init__.py +116 -0
- hexdag/__main__.py +30 -0
- hexdag/adapters/executors/__init__.py +5 -0
- hexdag/adapters/executors/local_executor.py +316 -0
- hexdag/builtin/__init__.py +6 -0
- hexdag/builtin/adapters/__init__.py +51 -0
- hexdag/builtin/adapters/anthropic/__init__.py +5 -0
- hexdag/builtin/adapters/anthropic/anthropic_adapter.py +151 -0
- hexdag/builtin/adapters/database/__init__.py +6 -0
- hexdag/builtin/adapters/database/csv/csv_adapter.py +249 -0
- hexdag/builtin/adapters/database/pgvector/__init__.py +5 -0
- hexdag/builtin/adapters/database/pgvector/pgvector_adapter.py +478 -0
- hexdag/builtin/adapters/database/sqlalchemy/sqlalchemy_adapter.py +252 -0
- hexdag/builtin/adapters/database/sqlite/__init__.py +5 -0
- hexdag/builtin/adapters/database/sqlite/sqlite_adapter.py +410 -0
- hexdag/builtin/adapters/local/README.md +59 -0
- hexdag/builtin/adapters/local/__init__.py +7 -0
- hexdag/builtin/adapters/local/local_observer_manager.py +696 -0
- hexdag/builtin/adapters/memory/__init__.py +47 -0
- hexdag/builtin/adapters/memory/file_memory_adapter.py +297 -0
- hexdag/builtin/adapters/memory/in_memory_memory.py +216 -0
- hexdag/builtin/adapters/memory/schemas.py +57 -0
- hexdag/builtin/adapters/memory/session_memory.py +178 -0
- hexdag/builtin/adapters/memory/sqlite_memory_adapter.py +215 -0
- hexdag/builtin/adapters/memory/state_memory.py +280 -0
- hexdag/builtin/adapters/mock/README.md +89 -0
- hexdag/builtin/adapters/mock/__init__.py +15 -0
- hexdag/builtin/adapters/mock/hexdag.toml +50 -0
- hexdag/builtin/adapters/mock/mock_database.py +225 -0
- hexdag/builtin/adapters/mock/mock_embedding.py +223 -0
- hexdag/builtin/adapters/mock/mock_llm.py +177 -0
- hexdag/builtin/adapters/mock/mock_tool_adapter.py +192 -0
- hexdag/builtin/adapters/mock/mock_tool_router.py +232 -0
- hexdag/builtin/adapters/openai/__init__.py +5 -0
- hexdag/builtin/adapters/openai/openai_adapter.py +634 -0
- hexdag/builtin/adapters/secret/__init__.py +7 -0
- hexdag/builtin/adapters/secret/local_secret_adapter.py +248 -0
- hexdag/builtin/adapters/unified_tool_router.py +280 -0
- hexdag/builtin/macros/__init__.py +17 -0
- hexdag/builtin/macros/conversation_agent.py +390 -0
- hexdag/builtin/macros/llm_macro.py +151 -0
- hexdag/builtin/macros/reasoning_agent.py +423 -0
- hexdag/builtin/macros/tool_macro.py +380 -0
- hexdag/builtin/nodes/__init__.py +38 -0
- hexdag/builtin/nodes/_discovery.py +123 -0
- hexdag/builtin/nodes/agent_node.py +696 -0
- hexdag/builtin/nodes/base_node_factory.py +242 -0
- hexdag/builtin/nodes/composite_node.py +926 -0
- hexdag/builtin/nodes/data_node.py +201 -0
- hexdag/builtin/nodes/expression_node.py +487 -0
- hexdag/builtin/nodes/function_node.py +454 -0
- hexdag/builtin/nodes/llm_node.py +491 -0
- hexdag/builtin/nodes/loop_node.py +920 -0
- hexdag/builtin/nodes/mapped_input.py +518 -0
- hexdag/builtin/nodes/port_call_node.py +269 -0
- hexdag/builtin/nodes/tool_call_node.py +195 -0
- hexdag/builtin/nodes/tool_utils.py +390 -0
- hexdag/builtin/prompts/__init__.py +68 -0
- hexdag/builtin/prompts/base.py +422 -0
- hexdag/builtin/prompts/chat_prompts.py +303 -0
- hexdag/builtin/prompts/error_correction_prompts.py +320 -0
- hexdag/builtin/prompts/tool_prompts.py +160 -0
- hexdag/builtin/tools/builtin_tools.py +84 -0
- hexdag/builtin/tools/database_tools.py +164 -0
- hexdag/cli/__init__.py +17 -0
- hexdag/cli/__main__.py +7 -0
- hexdag/cli/commands/__init__.py +27 -0
- hexdag/cli/commands/build_cmd.py +812 -0
- hexdag/cli/commands/create_cmd.py +208 -0
- hexdag/cli/commands/docs_cmd.py +293 -0
- hexdag/cli/commands/generate_types_cmd.py +252 -0
- hexdag/cli/commands/init_cmd.py +188 -0
- hexdag/cli/commands/pipeline_cmd.py +494 -0
- hexdag/cli/commands/plugin_dev_cmd.py +529 -0
- hexdag/cli/commands/plugins_cmd.py +441 -0
- hexdag/cli/commands/studio_cmd.py +101 -0
- hexdag/cli/commands/validate_cmd.py +221 -0
- hexdag/cli/main.py +84 -0
- hexdag/core/__init__.py +83 -0
- hexdag/core/config/__init__.py +20 -0
- hexdag/core/config/loader.py +479 -0
- hexdag/core/config/models.py +150 -0
- hexdag/core/configurable.py +294 -0
- hexdag/core/context/__init__.py +37 -0
- hexdag/core/context/execution_context.py +378 -0
- hexdag/core/docs/__init__.py +26 -0
- hexdag/core/docs/extractors.py +678 -0
- hexdag/core/docs/generators.py +890 -0
- hexdag/core/docs/models.py +120 -0
- hexdag/core/domain/__init__.py +10 -0
- hexdag/core/domain/dag.py +1225 -0
- hexdag/core/exceptions.py +234 -0
- hexdag/core/expression_parser.py +569 -0
- hexdag/core/logging.py +449 -0
- hexdag/core/models/__init__.py +17 -0
- hexdag/core/models/base.py +138 -0
- hexdag/core/orchestration/__init__.py +46 -0
- hexdag/core/orchestration/body_executor.py +481 -0
- hexdag/core/orchestration/components/__init__.py +97 -0
- hexdag/core/orchestration/components/adapter_lifecycle_manager.py +113 -0
- hexdag/core/orchestration/components/checkpoint_manager.py +134 -0
- hexdag/core/orchestration/components/execution_coordinator.py +360 -0
- hexdag/core/orchestration/components/health_check_manager.py +176 -0
- hexdag/core/orchestration/components/input_mapper.py +143 -0
- hexdag/core/orchestration/components/lifecycle_manager.py +583 -0
- hexdag/core/orchestration/components/node_executor.py +377 -0
- hexdag/core/orchestration/components/secret_manager.py +202 -0
- hexdag/core/orchestration/components/wave_executor.py +158 -0
- hexdag/core/orchestration/constants.py +17 -0
- hexdag/core/orchestration/events/README.md +312 -0
- hexdag/core/orchestration/events/__init__.py +104 -0
- hexdag/core/orchestration/events/batching.py +330 -0
- hexdag/core/orchestration/events/decorators.py +139 -0
- hexdag/core/orchestration/events/events.py +573 -0
- hexdag/core/orchestration/events/observers/__init__.py +30 -0
- hexdag/core/orchestration/events/observers/core_observers.py +690 -0
- hexdag/core/orchestration/events/observers/models.py +111 -0
- hexdag/core/orchestration/events/taxonomy.py +269 -0
- hexdag/core/orchestration/hook_context.py +237 -0
- hexdag/core/orchestration/hooks.py +437 -0
- hexdag/core/orchestration/models.py +418 -0
- hexdag/core/orchestration/orchestrator.py +910 -0
- hexdag/core/orchestration/orchestrator_factory.py +275 -0
- hexdag/core/orchestration/port_wrappers.py +327 -0
- hexdag/core/orchestration/prompt/__init__.py +32 -0
- hexdag/core/orchestration/prompt/template.py +332 -0
- hexdag/core/pipeline_builder/__init__.py +21 -0
- hexdag/core/pipeline_builder/component_instantiator.py +386 -0
- hexdag/core/pipeline_builder/include_tag.py +265 -0
- hexdag/core/pipeline_builder/pipeline_config.py +133 -0
- hexdag/core/pipeline_builder/py_tag.py +223 -0
- hexdag/core/pipeline_builder/tag_discovery.py +268 -0
- hexdag/core/pipeline_builder/yaml_builder.py +1196 -0
- hexdag/core/pipeline_builder/yaml_validator.py +569 -0
- hexdag/core/ports/__init__.py +65 -0
- hexdag/core/ports/api_call.py +133 -0
- hexdag/core/ports/database.py +489 -0
- hexdag/core/ports/embedding.py +215 -0
- hexdag/core/ports/executor.py +237 -0
- hexdag/core/ports/file_storage.py +117 -0
- hexdag/core/ports/healthcheck.py +87 -0
- hexdag/core/ports/llm.py +551 -0
- hexdag/core/ports/memory.py +70 -0
- hexdag/core/ports/observer_manager.py +130 -0
- hexdag/core/ports/secret.py +145 -0
- hexdag/core/ports/tool_router.py +94 -0
- hexdag/core/ports_builder.py +623 -0
- hexdag/core/protocols.py +273 -0
- hexdag/core/resolver.py +304 -0
- hexdag/core/schema/__init__.py +9 -0
- hexdag/core/schema/generator.py +742 -0
- hexdag/core/secrets.py +242 -0
- hexdag/core/types.py +413 -0
- hexdag/core/utils/async_warnings.py +206 -0
- hexdag/core/utils/schema_conversion.py +78 -0
- hexdag/core/utils/sql_validation.py +86 -0
- hexdag/core/validation/secure_json.py +148 -0
- hexdag/core/yaml_macro.py +517 -0
- hexdag/mcp_server.py +3120 -0
- hexdag/studio/__init__.py +10 -0
- hexdag/studio/build_ui.py +92 -0
- hexdag/studio/server/__init__.py +1 -0
- hexdag/studio/server/main.py +100 -0
- hexdag/studio/server/routes/__init__.py +9 -0
- hexdag/studio/server/routes/execute.py +208 -0
- hexdag/studio/server/routes/export.py +558 -0
- hexdag/studio/server/routes/files.py +207 -0
- hexdag/studio/server/routes/plugins.py +419 -0
- hexdag/studio/server/routes/validate.py +220 -0
- hexdag/studio/ui/index.html +13 -0
- hexdag/studio/ui/package-lock.json +2992 -0
- hexdag/studio/ui/package.json +31 -0
- hexdag/studio/ui/postcss.config.js +6 -0
- hexdag/studio/ui/public/hexdag.svg +5 -0
- hexdag/studio/ui/src/App.tsx +251 -0
- hexdag/studio/ui/src/components/Canvas.tsx +408 -0
- hexdag/studio/ui/src/components/ContextMenu.tsx +187 -0
- hexdag/studio/ui/src/components/FileBrowser.tsx +123 -0
- hexdag/studio/ui/src/components/Header.tsx +181 -0
- hexdag/studio/ui/src/components/HexdagNode.tsx +193 -0
- hexdag/studio/ui/src/components/NodeInspector.tsx +512 -0
- hexdag/studio/ui/src/components/NodePalette.tsx +262 -0
- hexdag/studio/ui/src/components/NodePortsSection.tsx +403 -0
- hexdag/studio/ui/src/components/PluginManager.tsx +347 -0
- hexdag/studio/ui/src/components/PortsEditor.tsx +481 -0
- hexdag/studio/ui/src/components/PythonEditor.tsx +195 -0
- hexdag/studio/ui/src/components/ValidationPanel.tsx +105 -0
- hexdag/studio/ui/src/components/YamlEditor.tsx +196 -0
- hexdag/studio/ui/src/components/index.ts +8 -0
- hexdag/studio/ui/src/index.css +92 -0
- hexdag/studio/ui/src/main.tsx +10 -0
- hexdag/studio/ui/src/types/index.ts +123 -0
- hexdag/studio/ui/src/vite-env.d.ts +1 -0
- hexdag/studio/ui/tailwind.config.js +29 -0
- hexdag/studio/ui/tsconfig.json +37 -0
- hexdag/studio/ui/tsconfig.node.json +13 -0
- hexdag/studio/ui/vite.config.ts +35 -0
- hexdag/visualization/__init__.py +69 -0
- hexdag/visualization/dag_visualizer.py +1020 -0
- hexdag-0.5.0.dev1.dist-info/METADATA +369 -0
- hexdag-0.5.0.dev1.dist-info/RECORD +261 -0
- hexdag-0.5.0.dev1.dist-info/WHEEL +4 -0
- hexdag-0.5.0.dev1.dist-info/entry_points.txt +4 -0
- hexdag-0.5.0.dev1.dist-info/licenses/LICENSE +190 -0
- hexdag_plugins/.gitignore +43 -0
- hexdag_plugins/README.md +73 -0
- hexdag_plugins/__init__.py +1 -0
- hexdag_plugins/azure/LICENSE +21 -0
- hexdag_plugins/azure/README.md +414 -0
- hexdag_plugins/azure/__init__.py +21 -0
- hexdag_plugins/azure/azure_blob_adapter.py +450 -0
- hexdag_plugins/azure/azure_cosmos_adapter.py +383 -0
- hexdag_plugins/azure/azure_keyvault_adapter.py +314 -0
- hexdag_plugins/azure/azure_openai_adapter.py +415 -0
- hexdag_plugins/azure/pyproject.toml +107 -0
- hexdag_plugins/azure/tests/__init__.py +1 -0
- hexdag_plugins/azure/tests/test_azure_blob_adapter.py +350 -0
- hexdag_plugins/azure/tests/test_azure_cosmos_adapter.py +323 -0
- hexdag_plugins/azure/tests/test_azure_keyvault_adapter.py +330 -0
- hexdag_plugins/azure/tests/test_azure_openai_adapter.py +329 -0
- hexdag_plugins/hexdag_etl/README.md +168 -0
- hexdag_plugins/hexdag_etl/__init__.py +53 -0
- hexdag_plugins/hexdag_etl/examples/01_simple_pandas_transform.py +270 -0
- hexdag_plugins/hexdag_etl/examples/02_simple_pandas_only.py +149 -0
- hexdag_plugins/hexdag_etl/examples/03_file_io_pipeline.py +109 -0
- hexdag_plugins/hexdag_etl/examples/test_pandas_transform.py +84 -0
- hexdag_plugins/hexdag_etl/hexdag.toml +25 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/__init__.py +48 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/__init__.py +13 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/api_extract.py +230 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/base_node_factory.py +181 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/file_io.py +415 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/outlook.py +492 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/pandas_transform.py +563 -0
- hexdag_plugins/hexdag_etl/hexdag_etl/nodes/sql_extract_load.py +112 -0
- hexdag_plugins/hexdag_etl/pyproject.toml +82 -0
- hexdag_plugins/hexdag_etl/test_transform.py +54 -0
- hexdag_plugins/hexdag_etl/tests/test_plugin_integration.py +62 -0
- hexdag_plugins/mysql_adapter/LICENSE +21 -0
- hexdag_plugins/mysql_adapter/README.md +224 -0
- hexdag_plugins/mysql_adapter/__init__.py +6 -0
- hexdag_plugins/mysql_adapter/mysql_adapter.py +408 -0
- hexdag_plugins/mysql_adapter/pyproject.toml +93 -0
- hexdag_plugins/mysql_adapter/tests/test_mysql_adapter.py +259 -0
- hexdag_plugins/storage/README.md +184 -0
- hexdag_plugins/storage/__init__.py +19 -0
- hexdag_plugins/storage/file/__init__.py +5 -0
- hexdag_plugins/storage/file/local.py +325 -0
- hexdag_plugins/storage/ports/__init__.py +5 -0
- hexdag_plugins/storage/ports/vector_store.py +236 -0
- hexdag_plugins/storage/sql/__init__.py +7 -0
- hexdag_plugins/storage/sql/base.py +187 -0
- hexdag_plugins/storage/sql/mysql.py +27 -0
- hexdag_plugins/storage/sql/postgresql.py +27 -0
- hexdag_plugins/storage/tests/__init__.py +1 -0
- hexdag_plugins/storage/tests/test_local_file_storage.py +161 -0
- hexdag_plugins/storage/tests/test_sql_adapters.py +212 -0
- hexdag_plugins/storage/vector/__init__.py +7 -0
- hexdag_plugins/storage/vector/chromadb.py +223 -0
- hexdag_plugins/storage/vector/in_memory.py +285 -0
- hexdag_plugins/storage/vector/pgvector.py +502 -0
|
@@ -0,0 +1,1225 @@
|
|
|
1
|
+
"""DAG primitives: NodeSpec and DirectedGraph.
|
|
2
|
+
|
|
3
|
+
This module provides the core building blocks for defining and executing
|
|
4
|
+
directed acyclic graphs of agents in the Hex-DAG framework.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from collections.abc import Callable, Mapping
|
|
10
|
+
from dataclasses import dataclass, field, replace
|
|
11
|
+
from enum import Enum, auto
|
|
12
|
+
from types import MappingProxyType
|
|
13
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from collections.abc import ItemsView, Iterator, KeysView, ValuesView # noqa: F401
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import ValidationError as PydanticValidationError
|
|
20
|
+
|
|
21
|
+
T = TypeVar("T", bound=BaseModel)
|
|
22
|
+
|
|
23
|
+
_EMPTY_SET: frozenset[str] = frozenset()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ValidationError(Exception):
|
|
27
|
+
"""Domain-specific validation error for DAG validation.
|
|
28
|
+
|
|
29
|
+
Note: This is separate from hexdag.core.exceptions.ValidationError
|
|
30
|
+
which is used for general field validation. This exception is specifically
|
|
31
|
+
for DAG node input/output validation failures.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
__slots__ = ()
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Color(Enum):
|
|
38
|
+
"""Colors for DFS cycle detection algorithm."""
|
|
39
|
+
|
|
40
|
+
WHITE = auto() # Unvisited
|
|
41
|
+
GRAY = auto() # Currently being processed (in recursion stack)
|
|
42
|
+
BLACK = auto() # Completely processed
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ValidationCacheState(Enum):
|
|
46
|
+
"""State of the validation cache."""
|
|
47
|
+
|
|
48
|
+
INVALID = auto() # Cache invalidated or never validated
|
|
49
|
+
VALID = auto() # Structural validation passed
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True, slots=True)
|
|
53
|
+
class NodeSpec:
|
|
54
|
+
"""Immutable representation of a node in a DAG.
|
|
55
|
+
|
|
56
|
+
A NodeSpec defines:
|
|
57
|
+
- A unique name within the DAG
|
|
58
|
+
- The function to execute (agent)
|
|
59
|
+
- Input/output types for validation (Pydantic models or legacy types)
|
|
60
|
+
- Dependencies (explicit and computed)
|
|
61
|
+
- Arbitrary metadata parameters
|
|
62
|
+
- Optional conditional execution via `when` expression
|
|
63
|
+
- Retry configuration with exponential backoff
|
|
64
|
+
|
|
65
|
+
Supports fluent chaining via .after() method.
|
|
66
|
+
|
|
67
|
+
Retry Configuration
|
|
68
|
+
-------------------
|
|
69
|
+
Nodes can be configured to automatically retry on failure with exponential
|
|
70
|
+
backoff. This is useful for handling transient errors like rate limits,
|
|
71
|
+
network timeouts, or temporary service unavailability.
|
|
72
|
+
|
|
73
|
+
- max_retries: Number of retry attempts (1 = no retries, 3 = two retries)
|
|
74
|
+
- retry_delay: Initial delay in seconds before first retry (default: 1.0)
|
|
75
|
+
- retry_backoff: Multiplier for exponential backoff (default: 2.0)
|
|
76
|
+
- retry_max_delay: Maximum delay cap in seconds (default: 60.0)
|
|
77
|
+
|
|
78
|
+
Example YAML configuration::
|
|
79
|
+
|
|
80
|
+
- kind: llm_node
|
|
81
|
+
metadata:
|
|
82
|
+
name: api_caller
|
|
83
|
+
spec:
|
|
84
|
+
max_retries: 3
|
|
85
|
+
retry_delay: 1.0
|
|
86
|
+
retry_backoff: 2.0
|
|
87
|
+
retry_max_delay: 30.0
|
|
88
|
+
# ... other params
|
|
89
|
+
|
|
90
|
+
This would retry with delays: 1s, 2s, 4s (capped at 30s if it exceeded).
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
name: str
|
|
94
|
+
fn: Callable[..., Any]
|
|
95
|
+
in_model: type[BaseModel] | None = None # Pydantic model for input validation
|
|
96
|
+
out_model: type[BaseModel] | None = None # Pydantic model for output validation
|
|
97
|
+
deps: frozenset[str] = field(default_factory=frozenset)
|
|
98
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
99
|
+
timeout: float | None = None # Optional timeout in seconds for this node
|
|
100
|
+
max_retries: int | None = None # Optional max retries for this node (1 = no retries)
|
|
101
|
+
retry_delay: float | None = None # Initial delay in seconds before first retry
|
|
102
|
+
retry_backoff: float | None = None # Multiplier for exponential backoff (default: 2.0)
|
|
103
|
+
retry_max_delay: float | None = None # Maximum delay cap in seconds (default: 60.0)
|
|
104
|
+
when: str | None = None # Optional expression to evaluate before execution
|
|
105
|
+
|
|
106
|
+
def __post_init__(self) -> None:
|
|
107
|
+
"""Ensure deps and params are immutable, and intern strings for performance."""
|
|
108
|
+
# Intern node name for memory efficiency and faster comparisons
|
|
109
|
+
object.__setattr__(self, "name", sys.intern(self.name))
|
|
110
|
+
# Intern dependency names as well
|
|
111
|
+
object.__setattr__(self, "deps", frozenset(sys.intern(d) for d in self.deps))
|
|
112
|
+
object.__setattr__(self, "params", MappingProxyType(self.params))
|
|
113
|
+
|
|
114
|
+
def _validate_with_model(
|
|
115
|
+
self, data: Any, model: type[T] | None, validation_type: str
|
|
116
|
+
) -> T | Any:
|
|
117
|
+
"""Validate data using the provided Pydantic model.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
data : Any
|
|
122
|
+
Data to validate
|
|
123
|
+
model : Type[T] | None
|
|
124
|
+
Pydantic model to validate against
|
|
125
|
+
validation_type : str
|
|
126
|
+
Type of validation for error messages ('input' or 'output')
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
T | Any
|
|
131
|
+
Validated model instance or original data if no model
|
|
132
|
+
|
|
133
|
+
Raises
|
|
134
|
+
------
|
|
135
|
+
ValidationError
|
|
136
|
+
If validation fails
|
|
137
|
+
"""
|
|
138
|
+
if model is None:
|
|
139
|
+
return data
|
|
140
|
+
|
|
141
|
+
# Fast path: if already the correct type, return as-is
|
|
142
|
+
if isinstance(data, model):
|
|
143
|
+
return data
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
# If data is a different Pydantic model, convert to dict first
|
|
147
|
+
# This allows schema transformation between incompatible models
|
|
148
|
+
if isinstance(data, BaseModel):
|
|
149
|
+
return model.model_validate(data.model_dump())
|
|
150
|
+
|
|
151
|
+
# For dict, primitives, and other types, validate directly
|
|
152
|
+
return model.model_validate(data)
|
|
153
|
+
except PydanticValidationError as e:
|
|
154
|
+
error_msg = (
|
|
155
|
+
f"{validation_type.capitalize()} validation failed for node '{self.name}': {e}"
|
|
156
|
+
)
|
|
157
|
+
raise ValidationError(error_msg) from e
|
|
158
|
+
except Exception as e:
|
|
159
|
+
error_msg = (
|
|
160
|
+
f"{validation_type.capitalize()} validation error for node '{self.name}': {e}"
|
|
161
|
+
)
|
|
162
|
+
raise ValidationError(error_msg) from e
|
|
163
|
+
|
|
164
|
+
def validate_input(self, data: Any) -> BaseModel | Any:
|
|
165
|
+
"""Validate and convert input data using Pydantic model if available.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
data : Any
|
|
170
|
+
Input data to validate
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
BaseModel | Any
|
|
175
|
+
Validated/converted data
|
|
176
|
+
"""
|
|
177
|
+
return self._validate_with_model(data, self.in_model, "input")
|
|
178
|
+
|
|
179
|
+
def validate_output(self, data: Any) -> BaseModel | Any:
|
|
180
|
+
"""Validate and convert output data using Pydantic model if available.
|
|
181
|
+
|
|
182
|
+
Parameters
|
|
183
|
+
----------
|
|
184
|
+
data : Any
|
|
185
|
+
Output data to validate
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
BaseModel | Any
|
|
190
|
+
Validated/converted data
|
|
191
|
+
"""
|
|
192
|
+
return self._validate_with_model(data, self.out_model, "output")
|
|
193
|
+
|
|
194
|
+
def after(self, *node_names: str) -> "NodeSpec":
|
|
195
|
+
"""Create a new NodeSpec that depends on the specified nodes.
|
|
196
|
+
|
|
197
|
+
Args
|
|
198
|
+
----
|
|
199
|
+
*node_names: Names of nodes this node should run after
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
New NodeSpec with updated dependencies
|
|
204
|
+
|
|
205
|
+
Examples
|
|
206
|
+
--------
|
|
207
|
+
node_b = NodeSpec("b", my_fn).after("a")
|
|
208
|
+
node_c = NodeSpec("c", my_fn).after("a", "b")
|
|
209
|
+
"""
|
|
210
|
+
new_deps = self.deps | frozenset(node_names)
|
|
211
|
+
return replace(self, deps=new_deps)
|
|
212
|
+
|
|
213
|
+
def __rshift__(self, other: "NodeSpec") -> "NodeSpec":
|
|
214
|
+
"""Create dependency using >> operator: node_a >> node_b means "b depends on a".
|
|
215
|
+
|
|
216
|
+
This operator provides a visual, left-to-right data flow representation.
|
|
217
|
+
The node on the right depends on the node on the left.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
other : NodeSpec
|
|
222
|
+
The downstream node that will depend on this node
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
NodeSpec
|
|
227
|
+
A new NodeSpec with the dependency added to 'other'
|
|
228
|
+
|
|
229
|
+
Examples
|
|
230
|
+
--------
|
|
231
|
+
>>> node_a = NodeSpec("a", lambda: "data")
|
|
232
|
+
>>> node_b = NodeSpec("b", lambda x: x.upper())
|
|
233
|
+
>>> node_b_with_dep = node_a >> node_b # b depends on a
|
|
234
|
+
>>> "a" in node_b_with_dep.deps
|
|
235
|
+
True
|
|
236
|
+
|
|
237
|
+
Chain multiple dependencies:
|
|
238
|
+
>>> graph = DirectedGraph()
|
|
239
|
+
>>> dummy = lambda: None
|
|
240
|
+
>>> a = NodeSpec("a", dummy)
|
|
241
|
+
>>> b = NodeSpec("b", dummy)
|
|
242
|
+
>>> c = NodeSpec("c", dummy)
|
|
243
|
+
>>> graph += a
|
|
244
|
+
>>> b_with_dep = a >> b # b depends on a
|
|
245
|
+
>>> graph += b_with_dep
|
|
246
|
+
>>> c_with_dep = b >> c # c depends on b
|
|
247
|
+
>>> graph += c_with_dep
|
|
248
|
+
>>> len(graph)
|
|
249
|
+
3
|
|
250
|
+
>>> "a" in graph.nodes["b"].deps
|
|
251
|
+
True
|
|
252
|
+
>>> "b" in graph.nodes["c"].deps
|
|
253
|
+
True
|
|
254
|
+
|
|
255
|
+
Notes
|
|
256
|
+
-----
|
|
257
|
+
The >> operator reads naturally as "flows into" or "feeds into".
|
|
258
|
+
For multiple dependencies, use .after() method instead.
|
|
259
|
+
"""
|
|
260
|
+
return replace(other, deps=other.deps | frozenset([self.name]))
|
|
261
|
+
|
|
262
|
+
def __repr__(self) -> str:
|
|
263
|
+
"""Readable representation for debugging.
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
String representation of the NodeSpec.
|
|
268
|
+
"""
|
|
269
|
+
deps_str = f", deps={sorted(self.deps)}" if self.deps else ""
|
|
270
|
+
types_str = ""
|
|
271
|
+
|
|
272
|
+
# Show Pydantic models if available
|
|
273
|
+
if self.in_model or self.out_model:
|
|
274
|
+
in_name = self.in_model.__name__ if self.in_model else "Any"
|
|
275
|
+
out_name = self.out_model.__name__ if self.out_model else "Any"
|
|
276
|
+
types_str = f", {in_name} -> {out_name}"
|
|
277
|
+
|
|
278
|
+
params_str = f", params={dict(self.params)}" if self.params else ""
|
|
279
|
+
|
|
280
|
+
return f"NodeSpec('{self.name}'{types_str}{deps_str}{params_str})"
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class DirectedGraphError(Exception):
|
|
284
|
+
"""Base exception for DirectedGraph errors."""
|
|
285
|
+
|
|
286
|
+
__slots__ = ()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class CycleDetectedError(DirectedGraphError):
|
|
290
|
+
"""Raised when a cycle is detected in the DAG."""
|
|
291
|
+
|
|
292
|
+
__slots__ = ()
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class MissingDependencyError(DirectedGraphError):
|
|
296
|
+
"""Raised when a node depends on a non-existent node."""
|
|
297
|
+
|
|
298
|
+
__slots__ = ()
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class DuplicateNodeError(DirectedGraphError):
|
|
302
|
+
"""Raised when attempting to add a node with an existing name."""
|
|
303
|
+
|
|
304
|
+
__slots__ = ()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class SchemaCompatibilityError(DirectedGraphError):
|
|
308
|
+
"""Raised when connected nodes have incompatible schemas."""
|
|
309
|
+
|
|
310
|
+
__slots__ = ()
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class DirectedGraph:
|
|
314
|
+
"""A directed acyclic graph (DAG) for orchestrating NodeSpec instances.
|
|
315
|
+
|
|
316
|
+
Provides:
|
|
317
|
+
- Node management with cycle detection
|
|
318
|
+
- Dependency validation
|
|
319
|
+
- Topological sorting into execution waves
|
|
320
|
+
- Optional Pydantic model compatibility checking
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
nodes: list[NodeSpec] | None = None,
|
|
326
|
+
strict_add: bool = False,
|
|
327
|
+
) -> None:
|
|
328
|
+
"""Initialize DirectedGraph, optionally with a list of nodes.
|
|
329
|
+
|
|
330
|
+
Args
|
|
331
|
+
----
|
|
332
|
+
nodes: Optional list of NodeSpec instances to add to the graph
|
|
333
|
+
strict_add: If True, validate dependencies and cycles immediately on add().
|
|
334
|
+
If False (default), allow adding nodes with missing dependencies
|
|
335
|
+
and defer validation to validate() call. Set to True for dynamic
|
|
336
|
+
graphs to catch errors early.
|
|
337
|
+
"""
|
|
338
|
+
self.nodes: dict[str, NodeSpec] = {}
|
|
339
|
+
self._forward_edges: defaultdict[str, set[str]] = defaultdict(
|
|
340
|
+
set
|
|
341
|
+
) # node -> set of dependents
|
|
342
|
+
self._reverse_edges: defaultdict[str, set[str]] = defaultdict(
|
|
343
|
+
set
|
|
344
|
+
) # node -> set of dependencies
|
|
345
|
+
|
|
346
|
+
self._waves_cache: list[list[str]] | None = None
|
|
347
|
+
self._validation_cache: ValidationCacheState = ValidationCacheState.INVALID
|
|
348
|
+
self._strict_add = strict_add
|
|
349
|
+
|
|
350
|
+
if nodes:
|
|
351
|
+
self.add_many(*nodes)
|
|
352
|
+
|
|
353
|
+
@staticmethod
|
|
354
|
+
def detect_cycle(graph: Mapping[str, set[str] | frozenset[str]]) -> str | None:
|
|
355
|
+
"""Detect cycles in a dependency graph using DFS with three-state coloring.
|
|
356
|
+
|
|
357
|
+
This is a public static method that can be used to detect cycles in simple
|
|
358
|
+
dependency graphs before constructing a full DirectedGraph.
|
|
359
|
+
|
|
360
|
+
Parameters
|
|
361
|
+
----------
|
|
362
|
+
graph : Mapping[str, set[str] | frozenset[str]]
|
|
363
|
+
Dependency graph where keys are node names and values are sets of dependencies
|
|
364
|
+
|
|
365
|
+
Returns
|
|
366
|
+
-------
|
|
367
|
+
str | None
|
|
368
|
+
Cycle description if found, None otherwise
|
|
369
|
+
|
|
370
|
+
Examples
|
|
371
|
+
--------
|
|
372
|
+
>>> graph = {"a": {"b"}, "b": {"c"}, "c": {"a"}} # a->b->c->a
|
|
373
|
+
>>> DirectedGraph.detect_cycle(graph)
|
|
374
|
+
'Cycle detected: a -> b -> c -> a'
|
|
375
|
+
|
|
376
|
+
>>> graph = {"a": {"b"}, "b": {"c"}, "c": set()} # No cycle
|
|
377
|
+
>>> result = DirectedGraph.detect_cycle(graph)
|
|
378
|
+
>>> result is None
|
|
379
|
+
True
|
|
380
|
+
"""
|
|
381
|
+
colors = dict.fromkeys(graph, Color.WHITE)
|
|
382
|
+
|
|
383
|
+
def dfs(node: str, path: list[str]) -> str | None:
|
|
384
|
+
if colors[node] == Color.GRAY:
|
|
385
|
+
# Found a back edge - cycle detected
|
|
386
|
+
cycle_start = path.index(node)
|
|
387
|
+
cycle = path[cycle_start:] + [node]
|
|
388
|
+
return f"Cycle detected: {' -> '.join(cycle)}"
|
|
389
|
+
|
|
390
|
+
if colors[node] == Color.BLACK:
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
colors[node] = Color.GRAY
|
|
394
|
+
path.append(node)
|
|
395
|
+
|
|
396
|
+
# Visit all dependencies
|
|
397
|
+
for dep in graph.get(node, set()):
|
|
398
|
+
if dep in colors and (result := dfs(dep, path)):
|
|
399
|
+
return result
|
|
400
|
+
|
|
401
|
+
path.pop()
|
|
402
|
+
colors[node] = Color.BLACK
|
|
403
|
+
return None
|
|
404
|
+
|
|
405
|
+
for node in graph:
|
|
406
|
+
if colors[node] == Color.WHITE and (result := dfs(node, [])):
|
|
407
|
+
return result
|
|
408
|
+
|
|
409
|
+
return None # No cycles found
|
|
410
|
+
|
|
411
|
+
def add(self, node_spec: NodeSpec) -> "DirectedGraph":
|
|
412
|
+
"""Add a NodeSpec to the graph.
|
|
413
|
+
|
|
414
|
+
Args
|
|
415
|
+
----
|
|
416
|
+
node_spec: NodeSpec instance to add to the graph
|
|
417
|
+
|
|
418
|
+
Returns
|
|
419
|
+
-------
|
|
420
|
+
Self for method chaining
|
|
421
|
+
|
|
422
|
+
Raises
|
|
423
|
+
------
|
|
424
|
+
DuplicateNodeError
|
|
425
|
+
If a node with the same name already exists.
|
|
426
|
+
MissingDependencyError
|
|
427
|
+
If strict_add=True and the node depends on non-existent nodes.
|
|
428
|
+
CycleDetectedError
|
|
429
|
+
If strict_add=True and adding the node would create a cycle.
|
|
430
|
+
|
|
431
|
+
Notes
|
|
432
|
+
-----
|
|
433
|
+
When strict_add=False (default), nodes can be added with missing dependencies.
|
|
434
|
+
Validation happens later when validate() is called.
|
|
435
|
+
|
|
436
|
+
When strict_add=True (optimized for dynamic graphs), validation happens
|
|
437
|
+
immediately with O(deps) complexity instead of O(n²) for full graph validation.
|
|
438
|
+
"""
|
|
439
|
+
if node_spec.name in self.nodes:
|
|
440
|
+
raise DuplicateNodeError(f"Node '{node_spec.name}' already exists in the graph")
|
|
441
|
+
|
|
442
|
+
# Incremental validation only if strict_add=True (for dynamic graphs)
|
|
443
|
+
if self._strict_add:
|
|
444
|
+
missing_deps = [dep for dep in node_spec.deps if dep not in self.nodes]
|
|
445
|
+
if missing_deps:
|
|
446
|
+
raise MissingDependencyError(
|
|
447
|
+
f"Node '{node_spec.name}' depends on missing node(s): {missing_deps}"
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Incremental cycle detection: only check if new node creates cycle
|
|
451
|
+
if self._would_create_cycle(node_spec):
|
|
452
|
+
raise CycleDetectedError(
|
|
453
|
+
f"Adding node '{node_spec.name}' would create a cycle in the graph"
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
self.nodes[node_spec.name] = node_spec
|
|
457
|
+
self._forward_edges[node_spec.name] # Ensure key exists (defaultdict creates empty set)
|
|
458
|
+
self._reverse_edges[node_spec.name] = set(node_spec.deps)
|
|
459
|
+
|
|
460
|
+
for dep in node_spec.deps:
|
|
461
|
+
self._forward_edges[dep].add(node_spec.name)
|
|
462
|
+
|
|
463
|
+
# Invalidate caches when graph structure changes
|
|
464
|
+
self._invalidate_caches()
|
|
465
|
+
|
|
466
|
+
return self
|
|
467
|
+
|
|
468
|
+
def _would_create_cycle(self, new_node: NodeSpec) -> bool:
|
|
469
|
+
"""Fast incremental cycle detection for a new node.
|
|
470
|
+
|
|
471
|
+
Only checks if adding this specific node would create a cycle,
|
|
472
|
+
without validating the entire graph. This is much faster than
|
|
473
|
+
full graph validation for dynamic graphs with 100+ nodes.
|
|
474
|
+
|
|
475
|
+
Strategy: Check if new_node is reachable from any of its dependencies.
|
|
476
|
+
If we can reach new_node by following forward edges from any of its deps,
|
|
477
|
+
adding new_node would create a cycle.
|
|
478
|
+
|
|
479
|
+
Parameters
|
|
480
|
+
----------
|
|
481
|
+
new_node : NodeSpec
|
|
482
|
+
The node being added
|
|
483
|
+
|
|
484
|
+
Returns
|
|
485
|
+
-------
|
|
486
|
+
bool
|
|
487
|
+
True if adding the node would create a cycle
|
|
488
|
+
|
|
489
|
+
Examples
|
|
490
|
+
--------
|
|
491
|
+
>>> graph = DirectedGraph()
|
|
492
|
+
>>> a = NodeSpec("a", lambda: None)
|
|
493
|
+
>>> b = NodeSpec("b", lambda: None).after("a")
|
|
494
|
+
>>> c = NodeSpec("c", lambda: None).after("a")
|
|
495
|
+
>>> graph += [a, b, c]
|
|
496
|
+
>>> bad_node = NodeSpec("a_cycle", lambda: None, deps=frozenset(["c"]))
|
|
497
|
+
>>> # Adding a_cycle with dep on c would create: a -> c -> a_cycle
|
|
498
|
+
>>> # No cycle since a_cycle is new
|
|
499
|
+
>>> graph._would_create_cycle(bad_node)
|
|
500
|
+
False
|
|
501
|
+
>>> # But if a_cycle tries to connect back to a's dependents:
|
|
502
|
+
>>> graph += NodeSpec("d", lambda: None).after("b")
|
|
503
|
+
>>> cycle_node = NodeSpec("b", lambda: None).after("d") # b->d->b
|
|
504
|
+
>>> # This would be caught as duplicate, but demonstrates the concept
|
|
505
|
+
"""
|
|
506
|
+
if not new_node.deps:
|
|
507
|
+
return False # No dependencies = no cycle possible
|
|
508
|
+
|
|
509
|
+
visited: set[str] = set()
|
|
510
|
+
|
|
511
|
+
def can_reach_new_node(current: str) -> bool:
|
|
512
|
+
"""DFS to check if we can reach new_node from current."""
|
|
513
|
+
if current == new_node.name:
|
|
514
|
+
return True # Found cycle!
|
|
515
|
+
|
|
516
|
+
if current in visited:
|
|
517
|
+
return False
|
|
518
|
+
|
|
519
|
+
visited.add(current)
|
|
520
|
+
|
|
521
|
+
# Check all nodes that depend on current (forward edges)
|
|
522
|
+
for dependent in self._forward_edges.get(current, _EMPTY_SET):
|
|
523
|
+
if can_reach_new_node(dependent):
|
|
524
|
+
return True
|
|
525
|
+
|
|
526
|
+
return False
|
|
527
|
+
|
|
528
|
+
# Check if new_node is reachable from any of its dependencies
|
|
529
|
+
return any(can_reach_new_node(dep) for dep in new_node.deps)
|
|
530
|
+
|
|
531
|
+
def _invalidate_caches(self) -> None:
|
|
532
|
+
"""Invalidate cached results when graph structure changes."""
|
|
533
|
+
self._waves_cache = None
|
|
534
|
+
self._validation_cache = ValidationCacheState.INVALID
|
|
535
|
+
|
|
536
|
+
def add_many(self, *node_specs: NodeSpec) -> "DirectedGraph":
|
|
537
|
+
"""Add multiple nodes to the graph.
|
|
538
|
+
|
|
539
|
+
Args
|
|
540
|
+
----
|
|
541
|
+
*node_specs: Variable number of NodeSpec instances to add
|
|
542
|
+
|
|
543
|
+
Returns
|
|
544
|
+
-------
|
|
545
|
+
Self for method chaining
|
|
546
|
+
|
|
547
|
+
Raises
|
|
548
|
+
------
|
|
549
|
+
DuplicateNodeError
|
|
550
|
+
If any node with the same name already exists.
|
|
551
|
+
|
|
552
|
+
Examples
|
|
553
|
+
--------
|
|
554
|
+
graph.add_many(
|
|
555
|
+
NodeSpec("fetch", fetch_fn),
|
|
556
|
+
NodeSpec("process", process_fn).after("fetch"),
|
|
557
|
+
NodeSpec("analyze", analyze_fn).after("process")
|
|
558
|
+
)
|
|
559
|
+
"""
|
|
560
|
+
# First, validate all nodes can be added (check for duplicates)
|
|
561
|
+
for node_spec in node_specs:
|
|
562
|
+
if node_spec.name in self.nodes:
|
|
563
|
+
raise DuplicateNodeError(f"Node '{node_spec.name}' already exists in the graph")
|
|
564
|
+
|
|
565
|
+
# If validation passes, add all nodes
|
|
566
|
+
for node_spec in node_specs:
|
|
567
|
+
self.add(node_spec)
|
|
568
|
+
return self
|
|
569
|
+
|
|
570
|
+
def get_dependencies(self, node_name: str) -> frozenset[str]:
|
|
571
|
+
"""Get the dependencies (parents) of a node.
|
|
572
|
+
|
|
573
|
+
Args
|
|
574
|
+
----
|
|
575
|
+
node_name: Name of the node
|
|
576
|
+
|
|
577
|
+
Returns
|
|
578
|
+
-------
|
|
579
|
+
Immutable set of node names that this node depends on
|
|
580
|
+
|
|
581
|
+
Raises
|
|
582
|
+
------
|
|
583
|
+
KeyError
|
|
584
|
+
If the node doesn't exist.
|
|
585
|
+
"""
|
|
586
|
+
if node_name not in self.nodes:
|
|
587
|
+
raise KeyError(f"Node '{node_name}' not found in graph")
|
|
588
|
+
return self.nodes[node_name].deps
|
|
589
|
+
|
|
590
|
+
def get_dependents(self, node_name: str) -> set[str]:
|
|
591
|
+
"""Get the dependents (children) of a node.
|
|
592
|
+
|
|
593
|
+
Args
|
|
594
|
+
----
|
|
595
|
+
node_name: Name of the node
|
|
596
|
+
|
|
597
|
+
Returns
|
|
598
|
+
-------
|
|
599
|
+
Set of node names that depend on this node
|
|
600
|
+
|
|
601
|
+
Raises
|
|
602
|
+
------
|
|
603
|
+
KeyError
|
|
604
|
+
If the node doesn't exist.
|
|
605
|
+
"""
|
|
606
|
+
if node_name not in self.nodes:
|
|
607
|
+
raise KeyError(f"Node '{node_name}' not found in graph")
|
|
608
|
+
return set(self._forward_edges.get(node_name, _EMPTY_SET))
|
|
609
|
+
|
|
610
|
+
def validate(self, check_type_compatibility: bool = True) -> None:
|
|
611
|
+
"""Validate the DAG structure and optionally type compatibility with caching.
|
|
612
|
+
|
|
613
|
+
Caching behavior:
|
|
614
|
+
- Structural validation (dependencies, cycles) is cached after first success
|
|
615
|
+
- Type compatibility validation is NOT cached (expensive but changes with node specs)
|
|
616
|
+
- Cache invalidated when graph structure changes (add/remove nodes)
|
|
617
|
+
|
|
618
|
+
Checks for:
|
|
619
|
+
- Missing dependencies
|
|
620
|
+
- Cycles in the graph
|
|
621
|
+
- Type compatibility between connected nodes (optional)
|
|
622
|
+
|
|
623
|
+
Parameters
|
|
624
|
+
----------
|
|
625
|
+
check_type_compatibility : bool
|
|
626
|
+
If True, validates that connected nodes have compatible types
|
|
627
|
+
|
|
628
|
+
Raises
|
|
629
|
+
------
|
|
630
|
+
MissingDependencyError
|
|
631
|
+
If any node depends on a non-existent node.
|
|
632
|
+
CycleDetectedError
|
|
633
|
+
If a cycle is detected in the graph.
|
|
634
|
+
SchemaCompatibilityError
|
|
635
|
+
If connected nodes have incompatible types.
|
|
636
|
+
"""
|
|
637
|
+
if self._validation_cache == ValidationCacheState.INVALID:
|
|
638
|
+
missing_deps: list[str] = []
|
|
639
|
+
for node_name, node_spec in self.nodes.items():
|
|
640
|
+
for dep in node_spec.deps:
|
|
641
|
+
if dep not in self.nodes:
|
|
642
|
+
msg = f"Node '{node_name}' depends on missing node '{dep}'"
|
|
643
|
+
missing_deps.append(msg)
|
|
644
|
+
|
|
645
|
+
if missing_deps:
|
|
646
|
+
raise MissingDependencyError("; ".join(missing_deps))
|
|
647
|
+
|
|
648
|
+
if cycle_message := self._detect_cycles():
|
|
649
|
+
raise CycleDetectedError(cycle_message)
|
|
650
|
+
|
|
651
|
+
self._validation_cache = ValidationCacheState.VALID
|
|
652
|
+
|
|
653
|
+
if check_type_compatibility and (incompatibilities := self._validate_type_compatibility()):
|
|
654
|
+
raise SchemaCompatibilityError("; ".join(incompatibilities))
|
|
655
|
+
|
|
656
|
+
def _detect_cycles(self) -> str | None:
|
|
657
|
+
"""Detect cycles using depth-first search with three states.
|
|
658
|
+
|
|
659
|
+
Returns
|
|
660
|
+
-------
|
|
661
|
+
str | None
|
|
662
|
+
Cycle detected message or None if no cycle is detected
|
|
663
|
+
"""
|
|
664
|
+
graph = {name: node_spec.deps for name, node_spec in self.nodes.items()}
|
|
665
|
+
return DirectedGraph.detect_cycle(graph)
|
|
666
|
+
|
|
667
|
+
def _validate_type_compatibility(self) -> list[str]:
|
|
668
|
+
"""Validate type compatibility between connected nodes.
|
|
669
|
+
|
|
670
|
+
Checks single-dependency nodes for type mismatches. Multi-dependency nodes
|
|
671
|
+
are not validated automatically as they require custom aggregation logic.
|
|
672
|
+
|
|
673
|
+
Returns
|
|
674
|
+
-------
|
|
675
|
+
List of incompatibility messages or empty list if no incompatibilities are found
|
|
676
|
+
"""
|
|
677
|
+
incompatibilities = []
|
|
678
|
+
|
|
679
|
+
for node_name, node_spec in self.nodes.items():
|
|
680
|
+
# Skip nodes without input validation or dependencies
|
|
681
|
+
if not node_spec.in_model or not node_spec.deps:
|
|
682
|
+
continue
|
|
683
|
+
|
|
684
|
+
# Only validate single-dependency nodes (multi-dep requires custom logic)
|
|
685
|
+
if len(node_spec.deps) == 1:
|
|
686
|
+
dep_name = next(iter(node_spec.deps))
|
|
687
|
+
dep_node = self.nodes[dep_name]
|
|
688
|
+
|
|
689
|
+
# Check if dependency has output model and it mismatches
|
|
690
|
+
if dep_node.out_model and dep_node.out_model != node_spec.in_model:
|
|
691
|
+
incompatibilities.append(
|
|
692
|
+
f"Node '{node_name}' expects {node_spec.in_model.__name__} "
|
|
693
|
+
f"but dependency '{dep_name}' outputs {dep_node.out_model.__name__}"
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
return incompatibilities
|
|
697
|
+
|
|
698
|
+
def waves(self) -> list[list[str]]:
|
|
699
|
+
"""Compute execution waves using topological sorting with caching.
|
|
700
|
+
|
|
701
|
+
Caches the result since waves() is called multiple times during orchestration:
|
|
702
|
+
1. During DAG validation
|
|
703
|
+
2. At pipeline start for event emission
|
|
704
|
+
3. For each wave execution
|
|
705
|
+
|
|
706
|
+
Returns
|
|
707
|
+
-------
|
|
708
|
+
List of waves, where each wave is a list of node names that can
|
|
709
|
+
be executed in parallel.
|
|
710
|
+
|
|
711
|
+
Raises
|
|
712
|
+
------
|
|
713
|
+
CycleDetectedError
|
|
714
|
+
If a cycle is detected (no nodes with zero in-degree found).
|
|
715
|
+
|
|
716
|
+
Examples
|
|
717
|
+
--------
|
|
718
|
+
# For DAG: A -> B -> D, A -> C -> D
|
|
719
|
+
# Returns: [["A"], ["B", "C"], ["D"]]
|
|
720
|
+
"""
|
|
721
|
+
if self._waves_cache is not None:
|
|
722
|
+
return self._waves_cache
|
|
723
|
+
|
|
724
|
+
if not self.nodes:
|
|
725
|
+
return []
|
|
726
|
+
|
|
727
|
+
in_degrees = {node: len(self.nodes[node].deps) for node in self.nodes}
|
|
728
|
+
waves = []
|
|
729
|
+
|
|
730
|
+
while in_degrees:
|
|
731
|
+
current_wave = []
|
|
732
|
+
for node, degree in in_degrees.items():
|
|
733
|
+
if degree == 0:
|
|
734
|
+
current_wave.append(node)
|
|
735
|
+
|
|
736
|
+
if not current_wave:
|
|
737
|
+
remaining_nodes = list(in_degrees.keys())
|
|
738
|
+
raise CycleDetectedError(
|
|
739
|
+
f"No nodes with zero in-degree found. Remaining nodes: {remaining_nodes}"
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
waves.append(sorted(current_wave))
|
|
743
|
+
|
|
744
|
+
for node in current_wave:
|
|
745
|
+
del in_degrees[node]
|
|
746
|
+
for dependent in self._forward_edges.get(node, _EMPTY_SET):
|
|
747
|
+
if dependent in in_degrees:
|
|
748
|
+
in_degrees[dependent] -= 1
|
|
749
|
+
|
|
750
|
+
self._waves_cache = waves
|
|
751
|
+
return waves
|
|
752
|
+
|
|
753
|
+
def waves_remaining(self, completed: frozenset[str] | set[str]) -> list[list[str]]:
|
|
754
|
+
"""Compute execution waves for remaining nodes after some have completed.
|
|
755
|
+
|
|
756
|
+
This is MUCH faster than recomputing waves() for the entire graph when
|
|
757
|
+
nodes are added dynamically during execution. Only computes in-degrees
|
|
758
|
+
for nodes that haven't completed yet.
|
|
759
|
+
|
|
760
|
+
Optimized for dynamic graphs where the graph is modified during execution:
|
|
761
|
+
- Macro expansions add new nodes
|
|
762
|
+
- Some nodes have already completed
|
|
763
|
+
- Need to compute remaining execution order
|
|
764
|
+
|
|
765
|
+
Performance: O(remaining nodes) instead of O(total nodes)
|
|
766
|
+
|
|
767
|
+
Parameters
|
|
768
|
+
----------
|
|
769
|
+
completed : frozenset[str] | set[str]
|
|
770
|
+
Set of node names that have already been executed
|
|
771
|
+
|
|
772
|
+
Returns
|
|
773
|
+
-------
|
|
774
|
+
list[list[str]]
|
|
775
|
+
List of remaining waves for parallel execution
|
|
776
|
+
|
|
777
|
+
Raises
|
|
778
|
+
------
|
|
779
|
+
CycleDetectedError
|
|
780
|
+
If a cycle is detected in remaining nodes
|
|
781
|
+
|
|
782
|
+
Examples
|
|
783
|
+
--------
|
|
784
|
+
Dynamic execution with graph expansion:
|
|
785
|
+
|
|
786
|
+
>>> graph = DirectedGraph()
|
|
787
|
+
>>> graph += NodeSpec("a", lambda: None)
|
|
788
|
+
>>> graph += NodeSpec("b", lambda: None).after("a")
|
|
789
|
+
>>> graph += NodeSpec("c", lambda: None).after("a")
|
|
790
|
+
>>> # After executing 'a', compute remaining waves
|
|
791
|
+
>>> remaining = graph.waves_remaining(frozenset(["a"]))
|
|
792
|
+
>>> remaining # [["b", "c"]]
|
|
793
|
+
[['b', 'c']]
|
|
794
|
+
>>> # Now dynamically add more nodes
|
|
795
|
+
>>> graph += NodeSpec("d", lambda: None).after("b", "c")
|
|
796
|
+
>>> # Compute waves for b, c, d (a already done)
|
|
797
|
+
>>> remaining = graph.waves_remaining(frozenset(["a"]))
|
|
798
|
+
>>> len(remaining) # Two waves: [b,c] then [d]
|
|
799
|
+
2
|
|
800
|
+
|
|
801
|
+
Notes
|
|
802
|
+
-----
|
|
803
|
+
For static graphs (no dynamic expansion), use waves() instead as it caches results.
|
|
804
|
+
This method is optimized for the dynamic case where caching isn't beneficial.
|
|
805
|
+
"""
|
|
806
|
+
if not completed:
|
|
807
|
+
# No nodes completed yet - use full waves() with caching
|
|
808
|
+
return self.waves()
|
|
809
|
+
|
|
810
|
+
if not self.nodes:
|
|
811
|
+
return []
|
|
812
|
+
|
|
813
|
+
# Only compute in-degrees for nodes that haven't completed
|
|
814
|
+
remaining_nodes = self.nodes.keys() - completed
|
|
815
|
+
in_degrees: dict[str, int] = {}
|
|
816
|
+
|
|
817
|
+
for node in remaining_nodes:
|
|
818
|
+
# Count only dependencies that haven't been completed
|
|
819
|
+
deps_remaining = self.nodes[node].deps - completed
|
|
820
|
+
in_degrees[node] = len(deps_remaining)
|
|
821
|
+
|
|
822
|
+
# Kahn's algorithm on remaining nodes only
|
|
823
|
+
waves: list[list[str]] = []
|
|
824
|
+
|
|
825
|
+
while in_degrees:
|
|
826
|
+
current_wave = [node for node, degree in in_degrees.items() if degree == 0]
|
|
827
|
+
|
|
828
|
+
if not current_wave:
|
|
829
|
+
remaining = list(in_degrees.keys())
|
|
830
|
+
raise CycleDetectedError(
|
|
831
|
+
f"No nodes with zero in-degree found. Remaining nodes: {remaining}"
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
waves.append(sorted(current_wave))
|
|
835
|
+
|
|
836
|
+
for node in current_wave:
|
|
837
|
+
del in_degrees[node]
|
|
838
|
+
# Update in-degrees for nodes that depend on completed node
|
|
839
|
+
for dependent in self._forward_edges.get(node, _EMPTY_SET):
|
|
840
|
+
if dependent in in_degrees:
|
|
841
|
+
in_degrees[dependent] -= 1
|
|
842
|
+
|
|
843
|
+
return waves
|
|
844
|
+
|
|
845
|
+
def __repr__(self) -> str:
|
|
846
|
+
"""Developer-friendly representation for debugging.
|
|
847
|
+
|
|
848
|
+
Shows all node names for inspection in REPL and debugging.
|
|
849
|
+
|
|
850
|
+
Returns
|
|
851
|
+
-------
|
|
852
|
+
str
|
|
853
|
+
Debug representation like 'DirectedGraph(nodes={'a', 'b', 'c'})'
|
|
854
|
+
|
|
855
|
+
Examples
|
|
856
|
+
--------
|
|
857
|
+
>>> graph = DirectedGraph()
|
|
858
|
+
>>> graph += NodeSpec("a", lambda: None)
|
|
859
|
+
>>> graph += NodeSpec("b", lambda: None)
|
|
860
|
+
>>> 'a' in repr(graph) and 'b' in repr(graph)
|
|
861
|
+
True
|
|
862
|
+
"""
|
|
863
|
+
if not self.nodes:
|
|
864
|
+
return "DirectedGraph(nodes=set())"
|
|
865
|
+
node_names = sorted(self.nodes.keys())
|
|
866
|
+
return f"DirectedGraph(nodes={set(node_names)!r})"
|
|
867
|
+
|
|
868
|
+
def __str__(self) -> str:
|
|
869
|
+
"""User-friendly string representation.
|
|
870
|
+
|
|
871
|
+
Returns
|
|
872
|
+
-------
|
|
873
|
+
str
|
|
874
|
+
Readable string showing node names
|
|
875
|
+
|
|
876
|
+
Examples
|
|
877
|
+
--------
|
|
878
|
+
>>> graph = DirectedGraph()
|
|
879
|
+
>>> graph += NodeSpec("a", lambda: None)
|
|
880
|
+
>>> graph += NodeSpec("b", lambda: None)
|
|
881
|
+
>>> str(graph)
|
|
882
|
+
'DirectedGraph(2 nodes: a, b)'
|
|
883
|
+
"""
|
|
884
|
+
if not self.nodes:
|
|
885
|
+
return "DirectedGraph(empty)"
|
|
886
|
+
|
|
887
|
+
node_names = sorted(self.nodes.keys())
|
|
888
|
+
if len(node_names) <= 5:
|
|
889
|
+
names_str = ", ".join(node_names)
|
|
890
|
+
return f"DirectedGraph({len(node_names)} nodes: {names_str})"
|
|
891
|
+
# Show first 5 nodes if more than 5
|
|
892
|
+
names_str = ", ".join(node_names[:5])
|
|
893
|
+
return f"DirectedGraph({len(node_names)} nodes: {names_str}, ...)"
|
|
894
|
+
|
|
895
|
+
def __len__(self) -> int:
|
|
896
|
+
"""Return the number of nodes in the graph."""
|
|
897
|
+
return len(self.nodes)
|
|
898
|
+
|
|
899
|
+
def __bool__(self) -> bool:
|
|
900
|
+
"""Return True if the graph has nodes."""
|
|
901
|
+
return bool(self.nodes)
|
|
902
|
+
|
|
903
|
+
def __contains__(self, node_name: str) -> bool:
|
|
904
|
+
"""Check if a node exists in the graph.
|
|
905
|
+
|
|
906
|
+
Examples
|
|
907
|
+
--------
|
|
908
|
+
>>> graph = DirectedGraph()
|
|
909
|
+
>>> _ = graph.add(NodeSpec("a", lambda: None))
|
|
910
|
+
>>> "a" in graph
|
|
911
|
+
True
|
|
912
|
+
>>> "b" in graph
|
|
913
|
+
False
|
|
914
|
+
"""
|
|
915
|
+
return node_name in self.nodes
|
|
916
|
+
|
|
917
|
+
def __iadd__(self, other: NodeSpec | list[NodeSpec]) -> "DirectedGraph":
|
|
918
|
+
"""Add node(s) to graph in-place using += operator.
|
|
919
|
+
|
|
920
|
+
This is a convenience operator that delegates to add() or add_many().
|
|
921
|
+
Provides a more Pythonic way to build graphs.
|
|
922
|
+
|
|
923
|
+
Parameters
|
|
924
|
+
----------
|
|
925
|
+
other : NodeSpec | list[NodeSpec]
|
|
926
|
+
Single node or list of nodes to add
|
|
927
|
+
|
|
928
|
+
Returns
|
|
929
|
+
-------
|
|
930
|
+
DirectedGraph
|
|
931
|
+
Self for method chaining
|
|
932
|
+
|
|
933
|
+
Examples
|
|
934
|
+
--------
|
|
935
|
+
>>> graph = DirectedGraph()
|
|
936
|
+
>>> node = NodeSpec("a", lambda: "result")
|
|
937
|
+
>>> graph += node # Add single node
|
|
938
|
+
>>> len(graph)
|
|
939
|
+
1
|
|
940
|
+
>>> graph += [NodeSpec("b", lambda: None), NodeSpec("c", lambda: None)]
|
|
941
|
+
>>> len(graph)
|
|
942
|
+
3
|
|
943
|
+
"""
|
|
944
|
+
if isinstance(other, DirectedGraph):
|
|
945
|
+
return self.merge(other)
|
|
946
|
+
if isinstance(other, list):
|
|
947
|
+
return self.add_many(*other)
|
|
948
|
+
return self.add(other)
|
|
949
|
+
|
|
950
|
+
def __iter__(self) -> "Iterator[NodeSpec]":
|
|
951
|
+
"""Iterate over NodeSpec instances in the graph.
|
|
952
|
+
|
|
953
|
+
Returns
|
|
954
|
+
-------
|
|
955
|
+
Iterator[NodeSpec]
|
|
956
|
+
Iterator over node specifications in the graph
|
|
957
|
+
|
|
958
|
+
Examples
|
|
959
|
+
--------
|
|
960
|
+
>>> graph = DirectedGraph()
|
|
961
|
+
>>> graph += NodeSpec("a", lambda: None)
|
|
962
|
+
>>> graph += NodeSpec("b", lambda: None)
|
|
963
|
+
>>> for node in graph:
|
|
964
|
+
... print(node.name)
|
|
965
|
+
a
|
|
966
|
+
b
|
|
967
|
+
"""
|
|
968
|
+
return iter(self.nodes.values())
|
|
969
|
+
|
|
970
|
+
def keys(self) -> "KeysView[str]":
|
|
971
|
+
"""Get an iterator over node names (dict-like interface).
|
|
972
|
+
|
|
973
|
+
Returns
|
|
974
|
+
-------
|
|
975
|
+
KeysView
|
|
976
|
+
View of node names in the graph
|
|
977
|
+
|
|
978
|
+
Examples
|
|
979
|
+
--------
|
|
980
|
+
>>> graph = DirectedGraph()
|
|
981
|
+
>>> graph += NodeSpec("a", lambda: None)
|
|
982
|
+
>>> list(graph.keys())
|
|
983
|
+
['a']
|
|
984
|
+
"""
|
|
985
|
+
return self.nodes.keys()
|
|
986
|
+
|
|
987
|
+
def values(self) -> "ValuesView[NodeSpec]":
|
|
988
|
+
"""Get an iterator over NodeSpec instances (dict-like interface).
|
|
989
|
+
|
|
990
|
+
Returns
|
|
991
|
+
-------
|
|
992
|
+
ValuesView
|
|
993
|
+
View of NodeSpec instances in the graph
|
|
994
|
+
|
|
995
|
+
Examples
|
|
996
|
+
--------
|
|
997
|
+
>>> graph = DirectedGraph()
|
|
998
|
+
>>> graph += NodeSpec("a", lambda: None)
|
|
999
|
+
>>> nodes = list(graph.values())
|
|
1000
|
+
>>> len(nodes)
|
|
1001
|
+
1
|
|
1002
|
+
"""
|
|
1003
|
+
return self.nodes.values()
|
|
1004
|
+
|
|
1005
|
+
def items(self) -> "ItemsView[str, NodeSpec]":
|
|
1006
|
+
"""Get an iterator over (name, NodeSpec) pairs (dict-like interface).
|
|
1007
|
+
|
|
1008
|
+
Returns
|
|
1009
|
+
-------
|
|
1010
|
+
ItemsView
|
|
1011
|
+
View of (name, NodeSpec) tuples
|
|
1012
|
+
|
|
1013
|
+
Examples
|
|
1014
|
+
--------
|
|
1015
|
+
>>> graph = DirectedGraph()
|
|
1016
|
+
>>> graph += NodeSpec("a", lambda: None)
|
|
1017
|
+
>>> for name, spec in graph.items():
|
|
1018
|
+
... print(f"{name}: {spec.fn}")
|
|
1019
|
+
a: <function...>
|
|
1020
|
+
"""
|
|
1021
|
+
return self.nodes.items()
|
|
1022
|
+
|
|
1023
|
+
def merge(self, other: "DirectedGraph") -> "DirectedGraph":
|
|
1024
|
+
"""Merge another graph into this one with optimized batching.
|
|
1025
|
+
|
|
1026
|
+
This method provides explicit graph merging, useful for dynamic
|
|
1027
|
+
graph expansion during execution (e.g., from macro expansions).
|
|
1028
|
+
|
|
1029
|
+
Optimized for performance with large subgraphs (10+ nodes):
|
|
1030
|
+
- Single validation pass instead of per-node validation
|
|
1031
|
+
- Batched cycle detection
|
|
1032
|
+
- Faster than calling add() for each node individually
|
|
1033
|
+
|
|
1034
|
+
Performance: O(n) instead of O(n²) for n nodes being merged
|
|
1035
|
+
|
|
1036
|
+
Parameters
|
|
1037
|
+
----------
|
|
1038
|
+
other : DirectedGraph
|
|
1039
|
+
The graph to merge into this one
|
|
1040
|
+
|
|
1041
|
+
Returns
|
|
1042
|
+
-------
|
|
1043
|
+
DirectedGraph
|
|
1044
|
+
Self, for method chaining
|
|
1045
|
+
|
|
1046
|
+
Raises
|
|
1047
|
+
------
|
|
1048
|
+
DuplicateNodeError
|
|
1049
|
+
If any nodes from the other graph already exist in this graph
|
|
1050
|
+
MissingDependencyError
|
|
1051
|
+
If merged nodes have dependencies not in either graph
|
|
1052
|
+
CycleDetectedError
|
|
1053
|
+
If merging would create a cycle
|
|
1054
|
+
|
|
1055
|
+
Examples
|
|
1056
|
+
--------
|
|
1057
|
+
Dynamic graph expansion:
|
|
1058
|
+
|
|
1059
|
+
.. code-block:: python
|
|
1060
|
+
|
|
1061
|
+
main_graph = DirectedGraph()
|
|
1062
|
+
main_graph += NodeSpec("llm", llm_fn)
|
|
1063
|
+
|
|
1064
|
+
# At runtime, expand tool calls into subgraph
|
|
1065
|
+
tool_graph = create_tool_subgraph(tool_calls)
|
|
1066
|
+
main_graph.merge(tool_graph) # Add tool nodes dynamically
|
|
1067
|
+
|
|
1068
|
+
Performance comparison for 50 node merge:
|
|
1069
|
+
- Old: 50 × O(n) validation = O(n²) ≈ 2500 operations
|
|
1070
|
+
- New: 1 × O(n) validation = O(n) ≈ 50 operations
|
|
1071
|
+
"""
|
|
1072
|
+
if not other.nodes:
|
|
1073
|
+
return self # Nothing to merge
|
|
1074
|
+
|
|
1075
|
+
# Fast path for small merges (< 5 nodes) - use existing add()
|
|
1076
|
+
if len(other.nodes) < 5:
|
|
1077
|
+
for node in other:
|
|
1078
|
+
self.add(node)
|
|
1079
|
+
return self
|
|
1080
|
+
|
|
1081
|
+
# Optimized batch merge for large subgraphs
|
|
1082
|
+
# Step 1: Check for duplicate nodes (O(n) set intersection)
|
|
1083
|
+
overlap = self.nodes.keys() & other.nodes.keys()
|
|
1084
|
+
if overlap:
|
|
1085
|
+
raise DuplicateNodeError(f"Cannot merge: duplicate node(s) found: {sorted(overlap)}")
|
|
1086
|
+
|
|
1087
|
+
# Step 2: Check dependencies exist (O(n) validation)
|
|
1088
|
+
# Dependencies can be in either graph (self or other)
|
|
1089
|
+
combined_nodes = self.nodes.keys() | other.nodes.keys()
|
|
1090
|
+
missing_deps = [
|
|
1091
|
+
f"Node '{node_name}' depends on missing '{dep}'"
|
|
1092
|
+
for node_name, node_spec in other.nodes.items()
|
|
1093
|
+
for dep in node_spec.deps
|
|
1094
|
+
if dep not in combined_nodes
|
|
1095
|
+
]
|
|
1096
|
+
|
|
1097
|
+
if missing_deps:
|
|
1098
|
+
raise MissingDependencyError("; ".join(missing_deps))
|
|
1099
|
+
|
|
1100
|
+
# Step 3: Batch add all nodes (skip per-node validation)
|
|
1101
|
+
for node_name, node_spec in other.nodes.items():
|
|
1102
|
+
# Direct insertion - validation already done
|
|
1103
|
+
self.nodes[node_name] = node_spec
|
|
1104
|
+
self._forward_edges[node_name] # Ensure key exists
|
|
1105
|
+
self._reverse_edges[node_name] = set(node_spec.deps)
|
|
1106
|
+
|
|
1107
|
+
for dep in node_spec.deps:
|
|
1108
|
+
self._forward_edges[dep].add(node_name)
|
|
1109
|
+
|
|
1110
|
+
# Step 4: Single cycle check for entire merged graph
|
|
1111
|
+
# Only check if merge creates cycles (incremental check)
|
|
1112
|
+
if self._detect_cycles():
|
|
1113
|
+
# Rollback the merge
|
|
1114
|
+
for node_name in other.nodes:
|
|
1115
|
+
del self.nodes[node_name]
|
|
1116
|
+
if node_name in self._forward_edges:
|
|
1117
|
+
del self._forward_edges[node_name]
|
|
1118
|
+
if node_name in self._reverse_edges:
|
|
1119
|
+
del self._reverse_edges[node_name]
|
|
1120
|
+
|
|
1121
|
+
raise CycleDetectedError("Merging graphs would create a cycle")
|
|
1122
|
+
|
|
1123
|
+
# Step 5: Invalidate caches once
|
|
1124
|
+
self._invalidate_caches()
|
|
1125
|
+
|
|
1126
|
+
return self
|
|
1127
|
+
|
|
1128
|
+
def get_exit_nodes(self) -> list[str]:
|
|
1129
|
+
"""Get nodes with no dependents (exit/leaf nodes).
|
|
1130
|
+
|
|
1131
|
+
Exit nodes are nodes that have no other nodes depending on them.
|
|
1132
|
+
These are typically the final outputs of a subgraph.
|
|
1133
|
+
|
|
1134
|
+
Returns
|
|
1135
|
+
-------
|
|
1136
|
+
list[str]
|
|
1137
|
+
List of node names with no dependents
|
|
1138
|
+
|
|
1139
|
+
Examples
|
|
1140
|
+
--------
|
|
1141
|
+
Find exit nodes:
|
|
1142
|
+
|
|
1143
|
+
.. code-block:: python
|
|
1144
|
+
|
|
1145
|
+
graph = DirectedGraph()
|
|
1146
|
+
graph += NodeSpec("a", lambda: None)
|
|
1147
|
+
graph += NodeSpec("b", lambda: None).after("a")
|
|
1148
|
+
graph += NodeSpec("c", lambda: None).after("a")
|
|
1149
|
+
|
|
1150
|
+
exit_nodes = graph.get_exit_nodes()
|
|
1151
|
+
assert set(exit_nodes) == {"b", "c"} # Both are leaves
|
|
1152
|
+
"""
|
|
1153
|
+
return [node_name for node_name in self.nodes if not self._forward_edges.get(node_name)]
|
|
1154
|
+
|
|
1155
|
+
def __ior__(self, other: "DirectedGraph") -> "DirectedGraph":
|
|
1156
|
+
"""Merge another graph into this one using |= operator.
|
|
1157
|
+
|
|
1158
|
+
This operator provides in-place merging of graphs, useful for composing
|
|
1159
|
+
subgraphs (especially from macro expansions) into a main graph.
|
|
1160
|
+
|
|
1161
|
+
Parameters
|
|
1162
|
+
----------
|
|
1163
|
+
other : DirectedGraph
|
|
1164
|
+
The graph to merge into this one
|
|
1165
|
+
|
|
1166
|
+
Returns
|
|
1167
|
+
-------
|
|
1168
|
+
DirectedGraph
|
|
1169
|
+
Self, for method chaining
|
|
1170
|
+
|
|
1171
|
+
Examples
|
|
1172
|
+
--------
|
|
1173
|
+
Merge graphs:
|
|
1174
|
+
|
|
1175
|
+
.. code-block:: python
|
|
1176
|
+
|
|
1177
|
+
main_graph = DirectedGraph()
|
|
1178
|
+
main_graph += NodeSpec("a", lambda: None)
|
|
1179
|
+
subgraph = DirectedGraph()
|
|
1180
|
+
subgraph += NodeSpec("b", lambda: None)
|
|
1181
|
+
subgraph += NodeSpec("c", lambda: None)
|
|
1182
|
+
main_graph |= subgraph # Merge subgraph into main
|
|
1183
|
+
assert len(main_graph) == 3
|
|
1184
|
+
"""
|
|
1185
|
+
return self.merge(other)
|
|
1186
|
+
|
|
1187
|
+
def __lshift__(self, other: NodeSpec | tuple) -> "DirectedGraph":
|
|
1188
|
+
"""Fluent chaining with << operator: graph << node or graph << (a >> b).
|
|
1189
|
+
|
|
1190
|
+
This operator provides a fluent interface for building graphs with
|
|
1191
|
+
a visual left-to-right flow.
|
|
1192
|
+
|
|
1193
|
+
Parameters
|
|
1194
|
+
----------
|
|
1195
|
+
other : NodeSpec | tuple
|
|
1196
|
+
Single node or tuple of nodes to add
|
|
1197
|
+
|
|
1198
|
+
Returns
|
|
1199
|
+
-------
|
|
1200
|
+
DirectedGraph
|
|
1201
|
+
Self, for method chaining
|
|
1202
|
+
|
|
1203
|
+
Examples
|
|
1204
|
+
--------
|
|
1205
|
+
Fluent chaining:
|
|
1206
|
+
|
|
1207
|
+
.. code-block:: python
|
|
1208
|
+
|
|
1209
|
+
graph = DirectedGraph()
|
|
1210
|
+
a = NodeSpec("a", lambda: None)
|
|
1211
|
+
b = NodeSpec("b", lambda: None)
|
|
1212
|
+
graph << a << b # Fluent chaining
|
|
1213
|
+
assert len(graph) == 2
|
|
1214
|
+
|
|
1215
|
+
# With pipeline operator:
|
|
1216
|
+
graph2 = DirectedGraph()
|
|
1217
|
+
c = NodeSpec("c", lambda: None)
|
|
1218
|
+
graph2 << (a >> b >> c) # Add pipeline
|
|
1219
|
+
"""
|
|
1220
|
+
if isinstance(other, tuple):
|
|
1221
|
+
for node in other:
|
|
1222
|
+
self.add(node)
|
|
1223
|
+
else:
|
|
1224
|
+
self.add(other)
|
|
1225
|
+
return self
|