aethergraph 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aethergraph/__init__.py +49 -0
- aethergraph/config/__init__.py +0 -0
- aethergraph/config/config.py +121 -0
- aethergraph/config/context.py +16 -0
- aethergraph/config/llm.py +26 -0
- aethergraph/config/loader.py +60 -0
- aethergraph/config/runtime.py +9 -0
- aethergraph/contracts/errors/errors.py +44 -0
- aethergraph/contracts/services/artifacts.py +142 -0
- aethergraph/contracts/services/channel.py +72 -0
- aethergraph/contracts/services/continuations.py +23 -0
- aethergraph/contracts/services/eventbus.py +12 -0
- aethergraph/contracts/services/kv.py +24 -0
- aethergraph/contracts/services/llm.py +17 -0
- aethergraph/contracts/services/mcp.py +22 -0
- aethergraph/contracts/services/memory.py +108 -0
- aethergraph/contracts/services/resume.py +28 -0
- aethergraph/contracts/services/state_stores.py +33 -0
- aethergraph/contracts/services/wakeup.py +28 -0
- aethergraph/core/execution/base_scheduler.py +77 -0
- aethergraph/core/execution/forward_scheduler.py +777 -0
- aethergraph/core/execution/global_scheduler.py +634 -0
- aethergraph/core/execution/retry_policy.py +22 -0
- aethergraph/core/execution/step_forward.py +411 -0
- aethergraph/core/execution/step_result.py +18 -0
- aethergraph/core/execution/wait_types.py +72 -0
- aethergraph/core/graph/graph_builder.py +192 -0
- aethergraph/core/graph/graph_fn.py +219 -0
- aethergraph/core/graph/graph_io.py +67 -0
- aethergraph/core/graph/graph_refs.py +154 -0
- aethergraph/core/graph/graph_spec.py +115 -0
- aethergraph/core/graph/graph_state.py +59 -0
- aethergraph/core/graph/graphify.py +128 -0
- aethergraph/core/graph/interpreter.py +145 -0
- aethergraph/core/graph/node_handle.py +33 -0
- aethergraph/core/graph/node_spec.py +46 -0
- aethergraph/core/graph/node_state.py +63 -0
- aethergraph/core/graph/task_graph.py +747 -0
- aethergraph/core/graph/task_node.py +82 -0
- aethergraph/core/graph/utils.py +37 -0
- aethergraph/core/graph/visualize.py +239 -0
- aethergraph/core/runtime/ad_hoc_context.py +61 -0
- aethergraph/core/runtime/base_service.py +153 -0
- aethergraph/core/runtime/bind_adapter.py +42 -0
- aethergraph/core/runtime/bound_memory.py +69 -0
- aethergraph/core/runtime/execution_context.py +220 -0
- aethergraph/core/runtime/graph_runner.py +349 -0
- aethergraph/core/runtime/lifecycle.py +26 -0
- aethergraph/core/runtime/node_context.py +203 -0
- aethergraph/core/runtime/node_services.py +30 -0
- aethergraph/core/runtime/recovery.py +159 -0
- aethergraph/core/runtime/run_registration.py +33 -0
- aethergraph/core/runtime/runtime_env.py +157 -0
- aethergraph/core/runtime/runtime_registry.py +32 -0
- aethergraph/core/runtime/runtime_services.py +224 -0
- aethergraph/core/runtime/wakeup_watcher.py +40 -0
- aethergraph/core/tools/__init__.py +10 -0
- aethergraph/core/tools/builtins/channel_tools.py +194 -0
- aethergraph/core/tools/builtins/toolset.py +134 -0
- aethergraph/core/tools/toolkit.py +510 -0
- aethergraph/core/tools/waitable.py +109 -0
- aethergraph/plugins/channel/__init__.py +0 -0
- aethergraph/plugins/channel/adapters/__init__.py +0 -0
- aethergraph/plugins/channel/adapters/console.py +106 -0
- aethergraph/plugins/channel/adapters/file.py +102 -0
- aethergraph/plugins/channel/adapters/slack.py +285 -0
- aethergraph/plugins/channel/adapters/telegram.py +302 -0
- aethergraph/plugins/channel/adapters/webhook.py +104 -0
- aethergraph/plugins/channel/adapters/webui.py +134 -0
- aethergraph/plugins/channel/routes/__init__.py +0 -0
- aethergraph/plugins/channel/routes/console_routes.py +86 -0
- aethergraph/plugins/channel/routes/slack_routes.py +49 -0
- aethergraph/plugins/channel/routes/telegram_routes.py +26 -0
- aethergraph/plugins/channel/routes/webui_routes.py +136 -0
- aethergraph/plugins/channel/utils/__init__.py +0 -0
- aethergraph/plugins/channel/utils/slack_utils.py +278 -0
- aethergraph/plugins/channel/utils/telegram_utils.py +324 -0
- aethergraph/plugins/channel/websockets/slack_ws.py +68 -0
- aethergraph/plugins/channel/websockets/telegram_polling.py +151 -0
- aethergraph/plugins/mcp/fs_server.py +128 -0
- aethergraph/plugins/mcp/http_server.py +101 -0
- aethergraph/plugins/mcp/ws_server.py +180 -0
- aethergraph/plugins/net/http.py +10 -0
- aethergraph/plugins/utils/data_io.py +359 -0
- aethergraph/runner/__init__.py +5 -0
- aethergraph/runtime/__init__.py +62 -0
- aethergraph/server/__init__.py +3 -0
- aethergraph/server/app_factory.py +84 -0
- aethergraph/server/start.py +122 -0
- aethergraph/services/__init__.py +10 -0
- aethergraph/services/artifacts/facade.py +284 -0
- aethergraph/services/artifacts/factory.py +35 -0
- aethergraph/services/artifacts/fs_store.py +656 -0
- aethergraph/services/artifacts/jsonl_index.py +123 -0
- aethergraph/services/artifacts/paths.py +23 -0
- aethergraph/services/artifacts/sqlite_index.py +209 -0
- aethergraph/services/artifacts/utils.py +124 -0
- aethergraph/services/auth/dev.py +16 -0
- aethergraph/services/channel/channel_bus.py +293 -0
- aethergraph/services/channel/factory.py +44 -0
- aethergraph/services/channel/session.py +511 -0
- aethergraph/services/channel/wait_helpers.py +57 -0
- aethergraph/services/clock/clock.py +9 -0
- aethergraph/services/container/default_container.py +320 -0
- aethergraph/services/continuations/continuation.py +56 -0
- aethergraph/services/continuations/factory.py +34 -0
- aethergraph/services/continuations/stores/fs_store.py +264 -0
- aethergraph/services/continuations/stores/inmem_store.py +95 -0
- aethergraph/services/eventbus/inmem.py +21 -0
- aethergraph/services/features/static.py +10 -0
- aethergraph/services/kv/ephemeral.py +90 -0
- aethergraph/services/kv/factory.py +27 -0
- aethergraph/services/kv/layered.py +41 -0
- aethergraph/services/kv/sqlite_kv.py +128 -0
- aethergraph/services/llm/factory.py +157 -0
- aethergraph/services/llm/generic_client.py +542 -0
- aethergraph/services/llm/providers.py +3 -0
- aethergraph/services/llm/service.py +105 -0
- aethergraph/services/logger/base.py +36 -0
- aethergraph/services/logger/compat.py +50 -0
- aethergraph/services/logger/formatters.py +106 -0
- aethergraph/services/logger/std.py +203 -0
- aethergraph/services/mcp/helpers.py +23 -0
- aethergraph/services/mcp/http_client.py +70 -0
- aethergraph/services/mcp/mcp_tools.py +21 -0
- aethergraph/services/mcp/registry.py +14 -0
- aethergraph/services/mcp/service.py +100 -0
- aethergraph/services/mcp/stdio_client.py +70 -0
- aethergraph/services/mcp/ws_client.py +115 -0
- aethergraph/services/memory/bound.py +106 -0
- aethergraph/services/memory/distillers/episode.py +116 -0
- aethergraph/services/memory/distillers/rolling.py +74 -0
- aethergraph/services/memory/facade.py +633 -0
- aethergraph/services/memory/factory.py +78 -0
- aethergraph/services/memory/hotlog_kv.py +27 -0
- aethergraph/services/memory/indices.py +74 -0
- aethergraph/services/memory/io_helpers.py +72 -0
- aethergraph/services/memory/persist_fs.py +40 -0
- aethergraph/services/memory/resolver.py +152 -0
- aethergraph/services/metering/noop.py +4 -0
- aethergraph/services/prompts/file_store.py +41 -0
- aethergraph/services/rag/chunker.py +29 -0
- aethergraph/services/rag/facade.py +593 -0
- aethergraph/services/rag/index/base.py +27 -0
- aethergraph/services/rag/index/faiss_index.py +121 -0
- aethergraph/services/rag/index/sqlite_index.py +134 -0
- aethergraph/services/rag/index_factory.py +52 -0
- aethergraph/services/rag/parsers/md.py +7 -0
- aethergraph/services/rag/parsers/pdf.py +14 -0
- aethergraph/services/rag/parsers/txt.py +7 -0
- aethergraph/services/rag/utils/hybrid.py +39 -0
- aethergraph/services/rag/utils/make_fs_key.py +62 -0
- aethergraph/services/redactor/simple.py +16 -0
- aethergraph/services/registry/key_parsing.py +44 -0
- aethergraph/services/registry/registry_key.py +19 -0
- aethergraph/services/registry/unified_registry.py +185 -0
- aethergraph/services/resume/multi_scheduler_resume_bus.py +65 -0
- aethergraph/services/resume/router.py +73 -0
- aethergraph/services/schedulers/registry.py +41 -0
- aethergraph/services/secrets/base.py +7 -0
- aethergraph/services/secrets/env.py +8 -0
- aethergraph/services/state_stores/externalize.py +135 -0
- aethergraph/services/state_stores/graph_observer.py +131 -0
- aethergraph/services/state_stores/json_store.py +67 -0
- aethergraph/services/state_stores/resume_policy.py +119 -0
- aethergraph/services/state_stores/serialize.py +249 -0
- aethergraph/services/state_stores/utils.py +91 -0
- aethergraph/services/state_stores/validate.py +78 -0
- aethergraph/services/tracing/noop.py +18 -0
- aethergraph/services/waits/wait_registry.py +91 -0
- aethergraph/services/wakeup/memory_queue.py +57 -0
- aethergraph/services/wakeup/scanner_producer.py +56 -0
- aethergraph/services/wakeup/worker.py +31 -0
- aethergraph/tools/__init__.py +25 -0
- aethergraph/utils/optdeps.py +8 -0
- aethergraph-0.1.0a1.dist-info/METADATA +410 -0
- aethergraph-0.1.0a1.dist-info/RECORD +182 -0
- aethergraph-0.1.0a1.dist-info/WHEEL +5 -0
- aethergraph-0.1.0a1.dist-info/entry_points.txt +2 -0
- aethergraph-0.1.0a1.dist-info/licenses/LICENSE +176 -0
- aethergraph-0.1.0a1.dist-info/licenses/NOTICE +31 -0
- aethergraph-0.1.0a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# aethergraph/core/execution/context.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
import importlib
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from aethergraph.core.graph.task_node import TaskNodeRuntime
|
|
12
|
+
|
|
13
|
+
from aethergraph.services.clock.clock import SystemClock
|
|
14
|
+
from aethergraph.services.logger.std import StdLoggerService
|
|
15
|
+
from aethergraph.services.resume.router import ResumeRouter
|
|
16
|
+
|
|
17
|
+
from ..graph.graph_refs import GRAPH_INPUTS_NODE_ID, RESERVED_INJECTABLES
|
|
18
|
+
from .bound_memory import BoundMemoryAdapter
|
|
19
|
+
from .node_context import NodeContext
|
|
20
|
+
from .node_services import NodeServices
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ExecutionContext:
|
|
25
|
+
run_id: str
|
|
26
|
+
graph_id: str | None
|
|
27
|
+
graph_inputs: dict[str, Any]
|
|
28
|
+
outputs_by_node: dict[str, dict[str, Any]]
|
|
29
|
+
services: NodeServices
|
|
30
|
+
logger_factory: StdLoggerService
|
|
31
|
+
clock: SystemClock
|
|
32
|
+
resume_payload: dict[str, Any] | None = None
|
|
33
|
+
should_run_fn: Callable[[], bool] | None = None
|
|
34
|
+
resume_router: ResumeRouter | None = None # ResumeRouter
|
|
35
|
+
|
|
36
|
+
# Back-compat shim
|
|
37
|
+
bound_memory: BoundMemoryAdapter | None = None
|
|
38
|
+
|
|
39
|
+
def create_node_context(self, node: TaskNodeRuntime) -> NodeContext:
|
|
40
|
+
return NodeContext(
|
|
41
|
+
run_id=self.run_id,
|
|
42
|
+
graph_id=self.graph_id or "",
|
|
43
|
+
node_id=node.node_id,
|
|
44
|
+
services=self.services,
|
|
45
|
+
resume_payload=self.resume_payload,
|
|
46
|
+
# back-compat for old ctx.mem()
|
|
47
|
+
bound_memory=self.bound_memory,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# def as_node_context(self, ad) -> "NodeContext":
|
|
51
|
+
# """ Create a NodeContext representing this execution context itself as a node.
|
|
52
|
+
# Useful for ad-hoc contexts that don't have real nodes.
|
|
53
|
+
# """
|
|
54
|
+
# return NodeContext(
|
|
55
|
+
# run_id=self.run_id,
|
|
56
|
+
# graph_id=self.graph_id or "",
|
|
57
|
+
# node_id="ad_",
|
|
58
|
+
# services=self.services,
|
|
59
|
+
# resume_payload=self.resume_payload,
|
|
60
|
+
# # back-compat for old ctx.mem()
|
|
61
|
+
# bound_memory=self.bound_memory,
|
|
62
|
+
# )
|
|
63
|
+
|
|
64
|
+
# ----- helpers used by step forward() -----
|
|
65
|
+
def now(self) -> datetime:
|
|
66
|
+
return self.clock.now()
|
|
67
|
+
|
|
68
|
+
def resolve(self, logic_ref: str):
|
|
69
|
+
"""Resolve a logic reference to a callable.
|
|
70
|
+
NOTE: This is not used anymore; prefer get_logic().
|
|
71
|
+
"""
|
|
72
|
+
# fallback dotted import
|
|
73
|
+
mod, _, attr = logic_ref.rpartition(".")
|
|
74
|
+
return getattr(importlib.import_module(mod), attr)
|
|
75
|
+
|
|
76
|
+
def get_logic(self, logic_ref):
|
|
77
|
+
"""Resolve a logic reference to a callable.
|
|
78
|
+
If a registry is available and the ref looks like a registry key, use it.
|
|
79
|
+
Otherwise, if a dotted path, import it.
|
|
80
|
+
Otherwise, return as-is (assumed callable).
|
|
81
|
+
Args:
|
|
82
|
+
logic_ref: A callable, dotted path string, or registry key string.
|
|
83
|
+
Returns:
|
|
84
|
+
The resolved callable.
|
|
85
|
+
"""
|
|
86
|
+
if isinstance(logic_ref, str) and logic_ref.startswith("registry:") and self.registry:
|
|
87
|
+
# registry key
|
|
88
|
+
return self.registry.get_logic_ref(logic_ref)
|
|
89
|
+
if isinstance(logic_ref, str):
|
|
90
|
+
# dotted path fallback
|
|
91
|
+
mod, _, attr = logic_ref.rpartition(".")
|
|
92
|
+
return getattr(importlib.import_module(mod), attr)
|
|
93
|
+
return logic_ref
|
|
94
|
+
|
|
95
|
+
async def resolve_inputs(self, node) -> dict[str, Any]:
|
|
96
|
+
"""
|
|
97
|
+
Materialize a node's input mapping by resolving:
|
|
98
|
+
- {"_type":"arg","key":K} → graph input value (or optional default)
|
|
99
|
+
- {"_type":"ref","from":NODE_ID,"key":OUT} → upstream node's output value
|
|
100
|
+
- {"_type":"context","key":K,"default":D} → memory value (or D if missing)
|
|
101
|
+
Works recursively over dicts/lists/tuples.
|
|
102
|
+
|
|
103
|
+
The function works as follows:
|
|
104
|
+
- If the value is a dict with "_type" of "arg", it looks up the graph input.
|
|
105
|
+
- If the value is a dict with "_type" of "ref", it looks up
|
|
106
|
+
the specified node's output.
|
|
107
|
+
- If the value is a dict without special keys, it recursively resolves
|
|
108
|
+
each key-value pair.
|
|
109
|
+
- If the value is a list or tuple, it recursively resolves each element.
|
|
110
|
+
- Otherwise, it returns the value as-is (assumed to be a constant).
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
node: The TaskNodeRuntime whose inputs to resolve.
|
|
114
|
+
Returns:
|
|
115
|
+
The fully resolved inputs dict for the node.
|
|
116
|
+
Raises:
|
|
117
|
+
KeyError: If a referenced graph input or node output is missing.
|
|
118
|
+
"""
|
|
119
|
+
raw = getattr(node, "inputs", {}) or {}
|
|
120
|
+
# Grab optional defaults from the graph spec if available
|
|
121
|
+
opt_defaults: dict[str, Any] = {}
|
|
122
|
+
parent_graph = getattr(node, "_parent_graph", None)
|
|
123
|
+
if parent_graph and getattr(parent_graph, "spec", None):
|
|
124
|
+
# _io_inputs_optional is a dict[str, Any]
|
|
125
|
+
opt_defaults = getattr(parent_graph.spec, "inputs_optional", {}) or {}
|
|
126
|
+
|
|
127
|
+
# Allow a fallback to graph.state.node_outputs if scheduler hasn't copied yet
|
|
128
|
+
fallback_outputs = {}
|
|
129
|
+
if parent_graph and getattr(parent_graph, "state", None):
|
|
130
|
+
fallback_outputs = getattr(parent_graph.state, "node_outputs", {}) or {}
|
|
131
|
+
|
|
132
|
+
def _err_path(msg: str, path: str):
|
|
133
|
+
raise KeyError(
|
|
134
|
+
f"{msg} (node={getattr(node, 'node_id', getattr(node, 'id', '?'))}, path={path})"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _resolve_arg(marker: dict[str, Any], path: str):
|
|
138
|
+
k = marker.get("key")
|
|
139
|
+
if k is None:
|
|
140
|
+
_err_path("Bad arg marker (missing 'key')", path)
|
|
141
|
+
if k in self.graph_inputs:
|
|
142
|
+
return self.graph_inputs[k]
|
|
143
|
+
if k in opt_defaults:
|
|
144
|
+
return opt_defaults[k]
|
|
145
|
+
# Helpful error: show known keys
|
|
146
|
+
known = list(self.graph_inputs.keys())
|
|
147
|
+
_err_path(f"Graph input '{k}' not provided (known inputs: {known})", path)
|
|
148
|
+
|
|
149
|
+
def _resolve_ref(marker: dict[str, Any], path: str):
|
|
150
|
+
src = marker.get("from")
|
|
151
|
+
out_key = marker.get("key")
|
|
152
|
+
if src is None or out_key is None:
|
|
153
|
+
_err_path("Bad ref marker (need 'from' and 'key')", path)
|
|
154
|
+
|
|
155
|
+
# Tolerate someone emitting a ref to the inputs sentinel
|
|
156
|
+
if src == GRAPH_INPUTS_NODE_ID:
|
|
157
|
+
# Interpret as an 'arg' reference to graph inputs
|
|
158
|
+
return _resolve_arg({"_type": "arg", "key": out_key}, path + ".__graph_inputs__")
|
|
159
|
+
|
|
160
|
+
# Primary source: env.outputs_by_node (scheduler publishes here)
|
|
161
|
+
if src in self.outputs_by_node:
|
|
162
|
+
outs = self.outputs_by_node[src] or {}
|
|
163
|
+
if out_key in outs:
|
|
164
|
+
return outs[out_key]
|
|
165
|
+
_err_path(
|
|
166
|
+
f"Upstream node '{src}' has no output key '{out_key}'. "
|
|
167
|
+
f"Available: {list(outs.keys())}",
|
|
168
|
+
path,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Fallback: graph state (useful during tests or if scheduler filled it there)
|
|
172
|
+
if src in fallback_outputs:
|
|
173
|
+
outs = fallback_outputs[src] or {}
|
|
174
|
+
if out_key in outs:
|
|
175
|
+
return outs[out_key]
|
|
176
|
+
_err_path(
|
|
177
|
+
f"(fallback) Upstream node '{src}' has no output key '{out_key}'. "
|
|
178
|
+
f"Available: {list(outs.keys())}",
|
|
179
|
+
path,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
_err_path(f"Upstream node '{src}' outputs not available yet", path)
|
|
183
|
+
|
|
184
|
+
def _resolve_any(val: Any, path: str):
|
|
185
|
+
# Handle dict markers
|
|
186
|
+
if isinstance(val, dict):
|
|
187
|
+
t = val.get("_type")
|
|
188
|
+
if t == "arg":
|
|
189
|
+
return _resolve_arg(val, path)
|
|
190
|
+
if t == "ref":
|
|
191
|
+
return _resolve_ref(val, path)
|
|
192
|
+
if t == "context":
|
|
193
|
+
return self.memory.read(val["key"], val.get("default"))
|
|
194
|
+
# regular dict: recurse keys
|
|
195
|
+
return {k: _resolve_any(v, f"{path}.{k}") for k, v in val.items()}
|
|
196
|
+
|
|
197
|
+
# Handle list/tuple
|
|
198
|
+
if isinstance(val, list):
|
|
199
|
+
return [_resolve_any(v, f"{path}[{i}]") for i, v in enumerate(val)]
|
|
200
|
+
if isinstance(val, tuple):
|
|
201
|
+
return tuple(_resolve_any(v, f"{path}[{i}]") for i, v in enumerate(val))
|
|
202
|
+
|
|
203
|
+
# Pass-through literal
|
|
204
|
+
return val
|
|
205
|
+
|
|
206
|
+
# Make sure we don't mutate node.inputs
|
|
207
|
+
# materialized = _resolve_any(copy.deepcopy(raw), path="inputs")
|
|
208
|
+
materialized = _resolve_any(raw, path="inputs")
|
|
209
|
+
|
|
210
|
+
# Strip framework-reserved injectables from *user* inputs.
|
|
211
|
+
# We always inject these later from the execution context.
|
|
212
|
+
if isinstance(materialized, dict):
|
|
213
|
+
for k in list(materialized.keys()):
|
|
214
|
+
if k in RESERVED_INJECTABLES:
|
|
215
|
+
materialized.pop(k, None)
|
|
216
|
+
|
|
217
|
+
# If someone put arguments under "kwargs", keep them;
|
|
218
|
+
# build_call_kwargs will flatten and then drop "kwargs".
|
|
219
|
+
# (No change needed here beyond not touching it.)
|
|
220
|
+
return materialized
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import threading
|
|
5
|
+
from typing import Any
|
|
6
|
+
import uuid
|
|
7
|
+
|
|
8
|
+
from aethergraph.contracts.errors.errors import GraphHasPendingWaits
|
|
9
|
+
from aethergraph.contracts.services.state_stores import GraphSnapshot
|
|
10
|
+
from aethergraph.core.runtime.recovery import hash_spec, recover_graph_run
|
|
11
|
+
from aethergraph.services.container.default_container import build_default_container
|
|
12
|
+
from aethergraph.services.state_stores.graph_observer import PersistenceObserver
|
|
13
|
+
from aethergraph.services.state_stores.resume_policy import (
|
|
14
|
+
assert_snapshot_json_only,
|
|
15
|
+
)
|
|
16
|
+
from aethergraph.services.state_stores.utils import snapshot_from_graph
|
|
17
|
+
|
|
18
|
+
from ..execution.forward_scheduler import ForwardScheduler
|
|
19
|
+
from ..execution.retry_policy import RetryPolicy
|
|
20
|
+
from ..graph.graph_fn import GraphFunction
|
|
21
|
+
from ..graph.graph_refs import resolve_any as _resolve_any
|
|
22
|
+
from ..runtime.runtime_env import RuntimeEnv
|
|
23
|
+
from ..runtime.runtime_services import ensure_services_installed
|
|
24
|
+
from .run_registration import RunRegistrationGuard
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# ---------- env helpers ----------
|
|
28
|
+
def _get_container():
|
|
29
|
+
# install once if not installed by sidecar/server
|
|
30
|
+
return ensure_services_installed(build_default_container)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def _attach_persistence(graph, env, spec, snapshot_every=1) -> PersistenceObserver:
|
|
34
|
+
"""
|
|
35
|
+
Wire the centralized state_store to the graph via PersistenceObserver.
|
|
36
|
+
Returns the observer instance so caller can optionally force a final snapshot.
|
|
37
|
+
"""
|
|
38
|
+
store = getattr(env.container, "state_store", None) or getattr(env, "state_store", None)
|
|
39
|
+
if not store:
|
|
40
|
+
# Safe no-op: resumability won't work but run still executes.
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
obs = PersistenceObserver(
|
|
44
|
+
store=store,
|
|
45
|
+
artifact_store=getattr(env.container, "artifacts", None),
|
|
46
|
+
spec_hash=hash_spec(spec),
|
|
47
|
+
snapshot_every=snapshot_every,
|
|
48
|
+
)
|
|
49
|
+
graph.add_observer(obs)
|
|
50
|
+
return obs
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def _build_env(
|
|
54
|
+
owner, inputs: dict[str, Any], **rt_overrides
|
|
55
|
+
) -> tuple[RuntimeEnv, RetryPolicy, int]:
|
|
56
|
+
container = _get_container()
|
|
57
|
+
# apply optional overrides onto the container instance
|
|
58
|
+
for k, v in rt_overrides.items():
|
|
59
|
+
if v is not None and hasattr(container, k):
|
|
60
|
+
setattr(container, k, v)
|
|
61
|
+
|
|
62
|
+
run_id = rt_overrides.get("run_id") or f"run-{uuid.uuid4().hex[:8]}"
|
|
63
|
+
env = RuntimeEnv(
|
|
64
|
+
run_id=run_id,
|
|
65
|
+
graph_inputs=inputs,
|
|
66
|
+
outputs_by_node={},
|
|
67
|
+
container=container,
|
|
68
|
+
)
|
|
69
|
+
retry = rt_overrides.get("retry") or RetryPolicy()
|
|
70
|
+
max_conc = rt_overrides.get("max_concurrency", getattr(owner, "max_concurrency", 4))
|
|
71
|
+
return env, retry, max_conc
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------- materialization ----------
|
|
75
|
+
def _materialize_task_graph(target) -> Any:
|
|
76
|
+
"""
|
|
77
|
+
Accept:
|
|
78
|
+
- TaskGraph instance (has io_signature attr)
|
|
79
|
+
- graph builder object with .build()
|
|
80
|
+
- a callable builder that returns a TaskGraph when invoked with no args
|
|
81
|
+
"""
|
|
82
|
+
# already a TaskGraph
|
|
83
|
+
if hasattr(target, "io_signature"):
|
|
84
|
+
return target
|
|
85
|
+
|
|
86
|
+
# builder pattern with .build()
|
|
87
|
+
if hasattr(target, "build") and callable(target.build):
|
|
88
|
+
g = target.build()
|
|
89
|
+
if hasattr(g, "io_signature"):
|
|
90
|
+
return g
|
|
91
|
+
|
|
92
|
+
# callable builder that returns a TaskGraph
|
|
93
|
+
if callable(target):
|
|
94
|
+
g = target()
|
|
95
|
+
if hasattr(g, "io_signature"):
|
|
96
|
+
return g
|
|
97
|
+
|
|
98
|
+
raise TypeError(
|
|
99
|
+
"run_async: target must be a TaskGraph instance, a TaskGraph builder, "
|
|
100
|
+
"or a callable returning a TaskGraph."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _resolve_graph_outputs(
|
|
105
|
+
graph,
|
|
106
|
+
inputs: dict[str, Any],
|
|
107
|
+
env: RuntimeEnv,
|
|
108
|
+
):
|
|
109
|
+
bindings = graph.io_signature().get("outputs", {}).get("bindings", {})
|
|
110
|
+
|
|
111
|
+
def _res(b):
|
|
112
|
+
return _resolve_any(b, graph_inputs=inputs, outputs_by_node=env.outputs_by_node)
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
result = {k: _res(v) for k, v in bindings.items()}
|
|
116
|
+
except KeyError as e:
|
|
117
|
+
waiting = [
|
|
118
|
+
nid
|
|
119
|
+
for nid, n in graph.state.nodes.items()
|
|
120
|
+
if getattr(n, "status", "").startswith("WAITING_")
|
|
121
|
+
]
|
|
122
|
+
continuations = []
|
|
123
|
+
if env.continuation_store and hasattr(env.continuation_store, "get"):
|
|
124
|
+
for nid in waiting:
|
|
125
|
+
cont = env.continuation_store.get(run_id=env.run_id, node_id=nid)
|
|
126
|
+
if cont:
|
|
127
|
+
continuations.append(
|
|
128
|
+
{
|
|
129
|
+
"node_id": nid,
|
|
130
|
+
"kind": cont.kind,
|
|
131
|
+
"token": cont.token,
|
|
132
|
+
"channel": cont.channel,
|
|
133
|
+
"deadline": getattr(cont.deadline, "isoformat", lambda: None)(),
|
|
134
|
+
}
|
|
135
|
+
)
|
|
136
|
+
raise GraphHasPendingWaits(
|
|
137
|
+
"Graph quiesced with pending waits; outputs are not yet resolvable.",
|
|
138
|
+
waiting_nodes=waiting,
|
|
139
|
+
continuations=continuations,
|
|
140
|
+
) from e
|
|
141
|
+
|
|
142
|
+
return next(iter(result.values())) if len(result) == 1 else result
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _resolve_graph_outputs_or_waits(graph, inputs, env, *, raise_on_waits: bool = True):
|
|
146
|
+
try:
|
|
147
|
+
return _resolve_graph_outputs(graph, inputs, env)
|
|
148
|
+
except GraphHasPendingWaits as e:
|
|
149
|
+
if raise_on_waits:
|
|
150
|
+
raise
|
|
151
|
+
return {
|
|
152
|
+
"status": "waiting",
|
|
153
|
+
"waiting_nodes": e.waiting_nodes,
|
|
154
|
+
"continuations": e.continuations,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _seed_outputs_from_snapshot(env, snap: GraphSnapshot):
|
|
159
|
+
env.outputs_by_node = env.outputs_by_node or {}
|
|
160
|
+
nodes = snap.state.get("nodes", {})
|
|
161
|
+
for nid, ns in nodes.items():
|
|
162
|
+
outs = (ns or {}).get("outputs") or {}
|
|
163
|
+
if outs:
|
|
164
|
+
env.outputs_by_node[nid] = outs
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _is_graph_complete(snap: GraphSnapshot) -> bool:
|
|
168
|
+
nodes = snap.state.get("nodes", {})
|
|
169
|
+
if not nodes:
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
# completed if every node is DONE/SKIPPED or has outputs matching spec
|
|
173
|
+
def doneish(st):
|
|
174
|
+
s = (st or {}).get("status", "")
|
|
175
|
+
return s in ("DONE", "SKIPPED")
|
|
176
|
+
|
|
177
|
+
return all(doneish(ns) for ns in nodes.values())
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def load_latest_snapshot_json(store, run_id: str) -> dict[str, Any] | None:
|
|
181
|
+
"""
|
|
182
|
+
Returns the raw JSON dict of the latest snapshot (or None).
|
|
183
|
+
"""
|
|
184
|
+
snap = await store.load_latest_snapshot(run_id)
|
|
185
|
+
if not snap:
|
|
186
|
+
return None
|
|
187
|
+
# JsonGraphStateStore serializes GraphSnapshot via snap.__dict__
|
|
188
|
+
# load_latest_snapshot already returns a GraphSnapshot(**jsondict).
|
|
189
|
+
# Convert back to plain JSON-ish dict:
|
|
190
|
+
return {
|
|
191
|
+
"run_id": snap.run_id,
|
|
192
|
+
"graph_id": snap.graph_id,
|
|
193
|
+
"rev": snap.rev,
|
|
194
|
+
"spec_hash": snap.spec_hash,
|
|
195
|
+
"state": snap.state,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# ---------- public API ----------
|
|
200
|
+
async def run_async(target, inputs: dict[str, Any] | None = None, **rt_overrides):
|
|
201
|
+
"""
|
|
202
|
+
Generic async runner for TaskGraph or GraphFunction.
|
|
203
|
+
- GraphFunction → delegates to gf.run(env=..., **inputs)
|
|
204
|
+
- TaskGraph/builder → schedules and resolves graph-level outputs
|
|
205
|
+
"""
|
|
206
|
+
inputs = inputs or {}
|
|
207
|
+
# GraphFunction path
|
|
208
|
+
if isinstance(target, GraphFunction):
|
|
209
|
+
env, retry, max_conc = await _build_env(target, inputs, **rt_overrides)
|
|
210
|
+
return await target.run(env=env, max_concurrency=max_conc, **inputs)
|
|
211
|
+
|
|
212
|
+
# TaskGraph path
|
|
213
|
+
graph = _materialize_task_graph(target)
|
|
214
|
+
env, retry, max_conc = await _build_env(graph, inputs, **rt_overrides)
|
|
215
|
+
|
|
216
|
+
# Extract spec for run/recovery ...
|
|
217
|
+
spec = getattr(graph, "spec", None) or getattr(graph, "get_spec", lambda: None)()
|
|
218
|
+
if spec is None:
|
|
219
|
+
spec = graph.spec
|
|
220
|
+
|
|
221
|
+
store = getattr(env.container, "state_store", None)
|
|
222
|
+
snap = None
|
|
223
|
+
assert store is None or hasattr(
|
|
224
|
+
store, "load_latest_snapshot"
|
|
225
|
+
), "state_store must implement lo ad_latest_snapshot(run_id)"
|
|
226
|
+
|
|
227
|
+
if store:
|
|
228
|
+
# 1) Attempt cold-resume (build a graph with hydrated state)
|
|
229
|
+
graph = await recover_graph_run(spec=spec, run_id=env.run_id, store=store)
|
|
230
|
+
|
|
231
|
+
# 2) Load raw JSON snapshot and ENFORCE strict policy
|
|
232
|
+
snap_json = await load_latest_snapshot_json(store, env.run_id)
|
|
233
|
+
if snap_json:
|
|
234
|
+
# keep for short-circuit + seeding
|
|
235
|
+
snap = await store.load_latest_snapshot(env.run_id)
|
|
236
|
+
# Short-circuit if already complete
|
|
237
|
+
if snap:
|
|
238
|
+
_seed_outputs_from_snapshot(env, snap)
|
|
239
|
+
if _is_graph_complete(snap):
|
|
240
|
+
return _resolve_graph_outputs(graph, inputs, env)
|
|
241
|
+
|
|
242
|
+
# strict policy: block resume if any non-JSON / __aether_ref__ is present
|
|
243
|
+
assert_snapshot_json_only(env.run_id, snap_json, mode="reuse_only")
|
|
244
|
+
else:
|
|
245
|
+
graph = _materialize_task_graph(target)
|
|
246
|
+
|
|
247
|
+
# Bind/validate inputs
|
|
248
|
+
graph._validate_and_bind_inputs(inputs)
|
|
249
|
+
|
|
250
|
+
# Attach persistence observer + run (unchanged) ...
|
|
251
|
+
obs = await _attach_persistence(graph, env, spec, snapshot_every=1)
|
|
252
|
+
|
|
253
|
+
# get logger from env's container
|
|
254
|
+
from ..runtime.runtime_services import current_logger_factory
|
|
255
|
+
|
|
256
|
+
logger = current_logger_factory().for_scheduler()
|
|
257
|
+
|
|
258
|
+
sched = ForwardScheduler(
|
|
259
|
+
graph,
|
|
260
|
+
env,
|
|
261
|
+
retry_policy=retry,
|
|
262
|
+
max_concurrency=max_conc,
|
|
263
|
+
skip_dep_on_failure=True,
|
|
264
|
+
stop_on_first_error=True,
|
|
265
|
+
logger=logger,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Register for resumes and run
|
|
269
|
+
with RunRegistrationGuard(run_id=env.run_id, scheduler=sched, container=env.container):
|
|
270
|
+
try:
|
|
271
|
+
await sched.run()
|
|
272
|
+
except asyncio.CancelledError:
|
|
273
|
+
raise
|
|
274
|
+
finally:
|
|
275
|
+
# FINAL SNAPSHOT on normal or cancelled exit (if store exists)
|
|
276
|
+
if store and obs:
|
|
277
|
+
artifacts = getattr(env.container, "artifacts", None)
|
|
278
|
+
snap = await snapshot_from_graph(
|
|
279
|
+
run_id=graph.state.run_id or env.run_id,
|
|
280
|
+
graph_id=graph.graph_id,
|
|
281
|
+
rev=graph.state.rev,
|
|
282
|
+
spec_hash=hash_spec(spec),
|
|
283
|
+
state_obj=graph.state,
|
|
284
|
+
artifacts=artifacts,
|
|
285
|
+
allow_externalize=False, # FIXME: artifact writer async loop error; set False to *avoid* writing artifacts during snapshot
|
|
286
|
+
include_wait_spec=True,
|
|
287
|
+
)
|
|
288
|
+
await store.save_snapshot(snap)
|
|
289
|
+
|
|
290
|
+
# Resolve graph-level outputs (will raise GraphHasPendingWaits if waits)
|
|
291
|
+
return _resolve_graph_outputs_or_waits(graph, inputs, env, raise_on_waits=True)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
async def run_or_resume_async(
|
|
295
|
+
target, inputs: dict[str, Any], *, run_id: str | None = None, **rt_overrides
|
|
296
|
+
):
|
|
297
|
+
"""
|
|
298
|
+
If state exists for run_id → cold resume, else fresh run.
|
|
299
|
+
Exactly the same signature as run_async plus optional run_id.
|
|
300
|
+
"""
|
|
301
|
+
if run_id is not None:
|
|
302
|
+
rt_overrides = dict(rt_overrides or {}, run_id=run_id)
|
|
303
|
+
return await run_async(target, inputs, **rt_overrides)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
# sync adapter (optional, safe in notebooks/servers)
|
|
307
|
+
class _LoopThread:
|
|
308
|
+
def __init__(self):
|
|
309
|
+
self._ev = threading.Event()
|
|
310
|
+
self._thread = threading.Thread(target=self._worker, daemon=True)
|
|
311
|
+
self._loop = None
|
|
312
|
+
self._thread.start()
|
|
313
|
+
self._ev.wait()
|
|
314
|
+
|
|
315
|
+
def _worker(self):
|
|
316
|
+
loop = asyncio.new_event_loop()
|
|
317
|
+
asyncio.set_event_loop(loop)
|
|
318
|
+
self._loop = loop
|
|
319
|
+
self._ev.set()
|
|
320
|
+
loop.run_forever()
|
|
321
|
+
|
|
322
|
+
def submit_old(self, coro):
|
|
323
|
+
# this will block terminal until coro is done
|
|
324
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
325
|
+
return fut.result()
|
|
326
|
+
|
|
327
|
+
def submit(self, coro):
|
|
328
|
+
# this will allow KeyboardInterrupt to propagate -> still not perfect. Use async main if possible.
|
|
329
|
+
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
330
|
+
try:
|
|
331
|
+
return fut.result()
|
|
332
|
+
except KeyboardInterrupt:
|
|
333
|
+
# cancel the task in the loop thread and wait for cleanup
|
|
334
|
+
fut.cancel()
|
|
335
|
+
|
|
336
|
+
def _cancel_all():
|
|
337
|
+
for t in asyncio.all_tasks(self._loop):
|
|
338
|
+
t.cancel()
|
|
339
|
+
|
|
340
|
+
self._loop.call_soon_threadsafe(_cancel_all)
|
|
341
|
+
raise
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
_LOOP = _LoopThread()
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def run(target, inputs: dict[str, Any] | None = None, **rt_overrides):
|
|
348
|
+
inputs = inputs or {}
|
|
349
|
+
return _LOOP.submit(run_async(target, inputs, **rt_overrides))
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# lifecycle.py
|
|
2
|
+
import asyncio
|
|
3
|
+
|
|
4
|
+
from aethergraph.core.runtime.runtime_services import current_services
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
async def start_all_services() -> None:
|
|
8
|
+
svc = current_services()
|
|
9
|
+
tasks = []
|
|
10
|
+
for _, inst in getattr(svc, "ext_services", {}).items():
|
|
11
|
+
start = getattr(inst, "start", None)
|
|
12
|
+
if asyncio.iscoroutinefunction(start):
|
|
13
|
+
tasks.append(start())
|
|
14
|
+
if tasks:
|
|
15
|
+
await asyncio.gather(*tasks)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
async def close_all_services() -> None:
|
|
19
|
+
svc = current_services()
|
|
20
|
+
tasks = []
|
|
21
|
+
for _, inst in getattr(svc, "ext_services", {}).items():
|
|
22
|
+
close = getattr(inst, "close", None)
|
|
23
|
+
if asyncio.iscoroutinefunction(close):
|
|
24
|
+
tasks.append(close())
|
|
25
|
+
if tasks:
|
|
26
|
+
await asyncio.gather(*tasks)
|