codex-lb 0.3.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/core/clients/proxy.py +33 -3
- app/core/config/settings.py +9 -8
- app/core/handlers/__init__.py +3 -0
- app/core/handlers/exceptions.py +39 -0
- app/core/middleware/__init__.py +9 -0
- app/core/middleware/api_errors.py +33 -0
- app/core/middleware/request_decompression.py +101 -0
- app/core/middleware/request_id.py +27 -0
- app/core/openai/chat_requests.py +172 -0
- app/core/openai/chat_responses.py +534 -0
- app/core/openai/message_coercion.py +60 -0
- app/core/openai/models_catalog.py +72 -0
- app/core/openai/requests.py +23 -5
- app/core/openai/v1_requests.py +92 -0
- app/db/models.py +3 -3
- app/db/session.py +25 -8
- app/dependencies.py +43 -16
- app/main.py +13 -67
- app/modules/accounts/repository.py +25 -10
- app/modules/proxy/api.py +94 -0
- app/modules/proxy/load_balancer.py +75 -58
- app/modules/proxy/repo_bundle.py +23 -0
- app/modules/proxy/service.py +127 -102
- app/modules/request_logs/api.py +61 -7
- app/modules/request_logs/repository.py +131 -16
- app/modules/request_logs/schemas.py +11 -2
- app/modules/request_logs/service.py +97 -20
- app/modules/usage/service.py +65 -4
- app/modules/usage/updater.py +58 -26
- app/static/index.css +378 -1
- app/static/index.html +183 -8
- app/static/index.js +308 -13
- {codex_lb-0.3.1.dist-info → codex_lb-0.5.0.dist-info}/METADATA +42 -3
- {codex_lb-0.3.1.dist-info → codex_lb-0.5.0.dist-info}/RECORD +37 -25
- {codex_lb-0.3.1.dist-info → codex_lb-0.5.0.dist-info}/WHEEL +0 -0
- {codex_lb-0.3.1.dist-info → codex_lb-0.5.0.dist-info}/entry_points.txt +0 -0
- {codex_lb-0.3.1.dist-info → codex_lb-0.5.0.dist-info}/licenses/LICENSE +0 -0
app/core/clients/proxy.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from typing import AsyncIterator, Mapping
|
|
4
|
+
from typing import AsyncIterator, Mapping, Protocol, TypeAlias
|
|
5
5
|
|
|
6
6
|
import aiohttp
|
|
7
7
|
|
|
@@ -28,6 +28,18 @@ class StreamIdleTimeoutError(Exception):
|
|
|
28
28
|
pass
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
class ErrorResponseProtocol(Protocol):
|
|
32
|
+
status: int
|
|
33
|
+
reason: str | None
|
|
34
|
+
|
|
35
|
+
async def json(self, *, content_type: str | None = None) -> object: ...
|
|
36
|
+
|
|
37
|
+
async def text(self, *, encoding: str | None = None, errors: str = "strict") -> str: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
ErrorResponse: TypeAlias = aiohttp.ClientResponse | ErrorResponseProtocol
|
|
41
|
+
|
|
42
|
+
|
|
31
43
|
class ProxyResponseError(Exception):
|
|
32
44
|
def __init__(self, status_code: int, payload: OpenAIErrorEnvelope) -> None:
|
|
33
45
|
super().__init__(f"Proxy response error ({status_code})")
|
|
@@ -88,8 +100,10 @@ async def _iter_sse_lines(
|
|
|
88
100
|
yield line
|
|
89
101
|
|
|
90
102
|
|
|
91
|
-
async def _error_event_from_response(resp:
|
|
103
|
+
async def _error_event_from_response(resp: ErrorResponse) -> ResponseFailedEvent:
|
|
92
104
|
fallback_message = f"Upstream error: HTTP {resp.status}"
|
|
105
|
+
if resp.reason:
|
|
106
|
+
fallback_message += f" {resp.reason}"
|
|
93
107
|
try:
|
|
94
108
|
data = await resp.json(content_type=None)
|
|
95
109
|
except Exception:
|
|
@@ -112,11 +126,16 @@ async def _error_event_from_response(resp: aiohttp.ClientResponse) -> ResponseFa
|
|
|
112
126
|
if key in payload:
|
|
113
127
|
event["response"]["error"][key] = payload[key]
|
|
114
128
|
return event
|
|
129
|
+
message = _extract_upstream_message(data)
|
|
130
|
+
if message:
|
|
131
|
+
return response_failed_event("upstream_error", message, response_id=get_request_id())
|
|
115
132
|
return response_failed_event("upstream_error", fallback_message, response_id=get_request_id())
|
|
116
133
|
|
|
117
134
|
|
|
118
|
-
async def _error_payload_from_response(resp:
|
|
135
|
+
async def _error_payload_from_response(resp: ErrorResponse) -> OpenAIErrorEnvelope:
|
|
119
136
|
fallback_message = f"Upstream error: HTTP {resp.status}"
|
|
137
|
+
if resp.reason:
|
|
138
|
+
fallback_message += f" {resp.reason}"
|
|
120
139
|
try:
|
|
121
140
|
data = await resp.json(content_type=None)
|
|
122
141
|
except Exception:
|
|
@@ -128,9 +147,20 @@ async def _error_payload_from_response(resp: aiohttp.ClientResponse) -> OpenAIEr
|
|
|
128
147
|
error = parse_error_payload(data)
|
|
129
148
|
if error:
|
|
130
149
|
return {"error": error.model_dump(exclude_none=True)}
|
|
150
|
+
message = _extract_upstream_message(data)
|
|
151
|
+
if message:
|
|
152
|
+
return openai_error("upstream_error", message)
|
|
131
153
|
return openai_error("upstream_error", fallback_message)
|
|
132
154
|
|
|
133
155
|
|
|
156
|
+
def _extract_upstream_message(data: dict) -> str | None:
|
|
157
|
+
for key in ("message", "detail", "error"):
|
|
158
|
+
value = data.get(key)
|
|
159
|
+
if isinstance(value, str) and value.strip():
|
|
160
|
+
return value
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
|
|
134
164
|
async def stream_responses(
|
|
135
165
|
payload: ResponsesRequest,
|
|
136
166
|
headers: Mapping[str, str],
|
app/core/config/settings.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from functools import lru_cache
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
|
-
from pydantic import field_validator
|
|
6
|
+
from pydantic import Field, field_validator
|
|
7
7
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
8
8
|
|
|
9
9
|
BASE_DIR = Path(__file__).resolve().parents[3]
|
|
@@ -22,6 +22,9 @@ class Settings(BaseSettings):
|
|
|
22
22
|
)
|
|
23
23
|
|
|
24
24
|
database_url: str = f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}"
|
|
25
|
+
database_pool_size: int = Field(default=15, gt=0)
|
|
26
|
+
database_max_overflow: int = Field(default=10, ge=0)
|
|
27
|
+
database_pool_timeout_seconds: float = Field(default=30.0, gt=0)
|
|
25
28
|
upstream_base_url: str = "https://chatgpt.com/backend-api"
|
|
26
29
|
upstream_connect_timeout_seconds: float = 30.0
|
|
27
30
|
stream_idle_timeout_seconds: float = 300.0
|
|
@@ -42,24 +45,22 @@ class Settings(BaseSettings):
|
|
|
42
45
|
database_migrations_fail_fast: bool = True
|
|
43
46
|
log_proxy_request_shape: bool = False
|
|
44
47
|
log_proxy_request_shape_raw_cache_key: bool = False
|
|
48
|
+
log_proxy_request_payload: bool = False
|
|
49
|
+
max_decompressed_body_bytes: int = Field(default=32 * 1024 * 1024, gt=0)
|
|
45
50
|
|
|
46
51
|
@field_validator("database_url")
|
|
47
52
|
@classmethod
|
|
48
|
-
def
|
|
49
|
-
if not isinstance(value, str):
|
|
50
|
-
return value
|
|
51
|
-
|
|
53
|
+
def _expand_database_url(cls, value: str) -> str:
|
|
52
54
|
for prefix in ("sqlite+aiosqlite:///", "sqlite:///"):
|
|
53
55
|
if value.startswith(prefix):
|
|
54
56
|
path = value[len(prefix) :]
|
|
55
57
|
if path.startswith("~"):
|
|
56
|
-
|
|
57
|
-
return f"{prefix}{expanded}"
|
|
58
|
+
return f"{prefix}{Path(path).expanduser()}"
|
|
58
59
|
return value
|
|
59
60
|
|
|
60
61
|
@field_validator("encryption_key_file", mode="before")
|
|
61
62
|
@classmethod
|
|
62
|
-
def
|
|
63
|
+
def _expand_encryption_key_file(cls, value: str | Path) -> Path:
|
|
63
64
|
if isinstance(value, Path):
|
|
64
65
|
return value.expanduser()
|
|
65
66
|
if isinstance(value, str):
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from fastapi import FastAPI, Request
|
|
4
|
+
from fastapi.exception_handlers import (
|
|
5
|
+
http_exception_handler,
|
|
6
|
+
request_validation_exception_handler,
|
|
7
|
+
)
|
|
8
|
+
from fastapi.exceptions import RequestValidationError
|
|
9
|
+
from fastapi.responses import JSONResponse, Response
|
|
10
|
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
11
|
+
|
|
12
|
+
from app.core.errors import dashboard_error
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def add_exception_handlers(app: FastAPI) -> None:
|
|
16
|
+
@app.exception_handler(RequestValidationError)
|
|
17
|
+
async def validation_error_handler(
|
|
18
|
+
request: Request,
|
|
19
|
+
exc: RequestValidationError,
|
|
20
|
+
) -> Response:
|
|
21
|
+
if request.url.path.startswith("/api/"):
|
|
22
|
+
return JSONResponse(
|
|
23
|
+
status_code=422,
|
|
24
|
+
content=dashboard_error("validation_error", "Invalid request payload"),
|
|
25
|
+
)
|
|
26
|
+
return await request_validation_exception_handler(request, exc)
|
|
27
|
+
|
|
28
|
+
@app.exception_handler(StarletteHTTPException)
|
|
29
|
+
async def http_error_handler(
|
|
30
|
+
request: Request,
|
|
31
|
+
exc: StarletteHTTPException,
|
|
32
|
+
) -> Response:
|
|
33
|
+
if request.url.path.startswith("/api/"):
|
|
34
|
+
detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
|
|
35
|
+
return JSONResponse(
|
|
36
|
+
status_code=exc.status_code,
|
|
37
|
+
content=dashboard_error(f"http_{exc.status_code}", detail),
|
|
38
|
+
)
|
|
39
|
+
return await http_exception_handler(request, exc)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from app.core.middleware.api_errors import add_api_unhandled_error_middleware
|
|
2
|
+
from app.core.middleware.request_decompression import add_request_decompression_middleware
|
|
3
|
+
from app.core.middleware.request_id import add_request_id_middleware
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"add_api_unhandled_error_middleware",
|
|
7
|
+
"add_request_decompression_middleware",
|
|
8
|
+
"add_request_id_middleware",
|
|
9
|
+
]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI, Request
|
|
7
|
+
from fastapi.responses import JSONResponse, Response
|
|
8
|
+
|
|
9
|
+
from app.core.errors import dashboard_error
|
|
10
|
+
from app.core.utils.request_id import get_request_id
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def add_api_unhandled_error_middleware(app: FastAPI) -> None:
|
|
16
|
+
@app.middleware("http")
|
|
17
|
+
async def api_unhandled_error_middleware(
|
|
18
|
+
request: Request,
|
|
19
|
+
call_next: Callable[[Request], Awaitable[Response]],
|
|
20
|
+
) -> Response:
|
|
21
|
+
try:
|
|
22
|
+
return await call_next(request)
|
|
23
|
+
except Exception:
|
|
24
|
+
if request.url.path.startswith("/api/"):
|
|
25
|
+
logger.exception(
|
|
26
|
+
"Unhandled API error request_id=%s",
|
|
27
|
+
get_request_id(),
|
|
28
|
+
)
|
|
29
|
+
return JSONResponse(
|
|
30
|
+
status_code=500,
|
|
31
|
+
content=dashboard_error("internal_error", "Unexpected error"),
|
|
32
|
+
)
|
|
33
|
+
raise
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
|
|
7
|
+
import zstandard as zstd
|
|
8
|
+
from fastapi import FastAPI, Request
|
|
9
|
+
from fastapi.responses import JSONResponse, Response
|
|
10
|
+
|
|
11
|
+
from app.core.config.settings import get_settings
|
|
12
|
+
from app.core.errors import dashboard_error
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _DecompressedBodyTooLarge(Exception):
|
|
16
|
+
def __init__(self, max_size: int) -> None:
|
|
17
|
+
super().__init__(f"Decompressed body exceeded {max_size} bytes")
|
|
18
|
+
self.max_size = max_size
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class _Readable(Protocol):
|
|
22
|
+
def read(self, size: int = ...) -> bytes: ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _read_limited(reader: _Readable, max_size: int) -> bytes:
|
|
26
|
+
buffer = bytearray()
|
|
27
|
+
total = 0
|
|
28
|
+
chunk_size = 64 * 1024
|
|
29
|
+
while True:
|
|
30
|
+
chunk = reader.read(chunk_size)
|
|
31
|
+
if not chunk:
|
|
32
|
+
break
|
|
33
|
+
total += len(chunk)
|
|
34
|
+
if total > max_size:
|
|
35
|
+
raise _DecompressedBodyTooLarge(max_size)
|
|
36
|
+
buffer.extend(chunk)
|
|
37
|
+
return bytes(buffer)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _replace_request_body(request: Request, body: bytes) -> None:
|
|
41
|
+
request._body = body
|
|
42
|
+
headers: list[tuple[bytes, bytes]] = []
|
|
43
|
+
for key, value in request.scope.get("headers", []):
|
|
44
|
+
if key.lower() in (b"content-encoding", b"content-length"):
|
|
45
|
+
continue
|
|
46
|
+
headers.append((key, value))
|
|
47
|
+
headers.append((b"content-length", str(len(body)).encode("ascii")))
|
|
48
|
+
request.scope["headers"] = headers
|
|
49
|
+
# Ensure subsequent request.headers reflects the updated scope headers.
|
|
50
|
+
request._headers = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def add_request_decompression_middleware(app: FastAPI) -> None:
|
|
54
|
+
@app.middleware("http")
|
|
55
|
+
async def request_decompression_middleware(
|
|
56
|
+
request: Request,
|
|
57
|
+
call_next: Callable[[Request], Awaitable[Response]],
|
|
58
|
+
) -> Response:
|
|
59
|
+
content_encoding = request.headers.get("content-encoding")
|
|
60
|
+
if not content_encoding:
|
|
61
|
+
return await call_next(request)
|
|
62
|
+
encodings = [enc.strip().lower() for enc in content_encoding.split(",") if enc.strip()]
|
|
63
|
+
if encodings != ["zstd"]:
|
|
64
|
+
return await call_next(request)
|
|
65
|
+
body = await request.body()
|
|
66
|
+
settings = get_settings()
|
|
67
|
+
max_size = settings.max_decompressed_body_bytes
|
|
68
|
+
try:
|
|
69
|
+
decompressed = zstd.ZstdDecompressor().decompress(body, max_output_size=max_size)
|
|
70
|
+
if len(decompressed) > max_size:
|
|
71
|
+
raise _DecompressedBodyTooLarge(max_size)
|
|
72
|
+
except _DecompressedBodyTooLarge:
|
|
73
|
+
return JSONResponse(
|
|
74
|
+
status_code=413,
|
|
75
|
+
content=dashboard_error(
|
|
76
|
+
"payload_too_large",
|
|
77
|
+
"Request body exceeds the maximum allowed size",
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
except Exception:
|
|
81
|
+
try:
|
|
82
|
+
with zstd.ZstdDecompressor().stream_reader(io.BytesIO(body)) as reader:
|
|
83
|
+
decompressed = _read_limited(reader, max_size)
|
|
84
|
+
except _DecompressedBodyTooLarge:
|
|
85
|
+
return JSONResponse(
|
|
86
|
+
status_code=413,
|
|
87
|
+
content=dashboard_error(
|
|
88
|
+
"payload_too_large",
|
|
89
|
+
"Request body exceeds the maximum allowed size",
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
except Exception:
|
|
93
|
+
return JSONResponse(
|
|
94
|
+
status_code=400,
|
|
95
|
+
content=dashboard_error(
|
|
96
|
+
"invalid_request",
|
|
97
|
+
"Request body is zstd-compressed but could not be decompressed",
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
_replace_request_body(request, decompressed)
|
|
101
|
+
return await call_next(request)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Awaitable, Callable
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI, Request
|
|
7
|
+
from fastapi.responses import JSONResponse
|
|
8
|
+
|
|
9
|
+
from app.core.utils.request_id import reset_request_id, set_request_id
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def add_request_id_middleware(app: FastAPI) -> None:
|
|
13
|
+
@app.middleware("http")
|
|
14
|
+
async def request_id_middleware(
|
|
15
|
+
request: Request,
|
|
16
|
+
call_next: Callable[[Request], Awaitable[JSONResponse]],
|
|
17
|
+
) -> JSONResponse:
|
|
18
|
+
inbound_request_id = request.headers.get("x-request-id") or request.headers.get("request-id")
|
|
19
|
+
request_id = inbound_request_id or str(uuid4())
|
|
20
|
+
token = set_request_id(request_id)
|
|
21
|
+
try:
|
|
22
|
+
response = await call_next(request)
|
|
23
|
+
except Exception:
|
|
24
|
+
reset_request_id(token)
|
|
25
|
+
raise
|
|
26
|
+
response.headers.setdefault("x-request-id", request_id)
|
|
27
|
+
return response
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
7
|
+
|
|
8
|
+
from app.core.openai.message_coercion import coerce_messages
|
|
9
|
+
from app.core.openai.requests import ResponsesRequest, ResponsesTextControls, ResponsesTextFormat
|
|
10
|
+
from app.core.types import JsonValue
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ChatCompletionsRequest(BaseModel):
|
|
14
|
+
model_config = ConfigDict(extra="allow")
|
|
15
|
+
|
|
16
|
+
model: str = Field(min_length=1)
|
|
17
|
+
messages: list[dict[str, JsonValue]]
|
|
18
|
+
tools: list[JsonValue] = Field(default_factory=list)
|
|
19
|
+
tool_choice: str | dict[str, JsonValue] | None = None
|
|
20
|
+
parallel_tool_calls: bool | None = None
|
|
21
|
+
stream: bool | None = None
|
|
22
|
+
temperature: float | None = None
|
|
23
|
+
top_p: float | None = None
|
|
24
|
+
stop: str | list[str] | None = None
|
|
25
|
+
n: int | None = None
|
|
26
|
+
presence_penalty: float | None = None
|
|
27
|
+
frequency_penalty: float | None = None
|
|
28
|
+
logprobs: bool | None = None
|
|
29
|
+
top_logprobs: int | None = None
|
|
30
|
+
seed: int | None = None
|
|
31
|
+
response_format: JsonValue | None = None
|
|
32
|
+
max_tokens: int | None = None
|
|
33
|
+
max_completion_tokens: int | None = None
|
|
34
|
+
store: bool | None = None
|
|
35
|
+
|
|
36
|
+
@model_validator(mode="after")
|
|
37
|
+
def _validate_messages(self) -> "ChatCompletionsRequest":
|
|
38
|
+
if not self.messages:
|
|
39
|
+
raise ValueError("'messages' must be a non-empty list.")
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def to_responses_request(self) -> ResponsesRequest:
|
|
43
|
+
data = self.model_dump(mode="json", exclude_none=True)
|
|
44
|
+
messages = data.pop("messages")
|
|
45
|
+
data.pop("store", None)
|
|
46
|
+
data.pop("max_tokens", None)
|
|
47
|
+
data.pop("max_completion_tokens", None)
|
|
48
|
+
response_format = data.pop("response_format", None)
|
|
49
|
+
tools = _normalize_chat_tools(data.pop("tools", []))
|
|
50
|
+
tool_choice = _normalize_tool_choice(data.pop("tool_choice", None))
|
|
51
|
+
reasoning_effort = data.pop("reasoning_effort", None)
|
|
52
|
+
if reasoning_effort is not None and "reasoning" not in data:
|
|
53
|
+
data["reasoning"] = {"effort": reasoning_effort}
|
|
54
|
+
if response_format is not None:
|
|
55
|
+
_apply_response_format(data, response_format)
|
|
56
|
+
instructions, input_items = coerce_messages("", messages)
|
|
57
|
+
data["instructions"] = instructions
|
|
58
|
+
data["input"] = input_items
|
|
59
|
+
data["tools"] = tools
|
|
60
|
+
if tool_choice is not None:
|
|
61
|
+
data["tool_choice"] = tool_choice
|
|
62
|
+
return ResponsesRequest.model_validate(data)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ChatResponseFormatJsonSchema(BaseModel):
|
|
66
|
+
model_config = ConfigDict(extra="allow", populate_by_name=True)
|
|
67
|
+
|
|
68
|
+
name: str | None = None
|
|
69
|
+
schema_: JsonValue | None = Field(default=None, alias="schema")
|
|
70
|
+
strict: bool | None = None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ChatResponseFormat(BaseModel):
|
|
74
|
+
model_config = ConfigDict(extra="allow")
|
|
75
|
+
|
|
76
|
+
type: str = Field(min_length=1)
|
|
77
|
+
json_schema: ChatResponseFormatJsonSchema | None = None
|
|
78
|
+
|
|
79
|
+
@model_validator(mode="after")
|
|
80
|
+
def _validate_schema(self) -> "ChatResponseFormat":
|
|
81
|
+
if self.type == "json_schema" and self.json_schema is None:
|
|
82
|
+
raise ValueError("'response_format.json_schema' is required when type is 'json_schema'.")
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _normalize_chat_tools(tools: list[JsonValue]) -> list[JsonValue]:
|
|
87
|
+
normalized: list[JsonValue] = []
|
|
88
|
+
for tool in tools:
|
|
89
|
+
if not isinstance(tool, dict):
|
|
90
|
+
continue
|
|
91
|
+
tool_type = tool.get("type")
|
|
92
|
+
function = tool.get("function")
|
|
93
|
+
if isinstance(function, dict):
|
|
94
|
+
name = function.get("name")
|
|
95
|
+
if not isinstance(name, str) or not name:
|
|
96
|
+
continue
|
|
97
|
+
normalized.append(
|
|
98
|
+
{
|
|
99
|
+
"type": tool_type or "function",
|
|
100
|
+
"name": name,
|
|
101
|
+
"description": function.get("description"),
|
|
102
|
+
"parameters": function.get("parameters"),
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
continue
|
|
106
|
+
name = tool.get("name")
|
|
107
|
+
if isinstance(name, str) and name:
|
|
108
|
+
normalized.append(tool)
|
|
109
|
+
return normalized
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _normalize_tool_choice(tool_choice: JsonValue | None) -> JsonValue | None:
|
|
113
|
+
if not isinstance(tool_choice, dict):
|
|
114
|
+
return tool_choice
|
|
115
|
+
tool_type = tool_choice.get("type")
|
|
116
|
+
function = tool_choice.get("function")
|
|
117
|
+
if isinstance(function, dict):
|
|
118
|
+
name = function.get("name")
|
|
119
|
+
if isinstance(name, str) and name:
|
|
120
|
+
return {"type": tool_type or "function", "name": name}
|
|
121
|
+
return tool_choice
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _apply_response_format(data: dict[str, JsonValue], response_format: JsonValue) -> None:
|
|
125
|
+
text_controls = _parse_text_controls(data.get("text"))
|
|
126
|
+
if text_controls is None:
|
|
127
|
+
text_controls = ResponsesTextControls()
|
|
128
|
+
if text_controls.format is not None:
|
|
129
|
+
raise ValueError("Provide either 'response_format' or 'text.format', not both.")
|
|
130
|
+
text_controls.format = _response_format_to_text_format(response_format)
|
|
131
|
+
data["text"] = cast(JsonValue, text_controls.model_dump(mode="json", exclude_none=True))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _parse_text_controls(text: JsonValue | None) -> ResponsesTextControls | None:
|
|
135
|
+
if text is None:
|
|
136
|
+
return None
|
|
137
|
+
if not isinstance(text, Mapping):
|
|
138
|
+
raise ValueError("'text' must be an object when using 'response_format'.")
|
|
139
|
+
return ResponsesTextControls.model_validate(text)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _response_format_to_text_format(response_format: JsonValue) -> ResponsesTextFormat:
|
|
143
|
+
if isinstance(response_format, str):
|
|
144
|
+
return _text_format_from_type(response_format)
|
|
145
|
+
if isinstance(response_format, Mapping):
|
|
146
|
+
parsed = ChatResponseFormat.model_validate(response_format)
|
|
147
|
+
return _text_format_from_parsed(parsed)
|
|
148
|
+
raise ValueError("'response_format' must be a string or object.")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _text_format_from_type(format_type: str) -> ResponsesTextFormat:
|
|
152
|
+
if format_type in ("json_object", "text"):
|
|
153
|
+
return ResponsesTextFormat(type=format_type)
|
|
154
|
+
if format_type == "json_schema":
|
|
155
|
+
raise ValueError("'response_format' must include 'json_schema' when type is 'json_schema'.")
|
|
156
|
+
raise ValueError(f"Unsupported response_format.type: {format_type}")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _text_format_from_parsed(parsed: ChatResponseFormat) -> ResponsesTextFormat:
|
|
160
|
+
if parsed.type == "json_schema":
|
|
161
|
+
json_schema = parsed.json_schema
|
|
162
|
+
if json_schema is None:
|
|
163
|
+
raise ValueError("'response_format.json_schema' is required when type is 'json_schema'.")
|
|
164
|
+
return ResponsesTextFormat(
|
|
165
|
+
type=parsed.type,
|
|
166
|
+
schema_=json_schema.schema_,
|
|
167
|
+
name=json_schema.name,
|
|
168
|
+
strict=json_schema.strict,
|
|
169
|
+
)
|
|
170
|
+
if parsed.type in ("json_object", "text"):
|
|
171
|
+
return ResponsesTextFormat(type=parsed.type)
|
|
172
|
+
raise ValueError(f"Unsupported response_format.type: {parsed.type}")
|