codex-lb 0.1.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. app/__init__.py +1 -1
  2. app/core/auth/__init__.py +12 -1
  3. app/core/balancer/logic.py +44 -7
  4. app/core/clients/proxy.py +2 -4
  5. app/core/config/settings.py +4 -1
  6. app/core/plan_types.py +64 -0
  7. app/core/types.py +4 -2
  8. app/core/usage/__init__.py +5 -2
  9. app/core/usage/logs.py +12 -2
  10. app/core/usage/quota.py +64 -0
  11. app/core/usage/types.py +3 -2
  12. app/core/utils/sse.py +6 -2
  13. app/db/migrations/__init__.py +91 -0
  14. app/db/migrations/versions/__init__.py +1 -0
  15. app/db/migrations/versions/add_accounts_chatgpt_account_id.py +29 -0
  16. app/db/migrations/versions/add_accounts_reset_at.py +29 -0
  17. app/db/migrations/versions/add_dashboard_settings.py +31 -0
  18. app/db/migrations/versions/add_request_logs_reasoning_effort.py +21 -0
  19. app/db/migrations/versions/normalize_account_plan_types.py +17 -0
  20. app/db/models.py +33 -0
  21. app/db/session.py +85 -11
  22. app/dependencies.py +27 -9
  23. app/main.py +15 -6
  24. app/modules/accounts/auth_manager.py +121 -0
  25. app/modules/accounts/repository.py +14 -6
  26. app/modules/accounts/service.py +14 -9
  27. app/modules/health/api.py +5 -3
  28. app/modules/health/schemas.py +9 -0
  29. app/modules/oauth/service.py +9 -4
  30. app/modules/proxy/helpers.py +285 -0
  31. app/modules/proxy/load_balancer.py +86 -41
  32. app/modules/proxy/service.py +172 -318
  33. app/modules/proxy/sticky_repository.py +56 -0
  34. app/modules/request_logs/repository.py +6 -3
  35. app/modules/request_logs/schemas.py +2 -0
  36. app/modules/request_logs/service.py +12 -3
  37. app/modules/settings/__init__.py +1 -0
  38. app/modules/settings/api.py +37 -0
  39. app/modules/settings/repository.py +40 -0
  40. app/modules/settings/schemas.py +13 -0
  41. app/modules/settings/service.py +33 -0
  42. app/modules/shared/schemas.py +16 -2
  43. app/modules/usage/schemas.py +1 -0
  44. app/modules/usage/service.py +23 -6
  45. app/modules/{proxy/usage_updater.py → usage/updater.py} +37 -8
  46. app/static/7.css +73 -0
  47. app/static/index.css +33 -4
  48. app/static/index.html +51 -4
  49. app/static/index.js +254 -32
  50. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/METADATA +2 -2
  51. codex_lb-0.3.0.dist-info/RECORD +97 -0
  52. app/modules/proxy/auth_manager.py +0 -51
  53. codex_lb-0.1.5.dist-info/RECORD +0 -80
  54. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/WHEEL +0 -0
  55. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/entry_points.txt +0 -0
  56. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/licenses/LICENSE +0 -0
app/db/models.py CHANGED
@@ -24,6 +24,7 @@ class Account(Base):
24
24
  __tablename__ = "accounts"
25
25
 
26
26
  id: Mapped[str] = mapped_column(String, primary_key=True)
27
+ chatgpt_account_id: Mapped[str | None] = mapped_column(String, nullable=True)
27
28
  email: Mapped[str] = mapped_column(String, unique=True, nullable=False)
28
29
  plan_type: Mapped[str] = mapped_column(String, nullable=False)
29
30
 
@@ -40,6 +41,7 @@ class Account(Base):
40
41
  nullable=False,
41
42
  )
42
43
  deactivation_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
44
+ reset_at: Mapped[int | None] = mapped_column(Integer, nullable=True)
43
45
 
44
46
 
45
47
  class UsageHistory(Base):
@@ -71,12 +73,43 @@ class RequestLog(Base):
71
73
  output_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
72
74
  cached_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
73
75
  reasoning_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
76
+ reasoning_effort: Mapped[str | None] = mapped_column(String, nullable=True)
74
77
  latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
