aethergraph 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. aethergraph/__init__.py +49 -0
  2. aethergraph/config/__init__.py +0 -0
  3. aethergraph/config/config.py +121 -0
  4. aethergraph/config/context.py +16 -0
  5. aethergraph/config/llm.py +26 -0
  6. aethergraph/config/loader.py +60 -0
  7. aethergraph/config/runtime.py +9 -0
  8. aethergraph/contracts/errors/errors.py +44 -0
  9. aethergraph/contracts/services/artifacts.py +142 -0
  10. aethergraph/contracts/services/channel.py +72 -0
  11. aethergraph/contracts/services/continuations.py +23 -0
  12. aethergraph/contracts/services/eventbus.py +12 -0
  13. aethergraph/contracts/services/kv.py +24 -0
  14. aethergraph/contracts/services/llm.py +17 -0
  15. aethergraph/contracts/services/mcp.py +22 -0
  16. aethergraph/contracts/services/memory.py +108 -0
  17. aethergraph/contracts/services/resume.py +28 -0
  18. aethergraph/contracts/services/state_stores.py +33 -0
  19. aethergraph/contracts/services/wakeup.py +28 -0
  20. aethergraph/core/execution/base_scheduler.py +77 -0
  21. aethergraph/core/execution/forward_scheduler.py +777 -0
  22. aethergraph/core/execution/global_scheduler.py +634 -0
  23. aethergraph/core/execution/retry_policy.py +22 -0
  24. aethergraph/core/execution/step_forward.py +411 -0
  25. aethergraph/core/execution/step_result.py +18 -0
  26. aethergraph/core/execution/wait_types.py +72 -0
  27. aethergraph/core/graph/graph_builder.py +192 -0
  28. aethergraph/core/graph/graph_fn.py +219 -0
  29. aethergraph/core/graph/graph_io.py +67 -0
  30. aethergraph/core/graph/graph_refs.py +154 -0
  31. aethergraph/core/graph/graph_spec.py +115 -0
  32. aethergraph/core/graph/graph_state.py +59 -0
  33. aethergraph/core/graph/graphify.py +128 -0
  34. aethergraph/core/graph/interpreter.py +145 -0
  35. aethergraph/core/graph/node_handle.py +33 -0
  36. aethergraph/core/graph/node_spec.py +46 -0
  37. aethergraph/core/graph/node_state.py +63 -0
  38. aethergraph/core/graph/task_graph.py +747 -0
  39. aethergraph/core/graph/task_node.py +82 -0
  40. aethergraph/core/graph/utils.py +37 -0
  41. aethergraph/core/graph/visualize.py +239 -0
  42. aethergraph/core/runtime/ad_hoc_context.py +61 -0
  43. aethergraph/core/runtime/base_service.py +153 -0
  44. aethergraph/core/runtime/bind_adapter.py +42 -0
  45. aethergraph/core/runtime/bound_memory.py +69 -0
  46. aethergraph/core/runtime/execution_context.py +220 -0
  47. aethergraph/core/runtime/graph_runner.py +349 -0
  48. aethergraph/core/runtime/lifecycle.py +26 -0
  49. aethergraph/core/runtime/node_context.py +203 -0
  50. aethergraph/core/runtime/node_services.py +30 -0
  51. aethergraph/core/runtime/recovery.py +159 -0
  52. aethergraph/core/runtime/run_registration.py +33 -0
  53. aethergraph/core/runtime/runtime_env.py +157 -0
  54. aethergraph/core/runtime/runtime_registry.py +32 -0
  55. aethergraph/core/runtime/runtime_services.py +224 -0
  56. aethergraph/core/runtime/wakeup_watcher.py +40 -0
  57. aethergraph/core/tools/__init__.py +10 -0
  58. aethergraph/core/tools/builtins/channel_tools.py +194 -0
  59. aethergraph/core/tools/builtins/toolset.py +134 -0
  60. aethergraph/core/tools/toolkit.py +510 -0
  61. aethergraph/core/tools/waitable.py +109 -0
  62. aethergraph/plugins/channel/__init__.py +0 -0
  63. aethergraph/plugins/channel/adapters/__init__.py +0 -0
  64. aethergraph/plugins/channel/adapters/console.py +106 -0
  65. aethergraph/plugins/channel/adapters/file.py +102 -0
  66. aethergraph/plugins/channel/adapters/slack.py +285 -0
  67. aethergraph/plugins/channel/adapters/telegram.py +302 -0
  68. aethergraph/plugins/channel/adapters/webhook.py +104 -0
  69. aethergraph/plugins/channel/adapters/webui.py +134 -0
  70. aethergraph/plugins/channel/routes/__init__.py +0 -0
  71. aethergraph/plugins/channel/routes/console_routes.py +86 -0
  72. aethergraph/plugins/channel/routes/slack_routes.py +49 -0
  73. aethergraph/plugins/channel/routes/telegram_routes.py +26 -0
  74. aethergraph/plugins/channel/routes/webui_routes.py +136 -0
  75. aethergraph/plugins/channel/utils/__init__.py +0 -0
  76. aethergraph/plugins/channel/utils/slack_utils.py +278 -0
  77. aethergraph/plugins/channel/utils/telegram_utils.py +324 -0
  78. aethergraph/plugins/channel/websockets/slack_ws.py +68 -0
  79. aethergraph/plugins/channel/websockets/telegram_polling.py +151 -0
  80. aethergraph/plugins/mcp/fs_server.py +128 -0
  81. aethergraph/plugins/mcp/http_server.py +101 -0
  82. aethergraph/plugins/mcp/ws_server.py +180 -0
  83. aethergraph/plugins/net/http.py +10 -0
  84. aethergraph/plugins/utils/data_io.py +359 -0
  85. aethergraph/runner/__init__.py +5 -0
  86. aethergraph/runtime/__init__.py +62 -0
  87. aethergraph/server/__init__.py +3 -0
  88. aethergraph/server/app_factory.py +84 -0
  89. aethergraph/server/start.py +122 -0
  90. aethergraph/services/__init__.py +10 -0
  91. aethergraph/services/artifacts/facade.py +284 -0
  92. aethergraph/services/artifacts/factory.py +35 -0
  93. aethergraph/services/artifacts/fs_store.py +656 -0
  94. aethergraph/services/artifacts/jsonl_index.py +123 -0
  95. aethergraph/services/artifacts/paths.py +23 -0
  96. aethergraph/services/artifacts/sqlite_index.py +209 -0
  97. aethergraph/services/artifacts/utils.py +124 -0
  98. aethergraph/services/auth/dev.py +16 -0
  99. aethergraph/services/channel/channel_bus.py +293 -0
  100. aethergraph/services/channel/factory.py +44 -0
  101. aethergraph/services/channel/session.py +511 -0
  102. aethergraph/services/channel/wait_helpers.py +57 -0
  103. aethergraph/services/clock/clock.py +9 -0
  104. aethergraph/services/container/default_container.py +320 -0
  105. aethergraph/services/continuations/continuation.py +56 -0
  106. aethergraph/services/continuations/factory.py +34 -0
  107. aethergraph/services/continuations/stores/fs_store.py +264 -0
  108. aethergraph/services/continuations/stores/inmem_store.py +95 -0
  109. aethergraph/services/eventbus/inmem.py +21 -0
  110. aethergraph/services/features/static.py +10 -0
  111. aethergraph/services/kv/ephemeral.py +90 -0
  112. aethergraph/services/kv/factory.py +27 -0
  113. aethergraph/services/kv/layered.py +41 -0
  114. aethergraph/services/kv/sqlite_kv.py +128 -0
  115. aethergraph/services/llm/factory.py +157 -0
  116. aethergraph/services/llm/generic_client.py +542 -0
  117. aethergraph/services/llm/providers.py +3 -0
  118. aethergraph/services/llm/service.py +105 -0
  119. aethergraph/services/logger/base.py +36 -0
  120. aethergraph/services/logger/compat.py +50 -0
  121. aethergraph/services/logger/formatters.py +106 -0
  122. aethergraph/services/logger/std.py +203 -0
  123. aethergraph/services/mcp/helpers.py +23 -0
  124. aethergraph/services/mcp/http_client.py +70 -0
  125. aethergraph/services/mcp/mcp_tools.py +21 -0
  126. aethergraph/services/mcp/registry.py +14 -0
  127. aethergraph/services/mcp/service.py +100 -0
  128. aethergraph/services/mcp/stdio_client.py +70 -0
  129. aethergraph/services/mcp/ws_client.py +115 -0
  130. aethergraph/services/memory/bound.py +106 -0
  131. aethergraph/services/memory/distillers/episode.py +116 -0
  132. aethergraph/services/memory/distillers/rolling.py +74 -0
  133. aethergraph/services/memory/facade.py +633 -0
  134. aethergraph/services/memory/factory.py +78 -0
  135. aethergraph/services/memory/hotlog_kv.py +27 -0
  136. aethergraph/services/memory/indices.py +74 -0
  137. aethergraph/services/memory/io_helpers.py +72 -0
  138. aethergraph/services/memory/persist_fs.py +40 -0
  139. aethergraph/services/memory/resolver.py +152 -0
  140. aethergraph/services/metering/noop.py +4 -0
  141. aethergraph/services/prompts/file_store.py +41 -0
  142. aethergraph/services/rag/chunker.py +29 -0
  143. aethergraph/services/rag/facade.py +593 -0
  144. aethergraph/services/rag/index/base.py +27 -0
  145. aethergraph/services/rag/index/faiss_index.py +121 -0
  146. aethergraph/services/rag/index/sqlite_index.py +134 -0
  147. aethergraph/services/rag/index_factory.py +52 -0
  148. aethergraph/services/rag/parsers/md.py +7 -0
  149. aethergraph/services/rag/parsers/pdf.py +14 -0
  150. aethergraph/services/rag/parsers/txt.py +7 -0
  151. aethergraph/services/rag/utils/hybrid.py +39 -0
  152. aethergraph/services/rag/utils/make_fs_key.py +62 -0
  153. aethergraph/services/redactor/simple.py +16 -0
  154. aethergraph/services/registry/key_parsing.py +44 -0
  155. aethergraph/services/registry/registry_key.py +19 -0
  156. aethergraph/services/registry/unified_registry.py +185 -0
  157. aethergraph/services/resume/multi_scheduler_resume_bus.py +65 -0
  158. aethergraph/services/resume/router.py +73 -0
  159. aethergraph/services/schedulers/registry.py +41 -0
  160. aethergraph/services/secrets/base.py +7 -0
  161. aethergraph/services/secrets/env.py +8 -0
  162. aethergraph/services/state_stores/externalize.py +135 -0
  163. aethergraph/services/state_stores/graph_observer.py +131 -0
  164. aethergraph/services/state_stores/json_store.py +67 -0
  165. aethergraph/services/state_stores/resume_policy.py +119 -0
  166. aethergraph/services/state_stores/serialize.py +249 -0
  167. aethergraph/services/state_stores/utils.py +91 -0
  168. aethergraph/services/state_stores/validate.py +78 -0
  169. aethergraph/services/tracing/noop.py +18 -0
  170. aethergraph/services/waits/wait_registry.py +91 -0
  171. aethergraph/services/wakeup/memory_queue.py +57 -0
  172. aethergraph/services/wakeup/scanner_producer.py +56 -0
  173. aethergraph/services/wakeup/worker.py +31 -0
  174. aethergraph/tools/__init__.py +25 -0
  175. aethergraph/utils/optdeps.py +8 -0
  176. aethergraph-0.1.0a1.dist-info/METADATA +410 -0
  177. aethergraph-0.1.0a1.dist-info/RECORD +182 -0
  178. aethergraph-0.1.0a1.dist-info/WHEEL +5 -0
  179. aethergraph-0.1.0a1.dist-info/entry_points.txt +2 -0
  180. aethergraph-0.1.0a1.dist-info/licenses/LICENSE +176 -0
  181. aethergraph-0.1.0a1.dist-info/licenses/NOTICE +31 -0
  182. aethergraph-0.1.0a1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,747 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import asdict, dataclass, field, is_dataclass
