axion-code 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. axion/__init__.py +3 -0
  2. axion/api/__init__.py +0 -0
  3. axion/api/anthropic.py +460 -0
  4. axion/api/client.py +259 -0
  5. axion/api/error.py +161 -0
  6. axion/api/ollama.py +597 -0
  7. axion/api/openai_compat.py +805 -0
  8. axion/api/openai_responses.py +627 -0
  9. axion/api/prompt_cache.py +31 -0
  10. axion/api/sse.py +98 -0
  11. axion/api/types.py +451 -0
  12. axion/cli/__init__.py +0 -0
  13. axion/cli/init_cmd.py +50 -0
  14. axion/cli/input.py +290 -0
  15. axion/cli/main.py +2953 -0
  16. axion/cli/render.py +489 -0
  17. axion/cli/tui.py +766 -0
  18. axion/commands/__init__.py +0 -0
  19. axion/commands/handlers/__init__.py +0 -0
  20. axion/commands/handlers/agents.py +51 -0
  21. axion/commands/handlers/builtin_commands.py +367 -0
  22. axion/commands/handlers/mcp.py +59 -0
  23. axion/commands/handlers/models.py +75 -0
  24. axion/commands/handlers/plugins.py +55 -0
  25. axion/commands/handlers/skills.py +61 -0
  26. axion/commands/parsing.py +317 -0
  27. axion/commands/registry.py +166 -0
  28. axion/compat_harness/__init__.py +0 -0
  29. axion/compat_harness/extractor.py +145 -0
  30. axion/plugins/__init__.py +0 -0
  31. axion/plugins/hooks.py +22 -0
  32. axion/plugins/manager.py +391 -0
  33. axion/plugins/manifest.py +270 -0
  34. axion/runtime/__init__.py +0 -0
  35. axion/runtime/bash.py +388 -0
  36. axion/runtime/bootstrap.py +39 -0
  37. axion/runtime/claude_subscription.py +300 -0
  38. axion/runtime/compact.py +233 -0
  39. axion/runtime/config.py +397 -0
  40. axion/runtime/conversation.py +1073 -0
  41. axion/runtime/file_ops.py +613 -0
  42. axion/runtime/git.py +213 -0
  43. axion/runtime/hooks.py +235 -0
  44. axion/runtime/image.py +212 -0
  45. axion/runtime/lanes.py +282 -0
  46. axion/runtime/lsp.py +425 -0
  47. axion/runtime/mcp/__init__.py +0 -0
  48. axion/runtime/mcp/client.py +76 -0
  49. axion/runtime/mcp/lifecycle.py +96 -0
  50. axion/runtime/mcp/stdio.py +318 -0
  51. axion/runtime/mcp/tool_bridge.py +79 -0
  52. axion/runtime/memory.py +196 -0
  53. axion/runtime/oauth.py +329 -0
  54. axion/runtime/openai_subscription.py +346 -0
  55. axion/runtime/permissions.py +247 -0
  56. axion/runtime/plan_mode.py +96 -0
  57. axion/runtime/policy_engine.py +259 -0
  58. axion/runtime/prompt.py +586 -0
  59. axion/runtime/recovery.py +261 -0
  60. axion/runtime/remote.py +28 -0
  61. axion/runtime/sandbox.py +68 -0
  62. axion/runtime/scheduler.py +231 -0
  63. axion/runtime/session.py +365 -0
  64. axion/runtime/sharing.py +159 -0
  65. axion/runtime/skills.py +124 -0
  66. axion/runtime/tasks.py +258 -0
  67. axion/runtime/usage.py +241 -0
  68. axion/runtime/workers.py +186 -0
  69. axion/telemetry/__init__.py +0 -0
  70. axion/telemetry/events.py +67 -0
  71. axion/telemetry/profile.py +49 -0
  72. axion/telemetry/sink.py +60 -0
  73. axion/telemetry/tracer.py +95 -0
  74. axion/tools/__init__.py +0 -0
  75. axion/tools/lane_completion.py +33 -0
  76. axion/tools/registry.py +853 -0
  77. axion/tools/tool_search.py +226 -0
  78. axion_code-1.0.0.dist-info/METADATA +709 -0
  79. axion_code-1.0.0.dist-info/RECORD +82 -0
  80. axion_code-1.0.0.dist-info/WHEEL +4 -0
  81. axion_code-1.0.0.dist-info/entry_points.txt +2 -0
  82. axion_code-1.0.0.dist-info/licenses/LICENSE +21 -0
