klaude-code 1.2.8__py3-none-any.whl → 1.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klaude_code/auth/codex/__init__.py +1 -1
- klaude_code/cli/main.py +12 -1
- klaude_code/cli/runtime.py +7 -11
- klaude_code/command/__init__.py +68 -21
- klaude_code/command/clear_cmd.py +6 -2
- klaude_code/command/command_abc.py +5 -2
- klaude_code/command/diff_cmd.py +5 -2
- klaude_code/command/export_cmd.py +7 -4
- klaude_code/command/help_cmd.py +6 -2
- klaude_code/command/model_cmd.py +5 -2
- klaude_code/command/prompt-deslop.md +14 -0
- klaude_code/command/prompt_command.py +8 -3
- klaude_code/command/refresh_cmd.py +6 -2
- klaude_code/command/registry.py +17 -5
- klaude_code/command/release_notes_cmd.py +89 -0
- klaude_code/command/status_cmd.py +98 -56
- klaude_code/command/terminal_setup_cmd.py +7 -4
- klaude_code/const/__init__.py +1 -1
- klaude_code/core/agent.py +66 -26
- klaude_code/core/executor.py +2 -2
- klaude_code/core/manager/agent_manager.py +6 -7
- klaude_code/core/manager/llm_clients.py +47 -22
- klaude_code/core/manager/llm_clients_builder.py +19 -7
- klaude_code/core/manager/sub_agent_manager.py +6 -2
- klaude_code/core/prompt.py +38 -28
- klaude_code/core/reminders.py +4 -7
- klaude_code/core/task.py +59 -40
- klaude_code/core/tool/__init__.py +2 -0
- klaude_code/core/tool/file/_utils.py +30 -0
- klaude_code/core/tool/file/apply_patch_tool.py +1 -1
- klaude_code/core/tool/file/edit_tool.py +6 -31
- klaude_code/core/tool/file/multi_edit_tool.py +7 -32
- klaude_code/core/tool/file/read_tool.py +6 -18
- klaude_code/core/tool/file/write_tool.py +6 -31
- klaude_code/core/tool/memory/__init__.py +5 -0
- klaude_code/core/tool/memory/memory_tool.py +2 -2
- klaude_code/core/tool/memory/skill_loader.py +2 -1
- klaude_code/core/tool/memory/skill_tool.py +13 -0
- klaude_code/core/tool/sub_agent_tool.py +2 -1
- klaude_code/core/tool/todo/todo_write_tool.py +1 -1
- klaude_code/core/tool/todo/update_plan_tool.py +1 -1
- klaude_code/core/tool/tool_context.py +21 -4
- klaude_code/core/tool/tool_runner.py +5 -8
- klaude_code/core/tool/web/mermaid_tool.py +1 -4
- klaude_code/core/turn.py +40 -37
- klaude_code/llm/__init__.py +2 -12
- klaude_code/llm/anthropic/client.py +14 -44
- klaude_code/llm/client.py +2 -2
- klaude_code/llm/codex/client.py +4 -3
- klaude_code/llm/input_common.py +0 -6
- klaude_code/llm/openai_compatible/client.py +31 -74
- klaude_code/llm/openai_compatible/input.py +6 -4
- klaude_code/llm/openai_compatible/stream_processor.py +82 -0
- klaude_code/llm/openrouter/client.py +32 -62
- klaude_code/llm/openrouter/input.py +4 -27
- klaude_code/llm/registry.py +33 -7
- klaude_code/llm/responses/client.py +16 -48
- klaude_code/llm/responses/input.py +1 -1
- klaude_code/llm/usage.py +61 -11
- klaude_code/protocol/commands.py +1 -0
- klaude_code/protocol/events.py +11 -2
- klaude_code/protocol/model.py +147 -24
- klaude_code/protocol/op.py +1 -0
- klaude_code/protocol/sub_agent.py +5 -1
- klaude_code/session/export.py +56 -32
- klaude_code/session/session.py +43 -21
- klaude_code/session/templates/export_session.html +4 -1
- klaude_code/ui/core/input.py +1 -1
- klaude_code/ui/modes/repl/__init__.py +1 -5
- klaude_code/ui/modes/repl/clipboard.py +5 -5
- klaude_code/ui/modes/repl/event_handler.py +153 -54
- klaude_code/ui/modes/repl/renderer.py +4 -4
- klaude_code/ui/renderers/developer.py +35 -25
- klaude_code/ui/renderers/metadata.py +68 -30
- klaude_code/ui/renderers/tools.py +53 -87
- klaude_code/ui/rich/markdown.py +5 -5
- klaude_code/ui/terminal/control.py +2 -2
- klaude_code/version.py +3 -3
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/METADATA +1 -1
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/RECORD +82 -78
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/WHEEL +0 -0
- {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/entry_points.txt +0 -0
klaude_code/llm/input_common.py
CHANGED
|
@@ -49,10 +49,6 @@ class AssistantGroup:
|
|
|
49
49
|
|
|
50
50
|
text_content: str | None = None
|
|
51
51
|
tool_calls: list[model.ToolCallItem] = field(default_factory=lambda: [])
|
|
52
|
-
reasoning_text: list[model.ReasoningTextItem] = field(default_factory=lambda: [])
|
|
53
|
-
reasoning_encrypted: list[model.ReasoningEncryptedItem] = field(default_factory=lambda: [])
|
|
54
|
-
# Preserve original ordering of reasoning items for providers that
|
|
55
|
-
# need to emit them as an ordered stream (e.g. OpenRouter).
|
|
56
52
|
reasoning_items: list[model.ReasoningTextItem | model.ReasoningEncryptedItem] = field(default_factory=lambda: [])
|
|
57
53
|
|
|
58
54
|
|
|
@@ -184,10 +180,8 @@ def parse_message_groups(history: list[model.ConversationItem]) -> list[MessageG
|
|
|
184
180
|
case model.ToolCallItem():
|
|
185
181
|
group.tool_calls.append(item)
|
|
186
182
|
case model.ReasoningTextItem():
|
|
187
|
-
group.reasoning_text.append(item)
|
|
188
183
|
group.reasoning_items.append(item)
|
|
189
184
|
case model.ReasoningEncryptedItem():
|
|
190
|
-
group.reasoning_encrypted.append(item)
|
|
191
185
|
group.reasoning_items.append(item)
|
|
192
186
|
case _:
|
|
193
187
|
pass
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from collections.abc import AsyncGenerator
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import override
|
|
4
4
|
|
|
5
5
|
import httpx
|
|
6
6
|
import openai
|
|
7
7
|
|
|
8
|
-
|
|
9
8
|
from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
|
|
10
9
|
from klaude_code.llm.input_common import apply_config_defaults
|
|
11
10
|
from klaude_code.llm.openai_compatible.input import convert_history_to_input, convert_tool_schema
|
|
12
|
-
from klaude_code.llm.openai_compatible.
|
|
11
|
+
from klaude_code.llm.openai_compatible.stream_processor import StreamStateManager
|
|
13
12
|
from klaude_code.llm.registry import register
|
|
14
13
|
from klaude_code.llm.usage import MetadataTracker, convert_usage
|
|
15
14
|
from klaude_code.protocol import llm_param, model
|
|
@@ -48,10 +47,10 @@ class OpenAICompatibleClient(LLMClientABC):
|
|
|
48
47
|
messages = convert_history_to_input(param.input, param.system, param.model)
|
|
49
48
|
tools = convert_tool_schema(param.tools)
|
|
50
49
|
|
|
51
|
-
metadata_tracker = MetadataTracker(cost_config=self.
|
|
50
|
+
metadata_tracker = MetadataTracker(cost_config=self.get_llm_config().cost)
|
|
52
51
|
|
|
53
52
|
extra_body = {}
|
|
54
|
-
extra_headers = {"extra": json.dumps({"session_id": param.session_id})}
|
|
53
|
+
extra_headers = {"extra": json.dumps({"session_id": param.session_id}, sort_keys=True)}
|
|
55
54
|
|
|
56
55
|
if param.thinking:
|
|
57
56
|
extra_body["thinking"] = {
|
|
@@ -74,42 +73,7 @@ class OpenAICompatibleClient(LLMClientABC):
|
|
|
74
73
|
extra_headers=extra_headers,
|
|
75
74
|
)
|
|
76
75
|
|
|
77
|
-
|
|
78
|
-
accumulated_reasoning: list[str] = []
|
|
79
|
-
accumulated_content: list[str] = []
|
|
80
|
-
accumulated_tool_calls: ToolCallAccumulatorABC = BasicToolCallAccumulator()
|
|
81
|
-
emitted_tool_start_indices: set[int] = set()
|
|
82
|
-
response_id: str | None = None
|
|
83
|
-
|
|
84
|
-
def flush_reasoning_items() -> list[model.ConversationItem]:
|
|
85
|
-
nonlocal accumulated_reasoning
|
|
86
|
-
if not accumulated_reasoning:
|
|
87
|
-
return []
|
|
88
|
-
item = model.ReasoningTextItem(
|
|
89
|
-
content="".join(accumulated_reasoning),
|
|
90
|
-
response_id=response_id,
|
|
91
|
-
model=str(param.model),
|
|
92
|
-
)
|
|
93
|
-
accumulated_reasoning = []
|
|
94
|
-
return [item]
|
|
95
|
-
|
|
96
|
-
def flush_assistant_items() -> list[model.ConversationItem]:
|
|
97
|
-
nonlocal accumulated_content
|
|
98
|
-
if len(accumulated_content) == 0:
|
|
99
|
-
return []
|
|
100
|
-
item = model.AssistantMessageItem(
|
|
101
|
-
content="".join(accumulated_content),
|
|
102
|
-
response_id=response_id,
|
|
103
|
-
)
|
|
104
|
-
accumulated_content = []
|
|
105
|
-
return [item]
|
|
106
|
-
|
|
107
|
-
def flush_tool_call_items() -> list[model.ToolCallItem]:
|
|
108
|
-
nonlocal accumulated_tool_calls
|
|
109
|
-
items: list[model.ToolCallItem] = accumulated_tool_calls.get()
|
|
110
|
-
if items:
|
|
111
|
-
accumulated_tool_calls.chunks_by_step = [] # pyright: ignore[reportAttributeAccessIssue]
|
|
112
|
-
return items
|
|
76
|
+
state = StreamStateManager(param_model=str(param.model))
|
|
113
77
|
|
|
114
78
|
try:
|
|
115
79
|
async for event in await stream:
|
|
@@ -118,14 +82,13 @@ class OpenAICompatibleClient(LLMClientABC):
|
|
|
118
82
|
style="blue",
|
|
119
83
|
debug_type=DebugType.LLM_STREAM,
|
|
120
84
|
)
|
|
121
|
-
if not response_id and event.id:
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
yield model.StartItem(response_id=response_id)
|
|
85
|
+
if not state.response_id and event.id:
|
|
86
|
+
state.set_response_id(event.id)
|
|
87
|
+
yield model.StartItem(response_id=event.id)
|
|
125
88
|
if (
|
|
126
89
|
event.usage is not None and event.usage.completion_tokens is not None # pyright: ignore[reportUnnecessaryComparison] gcp gemini will return None usage field
|
|
127
90
|
):
|
|
128
|
-
metadata_tracker.set_usage(convert_usage(event.usage, param.context_limit))
|
|
91
|
+
metadata_tracker.set_usage(convert_usage(event.usage, param.context_limit, param.max_tokens))
|
|
129
92
|
if event.model:
|
|
130
93
|
metadata_tracker.set_model_name(event.model)
|
|
131
94
|
if provider := getattr(event, "provider", None):
|
|
@@ -141,6 +104,7 @@ class OpenAICompatibleClient(LLMClientABC):
|
|
|
141
104
|
convert_usage(
|
|
142
105
|
openai.types.CompletionUsage.model_validate(getattr(event.choices[0], "usage")),
|
|
143
106
|
param.context_limit,
|
|
107
|
+
param.max_tokens,
|
|
144
108
|
)
|
|
145
109
|
)
|
|
146
110
|
|
|
@@ -152,60 +116,53 @@ class OpenAICompatibleClient(LLMClientABC):
|
|
|
152
116
|
reasoning_content = getattr(delta, "reasoning_content")
|
|
153
117
|
if reasoning_content:
|
|
154
118
|
metadata_tracker.record_token()
|
|
155
|
-
stage = "reasoning"
|
|
156
|
-
accumulated_reasoning.append(reasoning_content)
|
|
119
|
+
state.stage = "reasoning"
|
|
120
|
+
state.accumulated_reasoning.append(reasoning_content)
|
|
157
121
|
|
|
158
122
|
# Assistant
|
|
159
123
|
if delta.content and (
|
|
160
|
-
stage == "assistant" or delta.content.strip()
|
|
124
|
+
state.stage == "assistant" or delta.content.strip()
|
|
161
125
|
): # Process all content in assistant stage, filter empty content in reasoning stage
|
|
162
126
|
metadata_tracker.record_token()
|
|
163
|
-
if stage == "reasoning":
|
|
164
|
-
for item in
|
|
127
|
+
if state.stage == "reasoning":
|
|
128
|
+
for item in state.flush_reasoning():
|
|
165
129
|
yield item
|
|
166
|
-
elif stage == "tool":
|
|
167
|
-
for item in
|
|
130
|
+
elif state.stage == "tool":
|
|
131
|
+
for item in state.flush_tool_calls():
|
|
168
132
|
yield item
|
|
169
|
-
stage = "assistant"
|
|
170
|
-
accumulated_content.append(delta.content)
|
|
133
|
+
state.stage = "assistant"
|
|
134
|
+
state.accumulated_content.append(delta.content)
|
|
171
135
|
yield model.AssistantMessageDelta(
|
|
172
136
|
content=delta.content,
|
|
173
|
-
response_id=response_id,
|
|
137
|
+
response_id=state.response_id,
|
|
174
138
|
)
|
|
175
139
|
|
|
176
140
|
# Tool
|
|
177
141
|
if delta.tool_calls and len(delta.tool_calls) > 0:
|
|
178
142
|
metadata_tracker.record_token()
|
|
179
|
-
if stage == "reasoning":
|
|
180
|
-
for item in
|
|
143
|
+
if state.stage == "reasoning":
|
|
144
|
+
for item in state.flush_reasoning():
|
|
181
145
|
yield item
|
|
182
|
-
elif stage == "assistant":
|
|
183
|
-
for item in
|
|
146
|
+
elif state.stage == "assistant":
|
|
147
|
+
for item in state.flush_assistant():
|
|
184
148
|
yield item
|
|
185
|
-
stage = "tool"
|
|
149
|
+
state.stage = "tool"
|
|
186
150
|
# Emit ToolCallStartItem for new tool calls
|
|
187
151
|
for tc in delta.tool_calls:
|
|
188
|
-
if tc.index not in emitted_tool_start_indices and tc.function and tc.function.name:
|
|
189
|
-
emitted_tool_start_indices.add(tc.index)
|
|
152
|
+
if tc.index not in state.emitted_tool_start_indices and tc.function and tc.function.name:
|
|
153
|
+
state.emitted_tool_start_indices.add(tc.index)
|
|
190
154
|
yield model.ToolCallStartItem(
|
|
191
|
-
response_id=response_id,
|
|
155
|
+
response_id=state.response_id,
|
|
192
156
|
call_id=tc.id or "",
|
|
193
157
|
name=tc.function.name,
|
|
194
158
|
)
|
|
195
|
-
accumulated_tool_calls.add(delta.tool_calls)
|
|
159
|
+
state.accumulated_tool_calls.add(delta.tool_calls)
|
|
196
160
|
except (openai.OpenAIError, httpx.HTTPError) as e:
|
|
197
161
|
yield model.StreamErrorItem(error=f"{e.__class__.__name__} {str(e)}")
|
|
198
162
|
|
|
199
163
|
# Finalize
|
|
200
|
-
for item in
|
|
164
|
+
for item in state.flush_all():
|
|
201
165
|
yield item
|
|
202
166
|
|
|
203
|
-
|
|
204
|
-
yield item
|
|
205
|
-
|
|
206
|
-
if stage == "tool":
|
|
207
|
-
for tool_call_item in flush_tool_call_items():
|
|
208
|
-
yield tool_call_item
|
|
209
|
-
|
|
210
|
-
metadata_tracker.set_response_id(response_id)
|
|
167
|
+
metadata_tracker.set_response_id(state.response_id)
|
|
211
168
|
yield metadata_tracker.finalize()
|
|
@@ -10,7 +10,8 @@ from klaude_code.llm.input_common import AssistantGroup, ToolGroup, UserGroup, m
|
|
|
10
10
|
from klaude_code.protocol import llm_param, model
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def user_group_to_openai_message(group: UserGroup) -> chat.ChatCompletionMessageParam:
|
|
14
|
+
"""Convert a UserGroup to an OpenAI-compatible chat message."""
|
|
14
15
|
parts: list[ChatCompletionContentPartParam] = []
|
|
15
16
|
for text in group.text_parts:
|
|
16
17
|
parts.append({"type": "text", "text": text + "\n"})
|
|
@@ -21,7 +22,8 @@ def _user_group_to_message(group: UserGroup) -> chat.ChatCompletionMessageParam:
|
|
|
21
22
|
return {"role": "user", "content": parts}
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def
|
|
25
|
+
def tool_group_to_openai_message(group: ToolGroup) -> chat.ChatCompletionMessageParam:
|
|
26
|
+
"""Convert a ToolGroup to an OpenAI-compatible chat message."""
|
|
25
27
|
merged_text = merge_reminder_text(
|
|
26
28
|
group.tool_result.output or "<system-reminder>Tool ran without output or errors</system-reminder>",
|
|
27
29
|
group.reminder_texts,
|
|
@@ -82,9 +84,9 @@ def convert_history_to_input(
|
|
|
82
84
|
for group in parse_message_groups(history):
|
|
83
85
|
match group:
|
|
84
86
|
case UserGroup():
|
|
85
|
-
messages.append(
|
|
87
|
+
messages.append(user_group_to_openai_message(group))
|
|
86
88
|
case ToolGroup():
|
|
87
|
-
messages.append(
|
|
89
|
+
messages.append(tool_group_to_openai_message(group))
|
|
88
90
|
case AssistantGroup():
|
|
89
91
|
messages.append(_assistant_group_to_message(group))
|
|
90
92
|
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Shared stream processing utilities for OpenAI-compatible clients.
|
|
2
|
+
|
|
3
|
+
This module provides a reusable stream state manager that handles the common
|
|
4
|
+
logic for accumulating and flushing reasoning, assistant content, and tool calls
|
|
5
|
+
across different LLM providers (OpenAI-compatible, OpenRouter).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Callable, Literal
|
|
9
|
+
|
|
10
|
+
from klaude_code.llm.openai_compatible.tool_call_accumulator import BasicToolCallAccumulator, ToolCallAccumulatorABC
|
|
11
|
+
from klaude_code.protocol import model
|
|
12
|
+
|
|
13
|
+
StreamStage = Literal["waiting", "reasoning", "assistant", "tool"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StreamStateManager:
|
|
17
|
+
"""Manages streaming state and provides flush operations for accumulated content.
|
|
18
|
+
|
|
19
|
+
This class encapsulates the common state management logic used by both
|
|
20
|
+
OpenAI-compatible and OpenRouter clients, reducing code duplication.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
param_model: str,
|
|
26
|
+
response_id: str | None = None,
|
|
27
|
+
reasoning_flusher: Callable[[], list[model.ConversationItem]] | None = None,
|
|
28
|
+
):
|
|
29
|
+
self.param_model = param_model
|
|
30
|
+
self.response_id = response_id
|
|
31
|
+
self.stage: StreamStage = "waiting"
|
|
32
|
+
self.accumulated_reasoning: list[str] = []
|
|
33
|
+
self.accumulated_content: list[str] = []
|
|
34
|
+
self.accumulated_tool_calls: ToolCallAccumulatorABC = BasicToolCallAccumulator()
|
|
35
|
+
self.emitted_tool_start_indices: set[int] = set()
|
|
36
|
+
self._reasoning_flusher = reasoning_flusher
|
|
37
|
+
|
|
38
|
+
def set_response_id(self, response_id: str) -> None:
|
|
39
|
+
"""Set the response ID once received from the stream."""
|
|
40
|
+
self.response_id = response_id
|
|
41
|
+
self.accumulated_tool_calls.response_id = response_id # pyright: ignore[reportAttributeAccessIssue]
|
|
42
|
+
|
|
43
|
+
def flush_reasoning(self) -> list[model.ConversationItem]:
|
|
44
|
+
"""Flush accumulated reasoning content and return items."""
|
|
45
|
+
if self._reasoning_flusher is not None:
|
|
46
|
+
return self._reasoning_flusher()
|
|
47
|
+
if not self.accumulated_reasoning:
|
|
48
|
+
return []
|
|
49
|
+
item = model.ReasoningTextItem(
|
|
50
|
+
content="".join(self.accumulated_reasoning),
|
|
51
|
+
response_id=self.response_id,
|
|
52
|
+
model=self.param_model,
|
|
53
|
+
)
|
|
54
|
+
self.accumulated_reasoning = []
|
|
55
|
+
return [item]
|
|
56
|
+
|
|
57
|
+
def flush_assistant(self) -> list[model.ConversationItem]:
|
|
58
|
+
"""Flush accumulated assistant content and return items."""
|
|
59
|
+
if not self.accumulated_content:
|
|
60
|
+
return []
|
|
61
|
+
item = model.AssistantMessageItem(
|
|
62
|
+
content="".join(self.accumulated_content),
|
|
63
|
+
response_id=self.response_id,
|
|
64
|
+
)
|
|
65
|
+
self.accumulated_content = []
|
|
66
|
+
return [item]
|
|
67
|
+
|
|
68
|
+
def flush_tool_calls(self) -> list[model.ToolCallItem]:
|
|
69
|
+
"""Flush accumulated tool calls and return items."""
|
|
70
|
+
items: list[model.ToolCallItem] = self.accumulated_tool_calls.get()
|
|
71
|
+
if items:
|
|
72
|
+
self.accumulated_tool_calls.chunks_by_step = [] # pyright: ignore[reportAttributeAccessIssue]
|
|
73
|
+
return items
|
|
74
|
+
|
|
75
|
+
def flush_all(self) -> list[model.ConversationItem]:
|
|
76
|
+
"""Flush all accumulated content in order: reasoning, assistant, tool calls."""
|
|
77
|
+
items: list[model.ConversationItem] = []
|
|
78
|
+
items.extend(self.flush_reasoning())
|
|
79
|
+
items.extend(self.flush_assistant())
|
|
80
|
+
if self.stage == "tool":
|
|
81
|
+
items.extend(self.flush_tool_calls())
|
|
82
|
+
return items
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections.abc import AsyncGenerator
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import override
|
|
3
3
|
|
|
4
4
|
import httpx
|
|
5
5
|
import openai
|
|
@@ -7,7 +7,7 @@ import openai
|
|
|
7
7
|
from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
|
|
8
8
|
from klaude_code.llm.input_common import apply_config_defaults
|
|
9
9
|
from klaude_code.llm.openai_compatible.input import convert_tool_schema
|
|
10
|
-
from klaude_code.llm.openai_compatible.
|
|
10
|
+
from klaude_code.llm.openai_compatible.stream_processor import StreamStateManager
|
|
11
11
|
from klaude_code.llm.openrouter.input import convert_history_to_input, is_claude_model
|
|
12
12
|
from klaude_code.llm.openrouter.reasoning_handler import ReasoningDetail, ReasoningStreamHandler
|
|
13
13
|
from klaude_code.llm.registry import register
|
|
@@ -38,7 +38,7 @@ class OpenRouterClient(LLMClientABC):
|
|
|
38
38
|
messages = convert_history_to_input(param.input, param.system, param.model)
|
|
39
39
|
tools = convert_tool_schema(param.tools)
|
|
40
40
|
|
|
41
|
-
metadata_tracker = MetadataTracker(cost_config=self.
|
|
41
|
+
metadata_tracker = MetadataTracker(cost_config=self.get_llm_config().cost)
|
|
42
42
|
|
|
43
43
|
extra_body: dict[str, object] = {
|
|
44
44
|
"usage": {"include": True} # To get the cache tokens at the end of the response
|
|
@@ -73,40 +73,18 @@ class OpenRouterClient(LLMClientABC):
|
|
|
73
73
|
max_tokens=param.max_tokens,
|
|
74
74
|
tools=tools,
|
|
75
75
|
verbosity=param.verbosity,
|
|
76
|
-
extra_body=extra_body,
|
|
76
|
+
extra_body=extra_body,
|
|
77
77
|
extra_headers=extra_headers, # pyright: ignore[reportUnknownArgumentType]
|
|
78
78
|
)
|
|
79
79
|
|
|
80
|
-
stage: Literal["waiting", "reasoning", "assistant", "tool", "done"] = "waiting"
|
|
81
|
-
response_id: str | None = None
|
|
82
|
-
accumulated_content: list[str] = []
|
|
83
|
-
accumulated_tool_calls: ToolCallAccumulatorABC = BasicToolCallAccumulator()
|
|
84
|
-
emitted_tool_start_indices: set[int] = set()
|
|
85
80
|
reasoning_handler = ReasoningStreamHandler(
|
|
86
81
|
param_model=str(param.model),
|
|
87
|
-
response_id=
|
|
82
|
+
response_id=None,
|
|
83
|
+
)
|
|
84
|
+
state = StreamStateManager(
|
|
85
|
+
param_model=str(param.model),
|
|
86
|
+
reasoning_flusher=reasoning_handler.flush,
|
|
88
87
|
)
|
|
89
|
-
|
|
90
|
-
def flush_reasoning_items() -> list[model.ConversationItem]:
|
|
91
|
-
return reasoning_handler.flush()
|
|
92
|
-
|
|
93
|
-
def flush_assistant_items() -> list[model.ConversationItem]:
|
|
94
|
-
nonlocal accumulated_content
|
|
95
|
-
if len(accumulated_content) == 0:
|
|
96
|
-
return []
|
|
97
|
-
item = model.AssistantMessageItem(
|
|
98
|
-
content="".join(accumulated_content),
|
|
99
|
-
response_id=response_id,
|
|
100
|
-
)
|
|
101
|
-
accumulated_content = []
|
|
102
|
-
return [item]
|
|
103
|
-
|
|
104
|
-
def flush_tool_call_items() -> list[model.ToolCallItem]:
|
|
105
|
-
nonlocal accumulated_tool_calls
|
|
106
|
-
items: list[model.ToolCallItem] = accumulated_tool_calls.get()
|
|
107
|
-
if items:
|
|
108
|
-
accumulated_tool_calls.chunks_by_step = [] # pyright: ignore[reportAttributeAccessIssue]
|
|
109
|
-
return items
|
|
110
88
|
|
|
111
89
|
try:
|
|
112
90
|
async for event in await stream:
|
|
@@ -115,15 +93,14 @@ class OpenRouterClient(LLMClientABC):
|
|
|
115
93
|
style="blue",
|
|
116
94
|
debug_type=DebugType.LLM_STREAM,
|
|
117
95
|
)
|
|
118
|
-
if not response_id and event.id:
|
|
119
|
-
|
|
120
|
-
reasoning_handler.set_response_id(
|
|
121
|
-
|
|
122
|
-
yield model.StartItem(response_id=response_id)
|
|
96
|
+
if not state.response_id and event.id:
|
|
97
|
+
state.set_response_id(event.id)
|
|
98
|
+
reasoning_handler.set_response_id(event.id)
|
|
99
|
+
yield model.StartItem(response_id=event.id)
|
|
123
100
|
if (
|
|
124
101
|
event.usage is not None and event.usage.completion_tokens is not None # pyright: ignore[reportUnnecessaryComparison]
|
|
125
102
|
): # gcp gemini will return None usage field
|
|
126
|
-
metadata_tracker.set_usage(convert_usage(event.usage, param.context_limit))
|
|
103
|
+
metadata_tracker.set_usage(convert_usage(event.usage, param.context_limit, param.max_tokens))
|
|
127
104
|
if event.model:
|
|
128
105
|
metadata_tracker.set_model_name(event.model)
|
|
129
106
|
if provider := getattr(event, "provider", None):
|
|
@@ -140,7 +117,7 @@ class OpenRouterClient(LLMClientABC):
|
|
|
140
117
|
try:
|
|
141
118
|
reasoning_detail = ReasoningDetail.model_validate(item)
|
|
142
119
|
metadata_tracker.record_token()
|
|
143
|
-
stage = "reasoning"
|
|
120
|
+
state.stage = "reasoning"
|
|
144
121
|
for conversation_item in reasoning_handler.on_detail(reasoning_detail):
|
|
145
122
|
yield conversation_item
|
|
146
123
|
except Exception as e:
|
|
@@ -148,53 +125,46 @@ class OpenRouterClient(LLMClientABC):
|
|
|
148
125
|
|
|
149
126
|
# Assistant
|
|
150
127
|
if delta.content and (
|
|
151
|
-
stage == "assistant" or delta.content.strip()
|
|
128
|
+
state.stage == "assistant" or delta.content.strip()
|
|
152
129
|
): # Process all content in assistant stage, filter empty content in reasoning stage
|
|
153
130
|
metadata_tracker.record_token()
|
|
154
|
-
if stage == "reasoning":
|
|
155
|
-
for item in
|
|
131
|
+
if state.stage == "reasoning":
|
|
132
|
+
for item in state.flush_reasoning():
|
|
156
133
|
yield item
|
|
157
|
-
stage = "assistant"
|
|
158
|
-
accumulated_content.append(delta.content)
|
|
134
|
+
state.stage = "assistant"
|
|
135
|
+
state.accumulated_content.append(delta.content)
|
|
159
136
|
yield model.AssistantMessageDelta(
|
|
160
137
|
content=delta.content,
|
|
161
|
-
response_id=response_id,
|
|
138
|
+
response_id=state.response_id,
|
|
162
139
|
)
|
|
163
140
|
|
|
164
141
|
# Tool
|
|
165
142
|
if delta.tool_calls and len(delta.tool_calls) > 0:
|
|
166
143
|
metadata_tracker.record_token()
|
|
167
|
-
if stage == "reasoning":
|
|
168
|
-
for item in
|
|
144
|
+
if state.stage == "reasoning":
|
|
145
|
+
for item in state.flush_reasoning():
|
|
169
146
|
yield item
|
|
170
|
-
elif stage == "assistant":
|
|
171
|
-
for item in
|
|
147
|
+
elif state.stage == "assistant":
|
|
148
|
+
for item in state.flush_assistant():
|
|
172
149
|
yield item
|
|
173
|
-
stage = "tool"
|
|
150
|
+
state.stage = "tool"
|
|
174
151
|
# Emit ToolCallStartItem for new tool calls
|
|
175
152
|
for tc in delta.tool_calls:
|
|
176
|
-
if tc.index not in emitted_tool_start_indices and tc.function and tc.function.name:
|
|
177
|
-
emitted_tool_start_indices.add(tc.index)
|
|
153
|
+
if tc.index not in state.emitted_tool_start_indices and tc.function and tc.function.name:
|
|
154
|
+
state.emitted_tool_start_indices.add(tc.index)
|
|
178
155
|
yield model.ToolCallStartItem(
|
|
179
|
-
response_id=response_id,
|
|
156
|
+
response_id=state.response_id,
|
|
180
157
|
call_id=tc.id or "",
|
|
181
158
|
name=tc.function.name,
|
|
182
159
|
)
|
|
183
|
-
accumulated_tool_calls.add(delta.tool_calls)
|
|
160
|
+
state.accumulated_tool_calls.add(delta.tool_calls)
|
|
184
161
|
|
|
185
162
|
except (openai.OpenAIError, httpx.HTTPError) as e:
|
|
186
163
|
yield model.StreamErrorItem(error=f"{e.__class__.__name__} {str(e)}")
|
|
187
164
|
|
|
188
165
|
# Finalize
|
|
189
|
-
for item in
|
|
190
|
-
yield item
|
|
191
|
-
|
|
192
|
-
for item in flush_assistant_items():
|
|
166
|
+
for item in state.flush_all():
|
|
193
167
|
yield item
|
|
194
168
|
|
|
195
|
-
|
|
196
|
-
for tool_call_item in flush_tool_call_items():
|
|
197
|
-
yield tool_call_item
|
|
198
|
-
|
|
199
|
-
metadata_tracker.set_response_id(response_id)
|
|
169
|
+
metadata_tracker.set_response_id(state.response_id)
|
|
200
170
|
yield metadata_tracker.finalize()
|
|
@@ -7,9 +7,9 @@
|
|
|
7
7
|
# pyright: reportGeneralTypeIssues=false
|
|
8
8
|
|
|
9
9
|
from openai.types import chat
|
|
10
|
-
from openai.types.chat import ChatCompletionContentPartParam
|
|
11
10
|
|
|
12
|
-
from klaude_code.llm.input_common import AssistantGroup, ToolGroup, UserGroup,
|
|
11
|
+
from klaude_code.llm.input_common import AssistantGroup, ToolGroup, UserGroup, parse_message_groups
|
|
12
|
+
from klaude_code.llm.openai_compatible.input import tool_group_to_openai_message, user_group_to_openai_message
|
|
13
13
|
from klaude_code.protocol import model
|
|
14
14
|
|
|
15
15
|
|
|
@@ -25,29 +25,6 @@ def is_gemini_model(model_name: str | None) -> bool:
|
|
|
25
25
|
return model_name is not None and model_name.startswith("google/gemini")
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def _user_group_to_message(group: UserGroup) -> chat.ChatCompletionMessageParam:
|
|
29
|
-
parts: list[ChatCompletionContentPartParam] = []
|
|
30
|
-
for text in group.text_parts:
|
|
31
|
-
parts.append({"type": "text", "text": text + "\n"})
|
|
32
|
-
for image in group.images:
|
|
33
|
-
parts.append({"type": "image_url", "image_url": {"url": image.image_url.url}})
|
|
34
|
-
if not parts:
|
|
35
|
-
parts.append({"type": "text", "text": ""})
|
|
36
|
-
return {"role": "user", "content": parts}
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def _tool_group_to_message(group: ToolGroup) -> chat.ChatCompletionMessageParam:
|
|
40
|
-
merged_text = merge_reminder_text(
|
|
41
|
-
group.tool_result.output or "<system-reminder>Tool ran without output or errors</system-reminder>",
|
|
42
|
-
group.reminder_texts,
|
|
43
|
-
)
|
|
44
|
-
return {
|
|
45
|
-
"role": "tool",
|
|
46
|
-
"content": [{"type": "text", "text": merged_text}],
|
|
47
|
-
"tool_call_id": group.tool_result.call_id,
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
|
|
51
28
|
def _assistant_group_to_message(group: AssistantGroup, model_name: str | None) -> chat.ChatCompletionMessageParam:
|
|
52
29
|
assistant_message: dict[str, object] = {"role": "assistant"}
|
|
53
30
|
|
|
@@ -150,9 +127,9 @@ def convert_history_to_input(
|
|
|
150
127
|
for group in parse_message_groups(history):
|
|
151
128
|
match group:
|
|
152
129
|
case UserGroup():
|
|
153
|
-
messages.append(
|
|
130
|
+
messages.append(user_group_to_openai_message(group))
|
|
154
131
|
case ToolGroup():
|
|
155
|
-
messages.append(
|
|
132
|
+
messages.append(tool_group_to_openai_message(group))
|
|
156
133
|
case AssistantGroup():
|
|
157
134
|
messages.append(_assistant_group_to_message(group, model_name))
|
|
158
135
|
|
klaude_code/llm/registry.py
CHANGED
|
@@ -1,22 +1,48 @@
|
|
|
1
|
-
from typing import Callable, TypeVar
|
|
1
|
+
from typing import TYPE_CHECKING, Callable, TypeVar
|
|
2
2
|
|
|
3
|
-
from klaude_code.llm.client import LLMClientABC
|
|
4
3
|
from klaude_code.protocol import llm_param
|
|
5
4
|
|
|
6
|
-
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from klaude_code.llm.client import LLMClientABC
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
_T = TypeVar("_T", bound=type["LLMClientABC"])
|
|
9
9
|
|
|
10
|
+
# Track which protocols have been loaded
|
|
11
|
+
_loaded_protocols: set[llm_param.LLMClientProtocol] = set()
|
|
12
|
+
_REGISTRY: dict[llm_param.LLMClientProtocol, type["LLMClientABC"]] = {}
|
|
10
13
|
|
|
11
|
-
|
|
12
|
-
|
|
14
|
+
|
|
15
|
+
def _load_protocol(protocol: llm_param.LLMClientProtocol) -> None:
|
|
16
|
+
"""Load the module for a specific protocol on demand."""
|
|
17
|
+
if protocol in _loaded_protocols:
|
|
18
|
+
return
|
|
19
|
+
_loaded_protocols.add(protocol)
|
|
20
|
+
|
|
21
|
+
# Import only the needed module to trigger @register decorator
|
|
22
|
+
if protocol == llm_param.LLMClientProtocol.ANTHROPIC:
|
|
23
|
+
from . import anthropic as _ # noqa: F401
|
|
24
|
+
elif protocol == llm_param.LLMClientProtocol.CODEX:
|
|
25
|
+
from . import codex as _ # noqa: F401
|
|
26
|
+
elif protocol == llm_param.LLMClientProtocol.OPENAI:
|
|
27
|
+
from . import openai_compatible as _ # noqa: F401
|
|
28
|
+
elif protocol == llm_param.LLMClientProtocol.OPENROUTER:
|
|
29
|
+
from . import openrouter as _ # noqa: F401
|
|
30
|
+
elif protocol == llm_param.LLMClientProtocol.RESPONSES:
|
|
31
|
+
from . import responses as _ # noqa: F401
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def register(name: llm_param.LLMClientProtocol) -> Callable[[_T], _T]:
|
|
35
|
+
"""Decorator to register an LLM client class for a protocol."""
|
|
36
|
+
|
|
37
|
+
def _decorator(cls: _T) -> _T:
|
|
13
38
|
_REGISTRY[name] = cls
|
|
14
39
|
return cls
|
|
15
40
|
|
|
16
41
|
return _decorator
|
|
17
42
|
|
|
18
43
|
|
|
19
|
-
def create_llm_client(config: llm_param.LLMConfigParameter) -> LLMClientABC:
|
|
44
|
+
def create_llm_client(config: llm_param.LLMConfigParameter) -> "LLMClientABC":
|
|
45
|
+
_load_protocol(config.protocol)
|
|
20
46
|
if config.protocol not in _REGISTRY:
|
|
21
47
|
raise ValueError(f"Unknown LLMClient protocol: {config.protocol}")
|
|
22
48
|
return _REGISTRY[config.protocol].create(config)
|