codex-lb 0.1.5__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. app/__init__.py +1 -1
  2. app/core/auth/__init__.py +12 -1
  3. app/core/balancer/logic.py +44 -7
  4. app/core/clients/proxy.py +2 -4
  5. app/core/config/settings.py +4 -1
  6. app/core/plan_types.py +64 -0
  7. app/core/types.py +4 -2
  8. app/core/usage/__init__.py +5 -2
  9. app/core/usage/logs.py +12 -2
  10. app/core/usage/quota.py +64 -0
  11. app/core/usage/types.py +3 -2
  12. app/core/utils/sse.py +6 -2
  13. app/db/migrations/__init__.py +91 -0
  14. app/db/migrations/versions/__init__.py +1 -0
  15. app/db/migrations/versions/add_accounts_chatgpt_account_id.py +29 -0
  16. app/db/migrations/versions/add_accounts_reset_at.py +29 -0
  17. app/db/migrations/versions/add_dashboard_settings.py +31 -0
  18. app/db/migrations/versions/add_request_logs_reasoning_effort.py +21 -0
  19. app/db/migrations/versions/normalize_account_plan_types.py +17 -0
  20. app/db/models.py +33 -0
  21. app/db/session.py +85 -11
  22. app/dependencies.py +27 -9
  23. app/main.py +15 -6
  24. app/modules/accounts/auth_manager.py +121 -0
  25. app/modules/accounts/repository.py +14 -6
  26. app/modules/accounts/service.py +14 -9
  27. app/modules/health/api.py +5 -3
  28. app/modules/health/schemas.py +9 -0
  29. app/modules/oauth/service.py +9 -4
  30. app/modules/proxy/helpers.py +285 -0
  31. app/modules/proxy/load_balancer.py +86 -41
  32. app/modules/proxy/service.py +172 -318
  33. app/modules/proxy/sticky_repository.py +56 -0
  34. app/modules/request_logs/repository.py +6 -3
  35. app/modules/request_logs/schemas.py +2 -0
  36. app/modules/request_logs/service.py +12 -3
  37. app/modules/settings/__init__.py +1 -0
  38. app/modules/settings/api.py +37 -0
  39. app/modules/settings/repository.py +40 -0
  40. app/modules/settings/schemas.py +13 -0
  41. app/modules/settings/service.py +33 -0
  42. app/modules/shared/schemas.py +16 -2
  43. app/modules/usage/schemas.py +1 -0
  44. app/modules/usage/service.py +23 -6
  45. app/modules/{proxy/usage_updater.py → usage/updater.py} +37 -8
  46. app/static/7.css +73 -0
  47. app/static/index.css +33 -4
  48. app/static/index.html +51 -4
  49. app/static/index.js +254 -32
  50. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/METADATA +2 -2
  51. codex_lb-0.3.0.dist-info/RECORD +97 -0
  52. app/modules/proxy/auth_manager.py +0 -51
  53. codex_lb-0.1.5.dist-info/RECORD +0 -80
  54. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/WHEEL +0 -0
  55. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/entry_points.txt +0 -0
  56. {codex_lb-0.1.5.dist-info → codex_lb-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,10 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import time
5
+ from collections.abc import Sequence
5
6
  from datetime import timedelta
6
- from typing import AsyncIterator, Iterable, Mapping
7
+ from hashlib import sha256
8
+ from typing import AsyncIterator, Mapping
7
9
 
8
- from pydantic import ValidationError
10
+ import anyio
9
11
 
10
12
  from app.core import usage as usage_core
11
13
  from app.core.auth.refresh import RefreshError
@@ -14,28 +16,41 @@ from app.core.balancer.types import UpstreamError
14
16
  from app.core.clients.proxy import ProxyResponseError, filter_inbound_headers
15
17
  from app.core.clients.proxy import compact_responses as core_compact_responses
16
18
  from app.core.clients.proxy import stream_responses as core_stream_responses
19
+ from app.core.config.settings import get_settings
17
20
  from app.core.crypto import TokenEncryptor
18
- from app.core.errors import OpenAIErrorDetail, OpenAIErrorEnvelope, openai_error, response_failed_event
19
- from app.core.openai.models import OpenAIError, OpenAIResponsePayload
21
+ from app.core.errors import openai_error, response_failed_event
22
+ from app.core.openai.models import OpenAIResponsePayload
20
23
  from app.core.openai.parsing import parse_sse_event
21
24
  from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest
22
- from app.core.usage.types import UsageWindowRow, UsageWindowSummary
23
- from app.core.utils.request_id import ensure_request_id
25
+ from app.core.usage.types import UsageWindowRow
26
+ from app.core.utils.request_id import ensure_request_id, get_request_id
24
27
  from app.core.utils.sse import format_sse_event
25
28
  from app.core.utils.time import utcnow
26
- from app.db.models import Account, AccountStatus, UsageHistory
29
+ from app.db.models import Account, UsageHistory
30
+ from app.modules.accounts.auth_manager import AuthManager
27
31
  from app.modules.accounts.repository import AccountsRepository
28
- from app.modules.proxy.auth_manager import AuthManager
29
- from app.modules.proxy.load_balancer import LoadBalancer
30
- from app.modules.proxy.types import (
31
- CreditStatusDetailsData,
32
- RateLimitStatusDetailsData,
33
- RateLimitStatusPayloadData,
34
- RateLimitWindowSnapshotData,
32
+ from app.modules.proxy.helpers import (
33
+ _apply_error_metadata,
34
+ _credits_headers,
35
+ _credits_snapshot,
36
+ _header_account_id,
37
+ _normalize_error_code,
38
+ _parse_openai_error,
39
+ _plan_type_for_accounts,
40
+ _rate_limit_details,
41
+ _rate_limit_headers,
42
+ _select_accounts_for_limits,
43
+ _summarize_window,
44
+ _upstream_error_from_openai,
45
+ _window_snapshot,
35
46
  )
36
- from app.modules.proxy.usage_updater import UsageUpdater
47
+ from app.modules.proxy.load_balancer import LoadBalancer
48
+ from app.modules.proxy.sticky_repository import StickySessionsRepository
49
+ from app.modules.proxy.types import RateLimitStatusPayloadData
37
50
  from app.modules.request_logs.repository import RequestLogsRepository
51
+ from app.modules.settings.repository import SettingsRepository
38
52
  from app.modules.usage.repository import UsageRepository
53
+ from app.modules.usage.updater import UsageUpdater
39
54
 
40
55
  logger = logging.getLogger(__name__)
41
56
 
@@ -46,13 +61,16 @@ class ProxyService:
46
61
  accounts_repo: AccountsRepository,
47
62
  usage_repo: UsageRepository,
48
63
  logs_repo: RequestLogsRepository,
64
+ sticky_repo: StickySessionsRepository,
65
+ settings_repo: SettingsRepository,
49
66
  ) -> None:
50
67
  self._accounts_repo = accounts_repo
51
68
  self._usage_repo = usage_repo
52
69
  self._logs_repo = logs_repo
70
+ self._settings_repo = settings_repo
53
71
  self._encryptor = TokenEncryptor()
54
72
  self._auth_manager = AuthManager(accounts_repo)
55
- self._load_balancer = LoadBalancer(accounts_repo, usage_repo)
73
+ self._load_balancer = LoadBalancer(accounts_repo, usage_repo, sticky_repo)
56
74
  self._usage_updater = UsageUpdater(usage_repo, accounts_repo)
57
75
 
58
76
  def stream_responses(
@@ -62,6 +80,7 @@ class ProxyService:
62
80
  *,
63
81
  propagate_http_errors: bool = False,
64
82
  ) -> AsyncIterator[str]:
83
+ _maybe_log_proxy_request_shape("stream", payload, headers)
65
84
  filtered = filter_inbound_headers(headers)
66
85
  return self._stream_with_retry(
67
86
  payload,
@@ -74,8 +93,16 @@ class ProxyService:
74
93
  payload: ResponsesCompactRequest,
75
94
  headers: Mapping[str, str],
76
95
  ) -> OpenAIResponsePayload:
96
+ _maybe_log_proxy_request_shape("compact", payload, headers)
77
97
  filtered = filter_inbound_headers(headers)
78
- selection = await self._load_balancer.select_account()
98
+ settings = await self._settings_repo.get_or_create()
99
+ prefer_earlier_reset = settings.prefer_earlier_reset_accounts
100
+ sticky_key = _sticky_key_from_compact_payload(payload) if settings.sticky_threads_enabled else None
101
+ selection = await self._load_balancer.select_account(
102
+ sticky_key=sticky_key,
103
+ reallocate_sticky=sticky_key is not None,
104
+ prefer_earlier_reset_accounts=prefer_earlier_reset,
105
+ )
79
106
  account = selection.account
80
107
  if not account:
81
108
  raise ProxyResponseError(
@@ -83,7 +110,7 @@ class ProxyService:
83
110
  openai_error("no_accounts", selection.error_message or "No active accounts available"),
84
111
  )
85
112
  account = await self._ensure_fresh(account)
86
- account_id = _header_account_id(account.id)
113
+ account_id = _header_account_id(account.chatgpt_account_id)
87
114
 
88
115
  async def _call_compact(target: Account) -> OpenAIResponsePayload:
89
116
  access_token = self._encryptor.decrypt(target.access_token_encrypted)
@@ -181,9 +208,15 @@ class ProxyService:
181
208
  propagate_http_errors: bool,
182
209
  ) -> AsyncIterator[str]:
