gac 3.6.0__py3-none-any.whl → 3.10.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gac/__init__.py +4 -6
- gac/__version__.py +1 -1
- gac/ai_utils.py +59 -43
- gac/auth_cli.py +181 -36
- gac/cli.py +26 -9
- gac/commit_executor.py +59 -0
- gac/config.py +81 -2
- gac/config_cli.py +19 -7
- gac/constants/__init__.py +34 -0
- gac/constants/commit.py +63 -0
- gac/constants/defaults.py +40 -0
- gac/constants/file_patterns.py +110 -0
- gac/constants/languages.py +119 -0
- gac/diff_cli.py +0 -22
- gac/errors.py +8 -2
- gac/git.py +6 -6
- gac/git_state_validator.py +193 -0
- gac/grouped_commit_workflow.py +458 -0
- gac/init_cli.py +2 -1
- gac/interactive_mode.py +179 -0
- gac/language_cli.py +0 -1
- gac/main.py +231 -926
- gac/model_cli.py +67 -11
- gac/model_identifier.py +70 -0
- gac/oauth/__init__.py +26 -0
- gac/oauth/claude_code.py +89 -22
- gac/oauth/qwen_oauth.py +327 -0
- gac/oauth/token_store.py +81 -0
- gac/oauth_retry.py +161 -0
- gac/postprocess.py +155 -0
- gac/prompt.py +21 -479
- gac/prompt_builder.py +88 -0
- gac/providers/README.md +437 -0
- gac/providers/__init__.py +70 -78
- gac/providers/anthropic.py +12 -46
- gac/providers/azure_openai.py +48 -88
- gac/providers/base.py +329 -0
- gac/providers/cerebras.py +10 -33
- gac/providers/chutes.py +16 -62
- gac/providers/claude_code.py +64 -87
- gac/providers/custom_anthropic.py +51 -81
- gac/providers/custom_openai.py +29 -83
- gac/providers/deepseek.py +10 -33
- gac/providers/error_handler.py +139 -0
- gac/providers/fireworks.py +10 -33
- gac/providers/gemini.py +66 -63
- gac/providers/groq.py +10 -58
- gac/providers/kimi_coding.py +19 -55
- gac/providers/lmstudio.py +64 -43
- gac/providers/minimax.py +10 -33
- gac/providers/mistral.py +10 -33
- gac/providers/moonshot.py +10 -33
- gac/providers/ollama.py +56 -33
- gac/providers/openai.py +30 -36
- gac/providers/openrouter.py +15 -52
- gac/providers/protocol.py +71 -0
- gac/providers/qwen.py +64 -0
- gac/providers/registry.py +58 -0
- gac/providers/replicate.py +140 -82
- gac/providers/streamlake.py +26 -46
- gac/providers/synthetic.py +35 -37
- gac/providers/together.py +10 -33
- gac/providers/zai.py +29 -57
- gac/py.typed +0 -0
- gac/security.py +1 -1
- gac/templates/__init__.py +1 -0
- gac/templates/question_generation.txt +60 -0
- gac/templates/system_prompt.txt +224 -0
- gac/templates/user_prompt.txt +28 -0
- gac/utils.py +36 -6
- gac/workflow_context.py +162 -0
- gac/workflow_utils.py +3 -8
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/METADATA +6 -4
- gac-3.10.10.dist-info/RECORD +79 -0
- gac/constants.py +0 -321
- gac-3.6.0.dist-info/RECORD +0 -53
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/WHEEL +0 -0
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/entry_points.txt +0 -0
- {gac-3.6.0.dist-info → gac-3.10.10.dist-info}/licenses/LICENSE +0 -0
gac/oauth/qwen_oauth.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
"""Qwen OAuth device flow implementation.
|
|
2
|
+
|
|
3
|
+
Implements OAuth 2.0 Device Authorization Grant (RFC 8628) with PKCE.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import hashlib
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import secrets
|
|
11
|
+
import time
|
|
12
|
+
import webbrowser
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
|
|
17
|
+
from gac import __version__
|
|
18
|
+
from gac.errors import AIError
|
|
19
|
+
from gac.oauth.token_store import OAuthToken, TokenStore
|
|
20
|
+
from gac.utils import get_ssl_verify
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
QWEN_CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
|
|
25
|
+
USER_AGENT = f"gac/{__version__}"
|
|
26
|
+
QWEN_DEVICE_CODE_ENDPOINT = "https://chat.qwen.ai/api/v1/oauth2/device/code"
|
|
27
|
+
QWEN_TOKEN_ENDPOINT = "https://chat.qwen.ai/api/v1/oauth2/token"
|
|
28
|
+
QWEN_SCOPES = ["openid", "profile", "email", "model.completion"]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class DeviceCodeResponse:
|
|
33
|
+
"""Response from the device authorization endpoint."""
|
|
34
|
+
|
|
35
|
+
device_code: str
|
|
36
|
+
user_code: str
|
|
37
|
+
verification_uri: str
|
|
38
|
+
verification_uri_complete: str | None
|
|
39
|
+
expires_in: int
|
|
40
|
+
interval: int = 5
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class QwenDeviceFlow:
|
|
45
|
+
"""Qwen OAuth device flow implementation with PKCE."""
|
|
46
|
+
|
|
47
|
+
client_id: str = QWEN_CLIENT_ID
|
|
48
|
+
authorization_endpoint: str = QWEN_DEVICE_CODE_ENDPOINT
|
|
49
|
+
token_endpoint: str = QWEN_TOKEN_ENDPOINT
|
|
50
|
+
scopes: list[str] = field(default_factory=lambda: QWEN_SCOPES.copy())
|
|
51
|
+
_pkce_verifier: str = field(default="", init=False)
|
|
52
|
+
|
|
53
|
+
def _generate_pkce(self) -> tuple[str, str]:
|
|
54
|
+
"""Generate PKCE code verifier and challenge.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tuple of (verifier, challenge) strings.
|
|
58
|
+
"""
|
|
59
|
+
verifier = secrets.token_urlsafe(32)
|
|
60
|
+
challenge = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).rstrip(b"=").decode()
|
|
61
|
+
return verifier, challenge
|
|
62
|
+
|
|
63
|
+
def initiate_device_flow(self) -> DeviceCodeResponse:
|
|
64
|
+
"""Initiate the device authorization flow.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
DeviceCodeResponse with device code and verification URIs.
|
|
68
|
+
"""
|
|
69
|
+
verifier, challenge = self._generate_pkce()
|
|
70
|
+
self._pkce_verifier = verifier
|
|
71
|
+
|
|
72
|
+
params = {
|
|
73
|
+
"client_id": self.client_id,
|
|
74
|
+
"code_challenge": challenge,
|
|
75
|
+
"code_challenge_method": "S256",
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
if self.scopes:
|
|
79
|
+
params["scope"] = " ".join(self.scopes)
|
|
80
|
+
|
|
81
|
+
response = httpx.post(
|
|
82
|
+
self.authorization_endpoint,
|
|
83
|
+
data=params,
|
|
84
|
+
headers={
|
|
85
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
86
|
+
"Accept": "application/json",
|
|
87
|
+
"User-Agent": USER_AGENT,
|
|
88
|
+
},
|
|
89
|
+
timeout=30,
|
|
90
|
+
verify=get_ssl_verify(),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if not response.is_success:
|
|
94
|
+
raise AIError.connection_error(f"Failed to initiate device flow: HTTP {response.status_code}")
|
|
95
|
+
|
|
96
|
+
data = response.json()
|
|
97
|
+
return DeviceCodeResponse(
|
|
98
|
+
device_code=data["device_code"],
|
|
99
|
+
user_code=data["user_code"],
|
|
100
|
+
verification_uri=data["verification_uri"],
|
|
101
|
+
verification_uri_complete=data.get("verification_uri_complete"),
|
|
102
|
+
expires_in=data["expires_in"],
|
|
103
|
+
interval=data.get("interval", 5),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def poll_for_token(self, device_code: str, max_duration: int = 900) -> OAuthToken:
|
|
107
|
+
"""Poll the authorization server for an access token.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
device_code: Device code from initiation response.
|
|
111
|
+
max_duration: Maximum polling duration in seconds (default 15 minutes).
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
OAuthToken with access token and metadata.
|
|
115
|
+
"""
|
|
116
|
+
start_time = time.time()
|
|
117
|
+
interval = 5
|
|
118
|
+
|
|
119
|
+
while time.time() - start_time < max_duration:
|
|
120
|
+
params = {
|
|
121
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
|
122
|
+
"device_code": device_code,
|
|
123
|
+
"client_id": self.client_id,
|
|
124
|
+
"code_verifier": self._pkce_verifier,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
response = httpx.post(
|
|
129
|
+
self.token_endpoint,
|
|
130
|
+
data=params,
|
|
131
|
+
headers={
|
|
132
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
133
|
+
"Accept": "application/json",
|
|
134
|
+
"User-Agent": USER_AGENT,
|
|
135
|
+
},
|
|
136
|
+
timeout=30,
|
|
137
|
+
verify=get_ssl_verify(),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if response.is_success:
|
|
141
|
+
data = response.json()
|
|
142
|
+
now = int(time.time())
|
|
143
|
+
expires_in = data.get("expires_in", 3600)
|
|
144
|
+
|
|
145
|
+
return OAuthToken(
|
|
146
|
+
access_token=data["access_token"],
|
|
147
|
+
token_type="Bearer",
|
|
148
|
+
expiry=now + expires_in,
|
|
149
|
+
refresh_token=data.get("refresh_token"),
|
|
150
|
+
scope=data.get("scope"),
|
|
151
|
+
resource_url=data.get("resource_url"),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
error_data = response.json()
|
|
155
|
+
error = error_data.get("error", "")
|
|
156
|
+
|
|
157
|
+
if error == "authorization_pending":
|
|
158
|
+
time.sleep(interval)
|
|
159
|
+
continue
|
|
160
|
+
elif error == "slow_down":
|
|
161
|
+
interval += 5
|
|
162
|
+
time.sleep(interval)
|
|
163
|
+
continue
|
|
164
|
+
elif error == "access_denied":
|
|
165
|
+
raise AIError.authentication_error("Authorization was denied by user")
|
|
166
|
+
elif error == "expired_token":
|
|
167
|
+
raise AIError.authentication_error("Device code expired. Please try again.")
|
|
168
|
+
|
|
169
|
+
raise AIError.connection_error(f"Token request failed: {response.status_code}")
|
|
170
|
+
|
|
171
|
+
except httpx.RequestError as e:
|
|
172
|
+
interval = int(min(interval * 1.5, 60))
|
|
173
|
+
logger.debug(f"Network error during polling, retrying in {interval}s: {e}")
|
|
174
|
+
time.sleep(interval)
|
|
175
|
+
continue
|
|
176
|
+
|
|
177
|
+
raise AIError.timeout_error("Authorization timeout exceeded. Please try again.")
|
|
178
|
+
|
|
179
|
+
def refresh_token(self, refresh_token: str) -> OAuthToken:
|
|
180
|
+
"""Refresh an expired access token.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
refresh_token: Valid refresh token.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
New OAuthToken with refreshed access token.
|
|
187
|
+
"""
|
|
188
|
+
params = {
|
|
189
|
+
"grant_type": "refresh_token",
|
|
190
|
+
"refresh_token": refresh_token,
|
|
191
|
+
"client_id": self.client_id,
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
response = httpx.post(
|
|
195
|
+
self.token_endpoint,
|
|
196
|
+
data=params,
|
|
197
|
+
headers={
|
|
198
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
199
|
+
"Accept": "application/json",
|
|
200
|
+
"User-Agent": USER_AGENT,
|
|
201
|
+
},
|
|
202
|
+
timeout=30,
|
|
203
|
+
verify=get_ssl_verify(),
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if not response.is_success:
|
|
207
|
+
raise AIError.authentication_error(f"Token refresh failed: HTTP {response.status_code}")
|
|
208
|
+
|
|
209
|
+
data = response.json()
|
|
210
|
+
now = int(time.time())
|
|
211
|
+
expires_in = data.get("expires_in", 3600)
|
|
212
|
+
|
|
213
|
+
return OAuthToken(
|
|
214
|
+
access_token=data["access_token"],
|
|
215
|
+
token_type="Bearer",
|
|
216
|
+
expiry=now + expires_in - 30,
|
|
217
|
+
refresh_token=data.get("refresh_token") or refresh_token,
|
|
218
|
+
scope=data.get("scope"),
|
|
219
|
+
resource_url=data.get("resource_url"),
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class QwenOAuthProvider:
|
|
224
|
+
"""Qwen OAuth provider for authentication management."""
|
|
225
|
+
|
|
226
|
+
name = "qwen"
|
|
227
|
+
|
|
228
|
+
def __init__(self, token_store: TokenStore | None = None):
|
|
229
|
+
self.token_store = token_store or TokenStore()
|
|
230
|
+
self.device_flow = QwenDeviceFlow()
|
|
231
|
+
|
|
232
|
+
def _is_token_expired(self, token: OAuthToken) -> bool:
|
|
233
|
+
"""Check if token is expired or near expiry (30-second buffer)."""
|
|
234
|
+
now = time.time()
|
|
235
|
+
buffer = 30
|
|
236
|
+
return token["expiry"] <= now + buffer
|
|
237
|
+
|
|
238
|
+
def initiate_auth(self, open_browser: bool = True) -> None:
|
|
239
|
+
"""Initiate the OAuth authentication flow.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
open_browser: Whether to automatically open the browser.
|
|
243
|
+
"""
|
|
244
|
+
device_response = self.device_flow.initiate_device_flow()
|
|
245
|
+
|
|
246
|
+
auth_url = device_response.verification_uri_complete or (
|
|
247
|
+
f"{device_response.verification_uri}?user_code={device_response.user_code}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
print("\nQwen OAuth Authentication")
|
|
251
|
+
print("-" * 40)
|
|
252
|
+
print("Please visit the following URL to authorize:")
|
|
253
|
+
print(auth_url)
|
|
254
|
+
print(f"\nUser code: {device_response.user_code}")
|
|
255
|
+
|
|
256
|
+
if open_browser and self._should_launch_browser():
|
|
257
|
+
print("Opening browser for authentication...")
|
|
258
|
+
try:
|
|
259
|
+
webbrowser.open(auth_url)
|
|
260
|
+
except Exception as e:
|
|
261
|
+
logger.debug(f"Failed to open browser: {e}")
|
|
262
|
+
print("Failed to open browser automatically. Please open the URL manually.")
|
|
263
|
+
|
|
264
|
+
print("-" * 40)
|
|
265
|
+
print("Waiting for authorization...\n")
|
|
266
|
+
|
|
267
|
+
token = self.device_flow.poll_for_token(device_response.device_code)
|
|
268
|
+
self.token_store.save_token("qwen", token)
|
|
269
|
+
|
|
270
|
+
print("Authentication successful!")
|
|
271
|
+
|
|
272
|
+
def _should_launch_browser(self) -> bool:
|
|
273
|
+
"""Check if we should launch a browser."""
|
|
274
|
+
if os.getenv("SSH_CLIENT") or os.getenv("SSH_TTY"):
|
|
275
|
+
return False
|
|
276
|
+
if not os.getenv("DISPLAY") and os.name != "nt":
|
|
277
|
+
if os.uname().sysname != "Darwin":
|
|
278
|
+
return False
|
|
279
|
+
return True
|
|
280
|
+
|
|
281
|
+
def get_token(self) -> OAuthToken | None:
|
|
282
|
+
"""Get the current access token, refreshing if needed."""
|
|
283
|
+
token = self.token_store.get_token("qwen")
|
|
284
|
+
if not token:
|
|
285
|
+
return None
|
|
286
|
+
|
|
287
|
+
if self._is_token_expired(token):
|
|
288
|
+
return self.refresh_if_needed()
|
|
289
|
+
|
|
290
|
+
return token
|
|
291
|
+
|
|
292
|
+
def refresh_if_needed(self) -> OAuthToken | None:
|
|
293
|
+
"""Refresh the token if expired.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Refreshed token or None if refresh fails.
|
|
297
|
+
"""
|
|
298
|
+
current_token = self.token_store.get_token("qwen")
|
|
299
|
+
if not current_token:
|
|
300
|
+
return None
|
|
301
|
+
|
|
302
|
+
if self._is_token_expired(current_token):
|
|
303
|
+
refresh_token = current_token.get("refresh_token")
|
|
304
|
+
if refresh_token:
|
|
305
|
+
try:
|
|
306
|
+
refreshed_token = self.device_flow.refresh_token(refresh_token)
|
|
307
|
+
self.token_store.save_token("qwen", refreshed_token)
|
|
308
|
+
return refreshed_token
|
|
309
|
+
except Exception as e:
|
|
310
|
+
logger.debug(f"Token refresh failed: {e}")
|
|
311
|
+
self.token_store.remove_token("qwen")
|
|
312
|
+
return None
|
|
313
|
+
else:
|
|
314
|
+
self.token_store.remove_token("qwen")
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
return current_token
|
|
318
|
+
|
|
319
|
+
def logout(self) -> None:
|
|
320
|
+
"""Log out by removing stored tokens."""
|
|
321
|
+
self.token_store.remove_token("qwen")
|
|
322
|
+
print("Successfully logged out from Qwen")
|
|
323
|
+
|
|
324
|
+
def is_authenticated(self) -> bool:
|
|
325
|
+
"""Check if we have a valid token."""
|
|
326
|
+
token = self.get_token()
|
|
327
|
+
return token is not None
|
gac/oauth/token_store.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Token storage for OAuth authentication."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import stat
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TypedDict, cast
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OAuthToken(TypedDict, total=False):
|
|
12
|
+
"""OAuth token structure."""
|
|
13
|
+
|
|
14
|
+
access_token: str
|
|
15
|
+
refresh_token: str | None
|
|
16
|
+
expiry: int
|
|
17
|
+
token_type: str
|
|
18
|
+
scope: str | None
|
|
19
|
+
resource_url: str | None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class TokenStore:
|
|
24
|
+
"""Secure file-based token storage for OAuth tokens."""
|
|
25
|
+
|
|
26
|
+
base_dir: Path
|
|
27
|
+
|
|
28
|
+
def __init__(self, base_dir: Path | None = None):
|
|
29
|
+
if base_dir is None:
|
|
30
|
+
base_dir = Path.home() / ".gac" / "oauth"
|
|
31
|
+
self.base_dir = base_dir
|
|
32
|
+
self._ensure_directory()
|
|
33
|
+
|
|
34
|
+
def _ensure_directory(self) -> None:
|
|
35
|
+
"""Create the OAuth directory with secure permissions."""
|
|
36
|
+
if not self.base_dir.exists():
|
|
37
|
+
self.base_dir.mkdir(parents=True, mode=0o700)
|
|
38
|
+
else:
|
|
39
|
+
os.chmod(self.base_dir, stat.S_IRWXU)
|
|
40
|
+
|
|
41
|
+
def _get_token_path(self, provider: str) -> Path:
|
|
42
|
+
"""Get the path for a provider's token file."""
|
|
43
|
+
return self.base_dir / f"{provider}.json"
|
|
44
|
+
|
|
45
|
+
def save_token(self, provider: str, token: OAuthToken) -> None:
|
|
46
|
+
"""Save a token to file with secure permissions.
|
|
47
|
+
|
|
48
|
+
Uses atomic write (temp file + rename) to prevent partial reads.
|
|
49
|
+
"""
|
|
50
|
+
token_path = self._get_token_path(provider)
|
|
51
|
+
temp_path = token_path.with_suffix(".tmp")
|
|
52
|
+
|
|
53
|
+
with open(temp_path, "w") as f:
|
|
54
|
+
json.dump(token, f, indent=2)
|
|
55
|
+
|
|
56
|
+
os.chmod(temp_path, stat.S_IRUSR | stat.S_IWUSR)
|
|
57
|
+
temp_path.rename(token_path)
|
|
58
|
+
|
|
59
|
+
def get_token(self, provider: str) -> OAuthToken | None:
|
|
60
|
+
"""Retrieve a token from file."""
|
|
61
|
+
token_path = self._get_token_path(provider)
|
|
62
|
+
if not token_path.exists():
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
with open(token_path) as f:
|
|
66
|
+
token_data = json.load(f)
|
|
67
|
+
if isinstance(token_data, dict) and isinstance(token_data.get("access_token"), str):
|
|
68
|
+
return cast(OAuthToken, token_data)
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
def remove_token(self, provider: str) -> None:
|
|
72
|
+
"""Remove a token file."""
|
|
73
|
+
token_path = self._get_token_path(provider)
|
|
74
|
+
if token_path.exists():
|
|
75
|
+
token_path.unlink()
|
|
76
|
+
|
|
77
|
+
def list_providers(self) -> list[str]:
|
|
78
|
+
"""List all providers with stored tokens."""
|
|
79
|
+
if not self.base_dir.exists():
|
|
80
|
+
return []
|
|
81
|
+
return [f.stem for f in self.base_dir.glob("*.json")]
|
gac/oauth_retry.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""OAuth retry handling for expired tokens.
|
|
2
|
+
|
|
3
|
+
This module provides a unified mechanism for handling OAuth token expiration
|
|
4
|
+
across different providers (Claude Code, Qwen, etc.).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
|
|
16
|
+
from gac.errors import AIError, ConfigError
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from gac.workflow_context import WorkflowContext
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
console = Console()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class OAuthProviderConfig:
|
|
27
|
+
"""Configuration for OAuth retry handling for a specific provider."""
|
|
28
|
+
|
|
29
|
+
provider_prefix: str
|
|
30
|
+
display_name: str
|
|
31
|
+
manual_auth_hint: str
|
|
32
|
+
authenticate: Callable[[bool], bool]
|
|
33
|
+
extra_error_check: Callable[[AIError], bool] | None = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _create_claude_code_authenticator() -> Callable[[bool], bool]:
|
|
37
|
+
"""Create authenticator function for Claude Code."""
|
|
38
|
+
|
|
39
|
+
def authenticate(quiet: bool) -> bool:
|
|
40
|
+
from gac.oauth.claude_code import authenticate_and_save
|
|
41
|
+
|
|
42
|
+
return authenticate_and_save(quiet=quiet)
|
|
43
|
+
|
|
44
|
+
return authenticate
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _create_qwen_authenticator() -> Callable[[bool], bool]:
|
|
48
|
+
"""Create authenticator function for Qwen."""
|
|
49
|
+
|
|
50
|
+
def authenticate(quiet: bool) -> bool:
|
|
51
|
+
from gac.oauth import QwenOAuthProvider, TokenStore
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
oauth_provider = QwenOAuthProvider(TokenStore())
|
|
55
|
+
oauth_provider.initiate_auth(open_browser=True)
|
|
56
|
+
return True
|
|
57
|
+
except (AIError, ConfigError, OSError):
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
return authenticate
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _claude_code_extra_check(e: AIError) -> bool:
|
|
64
|
+
"""Extra check for Claude Code - verify error message contains expired/oauth."""
|
|
65
|
+
error_str = str(e).lower()
|
|
66
|
+
return "expired" in error_str or "oauth" in error_str
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
OAUTH_PROVIDERS: list[OAuthProviderConfig] = [
|
|
70
|
+
OAuthProviderConfig(
|
|
71
|
+
provider_prefix="claude-code:",
|
|
72
|
+
display_name="Claude Code",
|
|
73
|
+
manual_auth_hint="Run 'gac model' to re-authenticate manually.",
|
|
74
|
+
authenticate=_create_claude_code_authenticator(),
|
|
75
|
+
extra_error_check=_claude_code_extra_check,
|
|
76
|
+
),
|
|
77
|
+
OAuthProviderConfig(
|
|
78
|
+
provider_prefix="qwen:",
|
|
79
|
+
display_name="Qwen",
|
|
80
|
+
manual_auth_hint="Run 'gac auth qwen login' to re-authenticate manually.",
|
|
81
|
+
authenticate=_create_qwen_authenticator(),
|
|
82
|
+
extra_error_check=None,
|
|
83
|
+
),
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _find_oauth_provider(model: str, error: AIError) -> OAuthProviderConfig | None:
|
|
88
|
+
"""Find the OAuth provider config that matches the model and error."""
|
|
89
|
+
if error.error_type != "authentication":
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
for provider in OAUTH_PROVIDERS:
|
|
93
|
+
if not model.startswith(provider.provider_prefix):
|
|
94
|
+
continue
|
|
95
|
+
if provider.extra_error_check and not provider.extra_error_check(error):
|
|
96
|
+
continue
|
|
97
|
+
return provider
|
|
98
|
+
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _attempt_reauth_and_retry(
|
|
103
|
+
provider: OAuthProviderConfig,
|
|
104
|
+
quiet: bool,
|
|
105
|
+
retry_workflow: Callable[[], int],
|
|
106
|
+
) -> int:
|
|
107
|
+
"""Attempt re-authentication and retry the workflow.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
provider: The OAuth provider configuration
|
|
111
|
+
quiet: Whether to suppress output
|
|
112
|
+
retry_workflow: Callable that retries the workflow on success
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Exit code: 0 for success, 1 for failure
|
|
116
|
+
"""
|
|
117
|
+
console.print(f"[yellow]⚠ {provider.display_name} OAuth token has expired[/yellow]")
|
|
118
|
+
console.print("[cyan]🔐 Starting automatic re-authentication...[/cyan]")
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
if provider.authenticate(quiet):
|
|
122
|
+
console.print("[green]✓ Re-authentication successful![/green]")
|
|
123
|
+
console.print("[cyan]Retrying commit...[/cyan]\n")
|
|
124
|
+
return retry_workflow()
|
|
125
|
+
else:
|
|
126
|
+
console.print("[red]Re-authentication failed.[/red]")
|
|
127
|
+
console.print(f"[yellow]{provider.manual_auth_hint}[/yellow]")
|
|
128
|
+
return 1
|
|
129
|
+
except (AIError, ConfigError, OSError) as auth_error:
|
|
130
|
+
console.print(f"[red]Re-authentication error: {auth_error}[/red]")
|
|
131
|
+
console.print(f"[yellow]{provider.manual_auth_hint}[/yellow]")
|
|
132
|
+
return 1
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def handle_oauth_retry(e: AIError, ctx: WorkflowContext) -> int:
|
|
136
|
+
"""Handle OAuth retry logic for expired tokens.
|
|
137
|
+
|
|
138
|
+
Checks if the error is an OAuth-related authentication error for a known
|
|
139
|
+
provider, attempts re-authentication, and retries the workflow on success.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
e: The AIError that triggered this handler
|
|
143
|
+
ctx: WorkflowContext containing all workflow configuration and state
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Exit code: 0 for success, 1 for failure
|
|
147
|
+
"""
|
|
148
|
+
logger.error(str(e))
|
|
149
|
+
|
|
150
|
+
provider = _find_oauth_provider(ctx.model, e)
|
|
151
|
+
|
|
152
|
+
if provider is None:
|
|
153
|
+
console.print(f"[red]Failed to generate commit message: {e!s}[/red]")
|
|
154
|
+
return 1
|
|
155
|
+
|
|
156
|
+
def retry_workflow() -> int:
|
|
157
|
+
from gac.main import _execute_single_commit_workflow
|
|
158
|
+
|
|
159
|
+
return _execute_single_commit_workflow(ctx)
|
|
160
|
+
|
|
161
|
+
return _attempt_reauth_and_retry(provider, ctx.quiet, retry_workflow)
|