bareagent-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. bareagent/__init__.py +10 -0
  2. bareagent/concurrency/__init__.py +6 -0
  3. bareagent/concurrency/background.py +97 -0
  4. bareagent/concurrency/notification.py +61 -0
  5. bareagent/concurrency/scheduler.py +136 -0
  6. bareagent/config.toml +299 -0
  7. bareagent/core/__init__.py +1 -0
  8. bareagent/core/config_paths.py +49 -0
  9. bareagent/core/context.py +127 -0
  10. bareagent/core/fileutil.py +103 -0
  11. bareagent/core/goal.py +214 -0
  12. bareagent/core/handlers/__init__.py +1 -0
  13. bareagent/core/handlers/bash.py +79 -0
  14. bareagent/core/handlers/file_edit.py +47 -0
  15. bareagent/core/handlers/file_read.py +270 -0
  16. bareagent/core/handlers/file_write.py +34 -0
  17. bareagent/core/handlers/glob_search.py +30 -0
  18. bareagent/core/handlers/goal.py +60 -0
  19. bareagent/core/handlers/grep_search.py +52 -0
  20. bareagent/core/handlers/memory.py +71 -0
  21. bareagent/core/handlers/plan.py +106 -0
  22. bareagent/core/handlers/search_utils.py +77 -0
  23. bareagent/core/handlers/skill.py +87 -0
  24. bareagent/core/handlers/subagent_send.py +70 -0
  25. bareagent/core/handlers/web_fetch.py +126 -0
  26. bareagent/core/handlers/web_search.py +165 -0
  27. bareagent/core/handlers/workflow.py +190 -0
  28. bareagent/core/loop.py +535 -0
  29. bareagent/core/retry.py +131 -0
  30. bareagent/core/sandbox.py +27 -0
  31. bareagent/core/schema.py +21 -0
  32. bareagent/core/tools.py +779 -0
  33. bareagent/core/workflow.py +517 -0
  34. bareagent/core/workflow_registry.py +219 -0
  35. bareagent/debug/__init__.py +0 -0
  36. bareagent/debug/interaction_log.py +263 -0
  37. bareagent/debug/viewer.html +1750 -0
  38. bareagent/debug/web_viewer.py +157 -0
  39. bareagent/hooks/__init__.py +32 -0
  40. bareagent/hooks/config.py +118 -0
  41. bareagent/hooks/engine.py +197 -0
  42. bareagent/hooks/errors.py +14 -0
  43. bareagent/hooks/events.py +22 -0
  44. bareagent/lsp/__init__.py +63 -0
  45. bareagent/lsp/config.py +134 -0
  46. bareagent/lsp/coord.py +118 -0
  47. bareagent/lsp/diagnostics.py +240 -0
  48. bareagent/lsp/errors.py +24 -0
  49. bareagent/lsp/manager.py +866 -0
  50. bareagent/lsp/tools.py +629 -0
  51. bareagent/lsp/workspace_edit.py +305 -0
  52. bareagent/main.py +4205 -0
  53. bareagent/mcp/__init__.py +69 -0
  54. bareagent/mcp/_sse.py +69 -0
  55. bareagent/mcp/client.py +341 -0
  56. bareagent/mcp/config.py +169 -0
  57. bareagent/mcp/errors.py +32 -0
  58. bareagent/mcp/manager.py +318 -0
  59. bareagent/mcp/protocol.py +187 -0
  60. bareagent/mcp/registry.py +557 -0
  61. bareagent/mcp/transport/__init__.py +15 -0
  62. bareagent/mcp/transport/base.py +149 -0
  63. bareagent/mcp/transport/http_legacy.py +192 -0
  64. bareagent/mcp/transport/http_streamable.py +217 -0
  65. bareagent/mcp/transport/stdio.py +202 -0
  66. bareagent/memory/__init__.py +1 -0
  67. bareagent/memory/compact.py +203 -0
  68. bareagent/memory/conversation_io.py +226 -0
  69. bareagent/memory/embedding.py +194 -0
  70. bareagent/memory/persistent.py +515 -0
  71. bareagent/memory/token_counter.py +67 -0
  72. bareagent/memory/token_tracker.py +262 -0
  73. bareagent/memory/transcript.py +100 -0
  74. bareagent/permission/__init__.py +1 -0
  75. bareagent/permission/guard.py +329 -0
  76. bareagent/permission/rules.py +19 -0
  77. bareagent/planning/__init__.py +19 -0
  78. bareagent/planning/agent_types.py +169 -0
  79. bareagent/planning/skill_gen.py +141 -0
  80. bareagent/planning/skill_store.py +173 -0
  81. bareagent/planning/skills.py +146 -0
  82. bareagent/planning/subagent.py +355 -0
  83. bareagent/planning/subagent_registry.py +77 -0
  84. bareagent/planning/tasks.py +348 -0
  85. bareagent/planning/todo.py +153 -0
  86. bareagent/planning/worktree.py +122 -0
  87. bareagent/provider/__init__.py +1 -0
  88. bareagent/provider/anthropic.py +348 -0
  89. bareagent/provider/base.py +136 -0
  90. bareagent/provider/factory.py +130 -0
  91. bareagent/provider/openai.py +881 -0
  92. bareagent/provider/presets.py +72 -0
  93. bareagent/provider/setup.py +356 -0
  94. bareagent/skills/.gitkeep +1 -0
  95. bareagent/skills/code-review/SKILL.md +68 -0
  96. bareagent/skills/git/SKILL.md +68 -0
  97. bareagent/skills/test/SKILL.md +70 -0
  98. bareagent/team/__init__.py +17 -0
  99. bareagent/team/autonomous.py +193 -0
  100. bareagent/team/mailbox.py +239 -0
  101. bareagent/team/manager.py +155 -0
  102. bareagent/team/protocols.py +129 -0
  103. bareagent/tracing/__init__.py +12 -0
  104. bareagent/tracing/_api.py +92 -0
  105. bareagent/tracing/_proxy.py +60 -0
  106. bareagent/tracing/composite.py +115 -0
  107. bareagent/tracing/json_file.py +115 -0
  108. bareagent/tracing/langfuse.py +139 -0
  109. bareagent/tracing/otel.py +107 -0
  110. bareagent/tracing/setup.py +85 -0
  111. bareagent/ui/__init__.py +24 -0
  112. bareagent/ui/console.py +167 -0
  113. bareagent/ui/prompt.py +78 -0
  114. bareagent/ui/protocol.py +24 -0
  115. bareagent/ui/stream.py +66 -0
  116. bareagent/ui/theme.py +240 -0
  117. bareagent_cli-0.1.0.dist-info/METADATA +331 -0
  118. bareagent_cli-0.1.0.dist-info/RECORD +121 -0
  119. bareagent_cli-0.1.0.dist-info/WHEEL +4 -0
  120. bareagent_cli-0.1.0.dist-info/entry_points.txt +2 -0
  121. bareagent_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
