switchplane 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of switchplane might be problematic. Click here for more details.
- switchplane/__init__.py +12 -0
- switchplane/__main__.py +4 -0
- switchplane/_util.py +36 -0
- switchplane/agent.py +46 -0
- switchplane/agent_runtime.py +555 -0
- switchplane/app.py +157 -0
- switchplane/checkpoint.py +365 -0
- switchplane/cli.py +596 -0
- switchplane/config.py +83 -0
- switchplane/control_plane.py +643 -0
- switchplane/daemon.py +350 -0
- switchplane/discovery.py +155 -0
- switchplane/fmt.py +132 -0
- switchplane/llm.py +96 -0
- switchplane/logging.py +103 -0
- switchplane/mcp.py +305 -0
- switchplane/oauth.py +465 -0
- switchplane/persistence.py +498 -0
- switchplane/protocol.py +73 -0
- switchplane/shell.py +386 -0
- switchplane/subprocess_manager.py +425 -0
- switchplane/task.py +204 -0
- switchplane/transport.py +234 -0
- switchplane/tui.py +1380 -0
- switchplane-0.1.0.dist-info/METADATA +802 -0
- switchplane-0.1.0.dist-info/RECORD +28 -0
- switchplane-0.1.0.dist-info/WHEEL +4 -0
- switchplane-0.1.0.dist-info/licenses/LICENSE +191 -0
switchplane/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Switchplane - Local runtime harness for agent-based task execution."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from switchplane import fmt
|
|
8
|
+
from switchplane.app import Application
|
|
9
|
+
from switchplane.shell import Shell
|
|
10
|
+
from switchplane.task import Task, command
|
|
11
|
+
|
|
12
|
+
__all__ = ["Application", "Field", "Shell", "Task", "command", "fmt"]
|
switchplane/__main__.py
ADDED
switchplane/_util.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Shared internal utilities."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import struct
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
MAX_MESSAGE_SIZE = 64 * 1024 * 1024 # 64 MB
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def deep_merge(base: dict[str, Any], overrides: dict[str, Any]) -> None:
|
|
11
|
+
"""Merge *overrides* into *base* in place. Nested dicts are merged recursively."""
|
|
12
|
+
for key, value in overrides.items():
|
|
13
|
+
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
|
14
|
+
deep_merge(base[key], value)
|
|
15
|
+
else:
|
|
16
|
+
base[key] = value
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def encode_frame(data: bytes) -> bytes:
|
|
20
|
+
"""Prepend a 4-byte big-endian length header to *data*."""
|
|
21
|
+
return struct.pack(">I", len(data)) + data
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def read_frame(reader: asyncio.StreamReader) -> bytes:
|
|
25
|
+
"""Read a length-prefixed frame, enforcing MAX_MESSAGE_SIZE."""
|
|
26
|
+
length_bytes = await reader.readexactly(4)
|
|
27
|
+
length = struct.unpack(">I", length_bytes)[0]
|
|
28
|
+
if length > MAX_MESSAGE_SIZE:
|
|
29
|
+
raise ValueError(f"Message size {length} exceeds limit of {MAX_MESSAGE_SIZE}")
|
|
30
|
+
return await reader.readexactly(length)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def write_frame(writer: asyncio.StreamWriter, data: bytes) -> None:
|
|
34
|
+
"""Write a length-prefixed frame and drain."""
|
|
35
|
+
writer.write(encode_frame(data))
|
|
36
|
+
await writer.drain()
|
switchplane/agent.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Agent specifications and records."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AgentStatus(StrEnum):
|
|
11
|
+
"""Lifecycle status of an agent subprocess."""
|
|
12
|
+
|
|
13
|
+
IDLE = "idle"
|
|
14
|
+
RUNNING = "running"
|
|
15
|
+
STOPPING = "stopping"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AgentRecord(BaseModel):
|
|
19
|
+
"""Record of an agent instance."""
|
|
20
|
+
|
|
21
|
+
model_config = ConfigDict(
|
|
22
|
+
str_strip_whitespace=True,
|
|
23
|
+
validate_assignment=True,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
agent_id: str
|
|
27
|
+
agent_name: str
|
|
28
|
+
pid: int | None = None
|
|
29
|
+
status: AgentStatus = AgentStatus.IDLE
|
|
30
|
+
capabilities_json: str = "{}"
|
|
31
|
+
started_at: datetime | None = None
|
|
32
|
+
last_heartbeat: datetime | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AgentSpec(BaseModel):
|
|
36
|
+
"""Specification for an agent."""
|
|
37
|
+
|
|
38
|
+
model_config = ConfigDict(
|
|
39
|
+
str_strip_whitespace=True,
|
|
40
|
+
validate_assignment=True,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
agent_name: str
|
|
44
|
+
module_path: str = "" # Set by discovery — dotted Python path to agent module
|
|
45
|
+
mcp_servers: list[str] = [] # Allowed MCP server names
|
|
46
|
+
tasks: dict[str, Any] = {} # task_name -> Task class
|
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
"""Agent-side runtime harness for Switchplane.
|
|
2
|
+
|
|
3
|
+
This module runs inside agent subprocesses and provides the execution harness
|
|
4
|
+
that agent code calls into. Communication with the control plane is
|
|
5
|
+
bidirectional over a Unix socketpair passed via --ipc-fd, using
|
|
6
|
+
4-byte length-prefixed JSON framing.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
import asyncio
|
|
11
|
+
import importlib
|
|
12
|
+
import logging as _logging
|
|
13
|
+
import os
|
|
14
|
+
import socket
|
|
15
|
+
import struct
|
|
16
|
+
import traceback
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import structlog
|
|
21
|
+
|
|
22
|
+
from switchplane._util import MAX_MESSAGE_SIZE
|
|
23
|
+
from switchplane.protocol import AgentCommand, AgentEvent
|
|
24
|
+
|
|
25
|
+
_logger = structlog.get_logger()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _IPCLogHandler(_logging.Handler):
|
|
29
|
+
"""Forwards stdlib log records to the control plane as 'log' AgentEvents.
|
|
30
|
+
|
|
31
|
+
Uses StreamMessageFormatter so the format can be swapped without touching
|
|
32
|
+
this handler. Logger name comes from record.name rather than the formatter
|
|
33
|
+
since Formatter.format() returns a single string.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, ctx: "AgentContext"):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self._ctx = ctx
|
|
39
|
+
self.addFilter(self._no_recursion)
|
|
40
|
+
from switchplane.logging import StreamMessageFormatter
|
|
41
|
+
|
|
42
|
+
self.setFormatter(StreamMessageFormatter())
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _no_recursion(record: _logging.LogRecord) -> bool:
|
|
46
|
+
return not record.name.startswith("switchplane.agent_runtime")
|
|
47
|
+
|
|
48
|
+
def emit(self, record: _logging.LogRecord) -> None:
|
|
49
|
+
try:
|
|
50
|
+
self._ctx.emit(
|
|
51
|
+
"log",
|
|
52
|
+
{
|
|
53
|
+
"message": self.format(record),
|
|
54
|
+
"level": record.levelname.lower(),
|
|
55
|
+
"logger": record.name,
|
|
56
|
+
},
|
|
57
|
+
)
|
|
58
|
+
except Exception:
|
|
59
|
+
pass # never let logging failures crash the agent
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
async def _read_message(reader: asyncio.StreamReader) -> bytes:
|
|
63
|
+
"""Read a length-prefixed message from the IPC socket."""
|
|
64
|
+
length_bytes = await reader.readexactly(4)
|
|
65
|
+
length = struct.unpack(">I", length_bytes)[0]
|
|
66
|
+
if length > MAX_MESSAGE_SIZE:
|
|
67
|
+
raise ValueError(f"Message size {length} exceeds limit of {MAX_MESSAGE_SIZE}")
|
|
68
|
+
return await reader.readexactly(length)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _write_message_sync(sock: socket.socket, data: bytes) -> None:
|
|
72
|
+
"""Write a length-prefixed message synchronously.
|
|
73
|
+
|
|
74
|
+
Temporarily sets the socket to blocking mode if needed, because asyncio
|
|
75
|
+
puts sockets into non-blocking mode and sock.sendall() on a non-blocking
|
|
76
|
+
socket can send partial data then raise BlockingIOError for large messages,
|
|
77
|
+
corrupting the IPC framing.
|
|
78
|
+
"""
|
|
79
|
+
message = struct.pack(">I", len(data)) + data
|
|
80
|
+
was_blocking = sock.getblocking()
|
|
81
|
+
if not was_blocking:
|
|
82
|
+
sock.setblocking(True)
|
|
83
|
+
try:
|
|
84
|
+
sock.sendall(message)
|
|
85
|
+
finally:
|
|
86
|
+
if not was_blocking:
|
|
87
|
+
sock.setblocking(False)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class AgentContext:
|
|
91
|
+
"""Context injected into agent task execution. Provides IPC helpers.
|
|
92
|
+
|
|
93
|
+
An ``AgentContext`` is created by the agent runtime and passed to
|
|
94
|
+
``Task.run()``. It is the primary interface for task code to
|
|
95
|
+
communicate with the control plane.
|
|
96
|
+
|
|
97
|
+
Key attributes:
|
|
98
|
+
task_id: Unique identifier for this task execution.
|
|
99
|
+
task_name: The registered name of the task.
|
|
100
|
+
config: Dict of merged app + user configuration (from TOML
|
|
101
|
+
config cascade). Access agent-specific settings, API keys,
|
|
102
|
+
model names, etc. via this dict.
|
|
103
|
+
|
|
104
|
+
Logging:
|
|
105
|
+
Standard library ``logging`` calls are automatically forwarded
|
|
106
|
+
to the control plane as ``log`` events via an IPC log handler
|
|
107
|
+
installed at subprocess startup. Use ``logging.getLogger()`` as
|
|
108
|
+
normal — there is no need for a special logging method.
|
|
109
|
+
|
|
110
|
+
Lifecycle methods:
|
|
111
|
+
Use ``complete(result)`` to signal success, ``fail(error)`` to
|
|
112
|
+
signal failure, and ``progress(message)`` for intermediate
|
|
113
|
+
status updates. For low-level custom events, use ``emit()``.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
task_id: str,
|
|
119
|
+
task_name: str,
|
|
120
|
+
ipc_sock: socket.socket,
|
|
121
|
+
config: dict[str, Any],
|
|
122
|
+
db_path: str | None = None,
|
|
123
|
+
):
|
|
124
|
+
self.task_id = task_id
|
|
125
|
+
self.task_name = task_name
|
|
126
|
+
self._sock = ipc_sock
|
|
127
|
+
self.config = config
|
|
128
|
+
self._cancelled = asyncio.Event()
|
|
129
|
+
self._completed = False
|
|
130
|
+
self._command_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
131
|
+
self._mcp: Any = None # McpManager, set during startup if MCP servers configured
|
|
132
|
+
self._db_path = db_path
|
|
133
|
+
self._checkpointer: Any = None
|
|
134
|
+
self._db_conn: Any = None
|
|
135
|
+
self._task: Any = None
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def runtime_dir(self) -> Path:
|
|
139
|
+
if self._db_path is None:
|
|
140
|
+
raise RuntimeError("runtime_dir not available")
|
|
141
|
+
return Path(self._db_path).parent
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def mcp(self):
|
|
145
|
+
"""Access MCP sessions. Returns McpManager or None if no MCP servers configured."""
|
|
146
|
+
return self._mcp
|
|
147
|
+
|
|
148
|
+
async def mcp_tools(self) -> dict[str, Any]:
|
|
149
|
+
"""Get all MCP tools as LangChain StructuredTool instances."""
|
|
150
|
+
if self._mcp is None:
|
|
151
|
+
return {}
|
|
152
|
+
return {t.name: t for t in await self._mcp.langchain_tools()}
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def checkpointer(self):
|
|
156
|
+
"""LangGraph checkpoint saver for resumable workflows. Returns None if db_path not set."""
|
|
157
|
+
return self._checkpointer
|
|
158
|
+
|
|
159
|
+
def emit(self, event_type: str, payload: dict[str, Any] | None = None) -> None:
|
|
160
|
+
"""Send an event to the control plane over the IPC socket.
|
|
161
|
+
|
|
162
|
+
This is the low-level event primitive. Prefer the higher-level
|
|
163
|
+
helpers for standard lifecycle events:
|
|
164
|
+
|
|
165
|
+
- ``progress(message)`` — emits ``task.progress``
|
|
166
|
+
- ``complete(result)`` — emits ``task.completed``
|
|
167
|
+
- ``fail(error)`` — emits ``task.failed``
|
|
168
|
+
|
|
169
|
+
Use ``emit`` directly for custom event types (e.g.
|
|
170
|
+
``"task.metric"``, ``"task.warning"``). The ``event_type`` string
|
|
171
|
+
is stored as-is in the events table and forwarded to CLI/TUI
|
|
172
|
+
subscribers. The ``payload`` dict is serialized as JSON; keep
|
|
173
|
+
values JSON-serializable.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
event_type: Dot-namespaced event type string.
|
|
177
|
+
payload: Arbitrary JSON-serializable dict. Defaults to ``{}``.
|
|
178
|
+
"""
|
|
179
|
+
event = AgentEvent(
|
|
180
|
+
type=event_type,
|
|
181
|
+
task_id=self.task_id,
|
|
182
|
+
payload=payload or {},
|
|
183
|
+
)
|
|
184
|
+
_write_message_sync(self._sock, event.model_dump_json().encode())
|
|
185
|
+
|
|
186
|
+
def progress(self, message: str, detail: str | list[str] | None = None, **extra) -> None:
|
|
187
|
+
payload: dict[str, Any] = {"message": message, **extra}
|
|
188
|
+
if detail is not None:
|
|
189
|
+
payload["detail"] = detail if isinstance(detail, list) else detail.split("\n")
|
|
190
|
+
self.emit("task.progress", payload)
|
|
191
|
+
|
|
192
|
+
def complete(self, result: Any) -> None:
|
|
193
|
+
self._completed = True
|
|
194
|
+
self.emit("task.completed", {"result": result})
|
|
195
|
+
|
|
196
|
+
def fail(self, error: str, traceback_str: str | None = None) -> None:
|
|
197
|
+
self._completed = True
|
|
198
|
+
payload = {"error": error}
|
|
199
|
+
if traceback_str:
|
|
200
|
+
payload["traceback"] = traceback_str
|
|
201
|
+
self.emit("task.failed", payload)
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def is_cancelled(self) -> bool:
|
|
205
|
+
"""Whether a cancellation has been requested for this task.
|
|
206
|
+
|
|
207
|
+
Check this in long-running loops to exit gracefully. For an
|
|
208
|
+
async-friendly alternative that raises ``CancelledError``, use
|
|
209
|
+
``check_cancelled()``. For cancellation-aware sleeping, use
|
|
210
|
+
``sleep()`` which returns ``False`` when cancelled.
|
|
211
|
+
"""
|
|
212
|
+
return self._cancelled.is_set()
|
|
213
|
+
|
|
214
|
+
async def check_cancelled(self) -> None:
|
|
215
|
+
"""Raise asyncio.CancelledError if a cancel command has been received."""
|
|
216
|
+
if self._cancelled.is_set():
|
|
217
|
+
raise asyncio.CancelledError("Task cancelled by control plane")
|
|
218
|
+
|
|
219
|
+
async def receive_command(self) -> dict[str, Any]:
|
|
220
|
+
"""Block until a command arrives from the queue."""
|
|
221
|
+
return await self._command_queue.get()
|
|
222
|
+
|
|
223
|
+
def poll_command(self) -> dict[str, Any] | None:
|
|
224
|
+
"""Non-blocking check for commands, returns None if empty."""
|
|
225
|
+
try:
|
|
226
|
+
return self._command_queue.get_nowait()
|
|
227
|
+
except asyncio.QueueEmpty:
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
def command_result(self, action: str, result: dict[str, Any]) -> None:
|
|
231
|
+
"""Emit a ``task.command_result`` event back to the control plane.
|
|
232
|
+
|
|
233
|
+
Called automatically by ``Task._dispatch_command`` after a
|
|
234
|
+
``@command``-decorated handler returns. You generally do not
|
|
235
|
+
need to call this directly unless you are implementing custom
|
|
236
|
+
command dispatch logic outside the ``@command`` decorator
|
|
237
|
+
framework.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
action: The command action name that was executed.
|
|
241
|
+
result: Dict payload to include in the result event. An
|
|
242
|
+
``{"error": ...}`` key signals failure to the CLI/TUI.
|
|
243
|
+
"""
|
|
244
|
+
self.emit("task.command_result", {"action": action, "result": result})
|
|
245
|
+
|
|
246
|
+
async def wait_for_input(self, prompt: str | None = None) -> str:
|
|
247
|
+
"""Block until the user sends freeform text input.
|
|
248
|
+
|
|
249
|
+
Emits a task.interrupted event (with optional prompt), waits for
|
|
250
|
+
user input via the command queue, then emits task.resumed and returns
|
|
251
|
+
the text. Non-input commands are dispatched normally while waiting.
|
|
252
|
+
|
|
253
|
+
Requires a checkpointer to be configured.
|
|
254
|
+
"""
|
|
255
|
+
if self._checkpointer is None:
|
|
256
|
+
raise RuntimeError(
|
|
257
|
+
"wait_for_input requires a checkpointer — compile your graph with checkpointer=ctx.checkpointer"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
payload: dict[str, Any] = {"prompt": prompt} if prompt is not None else {}
|
|
261
|
+
self.emit("task.interrupted", payload)
|
|
262
|
+
|
|
263
|
+
while True:
|
|
264
|
+
cmd = await self._command_queue.get()
|
|
265
|
+
if cmd["action"] == "__input__":
|
|
266
|
+
self.emit("task.resumed", {})
|
|
267
|
+
return cmd["params"]["text"]
|
|
268
|
+
# Dispatch non-input commands to the task
|
|
269
|
+
if self._task is not None:
|
|
270
|
+
await self._task._dispatch_command(self, cmd)
|
|
271
|
+
|
|
272
|
+
async def sleep(self, seconds: float) -> bool:
|
|
273
|
+
"""Cancellation-aware sleep.
|
|
274
|
+
|
|
275
|
+
Returns ``True`` if the full duration elapsed, ``False`` if the
|
|
276
|
+
task was cancelled during the wait.
|
|
277
|
+
"""
|
|
278
|
+
try:
|
|
279
|
+
await asyncio.wait_for(self._cancelled.wait(), timeout=seconds)
|
|
280
|
+
return False
|
|
281
|
+
except TimeoutError:
|
|
282
|
+
return True
|
|
283
|
+
|
|
284
|
+
async def poll_until(
|
|
285
|
+
self,
|
|
286
|
+
callback,
|
|
287
|
+
interval: int | float = 60,
|
|
288
|
+
task: Any = None,
|
|
289
|
+
):
|
|
290
|
+
"""Repeatedly call *callback* every *interval* seconds until it
|
|
291
|
+
returns a non-``None`` value or the task is cancelled.
|
|
292
|
+
|
|
293
|
+
*callback* may be sync or async.
|
|
294
|
+
If *task* is provided, pending commands are dispatched between polls.
|
|
295
|
+
Returns the callback result, or ``None`` if cancelled.
|
|
296
|
+
"""
|
|
297
|
+
while not self.is_cancelled:
|
|
298
|
+
try:
|
|
299
|
+
await asyncio.wait_for(self._cancelled.wait(), timeout=interval)
|
|
300
|
+
return None
|
|
301
|
+
except TimeoutError:
|
|
302
|
+
pass
|
|
303
|
+
|
|
304
|
+
if task is not None:
|
|
305
|
+
await task.process_commands(self)
|
|
306
|
+
|
|
307
|
+
result = callback()
|
|
308
|
+
if asyncio.iscoroutine(result):
|
|
309
|
+
result = await result
|
|
310
|
+
if result is not None:
|
|
311
|
+
return result
|
|
312
|
+
return None
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
async def _listen_for_commands(
|
|
316
|
+
reader: asyncio.StreamReader,
|
|
317
|
+
ctx: AgentContext,
|
|
318
|
+
task_handle: asyncio.Task,
|
|
319
|
+
) -> None:
|
|
320
|
+
"""Listen for incoming commands from the control plane (cancel, shutdown)."""
|
|
321
|
+
try:
|
|
322
|
+
while True:
|
|
323
|
+
data = await _read_message(reader)
|
|
324
|
+
try:
|
|
325
|
+
command = AgentCommand.model_validate_json(data)
|
|
326
|
+
except Exception:
|
|
327
|
+
_logger.warning("malformed_command", data=data[:200])
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
match command.type:
|
|
331
|
+
case "cancel":
|
|
332
|
+
ctx._cancelled.set()
|
|
333
|
+
task_handle.cancel()
|
|
334
|
+
return
|
|
335
|
+
case "shutdown":
|
|
336
|
+
ctx._cancelled.set()
|
|
337
|
+
task_handle.cancel()
|
|
338
|
+
return
|
|
339
|
+
case "user_command":
|
|
340
|
+
# Put the command payload onto the command queue
|
|
341
|
+
await ctx._command_queue.put(command.payload)
|
|
342
|
+
except (asyncio.IncompleteReadError, ConnectionError, OSError):
|
|
343
|
+
pass # Socket closed — control plane is gone, task will finish or be orphaned
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
async def _start_checkpointer(ctx: AgentContext) -> None:
|
|
347
|
+
"""Initialize the checkpoint saver if db_path is available."""
|
|
348
|
+
if not ctx._db_path:
|
|
349
|
+
return
|
|
350
|
+
try:
|
|
351
|
+
import aiosqlite
|
|
352
|
+
|
|
353
|
+
from switchplane.checkpoint import SqliteCheckpointSaver
|
|
354
|
+
|
|
355
|
+
ctx._db_conn = await aiosqlite.connect(ctx._db_path)
|
|
356
|
+
ctx._db_conn.row_factory = aiosqlite.Row
|
|
357
|
+
await ctx._db_conn.execute("PRAGMA journal_mode=WAL")
|
|
358
|
+
saver = SqliteCheckpointSaver(ctx._db_conn)
|
|
359
|
+
await saver.setup()
|
|
360
|
+
ctx._checkpointer = saver
|
|
361
|
+
except Exception as e:
|
|
362
|
+
_logger.warning("checkpointer_init_failed", error=str(e))
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
async def _stop_checkpointer(ctx: AgentContext) -> None:
|
|
366
|
+
"""Close the checkpoint database connection."""
|
|
367
|
+
if ctx._db_conn is not None:
|
|
368
|
+
try:
|
|
369
|
+
await ctx._db_conn.close()
|
|
370
|
+
except Exception:
|
|
371
|
+
pass
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
async def _start_mcp(ctx: AgentContext, mcp_configs: list[dict[str, Any]]) -> None:
|
|
375
|
+
"""Start MCP sessions if configured.
|
|
376
|
+
|
|
377
|
+
Raises ``RuntimeError`` if all configured servers fail to start so
|
|
378
|
+
the caller can abort the task before execution begins.
|
|
379
|
+
"""
|
|
380
|
+
if not mcp_configs:
|
|
381
|
+
return
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
from switchplane.app import McpServerConfig
|
|
385
|
+
from switchplane.mcp import McpManager
|
|
386
|
+
except ImportError:
|
|
387
|
+
raise RuntimeError(
|
|
388
|
+
"MCP support requires the 'mcp' package. Install with: pip install switchplane[mcp]"
|
|
389
|
+
) from None
|
|
390
|
+
|
|
391
|
+
configs = [McpServerConfig.model_validate(c) for c in mcp_configs]
|
|
392
|
+
runtime_dir = Path(ctx._db_path).parent if ctx._db_path else None
|
|
393
|
+
manager = McpManager(configs, runtime_dir=runtime_dir)
|
|
394
|
+
errors = await manager.start()
|
|
395
|
+
for err in errors:
|
|
396
|
+
_logger.error("mcp_server_start_failed", error=err)
|
|
397
|
+
|
|
398
|
+
if len(errors) == len(configs):
|
|
399
|
+
await manager.stop()
|
|
400
|
+
raise RuntimeError(f"All MCP servers failed to start: {'; '.join(errors)}")
|
|
401
|
+
|
|
402
|
+
ctx._mcp = manager
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
async def _stop_mcp(ctx: AgentContext) -> None:
|
|
406
|
+
"""Stop MCP sessions if running."""
|
|
407
|
+
if ctx._mcp is not None:
|
|
408
|
+
try:
|
|
409
|
+
await ctx._mcp.stop()
|
|
410
|
+
except Exception:
|
|
411
|
+
pass
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
async def _run_task(ctx: AgentContext, task_module_path: str, raw_params: dict) -> None:
|
|
415
|
+
"""Import and execute the user task."""
|
|
416
|
+
try:
|
|
417
|
+
module = importlib.import_module(task_module_path)
|
|
418
|
+
|
|
419
|
+
from switchplane.task import Task
|
|
420
|
+
|
|
421
|
+
task_class = None
|
|
422
|
+
for name in dir(module):
|
|
423
|
+
obj = getattr(module, name)
|
|
424
|
+
if isinstance(obj, type) and issubclass(obj, Task) and obj is not Task:
|
|
425
|
+
task_class = obj
|
|
426
|
+
break
|
|
427
|
+
|
|
428
|
+
if task_class is None:
|
|
429
|
+
raise RuntimeError(f"No Task subclass found in {task_module_path}")
|
|
430
|
+
|
|
431
|
+
task_instance = task_class()
|
|
432
|
+
task_instance._ctx = ctx
|
|
433
|
+
ctx._task = task_instance
|
|
434
|
+
|
|
435
|
+
# Validate and set parameter fields
|
|
436
|
+
params_model = task_class.parameters_model()
|
|
437
|
+
if params_model is not None:
|
|
438
|
+
validated = params_model.model_validate(raw_params)
|
|
439
|
+
for field_name in params_model.model_fields:
|
|
440
|
+
setattr(task_instance, field_name, getattr(validated, field_name))
|
|
441
|
+
|
|
442
|
+
result = await task_instance.run(ctx)
|
|
443
|
+
# Only auto-complete if the task returned a value AND hasn't already
|
|
444
|
+
# emitted a terminal event (complete/fail) itself.
|
|
445
|
+
if result is not None and not ctx._completed:
|
|
446
|
+
ctx.complete(result)
|
|
447
|
+
except Exception:
|
|
448
|
+
# Re-raise to be caught by agent_main which will include traceback
|
|
449
|
+
raise
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
async def agent_main(ipc_fd: int, entry_point: str) -> None:
|
|
453
|
+
"""Main entry point for agent subprocess.
|
|
454
|
+
|
|
455
|
+
Opens the IPC socket from the passed fd, reads the initial execute_task
|
|
456
|
+
command, runs the task, and listens for cancel/shutdown concurrently.
|
|
457
|
+
"""
|
|
458
|
+
# Reconstruct socket from the inherited fd
|
|
459
|
+
sock = socket.fromfd(ipc_fd, socket.AF_UNIX, socket.SOCK_STREAM)
|
|
460
|
+
os.close(ipc_fd) # fromfd duped the fd
|
|
461
|
+
|
|
462
|
+
# Wrap for async reading (commands from CP)
|
|
463
|
+
try:
|
|
464
|
+
reader, _writer = await asyncio.open_connection(sock=sock)
|
|
465
|
+
except Exception:
|
|
466
|
+
sock.close()
|
|
467
|
+
raise
|
|
468
|
+
|
|
469
|
+
# Read the initial command
|
|
470
|
+
try:
|
|
471
|
+
data = await _read_message(reader)
|
|
472
|
+
except (asyncio.IncompleteReadError, ConnectionError):
|
|
473
|
+
sock.close()
|
|
474
|
+
return
|
|
475
|
+
|
|
476
|
+
command = AgentCommand.model_validate_json(data)
|
|
477
|
+
if command.type != "execute_task":
|
|
478
|
+
sock.close()
|
|
479
|
+
return
|
|
480
|
+
|
|
481
|
+
task_id = command.task_id
|
|
482
|
+
task_name = command.payload.get("task_name", "")
|
|
483
|
+
params = command.payload.get("params", {})
|
|
484
|
+
task_module_path = command.payload.get("task_module", "")
|
|
485
|
+
config = command.payload.get("config", {})
|
|
486
|
+
mcp_configs = command.payload.get("mcp_servers", [])
|
|
487
|
+
db_path = command.payload.get("db_path")
|
|
488
|
+
|
|
489
|
+
ctx = AgentContext(
|
|
490
|
+
task_id=task_id,
|
|
491
|
+
task_name=task_name,
|
|
492
|
+
ipc_sock=sock,
|
|
493
|
+
config=config,
|
|
494
|
+
db_path=db_path,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Configure log level and install IPC handler so structlog output
|
|
498
|
+
# from agent code is forwarded as "log" events to the control plane.
|
|
499
|
+
log_level = command.payload.get("log_level", "debug")
|
|
500
|
+
_logging.getLogger().setLevel(getattr(_logging, log_level.upper(), _logging.DEBUG))
|
|
501
|
+
_ipc_handler = _IPCLogHandler(ctx)
|
|
502
|
+
_logging.getLogger().addHandler(_ipc_handler)
|
|
503
|
+
|
|
504
|
+
# Start checkpointer and MCP sessions before task execution
|
|
505
|
+
await _start_checkpointer(ctx)
|
|
506
|
+
try:
|
|
507
|
+
await _start_mcp(ctx, mcp_configs)
|
|
508
|
+
except RuntimeError as e:
|
|
509
|
+
ctx.fail(str(e))
|
|
510
|
+
await _stop_checkpointer(ctx)
|
|
511
|
+
_writer.close()
|
|
512
|
+
sock.close()
|
|
513
|
+
return
|
|
514
|
+
|
|
515
|
+
ctx.emit("task.started", {})
|
|
516
|
+
|
|
517
|
+
# Run the task and the command listener concurrently
|
|
518
|
+
task_handle = asyncio.create_task(_run_task(ctx, task_module_path, params))
|
|
519
|
+
listener_handle = asyncio.create_task(_listen_for_commands(reader, ctx, task_handle))
|
|
520
|
+
|
|
521
|
+
try:
|
|
522
|
+
await task_handle
|
|
523
|
+
except asyncio.CancelledError:
|
|
524
|
+
ctx.emit("task.cancelled", {})
|
|
525
|
+
except BaseException as e:
|
|
526
|
+
ctx.fail(f"{type(e).__name__}: {e}", traceback.format_exc())
|
|
527
|
+
if isinstance(e, (KeyboardInterrupt, SystemExit)):
|
|
528
|
+
raise
|
|
529
|
+
finally:
|
|
530
|
+
_logging.getLogger().removeHandler(_ipc_handler)
|
|
531
|
+
await _stop_mcp(ctx)
|
|
532
|
+
await _stop_checkpointer(ctx)
|
|
533
|
+
listener_handle.cancel()
|
|
534
|
+
try:
|
|
535
|
+
await listener_handle
|
|
536
|
+
except asyncio.CancelledError:
|
|
537
|
+
pass
|
|
538
|
+
_writer.close()
|
|
539
|
+
sock.close()
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
if __name__ == "__main__":
|
|
543
|
+
parser = argparse.ArgumentParser(description="Switchplane agent runtime")
|
|
544
|
+
parser.add_argument("--entry-point", required=True)
|
|
545
|
+
parser.add_argument("--ipc-fd", required=True, type=int)
|
|
546
|
+
parser.add_argument("--log-file", default=None)
|
|
547
|
+
args = parser.parse_args()
|
|
548
|
+
|
|
549
|
+
from pathlib import Path as _Path
|
|
550
|
+
|
|
551
|
+
from switchplane import logging
|
|
552
|
+
|
|
553
|
+
logging.configure(log_file=_Path(args.log_file) if args.log_file else None)
|
|
554
|
+
|
|
555
|
+
asyncio.run(agent_main(args.ipc_fd, args.entry_point))
|