coding-proxy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. coding/__init__.py +0 -0
  2. coding/proxy/__init__.py +3 -0
  3. coding/proxy/__main__.py +5 -0
  4. coding/proxy/auth/__init__.py +13 -0
  5. coding/proxy/auth/providers/__init__.py +6 -0
  6. coding/proxy/auth/providers/base.py +35 -0
  7. coding/proxy/auth/providers/github.py +133 -0
  8. coding/proxy/auth/providers/google.py +237 -0
  9. coding/proxy/auth/runtime.py +122 -0
  10. coding/proxy/auth/store.py +74 -0
  11. coding/proxy/cli/__init__.py +151 -0
  12. coding/proxy/cli/auth_commands.py +224 -0
  13. coding/proxy/compat/__init__.py +30 -0
  14. coding/proxy/compat/canonical.py +193 -0
  15. coding/proxy/compat/session_store.py +137 -0
  16. coding/proxy/config/__init__.py +6 -0
  17. coding/proxy/config/auth_schema.py +24 -0
  18. coding/proxy/config/loader.py +139 -0
  19. coding/proxy/config/resiliency.py +46 -0
  20. coding/proxy/config/routing.py +279 -0
  21. coding/proxy/config/schema.py +280 -0
  22. coding/proxy/config/server.py +23 -0
  23. coding/proxy/config/vendors.py +53 -0
  24. coding/proxy/convert/__init__.py +14 -0
  25. coding/proxy/convert/anthropic_to_gemini.py +352 -0
  26. coding/proxy/convert/anthropic_to_openai.py +352 -0
  27. coding/proxy/convert/gemini_sse_adapter.py +169 -0
  28. coding/proxy/convert/gemini_to_anthropic.py +98 -0
  29. coding/proxy/convert/openai_to_anthropic.py +88 -0
  30. coding/proxy/logging/__init__.py +49 -0
  31. coding/proxy/logging/db.py +308 -0
  32. coding/proxy/logging/stats.py +129 -0
  33. coding/proxy/model/__init__.py +93 -0
  34. coding/proxy/model/auth.py +32 -0
  35. coding/proxy/model/compat.py +153 -0
  36. coding/proxy/model/constants.py +21 -0
  37. coding/proxy/model/pricing.py +70 -0
  38. coding/proxy/model/token.py +64 -0
  39. coding/proxy/model/vendor.py +218 -0
  40. coding/proxy/pricing.py +100 -0
  41. coding/proxy/routing/__init__.py +47 -0
  42. coding/proxy/routing/circuit_breaker.py +152 -0
  43. coding/proxy/routing/error_classifier.py +67 -0
  44. coding/proxy/routing/executor.py +453 -0
  45. coding/proxy/routing/model_mapper.py +90 -0
  46. coding/proxy/routing/quota_guard.py +169 -0
  47. coding/proxy/routing/rate_limit.py +159 -0
  48. coding/proxy/routing/retry.py +82 -0
  49. coding/proxy/routing/router.py +84 -0
  50. coding/proxy/routing/session_manager.py +62 -0
  51. coding/proxy/routing/tier.py +171 -0
  52. coding/proxy/routing/usage_parser.py +193 -0
  53. coding/proxy/routing/usage_recorder.py +131 -0
  54. coding/proxy/server/__init__.py +1 -0
  55. coding/proxy/server/app.py +142 -0
  56. coding/proxy/server/factory.py +175 -0
  57. coding/proxy/server/request_normalizer.py +139 -0
  58. coding/proxy/server/responses.py +74 -0
  59. coding/proxy/server/routes.py +264 -0
  60. coding/proxy/streaming/__init__.py +1 -0
  61. coding/proxy/streaming/anthropic_compat.py +484 -0
  62. coding/proxy/vendors/__init__.py +29 -0
  63. coding/proxy/vendors/anthropic.py +44 -0
  64. coding/proxy/vendors/antigravity.py +328 -0
  65. coding/proxy/vendors/base.py +353 -0
  66. coding/proxy/vendors/copilot.py +702 -0
  67. coding/proxy/vendors/copilot_models.py +438 -0
  68. coding/proxy/vendors/copilot_token_manager.py +167 -0
  69. coding/proxy/vendors/copilot_urls.py +16 -0
  70. coding/proxy/vendors/mixins.py +71 -0
  71. coding/proxy/vendors/token_manager.py +128 -0
  72. coding/proxy/vendors/zhipu.py +243 -0
  73. coding_proxy-0.1.0.dist-info/METADATA +184 -0
  74. coding_proxy-0.1.0.dist-info/RECORD +77 -0
  75. coding_proxy-0.1.0.dist-info/WHEEL +4 -0
  76. coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
  77. coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,137 @@