5
+ import inspect
6
+ from typing import Any
7
+ import uuid
8
+
9
+ from .graph_refs import GRAPH_INPUTS_NODE_ID, Ref, normalize_binding, resolve_binding
10
+ from .graph_spec import GraphView, TaskGraphSpec
11
+ from .graph_state import GraphPatch, TaskGraphState
12
+ from .node_spec import TaskNodeSpec
13
+ from .node_state import NodeStatus, TaskNodeState
14
+ from .task_node import TaskNodeRuntime
15
+ from .utils import _logic_label, _short, _status_label
16
+ from .visualize import ascii_overview, to_dot, visualize
17
+
18
+
19
+ # small helper to turn a dataclass (spec) into plain dict safely
20
+ def _dataclass_to_plain(d):
21
+ return asdict(d) if is_dataclass(d) else d
22
+
23
+
24
+ @dataclass
25
+ class TaskGraph:
26
+ spec: TaskGraphSpec
27
+ state: TaskGraphState = field(
28
+ default_factory=TaskGraphState
29
+ ) # mutable state, including node states, patches
30
+ observers: list[Any] = field(default_factory=list)
31
+
32
+ # Expose graph_id as convenient alias; source of truth is spec.graph_id
33
+ graph_id: str = field(init=False, repr=True)
34
+
35
+ # Ephemeral runtime table (not serialized)
36
+ _runtime_nodes: dict[str, TaskNodeRuntime] = field(default_factory=dict, init=False, repr=False)
37
+
38
+ # Inverted indexeds for quick lookup (by alias, logic, label) ephemeral
39
+ _idx_ready: bool = field(default=False, init=False, repr=False)
40
+ _by_alias: dict[str, str] = field(
41
+ default_factory=dict, init=False, repr=False
42
+ ) # alias -> node_id
43
+ _by_logic: dict[str, list[str]] = field(
44
+ default_factory=dict, init=False, repr=False
45
+ ) # logic -> [node_id, ...]
46
+ _by_label: dict[str, list[str]] = field(
47
+ default_factory=dict, init=False, repr=False
48
+ ) # label -> {node_id, ...}
49
+ _by_name: dict[str, list[str]] = field(
50
+ default_factory=dict, init=False, repr=False
51
+ ) # display name -> [node_id, ...]: TODO: decided if we need this
52
+
53
+ @classmethod
54
+ def new_run(self, spec: TaskGraphSpec, *, run_id: str | None = None, **kwargs) -> TaskGraph:
55
+ """Create a new TaskGraph instance for a new run."""
56
+ run_id = run_id or str(uuid.uuid4())
57
+ # initialize empty node states
58
+ nodes = {nid: TaskNodeState() for nid in spec.nodes}
59
+ state = TaskGraphState(run_id=run_id, nodes=nodes)
60
+ graph = self.from_spec(spec, state=state, **kwargs)
61
+ return graph
62
+
63
+ @classmethod
64
+ def from_spec(cls, spec: TaskGraphSpec, *, state: TaskGraphState | None = None):
65
+ """Create a TaskGraph instance from a TaskGraphSpec and optional state and memory."""
66
+ graph = cls(spec=spec, state=state or TaskGraphState())
67
+ # Set back-references in nodes
68
+ for node in graph.spec.nodes.values():
69
+ node._parent_graph = graph
70
+ graph.ensure_inputs_node()
71
+ graph.__post_init__()
72
+ graph.ensure_inputs_node()
73
+
74
+ # Set the inputs node state to DONE
75
+ input_node = graph.node(GRAPH_INPUTS_NODE_ID)
76
+ input_node.state.status = NodeStatus.DONE
77
+ return graph
78
+
79
+ # Publich read-only view
80
+ def node(self, node_id: str) -> TaskNodeRuntime:
81
+ return self._runtime_nodes[node_id]
82
+
83
+ @property
84
+ def nodes(self) -> list[TaskNodeRuntime]:
85
+ return list(self._runtime_nodes.values())
86
+
87
+ def _apply_patches(self) -> dict[str, TaskNodeSpec]:
88
+ """Compute a patched node spec dict for the *view*.
89
+ The original spec stays immutable
90
+ """
91
+ node_specs = dict(self.spec.nodes) # shallow copy of mapping
92
+
93
+ # The following is used when graph mutations are supported. It is just a sketch now.
94
+ for p in self.state.patches:
95
+ if p.op == "add_or_replace_node":
96
+ ns = TaskNodeSpec(**p.payload) # validate payload
97
+ node_specs[ns.node_id] = ns
98
+ elif p.op == "remove_node":
99
+ node_specs.pop(p.payload["node_id"], None)
100
+ elif p.op == "add_dependency":
101
+ nid = p.payload["node_id"]
102
+ dep = p.payload["dependency_id"]
103
+ old = node_specs.get(nid)
104
+ # create a new frozen spec with updated deps
105
+ node_specs[nid] = TaskNodeSpec(
106
+ **{**old.__dict__, "dependencies": [*old.dependencies, dep]}
107
+ )
108
+ # TODO: add more patch type
109
+ pass
110
+
111
+ return node_specs
112
+
113
+ def _reify_runtime_nodes(self):
114
+ """Create TaskNodeRuntime instances for all nodes in the graph spec."""
115
+ effective_specs = self._apply_patches() # get patched specs
116
+ table = {}
117
+ for nid, nspec in effective_specs.items():
118
+ nstate = self.state.nodes.get(nid)
119
+ if nstate is None:
120
+ nstate = TaskNodeState() # in case a patch add a node
121
+ self.state.nodes[nid] = nstate # persist its state in TaskGraphState
122
+ table[nid] = TaskNodeRuntime(spec=nspec, state=nstate, _parent_graph=self)
123
+ self._runtime_nodes = table
124
+
125
+ def __post_init__(self):
126
+ # establish graph_id as alias to spec.graph_id
127
+ self.graph_id = self.spec.graph_id
128
+
129
+ # establish back-references in nodes
130
+ if not getattr(self.state, "nodes", None):
131
+ self.state.nodes = {
132
+ nid: TaskNodeState() for nid in self.spec.nodes
133
+ } # GraphSpec.nodes is Dict[str, TaskNodeSpec]
134
+
135
+ # establish back-references in nodes
136
+ self._reify_runtime_nodes()
137
+
138
+ # index for quick lookup
139
+ self._reindex()
140
+
141
+ def _reindex(self):
142
+ self._by_alias.clear()
143
+ self._by_logic.clear()
144
+ self._by_label.clear()
145
+ self._by_name.clear()
146
+ for nid, node in self._runtime_nodes.items():
147
+ metadata = getattr(node.spec, "metadata", {}) or {}
148
+ alias = metadata.get("alias")
149
+ labels = metadata.get("labels", [])
150
+ display = metadata.get("display_name")
151
+ logic_name = node.spec.tool_name or (
152
+ node.spec.logic if isinstance(node.spec.logic, str) else _short(node.spec.logic)
153
+ )
154
+
155
+ if alias:
156
+ self._by_alias[alias] = nid
157
+ for label in labels:
158
+ self._by_label.setdefault(label, set()).add(nid)
159
+ if logic_name:
160
+ self._by_logic.setdefault(logic_name, []).append(nid)
161
+ if display:
162
+ self._by_name.setdefault(display, []).append(nid)
163
+
164
+ self._idx_ready = True
165
+
166
+ def _ensure_index(self):
167
+ if not self._idx_ready:
168
+ self._reindex()
169
+
170
+ # Call when mutate spec.nodes in PatchFlow
171
+ def index_touch(self):
172
+ self._idx_ready = False
173
+
174
+ # Node access
175
+ def node_ids(self) -> list[str]:
176
+ """Get list of all node IDs in the graph."""
177
+ return list(self._runtime_nodes.keys())
178
+
179
+ # Node finder
180
+ def get_by_id(self, node_id: str) -> str:
181
+ """Get node ID by ID (identity function)."""
182
+ if node_id not in self._runtime_nodes:
183
+ raise ValueError(f"Node ID '{node_id}' not found in graph '{self.graph_id}'")
184
+ return node_id
185
+
186
+ def get_by_alias(self, alias: str) -> str | None:
187
+ """Get node ID by alias."""
188
+ self._ensure_index()
189
+ node_id = self._by_alias.get(alias)
190
+ if not node_id:
191
+ raise KeyError(f"Alias '{alias}' not found in graph '{self.graph_id}'")
192
+ return node_id
193
+
194
+ def find_by_label(self, label: str) -> list[str]:
195
+ """Find node IDs by label."""
196
+ self._ensure_index()
197
+ return sorted(self._by_label.get(label, set()))
198
+
199
+ def find_by_logic(self, logic_prefix: str, *, first: bool = False) -> list[str] | str | None:
200
+ """Find node IDs by logic name.
201
+ If first=True, return only the first match or None if not found.
202
+
203
+ Usage:
204
+ graph.find_by_logic("my_tool") # all nodes with logic name "my_tool"
205
+ graph.find_by_logic("my_tool", first=True) # first node with logic name "my_tool" or None
206
+ graph.find_by_logic("my_tool_v") # all nodes with logic name starting with "my_tool_v"
207
+ graph.find_by_logic("my_tool_v", first=True) # first node with logic name starting with "my_tool_v" or None
208
+ """
209
+ self._ensure_index()
210
+ if logic_prefix in self._by_logic:
211
+ ids = list(self._by_logic[logic_prefix])
212
+ else:
213
+ ids = []
214
+ for k, vs in self._by_logic.items():
215
+ if k.startswith(logic_prefix):
216
+ ids.extend(vs)
217
+ ids.sort()
218
+ return (ids[0] if (first and ids) else ids) or ([] if not first else None)
219
+
220
+ def find_by_display(self, name_prefix: str, *, first: bool = False) -> list[str] | str | None:
221
+ """Find node IDs by display name.
222
+ If first=True, return only the first match or None if not found.
223
+
224
+ Usage:
225
+ graph.find_by_display("My Node") # all nodes with display name "My Node"
226
+ graph.find_by_display("My Node", first=True) # first node with display name "My Node" or None
227
+ graph.find_by_display("My Node V") # all nodes with display name starting with "My Node V"
228
+ graph.find_by_display("My Node V", first=True) # first node with display name starting with "My Node V" or None
229
+ """
230
+ self._ensure_index()
231
+ if name_prefix in self._by_name:
232
+ ids = list(self._by_name[name_prefix])
233
+ else:
234
+ ids = []
235
+ for k, vs in self._by_name.items():
236
+ if k.startswith(name_prefix):
237
+ ids.extend(vs)
238
+ ids.sort()
239
+ return (ids[0] if (first and ids) else ids) or ([] if not first else None)
240
+
241
+ # ---------- Unified selector ----------
242
+ # Mini-DSL:
243
+ # "@alias" -> by alias
244
+ # "#label" -> by label (many)
245
+ # "id:<id>" -> exact id
246
+ # "logic:<pref>" -> logic name prefix
247
+ # "name:<pref>" -> display name prefix
248
+ # "/regex/" -> regex on node_id
249
+
250
+ def select(self, selector: str, *, first: bool = False) -> str | list[str] | None:
251
+ selector = selector.strip()
252
+ if selector.startswith("@"):
253
+ return self.get_by_alias(selector[1:])
254
+ elif selector.startswith("#"):
255
+ ids = self.find_by_label(selector[1:])
256
+ return ids[0] if (first and ids) else ids
257
+
258
+ elif selector.startswith("id:"):
259
+ return self.get_by_id(selector[3:])
260
+ elif selector.startswith("logic:"):
261
+ ids = self.find_by_logic(selector[6:], first=first)
262
+ return ids
263
+ elif selector.startswith("name:"):
264
+ ids = self.find_by_display(selector[5:], first=first)
265
+ return ids
266
+ elif len(selector) >= 2 and selector[0] == "/" and selector[-1] == "/":
267
+ import re
268
+
269
+ pattern = re.compile(selector[1:-1])
270
+ ids = [nid for nid in self.node_ids() if pattern.search(nid)]
271
+ ids.sort()
272
+ return ids[0] if (first and ids) else ids
273
+ else:
274
+ # fallback: prefix on node_id
275
+ ids = [nid for nid in self.node_ids() if nid.startswith(selector)]
276
+ ids.sort()
277
+ return ids[0] if (first and ids) else ids
278
+
279
+ def pick_one(self, selector: str) -> str | None:
280
+ """Pick one node ID by selector, or None if not found."""
281
+ res = self.select(selector, first=True)
282
+ if not res:
283
+ raise KeyError(f"No node found for selector '{selector}' in graph '{self.graph_id}'")
284
+ return res
285
+
286
+ def pick_all(self, selector: str) -> list[str]:
287
+ """Pick all node IDs by selector, or empty list if none found."""
288
+ res = self.select(selector, first=False)
289
+ if isinstance(res, str):
290
+ return [res]
291
+ return res or []
292
+
293
+ # --------- Read-only views ---------
294
+ def view(self) -> GraphView:
295
+ """Get a read-only view of the graph's spec and state."""
296
+ return GraphView(
297
+ graph_id=self.spec.graph_id,
298
+ nodes=self.spec.nodes,
299
+ node_status=self.state.node_status, # state.node_status is a property in TaskGraphState derived from self.state.nodes
300
+ metadata=self.spec.metadata,
301
+ )
302
+
303
+ # -------- Graph mutation APIs ---------
304
+ def patch_add_or_replace_node(self, node_spec: dict[str, Any]):
305
+ """Patch the graph by adding or replacing a node."""
306
+ patch = GraphPatch(op="add_or_replace_node", payload=node_spec)
307
+ self.state.patches.append(patch)
308
+ self.state.rev += 1
309
+ # awaitable = None
310
+ for obs in self.observers:
311
+ cb = getattr(obs, "on_patch_applied", None)
312
+ if cb:
313
+ cb(self, patch) # r = cb(self, patch) if awaitable is needed
314
+ # if hasattr(r, "__await__"):
315
+ # awaitable = r # keep last; or gather all
316
+
317
+ self._reify_runtime_nodes()
318
+
319
+ def patch_remove_node(self, node_id: str):
320
+ """Patch the graph by removing a node."""
321
+ patch = GraphPatch(op="remove_node", payload={"node_id": node_id})
322
+ self.state.patches.append(patch)
323
+ self.state.rev += 1
324
+ # awaitable = None
325
+ for obs in self.observers:
326
+ cb = getattr(obs, "on_patch_applied", None)
327
+ if cb:
328
+ cb(self, patch) # r = cb(self, patch) if awaitable is needed
329
+ # if hasattr(r, "__await__"):
330
+ # awaitable = r # keep last; or gather all
331
+ self._reify_runtime_nodes()
332
+
333
+ def patch_add_dependency(self, node_id: str, dependency_id: str):
334
+ """Patch the graph by adding a dependency to a node."""
335
+ patch = GraphPatch(
336
+ op="add_dependency", payload={"node_id": node_id, "dependency_id": dependency_id}
337
+ )
338
+ self.state.patches.append(patch)
339
+ self.state.rev += 1
340
+ # awaitable = None
341
+ for obs in self.observers:
342
+ cb = getattr(obs, "on_patch_applied", None)
343
+ if cb:
344
+ cb(self, patch) # r = cb(self, patch) if awaitable is needed
345
+ # if hasattr(r, "__await__"):
346
+ # awaitable = r # keep last; or gather all
347
+ self._reify_runtime_nodes()
348
+
349
+ # --------- Introspection APIs ---------
350
+ def list_nodes(self, exclude_internal=True) -> list[str]:
351
+ """List all node IDs in the graph."""
352
+ return (
353
+ list(self.spec.nodes.keys())
354
+ if not exclude_internal
355
+ else [nid for nid in self.spec.nodes if not nid.startswith("_")]
356
+ )
357
+
358
+ # --------- Topology helpers ---------
359
+ def dependents(self, node_id: str) -> list[str]:
360
+ """Get list of node_ids that depend on the given node_id."""
361
+ return [x.node_id for x in self.spec.nodes.values() if node_id in x.dependencies]
362
+
363
+ def topological_order(self) -> list[str]:
364
+ """Get nodes in topological order. Raises error if cycles are detected."""
365
+ import networkx as nx
366
+
367
+ G = nx.DiGraph()
368
+ for n in self.spec.nodes.values():
369
+ G.add_node(n.node_id)
370
+ for dep in n.dependencies:
371
+ G.add_edge(dep, n.node_id)
372
+ try:
373
+ order = list(nx.topological_sort(G))
374
+ return order
375
+ except nx.NetworkXUnfeasible:
376
+ raise ValueError(
377
+ "Graph has at least one cycle; topological sort not possible."
378
+ ) from None
379
+
380
+ def get_subgraph_nodes(self, start_node_id: str) -> list[str]:
381
+ """Get all nodes reachable from the given start_node_id (including itself)."""
382
+ seen, stack = set(), [start_node_id]
383
+ while stack:
384
+ nid = stack.pop()
385
+ if nid in seen:
386
+ continue
387
+ seen.add(nid)
388
+ stack.extend(self.dependents(nid))
389
+ return list(seen)
390
+
391
+ def get_upstream_nodes(self, start_node_id: str) -> list[str]:
392
+ """Get all upstream nodes that the given node_id depends on (including itself)."""
393
+ seen, stack = set(), [start_node_id]
394
+ while stack:
395
+ nid = stack.pop()
396
+ if nid in seen:
397
+ continue
398
+ seen.add(nid)
399
+ stack.extend(self.spec.nodes[nid].dependencies)
400
+ return list(seen)
401
+
402
+ # --------- State mutation APIs ---------
403
+ async def set_status(self, node_id: str, status: NodeStatus):
404
+ """Set the status of a node and notify observers."""
405
+ raise NotImplementedError(
406
+ "set_status() is not implemented yet. Use set_node_status() instead."
407
+ )
408
+
409
+ async def set_outputs(self, node_id: str, outputs: dict[str, Any]):
410
+ """Set the outputs of a node."""
411
+ raise NotImplementedError(
412
+ "set_outputs() is not implemented yet. Use set_node_outputs() instead."
413
+ )
414
+
415
+ async def set_node_status(self, node_id: str, status: NodeStatus) -> None:
416
+ state = self.state.nodes.get(node_id)
417
+ if state.status is status:
418
+ return
419
+ state.status = status
420
+ self.state.rev += 1
421
+ await self._notify_status_change(node_id)
422
+
423
+ async def set_node_outputs(self, node_id: str, outputs: dict[str, Any]) -> None:
424
+ state = self.state.nodes.get(node_id)
425
+ state.outputs = outputs
426
+ self.state.rev += 1
427
+ await self._notify_output_change(node_id)
428
+
429
+ async def _notify_status_change(self, node_id: str):
430
+ runtime_node = self._runtime_nodes.get(node_id) # runtime view points at same state object
431
+ for obs in self.observers:
432
+ cb = getattr(obs, "on_node_status_change", None)
433
+ if cb:
434
+ out = cb(runtime_node)
435
+ if hasattr(out, "__await__"):
436
+ await out
437
+
438
+ def _notify_inputs_bound(self):
439
+ for obs in self.observers:
440
+ cb = getattr(obs, "on_inputs_bound", None)
441
+ if cb:
442
+ out = cb(self)
443
+ if hasattr(out, "__await__"):
444
+ # fire-and-forget is okay; await here to keep ordering
445
+ # (code already awaits in other notify paths)
446
+ pass
447
+
448
+ async def _notify_output_change(self, node_id: str):
449
+ runtime_node = self._runtime_nodes.get(node_id) # runtime view points at same state object
450
+ for obs in self.observers:
451
+ cb = getattr(obs, "on_node_output_change", None)
452
+ if cb:
453
+ out = cb(runtime_node)
454
+ if hasattr(out, "__await__"):
455
+ await out
456
+
457
+ # --------- Rest paths ---------
458
+ async def reset_node(self, node_id: str, *, preserve_outputs: bool = False):
459
+ """Reset a node to PENDING state. Optionally preserve outputs."""
460
+ if node_id not in self.spec.nodes:
461
+ raise ValueError(f"Node with id {node_id} does not exist in the graph.")
462
+
463
+ if node_id == GRAPH_INPUTS_NODE_ID:
464
+ raise ValueError("Cannot reset the special graph inputs node.")
465
+
466
+ node = self.state.nodes[node_id]
467
+ await node.reset_node(preserve_outputs=preserve_outputs)
468
+
469
+ async def reset(
470
+ self,
471
+ node_ids: list[str] | None = None,
472
+ *,
473
+ recursive=True,
474
+ direction="forward",
475
+ preserve_outputs: bool = False,
476
+ ):
477
+ """
478
+ Reset the graph or a subgraph to PENDING state.
479
+ If node_id is None, reset the entire graph.
480
+ If recursive is True, reset all dependent nodes (forward) or dependencies (backward).
481
+ """
482
+ if not node_ids:
483
+ # Reset the entire graph
484
+ for nid in list(self.spec.nodes.keys()):
485
+ if nid == GRAPH_INPUTS_NODE_ID:
486
+ continue
487
+ await self.reset_node(nid, preserve_outputs=preserve_outputs)
488
+
489
+ # partial reset
490
+ target_ids = []
491
+ for nid in node_ids:
492
+ if recursive:
493
+ if direction == "forward":
494
+ target_ids.extend(self.get_subgraph_nodes(nid))
495
+ elif direction == "backward":
496
+ target_ids.extend(self.get_upstream_nodes(nid))
497
+ else:
498
+ raise ValueError("direction must be 'forward' or 'backward'")
499
+ else:
500
+ target_ids.append(nid)
501
+
502
+ for nid in set(target_ids):
503
+ await self.reset_node(nid, preserve_outputs=preserve_outputs)
504
+
505
+ return {
506
+ "status": "partial_reset",
507
+ "graph_id": self.spec.graph_id,
508
+ "nodes_reset": list(set(target_ids)),
509
+ }
510
+
511
+ # --------- Observers and hooks ---------
512
+ def add_observer(self, observer: Any):
513
+ self.observers.append(observer)
514
+
515
+ # --------- Difference APIs ---------
516
+ def diff(self, other: TaskGraph) -> dict[str, Any]:
517
+ """
518
+ Compute the difference between this graph and another graph.
519
+ Returns a dict with added, removed, and modified nodes.
520
+ """
521
+ if self.spec.graph_id != other.spec.graph_id:
522
+ raise ValueError("Can only diff graphs with the same graph_id.")
523
+
524
+ diff_result = {"added": [], "removed": [], "modified": []}
525
+
526
+ # Check for added and modified nodes
527
+ for nid, node in other.spec.nodes.items():
528
+ if nid not in self.spec.nodes:
529
+ diff_result["added"].append(nid)
530
+ else:
531
+ # Check for modifications (dependencies or metadata changes)
532
+ old_node = self.spec.nodes[nid]
533
+ if (set(old_node.dependencies) != set(node.dependencies)) or (
534
+ old_node.metadata != node.metadata
535
+ ):
536
+ diff_result["modified"].append(nid)
537
+
538
+ # Check for removed nodes
539
+ for nid in self.spec.nodes:
540
+ if nid not in other.spec.nodes:
541
+ diff_result["removed"].append(nid)
542
+
543
+ return diff_result
544
+
545
+ # --------- IO definition APIs ---------
546
+ def declare_inputs(
547
+ self, *, required: Iterable[str] | None = None, optional: dict[str, Any] | None = None
548
+ ) -> None:
549
+ """Declare graph-level inputs."""
550
+ # if required: self.spec._io_inputs_required.update(required)
551
+ # if optional: self.spec._io_inputs_optional.update(optional or {})
552
+
553
+ from .graph_io import ParamSpec
554
+
555
+ required_spec = {
556
+ k: ParamSpec() for k in (required or [])
557
+ } # currently we don't support detailed param spec. Only names are used.
558
+ optional_spec = {k: ParamSpec(default=v) for k, v in (optional or {}).items()}
559
+ if required:
560
+ self.spec.io.required.update(required_spec)
561
+ if optional:
562
+ self.spec.io.optional.update(optional_spec)
563
+
564
+ def expose(self, name: str, value: Ref | Any) -> None:
565
+ """Expose a graph-level output.
566
+ In graph IO, outputs can be references to node outputs or constant values.
567
+ """
568
+ if name not in self.spec.io.expose:
569
+ self.spec.io.expose.append(name)
570
+ self.spec.io.set_expose(name, normalize_binding(value))
571
+
572
+ def require_outputs(self, *names: str) -> None:
573
+ """Require certain graph-level outputs to be present."""
574
+ missing = [n for n in names if n not in self.spec._io_outputs]
575
+ if missing:
576
+ raise ValueError(f"Missing required outputs: {', '.join(missing)}")
577
+
578
+ def io_signature(self, include_values: bool = False) -> dict[str, Any]:
579
+ """Get the graph's IO signature as a dict.
580
+ The signature includes:
581
+ - inputs: {required: [...], optional: {...}}
582
+ - outputs: {keys: [...], bindings: {...}}
583
+ Note: Disable include_values when initializing a graph to avoid resolving unbound refs.
584
+ """
585
+ if hasattr(self.spec.io, "get_expose_names"):
586
+ names: list[str] = self.spec.io.get_expose_names()
587
+ else:
588
+ names = list(getattr(self.spec.io, "expose", []) or [])
589
+
590
+ if hasattr(self.spec.io, "get_expose_bindings"):
591
+ bindings: dict[str, Any] = self.spec.io.get_expose_bindings()
592
+ else:
593
+ bindings = dict(getattr(self.spec, "meta", {}).get("expose_bindings", {}))
594
+
595
+ # Build the signature dict with concrete iterables / dicts
596
+ out = {
597
+ "inputs": {
598
+ "required": sorted(self.spec.inputs_required),
599
+ "optional": dict(self.spec.inputs_optional),
600
+ },
601
+ "outputs": {
602
+ "keys": list(names),
603
+ "bindings": {n: bindings.get(n) for n in names},
604
+ },
605
+ }
606
+
607
+ if include_values:
608
+ out["outputs"]["values"] = {
609
+ n: resolve_binding(bindings.get(n), self.state.node_outputs) for n in names
610
+ }
611
+ return out
612
+
613
+ def ensure_inputs_node(self):
614
+ if GRAPH_INPUTS_NODE_ID not in self.spec.nodes:
615
+ node_spec = TaskNodeSpec(
616
+ node_id=GRAPH_INPUTS_NODE_ID,
617
+ type="inputs",
618
+ logic=None,
619
+ inputs={},
620
+ dependencies=[],
621
+ metadata={"synthetic": True},
622
+ expected_input_keys=[],
623
+ expected_output_keys=[],
624
+ )
625
+ self.spec.nodes[GRAPH_INPUTS_NODE_ID] = node_spec
626
+
627
+ node_state = self.state.nodes.setdefault(GRAPH_INPUTS_NODE_ID, TaskNodeState())
628
+ node_state.status = NodeStatus.DONE
629
+
630
+ def _validate_and_bind_inputs(self, provided: dict[str, Any]) -> dict[str, Any]:
631
+ """Validate and bind provided inputs against the graph's IO signature."""
632
+ req = self.spec.inputs_required
633
+ # opt = set(self.spec.inputs_optional.keys())
634
+ missing = [k for k in req if k not in provided]
635
+ if missing:
636
+ raise ValueError(f"Missing required inputs: {', '.join(missing)}")
637
+
638
+ merged = dict(self.spec.inputs_optional) # start with optional defaults
639
+ merged.update(provided) # override with provided
640
+ self.state._bound_inputs = merged
641
+
642
+ # bump rev; persist an event
643
+ self.state.rev += 1
644
+ # notify
645
+ out = self._notify_inputs_bound()
646
+ if hasattr(out, "__await__"):
647
+ # optional: await if later want strict ordering
648
+ pass
649
+ return merged
650
+
651
+ def _resolve_ref(self, r: Any, node_outputs: dict[str, dict[str, Any]]) -> Any:
652
+ """Resolve a Ref or return the value as-is."""
653
+ if not (isinstance(r, dict) and r.get("_type") == "ref"):
654
+ return r
655
+
656
+ src, key = r.get("from"), r.get("key")
657
+ if src == GRAPH_INPUTS_NODE_ID:
658
+ if self.state._bound_inputs is None:
659
+ raise RuntimeError("Graph inputs not bound. Call graph(...) or bind explicitly.")
660
+ return self.state._bound_inputs.get(key)
661
+ return node_outputs.get(src, {}).get(key)
662
+
663
+ # --------- Execution APIs ---------
664
+ # Here we have temporary APIs, later we will use tools and scheduler to manage execution
665
+ def _load_logic(self, logic: Any):
666
+ # If logic is a callable, return it as-is
667
+ if callable(logic):
668
+ return logic
669
+ # If logic is a string, check if it starts with "registry:"
670
+ if isinstance(logic, str) and logic.startswith("registry:"):
671
+ # If it does, look it up in the registry
672
+ return self._lookup_registry(logic)
673
+ # If we reach here, logic is not valid
674
+ raise ValueError(f"Invalid logic: {logic}")
675
+
676
+ async def _run_tool(self, logic: Any, **kwargs):
677
+ fn = self._load_logic(logic)
678
+ res = fn(**kwargs)
679
+ if inspect.isawaitable(res):
680
+ res = await res
681
+ return res if isinstance(res, dict) else {"result": res}
682
+
683
+ # -------- Print and Debug ---------
684
+ def pretty(self, *, max_nodes: int = 20, max_width: int = 100) -> str:
685
+ """
686
+ Human-friendly summary of this TaskGraph.
687
+ """
688
+ lines: list[str] = []
689
+
690
+ # Header
691
+ lines.append(
692
+ f"TaskGraph[{self.spec.graph_id}] "
693
+ f"nodes={len(self.spec.nodes)} "
694
+ f"observers={len(self.observers)}"
695
+ )
696
+
697
+ # IO signature
698
+ lines.append("IO Signature:")
699
+ for s in self.spec.io_summary_lines():
700
+ lines.append(f" {s}")
701
+
702
+ # State summary
703
+ lines.append(f"State: {self.state.summary_line()}")
704
+
705
+ # Nodes table (compact)
706
+ lines.append("Nodes:")
707
+ header = f"{'id':<22} {'type':<10} {'status':<12} {'#deps':<5} logic"
708
+ lines.append(" " + header)
709
+ lines.append(" " + "-" * (len(header) + 4))
710
+
711
+ def _safe_get(node, attr, default):
712
+ return getattr(node, attr, default)
713
+
714
+ n_items = list(self.spec.nodes.items())
715
+ for idx, (nid, node) in enumerate(n_items):
716
+ if idx >= max_nodes:
717
+ lines.append(f" … ({len(n_items) - max_nodes} more)")
718
+ break
719
+
720
+ ntype = _safe_get(node, "node_type", "?")
721
+ status = _status_label(self.state.nodes.get(nid, TaskNodeState()).status)
722
+ deps = _safe_get(node, "dependencies", None) or []
723
+ logic = _logic_label(_safe_get(node, "logic", None))
724
+
725
+ # Width control: keep table tidy
726
+ row = f" {_short(nid, 22):<22} {_short(ntype, 10):<10} {_short(status, 12):<12} {len(deps):<5} {_short(logic, max_width)}"
727
+ lines.append(row)
728
+
729
+ return "\n".join(lines)
730
+
731
+ # Optional: make print(graph) show a compact version
732
+ def __str__(self) -> str:
733
+ return self.pretty(max_nodes=12, max_width=96)
734
+
735
+ # -------- Persistence conveniences (opt-in) --------
736
+ def spec_json(self) -> dict[str, Any]:
737
+ """
738
+ JSON-safe representation of the graph spec.
739
+ Keeps TaskGraph storage-agnostic; callers can write to file/db/etc.
740
+ """
741
+ return _dataclass_to_plain(self.spec)
742
+
743
+
744
+ # --------- Visualization ---------
745
+ TaskGraph.to_dot = to_dot
746
+ TaskGraph.visualize = visualize
747
+ TaskGraph.ascii_overview = ascii_overview