klaude-code 1.2.7__py3-none-any.whl → 1.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. klaude_code/auth/codex/__init__.py +1 -1
  2. klaude_code/command/__init__.py +2 -0
  3. klaude_code/command/prompt-deslop.md +14 -0
  4. klaude_code/command/release_notes_cmd.py +86 -0
  5. klaude_code/command/status_cmd.py +92 -54
  6. klaude_code/core/agent.py +13 -19
  7. klaude_code/core/manager/sub_agent_manager.py +5 -1
  8. klaude_code/core/prompt.py +38 -28
  9. klaude_code/core/reminders.py +4 -4
  10. klaude_code/core/task.py +60 -45
  11. klaude_code/core/tool/__init__.py +2 -0
  12. klaude_code/core/tool/file/apply_patch_tool.py +1 -1
  13. klaude_code/core/tool/file/edit_tool.py +1 -1
  14. klaude_code/core/tool/file/multi_edit_tool.py +1 -1
  15. klaude_code/core/tool/file/write_tool.py +1 -1
  16. klaude_code/core/tool/memory/memory_tool.py +2 -2
  17. klaude_code/core/tool/sub_agent_tool.py +2 -1
  18. klaude_code/core/tool/todo/todo_write_tool.py +1 -1
  19. klaude_code/core/tool/todo/update_plan_tool.py +1 -1
  20. klaude_code/core/tool/tool_context.py +21 -4
  21. klaude_code/core/tool/tool_runner.py +5 -8
  22. klaude_code/core/tool/web/mermaid_tool.py +1 -4
  23. klaude_code/core/turn.py +90 -62
  24. klaude_code/llm/anthropic/client.py +15 -46
  25. klaude_code/llm/client.py +1 -1
  26. klaude_code/llm/codex/client.py +44 -30
  27. klaude_code/llm/input_common.py +0 -6
  28. klaude_code/llm/openai_compatible/client.py +29 -73
  29. klaude_code/llm/openai_compatible/input.py +6 -4
  30. klaude_code/llm/openai_compatible/stream_processor.py +82 -0
  31. klaude_code/llm/openrouter/client.py +29 -59
  32. klaude_code/llm/openrouter/input.py +4 -27
  33. klaude_code/llm/responses/client.py +49 -79
  34. klaude_code/llm/usage.py +51 -10
  35. klaude_code/protocol/commands.py +1 -0
  36. klaude_code/protocol/events.py +12 -2
  37. klaude_code/protocol/model.py +142 -26
  38. klaude_code/protocol/sub_agent.py +5 -1
  39. klaude_code/session/export.py +51 -27
  40. klaude_code/session/session.py +33 -16
  41. klaude_code/session/templates/export_session.html +4 -1
  42. klaude_code/ui/modes/repl/__init__.py +1 -5
  43. klaude_code/ui/modes/repl/event_handler.py +153 -54
  44. klaude_code/ui/modes/repl/renderer.py +6 -4
  45. klaude_code/ui/renderers/developer.py +35 -25
  46. klaude_code/ui/renderers/metadata.py +68 -30
  47. klaude_code/ui/renderers/tools.py +53 -87
  48. klaude_code/ui/rich/markdown.py +5 -5
  49. {klaude_code-1.2.7.dist-info → klaude_code-1.2.9.dist-info}/METADATA +1 -1
  50. {klaude_code-1.2.7.dist-info → klaude_code-1.2.9.dist-info}/RECORD +52 -49
  51. {klaude_code-1.2.7.dist-info → klaude_code-1.2.9.dist-info}/WHEEL +0 -0
  52. {klaude_code-1.2.7.dist-info → klaude_code-1.2.9.dist-info}/entry_points.txt +0 -0
@@ -1,11 +1,10 @@
1
1
  import json
2
- import time
3
2
  from collections.abc import AsyncGenerator
4
3
  from typing import override
5
4
 
6
5
  import anthropic
7
6
  import httpx
8
- from anthropic import RateLimitError
7
+ from anthropic import APIError
9
8
  from anthropic.types.beta.beta_input_json_delta import BetaInputJSONDelta