183
210
  request_id = ensure_request_id()
211
+ settings = await self._settings_repo.get_or_create()
212
+ prefer_earlier_reset = settings.prefer_earlier_reset_accounts
213
+ sticky_key = _sticky_key_from_payload(payload) if settings.sticky_threads_enabled else None
184
214
  max_attempts = 3
185
215
  for attempt in range(max_attempts):
186
- selection = await self._load_balancer.select_account()
216
+ selection = await self._load_balancer.select_account(
217
+ sticky_key=sticky_key,
218
+ prefer_earlier_reset_accounts=prefer_earlier_reset,
219
+ )
187
220
  account = selection.account
188
221
  if not account:
189
222
  event = response_failed_event(
@@ -281,8 +314,9 @@ class ProxyService:
281
314
  ) -> AsyncIterator[str]:
282
315
  account_id_value = account.id
283
316
  access_token = self._encryptor.decrypt(account.access_token_encrypted)
284
- account_id = _header_account_id(account_id_value)
317
+ account_id = _header_account_id(account.chatgpt_account_id)
285
318
  model = payload.model
319
+ reasoning_effort = payload.reasoning.effort if payload.reasoning else None
286
320
  start = time.monotonic()
287
321
  status = "success"
288
322
  error_code = None
@@ -304,7 +338,11 @@ class ProxyService:
304
338
  return
