coding-proxy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. coding/__init__.py +0 -0
  2. coding/proxy/__init__.py +3 -0
  3. coding/proxy/__main__.py +5 -0
  4. coding/proxy/auth/__init__.py +13 -0
  5. coding/proxy/auth/providers/__init__.py +6 -0
  6. coding/proxy/auth/providers/base.py +35 -0
  7. coding/proxy/auth/providers/github.py +133 -0
  8. coding/proxy/auth/providers/google.py +237 -0
  9. coding/proxy/auth/runtime.py +122 -0
  10. coding/proxy/auth/store.py +74 -0
  11. coding/proxy/cli/__init__.py +151 -0
  12. coding/proxy/cli/auth_commands.py +224 -0
  13. coding/proxy/compat/__init__.py +30 -0
  14. coding/proxy/compat/canonical.py +193 -0
  15. coding/proxy/compat/session_store.py +137 -0
  16. coding/proxy/config/__init__.py +6 -0
  17. coding/proxy/config/auth_schema.py +24 -0
  18. coding/proxy/config/loader.py +139 -0
  19. coding/proxy/config/resiliency.py +46 -0
  20. coding/proxy/config/routing.py +279 -0
  21. coding/proxy/config/schema.py +280 -0
  22. coding/proxy/config/server.py +23 -0
  23. coding/proxy/config/vendors.py +53 -0
  24. coding/proxy/convert/__init__.py +14 -0
  25. coding/proxy/convert/anthropic_to_gemini.py +352 -0
  26. coding/proxy/convert/anthropic_to_openai.py +352 -0
  27. coding/proxy/convert/gemini_sse_adapter.py +169 -0
  28. coding/proxy/convert/gemini_to_anthropic.py +98 -0
  29. coding/proxy/convert/openai_to_anthropic.py +88 -0
  30. coding/proxy/logging/__init__.py +49 -0
  31. coding/proxy/logging/db.py +308 -0
  32. coding/proxy/logging/stats.py +129 -0
  33. coding/proxy/model/__init__.py +93 -0
  34. coding/proxy/model/auth.py +32 -0
  35. coding/proxy/model/compat.py +153 -0
  36. coding/proxy/model/constants.py +21 -0
  37. coding/proxy/model/pricing.py +70 -0
  38. coding/proxy/model/token.py +64 -0
  39. coding/proxy/model/vendor.py +218 -0
  40. coding/proxy/pricing.py +100 -0
  41. coding/proxy/routing/__init__.py +47 -0
  42. coding/proxy/routing/circuit_breaker.py +152 -0
  43. coding/proxy/routing/error_classifier.py +67 -0
  44. coding/proxy/routing/executor.py +453 -0
  45. coding/proxy/routing/model_mapper.py +90 -0
  46. coding/proxy/routing/quota_guard.py +169 -0
  47. coding/proxy/routing/rate_limit.py +159 -0
  48. coding/proxy/routing/retry.py +82 -0
  49. coding/proxy/routing/router.py +84 -0
  50. coding/proxy/routing/session_manager.py +62 -0
  51. coding/proxy/routing/tier.py +171 -0
  52. coding/proxy/routing/usage_parser.py +193 -0
  53. coding/proxy/routing/usage_recorder.py +131 -0
  54. coding/proxy/server/__init__.py +1 -0
  55. coding/proxy/server/app.py +142 -0
  56. coding/proxy/server/factory.py +175 -0
  57. coding/proxy/server/request_normalizer.py +139 -0
  58. coding/proxy/server/responses.py +74 -0
  59. coding/proxy/server/routes.py +264 -0
  60. coding/proxy/streaming/__init__.py +1 -0
  61. coding/proxy/streaming/anthropic_compat.py +484 -0
  62. coding/proxy/vendors/__init__.py +29 -0
  63. coding/proxy/vendors/anthropic.py +44 -0
  64. coding/proxy/vendors/antigravity.py +328 -0
  65. coding/proxy/vendors/base.py +353 -0
  66. coding/proxy/vendors/copilot.py +702 -0
  67. coding/proxy/vendors/copilot_models.py +438 -0
  68. coding/proxy/vendors/copilot_token_manager.py +167 -0
  69. coding/proxy/vendors/copilot_urls.py +16 -0
  70. coding/proxy/vendors/mixins.py +71 -0
  71. coding/proxy/vendors/token_manager.py +128 -0
  72. coding/proxy/vendors/zhipu.py +243 -0
  73. coding_proxy-0.1.0.dist-info/METADATA +184 -0
  74. coding_proxy-0.1.0.dist-info/RECORD +77 -0
  75. coding_proxy-0.1.0.dist-info/WHEEL +4 -0
  76. coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
  77. coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,453 @@