10
9
  from anthropic.types.beta.beta_raw_content_block_delta_event import BetaRawContentBlockDeltaEvent
11
10
  from anthropic.types.beta.beta_raw_content_block_start_event import BetaRawContentBlockStartEvent
@@ -22,7 +21,7 @@ from klaude_code.llm.anthropic.input import convert_history_to_input, convert_sy
22
21
  from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
23
22
  from klaude_code.llm.input_common import apply_config_defaults
24
23
  from klaude_code.llm.registry import register
25
- from klaude_code.llm.usage import calculate_cost
24
+ from klaude_code.llm.usage import MetadataTracker, convert_anthropic_usage
26
25
  from klaude_code.protocol import llm_param, model
27
26
  from klaude_code.trace import DebugType, log_debug
28
27
 
@@ -47,9 +46,7 @@ class AnthropicClient(LLMClientABC):
47
46
  async def call(self, param: llm_param.LLMCallParameter) -> AsyncGenerator[model.ConversationItem, None]:
48
47
  param = apply_config_defaults(param, self.get_llm_config())
49
48
 
50
- request_start_time = time.time()
51
- first_token_time: float | None = None
52
- last_token_time: float | None = None
49
+ metadata_tracker = MetadataTracker(cost_config=self._config.cost)
53
50
 
54
51
  messages = convert_history_to_input(param.input, param.model)
55
52
  tools = convert_tool_schema(param.tools)
@@ -77,7 +74,7 @@ class AnthropicClient(LLMClientABC):
77
74
  else anthropic.types.ThinkingConfigDisabledParam(
78
75
  type="disabled",
79
76
  ),
80
- extra_headers={"extra": json.dumps({"session_id": param.session_id})},
77
+ extra_headers={"extra": json.dumps({"session_id": param.session_id}, sort_keys=True)},
81
78
  )
82
79
 
83
80
  accumulated_thinking: list[str] = []
@@ -112,32 +109,24 @@ class AnthropicClient(LLMClientABC):
112
109
  case BetaRawContentBlockDeltaEvent() as event:
113
110
  match event.delta:
114
111
  case BetaThinkingDelta() as delta:
115
- if first_token_time is None:
116
- first_token_time = time.time()
117
- last_token_time = time.time()
112
+ metadata_tracker.record_token()
118
113
  accumulated_thinking.append(delta.thinking)
119
114
  case BetaSignatureDelta() as delta:
120
- if first_token_time is None:
121
- first_token_time = time.time()
122
- last_token_time = time.time()
115
+ metadata_tracker.record_token()
123
116
  yield model.ReasoningEncryptedItem(
124
117
  encrypted_content=delta.signature,
125
118
  response_id=response_id,
126
119
  model=str(param.model),
127
120
  )
128
121
  case BetaTextDelta() as delta:
129
- if first_token_time is None:
130
- first_token_time = time.time()
131
- last_token_time = time.time()
122
+ metadata_tracker.record_token()
132
123
  accumulated_content.append(delta.text)
133
124
  yield model.AssistantMessageDelta(
134
125
  content=delta.text,
135
126
  response_id=response_id,
136
127
  )
137
128
  case BetaInputJSONDelta() as delta:
138
- if first_token_time is None:
139
- first_token_time = time.time()
140
- last_token_time = time.time()
129
+ metadata_tracker.record_token()
141
130
  if current_tool_inputs is not None:
142
131
  current_tool_inputs.append(delta.partial_json)
143
132
  case _:
@@ -184,38 +173,18 @@ class AnthropicClient(LLMClientABC):
184
173
  input_tokens += (event.usage.input_tokens or 0) + (event.usage.cache_creation_input_tokens or 0)
185
174
  output_tokens += event.usage.output_tokens or 0
186
175
  cached_tokens += event.usage.cache_read_input_tokens or 0
187
- total_tokens = input_tokens + cached_tokens + output_tokens
188
- context_usage_percent = (
189
- (total_tokens / param.context_limit) * 100 if param.context_limit else None
190
- )
191
-
192
- throughput_tps: float | None = None
193
- first_token_latency_ms: float | None = None
194
-
195
- if first_token_time is not None:
196
- first_token_latency_ms = (first_token_time - request_start_time) * 1000
197
176
 
