pidriver 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pidriver/__init__.py +132 -0
- pidriver/_transport.py +272 -0
- pidriver/client.py +92 -0
- pidriver/config.py +259 -0
- pidriver/errors.py +77 -0
- pidriver/events.py +441 -0
- pidriver/interaction.py +214 -0
- pidriver/manager.py +192 -0
- pidriver/py.typed +0 -0
- pidriver/session.py +255 -0
- pidriver/usage.py +98 -0
- pidriver-0.0.1.dist-info/METADATA +490 -0
- pidriver-0.0.1.dist-info/RECORD +15 -0
- pidriver-0.0.1.dist-info/WHEEL +4 -0
- pidriver-0.0.1.dist-info/licenses/LICENSE +21 -0
pidriver/__init__.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""pidriver — an async Python driver for the pi coding agent.
|
|
2
|
+
|
|
3
|
+
Drives ``pi --mode rpc`` over a JSONL stdin/stdout protocol, with first-class
|
|
4
|
+
support for *isolated* agents (private config dir, scrubbed environment) so a pi
|
|
5
|
+
instance developing a project can't touch the operator's global pi setup.
|
|
6
|
+
|
|
7
|
+
Layers:
|
|
8
|
+
|
|
9
|
+
* :mod:`pidriver.config` — :class:`PiConfig` (isolation knobs, ``to_argv``/``to_env``).
|
|
10
|
+
* :mod:`pidriver._transport` — :class:`PiTransport` protocol + :class:`SubprocessTransport`
|
|
11
|
+
(the swap boundary).
|
|
12
|
+
* :mod:`pidriver.events` — the typed :class:`Event` hierarchy + :func:`parse_event`.
|
|
13
|
+
* :mod:`pidriver.interaction` — answering the agent's questions (handlers + policies).
|
|
14
|
+
* :mod:`pidriver.session` — :class:`PiSession` (the live RPC session).
|
|
15
|
+
* :mod:`pidriver.client` — :class:`PiClient` (start sessions).
|
|
16
|
+
* :mod:`pidriver.manager` — :class:`PiSessionManager` (registry, idle reaper, ``stop_all``).
|
|
17
|
+
* :mod:`pidriver.usage` — :class:`UsageTotals`.
|
|
18
|
+
* :mod:`pidriver.errors` — exception hierarchy.
|
|
19
|
+
|
|
20
|
+
Quick start::
|
|
21
|
+
|
|
22
|
+
from pidriver import PiClient, PiConfig, AutoApprove, MessageDelta, AgentEnd
|
|
23
|
+
|
|
24
|
+
client = PiClient(PiConfig(binary="omp", api_key=KEY, model="anthropic/claude-..."))
|
|
25
|
+
session = await client.start("Add a healthcheck endpoint", cwd="/srv/proj",
|
|
26
|
+
interaction_handler=AutoApprove())
|
|
27
|
+
async for event in session:
|
|
28
|
+
match event:
|
|
29
|
+
case MessageDelta(text=text):
|
|
30
|
+
print(text, end="")
|
|
31
|
+
case AgentEnd():
|
|
32
|
+
break
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
from ._transport import PiTransport, SubprocessTransport
|
|
38
|
+
from .client import PiClient, TransportFactory
|
|
39
|
+
from .config import PROVIDER_API_KEY_ENV, PiConfig
|
|
40
|
+
from .errors import (
|
|
41
|
+
PiCommandError,
|
|
42
|
+
PiDriverError,
|
|
43
|
+
PiNotFoundError,
|
|
44
|
+
PiProtocolError,
|
|
45
|
+
PiStartError,
|
|
46
|
+
PiTimeoutError,
|
|
47
|
+
SessionNotFoundError,
|
|
48
|
+
TransportClosedError,
|
|
49
|
+
)
|
|
50
|
+
from .events import (
|
|
51
|
+
AgentEnd,
|
|
52
|
+
AgentStart,
|
|
53
|
+
Channel,
|
|
54
|
+
Decision,
|
|
55
|
+
Error,
|
|
56
|
+
Event,
|
|
57
|
+
InteractionKind,
|
|
58
|
+
InteractionRequest,
|
|
59
|
+
MessageComplete,
|
|
60
|
+
MessageDelta,
|
|
61
|
+
ToolEnd,
|
|
62
|
+
ToolStart,
|
|
63
|
+
ToolUpdate,
|
|
64
|
+
Usage,
|
|
65
|
+
parse_event,
|
|
66
|
+
)
|
|
67
|
+
from .interaction import (
|
|
68
|
+
AskHost,
|
|
69
|
+
AutoApprove,
|
|
70
|
+
DenyAll,
|
|
71
|
+
InteractionHandler,
|
|
72
|
+
InteractionResponse,
|
|
73
|
+
chain,
|
|
74
|
+
response_to_wire,
|
|
75
|
+
)
|
|
76
|
+
from .manager import ManagedSession, PiSessionManager
|
|
77
|
+
from .session import PiSession
|
|
78
|
+
from .usage import UsageTotals
|
|
79
|
+
|
|
80
|
+
__version__ = "0.0.1"
|
|
81
|
+
|
|
82
|
+
__all__ = [
|
|
83
|
+
"__version__",
|
|
84
|
+
# config
|
|
85
|
+
"PiConfig",
|
|
86
|
+
"PROVIDER_API_KEY_ENV",
|
|
87
|
+
# transport
|
|
88
|
+
"PiTransport",
|
|
89
|
+
"SubprocessTransport",
|
|
90
|
+
# client / session
|
|
91
|
+
"PiClient",
|
|
92
|
+
"TransportFactory",
|
|
93
|
+
"PiSession",
|
|
94
|
+
# manager
|
|
95
|
+
"PiSessionManager",
|
|
96
|
+
"ManagedSession",
|
|
97
|
+
# usage
|
|
98
|
+
"UsageTotals",
|
|
99
|
+
# events
|
|
100
|
+
"Event",
|
|
101
|
+
"AgentStart",
|
|
102
|
+
"MessageDelta",
|
|
103
|
+
"MessageComplete",
|
|
104
|
+
"ToolStart",
|
|
105
|
+
"ToolUpdate",
|
|
106
|
+
"ToolEnd",
|
|
107
|
+
"Usage",
|
|
108
|
+
"AgentEnd",
|
|
109
|
+
"Error",
|
|
110
|
+
"Channel",
|
|
111
|
+
"parse_event",
|
|
112
|
+
# interaction
|
|
113
|
+
"InteractionRequest",
|
|
114
|
+
"InteractionKind",
|
|
115
|
+
"InteractionResponse",
|
|
116
|
+
"InteractionHandler",
|
|
117
|
+
"Decision",
|
|
118
|
+
"AskHost",
|
|
119
|
+
"AutoApprove",
|
|
120
|
+
"DenyAll",
|
|
121
|
+
"chain",
|
|
122
|
+
"response_to_wire",
|
|
123
|
+
# errors
|
|
124
|
+
"PiDriverError",
|
|
125
|
+
"PiNotFoundError",
|
|
126
|
+
"PiStartError",
|
|
127
|
+
"PiProtocolError",
|
|
128
|
+
"PiCommandError",
|
|
129
|
+
"PiTimeoutError",
|
|
130
|
+
"TransportClosedError",
|
|
131
|
+
"SessionNotFoundError",
|
|
132
|
+
]
|
pidriver/_transport.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""The transport boundary: process lifecycle + JSONL framing for ``pi --mode rpc``.
|
|
2
|
+
|
|
3
|
+
This module is intentionally **dumb**. It spawns the pi RPC subprocess, writes
|
|
4
|
+
JSON commands as newline-delimited lines, and yields parsed JSON objects (events
|
|
5
|
+
*and* responses, undifferentiated) from stdout. It does **not** correlate
|
|
6
|
+
request/response ids, dispatch events, or know any command semantics — those live
|
|
7
|
+
in the higher session/client layer.
|
|
8
|
+
|
|
9
|
+
Keeping this layer thin makes it the **swap boundary**: the public surface is the
|
|
10
|
+
:class:`PiTransport` protocol, and :class:`SubprocessTransport` is one concrete
|
|
11
|
+
implementation. A future omp-rpc-backed transport can satisfy the same protocol
|
|
12
|
+
without touching the session/manager code above it.
|
|
13
|
+
|
|
14
|
+
Framing follows pi's RPC contract strictly: records are split on ``\\n`` only, and
|
|
15
|
+
a trailing ``\\r`` is stripped. We never split on U+2028/U+2029 (which are valid
|
|
16
|
+
inside JSON strings), so we do manual buffering rather than a generic line reader.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import asyncio
|
|
22
|
+
import json
|
|
23
|
+
import os
|
|
24
|
+
import shutil
|
|
25
|
+
from collections.abc import AsyncIterator, Callable
|
|
26
|
+
from typing import Any, Protocol, runtime_checkable
|
|
27
|
+
|
|
28
|
+
from .config import PiConfig
|
|
29
|
+
from .errors import (
|
|
30
|
+
PiNotFoundError,
|
|
31
|
+
PiProtocolError,
|
|
32
|
+
PiStartError,
|
|
33
|
+
TransportClosedError,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
__all__ = ["PiTransport", "SubprocessTransport"]
|
|
37
|
+
|
|
38
|
+
# Generous per-line cap. pi can emit large tool-result lines; we still bound memory
|
|
39
|
+
# to avoid an unbounded buffer if the stream is malformed and never sees a newline.
|
|
40
|
+
_MAX_LINE_BYTES = 64 * 1024 * 1024
|
|
41
|
+
_READ_CHUNK = 1 << 16
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@runtime_checkable
|
|
45
|
+
class PiTransport(Protocol):
|
|
46
|
+
"""The minimal surface every transport must provide.
|
|
47
|
+
|
|
48
|
+
Contract:
|
|
49
|
+
|
|
50
|
+
* exactly **one** consumer may call :meth:`receive` / iterate at a time;
|
|
51
|
+
* :meth:`receive` returns ``None`` once the process's stdout reaches EOF;
|
|
52
|
+
* :meth:`send` and :meth:`receive` raise :class:`TransportClosedError` if
|
|
53
|
+
called before :meth:`start` or after the process has gone.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
async def start(self) -> None: ...
|
|
57
|
+
async def send(self, obj: dict[str, Any]) -> None: ...
|
|
58
|
+
async def receive(self) -> dict[str, Any] | None: ...
|
|
59
|
+
def __aiter__(self) -> AsyncIterator[dict[str, Any]]: ...
|
|
60
|
+
async def aclose(self) -> None: ...
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def pid(self) -> int | None: ...
|
|
64
|
+
@property
|
|
65
|
+
def returncode(self) -> int | None: ...
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SubprocessTransport:
|
|
69
|
+
""":class:`PiTransport` backed by a ``pi --mode rpc`` child process.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
config: how to spawn pi (argv + env + cwd).
|
|
73
|
+
on_stderr: optional callback invoked with each decoded stderr line; pi
|
|
74
|
+
writes diagnostics there. stderr is never interleaved into the dict
|
|
75
|
+
stream returned by :meth:`receive`.
|
|
76
|
+
spawn: injection seam for tests — defaults to
|
|
77
|
+
:func:`asyncio.create_subprocess_exec`.
|
|
78
|
+
terminate_timeout: seconds to wait after SIGTERM before SIGKILL in
|
|
79
|
+
:meth:`aclose`.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
config: PiConfig,
|
|
85
|
+
*,
|
|
86
|
+
on_stderr: Callable[[str], None] | None = None,
|
|
87
|
+
spawn: Callable[..., Any] = asyncio.create_subprocess_exec,
|
|
88
|
+
terminate_timeout: float = 5.0,
|
|
89
|
+
) -> None:
|
|
90
|
+
self._config = config
|
|
91
|
+
self._on_stderr = on_stderr
|
|
92
|
+
self._spawn = spawn
|
|
93
|
+
self._terminate_timeout = terminate_timeout
|
|
94
|
+
|
|
95
|
+
self._proc: Any | None = None
|
|
96
|
+
self._buf = bytearray()
|
|
97
|
+
self._eof = False
|
|
98
|
+
self._stderr_task: asyncio.Task[None] | None = None
|
|
99
|
+
# Serialize stdin writes so concurrent sends (e.g. a steer during a prompt)
|
|
100
|
+
# can't interleave bytes of two JSONL records. Mirrors omp's own client lock.
|
|
101
|
+
self._write_lock = asyncio.Lock()
|
|
102
|
+
|
|
103
|
+
# ---------------------------------------------------------------- props
|
|
104
|
+
@property
|
|
105
|
+
def pid(self) -> int | None:
|
|
106
|
+
return self._proc.pid if self._proc is not None else None
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def returncode(self) -> int | None:
|
|
110
|
+
return self._proc.returncode if self._proc is not None else None
|
|
111
|
+
|
|
112
|
+
# -------------------------------------------------------------- lifecycle
|
|
113
|
+
async def start(self) -> None:
|
|
114
|
+
"""Resolve the pi binary and spawn the RPC subprocess."""
|
|
115
|
+
if self._proc is not None:
|
|
116
|
+
raise PiStartError("transport already started")
|
|
117
|
+
|
|
118
|
+
argv = self._config.to_argv()
|
|
119
|
+
if shutil.which(argv[0]) is None and not os.path.exists(argv[0]):
|
|
120
|
+
raise PiNotFoundError(argv[0])
|
|
121
|
+
|
|
122
|
+
env = self._config.to_env()
|
|
123
|
+
cwd = str(self._config.workspace) if self._config.workspace is not None else None
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
self._proc = await self._spawn(
|
|
127
|
+
*argv,
|
|
128
|
+
stdin=asyncio.subprocess.PIPE,
|
|
129
|
+
stdout=asyncio.subprocess.PIPE,
|
|
130
|
+
stderr=asyncio.subprocess.PIPE,
|
|
131
|
+
cwd=cwd,
|
|
132
|
+
env=env,
|
|
133
|
+
)
|
|
134
|
+
except (OSError, ValueError) as exc: # pragma: no cover - spawn failure
|
|
135
|
+
raise PiStartError(f"failed to spawn pi: {exc}") from exc
|
|
136
|
+
|
|
137
|
+
if self._proc.stdin is None or self._proc.stdout is None:
|
|
138
|
+
raise PiStartError("pi subprocess started without stdin/stdout pipes")
|
|
139
|
+
|
|
140
|
+
if self._proc.stderr is not None:
|
|
141
|
+
self._stderr_task = asyncio.ensure_future(self._drain_stderr(self._proc.stderr))
|
|
142
|
+
|
|
143
|
+
async def _drain_stderr(self, stream: asyncio.StreamReader) -> None:
|
|
144
|
+
try:
|
|
145
|
+
while True:
|
|
146
|
+
line = await stream.readline()
|
|
147
|
+
if not line:
|
|
148
|
+
break
|
|
149
|
+
if self._on_stderr is not None:
|
|
150
|
+
self._on_stderr(line.decode("utf-8", errors="replace").rstrip("\r\n"))
|
|
151
|
+
except asyncio.CancelledError: # pragma: no cover - shutdown path
|
|
152
|
+
raise
|
|
153
|
+
except Exception: # pragma: no cover - never let stderr drain crash us
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
# ------------------------------------------------------------------- io
|
|
157
|
+
async def send(self, obj: dict[str, Any]) -> None:
|
|
158
|
+
"""Serialize ``obj`` to one JSONL line and write it to pi's stdin."""
|
|
159
|
+
proc = self._proc
|
|
160
|
+
if proc is None or proc.stdin is None:
|
|
161
|
+
raise TransportClosedError("transport not started")
|
|
162
|
+
if proc.returncode is not None:
|
|
163
|
+
raise TransportClosedError(f"pi process has exited (code {proc.returncode})")
|
|
164
|
+
|
|
165
|
+
line = json.dumps(obj, ensure_ascii=False, separators=(",", ":")) + "\n"
|
|
166
|
+
async with self._write_lock:
|
|
167
|
+
try:
|
|
168
|
+
proc.stdin.write(line.encode("utf-8"))
|
|
169
|
+
await proc.stdin.drain()
|
|
170
|
+
except (ConnectionResetError, BrokenPipeError) as exc:
|
|
171
|
+
raise TransportClosedError(f"pi stdin closed: {exc}") from exc
|
|
172
|
+
|
|
173
|
+
async def receive(self) -> dict[str, Any] | None:
|
|
174
|
+
"""Return the next parsed inbound JSON object, or ``None`` at EOF.
|
|
175
|
+
|
|
176
|
+
Only one coroutine may call this at a time (single-reader contract).
|
|
177
|
+
"""
|
|
178
|
+
proc = self._proc
|
|
179
|
+
if proc is None or proc.stdout is None:
|
|
180
|
+
raise TransportClosedError("transport not started")
|
|
181
|
+
|
|
182
|
+
while True:
|
|
183
|
+
line = self._take_buffered_line()
|
|
184
|
+
if line is not None:
|
|
185
|
+
if not line.strip():
|
|
186
|
+
continue # tolerate blank keepalive lines
|
|
187
|
+
return self._parse(line)
|
|
188
|
+
|
|
189
|
+
if self._eof:
|
|
190
|
+
# Flush any trailing partial line at EOF.
|
|
191
|
+
if self._buf:
|
|
192
|
+
rest = bytes(self._buf)
|
|
193
|
+
self._buf.clear()
|
|
194
|
+
text = rest.decode("utf-8", errors="replace").rstrip("\r")
|
|
195
|
+
if text.strip():
|
|
196
|
+
return self._parse(text)
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
chunk = await proc.stdout.read(_READ_CHUNK)
|
|
200
|
+
if not chunk:
|
|
201
|
+
self._eof = True
|
|
202
|
+
continue
|
|
203
|
+
self._buf.extend(chunk)
|
|
204
|
+
if len(self._buf) > _MAX_LINE_BYTES and b"\n" not in self._buf:
|
|
205
|
+
raise PiProtocolError(
|
|
206
|
+
f"line exceeded {_MAX_LINE_BYTES} bytes without a newline"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def _take_buffered_line(self) -> str | None:
|
|
210
|
+
idx = self._buf.find(b"\n")
|
|
211
|
+
if idx == -1:
|
|
212
|
+
return None
|
|
213
|
+
raw = bytes(self._buf[:idx])
|
|
214
|
+
del self._buf[: idx + 1]
|
|
215
|
+
return raw.decode("utf-8", errors="replace").rstrip("\r")
|
|
216
|
+
|
|
217
|
+
@staticmethod
|
|
218
|
+
def _parse(line: str) -> dict[str, Any]:
|
|
219
|
+
try:
|
|
220
|
+
obj = json.loads(line)
|
|
221
|
+
except json.JSONDecodeError as exc:
|
|
222
|
+
raise PiProtocolError(f"invalid JSON from pi: {exc}", raw=line) from exc
|
|
223
|
+
if not isinstance(obj, dict):
|
|
224
|
+
raise PiProtocolError(
|
|
225
|
+
f"expected JSON object from pi, got {type(obj).__name__}", raw=line
|
|
226
|
+
)
|
|
227
|
+
return obj
|
|
228
|
+
|
|
229
|
+
def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
|
|
230
|
+
return self._iter()
|
|
231
|
+
|
|
232
|
+
async def _iter(self) -> AsyncIterator[dict[str, Any]]:
|
|
233
|
+
while True:
|
|
234
|
+
obj = await self.receive()
|
|
235
|
+
if obj is None:
|
|
236
|
+
return
|
|
237
|
+
yield obj
|
|
238
|
+
|
|
239
|
+
# --------------------------------------------------------------- shutdown
|
|
240
|
+
async def aclose(self) -> None:
|
|
241
|
+
"""Close stdin, then terminate and reap the process (SIGKILL fallback)."""
|
|
242
|
+
proc = self._proc
|
|
243
|
+
if proc is None:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
if proc.stdin is not None:
|
|
247
|
+
try:
|
|
248
|
+
proc.stdin.close()
|
|
249
|
+
except Exception: # pragma: no cover
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
if proc.returncode is None:
|
|
253
|
+
try:
|
|
254
|
+
proc.terminate()
|
|
255
|
+
except ProcessLookupError: # pragma: no cover - already gone
|
|
256
|
+
pass
|
|
257
|
+
try:
|
|
258
|
+
await asyncio.wait_for(proc.wait(), timeout=self._terminate_timeout)
|
|
259
|
+
except TimeoutError:
|
|
260
|
+
try:
|
|
261
|
+
proc.kill()
|
|
262
|
+
except ProcessLookupError: # pragma: no cover
|
|
263
|
+
pass
|
|
264
|
+
await proc.wait()
|
|
265
|
+
|
|
266
|
+
if self._stderr_task is not None:
|
|
267
|
+
self._stderr_task.cancel()
|
|
268
|
+
try:
|
|
269
|
+
await self._stderr_task
|
|
270
|
+
except (asyncio.CancelledError, Exception): # pragma: no cover
|
|
271
|
+
pass
|
|
272
|
+
self._stderr_task = None
|
pidriver/client.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Starting sessions.
|
|
2
|
+
|
|
3
|
+
:class:`PiClient` is the entry point: configure it once, then call
|
|
4
|
+
:meth:`PiClient.start` per task to launch an isolated ``pi --mode rpc`` process
|
|
5
|
+
and get back a started :class:`~pidriver.session.PiSession`.
|
|
6
|
+
|
|
7
|
+
``PiClient.start`` has the signature a :class:`~pidriver.manager.PiSessionManager`
|
|
8
|
+
``factory`` expects, so you can wire the two together directly::
|
|
9
|
+
|
|
10
|
+
client = PiClient(PiConfig(...))
|
|
11
|
+
manager = PiSessionManager(factory=client.start)
|
|
12
|
+
session = await manager.create("Add a healthcheck endpoint", cwd="/srv/proj")
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from dataclasses import replace
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
21
|
+
|
|
22
|
+
from ._transport import PiTransport, SubprocessTransport
|
|
23
|
+
from .config import PiConfig
|
|
24
|
+
from .interaction import InteractionHandler
|
|
25
|
+
from .session import PiSession
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
__all__ = ["PiClient", "TransportFactory"]
|
|
31
|
+
|
|
32
|
+
# Builds a (not-yet-started) transport for a resolved config. Override in tests.
|
|
33
|
+
TransportFactory = Callable[[PiConfig], PiTransport]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PiClient:
|
|
37
|
+
"""Launches isolated pi/omp sessions from a base :class:`PiConfig`."""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
config: PiConfig,
|
|
42
|
+
*,
|
|
43
|
+
transport_factory: TransportFactory | None = None,
|
|
44
|
+
interaction_handler: InteractionHandler | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
self._config = config
|
|
47
|
+
self._make_transport: TransportFactory = transport_factory or SubprocessTransport
|
|
48
|
+
self._default_handler = interaction_handler
|
|
49
|
+
|
|
50
|
+
async def start(
|
|
51
|
+
self,
|
|
52
|
+
task: str,
|
|
53
|
+
*,
|
|
54
|
+
cwd: str | Path | None = None,
|
|
55
|
+
config: PiConfig | None = None,
|
|
56
|
+
interaction_handler: InteractionHandler | None = None,
|
|
57
|
+
resume: str | None = None,
|
|
58
|
+
session_id: str | None = None,
|
|
59
|
+
send_initial: bool = True,
|
|
60
|
+
) -> PiSession:
|
|
61
|
+
"""Launch a session and (by default) send ``task`` as the first prompt.
|
|
62
|
+
|
|
63
|
+
:param task: the initial instruction for the agent.
|
|
64
|
+
:param cwd: working directory the agent operates in (the target project);
|
|
65
|
+
overrides ``config.workspace``.
|
|
66
|
+
:param config: per-session override of the client's base config.
|
|
67
|
+
:param interaction_handler: auto-answer policy; defaults to the client's
|
|
68
|
+
handler, else :class:`~pidriver.interaction.AskHost`.
|
|
69
|
+
:param resume: a prior pi session path/id to ``--continue`` from.
|
|
70
|
+
:param session_id: stable local id for the manager registry key
|
|
71
|
+
(auto-generated when omitted).
|
|
72
|
+
:param send_initial: set ``False`` to open the session without prompting.
|
|
73
|
+
"""
|
|
74
|
+
resolved = config or self._config
|
|
75
|
+
if resolved is not None:
|
|
76
|
+
if cwd is not None:
|
|
77
|
+
resolved = replace(resolved, workspace=Path(cwd))
|
|
78
|
+
if resume:
|
|
79
|
+
resolved = replace(resolved, session_path=resume, continue_session=True)
|
|
80
|
+
|
|
81
|
+
transport = self._make_transport(resolved)
|
|
82
|
+
await transport.start()
|
|
83
|
+
|
|
84
|
+
session = PiSession(
|
|
85
|
+
transport,
|
|
86
|
+
session_id=session_id,
|
|
87
|
+
config=resolved,
|
|
88
|
+
interaction_handler=interaction_handler or self._default_handler,
|
|
89
|
+
)
|
|
90
|
+
if task and send_initial:
|
|
91
|
+
await session.prompt(task)
|
|
92
|
+
return session
|