bareagent/core/loop.py ADDED
@@ -0,0 +1,535 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from collections.abc import Callable
7
+ from typing import Any, cast
8
+
9
+ from bareagent.concurrency.notification import inject_notifications
10
+ from bareagent.core.fileutil import stringify
11
+ from bareagent.core.retry import RetryPolicy, run_with_retry
12
+ from bareagent.provider.base import BaseLLMProvider, LLMResponse, StreamEvent, ToolCall
13
+ from bareagent.tracing import tracer as global_tracer
14
+ from bareagent.ui.protocol import StreamProtocol, UIProtocol
15
+ from bareagent.ui.stream import StreamPrinter
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class LLMCallError(Exception):
21
+ """Raised when an LLM call fails or the agent loop exceeds its iteration limit."""
22
+
23
+
24
+ class _StreamingUnavailableError(RuntimeError):
25
+ """Raised when streaming is explicitly unsupported before any events arrive."""
26
+
27
+
28
+ def agent_loop(
29
+ provider: BaseLLMProvider,
30
+ messages: list[dict[str, Any]],
31
+ tools: list[dict[str, Any]],
32
+ handlers: dict[str, Callable[..., Any]],
33
+ permission: Any = None,
34
+ compact_fn: Callable[[list[dict[str, Any]]], None] | None = None,
35
+ bg_manager: Any = None,
36
+ stream: bool = False,
37
+ console: UIProtocol | None = None,
38
+ max_iterations: int = 200,
39
+ interaction_logger: Any = None,
40
+ token_tracker: Any = None,
41
+ hook_engine: Any = None,
42
+ retry_policy: RetryPolicy | None = None,
43
+ skill_gen: Any = None,
44
+ ) -> str:
45
+ compact = compact_fn or (lambda _messages: None)
46
+ hook_session_id = _resolve_hook_session_id(compact_fn)
47
+ hook_cwd = os.getcwd()
48
+ # Tool calls made during THIS user turn (accumulated across iterations).
49
+ # Fed to ``skill_gen`` only on normal completion so a failed/aborted turn
50
+ # never counts toward experiential skill generation. Sub-agents pass
51
+ # skill_gen=None, keeping generation a main-loop-only concern (like hooks).
52
+ turn_tool_calls = 0
53
+
54
+ for _iteration in range(max_iterations):
55
+ _run_background(bg_manager, messages)
56
+ compact(messages)
57
+
58
+ log_seq, log_started_at = _safe_log_request(
59
+ interaction_logger=interaction_logger,
60
+ messages=messages,
61
+ tools=tools,
62
+ provider=provider,
63
+ console=console,
64
+ )
65
+
66
+ model_name = getattr(provider, "model", "unknown")
67
+ with global_tracer.trace("llm_call", tags={"model": model_name}) as llm_span:
68
+ try:
69
+ response, streamed_output, displayed_tool_calls = _invoke_provider(
70
+ provider=provider,
71
+ messages=messages,
72
+ tools=tools,
73
+ stream=stream,
74
+ console=console,
75
+ retry_policy=retry_policy,
76
+ )
77
+ except BaseException as exc:
78
+ llm_span.set_error(str(exc) or type(exc).__name__)
79
+ _safe_log_response(
80
+ interaction_logger=interaction_logger,
81
+ log_seq=log_seq,
82
+ console=console,
83
+ duration_ms=(time.monotonic() - log_started_at) * 1000,
84
+ error=str(exc) or type(exc).__name__,
85
+ )
86
+ if not isinstance(exc, Exception):
87
+ raise
88
+ msg = f"LLM call failed: {type(exc).__name__}: {exc}"
89
+ if console is not None:
90
+ console.print_error(msg)
91
+ raise LLMCallError(msg) from exc
92
+
93
+ llm_span.set_tag("input_tokens", response.input_tokens)
94
+ llm_span.set_tag("output_tokens", response.output_tokens)
95
+ llm_span.set_content_tag("output", response.text)
96
+
97
+ # Aggregate token usage here so both streaming and non-streaming paths
98
+ # (both return through _invoke_provider) are covered at a single point.
99
+ if token_tracker is not None:
100
+ token_tracker.record(response, model_name)
101
+
102
+ _safe_log_response(
103
+ interaction_logger=interaction_logger,
104
+ log_seq=log_seq,
105
+ console=console,
106
+ text=response.text,
107
+ thinking=response.thinking,
108
+ tool_calls=_serialize_tool_calls(response.tool_calls),
109
+ input_tokens=response.input_tokens,
110
+ output_tokens=response.output_tokens,
111
+ duration_ms=(time.monotonic() - log_started_at) * 1000,
112
+ )
113
+
114
+ messages.append(response.to_message())
115
+ if response.text and console is not None and not streamed_output:
116
+ console.print_assistant(response.text)
117
+ if not response.has_tool_calls:
118
+ if not response.text:
119
+ _warn_empty_response(response, console)
120
+ if skill_gen is not None:
121
+ skill_gen.note_turn(turn_tool_calls)
122
+ return response.text or ""
123
+
124
+ turn_tool_calls += len(response.tool_calls)
125
+ results: list[dict[str, Any]] = []
126
+ for call in response.tool_calls:
127
+ if console is not None and call.id not in displayed_tool_calls:
128
+ console.print_tool_call(call.name, call.input)
129
+
130
+ if _requires_confirmation(permission, call):
131
+ if not _ask_permission(permission, call):
132
+ denied = "User denied."
133
+ if console is not None:
134
+ console.print_tool_result(call.name, denied)
135
+ results.append(_tool_result(call.id, denied, is_error=True))
136
+ continue
137
+
138
+ if hook_engine is not None:
139
+ outcome = hook_engine.run_pre_tool_use(
140
+ call.name,
141
+ call.input,
142
+ session_id=hook_session_id,
143
+ cwd=hook_cwd,
144
+ )
145
+ if outcome.block:
146
+ blocked = outcome.reason or "Blocked by PreToolUse hook."
147
+ if console is not None:
148
+ console.print_tool_result(call.name, blocked)
149
+ results.append(_tool_result(call.id, blocked, is_error=True))
150
+ continue
151
+
152
+ handler = handlers.get(call.name)
153
+ if handler is None:
154
+ output = f"Unknown tool: {call.name}"
155
+ if console is not None:
156
+ console.print_tool_result(call.name, output)
157
+ results.append(_tool_result(call.id, output, is_error=True))
158
+ continue
159
+
160
+ try:
161
+ with global_tracer.trace("tool_execution", tags={"tool": call.name}) as tool_span:
162
+ tool_span.set_content_tag("input", call.input)
163
+ output = handler(**call.input)
164
+ tool_span.set_content_tag("output", stringify(output))
165
+ except Exception as exc:
166
+ output = f"Error: {type(exc).__name__}: {exc}"
167
+ if console is not None:
168
+ console.print_tool_result(call.name, output)
169
+ results.append(_tool_result(call.id, output, is_error=True))
170
+ continue
171
+
172
+ if hook_engine is not None:
173
+ hook_engine.run_post_tool_use(
174
+ call.name,
175
+ call.input,
176
+ output,
177
+ is_error=False,
178
+ session_id=hook_session_id,
179
+ cwd=hook_cwd,
180
+ )
181
+
182
+ if console is not None:
183
+ console.print_tool_result(call.name, output)
184
+ results.append(_tool_result(call.id, output))
185
+
186
+ messages.append({"role": "user", "content": results})
187
+
188
+ msg = f"Agent loop exceeded {max_iterations} iterations"
189
+ if console is not None:
190
+ console.print_error(msg)
191
+ raise LLMCallError(msg)
192
+
193
+
194
+ def _warn_empty_response(response: LLMResponse, console: UIProtocol | None) -> None:
195
+ """Surface a non-fatal diagnostic for a degenerate empty response.
196
+
197
+ Fires when the model stops normally yet produced neither text nor tool
198
+ calls -- usually a wire_api/model mismatch, a relay returning an empty
199
+ output array, or the model declining to answer. This does not change the
200
+ loop's control flow: it still returns "" as before. Always logged (so
201
+ console-less sub-agents/teammates leave a trace); also shown on the console
202
+ when one is attached.
203
+ """
204
+ message = (
205
+ "LLM returned an empty response (no text or tool calls) -- "
206
+ f"stop_reason={response.stop_reason!r}, output_tokens={response.output_tokens}. "
207
+ "Possible wire_api/model mismatch or relay issue."
208
+ )
209
+ logger.warning(message)
210
+ if console is not None:
211
+ console.print_status(message)
212
+
213
+
214
+ def _invoke_provider(
215
+ provider: BaseLLMProvider,
216
+ messages: list[dict[str, Any]],
217
+ tools: list[dict[str, Any]],
218
+ *,
219
+ stream: bool,
220
+ console: UIProtocol | None,
221
+ retry_policy: RetryPolicy | None = None,
222
+ ) -> tuple[LLMResponse, bool, set[str]]:
223
+ # The whole provider call (including stream consumption, D5) is wrapped in
224
+ # run_with_retry so retryable transient failures (429 / 5xx / connection
225
+ # timeouts) back off and retry. When retry_policy is None / disabled the
226
+ # behavior is identical to a single direct call (backward compatible).
227
+ # _StreamingUnavailableError / NotImplementedError are control-flow signals
228
+ # with no status_code and class names outside the retryable set, so
229
+ # is_retryable returns False for them — the streaming fallback is unaffected.
230
+ def _call() -> tuple[LLMResponse, bool, set[str]]:
231
+ return _invoke_provider_once(
232
+ provider=provider,
233
+ messages=messages,
234
+ tools=tools,
235
+ stream=stream,
236
+ console=console,
237
+ )
238
+
239
+ if retry_policy is None:
240
+ return _call()
241
+
242
+ return run_with_retry(
243
+ _call,
244
+ retry_policy,
245
+ on_retry=_make_retry_notifier(console, retry_policy),
246
+ )
247
+
248
+
249
+ def _make_retry_notifier(
250
+ console: UIProtocol | None,
251
+ policy: RetryPolicy,
252
+ ) -> Callable[[BaseException, int, float], None] | None:
253
+ if console is None:
254
+ return None
255
+
256
+ def _notify(exc: BaseException, next_attempt: int, delay: float) -> None:
257
+ console.print_status(
258
+ f"LLM call failed ({type(exc).__name__}), retrying in {delay:.1f}s "
259
+ f"(attempt {next_attempt}/{policy.max_attempts})..."
260
+ )
261
+
262
+ return _notify
263
+
264
+
265
+ def _invoke_provider_once(
266
+ provider: BaseLLMProvider,
267
+ messages: list[dict[str, Any]],
268
+ tools: list[dict[str, Any]],
269
+ *,
270
+ stream: bool,
271
+ console: UIProtocol | None,
272
+ ) -> tuple[LLMResponse, bool, set[str]]:
273
+ if not stream:
274
+ return provider.create(messages=messages, tools=tools), False, set()
275
+
276
+ try:
277
+ stream_iter = provider.create_stream(messages=messages, tools=tools)
278
+ except Exception as exc:
279
+ if not _is_streaming_unsupported(exc):
280
+ raise
281
+ return _fallback_to_non_stream(
282
+ provider=provider,
283
+ messages=messages,
284
+ tools=tools,
285
+ console=console,
286
+ exc=exc,
287
+ )
288
+
289
+ try:
290
+ return _consume_stream(stream_iter, console=console)
291
+ except _StreamingUnavailableError as exc:
292
+ cause = exc.__cause__ or exc
293
+ return _fallback_to_non_stream(
294
+ provider=provider,
295
+ messages=messages,
296
+ tools=tools,
297
+ console=console,
298
+ exc=cause,
299
+ )
300
+
301
+
302
+ def _fallback_to_non_stream(
303
+ provider: BaseLLMProvider,
304
+ messages: list[dict[str, Any]],
305
+ tools: list[dict[str, Any]],
306
+ *,
307
+ console: UIProtocol | None,
308
+ exc: BaseException,
309
+ ) -> tuple[LLMResponse, bool, set[str]]:
310
+ if console is not None:
311
+ console.print_status(
312
+ f"Streaming unavailable, falling back to non-stream mode ({type(exc).__name__})."
313
+ )
314
+ return provider.create(messages=messages, tools=tools), False, set()
315
+
316
+
317
+ def _is_streaming_unsupported(exc: Exception) -> bool:
318
+ return isinstance(exc, NotImplementedError)
319
+
320
+
321
+ def _consume_stream(
322
+ stream_iter: Any,
323
+ *,
324
+ console: UIProtocol | None,
325
+ ) -> tuple[LLMResponse, bool, set[str]]:
326
+ printer = _get_stream_printer(console)
327
+ displayed_tool_calls: set[str] = set()
328
+ saw_stream_event = False
329
+ streamed_any_text = False
330
+ printer.start()
331
+
332
+ try:
333
+ while True:
334
+ try:
335
+ event = next(stream_iter)
336
+ except StopIteration as stop:
337
+ streamed_text = printer.finish()
338
+ response = stop.value
339
+ if response is None:
340
+ raise RuntimeError("Streaming provider did not return a response.") from None
341
+ return (
342
+ response,
343
+ streamed_any_text or bool(streamed_text),
344
+ displayed_tool_calls,
345
+ )
346
+
347
+ saw_stream_event = True
348
+ if event.type == "text" and bool(event.text):
349
+ streamed_any_text = True
350
+ _handle_stream_event(
351
+ event=event,
352
+ printer=printer,
353
+ console=console,
354
+ displayed_tool_calls=displayed_tool_calls,
355
+ )
356
+ except Exception as exc:
357
+ printer.finish()
358
+ if not saw_stream_event and _is_streaming_unsupported(exc):
359
+ raise _StreamingUnavailableError() from exc
360
+ raise
361
+
362
+
363
+ def _handle_stream_event(
364
+ event: StreamEvent,
365
+ *,
366
+ printer: StreamProtocol,
367
+ console: UIProtocol | None,
368
+ displayed_tool_calls: set[str],
369
+ ) -> None:
370
+ if event.type == "text":
371
+ printer.feed(event.text)
372
+ return
373
+
374
+ if event.type != "tool_call":
375
+ return
376
+
377
+ printer.finish()
378
+ if event.tool_call_id:
379
+ displayed_tool_calls.add(event.tool_call_id)
380
+ if console is not None:
381
+ console.print_tool_call(event.name, event.input)
382
+
383
+
384
+ def _get_stream_printer(console: UIProtocol | None) -> StreamProtocol:
385
+ if console is None:
386
+ return StreamPrinter()
387
+
388
+ get_stream_printer = getattr(console, "get_stream_printer", None)
389
+ if callable(get_stream_printer):
390
+ return cast(StreamProtocol, get_stream_printer())
391
+
392
+ # Backward compatibility for older console duck types that exposed `.console`
393
+ # but not a `get_stream_printer()` hook.
394
+ return StreamPrinter(getattr(console, "console", None))
395
+
396
+
397
+ def _tool_result(
398
+ tool_use_id: str,
399
+ output: str | list[dict[str, Any]] | Any,
400
+ *,
401
+ is_error: bool = False,
402
+ ) -> dict[str, Any]:
403
+ """Wrap a handler output into a tool_result block.
404
+
405
+ The ``output`` may be:
406
+ - ``str`` — used as-is.
407
+ - ``list[dict]`` — passed through verbatim (multimodal MCP path: text + image blocks).
408
+ - anything else — coerced via :func:`stringify`.
409
+ """
410
+ content: Any
411
+ if isinstance(output, list):
412
+ content = output
413
+ else:
414
+ content = stringify(output)
415
+ result: dict[str, Any] = {
416
+ "type": "tool_result",
417
+ "tool_use_id": tool_use_id,
418
+ "content": content,
419
+ }
420
+ if is_error:
421
+ result["is_error"] = True
422
+ return result
423
+
424
+
425
+ def _provider_info(provider: BaseLLMProvider) -> dict[str, Any]:
426
+ info: dict[str, Any] = {"provider_type": type(provider).__name__}
427
+ for name in ("model", "base_url", "wire_api"):
428
+ value = getattr(provider, name, None)
429
+ if value not in {None, ""}:
430
+ info[name] = value
431
+ return info
432
+
433
+
434
+ def _serialize_tool_calls(tool_calls: list[ToolCall]) -> list[dict[str, Any]]:
435
+ return [
436
+ {
437
+ "id": tool_call.id,
438
+ "name": tool_call.name,
439
+ "input": tool_call.input,
440
+ }
441
+ for tool_call in tool_calls
442
+ ]
443
+
444
+
445
+ def _resolve_hook_session_id(compact_fn: Any) -> str:
446
+ """Best-effort session id for hook JSON payloads.
447
+
448
+ Reuses the ``get_session_id`` attribute the REPL attaches to ``compact_fn``
449
+ (see ``main.py:_build_loop_compact``) rather than threading a new parameter
450
+ through every caller. Falls back to ``"default"`` when unavailable (tests,
451
+ sub-agents) — hooks don't run for sub-agents anyway.
452
+ """
453
+ getter = getattr(compact_fn, "get_session_id", None)
454
+ if callable(getter):
455
+ try:
456
+ return str(getter())
457
+ except Exception:
458
+ return "default"
459
+ return "default"
460
+
461
+
462
+ def _run_background(bg_manager: Any, messages: list[dict[str, Any]]) -> None:
463
+ if bg_manager is None:
464
+ return
465
+ inject_notifications(messages, bg_manager)
466
+
467
+
468
+ def _requires_confirmation(permission: Any, call: ToolCall) -> bool:
469
+ if permission is None:
470
+ return True
471
+
472
+ requires_confirm = getattr(permission, "requires_confirm", None)
473
+ if callable(requires_confirm):
474
+ return bool(requires_confirm(call.name, call.input))
475
+ return True
476
+
477
+
478
+ def _ask_permission(permission: Any, call: ToolCall) -> bool:
479
+ ask_user = getattr(permission, "ask_user", None)
480
+ if callable(ask_user):
481
+ return bool(ask_user(call))
482
+ return False
483
+
484
+
485
+ def _safe_log_request(
486
+ *,
487
+ interaction_logger: Any,
488
+ messages: list[dict[str, Any]],
489
+ tools: list[dict[str, Any]],
490
+ provider: BaseLLMProvider,
491
+ console: UIProtocol | None,
492
+ ) -> tuple[int | None, float]:
493
+ if interaction_logger is None:
494
+ return None, 0.0
495
+
496
+ try:
497
+ log_seq = interaction_logger.log_request(
498
+ messages,
499
+ tools,
500
+ provider_info=_provider_info(provider),
501
+ )
502
+ except Exception as exc:
503
+ _report_log_failure(console, "request", exc)
504
+ return None, 0.0
505
+
506
+ return log_seq, time.monotonic()
507
+
508
+
509
+ def _safe_log_response(
510
+ *,
511
+ interaction_logger: Any,
512
+ log_seq: int | None,
513
+ console: UIProtocol | None,
514
+ **kwargs: Any,
515
+ ) -> None:
516
+ if interaction_logger is None or log_seq is None:
517
+ return
518
+
519
+ try:
520
+ interaction_logger.log_response(log_seq, **kwargs)
521
+ except Exception as exc:
522
+ _report_log_failure(console, "response", exc)
523
+
524
+
525
+ def _report_log_failure(
526
+ console: UIProtocol | None,
527
+ phase: str,
528
+ exc: Exception,
529
+ ) -> None:
530
+ if console is None:
531
+ return
532
+
533
+ console.print_status(
534
+ f"Debug logging failed during {phase} capture ({type(exc).__name__}: {exc})."
535
+ )
@@ -0,0 +1,131 @@
1
+ """Provider-agnostic LLM retry policy: exponential backoff + retryable classification.
2
+
3
+ A pure module (no LLM / loop / SDK dependencies) so the policy, classifier, and
4
+ backoff math are fully unit-testable with injected ``sleep`` / ``rng``. The
5
+ agent loop wraps the single provider call site (``_invoke_provider``) with
6
+ :func:`run_with_retry`; the SDK clients are constructed with ``max_retries=0``
7
+ so this layer owns retries exclusively (no 2xN compound amplification).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import logging
13
+ import random
14
+ import time
15
+ from collections.abc import Callable
16
+ from dataclasses import dataclass
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Retryable HTTP status codes (aligned with the anthropic/openai SDKs' own
21
+ # retryable set + 529 overloaded).
22
+ _RETRYABLE_STATUS = frozenset({408, 409, 429, 500, 502, 503, 504, 529})
23
+ # Non-retryable status codes (auth / bad request / model-not-found, etc.) —
24
+ # raise immediately so a config error is never masked by retries.
25
+ _NON_RETRYABLE_STATUS = frozenset({400, 401, 403, 404, 413, 422})
26
+ # Retryable connection / timeout / server classes recognized by class name
27
+ # (no SDK import, so this stays cross-provider).
28
+ _RETRYABLE_NAMES = frozenset(
29
+ {
30
+ "APIConnectionError",
31
+ "APITimeoutError",
32
+ "APIConnectionTimeoutError",
33
+ "InternalServerError",
34
+ "OverloadedError",
35
+ "ServiceUnavailableError",
36
+ "ConnectionError",
37
+ "Timeout",
38
+ "TimeoutError",
39
+ "ReadTimeout",
40
+ "ConnectTimeout",
41
+ }
42
+ )
43
+
44
+
45
+ @dataclass(slots=True)
46
+ class RetryPolicy:
47
+ enabled: bool = True
48
+ max_attempts: int = 3 # total attempts (incl. first), <=1 disables retries
49
+ base_delay_sec: float = 1.0
50
+ max_delay_sec: float = 30.0
51
+ multiplier: float = 2.0
52
+ jitter: bool = True
53
+
54
+
55
+ def is_retryable(exc: BaseException) -> bool:
56
+ """Provider-agnostic retryable check.
57
+
58
+ Looks at ``status_code`` first, then the class name; unknown -> not retryable.
59
+ """
60
+ # Non-Exception (KeyboardInterrupt / SystemExit) is never retried.
61
+ if not isinstance(exc, Exception):
62
+ return False
63
+ status = getattr(exc, "status_code", None)
64
+ if status is None:
65
+ status = getattr(exc, "status", None)
66
+ if isinstance(status, int):
67
+ if status in _NON_RETRYABLE_STATUS:
68
+ return False
69
+ if status in _RETRYABLE_STATUS:
70
+ return True
71
+ # Any other 5xx is retryable; everything else is explicitly not.
72
+ return 500 <= status < 600
73
+ # No status code: match connection / timeout classes by name (including MRO).
74
+ for klass in type(exc).__mro__:
75
+ if klass.__name__ in _RETRYABLE_NAMES:
76
+ return True
77
+ return False
78
+
79
+
80
+ def compute_delay(
81
+ attempt: int,
82
+ policy: RetryPolicy,
83
+ rng: Callable[[float, float], float] = random.uniform,
84
+ ) -> float:
85
+ """Exponential backoff + cap + optional full jitter.
86
+
87
+ ``attempt`` starts at 1 (the wait before the first retry).
88
+ """
89
+ raw = policy.base_delay_sec * (policy.multiplier ** max(0, attempt - 1))
90
+ capped = min(policy.max_delay_sec, raw)
91
+ if policy.jitter:
92
+ return rng(0.0, capped)
93
+ return capped
94
+
95
+
96
+ def run_with_retry[T](
97
+ fn: Callable[[], T],
98
+ policy: RetryPolicy,
99
+ *,
100
+ on_retry: Callable[[BaseException, int, float], None] | None = None,
101
+ sleep: Callable[[float], None] = time.sleep,
102
+ rng: Callable[[float, float], float] = random.uniform,
103
+ ) -> T:
104
+ """Run ``fn``, backing off and retrying retryable exceptions per ``policy``.
105
+
106
+ - Non-retryable exceptions / non-Exception (KeyboardInterrupt, etc.) re-raise immediately.
107
+ - After exhausting ``max_attempts``, re-raises the **last** original exception
108
+ (preserving ``__cause__`` is the caller's responsibility).
109
+ - ``on_retry(exc, next_attempt, delay)`` is invoked before each sleep (observability).
110
+ """
111
+ if not policy.enabled or policy.max_attempts <= 1:
112
+ return fn()
113
+ attempt = 0
114
+ while True:
115
+ attempt += 1
116
+ try:
117
+ return fn()
118
+ except BaseException as exc: # noqa: BLE001 - must propagate non-Exception
119
+ if attempt >= policy.max_attempts or not is_retryable(exc):
120
+ raise
121
+ delay = compute_delay(attempt, policy, rng=rng)
122
+ if on_retry is not None:
123
+ on_retry(exc, attempt + 1, delay)
124
+ logger.warning(
125
+ "LLM call failed (%s), retrying in %.2fs (attempt %d/%d)",
126
+ type(exc).__name__,
127
+ delay,
128
+ attempt + 1,
129
+ policy.max_attempts,
130
+ )
131
+ sleep(delay)
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+
6
+ def safe_path(path: str, workspace: Path) -> Path:
7
+ """Resolve a path and ensure it stays within the workspace."""
8
+ if path.startswith("~"):
9
+ raise PermissionError(f"Home-relative paths are not allowed: {path!r}")
10
+ workspace_path = workspace.resolve(strict=False)
11
+ candidate = Path(path)
12
+ if candidate.is_absolute():
13
+ raise PermissionError(f"Absolute paths are not allowed: {path!r}")
14
+ resolved = (workspace_path / candidate).resolve(strict=False)
15
+ if not resolved.is_relative_to(workspace_path):
16
+ raise PermissionError(f"Path {path!r} escapes workspace {workspace_path}")
17
+ _check_no_symlink_in_chain(workspace_path, candidate)
18
+ return resolved
19
+
20
+
21
+ def _check_no_symlink_in_chain(workspace: Path, candidate: Path) -> None:
22
+ """Walk each component of *candidate* under *workspace* and reject symlinks."""
23
+ current = workspace
24
+ for part in candidate.parts:
25
+ current = current / part
26
+ if current.is_symlink():
27
+ raise PermissionError(f"Symlink detected in path chain: {current}")