klaude-code 1.2.2__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. klaude_code/cli/main.py +7 -0
  2. klaude_code/cli/runtime.py +6 -6
  3. klaude_code/command/__init__.py +7 -5
  4. klaude_code/command/clear_cmd.py +3 -24
  5. klaude_code/command/command_abc.py +36 -1
  6. klaude_code/command/export_cmd.py +14 -20
  7. klaude_code/command/help_cmd.py +1 -0
  8. klaude_code/command/model_cmd.py +3 -30
  9. klaude_code/command/{prompt-update-dev-doc.md → prompt-dev-docs-update.md} +3 -2
  10. klaude_code/command/{prompt-dev-doc.md → prompt-dev-docs.md} +3 -2
  11. klaude_code/command/prompt-init.md +2 -5
  12. klaude_code/command/prompt_command.py +3 -3
  13. klaude_code/command/registry.py +6 -7
  14. klaude_code/config/config.py +1 -1
  15. klaude_code/config/list_model.py +1 -1
  16. klaude_code/const/__init__.py +1 -1
  17. klaude_code/core/agent.py +2 -11
  18. klaude_code/core/executor.py +155 -14
  19. klaude_code/core/prompts/prompt-gemini.md +1 -1
  20. klaude_code/core/reminders.py +24 -0
  21. klaude_code/core/task.py +10 -0
  22. klaude_code/core/tool/shell/bash_tool.py +6 -2
  23. klaude_code/core/tool/sub_agent_tool.py +1 -1
  24. klaude_code/core/tool/tool_context.py +1 -1
  25. klaude_code/core/tool/tool_registry.py +1 -1
  26. klaude_code/core/tool/tool_runner.py +1 -1
  27. klaude_code/core/tool/web/mermaid_tool.py +1 -1
  28. klaude_code/llm/__init__.py +3 -4
  29. klaude_code/llm/anthropic/client.py +12 -9
  30. klaude_code/llm/openai_compatible/client.py +2 -18
  31. klaude_code/llm/openai_compatible/tool_call_accumulator.py +2 -2
  32. klaude_code/llm/openrouter/client.py +2 -18
  33. klaude_code/llm/openrouter/input.py +6 -2
  34. klaude_code/llm/registry.py +2 -71
  35. klaude_code/llm/responses/client.py +2 -0
  36. klaude_code/llm/{metadata_tracker.py → usage.py} +49 -2
  37. klaude_code/protocol/llm_param.py +12 -0
  38. klaude_code/protocol/model.py +23 -3
  39. klaude_code/protocol/op.py +14 -14
  40. klaude_code/protocol/op_handler.py +28 -0
  41. klaude_code/protocol/tools.py +0 -2
  42. klaude_code/session/export.py +124 -35
  43. klaude_code/session/session.py +1 -1
  44. klaude_code/session/templates/export_session.html +180 -42
  45. klaude_code/ui/__init__.py +6 -2
  46. klaude_code/ui/modes/exec/display.py +26 -0
  47. klaude_code/ui/modes/repl/event_handler.py +5 -1
  48. klaude_code/ui/renderers/developer.py +6 -10
  49. klaude_code/ui/renderers/metadata.py +33 -24
  50. klaude_code/ui/renderers/sub_agent.py +1 -1
  51. klaude_code/ui/renderers/tools.py +2 -2
  52. klaude_code/ui/renderers/user_input.py +18 -22
  53. klaude_code/ui/rich/status.py +13 -2
  54. {klaude_code-1.2.2.dist-info → klaude_code-1.2.3.dist-info}/METADATA +1 -1
  55. {klaude_code-1.2.2.dist-info → klaude_code-1.2.3.dist-info}/RECORD +58 -57
  56. /klaude_code/{core → protocol}/sub_agent.py +0 -0
  57. {klaude_code-1.2.2.dist-info → klaude_code-1.2.3.dist-info}/WHEEL +0 -0
  58. {klaude_code-1.2.2.dist-info → klaude_code-1.2.3.dist-info}/entry_points.txt +0 -0
@@ -5,19 +5,86 @@ This module implements the submission_loop equivalent for klaude,
5
5
  handling operations submitted from the CLI and coordinating with agents.
