aethergraph 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. aethergraph/__init__.py +49 -0
  2. aethergraph/config/__init__.py +0 -0
  3. aethergraph/config/config.py +121 -0
  4. aethergraph/config/context.py +16 -0
  5. aethergraph/config/llm.py +26 -0
  6. aethergraph/config/loader.py +60 -0
  7. aethergraph/config/runtime.py +9 -0
  8. aethergraph/contracts/errors/errors.py +44 -0
  9. aethergraph/contracts/services/artifacts.py +142 -0
  10. aethergraph/contracts/services/channel.py +72 -0
  11. aethergraph/contracts/services/continuations.py +23 -0
  12. aethergraph/contracts/services/eventbus.py +12 -0
  13. aethergraph/contracts/services/kv.py +24 -0
  14. aethergraph/contracts/services/llm.py +17 -0
  15. aethergraph/contracts/services/mcp.py +22 -0
  16. aethergraph/contracts/services/memory.py +108 -0
  17. aethergraph/contracts/services/resume.py +28 -0
  18. aethergraph/contracts/services/state_stores.py +33 -0
  19. aethergraph/contracts/services/wakeup.py +28 -0
  20. aethergraph/core/execution/base_scheduler.py +77 -0
  21. aethergraph/core/execution/forward_scheduler.py +777 -0
  22. aethergraph/core/execution/global_scheduler.py +634 -0
  23. aethergraph/core/execution/retry_policy.py +22 -0
  24. aethergraph/core/execution/step_forward.py +411 -0
  25. aethergraph/core/execution/step_result.py +18 -0
  26. aethergraph/core/execution/wait_types.py +72 -0
  27. aethergraph/core/graph/graph_builder.py +192 -0
  28. aethergraph/core/graph/graph_fn.py +219 -0
  29. aethergraph/core/graph/graph_io.py +67 -0
  30. aethergraph/core/graph/graph_refs.py +154 -0
  31. aethergraph/core/graph/graph_spec.py +115 -0
  32. aethergraph/core/graph/graph_state.py +59 -0
  33. aethergraph/core/graph/graphify.py +128 -0
  34. aethergraph/core/graph/interpreter.py +145 -0
  35. aethergraph/core/graph/node_handle.py +33 -0
  36. aethergraph/core/graph/node_spec.py +46 -0
  37. aethergraph/core/graph/node_state.py +63 -0
  38. aethergraph/core/graph/task_graph.py +747 -0
  39. aethergraph/core/graph/task_node.py +82 -0
  40. aethergraph/core/graph/utils.py +37 -0
  41. aethergraph/core/graph/visualize.py +239 -0
  42. aethergraph/core/runtime/ad_hoc_context.py +61 -0
  43. aethergraph/core/runtime/base_service.py +153 -0
  44. aethergraph/core/runtime/bind_adapter.py +42 -0
  45. aethergraph/core/runtime/bound_memory.py +69 -0
  46. aethergraph/core/runtime/execution_context.py +220 -0
  47. aethergraph/core/runtime/graph_runner.py +349 -0
  48. aethergraph/core/runtime/lifecycle.py +26 -0
  49. aethergraph/core/runtime/node_context.py +203 -0
  50. aethergraph/core/runtime/node_services.py +30 -0
  51. aethergraph/core/runtime/recovery.py +159 -0
  52. aethergraph/core/runtime/run_registration.py +33 -0
  53. aethergraph/core/runtime/runtime_env.py +157 -0
  54. aethergraph/core/runtime/runtime_registry.py +32 -0
  55. aethergraph/core/runtime/runtime_services.py +224 -0
  56. aethergraph/core/runtime/wakeup_watcher.py +40 -0
  57. aethergraph/core/tools/__init__.py +10 -0
  58. aethergraph/core/tools/builtins/channel_tools.py +194 -0
  59. aethergraph/core/tools/builtins/toolset.py +134 -0
  60. aethergraph/core/tools/toolkit.py +510 -0
  61. aethergraph/core/tools/waitable.py +109 -0
  62. aethergraph/plugins/channel/__init__.py +0 -0
  63. aethergraph/plugins/channel/adapters/__init__.py +0 -0
  64. aethergraph/plugins/channel/adapters/console.py +106 -0
  65. aethergraph/plugins/channel/adapters/file.py +102 -0
  66. aethergraph/plugins/channel/adapters/slack.py +285 -0
  67. aethergraph/plugins/channel/adapters/telegram.py +302 -0
  68. aethergraph/plugins/channel/adapters/webhook.py +104 -0
  69. aethergraph/plugins/channel/adapters/webui.py +134 -0
  70. aethergraph/plugins/channel/routes/__init__.py +0 -0
  71. aethergraph/plugins/channel/routes/console_routes.py +86 -0
  72. aethergraph/plugins/channel/routes/slack_routes.py +49 -0
  73. aethergraph/plugins/channel/routes/telegram_routes.py +26 -0
  74. aethergraph/plugins/channel/routes/webui_routes.py +136 -0
  75. aethergraph/plugins/channel/utils/__init__.py +0 -0
  76. aethergraph/plugins/channel/utils/slack_utils.py +278 -0
  77. aethergraph/plugins/channel/utils/telegram_utils.py +324 -0
  78. aethergraph/plugins/channel/websockets/slack_ws.py +68 -0
  79. aethergraph/plugins/channel/websockets/telegram_polling.py +151 -0
  80. aethergraph/plugins/mcp/fs_server.py +128 -0
  81. aethergraph/plugins/mcp/http_server.py +101 -0
  82. aethergraph/plugins/mcp/ws_server.py +180 -0
  83. aethergraph/plugins/net/http.py +10 -0
  84. aethergraph/plugins/utils/data_io.py +359 -0
  85. aethergraph/runner/__init__.py +5 -0
  86. aethergraph/runtime/__init__.py +62 -0
  87. aethergraph/server/__init__.py +3 -0
  88. aethergraph/server/app_factory.py +84 -0
  89. aethergraph/server/start.py +122 -0
  90. aethergraph/services/__init__.py +10 -0
  91. aethergraph/services/artifacts/facade.py +284 -0
  92. aethergraph/services/artifacts/factory.py +35 -0
  93. aethergraph/services/artifacts/fs_store.py +656 -0
  94. aethergraph/services/artifacts/jsonl_index.py +123 -0
  95. aethergraph/services/artifacts/paths.py +23 -0
  96. aethergraph/services/artifacts/sqlite_index.py +209 -0
  97. aethergraph/services/artifacts/utils.py +124 -0
  98. aethergraph/services/auth/dev.py +16 -0
  99. aethergraph/services/channel/channel_bus.py +293 -0
  100. aethergraph/services/channel/factory.py +44 -0
  101. aethergraph/services/channel/session.py +511 -0
  102. aethergraph/services/channel/wait_helpers.py +57 -0
  103. aethergraph/services/clock/clock.py +9 -0
  104. aethergraph/services/container/default_container.py +320 -0
  105. aethergraph/services/continuations/continuation.py +56 -0
  106. aethergraph/services/continuations/factory.py +34 -0
  107. aethergraph/services/continuations/stores/fs_store.py +264 -0
  108. aethergraph/services/continuations/stores/inmem_store.py +95 -0
  109. aethergraph/services/eventbus/inmem.py +21 -0
  110. aethergraph/services/features/static.py +10 -0
  111. aethergraph/services/kv/ephemeral.py +90 -0
  112. aethergraph/services/kv/factory.py +27 -0
  113. aethergraph/services/kv/layered.py +41 -0
  114. aethergraph/services/kv/sqlite_kv.py +128 -0
  115. aethergraph/services/llm/factory.py +157 -0
  116. aethergraph/services/llm/generic_client.py +542 -0
  117. aethergraph/services/llm/providers.py +3 -0
  118. aethergraph/services/llm/service.py +105 -0
  119. aethergraph/services/logger/base.py +36 -0
  120. aethergraph/services/logger/compat.py +50 -0
  121. aethergraph/services/logger/formatters.py +106 -0
  122. aethergraph/services/logger/std.py +203 -0
  123. aethergraph/services/mcp/helpers.py +23 -0
  124. aethergraph/services/mcp/http_client.py +70 -0
  125. aethergraph/services/mcp/mcp_tools.py +21 -0
  126. aethergraph/services/mcp/registry.py +14 -0
  127. aethergraph/services/mcp/service.py +100 -0
  128. aethergraph/services/mcp/stdio_client.py +70 -0
  129. aethergraph/services/mcp/ws_client.py +115 -0
  130. aethergraph/services/memory/bound.py +106 -0
  131. aethergraph/services/memory/distillers/episode.py +116 -0
  132. aethergraph/services/memory/distillers/rolling.py +74 -0
  133. aethergraph/services/memory/facade.py +633 -0
  134. aethergraph/services/memory/factory.py +78 -0
  135. aethergraph/services/memory/hotlog_kv.py +27 -0
  136. aethergraph/services/memory/indices.py +74 -0
  137. aethergraph/services/memory/io_helpers.py +72 -0
  138. aethergraph/services/memory/persist_fs.py +40 -0
  139. aethergraph/services/memory/resolver.py +152 -0
  140. aethergraph/services/metering/noop.py +4 -0
  141. aethergraph/services/prompts/file_store.py +41 -0
  142. aethergraph/services/rag/chunker.py +29 -0
  143. aethergraph/services/rag/facade.py +593 -0
  144. aethergraph/services/rag/index/base.py +27 -0
  145. aethergraph/services/rag/index/faiss_index.py +121 -0
  146. aethergraph/services/rag/index/sqlite_index.py +134 -0
  147. aethergraph/services/rag/index_factory.py +52 -0
  148. aethergraph/services/rag/parsers/md.py +7 -0
  149. aethergraph/services/rag/parsers/pdf.py +14 -0
  150. aethergraph/services/rag/parsers/txt.py +7 -0
  151. aethergraph/services/rag/utils/hybrid.py +39 -0
  152. aethergraph/services/rag/utils/make_fs_key.py +62 -0
  153. aethergraph/services/redactor/simple.py +16 -0
  154. aethergraph/services/registry/key_parsing.py +44 -0
  155. aethergraph/services/registry/registry_key.py +19 -0
  156. aethergraph/services/registry/unified_registry.py +185 -0
  157. aethergraph/services/resume/multi_scheduler_resume_bus.py +65 -0
  158. aethergraph/services/resume/router.py +73 -0
  159. aethergraph/services/schedulers/registry.py +41 -0
  160. aethergraph/services/secrets/base.py +7 -0
  161. aethergraph/services/secrets/env.py +8 -0
  162. aethergraph/services/state_stores/externalize.py +135 -0
  163. aethergraph/services/state_stores/graph_observer.py +131 -0
  164. aethergraph/services/state_stores/json_store.py +67 -0
  165. aethergraph/services/state_stores/resume_policy.py +119 -0
  166. aethergraph/services/state_stores/serialize.py +249 -0
  167. aethergraph/services/state_stores/utils.py +91 -0
  168. aethergraph/services/state_stores/validate.py +78 -0
  169. aethergraph/services/tracing/noop.py +18 -0
  170. aethergraph/services/waits/wait_registry.py +91 -0
  171. aethergraph/services/wakeup/memory_queue.py +57 -0
  172. aethergraph/services/wakeup/scanner_producer.py +56 -0
  173. aethergraph/services/wakeup/worker.py +31 -0
  174. aethergraph/tools/__init__.py +25 -0
  175. aethergraph/utils/optdeps.py +8 -0
  176. aethergraph-0.1.0a1.dist-info/METADATA +410 -0
  177. aethergraph-0.1.0a1.dist-info/RECORD +182 -0
  178. aethergraph-0.1.0a1.dist-info/WHEEL +5 -0
  179. aethergraph-0.1.0a1.dist-info/entry_points.txt +2 -0
  180. aethergraph-0.1.0a1.dist-info/licenses/LICENSE +176 -0
  181. aethergraph-0.1.0a1.dist-info/licenses/NOTICE +31 -0
  182. aethergraph-0.1.0a1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,220 @@
