codex-api-proxy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codex_api_proxy/__init__.py +3 -0
- codex_api_proxy/app_server_runner.py +554 -0
- codex_api_proxy/cli.py +570 -0
- codex_api_proxy/codex_runner.py +278 -0
- codex_api_proxy/config.py +83 -0
- codex_api_proxy/main.py +561 -0
- codex_api_proxy/prompt.py +31 -0
- codex_api_proxy/schemas.py +48 -0
- codex_api_proxy-0.1.0.dist-info/METADATA +347 -0
- codex_api_proxy-0.1.0.dist-info/RECORD +13 -0
- codex_api_proxy-0.1.0.dist-info/WHEEL +5 -0
- codex_api_proxy-0.1.0.dist-info/entry_points.txt +2 -0
- codex_api_proxy-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,554 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import signal
|
|
7
|
+
import time
|
|
8
|
+
import tomllib
|
|
9
|
+
from collections.abc import AsyncIterator, Callable
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Protocol
|
|
12
|
+
|
|
13
|
+
from . import __version__
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AppServerRunError(RuntimeError):
|
|
17
|
+
"""Raised when the app-server engine cannot complete a turn."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AppServerBusy(RuntimeError):
|
|
21
|
+
"""Raised when no app-server worker is available before queue limits are hit."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
DEFAULT_DISABLED_SKILL_NAMES = {
|
|
25
|
+
"imagegen",
|
|
26
|
+
"openai-docs",
|
|
27
|
+
"plugin-creator",
|
|
28
|
+
"skill-creator",
|
|
29
|
+
"skill-installer",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AppServerClient(Protocol):
|
|
34
|
+
async def start(self) -> None: ...
|
|
35
|
+
|
|
36
|
+
async def close(self) -> None: ...
|
|
37
|
+
|
|
38
|
+
async def request(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]: ...
|
|
39
|
+
|
|
40
|
+
async def read_message(self) -> dict[str, Any]: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _parse_config_value(raw: str) -> Any:
|
|
44
|
+
try:
|
|
45
|
+
return tomllib.loads(f"value = {raw}")["value"]
|
|
46
|
+
except tomllib.TOMLDecodeError:
|
|
47
|
+
return raw
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def codex_configs_to_object(codex_configs: list[str]) -> dict[str, Any]:
|
|
51
|
+
config: dict[str, Any] = {}
|
|
52
|
+
for item in codex_configs:
|
|
53
|
+
key, separator, raw_value = item.partition("=")
|
|
54
|
+
if not separator:
|
|
55
|
+
continue
|
|
56
|
+
target = config
|
|
57
|
+
parts = [part for part in key.strip().split(".") if part]
|
|
58
|
+
if not parts:
|
|
59
|
+
continue
|
|
60
|
+
for part in parts[:-1]:
|
|
61
|
+
next_target = target.setdefault(part, {})
|
|
62
|
+
if not isinstance(next_target, dict):
|
|
63
|
+
next_target = {}
|
|
64
|
+
target[part] = next_target
|
|
65
|
+
target = next_target
|
|
66
|
+
target[parts[-1]] = _parse_config_value(raw_value.strip())
|
|
67
|
+
return config
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def default_source_codex_home() -> Path:
|
|
71
|
+
return Path(os.environ.get("CODEX_HOME", Path.home() / ".codex")).expanduser()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def default_app_server_codex_home() -> Path:
|
|
75
|
+
return Path.home() / ".codex-api-proxy" / "codex-home"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def default_skill_roots(*, source_codex_home: Path | None = None) -> list[Path]:
|
|
79
|
+
source_home = source_codex_home or default_source_codex_home()
|
|
80
|
+
return [source_home / "skills", Path.home() / ".agents" / "skills"]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _parse_skill_name(skill_file: Path) -> str | None:
|
|
84
|
+
try:
|
|
85
|
+
with skill_file.open(encoding="utf-8") as handle:
|
|
86
|
+
for index, line in enumerate(handle):
|
|
87
|
+
if index > 80:
|
|
88
|
+
return None
|
|
89
|
+
key, separator, value = line.partition(":")
|
|
90
|
+
if separator and key.strip() == "name":
|
|
91
|
+
return value.strip().strip("\"'")
|
|
92
|
+
except OSError:
|
|
93
|
+
return None
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def discover_skill_names(skill_roots: list[Path]) -> set[str]:
|
|
98
|
+
names = set(DEFAULT_DISABLED_SKILL_NAMES)
|
|
99
|
+
for root in skill_roots:
|
|
100
|
+
expanded_root = root.expanduser()
|
|
101
|
+
if not expanded_root.exists():
|
|
102
|
+
continue
|
|
103
|
+
for skill_file in [*expanded_root.glob("*/SKILL.md"), *expanded_root.glob(".system/*/SKILL.md")]:
|
|
104
|
+
names.add(_parse_skill_name(skill_file) or skill_file.parent.name)
|
|
105
|
+
return names
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def build_disabled_skills_config(*, skill_roots: list[Path] | None = None) -> str:
|
|
109
|
+
names = discover_skill_names(skill_roots or default_skill_roots())
|
|
110
|
+
entries = ",".join(f"{{name={json.dumps(name)},enabled=false}}" for name in sorted(names))
|
|
111
|
+
return f"skills.config=[{entries}]"
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def build_app_server_command(
|
|
115
|
+
*,
|
|
116
|
+
codex_bin: str,
|
|
117
|
+
codex_configs: list[str],
|
|
118
|
+
disabled_skills_config: str,
|
|
119
|
+
) -> list[str]:
|
|
120
|
+
command = [
|
|
121
|
+
codex_bin,
|
|
122
|
+
"app-server",
|
|
123
|
+
"--stdio",
|
|
124
|
+
"--disable",
|
|
125
|
+
"apps",
|
|
126
|
+
"--disable",
|
|
127
|
+
"plugins",
|
|
128
|
+
"--disable",
|
|
129
|
+
"skill_mcp_dependency_install",
|
|
130
|
+
]
|
|
131
|
+
for config in codex_configs:
|
|
132
|
+
command.extend(["-c", config])
|
|
133
|
+
command.extend(["-c", "mcp_servers={}", "-c", disabled_skills_config])
|
|
134
|
+
return command
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def prepare_isolated_codex_home(codex_home: Path, *, source_codex_home: Path | None = None) -> Path:
|
|
138
|
+
source_home = (source_codex_home or default_source_codex_home()).expanduser()
|
|
139
|
+
target_home = codex_home.expanduser()
|
|
140
|
+
if target_home.resolve(strict=False) == source_home.resolve(strict=False):
|
|
141
|
+
raise AppServerRunError("app-server CODEX_HOME must be isolated from the current Codex user home")
|
|
142
|
+
|
|
143
|
+
target_home.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
source_auth = source_home / "auth.json"
|
|
145
|
+
target_auth = target_home / "auth.json"
|
|
146
|
+
if not source_auth.exists():
|
|
147
|
+
return target_home
|
|
148
|
+
|
|
149
|
+
if target_auth.is_symlink():
|
|
150
|
+
try:
|
|
151
|
+
if target_auth.resolve() == source_auth.resolve():
|
|
152
|
+
return target_home
|
|
153
|
+
except FileNotFoundError:
|
|
154
|
+
pass
|
|
155
|
+
target_auth.unlink()
|
|
156
|
+
|
|
157
|
+
if not target_auth.exists():
|
|
158
|
+
target_auth.symlink_to(source_auth)
|
|
159
|
+
return target_home
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def build_app_server_env(*, proxy: str | None, codex_home: Path) -> dict[str, str]:
|
|
163
|
+
env = os.environ.copy()
|
|
164
|
+
env["CODEX_HOME"] = str(codex_home)
|
|
165
|
+
if proxy:
|
|
166
|
+
env["http_proxy"] = proxy
|
|
167
|
+
env["https_proxy"] = proxy
|
|
168
|
+
env["HTTP_PROXY"] = proxy
|
|
169
|
+
env["HTTPS_PROXY"] = proxy
|
|
170
|
+
return env
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class StdioJsonRpcClient:
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
*,
|
|
177
|
+
codex_bin: str,
|
|
178
|
+
codex_configs: list[str],
|
|
179
|
+
proxy: str | None,
|
|
180
|
+
codex_home: Path | None,
|
|
181
|
+
timeout_seconds: float,
|
|
182
|
+
) -> None:
|
|
183
|
+
self.codex_bin = codex_bin
|
|
184
|
+
self.codex_configs = codex_configs
|
|
185
|
+
self.proxy = proxy
|
|
186
|
+
self.codex_home = codex_home
|
|
187
|
+
self.timeout_seconds = timeout_seconds
|
|
188
|
+
self.process: asyncio.subprocess.Process | None = None
|
|
189
|
+
self._next_id = 1
|
|
190
|
+
self._pending: dict[int, asyncio.Future[dict[str, Any]]] = {}
|
|
191
|
+
self._notifications: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
192
|
+
self._reader_task: asyncio.Task[None] | None = None
|
|
193
|
+
self._stderr_task: asyncio.Task[None] | None = None
|
|
194
|
+
self._stderr_tail: list[str] = []
|
|
195
|
+
self._write_lock = asyncio.Lock()
|
|
196
|
+
|
|
197
|
+
async def start(self) -> None:
|
|
198
|
+
if self.process:
|
|
199
|
+
return
|
|
200
|
+
source_codex_home = default_source_codex_home()
|
|
201
|
+
codex_home = prepare_isolated_codex_home(
|
|
202
|
+
self.codex_home or default_app_server_codex_home(),
|
|
203
|
+
source_codex_home=source_codex_home,
|
|
204
|
+
)
|
|
205
|
+
command = build_app_server_command(
|
|
206
|
+
codex_bin=self.codex_bin,
|
|
207
|
+
codex_configs=self.codex_configs,
|
|
208
|
+
disabled_skills_config=build_disabled_skills_config(
|
|
209
|
+
skill_roots=default_skill_roots(source_codex_home=source_codex_home)
|
|
210
|
+
),
|
|
211
|
+
)
|
|
212
|
+
env = build_app_server_env(proxy=self.proxy, codex_home=codex_home)
|
|
213
|
+
self.process = await asyncio.create_subprocess_exec(
|
|
214
|
+
*command,
|
|
215
|
+
stdin=asyncio.subprocess.PIPE,
|
|
216
|
+
stdout=asyncio.subprocess.PIPE,
|
|
217
|
+
stderr=asyncio.subprocess.PIPE,
|
|
218
|
+
env=env,
|
|
219
|
+
start_new_session=True,
|
|
220
|
+
)
|
|
221
|
+
self._reader_task = asyncio.create_task(self._read_stdout())
|
|
222
|
+
self._stderr_task = asyncio.create_task(self._read_stderr())
|
|
223
|
+
await self.request(
|
|
224
|
+
"initialize",
|
|
225
|
+
{
|
|
226
|
+
"clientInfo": {"name": "codex-api-proxy", "version": __version__},
|
|
227
|
+
"capabilities": {"experimentalApi": True},
|
|
228
|
+
},
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
async def close(self) -> None:
|
|
232
|
+
process = self.process
|
|
233
|
+
if not process:
|
|
234
|
+
return
|
|
235
|
+
if process.returncode is None:
|
|
236
|
+
try:
|
|
237
|
+
process.send_signal(signal.SIGTERM)
|
|
238
|
+
except ProcessLookupError:
|
|
239
|
+
pass
|
|
240
|
+
try:
|
|
241
|
+
await asyncio.wait_for(process.wait(), timeout=2)
|
|
242
|
+
except TimeoutError:
|
|
243
|
+
process.kill()
|
|
244
|
+
await process.wait()
|
|
245
|
+
for task in (self._reader_task, self._stderr_task):
|
|
246
|
+
if task:
|
|
247
|
+
task.cancel()
|
|
248
|
+
self.process = None
|
|
249
|
+
|
|
250
|
+
async def request(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
251
|
+
await self.start()
|
|
252
|
+
request_id = self._next_id
|
|
253
|
+
self._next_id += 1
|
|
254
|
+
loop = asyncio.get_running_loop()
|
|
255
|
+
future: asyncio.Future[dict[str, Any]] = loop.create_future()
|
|
256
|
+
self._pending[request_id] = future
|
|
257
|
+
await self._send({"jsonrpc": "2.0", "id": request_id, "method": method, "params": params})
|
|
258
|
+
try:
|
|
259
|
+
return await asyncio.wait_for(future, timeout=self.timeout_seconds)
|
|
260
|
+
finally:
|
|
261
|
+
self._pending.pop(request_id, None)
|
|
262
|
+
|
|
263
|
+
async def read_message(self) -> dict[str, Any]:
|
|
264
|
+
await self.start()
|
|
265
|
+
return await asyncio.wait_for(self._notifications.get(), timeout=self.timeout_seconds)
|
|
266
|
+
|
|
267
|
+
async def _send(self, payload: dict[str, Any]) -> None:
|
|
268
|
+
process = self.process
|
|
269
|
+
if not process or not process.stdin:
|
|
270
|
+
raise AppServerRunError("codex app-server process is not running")
|
|
271
|
+
data = json.dumps(payload, separators=(",", ":")).encode("utf-8") + b"\n"
|
|
272
|
+
async with self._write_lock:
|
|
273
|
+
process.stdin.write(data)
|
|
274
|
+
await process.stdin.drain()
|
|
275
|
+
|
|
276
|
+
async def _read_stdout(self) -> None:
|
|
277
|
+
process = self.process
|
|
278
|
+
if not process or not process.stdout:
|
|
279
|
+
return
|
|
280
|
+
try:
|
|
281
|
+
while line := await process.stdout.readline():
|
|
282
|
+
try:
|
|
283
|
+
message = json.loads(line.decode("utf-8"))
|
|
284
|
+
except json.JSONDecodeError:
|
|
285
|
+
continue
|
|
286
|
+
await self._dispatch_message(message)
|
|
287
|
+
except asyncio.CancelledError:
|
|
288
|
+
raise
|
|
289
|
+
except Exception as exc:
|
|
290
|
+
self._fail_pending(AppServerRunError(f"app-server stdout reader failed: {exc}"))
|
|
291
|
+
finally:
|
|
292
|
+
if process.returncode is None:
|
|
293
|
+
return
|
|
294
|
+
self._fail_pending(AppServerRunError(self._process_error_detail()))
|
|
295
|
+
|
|
296
|
+
async def _read_stderr(self) -> None:
|
|
297
|
+
process = self.process
|
|
298
|
+
if not process or not process.stderr:
|
|
299
|
+
return
|
|
300
|
+
try:
|
|
301
|
+
while line := await process.stderr.readline():
|
|
302
|
+
text = line.decode("utf-8", errors="replace").strip()
|
|
303
|
+
if text:
|
|
304
|
+
self._stderr_tail.append(text)
|
|
305
|
+
self._stderr_tail = self._stderr_tail[-20:]
|
|
306
|
+
except asyncio.CancelledError:
|
|
307
|
+
raise
|
|
308
|
+
|
|
309
|
+
async def _dispatch_message(self, message: Any) -> None:
|
|
310
|
+
if not isinstance(message, dict):
|
|
311
|
+
return
|
|
312
|
+
request_id = message.get("id")
|
|
313
|
+
method = message.get("method")
|
|
314
|
+
if request_id is not None and method is None:
|
|
315
|
+
future = self._pending.get(request_id)
|
|
316
|
+
if future and not future.done():
|
|
317
|
+
if "error" in message:
|
|
318
|
+
future.set_exception(AppServerRunError(str(message["error"])))
|
|
319
|
+
else:
|
|
320
|
+
result = message.get("result") or {}
|
|
321
|
+
future.set_result(result if isinstance(result, dict) else {"value": result})
|
|
322
|
+
return
|
|
323
|
+
if request_id is not None and method is not None:
|
|
324
|
+
await self._send(
|
|
325
|
+
{
|
|
326
|
+
"jsonrpc": "2.0",
|
|
327
|
+
"id": request_id,
|
|
328
|
+
"error": {"code": -32601, "message": f"unsupported server request: {method}"},
|
|
329
|
+
}
|
|
330
|
+
)
|
|
331
|
+
return
|
|
332
|
+
await self._notifications.put(message)
|
|
333
|
+
|
|
334
|
+
def _fail_pending(self, exc: BaseException) -> None:
|
|
335
|
+
for future in self._pending.values():
|
|
336
|
+
if not future.done():
|
|
337
|
+
future.set_exception(exc)
|
|
338
|
+
|
|
339
|
+
def _process_error_detail(self) -> str:
|
|
340
|
+
if self._stderr_tail:
|
|
341
|
+
return "codex app-server exited: " + " ".join(self._stderr_tail[-5:])
|
|
342
|
+
return "codex app-server exited"
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
LatencyCallback = Callable[[str, float], None]
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _elapsed_ms(started_at: float) -> float:
|
|
349
|
+
return round((time.perf_counter() - started_at) * 1000, 2)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _record_latency(callback: LatencyCallback | None, name: str, started_at: float) -> None:
|
|
353
|
+
if callback:
|
|
354
|
+
callback(name, _elapsed_ms(started_at))
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
async def stream_app_server_turn(
|
|
358
|
+
*,
|
|
359
|
+
client: AppServerClient,
|
|
360
|
+
cwd: Path,
|
|
361
|
+
prompt: str,
|
|
362
|
+
model: str | None,
|
|
363
|
+
codex_configs: list[str],
|
|
364
|
+
ephemeral: bool,
|
|
365
|
+
timeout_seconds: float,
|
|
366
|
+
latency_callback: LatencyCallback | None = None,
|
|
367
|
+
) -> AsyncIterator[str]:
|
|
368
|
+
await client.start()
|
|
369
|
+
thread_id: str | None = None
|
|
370
|
+
config = codex_configs_to_object(codex_configs)
|
|
371
|
+
thread_params: dict[str, Any] = {
|
|
372
|
+
"approvalPolicy": "never",
|
|
373
|
+
"config": config,
|
|
374
|
+
"cwd": str(cwd),
|
|
375
|
+
"dynamicTools": [],
|
|
376
|
+
"environments": [],
|
|
377
|
+
"ephemeral": ephemeral,
|
|
378
|
+
"model": model,
|
|
379
|
+
"sandbox": "read-only",
|
|
380
|
+
}
|
|
381
|
+
phase_started_at = time.perf_counter()
|
|
382
|
+
start_response = await asyncio.wait_for(client.request("thread/start", thread_params), timeout=timeout_seconds)
|
|
383
|
+
_record_latency(latency_callback, "app_server_thread_start", phase_started_at)
|
|
384
|
+
thread = start_response.get("thread")
|
|
385
|
+
if not isinstance(thread, dict) or not isinstance(thread.get("id"), str):
|
|
386
|
+
raise AppServerRunError("app-server thread/start returned no thread id")
|
|
387
|
+
thread_id = thread["id"]
|
|
388
|
+
phase_started_at = time.perf_counter()
|
|
389
|
+
turn_response = await asyncio.wait_for(
|
|
390
|
+
client.request("turn/start", {"input": [{"type": "text", "text": prompt}], "threadId": thread_id}),
|
|
391
|
+
timeout=timeout_seconds,
|
|
392
|
+
)
|
|
393
|
+
_record_latency(latency_callback, "app_server_turn_start", phase_started_at)
|
|
394
|
+
turn = turn_response.get("turn")
|
|
395
|
+
if not isinstance(turn, dict) or not isinstance(turn.get("id"), str):
|
|
396
|
+
raise AppServerRunError("app-server turn/start returned no turn id")
|
|
397
|
+
turn_id = turn["id"]
|
|
398
|
+
emitted = False
|
|
399
|
+
first_delta_recorded = False
|
|
400
|
+
final_text: str | None = None
|
|
401
|
+
event_wait_started_at = time.perf_counter()
|
|
402
|
+
try:
|
|
403
|
+
while True:
|
|
404
|
+
message = await asyncio.wait_for(client.read_message(), timeout=timeout_seconds)
|
|
405
|
+
method = message.get("method")
|
|
406
|
+
params = message.get("params")
|
|
407
|
+
if not isinstance(params, dict):
|
|
408
|
+
continue
|
|
409
|
+
if params.get("threadId") != thread_id:
|
|
410
|
+
continue
|
|
411
|
+
if method == "item/agentMessage/delta" and params.get("turnId") == turn_id:
|
|
412
|
+
delta = params.get("delta")
|
|
413
|
+
if isinstance(delta, str) and delta:
|
|
414
|
+
if not first_delta_recorded:
|
|
415
|
+
first_delta_recorded = True
|
|
416
|
+
_record_latency(latency_callback, "app_server_first_delta", event_wait_started_at)
|
|
417
|
+
emitted = True
|
|
418
|
+
yield delta
|
|
419
|
+
elif method == "item/completed" and params.get("turnId") == turn_id:
|
|
420
|
+
item = params.get("item")
|
|
421
|
+
if isinstance(item, dict) and item.get("type") == "agentMessage" and isinstance(item.get("text"), str):
|
|
422
|
+
final_text = item["text"]
|
|
423
|
+
elif method == "turn/completed":
|
|
424
|
+
turn = params.get("turn")
|
|
425
|
+
if isinstance(turn, dict) and turn.get("id") == turn_id:
|
|
426
|
+
_record_latency(latency_callback, "app_server_turn_complete", event_wait_started_at)
|
|
427
|
+
if final_text and not emitted:
|
|
428
|
+
yield final_text
|
|
429
|
+
return
|
|
430
|
+
finally:
|
|
431
|
+
if thread_id:
|
|
432
|
+
try:
|
|
433
|
+
phase_started_at = time.perf_counter()
|
|
434
|
+
await client.request("thread/archive", {"threadId": thread_id})
|
|
435
|
+
_record_latency(latency_callback, "app_server_thread_archive", phase_started_at)
|
|
436
|
+
except Exception:
|
|
437
|
+
pass
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
ClientFactory = Callable[[], AppServerClient]
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class AppServerWorkerPool:
|
|
444
|
+
def __init__(
|
|
445
|
+
self,
|
|
446
|
+
*,
|
|
447
|
+
client_factory: ClientFactory,
|
|
448
|
+
workers: int,
|
|
449
|
+
max_queue_size: int,
|
|
450
|
+
queue_timeout_seconds: float,
|
|
451
|
+
) -> None:
|
|
452
|
+
self.client_factory = client_factory
|
|
453
|
+
self.worker_count = workers
|
|
454
|
+
self.max_queue_size = max_queue_size
|
|
455
|
+
self.queue_timeout_seconds = queue_timeout_seconds
|
|
456
|
+
self._idle: asyncio.Queue[AppServerClient] = asyncio.Queue()
|
|
457
|
+
self._started = False
|
|
458
|
+
self._start_lock = asyncio.Lock()
|
|
459
|
+
self._waiters = 0
|
|
460
|
+
self._waiters_lock = asyncio.Lock()
|
|
461
|
+
|
|
462
|
+
async def start(self) -> None:
|
|
463
|
+
async with self._start_lock:
|
|
464
|
+
if self._started:
|
|
465
|
+
return
|
|
466
|
+
for _ in range(self.worker_count):
|
|
467
|
+
client = self.client_factory()
|
|
468
|
+
await client.start()
|
|
469
|
+
await self._idle.put(client)
|
|
470
|
+
self._started = True
|
|
471
|
+
|
|
472
|
+
async def close(self) -> None:
|
|
473
|
+
while not self._idle.empty():
|
|
474
|
+
client = await self._idle.get()
|
|
475
|
+
await client.close()
|
|
476
|
+
|
|
477
|
+
async def stream_completion(
|
|
478
|
+
self,
|
|
479
|
+
*,
|
|
480
|
+
cwd: Path,
|
|
481
|
+
prompt: str,
|
|
482
|
+
model: str | None,
|
|
483
|
+
codex_configs: list[str],
|
|
484
|
+
ephemeral: bool,
|
|
485
|
+
timeout_seconds: float,
|
|
486
|
+
latency_callback: LatencyCallback | None = None,
|
|
487
|
+
) -> AsyncIterator[str]:
|
|
488
|
+
client = await self._acquire()
|
|
489
|
+
healthy = True
|
|
490
|
+
try:
|
|
491
|
+
async for chunk in stream_app_server_turn(
|
|
492
|
+
client=client,
|
|
493
|
+
cwd=cwd,
|
|
494
|
+
prompt=prompt,
|
|
495
|
+
model=model,
|
|
496
|
+
codex_configs=codex_configs,
|
|
497
|
+
ephemeral=ephemeral,
|
|
498
|
+
timeout_seconds=timeout_seconds,
|
|
499
|
+
latency_callback=latency_callback,
|
|
500
|
+
):
|
|
501
|
+
yield chunk
|
|
502
|
+
except Exception:
|
|
503
|
+
healthy = False
|
|
504
|
+
await client.close()
|
|
505
|
+
raise
|
|
506
|
+
finally:
|
|
507
|
+
if healthy:
|
|
508
|
+
await self._idle.put(client)
|
|
509
|
+
else:
|
|
510
|
+
replacement = self.client_factory()
|
|
511
|
+
try:
|
|
512
|
+
await replacement.start()
|
|
513
|
+
await self._idle.put(replacement)
|
|
514
|
+
except Exception:
|
|
515
|
+
await replacement.close()
|
|
516
|
+
|
|
517
|
+
async def _acquire(self) -> AppServerClient:
|
|
518
|
+
await self.start()
|
|
519
|
+
async with self._waiters_lock:
|
|
520
|
+
if self._waiters >= self.max_queue_size:
|
|
521
|
+
raise AppServerBusy("app-server worker queue is full")
|
|
522
|
+
self._waiters += 1
|
|
523
|
+
try:
|
|
524
|
+
return await asyncio.wait_for(self._idle.get(), timeout=self.queue_timeout_seconds)
|
|
525
|
+
except TimeoutError as exc:
|
|
526
|
+
raise AppServerBusy("timed out waiting for an app-server worker") from exc
|
|
527
|
+
finally:
|
|
528
|
+
async with self._waiters_lock:
|
|
529
|
+
self._waiters -= 1
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def create_stdio_app_server_pool(
|
|
533
|
+
*,
|
|
534
|
+
codex_bin: str,
|
|
535
|
+
proxy: str | None,
|
|
536
|
+
codex_home: Path | None,
|
|
537
|
+
codex_configs: list[str],
|
|
538
|
+
workers: int,
|
|
539
|
+
max_queue_size: int,
|
|
540
|
+
queue_timeout_seconds: float,
|
|
541
|
+
timeout_seconds: float,
|
|
542
|
+
) -> AppServerWorkerPool:
|
|
543
|
+
return AppServerWorkerPool(
|
|
544
|
+
client_factory=lambda: StdioJsonRpcClient(
|
|
545
|
+
codex_bin=codex_bin,
|
|
546
|
+
codex_configs=codex_configs,
|
|
547
|
+
proxy=proxy,
|
|
548
|
+
codex_home=codex_home,
|
|
549
|
+
timeout_seconds=timeout_seconds,
|
|
550
|
+
),
|
|
551
|
+
workers=workers,
|
|
552
|
+
max_queue_size=max_queue_size,
|
|
553
|
+
queue_timeout_seconds=queue_timeout_seconds,
|
|
554
|
+
)
|