305
339
  event = parse_sse_event(first)
306
340
  if event and event.type in ("response.failed", "error"):
307
- error = event.response.error if event.type == "response.failed" else event.error
341
+ if event.type == "response.failed":
342
+ response = event.response
343
+ error = response.error if response else None
344
+ else:
345
+ error = event.error
308
346
  code = _normalize_error_code(
309
347
  error.code if error else None,
310
348
  error.type if error else None,
@@ -326,7 +364,11 @@ class ProxyService:
326
364
  event_type = event.type
327
365
  if event_type in ("response.failed", "error"):
328
366
  status = "error"
329
- error = event.response.error if event_type == "response.failed" else event.error
367
+ if event_type == "response.failed":
368
+ response = event.response
369
+ error = response.error if response else None
370
+ else:
371
+ error = event.error
330
372
  error_code = _normalize_error_code(
331
373
  error.code if error else None,
332
374
  error.type if error else None,
@@ -354,27 +396,29 @@ class ProxyService:
354
396
  reasoning_tokens = (
355
397
  usage.output_tokens_details.reasoning_tokens if usage and usage.output_tokens_details else None
356
398
  )
357
- try:
358
- await self._logs_repo.add_log(
359
- account_id=account_id_value,
360
- request_id=request_id,
361
- model=model,
362
- input_tokens=input_tokens,
363
- output_tokens=output_tokens,
364
- cached_input_tokens=cached_input_tokens,
365
- reasoning_tokens=reasoning_tokens,
366
- latency_ms=latency_ms,
367
- status=status,
368
- error_code=error_code,
369
- error_message=error_message,
370
- )
371
- except Exception:
372
- logger.warning(
373
- "Failed to persist request log account_id=%s request_id=%s",
374
- account_id_value,
375
- request_id,
376
- exc_info=True,
377
- )
399
+ with anyio.CancelScope(shield=True):
400
+ try:
401
+ await self._logs_repo.add_log(
402
+ account_id=account_id_value,
403
+ request_id=request_id,
404
+ model=model,
405
+ input_tokens=input_tokens,
406
+ output_tokens=output_tokens,
407
+ cached_input_tokens=cached_input_tokens,
408
+ reasoning_tokens=reasoning_tokens,
409
+ reasoning_effort=reasoning_effort,
410
+ latency_ms=latency_ms,
411
+ status=status,
412
+ error_code=error_code,
413
+ error_message=error_message,
414
+ )
415
+ except Exception:
416
+ logger.warning(
417
+ "Failed to persist request log account_id=%s request_id=%s",
418
+ account_id_value,
419
+ request_id,
420
+ exc_info=True,
421
+ )
378
422
 
379
423
  async def _refresh_usage(self, accounts: list[Account]) -> None:
380
424
  latest_usage = await self._usage_repo.latest_by_account(window="primary")
@@ -432,297 +476,107 @@ class ProxyService:
432
476
  await self._load_balancer.record_error(account)
433
477
 
434
478
 
435
- def _header_account_id(account_id: str | None) -> str | None:
436
- if not account_id:
437
- return None
438
- if account_id.startswith(("email_", "local_")):
439
- return None
440
- return account_id
441
-
442
-
443
- KNOWN_PLAN_TYPES = {
444
- "guest",
445
- "free",
446
- "go",
447
- "plus",
448
- "pro",
449
- "free_workspace",
450
- "team",
451
- "business",
452
- "education",
453
- "quorum",
454
- "k12",
455
- "enterprise",
456
- "edu",
457
- }
458
-
459
- PLAN_TYPE_PRIORITY = (
460
- "enterprise",
461
- "business",
462
- "team",
463
- "pro",
464
- "plus",
465
- "education",
466
- "edu",
467
- "free_workspace",
468
- "free",
469
- "go",
470
- "guest",
471
- "quorum",
472
- "k12",
473
- )
474
-
475
-
476
- def _select_accounts_for_limits(accounts: Iterable[Account]) -> list[Account]:
477
- return [account for account in accounts if account.status not in (AccountStatus.DEACTIVATED, AccountStatus.PAUSED)]
478
-
479
-
480
- def _summarize_window(
481
- rows: list[UsageWindowRow],
482
- account_map: dict[str, Account],
483
- window: str,
484
- ) -> UsageWindowSummary | None:
485
- if not rows:
486
- return None
487
- return usage_core.summarize_usage_window(rows, account_map, window)
488
-
489
-
490
- def _window_snapshot(
491
- summary: UsageWindowSummary | None,
492
- rows: list[UsageWindowRow],
493
- window: str,
494
- now_epoch: int,
495
- ) -> RateLimitWindowSnapshotData | None:
496
- if summary is None:
497
- return None
498
-
499
- used_percent = _normalize_used_percent(summary.used_percent, rows)
500
- if used_percent is None:
501
- return None
502
-
503
- reset_at = summary.reset_at
504
- if reset_at is None:
505
- return None
479
+ class _RetryableStreamError(Exception):
480
+ def __init__(self, code: str, error: UpstreamError) -> None:
481
+ super().__init__(code)
482
+ self.code = code
483
+ self.error = error
506
484
 
507
- window_minutes = summary.window_minutes or usage_core.default_window_minutes(window)
508
- if not window_minutes:
509
- return None
510
485
 
511
- limit_window_seconds = int(window_minutes * 60)
512
- reset_after_seconds = max(0, int(reset_at) - now_epoch)
486
+ def _maybe_log_proxy_request_shape(
487
+ kind: str,
488
+ payload: ResponsesRequest | ResponsesCompactRequest,
489
+ headers: Mapping[str, str],
490
+ ) -> None:
491
+ settings = get_settings()
492
+ if not settings.log_proxy_request_shape:
493
+ return
513
494
 
514
- return RateLimitWindowSnapshotData(
515
- used_percent=_percent_to_int(used_percent),
516
- limit_window_seconds=limit_window_seconds,
517
- reset_after_seconds=reset_after_seconds,
518
- reset_at=int(reset_at),
495
+ request_id = get_request_id()
496
+ prompt_cache_key = getattr(payload, "prompt_cache_key", None)
497
+ if prompt_cache_key is None and payload.model_extra:
498
+ extra_value = payload.model_extra.get("prompt_cache_key")
499
+ if isinstance(extra_value, str):
500
+ prompt_cache_key = extra_value
501
+ prompt_cache_key_hash = _hash_identifier(prompt_cache_key) if isinstance(prompt_cache_key, str) else None
502
+ prompt_cache_key_raw = (
503
+ _truncate_identifier(prompt_cache_key)
504
+ if settings.log_proxy_request_shape_raw_cache_key and isinstance(prompt_cache_key, str)
505
+ else None
519
506
  )
520
507
 
521
-
522
- def _normalize_used_percent(
523
- value: float | None,
524
- rows: Iterable[UsageWindowRow],
525
- ) -> float | None:
526
- if value is not None:
527
- return value
528
- values = [row.used_percent for row in rows if row.used_percent is not None]
529
- if not values:
530
- return None
531
- return sum(values) / len(values)
532
-
533
-
534
- def _percent_to_int(value: float) -> int:
535
- bounded = max(0.0, min(100.0, value))
536
- return int(bounded)
537
-
538
-
539
- def _rate_limit_details(
540
- primary: RateLimitWindowSnapshotData | None,
541
- secondary: RateLimitWindowSnapshotData | None,
542
- ) -> RateLimitStatusDetailsData | None:
543
- if not primary and not secondary:
544
- return None
545
- used_percents = [window.used_percent for window in (primary, secondary) if window]
546
- limit_reached = any(used >= 100 for used in used_percents)
547
- return RateLimitStatusDetailsData(
548
- allowed=not limit_reached,
549
- limit_reached=limit_reached,
550
- primary_window=primary,
551
- secondary_window=secondary,
508
+ extra_keys = sorted(payload.model_extra.keys()) if payload.model_extra else []
509
+ fields_set = sorted(payload.model_fields_set)
510
+ input_summary = _summarize_input(payload.input)
511
+ header_keys = _interesting_header_keys(headers)
512
+
513
+ logger.warning(
514
+ "proxy_request_shape request_id=%s kind=%s model=%s stream=%s input=%s "
515
+ "prompt_cache_key=%s prompt_cache_key_raw=%s fields=%s extra=%s headers=%s",
516
+ request_id,
517
+ kind,
518
+ payload.model,
519
+ getattr(payload, "stream", None),
520
+ input_summary,
521
+ prompt_cache_key_hash,
522
+ prompt_cache_key_raw,
523
+ fields_set,
524
+ extra_keys,
525
+ header_keys,
552
526
  )
553
527
 
554
528
 
555
- def _aggregate_credits(entries: Iterable[UsageHistory]) -> tuple[bool, bool, float] | None:
556
- has_data = False
557
- has_credits = False
558
- unlimited = False
559
- balance_total = 0.0
560
-
561
- for entry in entries:
562
- credits_has = entry.credits_has
563
- credits_unlimited = entry.credits_unlimited
564
- credits_balance = entry.credits_balance
565
- if credits_has is None and credits_unlimited is None and credits_balance is None:
566
- continue
567
- has_data = True
568
- if credits_has is True:
569
- has_credits = True
570
- if credits_unlimited is True:
571
- unlimited = True
572
- if credits_balance is not None and not credits_unlimited:
573
- try:
574
- balance_total += float(credits_balance)
575
- except (TypeError, ValueError):
576
- continue
577
-
578
- if not has_data:
579
- return None
580
- if unlimited:
581
- has_credits = True
582
- return has_credits, unlimited, balance_total
583
-
529
+ def _hash_identifier(value: str) -> str:
530
+ digest = sha256(value.encode("utf-8")).hexdigest()
531
+ return f"sha256:{digest[:12]}"
584
532
 
585
- def _credits_snapshot(entries: Iterable[UsageHistory]) -> CreditStatusDetailsData | None:
586
- aggregate = _aggregate_credits(entries)
587
- if aggregate is None:
588
- return None
589
- has_credits, unlimited, balance_total = aggregate
590
- balance_value = str(round(balance_total, 2))
591
- return CreditStatusDetailsData(
592
- has_credits=has_credits,
593
- unlimited=unlimited,
594
- balance=balance_value,
595
- approx_local_messages=None,
596
- approx_cloud_messages=None,
597
- )
598
533
 
534
+ def _summarize_input(items: Sequence[object]) -> str:
535
+ if not items:
536
+ return "0"
537
+ type_counts: dict[str, int] = {}
538
+ for item in items:
539
+ type_name = type(item).__name__
540
+ type_counts[type_name] = type_counts.get(type_name, 0) + 1
541
+ summary = ",".join(f"{key}={type_counts[key]}" for key in sorted(type_counts))
542
+ return f"{len(items)}({summary})"
599
543
 
600
- def _plan_type_for_accounts(accounts: Iterable[Account]) -> str:
601
- normalized = [_normalize_plan_type(account.plan_type) for account in accounts]
602
- filtered = [plan for plan in normalized if plan is not None]
603
- if not filtered:
604
- return "guest"
605
- unique = set(filtered)
606
- if len(unique) == 1:
607
- return filtered[0]
608
- for plan in PLAN_TYPE_PRIORITY:
609
- if plan in unique:
610
- return plan
611
- return "guest"
612
544
 
613
-
614
- def _normalize_plan_type(value: str | None) -> str | None:
615
- if not value:
616
- return None
617
- normalized = value.strip().lower()
618
- if normalized not in KNOWN_PLAN_TYPES:
619
- return None
620
- return normalized
621
-
622
-
623
- def _rate_limit_headers(
624
- window_label: str,
625
- summary: UsageWindowSummary,
626
- ) -> dict[str, str]:
627
- used_percent = summary.used_percent
628
- window_minutes = summary.window_minutes
629
- if used_percent is None or window_minutes is None:
630
- return {}
631
- headers = {
632
- f"x-codex-{window_label}-used-percent": str(float(used_percent)),
633
- f"x-codex-{window_label}-window-minutes": str(int(window_minutes)),
634
- }
635
- reset_at = summary.reset_at
636
- if reset_at is not None:
637
- headers[f"x-codex-{window_label}-reset-at"] = str(int(reset_at))
638
- return headers
639
-
640
-
641
- def _credits_headers(entries: Iterable[UsageHistory]) -> dict[str, str]:
642
- aggregate = _aggregate_credits(entries)
643
- if aggregate is None:
644
- return {}
645
- has_credits, unlimited, balance_total = aggregate
646
- balance_value = f"{balance_total:.2f}"
647
- return {
648
- "x-codex-credits-has-credits": "true" if has_credits else "false",
649
- "x-codex-credits-unlimited": "true" if unlimited else "false",
650
- "x-codex-credits-balance": balance_value,
545
+ def _truncate_identifier(value: str, *, max_length: int = 96) -> str:
546
+ if len(value) <= max_length:
547
+ return value
548
+ return f"{value[:48]}...{value[-16:]}"
549
+
550
+
551
+ def _interesting_header_keys(headers: Mapping[str, str]) -> list[str]:
552
+ allowlist = {
553
+ "user-agent",
554
+ "x-request-id",
555
+ "request-id",
556
+ "x-openai-client-id",
557
+ "x-openai-client-version",
558
+ "x-openai-client-arch",
559
+ "x-openai-client-os",
560
+ "x-openai-client-user-agent",
561
+ "x-codex-session-id",
562
+ "x-codex-conversation-id",
651
563
  }
564
+ return sorted({key.lower() for key in headers.keys() if key.lower() in allowlist})
652
565
 
653
566
 
654
- def _normalize_error_code(code: str | None, error_type: str | None) -> str:
655
- value = code or error_type
567
+ def _sticky_key_from_payload(payload: ResponsesRequest) -> str | None:
568
+ value = payload.prompt_cache_key
656
569
  if not value:
657
- return "upstream_error"
658
- return value.lower()
659
-
660
-
661
- def _parse_openai_error(payload: OpenAIErrorEnvelope) -> OpenAIError | None:
662
- error = payload.get("error")
663
- if not error:
664
570
  return None
665
- try:
666
- return OpenAIError.model_validate(error)
667
- except ValidationError:
668
- if not isinstance(error, dict):
669
- return None
670
- return OpenAIError(
671
- message=_coerce_str(error.get("message")),
672
- type=_coerce_str(error.get("type")),
673
- code=_coerce_str(error.get("code")),
674
- param=_coerce_str(error.get("param")),
675
- plan_type=_coerce_str(error.get("plan_type")),
676
- resets_at=_coerce_number(error.get("resets_at")),
677
- resets_in_seconds=_coerce_number(error.get("resets_in_seconds")),
678
- )
679
-
680
-
681
- def _coerce_str(value: object) -> str | None:
682
- return value if isinstance(value, str) else None
683
-
684
-
685
- def _coerce_number(value: object) -> int | float | None:
686
- if isinstance(value, (int, float)):
687
- return value
688
- if isinstance(value, str):
689
- try:
690
- return float(value.strip())
691
- except ValueError:
692
- return None
693
- return None
694
-
695
-
696
- def _apply_error_metadata(target: OpenAIErrorDetail, error: OpenAIError | None) -> None:
697
- if not error:
698
- return
699
- if error.plan_type is not None:
700
- target["plan_type"] = error.plan_type
701
- if error.resets_at is not None:
702
- target["resets_at"] = error.resets_at
703
- if error.resets_in_seconds is not None:
704
- target["resets_in_seconds"] = error.resets_in_seconds
705
-
706
-
707
- class _RetryableStreamError(Exception):
708
- def __init__(self, code: str, error: UpstreamError) -> None:
709
- super().__init__(code)
710
- self.code = code
711
- self.error = error
571
+ stripped = value.strip()
572
+ return stripped or None
712
573
 
713
574
 
714
- def _upstream_error_from_openai(error: OpenAIError | None) -> UpstreamError:
715
- if not error:
716
- return {}
717
- data = error.model_dump(exclude_none=True)
718
- payload: UpstreamError = {}
719
- message = data.get("message")
720
- if isinstance(message, str):
721
- payload["message"] = message
722
- resets_at = data.get("resets_at")
723
- if isinstance(resets_at, (int, float)):
724
- payload["resets_at"] = resets_at
725
- resets_in_seconds = data.get("resets_in_seconds")
726
- if isinstance(resets_in_seconds, (int, float)):
727
- payload["resets_in_seconds"] = resets_in_seconds
728
- return payload
575
+ def _sticky_key_from_compact_payload(payload: ResponsesCompactRequest) -> str | None:
576
+ if not payload.model_extra:
577
+ return None
578
+ value = payload.model_extra.get("prompt_cache_key")
579
+ if not isinstance(value, str):
580
+ return None
581
+ stripped = value.strip()
582
+ return stripped or None
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlalchemy import delete, select
4
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
5
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+ from sqlalchemy.sql import Insert, func
8
+
9
+ from app.db.models import StickySession
10
+
11
+
12
+ class StickySessionsRepository:
13
+ def __init__(self, session: AsyncSession) -> None:
14
+ self._session = session
15
+
16
+ async def get_account_id(self, key: str) -> str | None:
17
+ if not key:
18
+ return None
19
+ result = await self._session.execute(select(StickySession.account_id).where(StickySession.key == key))
20
+ return result.scalar_one_or_none()
21
+
22
+ async def upsert(self, key: str, account_id: str) -> StickySession:
23
+ statement = self._build_upsert_statement(key, account_id)
24
+ await self._session.execute(statement)
25
+ await self._session.commit()
26
+ row = await self._session.get(StickySession, key)
27
+ if row is None:
28
+ raise RuntimeError(f"StickySession upsert failed for key={key!r}")
29
+ await self._session.refresh(row)
30
+ return row
31
+
32
+ async def delete(self, key: str) -> bool:
33
+ if not key:
34
+ return False
35
+ result = await self._session.execute(
36
+ delete(StickySession).where(StickySession.key == key).returning(StickySession.key)
37
+ )
38
+ await self._session.commit()
39
+ return result.scalar_one_or_none() is not None
40
+
41
+ def _build_upsert_statement(self, key: str, account_id: str) -> Insert:
42
+ dialect = self._session.get_bind().dialect.name
43
+ if dialect == "postgresql":
44
+ insert_fn = pg_insert
45
+ elif dialect == "sqlite":
46
+ insert_fn = sqlite_insert
47
+ else:
48
+ raise RuntimeError(f"StickySession upsert unsupported for dialect={dialect!r}")
49
+ statement = insert_fn(StickySession).values(key=key, account_id=account_id)
50
+ return statement.on_conflict_do_update(
51
+ index_elements=[StickySession.key],
52
+ set_={
53
+ "account_id": account_id,
54
+ "updated_at": func.now(),
55
+ },
56
+ )