6
6
  """
7
7
 
8
+ from __future__ import annotations
9
+
8
10
  import asyncio
9
11
  from dataclasses import dataclass
12
+ from dataclasses import field as dataclass_field
10
13
 
11
- from klaude_code.command import dispatch_command
14
+ from klaude_code.command import InputAction, InputActionType, dispatch_command
15
+ from klaude_code.config import Config, load_config
12
16
  from klaude_code.core.agent import Agent, DefaultModelProfileProvider, ModelProfileProvider
13
- from klaude_code.core.sub_agent import SubAgentResult
14
17
  from klaude_code.core.tool import current_run_subtask_callback
15
- from klaude_code.llm import LLMClients
16
- from klaude_code.protocol import events, model, op
18
+ from klaude_code.llm.client import LLMClientABC
19
+ from klaude_code.llm.registry import create_llm_client
20
+ from klaude_code.protocol import commands, events, model, op
21
+ from klaude_code.protocol.op_handler import OperationHandler
22
+ from klaude_code.protocol.sub_agent import SubAgentResult, get_sub_agent_profile
23
+ from klaude_code.protocol.tools import SubAgentType
17
24
  from klaude_code.session.session import Session
18
25
  from klaude_code.trace import DebugType, log_debug
19
26
 
20
27
 
28
+ @dataclass
29
+ class LLMClients:
30
+ """Container for LLM clients used by main agent and sub-agents."""
31
+
32
+ main: LLMClientABC
33
+ sub_clients: dict[SubAgentType, LLMClientABC] = dataclass_field(default_factory=lambda: {})
34
+
35
+ def get_client(self, sub_agent_type: SubAgentType | None = None) -> LLMClientABC:
36
+ """Get client for given sub-agent type, or main client if None."""
37
+ if sub_agent_type is None:
38
+ return self.main
39
+ return self.sub_clients.get(sub_agent_type) or self.main
40
+
41
+ @classmethod
42
+ def from_config(
43
+ cls,
44
+ config: Config,
45
+ model_override: str | None = None,
46
+ enabled_sub_agents: list[SubAgentType] | None = None,
47
+ ) -> LLMClients:
48
+ """Create LLMClients from application config.
49
+
50
+ Args:
51
+ config: Application configuration
52
+ model_override: Optional model name to override the main model
53
+ enabled_sub_agents: List of sub-agent types to initialize clients for
54
+
55
+ Returns:
56
+ LLMClients instance
57
+ """
58
+ # Resolve main agent LLM config
59
+ if model_override:
60
+ llm_config = config.get_model_config(model_override)
61
+ else:
62
+ llm_config = config.get_main_model_config()
63
+
64
+ log_debug(
65
+ "Main LLM config",
66
+ llm_config.model_dump_json(exclude_none=True),
67
+ style="yellow",
68
+ debug_type=DebugType.LLM_CONFIG,
69
+ )
70
+
71
+ main_client = create_llm_client(llm_config)
72
+ sub_clients: dict[SubAgentType, LLMClientABC] = {}
73
+
74
+ # Initialize sub-agent clients
75
+ for sub_agent_type in enabled_sub_agents or []:
76
+ model_name = config.subagent_models.get(sub_agent_type)
77
+ if not model_name:
78
+ continue
79
+ profile = get_sub_agent_profile(sub_agent_type)
80
+ if not profile.enabled_for_model(main_client.model_name):
81
+ continue
82
+ sub_llm_config = config.get_model_config(model_name)
83
+ sub_clients[sub_agent_type] = create_llm_client(sub_llm_config)
84
+
85
+ return cls(main=main_client, sub_clients=sub_clients)
86
+
87
+
21
88
  @dataclass
22
89
  class ActiveTask:
23
90
  """Track an in-flight task and its owning session."""
@@ -32,6 +99,8 @@ class ExecutorContext:
32
99
 
33
100
  This context is passed to operations when they execute, allowing them
34
101
  to access shared resources like the event queue and active sessions.
102
+
103
+ Implements the OperationHandler protocol via structural subtyping.
35
104
  """
36
105
 
37
106
  def __init__(
@@ -65,7 +134,6 @@ class ExecutorContext:
65
134
  agent = Agent(
66
135
  session=session,
67
136
  profile=profile,
68
- model_profile_provider=self.model_profile_provider,
69
137
  )
70
138
 
71
139
  async for evt in agent.replay_history():
@@ -109,8 +177,12 @@ class ExecutorContext:
109
177
  )
