coding-proxy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coding/__init__.py +0 -0
- coding/proxy/__init__.py +3 -0
- coding/proxy/__main__.py +5 -0
- coding/proxy/auth/__init__.py +13 -0
- coding/proxy/auth/providers/__init__.py +6 -0
- coding/proxy/auth/providers/base.py +35 -0
- coding/proxy/auth/providers/github.py +133 -0
- coding/proxy/auth/providers/google.py +237 -0
- coding/proxy/auth/runtime.py +122 -0
- coding/proxy/auth/store.py +74 -0
- coding/proxy/cli/__init__.py +151 -0
- coding/proxy/cli/auth_commands.py +224 -0
- coding/proxy/compat/__init__.py +30 -0
- coding/proxy/compat/canonical.py +193 -0
- coding/proxy/compat/session_store.py +137 -0
- coding/proxy/config/__init__.py +6 -0
- coding/proxy/config/auth_schema.py +24 -0
- coding/proxy/config/loader.py +139 -0
- coding/proxy/config/resiliency.py +46 -0
- coding/proxy/config/routing.py +279 -0
- coding/proxy/config/schema.py +280 -0
- coding/proxy/config/server.py +23 -0
- coding/proxy/config/vendors.py +53 -0
- coding/proxy/convert/__init__.py +14 -0
- coding/proxy/convert/anthropic_to_gemini.py +352 -0
- coding/proxy/convert/anthropic_to_openai.py +352 -0
- coding/proxy/convert/gemini_sse_adapter.py +169 -0
- coding/proxy/convert/gemini_to_anthropic.py +98 -0
- coding/proxy/convert/openai_to_anthropic.py +88 -0
- coding/proxy/logging/__init__.py +49 -0
- coding/proxy/logging/db.py +308 -0
- coding/proxy/logging/stats.py +129 -0
- coding/proxy/model/__init__.py +93 -0
- coding/proxy/model/auth.py +32 -0
- coding/proxy/model/compat.py +153 -0
- coding/proxy/model/constants.py +21 -0
- coding/proxy/model/pricing.py +70 -0
- coding/proxy/model/token.py +64 -0
- coding/proxy/model/vendor.py +218 -0
- coding/proxy/pricing.py +100 -0
- coding/proxy/routing/__init__.py +47 -0
- coding/proxy/routing/circuit_breaker.py +152 -0
- coding/proxy/routing/error_classifier.py +67 -0
- coding/proxy/routing/executor.py +453 -0
- coding/proxy/routing/model_mapper.py +90 -0
- coding/proxy/routing/quota_guard.py +169 -0
- coding/proxy/routing/rate_limit.py +159 -0
- coding/proxy/routing/retry.py +82 -0
- coding/proxy/routing/router.py +84 -0
- coding/proxy/routing/session_manager.py +62 -0
- coding/proxy/routing/tier.py +171 -0
- coding/proxy/routing/usage_parser.py +193 -0
- coding/proxy/routing/usage_recorder.py +131 -0
- coding/proxy/server/__init__.py +1 -0
- coding/proxy/server/app.py +142 -0
- coding/proxy/server/factory.py +175 -0
- coding/proxy/server/request_normalizer.py +139 -0
- coding/proxy/server/responses.py +74 -0
- coding/proxy/server/routes.py +264 -0
- coding/proxy/streaming/__init__.py +1 -0
- coding/proxy/streaming/anthropic_compat.py +484 -0
- coding/proxy/vendors/__init__.py +29 -0
- coding/proxy/vendors/anthropic.py +44 -0
- coding/proxy/vendors/antigravity.py +328 -0
- coding/proxy/vendors/base.py +353 -0
- coding/proxy/vendors/copilot.py +702 -0
- coding/proxy/vendors/copilot_models.py +438 -0
- coding/proxy/vendors/copilot_token_manager.py +167 -0
- coding/proxy/vendors/copilot_urls.py +16 -0
- coding/proxy/vendors/mixins.py +71 -0
- coding/proxy/vendors/token_manager.py +128 -0
- coding/proxy/vendors/zhipu.py +243 -0
- coding_proxy-0.1.0.dist-info/METADATA +184 -0
- coding_proxy-0.1.0.dist-info/RECORD +77 -0
- coding_proxy-0.1.0.dist-info/WHEEL +4 -0
- coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
- coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
"""路由执行器 — 统一的 tier 迭代门控引擎.
|
|
2
|
+
|
|
3
|
+
封装 ``route_stream`` / ``route_message`` 共享的 tier 循环、
|
|
4
|
+
门控判断与错误处理逻辑,消除两个路由方法间的重复代码。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import time
|
|
12
|
+
from typing import Any, AsyncIterator
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
from .error_classifier import (
|
|
17
|
+
build_request_capabilities,
|
|
18
|
+
extract_error_payload_from_http_status,
|
|
19
|
+
is_semantic_rejection,
|
|
20
|
+
)
|
|
21
|
+
from .rate_limit import (
|
|
22
|
+
compute_effective_retry_seconds,
|
|
23
|
+
compute_rate_limit_deadline,
|
|
24
|
+
parse_rate_limit_headers,
|
|
25
|
+
)
|
|
26
|
+
from .session_manager import RouteSessionManager
|
|
27
|
+
from .tier import VendorTier
|
|
28
|
+
from .usage_parser import (
|
|
29
|
+
build_usage_evidence_records,
|
|
30
|
+
has_missing_input_usage_signals,
|
|
31
|
+
parse_usage_from_chunk,
|
|
32
|
+
)
|
|
33
|
+
from .usage_recorder import UsageRecorder
|
|
34
|
+
from ..vendors.base import VendorResponse, NoCompatibleVendorError, RequestCapabilities, UsageInfo
|
|
35
|
+
from ..vendors.token_manager import TokenAcquireError
|
|
36
|
+
|
|
37
|
+
# 向后兼容别名
|
|
38
|
+
BackendResponse = VendorResponse
|
|
39
|
+
NoCompatibleBackendError = NoCompatibleVendorError
|
|
40
|
+
from ..compat.canonical import CompatibilityStatus, build_canonical_request
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _log_http_error_detail(tier_name: str, exc: Exception, *, is_stream: bool = False) -> None:
|
|
46
|
+
"""记录 HTTP 错误的详细信息(状态码 / 响应体摘要 / 异常类型).
|
|
47
|
+
|
|
48
|
+
替代原先单行 ``logger.warning("Tier %s stream failed: %s", ...)``,
|
|
49
|
+
在非 200 响应时输出更丰富的诊断上下文,便于跟踪上游故障根因。
|
|
50
|
+
"""
|
|
51
|
+
detail_parts = [f"Tier {tier_name} {'stream' if is_stream else 'message'} failed:"]
|
|
52
|
+
detail_parts.append(f" exc_type={type(exc).__name__}")
|
|
53
|
+
if isinstance(exc, httpx.HTTPStatusError) and exc.response is not None:
|
|
54
|
+
resp = exc.response
|
|
55
|
+
detail_parts.append(f" status={resp.status_code}")
|
|
56
|
+
body_preview = (resp.text[:300] if resp.text else "(empty)") if resp.content else "(no content)"
|
|
57
|
+
detail_parts.append(f" response_body={body_preview}")
|
|
58
|
+
# 尝试提取 error type / message
|
|
59
|
+
try:
|
|
60
|
+
payload = resp.json() if resp.content else None
|
|
61
|
+
except Exception:
|
|
62
|
+
payload = None
|
|
63
|
+
if isinstance(payload, dict):
|
|
64
|
+
err = payload.get("error", {})
|
|
65
|
+
if isinstance(err, dict):
|
|
66
|
+
detail_parts.append(f" error_type={err.get('type', 'N/A')}")
|
|
67
|
+
detail_parts.append(f" error_msg={err.get('message', 'N/A')[:200]}")
|
|
68
|
+
else:
|
|
69
|
+
detail_parts.append(f" message={str(exc)[:300]}")
|
|
70
|
+
logger.warning("\n".join(detail_parts))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _has_tool_results(body: dict[str, Any]) -> bool:
|
|
74
|
+
"""检测请求体是否包含 tool_result 内容块.
|
|
75
|
+
|
|
76
|
+
用于诊断日志中标记「当前请求是否处于工具执行循环」,
|
|
77
|
+
帮助快速定位 vendor 对 tool_result 处理不兼容的问题(如 Zhipu 500).
|
|
78
|
+
"""
|
|
79
|
+
for msg in body.get("messages", []):
|
|
80
|
+
content = msg.get("content")
|
|
81
|
+
if not isinstance(content, list):
|
|
82
|
+
continue
|
|
83
|
+
if any(isinstance(b, dict) and b.get("type") == "tool_result" for b in content):
|
|
84
|
+
return True
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _log_vendor_response_error(
|
|
89
|
+
tier_name: str,
|
|
90
|
+
resp: VendorResponse,
|
|
91
|
+
body: dict[str, Any],
|
|
92
|
+
*,
|
|
93
|
+
is_stream: bool = False,
|
|
94
|
+
) -> None:
|
|
95
|
+
"""记录供应商返回的非 200 VendorResponse 详细信息.
|
|
96
|
+
|
|
97
|
+
补充 :func:`_log_http_error_detail` 的覆盖盲区:
|
|
98
|
+
当 ``send_message()`` 返回 ``VendorResponse(status_code>=400)``
|
|
99
|
+
而非抛出 httpx 异常时,该函数提供等价的诊断日志能力。
|
|
100
|
+
|
|
101
|
+
典型场景:Zhipu 等薄透传供应商将上游 500 原样包装为
|
|
102
|
+
VendorResponse 返回,executor 的异常捕获路径不会触发。
|
|
103
|
+
"""
|
|
104
|
+
mode = "stream" if is_stream else "message"
|
|
105
|
+
detail_parts = [f"Tier {tier_name} {mode} vendor error response:"]
|
|
106
|
+
detail_parts.append(f" status={resp.status_code}")
|
|
107
|
+
detail_parts.append(f" error_type={resp.error_type or 'N/A'}")
|
|
108
|
+
detail_parts.append(f" error_msg={(resp.error_message or 'N/A')[:300]}")
|
|
109
|
+
# 请求上下文(模型 / 工具 / 工具结果)
|
|
110
|
+
model = body.get("model", "unknown")
|
|
111
|
+
has_tools = bool(body.get("tools"))
|
|
112
|
+
has_tool_results = _has_tool_results(body)
|
|
113
|
+
detail_parts.append(f" model={model}")
|
|
114
|
+
detail_parts.append(f" has_tools={has_tools}")
|
|
115
|
+
detail_parts.append(f" has_tool_results={has_tool_results}")
|
|
116
|
+
# 响应体摘要
|
|
117
|
+
if resp.raw_body:
|
|
118
|
+
try:
|
|
119
|
+
raw_text = resp.raw_body.decode("utf-8", errors="replace")[:500]
|
|
120
|
+
except (AttributeError, UnicodeDecodeError):
|
|
121
|
+
raw_text = f"(binary, {len(resp.raw_body)} bytes)"
|
|
122
|
+
detail_parts.append(f" response_body_preview={raw_text}")
|
|
123
|
+
logger.warning("\n".join(detail_parts))
|
|
124
|
+
|
|
125
|
+
# tier.name → 上游 Vendor 协议标签映射(用于 token 用量日志标注)
|
|
126
|
+
_VENDOR_PROTOCOL_LABEL_MAP: dict[str, str] = {
|
|
127
|
+
"anthropic": "Anthropic",
|
|
128
|
+
"zhipu": "Anthropic",
|
|
129
|
+
"copilot": "OpenAI",
|
|
130
|
+
"antigravity": "Gemini",
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class _RouteExecutor:
|
|
135
|
+
"""统一的 tier 迭代门控引擎.
|
|
136
|
+
|
|
137
|
+
职责:
|
|
138
|
+
- 按优先级遍历 tiers,执行能力门控与健康检查
|
|
139
|
+
- 委托具体的流式/非流式执行给调用方回调
|
|
140
|
+
- 统一处理 TokenAcquireError / HTTP 错误 / 语义拒绝
|
|
141
|
+
- 成功后委托 UsageRecorder 记录用量
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
tiers: list[VendorTier],
|
|
147
|
+
usage_recorder: UsageRecorder,
|
|
148
|
+
session_manager: RouteSessionManager,
|
|
149
|
+
reauth_coordinator: Any | None = None,
|
|
150
|
+
) -> None:
|
|
151
|
+
self._tiers = tiers
|
|
152
|
+
self._recorder = usage_recorder
|
|
153
|
+
self._session_mgr = session_manager
|
|
154
|
+
self._reauth_coordinator = reauth_coordinator
|
|
155
|
+
|
|
156
|
+
# Tier 名称 → OAuth provider 名称的映射
|
|
157
|
+
self._tier_provider_map: dict[str, str] = {
|
|
158
|
+
"copilot": "github",
|
|
159
|
+
"antigravity": "google",
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# ── 公开执行入口 ──────────────────────────────────────
|
|
163
|
+
|
|
164
|
+
async def execute_stream(
|
|
165
|
+
self,
|
|
166
|
+
body: dict[str, Any],
|
|
167
|
+
headers: dict[str, str],
|
|
168
|
+
) -> AsyncIterator[tuple[bytes, str]]:
|
|
169
|
+
"""路由流式请求,按优先级尝试各层级."""
|
|
170
|
+
last_idx = len(self._tiers) - 1
|
|
171
|
+
last_exc: Exception | None = None
|
|
172
|
+
failed_tier_name: str | None = None
|
|
173
|
+
request_caps = build_request_capabilities(body)
|
|
174
|
+
canonical_request = build_canonical_request(body, headers)
|
|
175
|
+
session_record = await self._session_mgr.get_or_create_record(
|
|
176
|
+
canonical_request.session_key, canonical_request.trace_id,
|
|
177
|
+
)
|
|
178
|
+
incompatible_reasons: list[str] = []
|
|
179
|
+
|
|
180
|
+
for i, tier in enumerate(self._tiers):
|
|
181
|
+
is_last = i == last_idx
|
|
182
|
+
|
|
183
|
+
gate = await self._try_gate_tier(tier, is_last, request_caps, canonical_request, session_record, incompatible_reasons)
|
|
184
|
+
if gate == "skip":
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
start = time.monotonic()
|
|
188
|
+
usage: dict[str, Any] = {}
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
async for chunk in tier.vendor.send_message_stream(body, headers):
|
|
192
|
+
parse_usage_from_chunk(
|
|
193
|
+
chunk, usage,
|
|
194
|
+
vendor_label=_VENDOR_PROTOCOL_LABEL_MAP.get(tier.name),
|
|
195
|
+
)
|
|
196
|
+
yield chunk, tier.name
|
|
197
|
+
|
|
198
|
+
info = self._recorder.build_usage_info(usage)
|
|
199
|
+
if has_missing_input_usage_signals(info):
|
|
200
|
+
logger.warning(
|
|
201
|
+
"Stream completed with missing input usage signals: output_tokens=%d, "
|
|
202
|
+
"cache_creation_tokens=%d, cache_read_tokens=%d, tier=%s, usage_data=%r",
|
|
203
|
+
info.output_tokens,
|
|
204
|
+
info.cache_creation_tokens,
|
|
205
|
+
info.cache_read_tokens,
|
|
206
|
+
tier.name,
|
|
207
|
+
usage,
|
|
208
|
+
)
|
|
209
|
+
tier.record_success(info.input_tokens + info.output_tokens)
|
|
210
|
+
duration = int((time.monotonic() - start) * 1000)
|
|
211
|
+
model = body.get("model", "unknown")
|
|
212
|
+
model_served = usage.get("model_served") or tier.vendor.map_model(model)
|
|
213
|
+
self._recorder.log_model_call(vendor=tier.name, model_requested=model, model_served=model_served, duration_ms=duration, usage=info)
|
|
214
|
+
await self._session_mgr.persist_session(tier.vendor.get_compat_trace(), session_record)
|
|
215
|
+
await self._recorder.record(
|
|
216
|
+
tier.name, model, model_served, info, duration, True,
|
|
217
|
+
failed_tier_name is not None, failed_tier_name,
|
|
218
|
+
evidence_records=build_usage_evidence_records(usage, vendor=tier.name, model_served=model_served, request_id=info.request_id),
|
|
219
|
+
)
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
except TokenAcquireError as exc:
|
|
223
|
+
failed_tier_name, last_exc = await self._handle_token_error(tier, exc, is_last, failed_tier_name)
|
|
224
|
+
if is_last and last_exc is exc:
|
|
225
|
+
raise
|
|
226
|
+
|
|
227
|
+
except (httpx.HTTPStatusError, httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as exc:
|
|
228
|
+
_log_http_error_detail(tier.name, exc, is_stream=True)
|
|
229
|
+
should_continue, failed_tier_name, last_exc = await self._handle_http_error(tier, exc, is_last, failed_tier_name, last_exc, is_stream=True)
|
|
230
|
+
if should_continue:
|
|
231
|
+
continue
|
|
232
|
+
if is_last:
|
|
233
|
+
raise
|
|
234
|
+
except Exception as exc:
|
|
235
|
+
logger.error(
|
|
236
|
+
"Tier %s stream unexpected error: %s: %s",
|
|
237
|
+
tier.name, type(exc).__name__, exc,
|
|
238
|
+
exc_info=True,
|
|
239
|
+
)
|
|
240
|
+
tier.record_failure()
|
|
241
|
+
failed_tier_name = tier.name
|
|
242
|
+
if not is_last:
|
|
243
|
+
continue
|
|
244
|
+
raise
|
|
245
|
+
|
|
246
|
+
if last_exc:
|
|
247
|
+
raise last_exc
|
|
248
|
+
raise NoCompatibleVendorError("当前请求包含仅客户端/MCP 可安全承接的能力,未找到兼容供应商", reasons=incompatible_reasons)
|
|
249
|
+
|
|
250
|
+
async def execute_message(
|
|
251
|
+
self,
|
|
252
|
+
body: dict[str, Any],
|
|
253
|
+
headers: dict[str, str],
|
|
254
|
+
) -> VendorResponse:
|
|
255
|
+
"""路由非流式请求,按优先级尝试各层级."""
|
|
256
|
+
last_idx = len(self._tiers) - 1
|
|
257
|
+
start = time.monotonic()
|
|
258
|
+
failed_tier_name: str | None = None
|
|
259
|
+
request_caps = build_request_capabilities(body)
|
|
260
|
+
canonical_request = build_canonical_request(body, headers)
|
|
261
|
+
session_record = await self._session_mgr.get_or_create_record(
|
|
262
|
+
canonical_request.session_key, canonical_request.trace_id,
|
|
263
|
+
)
|
|
264
|
+
incompatible_reasons: list[str] = []
|
|
265
|
+
|
|
266
|
+
for i, tier in enumerate(self._tiers):
|
|
267
|
+
is_last = i == last_idx
|
|
268
|
+
|
|
269
|
+
gate = await self._try_gate_tier(tier, is_last, request_caps, canonical_request, session_record, incompatible_reasons)
|
|
270
|
+
if gate == "skip":
|
|
271
|
+
continue
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
resp = await tier.vendor.send_message(body, headers)
|
|
275
|
+
|
|
276
|
+
if resp.status_code < 400:
|
|
277
|
+
duration = int((time.monotonic() - start) * 1000)
|
|
278
|
+
model = body.get("model", "unknown")
|
|
279
|
+
model_served = resp.model_served or tier.vendor.map_model(model)
|
|
280
|
+
self._recorder.log_model_call(vendor=tier.name, model_requested=model, model_served=model_served, duration_ms=duration, usage=resp.usage)
|
|
281
|
+
await self._session_mgr.persist_session(tier.vendor.get_compat_trace(), session_record)
|
|
282
|
+
await self._recorder.record(
|
|
283
|
+
tier.name, model, model_served, resp.usage, duration, True,
|
|
284
|
+
failed_tier_name is not None, failed_tier_name,
|
|
285
|
+
evidence_records=self._recorder.build_nonstream_evidence_records(vendor=tier.name, model_served=model_served, usage=resp.usage),
|
|
286
|
+
)
|
|
287
|
+
return resp
|
|
288
|
+
|
|
289
|
+
# 非流式的 semantic rejection 和 failover 判断(从响应对象而非异常中提取)
|
|
290
|
+
if not is_last and is_semantic_rejection(status_code=resp.status_code, error_type=resp.error_type, error_message=resp.error_message):
|
|
291
|
+
logger.warning("Tier %s semantic rejection (%s), trying next tier without recording failure", tier.name, resp.error_type or resp.status_code)
|
|
292
|
+
failed_tier_name = tier.name
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
if not is_last and tier.vendor.should_trigger_failover(resp.status_code, {"error": {"type": resp.error_type, "message": resp.error_message}}):
|
|
296
|
+
logger.warning("Tier %s error %d, failing over", tier.name, resp.status_code)
|
|
297
|
+
rl_info = parse_rate_limit_headers(resp.response_headers, resp.status_code, resp.error_message)
|
|
298
|
+
tier.record_failure(
|
|
299
|
+
is_cap_error=self._is_cap_error(resp) or rl_info.is_cap_error,
|
|
300
|
+
retry_after_seconds=compute_effective_retry_seconds(rl_info),
|
|
301
|
+
rate_limit_deadline=compute_rate_limit_deadline(rl_info),
|
|
302
|
+
)
|
|
303
|
+
failed_tier_name = tier.name
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
# 最后一层或不可 failover 的错误:记录并返回原始响应
|
|
307
|
+
_log_vendor_response_error(tier.name, resp, body, is_stream=False)
|
|
308
|
+
duration = int((time.monotonic() - start) * 1000)
|
|
309
|
+
model = body.get("model", "unknown")
|
|
310
|
+
model_served = resp.model_served or tier.vendor.map_model(model)
|
|
311
|
+
self._recorder.log_model_call(vendor=tier.name, model_requested=model, model_served=model_served, duration_ms=duration, usage=resp.usage)
|
|
312
|
+
await self._recorder.record(
|
|
313
|
+
tier.name, model, model_served, resp.usage, duration, resp.status_code < 400,
|
|
314
|
+
failed_tier_name is not None, failed_tier_name,
|
|
315
|
+
evidence_records=self._recorder.build_nonstream_evidence_records(vendor=tier.name, model_served=model_served, usage=resp.usage),
|
|
316
|
+
)
|
|
317
|
+
return resp
|
|
318
|
+
|
|
319
|
+
except TokenAcquireError as exc:
|
|
320
|
+
failed_tier_name, last_exc = await self._handle_token_error(tier, exc, is_last, failed_tier_name)
|
|
321
|
+
if is_last:
|
|
322
|
+
raise
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as exc:
|
|
326
|
+
_log_http_error_detail(tier.name, exc, is_stream=False)
|
|
327
|
+
tier.record_failure()
|
|
328
|
+
failed_tier_name = tier.name
|
|
329
|
+
if is_last:
|
|
330
|
+
raise
|
|
331
|
+
continue
|
|
332
|
+
except Exception as exc:
|
|
333
|
+
logger.error(
|
|
334
|
+
"Tier %s message unexpected error: %s: %s",
|
|
335
|
+
tier.name, type(exc).__name__, exc,
|
|
336
|
+
exc_info=True,
|
|
337
|
+
)
|
|
338
|
+
tier.record_failure()
|
|
339
|
+
failed_tier_name = tier.name
|
|
340
|
+
if not is_last:
|
|
341
|
+
continue
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
if incompatible_reasons:
|
|
345
|
+
raise NoCompatibleVendorError("当前请求包含仅客户端/MCP 可安全承接的能力,未找到兼容供应商", reasons=incompatible_reasons)
|
|
346
|
+
raise RuntimeError("无可用供应商层级")
|
|
347
|
+
|
|
348
|
+
# ── 门控与错误处理 ──────────────────────────────────────
|
|
349
|
+
|
|
350
|
+
async def _try_gate_tier(
|
|
351
|
+
self,
|
|
352
|
+
tier: VendorTier,
|
|
353
|
+
is_last: bool,
|
|
354
|
+
request_caps: RequestCapabilities,
|
|
355
|
+
canonical_request: Any,
|
|
356
|
+
session_record: Any,
|
|
357
|
+
incompatible_reasons: list[str],
|
|
358
|
+
) -> str:
|
|
359
|
+
"""对单个 tier 执行能力门控和兼容性检查.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
"eligible" — 通过所有门控,可执行请求
|
|
363
|
+
"skip" — 未通过门控,跳过此 tier
|
|
364
|
+
"""
|
|
365
|
+
supported, reasons = tier.vendor.supports_request(request_caps)
|
|
366
|
+
if not supported:
|
|
367
|
+
reason_text = ",".join(sorted({r.value for r in reasons}))
|
|
368
|
+
incompatible_reasons.append(f"{tier.name}:{reason_text}")
|
|
369
|
+
logger.info("Tier %s skipped due to incompatible capabilities: %s", tier.name, reason_text)
|
|
370
|
+
return "skip"
|
|
371
|
+
|
|
372
|
+
decision = tier.vendor.make_compatibility_decision(canonical_request)
|
|
373
|
+
if decision.status is CompatibilityStatus.UNSAFE:
|
|
374
|
+
reason_text = ",".join(sorted(decision.unsupported_semantics))
|
|
375
|
+
incompatible_reasons.append(f"{tier.name}:{reason_text}")
|
|
376
|
+
logger.info("Tier %s skipped due to compatibility decision: %s", tier.name, reason_text)
|
|
377
|
+
return "skip"
|
|
378
|
+
|
|
379
|
+
self._session_mgr.apply_compat_context(
|
|
380
|
+
tier=tier, canonical_request=canonical_request, decision=decision, session_record=session_record,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# 非终端层使用健康检查门控;终端层仅检查 can_execute
|
|
384
|
+
if not is_last:
|
|
385
|
+
if not await tier.can_execute_with_health_check():
|
|
386
|
+
return "skip"
|
|
387
|
+
elif not tier.can_execute():
|
|
388
|
+
return "skip"
|
|
389
|
+
|
|
390
|
+
return "eligible"
|
|
391
|
+
|
|
392
|
+
async def _handle_token_error(
|
|
393
|
+
self,
|
|
394
|
+
tier: VendorTier,
|
|
395
|
+
exc: TokenAcquireError,
|
|
396
|
+
is_last: bool,
|
|
397
|
+
failed_tier_name: str | None,
|
|
398
|
+
) -> tuple[str | None, Exception]:
|
|
399
|
+
"""处理 TokenAcquireError 的共享逻辑."""
|
|
400
|
+
logger.warning("Tier %s credential expired: %s", tier.name, exc)
|
|
401
|
+
tier.record_failure()
|
|
402
|
+
if exc.needs_reauth and self._reauth_coordinator:
|
|
403
|
+
provider = self._tier_provider_map.get(tier.name)
|
|
404
|
+
if provider:
|
|
405
|
+
await self._reauth_coordinator.request_reauth(provider)
|
|
406
|
+
return tier.name, exc
|
|
407
|
+
|
|
408
|
+
async def _handle_http_error(
|
|
409
|
+
self,
|
|
410
|
+
tier: VendorTier,
|
|
411
|
+
exc: Exception,
|
|
412
|
+
is_last: bool,
|
|
413
|
+
failed_tier_name: str | None,
|
|
414
|
+
last_exc: Exception | None,
|
|
415
|
+
*,
|
|
416
|
+
is_stream: bool = False,
|
|
417
|
+
) -> tuple[bool, str | None, Exception | None]:
|
|
418
|
+
"""处理 HTTP 错误的共享逻辑(流式路径).
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
(should_continue, failed_tier_name, last_exc)
|
|
422
|
+
"""
|
|
423
|
+
semantic_rejection = False
|
|
424
|
+
if isinstance(exc, httpx.HTTPStatusError) and exc.response is not None:
|
|
425
|
+
payload = extract_error_payload_from_http_status(exc)
|
|
426
|
+
error = payload.get("error", {}) if isinstance(payload, dict) else {}
|
|
427
|
+
semantic_rejection = is_semantic_rejection(
|
|
428
|
+
status_code=exc.response.status_code,
|
|
429
|
+
error_type=error.get("type") if isinstance(error, dict) else None,
|
|
430
|
+
error_message=error.get("message") if isinstance(error, dict) else None,
|
|
431
|
+
)
|
|
432
|
+
if semantic_rejection and not is_last:
|
|
433
|
+
logger.warning("Tier %s semantic rejection, trying next tier without recording failure", tier.name)
|
|
434
|
+
return True, tier.name, exc
|
|
435
|
+
|
|
436
|
+
rl_info = parse_rate_limit_headers(exc.response.headers, exc.response.status_code, exc.response.text[:500] if exc.response.text else None)
|
|
437
|
+
tier.record_failure(
|
|
438
|
+
is_cap_error=rl_info.is_cap_error,
|
|
439
|
+
retry_after_seconds=compute_effective_retry_seconds(rl_info),
|
|
440
|
+
rate_limit_deadline=compute_rate_limit_deadline(rl_info),
|
|
441
|
+
)
|
|
442
|
+
else:
|
|
443
|
+
tier.record_failure()
|
|
444
|
+
|
|
445
|
+
return False, tier.name, exc
|
|
446
|
+
|
|
447
|
+
@staticmethod
|
|
448
|
+
def _is_cap_error(resp: VendorResponse) -> bool:
|
|
449
|
+
"""判断是否为订阅用量上限错误."""
|
|
450
|
+
if resp.status_code not in (429, 403):
|
|
451
|
+
return False
|
|
452
|
+
msg = (resp.error_message or "").lower()
|
|
453
|
+
return any(p in msg for p in ("usage cap", "quota", "limit exceeded"))
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""模型名称映射器 — 按供应商作用域解析模型名."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import fnmatch
|
|
6
|
+
import logging
|
|
7
|
+
import re
|
|
8
|
+
|
|
9
|
+
from ..config.schema import ModelMappingRule
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
_DEFAULT_TARGET = "glm-5.1"
|
|
14
|
+
_VENDOR_ALIASES = {
|
|
15
|
+
"zhipu": "fallback",
|
|
16
|
+
"fallback": "fallback",
|
|
17
|
+
"antigravity": "antigravity",
|
|
18
|
+
"copilot": "copilot",
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ModelMapper:
|
|
23
|
+
"""将请求模型名映射到目标供应商模型名."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, rules: list[ModelMappingRule]) -> None:
|
|
26
|
+
self._rules = rules
|
|
27
|
+
# 预编译正则表达式
|
|
28
|
+
self._compiled: dict[str, re.Pattern] = {}
|
|
29
|
+
for rule in rules:
|
|
30
|
+
if rule.is_regex:
|
|
31
|
+
self._compiled[rule.pattern] = re.compile(rule.pattern)
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def _normalize_vendor(vendor: str) -> str:
|
|
35
|
+
normalized = vendor.strip().lower()
|
|
36
|
+
return _VENDOR_ALIASES.get(normalized, normalized)
|
|
37
|
+
|
|
38
|
+
def _rule_applies_to_vendor(self, rule: ModelMappingRule, vendor: str) -> bool:
|
|
39
|
+
if not rule.vendors:
|
|
40
|
+
# 向后兼容:历史规则默认只服务 fallback/zhipu
|
|
41
|
+
return vendor == "fallback"
|
|
42
|
+
normalized = {self._normalize_vendor(name) for name in rule.vendors}
|
|
43
|
+
return vendor in normalized
|
|
44
|
+
|
|
45
|
+
def map(self, model: str, vendor: str = "fallback", default: str | None = None) -> str:
|
|
46
|
+
"""将源模型名映射为目标模型名.
|
|
47
|
+
|
|
48
|
+
优先级:精确匹配 > 通配符/正则匹配 > default/_DEFAULT_TARGET。
|
|
49
|
+
"""
|
|
50
|
+
display_name = vendor.strip().lower()
|
|
51
|
+
match_key = self._normalize_vendor(vendor)
|
|
52
|
+
# 1. 精确匹配
|
|
53
|
+
for rule in self._rules:
|
|
54
|
+
if not self._rule_applies_to_vendor(rule, match_key):
|
|
55
|
+
continue
|
|
56
|
+
if not rule.is_regex and "*" not in rule.pattern:
|
|
57
|
+
if rule.pattern == model:
|
|
58
|
+
logger.debug(
|
|
59
|
+
"Model mapped: %s -> %s (vendor=%s exact)",
|
|
60
|
+
model, rule.target, display_name,
|
|
61
|
+
)
|
|
62
|
+
return rule.target
|
|
63
|
+
|
|
64
|
+
# 2. 通配符/正则匹配
|
|
65
|
+
for rule in self._rules:
|
|
66
|
+
if not self._rule_applies_to_vendor(rule, match_key):
|
|
67
|
+
continue
|
|
68
|
+
if rule.is_regex:
|
|
69
|
+
compiled = self._compiled[rule.pattern]
|
|
70
|
+
if compiled.fullmatch(model):
|
|
71
|
+
logger.debug(
|
|
72
|
+
"Model mapped: %s -> %s (vendor=%s regex=%s)",
|
|
73
|
+
model, rule.target, display_name, rule.pattern,
|
|
74
|
+
)
|
|
75
|
+
return rule.target
|
|
76
|
+
elif "*" in rule.pattern:
|
|
77
|
+
if fnmatch.fnmatch(model, rule.pattern):
|
|
78
|
+
logger.debug(
|
|
79
|
+
"Model mapped: %s -> %s (vendor=%s glob=%s)",
|
|
80
|
+
model, rule.target, display_name, rule.pattern,
|
|
81
|
+
)
|
|
82
|
+
return rule.target
|
|
83
|
+
|
|
84
|
+
# 3. 默认值
|
|
85
|
+
fallback_target = default or _DEFAULT_TARGET
|
|
86
|
+
logger.debug(
|
|
87
|
+
"Model unmapped: %s -> %s (vendor=%s default)",
|
|
88
|
+
model, fallback_target, display_name,
|
|
89
|
+
)
|
|
90
|
+
return fallback_target
|