axion/runtime/tasks.py ADDED
@@ -0,0 +1,258 @@
1
+ """Task packet, registry, team assignment, and cron scheduling.
2
+
3
+ Maps to: rust/crates/runtime/src/task_packet.rs + team/cron registries
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import time
9
+ from dataclasses import dataclass, field
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Task packet
13
+ # ---------------------------------------------------------------------------
14
+
15
+ @dataclass
16
+ class TaskPacket:
17
+ """Specification for an autonomous task."""
18
+
19
+ objective: str
20
+ scope: str = ""
21
+ repo: str = ""
22
+ branch_policy: str = "feature-branch"
23
+ acceptance_tests: list[str] = field(default_factory=list)
24
+ commit_policy: str = "atomic"
25
+ reporting_contract: str = ""
26
+ escalation_policy: str = "alert_human"
27
+ priority: int = 0
28
+ tags: list[str] = field(default_factory=list)
29
+ assigned_team: str | None = None
30
+ cron_schedule: str | None = None
31
+ created_at_ms: int = field(default_factory=lambda: int(time.time() * 1000))
32
+
33
+
34
+ @dataclass
35
+ class TaskPacketValidationError:
36
+ errors: list[str]
37
+
38
+
39
+ def validate_packet(packet: TaskPacket) -> TaskPacketValidationError | None:
40
+ """Validate a task packet for completeness."""
41
+ errors: list[str] = []
42
+ if not packet.objective:
43
+ errors.append("objective is required")
44
+ if not packet.scope:
45
+ errors.append("scope is required")
46
+ if packet.branch_policy not in ("feature-branch", "direct", "trunk"):
47
+ errors.append(f"invalid branch_policy: {packet.branch_policy}")
48
+ if packet.commit_policy not in ("atomic", "squash", "incremental"):
49
+ errors.append(f"invalid commit_policy: {packet.commit_policy}")
50
+ if packet.escalation_policy not in ("alert_human", "log_and_continue", "abort"):
51
+ errors.append(f"invalid escalation_policy: {packet.escalation_policy}")
52
+ return TaskPacketValidationError(errors=errors) if errors else None
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Task registry
57
+ # ---------------------------------------------------------------------------
58
+
59
+ class TaskStatus:
60
+ PENDING = "pending"
61
+ RUNNING = "running"
62
+ COMPLETED = "completed"
63
+ FAILED = "failed"
64
+ CANCELLED = "cancelled"
65
+
66
+
67
+ @dataclass
68
+ class TaskEntry:
69
+ """A task in the registry with tracking metadata."""
70
+
71
+ task_id: str
72
+ packet: TaskPacket
73
+ status: str = TaskStatus.PENDING
74
+ worker_id: str | None = None
75
+ started_at_ms: int | None = None
76
+ completed_at_ms: int | None = None
77
+ result: str = ""
78
+ error: str = ""
79
+
80
+
81
+ class TaskRegistry:
82
+ """In-memory task lifecycle registry with team assignment and scheduling.
83
+
84
+ Maps to: rust/crates/runtime/src/task_packet.rs (registry)
85
+ """
86
+
87
+ def __init__(self) -> None:
88
+ self._tasks: dict[str, TaskEntry] = {}
89
+ self._counter = 0
90
+
91
+ def create(self, packet: TaskPacket) -> str:
92
+ """Create a new task and return its ID."""
93
+ self._counter += 1
94
+ task_id = f"task-{self._counter:04d}"
95
+ self._tasks[task_id] = TaskEntry(task_id=task_id, packet=packet)
96
+ return task_id
97
+
98
+ def get(self, task_id: str) -> TaskEntry | None:
99
+ return self._tasks.get(task_id)
100
+
101
+ def all_tasks(self) -> list[TaskEntry]:
102
+ return list(self._tasks.values())
103
+
104
+ def pending_tasks(self) -> list[TaskEntry]:
105
+ return [t for t in self._tasks.values() if t.status == TaskStatus.PENDING]
106
+
107
+ def running_tasks(self) -> list[TaskEntry]:
108
+ return [t for t in self._tasks.values() if t.status == TaskStatus.RUNNING]
109
+
110
+ def start_task(self, task_id: str, worker_id: str) -> bool:
111
+ task = self._tasks.get(task_id)
112
+ if task is None or task.status != TaskStatus.PENDING:
113
+ return False
114
+ task.status = TaskStatus.RUNNING
115
+ task.worker_id = worker_id
116
+ task.started_at_ms = int(time.time() * 1000)
117
+ return True
118
+
119
+ def complete_task(self, task_id: str, result: str = "") -> bool:
120
+ task = self._tasks.get(task_id)
121
+ if task is None or task.status != TaskStatus.RUNNING:
122
+ return False
123
+ task.status = TaskStatus.COMPLETED
124
+ task.result = result
125
+ task.completed_at_ms = int(time.time() * 1000)
126
+ return True
127
+
128
+ def fail_task(self, task_id: str, error: str = "") -> bool:
129
+ task = self._tasks.get(task_id)
130
+ if task is None or task.status != TaskStatus.RUNNING:
131
+ return False
132
+ task.status = TaskStatus.FAILED
133
+ task.error = error
134
+ task.completed_at_ms = int(time.time() * 1000)
135
+ return True
136
+
137
+ def cancel_task(self, task_id: str) -> bool:
138
+ task = self._tasks.get(task_id)
139
+ if task is None or task.status in (TaskStatus.COMPLETED, TaskStatus.CANCELLED):
140
+ return False
141
+ task.status = TaskStatus.CANCELLED
142
+ task.completed_at_ms = int(time.time() * 1000)
143
+ return True
144
+
145
+ def remove(self, task_id: str) -> bool:
146
+ return self._tasks.pop(task_id, None) is not None
147
+
148
+ def summary(self) -> dict[str, int]:
149
+ counts: dict[str, int] = {}
150
+ for task in self._tasks.values():
151
+ counts[task.status] = counts.get(task.status, 0) + 1
152
+ return counts
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Team registry
157
+ # ---------------------------------------------------------------------------
158
+
159
+ @dataclass
160
+ class Team:
161
+ """A named group of workers for task assignment."""
162
+
163
+ name: str
164
+ worker_ids: list[str] = field(default_factory=list)
165
+ max_concurrent: int = 1
166
+ tags: list[str] = field(default_factory=list)
167
+
168
+
169
+ class TeamRegistry:
170
+ """Registry of teams for task assignment."""
171
+
172
+ def __init__(self) -> None:
173
+ self._teams: dict[str, Team] = {}
174
+
175
+ def register(self, team: Team) -> None:
176
+ self._teams[team.name] = team
177
+
178
+ def get(self, name: str) -> Team | None:
179
+ return self._teams.get(name)
180
+
181
+ def all_teams(self) -> list[Team]:
182
+ return list(self._teams.values())
183
+
184
+ def assign_task(self, task: TaskEntry) -> str | None:
185
+ """Find an available team for a task based on tags."""
186
+ for team in self._teams.values():
187
+ if task.packet.assigned_team and task.packet.assigned_team != team.name:
188
+ continue
189
+ # Check tag match
190
+ if task.packet.tags and not any(t in team.tags for t in task.packet.tags):
191
+ continue
192
+ return team.name
193
+ return None
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # Cron registry
198
+ # ---------------------------------------------------------------------------
199
+
200
+ @dataclass
201
+ class CronEntry:
202
+ """A scheduled recurring task."""
203
+
204
+ cron_id: str
205
+ schedule: str # Cron expression (e.g. "*/5 * * * *")
206
+ packet: TaskPacket
207
+ enabled: bool = True
208
+ last_run_ms: int = 0
209
+ next_run_ms: int = 0
210
+ run_count: int = 0
211
+
212
+
213
+ class CronRegistry:
214
+ """Registry of cron-scheduled tasks."""
215
+
216
+ def __init__(self) -> None:
217
+ self._entries: dict[str, CronEntry] = {}
218
+ self._counter = 0
219
+
220
+ def create(self, schedule: str, packet: TaskPacket) -> str:
221
+ self._counter += 1
222
+ cron_id = f"cron-{self._counter:04d}"
223
+ self._entries[cron_id] = CronEntry(
224
+ cron_id=cron_id, schedule=schedule, packet=packet,
225
+ )
226
+ return cron_id
227
+
228
+ def get(self, cron_id: str) -> CronEntry | None:
229
+ return self._entries.get(cron_id)
230
+
231
+ def all_entries(self) -> list[CronEntry]:
232
+ return list(self._entries.values())
233
+
234
+ def enabled_entries(self) -> list[CronEntry]:
235
+ return [e for e in self._entries.values() if e.enabled]
236
+
237
+ def enable(self, cron_id: str) -> bool:
238
+ entry = self._entries.get(cron_id)
239
+ if entry is None:
240
+ return False
241
+ entry.enabled = True
242
+ return True
243
+
244
+ def disable(self, cron_id: str) -> bool:
245
+ entry = self._entries.get(cron_id)
246
+ if entry is None:
247
+ return False
248
+ entry.enabled = False
249
+ return True
250
+
251
+ def remove(self, cron_id: str) -> bool:
252
+ return self._entries.pop(cron_id, None) is not None
253
+
254
+ def record_run(self, cron_id: str) -> None:
255
+ entry = self._entries.get(cron_id)
256
+ if entry:
257
+ entry.last_run_ms = int(time.time() * 1000)
258
+ entry.run_count += 1
axion/runtime/usage.py ADDED
@@ -0,0 +1,241 @@
1
+ """Token usage tracking and cost estimation.
2
+
3
+ Maps to: rust/crates/runtime/src/usage.rs
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Any
10
+
11
+ # Default pricing (Sonnet 4 tier - $3 input / $15 output per million tokens)
12
+ DEFAULT_INPUT_COST_PER_MILLION = 3.0
13
+ DEFAULT_OUTPUT_COST_PER_MILLION = 15.0
14
+ DEFAULT_CACHE_CREATION_COST_PER_MILLION = 3.75
15
+ DEFAULT_CACHE_READ_COST_PER_MILLION = 0.3
16
+
17
+
18
+ @dataclass
19
+ class ModelPricing:
20
+ """Per-million-token pricing for cost estimation."""
21
+
22
+ input_cost_per_million: float = DEFAULT_INPUT_COST_PER_MILLION
23
+ output_cost_per_million: float = DEFAULT_OUTPUT_COST_PER_MILLION
24
+ cache_creation_cost_per_million: float = DEFAULT_CACHE_CREATION_COST_PER_MILLION
25
+ cache_read_cost_per_million: float = DEFAULT_CACHE_READ_COST_PER_MILLION
26
+
27
+ @classmethod
28
+ def default_sonnet_tier(cls) -> ModelPricing:
29
+ return cls()
30
+
31
+
32
+ @dataclass
33
+ class TokenUsage:
34
+ """Token counters accumulated for a conversation turn or session."""
35
+
36
+ input_tokens: int = 0
37
+ output_tokens: int = 0
38
+ cache_creation_input_tokens: int = 0
39
+ cache_read_input_tokens: int = 0
40
+
41
+ def total_tokens(self) -> int:
42
+ return (
43
+ self.input_tokens
44
+ + self.output_tokens
45
+ + self.cache_creation_input_tokens
46
+ + self.cache_read_input_tokens
47
+ )
48
+
49
+ def estimate_cost_usd(self) -> UsageCostEstimate:
50
+ return self.estimate_cost_usd_with_pricing(ModelPricing.default_sonnet_tier())
51
+
52
+ def estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate:
53
+ return UsageCostEstimate(
54
+ input_cost_usd=_cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
55
+ output_cost_usd=_cost_for_tokens(
56
+ self.output_tokens, pricing.output_cost_per_million
57
+ ),
58
+ cache_creation_cost_usd=_cost_for_tokens(
59
+ self.cache_creation_input_tokens,
60
+ pricing.cache_creation_cost_per_million,
61
+ ),
62
+ cache_read_cost_usd=_cost_for_tokens(
63
+ self.cache_read_input_tokens,
64
+ pricing.cache_read_cost_per_million,
65
+ ),
66
+ )
67
+
68
+ def __iadd__(self, other: TokenUsage) -> TokenUsage:
69
+ self.input_tokens += other.input_tokens
70
+ self.output_tokens += other.output_tokens
71
+ self.cache_creation_input_tokens += other.cache_creation_input_tokens
72
+ self.cache_read_input_tokens += other.cache_read_input_tokens
73
+ return self
74
+
75
+ def summary_lines(self, label: str, model: str | None = None) -> list[str]:
76
+ pricing = pricing_for_model(model) if model else None
77
+ cost = (
78
+ self.estimate_cost_usd_with_pricing(pricing)
79
+ if pricing
80
+ else self.estimate_cost_usd()
81
+ )
82
+ model_suffix = f" model={model}" if model else ""
83
+ return [
84
+ f"{label}: total_tokens={self.total_tokens()} "
85
+ f"input={self.input_tokens} output={self.output_tokens} "
86
+ f"cache_write={self.cache_creation_input_tokens} "
87
+ f"cache_read={self.cache_read_input_tokens} "
88
+ f"estimated_cost={format_usd(cost.total_cost_usd())}{model_suffix}",
89
+ f" cost breakdown: input={format_usd(cost.input_cost_usd)} "
90
+ f"output={format_usd(cost.output_cost_usd)} "
91
+ f"cache_write={format_usd(cost.cache_creation_cost_usd)} "
92
+ f"cache_read={format_usd(cost.cache_read_cost_usd)}",
93
+ ]
94
+
95
+
96
+ @dataclass
97
+ class UsageCostEstimate:
98
+ """Estimated dollar cost derived from a TokenUsage sample."""
99
+
100
+ input_cost_usd: float = 0.0
101
+ output_cost_usd: float = 0.0
102
+ cache_creation_cost_usd: float = 0.0
103
+ cache_read_cost_usd: float = 0.0
104
+
105
+ def total_cost_usd(self) -> float:
106
+ return (
107
+ self.input_cost_usd
108
+ + self.output_cost_usd
109
+ + self.cache_creation_cost_usd
110
+ + self.cache_read_cost_usd
111
+ )
112
+
113
+
114
+ def pricing_for_model(model: str | None) -> ModelPricing | None:
115
+ """Returns pricing metadata for a known model alias or family."""
116
+ if model is None:
117
+ return None
118
+ normalized = model.lower()
119
+ if "haiku" in normalized:
120
+ return ModelPricing(
121
+ input_cost_per_million=1.0,
122
+ output_cost_per_million=5.0,
123
+ cache_creation_cost_per_million=1.25,
124
+ cache_read_cost_per_million=0.1,
125
+ )
126
+ if "opus" in normalized:
127
+ return ModelPricing(
128
+ input_cost_per_million=5.0, # Opus 4.6: $5/MTok input
129
+ output_cost_per_million=25.0, # Opus 4.6: $25/MTok output
130
+ cache_creation_cost_per_million=6.25,
131
+ cache_read_cost_per_million=0.5,
132
+ )
133
+ if "sonnet" in normalized:
134
+ return ModelPricing.default_sonnet_tier()
135
+ # OpenAI models — ordered specific to general
136
+ # GPT-4.1 series (1M context)
137
+ if "gpt-4.1-nano" in normalized:
138
+ return ModelPricing(
139
+ input_cost_per_million=0.10, output_cost_per_million=0.40,
140
+ cache_creation_cost_per_million=0.05, cache_read_cost_per_million=0.025,
141
+ )
142
+ if "gpt-4.1-mini" in normalized:
143
+ return ModelPricing(
144
+ input_cost_per_million=0.40, output_cost_per_million=1.60,
145
+ cache_creation_cost_per_million=0.20, cache_read_cost_per_million=0.10,
146
+ )
147
+ if "gpt-4.1" in normalized:
148
+ return ModelPricing(
149
+ input_cost_per_million=2.0, output_cost_per_million=8.0,
150
+ cache_creation_cost_per_million=1.0, cache_read_cost_per_million=0.50,
151
+ )
152
+ # GPT-4o series
153
+ if "gpt-4o-mini" in normalized or "4o-mini" in normalized:
154
+ return ModelPricing(
155
+ input_cost_per_million=0.15, output_cost_per_million=0.60,
156
+ cache_creation_cost_per_million=0.075, cache_read_cost_per_million=0.075,
157
+ )
158
+ if "gpt-4o" in normalized or "4o" in normalized:
159
+ return ModelPricing(
160
+ input_cost_per_million=2.50, output_cost_per_million=10.0,
161
+ cache_creation_cost_per_million=1.25, cache_read_cost_per_million=1.25,
162
+ )
163
+ # o-series reasoning
164
+ if "o4-mini" in normalized:
165
+ return ModelPricing(
166
+ input_cost_per_million=1.10, output_cost_per_million=4.40,
167
+ cache_creation_cost_per_million=0.55, cache_read_cost_per_million=0.275,
168
+ )
169
+ if "o3-mini" in normalized:
170
+ return ModelPricing(
171
+ input_cost_per_million=1.10, output_cost_per_million=4.40,
172
+ cache_creation_cost_per_million=0.55, cache_read_cost_per_million=0.55,
173
+ )
174
+ if normalized in ("o1", "o3", "o1-pro"):
175
+ return ModelPricing(
176
+ input_cost_per_million=15.0, output_cost_per_million=60.0,
177
+ cache_creation_cost_per_million=7.5, cache_read_cost_per_million=7.5,
178
+ )
179
+ if "o1-mini" in normalized:
180
+ return ModelPricing(
181
+ input_cost_per_million=1.10, output_cost_per_million=4.40,
182
+ cache_creation_cost_per_million=0.55, cache_read_cost_per_million=0.55,
183
+ )
184
+ # Codex (Responses API — agent-tuned coding models)
185
+ if "codex-mini" in normalized:
186
+ return ModelPricing(
187
+ input_cost_per_million=0.25, output_cost_per_million=2.0,
188
+ cache_creation_cost_per_million=0.125, cache_read_cost_per_million=0.025,
189
+ )
190
+ if "codex" in normalized:
191
+ return ModelPricing(
192
+ input_cost_per_million=1.25, output_cost_per_million=10.0,
193
+ cache_creation_cost_per_million=0.625, cache_read_cost_per_million=0.125,
194
+ )
195
+ # xAI
196
+ if "grok" in normalized:
197
+ return ModelPricing(
198
+ input_cost_per_million=5.0,
199
+ output_cost_per_million=15.0,
200
+ cache_creation_cost_per_million=2.5,
201
+ cache_read_cost_per_million=2.5,
202
+ )
203
+ return None
204
+
205
+
206
+ def _cost_for_tokens(tokens: int, cost_per_million: float) -> float:
207
+ return (tokens / 1_000_000) * cost_per_million
208
+
209
+
210
+ def format_usd(value: float) -> str:
211
+ """Format a USD value with 4 decimal places."""
212
+ return f"${value:.4f}"
213
+
214
+
215
+ # ---------------------------------------------------------------------------
216
+ # Usage tracker (accumulates across turns)
217
+ # ---------------------------------------------------------------------------
218
+
219
+ @dataclass
220
+ class UsageTracker:
221
+ """Accumulates token usage across multiple turns."""
222
+
223
+ total: TokenUsage = field(default_factory=TokenUsage)
224
+ turn_count: int = 0
225
+ model: str | None = None
226
+
227
+ def record_turn(self, usage: TokenUsage) -> None:
228
+ self.total += usage
229
+ self.turn_count += 1
230
+
231
+ def summary_lines(self) -> list[str]:
232
+ return self.total.summary_lines("Session total", self.model)
233
+
234
+ @classmethod
235
+ def from_session(cls, session: Any) -> UsageTracker:
236
+ """Build tracker from existing session messages' usage data."""
237
+ tracker = cls()
238
+ for msg in getattr(session, "messages", []):
239
+ if msg.usage is not None:
240
+ tracker.record_turn(msg.usage)
241
+ return tracker
@@ -0,0 +1,186 @@
1
+ """Worker state machine and registry with full lifecycle management.
2
+
3
+ Maps to: rust/crates/runtime/src/worker_boot.rs
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import enum
9
+ import logging
10
+ import time
11
+ import uuid
12
+ from dataclasses import dataclass, field
13
+ from typing import Any
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class WorkerStatus(enum.Enum):
19
+ SPAWNING = "spawning"
20
+ TRUST_REQUIRED = "trust_required"
21
+ READY_FOR_PROMPT = "ready_for_prompt"
22
+ RUNNING = "running"
23
+ FINISHED = "finished"
24
+ FAILED = "failed"
25
+
26
+
27
+ class WorkerFailureKind(enum.Enum):
28
+ TRUST_GATE = "trust_gate"
29
+ PROMPT_DELIVERY = "prompt_delivery"
30
+ PROTOCOL = "protocol"
31
+ PROVIDER = "provider"
32
+
33
+
34
+ class WorkerTrustResolution(enum.Enum):
35
+ AUTO_ALLOWLISTED = "auto_allowlisted"
36
+ MANUAL_APPROVAL = "manual_approval"
37
+
38
+
39
+ class WorkerPromptTarget(enum.Enum):
40
+ SHELL = "shell"
41
+ WRONG_TARGET = "wrong_target"
42
+ UNKNOWN = "unknown"
43
+
44
+
45
+ @dataclass
46
+ class WorkerEvent:
47
+ """Timestamped event in a worker's lifecycle."""
48
+
49
+ seq: int
50
+ kind: str
51
+ status: WorkerStatus
52
+ detail: str = ""
53
+ payload: dict[str, Any] | None = None
54
+ timestamp_ms: int = field(default_factory=lambda: int(time.time() * 1000))
55
+
56
+
57
+ @dataclass
58
+ class Worker:
59
+ """Represents a single agent worker instance.
60
+
61
+ Implements the full state machine:
62
+ SPAWNING → TRUST_REQUIRED → READY_FOR_PROMPT → RUNNING → FINISHED/FAILED
63
+ """
64
+
65
+ worker_id: str = field(default_factory=lambda: f"w-{uuid.uuid4().hex[:8]}")
66
+ cwd: str = ""
67
+ status: WorkerStatus = WorkerStatus.SPAWNING
68
+ trust_auto_resolve: bool = False
69
+ trust_gate_cleared: bool = False
70
+ auto_recover_prompt_misdelivery: bool = True
71
+ prompt_delivery_attempts: int = 0
72
+ max_prompt_delivery_attempts: int = 3
73
+ events: list[WorkerEvent] = field(default_factory=list)
74
+
75
+ def transition(self, new_status: WorkerStatus, detail: str = "", **payload: Any) -> None:
76
+ """Transition to a new status with event logging."""
77
+ event = WorkerEvent(
78
+ seq=len(self.events) + 1,
79
+ kind=f"transition_{new_status.value}",
80
+ status=new_status,
81
+ detail=detail,
82
+ payload=payload if payload else None,
83
+ )
84
+ self.events.append(event)
85
+ self.status = new_status
86
+ logger.debug("Worker %s → %s: %s", self.worker_id, new_status.value, detail)
87
+
88
+ def resolve_trust(self, resolution: WorkerTrustResolution) -> None:
89
+ """Resolve the trust gate and advance to ready state."""
90
+ if self.status != WorkerStatus.TRUST_REQUIRED:
91
+ logger.warning("Cannot resolve trust in status %s", self.status.value)
92
+ return
93
+ self.trust_gate_cleared = True
94
+ self.transition(
95
+ WorkerStatus.READY_FOR_PROMPT,
96
+ detail=f"Trust resolved via {resolution.value}",
97
+ )
98
+
99
+ def deliver_prompt(self, target: WorkerPromptTarget = WorkerPromptTarget.SHELL) -> bool:
100
+ """Attempt to deliver the prompt to the worker."""
101
+ self.prompt_delivery_attempts += 1
102
+
103
+ if target == WorkerPromptTarget.WRONG_TARGET:
104
+ if (
105
+ self.auto_recover_prompt_misdelivery
106
+ and self.prompt_delivery_attempts < self.max_prompt_delivery_attempts
107
+ ):
108
+ self.transition(
109
+ WorkerStatus.READY_FOR_PROMPT,
110
+ detail=f"Prompt misdelivered (attempt {self.prompt_delivery_attempts}), retrying",
111
+ )
112
+ return False
113
+ self.transition(
114
+ WorkerStatus.FAILED,
115
+ detail="Prompt misdelivery exceeded max attempts",
116
+ )
117
+ return False
118
+
119
+ if target == WorkerPromptTarget.UNKNOWN:
120
+ self.transition(WorkerStatus.FAILED, detail="Unknown prompt target")
121
+ return False
122
+
123
+ self.transition(WorkerStatus.RUNNING, detail="Prompt delivered to shell")
124
+ return True
125
+
126
+ def finish(self, detail: str = "Completed successfully") -> None:
127
+ self.transition(WorkerStatus.FINISHED, detail=detail)
128
+
129
+ def fail(self, kind: WorkerFailureKind, detail: str = "") -> None:
130
+ self.transition(WorkerStatus.FAILED, detail=f"{kind.value}: {detail}")
131
+
132
+ def restart(self) -> None:
133
+ """Restart the worker from SPAWNING state."""
134
+ self.trust_gate_cleared = False
135
+ self.prompt_delivery_attempts = 0
136
+ self.transition(WorkerStatus.SPAWNING, detail="Restarted")
137
+
138
+ @property
139
+ def is_active(self) -> bool:
140
+ return self.status in (WorkerStatus.RUNNING, WorkerStatus.READY_FOR_PROMPT)
141
+
142
+ @property
143
+ def is_terminal(self) -> bool:
144
+ return self.status in (WorkerStatus.FINISHED, WorkerStatus.FAILED)
145
+
146
+
147
+ class WorkerRegistry:
148
+ """Manages multiple worker instances.
149
+
150
+ Maps to: rust/crates/runtime/src/worker_boot.rs::WorkerRegistry
151
+ """
152
+
153
+ def __init__(self) -> None:
154
+ self._workers: dict[str, Worker] = {}
155
+
156
+ def spawn(self, cwd: str = "", **kwargs: Any) -> Worker:
157
+ """Create and register a new worker."""
158
+ worker = Worker(cwd=cwd, **kwargs)
159
+ self._workers[worker.worker_id] = worker
160
+ worker.transition(WorkerStatus.SPAWNING, detail="Worker spawned")
161
+ return worker
162
+
163
+ def register(self, worker: Worker) -> None:
164
+ self._workers[worker.worker_id] = worker
165
+
166
+ def get(self, worker_id: str) -> Worker | None:
167
+ return self._workers.get(worker_id)
168
+
169
+ def all_workers(self) -> list[Worker]:
170
+ return list(self._workers.values())
171
+
172
+ def active_workers(self) -> list[Worker]:
173
+ return [w for w in self._workers.values() if w.is_active]
174
+
175
+ def finished_workers(self) -> list[Worker]:
176
+ return [w for w in self._workers.values() if w.is_terminal]
177
+
178
+ def remove(self, worker_id: str) -> bool:
179
+ return self._workers.pop(worker_id, None) is not None
180
+
181
+ def summary(self) -> dict[str, int]:
182
+ """Count workers by status."""
183
+ counts: dict[str, int] = {}
184
+ for worker in self._workers.values():
185
+ counts[worker.status.value] = counts.get(worker.status.value, 0) + 1
186
+ return counts
File without changes