codex-lb 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +5 -0
- app/cli.py +24 -0
- app/core/__init__.py +0 -0
- app/core/auth/__init__.py +96 -0
- app/core/auth/models.py +49 -0
- app/core/auth/refresh.py +144 -0
- app/core/balancer/__init__.py +19 -0
- app/core/balancer/logic.py +140 -0
- app/core/balancer/types.py +9 -0
- app/core/clients/__init__.py +0 -0
- app/core/clients/http.py +39 -0
- app/core/clients/oauth.py +340 -0
- app/core/clients/proxy.py +265 -0
- app/core/clients/usage.py +143 -0
- app/core/config/__init__.py +0 -0
- app/core/config/settings.py +69 -0
- app/core/crypto.py +37 -0
- app/core/errors.py +73 -0
- app/core/openai/__init__.py +0 -0
- app/core/openai/models.py +122 -0
- app/core/openai/parsing.py +55 -0
- app/core/openai/requests.py +59 -0
- app/core/types.py +4 -0
- app/core/usage/__init__.py +185 -0
- app/core/usage/logs.py +57 -0
- app/core/usage/models.py +35 -0
- app/core/usage/pricing.py +172 -0
- app/core/usage/types.py +95 -0
- app/core/utils/__init__.py +0 -0
- app/core/utils/request_id.py +30 -0
- app/core/utils/retry.py +16 -0
- app/core/utils/sse.py +13 -0
- app/core/utils/time.py +19 -0
- app/db/__init__.py +0 -0
- app/db/models.py +82 -0
- app/db/session.py +44 -0
- app/dependencies.py +123 -0
- app/main.py +124 -0
- app/modules/__init__.py +0 -0
- app/modules/accounts/__init__.py +0 -0
- app/modules/accounts/api.py +81 -0
- app/modules/accounts/repository.py +80 -0
- app/modules/accounts/schemas.py +66 -0
- app/modules/accounts/service.py +211 -0
- app/modules/health/__init__.py +0 -0
- app/modules/health/api.py +10 -0
- app/modules/oauth/__init__.py +0 -0
- app/modules/oauth/api.py +57 -0
- app/modules/oauth/schemas.py +32 -0
- app/modules/oauth/service.py +356 -0
- app/modules/oauth/templates/oauth_success.html +122 -0
- app/modules/proxy/__init__.py +0 -0
- app/modules/proxy/api.py +76 -0
- app/modules/proxy/auth_manager.py +51 -0
- app/modules/proxy/load_balancer.py +208 -0
- app/modules/proxy/schemas.py +85 -0
- app/modules/proxy/service.py +707 -0
- app/modules/proxy/types.py +37 -0
- app/modules/proxy/usage_updater.py +147 -0
- app/modules/request_logs/__init__.py +0 -0
- app/modules/request_logs/api.py +31 -0
- app/modules/request_logs/repository.py +86 -0
- app/modules/request_logs/schemas.py +25 -0
- app/modules/request_logs/service.py +77 -0
- app/modules/shared/__init__.py +0 -0
- app/modules/shared/schemas.py +8 -0
- app/modules/usage/__init__.py +0 -0
- app/modules/usage/api.py +31 -0
- app/modules/usage/repository.py +113 -0
- app/modules/usage/schemas.py +62 -0
- app/modules/usage/service.py +246 -0
- app/static/7.css +1336 -0
- app/static/index.css +543 -0
- app/static/index.html +457 -0
- app/static/index.js +1898 -0
- codex_lb-0.1.2.dist-info/METADATA +108 -0
- codex_lb-0.1.2.dist-info/RECORD +80 -0
- codex_lb-0.1.2.dist-info/WHEEL +4 -0
- codex_lb-0.1.2.dist-info/entry_points.txt +2 -0
- codex_lb-0.1.2.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timedelta, timezone
|
|
4
|
+
|
|
5
|
+
from app.core import usage as usage_core
|
|
6
|
+
from app.core.auth import (
|
|
7
|
+
DEFAULT_EMAIL,
|
|
8
|
+
DEFAULT_PLAN,
|
|
9
|
+
claims_from_auth,
|
|
10
|
+
extract_id_token_claims,
|
|
11
|
+
fallback_account_id,
|
|
12
|
+
parse_auth_json,
|
|
13
|
+
)
|
|
14
|
+
from app.core.crypto import TokenEncryptor
|
|
15
|
+
from app.core.usage.logs import cost_from_log
|
|
16
|
+
from app.core.utils.time import from_epoch_seconds, to_utc_naive, utcnow
|
|
17
|
+
from app.db.models import Account, AccountStatus, UsageHistory
|
|
18
|
+
from app.modules.accounts.repository import AccountsRepository
|
|
19
|
+
from app.modules.accounts.schemas import (
|
|
20
|
+
AccountAuthStatus,
|
|
21
|
+
AccountImportResponse,
|
|
22
|
+
AccountSummary,
|
|
23
|
+
AccountTokenStatus,
|
|
24
|
+
AccountUsage,
|
|
25
|
+
)
|
|
26
|
+
from app.modules.proxy.usage_updater import UsageUpdater
|
|
27
|
+
from app.modules.request_logs.repository import RequestLogsRepository
|
|
28
|
+
from app.modules.usage.repository import UsageRepository
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AccountsService:
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
repo: AccountsRepository,
|
|
35
|
+
usage_repo: UsageRepository | None = None,
|
|
36
|
+
logs_repo: RequestLogsRepository | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
self._repo = repo
|
|
39
|
+
self._usage_repo = usage_repo
|
|
40
|
+
self._logs_repo = logs_repo
|
|
41
|
+
self._usage_updater = UsageUpdater(usage_repo, repo) if usage_repo else None
|
|
42
|
+
self._encryptor = TokenEncryptor()
|
|
43
|
+
|
|
44
|
+
async def list_accounts(self) -> list[AccountSummary]:
|
|
45
|
+
accounts = await self._repo.list_accounts()
|
|
46
|
+
if not accounts:
|
|
47
|
+
return []
|
|
48
|
+
await self._refresh_usage(accounts)
|
|
49
|
+
primary_usage = await self._usage_repo.latest_by_account(window="primary") if self._usage_repo else {}
|
|
50
|
+
secondary_usage = await self._usage_repo.latest_by_account(window="secondary") if self._usage_repo else {}
|
|
51
|
+
cost_by_account = await self._costs_last_24h()
|
|
52
|
+
return [
|
|
53
|
+
self._account_to_summary(
|
|
54
|
+
account,
|
|
55
|
+
primary_usage.get(account.id),
|
|
56
|
+
secondary_usage.get(account.id),
|
|
57
|
+
cost_by_account.get(account.id),
|
|
58
|
+
)
|
|
59
|
+
for account in accounts
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
async def import_account(self, raw: bytes) -> AccountImportResponse:
|
|
63
|
+
auth = parse_auth_json(raw)
|
|
64
|
+
claims = claims_from_auth(auth)
|
|
65
|
+
|
|
66
|
+
email = claims.email or DEFAULT_EMAIL
|
|
67
|
+
plan_type = claims.plan_type or DEFAULT_PLAN
|
|
68
|
+
account_id = claims.account_id or fallback_account_id(email)
|
|
69
|
+
last_refresh = to_utc_naive(auth.last_refresh_at) if auth.last_refresh_at else utcnow()
|
|
70
|
+
|
|
71
|
+
account = Account(
|
|
72
|
+
id=account_id,
|
|
73
|
+
email=email,
|
|
74
|
+
plan_type=plan_type,
|
|
75
|
+
access_token_encrypted=self._encryptor.encrypt(auth.tokens.access_token),
|
|
76
|
+
refresh_token_encrypted=self._encryptor.encrypt(auth.tokens.refresh_token),
|
|
77
|
+
id_token_encrypted=self._encryptor.encrypt(auth.tokens.id_token),
|
|
78
|
+
last_refresh=last_refresh,
|
|
79
|
+
status=AccountStatus.ACTIVE,
|
|
80
|
+
deactivation_reason=None,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
saved = await self._repo.upsert(account)
|
|
84
|
+
if self._usage_repo and self._usage_updater:
|
|
85
|
+
latest_usage = await self._usage_repo.latest_by_account(window="primary")
|
|
86
|
+
await self._usage_updater.refresh_accounts([saved], latest_usage)
|
|
87
|
+
return AccountImportResponse(
|
|
88
|
+
account_id=saved.id,
|
|
89
|
+
email=saved.email,
|
|
90
|
+
plan_type=saved.plan_type,
|
|
91
|
+
status=saved.status,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
async def reactivate_account(self, account_id: str) -> bool:
|
|
95
|
+
return await self._repo.update_status(account_id, AccountStatus.ACTIVE, None)
|
|
96
|
+
|
|
97
|
+
async def pause_account(self, account_id: str) -> bool:
|
|
98
|
+
return await self._repo.update_status(account_id, AccountStatus.PAUSED, None)
|
|
99
|
+
|
|
100
|
+
async def delete_account(self, account_id: str) -> bool:
|
|
101
|
+
return await self._repo.delete(account_id)
|
|
102
|
+
|
|
103
|
+
def _account_to_summary(
|
|
104
|
+
self,
|
|
105
|
+
account: Account,
|
|
106
|
+
primary_usage: UsageHistory | None,
|
|
107
|
+
secondary_usage: UsageHistory | None,
|
|
108
|
+
cost_usd_24h: float | None,
|
|
109
|
+
) -> AccountSummary:
|
|
110
|
+
auth_status = self._build_auth_status(account)
|
|
111
|
+
primary_used_percent = _normalize_used_percent(primary_usage) or 0.0
|
|
112
|
+
secondary_used_percent = _normalize_used_percent(secondary_usage) or 0.0
|
|
113
|
+
primary_remaining_percent = usage_core.remaining_percent_from_used(primary_used_percent) or 0.0
|
|
114
|
+
secondary_remaining_percent = usage_core.remaining_percent_from_used(secondary_used_percent) or 0.0
|
|
115
|
+
reset_at_primary = from_epoch_seconds(primary_usage.reset_at) if primary_usage is not None else None
|
|
116
|
+
reset_at_secondary = from_epoch_seconds(secondary_usage.reset_at) if secondary_usage is not None else None
|
|
117
|
+
capacity_primary = usage_core.capacity_for_plan(account.plan_type, "primary")
|
|
118
|
+
capacity_secondary = usage_core.capacity_for_plan(account.plan_type, "secondary")
|
|
119
|
+
remaining_credits_primary = usage_core.remaining_credits_from_percent(
|
|
120
|
+
primary_used_percent,
|
|
121
|
+
capacity_primary,
|
|
122
|
+
)
|
|
123
|
+
remaining_credits_secondary = usage_core.remaining_credits_from_percent(
|
|
124
|
+
secondary_used_percent,
|
|
125
|
+
capacity_secondary,
|
|
126
|
+
)
|
|
127
|
+
return AccountSummary(
|
|
128
|
+
account_id=account.id,
|
|
129
|
+
email=account.email,
|
|
130
|
+
display_name=account.email,
|
|
131
|
+
plan_type=account.plan_type,
|
|
132
|
+
status=account.status.value,
|
|
133
|
+
usage=AccountUsage(
|
|
134
|
+
primary_remaining_percent=primary_remaining_percent,
|
|
135
|
+
secondary_remaining_percent=secondary_remaining_percent,
|
|
136
|
+
),
|
|
137
|
+
reset_at_primary=reset_at_primary,
|
|
138
|
+
reset_at_secondary=reset_at_secondary,
|
|
139
|
+
last_refresh_at=account.last_refresh,
|
|
140
|
+
capacity_credits_primary=capacity_primary,
|
|
141
|
+
remaining_credits_primary=remaining_credits_primary,
|
|
142
|
+
capacity_credits_secondary=capacity_secondary,
|
|
143
|
+
remaining_credits_secondary=remaining_credits_secondary,
|
|
144
|
+
cost_usd_24h=cost_usd_24h if cost_usd_24h is not None else 0.0,
|
|
145
|
+
deactivation_reason=account.deactivation_reason,
|
|
146
|
+
auth=auth_status,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _build_auth_status(self, account: Account) -> AccountAuthStatus:
|
|
150
|
+
access_token = self._decrypt_token(account.access_token_encrypted)
|
|
151
|
+
refresh_token = self._decrypt_token(account.refresh_token_encrypted)
|
|
152
|
+
id_token = self._decrypt_token(account.id_token_encrypted)
|
|
153
|
+
|
|
154
|
+
access_expires = _token_expiry(access_token)
|
|
155
|
+
refresh_state = "stored" if refresh_token else "missing"
|
|
156
|
+
id_state = "unknown"
|
|
157
|
+
if id_token:
|
|
158
|
+
claims = extract_id_token_claims(id_token)
|
|
159
|
+
if claims.model_dump(exclude_none=True):
|
|
160
|
+
id_state = "parsed"
|
|
161
|
+
|
|
162
|
+
return AccountAuthStatus(
|
|
163
|
+
access=AccountTokenStatus(expires_at=access_expires),
|
|
164
|
+
refresh=AccountTokenStatus(state=refresh_state),
|
|
165
|
+
id_token=AccountTokenStatus(state=id_state),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _decrypt_token(self, encrypted: bytes | None) -> str | None:
|
|
169
|
+
if not encrypted:
|
|
170
|
+
return None
|
|
171
|
+
try:
|
|
172
|
+
return self._encryptor.decrypt(encrypted)
|
|
173
|
+
except Exception:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
async def _refresh_usage(self, accounts: list[Account]) -> None:
|
|
177
|
+
if not self._usage_repo or not self._usage_updater:
|
|
178
|
+
return
|
|
179
|
+
latest_usage = await self._usage_repo.latest_by_account(window="primary")
|
|
180
|
+
await self._usage_updater.refresh_accounts(accounts, latest_usage)
|
|
181
|
+
|
|
182
|
+
async def _costs_last_24h(self) -> dict[str, float]:
|
|
183
|
+
if not self._logs_repo:
|
|
184
|
+
return {}
|
|
185
|
+
since = utcnow() - timedelta(hours=24)
|
|
186
|
+
logs = await self._logs_repo.list_since(since)
|
|
187
|
+
totals: dict[str, float] = {}
|
|
188
|
+
for log in logs:
|
|
189
|
+
cost = cost_from_log(log)
|
|
190
|
+
if cost is None:
|
|
191
|
+
continue
|
|
192
|
+
totals[log.account_id] = totals.get(log.account_id, 0.0) + cost
|
|
193
|
+
return {account_id: round(total, 6) for account_id, total in totals.items()}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _token_expiry(token: str | None) -> datetime | None:
|
|
197
|
+
if not token:
|
|
198
|
+
return None
|
|
199
|
+
claims = extract_id_token_claims(token)
|
|
200
|
+
exp = claims.exp
|
|
201
|
+
if isinstance(exp, (int, float)):
|
|
202
|
+
return datetime.fromtimestamp(exp, tz=timezone.utc)
|
|
203
|
+
if isinstance(exp, str) and exp.isdigit():
|
|
204
|
+
return datetime.fromtimestamp(int(exp), tz=timezone.utc)
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _normalize_used_percent(entry: UsageHistory | None) -> float | None:
|
|
209
|
+
if not entry:
|
|
210
|
+
return None
|
|
211
|
+
return entry.used_percent
|
|
File without changes
|
|
File without changes
|
app/modules/oauth/api.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, Body, Depends
|
|
4
|
+
from fastapi.responses import JSONResponse
|
|
5
|
+
|
|
6
|
+
from app.core.clients.oauth import OAuthError
|
|
7
|
+
from app.core.errors import dashboard_error
|
|
8
|
+
from app.dependencies import OauthContext, get_oauth_context
|
|
9
|
+
from app.modules.oauth.schemas import (
|
|
10
|
+
OauthCompleteRequest,
|
|
11
|
+
OauthCompleteResponse,
|
|
12
|
+
OauthStartRequest,
|
|
13
|
+
OauthStartResponse,
|
|
14
|
+
OauthStatusResponse,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
router = APIRouter(prefix="/api/oauth", tags=["dashboard"])
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@router.post("/start", response_model=OauthStartResponse)
|
|
21
|
+
async def start_oauth(
|
|
22
|
+
request: OauthStartRequest,
|
|
23
|
+
context: OauthContext = Depends(get_oauth_context),
|
|
24
|
+
) -> OauthStartResponse | JSONResponse:
|
|
25
|
+
try:
|
|
26
|
+
return await context.service.start_oauth(request)
|
|
27
|
+
except OAuthError as exc:
|
|
28
|
+
return JSONResponse(
|
|
29
|
+
status_code=502,
|
|
30
|
+
content=dashboard_error(exc.code, exc.message),
|
|
31
|
+
)
|
|
32
|
+
except NotImplementedError:
|
|
33
|
+
return JSONResponse(
|
|
34
|
+
status_code=501,
|
|
35
|
+
content=dashboard_error("not_implemented", "OAuth start is not implemented"),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@router.get("/status", response_model=OauthStatusResponse)
|
|
40
|
+
async def oauth_status(
|
|
41
|
+
context: OauthContext = Depends(get_oauth_context),
|
|
42
|
+
) -> OauthStatusResponse | JSONResponse:
|
|
43
|
+
return await context.service.oauth_status()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@router.post("/complete", response_model=OauthCompleteResponse)
|
|
47
|
+
async def complete_oauth(
|
|
48
|
+
request: OauthCompleteRequest | None = Body(default=None),
|
|
49
|
+
context: OauthContext = Depends(get_oauth_context),
|
|
50
|
+
) -> OauthCompleteResponse | JSONResponse:
|
|
51
|
+
try:
|
|
52
|
+
return await context.service.complete_oauth(request)
|
|
53
|
+
except NotImplementedError:
|
|
54
|
+
return JSONResponse(
|
|
55
|
+
status_code=501,
|
|
56
|
+
content=dashboard_error("not_implemented", "OAuth complete is not implemented"),
|
|
57
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from app.modules.shared.schemas import DashboardModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class OauthStartRequest(DashboardModel):
|
|
7
|
+
force_method: str | None = None
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OauthStartResponse(DashboardModel):
|
|
11
|
+
method: str
|
|
12
|
+
authorization_url: str | None = None
|
|
13
|
+
callback_url: str | None = None
|
|
14
|
+
verification_url: str | None = None
|
|
15
|
+
user_code: str | None = None
|
|
16
|
+
device_auth_id: str | None = None
|
|
17
|
+
interval_seconds: int | None = None
|
|
18
|
+
expires_in_seconds: int | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OauthStatusResponse(DashboardModel):
|
|
22
|
+
status: str
|
|
23
|
+
error_message: str | None = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class OauthCompleteRequest(DashboardModel):
|
|
27
|
+
device_auth_id: str | None = None
|
|
28
|
+
user_code: str | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OauthCompleteResponse(DashboardModel):
|
|
32
|
+
status: str
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import secrets
|
|
5
|
+
import time
|
|
6
|
+
from contextlib import AbstractAsyncContextManager
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Awaitable, Callable
|
|
10
|
+
|
|
11
|
+
from aiohttp import web
|
|
12
|
+
|
|
13
|
+
from app.core.auth import (
|
|
14
|
+
DEFAULT_EMAIL,
|
|
15
|
+
DEFAULT_PLAN,
|
|
16
|
+
OpenAIAuthClaims,
|
|
17
|
+
extract_id_token_claims,
|
|
18
|
+
fallback_account_id,
|
|
19
|
+
)
|
|
20
|
+
from app.core.clients.oauth import (
|
|
21
|
+
OAuthError,
|
|
22
|
+
OAuthTokens,
|
|
23
|
+
build_authorization_url,
|
|
24
|
+
exchange_authorization_code,
|
|
25
|
+
exchange_device_token,
|
|
26
|
+
generate_pkce_pair,
|
|
27
|
+
request_device_code,
|
|
28
|
+
)
|
|
29
|
+
from app.core.config.settings import get_settings
|
|
30
|
+
from app.core.crypto import TokenEncryptor
|
|
31
|
+
from app.core.utils.time import utcnow
|
|
32
|
+
from app.db.models import Account, AccountStatus
|
|
33
|
+
from app.modules.accounts.repository import AccountsRepository
|
|
34
|
+
from app.modules.oauth.schemas import (
|
|
35
|
+
OauthCompleteRequest,
|
|
36
|
+
OauthCompleteResponse,
|
|
37
|
+
OauthStartRequest,
|
|
38
|
+
OauthStartResponse,
|
|
39
|
+
OauthStatusResponse,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
_async_sleep = asyncio.sleep
|
|
43
|
+
_SUCCESS_TEMPLATE = Path(__file__).resolve().parent / "templates" / "oauth_success.html"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class OAuthState:
|
|
48
|
+
status: str = "pending"
|
|
49
|
+
method: str | None = None
|
|
50
|
+
error_message: str | None = None
|
|
51
|
+
state_token: str | None = None
|
|
52
|
+
code_verifier: str | None = None
|
|
53
|
+
device_auth_id: str | None = None
|
|
54
|
+
user_code: str | None = None
|
|
55
|
+
interval_seconds: int | None = None
|
|
56
|
+
expires_at: float | None = None
|
|
57
|
+
callback_server: "OAuthCallbackServer | None" = None
|
|
58
|
+
poll_task: asyncio.Task[None] | None = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class OAuthStateStore:
|
|
62
|
+
def __init__(self) -> None:
|
|
63
|
+
self._lock = asyncio.Lock()
|
|
64
|
+
self._state = OAuthState(status="idle")
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def lock(self) -> asyncio.Lock:
|
|
68
|
+
return self._lock
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def state(self) -> OAuthState:
|
|
72
|
+
return self._state
|
|
73
|
+
|
|
74
|
+
async def reset(self) -> None:
|
|
75
|
+
async with self._lock:
|
|
76
|
+
await self._cleanup_locked()
|
|
77
|
+
self._state = OAuthState(status="idle")
|
|
78
|
+
|
|
79
|
+
async def _cleanup_locked(self) -> None:
|
|
80
|
+
task = self._state.poll_task
|
|
81
|
+
if task and not task.done():
|
|
82
|
+
task.cancel()
|
|
83
|
+
server = self._state.callback_server
|
|
84
|
+
if server:
|
|
85
|
+
await server.stop()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class OAuthCallbackServer:
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
|
|
92
|
+
host: str = "127.0.0.1",
|
|
93
|
+
port: int = 1455,
|
|
94
|
+
) -> None:
|
|
95
|
+
self._handler = handler
|
|
96
|
+
self._host = host
|
|
97
|
+
self._port = port
|
|
98
|
+
self._runner: web.AppRunner | None = None
|
|
99
|
+
self._site: web.TCPSite | None = None
|
|
100
|
+
|
|
101
|
+
async def start(self) -> None:
|
|
102
|
+
app = web.Application()
|
|
103
|
+
app.router.add_get("/auth/callback", self._handler)
|
|
104
|
+
self._runner = web.AppRunner(app)
|
|
105
|
+
await self._runner.setup()
|
|
106
|
+
self._site = web.TCPSite(self._runner, self._host, self._port)
|
|
107
|
+
await self._site.start()
|
|
108
|
+
|
|
109
|
+
async def stop(self) -> None:
|
|
110
|
+
if self._runner:
|
|
111
|
+
await self._runner.cleanup()
|
|
112
|
+
self._runner = None
|
|
113
|
+
self._site = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
_OAUTH_STORE = OAuthStateStore()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class OauthService:
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
accounts_repo: AccountsRepository,
|
|
123
|
+
repo_factory: Callable[[], AbstractAsyncContextManager[AccountsRepository]] | None = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
self._accounts_repo = accounts_repo
|
|
126
|
+
self._encryptor = TokenEncryptor()
|
|
127
|
+
self._store = _OAUTH_STORE
|
|
128
|
+
self._repo_factory = repo_factory
|
|
129
|
+
|
|
130
|
+
async def start_oauth(self, request: OauthStartRequest) -> OauthStartResponse:
|
|
131
|
+
force_method = (request.force_method or "").lower()
|
|
132
|
+
if not force_method:
|
|
133
|
+
accounts = await self._accounts_repo.list_accounts()
|
|
134
|
+
if accounts:
|
|
135
|
+
async with self._store.lock:
|
|
136
|
+
await self._store._cleanup_locked()
|
|
137
|
+
self._store._state = OAuthState(status="success")
|
|
138
|
+
return OauthStartResponse(method="browser")
|
|
139
|
+
|
|
140
|
+
if force_method == "device":
|
|
141
|
+
return await self._start_device_flow()
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
return await self._start_browser_flow()
|
|
145
|
+
except OSError:
|
|
146
|
+
return await self._start_device_flow()
|
|
147
|
+
|
|
148
|
+
async def oauth_status(self) -> OauthStatusResponse:
|
|
149
|
+
async with self._store.lock:
|
|
150
|
+
state = self._store.state
|
|
151
|
+
status = state.status if state.status != "idle" else "pending"
|
|
152
|
+
return OauthStatusResponse(status=status, error_message=state.error_message)
|
|
153
|
+
|
|
154
|
+
async def complete_oauth(self, request: OauthCompleteRequest | None = None) -> OauthCompleteResponse:
|
|
155
|
+
payload = request or OauthCompleteRequest()
|
|
156
|
+
async with self._store.lock:
|
|
157
|
+
state = self._store.state
|
|
158
|
+
if payload.device_auth_id:
|
|
159
|
+
state.device_auth_id = payload.device_auth_id
|
|
160
|
+
if payload.user_code:
|
|
161
|
+
state.user_code = payload.user_code
|
|
162
|
+
if state.status == "success":
|
|
163
|
+
return OauthCompleteResponse(status="success")
|
|
164
|
+
if state.method != "device":
|
|
165
|
+
return OauthCompleteResponse(status="pending")
|
|
166
|
+
if state.poll_task and not state.poll_task.done():
|
|
167
|
+
return OauthCompleteResponse(status="pending")
|
|
168
|
+
if not state.device_auth_id or not state.user_code or not state.expires_at:
|
|
169
|
+
state.status = "error"
|
|
170
|
+
state.error_message = "Device code flow is not initialized."
|
|
171
|
+
return OauthCompleteResponse(status="error")
|
|
172
|
+
|
|
173
|
+
interval = state.interval_seconds if state.interval_seconds is not None else 0
|
|
174
|
+
interval = max(interval, 0)
|
|
175
|
+
poll_context = DevicePollContext(
|
|
176
|
+
device_auth_id=state.device_auth_id,
|
|
177
|
+
user_code=state.user_code,
|
|
178
|
+
interval_seconds=interval,
|
|
179
|
+
expires_at=state.expires_at,
|
|
180
|
+
)
|
|
181
|
+
state.poll_task = asyncio.create_task(self._poll_device_tokens(poll_context))
|
|
182
|
+
return OauthCompleteResponse(status="pending")
|
|
183
|
+
|
|
184
|
+
async def _start_browser_flow(self) -> OauthStartResponse:
|
|
185
|
+
await self._store.reset()
|
|
186
|
+
code_verifier, code_challenge = generate_pkce_pair()
|
|
187
|
+
state_token = secrets.token_urlsafe(16)
|
|
188
|
+
authorization_url = build_authorization_url(state=state_token, code_challenge=code_challenge)
|
|
189
|
+
settings = get_settings()
|
|
190
|
+
|
|
191
|
+
async with self._store.lock:
|
|
192
|
+
state = self._store.state
|
|
193
|
+
state.status = "pending"
|
|
194
|
+
state.method = "browser"
|
|
195
|
+
state.state_token = state_token
|
|
196
|
+
state.code_verifier = code_verifier
|
|
197
|
+
state.error_message = None
|
|
198
|
+
|
|
199
|
+
callback_server = OAuthCallbackServer(
|
|
200
|
+
self._handle_callback,
|
|
201
|
+
host=settings.oauth_callback_host,
|
|
202
|
+
port=settings.oauth_callback_port,
|
|
203
|
+
)
|
|
204
|
+
await callback_server.start()
|
|
205
|
+
|
|
206
|
+
async with self._store.lock:
|
|
207
|
+
self._store.state.callback_server = callback_server
|
|
208
|
+
|
|
209
|
+
return OauthStartResponse(
|
|
210
|
+
method="browser",
|
|
211
|
+
authorization_url=authorization_url,
|
|
212
|
+
callback_url=settings.oauth_redirect_uri,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
async def _start_device_flow(self) -> OauthStartResponse:
|
|
216
|
+
await self._store.reset()
|
|
217
|
+
try:
|
|
218
|
+
device = await request_device_code()
|
|
219
|
+
except OAuthError as exc:
|
|
220
|
+
await self._set_error(exc.message)
|
|
221
|
+
raise
|
|
222
|
+
|
|
223
|
+
async with self._store.lock:
|
|
224
|
+
state = self._store.state
|
|
225
|
+
state.status = "pending"
|
|
226
|
+
state.method = "device"
|
|
227
|
+
state.device_auth_id = device.device_auth_id
|
|
228
|
+
state.user_code = device.user_code
|
|
229
|
+
state.interval_seconds = device.interval_seconds
|
|
230
|
+
state.expires_at = time.time() + device.expires_in_seconds
|
|
231
|
+
state.error_message = None
|
|
232
|
+
|
|
233
|
+
return OauthStartResponse(
|
|
234
|
+
method="device",
|
|
235
|
+
verification_url=device.verification_url,
|
|
236
|
+
user_code=device.user_code,
|
|
237
|
+
device_auth_id=device.device_auth_id,
|
|
238
|
+
interval_seconds=device.interval_seconds,
|
|
239
|
+
expires_in_seconds=device.expires_in_seconds,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
async def _handle_callback(self, request: web.Request) -> web.Response:
|
|
243
|
+
params = request.rel_url.query
|
|
244
|
+
error = params.get("error")
|
|
245
|
+
code = params.get("code")
|
|
246
|
+
state = params.get("state")
|
|
247
|
+
|
|
248
|
+
if error:
|
|
249
|
+
await self._set_error(f"OAuth error: {error}")
|
|
250
|
+
return self._html_response(_error_html("Authorization failed."))
|
|
251
|
+
|
|
252
|
+
async with self._store.lock:
|
|
253
|
+
expected_state = self._store.state.state_token
|
|
254
|
+
verifier = self._store.state.code_verifier
|
|
255
|
+
|
|
256
|
+
if not code or not state or state != expected_state or not verifier:
|
|
257
|
+
await self._set_error("Invalid OAuth callback state.")
|
|
258
|
+
return self._html_response(_error_html("Invalid OAuth callback."))
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
tokens = await exchange_authorization_code(code=code, code_verifier=verifier)
|
|
262
|
+
await self._persist_tokens(tokens)
|
|
263
|
+
await self._set_success()
|
|
264
|
+
html = _success_html()
|
|
265
|
+
except OAuthError as exc:
|
|
266
|
+
await self._set_error(exc.message)
|
|
267
|
+
html = _error_html(exc.message)
|
|
268
|
+
|
|
269
|
+
asyncio.create_task(self._stop_callback_server())
|
|
270
|
+
return self._html_response(html)
|
|
271
|
+
|
|
272
|
+
async def _poll_device_tokens(self, context: "DevicePollContext") -> None:
|
|
273
|
+
try:
|
|
274
|
+
while time.time() < context.expires_at:
|
|
275
|
+
tokens = await exchange_device_token(
|
|
276
|
+
device_auth_id=context.device_auth_id,
|
|
277
|
+
user_code=context.user_code,
|
|
278
|
+
)
|
|
279
|
+
if tokens:
|
|
280
|
+
await self._persist_tokens(tokens)
|
|
281
|
+
await self._set_success()
|
|
282
|
+
return
|
|
283
|
+
await _async_sleep(context.interval_seconds)
|
|
284
|
+
await self._set_error("Device code expired.")
|
|
285
|
+
except OAuthError as exc:
|
|
286
|
+
await self._set_error(exc.message)
|
|
287
|
+
finally:
|
|
288
|
+
async with self._store.lock:
|
|
289
|
+
current = asyncio.current_task()
|
|
290
|
+
if self._store.state.poll_task is current:
|
|
291
|
+
self._store.state.poll_task = None
|
|
292
|
+
|
|
293
|
+
async def _persist_tokens(self, tokens: OAuthTokens) -> None:
|
|
294
|
+
claims = extract_id_token_claims(tokens.id_token)
|
|
295
|
+
auth_claims = claims.auth or OpenAIAuthClaims()
|
|
296
|
+
account_id = auth_claims.chatgpt_account_id or claims.chatgpt_account_id
|
|
297
|
+
email = claims.email or DEFAULT_EMAIL
|
|
298
|
+
plan_type = auth_claims.chatgpt_plan_type or claims.chatgpt_plan_type or DEFAULT_PLAN
|
|
299
|
+
account_id = account_id or fallback_account_id(email)
|
|
300
|
+
|
|
301
|
+
account = Account(
|
|
302
|
+
id=account_id,
|
|
303
|
+
email=email,
|
|
304
|
+
plan_type=plan_type,
|
|
305
|
+
access_token_encrypted=self._encryptor.encrypt(tokens.access_token),
|
|
306
|
+
refresh_token_encrypted=self._encryptor.encrypt(tokens.refresh_token),
|
|
307
|
+
id_token_encrypted=self._encryptor.encrypt(tokens.id_token),
|
|
308
|
+
last_refresh=utcnow(),
|
|
309
|
+
status=AccountStatus.ACTIVE,
|
|
310
|
+
deactivation_reason=None,
|
|
311
|
+
)
|
|
312
|
+
if self._repo_factory:
|
|
313
|
+
async with self._repo_factory() as repo:
|
|
314
|
+
await repo.upsert(account)
|
|
315
|
+
else:
|
|
316
|
+
await self._accounts_repo.upsert(account)
|
|
317
|
+
|
|
318
|
+
async def _set_success(self) -> None:
|
|
319
|
+
async with self._store.lock:
|
|
320
|
+
self._store.state.status = "success"
|
|
321
|
+
self._store.state.error_message = None
|
|
322
|
+
|
|
323
|
+
async def _set_error(self, message: str) -> None:
|
|
324
|
+
async with self._store.lock:
|
|
325
|
+
self._store.state.status = "error"
|
|
326
|
+
self._store.state.error_message = message
|
|
327
|
+
|
|
328
|
+
async def _stop_callback_server(self) -> None:
|
|
329
|
+
async with self._store.lock:
|
|
330
|
+
server = self._store.state.callback_server
|
|
331
|
+
self._store.state.callback_server = None
|
|
332
|
+
if server:
|
|
333
|
+
await server.stop()
|
|
334
|
+
|
|
335
|
+
@staticmethod
|
|
336
|
+
def _html_response(html: str) -> web.Response:
|
|
337
|
+
return web.Response(text=html, content_type="text/html")
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
@dataclass(frozen=True)
|
|
341
|
+
class DevicePollContext:
|
|
342
|
+
device_auth_id: str
|
|
343
|
+
user_code: str
|
|
344
|
+
interval_seconds: int
|
|
345
|
+
expires_at: float
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def _success_html() -> str:
|
|
349
|
+
try:
|
|
350
|
+
return _SUCCESS_TEMPLATE.read_text(encoding="utf-8")
|
|
351
|
+
except OSError:
|
|
352
|
+
return "<html><body><h1>Login complete</h1><p>Return to the dashboard.</p></body></html>"
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _error_html(message: str) -> str:
|
|
356
|
+
return f"<html><body><h1>Login failed</h1><p>{message}</p></body></html>"
|