klaude-code 2.5.2__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. klaude_code/auth/__init__.py +10 -0
  2. klaude_code/auth/env.py +77 -0
  3. klaude_code/cli/auth_cmd.py +89 -21
  4. klaude_code/cli/config_cmd.py +5 -5
  5. klaude_code/cli/cost_cmd.py +167 -68
  6. klaude_code/cli/main.py +51 -27
  7. klaude_code/cli/self_update.py +7 -7
  8. klaude_code/config/assets/builtin_config.yaml +45 -24
  9. klaude_code/config/builtin_config.py +23 -9
  10. klaude_code/config/config.py +19 -9
  11. klaude_code/config/model_matcher.py +1 -1
  12. klaude_code/const.py +2 -1
  13. klaude_code/core/tool/file/edit_tool.py +1 -1
  14. klaude_code/core/tool/file/read_tool.py +2 -2
  15. klaude_code/core/tool/file/write_tool.py +1 -1
  16. klaude_code/core/turn.py +21 -4
  17. klaude_code/llm/anthropic/client.py +75 -50
  18. klaude_code/llm/anthropic/input.py +20 -9
  19. klaude_code/llm/google/client.py +235 -148
  20. klaude_code/llm/google/input.py +44 -36
  21. klaude_code/llm/openai_compatible/stream.py +114 -100
  22. klaude_code/llm/openrouter/client.py +1 -0
  23. klaude_code/llm/openrouter/reasoning.py +4 -29
  24. klaude_code/llm/partial_message.py +2 -32
  25. klaude_code/llm/responses/client.py +99 -81
  26. klaude_code/llm/responses/input.py +11 -25
  27. klaude_code/llm/stream_parts.py +94 -0
  28. klaude_code/log.py +57 -0
  29. klaude_code/protocol/events.py +214 -0
  30. klaude_code/protocol/sub_agent/image_gen.py +0 -4
  31. klaude_code/session/session.py +51 -18
  32. klaude_code/tui/command/fork_session_cmd.py +14 -23
  33. klaude_code/tui/command/model_picker.py +2 -17
  34. klaude_code/tui/command/resume_cmd.py +2 -18
  35. klaude_code/tui/command/sub_agent_model_cmd.py +5 -19
  36. klaude_code/tui/command/thinking_cmd.py +2 -14
  37. klaude_code/tui/commands.py +0 -5
  38. klaude_code/tui/components/common.py +1 -1
  39. klaude_code/tui/components/metadata.py +21 -21
  40. klaude_code/tui/components/rich/quote.py +36 -8
  41. klaude_code/tui/components/rich/theme.py +2 -0
  42. klaude_code/tui/components/sub_agent.py +6 -0
  43. klaude_code/tui/display.py +11 -1
  44. klaude_code/tui/input/completers.py +11 -7
  45. klaude_code/tui/input/prompt_toolkit.py +3 -1
  46. klaude_code/tui/machine.py +108 -56
  47. klaude_code/tui/renderer.py +4 -65
  48. klaude_code/tui/terminal/selector.py +174 -31
  49. {klaude_code-2.5.2.dist-info → klaude_code-2.6.0.dist-info}/METADATA +23 -31
  50. {klaude_code-2.5.2.dist-info → klaude_code-2.6.0.dist-info}/RECORD +52 -58
  51. klaude_code/cli/session_cmd.py +0 -96
  52. klaude_code/protocol/events/__init__.py +0 -63
  53. klaude_code/protocol/events/base.py +0 -18
  54. klaude_code/protocol/events/chat.py +0 -30
  55. klaude_code/protocol/events/lifecycle.py +0 -23
  56. klaude_code/protocol/events/metadata.py +0 -16
  57. klaude_code/protocol/events/streaming.py +0 -43
  58. klaude_code/protocol/events/system.py +0 -56
  59. klaude_code/protocol/events/tools.py +0 -27
  60. {klaude_code-2.5.2.dist-info → klaude_code-2.6.0.dist-info}/WHEEL +0 -0
  61. {klaude_code-2.5.2.dist-info → klaude_code-2.6.0.dist-info}/entry_points.txt +0 -0
