klaude-code 1.2.8__py3-none-any.whl → 1.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klaude_code/auth/codex/__init__.py +1 -1
- klaude_code/cli/main.py +12 -1
- klaude_code/cli/runtime.py +7 -11
- klaude_code/command/__init__.py +68 -21
- klaude_code/command/clear_cmd.py +6 -2
- klaude_code/command/command_abc.py +5 -2
- klaude_code/command/diff_cmd.py +5 -2
- klaude_code/command/export_cmd.py +7 -4
- klaude_code/command/help_cmd.py +6 -2
- klaude_code/command/model_cmd.py +5 -2
- klaude_code/command/prompt-deslop.md +14 -0
- klaude_code/command/prompt_command.py +8 -3
- klaude_code/command/refresh_cmd.py +6 -2
- klaude_code/command/registry.py +17 -5
- klaude_code/command/release_notes_cmd.py +89 -0
- klaude_code/command/status_cmd.py +98 -56
- klaude_code/command/terminal_setup_cmd.py +7 -4
- klaude_code/const/__init__.py +1 -1
- klaude_code/core/agent.py +66 -26
- klaude_code/core/executor.py +2 -2
- klaude_code/core/manager/agent_manager.py +6 -7
- klaude_code/core/manager/llm_clients.py +47 -22
- klaude_code/core/manager/llm_clients_builder.py +19 -7
- klaude_code/core/manager/sub_agent_manager.py +6 -2
- klaude_code/core/prompt.py +38 -28
- klaude_code/core/reminders.py +4 -7
- klaude_code/core/task.py +59 -40
- klaude_code/core/tool/__init__.py +2 -0
- klaude_code/core/tool/file/_utils.py +30 -0
- klaude_code/core/tool/file/apply_patch_tool.py +1 -1
- klaude_code/core/tool/file/edit_tool.py +6 -31
- klaude_code/core/tool/file/multi_edit_tool.py +7 -32
- klaude_code/core/tool/file/read_tool.py +6 -18
- klaude_code/core/tool/file/write_tool.py +6 -31
- klaude_code/core/tool/memory/__init__.py +5 -0
- klaude_code/core/tool/memory/memory_tool.py +2 -2
- klaude_code/core/tool/memory/skill_loader.py +2 -1
- klaude_code/core/tool/memory/skill_tool.py +13 -0
- klaude_code/core/tool/sub_agent_tool.py +2 -1
- klaude_code/core/tool/todo/todo_write_tool.py +1 -1
- klaude_code/core/tool/todo/update_plan_tool.py +1 -1
- klaude_code/core/tool/tool_context.py +21 -4
- klaude_code/core/tool/tool_runner.py +5 -8
- klaude_code/core/tool/web/mermaid_tool.py +1 -4
- klaude_code/core/turn.py +40 -37
- klaude_code/llm/__init__.py +2 -12
- klaude_code/llm/anthropic/client.py +14 -44
- klaude_code/llm/client.py +2 -2
- klaude_code/llm/codex/client.py +4 -3
- klaude_code/llm/input_common.py +0 -6
- klaude_code/llm/openai_compatible/client.py +31 -74
- klaude_code/llm/openai_compatible/input.py +6 -4
- klaude_code/llm/openai_compatible/stream_processor.py +82 -0
- klaude_code/llm/openrouter/client.py +32 -62
- klaude_code/llm/openrouter/input.py +4 -27
- klaude_code/llm/registry.py +33 -7
- klaude_code/llm/responses/client.py +16 -48
- klaude_code/llm/responses/input.py +1 -1
- klaude_code/llm/usage.py +61 -11
- klaude_code/protocol/commands.py +1 -0
- klaude_code/protocol/events.py +11 -2
- klaude_code/protocol/model.py +147 -24
- klaude_code/protocol/op.py +1 -0
- klaude_code/protocol/sub_agent.py +5 -1
- klaude_code/session/export.py +56 -32
- klaude_code/session/session.py +43 -21
- klaude_code/session/templates/export_session.html +4 -1
- klaude_code/ui/core/input.py +1 -1
- klaude_code/ui/modes/repl/__init__.py +1 -5
- klaude_code/ui/modes/repl/clipboard.py +5 -5
- klaude_code/ui/modes/repl/event_handler.py +153 -54
- klaude_code/ui/modes/repl/renderer.py +4 -4
- klaude_code/ui/renderers/developer.py +35 -25
- klaude_code/ui/renderers/metadata.py +68 -30
- klaude_code/ui/renderers/tools.py +53 -87
- klaude_code/ui/rich/markdown.py +5 -5
- klaude_code/ui/terminal/control.py +2 -2
- klaude_code/version.py +3 -3
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/METADATA +1 -1
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/RECORD +82 -78
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/WHEEL +0 -0
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/entry_points.txt +0 -0
|
@@ -1,51 +1,73 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
1
3
|
from klaude_code.command.command_abc import CommandABC, CommandResult
|
|
2
4
|
from klaude_code.command.registry import register_command
|
|
3
|
-
from klaude_code.core.agent import Agent
|
|
4
5
|
from klaude_code.protocol import commands, events, model
|
|
5
6
|
from klaude_code.session.session import Session
|
|
6
7
|
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from klaude_code.core.agent import Agent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AggregatedUsage(model.BaseModel):
|
|
13
|
+
"""Aggregated usage statistics including per-model breakdown."""
|
|
14
|
+
|
|
15
|
+
total: model.Usage
|
|
16
|
+
by_model: list[model.TaskMetadata]
|
|
17
|
+
task_count: int
|
|
18
|
+
|
|
7
19
|
|
|
8
|
-
def accumulate_session_usage(session: Session) ->
|
|
9
|
-
"""Accumulate usage statistics from all
|
|
20
|
+
def accumulate_session_usage(session: Session) -> AggregatedUsage:
|
|
21
|
+
"""Accumulate usage statistics from all TaskMetadataItems in session history.
|
|
10
22
|
|
|
11
|
-
|
|
12
|
-
A tuple of (accumulated_usage, task_count)
|
|
23
|
+
Includes both main agent and sub-agent task metadata, grouped by model+provider.
|
|
13
24
|
"""
|
|
14
|
-
|
|
25
|
+
all_metadata: list[model.TaskMetadata] = []
|
|
15
26
|
task_count = 0
|
|
16
|
-
first_currency_set = False
|
|
17
27
|
|
|
18
28
|
for item in session.conversation_history:
|
|
19
|
-
if isinstance(item, model.
|
|
29
|
+
if isinstance(item, model.TaskMetadataItem):
|
|
20
30
|
task_count += 1
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
31
|
+
all_metadata.append(item.main)
|
|
32
|
+
all_metadata.extend(item.sub_agent_task_metadata)
|
|
33
|
+
|
|
34
|
+
# Aggregate by model+provider
|
|
35
|
+
by_model = model.TaskMetadata.aggregate_by_model(all_metadata)
|
|
36
|
+
|
|
37
|
+
# Calculate total from aggregated results
|
|
38
|
+
total = model.Usage()
|
|
39
|
+
for meta in by_model:
|
|
40
|
+
if not meta.usage:
|
|
41
|
+
continue
|
|
42
|
+
usage = meta.usage
|
|
43
|
+
|
|
44
|
+
# Set currency from first
|
|
45
|
+
if total.currency == "USD" and usage.currency:
|
|
46
|
+
total.currency = usage.currency
|
|
47
|
+
|
|
48
|
+
# Accumulate primary token fields (total_tokens is computed)
|
|
49
|
+
total.input_tokens += usage.input_tokens
|
|
50
|
+
total.cached_tokens += usage.cached_tokens
|
|
51
|
+
total.reasoning_tokens += usage.reasoning_tokens
|
|
52
|
+
total.output_tokens += usage.output_tokens
|
|
53
|
+
|
|
54
|
+
# Accumulate cost components (total_cost is computed)
|
|
55
|
+
if usage.input_cost is not None:
|
|
56
|
+
total.input_cost = (total.input_cost or 0.0) + usage.input_cost
|
|
57
|
+
if usage.output_cost is not None:
|
|
58
|
+
total.output_cost = (total.output_cost or 0.0) + usage.output_cost
|
|
59
|
+
if usage.cache_read_cost is not None:
|
|
60
|
+
total.cache_read_cost = (total.cache_read_cost or 0.0) + usage.cache_read_cost
|
|
61
|
+
|
|
62
|
+
# Track peak context window size (max across all tasks)
|
|
63
|
+
if usage.context_token is not None:
|
|
64
|
+
total.context_token = usage.context_token
|
|
65
|
+
|
|
66
|
+
# Keep the latest context_limit for computed context_usage_percent
|
|
67
|
+
if usage.context_limit is not None:
|
|
68
|
+
total.context_limit = usage.context_limit
|
|
69
|
+
|
|
70
|
+
return AggregatedUsage(total=total, by_model=by_model, task_count=task_count)
|
|
49
71
|
|
|
50
72
|
|
|
51
73
|
def _format_tokens(tokens: int) -> str:
|
|
@@ -67,20 +89,42 @@ def _format_cost(cost: float | None, currency: str = "USD") -> str:
|
|
|
67
89
|
return f"{symbol}{cost:.2f}"
|
|
68
90
|
|
|
69
91
|
|
|
70
|
-
def
|
|
71
|
-
"""Format
|
|
72
|
-
|
|
92
|
+
def _format_model_usage_line(meta: model.TaskMetadata) -> str:
|
|
93
|
+
"""Format a single model's usage as a line."""
|
|
94
|
+
model_label = meta.model_name
|
|
95
|
+
if meta.provider:
|
|
96
|
+
model_label = f"{meta.model_name} ({meta.provider})"
|
|
97
|
+
|
|
98
|
+
usage = meta.usage
|
|
99
|
+
if not usage:
|
|
100
|
+
return f" {model_label}: no usage data"
|
|
101
|
+
|
|
102
|
+
cost_str = _format_cost(usage.total_cost, usage.currency)
|
|
103
|
+
return (
|
|
104
|
+
f" {model_label}: "
|
|
105
|
+
f"{_format_tokens(usage.input_tokens)} input, "
|
|
106
|
+
f"{_format_tokens(usage.output_tokens)} output, "
|
|
107
|
+
f"{_format_tokens(usage.cached_tokens)} cache read, "
|
|
108
|
+
f"{_format_tokens(usage.reasoning_tokens)} thinking, "
|
|
109
|
+
f"({cost_str})"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def format_status_content(aggregated: AggregatedUsage) -> str:
|
|
114
|
+
"""Format session status with per-model breakdown."""
|
|
115
|
+
lines: list[str] = []
|
|
73
116
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
parts.append(f"Output: {_format_tokens(usage.output_tokens)}")
|
|
78
|
-
parts.append(f"Total: {_format_tokens(usage.total_tokens)}")
|
|
117
|
+
# Total cost line
|
|
118
|
+
total_cost_str = _format_cost(aggregated.total.total_cost, aggregated.total.currency)
|
|
119
|
+
lines.append(f"Total cost: {total_cost_str}")
|
|
79
120
|
|
|
80
|
-
|
|
81
|
-
|
|
121
|
+
# Per-model breakdown
|
|
122
|
+
if aggregated.by_model:
|
|
123
|
+
lines.append("Usage by model:")
|
|
124
|
+
for stats in aggregated.by_model:
|
|
125
|
+
lines.append(_format_model_usage_line(stats))
|
|
82
126
|
|
|
83
|
-
return "
|
|
127
|
+
return "\n".join(lines)
|
|
84
128
|
|
|
85
129
|
|
|
86
130
|
@register_command
|
|
@@ -95,22 +139,20 @@ class StatusCommand(CommandABC):
|
|
|
95
139
|
def summary(self) -> str:
|
|
96
140
|
return "Show session usage statistics"
|
|
97
141
|
|
|
98
|
-
async def run(self, raw: str, agent: Agent) -> CommandResult:
|
|
142
|
+
async def run(self, raw: str, agent: "Agent") -> CommandResult:
|
|
99
143
|
session = agent.session
|
|
100
|
-
|
|
144
|
+
aggregated = accumulate_session_usage(session)
|
|
101
145
|
|
|
102
146
|
event = events.DeveloperMessageEvent(
|
|
103
147
|
session_id=session.id,
|
|
104
148
|
item=model.DeveloperMessageItem(
|
|
105
|
-
content=format_status_content(
|
|
149
|
+
content=format_status_content(aggregated),
|
|
106
150
|
command_output=model.CommandOutput(
|
|
107
151
|
command_name=self.name,
|
|
108
|
-
ui_extra=model.
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
task_count=task_count,
|
|
113
|
-
),
|
|
152
|
+
ui_extra=model.SessionStatusUIExtra(
|
|
153
|
+
usage=aggregated.total,
|
|
154
|
+
task_count=aggregated.task_count,
|
|
155
|
+
by_model=aggregated.by_model,
|
|
114
156
|
),
|
|
115
157
|
),
|
|
116
158
|
),
|
|
@@ -1,12 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import subprocess
|
|
3
3
|
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
4
5
|
|
|
5
6
|
from klaude_code.command.command_abc import CommandABC, CommandResult
|
|
6
7
|
from klaude_code.command.registry import register_command
|
|
7
|
-
from klaude_code.core.agent import Agent
|
|
8
8
|
from klaude_code.protocol import commands, events, model
|
|
9
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from klaude_code.core.agent import Agent
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
@register_command
|
|
12
15
|
class TerminalSetupCommand(CommandABC):
|
|
@@ -24,7 +27,7 @@ class TerminalSetupCommand(CommandABC):
|
|
|
24
27
|
def is_interactive(self) -> bool:
|
|
25
28
|
return False
|
|
26
29
|
|
|
27
|
-
async def run(self, raw: str, agent: Agent) -> CommandResult:
|
|
30
|
+
async def run(self, raw: str, agent: "Agent") -> CommandResult:
|
|
28
31
|
term_program = os.environ.get("TERM_PROGRAM", "").lower()
|
|
29
32
|
|
|
30
33
|
try:
|
|
@@ -223,7 +226,7 @@ class TerminalSetupCommand(CommandABC):
|
|
|
223
226
|
|
|
224
227
|
return message
|
|
225
228
|
|
|
226
|
-
def _create_success_result(self, agent: Agent, message: str) -> CommandResult:
|
|
229
|
+
def _create_success_result(self, agent: "Agent", message: str) -> CommandResult:
|
|
227
230
|
"""Create success result"""
|
|
228
231
|
return CommandResult(
|
|
229
232
|
events=[
|
|
@@ -237,7 +240,7 @@ class TerminalSetupCommand(CommandABC):
|
|
|
237
240
|
]
|
|
238
241
|
)
|
|
239
242
|
|
|
240
|
-
def _create_error_result(self, agent: Agent, message: str) -> CommandResult:
|
|
243
|
+
def _create_error_result(self, agent: "Agent", message: str) -> CommandResult:
|
|
241
244
|
"""Create error result"""
|
|
242
245
|
return CommandResult(
|
|
243
246
|
events=[
|
klaude_code/const/__init__.py
CHANGED
|
@@ -91,7 +91,7 @@ INVALID_TOOL_CALL_MAX_LENGTH = 500
|
|
|
91
91
|
TRUNCATE_DISPLAY_MAX_LINE_LENGTH = 1000
|
|
92
92
|
|
|
93
93
|
# Maximum lines for truncated display output
|
|
94
|
-
TRUNCATE_DISPLAY_MAX_LINES =
|
|
94
|
+
TRUNCATE_DISPLAY_MAX_LINES = 20
|
|
95
95
|
|
|
96
96
|
# Maximum lines for sub-agent result display
|
|
97
97
|
SUB_AGENT_RESULT_MAX_LINES = 12
|
klaude_code/core/agent.py
CHANGED
|
@@ -1,34 +1,51 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncGenerator, Iterable
|
|
3
|
+
from collections.abc import AsyncGenerator, Callable, Iterable
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Protocol
|
|
5
|
+
from typing import TYPE_CHECKING, Protocol
|
|
6
6
|
|
|
7
7
|
from klaude_code.core.prompt import get_system_prompt as load_system_prompt
|
|
8
8
|
from klaude_code.core.reminders import Reminder, load_agent_reminders
|
|
9
|
-
from klaude_code.core.task import TaskExecutionContext, TaskExecutor
|
|
10
|
-
from klaude_code.core.tool import
|
|
9
|
+
from klaude_code.core.task import SessionContext, TaskExecutionContext, TaskExecutor
|
|
10
|
+
from klaude_code.core.tool import build_todo_context, get_registry, load_agent_tools
|
|
11
11
|
from klaude_code.llm import LLMClientABC
|
|
12
12
|
from klaude_code.protocol import events, llm_param, model, tools
|
|
13
13
|
from klaude_code.protocol.model import UserInputPayload
|
|
14
14
|
from klaude_code.session import Session
|
|
15
15
|
from klaude_code.trace import DebugType, log_debug
|
|
16
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from klaude_code.core.manager.llm_clients import LLMClients
|
|
19
|
+
|
|
17
20
|
|
|
18
21
|
@dataclass(frozen=True)
|
|
19
22
|
class AgentProfile:
|
|
20
23
|
"""Encapsulates the active LLM client plus prompts/tools/reminders."""
|
|
21
24
|
|
|
22
|
-
|
|
25
|
+
llm_client_factory: Callable[[], LLMClientABC]
|
|
23
26
|
system_prompt: str | None
|
|
24
27
|
tools: list[llm_param.ToolSchema]
|
|
25
28
|
reminders: list[Reminder]
|
|
26
29
|
|
|
30
|
+
_llm_client: LLMClientABC | None = None
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def llm_client(self) -> LLMClientABC:
|
|
34
|
+
if self._llm_client is None:
|
|
35
|
+
object.__setattr__(self, "_llm_client", self.llm_client_factory())
|
|
36
|
+
return self._llm_client # type: ignore[return-value]
|
|
37
|
+
|
|
27
38
|
|
|
28
39
|
class ModelProfileProvider(Protocol):
|
|
29
40
|
"""Strategy interface for constructing agent profiles."""
|
|
30
41
|
|
|
31
42
|
def build_profile(
|
|
43
|
+
self,
|
|
44
|
+
llm_clients: LLMClients,
|
|
45
|
+
sub_agent_type: tools.SubAgentType | None = None,
|
|
46
|
+
) -> AgentProfile: ...
|
|
47
|
+
|
|
48
|
+
def build_profile_eager(
|
|
32
49
|
self,
|
|
33
50
|
llm_client: LLMClientABC,
|
|
34
51
|
sub_agent_type: tools.SubAgentType | None = None,
|
|
@@ -39,13 +56,26 @@ class DefaultModelProfileProvider(ModelProfileProvider):
|
|
|
39
56
|
"""Default provider backed by global prompts/tool/reminder registries."""
|
|
40
57
|
|
|
41
58
|
def build_profile(
|
|
59
|
+
self,
|
|
60
|
+
llm_clients: LLMClients,
|
|
61
|
+
sub_agent_type: tools.SubAgentType | None = None,
|
|
62
|
+
) -> AgentProfile:
|
|
63
|
+
model_name = llm_clients.main_model_name
|
|
64
|
+
return AgentProfile(
|
|
65
|
+
llm_client_factory=lambda: llm_clients.main,
|
|
66
|
+
system_prompt=load_system_prompt(model_name, sub_agent_type),
|
|
67
|
+
tools=load_agent_tools(model_name, sub_agent_type),
|
|
68
|
+
reminders=load_agent_reminders(model_name, sub_agent_type),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def build_profile_eager(
|
|
42
72
|
self,
|
|
43
73
|
llm_client: LLMClientABC,
|
|
44
74
|
sub_agent_type: tools.SubAgentType | None = None,
|
|
45
75
|
) -> AgentProfile:
|
|
46
76
|
model_name = llm_client.model_name
|
|
47
77
|
return AgentProfile(
|
|
48
|
-
|
|
78
|
+
llm_client_factory=lambda: llm_client,
|
|
49
79
|
system_prompt=load_system_prompt(model_name, sub_agent_type),
|
|
50
80
|
tools=load_agent_tools(model_name, sub_agent_type),
|
|
51
81
|
reminders=load_agent_reminders(model_name, sub_agent_type),
|
|
@@ -56,13 +86,26 @@ class VanillaModelProfileProvider(ModelProfileProvider):
|
|
|
56
86
|
"""Provider that strips prompts, reminders, and tools for vanilla mode."""
|
|
57
87
|
|
|
58
88
|
def build_profile(
|
|
89
|
+
self,
|
|
90
|
+
llm_clients: LLMClients,
|
|
91
|
+
sub_agent_type: tools.SubAgentType | None = None,
|
|
92
|
+
) -> AgentProfile:
|
|
93
|
+
model_name = llm_clients.main_model_name
|
|
94
|
+
return AgentProfile(
|
|
95
|
+
llm_client_factory=lambda: llm_clients.main,
|
|
96
|
+
system_prompt=None,
|
|
97
|
+
tools=load_agent_tools(model_name, vanilla=True),
|
|
98
|
+
reminders=load_agent_reminders(model_name, vanilla=True),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def build_profile_eager(
|
|
59
102
|
self,
|
|
60
103
|
llm_client: LLMClientABC,
|
|
61
104
|
sub_agent_type: tools.SubAgentType | None = None,
|
|
62
105
|
) -> AgentProfile:
|
|
63
106
|
model_name = llm_client.model_name
|
|
64
107
|
return AgentProfile(
|
|
65
|
-
|
|
108
|
+
llm_client_factory=lambda: llm_client,
|
|
66
109
|
system_prompt=None,
|
|
67
110
|
tools=load_agent_tools(model_name, vanilla=True),
|
|
68
111
|
reminders=load_agent_reminders(model_name, vanilla=True),
|
|
@@ -74,13 +117,13 @@ class Agent:
|
|
|
74
117
|
self,
|
|
75
118
|
session: Session,
|
|
76
119
|
profile: AgentProfile,
|
|
120
|
+
model_name: str | None = None,
|
|
77
121
|
):
|
|
78
122
|
self.session: Session = session
|
|
79
|
-
self.profile: AgentProfile
|
|
80
|
-
# Active task executor, if any
|
|
123
|
+
self.profile: AgentProfile = profile
|
|
81
124
|
self._current_task: TaskExecutor | None = None
|
|
82
|
-
|
|
83
|
-
|
|
125
|
+
if not self.session.model_name and model_name:
|
|
126
|
+
self.session.model_name = model_name
|
|
84
127
|
|
|
85
128
|
def cancel(self) -> Iterable[events.Event]:
|
|
86
129
|
"""Handle agent cancellation and persist an interrupt marker and tool cancellations.
|
|
@@ -106,17 +149,17 @@ class Agent:
|
|
|
106
149
|
)
|
|
107
150
|
|
|
108
151
|
async def run_task(self, user_input: UserInputPayload) -> AsyncGenerator[events.Event, None]:
|
|
109
|
-
|
|
152
|
+
session_ctx = SessionContext(
|
|
110
153
|
session_id=self.session.id,
|
|
111
|
-
profile=self._require_profile(),
|
|
112
154
|
get_conversation_history=lambda: self.session.conversation_history,
|
|
113
155
|
append_history=self.session.append_history,
|
|
114
|
-
tool_registry=get_registry(),
|
|
115
156
|
file_tracker=self.session.file_tracker,
|
|
116
|
-
todo_context=
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
157
|
+
todo_context=build_todo_context(self.session),
|
|
158
|
+
)
|
|
159
|
+
context = TaskExecutionContext(
|
|
160
|
+
session_ctx=session_ctx,
|
|
161
|
+
profile=self.profile,
|
|
162
|
+
tool_registry=get_registry(),
|
|
120
163
|
process_reminder=self._process_reminder,
|
|
121
164
|
sub_agent_state=self.session.sub_agent_state,
|
|
122
165
|
)
|
|
@@ -149,17 +192,14 @@ class Agent:
|
|
|
149
192
|
self.session.append_history([item])
|
|
150
193
|
yield events.DeveloperMessageEvent(session_id=self.session.id, item=item)
|
|
151
194
|
|
|
152
|
-
def set_model_profile(self, profile: AgentProfile) -> None:
|
|
195
|
+
def set_model_profile(self, profile: AgentProfile, model_name: str | None = None) -> None:
|
|
153
196
|
"""Apply a fully constructed profile to the agent."""
|
|
154
197
|
|
|
155
198
|
self.profile = profile
|
|
156
|
-
if
|
|
199
|
+
if model_name:
|
|
200
|
+
self.session.model_name = model_name
|
|
201
|
+
elif not self.session.model_name:
|
|
157
202
|
self.session.model_name = profile.llm_client.model_name
|
|
158
203
|
|
|
159
204
|
def get_llm_client(self) -> LLMClientABC:
|
|
160
|
-
return self.
|
|
161
|
-
|
|
162
|
-
def _require_profile(self) -> AgentProfile:
|
|
163
|
-
if self.profile is None:
|
|
164
|
-
raise RuntimeError("Agent profile is not initialized")
|
|
165
|
-
return self.profile
|
|
205
|
+
return self.profile.llm_client
|
klaude_code/core/executor.py
CHANGED
|
@@ -117,7 +117,7 @@ class ExecutorContext:
|
|
|
117
117
|
if operation.session_id is None:
|
|
118
118
|
raise ValueError("session_id cannot be None")
|
|
119
119
|
|
|
120
|
-
await self.agent_manager.ensure_agent(operation.session_id)
|
|
120
|
+
await self.agent_manager.ensure_agent(operation.session_id, is_new_session=operation.is_new_session)
|
|
121
121
|
|
|
122
122
|
async def handle_user_input(self, operation: op.UserInputOperation) -> None:
|
|
123
123
|
"""Handle a user input operation by running it through an agent."""
|
|
@@ -482,4 +482,4 @@ class Executor:
|
|
|
482
482
|
|
|
483
483
|
# Static type check: ExecutorContext must satisfy OperationHandler protocol.
|
|
484
484
|
# If this line causes a type error, ExecutorContext is missing required methods.
|
|
485
|
-
_: type[OperationHandler] = ExecutorContext
|
|
485
|
+
_: type[OperationHandler] = ExecutorContext
|
|
@@ -38,16 +38,15 @@ class AgentManager:
|
|
|
38
38
|
|
|
39
39
|
await self._event_queue.put(event)
|
|
40
40
|
|
|
41
|
-
async def ensure_agent(self, session_id: str) -> Agent:
|
|
41
|
+
async def ensure_agent(self, session_id: str, *, is_new_session: bool = False) -> Agent:
|
|
42
42
|
"""Return an existing agent for the session or create a new one."""
|
|
43
|
-
|
|
44
43
|
agent = self._active_agents.get(session_id)
|
|
45
44
|
if agent is not None:
|
|
46
45
|
return agent
|
|
47
46
|
|
|
48
|
-
session = Session.load(session_id)
|
|
49
|
-
profile = self._model_profile_provider.build_profile(self._llm_clients
|
|
50
|
-
agent = Agent(session=session, profile=profile)
|
|
47
|
+
session = Session.load(session_id, skip_if_missing=is_new_session)
|
|
48
|
+
profile = self._model_profile_provider.build_profile(self._llm_clients)
|
|
49
|
+
agent = Agent(session=session, profile=profile, model_name=self._llm_clients.main_model_name)
|
|
51
50
|
|
|
52
51
|
async for evt in agent.replay_history():
|
|
53
52
|
await self.emit_event(evt)
|
|
@@ -55,7 +54,7 @@ class AgentManager:
|
|
|
55
54
|
await self.emit_event(
|
|
56
55
|
events.WelcomeEvent(
|
|
57
56
|
work_dir=str(session.work_dir),
|
|
58
|
-
llm_config=self._llm_clients.
|
|
57
|
+
llm_config=self._llm_clients.get_llm_config(),
|
|
59
58
|
)
|
|
60
59
|
)
|
|
61
60
|
|
|
@@ -76,7 +75,7 @@ class AgentManager:
|
|
|
76
75
|
|
|
77
76
|
llm_config = config.get_model_config(model_name)
|
|
78
77
|
llm_client = create_llm_client(llm_config)
|
|
79
|
-
agent.set_model_profile(self._model_profile_provider.
|
|
78
|
+
agent.set_model_profile(self._model_profile_provider.build_profile_eager(llm_client), model_name=model_name)
|
|
80
79
|
|
|
81
80
|
developer_item = model.DeveloperMessageItem(
|
|
82
81
|
content=f"switched to model: {model_name}",
|
|
@@ -2,41 +2,66 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from dataclasses import field as dataclass_field
|
|
5
|
+
from collections.abc import Callable
|
|
7
6
|
|
|
8
7
|
from klaude_code.llm.client import LLMClientABC
|
|
8
|
+
from klaude_code.protocol import llm_param
|
|
9
9
|
from klaude_code.protocol.tools import SubAgentType
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
"""
|
|
12
|
+
class LLMClients:
|
|
13
|
+
"""Container for LLM clients used by main agent and sub-agents."""
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
main_factory: Callable[[], LLMClientABC],
|
|
18
|
+
main_model_name: str,
|
|
19
|
+
main_llm_config: llm_param.LLMConfigParameter,
|
|
20
|
+
) -> None:
|
|
21
|
+
self._main_factory: Callable[[], LLMClientABC] | None = main_factory
|
|
22
|
+
self._main_client: LLMClientABC | None = None
|
|
23
|
+
self._main_model_name: str = main_model_name
|
|
24
|
+
self._main_llm_config: llm_param.LLMConfigParameter = main_llm_config
|
|
25
|
+
self._sub_clients: dict[SubAgentType, LLMClientABC] = {}
|
|
26
|
+
self._sub_factories: dict[SubAgentType, Callable[[], LLMClientABC]] = {}
|
|
18
27
|
|
|
19
|
-
|
|
28
|
+
@property
|
|
29
|
+
def main_model_name(self) -> str:
|
|
30
|
+
return self._main_model_name
|
|
20
31
|
|
|
32
|
+
def get_llm_config(self) -> llm_param.LLMConfigParameter:
|
|
33
|
+
return self._main_llm_config
|
|
21
34
|
|
|
22
|
-
@
|
|
23
|
-
|
|
24
|
-
|
|
35
|
+
@property
|
|
36
|
+
def main(self) -> LLMClientABC:
|
|
37
|
+
if self._main_client is None:
|
|
38
|
+
if self._main_factory is None:
|
|
39
|
+
raise RuntimeError("Main client factory not set")
|
|
40
|
+
self._main_client = self._main_factory()
|
|
41
|
+
self._main_factory = None
|
|
42
|
+
return self._main_client
|
|
25
43
|
|
|
26
|
-
|
|
27
|
-
|
|
44
|
+
def register_sub_client_factory(
|
|
45
|
+
self,
|
|
46
|
+
sub_agent_type: SubAgentType,
|
|
47
|
+
factory: Callable[[], LLMClientABC],
|
|
48
|
+
) -> None:
|
|
49
|
+
self._sub_factories[sub_agent_type] = factory
|
|
28
50
|
|
|
29
51
|
def get_client(self, sub_agent_type: SubAgentType | None = None) -> LLMClientABC:
|
|
30
|
-
"""Return client for a sub-agent type or the main client.
|
|
52
|
+
"""Return client for a sub-agent type or the main client."""
|
|
31
53
|
|
|
32
|
-
|
|
33
|
-
|
|
54
|
+
if sub_agent_type is None:
|
|
55
|
+
return self.main
|
|
34
56
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
"""
|
|
57
|
+
existing = self._sub_clients.get(sub_agent_type)
|
|
58
|
+
if existing is not None:
|
|
59
|
+
return existing
|
|
39
60
|
|
|
40
|
-
|
|
61
|
+
factory = self._sub_factories.get(sub_agent_type)
|
|
62
|
+
if factory is None:
|
|
41
63
|
return self.main
|
|
42
|
-
|
|
64
|
+
|
|
65
|
+
client = factory()
|
|
66
|
+
self._sub_clients[sub_agent_type] = client
|
|
67
|
+
return client
|
|
@@ -32,18 +32,30 @@ def build_llm_clients(
|
|
|
32
32
|
debug_type=DebugType.LLM_CONFIG,
|
|
33
33
|
)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
35
|
+
main_model_name = str(llm_config.model)
|
|
36
|
+
|
|
37
|
+
def _main_factory() -> LLMClientABC:
|
|
38
|
+
return create_llm_client(llm_config)
|
|
39
|
+
|
|
40
|
+
clients = LLMClients(
|
|
41
|
+
main_factory=_main_factory,
|
|
42
|
+
main_model_name=main_model_name,
|
|
43
|
+
main_llm_config=llm_config,
|
|
44
|
+
)
|
|
37
45
|
|
|
38
|
-
# Initialize sub-agent clients
|
|
39
46
|
for sub_agent_type in enabled_sub_agents or []:
|
|
40
47
|
model_name = config.subagent_models.get(sub_agent_type)
|
|
41
48
|
if not model_name:
|
|
42
49
|
continue
|
|
50
|
+
|
|
43
51
|
profile = get_sub_agent_profile(sub_agent_type)
|
|
44
|
-
if not profile.enabled_for_model(
|
|
52
|
+
if not profile.enabled_for_model(main_model_name):
|
|
45
53
|
continue
|
|
46
|
-
sub_llm_config = config.get_model_config(model_name)
|
|
47
|
-
sub_clients[sub_agent_type] = create_llm_client(sub_llm_config)
|
|
48
54
|
|
|
49
|
-
|
|
55
|
+
def _factory(model_name_for_factory: str = model_name) -> LLMClientABC:
|
|
56
|
+
sub_llm_config = config.get_model_config(model_name_for_factory)
|
|
57
|
+
return create_llm_client(sub_llm_config)
|
|
58
|
+
|
|
59
|
+
clients.register_sub_client_factory(sub_agent_type, _factory)
|
|
60
|
+
|
|
61
|
+
return clients
|
|
@@ -43,7 +43,7 @@ class SubAgentManager:
|
|
|
43
43
|
child_session = Session(work_dir=parent_session.work_dir)
|
|
44
44
|
child_session.sub_agent_state = state
|
|
45
45
|
|
|
46
|
-
child_profile = self._model_profile_provider.
|
|
46
|
+
child_profile = self._model_profile_provider.build_profile_eager(
|
|
47
47
|
self._llm_clients.get_client(state.sub_agent_type),
|
|
48
48
|
state.sub_agent_type,
|
|
49
49
|
)
|
|
@@ -58,13 +58,17 @@ class SubAgentManager:
|
|
|
58
58
|
try:
|
|
59
59
|
# Not emit the subtask's user input since task tool call is already rendered
|
|
60
60
|
result: str = ""
|
|
61
|
+
task_metadata: model.TaskMetadata | None = None
|
|
61
62
|
sub_agent_input = model.UserInputPayload(text=state.sub_agent_prompt, images=None)
|
|
62
63
|
async for event in child_agent.run_task(sub_agent_input):
|
|
63
64
|
# Capture TaskFinishEvent content for return
|
|
64
65
|
if isinstance(event, events.TaskFinishEvent):
|
|
65
66
|
result = event.task_result
|
|
67
|
+
# Capture TaskMetadataEvent for metadata propagation
|
|
68
|
+
elif isinstance(event, events.TaskMetadataEvent):
|
|
69
|
+
task_metadata = event.metadata.main
|
|
66
70
|
await self.emit_event(event)
|
|
67
|
-
return SubAgentResult(task_result=result, session_id=child_session.id)
|
|
71
|
+
return SubAgentResult(task_result=result, session_id=child_session.id, task_metadata=task_metadata)
|
|
68
72
|
except asyncio.CancelledError:
|
|
69
73
|
# Propagate cancellation so tooling can treat it as user interrupt
|
|
70
74
|
log_debug(
|