110
178
 
111
179
  result = await dispatch_command(user_input.text, agent)
112
- if not result.agent_input:
113
- # If this command do not need run agent, we should append user message to session history here
180
+
181
+ actions: list[InputAction] = list(result.actions or [])
182
+
183
+ has_run_agent_action = any(action.type is InputActionType.RUN_AGENT for action in actions)
184
+ if not has_run_agent_action:
185
+ # No async agent task will run, append user message directly
114
186
  agent.session.append_history([model.UserMessageItem(content=user_input.text, images=user_input.images)])
115
187
 
116
188
  if result.events:
@@ -120,15 +192,80 @@ class ExecutorContext:
120
192
  for evt in result.events:
121
193
  await self.emit_event(evt)
122
194
 
123
- if result.agent_input:
124
- # Construct new UserInputPayload with command-processed text, preserving original images
125
- task_input = model.UserInputPayload(text=result.agent_input, images=user_input.images)
126
- # Start task to process user input (do NOT await here so the executor loop stays responsive)
195
+ for action in actions:
196
+ await self._run_input_action(action, operation, agent)
197
+
198
+ async def _run_input_action(self, action: InputAction, operation: op.UserInputOperation, agent: Agent) -> None:
199
+ if operation.session_id is None:
200
+ raise ValueError("session_id cannot be None for input actions")
201
+
202
+ session_id = operation.session_id
203
+
204
+ if action.type == InputActionType.RUN_AGENT:
205
+ task_input = model.UserInputPayload(text=action.text, images=operation.input.images)
206
+
207
+ existing_active = self.active_tasks.get(operation.id)
208
+ if existing_active is not None and not existing_active.task.done():
209
+ raise RuntimeError(f"Active task already registered for operation {operation.id}")
210
+
127
211
  task: asyncio.Task[None] = asyncio.create_task(
128
212
  self._run_agent_task(agent, task_input, operation.id, session_id)
129
213
  )
130
214
  self.active_tasks[operation.id] = ActiveTask(task=task, session_id=session_id)
131
- # Do not await task here; completion will be tracked by the executor
215
+ return
216
+
217
+ if action.type == InputActionType.CHANGE_MODEL:
218
+ if not action.model_name:
219
+ raise ValueError("ChangeModel action requires model_name")
220
+
221
+ await self._apply_model_change(agent, action.model_name)
222
+ return
223
+
224
+ if action.type == InputActionType.CLEAR:
225
+ await self._apply_clear(agent)
226
+ return
227
+
228
+ raise ValueError(f"Unsupported input action type: {action.type}")
229
+
230
+ async def _apply_model_change(self, agent: Agent, model_name: str) -> None:
231
+ config = load_config()
232
+ if config is None:
233
+ raise ValueError("Configuration must be initialized before changing model")
234
+
235
+ llm_config = config.get_model_config(model_name)
236
+ llm_client = create_llm_client(llm_config)
237
+ agent.set_model_profile(self.model_profile_provider.build_profile(llm_client))
238
+
239
+ developer_item = model.DeveloperMessageItem(
240
+ content=f"switched to model: {model_name}",
241
+ command_output=model.CommandOutput(command_name=commands.CommandName.MODEL),
242
+ )
243
+ agent.session.append_history([developer_item])
244
+
245
+ await self.emit_event(events.DeveloperMessageEvent(session_id=agent.session.id, item=developer_item))
246
+ await self.emit_event(events.WelcomeEvent(llm_config=llm_config, work_dir=str(agent.session.work_dir)))
247
+
248
+ async def _apply_clear(self, agent: Agent) -> None:
249
+ old_session_id = agent.session.id
250
+
251
+ # Create a new session instance to replace the current one
252
+ new_session = Session(work_dir=agent.session.work_dir)
253
+ new_session.model_name = agent.session.model_name
254
+
255
+ # Replace the agent's session with the new one
256
+ agent.session = new_session
257
+ agent.session.save()
258
+
259
+ # Update the active_agents mapping
260
+ self.active_agents.pop(old_session_id, None)
261
+ self.active_agents[new_session.id] = agent
262
+
263
+ developer_item = model.DeveloperMessageItem(
264
+ content="started new conversation",
265
+ command_output=model.CommandOutput(command_name=commands.CommandName.CLEAR),
266
+ )
267
+
268
+ await self.emit_event(events.DeveloperMessageEvent(session_id=agent.session.id, item=developer_item))
132
269
 