@@ -12,6 +12,12 @@ from klaude_code.auth.codex import (
12
12
  CodexTokenExpiredError,
13
13
  CodexTokenManager,
14
14
  )
15
+ from klaude_code.auth.env import (
16
+ delete_auth_env,
17
+ get_auth_env,
18
+ list_auth_env,
19
+ set_auth_env,
20
+ )
15
21
 
16
22
  __all__ = [
17
23
  "CodexAuthError",
@@ -21,4 +27,8 @@ __all__ = [
21
27
  "CodexOAuthError",
22
28
  "CodexTokenExpiredError",
23
29
  "CodexTokenManager",
30
+ "delete_auth_env",
31
+ "get_auth_env",
32
+ "list_auth_env",
33
+ "set_auth_env",
24
34
  ]
@@ -0,0 +1,77 @@
1
+ """Environment variable configuration stored in klaude-auth.json."""
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from klaude_code.auth.base import KLAUDE_AUTH_FILE
7
+
8
+
9
+ def _load_store() -> dict[str, Any]:
10
+ """Load the auth store from file."""
11
+ if not KLAUDE_AUTH_FILE.exists():
12
+ return {}
13
+ try:
14
+ data: Any = json.loads(KLAUDE_AUTH_FILE.read_text())
15
+ if isinstance(data, dict):
16
+ return dict(data)
17
+ return {}
18
+ except (json.JSONDecodeError, ValueError):
19
+ return {}
20
+
21
+
22
+ def _save_store(data: dict[str, Any]) -> None:
23
+ """Save the auth store to file."""
24
+ KLAUDE_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
25
+ KLAUDE_AUTH_FILE.write_text(json.dumps(data, indent=2))
26
+
27
+
28
+ def get_auth_env(env_var: str) -> str | None:
29
+ """Get environment variable value from klaude-auth.json 'env' section.
30
+
31
+ This provides a fallback for API keys when real environment variables are not set.
32
+ Priority: os.environ > klaude-auth.json env
33
+ """
34
+ store = _load_store()
35
+ env_section: Any = store.get("env")
36
+ if not isinstance(env_section, dict):
37
+ return None
38
+ value: Any = env_section.get(env_var)
39
+ return str(value) if value is not None else None
40
+
41
+
42
+ def set_auth_env(env_var: str, value: str) -> None:
43
+ """Set environment variable value in klaude-auth.json 'env' section."""
44
+ store = _load_store()
45
+ env_section: Any = store.get("env")
46
+ if not isinstance(env_section, dict):
47
+ env_section = {}
48
+ env_section[env_var] = value
49
+ store["env"] = env_section
50
+ _save_store(store)
51
+
52
+
53
+ def delete_auth_env(env_var: str) -> None:
54
+ """Delete environment variable from klaude-auth.json 'env' section."""
55
+ store = _load_store()
56
+ env_section: Any = store.get("env")
57
+ if not isinstance(env_section, dict):
58
+ return
59
+ env_section.pop(env_var, None)
60
+ if len(env_section) == 0:
61
+ store.pop("env", None)
62
+ else:
63
+ store["env"] = env_section
64
+ if len(store) == 0:
65
+ if KLAUDE_AUTH_FILE.exists():
66
+ KLAUDE_AUTH_FILE.unlink()
67
+ else:
68
+ _save_store(store)
69
+
70
+
71
+ def list_auth_env() -> dict[str, str]:
72
+ """List all environment variables in klaude-auth.json 'env' section."""
73
+ store = _load_store()
74
+ env_section: Any = store.get("env")
75
+ if not isinstance(env_section, dict):
76
+ return {}
77
+ return {k: str(v) for k, v in env_section.items() if v is not None}
@@ -4,41 +4,81 @@ import datetime
4
4
  import webbrowser
5
5
 
6
6
  import typer
7
- from prompt_toolkit.styles import Style
8
7
 
9
8
  from klaude_code.log import log
10
- from klaude_code.tui.terminal.selector import SelectItem, select_one
11
-
12
- _SELECT_STYLE = Style(
13
- [
14
- ("instruction", "ansibrightblack"),
15
- ("pointer", "ansigreen"),
16
- ("highlighted", "ansigreen"),
17
- ("text", "ansibrightblack"),
18
- ("question", "bold"),
19
- ]
20
- )
9
+ from klaude_code.tui.terminal.selector import DEFAULT_PICKER_STYLE, SelectItem, select_one
21
10
 
22
11
 
23
12
  def _select_provider() -> str | None:
24
13
  """Display provider selection menu and return selected provider."""
14
+ from klaude_code.config.builtin_config import SUPPORTED_API_KEYS
15
+
25
16
  items: list[SelectItem[str]] = [
26
- SelectItem(title=[("class:text", "Claude Max/Pro Subscription\n")], value="claude", search_text="claude"),
27
- SelectItem(title=[("class:text", "ChatGPT Codex Subscription\n")], value="codex", search_text="codex"),
17
+ SelectItem(
18
+ title=[("", "Claude Max/Pro Subscription "), ("ansibrightblack", "[OAuth]\n")],
19
+ value="claude",
20
+ search_text="claude",
21
+ ),
22
+ SelectItem(
23
+ title=[("", "ChatGPT Codex Subscription "), ("ansibrightblack", "[OAuth]\n")],
24
+ value="codex",
25
+ search_text="codex",
26
+ ),
28
27
  ]
28
+ # Add API key options
29
+ for key_info in SUPPORTED_API_KEYS:
30
+ items.append(
31
+ SelectItem(
32
+ title=[("", f"{key_info.name} "), ("ansibrightblack", "[API key]\n")],
33
+ value=key_info.env_var,
34
+ search_text=key_info.env_var,
35
+ )
36
+ )
37
+
29
38
  return select_one(
30
39
  message="Select provider to login:",
31
40
  items=items,
32
41
  pointer="→",
33
- style=_SELECT_STYLE,
42
+ style=DEFAULT_PICKER_STYLE,
34
43
  use_search_filter=False,
35
44
  )
36
45
 
37
46
 
47
+ def _configure_api_key(env_var: str) -> None:
48
+ """Configure a specific API key."""
49
+ import os
50
+
51
+ from klaude_code.auth.env import get_auth_env, set_auth_env
52
+
53
+ # Check if already configured
54
+ current_value = os.environ.get(env_var) or get_auth_env(env_var)
55
+ if current_value:
56
+ masked = current_value[:8] + "..." if len(current_value) > 8 else "***"
57
+ log(f"Current {env_var}: {masked}")
58
+ if not typer.confirm("Do you want to update it?"):
59
+ return
60
+
61
+ api_key = typer.prompt(f"Enter {env_var}", hide_input=True)
62
+ if not api_key.strip():
63
+ log(("Error: API key cannot be empty", "red"))
64
+ raise typer.Exit(1)
65
+
66
+ set_auth_env(env_var, api_key.strip())
67
+ log((f"{env_var} saved successfully!", "green"))
68
+
69
+
70
+ def _build_provider_help() -> str:
71
+ from klaude_code.config.builtin_config import SUPPORTED_API_KEYS
72
+
73
+ # Use first word of name for brevity (e.g., "google" instead of "google gemini")
74
+ names = ["codex", "claude"] + [k.name.split()[0].lower() for k in SUPPORTED_API_KEYS]
75
+ return f"Provider name ({', '.join(names)})"
76
+
77
+
38
78
  def login_command(
39
- provider: str | None = typer.Argument(None, help="Provider to login (codex|claude)"),
79
+ provider: str | None = typer.Argument(None, help=_build_provider_help()),
40
80
  ) -> None:
41
- """Login to a provider using OAuth."""
81
+ """Login to a provider or configure API keys."""
42
82
  if provider is None:
43
83
  provider = _select_provider()
44
84
  if provider is None:
@@ -110,8 +150,27 @@ def login_command(
110
150
  log((f"Login failed: {e}", "red"))
111
151
  raise typer.Exit(1) from None
112
152
  case _:
113
- log((f"Error: Unknown provider '{provider}'. Supported: codex, claude", "red"))
114
- raise typer.Exit(1)
153
+ from klaude_code.config.builtin_config import SUPPORTED_API_KEYS
154
+
155
+ # Match by env var (e.g., OPENAI_API_KEY) or name (e.g., openai, google)
156
+ env_var: str | None = None
157
+ provider_lower = provider.lower()
158
+ provider_upper = provider.upper()
159
+ for key_info in SUPPORTED_API_KEYS:
160
+ name_lower = key_info.name.lower()
161
+ # Exact match or starts with (for "google" -> "google gemini")
162
+ if key_info.env_var == provider_upper or name_lower == provider_lower:
163
+ env_var = key_info.env_var
164
+ break
165
+ if name_lower.startswith(provider_lower) or provider_lower in name_lower.split():
166
+ env_var = key_info.env_var
167
+ break
168
+
169
+ if env_var:
170
+ _configure_api_key(env_var)
171
+ else:
172
+ log((f"Error: Unknown provider '{provider}'", "red"))
173
+ raise typer.Exit(1)
115
174
 
116
175
 
117
176
  def logout_command(
@@ -150,5 +209,14 @@ def logout_command(
150
209
 
151
210
  def register_auth_commands(app: typer.Typer) -> None:
152
211
  """Register auth commands to the given Typer app."""
153
- app.command("login")(login_command)
154
- app.command("logout")(logout_command)
212
+ auth_app = typer.Typer(help="Login/logout", invoke_without_command=True)
213
+
214
+ @auth_app.callback()
215
+ def auth_callback(ctx: typer.Context) -> None:
216
+ """Authentication commands for managing provider logins."""
217
+ if ctx.invoked_subcommand is None:
218
+ typer.echo(ctx.get_help())
219
+
220
+ auth_app.command("login")(login_command)
221
+ auth_app.command("logout")(logout_command)
222
+ app.add_typer(auth_app, name="auth")
@@ -11,9 +11,9 @@ from klaude_code.log import log
11
11
 
12
12
 
13
13
  def list_models(
14
- show_all: bool = typer.Option(False, "--all", "-a", help="Show all providers including unavailable ones"),
14
+ show_all: bool = typer.Option(False, "--all", "-a", help="Include unavailable providers"),
15
15
  ) -> None:
16
- """List all models and providers configuration"""
16
+ """List available models"""
17
17
  from klaude_code.cli.list_model import display_models_and_providers
18
18
  from klaude_code.tui.terminal.color import is_light_terminal_background
19
19
 
@@ -31,7 +31,7 @@ def list_models(
31
31
 
32
32
 
33
33
  def edit_config() -> None:
34
- """Open the configuration file in $EDITOR or default system editor"""
34
+ """Edit config file"""
35
35
  editor = os.environ.get("EDITOR")
36
36
 
37
37
  # If no EDITOR is set, prioritize TextEdit on macOS
@@ -89,5 +89,5 @@ def edit_config() -> None:
89
89
  def register_config_commands(app: typer.Typer) -> None:
90
90
  """Register config commands to the given Typer app."""
91
91
  app.command("list")(list_models)
92
- app.command("config")(edit_config)
93
- app.command("conf", hidden=True)(edit_config)
92
+ app.command("conf")(edit_config)
93
+ app.command("config", hidden=True)(edit_config)
@@ -5,6 +5,7 @@ from dataclasses import dataclass, field
5
5
  from datetime import datetime
6
6
  from pathlib import Path
7
7
 
8
+ import pydantic
8
9
  import typer
9
10
  from rich.box import Box
10
11
  from rich.console import Console
@@ -22,6 +23,7 @@ class ModelUsageStats:
22
23
  """Aggregated usage stats for a single model."""
23
24
 
24
25
  model_name: str
26
+ provider: str = ""
25
27
  input_tokens: int = 0
26
28
  output_tokens: int = 0
27
29
  cached_tokens: int = 0
@@ -57,8 +59,11 @@ class DailyStats:
57
59
  return
58
60
 
59
61
  model_key = meta.model_name
62
+ provider = meta.provider or meta.usage.provider or ""
60
63
  if model_key not in self.by_model:
61
- self.by_model[model_key] = ModelUsageStats(model_name=model_key)
64
+ self.by_model[model_key] = ModelUsageStats(model_name=model_key, provider=provider)
65
+ elif not self.by_model[model_key].provider and provider:
66
+ self.by_model[model_key].provider = provider
62
67
 
63
68
  self.by_model[model_key].add_usage(meta.usage)
64
69
 
@@ -115,6 +120,7 @@ def extract_task_metadata_from_events(events_path: Path) -> list[tuple[str, mode
115
120
  """Extract TaskMetadataItem entries from events.jsonl with their dates.
116
121
 
117
122
  Returns list of (date_str, TaskMetadataItem) tuples.
123
+ Skips lines that fail pydantic validation.
118
124
  """
119
125
  results: list[tuple[str, model.TaskMetadataItem]] = []
120
126
  try:
@@ -123,7 +129,10 @@ def extract_task_metadata_from_events(events_path: Path) -> list[tuple[str, mode
123
129
  return results
124
130
 
125
131
  for line in content.splitlines():
126
- item = decode_jsonl_line(line)
132
+ try:
133
+ item = decode_jsonl_line(line)
134
+ except pydantic.ValidationError:
135
+ continue
127
136
  if isinstance(item, model.TaskMetadataItem):
128
137
  date_str = item.created_at.strftime("%Y-%m-%d")
129
138
  results.append((date_str, item))
@@ -183,68 +192,123 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
183
192
  box=ASCII_HORIZONAL,
184
193
  )
185
194
 
186
- table.add_column("Date", style="cyan", no_wrap=True)
187
- table.add_column("Model", no_wrap=True)
188
- table.add_column("Input", justify="right", no_wrap=True)
189
- table.add_column("Output", justify="right", no_wrap=True)
190
- table.add_column("Cache", justify="right", no_wrap=True)
191
- table.add_column("Total", justify="right", no_wrap=True)
192
- table.add_column("USD", justify="right", no_wrap=True)
193
- table.add_column("CNY", justify="right", no_wrap=True)
195
+ table.add_column("Date", style="cyan")
196
+ table.add_column("Model", overflow="ellipsis")
197
+ table.add_column("Input", justify="right")
198
+ table.add_column("Output", justify="right")
199
+ table.add_column("Cache", justify="right")
200
+ table.add_column("Total", justify="right")
201
+ table.add_column("USD", justify="right")
202
+ table.add_column("CNY", justify="right")
194
203
 
195
204
  # Sort dates
196
205
  sorted_dates = sorted(daily_stats.keys())
197
206
 
198
- # Track global totals by model
199
- global_by_model: dict[str, ModelUsageStats] = {}
207
+ # Track global totals by (model, provider)
208
+ global_by_model: dict[tuple[str, str], ModelUsageStats] = {}
200
209
 
201
210
  def sort_by_cost(stats: ModelUsageStats) -> tuple[float, float]:
202
211
  """Sort key: USD desc, then CNY desc."""
203
212
  return (-stats.cost_usd, -stats.cost_cny)
204
213
 
205
- for date_str in sorted_dates:
206
- day = daily_stats[date_str]
207
- sorted_models = [s.model_name for s in sorted(day.by_model.values(), key=sort_by_cost)]
214
+ def render_by_provider(
215
+ models: dict[str, ModelUsageStats],
216
+ date_label: str = "",
217
+ show_subtotal: bool = True,
218
+ ) -> None:
219
+ """Render models grouped by provider with tree structure."""
220
+ # Group models by provider
221
+ models_by_provider: dict[str, list[ModelUsageStats]] = {}
222
+ provider_totals: dict[str, ModelUsageStats] = {}
223
+ for stats in models.values():
224
+ provider_key = stats.provider or "(unknown)"
225
+ if provider_key not in models_by_provider:
226
+ models_by_provider[provider_key] = []
227
+ provider_totals[provider_key] = ModelUsageStats(model_name=provider_key, provider=provider_key)
228
+ models_by_provider[provider_key].append(stats)
229
+ provider_totals[provider_key].input_tokens += stats.input_tokens
230
+ provider_totals[provider_key].output_tokens += stats.output_tokens
231
+ provider_totals[provider_key].cached_tokens += stats.cached_tokens
232
+ provider_totals[provider_key].cost_usd += stats.cost_usd
233
+ provider_totals[provider_key].cost_cny += stats.cost_cny
234
+
235
+ # Sort providers by cost, and models within each provider by cost
236
+ sorted_providers = sorted(provider_totals.keys(), key=lambda p: sort_by_cost(provider_totals[p]))
237
+ for provider_key in models_by_provider:
238
+ models_by_provider[provider_key].sort(key=sort_by_cost)
208
239
 
209
240
  first_row = True
210
- for model_name in sorted_models:
211
- stats = day.by_model[model_name]
212
- usd_str, cny_str = format_cost_dual(stats.cost_usd, stats.cost_cny)
213
-
214
- # Accumulate to global totals
215
- if model_name not in global_by_model:
216
- global_by_model[model_name] = ModelUsageStats(model_name=model_name)
217
- global_by_model[model_name].input_tokens += stats.input_tokens
218
- global_by_model[model_name].output_tokens += stats.output_tokens
219
- global_by_model[model_name].cached_tokens += stats.cached_tokens
220
- global_by_model[model_name].cost_usd += stats.cost_usd
221
- global_by_model[model_name].cost_cny += stats.cost_cny
241
+ for provider_key in sorted_providers:
242
+ provider_stats = provider_totals[provider_key]
243
+ provider_models = models_by_provider[provider_key]
222
244
 
245
+ # Provider row (bold)
246
+ usd_str, cny_str = format_cost_dual(provider_stats.cost_usd, provider_stats.cost_cny)
223
247
  table.add_row(
224
- format_date_display(date_str) if first_row else "",
225
- f"- {model_name}",
226
- format_tokens(stats.input_tokens),
227
- format_tokens(stats.output_tokens),
228
- format_tokens(stats.cached_tokens),
229
- format_tokens(stats.total_tokens),
230
- usd_str,
231
- cny_str,
248
+ date_label if first_row else "",
249
+ f"[bold]{provider_key}[/bold]",
250
+ f"[bold]{format_tokens(provider_stats.input_tokens)}[/bold]",
251
+ f"[bold]{format_tokens(provider_stats.output_tokens)}[/bold]",
252
+ f"[bold]{format_tokens(provider_stats.cached_tokens)}[/bold]",
253
+ f"[bold]{format_tokens(provider_stats.total_tokens)}[/bold]",
254
+ f"[bold]{usd_str}[/bold]",
255
+ f"[bold]{cny_str}[/bold]",
232
256
  )
233
257
  first_row = False
234
258
 
235
- # Add subtotal row for this day
236
- subtotal = day.get_subtotal()
237
- usd_str, cny_str = format_cost_dual(subtotal.cost_usd, subtotal.cost_cny)
238
- table.add_row(
239
- "",
240
- "[cyan] (subtotal)[/cyan]",
241
- f"[cyan]{format_tokens(subtotal.input_tokens)}[/cyan]",
242
- f"[cyan]{format_tokens(subtotal.output_tokens)}[/cyan]",
243
- f"[cyan]{format_tokens(subtotal.cached_tokens)}[/cyan]",
244
- f"[cyan]{format_tokens(subtotal.total_tokens)}[/cyan]",
245
- f"[cyan]{usd_str}[/cyan]",
246
- f"[cyan]{cny_str}[/cyan]",
247
- )
259
+ # Model rows with tree prefix
260
+ for i, stats in enumerate(provider_models):
261
+ is_last = i == len(provider_models) - 1
262
+ prefix = " └─ " if is_last else " ├─ "
263
+ usd_str, cny_str = format_cost_dual(stats.cost_usd, stats.cost_cny)
264
+ table.add_row(
265
+ "",
266
+ f"[bright_black dim]{prefix}[/bright_black dim]{stats.model_name}",
267
+ format_tokens(stats.input_tokens),
268
+ format_tokens(stats.output_tokens),
269
+ format_tokens(stats.cached_tokens),
270
+ format_tokens(stats.total_tokens),
271
+ usd_str,
272
+ cny_str,
273
+ )
274
+
275
+ # Add subtotal row
276
+ if show_subtotal:
277
+ subtotal = ModelUsageStats(model_name="(subtotal)")
278
+ for stats in models.values():
279
+ subtotal.input_tokens += stats.input_tokens
280
+ subtotal.output_tokens += stats.output_tokens
281
+ subtotal.cached_tokens += stats.cached_tokens
282
+ subtotal.cost_usd += stats.cost_usd
283
+ subtotal.cost_cny += stats.cost_cny
284
+ usd_str, cny_str = format_cost_dual(subtotal.cost_usd, subtotal.cost_cny)
285
+ table.add_row(
286
+ "",
287
+ "[bold](subtotal)[/bold]",
288
+ f"[bold]{format_tokens(subtotal.input_tokens)}[/bold]",
289
+ f"[bold]{format_tokens(subtotal.output_tokens)}[/bold]",
290
+ f"[bold]{format_tokens(subtotal.cached_tokens)}[/bold]",
291
+ f"[bold]{format_tokens(subtotal.total_tokens)}[/bold]",
292
+ f"[bold]{usd_str}[/bold]",
293
+ f"[bold]{cny_str}[/bold]",
294
+ )
295
+
296
+ for date_str in sorted_dates:
297
+ day = daily_stats[date_str]
298
+
299
+ # Accumulate to global totals by (model, provider)
300
+ for model_name, stats in day.by_model.items():
301
+ model_key = (model_name, stats.provider or "")
302
+ if model_key not in global_by_model:
303
+ global_by_model[model_key] = ModelUsageStats(model_name=model_name, provider=stats.provider)
304
+ global_by_model[model_key].input_tokens += stats.input_tokens
305
+ global_by_model[model_key].output_tokens += stats.output_tokens
306
+ global_by_model[model_key].cached_tokens += stats.cached_tokens
307
+ global_by_model[model_key].cost_usd += stats.cost_usd
308
+ global_by_model[model_key].cost_cny += stats.cost_cny
309
+
310
+ # Render this day's data grouped by provider
311
+ render_by_provider(day.by_model, date_label=format_date_display(date_str))
248
312
 
249
313
  # Add separator between days
250
314
  if date_str != sorted_dates[-1]:
@@ -264,27 +328,62 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
264
328
  else:
265
329
  total_label = "[bold]Total[/bold]"
266
330
 
267
- # Add per-model totals
268
- sorted_global_models = [s.model_name for s in sorted(global_by_model.values(), key=sort_by_cost)]
269
- first_total_row = True
270
- for model_name in sorted_global_models:
271
- # Add empty row before first model to align with Total date range
272
- if first_total_row:
273
- table.add_row(total_label, "", "", "", "", "", "", "")
274
- first_total_row = False
275
- stats = global_by_model[model_name]
276
- usd_str, cny_str = format_cost_dual(stats.cost_usd, stats.cost_cny)
331
+ # Group models by provider
332
+ models_by_provider: dict[str, list[ModelUsageStats]] = {}
333
+ provider_totals: dict[str, ModelUsageStats] = {}
334
+ for stats in global_by_model.values():
335
+ provider_key = stats.provider or "(unknown)"
336
+ if provider_key not in models_by_provider:
337
+ models_by_provider[provider_key] = []
338
+ provider_totals[provider_key] = ModelUsageStats(model_name=provider_key, provider=provider_key)
339
+ models_by_provider[provider_key].append(stats)
340
+ provider_totals[provider_key].input_tokens += stats.input_tokens
341
+ provider_totals[provider_key].output_tokens += stats.output_tokens
342
+ provider_totals[provider_key].cached_tokens += stats.cached_tokens
343
+ provider_totals[provider_key].cost_usd += stats.cost_usd
344
+ provider_totals[provider_key].cost_cny += stats.cost_cny
345
+
346
+ # Sort providers by cost, and models within each provider by cost
347
+ sorted_providers = sorted(provider_totals.keys(), key=lambda p: sort_by_cost(provider_totals[p]))
348
+ for provider_key in models_by_provider:
349
+ models_by_provider[provider_key].sort(key=sort_by_cost)
350
+
351
+ # Add total label row
352
+ table.add_row(total_label, "", "", "", "", "", "", "")
353
+
354
+ # Render each provider with its models
355
+ for provider_key in sorted_providers:
356
+ provider_stats = provider_totals[provider_key]
357
+ models = models_by_provider[provider_key]
358
+
359
+ # Provider row (bold)
360
+ usd_str, cny_str = format_cost_dual(provider_stats.cost_usd, provider_stats.cost_cny)
277
361
  table.add_row(
278
362
  "",
279
- f"- {model_name}",
280
- format_tokens(stats.input_tokens),
281
- format_tokens(stats.output_tokens),
282
- format_tokens(stats.cached_tokens),
283
- format_tokens(stats.total_tokens),
284
- usd_str,
285
- cny_str,
363
+ f"[bold]{provider_key}[/bold]",
364
+ f"[bold]{format_tokens(provider_stats.input_tokens)}[/bold]",
365
+ f"[bold]{format_tokens(provider_stats.output_tokens)}[/bold]",
366
+ f"[bold]{format_tokens(provider_stats.cached_tokens)}[/bold]",
367
+ f"[bold]{format_tokens(provider_stats.total_tokens)}[/bold]",
368
+ f"[bold]{usd_str}[/bold]",
369
+ f"[bold]{cny_str}[/bold]",
286
370
  )
287
- first_total_row = False
371
+
372
+ # Model rows with tree prefix
373
+ for i, stats in enumerate(models):
374
+ is_last = i == len(models) - 1
375
+ prefix = " └─ " if is_last else " ├─ "
376
+ usd_str, cny_str = format_cost_dual(stats.cost_usd, stats.cost_cny)
377
+ table.add_row(
378
+ "",
379
+ f"[bright_black dim]{prefix}[/bright_black dim]{stats.model_name}",
380
+ format_tokens(stats.input_tokens),
381
+ format_tokens(stats.output_tokens),
382
+ format_tokens(stats.cached_tokens),
383
+ format_tokens(stats.total_tokens),
384
+ usd_str,
385
+ cny_str,
386
+ )
288
387
 
289
388
  # Add grand total row
290
389
  grand_total = ModelUsageStats(model_name="(total)")
@@ -298,7 +397,7 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
298
397
  usd_str, cny_str = format_cost_dual(grand_total.cost_usd, grand_total.cost_cny)
299
398
  table.add_row(
300
399
  "",
301
- "[bold] (total)[/bold]",
400
+ "[bold](total)[/bold]",
302
401
  f"[bold]{format_tokens(grand_total.input_tokens)}[/bold]",
303
402
  f"[bold]{format_tokens(grand_total.output_tokens)}[/bold]",
304
403
  f"[bold]{format_tokens(grand_total.cached_tokens)}[/bold]",
@@ -313,7 +412,7 @@ def render_cost_table(daily_stats: dict[str, DailyStats]) -> Table:
313
412
  def cost_command(
314
413
  days: int | None = typer.Option(None, "--days", "-d", help="Limit to last N days"),
315
414
  ) -> None:
316
- """Display aggregated usage statistics across all sessions."""
415
+ """Show usage stats"""
317
416
  daily_stats = aggregate_all_sessions()
318
417
 
319
418
  if not daily_stats: