switchplane 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- switchplane/__init__.py +12 -0
- switchplane/__main__.py +4 -0
- switchplane/_util.py +36 -0
- switchplane/agent.py +46 -0
- switchplane/agent_runtime.py +555 -0
- switchplane/app.py +157 -0
- switchplane/checkpoint.py +365 -0
- switchplane/cli.py +596 -0
- switchplane/config.py +83 -0
- switchplane/control_plane.py +643 -0
- switchplane/daemon.py +350 -0
- switchplane/discovery.py +155 -0
- switchplane/fmt.py +132 -0
- switchplane/llm.py +96 -0
- switchplane/logging.py +103 -0
- switchplane/mcp.py +305 -0
- switchplane/oauth.py +465 -0
- switchplane/persistence.py +498 -0
- switchplane/protocol.py +73 -0
- switchplane/shell.py +386 -0
- switchplane/subprocess_manager.py +425 -0
- switchplane/task.py +204 -0
- switchplane/transport.py +234 -0
- switchplane/tui.py +1380 -0
- switchplane-0.1.0.dist-info/METADATA +802 -0
- switchplane-0.1.0.dist-info/RECORD +28 -0
- switchplane-0.1.0.dist-info/WHEEL +4 -0
- switchplane-0.1.0.dist-info/licenses/LICENSE +191 -0
switchplane/app.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Application and MCP server configuration."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from switchplane.agent import AgentSpec
|
|
11
|
+
from switchplane.config import AppConfig
|
|
12
|
+
|
|
13
|
+
_VALID_APP_NAME = re.compile(r"^[a-zA-Z][a-zA-Z0-9_-]*$")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OAuthConfig(BaseModel):
|
|
17
|
+
"""OAuth 2.0 configuration for an MCP server.
|
|
18
|
+
|
|
19
|
+
Two modes are supported:
|
|
20
|
+
|
|
21
|
+
**MCP-spec OAuth** (default): Leave ``auth_url`` and ``token_url``
|
|
22
|
+
unset. The MCP SDK's ``OAuthClientProvider`` discovers endpoints
|
|
23
|
+
from the MCP server's protected-resource metadata automatically.
|
|
24
|
+
|
|
25
|
+
**Direct OIDC**: Set ``auth_url`` and ``token_url`` to point at an
|
|
26
|
+
external identity provider (e.g. Keycloak/QuantumK). Switchplane
|
|
27
|
+
runs the PKCE authorization-code flow against those endpoints
|
|
28
|
+
directly, bypassing MCP metadata discovery.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
model_config = ConfigDict(str_strip_whitespace=True)
|
|
32
|
+
|
|
33
|
+
client_id: str
|
|
34
|
+
client_secret: str | None = None
|
|
35
|
+
callback_port: int = 3118
|
|
36
|
+
scopes: str | None = None
|
|
37
|
+
auth_url: str | None = None
|
|
38
|
+
token_url: str | None = None
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def is_direct(self) -> bool:
|
|
42
|
+
"""True when explicit OIDC endpoints are configured."""
|
|
43
|
+
return self.auth_url is not None and self.token_url is not None
|
|
44
|
+
|
|
45
|
+
def model_post_init(self, __context) -> None:
|
|
46
|
+
if bool(self.auth_url) != bool(self.token_url):
|
|
47
|
+
raise ValueError("Provide both 'auth_url' and 'token_url', or neither")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class McpServerConfig(BaseModel):
|
|
51
|
+
"""Configuration for an MCP server.
|
|
52
|
+
|
|
53
|
+
Provide `command` for stdio transport (Switchplane manages the process)
|
|
54
|
+
or `url` for HTTP transport (Switchplane connects to an existing server).
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
model_config = ConfigDict(
|
|
58
|
+
str_strip_whitespace=True,
|
|
59
|
+
validate_assignment=True,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
name: str
|
|
63
|
+
command: list[str] | None = None
|
|
64
|
+
url: str | None = None
|
|
65
|
+
env: dict[str, str] = {}
|
|
66
|
+
http_transport: str | None = None
|
|
67
|
+
oauth: OAuthConfig | None = None
|
|
68
|
+
oauth_group: str | None = None
|
|
69
|
+
timeout: float | None = 30.0
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def oauth_storage_key(self) -> str:
|
|
73
|
+
"""Key for OAuth token storage. Servers sharing an oauth_group share tokens."""
|
|
74
|
+
return self.oauth_group or self.name
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def transport(self) -> str:
|
|
78
|
+
return "stdio" if self.command else "http"
|
|
79
|
+
|
|
80
|
+
def model_post_init(self, __context) -> None:
|
|
81
|
+
if self.command and self.url:
|
|
82
|
+
raise ValueError("Provide either 'command' (stdio) or 'url' (http), not both")
|
|
83
|
+
if not self.command and not self.url:
|
|
84
|
+
raise ValueError("Provide either 'command' (stdio) or 'url' (http)")
|
|
85
|
+
if self.oauth and self.http_transport:
|
|
86
|
+
raise ValueError("Provide either 'oauth' or 'http_transport', not both")
|
|
87
|
+
if self.oauth and not self.url:
|
|
88
|
+
raise ValueError("OAuth requires HTTP transport ('url' must be set)")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class Application:
|
|
92
|
+
"""Application container for agents and MCP servers."""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
name: str,
|
|
97
|
+
runtime_dir: str | Path | None = None,
|
|
98
|
+
default_config: Path | None = None,
|
|
99
|
+
config_class: "type[AppConfig] | None" = None,
|
|
100
|
+
) -> None:
|
|
101
|
+
"""Initialize application.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: Application name
|
|
105
|
+
runtime_dir: Optional runtime directory path (defaults to ~/.{name})
|
|
106
|
+
default_config: Optional path to default config file bundled with the app
|
|
107
|
+
config_class: Pydantic model class used to parse the merged config.
|
|
108
|
+
Defaults to AppConfig; pass a subclass to support app-specific
|
|
109
|
+
config sections that are otherwise dropped by the base model.
|
|
110
|
+
"""
|
|
111
|
+
from switchplane.config import AppConfig as _AppConfig
|
|
112
|
+
|
|
113
|
+
if not _VALID_APP_NAME.match(name):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"Invalid application name {name!r}. "
|
|
116
|
+
"Names must start with a letter and contain only letters, digits, hyphens, or underscores."
|
|
117
|
+
)
|
|
118
|
+
self.name = name
|
|
119
|
+
self.runtime_dir = Path(runtime_dir).expanduser() if runtime_dir else Path.home() / f".{name}"
|
|
120
|
+
self.default_config_path = default_config
|
|
121
|
+
self.config_class: type[AppConfig] = config_class or _AppConfig
|
|
122
|
+
self.agents: dict[str, AgentSpec] = {}
|
|
123
|
+
self.mcp_servers: dict[str, McpServerConfig] = {}
|
|
124
|
+
self._discovery_roots: list[str] = []
|
|
125
|
+
|
|
126
|
+
def discover_agents(self, root: str) -> None:
|
|
127
|
+
"""Store root directory for later agent discovery.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
root: Root directory path to discover agents from
|
|
131
|
+
"""
|
|
132
|
+
self._discovery_roots.append(root)
|
|
133
|
+
|
|
134
|
+
def register_agent(self, spec: "AgentSpec") -> None:
|
|
135
|
+
"""Register an agent specification.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
spec: Agent specification to register
|
|
139
|
+
"""
|
|
140
|
+
self.agents[spec.agent_name] = spec
|
|
141
|
+
|
|
142
|
+
def register_mcp_server(self, config: McpServerConfig) -> None:
|
|
143
|
+
"""Register an MCP server configuration.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
config: MCP server configuration to register
|
|
147
|
+
"""
|
|
148
|
+
self.mcp_servers[config.name] = config
|
|
149
|
+
|
|
150
|
+
def run(self) -> None:
|
|
151
|
+
"""Discover agents and start the CLI."""
|
|
152
|
+
from switchplane.cli import build_cli
|
|
153
|
+
from switchplane.discovery import discover_agents_for_app
|
|
154
|
+
|
|
155
|
+
discover_agents_for_app(self)
|
|
156
|
+
cli = build_cli(self)
|
|
157
|
+
cli()
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""Custom LangGraph checkpoint saver backed by SQLite."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import aiosqlite
|
|
9
|
+
from langchain_core.runnables import RunnableConfig
|
|
10
|
+
from langgraph.checkpoint.base import (
|
|
11
|
+
BaseCheckpointSaver,
|
|
12
|
+
ChannelVersions,
|
|
13
|
+
Checkpoint,
|
|
14
|
+
CheckpointMetadata,
|
|
15
|
+
CheckpointTuple,
|
|
16
|
+
JsonPlusSerializer,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
async def setup_tables(db: aiosqlite.Connection) -> None:
|
|
21
|
+
"""Create checkpoint tables in the SQLite database."""
|
|
22
|
+
await db.execute("""
|
|
23
|
+
CREATE TABLE IF NOT EXISTS checkpoints (
|
|
24
|
+
thread_id TEXT NOT NULL,
|
|
25
|
+
checkpoint_ns TEXT NOT NULL DEFAULT '',
|
|
26
|
+
checkpoint_id TEXT NOT NULL,
|
|
27
|
+
parent_checkpoint_id TEXT,
|
|
28
|
+
type TEXT,
|
|
29
|
+
checkpoint BLOB,
|
|
30
|
+
metadata BLOB,
|
|
31
|
+
metadata_type TEXT,
|
|
32
|
+
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
|
|
33
|
+
)
|
|
34
|
+
""")
|
|
35
|
+
|
|
36
|
+
await db.execute("""
|
|
37
|
+
CREATE TABLE IF NOT EXISTS checkpoint_writes (
|
|
38
|
+
thread_id TEXT NOT NULL,
|
|
39
|
+
checkpoint_ns TEXT NOT NULL DEFAULT '',
|
|
40
|
+
checkpoint_id TEXT NOT NULL,
|
|
41
|
+
task_id TEXT NOT NULL,
|
|
42
|
+
idx INTEGER NOT NULL,
|
|
43
|
+
channel TEXT NOT NULL,
|
|
44
|
+
type TEXT,
|
|
45
|
+
blob BLOB,
|
|
46
|
+
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
|
|
47
|
+
)
|
|
48
|
+
""")
|
|
49
|
+
|
|
50
|
+
await db.commit()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class SqliteCheckpointSaver(BaseCheckpointSaver):
|
|
54
|
+
"""LangGraph checkpoint saver backed by SQLite."""
|
|
55
|
+
|
|
56
|
+
serde = JsonPlusSerializer()
|
|
57
|
+
|
|
58
|
+
def __init__(self, db: aiosqlite.Connection):
|
|
59
|
+
"""Initialize with an existing SQLite connection."""
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.db = db
|
|
62
|
+
|
|
63
|
+
async def setup(self) -> None:
|
|
64
|
+
"""Create checkpoint tables if they don't exist."""
|
|
65
|
+
await setup_tables(self.db)
|
|
66
|
+
|
|
67
|
+
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
68
|
+
"""Get a checkpoint tuple for the given config."""
|
|
69
|
+
thread_id = config["configurable"]["thread_id"]
|
|
70
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
71
|
+
checkpoint_id = config["configurable"].get("checkpoint_id")
|
|
72
|
+
|
|
73
|
+
# Query for the checkpoint
|
|
74
|
+
if checkpoint_id:
|
|
75
|
+
# Get specific checkpoint
|
|
76
|
+
cursor = await self.db.execute(
|
|
77
|
+
"""
|
|
78
|
+
SELECT checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type
|
|
79
|
+
FROM checkpoints
|
|
80
|
+
WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?
|
|
81
|
+
""",
|
|
82
|
+
(thread_id, checkpoint_ns, checkpoint_id),
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
# Get latest checkpoint
|
|
86
|
+
cursor = await self.db.execute(
|
|
87
|
+
"""
|
|
88
|
+
SELECT checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type
|
|
89
|
+
FROM checkpoints
|
|
90
|
+
WHERE thread_id = ? AND checkpoint_ns = ?
|
|
91
|
+
ORDER BY checkpoint_id DESC
|
|
92
|
+
LIMIT 1
|
|
93
|
+
""",
|
|
94
|
+
(thread_id, checkpoint_ns),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
row = await cursor.fetchone()
|
|
98
|
+
if not row:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
checkpoint_id, parent_checkpoint_id, checkpoint_type, checkpoint_blob, metadata_blob, metadata_type = row
|
|
102
|
+
|
|
103
|
+
# Deserialize checkpoint and metadata
|
|
104
|
+
checkpoint = self.serde.loads_typed((checkpoint_type, checkpoint_blob)) if checkpoint_blob else {}
|
|
105
|
+
metadata = self.serde.loads_typed((metadata_type or checkpoint_type, metadata_blob)) if metadata_blob else {}
|
|
106
|
+
|
|
107
|
+
# Query for pending writes
|
|
108
|
+
writes_cursor = await self.db.execute(
|
|
109
|
+
"""
|
|
110
|
+
SELECT task_id, channel, type, blob
|
|
111
|
+
FROM checkpoint_writes
|
|
112
|
+
WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?
|
|
113
|
+
ORDER BY task_id, idx
|
|
114
|
+
""",
|
|
115
|
+
(thread_id, checkpoint_ns, checkpoint_id),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
writes_rows = await writes_cursor.fetchall()
|
|
119
|
+
pending_writes = []
|
|
120
|
+
for write_row in writes_rows:
|
|
121
|
+
task_id, channel, write_type, write_blob = write_row
|
|
122
|
+
value = self.serde.loads_typed((write_type, write_blob)) if write_blob else None
|
|
123
|
+
pending_writes.append((task_id, channel, value))
|
|
124
|
+
|
|
125
|
+
# Build parent config if there's a parent checkpoint
|
|
126
|
+
parent_config = None
|
|
127
|
+
if parent_checkpoint_id:
|
|
128
|
+
parent_config = {
|
|
129
|
+
"configurable": {
|
|
130
|
+
"thread_id": thread_id,
|
|
131
|
+
"checkpoint_ns": checkpoint_ns,
|
|
132
|
+
"checkpoint_id": parent_checkpoint_id,
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
return CheckpointTuple(
|
|
137
|
+
config={
|
|
138
|
+
"configurable": {
|
|
139
|
+
"thread_id": thread_id,
|
|
140
|
+
"checkpoint_ns": checkpoint_ns,
|
|
141
|
+
"checkpoint_id": checkpoint_id,
|
|
142
|
+
}
|
|
143
|
+
},
|
|
144
|
+
checkpoint=checkpoint,
|
|
145
|
+
metadata=metadata,
|
|
146
|
+
parent_config=parent_config,
|
|
147
|
+
pending_writes=pending_writes,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
async def alist(
|
|
151
|
+
self,
|
|
152
|
+
config: RunnableConfig | None,
|
|
153
|
+
*,
|
|
154
|
+
filter: dict[str, Any] | None = None,
|
|
155
|
+
before: RunnableConfig | None = None,
|
|
156
|
+
limit: int = 10,
|
|
157
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
158
|
+
"""List checkpoints for a thread."""
|
|
159
|
+
if config is None:
|
|
160
|
+
return
|
|
161
|
+
thread_id = config["configurable"]["thread_id"]
|
|
162
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
163
|
+
|
|
164
|
+
# Build query
|
|
165
|
+
query = """
|
|
166
|
+
SELECT checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type
|
|
167
|
+
FROM checkpoints
|
|
168
|
+
WHERE thread_id = ? AND checkpoint_ns = ?
|
|
169
|
+
"""
|
|
170
|
+
params = [thread_id, checkpoint_ns]
|
|
171
|
+
|
|
172
|
+
# Add before filter if specified
|
|
173
|
+
if before:
|
|
174
|
+
before_checkpoint_id = before["configurable"].get("checkpoint_id")
|
|
175
|
+
if before_checkpoint_id:
|
|
176
|
+
query += " AND checkpoint_id < ?"
|
|
177
|
+
params.append(before_checkpoint_id)
|
|
178
|
+
|
|
179
|
+
query += " ORDER BY checkpoint_id DESC LIMIT ?"
|
|
180
|
+
params.append(limit)
|
|
181
|
+
|
|
182
|
+
cursor = await self.db.execute(query, params)
|
|
183
|
+
rows = await cursor.fetchall()
|
|
184
|
+
|
|
185
|
+
for row in rows:
|
|
186
|
+
checkpoint_id, parent_checkpoint_id, checkpoint_type, checkpoint_blob, metadata_blob, metadata_type = row
|
|
187
|
+
|
|
188
|
+
# Deserialize
|
|
189
|
+
checkpoint = self.serde.loads_typed((checkpoint_type, checkpoint_blob)) if checkpoint_blob else {}
|
|
190
|
+
metadata = (
|
|
191
|
+
self.serde.loads_typed((metadata_type or checkpoint_type, metadata_blob)) if metadata_blob else {}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Apply filter if provided
|
|
195
|
+
if filter and not all(metadata.get(k) == v for k, v in filter.items()):
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
# Query for pending writes
|
|
199
|
+
writes_cursor = await self.db.execute(
|
|
200
|
+
"""
|
|
201
|
+
SELECT task_id, channel, type, blob
|
|
202
|
+
FROM checkpoint_writes
|
|
203
|
+
WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?
|
|
204
|
+
ORDER BY task_id, idx
|
|
205
|
+
""",
|
|
206
|
+
(thread_id, checkpoint_ns, checkpoint_id),
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
writes_rows = await writes_cursor.fetchall()
|
|
210
|
+
pending_writes = []
|
|
211
|
+
for write_row in writes_rows:
|
|
212
|
+
task_id, channel, write_type, write_blob = write_row
|
|
213
|
+
value = self.serde.loads_typed((write_type, write_blob)) if write_blob else None
|
|
214
|
+
pending_writes.append((task_id, channel, value))
|
|
215
|
+
|
|
216
|
+
# Build parent config
|
|
217
|
+
parent_config = None
|
|
218
|
+
if parent_checkpoint_id:
|
|
219
|
+
parent_config = {
|
|
220
|
+
"configurable": {
|
|
221
|
+
"thread_id": thread_id,
|
|
222
|
+
"checkpoint_ns": checkpoint_ns,
|
|
223
|
+
"checkpoint_id": parent_checkpoint_id,
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
yield CheckpointTuple(
|
|
228
|
+
config={
|
|
229
|
+
"configurable": {
|
|
230
|
+
"thread_id": thread_id,
|
|
231
|
+
"checkpoint_ns": checkpoint_ns,
|
|
232
|
+
"checkpoint_id": checkpoint_id,
|
|
233
|
+
}
|
|
234
|
+
},
|
|
235
|
+
checkpoint=checkpoint,
|
|
236
|
+
metadata=metadata,
|
|
237
|
+
parent_config=parent_config,
|
|
238
|
+
pending_writes=pending_writes,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
async def aput(
|
|
242
|
+
self,
|
|
243
|
+
config: RunnableConfig,
|
|
244
|
+
checkpoint: Checkpoint,
|
|
245
|
+
metadata: CheckpointMetadata,
|
|
246
|
+
new_versions: ChannelVersions,
|
|
247
|
+
) -> RunnableConfig:
|
|
248
|
+
"""Save a checkpoint."""
|
|
249
|
+
thread_id = config["configurable"]["thread_id"]
|
|
250
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
251
|
+
parent_checkpoint_id = config["configurable"].get("checkpoint_id")
|
|
252
|
+
|
|
253
|
+
# Generate new checkpoint ID from checkpoint data
|
|
254
|
+
checkpoint_id = checkpoint["id"]
|
|
255
|
+
|
|
256
|
+
# Serialize checkpoint and metadata
|
|
257
|
+
checkpoint_type, checkpoint_blob = self.serde.dumps_typed(checkpoint)
|
|
258
|
+
metadata_type, metadata_blob = self.serde.dumps_typed(metadata)
|
|
259
|
+
|
|
260
|
+
# Insert or replace checkpoint
|
|
261
|
+
await self.db.execute(
|
|
262
|
+
"""
|
|
263
|
+
INSERT OR REPLACE INTO checkpoints
|
|
264
|
+
(thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type)
|
|
265
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
266
|
+
""",
|
|
267
|
+
(
|
|
268
|
+
thread_id,
|
|
269
|
+
checkpoint_ns,
|
|
270
|
+
checkpoint_id,
|
|
271
|
+
parent_checkpoint_id,
|
|
272
|
+
checkpoint_type,
|
|
273
|
+
checkpoint_blob,
|
|
274
|
+
metadata_blob,
|
|
275
|
+
metadata_type,
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
await self.db.commit()
|
|
280
|
+
|
|
281
|
+
# Return updated config with new checkpoint_id
|
|
282
|
+
return {
|
|
283
|
+
"configurable": {
|
|
284
|
+
"thread_id": thread_id,
|
|
285
|
+
"checkpoint_ns": checkpoint_ns,
|
|
286
|
+
"checkpoint_id": checkpoint_id,
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
async def aput_writes(
|
|
291
|
+
self,
|
|
292
|
+
config: RunnableConfig,
|
|
293
|
+
writes: list[tuple[str, Any]],
|
|
294
|
+
task_id: str,
|
|
295
|
+
task_path: str = "",
|
|
296
|
+
) -> None:
|
|
297
|
+
"""Save pending writes for a checkpoint."""
|
|
298
|
+
thread_id = config["configurable"]["thread_id"]
|
|
299
|
+
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
|
|
300
|
+
checkpoint_id = config["configurable"]["checkpoint_id"]
|
|
301
|
+
|
|
302
|
+
# Insert writes
|
|
303
|
+
rows = []
|
|
304
|
+
for idx, (channel, value) in enumerate(writes):
|
|
305
|
+
write_type, write_blob = self.serde.dumps_typed(value)
|
|
306
|
+
rows.append(
|
|
307
|
+
(
|
|
308
|
+
thread_id,
|
|
309
|
+
checkpoint_ns,
|
|
310
|
+
checkpoint_id,
|
|
311
|
+
task_id,
|
|
312
|
+
idx,
|
|
313
|
+
channel,
|
|
314
|
+
write_type,
|
|
315
|
+
write_blob,
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
await self.db.executemany(
|
|
320
|
+
"""
|
|
321
|
+
INSERT OR REPLACE INTO checkpoint_writes
|
|
322
|
+
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
|
|
323
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
324
|
+
""",
|
|
325
|
+
rows,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
await self.db.commit()
|
|
329
|
+
|
|
330
|
+
# Sync methods - not implemented
|
|
331
|
+
|
|
332
|
+
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
333
|
+
"""Sync version not implemented - use async version."""
|
|
334
|
+
raise NotImplementedError("Use async version: aget_tuple")
|
|
335
|
+
|
|
336
|
+
def list(
|
|
337
|
+
self,
|
|
338
|
+
config: RunnableConfig | None,
|
|
339
|
+
*,
|
|
340
|
+
filter: dict[str, Any] | None = None,
|
|
341
|
+
before: RunnableConfig | None = None,
|
|
342
|
+
limit: int = 10,
|
|
343
|
+
):
|
|
344
|
+
"""Sync version not implemented - use async version."""
|
|
345
|
+
raise NotImplementedError("Use async version: alist")
|
|
346
|
+
|
|
347
|
+
def put(
|
|
348
|
+
self,
|
|
349
|
+
config: RunnableConfig,
|
|
350
|
+
checkpoint: Checkpoint,
|
|
351
|
+
metadata: CheckpointMetadata,
|
|
352
|
+
new_versions: ChannelVersions,
|
|
353
|
+
) -> RunnableConfig:
|
|
354
|
+
"""Sync version not implemented - use async version."""
|
|
355
|
+
raise NotImplementedError("Use async version: aput")
|
|
356
|
+
|
|
357
|
+
def put_writes(
|
|
358
|
+
self,
|
|
359
|
+
config: RunnableConfig,
|
|
360
|
+
writes: list[tuple[str, Any]],
|
|
361
|
+
task_id: str,
|
|
362
|
+
task_path: str = "",
|
|
363
|
+
) -> None:
|
|
364
|
+
"""Sync version not implemented - use async version."""
|
|
365
|
+
raise NotImplementedError("Use async version: aput_writes")
|