axion-code 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axion/__init__.py +3 -0
- axion/api/__init__.py +0 -0
- axion/api/anthropic.py +460 -0
- axion/api/client.py +259 -0
- axion/api/error.py +161 -0
- axion/api/ollama.py +597 -0
- axion/api/openai_compat.py +805 -0
- axion/api/openai_responses.py +627 -0
- axion/api/prompt_cache.py +31 -0
- axion/api/sse.py +98 -0
- axion/api/types.py +451 -0
- axion/cli/__init__.py +0 -0
- axion/cli/init_cmd.py +50 -0
- axion/cli/input.py +290 -0
- axion/cli/main.py +2953 -0
- axion/cli/render.py +489 -0
- axion/cli/tui.py +766 -0
- axion/commands/__init__.py +0 -0
- axion/commands/handlers/__init__.py +0 -0
- axion/commands/handlers/agents.py +51 -0
- axion/commands/handlers/builtin_commands.py +367 -0
- axion/commands/handlers/mcp.py +59 -0
- axion/commands/handlers/models.py +75 -0
- axion/commands/handlers/plugins.py +55 -0
- axion/commands/handlers/skills.py +61 -0
- axion/commands/parsing.py +317 -0
- axion/commands/registry.py +166 -0
- axion/compat_harness/__init__.py +0 -0
- axion/compat_harness/extractor.py +145 -0
- axion/plugins/__init__.py +0 -0
- axion/plugins/hooks.py +22 -0
- axion/plugins/manager.py +391 -0
- axion/plugins/manifest.py +270 -0
- axion/runtime/__init__.py +0 -0
- axion/runtime/bash.py +388 -0
- axion/runtime/bootstrap.py +39 -0
- axion/runtime/claude_subscription.py +300 -0
- axion/runtime/compact.py +233 -0
- axion/runtime/config.py +397 -0
- axion/runtime/conversation.py +1073 -0
- axion/runtime/file_ops.py +613 -0
- axion/runtime/git.py +213 -0
- axion/runtime/hooks.py +235 -0
- axion/runtime/image.py +212 -0
- axion/runtime/lanes.py +282 -0
- axion/runtime/lsp.py +425 -0
- axion/runtime/mcp/__init__.py +0 -0
- axion/runtime/mcp/client.py +76 -0
- axion/runtime/mcp/lifecycle.py +96 -0
- axion/runtime/mcp/stdio.py +318 -0
- axion/runtime/mcp/tool_bridge.py +79 -0
- axion/runtime/memory.py +196 -0
- axion/runtime/oauth.py +329 -0
- axion/runtime/openai_subscription.py +346 -0
- axion/runtime/permissions.py +247 -0
- axion/runtime/plan_mode.py +96 -0
- axion/runtime/policy_engine.py +259 -0
- axion/runtime/prompt.py +586 -0
- axion/runtime/recovery.py +261 -0
- axion/runtime/remote.py +28 -0
- axion/runtime/sandbox.py +68 -0
- axion/runtime/scheduler.py +231 -0
- axion/runtime/session.py +365 -0
- axion/runtime/sharing.py +159 -0
- axion/runtime/skills.py +124 -0
- axion/runtime/tasks.py +258 -0
- axion/runtime/usage.py +241 -0
- axion/runtime/workers.py +186 -0
- axion/telemetry/__init__.py +0 -0
- axion/telemetry/events.py +67 -0
- axion/telemetry/profile.py +49 -0
- axion/telemetry/sink.py +60 -0
- axion/telemetry/tracer.py +95 -0
- axion/tools/__init__.py +0 -0
- axion/tools/lane_completion.py +33 -0
- axion/tools/registry.py +853 -0
- axion/tools/tool_search.py +226 -0
- axion_code-1.0.0.dist-info/METADATA +709 -0
- axion_code-1.0.0.dist-info/RECORD +82 -0
- axion_code-1.0.0.dist-info/WHEEL +4 -0
- axion_code-1.0.0.dist-info/entry_points.txt +2 -0
- axion_code-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1073 @@
|
|
|
1
|
+
"""Core conversation loop - coordinates model, tools, hooks, and session.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/runtime/src/conversation.rs
|
|
4
|
+
|
|
5
|
+
The ConversationRuntime orchestrates the full model turn loop including:
|
|
6
|
+
- Streaming model responses and assembling tool-use blocks
|
|
7
|
+
- Pre/post tool-use hook integration with permission override support
|
|
8
|
+
- Auto-compaction when cumulative input tokens exceed a threshold
|
|
9
|
+
- Session tracing for observability (turn lifecycle, tool execution)
|
|
10
|
+
- Prompt cache event collection from stream metadata
|
|
11
|
+
- Builder pattern for ergonomic construction
|
|
12
|
+
- Session forking for parallel exploration branches
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
import time
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from typing import Any, Callable, Protocol, runtime_checkable
|
|
23
|
+
|
|
24
|
+
from axion.api.client import (
|
|
25
|
+
ProviderClient,
|
|
26
|
+
max_tokens_for_model,
|
|
27
|
+
resolve_model_alias,
|
|
28
|
+
)
|
|
29
|
+
from axion.api.types import (
|
|
30
|
+
ContentBlockDeltaEvent,
|
|
31
|
+
ContentBlockStartEvent,
|
|
32
|
+
InputJsonDelta,
|
|
33
|
+
InputMessage,
|
|
34
|
+
MessageDeltaEvent,
|
|
35
|
+
MessageRequest,
|
|
36
|
+
MessageStartEvent,
|
|
37
|
+
MessageStopEvent,
|
|
38
|
+
TextDelta,
|
|
39
|
+
ThinkingDelta,
|
|
40
|
+
ToolChoice,
|
|
41
|
+
ToolDefinition,
|
|
42
|
+
ToolUseOutputBlock,
|
|
43
|
+
)
|
|
44
|
+
from axion.runtime.compact import (
|
|
45
|
+
CompactionConfig,
|
|
46
|
+
CompactionResult,
|
|
47
|
+
compact_session,
|
|
48
|
+
estimate_session_tokens,
|
|
49
|
+
)
|
|
50
|
+
from axion.runtime.hooks import HookRunner
|
|
51
|
+
from axion.runtime.permissions import (
|
|
52
|
+
TOOL_PERMISSION_REQUIREMENTS,
|
|
53
|
+
PermissionAllow,
|
|
54
|
+
PermissionDeny,
|
|
55
|
+
PermissionMode,
|
|
56
|
+
PermissionOutcome,
|
|
57
|
+
PermissionOverride,
|
|
58
|
+
PermissionPolicy,
|
|
59
|
+
PermissionPromptDecision,
|
|
60
|
+
PermissionPrompter,
|
|
61
|
+
PermissionRequest,
|
|
62
|
+
)
|
|
63
|
+
from axion.runtime.session import (
|
|
64
|
+
ContentBlock,
|
|
65
|
+
ConversationMessage,
|
|
66
|
+
ImageBlock,
|
|
67
|
+
MessageRole,
|
|
68
|
+
Session,
|
|
69
|
+
SessionFork,
|
|
70
|
+
TextBlock,
|
|
71
|
+
ToolResultBlock,
|
|
72
|
+
ToolUseBlock,
|
|
73
|
+
)
|
|
74
|
+
from axion.runtime.usage import TokenUsage, UsageTracker
|
|
75
|
+
from axion.telemetry.tracer import SessionTracer
|
|
76
|
+
|
|
77
|
+
logger = logging.getLogger(__name__)
|
|
78
|
+
|
|
79
|
+
DEFAULT_AUTO_COMPACTION_THRESHOLD = 100_000
|
|
80
|
+
_ENV_COMPACTION_KEY = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS"
|
|
81
|
+
|
|
82
|
+
# Context window sizes per model family (in tokens)
|
|
83
|
+
_CONTEXT_WINDOWS: dict[str, int] = {
|
|
84
|
+
"claude-opus": 200_000,
|
|
85
|
+
"claude-sonnet": 200_000,
|
|
86
|
+
"claude-haiku": 200_000,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# ---------------------------------------------------------------------------
|
|
91
|
+
# Protocols (traits)
|
|
92
|
+
# ---------------------------------------------------------------------------
|
|
93
|
+
|
|
94
|
+
@runtime_checkable
|
|
95
|
+
class ToolExecutor(Protocol):
|
|
96
|
+
"""Trait for tool dispatchers that execute model-requested tools."""
|
|
97
|
+
|
|
98
|
+
async def execute(self, tool_name: str, tool_input: str) -> str: ...
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ---------------------------------------------------------------------------
|
|
102
|
+
# Events emitted during a turn
|
|
103
|
+
# ---------------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
@dataclass(frozen=True)
|
|
106
|
+
class AssistantTextDelta:
|
|
107
|
+
"""Incremental text chunk from the model."""
|
|
108
|
+
text: str
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclass(frozen=True)
|
|
112
|
+
class AssistantToolUse:
|
|
113
|
+
"""Model requested a tool invocation."""
|
|
114
|
+
id: str
|
|
115
|
+
name: str
|
|
116
|
+
input: str
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass(frozen=True)
|
|
120
|
+
class AssistantUsage:
|
|
121
|
+
"""Token usage snapshot for a single iteration."""
|
|
122
|
+
usage: TokenUsage
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass(frozen=True)
|
|
126
|
+
class AssistantPromptCache:
|
|
127
|
+
"""Prompt cache hit/miss information from streaming metadata."""
|
|
128
|
+
cache_creation_input_tokens: int
|
|
129
|
+
cache_read_input_tokens: int
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass(frozen=True)
|
|
133
|
+
class AssistantMessageStop:
|
|
134
|
+
"""End of model message, includes stop reason."""
|
|
135
|
+
stop_reason: str | None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
AssistantEvent = (
|
|
139
|
+
AssistantTextDelta
|
|
140
|
+
| AssistantToolUse
|
|
141
|
+
| AssistantUsage
|
|
142
|
+
| AssistantPromptCache
|
|
143
|
+
| AssistantMessageStop
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# ---------------------------------------------------------------------------
|
|
148
|
+
# Prompt cache event tracking
|
|
149
|
+
# ---------------------------------------------------------------------------
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class PromptCacheEvent:
|
|
153
|
+
"""Collected prompt cache stats from a single streaming response."""
|
|
154
|
+
cache_creation_input_tokens: int = 0
|
|
155
|
+
cache_read_input_tokens: int = 0
|
|
156
|
+
timestamp_ms: int = 0
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# ---------------------------------------------------------------------------
|
|
160
|
+
# Turn summary
|
|
161
|
+
# ---------------------------------------------------------------------------
|
|
162
|
+
|
|
163
|
+
@dataclass
|
|
164
|
+
class TurnSummary:
|
|
165
|
+
"""Summary of one completed runtime turn."""
|
|
166
|
+
|
|
167
|
+
assistant_messages: list[ConversationMessage] = field(default_factory=list)
|
|
168
|
+
tool_results: list[ConversationMessage] = field(default_factory=list)
|
|
169
|
+
iterations: int = 0
|
|
170
|
+
usage: TokenUsage = field(default_factory=TokenUsage)
|
|
171
|
+
text_output: str = ""
|
|
172
|
+
prompt_cache_events: list[PromptCacheEvent] = field(default_factory=list)
|
|
173
|
+
compaction_result: CompactionResult | None = None
|
|
174
|
+
was_auto_compacted: bool = False
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# ---------------------------------------------------------------------------
|
|
178
|
+
# Errors
|
|
179
|
+
# ---------------------------------------------------------------------------
|
|
180
|
+
|
|
181
|
+
class ConversationError(Exception):
|
|
182
|
+
"""Error during conversation turn."""
|
|
183
|
+
|
|
184
|
+
def __init__(self, message: str, *, cause: Exception | None = None) -> None:
|
|
185
|
+
super().__init__(message)
|
|
186
|
+
self.cause = cause
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class ToolError(Exception):
|
|
190
|
+
"""Error from tool execution."""
|
|
191
|
+
|
|
192
|
+
def __init__(
|
|
193
|
+
self,
|
|
194
|
+
message: str,
|
|
195
|
+
*,
|
|
196
|
+
tool_name: str = "",
|
|
197
|
+
tool_use_id: str = "",
|
|
198
|
+
cause: Exception | None = None,
|
|
199
|
+
) -> None:
|
|
200
|
+
super().__init__(message)
|
|
201
|
+
self.tool_name = tool_name
|
|
202
|
+
self.tool_use_id = tool_use_id
|
|
203
|
+
self.cause = cause
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class MaxIterationsError(ConversationError):
|
|
207
|
+
"""Raised when the tool loop exceeds max_iterations."""
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class PermissionDeniedError(ConversationError):
|
|
211
|
+
"""Raised when a tool is denied by permission policy or hooks."""
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class ContextWindowExceededError(ConversationError):
|
|
215
|
+
"""Raised when estimated tokens exceed the model's context window."""
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
# ---------------------------------------------------------------------------
|
|
219
|
+
# Conversation runtime
|
|
220
|
+
# ---------------------------------------------------------------------------
|
|
221
|
+
|
|
222
|
+
@dataclass
|
|
223
|
+
class ConversationRuntime:
|
|
224
|
+
"""Coordinates the model loop, tool execution, and session updates.
|
|
225
|
+
|
|
226
|
+
Maps to: rust/crates/runtime/src/conversation.rs::ConversationRuntime
|
|
227
|
+
|
|
228
|
+
The runtime implements the full agentic loop:
|
|
229
|
+
1. Send user message + history to model
|
|
230
|
+
2. Stream and assemble the response (text + tool_use blocks)
|
|
231
|
+
3. For each tool_use: run pre-hooks, check permissions, execute, run post-hooks
|
|
232
|
+
4. Append results, check auto-compaction, and loop
|
|
233
|
+
5. Return when model produces final text (end_turn) or max iterations reached
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
session: Session
|
|
237
|
+
provider: ProviderClient
|
|
238
|
+
tool_executor: ToolExecutor | None = None
|
|
239
|
+
permission_policy: PermissionPolicy = field(default_factory=PermissionPolicy)
|
|
240
|
+
permission_prompter: PermissionPrompter | None = None
|
|
241
|
+
hook_runner: HookRunner | None = None
|
|
242
|
+
session_tracer: SessionTracer | None = None
|
|
243
|
+
system_prompt: str = ""
|
|
244
|
+
model: str = "claude-sonnet-4-6"
|
|
245
|
+
max_iterations: int = 50
|
|
246
|
+
auto_compaction_threshold: int = field(default_factory=lambda: _resolve_compaction_threshold())
|
|
247
|
+
usage_tracker: UsageTracker = field(default_factory=UsageTracker)
|
|
248
|
+
on_event: Callable[[AssistantEvent], None] | None = None
|
|
249
|
+
on_text_delta: Callable[[str], None] | None = None
|
|
250
|
+
on_tool_use: Callable[[str, str], None] | None = None # (tool_name, tool_input)
|
|
251
|
+
on_tool_result: Callable[[str, str, bool], None] | None = None # (tool_name, output, is_error)
|
|
252
|
+
on_thinking: Callable[[str], None] | None = None # (thinking_text)
|
|
253
|
+
cost_budget_usd: float | None = None # Max spend per session (None = unlimited)
|
|
254
|
+
plan_mode_active: bool = False # When True, only read-only tools allowed
|
|
255
|
+
|
|
256
|
+
# -- Builder helpers -----------------------------------------------------
|
|
257
|
+
|
|
258
|
+
def with_max_iterations(self, n: int) -> ConversationRuntime:
|
|
259
|
+
"""Set maximum tool-loop iterations per turn."""
|
|
260
|
+
self.max_iterations = n
|
|
261
|
+
return self
|
|
262
|
+
|
|
263
|
+
def with_auto_compaction_threshold(self, tokens: int) -> ConversationRuntime:
|
|
264
|
+
"""Set the input-token threshold that triggers auto-compaction."""
|
|
265
|
+
self.auto_compaction_threshold = tokens
|
|
266
|
+
return self
|
|
267
|
+
|
|
268
|
+
def with_hook_runner(self, runner: HookRunner) -> ConversationRuntime:
|
|
269
|
+
"""Attach a hook runner for pre/post tool-use hooks."""
|
|
270
|
+
self.hook_runner = runner
|
|
271
|
+
return self
|
|
272
|
+
|
|
273
|
+
def with_session_tracer(self, tracer: SessionTracer) -> ConversationRuntime:
|
|
274
|
+
"""Attach a session tracer for observability."""
|
|
275
|
+
self.session_tracer = tracer
|
|
276
|
+
return self
|
|
277
|
+
|
|
278
|
+
def with_permission_prompter(self, prompter: PermissionPrompter) -> ConversationRuntime:
|
|
279
|
+
"""Attach an interactive permission prompter for PROMPT mode."""
|
|
280
|
+
self.permission_prompter = prompter
|
|
281
|
+
return self
|
|
282
|
+
|
|
283
|
+
def with_tool_executor(self, executor: ToolExecutor) -> ConversationRuntime:
|
|
284
|
+
"""Set the tool executor."""
|
|
285
|
+
self.tool_executor = executor
|
|
286
|
+
return self
|
|
287
|
+
|
|
288
|
+
def with_system_prompt(self, prompt: str) -> ConversationRuntime:
|
|
289
|
+
"""Set the system prompt."""
|
|
290
|
+
self.system_prompt = prompt
|
|
291
|
+
return self
|
|
292
|
+
|
|
293
|
+
# -- Session operations --------------------------------------------------
|
|
294
|
+
|
|
295
|
+
def fork_session(self, branch_name: str | None = None) -> ConversationRuntime:
|
|
296
|
+
"""Create a forked copy of this runtime with a new session.
|
|
297
|
+
|
|
298
|
+
The forked session shares the conversation history up to this point
|
|
299
|
+
but diverges from here. The fork metadata references the parent.
|
|
300
|
+
"""
|
|
301
|
+
import copy
|
|
302
|
+
|
|
303
|
+
forked_session = Session(
|
|
304
|
+
messages=copy.deepcopy(self.session.messages),
|
|
305
|
+
compaction=copy.deepcopy(self.session.compaction),
|
|
306
|
+
fork=SessionFork(
|
|
307
|
+
parent_session_id=self.session.session_id,
|
|
308
|
+
branch_name=branch_name,
|
|
309
|
+
),
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return ConversationRuntime(
|
|
313
|
+
session=forked_session,
|
|
314
|
+
provider=self.provider,
|
|
315
|
+
tool_executor=self.tool_executor,
|
|
316
|
+
permission_policy=self.permission_policy,
|
|
317
|
+
permission_prompter=self.permission_prompter,
|
|
318
|
+
hook_runner=self.hook_runner,
|
|
319
|
+
session_tracer=self.session_tracer,
|
|
320
|
+
system_prompt=self.system_prompt,
|
|
321
|
+
model=self.model,
|
|
322
|
+
max_iterations=self.max_iterations,
|
|
323
|
+
auto_compaction_threshold=self.auto_compaction_threshold,
|
|
324
|
+
usage_tracker=UsageTracker(),
|
|
325
|
+
on_event=self.on_event,
|
|
326
|
+
on_text_delta=self.on_text_delta,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def estimated_tokens(self) -> int:
|
|
330
|
+
"""Estimate the current token count of the session."""
|
|
331
|
+
return estimate_session_tokens(self.session)
|
|
332
|
+
|
|
333
|
+
# -- Preflight check -----------------------------------------------------
|
|
334
|
+
|
|
335
|
+
def _preflight_check(self) -> None:
|
|
336
|
+
"""Estimate token count and raise if it would exceed the model's context window.
|
|
337
|
+
|
|
338
|
+
Uses ~4 chars/token heuristic for the system prompt + messages.
|
|
339
|
+
"""
|
|
340
|
+
# Estimate system prompt tokens
|
|
341
|
+
system_tokens = len(self.system_prompt) // 4 if self.system_prompt else 0
|
|
342
|
+
|
|
343
|
+
# Estimate message tokens
|
|
344
|
+
message_tokens = estimate_session_tokens(self.session)
|
|
345
|
+
|
|
346
|
+
estimated_total = system_tokens + message_tokens
|
|
347
|
+
|
|
348
|
+
# Look up context window by model family prefix
|
|
349
|
+
resolved = resolve_model_alias(self.model)
|
|
350
|
+
context_window = 200_000 # default
|
|
351
|
+
for prefix, window in _CONTEXT_WINDOWS.items():
|
|
352
|
+
if resolved.startswith(prefix):
|
|
353
|
+
context_window = window
|
|
354
|
+
break
|
|
355
|
+
|
|
356
|
+
# Get max output tokens for the model
|
|
357
|
+
output_tokens = max_tokens_for_model(resolved)
|
|
358
|
+
|
|
359
|
+
if estimated_total + output_tokens > context_window:
|
|
360
|
+
raise ContextWindowExceededError(
|
|
361
|
+
f"Estimated {estimated_total} input tokens + {output_tokens} max output tokens "
|
|
362
|
+
f"= {estimated_total + output_tokens} exceeds context window of {context_window} "
|
|
363
|
+
f"for model {resolved}. Consider compacting the session."
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# -- Main turn loop ------------------------------------------------------
|
|
367
|
+
|
|
368
|
+
async def run_turn(
|
|
369
|
+
self,
|
|
370
|
+
user_input: str,
|
|
371
|
+
images: list[tuple[str, str]] | None = None,
|
|
372
|
+
) -> TurnSummary:
|
|
373
|
+
"""Execute a full model turn with tool loop.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
user_input: The user's text input.
|
|
377
|
+
images: Optional list of (media_type, base64_data) tuples for
|
|
378
|
+
image inputs (screenshots, pasted images).
|
|
379
|
+
|
|
380
|
+
1. Send user message + history to model
|
|
381
|
+
2. If model requests tools, execute them (with hooks) and loop
|
|
382
|
+
3. Auto-compact if token threshold is exceeded
|
|
383
|
+
4. Return when model produces final text (end_turn)
|
|
384
|
+
"""
|
|
385
|
+
self._trace("turn_started", {"user_input_length": len(user_input)})
|
|
386
|
+
|
|
387
|
+
if images:
|
|
388
|
+
# Push a combined text+image user message
|
|
389
|
+
blocks: list[ContentBlock] = []
|
|
390
|
+
for media_type, b64_data in images:
|
|
391
|
+
blocks.append(ImageBlock(media_type=media_type, data=b64_data))
|
|
392
|
+
if user_input:
|
|
393
|
+
blocks.append(TextBlock(text=user_input))
|
|
394
|
+
self.session.push_message(
|
|
395
|
+
ConversationMessage(role=MessageRole.USER, blocks=blocks)
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
self.session.push_user_text(user_input)
|
|
399
|
+
summary = TurnSummary()
|
|
400
|
+
iteration = 0
|
|
401
|
+
cumulative_input_tokens = 0
|
|
402
|
+
|
|
403
|
+
# Preflight: ensure we won't exceed the context window
|
|
404
|
+
self._preflight_check()
|
|
405
|
+
|
|
406
|
+
try:
|
|
407
|
+
api_messages = self._build_api_messages()
|
|
408
|
+
|
|
409
|
+
while iteration < self.max_iterations:
|
|
410
|
+
iteration += 1
|
|
411
|
+
summary.iterations = iteration
|
|
412
|
+
|
|
413
|
+
self._trace("assistant_iteration_started", {"iteration": iteration})
|
|
414
|
+
|
|
415
|
+
# Pre-check budget before making an API call
|
|
416
|
+
if self.cost_budget_usd is not None and iteration > 1:
|
|
417
|
+
current_cost = self.usage_tracker.total.estimate_cost_usd()
|
|
418
|
+
remaining = self.cost_budget_usd - current_cost.total_cost_usd()
|
|
419
|
+
if remaining <= 0:
|
|
420
|
+
summary.text_output += (
|
|
421
|
+
f"\n\n[Budget reached: ${current_cost.total_cost_usd():.4f} "
|
|
422
|
+
f"of ${self.cost_budget_usd:.4f}. Stopping before next API call.]"
|
|
423
|
+
)
|
|
424
|
+
break
|
|
425
|
+
|
|
426
|
+
# Stream one model response
|
|
427
|
+
stream_result = await self._stream_model_response(api_messages)
|
|
428
|
+
|
|
429
|
+
# Accumulate usage
|
|
430
|
+
summary.usage += stream_result.usage
|
|
431
|
+
self.usage_tracker.record_turn(stream_result.usage)
|
|
432
|
+
cumulative_input_tokens += stream_result.usage.input_tokens
|
|
433
|
+
|
|
434
|
+
# Check cost budget — soft stop (don't crash, just stop looping)
|
|
435
|
+
if self.cost_budget_usd is not None:
|
|
436
|
+
total_cost = self.usage_tracker.total.estimate_cost_usd()
|
|
437
|
+
if total_cost.total_cost_usd() >= self.cost_budget_usd:
|
|
438
|
+
logger.info(
|
|
439
|
+
"Cost budget reached: $%.4f >= $%.4f",
|
|
440
|
+
total_cost.total_cost_usd(), self.cost_budget_usd,
|
|
441
|
+
)
|
|
442
|
+
summary.text_output += (
|
|
443
|
+
f"\n\n[Budget reached: ${total_cost.total_cost_usd():.4f} "
|
|
444
|
+
f"of ${self.cost_budget_usd:.4f}. "
|
|
445
|
+
f"Use /cost for details or restart with a higher --budget.]"
|
|
446
|
+
)
|
|
447
|
+
break # Stop the tool loop gracefully
|
|
448
|
+
|
|
449
|
+
# Collect prompt cache events
|
|
450
|
+
if stream_result.prompt_cache_event:
|
|
451
|
+
summary.prompt_cache_events.append(stream_result.prompt_cache_event)
|
|
452
|
+
|
|
453
|
+
# Store assistant message in session
|
|
454
|
+
assistant_msg = self._build_assistant_message(
|
|
455
|
+
stream_result.text_parts, stream_result.tool_uses, stream_result.usage
|
|
456
|
+
)
|
|
457
|
+
if assistant_msg:
|
|
458
|
+
self.session.push_message(assistant_msg)
|
|
459
|
+
summary.assistant_messages.append(assistant_msg)
|
|
460
|
+
|
|
461
|
+
summary.text_output += stream_result.full_text
|
|
462
|
+
|
|
463
|
+
self._trace("assistant_iteration_completed", {
|
|
464
|
+
"iteration": iteration,
|
|
465
|
+
"tool_use_count": len(stream_result.tool_uses),
|
|
466
|
+
"stop_reason": stream_result.stop_reason or "unknown",
|
|
467
|
+
})
|
|
468
|
+
|
|
469
|
+
# If no tool uses or stop_reason is end_turn, we're done
|
|
470
|
+
if not stream_result.tool_uses or stream_result.stop_reason == "end_turn":
|
|
471
|
+
break
|
|
472
|
+
|
|
473
|
+
# Execute tools (with full hook integration)
|
|
474
|
+
tool_result_messages = await self._execute_tools_with_hooks(
|
|
475
|
+
stream_result.tool_uses
|
|
476
|
+
)
|
|
477
|
+
for trm in tool_result_messages:
|
|
478
|
+
self.session.push_message(trm)
|
|
479
|
+
summary.tool_results.append(trm)
|
|
480
|
+
|
|
481
|
+
# Auto-compaction check
|
|
482
|
+
compaction = self._maybe_auto_compact(cumulative_input_tokens)
|
|
483
|
+
if compaction is not None:
|
|
484
|
+
summary.compaction_result = compaction
|
|
485
|
+
summary.was_auto_compacted = True
|
|
486
|
+
logger.info(
|
|
487
|
+
"Auto-compacted session: %d -> %d estimated tokens",
|
|
488
|
+
compaction.estimated_tokens_before,
|
|
489
|
+
compaction.estimated_tokens_after,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Rebuild API messages for next iteration
|
|
493
|
+
api_messages = self._build_api_messages()
|
|
494
|
+
|
|
495
|
+
else:
|
|
496
|
+
# Loop ended without break -- max iterations exceeded
|
|
497
|
+
logger.warning(
|
|
498
|
+
"Turn reached max iterations (%d)", self.max_iterations
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
except Exception as exc:
|
|
502
|
+
self._trace("turn_failed", {"error": str(exc)})
|
|
503
|
+
raise ConversationError(
|
|
504
|
+
f"Turn failed at iteration {iteration}: {exc}", cause=exc
|
|
505
|
+
) from exc
|
|
506
|
+
|
|
507
|
+
self._trace("turn_completed", {
|
|
508
|
+
"iterations": summary.iterations,
|
|
509
|
+
"total_input_tokens": summary.usage.input_tokens,
|
|
510
|
+
"total_output_tokens": summary.usage.output_tokens,
|
|
511
|
+
"was_compacted": summary.was_auto_compacted,
|
|
512
|
+
})
|
|
513
|
+
|
|
514
|
+
return summary
|
|
515
|
+
|
|
516
|
+
# -- Streaming -----------------------------------------------------------
|
|
517
|
+
|
|
518
|
+
@dataclass
|
|
519
|
+
class _StreamResult:
|
|
520
|
+
"""Internal: assembled result from one streaming model response."""
|
|
521
|
+
text_parts: list[str] = field(default_factory=list)
|
|
522
|
+
thinking_parts: list[str] = field(default_factory=list)
|
|
523
|
+
tool_uses: list[dict[str, Any]] = field(default_factory=list)
|
|
524
|
+
usage: TokenUsage = field(default_factory=TokenUsage)
|
|
525
|
+
stop_reason: str | None = None
|
|
526
|
+
prompt_cache_event: PromptCacheEvent | None = None
|
|
527
|
+
|
|
528
|
+
@property
|
|
529
|
+
def full_text(self) -> str:
|
|
530
|
+
return "".join(self.text_parts)
|
|
531
|
+
|
|
532
|
+
def _build_tool_definitions(self) -> list[ToolDefinition] | None:
|
|
533
|
+
"""Build API tool definitions from the tool registry.
|
|
534
|
+
|
|
535
|
+
Returns None if no tool executor is configured (the model won't see any tools).
|
|
536
|
+
"""
|
|
537
|
+
if self.tool_executor is None:
|
|
538
|
+
return None
|
|
539
|
+
|
|
540
|
+
from axion.tools.registry import get_tool_registry
|
|
541
|
+
|
|
542
|
+
registry = get_tool_registry()
|
|
543
|
+
tools = []
|
|
544
|
+
for tool_def in registry.all_tools():
|
|
545
|
+
tools.append(ToolDefinition(
|
|
546
|
+
name=tool_def.spec.name,
|
|
547
|
+
description=tool_def.spec.description,
|
|
548
|
+
input_schema=tool_def.spec.input_schema,
|
|
549
|
+
))
|
|
550
|
+
return tools if tools else None
|
|
551
|
+
|
|
552
|
+
async def _stream_model_response(
|
|
553
|
+
self, api_messages: list[InputMessage]
|
|
554
|
+
) -> _StreamResult:
|
|
555
|
+
"""Stream a single model request and assemble the response."""
|
|
556
|
+
resolved_model = resolve_model_alias(self.model)
|
|
557
|
+
|
|
558
|
+
# Build tool definitions so the model knows what tools are available
|
|
559
|
+
tool_defs = self._build_tool_definitions()
|
|
560
|
+
|
|
561
|
+
request = MessageRequest(
|
|
562
|
+
model=resolved_model,
|
|
563
|
+
max_tokens=max_tokens_for_model(resolved_model),
|
|
564
|
+
messages=api_messages,
|
|
565
|
+
system=self.system_prompt or None,
|
|
566
|
+
tools=tool_defs,
|
|
567
|
+
tool_choice=ToolChoice.auto() if tool_defs else None,
|
|
568
|
+
stream=True,
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
result = ConversationRuntime._StreamResult()
|
|
572
|
+
current_tool_inputs: dict[int, list[str]] = {}
|
|
573
|
+
current_tool_blocks: dict[int, dict[str, Any]] = {}
|
|
574
|
+
|
|
575
|
+
async for event in self.provider.stream_message(request):
|
|
576
|
+
match event:
|
|
577
|
+
case MessageStartEvent(message=msg) if msg is not None:
|
|
578
|
+
result.usage.input_tokens = msg.usage.input_tokens
|
|
579
|
+
result.usage.cache_creation_input_tokens = (
|
|
580
|
+
msg.usage.cache_creation_input_tokens
|
|
581
|
+
)
|
|
582
|
+
result.usage.cache_read_input_tokens = (
|
|
583
|
+
msg.usage.cache_read_input_tokens
|
|
584
|
+
)
|
|
585
|
+
# Collect prompt cache event
|
|
586
|
+
if (
|
|
587
|
+
msg.usage.cache_creation_input_tokens > 0
|
|
588
|
+
or msg.usage.cache_read_input_tokens > 0
|
|
589
|
+
):
|
|
590
|
+
result.prompt_cache_event = PromptCacheEvent(
|
|
591
|
+
cache_creation_input_tokens=msg.usage.cache_creation_input_tokens,
|
|
592
|
+
cache_read_input_tokens=msg.usage.cache_read_input_tokens,
|
|
593
|
+
timestamp_ms=int(time.time() * 1000),
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
case ContentBlockStartEvent(index=idx, content_block=block):
|
|
597
|
+
if isinstance(block, ToolUseOutputBlock):
|
|
598
|
+
current_tool_blocks[idx] = {
|
|
599
|
+
"id": block.id,
|
|
600
|
+
"name": block.name,
|
|
601
|
+
}
|
|
602
|
+
current_tool_inputs[idx] = []
|
|
603
|
+
|
|
604
|
+
case ContentBlockDeltaEvent(index=idx, delta=delta):
|
|
605
|
+
if isinstance(delta, TextDelta) and delta.text:
|
|
606
|
+
result.text_parts.append(delta.text)
|
|
607
|
+
self._emit_event(AssistantTextDelta(text=delta.text))
|
|
608
|
+
if self.on_text_delta:
|
|
609
|
+
self.on_text_delta(delta.text)
|
|
610
|
+
elif isinstance(delta, InputJsonDelta):
|
|
611
|
+
if idx in current_tool_inputs:
|
|
612
|
+
current_tool_inputs[idx].append(delta.partial_json)
|
|
613
|
+
elif isinstance(delta, ThinkingDelta) and delta.thinking:
|
|
614
|
+
result.thinking_parts.append(delta.thinking)
|
|
615
|
+
if self.on_thinking:
|
|
616
|
+
try:
|
|
617
|
+
self.on_thinking(delta.thinking)
|
|
618
|
+
except Exception:
|
|
619
|
+
pass
|
|
620
|
+
|
|
621
|
+
case MessageDeltaEvent(delta=d, usage=u):
|
|
622
|
+
result.usage.output_tokens = u.output_tokens
|
|
623
|
+
result.stop_reason = d.stop_reason
|
|
624
|
+
|
|
625
|
+
case MessageStopEvent():
|
|
626
|
+
self._emit_event(
|
|
627
|
+
AssistantMessageStop(stop_reason=result.stop_reason)
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
# Assemble completed tool uses
|
|
631
|
+
for idx, block_info in current_tool_blocks.items():
|
|
632
|
+
input_json = "".join(current_tool_inputs.get(idx, []))
|
|
633
|
+
tool_use = {
|
|
634
|
+
"id": block_info["id"],
|
|
635
|
+
"name": block_info["name"],
|
|
636
|
+
"input": input_json,
|
|
637
|
+
}
|
|
638
|
+
result.tool_uses.append(tool_use)
|
|
639
|
+
self._emit_event(
|
|
640
|
+
AssistantToolUse(
|
|
641
|
+
id=tool_use["id"],
|
|
642
|
+
name=tool_use["name"],
|
|
643
|
+
input=input_json,
|
|
644
|
+
)
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
# Emit usage event
|
|
648
|
+
self._emit_event(AssistantUsage(usage=result.usage))
|
|
649
|
+
|
|
650
|
+
# Emit prompt cache event if present
|
|
651
|
+
if result.prompt_cache_event:
|
|
652
|
+
self._emit_event(AssistantPromptCache(
|
|
653
|
+
cache_creation_input_tokens=result.prompt_cache_event.cache_creation_input_tokens,
|
|
654
|
+
cache_read_input_tokens=result.prompt_cache_event.cache_read_input_tokens,
|
|
655
|
+
))
|
|
656
|
+
|
|
657
|
+
return result
|
|
658
|
+
|
|
659
|
+
# -- Tool execution with hooks -------------------------------------------
|
|
660
|
+
|
|
661
|
+
async def _execute_tools_with_hooks(
|
|
662
|
+
self, tool_uses: list[dict[str, Any]]
|
|
663
|
+
) -> list[ConversationMessage]:
|
|
664
|
+
"""Execute tool calls with full pre/post hook integration.
|
|
665
|
+
|
|
666
|
+
Agent tool calls are executed in parallel via asyncio.gather for
|
|
667
|
+
better performance. All other tools run sequentially to avoid
|
|
668
|
+
race conditions on shared state (filesystem, session, etc.).
|
|
669
|
+
"""
|
|
670
|
+
import asyncio
|
|
671
|
+
|
|
672
|
+
# Separate parallelizable (Agent) calls from sequential ones
|
|
673
|
+
PARALLEL_TOOLS = {"Agent"}
|
|
674
|
+
parallel_batch: list[dict[str, Any]] = []
|
|
675
|
+
sequential_queue: list[dict[str, Any]] = []
|
|
676
|
+
|
|
677
|
+
for tu in tool_uses:
|
|
678
|
+
if tu["name"] in PARALLEL_TOOLS:
|
|
679
|
+
parallel_batch.append(tu)
|
|
680
|
+
else:
|
|
681
|
+
sequential_queue.append(tu)
|
|
682
|
+
|
|
683
|
+
results: list[ConversationMessage] = []
|
|
684
|
+
|
|
685
|
+
# Execute sequential tools first (file ops, bash, etc.)
|
|
686
|
+
for tu in sequential_queue:
|
|
687
|
+
result_msg = await self._execute_single_tool(tu)
|
|
688
|
+
results.append(result_msg)
|
|
689
|
+
|
|
690
|
+
# Execute parallel tools concurrently
|
|
691
|
+
if parallel_batch:
|
|
692
|
+
if len(parallel_batch) == 1:
|
|
693
|
+
result_msg = await self._execute_single_tool(parallel_batch[0])
|
|
694
|
+
results.append(result_msg)
|
|
695
|
+
else:
|
|
696
|
+
logger.info(
|
|
697
|
+
"Executing %d Agent calls in parallel", len(parallel_batch)
|
|
698
|
+
)
|
|
699
|
+
parallel_results = await asyncio.gather(
|
|
700
|
+
*(self._execute_single_tool(tu) for tu in parallel_batch),
|
|
701
|
+
return_exceptions=True,
|
|
702
|
+
)
|
|
703
|
+
for i, res in enumerate(parallel_results):
|
|
704
|
+
if isinstance(res, BaseException):
|
|
705
|
+
tu_item = parallel_batch[i]
|
|
706
|
+
err_msg = self._make_tool_result(
|
|
707
|
+
tu_item["id"], tu_item["name"],
|
|
708
|
+
f"Agent execution failed: {res}",
|
|
709
|
+
is_error=True,
|
|
710
|
+
)
|
|
711
|
+
results.append(err_msg)
|
|
712
|
+
elif isinstance(res, ConversationMessage):
|
|
713
|
+
results.append(res)
|
|
714
|
+
|
|
715
|
+
return results
|
|
716
|
+
|
|
717
|
+
async def _execute_single_tool(
|
|
718
|
+
self, tu: dict[str, Any]
|
|
719
|
+
) -> ConversationMessage:
|
|
720
|
+
"""Execute a single tool call with full hook integration."""
|
|
721
|
+
tool_name = tu["name"]
|
|
722
|
+
tool_input = tu["input"]
|
|
723
|
+
tool_id = tu["id"]
|
|
724
|
+
|
|
725
|
+
self._trace("tool_execution_started", {
|
|
726
|
+
"tool_name": tool_name,
|
|
727
|
+
"tool_use_id": tool_id,
|
|
728
|
+
})
|
|
729
|
+
|
|
730
|
+
# ---- Phase 1: Pre-tool-use hooks ----
|
|
731
|
+
effective_input = tool_input
|
|
732
|
+
permission_override: PermissionOverride | None = None
|
|
733
|
+
|
|
734
|
+
if self.hook_runner:
|
|
735
|
+
pre_result = await self.hook_runner.run_pre_tool_use(
|
|
736
|
+
tool_name, tool_input
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
# Hook denied execution outright
|
|
740
|
+
if pre_result.denied:
|
|
741
|
+
deny_reason = "; ".join(pre_result.messages) or "Denied by pre-tool-use hook"
|
|
742
|
+
result_msg = self._make_tool_result(
|
|
743
|
+
tool_id, tool_name, f"Hook denied: {deny_reason}", is_error=True
|
|
744
|
+
)
|
|
745
|
+
self._trace("tool_execution_finished", {
|
|
746
|
+
"tool_name": tool_name,
|
|
747
|
+
"tool_use_id": tool_id,
|
|
748
|
+
"outcome": "hook_denied",
|
|
749
|
+
})
|
|
750
|
+
return result_msg
|
|
751
|
+
|
|
752
|
+
# Hook may have updated the input
|
|
753
|
+
if pre_result.updated_input is not None:
|
|
754
|
+
effective_input = pre_result.updated_input
|
|
755
|
+
logger.debug(
|
|
756
|
+
"Pre-hook updated input for tool '%s'", tool_name
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
# Hook may have set a permission override
|
|
760
|
+
if pre_result.permission_override is not None:
|
|
761
|
+
try:
|
|
762
|
+
permission_override = PermissionOverride(
|
|
763
|
+
pre_result.permission_override
|
|
764
|
+
)
|
|
765
|
+
except ValueError:
|
|
766
|
+
logger.warning(
|
|
767
|
+
"Invalid permission_override from hook: %s",
|
|
768
|
+
pre_result.permission_override,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# ---- Phase 2: Permission check ----
|
|
772
|
+
permission_outcome = await self._resolve_permission(
|
|
773
|
+
tool_name, effective_input, permission_override
|
|
774
|
+
)
|
|
775
|
+
if isinstance(permission_outcome, PermissionDeny):
|
|
776
|
+
result_msg = self._make_tool_result(
|
|
777
|
+
tool_id,
|
|
778
|
+
tool_name,
|
|
779
|
+
f"Permission denied: {permission_outcome.reason}",
|
|
780
|
+
is_error=True,
|
|
781
|
+
)
|
|
782
|
+
self._trace("tool_execution_finished", {
|
|
783
|
+
"tool_name": tool_name,
|
|
784
|
+
"tool_use_id": tool_id,
|
|
785
|
+
"outcome": "permission_denied",
|
|
786
|
+
})
|
|
787
|
+
return result_msg
|
|
788
|
+
|
|
789
|
+
# ---- Plan mode check: block write tools ----
|
|
790
|
+
if self.plan_mode_active:
|
|
791
|
+
from axion.runtime.plan_mode import get_plan_mode_denial_message, is_tool_allowed_in_plan_mode
|
|
792
|
+
if not is_tool_allowed_in_plan_mode(tool_name):
|
|
793
|
+
return self._make_tool_result(
|
|
794
|
+
tool_id, tool_name,
|
|
795
|
+
get_plan_mode_denial_message(tool_name),
|
|
796
|
+
is_error=True,
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# ---- Phase 3: Execute tool ----
|
|
800
|
+
# Notify caller that tool is about to execute
|
|
801
|
+
if self.on_tool_use is not None:
|
|
802
|
+
try:
|
|
803
|
+
self.on_tool_use(tool_name, effective_input)
|
|
804
|
+
except Exception:
|
|
805
|
+
pass
|
|
806
|
+
|
|
807
|
+
if self.tool_executor is None:
|
|
808
|
+
output = f"No tool executor configured for '{tool_name}'"
|
|
809
|
+
is_error = True
|
|
810
|
+
else:
|
|
811
|
+
try:
|
|
812
|
+
output = await self.tool_executor.execute(
|
|
813
|
+
tool_name, effective_input
|
|
814
|
+
)
|
|
815
|
+
is_error = False
|
|
816
|
+
except Exception as exc:
|
|
817
|
+
output = f"Tool error: {exc}"
|
|
818
|
+
is_error = True
|
|
819
|
+
logger.warning("Tool '%s' failed: %s", tool_name, exc)
|
|
820
|
+
|
|
821
|
+
# ---- Phase 3b: Post-tool-use-failure hooks ----
|
|
822
|
+
if self.hook_runner:
|
|
823
|
+
fail_result = await self.hook_runner.run_post_tool_use_failure(
|
|
824
|
+
tool_name, effective_input, str(exc)
|
|
825
|
+
)
|
|
826
|
+
if fail_result.messages:
|
|
827
|
+
output = self._merge_hook_feedback(
|
|
828
|
+
output, fail_result.messages
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# ---- Phase 4: Post-tool-use hooks (on success) ----
|
|
832
|
+
if not is_error and self.hook_runner:
|
|
833
|
+
post_result = await self.hook_runner.run_post_tool_use(
|
|
834
|
+
tool_name, effective_input, output, is_error=False
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
# Post-hook can retroactively mark as error
|
|
838
|
+
if post_result.denied:
|
|
839
|
+
is_error = True
|
|
840
|
+
deny_reason = (
|
|
841
|
+
"; ".join(post_result.messages)
|
|
842
|
+
or "Retroactively denied by post-tool-use hook"
|
|
843
|
+
)
|
|
844
|
+
output = f"Post-hook error: {deny_reason}\nOriginal output: {output}"
|
|
845
|
+
elif post_result.messages:
|
|
846
|
+
output = self._merge_hook_feedback(output, post_result.messages)
|
|
847
|
+
|
|
848
|
+
# Notify caller of tool result
|
|
849
|
+
if self.on_tool_result is not None:
|
|
850
|
+
try:
|
|
851
|
+
self.on_tool_result(tool_name, output, is_error)
|
|
852
|
+
except Exception:
|
|
853
|
+
pass
|
|
854
|
+
|
|
855
|
+
result_msg = self._make_tool_result(
|
|
856
|
+
tool_id, tool_name, output, is_error=is_error
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
self._trace("tool_execution_finished", {
|
|
860
|
+
"tool_name": tool_name,
|
|
861
|
+
"tool_use_id": tool_id,
|
|
862
|
+
"outcome": "error" if is_error else "success",
|
|
863
|
+
})
|
|
864
|
+
|
|
865
|
+
return result_msg
|
|
866
|
+
|
|
867
|
+
# -- Permission resolution -----------------------------------------------
|
|
868
|
+
|
|
869
|
+
async def _resolve_permission(
|
|
870
|
+
self,
|
|
871
|
+
tool_name: str,
|
|
872
|
+
tool_input: str,
|
|
873
|
+
hook_override: PermissionOverride | None,
|
|
874
|
+
) -> PermissionOutcome:
|
|
875
|
+
"""Resolve permission for a tool call, respecting hook overrides.
|
|
876
|
+
|
|
877
|
+
Priority:
|
|
878
|
+
1. Hook override (allow/deny/ask)
|
|
879
|
+
2. Policy-based authorization
|
|
880
|
+
3. Interactive prompter (if policy returned __NEEDS_PROMPT__)
|
|
881
|
+
4. Cache and persist the decision
|
|
882
|
+
"""
|
|
883
|
+
if hook_override is not None:
|
|
884
|
+
if hook_override == PermissionOverride.ALLOW:
|
|
885
|
+
return PermissionAllow()
|
|
886
|
+
if hook_override == PermissionOverride.DENY:
|
|
887
|
+
return PermissionDeny(reason="Denied by hook permission override")
|
|
888
|
+
# ASK falls through to normal policy + prompter flow
|
|
889
|
+
|
|
890
|
+
outcome = self.permission_policy.authorize(tool_name, tool_input)
|
|
891
|
+
|
|
892
|
+
# Check if the policy needs interactive approval
|
|
893
|
+
if (
|
|
894
|
+
isinstance(outcome, PermissionDeny)
|
|
895
|
+
and outcome.reason.startswith("__NEEDS_PROMPT__")
|
|
896
|
+
and self.permission_prompter is not None
|
|
897
|
+
):
|
|
898
|
+
request = PermissionRequest(
|
|
899
|
+
tool_name=tool_name,
|
|
900
|
+
input_json=tool_input,
|
|
901
|
+
current_mode=self.permission_policy.mode,
|
|
902
|
+
required_mode=TOOL_PERMISSION_REQUIREMENTS.get(
|
|
903
|
+
tool_name, PermissionMode.WORKSPACE_WRITE
|
|
904
|
+
),
|
|
905
|
+
reason=f"Tool '{tool_name}' requires approval",
|
|
906
|
+
)
|
|
907
|
+
decision = await self.permission_prompter.decide(request)
|
|
908
|
+
if decision == PermissionPromptDecision.ALLOW:
|
|
909
|
+
# Cache the decision so we don't ask again for this tool
|
|
910
|
+
result = PermissionAllow()
|
|
911
|
+
self.permission_policy.remember_decision(tool_name, result)
|
|
912
|
+
return result
|
|
913
|
+
return PermissionDeny(reason=f"User denied '{tool_name}'")
|
|
914
|
+
|
|
915
|
+
return outcome
|
|
916
|
+
|
|
917
|
+
# -- Auto-compaction -----------------------------------------------------
|
|
918
|
+
|
|
919
|
+
def _maybe_auto_compact(self, cumulative_input_tokens: int) -> CompactionResult | None:
|
|
920
|
+
"""Check if auto-compaction should trigger and perform it."""
|
|
921
|
+
if cumulative_input_tokens < self.auto_compaction_threshold:
|
|
922
|
+
return None
|
|
923
|
+
|
|
924
|
+
config = CompactionConfig(max_tokens=self.auto_compaction_threshold)
|
|
925
|
+
result = compact_session(self.session, config)
|
|
926
|
+
|
|
927
|
+
if result is not None:
|
|
928
|
+
self._trace("session_auto_compacted", {
|
|
929
|
+
"tokens_before": result.estimated_tokens_before,
|
|
930
|
+
"tokens_after": result.estimated_tokens_after,
|
|
931
|
+
"removed_count": result.removed_count,
|
|
932
|
+
})
|
|
933
|
+
|
|
934
|
+
return result
|
|
935
|
+
|
|
936
|
+
# -- Message building helpers --------------------------------------------
|
|
937
|
+
|
|
938
|
+
@staticmethod
|
|
939
|
+
def _build_assistant_message(
|
|
940
|
+
text_parts: list[str],
|
|
941
|
+
tool_uses: list[dict[str, Any]],
|
|
942
|
+
usage: TokenUsage,
|
|
943
|
+
) -> ConversationMessage | None:
|
|
944
|
+
"""Assemble an assistant ConversationMessage from streaming output."""
|
|
945
|
+
full_text = "".join(text_parts)
|
|
946
|
+
blocks: list[ContentBlock] = []
|
|
947
|
+
|
|
948
|
+
if full_text:
|
|
949
|
+
blocks.append(TextBlock(text=full_text))
|
|
950
|
+
for tu in tool_uses:
|
|
951
|
+
blocks.append(
|
|
952
|
+
ToolUseBlock(id=tu["id"], name=tu["name"], input=tu["input"])
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
if not blocks:
|
|
956
|
+
return None
|
|
957
|
+
|
|
958
|
+
return ConversationMessage(
|
|
959
|
+
role=MessageRole.ASSISTANT,
|
|
960
|
+
blocks=blocks,
|
|
961
|
+
usage=usage,
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
@staticmethod
|
|
965
|
+
def _make_tool_result(
|
|
966
|
+
tool_use_id: str,
|
|
967
|
+
tool_name: str,
|
|
968
|
+
output: str,
|
|
969
|
+
*,
|
|
970
|
+
is_error: bool = False,
|
|
971
|
+
) -> ConversationMessage:
|
|
972
|
+
"""Create a tool-result ConversationMessage."""
|
|
973
|
+
return ConversationMessage(
|
|
974
|
+
role=MessageRole.USER,
|
|
975
|
+
blocks=[
|
|
976
|
+
ToolResultBlock(
|
|
977
|
+
tool_use_id=tool_use_id,
|
|
978
|
+
tool_name=tool_name,
|
|
979
|
+
output=output,
|
|
980
|
+
is_error=is_error,
|
|
981
|
+
)
|
|
982
|
+
],
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
@staticmethod
|
|
986
|
+
def _merge_hook_feedback(output: str, hook_messages: list[str]) -> str:
|
|
987
|
+
"""Merge hook feedback messages into tool output."""
|
|
988
|
+
feedback = "\n".join(f"[hook] {m}" for m in hook_messages if m)
|
|
989
|
+
if not feedback:
|
|
990
|
+
return output
|
|
991
|
+
return f"{output}\n\n{feedback}"
|
|
992
|
+
|
|
993
|
+
# -- API message conversion ----------------------------------------------
|
|
994
|
+
|
|
995
|
+
def _build_api_messages(self) -> list[InputMessage]:
|
|
996
|
+
"""Convert session messages to API input format."""
|
|
997
|
+
from axion.api.types import (
|
|
998
|
+
ImageInputBlock,
|
|
999
|
+
TextInputBlock,
|
|
1000
|
+
ToolResultTextContent,
|
|
1001
|
+
ToolUseInputBlock,
|
|
1002
|
+
)
|
|
1003
|
+
from axion.api.types import (
|
|
1004
|
+
ToolResultBlock as ApiToolResultBlock,
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
api_messages: list[InputMessage] = []
|
|
1008
|
+
|
|
1009
|
+
for msg in self.session.messages:
|
|
1010
|
+
blocks = []
|
|
1011
|
+
for block in msg.blocks:
|
|
1012
|
+
match block:
|
|
1013
|
+
case TextBlock(text=text):
|
|
1014
|
+
blocks.append(TextInputBlock(text=text))
|
|
1015
|
+
case ImageBlock(media_type=mt, data=data):
|
|
1016
|
+
blocks.append(ImageInputBlock(media_type=mt, data=data))
|
|
1017
|
+
case ToolUseBlock(id=tid, name=name, input=inp):
|
|
1018
|
+
try:
|
|
1019
|
+
parsed = json.loads(inp) if inp else {}
|
|
1020
|
+
except json.JSONDecodeError:
|
|
1021
|
+
parsed = {"raw": inp}
|
|
1022
|
+
blocks.append(
|
|
1023
|
+
ToolUseInputBlock(id=tid, name=name, input=parsed)
|
|
1024
|
+
)
|
|
1025
|
+
case ToolResultBlock() as tr:
|
|
1026
|
+
blocks.append(
|
|
1027
|
+
ApiToolResultBlock(
|
|
1028
|
+
tool_use_id=tr.tool_use_id,
|
|
1029
|
+
content=[ToolResultTextContent(text=tr.output)],
|
|
1030
|
+
is_error=tr.is_error,
|
|
1031
|
+
)
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
if blocks:
|
|
1035
|
+
role = "assistant" if msg.role == MessageRole.ASSISTANT else "user"
|
|
1036
|
+
api_messages.append(InputMessage(role=role, content=blocks))
|
|
1037
|
+
|
|
1038
|
+
return api_messages
|
|
1039
|
+
|
|
1040
|
+
# -- Tracing / events ----------------------------------------------------
|
|
1041
|
+
|
|
1042
|
+
def _trace(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
|
1043
|
+
"""Record a trace event if a session tracer is attached."""
|
|
1044
|
+
if self.session_tracer is not None:
|
|
1045
|
+
self.session_tracer.record(name, attributes)
|
|
1046
|
+
|
|
1047
|
+
def _emit_event(self, event: AssistantEvent) -> None:
|
|
1048
|
+
"""Emit an assistant event to the on_event callback."""
|
|
1049
|
+
if self.on_event is not None:
|
|
1050
|
+
try:
|
|
1051
|
+
self.on_event(event)
|
|
1052
|
+
except Exception:
|
|
1053
|
+
logger.debug("on_event callback raised", exc_info=True)
|
|
1054
|
+
|
|
1055
|
+
|
|
1056
|
+
# ---------------------------------------------------------------------------
|
|
1057
|
+
# Module-level helpers
|
|
1058
|
+
# ---------------------------------------------------------------------------
|
|
1059
|
+
|
|
1060
|
+
def _resolve_compaction_threshold() -> int:
|
|
1061
|
+
"""Resolve auto-compaction threshold from environment or default."""
|
|
1062
|
+
raw = os.environ.get(_ENV_COMPACTION_KEY)
|
|
1063
|
+
if raw is not None:
|
|
1064
|
+
try:
|
|
1065
|
+
return int(raw)
|
|
1066
|
+
except ValueError:
|
|
1067
|
+
logger.warning(
|
|
1068
|
+
"Invalid %s value '%s', using default %d",
|
|
1069
|
+
_ENV_COMPACTION_KEY,
|
|
1070
|
+
raw,
|
|
1071
|
+
DEFAULT_AUTO_COMPACTION_THRESHOLD,
|
|
1072
|
+
)
|
|
1073
|
+
return DEFAULT_AUTO_COMPACTION_THRESHOLD
|