klaude-code 2.8.1__py3-none-any.whl → 2.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klaude_code/app/runtime.py +2 -1
- klaude_code/auth/antigravity/oauth.py +33 -38
- klaude_code/auth/antigravity/token_manager.py +0 -18
- klaude_code/auth/base.py +53 -0
- klaude_code/auth/claude/oauth.py +34 -49
- klaude_code/auth/codex/exceptions.py +0 -4
- klaude_code/auth/codex/oauth.py +32 -28
- klaude_code/auth/codex/token_manager.py +0 -18
- klaude_code/cli/cost_cmd.py +128 -39
- klaude_code/cli/list_model.py +27 -10
- klaude_code/cli/main.py +14 -3
- klaude_code/config/assets/builtin_config.yaml +25 -24
- klaude_code/config/config.py +47 -25
- klaude_code/config/sub_agent_model_helper.py +18 -13
- klaude_code/config/thinking.py +0 -8
- klaude_code/const.py +1 -1
- klaude_code/core/agent_profile.py +11 -56
- klaude_code/core/compaction/overflow.py +0 -4
- klaude_code/core/executor.py +33 -5
- klaude_code/core/manager/llm_clients.py +9 -1
- klaude_code/core/prompts/prompt-claude-code.md +4 -4
- klaude_code/core/reminders.py +21 -23
- klaude_code/core/task.py +1 -5
- klaude_code/core/tool/__init__.py +3 -2
- klaude_code/core/tool/file/apply_patch.py +0 -27
- klaude_code/core/tool/file/read_tool.md +3 -2
- klaude_code/core/tool/file/read_tool.py +27 -3
- klaude_code/core/tool/offload.py +0 -35
- klaude_code/core/tool/shell/bash_tool.py +1 -1
- klaude_code/core/tool/sub_agent/__init__.py +6 -0
- klaude_code/core/tool/sub_agent/image_gen.md +16 -0
- klaude_code/core/tool/sub_agent/image_gen.py +146 -0
- klaude_code/core/tool/sub_agent/task.md +20 -0
- klaude_code/core/tool/sub_agent/task.py +205 -0
- klaude_code/core/tool/tool_registry.py +0 -16
- klaude_code/core/turn.py +1 -1
- klaude_code/llm/anthropic/input.py +6 -5
- klaude_code/llm/antigravity/input.py +14 -7
- klaude_code/llm/bedrock_anthropic/__init__.py +3 -0
- klaude_code/llm/google/client.py +8 -6
- klaude_code/llm/google/input.py +20 -12
- klaude_code/llm/image.py +18 -11
- klaude_code/llm/input_common.py +32 -6
- klaude_code/llm/json_stable.py +37 -0
- klaude_code/llm/{codex → openai_codex}/__init__.py +1 -1
- klaude_code/llm/{codex → openai_codex}/client.py +24 -2
- klaude_code/llm/openai_codex/prompt_sync.py +237 -0
- klaude_code/llm/openai_compatible/client.py +3 -1
- klaude_code/llm/openai_compatible/input.py +0 -10
- klaude_code/llm/openai_compatible/stream.py +35 -10
- klaude_code/llm/{responses → openai_responses}/client.py +1 -1
- klaude_code/llm/{responses → openai_responses}/input.py +15 -5
- klaude_code/llm/registry.py +3 -8
- klaude_code/llm/stream_parts.py +3 -1
- klaude_code/llm/usage.py +1 -9
- klaude_code/protocol/events.py +2 -2
- klaude_code/protocol/message.py +3 -2
- klaude_code/protocol/model.py +34 -2
- klaude_code/protocol/op.py +13 -0
- klaude_code/protocol/op_handler.py +5 -0
- klaude_code/protocol/sub_agent/AGENTS.md +5 -5
- klaude_code/protocol/sub_agent/__init__.py +13 -34
- klaude_code/protocol/sub_agent/explore.py +7 -34
- klaude_code/protocol/sub_agent/image_gen.py +3 -74
- klaude_code/protocol/sub_agent/task.py +3 -47
- klaude_code/protocol/sub_agent/web.py +8 -52
- klaude_code/protocol/tools.py +2 -0
- klaude_code/session/session.py +80 -22
- klaude_code/session/store.py +0 -4
- klaude_code/skill/assets/deslop/SKILL.md +9 -0
- klaude_code/skill/system_skills.py +0 -20
- klaude_code/tui/command/fork_session_cmd.py +5 -2
- klaude_code/tui/command/resume_cmd.py +9 -2
- klaude_code/tui/command/sub_agent_model_cmd.py +85 -18
- klaude_code/tui/components/assistant.py +0 -26
- klaude_code/tui/components/bash_syntax.py +4 -0
- klaude_code/tui/components/command_output.py +3 -1
- klaude_code/tui/components/developer.py +3 -0
- klaude_code/tui/components/diffs.py +4 -209
- klaude_code/tui/components/errors.py +4 -0
- klaude_code/tui/components/mermaid_viewer.py +2 -2
- klaude_code/tui/components/metadata.py +0 -3
- klaude_code/tui/components/rich/markdown.py +120 -87
- klaude_code/tui/components/rich/status.py +2 -2
- klaude_code/tui/components/rich/theme.py +11 -6
- klaude_code/tui/components/sub_agent.py +2 -46
- klaude_code/tui/components/thinking.py +0 -33
- klaude_code/tui/components/tools.py +65 -21
- klaude_code/tui/components/user_input.py +2 -0
- klaude_code/tui/input/images.py +21 -18
- klaude_code/tui/input/key_bindings.py +2 -2
- klaude_code/tui/input/prompt_toolkit.py +49 -49
- klaude_code/tui/machine.py +29 -47
- klaude_code/tui/renderer.py +48 -33
- klaude_code/tui/runner.py +2 -1
- klaude_code/tui/terminal/image.py +27 -34
- klaude_code/ui/common.py +0 -70
- {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/METADATA +3 -6
- {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/RECORD +103 -99
- klaude_code/core/tool/sub_agent_tool.py +0 -126
- klaude_code/llm/bedrock/__init__.py +0 -3
- klaude_code/llm/openai_compatible/tool_call_accumulator.py +0 -108
- klaude_code/tui/components/rich/searchable_text.py +0 -68
- /klaude_code/llm/{bedrock → bedrock_anthropic}/client.py +0 -0
- /klaude_code/llm/{responses → openai_responses}/__init__.py +0 -0
- {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/WHEEL +0 -0
- {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/entry_points.txt +0 -0
klaude_code/app/runtime.py
CHANGED
|
@@ -178,6 +178,7 @@ async def handle_keyboard_interrupt(executor: Executor) -> None:
|
|
|
178
178
|
log("Bye!")
|
|
179
179
|
session_id = executor.context.current_session_id()
|
|
180
180
|
if session_id and Session.exists(session_id):
|
|
181
|
-
|
|
181
|
+
short_id = Session.shortest_unique_prefix(session_id)
|
|
182
|
+
log(("Resume with:", "dim"), (f"klaude -r {short_id}", "green"))
|
|
182
183
|
with contextlib.suppress(Exception):
|
|
183
184
|
await executor.submit(op.InterruptOperation(target_session_id=None))
|
|
@@ -258,42 +258,46 @@ class AntigravityOAuth:
|
|
|
258
258
|
)
|
|
259
259
|
|
|
260
260
|
def refresh(self) -> AntigravityAuthState:
|
|
261
|
-
"""Refresh the access token using refresh token.
|
|
262
|
-
state = self.token_manager.get_state()
|
|
263
|
-
if state is None:
|
|
264
|
-
raise AntigravityNotLoggedInError("Not logged in to Antigravity. Run 'klaude login antigravity' first.")
|
|
261
|
+
"""Refresh the access token using refresh token with file locking.
|
|
265
262
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
"refresh_token": state.refresh_token,
|
|
270
|
-
"grant_type": "refresh_token",
|
|
271
|
-
}
|
|
263
|
+
Uses file locking to prevent multiple instances from refreshing simultaneously.
|
|
264
|
+
If another instance has already refreshed, returns the updated state.
|
|
265
|
+
"""
|
|
272
266
|
|
|
273
|
-
|
|
274
|
-
|
|
267
|
+
def do_refresh(current_state: AntigravityAuthState) -> AntigravityAuthState:
|
|
268
|
+
data = {
|
|
269
|
+
"client_id": CLIENT_ID,
|
|
270
|
+
"client_secret": CLIENT_SECRET,
|
|
271
|
+
"refresh_token": current_state.refresh_token,
|
|
272
|
+
"grant_type": "refresh_token",
|
|
273
|
+
}
|
|
275
274
|
|
|
276
|
-
|
|
277
|
-
|
|
275
|
+
with httpx.Client() as client:
|
|
276
|
+
response = client.post(TOKEN_URL, data=data, timeout=30)
|
|
278
277
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
refresh_token = tokens.get("refresh_token", state.refresh_token)
|
|
282
|
-
expires_in = tokens.get("expires_in", 3600)
|
|
278
|
+
if response.status_code != 200:
|
|
279
|
+
raise AntigravityTokenExpiredError(f"Token refresh failed: {response.text}")
|
|
283
280
|
|
|
284
|
-
|
|
285
|
-
|
|
281
|
+
tokens = response.json()
|
|
282
|
+
access_token = tokens["access_token"]
|
|
283
|
+
refresh_token = tokens.get("refresh_token", current_state.refresh_token)
|
|
284
|
+
expires_in = tokens.get("expires_in", 3600)
|
|
286
285
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
refresh_token=refresh_token,
|
|
290
|
-
expires_at=expires_at,
|
|
291
|
-
project_id=state.project_id,
|
|
292
|
-
email=state.email,
|
|
293
|
-
)
|
|
286
|
+
# Calculate expiry time with 5 minute buffer
|
|
287
|
+
expires_at = int(time.time()) + expires_in - 300
|
|
294
288
|
|
|
295
|
-
|
|
296
|
-
|
|
289
|
+
return AntigravityAuthState(
|
|
290
|
+
access_token=access_token,
|
|
291
|
+
refresh_token=refresh_token,
|
|
292
|
+
expires_at=expires_at,
|
|
293
|
+
project_id=current_state.project_id,
|
|
294
|
+
email=current_state.email,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
return self.token_manager.refresh_with_lock(do_refresh)
|
|
299
|
+
except ValueError as e:
|
|
300
|
+
raise AntigravityNotLoggedInError(str(e)) from e
|
|
297
301
|
|
|
298
302
|
def ensure_valid_token(self) -> tuple[str, str]:
|
|
299
303
|
"""Ensure we have a valid access token, refreshing if needed.
|
|
@@ -309,12 +313,3 @@ class AntigravityOAuth:
|
|
|
309
313
|
state = self.refresh()
|
|
310
314
|
|
|
311
315
|
return state.access_token, state.project_id
|
|
312
|
-
|
|
313
|
-
def get_api_key_json(self) -> str:
|
|
314
|
-
"""Get API key as JSON string for LLM client.
|
|
315
|
-
|
|
316
|
-
Returns:
|
|
317
|
-
JSON string with token and projectId.
|
|
318
|
-
"""
|
|
319
|
-
access_token, project_id = self.ensure_valid_token()
|
|
320
|
-
return json.dumps({"token": access_token, "projectId": project_id})
|
|
@@ -25,21 +25,3 @@ class AntigravityTokenManager(BaseTokenManager[AntigravityAuthState]):
|
|
|
25
25
|
|
|
26
26
|
def _create_state(self, data: dict[str, Any]) -> AntigravityAuthState:
|
|
27
27
|
return AntigravityAuthState.model_validate(data)
|
|
28
|
-
|
|
29
|
-
def get_access_token(self) -> str:
|
|
30
|
-
"""Get access token, raising if not logged in."""
|
|
31
|
-
state = self.get_state()
|
|
32
|
-
if state is None:
|
|
33
|
-
from klaude_code.auth.antigravity.exceptions import AntigravityNotLoggedInError
|
|
34
|
-
|
|
35
|
-
raise AntigravityNotLoggedInError("Not logged in to Antigravity. Run 'klaude login antigravity' first.")
|
|
36
|
-
return state.access_token
|
|
37
|
-
|
|
38
|
-
def get_project_id(self) -> str:
|
|
39
|
-
"""Get project ID, raising if not logged in."""
|
|
40
|
-
state = self.get_state()
|
|
41
|
-
if state is None:
|
|
42
|
-
from klaude_code.auth.antigravity.exceptions import AntigravityNotLoggedInError
|
|
43
|
-
|
|
44
|
-
raise AntigravityNotLoggedInError("Not logged in to Antigravity. Run 'klaude login antigravity' first.")
|
|
45
|
-
return state.project_id
|
klaude_code/auth/base.py
CHANGED
|
@@ -3,12 +3,15 @@
|
|
|
3
3
|
import json
|
|
4
4
|
import time
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Callable
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from typing import Any, cast
|
|
8
9
|
|
|
10
|
+
from filelock import FileLock, Timeout
|
|
9
11
|
from pydantic import BaseModel
|
|
10
12
|
|
|
11
13
|
KLAUDE_AUTH_FILE = Path.home() / ".klaude" / "klaude-auth.json"
|
|
14
|
+
LOCK_TIMEOUT_SECONDS = 30 # Maximum time to wait for lock acquisition
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
class BaseAuthState(BaseModel):
|
|
@@ -99,3 +102,53 @@ class BaseTokenManager[T: BaseAuthState](ABC):
|
|
|
99
102
|
def clear_cached_state(self) -> None:
|
|
100
103
|
"""Clear in-memory cached state to force reload from file on next access."""
|
|
101
104
|
self._state = None
|
|
105
|
+
|
|
106
|
+
def _get_lock_file(self) -> Path:
|
|
107
|
+
"""Get the lock file path for this auth file."""
|
|
108
|
+
return self.auth_file.with_suffix(".lock")
|
|
109
|
+
|
|
110
|
+
def refresh_with_lock(self, refresh_fn: Callable[[T], T]) -> T:
|
|
111
|
+
"""Refresh token with file locking to prevent concurrent refresh.
|
|
112
|
+
|
|
113
|
+
This prevents multiple instances from simultaneously refreshing the same token.
|
|
114
|
+
If another instance has already refreshed, returns the updated state.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
refresh_fn: Function that takes current state and returns new state.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The new or already-refreshed authentication state.
|
|
121
|
+
|
|
122
|
+
Raises:
|
|
123
|
+
Timeout: If unable to acquire the lock within timeout.
|
|
124
|
+
ValueError: If not logged in.
|
|
125
|
+
"""
|
|
126
|
+
lock_file = self._get_lock_file()
|
|
127
|
+
lock = FileLock(lock_file, timeout=LOCK_TIMEOUT_SECONDS)
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
with lock:
|
|
131
|
+
# Re-read file after acquiring lock - another instance may have refreshed
|
|
132
|
+
self.clear_cached_state()
|
|
133
|
+
state = self.load()
|
|
134
|
+
|
|
135
|
+
if state is None:
|
|
136
|
+
raise ValueError(f"Not logged in to {self.storage_key}")
|
|
137
|
+
|
|
138
|
+
# Check if token is still expired after re-reading
|
|
139
|
+
if not state.is_expired():
|
|
140
|
+
# Another instance already refreshed, use their result
|
|
141
|
+
return state
|
|
142
|
+
|
|
143
|
+
# Token still expired, we need to refresh
|
|
144
|
+
new_state = refresh_fn(state)
|
|
145
|
+
self.save(new_state)
|
|
146
|
+
return new_state
|
|
147
|
+
|
|
148
|
+
except Timeout:
|
|
149
|
+
# Lock timeout - try to re-read file in case another instance succeeded
|
|
150
|
+
self.clear_cached_state()
|
|
151
|
+
state = self.load()
|
|
152
|
+
if state and not state.is_expired():
|
|
153
|
+
return state
|
|
154
|
+
raise
|
klaude_code/auth/claude/oauth.py
CHANGED
|
@@ -125,60 +125,45 @@ class ClaudeOAuth:
|
|
|
125
125
|
expires_at=int(time.time()) + int(expires_in),
|
|
126
126
|
)
|
|
127
127
|
|
|
128
|
-
def _do_refresh_request(self, refresh_token: str) -> httpx.Response:
|
|
129
|
-
"""Send token refresh request to OAuth server."""
|
|
130
|
-
payload = {
|
|
131
|
-
"grant_type": "refresh_token",
|
|
132
|
-
"client_id": CLIENT_ID,
|
|
133
|
-
"refresh_token": refresh_token,
|
|
134
|
-
}
|
|
135
|
-
with httpx.Client() as client:
|
|
136
|
-
return client.post(
|
|
137
|
-
TOKEN_URL,
|
|
138
|
-
json=payload,
|
|
139
|
-
headers={"Content-Type": "application/json"},
|
|
140
|
-
)
|
|
141
|
-
|
|
142
128
|
def refresh(self) -> ClaudeAuthState:
|
|
143
|
-
"""Refresh the access token using refresh token.
|
|
129
|
+
"""Refresh the access token using refresh token with file locking.
|
|
144
130
|
|
|
145
|
-
|
|
146
|
-
|
|
131
|
+
Uses file locking to prevent multiple instances from refreshing simultaneously.
|
|
132
|
+
If another instance has already refreshed, returns the updated state.
|
|
147
133
|
"""
|
|
148
|
-
state = self.token_manager.get_state()
|
|
149
|
-
if state is None:
|
|
150
|
-
raise ClaudeNotLoggedInError("Not logged in to Claude. Run 'klaude login claude' first.")
|
|
151
134
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
135
|
+
def do_refresh(current_state: ClaudeAuthState) -> ClaudeAuthState:
|
|
136
|
+
payload = {
|
|
137
|
+
"grant_type": "refresh_token",
|
|
138
|
+
"client_id": CLIENT_ID,
|
|
139
|
+
"refresh_token": current_state.refresh_token,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
with httpx.Client() as client:
|
|
143
|
+
response = client.post(
|
|
144
|
+
TOKEN_URL,
|
|
145
|
+
json=payload,
|
|
146
|
+
headers={"Content-Type": "application/json"},
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if response.status_code != 200:
|
|
150
|
+
raise ClaudeAuthError(f"Token refresh failed: {response.text}")
|
|
151
|
+
|
|
152
|
+
tokens = response.json()
|
|
153
|
+
access_token = tokens["access_token"]
|
|
154
|
+
refresh_token = tokens.get("refresh_token", current_state.refresh_token)
|
|
155
|
+
expires_in = tokens.get("expires_in", 3600)
|
|
156
|
+
|
|
157
|
+
return ClaudeAuthState(
|
|
158
|
+
access_token=access_token,
|
|
159
|
+
refresh_token=refresh_token,
|
|
160
|
+
expires_at=int(time.time()) + int(expires_in),
|
|
161
|
+
)
|
|
174
162
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
)
|
|
180
|
-
self.token_manager.save(new_state)
|
|
181
|
-
return new_state
|
|
163
|
+
try:
|
|
164
|
+
return self.token_manager.refresh_with_lock(do_refresh)
|
|
165
|
+
except ValueError as e:
|
|
166
|
+
raise ClaudeNotLoggedInError(str(e)) from e
|
|
182
167
|
|
|
183
168
|
def ensure_valid_token(self) -> str:
|
|
184
169
|
"""Ensure we have a valid access token, refreshing if needed."""
|
klaude_code/auth/codex/oauth.py
CHANGED
|
@@ -177,43 +177,47 @@ class CodexOAuth:
|
|
|
177
177
|
)
|
|
178
178
|
|
|
179
179
|
def refresh(self) -> CodexAuthState:
|
|
180
|
-
"""Refresh the access token using refresh token.
|
|
181
|
-
state = self.token_manager.get_state()
|
|
182
|
-
if state is None:
|
|
183
|
-
from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
|
|
180
|
+
"""Refresh the access token using refresh token with file locking.
|
|
184
181
|
|
|
185
|
-
|
|
182
|
+
Uses file locking to prevent multiple instances from refreshing simultaneously.
|
|
183
|
+
If another instance has already refreshed, returns the updated state.
|
|
184
|
+
"""
|
|
186
185
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
186
|
+
def do_refresh(current_state: CodexAuthState) -> CodexAuthState:
|
|
187
|
+
data = {
|
|
188
|
+
"grant_type": "refresh_token",
|
|
189
|
+
"client_id": CLIENT_ID,
|
|
190
|
+
"refresh_token": current_state.refresh_token,
|
|
191
|
+
}
|
|
192
192
|
|
|
193
|
-
|
|
194
|
-
|
|
193
|
+
with httpx.Client() as client:
|
|
194
|
+
response = client.post(TOKEN_URL, data=data)
|
|
195
195
|
|
|
196
|
-
|
|
197
|
-
|
|
196
|
+
if response.status_code != 200:
|
|
197
|
+
from klaude_code.auth.codex.exceptions import CodexTokenExpiredError
|
|
198
198
|
|
|
199
|
-
|
|
199
|
+
raise CodexTokenExpiredError(f"Token refresh failed: {response.text}")
|
|
200
200
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
201
|
+
tokens = response.json()
|
|
202
|
+
access_token = tokens["access_token"]
|
|
203
|
+
refresh_token = tokens.get("refresh_token", current_state.refresh_token)
|
|
204
|
+
expires_in = tokens.get("expires_in", 3600)
|
|
205
205
|
|
|
206
|
-
|
|
206
|
+
account_id = extract_account_id(access_token)
|
|
207
207
|
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
208
|
+
return CodexAuthState(
|
|
209
|
+
access_token=access_token,
|
|
210
|
+
refresh_token=refresh_token,
|
|
211
|
+
expires_at=int(time.time()) + expires_in,
|
|
212
|
+
account_id=account_id,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
return self.token_manager.refresh_with_lock(do_refresh)
|
|
217
|
+
except ValueError as e:
|
|
218
|
+
from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
|
|
214
219
|
|
|
215
|
-
|
|
216
|
-
return new_state
|
|
220
|
+
raise CodexNotLoggedInError(str(e)) from e
|
|
217
221
|
|
|
218
222
|
def ensure_valid_token(self) -> str:
|
|
219
223
|
"""Ensure we have a valid access token, refreshing if needed."""
|
|
@@ -24,21 +24,3 @@ class CodexTokenManager(BaseTokenManager[CodexAuthState]):
|
|
|
24
24
|
|
|
25
25
|
def _create_state(self, data: dict[str, Any]) -> CodexAuthState:
|
|
26
26
|
return CodexAuthState.model_validate(data)
|
|
27
|
-
|
|
28
|
-
def get_access_token(self) -> str:
|
|
29
|
-
"""Get access token, raising if not logged in."""
|
|
30
|
-
state = self.get_state()
|
|
31
|
-
if state is None:
|
|
32
|
-
from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
|
|
33
|
-
|
|
34
|
-
raise CodexNotLoggedInError("Not logged in to Codex. Run 'klaude login codex' first.")
|
|
35
|
-
return state.access_token
|
|
36
|
-
|
|
37
|
-
def get_account_id(self) -> str:
|
|
38
|
-
"""Get account ID, raising if not logged in."""
|
|
39
|
-
state = self.get_state()
|
|
40
|
-
if state is None:
|
|
41
|
-
from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
|
|
42
|
-
|
|
43
|
-
raise CodexNotLoggedInError("Not logged in to Codex. Run 'klaude login codex' first.")
|
|
44
|
-
return state.account_id
|
klaude_code/cli/cost_cmd.py
CHANGED
|
@@ -34,6 +34,16 @@ class ModelUsageStats:
|
|
|
34
34
|
def total_tokens(self) -> int:
|
|
35
35
|
return self.input_tokens + self.output_tokens
|
|
36
36
|
|
|
37
|
+
@property
|
|
38
|
+
def non_cached_input_tokens(self) -> int:
|
|
39
|
+
"""Non-cached prompt tokens.
|
|
40
|
+
|
|
41
|
+
We store `input_tokens` as the provider-reported prompt token count, which
|
|
42
|
+
includes cached tokens for providers that support prompt caching.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
return max(0, self.input_tokens - self.cached_tokens)
|
|
46
|
+
|
|
37
47
|
def add_usage(self, usage: model.Usage) -> None:
|
|
38
48
|
self.input_tokens += usage.input_tokens
|
|
39
49
|
self.output_tokens += usage.output_tokens
|
|
@@ -48,41 +58,99 @@ class ModelUsageStats:
|
|
|
48
58
|
ModelKey = tuple[str, str] # (model_name, provider)
|
|
49
59
|
|
|
50
60
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
61
|
+
@dataclass
|
|
62
|
+
class SubProviderGroup:
|
|
63
|
+
"""Group of models under a sub-provider."""
|
|
64
|
+
|
|
65
|
+
name: str
|
|
66
|
+
models: list[ModelUsageStats]
|
|
67
|
+
total: ModelUsageStats
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class ProviderGroup:
|
|
72
|
+
"""Group of models/sub-providers under a top-level provider."""
|
|
73
|
+
|
|
74
|
+
name: str
|
|
75
|
+
sub_providers: dict[str, SubProviderGroup] # empty if no sub-providers
|
|
76
|
+
models: list[ModelUsageStats] # direct models (when no sub-provider)
|
|
77
|
+
total: ModelUsageStats
|
|
78
|
+
|
|
55
79
|
|
|
56
|
-
|
|
80
|
+
def _sort_by_cost(stats: ModelUsageStats) -> tuple[float, float]:
|
|
81
|
+
return (-stats.cost_usd, -stats.cost_cny)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def group_models_by_provider(models: dict[ModelKey, ModelUsageStats]) -> dict[str, ProviderGroup]:
|
|
85
|
+
"""Group models by provider with three-level hierarchy.
|
|
86
|
+
|
|
87
|
+
Provider strings like "openrouter/Anthropic" are split into:
|
|
88
|
+
- Top-level: "openrouter"
|
|
89
|
+
- Sub-provider: "Anthropic"
|
|
90
|
+
|
|
91
|
+
Returns dict of ProviderGroup sorted by cost desc.
|
|
57
92
|
"""
|
|
58
|
-
|
|
59
|
-
provider_totals: dict[str, ModelUsageStats] = {}
|
|
93
|
+
provider_groups: dict[str, ProviderGroup] = {}
|
|
60
94
|
|
|
61
95
|
for stats in models.values():
|
|
62
|
-
|
|
63
|
-
if provider_key not in models_by_provider:
|
|
64
|
-
models_by_provider[provider_key] = []
|
|
65
|
-
provider_totals[provider_key] = ModelUsageStats(model_name=provider_key, provider=provider_key)
|
|
66
|
-
models_by_provider[provider_key].append(stats)
|
|
67
|
-
provider_totals[provider_key].input_tokens += stats.input_tokens
|
|
68
|
-
provider_totals[provider_key].output_tokens += stats.output_tokens
|
|
69
|
-
provider_totals[provider_key].cached_tokens += stats.cached_tokens
|
|
70
|
-
provider_totals[provider_key].cost_usd += stats.cost_usd
|
|
71
|
-
provider_totals[provider_key].cost_cny += stats.cost_cny
|
|
96
|
+
provider_raw = stats.provider or "(unknown)"
|
|
72
97
|
|
|
73
|
-
|
|
74
|
-
|
|
98
|
+
# Split provider by first "/"
|
|
99
|
+
if "/" in provider_raw:
|
|
100
|
+
parts = provider_raw.split("/", 1)
|
|
101
|
+
top_provider, sub_provider = parts[0], parts[1]
|
|
102
|
+
else:
|
|
103
|
+
top_provider, sub_provider = provider_raw, ""
|
|
104
|
+
|
|
105
|
+
# Initialize top-level provider group
|
|
106
|
+
if top_provider not in provider_groups:
|
|
107
|
+
provider_groups[top_provider] = ProviderGroup(
|
|
108
|
+
name=top_provider,
|
|
109
|
+
sub_providers={},
|
|
110
|
+
models=[],
|
|
111
|
+
total=ModelUsageStats(model_name=top_provider),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
group = provider_groups[top_provider]
|
|
115
|
+
|
|
116
|
+
# Accumulate to top-level total
|
|
117
|
+
group.total.input_tokens += stats.input_tokens
|
|
118
|
+
group.total.output_tokens += stats.output_tokens
|
|
119
|
+
group.total.cached_tokens += stats.cached_tokens
|
|
120
|
+
group.total.cost_usd += stats.cost_usd
|
|
121
|
+
group.total.cost_cny += stats.cost_cny
|
|
122
|
+
|
|
123
|
+
if sub_provider:
|
|
124
|
+
# Has sub-provider, add to sub-provider group
|
|
125
|
+
if sub_provider not in group.sub_providers:
|
|
126
|
+
group.sub_providers[sub_provider] = SubProviderGroup(
|
|
127
|
+
name=sub_provider,
|
|
128
|
+
models=[],
|
|
129
|
+
total=ModelUsageStats(model_name=sub_provider),
|
|
130
|
+
)
|
|
131
|
+
sub_group = group.sub_providers[sub_provider]
|
|
132
|
+
sub_group.models.append(stats)
|
|
133
|
+
sub_group.total.input_tokens += stats.input_tokens
|
|
134
|
+
sub_group.total.output_tokens += stats.output_tokens
|
|
135
|
+
sub_group.total.cached_tokens += stats.cached_tokens
|
|
136
|
+
sub_group.total.cost_usd += stats.cost_usd
|
|
137
|
+
sub_group.total.cost_cny += stats.cost_cny
|
|
138
|
+
else:
|
|
139
|
+
# No sub-provider, add directly to models
|
|
140
|
+
group.models.append(stats)
|
|
75
141
|
|
|
76
|
-
# Sort
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
142
|
+
# Sort everything by cost
|
|
143
|
+
for group in provider_groups.values():
|
|
144
|
+
group.models.sort(key=_sort_by_cost)
|
|
145
|
+
for sub_group in group.sub_providers.values():
|
|
146
|
+
sub_group.models.sort(key=_sort_by_cost)
|
|
147
|
+
# Sort sub-providers by cost
|
|
148
|
+
group.sub_providers = dict(sorted(group.sub_providers.items(), key=lambda x: _sort_by_cost(x[1].total)))
|
|
80
149
|
|
|
81
|
-
#
|
|
82
|
-
|
|
83
|
-
sorted_provider_totals = {p: provider_totals[p] for p in sorted_providers}
|
|
150
|
+
# Sort top-level providers by cost
|
|
151
|
+
sorted_groups = dict(sorted(provider_groups.items(), key=lambda x: _sort_by_cost(x[1].total)))
|
|
84
152
|
|
|
85
|
-
return
|
|
153
|
+
return sorted_groups
|
|
86
154
|
|
|
87
155
|
|
|
88
156
|
@dataclass
|
|
@@ -223,8 +291,8 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
|
|
|
223
291
|
table.add_column("Date", style="cyan")
|
|
224
292
|
table.add_column("Model", overflow="ellipsis")
|
|
225
293
|
table.add_column("Input", justify="right")
|
|
226
|
-
table.add_column("Output", justify="right")
|
|
227
294
|
table.add_column("Cache", justify="right")
|
|
295
|
+
table.add_column("Output", justify="right")
|
|
228
296
|
table.add_column("Total", justify="right")
|
|
229
297
|
table.add_column("USD", justify="right")
|
|
230
298
|
table.add_column("CNY", justify="right")
|
|
@@ -248,9 +316,9 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
|
|
|
248
316
|
table.add_row(
|
|
249
317
|
date_label,
|
|
250
318
|
model_col,
|
|
251
|
-
fmt(format_tokens(stats.
|
|
252
|
-
fmt(format_tokens(stats.output_tokens)),
|
|
319
|
+
fmt(format_tokens(stats.non_cached_input_tokens)),
|
|
253
320
|
fmt(format_tokens(stats.cached_tokens)),
|
|
321
|
+
fmt(format_tokens(stats.output_tokens)),
|
|
254
322
|
fmt(format_tokens(stats.total_tokens)),
|
|
255
323
|
fmt(usd_str),
|
|
256
324
|
fmt(cny_str),
|
|
@@ -261,19 +329,40 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
|
|
|
261
329
|
date_label: str = "",
|
|
262
330
|
show_subtotal: bool = True,
|
|
263
331
|
) -> None:
|
|
264
|
-
"""Render models grouped by provider with tree structure."""
|
|
265
|
-
|
|
332
|
+
"""Render models grouped by provider with three-level tree structure."""
|
|
333
|
+
provider_groups = group_models_by_provider(models)
|
|
266
334
|
|
|
267
335
|
first_row = True
|
|
268
|
-
for
|
|
269
|
-
|
|
270
|
-
add_stats_row(
|
|
336
|
+
for group in provider_groups.values():
|
|
337
|
+
# Top-level provider
|
|
338
|
+
add_stats_row(group.total, date_label=date_label if first_row else "", bold=True)
|
|
271
339
|
first_row = False
|
|
272
340
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
341
|
+
if group.sub_providers:
|
|
342
|
+
# Has sub-providers: render three-level tree
|
|
343
|
+
sub_list = list(group.sub_providers.values())
|
|
344
|
+
for sub_idx, sub_group in enumerate(sub_list):
|
|
345
|
+
is_last_sub = sub_idx == len(sub_list) - 1
|
|
346
|
+
sub_prefix = " └─ " if is_last_sub else " ├─ "
|
|
347
|
+
|
|
348
|
+
# Sub-provider row
|
|
349
|
+
add_stats_row(sub_group.total, prefix=sub_prefix, bold=True)
|
|
350
|
+
|
|
351
|
+
# Models under sub-provider
|
|
352
|
+
for model_idx, stats in enumerate(sub_group.models):
|
|
353
|
+
is_last_model = model_idx == len(sub_group.models) - 1
|
|
354
|
+
# Indent based on whether sub-provider is last
|
|
355
|
+
if is_last_sub:
|
|
356
|
+
model_prefix = " └─ " if is_last_model else " ├─ "
|
|
357
|
+
else:
|
|
358
|
+
model_prefix = " │ └─ " if is_last_model else " │ ├─ "
|
|
359
|
+
add_stats_row(stats, prefix=model_prefix)
|
|
360
|
+
else:
|
|
361
|
+
# No sub-providers: render two-level tree (direct models)
|
|
362
|
+
for model_idx, stats in enumerate(group.models):
|
|
363
|
+
is_last_model = model_idx == len(group.models) - 1
|
|
364
|
+
model_prefix = " └─ " if is_last_model else " ├─ "
|
|
365
|
+
add_stats_row(stats, prefix=model_prefix)
|
|
277
366
|
|
|
278
367
|
if show_subtotal:
|
|
279
368
|
subtotal = ModelUsageStats(model_name="(subtotal)")
|