codex-lb 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. app/__init__.py +5 -0
  2. app/cli.py +24 -0
  3. app/core/__init__.py +0 -0
  4. app/core/auth/__init__.py +96 -0
  5. app/core/auth/models.py +49 -0
  6. app/core/auth/refresh.py +144 -0
  7. app/core/balancer/__init__.py +19 -0
  8. app/core/balancer/logic.py +140 -0
  9. app/core/balancer/types.py +9 -0
  10. app/core/clients/__init__.py +0 -0
  11. app/core/clients/http.py +39 -0
  12. app/core/clients/oauth.py +340 -0
  13. app/core/clients/proxy.py +265 -0
  14. app/core/clients/usage.py +143 -0
  15. app/core/config/__init__.py +0 -0
  16. app/core/config/settings.py +69 -0
  17. app/core/crypto.py +37 -0
  18. app/core/errors.py +73 -0
  19. app/core/openai/__init__.py +0 -0
  20. app/core/openai/models.py +122 -0
  21. app/core/openai/parsing.py +55 -0
  22. app/core/openai/requests.py +59 -0
  23. app/core/types.py +4 -0
  24. app/core/usage/__init__.py +185 -0
  25. app/core/usage/logs.py +57 -0
  26. app/core/usage/models.py +35 -0
  27. app/core/usage/pricing.py +172 -0
  28. app/core/usage/types.py +95 -0
  29. app/core/utils/__init__.py +0 -0
  30. app/core/utils/request_id.py +30 -0
  31. app/core/utils/retry.py +16 -0
  32. app/core/utils/sse.py +13 -0
  33. app/core/utils/time.py +19 -0
  34. app/db/__init__.py +0 -0
  35. app/db/models.py +82 -0
  36. app/db/session.py +44 -0
  37. app/dependencies.py +123 -0
  38. app/main.py +124 -0
  39. app/modules/__init__.py +0 -0
  40. app/modules/accounts/__init__.py +0 -0
  41. app/modules/accounts/api.py +81 -0
  42. app/modules/accounts/repository.py +80 -0
  43. app/modules/accounts/schemas.py +66 -0
  44. app/modules/accounts/service.py +211 -0
  45. app/modules/health/__init__.py +0 -0
  46. app/modules/health/api.py +10 -0
  47. app/modules/oauth/__init__.py +0 -0
  48. app/modules/oauth/api.py +57 -0
  49. app/modules/oauth/schemas.py +32 -0
  50. app/modules/oauth/service.py +356 -0
  51. app/modules/oauth/templates/oauth_success.html +122 -0
  52. app/modules/proxy/__init__.py +0 -0
  53. app/modules/proxy/api.py +76 -0
  54. app/modules/proxy/auth_manager.py +51 -0
  55. app/modules/proxy/load_balancer.py +208 -0
  56. app/modules/proxy/schemas.py +85 -0
  57. app/modules/proxy/service.py +707 -0
  58. app/modules/proxy/types.py +37 -0
  59. app/modules/proxy/usage_updater.py +147 -0
  60. app/modules/request_logs/__init__.py +0 -0
  61. app/modules/request_logs/api.py +31 -0
  62. app/modules/request_logs/repository.py +86 -0
  63. app/modules/request_logs/schemas.py +25 -0
  64. app/modules/request_logs/service.py +77 -0
  65. app/modules/shared/__init__.py +0 -0
  66. app/modules/shared/schemas.py +8 -0
  67. app/modules/usage/__init__.py +0 -0
  68. app/modules/usage/api.py +31 -0
  69. app/modules/usage/repository.py +113 -0
  70. app/modules/usage/schemas.py +62 -0
  71. app/modules/usage/service.py +246 -0
  72. app/static/7.css +1336 -0
  73. app/static/index.css +543 -0
  74. app/static/index.html +457 -0
  75. app/static/index.js +1898 -0
  76. codex_lb-0.1.2.dist-info/METADATA +108 -0
  77. codex_lb-0.1.2.dist-info/RECORD +80 -0
  78. codex_lb-0.1.2.dist-info/WHEEL +4 -0
  79. codex_lb-0.1.2.dist-info/entry_points.txt +2 -0
  80. codex_lb-0.1.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,211 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime, timedelta, timezone
