codex-proxy 3.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codex_proxy/__init__.py +3 -0
- codex_proxy/__main__.py +66 -0
- codex_proxy/circuit_breaker.py +83 -0
- codex_proxy/compaction.py +42 -0
- codex_proxy/config.py +313 -0
- codex_proxy/key_rotation.py +108 -0
- codex_proxy/plugins.py +110 -0
- codex_proxy/plugins_builtin.py +34 -0
- codex_proxy/providers.py +130 -0
- codex_proxy/server.py +647 -0
- codex_proxy/store.py +97 -0
- codex_proxy/translator.py +360 -0
- codex_proxy/tui.py +262 -0
- codex_proxy-3.1.0.dist-info/METADATA +25 -0
- codex_proxy-3.1.0.dist-info/RECORD +18 -0
- codex_proxy-3.1.0.dist-info/WHEEL +4 -0
- codex_proxy-3.1.0.dist-info/entry_points.txt +2 -0
- codex_proxy-3.1.0.dist-info/licenses/LICENSE +21 -0
codex_proxy/server.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
1
|
+
"""FastAPI server — WebSocket + HTTP endpoints for Codex CLI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
import uuid
|
|
10
|
+
from contextlib import asynccontextmanager
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
import uvicorn
|
|
16
|
+
from fastapi import FastAPI, Header, Request, WebSocket, WebSocketDisconnect
|
|
17
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
|
18
|
+
|
|
19
|
+
from . import __version__
|
|
20
|
+
from .circuit_breaker import CircuitBreaker
|
|
21
|
+
from .config import ProxyConfig
|
|
22
|
+
from .key_rotation import KeyRotator
|
|
23
|
+
from .plugins import PluginContext, PluginRegistry
|
|
24
|
+
from .providers import ProviderAdapter, get_adapter
|
|
25
|
+
from .store import ResponseStore
|
|
26
|
+
from .translator import (
|
|
27
|
+
accumulate_tool_call,
|
|
28
|
+
build_cc_request,
|
|
29
|
+
build_final_output,
|
|
30
|
+
cc_to_response,
|
|
31
|
+
generate_response_id,
|
|
32
|
+
parse_cc_stream,
|
|
33
|
+
stream_cc_to_response,
|
|
34
|
+
unwrap_envelope,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger("codex-proxy")
|
|
38
|
+
|
|
39
|
+
MAX_RETRIES = 1
|
|
40
|
+
RETRY_DELAY = 0.5
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class AppState:
|
|
45
|
+
config: ProxyConfig
|
|
46
|
+
store: ResponseStore
|
|
47
|
+
client: httpx.AsyncClient
|
|
48
|
+
adapter: ProviderAdapter
|
|
49
|
+
circuit_breaker: CircuitBreaker | None
|
|
50
|
+
start_time: float = 0.0
|
|
51
|
+
request_count: int = 0
|
|
52
|
+
success_count: int = 0
|
|
53
|
+
failure_count: int = 0
|
|
54
|
+
last_request_time: float = 0.0
|
|
55
|
+
key_rotator: KeyRotator | None = None
|
|
56
|
+
plugin_registry: PluginRegistry | None = None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@asynccontextmanager
|
|
60
|
+
async def lifespan(app):
|
|
61
|
+
state: AppState = app.state.proxy
|
|
62
|
+
if state.plugin_registry:
|
|
63
|
+
await state.plugin_registry.on_startup(state.config)
|
|
64
|
+
yield
|
|
65
|
+
if state.plugin_registry:
|
|
66
|
+
await state.plugin_registry.on_shutdown()
|
|
67
|
+
if state.client:
|
|
68
|
+
await state.client.aclose()
|
|
69
|
+
logger.info("codex-proxy shut down")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
app = FastAPI(title="codex-proxy", lifespan=lifespan)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def configure(config: ProxyConfig) -> None:
|
|
76
|
+
store = ResponseStore(
|
|
77
|
+
ttl_seconds=config.store.ttl_seconds,
|
|
78
|
+
max_entries=config.store.max_entries,
|
|
79
|
+
)
|
|
80
|
+
client = httpx.AsyncClient(
|
|
81
|
+
timeout=httpx.Timeout(connect=10, read=180, write=10, pool=30),
|
|
82
|
+
)
|
|
83
|
+
adapter = get_adapter(config.provider.name)
|
|
84
|
+
cb_config = config.circuit_breaker
|
|
85
|
+
cb = CircuitBreaker(
|
|
86
|
+
failure_threshold=cb_config.failure_threshold,
|
|
87
|
+
recovery_timeout=cb_config.recovery_timeout,
|
|
88
|
+
) if cb_config.enabled else None
|
|
89
|
+
keys = config.provider.effective_api_keys()
|
|
90
|
+
key_rotator = None
|
|
91
|
+
if len(keys) > 1:
|
|
92
|
+
key_rotator = KeyRotator(
|
|
93
|
+
keys=keys,
|
|
94
|
+
failure_threshold=cb_config.failure_threshold,
|
|
95
|
+
recovery_timeout=cb_config.recovery_timeout,
|
|
96
|
+
)
|
|
97
|
+
app.state.proxy = AppState(
|
|
98
|
+
config=config, store=store, client=client,
|
|
99
|
+
adapter=adapter, circuit_breaker=cb, start_time=time.time(),
|
|
100
|
+
key_rotator=key_rotator,
|
|
101
|
+
plugin_registry=_build_plugin_registry(config),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _state() -> AppState:
|
|
106
|
+
val = app.state.proxy
|
|
107
|
+
assert isinstance(val, AppState)
|
|
108
|
+
return val
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _build_plugin_registry(config: ProxyConfig) -> PluginRegistry | None:
|
|
112
|
+
if config.plugins.enabled and config.plugins.plugins:
|
|
113
|
+
registry = PluginRegistry()
|
|
114
|
+
registry.load(config.plugins.plugins)
|
|
115
|
+
return registry
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _mask_key(key: str) -> str:
|
|
120
|
+
if len(key) <= 7:
|
|
121
|
+
return "***"
|
|
122
|
+
return f"{key[:3]}...{key[-4:]}"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _build_plugin_ctx(request_id: str, method: str, model: str,
|
|
126
|
+
api_key: str, stream: bool) -> PluginContext:
|
|
127
|
+
state = _state()
|
|
128
|
+
return PluginContext(
|
|
129
|
+
request_id=request_id, method=method, model=model,
|
|
130
|
+
provider=state.config.provider.name,
|
|
131
|
+
api_key_masked=_mask_key(api_key), stream=stream,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _api_key(auth_header: str) -> str:
|
|
136
|
+
if auth_header:
|
|
137
|
+
lower = auth_header.lower()
|
|
138
|
+
if lower.startswith("bearer "):
|
|
139
|
+
return auth_header[7:].strip()
|
|
140
|
+
return auth_header.strip()
|
|
141
|
+
state = _state()
|
|
142
|
+
if state.key_rotator:
|
|
143
|
+
return state.key_rotator.next_key()
|
|
144
|
+
return state.config.provider.effective_api_key()
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _cc_headers(api_key: str) -> dict:
|
|
148
|
+
state = _state()
|
|
149
|
+
h = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
150
|
+
h.update(state.config.provider.extra_headers)
|
|
151
|
+
return state.adapter.adjust_headers(h)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _backend_url() -> str:
|
|
155
|
+
return _state().config.provider.base_url
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
async def _post_with_retry(url: str, json_body: dict, headers: dict) -> httpx.Response:
|
|
159
|
+
"""POST with one retry on 5xx or transport errors."""
|
|
160
|
+
client = _state().client
|
|
161
|
+
for attempt in range(MAX_RETRIES + 1):
|
|
162
|
+
try:
|
|
163
|
+
r = await client.post(url, json=json_body, headers=headers)
|
|
164
|
+
if r.status_code < 500 or attempt == MAX_RETRIES:
|
|
165
|
+
return r
|
|
166
|
+
logger.warning("Upstream 5xx (attempt %d), retrying...", attempt + 1)
|
|
167
|
+
except httpx.TransportError as e:
|
|
168
|
+
if attempt == MAX_RETRIES:
|
|
169
|
+
raise
|
|
170
|
+
logger.warning("Transport error (attempt %d): %s, retrying...", attempt + 1, e)
|
|
171
|
+
await asyncio.sleep(RETRY_DELAY)
|
|
172
|
+
raise httpx.TransportError("Max retries exceeded")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
async def _request_with_key_failover(url: str, json_body: dict,
|
|
176
|
+
headers: dict, api_key: str) -> httpx.Response:
|
|
177
|
+
"""POST with retry. On auth/rate-limit errors, rotate key and retry."""
|
|
178
|
+
state = _state()
|
|
179
|
+
if not state.key_rotator:
|
|
180
|
+
return await _post_with_retry(url, json_body, headers)
|
|
181
|
+
max_attempts = len(state.key_rotator._keys)
|
|
182
|
+
r: httpx.Response | None = None
|
|
183
|
+
for _ in range(max_attempts):
|
|
184
|
+
r = await _post_with_retry(url, json_body, headers)
|
|
185
|
+
if r.status_code in (401, 403, 429):
|
|
186
|
+
state.key_rotator.record_failure(api_key, r.status_code)
|
|
187
|
+
api_key = state.key_rotator.next_key()
|
|
188
|
+
headers = _cc_headers(api_key)
|
|
189
|
+
logger.warning("Key failed with %d, rotating", r.status_code)
|
|
190
|
+
continue
|
|
191
|
+
if r.status_code < 400:
|
|
192
|
+
state.key_rotator.record_success(api_key)
|
|
193
|
+
return r
|
|
194
|
+
assert r is not None
|
|
195
|
+
return r
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# ── HTTP endpoint ───────────────────────────────────────────────────────
|
|
199
|
+
|
|
200
|
+
@app.post("/responses")
|
|
201
|
+
async def responses_http(request: Request,
|
|
202
|
+
authorization: str = Header(default="")):
|
|
203
|
+
state = _state()
|
|
204
|
+
state.request_count += 1
|
|
205
|
+
state.last_request_time = time.time()
|
|
206
|
+
|
|
207
|
+
if state.circuit_breaker and not state.circuit_breaker.can_execute():
|
|
208
|
+
return JSONResponse(
|
|
209
|
+
{"error": {"message": "Circuit breaker open", "code": "circuit_open"}},
|
|
210
|
+
status_code=503,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
body = await request.json()
|
|
214
|
+
body = state.store.resolve_input(body)
|
|
215
|
+
model = body.get("model", state.config.provider.default_model)
|
|
216
|
+
stream = body.get("stream", False)
|
|
217
|
+
api_key = _api_key(authorization)
|
|
218
|
+
|
|
219
|
+
cc_body = build_cc_request(
|
|
220
|
+
body,
|
|
221
|
+
compaction_enabled=state.config.compaction.enabled,
|
|
222
|
+
compaction_max_messages=state.config.compaction.max_messages,
|
|
223
|
+
compaction_keep_last=state.config.compaction.keep_last,
|
|
224
|
+
)
|
|
225
|
+
cc_body["model"] = model
|
|
226
|
+
cc_body["stream"] = stream
|
|
227
|
+
cc_body = state.adapter.adjust_request(cc_body)
|
|
228
|
+
headers = _cc_headers(api_key)
|
|
229
|
+
|
|
230
|
+
# Plugin: on_request
|
|
231
|
+
req_id = uuid.uuid4().hex[:12]
|
|
232
|
+
if state.plugin_registry:
|
|
233
|
+
ctx = _build_plugin_ctx(req_id, "http", model, api_key, stream)
|
|
234
|
+
await state.plugin_registry.on_request(ctx)
|
|
235
|
+
|
|
236
|
+
if stream:
|
|
237
|
+
result_holder: dict = {}
|
|
238
|
+
original_input = body.get("input", [])
|
|
239
|
+
|
|
240
|
+
async def _stream():
|
|
241
|
+
async with state.client.stream("POST", f"{_backend_url()}/chat/completions",
|
|
242
|
+
json=cc_body, headers=headers) as resp:
|
|
243
|
+
if resp.status_code >= 400:
|
|
244
|
+
error_body = await resp.aread()
|
|
245
|
+
logger.error("Upstream error %d: %s", resp.status_code, error_body[:500])
|
|
246
|
+
if state.circuit_breaker:
|
|
247
|
+
state.circuit_breaker.record_failure()
|
|
248
|
+
state.failure_count += 1
|
|
249
|
+
yield f"event: error\ndata: {json.dumps({'error': {'message': 'upstream error', 'status': resp.status_code}})}\n\n"
|
|
250
|
+
return
|
|
251
|
+
|
|
252
|
+
if state.circuit_breaker:
|
|
253
|
+
state.circuit_breaker.record_success()
|
|
254
|
+
state.success_count += 1
|
|
255
|
+
|
|
256
|
+
gen = stream_cc_to_response(resp.aiter_lines(), model, result=result_holder)
|
|
257
|
+
async for chunk in gen:
|
|
258
|
+
yield chunk
|
|
259
|
+
|
|
260
|
+
completed = result_holder.get("response")
|
|
261
|
+
if completed:
|
|
262
|
+
state.store.put(completed["id"], {**completed, "_original_input": original_input})
|
|
263
|
+
|
|
264
|
+
return StreamingResponse(_stream(), media_type="text/event-stream",
|
|
265
|
+
headers={"Cache-Control": "no-cache",
|
|
266
|
+
"X-Accel-Buffering": "no"})
|
|
267
|
+
|
|
268
|
+
start_t = time.monotonic()
|
|
269
|
+
r = await _request_with_key_failover(
|
|
270
|
+
f"{_backend_url()}/chat/completions", cc_body, headers, api_key)
|
|
271
|
+
duration = (time.monotonic() - start_t) * 1000
|
|
272
|
+
|
|
273
|
+
if r.status_code >= 400:
|
|
274
|
+
logger.error("Upstream error %d: %s", r.status_code, r.text[:500])
|
|
275
|
+
if state.circuit_breaker:
|
|
276
|
+
state.circuit_breaker.record_failure()
|
|
277
|
+
state.failure_count += 1
|
|
278
|
+
if state.plugin_registry:
|
|
279
|
+
ectx = _build_plugin_ctx(req_id, "http", model, api_key, stream)
|
|
280
|
+
ectx.status_code = r.status_code
|
|
281
|
+
ectx.error = r.text[:200]
|
|
282
|
+
ectx.duration_ms = duration
|
|
283
|
+
await state.plugin_registry.on_error(ectx)
|
|
284
|
+
return JSONResponse({"error": {"message": "upstream error",
|
|
285
|
+
"status": r.status_code}},
|
|
286
|
+
status_code=502)
|
|
287
|
+
|
|
288
|
+
if state.circuit_breaker:
|
|
289
|
+
state.circuit_breaker.record_success()
|
|
290
|
+
state.success_count += 1
|
|
291
|
+
if state.plugin_registry:
|
|
292
|
+
rctx = _build_plugin_ctx(req_id, "http", model, api_key, stream)
|
|
293
|
+
rctx.status_code = r.status_code
|
|
294
|
+
rctx.duration_ms = duration
|
|
295
|
+
await state.plugin_registry.on_response(rctx)
|
|
296
|
+
resp = cc_to_response(r.json(), model)
|
|
297
|
+
state.store.put(resp["id"], {**resp, "_original_input": body.get("input", [])})
|
|
298
|
+
return JSONResponse(resp)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# ── WebSocket endpoint ──────────────────────────────────────────────────
|
|
302
|
+
|
|
303
|
+
@app.websocket("/responses")
|
|
304
|
+
async def responses_ws(ws: WebSocket):
|
|
305
|
+
state = _state()
|
|
306
|
+
|
|
307
|
+
await ws.accept()
|
|
308
|
+
api_key = ""
|
|
309
|
+
auth = ws.headers.get("authorization", "")
|
|
310
|
+
if auth:
|
|
311
|
+
api_key = _api_key(auth)
|
|
312
|
+
if not api_key:
|
|
313
|
+
if state.key_rotator:
|
|
314
|
+
api_key = state.key_rotator.next_key()
|
|
315
|
+
else:
|
|
316
|
+
api_key = state.config.provider.effective_api_key()
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
while True:
|
|
320
|
+
raw = await ws.receive_text()
|
|
321
|
+
try:
|
|
322
|
+
body = unwrap_envelope(raw)
|
|
323
|
+
except json.JSONDecodeError:
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
state.request_count += 1
|
|
327
|
+
state.last_request_time = time.time()
|
|
328
|
+
# Rotate key per message when no client auth header
|
|
329
|
+
if not auth and state.key_rotator:
|
|
330
|
+
api_key = state.key_rotator.next_key()
|
|
331
|
+
body = state.store.resolve_input(body)
|
|
332
|
+
model = body.get("model", state.config.provider.default_model)
|
|
333
|
+
cc_body = build_cc_request(
|
|
334
|
+
body,
|
|
335
|
+
compaction_enabled=state.config.compaction.enabled,
|
|
336
|
+
compaction_max_messages=state.config.compaction.max_messages,
|
|
337
|
+
compaction_keep_last=state.config.compaction.keep_last,
|
|
338
|
+
)
|
|
339
|
+
cc_body["model"] = model
|
|
340
|
+
cc_body["stream"] = True
|
|
341
|
+
cc_body = state.adapter.adjust_request(cc_body)
|
|
342
|
+
headers = _cc_headers(api_key)
|
|
343
|
+
|
|
344
|
+
if state.circuit_breaker and not state.circuit_breaker.can_execute():
|
|
345
|
+
error_resp = {
|
|
346
|
+
"id": generate_response_id(), "object": "response",
|
|
347
|
+
"created_at": int(time.time()), "model": model,
|
|
348
|
+
"status": "failed", "output": [],
|
|
349
|
+
"usage": {"input_tokens": 0, "output_tokens": 0,
|
|
350
|
+
"total_tokens": 0},
|
|
351
|
+
"error": {"message": "Circuit breaker open",
|
|
352
|
+
"code": "circuit_open"},
|
|
353
|
+
}
|
|
354
|
+
await ws.send_text(json.dumps(
|
|
355
|
+
{"type": "response.failed", "response": error_resp}))
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
rid = generate_response_id()
|
|
359
|
+
mid = f"msg_{uuid.uuid4().hex[:24]}"
|
|
360
|
+
now = int(time.time())
|
|
361
|
+
|
|
362
|
+
init = {"id": rid, "object": "response", "created_at": now,
|
|
363
|
+
"model": model, "status": "in_progress", "output": [],
|
|
364
|
+
"usage": {"input_tokens": 0, "output_tokens": 0,
|
|
365
|
+
"total_tokens": 0}}
|
|
366
|
+
|
|
367
|
+
await ws.send_text(json.dumps(
|
|
368
|
+
{"type": "response.created", "response": init}))
|
|
369
|
+
await ws.send_text(json.dumps(
|
|
370
|
+
{"type": "response.in_progress", "response": init}))
|
|
371
|
+
await ws.send_text(json.dumps({
|
|
372
|
+
"type": "response.output_item.added", "output_index": 0,
|
|
373
|
+
"item": {"type": "message", "id": mid,
|
|
374
|
+
"status": "in_progress", "role": "assistant",
|
|
375
|
+
"content": []}}))
|
|
376
|
+
await ws.send_text(json.dumps({
|
|
377
|
+
"type": "response.content_part.added", "output_index": 0,
|
|
378
|
+
"content_index": 0,
|
|
379
|
+
"part": {"type": "output_text", "text": "",
|
|
380
|
+
"annotations": []}}))
|
|
381
|
+
|
|
382
|
+
full_text = ""
|
|
383
|
+
reasoning_text = ""
|
|
384
|
+
tool_calls: list[dict] = []
|
|
385
|
+
usage_data = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
|
386
|
+
|
|
387
|
+
try:
|
|
388
|
+
async with state.client.stream(
|
|
389
|
+
"POST", f"{_backend_url()}/chat/completions",
|
|
390
|
+
json=cc_body, headers=headers
|
|
391
|
+
) as resp:
|
|
392
|
+
async for event_type, data in parse_cc_stream(resp.aiter_lines()):
|
|
393
|
+
if event_type == "reasoning":
|
|
394
|
+
reasoning_text += data
|
|
395
|
+
elif event_type == "text":
|
|
396
|
+
full_text += data
|
|
397
|
+
await ws.send_text(json.dumps({
|
|
398
|
+
"type": "response.output_text.delta",
|
|
399
|
+
"output_index": 0,
|
|
400
|
+
"content_index": 0, "delta": data}))
|
|
401
|
+
elif event_type == "tool_call":
|
|
402
|
+
accumulate_tool_call(tool_calls, data)
|
|
403
|
+
elif event_type == "usage":
|
|
404
|
+
usage_data = {
|
|
405
|
+
"input_tokens": data.get("prompt_tokens", 0),
|
|
406
|
+
"output_tokens": data.get("completion_tokens", 0),
|
|
407
|
+
"total_tokens": data.get("total_tokens", 0),
|
|
408
|
+
}
|
|
409
|
+
except Exception as e:
|
|
410
|
+
logger.error("WS upstream error: %s", e)
|
|
411
|
+
if state.circuit_breaker:
|
|
412
|
+
state.circuit_breaker.record_failure()
|
|
413
|
+
state.failure_count += 1
|
|
414
|
+
error_resp = {
|
|
415
|
+
"id": rid, "object": "response", "created_at": now,
|
|
416
|
+
"model": model, "status": "failed",
|
|
417
|
+
"output": [], "usage": {"input_tokens": 0,
|
|
418
|
+
"output_tokens": 0,
|
|
419
|
+
"total_tokens": 0},
|
|
420
|
+
"error": {"message": f"Upstream error: {e}",
|
|
421
|
+
"code": "upstream_error"},
|
|
422
|
+
}
|
|
423
|
+
await ws.send_text(json.dumps(
|
|
424
|
+
{"type": "response.failed", "response": error_resp}))
|
|
425
|
+
continue
|
|
426
|
+
|
|
427
|
+
# Finish text
|
|
428
|
+
if state.circuit_breaker:
|
|
429
|
+
state.circuit_breaker.record_success()
|
|
430
|
+
state.success_count += 1
|
|
431
|
+
await ws.send_text(json.dumps({
|
|
432
|
+
"type": "response.output_text.done", "output_index": 0,
|
|
433
|
+
"content_index": 0, "text": full_text}))
|
|
434
|
+
await ws.send_text(json.dumps({
|
|
435
|
+
"type": "response.content_part.done", "output_index": 0,
|
|
436
|
+
"content_index": 0,
|
|
437
|
+
"part": {"type": "output_text", "text": full_text,
|
|
438
|
+
"annotations": []}}))
|
|
439
|
+
await ws.send_text(json.dumps({
|
|
440
|
+
"type": "response.output_item.done", "output_index": 0,
|
|
441
|
+
"item": {"type": "message", "id": mid,
|
|
442
|
+
"status": "completed", "role": "assistant",
|
|
443
|
+
"content": [{"type": "output_text",
|
|
444
|
+
"text": full_text,
|
|
445
|
+
"annotations": []}]}}))
|
|
446
|
+
|
|
447
|
+
# Build final output using shared function
|
|
448
|
+
final_out = build_final_output(mid, full_text, reasoning_text, tool_calls)
|
|
449
|
+
|
|
450
|
+
# Emit tool call events
|
|
451
|
+
text_and_reasoning_count = len(final_out) - len(tool_calls)
|
|
452
|
+
for i, item in enumerate(final_out[text_and_reasoning_count:], text_and_reasoning_count):
|
|
453
|
+
fc = {k: v for k, v in item.items() if k != "status"}
|
|
454
|
+
await ws.send_text(json.dumps({
|
|
455
|
+
"type": "response.output_item.added",
|
|
456
|
+
"output_index": i, "item": fc}))
|
|
457
|
+
await ws.send_text(json.dumps({
|
|
458
|
+
"type": "response.output_item.done",
|
|
459
|
+
"output_index": i, "item": fc}))
|
|
460
|
+
|
|
461
|
+
completed = {"id": rid, "object": "response",
|
|
462
|
+
"created_at": now, "model": model,
|
|
463
|
+
"status": "completed", "output": final_out,
|
|
464
|
+
"usage": usage_data}
|
|
465
|
+
await ws.send_text(json.dumps(
|
|
466
|
+
{"type": "response.completed",
|
|
467
|
+
"response": completed}))
|
|
468
|
+
|
|
469
|
+
state.store.put(rid, {**completed,
|
|
470
|
+
"_original_input": body.get("input", [])})
|
|
471
|
+
|
|
472
|
+
except WebSocketDisconnect:
|
|
473
|
+
logger.info("WebSocket disconnected")
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.error("WebSocket error: %s", e)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
# ── Utility endpoints ───────────────────────────────────────────────────
|
|
479
|
+
|
|
480
|
+
@app.get("/responses/{response_id}")
|
|
481
|
+
async def get_response(response_id: str):
|
|
482
|
+
state = _state()
|
|
483
|
+
resp = state.store.get(response_id)
|
|
484
|
+
if not resp:
|
|
485
|
+
return JSONResponse(
|
|
486
|
+
{"error": {"message": "Response not found", "code": "not_found"}},
|
|
487
|
+
status_code=404,
|
|
488
|
+
)
|
|
489
|
+
clean = {k: v for k, v in resp.items() if not k.startswith("_")}
|
|
490
|
+
return JSONResponse(clean)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
@app.get("/models")
|
|
494
|
+
@app.get("/v1/models")
|
|
495
|
+
async def models():
|
|
496
|
+
provider = _state().config.provider
|
|
497
|
+
return JSONResponse({
|
|
498
|
+
"object": "list",
|
|
499
|
+
"data": [{"id": m, "object": "model",
|
|
500
|
+
"owned_by": provider.name}
|
|
501
|
+
for m in provider.models],
|
|
502
|
+
})
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
@app.get("/health")
|
|
506
|
+
async def health(request: Request):
|
|
507
|
+
state = _state()
|
|
508
|
+
result: dict[str, Any] = {
|
|
509
|
+
"status": "ok",
|
|
510
|
+
"proxy": "codex-proxy",
|
|
511
|
+
"version": __version__,
|
|
512
|
+
}
|
|
513
|
+
if request.query_params.get("check_backend"):
|
|
514
|
+
try:
|
|
515
|
+
r = await state.client.get(
|
|
516
|
+
f"{state.config.provider.base_url}/models",
|
|
517
|
+
headers={"Authorization": f"Bearer {state.config.provider.effective_api_key()}"},
|
|
518
|
+
timeout=5.0,
|
|
519
|
+
)
|
|
520
|
+
result["backend"] = "ok" if r.status_code < 400 else "error"
|
|
521
|
+
result["backend_status"] = r.status_code
|
|
522
|
+
except Exception as e:
|
|
523
|
+
result["backend"] = "unreachable"
|
|
524
|
+
result["backend_error"] = str(e)
|
|
525
|
+
result["status"] = "degraded"
|
|
526
|
+
return result
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
@app.get("/status")
|
|
530
|
+
async def status():
|
|
531
|
+
state = _state()
|
|
532
|
+
provider = state.config.provider
|
|
533
|
+
uptime = int(time.time() - state.start_time)
|
|
534
|
+
result = {
|
|
535
|
+
"proxy": "codex-proxy",
|
|
536
|
+
"version": __version__,
|
|
537
|
+
"status": "running",
|
|
538
|
+
"uptime_seconds": uptime,
|
|
539
|
+
"requests_total": state.request_count,
|
|
540
|
+
"response_store_size": state.store.size(),
|
|
541
|
+
"provider": {
|
|
542
|
+
"name": provider.name,
|
|
543
|
+
"display_name": provider.display_name,
|
|
544
|
+
"base_url": provider.base_url,
|
|
545
|
+
"models": provider.models,
|
|
546
|
+
"default_model": provider.default_model,
|
|
547
|
+
},
|
|
548
|
+
"server": {
|
|
549
|
+
"host": state.config.server.host,
|
|
550
|
+
"port": state.config.server.port,
|
|
551
|
+
},
|
|
552
|
+
}
|
|
553
|
+
if state.circuit_breaker:
|
|
554
|
+
result["circuit_breaker"] = state.circuit_breaker.get_status()
|
|
555
|
+
if state.key_rotator:
|
|
556
|
+
result["key_rotation"] = {
|
|
557
|
+
"total_keys": len(state.key_rotator._keys),
|
|
558
|
+
"keys": state.key_rotator.get_status(),
|
|
559
|
+
}
|
|
560
|
+
if state.plugin_registry:
|
|
561
|
+
result["plugins"] = {
|
|
562
|
+
"enabled": True,
|
|
563
|
+
"loaded": state.plugin_registry.list_plugins(),
|
|
564
|
+
}
|
|
565
|
+
return JSONResponse(result)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def reload_config_internal(state: AppState) -> tuple[str, str]:
|
|
569
|
+
"""Internal config reload — updates state in-place.
|
|
570
|
+
|
|
571
|
+
Returns a tuple of (display_name, default_model).
|
|
572
|
+
"""
|
|
573
|
+
from .config import load_config
|
|
574
|
+
new_config = load_config()
|
|
575
|
+
state.config = new_config
|
|
576
|
+
state.adapter = get_adapter(new_config.provider.name)
|
|
577
|
+
if (new_config.store.ttl_seconds != state.store.ttl_seconds or
|
|
578
|
+
new_config.store.max_entries != state.store.max_entries):
|
|
579
|
+
state.store = ResponseStore(
|
|
580
|
+
ttl_seconds=new_config.store.ttl_seconds,
|
|
581
|
+
max_entries=new_config.store.max_entries,
|
|
582
|
+
)
|
|
583
|
+
cb_config = new_config.circuit_breaker
|
|
584
|
+
state.circuit_breaker = CircuitBreaker(
|
|
585
|
+
failure_threshold=cb_config.failure_threshold,
|
|
586
|
+
recovery_timeout=cb_config.recovery_timeout,
|
|
587
|
+
) if cb_config.enabled else None
|
|
588
|
+
keys = new_config.provider.effective_api_keys()
|
|
589
|
+
if len(keys) > 1:
|
|
590
|
+
state.key_rotator = KeyRotator(
|
|
591
|
+
keys=keys,
|
|
592
|
+
failure_threshold=cb_config.failure_threshold,
|
|
593
|
+
recovery_timeout=cb_config.recovery_timeout,
|
|
594
|
+
)
|
|
595
|
+
else:
|
|
596
|
+
state.key_rotator = None
|
|
597
|
+
state.plugin_registry = _build_plugin_registry(new_config)
|
|
598
|
+
return new_config.provider.display_name, new_config.provider.default_model
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
@app.post("/reload")
|
|
602
|
+
async def reload_config():
|
|
603
|
+
"""Reload config from disk without restarting."""
|
|
604
|
+
state = _state()
|
|
605
|
+
old_registry = state.plugin_registry
|
|
606
|
+
try:
|
|
607
|
+
display_name, default_model = reload_config_internal(state)
|
|
608
|
+
if old_registry:
|
|
609
|
+
await old_registry.on_shutdown()
|
|
610
|
+
if state.plugin_registry:
|
|
611
|
+
await state.plugin_registry.on_startup(state.config)
|
|
612
|
+
logger.info("Config reloaded successfully")
|
|
613
|
+
return {"status": "reloaded", "provider": display_name,
|
|
614
|
+
"model": default_model}
|
|
615
|
+
except Exception as e:
|
|
616
|
+
logger.error("Config reload failed: %s", e)
|
|
617
|
+
return JSONResponse(
|
|
618
|
+
{"status": "error", "message": str(e)},
|
|
619
|
+
status_code=500,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
def run(config: ProxyConfig, *, tui: bool = False) -> None:
|
|
624
|
+
"""Run the proxy server."""
|
|
625
|
+
configure(config)
|
|
626
|
+
level = getattr(logging, config.server.log_level.upper(), logging.WARNING)
|
|
627
|
+
logging.basicConfig(level=level, format="%(asctime)s [%(name)s] %(levelname)s %(message)s")
|
|
628
|
+
if config.server.log_level == "debug":
|
|
629
|
+
config.server.log_dir.mkdir(parents=True, exist_ok=True)
|
|
630
|
+
fh = logging.FileHandler(config.server.log_dir / "proxy.log", encoding="utf-8")
|
|
631
|
+
fh.setLevel(logging.DEBUG)
|
|
632
|
+
fh.setFormatter(logging.Formatter("%(asctime)s [%(name)s] %(levelname)s %(message)s"))
|
|
633
|
+
logging.getLogger("codex-proxy").addHandler(fh)
|
|
634
|
+
logger.info("codex-proxy v%s starting", __version__)
|
|
635
|
+
|
|
636
|
+
if tui:
|
|
637
|
+
from .tui import start_tui
|
|
638
|
+
start_tui(app.state.proxy)
|
|
639
|
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
|
640
|
+
else:
|
|
641
|
+
print(f" codex-proxy v{__version__}")
|
|
642
|
+
print(f" http://{config.server.host}:{config.server.port}")
|
|
643
|
+
print(f" backend {config.provider.display_name} ({config.provider.base_url})")
|
|
644
|
+
print(f" models {', '.join(config.provider.models)}")
|
|
645
|
+
|
|
646
|
+
uvicorn.run(app, host=config.server.host, port=config.server.port,
|
|
647
|
+
log_level=config.server.log_level)
|