coding-proxy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. coding/__init__.py +0 -0
  2. coding/proxy/__init__.py +3 -0
  3. coding/proxy/__main__.py +5 -0
  4. coding/proxy/auth/__init__.py +13 -0
  5. coding/proxy/auth/providers/__init__.py +6 -0
  6. coding/proxy/auth/providers/base.py +35 -0
  7. coding/proxy/auth/providers/github.py +133 -0
  8. coding/proxy/auth/providers/google.py +237 -0
  9. coding/proxy/auth/runtime.py +122 -0
  10. coding/proxy/auth/store.py +74 -0
  11. coding/proxy/cli/__init__.py +151 -0
  12. coding/proxy/cli/auth_commands.py +224 -0
  13. coding/proxy/compat/__init__.py +30 -0
  14. coding/proxy/compat/canonical.py +193 -0
  15. coding/proxy/compat/session_store.py +137 -0
  16. coding/proxy/config/__init__.py +6 -0
  17. coding/proxy/config/auth_schema.py +24 -0
  18. coding/proxy/config/loader.py +139 -0
  19. coding/proxy/config/resiliency.py +46 -0
  20. coding/proxy/config/routing.py +279 -0
  21. coding/proxy/config/schema.py +280 -0
  22. coding/proxy/config/server.py +23 -0
  23. coding/proxy/config/vendors.py +53 -0
  24. coding/proxy/convert/__init__.py +14 -0
  25. coding/proxy/convert/anthropic_to_gemini.py +352 -0
  26. coding/proxy/convert/anthropic_to_openai.py +352 -0
  27. coding/proxy/convert/gemini_sse_adapter.py +169 -0
  28. coding/proxy/convert/gemini_to_anthropic.py +98 -0
  29. coding/proxy/convert/openai_to_anthropic.py +88 -0
  30. coding/proxy/logging/__init__.py +49 -0
  31. coding/proxy/logging/db.py +308 -0
  32. coding/proxy/logging/stats.py +129 -0
  33. coding/proxy/model/__init__.py +93 -0
  34. coding/proxy/model/auth.py +32 -0
  35. coding/proxy/model/compat.py +153 -0
  36. coding/proxy/model/constants.py +21 -0
  37. coding/proxy/model/pricing.py +70 -0
  38. coding/proxy/model/token.py +64 -0
  39. coding/proxy/model/vendor.py +218 -0
  40. coding/proxy/pricing.py +100 -0
  41. coding/proxy/routing/__init__.py +47 -0
  42. coding/proxy/routing/circuit_breaker.py +152 -0
  43. coding/proxy/routing/error_classifier.py +67 -0
  44. coding/proxy/routing/executor.py +453 -0
  45. coding/proxy/routing/model_mapper.py +90 -0
  46. coding/proxy/routing/quota_guard.py +169 -0
  47. coding/proxy/routing/rate_limit.py +159 -0
  48. coding/proxy/routing/retry.py +82 -0
  49. coding/proxy/routing/router.py +84 -0
  50. coding/proxy/routing/session_manager.py +62 -0
  51. coding/proxy/routing/tier.py +171 -0
  52. coding/proxy/routing/usage_parser.py +193 -0
  53. coding/proxy/routing/usage_recorder.py +131 -0
  54. coding/proxy/server/__init__.py +1 -0
  55. coding/proxy/server/app.py +142 -0
  56. coding/proxy/server/factory.py +175 -0
  57. coding/proxy/server/request_normalizer.py +139 -0
  58. coding/proxy/server/responses.py +74 -0
  59. coding/proxy/server/routes.py +264 -0
  60. coding/proxy/streaming/__init__.py +1 -0
  61. coding/proxy/streaming/anthropic_compat.py +484 -0
  62. coding/proxy/vendors/__init__.py +29 -0
  63. coding/proxy/vendors/anthropic.py +44 -0
  64. coding/proxy/vendors/antigravity.py +328 -0
  65. coding/proxy/vendors/base.py +353 -0
  66. coding/proxy/vendors/copilot.py +702 -0
  67. coding/proxy/vendors/copilot_models.py +438 -0
  68. coding/proxy/vendors/copilot_token_manager.py +167 -0
  69. coding/proxy/vendors/copilot_urls.py +16 -0
  70. coding/proxy/vendors/mixins.py +71 -0
  71. coding/proxy/vendors/token_manager.py +128 -0
  72. coding/proxy/vendors/zhipu.py +243 -0
  73. coding_proxy-0.1.0.dist-info/METADATA +184 -0
  74. coding_proxy-0.1.0.dist-info/RECORD +77 -0
  75. coding_proxy-0.1.0.dist-info/WHEEL +4 -0
  76. coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
  77. coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,169 @@
