codex-lb 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/core/config/settings.py +8 -8
- app/core/handlers/__init__.py +3 -0
- app/core/handlers/exceptions.py +39 -0
- app/core/middleware/__init__.py +9 -0
- app/core/middleware/api_errors.py +33 -0
- app/core/middleware/request_decompression.py +158 -0
- app/core/middleware/request_id.py +27 -0
- app/core/openai/chat_requests.py +172 -0
- app/core/openai/chat_responses.py +534 -0
- app/core/openai/message_coercion.py +60 -0
- app/core/openai/models_catalog.py +72 -0
- app/core/openai/requests.py +4 -4
- app/core/openai/v1_requests.py +4 -60
- app/db/session.py +25 -8
- app/dependencies.py +43 -16
- app/main.py +12 -67
- app/modules/accounts/repository.py +21 -9
- app/modules/proxy/api.py +58 -0
- app/modules/proxy/load_balancer.py +75 -58
- app/modules/proxy/repo_bundle.py +23 -0
- app/modules/proxy/service.py +98 -102
- app/modules/request_logs/repository.py +3 -0
- app/modules/usage/service.py +65 -4
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.1.dist-info}/METADATA +4 -2
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.1.dist-info}/RECORD +28 -17
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.1.dist-info}/WHEEL +0 -0
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.1.dist-info}/entry_points.txt +0 -0
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.1.dist-info}/licenses/LICENSE +0 -0
app/core/config/settings.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from functools import lru_cache
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
|
-
from pydantic import field_validator
|
|
6
|
+
from pydantic import Field, field_validator
|
|
7
7
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
8
8
|
|
|
9
9
|
BASE_DIR = Path(__file__).resolve().parents[3]
|
|
@@ -22,6 +22,9 @@ class Settings(BaseSettings):
|
|
|
22
22
|
)
|
|
23
23
|
|
|
24
24
|
database_url: str = f"sqlite+aiosqlite:///{DEFAULT_DB_PATH}"
|
|
25
|
+
database_pool_size: int = Field(default=15, gt=0)
|
|
26
|
+
database_max_overflow: int = Field(default=10, ge=0)
|
|
27
|
+
database_pool_timeout_seconds: float = Field(default=30.0, gt=0)
|
|
25
28
|
upstream_base_url: str = "https://chatgpt.com/backend-api"
|
|
26
29
|
upstream_connect_timeout_seconds: float = 30.0
|
|
27
30
|
stream_idle_timeout_seconds: float = 300.0
|
|
@@ -43,24 +46,21 @@ class Settings(BaseSettings):
|
|
|
43
46
|
log_proxy_request_shape: bool = False
|
|
44
47
|
log_proxy_request_shape_raw_cache_key: bool = False
|
|
45
48
|
log_proxy_request_payload: bool = False
|
|
49
|
+
max_decompressed_body_bytes: int = Field(default=32 * 1024 * 1024, gt=0)
|
|
46
50
|
|
|
47
51
|
@field_validator("database_url")
|
|
48
52
|
@classmethod
|
|
49
|
-
def
|
|
50
|
-
if not isinstance(value, str):
|
|
51
|
-
return value
|
|
52
|
-
|
|
53
|
+
def _expand_database_url(cls, value: str) -> str:
|
|
53
54
|
for prefix in ("sqlite+aiosqlite:///", "sqlite:///"):
|
|
54
55
|
if value.startswith(prefix):
|
|
55
56
|
path = value[len(prefix) :]
|
|
56
57
|
if path.startswith("~"):
|
|
57
|
-
|
|
58
|
-
return f"{prefix}{expanded}"
|
|
58
|
+
return f"{prefix}{Path(path).expanduser()}"
|
|
59
59
|
return value
|
|
60
60
|
|
|
61
61
|
@field_validator("encryption_key_file", mode="before")
|
|
62
62
|
@classmethod
|
|
63
|
-
def
|
|
63
|
+
def _expand_encryption_key_file(cls, value: str | Path) -> Path:
|
|
64
64
|
if isinstance(value, Path):
|
|
65
65
|
return value.expanduser()
|
|
66
66
|
if isinstance(value, str):
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from fastapi import FastAPI, Request
|
|
4
|
+
from fastapi.exception_handlers import (
|
|
5
|
+
http_exception_handler,
|
|
6
|
+
request_validation_exception_handler,
|
|
7
|
+
)
|
|
8
|
+
from fastapi.exceptions import RequestValidationError
|
|
9
|
+
from fastapi.responses import JSONResponse, Response
|
|
10
|
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
11
|
+
|
|
12
|
+
from app.core.errors import dashboard_error
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def add_exception_handlers(app: FastAPI) -> None:
|
|
16
|
+
@app.exception_handler(RequestValidationError)
|
|
17
|
+
async def validation_error_handler(
|
|
18
|
+
request: Request,
|
|
19
|
+
exc: RequestValidationError,
|
|
20
|
+
) -> Response:
|
|
21
|
+
if request.url.path.startswith("/api/"):
|
|
22
|
+
return JSONResponse(
|
|
23
|
+
status_code=422,
|
|
24
|
+
content=dashboard_error("validation_error", "Invalid request payload"),
|
|
25
|
+
)
|
|
26
|
+
return await request_validation_exception_handler(request, exc)
|
|
27
|
+
|
|
28
|
+
@app.exception_handler(StarletteHTTPException)
|
|
29
|
+
async def http_error_handler(
|
|
30
|
+
request: Request,
|
|
31
|
+
exc: StarletteHTTPException,
|
|
32
|
+
) -> Response:
|
|
33
|
+
if request.url.path.startswith("/api/"):
|
|
34
|
+
detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
|
|
35
|
+
return JSONResponse(
|
|
36
|
+
status_code=exc.status_code,
|
|
37
|
+
content=dashboard_error(f"http_{exc.status_code}", detail),
|
|
38
|
+
)
|
|
39
|
+
return await http_exception_handler(request, exc)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from app.core.middleware.api_errors import add_api_unhandled_error_middleware
|
|
2
|
+
from app.core.middleware.request_decompression import add_request_decompression_middleware
|
|
3
|
+
from app.core.middleware.request_id import add_request_id_middleware
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"add_api_unhandled_error_middleware",
|
|
7
|
+
"add_request_decompression_middleware",
|
|
8
|
+
"add_request_id_middleware",
|
|
9
|
+
]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI, Request
|
|
7
|
+
from fastapi.responses import JSONResponse, Response
|
|
8
|
+
|
|
9
|
+
from app.core.errors import dashboard_error
|
|
10
|
+
from app.core.utils.request_id import get_request_id
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def add_api_unhandled_error_middleware(app: FastAPI) -> None:
|
|
16
|
+
@app.middleware("http")
|
|
17
|
+
async def api_unhandled_error_middleware(
|
|
18
|
+
request: Request,
|
|
19
|
+
call_next: Callable[[Request], Awaitable[Response]],
|
|
20
|
+
) -> Response:
|
|
21
|
+
try:
|
|
22
|
+
return await call_next(request)
|
|
23
|
+
except Exception:
|
|
24
|
+
if request.url.path.startswith("/api/"):
|
|
25
|
+
logger.exception(
|
|
26
|
+
"Unhandled API error request_id=%s",
|
|
27
|
+
get_request_id(),
|
|
28
|
+
)
|
|
29
|
+
return JSONResponse(
|
|
30
|
+
status_code=500,
|
|
31
|
+
content=dashboard_error("internal_error", "Unexpected error"),
|
|
32
|
+
)
|
|
33
|
+
raise
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import gzip
|
|
4
|
+
import io
|
|
5
|
+
import zlib
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from typing import Protocol
|
|
8
|
+
|
|
9
|
+
import zstandard as zstd
|
|
10
|
+
from fastapi import FastAPI, Request
|
|
11
|
+
from fastapi.responses import JSONResponse, Response
|
|
12
|
+
|
|
13
|
+
from app.core.config.settings import get_settings
|
|
14
|
+
from app.core.errors import dashboard_error
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class _DecompressedBodyTooLarge(Exception):
|
|
18
|
+
def __init__(self, max_size: int) -> None:
|
|
19
|
+
super().__init__(f"Decompressed body exceeded {max_size} bytes")
|
|
20
|
+
self.max_size = max_size
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _Readable(Protocol):
|
|
24
|
+
def read(self, size: int = ...) -> bytes: ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _read_limited(reader: _Readable, max_size: int) -> bytes:
|
|
28
|
+
buffer = bytearray()
|
|
29
|
+
total = 0
|
|
30
|
+
chunk_size = 64 * 1024
|
|
31
|
+
while True:
|
|
32
|
+
chunk = reader.read(chunk_size)
|
|
33
|
+
if not chunk:
|
|
34
|
+
break
|
|
35
|
+
total += len(chunk)
|
|
36
|
+
if total > max_size:
|
|
37
|
+
raise _DecompressedBodyTooLarge(max_size)
|
|
38
|
+
buffer.extend(chunk)
|
|
39
|
+
return bytes(buffer)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _decompress_gzip(data: bytes, max_size: int) -> bytes:
|
|
43
|
+
with gzip.GzipFile(fileobj=io.BytesIO(data)) as reader:
|
|
44
|
+
return _read_limited(reader, max_size)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _decompress_deflate(data: bytes, max_size: int) -> bytes:
|
|
48
|
+
decompressor = zlib.decompressobj()
|
|
49
|
+
buffer = bytearray()
|
|
50
|
+
chunk_size = 64 * 1024
|
|
51
|
+
for start in range(0, len(data), chunk_size):
|
|
52
|
+
chunk = data[start : start + chunk_size]
|
|
53
|
+
# Bound output growth to avoid oversized allocations.
|
|
54
|
+
while chunk:
|
|
55
|
+
remaining = max_size - len(buffer)
|
|
56
|
+
if remaining == 0:
|
|
57
|
+
raise _DecompressedBodyTooLarge(max_size)
|
|
58
|
+
buffer.extend(decompressor.decompress(chunk, max_length=remaining))
|
|
59
|
+
chunk = decompressor.unconsumed_tail
|
|
60
|
+
while True:
|
|
61
|
+
remaining = max_size - len(buffer)
|
|
62
|
+
if remaining == 0:
|
|
63
|
+
raise _DecompressedBodyTooLarge(max_size)
|
|
64
|
+
drained = decompressor.decompress(b"", max_length=remaining)
|
|
65
|
+
if not drained:
|
|
66
|
+
break
|
|
67
|
+
buffer.extend(drained)
|
|
68
|
+
if not decompressor.eof:
|
|
69
|
+
raise zlib.error("Incomplete deflate stream")
|
|
70
|
+
return bytes(buffer)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _decompress_zstd(data: bytes, max_size: int) -> bytes:
|
|
74
|
+
try:
|
|
75
|
+
decompressed = zstd.ZstdDecompressor().decompress(data, max_output_size=max_size)
|
|
76
|
+
if len(decompressed) > max_size:
|
|
77
|
+
raise _DecompressedBodyTooLarge(max_size)
|
|
78
|
+
return decompressed
|
|
79
|
+
except _DecompressedBodyTooLarge:
|
|
80
|
+
raise
|
|
81
|
+
except Exception:
|
|
82
|
+
with zstd.ZstdDecompressor().stream_reader(io.BytesIO(data)) as reader:
|
|
83
|
+
return _read_limited(reader, max_size)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _decompress_body(data: bytes, encodings: list[str], max_size: int) -> bytes:
|
|
87
|
+
supported = {"zstd", "gzip", "deflate", "identity"}
|
|
88
|
+
if any(encoding not in supported for encoding in encodings):
|
|
89
|
+
raise ValueError("Unsupported content-encoding")
|
|
90
|
+
result = data
|
|
91
|
+
for encoding in reversed(encodings):
|
|
92
|
+
if encoding == "zstd":
|
|
93
|
+
result = _decompress_zstd(result, max_size)
|
|
94
|
+
elif encoding == "gzip":
|
|
95
|
+
result = _decompress_gzip(result, max_size)
|
|
96
|
+
elif encoding == "deflate":
|
|
97
|
+
result = _decompress_deflate(result, max_size)
|
|
98
|
+
elif encoding == "identity":
|
|
99
|
+
continue
|
|
100
|
+
return result
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _replace_request_body(request: Request, body: bytes) -> None:
|
|
104
|
+
request._body = body
|
|
105
|
+
headers: list[tuple[bytes, bytes]] = []
|
|
106
|
+
for key, value in request.scope.get("headers", []):
|
|
107
|
+
if key.lower() in (b"content-encoding", b"content-length"):
|
|
108
|
+
continue
|
|
109
|
+
headers.append((key, value))
|
|
110
|
+
headers.append((b"content-length", str(len(body)).encode("ascii")))
|
|
111
|
+
request.scope["headers"] = headers
|
|
112
|
+
# Ensure subsequent request.headers reflects the updated scope headers.
|
|
113
|
+
request._headers = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def add_request_decompression_middleware(app: FastAPI) -> None:
|
|
117
|
+
@app.middleware("http")
|
|
118
|
+
async def request_decompression_middleware(
|
|
119
|
+
request: Request,
|
|
120
|
+
call_next: Callable[[Request], Awaitable[Response]],
|
|
121
|
+
) -> Response:
|
|
122
|
+
content_encoding = request.headers.get("content-encoding")
|
|
123
|
+
if not content_encoding:
|
|
124
|
+
return await call_next(request)
|
|
125
|
+
encodings = [enc.strip().lower() for enc in content_encoding.split(",") if enc.strip()]
|
|
126
|
+
if not encodings:
|
|
127
|
+
return await call_next(request)
|
|
128
|
+
body = await request.body()
|
|
129
|
+
settings = get_settings()
|
|
130
|
+
max_size = settings.max_decompressed_body_bytes
|
|
131
|
+
try:
|
|
132
|
+
decompressed = _decompress_body(body, encodings, max_size)
|
|
133
|
+
except _DecompressedBodyTooLarge:
|
|
134
|
+
return JSONResponse(
|
|
135
|
+
status_code=413,
|
|
136
|
+
content=dashboard_error(
|
|
137
|
+
"payload_too_large",
|
|
138
|
+
"Request body exceeds the maximum allowed size",
|
|
139
|
+
),
|
|
140
|
+
)
|
|
141
|
+
except ValueError:
|
|
142
|
+
return JSONResponse(
|
|
143
|
+
status_code=400,
|
|
144
|
+
content=dashboard_error(
|
|
145
|
+
"invalid_request",
|
|
146
|
+
"Unsupported Content-Encoding",
|
|
147
|
+
),
|
|
148
|
+
)
|
|
149
|
+
except Exception:
|
|
150
|
+
return JSONResponse(
|
|
151
|
+
status_code=400,
|
|
152
|
+
content=dashboard_error(
|
|
153
|
+
"invalid_request",
|
|
154
|
+
"Request body is compressed but could not be decompressed",
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
_replace_request_body(request, decompressed)
|
|
158
|
+
return await call_next(request)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Awaitable, Callable
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
from fastapi import FastAPI, Request
|
|
7
|
+
from fastapi.responses import JSONResponse
|
|
8
|
+
|
|
9
|
+
from app.core.utils.request_id import reset_request_id, set_request_id
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def add_request_id_middleware(app: FastAPI) -> None:
|
|
13
|
+
@app.middleware("http")
|
|
14
|
+
async def request_id_middleware(
|
|
15
|
+
request: Request,
|
|
16
|
+
call_next: Callable[[Request], Awaitable[JSONResponse]],
|
|
17
|
+
) -> JSONResponse:
|
|
18
|
+
inbound_request_id = request.headers.get("x-request-id") or request.headers.get("request-id")
|
|
19
|
+
request_id = inbound_request_id or str(uuid4())
|
|
20
|
+
token = set_request_id(request_id)
|
|
21
|
+
try:
|
|
22
|
+
response = await call_next(request)
|
|
23
|
+
except Exception:
|
|
24
|
+
reset_request_id(token)
|
|
25
|
+
raise
|
|
26
|
+
response.headers.setdefault("x-request-id", request_id)
|
|
27
|
+
return response
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
7
|
+
|
|
8
|
+
from app.core.openai.message_coercion import coerce_messages
|
|
9
|
+
from app.core.openai.requests import ResponsesRequest, ResponsesTextControls, ResponsesTextFormat
|
|
10
|
+
from app.core.types import JsonValue
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ChatCompletionsRequest(BaseModel):
|
|
14
|
+
model_config = ConfigDict(extra="allow")
|
|
15
|
+
|
|
16
|
+
model: str = Field(min_length=1)
|
|
17
|
+
messages: list[dict[str, JsonValue]]
|
|
18
|
+
tools: list[JsonValue] = Field(default_factory=list)
|
|
19
|
+
tool_choice: str | dict[str, JsonValue] | None = None
|
|
20
|
+
parallel_tool_calls: bool | None = None
|
|
21
|
+
stream: bool | None = None
|
|
22
|
+
temperature: float | None = None
|
|
23
|
+
top_p: float | None = None
|
|
24
|
+
stop: str | list[str] | None = None
|
|
25
|
+
n: int | None = None
|
|
26
|
+
presence_penalty: float | None = None
|
|
27
|
+
frequency_penalty: float | None = None
|
|
28
|
+
logprobs: bool | None = None
|
|
29
|
+
top_logprobs: int | None = None
|
|
30
|
+
seed: int | None = None
|
|
31
|
+
response_format: JsonValue | None = None
|
|
32
|
+
max_tokens: int | None = None
|
|
33
|
+
max_completion_tokens: int | None = None
|
|
34
|
+
store: bool | None = None
|
|
35
|
+
|
|
36
|
+
@model_validator(mode="after")
|
|
37
|
+
def _validate_messages(self) -> "ChatCompletionsRequest":
|
|
38
|
+
if not self.messages:
|
|
39
|
+
raise ValueError("'messages' must be a non-empty list.")
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def to_responses_request(self) -> ResponsesRequest:
|
|
43
|
+
data = self.model_dump(mode="json", exclude_none=True)
|
|
44
|
+
messages = data.pop("messages")
|
|
45
|
+
data.pop("store", None)
|
|
46
|
+
data.pop("max_tokens", None)
|
|
47
|
+
data.pop("max_completion_tokens", None)
|
|
48
|
+
response_format = data.pop("response_format", None)
|
|
49
|
+
tools = _normalize_chat_tools(data.pop("tools", []))
|
|
50
|
+
tool_choice = _normalize_tool_choice(data.pop("tool_choice", None))
|
|
51
|
+
reasoning_effort = data.pop("reasoning_effort", None)
|
|
52
|
+
if reasoning_effort is not None and "reasoning" not in data:
|
|
53
|
+
data["reasoning"] = {"effort": reasoning_effort}
|
|
54
|
+
if response_format is not None:
|
|
55
|
+
_apply_response_format(data, response_format)
|
|
56
|
+
instructions, input_items = coerce_messages("", messages)
|
|
57
|
+
data["instructions"] = instructions
|
|
58
|
+
data["input"] = input_items
|
|
59
|
+
data["tools"] = tools
|
|
60
|
+
if tool_choice is not None:
|
|
61
|
+
data["tool_choice"] = tool_choice
|
|
62
|
+
return ResponsesRequest.model_validate(data)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ChatResponseFormatJsonSchema(BaseModel):
|
|
66
|
+
model_config = ConfigDict(extra="allow", populate_by_name=True)
|
|
67
|
+
|
|
68
|
+
name: str | None = None
|
|
69
|
+
schema_: JsonValue | None = Field(default=None, alias="schema")
|
|
70
|
+
strict: bool | None = None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ChatResponseFormat(BaseModel):
|
|
74
|
+
model_config = ConfigDict(extra="allow")
|
|
75
|
+
|
|
76
|
+
type: str = Field(min_length=1)
|
|
77
|
+
json_schema: ChatResponseFormatJsonSchema | None = None
|
|
78
|
+
|
|
79
|
+
@model_validator(mode="after")
|
|
80
|
+
def _validate_schema(self) -> "ChatResponseFormat":
|
|
81
|
+
if self.type == "json_schema" and self.json_schema is None:
|
|
82
|
+
raise ValueError("'response_format.json_schema' is required when type is 'json_schema'.")
|
|
83
|
+
return self
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _normalize_chat_tools(tools: list[JsonValue]) -> list[JsonValue]:
|
|
87
|
+
normalized: list[JsonValue] = []
|
|
88
|
+
for tool in tools:
|
|
89
|
+
if not isinstance(tool, dict):
|
|
90
|
+
continue
|
|
91
|
+
tool_type = tool.get("type")
|
|
92
|
+
function = tool.get("function")
|
|
93
|
+
if isinstance(function, dict):
|
|
94
|
+
name = function.get("name")
|
|
95
|
+
if not isinstance(name, str) or not name:
|
|
96
|
+
continue
|
|
97
|
+
normalized.append(
|
|
98
|
+
{
|
|
99
|
+
"type": tool_type or "function",
|
|
100
|
+
"name": name,
|
|
101
|
+
"description": function.get("description"),
|
|
102
|
+
"parameters": function.get("parameters"),
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
continue
|
|
106
|
+
name = tool.get("name")
|
|
107
|
+
if isinstance(name, str) and name:
|
|
108
|
+
normalized.append(tool)
|
|
109
|
+
return normalized
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _normalize_tool_choice(tool_choice: JsonValue | None) -> JsonValue | None:
|
|
113
|
+
if not isinstance(tool_choice, dict):
|
|
114
|
+
return tool_choice
|
|
115
|
+
tool_type = tool_choice.get("type")
|
|
116
|
+
function = tool_choice.get("function")
|
|
117
|
+
if isinstance(function, dict):
|
|
118
|
+
name = function.get("name")
|
|
119
|
+
if isinstance(name, str) and name:
|
|
120
|
+
return {"type": tool_type or "function", "name": name}
|
|
121
|
+
return tool_choice
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _apply_response_format(data: dict[str, JsonValue], response_format: JsonValue) -> None:
|
|
125
|
+
text_controls = _parse_text_controls(data.get("text"))
|
|
126
|
+
if text_controls is None:
|
|
127
|
+
text_controls = ResponsesTextControls()
|
|
128
|
+
if text_controls.format is not None:
|
|
129
|
+
raise ValueError("Provide either 'response_format' or 'text.format', not both.")
|
|
130
|
+
text_controls.format = _response_format_to_text_format(response_format)
|
|
131
|
+
data["text"] = cast(JsonValue, text_controls.model_dump(mode="json", exclude_none=True))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _parse_text_controls(text: JsonValue | None) -> ResponsesTextControls | None:
|
|
135
|
+
if text is None:
|
|
136
|
+
return None
|
|
137
|
+
if not isinstance(text, Mapping):
|
|
138
|
+
raise ValueError("'text' must be an object when using 'response_format'.")
|
|
139
|
+
return ResponsesTextControls.model_validate(text)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _response_format_to_text_format(response_format: JsonValue) -> ResponsesTextFormat:
|
|
143
|
+
if isinstance(response_format, str):
|
|
144
|
+
return _text_format_from_type(response_format)
|
|
145
|
+
if isinstance(response_format, Mapping):
|
|
146
|
+
parsed = ChatResponseFormat.model_validate(response_format)
|
|
147
|
+
return _text_format_from_parsed(parsed)
|
|
148
|
+
raise ValueError("'response_format' must be a string or object.")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _text_format_from_type(format_type: str) -> ResponsesTextFormat:
|
|
152
|
+
if format_type in ("json_object", "text"):
|
|
153
|
+
return ResponsesTextFormat(type=format_type)
|
|
154
|
+
if format_type == "json_schema":
|
|
155
|
+
raise ValueError("'response_format' must include 'json_schema' when type is 'json_schema'.")
|
|
156
|
+
raise ValueError(f"Unsupported response_format.type: {format_type}")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _text_format_from_parsed(parsed: ChatResponseFormat) -> ResponsesTextFormat:
|
|
160
|
+
if parsed.type == "json_schema":
|
|
161
|
+
json_schema = parsed.json_schema
|
|
162
|
+
if json_schema is None:
|
|
163
|
+
raise ValueError("'response_format.json_schema' is required when type is 'json_schema'.")
|
|
164
|
+
return ResponsesTextFormat(
|
|
165
|
+
type=parsed.type,
|
|
166
|
+
schema_=json_schema.schema_,
|
|
167
|
+
name=json_schema.name,
|
|
168
|
+
strict=json_schema.strict,
|
|
169
|
+
)
|
|
170
|
+
if parsed.type in ("json_object", "text"):
|
|
171
|
+
return ResponsesTextFormat(type=parsed.type)
|
|
172
|
+
raise ValueError(f"Unsupported response_format.type: {parsed.type}")
|