1
+ # aethergraph/core/execution/context.py
2
+ from __future__ import annotations
3
+
4
+ from collections.abc import Callable
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ import importlib
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ if TYPE_CHECKING:
11
+ from aethergraph.core.graph.task_node import TaskNodeRuntime
12
+
13
+ from aethergraph.services.clock.clock import SystemClock
14
+ from aethergraph.services.logger.std import StdLoggerService
15
+ from aethergraph.services.resume.router import ResumeRouter
16
+
17
+ from ..graph.graph_refs import GRAPH_INPUTS_NODE_ID, RESERVED_INJECTABLES
18
+ from .bound_memory import BoundMemoryAdapter
19
+ from .node_context import NodeContext
20
+ from .node_services import NodeServices
21
+
22
+
23
+ @dataclass
24
+ class ExecutionContext:
25
+ run_id: str
26
+ graph_id: str | None
27
+ graph_inputs: dict[str, Any]
28
+ outputs_by_node: dict[str, dict[str, Any]]
29
+ services: NodeServices
30
+ logger_factory: StdLoggerService
31
+ clock: SystemClock
32
+ resume_payload: dict[str, Any] | None = None
33
+ should_run_fn: Callable[[], bool] | None = None
34
+ resume_router: ResumeRouter | None = None # ResumeRouter
35
+
36
+ # Back-compat shim
37
+ bound_memory: BoundMemoryAdapter | None = None
38
+
39
+ def create_node_context(self, node: TaskNodeRuntime) -> NodeContext:
40
+ return NodeContext(
41
+ run_id=self.run_id,
42
+ graph_id=self.graph_id or "",
43
+ node_id=node.node_id,
44
+ services=self.services,
45
+ resume_payload=self.resume_payload,
46
+ # back-compat for old ctx.mem()
47
+ bound_memory=self.bound_memory,
48
+ )
49
+
50
+ # def as_node_context(self, ad) -> "NodeContext":
51
+ # """ Create a NodeContext representing this execution context itself as a node.
52
+ # Useful for ad-hoc contexts that don't have real nodes.
53
+ # """
54
+ # return NodeContext(
55
+ # run_id=self.run_id,
56
+ # graph_id=self.graph_id or "",
57
+ # node_id="ad_",
58
+ # services=self.services,
59
+ # resume_payload=self.resume_payload,
60
+ # # back-compat for old ctx.mem()
61
+ # bound_memory=self.bound_memory,
62
+ # )
63
+
64
+ # ----- helpers used by step forward() -----
65
+ def now(self) -> datetime:
66
+ return self.clock.now()
67
+
68
+ def resolve(self, logic_ref: str):
69
+ """Resolve a logic reference to a callable.
70
+ NOTE: This is not used anymore; prefer get_logic().
71
+ """
72
+ # fallback dotted import
73
+ mod, _, attr = logic_ref.rpartition(".")
74
+ return getattr(importlib.import_module(mod), attr)
75
+
76
+ def get_logic(self, logic_ref):
77
+ """Resolve a logic reference to a callable.
78
+ If a registry is available and the ref looks like a registry key, use it.
79
+ Otherwise, if a dotted path, import it.
80
+ Otherwise, return as-is (assumed callable).
81
+ Args:
82
+ logic_ref: A callable, dotted path string, or registry key string.
83
+ Returns:
84
+ The resolved callable.
85
+ """
86
+ if isinstance(logic_ref, str) and logic_ref.startswith("registry:") and self.registry:
87
+ # registry key
88
+ return self.registry.get_logic_ref(logic_ref)
89
+ if isinstance(logic_ref, str):
90
+ # dotted path fallback
91
+ mod, _, attr = logic_ref.rpartition(".")
92
+ return getattr(importlib.import_module(mod), attr)
93
+ return logic_ref
94
+
95
+ async def resolve_inputs(self, node) -> dict[str, Any]:
96
+ """
97
+ Materialize a node's input mapping by resolving:
98
+ - {"_type":"arg","key":K} → graph input value (or optional default)
99
+ - {"_type":"ref","from":NODE_ID,"key":OUT} → upstream node's output value
100
+ - {"_type":"context","key":K,"default":D} → memory value (or D if missing)
101
+ Works recursively over dicts/lists/tuples.
102
+
103
+ The function works as follows:
104
+ - If the value is a dict with "_type" of "arg", it looks up the graph input.
105
+ - If the value is a dict with "_type" of "ref", it looks up
106
+ the specified node's output.
107
+ - If the value is a dict without special keys, it recursively resolves
108
+ each key-value pair.
109
+ - If the value is a list or tuple, it recursively resolves each element.
110
+ - Otherwise, it returns the value as-is (assumed to be a constant).
111
+
112
+ Args:
113
+ node: The TaskNodeRuntime whose inputs to resolve.
114
+ Returns:
115
+ The fully resolved inputs dict for the node.
116
+ Raises:
117
+ KeyError: If a referenced graph input or node output is missing.
118
+ """
119
+ raw = getattr(node, "inputs", {}) or {}
120
+ # Grab optional defaults from the graph spec if available
121
+ opt_defaults: dict[str, Any] = {}
122
+ parent_graph = getattr(node, "_parent_graph", None)
123
+ if parent_graph and getattr(parent_graph, "spec", None):
124
+ # _io_inputs_optional is a dict[str, Any]
125
+ opt_defaults = getattr(parent_graph.spec, "inputs_optional", {}) or {}
126
+
127
+ # Allow a fallback to graph.state.node_outputs if scheduler hasn't copied yet
128
+ fallback_outputs = {}
129
+ if parent_graph and getattr(parent_graph, "state", None):
130
+ fallback_outputs = getattr(parent_graph.state, "node_outputs", {}) or {}
131
+
132
+ def _err_path(msg: str, path: str):
133
+ raise KeyError(
134
+ f"{msg} (node={getattr(node, 'node_id', getattr(node, 'id', '?'))}, path={path})"
135
+ )
136
+
137
+ def _resolve_arg(marker: dict[str, Any], path: str):
138
+ k = marker.get("key")
139
+ if k is None:
140
+ _err_path("Bad arg marker (missing 'key')", path)
141
+ if k in self.graph_inputs:
142
+ return self.graph_inputs[k]
143
+ if k in opt_defaults:
144
+ return opt_defaults[k]
145
+ # Helpful error: show known keys
146
+ known = list(self.graph_inputs.keys())
147
+ _err_path(f"Graph input '{k}' not provided (known inputs: {known})", path)
148
+
149
+ def _resolve_ref(marker: dict[str, Any], path: str):
150
+ src = marker.get("from")
151
+ out_key = marker.get("key")
152
+ if src is None or out_key is None:
153
+ _err_path("Bad ref marker (need 'from' and 'key')", path)
154
+
155
+ # Tolerate someone emitting a ref to the inputs sentinel
156
+ if src == GRAPH_INPUTS_NODE_ID:
157
+ # Interpret as an 'arg' reference to graph inputs
158
+ return _resolve_arg({"_type": "arg", "key": out_key}, path + ".__graph_inputs__")
159
+
160
+ # Primary source: env.outputs_by_node (scheduler publishes here)
161
+ if src in self.outputs_by_node:
162
+ outs = self.outputs_by_node[src] or {}
163
+ if out_key in outs:
164
+ return outs[out_key]
165
+ _err_path(
166
+ f"Upstream node '{src}' has no output key '{out_key}'. "
167
+ f"Available: {list(outs.keys())}",
168
+ path,
169
+ )
170
+
171
+ # Fallback: graph state (useful during tests or if scheduler filled it there)
172
+ if src in fallback_outputs:
173
+ outs = fallback_outputs[src] or {}
174
+ if out_key in outs:
175
+ return outs[out_key]
176
+ _err_path(
177
+ f"(fallback) Upstream node '{src}' has no output key '{out_key}'. "
178
+ f"Available: {list(outs.keys())}",
179
+ path,
180
+ )
181
+
182
+ _err_path(f"Upstream node '{src}' outputs not available yet", path)
183
+
184
+ def _resolve_any(val: Any, path: str):
185
+ # Handle dict markers
186
+ if isinstance(val, dict):
187
+ t = val.get("_type")
188
+ if t == "arg":
189
+ return _resolve_arg(val, path)
190
+ if t == "ref":
191
+ return _resolve_ref(val, path)
192
+ if t == "context":
193
+ return self.memory.read(val["key"], val.get("default"))
194
+ # regular dict: recurse keys
195
+ return {k: _resolve_any(v, f"{path}.{k}") for k, v in val.items()}
196
+
197
+ # Handle list/tuple
198
+ if isinstance(val, list):
199
+ return [_resolve_any(v, f"{path}[{i}]") for i, v in enumerate(val)]
200
+ if isinstance(val, tuple):
201
+ return tuple(_resolve_any(v, f"{path}[{i}]") for i, v in enumerate(val))
202
+
203
+ # Pass-through literal
204
+ return val
205
+
206
+ # Make sure we don't mutate node.inputs
207
+ # materialized = _resolve_any(copy.deepcopy(raw), path="inputs")
208
+ materialized = _resolve_any(raw, path="inputs")
209
+
210
+ # Strip framework-reserved injectables from *user* inputs.
211
+ # We always inject these later from the execution context.
212
+ if isinstance(materialized, dict):
213
+ for k in list(materialized.keys()):
214
+ if k in RESERVED_INJECTABLES:
215
+ materialized.pop(k, None)
216
+
217
+ # If someone put arguments under "kwargs", keep them;
218
+ # build_call_kwargs will flatten and then drop "kwargs".
219
+ # (No change needed here beyond not touching it.)
220
+ return materialized
@@ -0,0 +1,349 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import threading
5
+ from typing import Any
6
+ import uuid
7
+
8
+ from aethergraph.contracts.errors.errors import GraphHasPendingWaits
9
+ from aethergraph.contracts.services.state_stores import GraphSnapshot
10
+ from aethergraph.core.runtime.recovery import hash_spec, recover_graph_run
11
+ from aethergraph.services.container.default_container import build_default_container
12
+ from aethergraph.services.state_stores.graph_observer import PersistenceObserver
13
+ from aethergraph.services.state_stores.resume_policy import (
14
+ assert_snapshot_json_only,
15
+ )
16
+ from aethergraph.services.state_stores.utils import snapshot_from_graph
17
+
18
+ from ..execution.forward_scheduler import ForwardScheduler
19
+ from ..execution.retry_policy import RetryPolicy
20
+ from ..graph.graph_fn import GraphFunction
21
+ from ..graph.graph_refs import resolve_any as _resolve_any
22
+ from ..runtime.runtime_env import RuntimeEnv
23
+ from ..runtime.runtime_services import ensure_services_installed
24
+ from .run_registration import RunRegistrationGuard
25
+
26
+
27
+ # ---------- env helpers ----------
28
+ def _get_container():
29
+ # install once if not installed by sidecar/server
30
+ return ensure_services_installed(build_default_container)
31
+
32
+
33
+ async def _attach_persistence(graph, env, spec, snapshot_every=1) -> PersistenceObserver:
34
+ """
35
+ Wire the centralized state_store to the graph via PersistenceObserver.
36
+ Returns the observer instance so caller can optionally force a final snapshot.
37
+ """
38
+ store = getattr(env.container, "state_store", None) or getattr(env, "state_store", None)
39
+ if not store:
40
+ # Safe no-op: resumability won't work but run still executes.
41
+ return None
42
+
43
+ obs = PersistenceObserver(
44
+ store=store,
45
+ artifact_store=getattr(env.container, "artifacts", None),
46
+ spec_hash=hash_spec(spec),
47
+ snapshot_every=snapshot_every,
48
+ )
49
+ graph.add_observer(obs)
50
+ return obs
51
+
52
+
53
+ async def _build_env(
54
+ owner, inputs: dict[str, Any], **rt_overrides
55
+ ) -> tuple[RuntimeEnv, RetryPolicy, int]:
56
+ container = _get_container()
57
+ # apply optional overrides onto the container instance
58
+ for k, v in rt_overrides.items():
59
+ if v is not None and hasattr(container, k):
60
+ setattr(container, k, v)
61
+
62
+ run_id = rt_overrides.get("run_id") or f"run-{uuid.uuid4().hex[:8]}"
63
+ env = RuntimeEnv(
64
+ run_id=run_id,
65
+ graph_inputs=inputs,
66
+ outputs_by_node={},
67
+ container=container,
68
+ )
69
+ retry = rt_overrides.get("retry") or RetryPolicy()
70
+ max_conc = rt_overrides.get("max_concurrency", getattr(owner, "max_concurrency", 4))
71
+ return env, retry, max_conc
72
+
73
+
74
+ # ---------- materialization ----------
75
+ def _materialize_task_graph(target) -> Any:
76
+ """
77
+ Accept:
78
+ - TaskGraph instance (has io_signature attr)
79
+ - graph builder object with .build()
80
+ - a callable builder that returns a TaskGraph when invoked with no args
81
+ """
82
+ # already a TaskGraph
83
+ if hasattr(target, "io_signature"):
84
+ return target
85
+
86
+ # builder pattern with .build()
87
+ if hasattr(target, "build") and callable(target.build):
88
+ g = target.build()
89
+ if hasattr(g, "io_signature"):
90
+ return g
91
+
92
+ # callable builder that returns a TaskGraph
93
+ if callable(target):
94
+ g = target()
95
+ if hasattr(g, "io_signature"):
96
+ return g
97
+
98
+ raise TypeError(
99
+ "run_async: target must be a TaskGraph instance, a TaskGraph builder, "
100
+ "or a callable returning a TaskGraph."
101
+ )
102
+
103
+
104
+ def _resolve_graph_outputs(
105
+ graph,
106
+ inputs: dict[str, Any],
107
+ env: RuntimeEnv,
108
+ ):
109
+ bindings = graph.io_signature().get("outputs", {}).get("bindings", {})
110
+
111
+ def _res(b):
112
+ return _resolve_any(b, graph_inputs=inputs, outputs_by_node=env.outputs_by_node)
113
+
114
+ try:
115
+ result = {k: _res(v) for k, v in bindings.items()}
116
+ except KeyError as e:
117
+ waiting = [
118
+ nid
119
+ for nid, n in graph.state.nodes.items()
120
+ if getattr(n, "status", "").startswith("WAITING_")
121
+ ]
122
+ continuations = []
123
+ if env.continuation_store and hasattr(env.continuation_store, "get"):
124
+ for nid in waiting:
125
+ cont = env.continuation_store.get(run_id=env.run_id, node_id=nid)
126
+ if cont:
127
+ continuations.append(
128
+ {
129
+ "node_id": nid,
130
+ "kind": cont.kind,
131
+ "token": cont.token,
132
+ "channel": cont.channel,
133
+ "deadline": getattr(cont.deadline, "isoformat", lambda: None)(),
134
+ }
135
+ )
136
+ raise GraphHasPendingWaits(
137
+ "Graph quiesced with pending waits; outputs are not yet resolvable.",
138
+ waiting_nodes=waiting,
139
+ continuations=continuations,
140
+ ) from e
141
+
142
+ return next(iter(result.values())) if len(result) == 1 else result
143
+
144
+
145
+ def _resolve_graph_outputs_or_waits(graph, inputs, env, *, raise_on_waits: bool = True):
146
+ try:
147
+ return _resolve_graph_outputs(graph, inputs, env)
148
+ except GraphHasPendingWaits as e:
149
+ if raise_on_waits:
150
+ raise
151
+ return {
152
+ "status": "waiting",
153
+ "waiting_nodes": e.waiting_nodes,
154
+ "continuations": e.continuations,
155
+ }
156
+
157
+
158
+ def _seed_outputs_from_snapshot(env, snap: GraphSnapshot):
159
+ env.outputs_by_node = env.outputs_by_node or {}
160
+ nodes = snap.state.get("nodes", {})
161
+ for nid, ns in nodes.items():
162
+ outs = (ns or {}).get("outputs") or {}
163
+ if outs:
164
+ env.outputs_by_node[nid] = outs
165
+
166
+
167
+ def _is_graph_complete(snap: GraphSnapshot) -> bool:
168
+ nodes = snap.state.get("nodes", {})
169
+ if not nodes:
170
+ return False
171
+
172
+ # completed if every node is DONE/SKIPPED or has outputs matching spec
173
+ def doneish(st):
174
+ s = (st or {}).get("status", "")
175
+ return s in ("DONE", "SKIPPED")
176
+
177
+ return all(doneish(ns) for ns in nodes.values())
178
+
179
+
180
+ async def load_latest_snapshot_json(store, run_id: str) -> dict[str, Any] | None:
181
+ """
182
+ Returns the raw JSON dict of the latest snapshot (or None).
183
+ """
184
+ snap = await store.load_latest_snapshot(run_id)
185
+ if not snap:
186
+ return None
187
+ # JsonGraphStateStore serializes GraphSnapshot via snap.__dict__
188
+ # load_latest_snapshot already returns a GraphSnapshot(**jsondict).
189
+ # Convert back to plain JSON-ish dict:
190
+ return {
191
+ "run_id": snap.run_id,
192
+ "graph_id": snap.graph_id,
193
+ "rev": snap.rev,
194
+ "spec_hash": snap.spec_hash,
195
+ "state": snap.state,
196
+ }
197
+
198
+
199
+ # ---------- public API ----------
200
+ async def run_async(target, inputs: dict[str, Any] | None = None, **rt_overrides):
201
+ """
202
+ Generic async runner for TaskGraph or GraphFunction.
203
+ - GraphFunction → delegates to gf.run(env=..., **inputs)
204
+ - TaskGraph/builder → schedules and resolves graph-level outputs
205
+ """
206
+ inputs = inputs or {}
207
+ # GraphFunction path
208
+ if isinstance(target, GraphFunction):
209
+ env, retry, max_conc = await _build_env(target, inputs, **rt_overrides)
210
+ return await target.run(env=env, max_concurrency=max_conc, **inputs)
211
+
212
+ # TaskGraph path
213
+ graph = _materialize_task_graph(target)
214
+ env, retry, max_conc = await _build_env(graph, inputs, **rt_overrides)
215
+
216
+ # Extract spec for run/recovery ...
217
+ spec = getattr(graph, "spec", None) or getattr(graph, "get_spec", lambda: None)()
218
+ if spec is None:
219
+ spec = graph.spec
220
+
221
+ store = getattr(env.container, "state_store", None)
222
+ snap = None
223
+ assert store is None or hasattr(
224
+ store, "load_latest_snapshot"
225
+ ), "state_store must implement lo ad_latest_snapshot(run_id)"
226
+
227
+ if store:
228
+ # 1) Attempt cold-resume (build a graph with hydrated state)
229
+ graph = await recover_graph_run(spec=spec, run_id=env.run_id, store=store)
230
+
231
+ # 2) Load raw JSON snapshot and ENFORCE strict policy
232
+ snap_json = await load_latest_snapshot_json(store, env.run_id)
233
+ if snap_json:
234
+ # keep for short-circuit + seeding
235
+ snap = await store.load_latest_snapshot(env.run_id)
236
+ # Short-circuit if already complete
237
+ if snap:
238
+ _seed_outputs_from_snapshot(env, snap)
239
+ if _is_graph_complete(snap):
240
+ return _resolve_graph_outputs(graph, inputs, env)
241
+
242
+ # strict policy: block resume if any non-JSON / __aether_ref__ is present
243
+ assert_snapshot_json_only(env.run_id, snap_json, mode="reuse_only")
244
+ else:
245
+ graph = _materialize_task_graph(target)
246
+
247
+ # Bind/validate inputs
248
+ graph._validate_and_bind_inputs(inputs)
249
+
250
+ # Attach persistence observer + run (unchanged) ...
251
+ obs = await _attach_persistence(graph, env, spec, snapshot_every=1)
252
+
253
+ # get logger from env's container
254
+ from ..runtime.runtime_services import current_logger_factory
255
+
256
+ logger = current_logger_factory().for_scheduler()
257
+
258
+ sched = ForwardScheduler(
259
+ graph,
260
+ env,
261
+ retry_policy=retry,
262
+ max_concurrency=max_conc,
263
+ skip_dep_on_failure=True,
264
+ stop_on_first_error=True,
265
+ logger=logger,
266
+ )
267
+
268
+ # Register for resumes and run
269
+ with RunRegistrationGuard(run_id=env.run_id, scheduler=sched, container=env.container):
270
+ try:
271
+ await sched.run()
272
+ except asyncio.CancelledError:
273
+ raise
274
+ finally:
275
+ # FINAL SNAPSHOT on normal or cancelled exit (if store exists)
276
+ if store and obs:
277
+ artifacts = getattr(env.container, "artifacts", None)
278
+ snap = await snapshot_from_graph(
279
+ run_id=graph.state.run_id or env.run_id,
280
+ graph_id=graph.graph_id,
281
+ rev=graph.state.rev,
282
+ spec_hash=hash_spec(spec),
283
+ state_obj=graph.state,
284
+ artifacts=artifacts,
285
+ allow_externalize=False, # FIXME: artifact writer async loop error; set False to *avoid* writing artifacts during snapshot
286
+ include_wait_spec=True,
287
+ )
288
+ await store.save_snapshot(snap)
289
+
290
+ # Resolve graph-level outputs (will raise GraphHasPendingWaits if waits)
291
+ return _resolve_graph_outputs_or_waits(graph, inputs, env, raise_on_waits=True)
292
+
293
+
294
+ async def run_or_resume_async(
295
+ target, inputs: dict[str, Any], *, run_id: str | None = None, **rt_overrides
296
+ ):
297
+ """
298
+ If state exists for run_id → cold resume, else fresh run.
299
+ Exactly the same signature as run_async plus optional run_id.
300
+ """
301
+ if run_id is not None:
302
+ rt_overrides = dict(rt_overrides or {}, run_id=run_id)
303
+ return await run_async(target, inputs, **rt_overrides)
304
+
305
+
306
+ # sync adapter (optional, safe in notebooks/servers)
307
+ class _LoopThread:
308
+ def __init__(self):
309
+ self._ev = threading.Event()
310
+ self._thread = threading.Thread(target=self._worker, daemon=True)
311
+ self._loop = None
312
+ self._thread.start()
313
+ self._ev.wait()
314
+
315
+ def _worker(self):
316
+ loop = asyncio.new_event_loop()
317
+ asyncio.set_event_loop(loop)
318
+ self._loop = loop
319
+ self._ev.set()
320
+ loop.run_forever()
321
+
322
+ def submit_old(self, coro):
323
+ # this will block terminal until coro is done
324
+ fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
325
+ return fut.result()
326
+
327
+ def submit(self, coro):
328
+ # this will allow KeyboardInterrupt to propagate -> still not perfect. Use async main if possible.
329
+ fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
330
+ try:
331
+ return fut.result()
332
+ except KeyboardInterrupt:
333
+ # cancel the task in the loop thread and wait for cleanup
334
+ fut.cancel()
335
+
336
+ def _cancel_all():
337
+ for t in asyncio.all_tasks(self._loop):
338
+ t.cancel()
339
+
340
+ self._loop.call_soon_threadsafe(_cancel_all)
341
+ raise
342
+
343
+
344
+ _LOOP = _LoopThread()
345
+
346
+
347
+ def run(target, inputs: dict[str, Any] | None = None, **rt_overrides):
348
+ inputs = inputs or {}
349
+ return _LOOP.submit(run_async(target, inputs, **rt_overrides))
@@ -0,0 +1,26 @@
1
+ # lifecycle.py
2
+ import asyncio
3
+
4
+ from aethergraph.core.runtime.runtime_services import current_services
5
+
6
+
7
+ async def start_all_services() -> None:
8
+ svc = current_services()
9
+ tasks = []
10
+ for _, inst in getattr(svc, "ext_services", {}).items():
11
+ start = getattr(inst, "start", None)
12
+ if asyncio.iscoroutinefunction(start):
13
+ tasks.append(start())
14
+ if tasks:
15
+ await asyncio.gather(*tasks)
16
+
17
+
18
+ async def close_all_services() -> None:
19
+ svc = current_services()
20
+ tasks = []
21
+ for _, inst in getattr(svc, "ext_services", {}).items():
22
+ close = getattr(inst, "close", None)
23
+ if asyncio.iscoroutinefunction(close):
24
+ tasks.append(close())
25
+ if tasks:
26
+ await asyncio.gather(*tasks)