revive-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- revive/__init__.py +45 -0
- revive/adapters/__init__.py +5 -0
- revive/adapters/anthropic_tools.py +60 -0
- revive/adapters/langgraph.py +89 -0
- revive/adapters/openai_agents.py +63 -0
- revive/adapters/temporal.py +67 -0
- revive/checkpoint.py +199 -0
- revive/classifier.py +198 -0
- revive/client.py +145 -0
- revive/engine.py +294 -0
- revive/postgres.py +153 -0
- revive/providers.py +89 -0
- revive/rendezvous.py +113 -0
- revive/reporter.py +60 -0
- revive_sdk-0.1.0.dist-info/METADATA +99 -0
- revive_sdk-0.1.0.dist-info/RECORD +17 -0
- revive_sdk-0.1.0.dist-info/WHEEL +4 -0
revive/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Revive — agent recovery control plane SDK.
|
|
2
|
+
|
|
3
|
+
Primary entrypoint (talks to the hosted control plane, stdlib-only, no deps):
|
|
4
|
+
|
|
5
|
+
from revive import ReviveClient
|
|
6
|
+
revive = ReviveClient("https://revivelabs.app", api_key="rv_live_…")
|
|
7
|
+
result = revive.protect_action(run_id=..., connection_id=..., action_key=...,
|
|
8
|
+
execute=lambda: do_the_side_effect())
|
|
9
|
+
|
|
10
|
+
When a protected action's credential dies mid-run, Revive classifies the
|
|
11
|
+
death, parks the run, routes a single-use reauthorization to the bound
|
|
12
|
+
account, rotates the credential lease, and resumes the same logical run
|
|
13
|
+
without duplicating a side effect that already committed.
|
|
14
|
+
|
|
15
|
+
Framework adapters live under revive.adapters (OpenAI Agents, Anthropic tool
|
|
16
|
+
use, LangGraph). The local self-hosted engine (Engine, CheckpointStore, …)
|
|
17
|
+
remains available for running the park/resume loop in-process.
|
|
18
|
+
"""
|
|
19
|
+
# Hosted SDK entrypoint — the primary developer API.
|
|
20
|
+
from .client import (AmbiguousCommitError, ParkedRun, ReviveClient,
|
|
21
|
+
ReviveParkedError, idempotency_key)
|
|
22
|
+
|
|
23
|
+
from .checkpoint import Checkpoint, CheckpointStore
|
|
24
|
+
from .classifier import ClassifierResult, Verdict, classify
|
|
25
|
+
from .engine import (AmbiguousSideEffect, Completed, Engine, NeedsApproval,
|
|
26
|
+
Parked, StaleCredentialGeneration, Step, StepContext,
|
|
27
|
+
WrongRecoveryIdentity)
|
|
28
|
+
from .providers import AuthError, Provider, Token, TokenError
|
|
29
|
+
from .rendezvous import Kind, Rendezvous, console_channel, webhook_channel
|
|
30
|
+
from .postgres import PostgresCheckpointStore
|
|
31
|
+
from .reporter import Reporter
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"ReviveClient", "ReviveParkedError", "ParkedRun", "AmbiguousCommitError", "idempotency_key",
|
|
35
|
+
"Engine", "Step", "StepContext", "Parked", "Completed", "NeedsApproval",
|
|
36
|
+
"AmbiguousSideEffect",
|
|
37
|
+
"WrongRecoveryIdentity", "StaleCredentialGeneration",
|
|
38
|
+
"Provider", "Token", "TokenError", "AuthError",
|
|
39
|
+
"CheckpointStore", "Checkpoint",
|
|
40
|
+
"PostgresCheckpointStore",
|
|
41
|
+
"Reporter",
|
|
42
|
+
"classify", "ClassifierResult", "Verdict",
|
|
43
|
+
"Rendezvous", "Kind", "console_channel", "webhook_channel",
|
|
44
|
+
]
|
|
45
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Anthropic tool-use adapter (adapter preview).
|
|
2
|
+
|
|
3
|
+
Guard the tool-execution side of an Anthropic agent loop (Claude API tool use
|
|
4
|
+
or the Claude Agent SDK): each tool_use block that mutates an external system
|
|
5
|
+
becomes a protected action in the Revive ledger.
|
|
6
|
+
|
|
7
|
+
from revive.client import ReviveClient
|
|
8
|
+
from revive.adapters.anthropic_tools import ReviveToolGuard
|
|
9
|
+
|
|
10
|
+
revive = ReviveClient(base_url="https://console.revivelabs.app", api_key="rv_live_…")
|
|
11
|
+
guard = ReviveToolGuard(revive, connection_id="conn_…", protected={"send_email", "create_ticket"})
|
|
12
|
+
|
|
13
|
+
# inside the agent loop, for each tool_use content block:
|
|
14
|
+
result = guard.run(run_id=thread_id, tool_name=block.name, tool_input=block.input,
|
|
15
|
+
execute=lambda: dispatch_tool(block.name, block.input))
|
|
16
|
+
|
|
17
|
+
Committed replays return the stored result without re-dispatching; a dead
|
|
18
|
+
credential raises ReviveParkedError carrying the recovery URL, which the loop
|
|
19
|
+
should return to the model/user as the tool result.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import hashlib
|
|
25
|
+
import json
|
|
26
|
+
from typing import Any, Callable, Iterable, Optional
|
|
27
|
+
|
|
28
|
+
from revive.client import ReviveClient
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ReviveToolGuard:
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
client: ReviveClient,
|
|
35
|
+
*,
|
|
36
|
+
connection_id: str,
|
|
37
|
+
protected: Optional[Iterable[str]] = None,
|
|
38
|
+
lease_generation: Optional[int] = None,
|
|
39
|
+
):
|
|
40
|
+
self.client = client
|
|
41
|
+
self.connection_id = connection_id
|
|
42
|
+
self.protected = set(protected) if protected is not None else None
|
|
43
|
+
self.lease_generation = lease_generation
|
|
44
|
+
|
|
45
|
+
def is_protected(self, tool_name: str) -> bool:
|
|
46
|
+
return self.protected is None or tool_name in self.protected
|
|
47
|
+
|
|
48
|
+
def run(self, *, run_id: str, tool_name: str, tool_input: Any, execute: Callable[[], Any]) -> Any:
|
|
49
|
+
if not self.is_protected(tool_name):
|
|
50
|
+
return execute()
|
|
51
|
+
payload = json.dumps(tool_input, sort_keys=True, default=repr)
|
|
52
|
+
derived = hashlib.sha256(f"{run_id}:{tool_name}:{payload}".encode()).hexdigest()
|
|
53
|
+
return self.client.protect_action(
|
|
54
|
+
run_id=run_id,
|
|
55
|
+
connection_id=self.connection_id,
|
|
56
|
+
action_key=tool_name,
|
|
57
|
+
idem_key=derived,
|
|
58
|
+
lease_generation=self.lease_generation,
|
|
59
|
+
execute=execute,
|
|
60
|
+
)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""LangGraph adapter.
|
|
2
|
+
|
|
3
|
+
LangGraph already ships durable checkpointing and `interrupt()` / `Command(resume=)`
|
|
4
|
+
for human-in-the-loop. What it lacks is the auth-death TRIGGER and the re-consent
|
|
5
|
+
+ credential rotation. This adapter supplies that trigger, reusing LangGraph's own durable
|
|
6
|
+
pause/resume — so an existing LangGraph agent gains dead-reauth-resume with a
|
|
7
|
+
single call inside its nodes.
|
|
8
|
+
|
|
9
|
+
from revive.adapters.langgraph import revive_refresh
|
|
10
|
+
|
|
11
|
+
def files_node(state):
|
|
12
|
+
try:
|
|
13
|
+
data = call_graph_api(state["access_token"])
|
|
14
|
+
except Unauthorized:
|
|
15
|
+
tok = revive_refresh(provider, state["refresh_token"], SCOPES) # pauses if dead
|
|
16
|
+
state.update(access_token=tok.access_token, refresh_token=tok.refresh_token)
|
|
17
|
+
data = call_graph_api(tok.access_token)
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
When the refresh token is dead, `revive_refresh` calls LangGraph's `interrupt()`,
|
|
21
|
+
which durably parks the run via the configured checkpointer. The driver routes
|
|
22
|
+
the re-consent and resumes with `Command(resume={"refresh_token": <new>})`; the
|
|
23
|
+
node re-executes, `interrupt()` returns that payload, and the fresh credential
|
|
24
|
+
generation is used by the same logical thread.
|
|
25
|
+
"""
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
from ..classifier import ClassifierResult, classify
|
|
29
|
+
from ..providers import Provider, Token, TokenError
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from langgraph.types import interrupt # type: ignore
|
|
33
|
+
HAS_LANGGRAPH = True
|
|
34
|
+
except Exception: # pragma: no cover - optional dependency
|
|
35
|
+
HAS_LANGGRAPH = False
|
|
36
|
+
|
|
37
|
+
def interrupt(value): # type: ignore
|
|
38
|
+
raise RuntimeError("langgraph is not installed: pip install langgraph")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def reconsent_payload(result: ClassifierResult, scopes) -> dict:
|
|
42
|
+
return {
|
|
43
|
+
"kind": "reconsent",
|
|
44
|
+
"code": result.code,
|
|
45
|
+
"title": result.title,
|
|
46
|
+
"reason": result.reason,
|
|
47
|
+
"remediation": result.remediation,
|
|
48
|
+
"scopes": list(scopes),
|
|
49
|
+
"confidence": result.confidence,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def revive_refresh(provider: Provider, refresh_token: str, scopes=(),
|
|
54
|
+
credential_resolver=None) -> Token:
|
|
55
|
+
"""Refresh; on a truly-dead refresh token, durably pause for re-consent and
|
|
56
|
+
resume with the rotated credential. A still-refreshable error refreshes
|
|
57
|
+
silently; a transient error is re-raised for the framework's retry policy.
|
|
58
|
+
|
|
59
|
+
Resume payload contract — the Command payload is persisted in LangGraph's
|
|
60
|
+
checkpoint history, so it must stay OPAQUE:
|
|
61
|
+
|
|
62
|
+
Command(resume={"connection_id": "...", "lease_generation": 2})
|
|
63
|
+
|
|
64
|
+
`credential_resolver(connection_id, lease_generation)` is supplied by the
|
|
65
|
+
host process and returns the refresh token from the credential store; the
|
|
66
|
+
raw token never enters workflow history. Passing a raw ``refresh_token`` in
|
|
67
|
+
the resume payload is supported for local sandbox drills only.
|
|
68
|
+
"""
|
|
69
|
+
try:
|
|
70
|
+
return provider.refresh(refresh_token)
|
|
71
|
+
except TokenError as te:
|
|
72
|
+
result = classify(te.payload)
|
|
73
|
+
if not result.needs_human:
|
|
74
|
+
raise
|
|
75
|
+
# durable pause via LangGraph's checkpointer
|
|
76
|
+
reply = interrupt(reconsent_payload(result, scopes))
|
|
77
|
+
if isinstance(reply, dict) and reply.get("connection_id"):
|
|
78
|
+
if credential_resolver is None:
|
|
79
|
+
raise RuntimeError(
|
|
80
|
+
"resume payload references a connection, but no "
|
|
81
|
+
"credential_resolver was configured"
|
|
82
|
+
)
|
|
83
|
+
new_refresh = credential_resolver(
|
|
84
|
+
reply["connection_id"], int(reply.get("lease_generation") or 1)
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
# Sandbox-only fallback: raw token in the resume payload.
|
|
88
|
+
new_refresh = reply["refresh_token"] if isinstance(reply, dict) else reply
|
|
89
|
+
return provider.refresh(new_refresh)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""OpenAI Agents SDK adapter (adapter preview).
|
|
2
|
+
|
|
3
|
+
Wrap a function tool so every invocation is a protected action in the Revive
|
|
4
|
+
ledger: replays of committed calls return the stored result, uncertain calls
|
|
5
|
+
demand reconciliation, and credential failures park the run with a recovery
|
|
6
|
+
case instead of crashing the agent.
|
|
7
|
+
|
|
8
|
+
from agents import function_tool
|
|
9
|
+
from revive.client import ReviveClient
|
|
10
|
+
from revive.adapters.openai_agents import revive_tool
|
|
11
|
+
|
|
12
|
+
revive = ReviveClient(base_url="https://console.revivelabs.app", api_key="rv_live_…")
|
|
13
|
+
|
|
14
|
+
@function_tool
|
|
15
|
+
@revive_tool(revive, connection_id="conn_…", run_id_from="run_id")
|
|
16
|
+
def send_followup_email(run_id: str, to: str, subject: str) -> dict:
|
|
17
|
+
... # real provider call
|
|
18
|
+
|
|
19
|
+
The wrapped function raises ReviveParkedError when the credential dies; the
|
|
20
|
+
agent loop should surface parked.recovery_url to the account owner.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import functools
|
|
26
|
+
import json
|
|
27
|
+
from typing import Any, Callable, Optional
|
|
28
|
+
|
|
29
|
+
from revive.client import ReviveClient
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def revive_tool(
|
|
33
|
+
client: ReviveClient,
|
|
34
|
+
*,
|
|
35
|
+
connection_id: str,
|
|
36
|
+
run_id_from: str = "run_id",
|
|
37
|
+
action_key: Optional[str] = None,
|
|
38
|
+
lease_generation: Optional[int] = None,
|
|
39
|
+
):
|
|
40
|
+
"""Decorator: one exactly-once ledger entry per (run, tool, arguments)."""
|
|
41
|
+
|
|
42
|
+
def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
43
|
+
@functools.wraps(func)
|
|
44
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
45
|
+
run_id = str(kwargs.get(run_id_from) or "")
|
|
46
|
+
if not run_id:
|
|
47
|
+
raise ValueError(f"revive_tool requires the '{run_id_from}' keyword argument")
|
|
48
|
+
key_material = json.dumps({"args": [repr(a) for a in args], "kwargs": {k: repr(v) for k, v in sorted(kwargs.items())}}, sort_keys=True)
|
|
49
|
+
import hashlib
|
|
50
|
+
|
|
51
|
+
derived = hashlib.sha256(f"{run_id}:{func.__name__}:{key_material}".encode()).hexdigest()
|
|
52
|
+
return client.protect_action(
|
|
53
|
+
run_id=run_id,
|
|
54
|
+
connection_id=connection_id,
|
|
55
|
+
action_key=action_key or func.__name__,
|
|
56
|
+
idem_key=derived,
|
|
57
|
+
lease_generation=lease_generation,
|
|
58
|
+
execute=lambda: func(*args, **kwargs),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return wrapper
|
|
62
|
+
|
|
63
|
+
return decorate
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Temporal signal adapter for Revive recovery rendezvous.
|
|
2
|
+
|
|
3
|
+
Activities detect/classify credential failures. A workflow parks by waiting on a
|
|
4
|
+
RecoveryGate. The OAuth callback signals the existing workflow ID with a new
|
|
5
|
+
opaque lease reference; raw refresh tokens do not enter workflow history.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import asdict, dataclass
|
|
10
|
+
from typing import Any, Optional
|
|
11
|
+
|
|
12
|
+
SIGNAL_NAME = "revive_reauthorized"
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from temporalio import workflow # type: ignore
|
|
16
|
+
HAS_TEMPORAL = True
|
|
17
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
18
|
+
HAS_TEMPORAL = False
|
|
19
|
+
|
|
20
|
+
class _WorkflowFallback:
|
|
21
|
+
@staticmethod
|
|
22
|
+
def signal(*args, **kwargs):
|
|
23
|
+
def decorate(fn):
|
|
24
|
+
return fn
|
|
25
|
+
return decorate
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
async def wait_condition(predicate):
|
|
29
|
+
raise RuntimeError("temporalio is not installed: pip install 'revive-sidecar[temporal]'")
|
|
30
|
+
|
|
31
|
+
workflow = _WorkflowFallback() # type: ignore
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class ReauthorizationSignal:
|
|
36
|
+
recovery_case_id: str
|
|
37
|
+
connection_id: str
|
|
38
|
+
lease_generation: int
|
|
39
|
+
provider: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RecoveryGate:
|
|
43
|
+
"""Mixin/state holder for a Temporal workflow that can wait for reauthorization."""
|
|
44
|
+
|
|
45
|
+
def __init__(self) -> None:
|
|
46
|
+
self._revive_signals: dict[str, dict[str, Any]] = {}
|
|
47
|
+
|
|
48
|
+
@workflow.signal(name=SIGNAL_NAME)
|
|
49
|
+
def revive_reauthorized(self, payload: dict[str, Any]) -> None:
|
|
50
|
+
case_id = str(payload["recovery_case_id"])
|
|
51
|
+
self._revive_signals[case_id] = dict(payload)
|
|
52
|
+
|
|
53
|
+
async def wait_for_reauthorization(self, recovery_case_id: str) -> dict[str, Any]:
|
|
54
|
+
await workflow.wait_condition(lambda: recovery_case_id in self._revive_signals)
|
|
55
|
+
return self._revive_signals.pop(recovery_case_id)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TemporalRecoveryClient:
|
|
59
|
+
def __init__(self, client: Any, signal_name: str = SIGNAL_NAME):
|
|
60
|
+
self.client = client
|
|
61
|
+
self.signal_name = signal_name
|
|
62
|
+
|
|
63
|
+
async def resume(self, workflow_id: str, signal: ReauthorizationSignal,
|
|
64
|
+
run_id: Optional[str] = None) -> None:
|
|
65
|
+
kwargs = {"run_id": run_id} if run_id else {}
|
|
66
|
+
handle = self.client.get_workflow_handle(workflow_id, **kwargs)
|
|
67
|
+
await handle.signal(self.signal_name, asdict(signal))
|
revive/checkpoint.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Durable, step-accurate checkpoints — real SQLite persistence.
|
|
2
|
+
|
|
3
|
+
A checkpoint is captured the instant a run parks (dead token, or any rendezvous).
|
|
4
|
+
Because it is durable, the run survives a process restart: a worker can pick the
|
|
5
|
+
run back up and resume it from the exact step after the human responds.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import sqlite3
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import asdict, dataclass, field
|
|
13
|
+
from typing import Any, Optional
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class Checkpoint:
|
|
18
|
+
run_id: str
|
|
19
|
+
step_index: int
|
|
20
|
+
step_id: str
|
|
21
|
+
cursor: dict[str, Any]
|
|
22
|
+
token_fingerprint: str
|
|
23
|
+
scopes: list[str]
|
|
24
|
+
status: str = "parked" # parked | running | done | dead
|
|
25
|
+
taken_at: float = field(default_factory=time.time)
|
|
26
|
+
|
|
27
|
+
def to_dict(self) -> dict:
|
|
28
|
+
return asdict(self)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CheckpointStore:
|
|
32
|
+
def __init__(self, path: str = "revive.db"):
|
|
33
|
+
self.path = path
|
|
34
|
+
self._conn = sqlite3.connect(path, check_same_thread=False)
|
|
35
|
+
self._conn.execute(
|
|
36
|
+
"""CREATE TABLE IF NOT EXISTS checkpoints (
|
|
37
|
+
run_id TEXT PRIMARY KEY,
|
|
38
|
+
step_index INTEGER, step_id TEXT, cursor TEXT,
|
|
39
|
+
token_fingerprint TEXT, scopes TEXT, status TEXT, taken_at REAL
|
|
40
|
+
)"""
|
|
41
|
+
)
|
|
42
|
+
self._conn.execute(
|
|
43
|
+
"""CREATE TABLE IF NOT EXISTS rendezvous (
|
|
44
|
+
id TEXT PRIMARY KEY, run_id TEXT UNIQUE, kind TEXT, prompt TEXT,
|
|
45
|
+
url TEXT, context TEXT, status TEXT, reply TEXT,
|
|
46
|
+
created_at REAL, expires_at REAL
|
|
47
|
+
)"""
|
|
48
|
+
)
|
|
49
|
+
self._conn.execute(
|
|
50
|
+
"""CREATE TABLE IF NOT EXISTS action_attempts (
|
|
51
|
+
action_id TEXT PRIMARY KEY, run_id TEXT, step_id TEXT,
|
|
52
|
+
state TEXT, attempts INTEGER, updated_at REAL
|
|
53
|
+
)"""
|
|
54
|
+
)
|
|
55
|
+
self._conn.execute(
|
|
56
|
+
"""CREATE TABLE IF NOT EXISTS credential_leases (
|
|
57
|
+
lease_id TEXT PRIMARY KEY, generation INTEGER NOT NULL,
|
|
58
|
+
updated_at REAL NOT NULL
|
|
59
|
+
)"""
|
|
60
|
+
)
|
|
61
|
+
self._conn.commit()
|
|
62
|
+
|
|
63
|
+
def save(self, cp: Checkpoint) -> None:
|
|
64
|
+
self._conn.execute(
|
|
65
|
+
"""INSERT INTO checkpoints
|
|
66
|
+
(run_id, step_index, step_id, cursor, token_fingerprint, scopes, status, taken_at)
|
|
67
|
+
VALUES (?,?,?,?,?,?,?,?)
|
|
68
|
+
ON CONFLICT(run_id) DO UPDATE SET
|
|
69
|
+
step_index=excluded.step_index, step_id=excluded.step_id,
|
|
70
|
+
cursor=excluded.cursor, token_fingerprint=excluded.token_fingerprint,
|
|
71
|
+
scopes=excluded.scopes, status=excluded.status, taken_at=excluded.taken_at""",
|
|
72
|
+
(cp.run_id, cp.step_index, cp.step_id, json.dumps(cp.cursor),
|
|
73
|
+
cp.token_fingerprint, json.dumps(cp.scopes), cp.status, cp.taken_at),
|
|
74
|
+
)
|
|
75
|
+
self._conn.commit()
|
|
76
|
+
|
|
77
|
+
def load(self, run_id: str) -> Optional[Checkpoint]:
|
|
78
|
+
row = self._conn.execute(
|
|
79
|
+
"SELECT run_id, step_index, step_id, cursor, token_fingerprint, scopes, status, taken_at "
|
|
80
|
+
"FROM checkpoints WHERE run_id=?", (run_id,)
|
|
81
|
+
).fetchone()
|
|
82
|
+
if not row:
|
|
83
|
+
return None
|
|
84
|
+
return Checkpoint(
|
|
85
|
+
run_id=row[0], step_index=row[1], step_id=row[2], cursor=json.loads(row[3]),
|
|
86
|
+
token_fingerprint=row[4], scopes=json.loads(row[5]), status=row[6], taken_at=row[7],
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def set_status(self, run_id: str, status: str) -> None:
|
|
90
|
+
self._conn.execute("UPDATE checkpoints SET status=? WHERE run_id=?", (status, run_id))
|
|
91
|
+
self._conn.commit()
|
|
92
|
+
|
|
93
|
+
def save_rendezvous(self, data: dict[str, Any]) -> None:
|
|
94
|
+
self._conn.execute(
|
|
95
|
+
"""INSERT INTO rendezvous
|
|
96
|
+
(id, run_id, kind, prompt, url, context, status, reply, created_at, expires_at)
|
|
97
|
+
VALUES (?,?,?,?,?,?,?,?,?,?)
|
|
98
|
+
ON CONFLICT(run_id) DO UPDATE SET
|
|
99
|
+
id=excluded.id, kind=excluded.kind, prompt=excluded.prompt,
|
|
100
|
+
url=excluded.url, context=excluded.context, status=excluded.status,
|
|
101
|
+
reply=excluded.reply, created_at=excluded.created_at,
|
|
102
|
+
expires_at=excluded.expires_at""",
|
|
103
|
+
(data["id"], data["run_id"], data["kind"], data["prompt"], data["url"],
|
|
104
|
+
json.dumps(data.get("context", {})), data.get("status", "open"),
|
|
105
|
+
json.dumps(data.get("reply")), data["created_at"], data["expires_at"]),
|
|
106
|
+
)
|
|
107
|
+
self._conn.commit()
|
|
108
|
+
|
|
109
|
+
def load_rendezvous(self, run_id: str) -> Optional[dict[str, Any]]:
|
|
110
|
+
row = self._conn.execute(
|
|
111
|
+
"SELECT id, run_id, kind, prompt, url, context, status, reply, created_at, expires_at "
|
|
112
|
+
"FROM rendezvous WHERE run_id=?", (run_id,)
|
|
113
|
+
).fetchone()
|
|
114
|
+
if not row:
|
|
115
|
+
return None
|
|
116
|
+
return {
|
|
117
|
+
"id": row[0], "run_id": row[1], "kind": row[2], "prompt": row[3],
|
|
118
|
+
"url": row[4], "context": json.loads(row[5]), "status": row[6],
|
|
119
|
+
"reply": json.loads(row[7]) if row[7] else None,
|
|
120
|
+
"created_at": row[8], "expires_at": row[9],
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
def consume_rendezvous(self, run_id: str, reply: dict[str, Any]) -> bool:
|
|
124
|
+
"""Atomically consume one open, unexpired rendezvous."""
|
|
125
|
+
now = time.time()
|
|
126
|
+
cur = self._conn.execute(
|
|
127
|
+
"UPDATE rendezvous SET status='answered', reply=? "
|
|
128
|
+
"WHERE run_id=? AND status='open' AND expires_at>?",
|
|
129
|
+
(json.dumps(reply), run_id, now),
|
|
130
|
+
)
|
|
131
|
+
self._conn.commit()
|
|
132
|
+
return cur.rowcount == 1
|
|
133
|
+
|
|
134
|
+
def action_state(self, action_id: str) -> Optional[str]:
|
|
135
|
+
row = self._conn.execute(
|
|
136
|
+
"SELECT state FROM action_attempts WHERE action_id=?", (action_id,)
|
|
137
|
+
).fetchone()
|
|
138
|
+
return row[0] if row else None
|
|
139
|
+
|
|
140
|
+
def start_action(self, action_id: str, run_id: str, step_id: str) -> None:
|
|
141
|
+
now = time.time()
|
|
142
|
+
self._conn.execute(
|
|
143
|
+
"""INSERT INTO action_attempts
|
|
144
|
+
(action_id, run_id, step_id, state, attempts, updated_at)
|
|
145
|
+
VALUES (?,?,?,'started',1,?)
|
|
146
|
+
ON CONFLICT(action_id) DO UPDATE SET
|
|
147
|
+
attempts=action_attempts.attempts+1, updated_at=excluded.updated_at""",
|
|
148
|
+
(action_id, run_id, step_id, now),
|
|
149
|
+
)
|
|
150
|
+
self._conn.commit()
|
|
151
|
+
|
|
152
|
+
def complete_action(self, action_id: str) -> None:
|
|
153
|
+
self._conn.execute(
|
|
154
|
+
"UPDATE action_attempts SET state='completed', updated_at=? WHERE action_id=?",
|
|
155
|
+
(time.time(), action_id),
|
|
156
|
+
)
|
|
157
|
+
self._conn.commit()
|
|
158
|
+
|
|
159
|
+
def reset_action(self, action_id: str) -> None:
|
|
160
|
+
# A provider 401 proves the protected side effect was not accepted.
|
|
161
|
+
self._conn.execute("DELETE FROM action_attempts WHERE action_id=?", (action_id,))
|
|
162
|
+
self._conn.commit()
|
|
163
|
+
|
|
164
|
+
def ensure_lease(self, lease_id: str, generation: int) -> None:
|
|
165
|
+
"""Create a lease once without allowing an old worker to lower it."""
|
|
166
|
+
self._conn.execute(
|
|
167
|
+
"""INSERT INTO credential_leases (lease_id, generation, updated_at)
|
|
168
|
+
VALUES (?,?,?) ON CONFLICT(lease_id) DO NOTHING""",
|
|
169
|
+
(lease_id, generation, time.time()),
|
|
170
|
+
)
|
|
171
|
+
self._conn.commit()
|
|
172
|
+
|
|
173
|
+
def assert_lease_generation(self, lease_id: str, generation: int) -> None:
|
|
174
|
+
row = self._conn.execute(
|
|
175
|
+
"SELECT generation FROM credential_leases WHERE lease_id=?", (lease_id,)
|
|
176
|
+
).fetchone()
|
|
177
|
+
if row is None or int(row[0]) != generation:
|
|
178
|
+
current = int(row[0]) if row else None
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"stale credential generation for {lease_id}: got {generation}, current {current}"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def rotate_lease(self, lease_id: str, expected_generation: int) -> int:
|
|
184
|
+
"""Atomically fence the prior generation and return the new one."""
|
|
185
|
+
next_generation = expected_generation + 1
|
|
186
|
+
cur = self._conn.execute(
|
|
187
|
+
"""UPDATE credential_leases
|
|
188
|
+
SET generation=?, updated_at=?
|
|
189
|
+
WHERE lease_id=? AND generation=?""",
|
|
190
|
+
(next_generation, time.time(), lease_id, expected_generation),
|
|
191
|
+
)
|
|
192
|
+
self._conn.commit()
|
|
193
|
+
if cur.rowcount != 1:
|
|
194
|
+
self.assert_lease_generation(lease_id, expected_generation)
|
|
195
|
+
raise ValueError(f"could not rotate credential lease {lease_id}")
|
|
196
|
+
return next_generation
|
|
197
|
+
|
|
198
|
+
def close(self) -> None:
|
|
199
|
+
self._conn.close()
|