coderouter-cli 2.0.0__py3-none-any.whl → 2.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,413 @@
1
+ """Self-healing routing orchestrator (v2.0-J).
2
+
3
+ When the L5 :class:`BackendHealthMonitor` transitions a provider to
4
+ UNHEALTHY and the profile's ``backend_health_action`` is ``exclude``,
5
+ this orchestrator:
6
+
7
+ 1. **Excludes** the provider from the chain (complete removal, not
8
+ just demotion to the back).
9
+ 2. **Attempts restart** if the provider declares a ``restart_command``
10
+ in ``providers.yaml``.
11
+ 3. **Schedules recovery probes** with exponential backoff (default
12
+ 30 s → 60 s → 120 s → 300 s cap) to detect when the backend
13
+ comes back online.
14
+ 4. **Restores** the provider to its original chain position on the
15
+ first successful recovery probe.
16
+
17
+ Architecture
18
+ ============
19
+
20
+ ::
21
+
22
+ BackendHealthMonitor
23
+ └─ transition → UNHEALTHY
24
+ └─ engine calls orchestrator.on_unhealthy(provider)
25
+ ├─ add to _excluded set
26
+ ├─ try restart_command (if configured)
27
+ └─ schedule recovery probe (async)
28
+
29
+ recovery_probe_loop:
30
+ while provider in _excluded:
31
+ sleep(backoff interval)
32
+ probe_one(provider)
33
+ if success:
34
+ remove from _excluded
35
+ record_attempt(success=True) → snap to HEALTHY
36
+ log restore
37
+ break
38
+ else:
39
+ interval = min(interval * 2, max_interval)
40
+
41
+ _resolve_chain (engine):
42
+ Pass 4b: if action == "exclude":
43
+ filter out providers in orchestrator._excluded
44
+
45
+ Design choices
46
+ ==============
47
+
48
+ - **Thread-safe** via an internal ``RLock`` (same pattern as
49
+ ``BackendHealthMonitor``). The exclude set and restart lock
50
+ are guarded independently.
51
+ - **No new dependency** — subprocess (stdlib) for restart commands,
52
+ asyncio for recovery probe scheduling.
53
+ - **Restart is opt-in** — only providers with ``restart_command``
54
+ set get automatic restart. Others rely solely on recovery probes
55
+ (waiting for manual restart by the operator).
56
+ - **Double-restart prevention** — a per-provider ``_restart_lock``
57
+ prevents concurrent restart attempts (e.g. two profiles both
58
+ hitting UNHEALTHY on the same provider).
59
+ """
60
+
61
+ from __future__ import annotations
62
+
63
+ import asyncio
64
+ import contextlib
65
+ import subprocess
66
+ import threading
67
+ import time
68
+ from dataclasses import dataclass
69
+
70
+ from coderouter.config.schemas import ProviderConfig
71
+ from coderouter.logging import (
72
+ get_logger,
73
+ log_self_healing_exclude,
74
+ log_self_healing_recovery_probe,
75
+ log_self_healing_restart,
76
+ log_self_healing_restore,
77
+ )
78
+
79
+ logger = get_logger(__name__)
80
+
81
+
82
+ @dataclass(slots=True)
83
+ class _ExcludedProvider:
84
+ """Metadata for a provider currently excluded from the chain."""
85
+
86
+ provider: str
87
+ excluded_at: float
88
+ profile: str
89
+ consecutive_failures: int
90
+
91
+
92
+ class SelfHealingOrchestrator:
93
+ """Manages provider exclusion, restart, and recovery probing.
94
+
95
+ Public API:
96
+
97
+ - :meth:`on_unhealthy(provider, ...)` — called by the engine when
98
+ a provider transitions to UNHEALTHY with action ``exclude``.
99
+ - :meth:`on_recovered(provider, ...)` — called when a recovery
100
+ probe succeeds or a regular request succeeds on a previously
101
+ excluded provider.
102
+ - :meth:`is_excluded(provider)` — True iff the provider is
103
+ currently excluded from the chain.
104
+ - :meth:`excluded_providers()` — set of currently excluded
105
+ provider names.
106
+ - :meth:`try_restart(provider_config, ...)` — attempt to restart
107
+ a provider's backend process.
108
+ - :meth:`reset()` — clear all state. Mainly for tests.
109
+ """
110
+
111
+ def __init__(self) -> None:
112
+ self._lock: threading.RLock = threading.RLock()
113
+ self._excluded: dict[str, _ExcludedProvider] = {}
114
+ # Per-provider lock to prevent concurrent restart attempts.
115
+ self._restart_locks: dict[str, threading.Lock] = {}
116
+
117
+ # ------------------------------------------------------------------
118
+ # Exclusion management
119
+ # ------------------------------------------------------------------
120
+
121
+ def on_unhealthy(
122
+ self,
123
+ provider: str,
124
+ *,
125
+ profile: str,
126
+ consecutive_failures: int,
127
+ ) -> bool:
128
+ """Mark a provider as excluded from the chain.
129
+
130
+ Returns True if the provider was newly excluded (not already
131
+ excluded). Returns False if it was already excluded (idempotent).
132
+ """
133
+ with self._lock:
134
+ if provider in self._excluded:
135
+ return False
136
+ self._excluded[provider] = _ExcludedProvider(
137
+ provider=provider,
138
+ excluded_at=time.monotonic(),
139
+ profile=profile,
140
+ consecutive_failures=consecutive_failures,
141
+ )
142
+ log_self_healing_exclude(
143
+ logger,
144
+ provider=provider,
145
+ profile=profile,
146
+ consecutive_failures=consecutive_failures,
147
+ )
148
+ return True
149
+
150
+ def on_recovered(
151
+ self,
152
+ provider: str,
153
+ *,
154
+ profile: str,
155
+ ) -> float | None:
156
+ """Restore a provider to the chain after recovery.
157
+
158
+ Returns the duration (seconds) the provider was excluded,
159
+ or None if the provider was not in the excluded set.
160
+ """
161
+ with self._lock:
162
+ entry = self._excluded.pop(provider, None)
163
+ if entry is None:
164
+ return None
165
+ duration = time.monotonic() - entry.excluded_at
166
+ log_self_healing_restore(
167
+ logger,
168
+ provider=provider,
169
+ profile=profile,
170
+ excluded_duration_s=duration,
171
+ )
172
+ return duration
173
+
174
+ def is_excluded(self, provider: str) -> bool:
175
+ """True iff the provider is currently excluded from the chain."""
176
+ with self._lock:
177
+ return provider in self._excluded
178
+
179
+ def excluded_providers(self) -> set[str]:
180
+ """Return a snapshot of currently excluded provider names."""
181
+ with self._lock:
182
+ return set(self._excluded.keys())
183
+
184
+ def reset(self) -> None:
185
+ """Drop all state. Mainly for tests."""
186
+ with self._lock:
187
+ self._excluded.clear()
188
+ self._restart_locks.clear()
189
+
190
+ # ------------------------------------------------------------------
191
+ # v2.0-K: Persistence
192
+ # ------------------------------------------------------------------
193
+
194
+ def save_state(self) -> dict[str, object]:
195
+ """Export the current excluded-provider set for persistence."""
196
+ with self._lock:
197
+ return {
198
+ name: {
199
+ "profile": entry.profile,
200
+ "consecutive_failures": entry.consecutive_failures,
201
+ }
202
+ for name, entry in self._excluded.items()
203
+ }
204
+
205
+ def load_state(self, state: dict[str, object]) -> None:
206
+ """Restore excluded providers from a previously saved dict.
207
+
208
+ Re-creates ``_ExcludedProvider`` entries with ``excluded_at``
209
+ set to the current time (the original exclude timestamp is
210
+ lost across restarts — the important thing is that the provider
211
+ *stays* excluded until a recovery probe succeeds).
212
+ """
213
+ if not isinstance(state, dict):
214
+ return
215
+
216
+ with self._lock:
217
+ for name, data in state.items():
218
+ if not isinstance(data, dict):
219
+ continue
220
+ if name in self._excluded:
221
+ continue # already excluded
222
+ profile = data.get("profile", "")
223
+ failures = data.get("consecutive_failures", 0)
224
+ if not isinstance(failures, int):
225
+ failures = 0
226
+ self._excluded[name] = _ExcludedProvider(
227
+ provider=name,
228
+ excluded_at=time.monotonic(),
229
+ profile=str(profile),
230
+ consecutive_failures=failures,
231
+ )
232
+
233
+ # ------------------------------------------------------------------
234
+ # Restart helper
235
+ # ------------------------------------------------------------------
236
+
237
+ def try_restart(
238
+ self,
239
+ provider_config: ProviderConfig,
240
+ *,
241
+ timeout_s: float = 30.0,
242
+ ) -> bool:
243
+ """Attempt to restart a provider's backend process.
244
+
245
+ Returns True if the restart command succeeded (exit code 0),
246
+ False otherwise (or if no restart_command is configured).
247
+
248
+ Thread-safe: only one restart per provider at a time. A
249
+ concurrent call returns False immediately without blocking.
250
+ """
251
+ command = provider_config.restart_command
252
+ if not command:
253
+ return False
254
+
255
+ provider = provider_config.name
256
+
257
+ # Get or create a per-provider lock.
258
+ with self._lock:
259
+ if provider not in self._restart_locks:
260
+ self._restart_locks[provider] = threading.Lock()
261
+ restart_lock = self._restart_locks[provider]
262
+
263
+ # Non-blocking acquire — if another thread is already
264
+ # restarting this provider, we skip silently.
265
+ if not restart_lock.acquire(blocking=False):
266
+ return False
267
+
268
+ try:
269
+ result = subprocess.run(
270
+ command,
271
+ shell=True,
272
+ capture_output=True,
273
+ timeout=timeout_s,
274
+ text=True,
275
+ )
276
+ success = result.returncode == 0
277
+ error = result.stderr.strip() if not success else None
278
+ log_self_healing_restart(
279
+ logger,
280
+ provider=provider,
281
+ command=command,
282
+ success=success,
283
+ error=error,
284
+ )
285
+ return success
286
+ except subprocess.TimeoutExpired:
287
+ log_self_healing_restart(
288
+ logger,
289
+ provider=provider,
290
+ command=command,
291
+ success=False,
292
+ error=f"timeout after {timeout_s}s",
293
+ )
294
+ return False
295
+ except OSError as exc:
296
+ log_self_healing_restart(
297
+ logger,
298
+ provider=provider,
299
+ command=command,
300
+ success=False,
301
+ error=str(exc),
302
+ )
303
+ return False
304
+ finally:
305
+ restart_lock.release()
306
+
307
+
308
+ # ---------------------------------------------------------------------------
309
+ # Recovery probe loop (async, runs as a background task)
310
+ # ---------------------------------------------------------------------------
311
+
312
+
313
+ async def recovery_probe_loop(
314
+ provider_config: ProviderConfig,
315
+ *,
316
+ orchestrator: SelfHealingOrchestrator,
317
+ record_fn: object | None = None,
318
+ health_threshold: int = 3,
319
+ initial_interval_s: float = 30.0,
320
+ max_interval_s: float = 300.0,
321
+ restart_timeout_s: float = 30.0,
322
+ probe_timeout_s: float = 10.0,
323
+ shutdown_event: asyncio.Event | None = None,
324
+ profile: str = "",
325
+ ) -> None:
326
+ """Probe an excluded provider with exponential backoff until recovery.
327
+
328
+ This function runs as a long-lived asyncio task, one per excluded
329
+ provider. It terminates when:
330
+
331
+ - The provider recovers (probe succeeds) → restores to chain.
332
+ - The shutdown event is set → graceful exit.
333
+ - The provider is no longer excluded (external recovery).
334
+
335
+ On first invocation, attempts a restart if configured, then waits
336
+ for the initial interval before the first probe.
337
+ """
338
+ from coderouter.guards.continuous_probe import probe_one
339
+
340
+ _shutdown = shutdown_event or asyncio.Event()
341
+ provider_name = provider_config.name
342
+ interval = initial_interval_s
343
+
344
+ # Step 1: attempt restart if configured.
345
+ if provider_config.restart_command:
346
+ # Run in executor to avoid blocking the event loop.
347
+ loop = asyncio.get_running_loop()
348
+ await loop.run_in_executor(
349
+ None,
350
+ lambda: orchestrator.try_restart(
351
+ provider_config,
352
+ timeout_s=restart_timeout_s,
353
+ ),
354
+ )
355
+
356
+ # Step 2: exponential backoff recovery probes.
357
+ while not _shutdown.is_set():
358
+ # Check if still excluded (external code may have restored it).
359
+ if not orchestrator.is_excluded(provider_name):
360
+ return
361
+
362
+ # Wait for the current interval (or shutdown).
363
+ try:
364
+ await asyncio.wait_for(_shutdown.wait(), timeout=interval)
365
+ return # shutdown signalled
366
+ except TimeoutError:
367
+ pass # normal: interval elapsed
368
+
369
+ # Still excluded? Probe.
370
+ if not orchestrator.is_excluded(provider_name):
371
+ return
372
+
373
+ result = await probe_one(provider_config, timeout_s=probe_timeout_s)
374
+
375
+ if result.success:
376
+ # Provider is back! Restore it.
377
+ orchestrator.on_recovered(provider_name, profile=profile)
378
+
379
+ # Feed success into the backend health state machine
380
+ # so it snaps back to HEALTHY.
381
+ if record_fn is not None:
382
+ with contextlib.suppress(Exception):
383
+ record_fn( # type: ignore[operator]
384
+ provider_name,
385
+ success=True,
386
+ threshold=health_threshold,
387
+ )
388
+
389
+ log_self_healing_recovery_probe(
390
+ logger,
391
+ provider=provider_name,
392
+ success=True,
393
+ next_interval_s=0,
394
+ latency_ms=result.latency_ms,
395
+ )
396
+ return
397
+
398
+ # Failed — exponential backoff.
399
+ next_interval = min(interval * 2, max_interval_s)
400
+ log_self_healing_recovery_probe(
401
+ logger,
402
+ provider=provider_name,
403
+ success=False,
404
+ next_interval_s=next_interval,
405
+ latency_ms=result.latency_ms,
406
+ )
407
+ interval = next_interval
408
+
409
+
410
+ __all__ = [
411
+ "SelfHealingOrchestrator",
412
+ "recovery_probe_loop",
413
+ ]
@@ -148,6 +148,44 @@ class ToolLoopDetection:
148
148
  """
149
149
 
150
150
 
151
+ @dataclass(frozen=True)
152
+ class ToolCountExceeded:
153
+ """The outcome of a total tool-call count check.
154
+
155
+ Returned by :func:`check_total_tool_count` when the conversation's
156
+ cumulative tool_use count exceeds the configured hard cap. This is
157
+ a safety valve against runaway agents that call many *different*
158
+ tools without looping (which L3's identical-streak detector misses).
159
+ """
160
+
161
+ total_count: int
162
+ """How many tool_use blocks the conversation currently contains."""
163
+ max_allowed: int
164
+ """The configured ceiling that was exceeded."""
165
+
166
+
167
+ class ToolCountExceededError(CodeRouterError):
168
+ """Raised when total tool-call count exceeds the hard cap.
169
+
170
+ The ingress converts this into a structured ``400`` response with
171
+ ``error: "tool_count_exceeded"`` so the client sees a programmable
172
+ failure rather than a 5xx.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ exceeded: ToolCountExceeded,
178
+ profile: str,
179
+ ) -> None:
180
+ super().__init__(
181
+ f"tool count exceeded on profile={profile!r}: "
182
+ f"{exceeded.total_count} tool calls exceed the limit of "
183
+ f"{exceeded.max_allowed}."
184
+ )
185
+ self.exceeded = exceeded
186
+ self.profile = profile
187
+
188
+
151
189
  class ToolLoopBreakError(CodeRouterError):
