codex-lb 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- app/__init__.py +1 -1
- app/core/auth/__init__.py +2 -1
- app/core/balancer/logic.py +16 -13
- app/core/clients/proxy.py +2 -4
- app/core/config/settings.py +2 -1
- app/core/plan_types.py +64 -0
- app/core/types.py +4 -2
- app/core/usage/__init__.py +3 -2
- app/core/usage/quota.py +58 -0
- app/core/utils/retry.py +14 -0
- app/core/utils/sse.py +6 -2
- app/db/migrations/__init__.py +80 -0
- app/db/migrations/versions/__init__.py +1 -0
- app/db/migrations/versions/normalize_account_plan_types.py +17 -0
- app/db/session.py +14 -0
- app/dependencies.py +0 -8
- app/main.py +4 -4
- app/modules/{proxy → accounts}/auth_manager.py +33 -4
- app/modules/accounts/repository.py +3 -3
- app/modules/accounts/service.py +10 -7
- app/modules/health/api.py +5 -3
- app/modules/health/schemas.py +9 -0
- app/modules/oauth/service.py +5 -1
- app/modules/proxy/helpers.py +285 -0
- app/modules/proxy/load_balancer.py +13 -37
- app/modules/proxy/service.py +37 -307
- app/modules/request_logs/service.py +5 -3
- app/modules/usage/service.py +7 -6
- app/modules/{proxy/usage_updater.py → usage/updater.py} +1 -1
- app/static/index.js +26 -18
- {codex_lb-0.1.4.dist-info → codex_lb-0.2.0.dist-info}/METADATA +1 -1
- {codex_lb-0.1.4.dist-info → codex_lb-0.2.0.dist-info}/RECORD +35 -28
- {codex_lb-0.1.4.dist-info → codex_lb-0.2.0.dist-info}/WHEEL +0 -0
- {codex_lb-0.1.4.dist-info → codex_lb-0.2.0.dist-info}/entry_points.txt +0 -0
- {codex_lb-0.1.4.dist-info → codex_lb-0.2.0.dist-info}/licenses/LICENSE +0 -0
app/__init__.py
CHANGED
app/core/auth/__init__.py
CHANGED
|
@@ -82,10 +82,11 @@ def extract_id_token_claims(id_token: str) -> IdTokenClaims:
|
|
|
82
82
|
def claims_from_auth(auth: AuthFile) -> AccountClaims:
|
|
83
83
|
claims = extract_id_token_claims(auth.tokens.id_token)
|
|
84
84
|
auth_claims = claims.auth or OpenAIAuthClaims()
|
|
85
|
+
plan_type = auth_claims.chatgpt_plan_type or claims.chatgpt_plan_type
|
|
85
86
|
return AccountClaims(
|
|
86
87
|
account_id=auth.tokens.account_id or auth_claims.chatgpt_account_id or claims.chatgpt_account_id,
|
|
87
88
|
email=claims.email,
|
|
88
|
-
plan_type=
|
|
89
|
+
plan_type=plan_type,
|
|
89
90
|
)
|
|
90
91
|
|
|
91
92
|
|
app/core/balancer/logic.py
CHANGED
|
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|
|
5
5
|
from typing import Iterable
|
|
6
6
|
|
|
7
7
|
from app.core.balancer.types import UpstreamError
|
|
8
|
-
from app.core.utils.retry import parse_retry_after
|
|
8
|
+
from app.core.utils.retry import backoff_seconds, parse_retry_after
|
|
9
9
|
from app.db.models import AccountStatus
|
|
10
10
|
|
|
11
11
|
PERMANENT_FAILURE_CODES = {
|
|
@@ -22,7 +22,8 @@ class AccountState:
|
|
|
22
22
|
account_id: str
|
|
23
23
|
status: AccountStatus
|
|
24
24
|
used_percent: float | None = None
|
|
25
|
-
reset_at:
|
|
25
|
+
reset_at: float | None = None
|
|
26
|
+
cooldown_until: float | None = None
|
|
26
27
|
last_error_at: float | None = None
|
|
27
28
|
last_selected_at: float | None = None
|
|
28
29
|
error_count: int = 0
|
|
@@ -59,6 +60,12 @@ def select_account(states: Iterable[AccountState], now: float | None = None) ->
|
|
|
59
60
|
state.reset_at = None
|
|
60
61
|
else:
|
|
61
62
|
continue
|
|
63
|
+
if state.cooldown_until and current >= state.cooldown_until:
|
|
64
|
+
state.cooldown_until = None
|
|
65
|
+
state.last_error_at = None
|
|
66
|
+
state.error_count = 0
|
|
67
|
+
if state.cooldown_until and current < state.cooldown_until:
|
|
68
|
+
continue
|
|
62
69
|
if state.error_count >= 3:
|
|
63
70
|
backoff = min(300, 30 * (2 ** (state.error_count - 3)))
|
|
64
71
|
if state.last_error_at and current - state.last_error_at < backoff:
|
|
@@ -82,6 +89,10 @@ def select_account(states: Iterable[AccountState], now: float | None = None) ->
|
|
|
82
89
|
if reset_candidates:
|
|
83
90
|
wait_seconds = max(0, min(reset_candidates) - int(current))
|
|
84
91
|
return SelectionResult(None, f"Rate limit exceeded. Try again in {wait_seconds:.0f}s")
|
|
92
|
+
cooldowns = [s.cooldown_until for s in all_states if s.cooldown_until and s.cooldown_until > current]
|
|
93
|
+
if cooldowns:
|
|
94
|
+
wait_seconds = max(0.0, min(cooldowns) - current)
|
|
95
|
+
return SelectionResult(None, f"Rate limit exceeded. Try again in {wait_seconds:.0f}s")
|
|
85
96
|
return SelectionResult(None, "No available accounts")
|
|
86
97
|
|
|
87
98
|
def _sort_key(state: AccountState) -> tuple[float, float, str]:
|
|
@@ -94,21 +105,13 @@ def select_account(states: Iterable[AccountState], now: float | None = None) ->
|
|
|
94
105
|
|
|
95
106
|
|
|
96
107
|
def handle_rate_limit(state: AccountState, error: UpstreamError) -> None:
|
|
97
|
-
state.status = AccountStatus.RATE_LIMITED
|
|
98
108
|
state.error_count += 1
|
|
99
109
|
state.last_error_at = time.time()
|
|
100
|
-
|
|
101
|
-
reset_at = _extract_reset_at(error)
|
|
102
|
-
if reset_at is not None:
|
|
103
|
-
state.reset_at = reset_at
|
|
104
|
-
return
|
|
105
|
-
|
|
106
110
|
message = error.get("message")
|
|
107
111
|
delay = parse_retry_after(message) if message else None
|
|
108
|
-
if delay:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
state.reset_at = int(time.time() + 300)
|
|
112
|
+
if delay is None:
|
|
113
|
+
delay = backoff_seconds(state.error_count)
|
|
114
|
+
state.cooldown_until = time.time() + delay
|
|
112
115
|
|
|
113
116
|
|
|
114
117
|
def handle_quota_exceeded(state: AccountState, error: UpstreamError) -> None:
|
app/core/clients/proxy.py
CHANGED
|
@@ -18,7 +18,6 @@ IGNORE_INBOUND_HEADERS = {"authorization", "chatgpt-account-id", "content-length
|
|
|
18
18
|
|
|
19
19
|
_ERROR_TYPE_CODE_MAP = {
|
|
20
20
|
"rate_limit_exceeded": "rate_limit_exceeded",
|
|
21
|
-
"usage_limit_reached": "rate_limit_exceeded",
|
|
22
21
|
"usage_not_included": "usage_not_included",
|
|
23
22
|
"insufficient_quota": "insufficient_quota",
|
|
24
23
|
"quota_exceeded": "quota_exceeded",
|
|
@@ -64,12 +63,11 @@ def _normalize_error_code(code: str | None, error_type: str | None) -> str:
|
|
|
64
63
|
if code:
|
|
65
64
|
normalized_code = code.lower()
|
|
66
65
|
mapped = _ERROR_TYPE_CODE_MAP.get(normalized_code)
|
|
67
|
-
return mapped or
|
|
66
|
+
return mapped or normalized_code
|
|
68
67
|
normalized_type = error_type.lower() if error_type else None
|
|
69
68
|
if normalized_type:
|
|
70
69
|
mapped = _ERROR_TYPE_CODE_MAP.get(normalized_type)
|
|
71
|
-
|
|
72
|
-
return mapped
|
|
70
|
+
return mapped or normalized_type
|
|
73
71
|
return "upstream_error"
|
|
74
72
|
|
|
75
73
|
|
app/core/config/settings.py
CHANGED
|
@@ -39,6 +39,7 @@ class Settings(BaseSettings):
|
|
|
39
39
|
usage_refresh_enabled: bool = True
|
|
40
40
|
usage_refresh_interval_seconds: int = 60
|
|
41
41
|
encryption_key_file: Path = DEFAULT_ENCRYPTION_KEY_FILE
|
|
42
|
+
database_migrations_fail_fast: bool = True
|
|
42
43
|
|
|
43
44
|
@field_validator("database_url")
|
|
44
45
|
@classmethod
|
|
@@ -61,7 +62,7 @@ class Settings(BaseSettings):
|
|
|
61
62
|
return value.expanduser()
|
|
62
63
|
if isinstance(value, str):
|
|
63
64
|
return Path(value).expanduser()
|
|
64
|
-
|
|
65
|
+
raise TypeError("encryption_key_file must be a path")
|
|
65
66
|
|
|
66
67
|
|
|
67
68
|
@lru_cache(maxsize=1)
|
app/core/plan_types.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Final
|
|
4
|
+
|
|
5
|
+
ACCOUNT_PLAN_TYPES: Final[set[str]] = {
|
|
6
|
+
"free",
|
|
7
|
+
"plus",
|
|
8
|
+
"pro",
|
|
9
|
+
"team",
|
|
10
|
+
"business",
|
|
11
|
+
"enterprise",
|
|
12
|
+
"edu",
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
RATE_LIMIT_PLAN_TYPES: Final[set[str]] = {
|
|
16
|
+
*ACCOUNT_PLAN_TYPES,
|
|
17
|
+
"guest",
|
|
18
|
+
"go",
|
|
19
|
+
"free_workspace",
|
|
20
|
+
"education",
|
|
21
|
+
"quorum",
|
|
22
|
+
"k12",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _clean_plan_type(value: str | None) -> str | None:
|
|
27
|
+
if value is None:
|
|
28
|
+
return None
|
|
29
|
+
cleaned = value.strip()
|
|
30
|
+
return cleaned or None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def normalize_account_plan_type(value: str | None) -> str | None:
|
|
34
|
+
cleaned = _clean_plan_type(value)
|
|
35
|
+
if not cleaned:
|
|
36
|
+
return None
|
|
37
|
+
normalized = cleaned.lower()
|
|
38
|
+
return normalized if normalized in ACCOUNT_PLAN_TYPES else None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def canonicalize_account_plan_type(value: str | None) -> str | None:
|
|
42
|
+
cleaned = _clean_plan_type(value)
|
|
43
|
+
if not cleaned:
|
|
44
|
+
return None
|
|
45
|
+
normalized = cleaned.lower()
|
|
46
|
+
if normalized in ACCOUNT_PLAN_TYPES:
|
|
47
|
+
return normalized
|
|
48
|
+
return cleaned
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def coerce_account_plan_type(value: str | None, default: str) -> str:
|
|
52
|
+
cleaned = _clean_plan_type(value)
|
|
53
|
+
if cleaned is None:
|
|
54
|
+
return default
|
|
55
|
+
canonical = canonicalize_account_plan_type(cleaned)
|
|
56
|
+
return canonical if canonical is not None else default
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def normalize_rate_limit_plan_type(value: str | None) -> str | None:
|
|
60
|
+
cleaned = _clean_plan_type(value)
|
|
61
|
+
if not cleaned:
|
|
62
|
+
return None
|
|
63
|
+
normalized = cleaned.lower()
|
|
64
|
+
return normalized if normalized in RATE_LIMIT_PLAN_TYPES else None
|
app/core/types.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
|
|
5
|
+
type JsonValue = bool | int | float | str | None | list[JsonValue] | Mapping[str, JsonValue]
|
|
6
|
+
type JsonObject = Mapping[str, JsonValue]
|
app/core/usage/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Iterable, Mapping
|
|
4
4
|
|
|
5
|
+
from app.core.plan_types import normalize_account_plan_type
|
|
5
6
|
from app.core.usage.types import (
|
|
6
7
|
UsageCostSummary,
|
|
7
8
|
UsageHistoryPayload,
|
|
@@ -134,9 +135,9 @@ def summarize_usage_window(
|
|
|
134
135
|
|
|
135
136
|
|
|
136
137
|
def capacity_for_plan(plan_type: str | None, window: str) -> float | None:
|
|
137
|
-
|
|
138
|
+
normalized = normalize_account_plan_type(plan_type)
|
|
139
|
+
if not normalized:
|
|
138
140
|
return None
|
|
139
|
-
normalized = plan_type.lower()
|
|
140
141
|
window_key = _normalize_window_key(window)
|
|
141
142
|
if window_key == "primary":
|
|
142
143
|
return PLAN_CAPACITY_CREDITS_PRIMARY.get(normalized)
|
app/core/usage/quota.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from app.core import usage as usage_core
|
|
6
|
+
from app.db.models import AccountStatus
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def apply_usage_quota(
|
|
10
|
+
*,
|
|
11
|
+
status: AccountStatus,
|
|
12
|
+
primary_used: float | None,
|
|
13
|
+
primary_reset: int | None,
|
|
14
|
+
primary_window_minutes: int | None,
|
|
15
|
+
runtime_reset: float | None,
|
|
16
|
+
secondary_used: float | None,
|
|
17
|
+
secondary_reset: int | None,
|
|
18
|
+
) -> tuple[AccountStatus, float | None, float | None]:
|
|
19
|
+
used_percent = primary_used
|
|
20
|
+
reset_at = runtime_reset
|
|
21
|
+
|
|
22
|
+
if status in (AccountStatus.DEACTIVATED, AccountStatus.PAUSED):
|
|
23
|
+
return status, used_percent, reset_at
|
|
24
|
+
|
|
25
|
+
if secondary_used is not None:
|
|
26
|
+
if secondary_used >= 100.0:
|
|
27
|
+
status = AccountStatus.QUOTA_EXCEEDED
|
|
28
|
+
used_percent = 100.0
|
|
29
|
+
if secondary_reset is not None:
|
|
30
|
+
reset_at = secondary_reset
|
|
31
|
+
return status, used_percent, reset_at
|
|
32
|
+
if status == AccountStatus.QUOTA_EXCEEDED:
|
|
33
|
+
status = AccountStatus.ACTIVE
|
|
34
|
+
reset_at = None
|
|
35
|
+
elif status == AccountStatus.QUOTA_EXCEEDED and secondary_reset is not None:
|
|
36
|
+
reset_at = secondary_reset
|
|
37
|
+
|
|
38
|
+
if primary_used is not None:
|
|
39
|
+
if primary_used >= 100.0:
|
|
40
|
+
status = AccountStatus.RATE_LIMITED
|
|
41
|
+
used_percent = 100.0
|
|
42
|
+
if primary_reset is not None:
|
|
43
|
+
reset_at = primary_reset
|
|
44
|
+
else:
|
|
45
|
+
reset_at = _fallback_primary_reset(primary_window_minutes) or reset_at
|
|
46
|
+
return status, used_percent, reset_at
|
|
47
|
+
if status == AccountStatus.RATE_LIMITED:
|
|
48
|
+
status = AccountStatus.ACTIVE
|
|
49
|
+
reset_at = None
|
|
50
|
+
|
|
51
|
+
return status, used_percent, reset_at
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _fallback_primary_reset(primary_window_minutes: int | None) -> float | None:
|
|
55
|
+
window_minutes = primary_window_minutes or usage_core.default_window_minutes("primary")
|
|
56
|
+
if not window_minutes:
|
|
57
|
+
return None
|
|
58
|
+
return time.time() + float(window_minutes) * 60.0
|
app/core/utils/retry.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import random
|
|
3
4
|
import re
|
|
4
5
|
|
|
5
6
|
_RETRY_PATTERN = re.compile(r"(?i)try again in\s*(\d+(?:\.\d+)?)\s*(s|ms|seconds?)")
|
|
7
|
+
_BACKOFF_INITIAL_DELAY_MS = 200
|
|
8
|
+
_BACKOFF_FACTOR = 2.0
|
|
9
|
+
_BACKOFF_JITTER_MIN = 0.9
|
|
10
|
+
_BACKOFF_JITTER_MAX = 1.1
|
|
6
11
|
|
|
7
12
|
|
|
8
13
|
def parse_retry_after(message: str) -> float | None:
|
|
@@ -14,3 +19,12 @@ def parse_retry_after(message: str) -> float | None:
|
|
|
14
19
|
if unit == "ms":
|
|
15
20
|
return value / 1000
|
|
16
21
|
return value
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def backoff_seconds(attempt: int) -> float:
|
|
25
|
+
if attempt < 1:
|
|
26
|
+
attempt = 1
|
|
27
|
+
exponent = _BACKOFF_FACTOR ** (attempt - 1)
|
|
28
|
+
base_ms = _BACKOFF_INITIAL_DELAY_MS * exponent
|
|
29
|
+
jitter = random.uniform(_BACKOFF_JITTER_MIN, _BACKOFF_JITTER_MAX)
|
|
30
|
+
return (base_ms * jitter) / 1000.0
|
app/core/utils/sse.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from collections.abc import Mapping
|
|
4
5
|
|
|
5
|
-
from app.core.
|
|
6
|
+
from app.core.errors import ResponseFailedEvent
|
|
7
|
+
from app.core.types import JsonValue
|
|
6
8
|
|
|
9
|
+
type JsonPayload = Mapping[str, JsonValue] | ResponseFailedEvent
|
|
7
10
|
|
|
8
|
-
|
|
11
|
+
|
|
12
|
+
def format_sse_event(payload: JsonPayload) -> str:
|
|
9
13
|
data = json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
|
|
10
14
|
event_type = payload.get("type")
|
|
11
15
|
if isinstance(event_type, str) and event_type:
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Awaitable, Callable, Final
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import text
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
10
|
+
|
|
11
|
+
from app.db.migrations.versions import normalize_account_plan_types
|
|
12
|
+
|
|
13
|
+
_CREATE_MIGRATIONS_TABLE = """
|
|
14
|
+
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
15
|
+
name TEXT PRIMARY KEY,
|
|
16
|
+
applied_at TEXT NOT NULL
|
|
17
|
+
)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
_INSERT_MIGRATION = """
|
|
21
|
+
INSERT INTO schema_migrations (name, applied_at)
|
|
22
|
+
VALUES (:name, :applied_at)
|
|
23
|
+
ON CONFLICT(name) DO NOTHING
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class Migration:
|
|
29
|
+
name: str
|
|
30
|
+
run: Callable[[AsyncSession], Awaitable[None]]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
MIGRATIONS: Final[tuple[Migration, ...]] = (
|
|
34
|
+
Migration("001_normalize_account_plan_types", normalize_account_plan_types.run),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
async def run_migrations(session: AsyncSession) -> int:
|
|
39
|
+
await _ensure_schema_migrations(session)
|
|
40
|
+
applied_count = 0
|
|
41
|
+
for migration in MIGRATIONS:
|
|
42
|
+
applied_now = await _apply_migration(session, migration)
|
|
43
|
+
if applied_now:
|
|
44
|
+
applied_count += 1
|
|
45
|
+
return applied_count
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
async def _apply_migration(session: AsyncSession, migration: Migration) -> bool:
|
|
49
|
+
async with _migration_transaction(session):
|
|
50
|
+
result = await session.execute(
|
|
51
|
+
text(_INSERT_MIGRATION),
|
|
52
|
+
{
|
|
53
|
+
"name": migration.name,
|
|
54
|
+
"applied_at": _utcnow_iso(),
|
|
55
|
+
},
|
|
56
|
+
)
|
|
57
|
+
rowcount = getattr(result, "rowcount", 0) or 0
|
|
58
|
+
if not rowcount:
|
|
59
|
+
return False
|
|
60
|
+
await migration.run(session)
|
|
61
|
+
return True
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def _ensure_schema_migrations(session: AsyncSession) -> None:
|
|
65
|
+
async with _migration_transaction(session):
|
|
66
|
+
await session.execute(text(_CREATE_MIGRATIONS_TABLE))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@asynccontextmanager
|
|
70
|
+
async def _migration_transaction(session: AsyncSession):
|
|
71
|
+
if session.in_transaction():
|
|
72
|
+
async with session.begin_nested():
|
|
73
|
+
yield
|
|
74
|
+
else:
|
|
75
|
+
async with session.begin():
|
|
76
|
+
yield
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _utcnow_iso() -> str:
|
|
80
|
+
return datetime.now(timezone.utc).isoformat()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import select
|
|
4
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
5
|
+
|
|
6
|
+
from app.core.auth import DEFAULT_PLAN
|
|
7
|
+
from app.core.plan_types import coerce_account_plan_type
|
|
8
|
+
from app.db.models import Account
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def run(session: AsyncSession) -> None:
|
|
12
|
+
result = await session.execute(select(Account))
|
|
13
|
+
accounts = list(result.scalars().all())
|
|
14
|
+
for account in accounts:
|
|
15
|
+
coerced = coerce_account_plan_type(account.plan_type, DEFAULT_PLAN)
|
|
16
|
+
if account.plan_type != coerced:
|
|
17
|
+
account.plan_type = coerced
|
app/db/session.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import logging
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import AsyncIterator
|
|
6
7
|
|
|
7
8
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
8
9
|
|
|
9
10
|
from app.core.config.settings import get_settings
|
|
11
|
+
from app.db.migrations import run_migrations
|
|
10
12
|
|
|
11
13
|
DATABASE_URL = get_settings().database_url
|
|
12
14
|
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
13
17
|
engine = create_async_engine(DATABASE_URL, echo=False)
|
|
14
18
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
|
15
19
|
|
|
@@ -60,3 +64,13 @@ async def init_db() -> None:
|
|
|
60
64
|
|
|
61
65
|
async with engine.begin() as conn:
|
|
62
66
|
await conn.run_sync(Base.metadata.create_all)
|
|
67
|
+
|
|
68
|
+
async with SessionLocal() as session:
|
|
69
|
+
try:
|
|
70
|
+
updated = await run_migrations(session)
|
|
71
|
+
if updated:
|
|
72
|
+
logger.info("Applied database migrations count=%s", updated)
|
|
73
|
+
except Exception:
|
|
74
|
+
logger.exception("Failed to apply database migrations")
|
|
75
|
+
if get_settings().database_migrations_fail_fast:
|
|
76
|
+
raise
|
app/dependencies.py
CHANGED
|
@@ -22,8 +22,6 @@ from app.modules.usage.service import UsageService
|
|
|
22
22
|
class AccountsContext:
|
|
23
23
|
session: AsyncSession
|
|
24
24
|
repository: AccountsRepository
|
|
25
|
-
usage_repository: UsageRepository
|
|
26
|
-
request_logs_repository: RequestLogsRepository
|
|
27
25
|
service: AccountsService
|
|
28
26
|
|
|
29
27
|
|
|
@@ -31,8 +29,6 @@ class AccountsContext:
|
|
|
31
29
|
class UsageContext:
|
|
32
30
|
session: AsyncSession
|
|
33
31
|
usage_repository: UsageRepository
|
|
34
|
-
request_logs_repository: RequestLogsRepository
|
|
35
|
-
accounts_repository: AccountsRepository
|
|
36
32
|
service: UsageService
|
|
37
33
|
|
|
38
34
|
|
|
@@ -63,8 +59,6 @@ def get_accounts_context(
|
|
|
63
59
|
return AccountsContext(
|
|
64
60
|
session=session,
|
|
65
61
|
repository=repository,
|
|
66
|
-
usage_repository=usage_repository,
|
|
67
|
-
request_logs_repository=request_logs_repository,
|
|
68
62
|
service=service,
|
|
69
63
|
)
|
|
70
64
|
|
|
@@ -79,8 +73,6 @@ def get_usage_context(
|
|
|
79
73
|
return UsageContext(
|
|
80
74
|
session=session,
|
|
81
75
|
usage_repository=usage_repository,
|
|
82
|
-
request_logs_repository=request_logs_repository,
|
|
83
|
-
accounts_repository=accounts_repository,
|
|
84
76
|
service=service,
|
|
85
77
|
)
|
|
86
78
|
|
app/main.py
CHANGED
|
@@ -11,7 +11,7 @@ from fastapi.exception_handlers import (
|
|
|
11
11
|
request_validation_exception_handler,
|
|
12
12
|
)
|
|
13
13
|
from fastapi.exceptions import RequestValidationError
|
|
14
|
-
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
|
|
14
|
+
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, Response
|
|
15
15
|
from fastapi.staticfiles import StaticFiles
|
|
16
16
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
17
17
|
|
|
@@ -57,7 +57,7 @@ def create_app() -> FastAPI:
|
|
|
57
57
|
return response
|
|
58
58
|
|
|
59
59
|
@app.middleware("http")
|
|
60
|
-
async def api_unhandled_error_middleware(request: Request, call_next) ->
|
|
60
|
+
async def api_unhandled_error_middleware(request: Request, call_next) -> Response:
|
|
61
61
|
try:
|
|
62
62
|
return await call_next(request)
|
|
63
63
|
except Exception:
|
|
@@ -76,7 +76,7 @@ def create_app() -> FastAPI:
|
|
|
76
76
|
async def _validation_error_handler(
|
|
77
77
|
request: Request,
|
|
78
78
|
exc: RequestValidationError,
|
|
79
|
-
) ->
|
|
79
|
+
) -> Response:
|
|
80
80
|
if request.url.path.startswith("/api/"):
|
|
81
81
|
return JSONResponse(
|
|
82
82
|
status_code=422,
|
|
@@ -88,7 +88,7 @@ def create_app() -> FastAPI:
|
|
|
88
88
|
async def _http_error_handler(
|
|
89
89
|
request: Request,
|
|
90
90
|
exc: StarletteHTTPException,
|
|
91
|
-
) ->
|
|
91
|
+
) -> Response:
|
|
92
92
|
if request.url.path.startswith("/api/"):
|
|
93
93
|
detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
|
|
94
94
|
return JSONResponse(
|
|
@@ -1,15 +1,39 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
from app.core.auth import DEFAULT_PLAN
|
|
3
7
|
from app.core.auth.refresh import RefreshError, refresh_access_token, should_refresh
|
|
4
8
|
from app.core.balancer import PERMANENT_FAILURE_CODES
|
|
5
9
|
from app.core.crypto import TokenEncryptor
|
|
10
|
+
from app.core.plan_types import coerce_account_plan_type
|
|
6
11
|
from app.core.utils.time import utcnow
|
|
7
12
|
from app.db.models import Account, AccountStatus
|
|
8
|
-
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AccountsRepositoryPort(Protocol):
|
|
16
|
+
async def update_status(
|
|
17
|
+
self,
|
|
18
|
+
account_id: str,
|
|
19
|
+
status: AccountStatus,
|
|
20
|
+
deactivation_reason: str | None = None,
|
|
21
|
+
) -> bool: ...
|
|
22
|
+
|
|
23
|
+
async def update_tokens(
|
|
24
|
+
self,
|
|
25
|
+
account_id: str,
|
|
26
|
+
access_token_encrypted: bytes,
|
|
27
|
+
refresh_token_encrypted: bytes,
|
|
28
|
+
id_token_encrypted: bytes,
|
|
29
|
+
last_refresh: datetime,
|
|
30
|
+
plan_type: str | None = None,
|
|
31
|
+
email: str | None = None,
|
|
32
|
+
) -> bool: ...
|
|
9
33
|
|
|
10
34
|
|
|
11
35
|
class AuthManager:
|
|
12
|
-
def __init__(self, repo:
|
|
36
|
+
def __init__(self, repo: AccountsRepositoryPort) -> None:
|
|
13
37
|
self._repo = repo
|
|
14
38
|
self._encryptor = TokenEncryptor()
|
|
15
39
|
|
|
@@ -34,8 +58,13 @@ class AuthManager:
|
|
|
34
58
|
account.refresh_token_encrypted = self._encryptor.encrypt(result.refresh_token)
|
|
35
59
|
account.id_token_encrypted = self._encryptor.encrypt(result.id_token)
|
|
36
60
|
account.last_refresh = utcnow()
|
|
37
|
-
if result.plan_type:
|
|
38
|
-
account.plan_type =
|
|
61
|
+
if result.plan_type is not None:
|
|
62
|
+
account.plan_type = coerce_account_plan_type(
|
|
63
|
+
result.plan_type,
|
|
64
|
+
account.plan_type or DEFAULT_PLAN,
|
|
65
|
+
)
|
|
66
|
+
elif not account.plan_type:
|
|
67
|
+
account.plan_type = DEFAULT_PLAN
|
|
39
68
|
if result.email:
|
|
40
69
|
account.email = result.email
|
|
41
70
|
|
|
@@ -48,12 +48,12 @@ class AccountsRepository:
|
|
|
48
48
|
.values(status=status, deactivation_reason=deactivation_reason)
|
|
49
49
|
)
|
|
50
50
|
await self._session.commit()
|
|
51
|
-
return bool(result
|
|
51
|
+
return bool(getattr(result, "rowcount", 0) or 0)
|
|
52
52
|
|
|
53
53
|
async def delete(self, account_id: str) -> bool:
|
|
54
54
|
result = await self._session.execute(delete(Account).where(Account.id == account_id))
|
|
55
55
|
await self._session.commit()
|
|
56
|
-
return bool(result
|
|
56
|
+
return bool(getattr(result, "rowcount", 0) or 0)
|
|
57
57
|
|
|
58
58
|
async def update_tokens(
|
|
59
59
|
self,
|
|
@@ -77,4 +77,4 @@ class AccountsRepository:
|
|
|
77
77
|
values["email"] = email
|
|
78
78
|
result = await self._session.execute(update(Account).where(Account.id == account_id).values(**values))
|
|
79
79
|
await self._session.commit()
|
|
80
|
-
return bool(result
|
|
80
|
+
return bool(getattr(result, "rowcount", 0) or 0)
|