codex-lb 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/core/config/settings.py +8 -8
- app/core/handlers/__init__.py +3 -0
- app/core/handlers/exceptions.py +39 -0
- app/core/middleware/__init__.py +9 -0
- app/core/middleware/api_errors.py +33 -0
- app/core/middleware/request_decompression.py +101 -0
- app/core/middleware/request_id.py +27 -0
- app/core/openai/chat_requests.py +172 -0
- app/core/openai/chat_responses.py +534 -0
- app/core/openai/message_coercion.py +60 -0
- app/core/openai/models_catalog.py +72 -0
- app/core/openai/requests.py +4 -4
- app/core/openai/v1_requests.py +4 -60
- app/db/session.py +25 -8
- app/dependencies.py +43 -16
- app/main.py +12 -67
- app/modules/accounts/repository.py +21 -9
- app/modules/proxy/api.py +58 -0
- app/modules/proxy/load_balancer.py +75 -58
- app/modules/proxy/repo_bundle.py +23 -0
- app/modules/proxy/service.py +98 -102
- app/modules/request_logs/repository.py +3 -0
- app/modules/usage/service.py +65 -4
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.0.dist-info}/METADATA +3 -2
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.0.dist-info}/RECORD +28 -17
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.0.dist-info}/WHEEL +0 -0
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.0.dist-info}/entry_points.txt +0 -0
- {codex_lb-0.4.0.dist-info → codex_lb-0.5.0.dist-info}/licenses/LICENSE +0 -0
app/core/openai/v1_requests.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import cast
|
|
4
|
-
|
|
5
3
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
6
4
|
|
|
5
|
+
from app.core.openai.message_coercion import coerce_messages
|
|
7
6
|
from app.core.openai.requests import (
|
|
8
7
|
ResponsesCompactRequest,
|
|
9
8
|
ResponsesReasoning,
|
|
@@ -21,7 +20,7 @@ class V1ResponsesRequest(BaseModel):
|
|
|
21
20
|
input: list[JsonValue] | None = None
|
|
22
21
|
instructions: str | None = None
|
|
23
22
|
tools: list[JsonValue] = Field(default_factory=list)
|
|
24
|
-
tool_choice: str | None = None
|
|
23
|
+
tool_choice: str | dict[str, JsonValue] | None = None
|
|
25
24
|
parallel_tool_calls: bool | None = None
|
|
26
25
|
reasoning: ResponsesReasoning | None = None
|
|
27
26
|
store: bool | None = None
|
|
@@ -54,7 +53,7 @@ class V1ResponsesRequest(BaseModel):
|
|
|
54
53
|
input_items: list[JsonValue] = input_value if isinstance(input_value, list) else []
|
|
55
54
|
|
|
56
55
|
if messages is not None:
|
|
57
|
-
instruction_text, input_items =
|
|
56
|
+
instruction_text, input_items = coerce_messages(instruction_text, messages)
|
|
58
57
|
|
|
59
58
|
data["instructions"] = instruction_text
|
|
60
59
|
data["input"] = input_items
|
|
@@ -86,63 +85,8 @@ class V1ResponsesCompactRequest(BaseModel):
|
|
|
86
85
|
input_items: list[JsonValue] = input_value if isinstance(input_value, list) else []
|
|
87
86
|
|
|
88
87
|
if messages is not None:
|
|
89
|
-
instruction_text, input_items =
|
|
88
|
+
instruction_text, input_items = coerce_messages(instruction_text, messages)
|
|
90
89
|
|
|
91
90
|
data["instructions"] = instruction_text
|
|
92
91
|
data["input"] = input_items
|
|
93
92
|
return ResponsesCompactRequest.model_validate(data)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def _coerce_messages(existing_instructions: str, messages: list[JsonValue]) -> tuple[str, list[JsonValue]]:
|
|
97
|
-
instruction_parts: list[str] = []
|
|
98
|
-
input_messages: list[JsonValue] = []
|
|
99
|
-
for message in messages:
|
|
100
|
-
if not isinstance(message, dict):
|
|
101
|
-
raise ValueError("Each message must be an object.")
|
|
102
|
-
message_dict = cast(dict[str, JsonValue], message)
|
|
103
|
-
role_value = message_dict.get("role")
|
|
104
|
-
role = role_value if isinstance(role_value, str) else None
|
|
105
|
-
if role in ("system", "developer"):
|
|
106
|
-
content_text = _content_to_text(message_dict.get("content"))
|
|
107
|
-
if content_text:
|
|
108
|
-
instruction_parts.append(content_text)
|
|
109
|
-
continue
|
|
110
|
-
input_messages.append(cast(JsonValue, message_dict))
|
|
111
|
-
merged = _merge_instructions(existing_instructions, instruction_parts)
|
|
112
|
-
return merged, input_messages
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def _merge_instructions(existing: str, extra_parts: list[str]) -> str:
|
|
116
|
-
if not extra_parts:
|
|
117
|
-
return existing
|
|
118
|
-
extra = "\n".join([part for part in extra_parts if part])
|
|
119
|
-
if not extra:
|
|
120
|
-
return existing
|
|
121
|
-
if existing:
|
|
122
|
-
return f"{existing}\n{extra}"
|
|
123
|
-
return extra
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def _content_to_text(content: object) -> str | None:
|
|
127
|
-
if content is None:
|
|
128
|
-
return None
|
|
129
|
-
if isinstance(content, str):
|
|
130
|
-
return content
|
|
131
|
-
if isinstance(content, list):
|
|
132
|
-
parts: list[str] = []
|
|
133
|
-
for part in content:
|
|
134
|
-
if isinstance(part, str):
|
|
135
|
-
parts.append(part)
|
|
136
|
-
elif isinstance(part, dict):
|
|
137
|
-
part_dict = cast(dict[str, JsonValue], part)
|
|
138
|
-
text = part_dict.get("text")
|
|
139
|
-
if isinstance(text, str):
|
|
140
|
-
parts.append(text)
|
|
141
|
-
return "\n".join([part for part in parts if part])
|
|
142
|
-
if isinstance(content, dict):
|
|
143
|
-
content_dict = cast(dict[str, JsonValue], content)
|
|
144
|
-
text = content_dict.get("text")
|
|
145
|
-
if isinstance(text, str):
|
|
146
|
-
return text
|
|
147
|
-
return None
|
|
148
|
-
return None
|
app/db/session.py
CHANGED
|
@@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
|
|
13
13
|
from app.core.config.settings import get_settings
|
|
14
14
|
from app.db.migrations import run_migrations
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
_settings = get_settings()
|
|
17
17
|
|
|
18
18
|
logger = logging.getLogger(__name__)
|
|
19
19
|
|
|
@@ -43,15 +43,32 @@ def _configure_sqlite_engine(engine: Engine, *, enable_wal: bool) -> None:
|
|
|
43
43
|
cursor.close()
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
if _is_sqlite_url(
|
|
46
|
+
if _is_sqlite_url(_settings.database_url):
|
|
47
|
+
is_sqlite_memory = _is_sqlite_memory_url(_settings.database_url)
|
|
48
|
+
if is_sqlite_memory:
|
|
49
|
+
engine = create_async_engine(
|
|
50
|
+
_settings.database_url,
|
|
51
|
+
echo=False,
|
|
52
|
+
connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
|
|
53
|
+
)
|
|
54
|
+
else:
|
|
55
|
+
engine = create_async_engine(
|
|
56
|
+
_settings.database_url,
|
|
57
|
+
echo=False,
|
|
58
|
+
pool_size=_settings.database_pool_size,
|
|
59
|
+
max_overflow=_settings.database_max_overflow,
|
|
60
|
+
pool_timeout=_settings.database_pool_timeout_seconds,
|
|
61
|
+
connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
|
|
62
|
+
)
|
|
63
|
+
_configure_sqlite_engine(engine.sync_engine, enable_wal=not is_sqlite_memory)
|
|
64
|
+
else:
|
|
47
65
|
engine = create_async_engine(
|
|
48
|
-
|
|
66
|
+
_settings.database_url,
|
|
49
67
|
echo=False,
|
|
50
|
-
|
|
68
|
+
pool_size=_settings.database_pool_size,
|
|
69
|
+
max_overflow=_settings.database_max_overflow,
|
|
70
|
+
pool_timeout=_settings.database_pool_timeout_seconds,
|
|
51
71
|
)
|
|
52
|
-
_configure_sqlite_engine(engine.sync_engine, enable_wal=not _is_sqlite_memory_url(DATABASE_URL))
|
|
53
|
-
else:
|
|
54
|
-
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
55
72
|
|
|
56
73
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
|
57
74
|
|
|
@@ -116,7 +133,7 @@ async def get_session() -> AsyncIterator[AsyncSession]:
|
|
|
116
133
|
async def init_db() -> None:
|
|
117
134
|
from app.db.models import Base
|
|
118
135
|
|
|
119
|
-
_ensure_sqlite_dir(
|
|
136
|
+
_ensure_sqlite_dir(_settings.database_url)
|
|
120
137
|
|
|
121
138
|
async with engine.begin() as conn:
|
|
122
139
|
await conn.run_sync(Base.metadata.create_all)
|
app/dependencies.py
CHANGED
|
@@ -11,6 +11,7 @@ from app.db.session import SessionLocal, _safe_close, _safe_rollback, get_sessio
|
|
|
11
11
|
from app.modules.accounts.repository import AccountsRepository
|
|
12
12
|
from app.modules.accounts.service import AccountsService
|
|
13
13
|
from app.modules.oauth.service import OauthService
|
|
14
|
+
from app.modules.proxy.repo_bundle import ProxyRepositories
|
|
14
15
|
from app.modules.proxy.service import ProxyService
|
|
15
16
|
from app.modules.proxy.sticky_repository import StickySessionsRepository
|
|
16
17
|
from app.modules.request_logs.repository import RequestLogsRepository
|
|
@@ -79,7 +80,12 @@ def get_usage_context(
|
|
|
79
80
|
usage_repository = UsageRepository(session)
|
|
80
81
|
request_logs_repository = RequestLogsRepository(session)
|
|
81
82
|
accounts_repository = AccountsRepository(session)
|
|
82
|
-
service = UsageService(
|
|
83
|
+
service = UsageService(
|
|
84
|
+
usage_repository,
|
|
85
|
+
request_logs_repository,
|
|
86
|
+
accounts_repository,
|
|
87
|
+
refresh_repo_factory=_usage_refresh_context,
|
|
88
|
+
)
|
|
83
89
|
return UsageContext(
|
|
84
90
|
session=session,
|
|
85
91
|
usage_repository=usage_repository,
|
|
@@ -101,6 +107,40 @@ async def _accounts_repo_context() -> AsyncIterator[AccountsRepository]:
|
|
|
101
107
|
await _safe_close(session)
|
|
102
108
|
|
|
103
109
|
|
|
110
|
+
@asynccontextmanager
|
|
111
|
+
async def _usage_refresh_context() -> AsyncIterator[tuple[UsageRepository, AccountsRepository]]:
|
|
112
|
+
session = SessionLocal()
|
|
113
|
+
try:
|
|
114
|
+
yield UsageRepository(session), AccountsRepository(session)
|
|
115
|
+
except BaseException:
|
|
116
|
+
await _safe_rollback(session)
|
|
117
|
+
raise
|
|
118
|
+
finally:
|
|
119
|
+
if session.in_transaction():
|
|
120
|
+
await _safe_rollback(session)
|
|
121
|
+
await _safe_close(session)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@asynccontextmanager
|
|
125
|
+
async def _proxy_repo_context() -> AsyncIterator[ProxyRepositories]:
|
|
126
|
+
session = SessionLocal()
|
|
127
|
+
try:
|
|
128
|
+
yield ProxyRepositories(
|
|
129
|
+
accounts=AccountsRepository(session),
|
|
130
|
+
usage=UsageRepository(session),
|
|
131
|
+
request_logs=RequestLogsRepository(session),
|
|
132
|
+
sticky_sessions=StickySessionsRepository(session),
|
|
133
|
+
settings=SettingsRepository(session),
|
|
134
|
+
)
|
|
135
|
+
except BaseException:
|
|
136
|
+
await _safe_rollback(session)
|
|
137
|
+
raise
|
|
138
|
+
finally:
|
|
139
|
+
if session.in_transaction():
|
|
140
|
+
await _safe_rollback(session)
|
|
141
|
+
await _safe_close(session)
|
|
142
|
+
|
|
143
|
+
|
|
104
144
|
def get_oauth_context(
|
|
105
145
|
session: AsyncSession = Depends(get_session),
|
|
106
146
|
) -> OauthContext:
|
|
@@ -108,21 +148,8 @@ def get_oauth_context(
|
|
|
108
148
|
return OauthContext(service=OauthService(accounts_repository, repo_factory=_accounts_repo_context))
|
|
109
149
|
|
|
110
150
|
|
|
111
|
-
def get_proxy_context(
|
|
112
|
-
|
|
113
|
-
) -> ProxyContext:
|
|
114
|
-
accounts_repository = AccountsRepository(session)
|
|
115
|
-
usage_repository = UsageRepository(session)
|
|
116
|
-
request_logs_repository = RequestLogsRepository(session)
|
|
117
|
-
sticky_repository = StickySessionsRepository(session)
|
|
118
|
-
settings_repository = SettingsRepository(session)
|
|
119
|
-
service = ProxyService(
|
|
120
|
-
accounts_repository,
|
|
121
|
-
usage_repository,
|
|
122
|
-
request_logs_repository,
|
|
123
|
-
sticky_repository,
|
|
124
|
-
settings_repository,
|
|
125
|
-
)
|
|
151
|
+
def get_proxy_context() -> ProxyContext:
|
|
152
|
+
service = ProxyService(repo_factory=_proxy_repo_context)
|
|
126
153
|
return ProxyContext(service=service)
|
|
127
154
|
|
|
128
155
|
|
app/main.py
CHANGED
|
@@ -1,23 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
from contextlib import asynccontextmanager
|
|
5
4
|
from pathlib import Path
|
|
6
|
-
from uuid import uuid4
|
|
7
5
|
|
|
8
|
-
from fastapi import FastAPI
|
|
9
|
-
from fastapi.
|
|
10
|
-
http_exception_handler,
|
|
11
|
-
request_validation_exception_handler,
|
|
12
|
-
)
|
|
13
|
-
from fastapi.exceptions import RequestValidationError
|
|
14
|
-
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, Response
|
|
6
|
+
from fastapi import FastAPI
|
|
7
|
+
from fastapi.responses import FileResponse, RedirectResponse
|
|
15
8
|
from fastapi.staticfiles import StaticFiles
|
|
16
|
-
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
17
9
|
|
|
18
10
|
from app.core.clients.http import close_http_client, init_http_client
|
|
19
|
-
from app.core.
|
|
20
|
-
from app.core.
|
|
11
|
+
from app.core.handlers import add_exception_handlers
|
|
12
|
+
from app.core.middleware import (
|
|
13
|
+
add_api_unhandled_error_middleware,
|
|
14
|
+
add_request_decompression_middleware,
|
|
15
|
+
add_request_id_middleware,
|
|
16
|
+
)
|
|
21
17
|
from app.db.session import close_db, init_db
|
|
22
18
|
from app.modules.accounts import api as accounts_api
|
|
23
19
|
from app.modules.health import api as health_api
|
|
@@ -27,8 +23,6 @@ from app.modules.request_logs import api as request_logs_api
|
|
|
27
23
|
from app.modules.settings import api as settings_api
|
|
28
24
|
from app.modules.usage import api as usage_api
|
|
29
25
|
|
|
30
|
-
logger = logging.getLogger(__name__)
|
|
31
|
-
|
|
32
26
|
|
|
33
27
|
@asynccontextmanager
|
|
34
28
|
async def lifespan(_: FastAPI):
|
|
@@ -47,59 +41,10 @@ async def lifespan(_: FastAPI):
|
|
|
47
41
|
def create_app() -> FastAPI:
|
|
48
42
|
app = FastAPI(title="codex-lb", version="0.1.0", lifespan=lifespan)
|
|
49
43
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
token = set_request_id(request_id)
|
|
55
|
-
try:
|
|
56
|
-
response = await call_next(request)
|
|
57
|
-
except Exception:
|
|
58
|
-
reset_request_id(token)
|
|
59
|
-
raise
|
|
60
|
-
response.headers.setdefault("x-request-id", request_id)
|
|
61
|
-
return response
|
|
62
|
-
|
|
63
|
-
@app.middleware("http")
|
|
64
|
-
async def api_unhandled_error_middleware(request: Request, call_next) -> Response:
|
|
65
|
-
try:
|
|
66
|
-
return await call_next(request)
|
|
67
|
-
except Exception:
|
|
68
|
-
if request.url.path.startswith("/api/"):
|
|
69
|
-
logger.exception(
|
|
70
|
-
"Unhandled API error request_id=%s",
|
|
71
|
-
get_request_id(),
|
|
72
|
-
)
|
|
73
|
-
return JSONResponse(
|
|
74
|
-
status_code=500,
|
|
75
|
-
content=dashboard_error("internal_error", "Unexpected error"),
|
|
76
|
-
)
|
|
77
|
-
raise
|
|
78
|
-
|
|
79
|
-
@app.exception_handler(RequestValidationError)
|
|
80
|
-
async def _validation_error_handler(
|
|
81
|
-
request: Request,
|
|
82
|
-
exc: RequestValidationError,
|
|
83
|
-
) -> Response:
|
|
84
|
-
if request.url.path.startswith("/api/"):
|
|
85
|
-
return JSONResponse(
|
|
86
|
-
status_code=422,
|
|
87
|
-
content=dashboard_error("validation_error", "Invalid request payload"),
|
|
88
|
-
)
|
|
89
|
-
return await request_validation_exception_handler(request, exc)
|
|
90
|
-
|
|
91
|
-
@app.exception_handler(StarletteHTTPException)
|
|
92
|
-
async def _http_error_handler(
|
|
93
|
-
request: Request,
|
|
94
|
-
exc: StarletteHTTPException,
|
|
95
|
-
) -> Response:
|
|
96
|
-
if request.url.path.startswith("/api/"):
|
|
97
|
-
detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
|
|
98
|
-
return JSONResponse(
|
|
99
|
-
status_code=exc.status_code,
|
|
100
|
-
content=dashboard_error(f"http_{exc.status_code}", detail),
|
|
101
|
-
)
|
|
102
|
-
return await http_exception_handler(request, exc)
|
|
44
|
+
add_request_decompression_middleware(app)
|
|
45
|
+
add_request_id_middleware(app)
|
|
46
|
+
add_api_unhandled_error_middleware(app)
|
|
47
|
+
add_exception_handlers(app)
|
|
103
48
|
|
|
104
49
|
app.include_router(proxy_api.router)
|
|
105
50
|
app.include_router(proxy_api.v1_router)
|
|
@@ -19,19 +19,19 @@ class AccountsRepository:
|
|
|
19
19
|
async def upsert(self, account: Account) -> Account:
|
|
20
20
|
existing = await self._session.get(Account, account.id)
|
|
21
21
|
if existing:
|
|
22
|
-
existing
|
|
23
|
-
existing.email = account.email
|
|
24
|
-
existing.plan_type = account.plan_type
|
|
25
|
-
existing.access_token_encrypted = account.access_token_encrypted
|
|
26
|
-
existing.refresh_token_encrypted = account.refresh_token_encrypted
|
|
27
|
-
existing.id_token_encrypted = account.id_token_encrypted
|
|
28
|
-
existing.last_refresh = account.last_refresh
|
|
29
|
-
existing.status = account.status
|
|
30
|
-
existing.deactivation_reason = account.deactivation_reason
|
|
22
|
+
_apply_account_updates(existing, account)
|
|
31
23
|
await self._session.commit()
|
|
32
24
|
await self._session.refresh(existing)
|
|
33
25
|
return existing
|
|
34
26
|
|
|
27
|
+
result = await self._session.execute(select(Account).where(Account.email == account.email))
|
|
28
|
+
existing_by_email = result.scalar_one_or_none()
|
|
29
|
+
if existing_by_email:
|
|
30
|
+
_apply_account_updates(existing_by_email, account)
|
|
31
|
+
await self._session.commit()
|
|
32
|
+
await self._session.refresh(existing_by_email)
|
|
33
|
+
return existing_by_email
|
|
34
|
+
|
|
35
35
|
self._session.add(account)
|
|
36
36
|
await self._session.commit()
|
|
37
37
|
await self._session.refresh(account)
|
|
@@ -89,3 +89,15 @@ class AccountsRepository:
|
|
|
89
89
|
)
|
|
90
90
|
await self._session.commit()
|
|
91
91
|
return result.scalar_one_or_none() is not None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _apply_account_updates(target: Account, source: Account) -> None:
|
|
95
|
+
target.chatgpt_account_id = source.chatgpt_account_id
|
|
96
|
+
target.email = source.email
|
|
97
|
+
target.plan_type = source.plan_type
|
|
98
|
+
target.access_token_encrypted = source.access_token_encrypted
|
|
99
|
+
target.refresh_token_encrypted = source.refresh_token_encrypted
|
|
100
|
+
target.id_token_encrypted = source.id_token_encrypted
|
|
101
|
+
target.last_refresh = source.last_refresh
|
|
102
|
+
target.status = source.status
|
|
103
|
+
target.deactivation_reason = source.deactivation_reason
|
app/modules/proxy/api.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import time
|
|
3
4
|
from collections.abc import AsyncIterator
|
|
4
5
|
|
|
5
6
|
from fastapi import APIRouter, Body, Depends, Request, Response
|
|
@@ -7,6 +8,9 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
7
8
|
|
|
8
9
|
from app.core.clients.proxy import ProxyResponseError
|
|
9
10
|
from app.core.errors import openai_error
|
|
11
|
+
from app.core.openai.chat_requests import ChatCompletionsRequest
|
|
12
|
+
from app.core.openai.chat_responses import collect_chat_completion, stream_chat_chunks
|
|
13
|
+
from app.core.openai.models_catalog import MODEL_CATALOG
|
|
10
14
|
from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest
|
|
11
15
|
from app.core.openai.v1_requests import V1ResponsesCompactRequest, V1ResponsesRequest
|
|
12
16
|
from app.dependencies import ProxyContext, get_proxy_context
|
|
@@ -35,6 +39,60 @@ async def v1_responses(
|
|
|
35
39
|
return await _stream_responses(request, payload.to_responses_request(), context)
|
|
36
40
|
|
|
37
41
|
|
|
42
|
+
@v1_router.get("/models")
|
|
43
|
+
async def v1_models() -> JSONResponse:
|
|
44
|
+
created = int(time.time())
|
|
45
|
+
items = [
|
|
46
|
+
{
|
|
47
|
+
"id": model_id,
|
|
48
|
+
"object": "model",
|
|
49
|
+
"created": created,
|
|
50
|
+
"owned_by": "codex-lb",
|
|
51
|
+
"metadata": entry.model_dump(mode="json"),
|
|
52
|
+
}
|
|
53
|
+
for model_id, entry in MODEL_CATALOG.items()
|
|
54
|
+
]
|
|
55
|
+
return JSONResponse({"object": "list", "data": items})
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@v1_router.post("/chat/completions")
|
|
59
|
+
async def v1_chat_completions(
|
|
60
|
+
request: Request,
|
|
61
|
+
payload: ChatCompletionsRequest = Body(...),
|
|
62
|
+
context: ProxyContext = Depends(get_proxy_context),
|
|
63
|
+
) -> Response:
|
|
64
|
+
rate_limit_headers = await context.service.rate_limit_headers()
|
|
65
|
+
responses_payload = payload.to_responses_request()
|
|
66
|
+
responses_payload.stream = True
|
|
67
|
+
stream = context.service.stream_responses(
|
|
68
|
+
responses_payload,
|
|
69
|
+
request.headers,
|
|
70
|
+
propagate_http_errors=True,
|
|
71
|
+
)
|
|
72
|
+
try:
|
|
73
|
+
first = await stream.__anext__()
|
|
74
|
+
except StopAsyncIteration:
|
|
75
|
+
first = None
|
|
76
|
+
except ProxyResponseError as exc:
|
|
77
|
+
return JSONResponse(status_code=exc.status_code, content=exc.payload, headers=rate_limit_headers)
|
|
78
|
+
|
|
79
|
+
stream_with_first = _prepend_first(first, stream)
|
|
80
|
+
if payload.stream:
|
|
81
|
+
return StreamingResponse(
|
|
82
|
+
stream_chat_chunks(stream_with_first, model=payload.model),
|
|
83
|
+
media_type="text/event-stream",
|
|
84
|
+
headers={"Cache-Control": "no-cache", **rate_limit_headers},
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
result = await collect_chat_completion(stream_with_first, model=payload.model)
|
|
88
|
+
status_code = 200
|
|
89
|
+
if isinstance(result, dict) and "error" in result:
|
|
90
|
+
error = result.get("error")
|
|
91
|
+
code = error.get("code") if isinstance(error, dict) else None
|
|
92
|
+
status_code = 503 if code == "no_accounts" else 502
|
|
93
|
+
return JSONResponse(content=result, status_code=status_code, headers=rate_limit_headers)
|
|
94
|
+
|
|
95
|
+
|
|
38
96
|
async def _stream_responses(
|
|
39
97
|
request: Request,
|
|
40
98
|
payload: ResponsesRequest,
|
|
@@ -16,8 +16,8 @@ from app.core.balancer.types import UpstreamError
|
|
|
16
16
|
from app.core.usage.quota import apply_usage_quota
|
|
17
17
|
from app.db.models import Account, UsageHistory
|
|
18
18
|
from app.modules.accounts.repository import AccountsRepository
|
|
19
|
+
from app.modules.proxy.repo_bundle import ProxyRepoFactory
|
|
19
20
|
from app.modules.proxy.sticky_repository import StickySessionsRepository
|
|
20
|
-
from app.modules.usage.repository import UsageRepository
|
|
21
21
|
from app.modules.usage.updater import UsageUpdater
|
|
22
22
|
|
|
23
23
|
|
|
@@ -37,16 +37,8 @@ class AccountSelection:
|
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class LoadBalancer:
|
|
40
|
-
def __init__(
|
|
41
|
-
self
|
|
42
|
-
accounts_repo: AccountsRepository,
|
|
43
|
-
usage_repo: UsageRepository,
|
|
44
|
-
sticky_repo: StickySessionsRepository | None = None,
|
|
45
|
-
) -> None:
|
|
46
|
-
self._accounts_repo = accounts_repo
|
|
47
|
-
self._usage_repo = usage_repo
|
|
48
|
-
self._usage_updater = UsageUpdater(usage_repo, accounts_repo)
|
|
49
|
-
self._sticky_repo = sticky_repo
|
|
40
|
+
def __init__(self, repo_factory: ProxyRepoFactory) -> None:
|
|
41
|
+
self._repo_factory = repo_factory
|
|
50
42
|
self._runtime: dict[str, RuntimeState] = {}
|
|
51
43
|
|
|
52
44
|
async def select_account(
|
|
@@ -56,43 +48,53 @@ class LoadBalancer:
|
|
|
56
48
|
reallocate_sticky: bool = False,
|
|
57
49
|
prefer_earlier_reset_accounts: bool = False,
|
|
58
50
|
) -> AccountSelection:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
51
|
+
selected_snapshot: Account | None = None
|
|
52
|
+
error_message: str | None = None
|
|
53
|
+
async with self._repo_factory() as repos:
|
|
54
|
+
accounts = await repos.accounts.list_accounts()
|
|
55
|
+
latest_primary = await repos.usage.latest_by_account()
|
|
56
|
+
updater = UsageUpdater(repos.usage, repos.accounts)
|
|
57
|
+
await updater.refresh_accounts(accounts, latest_primary)
|
|
58
|
+
latest_primary = await repos.usage.latest_by_account()
|
|
59
|
+
latest_secondary = await repos.usage.latest_by_account(window="secondary")
|
|
60
|
+
|
|
61
|
+
states, account_map = _build_states(
|
|
62
|
+
accounts=accounts,
|
|
63
|
+
latest_primary=latest_primary,
|
|
64
|
+
latest_secondary=latest_secondary,
|
|
65
|
+
runtime=self._runtime,
|
|
66
|
+
)
|
|
71
67
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
68
|
+
result = await self._select_with_stickiness(
|
|
69
|
+
states=states,
|
|
70
|
+
account_map=account_map,
|
|
71
|
+
sticky_key=sticky_key,
|
|
72
|
+
reallocate_sticky=reallocate_sticky,
|
|
73
|
+
prefer_earlier_reset_accounts=prefer_earlier_reset_accounts,
|
|
74
|
+
sticky_repo=repos.sticky_sessions,
|
|
75
|
+
)
|
|
76
|
+
for state in states:
|
|
77
|
+
account = account_map.get(state.account_id)
|
|
78
|
+
if account:
|
|
79
|
+
await self._sync_state(repos.accounts, account, state)
|
|
80
|
+
|
|
81
|
+
if result.account is None:
|
|
82
|
+
error_message = result.error_message
|
|
83
|
+
else:
|
|
84
|
+
selected = account_map.get(result.account.account_id)
|
|
85
|
+
if selected is None:
|
|
86
|
+
error_message = result.error_message
|
|
87
|
+
else:
|
|
88
|
+
selected.status = result.account.status
|
|
89
|
+
selected.deactivation_reason = result.account.deactivation_reason
|
|
90
|
+
selected_snapshot = _clone_account(selected)
|
|
91
|
+
|
|
92
|
+
if selected_snapshot is None:
|
|
93
|
+
return AccountSelection(account=None, error_message=error_message)
|
|
94
|
+
|
|
95
|
+
runtime = self._runtime.setdefault(selected_snapshot.id, RuntimeState())
|
|
96
|
+
runtime.last_selected_at = time.time()
|
|
97
|
+
return AccountSelection(account=selected_snapshot, error_message=None)
|
|
96
98
|
|
|
97
99
|
async def _select_with_stickiness(
|
|
98
100
|
self,
|
|
@@ -102,21 +104,22 @@ class LoadBalancer:
|
|
|
102
104
|
sticky_key: str | None,
|
|
103
105
|
reallocate_sticky: bool,
|
|
104
106
|
prefer_earlier_reset_accounts: bool,
|
|
107
|
+
sticky_repo: StickySessionsRepository | None,
|
|
105
108
|
) -> SelectionResult:
|
|
106
|
-
if not sticky_key or not
|
|
109
|
+
if not sticky_key or not sticky_repo:
|
|
107
110
|
return select_account(states, prefer_earlier_reset=prefer_earlier_reset_accounts)
|
|
108
111
|
|
|
109
112
|
if reallocate_sticky:
|
|
110
113
|
chosen = select_account(states, prefer_earlier_reset=prefer_earlier_reset_accounts)
|
|
111
114
|
if chosen.account is not None and chosen.account.account_id in account_map:
|
|
112
|
-
await
|
|
115
|
+
await sticky_repo.upsert(sticky_key, chosen.account.account_id)
|
|
113
116
|
return chosen
|
|
114
117
|
|
|
115
|
-
existing = await
|
|
118
|
+
existing = await sticky_repo.get_account_id(sticky_key)
|
|
116
119
|
if existing:
|
|
117
120
|
pinned = next((state for state in states if state.account_id == existing), None)
|
|
118
121
|
if pinned is None:
|
|
119
|
-
await
|
|
122
|
+
await sticky_repo.delete(sticky_key)
|
|
120
123
|
else:
|
|
121
124
|
pinned_result = select_account([pinned], prefer_earlier_reset=prefer_earlier_reset_accounts)
|
|
122
125
|
if pinned_result.account is not None:
|
|
@@ -124,29 +127,33 @@ class LoadBalancer:
|
|
|
124
127
|
|
|
125
128
|
chosen = select_account(states, prefer_earlier_reset=prefer_earlier_reset_accounts)
|
|
126
129
|
if chosen.account is not None and chosen.account.account_id in account_map:
|
|
127
|
-
await
|
|
130
|
+
await sticky_repo.upsert(sticky_key, chosen.account.account_id)
|
|
128
131
|
return chosen
|
|
129
132
|
|
|
130
133
|
async def mark_rate_limit(self, account: Account, error: UpstreamError) -> None:
|
|
131
134
|
state = self._state_for(account)
|
|
132
135
|
handle_rate_limit(state, error)
|
|
133
|
-
|
|
136
|
+
async with self._repo_factory() as repos:
|
|
137
|
+
await self._sync_state(repos.accounts, account, state)
|
|
134
138
|
|
|
135
139
|
async def mark_quota_exceeded(self, account: Account, error: UpstreamError) -> None:
|
|
136
140
|
state = self._state_for(account)
|
|
137
141
|
handle_quota_exceeded(state, error)
|
|
138
|
-
|
|
142
|
+
async with self._repo_factory() as repos:
|
|
143
|
+
await self._sync_state(repos.accounts, account, state)
|
|
139
144
|
|
|
140
145
|
async def mark_permanent_failure(self, account: Account, error_code: str) -> None:
|
|
141
146
|
state = self._state_for(account)
|
|
142
147
|
handle_permanent_failure(state, error_code)
|
|
143
|
-
|
|
148
|
+
async with self._repo_factory() as repos:
|
|
149
|
+
await self._sync_state(repos.accounts, account, state)
|
|
144
150
|
|
|
145
151
|
async def record_error(self, account: Account) -> None:
|
|
146
152
|
state = self._state_for(account)
|
|
147
153
|
state.error_count += 1
|
|
148
154
|
state.last_error_at = time.time()
|
|
149
|
-
|
|
155
|
+
async with self._repo_factory() as repos:
|
|
156
|
+
await self._sync_state(repos.accounts, account, state)
|
|
150
157
|
|
|
151
158
|
def _state_for(self, account: Account) -> AccountState:
|
|
152
159
|
runtime = self._runtime.setdefault(account.id, RuntimeState())
|
|
@@ -164,7 +171,12 @@ class LoadBalancer:
|
|
|
164
171
|
deactivation_reason=account.deactivation_reason,
|
|
165
172
|
)
|
|
166
173
|
|
|
167
|
-
async def _sync_state(
|
|
174
|
+
async def _sync_state(
|
|
175
|
+
self,
|
|
176
|
+
accounts_repo: AccountsRepository,
|
|
177
|
+
account: Account,
|
|
178
|
+
state: AccountState,
|
|
179
|
+
) -> None:
|
|
168
180
|
runtime = self._runtime.setdefault(account.id, RuntimeState())
|
|
169
181
|
runtime.reset_at = state.reset_at
|
|
170
182
|
runtime.cooldown_until = state.cooldown_until
|
|
@@ -177,7 +189,7 @@ class LoadBalancer:
|
|
|
177
189
|
reset_changed = account.reset_at != reset_at_int
|
|
178
190
|
|
|
179
191
|
if status_changed or reason_changed or reset_changed:
|
|
180
|
-
await
|
|
192
|
+
await accounts_repo.update_status(
|
|
181
193
|
account.id,
|
|
182
194
|
state.status,
|
|
183
195
|
state.deactivation_reason,
|
|
@@ -251,3 +263,8 @@ def _state_from_account(
|
|
|
251
263
|
error_count=runtime.error_count,
|
|
252
264
|
deactivation_reason=account.deactivation_reason,
|
|
253
265
|
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _clone_account(account: Account) -> Account:
|
|
269
|
+
data = {column.name: getattr(account, column.name) for column in Account.__table__.columns}
|
|
270
|
+
return Account(**data)
|