codex-lb 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import cast
4
-
5
3
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
6
4
 
5
+ from app.core.openai.message_coercion import coerce_messages
7
6
  from app.core.openai.requests import (
8
7
  ResponsesCompactRequest,
9
8
  ResponsesReasoning,
@@ -21,7 +20,7 @@ class V1ResponsesRequest(BaseModel):
21
20
  input: list[JsonValue] | None = None
22
21
  instructions: str | None = None
23
22
  tools: list[JsonValue] = Field(default_factory=list)
24
- tool_choice: str | None = None
23
+ tool_choice: str | dict[str, JsonValue] | None = None
25
24
  parallel_tool_calls: bool | None = None
26
25
  reasoning: ResponsesReasoning | None = None
27
26
  store: bool | None = None
@@ -54,7 +53,7 @@ class V1ResponsesRequest(BaseModel):
54
53
  input_items: list[JsonValue] = input_value if isinstance(input_value, list) else []
55
54
 
56
55
  if messages is not None:
57
- instruction_text, input_items = _coerce_messages(instruction_text, messages)
56
+ instruction_text, input_items = coerce_messages(instruction_text, messages)
58
57
 
59
58
  data["instructions"] = instruction_text
60
59
  data["input"] = input_items
@@ -86,63 +85,8 @@ class V1ResponsesCompactRequest(BaseModel):
86
85
  input_items: list[JsonValue] = input_value if isinstance(input_value, list) else []
87
86
 
88
87
  if messages is not None:
89
- instruction_text, input_items = _coerce_messages(instruction_text, messages)
88
+ instruction_text, input_items = coerce_messages(instruction_text, messages)
90
89
 
91
90
  data["instructions"] = instruction_text
92
91
  data["input"] = input_items
93
92
  return ResponsesCompactRequest.model_validate(data)
94
-
95
-
96
- def _coerce_messages(existing_instructions: str, messages: list[JsonValue]) -> tuple[str, list[JsonValue]]:
97
- instruction_parts: list[str] = []
98
- input_messages: list[JsonValue] = []
99
- for message in messages:
100
- if not isinstance(message, dict):
101
- raise ValueError("Each message must be an object.")
102
- message_dict = cast(dict[str, JsonValue], message)
103
- role_value = message_dict.get("role")
104
- role = role_value if isinstance(role_value, str) else None
105
- if role in ("system", "developer"):
106
- content_text = _content_to_text(message_dict.get("content"))
107
- if content_text:
108
- instruction_parts.append(content_text)
109
- continue
110
- input_messages.append(cast(JsonValue, message_dict))
111
- merged = _merge_instructions(existing_instructions, instruction_parts)
112
- return merged, input_messages
113
-
114
-
115
- def _merge_instructions(existing: str, extra_parts: list[str]) -> str:
116
- if not extra_parts:
117
- return existing
118
- extra = "\n".join([part for part in extra_parts if part])
119
- if not extra:
120
- return existing
121
- if existing:
122
- return f"{existing}\n{extra}"
123
- return extra
124
-
125
-
126
- def _content_to_text(content: object) -> str | None:
127
- if content is None:
128
- return None
129
- if isinstance(content, str):
130
- return content
131
- if isinstance(content, list):
132
- parts: list[str] = []
133
- for part in content:
134
- if isinstance(part, str):
135
- parts.append(part)
136
- elif isinstance(part, dict):
137
- part_dict = cast(dict[str, JsonValue], part)
138
- text = part_dict.get("text")
139
- if isinstance(text, str):
140
- parts.append(text)
141
- return "\n".join([part for part in parts if part])
142
- if isinstance(content, dict):
143
- content_dict = cast(dict[str, JsonValue], content)
144
- text = content_dict.get("text")
145
- if isinstance(text, str):
146
- return text
147
- return None
148
- return None
app/db/session.py CHANGED
@@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
13
13
  from app.core.config.settings import get_settings
14
14
  from app.db.migrations import run_migrations
15
15
 
16
- DATABASE_URL = get_settings().database_url
16
+ _settings = get_settings()
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
@@ -43,15 +43,32 @@ def _configure_sqlite_engine(engine: Engine, *, enable_wal: bool) -> None:
43
43
  cursor.close()
44
44
 
45
45
 
46
- if _is_sqlite_url(DATABASE_URL):
46
+ if _is_sqlite_url(_settings.database_url):
47
+ is_sqlite_memory = _is_sqlite_memory_url(_settings.database_url)
48
+ if is_sqlite_memory:
49
+ engine = create_async_engine(
50
+ _settings.database_url,
51
+ echo=False,
52
+ connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
53
+ )
54
+ else:
55
+ engine = create_async_engine(
56
+ _settings.database_url,
57
+ echo=False,
58
+ pool_size=_settings.database_pool_size,
59
+ max_overflow=_settings.database_max_overflow,
60
+ pool_timeout=_settings.database_pool_timeout_seconds,
61
+ connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
62
+ )
63
+ _configure_sqlite_engine(engine.sync_engine, enable_wal=not is_sqlite_memory)
64
+ else:
47
65
  engine = create_async_engine(
48
- DATABASE_URL,
66
+ _settings.database_url,
49
67
  echo=False,
50
- connect_args={"timeout": _SQLITE_BUSY_TIMEOUT_SECONDS},
68
+ pool_size=_settings.database_pool_size,
69
+ max_overflow=_settings.database_max_overflow,
70
+ pool_timeout=_settings.database_pool_timeout_seconds,
51
71
  )
52
- _configure_sqlite_engine(engine.sync_engine, enable_wal=not _is_sqlite_memory_url(DATABASE_URL))
53
- else:
54
- engine = create_async_engine(DATABASE_URL, echo=False)
55
72
 
56
73
  SessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
57
74
 
@@ -116,7 +133,7 @@ async def get_session() -> AsyncIterator[AsyncSession]:
116
133
  async def init_db() -> None:
117
134
  from app.db.models import Base
118
135
 
119
- _ensure_sqlite_dir(DATABASE_URL)
136
+ _ensure_sqlite_dir(_settings.database_url)
120
137
 
121
138
  async with engine.begin() as conn:
122
139
  await conn.run_sync(Base.metadata.create_all)
app/dependencies.py CHANGED
@@ -11,6 +11,7 @@ from app.db.session import SessionLocal, _safe_close, _safe_rollback, get_sessio
11
11
  from app.modules.accounts.repository import AccountsRepository
12
12
  from app.modules.accounts.service import AccountsService
13
13
  from app.modules.oauth.service import OauthService
14
+ from app.modules.proxy.repo_bundle import ProxyRepositories
14
15
  from app.modules.proxy.service import ProxyService
15
16
  from app.modules.proxy.sticky_repository import StickySessionsRepository
16
17
  from app.modules.request_logs.repository import RequestLogsRepository
@@ -79,7 +80,12 @@ def get_usage_context(
79
80
  usage_repository = UsageRepository(session)
80
81
  request_logs_repository = RequestLogsRepository(session)
81
82
  accounts_repository = AccountsRepository(session)
82
- service = UsageService(usage_repository, request_logs_repository, accounts_repository)
83
+ service = UsageService(
84
+ usage_repository,
85
+ request_logs_repository,
86
+ accounts_repository,
87
+ refresh_repo_factory=_usage_refresh_context,
88
+ )
83
89
  return UsageContext(
84
90
  session=session,
85
91
  usage_repository=usage_repository,
@@ -101,6 +107,40 @@ async def _accounts_repo_context() -> AsyncIterator[AccountsRepository]:
101
107
  await _safe_close(session)
102
108
 
103
109
 
110
+ @asynccontextmanager
111
+ async def _usage_refresh_context() -> AsyncIterator[tuple[UsageRepository, AccountsRepository]]:
112
+ session = SessionLocal()
113
+ try:
114
+ yield UsageRepository(session), AccountsRepository(session)
115
+ except BaseException:
116
+ await _safe_rollback(session)
117
+ raise
118
+ finally:
119
+ if session.in_transaction():
120
+ await _safe_rollback(session)
121
+ await _safe_close(session)
122
+
123
+
124
+ @asynccontextmanager
125
+ async def _proxy_repo_context() -> AsyncIterator[ProxyRepositories]:
126
+ session = SessionLocal()
127
+ try:
128
+ yield ProxyRepositories(
129
+ accounts=AccountsRepository(session),
130
+ usage=UsageRepository(session),
131
+ request_logs=RequestLogsRepository(session),
132
+ sticky_sessions=StickySessionsRepository(session),
133
+ settings=SettingsRepository(session),
134
+ )
135
+ except BaseException:
136
+ await _safe_rollback(session)
137
+ raise
138
+ finally:
139
+ if session.in_transaction():
140
+ await _safe_rollback(session)
141
+ await _safe_close(session)
142
+
143
+
104
144
  def get_oauth_context(
105
145
  session: AsyncSession = Depends(get_session),
106
146
  ) -> OauthContext:
@@ -108,21 +148,8 @@ def get_oauth_context(
108
148
  return OauthContext(service=OauthService(accounts_repository, repo_factory=_accounts_repo_context))
109
149
 
110
150
 
111
- def get_proxy_context(
112
- session: AsyncSession = Depends(get_session),
113
- ) -> ProxyContext:
114
- accounts_repository = AccountsRepository(session)
115
- usage_repository = UsageRepository(session)
116
- request_logs_repository = RequestLogsRepository(session)
117
- sticky_repository = StickySessionsRepository(session)
118
- settings_repository = SettingsRepository(session)
119
- service = ProxyService(
120
- accounts_repository,
121
- usage_repository,
122
- request_logs_repository,
123
- sticky_repository,
124
- settings_repository,
125
- )
151
+ def get_proxy_context() -> ProxyContext:
152
+ service = ProxyService(repo_factory=_proxy_repo_context)
126
153
  return ProxyContext(service=service)
127
154
 
128
155
 
app/main.py CHANGED
@@ -1,23 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  from contextlib import asynccontextmanager
5
4
  from pathlib import Path
6
- from uuid import uuid4
7
5
 
8
- from fastapi import FastAPI, Request
9
- from fastapi.exception_handlers import (
10
- http_exception_handler,
11
- request_validation_exception_handler,
12
- )
13
- from fastapi.exceptions import RequestValidationError
14
- from fastapi.responses import FileResponse, JSONResponse, RedirectResponse, Response
6
+ from fastapi import FastAPI
7
+ from fastapi.responses import FileResponse, RedirectResponse
15
8
  from fastapi.staticfiles import StaticFiles
16
- from starlette.exceptions import HTTPException as StarletteHTTPException
17
9
 
18
10
  from app.core.clients.http import close_http_client, init_http_client
19
- from app.core.errors import dashboard_error
20
- from app.core.utils.request_id import get_request_id, reset_request_id, set_request_id
11
+ from app.core.handlers import add_exception_handlers
12
+ from app.core.middleware import (
13
+ add_api_unhandled_error_middleware,
14
+ add_request_decompression_middleware,
15
+ add_request_id_middleware,
16
+ )
21
17
  from app.db.session import close_db, init_db
22
18
  from app.modules.accounts import api as accounts_api
23
19
  from app.modules.health import api as health_api
@@ -27,8 +23,6 @@ from app.modules.request_logs import api as request_logs_api
27
23
  from app.modules.settings import api as settings_api
28
24
  from app.modules.usage import api as usage_api
29
25
 
30
- logger = logging.getLogger(__name__)
31
-
32
26
 
33
27
  @asynccontextmanager
34
28
  async def lifespan(_: FastAPI):
@@ -47,59 +41,10 @@ async def lifespan(_: FastAPI):
47
41
  def create_app() -> FastAPI:
48
42
  app = FastAPI(title="codex-lb", version="0.1.0", lifespan=lifespan)
49
43
 
50
- @app.middleware("http")
51
- async def request_id_middleware(request: Request, call_next) -> JSONResponse:
52
- inbound_request_id = request.headers.get("x-request-id") or request.headers.get("request-id")
53
- request_id = inbound_request_id or str(uuid4())
54
- token = set_request_id(request_id)
55
- try:
56
- response = await call_next(request)
57
- except Exception:
58
- reset_request_id(token)
59
- raise
60
- response.headers.setdefault("x-request-id", request_id)
61
- return response
62
-
63
- @app.middleware("http")
64
- async def api_unhandled_error_middleware(request: Request, call_next) -> Response:
65
- try:
66
- return await call_next(request)
67
- except Exception:
68
- if request.url.path.startswith("/api/"):
69
- logger.exception(
70
- "Unhandled API error request_id=%s",
71
- get_request_id(),
72
- )
73
- return JSONResponse(
74
- status_code=500,
75
- content=dashboard_error("internal_error", "Unexpected error"),
76
- )
77
- raise
78
-
79
- @app.exception_handler(RequestValidationError)
80
- async def _validation_error_handler(
81
- request: Request,
82
- exc: RequestValidationError,
83
- ) -> Response:
84
- if request.url.path.startswith("/api/"):
85
- return JSONResponse(
86
- status_code=422,
87
- content=dashboard_error("validation_error", "Invalid request payload"),
88
- )
89
- return await request_validation_exception_handler(request, exc)
90
-
91
- @app.exception_handler(StarletteHTTPException)
92
- async def _http_error_handler(
93
- request: Request,
94
- exc: StarletteHTTPException,
95
- ) -> Response:
96
- if request.url.path.startswith("/api/"):
97
- detail = exc.detail if isinstance(exc.detail, str) else "Request failed"
98
- return JSONResponse(
99
- status_code=exc.status_code,
100
- content=dashboard_error(f"http_{exc.status_code}", detail),
101
- )
102
- return await http_exception_handler(request, exc)
44
+ add_request_decompression_middleware(app)
45
+ add_request_id_middleware(app)
46
+ add_api_unhandled_error_middleware(app)
47
+ add_exception_handlers(app)
103
48
 
104
49
  app.include_router(proxy_api.router)
105
50
  app.include_router(proxy_api.v1_router)
@@ -19,19 +19,19 @@ class AccountsRepository:
19
19
  async def upsert(self, account: Account) -> Account:
20
20
  existing = await self._session.get(Account, account.id)
21
21
  if existing:
22
- existing.chatgpt_account_id = account.chatgpt_account_id
23
- existing.email = account.email
24
- existing.plan_type = account.plan_type
25
- existing.access_token_encrypted = account.access_token_encrypted
26
- existing.refresh_token_encrypted = account.refresh_token_encrypted
27
- existing.id_token_encrypted = account.id_token_encrypted
28
- existing.last_refresh = account.last_refresh
29
- existing.status = account.status
30
- existing.deactivation_reason = account.deactivation_reason
22
+ _apply_account_updates(existing, account)
31
23
  await self._session.commit()
32
24
  await self._session.refresh(existing)
33
25
  return existing
34
26
 
27
+ result = await self._session.execute(select(Account).where(Account.email == account.email))
28
+ existing_by_email = result.scalar_one_or_none()
29
+ if existing_by_email:
30
+ _apply_account_updates(existing_by_email, account)
31
+ await self._session.commit()
32
+ await self._session.refresh(existing_by_email)
33
+ return existing_by_email
34
+
35
35
  self._session.add(account)
36
36
  await self._session.commit()
37
37
  await self._session.refresh(account)
@@ -89,3 +89,15 @@ class AccountsRepository:
89
89
  )
90
90
  await self._session.commit()
91
91
  return result.scalar_one_or_none() is not None
92
+
93
+
94
+ def _apply_account_updates(target: Account, source: Account) -> None:
95
+ target.chatgpt_account_id = source.chatgpt_account_id
96
+ target.email = source.email
97
+ target.plan_type = source.plan_type
98
+ target.access_token_encrypted = source.access_token_encrypted
99
+ target.refresh_token_encrypted = source.refresh_token_encrypted
100
+ target.id_token_encrypted = source.id_token_encrypted
101
+ target.last_refresh = source.last_refresh
102
+ target.status = source.status
103
+ target.deactivation_reason = source.deactivation_reason
app/modules/proxy/api.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import time
3
4
  from collections.abc import AsyncIterator
4
5
 
5
6
  from fastapi import APIRouter, Body, Depends, Request, Response
@@ -7,6 +8,9 @@ from fastapi.responses import JSONResponse, StreamingResponse
7
8
 
8
9
  from app.core.clients.proxy import ProxyResponseError
9
10
  from app.core.errors import openai_error
11
+ from app.core.openai.chat_requests import ChatCompletionsRequest
12
+ from app.core.openai.chat_responses import collect_chat_completion, stream_chat_chunks
13
+ from app.core.openai.models_catalog import MODEL_CATALOG
10
14
  from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest
11
15
  from app.core.openai.v1_requests import V1ResponsesCompactRequest, V1ResponsesRequest
12
16
  from app.dependencies import ProxyContext, get_proxy_context
@@ -35,6 +39,60 @@ async def v1_responses(
35
39
  return await _stream_responses(request, payload.to_responses_request(), context)
36
40
 
37
41
 
42
+ @v1_router.get("/models")
43
+ async def v1_models() -> JSONResponse:
44
+ created = int(time.time())
45
+ items = [
46
+ {
47
+ "id": model_id,
48
+ "object": "model",
49
+ "created": created,
50
+ "owned_by": "codex-lb",
51
+ "metadata": entry.model_dump(mode="json"),
52
+ }
53
+ for model_id, entry in MODEL_CATALOG.items()
54
+ ]
55
+ return JSONResponse({"object": "list", "data": items})
56
+
57
+
58
+ @v1_router.post("/chat/completions")
59
+ async def v1_chat_completions(
60
+ request: Request,
61
+ payload: ChatCompletionsRequest = Body(...),
62
+ context: ProxyContext = Depends(get_proxy_context),
63
+ ) -> Response:
64
+ rate_limit_headers = await context.service.rate_limit_headers()
65
+ responses_payload = payload.to_responses_request()
66
+ responses_payload.stream = True
67
+ stream = context.service.stream_responses(
68
+ responses_payload,
69
+ request.headers,
70
+ propagate_http_errors=True,
71
+ )
72
+ try:
73
+ first = await stream.__anext__()
74
+ except StopAsyncIteration:
75
+ first = None
76
+ except ProxyResponseError as exc:
77
+ return JSONResponse(status_code=exc.status_code, content=exc.payload, headers=rate_limit_headers)
78
+
79
+ stream_with_first = _prepend_first(first, stream)
80
+ if payload.stream:
81
+ return StreamingResponse(
82
+ stream_chat_chunks(stream_with_first, model=payload.model),
83
+ media_type="text/event-stream",
84
+ headers={"Cache-Control": "no-cache", **rate_limit_headers},
85
+ )
86
+
87
+ result = await collect_chat_completion(stream_with_first, model=payload.model)
88
+ status_code = 200
89
+ if isinstance(result, dict) and "error" in result:
90
+ error = result.get("error")
91
+ code = error.get("code") if isinstance(error, dict) else None
92
+ status_code = 503 if code == "no_accounts" else 502
93
+ return JSONResponse(content=result, status_code=status_code, headers=rate_limit_headers)
94
+
95
+
38
96
  async def _stream_responses(
39
97
  request: Request,
40
98
  payload: ResponsesRequest,
@@ -16,8 +16,8 @@ from app.core.balancer.types import UpstreamError
16
16
  from app.core.usage.quota import apply_usage_quota
17
17
  from app.db.models import Account, UsageHistory
18
18
  from app.modules.accounts.repository import AccountsRepository
19
+ from app.modules.proxy.repo_bundle import ProxyRepoFactory
19
20
  from app.modules.proxy.sticky_repository import StickySessionsRepository
20
- from app.modules.usage.repository import UsageRepository
21
21
  from app.modules.usage.updater import UsageUpdater
22
22
 
23
23
 
@@ -37,16 +37,8 @@ class AccountSelection:
37
37
 
38
38
 
39
39
  class LoadBalancer:
40
- def __init__(
41
- self,
42
- accounts_repo: AccountsRepository,
43
- usage_repo: UsageRepository,
44
- sticky_repo: StickySessionsRepository | None = None,
45
- ) -> None:
46
- self._accounts_repo = accounts_repo
47
- self._usage_repo = usage_repo
48
- self._usage_updater = UsageUpdater(usage_repo, accounts_repo)
49
- self._sticky_repo = sticky_repo
40
+ def __init__(self, repo_factory: ProxyRepoFactory) -> None:
41
+ self._repo_factory = repo_factory
50
42
  self._runtime: dict[str, RuntimeState] = {}
51
43
 
52
44
  async def select_account(
@@ -56,43 +48,53 @@ class LoadBalancer:
56
48
  reallocate_sticky: bool = False,
57
49
  prefer_earlier_reset_accounts: bool = False,
58
50
  ) -> AccountSelection:
59
- accounts = await self._accounts_repo.list_accounts()
60
- latest_primary = await self._usage_repo.latest_by_account()
61
- await self._usage_updater.refresh_accounts(accounts, latest_primary)
62
- latest_primary = await self._usage_repo.latest_by_account()
63
- latest_secondary = await self._usage_repo.latest_by_account(window="secondary")
64
-
65
- states, account_map = _build_states(
66
- accounts=accounts,
67
- latest_primary=latest_primary,
68
- latest_secondary=latest_secondary,
69
- runtime=self._runtime,
70
- )
51
+ selected_snapshot: Account | None = None
52
+ error_message: str | None = None
53
+ async with self._repo_factory() as repos:
54
+ accounts = await repos.accounts.list_accounts()
55
+ latest_primary = await repos.usage.latest_by_account()
56
+ updater = UsageUpdater(repos.usage, repos.accounts)
57
+ await updater.refresh_accounts(accounts, latest_primary)
58
+ latest_primary = await repos.usage.latest_by_account()
59
+ latest_secondary = await repos.usage.latest_by_account(window="secondary")
60
+
61
+ states, account_map = _build_states(
62
+ accounts=accounts,
63
+ latest_primary=latest_primary,
64
+ latest_secondary=latest_secondary,
65
+ runtime=self._runtime,
66
+ )
71
67
 
72
- result = await self._select_with_stickiness(
73
- states=states,
74
- account_map=account_map,
75
- sticky_key=sticky_key,
76
- reallocate_sticky=reallocate_sticky,
77
- prefer_earlier_reset_accounts=prefer_earlier_reset_accounts,
78
- )
79
- for state in states:
80
- account = account_map.get(state.account_id)
81
- if account:
82
- await self._sync_state(account, state)
83
-
84
- if result.account is None:
85
- return AccountSelection(account=None, error_message=result.error_message)
86
-
87
- selected = account_map.get(result.account.account_id)
88
- if selected:
89
- selected.status = result.account.status
90
- selected.deactivation_reason = result.account.deactivation_reason
91
- runtime = self._runtime.setdefault(selected.id, RuntimeState())
92
- runtime.last_selected_at = time.time()
93
- if selected is None:
94
- return AccountSelection(account=None, error_message=result.error_message)
95
- return AccountSelection(account=selected, error_message=None)
68
+ result = await self._select_with_stickiness(
69
+ states=states,
70
+ account_map=account_map,
71
+ sticky_key=sticky_key,
72
+ reallocate_sticky=reallocate_sticky,
73
+ prefer_earlier_reset_accounts=prefer_earlier_reset_accounts,
74
+ sticky_repo=repos.sticky_sessions,
75
+ )
76
+ for state in states:
77
+ account = account_map.get(state.account_id)
78
+ if account:
79
+ await self._sync_state(repos.accounts, account, state)
80
+
81
+ if result.account is None:
82
+ error_message = result.error_message
83
+ else:
84
+ selected = account_map.get(result.account.account_id)
85
+ if selected is None:
86
+ error_message = result.error_message
87
+ else:
88
+ selected.status = result.account.status
89
+ selected.deactivation_reason = result.account.deactivation_reason
90
+ selected_snapshot = _clone_account(selected)
91
+
92
+ if selected_snapshot is None:
93
+ return AccountSelection(account=None, error_message=error_message)
94
+
95
+ runtime = self._runtime.setdefault(selected_snapshot.id, RuntimeState())
96
+ runtime.last_selected_at = time.time()
97
+ return AccountSelection(account=selected_snapshot, error_message=None)
96
98
 
97
99
  async def _select_with_stickiness(
98
100
  self,
@@ -102,21 +104,22 @@ class LoadBalancer:
102
104
  sticky_key: str | None,
103
105
  reallocate_sticky: bool,
104
106
  prefer_earlier_reset_accounts: bool,
107
+ sticky_repo: StickySessionsRepository | None,
105
108
  ) -> SelectionResult:
106
- if not sticky_key or not self._sticky_repo:
109
+ if not sticky_key or not sticky_repo:
107
110
  return select_account(states, prefer_earlier_reset=prefer_earlier_reset_accounts)
108
111
 
109
112
  if reallocate_sticky:
110
113
  chosen = select_account(states, prefer_earlier_reset=prefer_earlier_reset_accounts)
111
114
  if chosen.account is not None and chosen.account.account_id in account_map:
112
- await self._sticky_repo.upsert(sticky_key, chosen.account.account_id)
115
+ await sticky_repo.upsert(sticky_key, chosen.account.account_id)
113
116
  return chosen
114
117
 
115
- existing = await self._sticky_repo.get_account_id(sticky_key)
118
+ existing = await sticky_repo.get_account_id(sticky_key)
116
119
  if existing:
117
120
  pinned = next((state for state in states if state.account_id == existing), None)
118
121
  if pinned is None:
119
- await self._sticky_repo.delete(sticky_key)
122
+ await sticky_repo.delete(sticky_key)
120
123
  else:
121
124
  pinned_result = select_account([pinned], prefer_earlier_reset=prefer_earlier_reset_accounts)
122
125
  if pinned_result.account is not None:
@@ -124,29 +127,33 @@ class LoadBalancer:
124
127
 
125
128
  chosen = select_account(states, prefer_earlier_reset=prefer_earlier_reset_accounts)
126
129
  if chosen.account is not None and chosen.account.account_id in account_map:
127
- await self._sticky_repo.upsert(sticky_key, chosen.account.account_id)
130
+ await sticky_repo.upsert(sticky_key, chosen.account.account_id)
128
131
  return chosen
129
132
 
130
133
  async def mark_rate_limit(self, account: Account, error: UpstreamError) -> None:
131
134
  state = self._state_for(account)
132
135
  handle_rate_limit(state, error)
133
- await self._sync_state(account, state)
136
+ async with self._repo_factory() as repos:
137
+ await self._sync_state(repos.accounts, account, state)
134
138
 
135
139
  async def mark_quota_exceeded(self, account: Account, error: UpstreamError) -> None:
136
140
  state = self._state_for(account)
137
141
  handle_quota_exceeded(state, error)
138
- await self._sync_state(account, state)
142
+ async with self._repo_factory() as repos:
143
+ await self._sync_state(repos.accounts, account, state)
139
144
 
140
145
  async def mark_permanent_failure(self, account: Account, error_code: str) -> None:
141
146
  state = self._state_for(account)
142
147
  handle_permanent_failure(state, error_code)
143
- await self._sync_state(account, state)
148
+ async with self._repo_factory() as repos:
149
+ await self._sync_state(repos.accounts, account, state)
144
150
 
145
151
  async def record_error(self, account: Account) -> None:
146
152
  state = self._state_for(account)
147
153
  state.error_count += 1
148
154
  state.last_error_at = time.time()
149
- await self._sync_state(account, state)
155
+ async with self._repo_factory() as repos:
156
+ await self._sync_state(repos.accounts, account, state)
150
157
 
151
158
  def _state_for(self, account: Account) -> AccountState:
152
159
  runtime = self._runtime.setdefault(account.id, RuntimeState())
@@ -164,7 +171,12 @@ class LoadBalancer:
164
171
  deactivation_reason=account.deactivation_reason,
165
172
  )
166
173
 
167
- async def _sync_state(self, account: Account, state: AccountState) -> None:
174
+ async def _sync_state(
175
+ self,
176
+ accounts_repo: AccountsRepository,
177
+ account: Account,
178
+ state: AccountState,
179
+ ) -> None:
168
180
  runtime = self._runtime.setdefault(account.id, RuntimeState())
169
181
  runtime.reset_at = state.reset_at
170
182
  runtime.cooldown_until = state.cooldown_until
@@ -177,7 +189,7 @@ class LoadBalancer:
177
189
  reset_changed = account.reset_at != reset_at_int
178
190
 
179
191
  if status_changed or reason_changed or reset_changed:
180
- await self._accounts_repo.update_status(
192
+ await accounts_repo.update_status(
181
193
  account.id,
182
194
  state.status,
183
195
  state.deactivation_reason,
@@ -251,3 +263,8 @@ def _state_from_account(
251
263
  error_count=runtime.error_count,
252
264
  deactivation_reason=account.deactivation_reason,
253
265
  )
266
+
267
+
268
+ def _clone_account(account: Account) -> Account:
269
+ data = {column.name: getattr(account, column.name) for column in Account.__table__.columns}
270
+ return Account(**data)