claude-code-generator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- claude_code_generator-0.1.0.dist-info/METADATA +176 -0
- claude_code_generator-0.1.0.dist-info/RECORD +49 -0
- claude_code_generator-0.1.0.dist-info/WHEEL +5 -0
- claude_code_generator-0.1.0.dist-info/entry_points.txt +2 -0
- claude_code_generator-0.1.0.dist-info/licenses/LICENSE +21 -0
- claude_code_generator-0.1.0.dist-info/top_level.txt +1 -0
- code_generator/__init__.py +3 -0
- code_generator/agents.py +177 -0
- code_generator/cli.py +49 -0
- code_generator/commands/__init__.py +1 -0
- code_generator/commands/generate.py +252 -0
- code_generator/commands/init.py +72 -0
- code_generator/commands/review.py +117 -0
- code_generator/commands/status.py +83 -0
- code_generator/env.py +55 -0
- code_generator/gh.py +331 -0
- code_generator/logging_setup.py +73 -0
- code_generator/orchestrator/__init__.py +4 -0
- code_generator/orchestrator/cycle_loop.py +371 -0
- code_generator/orchestrator/phase0_complexity.py +159 -0
- code_generator/orchestrator/phase1_plan.py +170 -0
- code_generator/orchestrator/phase2_review.py +126 -0
- code_generator/orchestrator/phase3_4_implement.py +164 -0
- code_generator/orchestrator/phase5_closure.py +154 -0
- code_generator/orchestrator/phase6_test.py +98 -0
- code_generator/orchestrator/phase7_commit.py +167 -0
- code_generator/prompts/__init__.py +86 -0
- code_generator/prompts/prompt-phase-0-complexity.md +85 -0
- code_generator/prompts/prompt-phase-1-planning.md +209 -0
- code_generator/prompts/prompt-phase-2-issue-review.md +84 -0
- code_generator/prompts/prompt-phase-3-implementation.md +191 -0
- code_generator/prompts/prompt-phase-5-final-review.md +135 -0
- code_generator/prompts/prompt-phase-6-test.md +102 -0
- code_generator/prompts/prompt-phase-7-commit.md +103 -0
- code_generator/prompts/prompt-review.md +124 -0
- code_generator/runner/__init__.py +26 -0
- code_generator/runner/rate_limit.py +113 -0
- code_generator/runner/retry.py +165 -0
- code_generator/runner/sdk_runner.py +267 -0
- code_generator/runner/subprocess_runner.py +200 -0
- code_generator/state.py +178 -0
- code_generator/templates/__init__.py +1 -0
- code_generator/templates/angular.md +12 -0
- code_generator/templates/base.md +28 -0
- code_generator/templates/fastapi.md +12 -0
- code_generator/templates/finance.md +9 -0
- code_generator/templates/fullstack.md +24 -0
- code_generator/templates/nestjs.md +9 -0
- code_generator/templates/python-cli.md +9 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Retry logic with exponential backoff and a circuit breaker.
|
|
2
|
+
|
|
3
|
+
Backoff applies only to transient errors. RateLimitHit and OverageAbort
|
|
4
|
+
bypass both backoff and the circuit breaker — they are handled by the
|
|
5
|
+
wait-and-resume loop in rate_limit.py, not by blind retrying.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Awaitable, Callable
|
|
15
|
+
|
|
16
|
+
from code_generator.runner.sdk_runner import OverageAbort, RateLimitHit
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T")
|
|
19
|
+
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
# Backoff schedule (seconds). Capped at 120.
|
|
22
|
+
# ---------------------------------------------------------------------------
|
|
23
|
+
|
|
24
|
+
_BACKOFF_SCHEDULE: list[int] = [10, 20, 40, 80, 120]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
# Sentinel exception for transient runner errors
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TransientRunnerError(RuntimeError):
|
|
33
|
+
"""Marker exception for transient runner failures that should be retried."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ---------------------------------------------------------------------------
|
|
37
|
+
# Transient-error types (retry-eligible)
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
|
|
40
|
+
_TRANSIENT_TYPES: tuple[type[BaseException], ...] = (
|
|
41
|
+
asyncio.TimeoutError,
|
|
42
|
+
ConnectionError,
|
|
43
|
+
OSError,
|
|
44
|
+
TransientRunnerError,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _is_transient(exc: BaseException) -> bool:
|
|
49
|
+
"""Return True for exceptions that warrant a retry via backoff."""
|
|
50
|
+
return isinstance(exc, _TRANSIENT_TYPES)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _is_rate_limit_or_overage(exc: BaseException) -> bool:
|
|
54
|
+
"""Return True for exceptions that bypass both backoff and the breaker."""
|
|
55
|
+
return isinstance(exc, (RateLimitHit, OverageAbort))
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
# Circuit breaker
|
|
60
|
+
# ---------------------------------------------------------------------------
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class CircuitOpen(RuntimeError):
|
|
64
|
+
"""Raised when the circuit breaker trips due to consecutive failures."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class CircuitBreaker:
|
|
68
|
+
"""Counts consecutive non-rate-limit failures and opens after max_failures.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
max_failures: Number of consecutive failures before the circuit opens.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, max_failures: int = 3) -> None:
|
|
75
|
+
self._max_failures = max_failures
|
|
76
|
+
self._consecutive_failures = 0
|
|
77
|
+
|
|
78
|
+
def record_success(self) -> None:
|
|
79
|
+
"""Reset the failure counter after a successful operation."""
|
|
80
|
+
self._consecutive_failures = 0
|
|
81
|
+
|
|
82
|
+
def record_failure(self, exc: BaseException) -> None:
|
|
83
|
+
"""Increment the failure counter; raise CircuitOpen when threshold is hit.
|
|
84
|
+
|
|
85
|
+
Rate-limit and overage exceptions are excluded from the counter —
|
|
86
|
+
they are expected transients managed elsewhere.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
exc: The exception that caused the failure.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
CircuitOpen: When consecutive failures reach max_failures.
|
|
93
|
+
"""
|
|
94
|
+
if _is_rate_limit_or_overage(exc):
|
|
95
|
+
# Non-negotiable bypass — never count rate-limit/overage towards circuit.
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
self._consecutive_failures += 1
|
|
99
|
+
if self._consecutive_failures >= self._max_failures:
|
|
100
|
+
raise CircuitOpen(
|
|
101
|
+
f"Circuit opened after {self._consecutive_failures} consecutive failures."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# ---------------------------------------------------------------------------
|
|
106
|
+
# with_backoff
|
|
107
|
+
# ---------------------------------------------------------------------------
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
async def with_backoff(
|
|
111
|
+
coro_factory: Callable[[], Awaitable[T]],
|
|
112
|
+
*,
|
|
113
|
+
attempts: int = 5,
|
|
114
|
+
sleep_fn: Callable[[float], Awaitable[None]] = asyncio.sleep,
|
|
115
|
+
breaker: CircuitBreaker | None = None,
|
|
116
|
+
) -> T:
|
|
117
|
+
"""Execute coro_factory with exponential backoff on transient errors.
|
|
118
|
+
|
|
119
|
+
Only retries on transient errors (asyncio.TimeoutError, ConnectionError,
|
|
120
|
+
OSError, TransientRunnerError). RateLimitHit and OverageAbort propagate
|
|
121
|
+
immediately without touching the circuit breaker.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
coro_factory: Zero-argument callable returning an awaitable.
|
|
125
|
+
attempts: Maximum number of attempts (1 means no retries).
|
|
126
|
+
sleep_fn: Async sleep function (injectable for tests).
|
|
127
|
+
breaker: Optional CircuitBreaker instance.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
The value returned by coro_factory on success.
|
|
131
|
+
|
|
132
|
+
Raises:
|
|
133
|
+
CircuitOpen: When the breaker trips.
|
|
134
|
+
RateLimitHit: Immediately, bypassing backoff.
|
|
135
|
+
OverageAbort: Immediately, bypassing backoff.
|
|
136
|
+
The original exception: When all attempts are exhausted.
|
|
137
|
+
"""
|
|
138
|
+
last_exc: BaseException | None = None
|
|
139
|
+
|
|
140
|
+
for attempt in range(attempts):
|
|
141
|
+
try:
|
|
142
|
+
result = await coro_factory()
|
|
143
|
+
if breaker is not None:
|
|
144
|
+
breaker.record_success()
|
|
145
|
+
return result
|
|
146
|
+
except (RateLimitHit, OverageAbort):
|
|
147
|
+
# Non-negotiable: bypass backoff and breaker for rate-limit/overage.
|
|
148
|
+
raise
|
|
149
|
+
except BaseException as exc:
|
|
150
|
+
if not _is_transient(exc):
|
|
151
|
+
# Non-transient error — propagate without retry or breaker.
|
|
152
|
+
raise
|
|
153
|
+
|
|
154
|
+
last_exc = exc
|
|
155
|
+
|
|
156
|
+
if breaker is not None:
|
|
157
|
+
# record_failure may raise CircuitOpen — let it propagate.
|
|
158
|
+
breaker.record_failure(exc)
|
|
159
|
+
|
|
160
|
+
delay = _BACKOFF_SCHEDULE[min(attempt, len(_BACKOFF_SCHEDULE) - 1)]
|
|
161
|
+
await sleep_fn(float(delay))
|
|
162
|
+
|
|
163
|
+
# All attempts exhausted.
|
|
164
|
+
assert last_exc is not None
|
|
165
|
+
raise last_exc
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Primary runner: async wrapper around claude_agent_sdk.ClaudeSDKClient.
|
|
2
|
+
|
|
3
|
+
Handles RateLimitEvent processing, overage abort, and session resume.
|
|
4
|
+
Non-negotiables: #1 (strip env), #3 (bypassPermissions), #4 (overage abort).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import TYPE_CHECKING, Any
|
|
11
|
+
|
|
12
|
+
from code_generator import env as _env
|
|
13
|
+
from code_generator import state as _state
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import logging
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
# Guard SDK import so the module can be loaded without the SDK installed
|
|
20
|
+
# (used by get_runner() to detect availability).
|
|
21
|
+
try:
|
|
22
|
+
from claude_agent_sdk import (
|
|
23
|
+
AssistantMessage,
|
|
24
|
+
ClaudeSDKClient,
|
|
25
|
+
ResultMessage,
|
|
26
|
+
TextBlock,
|
|
27
|
+
)
|
|
28
|
+
from claude_agent_sdk.types import RateLimitEvent
|
|
29
|
+
|
|
30
|
+
_SDK_AVAILABLE = True
|
|
31
|
+
except ImportError: # pragma: no cover
|
|
32
|
+
_SDK_AVAILABLE = False
|
|
33
|
+
# Provide sentinel names so type checkers and duck-typed tests still work.
|
|
34
|
+
ClaudeSDKClient = None # type: ignore[assignment,misc]
|
|
35
|
+
AssistantMessage = None # type: ignore[assignment]
|
|
36
|
+
ResultMessage = None # type: ignore[assignment]
|
|
37
|
+
TextBlock = None # type: ignore[assignment]
|
|
38
|
+
RateLimitEvent = None # type: ignore[assignment]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
# Public exception types
|
|
43
|
+
# ---------------------------------------------------------------------------
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class OverageAbort(RuntimeError):
|
|
47
|
+
"""Raised when overage billing is active — abort to avoid charges."""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RateLimitHit(Exception):
|
|
51
|
+
"""Raised when the rate limit is rejected and the session is paused."""
|
|
52
|
+
|
|
53
|
+
def __init__(self, resets_at: int, rate_limit_type: str) -> None:
|
|
54
|
+
self.resets_at = resets_at
|
|
55
|
+
self.rate_limit_type = rate_limit_type
|
|
56
|
+
super().__init__(f"Rate limit hit; resets at {resets_at}, type={rate_limit_type}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ---------------------------------------------------------------------------
|
|
60
|
+
# Result type
|
|
61
|
+
# ---------------------------------------------------------------------------
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class RunResult:
|
|
66
|
+
"""Captured output from a single SDK run."""
|
|
67
|
+
|
|
68
|
+
text: str
|
|
69
|
+
session_id: str | None
|
|
70
|
+
tokens_in: int | None
|
|
71
|
+
tokens_out: int | None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
# Internal helpers
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _extract_text(msg: Any) -> str:
|
|
80
|
+
"""Extract concatenated text from an AssistantMessage's content blocks."""
|
|
81
|
+
parts: list[str] = []
|
|
82
|
+
for block in getattr(msg, "content", []):
|
|
83
|
+
text = getattr(block, "text", None)
|
|
84
|
+
if text is not None:
|
|
85
|
+
parts.append(text)
|
|
86
|
+
return "".join(parts)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _extract_usage(msg: Any) -> tuple[int | None, int | None]:
|
|
90
|
+
"""Return (tokens_in, tokens_out) from a ResultMessage, defensively."""
|
|
91
|
+
usage = getattr(msg, "usage", None)
|
|
92
|
+
if usage is None:
|
|
93
|
+
return None, None
|
|
94
|
+
if isinstance(usage, dict):
|
|
95
|
+
return usage.get("input_tokens"), usage.get("output_tokens")
|
|
96
|
+
return getattr(usage, "input_tokens", None), getattr(usage, "output_tokens", None)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _handle_rate_limit_event(
|
|
100
|
+
msg: Any,
|
|
101
|
+
logger: logging.Logger,
|
|
102
|
+
state_path: Path,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""Process a RateLimitEvent; raise OverageAbort or RateLimitHit as needed.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
msg: The RateLimitEvent message (duck-typed).
|
|
108
|
+
logger: Phase logger for warnings.
|
|
109
|
+
state_path: Path to state.json for persistence.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
OverageAbort: When overage billing is active.
|
|
113
|
+
RateLimitHit: When the rate limit is rejected and resets_at is known.
|
|
114
|
+
"""
|
|
115
|
+
info = getattr(msg, "rate_limit_info", None)
|
|
116
|
+
if info is None:
|
|
117
|
+
logger.warning("RateLimitEvent received but has no rate_limit_info — skipping.")
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
# Non-negotiable #4: overage check — abort before incurring charges.
|
|
121
|
+
# SDK values: "allowed" / "allowed_warning" mean overage billing IS active
|
|
122
|
+
# (unsafe); "rejected" means overage is off and requests are blocked (safe);
|
|
123
|
+
# None / "disabled" also safe.
|
|
124
|
+
overage_status = getattr(info, "overage_status", None)
|
|
125
|
+
if overage_status in ("allowed", "allowed_warning"):
|
|
126
|
+
raise OverageAbort(
|
|
127
|
+
f"Overage API billing active ({overage_status}) — aborting to avoid charges"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
status = getattr(info, "status", None)
|
|
131
|
+
|
|
132
|
+
if status == "rejected":
|
|
133
|
+
resets_at = getattr(info, "resets_at", None)
|
|
134
|
+
rate_limit_type = getattr(info, "rate_limit_type", None) or "unknown"
|
|
135
|
+
session_id = getattr(msg, "session_id", None)
|
|
136
|
+
|
|
137
|
+
if resets_at is not None:
|
|
138
|
+
st = _state.load_state(state_path)
|
|
139
|
+
_state.mark_paused(
|
|
140
|
+
st,
|
|
141
|
+
resets_at=resets_at,
|
|
142
|
+
rate_limit_type=str(rate_limit_type),
|
|
143
|
+
session_id=session_id,
|
|
144
|
+
)
|
|
145
|
+
_state.save_state(state_path, st)
|
|
146
|
+
raise RateLimitHit(resets_at, str(rate_limit_type))
|
|
147
|
+
|
|
148
|
+
logger.warning("Rate limit rejected but no resets_at provided — cannot pause.")
|
|
149
|
+
|
|
150
|
+
elif status == "allowed_warning":
|
|
151
|
+
utilization = getattr(info, "utilization", None)
|
|
152
|
+
logger.warning(
|
|
153
|
+
"Rate limit allowed_warning: approaching limit (utilization=%s).", utilization
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# ---------------------------------------------------------------------------
|
|
158
|
+
# Public API
|
|
159
|
+
# ---------------------------------------------------------------------------
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
async def run(
|
|
163
|
+
prompt: str,
|
|
164
|
+
options: Any,
|
|
165
|
+
*,
|
|
166
|
+
logger: logging.Logger,
|
|
167
|
+
state_path: Path,
|
|
168
|
+
) -> RunResult:
|
|
169
|
+
"""Run a prompt via ClaudeSDKClient and return the accumulated result.
|
|
170
|
+
|
|
171
|
+
Non-negotiable #1: strips dangerous env vars at the start of every call.
|
|
172
|
+
Non-negotiable #3: forces permission_mode to bypassPermissions.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
prompt: The prompt text to send.
|
|
176
|
+
options: ClaudeAgentOptions (or duck-typed equivalent).
|
|
177
|
+
logger: Phase logger for rate-limit warnings and debug output.
|
|
178
|
+
state_path: Path to state.json used for pause persistence.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
RunResult with accumulated text, session_id, and token counts.
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
OverageAbort: When overage billing is detected.
|
|
185
|
+
RateLimitHit: When the rate limit is rejected.
|
|
186
|
+
"""
|
|
187
|
+
# Non-negotiable #1 — strip dangerous env vars on every invocation.
|
|
188
|
+
_env.strip_dangerous_env()
|
|
189
|
+
|
|
190
|
+
# Non-negotiable #3 — always force bypassPermissions; never trust caller.
|
|
191
|
+
options.permission_mode = "bypassPermissions"
|
|
192
|
+
|
|
193
|
+
text_parts: list[str] = []
|
|
194
|
+
session_id: str | None = None
|
|
195
|
+
tokens_in: int | None = None
|
|
196
|
+
tokens_out: int | None = None
|
|
197
|
+
|
|
198
|
+
logger.info("SDK session starting (model=%s).", getattr(options, "model", "?"))
|
|
199
|
+
msg_count = 0
|
|
200
|
+
async with ClaudeSDKClient(options=options) as client:
|
|
201
|
+
await client.query(prompt)
|
|
202
|
+
async for msg in client.receive_messages():
|
|
203
|
+
msg_count += 1
|
|
204
|
+
# Detect message type by duck-typing on attributes rather than
|
|
205
|
+
# isinstance() so fake test objects work without SDK installed.
|
|
206
|
+
if _is_rate_limit_event(msg):
|
|
207
|
+
logger.info("msg #%d: RateLimitEvent", msg_count)
|
|
208
|
+
_handle_rate_limit_event(msg, logger, state_path)
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
if _is_assistant_message(msg):
|
|
212
|
+
chunk = _extract_text(msg)
|
|
213
|
+
text_parts.append(chunk)
|
|
214
|
+
if chunk.strip():
|
|
215
|
+
logger.info(
|
|
216
|
+
"msg #%d: assistant (%d chars)\n%s",
|
|
217
|
+
msg_count, len(chunk), chunk.rstrip(),
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
logger.debug("msg #%d: assistant (empty)", msg_count)
|
|
221
|
+
# Some SDK versions attach session_id to AssistantMessage.
|
|
222
|
+
if session_id is None:
|
|
223
|
+
session_id = getattr(msg, "session_id", None)
|
|
224
|
+
|
|
225
|
+
elif _is_result_message(msg):
|
|
226
|
+
session_id = getattr(msg, "session_id", session_id)
|
|
227
|
+
tokens_in, tokens_out = _extract_usage(msg)
|
|
228
|
+
logger.info(
|
|
229
|
+
"msg #%d: ResultMessage (tokens_in=%s tokens_out=%s)",
|
|
230
|
+
msg_count, tokens_in, tokens_out,
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
logger.debug("msg #%d: %s", msg_count, type(msg).__name__)
|
|
234
|
+
|
|
235
|
+
logger.info("SDK session complete (%d messages).", msg_count)
|
|
236
|
+
|
|
237
|
+
return RunResult(
|
|
238
|
+
text="".join(text_parts),
|
|
239
|
+
session_id=session_id,
|
|
240
|
+
tokens_in=tokens_in,
|
|
241
|
+
tokens_out=tokens_out,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
# ---------------------------------------------------------------------------
|
|
246
|
+
# Duck-type discriminators
|
|
247
|
+
# ---------------------------------------------------------------------------
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _is_rate_limit_event(msg: Any) -> bool:
|
|
251
|
+
"""Return True when msg looks like a RateLimitEvent."""
|
|
252
|
+
return hasattr(msg, "rate_limit_info")
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _is_assistant_message(msg: Any) -> bool:
|
|
256
|
+
"""Return True when msg looks like an AssistantMessage."""
|
|
257
|
+
return hasattr(msg, "content") and hasattr(msg, "model")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _is_result_message(msg: Any) -> bool:
|
|
261
|
+
"""Return True when msg looks like a ResultMessage."""
|
|
262
|
+
return (
|
|
263
|
+
hasattr(msg, "session_id")
|
|
264
|
+
and hasattr(msg, "is_error")
|
|
265
|
+
and not hasattr(msg, "rate_limit_info")
|
|
266
|
+
and not hasattr(msg, "model")
|
|
267
|
+
)
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Subprocess fallback runner: shells out to the `claude` CLI.
|
|
2
|
+
|
|
3
|
+
Used when claude_agent_sdk is not importable. Spawns:
|
|
4
|
+
claude -p <prompt> --output-format stream-json --verbose
|
|
5
|
+
--dangerously-skip-permissions --model <model>
|
|
6
|
+
--max-turns <n> --allowedTools <csv>
|
|
7
|
+
|
|
8
|
+
Non-negotiables: #1 (build_agent_env), #2 (never --bare).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import asyncio
|
|
14
|
+
import json
|
|
15
|
+
import subprocess
|
|
16
|
+
import time
|
|
17
|
+
from typing import TYPE_CHECKING, Any
|
|
18
|
+
|
|
19
|
+
from code_generator import env as _env
|
|
20
|
+
from code_generator import state as _state
|
|
21
|
+
|
|
22
|
+
# Re-export the shared exception types so callers can import from one place.
|
|
23
|
+
from code_generator.runner.sdk_runner import RateLimitHit, RunResult
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
import logging
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
# Model name translation map
|
|
31
|
+
# SDK uses full model IDs; the `claude` CLI may use short aliases.
|
|
32
|
+
# ---------------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
_MODEL_ALIASES: dict[str, str] = {
|
|
35
|
+
"claude-opus-4-6": "claude-opus-4-6",
|
|
36
|
+
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
|
37
|
+
"claude-haiku-4-5": "claude-haiku-4-5",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _translate_model(model: str | None) -> str:
|
|
42
|
+
"""Return the CLI model alias for a given SDK model name."""
|
|
43
|
+
if model is None:
|
|
44
|
+
return "claude-sonnet-4-6"
|
|
45
|
+
return _MODEL_ALIASES.get(model, model)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# ---------------------------------------------------------------------------
|
|
49
|
+
# NDJSON line handlers
|
|
50
|
+
# ---------------------------------------------------------------------------
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _handle_line(
|
|
54
|
+
line: str,
|
|
55
|
+
text_parts: list[str],
|
|
56
|
+
result_holder: dict[str, Any],
|
|
57
|
+
logger: logging.Logger,
|
|
58
|
+
state_path: Path,
|
|
59
|
+
proc: Any,
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Parse one NDJSON line and mutate text_parts / result_holder in place.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
line: A single JSON line from the CLI stdout.
|
|
65
|
+
text_parts: Accumulator for assistant text fragments.
|
|
66
|
+
result_holder: Dict capturing session_id and token counts.
|
|
67
|
+
logger: Phase logger.
|
|
68
|
+
state_path: Path to state.json for pause persistence.
|
|
69
|
+
proc: The Popen process (so it can be killed on hard rate-limit).
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
RateLimitHit: When retry_delay_ms exceeds 120_000 ms.
|
|
73
|
+
"""
|
|
74
|
+
try:
|
|
75
|
+
data = json.loads(line)
|
|
76
|
+
except json.JSONDecodeError:
|
|
77
|
+
logger.debug("Ignoring malformed NDJSON line: %r", line[:200])
|
|
78
|
+
return
|
|
79
|
+
|
|
80
|
+
msg_type = data.get("type")
|
|
81
|
+
|
|
82
|
+
if msg_type == "assistant":
|
|
83
|
+
message = data.get("message", {})
|
|
84
|
+
for block in message.get("content", []):
|
|
85
|
+
if block.get("type") == "text":
|
|
86
|
+
text_parts.append(block.get("text", ""))
|
|
87
|
+
|
|
88
|
+
elif msg_type == "result":
|
|
89
|
+
result_holder["session_id"] = data.get("session_id")
|
|
90
|
+
usage = data.get("usage") or {}
|
|
91
|
+
result_holder["tokens_in"] = usage.get("input_tokens")
|
|
92
|
+
result_holder["tokens_out"] = usage.get("output_tokens")
|
|
93
|
+
|
|
94
|
+
elif msg_type == "system" and data.get("subtype") == "api_retry":
|
|
95
|
+
if data.get("error") == "rate_limit":
|
|
96
|
+
delay_ms: int = data.get("retry_delay_ms", 0)
|
|
97
|
+
session_id: str | None = data.get("session_id")
|
|
98
|
+
if delay_ms > 120_000:
|
|
99
|
+
paused_until = int(time.time() + delay_ms / 1000)
|
|
100
|
+
st = _state.load_state(state_path)
|
|
101
|
+
_state.mark_paused(
|
|
102
|
+
st,
|
|
103
|
+
resets_at=paused_until,
|
|
104
|
+
rate_limit_type="five_hour",
|
|
105
|
+
session_id=session_id,
|
|
106
|
+
)
|
|
107
|
+
_state.save_state(state_path, st)
|
|
108
|
+
proc.kill()
|
|
109
|
+
raise RateLimitHit(paused_until, "five_hour")
|
|
110
|
+
else:
|
|
111
|
+
logger.info(
|
|
112
|
+
"CLI rate-limit retry in %dms — letting CLI handle internally.", delay_ms
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ---------------------------------------------------------------------------
|
|
117
|
+
# Public API
|
|
118
|
+
# ---------------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
async def run(
|
|
122
|
+
prompt: str,
|
|
123
|
+
options: Any,
|
|
124
|
+
*,
|
|
125
|
+
logger: logging.Logger,
|
|
126
|
+
state_path: Path,
|
|
127
|
+
) -> RunResult:
|
|
128
|
+
"""Spawn the Claude CLI and parse its stream-json output.
|
|
129
|
+
|
|
130
|
+
Non-negotiable #1: uses build_agent_env() — no dangerous vars reach the CLI.
|
|
131
|
+
Non-negotiable #2: --bare is never included; an assertion guards this.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
prompt: The prompt text to pass via -p.
|
|
135
|
+
options: Duck-typed options object with model, max_turns, allowed_tools, cwd.
|
|
136
|
+
logger: Phase logger for debug/info/warning messages.
|
|
137
|
+
state_path: Path to state.json for pause persistence.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
RunResult with accumulated text, session_id, and token counts.
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
RateLimitHit: When retry_delay_ms > 120_000 in the stream.
|
|
144
|
+
"""
|
|
145
|
+
model = _translate_model(getattr(options, "model", None))
|
|
146
|
+
max_turns = getattr(options, "max_turns", 50)
|
|
147
|
+
allowed_tools: list[str] = getattr(options, "allowed_tools", [])
|
|
148
|
+
cwd: str | None = getattr(options, "cwd", None)
|
|
149
|
+
|
|
150
|
+
argv: list[str] = [
|
|
151
|
+
"claude",
|
|
152
|
+
"-p",
|
|
153
|
+
prompt,
|
|
154
|
+
"--output-format",
|
|
155
|
+
"stream-json",
|
|
156
|
+
"--verbose",
|
|
157
|
+
"--dangerously-skip-permissions",
|
|
158
|
+
"--model",
|
|
159
|
+
model,
|
|
160
|
+
"--max-turns",
|
|
161
|
+
str(max_turns),
|
|
162
|
+
]
|
|
163
|
+
if allowed_tools:
|
|
164
|
+
argv += ["--allowedTools", ",".join(allowed_tools)]
|
|
165
|
+
|
|
166
|
+
# Non-negotiable #2 — never pass --bare.
|
|
167
|
+
assert "--bare" not in argv, "BUG: --bare must never appear in the claude CLI argv."
|
|
168
|
+
|
|
169
|
+
safe_env = _env.build_agent_env()
|
|
170
|
+
|
|
171
|
+
text_parts: list[str] = []
|
|
172
|
+
result_holder: dict[str, Any] = {
|
|
173
|
+
"session_id": None,
|
|
174
|
+
"tokens_in": None,
|
|
175
|
+
"tokens_out": None,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
# Run subprocess in a thread to keep the coroutine non-blocking.
|
|
179
|
+
def _run_subprocess() -> None:
|
|
180
|
+
with subprocess.Popen(
|
|
181
|
+
argv,
|
|
182
|
+
stdout=subprocess.PIPE,
|
|
183
|
+
stderr=subprocess.PIPE,
|
|
184
|
+
env=safe_env,
|
|
185
|
+
cwd=cwd,
|
|
186
|
+
) as proc:
|
|
187
|
+
for raw_line in proc.stdout: # type: ignore[union-attr]
|
|
188
|
+
line = raw_line.decode(errors="replace").rstrip()
|
|
189
|
+
if not line:
|
|
190
|
+
continue
|
|
191
|
+
_handle_line(line, text_parts, result_holder, logger, state_path, proc)
|
|
192
|
+
|
|
193
|
+
await asyncio.to_thread(_run_subprocess)
|
|
194
|
+
|
|
195
|
+
return RunResult(
|
|
196
|
+
text="".join(text_parts),
|
|
197
|
+
session_id=result_holder["session_id"],
|
|
198
|
+
tokens_in=result_holder["tokens_in"],
|
|
199
|
+
tokens_out=result_holder["tokens_out"],
|
|
200
|
+
)
|