75
78
  status: Mapped[str] = mapped_column(String, nullable=False)
76
79
  error_code: Mapped[str | None] = mapped_column(String, nullable=True)
77
80
  error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
78
81
 
79
82
 
83
+ class StickySession(Base):
84
+ __tablename__ = "sticky_sessions"
85
+
86
+ key: Mapped[str] = mapped_column(String, primary_key=True)
87
+ account_id: Mapped[str] = mapped_column(String, ForeignKey("accounts.id"), nullable=False)
88
+ created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)
89
+ updated_at: Mapped[datetime] = mapped_column(
90
+ DateTime,
91
+ server_default=func.now(),
92
+ onupdate=func.now(),
93
+ nullable=False,
94
+ )
95
+
96
+
97
+ class DashboardSettings(Base):
98
+ __tablename__ = "dashboard_settings"
99
+
100
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=False)
101
+ sticky_threads_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
102
+ prefer_earlier_reset_accounts: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
103
+ created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)
104
+ updated_at: Mapped[datetime] = mapped_column(
105
+ DateTime,
106
+ server_default=func.now(),
107
+ onupdate=func.now(),
108
+ nullable=False,
109
+ )
110
+
111
+
80
112
  Index("idx_usage_recorded_at", UsageHistory.recorded_at)
81
113
  Index("idx_usage_account_time", UsageHistory.account_id, UsageHistory.recorded_at)
82
114
  Index("idx_logs_account_time", RequestLog.account_id, RequestLog.requested_at)
115
+ Index("idx_sticky_account", StickySession.account_id)
app/db/session.py CHANGED
@@ -1,42 +1,102 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
3
+ import logging
4
+ import sqlite3
4
5
  from pathlib import Path
5
- from typing import AsyncIterator
6
+ from typing import AsyncIterator, Awaitable, TypeVar
6
7
 
8
+ import anyio
9
+ from sqlalchemy import event
10
+ from sqlalchemy.engine import Engine
7
11
  from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
8
12
 
9
13
  from app.core.config.settings import get_settings
14
+ from app.db.migrations import run_migrations
10
15
 
11
16
  DATABASE_URL = get_settings().database_url
12
17
 
13
- engine = create_async_engine(DATABASE_URL, echo=False)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _SQLITE_BUSY_TIMEOUT_MS = 5_000
21
+ _SQLITE_BUSY_TIMEOUT_SECONDS = _SQLITE_BUSY_TIMEOUT_MS / 1000
22
+
23
+
24
+ def _is_sqlite_url(url: str) -> bool:
25
+ return url.startswith("sqlite+aiosqlite:///") or url.startswith("sqlite:///")
26
+
27
+
28
+ def _is_sqlite_memory_url(url: str) -> bool:
29
+ return _is_sqlite_url(url) and ":memory:" in url
30
+
31
+
32
+ def _configure_sqlite_engine(engine: Engine, *, enable_wal: bool) -> None:
33
+ @event.listens_for(engine, "connect")
34
+ def _set_sqlite_pragmas(dbapi_connection: sqlite3.Connection, _: object) -> None:
35
+ cursor: sqlite3.Cursor = dbapi_connection.cursor()
36
+ try:
37
+ if enable_wal:
38
+ cursor.execute("PRAGMA journal_mode=WAL")
39
+ cursor.execute("PRAGMA synchronous=NORMAL")
40
+ cursor.execute("PRAGMA foreign_keys=ON")
41
+ cursor.execute(f"PRAGMA busy_timeout={_SQLITE_BUSY_TIMEOUT_MS}")
42
+ finally:
43
+ cursor.close()
44
+
45
+
46
+ if _is_sqlite_url(DATABASE_URL):
47
+ engine = create_async_engine(
48
+ DATABASE_URL,
49
+ echo=False,
50
+ connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
51
+ )
52
+ _configure_sqlite_engine(engine.sync_engine, enable_wal=not _is_sqlite_memory_url(DATABASE_URL))
53
+ else:
54
+ engine = create_async_engine(DATABASE_URL, echo=False)
55
+
14
56
  SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
15
57
 
58
+ _T = TypeVar("_T")
59
+
16
60
 