133
270
  async def handle_interrupt(self, operation: op.InterruptOperation) -> None:
134
271
  """Handle an interrupt by invoking agent.cancel() and cancelling tasks."""
@@ -256,7 +393,6 @@ class ExecutorContext:
256
393
  child_agent = Agent(
257
394
  session=child_session,
258
395
  profile=child_profile,
259
- model_profile_provider=self.model_profile_provider,
260
396
  )
261
397
 
262
398
  log_debug(
@@ -439,7 +575,7 @@ class Executor:
439
575
  )
440
576
 
441
577
  # Execute to spawn the agent task in context
442
- await submission.operation.execute(self.context)
578
+ await submission.operation.execute(handler=self.context)
443
579
 
444
580
  async def _await_agent_and_complete() -> None:
445
581
  # Wait for the agent task tied to this submission id
@@ -474,3 +610,8 @@ class Executor:
474
610
  event = self._completion_events.get(submission.id)
475
611
  if event is not None:
476
612
  event.set()
613
+
614
+
615
+ # Static type check: ExecutorContext must satisfy OperationHandler protocol.
616
+ # If this line causes a type error, ExecutorContext is missing required methods.
617
+ _: type[OperationHandler] = ExecutorContext # pyright: ignore[reportUnusedVariable]
@@ -1,4 +1,4 @@
1
- You are a very strong reasoner and planner. Use these critical instructions to structure your plans, thoughts, and responses.
1
+ You are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
2
2
 
3
3
  Before taking any action (either tool calls *or* responses to the user), you must proactively, methodically, and independently plan and reason about:
4
4
 
@@ -241,6 +241,28 @@ class Memory(BaseModel):
241
241
  content: str
242
242
 
243
243
 
244
+ def get_last_user_message_image_count(session: Session) -> int:
245
+ """Get image count from the last user message in conversation history."""
246
+ for item in reversed(session.conversation_history):
247
+ if isinstance(item, model.ToolResultItem):
248
+ return 0
249
+ if isinstance(item, model.UserMessageItem):
250
+ return len(item.images) if item.images else 0
251
+ return 0
252
+
253
+
254
+ async def image_reminder(session: Session) -> model.DeveloperMessageItem | None:
255
+ """Remind agent about images attached by user in the last message."""
256
+ image_count = get_last_user_message_image_count(session)
257
+ if image_count == 0:
258
+ return None
259
+
260
+ return model.DeveloperMessageItem(
261
+ content=f"<system-reminder>User attached {image_count} image{'s' if image_count > 1 else ''} in their message. Make sure to analyze and reference these images as needed.</system-reminder>",
262
+ user_image_count=image_count,
263
+ )
264
+
265
+
244
266
  async def memory_reminder(session: Session) -> model.DeveloperMessageItem | None:
245
267
  """CLAUDE.md AGENTS.md"""
246
268
  memory_paths = get_memory_paths()
@@ -386,6 +408,7 @@ ALL_REMINDERS = [
386
408
  memory_reminder,
387
409
  last_path_memory_reminder,
388
410
  at_file_reader_reminder,
411
+ image_reminder,
389
412
  ]
390
413
 
391
414
 
@@ -415,6 +438,7 @@ def load_agent_reminders(
415
438
  last_path_memory_reminder,
416
439
  at_file_reader_reminder,
417
440
  file_changed_externally_reminder,
441
+ image_reminder,
418
442
  ]
419
443
  )
420
444
 
klaude_code/core/task.py CHANGED
@@ -62,6 +62,16 @@ class MetadataAccumulator:
62
62
  self._throughput_weighted_sum += usage.throughput_tps * current_output
63
63
  self._throughput_tracked_tokens += current_output
64
64
 
65
+ # Accumulate costs
66
+ if usage.input_cost is not None:
67
+ acc_usage.input_cost = (acc_usage.input_cost or 0.0) + usage.input_cost
68
+ if usage.output_cost is not None:
69
+ acc_usage.output_cost = (acc_usage.output_cost or 0.0) + usage.output_cost
70
+ if usage.cache_read_cost is not None:
71
+ acc_usage.cache_read_cost = (acc_usage.cache_read_cost or 0.0) + usage.cache_read_cost
72
+ if usage.total_cost is not None:
73
+ acc_usage.total_cost = (acc_usage.total_cost or 0.0) + usage.total_cost
74
+
65
75
  if turn_metadata.provider is not None:
