aethergraph 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aethergraph/__init__.py +49 -0
- aethergraph/config/__init__.py +0 -0
- aethergraph/config/config.py +121 -0
- aethergraph/config/context.py +16 -0
- aethergraph/config/llm.py +26 -0
- aethergraph/config/loader.py +60 -0
- aethergraph/config/runtime.py +9 -0
- aethergraph/contracts/errors/errors.py +44 -0
- aethergraph/contracts/services/artifacts.py +142 -0
- aethergraph/contracts/services/channel.py +72 -0
- aethergraph/contracts/services/continuations.py +23 -0
- aethergraph/contracts/services/eventbus.py +12 -0
- aethergraph/contracts/services/kv.py +24 -0
- aethergraph/contracts/services/llm.py +17 -0
- aethergraph/contracts/services/mcp.py +22 -0
- aethergraph/contracts/services/memory.py +108 -0
- aethergraph/contracts/services/resume.py +28 -0
- aethergraph/contracts/services/state_stores.py +33 -0
- aethergraph/contracts/services/wakeup.py +28 -0
- aethergraph/core/execution/base_scheduler.py +77 -0
- aethergraph/core/execution/forward_scheduler.py +777 -0
- aethergraph/core/execution/global_scheduler.py +634 -0
- aethergraph/core/execution/retry_policy.py +22 -0
- aethergraph/core/execution/step_forward.py +411 -0
- aethergraph/core/execution/step_result.py +18 -0
- aethergraph/core/execution/wait_types.py +72 -0
- aethergraph/core/graph/graph_builder.py +192 -0
- aethergraph/core/graph/graph_fn.py +219 -0
- aethergraph/core/graph/graph_io.py +67 -0
- aethergraph/core/graph/graph_refs.py +154 -0
- aethergraph/core/graph/graph_spec.py +115 -0
- aethergraph/core/graph/graph_state.py +59 -0
- aethergraph/core/graph/graphify.py +128 -0
- aethergraph/core/graph/interpreter.py +145 -0
- aethergraph/core/graph/node_handle.py +33 -0
- aethergraph/core/graph/node_spec.py +46 -0
- aethergraph/core/graph/node_state.py +63 -0
- aethergraph/core/graph/task_graph.py +747 -0
- aethergraph/core/graph/task_node.py +82 -0
- aethergraph/core/graph/utils.py +37 -0
- aethergraph/core/graph/visualize.py +239 -0
- aethergraph/core/runtime/ad_hoc_context.py +61 -0
- aethergraph/core/runtime/base_service.py +153 -0
- aethergraph/core/runtime/bind_adapter.py +42 -0
- aethergraph/core/runtime/bound_memory.py +69 -0
- aethergraph/core/runtime/execution_context.py +220 -0
- aethergraph/core/runtime/graph_runner.py +349 -0
- aethergraph/core/runtime/lifecycle.py +26 -0
- aethergraph/core/runtime/node_context.py +203 -0
- aethergraph/core/runtime/node_services.py +30 -0
- aethergraph/core/runtime/recovery.py +159 -0
- aethergraph/core/runtime/run_registration.py +33 -0
- aethergraph/core/runtime/runtime_env.py +157 -0
- aethergraph/core/runtime/runtime_registry.py +32 -0
- aethergraph/core/runtime/runtime_services.py +224 -0
- aethergraph/core/runtime/wakeup_watcher.py +40 -0
- aethergraph/core/tools/__init__.py +10 -0
- aethergraph/core/tools/builtins/channel_tools.py +194 -0
- aethergraph/core/tools/builtins/toolset.py +134 -0
- aethergraph/core/tools/toolkit.py +510 -0
- aethergraph/core/tools/waitable.py +109 -0
- aethergraph/plugins/channel/__init__.py +0 -0
- aethergraph/plugins/channel/adapters/__init__.py +0 -0
- aethergraph/plugins/channel/adapters/console.py +106 -0
- aethergraph/plugins/channel/adapters/file.py +102 -0
- aethergraph/plugins/channel/adapters/slack.py +285 -0
- aethergraph/plugins/channel/adapters/telegram.py +302 -0
- aethergraph/plugins/channel/adapters/webhook.py +104 -0
- aethergraph/plugins/channel/adapters/webui.py +134 -0
- aethergraph/plugins/channel/routes/__init__.py +0 -0
- aethergraph/plugins/channel/routes/console_routes.py +86 -0
- aethergraph/plugins/channel/routes/slack_routes.py +49 -0
- aethergraph/plugins/channel/routes/telegram_routes.py +26 -0
- aethergraph/plugins/channel/routes/webui_routes.py +136 -0
- aethergraph/plugins/channel/utils/__init__.py +0 -0
- aethergraph/plugins/channel/utils/slack_utils.py +278 -0
- aethergraph/plugins/channel/utils/telegram_utils.py +324 -0
- aethergraph/plugins/channel/websockets/slack_ws.py +68 -0
- aethergraph/plugins/channel/websockets/telegram_polling.py +151 -0
- aethergraph/plugins/mcp/fs_server.py +128 -0
- aethergraph/plugins/mcp/http_server.py +101 -0
- aethergraph/plugins/mcp/ws_server.py +180 -0
- aethergraph/plugins/net/http.py +10 -0
- aethergraph/plugins/utils/data_io.py +359 -0
- aethergraph/runner/__init__.py +5 -0
- aethergraph/runtime/__init__.py +62 -0
- aethergraph/server/__init__.py +3 -0
- aethergraph/server/app_factory.py +84 -0
- aethergraph/server/start.py +122 -0
- aethergraph/services/__init__.py +10 -0
- aethergraph/services/artifacts/facade.py +284 -0
- aethergraph/services/artifacts/factory.py +35 -0
- aethergraph/services/artifacts/fs_store.py +656 -0
- aethergraph/services/artifacts/jsonl_index.py +123 -0
- aethergraph/services/artifacts/paths.py +23 -0
- aethergraph/services/artifacts/sqlite_index.py +209 -0
- aethergraph/services/artifacts/utils.py +124 -0
- aethergraph/services/auth/dev.py +16 -0
- aethergraph/services/channel/channel_bus.py +293 -0
- aethergraph/services/channel/factory.py +44 -0
- aethergraph/services/channel/session.py +511 -0
- aethergraph/services/channel/wait_helpers.py +57 -0
- aethergraph/services/clock/clock.py +9 -0
- aethergraph/services/container/default_container.py +320 -0
- aethergraph/services/continuations/continuation.py +56 -0
- aethergraph/services/continuations/factory.py +34 -0
- aethergraph/services/continuations/stores/fs_store.py +264 -0
- aethergraph/services/continuations/stores/inmem_store.py +95 -0
- aethergraph/services/eventbus/inmem.py +21 -0
- aethergraph/services/features/static.py +10 -0
- aethergraph/services/kv/ephemeral.py +90 -0
- aethergraph/services/kv/factory.py +27 -0
- aethergraph/services/kv/layered.py +41 -0
- aethergraph/services/kv/sqlite_kv.py +128 -0
- aethergraph/services/llm/factory.py +157 -0
- aethergraph/services/llm/generic_client.py +542 -0
- aethergraph/services/llm/providers.py +3 -0
- aethergraph/services/llm/service.py +105 -0
- aethergraph/services/logger/base.py +36 -0
- aethergraph/services/logger/compat.py +50 -0
- aethergraph/services/logger/formatters.py +106 -0
- aethergraph/services/logger/std.py +203 -0
- aethergraph/services/mcp/helpers.py +23 -0
- aethergraph/services/mcp/http_client.py +70 -0
- aethergraph/services/mcp/mcp_tools.py +21 -0
- aethergraph/services/mcp/registry.py +14 -0
- aethergraph/services/mcp/service.py +100 -0
- aethergraph/services/mcp/stdio_client.py +70 -0
- aethergraph/services/mcp/ws_client.py +115 -0
- aethergraph/services/memory/bound.py +106 -0
- aethergraph/services/memory/distillers/episode.py +116 -0
- aethergraph/services/memory/distillers/rolling.py +74 -0
- aethergraph/services/memory/facade.py +633 -0
- aethergraph/services/memory/factory.py +78 -0
- aethergraph/services/memory/hotlog_kv.py +27 -0
- aethergraph/services/memory/indices.py +74 -0
- aethergraph/services/memory/io_helpers.py +72 -0
- aethergraph/services/memory/persist_fs.py +40 -0
- aethergraph/services/memory/resolver.py +152 -0
- aethergraph/services/metering/noop.py +4 -0
- aethergraph/services/prompts/file_store.py +41 -0
- aethergraph/services/rag/chunker.py +29 -0
- aethergraph/services/rag/facade.py +593 -0
- aethergraph/services/rag/index/base.py +27 -0
- aethergraph/services/rag/index/faiss_index.py +121 -0
- aethergraph/services/rag/index/sqlite_index.py +134 -0
- aethergraph/services/rag/index_factory.py +52 -0
- aethergraph/services/rag/parsers/md.py +7 -0
- aethergraph/services/rag/parsers/pdf.py +14 -0
- aethergraph/services/rag/parsers/txt.py +7 -0
- aethergraph/services/rag/utils/hybrid.py +39 -0
- aethergraph/services/rag/utils/make_fs_key.py +62 -0
- aethergraph/services/redactor/simple.py +16 -0
- aethergraph/services/registry/key_parsing.py +44 -0
- aethergraph/services/registry/registry_key.py +19 -0
- aethergraph/services/registry/unified_registry.py +185 -0
- aethergraph/services/resume/multi_scheduler_resume_bus.py +65 -0
- aethergraph/services/resume/router.py +73 -0
- aethergraph/services/schedulers/registry.py +41 -0
- aethergraph/services/secrets/base.py +7 -0
- aethergraph/services/secrets/env.py +8 -0
- aethergraph/services/state_stores/externalize.py +135 -0
- aethergraph/services/state_stores/graph_observer.py +131 -0
- aethergraph/services/state_stores/json_store.py +67 -0
- aethergraph/services/state_stores/resume_policy.py +119 -0
- aethergraph/services/state_stores/serialize.py +249 -0
- aethergraph/services/state_stores/utils.py +91 -0
- aethergraph/services/state_stores/validate.py +78 -0
- aethergraph/services/tracing/noop.py +18 -0
- aethergraph/services/waits/wait_registry.py +91 -0
- aethergraph/services/wakeup/memory_queue.py +57 -0
- aethergraph/services/wakeup/scanner_producer.py +56 -0
- aethergraph/services/wakeup/worker.py +31 -0
- aethergraph/tools/__init__.py +25 -0
- aethergraph/utils/optdeps.py +8 -0
- aethergraph-0.1.0a1.dist-info/METADATA +410 -0
- aethergraph-0.1.0a1.dist-info/RECORD +182 -0
- aethergraph-0.1.0a1.dist-info/WHEEL +5 -0
- aethergraph-0.1.0a1.dist-info/entry_points.txt +2 -0
- aethergraph-0.1.0a1.dist-info/licenses/LICENSE +176 -0
- aethergraph-0.1.0a1.dist-info/licenses/NOTICE +31 -0
- aethergraph-0.1.0a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,747 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from dataclasses import asdict, dataclass, field, is_dataclass
|
|
5
|
+
import inspect
|
|
6
|
+
from typing import Any
|
|
7
|
+
import uuid
|
|
8
|
+
|
|
9
|
+
from .graph_refs import GRAPH_INPUTS_NODE_ID, Ref, normalize_binding, resolve_binding
|
|
10
|
+
from .graph_spec import GraphView, TaskGraphSpec
|
|
11
|
+
from .graph_state import GraphPatch, TaskGraphState
|
|
12
|
+
from .node_spec import TaskNodeSpec
|
|
13
|
+
from .node_state import NodeStatus, TaskNodeState
|
|
14
|
+
from .task_node import TaskNodeRuntime
|
|
15
|
+
from .utils import _logic_label, _short, _status_label
|
|
16
|
+
from .visualize import ascii_overview, to_dot, visualize
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# small helper to turn a dataclass (spec) into plain dict safely
|
|
20
|
+
def _dataclass_to_plain(d):
|
|
21
|
+
return asdict(d) if is_dataclass(d) else d
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class TaskGraph:
|
|
26
|
+
spec: TaskGraphSpec
|
|
27
|
+
state: TaskGraphState = field(
|
|
28
|
+
default_factory=TaskGraphState
|
|
29
|
+
) # mutable state, including node states, patches
|
|
30
|
+
observers: list[Any] = field(default_factory=list)
|
|
31
|
+
|
|
32
|
+
# Expose graph_id as convenient alias; source of truth is spec.graph_id
|
|
33
|
+
graph_id: str = field(init=False, repr=True)
|
|
34
|
+
|
|
35
|
+
# Ephemeral runtime table (not serialized)
|
|
36
|
+
_runtime_nodes: dict[str, TaskNodeRuntime] = field(default_factory=dict, init=False, repr=False)
|
|
37
|
+
|
|
38
|
+
# Inverted indexeds for quick lookup (by alias, logic, label) ephemeral
|
|
39
|
+
_idx_ready: bool = field(default=False, init=False, repr=False)
|
|
40
|
+
_by_alias: dict[str, str] = field(
|
|
41
|
+
default_factory=dict, init=False, repr=False
|
|
42
|
+
) # alias -> node_id
|
|
43
|
+
_by_logic: dict[str, list[str]] = field(
|
|
44
|
+
default_factory=dict, init=False, repr=False
|
|
45
|
+
) # logic -> [node_id, ...]
|
|
46
|
+
_by_label: dict[str, list[str]] = field(
|
|
47
|
+
default_factory=dict, init=False, repr=False
|
|
48
|
+
) # label -> {node_id, ...}
|
|
49
|
+
_by_name: dict[str, list[str]] = field(
|
|
50
|
+
default_factory=dict, init=False, repr=False
|
|
51
|
+
) # display name -> [node_id, ...]: TODO: decided if we need this
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def new_run(self, spec: TaskGraphSpec, *, run_id: str | None = None, **kwargs) -> TaskGraph:
|
|
55
|
+
"""Create a new TaskGraph instance for a new run."""
|
|
56
|
+
run_id = run_id or str(uuid.uuid4())
|
|
57
|
+
# initialize empty node states
|
|
58
|
+
nodes = {nid: TaskNodeState() for nid in spec.nodes}
|
|
59
|
+
state = TaskGraphState(run_id=run_id, nodes=nodes)
|
|
60
|
+
graph = self.from_spec(spec, state=state, **kwargs)
|
|
61
|
+
return graph
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def from_spec(cls, spec: TaskGraphSpec, *, state: TaskGraphState | None = None):
|
|
65
|
+
"""Create a TaskGraph instance from a TaskGraphSpec and optional state and memory."""
|
|
66
|
+
graph = cls(spec=spec, state=state or TaskGraphState())
|
|
67
|
+
# Set back-references in nodes
|
|
68
|
+
for node in graph.spec.nodes.values():
|
|
69
|
+
node._parent_graph = graph
|
|
70
|
+
graph.ensure_inputs_node()
|
|
71
|
+
graph.__post_init__()
|
|
72
|
+
graph.ensure_inputs_node()
|
|
73
|
+
|
|
74
|
+
# Set the inputs node state to DONE
|
|
75
|
+
input_node = graph.node(GRAPH_INPUTS_NODE_ID)
|
|
76
|
+
input_node.state.status = NodeStatus.DONE
|
|
77
|
+
return graph
|
|
78
|
+
|
|
79
|
+
# Publich read-only view
|
|
80
|
+
def node(self, node_id: str) -> TaskNodeRuntime:
|
|
81
|
+
return self._runtime_nodes[node_id]
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def nodes(self) -> list[TaskNodeRuntime]:
|
|
85
|
+
return list(self._runtime_nodes.values())
|
|
86
|
+
|
|
87
|
+
def _apply_patches(self) -> dict[str, TaskNodeSpec]:
|
|
88
|
+
"""Compute a patched node spec dict for the *view*.
|
|
89
|
+
The original spec stays immutable
|
|
90
|
+
"""
|
|
91
|
+
node_specs = dict(self.spec.nodes) # shallow copy of mapping
|
|
92
|
+
|
|
93
|
+
# The following is used when graph mutations are supported. It is just a sketch now.
|
|
94
|
+
for p in self.state.patches:
|
|
95
|
+
if p.op == "add_or_replace_node":
|
|
96
|
+
ns = TaskNodeSpec(**p.payload) # validate payload
|
|
97
|
+
node_specs[ns.node_id] = ns
|
|
98
|
+
elif p.op == "remove_node":
|
|
99
|
+
node_specs.pop(p.payload["node_id"], None)
|
|
100
|
+
elif p.op == "add_dependency":
|
|
101
|
+
nid = p.payload["node_id"]
|
|
102
|
+
dep = p.payload["dependency_id"]
|
|
103
|
+
old = node_specs.get(nid)
|
|
104
|
+
# create a new frozen spec with updated deps
|
|
105
|
+
node_specs[nid] = TaskNodeSpec(
|
|
106
|
+
**{**old.__dict__, "dependencies": [*old.dependencies, dep]}
|
|
107
|
+
)
|
|
108
|
+
# TODO: add more patch type
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
return node_specs
|
|
112
|
+
|
|
113
|
+
def _reify_runtime_nodes(self):
|
|
114
|
+
"""Create TaskNodeRuntime instances for all nodes in the graph spec."""
|
|
115
|
+
effective_specs = self._apply_patches() # get patched specs
|
|
116
|
+
table = {}
|
|
117
|
+
for nid, nspec in effective_specs.items():
|
|
118
|
+
nstate = self.state.nodes.get(nid)
|
|
119
|
+
if nstate is None:
|
|
120
|
+
nstate = TaskNodeState() # in case a patch add a node
|
|
121
|
+
self.state.nodes[nid] = nstate # persist its state in TaskGraphState
|
|
122
|
+
table[nid] = TaskNodeRuntime(spec=nspec, state=nstate, _parent_graph=self)
|
|
123
|
+
self._runtime_nodes = table
|
|
124
|
+
|
|
125
|
+
def __post_init__(self):
|
|
126
|
+
# establish graph_id as alias to spec.graph_id
|
|
127
|
+
self.graph_id = self.spec.graph_id
|
|
128
|
+
|
|
129
|
+
# establish back-references in nodes
|
|
130
|
+
if not getattr(self.state, "nodes", None):
|
|
131
|
+
self.state.nodes = {
|
|
132
|
+
nid: TaskNodeState() for nid in self.spec.nodes
|
|
133
|
+
} # GraphSpec.nodes is Dict[str, TaskNodeSpec]
|
|
134
|
+
|
|
135
|
+
# establish back-references in nodes
|
|
136
|
+
self._reify_runtime_nodes()
|
|
137
|
+
|
|
138
|
+
# index for quick lookup
|
|
139
|
+
self._reindex()
|
|
140
|
+
|
|
141
|
+
def _reindex(self):
|
|
142
|
+
self._by_alias.clear()
|
|
143
|
+
self._by_logic.clear()
|
|
144
|
+
self._by_label.clear()
|
|
145
|
+
self._by_name.clear()
|
|
146
|
+
for nid, node in self._runtime_nodes.items():
|
|
147
|
+
metadata = getattr(node.spec, "metadata", {}) or {}
|
|
148
|
+
alias = metadata.get("alias")
|
|
149
|
+
labels = metadata.get("labels", [])
|
|
150
|
+
display = metadata.get("display_name")
|
|
151
|
+
logic_name = node.spec.tool_name or (
|
|
152
|
+
node.spec.logic if isinstance(node.spec.logic, str) else _short(node.spec.logic)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if alias:
|
|
156
|
+
self._by_alias[alias] = nid
|
|
157
|
+
for label in labels:
|
|
158
|
+
self._by_label.setdefault(label, set()).add(nid)
|
|
159
|
+
if logic_name:
|
|
160
|
+
self._by_logic.setdefault(logic_name, []).append(nid)
|
|
161
|
+
if display:
|
|
162
|
+
self._by_name.setdefault(display, []).append(nid)
|
|
163
|
+
|
|
164
|
+
self._idx_ready = True
|
|
165
|
+
|
|
166
|
+
def _ensure_index(self):
|
|
167
|
+
if not self._idx_ready:
|
|
168
|
+
self._reindex()
|
|
169
|
+
|
|
170
|
+
# Call when mutate spec.nodes in PatchFlow
|
|
171
|
+
def index_touch(self):
|
|
172
|
+
self._idx_ready = False
|
|
173
|
+
|
|
174
|
+
# Node access
|
|
175
|
+
def node_ids(self) -> list[str]:
|
|
176
|
+
"""Get list of all node IDs in the graph."""
|
|
177
|
+
return list(self._runtime_nodes.keys())
|
|
178
|
+
|
|
179
|
+
# Node finder
|
|
180
|
+
def get_by_id(self, node_id: str) -> str:
|
|
181
|
+
"""Get node ID by ID (identity function)."""
|
|
182
|
+
if node_id not in self._runtime_nodes:
|
|
183
|
+
raise ValueError(f"Node ID '{node_id}' not found in graph '{self.graph_id}'")
|
|
184
|
+
return node_id
|
|
185
|
+
|
|
186
|
+
def get_by_alias(self, alias: str) -> str | None:
|
|
187
|
+
"""Get node ID by alias."""
|
|
188
|
+
self._ensure_index()
|
|
189
|
+
node_id = self._by_alias.get(alias)
|
|
190
|
+
if not node_id:
|
|
191
|
+
raise KeyError(f"Alias '{alias}' not found in graph '{self.graph_id}'")
|
|
192
|
+
return node_id
|
|
193
|
+
|
|
194
|
+
def find_by_label(self, label: str) -> list[str]:
|
|
195
|
+
"""Find node IDs by label."""
|
|
196
|
+
self._ensure_index()
|
|
197
|
+
return sorted(self._by_label.get(label, set()))
|
|
198
|
+
|
|
199
|
+
def find_by_logic(self, logic_prefix: str, *, first: bool = False) -> list[str] | str | None:
|
|
200
|
+
"""Find node IDs by logic name.
|
|
201
|
+
If first=True, return only the first match or None if not found.
|
|
202
|
+
|
|
203
|
+
Usage:
|
|
204
|
+
graph.find_by_logic("my_tool") # all nodes with logic name "my_tool"
|
|
205
|
+
graph.find_by_logic("my_tool", first=True) # first node with logic name "my_tool" or None
|
|
206
|
+
graph.find_by_logic("my_tool_v") # all nodes with logic name starting with "my_tool_v"
|
|
207
|
+
graph.find_by_logic("my_tool_v", first=True) # first node with logic name starting with "my_tool_v" or None
|
|
208
|
+
"""
|
|
209
|
+
self._ensure_index()
|
|
210
|
+
if logic_prefix in self._by_logic:
|
|
211
|
+
ids = list(self._by_logic[logic_prefix])
|
|
212
|
+
else:
|
|
213
|
+
ids = []
|
|
214
|
+
for k, vs in self._by_logic.items():
|
|
215
|
+
if k.startswith(logic_prefix):
|
|
216
|
+
ids.extend(vs)
|
|
217
|
+
ids.sort()
|
|
218
|
+
return (ids[0] if (first and ids) else ids) or ([] if not first else None)
|
|
219
|
+
|
|
220
|
+
def find_by_display(self, name_prefix: str, *, first: bool = False) -> list[str] | str | None:
|
|
221
|
+
"""Find node IDs by display name.
|
|
222
|
+
If first=True, return only the first match or None if not found.
|
|
223
|
+
|
|
224
|
+
Usage:
|
|
225
|
+
graph.find_by_display("My Node") # all nodes with display name "My Node"
|
|
226
|
+
graph.find_by_display("My Node", first=True) # first node with display name "My Node" or None
|
|
227
|
+
graph.find_by_display("My Node V") # all nodes with display name starting with "My Node V"
|
|
228
|
+
graph.find_by_display("My Node V", first=True) # first node with display name starting with "My Node V" or None
|
|
229
|
+
"""
|
|
230
|
+
self._ensure_index()
|
|
231
|
+
if name_prefix in self._by_name:
|
|
232
|
+
ids = list(self._by_name[name_prefix])
|
|
233
|
+
else:
|
|
234
|
+
ids = []
|
|
235
|
+
for k, vs in self._by_name.items():
|
|
236
|
+
if k.startswith(name_prefix):
|
|
237
|
+
ids.extend(vs)
|
|
238
|
+
ids.sort()
|
|
239
|
+
return (ids[0] if (first and ids) else ids) or ([] if not first else None)
|
|
240
|
+
|
|
241
|
+
# ---------- Unified selector ----------
|
|
242
|
+
# Mini-DSL:
|
|
243
|
+
# "@alias" -> by alias
|
|
244
|
+
# "#label" -> by label (many)
|
|
245
|
+
# "id:<id>" -> exact id
|
|
246
|
+
# "logic:<pref>" -> logic name prefix
|
|
247
|
+
# "name:<pref>" -> display name prefix
|
|
248
|
+
# "/regex/" -> regex on node_id
|
|
249
|
+
|
|
250
|
+
def select(self, selector: str, *, first: bool = False) -> str | list[str] | None:
|
|
251
|
+
selector = selector.strip()
|
|
252
|
+
if selector.startswith("@"):
|
|
253
|
+
return self.get_by_alias(selector[1:])
|
|
254
|
+
elif selector.startswith("#"):
|
|
255
|
+
ids = self.find_by_label(selector[1:])
|
|
256
|
+
return ids[0] if (first and ids) else ids
|
|
257
|
+
|
|
258
|
+
elif selector.startswith("id:"):
|
|
259
|
+
return self.get_by_id(selector[3:])
|
|
260
|
+
elif selector.startswith("logic:"):
|
|
261
|
+
ids = self.find_by_logic(selector[6:], first=first)
|
|
262
|
+
return ids
|
|
263
|
+
elif selector.startswith("name:"):
|
|
264
|
+
ids = self.find_by_display(selector[5:], first=first)
|
|
265
|
+
return ids
|
|
266
|
+
elif len(selector) >= 2 and selector[0] == "/" and selector[-1] == "/":
|
|
267
|
+
import re
|
|
268
|
+
|
|
269
|
+
pattern = re.compile(selector[1:-1])
|
|
270
|
+
ids = [nid for nid in self.node_ids() if pattern.search(nid)]
|
|
271
|
+
ids.sort()
|
|
272
|
+
return ids[0] if (first and ids) else ids
|
|
273
|
+
else:
|
|
274
|
+
# fallback: prefix on node_id
|
|
275
|
+
ids = [nid for nid in self.node_ids() if nid.startswith(selector)]
|
|
276
|
+
ids.sort()
|
|
277
|
+
return ids[0] if (first and ids) else ids
|
|
278
|
+
|
|
279
|
+
def pick_one(self, selector: str) -> str | None:
|
|
280
|
+
"""Pick one node ID by selector, or None if not found."""
|
|
281
|
+
res = self.select(selector, first=True)
|
|
282
|
+
if not res:
|
|
283
|
+
raise KeyError(f"No node found for selector '{selector}' in graph '{self.graph_id}'")
|
|
284
|
+
return res
|
|
285
|
+
|
|
286
|
+
def pick_all(self, selector: str) -> list[str]:
|
|
287
|
+
"""Pick all node IDs by selector, or empty list if none found."""
|
|
288
|
+
res = self.select(selector, first=False)
|
|
289
|
+
if isinstance(res, str):
|
|
290
|
+
return [res]
|
|
291
|
+
return res or []
|
|
292
|
+
|
|
293
|
+
# --------- Read-only views ---------
|
|
294
|
+
def view(self) -> GraphView:
|
|
295
|
+
"""Get a read-only view of the graph's spec and state."""
|
|
296
|
+
return GraphView(
|
|
297
|
+
graph_id=self.spec.graph_id,
|
|
298
|
+
nodes=self.spec.nodes,
|
|
299
|
+
node_status=self.state.node_status, # state.node_status is a property in TaskGraphState derived from self.state.nodes
|
|
300
|
+
metadata=self.spec.metadata,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# -------- Graph mutation APIs ---------
|
|
304
|
+
def patch_add_or_replace_node(self, node_spec: dict[str, Any]):
|
|
305
|
+
"""Patch the graph by adding or replacing a node."""
|
|
306
|
+
patch = GraphPatch(op="add_or_replace_node", payload=node_spec)
|
|
307
|
+
self.state.patches.append(patch)
|
|
308
|
+
self.state.rev += 1
|
|
309
|
+
# awaitable = None
|
|
310
|
+
for obs in self.observers:
|
|
311
|
+
cb = getattr(obs, "on_patch_applied", None)
|
|
312
|
+
if cb:
|
|
313
|
+
cb(self, patch) # r = cb(self, patch) if awaitable is needed
|
|
314
|
+
# if hasattr(r, "__await__"):
|
|
315
|
+
# awaitable = r # keep last; or gather all
|
|
316
|
+
|
|
317
|
+
self._reify_runtime_nodes()
|
|
318
|
+
|
|
319
|
+
def patch_remove_node(self, node_id: str):
|
|
320
|
+
"""Patch the graph by removing a node."""
|
|
321
|
+
patch = GraphPatch(op="remove_node", payload={"node_id": node_id})
|
|
322
|
+
self.state.patches.append(patch)
|
|
323
|
+
self.state.rev += 1
|
|
324
|
+
# awaitable = None
|
|
325
|
+
for obs in self.observers:
|
|
326
|
+
cb = getattr(obs, "on_patch_applied", None)
|
|
327
|
+
if cb:
|
|
328
|
+
cb(self, patch) # r = cb(self, patch) if awaitable is needed
|
|
329
|
+
# if hasattr(r, "__await__"):
|
|
330
|
+
# awaitable = r # keep last; or gather all
|
|
331
|
+
self._reify_runtime_nodes()
|
|
332
|
+
|
|
333
|
+
def patch_add_dependency(self, node_id: str, dependency_id: str):
|
|
334
|
+
"""Patch the graph by adding a dependency to a node."""
|
|
335
|
+
patch = GraphPatch(
|
|
336
|
+
op="add_dependency", payload={"node_id": node_id, "dependency_id": dependency_id}
|
|
337
|
+
)
|
|
338
|
+
self.state.patches.append(patch)
|
|
339
|
+
self.state.rev += 1
|
|
340
|
+
# awaitable = None
|
|
341
|
+
for obs in self.observers:
|
|
342
|
+
cb = getattr(obs, "on_patch_applied", None)
|
|
343
|
+
if cb:
|
|
344
|
+
cb(self, patch) # r = cb(self, patch) if awaitable is needed
|
|
345
|
+
# if hasattr(r, "__await__"):
|
|
346
|
+
# awaitable = r # keep last; or gather all
|
|
347
|
+
self._reify_runtime_nodes()
|
|
348
|
+
|
|
349
|
+
# --------- Introspection APIs ---------
|
|
350
|
+
def list_nodes(self, exclude_internal=True) -> list[str]:
|
|
351
|
+
"""List all node IDs in the graph."""
|
|
352
|
+
return (
|
|
353
|
+
list(self.spec.nodes.keys())
|
|
354
|
+
if not exclude_internal
|
|
355
|
+
else [nid for nid in self.spec.nodes if not nid.startswith("_")]
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# --------- Topology helpers ---------
|
|
359
|
+
def dependents(self, node_id: str) -> list[str]:
|
|
360
|
+
"""Get list of node_ids that depend on the given node_id."""
|
|
361
|
+
return [x.node_id for x in self.spec.nodes.values() if node_id in x.dependencies]
|
|
362
|
+
|
|
363
|
+
def topological_order(self) -> list[str]:
|
|
364
|
+
"""Get nodes in topological order. Raises error if cycles are detected."""
|
|
365
|
+
import networkx as nx
|
|
366
|
+
|
|
367
|
+
G = nx.DiGraph()
|
|
368
|
+
for n in self.spec.nodes.values():
|
|
369
|
+
G.add_node(n.node_id)
|
|
370
|
+
for dep in n.dependencies:
|
|
371
|
+
G.add_edge(dep, n.node_id)
|
|
372
|
+
try:
|
|
373
|
+
order = list(nx.topological_sort(G))
|
|
374
|
+
return order
|
|
375
|
+
except nx.NetworkXUnfeasible:
|
|
376
|
+
raise ValueError(
|
|
377
|
+
"Graph has at least one cycle; topological sort not possible."
|
|
378
|
+
) from None
|
|
379
|
+
|
|
380
|
+
def get_subgraph_nodes(self, start_node_id: str) -> list[str]:
|
|
381
|
+
"""Get all nodes reachable from the given start_node_id (including itself)."""
|
|
382
|
+
seen, stack = set(), [start_node_id]
|
|
383
|
+
while stack:
|
|
384
|
+
nid = stack.pop()
|
|
385
|
+
if nid in seen:
|
|
386
|
+
continue
|
|
387
|
+
seen.add(nid)
|
|
388
|
+
stack.extend(self.dependents(nid))
|
|
389
|
+
return list(seen)
|
|
390
|
+
|
|
391
|
+
def get_upstream_nodes(self, start_node_id: str) -> list[str]:
|
|
392
|
+
"""Get all upstream nodes that the given node_id depends on (including itself)."""
|
|
393
|
+
seen, stack = set(), [start_node_id]
|
|
394
|
+
while stack:
|
|
395
|
+
nid = stack.pop()
|
|
396
|
+
if nid in seen:
|
|
397
|
+
continue
|
|
398
|
+
seen.add(nid)
|
|
399
|
+
stack.extend(self.spec.nodes[nid].dependencies)
|
|
400
|
+
return list(seen)
|
|
401
|
+
|
|
402
|
+
# --------- State mutation APIs ---------
|
|
403
|
+
async def set_status(self, node_id: str, status: NodeStatus):
|
|
404
|
+
"""Set the status of a node and notify observers."""
|
|
405
|
+
raise NotImplementedError(
|
|
406
|
+
"set_status() is not implemented yet. Use set_node_status() instead."
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
async def set_outputs(self, node_id: str, outputs: dict[str, Any]):
|
|
410
|
+
"""Set the outputs of a node."""
|
|
411
|
+
raise NotImplementedError(
|
|
412
|
+
"set_outputs() is not implemented yet. Use set_node_outputs() instead."
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
async def set_node_status(self, node_id: str, status: NodeStatus) -> None:
|
|
416
|
+
state = self.state.nodes.get(node_id)
|
|
417
|
+
if state.status is status:
|
|
418
|
+
return
|
|
419
|
+
state.status = status
|
|
420
|
+
self.state.rev += 1
|
|
421
|
+
await self._notify_status_change(node_id)
|
|
422
|
+
|
|
423
|
+
async def set_node_outputs(self, node_id: str, outputs: dict[str, Any]) -> None:
|
|
424
|
+
state = self.state.nodes.get(node_id)
|
|
425
|
+
state.outputs = outputs
|
|
426
|
+
self.state.rev += 1
|
|
427
|
+
await self._notify_output_change(node_id)
|
|
428
|
+
|
|
429
|
+
async def _notify_status_change(self, node_id: str):
|
|
430
|
+
runtime_node = self._runtime_nodes.get(node_id) # runtime view points at same state object
|
|
431
|
+
for obs in self.observers:
|
|
432
|
+
cb = getattr(obs, "on_node_status_change", None)
|
|
433
|
+
if cb:
|
|
434
|
+
out = cb(runtime_node)
|
|
435
|
+
if hasattr(out, "__await__"):
|
|
436
|
+
await out
|
|
437
|
+
|
|
438
|
+
def _notify_inputs_bound(self):
|
|
439
|
+
for obs in self.observers:
|
|
440
|
+
cb = getattr(obs, "on_inputs_bound", None)
|
|
441
|
+
if cb:
|
|
442
|
+
out = cb(self)
|
|
443
|
+
if hasattr(out, "__await__"):
|
|
444
|
+
# fire-and-forget is okay; await here to keep ordering
|
|
445
|
+
# (code already awaits in other notify paths)
|
|
446
|
+
pass
|
|
447
|
+
|
|
448
|
+
async def _notify_output_change(self, node_id: str):
|
|
449
|
+
runtime_node = self._runtime_nodes.get(node_id) # runtime view points at same state object
|
|
450
|
+
for obs in self.observers:
|
|
451
|
+
cb = getattr(obs, "on_node_output_change", None)
|
|
452
|
+
if cb:
|
|
453
|
+
out = cb(runtime_node)
|
|
454
|
+
if hasattr(out, "__await__"):
|
|
455
|
+
await out
|
|
456
|
+
|
|
457
|
+
# --------- Rest paths ---------
|
|
458
|
+
async def reset_node(self, node_id: str, *, preserve_outputs: bool = False):
|
|
459
|
+
"""Reset a node to PENDING state. Optionally preserve outputs."""
|
|
460
|
+
if node_id not in self.spec.nodes:
|
|
461
|
+
raise ValueError(f"Node with id {node_id} does not exist in the graph.")
|
|
462
|
+
|
|
463
|
+
if node_id == GRAPH_INPUTS_NODE_ID:
|
|
464
|
+
raise ValueError("Cannot reset the special graph inputs node.")
|
|
465
|
+
|
|
466
|
+
node = self.state.nodes[node_id]
|
|
467
|
+
await node.reset_node(preserve_outputs=preserve_outputs)
|
|
468
|
+
|
|
469
|
+
async def reset(
|
|
470
|
+
self,
|
|
471
|
+
node_ids: list[str] | None = None,
|
|
472
|
+
*,
|
|
473
|
+
recursive=True,
|
|
474
|
+
direction="forward",
|
|
475
|
+
preserve_outputs: bool = False,
|
|
476
|
+
):
|
|
477
|
+
"""
|
|
478
|
+
Reset the graph or a subgraph to PENDING state.
|
|
479
|
+
If node_id is None, reset the entire graph.
|
|
480
|
+
If recursive is True, reset all dependent nodes (forward) or dependencies (backward).
|
|
481
|
+
"""
|
|
482
|
+
if not node_ids:
|
|
483
|
+
# Reset the entire graph
|
|
484
|
+
for nid in list(self.spec.nodes.keys()):
|
|
485
|
+
if nid == GRAPH_INPUTS_NODE_ID:
|
|
486
|
+
continue
|
|
487
|
+
await self.reset_node(nid, preserve_outputs=preserve_outputs)
|
|
488
|
+
|
|
489
|
+
# partial reset
|
|
490
|
+
target_ids = []
|
|
491
|
+
for nid in node_ids:
|
|
492
|
+
if recursive:
|
|
493
|
+
if direction == "forward":
|
|
494
|
+
target_ids.extend(self.get_subgraph_nodes(nid))
|
|
495
|
+
elif direction == "backward":
|
|
496
|
+
target_ids.extend(self.get_upstream_nodes(nid))
|
|
497
|
+
else:
|
|
498
|
+
raise ValueError("direction must be 'forward' or 'backward'")
|
|
499
|
+
else:
|
|
500
|
+
target_ids.append(nid)
|
|
501
|
+
|
|
502
|
+
for nid in set(target_ids):
|
|
503
|
+
await self.reset_node(nid, preserve_outputs=preserve_outputs)
|
|
504
|
+
|
|
505
|
+
return {
|
|
506
|
+
"status": "partial_reset",
|
|
507
|
+
"graph_id": self.spec.graph_id,
|
|
508
|
+
"nodes_reset": list(set(target_ids)),
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
# --------- Observers and hooks ---------
|
|
512
|
+
def add_observer(self, observer: Any):
|
|
513
|
+
self.observers.append(observer)
|
|
514
|
+
|
|
515
|
+
# --------- Difference APIs ---------
|
|
516
|
+
def diff(self, other: TaskGraph) -> dict[str, Any]:
|
|
517
|
+
"""
|
|
518
|
+
Compute the difference between this graph and another graph.
|
|
519
|
+
Returns a dict with added, removed, and modified nodes.
|
|
520
|
+
"""
|
|
521
|
+
if self.spec.graph_id != other.spec.graph_id:
|
|
522
|
+
raise ValueError("Can only diff graphs with the same graph_id.")
|
|
523
|
+
|
|
524
|
+
diff_result = {"added": [], "removed": [], "modified": []}
|
|
525
|
+
|
|
526
|
+
# Check for added and modified nodes
|
|
527
|
+
for nid, node in other.spec.nodes.items():
|
|
528
|
+
if nid not in self.spec.nodes:
|
|
529
|
+
diff_result["added"].append(nid)
|
|
530
|
+
else:
|
|
531
|
+
# Check for modifications (dependencies or metadata changes)
|
|
532
|
+
old_node = self.spec.nodes[nid]
|
|
533
|
+
if (set(old_node.dependencies) != set(node.dependencies)) or (
|
|
534
|
+
old_node.metadata != node.metadata
|
|
535
|
+
):
|
|
536
|
+
diff_result["modified"].append(nid)
|
|
537
|
+
|
|
538
|
+
# Check for removed nodes
|
|
539
|
+
for nid in self.spec.nodes:
|
|
540
|
+
if nid not in other.spec.nodes:
|
|
541
|
+
diff_result["removed"].append(nid)
|
|
542
|
+
|
|
543
|
+
return diff_result
|
|
544
|
+
|
|
545
|
+
# --------- IO definition APIs ---------
|
|
546
|
+
def declare_inputs(
|
|
547
|
+
self, *, required: Iterable[str] | None = None, optional: dict[str, Any] | None = None
|
|
548
|
+
) -> None:
|
|
549
|
+
"""Declare graph-level inputs."""
|
|
550
|
+
# if required: self.spec._io_inputs_required.update(required)
|
|
551
|
+
# if optional: self.spec._io_inputs_optional.update(optional or {})
|
|
552
|
+
|
|
553
|
+
from .graph_io import ParamSpec
|
|
554
|
+
|
|
555
|
+
required_spec = {
|
|
556
|
+
k: ParamSpec() for k in (required or [])
|
|
557
|
+
} # currently we don't support detailed param spec. Only names are used.
|
|
558
|
+
optional_spec = {k: ParamSpec(default=v) for k, v in (optional or {}).items()}
|
|
559
|
+
if required:
|
|
560
|
+
self.spec.io.required.update(required_spec)
|
|
561
|
+
if optional:
|
|
562
|
+
self.spec.io.optional.update(optional_spec)
|
|
563
|
+
|
|
564
|
+
def expose(self, name: str, value: Ref | Any) -> None:
|
|
565
|
+
"""Expose a graph-level output.
|
|
566
|
+
In graph IO, outputs can be references to node outputs or constant values.
|
|
567
|
+
"""
|
|
568
|
+
if name not in self.spec.io.expose:
|
|
569
|
+
self.spec.io.expose.append(name)
|
|
570
|
+
self.spec.io.set_expose(name, normalize_binding(value))
|
|
571
|
+
|
|
572
|
+
def require_outputs(self, *names: str) -> None:
|
|
573
|
+
"""Require certain graph-level outputs to be present."""
|
|
574
|
+
missing = [n for n in names if n not in self.spec._io_outputs]
|
|
575
|
+
if missing:
|
|
576
|
+
raise ValueError(f"Missing required outputs: {', '.join(missing)}")
|
|
577
|
+
|
|
578
|
+
def io_signature(self, include_values: bool = False) -> dict[str, Any]:
|
|
579
|
+
"""Get the graph's IO signature as a dict.
|
|
580
|
+
The signature includes:
|
|
581
|
+
- inputs: {required: [...], optional: {...}}
|
|
582
|
+
- outputs: {keys: [...], bindings: {...}}
|
|
583
|
+
Note: Disable include_values when initializing a graph to avoid resolving unbound refs.
|
|
584
|
+
"""
|
|
585
|
+
if hasattr(self.spec.io, "get_expose_names"):
|
|
586
|
+
names: list[str] = self.spec.io.get_expose_names()
|
|
587
|
+
else:
|
|
588
|
+
names = list(getattr(self.spec.io, "expose", []) or [])
|
|
589
|
+
|
|
590
|
+
if hasattr(self.spec.io, "get_expose_bindings"):
|
|
591
|
+
bindings: dict[str, Any] = self.spec.io.get_expose_bindings()
|
|
592
|
+
else:
|
|
593
|
+
bindings = dict(getattr(self.spec, "meta", {}).get("expose_bindings", {}))
|
|
594
|
+
|
|
595
|
+
# Build the signature dict with concrete iterables / dicts
|
|
596
|
+
out = {
|
|
597
|
+
"inputs": {
|
|
598
|
+
"required": sorted(self.spec.inputs_required),
|
|
599
|
+
"optional": dict(self.spec.inputs_optional),
|
|
600
|
+
},
|
|
601
|
+
"outputs": {
|
|
602
|
+
"keys": list(names),
|
|
603
|
+
"bindings": {n: bindings.get(n) for n in names},
|
|
604
|
+
},
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
if include_values:
|
|
608
|
+
out["outputs"]["values"] = {
|
|
609
|
+
n: resolve_binding(bindings.get(n), self.state.node_outputs) for n in names
|
|
610
|
+
}
|
|
611
|
+
return out
|
|
612
|
+
|
|
613
|
+
def ensure_inputs_node(self):
|
|
614
|
+
if GRAPH_INPUTS_NODE_ID not in self.spec.nodes:
|
|
615
|
+
node_spec = TaskNodeSpec(
|
|
616
|
+
node_id=GRAPH_INPUTS_NODE_ID,
|
|
617
|
+
type="inputs",
|
|
618
|
+
logic=None,
|
|
619
|
+
inputs={},
|
|
620
|
+
dependencies=[],
|
|
621
|
+
metadata={"synthetic": True},
|
|
622
|
+
expected_input_keys=[],
|
|
623
|
+
expected_output_keys=[],
|
|
624
|
+
)
|
|
625
|
+
self.spec.nodes[GRAPH_INPUTS_NODE_ID] = node_spec
|
|
626
|
+
|
|
627
|
+
node_state = self.state.nodes.setdefault(GRAPH_INPUTS_NODE_ID, TaskNodeState())
|
|
628
|
+
node_state.status = NodeStatus.DONE
|
|
629
|
+
|
|
630
|
+
def _validate_and_bind_inputs(self, provided: dict[str, Any]) -> dict[str, Any]:
|
|
631
|
+
"""Validate and bind provided inputs against the graph's IO signature."""
|
|
632
|
+
req = self.spec.inputs_required
|
|
633
|
+
# opt = set(self.spec.inputs_optional.keys())
|
|
634
|
+
missing = [k for k in req if k not in provided]
|
|
635
|
+
if missing:
|
|
636
|
+
raise ValueError(f"Missing required inputs: {', '.join(missing)}")
|
|
637
|
+
|
|
638
|
+
merged = dict(self.spec.inputs_optional) # start with optional defaults
|
|
639
|
+
merged.update(provided) # override with provided
|
|
640
|
+
self.state._bound_inputs = merged
|
|
641
|
+
|
|
642
|
+
# bump rev; persist an event
|
|
643
|
+
self.state.rev += 1
|
|
644
|
+
# notify
|
|
645
|
+
out = self._notify_inputs_bound()
|
|
646
|
+
if hasattr(out, "__await__"):
|
|
647
|
+
# optional: await if later want strict ordering
|
|
648
|
+
pass
|
|
649
|
+
return merged
|
|
650
|
+
|
|
651
|
+
def _resolve_ref(self, r: Any, node_outputs: dict[str, dict[str, Any]]) -> Any:
|
|
652
|
+
"""Resolve a Ref or return the value as-is."""
|
|
653
|
+
if not (isinstance(r, dict) and r.get("_type") == "ref"):
|
|
654
|
+
return r
|
|
655
|
+
|
|
656
|
+
src, key = r.get("from"), r.get("key")
|
|
657
|
+
if src == GRAPH_INPUTS_NODE_ID:
|
|
658
|
+
if self.state._bound_inputs is None:
|
|
659
|
+
raise RuntimeError("Graph inputs not bound. Call graph(...) or bind explicitly.")
|
|
660
|
+
return self.state._bound_inputs.get(key)
|
|
661
|
+
return node_outputs.get(src, {}).get(key)
|
|
662
|
+
|
|
663
|
+
# --------- Execution APIs ---------
|
|
664
|
+
# Here we have temporary APIs, later we will use tools and scheduler to manage execution
|
|
665
|
+
def _load_logic(self, logic: Any):
|
|
666
|
+
# If logic is a callable, return it as-is
|
|
667
|
+
if callable(logic):
|
|
668
|
+
return logic
|
|
669
|
+
# If logic is a string, check if it starts with "registry:"
|
|
670
|
+
if isinstance(logic, str) and logic.startswith("registry:"):
|
|
671
|
+
# If it does, look it up in the registry
|
|
672
|
+
return self._lookup_registry(logic)
|
|
673
|
+
# If we reach here, logic is not valid
|
|
674
|
+
raise ValueError(f"Invalid logic: {logic}")
|
|
675
|
+
|
|
676
|
+
async def _run_tool(self, logic: Any, **kwargs):
|
|
677
|
+
fn = self._load_logic(logic)
|
|
678
|
+
res = fn(**kwargs)
|
|
679
|
+
if inspect.isawaitable(res):
|
|
680
|
+
res = await res
|
|
681
|
+
return res if isinstance(res, dict) else {"result": res}
|
|
682
|
+
|
|
683
|
+
# -------- Print and Debug ---------
|
|
684
|
+
def pretty(self, *, max_nodes: int = 20, max_width: int = 100) -> str:
|
|
685
|
+
"""
|
|
686
|
+
Human-friendly summary of this TaskGraph.
|
|
687
|
+
"""
|
|
688
|
+
lines: list[str] = []
|
|
689
|
+
|
|
690
|
+
# Header
|
|
691
|
+
lines.append(
|
|
692
|
+
f"TaskGraph[{self.spec.graph_id}] "
|
|
693
|
+
f"nodes={len(self.spec.nodes)} "
|
|
694
|
+
f"observers={len(self.observers)}"
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
# IO signature
|
|
698
|
+
lines.append("IO Signature:")
|
|
699
|
+
for s in self.spec.io_summary_lines():
|
|
700
|
+
lines.append(f" {s}")
|
|
701
|
+
|
|
702
|
+
# State summary
|
|
703
|
+
lines.append(f"State: {self.state.summary_line()}")
|
|
704
|
+
|
|
705
|
+
# Nodes table (compact)
|
|
706
|
+
lines.append("Nodes:")
|
|
707
|
+
header = f"{'id':<22} {'type':<10} {'status':<12} {'#deps':<5} logic"
|
|
708
|
+
lines.append(" " + header)
|
|
709
|
+
lines.append(" " + "-" * (len(header) + 4))
|
|
710
|
+
|
|
711
|
+
def _safe_get(node, attr, default):
|
|
712
|
+
return getattr(node, attr, default)
|
|
713
|
+
|
|
714
|
+
n_items = list(self.spec.nodes.items())
|
|
715
|
+
for idx, (nid, node) in enumerate(n_items):
|
|
716
|
+
if idx >= max_nodes:
|
|
717
|
+
lines.append(f" … ({len(n_items) - max_nodes} more)")
|
|
718
|
+
break
|
|
719
|
+
|
|
720
|
+
ntype = _safe_get(node, "node_type", "?")
|
|
721
|
+
status = _status_label(self.state.nodes.get(nid, TaskNodeState()).status)
|
|
722
|
+
deps = _safe_get(node, "dependencies", None) or []
|
|
723
|
+
logic = _logic_label(_safe_get(node, "logic", None))
|
|
724
|
+
|
|
725
|
+
# Width control: keep table tidy
|
|
726
|
+
row = f" {_short(nid, 22):<22} {_short(ntype, 10):<10} {_short(status, 12):<12} {len(deps):<5} {_short(logic, max_width)}"
|
|
727
|
+
lines.append(row)
|
|
728
|
+
|
|
729
|
+
return "\n".join(lines)
|
|
730
|
+
|
|
731
|
+
# Optional: make print(graph) show a compact version
|
|
732
|
+
def __str__(self) -> str:
|
|
733
|
+
return self.pretty(max_nodes=12, max_width=96)
|
|
734
|
+
|
|
735
|
+
# -------- Persistence conveniences (opt-in) --------
|
|
736
|
+
def spec_json(self) -> dict[str, Any]:
|
|
737
|
+
"""
|
|
738
|
+
JSON-safe representation of the graph spec.
|
|
739
|
+
Keeps TaskGraph storage-agnostic; callers can write to file/db/etc.
|
|
740
|
+
"""
|
|
741
|
+
return _dataclass_to_plain(self.spec)
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
# --------- Visualization ---------
|
|
745
|
+
TaskGraph.to_dot = to_dot
|
|
746
|
+
TaskGraph.visualize = visualize
|
|
747
|
+
TaskGraph.ascii_overview = ascii_overview
|