coding-proxy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coding/__init__.py +0 -0
- coding/proxy/__init__.py +3 -0
- coding/proxy/__main__.py +5 -0
- coding/proxy/auth/__init__.py +13 -0
- coding/proxy/auth/providers/__init__.py +6 -0
- coding/proxy/auth/providers/base.py +35 -0
- coding/proxy/auth/providers/github.py +133 -0
- coding/proxy/auth/providers/google.py +237 -0
- coding/proxy/auth/runtime.py +122 -0
- coding/proxy/auth/store.py +74 -0
- coding/proxy/cli/__init__.py +151 -0
- coding/proxy/cli/auth_commands.py +224 -0
- coding/proxy/compat/__init__.py +30 -0
- coding/proxy/compat/canonical.py +193 -0
- coding/proxy/compat/session_store.py +137 -0
- coding/proxy/config/__init__.py +6 -0
- coding/proxy/config/auth_schema.py +24 -0
- coding/proxy/config/loader.py +139 -0
- coding/proxy/config/resiliency.py +46 -0
- coding/proxy/config/routing.py +279 -0
- coding/proxy/config/schema.py +280 -0
- coding/proxy/config/server.py +23 -0
- coding/proxy/config/vendors.py +53 -0
- coding/proxy/convert/__init__.py +14 -0
- coding/proxy/convert/anthropic_to_gemini.py +352 -0
- coding/proxy/convert/anthropic_to_openai.py +352 -0
- coding/proxy/convert/gemini_sse_adapter.py +169 -0
- coding/proxy/convert/gemini_to_anthropic.py +98 -0
- coding/proxy/convert/openai_to_anthropic.py +88 -0
- coding/proxy/logging/__init__.py +49 -0
- coding/proxy/logging/db.py +308 -0
- coding/proxy/logging/stats.py +129 -0
- coding/proxy/model/__init__.py +93 -0
- coding/proxy/model/auth.py +32 -0
- coding/proxy/model/compat.py +153 -0
- coding/proxy/model/constants.py +21 -0
- coding/proxy/model/pricing.py +70 -0
- coding/proxy/model/token.py +64 -0
- coding/proxy/model/vendor.py +218 -0
- coding/proxy/pricing.py +100 -0
- coding/proxy/routing/__init__.py +47 -0
- coding/proxy/routing/circuit_breaker.py +152 -0
- coding/proxy/routing/error_classifier.py +67 -0
- coding/proxy/routing/executor.py +453 -0
- coding/proxy/routing/model_mapper.py +90 -0
- coding/proxy/routing/quota_guard.py +169 -0
- coding/proxy/routing/rate_limit.py +159 -0
- coding/proxy/routing/retry.py +82 -0
- coding/proxy/routing/router.py +84 -0
- coding/proxy/routing/session_manager.py +62 -0
- coding/proxy/routing/tier.py +171 -0
- coding/proxy/routing/usage_parser.py +193 -0
- coding/proxy/routing/usage_recorder.py +131 -0
- coding/proxy/server/__init__.py +1 -0
- coding/proxy/server/app.py +142 -0
- coding/proxy/server/factory.py +175 -0
- coding/proxy/server/request_normalizer.py +139 -0
- coding/proxy/server/responses.py +74 -0
- coding/proxy/server/routes.py +264 -0
- coding/proxy/streaming/__init__.py +1 -0
- coding/proxy/streaming/anthropic_compat.py +484 -0
- coding/proxy/vendors/__init__.py +29 -0
- coding/proxy/vendors/anthropic.py +44 -0
- coding/proxy/vendors/antigravity.py +328 -0
- coding/proxy/vendors/base.py +353 -0
- coding/proxy/vendors/copilot.py +702 -0
- coding/proxy/vendors/copilot_models.py +438 -0
- coding/proxy/vendors/copilot_token_manager.py +167 -0
- coding/proxy/vendors/copilot_urls.py +16 -0
- coding/proxy/vendors/mixins.py +71 -0
- coding/proxy/vendors/token_manager.py +128 -0
- coding/proxy/vendors/zhipu.py +243 -0
- coding_proxy-0.1.0.dist-info/METADATA +184 -0
- coding_proxy-0.1.0.dist-info/RECORD +77 -0
- coding_proxy-0.1.0.dist-info/WHEEL +4 -0
- coding_proxy-0.1.0.dist-info/entry_points.txt +2 -0
- coding_proxy-0.1.0.dist-info/licenses/LICENSE +201 -0
coding/__init__.py
ADDED
|
File without changes
|
coding/proxy/__init__.py
ADDED
coding/proxy/__main__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""OAuth 认证管理模块."""
|
|
2
|
+
|
|
3
|
+
from .providers.base import OAuthProvider
|
|
4
|
+
from .providers.github import GitHubDeviceFlowProvider
|
|
5
|
+
from .providers.google import GoogleOAuthProvider
|
|
6
|
+
from .runtime import ReauthState, RuntimeReauthCoordinator
|
|
7
|
+
from .store import ProviderTokens, TokenStoreManager
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"OAuthProvider", "GitHubDeviceFlowProvider", "GoogleOAuthProvider",
|
|
11
|
+
"RuntimeReauthCoordinator", "ReauthState",
|
|
12
|
+
"ProviderTokens", "TokenStoreManager",
|
|
13
|
+
]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""OAuth Provider 抽象基类."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
from ..store import ProviderTokens
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OAuthProvider(ABC):
|
|
11
|
+
"""OAuth 登录提供者抽象基类."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def get_name(self) -> str:
|
|
15
|
+
"""返回 Provider 唯一标识(用于 Token Store 的 key)."""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
async def login(self) -> ProviderTokens:
|
|
19
|
+
"""执行 OAuth 登录流程,返回获取到的 Token."""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def refresh(self, tokens: ProviderTokens) -> ProviderTokens:
|
|
23
|
+
"""使用 refresh_token 刷新 access_token."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
async def validate(self, tokens: ProviderTokens) -> bool:
|
|
27
|
+
"""验证当前 Token 是否仍然有效."""
|
|
28
|
+
|
|
29
|
+
def needs_login(self, tokens: ProviderTokens) -> bool:
|
|
30
|
+
"""判断是否需要重新登录(默认:无凭证或已过期且无 refresh_token)."""
|
|
31
|
+
if not tokens.has_credentials:
|
|
32
|
+
return True
|
|
33
|
+
if tokens.is_expired and not tokens.refresh_token:
|
|
34
|
+
return True
|
|
35
|
+
return False
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""GitHub Device Authorization Flow — 浏览器免回调的 OAuth 登录."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import webbrowser
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from ..store import ProviderTokens
|
|
15
|
+
from .base import OAuthProvider
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# GitHub Copilot VS Code 扩展的公开 client_id
|
|
20
|
+
# SOT(权威源): coding.proxy.config.schema.AuthConfig.github_client_id
|
|
21
|
+
# 此处默认值仅作 fallback,生产环境应通过 config.yaml 的 auth 段覆盖
|
|
22
|
+
_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
|
|
23
|
+
|
|
24
|
+
_DEVICE_CODE_URL = "https://github.com/login/device/code"
|
|
25
|
+
_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
|
26
|
+
_POLL_INTERVAL = 5 # seconds
|
|
27
|
+
_MAX_POLL_ATTEMPTS = 60 # 5 minutes total
|
|
28
|
+
_COPILOT_PERMISSIVE_SCOPES = "read:user user:email repo workflow"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GitHubDeviceFlowProvider(OAuthProvider):
|
|
32
|
+
"""GitHub Device Authorization Flow 实现.
|
|
33
|
+
|
|
34
|
+
无需本地 HTTP 服务器,用户在浏览器中输入 user_code 即可完成授权。
|
|
35
|
+
获取的 GitHub OAuth token 可用于 Copilot token 交换。
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, client_id: str = _COPILOT_CLIENT_ID) -> None:
|
|
39
|
+
self._client_id = client_id
|
|
40
|
+
self._http = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
|
|
41
|
+
|
|
42
|
+
def get_name(self) -> str:
|
|
43
|
+
return "github"
|
|
44
|
+
|
|
45
|
+
async def login(self) -> ProviderTokens:
|
|
46
|
+
"""执行 GitHub Device Flow,返回 OAuth token."""
|
|
47
|
+
# Step 1: 请求 device code
|
|
48
|
+
resp = await self._http.post(
|
|
49
|
+
_DEVICE_CODE_URL,
|
|
50
|
+
data={"client_id": self._client_id, "scope": _COPILOT_PERMISSIVE_SCOPES},
|
|
51
|
+
headers={"accept": "application/json"},
|
|
52
|
+
)
|
|
53
|
+
resp.raise_for_status()
|
|
54
|
+
device_data: dict[str, Any] = resp.json()
|
|
55
|
+
|
|
56
|
+
user_code = device_data["user_code"]
|
|
57
|
+
verification_uri = device_data["verification_uri"]
|
|
58
|
+
device_code = device_data["device_code"]
|
|
59
|
+
interval = device_data.get("interval", _POLL_INTERVAL)
|
|
60
|
+
|
|
61
|
+
# 优先使用预填充 user_code 的完整链接
|
|
62
|
+
verification_url = device_data.get(
|
|
63
|
+
"verification_uri_complete", verification_uri
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Step 2: 引导用户在浏览器中授权
|
|
67
|
+
logger.info("请在浏览器中访问 %s 并输入代码: %s", verification_uri, user_code)
|
|
68
|
+
print(f"\n 🔗 请在浏览器中访问: {verification_uri}")
|
|
69
|
+
print(f" 📋 并输入代码: {user_code}\n")
|
|
70
|
+
|
|
71
|
+
webbrowser.open(verification_url)
|
|
72
|
+
|
|
73
|
+
# Step 3: 轮询等待用户完成授权
|
|
74
|
+
for attempt in range(_MAX_POLL_ATTEMPTS):
|
|
75
|
+
await asyncio.sleep(interval)
|
|
76
|
+
|
|
77
|
+
token_resp = await self._http.post(
|
|
78
|
+
_ACCESS_TOKEN_URL,
|
|
79
|
+
data={
|
|
80
|
+
"client_id": self._client_id,
|
|
81
|
+
"device_code": device_code,
|
|
82
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
|
83
|
+
},
|
|
84
|
+
headers={"accept": "application/json"},
|
|
85
|
+
)
|
|
86
|
+
token_data = token_resp.json()
|
|
87
|
+
|
|
88
|
+
if "access_token" in token_data:
|
|
89
|
+
logger.info("GitHub OAuth 授权成功")
|
|
90
|
+
return ProviderTokens(
|
|
91
|
+
access_token=token_data["access_token"],
|
|
92
|
+
token_type=token_data.get("token_type", "bearer"),
|
|
93
|
+
scope=token_data.get("scope", ""),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
error = token_data.get("error", "")
|
|
97
|
+
if error == "authorization_pending":
|
|
98
|
+
continue
|
|
99
|
+
elif error == "slow_down":
|
|
100
|
+
interval += 5
|
|
101
|
+
continue
|
|
102
|
+
elif error == "expired_token":
|
|
103
|
+
raise RuntimeError("Device code 已过期,请重新登录")
|
|
104
|
+
elif error == "access_denied":
|
|
105
|
+
raise RuntimeError("用户拒绝了授权")
|
|
106
|
+
else:
|
|
107
|
+
raise RuntimeError(f"GitHub OAuth 错误: {error}")
|
|
108
|
+
|
|
109
|
+
raise RuntimeError("GitHub Device Flow 超时,请重试")
|
|
110
|
+
|
|
111
|
+
async def refresh(self, tokens: ProviderTokens) -> ProviderTokens:
|
|
112
|
+
"""GitHub Device Flow 不支持 refresh_token,需要重新登录."""
|
|
113
|
+
return await self.login()
|
|
114
|
+
|
|
115
|
+
async def validate(self, tokens: ProviderTokens) -> bool:
|
|
116
|
+
"""验证 GitHub token 是否有效."""
|
|
117
|
+
if not tokens.access_token:
|
|
118
|
+
return False
|
|
119
|
+
try:
|
|
120
|
+
resp = await self._http.get(
|
|
121
|
+
"https://api.github.com/user",
|
|
122
|
+
headers={
|
|
123
|
+
"authorization": f"token {tokens.access_token}",
|
|
124
|
+
"accept": "application/json",
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
return resp.status_code == 200
|
|
128
|
+
except httpx.HTTPError:
|
|
129
|
+
return False
|
|
130
|
+
|
|
131
|
+
async def close(self) -> None:
|
|
132
|
+
if not self._http.is_closed:
|
|
133
|
+
await self._http.aclose()
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Google OAuth2 Authorization Code Flow — 本地回调服务器."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import secrets
|
|
8
|
+
import time
|
|
9
|
+
from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
10
|
+
from typing import Any
|
|
11
|
+
from urllib.parse import parse_qs, urlencode, urlparse
|
|
12
|
+
|
|
13
|
+
import httpx
|
|
14
|
+
|
|
15
|
+
from ..store import ProviderTokens
|
|
16
|
+
from .base import OAuthProvider
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
# Antigravity Enterprise 公开 OAuth 凭据
|
|
21
|
+
# SOT(权威源): coding.proxy.config.schema.AuthConfig
|
|
22
|
+
# 此处默认值仅作 fallback,生产环境应通过 config.yaml 的 auth 段覆盖
|
|
23
|
+
_DEFAULT_CLIENT_ID = (
|
|
24
|
+
"1071006060591-tmhssin2h21lcre235vtolojh4g403ep"
|
|
25
|
+
".apps.googleusercontent.com"
|
|
26
|
+
)
|
|
27
|
+
_DEFAULT_CLIENT_SECRET = "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"
|
|
28
|
+
|
|
29
|
+
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
|
30
|
+
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
|
31
|
+
_SCOPES = [
|
|
32
|
+
"https://www.googleapis.com/auth/cloud-platform",
|
|
33
|
+
"https://www.googleapis.com/auth/userinfo.email",
|
|
34
|
+
"https://www.googleapis.com/auth/userinfo.profile",
|
|
35
|
+
"https://www.googleapis.com/auth/cclog",
|
|
36
|
+
"https://www.googleapis.com/auth/experimentsandconfigs",
|
|
37
|
+
]
|
|
38
|
+
_REQUIRED_SCOPE_SET = frozenset(_SCOPES)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class _CallbackHandler(BaseHTTPRequestHandler):
|
|
42
|
+
"""OAuth 回调 HTTP 处理器.
|
|
43
|
+
|
|
44
|
+
使用实例级 result dict 避免类属性在并发场景下的交叉污染.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, *args: Any, result: dict[str, str | None], **kwargs: Any) -> None:
|
|
48
|
+
self._result = result
|
|
49
|
+
super().__init__(*args, **kwargs)
|
|
50
|
+
|
|
51
|
+
def do_GET(self) -> None:
|
|
52
|
+
parsed = urlparse(self.path)
|
|
53
|
+
params = parse_qs(parsed.query)
|
|
54
|
+
|
|
55
|
+
if parsed.path == "/callback":
|
|
56
|
+
if "error" in params:
|
|
57
|
+
self._result["error"] = params["error"][0]
|
|
58
|
+
self._respond("授权失败,请关闭此页面返回终端。")
|
|
59
|
+
elif "code" in params and "state" in params:
|
|
60
|
+
self._result["auth_code"] = params["code"][0]
|
|
61
|
+
self._result["state"] = params["state"][0]
|
|
62
|
+
self._respond("授权成功!请关闭此页面返回终端。")
|
|
63
|
+
else:
|
|
64
|
+
self._respond("无效的回调参数。")
|
|
65
|
+
else:
|
|
66
|
+
self.send_response(404)
|
|
67
|
+
self.end_headers()
|
|
68
|
+
|
|
69
|
+
def _respond(self, message: str) -> None:
|
|
70
|
+
self.send_response(200)
|
|
71
|
+
self.send_header("content-type", "text/html; charset=utf-8")
|
|
72
|
+
self.end_headers()
|
|
73
|
+
self.wfile.write(f"<html><body><h2>{message}</h2></body></html>".encode())
|
|
74
|
+
|
|
75
|
+
def log_message(self, format: str, *args: Any) -> None:
|
|
76
|
+
pass # 静默 HTTP 日志
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class GoogleOAuthProvider(OAuthProvider):
|
|
80
|
+
"""Google OAuth2 Authorization Code Flow 实现.
|
|
81
|
+
|
|
82
|
+
启动本地 HTTP 回调服务器捕获 authorization code,
|
|
83
|
+
交换为 access_token + refresh_token。
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
client_id: str = _DEFAULT_CLIENT_ID,
|
|
89
|
+
client_secret: str = _DEFAULT_CLIENT_SECRET,
|
|
90
|
+
) -> None:
|
|
91
|
+
self._client_id = client_id
|
|
92
|
+
self._client_secret = client_secret
|
|
93
|
+
self._http = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
|
|
94
|
+
|
|
95
|
+
def get_name(self) -> str:
|
|
96
|
+
return "google"
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def has_required_scopes(scope: str) -> bool:
|
|
100
|
+
granted = {item for item in scope.split() if item}
|
|
101
|
+
return _REQUIRED_SCOPE_SET.issubset(granted)
|
|
102
|
+
|
|
103
|
+
async def login(self) -> ProviderTokens:
|
|
104
|
+
"""执行 Google OAuth2 Code Flow,返回 Token."""
|
|
105
|
+
state = secrets.token_urlsafe(32)
|
|
106
|
+
result: dict[str, str | None] = {"auth_code": None, "state": None, "error": None}
|
|
107
|
+
|
|
108
|
+
def _make_handler(*args: Any, **kwargs: Any) -> _CallbackHandler:
|
|
109
|
+
return _CallbackHandler(*args, result=result, **kwargs)
|
|
110
|
+
|
|
111
|
+
# 绑定到 port 0,由 OS 分配可用端口,避免 TOCTOU 竞态
|
|
112
|
+
server = HTTPServer(("127.0.0.1", 0), _make_handler)
|
|
113
|
+
redirect_port = server.server_address[1]
|
|
114
|
+
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
|
|
115
|
+
|
|
116
|
+
params = urlencode({
|
|
117
|
+
"client_id": self._client_id,
|
|
118
|
+
"redirect_uri": redirect_uri,
|
|
119
|
+
"response_type": "code",
|
|
120
|
+
"scope": " ".join(_SCOPES),
|
|
121
|
+
"state": state,
|
|
122
|
+
"access_type": "offline",
|
|
123
|
+
"prompt": "consent",
|
|
124
|
+
})
|
|
125
|
+
auth_url = f"{_AUTH_URL}?{params}"
|
|
126
|
+
|
|
127
|
+
logger.info("请在浏览器中完成 Google 授权")
|
|
128
|
+
print(f"\n 🔗 请在浏览器中访问以下链接完成授权:\n")
|
|
129
|
+
print(f" {auth_url}\n")
|
|
130
|
+
|
|
131
|
+
# 打开浏览器
|
|
132
|
+
import webbrowser
|
|
133
|
+
webbrowser.open(auth_url)
|
|
134
|
+
|
|
135
|
+
# 等待回调
|
|
136
|
+
for _ in range(120): # 最多等 2 分钟
|
|
137
|
+
server.handle_request()
|
|
138
|
+
if result["auth_code"] or result["error"]:
|
|
139
|
+
break
|
|
140
|
+
await asyncio.sleep(1)
|
|
141
|
+
|
|
142
|
+
server.server_close()
|
|
143
|
+
|
|
144
|
+
if result["error"]:
|
|
145
|
+
raise RuntimeError(f"Google OAuth 错误: {result['error']}")
|
|
146
|
+
|
|
147
|
+
if not result["auth_code"]:
|
|
148
|
+
raise RuntimeError("Google OAuth 超时,请重试")
|
|
149
|
+
|
|
150
|
+
if result["state"] != state:
|
|
151
|
+
raise RuntimeError("OAuth state 不匹配,可能遭受 CSRF 攻击")
|
|
152
|
+
|
|
153
|
+
# 交换 code → token
|
|
154
|
+
return await self._exchange_code(result["auth_code"], redirect_uri)
|
|
155
|
+
|
|
156
|
+
async def _exchange_code(
|
|
157
|
+
self, code: str, redirect_uri: str
|
|
158
|
+
) -> ProviderTokens:
|
|
159
|
+
"""将 authorization code 交换为 access_token + refresh_token."""
|
|
160
|
+
resp = await self._http.post(
|
|
161
|
+
_TOKEN_URL,
|
|
162
|
+
data={
|
|
163
|
+
"client_id": self._client_id,
|
|
164
|
+
"client_secret": self._client_secret,
|
|
165
|
+
"code": code,
|
|
166
|
+
"redirect_uri": redirect_uri,
|
|
167
|
+
"grant_type": "authorization_code",
|
|
168
|
+
},
|
|
169
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
170
|
+
)
|
|
171
|
+
resp.raise_for_status()
|
|
172
|
+
data = resp.json()
|
|
173
|
+
|
|
174
|
+
expires_at = 0.0
|
|
175
|
+
if "expires_in" in data:
|
|
176
|
+
expires_at = time.time() + data["expires_in"]
|
|
177
|
+
|
|
178
|
+
return ProviderTokens(
|
|
179
|
+
access_token=data.get("access_token", ""),
|
|
180
|
+
refresh_token=data.get("refresh_token", ""),
|
|
181
|
+
expires_at=expires_at,
|
|
182
|
+
scope=data.get("scope", ""),
|
|
183
|
+
token_type=data.get("token_type", "bearer"),
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
async def refresh(self, tokens: ProviderTokens) -> ProviderTokens:
|
|
187
|
+
"""使用 refresh_token 刷新 access_token."""
|
|
188
|
+
if not tokens.refresh_token:
|
|
189
|
+
return await self.login()
|
|
190
|
+
|
|
191
|
+
resp = await self._http.post(
|
|
192
|
+
_TOKEN_URL,
|
|
193
|
+
data={
|
|
194
|
+
"client_id": self._client_id,
|
|
195
|
+
"client_secret": self._client_secret,
|
|
196
|
+
"refresh_token": tokens.refresh_token,
|
|
197
|
+
"grant_type": "refresh_token",
|
|
198
|
+
},
|
|
199
|
+
headers={"content-type": "application/x-www-form-urlencoded"},
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if resp.status_code >= 400:
|
|
203
|
+
logger.warning("Google token refresh 失败,需要重新登录")
|
|
204
|
+
return await self.login()
|
|
205
|
+
|
|
206
|
+
data = resp.json()
|
|
207
|
+
expires_at = 0.0
|
|
208
|
+
if "expires_in" in data:
|
|
209
|
+
expires_at = time.time() + data["expires_in"]
|
|
210
|
+
|
|
211
|
+
return ProviderTokens(
|
|
212
|
+
access_token=data.get("access_token", ""),
|
|
213
|
+
refresh_token=tokens.refresh_token, # refresh_token 通常不变
|
|
214
|
+
expires_at=expires_at,
|
|
215
|
+
scope=data.get("scope", tokens.scope),
|
|
216
|
+
token_type=data.get("token_type", "bearer"),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
async def validate(self, tokens: ProviderTokens) -> bool:
|
|
220
|
+
"""验证 Google token 是否有效."""
|
|
221
|
+
if not tokens.access_token:
|
|
222
|
+
return False
|
|
223
|
+
try:
|
|
224
|
+
resp = await self._http.get(
|
|
225
|
+
"https://www.googleapis.com/oauth2/v1/tokeninfo",
|
|
226
|
+
params={"access_token": tokens.access_token},
|
|
227
|
+
)
|
|
228
|
+
if resp.status_code != 200:
|
|
229
|
+
return False
|
|
230
|
+
data = resp.json()
|
|
231
|
+
return self.has_required_scopes(data.get("scope", tokens.scope))
|
|
232
|
+
except httpx.HTTPError:
|
|
233
|
+
return False
|
|
234
|
+
|
|
235
|
+
async def close(self) -> None:
|
|
236
|
+
if not self._http.is_closed:
|
|
237
|
+
await self._http.aclose()
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""运行时 OAuth 重认证协调器 — 后台触发浏览器登录并热更新凭证."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import enum
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
from typing import Callable
|
|
10
|
+
|
|
11
|
+
from .providers.base import OAuthProvider
|
|
12
|
+
from .store import TokenStoreManager
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ReauthState(enum.Enum):
|
|
18
|
+
"""重认证状态."""
|
|
19
|
+
|
|
20
|
+
IDLE = "idle"
|
|
21
|
+
PENDING = "pending"
|
|
22
|
+
COMPLETED = "completed"
|
|
23
|
+
FAILED = "failed"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RuntimeReauthCoordinator:
|
|
27
|
+
"""运行时 OAuth 重认证协调器.
|
|
28
|
+
|
|
29
|
+
当 TokenManager 报告 needs_reauth=True 时,Router 调用
|
|
30
|
+
``request_reauth()`` 在后台触发浏览器登录流程。
|
|
31
|
+
|
|
32
|
+
与熔断器的协同:
|
|
33
|
+
- 重认证期间 TokenManager 持续抛 TokenAcquireError
|
|
34
|
+
- Router 触发 failover → 熔断器 OPEN → 请求路由到下一层级
|
|
35
|
+
- 重认证完成后 TokenManager 获得新凭证 → 熔断器恢复 → 后端可用
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
token_store: TokenStoreManager,
|
|
41
|
+
providers: dict[str, OAuthProvider],
|
|
42
|
+
token_updaters: dict[str, Callable[[str], None]],
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Args:
|
|
46
|
+
token_store: Token 持久化管理器
|
|
47
|
+
providers: provider_name → OAuthProvider 实例
|
|
48
|
+
token_updaters: provider_name → 更新 TokenManager 凭证的回调
|
|
49
|
+
"""
|
|
50
|
+
self._token_store = token_store
|
|
51
|
+
self._providers = providers
|
|
52
|
+
self._token_updaters = token_updaters
|
|
53
|
+
self._states: dict[str, ReauthState] = {
|
|
54
|
+
name: ReauthState.IDLE for name in providers
|
|
55
|
+
}
|
|
56
|
+
self._locks: dict[str, asyncio.Lock] = {
|
|
57
|
+
name: asyncio.Lock() for name in providers
|
|
58
|
+
}
|
|
59
|
+
self._last_error: dict[str, str] = {}
|
|
60
|
+
self._last_completed: dict[str, float] = {}
|
|
61
|
+
|
|
62
|
+
async def request_reauth(self, provider_name: str) -> None:
|
|
63
|
+
"""请求对指定 provider 进行重认证(幂等,后台执行).
|
|
64
|
+
|
|
65
|
+
若已在进行中,直接返回不重复触发。
|
|
66
|
+
"""
|
|
67
|
+
if provider_name not in self._providers:
|
|
68
|
+
logger.warning("未知 provider: %s", provider_name)
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
if self._states.get(provider_name) == ReauthState.PENDING:
|
|
72
|
+
return # 已在进行中
|
|
73
|
+
|
|
74
|
+
asyncio.create_task(self._do_reauth(provider_name))
|
|
75
|
+
|
|
76
|
+
async def _do_reauth(self, provider_name: str) -> None:
|
|
77
|
+
"""执行重认证流程(带锁保护的幂等实现)."""
|
|
78
|
+
lock = self._locks[provider_name]
|
|
79
|
+
if lock.locked():
|
|
80
|
+
return # 另一个任务正在执行
|
|
81
|
+
|
|
82
|
+
async with lock:
|
|
83
|
+
self._states[provider_name] = ReauthState.PENDING
|
|
84
|
+
logger.info("开始 %s 重认证...", provider_name)
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
provider = self._providers[provider_name]
|
|
88
|
+
tokens = await provider.login()
|
|
89
|
+
self._token_store.set(provider_name, tokens)
|
|
90
|
+
|
|
91
|
+
# 调用热更新回调
|
|
92
|
+
updater = self._token_updaters.get(provider_name)
|
|
93
|
+
if updater:
|
|
94
|
+
# GitHub → access_token, Google → refresh_token
|
|
95
|
+
if provider_name == "github":
|
|
96
|
+
updater(tokens.access_token)
|
|
97
|
+
elif provider_name == "google":
|
|
98
|
+
updater(tokens.refresh_token)
|
|
99
|
+
|
|
100
|
+
self._states[provider_name] = ReauthState.COMPLETED
|
|
101
|
+
self._last_completed[provider_name] = time.monotonic()
|
|
102
|
+
self._last_error.pop(provider_name, None)
|
|
103
|
+
logger.info("%s 重认证成功", provider_name)
|
|
104
|
+
|
|
105
|
+
except Exception as exc:
|
|
106
|
+
self._states[provider_name] = ReauthState.FAILED
|
|
107
|
+
self._last_error[provider_name] = str(exc)
|
|
108
|
+
logger.error("%s 重认证失败: %s", provider_name, exc)
|
|
109
|
+
|
|
110
|
+
def get_status(self) -> dict[str, dict[str, str]]:
|
|
111
|
+
"""返回所有 provider 的重认证状态."""
|
|
112
|
+
result = {}
|
|
113
|
+
for name in self._providers:
|
|
114
|
+
info: dict[str, str] = {"state": self._states[name].value}
|
|
115
|
+
if name in self._last_error:
|
|
116
|
+
info["error"] = self._last_error[name]
|
|
117
|
+
if name in self._last_completed:
|
|
118
|
+
info["completed_ago_seconds"] = str(
|
|
119
|
+
int(time.monotonic() - self._last_completed[name])
|
|
120
|
+
)
|
|
121
|
+
result[name] = info
|
|
122
|
+
return result
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Token 持久化存储 — ~/.coding-proxy/tokens.json.
|
|
2
|
+
|
|
3
|
+
``ProviderTokens`` 数据模型已迁移至 :mod:`coding.proxy.model.auth`。
|
|
4
|
+
本文件保留 ``TokenStoreManager`` 持久化管理器,类型通过 re-export 提供。
|
|
5
|
+
|
|
6
|
+
.. deprecated::
|
|
7
|
+
未来版本将移除类型 re-export,请直接从 :mod:`coding.proxy.model.auth` 导入。
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
# noqa: F401
|
|
19
|
+
from ..model.auth import ProviderTokens
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
_DEFAULT_STORE_PATH = Path("~/.coding-proxy/tokens.json")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TokenStoreManager:
|
|
27
|
+
"""管理所有 Provider 的 Token 持久化."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, store_path: Path | None = None) -> None:
|
|
30
|
+
self._path = (store_path or _DEFAULT_STORE_PATH).expanduser()
|
|
31
|
+
self._data: dict[str, dict[str, Any]] = {}
|
|
32
|
+
|
|
33
|
+
def load(self) -> None:
|
|
34
|
+
"""从磁盘加载 Token 存储."""
|
|
35
|
+
if self._path.exists():
|
|
36
|
+
try:
|
|
37
|
+
with open(self._path) as f:
|
|
38
|
+
self._data = json.load(f)
|
|
39
|
+
logger.debug("Token store loaded from %s", self._path)
|
|
40
|
+
except (json.JSONDecodeError, OSError) as exc:
|
|
41
|
+
logger.warning("Failed to load token store: %s", exc)
|
|
42
|
+
self._data = {}
|
|
43
|
+
else:
|
|
44
|
+
self._data = {}
|
|
45
|
+
|
|
46
|
+
def save(self) -> None:
|
|
47
|
+
"""持久化 Token 到磁盘."""
|
|
48
|
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
with open(self._path, "w") as f:
|
|
50
|
+
json.dump(self._data, f, indent=2, ensure_ascii=False)
|
|
51
|
+
# 限制文件权限为仅 owner 可读写
|
|
52
|
+
self._path.chmod(0o600)
|
|
53
|
+
logger.debug("Token store saved to %s", self._path)
|
|
54
|
+
|
|
55
|
+
def get(self, provider: str) -> ProviderTokens:
|
|
56
|
+
"""获取指定 Provider 的 Token."""
|
|
57
|
+
raw = self._data.get(provider, {})
|
|
58
|
+
return ProviderTokens(**raw) if raw else ProviderTokens()
|
|
59
|
+
|
|
60
|
+
def set(self, provider: str, tokens: ProviderTokens) -> None:
|
|
61
|
+
"""设置指定 Provider 的 Token 并持久化."""
|
|
62
|
+
self._data[provider] = tokens.model_dump()
|
|
63
|
+
self.save()
|
|
64
|
+
logger.info("Token updated for provider: %s", provider)
|
|
65
|
+
|
|
66
|
+
def remove(self, provider: str) -> None:
|
|
67
|
+
"""移除指定 Provider 的 Token."""
|
|
68
|
+
if provider in self._data:
|
|
69
|
+
del self._data[provider]
|
|
70
|
+
self.save()
|
|
71
|
+
|
|
72
|
+
def list_providers(self) -> list[str]:
|
|
73
|
+
"""列出所有已存储 Token 的 Provider."""
|
|
74
|
+
return list(self._data.keys())
|