66
76
  accumulated.provider = turn_metadata.provider
67
77
  if turn_metadata.model_name:
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import re
2
3
  import subprocess
3
4
  from pathlib import Path
4
5
 
@@ -10,6 +11,9 @@ from klaude_code.core.tool.tool_abc import ToolABC, load_desc
10
11
  from klaude_code.core.tool.tool_registry import register
11
12
  from klaude_code.protocol import llm_param, model, tools
12
13
 
14
+ # Regex to strip ANSI escape sequences from command output
15
+ _ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;]*m")
16
+
13
17
 
14
18
  @register(tools.BASH)
15
19
  class BashTool(ToolABC):
@@ -78,8 +82,8 @@ class BashTool(ToolABC):
78
82
  check=False,
79
83
  )
80
84
 
81
- stdout = completed.stdout or ""
82
- stderr = completed.stderr or ""
85
+ stdout = _ANSI_ESCAPE_RE.sub("", completed.stdout or "")
86
+ stderr = _ANSI_ESCAPE_RE.sub("", completed.stderr or "")
83
87
  rc = completed.returncode
84
88
 
85
89
  if rc == 0:
@@ -15,7 +15,7 @@ from klaude_code.core.tool.tool_context import current_run_subtask_callback
15
15
  from klaude_code.protocol import llm_param, model
16
16
 
17
17
  if TYPE_CHECKING:
18
- from klaude_code.core.sub_agent import SubAgentProfile
18
+ from klaude_code.protocol.sub_agent import SubAgentProfile
19
19
 
20
20
 
21
21
  class SubAgentTool(ToolABC):
@@ -5,8 +5,8 @@ from contextlib import contextmanager
5
5
  from contextvars import ContextVar, Token
6
6
  from dataclasses import dataclass
7
7
 
8
- from klaude_code.core.sub_agent import SubAgentResult
9
8
  from klaude_code.protocol import model
9
+ from klaude_code.protocol.sub_agent import SubAgentResult
10
10
  from klaude_code.session.session import Session
11
11
 
12
12
 
@@ -1,9 +1,9 @@
1
1
  from typing import Callable, TypeVar
2
2
 
3
- from klaude_code.core.sub_agent import get_sub_agent_profile, iter_sub_agent_profiles, sub_agent_tool_names
4
3
  from klaude_code.core.tool.sub_agent_tool import SubAgentTool
5
4
  from klaude_code.core.tool.tool_abc import ToolABC
6
5
  from klaude_code.protocol import llm_param, tools
6
+ from klaude_code.protocol.sub_agent import get_sub_agent_profile, iter_sub_agent_profiles, sub_agent_tool_names
7
7
 
8
8
  _REGISTRY: dict[str, type[ToolABC]] = {}
9
9
 
@@ -3,10 +3,10 @@ from collections.abc import AsyncGenerator, Callable, Iterable, Sequence
3
3
  from dataclasses import dataclass
4
4
 
5
5
  from klaude_code import const
6
- from klaude_code.core.sub_agent import is_sub_agent_tool
7
6
  from klaude_code.core.tool.tool_abc import ToolABC
8
7
  from klaude_code.core.tool.truncation import truncate_tool_output
9
8
  from klaude_code.protocol import model
9
+ from klaude_code.protocol.sub_agent import is_sub_agent_tool
10
10
 
11
11
 
12
12
  async def run_tool(tool_call: model.ToolCallItem, registry: dict[str, type[ToolABC]]) -> model.ToolResultItem:
@@ -60,7 +60,7 @@ class MermaidTool(ToolABC):
60
60
  def _build_link(code: str) -> str:
61
61
  state = {
62
62
  "code": code,
63
- "mermaid": {"theme": "default"},
63
+ "mermaid": {"theme": "neutral"},
64
64
  "autoSync": True,
65
65
  "updateDiagram": True,
66
66
  }
@@ -1,19 +1,18 @@
1
1
  """LLM package init.
2
2
 
3
- Ensures built-in clients are imported so their `@register` decorators run
4
- and they become available via the registry.
3
+ Imports built-in LLM clients so their ``@register`` decorators run and they
4
+ become available via the registry.
5
5
  """
