klaude-code 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klaude_code/cli/main.py +7 -0
- klaude_code/cli/runtime.py +6 -6
- klaude_code/command/__init__.py +9 -5
- klaude_code/command/clear_cmd.py +3 -24
- klaude_code/command/command_abc.py +36 -1
- klaude_code/command/export_cmd.py +16 -20
- klaude_code/command/help_cmd.py +1 -0
- klaude_code/command/model_cmd.py +3 -30
- klaude_code/command/{prompt-update-dev-doc.md → prompt-dev-docs-update.md} +3 -2
- klaude_code/command/{prompt-dev-doc.md → prompt-dev-docs.md} +3 -2
- klaude_code/command/prompt-init.md +2 -5
- klaude_code/command/prompt_command.py +3 -3
- klaude_code/command/registry.py +6 -7
- klaude_code/command/status_cmd.py +111 -0
- klaude_code/config/config.py +1 -1
- klaude_code/config/list_model.py +1 -1
- klaude_code/const/__init__.py +1 -1
- klaude_code/core/agent.py +2 -11
- klaude_code/core/executor.py +155 -14
- klaude_code/core/prompts/prompt-gemini.md +1 -1
- klaude_code/core/reminders.py +24 -0
- klaude_code/core/task.py +10 -0
- klaude_code/core/tool/shell/bash_tool.py +6 -2
- klaude_code/core/tool/sub_agent_tool.py +1 -1
- klaude_code/core/tool/tool_context.py +1 -1
- klaude_code/core/tool/tool_registry.py +1 -1
- klaude_code/core/tool/tool_runner.py +1 -1
- klaude_code/core/tool/web/mermaid_tool.py +1 -1
- klaude_code/llm/__init__.py +3 -4
- klaude_code/llm/anthropic/client.py +12 -9
- klaude_code/llm/openai_compatible/client.py +2 -18
- klaude_code/llm/openai_compatible/tool_call_accumulator.py +2 -2
- klaude_code/llm/openrouter/client.py +2 -18
- klaude_code/llm/openrouter/input.py +6 -2
- klaude_code/llm/registry.py +2 -71
- klaude_code/llm/responses/client.py +2 -0
- klaude_code/llm/{metadata_tracker.py → usage.py} +49 -2
- klaude_code/protocol/commands.py +1 -0
- klaude_code/protocol/llm_param.py +12 -0
- klaude_code/protocol/model.py +30 -3
- klaude_code/protocol/op.py +14 -14
- klaude_code/protocol/op_handler.py +28 -0
- klaude_code/protocol/tools.py +0 -2
- klaude_code/session/export.py +124 -35
- klaude_code/session/session.py +1 -1
- klaude_code/session/templates/export_session.html +383 -39
- klaude_code/ui/__init__.py +6 -2
- klaude_code/ui/modes/exec/display.py +26 -0
- klaude_code/ui/modes/repl/event_handler.py +5 -1
- klaude_code/ui/renderers/developer.py +62 -11
- klaude_code/ui/renderers/metadata.py +33 -24
- klaude_code/ui/renderers/sub_agent.py +1 -1
- klaude_code/ui/renderers/tools.py +2 -2
- klaude_code/ui/renderers/user_input.py +18 -22
- klaude_code/ui/rich/status.py +13 -2
- {klaude_code-1.2.2.dist-info → klaude_code-1.2.4.dist-info}/METADATA +1 -1
- {klaude_code-1.2.2.dist-info → klaude_code-1.2.4.dist-info}/RECORD +60 -58
- /klaude_code/{core → protocol}/sub_agent.py +0 -0
- {klaude_code-1.2.2.dist-info → klaude_code-1.2.4.dist-info}/WHEEL +0 -0
- {klaude_code-1.2.2.dist-info → klaude_code-1.2.4.dist-info}/entry_points.txt +0 -0
klaude_code/core/executor.py
CHANGED
|
@@ -5,19 +5,86 @@ This module implements the submission_loop equivalent for klaude,
|
|
|
5
5
|
handling operations submitted from the CLI and coordinating with agents.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
8
10
|
import asyncio
|
|
9
11
|
from dataclasses import dataclass
|
|
12
|
+
from dataclasses import field as dataclass_field
|
|
10
13
|
|
|
11
|
-
from klaude_code.command import dispatch_command
|
|
14
|
+
from klaude_code.command import InputAction, InputActionType, dispatch_command
|
|
15
|
+
from klaude_code.config import Config, load_config
|
|
12
16
|
from klaude_code.core.agent import Agent, DefaultModelProfileProvider, ModelProfileProvider
|
|
13
|
-
from klaude_code.core.sub_agent import SubAgentResult
|
|
14
17
|
from klaude_code.core.tool import current_run_subtask_callback
|
|
15
|
-
from klaude_code.llm import
|
|
16
|
-
from klaude_code.
|
|
18
|
+
from klaude_code.llm.client import LLMClientABC
|
|
19
|
+
from klaude_code.llm.registry import create_llm_client
|
|
20
|
+
from klaude_code.protocol import commands, events, model, op
|
|
21
|
+
from klaude_code.protocol.op_handler import OperationHandler
|
|
22
|
+
from klaude_code.protocol.sub_agent import SubAgentResult, get_sub_agent_profile
|
|
23
|
+
from klaude_code.protocol.tools import SubAgentType
|
|
17
24
|
from klaude_code.session.session import Session
|
|
18
25
|
from klaude_code.trace import DebugType, log_debug
|
|
19
26
|
|
|
20
27
|
|
|
28
|
+
@dataclass
|
|
29
|
+
class LLMClients:
|
|
30
|
+
"""Container for LLM clients used by main agent and sub-agents."""
|
|
31
|
+
|
|
32
|
+
main: LLMClientABC
|
|
33
|
+
sub_clients: dict[SubAgentType, LLMClientABC] = dataclass_field(default_factory=lambda: {})
|
|
34
|
+
|
|
35
|
+
def get_client(self, sub_agent_type: SubAgentType | None = None) -> LLMClientABC:
|
|
36
|
+
"""Get client for given sub-agent type, or main client if None."""
|
|
37
|
+
if sub_agent_type is None:
|
|
38
|
+
return self.main
|
|
39
|
+
return self.sub_clients.get(sub_agent_type) or self.main
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_config(
|
|
43
|
+
cls,
|
|
44
|
+
config: Config,
|
|
45
|
+
model_override: str | None = None,
|
|
46
|
+
enabled_sub_agents: list[SubAgentType] | None = None,
|
|
47
|
+
) -> LLMClients:
|
|
48
|
+
"""Create LLMClients from application config.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config: Application configuration
|
|
52
|
+
model_override: Optional model name to override the main model
|
|
53
|
+
enabled_sub_agents: List of sub-agent types to initialize clients for
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
LLMClients instance
|
|
57
|
+
"""
|
|
58
|
+
# Resolve main agent LLM config
|
|
59
|
+
if model_override:
|
|
60
|
+
llm_config = config.get_model_config(model_override)
|
|
61
|
+
else:
|
|
62
|
+
llm_config = config.get_main_model_config()
|
|
63
|
+
|
|
64
|
+
log_debug(
|
|
65
|
+
"Main LLM config",
|
|
66
|
+
llm_config.model_dump_json(exclude_none=True),
|
|
67
|
+
style="yellow",
|
|
68
|
+
debug_type=DebugType.LLM_CONFIG,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
main_client = create_llm_client(llm_config)
|
|
72
|
+
sub_clients: dict[SubAgentType, LLMClientABC] = {}
|
|
73
|
+
|
|
74
|
+
# Initialize sub-agent clients
|
|
75
|
+
for sub_agent_type in enabled_sub_agents or []:
|
|
76
|
+
model_name = config.subagent_models.get(sub_agent_type)
|
|
77
|
+
if not model_name:
|
|
78
|
+
continue
|
|
79
|
+
profile = get_sub_agent_profile(sub_agent_type)
|
|
80
|
+
if not profile.enabled_for_model(main_client.model_name):
|
|
81
|
+
continue
|
|
82
|
+
sub_llm_config = config.get_model_config(model_name)
|
|
83
|
+
sub_clients[sub_agent_type] = create_llm_client(sub_llm_config)
|
|
84
|
+
|
|
85
|
+
return cls(main=main_client, sub_clients=sub_clients)
|
|
86
|
+
|
|
87
|
+
|
|
21
88
|
@dataclass
|
|
22
89
|
class ActiveTask:
|
|
23
90
|
"""Track an in-flight task and its owning session."""
|
|
@@ -32,6 +99,8 @@ class ExecutorContext:
|
|
|
32
99
|
|
|
33
100
|
This context is passed to operations when they execute, allowing them
|
|
34
101
|
to access shared resources like the event queue and active sessions.
|
|
102
|
+
|
|
103
|
+
Implements the OperationHandler protocol via structural subtyping.
|
|
35
104
|
"""
|
|
36
105
|
|
|
37
106
|
def __init__(
|
|
@@ -65,7 +134,6 @@ class ExecutorContext:
|
|
|
65
134
|
agent = Agent(
|
|
66
135
|
session=session,
|
|
67
136
|
profile=profile,
|
|
68
|
-
model_profile_provider=self.model_profile_provider,
|
|
69
137
|
)
|
|
70
138
|
|
|
71
139
|
async for evt in agent.replay_history():
|
|
@@ -109,8 +177,12 @@ class ExecutorContext:
|
|
|
109
177
|
)
|
|
110
178
|
|
|
111
179
|
result = await dispatch_command(user_input.text, agent)
|
|
112
|
-
|
|
113
|
-
|
|
180
|
+
|
|
181
|
+
actions: list[InputAction] = list(result.actions or [])
|
|
182
|
+
|
|
183
|
+
has_run_agent_action = any(action.type is InputActionType.RUN_AGENT for action in actions)
|
|
184
|
+
if not has_run_agent_action:
|
|
185
|
+
# No async agent task will run, append user message directly
|
|
114
186
|
agent.session.append_history([model.UserMessageItem(content=user_input.text, images=user_input.images)])
|
|
115
187
|
|
|
116
188
|
if result.events:
|
|
@@ -120,15 +192,80 @@ class ExecutorContext:
|
|
|
120
192
|
for evt in result.events:
|
|
121
193
|
await self.emit_event(evt)
|
|
122
194
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
195
|
+
for action in actions:
|
|
196
|
+
await self._run_input_action(action, operation, agent)
|
|
197
|
+
|
|
198
|
+
async def _run_input_action(self, action: InputAction, operation: op.UserInputOperation, agent: Agent) -> None:
|
|
199
|
+
if operation.session_id is None:
|
|
200
|
+
raise ValueError("session_id cannot be None for input actions")
|
|
201
|
+
|
|
202
|
+
session_id = operation.session_id
|
|
203
|
+
|
|
204
|
+
if action.type == InputActionType.RUN_AGENT:
|
|
205
|
+
task_input = model.UserInputPayload(text=action.text, images=operation.input.images)
|
|
206
|
+
|
|
207
|
+
existing_active = self.active_tasks.get(operation.id)
|
|
208
|
+
if existing_active is not None and not existing_active.task.done():
|
|
209
|
+
raise RuntimeError(f"Active task already registered for operation {operation.id}")
|
|
210
|
+
|
|
127
211
|
task: asyncio.Task[None] = asyncio.create_task(
|
|
128
212
|
self._run_agent_task(agent, task_input, operation.id, session_id)
|
|
129
213
|
)
|
|
130
214
|
self.active_tasks[operation.id] = ActiveTask(task=task, session_id=session_id)
|
|
131
|
-
|
|
215
|
+
return
|
|
216
|
+
|
|
217
|
+
if action.type == InputActionType.CHANGE_MODEL:
|
|
218
|
+
if not action.model_name:
|
|
219
|
+
raise ValueError("ChangeModel action requires model_name")
|
|
220
|
+
|
|
221
|
+
await self._apply_model_change(agent, action.model_name)
|
|
222
|
+
return
|
|
223
|
+
|
|
224
|
+
if action.type == InputActionType.CLEAR:
|
|
225
|
+
await self._apply_clear(agent)
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
raise ValueError(f"Unsupported input action type: {action.type}")
|
|
229
|
+
|
|
230
|
+
async def _apply_model_change(self, agent: Agent, model_name: str) -> None:
|
|
231
|
+
config = load_config()
|
|
232
|
+
if config is None:
|
|
233
|
+
raise ValueError("Configuration must be initialized before changing model")
|
|
234
|
+
|
|
235
|
+
llm_config = config.get_model_config(model_name)
|
|
236
|
+
llm_client = create_llm_client(llm_config)
|
|
237
|
+
agent.set_model_profile(self.model_profile_provider.build_profile(llm_client))
|
|
238
|
+
|
|
239
|
+
developer_item = model.DeveloperMessageItem(
|
|
240
|
+
content=f"switched to model: {model_name}",
|
|
241
|
+
command_output=model.CommandOutput(command_name=commands.CommandName.MODEL),
|
|
242
|
+
)
|
|
243
|
+
agent.session.append_history([developer_item])
|
|
244
|
+
|
|
245
|
+
await self.emit_event(events.DeveloperMessageEvent(session_id=agent.session.id, item=developer_item))
|
|
246
|
+
await self.emit_event(events.WelcomeEvent(llm_config=llm_config, work_dir=str(agent.session.work_dir)))
|
|
247
|
+
|
|
248
|
+
async def _apply_clear(self, agent: Agent) -> None:
|
|
249
|
+
old_session_id = agent.session.id
|
|
250
|
+
|
|
251
|
+
# Create a new session instance to replace the current one
|
|
252
|
+
new_session = Session(work_dir=agent.session.work_dir)
|
|
253
|
+
new_session.model_name = agent.session.model_name
|
|
254
|
+
|
|
255
|
+
# Replace the agent's session with the new one
|
|
256
|
+
agent.session = new_session
|
|
257
|
+
agent.session.save()
|
|
258
|
+
|
|
259
|
+
# Update the active_agents mapping
|
|
260
|
+
self.active_agents.pop(old_session_id, None)
|
|
261
|
+
self.active_agents[new_session.id] = agent
|
|
262
|
+
|
|
263
|
+
developer_item = model.DeveloperMessageItem(
|
|
264
|
+
content="started new conversation",
|
|
265
|
+
command_output=model.CommandOutput(command_name=commands.CommandName.CLEAR),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
await self.emit_event(events.DeveloperMessageEvent(session_id=agent.session.id, item=developer_item))
|
|
132
269
|
|
|
133
270
|
async def handle_interrupt(self, operation: op.InterruptOperation) -> None:
|
|
134
271
|
"""Handle an interrupt by invoking agent.cancel() and cancelling tasks."""
|
|
@@ -256,7 +393,6 @@ class ExecutorContext:
|
|
|
256
393
|
child_agent = Agent(
|
|
257
394
|
session=child_session,
|
|
258
395
|
profile=child_profile,
|
|
259
|
-
model_profile_provider=self.model_profile_provider,
|
|
260
396
|
)
|
|
261
397
|
|
|
262
398
|
log_debug(
|
|
@@ -439,7 +575,7 @@ class Executor:
|
|
|
439
575
|
)
|
|
440
576
|
|
|
441
577
|
# Execute to spawn the agent task in context
|
|
442
|
-
await submission.operation.execute(self.context)
|
|
578
|
+
await submission.operation.execute(handler=self.context)
|
|
443
579
|
|
|
444
580
|
async def _await_agent_and_complete() -> None:
|
|
445
581
|
# Wait for the agent task tied to this submission id
|
|
@@ -474,3 +610,8 @@ class Executor:
|
|
|
474
610
|
event = self._completion_events.get(submission.id)
|
|
475
611
|
if event is not None:
|
|
476
612
|
event.set()
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
# Static type check: ExecutorContext must satisfy OperationHandler protocol.
|
|
616
|
+
# If this line causes a type error, ExecutorContext is missing required methods.
|
|
617
|
+
_: type[OperationHandler] = ExecutorContext # pyright: ignore[reportUnusedVariable]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
You are
|
|
1
|
+
You are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
|
2
2
|
|
|
3
3
|
Before taking any action (either tool calls *or* responses to the user), you must proactively, methodically, and independently plan and reason about:
|
|
4
4
|
|
klaude_code/core/reminders.py
CHANGED
|
@@ -241,6 +241,28 @@ class Memory(BaseModel):
|
|
|
241
241
|
content: str
|
|
242
242
|
|
|
243
243
|
|
|
244
|
+
def get_last_user_message_image_count(session: Session) -> int:
|
|
245
|
+
"""Get image count from the last user message in conversation history."""
|
|
246
|
+
for item in reversed(session.conversation_history):
|
|
247
|
+
if isinstance(item, model.ToolResultItem):
|
|
248
|
+
return 0
|
|
249
|
+
if isinstance(item, model.UserMessageItem):
|
|
250
|
+
return len(item.images) if item.images else 0
|
|
251
|
+
return 0
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
async def image_reminder(session: Session) -> model.DeveloperMessageItem | None:
|
|
255
|
+
"""Remind agent about images attached by user in the last message."""
|
|
256
|
+
image_count = get_last_user_message_image_count(session)
|
|
257
|
+
if image_count == 0:
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
return model.DeveloperMessageItem(
|
|
261
|
+
content=f"<system-reminder>User attached {image_count} image{'s' if image_count > 1 else ''} in their message. Make sure to analyze and reference these images as needed.</system-reminder>",
|
|
262
|
+
user_image_count=image_count,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
244
266
|
async def memory_reminder(session: Session) -> model.DeveloperMessageItem | None:
|
|
245
267
|
"""CLAUDE.md AGENTS.md"""
|
|
246
268
|
memory_paths = get_memory_paths()
|
|
@@ -386,6 +408,7 @@ ALL_REMINDERS = [
|
|
|
386
408
|
memory_reminder,
|
|
387
409
|
last_path_memory_reminder,
|
|
388
410
|
at_file_reader_reminder,
|
|
411
|
+
image_reminder,
|
|
389
412
|
]
|
|
390
413
|
|
|
391
414
|
|
|
@@ -415,6 +438,7 @@ def load_agent_reminders(
|
|
|
415
438
|
last_path_memory_reminder,
|
|
416
439
|
at_file_reader_reminder,
|
|
417
440
|
file_changed_externally_reminder,
|
|
441
|
+
image_reminder,
|
|
418
442
|
]
|
|
419
443
|
)
|
|
420
444
|
|
klaude_code/core/task.py
CHANGED
|
@@ -62,6 +62,16 @@ class MetadataAccumulator:
|
|
|
62
62
|
self._throughput_weighted_sum += usage.throughput_tps * current_output
|
|
63
63
|
self._throughput_tracked_tokens += current_output
|
|
64
64
|
|
|
65
|
+
# Accumulate costs
|
|
66
|
+
if usage.input_cost is not None:
|
|
67
|
+
acc_usage.input_cost = (acc_usage.input_cost or 0.0) + usage.input_cost
|
|
68
|
+
if usage.output_cost is not None:
|
|
69
|
+
acc_usage.output_cost = (acc_usage.output_cost or 0.0) + usage.output_cost
|
|
70
|
+
if usage.cache_read_cost is not None:
|
|
71
|
+
acc_usage.cache_read_cost = (acc_usage.cache_read_cost or 0.0) + usage.cache_read_cost
|
|
72
|
+
if usage.total_cost is not None:
|
|
73
|
+
acc_usage.total_cost = (acc_usage.total_cost or 0.0) + usage.total_cost
|
|
74
|
+
|
|
65
75
|
if turn_metadata.provider is not None:
|
|
66
76
|
accumulated.provider = turn_metadata.provider
|
|
67
77
|
if turn_metadata.model_name:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import re
|
|
2
3
|
import subprocess
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
@@ -10,6 +11,9 @@ from klaude_code.core.tool.tool_abc import ToolABC, load_desc
|
|
|
10
11
|
from klaude_code.core.tool.tool_registry import register
|
|
11
12
|
from klaude_code.protocol import llm_param, model, tools
|
|
12
13
|
|
|
14
|
+
# Regex to strip ANSI escape sequences from command output
|
|
15
|
+
_ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;]*m")
|
|
16
|
+
|
|
13
17
|
|
|
14
18
|
@register(tools.BASH)
|
|
15
19
|
class BashTool(ToolABC):
|
|
@@ -78,8 +82,8 @@ class BashTool(ToolABC):
|
|
|
78
82
|
check=False,
|
|
79
83
|
)
|
|
80
84
|
|
|
81
|
-
stdout = completed.stdout or ""
|
|
82
|
-
stderr = completed.stderr or ""
|
|
85
|
+
stdout = _ANSI_ESCAPE_RE.sub("", completed.stdout or "")
|
|
86
|
+
stderr = _ANSI_ESCAPE_RE.sub("", completed.stderr or "")
|
|
83
87
|
rc = completed.returncode
|
|
84
88
|
|
|
85
89
|
if rc == 0:
|
|
@@ -15,7 +15,7 @@ from klaude_code.core.tool.tool_context import current_run_subtask_callback
|
|
|
15
15
|
from klaude_code.protocol import llm_param, model
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
|
-
from klaude_code.
|
|
18
|
+
from klaude_code.protocol.sub_agent import SubAgentProfile
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class SubAgentTool(ToolABC):
|
|
@@ -5,8 +5,8 @@ from contextlib import contextmanager
|
|
|
5
5
|
from contextvars import ContextVar, Token
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
|
|
8
|
-
from klaude_code.core.sub_agent import SubAgentResult
|
|
9
8
|
from klaude_code.protocol import model
|
|
9
|
+
from klaude_code.protocol.sub_agent import SubAgentResult
|
|
10
10
|
from klaude_code.session.session import Session
|
|
11
11
|
|
|
12
12
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from typing import Callable, TypeVar
|
|
2
2
|
|
|
3
|
-
from klaude_code.core.sub_agent import get_sub_agent_profile, iter_sub_agent_profiles, sub_agent_tool_names
|
|
4
3
|
from klaude_code.core.tool.sub_agent_tool import SubAgentTool
|
|
5
4
|
from klaude_code.core.tool.tool_abc import ToolABC
|
|
6
5
|
from klaude_code.protocol import llm_param, tools
|
|
6
|
+
from klaude_code.protocol.sub_agent import get_sub_agent_profile, iter_sub_agent_profiles, sub_agent_tool_names
|
|
7
7
|
|
|
8
8
|
_REGISTRY: dict[str, type[ToolABC]] = {}
|
|
9
9
|
|
|
@@ -3,10 +3,10 @@ from collections.abc import AsyncGenerator, Callable, Iterable, Sequence
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
5
|
from klaude_code import const
|
|
6
|
-
from klaude_code.core.sub_agent import is_sub_agent_tool
|
|
7
6
|
from klaude_code.core.tool.tool_abc import ToolABC
|
|
8
7
|
from klaude_code.core.tool.truncation import truncate_tool_output
|
|
9
8
|
from klaude_code.protocol import model
|
|
9
|
+
from klaude_code.protocol.sub_agent import is_sub_agent_tool
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
async def run_tool(tool_call: model.ToolCallItem, registry: dict[str, type[ToolABC]]) -> model.ToolResultItem:
|
klaude_code/llm/__init__.py
CHANGED
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
"""LLM package init.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Imports built-in LLM clients so their ``@register`` decorators run and they
|
|
4
|
+
become available via the registry.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from .anthropic import AnthropicClient
|
|
8
8
|
from .client import LLMClientABC
|
|
9
9
|
from .openai_compatible import OpenAICompatibleClient
|
|
10
10
|
from .openrouter import OpenRouterClient
|
|
11
|
-
from .registry import
|
|
11
|
+
from .registry import create_llm_client
|
|
12
12
|
from .responses import ResponsesClient
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
15
|
"LLMClientABC",
|
|
16
|
-
"LLMClients",
|
|
17
16
|
"ResponsesClient",
|
|
18
17
|
"OpenAICompatibleClient",
|
|
19
18
|
"OpenRouterClient",
|
|
@@ -22,6 +22,7 @@ from klaude_code.llm.anthropic.input import convert_history_to_input, convert_sy
|
|
|
22
22
|
from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
|
|
23
23
|
from klaude_code.llm.input_common import apply_config_defaults
|
|
24
24
|
from klaude_code.llm.registry import register
|
|
25
|
+
from klaude_code.llm.usage import calculate_cost
|
|
25
26
|
from klaude_code.protocol import llm_param, model
|
|
26
27
|
from klaude_code.trace import DebugType, log_debug
|
|
27
28
|
|
|
@@ -199,16 +200,18 @@ class AnthropicClient(LLMClientABC):
|
|
|
199
200
|
if time_duration >= 0.15:
|
|
200
201
|
throughput_tps = output_tokens / time_duration
|
|
201
202
|
|
|
203
|
+
usage = model.Usage(
|
|
204
|
+
input_tokens=input_tokens,
|
|
205
|
+
output_tokens=output_tokens,
|
|
206
|
+
cached_tokens=cached_tokens,
|
|
207
|
+
total_tokens=total_tokens,
|
|
208
|
+
context_usage_percent=context_usage_percent,
|
|
209
|
+
throughput_tps=throughput_tps,
|
|
210
|
+
first_token_latency_ms=first_token_latency_ms,
|
|
211
|
+
)
|
|
212
|
+
calculate_cost(usage, self._config.cost)
|
|
202
213
|
yield model.ResponseMetadataItem(
|
|
203
|
-
usage=
|
|
204
|
-
input_tokens=input_tokens,
|
|
205
|
-
output_tokens=output_tokens,
|
|
206
|
-
cached_tokens=cached_tokens,
|
|
207
|
-
total_tokens=total_tokens,
|
|
208
|
-
context_usage_percent=context_usage_percent,
|
|
209
|
-
throughput_tps=throughput_tps,
|
|
210
|
-
first_token_latency_ms=first_token_latency_ms,
|
|
211
|
-
),
|
|
214
|
+
usage=usage,
|
|
212
215
|
response_id=response_id,
|
|
213
216
|
model_name=str(param.model),
|
|
214
217
|
)
|
|
@@ -8,10 +8,10 @@ from openai import APIError, RateLimitError
|
|
|
8
8
|
|
|
9
9
|
from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
|
|
10
10
|
from klaude_code.llm.input_common import apply_config_defaults
|
|
11
|
-
from klaude_code.llm.metadata_tracker import MetadataTracker
|
|
12
11
|
from klaude_code.llm.openai_compatible.input import convert_history_to_input, convert_tool_schema
|
|
13
12
|
from klaude_code.llm.openai_compatible.tool_call_accumulator import BasicToolCallAccumulator, ToolCallAccumulatorABC
|
|
14
13
|
from klaude_code.llm.registry import register
|
|
14
|
+
from klaude_code.llm.usage import MetadataTracker, convert_usage
|
|
15
15
|
from klaude_code.protocol import llm_param, model
|
|
16
16
|
from klaude_code.trace import DebugType, log_debug
|
|
17
17
|
|
|
@@ -48,7 +48,7 @@ class OpenAICompatibleClient(LLMClientABC):
|
|
|
48
48
|
messages = convert_history_to_input(param.input, param.system, param.model)
|
|
49
49
|
tools = convert_tool_schema(param.tools)
|
|
50
50
|
|
|
51
|
-
metadata_tracker = MetadataTracker()
|
|
51
|
+
metadata_tracker = MetadataTracker(cost_config=self._config.cost)
|
|
52
52
|
|
|
53
53
|
extra_body = {}
|
|
54
54
|
extra_headers = {"extra": json.dumps({"session_id": param.session_id})}
|
|
@@ -209,19 +209,3 @@ class OpenAICompatibleClient(LLMClientABC):
|
|
|
209
209
|
|
|
210
210
|
metadata_tracker.set_response_id(response_id)
|
|
211
211
|
yield metadata_tracker.finalize()
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def convert_usage(usage: openai.types.CompletionUsage, context_limit: int | None = None) -> model.Usage:
|
|
215
|
-
total_tokens = usage.total_tokens
|
|
216
|
-
context_usage_percent = (total_tokens / context_limit) * 100 if context_limit else None
|
|
217
|
-
return model.Usage(
|
|
218
|
-
input_tokens=usage.prompt_tokens,
|
|
219
|
-
cached_tokens=(usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0) or 0,
|
|
220
|
-
reasoning_tokens=(usage.completion_tokens_details.reasoning_tokens if usage.completion_tokens_details else 0)
|
|
221
|
-
or 0,
|
|
222
|
-
output_tokens=usage.completion_tokens,
|
|
223
|
-
total_tokens=total_tokens,
|
|
224
|
-
context_usage_percent=context_usage_percent,
|
|
225
|
-
throughput_tps=None,
|
|
226
|
-
first_token_latency_ms=None,
|
|
227
|
-
)
|
|
@@ -8,7 +8,7 @@ from klaude_code.protocol import model
|
|
|
8
8
|
|
|
9
9
|
class ToolCallAccumulatorABC(ABC):
|
|
10
10
|
@abstractmethod
|
|
11
|
-
def add(self, chunks: list[ChoiceDeltaToolCall]):
|
|
11
|
+
def add(self, chunks: list[ChoiceDeltaToolCall]) -> None:
|
|
12
12
|
pass
|
|
13
13
|
|
|
14
14
|
@abstractmethod
|
|
@@ -50,7 +50,7 @@ class BasicToolCallAccumulator(ToolCallAccumulatorABC, BaseModel):
|
|
|
50
50
|
chunks_by_step: list[list[ChoiceDeltaToolCall]] = Field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
|
|
51
51
|
response_id: str | None = None
|
|
52
52
|
|
|
53
|
-
def add(self, chunks: list[ChoiceDeltaToolCall]):
|
|
53
|
+
def add(self, chunks: list[ChoiceDeltaToolCall]) -> None:
|
|
54
54
|
self.chunks_by_step.append(chunks)
|
|
55
55
|
|
|
56
56
|
def get(self) -> list[model.ToolCallItem]:
|
|
@@ -6,12 +6,12 @@ import openai
|
|
|
6
6
|
|
|
7
7
|
from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
|
|
8
8
|
from klaude_code.llm.input_common import apply_config_defaults
|
|
9
|
-
from klaude_code.llm.metadata_tracker import MetadataTracker
|
|
10
9
|
from klaude_code.llm.openai_compatible.input import convert_tool_schema
|
|
11
10
|
from klaude_code.llm.openai_compatible.tool_call_accumulator import BasicToolCallAccumulator, ToolCallAccumulatorABC
|
|
12
11
|
from klaude_code.llm.openrouter.input import convert_history_to_input, is_claude_model
|
|
13
12
|
from klaude_code.llm.openrouter.reasoning_handler import ReasoningDetail, ReasoningStreamHandler
|
|
14
13
|
from klaude_code.llm.registry import register
|
|
14
|
+
from klaude_code.llm.usage import MetadataTracker, convert_usage
|
|
15
15
|
from klaude_code.protocol import llm_param, model
|
|
16
16
|
from klaude_code.trace import DebugType, log, log_debug
|
|
17
17
|
|
|
@@ -38,7 +38,7 @@ class OpenRouterClient(LLMClientABC):
|
|
|
38
38
|
messages = convert_history_to_input(param.input, param.system, param.model)
|
|
39
39
|
tools = convert_tool_schema(param.tools)
|
|
40
40
|
|
|
41
|
-
metadata_tracker = MetadataTracker()
|
|
41
|
+
metadata_tracker = MetadataTracker(cost_config=self._config.cost)
|
|
42
42
|
|
|
43
43
|
extra_body: dict[str, object] = {
|
|
44
44
|
"usage": {"include": True} # To get the cache tokens at the end of the response
|
|
@@ -198,19 +198,3 @@ class OpenRouterClient(LLMClientABC):
|
|
|
198
198
|
|
|
199
199
|
metadata_tracker.set_response_id(response_id)
|
|
200
200
|
yield metadata_tracker.finalize()
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def convert_usage(usage: openai.types.CompletionUsage, context_limit: int | None = None) -> model.Usage:
|
|
204
|
-
total_tokens = usage.total_tokens
|
|
205
|
-
context_usage_percent = (total_tokens / context_limit) * 100 if context_limit else None
|
|
206
|
-
return model.Usage(
|
|
207
|
-
input_tokens=usage.prompt_tokens,
|
|
208
|
-
cached_tokens=(usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0) or 0,
|
|
209
|
-
reasoning_tokens=(usage.completion_tokens_details.reasoning_tokens if usage.completion_tokens_details else 0)
|
|
210
|
-
or 0,
|
|
211
|
-
output_tokens=usage.completion_tokens,
|
|
212
|
-
total_tokens=total_tokens,
|
|
213
|
-
context_usage_percent=context_usage_percent,
|
|
214
|
-
throughput_tps=None,
|
|
215
|
-
first_token_latency_ms=None,
|
|
216
|
-
)
|
|
@@ -13,11 +13,15 @@ from klaude_code.llm.input_common import AssistantGroup, ToolGroup, UserGroup, m
|
|
|
13
13
|
from klaude_code.protocol import model
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
def is_claude_model(model_name: str | None):
|
|
16
|
+
def is_claude_model(model_name: str | None) -> bool:
|
|
17
|
+
"""Return True if the model name represents an Anthropic Claude model."""
|
|
18
|
+
|
|
17
19
|
return model_name is not None and model_name.startswith("anthropic/claude")
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
def is_gemini_model(model_name: str | None):
|
|
22
|
+
def is_gemini_model(model_name: str | None) -> bool:
|
|
23
|
+
"""Return True if the model name represents a Google Gemini model."""
|
|
24
|
+
|
|
21
25
|
return model_name is not None and model_name.startswith("google/gemini")
|
|
22
26
|
|
|
23
27
|
|
klaude_code/llm/registry.py
CHANGED
|
@@ -1,14 +1,7 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
|
-
from typing import TYPE_CHECKING, Callable, TypeVar
|
|
1
|
+
from typing import Callable, TypeVar
|
|
5
2
|
|
|
6
3
|
from klaude_code.llm.client import LLMClientABC
|
|
7
|
-
from klaude_code.protocol import llm_param
|
|
8
|
-
from klaude_code.trace import DebugType, log_debug
|
|
9
|
-
|
|
10
|
-
if TYPE_CHECKING:
|
|
11
|
-
from klaude_code.config import Config
|
|
4
|
+
from klaude_code.protocol import llm_param
|
|
12
5
|
|
|
13
6
|
_REGISTRY: dict[llm_param.LLMClientProtocol, type[LLMClientABC]] = {}
|
|
14
7
|
|
|
@@ -27,65 +20,3 @@ def create_llm_client(config: llm_param.LLMConfigParameter) -> LLMClientABC:
|
|
|
27
20
|
if config.protocol not in _REGISTRY:
|
|
28
21
|
raise ValueError(f"Unknown LLMClient protocol: {config.protocol}")
|
|
29
22
|
return _REGISTRY[config.protocol].create(config)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
@dataclass
|
|
33
|
-
class LLMClients:
|
|
34
|
-
"""Container for LLM clients used by main agent and sub-agents."""
|
|
35
|
-
|
|
36
|
-
main: LLMClientABC
|
|
37
|
-
sub_clients: dict[tools.SubAgentType, LLMClientABC] = field(default_factory=lambda: {})
|
|
38
|
-
|
|
39
|
-
def get_client(self, sub_agent_type: tools.SubAgentType | None = None) -> LLMClientABC:
|
|
40
|
-
"""Get client for given sub-agent type, or main client if None."""
|
|
41
|
-
if sub_agent_type is None:
|
|
42
|
-
return self.main
|
|
43
|
-
return self.sub_clients.get(sub_agent_type) or self.main
|
|
44
|
-
|
|
45
|
-
@classmethod
|
|
46
|
-
def from_config(
|
|
47
|
-
cls,
|
|
48
|
-
config: Config,
|
|
49
|
-
model_override: str | None = None,
|
|
50
|
-
enabled_sub_agents: list[tools.SubAgentType] | None = None,
|
|
51
|
-
) -> LLMClients:
|
|
52
|
-
"""Create LLMClients from application config.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
config: Application configuration
|
|
56
|
-
model_override: Optional model name to override the main model
|
|
57
|
-
enabled_sub_agents: List of sub-agent types to initialize clients for
|
|
58
|
-
|
|
59
|
-
Returns:
|
|
60
|
-
LLMClients instance
|
|
61
|
-
"""
|
|
62
|
-
from klaude_code.core.sub_agent import get_sub_agent_profile
|
|
63
|
-
|
|
64
|
-
# Resolve main agent LLM config
|
|
65
|
-
if model_override:
|
|
66
|
-
llm_config = config.get_model_config(model_override)
|
|
67
|
-
else:
|
|
68
|
-
llm_config = config.get_main_model_config()
|
|
69
|
-
|
|
70
|
-
log_debug(
|
|
71
|
-
"Main LLM config",
|
|
72
|
-
llm_config.model_dump_json(exclude_none=True),
|
|
73
|
-
style="yellow",
|
|
74
|
-
debug_type=DebugType.LLM_CONFIG,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
main_client = create_llm_client(llm_config)
|
|
78
|
-
sub_clients: dict[tools.SubAgentType, LLMClientABC] = {}
|
|
79
|
-
|
|
80
|
-
# Initialize sub-agent clients
|
|
81
|
-
for sub_agent_type in enabled_sub_agents or []:
|
|
82
|
-
model_name = config.subagent_models.get(sub_agent_type)
|
|
83
|
-
if not model_name:
|
|
84
|
-
continue
|
|
85
|
-
profile = get_sub_agent_profile(sub_agent_type)
|
|
86
|
-
if not profile.enabled_for_model(main_client.model_name):
|
|
87
|
-
continue
|
|
88
|
-
sub_llm_config = config.get_model_config(model_name)
|
|
89
|
-
sub_clients[sub_agent_type] = create_llm_client(sub_llm_config)
|
|
90
|
-
|
|
91
|
-
return cls(main=main_client, sub_clients=sub_clients)
|
|
@@ -11,6 +11,7 @@ from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
|
|
|
11
11
|
from klaude_code.llm.input_common import apply_config_defaults
|
|
12
12
|
from klaude_code.llm.registry import register
|
|
13
13
|
from klaude_code.llm.responses.input import convert_history_to_input, convert_tool_schema
|
|
14
|
+
from klaude_code.llm.usage import calculate_cost
|
|
14
15
|
from klaude_code.protocol import llm_param, model
|
|
15
16
|
from klaude_code.trace import DebugType, log_debug
|
|
16
17
|
|
|
@@ -185,6 +186,7 @@ class ResponsesClient(LLMClientABC):
|
|
|
185
186
|
throughput_tps=throughput_tps,
|
|
186
187
|
first_token_latency_ms=first_token_latency_ms,
|
|
187
188
|
)
|
|
189
|
+
calculate_cost(usage, self._config.cost)
|
|
188
190
|
yield model.ResponseMetadataItem(
|
|
189
191
|
usage=usage,
|
|
190
192
|
response_id=response_id,
|