152
190
  """Raised when a loop is detected and the configured action is ``break``.
153
191
 
@@ -337,3 +375,36 @@ def inject_loop_break_hint(
337
375
  new_system = [*list(system), {"type": "text", "text": hint}]
338
376
 
339
377
  return request.model_copy(update={"system": new_system})
378
+
379
+
380
+ # ---------------------------------------------------------------------------
381
+ # Total tool-call count hard cap (v2.2)
382
+ # ---------------------------------------------------------------------------
383
+
384
+
385
+ def check_total_tool_count(
386
+ request: AnthropicRequest,
387
+ *,
388
+ max_calls: int,
389
+ ) -> ToolCountExceeded | None:
390
+ """Return a detection if total tool_use count exceeds ``max_calls``.
391
+
392
+ Unlike :func:`detect_tool_loop` which catches *identical*
393
+ consecutive calls, this is a blunt hard cap on the cumulative
394
+ number of tool_use blocks across the entire conversation. It
395
+ catches runaway agents that cycle through many *different* tools
396
+ without ever repeating the same (name, args) pair — a pattern
397
+ that the streak-based L3 detector cannot see.
398
+
399
+ Default ceiling is 50 (configurable per-profile). This is
400
+ deliberately more permissive than Unsloth Studio's 25 — Claude
401
+ Code's long-running agent sessions routinely reach 25+ tool calls
402
+ in normal operation.
403
+
404
+ Returns ``None`` when the count is within limits.
405
+ """
406
+ history = _extract_tool_use_history(request)
407
+ count = len(history)
408
+ if count > max_calls:
409
+ return ToolCountExceeded(total_count=count, max_allowed=max_calls)
410
+ return None
@@ -28,7 +28,7 @@ from typing import Any
28
28
  from fastapi import APIRouter, Header, HTTPException, Request
29
29
  from fastapi.responses import JSONResponse, StreamingResponse
30
30
 
31
- from coderouter.guards.tool_loop import ToolLoopBreakError
31
+ from coderouter.guards.tool_loop import ToolCountExceededError, ToolLoopBreakError
32
32
  from coderouter.logging import get_logger
33
33
  from coderouter.routing import (
34
34
  FallbackEngine,
@@ -49,6 +49,7 @@ _MODE_HEADER = "x-coderouter-mode"
49
49
  _ANTHROPIC_VERSION_HEADER = "anthropic-version"
50
50
  _ANTHROPIC_BETA_HEADER = "anthropic-beta"
51
51
  _CTX_BUDGET_HEADER = "X-CodeRouter-Context-Budget"
52
+ _DRIFT_HEADER = "X-CodeRouter-Drift"
52
53
 
53
54
 
54
55
  @router.post("/messages", response_model=None)
@@ -144,6 +145,12 @@ async def messages(
144
145
  }
145
146
  if ctx_budget_status:
146
147
  stream_headers[_CTX_BUDGET_HEADER] = ctx_budget_status
148
+ # v2.0-G: drift header is set post-stream via a trailer-like
149
+ # mechanism — for streaming we cannot know the verdict before
150
+ # the first chunk ships. Instead, check pre-existing drift state.
151
+ drift_severity = engine.last_drift_severity
152
+ if drift_severity:
153
+ stream_headers[_DRIFT_HEADER] = drift_severity
147
154
  return StreamingResponse(
148
155
  _anthropic_sse_iterator(engine, anth_req),
149
156
  media_type="text/event-stream",
@@ -166,11 +173,31 @@ async def messages(
166
173
  status_code=400,
167
174
  detail=_tool_loop_break_detail(exc),
168
175
  ) from exc
176
+ except ToolCountExceededError as exc:
177
+ # v2.2: total tool-call count exceeded — surface as a 400.
178
+ raise HTTPException(
179
+ status_code=400,
180
+ detail={
181
+ "error": "tool_count_exceeded",
182
+ "message": str(exc),
183
+ "total_count": exc.exceeded.total_count,
184
+ "max_allowed": exc.exceeded.max_allowed,
185
+ "profile": exc.profile,
186
+ },
187
+ ) from exc
169
188
 
189
+ # v2.0-G: collect drift header after engine dispatch.
190
+ drift_severity = engine.last_drift_severity
191
+ resp_headers: dict[str, str] = {}
170
192
  if ctx_budget_status:
193
+ resp_headers[_CTX_BUDGET_HEADER] = ctx_budget_status
194
+ if drift_severity:
195
+ resp_headers[_DRIFT_HEADER] = drift_severity
196
+
197
+ if resp_headers:
171
198
  return JSONResponse(
172
199
  content=anth_resp.model_dump(exclude_none=True),
173
- headers={_CTX_BUDGET_HEADER: ctx_budget_status},
200
+ headers=resp_headers,
174
201
  )
175
202
  return anth_resp.model_dump(exclude_none=True)
176
203
 
@@ -224,26 +251,93 @@ async def _anthropic_sse_iterator(
224
251
  },
225
252
  )
226
253
  yield _format_anthropic_sse(err_event)
227
- except MidStreamError as exc:
228
- # v0.3-B: a provider failed AFTER emitting at least one event. We
229
- # cannot fall back (client already received partial content), so
230
- # close the stream with an explicit error event. `api_error`
231
- # distinguishes this from "no provider could start" (overloaded).
232
- logger.warning(
233
- "sse-midstream-error",
234
- extra={"provider": exc.provider, "original": str(exc.original)},
235
- )
254
+ except ToolCountExceededError as exc:
255
+ # v2.2: streaming counterpart of the tool-count-exceeded 400.
236
256
  err_event = AnthropicStreamEvent(
237
257
  type="error",
238
258
  data={
239
259
  "type": "error",
240
260
  "error": {
241
- "type": "api_error",
261
+ "type": "invalid_request_error",
242
262
  "message": str(exc),
263
+ "tool_count": {
264
+ "total_count": exc.exceeded.total_count,
265
+ "max_allowed": exc.exceeded.max_allowed,
266
+ "profile": exc.profile,
267
+ },
243
268
  },
244
269
  },
245
270
  )
246
271
  yield _format_anthropic_sse(err_event)
272
+ except MidStreamError as exc:
273
+ # v0.3-B: a provider failed AFTER emitting at least one event. We
274
+ # cannot fall back (client already received partial content), so
275
+ # close the stream with an explicit error event. `api_error`
276
+ # distinguishes this from "no provider could start" (overloaded).
277
+ logger.warning(
278
+ "sse-midstream-error",
279
+ extra={"provider": exc.provider, "original": str(exc.original)},
280
+ )
281
+
282
+ # v2.0-H (L6): partial stitch surface mode — synthesize a graceful
283
+ # stream termination that delivers accumulated text to the client.
284
+ profile_name = anth_req.profile or engine.config.default_profile
285
+ partial_action = "off"
286
+ try:
287
+ chain_cfg = engine.config.profile_by_name(profile_name)
288
+ partial_action = chain_cfg.partial_stitch_action
289
+ except (KeyError, ValueError):
290
+ pass
291
+
292
+ if partial_action == "surface" and exc.partial_content:
293
+ # Emit message_delta with accumulated usage (signals stream end).
294
+ yield _format_anthropic_sse(AnthropicStreamEvent(
295
+ type="message_delta",
296
+ data={
297
+ "type": "message_delta",
298
+ "delta": {"stop_reason": None, "stop_sequence": None},
299
+ "usage": {"output_tokens": 0},
300
+ },
301
+ ))
302
+ # Emit message_stop so the client sees a complete stream.
303
+ yield _format_anthropic_sse(AnthropicStreamEvent(
304
+ type="message_stop",
305
+ data={"type": "message_stop"},
306
+ ))
307
+ # Emit coderouter_partial metadata event (client-optional).
308
+ yield _format_anthropic_sse(AnthropicStreamEvent(
309
+ type="coderouter_partial",
310
+ data={
311
+ "type": "coderouter_partial",
312
+ "partial_content": exc.partial_content,
313
+ "provider": exc.provider,
314
+ "reason": "mid_stream_failure",
315
+ "original_error": str(exc.original)[:200],
316
+ },
317
+ ))
318
+ logger.info(
319
+ "partial-stitch-surfaced",
320
+ extra={
321
+ "provider": exc.provider,
322
+ "profile": profile_name,
323
+ "text_blocks": len(exc.partial_content),
324
+ "text_length": sum(
325
+ len(b.get("text", "")) for b in exc.partial_content
326
+ ),
327
+ },
328
+ )
329
+ else:
330
+ err_event = AnthropicStreamEvent(
331
+ type="error",
332
+ data={
333
+ "type": "error",
334
+ "error": {
335
+ "type": "api_error",
336
+ "message": str(exc),
337
+ },
338
+ },
339
+ )
340
+ yield _format_anthropic_sse(err_event)
247
341
 
248
342
 
249
343
  def _format_anthropic_sse(ev: AnthropicStreamEvent) -> str: