vtx-coding-agent 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vtx/__init__.py +63 -0
- vtx/async_utils.py +40 -0
- vtx/builtin_skills/github/SKILL.md +139 -0
- vtx/builtin_skills/init/SKILL.md +74 -0
- vtx/builtin_skills/review/SKILL.md +73 -0
- vtx/builtin_skills/skill-builder/SKILL.md +133 -0
- vtx/cli.py +90 -0
- vtx/config.py +741 -0
- vtx/context/__init__.py +15 -0
- vtx/context/_xml.py +8 -0
- vtx/context/agent_mds.py +128 -0
- vtx/context/git.py +64 -0
- vtx/context/loader.py +41 -0
- vtx/context/skills.py +423 -0
- vtx/core/__init__.py +47 -0
- vtx/core/compaction.py +89 -0
- vtx/core/errors.py +17 -0
- vtx/core/handoff.py +51 -0
- vtx/core/scratchpad.py +54 -0
- vtx/core/types.py +197 -0
- vtx/defaults/__init__.py +0 -0
- vtx/defaults/config.yml +53 -0
- vtx/diff_display.py +12 -0
- vtx/events.py +224 -0
- vtx/gh_cli.py +82 -0
- vtx/git_branch.py +90 -0
- vtx/headless.py +127 -0
- vtx/llm/__init__.py +93 -0
- vtx/llm/base.py +217 -0
- vtx/llm/context_length.py +150 -0
- vtx/llm/dynamic_models.py +735 -0
- vtx/llm/model_fetcher.py +279 -0
- vtx/llm/models.py +78 -0
- vtx/llm/oauth/__init__.py +59 -0
- vtx/llm/oauth/copilot.py +358 -0
- vtx/llm/oauth/dynamic.py +236 -0
- vtx/llm/oauth/openai.py +400 -0
- vtx/llm/phase_parser.py +270 -0
- vtx/llm/provider.yaml +280 -0
- vtx/llm/provider_catalog.py +230 -0
- vtx/llm/providers/__init__.py +45 -0
- vtx/llm/providers/anthropic_sdk.py +256 -0
- vtx/llm/providers/mock.py +249 -0
- vtx/llm/providers/openai_sdk.py +246 -0
- vtx/llm/providers/sanitize.py +14 -0
- vtx/llm/sdk/__init__.py +13 -0
- vtx/llm/sdk/anthropic.py +382 -0
- vtx/llm/sdk/base.py +82 -0
- vtx/llm/sdk/openai.py +344 -0
- vtx/llm/tool_parser.py +161 -0
- vtx/loop.py +272 -0
- vtx/notify.py +109 -0
- vtx/permissions.py +114 -0
- vtx/prompts/__init__.py +45 -0
- vtx/prompts/builder.py +86 -0
- vtx/prompts/env.py +58 -0
- vtx/prompts/identity.py +166 -0
- vtx/prompts/tooling.py +36 -0
- vtx/py.typed +0 -0
- vtx/runtime.py +580 -0
- vtx/session.py +868 -0
- vtx/sounds/completion.wav +0 -0
- vtx/sounds/error.wav +0 -0
- vtx/sounds/permission.wav +0 -0
- vtx/themes.py +1104 -0
- vtx/tools/__init__.py +68 -0
- vtx/tools/_read_image.py +106 -0
- vtx/tools/_tool_utils.py +90 -0
- vtx/tools/base.py +36 -0
- vtx/tools/bash.py +371 -0
- vtx/tools/edit.py +261 -0
- vtx/tools/find.py +132 -0
- vtx/tools/read.py +238 -0
- vtx/tools/skill.py +278 -0
- vtx/tools/web.py +238 -0
- vtx/tools/write.py +88 -0
- vtx/tools_manager.py +216 -0
- vtx/turn.py +789 -0
- vtx/ui/__init__.py +0 -0
- vtx/ui/agent_runner.py +417 -0
- vtx/ui/app.py +665 -0
- vtx/ui/app_protocol.py +29 -0
- vtx/ui/autocomplete.py +440 -0
- vtx/ui/blocks.py +735 -0
- vtx/ui/chat.py +613 -0
- vtx/ui/clipboard.py +59 -0
- vtx/ui/commands/__init__.py +100 -0
- vtx/ui/commands/auth.py +306 -0
- vtx/ui/commands/base.py +122 -0
- vtx/ui/commands/models.py +144 -0
- vtx/ui/commands/sessions.py +388 -0
- vtx/ui/commands/settings.py +286 -0
- vtx/ui/completion_ui.py +313 -0
- vtx/ui/export.py +703 -0
- vtx/ui/floating_list.py +370 -0
- vtx/ui/formatting.py +287 -0
- vtx/ui/input.py +760 -0
- vtx/ui/latex.py +349 -0
- vtx/ui/launch.py +108 -0
- vtx/ui/path_complete.py +228 -0
- vtx/ui/prompt_history.py +102 -0
- vtx/ui/queue_ui.py +141 -0
- vtx/ui/selection_mode.py +18 -0
- vtx/ui/session_ui.py +235 -0
- vtx/ui/startup.py +124 -0
- vtx/ui/styles.py +327 -0
- vtx/ui/tool_output.py +34 -0
- vtx/ui/tree.py +437 -0
- vtx/ui/welcome.py +51 -0
- vtx/ui/widgets.py +558 -0
- vtx/update_check.py +49 -0
- vtx/version.py +22 -0
- vtx_coding_agent-0.1.1.dist-info/METADATA +259 -0
- vtx_coding_agent-0.1.1.dist-info/RECORD +117 -0
- vtx_coding_agent-0.1.1.dist-info/WHEEL +4 -0
- vtx_coding_agent-0.1.1.dist-info/entry_points.txt +2 -0
- vtx_coding_agent-0.1.1.dist-info/licenses/LICENSE +201 -0
vtx/llm/oauth/copilot.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GitHub Copilot OAuth device flow.
|
|
3
|
+
|
|
4
|
+
Implements the device code flow to authenticate with GitHub and
|
|
5
|
+
exchange for a Copilot token that can be used with the Copilot API.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
from base64 import b64decode
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import aiohttp
|
|
16
|
+
|
|
17
|
+
from vtx import get_config_dir
|
|
18
|
+
|
|
19
|
+
# GitHub OAuth client ID (same as VS Code Copilot extension)
|
|
20
|
+
_CLIENT_ID = b64decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=").decode()
|
|
21
|
+
|
|
22
|
+
# Required headers for Copilot API
|
|
23
|
+
COPILOT_HEADERS = {
|
|
24
|
+
"User-Agent": "GitHubCopilotChat/0.35.0",
|
|
25
|
+
"Editor-Version": "vscode/1.107.0",
|
|
26
|
+
"Editor-Plugin-Version": "copilot-chat/0.35.0",
|
|
27
|
+
"Copilot-Integration-Id": "vscode-chat",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class CopilotCredentials:
|
|
33
|
+
github_token: str # Long-lived GitHub OAuth token (refresh token)
|
|
34
|
+
copilot_token: str # Short-lived Copilot API token (access token)
|
|
35
|
+
expires_at: int # Unix timestamp (milliseconds) when copilot_token expires
|
|
36
|
+
enterprise_domain: str | None = None # For GitHub Enterprise
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class DeviceCodeResponse:
|
|
41
|
+
device_code: str
|
|
42
|
+
user_code: str
|
|
43
|
+
verification_uri: str
|
|
44
|
+
interval: int
|
|
45
|
+
expires_in: int
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_copilot_auth_path() -> Path:
|
|
49
|
+
return get_config_dir() / "copilot_auth.json"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_credentials() -> CopilotCredentials | None:
|
|
53
|
+
path = get_copilot_auth_path()
|
|
54
|
+
if not path.exists():
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
data = json.loads(path.read_text())
|
|
59
|
+
return CopilotCredentials(
|
|
60
|
+
github_token=data["github_token"],
|
|
61
|
+
copilot_token=data["copilot_token"],
|
|
62
|
+
expires_at=data["expires_at"],
|
|
63
|
+
enterprise_domain=data.get("enterprise_domain"),
|
|
64
|
+
)
|
|
65
|
+
except (json.JSONDecodeError, KeyError):
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def save_credentials(creds: CopilotCredentials) -> None:
|
|
70
|
+
path = get_copilot_auth_path()
|
|
71
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
72
|
+
|
|
73
|
+
data = {
|
|
74
|
+
"github_token": creds.github_token,
|
|
75
|
+
"copilot_token": creds.copilot_token,
|
|
76
|
+
"expires_at": creds.expires_at,
|
|
77
|
+
}
|
|
78
|
+
if creds.enterprise_domain:
|
|
79
|
+
data["enterprise_domain"] = creds.enterprise_domain
|
|
80
|
+
|
|
81
|
+
path.write_text(json.dumps(data, indent=2))
|
|
82
|
+
path.chmod(0o600)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def clear_credentials() -> None:
|
|
86
|
+
path = get_copilot_auth_path()
|
|
87
|
+
if path.exists():
|
|
88
|
+
path.unlink()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def is_copilot_logged_in() -> bool:
|
|
92
|
+
return load_credentials() is not None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _get_urls(domain: str) -> dict[str, str]:
|
|
96
|
+
return {
|
|
97
|
+
"device_code": f"https://{domain}/login/device/code",
|
|
98
|
+
"access_token": f"https://{domain}/login/oauth/access_token",
|
|
99
|
+
"copilot_token": f"https://api.{domain}/copilot_internal/v2/token",
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_base_url_from_token(token: str, enterprise_domain: str | None = None) -> str:
|
|
104
|
+
"""
|
|
105
|
+
Extract API base URL from Copilot token.
|
|
106
|
+
|
|
107
|
+
Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
|
|
108
|
+
Returns API URL like https://api.individual.githubcopilot.com
|
|
109
|
+
"""
|
|
110
|
+
import re
|
|
111
|
+
|
|
112
|
+
match = re.search(r"proxy-ep=([^;]+)", token)
|
|
113
|
+
if match:
|
|
114
|
+
proxy_host = match.group(1)
|
|
115
|
+
api_host = proxy_host.replace("proxy.", "api.", 1)
|
|
116
|
+
return f"https://{api_host}"
|
|
117
|
+
|
|
118
|
+
# Fallback
|
|
119
|
+
if enterprise_domain:
|
|
120
|
+
return f"https://copilot-api.{enterprise_domain}"
|
|
121
|
+
return "https://api.individual.githubcopilot.com"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
async def start_device_flow(domain: str = "github.com") -> DeviceCodeResponse:
|
|
125
|
+
urls = _get_urls(domain)
|
|
126
|
+
|
|
127
|
+
async with (
|
|
128
|
+
aiohttp.ClientSession() as session,
|
|
129
|
+
session.post(
|
|
130
|
+
urls["device_code"],
|
|
131
|
+
headers={
|
|
132
|
+
"Accept": "application/json",
|
|
133
|
+
"Content-Type": "application/json",
|
|
134
|
+
"User-Agent": "GitHubCopilotChat/0.35.0",
|
|
135
|
+
},
|
|
136
|
+
json={"client_id": _CLIENT_ID, "scope": "read:user"},
|
|
137
|
+
) as response,
|
|
138
|
+
):
|
|
139
|
+
response.raise_for_status()
|
|
140
|
+
data = await response.json()
|
|
141
|
+
|
|
142
|
+
return DeviceCodeResponse(
|
|
143
|
+
device_code=data["device_code"],
|
|
144
|
+
user_code=data["user_code"],
|
|
145
|
+
verification_uri=data["verification_uri"],
|
|
146
|
+
interval=data["interval"],
|
|
147
|
+
expires_in=data["expires_in"],
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def poll_for_github_token(
|
|
152
|
+
device_code: str,
|
|
153
|
+
interval: int,
|
|
154
|
+
expires_in: int,
|
|
155
|
+
domain: str = "github.com",
|
|
156
|
+
on_poll: Any | None = None,
|
|
157
|
+
) -> str:
|
|
158
|
+
"""
|
|
159
|
+
Poll GitHub for the access token after user authorizes.
|
|
160
|
+
|
|
161
|
+
Returns the GitHub OAuth access token.
|
|
162
|
+
Raises TimeoutError if the flow expires.
|
|
163
|
+
"""
|
|
164
|
+
import time
|
|
165
|
+
|
|
166
|
+
urls = _get_urls(domain)
|
|
167
|
+
deadline = time.time() + expires_in
|
|
168
|
+
poll_interval = max(1, interval)
|
|
169
|
+
|
|
170
|
+
async with aiohttp.ClientSession() as session:
|
|
171
|
+
while time.time() < deadline:
|
|
172
|
+
if on_poll:
|
|
173
|
+
on_poll()
|
|
174
|
+
|
|
175
|
+
async with session.post(
|
|
176
|
+
urls["access_token"],
|
|
177
|
+
headers={
|
|
178
|
+
"Accept": "application/json",
|
|
179
|
+
"Content-Type": "application/json",
|
|
180
|
+
"User-Agent": "GitHubCopilotChat/0.35.0",
|
|
181
|
+
},
|
|
182
|
+
json={
|
|
183
|
+
"client_id": _CLIENT_ID,
|
|
184
|
+
"device_code": device_code,
|
|
185
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
|
186
|
+
},
|
|
187
|
+
) as response:
|
|
188
|
+
data = await response.json()
|
|
189
|
+
|
|
190
|
+
if "access_token" in data:
|
|
191
|
+
return data["access_token"]
|
|
192
|
+
|
|
193
|
+
error = data.get("error")
|
|
194
|
+
if error == "authorization_pending":
|
|
195
|
+
await asyncio.sleep(poll_interval)
|
|
196
|
+
continue
|
|
197
|
+
elif error == "slow_down":
|
|
198
|
+
poll_interval += 5
|
|
199
|
+
await asyncio.sleep(poll_interval)
|
|
200
|
+
continue
|
|
201
|
+
elif error == "expired_token":
|
|
202
|
+
raise TimeoutError("Device code expired")
|
|
203
|
+
else:
|
|
204
|
+
raise RuntimeError(f"OAuth error: {error}")
|
|
205
|
+
|
|
206
|
+
raise TimeoutError("Device code flow timed out")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
async def exchange_for_copilot_token(
|
|
210
|
+
github_token: str, domain: str = "github.com"
|
|
211
|
+
) -> tuple[str, int]:
|
|
212
|
+
"""
|
|
213
|
+
Exchange GitHub OAuth token for Copilot API token.
|
|
214
|
+
|
|
215
|
+
Returns (copilot_token, expires_at_ms).
|
|
216
|
+
"""
|
|
217
|
+
urls = _get_urls(domain)
|
|
218
|
+
|
|
219
|
+
async with (
|
|
220
|
+
aiohttp.ClientSession() as session,
|
|
221
|
+
session.get(
|
|
222
|
+
urls["copilot_token"],
|
|
223
|
+
headers={
|
|
224
|
+
"Accept": "application/json",
|
|
225
|
+
"Authorization": f"Bearer {github_token}",
|
|
226
|
+
**COPILOT_HEADERS,
|
|
227
|
+
},
|
|
228
|
+
) as response,
|
|
229
|
+
):
|
|
230
|
+
if response.status == 401:
|
|
231
|
+
raise RuntimeError(
|
|
232
|
+
"GitHub Copilot subscription not found. "
|
|
233
|
+
"Make sure you have an active Copilot subscription."
|
|
234
|
+
)
|
|
235
|
+
response.raise_for_status()
|
|
236
|
+
data = await response.json()
|
|
237
|
+
|
|
238
|
+
token = data["token"]
|
|
239
|
+
# expires_at is in seconds, convert to milliseconds with 5min buffer
|
|
240
|
+
expires_at = data["expires_at"] * 1000 - 5 * 60 * 1000
|
|
241
|
+
|
|
242
|
+
return token, expires_at
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
async def refresh_copilot_token(creds: CopilotCredentials) -> CopilotCredentials:
|
|
246
|
+
domain = creds.enterprise_domain or "github.com"
|
|
247
|
+
copilot_token, expires_at = await exchange_for_copilot_token(creds.github_token, domain)
|
|
248
|
+
|
|
249
|
+
new_creds = CopilotCredentials(
|
|
250
|
+
github_token=creds.github_token,
|
|
251
|
+
copilot_token=copilot_token,
|
|
252
|
+
expires_at=expires_at,
|
|
253
|
+
enterprise_domain=creds.enterprise_domain,
|
|
254
|
+
)
|
|
255
|
+
save_credentials(new_creds)
|
|
256
|
+
return new_creds
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
async def get_valid_token() -> str | None:
|
|
260
|
+
"""
|
|
261
|
+
Get a valid Copilot API token, refreshing if needed.
|
|
262
|
+
|
|
263
|
+
Returns None if not logged in.
|
|
264
|
+
"""
|
|
265
|
+
import time
|
|
266
|
+
|
|
267
|
+
creds = load_credentials()
|
|
268
|
+
if not creds:
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
# Check if token needs refresh (with 1 minute buffer)
|
|
272
|
+
if time.time() * 1000 >= creds.expires_at - 60_000:
|
|
273
|
+
try:
|
|
274
|
+
creds = await refresh_copilot_token(creds)
|
|
275
|
+
except Exception:
|
|
276
|
+
# Token refresh failed, need to re-login
|
|
277
|
+
return None
|
|
278
|
+
|
|
279
|
+
return creds.copilot_token
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
async def _enable_copilot_model(
|
|
283
|
+
token: str, model_id: str, enterprise_domain: str | None = None
|
|
284
|
+
) -> bool:
|
|
285
|
+
base_url = get_base_url_from_token(token, enterprise_domain)
|
|
286
|
+
url = f"{base_url}/models/{model_id}/policy"
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
async with (
|
|
290
|
+
aiohttp.ClientSession() as session,
|
|
291
|
+
session.post(
|
|
292
|
+
url,
|
|
293
|
+
headers={
|
|
294
|
+
"Content-Type": "application/json",
|
|
295
|
+
"Authorization": f"Bearer {token}",
|
|
296
|
+
**COPILOT_HEADERS,
|
|
297
|
+
"openai-intent": "chat-policy",
|
|
298
|
+
"x-interaction-type": "chat-policy",
|
|
299
|
+
},
|
|
300
|
+
json={"state": "enabled"},
|
|
301
|
+
) as response,
|
|
302
|
+
):
|
|
303
|
+
return response.status < 400
|
|
304
|
+
except Exception:
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
async def enable_all_copilot_models(token: str, enterprise_domain: str | None = None) -> None:
|
|
309
|
+
from ..models import get_all_models
|
|
310
|
+
|
|
311
|
+
copilot_models = [m for m in get_all_models() if m.provider == "github-copilot"]
|
|
312
|
+
tasks = [_enable_copilot_model(token, model.id, enterprise_domain) for model in copilot_models]
|
|
313
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
async def login(
|
|
317
|
+
on_user_code: Any | None = None, enterprise_domain: str | None = None
|
|
318
|
+
) -> CopilotCredentials:
|
|
319
|
+
"""
|
|
320
|
+
Perform the full Copilot login flow.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
on_user_code: Callback with (verification_uri, user_code) when user action needed
|
|
324
|
+
enterprise_domain: Optional GitHub Enterprise domain
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
CopilotCredentials that are saved and ready to use
|
|
328
|
+
"""
|
|
329
|
+
domain = enterprise_domain or "github.com"
|
|
330
|
+
|
|
331
|
+
# Start device flow
|
|
332
|
+
device = await start_device_flow(domain)
|
|
333
|
+
|
|
334
|
+
# Notify caller about user action needed
|
|
335
|
+
if on_user_code:
|
|
336
|
+
on_user_code(device.verification_uri, device.user_code)
|
|
337
|
+
|
|
338
|
+
# Poll for GitHub token
|
|
339
|
+
github_token = await poll_for_github_token(
|
|
340
|
+
device.device_code, device.interval, device.expires_in, domain
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Exchange for Copilot token
|
|
344
|
+
copilot_token, expires_at = await exchange_for_copilot_token(github_token, domain)
|
|
345
|
+
|
|
346
|
+
# Save and return credentials
|
|
347
|
+
creds = CopilotCredentials(
|
|
348
|
+
github_token=github_token,
|
|
349
|
+
copilot_token=copilot_token,
|
|
350
|
+
expires_at=expires_at,
|
|
351
|
+
enterprise_domain=enterprise_domain,
|
|
352
|
+
)
|
|
353
|
+
save_credentials(creds)
|
|
354
|
+
|
|
355
|
+
# Enable all Copilot models (some require policy acceptance)
|
|
356
|
+
await enable_all_copilot_models(copilot_token, enterprise_domain)
|
|
357
|
+
|
|
358
|
+
return creds
|
vtx/llm/oauth/dynamic.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""
|
|
2
|
+
API-key storage for dynamic OpenAI-compatible providers.
|
|
3
|
+
|
|
4
|
+
The dynamic providers (``airouter``, ``opencode``, ``kilo``, ``tokenrouter``) do
|
|
5
|
+
not need an OAuth flow — they just need an API key. Users can set one of three
|
|
6
|
+
ways, in priority order:
|
|
7
|
+
|
|
8
|
+
1. The provider's ``<NAME>_API_KEY`` environment variable (e.g. ``KILO_API_KEY``).
|
|
9
|
+
2. The encrypted-on-disk key file at the configured location (mode 0600),
|
|
10
|
+
written by the in-app ``/login`` command.
|
|
11
|
+
3. None — for providers that support a free tier (airouter, kilo), vtx will
|
|
12
|
+
fall back to a placeholder key.
|
|
13
|
+
|
|
14
|
+
This module owns path #2: it reads/writes the key file and exposes a small
|
|
15
|
+
helper, :func:`get_dynamic_api_key`, that already implements the env-var-first
|
|
16
|
+
priority so the rest of vtx does not have to.
|
|
17
|
+
|
|
18
|
+
The storage location and format can be configured via the vtx-api-key-storage skill.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import contextlib
|
|
24
|
+
import json
|
|
25
|
+
import os
|
|
26
|
+
from dataclasses import dataclass
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
from vtx import get_config_dir
|
|
30
|
+
from vtx.llm.dynamic_models import DYNAMIC_PROVIDERS
|
|
31
|
+
from vtx.llm.provider_catalog import get as get_provider_info
|
|
32
|
+
|
|
33
|
+
# Default configuration
|
|
34
|
+
AUTH_FILENAME = "dynamic_auth.json"
|
|
35
|
+
Vtx_STORAGE_DIR = Path.home() / "vtx"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class DynamicProviderStatus:
|
|
40
|
+
"""Status of a dynamic provider's credentials."""
|
|
41
|
+
|
|
42
|
+
provider: str
|
|
43
|
+
env_var: str | None
|
|
44
|
+
has_env_key: bool
|
|
45
|
+
has_stored_key: bool
|
|
46
|
+
api_key_optional: bool
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def is_configured(self) -> bool:
|
|
50
|
+
"""True if we have any way to authenticate (key or no-auth provider)."""
|
|
51
|
+
return self.has_env_key or self.has_stored_key or self.api_key_optional
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_dynamic_auth_path() -> Path:
|
|
55
|
+
"""Get the path to the API key storage file.
|
|
56
|
+
|
|
57
|
+
This function checks for the new YAML format first, then falls back
|
|
58
|
+
to the JSON format for backward compatibility.
|
|
59
|
+
"""
|
|
60
|
+
# Check for new YAML format
|
|
61
|
+
yaml_path = Vtx_STORAGE_DIR / "dynamic_auth.yml"
|
|
62
|
+
if yaml_path.exists():
|
|
63
|
+
return yaml_path
|
|
64
|
+
|
|
65
|
+
# Check for JSON format in both old and new locations
|
|
66
|
+
# First check XDG_CONFIG_HOME/vtx/
|
|
67
|
+
xdg_config_dir = os.environ.get("XDG_CONFIG_HOME")
|
|
68
|
+
if xdg_config_dir:
|
|
69
|
+
xdg_path = Path(xdg_config_dir) / "vtx" / AUTH_FILENAME
|
|
70
|
+
if xdg_path.exists():
|
|
71
|
+
return xdg_path
|
|
72
|
+
|
|
73
|
+
# Then check ~/.vtx/ for backward compatibility
|
|
74
|
+
return get_config_dir() / AUTH_FILENAME
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _read_all() -> dict[str, str]:
|
|
78
|
+
path = get_dynamic_auth_path()
|
|
79
|
+
if not path.exists():
|
|
80
|
+
return {}
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
content = path.read_text(encoding="utf-8")
|
|
84
|
+
|
|
85
|
+
# Determine format based on file extension
|
|
86
|
+
if path.suffix.lower() == ".yml" or path.suffix.lower() == ".yaml":
|
|
87
|
+
import yaml
|
|
88
|
+
|
|
89
|
+
data = yaml.safe_load(content) or {}
|
|
90
|
+
elif path.suffix.lower() == ".json":
|
|
91
|
+
data = json.loads(content)
|
|
92
|
+
else:
|
|
93
|
+
# Default to JSON for backward compatibility
|
|
94
|
+
data = json.loads(content)
|
|
95
|
+
|
|
96
|
+
except (OSError, json.JSONDecodeError, ImportError):
|
|
97
|
+
return {}
|
|
98
|
+
except Exception:
|
|
99
|
+
return {}
|
|
100
|
+
|
|
101
|
+
if not isinstance(data, dict):
|
|
102
|
+
return {}
|
|
103
|
+
# Only keep str→str entries; ignore anything weird.
|
|
104
|
+
return {k: v for k, v in data.items() if isinstance(k, str) and isinstance(v, str)}
|
|
105
|
+
|
|
106
|
+
if not isinstance(data, dict):
|
|
107
|
+
return {}
|
|
108
|
+
# Only keep str→str entries; ignore anything weird.
|
|
109
|
+
return {k: v for k, v in data.items() if isinstance(k, str) and isinstance(v, str)}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _write_all(keys: dict[str, str]) -> None:
|
|
113
|
+
path = get_dynamic_auth_path()
|
|
114
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
# Determine format based on file extension
|
|
117
|
+
if path.suffix.lower() == ".yml" or path.suffix.lower() == ".yaml":
|
|
118
|
+
tmp = path.with_suffix(".yml.tmp")
|
|
119
|
+
try:
|
|
120
|
+
import yaml
|
|
121
|
+
|
|
122
|
+
tmp.write_text(yaml.dump(keys, default_flow_style=False), encoding="utf-8")
|
|
123
|
+
except ImportError:
|
|
124
|
+
# Fallback to JSON if yaml not available
|
|
125
|
+
tmp = path.with_suffix(".json.tmp")
|
|
126
|
+
tmp.write_text(json.dumps(keys, indent=2), encoding="utf-8")
|
|
127
|
+
else:
|
|
128
|
+
# Default to JSON
|
|
129
|
+
tmp = path.with_suffix(".json.tmp")
|
|
130
|
+
tmp.write_text(json.dumps(keys, indent=2), encoding="utf-8")
|
|
131
|
+
|
|
132
|
+
with contextlib.suppress(OSError):
|
|
133
|
+
# Non-POSIX filesystems (e.g. Windows) don't support chmod; ignore.
|
|
134
|
+
os.chmod(tmp, 0o600)
|
|
135
|
+
tmp.replace(path)
|
|
136
|
+
with contextlib.suppress(OSError):
|
|
137
|
+
os.chmod(path, 0o600)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def load_api_key(provider: str) -> str | None:
|
|
141
|
+
"""Return the API key stored on disk for a provider, if any."""
|
|
142
|
+
return _read_all().get(provider)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def save_api_key(provider: str, key: str) -> None:
|
|
146
|
+
"""Persist an API key for a provider."""
|
|
147
|
+
key = key.strip()
|
|
148
|
+
if not key:
|
|
149
|
+
raise ValueError("API key must not be empty")
|
|
150
|
+
if provider not in DYNAMIC_PROVIDERS and get_provider_info(provider) is None:
|
|
151
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
152
|
+
keys = _read_all()
|
|
153
|
+
keys[provider] = key
|
|
154
|
+
_write_all(keys)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def clear_api_key(provider: str) -> bool:
|
|
158
|
+
"""Remove a stored API key. Returns True if one was removed."""
|
|
159
|
+
keys = _read_all()
|
|
160
|
+
if provider not in keys:
|
|
161
|
+
return False
|
|
162
|
+
del keys[provider]
|
|
163
|
+
_write_all(keys)
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def has_api_key(provider: str) -> bool:
|
|
168
|
+
"""True if a stored key exists for the provider."""
|
|
169
|
+
return provider in _read_all()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _env_var_for(provider: str) -> str | None:
|
|
173
|
+
config = DYNAMIC_PROVIDERS.get(provider)
|
|
174
|
+
if config is not None:
|
|
175
|
+
return config.env_var
|
|
176
|
+
p = get_provider_info(provider)
|
|
177
|
+
return p.api_key_env if p else None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_dynamic_api_key(provider: str) -> str | None:
|
|
181
|
+
"""Return the best available API key for a dynamic provider.
|
|
182
|
+
|
|
183
|
+
Priority: ``<NAME>_API_KEY`` env var > stored ``dynamic_auth.json`` entry.
|
|
184
|
+
Returns ``None`` if neither is set.
|
|
185
|
+
"""
|
|
186
|
+
env_var = _env_var_for(provider)
|
|
187
|
+
if env_var:
|
|
188
|
+
env_value = os.environ.get(env_var)
|
|
189
|
+
if env_value and env_value.strip():
|
|
190
|
+
return env_value.strip()
|
|
191
|
+
return load_api_key(provider)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def get_provider_status(provider: str) -> DynamicProviderStatus | None:
|
|
195
|
+
"""Return credential status for a provider, or ``None`` if unknown.
|
|
196
|
+
|
|
197
|
+
Works for both the built-in ``DYNAMIC_PROVIDERS`` (airouter, opencode,
|
|
198
|
+
kilo, tokenrouter) and any provider defined in ``provider.yaml``.
|
|
199
|
+
"""
|
|
200
|
+
config = DYNAMIC_PROVIDERS.get(provider)
|
|
201
|
+
if config is not None:
|
|
202
|
+
env_var = config.env_var
|
|
203
|
+
has_env = bool(env_var and os.environ.get(env_var, "").strip())
|
|
204
|
+
return DynamicProviderStatus(
|
|
205
|
+
provider=provider,
|
|
206
|
+
env_var=env_var,
|
|
207
|
+
has_env_key=has_env,
|
|
208
|
+
has_stored_key=has_api_key(provider),
|
|
209
|
+
api_key_optional=config.api_key_optional,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
p = get_provider_info(provider)
|
|
213
|
+
if p is None or not p.base_url:
|
|
214
|
+
return None
|
|
215
|
+
env_var = p.api_key_env
|
|
216
|
+
has_env = bool(env_var and os.environ.get(env_var, "").strip())
|
|
217
|
+
return DynamicProviderStatus(
|
|
218
|
+
provider=provider,
|
|
219
|
+
env_var=env_var,
|
|
220
|
+
has_env_key=has_env,
|
|
221
|
+
has_stored_key=has_api_key(provider),
|
|
222
|
+
api_key_optional=p.api_key_optional,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
__all__ = [
|
|
227
|
+
"AUTH_FILENAME",
|
|
228
|
+
"DynamicProviderStatus",
|
|
229
|
+
"clear_api_key",
|
|
230
|
+
"get_dynamic_api_key",
|
|
231
|
+
"get_dynamic_auth_path",
|
|
232
|
+
"get_provider_status",
|
|
233
|
+
"has_api_key",
|
|
234
|
+
"load_api_key",
|
|
235
|
+
"save_api_key",
|
|
236
|
+
]
|