1
+ """路由执行器 — 统一的 tier 迭代门控引擎.
2
+
3
+ 封装 ``route_stream`` / ``route_message`` 共享的 tier 循环、
4
+ 门控判断与错误处理逻辑,消除两个路由方法间的重复代码。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import Any, AsyncIterator
13
+
14
+ import httpx
15
+
16
+ from .error_classifier import (
17
+ build_request_capabilities,
18
+ extract_error_payload_from_http_status,
19
+ is_semantic_rejection,
20
+ )
21
+ from .rate_limit import (
22
+ compute_effective_retry_seconds,
23
+ compute_rate_limit_deadline,
24
+ parse_rate_limit_headers,
25
+ )
26
+ from .session_manager import RouteSessionManager
27
+ from .tier import VendorTier
28
+ from .usage_parser import (
29
+ build_usage_evidence_records,
30
+ has_missing_input_usage_signals,
31
+ parse_usage_from_chunk,
32
+ )
33
+ from .usage_recorder import UsageRecorder
34
+ from ..vendors.base import VendorResponse, NoCompatibleVendorError, RequestCapabilities, UsageInfo
35
+ from ..vendors.token_manager import TokenAcquireError
36
+
37
+ # 向后兼容别名
38
+ BackendResponse = VendorResponse
39
+ NoCompatibleBackendError = NoCompatibleVendorError
40
+ from ..compat.canonical import CompatibilityStatus, build_canonical_request
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ def _log_http_error_detail(tier_name: str, exc: Exception, *, is_stream: bool = False) -> None:
46
+ """记录 HTTP 错误的详细信息(状态码 / 响应体摘要 / 异常类型).
47
+
48
+ 替代原先单行 ``logger.warning("Tier %s stream failed: %s", ...)``,
49
+ 在非 200 响应时输出更丰富的诊断上下文,便于跟踪上游故障根因。
50
+ """
51
+ detail_parts = [f"Tier {tier_name} {'stream' if is_stream else 'message'} failed:"]
52
+ detail_parts.append(f" exc_type={type(exc).__name__}")
53
+ if isinstance(exc, httpx.HTTPStatusError) and exc.response is not None:
54
+ resp = exc.response
55
+ detail_parts.append(f" status={resp.status_code}")
56
+ body_preview = (resp.text[:300] if resp.text else "(empty)") if resp.content else "(no content)"
57
+ detail_parts.append(f" response_body={body_preview}")
58
+ # 尝试提取 error type / message
59
+ try:
60
+ payload = resp.json() if resp.content else None
61
+ except Exception:
62
+ payload = None
63
+ if isinstance(payload, dict):
64
+ err = payload.get("error", {})
65
+ if isinstance(err, dict):
66
+ detail_parts.append(f" error_type={err.get('type', 'N/A')}")
67
+ detail_parts.append(f" error_msg={err.get('message', 'N/A')[:200]}")
68
+ else:
69
+ detail_parts.append(f" message={str(exc)[:300]}")
70
+ logger.warning("\n".join(detail_parts))
71
+
72
+
73
+ def _has_tool_results(body: dict[str, Any]) -> bool:
74
+ """检测请求体是否包含 tool_result 内容块.
75
+
76
+ 用于诊断日志中标记「当前请求是否处于工具执行循环」,
77
+ 帮助快速定位 vendor 对 tool_result 处理不兼容的问题(如 Zhipu 500).
78
+ """
79
+ for msg in body.get("messages", []):
80
+ content = msg.get("content")
81
+ if not isinstance(content, list):
82
+ continue
83
+ if any(isinstance(b, dict) and b.get("type") == "tool_result" for b in content):
84
+ return True
85
+ return False
86
+
87
+
88
+ def _log_vendor_response_error(
89
+ tier_name: str,
90
+ resp: VendorResponse,
91
+ body: dict[str, Any],
92
+ *,
93
+ is_stream: bool = False,
94
+ ) -> None:
95
+ """记录供应商返回的非 200 VendorResponse 详细信息.
96
+
97
+ 补充 :func:`_log_http_error_detail` 的覆盖盲区:
98
+ 当 ``send_message()`` 返回 ``VendorResponse(status_code>=400)``
99
+ 而非抛出 httpx 异常时,该函数提供等价的诊断日志能力。
100
+
101
+ 典型场景:Zhipu 等薄透传供应商将上游 500 原样包装为
102
+ VendorResponse 返回,executor 的异常捕获路径不会触发。
103
+ """
104
+ mode = "stream" if is_stream else "message"
105
+ detail_parts = [f"Tier {tier_name} {mode} vendor error response:"]
106
+ detail_parts.append(f" status={resp.status_code}")
107
+ detail_parts.append(f" error_type={resp.error_type or 'N/A'}")
108
+ detail_parts.append(f" error_msg={(resp.error_message or 'N/A')[:300]}")
109
+ # 请求上下文(模型 / 工具 / 工具结果)
110
+ model = body.get("model", "unknown")
111
+ has_tools = bool(body.get("tools"))
112
+ has_tool_results = _has_tool_results(body)
113
+ detail_parts.append(f" model={model}")
114
+ detail_parts.append(f" has_tools={has_tools}")
115
+ detail_parts.append(f" has_tool_results={has_tool_results}")
116
+ # 响应体摘要
117
+ if resp.raw_body:
118
+ try:
119
+ raw_text = resp.raw_body.decode("utf-8", errors="replace")[:500]
120
+ except (AttributeError, UnicodeDecodeError):
121
+ raw_text = f"(binary, {len(resp.raw_body)} bytes)"
122
+ detail_parts.append(f" response_body_preview={raw_text}")
123
+ logger.warning("\n".join(detail_parts))
124
+
125
+ # tier.name → 上游 Vendor 协议标签映射(用于 token 用量日志标注)
126
+ _VENDOR_PROTOCOL_LABEL_MAP: dict[str, str] = {
127
+ "anthropic": "Anthropic",
128
+ "zhipu": "Anthropic",
129
+ "copilot": "OpenAI",
130
+ "antigravity": "Gemini",
131
+ }
132
+
133
+
134
+ class _RouteExecutor:
135
+ """统一的 tier 迭代门控引擎.
136
+
137
+ 职责:
138
+ - 按优先级遍历 tiers,执行能力门控与健康检查
139
+ - 委托具体的流式/非流式执行给调用方回调
140
+ - 统一处理 TokenAcquireError / HTTP 错误 / 语义拒绝
141
+ - 成功后委托 UsageRecorder 记录用量
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ tiers: list[VendorTier],
147
+ usage_recorder: UsageRecorder,
148
+ session_manager: RouteSessionManager,
149
+ reauth_coordinator: Any | None = None,
150
+ ) -> None:
151
+ self._tiers = tiers
152
+ self._recorder = usage_recorder
153
+ self._session_mgr = session_manager
154
+ self._reauth_coordinator = reauth_coordinator
155
+
156
+ # Tier 名称 → OAuth provider 名称的映射
157
+ self._tier_provider_map: dict[str, str] = {
158
+ "copilot": "github",
159
+ "antigravity": "google",
160
+ }
161
+
162
+ # ── 公开执行入口 ──────────────────────────────────────
163
+
164
+ async def execute_stream(
165
+ self,
166
+ body: dict[str, Any],
167
+ headers: dict[str, str],
168
+ ) -> AsyncIterator[tuple[bytes, str]]:
169
+ """路由流式请求,按优先级尝试各层级."""
170
+ last_idx = len(self._tiers) - 1
171
+ last_exc: Exception | None = None
172
+ failed_tier_name: str | None = None
173
+ request_caps = build_request_capabilities(body)
174
+ canonical_request = build_canonical_request(body, headers)
175
+ session_record = await self._session_mgr.get_or_create_record(
176
+ canonical_request.session_key, canonical_request.trace_id,
177
+ )
178
+ incompatible_reasons: list[str] = []
179
+
180
+ for i, tier in enumerate(self._tiers):
181
+ is_last = i == last_idx
182
+
183
+ gate = await self._try_gate_tier(tier, is_last, request_caps, canonical_request, session_record, incompatible_reasons)
184
+ if gate == "skip":
185
+ continue
186
+
187
+ start = time.monotonic()
188
+ usage: dict[str, Any] = {}
189
+
190
+ try:
191
+ async for chunk in tier.vendor.send_message_stream(body, headers):
192
+ parse_usage_from_chunk(
193
+ chunk, usage,
194
+ vendor_label=_VENDOR_PROTOCOL_LABEL_MAP.get(tier.name),
195
+ )
196
+ yield chunk, tier.name
197
+
198
+ info = self._recorder.build_usage_info(usage)
199
+ if has_missing_input_usage_signals(info):
200
+ logger.warning(
201
+ "Stream completed with missing input usage signals: output_tokens=%d, "
202
+ "cache_creation_tokens=%d, cache_read_tokens=%d, tier=%s, usage_data=%r",
203
+ info.output_tokens,
204
+ info.cache_creation_tokens,
205
+ info.cache_read_tokens,
206
+ tier.name,
207
+ usage,
208
+ )
209
+ tier.record_success(info.input_tokens + info.output_tokens)
210
+ duration = int((time.monotonic() - start) * 1000)
211
+ model = body.get("model", "unknown")
212
+ model_served = usage.get("model_served") or tier.vendor.map_model(model)
213
+ self._recorder.log_model_call(vendor=tier.name, model_requested=model, model_served=model_served, duration_ms=duration, usage=info)
214
+ await self._session_mgr.persist_session(tier.vendor.get_compat_trace(), session_record)
215
+ await self._recorder.record(
216
+ tier.name, model, model_served, info, duration, True,
217
+ failed_tier_name is not None, failed_tier_name,
218
+ evidence_records=build_usage_evidence_records(usage, vendor=tier.name, model_served=model_served, request_id=info.request_id),
219
+ )
220
+ return
221
+
222
+ except TokenAcquireError as exc:
223
+ failed_tier_name, last_exc = await self._handle_token_error(tier, exc, is_last, failed_tier_name)
224
+ if is_last and last_exc is exc:
225
+ raise
226
+
227
+ except (httpx.HTTPStatusError, httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as exc:
228
+ _log_http_error_detail(tier.name, exc, is_stream=True)
229
+ should_continue, failed_tier_name, last_exc = await self._handle_http_error(tier, exc, is_last, failed_tier_name, last_exc, is_stream=True)
230
+ if should_continue:
231
+ continue
232
+ if is_last:
233
+ raise
234
+ except Exception as exc:
235
+ logger.error(
236
+ "Tier %s stream unexpected error: %s: %s",
237
+ tier.name, type(exc).__name__, exc,
238
+ exc_info=True,
239
+ )
240
+ tier.record_failure()
241
+ failed_tier_name = tier.name
242
+ if not is_last:
243
+ continue
244
+ raise
245
+
246
+ if last_exc:
247
+ raise last_exc
248
+ raise NoCompatibleVendorError("当前请求包含仅客户端/MCP 可安全承接的能力,未找到兼容供应商", reasons=incompatible_reasons)
249
+
250
+ async def execute_message(
251
+ self,
252
+ body: dict[str, Any],
253
+ headers: dict[str, str],
254
+ ) -> VendorResponse:
255
+ """路由非流式请求,按优先级尝试各层级."""
256
+ last_idx = len(self._tiers) - 1
257
+ start = time.monotonic()
258
+ failed_tier_name: str | None = None
259
+ request_caps = build_request_capabilities(body)
260
+ canonical_request = build_canonical_request(body, headers)
261
+ session_record = await self._session_mgr.get_or_create_record(
262
+ canonical_request.session_key, canonical_request.trace_id,
263
+ )
264
+ incompatible_reasons: list[str] = []
265
+
266
+ for i, tier in enumerate(self._tiers):
267
+ is_last = i == last_idx
268
+
269
+ gate = await self._try_gate_tier(tier, is_last, request_caps, canonical_request, session_record, incompatible_reasons)
270
+ if gate == "skip":
271
+ continue
272
+
273
+ try:
274
+ resp = await tier.vendor.send_message(body, headers)
275
+
276
+ if resp.status_code < 400:
277
+ duration = int((time.monotonic() - start) * 1000)
278
+ model = body.get("model", "unknown")
279
+ model_served = resp.model_served or tier.vendor.map_model(model)
280
+ self._recorder.log_model_call(vendor=tier.name, model_requested=model, model_served=model_served, duration_ms=duration, usage=resp.usage)
281
+ await self._session_mgr.persist_session(tier.vendor.get_compat_trace(), session_record)
282
+ await self._recorder.record(
283
+ tier.name, model, model_served, resp.usage, duration, True,
284
+ failed_tier_name is not None, failed_tier_name,
285
+ evidence_records=self._recorder.build_nonstream_evidence_records(vendor=tier.name, model_served=model_served, usage=resp.usage),
286
+ )
287
+ return resp
288
+
289
+ # 非流式的 semantic rejection 和 failover 判断(从响应对象而非异常中提取)
290
+ if not is_last and is_semantic_rejection(status_code=resp.status_code, error_type=resp.error_type, error_message=resp.error_message):
291
+ logger.warning("Tier %s semantic rejection (%s), trying next tier without recording failure", tier.name, resp.error_type or resp.status_code)
292
+ failed_tier_name = tier.name
293
+ continue
294
+
295
+ if not is_last and tier.vendor.should_trigger_failover(resp.status_code, {"error": {"type": resp.error_type, "message": resp.error_message}}):
296
+ logger.warning("Tier %s error %d, failing over", tier.name, resp.status_code)
297
+ rl_info = parse_rate_limit_headers(resp.response_headers, resp.status_code, resp.error_message)
298
+ tier.record_failure(
299
+ is_cap_error=self._is_cap_error(resp) or rl_info.is_cap_error,
300
+ retry_after_seconds=compute_effective_retry_seconds(rl_info),
301
+ rate_limit_deadline=compute_rate_limit_deadline(rl_info),
302
+ )
303
+ failed_tier_name = tier.name
304
+ continue
305
+
306
+ # 最后一层或不可 failover 的错误:记录并返回原始响应
307
+ _log_vendor_response_error(tier.name, resp, body, is_stream=False)
308
+ duration = int((time.monotonic() - start) * 1000)
309
+ model = body.get("model", "unknown")
310
+ model_served = resp.model_served or tier.vendor.map_model(model)
311
+ self._recorder.log_model_call(vendor=tier.name, model_requested=model, model_served=model_served, duration_ms=duration, usage=resp.usage)
312
+ await self._recorder.record(
313
+ tier.name, model, model_served, resp.usage, duration, resp.status_code < 400,
314
+ failed_tier_name is not None, failed_tier_name,
315
+ evidence_records=self._recorder.build_nonstream_evidence_records(vendor=tier.name, model_served=model_served, usage=resp.usage),
316
+ )
317
+ return resp
318
+
319
+ except TokenAcquireError as exc:
320
+ failed_tier_name, last_exc = await self._handle_token_error(tier, exc, is_last, failed_tier_name)
321
+ if is_last:
322
+ raise
323
+ continue
324
+
325
+ except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as exc:
326
+ _log_http_error_detail(tier.name, exc, is_stream=False)
327
+ tier.record_failure()
328
+ failed_tier_name = tier.name
329
+ if is_last:
330
+ raise
331
+ continue
332
+ except Exception as exc:
333
+ logger.error(
334
+ "Tier %s message unexpected error: %s: %s",
335
+ tier.name, type(exc).__name__, exc,
336
+ exc_info=True,
337
+ )
338
+ tier.record_failure()
339
+ failed_tier_name = tier.name
340
+ if not is_last:
341
+ continue
342
+ raise
343
+
344
+ if incompatible_reasons:
345
+ raise NoCompatibleVendorError("当前请求包含仅客户端/MCP 可安全承接的能力,未找到兼容供应商", reasons=incompatible_reasons)
346
+ raise RuntimeError("无可用供应商层级")
347
+
348
+ # ── 门控与错误处理 ──────────────────────────────────────
349
+
350
+ async def _try_gate_tier(
351
+ self,
352
+ tier: VendorTier,
353
+ is_last: bool,
354
+ request_caps: RequestCapabilities,
355
+ canonical_request: Any,
356
+ session_record: Any,
357
+ incompatible_reasons: list[str],
358
+ ) -> str:
359
+ """对单个 tier 执行能力门控和兼容性检查.
360
+
361
+ Returns:
362
+ "eligible" — 通过所有门控,可执行请求
363
+ "skip" — 未通过门控,跳过此 tier
364
+ """
365
+ supported, reasons = tier.vendor.supports_request(request_caps)
366
+ if not supported:
367
+ reason_text = ",".join(sorted({r.value for r in reasons}))
368
+ incompatible_reasons.append(f"{tier.name}:{reason_text}")
369
+ logger.info("Tier %s skipped due to incompatible capabilities: %s", tier.name, reason_text)
370
+ return "skip"
371
+
372
+ decision = tier.vendor.make_compatibility_decision(canonical_request)
373
+ if decision.status is CompatibilityStatus.UNSAFE:
374
+ reason_text = ",".join(sorted(decision.unsupported_semantics))
375
+ incompatible_reasons.append(f"{tier.name}:{reason_text}")
376
+ logger.info("Tier %s skipped due to compatibility decision: %s", tier.name, reason_text)
377
+ return "skip"
378
+
379
+ self._session_mgr.apply_compat_context(
380
+ tier=tier, canonical_request=canonical_request, decision=decision, session_record=session_record,
381
+ )
382
+
383
+ # 非终端层使用健康检查门控;终端层仅检查 can_execute
384
+ if not is_last:
385
+ if not await tier.can_execute_with_health_check():
386
+ return "skip"
387
+ elif not tier.can_execute():
388
+ return "skip"
389
+
390
+ return "eligible"
391
+
392
+ async def _handle_token_error(
393
+ self,
394
+ tier: VendorTier,
395
+ exc: TokenAcquireError,
396
+ is_last: bool,
397
+ failed_tier_name: str | None,
398
+ ) -> tuple[str | None, Exception]:
399
+ """处理 TokenAcquireError 的共享逻辑."""
400
+ logger.warning("Tier %s credential expired: %s", tier.name, exc)
401
+ tier.record_failure()
402
+ if exc.needs_reauth and self._reauth_coordinator:
403
+ provider = self._tier_provider_map.get(tier.name)
404
+ if provider:
405
+ await self._reauth_coordinator.request_reauth(provider)
406
+ return tier.name, exc
407
+
408
+ async def _handle_http_error(
409
+ self,
410
+ tier: VendorTier,
411
+ exc: Exception,
412
+ is_last: bool,
413
+ failed_tier_name: str | None,
414
+ last_exc: Exception | None,
415
+ *,
416
+ is_stream: bool = False,
417
+ ) -> tuple[bool, str | None, Exception | None]:
418
+ """处理 HTTP 错误的共享逻辑(流式路径).
419
+
420
+ Returns:
421
+ (should_continue, failed_tier_name, last_exc)
422
+ """
423
+ semantic_rejection = False
424
+ if isinstance(exc, httpx.HTTPStatusError) and exc.response is not None:
425
+ payload = extract_error_payload_from_http_status(exc)
426
+ error = payload.get("error", {}) if isinstance(payload, dict) else {}
427
+ semantic_rejection = is_semantic_rejection(
428
+ status_code=exc.response.status_code,
429
+ error_type=error.get("type") if isinstance(error, dict) else None,
430
+ error_message=error.get("message") if isinstance(error, dict) else None,
431
+ )
432
+ if semantic_rejection and not is_last:
433
+ logger.warning("Tier %s semantic rejection, trying next tier without recording failure", tier.name)
434
+ return True, tier.name, exc
435
+
436
+ rl_info = parse_rate_limit_headers(exc.response.headers, exc.response.status_code, exc.response.text[:500] if exc.response.text else None)
437
+ tier.record_failure(
438
+ is_cap_error=rl_info.is_cap_error,
439
+ retry_after_seconds=compute_effective_retry_seconds(rl_info),
440
+ rate_limit_deadline=compute_rate_limit_deadline(rl_info),
441
+ )
442
+ else:
443
+ tier.record_failure()
444
+
445
+ return False, tier.name, exc
446
+
447
+ @staticmethod
448
+ def _is_cap_error(resp: VendorResponse) -> bool:
449
+ """判断是否为订阅用量上限错误."""
450
+ if resp.status_code not in (429, 403):
451
+ return False
452
+ msg = (resp.error_message or "").lower()
453
+ return any(p in msg for p in ("usage cap", "quota", "limit exceeded"))
@@ -0,0 +1,90 @@
1
+ """模型名称映射器 — 按供应商作用域解析模型名."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import fnmatch
6
+ import logging
7
+ import re
8
+
9
+ from ..config.schema import ModelMappingRule
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ _DEFAULT_TARGET = "glm-5.1"
14
+ _VENDOR_ALIASES = {
15
+ "zhipu": "fallback",
16
+ "fallback": "fallback",
17
+ "antigravity": "antigravity",
18
+ "copilot": "copilot",
19
+ }
20
+
21
+
22
+ class ModelMapper:
23
+ """将请求模型名映射到目标供应商模型名."""
24
+
25
+ def __init__(self, rules: list[ModelMappingRule]) -> None:
26
+ self._rules = rules
27
+ # 预编译正则表达式
28
+ self._compiled: dict[str, re.Pattern] = {}
29
+ for rule in rules:
30
+ if rule.is_regex:
31
+ self._compiled[rule.pattern] = re.compile(rule.pattern)
32
+
33
+ @staticmethod
34
+ def _normalize_vendor(vendor: str) -> str:
35
+ normalized = vendor.strip().lower()
36
+ return _VENDOR_ALIASES.get(normalized, normalized)
37
+
38
+ def _rule_applies_to_vendor(self, rule: ModelMappingRule, vendor: str) -> bool:
39
+ if not rule.vendors:
40
+ # 向后兼容:历史规则默认只服务 fallback/zhipu
41
+ return vendor == "fallback"
42
+ normalized = {self._normalize_vendor(name) for name in rule.vendors}
43
+ return vendor in normalized
44
+
45
+ def map(self, model: str, vendor: str = "fallback", default: str | None = None) -> str:
46
+ """将源模型名映射为目标模型名.
47
+
48
+ 优先级:精确匹配 > 通配符/正则匹配 > default/_DEFAULT_TARGET。
49
+ """
50
+ display_name = vendor.strip().lower()
51
+ match_key = self._normalize_vendor(vendor)
52
+ # 1. 精确匹配
53
+ for rule in self._rules:
54
+ if not self._rule_applies_to_vendor(rule, match_key):
55
+ continue
56
+ if not rule.is_regex and "*" not in rule.pattern:
57
+ if rule.pattern == model:
58
+ logger.debug(
59
+ "Model mapped: %s -> %s (vendor=%s exact)",
60
+ model, rule.target, display_name,
61
+ )
62
+ return rule.target
63
+
64
+ # 2. 通配符/正则匹配
65
+ for rule in self._rules:
66
+ if not self._rule_applies_to_vendor(rule, match_key):
67
+ continue
68
+ if rule.is_regex:
69
+ compiled = self._compiled[rule.pattern]
70
+ if compiled.fullmatch(model):
71
+ logger.debug(
72
+ "Model mapped: %s -> %s (vendor=%s regex=%s)",
73
+ model, rule.target, display_name, rule.pattern,
74
+ )
75
+ return rule.target
76
+ elif "*" in rule.pattern:
77
+ if fnmatch.fnmatch(model, rule.pattern):
78
+ logger.debug(
79
+ "Model mapped: %s -> %s (vendor=%s glob=%s)",
80
+ model, rule.target, display_name, rule.pattern,
81
+ )
82
+ return rule.target
83
+
84
+ # 3. 默认值
85
+ fallback_target = default or _DEFAULT_TARGET
86
+ logger.debug(
87
+ "Model unmapped: %s -> %s (vendor=%s default)",
88
+ model, fallback_target, display_name,
89
+ )
90
+ return fallback_target