switchplane 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
switchplane/app.py ADDED
@@ -0,0 +1,157 @@
1
+ """Application and MCP server configuration."""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ if TYPE_CHECKING:
10
+ from switchplane.agent import AgentSpec
11
+ from switchplane.config import AppConfig
12
+
13
+ _VALID_APP_NAME = re.compile(r"^[a-zA-Z][a-zA-Z0-9_-]*$")
14
+
15
+
16
+ class OAuthConfig(BaseModel):
17
+ """OAuth 2.0 configuration for an MCP server.
18
+
19
+ Two modes are supported:
20
+
21
+ **MCP-spec OAuth** (default): Leave ``auth_url`` and ``token_url``
22
+ unset. The MCP SDK's ``OAuthClientProvider`` discovers endpoints
23
+ from the MCP server's protected-resource metadata automatically.
24
+
25
+ **Direct OIDC**: Set ``auth_url`` and ``token_url`` to point at an
26
+ external identity provider (e.g. Keycloak/QuantumK). Switchplane
27
+ runs the PKCE authorization-code flow against those endpoints
28
+ directly, bypassing MCP metadata discovery.
29
+ """
30
+
31
+ model_config = ConfigDict(str_strip_whitespace=True)
32
+
33
+ client_id: str
34
+ client_secret: str | None = None
35
+ callback_port: int = 3118
36
+ scopes: str | None = None
37
+ auth_url: str | None = None
38
+ token_url: str | None = None
39
+
40
+ @property
41
+ def is_direct(self) -> bool:
42
+ """True when explicit OIDC endpoints are configured."""
43
+ return self.auth_url is not None and self.token_url is not None
44
+
45
+ def model_post_init(self, __context) -> None:
46
+ if bool(self.auth_url) != bool(self.token_url):
47
+ raise ValueError("Provide both 'auth_url' and 'token_url', or neither")
48
+
49
+
50
+ class McpServerConfig(BaseModel):
51
+ """Configuration for an MCP server.
52
+
53
+ Provide `command` for stdio transport (Switchplane manages the process)
54
+ or `url` for HTTP transport (Switchplane connects to an existing server).
55
+ """
56
+
57
+ model_config = ConfigDict(
58
+ str_strip_whitespace=True,
59
+ validate_assignment=True,
60
+ )
61
+
62
+ name: str
63
+ command: list[str] | None = None
64
+ url: str | None = None
65
+ env: dict[str, str] = {}
66
+ http_transport: str | None = None
67
+ oauth: OAuthConfig | None = None
68
+ oauth_group: str | None = None
69
+ timeout: float | None = 30.0
70
+
71
+ @property
72
+ def oauth_storage_key(self) -> str:
73
+ """Key for OAuth token storage. Servers sharing an oauth_group share tokens."""
74
+ return self.oauth_group or self.name
75
+
76
+ @property
77
+ def transport(self) -> str:
78
+ return "stdio" if self.command else "http"
79
+
80
+ def model_post_init(self, __context) -> None:
81
+ if self.command and self.url:
82
+ raise ValueError("Provide either 'command' (stdio) or 'url' (http), not both")
83
+ if not self.command and not self.url:
84
+ raise ValueError("Provide either 'command' (stdio) or 'url' (http)")
85
+ if self.oauth and self.http_transport:
86
+ raise ValueError("Provide either 'oauth' or 'http_transport', not both")
87
+ if self.oauth and not self.url:
88
+ raise ValueError("OAuth requires HTTP transport ('url' must be set)")
89
+
90
+
91
+ class Application:
92
+ """Application container for agents and MCP servers."""
93
+
94
+ def __init__(
95
+ self,
96
+ name: str,
97
+ runtime_dir: str | Path | None = None,
98
+ default_config: Path | None = None,
99
+ config_class: "type[AppConfig] | None" = None,
100
+ ) -> None:
101
+ """Initialize application.
102
+
103
+ Args:
104
+ name: Application name
105
+ runtime_dir: Optional runtime directory path (defaults to ~/.{name})
106
+ default_config: Optional path to default config file bundled with the app
107
+ config_class: Pydantic model class used to parse the merged config.
108
+ Defaults to AppConfig; pass a subclass to support app-specific
109
+ config sections that are otherwise dropped by the base model.
110
+ """
111
+ from switchplane.config import AppConfig as _AppConfig
112
+
113
+ if not _VALID_APP_NAME.match(name):
114
+ raise ValueError(
115
+ f"Invalid application name {name!r}. "
116
+ "Names must start with a letter and contain only letters, digits, hyphens, or underscores."
117
+ )
118
+ self.name = name
119
+ self.runtime_dir = Path(runtime_dir).expanduser() if runtime_dir else Path.home() / f".{name}"
120
+ self.default_config_path = default_config
121
+ self.config_class: type[AppConfig] = config_class or _AppConfig
122
+ self.agents: dict[str, AgentSpec] = {}
123
+ self.mcp_servers: dict[str, McpServerConfig] = {}
124
+ self._discovery_roots: list[str] = []
125
+
126
+ def discover_agents(self, root: str) -> None:
127
+ """Store root directory for later agent discovery.
128
+
129
+ Args:
130
+ root: Root directory path to discover agents from
131
+ """
132
+ self._discovery_roots.append(root)
133
+
134
+ def register_agent(self, spec: "AgentSpec") -> None:
135
+ """Register an agent specification.
136
+
137
+ Args:
138
+ spec: Agent specification to register
139
+ """
140
+ self.agents[spec.agent_name] = spec
141
+
142
+ def register_mcp_server(self, config: McpServerConfig) -> None:
143
+ """Register an MCP server configuration.
144
+
145
+ Args:
146
+ config: MCP server configuration to register
147
+ """
148
+ self.mcp_servers[config.name] = config
149
+
150
+ def run(self) -> None:
151
+ """Discover agents and start the CLI."""
152
+ from switchplane.cli import build_cli
153
+ from switchplane.discovery import discover_agents_for_app
154
+
155
+ discover_agents_for_app(self)
156
+ cli = build_cli(self)
157
+ cli()
@@ -0,0 +1,365 @@
1
+ """Custom LangGraph checkpoint saver backed by SQLite."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import AsyncIterator
6
+ from typing import Any
7
+
8
+ import aiosqlite
9
+ from langchain_core.runnables import RunnableConfig
10
+ from langgraph.checkpoint.base import (
11
+ BaseCheckpointSaver,
12
+ ChannelVersions,
13
+ Checkpoint,
14
+ CheckpointMetadata,
15
+ CheckpointTuple,
16
+ JsonPlusSerializer,
17
+ )
18
+
19
+
20
+ async def setup_tables(db: aiosqlite.Connection) -> None:
21
+ """Create checkpoint tables in the SQLite database."""
22
+ await db.execute("""
23
+ CREATE TABLE IF NOT EXISTS checkpoints (
24
+ thread_id TEXT NOT NULL,
25
+ checkpoint_ns TEXT NOT NULL DEFAULT '',
26
+ checkpoint_id TEXT NOT NULL,
27
+ parent_checkpoint_id TEXT,
28
+ type TEXT,
29
+ checkpoint BLOB,
30
+ metadata BLOB,
31
+ metadata_type TEXT,
32
+ PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
33
+ )
34
+ """)
35
+
36
+ await db.execute("""
37
+ CREATE TABLE IF NOT EXISTS checkpoint_writes (
38
+ thread_id TEXT NOT NULL,
39
+ checkpoint_ns TEXT NOT NULL DEFAULT '',
40
+ checkpoint_id TEXT NOT NULL,
41
+ task_id TEXT NOT NULL,
42
+ idx INTEGER NOT NULL,
43
+ channel TEXT NOT NULL,
44
+ type TEXT,
45
+ blob BLOB,
46
+ PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
47
+ )
48
+ """)
49
+
50
+ await db.commit()
51
+
52
+
53
+ class SqliteCheckpointSaver(BaseCheckpointSaver):
54
+ """LangGraph checkpoint saver backed by SQLite."""
55
+
56
+ serde = JsonPlusSerializer()
57
+
58
+ def __init__(self, db: aiosqlite.Connection):
59
+ """Initialize with an existing SQLite connection."""
60
+ super().__init__()
61
+ self.db = db
62
+
63
+ async def setup(self) -> None:
64
+ """Create checkpoint tables if they don't exist."""
65
+ await setup_tables(self.db)
66
+
67
+ async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
68
+ """Get a checkpoint tuple for the given config."""
69
+ thread_id = config["configurable"]["thread_id"]
70
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
71
+ checkpoint_id = config["configurable"].get("checkpoint_id")
72
+
73
+ # Query for the checkpoint
74
+ if checkpoint_id:
75
+ # Get specific checkpoint
76
+ cursor = await self.db.execute(
77
+ """
78
+ SELECT checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type
79
+ FROM checkpoints
80
+ WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?
81
+ """,
82
+ (thread_id, checkpoint_ns, checkpoint_id),
83
+ )
84
+ else:
85
+ # Get latest checkpoint
86
+ cursor = await self.db.execute(
87
+ """
88
+ SELECT checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type
89
+ FROM checkpoints
90
+ WHERE thread_id = ? AND checkpoint_ns = ?
91
+ ORDER BY checkpoint_id DESC
92
+ LIMIT 1
93
+ """,
94
+ (thread_id, checkpoint_ns),
95
+ )
96
+
97
+ row = await cursor.fetchone()
98
+ if not row:
99
+ return None
100
+
101
+ checkpoint_id, parent_checkpoint_id, checkpoint_type, checkpoint_blob, metadata_blob, metadata_type = row
102
+
103
+ # Deserialize checkpoint and metadata
104
+ checkpoint = self.serde.loads_typed((checkpoint_type, checkpoint_blob)) if checkpoint_blob else {}
105
+ metadata = self.serde.loads_typed((metadata_type or checkpoint_type, metadata_blob)) if metadata_blob else {}
106
+
107
+ # Query for pending writes
108
+ writes_cursor = await self.db.execute(
109
+ """
110
+ SELECT task_id, channel, type, blob
111
+ FROM checkpoint_writes
112
+ WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?
113
+ ORDER BY task_id, idx
114
+ """,
115
+ (thread_id, checkpoint_ns, checkpoint_id),
116
+ )
117
+
118
+ writes_rows = await writes_cursor.fetchall()
119
+ pending_writes = []
120
+ for write_row in writes_rows:
121
+ task_id, channel, write_type, write_blob = write_row
122
+ value = self.serde.loads_typed((write_type, write_blob)) if write_blob else None
123
+ pending_writes.append((task_id, channel, value))
124
+
125
+ # Build parent config if there's a parent checkpoint
126
+ parent_config = None
127
+ if parent_checkpoint_id:
128
+ parent_config = {
129
+ "configurable": {
130
+ "thread_id": thread_id,
131
+ "checkpoint_ns": checkpoint_ns,
132
+ "checkpoint_id": parent_checkpoint_id,
133
+ }
134
+ }
135
+
136
+ return CheckpointTuple(
137
+ config={
138
+ "configurable": {
139
+ "thread_id": thread_id,
140
+ "checkpoint_ns": checkpoint_ns,
141
+ "checkpoint_id": checkpoint_id,
142
+ }
143
+ },
144
+ checkpoint=checkpoint,
145
+ metadata=metadata,
146
+ parent_config=parent_config,
147
+ pending_writes=pending_writes,
148
+ )
149
+
150
+ async def alist(
151
+ self,
152
+ config: RunnableConfig | None,
153
+ *,
154
+ filter: dict[str, Any] | None = None,
155
+ before: RunnableConfig | None = None,
156
+ limit: int = 10,
157
+ ) -> AsyncIterator[CheckpointTuple]:
158
+ """List checkpoints for a thread."""
159
+ if config is None:
160
+ return
161
+ thread_id = config["configurable"]["thread_id"]
162
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
163
+
164
+ # Build query
165
+ query = """
166
+ SELECT checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type
167
+ FROM checkpoints
168
+ WHERE thread_id = ? AND checkpoint_ns = ?
169
+ """
170
+ params = [thread_id, checkpoint_ns]
171
+
172
+ # Add before filter if specified
173
+ if before:
174
+ before_checkpoint_id = before["configurable"].get("checkpoint_id")
175
+ if before_checkpoint_id:
176
+ query += " AND checkpoint_id < ?"
177
+ params.append(before_checkpoint_id)
178
+
179
+ query += " ORDER BY checkpoint_id DESC LIMIT ?"
180
+ params.append(limit)
181
+
182
+ cursor = await self.db.execute(query, params)
183
+ rows = await cursor.fetchall()
184
+
185
+ for row in rows:
186
+ checkpoint_id, parent_checkpoint_id, checkpoint_type, checkpoint_blob, metadata_blob, metadata_type = row
187
+
188
+ # Deserialize
189
+ checkpoint = self.serde.loads_typed((checkpoint_type, checkpoint_blob)) if checkpoint_blob else {}
190
+ metadata = (
191
+ self.serde.loads_typed((metadata_type or checkpoint_type, metadata_blob)) if metadata_blob else {}
192
+ )
193
+
194
+ # Apply filter if provided
195
+ if filter and not all(metadata.get(k) == v for k, v in filter.items()):
196
+ continue
197
+
198
+ # Query for pending writes
199
+ writes_cursor = await self.db.execute(
200
+ """
201
+ SELECT task_id, channel, type, blob
202
+ FROM checkpoint_writes
203
+ WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?
204
+ ORDER BY task_id, idx
205
+ """,
206
+ (thread_id, checkpoint_ns, checkpoint_id),
207
+ )
208
+
209
+ writes_rows = await writes_cursor.fetchall()
210
+ pending_writes = []
211
+ for write_row in writes_rows:
212
+ task_id, channel, write_type, write_blob = write_row
213
+ value = self.serde.loads_typed((write_type, write_blob)) if write_blob else None
214
+ pending_writes.append((task_id, channel, value))
215
+
216
+ # Build parent config
217
+ parent_config = None
218
+ if parent_checkpoint_id:
219
+ parent_config = {
220
+ "configurable": {
221
+ "thread_id": thread_id,
222
+ "checkpoint_ns": checkpoint_ns,
223
+ "checkpoint_id": parent_checkpoint_id,
224
+ }
225
+ }
226
+
227
+ yield CheckpointTuple(
228
+ config={
229
+ "configurable": {
230
+ "thread_id": thread_id,
231
+ "checkpoint_ns": checkpoint_ns,
232
+ "checkpoint_id": checkpoint_id,
233
+ }
234
+ },
235
+ checkpoint=checkpoint,
236
+ metadata=metadata,
237
+ parent_config=parent_config,
238
+ pending_writes=pending_writes,
239
+ )
240
+
241
+ async def aput(
242
+ self,
243
+ config: RunnableConfig,
244
+ checkpoint: Checkpoint,
245
+ metadata: CheckpointMetadata,
246
+ new_versions: ChannelVersions,
247
+ ) -> RunnableConfig:
248
+ """Save a checkpoint."""
249
+ thread_id = config["configurable"]["thread_id"]
250
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
251
+ parent_checkpoint_id = config["configurable"].get("checkpoint_id")
252
+
253
+ # Generate new checkpoint ID from checkpoint data
254
+ checkpoint_id = checkpoint["id"]
255
+
256
+ # Serialize checkpoint and metadata
257
+ checkpoint_type, checkpoint_blob = self.serde.dumps_typed(checkpoint)
258
+ metadata_type, metadata_blob = self.serde.dumps_typed(metadata)
259
+
260
+ # Insert or replace checkpoint
261
+ await self.db.execute(
262
+ """
263
+ INSERT OR REPLACE INTO checkpoints
264
+ (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata, metadata_type)
265
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
266
+ """,
267
+ (
268
+ thread_id,
269
+ checkpoint_ns,
270
+ checkpoint_id,
271
+ parent_checkpoint_id,
272
+ checkpoint_type,
273
+ checkpoint_blob,
274
+ metadata_blob,
275
+ metadata_type,
276
+ ),
277
+ )
278
+
279
+ await self.db.commit()
280
+
281
+ # Return updated config with new checkpoint_id
282
+ return {
283
+ "configurable": {
284
+ "thread_id": thread_id,
285
+ "checkpoint_ns": checkpoint_ns,
286
+ "checkpoint_id": checkpoint_id,
287
+ }
288
+ }
289
+
290
+ async def aput_writes(
291
+ self,
292
+ config: RunnableConfig,
293
+ writes: list[tuple[str, Any]],
294
+ task_id: str,
295
+ task_path: str = "",
296
+ ) -> None:
297
+ """Save pending writes for a checkpoint."""
298
+ thread_id = config["configurable"]["thread_id"]
299
+ checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
300
+ checkpoint_id = config["configurable"]["checkpoint_id"]
301
+
302
+ # Insert writes
303
+ rows = []
304
+ for idx, (channel, value) in enumerate(writes):
305
+ write_type, write_blob = self.serde.dumps_typed(value)
306
+ rows.append(
307
+ (
308
+ thread_id,
309
+ checkpoint_ns,
310
+ checkpoint_id,
311
+ task_id,
312
+ idx,
313
+ channel,
314
+ write_type,
315
+ write_blob,
316
+ )
317
+ )
318
+
319
+ await self.db.executemany(
320
+ """
321
+ INSERT OR REPLACE INTO checkpoint_writes
322
+ (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob)
323
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
324
+ """,
325
+ rows,
326
+ )
327
+
328
+ await self.db.commit()
329
+
330
+ # Sync methods - not implemented
331
+
332
+ def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
333
+ """Sync version not implemented - use async version."""
334
+ raise NotImplementedError("Use async version: aget_tuple")
335
+
336
+ def list(
337
+ self,
338
+ config: RunnableConfig | None,
339
+ *,
340
+ filter: dict[str, Any] | None = None,
341
+ before: RunnableConfig | None = None,
342
+ limit: int = 10,
343
+ ):
344
+ """Sync version not implemented - use async version."""
345
+ raise NotImplementedError("Use async version: alist")
346
+
347
+ def put(
348
+ self,
349
+ config: RunnableConfig,
350
+ checkpoint: Checkpoint,
351
+ metadata: CheckpointMetadata,
352
+ new_versions: ChannelVersions,
353
+ ) -> RunnableConfig:
354
+ """Sync version not implemented - use async version."""
355
+ raise NotImplementedError("Use async version: aput")
356
+
357
+ def put_writes(
358
+ self,
359
+ config: RunnableConfig,
360
+ writes: list[tuple[str, Any]],
361
+ task_id: str,
362
+ task_path: str = "",
363
+ ) -> None:
364
+ """Sync version not implemented - use async version."""
365
+ raise NotImplementedError("Use async version: aput_writes")