coding-proxy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coding/__init__.py +0 -0
- coding/proxy/__init__.py +3 -0
- coding/proxy/__main__.py +5 -0
- coding/proxy/auth/__init__.py +13 -0
- coding/proxy/auth/providers/__init__.py +6 -0
- coding/proxy/auth/providers/base.py +35 -0
- coding/proxy/auth/providers/github.py +133 -0
- coding/proxy/auth/providers/google.py +237 -0
- coding/proxy/auth/runtime.py +122 -0
- coding/proxy/auth/store.py +74 -0
- coding/proxy/cli/__init__.py +151 -0
- coding/proxy/cli/auth_commands.py +224 -0
- coding/proxy/compat/__init__.py +30 -0
- coding/proxy/compat/canonical.py +193 -0
- coding/proxy/compat/session_store.py +137 -0
- coding/proxy/config/__init__.py +6 -0
- coding/proxy/config/auth_schema.py +24 -0
- coding/proxy/config/loader.py +139 -0
- coding/proxy/config/resiliency.py +46 -0
- coding/proxy/config/routing.py +279 -0
- coding/proxy/config/schema.py +280 -0
- coding/proxy/config/server.py +23 -0
- coding/proxy/config/vendors.py +53 -0
- coding/proxy/convert/__init__.py +14 -0
- coding/proxy/convert/anthropic_to_gemini.py +352 -0
- coding/proxy/convert/anthropic_to_openai.py +352 -0
- coding/proxy/convert/gemini_sse_adapter.py +169 -0
- coding/proxy/convert/gemini_to_anthropic.py +98 -0
- coding/proxy/convert/openai_to_anthropic.py +88 -0
- coding/proxy/logging/__init__.py +49 -0
- coding/proxy/logging/db.py +308 -0
- coding/proxy/logging/stats.py +129 -0
- coding/proxy/model/__init__.py +93 -0
- coding/proxy/model/auth.py +32 -0
- coding/proxy/model/compat.py +153 -0
- coding/proxy/model/constants.py +21 -0
- coding/proxy/model/pricing.py +70 -0
- coding/proxy/model/token.py +64 -0
- coding/proxy/model/vendor.py +218 -0
- coding/proxy/pricing.py +100 -0
- coding/proxy/routing/__init__.py +47 -0
- coding/proxy/routing/circuit_breaker.py +152 -0
- coding/proxy/routing/error_classifier.py +67 -0
- coding/proxy/routing/executor.py +453 -0
- coding/proxy/routing/model_mapper.py +90 -0
- coding/proxy/routing/quota_guard.py +169 -0
- coding/proxy/routing/rate_limit.py +159 -0
- coding/proxy/routing/retry.py +82 -0
- coding/proxy/routing/router.py +84 -0
- coding/proxy/routing/session_manager.py +62 -0
- coding/proxy/routing/tier.py +171 -0
- coding/proxy/routing/usage_parser.py +193 -0
- coding/proxy/routing/usage_recorder.py +131 -0
- coding/proxy/server/__init__.py +1 -0
- coding/proxy/server/app.py +142 -0
- coding/proxy/server/factory.py +175 -0
- coding/proxy/server/request_normalizer.py +139 -0
- coding/proxy/server/responses.py +74 -0
- coding/proxy/server/routes.py +264 -0
- coding/proxy/streaming/__init__.py +1 -0
- coding/proxy/streaming/anthropic_compat.py +484 -0
- coding/proxy/vendors/__init__.py +29 -0
- coding/proxy/vendors/anthropic.py +44 -0
- coding/proxy/vendors/antigravity.py +328 -0
- coding/proxy/vendors/base.py +353 -0
- coding/proxy/vendors/copilot.py +702 -0
- coding/proxy/vendors/copilot_models.py +438 -0
- coding/proxy/vendors/copilot_token_manager.py +167 -0
- coding/proxy/vendors/copilot_urls.py +16 -0
- coding/proxy/vendors/mixins.py +71 -0
- coding/proxy/vendors/token_manager.py +128 -0
- coding/proxy/vendors/zhipu.py +243 -0
- coding_proxy-0.1.0.dist-info/METADATA +184 -0
- coding_proxy-0.1.0.dist-info/RECORD +77 -0
- coding_proxy-0.1.0.dist-info/WHEEL +4 -0
- coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
- coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""FastAPI 应用工厂函数 — 供应商实例化与凭证解析."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from ..auth.providers.google import (
|
|
10
|
+
GoogleOAuthProvider,
|
|
11
|
+
_DEFAULT_CLIENT_ID as _GOOGLE_DEFAULT_CLIENT_ID,
|
|
12
|
+
_DEFAULT_CLIENT_SECRET as _GOOGLE_DEFAULT_CLIENT_SECRET,
|
|
13
|
+
_REQUIRED_SCOPE_SET as _GOOGLE_REQUIRED_SCOPE_SET,
|
|
14
|
+
)
|
|
15
|
+
from ..auth.runtime import RuntimeReauthCoordinator
|
|
16
|
+
from ..auth.store import TokenStoreManager
|
|
17
|
+
from ..vendors.antigravity import AntigravityVendor
|
|
18
|
+
from ..vendors.anthropic import AnthropicVendor
|
|
19
|
+
from ..vendors.copilot import CopilotVendor
|
|
20
|
+
from ..vendors.zhipu import ZhipuVendor
|
|
21
|
+
from ..config.schema import (
|
|
22
|
+
AntigravityConfig,
|
|
23
|
+
AnthropicConfig,
|
|
24
|
+
CircuitBreakerConfig,
|
|
25
|
+
CopilotConfig,
|
|
26
|
+
FailoverConfig,
|
|
27
|
+
QuotaGuardConfig,
|
|
28
|
+
TierConfig,
|
|
29
|
+
ZhipuConfig,
|
|
30
|
+
)
|
|
31
|
+
from ..routing.circuit_breaker import CircuitBreaker
|
|
32
|
+
from ..routing.model_mapper import ModelMapper
|
|
33
|
+
from ..routing.quota_guard import QuotaGuard
|
|
34
|
+
from ..routing.tier import VendorTier
|
|
35
|
+
# 向后兼容别名
|
|
36
|
+
BackendTier = VendorTier # noqa: F401 (deprecated)
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _find_anthropic_vendor(router: Any) -> AnthropicVendor | None:
|
|
42
|
+
"""从路由链中查找 Anthropic 供应商实例(用于旁路透传)."""
|
|
43
|
+
for tier in router.tiers:
|
|
44
|
+
if isinstance(tier.vendor, AnthropicVendor):
|
|
45
|
+
return tier.vendor
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _find_copilot_vendor(router: Any) -> CopilotVendor | None:
|
|
50
|
+
"""从路由链中查找 Copilot 供应商实例(用于诊断与模型探测)."""
|
|
51
|
+
for tier in router.tiers:
|
|
52
|
+
if isinstance(tier.vendor, CopilotVendor):
|
|
53
|
+
return tier.vendor
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _build_circuit_breaker(cfg: CircuitBreakerConfig) -> CircuitBreaker:
|
|
58
|
+
"""从配置构建熔断器实例."""
|
|
59
|
+
return CircuitBreaker(
|
|
60
|
+
failure_threshold=cfg.failure_threshold,
|
|
61
|
+
recovery_timeout_seconds=cfg.recovery_timeout_seconds,
|
|
62
|
+
success_threshold=cfg.success_threshold,
|
|
63
|
+
max_recovery_seconds=cfg.max_recovery_seconds,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _build_quota_guard(cfg: QuotaGuardConfig) -> QuotaGuard:
|
|
68
|
+
"""从配置构建配额守卫实例."""
|
|
69
|
+
return QuotaGuard(
|
|
70
|
+
enabled=cfg.enabled,
|
|
71
|
+
token_budget=cfg.token_budget,
|
|
72
|
+
window_seconds=int(cfg.window_hours * 3600),
|
|
73
|
+
threshold_percent=cfg.threshold_percent,
|
|
74
|
+
probe_interval_seconds=cfg.probe_interval_seconds,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _create_vendor_from_config(
|
|
79
|
+
vendor_cfg: TierConfig,
|
|
80
|
+
failover_cfg: FailoverConfig,
|
|
81
|
+
mapper: ModelMapper,
|
|
82
|
+
token_store: TokenStoreManager,
|
|
83
|
+
) -> Any:
|
|
84
|
+
"""根据 vendor_cfg.vendor 创建对应供应商实例(Strategy + Factory 模式)."""
|
|
85
|
+
match vendor_cfg.vendor:
|
|
86
|
+
case "anthropic":
|
|
87
|
+
cfg = AnthropicConfig(
|
|
88
|
+
enabled=vendor_cfg.enabled,
|
|
89
|
+
base_url=vendor_cfg.base_url or "https://api.anthropic.com",
|
|
90
|
+
timeout_ms=vendor_cfg.timeout_ms,
|
|
91
|
+
)
|
|
92
|
+
return AnthropicVendor(cfg, failover_cfg)
|
|
93
|
+
case "copilot":
|
|
94
|
+
cfg = CopilotConfig(
|
|
95
|
+
enabled=vendor_cfg.enabled,
|
|
96
|
+
github_token=vendor_cfg.github_token,
|
|
97
|
+
account_type=vendor_cfg.account_type,
|
|
98
|
+
token_url=vendor_cfg.token_url,
|
|
99
|
+
base_url=vendor_cfg.base_url,
|
|
100
|
+
models_cache_ttl_seconds=vendor_cfg.models_cache_ttl_seconds,
|
|
101
|
+
timeout_ms=vendor_cfg.timeout_ms,
|
|
102
|
+
)
|
|
103
|
+
cfg = _resolve_copilot_credentials(cfg, token_store)
|
|
104
|
+
return CopilotVendor(cfg, failover_cfg, mapper)
|
|
105
|
+
case "antigravity":
|
|
106
|
+
cfg = AntigravityConfig(
|
|
107
|
+
enabled=vendor_cfg.enabled,
|
|
108
|
+
client_id=vendor_cfg.client_id,
|
|
109
|
+
client_secret=vendor_cfg.client_secret,
|
|
110
|
+
refresh_token=vendor_cfg.refresh_token,
|
|
111
|
+
base_url=vendor_cfg.base_url or "https://generativelanguage.googleapis.com/v1beta",
|
|
112
|
+
model_endpoint=vendor_cfg.model_endpoint,
|
|
113
|
+
timeout_ms=vendor_cfg.timeout_ms,
|
|
114
|
+
)
|
|
115
|
+
cfg = _resolve_antigravity_credentials(cfg, token_store)
|
|
116
|
+
return AntigravityVendor(cfg, failover_cfg, mapper)
|
|
117
|
+
case "zhipu":
|
|
118
|
+
cfg = ZhipuConfig(
|
|
119
|
+
enabled=vendor_cfg.enabled,
|
|
120
|
+
base_url=vendor_cfg.base_url or "https://open.bigmodel.cn/api/anthropic",
|
|
121
|
+
api_key=vendor_cfg.api_key,
|
|
122
|
+
timeout_ms=vendor_cfg.timeout_ms,
|
|
123
|
+
)
|
|
124
|
+
return ZhipuVendor(cfg, mapper)
|
|
125
|
+
case _:
|
|
126
|
+
raise ValueError(f"未知的 vendor 类型: {vendor_cfg.vendor!r}")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _resolve_copilot_credentials(cfg: CopilotConfig, token_store: TokenStoreManager) -> CopilotConfig:
|
|
130
|
+
"""合并 Copilot 凭证: Token Store > Config YAML.
|
|
131
|
+
|
|
132
|
+
返回更新后的 CopilotConfig(github_token 已填充)。
|
|
133
|
+
"""
|
|
134
|
+
if cfg.github_token:
|
|
135
|
+
return cfg # config.yaml 已有凭证,直接使用
|
|
136
|
+
|
|
137
|
+
tokens = token_store.get("github")
|
|
138
|
+
if tokens.access_token:
|
|
139
|
+
cfg = cfg.model_copy(update={"github_token": tokens.access_token})
|
|
140
|
+
logger.info("Copilot: 使用 Token Store 中的 GitHub 凭证")
|
|
141
|
+
|
|
142
|
+
return cfg
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _resolve_antigravity_credentials(cfg: AntigravityConfig, token_store: TokenStoreManager) -> AntigravityConfig:
|
|
146
|
+
"""合并 Antigravity 凭证: Token Store > Config YAML.
|
|
147
|
+
|
|
148
|
+
优先使用 Token Store 中的 refresh_token;
|
|
149
|
+
若 config.yaml 已有完整凭证(client_id + client_secret + refresh_token),则直接使用。
|
|
150
|
+
"""
|
|
151
|
+
if cfg.refresh_token:
|
|
152
|
+
return cfg # config.yaml 已有凭证,直接使用
|
|
153
|
+
|
|
154
|
+
tokens = token_store.get("google")
|
|
155
|
+
if tokens.refresh_token:
|
|
156
|
+
updates: dict[str, str] = {"refresh_token": tokens.refresh_token}
|
|
157
|
+
# 若 config.yaml 缺少 OAuth 凭据,使用默认公开凭据
|
|
158
|
+
if not cfg.client_id:
|
|
159
|
+
updates["client_id"] = _GOOGLE_DEFAULT_CLIENT_ID
|
|
160
|
+
if not cfg.client_secret:
|
|
161
|
+
updates["client_secret"] = _GOOGLE_DEFAULT_CLIENT_SECRET
|
|
162
|
+
cfg = cfg.model_copy(update=updates)
|
|
163
|
+
logger.info("Antigravity: 使用 Token Store 中的 Google 凭证")
|
|
164
|
+
if tokens.scope and not GoogleOAuthProvider.has_required_scopes(tokens.scope):
|
|
165
|
+
missing = sorted(_GOOGLE_REQUIRED_SCOPE_SET.difference(tokens.scope.split()))
|
|
166
|
+
logger.warning("Antigravity: Token Store 中的 Google scope 不完整,缺少: %s", ", ".join(missing))
|
|
167
|
+
|
|
168
|
+
return cfg
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ── 向后兼容别名 (deprecated) ──────────────────────────────
|
|
172
|
+
|
|
173
|
+
_find_anthropic_backend = _find_anthropic_vendor # noqa: F401
|
|
174
|
+
_find_copilot_backend = _find_copilot_vendor # noqa: F401
|
|
175
|
+
_create_backend_from_tier = _create_vendor_from_config # noqa: F401
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""入站 Anthropic Messages 请求规范化."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import re
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_ANTHROPIC_TOOL_USE_ID_RE = re.compile(r"^toolu_[A-Za-z0-9_]+$")
|
|
12
|
+
_ANTHROPIC_SERVER_TOOL_USE_ID_RE = re.compile(r"^srvtoolu_[A-Za-z0-9_]+$")
|
|
13
|
+
_VENDOR_TOOL_BLOCK_TYPES = {
|
|
14
|
+
"server_tool_use_delta",
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class NormalizationResult:
|
|
20
|
+
"""请求规范化结果."""
|
|
21
|
+
|
|
22
|
+
body: dict[str, Any]
|
|
23
|
+
adaptations: list[str] = field(default_factory=list)
|
|
24
|
+
fatal_reasons: list[str] = field(default_factory=list)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def recoverable(self) -> bool:
|
|
28
|
+
return not self.fatal_reasons
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def normalize_anthropic_request(body: dict[str, Any]) -> NormalizationResult:
|
|
32
|
+
"""清洗供应商私有块,尽量恢复为合法 Anthropic Messages 请求."""
|
|
33
|
+
normalized = copy.deepcopy(body)
|
|
34
|
+
adaptations: list[str] = []
|
|
35
|
+
fatal_reasons: list[str] = []
|
|
36
|
+
tool_id_map: dict[str, str] = {}
|
|
37
|
+
normalized_counter = 0
|
|
38
|
+
|
|
39
|
+
def next_tool_id() -> str:
|
|
40
|
+
nonlocal normalized_counter
|
|
41
|
+
normalized_counter += 1
|
|
42
|
+
return f"toolu_normalized_{normalized_counter}"
|
|
43
|
+
|
|
44
|
+
def normalize_content_block(
|
|
45
|
+
block: Any,
|
|
46
|
+
*,
|
|
47
|
+
message_role: str,
|
|
48
|
+
message_index: int,
|
|
49
|
+
block_index: int,
|
|
50
|
+
) -> dict[str, Any] | None:
|
|
51
|
+
if not isinstance(block, dict):
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
block_type = block.get("type")
|
|
55
|
+
if block_type in _VENDOR_TOOL_BLOCK_TYPES:
|
|
56
|
+
adaptations.append(f"vendor_block_removed:{block_type}")
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
if message_role == "assistant" and block_type in {"tool_use", "server_tool_use"}:
|
|
60
|
+
normalized_block = dict(block)
|
|
61
|
+
tool_id = normalized_block.get("id")
|
|
62
|
+
if isinstance(tool_id, str) and _ANTHROPIC_SERVER_TOOL_USE_ID_RE.match(tool_id):
|
|
63
|
+
new_id = next_tool_id()
|
|
64
|
+
tool_id_map[tool_id] = new_id
|
|
65
|
+
normalized_block["id"] = new_id
|
|
66
|
+
normalized_block["type"] = "tool_use"
|
|
67
|
+
adaptations.append("server_tool_use_id_rewritten_for_anthropic")
|
|
68
|
+
elif isinstance(tool_id, str) and _ANTHROPIC_TOOL_USE_ID_RE.match(tool_id):
|
|
69
|
+
normalized_block["type"] = "tool_use"
|
|
70
|
+
elif isinstance(tool_id, str) and tool_id:
|
|
71
|
+
if "name" in normalized_block:
|
|
72
|
+
new_id = next_tool_id()
|
|
73
|
+
tool_id_map[tool_id] = new_id
|
|
74
|
+
normalized_block["id"] = new_id
|
|
75
|
+
normalized_block["type"] = "tool_use"
|
|
76
|
+
adaptations.append("invalid_tool_use_id_rewritten_for_anthropic")
|
|
77
|
+
else:
|
|
78
|
+
fatal_reasons.append(
|
|
79
|
+
f"messages.{message_index}.content.{block_index}: tool block missing name for id rewrite"
|
|
80
|
+
)
|
|
81
|
+
return None
|
|
82
|
+
else:
|
|
83
|
+
fatal_reasons.append(
|
|
84
|
+
f"messages.{message_index}.content.{block_index}: tool block missing id"
|
|
85
|
+
)
|
|
86
|
+
return None
|
|
87
|
+
return normalized_block
|
|
88
|
+
|
|
89
|
+
if message_role == "user" and block_type == "tool_result":
|
|
90
|
+
normalized_block = dict(block)
|
|
91
|
+
tool_use_id = normalized_block.get("tool_use_id")
|
|
92
|
+
if isinstance(tool_use_id, str) and tool_use_id in tool_id_map:
|
|
93
|
+
normalized_block["tool_use_id"] = tool_id_map[tool_use_id]
|
|
94
|
+
adaptations.append("tool_result_tool_use_id_rewritten")
|
|
95
|
+
elif isinstance(tool_use_id, str) and (
|
|
96
|
+
_ANTHROPIC_TOOL_USE_ID_RE.match(tool_use_id)
|
|
97
|
+
or _ANTHROPIC_SERVER_TOOL_USE_ID_RE.match(tool_use_id)
|
|
98
|
+
):
|
|
99
|
+
# 保持原样。对 server_tool_use_id 的用户结果,若未在当前请求体中出现,
|
|
100
|
+
# 交由上游决定是否接受,避免错误猜测跨轮次关联。
|
|
101
|
+
return normalized_block
|
|
102
|
+
elif isinstance(tool_use_id, str) and tool_use_id:
|
|
103
|
+
fatal_reasons.append(
|
|
104
|
+
f"messages.{message_index}.content.{block_index}: tool_result references unknown tool_use_id"
|
|
105
|
+
)
|
|
106
|
+
return None
|
|
107
|
+
else:
|
|
108
|
+
fatal_reasons.append(
|
|
109
|
+
f"messages.{message_index}.content.{block_index}: tool_result missing tool_use_id"
|
|
110
|
+
)
|
|
111
|
+
return None
|
|
112
|
+
return normalized_block
|
|
113
|
+
|
|
114
|
+
return dict(block)
|
|
115
|
+
|
|
116
|
+
for message_index, message in enumerate(normalized.get("messages", [])):
|
|
117
|
+
if not isinstance(message, dict):
|
|
118
|
+
continue
|
|
119
|
+
content = message.get("content")
|
|
120
|
+
if not isinstance(content, list):
|
|
121
|
+
continue
|
|
122
|
+
role = str(message.get("role") or "")
|
|
123
|
+
new_content: list[Any] = []
|
|
124
|
+
for block_index, block in enumerate(content):
|
|
125
|
+
normalized_block = normalize_content_block(
|
|
126
|
+
block,
|
|
127
|
+
message_role=role,
|
|
128
|
+
message_index=message_index,
|
|
129
|
+
block_index=block_index,
|
|
130
|
+
)
|
|
131
|
+
if normalized_block is not None:
|
|
132
|
+
new_content.append(normalized_block)
|
|
133
|
+
message["content"] = new_content
|
|
134
|
+
|
|
135
|
+
return NormalizationResult(
|
|
136
|
+
body=normalized,
|
|
137
|
+
adaptations=sorted(set(adaptations)),
|
|
138
|
+
fatal_reasons=fatal_reasons,
|
|
139
|
+
)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""HTTP 错误响应构造工具."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from fastapi import Response
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def json_error_response(
|
|
13
|
+
status_code: int,
|
|
14
|
+
*,
|
|
15
|
+
error_type: str,
|
|
16
|
+
message: str,
|
|
17
|
+
details: list[str] | None = None,
|
|
18
|
+
) -> Response:
|
|
19
|
+
"""构造 JSON 格式的错误响应."""
|
|
20
|
+
payload: dict[str, Any] = {
|
|
21
|
+
"error": {
|
|
22
|
+
"type": error_type,
|
|
23
|
+
"message": message,
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
if details:
|
|
27
|
+
payload["error"]["details"] = details
|
|
28
|
+
return Response(
|
|
29
|
+
content=json.dumps(payload, ensure_ascii=False).encode(),
|
|
30
|
+
status_code=status_code,
|
|
31
|
+
media_type="application/json",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def stream_error_event(error_type: str, message: str, details: list[str] | None = None) -> bytes:
|
|
36
|
+
"""构造 SSE 格式的错误事件."""
|
|
37
|
+
payload: dict[str, Any] = {
|
|
38
|
+
"type": "error",
|
|
39
|
+
"error": {
|
|
40
|
+
"type": error_type,
|
|
41
|
+
"message": message,
|
|
42
|
+
},
|
|
43
|
+
}
|
|
44
|
+
if details:
|
|
45
|
+
payload["error"]["details"] = details
|
|
46
|
+
return f"event: error\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n".encode()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def extract_stream_http_error(exc: httpx.HTTPStatusError) -> tuple[str, str]:
|
|
50
|
+
"""从 HTTPStatusError 中提取错误类型和消息."""
|
|
51
|
+
response = exc.response
|
|
52
|
+
if response is None:
|
|
53
|
+
return "api_error", str(exc)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
payload = response.json() if response.content else None
|
|
57
|
+
except (json.JSONDecodeError, UnicodeDecodeError, TypeError, ValueError):
|
|
58
|
+
payload = None
|
|
59
|
+
|
|
60
|
+
if isinstance(payload, dict):
|
|
61
|
+
error = payload.get("error")
|
|
62
|
+
if isinstance(error, dict):
|
|
63
|
+
error_type = error.get("type")
|
|
64
|
+
message = error.get("message")
|
|
65
|
+
if isinstance(error_type, str) and isinstance(message, str) and message:
|
|
66
|
+
return error_type, message
|
|
67
|
+
message = payload.get("message")
|
|
68
|
+
if isinstance(message, str) and message:
|
|
69
|
+
return "api_error", message
|
|
70
|
+
|
|
71
|
+
text = response.text.strip() if response.content else ""
|
|
72
|
+
if text:
|
|
73
|
+
return "api_error", text[:500]
|
|
74
|
+
return "api_error", str(exc)
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""路由注册 — 将 FastAPI 路由端点按职责分组注册到 app 实例."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
from fastapi import Request, Response
|
|
11
|
+
from fastapi.responses import StreamingResponse
|
|
12
|
+
|
|
13
|
+
from ..vendors.base import NoCompatibleVendorError
|
|
14
|
+
# 向后兼容别名
|
|
15
|
+
NoCompatibleBackendError = NoCompatibleVendorError # noqa: F401 (deprecated)
|
|
16
|
+
from ..vendors.token_manager import TokenAcquireError
|
|
17
|
+
from .responses import (
|
|
18
|
+
extract_stream_http_error,
|
|
19
|
+
json_error_response,
|
|
20
|
+
stream_error_event,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def _stream_proxy(router: Any, body: dict, headers: dict) -> Any:
|
|
27
|
+
"""流式代理生成器."""
|
|
28
|
+
try:
|
|
29
|
+
async for chunk, vendor_name in router.route_stream(body, headers):
|
|
30
|
+
yield chunk
|
|
31
|
+
except NoCompatibleVendorError as exc:
|
|
32
|
+
yield (
|
|
33
|
+
"event: error\n"
|
|
34
|
+
f"data: {json.dumps({'type': 'error', 'error': {'type': 'invalid_request_error', 'message': str(exc), 'details': exc.reasons}}, ensure_ascii=False)}\n\n"
|
|
35
|
+
).encode()
|
|
36
|
+
except TokenAcquireError as exc:
|
|
37
|
+
yield (
|
|
38
|
+
"event: error\n"
|
|
39
|
+
f"data: {json.dumps({'type': 'error', 'error': {'type': 'authentication_error', 'message': str(exc)}}, ensure_ascii=False)}\n\n"
|
|
40
|
+
).encode()
|
|
41
|
+
except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as exc:
|
|
42
|
+
yield stream_error_event("api_error", f"上游不可达: {exc}")
|
|
43
|
+
except httpx.HTTPStatusError as exc:
|
|
44
|
+
error_type, message = extract_stream_http_error(exc)
|
|
45
|
+
yield stream_error_event(error_type, message)
|
|
46
|
+
except Exception as exc:
|
|
47
|
+
logger.error(
|
|
48
|
+
"_stream_proxy 未预期异常: %s: %s",
|
|
49
|
+
type(exc).__name__, exc,
|
|
50
|
+
exc_info=True,
|
|
51
|
+
)
|
|
52
|
+
yield stream_error_event(
|
|
53
|
+
"api_error",
|
|
54
|
+
f"内部错误: {type(exc).__name__}: {exc}",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def register_core_routes(app: Any, router: Any) -> None:
|
|
59
|
+
"""注册核心 API 路由:消息代理与 Token 计数."""
|
|
60
|
+
from .request_normalizer import normalize_anthropic_request
|
|
61
|
+
|
|
62
|
+
@app.post("/v1/messages")
|
|
63
|
+
async def messages(request: Request) -> Response:
|
|
64
|
+
"""Anthropic Messages API 代理端点."""
|
|
65
|
+
body = await request.json()
|
|
66
|
+
headers = dict(request.headers)
|
|
67
|
+
normalization = normalize_anthropic_request(body)
|
|
68
|
+
body = normalization.body
|
|
69
|
+
is_streaming = body.get("stream", False)
|
|
70
|
+
|
|
71
|
+
if normalization.adaptations:
|
|
72
|
+
logger.debug("Request normalized before routing: %s", ", ".join(normalization.adaptations))
|
|
73
|
+
|
|
74
|
+
if is_streaming:
|
|
75
|
+
return StreamingResponse(
|
|
76
|
+
_stream_proxy(router, body, headers),
|
|
77
|
+
media_type="text/event-stream",
|
|
78
|
+
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
resp = await router.route_message(body, headers)
|
|
83
|
+
except NoCompatibleVendorError as exc:
|
|
84
|
+
return json_error_response(400, error_type="invalid_request_error", message=str(exc), details=exc.reasons)
|
|
85
|
+
except TokenAcquireError as exc:
|
|
86
|
+
return json_error_response(503, error_type="authentication_error", message=str(exc))
|
|
87
|
+
except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as exc:
|
|
88
|
+
return json_error_response(502, error_type="api_error", message=f"上游不可达: {exc}")
|
|
89
|
+
except Exception as exc:
|
|
90
|
+
logger.error(
|
|
91
|
+
"messages() 非流式路径未预期异常: %s: %s",
|
|
92
|
+
type(exc).__name__, exc,
|
|
93
|
+
exc_info=True,
|
|
94
|
+
)
|
|
95
|
+
return json_error_response(500, error_type="api_error", message=f"内部错误: {type(exc).__name__}")
|
|
96
|
+
|
|
97
|
+
# 对上游返回的非标准错误格式输出诊断日志(如 Zhipu 使用 code 而非 type)
|
|
98
|
+
if resp.status_code >= 500 and resp.raw_body:
|
|
99
|
+
try:
|
|
100
|
+
payload = json.loads(resp.raw_body)
|
|
101
|
+
if isinstance(payload, dict) and "error" in payload:
|
|
102
|
+
err = payload["error"]
|
|
103
|
+
if isinstance(err, dict) and "type" not in err and "code" in err:
|
|
104
|
+
logger.debug(
|
|
105
|
+
"检测到非标准上游错误格式(含 code 非 type): vendor_error=%s",
|
|
106
|
+
json.dumps(err, ensure_ascii=False)[:200],
|
|
107
|
+
)
|
|
108
|
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
return Response(content=resp.raw_body or b"{}", status_code=resp.status_code, media_type="application/json")
|
|
112
|
+
|
|
113
|
+
@app.post("/v1/messages/count_tokens")
|
|
114
|
+
async def count_tokens(request: Request) -> Response:
|
|
115
|
+
"""Token 计数 API 透传 — 旁路直通 Anthropic,不经过路由链.
|
|
116
|
+
|
|
117
|
+
仅当 Anthropic 主供应商启用时可用;其他供应商不支持此协议。
|
|
118
|
+
"""
|
|
119
|
+
from .factory import _find_anthropic_vendor
|
|
120
|
+
|
|
121
|
+
anthropic_vendor = _find_anthropic_vendor(router)
|
|
122
|
+
if anthropic_vendor is None:
|
|
123
|
+
return Response(
|
|
124
|
+
content=b'{"error":{"type":"not_found","message":"count_tokens requires anthropic vendor"}}',
|
|
125
|
+
status_code=404,
|
|
126
|
+
media_type="application/json",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
body = await request.json()
|
|
130
|
+
headers = dict(request.headers)
|
|
131
|
+
prepared_body, prepared_headers = await anthropic_vendor._prepare_request(body, headers)
|
|
132
|
+
|
|
133
|
+
client = anthropic_vendor._get_client()
|
|
134
|
+
url = "/v1/messages/count_tokens"
|
|
135
|
+
if request.query_params:
|
|
136
|
+
url = f"{url}?{request.query_params}"
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
response = await client.post(url, json=prepared_body, headers=prepared_headers)
|
|
140
|
+
return Response(content=response.content, status_code=response.status_code, media_type="application/json")
|
|
141
|
+
except (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError) as exc:
|
|
142
|
+
logger.warning("count_tokens proxy failed: %s", exc)
|
|
143
|
+
return Response(
|
|
144
|
+
content=b'{"error":{"type":"api_error","message":"count_tokens upstream unreachable"}}',
|
|
145
|
+
status_code=502,
|
|
146
|
+
media_type="application/json",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def register_health_routes(app: Any) -> None:
|
|
151
|
+
"""注册健康检查与连通性探测路由."""
|
|
152
|
+
|
|
153
|
+
@app.get("/health")
|
|
154
|
+
async def health() -> dict:
|
|
155
|
+
return {"status": "ok"}
|
|
156
|
+
|
|
157
|
+
@app.head("/")
|
|
158
|
+
@app.get("/")
|
|
159
|
+
async def root() -> Response:
|
|
160
|
+
"""根路径连通性探测 — Claude Code 在建连前发送 HEAD / 作为 health probe."""
|
|
161
|
+
return Response(status_code=200)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def register_status_route(app: Any, router: Any) -> None:
|
|
165
|
+
"""注册状态查询路由."""
|
|
166
|
+
|
|
167
|
+
@app.get("/api/status")
|
|
168
|
+
async def status() -> dict:
|
|
169
|
+
result: dict[str, Any] = {"tiers": []}
|
|
170
|
+
for tier in router.tiers:
|
|
171
|
+
info: dict[str, Any] = {"name": tier.name}
|
|
172
|
+
if tier.circuit_breaker:
|
|
173
|
+
info["circuit_breaker"] = tier.circuit_breaker.get_info()
|
|
174
|
+
if tier.quota_guard and tier.quota_guard.enabled:
|
|
175
|
+
info["quota_guard"] = tier.quota_guard.get_info()
|
|
176
|
+
if tier.weekly_quota_guard and tier.weekly_quota_guard.enabled:
|
|
177
|
+
info["weekly_quota_guard"] = tier.weekly_quota_guard.get_info()
|
|
178
|
+
info["rate_limit"] = tier.get_rate_limit_info()
|
|
179
|
+
diagnostics = tier.vendor.get_diagnostics()
|
|
180
|
+
if diagnostics:
|
|
181
|
+
info["diagnostics"] = diagnostics
|
|
182
|
+
result["tiers"].append(info)
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def register_copilot_routes(app: Any, router: Any) -> None:
|
|
187
|
+
"""注册 Copilot 诊断与模型探测路由."""
|
|
188
|
+
from .factory import _find_copilot_vendor
|
|
189
|
+
|
|
190
|
+
@app.get("/api/copilot/diagnostics")
|
|
191
|
+
async def copilot_diagnostics() -> Response:
|
|
192
|
+
"""返回 Copilot 认证与交换链路的脱敏诊断信息."""
|
|
193
|
+
vendor = _find_copilot_vendor(router)
|
|
194
|
+
if vendor is None:
|
|
195
|
+
return json_error_response(404, error_type="not_found", message="copilot vendor not enabled")
|
|
196
|
+
return Response(
|
|
197
|
+
content=json.dumps(vendor.get_diagnostics(), ensure_ascii=False).encode(),
|
|
198
|
+
status_code=200,
|
|
199
|
+
media_type="application/json",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
@app.get("/api/copilot/models")
|
|
203
|
+
async def copilot_models() -> Response:
|
|
204
|
+
"""按需探测当前 Copilot 会话可见模型列表."""
|
|
205
|
+
vendor = _find_copilot_vendor(router)
|
|
206
|
+
if vendor is None:
|
|
207
|
+
return json_error_response(404, error_type="not_found", message="copilot vendor not enabled")
|
|
208
|
+
try:
|
|
209
|
+
probe = await vendor.probe_models()
|
|
210
|
+
except TokenAcquireError as exc:
|
|
211
|
+
return json_error_response(503, error_type="authentication_error", message=str(exc))
|
|
212
|
+
except (httpx.TimeoutException, httpx.ConnectError) as exc:
|
|
213
|
+
return json_error_response(502, error_type="api_error", message=f"copilot models probe failed: {exc}")
|
|
214
|
+
return Response(
|
|
215
|
+
content=json.dumps(probe, ensure_ascii=False).encode(),
|
|
216
|
+
status_code=200 if probe.get("probe_status") == "ok" else 502,
|
|
217
|
+
media_type="application/json",
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def register_admin_routes(app: Any, router: Any) -> None:
|
|
222
|
+
"""注册管理操作路由(重置等)."""
|
|
223
|
+
|
|
224
|
+
@app.post("/api/reset")
|
|
225
|
+
async def reset_circuit() -> dict:
|
|
226
|
+
for tier in router.tiers:
|
|
227
|
+
if tier.circuit_breaker:
|
|
228
|
+
tier.circuit_breaker.reset()
|
|
229
|
+
if tier.quota_guard:
|
|
230
|
+
tier.quota_guard.reset()
|
|
231
|
+
if tier.weekly_quota_guard:
|
|
232
|
+
tier.weekly_quota_guard.reset()
|
|
233
|
+
tier.reset_rate_limit()
|
|
234
|
+
return {"status": "ok"}
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def register_reauth_routes(app: Any, reauth_coordinator: Any) -> None:
|
|
238
|
+
"""注册重认证路由."""
|
|
239
|
+
|
|
240
|
+
@app.get("/api/reauth/status")
|
|
241
|
+
async def reauth_status() -> dict:
|
|
242
|
+
"""查询运行时重认证状态."""
|
|
243
|
+
if not reauth_coordinator:
|
|
244
|
+
return {"providers": {}}
|
|
245
|
+
return {"providers": reauth_coordinator.get_status()}
|
|
246
|
+
|
|
247
|
+
@app.post("/api/reauth/{provider}")
|
|
248
|
+
async def trigger_reauth(provider: str) -> Response:
|
|
249
|
+
"""手动触发指定 provider 的运行时重认证."""
|
|
250
|
+
if not reauth_coordinator:
|
|
251
|
+
return Response(content=b'{"error":"reauth not available"}', status_code=404, media_type="application/json")
|
|
252
|
+
await reauth_coordinator.request_reauth(provider)
|
|
253
|
+
return Response(content=b'{"status":"reauth requested"}', status_code=202, media_type="application/json")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def register_all_routes(app: Any, router: Any, reauth_coordinator: Any | None = None) -> None:
|
|
257
|
+
"""一次性注册所有路由分组."""
|
|
258
|
+
register_core_routes(app, router)
|
|
259
|
+
register_health_routes(app)
|
|
260
|
+
register_status_route(app, router)
|
|
261
|
+
register_copilot_routes(app, router)
|
|
262
|
+
register_admin_routes(app, router)
|
|
263
|
+
if reauth_coordinator:
|
|
264
|
+
register_reauth_routes(app, reauth_coordinator)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""流处理模块."""
|