198
- if first_token_time is not None and last_token_time is not None and output_tokens > 0:
199
- time_duration = last_token_time - first_token_time
200
- if time_duration >= 0.15:
201
- throughput_tps = output_tokens / time_duration
202
-
203
- usage = model.Usage(
177
+ usage = convert_anthropic_usage(
204
178
  input_tokens=input_tokens,
205
179
  output_tokens=output_tokens,
206
180
  cached_tokens=cached_tokens,
207
- total_tokens=total_tokens,
208
- context_usage_percent=context_usage_percent,
209
- throughput_tps=throughput_tps,
210
- first_token_latency_ms=first_token_latency_ms,
211
- )
212
- calculate_cost(usage, self._config.cost)
213
- yield model.ResponseMetadataItem(
214
- usage=usage,
215
- response_id=response_id,
216
- model_name=str(param.model),
181
+ context_limit=param.context_limit,
217
182
  )
183
+ metadata_tracker.set_usage(usage)
184
+ metadata_tracker.set_model_name(str(param.model))
185
+ metadata_tracker.set_response_id(response_id)
186
+ yield metadata_tracker.finalize()
218
187
  case _:
219
188
  pass
220
- except RateLimitError as e:
189
+ except (APIError, httpx.HTTPError) as e:
221
190
  yield model.StreamErrorItem(error=f"{e.__class__.__name__} {str(e)}")
klaude_code/llm/client.py CHANGED
@@ -42,7 +42,7 @@ def call_with_logged_payload(func: Callable[P, R], *args: P.args, **kwargs: P.kw
42
42
 
43
43
  payload = {k: v for k, v in kwargs.items() if v is not None}
44
44
  log_debug(
45
- json.dumps(payload, ensure_ascii=False, default=str),
45
+ json.dumps(payload, ensure_ascii=False, default=str, sort_keys=True),
46
46
  style="yellow",
47
47
  debug_type=DebugType.LLM_PAYLOAD,
48
48
  )
@@ -1,10 +1,10 @@
1
1
  """Codex LLM client using ChatGPT subscription via OAuth."""
2
2
 
3
- import time
4
3
  from collections.abc import AsyncGenerator
5
4
  from typing import override
6
5
 
7
6
  import httpx
7
+ import openai
8
8
  from openai import AsyncOpenAI
9
9
 
10
10
  from klaude_code.auth.codex.exceptions import CodexNotLoggedInError
@@ -15,14 +15,16 @@ from klaude_code.llm.input_common import apply_config_defaults
15
15
  from klaude_code.llm.registry import register
16
16
  from klaude_code.llm.responses.client import parse_responses_stream
17
17
  from klaude_code.llm.responses.input import convert_history_to_input, convert_tool_schema
18
+ from klaude_code.llm.usage import MetadataTracker
18
19
  from klaude_code.protocol import llm_param, model
19
20
 
20
21
  # Codex API configuration
21
22
  CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
22
23
  CODEX_HEADERS = {
23
- "OpenAI-Beta": "responses=experimental",
24
24
  "originator": "codex_cli_rs",
25
- "User-Agent": "GitHubCopilotChat/0.32.4",
25
+ # Mocked Codex-style user agent string
26
+ "User-Agent": "codex_cli_rs/0.0.0-klaude",
27
+ "OpenAI-Beta": "responses=experimental",
26
28
  }
27
29
 
28
30
 
@@ -82,35 +84,47 @@ class CodexClient(LLMClientABC):
82
84
  # Codex API requires store=False
83
85
  param.store = False
84
86
 
85
- request_start_time = time.time()
87
+ metadata_tracker = MetadataTracker(cost_config=self._config.cost)
86
88
 
87
89
  inputs = convert_history_to_input(param.input, param.model)
88
90
  tools = convert_tool_schema(param.tools)
89
91
 
90
- stream = await call_with_logged_payload(
91
- self.client.responses.create,
92
- model=str(param.model),
93
- tool_choice="auto",
94
- parallel_tool_calls=True,
95
- include=[
96
- "reasoning.encrypted_content",
97
- ],
98
- store=False, # Always False for Codex
99
- stream=True,
100
- input=inputs,
101
- instructions=param.system,
102
- tools=tools,
103
- text={
104
- "verbosity": param.verbosity,
105
- },
106
- prompt_cache_key=param.session_id or "",
107
- reasoning={
108
- "effort": param.thinking.reasoning_effort,
109
- "summary": param.thinking.reasoning_summary,
110
- }
111
- if param.thinking and param.thinking.reasoning_effort
112
- else None,
113
- )
114
-
115
- async for item in parse_responses_stream(stream, param, self._config.cost, request_start_time):
92
+ session_id = param.session_id or ""
93
+ # Must send conversation_id/session_id headers to improve ChatGPT backend prompt cache hit rate.
94
+ extra_headers: dict[str, str] = {}
95
+ if session_id:
96
+ extra_headers["conversation_id"] = session_id
97
+ extra_headers["session_id"] = session_id
98
+
99
+ try:
100
+ stream = await call_with_logged_payload(
101
+ self.client.responses.create,
102
+ model=str(param.model),
103
+ tool_choice="auto",
104
+ parallel_tool_calls=True,
105
+ include=[
106
+ "reasoning.encrypted_content",
107
+ ],
108
+ store=False, # Always False for Codex
109
+ stream=True,
110
+ input=inputs,
111
+ instructions=param.system,
112
+ tools=tools,
113
+ text={
114
+ "verbosity": param.verbosity,
115
+ },
116
+ prompt_cache_key=session_id,
117
+ reasoning={
118
+ "effort": param.thinking.reasoning_effort,
119
+ "summary": param.thinking.reasoning_summary,
120
+ }
121
+ if param.thinking and param.thinking.reasoning_effort
122
+ else None,
123
+ extra_headers=extra_headers,
124
+ )
125
+ except (openai.OpenAIError, httpx.HTTPError) as e:
126
+ yield model.StreamErrorItem(error=f"{e.__class__.__name__} {str(e)}")
127
+ return
128
+
129
+ async for item in parse_responses_stream(stream, param, metadata_tracker):
116
130
  yield item
@@ -49,10 +49,6 @@ class AssistantGroup:
49
49
 
50
50
  text_content: str | None = None
51
51
  tool_calls: list[model.ToolCallItem] = field(default_factory=lambda: [])
52
- reasoning_text: list[model.ReasoningTextItem] = field(default_factory=lambda: [])
53
- reasoning_encrypted: list[model.ReasoningEncryptedItem] = field(default_factory=lambda: [])
54
- # Preserve original ordering of reasoning items for providers that
55
- # need to emit them as an ordered stream (e.g. OpenRouter).
56
52
  reasoning_items: list[model.ReasoningTextItem | model.ReasoningEncryptedItem] = field(default_factory=lambda: [])
57
53
 
58
54
 
@@ -184,10 +180,8 @@ def parse_message_groups(history: list[model.ConversationItem]) -> list[MessageG
184
180
  case model.ToolCallItem():
185
181
  group.tool_calls.append(item)
186
182
  case model.ReasoningTextItem():
187
- group.reasoning_text.append(item)
188
183
  group.reasoning_items.append(item)
189
184
  case model.ReasoningEncryptedItem():
190
- group.reasoning_encrypted.append(item)
191
185
  group.reasoning_items.append(item)
192
186
  case _:
193
187
  pass
@@ -1,15 +1,14 @@
1
1
  import json
2
2
  from collections.abc import AsyncGenerator
3
- from typing import Literal, override
3
+ from typing import override
4
4
 
5
5
  import httpx
6
6
  import openai
7
- from openai import APIError, RateLimitError
8
7
 
9
8
  from klaude_code.llm.client import LLMClientABC, call_with_logged_payload
10
9
  from klaude_code.llm.input_common import apply_config_defaults
11
10
  from klaude_code.llm.openai_compatible.input import convert_history_to_input, convert_tool_schema
12
- from klaude_code.llm.openai_compatible.tool_call_accumulator import BasicToolCallAccumulator, ToolCallAccumulatorABC
11
+ from klaude_code.llm.openai_compatible.stream_processor import StreamStateManager
13
12
  from klaude_code.llm.registry import register
14
13
  from klaude_code.llm.usage import MetadataTracker, convert_usage
15
14
  from klaude_code.protocol import llm_param, model
@@ -51,7 +50,7 @@ class OpenAICompatibleClient(LLMClientABC):
51
50
  metadata_tracker = MetadataTracker(cost_config=self._config.cost)
52
51
 
53
52
  extra_body = {}
54
- extra_headers = {"extra": json.dumps({"session_id": param.session_id})}
53
+ extra_headers = {"extra": json.dumps({"session_id": param.session_id}, sort_keys=True)}
55
54
 
56
55
  if param.thinking:
57
56
  extra_body["thinking"] = {
@@ -74,42 +73,7 @@ class OpenAICompatibleClient(LLMClientABC):
74
73
  extra_headers=extra_headers,
75
74
  )
76
75
 
77
- stage: Literal["waiting", "reasoning", "assistant", "tool", "done"] = "waiting"
78
- accumulated_reasoning: list[str] = []
79
- accumulated_content: list[str] = []
80
- accumulated_tool_calls: ToolCallAccumulatorABC = BasicToolCallAccumulator()
81
- emitted_tool_start_indices: set[int] = set()
82
- response_id: str | None = None
83
-
84
- def flush_reasoning_items() -> list[model.ConversationItem]:
85
- nonlocal accumulated_reasoning
86
- if not accumulated_reasoning:
87
- return []
88
- item = model.ReasoningTextItem(
89
- content="".join(accumulated_reasoning),
90
- response_id=response_id,
91
- model=str(param.model),
92
- )
93
- accumulated_reasoning = []
94
- return [item]
95
-
96
- def flush_assistant_items() -> list[model.ConversationItem]:
97
- nonlocal accumulated_content
98
- if len(accumulated_content) == 0:
99
- return []
100
- item = model.AssistantMessageItem(
101
- content="".join(accumulated_content),
102
- response_id=response_id,
103
- )
104
- accumulated_content = []
105
- return [item]
106
-
107
- def flush_tool_call_items() -> list[model.ToolCallItem]:
108
- nonlocal accumulated_tool_calls
109
- items: list[model.ToolCallItem] = accumulated_tool_calls.get()
110
- if items:
111
- accumulated_tool_calls.chunks_by_step = [] # pyright: ignore[reportAttributeAccessIssue]
112
- return items
76
+ state = StreamStateManager(param_model=str(param.model))
113
77
 
114
78
  try:
115
79
  async for event in await stream:
@@ -118,10 +82,9 @@ class OpenAICompatibleClient(LLMClientABC):
118
82
  style="blue",
119
83
  debug_type=DebugType.LLM_STREAM,
120
84
  )
121
- if not response_id and event.id:
122
- response_id = event.id
123
- accumulated_tool_calls.response_id = response_id
124
- yield model.StartItem(response_id=response_id)
85
+ if not state.response_id and event.id:
86
+ state.set_response_id(event.id)
87
+ yield model.StartItem(response_id=event.id)
125
88
  if (
126
89
  event.usage is not None and event.usage.completion_tokens is not None # pyright: ignore[reportUnnecessaryComparison] gcp gemini will return None usage field
127
90
  ):
@@ -152,60 +115,53 @@ class OpenAICompatibleClient(LLMClientABC):
152
115
  reasoning_content = getattr(delta, "reasoning_content")
153
116
  if reasoning_content:
154
117
  metadata_tracker.record_token()
155
- stage = "reasoning"
156
- accumulated_reasoning.append(reasoning_content)
118
+ state.stage = "reasoning"
119
+ state.accumulated_reasoning.append(reasoning_content)
157
120
 
158
121
  # Assistant
159
122
  if delta.content and (
160
- stage == "assistant" or delta.content.strip()
123
+ state.stage == "assistant" or delta.content.strip()
161
124
  ): # Process all content in assistant stage, filter empty content in reasoning stage
162
125
  metadata_tracker.record_token()
163
- if stage == "reasoning":
164
- for item in flush_reasoning_items():
126
+ if state.stage == "reasoning":
127
+ for item in state.flush_reasoning():
165
128
  yield item
166
- elif stage == "tool":
167
- for item in flush_tool_call_items():
129
+ elif state.stage == "tool":
130
+ for item in state.flush_tool_calls():
168
131
  yield item
169
- stage = "assistant"
170
- accumulated_content.append(delta.content)
132
+ state.stage = "assistant"
133
+ state.accumulated_content.append(delta.content)
171
134
  yield model.AssistantMessageDelta(
172
135
  content=delta.content,
173
- response_id=response_id,
136
+ response_id=state.response_id,
174
137
  )
175
138
 
176
139
  # Tool
177
140
  if delta.tool_calls and len(delta.tool_calls) > 0:
178
141
  metadata_tracker.record_token()
179
- if stage == "reasoning":
180
- for item in flush_reasoning_items():
142
+ if state.stage == "reasoning":
143
+ for item in state.flush_reasoning():
181
144
  yield item
182
- elif stage == "assistant":
183
- for item in flush_assistant_items():
145
+ elif state.stage == "assistant":
146
+ for item in state.flush_assistant():
184
147
  yield item
185
- stage = "tool"
148
+ state.stage = "tool"
186
149
  # Emit ToolCallStartItem for new tool calls
187
150
  for tc in delta.tool_calls:
188
- if tc.index not in emitted_tool_start_indices and tc.function and tc.function.name:
189
- emitted_tool_start_indices.add(tc.index)
151
+ if tc.index not in state.emitted_tool_start_indices and tc.function and tc.function.name:
152
+ state.emitted_tool_start_indices.add(tc.index)
190
153
  yield model.ToolCallStartItem(
191
- response_id=response_id,
154
+ response_id=state.response_id,
192
155
  call_id=tc.id or "",
193
156
  name=tc.function.name,
194
157
  )
195
- accumulated_tool_calls.add(delta.tool_calls)
196
- except (RateLimitError, APIError) as e:
158
+ state.accumulated_tool_calls.add(delta.tool_calls)
159
+ except (openai.OpenAIError, httpx.HTTPError) as e:
197
160
  yield model.StreamErrorItem(error=f"{e.__class__.__name__} {str(e)}")
198
161
 
199
162
  # Finalize
200
- for item in flush_reasoning_items():
163
+ for item in state.flush_all():
201
164
  yield item
202
165
 
203
- for item in flush_assistant_items():
204
- yield item
205
-
206
- if stage == "tool":
207
- for tool_call_item in flush_tool_call_items():
208
- yield tool_call_item
209
-
210
- metadata_tracker.set_response_id(response_id)
166
+ metadata_tracker.set_response_id(state.response_id)
211
167
  yield metadata_tracker.finalize()
@@ -10,7 +10,8 @@ from klaude_code.llm.input_common import AssistantGroup, ToolGroup, UserGroup, m
10
10
  from klaude_code.protocol import llm_param, model
11
11
 
12
12
 
13
- def _user_group_to_message(group: UserGroup) -> chat.ChatCompletionMessageParam:
13
+ def user_group_to_openai_message(group: UserGroup) -> chat.ChatCompletionMessageParam:
14
+ """Convert a UserGroup to an OpenAI-compatible chat message."""
14
15
  parts: list[ChatCompletionContentPartParam] = []
15
16
  for text in group.text_parts:
16
17
  parts.append({"type": "text", "text": text + "\n"})
@@ -21,7 +22,8 @@ def _user_group_to_message(group: UserGroup) -> chat.ChatCompletionMessageParam:
21
22
  return {"role": "user", "content": parts}
22
23
 
23
24
 
24
- def _tool_group_to_message(group: ToolGroup) -> chat.ChatCompletionMessageParam:
25
+ def tool_group_to_openai_message(group: ToolGroup) -> chat.ChatCompletionMessageParam:
26
+ """Convert a ToolGroup to an OpenAI-compatible chat message."""
25
27
  merged_text = merge_reminder_text(
26
28
  group.tool_result.output or "<system-reminder>Tool ran without output or errors</system-reminder>",
27
29
  group.reminder_texts,
@@ -82,9 +84,9 @@ def convert_history_to_input(
82
84
  for group in parse_message_groups(history):
83
85
  match group:
84
86
  case UserGroup():
85
- messages.append(_user_group_to_message(group))
87
+ messages.append(user_group_to_openai_message(group))
86
88
  case ToolGroup():
87
- messages.append(_tool_group_to_message(group))
89
+ messages.append(tool_group_to_openai_message(group))
88
90
  case AssistantGroup():
89
91
  messages.append(_assistant_group_to_message(group))
90
92
 
@@ -0,0 +1,82 @@
1
+ """Shared stream processing utilities for OpenAI-compatible clients.
2
+
3
+ This module provides a reusable stream state manager that handles the common
4
+ logic for accumulating and flushing reasoning, assistant content, and tool calls
5
+ across different LLM providers (OpenAI-compatible, OpenRouter).
6
+ """
7
+
8
+ from typing import Callable, Literal
9
+
10
+ from klaude_code.llm.openai_compatible.tool_call_accumulator import BasicToolCallAccumulator, ToolCallAccumulatorABC
11
+ from klaude_code.protocol import model
12
+
13
+ StreamStage = Literal["waiting", "reasoning", "assistant", "tool"]
14
+
15
+
16
+ class StreamStateManager:
17
+ """Manages streaming state and provides flush operations for accumulated content.
18
+
19
+ This class encapsulates the common state management logic used by both
20
+ OpenAI-compatible and OpenRouter clients, reducing code duplication.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ param_model: str,
26
+ response_id: str | None = None,
27
+ reasoning_flusher: Callable[[], list[model.ConversationItem]] | None = None,
28
+ ):
29
+ self.param_model = param_model
30
+ self.response_id = response_id
31
+ self.stage: StreamStage = "waiting"
32
+ self.accumulated_reasoning: list[str] = []
33
+ self.accumulated_content: list[str] = []
34
+ self.accumulated_tool_calls: ToolCallAccumulatorABC = BasicToolCallAccumulator()
35
+ self.emitted_tool_start_indices: set[int] = set()
36
+ self._reasoning_flusher = reasoning_flusher
37
+
38
+ def set_response_id(self, response_id: str) -> None:
39
+ """Set the response ID once received from the stream."""
40
+ self.response_id = response_id
41
+ self.accumulated_tool_calls.response_id = response_id # pyright: ignore[reportAttributeAccessIssue]
42
+
43
+ def flush_reasoning(self) -> list[model.ConversationItem]:
44
+ """Flush accumulated reasoning content and return items."""
45
+ if self._reasoning_flusher is not None:
46
+ return self._reasoning_flusher()
47
+ if not self.accumulated_reasoning:
48
+ return []
49
+ item = model.ReasoningTextItem(
50
+ content="".join(self.accumulated_reasoning),
51
+ response_id=self.response_id,
52
+ model=self.param_model,
53
+ )
54
+ self.accumulated_reasoning = []
55
+ return [item]
56
+
57
+ def flush_assistant(self) -> list[model.ConversationItem]:
58
+ """Flush accumulated assistant content and return items."""
59
+ if not self.accumulated_content:
60
+ return []
61
+ item = model.AssistantMessageItem(
62
+ content="".join(self.accumulated_content),
63
+ response_id=self.response_id,
64
+ )
65
+ self.accumulated_content = []
66
+ return [item]
67
+
68
+ def flush_tool_calls(self) -> list[model.ToolCallItem]:
69
+ """Flush accumulated tool calls and return items."""
70
+ items: list[model.ToolCallItem] = self.accumulated_tool_calls.get()
71
+ if items:
72
+ self.accumulated_tool_calls.chunks_by_step = [] # pyright: ignore[reportAttributeAccessIssue]
73
+ return items
74
+
75
+ def flush_all(self) -> list[model.ConversationItem]:
76
+ """Flush all accumulated content in order: reasoning, assistant, tool calls."""
77
+ items: list[model.ConversationItem] = []
78
+ items.extend(self.flush_reasoning())
79
+ items.extend(self.flush_assistant())
80
+ if self.stage == "tool":
81
+ items.extend(self.flush_tool_calls())
82
+ return items