coding-proxy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coding/__init__.py +0 -0
- coding/proxy/__init__.py +3 -0
- coding/proxy/__main__.py +5 -0
- coding/proxy/auth/__init__.py +13 -0
- coding/proxy/auth/providers/__init__.py +6 -0
- coding/proxy/auth/providers/base.py +35 -0
- coding/proxy/auth/providers/github.py +133 -0
- coding/proxy/auth/providers/google.py +237 -0
- coding/proxy/auth/runtime.py +122 -0
- coding/proxy/auth/store.py +74 -0
- coding/proxy/cli/__init__.py +151 -0
- coding/proxy/cli/auth_commands.py +224 -0
- coding/proxy/compat/__init__.py +30 -0
- coding/proxy/compat/canonical.py +193 -0
- coding/proxy/compat/session_store.py +137 -0
- coding/proxy/config/__init__.py +6 -0
- coding/proxy/config/auth_schema.py +24 -0
- coding/proxy/config/loader.py +139 -0
- coding/proxy/config/resiliency.py +46 -0
- coding/proxy/config/routing.py +279 -0
- coding/proxy/config/schema.py +280 -0
- coding/proxy/config/server.py +23 -0
- coding/proxy/config/vendors.py +53 -0
- coding/proxy/convert/__init__.py +14 -0
- coding/proxy/convert/anthropic_to_gemini.py +352 -0
- coding/proxy/convert/anthropic_to_openai.py +352 -0
- coding/proxy/convert/gemini_sse_adapter.py +169 -0
- coding/proxy/convert/gemini_to_anthropic.py +98 -0
- coding/proxy/convert/openai_to_anthropic.py +88 -0
- coding/proxy/logging/__init__.py +49 -0
- coding/proxy/logging/db.py +308 -0
- coding/proxy/logging/stats.py +129 -0
- coding/proxy/model/__init__.py +93 -0
- coding/proxy/model/auth.py +32 -0
- coding/proxy/model/compat.py +153 -0
- coding/proxy/model/constants.py +21 -0
- coding/proxy/model/pricing.py +70 -0
- coding/proxy/model/token.py +64 -0
- coding/proxy/model/vendor.py +218 -0
- coding/proxy/pricing.py +100 -0
- coding/proxy/routing/__init__.py +47 -0
- coding/proxy/routing/circuit_breaker.py +152 -0
- coding/proxy/routing/error_classifier.py +67 -0
- coding/proxy/routing/executor.py +453 -0
- coding/proxy/routing/model_mapper.py +90 -0
- coding/proxy/routing/quota_guard.py +169 -0
- coding/proxy/routing/rate_limit.py +159 -0
- coding/proxy/routing/retry.py +82 -0
- coding/proxy/routing/router.py +84 -0
- coding/proxy/routing/session_manager.py +62 -0
- coding/proxy/routing/tier.py +171 -0
- coding/proxy/routing/usage_parser.py +193 -0
- coding/proxy/routing/usage_recorder.py +131 -0
- coding/proxy/server/__init__.py +1 -0
- coding/proxy/server/app.py +142 -0
- coding/proxy/server/factory.py +175 -0
- coding/proxy/server/request_normalizer.py +139 -0
- coding/proxy/server/responses.py +74 -0
- coding/proxy/server/routes.py +264 -0
- coding/proxy/streaming/__init__.py +1 -0
- coding/proxy/streaming/anthropic_compat.py +484 -0
- coding/proxy/vendors/__init__.py +29 -0
- coding/proxy/vendors/anthropic.py +44 -0
- coding/proxy/vendors/antigravity.py +328 -0
- coding/proxy/vendors/base.py +353 -0
- coding/proxy/vendors/copilot.py +702 -0
- coding/proxy/vendors/copilot_models.py +438 -0
- coding/proxy/vendors/copilot_token_manager.py +167 -0
- coding/proxy/vendors/copilot_urls.py +16 -0
- coding/proxy/vendors/mixins.py +71 -0
- coding/proxy/vendors/token_manager.py +128 -0
- coding/proxy/vendors/zhipu.py +243 -0
- coding_proxy-0.1.0.dist-info/METADATA +184 -0
- coding_proxy-0.1.0.dist-info/RECORD +77 -0
- coding_proxy-0.1.0.dist-info/WHEEL +4 -0
- coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
- coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""兼容层会话状态持久化.
|
|
2
|
+
|
|
3
|
+
数据类型 ``CompatSessionRecord`` 已迁移至 :mod:`coding.proxy.model.compat`。
|
|
4
|
+
本文件保留 ``CompatSessionStore`` 持久化管理器,类型通过 re-export 提供。
|
|
5
|
+
|
|
6
|
+
.. deprecated::
|
|
7
|
+
未来版本将移除类型 re-export,请直接从 :mod:`coding.proxy.model.compat` 导入。
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import time
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import aiosqlite
|
|
18
|
+
|
|
19
|
+
# noqa: F401
|
|
20
|
+
from ..model.compat import CompatSessionRecord
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_CREATE_TABLE = """
|
|
24
|
+
CREATE TABLE IF NOT EXISTS compat_session (
|
|
25
|
+
session_key TEXT PRIMARY KEY,
|
|
26
|
+
trace_id TEXT NOT NULL DEFAULT '',
|
|
27
|
+
tool_call_map_json TEXT NOT NULL DEFAULT '{}',
|
|
28
|
+
thought_signature_map_json TEXT NOT NULL DEFAULT '{}',
|
|
29
|
+
provider_state_json TEXT NOT NULL DEFAULT '{}',
|
|
30
|
+
state_version INTEGER NOT NULL DEFAULT 1,
|
|
31
|
+
updated_at_unix INTEGER NOT NULL DEFAULT 0
|
|
32
|
+
);
|
|
33
|
+
CREATE INDEX IF NOT EXISTS idx_compat_session_updated_at ON compat_session(updated_at_unix);
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CompatSessionStore:
|
|
38
|
+
"""兼容层会话状态 SQLite 持久化存储."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, db_path: Path, ttl_seconds: int = 86400) -> None:
|
|
41
|
+
self._db_path = db_path
|
|
42
|
+
self._ttl_seconds = ttl_seconds
|
|
43
|
+
self._db: aiosqlite.Connection | None = None
|
|
44
|
+
|
|
45
|
+
async def init(self) -> None:
|
|
46
|
+
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
47
|
+
self._db = await aiosqlite.connect(str(self._db_path))
|
|
48
|
+
self._db.row_factory = aiosqlite.Row
|
|
49
|
+
await self._db.execute("PRAGMA journal_mode=WAL")
|
|
50
|
+
await self._db.executescript(_CREATE_TABLE)
|
|
51
|
+
await self._purge_expired()
|
|
52
|
+
await self._db.commit()
|
|
53
|
+
|
|
54
|
+
async def get(self, session_key: str) -> CompatSessionRecord | None:
|
|
55
|
+
if not self._db:
|
|
56
|
+
return None
|
|
57
|
+
cursor = await self._db.execute(
|
|
58
|
+
"""SELECT session_key, trace_id, tool_call_map_json, thought_signature_map_json,
|
|
59
|
+
provider_state_json, state_version, updated_at_unix
|
|
60
|
+
FROM compat_session WHERE session_key = ?""",
|
|
61
|
+
(session_key,),
|
|
62
|
+
)
|
|
63
|
+
row = await cursor.fetchone()
|
|
64
|
+
if row is None:
|
|
65
|
+
return None
|
|
66
|
+
record = CompatSessionRecord(
|
|
67
|
+
session_key=row["session_key"],
|
|
68
|
+
trace_id=row["trace_id"],
|
|
69
|
+
tool_call_map=_loads_dict(row["tool_call_map_json"]),
|
|
70
|
+
thought_signature_map=_loads_dict(row["thought_signature_map_json"]),
|
|
71
|
+
provider_state=_loads_dict(row["provider_state_json"]),
|
|
72
|
+
state_version=row["state_version"],
|
|
73
|
+
updated_at_unix=row["updated_at_unix"],
|
|
74
|
+
)
|
|
75
|
+
if self._is_expired(record.updated_at_unix):
|
|
76
|
+
await self.delete(session_key)
|
|
77
|
+
return None
|
|
78
|
+
return record
|
|
79
|
+
|
|
80
|
+
async def upsert(self, record: CompatSessionRecord) -> None:
|
|
81
|
+
if not self._db:
|
|
82
|
+
return
|
|
83
|
+
updated_at = int(time.time())
|
|
84
|
+
await self._db.execute(
|
|
85
|
+
"""INSERT INTO compat_session (
|
|
86
|
+
session_key, trace_id, tool_call_map_json, thought_signature_map_json,
|
|
87
|
+
provider_state_json, state_version, updated_at_unix
|
|
88
|
+
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
89
|
+
ON CONFLICT(session_key) DO UPDATE SET
|
|
90
|
+
trace_id=excluded.trace_id,
|
|
91
|
+
tool_call_map_json=excluded.tool_call_map_json,
|
|
92
|
+
thought_signature_map_json=excluded.thought_signature_map_json,
|
|
93
|
+
provider_state_json=excluded.provider_state_json,
|
|
94
|
+
state_version=excluded.state_version,
|
|
95
|
+
updated_at_unix=excluded.updated_at_unix""",
|
|
96
|
+
(
|
|
97
|
+
record.session_key,
|
|
98
|
+
record.trace_id,
|
|
99
|
+
json.dumps(record.tool_call_map, ensure_ascii=False, sort_keys=True),
|
|
100
|
+
json.dumps(record.thought_signature_map, ensure_ascii=False, sort_keys=True),
|
|
101
|
+
json.dumps(record.provider_state, ensure_ascii=False, sort_keys=True),
|
|
102
|
+
record.state_version,
|
|
103
|
+
updated_at,
|
|
104
|
+
),
|
|
105
|
+
)
|
|
106
|
+
await self._db.commit()
|
|
107
|
+
|
|
108
|
+
async def delete(self, session_key: str) -> None:
|
|
109
|
+
if not self._db:
|
|
110
|
+
return
|
|
111
|
+
await self._db.execute("DELETE FROM compat_session WHERE session_key = ?", (session_key,))
|
|
112
|
+
await self._db.commit()
|
|
113
|
+
|
|
114
|
+
async def close(self) -> None:
|
|
115
|
+
if self._db:
|
|
116
|
+
await self._db.close()
|
|
117
|
+
self._db = None
|
|
118
|
+
|
|
119
|
+
async def _purge_expired(self) -> None:
|
|
120
|
+
if not self._db:
|
|
121
|
+
return
|
|
122
|
+
threshold = int(time.time()) - self._ttl_seconds
|
|
123
|
+
await self._db.execute(
|
|
124
|
+
"DELETE FROM compat_session WHERE updated_at_unix > 0 AND updated_at_unix < ?",
|
|
125
|
+
(threshold,),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def _is_expired(self, updated_at_unix: int) -> bool:
|
|
129
|
+
return updated_at_unix > 0 and (int(time.time()) - updated_at_unix) > self._ttl_seconds
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _loads_dict(raw: str) -> dict[str, Any]:
|
|
133
|
+
try:
|
|
134
|
+
value = json.loads(raw)
|
|
135
|
+
except (json.JSONDecodeError, TypeError):
|
|
136
|
+
return {}
|
|
137
|
+
return value if isinstance(value, dict) else {}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""OAuth 认证配置模型."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AuthConfig(BaseModel):
|
|
9
|
+
"""OAuth 登录配置.
|
|
10
|
+
|
|
11
|
+
.. note::
|
|
12
|
+
各 Provider 的硬编码默认值(google.py / github.py)应与此配置保持同步。
|
|
13
|
+
此配置作为运行时注入值的权威来源(Single Source of Truth)。
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
github_client_id: str = "Iv1.b507a08c87ecfe98"
|
|
17
|
+
google_client_id: str = (
|
|
18
|
+
"1071006060591-tmhssin2h21lcre235vtolojh4g403ep"
|
|
19
|
+
".apps.googleusercontent.com"
|
|
20
|
+
)
|
|
21
|
+
google_client_secret: str = "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"
|
|
22
|
+
token_store_path: str = "~/.coding-proxy/tokens.json"
|
|
23
|
+
|
|
24
|
+
__all__ = ["AuthConfig"]
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""YAML 配置加载 + 环境变量展开 + 示例配置深度合并."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import yaml
|
|
11
|
+
|
|
12
|
+
from .schema import ProxyConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
_ENV_VAR_RE = re.compile(r"\$\{([^}]+)\}")
|
|
17
|
+
|
|
18
|
+
# ── Legacy flat 格式字段集合(用于检测旧配置,避免与 example vendors 冲突) ──
|
|
19
|
+
_LEGACY_FLAT_KEYS: frozenset[str] = frozenset({
|
|
20
|
+
"primary", "copilot", "antigravity", "fallback",
|
|
21
|
+
"circuit_breaker", "copilot_circuit_breaker", "antigravity_circuit_breaker",
|
|
22
|
+
"quota_guard", "copilot_quota_guard", "antigravity_quota_guard",
|
|
23
|
+
})
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _expand_env(value: str) -> str:
|
|
27
|
+
"""将 ${VAR} 替换为环境变量值."""
|
|
28
|
+
def _replacer(match: re.Match) -> str:
|
|
29
|
+
var_name = match.group(1)
|
|
30
|
+
return os.environ.get(var_name, "")
|
|
31
|
+
return _ENV_VAR_RE.sub(_replacer, value)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _expand_env_recursive(obj):
|
|
35
|
+
"""递归展开字典中的环境变量."""
|
|
36
|
+
if isinstance(obj, dict):
|
|
37
|
+
return {k: _expand_env_recursive(v) for k, v in obj.items()}
|
|
38
|
+
if isinstance(obj, list):
|
|
39
|
+
return [_expand_env_recursive(v) for v in obj]
|
|
40
|
+
if isinstance(obj, str):
|
|
41
|
+
return _expand_env(obj)
|
|
42
|
+
return obj
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _deep_merge(defaults: dict, override: dict) -> dict:
|
|
46
|
+
"""深度合并两个字典.
|
|
47
|
+
|
|
48
|
+
合并策略:
|
|
49
|
+
- dict + dict → 递归合并子键(支持部分覆盖嵌套配置)
|
|
50
|
+
- list → override 完整替换 default(有序集合,顺序敏感)
|
|
51
|
+
- 标量 → override 替换 default
|
|
52
|
+
- override 中不存在于 defaults 的新键直接添加
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
defaults: 基础字典(通常来自 config.example.yaml)
|
|
56
|
+
override: 覆盖字典(来自用户配置文件)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
合并后的新字典
|
|
60
|
+
"""
|
|
61
|
+
result = dict(defaults)
|
|
62
|
+
for key, ov in override.items():
|
|
63
|
+
if key not in result:
|
|
64
|
+
result[key] = ov
|
|
65
|
+
elif isinstance(result.get(key), dict) and isinstance(ov, dict):
|
|
66
|
+
result[key] = _deep_merge(result[key], ov)
|
|
67
|
+
else:
|
|
68
|
+
result[key] = ov
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_example_config_path() -> Path | None:
|
|
73
|
+
"""定位 config.example.yaml 文件路径.
|
|
74
|
+
|
|
75
|
+
搜索策略:从 loader.py 所在目录向上回溯到项目根目录查找。
|
|
76
|
+
路径链:config/ → proxy/ → coding/ → src/ → project_root
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
文件路径对象,未找到时返回 None(触发降级至 Pydantic 默认值)
|
|
80
|
+
"""
|
|
81
|
+
current = Path(__file__).resolve().parent
|
|
82
|
+
# 先检查当前层,再逐级向上(共检查 5 层:config/ ~ project_root/)
|
|
83
|
+
for _ in range(5):
|
|
84
|
+
candidate = current / "config.example.yaml"
|
|
85
|
+
if candidate.is_file():
|
|
86
|
+
return candidate
|
|
87
|
+
current = current.parent
|
|
88
|
+
logger.debug("未找到 config.example.yaml,将使用 Pydantic 默认值")
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def load_config(path: Path | None = None) -> ProxyConfig:
|
|
93
|
+
"""加载配置文件,以 config.example.yaml 为基础进行深度合并.
|
|
94
|
+
|
|
95
|
+
加载优先级(低→高):
|
|
96
|
+
1. config.example.yaml 内置完整默认值
|
|
97
|
+
2. 用户配置文件(CWD/config.yaml > ~/.coding-proxy/config.yaml > -c 指定路径)
|
|
98
|
+
|
|
99
|
+
环境变量展开(${VAR})在深度合并之后执行,确保用户可通过环境变量覆盖任意字段。
|
|
100
|
+
"""
|
|
101
|
+
# ── 第 1 步:确定并加载用户配置 ─────────────────────────────
|
|
102
|
+
user_raw: dict = {}
|
|
103
|
+
if path is None:
|
|
104
|
+
candidates = [
|
|
105
|
+
Path("config.yaml"),
|
|
106
|
+
Path("~/.coding-proxy/config.yaml").expanduser(),
|
|
107
|
+
]
|
|
108
|
+
for candidate in candidates:
|
|
109
|
+
if candidate.exists():
|
|
110
|
+
path = candidate
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
if path and path.exists():
|
|
114
|
+
with open(path) as f:
|
|
115
|
+
user_raw = yaml.safe_load(f) or {}
|
|
116
|
+
|
|
117
|
+
# ── 第 2 步:加载示例默认配置 ─────────────────────────────
|
|
118
|
+
example_path = _get_example_config_path()
|
|
119
|
+
if example_path is None:
|
|
120
|
+
# 降级:无示例文件时使用纯 Pydantic 默认值(向后兼容)
|
|
121
|
+
expanded = _expand_env_recursive(user_raw)
|
|
122
|
+
return ProxyConfig(**expanded)
|
|
123
|
+
|
|
124
|
+
with open(example_path) as f:
|
|
125
|
+
defaults = yaml.safe_load(f) or {}
|
|
126
|
+
|
|
127
|
+
# ── Legacy 兼容:旧 flat 格式用户配置不应继承 example 的 vendors ──
|
|
128
|
+
# 当用户使用 legacy 字段时,移除 defaults 中的 vendors,
|
|
129
|
+
# 让 ProxyConfig._migrate_legacy_fields 迁移器正常接管 vendors 构建
|
|
130
|
+
if any(k in user_raw for k in _LEGACY_FLAT_KEYS):
|
|
131
|
+
defaults.pop("vendors", None)
|
|
132
|
+
|
|
133
|
+
# ── 第 3 步:深度合并 ─────────────────────────────────────
|
|
134
|
+
merged = _deep_merge(defaults, user_raw)
|
|
135
|
+
|
|
136
|
+
# ── 第 4 步:环境变量展开(必须在合并之后) ────────────────
|
|
137
|
+
expanded = _expand_env_recursive(merged)
|
|
138
|
+
|
|
139
|
+
return ProxyConfig(**expanded)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""弹性设施配置模型(熔断、重试、故障转移、配额守卫)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CircuitBreakerConfig(BaseModel):
|
|
9
|
+
failure_threshold: int = 3
|
|
10
|
+
recovery_timeout_seconds: int = 300
|
|
11
|
+
success_threshold: int = 2
|
|
12
|
+
max_recovery_seconds: int = 3600
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RetryConfig(BaseModel):
|
|
16
|
+
"""传输层重试配置."""
|
|
17
|
+
|
|
18
|
+
max_retries: int = 2
|
|
19
|
+
initial_delay_ms: int = 500
|
|
20
|
+
max_delay_ms: int = 5000
|
|
21
|
+
backoff_multiplier: float = 2.0
|
|
22
|
+
jitter: bool = True
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FailoverConfig(BaseModel):
|
|
26
|
+
status_codes: list[int] = Field(
|
|
27
|
+
default=[429, 403, 503, 500],
|
|
28
|
+
)
|
|
29
|
+
error_types: list[str] = Field(
|
|
30
|
+
default=["rate_limit_error", "overloaded_error", "api_error"],
|
|
31
|
+
)
|
|
32
|
+
error_message_patterns: list[str] = Field(
|
|
33
|
+
default=["quota", "limit exceeded", "usage cap", "capacity"],
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class QuotaGuardConfig(BaseModel):
|
|
38
|
+
enabled: bool = False
|
|
39
|
+
token_budget: int = 0
|
|
40
|
+
window_hours: float = 5.0
|
|
41
|
+
threshold_percent: float = 99.0
|
|
42
|
+
probe_interval_seconds: int = 300
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
"CircuitBreakerConfig", "RetryConfig", "FailoverConfig", "QuotaGuardConfig",
|
|
46
|
+
]
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""路由层配置模型(供应商类型、Vendor、模型映射、定价)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
from typing import Annotated, Any, Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr, model_validator
|
|
10
|
+
|
|
11
|
+
from .resiliency import CircuitBreakerConfig, QuotaGuardConfig, RetryConfig
|
|
12
|
+
|
|
13
|
+
# ── 价格字段解析($ / ¥ 前缀支持) ──────────────────────────
|
|
14
|
+
|
|
15
|
+
_PRICE_RE = re.compile(r"^([$\u00a5])\s*(.+)$")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _price_to_float(v: Any) -> float:
|
|
19
|
+
"""Pydantic BeforeValidator: 将 $/¥ 前缀字符串 / 纯数字统一转为 float."""
|
|
20
|
+
if isinstance(v, (int, float)):
|
|
21
|
+
return float(v)
|
|
22
|
+
if isinstance(v, str):
|
|
23
|
+
m = _PRICE_RE.match(v.strip())
|
|
24
|
+
if m:
|
|
25
|
+
return float(m.group(2))
|
|
26
|
+
return float(v)
|
|
27
|
+
return float(v)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _detect_currency(v: Any) -> str | None:
|
|
31
|
+
"""从原始值中检测币种前缀. 返回 ``'USD'``/``'CNY'``/``None``."""
|
|
32
|
+
if isinstance(v, str):
|
|
33
|
+
m = _PRICE_RE.match(v.strip())
|
|
34
|
+
if m:
|
|
35
|
+
return "USD" if m.group(1) == "$" else "CNY"
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
PriceField = Annotated[float, BeforeValidator(_price_to_float)]
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
# ── 供应商专属字段分组映射 ──────────────────────────────────────
|
|
44
|
+
# 每个 vendor 类型对应其专属字段集合,用于 VendorConfig 的语义标注与校验
|
|
45
|
+
|
|
46
|
+
_COPILOT_FIELDS: frozenset[str] = frozenset({
|
|
47
|
+
"github_token", "account_type", "token_url", "models_cache_ttl_seconds",
|
|
48
|
+
})
|
|
49
|
+
_ANTIGRAVITY_FIELDS: frozenset[str] = frozenset({
|
|
50
|
+
"client_id", "client_secret", "refresh_token", "model_endpoint",
|
|
51
|
+
})
|
|
52
|
+
_ZHIPU_FIELDS: frozenset[str] = frozenset({"api_key",})
|
|
53
|
+
|
|
54
|
+
_VENDOR_EXCLUSIVE_FIELDS: dict[str, frozenset[str]] = {
|
|
55
|
+
"copilot": _COPILOT_FIELDS,
|
|
56
|
+
"antigravity": _ANTIGRAVITY_FIELDS,
|
|
57
|
+
"zhipu": _ZHIPU_FIELDS,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
VendorType = Literal["anthropic", "copilot", "antigravity", "zhipu"]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ModelMappingRule(BaseModel):
|
|
64
|
+
pattern: str
|
|
65
|
+
target: str
|
|
66
|
+
is_regex: bool = False
|
|
67
|
+
vendors: list[str] = Field(default_factory=list)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ModelPricingEntry(BaseModel):
|
|
71
|
+
"""单个模型的定价配置(支持 $ / ¥ 前缀指定币种).
|
|
72
|
+
|
|
73
|
+
用法示例::
|
|
74
|
+
|
|
75
|
+
input_cost_per_mtok: $3.0 # USD
|
|
76
|
+
output_cost_per_mtok: \\u00a53.2 # CNY (¥)
|
|
77
|
+
cache_read_cost_per_mtok: 0.5 # 无前缀,默认 USD
|
|
78
|
+
|
|
79
|
+
向后兼容:不带前缀的纯数字默认视为 USD。
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
model_config = {"extra": "allow"}
|
|
83
|
+
|
|
84
|
+
vendor: str # 供应商名称(对应 usage 表"供应商"列)
|
|
85
|
+
model: str # 实际模型名(对应 usage 表"实际模型"列)
|
|
86
|
+
input_cost_per_mtok: PriceField = 0.0 # 输入 Token 单价
|
|
87
|
+
output_cost_per_mtok: PriceField = 0.0 # 输出 Token 单价
|
|
88
|
+
cache_write_cost_per_mtok: PriceField = 0.0 # 缓存创建 Token 单价
|
|
89
|
+
cache_read_cost_per_mtok: PriceField = 0.0 # 缓存读取 Token 单价
|
|
90
|
+
|
|
91
|
+
# ── 内部状态:币种信息(不参与序列化) ───────────────────
|
|
92
|
+
_currency: str = PrivateAttr(default="USD")
|
|
93
|
+
|
|
94
|
+
# ── 币种一致性校验与提取 ──────────────────────────────
|
|
95
|
+
|
|
96
|
+
@model_validator(mode="before")
|
|
97
|
+
@classmethod
|
|
98
|
+
def _check_currency_consistency(cls, data: Any) -> Any:
|
|
99
|
+
"""校验同一 entry 内所有非零价格的币种一致性,并提取币种."""
|
|
100
|
+
if not isinstance(data, dict):
|
|
101
|
+
return data
|
|
102
|
+
|
|
103
|
+
price_field_names = [
|
|
104
|
+
"input_cost_per_mtok", "output_cost_per_mtok",
|
|
105
|
+
"cache_write_cost_per_mtok", "cache_read_cost_per_mtok",
|
|
106
|
+
]
|
|
107
|
+
currencies: set[str] = set()
|
|
108
|
+
|
|
109
|
+
for name in price_field_names:
|
|
110
|
+
raw = data.get(name)
|
|
111
|
+
if raw is None or raw == 0 or raw == 0.0:
|
|
112
|
+
continue
|
|
113
|
+
detected = _detect_currency(raw)
|
|
114
|
+
if detected:
|
|
115
|
+
currencies.add(detected)
|
|
116
|
+
|
|
117
|
+
if len(currencies) > 1:
|
|
118
|
+
vendor = data.get("vendor", "?")
|
|
119
|
+
model = data.get("model", "?")
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"PricingEntry(vendor={vendor!r}, model={model!r}): "
|
|
122
|
+
f"检测到混合币种 {sorted(currencies)},"
|
|
123
|
+
"同一模型的所有单价必须使用相同币种 ($ 或 ¥)"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# 将检测到的币种暂存到临时键(mode=after 中消费后清理)
|
|
127
|
+
if currencies:
|
|
128
|
+
data["__detected_currency__"] = currencies.pop()
|
|
129
|
+
|
|
130
|
+
return data
|
|
131
|
+
|
|
132
|
+
@model_validator(mode="after")
|
|
133
|
+
def _capture_currency(self) -> "ModelPricingEntry":
|
|
134
|
+
"""将 mode=before 检测到的币种转移到 PrivateAttr,清理临时键."""
|
|
135
|
+
detected = getattr(self, "__detected_currency__", None)
|
|
136
|
+
if detected:
|
|
137
|
+
self._currency = detected
|
|
138
|
+
# 从 __pydantic_extra__ 中移除临时键,避免序列化泄露
|
|
139
|
+
if hasattr(self, "__pydantic_extra__") and "__detected_currency__" in self.__pydantic_extra__:
|
|
140
|
+
del self.__pydantic_extra__["__detected_currency__"]
|
|
141
|
+
else:
|
|
142
|
+
# 回退:直接删除实例属性
|
|
143
|
+
try:
|
|
144
|
+
object.__delattr__(self, "__detected_currency__")
|
|
145
|
+
except AttributeError:
|
|
146
|
+
pass
|
|
147
|
+
return self
|
|
148
|
+
|
|
149
|
+
@model_validator(mode="after")
|
|
150
|
+
def _validate_non_negative(self) -> "ModelPricingEntry":
|
|
151
|
+
"""校验所有价格字段非负."""
|
|
152
|
+
for name in (
|
|
153
|
+
"input_cost_per_mtok", "output_cost_per_mtok",
|
|
154
|
+
"cache_write_cost_per_mtok", "cache_read_cost_per_mtok",
|
|
155
|
+
):
|
|
156
|
+
val = getattr(self, name)
|
|
157
|
+
if val < 0:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"PricingEntry(vendor={self.vendor!r}, model={self.model!r}): "
|
|
160
|
+
f"{name} 不能为负数(当前值: {val})"
|
|
161
|
+
)
|
|
162
|
+
return self
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def currency(self) -> str:
|
|
166
|
+
"""本条目的币种代码 (``'USD'`` / ``'CNY'``).
|
|
167
|
+
|
|
168
|
+
从 PrivateAttr 读取,若无显式币种前缀则默认 ``'USD'``。
|
|
169
|
+
"""
|
|
170
|
+
return self._currency
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class VendorConfig(BaseModel):
|
|
174
|
+
"""单个 Vendor 的统一配置(支持所有供应商类型).
|
|
175
|
+
|
|
176
|
+
.. note::
|
|
177
|
+
当 ``tiers`` 未配置时,vendors 列表顺序即为优先级;
|
|
178
|
+
配置了 ``tiers`` 后,优先级由其显式指定。
|
|
179
|
+
|
|
180
|
+
无 circuit_breaker 的 Vendor 为终端层(不触发故障转移)。
|
|
181
|
+
|
|
182
|
+
各供应商类型的专属字段已通过 ``Field(description=...)`` 标注适用范围,
|
|
183
|
+
非当前 vendor 类型的专属字段在验证阶段会发出 warning 日志。
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
vendor: VendorType
|
|
187
|
+
|
|
188
|
+
# ── 通用字段(所有供应商共用) ──────────────────────────────
|
|
189
|
+
enabled: bool = True
|
|
190
|
+
base_url: str = Field(
|
|
191
|
+
default="",
|
|
192
|
+
description="供应商 API 基础 URL;留空时使用各供应商默认值",
|
|
193
|
+
)
|
|
194
|
+
timeout_ms: int = Field(
|
|
195
|
+
default=300000,
|
|
196
|
+
description="请求超时时间(毫秒),适用于所有供应商",
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# ── Copilot 专属字段 ─────────────────────────────────────────────
|
|
200
|
+
github_token: str = Field(
|
|
201
|
+
default="",
|
|
202
|
+
description="[copilot] GitHub Personal Access Token 或 OAuth Token",
|
|
203
|
+
)
|
|
204
|
+
account_type: str = Field(
|
|
205
|
+
default="individual",
|
|
206
|
+
description="[copilot] Copilot 账户类型:individual / business / enterprise",
|
|
207
|
+
)
|
|
208
|
+
token_url: str = Field(
|
|
209
|
+
default="https://api.github.com/copilot_internal/v2/token",
|
|
210
|
+
description="[copilot] Copilot Token 交换端点 URL",
|
|
211
|
+
)
|
|
212
|
+
models_cache_ttl_seconds: int = Field(
|
|
213
|
+
default=300,
|
|
214
|
+
description="[copilot] 模型列表缓存 TTL(秒)",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# ── Antigravity 专属字段 ────────────────────────────────────────
|
|
218
|
+
client_id: str = Field(
|
|
219
|
+
default="",
|
|
220
|
+
description="[antigravity] Google OAuth2 Client ID",
|
|
221
|
+
)
|
|
222
|
+
client_secret: str = Field(
|
|
223
|
+
default="",
|
|
224
|
+
description="[antigravity] Google OAuth2 Client Secret",
|
|
225
|
+
)
|
|
226
|
+
refresh_token: str = Field(
|
|
227
|
+
default="",
|
|
228
|
+
description="[antigravity] Google OAuth2 Refresh Token",
|
|
229
|
+
)
|
|
230
|
+
model_endpoint: str = Field(
|
|
231
|
+
default="models/claude-sonnet-4-20250514",
|
|
232
|
+
description="[antigravity] Antigravity 模型端点路径",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# ── Zhipu 专属字段 ────────────────────────────────────────────
|
|
236
|
+
api_key: str = Field(
|
|
237
|
+
default="",
|
|
238
|
+
description="[zhipu] 智谱 GLM API Key",
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# ── 弹性配置 ──────────────────────────────────────────────
|
|
242
|
+
circuit_breaker: CircuitBreakerConfig | None = Field(
|
|
243
|
+
default=None,
|
|
244
|
+
description="熔断器配置;None 表示终端层(不触发故障转移)",
|
|
245
|
+
)
|
|
246
|
+
retry: RetryConfig = Field(default_factory=RetryConfig)
|
|
247
|
+
quota_guard: QuotaGuardConfig = Field(default_factory=QuotaGuardConfig)
|
|
248
|
+
weekly_quota_guard: QuotaGuardConfig = Field(default_factory=QuotaGuardConfig)
|
|
249
|
+
|
|
250
|
+
@model_validator(mode="after")
|
|
251
|
+
def _warn_irrelevant_fields(self) -> "VendorConfig":
|
|
252
|
+
"""对非当前 vendor 类型的非空专属字段发出 warning."""
|
|
253
|
+
exclusive = _VENDOR_EXCLUSIVE_FIELDS.get(self.vendor)
|
|
254
|
+
if not exclusive:
|
|
255
|
+
return self
|
|
256
|
+
for vendor_type, fields in _VENDOR_EXCLUSIVE_FIELDS.items():
|
|
257
|
+
if vendor_type == self.vendor:
|
|
258
|
+
continue
|
|
259
|
+
for field_name in fields:
|
|
260
|
+
value = getattr(self, field_name, None)
|
|
261
|
+
if value and value != getattr(VendorConfig.model_fields[field_name], "default", None):
|
|
262
|
+
logger.warning(
|
|
263
|
+
"VendorConfig(vendor=%s): 字段 %s 属于 %s 供应商,当前值将被忽略",
|
|
264
|
+
self.vendor, field_name, vendor_type,
|
|
265
|
+
)
|
|
266
|
+
return self
|
|
267
|
+
|
|
268
|
+
# ── 向后兼容别名(v2 移除)────────────────────────────────────
|
|
269
|
+
|
|
270
|
+
TierConfig = VendorConfig
|
|
271
|
+
BackendType = VendorType
|
|
272
|
+
_BACKEND_EXCLUSIVE_FIELDS = _VENDOR_EXCLUSIVE_FIELDS
|
|
273
|
+
|
|
274
|
+
__all__ = [
|
|
275
|
+
"VendorType", "VendorConfig", "ModelMappingRule", "ModelPricingEntry",
|
|
276
|
+
"TierConfig", "BackendType", # 向后兼容别名
|
|
277
|
+
"_COPILOT_FIELDS", "_ANTIGRAVITY_FIELDS", "_ZHIPU_FIELDS",
|
|
278
|
+
"_VENDOR_EXCLUSIVE_FIELDS", "_BACKEND_EXCLUSIVE_FIELDS",
|
|
279
|
+
]
|