klaude-code 1.9.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klaude_code/auth/base.py +2 -6
- klaude_code/cli/auth_cmd.py +4 -4
- klaude_code/cli/list_model.py +1 -1
- klaude_code/cli/main.py +1 -1
- klaude_code/cli/runtime.py +7 -5
- klaude_code/cli/self_update.py +1 -1
- klaude_code/cli/session_cmd.py +1 -1
- klaude_code/command/clear_cmd.py +6 -2
- klaude_code/command/command_abc.py +2 -2
- klaude_code/command/debug_cmd.py +4 -4
- klaude_code/command/export_cmd.py +2 -2
- klaude_code/command/export_online_cmd.py +12 -12
- klaude_code/command/fork_session_cmd.py +29 -23
- klaude_code/command/help_cmd.py +4 -4
- klaude_code/command/model_cmd.py +4 -4
- klaude_code/command/model_select.py +1 -1
- klaude_code/command/prompt-commit.md +11 -2
- klaude_code/command/prompt_command.py +3 -3
- klaude_code/command/refresh_cmd.py +2 -2
- klaude_code/command/registry.py +7 -5
- klaude_code/command/release_notes_cmd.py +4 -4
- klaude_code/command/resume_cmd.py +15 -11
- klaude_code/command/status_cmd.py +4 -4
- klaude_code/command/terminal_setup_cmd.py +8 -8
- klaude_code/command/thinking_cmd.py +4 -4
- klaude_code/config/assets/builtin_config.yaml +16 -0
- klaude_code/config/builtin_config.py +16 -5
- klaude_code/config/config.py +7 -2
- klaude_code/const.py +146 -91
- klaude_code/core/agent.py +3 -12
- klaude_code/core/executor.py +21 -13
- klaude_code/core/manager/sub_agent_manager.py +71 -7
- klaude_code/core/prompts/prompt-sub-agent-image-gen.md +1 -0
- klaude_code/core/prompts/prompt-sub-agent-web.md +27 -1
- klaude_code/core/reminders.py +88 -69
- klaude_code/core/task.py +44 -45
- klaude_code/core/tool/file/apply_patch_tool.py +9 -9
- klaude_code/core/tool/file/diff_builder.py +3 -5
- klaude_code/core/tool/file/edit_tool.py +23 -23
- klaude_code/core/tool/file/move_tool.py +43 -43
- klaude_code/core/tool/file/read_tool.py +44 -39
- klaude_code/core/tool/file/write_tool.py +14 -14
- klaude_code/core/tool/report_back_tool.py +4 -4
- klaude_code/core/tool/shell/bash_tool.py +23 -23
- klaude_code/core/tool/skill/skill_tool.py +7 -7
- klaude_code/core/tool/sub_agent_tool.py +38 -9
- klaude_code/core/tool/todo/todo_write_tool.py +8 -8
- klaude_code/core/tool/todo/update_plan_tool.py +6 -6
- klaude_code/core/tool/tool_abc.py +2 -2
- klaude_code/core/tool/tool_context.py +27 -0
- klaude_code/core/tool/tool_runner.py +88 -42
- klaude_code/core/tool/truncation.py +38 -20
- klaude_code/core/tool/web/mermaid_tool.py +6 -7
- klaude_code/core/tool/web/web_fetch_tool.py +68 -30
- klaude_code/core/tool/web/web_search_tool.py +15 -17
- klaude_code/core/turn.py +120 -73
- klaude_code/llm/anthropic/client.py +79 -44
- klaude_code/llm/anthropic/input.py +116 -108
- klaude_code/llm/bedrock/client.py +8 -5
- klaude_code/llm/claude/client.py +18 -8
- klaude_code/llm/client.py +4 -3
- klaude_code/llm/codex/client.py +15 -9
- klaude_code/llm/google/client.py +122 -60
- klaude_code/llm/google/input.py +94 -108
- klaude_code/llm/image.py +123 -0
- klaude_code/llm/input_common.py +136 -189
- klaude_code/llm/openai_compatible/client.py +17 -7
- klaude_code/llm/openai_compatible/input.py +36 -66
- klaude_code/llm/openai_compatible/stream.py +119 -67
- klaude_code/llm/openai_compatible/tool_call_accumulator.py +23 -11
- klaude_code/llm/openrouter/client.py +34 -9
- klaude_code/llm/openrouter/input.py +63 -64
- klaude_code/llm/openrouter/reasoning.py +22 -24
- klaude_code/llm/registry.py +20 -17
- klaude_code/llm/responses/client.py +107 -45
- klaude_code/llm/responses/input.py +115 -98
- klaude_code/llm/usage.py +52 -25
- klaude_code/protocol/__init__.py +1 -0
- klaude_code/protocol/events.py +16 -12
- klaude_code/protocol/llm_param.py +20 -2
- klaude_code/protocol/message.py +250 -0
- klaude_code/protocol/model.py +94 -281
- klaude_code/protocol/op.py +2 -2
- klaude_code/protocol/sub_agent/__init__.py +1 -0
- klaude_code/protocol/sub_agent/explore.py +10 -0
- klaude_code/protocol/sub_agent/image_gen.py +119 -0
- klaude_code/protocol/sub_agent/task.py +10 -0
- klaude_code/protocol/sub_agent/web.py +10 -0
- klaude_code/session/codec.py +6 -6
- klaude_code/session/export.py +261 -62
- klaude_code/session/selector.py +7 -24
- klaude_code/session/session.py +126 -54
- klaude_code/session/store.py +5 -32
- klaude_code/session/templates/export_session.html +1 -1
- klaude_code/session/templates/mermaid_viewer.html +1 -1
- klaude_code/trace/log.py +11 -6
- klaude_code/ui/core/input.py +1 -1
- klaude_code/ui/core/stage_manager.py +1 -8
- klaude_code/ui/modes/debug/display.py +2 -2
- klaude_code/ui/modes/repl/clipboard.py +2 -2
- klaude_code/ui/modes/repl/completers.py +18 -10
- klaude_code/ui/modes/repl/event_handler.py +136 -127
- klaude_code/ui/modes/repl/input_prompt_toolkit.py +1 -1
- klaude_code/ui/modes/repl/key_bindings.py +1 -1
- klaude_code/ui/modes/repl/renderer.py +107 -15
- klaude_code/ui/renderers/assistant.py +2 -2
- klaude_code/ui/renderers/common.py +65 -7
- klaude_code/ui/renderers/developer.py +7 -6
- klaude_code/ui/renderers/diffs.py +11 -11
- klaude_code/ui/renderers/mermaid_viewer.py +49 -2
- klaude_code/ui/renderers/metadata.py +33 -5
- klaude_code/ui/renderers/sub_agent.py +57 -16
- klaude_code/ui/renderers/thinking.py +37 -2
- klaude_code/ui/renderers/tools.py +180 -165
- klaude_code/ui/rich/live.py +3 -1
- klaude_code/ui/rich/markdown.py +39 -7
- klaude_code/ui/rich/quote.py +76 -1
- klaude_code/ui/rich/status.py +14 -8
- klaude_code/ui/rich/theme.py +8 -2
- klaude_code/ui/terminal/image.py +34 -0
- klaude_code/ui/terminal/notifier.py +2 -1
- klaude_code/ui/terminal/progress_bar.py +4 -4
- klaude_code/ui/terminal/selector.py +22 -4
- klaude_code/ui/utils/common.py +11 -2
- {klaude_code-1.9.0.dist-info → klaude_code-2.0.0.dist-info}/METADATA +4 -2
- klaude_code-2.0.0.dist-info/RECORD +229 -0
- klaude_code-1.9.0.dist-info/RECORD +0 -224
- {klaude_code-1.9.0.dist-info → klaude_code-2.0.0.dist-info}/WHEEL +0 -0
- {klaude_code-1.9.0.dist-info → klaude_code-2.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -9,7 +9,7 @@ from pydantic import BaseModel, field_validator
|
|
|
9
9
|
from klaude_code.core.tool.tool_abc import ToolABC, load_desc
|
|
10
10
|
from klaude_code.core.tool.tool_context import get_current_todo_context
|
|
11
11
|
from klaude_code.core.tool.tool_registry import register
|
|
12
|
-
from klaude_code.protocol import llm_param, model, tools
|
|
12
|
+
from klaude_code.protocol import llm_param, message, model, tools
|
|
13
13
|
|
|
14
14
|
from .todo_write_tool import get_new_completed_todos
|
|
15
15
|
|
|
@@ -79,15 +79,15 @@ class UpdatePlanTool(ToolABC):
|
|
|
79
79
|
)
|
|
80
80
|
|
|
81
81
|
@classmethod
|
|
82
|
-
async def call(cls, arguments: str) ->
|
|
82
|
+
async def call(cls, arguments: str) -> message.ToolResultMessage:
|
|
83
83
|
try:
|
|
84
84
|
args = UpdatePlanArguments.model_validate_json(arguments)
|
|
85
85
|
except ValueError as exc:
|
|
86
|
-
return
|
|
86
|
+
return message.ToolResultMessage(status="error", output_text=f"Invalid arguments: {exc}")
|
|
87
87
|
|
|
88
88
|
todo_context = get_current_todo_context()
|
|
89
89
|
if todo_context is None:
|
|
90
|
-
return
|
|
90
|
+
return message.ToolResultMessage(status="error", output_text="No active session found")
|
|
91
91
|
|
|
92
92
|
new_todos = [model.TodoItem(content=item.step, status=item.status) for item in args.plan]
|
|
93
93
|
old_todos = todo_context.get_todos()
|
|
@@ -96,9 +96,9 @@ class UpdatePlanTool(ToolABC):
|
|
|
96
96
|
|
|
97
97
|
ui_extra = model.TodoUIExtra(todos=new_todos, new_completed=new_completed)
|
|
98
98
|
|
|
99
|
-
return
|
|
99
|
+
return message.ToolResultMessage(
|
|
100
100
|
status="success",
|
|
101
|
-
|
|
101
|
+
output_text="Plan updated",
|
|
102
102
|
ui_extra=model.TodoListUIExtra(todo_list=ui_extra),
|
|
103
103
|
side_effects=[model.ToolSideEffect.TODO_CHANGE],
|
|
104
104
|
)
|
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
from enum import Enum
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
-
from klaude_code.protocol import llm_param,
|
|
7
|
+
from klaude_code.protocol import llm_param, message
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def load_desc(path: Path, substitutions: dict[str, str] | None = None) -> str:
|
|
@@ -27,7 +27,7 @@ class ToolABC(ABC):
|
|
|
27
27
|
|
|
28
28
|
@classmethod
|
|
29
29
|
@abstractmethod
|
|
30
|
-
async def call(cls, arguments: str) ->
|
|
30
|
+
async def call(cls, arguments: str) -> message.ToolResultMessage:
|
|
31
31
|
raise NotImplementedError
|
|
32
32
|
|
|
33
33
|
|
|
@@ -119,3 +119,30 @@ def get_current_todo_context() -> TodoContext | None:
|
|
|
119
119
|
current_run_subtask_callback: ContextVar[Callable[[model.SubAgentState], Awaitable[SubAgentResult]] | None] = (
|
|
120
120
|
ContextVar("current_run_subtask_callback", default=None)
|
|
121
121
|
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# Allows sub-agent execution to record the created/used session id for the currently
|
|
125
|
+
# executing tool call (used by ToolExecutor.cancel() to include session_id in UIExtra).
|
|
126
|
+
current_sub_agent_session_id_recorder: ContextVar[Callable[[str], None] | None] = ContextVar(
|
|
127
|
+
"current_sub_agent_session_id_recorder",
|
|
128
|
+
default=None,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def record_sub_agent_session_id(session_id: str) -> None:
|
|
133
|
+
"""Record the sub-agent session id for the current tool call, if supported."""
|
|
134
|
+
|
|
135
|
+
recorder = current_sub_agent_session_id_recorder.get()
|
|
136
|
+
if recorder is None:
|
|
137
|
+
return
|
|
138
|
+
recorder(session_id)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Tracks sub-agent resume claims for the current turn.
|
|
142
|
+
#
|
|
143
|
+
# This is used to reject multiple sub-agent tool calls in the same LLM response
|
|
144
|
+
# that attempt to resume the same agent ID.
|
|
145
|
+
current_sub_agent_resume_claims: ContextVar[set[str] | None] = ContextVar(
|
|
146
|
+
"current_sub_agent_resume_claims",
|
|
147
|
+
default=None,
|
|
148
|
+
)
|
|
@@ -2,14 +2,23 @@ import asyncio
|
|
|
2
2
|
from collections.abc import AsyncGenerator, Callable, Iterable, Sequence
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
|
|
5
|
-
from klaude_code import
|
|
5
|
+
from klaude_code.const import CANCEL_OUTPUT
|
|
6
6
|
from klaude_code.core.tool.report_back_tool import ReportBackTool
|
|
7
7
|
from klaude_code.core.tool.tool_abc import ToolABC, ToolConcurrencyPolicy
|
|
8
|
+
from klaude_code.core.tool.tool_context import current_sub_agent_session_id_recorder
|
|
8
9
|
from klaude_code.core.tool.truncation import truncate_tool_output
|
|
9
|
-
from klaude_code.protocol import model, tools
|
|
10
|
+
from klaude_code.protocol import message, model, tools
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class ToolCallRequest:
|
|
15
|
+
response_id: str | None
|
|
16
|
+
call_id: str
|
|
17
|
+
tool_name: str
|
|
18
|
+
arguments_json: str
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def run_tool(tool_call: ToolCallRequest, registry: dict[str, type[ToolABC]]) -> message.ToolResultMessage:
|
|
13
22
|
"""Execute a tool call and return the result.
|
|
14
23
|
|
|
15
24
|
Args:
|
|
@@ -20,26 +29,26 @@ async def run_tool(tool_call: model.ToolCallItem, registry: dict[str, type[ToolA
|
|
|
20
29
|
The result of the tool execution.
|
|
21
30
|
"""
|
|
22
31
|
# Special handling for report_back tool (not registered in global registry)
|
|
23
|
-
if tool_call.
|
|
24
|
-
tool_result = await ReportBackTool.call(tool_call.
|
|
32
|
+
if tool_call.tool_name == tools.REPORT_BACK:
|
|
33
|
+
tool_result = await ReportBackTool.call(tool_call.arguments_json)
|
|
25
34
|
tool_result.call_id = tool_call.call_id
|
|
26
|
-
tool_result.tool_name = tool_call.
|
|
35
|
+
tool_result.tool_name = tool_call.tool_name
|
|
27
36
|
return tool_result
|
|
28
37
|
|
|
29
|
-
if tool_call.
|
|
30
|
-
return
|
|
38
|
+
if tool_call.tool_name not in registry:
|
|
39
|
+
return message.ToolResultMessage(
|
|
31
40
|
call_id=tool_call.call_id,
|
|
32
|
-
|
|
41
|
+
output_text=f"Tool {tool_call.tool_name} not exists",
|
|
33
42
|
status="error",
|
|
34
|
-
tool_name=tool_call.
|
|
43
|
+
tool_name=tool_call.tool_name,
|
|
35
44
|
)
|
|
36
45
|
try:
|
|
37
|
-
tool_result = await registry[tool_call.
|
|
46
|
+
tool_result = await registry[tool_call.tool_name].call(tool_call.arguments_json)
|
|
38
47
|
tool_result.call_id = tool_call.call_id
|
|
39
|
-
tool_result.tool_name = tool_call.
|
|
40
|
-
if tool_result.
|
|
41
|
-
truncation_result = truncate_tool_output(tool_result.
|
|
42
|
-
tool_result.
|
|
48
|
+
tool_result.tool_name = tool_call.tool_name
|
|
49
|
+
if tool_result.output_text:
|
|
50
|
+
truncation_result = truncate_tool_output(tool_result.output_text, tool_call)
|
|
51
|
+
tool_result.output_text = truncation_result.output
|
|
43
52
|
if truncation_result.was_truncated and truncation_result.saved_file_path:
|
|
44
53
|
tool_result.ui_extra = model.TruncationUIExtra(
|
|
45
54
|
saved_file_path=truncation_result.saved_file_path,
|
|
@@ -51,11 +60,11 @@ async def run_tool(tool_call: model.ToolCallItem, registry: dict[str, type[ToolA
|
|
|
51
60
|
# Propagate cooperative cancellation so outer layers can handle interrupts correctly.
|
|
52
61
|
raise
|
|
53
62
|
except Exception as e:
|
|
54
|
-
return
|
|
63
|
+
return message.ToolResultMessage(
|
|
55
64
|
call_id=tool_call.call_id,
|
|
56
|
-
|
|
65
|
+
output_text=f"Tool {tool_call.tool_name} execution error: {e.__class__.__name__} {e}",
|
|
57
66
|
status="error",
|
|
58
|
-
tool_name=tool_call.
|
|
67
|
+
tool_name=tool_call.tool_name,
|
|
59
68
|
)
|
|
60
69
|
|
|
61
70
|
|
|
@@ -63,15 +72,18 @@ async def run_tool(tool_call: model.ToolCallItem, registry: dict[str, type[ToolA
|
|
|
63
72
|
class ToolExecutionCallStarted:
|
|
64
73
|
"""Represents the start of a tool call execution."""
|
|
65
74
|
|
|
66
|
-
tool_call:
|
|
75
|
+
tool_call: ToolCallRequest
|
|
67
76
|
|
|
68
77
|
|
|
69
78
|
@dataclass
|
|
70
79
|
class ToolExecutionResult:
|
|
71
80
|
"""Represents the completion of a tool call with its result."""
|
|
72
81
|
|
|
73
|
-
tool_call:
|
|
74
|
-
tool_result:
|
|
82
|
+
tool_call: ToolCallRequest
|
|
83
|
+
tool_result: message.ToolResultMessage
|
|
84
|
+
# Whether this is the last ToolExecutionResult emitted in the current turn.
|
|
85
|
+
# Used by UI to decide whether to close the tree prefix.
|
|
86
|
+
is_last_in_turn: bool = False
|
|
75
87
|
|
|
76
88
|
|
|
77
89
|
@dataclass
|
|
@@ -98,16 +110,17 @@ class ToolExecutor:
|
|
|
98
110
|
self,
|
|
99
111
|
*,
|
|
100
112
|
registry: dict[str, type[ToolABC]],
|
|
101
|
-
append_history: Callable[[Sequence[
|
|
113
|
+
append_history: Callable[[Sequence[message.HistoryEvent]], None],
|
|
102
114
|
) -> None:
|
|
103
115
|
self._registry = registry
|
|
104
116
|
self._append_history = append_history
|
|
105
117
|
|
|
106
|
-
self._unfinished_calls: dict[str,
|
|
118
|
+
self._unfinished_calls: dict[str, ToolCallRequest] = {}
|
|
107
119
|
self._call_event_emitted: set[str] = set()
|
|
108
120
|
self._concurrent_tasks: set[asyncio.Task[list[ToolExecutorEvent]]] = set()
|
|
121
|
+
self._sub_agent_session_ids: dict[str, str] = {}
|
|
109
122
|
|
|
110
|
-
async def run_tools(self, tool_calls: list[
|
|
123
|
+
async def run_tools(self, tool_calls: list[ToolCallRequest]) -> AsyncGenerator[ToolExecutorEvent]:
|
|
111
124
|
"""Run the given tool calls and yield execution events.
|
|
112
125
|
|
|
113
126
|
Tool calls are partitioned into regular tools and sub-agent tools. Regular tools
|
|
@@ -120,8 +133,15 @@ class ToolExecutor:
|
|
|
120
133
|
|
|
121
134
|
sequential_tool_calls, concurrent_tool_calls = self._partition_tool_calls(tool_calls)
|
|
122
135
|
|
|
136
|
+
def _mark_last_in_turn(events_to_mark: list[ToolExecutorEvent], *, is_last_in_turn: bool) -> None:
|
|
137
|
+
if not events_to_mark:
|
|
138
|
+
return
|
|
139
|
+
first = events_to_mark[0]
|
|
140
|
+
if isinstance(first, ToolExecutionResult):
|
|
141
|
+
first.is_last_in_turn = is_last_in_turn
|
|
142
|
+
|
|
123
143
|
# Run sequential tools one by one.
|
|
124
|
-
for tool_call in sequential_tool_calls:
|
|
144
|
+
for idx, tool_call in enumerate(sequential_tool_calls):
|
|
125
145
|
tool_call_event = self._build_tool_call_started(tool_call)
|
|
126
146
|
self._call_event_emitted.add(tool_call.call_id)
|
|
127
147
|
yield tool_call_event
|
|
@@ -132,6 +152,9 @@ class ToolExecutor:
|
|
|
132
152
|
# Propagate cooperative cancellation so the agent task can be stopped.
|
|
133
153
|
raise
|
|
134
154
|
|
|
155
|
+
is_last_in_turn = idx == len(sequential_tool_calls) - 1 and not concurrent_tool_calls
|
|
156
|
+
_mark_last_in_turn(result_events, is_last_in_turn=is_last_in_turn)
|
|
157
|
+
|
|
135
158
|
for exec_event in result_events:
|
|
136
159
|
yield exec_event
|
|
137
160
|
|
|
@@ -147,6 +170,7 @@ class ToolExecutor:
|
|
|
147
170
|
self._register_concurrent_task(task)
|
|
148
171
|
execution_tasks.append(task)
|
|
149
172
|
|
|
173
|
+
remaining = len(execution_tasks)
|
|
150
174
|
for task in asyncio.as_completed(execution_tasks):
|
|
151
175
|
# Do not swallow asyncio.CancelledError here:
|
|
152
176
|
# - If the user interrupts the main agent, the executor cancels the
|
|
@@ -158,6 +182,9 @@ class ToolExecutor:
|
|
|
158
182
|
# calling agent can stop cleanly, matching pre-refactor behavior.
|
|
159
183
|
result_events = await task
|
|
160
184
|
|
|
185
|
+
remaining -= 1
|
|
186
|
+
_mark_last_in_turn(result_events, is_last_in_turn=remaining == 0)
|
|
187
|
+
|
|
161
188
|
for exec_event in result_events:
|
|
162
189
|
yield exec_event
|
|
163
190
|
|
|
@@ -168,7 +195,7 @@ class ToolExecutor:
|
|
|
168
195
|
- For each unfinished tool call, yields a ToolExecutionCallStarted (if not
|
|
169
196
|
already emitted for this turn) followed by a ToolExecutionResult with
|
|
170
197
|
error status and a standard cancellation output. The corresponding
|
|
171
|
-
|
|
198
|
+
ToolResultMessage is appended to history via `append_history`.
|
|
172
199
|
"""
|
|
173
200
|
|
|
174
201
|
events_to_yield: list[ToolExecutorEvent] = []
|
|
@@ -182,23 +209,32 @@ class ToolExecutor:
|
|
|
182
209
|
if not self._unfinished_calls:
|
|
183
210
|
return events_to_yield
|
|
184
211
|
|
|
185
|
-
|
|
186
|
-
|
|
212
|
+
unfinished = list(self._unfinished_calls.items())
|
|
213
|
+
for idx, (call_id, tool_call) in enumerate(unfinished):
|
|
214
|
+
session_id = self._sub_agent_session_ids.get(call_id)
|
|
215
|
+
cancel_result = message.ToolResultMessage(
|
|
187
216
|
call_id=tool_call.call_id,
|
|
188
|
-
|
|
189
|
-
status="
|
|
190
|
-
tool_name=tool_call.
|
|
191
|
-
ui_extra=None,
|
|
217
|
+
output_text=CANCEL_OUTPUT,
|
|
218
|
+
status="aborted",
|
|
219
|
+
tool_name=tool_call.tool_name,
|
|
220
|
+
ui_extra=model.SessionIdUIExtra(session_id=session_id) if session_id else None,
|
|
192
221
|
)
|
|
193
222
|
|
|
194
223
|
if call_id not in self._call_event_emitted:
|
|
195
224
|
events_to_yield.append(ToolExecutionCallStarted(tool_call=tool_call))
|
|
196
225
|
self._call_event_emitted.add(call_id)
|
|
197
226
|
|
|
198
|
-
events_to_yield.append(
|
|
227
|
+
events_to_yield.append(
|
|
228
|
+
ToolExecutionResult(
|
|
229
|
+
tool_call=tool_call,
|
|
230
|
+
tool_result=cancel_result,
|
|
231
|
+
is_last_in_turn=idx == len(unfinished) - 1,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
199
234
|
|
|
200
235
|
self._append_history([cancel_result])
|
|
201
236
|
self._unfinished_calls.pop(call_id, None)
|
|
237
|
+
self._sub_agent_session_ids.pop(call_id, None)
|
|
202
238
|
|
|
203
239
|
return events_to_yield
|
|
204
240
|
|
|
@@ -212,12 +248,12 @@ class ToolExecutor:
|
|
|
212
248
|
|
|
213
249
|
def _partition_tool_calls(
|
|
214
250
|
self,
|
|
215
|
-
tool_calls: list[
|
|
216
|
-
) -> tuple[list[
|
|
217
|
-
sequential_tool_calls: list[
|
|
218
|
-
concurrent_tool_calls: list[
|
|
251
|
+
tool_calls: list[ToolCallRequest],
|
|
252
|
+
) -> tuple[list[ToolCallRequest], list[ToolCallRequest]]:
|
|
253
|
+
sequential_tool_calls: list[ToolCallRequest] = []
|
|
254
|
+
concurrent_tool_calls: list[ToolCallRequest] = []
|
|
219
255
|
for tool_call in tool_calls:
|
|
220
|
-
tool_cls = self._registry.get(tool_call.
|
|
256
|
+
tool_cls = self._registry.get(tool_call.tool_name)
|
|
221
257
|
policy = (
|
|
222
258
|
tool_cls.metadata().concurrency_policy if tool_cls is not None else ToolConcurrencyPolicy.SEQUENTIAL
|
|
223
259
|
)
|
|
@@ -227,22 +263,32 @@ class ToolExecutor:
|
|
|
227
263
|
sequential_tool_calls.append(tool_call)
|
|
228
264
|
return sequential_tool_calls, concurrent_tool_calls
|
|
229
265
|
|
|
230
|
-
def _build_tool_call_started(self, tool_call:
|
|
266
|
+
def _build_tool_call_started(self, tool_call: ToolCallRequest) -> ToolExecutionCallStarted:
|
|
231
267
|
return ToolExecutionCallStarted(tool_call=tool_call)
|
|
232
268
|
|
|
233
|
-
async def _run_single_tool_call(self, tool_call:
|
|
234
|
-
|
|
269
|
+
async def _run_single_tool_call(self, tool_call: ToolCallRequest) -> list[ToolExecutorEvent]:
|
|
270
|
+
def _record_sub_agent_session_id(session_id: str) -> None:
|
|
271
|
+
# Keep the first recorded id if multiple writes happen.
|
|
272
|
+
if tool_call.call_id not in self._sub_agent_session_ids:
|
|
273
|
+
self._sub_agent_session_ids[tool_call.call_id] = session_id
|
|
274
|
+
|
|
275
|
+
recorder_token = current_sub_agent_session_id_recorder.set(_record_sub_agent_session_id)
|
|
276
|
+
try:
|
|
277
|
+
tool_result: message.ToolResultMessage = await run_tool(tool_call, self._registry)
|
|
278
|
+
finally:
|
|
279
|
+
current_sub_agent_session_id_recorder.reset(recorder_token)
|
|
235
280
|
|
|
236
281
|
self._append_history([tool_result])
|
|
237
282
|
|
|
238
283
|
result_event = ToolExecutionResult(tool_call=tool_call, tool_result=tool_result)
|
|
239
284
|
|
|
240
285
|
self._unfinished_calls.pop(tool_call.call_id, None)
|
|
286
|
+
self._sub_agent_session_ids.pop(tool_call.call_id, None)
|
|
241
287
|
|
|
242
288
|
extra_events = self._build_tool_side_effect_events(tool_result)
|
|
243
289
|
return [result_event, *extra_events]
|
|
244
290
|
|
|
245
|
-
def _build_tool_side_effect_events(self, tool_result:
|
|
291
|
+
def _build_tool_side_effect_events(self, tool_result: message.ToolResultMessage) -> list[ToolExecutorEvent]:
|
|
246
292
|
side_effects = tool_result.side_effects
|
|
247
293
|
if not side_effects:
|
|
248
294
|
return []
|
|
@@ -4,10 +4,28 @@ import time
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
from typing import Protocol
|
|
7
8
|
from urllib.parse import urlparse
|
|
8
9
|
|
|
9
|
-
from klaude_code import
|
|
10
|
-
|
|
10
|
+
from klaude_code.const import (
|
|
11
|
+
TOOL_OUTPUT_DISPLAY_HEAD,
|
|
12
|
+
TOOL_OUTPUT_DISPLAY_TAIL,
|
|
13
|
+
TOOL_OUTPUT_MAX_LENGTH,
|
|
14
|
+
TOOL_OUTPUT_TRUNCATION_DIR,
|
|
15
|
+
URL_FILENAME_MAX_LENGTH,
|
|
16
|
+
)
|
|
17
|
+
from klaude_code.protocol import tools
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ToolCallLike(Protocol):
|
|
21
|
+
@property
|
|
22
|
+
def tool_name(self) -> str: ...
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def call_id(self) -> str: ...
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def arguments_json(self) -> str: ...
|
|
11
29
|
|
|
12
30
|
|
|
13
31
|
@dataclass
|
|
@@ -40,14 +58,14 @@ def _extract_url_filename(url: str) -> str:
|
|
|
40
58
|
# Sanitize: keep only alphanumeric, underscore, hyphen
|
|
41
59
|
name = re.sub(r"[^a-zA-Z0-9_\-]", "_", name)
|
|
42
60
|
# Limit length
|
|
43
|
-
return name[:
|
|
61
|
+
return name[:URL_FILENAME_MAX_LENGTH] if len(name) > URL_FILENAME_MAX_LENGTH else name
|
|
44
62
|
|
|
45
63
|
|
|
46
64
|
class TruncationStrategy(ABC):
|
|
47
65
|
"""Abstract base class for tool output truncation strategies."""
|
|
48
66
|
|
|
49
67
|
@abstractmethod
|
|
50
|
-
def truncate(self, output: str, tool_call:
|
|
68
|
+
def truncate(self, output: str, tool_call: ToolCallLike | None = None) -> TruncationResult:
|
|
51
69
|
"""Truncate the output according to the strategy."""
|
|
52
70
|
...
|
|
53
71
|
|
|
@@ -55,13 +73,13 @@ class TruncationStrategy(ABC):
|
|
|
55
73
|
class SimpleTruncationStrategy(TruncationStrategy):
|
|
56
74
|
"""Simple character-based truncation strategy."""
|
|
57
75
|
|
|
58
|
-
def __init__(self, max_length: int =
|
|
76
|
+
def __init__(self, max_length: int = TOOL_OUTPUT_MAX_LENGTH):
|
|
59
77
|
self.max_length = max_length
|
|
60
78
|
|
|
61
|
-
def truncate(self, output: str, tool_call:
|
|
79
|
+
def truncate(self, output: str, tool_call: ToolCallLike | None = None) -> TruncationResult:
|
|
62
80
|
if len(output) > self.max_length:
|
|
63
81
|
truncated_length = len(output) - self.max_length
|
|
64
|
-
truncated_output = output[: self.max_length] + f"
|
|
82
|
+
truncated_output = output[: self.max_length] + f"… (truncated {truncated_length} characters)"
|
|
65
83
|
return TruncationResult(
|
|
66
84
|
output=truncated_output,
|
|
67
85
|
was_truncated=True,
|
|
@@ -76,21 +94,21 @@ class SmartTruncationStrategy(TruncationStrategy):
|
|
|
76
94
|
|
|
77
95
|
def __init__(
|
|
78
96
|
self,
|
|
79
|
-
max_length: int =
|
|
80
|
-
head_chars: int =
|
|
81
|
-
tail_chars: int =
|
|
82
|
-
truncation_dir: str =
|
|
97
|
+
max_length: int = TOOL_OUTPUT_MAX_LENGTH,
|
|
98
|
+
head_chars: int = TOOL_OUTPUT_DISPLAY_HEAD,
|
|
99
|
+
tail_chars: int = TOOL_OUTPUT_DISPLAY_TAIL,
|
|
100
|
+
truncation_dir: str = TOOL_OUTPUT_TRUNCATION_DIR,
|
|
83
101
|
):
|
|
84
102
|
self.max_length = max_length
|
|
85
103
|
self.head_chars = head_chars
|
|
86
104
|
self.tail_chars = tail_chars
|
|
87
105
|
self.truncation_dir = Path(truncation_dir)
|
|
88
106
|
|
|
89
|
-
def _get_file_identifier(self, tool_call:
|
|
107
|
+
def _get_file_identifier(self, tool_call: ToolCallLike | None) -> str:
|
|
90
108
|
"""Get a file identifier based on tool call. For WebFetch, use URL; otherwise use call_id."""
|
|
91
|
-
if tool_call and tool_call.
|
|
109
|
+
if tool_call and tool_call.tool_name == tools.WEB_FETCH:
|
|
92
110
|
try:
|
|
93
|
-
args = json.loads(tool_call.
|
|
111
|
+
args = json.loads(tool_call.arguments_json)
|
|
94
112
|
url = args.get("url", "")
|
|
95
113
|
if url:
|
|
96
114
|
return _extract_url_filename(url)
|
|
@@ -101,12 +119,12 @@ class SmartTruncationStrategy(TruncationStrategy):
|
|
|
101
119
|
return tool_call.call_id.replace("/", "_")
|
|
102
120
|
return "unknown"
|
|
103
121
|
|
|
104
|
-
def _save_to_file(self, output: str, tool_call:
|
|
122
|
+
def _save_to_file(self, output: str, tool_call: ToolCallLike | None) -> str | None:
|
|
105
123
|
"""Save full output to file. Returns file path or None on failure."""
|
|
106
124
|
try:
|
|
107
125
|
self.truncation_dir.mkdir(parents=True, exist_ok=True)
|
|
108
126
|
timestamp = int(time.time())
|
|
109
|
-
tool_name = (tool_call.
|
|
127
|
+
tool_name = (tool_call.tool_name if tool_call else "unknown").replace("/", "_")
|
|
110
128
|
identifier = self._get_file_identifier(tool_call)
|
|
111
129
|
filename = f"{tool_name}-{identifier}-{timestamp}.txt"
|
|
112
130
|
file_path = self.truncation_dir / filename
|
|
@@ -115,8 +133,8 @@ class SmartTruncationStrategy(TruncationStrategy):
|
|
|
115
133
|
except OSError:
|
|
116
134
|
return None
|
|
117
135
|
|
|
118
|
-
def truncate(self, output: str, tool_call:
|
|
119
|
-
if tool_call and tool_call.
|
|
136
|
+
def truncate(self, output: str, tool_call: ToolCallLike | None = None) -> TruncationResult:
|
|
137
|
+
if tool_call and tool_call.tool_name == tools.READ:
|
|
120
138
|
# Do not truncate Read tool outputs
|
|
121
139
|
return TruncationResult(output=output, was_truncated=False, original_length=len(output))
|
|
122
140
|
|
|
@@ -153,7 +171,7 @@ class SmartTruncationStrategy(TruncationStrategy):
|
|
|
153
171
|
|
|
154
172
|
truncated_output = (
|
|
155
173
|
f"{header}{head_content}\n\n"
|
|
156
|
-
f"<system-reminder
|
|
174
|
+
f"<system-reminder>… {truncated_length} characters omitted …</system-reminder>\n\n"
|
|
157
175
|
f"{tail_content}"
|
|
158
176
|
)
|
|
159
177
|
|
|
@@ -180,6 +198,6 @@ def set_truncation_strategy(strategy: TruncationStrategy) -> None:
|
|
|
180
198
|
_default_strategy = strategy
|
|
181
199
|
|
|
182
200
|
|
|
183
|
-
def truncate_tool_output(output: str, tool_call:
|
|
201
|
+
def truncate_tool_output(output: str, tool_call: ToolCallLike | None = None) -> TruncationResult:
|
|
184
202
|
"""Truncate tool output using the current strategy."""
|
|
185
203
|
return get_truncation_strategy().truncate(output, tool_call)
|
|
@@ -7,11 +7,10 @@ from pathlib import Path
|
|
|
7
7
|
|
|
8
8
|
from pydantic import BaseModel, Field
|
|
9
9
|
|
|
10
|
+
from klaude_code.const import MERMAID_LIVE_PREFIX
|
|
10
11
|
from klaude_code.core.tool.tool_abc import ToolABC, load_desc
|
|
11
12
|
from klaude_code.core.tool.tool_registry import register
|
|
12
|
-
from klaude_code.protocol import llm_param, model, tools
|
|
13
|
-
|
|
14
|
-
_MERMAID_LIVE_PREFIX = "https://mermaid.live/view#pako:"
|
|
13
|
+
from klaude_code.protocol import llm_param, message, model, tools
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
@register(tools.MERMAID)
|
|
@@ -41,17 +40,17 @@ class MermaidTool(ToolABC):
|
|
|
41
40
|
)
|
|
42
41
|
|
|
43
42
|
@classmethod
|
|
44
|
-
async def call(cls, arguments: str) ->
|
|
43
|
+
async def call(cls, arguments: str) -> message.ToolResultMessage:
|
|
45
44
|
try:
|
|
46
45
|
args = cls.MermaidArguments.model_validate_json(arguments)
|
|
47
46
|
except Exception as exc: # pragma: no cover - defensive
|
|
48
|
-
return
|
|
47
|
+
return message.ToolResultMessage(status="error", output_text=f"Invalid arguments: {exc}")
|
|
49
48
|
|
|
50
49
|
link = cls._build_link(args.code)
|
|
51
50
|
line_count = cls._count_lines(args.code)
|
|
52
51
|
ui_extra = model.MermaidLinkUIExtra(code=args.code, link=link, line_count=line_count)
|
|
53
52
|
output = f"Mermaid diagram rendered successfully ({line_count} lines)."
|
|
54
|
-
return
|
|
53
|
+
return message.ToolResultMessage(status="success", output_text=output, ui_extra=ui_extra)
|
|
55
54
|
|
|
56
55
|
@staticmethod
|
|
57
56
|
def _build_link(code: str) -> str:
|
|
@@ -64,7 +63,7 @@ class MermaidTool(ToolABC):
|
|
|
64
63
|
json_payload = json.dumps(state, ensure_ascii=False)
|
|
65
64
|
compressed = zlib.compress(json_payload.encode("utf-8"), level=9)
|
|
66
65
|
encoded = base64.urlsafe_b64encode(compressed).decode("ascii").rstrip("=")
|
|
67
|
-
return f"{
|
|
66
|
+
return f"{MERMAID_LIVE_PREFIX}{encoded}"
|
|
68
67
|
|
|
69
68
|
@staticmethod
|
|
70
69
|
def _count_lines(code: str) -> int:
|