4
+
5
+ from app.core import usage as usage_core
6
+ from app.core.auth import (
7
+ DEFAULT_EMAIL,
8
+ DEFAULT_PLAN,
9
+ claims_from_auth,
10
+ extract_id_token_claims,
11
+ fallback_account_id,
12
+ parse_auth_json,
13
+ )
14
+ from app.core.crypto import TokenEncryptor
15
+ from app.core.usage.logs import cost_from_log
16
+ from app.core.utils.time import from_epoch_seconds, to_utc_naive, utcnow
17
+ from app.db.models import Account, AccountStatus, UsageHistory
18
+ from app.modules.accounts.repository import AccountsRepository
19
+ from app.modules.accounts.schemas import (
20
+ AccountAuthStatus,
21
+ AccountImportResponse,
22
+ AccountSummary,
23
+ AccountTokenStatus,
24
+ AccountUsage,
25
+ )
26
+ from app.modules.proxy.usage_updater import UsageUpdater
27
+ from app.modules.request_logs.repository import RequestLogsRepository
28
+ from app.modules.usage.repository import UsageRepository
29
+
30
+
31
+ class AccountsService:
32
+ def __init__(
33
+ self,
34
+ repo: AccountsRepository,
35
+ usage_repo: UsageRepository | None = None,
36
+ logs_repo: RequestLogsRepository | None = None,
37
+ ) -> None:
38
+ self._repo = repo
39
+ self._usage_repo = usage_repo
40
+ self._logs_repo = logs_repo
41
+ self._usage_updater = UsageUpdater(usage_repo, repo) if usage_repo else None
42
+ self._encryptor = TokenEncryptor()
43
+
44
+ async def list_accounts(self) -> list[AccountSummary]:
45
+ accounts = await self._repo.list_accounts()
46
+ if not accounts:
47
+ return []
48
+ await self._refresh_usage(accounts)
49
+ primary_usage = await self._usage_repo.latest_by_account(window="primary") if self._usage_repo else {}
50
+ secondary_usage = await self._usage_repo.latest_by_account(window="secondary") if self._usage_repo else {}
51
+ cost_by_account = await self._costs_last_24h()
52
+ return [
53
+ self._account_to_summary(
54
+ account,
55
+ primary_usage.get(account.id),
56
+ secondary_usage.get(account.id),
57
+ cost_by_account.get(account.id),
58
+ )
59
+ for account in accounts
60
+ ]
61
+
62
+ async def import_account(self, raw: bytes) -> AccountImportResponse:
63
+ auth = parse_auth_json(raw)
64
+ claims = claims_from_auth(auth)
65
+
66
+ email = claims.email or DEFAULT_EMAIL
67
+ plan_type = claims.plan_type or DEFAULT_PLAN
68
+ account_id = claims.account_id or fallback_account_id(email)
69
+ last_refresh = to_utc_naive(auth.last_refresh_at) if auth.last_refresh_at else utcnow()
70
+
71
+ account = Account(
72
+ id=account_id,
73
+ email=email,
74
+ plan_type=plan_type,
75
+ access_token_encrypted=self._encryptor.encrypt(auth.tokens.access_token),
76
+ refresh_token_encrypted=self._encryptor.encrypt(auth.tokens.refresh_token),
77
+ id_token_encrypted=self._encryptor.encrypt(auth.tokens.id_token),
78
+ last_refresh=last_refresh,
79
+ status=AccountStatus.ACTIVE,
80
+ deactivation_reason=None,
81
+ )
82
+
83
+ saved = await self._repo.upsert(account)
84
+ if self._usage_repo and self._usage_updater:
85
+ latest_usage = await self._usage_repo.latest_by_account(window="primary")
86
+ await self._usage_updater.refresh_accounts([saved], latest_usage)
87
+ return AccountImportResponse(
88
+ account_id=saved.id,
89
+ email=saved.email,
90
+ plan_type=saved.plan_type,
91
+ status=saved.status,
92
+ )
93
+
94
+ async def reactivate_account(self, account_id: str) -> bool:
95
+ return await self._repo.update_status(account_id, AccountStatus.ACTIVE, None)
96
+
97
+ async def pause_account(self, account_id: str) -> bool:
98
+ return await self._repo.update_status(account_id, AccountStatus.PAUSED, None)
99
+
100
+ async def delete_account(self, account_id: str) -> bool:
101
+ return await self._repo.delete(account_id)
102
+
103
+ def _account_to_summary(
104
+ self,
105
+ account: Account,
106
+ primary_usage: UsageHistory | None,
107
+ secondary_usage: UsageHistory | None,
108
+ cost_usd_24h: float | None,
109
+ ) -> AccountSummary:
110
+ auth_status = self._build_auth_status(account)
111
+ primary_used_percent = _normalize_used_percent(primary_usage) or 0.0
112
+ secondary_used_percent = _normalize_used_percent(secondary_usage) or 0.0
113
+ primary_remaining_percent = usage_core.remaining_percent_from_used(primary_used_percent) or 0.0
114
+ secondary_remaining_percent = usage_core.remaining_percent_from_used(secondary_used_percent) or 0.0
115
+ reset_at_primary = from_epoch_seconds(primary_usage.reset_at) if primary_usage is not None else None
116
+ reset_at_secondary = from_epoch_seconds(secondary_usage.reset_at) if secondary_usage is not None else None
117
+ capacity_primary = usage_core.capacity_for_plan(account.plan_type, "primary")
118
+ capacity_secondary = usage_core.capacity_for_plan(account.plan_type, "secondary")
119
+ remaining_credits_primary = usage_core.remaining_credits_from_percent(
120
+ primary_used_percent,
121
+ capacity_primary,
122
+ )
123
+ remaining_credits_secondary = usage_core.remaining_credits_from_percent(
124
+ secondary_used_percent,
125
+ capacity_secondary,
126
+ )
127
+ return AccountSummary(
128
+ account_id=account.id,
129
+ email=account.email,
130
+ display_name=account.email,
131
+ plan_type=account.plan_type,
132
+ status=account.status.value,
133
+ usage=AccountUsage(
134
+ primary_remaining_percent=primary_remaining_percent,
135
+ secondary_remaining_percent=secondary_remaining_percent,
136
+ ),
137
+ reset_at_primary=reset_at_primary,
138
+ reset_at_secondary=reset_at_secondary,
139
+ last_refresh_at=account.last_refresh,
140
+ capacity_credits_primary=capacity_primary,
141
+ remaining_credits_primary=remaining_credits_primary,
142
+ capacity_credits_secondary=capacity_secondary,
143
+ remaining_credits_secondary=remaining_credits_secondary,
144
+ cost_usd_24h=cost_usd_24h if cost_usd_24h is not None else 0.0,
145
+ deactivation_reason=account.deactivation_reason,
146
+ auth=auth_status,
147
+ )
148
+
149
+ def _build_auth_status(self, account: Account) -> AccountAuthStatus:
150
+ access_token = self._decrypt_token(account.access_token_encrypted)
151
+ refresh_token = self._decrypt_token(account.refresh_token_encrypted)
152
+ id_token = self._decrypt_token(account.id_token_encrypted)
153
+
154
+ access_expires = _token_expiry(access_token)
155
+ refresh_state = "stored" if refresh_token else "missing"
156
+ id_state = "unknown"
157
+ if id_token:
158
+ claims = extract_id_token_claims(id_token)
159
+ if claims.model_dump(exclude_none=True):
160
+ id_state = "parsed"
161
+
162
+ return AccountAuthStatus(
163
+ access=AccountTokenStatus(expires_at=access_expires),
164
+ refresh=AccountTokenStatus(state=refresh_state),
165
+ id_token=AccountTokenStatus(state=id_state),
166
+ )
167
+
168
+ def _decrypt_token(self, encrypted: bytes | None) -> str | None:
169
+ if not encrypted:
170
+ return None
171
+ try:
172
+ return self._encryptor.decrypt(encrypted)
173
+ except Exception:
174
+ return None
175
+
176
+ async def _refresh_usage(self, accounts: list[Account]) -> None:
177
+ if not self._usage_repo or not self._usage_updater:
178
+ return
179
+ latest_usage = await self._usage_repo.latest_by_account(window="primary")
180
+ await self._usage_updater.refresh_accounts(accounts, latest_usage)
181
+
182
+ async def _costs_last_24h(self) -> dict[str, float]:
183
+ if not self._logs_repo:
184
+ return {}
185
+ since = utcnow() - timedelta(hours=24)
186
+ logs = await self._logs_repo.list_since(since)
187
+ totals: dict[str, float] = {}
188
+ for log in logs:
189
+ cost = cost_from_log(log)
190
+ if cost is None:
191
+ continue
192
+ totals[log.account_id] = totals.get(log.account_id, 0.0) + cost
193
+ return {account_id: round(total, 6) for account_id, total in totals.items()}
194
+
195
+
196
+ def _token_expiry(token: str | None) -> datetime | None:
197
+ if not token:
198
+ return None
199
+ claims = extract_id_token_claims(token)
200
+ exp = claims.exp
201
+ if isinstance(exp, (int, float)):
202
+ return datetime.fromtimestamp(exp, tz=timezone.utc)
203
+ if isinstance(exp, str) and exp.isdigit():
204
+ return datetime.fromtimestamp(int(exp), tz=timezone.utc)
205
+ return None
206
+
207
+
208
+ def _normalize_used_percent(entry: UsageHistory | None) -> float | None:
209
+ if not entry:
210
+ return None
211
+ return entry.used_percent
File without changes
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+
3
+ from fastapi import APIRouter
4
+
5
+ router = APIRouter(tags=["health"])
6
+
7
+
8
+ @router.get("/health")
9
+ async def health_check() -> dict:
10
+ return {"status": "ok"}
File without changes
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ from fastapi import APIRouter, Body, Depends
4
+ from fastapi.responses import JSONResponse
5
+
6
+ from app.core.clients.oauth import OAuthError
7
+ from app.core.errors import dashboard_error
8
+ from app.dependencies import OauthContext, get_oauth_context
9
+ from app.modules.oauth.schemas import (
10
+ OauthCompleteRequest,
11
+ OauthCompleteResponse,
12
+ OauthStartRequest,
13
+ OauthStartResponse,
14
+ OauthStatusResponse,
15
+ )
16
+
17
+ router = APIRouter(prefix="/api/oauth", tags=["dashboard"])
18
+
19
+
20
+ @router.post("/start", response_model=OauthStartResponse)
21
+ async def start_oauth(
22
+ request: OauthStartRequest,
23
+ context: OauthContext = Depends(get_oauth_context),
24
+ ) -> OauthStartResponse | JSONResponse:
25
+ try:
26
+ return await context.service.start_oauth(request)
27
+ except OAuthError as exc:
28
+ return JSONResponse(
29
+ status_code=502,
30
+ content=dashboard_error(exc.code, exc.message),
31
+ )
32
+ except NotImplementedError:
33
+ return JSONResponse(
34
+ status_code=501,
35
+ content=dashboard_error("not_implemented", "OAuth start is not implemented"),
36
+ )
37
+
38
+
39
+ @router.get("/status", response_model=OauthStatusResponse)
40
+ async def oauth_status(
41
+ context: OauthContext = Depends(get_oauth_context),
42
+ ) -> OauthStatusResponse | JSONResponse:
43
+ return await context.service.oauth_status()
44
+
45
+
46
+ @router.post("/complete", response_model=OauthCompleteResponse)
47
+ async def complete_oauth(
48
+ request: OauthCompleteRequest | None = Body(default=None),
49
+ context: OauthContext = Depends(get_oauth_context),
50
+ ) -> OauthCompleteResponse | JSONResponse:
51
+ try:
52
+ return await context.service.complete_oauth(request)
53
+ except NotImplementedError:
54
+ return JSONResponse(
55
+ status_code=501,
56
+ content=dashboard_error("not_implemented", "OAuth complete is not implemented"),
57
+ )
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ from app.modules.shared.schemas import DashboardModel
4
+
5
+
6
+ class OauthStartRequest(DashboardModel):
7
+ force_method: str | None = None
8
+
9
+
10
+ class OauthStartResponse(DashboardModel):
11
+ method: str
12
+ authorization_url: str | None = None
13
+ callback_url: str | None = None
14
+ verification_url: str | None = None
15
+ user_code: str | None = None
16
+ device_auth_id: str | None = None
17
+ interval_seconds: int | None = None
18
+ expires_in_seconds: int | None = None
19
+
20
+
21
+ class OauthStatusResponse(DashboardModel):
22
+ status: str
23
+ error_message: str | None = None
24
+
25
+
26
+ class OauthCompleteRequest(DashboardModel):
27
+ device_auth_id: str | None = None
28
+ user_code: str | None = None
29
+
30
+
31
+ class OauthCompleteResponse(DashboardModel):
32
+ status: str
@@ -0,0 +1,356 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import secrets
5
+ import time
6
+ from contextlib import AbstractAsyncContextManager
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Awaitable, Callable
10
+
11
+ from aiohttp import web
12
+
13
+ from app.core.auth import (
14
+ DEFAULT_EMAIL,
15
+ DEFAULT_PLAN,
16
+ OpenAIAuthClaims,
17
+ extract_id_token_claims,
18
+ fallback_account_id,
19
+ )
20
+ from app.core.clients.oauth import (
21
+ OAuthError,
22
+ OAuthTokens,
23
+ build_authorization_url,
24
+ exchange_authorization_code,
25
+ exchange_device_token,
26
+ generate_pkce_pair,
27
+ request_device_code,
28
+ )
29
+ from app.core.config.settings import get_settings
30
+ from app.core.crypto import TokenEncryptor
31
+ from app.core.utils.time import utcnow
32
+ from app.db.models import Account, AccountStatus
33
+ from app.modules.accounts.repository import AccountsRepository
34
+ from app.modules.oauth.schemas import (
35
+ OauthCompleteRequest,
36
+ OauthCompleteResponse,
37
+ OauthStartRequest,
38
+ OauthStartResponse,
39
+ OauthStatusResponse,
40
+ )
41
+
42
+ _async_sleep = asyncio.sleep
43
+ _SUCCESS_TEMPLATE = Path(__file__).resolve().parent / "templates" / "oauth_success.html"
44
+
45
+
46
+ @dataclass
47
+ class OAuthState:
48
+ status: str = "pending"
49
+ method: str | None = None
50
+ error_message: str | None = None
51
+ state_token: str | None = None
52
+ code_verifier: str | None = None
53
+ device_auth_id: str | None = None
54
+ user_code: str | None = None
55
+ interval_seconds: int | None = None
56
+ expires_at: float | None = None
57
+ callback_server: "OAuthCallbackServer | None" = None
58
+ poll_task: asyncio.Task[None] | None = None
59
+
60
+
61
+ class OAuthStateStore:
62
+ def __init__(self) -> None:
63
+ self._lock = asyncio.Lock()
64
+ self._state = OAuthState(status="idle")
65
+
66
+ @property
67
+ def lock(self) -> asyncio.Lock:
68
+ return self._lock
69
+
70
+ @property
71
+ def state(self) -> OAuthState:
72
+ return self._state
73
+
74
+ async def reset(self) -> None:
75
+ async with self._lock:
76
+ await self._cleanup_locked()
77
+ self._state = OAuthState(status="idle")
78
+
79
+ async def _cleanup_locked(self) -> None:
80
+ task = self._state.poll_task
81
+ if task and not task.done():
82
+ task.cancel()
83
+ server = self._state.callback_server
84
+ if server:
85
+ await server.stop()
86
+
87
+
88
+ class OAuthCallbackServer:
89
+ def __init__(
90
+ self,
91
+ handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
92
+ host: str = "127.0.0.1",
93
+ port: int = 1455,
94
+ ) -> None:
95
+ self._handler = handler
96
+ self._host = host
97
+ self._port = port
98
+ self._runner: web.AppRunner | None = None
99
+ self._site: web.TCPSite | None = None
100
+
101
+ async def start(self) -> None:
102
+ app = web.Application()
103
+ app.router.add_get("/auth/callback", self._handler)
104
+ self._runner = web.AppRunner(app)
105
+ await self._runner.setup()
106
+ self._site = web.TCPSite(self._runner, self._host, self._port)
107
+ await self._site.start()
108
+
109
+ async def stop(self) -> None:
110
+ if self._runner:
111
+ await self._runner.cleanup()
112
+ self._runner = None
113
+ self._site = None
114
+
115
+
116
+ _OAUTH_STORE = OAuthStateStore()
117
+
118
+
119
+ class OauthService:
120
+ def __init__(
121
+ self,
122
+ accounts_repo: AccountsRepository,
123
+ repo_factory: Callable[[], AbstractAsyncContextManager[AccountsRepository]] | None = None,
124
+ ) -> None:
125
+ self._accounts_repo = accounts_repo
126
+ self._encryptor = TokenEncryptor()
127
+ self._store = _OAUTH_STORE
128
+ self._repo_factory = repo_factory
129
+
130
+ async def start_oauth(self, request: OauthStartRequest) -> OauthStartResponse:
131
+ force_method = (request.force_method or "").lower()
132
+ if not force_method:
133
+ accounts = await self._accounts_repo.list_accounts()
134
+ if accounts:
135
+ async with self._store.lock:
136
+ await self._store._cleanup_locked()
137
+ self._store._state = OAuthState(status="success")
138
+ return OauthStartResponse(method="browser")
139
+
140
+ if force_method == "device":
141
+ return await self._start_device_flow()
142
+
143
+ try:
144
+ return await self._start_browser_flow()
145
+ except OSError:
146
+ return await self._start_device_flow()
147
+
148
+ async def oauth_status(self) -> OauthStatusResponse:
149
+ async with self._store.lock:
150
+ state = self._store.state
151
+ status = state.status if state.status != "idle" else "pending"
152
+ return OauthStatusResponse(status=status, error_message=state.error_message)
153
+
154
+ async def complete_oauth(self, request: OauthCompleteRequest | None = None) -> OauthCompleteResponse:
155
+ payload = request or OauthCompleteRequest()
156
+ async with self._store.lock:
157
+ state = self._store.state
158
+ if payload.device_auth_id:
159
+ state.device_auth_id = payload.device_auth_id
160
+ if payload.user_code:
161
+ state.user_code = payload.user_code
162
+ if state.status == "success":
163
+ return OauthCompleteResponse(status="success")
164
+ if state.method != "device":
165
+ return OauthCompleteResponse(status="pending")
166
+ if state.poll_task and not state.poll_task.done():
167
+ return OauthCompleteResponse(status="pending")
168
+ if not state.device_auth_id or not state.user_code or not state.expires_at:
169
+ state.status = "error"
170
+ state.error_message = "Device code flow is not initialized."
171
+ return OauthCompleteResponse(status="error")
172
+
173
+ interval = state.interval_seconds if state.interval_seconds is not None else 0
174
+ interval = max(interval, 0)
175
+ poll_context = DevicePollContext(
176
+ device_auth_id=state.device_auth_id,
177
+ user_code=state.user_code,
178
+ interval_seconds=interval,
179
+ expires_at=state.expires_at,
180
+ )
181
+ state.poll_task = asyncio.create_task(self._poll_device_tokens(poll_context))
182
+ return OauthCompleteResponse(status="pending")
183
+
184
+ async def _start_browser_flow(self) -> OauthStartResponse:
185
+ await self._store.reset()
186
+ code_verifier, code_challenge = generate_pkce_pair()
187
+ state_token = secrets.token_urlsafe(16)
188
+ authorization_url = build_authorization_url(state=state_token, code_challenge=code_challenge)
189
+ settings = get_settings()
190
+
191
+ async with self._store.lock:
192
+ state = self._store.state
193
+ state.status = "pending"
194
+ state.method = "browser"
195
+ state.state_token = state_token
196
+ state.code_verifier = code_verifier
197
+ state.error_message = None
198
+
199
+ callback_server = OAuthCallbackServer(
200
+ self._handle_callback,
201
+ host=settings.oauth_callback_host,
202
+ port=settings.oauth_callback_port,
203
+ )
204
+ await callback_server.start()
205
+
206
+ async with self._store.lock:
207
+ self._store.state.callback_server = callback_server
208
+
209
+ return OauthStartResponse(
210
+ method="browser",
211
+ authorization_url=authorization_url,
212
+ callback_url=settings.oauth_redirect_uri,
213
+ )
214
+
215
+ async def _start_device_flow(self) -> OauthStartResponse:
216
+ await self._store.reset()
217
+ try:
218
+ device = await request_device_code()
219
+ except OAuthError as exc:
220
+ await self._set_error(exc.message)
221
+ raise
222
+
223
+ async with self._store.lock:
224
+ state = self._store.state
225
+ state.status = "pending"
226
+ state.method = "device"
227
+ state.device_auth_id = device.device_auth_id
228
+ state.user_code = device.user_code
229
+ state.interval_seconds = device.interval_seconds
230
+ state.expires_at = time.time() + device.expires_in_seconds
231
+ state.error_message = None
232
+
233
+ return OauthStartResponse(
234
+ method="device",
235
+ verification_url=device.verification_url,
236
+ user_code=device.user_code,
237
+ device_auth_id=device.device_auth_id,
238
+ interval_seconds=device.interval_seconds,
239
+ expires_in_seconds=device.expires_in_seconds,
240
+ )
241
+
242
+ async def _handle_callback(self, request: web.Request) -> web.Response:
243
+ params = request.rel_url.query
244
+ error = params.get("error")
245
+ code = params.get("code")
246
+ state = params.get("state")
247
+
248
+ if error:
249
+ await self._set_error(f"OAuth error: {error}")
250
+ return self._html_response(_error_html("Authorization failed."))
251
+
252
+ async with self._store.lock:
253
+ expected_state = self._store.state.state_token
254
+ verifier = self._store.state.code_verifier
255
+
256
+ if not code or not state or state != expected_state or not verifier:
257
+ await self._set_error("Invalid OAuth callback state.")
258
+ return self._html_response(_error_html("Invalid OAuth callback."))
259
+
260
+ try:
261
+ tokens = await exchange_authorization_code(code=code, code_verifier=verifier)
262
+ await self._persist_tokens(tokens)
263
+ await self._set_success()
264
+ html = _success_html()
265
+ except OAuthError as exc:
266
+ await self._set_error(exc.message)
267
+ html = _error_html(exc.message)
268
+
269
+ asyncio.create_task(self._stop_callback_server())
270
+ return self._html_response(html)
271
+
272
+ async def _poll_device_tokens(self, context: "DevicePollContext") -> None:
273
+ try:
274
+ while time.time() < context.expires_at:
275
+ tokens = await exchange_device_token(
276
+ device_auth_id=context.device_auth_id,
277
+ user_code=context.user_code,
278
+ )
279
+ if tokens:
280
+ await self._persist_tokens(tokens)
281
+ await self._set_success()
282
+ return
283
+ await _async_sleep(context.interval_seconds)
284
+ await self._set_error("Device code expired.")
285
+ except OAuthError as exc:
286
+ await self._set_error(exc.message)
287
+ finally:
288
+ async with self._store.lock:
289
+ current = asyncio.current_task()
290
+ if self._store.state.poll_task is current:
291
+ self._store.state.poll_task = None
292
+
293
+ async def _persist_tokens(self, tokens: OAuthTokens) -> None:
294
+ claims = extract_id_token_claims(tokens.id_token)
295
+ auth_claims = claims.auth or OpenAIAuthClaims()
296
+ account_id = auth_claims.chatgpt_account_id or claims.chatgpt_account_id
297
+ email = claims.email or DEFAULT_EMAIL
298
+ plan_type = auth_claims.chatgpt_plan_type or claims.chatgpt_plan_type or DEFAULT_PLAN
299
+ account_id = account_id or fallback_account_id(email)
300
+
301
+ account = Account(
302
+ id=account_id,
303
+ email=email,
304
+ plan_type=plan_type,
305
+ access_token_encrypted=self._encryptor.encrypt(tokens.access_token),
306
+ refresh_token_encrypted=self._encryptor.encrypt(tokens.refresh_token),
307
+ id_token_encrypted=self._encryptor.encrypt(tokens.id_token),
308
+ last_refresh=utcnow(),
309
+ status=AccountStatus.ACTIVE,
310
+ deactivation_reason=None,
311
+ )
312
+ if self._repo_factory:
313
+ async with self._repo_factory() as repo:
314
+ await repo.upsert(account)
315
+ else:
316
+ await self._accounts_repo.upsert(account)
317
+
318
+ async def _set_success(self) -> None:
319
+ async with self._store.lock:
320
+ self._store.state.status = "success"
321
+ self._store.state.error_message = None
322
+
323
+ async def _set_error(self, message: str) -> None:
324
+ async with self._store.lock:
325
+ self._store.state.status = "error"
326
+ self._store.state.error_message = message
327
+
328
+ async def _stop_callback_server(self) -> None:
329
+ async with self._store.lock:
330
+ server = self._store.state.callback_server
331
+ self._store.state.callback_server = None
332
+ if server:
333
+ await server.stop()
334
+
335
+ @staticmethod
336
+ def _html_response(html: str) -> web.Response:
337
+ return web.Response(text=html, content_type="text/html")
338
+
339
+
340
+ @dataclass(frozen=True)
341
+ class DevicePollContext:
342
+ device_auth_id: str
343
+ user_code: str
344
+ interval_seconds: int
345
+ expires_at: float
346
+
347
+
348
+ def _success_html() -> str:
349
+ try:
350
+ return _SUCCESS_TEMPLATE.read_text(encoding="utf-8")
351
+ except OSError:
352
+ return "<html><body><h1>Login complete</h1><p>Return to the dashboard.</p></body></html>"
353
+
354
+
355
+ def _error_html(message: str) -> str:
356
+ return f"<html><body><h1>Login failed</h1><p>{message}</p></body></html>"