6
6
 
7
7
  from .anthropic import AnthropicClient
8
8
  from .client import LLMClientABC
9
9
  from .openai_compatible import OpenAICompatibleClient
10
10
  from .openrouter import OpenRouterClient
11
- from .registry import LLMClients, create_llm_client
11
+ from .registry import create_llm_client
12
12
  from .responses import ResponsesClient
13
13
 
14
14
  __all__ = [
15
15
  "LLMClientABC",
16
- "LLMClients",
17
16
  "ResponsesClient",
18
17
  "OpenAICompatibleClient",
19
18
  "OpenRouterClient",
@@ -22,6 +22,7 @@ from klaude_code.llm.anthropic.input import convert_history_to_input, convert_sy
22
22
  from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
23
23
  from klaude_code.llm.input_common import apply_config_defaults
24
24
  from klaude_code.llm.registry import register
25
+ from klaude_code.llm.usage import calculate_cost
25
26
  from klaude_code.protocol import llm_param, model
26
27
  from klaude_code.trace import DebugType, log_debug
27
28
 
@@ -199,16 +200,18 @@ class AnthropicClient(LLMClientABC):
199
200
  if time_duration >= 0.15:
200
201
  throughput_tps = output_tokens / time_duration
201
202
 
203
+ usage = model.Usage(
204
+ input_tokens=input_tokens,
205
+ output_tokens=output_tokens,
206
+ cached_tokens=cached_tokens,
207
+ total_tokens=total_tokens,
208
+ context_usage_percent=context_usage_percent,
209
+ throughput_tps=throughput_tps,
210
+ first_token_latency_ms=first_token_latency_ms,
211
+ )
212
+ calculate_cost(usage, self._config.cost)
202
213
  yield model.ResponseMetadataItem(
203
- usage=model.Usage(
204
- input_tokens=input_tokens,
205
- output_tokens=output_tokens,
206
- cached_tokens=cached_tokens,
207
- total_tokens=total_tokens,
208
- context_usage_percent=context_usage_percent,
209
- throughput_tps=throughput_tps,
210
- first_token_latency_ms=first_token_latency_ms,
211
- ),
214
+ usage=usage,
212
215
  response_id=response_id,
213
216
  model_name=str(param.model),
214
217
  )
@@ -8,10 +8,10 @@ from openai import APIError, RateLimitError
8
8
 
9
9
  from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
10
10
  from klaude_code.llm.input_common import apply_config_defaults
11
- from klaude_code.llm.metadata_tracker import MetadataTracker
12
11
  from klaude_code.llm.openai_compatible.input import convert_history_to_input, convert_tool_schema
13
12
  from klaude_code.llm.openai_compatible.tool_call_accumulator import BasicToolCallAccumulator, ToolCallAccumulatorABC
14
13
  from klaude_code.llm.registry import register
14
+ from klaude_code.llm.usage import MetadataTracker, convert_usage
15
15
  from klaude_code.protocol import llm_param, model
16
16
  from klaude_code.trace import DebugType, log_debug
17
17
 
@@ -48,7 +48,7 @@ class OpenAICompatibleClient(LLMClientABC):
48
48
  messages = convert_history_to_input(param.input, param.system, param.model)
49
49
  tools = convert_tool_schema(param.tools)
50
50
 
51
- metadata_tracker = MetadataTracker()
51
+ metadata_tracker = MetadataTracker(cost_config=self._config.cost)
52
52
 
53
53
  extra_body = {}
54
54
  extra_headers = {"extra": json.dumps({"session_id": param.session_id})}
@@ -209,19 +209,3 @@ class OpenAICompatibleClient(LLMClientABC):
209
209
 
210
210
  metadata_tracker.set_response_id(response_id)
211
211
  yield metadata_tracker.finalize()
212
-
213
-
214
- def convert_usage(usage: openai.types.CompletionUsage, context_limit: int | None = None) -> model.Usage:
215
- total_tokens = usage.total_tokens
216
- context_usage_percent = (total_tokens / context_limit) * 100 if context_limit else None
217
- return model.Usage(
218
- input_tokens=usage.prompt_tokens,
219
- cached_tokens=(usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0) or 0,
220
- reasoning_tokens=(usage.completion_tokens_details.reasoning_tokens if usage.completion_tokens_details else 0)
221
- or 0,
222
- output_tokens=usage.completion_tokens,
223
- total_tokens=total_tokens,
224
- context_usage_percent=context_usage_percent,
225
- throughput_tps=None,
226
- first_token_latency_ms=None,
227
- )
@@ -8,7 +8,7 @@ from klaude_code.protocol import model
8
8
 
9
9
  class ToolCallAccumulatorABC(ABC):
10
10
  @abstractmethod
11
- def add(self, chunks: list[ChoiceDeltaToolCall]):
11
+ def add(self, chunks: list[ChoiceDeltaToolCall]) -> None:
12
12
  pass
13
13
 
14
14
  @abstractmethod
@@ -50,7 +50,7 @@ class BasicToolCallAccumulator(ToolCallAccumulatorABC, BaseModel):
50
50
  chunks_by_step: list[list[ChoiceDeltaToolCall]] = Field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
51
51
  response_id: str | None = None
52
52
 
53
- def add(self, chunks: list[ChoiceDeltaToolCall]):
53
+ def add(self, chunks: list[ChoiceDeltaToolCall]) -> None:
54
54
  self.chunks_by_step.append(chunks)
55
55
 
56
56
  def get(self) -> list[model.ToolCallItem]:
@@ -6,12 +6,12 @@ import openai
6
6
 
7
7
  from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
8
8
  from klaude_code.llm.input_common import apply_config_defaults
9
- from klaude_code.llm.metadata_tracker import MetadataTracker
10
9
  from klaude_code.llm.openai_compatible.input import convert_tool_schema
11
10
  from klaude_code.llm.openai_compatible.tool_call_accumulator import BasicToolCallAccumulator, ToolCallAccumulatorABC
12
11
  from klaude_code.llm.openrouter.input import convert_history_to_input, is_claude_model
13
12
  from klaude_code.llm.openrouter.reasoning_handler import ReasoningDetail, ReasoningStreamHandler
14
13
  from klaude_code.llm.registry import register
14
+ from klaude_code.llm.usage import MetadataTracker, convert_usage
15
15
  from klaude_code.protocol import llm_param, model
16
16
  from klaude_code.trace import DebugType, log, log_debug
17
17
 
@@ -38,7 +38,7 @@ class OpenRouterClient(LLMClientABC):
38
38
  messages = convert_history_to_input(param.input, param.system, param.model)
39
39
  tools = convert_tool_schema(param.tools)
40
40
 
41
- metadata_tracker = MetadataTracker()
41
+ metadata_tracker = MetadataTracker(cost_config=self._config.cost)
42
42
 
43
43
  extra_body: dict[str, object] = {
44
44
  "usage": {"include": True} # To get the cache tokens at the end of the response
@@ -198,19 +198,3 @@ class OpenRouterClient(LLMClientABC):
198
198
 
199
199
  metadata_tracker.set_response_id(response_id)
200
200
  yield metadata_tracker.finalize()
201
-
202
-
203
- def convert_usage(usage: openai.types.CompletionUsage, context_limit: int | None = None) -> model.Usage:
204
- total_tokens = usage.total_tokens
205
- context_usage_percent = (total_tokens / context_limit) * 100 if context_limit else None
206
- return model.Usage(
207
- input_tokens=usage.prompt_tokens,
208
- cached_tokens=(usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0) or 0,
209
- reasoning_tokens=(usage.completion_tokens_details.reasoning_tokens if usage.completion_tokens_details else 0)
210
- or 0,
211
- output_tokens=usage.completion_tokens,
212
- total_tokens=total_tokens,
213
- context_usage_percent=context_usage_percent,
214
- throughput_tps=None,
215
- first_token_latency_ms=None,
216
- )
@@ -13,11 +13,15 @@ from klaude_code.llm.input_common import AssistantGroup, ToolGroup, UserGroup, m
13
13
  from klaude_code.protocol import model
14
14
 
15
15
 
16
- def is_claude_model(model_name: str | None):
16
+ def is_claude_model(model_name: str | None) -> bool:
17
+ """Return True if the model name represents an Anthropic Claude model."""
18
+
17
19
  return model_name is not None and model_name.startswith("anthropic/claude")
18
20
 
19
21
 
20
- def is_gemini_model(model_name: str | None):
22
+ def is_gemini_model(model_name: str | None) -> bool:
23
+ """Return True if the model name represents a Google Gemini model."""
24
+
21
25
  return model_name is not None and model_name.startswith("google/gemini")
22
26
 
23
27
 
@@ -1,14 +1,7 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass, field
4
- from typing import TYPE_CHECKING, Callable, TypeVar
1
+ from typing import Callable, TypeVar
5
2
 
6
3
  from klaude_code.llm.client import LLMClientABC
7
- from klaude_code.protocol import llm_param, tools
8
- from klaude_code.trace import DebugType, log_debug
9
-
10
- if TYPE_CHECKING:
11
- from klaude_code.config import Config
4
+ from klaude_code.protocol import llm_param
12
5
 
13
6
  _REGISTRY: dict[llm_param.LLMClientProtocol, type[LLMClientABC]] = {}
14
7
 
@@ -27,65 +20,3 @@ def create_llm_client(config: llm_param.LLMConfigParameter) -> LLMClientABC:
27
20
  if config.protocol not in _REGISTRY:
28
21
  raise ValueError(f"Unknown LLMClient protocol: {config.protocol}")
29
22
  return _REGISTRY[config.protocol].create(config)
30
-
31
-
32
- @dataclass
33
- class LLMClients:
34
- """Container for LLM clients used by main agent and sub-agents."""
35
-
36
- main: LLMClientABC
37
- sub_clients: dict[tools.SubAgentType, LLMClientABC] = field(default_factory=lambda: {})
38
-
39
- def get_client(self, sub_agent_type: tools.SubAgentType | None = None) -> LLMClientABC:
40
- """Get client for given sub-agent type, or main client if None."""
41
- if sub_agent_type is None:
42
- return self.main
43
- return self.sub_clients.get(sub_agent_type) or self.main
44
-
45
- @classmethod
46
- def from_config(
47
- cls,
48
- config: Config,
49
- model_override: str | None = None,
50
- enabled_sub_agents: list[tools.SubAgentType] | None = None,
51
- ) -> LLMClients:
52
- """Create LLMClients from application config.
53
-
54
- Args:
55
- config: Application configuration
56
- model_override: Optional model name to override the main model
57
- enabled_sub_agents: List of sub-agent types to initialize clients for
58
-
59
- Returns:
60
- LLMClients instance
61
- """
62
- from klaude_code.core.sub_agent import get_sub_agent_profile
63
-
64
- # Resolve main agent LLM config
65
- if model_override:
66
- llm_config = config.get_model_config(model_override)
67
- else:
68
- llm_config = config.get_main_model_config()
69
-
70
- log_debug(
71
- "Main LLM config",
72
- llm_config.model_dump_json(exclude_none=True),
73
- style="yellow",
74
- debug_type=DebugType.LLM_CONFIG,
75
- )
76
-
77
- main_client = create_llm_client(llm_config)
78
- sub_clients: dict[tools.SubAgentType, LLMClientABC] = {}
79
-
80
- # Initialize sub-agent clients
81
- for sub_agent_type in enabled_sub_agents or []:
82
- model_name = config.subagent_models.get(sub_agent_type)
83
- if not model_name:
84
- continue
85
- profile = get_sub_agent_profile(sub_agent_type)
86
- if not profile.enabled_for_model(main_client.model_name):
87
- continue
88
- sub_llm_config = config.get_model_config(model_name)
89
- sub_clients[sub_agent_type] = create_llm_client(sub_llm_config)
90
-
91
- return cls(main=main_client, sub_clients=sub_clients)
@@ -11,6 +11,7 @@ from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
11
11
  from klaude_code.llm.input_common import apply_config_defaults
12
12
  from klaude_code.llm.registry import register
13
13
  from klaude_code.llm.responses.input import convert_history_to_input, convert_tool_schema
14
+ from klaude_code.llm.usage import calculate_cost
14
15
  from klaude_code.protocol import llm_param, model
15
16
  from klaude_code.trace import DebugType, log_debug
16
17
 
@@ -185,6 +186,7 @@ class ResponsesClient(LLMClientABC):
185
186
  throughput_tps=throughput_tps,
186
187
  first_token_latency_ms=first_token_latency_ms,
187
188
  )
189
+ calculate_cost(usage, self._config.cost)
188
190
  yield model.ResponseMetadataItem(
189
191
  usage=usage,
190
192
  response_id=response_id,