codex-lb 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. app/core/auth/__init__.py +10 -0
  2. app/core/balancer/logic.py +33 -6
  3. app/core/config/settings.py +2 -0
  4. app/core/usage/__init__.py +2 -0
  5. app/core/usage/logs.py +12 -2
  6. app/core/usage/quota.py +10 -4
  7. app/core/usage/types.py +3 -2
  8. app/db/migrations/__init__.py +14 -3
  9. app/db/migrations/versions/add_accounts_chatgpt_account_id.py +29 -0
  10. app/db/migrations/versions/add_accounts_reset_at.py +29 -0
  11. app/db/migrations/versions/add_dashboard_settings.py +31 -0
  12. app/db/migrations/versions/add_request_logs_reasoning_effort.py +21 -0
  13. app/db/models.py +33 -0
  14. app/db/session.py +71 -11
  15. app/dependencies.py +27 -1
  16. app/main.py +11 -2
  17. app/modules/accounts/auth_manager.py +44 -3
  18. app/modules/accounts/repository.py +14 -6
  19. app/modules/accounts/service.py +4 -2
  20. app/modules/oauth/service.py +4 -3
  21. app/modules/proxy/load_balancer.py +74 -5
  22. app/modules/proxy/service.py +155 -31
  23. app/modules/proxy/sticky_repository.py +56 -0
  24. app/modules/request_logs/repository.py +6 -3
  25. app/modules/request_logs/schemas.py +2 -0
  26. app/modules/request_logs/service.py +8 -1
  27. app/modules/settings/__init__.py +1 -0
  28. app/modules/settings/api.py +37 -0
  29. app/modules/settings/repository.py +40 -0
  30. app/modules/settings/schemas.py +13 -0
  31. app/modules/settings/service.py +33 -0
  32. app/modules/shared/schemas.py +16 -2
  33. app/modules/usage/schemas.py +1 -0
  34. app/modules/usage/service.py +17 -1
  35. app/modules/usage/updater.py +36 -7
  36. app/static/7.css +73 -0
  37. app/static/index.css +33 -4
  38. app/static/index.html +51 -4
  39. app/static/index.js +231 -25
  40. {codex_lb-0.2.0.dist-info → codex_lb-0.3.0.dist-info}/METADATA +2 -2
  41. {codex_lb-0.2.0.dist-info → codex_lb-0.3.0.dist-info}/RECORD +44 -34
  42. {codex_lb-0.2.0.dist-info → codex_lb-0.3.0.dist-info}/WHEEL +0 -0
  43. {codex_lb-0.2.0.dist-info → codex_lb-0.3.0.dist-info}/entry_points.txt +0 -0
  44. {codex_lb-0.2.0.dist-info → codex_lb-0.3.0.dist-info}/licenses/LICENSE +0 -0
app/core/auth/__init__.py CHANGED
@@ -90,7 +90,17 @@ def claims_from_auth(auth: AuthFile) -> AccountClaims:
90
90
  )
91
91
 
92
92
 
93
+ def generate_unique_account_id(account_id: str | None, email: str | None) -> str:
94
+ if account_id and email and email != DEFAULT_EMAIL:
95
+ email_hash = hashlib.sha256(email.encode()).hexdigest()[:8]
96
+ return f"{account_id}_{email_hash}"
97
+ if account_id:
98
+ return account_id
99
+ return fallback_account_id(email)
100
+
101
+
93
102
  def fallback_account_id(email: str | None) -> str:
103
+ """Generate a fallback account ID when no OpenAI account ID is available."""
94
104
  if email and email != DEFAULT_EMAIL:
95
105
  digest = hashlib.sha256(email.encode()).hexdigest()[:12]
96
106
  return f"email_{digest}"
@@ -16,6 +16,9 @@ PERMANENT_FAILURE_CODES = {
16
16
  "account_deleted": "Account has been deleted",
17
17
  }
18
18
 
19
+ SECONDS_PER_DAY = 60 * 60 * 24
20
+ UNKNOWN_RESET_BUCKET_DAYS = 10_000
21
+
19
22
 
20
23
  @dataclass
21
24
  class AccountState:
@@ -24,6 +27,8 @@ class AccountState:
24
27
  used_percent: float | None = None
25
28
  reset_at: float | None = None
26
29
  cooldown_until: float | None = None
30
+ secondary_used_percent: float | None = None
31
+ secondary_reset_at: int | None = None
27
32
  last_error_at: float | None = None
28
33
  last_selected_at: float | None = None
29
34
  error_count: int = 0
@@ -36,7 +41,12 @@ class SelectionResult:
36
41
  error_message: str | None
37
42
 
38
43
 
39
- def select_account(states: Iterable[AccountState], now: float | None = None) -> SelectionResult:
44
+ def select_account(
45
+ states: Iterable[AccountState],
46
+ now: float | None = None,
47
+ *,
48
+ prefer_earlier_reset: bool = False,
49
+ ) -> SelectionResult:
40
50
  current = now or time.time()
41
51
  available: list[AccountState] = []
42
52
  all_states = list(states)
@@ -95,18 +105,35 @@ def select_account(states: Iterable[AccountState], now: float | None = None) ->
95
105
  return SelectionResult(None, f"Rate limit exceeded. Try again in {wait_seconds:.0f}s")
96
106
  return SelectionResult(None, "No available accounts")
97
107
 
98
- def _sort_key(state: AccountState) -> tuple[float, float, str]:
99
- used = state.used_percent if state.used_percent is not None else 0.0
108
+ def _usage_sort_key(state: AccountState) -> tuple[float, float, float, str]:
109
+ primary_used = state.used_percent if state.used_percent is not None else 0.0
110
+ secondary_used = state.secondary_used_percent if state.secondary_used_percent is not None else primary_used
100
111
  last_selected = state.last_selected_at or 0.0
101
- return used, last_selected, state.account_id
102
-
103
- selected = min(available, key=_sort_key)
112
+ return secondary_used, primary_used, last_selected, state.account_id
113
+
114
+ def _reset_first_sort_key(state: AccountState) -> tuple[int, float, float, float, str]:
115
+ reset_bucket_days = UNKNOWN_RESET_BUCKET_DAYS
116
+ if state.secondary_reset_at is not None:
117
+ reset_bucket_days = max(
118
+ 0,
119
+ int((state.secondary_reset_at - current) // SECONDS_PER_DAY),
120
+ )
121
+ secondary_used, primary_used, last_selected, account_id = _usage_sort_key(state)
122
+ return reset_bucket_days, secondary_used, primary_used, last_selected, account_id
123
+
124
+ selected = min(available, key=_reset_first_sort_key if prefer_earlier_reset else _usage_sort_key)
104
125
  return SelectionResult(selected, None)
105
126
 
106
127
 
107
128
  def handle_rate_limit(state: AccountState, error: UpstreamError) -> None:
129
+ state.status = AccountStatus.RATE_LIMITED
108
130
  state.error_count += 1
109
131
  state.last_error_at = time.time()
132
+
133
+ reset_at = _extract_reset_at(error)
134
+ if reset_at is not None:
135
+ state.reset_at = reset_at
136
+
110
137
  message = error.get("message")
111
138
  delay = parse_retry_after(message) if message else None
112
139
  if delay is None:
@@ -40,6 +40,8 @@ class Settings(BaseSettings):
40
40
  usage_refresh_interval_seconds: int = 60
41
41
  encryption_key_file: Path = DEFAULT_ENCRYPTION_KEY_FILE
42
42
  database_migrations_fail_fast: bool = True
43
+ log_proxy_request_shape: bool = False
44
+ log_proxy_request_shape_raw_cache_key: bool = False
43
45
 
44
46
  @field_validator("database_url")
45
47
  @classmethod
@@ -17,12 +17,14 @@ from app.db.models import Account
17
17
  PLAN_CAPACITY_CREDITS_PRIMARY = {
18
18
  "plus": 225.0,
19
19
  "business": 225.0,
20
+ "team": 225.0,
20
21
  "pro": 1500.0,
21
22
  }
22
23
 
23
24
  PLAN_CAPACITY_CREDITS_SECONDARY = {
24
25
  "plus": 7560.0,
25
26
  "business": 7560.0,
27
+ "team": 7560.0,
26
28
  "pro": 50400.0,
27
29
  }
28
30
 
app/core/usage/logs.py CHANGED
@@ -13,6 +13,17 @@ class RequestLogLike(Protocol):
13
13
  reasoning_tokens: int | None
14
14
 
15
15
 
16
+ def cached_input_tokens_from_log(log: RequestLogLike) -> int | None:
17
+ cached_tokens = log.cached_input_tokens
18
+ if cached_tokens is None:
19
+ return None
20
+ cached_tokens = max(0, int(cached_tokens))
21
+ input_tokens = log.input_tokens
22
+ if input_tokens is not None:
23
+ cached_tokens = min(cached_tokens, int(input_tokens))
24
+ return cached_tokens
25
+
26
+
16
27
  def usage_tokens_from_log(log: RequestLogLike) -> UsageTokens | None:
17
28
  input_tokens = log.input_tokens
18
29
  if input_tokens is None:
@@ -20,8 +31,7 @@ def usage_tokens_from_log(log: RequestLogLike) -> UsageTokens | None:
20
31
  output_tokens = log.output_tokens if log.output_tokens is not None else log.reasoning_tokens
21
32
  if output_tokens is None:
22
33
  return None
23
- cached_tokens = log.cached_input_tokens or 0
24
- cached_tokens = max(0, min(cached_tokens, input_tokens))
34
+ cached_tokens = cached_input_tokens_from_log(log) or 0
25
35
  return UsageTokens(
26
36
  input_tokens=float(input_tokens),
27
37
  output_tokens=float(output_tokens),
app/core/usage/quota.py CHANGED
@@ -30,8 +30,11 @@ def apply_usage_quota(
30
30
  reset_at = secondary_reset
31
31
  return status, used_percent, reset_at
32
32
  if status == AccountStatus.QUOTA_EXCEEDED:
33
- status = AccountStatus.ACTIVE
34
- reset_at = None
33
+ if runtime_reset and runtime_reset > time.time():
34
+ reset_at = runtime_reset
35
+ else:
36
+ status = AccountStatus.ACTIVE
37
+ reset_at = None
35
38
  elif status == AccountStatus.QUOTA_EXCEEDED and secondary_reset is not None:
36
39
  reset_at = secondary_reset
37
40
 
@@ -45,8 +48,11 @@ def apply_usage_quota(
45
48
  reset_at = _fallback_primary_reset(primary_window_minutes) or reset_at
46
49
  return status, used_percent, reset_at
47
50
  if status == AccountStatus.RATE_LIMITED:
48
- status = AccountStatus.ACTIVE
49
- reset_at = None
51
+ if runtime_reset and runtime_reset > time.time():
52
+ reset_at = runtime_reset
53
+ else:
54
+ status = AccountStatus.ACTIVE
55
+ reset_at = None
50
56
 
51
57
  return status, used_percent, reset_at
52
58
 
app/core/usage/types.py CHANGED
@@ -67,8 +67,9 @@ class UsageCostSummary:
67
67
  class UsageMetricsSummary:
68
68
  requests_7d: int | None
69
69
  tokens_secondary_window: int | None
70
- error_rate_7d: float | None
71
- top_error: str | None
70
+ cached_tokens_secondary_window: int | None = None
71
+ error_rate_7d: float | None = None
72
+ top_error: str | None = None
72
73
 
73
74
 
74
75
  @dataclass(frozen=True)
@@ -8,7 +8,13 @@ from typing import Awaitable, Callable, Final
8
8
  from sqlalchemy import text
9
9
  from sqlalchemy.ext.asyncio import AsyncSession
10
10
 
11
- from app.db.migrations.versions import normalize_account_plan_types
11
+ from app.db.migrations.versions import (
12
+ add_accounts_chatgpt_account_id,
13
+ add_accounts_reset_at,
14
+ add_dashboard_settings,
15
+ add_request_logs_reasoning_effort,
16
+ normalize_account_plan_types,
17
+ )
12
18
 
13
19
  _CREATE_MIGRATIONS_TABLE = """
14
20
  CREATE TABLE IF NOT EXISTS schema_migrations (
@@ -21,6 +27,7 @@ _INSERT_MIGRATION = """
21
27
  INSERT INTO schema_migrations (name, applied_at)
22
28
  VALUES (:name, :applied_at)
23
29
  ON CONFLICT(name) DO NOTHING
30
+ RETURNING name
24
31
  """
25
32
 
26
33
 
@@ -32,6 +39,10 @@ class Migration:
32
39
 
33
40
  MIGRATIONS: Final[tuple[Migration, ...]] = (
34
41
  Migration("001_normalize_account_plan_types", normalize_account_plan_types.run),
42
+ Migration("002_add_request_logs_reasoning_effort", add_request_logs_reasoning_effort.run),
43
+ Migration("003_add_accounts_reset_at", add_accounts_reset_at.run),
44
+ Migration("004_add_accounts_chatgpt_account_id", add_accounts_chatgpt_account_id.run),
45
+ Migration("005_add_dashboard_settings", add_dashboard_settings.run),
35
46
  )
36
47
 
37
48
 
@@ -54,8 +65,8 @@ async def _apply_migration(session: AsyncSession, migration: Migration) -> bool:
54
65
  "applied_at": _utcnow_iso(),
55
66
  },
56
67
  )
57
- rowcount = getattr(result, "rowcount", 0) or 0
58
- if not rowcount:
68
+ inserted = result.scalar_one_or_none()
69
+ if inserted is None:
59
70
  return False
60
71
  await migration.run(session)
61
72
  return True
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlalchemy import text
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+
6
+
7
+ async def run(session: AsyncSession) -> None:
8
+ bind = session.get_bind()
9
+ dialect = getattr(getattr(bind, "dialect", None), "name", None)
10
+ if dialect == "sqlite":
11
+ await _sqlite_add_column_if_missing(session, "accounts", "chatgpt_account_id", "VARCHAR")
12
+ elif dialect == "postgresql":
13
+ await session.execute(
14
+ text("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS chatgpt_account_id VARCHAR"),
15
+ )
16
+
17
+
18
+ async def _sqlite_add_column_if_missing(
19
+ session: AsyncSession,
20
+ table: str,
21
+ column: str,
22
+ column_type: str,
23
+ ) -> None:
24
+ result = await session.execute(text(f"PRAGMA table_info({table})"))
25
+ rows = result.fetchall()
26
+ existing = {row[1] for row in rows if len(row) > 1}
27
+ if column in existing:
28
+ return
29
+ await session.execute(text(f"ALTER TABLE {table} ADD COLUMN {column} {column_type}"))
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlalchemy import text
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+
6
+
7
+ async def run(session: AsyncSession) -> None:
8
+ bind = session.get_bind()
9
+ dialect = getattr(getattr(bind, "dialect", None), "name", None)
10
+ if dialect == "sqlite":
11
+ await _sqlite_add_column_if_missing(session, "accounts", "reset_at", "INTEGER")
12
+ elif dialect == "postgresql":
13
+ await session.execute(
14
+ text("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS reset_at INTEGER"),
15
+ )
16
+
17
+
18
+ async def _sqlite_add_column_if_missing(
19
+ session: AsyncSession,
20
+ table: str,
21
+ column: str,
22
+ column_type: str,
23
+ ) -> None:
24
+ result = await session.execute(text(f"PRAGMA table_info({table})"))
25
+ rows = result.fetchall()
26
+ existing = {row[1] for row in rows if len(row) > 1}
27
+ if column in existing:
28
+ return
29
+ await session.execute(text(f"ALTER TABLE {table} ADD COLUMN {column} {column_type}"))
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlalchemy import inspect
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.orm import Session
6
+
7
+ from app.db.models import DashboardSettings
8
+
9
+
10
+ def _settings_table_exists(session: Session) -> bool:
11
+ inspector = inspect(session.connection())
12
+ return inspector.has_table("dashboard_settings")
13
+
14
+
15
+ async def run(session: AsyncSession) -> None:
16
+ exists = await session.run_sync(_settings_table_exists)
17
+ if not exists:
18
+ return
19
+
20
+ row = await session.get(DashboardSettings, 1)
21
+ if row is not None:
22
+ return
23
+
24
+ session.add(
25
+ DashboardSettings(
26
+ id=1,
27
+ sticky_threads_enabled=False,
28
+ prefer_earlier_reset_accounts=False,
29
+ )
30
+ )
31
+ await session.flush()
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlalchemy import inspect, text
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy.orm import Session
6
+
7
+
8
+ def _request_logs_column_state(session: Session) -> tuple[bool, bool]:
9
+ conn = session.connection()
10
+ inspector = inspect(conn)
11
+ if not inspector.has_table("request_logs"):
12
+ return False, False
13
+ columns = {column["name"] for column in inspector.get_columns("request_logs")}
14
+ return True, "reasoning_effort" in columns
15
+
16
+
17
+ async def run(session: AsyncSession) -> None:
18
+ has_table, has_column = await session.run_sync(_request_logs_column_state)
19
+ if not has_table or has_column:
20
+ return
21
+ await session.execute(text("ALTER TABLE request_logs ADD COLUMN reasoning_effort VARCHAR"))
app/db/models.py CHANGED
@@ -24,6 +24,7 @@ class Account(Base):
24
24
  __tablename__ = "accounts"
25
25
 
26
26
  id: Mapped[str] = mapped_column(String, primary_key=True)
27
+ chatgpt_account_id: Mapped[str | None] = mapped_column(String, nullable=True)
27
28
  email: Mapped[str] = mapped_column(String, unique=True, nullable=False)
28
29
  plan_type: Mapped[str] = mapped_column(String, nullable=False)
29
30
 
@@ -40,6 +41,7 @@ class Account(Base):
40
41
  nullable=False,
41
42
  )
42
43
  deactivation_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
44
+ reset_at: Mapped[int | None] = mapped_column(Integer, nullable=True)
43
45
 
44
46
 
45
47
  class UsageHistory(Base):
@@ -71,12 +73,43 @@ class RequestLog(Base):
71
73
  output_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
72
74
  cached_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
73
75
  reasoning_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
76
+ reasoning_effort: Mapped[str | None] = mapped_column(String, nullable=True)
74
77
  latency_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
75
78
  status: Mapped[str] = mapped_column(String, nullable=False)
76
79
  error_code: Mapped[str | None] = mapped_column(String, nullable=True)
77
80
  error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
78
81
 
79
82
 
83
+ class StickySession(Base):
84
+ __tablename__ = "sticky_sessions"
85
+
86
+ key: Mapped[str] = mapped_column(String, primary_key=True)
87
+ account_id: Mapped[str] = mapped_column(String, ForeignKey("accounts.id"), nullable=False)
88
+ created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)
89
+ updated_at: Mapped[datetime] = mapped_column(
90
+ DateTime,
91
+ server_default=func.now(),
92
+ onupdate=func.now(),
93
+ nullable=False,
94
+ )
95
+
96
+
97
+ class DashboardSettings(Base):
98
+ __tablename__ = "dashboard_settings"
99
+
100
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=False)
101
+ sticky_threads_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
102
+ prefer_earlier_reset_accounts: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
103
+ created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False)
104
+ updated_at: Mapped[datetime] = mapped_column(
105
+ DateTime,
106
+ server_default=func.now(),
107
+ onupdate=func.now(),
108
+ nullable=False,
109
+ )
110
+
111
+
80
112
  Index("idx_usage_recorded_at", UsageHistory.recorded_at)
81
113
  Index("idx_usage_account_time", UsageHistory.account_id, UsageHistory.recorded_at)
82
114
  Index("idx_logs_account_time", RequestLog.account_id, RequestLog.requested_at)
115
+ Index("idx_sticky_account", StickySession.account_id)
app/db/session.py CHANGED
@@ -1,10 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import logging
4
+ import sqlite3
5
5
  from pathlib import Path
6
- from typing import AsyncIterator
6
+ from typing import AsyncIterator, Awaitable, TypeVar
7
7
 
8
+ import anyio
9
+ from sqlalchemy import event
10
+ from sqlalchemy.engine import Engine
8
11
  from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
9
12
 
10
13
  from app.core.config.settings import get_settings
@@ -14,33 +17,86 @@ DATABASE_URL = get_settings().database_url
14
17
 
15
18
  logger = logging.getLogger(__name__)
16
19
 
17
- engine = create_async_engine(DATABASE_URL, echo=False)
20
+ _SQLITE_BUSY_TIMEOUT_MS = 5_000
21
+ _SQLITE_BUSY_TIMEOUT_SECONDS = _SQLITE_BUSY_TIMEOUT_MS / 1000
22
+
23
+
24
+ def _is_sqlite_url(url: str) -> bool:
25
+ return url.startswith("sqlite+aiosqlite:///") or url.startswith("sqlite:///")
26
+
27
+
28
+ def _is_sqlite_memory_url(url: str) -> bool:
29
+ return _is_sqlite_url(url) and ":memory:" in url
30
+
31
+
32
+ def _configure_sqlite_engine(engine: Engine, *, enable_wal: bool) -> None:
33
+ @event.listens_for(engine, "connect")
34
+ def _set_sqlite_pragmas(dbapi_connection: sqlite3.Connection, _: object) -> None:
35
+ cursor: sqlite3.Cursor = dbapi_connection.cursor()
36
+ try:
37
+ if enable_wal:
38
+ cursor.execute("PRAGMA journal_mode=WAL")
39
+ cursor.execute("PRAGMA synchronous=NORMAL")
40
+ cursor.execute("PRAGMA foreign_keys=ON")
41
+ cursor.execute(f"PRAGMA busy_timeout={_SQLITE_BUSY_TIMEOUT_MS}")
42
+ finally:
43
+ cursor.close()
44
+
45
+
46
+ if _is_sqlite_url(DATABASE_URL):
47
+ engine = create_async_engine(
48
+ DATABASE_URL,
49
+ echo=False,
50
+ connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
51
+ )
52
+ _configure_sqlite_engine(engine.sync_engine, enable_wal=not _is_sqlite_memory_url(DATABASE_URL))
53
+ else:
54
+ engine = create_async_engine(DATABASE_URL, echo=False)
55
+
18
56
  SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
19
57
 
58
+ _T = TypeVar("_T")
59
+
20
60
 
21
61
  def _ensure_sqlite_dir(url: str) -> None:
22
- prefix = "sqlite+aiosqlite:///"
23
- if not url.startswith(prefix):
62
+ if not (url.startswith("sqlite+aiosqlite:") or url.startswith("sqlite:")):
63
+ return
64
+
65
+ marker = ":///"
66
+ marker_index = url.find(marker)
67
+ if marker_index < 0:
24
68
  return
25
- path = url[len(prefix) :]
26
- if path == ":memory:":
69
+
70
+ # Works for both relative (sqlite+aiosqlite:///./db.sqlite) and absolute
71
+ # paths (sqlite+aiosqlite:////var/lib/app/db.sqlite).
72
+ path = url[marker_index + len(marker) :]
73
+ path = path.partition("?")[0]
74
+ path = path.partition("#")[0]
75
+
76
+ if not path or path == ":memory:":
27
77
  return
78
+
28
79
  Path(path).expanduser().parent.mkdir(parents=True, exist_ok=True)
29
80
 
30
81
 
82
+ async def _shielded(awaitable: Awaitable[_T]) -> _T:
83
+ with anyio.CancelScope(shield=True):
84
+ return await awaitable
85
+
86
+
31
87
  async def _safe_rollback(session: AsyncSession) -> None:
32
88
  if not session.in_transaction():
33
89
  return
34
90
  try:
35
- await asyncio.shield(session.rollback())
36
- except Exception:
91
+ await _shielded(session.rollback())
92
+ except BaseException:
37
93
  return
38
94
 
39
95
 
40
96
  async def _safe_close(session: AsyncSession) -> None:
41
97
  try:
42
- await asyncio.shield(session.close())
43
- except Exception:
98
+ await _shielded(session.close())
99
+ except BaseException:
44
100
  return
45
101
 
46
102
 
@@ -74,3 +130,7 @@ async def init_db() -> None:
74
130
  logger.exception("Failed to apply database migrations")
75
131
  if get_settings().database_migrations_fail_fast:
76
132
  raise
133
+
134
+
135
+ async def close_db() -> None:
136
+ await engine.dispose()
app/dependencies.py CHANGED
@@ -12,8 +12,11 @@ from app.modules.accounts.repository import AccountsRepository
12
12
  from app.modules.accounts.service import AccountsService
13
13
  from app.modules.oauth.service import OauthService
14
14
  from app.modules.proxy.service import ProxyService
15
+ from app.modules.proxy.sticky_repository import StickySessionsRepository
15
16
  from app.modules.request_logs.repository import RequestLogsRepository
16
17
  from app.modules.request_logs.service import RequestLogsService
18
+ from app.modules.settings.repository import SettingsRepository
19
+ from app.modules.settings.service import SettingsService
17
20
  from app.modules.usage.repository import UsageRepository
18
21
  from app.modules.usage.service import UsageService
19
22
 
@@ -49,6 +52,13 @@ class RequestLogsContext:
49
52
  service: RequestLogsService
50
53
 
51
54
 
55
+ @dataclass(slots=True)
56
+ class SettingsContext:
57
+ session: AsyncSession
58
+ repository: SettingsRepository
59
+ service: SettingsService
60
+
61
+
52
62
  def get_accounts_context(
53
63
  session: AsyncSession = Depends(get_session),
54
64
  ) -> AccountsContext:
@@ -104,7 +114,15 @@ def get_proxy_context(
104
114
  accounts_repository = AccountsRepository(session)
105
115
  usage_repository = UsageRepository(session)
106
116
  request_logs_repository = RequestLogsRepository(session)
107
- service = ProxyService(accounts_repository, usage_repository, request_logs_repository)
117
+ sticky_repository = StickySessionsRepository(session)
118
+ settings_repository = SettingsRepository(session)
119
+ service = ProxyService(
120
+ accounts_repository,
121
+ usage_repository,
122
+ request_logs_repository,
123
+ sticky_repository,
124
+ settings_repository,
125
+ )
108
126
  return ProxyContext(service=service)
109
127
 
110
128
 
@@ -114,3 +132,11 @@ def get_request_logs_context(
114
132
  repository = RequestLogsRepository(session)
115
133
  service = RequestLogsService(repository)
116
134
  return RequestLogsContext(session=session, repository=repository, service=service)
135
+
136
+
137
+ def get_settings_context(
138
+ session: AsyncSession = Depends(get_session),
139
+ ) -> SettingsContext:
140
+ repository = SettingsRepository(session)
141
+ service = SettingsService(repository)
142
+ return SettingsContext(session=session, repository=repository, service=service)
app/main.py CHANGED
@@ -18,12 +18,13 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
18
18
  from app.core.clients.http import close_http_client, init_http_client
19
19
  from app.core.errors import dashboard_error
20
20
  from app.core.utils.request_id import get_request_id, reset_request_id, set_request_id
21
- from app.db.session import init_db
21
+ from app.db.session import close_db, init_db
22
22
  from app.modules.accounts import api as accounts_api
23
23
  from app.modules.health import api as health_api
24
24
  from app.modules.oauth import api as oauth_api
25
25
  from app.modules.proxy import api as proxy_api
26
26
  from app.modules.request_logs import api as request_logs_api
27
+ from app.modules.settings import api as settings_api
27
28
  from app.modules.usage import api as usage_api
28
29
 
29
30
  logger = logging.getLogger(__name__)
@@ -37,7 +38,10 @@ async def lifespan(_: FastAPI):
37
38
  try:
38
39
  yield
39
40
  finally:
40
- await close_http_client()
41
+ try:
42
+ await close_http_client()
43
+ finally:
44
+ await close_db()
41
45
 
42
46
 
43
47
  def create_app() -> FastAPI:
@@ -103,6 +107,7 @@ def create_app() -> FastAPI:
103
107
  app.include_router(usage_api.router)
104
108
  app.include_router(request_logs_api.router)
105
109
  app.include_router(oauth_api.router)
110
+ app.include_router(settings_api.router)
106
111
  app.include_router(health_api.router)
107
112
 
108
113
  static_dir = Path(__file__).parent / "static"
@@ -116,6 +121,10 @@ def create_app() -> FastAPI:
116
121
  async def spa_accounts():
117
122
  return FileResponse(index_html, media_type="text/html")
118
123
 
124
+ @app.get("/settings", include_in_schema=False)
125
+ async def spa_settings():
126
+ return FileResponse(index_html, media_type="text/html")
127
+
119
128
  app.mount("/dashboard", StaticFiles(directory=static_dir, html=True), name="dashboard")
120
129
 
121
130
  return app