coding-proxy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. coding/__init__.py +0 -0
  2. coding/proxy/__init__.py +3 -0
  3. coding/proxy/__main__.py +5 -0
  4. coding/proxy/auth/__init__.py +13 -0
  5. coding/proxy/auth/providers/__init__.py +6 -0
  6. coding/proxy/auth/providers/base.py +35 -0
  7. coding/proxy/auth/providers/github.py +133 -0
  8. coding/proxy/auth/providers/google.py +237 -0
  9. coding/proxy/auth/runtime.py +122 -0
  10. coding/proxy/auth/store.py +74 -0
  11. coding/proxy/cli/__init__.py +151 -0
  12. coding/proxy/cli/auth_commands.py +224 -0
  13. coding/proxy/compat/__init__.py +30 -0
  14. coding/proxy/compat/canonical.py +193 -0
  15. coding/proxy/compat/session_store.py +137 -0
  16. coding/proxy/config/__init__.py +6 -0
  17. coding/proxy/config/auth_schema.py +24 -0
  18. coding/proxy/config/loader.py +139 -0
  19. coding/proxy/config/resiliency.py +46 -0
  20. coding/proxy/config/routing.py +279 -0
  21. coding/proxy/config/schema.py +280 -0
  22. coding/proxy/config/server.py +23 -0
  23. coding/proxy/config/vendors.py +53 -0
  24. coding/proxy/convert/__init__.py +14 -0
  25. coding/proxy/convert/anthropic_to_gemini.py +352 -0
  26. coding/proxy/convert/anthropic_to_openai.py +352 -0
  27. coding/proxy/convert/gemini_sse_adapter.py +169 -0
  28. coding/proxy/convert/gemini_to_anthropic.py +98 -0
  29. coding/proxy/convert/openai_to_anthropic.py +88 -0
  30. coding/proxy/logging/__init__.py +49 -0
  31. coding/proxy/logging/db.py +308 -0
  32. coding/proxy/logging/stats.py +129 -0
  33. coding/proxy/model/__init__.py +93 -0
  34. coding/proxy/model/auth.py +32 -0
  35. coding/proxy/model/compat.py +153 -0
  36. coding/proxy/model/constants.py +21 -0
  37. coding/proxy/model/pricing.py +70 -0
  38. coding/proxy/model/token.py +64 -0
  39. coding/proxy/model/vendor.py +218 -0
  40. coding/proxy/pricing.py +100 -0
  41. coding/proxy/routing/__init__.py +47 -0
  42. coding/proxy/routing/circuit_breaker.py +152 -0
  43. coding/proxy/routing/error_classifier.py +67 -0
  44. coding/proxy/routing/executor.py +453 -0
  45. coding/proxy/routing/model_mapper.py +90 -0
  46. coding/proxy/routing/quota_guard.py +169 -0
  47. coding/proxy/routing/rate_limit.py +159 -0
  48. coding/proxy/routing/retry.py +82 -0
  49. coding/proxy/routing/router.py +84 -0
  50. coding/proxy/routing/session_manager.py +62 -0
  51. coding/proxy/routing/tier.py +171 -0
  52. coding/proxy/routing/usage_parser.py +193 -0
  53. coding/proxy/routing/usage_recorder.py +131 -0
  54. coding/proxy/server/__init__.py +1 -0
  55. coding/proxy/server/app.py +142 -0
  56. coding/proxy/server/factory.py +175 -0
  57. coding/proxy/server/request_normalizer.py +139 -0
  58. coding/proxy/server/responses.py +74 -0
  59. coding/proxy/server/routes.py +264 -0
  60. coding/proxy/streaming/__init__.py +1 -0
  61. coding/proxy/streaming/anthropic_compat.py +484 -0
  62. coding/proxy/vendors/__init__.py +29 -0
  63. coding/proxy/vendors/anthropic.py +44 -0
  64. coding/proxy/vendors/antigravity.py +328 -0
  65. coding/proxy/vendors/base.py +353 -0
  66. coding/proxy/vendors/copilot.py +702 -0
  67. coding/proxy/vendors/copilot_models.py +438 -0
  68. coding/proxy/vendors/copilot_token_manager.py +167 -0
  69. coding/proxy/vendors/copilot_urls.py +16 -0
  70. coding/proxy/vendors/mixins.py +71 -0
  71. coding/proxy/vendors/token_manager.py +128 -0
  72. coding/proxy/vendors/zhipu.py +243 -0
  73. coding_proxy-0.1.0.dist-info/METADATA +184 -0
  74. coding_proxy-0.1.0.dist-info/RECORD +77 -0
  75. coding_proxy-0.1.0.dist-info/WHEEL +4 -0
  76. coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
  77. coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
coding/__init__.py ADDED
File without changes
@@ -0,0 +1,3 @@
1
+ """coding-proxy: Claude Code 多后端智能代理."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,5 @@
1
+ """入口: uv run coding-proxy."""
2
+
3
+ from coding.proxy.cli import app
4
+
5
+ app()
@@ -0,0 +1,13 @@
1
+ """OAuth 认证管理模块."""
2
+
3
+ from .providers.base import OAuthProvider
4
+ from .providers.github import GitHubDeviceFlowProvider
5
+ from .providers.google import GoogleOAuthProvider
6
+ from .runtime import ReauthState, RuntimeReauthCoordinator
7
+ from .store import ProviderTokens, TokenStoreManager
8
+
9
+ __all__ = [
10
+ "OAuthProvider", "GitHubDeviceFlowProvider", "GoogleOAuthProvider",
11
+ "RuntimeReauthCoordinator", "ReauthState",
12
+ "ProviderTokens", "TokenStoreManager",
13
+ ]
@@ -0,0 +1,6 @@
1
+ """OAuth Provider 实现."""
2
+
3
+ from .github import GitHubDeviceFlowProvider
4
+ from .google import GoogleOAuthProvider
5
+
6
+ __all__ = ["GitHubDeviceFlowProvider", "GoogleOAuthProvider"]
@@ -0,0 +1,35 @@
1
+ """OAuth Provider 抽象基类."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ from ..store import ProviderTokens
8
+
9
+
10
+ class OAuthProvider(ABC):
11
+ """OAuth 登录提供者抽象基类."""
12
+
13
+ @abstractmethod
14
+ def get_name(self) -> str:
15
+ """返回 Provider 唯一标识(用于 Token Store 的 key)."""
16
+
17
+ @abstractmethod
18
+ async def login(self) -> ProviderTokens:
19
+ """执行 OAuth 登录流程,返回获取到的 Token."""
20
+
21
+ @abstractmethod
22
+ async def refresh(self, tokens: ProviderTokens) -> ProviderTokens:
23
+ """使用 refresh_token 刷新 access_token."""
24
+
25
+ @abstractmethod
26
+ async def validate(self, tokens: ProviderTokens) -> bool:
27
+ """验证当前 Token 是否仍然有效."""
28
+
29
+ def needs_login(self, tokens: ProviderTokens) -> bool:
30
+ """判断是否需要重新登录(默认:无凭证或已过期且无 refresh_token)."""
31
+ if not tokens.has_credentials:
32
+ return True
33
+ if tokens.is_expired and not tokens.refresh_token:
34
+ return True
35
+ return False
@@ -0,0 +1,133 @@
1
+ """GitHub Device Authorization Flow — 浏览器免回调的 OAuth 登录."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import time
8
+ from typing import Any
9
+
10
+ import webbrowser
11
+
12
+ import httpx
13
+
14
+ from ..store import ProviderTokens
15
+ from .base import OAuthProvider
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # GitHub Copilot VS Code 扩展的公开 client_id
20
+ # SOT(权威源): coding.proxy.config.schema.AuthConfig.github_client_id
21
+ # 此处默认值仅作 fallback,生产环境应通过 config.yaml 的 auth 段覆盖
22
+ _COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
23
+
24
+ _DEVICE_CODE_URL = "https://github.com/login/device/code"
25
+ _ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
26
+ _POLL_INTERVAL = 5 # seconds
27
+ _MAX_POLL_ATTEMPTS = 60 # 5 minutes total
28
+ _COPILOT_PERMISSIVE_SCOPES = "read:user user:email repo workflow"
29
+
30
+
31
+ class GitHubDeviceFlowProvider(OAuthProvider):
32
+ """GitHub Device Authorization Flow 实现.
33
+
34
+ 无需本地 HTTP 服务器,用户在浏览器中输入 user_code 即可完成授权。
35
+ 获取的 GitHub OAuth token 可用于 Copilot token 交换。
36
+ """
37
+
38
+ def __init__(self, client_id: str = _COPILOT_CLIENT_ID) -> None:
39
+ self._client_id = client_id
40
+ self._http = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
41
+
42
+ def get_name(self) -> str:
43
+ return "github"
44
+
45
+ async def login(self) -> ProviderTokens:
46
+ """执行 GitHub Device Flow,返回 OAuth token."""
47
+ # Step 1: 请求 device code
48
+ resp = await self._http.post(
49
+ _DEVICE_CODE_URL,
50
+ data={"client_id": self._client_id, "scope": _COPILOT_PERMISSIVE_SCOPES},
51
+ headers={"accept": "application/json"},
52
+ )
53
+ resp.raise_for_status()
54
+ device_data: dict[str, Any] = resp.json()
55
+
56
+ user_code = device_data["user_code"]
57
+ verification_uri = device_data["verification_uri"]
58
+ device_code = device_data["device_code"]
59
+ interval = device_data.get("interval", _POLL_INTERVAL)
60
+
61
+ # 优先使用预填充 user_code 的完整链接
62
+ verification_url = device_data.get(
63
+ "verification_uri_complete", verification_uri
64
+ )
65
+
66
+ # Step 2: 引导用户在浏览器中授权
67
+ logger.info("请在浏览器中访问 %s 并输入代码: %s", verification_uri, user_code)
68
+ print(f"\n 🔗 请在浏览器中访问: {verification_uri}")
69
+ print(f" 📋 并输入代码: {user_code}\n")
70
+
71
+ webbrowser.open(verification_url)
72
+
73
+ # Step 3: 轮询等待用户完成授权
74
+ for attempt in range(_MAX_POLL_ATTEMPTS):
75
+ await asyncio.sleep(interval)
76
+
77
+ token_resp = await self._http.post(
78
+ _ACCESS_TOKEN_URL,
79
+ data={
80
+ "client_id": self._client_id,
81
+ "device_code": device_code,
82
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
83
+ },
84
+ headers={"accept": "application/json"},
85
+ )
86
+ token_data = token_resp.json()
87
+
88
+ if "access_token" in token_data:
89
+ logger.info("GitHub OAuth 授权成功")
90
+ return ProviderTokens(
91
+ access_token=token_data["access_token"],
92
+ token_type=token_data.get("token_type", "bearer"),
93
+ scope=token_data.get("scope", ""),
94
+ )
95
+
96
+ error = token_data.get("error", "")
97
+ if error == "authorization_pending":
98
+ continue
99
+ elif error == "slow_down":
100
+ interval += 5
101
+ continue
102
+ elif error == "expired_token":
103
+ raise RuntimeError("Device code 已过期,请重新登录")
104
+ elif error == "access_denied":
105
+ raise RuntimeError("用户拒绝了授权")
106
+ else:
107
+ raise RuntimeError(f"GitHub OAuth 错误: {error}")
108
+
109
+ raise RuntimeError("GitHub Device Flow 超时,请重试")
110
+
111
+ async def refresh(self, tokens: ProviderTokens) -> ProviderTokens:
112
+ """GitHub Device Flow 不支持 refresh_token,需要重新登录."""
113
+ return await self.login()
114
+
115
+ async def validate(self, tokens: ProviderTokens) -> bool:
116
+ """验证 GitHub token 是否有效."""
117
+ if not tokens.access_token:
118
+ return False
119
+ try:
120
+ resp = await self._http.get(
121
+ "https://api.github.com/user",
122
+ headers={
123
+ "authorization": f"token {tokens.access_token}",
124
+ "accept": "application/json",
125
+ },
126
+ )
127
+ return resp.status_code == 200
128
+ except httpx.HTTPError:
129
+ return False
130
+
131
+ async def close(self) -> None:
132
+ if not self._http.is_closed:
133
+ await self._http.aclose()
@@ -0,0 +1,237 @@
1
+ """Google OAuth2 Authorization Code Flow — 本地回调服务器."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import secrets
8
+ import time
9
+ from http.server import HTTPServer, BaseHTTPRequestHandler
10
+ from typing import Any
11
+ from urllib.parse import parse_qs, urlencode, urlparse
12
+
13
+ import httpx
14
+
15
+ from ..store import ProviderTokens
16
+ from .base import OAuthProvider
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Antigravity Enterprise 公开 OAuth 凭据
21
+ # SOT(权威源): coding.proxy.config.schema.AuthConfig
22
+ # 此处默认值仅作 fallback,生产环境应通过 config.yaml 的 auth 段覆盖
23
+ _DEFAULT_CLIENT_ID = (
24
+ "1071006060591-tmhssin2h21lcre235vtolojh4g403ep"
25
+ ".apps.googleusercontent.com"
26
+ )
27
+ _DEFAULT_CLIENT_SECRET = "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"
28
+
29
+ _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
30
+ _TOKEN_URL = "https://oauth2.googleapis.com/token"
31
+ _SCOPES = [
32
+ "https://www.googleapis.com/auth/cloud-platform",
33
+ "https://www.googleapis.com/auth/userinfo.email",
34
+ "https://www.googleapis.com/auth/userinfo.profile",
35
+ "https://www.googleapis.com/auth/cclog",
36
+ "https://www.googleapis.com/auth/experimentsandconfigs",
37
+ ]
38
+ _REQUIRED_SCOPE_SET = frozenset(_SCOPES)
39
+
40
+
41
+ class _CallbackHandler(BaseHTTPRequestHandler):
42
+ """OAuth 回调 HTTP 处理器.
43
+
44
+ 使用实例级 result dict 避免类属性在并发场景下的交叉污染.
45
+ """
46
+
47
+ def __init__(self, *args: Any, result: dict[str, str | None], **kwargs: Any) -> None:
48
+ self._result = result
49
+ super().__init__(*args, **kwargs)
50
+
51
+ def do_GET(self) -> None:
52
+ parsed = urlparse(self.path)
53
+ params = parse_qs(parsed.query)
54
+
55
+ if parsed.path == "/callback":
56
+ if "error" in params:
57
+ self._result["error"] = params["error"][0]
58
+ self._respond("授权失败,请关闭此页面返回终端。")
59
+ elif "code" in params and "state" in params:
60
+ self._result["auth_code"] = params["code"][0]
61
+ self._result["state"] = params["state"][0]
62
+ self._respond("授权成功!请关闭此页面返回终端。")
63
+ else:
64
+ self._respond("无效的回调参数。")
65
+ else:
66
+ self.send_response(404)
67
+ self.end_headers()
68
+
69
+ def _respond(self, message: str) -> None:
70
+ self.send_response(200)
71
+ self.send_header("content-type", "text/html; charset=utf-8")
72
+ self.end_headers()
73
+ self.wfile.write(f"<html><body><h2>{message}</h2></body></html>".encode())
74
+
75
+ def log_message(self, format: str, *args: Any) -> None:
76
+ pass # 静默 HTTP 日志
77
+
78
+
79
+ class GoogleOAuthProvider(OAuthProvider):
80
+ """Google OAuth2 Authorization Code Flow 实现.
81
+
82
+ 启动本地 HTTP 回调服务器捕获 authorization code,
83
+ 交换为 access_token + refresh_token。
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ client_id: str = _DEFAULT_CLIENT_ID,
89
+ client_secret: str = _DEFAULT_CLIENT_SECRET,
90
+ ) -> None:
91
+ self._client_id = client_id
92
+ self._client_secret = client_secret
93
+ self._http = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
94
+
95
+ def get_name(self) -> str:
96
+ return "google"
97
+
98
+ @staticmethod
99
+ def has_required_scopes(scope: str) -> bool:
100
+ granted = {item for item in scope.split() if item}
101
+ return _REQUIRED_SCOPE_SET.issubset(granted)
102
+
103
+ async def login(self) -> ProviderTokens:
104
+ """执行 Google OAuth2 Code Flow,返回 Token."""
105
+ state = secrets.token_urlsafe(32)
106
+ result: dict[str, str | None] = {"auth_code": None, "state": None, "error": None}
107
+
108
+ def _make_handler(*args: Any, **kwargs: Any) -> _CallbackHandler:
109
+ return _CallbackHandler(*args, result=result, **kwargs)
110
+
111
+ # 绑定到 port 0,由 OS 分配可用端口,避免 TOCTOU 竞态
112
+ server = HTTPServer(("127.0.0.1", 0), _make_handler)
113
+ redirect_port = server.server_address[1]
114
+ redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
115
+
116
+ params = urlencode({
117
+ "client_id": self._client_id,
118
+ "redirect_uri": redirect_uri,
119
+ "response_type": "code",
120
+ "scope": " ".join(_SCOPES),
121
+ "state": state,
122
+ "access_type": "offline",
123
+ "prompt": "consent",
124
+ })
125
+ auth_url = f"{_AUTH_URL}?{params}"
126
+
127
+ logger.info("请在浏览器中完成 Google 授权")
128
+ print(f"\n 🔗 请在浏览器中访问以下链接完成授权:\n")
129
+ print(f" {auth_url}\n")
130
+
131
+ # 打开浏览器
132
+ import webbrowser
133
+ webbrowser.open(auth_url)
134
+
135
+ # 等待回调
136
+ for _ in range(120): # 最多等 2 分钟
137
+ server.handle_request()
138
+ if result["auth_code"] or result["error"]:
139
+ break
140
+ await asyncio.sleep(1)
141
+
142
+ server.server_close()
143
+
144
+ if result["error"]:
145
+ raise RuntimeError(f"Google OAuth 错误: {result['error']}")
146
+
147
+ if not result["auth_code"]:
148
+ raise RuntimeError("Google OAuth 超时,请重试")
149
+
150
+ if result["state"] != state:
151
+ raise RuntimeError("OAuth state 不匹配,可能遭受 CSRF 攻击")
152
+
153
+ # 交换 code → token
154
+ return await self._exchange_code(result["auth_code"], redirect_uri)
155
+
156
+ async def _exchange_code(
157
+ self, code: str, redirect_uri: str
158
+ ) -> ProviderTokens:
159
+ """将 authorization code 交换为 access_token + refresh_token."""
160
+ resp = await self._http.post(
161
+ _TOKEN_URL,
162
+ data={
163
+ "client_id": self._client_id,
164
+ "client_secret": self._client_secret,
165
+ "code": code,
166
+ "redirect_uri": redirect_uri,
167
+ "grant_type": "authorization_code",
168
+ },
169
+ headers={"content-type": "application/x-www-form-urlencoded"},
170
+ )
171
+ resp.raise_for_status()
172
+ data = resp.json()
173
+
174
+ expires_at = 0.0
175
+ if "expires_in" in data:
176
+ expires_at = time.time() + data["expires_in"]
177
+
178
+ return ProviderTokens(
179
+ access_token=data.get("access_token", ""),
180
+ refresh_token=data.get("refresh_token", ""),
181
+ expires_at=expires_at,
182
+ scope=data.get("scope", ""),
183
+ token_type=data.get("token_type", "bearer"),
184
+ )
185
+
186
+ async def refresh(self, tokens: ProviderTokens) -> ProviderTokens:
187
+ """使用 refresh_token 刷新 access_token."""
188
+ if not tokens.refresh_token:
189
+ return await self.login()
190
+
191
+ resp = await self._http.post(
192
+ _TOKEN_URL,
193
+ data={
194
+ "client_id": self._client_id,
195
+ "client_secret": self._client_secret,
196
+ "refresh_token": tokens.refresh_token,
197
+ "grant_type": "refresh_token",
198
+ },
199
+ headers={"content-type": "application/x-www-form-urlencoded"},
200
+ )
201
+
202
+ if resp.status_code >= 400:
203
+ logger.warning("Google token refresh 失败,需要重新登录")
204
+ return await self.login()
205
+
206
+ data = resp.json()
207
+ expires_at = 0.0
208
+ if "expires_in" in data:
209
+ expires_at = time.time() + data["expires_in"]
210
+
211
+ return ProviderTokens(
212
+ access_token=data.get("access_token", ""),
213
+ refresh_token=tokens.refresh_token, # refresh_token 通常不变
214
+ expires_at=expires_at,
215
+ scope=data.get("scope", tokens.scope),
216
+ token_type=data.get("token_type", "bearer"),
217
+ )
218
+
219
+ async def validate(self, tokens: ProviderTokens) -> bool:
220
+ """验证 Google token 是否有效."""
221
+ if not tokens.access_token:
222
+ return False
223
+ try:
224
+ resp = await self._http.get(
225
+ "https://www.googleapis.com/oauth2/v1/tokeninfo",
226
+ params={"access_token": tokens.access_token},
227
+ )
228
+ if resp.status_code != 200:
229
+ return False
230
+ data = resp.json()
231
+ return self.has_required_scopes(data.get("scope", tokens.scope))
232
+ except httpx.HTTPError:
233
+ return False
234
+
235
+ async def close(self) -> None:
236
+ if not self._http.is_closed:
237
+ await self._http.aclose()
@@ -0,0 +1,122 @@
1
+ """运行时 OAuth 重认证协调器 — 后台触发浏览器登录并热更新凭证."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import enum
7
+ import logging
8
+ import time
9
+ from typing import Callable
10
+
11
+ from .providers.base import OAuthProvider
12
+ from .store import TokenStoreManager
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ReauthState(enum.Enum):
18
+ """重认证状态."""
19
+
20
+ IDLE = "idle"
21
+ PENDING = "pending"
22
+ COMPLETED = "completed"
23
+ FAILED = "failed"
24
+
25
+
26
+ class RuntimeReauthCoordinator:
27
+ """运行时 OAuth 重认证协调器.
28
+
29
+ 当 TokenManager 报告 needs_reauth=True 时,Router 调用
30
+ ``request_reauth()`` 在后台触发浏览器登录流程。
31
+
32
+ 与熔断器的协同:
33
+ - 重认证期间 TokenManager 持续抛 TokenAcquireError
34
+ - Router 触发 failover → 熔断器 OPEN → 请求路由到下一层级
35
+ - 重认证完成后 TokenManager 获得新凭证 → 熔断器恢复 → 后端可用
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ token_store: TokenStoreManager,
41
+ providers: dict[str, OAuthProvider],
42
+ token_updaters: dict[str, Callable[[str], None]],
43
+ ) -> None:
44
+ """
45
+ Args:
46
+ token_store: Token 持久化管理器
47
+ providers: provider_name → OAuthProvider 实例
48
+ token_updaters: provider_name → 更新 TokenManager 凭证的回调
49
+ """
50
+ self._token_store = token_store
51
+ self._providers = providers
52
+ self._token_updaters = token_updaters
53
+ self._states: dict[str, ReauthState] = {
54
+ name: ReauthState.IDLE for name in providers
55
+ }
56
+ self._locks: dict[str, asyncio.Lock] = {
57
+ name: asyncio.Lock() for name in providers
58
+ }
59
+ self._last_error: dict[str, str] = {}
60
+ self._last_completed: dict[str, float] = {}
61
+
62
+ async def request_reauth(self, provider_name: str) -> None:
63
+ """请求对指定 provider 进行重认证(幂等,后台执行).
64
+
65
+ 若已在进行中,直接返回不重复触发。
66
+ """
67
+ if provider_name not in self._providers:
68
+ logger.warning("未知 provider: %s", provider_name)
69
+ return
70
+
71
+ if self._states.get(provider_name) == ReauthState.PENDING:
72
+ return # 已在进行中
73
+
74
+ asyncio.create_task(self._do_reauth(provider_name))
75
+
76
+ async def _do_reauth(self, provider_name: str) -> None:
77
+ """执行重认证流程(带锁保护的幂等实现)."""
78
+ lock = self._locks[provider_name]
79
+ if lock.locked():
80
+ return # 另一个任务正在执行
81
+
82
+ async with lock:
83
+ self._states[provider_name] = ReauthState.PENDING
84
+ logger.info("开始 %s 重认证...", provider_name)
85
+
86
+ try:
87
+ provider = self._providers[provider_name]
88
+ tokens = await provider.login()
89
+ self._token_store.set(provider_name, tokens)
90
+
91
+ # 调用热更新回调
92
+ updater = self._token_updaters.get(provider_name)
93
+ if updater:
94
+ # GitHub → access_token, Google → refresh_token
95
+ if provider_name == "github":
96
+ updater(tokens.access_token)
97
+ elif provider_name == "google":
98
+ updater(tokens.refresh_token)
99
+
100
+ self._states[provider_name] = ReauthState.COMPLETED
101
+ self._last_completed[provider_name] = time.monotonic()
102
+ self._last_error.pop(provider_name, None)
103
+ logger.info("%s 重认证成功", provider_name)
104
+
105
+ except Exception as exc:
106
+ self._states[provider_name] = ReauthState.FAILED
107
+ self._last_error[provider_name] = str(exc)
108
+ logger.error("%s 重认证失败: %s", provider_name, exc)
109
+
110
+ def get_status(self) -> dict[str, dict[str, str]]:
111
+ """返回所有 provider 的重认证状态."""
112
+ result = {}
113
+ for name in self._providers:
114
+ info: dict[str, str] = {"state": self._states[name].value}
115
+ if name in self._last_error:
116
+ info["error"] = self._last_error[name]
117
+ if name in self._last_completed:
118
+ info["completed_ago_seconds"] = str(
119
+ int(time.monotonic() - self._last_completed[name])
120
+ )
121
+ result[name] = info
122
+ return result
@@ -0,0 +1,74 @@
1
+ """Token 持久化存储 — ~/.coding-proxy/tokens.json.
2
+
3
+ ``ProviderTokens`` 数据模型已迁移至 :mod:`coding.proxy.model.auth`。
4
+ 本文件保留 ``TokenStoreManager`` 持久化管理器,类型通过 re-export 提供。
5
+
6
+ .. deprecated::
7
+ 未来版本将移除类型 re-export,请直接从 :mod:`coding.proxy.model.auth` 导入。
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import logging
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ # noqa: F401
19
+ from ..model.auth import ProviderTokens
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ _DEFAULT_STORE_PATH = Path("~/.coding-proxy/tokens.json")
24
+
25
+
26
+ class TokenStoreManager:
27
+ """管理所有 Provider 的 Token 持久化."""
28
+
29
+ def __init__(self, store_path: Path | None = None) -> None:
30
+ self._path = (store_path or _DEFAULT_STORE_PATH).expanduser()
31
+ self._data: dict[str, dict[str, Any]] = {}
32
+
33
+ def load(self) -> None:
34
+ """从磁盘加载 Token 存储."""
35
+ if self._path.exists():
36
+ try:
37
+ with open(self._path) as f:
38
+ self._data = json.load(f)
39
+ logger.debug("Token store loaded from %s", self._path)
40
+ except (json.JSONDecodeError, OSError) as exc:
41
+ logger.warning("Failed to load token store: %s", exc)
42
+ self._data = {}
43
+ else:
44
+ self._data = {}
45
+
46
+ def save(self) -> None:
47
+ """持久化 Token 到磁盘."""
48
+ self._path.parent.mkdir(parents=True, exist_ok=True)
49
+ with open(self._path, "w") as f:
50
+ json.dump(self._data, f, indent=2, ensure_ascii=False)
51
+ # 限制文件权限为仅 owner 可读写
52
+ self._path.chmod(0o600)
53
+ logger.debug("Token store saved to %s", self._path)
54
+
55
+ def get(self, provider: str) -> ProviderTokens:
56
+ """获取指定 Provider 的 Token."""
57
+ raw = self._data.get(provider, {})
58
+ return ProviderTokens(**raw) if raw else ProviderTokens()
59
+
60
+ def set(self, provider: str, tokens: ProviderTokens) -> None:
61
+ """设置指定 Provider 的 Token 并持久化."""
62
+ self._data[provider] = tokens.model_dump()
63
+ self.save()
64
+ logger.info("Token updated for provider: %s", provider)
65
+
66
+ def remove(self, provider: str) -> None:
67
+ """移除指定 Provider 的 Token."""
68
+ if provider in self._data:
69
+ del self._data[provider]
70
+ self.save()
71
+
72
+ def list_providers(self) -> list[str]:
73
+ """列出所有已存储 Token 的 Provider."""
74
+ return list(self._data.keys())