klaude-code 1.2.8__py3-none-any.whl → 1.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. klaude_code/auth/codex/__init__.py +1 -1
  2. klaude_code/cli/main.py +12 -1
  3. klaude_code/cli/runtime.py +7 -11
  4. klaude_code/command/__init__.py +68 -21
  5. klaude_code/command/clear_cmd.py +6 -2
  6. klaude_code/command/command_abc.py +5 -2
  7. klaude_code/command/diff_cmd.py +5 -2
  8. klaude_code/command/export_cmd.py +7 -4
  9. klaude_code/command/help_cmd.py +6 -2
  10. klaude_code/command/model_cmd.py +5 -2
  11. klaude_code/command/prompt-deslop.md +14 -0
  12. klaude_code/command/prompt_command.py +8 -3
  13. klaude_code/command/refresh_cmd.py +6 -2
  14. klaude_code/command/registry.py +17 -5
  15. klaude_code/command/release_notes_cmd.py +89 -0
  16. klaude_code/command/status_cmd.py +98 -56
  17. klaude_code/command/terminal_setup_cmd.py +7 -4
  18. klaude_code/const/__init__.py +1 -1
  19. klaude_code/core/agent.py +66 -26
  20. klaude_code/core/executor.py +2 -2
  21. klaude_code/core/manager/agent_manager.py +6 -7
  22. klaude_code/core/manager/llm_clients.py +47 -22
  23. klaude_code/core/manager/llm_clients_builder.py +19 -7
  24. klaude_code/core/manager/sub_agent_manager.py +6 -2
  25. klaude_code/core/prompt.py +38 -28
  26. klaude_code/core/reminders.py +4 -7
  27. klaude_code/core/task.py +59 -40
  28. klaude_code/core/tool/__init__.py +2 -0
  29. klaude_code/core/tool/file/_utils.py +30 -0
  30. klaude_code/core/tool/file/apply_patch_tool.py +1 -1
  31. klaude_code/core/tool/file/edit_tool.py +6 -31
  32. klaude_code/core/tool/file/multi_edit_tool.py +7 -32
  33. klaude_code/core/tool/file/read_tool.py +6 -18
  34. klaude_code/core/tool/file/write_tool.py +6 -31
  35. klaude_code/core/tool/memory/__init__.py +5 -0
  36. klaude_code/core/tool/memory/memory_tool.py +2 -2
  37. klaude_code/core/tool/memory/skill_loader.py +2 -1
  38. klaude_code/core/tool/memory/skill_tool.py +13 -0
  39. klaude_code/core/tool/sub_agent_tool.py +2 -1
  40. klaude_code/core/tool/todo/todo_write_tool.py +1 -1
  41. klaude_code/core/tool/todo/update_plan_tool.py +1 -1
  42. klaude_code/core/tool/tool_context.py +21 -4
  43. klaude_code/core/tool/tool_runner.py +5 -8
  44. klaude_code/core/tool/web/mermaid_tool.py +1 -4
  45. klaude_code/core/turn.py +40 -37
  46. klaude_code/llm/__init__.py +2 -12
  47. klaude_code/llm/anthropic/client.py +14 -44
  48. klaude_code/llm/client.py +2 -2
  49. klaude_code/llm/codex/client.py +4 -3
  50. klaude_code/llm/input_common.py +0 -6
  51. klaude_code/llm/openai_compatible/client.py +31 -74
  52. klaude_code/llm/openai_compatible/input.py +6 -4
  53. klaude_code/llm/openai_compatible/stream_processor.py +82 -0
  54. klaude_code/llm/openrouter/client.py +32 -62
  55. klaude_code/llm/openrouter/input.py +4 -27
  56. klaude_code/llm/registry.py +33 -7
  57. klaude_code/llm/responses/client.py +16 -48
  58. klaude_code/llm/responses/input.py +1 -1
  59. klaude_code/llm/usage.py +61 -11
  60. klaude_code/protocol/commands.py +1 -0
  61. klaude_code/protocol/events.py +11 -2
  62. klaude_code/protocol/model.py +147 -24
  63. klaude_code/protocol/op.py +1 -0
  64. klaude_code/protocol/sub_agent.py +5 -1
  65. klaude_code/session/export.py +56 -32
  66. klaude_code/session/session.py +43 -21
  67. klaude_code/session/templates/export_session.html +4 -1
  68. klaude_code/ui/core/input.py +1 -1
  69. klaude_code/ui/modes/repl/__init__.py +1 -5
  70. klaude_code/ui/modes/repl/clipboard.py +5 -5
  71. klaude_code/ui/modes/repl/event_handler.py +153 -54
  72. klaude_code/ui/modes/repl/renderer.py +4 -4
  73. klaude_code/ui/renderers/developer.py +35 -25
  74. klaude_code/ui/renderers/metadata.py +68 -30
  75. klaude_code/ui/renderers/tools.py +53 -87
  76. klaude_code/ui/rich/markdown.py +5 -5
  77. klaude_code/ui/terminal/control.py +2 -2
  78. klaude_code/version.py +3 -3
  79. {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/METADATA +1 -1
  80. {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/RECORD +82 -78
  81. {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/WHEEL +0 -0
  82. {klaude_code-1.2.8.dist-info → klaude_code-1.2.10.dist-info}/entry_points.txt +0 -0
@@ -1,51 +1,73 @@
1
+ from typing import TYPE_CHECKING
2
+
1
3
  from klaude_code.command.command_abc import CommandABC, CommandResult
2
4
  from klaude_code.command.registry import register_command
3
- from klaude_code.core.agent import Agent
4
5
  from klaude_code.protocol import commands, events, model
5
6
  from klaude_code.session.session import Session
6
7
 
8
+ if TYPE_CHECKING:
9
+ from klaude_code.core.agent import Agent
10
+
11
+
12
+ class AggregatedUsage(model.BaseModel):
13
+ """Aggregated usage statistics including per-model breakdown."""
14
+
15
+ total: model.Usage
16
+ by_model: list[model.TaskMetadata]
17
+ task_count: int
18
+
7
19
 
8
- def accumulate_session_usage(session: Session) -> tuple[model.Usage, int]:
9
- """Accumulate usage statistics from all ResponseMetadataItems in session history.
20
+ def accumulate_session_usage(session: Session) -> AggregatedUsage:
21
+ """Accumulate usage statistics from all TaskMetadataItems in session history.
10
22
 
11
- Returns:
12
- A tuple of (accumulated_usage, task_count)
23
+ Includes both main agent and sub-agent task metadata, grouped by model+provider.
13
24
  """
14
- total = model.Usage()
25
+ all_metadata: list[model.TaskMetadata] = []
15
26
  task_count = 0
16
- first_currency_set = False
17
27
 
18
28
  for item in session.conversation_history:
19
- if isinstance(item, model.ResponseMetadataItem) and item.usage:
29
+ if isinstance(item, model.TaskMetadataItem):
20
30
  task_count += 1
21
- usage = item.usage
22
-
23
- # Set currency from first usage item
24
- if not first_currency_set and usage.currency:
25
- total.currency = usage.currency
26
- first_currency_set = True
27
-
28
- total.input_tokens += usage.input_tokens
29
- total.cached_tokens += usage.cached_tokens
30
- total.reasoning_tokens += usage.reasoning_tokens
31
- total.output_tokens += usage.output_tokens
32
- total.total_tokens += usage.total_tokens
33
-
34
- # Accumulate costs
35
- if usage.input_cost is not None:
36
- total.input_cost = (total.input_cost or 0.0) + usage.input_cost
37
- if usage.output_cost is not None:
38
- total.output_cost = (total.output_cost or 0.0) + usage.output_cost
39
- if usage.cache_read_cost is not None:
40
- total.cache_read_cost = (total.cache_read_cost or 0.0) + usage.cache_read_cost
41
- if usage.total_cost is not None:
42
- total.total_cost = (total.total_cost or 0.0) + usage.total_cost
43
-
44
- # Keep the latest context_usage_percent
45
- if usage.context_usage_percent is not None:
46
- total.context_usage_percent = usage.context_usage_percent
47
-
48
- return total, task_count
31
+ all_metadata.append(item.main)
32
+ all_metadata.extend(item.sub_agent_task_metadata)
33
+
34
+ # Aggregate by model+provider
35
+ by_model = model.TaskMetadata.aggregate_by_model(all_metadata)
36
+
37
+ # Calculate total from aggregated results
38
+ total = model.Usage()
39
+ for meta in by_model:
40
+ if not meta.usage:
41
+ continue
42
+ usage = meta.usage
43
+
44
+ # Set currency from first
45
+ if total.currency == "USD" and usage.currency:
46
+ total.currency = usage.currency
47
+
48
+ # Accumulate primary token fields (total_tokens is computed)
49
+ total.input_tokens += usage.input_tokens
50
+ total.cached_tokens += usage.cached_tokens
51
+ total.reasoning_tokens += usage.reasoning_tokens
52
+ total.output_tokens += usage.output_tokens
53
+
54
+ # Accumulate cost components (total_cost is computed)
55
+ if usage.input_cost is not None:
56
+ total.input_cost = (total.input_cost or 0.0) + usage.input_cost
57
+ if usage.output_cost is not None:
58
+ total.output_cost = (total.output_cost or 0.0) + usage.output_cost
59
+ if usage.cache_read_cost is not None:
60
+ total.cache_read_cost = (total.cache_read_cost or 0.0) + usage.cache_read_cost
61
+
62
+ # Track peak context window size (max across all tasks)
63
+ if usage.context_token is not None:
64
+ total.context_token = usage.context_token
65
+
66
+ # Keep the latest context_limit for computed context_usage_percent
67
+ if usage.context_limit is not None:
68
+ total.context_limit = usage.context_limit
69
+
70
+ return AggregatedUsage(total=total, by_model=by_model, task_count=task_count)
49
71
 
50
72
 
51
73
  def _format_tokens(tokens: int) -> str:
@@ -67,20 +89,42 @@ def _format_cost(cost: float | None, currency: str = "USD") -> str:
67
89
  return f"{symbol}{cost:.2f}"
68
90
 
69
91
 
70
- def format_status_content(usage: model.Usage) -> str:
71
- """Format session status as comma-separated text."""
72
- parts: list[str] = []
92
+ def _format_model_usage_line(meta: model.TaskMetadata) -> str:
93
+ """Format a single model's usage as a line."""
94
+ model_label = meta.model_name
95
+ if meta.provider:
96
+ model_label = f"{meta.model_name} ({meta.provider})"
97
+
98
+ usage = meta.usage
99
+ if not usage:
100
+ return f" {model_label}: no usage data"
101
+
102
+ cost_str = _format_cost(usage.total_cost, usage.currency)
103
+ return (
104
+ f" {model_label}: "
105
+ f"{_format_tokens(usage.input_tokens)} input, "
106
+ f"{_format_tokens(usage.output_tokens)} output, "
107
+ f"{_format_tokens(usage.cached_tokens)} cache read, "
108
+ f"{_format_tokens(usage.reasoning_tokens)} thinking, "
109
+ f"({cost_str})"
110
+ )
111
+
112
+
113
+ def format_status_content(aggregated: AggregatedUsage) -> str:
114
+ """Format session status with per-model breakdown."""
115
+ lines: list[str] = []
73
116
 
74
- parts.append(f"Input: {_format_tokens(usage.input_tokens)}")
75
- if usage.cached_tokens > 0:
76
- parts.append(f"Cached: {_format_tokens(usage.cached_tokens)}")
77
- parts.append(f"Output: {_format_tokens(usage.output_tokens)}")
78
- parts.append(f"Total: {_format_tokens(usage.total_tokens)}")
117
+ # Total cost line
118
+ total_cost_str = _format_cost(aggregated.total.total_cost, aggregated.total.currency)
119
+ lines.append(f"Total cost: {total_cost_str}")
79
120
 
80
- if usage.total_cost is not None:
81
- parts.append(f"Cost: {_format_cost(usage.total_cost, usage.currency)}")
121
+ # Per-model breakdown
122
+ if aggregated.by_model:
123
+ lines.append("Usage by model:")
124
+ for stats in aggregated.by_model:
125
+ lines.append(_format_model_usage_line(stats))
82
126
 
83
- return ", ".join(parts)
127
+ return "\n".join(lines)
84
128
 
85
129
 
86
130
  @register_command
@@ -95,22 +139,20 @@ class StatusCommand(CommandABC):
95
139
  def summary(self) -> str:
96
140
  return "Show session usage statistics"
97
141
 
98
- async def run(self, raw: str, agent: Agent) -> CommandResult:
142
+ async def run(self, raw: str, agent: "Agent") -> CommandResult:
99
143
  session = agent.session
100
- usage, task_count = accumulate_session_usage(session)
144
+ aggregated = accumulate_session_usage(session)
101
145
 
102
146
  event = events.DeveloperMessageEvent(
103
147
  session_id=session.id,
104
148
  item=model.DeveloperMessageItem(
105
- content=format_status_content(usage),
149
+ content=format_status_content(aggregated),
106
150
  command_output=model.CommandOutput(
107
151
  command_name=self.name,
108
- ui_extra=model.ToolResultUIExtra(
109
- type=model.ToolResultUIExtraType.SESSION_STATUS,
110
- session_status=model.SessionStatusUIExtra(
111
- usage=usage,
112
- task_count=task_count,
113
- ),
152
+ ui_extra=model.SessionStatusUIExtra(
153
+ usage=aggregated.total,
154
+ task_count=aggregated.task_count,
155
+ by_model=aggregated.by_model,
114
156
  ),
115
157
  ),
116
158
  ),
@@ -1,12 +1,15 @@
1
1
  import os
2
2
  import subprocess
3
3
  from pathlib import Path
4
+ from typing import TYPE_CHECKING
4
5
 
5
6
  from klaude_code.command.command_abc import CommandABC, CommandResult
6
7
  from klaude_code.command.registry import register_command
7
- from klaude_code.core.agent import Agent
8
8
  from klaude_code.protocol import commands, events, model
9
9
 
10
+ if TYPE_CHECKING:
11
+ from klaude_code.core.agent import Agent
12
+
10
13
 
11
14
  @register_command
12
15
  class TerminalSetupCommand(CommandABC):
@@ -24,7 +27,7 @@ class TerminalSetupCommand(CommandABC):
24
27
  def is_interactive(self) -> bool:
25
28
  return False
26
29
 
27
- async def run(self, raw: str, agent: Agent) -> CommandResult:
30
+ async def run(self, raw: str, agent: "Agent") -> CommandResult:
28
31
  term_program = os.environ.get("TERM_PROGRAM", "").lower()
29
32
 
30
33
  try:
@@ -223,7 +226,7 @@ class TerminalSetupCommand(CommandABC):
223
226
 
224
227
  return message
225
228
 
226
- def _create_success_result(self, agent: Agent, message: str) -> CommandResult:
229
+ def _create_success_result(self, agent: "Agent", message: str) -> CommandResult:
227
230
  """Create success result"""
228
231
  return CommandResult(
229
232
  events=[
@@ -237,7 +240,7 @@ class TerminalSetupCommand(CommandABC):
237
240
  ]
238
241
  )
239
242
 
240
- def _create_error_result(self, agent: Agent, message: str) -> CommandResult:
243
+ def _create_error_result(self, agent: "Agent", message: str) -> CommandResult:
241
244
  """Create error result"""
242
245
  return CommandResult(
243
246
  events=[
@@ -91,7 +91,7 @@ INVALID_TOOL_CALL_MAX_LENGTH = 500
91
91
  TRUNCATE_DISPLAY_MAX_LINE_LENGTH = 1000
92
92
 
93
93
  # Maximum lines for truncated display output
94
- TRUNCATE_DISPLAY_MAX_LINES = 10
94
+ TRUNCATE_DISPLAY_MAX_LINES = 20
95
95
 
96
96
  # Maximum lines for sub-agent result display
97
97
  SUB_AGENT_RESULT_MAX_LINES = 12
klaude_code/core/agent.py CHANGED
@@ -1,34 +1,51 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import AsyncGenerator, Iterable
3
+ from collections.abc import AsyncGenerator, Callable, Iterable
4
4
  from dataclasses import dataclass
5
- from typing import Protocol
5
+ from typing import TYPE_CHECKING, Protocol
6
6
 
7
7
  from klaude_code.core.prompt import get_system_prompt as load_system_prompt
8
8
  from klaude_code.core.reminders import Reminder, load_agent_reminders
9
- from klaude_code.core.task import TaskExecutionContext, TaskExecutor
10
- from klaude_code.core.tool import TodoContext, get_registry, load_agent_tools
9
+ from klaude_code.core.task import SessionContext, TaskExecutionContext, TaskExecutor
10
+ from klaude_code.core.tool import build_todo_context, get_registry, load_agent_tools
11
11
  from klaude_code.llm import LLMClientABC
12
12
  from klaude_code.protocol import events, llm_param, model, tools
13
13
  from klaude_code.protocol.model import UserInputPayload
14
14
  from klaude_code.session import Session
15
15
  from klaude_code.trace import DebugType, log_debug
16
16
 
17
+ if TYPE_CHECKING:
18
+ from klaude_code.core.manager.llm_clients import LLMClients
19
+
17
20
 
18
21
  @dataclass(frozen=True)
19
22
  class AgentProfile:
20
23
  """Encapsulates the active LLM client plus prompts/tools/reminders."""
21
24
 
22
- llm_client: LLMClientABC
25
+ llm_client_factory: Callable[[], LLMClientABC]
23
26
  system_prompt: str | None
24
27
  tools: list[llm_param.ToolSchema]
25
28
  reminders: list[Reminder]
26
29
 
30
+ _llm_client: LLMClientABC | None = None
31
+
32
+ @property
33
+ def llm_client(self) -> LLMClientABC:
34
+ if self._llm_client is None:
35
+ object.__setattr__(self, "_llm_client", self.llm_client_factory())
36
+ return self._llm_client # type: ignore[return-value]
37
+
27
38
 
28
39
  class ModelProfileProvider(Protocol):
29
40
  """Strategy interface for constructing agent profiles."""
30
41
 
31
42
  def build_profile(
43
+ self,
44
+ llm_clients: LLMClients,
45
+ sub_agent_type: tools.SubAgentType | None = None,
46
+ ) -> AgentProfile: ...
47
+
48
+ def build_profile_eager(
32
49
  self,
33
50
  llm_client: LLMClientABC,
34
51
  sub_agent_type: tools.SubAgentType | None = None,
@@ -39,13 +56,26 @@ class DefaultModelProfileProvider(ModelProfileProvider):
39
56
  """Default provider backed by global prompts/tool/reminder registries."""
40
57
 
41
58
  def build_profile(
59
+ self,
60
+ llm_clients: LLMClients,
61
+ sub_agent_type: tools.SubAgentType | None = None,
62
+ ) -> AgentProfile:
63
+ model_name = llm_clients.main_model_name
64
+ return AgentProfile(
65
+ llm_client_factory=lambda: llm_clients.main,
66
+ system_prompt=load_system_prompt(model_name, sub_agent_type),
67
+ tools=load_agent_tools(model_name, sub_agent_type),
68
+ reminders=load_agent_reminders(model_name, sub_agent_type),
69
+ )
70
+
71
+ def build_profile_eager(
42
72
  self,
43
73
  llm_client: LLMClientABC,
44
74
  sub_agent_type: tools.SubAgentType | None = None,
45
75
  ) -> AgentProfile:
46
76
  model_name = llm_client.model_name
47
77
  return AgentProfile(
48
- llm_client=llm_client,
78
+ llm_client_factory=lambda: llm_client,
49
79
  system_prompt=load_system_prompt(model_name, sub_agent_type),
50
80
  tools=load_agent_tools(model_name, sub_agent_type),
51
81
  reminders=load_agent_reminders(model_name, sub_agent_type),
@@ -56,13 +86,26 @@ class VanillaModelProfileProvider(ModelProfileProvider):
56
86
  """Provider that strips prompts, reminders, and tools for vanilla mode."""
57
87
 
58
88
  def build_profile(
89
+ self,
90
+ llm_clients: LLMClients,
91
+ sub_agent_type: tools.SubAgentType | None = None,
92
+ ) -> AgentProfile:
93
+ model_name = llm_clients.main_model_name
94
+ return AgentProfile(
95
+ llm_client_factory=lambda: llm_clients.main,
96
+ system_prompt=None,
97
+ tools=load_agent_tools(model_name, vanilla=True),
98
+ reminders=load_agent_reminders(model_name, vanilla=True),
99
+ )
100
+
101
+ def build_profile_eager(
59
102
  self,
60
103
  llm_client: LLMClientABC,
61
104
  sub_agent_type: tools.SubAgentType | None = None,
62
105
  ) -> AgentProfile:
63
106
  model_name = llm_client.model_name
64
107
  return AgentProfile(
65
- llm_client=llm_client,
108
+ llm_client_factory=lambda: llm_client,
66
109
  system_prompt=None,
67
110
  tools=load_agent_tools(model_name, vanilla=True),
68
111
  reminders=load_agent_reminders(model_name, vanilla=True),
@@ -74,13 +117,13 @@ class Agent:
74
117
  self,
75
118
  session: Session,
76
119
  profile: AgentProfile,
120
+ model_name: str | None = None,
77
121
  ):
78
122
  self.session: Session = session
79
- self.profile: AgentProfile | None = None
80
- # Active task executor, if any
123
+ self.profile: AgentProfile = profile
81
124
  self._current_task: TaskExecutor | None = None
82
- # Ensure runtime configuration matches the active model on initialization
83
- self.set_model_profile(profile)
125
+ if not self.session.model_name and model_name:
126
+ self.session.model_name = model_name
84
127
 
85
128
  def cancel(self) -> Iterable[events.Event]:
86
129
  """Handle agent cancellation and persist an interrupt marker and tool cancellations.
@@ -106,17 +149,17 @@ class Agent:
106
149
  )
107
150
 
108
151
  async def run_task(self, user_input: UserInputPayload) -> AsyncGenerator[events.Event, None]:
109
- context = TaskExecutionContext(
152
+ session_ctx = SessionContext(
110
153
  session_id=self.session.id,
111
- profile=self._require_profile(),
112
154
  get_conversation_history=lambda: self.session.conversation_history,
113
155
  append_history=self.session.append_history,
114
- tool_registry=get_registry(),
115
156
  file_tracker=self.session.file_tracker,
116
- todo_context=TodoContext(
117
- get_todos=lambda: self.session.todos,
118
- set_todos=lambda todos: setattr(self.session, "todos", todos),
119
- ),
157
+ todo_context=build_todo_context(self.session),
158
+ )
159
+ context = TaskExecutionContext(
160
+ session_ctx=session_ctx,
161
+ profile=self.profile,
162
+ tool_registry=get_registry(),
120
163
  process_reminder=self._process_reminder,
121
164
  sub_agent_state=self.session.sub_agent_state,
122
165
  )
@@ -149,17 +192,14 @@ class Agent:
149
192
  self.session.append_history([item])
150
193
  yield events.DeveloperMessageEvent(session_id=self.session.id, item=item)
151
194
 
152
- def set_model_profile(self, profile: AgentProfile) -> None:
195
+ def set_model_profile(self, profile: AgentProfile, model_name: str | None = None) -> None:
153
196
  """Apply a fully constructed profile to the agent."""
154
197
 
155
198
  self.profile = profile
156
- if not self.session.model_name:
199
+ if model_name:
200
+ self.session.model_name = model_name
201
+ elif not self.session.model_name:
157
202
  self.session.model_name = profile.llm_client.model_name
158
203
 
159
204
  def get_llm_client(self) -> LLMClientABC:
160
- return self._require_profile().llm_client
161
-
162
- def _require_profile(self) -> AgentProfile:
163
- if self.profile is None:
164
- raise RuntimeError("Agent profile is not initialized")
165
- return self.profile
205
+ return self.profile.llm_client
@@ -117,7 +117,7 @@ class ExecutorContext:
117
117
  if operation.session_id is None:
118
118
  raise ValueError("session_id cannot be None")
119
119
 
120
- await self.agent_manager.ensure_agent(operation.session_id)
120
+ await self.agent_manager.ensure_agent(operation.session_id, is_new_session=operation.is_new_session)
121
121
 
122
122
  async def handle_user_input(self, operation: op.UserInputOperation) -> None:
123
123
  """Handle a user input operation by running it through an agent."""
@@ -482,4 +482,4 @@ class Executor:
482
482
 
483
483
  # Static type check: ExecutorContext must satisfy OperationHandler protocol.
484
484
  # If this line causes a type error, ExecutorContext is missing required methods.
485
- _: type[OperationHandler] = ExecutorContext # pyright: ignore[reportUnusedVariable]
485
+ _: type[OperationHandler] = ExecutorContext
@@ -38,16 +38,15 @@ class AgentManager:
38
38
 
39
39
  await self._event_queue.put(event)
40
40
 
41
- async def ensure_agent(self, session_id: str) -> Agent:
41
+ async def ensure_agent(self, session_id: str, *, is_new_session: bool = False) -> Agent:
42
42
  """Return an existing agent for the session or create a new one."""
43
-
44
43
  agent = self._active_agents.get(session_id)
45
44
  if agent is not None:
46
45
  return agent
47
46
 
48
- session = Session.load(session_id)
49
- profile = self._model_profile_provider.build_profile(self._llm_clients.main)
50
- agent = Agent(session=session, profile=profile)
47
+ session = Session.load(session_id, skip_if_missing=is_new_session)
48
+ profile = self._model_profile_provider.build_profile(self._llm_clients)
49
+ agent = Agent(session=session, profile=profile, model_name=self._llm_clients.main_model_name)
51
50
 
52
51
  async for evt in agent.replay_history():
53
52
  await self.emit_event(evt)
@@ -55,7 +54,7 @@ class AgentManager:
55
54
  await self.emit_event(
56
55
  events.WelcomeEvent(
57
56
  work_dir=str(session.work_dir),
58
- llm_config=self._llm_clients.main.get_llm_config(),
57
+ llm_config=self._llm_clients.get_llm_config(),
59
58
  )
60
59
  )
61
60
 
@@ -76,7 +75,7 @@ class AgentManager:
76
75
 
77
76
  llm_config = config.get_model_config(model_name)
78
77
  llm_client = create_llm_client(llm_config)
79
- agent.set_model_profile(self._model_profile_provider.build_profile(llm_client))
78
+ agent.set_model_profile(self._model_profile_provider.build_profile_eager(llm_client), model_name=model_name)
80
79
 
81
80
  developer_item = model.DeveloperMessageItem(
82
81
  content=f"switched to model: {model_name}",
@@ -2,41 +2,66 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from dataclasses import dataclass
6
- from dataclasses import field as dataclass_field
5
+ from collections.abc import Callable
7
6
 
8
7
  from klaude_code.llm.client import LLMClientABC
8
+ from klaude_code.protocol import llm_param
9
9
  from klaude_code.protocol.tools import SubAgentType
10
10
 
11
11
 
12
- def _default_sub_clients() -> dict[SubAgentType, LLMClientABC]:
13
- """Return an empty mapping for sub-agent clients.
12
+ class LLMClients:
13
+ """Container for LLM clients used by main agent and sub-agents."""
14
14
 
15
- Defined separately so static type checkers can infer the dictionary
16
- key and value types instead of treating them as ``Unknown``.
17
- """
15
+ def __init__(
16
+ self,
17
+ main_factory: Callable[[], LLMClientABC],
18
+ main_model_name: str,
19
+ main_llm_config: llm_param.LLMConfigParameter,
20
+ ) -> None:
21
+ self._main_factory: Callable[[], LLMClientABC] | None = main_factory
22
+ self._main_client: LLMClientABC | None = None
23
+ self._main_model_name: str = main_model_name
24
+ self._main_llm_config: llm_param.LLMConfigParameter = main_llm_config
25
+ self._sub_clients: dict[SubAgentType, LLMClientABC] = {}
26
+ self._sub_factories: dict[SubAgentType, Callable[[], LLMClientABC]] = {}
18
27
 
19
- return {}
28
+ @property
29
+ def main_model_name(self) -> str:
30
+ return self._main_model_name
20
31
 
32
+ def get_llm_config(self) -> llm_param.LLMConfigParameter:
33
+ return self._main_llm_config
21
34
 
22
- @dataclass
23
- class LLMClients:
24
- """Container for LLM clients used by main agent and sub-agents."""
35
+ @property
36
+ def main(self) -> LLMClientABC:
37
+ if self._main_client is None:
38
+ if self._main_factory is None:
39
+ raise RuntimeError("Main client factory not set")
40
+ self._main_client = self._main_factory()
41
+ self._main_factory = None
42
+ return self._main_client
25
43
 
26
- main: LLMClientABC
27
- sub_clients: dict[SubAgentType, LLMClientABC] = dataclass_field(default_factory=_default_sub_clients)
44
+ def register_sub_client_factory(
45
+ self,
46
+ sub_agent_type: SubAgentType,
47
+ factory: Callable[[], LLMClientABC],
48
+ ) -> None:
49
+ self._sub_factories[sub_agent_type] = factory
28
50
 
29
51
  def get_client(self, sub_agent_type: SubAgentType | None = None) -> LLMClientABC:
30
- """Return client for a sub-agent type or the main client.
52
+ """Return client for a sub-agent type or the main client."""
31
53
 
32
- Args:
33
- sub_agent_type: Optional sub-agent type whose client should be returned.
54
+ if sub_agent_type is None:
55
+ return self.main
34
56
 
35
- Returns:
36
- The LLM client corresponding to the sub-agent type, or the main client
37
- when no specialized client is available.
38
- """
57
+ existing = self._sub_clients.get(sub_agent_type)
58
+ if existing is not None:
59
+ return existing
39
60
 
40
- if sub_agent_type is None:
61
+ factory = self._sub_factories.get(sub_agent_type)
62
+ if factory is None:
41
63
  return self.main
42
- return self.sub_clients.get(sub_agent_type) or self.main
64
+
65
+ client = factory()
66
+ self._sub_clients[sub_agent_type] = client
67
+ return client
@@ -32,18 +32,30 @@ def build_llm_clients(
32
32
  debug_type=DebugType.LLM_CONFIG,
33
33
  )
34
34
 
35
- main_client = create_llm_client(llm_config)
36
- sub_clients: dict[SubAgentType, LLMClientABC] = {}
35
+ main_model_name = str(llm_config.model)
36
+
37
+ def _main_factory() -> LLMClientABC:
38
+ return create_llm_client(llm_config)
39
+
40
+ clients = LLMClients(
41
+ main_factory=_main_factory,
42
+ main_model_name=main_model_name,
43
+ main_llm_config=llm_config,
44
+ )
37
45
 
38
- # Initialize sub-agent clients
39
46
  for sub_agent_type in enabled_sub_agents or []:
40
47
  model_name = config.subagent_models.get(sub_agent_type)
41
48
  if not model_name:
42
49
  continue
50
+
43
51
  profile = get_sub_agent_profile(sub_agent_type)
44
- if not profile.enabled_for_model(main_client.model_name):
52
+ if not profile.enabled_for_model(main_model_name):
45
53
  continue
46
- sub_llm_config = config.get_model_config(model_name)
47
- sub_clients[sub_agent_type] = create_llm_client(sub_llm_config)
48
54
 
49
- return LLMClients(main=main_client, sub_clients=sub_clients)
55
+ def _factory(model_name_for_factory: str = model_name) -> LLMClientABC:
56
+ sub_llm_config = config.get_model_config(model_name_for_factory)
57
+ return create_llm_client(sub_llm_config)
58
+
59
+ clients.register_sub_client_factory(sub_agent_type, _factory)
60
+
61
+ return clients
@@ -43,7 +43,7 @@ class SubAgentManager:
43
43
  child_session = Session(work_dir=parent_session.work_dir)
44
44
  child_session.sub_agent_state = state
45
45
 
46
- child_profile = self._model_profile_provider.build_profile(
46
+ child_profile = self._model_profile_provider.build_profile_eager(
47
47
  self._llm_clients.get_client(state.sub_agent_type),
48
48
  state.sub_agent_type,
49
49
  )
@@ -58,13 +58,17 @@ class SubAgentManager:
58
58
  try:
59
59
  # Not emit the subtask's user input since task tool call is already rendered
60
60
  result: str = ""
61
+ task_metadata: model.TaskMetadata | None = None
61
62
  sub_agent_input = model.UserInputPayload(text=state.sub_agent_prompt, images=None)
62
63
  async for event in child_agent.run_task(sub_agent_input):
63
64
  # Capture TaskFinishEvent content for return
64
65
  if isinstance(event, events.TaskFinishEvent):
65
66
  result = event.task_result
67
+ # Capture TaskMetadataEvent for metadata propagation
68
+ elif isinstance(event, events.TaskMetadataEvent):
69
+ task_metadata = event.metadata.main
66
70
  await self.emit_event(event)
67
- return SubAgentResult(task_result=result, session_id=child_session.id)
71
+ return SubAgentResult(task_result=result, session_id=child_session.id, task_metadata=task_metadata)
68
72
  except asyncio.CancelledError:
69
73
  # Propagate cancellation so tooling can treat it as user interrupt
70
74
  log_debug(