codex-lb 0.1.5__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +1 -1
- app/core/auth/__init__.py +12 -1
- app/core/balancer/logic.py +44 -7
- app/core/clients/proxy.py +2 -4
- app/core/config/settings.py +4 -1
- app/core/plan_types.py +64 -0
- app/core/types.py +4 -2
- app/core/usage/__init__.py +5 -2
- app/core/usage/logs.py +12 -2
- app/core/usage/quota.py +64 -0
- app/core/usage/types.py +3 -2
- app/core/utils/sse.py +6 -2
- app/db/migrations/__init__.py +91 -0
- app/db/migrations/versions/__init__.py +1 -0
- app/db/migrations/versions/add_accounts_chatgpt_account_id.py +29 -0
- app/db/migrations/versions/add_accounts_reset_at.py +29 -0
- app/db/migrations/versions/add_dashboard_settings.py +31 -0
- app/db/migrations/versions/add_request_logs_reasoning_effort.py +21 -0
- app/db/migrations/versions/normalize_account_plan_types.py +17 -0
- app/db/models.py +33 -0
- app/db/session.py +85 -11
- app/dependencies.py +27 -9
- app/main.py +15 -6
- app/modules/accounts/auth_manager.py +121 -0
- app/modules/accounts/repository.py +14 -6
- app/modules/accounts/service.py +14 -9
- app/modules/health/api.py +5 -3
- app/modules/health/schemas.py +9 -0
- app/modules/oauth/service.py +9 -4
- app/modules/proxy/helpers.py +285 -0
- app/modules/proxy/load_balancer.py +86 -41
- app/modules/proxy/service.py +172 -318
- app/modules/proxy/sticky_repository.py +56 -0
- app/modules/request_logs/repository.py +6 -3
- app/modules/request_logs/schemas.py +2 -0
- app/modules/request_logs/service.py +12 -3
- app/modules/settings/__init__.py +1 -0
- app/modules/settings/api.py +37 -0
- app/modules/settings/repository.py +40 -0
- app/modules/settings/schemas.py +13 -0
- app/modules/settings/service.py +33 -0
- app/modules/shared/schemas.py +16 -2
- app/modules/usage/schemas.py +1 -0
- app/modules/usage/service.py +23 -6
- app/modules/{proxy/usage_updater.py → usage/updater.py} +37 -8
- app/static/7.css +73 -0
- app/static/index.css +33 -4
- app/static/index.html +51 -4
- app/static/index.js +254 -32
- {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/METADATA +2 -2
- codex_lb-0.3.0.dist-info/RECORD +97 -0
- app/modules/proxy/auth_manager.py +0 -51
- codex_lb-0.1.5.dist-info/RECORD +0 -80
- {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/WHEEL +0 -0
- {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/entry_points.txt +0 -0
- {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/licenses/LICENSE +0 -0
app/db/models.py
CHANGED
|
@@ -24,6 +24,7 @@ class Account(Base):
|
|
|
24
24
|
__tablename__ = "accounts"
|
|
25
25
|
|
|
26
26
|
id: Mapped[str] = mapped_column(String, primary_key=True)
|
|
27
|
+
chatgpt_account_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
|
27
28
|
email: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
|
28
29
|
plan_type: Mapped[str] = mapped_column(String, nullable=False)
|
|
29
30
|
|
|
@@ -40,6 +41,7 @@ class Account(Base):
|
|
|
40
41
|
nullable=False,
|
|
41
42
|
)
|
|
42
43
|
deactivation_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
44
|
+
reset_at: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
43
45
|
|
|
44
46
|
|
|
45
47
|
class UsageHistory(Base):
|
|
@@ -71,12 +73,43 @@ class RequestLog(Base):
|
|
|
71
73
|
output_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
72
74
|
cached_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
73
75
|
reasoning_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
76
|
+
reasoning_effort: Mapped[str | None] = mapped_column(String, nullable=True)
|
|
74
77
|
latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
75
78
|
status: Mapped[str] = mapped_column(String, nullable=False)
|
|
76
79
|
error_code: Mapped[str | None] = mapped_column(String, nullable=True)
|
|
77
80
|
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
78
81
|
|
|
79
82
|
|
|
83
|
+
class StickySession(Base):
|
|
84
|
+
__tablename__ = "sticky_sessions"
|
|
85
|
+
|
|
86
|
+
key: Mapped[str] = mapped_column(String, primary_key=True)
|
|
87
|
+
account_id: Mapped[str] = mapped_column(String, ForeignKey("accounts.id"), nullable=False)
|
|
88
|
+
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)
|
|
89
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
90
|
+
DateTime,
|
|
91
|
+
server_default=func.now(),
|
|
92
|
+
onupdate=func.now(),
|
|
93
|
+
nullable=False,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class DashboardSettings(Base):
|
|
98
|
+
__tablename__ = "dashboard_settings"
|
|
99
|
+
|
|
100
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=False)
|
|
101
|
+
sticky_threads_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
|
102
|
+
prefer_earlier_reset_accounts: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
|
103
|
+
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)
|
|
104
|
+
updated_at: Mapped[datetime] = mapped_column(
|
|
105
|
+
DateTime,
|
|
106
|
+
server_default=func.now(),
|
|
107
|
+
onupdate=func.now(),
|
|
108
|
+
nullable=False,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
80
112
|
Index("idx_usage_recorded_at", UsageHistory.recorded_at)
|
|
81
113
|
Index("idx_usage_account_time", UsageHistory.account_id, UsageHistory.recorded_at)
|
|
82
114
|
Index("idx_logs_account_time", RequestLog.account_id, RequestLog.requested_at)
|
|
115
|
+
Index("idx_sticky_account", StickySession.account_id)
|
app/db/session.py
CHANGED
|
@@ -1,42 +1,102 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import logging
|
|
4
|
+
import sqlite3
|
|
4
5
|
from pathlib import Path
|
|
5
|
-
from typing import AsyncIterator
|
|
6
|
+
from typing import AsyncIterator, Awaitable, TypeVar
|
|
6
7
|
|
|
8
|
+
import anyio
|
|
9
|
+
from sqlalchemy import event
|
|
10
|
+
from sqlalchemy.engine import Engine
|
|
7
11
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
8
12
|
|
|
9
13
|
from app.core.config.settings import get_settings
|
|
14
|
+
from app.db.migrations import run_migrations
|
|
10
15
|
|
|
11
16
|
DATABASE_URL = get_settings().database_url
|
|
12
17
|
|
|
13
|
-
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
_SQLITE_BUSY_TIMEOUT_MS = 5_000
|
|
21
|
+
_SQLITE_BUSY_TIMEOUT_SECONDS = _SQLITE_BUSY_TIMEOUT_MS / 1000
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _is_sqlite_url(url: str) -> bool:
|
|
25
|
+
return url.startswith("sqlite+aiosqlite:///") or url.startswith("sqlite:///")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _is_sqlite_memory_url(url: str) -> bool:
|
|
29
|
+
return _is_sqlite_url(url) and ":memory:" in url
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _configure_sqlite_engine(engine: Engine, *, enable_wal: bool) -> None:
|
|
33
|
+
@event.listens_for(engine, "connect")
|
|
34
|
+
def _set_sqlite_pragmas(dbapi_connection: sqlite3.Connection, _: object) -> None:
|
|
35
|
+
cursor: sqlite3.Cursor = dbapi_connection.cursor()
|
|
36
|
+
try:
|
|
37
|
+
if enable_wal:
|
|
38
|
+
cursor.execute("PRAGMA journal_mode=WAL")
|
|
39
|
+
cursor.execute("PRAGMA synchronous=NORMAL")
|
|
40
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
|
41
|
+
cursor.execute(f"PRAGMA busy_timeout={_SQLITE_BUSY_TIMEOUT_MS}")
|
|
42
|
+
finally:
|
|
43
|
+
cursor.close()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
if _is_sqlite_url(DATABASE_URL):
|
|
47
|
+
engine = create_async_engine(
|
|
48
|
+
DATABASE_URL,
|
|
49
|
+
echo=False,
|
|
50
|
+
connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
|
|
51
|
+
)
|
|
52
|
+
_configure_sqlite_engine(engine.sync_engine, enable_wal=not _is_sqlite_memory_url(DATABASE_URL))
|
|
53
|
+
else:
|
|
54
|
+
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
55
|
+
|
|
14
56
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
|
15
57
|
|
|
58
|
+
_T = TypeVar("_T")
|
|
59
|
+
|
|
16
60
|
|
|
17
61
|
def _ensure_sqlite_dir(url: str) -> None:
|
|
18
|
-
|
|
19
|
-
if not url.startswith(prefix):
|
|
62
|
+
if not (url.startswith("sqlite+aiosqlite:") or url.startswith("sqlite:")):
|
|
20
63
|
return
|
|
21
|
-
|
|
22
|
-
|
|
64
|
+
|
|
65
|
+
marker = ":///"
|
|
66
|
+
marker_index = url.find(marker)
|
|
67
|
+
if marker_index < 0:
|
|
23
68
|
return
|
|
69
|
+
|
|
70
|
+
# Works for both relative (sqlite+aiosqlite:///./db.sqlite) and absolute
|
|
71
|
+
# paths (sqlite+aiosqlite:////var/lib/app/db.sqlite).
|
|
72
|
+
path = url[marker_index + len(marker) :]
|
|
73
|
+
path = path.partition("?")[0]
|
|
74
|
+
path = path.partition("#")[0]
|
|
75
|
+
|
|
76
|
+
if not path or path == ":memory:":
|
|
77
|
+
return
|
|
78
|
+
|
|
24
79
|
Path(path).expanduser().parent.mkdir(parents=True, exist_ok=True)
|
|
25
80
|
|
|
26
81
|
|
|
82
|
+
async def _shielded(awaitable: Awaitable[_T]) -> _T:
|
|
83
|
+
with anyio.CancelScope(shield=True):
|
|
84
|
+
return await awaitable
|
|
85
|
+
|
|
86
|
+
|
|
27
87
|
async def _safe_rollback(session: AsyncSession) -> None:
|
|
28
88
|
if not session.in_transaction():
|
|
29
89
|
return
|
|
30
90
|
try:
|
|
31
|
-
await
|
|
32
|
-
except
|
|
91
|
+
await _shielded(session.rollback())
|
|
92
|
+
except BaseException:
|
|
33
93
|
return
|
|
34
94
|
|
|
35
95
|
|
|
36
96
|
async def _safe_close(session: AsyncSession) -> None:
|
|
37
97
|
try:
|
|
38
|
-
await
|
|
39
|
-
except
|
|
98
|
+
await _shielded(session.close())
|
|
99
|
+
except BaseException:
|
|
40
100
|
return
|
|
41
101
|
|
|
42
102
|
|
|
@@ -60,3 +120,17 @@ async def init_db() -> None:
|
|
|
60
120
|
|
|
61
121
|
async with engine.begin() as conn:
|
|
62
122
|
await conn.run_sync(Base.metadata.create_all)
|
|
123
|
+
|
|
124
|
+
async with SessionLocal() as session:
|
|
125
|
+
try:
|
|
126
|
+
updated = await run_migrations(session)
|
|
127
|
+
if updated:
|
|
128
|
+
logger.info("Applied database migrations count=%s", updated)
|
|
129
|
+
except Exception:
|
|
130
|
+
logger.exception("Failed to apply database migrations")
|
|
131
|
+
if get_settings().database_migrations_fail_fast:
|
|
132
|
+
raise
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def close_db() -> None:
|
|
136
|
+
await engine.dispose()
|
app/dependencies.py
CHANGED
|
@@ -12,8 +12,11 @@ from app.modules.accounts.repository import AccountsRepository
|
|
|
12
12
|
from app.modules.accounts.service import AccountsService
|
|
13
13
|
from app.modules.oauth.service import OauthService
|
|
14
14
|
from app.modules.proxy.service import ProxyService
|
|
15
|
+
from app.modules.proxy.sticky_repository import StickySessionsRepository
|
|
15
16
|
from app.modules.request_logs.repository import RequestLogsRepository
|
|
16
17
|
from app.modules.request_logs.service import RequestLogsService
|
|
18
|
+
from app.modules.settings.repository import SettingsRepository
|
|
19
|
+
from app.modules.settings.service import SettingsService
|
|
17
20
|
from app.modules.usage.repository import UsageRepository
|
|
18
21
|
from app.modules.usage.service import UsageService
|
|
19
22
|
|
|
@@ -22,8 +25,6 @@ from app.modules.usage.service import UsageService
|
|
|
22
25
|
class AccountsContext:
|
|
23
26
|
session: AsyncSession
|
|
24
27
|
repository: AccountsRepository
|
|
25
|
-
usage_repository: UsageRepository
|
|
26
|
-
request_logs_repository: RequestLogsRepository
|
|
27
28
|
service: AccountsService
|
|
28
29
|
|
|
29
30
|
|
|
@@ -31,8 +32,6 @@ class AccountsContext:
|
|
|
31
32
|
class UsageContext:
|
|
32
33
|
session: AsyncSession
|
|
33
34
|
usage_repository: UsageRepository
|
|
34
|
-
request_logs_repository: RequestLogsRepository
|
|
35
|
-
accounts_repository: AccountsRepository
|
|
36
35
|
service: UsageService
|
|
37
36
|
|
|
38
37
|
|
|
@@ -53,6 +52,13 @@ class RequestLogsContext:
|
|
|
53
52
|
service: RequestLogsService
|
|
54
53
|
|
|
55
54
|
|
|
55
|
+
@dataclass(slots=True)
|
|
56
|
+
class SettingsContext:
|
|
57
|
+
session: AsyncSession
|
|
58
|
+
repository: SettingsRepository
|
|
59
|
+
service: SettingsService
|
|
60
|
+
|
|
61
|
+
|
|
56
62
|
def get_accounts_context(
|
|
57
63
|
session: AsyncSession = Depends(get_session),
|
|
58
64
|
) -> AccountsContext:
|
|
@@ -63,8 +69,6 @@ def get_accounts_context(
|
|
|
63
69
|
return AccountsContext(
|
|
64
70
|
session=session,
|
|
65
71
|
repository=repository,
|
|
66
|
-
usage_repository=usage_repository,
|
|
67
|
-
request_logs_repository=request_logs_repository,
|
|
68
72
|
service=service,
|
|
69
73
|
)
|
|
70
74
|
|
|
@@ -79,8 +83,6 @@ def get_usage_context(
|
|
|
79
83
|
return UsageContext(
|
|
80
84
|
session=session,
|
|
81
85
|
usage_repository=usage_repository,
|
|
82
|
-
request_logs_repository=request_logs_repository,
|
|
83
|
-
accounts_repository=accounts_repository,
|
|
84
86
|
service=service,
|
|
85
87
|
)
|
|
86
88
|
|
|
@@ -112,7 +114,15 @@ def get_proxy_context(
|
|
|
112
114
|
accounts_repository = AccountsRepository(session)
|
|
113
115
|
usage_repository = UsageRepository(session)
|
|
114
116
|
request_logs_repository = RequestLogsRepository(session)
|
|
115
|
-
|
|
117
|
+
sticky_repository = StickySessionsRepository(session)
|
|
118
|
+
settings_repository = SettingsRepository(session)
|
|
119
|
+
service = ProxyService(
|
|
120
|
+
accounts_repository,
|
|
121
|
+
usage_repository,
|
|
122
|
+
request_logs_repository,
|
|
123
|
+
sticky_repository,
|
|
124
|
+
settings_repository,
|
|
125
|
+
)
|
|
116
126
|
return ProxyContext(service=service)
|
|
117
127
|
|
|
118
128
|
|
|
@@ -122,3 +132,11 @@ def get_request_logs_context(
|
|
|
122
132
|
repository = RequestLogsRepository(session)
|
|
123
133
|
service = RequestLogsService(repository)
|
|
124
134
|
return RequestLogsContext(session=session, repository=repository, service=service)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_settings_context(
|
|
138
|
+
session: AsyncSession = Depends(get_session),
|
|
139
|
+
) -> SettingsContext:
|
|
140
|
+
repository = SettingsRepository(session)
|
|
141
|
+
service = SettingsService(repository)
|
|
142
|
+
return SettingsContext(session=session, repository=repository, service=service)
|
app/main.py
CHANGED
|
@@ -11,19 +11,20 @@ from fastapi.exception_handlers import (
|
|
|
11
11
|
request_validation_exception_handler,
|
|
12
12
|
)
|
|
13
13
|
from fastapi.exceptions import RequestValidationError
|
|
14
|
-
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
|
|
14
|
+
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, Response
|
|
15
15
|
from fastapi.staticfiles import StaticFiles
|
|
16
16
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
17
17
|
|
|
18
18
|
from app.core.clients.http import close_http_client, init_http_client
|
|
19
19
|
from app.core.errors import dashboard_error
|
|
20
20
|
from app.core.utils.request_id import get_request_id, reset_request_id, set_request_id
|
|
21
|
-
from app.db.session import init_db
|
|
21
|
+
from app.db.session import close_db, init_db
|
|
22
22
|
from app.modules.accounts import api as accounts_api
|
|
23
23
|
from app.modules.health import api as health_api
|
|
24
24
|
from app.modules.oauth import api as oauth_api
|
|
25
25
|
from app.modules.proxy import api as proxy_api
|
|
26
26
|
from app.modules.request_logs import api as request_logs_api
|
|
27
|
+
from app.modules.settings import api as settings_api
|
|
27
28
|
from app.modules.usage import api as usage_api
|
|
28
29
|
|
|
29
30
|
logger = logging.getLogger(__name__)
|
|
@@ -37,7 +38,10 @@ async def lifespan(_: FastAPI):
|
|
|
37
38
|
try:
|
|
38
39
|
yield
|
|
39
40
|
finally:
|
|
40
|
-
|
|
41
|
+
try:
|
|
42
|
+
await close_http_client()
|
|
43
|
+
finally:
|
|
44
|
+
await close_db()
|
|
41
45
|
|
|
42
46
|
|
|
43
47
|
def create_app() -> FastAPI:
|
|
@@ -57,7 +61,7 @@ def create_app() -> FastAPI:
|
|
|
57
61
|
return response
|
|
58
62
|
|
|
59
63
|
@app.middleware("http")
|
|
60
|
-
async def api_unhandled_error_middleware(request: Request, call_next) ->
|
|
64
|
+
async def api_unhandled_error_middleware(request: Request, call_next) -> Response:
|
|
61
65
|
try:
|
|
62
66
|
return await call_next(request)
|
|
63
67
|
except Exception:
|
|
@@ -76,7 +80,7 @@ def create_app() -> FastAPI:
|
|
|
76
80
|
async def _validation_error_handler(
|
|
77
81
|
request: Request,
|
|
78
82
|
exc: RequestValidationError,
|
|
79
|
-
) ->
|
|
83
|
+
) -> Response:
|
|
80
84
|
if request.url.path.startswith("/api/"):
|
|
81
85
|
return JSONResponse(
|
|
82
86
|
status_code=422,
|
|
@@ -88,7 +92,7 @@ def create_app() -> FastAPI:
|
|
|
88
92
|
async def _http_error_handler(
|
|
89
93
|
request: Request,
|
|
90
94
|
exc: StarletteHTTPException,
|
|
91
|
-
) ->
|
|
95
|
+
) -> Response:
|
|
92
96
|
if request.url.path.startswith("/api/"):
|
|
93
97
|
detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
|
|
94
98
|
return JSONResponse(
|
|
@@ -103,6 +107,7 @@ def create_app() -> FastAPI:
|
|
|
103
107
|
app.include_router(usage_api.router)
|
|
104
108
|
app.include_router(request_logs_api.router)
|
|
105
109
|
app.include_router(oauth_api.router)
|
|
110
|
+
app.include_router(settings_api.router)
|
|
106
111
|
app.include_router(health_api.router)
|
|
107
112
|
|
|
108
113
|
static_dir = Path(__file__).parent / "static"
|
|
@@ -116,6 +121,10 @@ def create_app() -> FastAPI:
|
|
|
116
121
|
async def spa_accounts():
|
|
117
122
|
return FileResponse(index_html, media_type="text/html")
|
|
118
123
|
|
|
124
|
+
@app.get("/settings", include_in_schema=False)
|
|
125
|
+
async def spa_settings():
|
|
126
|
+
return FileResponse(index_html, media_type="text/html")
|
|
127
|
+
|
|
119
128
|
app.mount("/dashboard", StaticFiles(directory=static_dir, html=True), name="dashboard")
|
|
120
129
|
|
|
121
130
|
return app
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
|
|
7
|
+
from app.core.auth import DEFAULT_PLAN, OpenAIAuthClaims, extract_id_token_claims
|
|
8
|
+
from app.core.auth.refresh import RefreshError, refresh_access_token, should_refresh
|
|
9
|
+
from app.core.balancer import PERMANENT_FAILURE_CODES
|
|
10
|
+
from app.core.crypto import TokenEncryptor
|
|
11
|
+
from app.core.plan_types import coerce_account_plan_type
|
|
12
|
+
from app.core.utils.time import utcnow
|
|
13
|
+
from app.db.models import Account, AccountStatus
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AccountsRepositoryPort(Protocol):
|
|
17
|
+
async def update_status(
|
|
18
|
+
self,
|
|
19
|
+
account_id: str,
|
|
20
|
+
status: AccountStatus,
|
|
21
|
+
deactivation_reason: str | None = None,
|
|
22
|
+
) -> bool: ...
|
|
23
|
+
|
|
24
|
+
async def update_tokens(
|
|
25
|
+
self,
|
|
26
|
+
account_id: str,
|
|
27
|
+
access_token_encrypted: bytes,
|
|
28
|
+
refresh_token_encrypted: bytes,
|
|
29
|
+
id_token_encrypted: bytes,
|
|
30
|
+
last_refresh: datetime,
|
|
31
|
+
plan_type: str | None = None,
|
|
32
|
+
email: str | None = None,
|
|
33
|
+
chatgpt_account_id: str | None = None,
|
|
34
|
+
) -> bool: ...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AuthManager:
|
|
41
|
+
def __init__(self, repo: AccountsRepositoryPort) -> None:
|
|
42
|
+
self._repo = repo
|
|
43
|
+
self._encryptor = TokenEncryptor()
|
|
44
|
+
|
|
45
|
+
async def ensure_fresh(self, account: Account, *, force: bool = False) -> Account:
|
|
46
|
+
if force or should_refresh(account.last_refresh):
|
|
47
|
+
account = await self.refresh_account(account)
|
|
48
|
+
return await self._ensure_chatgpt_account_id(account)
|
|
49
|
+
|
|
50
|
+
async def refresh_account(self, account: Account) -> Account:
|
|
51
|
+
refresh_token = self._encryptor.decrypt(account.refresh_token_encrypted)
|
|
52
|
+
try:
|
|
53
|
+
result = await refresh_access_token(refresh_token)
|
|
54
|
+
except RefreshError as exc:
|
|
55
|
+
if exc.is_permanent:
|
|
56
|
+
reason = PERMANENT_FAILURE_CODES.get(exc.code, exc.message)
|
|
57
|
+
await self._repo.update_status(account.id, AccountStatus.DEACTIVATED, reason)
|
|
58
|
+
account.status = AccountStatus.DEACTIVATED
|
|
59
|
+
account.deactivation_reason = reason
|
|
60
|
+
raise
|
|
61
|
+
|
|
62
|
+
account.access_token_encrypted = self._encryptor.encrypt(result.access_token)
|
|
63
|
+
account.refresh_token_encrypted = self._encryptor.encrypt(result.refresh_token)
|
|
64
|
+
account.id_token_encrypted = self._encryptor.encrypt(result.id_token)
|
|
65
|
+
account.last_refresh = utcnow()
|
|
66
|
+
if result.account_id:
|
|
67
|
+
account.chatgpt_account_id = result.account_id
|
|
68
|
+
if result.plan_type is not None:
|
|
69
|
+
account.plan_type = coerce_account_plan_type(
|
|
70
|
+
result.plan_type,
|
|
71
|
+
account.plan_type or DEFAULT_PLAN,
|
|
72
|
+
)
|
|
73
|
+
elif not account.plan_type:
|
|
74
|
+
account.plan_type = DEFAULT_PLAN
|
|
75
|
+
if result.email:
|
|
76
|
+
account.email = result.email
|
|
77
|
+
|
|
78
|
+
await self._repo.update_tokens(
|
|
79
|
+
account.id,
|
|
80
|
+
access_token_encrypted=account.access_token_encrypted,
|
|
81
|
+
refresh_token_encrypted=account.refresh_token_encrypted,
|
|
82
|
+
id_token_encrypted=account.id_token_encrypted,
|
|
83
|
+
last_refresh=account.last_refresh,
|
|
84
|
+
plan_type=account.plan_type,
|
|
85
|
+
email=account.email,
|
|
86
|
+
chatgpt_account_id=account.chatgpt_account_id,
|
|
87
|
+
)
|
|
88
|
+
return account
|
|
89
|
+
|
|
90
|
+
async def _ensure_chatgpt_account_id(self, account: Account) -> Account:
|
|
91
|
+
if account.chatgpt_account_id:
|
|
92
|
+
return account
|
|
93
|
+
try:
|
|
94
|
+
id_token = self._encryptor.decrypt(account.id_token_encrypted)
|
|
95
|
+
except Exception:
|
|
96
|
+
return account
|
|
97
|
+
raw_account_id = _chatgpt_account_id_from_id_token(id_token)
|
|
98
|
+
if not raw_account_id:
|
|
99
|
+
return account
|
|
100
|
+
|
|
101
|
+
account.chatgpt_account_id = raw_account_id
|
|
102
|
+
try:
|
|
103
|
+
await self._repo.update_tokens(
|
|
104
|
+
account.id,
|
|
105
|
+
access_token_encrypted=account.access_token_encrypted,
|
|
106
|
+
refresh_token_encrypted=account.refresh_token_encrypted,
|
|
107
|
+
id_token_encrypted=account.id_token_encrypted,
|
|
108
|
+
last_refresh=account.last_refresh,
|
|
109
|
+
plan_type=account.plan_type,
|
|
110
|
+
email=account.email,
|
|
111
|
+
chatgpt_account_id=raw_account_id,
|
|
112
|
+
)
|
|
113
|
+
except Exception:
|
|
114
|
+
logger.warning("Failed to persist chatgpt_account_id account_id=%s", account.id, exc_info=True)
|
|
115
|
+
return account
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _chatgpt_account_id_from_id_token(id_token: str) -> str | None:
|
|
119
|
+
claims = extract_id_token_claims(id_token)
|
|
120
|
+
auth_claims = claims.auth or OpenAIAuthClaims()
|
|
121
|
+
return auth_claims.chatgpt_account_id or claims.chatgpt_account_id
|
|
@@ -19,6 +19,7 @@ class AccountsRepository:
|
|
|
19
19
|
async def upsert(self, account: Account) -> Account:
|
|
20
20
|
existing = await self._session.get(Account, account.id)
|
|
21
21
|
if existing:
|
|
22
|
+
existing.chatgpt_account_id = account.chatgpt_account_id
|
|
22
23
|
existing.email = account.email
|
|
23
24
|
existing.plan_type = account.plan_type
|
|
24
25
|
existing.access_token_encrypted = account.access_token_encrypted
|
|
@@ -41,19 +42,21 @@ class AccountsRepository:
|
|
|
41
42
|
account_id: str,
|
|
42
43
|
status: AccountStatus,
|
|
43
44
|
deactivation_reason: str | None = None,
|
|
45
|
+
reset_at: int | None = None,
|
|
44
46
|
) -> bool:
|
|
45
47
|
result = await self._session.execute(
|
|
46
48
|
update(Account)
|
|
47
49
|
.where(Account.id == account_id)
|
|
48
|
-
.values(status=status, deactivation_reason=deactivation_reason)
|
|
50
|
+
.values(status=status, deactivation_reason=deactivation_reason, reset_at=reset_at)
|
|
51
|
+
.returning(Account.id)
|
|
49
52
|
)
|
|
50
53
|
await self._session.commit()
|
|
51
|
-
return
|
|
54
|
+
return result.scalar_one_or_none() is not None
|
|
52
55
|
|
|
53
56
|
async def delete(self, account_id: str) -> bool:
|
|
54
|
-
result = await self._session.execute(delete(Account).where(Account.id == account_id))
|
|
57
|
+
result = await self._session.execute(delete(Account).where(Account.id == account_id).returning(Account.id))
|
|
55
58
|
await self._session.commit()
|
|
56
|
-
return
|
|
59
|
+
return result.scalar_one_or_none() is not None
|
|
57
60
|
|
|
58
61
|
async def update_tokens(
|
|
59
62
|
self,
|
|
@@ -64,6 +67,7 @@ class AccountsRepository:
|
|
|
64
67
|
last_refresh: datetime,
|
|
65
68
|
plan_type: str | None = None,
|
|
66
69
|
email: str | None = None,
|
|
70
|
+
chatgpt_account_id: str | None = None,
|
|
67
71
|
) -> bool:
|
|
68
72
|
values = {
|
|
69
73
|
"access_token_encrypted": access_token_encrypted,
|
|
@@ -75,6 +79,10 @@ class AccountsRepository:
|
|
|
75
79
|
values["plan_type"] = plan_type
|
|
76
80
|
if email is not None:
|
|
77
81
|
values["email"] = email
|
|
78
|
-
|
|
82
|
+
if chatgpt_account_id is not None:
|
|
83
|
+
values["chatgpt_account_id"] = chatgpt_account_id
|
|
84
|
+
result = await self._session.execute(
|
|
85
|
+
update(Account).where(Account.id == account_id).values(**values).returning(Account.id)
|
|
86
|
+
)
|
|
79
87
|
await self._session.commit()
|
|
80
|
-
return
|
|
88
|
+
return result.scalar_one_or_none() is not None
|
app/modules/accounts/service.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from datetime import datetime, timedelta, timezone
|
|
4
|
+
from typing import cast
|
|
4
5
|
|
|
5
6
|
from app.core import usage as usage_core
|
|
6
7
|
from app.core.auth import (
|
|
@@ -8,11 +9,12 @@ from app.core.auth import (
|
|
|
8
9
|
DEFAULT_PLAN,
|
|
9
10
|
claims_from_auth,
|
|
10
11
|
extract_id_token_claims,
|
|
11
|
-
|
|
12
|
+
generate_unique_account_id,
|
|
12
13
|
parse_auth_json,
|
|
13
14
|
)
|
|
14
15
|
from app.core.crypto import TokenEncryptor
|
|
15
|
-
from app.core.
|
|
16
|
+
from app.core.plan_types import coerce_account_plan_type
|
|
17
|
+
from app.core.usage.logs import RequestLogLike, cost_from_log
|
|
16
18
|
from app.core.utils.time import from_epoch_seconds, to_utc_naive, utcnow
|
|
17
19
|
from app.db.models import Account, AccountStatus, UsageHistory
|
|
18
20
|
from app.modules.accounts.repository import AccountsRepository
|
|
@@ -23,9 +25,9 @@ from app.modules.accounts.schemas import (
|
|
|
23
25
|
AccountTokenStatus,
|
|
24
26
|
AccountUsage,
|
|
25
27
|
)
|
|
26
|
-
from app.modules.proxy.usage_updater import UsageUpdater
|
|
27
28
|
from app.modules.request_logs.repository import RequestLogsRepository
|
|
28
29
|
from app.modules.usage.repository import UsageRepository
|
|
30
|
+
from app.modules.usage.updater import UsageUpdater
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
class AccountsService:
|
|
@@ -64,12 +66,14 @@ class AccountsService:
|
|
|
64
66
|
claims = claims_from_auth(auth)
|
|
65
67
|
|
|
66
68
|
email = claims.email or DEFAULT_EMAIL
|
|
67
|
-
|
|
68
|
-
account_id =
|
|
69
|
+
raw_account_id = claims.account_id
|
|
70
|
+
account_id = generate_unique_account_id(raw_account_id, email)
|
|
71
|
+
plan_type = coerce_account_plan_type(claims.plan_type, DEFAULT_PLAN)
|
|
69
72
|
last_refresh = to_utc_naive(auth.last_refresh_at) if auth.last_refresh_at else utcnow()
|
|
70
73
|
|
|
71
74
|
account = Account(
|
|
72
75
|
id=account_id,
|
|
76
|
+
chatgpt_account_id=raw_account_id,
|
|
73
77
|
email=email,
|
|
74
78
|
plan_type=plan_type,
|
|
75
79
|
access_token_encrypted=self._encryptor.encrypt(auth.tokens.access_token),
|
|
@@ -107,6 +111,7 @@ class AccountsService:
|
|
|
107
111
|
secondary_usage: UsageHistory | None,
|
|
108
112
|
cost_usd_24h: float | None,
|
|
109
113
|
) -> AccountSummary:
|
|
114
|
+
plan_type = coerce_account_plan_type(account.plan_type, DEFAULT_PLAN)
|
|
110
115
|
auth_status = self._build_auth_status(account)
|
|
111
116
|
primary_used_percent = _normalize_used_percent(primary_usage) or 0.0
|
|
112
117
|
secondary_used_percent = _normalize_used_percent(secondary_usage) or 0.0
|
|
@@ -114,8 +119,8 @@ class AccountsService:
|
|
|
114
119
|
secondary_remaining_percent = usage_core.remaining_percent_from_used(secondary_used_percent) or 0.0
|
|
115
120
|
reset_at_primary = from_epoch_seconds(primary_usage.reset_at) if primary_usage is not None else None
|
|
116
121
|
reset_at_secondary = from_epoch_seconds(secondary_usage.reset_at) if secondary_usage is not None else None
|
|
117
|
-
capacity_primary = usage_core.capacity_for_plan(
|
|
118
|
-
capacity_secondary = usage_core.capacity_for_plan(
|
|
122
|
+
capacity_primary = usage_core.capacity_for_plan(plan_type, "primary")
|
|
123
|
+
capacity_secondary = usage_core.capacity_for_plan(plan_type, "secondary")
|
|
119
124
|
remaining_credits_primary = usage_core.remaining_credits_from_percent(
|
|
120
125
|
primary_used_percent,
|
|
121
126
|
capacity_primary,
|
|
@@ -128,7 +133,7 @@ class AccountsService:
|
|
|
128
133
|
account_id=account.id,
|
|
129
134
|
email=account.email,
|
|
130
135
|
display_name=account.email,
|
|
131
|
-
plan_type=
|
|
136
|
+
plan_type=plan_type,
|
|
132
137
|
status=account.status.value,
|
|
133
138
|
usage=AccountUsage(
|
|
134
139
|
primary_remaining_percent=primary_remaining_percent,
|
|
@@ -186,7 +191,7 @@ class AccountsService:
|
|
|
186
191
|
logs = await self._logs_repo.list_since(since)
|
|
187
192
|
totals: dict[str, float] = {}
|
|
188
193
|
for log in logs:
|
|
189
|
-
cost = cost_from_log(log)
|
|
194
|
+
cost = cost_from_log(cast(RequestLogLike, log))
|
|
190
195
|
if cost is None:
|
|
191
196
|
continue
|
|
192
197
|
totals[log.account_id] = totals.get(log.account_id, 0.0) + cost
|
app/modules/health/api.py
CHANGED
|
@@ -2,9 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from fastapi import APIRouter
|
|
4
4
|
|
|
5
|
+
from app.modules.health.schemas import HealthResponse
|
|
6
|
+
|
|
5
7
|
router = APIRouter(tags=["health"])
|
|
6
8
|
|
|
7
9
|
|
|
8
|
-
@router.get("/health")
|
|
9
|
-
async def health_check() ->
|
|
10
|
-
return
|
|
10
|
+
@router.get("/health", response_model=HealthResponse)
|
|
11
|
+
async def health_check() -> HealthResponse:
|
|
12
|
+
return HealthResponse(status="ok")
|
app/modules/oauth/service.py
CHANGED
|
@@ -15,7 +15,7 @@ from app.core.auth import (
|
|
|
15
15
|
DEFAULT_PLAN,
|
|
16
16
|
OpenAIAuthClaims,
|
|
17
17
|
extract_id_token_claims,
|
|
18
|
-
|
|
18
|
+
generate_unique_account_id,
|
|
19
19
|
)
|
|
20
20
|
from app.core.clients.oauth import (
|
|
21
21
|
OAuthError,
|
|
@@ -28,6 +28,7 @@ from app.core.clients.oauth import (
|
|
|
28
28
|
)
|
|
29
29
|
from app.core.config.settings import get_settings
|
|
30
30
|
from app.core.crypto import TokenEncryptor
|
|
31
|
+
from app.core.plan_types import coerce_account_plan_type
|
|
31
32
|
from app.core.utils.time import utcnow
|
|
32
33
|
from app.db.models import Account, AccountStatus
|
|
33
34
|
from app.modules.accounts.repository import AccountsRepository
|
|
@@ -293,13 +294,17 @@ class OauthService:
|
|
|
293
294
|
async def _persist_tokens(self, tokens: OAuthTokens) -> None:
|
|
294
295
|
claims = extract_id_token_claims(tokens.id_token)
|
|
295
296
|
auth_claims = claims.auth or OpenAIAuthClaims()
|
|
296
|
-
|
|
297
|
+
raw_account_id = auth_claims.chatgpt_account_id or claims.chatgpt_account_id
|
|
297
298
|
email = claims.email or DEFAULT_EMAIL
|
|
298
|
-
|
|
299
|
-
|
|
299
|
+
account_id = generate_unique_account_id(raw_account_id, email)
|
|
300
|
+
plan_type = coerce_account_plan_type(
|
|
301
|
+
auth_claims.chatgpt_plan_type or claims.chatgpt_plan_type,
|
|
302
|
+
DEFAULT_PLAN,
|
|
303
|
+
)
|
|
300
304
|
|
|
301
305
|
account = Account(
|
|
302
306
|
id=account_id,
|
|
307
|
+
chatgpt_account_id=raw_account_id,
|
|
303
308
|
email=email,
|
|
304
309
|
plan_type=plan_type,
|
|
305
310
|
access_token_encrypted=self._encryptor.encrypt(tokens.access_token),
|