ripperdoc 0.2.0__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripperdoc/__init__.py +1 -1
- ripperdoc/cli/cli.py +74 -9
- ripperdoc/cli/commands/__init__.py +4 -0
- ripperdoc/cli/commands/agents_cmd.py +30 -4
- ripperdoc/cli/commands/context_cmd.py +11 -1
- ripperdoc/cli/commands/cost_cmd.py +5 -0
- ripperdoc/cli/commands/doctor_cmd.py +208 -0
- ripperdoc/cli/commands/memory_cmd.py +202 -0
- ripperdoc/cli/commands/models_cmd.py +61 -6
- ripperdoc/cli/commands/resume_cmd.py +4 -2
- ripperdoc/cli/commands/status_cmd.py +1 -1
- ripperdoc/cli/commands/tasks_cmd.py +27 -0
- ripperdoc/cli/ui/rich_ui.py +258 -11
- ripperdoc/cli/ui/thinking_spinner.py +128 -0
- ripperdoc/core/agents.py +14 -4
- ripperdoc/core/config.py +56 -3
- ripperdoc/core/default_tools.py +16 -2
- ripperdoc/core/permissions.py +19 -0
- ripperdoc/core/providers/__init__.py +31 -0
- ripperdoc/core/providers/anthropic.py +136 -0
- ripperdoc/core/providers/base.py +187 -0
- ripperdoc/core/providers/gemini.py +172 -0
- ripperdoc/core/providers/openai.py +142 -0
- ripperdoc/core/query.py +510 -386
- ripperdoc/core/query_utils.py +578 -0
- ripperdoc/core/system_prompt.py +2 -1
- ripperdoc/core/tool.py +16 -1
- ripperdoc/sdk/client.py +12 -1
- ripperdoc/tools/background_shell.py +63 -21
- ripperdoc/tools/bash_tool.py +48 -13
- ripperdoc/tools/file_edit_tool.py +20 -0
- ripperdoc/tools/file_read_tool.py +23 -0
- ripperdoc/tools/file_write_tool.py +20 -0
- ripperdoc/tools/glob_tool.py +59 -15
- ripperdoc/tools/grep_tool.py +7 -0
- ripperdoc/tools/ls_tool.py +246 -73
- ripperdoc/tools/mcp_tools.py +32 -10
- ripperdoc/tools/multi_edit_tool.py +23 -0
- ripperdoc/tools/notebook_edit_tool.py +18 -3
- ripperdoc/tools/task_tool.py +7 -0
- ripperdoc/tools/todo_tool.py +157 -25
- ripperdoc/tools/tool_search_tool.py +17 -4
- ripperdoc/utils/file_watch.py +134 -0
- ripperdoc/utils/git_utils.py +274 -0
- ripperdoc/utils/json_utils.py +27 -0
- ripperdoc/utils/log.py +129 -29
- ripperdoc/utils/mcp.py +71 -6
- ripperdoc/utils/memory.py +12 -1
- ripperdoc/utils/message_compaction.py +22 -5
- ripperdoc/utils/messages.py +72 -17
- ripperdoc/utils/output_utils.py +34 -9
- ripperdoc/utils/permissions/path_validation_utils.py +6 -0
- ripperdoc/utils/prompt.py +17 -0
- ripperdoc/utils/safe_get_cwd.py +4 -0
- ripperdoc/utils/session_history.py +27 -9
- ripperdoc/utils/session_usage.py +7 -0
- ripperdoc/utils/shell_utils.py +159 -0
- ripperdoc/utils/todo.py +2 -2
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/METADATA +4 -2
- ripperdoc-0.2.3.dist-info/RECORD +95 -0
- ripperdoc-0.2.0.dist-info/RECORD +0 -81
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/WHEEL +0 -0
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/entry_points.txt +0 -0
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {ripperdoc-0.2.0.dist-info → ripperdoc-0.2.3.dist-info}/top_level.txt +0 -0
ripperdoc/core/config.py
CHANGED
|
@@ -7,7 +7,7 @@ including API keys, model settings, and user preferences.
|
|
|
7
7
|
import json
|
|
8
8
|
import os
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from typing import Dict, Optional
|
|
10
|
+
from typing import Dict, Optional, Literal
|
|
11
11
|
from pydantic import BaseModel, Field
|
|
12
12
|
from enum import Enum
|
|
13
13
|
|
|
@@ -100,11 +100,19 @@ class ModelProfile(BaseModel):
|
|
|
100
100
|
provider: ProviderType
|
|
101
101
|
model: str
|
|
102
102
|
api_key: Optional[str] = None
|
|
103
|
+
# Anthropic supports either api_key or auth_token; api_key takes precedence when both are set.
|
|
104
|
+
auth_token: Optional[str] = None
|
|
103
105
|
api_base: Optional[str] = None
|
|
104
106
|
max_tokens: int = 4096
|
|
105
107
|
temperature: float = 0.7
|
|
106
108
|
# Total context window in tokens (if known). Falls back to heuristics when unset.
|
|
107
109
|
context_window: Optional[int] = None
|
|
110
|
+
# Tool handling for OpenAI-compatible providers. "native" uses tool_calls, "text" flattens tool
|
|
111
|
+
# interactions into plain text to support providers that reject tool roles.
|
|
112
|
+
openai_tool_mode: Literal["native", "text"] = "native"
|
|
113
|
+
# Pricing (USD per 1M tokens). Leave as 0 to skip cost calculation.
|
|
114
|
+
input_cost_per_million_tokens: float = 0.0
|
|
115
|
+
output_cost_per_million_tokens: float = 0.0
|
|
108
116
|
|
|
109
117
|
|
|
110
118
|
class ModelPointers(BaseModel):
|
|
@@ -185,17 +193,36 @@ class ConfigManager:
|
|
|
185
193
|
try:
|
|
186
194
|
data = json.loads(self.global_config_path.read_text())
|
|
187
195
|
self._global_config = GlobalConfig(**data)
|
|
196
|
+
logger.debug(
|
|
197
|
+
"[config] Loaded global configuration",
|
|
198
|
+
extra={
|
|
199
|
+
"path": str(self.global_config_path),
|
|
200
|
+
"profile_count": len(self._global_config.model_profiles),
|
|
201
|
+
},
|
|
202
|
+
)
|
|
188
203
|
except Exception as e:
|
|
189
|
-
logger.
|
|
204
|
+
logger.exception("Error loading global config", extra={"error": str(e)})
|
|
190
205
|
self._global_config = GlobalConfig()
|
|
191
206
|
else:
|
|
192
207
|
self._global_config = GlobalConfig()
|
|
208
|
+
logger.debug(
|
|
209
|
+
"[config] Global config not found; using defaults",
|
|
210
|
+
extra={"path": str(self.global_config_path)},
|
|
211
|
+
)
|
|
193
212
|
return self._global_config
|
|
194
213
|
|
|
195
214
|
def save_global_config(self, config: GlobalConfig) -> None:
|
|
196
215
|
"""Save global configuration."""
|
|
197
216
|
self._global_config = config
|
|
198
217
|
self.global_config_path.write_text(config.model_dump_json(indent=2))
|
|
218
|
+
logger.debug(
|
|
219
|
+
"[config] Saved global configuration",
|
|
220
|
+
extra={
|
|
221
|
+
"path": str(self.global_config_path),
|
|
222
|
+
"profile_count": len(config.model_profiles),
|
|
223
|
+
"pointers": config.model_pointers.model_dump(),
|
|
224
|
+
},
|
|
225
|
+
)
|
|
199
226
|
|
|
200
227
|
def get_project_config(self, project_path: Optional[Path] = None) -> ProjectConfig:
|
|
201
228
|
"""Load and return project configuration."""
|
|
@@ -215,11 +242,29 @@ class ConfigManager:
|
|
|
215
242
|
try:
|
|
216
243
|
data = json.loads(config_path.read_text())
|
|
217
244
|
self._project_config = ProjectConfig(**data)
|
|
245
|
+
logger.debug(
|
|
246
|
+
"[config] Loaded project config",
|
|
247
|
+
extra={
|
|
248
|
+
"path": str(config_path),
|
|
249
|
+
"project_path": str(self.current_project_path),
|
|
250
|
+
"allowed_tools": len(self._project_config.allowed_tools),
|
|
251
|
+
},
|
|
252
|
+
)
|
|
218
253
|
except Exception as e:
|
|
219
|
-
logger.
|
|
254
|
+
logger.exception(
|
|
255
|
+
"Error loading project config",
|
|
256
|
+
extra={"error": str(e), "path": str(config_path)},
|
|
257
|
+
)
|
|
220
258
|
self._project_config = ProjectConfig()
|
|
221
259
|
else:
|
|
222
260
|
self._project_config = ProjectConfig()
|
|
261
|
+
logger.debug(
|
|
262
|
+
"[config] Project config not found; using defaults",
|
|
263
|
+
extra={
|
|
264
|
+
"path": str(config_path),
|
|
265
|
+
"project_path": str(self.current_project_path),
|
|
266
|
+
},
|
|
267
|
+
)
|
|
223
268
|
|
|
224
269
|
return self._project_config
|
|
225
270
|
|
|
@@ -239,6 +284,14 @@ class ConfigManager:
|
|
|
239
284
|
config_path = config_dir / "config.json"
|
|
240
285
|
self._project_config = config
|
|
241
286
|
config_path.write_text(config.model_dump_json(indent=2))
|
|
287
|
+
logger.debug(
|
|
288
|
+
"[config] Saved project config",
|
|
289
|
+
extra={
|
|
290
|
+
"path": str(config_path),
|
|
291
|
+
"project_path": str(self.current_project_path),
|
|
292
|
+
"allowed_tools": len(config.allowed_tools),
|
|
293
|
+
},
|
|
294
|
+
)
|
|
242
295
|
|
|
243
296
|
def get_api_key(self, provider: ProviderType) -> Optional[str]:
|
|
244
297
|
"""Get API key for a provider."""
|
ripperdoc/core/default_tools.py
CHANGED
|
@@ -26,6 +26,9 @@ from ripperdoc.tools.mcp_tools import (
|
|
|
26
26
|
ReadMcpResourceTool,
|
|
27
27
|
load_dynamic_mcp_tools_sync,
|
|
28
28
|
)
|
|
29
|
+
from ripperdoc.utils.log import get_logger
|
|
30
|
+
|
|
31
|
+
logger = get_logger()
|
|
29
32
|
|
|
30
33
|
|
|
31
34
|
def get_default_tools() -> List[Tool[Any, Any]]:
|
|
@@ -49,15 +52,26 @@ def get_default_tools() -> List[Tool[Any, Any]]:
|
|
|
49
52
|
ListMcpResourcesTool(),
|
|
50
53
|
ReadMcpResourceTool(),
|
|
51
54
|
]
|
|
55
|
+
dynamic_tools: List[Tool[Any, Any]] = []
|
|
52
56
|
try:
|
|
53
57
|
mcp_tools = load_dynamic_mcp_tools_sync()
|
|
54
58
|
# Filter to ensure only Tool instances are added
|
|
55
59
|
for tool in mcp_tools:
|
|
56
60
|
if isinstance(tool, Tool):
|
|
57
61
|
base_tools.append(tool)
|
|
62
|
+
dynamic_tools.append(tool)
|
|
58
63
|
except Exception:
|
|
59
64
|
# If MCP runtime is not available, continue with base tools only.
|
|
60
|
-
|
|
65
|
+
logger.exception("[default_tools] Failed to load dynamic MCP tools")
|
|
61
66
|
|
|
62
67
|
task_tool = TaskTool(lambda: base_tools)
|
|
63
|
-
|
|
68
|
+
all_tools = base_tools + [task_tool]
|
|
69
|
+
logger.debug(
|
|
70
|
+
"[default_tools] Built tool inventory",
|
|
71
|
+
extra={
|
|
72
|
+
"base_tools": len(base_tools),
|
|
73
|
+
"dynamic_mcp_tools": len(dynamic_tools),
|
|
74
|
+
"total_tools": len(all_tools),
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
return all_tools
|
ripperdoc/core/permissions.py
CHANGED
|
@@ -11,6 +11,9 @@ from typing import Any, Awaitable, Callable, Optional, Set
|
|
|
11
11
|
from ripperdoc.core.config import config_manager
|
|
12
12
|
from ripperdoc.core.tool import Tool
|
|
13
13
|
from ripperdoc.utils.permissions import PermissionDecision, ToolRule
|
|
14
|
+
from ripperdoc.utils.log import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger()
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
@dataclass
|
|
@@ -46,11 +49,19 @@ def permission_key(tool: Tool[Any, Any], parsed_input: Any) -> str:
|
|
|
46
49
|
try:
|
|
47
50
|
return f"{tool.name}::path::{Path(getattr(parsed_input, 'file_path')).resolve()}"
|
|
48
51
|
except Exception:
|
|
52
|
+
logger.exception(
|
|
53
|
+
"[permissions] Failed to resolve file_path for permission key",
|
|
54
|
+
extra={"tool": getattr(tool, "name", None)},
|
|
55
|
+
)
|
|
49
56
|
return f"{tool.name}::path::{getattr(parsed_input, 'file_path')}"
|
|
50
57
|
if hasattr(parsed_input, "path"):
|
|
51
58
|
try:
|
|
52
59
|
return f"{tool.name}::path::{Path(getattr(parsed_input, 'path')).resolve()}"
|
|
53
60
|
except Exception:
|
|
61
|
+
logger.exception(
|
|
62
|
+
"[permissions] Failed to resolve path for permission key",
|
|
63
|
+
extra={"tool": getattr(tool, "name", None)},
|
|
64
|
+
)
|
|
54
65
|
return f"{tool.name}::path::{getattr(parsed_input, 'path')}"
|
|
55
66
|
return tool.name
|
|
56
67
|
|
|
@@ -116,6 +127,10 @@ def make_permission_checker(
|
|
|
116
127
|
if hasattr(tool, "needs_permissions") and not tool.needs_permissions(parsed_input):
|
|
117
128
|
return PermissionResult(result=True)
|
|
118
129
|
except Exception:
|
|
130
|
+
logger.exception(
|
|
131
|
+
"[permissions] Tool needs_permissions check failed",
|
|
132
|
+
extra={"tool": getattr(tool, "name", None)},
|
|
133
|
+
)
|
|
119
134
|
return PermissionResult(
|
|
120
135
|
result=False,
|
|
121
136
|
message="Permission check failed for this tool invocation.",
|
|
@@ -153,6 +168,10 @@ def make_permission_checker(
|
|
|
153
168
|
if isinstance(decision, dict) and "behavior" in decision:
|
|
154
169
|
decision = PermissionDecision(**decision)
|
|
155
170
|
except Exception:
|
|
171
|
+
logger.exception(
|
|
172
|
+
"[permissions] Tool check_permissions failed",
|
|
173
|
+
extra={"tool": getattr(tool, "name", None)},
|
|
174
|
+
)
|
|
156
175
|
decision = PermissionDecision(
|
|
157
176
|
behavior="ask",
|
|
158
177
|
message="Error checking permissions for this tool.",
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Provider client registry."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from ripperdoc.core.config import ProviderType
|
|
8
|
+
from ripperdoc.core.providers.anthropic import AnthropicClient
|
|
9
|
+
from ripperdoc.core.providers.base import ProviderClient
|
|
10
|
+
from ripperdoc.core.providers.gemini import GeminiClient
|
|
11
|
+
from ripperdoc.core.providers.openai import OpenAIClient
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_provider_client(provider: ProviderType) -> Optional[ProviderClient]:
|
|
15
|
+
"""Return a provider client for the given protocol."""
|
|
16
|
+
if provider == ProviderType.ANTHROPIC:
|
|
17
|
+
return AnthropicClient()
|
|
18
|
+
if provider == ProviderType.OPENAI_COMPATIBLE:
|
|
19
|
+
return OpenAIClient()
|
|
20
|
+
if provider == ProviderType.GEMINI:
|
|
21
|
+
return GeminiClient()
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"ProviderClient",
|
|
27
|
+
"AnthropicClient",
|
|
28
|
+
"GeminiClient",
|
|
29
|
+
"OpenAIClient",
|
|
30
|
+
"get_provider_client",
|
|
31
|
+
]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Anthropic provider client."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from anthropic import AsyncAnthropic
|
|
9
|
+
|
|
10
|
+
from ripperdoc.core.config import ModelProfile
|
|
11
|
+
from ripperdoc.core.providers.base import (
|
|
12
|
+
ProgressCallback,
|
|
13
|
+
ProviderClient,
|
|
14
|
+
ProviderResponse,
|
|
15
|
+
call_with_timeout_and_retries,
|
|
16
|
+
sanitize_tool_history,
|
|
17
|
+
)
|
|
18
|
+
from ripperdoc.core.query_utils import (
|
|
19
|
+
anthropic_usage_tokens,
|
|
20
|
+
build_anthropic_tool_schemas,
|
|
21
|
+
content_blocks_from_anthropic_response,
|
|
22
|
+
estimate_cost_usd,
|
|
23
|
+
)
|
|
24
|
+
from ripperdoc.core.tool import Tool
|
|
25
|
+
from ripperdoc.utils.log import get_logger
|
|
26
|
+
from ripperdoc.utils.session_usage import record_usage
|
|
27
|
+
|
|
28
|
+
logger = get_logger()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AnthropicClient(ProviderClient):
|
|
32
|
+
"""Anthropic client with streaming and non-streaming support."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, client_factory: Optional[Callable[[], Awaitable[AsyncAnthropic]]] = None):
|
|
35
|
+
self._client_factory = client_factory
|
|
36
|
+
|
|
37
|
+
async def _client(self, kwargs: Dict[str, Any]) -> AsyncAnthropic:
|
|
38
|
+
if self._client_factory:
|
|
39
|
+
return await self._client_factory()
|
|
40
|
+
return AsyncAnthropic(**kwargs)
|
|
41
|
+
|
|
42
|
+
async def call(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
model_profile: ModelProfile,
|
|
46
|
+
system_prompt: str,
|
|
47
|
+
normalized_messages: Any,
|
|
48
|
+
tools: List[Tool[Any, Any]],
|
|
49
|
+
tool_mode: str,
|
|
50
|
+
stream: bool,
|
|
51
|
+
progress_callback: Optional[ProgressCallback],
|
|
52
|
+
request_timeout: Optional[float],
|
|
53
|
+
max_retries: int,
|
|
54
|
+
) -> ProviderResponse:
|
|
55
|
+
start_time = time.time()
|
|
56
|
+
tool_schemas = await build_anthropic_tool_schemas(tools)
|
|
57
|
+
collected_text: List[str] = []
|
|
58
|
+
|
|
59
|
+
anthropic_kwargs = {"base_url": model_profile.api_base}
|
|
60
|
+
if model_profile.api_key:
|
|
61
|
+
anthropic_kwargs["api_key"] = model_profile.api_key
|
|
62
|
+
auth_token = getattr(model_profile, "auth_token", None)
|
|
63
|
+
if auth_token:
|
|
64
|
+
anthropic_kwargs["auth_token"] = auth_token
|
|
65
|
+
|
|
66
|
+
normalized_messages = sanitize_tool_history(list(normalized_messages))
|
|
67
|
+
|
|
68
|
+
async with await self._client(anthropic_kwargs) as client:
|
|
69
|
+
|
|
70
|
+
async def _stream_request() -> Any:
|
|
71
|
+
async with client.messages.stream(
|
|
72
|
+
model=model_profile.model,
|
|
73
|
+
max_tokens=model_profile.max_tokens,
|
|
74
|
+
system=system_prompt,
|
|
75
|
+
messages=normalized_messages, # type: ignore[arg-type]
|
|
76
|
+
tools=tool_schemas if tool_schemas else None, # type: ignore
|
|
77
|
+
temperature=model_profile.temperature,
|
|
78
|
+
) as stream_resp:
|
|
79
|
+
async for text in stream_resp.text_stream:
|
|
80
|
+
if text:
|
|
81
|
+
collected_text.append(text)
|
|
82
|
+
if progress_callback:
|
|
83
|
+
try:
|
|
84
|
+
await progress_callback(text)
|
|
85
|
+
except Exception:
|
|
86
|
+
logger.exception("[anthropic_client] Stream callback failed")
|
|
87
|
+
getter = getattr(stream_resp, "get_final_response", None) or getattr(
|
|
88
|
+
stream_resp, "get_final_message", None
|
|
89
|
+
)
|
|
90
|
+
if getter:
|
|
91
|
+
return await getter()
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
async def _non_stream_request() -> Any:
|
|
95
|
+
return await client.messages.create(
|
|
96
|
+
model=model_profile.model,
|
|
97
|
+
max_tokens=model_profile.max_tokens,
|
|
98
|
+
system=system_prompt,
|
|
99
|
+
messages=normalized_messages, # type: ignore[arg-type]
|
|
100
|
+
tools=tool_schemas if tool_schemas else None, # type: ignore
|
|
101
|
+
temperature=model_profile.temperature,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
response = await call_with_timeout_and_retries(
|
|
105
|
+
_stream_request if stream else _non_stream_request,
|
|
106
|
+
request_timeout,
|
|
107
|
+
max_retries,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
111
|
+
usage_tokens = anthropic_usage_tokens(getattr(response, "usage", None))
|
|
112
|
+
cost_usd = estimate_cost_usd(model_profile, usage_tokens)
|
|
113
|
+
record_usage(
|
|
114
|
+
model_profile.model, duration_ms=duration_ms, cost_usd=cost_usd, **usage_tokens
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
content_blocks = content_blocks_from_anthropic_response(response, tool_mode)
|
|
118
|
+
if stream and collected_text and tool_mode == "text":
|
|
119
|
+
content_blocks = [{"type": "text", "text": "".join(collected_text)}]
|
|
120
|
+
|
|
121
|
+
logger.info(
|
|
122
|
+
"[anthropic_client] Response received",
|
|
123
|
+
extra={
|
|
124
|
+
"model": model_profile.model,
|
|
125
|
+
"duration_ms": round(duration_ms, 2),
|
|
126
|
+
"tool_mode": tool_mode,
|
|
127
|
+
"tool_schemas": len(tool_schemas),
|
|
128
|
+
},
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return ProviderResponse(
|
|
132
|
+
content_blocks=content_blocks,
|
|
133
|
+
usage_tokens=usage_tokens,
|
|
134
|
+
cost_usd=cost_usd,
|
|
135
|
+
duration_ms=duration_ms,
|
|
136
|
+
)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Shared abstractions for provider clients."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
9
|
+
|
|
10
|
+
from ripperdoc.core.config import ModelProfile
|
|
11
|
+
from ripperdoc.core.tool import Tool
|
|
12
|
+
from ripperdoc.utils.log import get_logger
|
|
13
|
+
|
|
14
|
+
logger = get_logger()
|
|
15
|
+
|
|
16
|
+
ProgressCallback = Callable[[str], Awaitable[None]]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ProviderResponse:
|
|
21
|
+
"""Normalized provider response payload."""
|
|
22
|
+
|
|
23
|
+
content_blocks: List[Dict[str, Any]]
|
|
24
|
+
usage_tokens: Dict[str, int]
|
|
25
|
+
cost_usd: float
|
|
26
|
+
duration_ms: float
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ProviderClient(ABC):
|
|
30
|
+
"""Abstract base for model provider clients."""
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def call(
|
|
34
|
+
self,
|
|
35
|
+
*,
|
|
36
|
+
model_profile: ModelProfile,
|
|
37
|
+
system_prompt: str,
|
|
38
|
+
normalized_messages: List[Dict[str, Any]],
|
|
39
|
+
tools: List[Tool[Any, Any]],
|
|
40
|
+
tool_mode: str,
|
|
41
|
+
stream: bool,
|
|
42
|
+
progress_callback: Optional[ProgressCallback],
|
|
43
|
+
request_timeout: Optional[float],
|
|
44
|
+
max_retries: int,
|
|
45
|
+
) -> ProviderResponse:
|
|
46
|
+
"""Execute a model call and return a normalized response."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def sanitize_tool_history(normalized_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
50
|
+
"""Strip tool_use blocks that lack a following tool_result to satisfy provider constraints."""
|
|
51
|
+
|
|
52
|
+
def _tool_result_ids(msg: Dict[str, Any]) -> set[str]:
|
|
53
|
+
ids: set[str] = set()
|
|
54
|
+
content = msg.get("content")
|
|
55
|
+
if isinstance(content, list):
|
|
56
|
+
for part in content:
|
|
57
|
+
part_type = getattr(
|
|
58
|
+
part, "get", lambda k, default=None: part.__dict__.get(k, default)
|
|
59
|
+
)("type", None)
|
|
60
|
+
if part_type == "tool_result":
|
|
61
|
+
tid = (
|
|
62
|
+
getattr(part, "tool_use_id", None)
|
|
63
|
+
or getattr(part, "id", None)
|
|
64
|
+
or part.get("tool_use_id")
|
|
65
|
+
or part.get("id")
|
|
66
|
+
)
|
|
67
|
+
if tid:
|
|
68
|
+
ids.add(str(tid))
|
|
69
|
+
return ids
|
|
70
|
+
|
|
71
|
+
# Build a lookahead map so we can pair tool_use blocks with tool_results that may
|
|
72
|
+
# appear in any later message (not just the immediate next one).
|
|
73
|
+
tool_results_after: List[set[str]] = []
|
|
74
|
+
if normalized_messages:
|
|
75
|
+
tool_results_after = [set() for _ in normalized_messages]
|
|
76
|
+
future_ids: set[str] = set()
|
|
77
|
+
for idx in range(len(normalized_messages) - 1, -1, -1):
|
|
78
|
+
tool_results_after[idx] = set(future_ids)
|
|
79
|
+
future_ids.update(_tool_result_ids(normalized_messages[idx]))
|
|
80
|
+
|
|
81
|
+
sanitized: List[Dict[str, Any]] = []
|
|
82
|
+
for idx, message in enumerate(normalized_messages):
|
|
83
|
+
if message.get("role") != "assistant":
|
|
84
|
+
sanitized.append(message)
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
content = message.get("content")
|
|
88
|
+
if not isinstance(content, list):
|
|
89
|
+
sanitized.append(message)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
tool_use_blocks = [
|
|
93
|
+
part
|
|
94
|
+
for part in content
|
|
95
|
+
if (
|
|
96
|
+
getattr(part, "type", None)
|
|
97
|
+
or (part.get("type") if isinstance(part, dict) else None)
|
|
98
|
+
)
|
|
99
|
+
== "tool_use"
|
|
100
|
+
]
|
|
101
|
+
if not tool_use_blocks:
|
|
102
|
+
sanitized.append(message)
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
future_results = tool_results_after[idx] if tool_results_after else set()
|
|
106
|
+
|
|
107
|
+
# Identify unpaired tool_use IDs
|
|
108
|
+
unpaired_ids: set[str] = set()
|
|
109
|
+
for block in tool_use_blocks:
|
|
110
|
+
block_id = (
|
|
111
|
+
getattr(block, "tool_use_id", None)
|
|
112
|
+
or getattr(block, "id", None)
|
|
113
|
+
or (block.get("tool_use_id") if isinstance(block, dict) else None)
|
|
114
|
+
or (block.get("id") if isinstance(block, dict) else None)
|
|
115
|
+
)
|
|
116
|
+
if block_id and str(block_id) not in future_results:
|
|
117
|
+
unpaired_ids.add(str(block_id))
|
|
118
|
+
|
|
119
|
+
if not unpaired_ids:
|
|
120
|
+
sanitized.append(message)
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
# Drop unpaired tool_use blocks
|
|
124
|
+
filtered_content = []
|
|
125
|
+
for part in content:
|
|
126
|
+
part_type = getattr(part, "type", None) or (
|
|
127
|
+
part.get("type") if isinstance(part, dict) else None
|
|
128
|
+
)
|
|
129
|
+
if part_type == "tool_use":
|
|
130
|
+
block_id = (
|
|
131
|
+
getattr(part, "tool_use_id", None)
|
|
132
|
+
or getattr(part, "id", None)
|
|
133
|
+
or (part.get("tool_use_id") if isinstance(part, dict) else None)
|
|
134
|
+
or (part.get("id") if isinstance(part, dict) else None)
|
|
135
|
+
)
|
|
136
|
+
if block_id and str(block_id) in unpaired_ids:
|
|
137
|
+
continue
|
|
138
|
+
filtered_content.append(part)
|
|
139
|
+
|
|
140
|
+
if not filtered_content:
|
|
141
|
+
logger.debug(
|
|
142
|
+
"[provider_clients] Dropped assistant message with unpaired tool_use blocks",
|
|
143
|
+
extra={"unpaired_ids": list(unpaired_ids)},
|
|
144
|
+
)
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
sanitized.append({**message, "content": filtered_content})
|
|
148
|
+
logger.debug(
|
|
149
|
+
"[provider_clients] Sanitized message to remove unpaired tool_use blocks",
|
|
150
|
+
extra={"unpaired_ids": list(unpaired_ids)},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return sanitized
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
async def call_with_timeout_and_retries(
|
|
157
|
+
coro_factory: Callable[[], Awaitable[Any]],
|
|
158
|
+
request_timeout: Optional[float],
|
|
159
|
+
max_retries: int,
|
|
160
|
+
) -> Any:
|
|
161
|
+
"""Run a coroutine with timeout and limited retries."""
|
|
162
|
+
attempts = max(0, int(max_retries)) + 1
|
|
163
|
+
last_error: Optional[Exception] = None
|
|
164
|
+
for attempt in range(1, attempts + 1):
|
|
165
|
+
try:
|
|
166
|
+
if request_timeout and request_timeout > 0:
|
|
167
|
+
return await asyncio.wait_for(coro_factory(), timeout=request_timeout)
|
|
168
|
+
return await coro_factory()
|
|
169
|
+
except asyncio.TimeoutError as exc:
|
|
170
|
+
last_error = exc
|
|
171
|
+
logger.warning(
|
|
172
|
+
"[provider_clients] Request timed out; retrying",
|
|
173
|
+
extra={"attempt": attempt, "max_retries": attempts - 1},
|
|
174
|
+
)
|
|
175
|
+
if attempt == attempts:
|
|
176
|
+
raise
|
|
177
|
+
except Exception as exc:
|
|
178
|
+
last_error = exc
|
|
179
|
+
if attempt == attempts:
|
|
180
|
+
raise
|
|
181
|
+
logger.warning(
|
|
182
|
+
"[provider_clients] Request failed; retrying",
|
|
183
|
+
extra={"attempt": attempt, "max_retries": attempts - 1, "error": str(exc)},
|
|
184
|
+
)
|
|
185
|
+
if last_error:
|
|
186
|
+
raise last_error
|
|
187
|
+
raise RuntimeError("Unexpected error executing request with retries")
|