1
+ """兼容层会话状态持久化.
2
+
3
+ 数据类型 ``CompatSessionRecord`` 已迁移至 :mod:`coding.proxy.model.compat`。
4
+ 本文件保留 ``CompatSessionStore`` 持久化管理器,类型通过 re-export 提供。
5
+
6
+ .. deprecated::
7
+ 未来版本将移除类型 re-export,请直接从 :mod:`coding.proxy.model.compat` 导入。
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import time
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import aiosqlite
18
+
19
+ # noqa: F401
20
+ from ..model.compat import CompatSessionRecord
21
+
22
+
23
+ _CREATE_TABLE = """
24
+ CREATE TABLE IF NOT EXISTS compat_session (
25
+ session_key TEXT PRIMARY KEY,
26
+ trace_id TEXT NOT NULL DEFAULT '',
27
+ tool_call_map_json TEXT NOT NULL DEFAULT '{}',
28
+ thought_signature_map_json TEXT NOT NULL DEFAULT '{}',
29
+ provider_state_json TEXT NOT NULL DEFAULT '{}',
30
+ state_version INTEGER NOT NULL DEFAULT 1,
31
+ updated_at_unix INTEGER NOT NULL DEFAULT 0
32
+ );
33
+ CREATE INDEX IF NOT EXISTS idx_compat_session_updated_at ON compat_session(updated_at_unix);
34
+ """
35
+
36
+
37
+ class CompatSessionStore:
38
+ """兼容层会话状态 SQLite 持久化存储."""
39
+
40
+ def __init__(self, db_path: Path, ttl_seconds: int = 86400) -> None:
41
+ self._db_path = db_path
42
+ self._ttl_seconds = ttl_seconds
43
+ self._db: aiosqlite.Connection | None = None
44
+
45
+ async def init(self) -> None:
46
+ self._db_path.parent.mkdir(parents=True, exist_ok=True)
47
+ self._db = await aiosqlite.connect(str(self._db_path))
48
+ self._db.row_factory = aiosqlite.Row
49
+ await self._db.execute("PRAGMA journal_mode=WAL")
50
+ await self._db.executescript(_CREATE_TABLE)
51
+ await self._purge_expired()
52
+ await self._db.commit()
53
+
54
+ async def get(self, session_key: str) -> CompatSessionRecord | None:
55
+ if not self._db:
56
+ return None
57
+ cursor = await self._db.execute(
58
+ """SELECT session_key, trace_id, tool_call_map_json, thought_signature_map_json,
59
+ provider_state_json, state_version, updated_at_unix
60
+ FROM compat_session WHERE session_key = ?""",
61
+ (session_key,),
62
+ )
63
+ row = await cursor.fetchone()
64
+ if row is None:
65
+ return None
66
+ record = CompatSessionRecord(
67
+ session_key=row["session_key"],
68
+ trace_id=row["trace_id"],
69
+ tool_call_map=_loads_dict(row["tool_call_map_json"]),
70
+ thought_signature_map=_loads_dict(row["thought_signature_map_json"]),
71
+ provider_state=_loads_dict(row["provider_state_json"]),
72
+ state_version=row["state_version"],
73
+ updated_at_unix=row["updated_at_unix"],
74
+ )
75
+ if self._is_expired(record.updated_at_unix):
76
+ await self.delete(session_key)
77
+ return None
78
+ return record
79
+
80
+ async def upsert(self, record: CompatSessionRecord) -> None:
81
+ if not self._db:
82
+ return
83
+ updated_at = int(time.time())
84
+ await self._db.execute(
85
+ """INSERT INTO compat_session (
86
+ session_key, trace_id, tool_call_map_json, thought_signature_map_json,
87
+ provider_state_json, state_version, updated_at_unix
88
+ ) VALUES (?, ?, ?, ?, ?, ?, ?)
89
+ ON CONFLICT(session_key) DO UPDATE SET
90
+ trace_id=excluded.trace_id,
91
+ tool_call_map_json=excluded.tool_call_map_json,
92
+ thought_signature_map_json=excluded.thought_signature_map_json,
93
+ provider_state_json=excluded.provider_state_json,
94
+ state_version=excluded.state_version,
95
+ updated_at_unix=excluded.updated_at_unix""",
96
+ (
97
+ record.session_key,
98
+ record.trace_id,
99
+ json.dumps(record.tool_call_map, ensure_ascii=False, sort_keys=True),
100
+ json.dumps(record.thought_signature_map, ensure_ascii=False, sort_keys=True),
101
+ json.dumps(record.provider_state, ensure_ascii=False, sort_keys=True),
102
+ record.state_version,
103
+ updated_at,
104
+ ),
105
+ )
106
+ await self._db.commit()
107
+
108
+ async def delete(self, session_key: str) -> None:
109
+ if not self._db:
110
+ return
111
+ await self._db.execute("DELETE FROM compat_session WHERE session_key = ?", (session_key,))
112
+ await self._db.commit()
113
+
114
+ async def close(self) -> None:
115
+ if self._db:
116
+ await self._db.close()
117
+ self._db = None
118
+
119
+ async def _purge_expired(self) -> None:
120
+ if not self._db:
121
+ return
122
+ threshold = int(time.time()) - self._ttl_seconds
123
+ await self._db.execute(
124
+ "DELETE FROM compat_session WHERE updated_at_unix > 0 AND updated_at_unix < ?",
125
+ (threshold,),
126
+ )
127
+
128
+ def _is_expired(self, updated_at_unix: int) -> bool:
129
+ return updated_at_unix > 0 and (int(time.time()) - updated_at_unix) > self._ttl_seconds
130
+
131
+
132
+ def _loads_dict(raw: str) -> dict[str, Any]:
133
+ try:
134
+ value = json.loads(raw)
135
+ except (json.JSONDecodeError, TypeError):
136
+ return {}
137
+ return value if isinstance(value, dict) else {}
@@ -0,0 +1,6 @@
1
+ """配置模块."""
2
+
3
+ from .loader import load_config
4
+ from .schema import ProxyConfig
5
+
6
+ __all__ = ["load_config", "ProxyConfig"]
@@ -0,0 +1,24 @@
1
+ """OAuth 认证配置模型."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class AuthConfig(BaseModel):
9
+ """OAuth 登录配置.
10
+
11
+ .. note::
12
+ 各 Provider 的硬编码默认值(google.py / github.py)应与此配置保持同步。
13
+ 此配置作为运行时注入值的权威来源(Single Source of Truth)。
14
+ """
15
+
16
+ github_client_id: str = "Iv1.b507a08c87ecfe98"
17
+ google_client_id: str = (
18
+ "1071006060591-tmhssin2h21lcre235vtolojh4g403ep"
19
+ ".apps.googleusercontent.com"
20
+ )
21
+ google_client_secret: str = "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"
22
+ token_store_path: str = "~/.coding-proxy/tokens.json"
23
+
24
+ __all__ = ["AuthConfig"]
@@ -0,0 +1,139 @@
1
+ """YAML 配置加载 + 环境变量展开 + 示例配置深度合并."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import re
8
+ from pathlib import Path
9
+
10
+ import yaml
11
+
12
+ from .schema import ProxyConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ _ENV_VAR_RE = re.compile(r"\$\{([^}]+)\}")
17
+
18
+ # ── Legacy flat 格式字段集合(用于检测旧配置,避免与 example vendors 冲突) ──
19
+ _LEGACY_FLAT_KEYS: frozenset[str] = frozenset({
20
+ "primary", "copilot", "antigravity", "fallback",
21
+ "circuit_breaker", "copilot_circuit_breaker", "antigravity_circuit_breaker",
22
+ "quota_guard", "copilot_quota_guard", "antigravity_quota_guard",
23
+ })
24
+
25
+
26
+ def _expand_env(value: str) -> str:
27
+ """将 ${VAR} 替换为环境变量值."""
28
+ def _replacer(match: re.Match) -> str:
29
+ var_name = match.group(1)
30
+ return os.environ.get(var_name, "")
31
+ return _ENV_VAR_RE.sub(_replacer, value)
32
+
33
+
34
+ def _expand_env_recursive(obj):
35
+ """递归展开字典中的环境变量."""
36
+ if isinstance(obj, dict):
37
+ return {k: _expand_env_recursive(v) for k, v in obj.items()}
38
+ if isinstance(obj, list):
39
+ return [_expand_env_recursive(v) for v in obj]
40
+ if isinstance(obj, str):
41
+ return _expand_env(obj)
42
+ return obj
43
+
44
+
45
+ def _deep_merge(defaults: dict, override: dict) -> dict:
46
+ """深度合并两个字典.
47
+
48
+ 合并策略:
49
+ - dict + dict → 递归合并子键(支持部分覆盖嵌套配置)
50
+ - list → override 完整替换 default(有序集合,顺序敏感)
51
+ - 标量 → override 替换 default
52
+ - override 中不存在于 defaults 的新键直接添加
53
+
54
+ Args:
55
+ defaults: 基础字典(通常来自 config.example.yaml)
56
+ override: 覆盖字典(来自用户配置文件)
57
+
58
+ Returns:
59
+ 合并后的新字典
60
+ """
61
+ result = dict(defaults)
62
+ for key, ov in override.items():
63
+ if key not in result:
64
+ result[key] = ov
65
+ elif isinstance(result.get(key), dict) and isinstance(ov, dict):
66
+ result[key] = _deep_merge(result[key], ov)
67
+ else:
68
+ result[key] = ov
69
+ return result
70
+
71
+
72
+ def _get_example_config_path() -> Path | None:
73
+ """定位 config.example.yaml 文件路径.
74
+
75
+ 搜索策略:从 loader.py 所在目录向上回溯到项目根目录查找。
76
+ 路径链:config/ → proxy/ → coding/ → src/ → project_root
77
+
78
+ Returns:
79
+ 文件路径对象,未找到时返回 None(触发降级至 Pydantic 默认值)
80
+ """
81
+ current = Path(__file__).resolve().parent
82
+ # 先检查当前层,再逐级向上(共检查 5 层:config/ ~ project_root/)
83
+ for _ in range(5):
84
+ candidate = current / "config.example.yaml"
85
+ if candidate.is_file():
86
+ return candidate
87
+ current = current.parent
88
+ logger.debug("未找到 config.example.yaml,将使用 Pydantic 默认值")
89
+ return None
90
+
91
+
92
+ def load_config(path: Path | None = None) -> ProxyConfig:
93
+ """加载配置文件,以 config.example.yaml 为基础进行深度合并.
94
+
95
+ 加载优先级(低→高):
96
+ 1. config.example.yaml 内置完整默认值
97
+ 2. 用户配置文件(CWD/config.yaml > ~/.coding-proxy/config.yaml > -c 指定路径)
98
+
99
+ 环境变量展开(${VAR})在深度合并之后执行,确保用户可通过环境变量覆盖任意字段。
100
+ """
101
+ # ── 第 1 步:确定并加载用户配置 ─────────────────────────────
102
+ user_raw: dict = {}
103
+ if path is None:
104
+ candidates = [
105
+ Path("config.yaml"),
106
+ Path("~/.coding-proxy/config.yaml").expanduser(),
107
+ ]
108
+ for candidate in candidates:
109
+ if candidate.exists():
110
+ path = candidate
111
+ break
112
+
113
+ if path and path.exists():
114
+ with open(path) as f:
115
+ user_raw = yaml.safe_load(f) or {}
116
+
117
+ # ── 第 2 步:加载示例默认配置 ─────────────────────────────
118
+ example_path = _get_example_config_path()
119
+ if example_path is None:
120
+ # 降级:无示例文件时使用纯 Pydantic 默认值(向后兼容)
121
+ expanded = _expand_env_recursive(user_raw)
122
+ return ProxyConfig(**expanded)
123
+
124
+ with open(example_path) as f:
125
+ defaults = yaml.safe_load(f) or {}
126
+
127
+ # ── Legacy 兼容:旧 flat 格式用户配置不应继承 example 的 vendors ──
128
+ # 当用户使用 legacy 字段时,移除 defaults 中的 vendors,
129
+ # 让 ProxyConfig._migrate_legacy_fields 迁移器正常接管 vendors 构建
130
+ if any(k in user_raw for k in _LEGACY_FLAT_KEYS):
131
+ defaults.pop("vendors", None)
132
+
133
+ # ── 第 3 步:深度合并 ─────────────────────────────────────
134
+ merged = _deep_merge(defaults, user_raw)
135
+
136
+ # ── 第 4 步:环境变量展开(必须在合并之后) ────────────────
137
+ expanded = _expand_env_recursive(merged)
138
+
139
+ return ProxyConfig(**expanded)
@@ -0,0 +1,46 @@
1
+ """弹性设施配置模型(熔断、重试、故障转移、配额守卫)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class CircuitBreakerConfig(BaseModel):
9
+ failure_threshold: int = 3
10
+ recovery_timeout_seconds: int = 300
11
+ success_threshold: int = 2
12
+ max_recovery_seconds: int = 3600
13
+
14
+
15
+ class RetryConfig(BaseModel):
16
+ """传输层重试配置."""
17
+
18
+ max_retries: int = 2
19
+ initial_delay_ms: int = 500
20
+ max_delay_ms: int = 5000
21
+ backoff_multiplier: float = 2.0
22
+ jitter: bool = True
23
+
24
+
25
+ class FailoverConfig(BaseModel):
26
+ status_codes: list[int] = Field(
27
+ default=[429, 403, 503, 500],
28
+ )
29
+ error_types: list[str] = Field(
30
+ default=["rate_limit_error", "overloaded_error", "api_error"],
31
+ )
32
+ error_message_patterns: list[str] = Field(
33
+ default=["quota", "limit exceeded", "usage cap", "capacity"],
34
+ )
35
+
36
+
37
+ class QuotaGuardConfig(BaseModel):
38
+ enabled: bool = False
39
+ token_budget: int = 0
40
+ window_hours: float = 5.0
41
+ threshold_percent: float = 99.0
42
+ probe_interval_seconds: int = 300
43
+
44
+ __all__ = [
45
+ "CircuitBreakerConfig", "RetryConfig", "FailoverConfig", "QuotaGuardConfig",
46
+ ]
@@ -0,0 +1,279 @@
1
+ """路由层配置模型(供应商类型、Vendor、模型映射、定价)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import re
7
+ from typing import Annotated, Any, Literal
8
+
9
+ from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr, model_validator
10
+
11
+ from .resiliency import CircuitBreakerConfig, QuotaGuardConfig, RetryConfig
12
+
13
+ # ── 价格字段解析($ / ¥ 前缀支持) ──────────────────────────
14
+
15
+ _PRICE_RE = re.compile(r"^([$\u00a5])\s*(.+)$")
16
+
17
+
18
+ def _price_to_float(v: Any) -> float:
19
+ """Pydantic BeforeValidator: 将 $/¥ 前缀字符串 / 纯数字统一转为 float."""
20
+ if isinstance(v, (int, float)):
21
+ return float(v)
22
+ if isinstance(v, str):
23
+ m = _PRICE_RE.match(v.strip())
24
+ if m:
25
+ return float(m.group(2))
26
+ return float(v)
27
+ return float(v)
28
+
29
+
30
+ def _detect_currency(v: Any) -> str | None:
31
+ """从原始值中检测币种前缀. 返回 ``'USD'``/``'CNY'``/``None``."""
32
+ if isinstance(v, str):
33
+ m = _PRICE_RE.match(v.strip())
34
+ if m:
35
+ return "USD" if m.group(1) == "$" else "CNY"
36
+ return None
37
+
38
+
39
+ PriceField = Annotated[float, BeforeValidator(_price_to_float)]
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # ── 供应商专属字段分组映射 ──────────────────────────────────────
44
+ # 每个 vendor 类型对应其专属字段集合,用于 VendorConfig 的语义标注与校验
45
+
46
+ _COPILOT_FIELDS: frozenset[str] = frozenset({
47
+ "github_token", "account_type", "token_url", "models_cache_ttl_seconds",
48
+ })
49
+ _ANTIGRAVITY_FIELDS: frozenset[str] = frozenset({
50
+ "client_id", "client_secret", "refresh_token", "model_endpoint",
51
+ })
52
+ _ZHIPU_FIELDS: frozenset[str] = frozenset({"api_key",})
53
+
54
+ _VENDOR_EXCLUSIVE_FIELDS: dict[str, frozenset[str]] = {
55
+ "copilot": _COPILOT_FIELDS,
56
+ "antigravity": _ANTIGRAVITY_FIELDS,
57
+ "zhipu": _ZHIPU_FIELDS,
58
+ }
59
+
60
+ VendorType = Literal["anthropic", "copilot", "antigravity", "zhipu"]
61
+
62
+
63
+ class ModelMappingRule(BaseModel):
64
+ pattern: str
65
+ target: str
66
+ is_regex: bool = False
67
+ vendors: list[str] = Field(default_factory=list)
68
+
69
+
70
+ class ModelPricingEntry(BaseModel):
71
+ """单个模型的定价配置(支持 $ / ¥ 前缀指定币种).
72
+
73
+ 用法示例::
74
+
75
+ input_cost_per_mtok: $3.0 # USD
76
+ output_cost_per_mtok: \\u00a53.2 # CNY (¥)
77
+ cache_read_cost_per_mtok: 0.5 # 无前缀,默认 USD
78
+
79
+ 向后兼容:不带前缀的纯数字默认视为 USD。
80
+ """
81
+
82
+ model_config = {"extra": "allow"}
83
+
84
+ vendor: str # 供应商名称(对应 usage 表"供应商"列)
85
+ model: str # 实际模型名(对应 usage 表"实际模型"列)
86
+ input_cost_per_mtok: PriceField = 0.0 # 输入 Token 单价
87
+ output_cost_per_mtok: PriceField = 0.0 # 输出 Token 单价
88
+ cache_write_cost_per_mtok: PriceField = 0.0 # 缓存创建 Token 单价
89
+ cache_read_cost_per_mtok: PriceField = 0.0 # 缓存读取 Token 单价
90
+
91
+ # ── 内部状态:币种信息(不参与序列化) ───────────────────
92
+ _currency: str = PrivateAttr(default="USD")
93
+
94
+ # ── 币种一致性校验与提取 ──────────────────────────────
95
+
96
+ @model_validator(mode="before")
97
+ @classmethod
98
+ def _check_currency_consistency(cls, data: Any) -> Any:
99
+ """校验同一 entry 内所有非零价格的币种一致性,并提取币种."""
100
+ if not isinstance(data, dict):
101
+ return data
102
+
103
+ price_field_names = [
104
+ "input_cost_per_mtok", "output_cost_per_mtok",
105
+ "cache_write_cost_per_mtok", "cache_read_cost_per_mtok",
106
+ ]
107
+ currencies: set[str] = set()
108
+
109
+ for name in price_field_names:
110
+ raw = data.get(name)
111
+ if raw is None or raw == 0 or raw == 0.0:
112
+ continue
113
+ detected = _detect_currency(raw)
114
+ if detected:
115
+ currencies.add(detected)
116
+
117
+ if len(currencies) > 1:
118
+ vendor = data.get("vendor", "?")
119
+ model = data.get("model", "?")
120
+ raise ValueError(
121
+ f"PricingEntry(vendor={vendor!r}, model={model!r}): "
122
+ f"检测到混合币种 {sorted(currencies)},"
123
+ "同一模型的所有单价必须使用相同币种 ($ 或 ¥)"
124
+ )
125
+
126
+ # 将检测到的币种暂存到临时键(mode=after 中消费后清理)
127
+ if currencies:
128
+ data["__detected_currency__"] = currencies.pop()
129
+
130
+ return data
131
+
132
+ @model_validator(mode="after")
133
+ def _capture_currency(self) -> "ModelPricingEntry":
134
+ """将 mode=before 检测到的币种转移到 PrivateAttr,清理临时键."""
135
+ detected = getattr(self, "__detected_currency__", None)
136
+ if detected:
137
+ self._currency = detected
138
+ # 从 __pydantic_extra__ 中移除临时键,避免序列化泄露
139
+ if hasattr(self, "__pydantic_extra__") and "__detected_currency__" in self.__pydantic_extra__:
140
+ del self.__pydantic_extra__["__detected_currency__"]
141
+ else:
142
+ # 回退:直接删除实例属性
143
+ try:
144
+ object.__delattr__(self, "__detected_currency__")
145
+ except AttributeError:
146
+ pass
147
+ return self
148
+
149
+ @model_validator(mode="after")
150
+ def _validate_non_negative(self) -> "ModelPricingEntry":
151
+ """校验所有价格字段非负."""
152
+ for name in (
153
+ "input_cost_per_mtok", "output_cost_per_mtok",
154
+ "cache_write_cost_per_mtok", "cache_read_cost_per_mtok",
155
+ ):
156
+ val = getattr(self, name)
157
+ if val < 0:
158
+ raise ValueError(
159
+ f"PricingEntry(vendor={self.vendor!r}, model={self.model!r}): "
160
+ f"{name} 不能为负数(当前值: {val})"
161
+ )
162
+ return self
163
+
164
+ @property
165
+ def currency(self) -> str:
166
+ """本条目的币种代码 (``'USD'`` / ``'CNY'``).
167
+
168
+ 从 PrivateAttr 读取,若无显式币种前缀则默认 ``'USD'``。
169
+ """
170
+ return self._currency
171
+
172
+
173
+ class VendorConfig(BaseModel):
174
+ """单个 Vendor 的统一配置(支持所有供应商类型).
175
+
176
+ .. note::
177
+ 当 ``tiers`` 未配置时,vendors 列表顺序即为优先级;
178
+ 配置了 ``tiers`` 后,优先级由其显式指定。
179
+
180
+ 无 circuit_breaker 的 Vendor 为终端层(不触发故障转移)。
181
+
182
+ 各供应商类型的专属字段已通过 ``Field(description=...)`` 标注适用范围,
183
+ 非当前 vendor 类型的专属字段在验证阶段会发出 warning 日志。
184
+ """
185
+
186
+ vendor: VendorType
187
+
188
+ # ── 通用字段(所有供应商共用) ──────────────────────────────
189
+ enabled: bool = True
190
+ base_url: str = Field(
191
+ default="",
192
+ description="供应商 API 基础 URL;留空时使用各供应商默认值",
193
+ )
194
+ timeout_ms: int = Field(
195
+ default=300000,
196
+ description="请求超时时间(毫秒),适用于所有供应商",
197
+ )
198
+
199
+ # ── Copilot 专属字段 ─────────────────────────────────────────────
200
+ github_token: str = Field(
201
+ default="",
202
+ description="[copilot] GitHub Personal Access Token 或 OAuth Token",
203
+ )
204
+ account_type: str = Field(
205
+ default="individual",
206
+ description="[copilot] Copilot 账户类型:individual / business / enterprise",
207
+ )
208
+ token_url: str = Field(
209
+ default="https://api.github.com/copilot_internal/v2/token",
210
+ description="[copilot] Copilot Token 交换端点 URL",
211
+ )
212
+ models_cache_ttl_seconds: int = Field(
213
+ default=300,
214
+ description="[copilot] 模型列表缓存 TTL(秒)",
215
+ )
216
+
217
+ # ── Antigravity 专属字段 ────────────────────────────────────────
218
+ client_id: str = Field(
219
+ default="",
220
+ description="[antigravity] Google OAuth2 Client ID",
221
+ )
222
+ client_secret: str = Field(
223
+ default="",
224
+ description="[antigravity] Google OAuth2 Client Secret",
225
+ )
226
+ refresh_token: str = Field(
227
+ default="",
228
+ description="[antigravity] Google OAuth2 Refresh Token",
229
+ )
230
+ model_endpoint: str = Field(
231
+ default="models/claude-sonnet-4-20250514",
232
+ description="[antigravity] Antigravity 模型端点路径",
233
+ )
234
+
235
+ # ── Zhipu 专属字段 ────────────────────────────────────────────
236
+ api_key: str = Field(
237
+ default="",
238
+ description="[zhipu] 智谱 GLM API Key",
239
+ )
240
+
241
+ # ── 弹性配置 ──────────────────────────────────────────────
242
+ circuit_breaker: CircuitBreakerConfig | None = Field(
243
+ default=None,
244
+ description="熔断器配置;None 表示终端层(不触发故障转移)",
245
+ )
246
+ retry: RetryConfig = Field(default_factory=RetryConfig)
247
+ quota_guard: QuotaGuardConfig = Field(default_factory=QuotaGuardConfig)
248
+ weekly_quota_guard: QuotaGuardConfig = Field(default_factory=QuotaGuardConfig)
249
+
250
+ @model_validator(mode="after")
251
+ def _warn_irrelevant_fields(self) -> "VendorConfig":
252
+ """对非当前 vendor 类型的非空专属字段发出 warning."""
253
+ exclusive = _VENDOR_EXCLUSIVE_FIELDS.get(self.vendor)
254
+ if not exclusive:
255
+ return self
256
+ for vendor_type, fields in _VENDOR_EXCLUSIVE_FIELDS.items():
257
+ if vendor_type == self.vendor:
258
+ continue
259
+ for field_name in fields:
260
+ value = getattr(self, field_name, None)
261
+ if value and value != getattr(VendorConfig.model_fields[field_name], "default", None):
262
+ logger.warning(
263
+ "VendorConfig(vendor=%s): 字段 %s 属于 %s 供应商,当前值将被忽略",
264
+ self.vendor, field_name, vendor_type,
265
+ )
266
+ return self
267
+
268
+ # ── 向后兼容别名(v2 移除)────────────────────────────────────
269
+
270
+ TierConfig = VendorConfig
271
+ BackendType = VendorType
272
+ _BACKEND_EXCLUSIVE_FIELDS = _VENDOR_EXCLUSIVE_FIELDS
273
+
274
+ __all__ = [
275
+ "VendorType", "VendorConfig", "ModelMappingRule", "ModelPricingEntry",
276
+ "TierConfig", "BackendType", # 向后兼容别名
277
+ "_COPILOT_FIELDS", "_ANTIGRAVITY_FIELDS", "_ZHIPU_FIELDS",
278
+ "_VENDOR_EXCLUSIVE_FIELDS", "_BACKEND_EXCLUSIVE_FIELDS",
279
+ ]