1
+ """用量配额守卫 (Quota Guard) — 滑动窗口限额与探测恢复."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import threading
7
+ import time
8
+ from collections import deque
9
+ from enum import Enum
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class QuotaState(Enum):
15
+ WITHIN_QUOTA = "within_quota"
16
+ QUOTA_EXCEEDED = "quota_exceeded"
17
+
18
+
19
+ class QuotaGuard:
20
+ """基于滑动窗口的用量配额守卫.
21
+
22
+ 状态转换:
23
+ - WITHIN_QUOTA → QUOTA_EXCEEDED: 窗口用量 >= budget × threshold% 或检测到 cap 错误
24
+ - QUOTA_EXCEEDED → WITHIN_QUOTA: 窗口用量自然滑出 < threshold% 或探测成功
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ enabled: bool = False,
31
+ token_budget: int = 0,
32
+ window_seconds: int = 18000,
33
+ threshold_percent: float = 99.0,
34
+ probe_interval_seconds: int = 300,
35
+ ) -> None:
36
+ self._enabled = enabled
37
+ self._budget = token_budget
38
+ self._window = window_seconds
39
+ self._threshold = threshold_percent / 100.0
40
+ self._probe_interval = probe_interval_seconds
41
+
42
+ self._state = QuotaState.WITHIN_QUOTA
43
+ self._entries: deque[tuple[float, int]] = deque()
44
+ self._total: int = 0
45
+ self._last_probe: float = 0.0
46
+ self._cap_error_active: bool = False
47
+ self._effective_probe_interval: float = probe_interval_seconds
48
+ self._lock = threading.Lock()
49
+
50
+ @property
51
+ def enabled(self) -> bool:
52
+ return self._enabled
53
+
54
+ @property
55
+ def window_hours(self) -> float:
56
+ """滑动窗口小时数(供基线加载使用)."""
57
+ return self._window / 3600
58
+
59
+ def can_use_primary(self) -> bool:
60
+ """判断是否可以使用主后端."""
61
+ if not self._enabled:
62
+ return True
63
+ with self._lock:
64
+ self._expire()
65
+ if self._state == QuotaState.WITHIN_QUOTA:
66
+ if self._budget > 0 and self._total >= int(self._budget * self._threshold):
67
+ self._transition_to(QuotaState.QUOTA_EXCEEDED)
68
+ logger.warning(
69
+ "Quota guard: WITHIN_QUOTA → EXCEEDED (%.1f%%)",
70
+ self._total / self._budget * 100,
71
+ )
72
+ return False
73
+ return True
74
+ # QUOTA_EXCEEDED — cap 错误触发时仅允许探测恢复,不做预算自动恢复
75
+ if not self._cap_error_active and self._budget > 0 and self._total < int(self._budget * self._threshold):
76
+ self._transition_to(QuotaState.WITHIN_QUOTA)
77
+ logger.info("Quota guard: EXCEEDED → WITHIN_QUOTA (usage dropped)")
78
+ return True
79
+ now = time.monotonic()
80
+ if now - self._last_probe >= self._effective_probe_interval:
81
+ self._last_probe = now
82
+ logger.info("Quota guard: allowing probe request")
83
+ return True
84
+ return False
85
+
86
+ def record_usage(self, tokens: int) -> None:
87
+ """记录新 token 用量到滑动窗口."""
88
+ if not self._enabled or tokens <= 0:
89
+ return
90
+ with self._lock:
91
+ self._entries.append((time.monotonic(), tokens))
92
+ self._total += tokens
93
+
94
+ def record_primary_success(self) -> None:
95
+ """记录主后端请求成功(探测恢复触发点)."""
96
+ if not self._enabled:
97
+ return
98
+ with self._lock:
99
+ if self._state == QuotaState.QUOTA_EXCEEDED:
100
+ self._transition_to(QuotaState.WITHIN_QUOTA)
101
+ logger.info("Quota guard: EXCEEDED → WITHIN_QUOTA (probe success)")
102
+
103
+ def notify_cap_error(self, retry_after_seconds: float | None = None) -> None:
104
+ """外部通知检测到用量上限错误.
105
+
106
+ Args:
107
+ retry_after_seconds: 从响应头解析的建议恢复时间。
108
+ 若提供,更新探测间隔以避免过早探测。
109
+ """
110
+ if not self._enabled:
111
+ return
112
+ with self._lock:
113
+ if self._state != QuotaState.QUOTA_EXCEEDED:
114
+ self._transition_to(QuotaState.QUOTA_EXCEEDED)
115
+ if retry_after_seconds is not None:
116
+ self._effective_probe_interval = max(
117
+ retry_after_seconds * 1.1,
118
+ self._probe_interval,
119
+ )
120
+ self._cap_error_active = True
121
+ logger.warning(
122
+ "Quota guard: cap error detected → EXCEEDED (effective_probe=%ds)",
123
+ int(self._effective_probe_interval),
124
+ )
125
+
126
+ def load_baseline(self, total_tokens: int) -> None:
127
+ """从数据库加载窗口历史用量基线."""
128
+ if not self._enabled or total_tokens <= 0:
129
+ return
130
+ with self._lock:
131
+ midpoint = time.monotonic() - self._window / 2
132
+ self._entries.append((midpoint, total_tokens))
133
+ self._total += total_tokens
134
+ logger.info("Quota guard: loaded baseline %d tokens", total_tokens)
135
+
136
+ def reset(self) -> None:
137
+ """手动重置为 WITHIN_QUOTA 状态."""
138
+ with self._lock:
139
+ self._transition_to(QuotaState.WITHIN_QUOTA)
140
+ self._entries.clear()
141
+ self._total = 0
142
+ logger.info("Quota guard: manually reset to WITHIN_QUOTA")
143
+
144
+ def get_info(self) -> dict:
145
+ """获取配额守卫状态信息."""
146
+ with self._lock:
147
+ self._expire()
148
+ return {
149
+ "state": self._state.value,
150
+ "window_usage_tokens": self._total,
151
+ "budget_tokens": self._budget,
152
+ "usage_percent": round(self._total / self._budget * 100, 1) if self._budget > 0 else 0,
153
+ "threshold_percent": self._threshold * 100,
154
+ }
155
+
156
+ def _expire(self) -> None:
157
+ """清除超出时间窗口的条目."""
158
+ cutoff = time.monotonic() - self._window
159
+ while self._entries and self._entries[0][0] < cutoff:
160
+ _, tokens = self._entries.popleft()
161
+ self._total -= tokens
162
+
163
+ def _transition_to(self, new_state: QuotaState) -> None:
164
+ self._state = new_state
165
+ if new_state == QuotaState.WITHIN_QUOTA:
166
+ self._cap_error_active = False
167
+ self._effective_probe_interval = self._probe_interval
168
+ elif new_state == QuotaState.QUOTA_EXCEEDED:
169
+ self._last_probe = time.monotonic()
@@ -0,0 +1,159 @@
1
+ """速率限制信息解析 — 从 HTTP 响应头提取恢复时间."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, timezone
9
+ from email.utils import parsedate_to_datetime
10
+ from typing import Any
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class RateLimitInfo:
17
+ """从 429 响应中提取的速率限制信息."""
18
+
19
+ retry_after_seconds: float | None = None
20
+ requests_reset_at: float | None = None # monotonic timestamp
21
+ tokens_reset_at: float | None = None # monotonic timestamp
22
+ is_cap_error: bool = False
23
+
24
+
25
+ def parse_rate_limit_headers(
26
+ headers: Any,
27
+ status_code: int,
28
+ error_body: str | None = None,
29
+ ) -> RateLimitInfo:
30
+ """从 HTTP 响应头和状态码解析速率限制信息.
31
+
32
+ Args:
33
+ headers: httpx.Headers 或 dict-like 响应头
34
+ status_code: HTTP 状态码
35
+ error_body: 错误响应体文本(用于检测 cap error)
36
+
37
+ Returns:
38
+ RateLimitInfo 实例
39
+ """
40
+ info = RateLimitInfo()
41
+
42
+ if status_code not in (429, 403):
43
+ return info
44
+
45
+ # 检测 cap error
46
+ if error_body:
47
+ msg = error_body.lower()
48
+ info.is_cap_error = any(
49
+ p in msg for p in ("usage cap", "quota", "limit exceeded")
50
+ )
51
+
52
+ # 解析 retry-after (标准 HTTP header)
53
+ retry_after = _get_header(headers, "retry-after")
54
+ if retry_after:
55
+ info.retry_after_seconds = _parse_retry_after(retry_after)
56
+
57
+ # 解析 anthropic-ratelimit-requests-reset (ISO 8601 datetime)
58
+ requests_reset = _get_header(headers, "anthropic-ratelimit-requests-reset")
59
+ if requests_reset:
60
+ info.requests_reset_at = _parse_reset_time(requests_reset)
61
+
62
+ # 解析 anthropic-ratelimit-tokens-reset (ISO 8601 datetime)
63
+ tokens_reset = _get_header(headers, "anthropic-ratelimit-tokens-reset")
64
+ if tokens_reset:
65
+ info.tokens_reset_at = _parse_reset_time(tokens_reset)
66
+
67
+ return info
68
+
69
+
70
+ def compute_effective_retry_seconds(info: RateLimitInfo) -> float | None:
71
+ """从 RateLimitInfo 中计算最保守的恢复等待时间.
72
+
73
+ 取所有可用信号中的最大值,并加 10% 安全余量。
74
+ """
75
+ candidates: list[float] = []
76
+
77
+ if info.retry_after_seconds is not None:
78
+ candidates.append(info.retry_after_seconds * 1.1)
79
+
80
+ now = time.monotonic()
81
+ if info.requests_reset_at is not None:
82
+ remaining = info.requests_reset_at - now
83
+ if remaining > 0:
84
+ candidates.append(remaining * 1.1)
85
+
86
+ if info.tokens_reset_at is not None:
87
+ remaining = info.tokens_reset_at - now
88
+ if remaining > 0:
89
+ candidates.append(remaining * 1.1)
90
+
91
+ return max(candidates) if candidates else None
92
+
93
+
94
+ def compute_rate_limit_deadline(info: RateLimitInfo) -> float | None:
95
+ """从 RateLimitInfo 中计算最保守的恢复截止 monotonic 时间戳.
96
+
97
+ 与 compute_effective_retry_seconds() 互补:
98
+ - 后者返回相对秒数(给 CircuitBreaker 用于退避计算)
99
+ - 本函数返回绝对 monotonic 时间戳(给 VendorTier 用于精确门控)
100
+
101
+ 取所有可用时间信号中的最大值,并加 10% 安全余量。
102
+ """
103
+ candidates: list[float] = []
104
+ now = time.monotonic()
105
+
106
+ if info.retry_after_seconds is not None:
107
+ candidates.append(now + info.retry_after_seconds * 1.1)
108
+
109
+ if info.requests_reset_at is not None and info.requests_reset_at > now:
110
+ remaining = info.requests_reset_at - now
111
+ candidates.append(now + remaining * 1.1)
112
+
113
+ if info.tokens_reset_at is not None and info.tokens_reset_at > now:
114
+ remaining = info.tokens_reset_at - now
115
+ candidates.append(now + remaining * 1.1)
116
+
117
+ return max(candidates) if candidates else None
118
+
119
+
120
+ def _get_header(headers: Any, name: str) -> str | None:
121
+ """统一获取 header 值(兼容 httpx.Headers 和 dict)."""
122
+ if headers is None:
123
+ return None
124
+ if hasattr(headers, "get"):
125
+ val = headers.get(name)
126
+ return val if val else None
127
+ if isinstance(headers, dict):
128
+ lower_name = name.lower()
129
+ for k, v in headers.items():
130
+ if k.lower() == lower_name:
131
+ return v
132
+ return None
133
+
134
+
135
+ def _parse_retry_after(value: str) -> float | None:
136
+ """解析 Retry-After header (秒数或 HTTP date)."""
137
+ try:
138
+ return float(value)
139
+ except ValueError:
140
+ pass
141
+ try:
142
+ dt = parsedate_to_datetime(value)
143
+ return max(0, (dt - datetime.now(timezone.utc)).total_seconds())
144
+ except (ValueError, TypeError):
145
+ logger.warning("Cannot parse retry-after header: %s", value)
146
+ return None
147
+
148
+
149
+ def _parse_reset_time(value: str) -> float | None:
150
+ """解析 ISO 8601 datetime 为 monotonic timestamp."""
151
+ try:
152
+ dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
153
+ if dt.tzinfo is None:
154
+ dt = dt.replace(tzinfo=timezone.utc)
155
+ remaining = (dt - datetime.now(timezone.utc)).total_seconds()
156
+ return time.monotonic() + max(0, remaining)
157
+ except (ValueError, TypeError):
158
+ logger.warning("Cannot parse reset time: %s", value)
159
+ return None
@@ -0,0 +1,82 @@
1
+ """传输层重试策略 — 指数退避与 Full Jitter.
2
+
3
+ 与 Circuit Breaker 正交互:
4
+ - Retry 处理瞬态网络抖动(秒级恢复)
5
+ - Circuit Breaker 处理持续故障(分钟级恢复)
6
+ - Retry 失败仅向 Circuit Breaker 贡献 1 次失败计数
7
+
8
+ 参考:
9
+ [1] M. Nygard, "Release It!," Pragmatic Bookshelf, 2nd ed., 2018.
10
+ [2] AWS Architecture Center, "Retry Pattern," 2022.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ import random
17
+ from dataclasses import dataclass
18
+
19
+ import httpx
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class RetryConfig:
26
+ """传输层重试配置(运行时)."""
27
+
28
+ max_retries: int = 2 # 最大重试次数(0 = 禁用)
29
+ initial_delay_ms: int = 500 # 初始退避延迟(毫秒)
30
+ max_delay_ms: int = 5000 # 最大退避延迟(毫秒)
31
+ backoff_multiplier: float = 2.0 # 退避倍数
32
+ jitter: bool = True # 是否添加随机抖动
33
+
34
+ @property
35
+ def enabled(self) -> bool:
36
+ return self.max_retries > 0
37
+
38
+ @property
39
+ def max_attempts(self) -> int:
40
+ return self.max_retries + 1
41
+
42
+
43
+ def is_retryable_error(exc: Exception) -> bool:
44
+ """判断异常是否值得重试.
45
+
46
+ 可重试:
47
+ - httpx.TimeoutException(瞬态超时)
48
+ - httpx.ConnectError(网络连接失败)
49
+ - httpx.HTTPStatusError with 5xx(服务端瞬时错误)
50
+
51
+ 不可重试:
52
+ - httpx.HTTPStatusError with 4xx(客户端错误)
53
+ - TokenAcquireError(认证层错误)
54
+ - 其他异常
55
+ """
56
+ if isinstance(exc, httpx.TimeoutException):
57
+ return True
58
+ if isinstance(exc, httpx.ConnectError):
59
+ return True
60
+ if isinstance(exc, httpx.HTTPStatusError):
61
+ return exc.response.status_code >= 500
62
+ return False
63
+
64
+
65
+ def is_retryable_status(status_code: int) -> bool:
66
+ """判断 HTTP 状态码是否值得重试(5xx)."""
67
+ return status_code >= 500
68
+
69
+
70
+ def calculate_delay(attempt: int, cfg: RetryConfig) -> float:
71
+ """计算第 N 次重试的延迟(毫秒),含指数退避和 Full Jitter.
72
+
73
+ Full Jitter 策略: delay = random(0, min(initial * backoff^attempt, max))
74
+ 参考: AWS "Exponential Backoff And Jitter" (Marc Brooker, 2015)
75
+ """
76
+ delay = cfg.initial_delay_ms * (cfg.backoff_multiplier ** attempt)
77
+ delay = min(delay, cfg.max_delay_ms)
78
+
79
+ if cfg.jitter:
80
+ delay = random.uniform(0, delay)
81
+
82
+ return delay
@@ -0,0 +1,84 @@
1
+ """请求路由器 — N-tier 链式路由与自动故障转移(薄代理层).
2
+
3
+ 核心路由逻辑已正交分解至:
4
+ - :mod:`.executor` — 统一的 tier 迭代门控引擎 (_RouteExecutor)
5
+ - :mod:`.usage_recorder` — 用量记录、定价日志与证据构建 (UsageRecorder)
6
+ - :mod:`.session_manager`— 兼容性会话生命周期管理 (RouteSessionManager)
7
+
8
+ 本文件保留 ``RequestRouter`` 公开接口,内部委托给上述模块。
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import TYPE_CHECKING, Any, AsyncIterator
14
+
15
+ if TYPE_CHECKING:
16
+ from ..pricing import PricingTable
17
+
18
+ from .executor import _RouteExecutor
19
+ from .session_manager import RouteSessionManager
20
+ from .tier import VendorTier
21
+
22
+ # 向后兼容别名
23
+ BackendTier = VendorTier
24
+ from .usage_recorder import UsageRecorder
25
+ from ..compat.session_store import CompatSessionStore
26
+ from ..logging.db import TokenLogger
27
+
28
+
29
+ class RequestRouter:
30
+ """路由请求到合适的供应商层级,按优先级链式故障转移."""
31
+
32
+ def __init__(
33
+ self,
34
+ tiers: list[VendorTier],
35
+ token_logger: TokenLogger | None = None,
36
+ reauth_coordinator: Any | None = None,
37
+ compat_session_store: CompatSessionStore | None = None,
38
+ ) -> None:
39
+ if not tiers:
40
+ raise ValueError("至少需要一个供应商层级")
41
+ self._tiers = tiers
42
+
43
+ # 正交分解的子组件
44
+ self._recorder = UsageRecorder(token_logger=token_logger)
45
+ self._session_mgr = RouteSessionManager(compat_session_store)
46
+ self._executor = _RouteExecutor(
47
+ tiers=tiers,
48
+ usage_recorder=self._recorder,
49
+ session_manager=self._session_mgr,
50
+ reauth_coordinator=reauth_coordinator,
51
+ )
52
+
53
+ def set_pricing_table(self, table: PricingTable) -> None:
54
+ """注入 PricingTable 实例(由 lifespan 在启动阶段调用)."""
55
+ self._recorder.set_pricing_table(table)
56
+
57
+ @property
58
+ def tiers(self) -> list[VendorTier]:
59
+ return self._tiers
60
+
61
+ # ── 公开路由接口(委托给 _RouteExecutor)───────────────
62
+
63
+ async def route_stream(
64
+ self,
65
+ body: dict[str, Any],
66
+ headers: dict[str, str],
67
+ ) -> AsyncIterator[tuple[bytes, str]]:
68
+ """路由流式请求,按优先级尝试各层级."""
69
+ async for chunk, vendor_name in self._executor.execute_stream(body, headers):
70
+ yield chunk, vendor_name
71
+
72
+ async def route_message(
73
+ self,
74
+ body: dict[str, Any],
75
+ headers: dict[str, str],
76
+ ) -> Any:
77
+ """路由非流式请求,按优先级尝试各层级."""
78
+ return await self._executor.execute_message(body, headers)
79
+
80
+ # ── 生命周期 ───────────────────────────────────────────
81
+
82
+ async def close(self) -> None:
83
+ for tier in self._tiers:
84
+ await tier.vendor.close()
@@ -0,0 +1,62 @@
1
+ """路由会话管理器 — 封装兼容性会话的创建、上下文应用与持久化."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from ..compat.canonical import (
8
+ CompatibilityStatus,
9
+ CompatibilityTrace,
10
+ )
11
+ from ..compat.session_store import CompatSessionRecord, CompatSessionStore
12
+ from .tier import VendorTier
13
+
14
+
15
+ class RouteSessionManager:
16
+ """管理单次路由请求的兼容性会话生命周期."""
17
+
18
+ def __init__(self, compat_session_store: CompatSessionStore | None = None) -> None:
19
+ self._store = compat_session_store
20
+
21
+ async def get_or_create_record(self, session_key: str, trace_id: str) -> CompatSessionRecord | None:
22
+ if self._store is None:
23
+ return None
24
+ record = await self._store.get(session_key)
25
+ if record is not None:
26
+ return record
27
+ return CompatSessionRecord(session_key=session_key, trace_id=trace_id)
28
+
29
+ def apply_compat_context(
30
+ self,
31
+ *,
32
+ tier: VendorTier,
33
+ canonical_request: Any,
34
+ decision: Any,
35
+ session_record: CompatSessionRecord | None,
36
+ ) -> None:
37
+ provider_protocol = {
38
+ "copilot": "openai_chat_completions",
39
+ "antigravity": "gemini_generate_content",
40
+ "zhipu": "anthropic_messages",
41
+ "anthropic": "anthropic_messages",
42
+ }.get(tier.name, "unknown")
43
+ compat_trace = CompatibilityTrace(
44
+ trace_id=canonical_request.trace_id, vendor=tier.name,
45
+ session_key=canonical_request.session_key, provider_protocol=provider_protocol,
46
+ compat_mode=decision.status.value, simulation_actions=list(decision.simulation_actions),
47
+ unsupported_semantics=list(decision.unsupported_semantics),
48
+ session_state_hits=1 if session_record else 0, request_adaptations=[],
49
+ )
50
+ tier.vendor.set_compat_context(trace=compat_trace, session_record=session_record)
51
+
52
+ async def persist_session(self, trace: CompatibilityTrace | None, session_record: CompatSessionRecord | None) -> None:
53
+ if self._store is None or trace is None or session_record is None:
54
+ return
55
+ provider_states = dict(session_record.provider_state)
56
+ provider_states[trace.vendor] = {
57
+ "compat_mode": trace.compat_mode, "simulation_actions": trace.simulation_actions,
58
+ "unsupported_semantics": trace.unsupported_semantics, "trace_id": trace.trace_id,
59
+ }
60
+ session_record.trace_id = trace.trace_id
61
+ session_record.provider_state = provider_states
62
+ await self._store.upsert(session_record)