17
61
  def _ensure_sqlite_dir(url: str) -> None:
18
- prefix = "sqlite+aiosqlite:///"
19
- if not url.startswith(prefix):
62
+ if not (url.startswith("sqlite+aiosqlite:") or url.startswith("sqlite:")):
20
63
  return
21
- path = url[len(prefix) :]
22
- if path == ":memory:":
64
+
65
+ marker = ":///"
66
+ marker_index = url.find(marker)
67
+ if marker_index < 0:
23
68
  return
69
+
70
+ # Works for both relative (sqlite+aiosqlite:///./db.sqlite) and absolute
71
+ # paths (sqlite+aiosqlite:////var/lib/app/db.sqlite).
72
+ path = url[marker_index + len(marker) :]
73
+ path = path.partition("?")[0]
74
+ path = path.partition("#")[0]
75
+
76
+ if not path or path == ":memory:":
77
+ return
78
+
24
79
  Path(path).expanduser().parent.mkdir(parents=True, exist_ok=True)
25
80
 
26
81
 
82
+ async def _shielded(awaitable: Awaitable[_T]) -> _T:
83
+ with anyio.CancelScope(shield=True):
84
+ return await awaitable
85
+
86
+
27
87
  async def _safe_rollback(session: AsyncSession) -> None:
28
88
  if not session.in_transaction():
29
89
  return
30
90
  try:
31
- await asyncio.shield(session.rollback())
32
- except Exception:
91
+ await _shielded(session.rollback())
92
+ except BaseException:
33
93
  return
34
94
 
35
95
 
36
96
  async def _safe_close(session: AsyncSession) -> None:
37
97
  try:
38
- await asyncio.shield(session.close())
39
- except Exception:
98
+ await _shielded(session.close())
99
+ except BaseException:
40
100
  return
41
101
 
42
102
 
@@ -60,3 +120,17 @@ async def init_db() -> None:
60
120
 
61
121
  async with engine.begin() as conn:
62
122
  await conn.run_sync(Base.metadata.create_all)
123
+
124
+ async with SessionLocal() as session:
125
+ try:
126
+ updated = await run_migrations(session)
127
+ if updated:
128
+ logger.info("Applied database migrations count=%s", updated)
129
+ except Exception:
130
+ logger.exception("Failed to apply database migrations")
131
+ if get_settings().database_migrations_fail_fast:
132
+ raise
133
+
134
+
135
+ async def close_db() -> None:
136
+ await engine.dispose()
app/dependencies.py CHANGED
@@ -12,8 +12,11 @@ from app.modules.accounts.repository import AccountsRepository
12
12
  from app.modules.accounts.service import AccountsService
13
13
  from app.modules.oauth.service import OauthService
14
14
  from app.modules.proxy.service import ProxyService
15
+ from app.modules.proxy.sticky_repository import StickySessionsRepository
15
16
  from app.modules.request_logs.repository import RequestLogsRepository
16
17
  from app.modules.request_logs.service import RequestLogsService
18
+ from app.modules.settings.repository import SettingsRepository
19
+ from app.modules.settings.service import SettingsService
17
20
  from app.modules.usage.repository import UsageRepository
18
21
  from app.modules.usage.service import UsageService
19
22
 
@@ -22,8 +25,6 @@ from app.modules.usage.service import UsageService
22
25
  class AccountsContext:
23
26
  session: AsyncSession
24
27
  repository: AccountsRepository
25
- usage_repository: UsageRepository
26
- request_logs_repository: RequestLogsRepository
27
28
  service: AccountsService
28
29
 
29
30
 
@@ -31,8 +32,6 @@ class AccountsContext:
31
32
  class UsageContext:
32
33
  session: AsyncSession
33
34
  usage_repository: UsageRepository
34
- request_logs_repository: RequestLogsRepository
35
- accounts_repository: AccountsRepository
36
35
  service: UsageService
37
36
 
38
37
 
@@ -53,6 +52,13 @@ class RequestLogsContext:
53
52
  service: RequestLogsService
54
53
 
55
54
 
