klaude-code 2.8.1__py3-none-any.whl → 2.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. klaude_code/app/runtime.py +2 -1
  2. klaude_code/auth/antigravity/oauth.py +33 -38
  3. klaude_code/auth/antigravity/token_manager.py +0 -18
  4. klaude_code/auth/base.py +53 -0
  5. klaude_code/auth/claude/oauth.py +34 -49
  6. klaude_code/auth/codex/exceptions.py +0 -4
  7. klaude_code/auth/codex/oauth.py +32 -28
  8. klaude_code/auth/codex/token_manager.py +0 -18
  9. klaude_code/cli/cost_cmd.py +128 -39
  10. klaude_code/cli/list_model.py +27 -10
  11. klaude_code/cli/main.py +14 -3
  12. klaude_code/config/assets/builtin_config.yaml +25 -24
  13. klaude_code/config/config.py +47 -25
  14. klaude_code/config/sub_agent_model_helper.py +18 -13
  15. klaude_code/config/thinking.py +0 -8
  16. klaude_code/const.py +1 -1
  17. klaude_code/core/agent_profile.py +11 -56
  18. klaude_code/core/compaction/overflow.py +0 -4
  19. klaude_code/core/executor.py +33 -5
  20. klaude_code/core/manager/llm_clients.py +9 -1
  21. klaude_code/core/prompts/prompt-claude-code.md +4 -4
  22. klaude_code/core/reminders.py +21 -23
  23. klaude_code/core/task.py +1 -5
  24. klaude_code/core/tool/__init__.py +3 -2
  25. klaude_code/core/tool/file/apply_patch.py +0 -27
  26. klaude_code/core/tool/file/read_tool.md +3 -2
  27. klaude_code/core/tool/file/read_tool.py +27 -3
  28. klaude_code/core/tool/offload.py +0 -35
  29. klaude_code/core/tool/shell/bash_tool.py +1 -1
  30. klaude_code/core/tool/sub_agent/__init__.py +6 -0
  31. klaude_code/core/tool/sub_agent/image_gen.md +16 -0
  32. klaude_code/core/tool/sub_agent/image_gen.py +146 -0
  33. klaude_code/core/tool/sub_agent/task.md +20 -0
  34. klaude_code/core/tool/sub_agent/task.py +205 -0
  35. klaude_code/core/tool/tool_registry.py +0 -16
  36. klaude_code/core/turn.py +1 -1
  37. klaude_code/llm/anthropic/input.py +6 -5
  38. klaude_code/llm/antigravity/input.py +14 -7
  39. klaude_code/llm/bedrock_anthropic/__init__.py +3 -0
  40. klaude_code/llm/google/client.py +8 -6
  41. klaude_code/llm/google/input.py +20 -12
  42. klaude_code/llm/image.py +18 -11
  43. klaude_code/llm/input_common.py +32 -6
  44. klaude_code/llm/json_stable.py +37 -0
  45. klaude_code/llm/{codex → openai_codex}/__init__.py +1 -1
  46. klaude_code/llm/{codex → openai_codex}/client.py +24 -2
  47. klaude_code/llm/openai_codex/prompt_sync.py +237 -0
  48. klaude_code/llm/openai_compatible/client.py +3 -1
  49. klaude_code/llm/openai_compatible/input.py +0 -10
  50. klaude_code/llm/openai_compatible/stream.py +35 -10
  51. klaude_code/llm/{responses → openai_responses}/client.py +1 -1
  52. klaude_code/llm/{responses → openai_responses}/input.py +15 -5
  53. klaude_code/llm/registry.py +3 -8
  54. klaude_code/llm/stream_parts.py +3 -1
  55. klaude_code/llm/usage.py +1 -9
  56. klaude_code/protocol/events.py +2 -2
  57. klaude_code/protocol/message.py +3 -2
  58. klaude_code/protocol/model.py +34 -2
  59. klaude_code/protocol/op.py +13 -0
  60. klaude_code/protocol/op_handler.py +5 -0
  61. klaude_code/protocol/sub_agent/AGENTS.md +5 -5
  62. klaude_code/protocol/sub_agent/__init__.py +13 -34
  63. klaude_code/protocol/sub_agent/explore.py +7 -34
  64. klaude_code/protocol/sub_agent/image_gen.py +3 -74
  65. klaude_code/protocol/sub_agent/task.py +3 -47
  66. klaude_code/protocol/sub_agent/web.py +8 -52
  67. klaude_code/protocol/tools.py +2 -0
  68. klaude_code/session/session.py +80 -22
  69. klaude_code/session/store.py +0 -4
  70. klaude_code/skill/assets/deslop/SKILL.md +9 -0
  71. klaude_code/skill/system_skills.py +0 -20
  72. klaude_code/tui/command/fork_session_cmd.py +5 -2
  73. klaude_code/tui/command/resume_cmd.py +9 -2
  74. klaude_code/tui/command/sub_agent_model_cmd.py +85 -18
  75. klaude_code/tui/components/assistant.py +0 -26
  76. klaude_code/tui/components/bash_syntax.py +4 -0
  77. klaude_code/tui/components/command_output.py +3 -1
  78. klaude_code/tui/components/developer.py +3 -0
  79. klaude_code/tui/components/diffs.py +4 -209
  80. klaude_code/tui/components/errors.py +4 -0
  81. klaude_code/tui/components/mermaid_viewer.py +2 -2
  82. klaude_code/tui/components/metadata.py +0 -3
  83. klaude_code/tui/components/rich/markdown.py +120 -87
  84. klaude_code/tui/components/rich/status.py +2 -2
  85. klaude_code/tui/components/rich/theme.py +11 -6
  86. klaude_code/tui/components/sub_agent.py +2 -46
  87. klaude_code/tui/components/thinking.py +0 -33
  88. klaude_code/tui/components/tools.py +65 -21
  89. klaude_code/tui/components/user_input.py +2 -0
  90. klaude_code/tui/input/images.py +21 -18
  91. klaude_code/tui/input/key_bindings.py +2 -2
  92. klaude_code/tui/input/prompt_toolkit.py +49 -49
  93. klaude_code/tui/machine.py +29 -47
  94. klaude_code/tui/renderer.py +48 -33
  95. klaude_code/tui/runner.py +2 -1
  96. klaude_code/tui/terminal/image.py +27 -34
  97. klaude_code/ui/common.py +0 -70
  98. {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/METADATA +3 -6
  99. {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/RECORD +103 -99
  100. klaude_code/core/tool/sub_agent_tool.py +0 -126
  101. klaude_code/llm/bedrock/__init__.py +0 -3
  102. klaude_code/llm/openai_compatible/tool_call_accumulator.py +0 -108
  103. klaude_code/tui/components/rich/searchable_text.py +0 -68
  104. /klaude_code/llm/{bedrock → bedrock_anthropic}/client.py +0 -0
  105. /klaude_code/llm/{responses → openai_responses}/__init__.py +0 -0
  106. {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/WHEEL +0 -0
  107. {klaude_code-2.8.1.dist-info → klaude_code-2.9.1.dist-info}/entry_points.txt +0 -0
@@ -178,6 +178,7 @@ async def handle_keyboard_interrupt(executor: Executor) -> None:
178
178
  log("Bye!")
179
179
  session_id = executor.context.current_session_id()
180
180
  if session_id and Session.exists(session_id):
181
- log(("Resume with:", "dim"), (f"klaude --resume {session_id}", "green"))
181
+ short_id = Session.shortest_unique_prefix(session_id)
182
+ log(("Resume with:", "dim"), (f"klaude -r {short_id}", "green"))
182
183
  with contextlib.suppress(Exception):
183
184
  await executor.submit(op.InterruptOperation(target_session_id=None))
@@ -258,42 +258,46 @@ class AntigravityOAuth:
258
258
  )
259
259
 
260
260
  def refresh(self) -> AntigravityAuthState:
261
- """Refresh the access token using refresh token."""
262
- state = self.token_manager.get_state()
263
- if state is None:
264
- raise AntigravityNotLoggedInError("Not logged in to Antigravity. Run 'klaude login antigravity' first.")
261
+ """Refresh the access token using refresh token with file locking.
265
262
 
266
- data = {
267
- "client_id": CLIENT_ID,
268
- "client_secret": CLIENT_SECRET,
269
- "refresh_token": state.refresh_token,
270
- "grant_type": "refresh_token",
271
- }
263
+ Uses file locking to prevent multiple instances from refreshing simultaneously.
264
+ If another instance has already refreshed, returns the updated state.
265
+ """
272
266
 
273
- with httpx.Client() as client:
274
- response = client.post(TOKEN_URL, data=data, timeout=30)
267
+ def do_refresh(current_state: AntigravityAuthState) -> AntigravityAuthState:
268
+ data = {
269
+ "client_id": CLIENT_ID,
270
+ "client_secret": CLIENT_SECRET,
271
+ "refresh_token": current_state.refresh_token,
272
+ "grant_type": "refresh_token",
273
+ }
275
274
 
276
- if response.status_code != 200:
277
- raise AntigravityTokenExpiredError(f"Token refresh failed: {response.text}")
275
+ with httpx.Client() as client:
276
+ response = client.post(TOKEN_URL, data=data, timeout=30)
278
277
 
279
- tokens = response.json()
280
- access_token = tokens["access_token"]
281
- refresh_token = tokens.get("refresh_token", state.refresh_token)
282
- expires_in = tokens.get("expires_in", 3600)
278
+ if response.status_code != 200:
279
+ raise AntigravityTokenExpiredError(f"Token refresh failed: {response.text}")
283
280
 
284
- # Calculate expiry time with 5 minute buffer
285
- expires_at = int(time.time()) + expires_in - 300
281
+ tokens = response.json()
282
+ access_token = tokens["access_token"]
283
+ refresh_token = tokens.get("refresh_token", current_state.refresh_token)
284
+ expires_in = tokens.get("expires_in", 3600)
286
285
 
287
- new_state = AntigravityAuthState(
288
- access_token=access_token,
289
- refresh_token=refresh_token,
290
- expires_at=expires_at,
291
- project_id=state.project_id,
292
- email=state.email,
293
- )
286
+ # Calculate expiry time with 5 minute buffer
287
+ expires_at = int(time.time()) + expires_in - 300
294
288
 
295
- self.token_manager.save(new_state)
296
- return new_state
289
+ return AntigravityAuthState(
290
+ access_token=access_token,
291
+ refresh_token=refresh_token,
292
+ expires_at=expires_at,
293
+ project_id=current_state.project_id,
294
+ email=current_state.email,
295
+ )
296
+
297
+ try:
298
+ return self.token_manager.refresh_with_lock(do_refresh)
299
+ except ValueError as e:
300
+ raise AntigravityNotLoggedInError(str(e)) from e
297
301
 
298
302
  def ensure_valid_token(self) -> tuple[str, str]:
299
303
  """Ensure we have a valid access token, refreshing if needed.
@@ -309,12 +313,3 @@ class AntigravityOAuth:
309
313
  state = self.refresh()
310
314
 
311
315
  return state.access_token, state.project_id
312
-
313
- def get_api_key_json(self) -> str:
314
- """Get API key as JSON string for LLM client.
315
-
316
- Returns:
317
- JSON string with token and projectId.
318
- """
319
- access_token, project_id = self.ensure_valid_token()
320
- return json.dumps({"token": access_token, "projectId": project_id})
@@ -25,21 +25,3 @@ class AntigravityTokenManager(BaseTokenManager[AntigravityAuthState]):
25
25
 
26
26
  def _create_state(self, data: dict[str, Any]) -> AntigravityAuthState:
27
27
  return AntigravityAuthState.model_validate(data)
28
-
29
- def get_access_token(self) -> str:
30
- """Get access token, raising if not logged in."""
31
- state = self.get_state()
32
- if state is None:
33
- from klaude_code.auth.antigravity.exceptions import AntigravityNotLoggedInError
34
-
35
- raise AntigravityNotLoggedInError("Not logged in to Antigravity. Run 'klaude login antigravity' first.")
36
- return state.access_token
37
-
38
- def get_project_id(self) -> str:
39
- """Get project ID, raising if not logged in."""
40
- state = self.get_state()
41
- if state is None:
42
- from klaude_code.auth.antigravity.exceptions import AntigravityNotLoggedInError
43
-
44
- raise AntigravityNotLoggedInError("Not logged in to Antigravity. Run 'klaude login antigravity' first.")
45
- return state.project_id
klaude_code/auth/base.py CHANGED
@@ -3,12 +3,15 @@
3
3
  import json
4
4
  import time
5
5
  from abc import ABC, abstractmethod
6
+ from collections.abc import Callable
6
7
  from pathlib import Path
7
8
  from typing import Any, cast
8
9
 
10
+ from filelock import FileLock, Timeout
9
11
  from pydantic import BaseModel
10
12
 
11
13
  KLAUDE_AUTH_FILE = Path.home() / ".klaude" / "klaude-auth.json"
14
+ LOCK_TIMEOUT_SECONDS = 30 # Maximum time to wait for lock acquisition
12
15
 
13
16
 
14
17
  class BaseAuthState(BaseModel):
@@ -99,3 +102,53 @@ class BaseTokenManager[T: BaseAuthState](ABC):
99
102
  def clear_cached_state(self) -> None:
100
103
  """Clear in-memory cached state to force reload from file on next access."""
101
104
  self._state = None
105
+
106
+ def _get_lock_file(self) -> Path:
107
+ """Get the lock file path for this auth file."""
108
+ return self.auth_file.with_suffix(".lock")
109
+
110
+ def refresh_with_lock(self, refresh_fn: Callable[[T], T]) -> T:
111
+ """Refresh token with file locking to prevent concurrent refresh.
112
+
113
+ This prevents multiple instances from simultaneously refreshing the same token.
114
+ If another instance has already refreshed, returns the updated state.
115
+
116
+ Args:
117
+ refresh_fn: Function that takes current state and returns new state.
118
+
119
+ Returns:
120
+ The new or already-refreshed authentication state.
121
+
122
+ Raises:
123
+ Timeout: If unable to acquire the lock within timeout.
124
+ ValueError: If not logged in.
125
+ """
126
+ lock_file = self._get_lock_file()
127
+ lock = FileLock(lock_file, timeout=LOCK_TIMEOUT_SECONDS)
128
+
129
+ try:
130
+ with lock:
131
+ # Re-read file after acquiring lock - another instance may have refreshed
132
+ self.clear_cached_state()
133
+ state = self.load()
134
+
135
+ if state is None:
136
+ raise ValueError(f"Not logged in to {self.storage_key}")
137
+
138
+ # Check if token is still expired after re-reading
139
+ if not state.is_expired():
140
+ # Another instance already refreshed, use their result
141
+ return state
142
+
143
+ # Token still expired, we need to refresh
144
+ new_state = refresh_fn(state)
145
+ self.save(new_state)
146
+ return new_state
147
+
148
+ except Timeout:
149
+ # Lock timeout - try to re-read file in case another instance succeeded
150
+ self.clear_cached_state()
151
+ state = self.load()
152
+ if state and not state.is_expired():
153
+ return state
154
+ raise
@@ -125,60 +125,45 @@ class ClaudeOAuth:
125
125
  expires_at=int(time.time()) + int(expires_in),
126
126
  )
127
127
 
128
- def _do_refresh_request(self, refresh_token: str) -> httpx.Response:
129
- """Send token refresh request to OAuth server."""
130
- payload = {
131
- "grant_type": "refresh_token",
132
- "client_id": CLIENT_ID,
133
- "refresh_token": refresh_token,
134
- }
135
- with httpx.Client() as client:
136
- return client.post(
137
- TOKEN_URL,
138
- json=payload,
139
- headers={"Content-Type": "application/json"},
140
- )
141
-
142
128
  def refresh(self) -> ClaudeAuthState:
143
- """Refresh the access token using refresh token.
129
+ """Refresh the access token using refresh token with file locking.
144
130
 
145
- Handles concurrent refresh race conditions by retrying with freshly loaded token
146
- if the first attempt fails with invalid_grant error.
131
+ Uses file locking to prevent multiple instances from refreshing simultaneously.
132
+ If another instance has already refreshed, returns the updated state.
147
133
  """
148
- state = self.token_manager.get_state()
149
- if state is None:
150
- raise ClaudeNotLoggedInError("Not logged in to Claude. Run 'klaude login claude' first.")
151
134
 
152
- response = self._do_refresh_request(state.refresh_token)
153
-
154
- # Handle race condition: another process may have refreshed the token already
155
- if response.status_code != 200 and "invalid_grant" in response.text:
156
- # Reload token from file (another process may have updated it)
157
- self.token_manager.clear_cached_state()
158
- fresh_state = self.token_manager.load()
159
- if fresh_state and fresh_state.refresh_token != state.refresh_token:
160
- # Token was updated by another process
161
- if not fresh_state.is_expired():
162
- # New token is still valid, use it directly
163
- return fresh_state
164
- # New token expired, try refreshing with the new refresh_token
165
- response = self._do_refresh_request(fresh_state.refresh_token)
166
-
167
- if response.status_code != 200:
168
- raise ClaudeAuthError(f"Token refresh failed: {response.text}")
169
-
170
- tokens = response.json()
171
- access_token = tokens["access_token"]
172
- refresh_token = tokens.get("refresh_token", state.refresh_token)
173
- expires_in = tokens.get("expires_in", 3600)
135
+ def do_refresh(current_state: ClaudeAuthState) -> ClaudeAuthState:
136
+ payload = {
137
+ "grant_type": "refresh_token",
138
+ "client_id": CLIENT_ID,
139
+ "refresh_token": current_state.refresh_token,
140
+ }
141
+
142
+ with httpx.Client() as client:
143
+ response = client.post(
144
+ TOKEN_URL,
145
+ json=payload,
146
+ headers={"Content-Type": "application/json"},
147
+ )
148
+
149
+ if response.status_code != 200:
150
+ raise ClaudeAuthError(f"Token refresh failed: {response.text}")
151
+
152
+ tokens = response.json()
153
+ access_token = tokens["access_token"]
154
+ refresh_token = tokens.get("refresh_token", current_state.refresh_token)
155
+ expires_in = tokens.get("expires_in", 3600)
156
+
157
+ return ClaudeAuthState(
158
+ access_token=access_token,
159
+ refresh_token=refresh_token,
160
+ expires_at=int(time.time()) + int(expires_in),
161
+ )
174
162
 
175
- new_state = ClaudeAuthState(
176
- access_token=access_token,
177
- refresh_token=refresh_token,
178
- expires_at=int(time.time()) + int(expires_in),
179
- )
180
- self.token_manager.save(new_state)
181
- return new_state
163
+ try:
164
+ return self.token_manager.refresh_with_lock(do_refresh)
165
+ except ValueError as e:
166
+ raise ClaudeNotLoggedInError(str(e)) from e
182
167
 
183
168
  def ensure_valid_token(self) -> str:
184
169
  """Ensure we have a valid access token, refreshing if needed."""
@@ -15,7 +15,3 @@ class CodexTokenExpiredError(CodexAuthError):
15
15
 
16
16
  class CodexOAuthError(CodexAuthError):
17
17
  """OAuth flow failed."""
18
-
19
-
20
- class CodexUnsupportedModelError(CodexAuthError):
21
- """Model is not supported by codex_oauth protocol."""
@@ -177,43 +177,47 @@ class CodexOAuth:
177
177
  )
178
178
 
179
179
  def refresh(self) -> CodexAuthState:
180
- """Refresh the access token using refresh token."""
181
- state = self.token_manager.get_state()
182
- if state is None:
183
- from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
180
+ """Refresh the access token using refresh token with file locking.
184
181
 
185
- raise CodexNotLoggedInError("Not logged in to Codex. Run 'klaude login codex' first.")
182
+ Uses file locking to prevent multiple instances from refreshing simultaneously.
183
+ If another instance has already refreshed, returns the updated state.
184
+ """
186
185
 
187
- data = {
188
- "grant_type": "refresh_token",
189
- "client_id": CLIENT_ID,
190
- "refresh_token": state.refresh_token,
191
- }
186
+ def do_refresh(current_state: CodexAuthState) -> CodexAuthState:
187
+ data = {
188
+ "grant_type": "refresh_token",
189
+ "client_id": CLIENT_ID,
190
+ "refresh_token": current_state.refresh_token,
191
+ }
192
192
 
193
- with httpx.Client() as client:
194
- response = client.post(TOKEN_URL, data=data)
193
+ with httpx.Client() as client:
194
+ response = client.post(TOKEN_URL, data=data)
195
195
 
196
- if response.status_code != 200:
197
- from klaude_code.auth.codex.exceptions import CodexTokenExpiredError
196
+ if response.status_code != 200:
197
+ from klaude_code.auth.codex.exceptions import CodexTokenExpiredError
198
198
 
199
- raise CodexTokenExpiredError(f"Token refresh failed: {response.text}")
199
+ raise CodexTokenExpiredError(f"Token refresh failed: {response.text}")
200
200
 
201
- tokens = response.json()
202
- access_token = tokens["access_token"]
203
- refresh_token = tokens.get("refresh_token", state.refresh_token)
204
- expires_in = tokens.get("expires_in", 3600)
201
+ tokens = response.json()
202
+ access_token = tokens["access_token"]
203
+ refresh_token = tokens.get("refresh_token", current_state.refresh_token)
204
+ expires_in = tokens.get("expires_in", 3600)
205
205
 
206
- account_id = extract_account_id(access_token)
206
+ account_id = extract_account_id(access_token)
207
207
 
208
- new_state = CodexAuthState(
209
- access_token=access_token,
210
- refresh_token=refresh_token,
211
- expires_at=int(time.time()) + expires_in,
212
- account_id=account_id,
213
- )
208
+ return CodexAuthState(
209
+ access_token=access_token,
210
+ refresh_token=refresh_token,
211
+ expires_at=int(time.time()) + expires_in,
212
+ account_id=account_id,
213
+ )
214
+
215
+ try:
216
+ return self.token_manager.refresh_with_lock(do_refresh)
217
+ except ValueError as e:
218
+ from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
214
219
 
215
- self.token_manager.save(new_state)
216
- return new_state
220
+ raise CodexNotLoggedInError(str(e)) from e
217
221
 
218
222
  def ensure_valid_token(self) -> str:
219
223
  """Ensure we have a valid access token, refreshing if needed."""
@@ -24,21 +24,3 @@ class CodexTokenManager(BaseTokenManager[CodexAuthState]):
24
24
 
25
25
  def _create_state(self, data: dict[str, Any]) -> CodexAuthState:
26
26
  return CodexAuthState.model_validate(data)
27
-
28
- def get_access_token(self) -> str:
29
- """Get access token, raising if not logged in."""
30
- state = self.get_state()
31
- if state is None:
32
- from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
33
-
34
- raise CodexNotLoggedInError("Not logged in to Codex. Run 'klaude login codex' first.")
35
- return state.access_token
36
-
37
- def get_account_id(self) -> str:
38
- """Get account ID, raising if not logged in."""
39
- state = self.get_state()
40
- if state is None:
41
- from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
42
-
43
- raise CodexNotLoggedInError("Not logged in to Codex. Run 'klaude login codex' first.")
44
- return state.account_id
@@ -34,6 +34,16 @@ class ModelUsageStats:
34
34
  def total_tokens(self) -> int:
35
35
  return self.input_tokens + self.output_tokens
36
36
 
37
+ @property
38
+ def non_cached_input_tokens(self) -> int:
39
+ """Non-cached prompt tokens.
40
+
41
+ We store `input_tokens` as the provider-reported prompt token count, which
42
+ includes cached tokens for providers that support prompt caching.
43
+ """
44
+
45
+ return max(0, self.input_tokens - self.cached_tokens)
46
+
37
47
  def add_usage(self, usage: model.Usage) -> None:
38
48
  self.input_tokens += usage.input_tokens
39
49
  self.output_tokens += usage.output_tokens
@@ -48,41 +58,99 @@ class ModelUsageStats:
48
58
  ModelKey = tuple[str, str] # (model_name, provider)
49
59
 
50
60
 
51
- def group_models_by_provider(
52
- models: dict[ModelKey, ModelUsageStats],
53
- ) -> tuple[dict[str, list[ModelUsageStats]], dict[str, ModelUsageStats]]:
54
- """Group models by provider and compute provider totals.
61
+ @dataclass
62
+ class SubProviderGroup:
63
+ """Group of models under a sub-provider."""
64
+
65
+ name: str
66
+ models: list[ModelUsageStats]
67
+ total: ModelUsageStats
68
+
69
+
70
+ @dataclass
71
+ class ProviderGroup:
72
+ """Group of models/sub-providers under a top-level provider."""
73
+
74
+ name: str
75
+ sub_providers: dict[str, SubProviderGroup] # empty if no sub-providers
76
+ models: list[ModelUsageStats] # direct models (when no sub-provider)
77
+ total: ModelUsageStats
78
+
55
79
 
56
- Returns (models_by_provider, provider_totals) where both are sorted by cost desc.
80
+ def _sort_by_cost(stats: ModelUsageStats) -> tuple[float, float]:
81
+ return (-stats.cost_usd, -stats.cost_cny)
82
+
83
+
84
+ def group_models_by_provider(models: dict[ModelKey, ModelUsageStats]) -> dict[str, ProviderGroup]:
85
+ """Group models by provider with three-level hierarchy.
86
+
87
+ Provider strings like "openrouter/Anthropic" are split into:
88
+ - Top-level: "openrouter"
89
+ - Sub-provider: "Anthropic"
90
+
91
+ Returns dict of ProviderGroup sorted by cost desc.
57
92
  """
58
- models_by_provider: dict[str, list[ModelUsageStats]] = {}
59
- provider_totals: dict[str, ModelUsageStats] = {}
93
+ provider_groups: dict[str, ProviderGroup] = {}
60
94
 
61
95
  for stats in models.values():
62
- provider_key = stats.provider or "(unknown)"
63
- if provider_key not in models_by_provider:
64
- models_by_provider[provider_key] = []
65
- provider_totals[provider_key] = ModelUsageStats(model_name=provider_key, provider=provider_key)
66
- models_by_provider[provider_key].append(stats)
67
- provider_totals[provider_key].input_tokens += stats.input_tokens
68
- provider_totals[provider_key].output_tokens += stats.output_tokens
69
- provider_totals[provider_key].cached_tokens += stats.cached_tokens
70
- provider_totals[provider_key].cost_usd += stats.cost_usd
71
- provider_totals[provider_key].cost_cny += stats.cost_cny
96
+ provider_raw = stats.provider or "(unknown)"
72
97
 
73
- def sort_by_cost(stats: ModelUsageStats) -> tuple[float, float]:
74
- return (-stats.cost_usd, -stats.cost_cny)
98
+ # Split provider by first "/"
99
+ if "/" in provider_raw:
100
+ parts = provider_raw.split("/", 1)
101
+ top_provider, sub_provider = parts[0], parts[1]
102
+ else:
103
+ top_provider, sub_provider = provider_raw, ""
104
+
105
+ # Initialize top-level provider group
106
+ if top_provider not in provider_groups:
107
+ provider_groups[top_provider] = ProviderGroup(
108
+ name=top_provider,
109
+ sub_providers={},
110
+ models=[],
111
+ total=ModelUsageStats(model_name=top_provider),
112
+ )
113
+
114
+ group = provider_groups[top_provider]
115
+
116
+ # Accumulate to top-level total
117
+ group.total.input_tokens += stats.input_tokens
118
+ group.total.output_tokens += stats.output_tokens
119
+ group.total.cached_tokens += stats.cached_tokens
120
+ group.total.cost_usd += stats.cost_usd
121
+ group.total.cost_cny += stats.cost_cny
122
+
123
+ if sub_provider:
124
+ # Has sub-provider, add to sub-provider group
125
+ if sub_provider not in group.sub_providers:
126
+ group.sub_providers[sub_provider] = SubProviderGroup(
127
+ name=sub_provider,
128
+ models=[],
129
+ total=ModelUsageStats(model_name=sub_provider),
130
+ )
131
+ sub_group = group.sub_providers[sub_provider]
132
+ sub_group.models.append(stats)
133
+ sub_group.total.input_tokens += stats.input_tokens
134
+ sub_group.total.output_tokens += stats.output_tokens
135
+ sub_group.total.cached_tokens += stats.cached_tokens
136
+ sub_group.total.cost_usd += stats.cost_usd
137
+ sub_group.total.cost_cny += stats.cost_cny
138
+ else:
139
+ # No sub-provider, add directly to models
140
+ group.models.append(stats)
75
141
 
76
- # Sort providers by cost, and models within each provider
77
- sorted_providers = sorted(provider_totals.keys(), key=lambda p: sort_by_cost(provider_totals[p]))
78
- for provider_key in models_by_provider:
79
- models_by_provider[provider_key].sort(key=sort_by_cost)
142
+ # Sort everything by cost
143
+ for group in provider_groups.values():
144
+ group.models.sort(key=_sort_by_cost)
145
+ for sub_group in group.sub_providers.values():
146
+ sub_group.models.sort(key=_sort_by_cost)
147
+ # Sort sub-providers by cost
148
+ group.sub_providers = dict(sorted(group.sub_providers.items(), key=lambda x: _sort_by_cost(x[1].total)))
80
149
 
81
- # Rebuild dicts in sorted order
82
- sorted_models_by_provider = {p: models_by_provider[p] for p in sorted_providers}
83
- sorted_provider_totals = {p: provider_totals[p] for p in sorted_providers}
150
+ # Sort top-level providers by cost
151
+ sorted_groups = dict(sorted(provider_groups.items(), key=lambda x: _sort_by_cost(x[1].total)))
84
152
 
85
- return sorted_models_by_provider, sorted_provider_totals
153
+ return sorted_groups
86
154
 
87
155
 
88
156
  @dataclass
@@ -223,8 +291,8 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
223
291
  table.add_column("Date", style="cyan")
224
292
  table.add_column("Model", overflow="ellipsis")
225
293
  table.add_column("Input", justify="right")
226
- table.add_column("Output", justify="right")
227
294
  table.add_column("Cache", justify="right")
295
+ table.add_column("Output", justify="right")
228
296
  table.add_column("Total", justify="right")
229
297
  table.add_column("USD", justify="right")
230
298
  table.add_column("CNY", justify="right")
@@ -248,9 +316,9 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
248
316
  table.add_row(
249
317
  date_label,
250
318
  model_col,
251
- fmt(format_tokens(stats.input_tokens)),
252
- fmt(format_tokens(stats.output_tokens)),
319
+ fmt(format_tokens(stats.non_cached_input_tokens)),
253
320
  fmt(format_tokens(stats.cached_tokens)),
321
+ fmt(format_tokens(stats.output_tokens)),
254
322
  fmt(format_tokens(stats.total_tokens)),
255
323
  fmt(usd_str),
256
324
  fmt(cny_str),
@@ -261,19 +329,40 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
261
329
  date_label: str = "",
262
330
  show_subtotal: bool = True,
263
331
  ) -> None:
264
- """Render models grouped by provider with tree structure."""
265
- models_by_provider, provider_totals = group_models_by_provider(models)
332
+ """Render models grouped by provider with three-level tree structure."""
333
+ provider_groups = group_models_by_provider(models)
266
334
 
267
335
  first_row = True
268
- for provider_key, provider_models in models_by_provider.items():
269
- provider_stats = provider_totals[provider_key]
270
- add_stats_row(provider_stats, date_label=date_label if first_row else "", bold=True)
336
+ for group in provider_groups.values():
337
+ # Top-level provider
338
+ add_stats_row(group.total, date_label=date_label if first_row else "", bold=True)
271
339
  first_row = False
272
340
 
273
- for i, stats in enumerate(provider_models):
274
- is_last = i == len(provider_models) - 1
275
- prefix = " └─ " if is_last else " ├─ "
276
- add_stats_row(stats, prefix=prefix)
341
+ if group.sub_providers:
342
+ # Has sub-providers: render three-level tree
343
+ sub_list = list(group.sub_providers.values())
344
+ for sub_idx, sub_group in enumerate(sub_list):
345
+ is_last_sub = sub_idx == len(sub_list) - 1
346
+ sub_prefix = " └─ " if is_last_sub else " ├─ "
347
+
348
+ # Sub-provider row
349
+ add_stats_row(sub_group.total, prefix=sub_prefix, bold=True)
350
+
351
+ # Models under sub-provider
352
+ for model_idx, stats in enumerate(sub_group.models):
353
+ is_last_model = model_idx == len(sub_group.models) - 1
354
+ # Indent based on whether sub-provider is last
355
+ if is_last_sub:
356
+ model_prefix = " └─ " if is_last_model else " ├─ "
357
+ else:
358
+ model_prefix = " │ └─ " if is_last_model else " │ ├─ "
359
+ add_stats_row(stats, prefix=model_prefix)
360
+ else:
361
+ # No sub-providers: render two-level tree (direct models)
362
+ for model_idx, stats in enumerate(group.models):
363
+ is_last_model = model_idx == len(group.models) - 1
364
+ model_prefix = " └─ " if is_last_model else " ├─ "
365
+ add_stats_row(stats, prefix=model_prefix)
277
366
 
278
367
  if show_subtotal:
279
368
  subtotal = ModelUsageStats(model_name="(subtotal)")