axion-code 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axion/__init__.py +3 -0
- axion/api/__init__.py +0 -0
- axion/api/anthropic.py +460 -0
- axion/api/client.py +259 -0
- axion/api/error.py +161 -0
- axion/api/ollama.py +597 -0
- axion/api/openai_compat.py +805 -0
- axion/api/openai_responses.py +627 -0
- axion/api/prompt_cache.py +31 -0
- axion/api/sse.py +98 -0
- axion/api/types.py +451 -0
- axion/cli/__init__.py +0 -0
- axion/cli/init_cmd.py +50 -0
- axion/cli/input.py +290 -0
- axion/cli/main.py +2953 -0
- axion/cli/render.py +489 -0
- axion/cli/tui.py +766 -0
- axion/commands/__init__.py +0 -0
- axion/commands/handlers/__init__.py +0 -0
- axion/commands/handlers/agents.py +51 -0
- axion/commands/handlers/builtin_commands.py +367 -0
- axion/commands/handlers/mcp.py +59 -0
- axion/commands/handlers/models.py +75 -0
- axion/commands/handlers/plugins.py +55 -0
- axion/commands/handlers/skills.py +61 -0
- axion/commands/parsing.py +317 -0
- axion/commands/registry.py +166 -0
- axion/compat_harness/__init__.py +0 -0
- axion/compat_harness/extractor.py +145 -0
- axion/plugins/__init__.py +0 -0
- axion/plugins/hooks.py +22 -0
- axion/plugins/manager.py +391 -0
- axion/plugins/manifest.py +270 -0
- axion/runtime/__init__.py +0 -0
- axion/runtime/bash.py +388 -0
- axion/runtime/bootstrap.py +39 -0
- axion/runtime/claude_subscription.py +300 -0
- axion/runtime/compact.py +233 -0
- axion/runtime/config.py +397 -0
- axion/runtime/conversation.py +1073 -0
- axion/runtime/file_ops.py +613 -0
- axion/runtime/git.py +213 -0
- axion/runtime/hooks.py +235 -0
- axion/runtime/image.py +212 -0
- axion/runtime/lanes.py +282 -0
- axion/runtime/lsp.py +425 -0
- axion/runtime/mcp/__init__.py +0 -0
- axion/runtime/mcp/client.py +76 -0
- axion/runtime/mcp/lifecycle.py +96 -0
- axion/runtime/mcp/stdio.py +318 -0
- axion/runtime/mcp/tool_bridge.py +79 -0
- axion/runtime/memory.py +196 -0
- axion/runtime/oauth.py +329 -0
- axion/runtime/openai_subscription.py +346 -0
- axion/runtime/permissions.py +247 -0
- axion/runtime/plan_mode.py +96 -0
- axion/runtime/policy_engine.py +259 -0
- axion/runtime/prompt.py +586 -0
- axion/runtime/recovery.py +261 -0
- axion/runtime/remote.py +28 -0
- axion/runtime/sandbox.py +68 -0
- axion/runtime/scheduler.py +231 -0
- axion/runtime/session.py +365 -0
- axion/runtime/sharing.py +159 -0
- axion/runtime/skills.py +124 -0
- axion/runtime/tasks.py +258 -0
- axion/runtime/usage.py +241 -0
- axion/runtime/workers.py +186 -0
- axion/telemetry/__init__.py +0 -0
- axion/telemetry/events.py +67 -0
- axion/telemetry/profile.py +49 -0
- axion/telemetry/sink.py +60 -0
- axion/telemetry/tracer.py +95 -0
- axion/tools/__init__.py +0 -0
- axion/tools/lane_completion.py +33 -0
- axion/tools/registry.py +853 -0
- axion/tools/tool_search.py +226 -0
- axion_code-1.0.0.dist-info/METADATA +709 -0
- axion_code-1.0.0.dist-info/RECORD +82 -0
- axion_code-1.0.0.dist-info/WHEEL +4 -0
- axion_code-1.0.0.dist-info/entry_points.txt +2 -0
- axion_code-1.0.0.dist-info/licenses/LICENSE +21 -0
axion/runtime/tasks.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
"""Task packet, registry, team assignment, and cron scheduling.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/runtime/src/task_packet.rs + team/cron registries
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
|
|
11
|
+
# ---------------------------------------------------------------------------
|
|
12
|
+
# Task packet
|
|
13
|
+
# ---------------------------------------------------------------------------
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class TaskPacket:
|
|
17
|
+
"""Specification for an autonomous task."""
|
|
18
|
+
|
|
19
|
+
objective: str
|
|
20
|
+
scope: str = ""
|
|
21
|
+
repo: str = ""
|
|
22
|
+
branch_policy: str = "feature-branch"
|
|
23
|
+
acceptance_tests: list[str] = field(default_factory=list)
|
|
24
|
+
commit_policy: str = "atomic"
|
|
25
|
+
reporting_contract: str = ""
|
|
26
|
+
escalation_policy: str = "alert_human"
|
|
27
|
+
priority: int = 0
|
|
28
|
+
tags: list[str] = field(default_factory=list)
|
|
29
|
+
assigned_team: str | None = None
|
|
30
|
+
cron_schedule: str | None = None
|
|
31
|
+
created_at_ms: int = field(default_factory=lambda: int(time.time() * 1000))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class TaskPacketValidationError:
|
|
36
|
+
errors: list[str]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def validate_packet(packet: TaskPacket) -> TaskPacketValidationError | None:
|
|
40
|
+
"""Validate a task packet for completeness."""
|
|
41
|
+
errors: list[str] = []
|
|
42
|
+
if not packet.objective:
|
|
43
|
+
errors.append("objective is required")
|
|
44
|
+
if not packet.scope:
|
|
45
|
+
errors.append("scope is required")
|
|
46
|
+
if packet.branch_policy not in ("feature-branch", "direct", "trunk"):
|
|
47
|
+
errors.append(f"invalid branch_policy: {packet.branch_policy}")
|
|
48
|
+
if packet.commit_policy not in ("atomic", "squash", "incremental"):
|
|
49
|
+
errors.append(f"invalid commit_policy: {packet.commit_policy}")
|
|
50
|
+
if packet.escalation_policy not in ("alert_human", "log_and_continue", "abort"):
|
|
51
|
+
errors.append(f"invalid escalation_policy: {packet.escalation_policy}")
|
|
52
|
+
return TaskPacketValidationError(errors=errors) if errors else None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# ---------------------------------------------------------------------------
|
|
56
|
+
# Task registry
|
|
57
|
+
# ---------------------------------------------------------------------------
|
|
58
|
+
|
|
59
|
+
class TaskStatus:
|
|
60
|
+
PENDING = "pending"
|
|
61
|
+
RUNNING = "running"
|
|
62
|
+
COMPLETED = "completed"
|
|
63
|
+
FAILED = "failed"
|
|
64
|
+
CANCELLED = "cancelled"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class TaskEntry:
|
|
69
|
+
"""A task in the registry with tracking metadata."""
|
|
70
|
+
|
|
71
|
+
task_id: str
|
|
72
|
+
packet: TaskPacket
|
|
73
|
+
status: str = TaskStatus.PENDING
|
|
74
|
+
worker_id: str | None = None
|
|
75
|
+
started_at_ms: int | None = None
|
|
76
|
+
completed_at_ms: int | None = None
|
|
77
|
+
result: str = ""
|
|
78
|
+
error: str = ""
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class TaskRegistry:
|
|
82
|
+
"""In-memory task lifecycle registry with team assignment and scheduling.
|
|
83
|
+
|
|
84
|
+
Maps to: rust/crates/runtime/src/task_packet.rs (registry)
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self) -> None:
|
|
88
|
+
self._tasks: dict[str, TaskEntry] = {}
|
|
89
|
+
self._counter = 0
|
|
90
|
+
|
|
91
|
+
def create(self, packet: TaskPacket) -> str:
|
|
92
|
+
"""Create a new task and return its ID."""
|
|
93
|
+
self._counter += 1
|
|
94
|
+
task_id = f"task-{self._counter:04d}"
|
|
95
|
+
self._tasks[task_id] = TaskEntry(task_id=task_id, packet=packet)
|
|
96
|
+
return task_id
|
|
97
|
+
|
|
98
|
+
def get(self, task_id: str) -> TaskEntry | None:
|
|
99
|
+
return self._tasks.get(task_id)
|
|
100
|
+
|
|
101
|
+
def all_tasks(self) -> list[TaskEntry]:
|
|
102
|
+
return list(self._tasks.values())
|
|
103
|
+
|
|
104
|
+
def pending_tasks(self) -> list[TaskEntry]:
|
|
105
|
+
return [t for t in self._tasks.values() if t.status == TaskStatus.PENDING]
|
|
106
|
+
|
|
107
|
+
def running_tasks(self) -> list[TaskEntry]:
|
|
108
|
+
return [t for t in self._tasks.values() if t.status == TaskStatus.RUNNING]
|
|
109
|
+
|
|
110
|
+
def start_task(self, task_id: str, worker_id: str) -> bool:
|
|
111
|
+
task = self._tasks.get(task_id)
|
|
112
|
+
if task is None or task.status != TaskStatus.PENDING:
|
|
113
|
+
return False
|
|
114
|
+
task.status = TaskStatus.RUNNING
|
|
115
|
+
task.worker_id = worker_id
|
|
116
|
+
task.started_at_ms = int(time.time() * 1000)
|
|
117
|
+
return True
|
|
118
|
+
|
|
119
|
+
def complete_task(self, task_id: str, result: str = "") -> bool:
|
|
120
|
+
task = self._tasks.get(task_id)
|
|
121
|
+
if task is None or task.status != TaskStatus.RUNNING:
|
|
122
|
+
return False
|
|
123
|
+
task.status = TaskStatus.COMPLETED
|
|
124
|
+
task.result = result
|
|
125
|
+
task.completed_at_ms = int(time.time() * 1000)
|
|
126
|
+
return True
|
|
127
|
+
|
|
128
|
+
def fail_task(self, task_id: str, error: str = "") -> bool:
|
|
129
|
+
task = self._tasks.get(task_id)
|
|
130
|
+
if task is None or task.status != TaskStatus.RUNNING:
|
|
131
|
+
return False
|
|
132
|
+
task.status = TaskStatus.FAILED
|
|
133
|
+
task.error = error
|
|
134
|
+
task.completed_at_ms = int(time.time() * 1000)
|
|
135
|
+
return True
|
|
136
|
+
|
|
137
|
+
def cancel_task(self, task_id: str) -> bool:
|
|
138
|
+
task = self._tasks.get(task_id)
|
|
139
|
+
if task is None or task.status in (TaskStatus.COMPLETED, TaskStatus.CANCELLED):
|
|
140
|
+
return False
|
|
141
|
+
task.status = TaskStatus.CANCELLED
|
|
142
|
+
task.completed_at_ms = int(time.time() * 1000)
|
|
143
|
+
return True
|
|
144
|
+
|
|
145
|
+
def remove(self, task_id: str) -> bool:
|
|
146
|
+
return self._tasks.pop(task_id, None) is not None
|
|
147
|
+
|
|
148
|
+
def summary(self) -> dict[str, int]:
|
|
149
|
+
counts: dict[str, int] = {}
|
|
150
|
+
for task in self._tasks.values():
|
|
151
|
+
counts[task.status] = counts.get(task.status, 0) + 1
|
|
152
|
+
return counts
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# ---------------------------------------------------------------------------
|
|
156
|
+
# Team registry
|
|
157
|
+
# ---------------------------------------------------------------------------
|
|
158
|
+
|
|
159
|
+
@dataclass
|
|
160
|
+
class Team:
|
|
161
|
+
"""A named group of workers for task assignment."""
|
|
162
|
+
|
|
163
|
+
name: str
|
|
164
|
+
worker_ids: list[str] = field(default_factory=list)
|
|
165
|
+
max_concurrent: int = 1
|
|
166
|
+
tags: list[str] = field(default_factory=list)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class TeamRegistry:
|
|
170
|
+
"""Registry of teams for task assignment."""
|
|
171
|
+
|
|
172
|
+
def __init__(self) -> None:
|
|
173
|
+
self._teams: dict[str, Team] = {}
|
|
174
|
+
|
|
175
|
+
def register(self, team: Team) -> None:
|
|
176
|
+
self._teams[team.name] = team
|
|
177
|
+
|
|
178
|
+
def get(self, name: str) -> Team | None:
|
|
179
|
+
return self._teams.get(name)
|
|
180
|
+
|
|
181
|
+
def all_teams(self) -> list[Team]:
|
|
182
|
+
return list(self._teams.values())
|
|
183
|
+
|
|
184
|
+
def assign_task(self, task: TaskEntry) -> str | None:
|
|
185
|
+
"""Find an available team for a task based on tags."""
|
|
186
|
+
for team in self._teams.values():
|
|
187
|
+
if task.packet.assigned_team and task.packet.assigned_team != team.name:
|
|
188
|
+
continue
|
|
189
|
+
# Check tag match
|
|
190
|
+
if task.packet.tags and not any(t in team.tags for t in task.packet.tags):
|
|
191
|
+
continue
|
|
192
|
+
return team.name
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# ---------------------------------------------------------------------------
|
|
197
|
+
# Cron registry
|
|
198
|
+
# ---------------------------------------------------------------------------
|
|
199
|
+
|
|
200
|
+
@dataclass
|
|
201
|
+
class CronEntry:
|
|
202
|
+
"""A scheduled recurring task."""
|
|
203
|
+
|
|
204
|
+
cron_id: str
|
|
205
|
+
schedule: str # Cron expression (e.g. "*/5 * * * *")
|
|
206
|
+
packet: TaskPacket
|
|
207
|
+
enabled: bool = True
|
|
208
|
+
last_run_ms: int = 0
|
|
209
|
+
next_run_ms: int = 0
|
|
210
|
+
run_count: int = 0
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class CronRegistry:
|
|
214
|
+
"""Registry of cron-scheduled tasks."""
|
|
215
|
+
|
|
216
|
+
def __init__(self) -> None:
|
|
217
|
+
self._entries: dict[str, CronEntry] = {}
|
|
218
|
+
self._counter = 0
|
|
219
|
+
|
|
220
|
+
def create(self, schedule: str, packet: TaskPacket) -> str:
|
|
221
|
+
self._counter += 1
|
|
222
|
+
cron_id = f"cron-{self._counter:04d}"
|
|
223
|
+
self._entries[cron_id] = CronEntry(
|
|
224
|
+
cron_id=cron_id, schedule=schedule, packet=packet,
|
|
225
|
+
)
|
|
226
|
+
return cron_id
|
|
227
|
+
|
|
228
|
+
def get(self, cron_id: str) -> CronEntry | None:
|
|
229
|
+
return self._entries.get(cron_id)
|
|
230
|
+
|
|
231
|
+
def all_entries(self) -> list[CronEntry]:
|
|
232
|
+
return list(self._entries.values())
|
|
233
|
+
|
|
234
|
+
def enabled_entries(self) -> list[CronEntry]:
|
|
235
|
+
return [e for e in self._entries.values() if e.enabled]
|
|
236
|
+
|
|
237
|
+
def enable(self, cron_id: str) -> bool:
|
|
238
|
+
entry = self._entries.get(cron_id)
|
|
239
|
+
if entry is None:
|
|
240
|
+
return False
|
|
241
|
+
entry.enabled = True
|
|
242
|
+
return True
|
|
243
|
+
|
|
244
|
+
def disable(self, cron_id: str) -> bool:
|
|
245
|
+
entry = self._entries.get(cron_id)
|
|
246
|
+
if entry is None:
|
|
247
|
+
return False
|
|
248
|
+
entry.enabled = False
|
|
249
|
+
return True
|
|
250
|
+
|
|
251
|
+
def remove(self, cron_id: str) -> bool:
|
|
252
|
+
return self._entries.pop(cron_id, None) is not None
|
|
253
|
+
|
|
254
|
+
def record_run(self, cron_id: str) -> None:
|
|
255
|
+
entry = self._entries.get(cron_id)
|
|
256
|
+
if entry:
|
|
257
|
+
entry.last_run_ms = int(time.time() * 1000)
|
|
258
|
+
entry.run_count += 1
|
axion/runtime/usage.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""Token usage tracking and cost estimation.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/runtime/src/usage.rs
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
# Default pricing (Sonnet 4 tier - $3 input / $15 output per million tokens)
|
|
12
|
+
DEFAULT_INPUT_COST_PER_MILLION = 3.0
|
|
13
|
+
DEFAULT_OUTPUT_COST_PER_MILLION = 15.0
|
|
14
|
+
DEFAULT_CACHE_CREATION_COST_PER_MILLION = 3.75
|
|
15
|
+
DEFAULT_CACHE_READ_COST_PER_MILLION = 0.3
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ModelPricing:
|
|
20
|
+
"""Per-million-token pricing for cost estimation."""
|
|
21
|
+
|
|
22
|
+
input_cost_per_million: float = DEFAULT_INPUT_COST_PER_MILLION
|
|
23
|
+
output_cost_per_million: float = DEFAULT_OUTPUT_COST_PER_MILLION
|
|
24
|
+
cache_creation_cost_per_million: float = DEFAULT_CACHE_CREATION_COST_PER_MILLION
|
|
25
|
+
cache_read_cost_per_million: float = DEFAULT_CACHE_READ_COST_PER_MILLION
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def default_sonnet_tier(cls) -> ModelPricing:
|
|
29
|
+
return cls()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class TokenUsage:
|
|
34
|
+
"""Token counters accumulated for a conversation turn or session."""
|
|
35
|
+
|
|
36
|
+
input_tokens: int = 0
|
|
37
|
+
output_tokens: int = 0
|
|
38
|
+
cache_creation_input_tokens: int = 0
|
|
39
|
+
cache_read_input_tokens: int = 0
|
|
40
|
+
|
|
41
|
+
def total_tokens(self) -> int:
|
|
42
|
+
return (
|
|
43
|
+
self.input_tokens
|
|
44
|
+
+ self.output_tokens
|
|
45
|
+
+ self.cache_creation_input_tokens
|
|
46
|
+
+ self.cache_read_input_tokens
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def estimate_cost_usd(self) -> UsageCostEstimate:
|
|
50
|
+
return self.estimate_cost_usd_with_pricing(ModelPricing.default_sonnet_tier())
|
|
51
|
+
|
|
52
|
+
def estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate:
|
|
53
|
+
return UsageCostEstimate(
|
|
54
|
+
input_cost_usd=_cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
|
|
55
|
+
output_cost_usd=_cost_for_tokens(
|
|
56
|
+
self.output_tokens, pricing.output_cost_per_million
|
|
57
|
+
),
|
|
58
|
+
cache_creation_cost_usd=_cost_for_tokens(
|
|
59
|
+
self.cache_creation_input_tokens,
|
|
60
|
+
pricing.cache_creation_cost_per_million,
|
|
61
|
+
),
|
|
62
|
+
cache_read_cost_usd=_cost_for_tokens(
|
|
63
|
+
self.cache_read_input_tokens,
|
|
64
|
+
pricing.cache_read_cost_per_million,
|
|
65
|
+
),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def __iadd__(self, other: TokenUsage) -> TokenUsage:
|
|
69
|
+
self.input_tokens += other.input_tokens
|
|
70
|
+
self.output_tokens += other.output_tokens
|
|
71
|
+
self.cache_creation_input_tokens += other.cache_creation_input_tokens
|
|
72
|
+
self.cache_read_input_tokens += other.cache_read_input_tokens
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def summary_lines(self, label: str, model: str | None = None) -> list[str]:
|
|
76
|
+
pricing = pricing_for_model(model) if model else None
|
|
77
|
+
cost = (
|
|
78
|
+
self.estimate_cost_usd_with_pricing(pricing)
|
|
79
|
+
if pricing
|
|
80
|
+
else self.estimate_cost_usd()
|
|
81
|
+
)
|
|
82
|
+
model_suffix = f" model={model}" if model else ""
|
|
83
|
+
return [
|
|
84
|
+
f"{label}: total_tokens={self.total_tokens()} "
|
|
85
|
+
f"input={self.input_tokens} output={self.output_tokens} "
|
|
86
|
+
f"cache_write={self.cache_creation_input_tokens} "
|
|
87
|
+
f"cache_read={self.cache_read_input_tokens} "
|
|
88
|
+
f"estimated_cost={format_usd(cost.total_cost_usd())}{model_suffix}",
|
|
89
|
+
f" cost breakdown: input={format_usd(cost.input_cost_usd)} "
|
|
90
|
+
f"output={format_usd(cost.output_cost_usd)} "
|
|
91
|
+
f"cache_write={format_usd(cost.cache_creation_cost_usd)} "
|
|
92
|
+
f"cache_read={format_usd(cost.cache_read_cost_usd)}",
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class UsageCostEstimate:
|
|
98
|
+
"""Estimated dollar cost derived from a TokenUsage sample."""
|
|
99
|
+
|
|
100
|
+
input_cost_usd: float = 0.0
|
|
101
|
+
output_cost_usd: float = 0.0
|
|
102
|
+
cache_creation_cost_usd: float = 0.0
|
|
103
|
+
cache_read_cost_usd: float = 0.0
|
|
104
|
+
|
|
105
|
+
def total_cost_usd(self) -> float:
|
|
106
|
+
return (
|
|
107
|
+
self.input_cost_usd
|
|
108
|
+
+ self.output_cost_usd
|
|
109
|
+
+ self.cache_creation_cost_usd
|
|
110
|
+
+ self.cache_read_cost_usd
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def pricing_for_model(model: str | None) -> ModelPricing | None:
|
|
115
|
+
"""Returns pricing metadata for a known model alias or family."""
|
|
116
|
+
if model is None:
|
|
117
|
+
return None
|
|
118
|
+
normalized = model.lower()
|
|
119
|
+
if "haiku" in normalized:
|
|
120
|
+
return ModelPricing(
|
|
121
|
+
input_cost_per_million=1.0,
|
|
122
|
+
output_cost_per_million=5.0,
|
|
123
|
+
cache_creation_cost_per_million=1.25,
|
|
124
|
+
cache_read_cost_per_million=0.1,
|
|
125
|
+
)
|
|
126
|
+
if "opus" in normalized:
|
|
127
|
+
return ModelPricing(
|
|
128
|
+
input_cost_per_million=5.0, # Opus 4.6: $5/MTok input
|
|
129
|
+
output_cost_per_million=25.0, # Opus 4.6: $25/MTok output
|
|
130
|
+
cache_creation_cost_per_million=6.25,
|
|
131
|
+
cache_read_cost_per_million=0.5,
|
|
132
|
+
)
|
|
133
|
+
if "sonnet" in normalized:
|
|
134
|
+
return ModelPricing.default_sonnet_tier()
|
|
135
|
+
# OpenAI models — ordered specific to general
|
|
136
|
+
# GPT-4.1 series (1M context)
|
|
137
|
+
if "gpt-4.1-nano" in normalized:
|
|
138
|
+
return ModelPricing(
|
|
139
|
+
input_cost_per_million=0.10, output_cost_per_million=0.40,
|
|
140
|
+
cache_creation_cost_per_million=0.05, cache_read_cost_per_million=0.025,
|
|
141
|
+
)
|
|
142
|
+
if "gpt-4.1-mini" in normalized:
|
|
143
|
+
return ModelPricing(
|
|
144
|
+
input_cost_per_million=0.40, output_cost_per_million=1.60,
|
|
145
|
+
cache_creation_cost_per_million=0.20, cache_read_cost_per_million=0.10,
|
|
146
|
+
)
|
|
147
|
+
if "gpt-4.1" in normalized:
|
|
148
|
+
return ModelPricing(
|
|
149
|
+
input_cost_per_million=2.0, output_cost_per_million=8.0,
|
|
150
|
+
cache_creation_cost_per_million=1.0, cache_read_cost_per_million=0.50,
|
|
151
|
+
)
|
|
152
|
+
# GPT-4o series
|
|
153
|
+
if "gpt-4o-mini" in normalized or "4o-mini" in normalized:
|
|
154
|
+
return ModelPricing(
|
|
155
|
+
input_cost_per_million=0.15, output_cost_per_million=0.60,
|
|
156
|
+
cache_creation_cost_per_million=0.075, cache_read_cost_per_million=0.075,
|
|
157
|
+
)
|
|
158
|
+
if "gpt-4o" in normalized or "4o" in normalized:
|
|
159
|
+
return ModelPricing(
|
|
160
|
+
input_cost_per_million=2.50, output_cost_per_million=10.0,
|
|
161
|
+
cache_creation_cost_per_million=1.25, cache_read_cost_per_million=1.25,
|
|
162
|
+
)
|
|
163
|
+
# o-series reasoning
|
|
164
|
+
if "o4-mini" in normalized:
|
|
165
|
+
return ModelPricing(
|
|
166
|
+
input_cost_per_million=1.10, output_cost_per_million=4.40,
|
|
167
|
+
cache_creation_cost_per_million=0.55, cache_read_cost_per_million=0.275,
|
|
168
|
+
)
|
|
169
|
+
if "o3-mini" in normalized:
|
|
170
|
+
return ModelPricing(
|
|
171
|
+
input_cost_per_million=1.10, output_cost_per_million=4.40,
|
|
172
|
+
cache_creation_cost_per_million=0.55, cache_read_cost_per_million=0.55,
|
|
173
|
+
)
|
|
174
|
+
if normalized in ("o1", "o3", "o1-pro"):
|
|
175
|
+
return ModelPricing(
|
|
176
|
+
input_cost_per_million=15.0, output_cost_per_million=60.0,
|
|
177
|
+
cache_creation_cost_per_million=7.5, cache_read_cost_per_million=7.5,
|
|
178
|
+
)
|
|
179
|
+
if "o1-mini" in normalized:
|
|
180
|
+
return ModelPricing(
|
|
181
|
+
input_cost_per_million=1.10, output_cost_per_million=4.40,
|
|
182
|
+
cache_creation_cost_per_million=0.55, cache_read_cost_per_million=0.55,
|
|
183
|
+
)
|
|
184
|
+
# Codex (Responses API — agent-tuned coding models)
|
|
185
|
+
if "codex-mini" in normalized:
|
|
186
|
+
return ModelPricing(
|
|
187
|
+
input_cost_per_million=0.25, output_cost_per_million=2.0,
|
|
188
|
+
cache_creation_cost_per_million=0.125, cache_read_cost_per_million=0.025,
|
|
189
|
+
)
|
|
190
|
+
if "codex" in normalized:
|
|
191
|
+
return ModelPricing(
|
|
192
|
+
input_cost_per_million=1.25, output_cost_per_million=10.0,
|
|
193
|
+
cache_creation_cost_per_million=0.625, cache_read_cost_per_million=0.125,
|
|
194
|
+
)
|
|
195
|
+
# xAI
|
|
196
|
+
if "grok" in normalized:
|
|
197
|
+
return ModelPricing(
|
|
198
|
+
input_cost_per_million=5.0,
|
|
199
|
+
output_cost_per_million=15.0,
|
|
200
|
+
cache_creation_cost_per_million=2.5,
|
|
201
|
+
cache_read_cost_per_million=2.5,
|
|
202
|
+
)
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _cost_for_tokens(tokens: int, cost_per_million: float) -> float:
|
|
207
|
+
return (tokens / 1_000_000) * cost_per_million
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def format_usd(value: float) -> str:
|
|
211
|
+
"""Format a USD value with 4 decimal places."""
|
|
212
|
+
return f"${value:.4f}"
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ---------------------------------------------------------------------------
|
|
216
|
+
# Usage tracker (accumulates across turns)
|
|
217
|
+
# ---------------------------------------------------------------------------
|
|
218
|
+
|
|
219
|
+
@dataclass
|
|
220
|
+
class UsageTracker:
|
|
221
|
+
"""Accumulates token usage across multiple turns."""
|
|
222
|
+
|
|
223
|
+
total: TokenUsage = field(default_factory=TokenUsage)
|
|
224
|
+
turn_count: int = 0
|
|
225
|
+
model: str | None = None
|
|
226
|
+
|
|
227
|
+
def record_turn(self, usage: TokenUsage) -> None:
|
|
228
|
+
self.total += usage
|
|
229
|
+
self.turn_count += 1
|
|
230
|
+
|
|
231
|
+
def summary_lines(self) -> list[str]:
|
|
232
|
+
return self.total.summary_lines("Session total", self.model)
|
|
233
|
+
|
|
234
|
+
@classmethod
|
|
235
|
+
def from_session(cls, session: Any) -> UsageTracker:
|
|
236
|
+
"""Build tracker from existing session messages' usage data."""
|
|
237
|
+
tracker = cls()
|
|
238
|
+
for msg in getattr(session, "messages", []):
|
|
239
|
+
if msg.usage is not None:
|
|
240
|
+
tracker.record_turn(msg.usage)
|
|
241
|
+
return tracker
|
axion/runtime/workers.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Worker state machine and registry with full lifecycle management.
|
|
2
|
+
|
|
3
|
+
Maps to: rust/crates/runtime/src/worker_boot.rs
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import enum
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
import uuid
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class WorkerStatus(enum.Enum):
|
|
19
|
+
SPAWNING = "spawning"
|
|
20
|
+
TRUST_REQUIRED = "trust_required"
|
|
21
|
+
READY_FOR_PROMPT = "ready_for_prompt"
|
|
22
|
+
RUNNING = "running"
|
|
23
|
+
FINISHED = "finished"
|
|
24
|
+
FAILED = "failed"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class WorkerFailureKind(enum.Enum):
|
|
28
|
+
TRUST_GATE = "trust_gate"
|
|
29
|
+
PROMPT_DELIVERY = "prompt_delivery"
|
|
30
|
+
PROTOCOL = "protocol"
|
|
31
|
+
PROVIDER = "provider"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WorkerTrustResolution(enum.Enum):
|
|
35
|
+
AUTO_ALLOWLISTED = "auto_allowlisted"
|
|
36
|
+
MANUAL_APPROVAL = "manual_approval"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class WorkerPromptTarget(enum.Enum):
|
|
40
|
+
SHELL = "shell"
|
|
41
|
+
WRONG_TARGET = "wrong_target"
|
|
42
|
+
UNKNOWN = "unknown"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class WorkerEvent:
|
|
47
|
+
"""Timestamped event in a worker's lifecycle."""
|
|
48
|
+
|
|
49
|
+
seq: int
|
|
50
|
+
kind: str
|
|
51
|
+
status: WorkerStatus
|
|
52
|
+
detail: str = ""
|
|
53
|
+
payload: dict[str, Any] | None = None
|
|
54
|
+
timestamp_ms: int = field(default_factory=lambda: int(time.time() * 1000))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class Worker:
|
|
59
|
+
"""Represents a single agent worker instance.
|
|
60
|
+
|
|
61
|
+
Implements the full state machine:
|
|
62
|
+
SPAWNING → TRUST_REQUIRED → READY_FOR_PROMPT → RUNNING → FINISHED/FAILED
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
worker_id: str = field(default_factory=lambda: f"w-{uuid.uuid4().hex[:8]}")
|
|
66
|
+
cwd: str = ""
|
|
67
|
+
status: WorkerStatus = WorkerStatus.SPAWNING
|
|
68
|
+
trust_auto_resolve: bool = False
|
|
69
|
+
trust_gate_cleared: bool = False
|
|
70
|
+
auto_recover_prompt_misdelivery: bool = True
|
|
71
|
+
prompt_delivery_attempts: int = 0
|
|
72
|
+
max_prompt_delivery_attempts: int = 3
|
|
73
|
+
events: list[WorkerEvent] = field(default_factory=list)
|
|
74
|
+
|
|
75
|
+
def transition(self, new_status: WorkerStatus, detail: str = "", **payload: Any) -> None:
|
|
76
|
+
"""Transition to a new status with event logging."""
|
|
77
|
+
event = WorkerEvent(
|
|
78
|
+
seq=len(self.events) + 1,
|
|
79
|
+
kind=f"transition_{new_status.value}",
|
|
80
|
+
status=new_status,
|
|
81
|
+
detail=detail,
|
|
82
|
+
payload=payload if payload else None,
|
|
83
|
+
)
|
|
84
|
+
self.events.append(event)
|
|
85
|
+
self.status = new_status
|
|
86
|
+
logger.debug("Worker %s → %s: %s", self.worker_id, new_status.value, detail)
|
|
87
|
+
|
|
88
|
+
def resolve_trust(self, resolution: WorkerTrustResolution) -> None:
|
|
89
|
+
"""Resolve the trust gate and advance to ready state."""
|
|
90
|
+
if self.status != WorkerStatus.TRUST_REQUIRED:
|
|
91
|
+
logger.warning("Cannot resolve trust in status %s", self.status.value)
|
|
92
|
+
return
|
|
93
|
+
self.trust_gate_cleared = True
|
|
94
|
+
self.transition(
|
|
95
|
+
WorkerStatus.READY_FOR_PROMPT,
|
|
96
|
+
detail=f"Trust resolved via {resolution.value}",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def deliver_prompt(self, target: WorkerPromptTarget = WorkerPromptTarget.SHELL) -> bool:
|
|
100
|
+
"""Attempt to deliver the prompt to the worker."""
|
|
101
|
+
self.prompt_delivery_attempts += 1
|
|
102
|
+
|
|
103
|
+
if target == WorkerPromptTarget.WRONG_TARGET:
|
|
104
|
+
if (
|
|
105
|
+
self.auto_recover_prompt_misdelivery
|
|
106
|
+
and self.prompt_delivery_attempts < self.max_prompt_delivery_attempts
|
|
107
|
+
):
|
|
108
|
+
self.transition(
|
|
109
|
+
WorkerStatus.READY_FOR_PROMPT,
|
|
110
|
+
detail=f"Prompt misdelivered (attempt {self.prompt_delivery_attempts}), retrying",
|
|
111
|
+
)
|
|
112
|
+
return False
|
|
113
|
+
self.transition(
|
|
114
|
+
WorkerStatus.FAILED,
|
|
115
|
+
detail="Prompt misdelivery exceeded max attempts",
|
|
116
|
+
)
|
|
117
|
+
return False
|
|
118
|
+
|
|
119
|
+
if target == WorkerPromptTarget.UNKNOWN:
|
|
120
|
+
self.transition(WorkerStatus.FAILED, detail="Unknown prompt target")
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
self.transition(WorkerStatus.RUNNING, detail="Prompt delivered to shell")
|
|
124
|
+
return True
|
|
125
|
+
|
|
126
|
+
def finish(self, detail: str = "Completed successfully") -> None:
|
|
127
|
+
self.transition(WorkerStatus.FINISHED, detail=detail)
|
|
128
|
+
|
|
129
|
+
def fail(self, kind: WorkerFailureKind, detail: str = "") -> None:
|
|
130
|
+
self.transition(WorkerStatus.FAILED, detail=f"{kind.value}: {detail}")
|
|
131
|
+
|
|
132
|
+
def restart(self) -> None:
|
|
133
|
+
"""Restart the worker from SPAWNING state."""
|
|
134
|
+
self.trust_gate_cleared = False
|
|
135
|
+
self.prompt_delivery_attempts = 0
|
|
136
|
+
self.transition(WorkerStatus.SPAWNING, detail="Restarted")
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def is_active(self) -> bool:
|
|
140
|
+
return self.status in (WorkerStatus.RUNNING, WorkerStatus.READY_FOR_PROMPT)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def is_terminal(self) -> bool:
|
|
144
|
+
return self.status in (WorkerStatus.FINISHED, WorkerStatus.FAILED)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class WorkerRegistry:
|
|
148
|
+
"""Manages multiple worker instances.
|
|
149
|
+
|
|
150
|
+
Maps to: rust/crates/runtime/src/worker_boot.rs::WorkerRegistry
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def __init__(self) -> None:
|
|
154
|
+
self._workers: dict[str, Worker] = {}
|
|
155
|
+
|
|
156
|
+
def spawn(self, cwd: str = "", **kwargs: Any) -> Worker:
|
|
157
|
+
"""Create and register a new worker."""
|
|
158
|
+
worker = Worker(cwd=cwd, **kwargs)
|
|
159
|
+
self._workers[worker.worker_id] = worker
|
|
160
|
+
worker.transition(WorkerStatus.SPAWNING, detail="Worker spawned")
|
|
161
|
+
return worker
|
|
162
|
+
|
|
163
|
+
def register(self, worker: Worker) -> None:
|
|
164
|
+
self._workers[worker.worker_id] = worker
|
|
165
|
+
|
|
166
|
+
def get(self, worker_id: str) -> Worker | None:
|
|
167
|
+
return self._workers.get(worker_id)
|
|
168
|
+
|
|
169
|
+
def all_workers(self) -> list[Worker]:
|
|
170
|
+
return list(self._workers.values())
|
|
171
|
+
|
|
172
|
+
def active_workers(self) -> list[Worker]:
|
|
173
|
+
return [w for w in self._workers.values() if w.is_active]
|
|
174
|
+
|
|
175
|
+
def finished_workers(self) -> list[Worker]:
|
|
176
|
+
return [w for w in self._workers.values() if w.is_terminal]
|
|
177
|
+
|
|
178
|
+
def remove(self, worker_id: str) -> bool:
|
|
179
|
+
return self._workers.pop(worker_id, None) is not None
|
|
180
|
+
|
|
181
|
+
def summary(self) -> dict[str, int]:
|
|
182
|
+
"""Count workers by status."""
|
|
183
|
+
counts: dict[str, int] = {}
|
|
184
|
+
for worker in self._workers.values():
|
|
185
|
+
counts[worker.status.value] = counts.get(worker.status.value, 0) + 1
|
|
186
|
+
return counts
|
|
File without changes
|