55
+ @dataclass(slots=True)
56
+ class SettingsContext:
57
+ session: AsyncSession
58
+ repository: SettingsRepository
59
+ service: SettingsService
60
+
61
+
56
62
  def get_accounts_context(
57
63
  session: AsyncSession = Depends(get_session),
58
64
  ) -> AccountsContext:
@@ -63,8 +69,6 @@ def get_accounts_context(
63
69
  return AccountsContext(
64
70
  session=session,
65
71
  repository=repository,
66
- usage_repository=usage_repository,
67
- request_logs_repository=request_logs_repository,
68
72
  service=service,
69
73
  )
70
74
 
@@ -79,8 +83,6 @@ def get_usage_context(
79
83
  return UsageContext(
80
84
  session=session,
81
85
  usage_repository=usage_repository,
82
- request_logs_repository=request_logs_repository,
83
- accounts_repository=accounts_repository,
84
86
  service=service,
85
87
  )
86
88
 
@@ -112,7 +114,15 @@ def get_proxy_context(
112
114
  accounts_repository = AccountsRepository(session)
113
115
  usage_repository = UsageRepository(session)
114
116
  request_logs_repository = RequestLogsRepository(session)
115
- service = ProxyService(accounts_repository, usage_repository, request_logs_repository)
117
+ sticky_repository = StickySessionsRepository(session)
118
+ settings_repository = SettingsRepository(session)
119
+ service = ProxyService(
120
+ accounts_repository,
121
+ usage_repository,
122
+ request_logs_repository,
123
+ sticky_repository,
124
+ settings_repository,
125
+ )
116
126
  return ProxyContext(service=service)
117
127
 
118
128
 
@@ -122,3 +132,11 @@ def get_request_logs_context(
122
132
  repository = RequestLogsRepository(session)
123
133
  service = RequestLogsService(repository)
124
134
  return RequestLogsContext(session=session, repository=repository, service=service)
135
+
136
+
137
+ def get_settings_context(
138
+ session: AsyncSession = Depends(get_session),
139
+ ) -> SettingsContext:
140
+ repository = SettingsRepository(session)
141
+ service = SettingsService(repository)
142
+ return SettingsContext(session=session, repository=repository, service=service)
app/main.py CHANGED
@@ -11,19 +11,20 @@ from fastapi.exception_handlers import (
11
11
  request_validation_exception_handler,
12
12
  )
13
13
  from fastapi.exceptions import RequestValidationError
14
- from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
14
+ from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, Response
15
15
  from fastapi.staticfiles import StaticFiles
16
16
  from starlette.exceptions import HTTPException as StarletteHTTPException
17
17
 
18
18
  from app.core.clients.http import close_http_client, init_http_client
19
19
  from app.core.errors import dashboard_error
20
20
  from app.core.utils.request_id import get_request_id, reset_request_id, set_request_id
21
- from app.db.session import init_db
21
+ from app.db.session import close_db, init_db
22
22
  from app.modules.accounts import api as accounts_api
23
23
  from app.modules.health import api as health_api
24
24
  from app.modules.oauth import api as oauth_api
25
25
  from app.modules.proxy import api as proxy_api
26
26
  from app.modules.request_logs import api as request_logs_api
27
+ from app.modules.settings import api as settings_api
27
28
  from app.modules.usage import api as usage_api
28
29
 
29
30
  logger = logging.getLogger(__name__)
@@ -37,7 +38,10 @@ async def lifespan(_: FastAPI):
37
38
  try:
38
39
  yield
39
40
  finally:
40
- await close_http_client()
41
+ try:
42
+ await close_http_client()
43
+ finally:
44
+ await close_db()
41
45
 
42
46
 
43
47
  def create_app() -> FastAPI:
@@ -57,7 +61,7 @@ def create_app() -> FastAPI:
57
61
  return response
58
62
 
59
63
  @app.middleware("http")
60
- async def api_unhandled_error_middleware(request: Request, call_next) -> JSONResponse:
64
+ async def api_unhandled_error_middleware(request: Request, call_next) -> Response:
61
65
  try:
62
66
  return await call_next(request)
63
67
  except Exception:
@@ -76,7 +80,7 @@ def create_app() -> FastAPI:
76
80
  async def _validation_error_handler(
77
81
  request: Request,
78
82
  exc: RequestValidationError,
79
- ) -> JSONResponse:
83
+ ) -> Response:
80
84
  if request.url.path.startswith("/api/"):
81
85
  return JSONResponse(
82
86
  status_code=422,
@@ -88,7 +92,7 @@ def create_app() -> FastAPI:
88
92
  async def _http_error_handler(
89
93
  request: Request,
90
94
  exc: StarletteHTTPException,
91
- ) -> JSONResponse:
95
+ ) -> Response:
92
96
  if request.url.path.startswith("/api/"):
93
97
  detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
94
98
  return JSONResponse(
@@ -103,6 +107,7 @@ def create_app() -> FastAPI:
103
107
  app.include_router(usage_api.router)
104
108
  app.include_router(request_logs_api.router)
105
109
  app.include_router(oauth_api.router)
110
+ app.include_router(settings_api.router)
106
111
  app.include_router(health_api.router)
107
112
 
108
113
  static_dir = Path(__file__).parent / "static"
@@ -116,6 +121,10 @@ def create_app() -> FastAPI:
116
121
  async def spa_accounts():
117
122
  return FileResponse(index_html, media_type="text/html")
118
123
 
124
+ @app.get("/settings", include_in_schema=False)
125
+ async def spa_settings():
126
+ return FileResponse(index_html, media_type="text/html")
127
+
119
128
  app.mount("/dashboard", StaticFiles(directory=static_dir, html=True), name="dashboard")
120
129
 
121
130
  return app
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from datetime import datetime
5
+ from typing import Protocol
6
+
7
+ from app.core.auth import DEFAULT_PLAN, OpenAIAuthClaims, extract_id_token_claims
8
+ from app.core.auth.refresh import RefreshError, refresh_access_token, should_refresh
9
+ from app.core.balancer import PERMANENT_FAILURE_CODES
10
+ from app.core.crypto import TokenEncryptor
11
+ from app.core.plan_types import coerce_account_plan_type
12
+ from app.core.utils.time import utcnow
13
+ from app.db.models import Account, AccountStatus
14
+
15
+
16
+ class AccountsRepositoryPort(Protocol):
17
+ async def update_status(
18
+ self,
19
+ account_id: str,
20
+ status: AccountStatus,
21
+ deactivation_reason: str | None = None,
22
+ ) -> bool: ...
23
+
24
+ async def update_tokens(
25
+ self,
26
+ account_id: str,
27
+ access_token_encrypted: bytes,
28
+ refresh_token_encrypted: bytes,
29
+ id_token_encrypted: bytes,
30
+ last_refresh: datetime,
31
+ plan_type: str | None = None,
32
+ email: str | None = None,
33
+ chatgpt_account_id: str | None = None,
34
+ ) -> bool: ...
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class AuthManager:
41
+ def __init__(self, repo: AccountsRepositoryPort) -> None:
42
+ self._repo = repo
43
+ self._encryptor = TokenEncryptor()
44
+
45
+ async def ensure_fresh(self, account: Account, *, force: bool = False) -> Account:
46
+ if force or should_refresh(account.last_refresh):
47
+ account = await self.refresh_account(account)
48
+ return await self._ensure_chatgpt_account_id(account)
49
+
50
+ async def refresh_account(self, account: Account) -> Account:
51
+ refresh_token = self._encryptor.decrypt(account.refresh_token_encrypted)
52
+ try:
53
+ result = await refresh_access_token(refresh_token)
54
+ except RefreshError as exc:
55
+ if exc.is_permanent:
56
+ reason = PERMANENT_FAILURE_CODES.get(exc.code, exc.message)
57
+ await self._repo.update_status(account.id, AccountStatus.DEACTIVATED, reason)
58
+ account.status = AccountStatus.DEACTIVATED
59
+ account.deactivation_reason = reason
60
+ raise
61
+
62
+ account.access_token_encrypted = self._encryptor.encrypt(result.access_token)
63
+ account.refresh_token_encrypted = self._encryptor.encrypt(result.refresh_token)
64
+ account.id_token_encrypted = self._encryptor.encrypt(result.id_token)
65
+ account.last_refresh = utcnow()
66
+ if result.account_id:
67
+ account.chatgpt_account_id = result.account_id
68
+ if result.plan_type is not None:
69
+ account.plan_type = coerce_account_plan_type(
70
+ result.plan_type,
71
+ account.plan_type or DEFAULT_PLAN,
72
+ )
73
+ elif not account.plan_type:
74
+ account.plan_type = DEFAULT_PLAN
75
+ if result.email:
76
+ account.email = result.email
77
+
78
+ await self._repo.update_tokens(
79
+ account.id,
80
+ access_token_encrypted=account.access_token_encrypted,
81
+ refresh_token_encrypted=account.refresh_token_encrypted,
82
+ id_token_encrypted=account.id_token_encrypted,
83
+ last_refresh=account.last_refresh,
84
+ plan_type=account.plan_type,
85
+ email=account.email,
86
+ chatgpt_account_id=account.chatgpt_account_id,
87
+ )
88
+ return account
89
+
90
+ async def _ensure_chatgpt_account_id(self, account: Account) -> Account:
91
+ if account.chatgpt_account_id:
92
+ return account
93
+ try:
94
+ id_token = self._encryptor.decrypt(account.id_token_encrypted)
95
+ except Exception:
96
+ return account
97
+ raw_account_id = _chatgpt_account_id_from_id_token(id_token)
98
+ if not raw_account_id:
99
+ return account
100
+
101
+ account.chatgpt_account_id = raw_account_id
102
+ try:
103
+ await self._repo.update_tokens(
104
+ account.id,
105
+ access_token_encrypted=account.access_token_encrypted,
106
+ refresh_token_encrypted=account.refresh_token_encrypted,
107
+ id_token_encrypted=account.id_token_encrypted,
108
+ last_refresh=account.last_refresh,
109
+ plan_type=account.plan_type,
110
+ email=account.email,
111
+ chatgpt_account_id=raw_account_id,
112
+ )
113
+ except Exception:
114
+ logger.warning("Failed to persist chatgpt_account_id account_id=%s", account.id, exc_info=True)
115
+ return account
116
+
117
+
118
+ def _chatgpt_account_id_from_id_token(id_token: str) -> str | None:
119
+ claims = extract_id_token_claims(id_token)
120
+ auth_claims = claims.auth or OpenAIAuthClaims()
121
+ return auth_claims.chatgpt_account_id or claims.chatgpt_account_id
@@ -19,6 +19,7 @@ class AccountsRepository:
19
19
  async def upsert(self, account: Account) -> Account:
20
20
  existing = await self._session.get(Account, account.id)
21
21
  if existing:
22
+ existing.chatgpt_account_id = account.chatgpt_account_id
22
23
  existing.email = account.email
23
24
  existing.plan_type = account.plan_type
24
25
  existing.access_token_encrypted = account.access_token_encrypted
@@ -41,19 +42,21 @@ class AccountsRepository:
41
42
  account_id: str,
42
43
  status: AccountStatus,
43
44
  deactivation_reason: str | None = None,
45
+ reset_at: int | None = None,
44
46
  ) -> bool:
45
47
  result = await self._session.execute(
46
48
  update(Account)
47
49
  .where(Account.id == account_id)
48
- .values(status=status, deactivation_reason=deactivation_reason)
50
+ .values(status=status, deactivation_reason=deactivation_reason, reset_at=reset_at)
51
+ .returning(Account.id)
49
52
  )
50
53
  await self._session.commit()
51
- return bool(result.rowcount)
54
+ return result.scalar_one_or_none() is not None
52
55
 
53
56
  async def delete(self, account_id: str) -> bool:
54
- result = await self._session.execute(delete(Account).where(Account.id == account_id))
57
+ result = await self._session.execute(delete(Account).where(Account.id == account_id).returning(Account.id))
55
58
  await self._session.commit()
56
- return bool(result.rowcount)
59
+ return result.scalar_one_or_none() is not None
57
60
 
58
61
  async def update_tokens(
59
62
  self,
@@ -64,6 +67,7 @@ class AccountsRepository:
64
67
  last_refresh: datetime,
65
68
  plan_type: str | None = None,
66
69
  email: str | None = None,
70
+ chatgpt_account_id: str | None = None,
67
71
  ) -> bool:
68
72
  values = {
69
73
  "access_token_encrypted": access_token_encrypted,
@@ -75,6 +79,10 @@ class AccountsRepository:
75
79
  values["plan_type"] = plan_type
76
80
  if email is not None:
77
81
  values["email"] = email
78
- result = await self._session.execute(update(Account).where(Account.id == account_id).values(**values))
82
+ if chatgpt_account_id is not None:
83
+ values["chatgpt_account_id"] = chatgpt_account_id
84
+ result = await self._session.execute(
85
+ update(Account).where(Account.id == account_id).values(**values).returning(Account.id)
86
+ )
79
87
  await self._session.commit()
80
- return bool(result.rowcount)
88
+ return result.scalar_one_or_none() is not None
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from datetime import datetime, timedelta, timezone
4
+ from typing import cast
4
5
 
5
6
  from app.core import usage as usage_core
6
7
  from app.core.auth import (
@@ -8,11 +9,12 @@ from app.core.auth import (
8
9
  DEFAULT_PLAN,
9
10
  claims_from_auth,
10
11
  extract_id_token_claims,
11
- fallback_account_id,
12
+ generate_unique_account_id,
12
13
  parse_auth_json,
13
14
  )
14
15
  from app.core.crypto import TokenEncryptor
15
- from app.core.usage.logs import cost_from_log
16
+ from app.core.plan_types import coerce_account_plan_type
17
+ from app.core.usage.logs import RequestLogLike, cost_from_log
16
18
  from app.core.utils.time import from_epoch_seconds, to_utc_naive, utcnow
17
19
  from app.db.models import Account, AccountStatus, UsageHistory
18
20
  from app.modules.accounts.repository import AccountsRepository
@@ -23,9 +25,9 @@ from app.modules.accounts.schemas import (
23
25
  AccountTokenStatus,
24
26
  AccountUsage,
25
27
  )
26
- from app.modules.proxy.usage_updater import UsageUpdater
27
28
  from app.modules.request_logs.repository import RequestLogsRepository
28
29
  from app.modules.usage.repository import UsageRepository
30
+ from app.modules.usage.updater import UsageUpdater
29
31
 
30
32
 
31
33
  class AccountsService:
@@ -64,12 +66,14 @@ class AccountsService:
64
66
  claims = claims_from_auth(auth)
65
67
 
66
68
  email = claims.email or DEFAULT_EMAIL
67
- plan_type = claims.plan_type or DEFAULT_PLAN
68
- account_id = claims.account_id or fallback_account_id(email)
69
+ raw_account_id = claims.account_id
70
+ account_id = generate_unique_account_id(raw_account_id, email)
71
+ plan_type = coerce_account_plan_type(claims.plan_type, DEFAULT_PLAN)
69
72
  last_refresh = to_utc_naive(auth.last_refresh_at) if auth.last_refresh_at else utcnow()
70
73
 
71
74
  account = Account(
72
75
  id=account_id,
76
+ chatgpt_account_id=raw_account_id,
73
77
  email=email,
74
78
  plan_type=plan_type,
75
79
  access_token_encrypted=self._encryptor.encrypt(auth.tokens.access_token),
@@ -107,6 +111,7 @@ class AccountsService:
107
111
  secondary_usage: UsageHistory | None,
108
112
  cost_usd_24h: float | None,
109
113
  ) -> AccountSummary:
114
+ plan_type = coerce_account_plan_type(account.plan_type, DEFAULT_PLAN)
110
115
  auth_status = self._build_auth_status(account)
111
116
  primary_used_percent = _normalize_used_percent(primary_usage) or 0.0
112
117
  secondary_used_percent = _normalize_used_percent(secondary_usage) or 0.0
@@ -114,8 +119,8 @@ class AccountsService:
114
119
  secondary_remaining_percent = usage_core.remaining_percent_from_used(secondary_used_percent) or 0.0
115
120
  reset_at_primary = from_epoch_seconds(primary_usage.reset_at) if primary_usage is not None else None
116
121
  reset_at_secondary = from_epoch_seconds(secondary_usage.reset_at) if secondary_usage is not None else None
117
- capacity_primary = usage_core.capacity_for_plan(account.plan_type, "primary")
118
- capacity_secondary = usage_core.capacity_for_plan(account.plan_type, "secondary")
122
+ capacity_primary = usage_core.capacity_for_plan(plan_type, "primary")
123
+ capacity_secondary = usage_core.capacity_for_plan(plan_type, "secondary")
119
124
  remaining_credits_primary = usage_core.remaining_credits_from_percent(
120
125
  primary_used_percent,
121
126
  capacity_primary,
@@ -128,7 +133,7 @@ class AccountsService:
128
133
  account_id=account.id,
129
134
  email=account.email,
130
135
  display_name=account.email,
131
- plan_type=account.plan_type,
136
+ plan_type=plan_type,
132
137
  status=account.status.value,
133
138
  usage=AccountUsage(
134
139
  primary_remaining_percent=primary_remaining_percent,
@@ -186,7 +191,7 @@ class AccountsService:
186
191
  logs = await self._logs_repo.list_since(since)
187
192
  totals: dict[str, float] = {}
188
193
  for log in logs:
189
- cost = cost_from_log(log)
194
+ cost = cost_from_log(cast(RequestLogLike, log))
190
195
  if cost is None:
191
196
  continue
192
197
  totals[log.account_id] = totals.get(log.account_id, 0.0) + cost
app/modules/health/api.py CHANGED
@@ -2,9 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  from fastapi import APIRouter
4
4
 
5
+ from app.modules.health.schemas import HealthResponse
6
+
5
7
  router = APIRouter(tags=["health"])
6
8
 
7
9
 
8
- @router.get("/health")
9
- async def health_check() -> dict:
10
- return {"status": "ok"}
10
+ @router.get("/health", response_model=HealthResponse)
11
+ async def health_check() -> HealthResponse:
12
+ return HealthResponse(status="ok")
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+
6
+ class HealthResponse(BaseModel):
7
+ model_config = ConfigDict(extra="ignore")
8
+
9
+ status: str
@@ -15,7 +15,7 @@ from app.core.auth import (
15
15
  DEFAULT_PLAN,
16
16
  OpenAIAuthClaims,
17
17
  extract_id_token_claims,
18
- fallback_account_id,
18
+ generate_unique_account_id,
19
19
  )
20
20
  from app.core.clients.oauth import (
21
21
  OAuthError,
@@ -28,6 +28,7 @@ from app.core.clients.oauth import (
28
28
  )
29
29
  from app.core.config.settings import get_settings
30
30
  from app.core.crypto import TokenEncryptor
31
+ from app.core.plan_types import coerce_account_plan_type
31
32
  from app.core.utils.time import utcnow
32
33
  from app.db.models import Account, AccountStatus
33
34
  from app.modules.accounts.repository import AccountsRepository
@@ -293,13 +294,17 @@ class OauthService:
293
294
  async def _persist_tokens(self, tokens: OAuthTokens) -> None:
294
295
  claims = extract_id_token_claims(tokens.id_token)
295
296
  auth_claims = claims.auth or OpenAIAuthClaims()
296
- account_id = auth_claims.chatgpt_account_id or claims.chatgpt_account_id
297
+ raw_account_id = auth_claims.chatgpt_account_id or claims.chatgpt_account_id
297
298
  email = claims.email or DEFAULT_EMAIL
298
- plan_type = auth_claims.chatgpt_plan_type or claims.chatgpt_plan_type or DEFAULT_PLAN
299
- account_id = account_id or fallback_account_id(email)
299
+ account_id = generate_unique_account_id(raw_account_id, email)
300
+ plan_type = coerce_account_plan_type(
301
+ auth_claims.chatgpt_plan_type or claims.chatgpt_plan_type,
302
+ DEFAULT_PLAN,
303
+ )
300
304
 
301
305
  account = Account(
302
306
  id=account_id,
307
+ chatgpt_account_id=raw_account_id,
303
308
  email=email,
304
309
  plan_type=plan_type,
305
310
  access_token_encrypted=self._encryptor.encrypt(tokens.access_token),