ripperdoc 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripperdoc/__init__.py +1 -1
- ripperdoc/cli/cli.py +9 -2
- ripperdoc/cli/commands/agents_cmd.py +8 -4
- ripperdoc/cli/commands/cost_cmd.py +5 -0
- ripperdoc/cli/commands/doctor_cmd.py +12 -4
- ripperdoc/cli/commands/memory_cmd.py +6 -13
- ripperdoc/cli/commands/models_cmd.py +36 -6
- ripperdoc/cli/commands/resume_cmd.py +4 -2
- ripperdoc/cli/commands/status_cmd.py +1 -1
- ripperdoc/cli/ui/rich_ui.py +102 -2
- ripperdoc/cli/ui/thinking_spinner.py +128 -0
- ripperdoc/core/agents.py +13 -5
- ripperdoc/core/config.py +9 -1
- ripperdoc/core/providers/__init__.py +31 -0
- ripperdoc/core/providers/anthropic.py +136 -0
- ripperdoc/core/providers/base.py +187 -0
- ripperdoc/core/providers/gemini.py +172 -0
- ripperdoc/core/providers/openai.py +142 -0
- ripperdoc/core/query.py +331 -141
- ripperdoc/core/query_utils.py +64 -23
- ripperdoc/core/tool.py +5 -3
- ripperdoc/sdk/client.py +12 -1
- ripperdoc/tools/background_shell.py +54 -18
- ripperdoc/tools/bash_tool.py +33 -13
- ripperdoc/tools/file_edit_tool.py +13 -0
- ripperdoc/tools/file_read_tool.py +16 -0
- ripperdoc/tools/file_write_tool.py +13 -0
- ripperdoc/tools/glob_tool.py +5 -1
- ripperdoc/tools/ls_tool.py +14 -10
- ripperdoc/tools/multi_edit_tool.py +12 -0
- ripperdoc/tools/notebook_edit_tool.py +12 -0
- ripperdoc/tools/todo_tool.py +1 -3
- ripperdoc/tools/tool_search_tool.py +8 -4
- ripperdoc/utils/file_watch.py +134 -0
- ripperdoc/utils/git_utils.py +36 -38
- ripperdoc/utils/json_utils.py +1 -2
- ripperdoc/utils/log.py +3 -4
- ripperdoc/utils/memory.py +1 -3
- ripperdoc/utils/message_compaction.py +2 -6
- ripperdoc/utils/messages.py +9 -13
- ripperdoc/utils/output_utils.py +1 -3
- ripperdoc/utils/prompt.py +17 -0
- ripperdoc/utils/session_usage.py +7 -0
- ripperdoc/utils/shell_utils.py +159 -0
- {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/METADATA +1 -1
- ripperdoc-0.2.3.dist-info/RECORD +95 -0
- ripperdoc-0.2.2.dist-info/RECORD +0 -86
- {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/WHEEL +0 -0
- {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/entry_points.txt +0 -0
- {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {ripperdoc-0.2.2.dist-info → ripperdoc-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Anthropic provider client."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from anthropic import AsyncAnthropic
|
|
9
|
+
|
|
10
|
+
from ripperdoc.core.config import ModelProfile
|
|
11
|
+
from ripperdoc.core.providers.base import (
|
|
12
|
+
ProgressCallback,
|
|
13
|
+
ProviderClient,
|
|
14
|
+
ProviderResponse,
|
|
15
|
+
call_with_timeout_and_retries,
|
|
16
|
+
sanitize_tool_history,
|
|
17
|
+
)
|
|
18
|
+
from ripperdoc.core.query_utils import (
|
|
19
|
+
anthropic_usage_tokens,
|
|
20
|
+
build_anthropic_tool_schemas,
|
|
21
|
+
content_blocks_from_anthropic_response,
|
|
22
|
+
estimate_cost_usd,
|
|
23
|
+
)
|
|
24
|
+
from ripperdoc.core.tool import Tool
|
|
25
|
+
from ripperdoc.utils.log import get_logger
|
|
26
|
+
from ripperdoc.utils.session_usage import record_usage
|
|
27
|
+
|
|
28
|
+
logger = get_logger()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AnthropicClient(ProviderClient):
|
|
32
|
+
"""Anthropic client with streaming and non-streaming support."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, client_factory: Optional[Callable[[], Awaitable[AsyncAnthropic]]] = None):
|
|
35
|
+
self._client_factory = client_factory
|
|
36
|
+
|
|
37
|
+
async def _client(self, kwargs: Dict[str, Any]) -> AsyncAnthropic:
|
|
38
|
+
if self._client_factory:
|
|
39
|
+
return await self._client_factory()
|
|
40
|
+
return AsyncAnthropic(**kwargs)
|
|
41
|
+
|
|
42
|
+
async def call(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
model_profile: ModelProfile,
|
|
46
|
+
system_prompt: str,
|
|
47
|
+
normalized_messages: Any,
|
|
48
|
+
tools: List[Tool[Any, Any]],
|
|
49
|
+
tool_mode: str,
|
|
50
|
+
stream: bool,
|
|
51
|
+
progress_callback: Optional[ProgressCallback],
|
|
52
|
+
request_timeout: Optional[float],
|
|
53
|
+
max_retries: int,
|
|
54
|
+
) -> ProviderResponse:
|
|
55
|
+
start_time = time.time()
|
|
56
|
+
tool_schemas = await build_anthropic_tool_schemas(tools)
|
|
57
|
+
collected_text: List[str] = []
|
|
58
|
+
|
|
59
|
+
anthropic_kwargs = {"base_url": model_profile.api_base}
|
|
60
|
+
if model_profile.api_key:
|
|
61
|
+
anthropic_kwargs["api_key"] = model_profile.api_key
|
|
62
|
+
auth_token = getattr(model_profile, "auth_token", None)
|
|
63
|
+
if auth_token:
|
|
64
|
+
anthropic_kwargs["auth_token"] = auth_token
|
|
65
|
+
|
|
66
|
+
normalized_messages = sanitize_tool_history(list(normalized_messages))
|
|
67
|
+
|
|
68
|
+
async with await self._client(anthropic_kwargs) as client:
|
|
69
|
+
|
|
70
|
+
async def _stream_request() -> Any:
|
|
71
|
+
async with client.messages.stream(
|
|
72
|
+
model=model_profile.model,
|
|
73
|
+
max_tokens=model_profile.max_tokens,
|
|
74
|
+
system=system_prompt,
|
|
75
|
+
messages=normalized_messages, # type: ignore[arg-type]
|
|
76
|
+
tools=tool_schemas if tool_schemas else None, # type: ignore
|
|
77
|
+
temperature=model_profile.temperature,
|
|
78
|
+
) as stream_resp:
|
|
79
|
+
async for text in stream_resp.text_stream:
|
|
80
|
+
if text:
|
|
81
|
+
collected_text.append(text)
|
|
82
|
+
if progress_callback:
|
|
83
|
+
try:
|
|
84
|
+
await progress_callback(text)
|
|
85
|
+
except Exception:
|
|
86
|
+
logger.exception("[anthropic_client] Stream callback failed")
|
|
87
|
+
getter = getattr(stream_resp, "get_final_response", None) or getattr(
|
|
88
|
+
stream_resp, "get_final_message", None
|
|
89
|
+
)
|
|
90
|
+
if getter:
|
|
91
|
+
return await getter()
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
async def _non_stream_request() -> Any:
|
|
95
|
+
return await client.messages.create(
|
|
96
|
+
model=model_profile.model,
|
|
97
|
+
max_tokens=model_profile.max_tokens,
|
|
98
|
+
system=system_prompt,
|
|
99
|
+
messages=normalized_messages, # type: ignore[arg-type]
|
|
100
|
+
tools=tool_schemas if tool_schemas else None, # type: ignore
|
|
101
|
+
temperature=model_profile.temperature,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
response = await call_with_timeout_and_retries(
|
|
105
|
+
_stream_request if stream else _non_stream_request,
|
|
106
|
+
request_timeout,
|
|
107
|
+
max_retries,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
111
|
+
usage_tokens = anthropic_usage_tokens(getattr(response, "usage", None))
|
|
112
|
+
cost_usd = estimate_cost_usd(model_profile, usage_tokens)
|
|
113
|
+
record_usage(
|
|
114
|
+
model_profile.model, duration_ms=duration_ms, cost_usd=cost_usd, **usage_tokens
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
content_blocks = content_blocks_from_anthropic_response(response, tool_mode)
|
|
118
|
+
if stream and collected_text and tool_mode == "text":
|
|
119
|
+
content_blocks = [{"type": "text", "text": "".join(collected_text)}]
|
|
120
|
+
|
|
121
|
+
logger.info(
|
|
122
|
+
"[anthropic_client] Response received",
|
|
123
|
+
extra={
|
|
124
|
+
"model": model_profile.model,
|
|
125
|
+
"duration_ms": round(duration_ms, 2),
|
|
126
|
+
"tool_mode": tool_mode,
|
|
127
|
+
"tool_schemas": len(tool_schemas),
|
|
128
|
+
},
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return ProviderResponse(
|
|
132
|
+
content_blocks=content_blocks,
|
|
133
|
+
usage_tokens=usage_tokens,
|
|
134
|
+
cost_usd=cost_usd,
|
|
135
|
+
duration_ms=duration_ms,
|
|
136
|
+
)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Shared abstractions for provider clients."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
9
|
+
|
|
10
|
+
from ripperdoc.core.config import ModelProfile
|
|
11
|
+
from ripperdoc.core.tool import Tool
|
|
12
|
+
from ripperdoc.utils.log import get_logger
|
|
13
|
+
|
|
14
|
+
logger = get_logger()
|
|
15
|
+
|
|
16
|
+
ProgressCallback = Callable[[str], Awaitable[None]]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ProviderResponse:
|
|
21
|
+
"""Normalized provider response payload."""
|
|
22
|
+
|
|
23
|
+
content_blocks: List[Dict[str, Any]]
|
|
24
|
+
usage_tokens: Dict[str, int]
|
|
25
|
+
cost_usd: float
|
|
26
|
+
duration_ms: float
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ProviderClient(ABC):
|
|
30
|
+
"""Abstract base for model provider clients."""
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def call(
|
|
34
|
+
self,
|
|
35
|
+
*,
|
|
36
|
+
model_profile: ModelProfile,
|
|
37
|
+
system_prompt: str,
|
|
38
|
+
normalized_messages: List[Dict[str, Any]],
|
|
39
|
+
tools: List[Tool[Any, Any]],
|
|
40
|
+
tool_mode: str,
|
|
41
|
+
stream: bool,
|
|
42
|
+
progress_callback: Optional[ProgressCallback],
|
|
43
|
+
request_timeout: Optional[float],
|
|
44
|
+
max_retries: int,
|
|
45
|
+
) -> ProviderResponse:
|
|
46
|
+
"""Execute a model call and return a normalized response."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def sanitize_tool_history(normalized_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
50
|
+
"""Strip tool_use blocks that lack a following tool_result to satisfy provider constraints."""
|
|
51
|
+
|
|
52
|
+
def _tool_result_ids(msg: Dict[str, Any]) -> set[str]:
|
|
53
|
+
ids: set[str] = set()
|
|
54
|
+
content = msg.get("content")
|
|
55
|
+
if isinstance(content, list):
|
|
56
|
+
for part in content:
|
|
57
|
+
part_type = getattr(
|
|
58
|
+
part, "get", lambda k, default=None: part.__dict__.get(k, default)
|
|
59
|
+
)("type", None)
|
|
60
|
+
if part_type == "tool_result":
|
|
61
|
+
tid = (
|
|
62
|
+
getattr(part, "tool_use_id", None)
|
|
63
|
+
or getattr(part, "id", None)
|
|
64
|
+
or part.get("tool_use_id")
|
|
65
|
+
or part.get("id")
|
|
66
|
+
)
|
|
67
|
+
if tid:
|
|
68
|
+
ids.add(str(tid))
|
|
69
|
+
return ids
|
|
70
|
+
|
|
71
|
+
# Build a lookahead map so we can pair tool_use blocks with tool_results that may
|
|
72
|
+
# appear in any later message (not just the immediate next one).
|
|
73
|
+
tool_results_after: List[set[str]] = []
|
|
74
|
+
if normalized_messages:
|
|
75
|
+
tool_results_after = [set() for _ in normalized_messages]
|
|
76
|
+
future_ids: set[str] = set()
|
|
77
|
+
for idx in range(len(normalized_messages) - 1, -1, -1):
|
|
78
|
+
tool_results_after[idx] = set(future_ids)
|
|
79
|
+
future_ids.update(_tool_result_ids(normalized_messages[idx]))
|
|
80
|
+
|
|
81
|
+
sanitized: List[Dict[str, Any]] = []
|
|
82
|
+
for idx, message in enumerate(normalized_messages):
|
|
83
|
+
if message.get("role") != "assistant":
|
|
84
|
+
sanitized.append(message)
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
content = message.get("content")
|
|
88
|
+
if not isinstance(content, list):
|
|
89
|
+
sanitized.append(message)
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
tool_use_blocks = [
|
|
93
|
+
part
|
|
94
|
+
for part in content
|
|
95
|
+
if (
|
|
96
|
+
getattr(part, "type", None)
|
|
97
|
+
or (part.get("type") if isinstance(part, dict) else None)
|
|
98
|
+
)
|
|
99
|
+
== "tool_use"
|
|
100
|
+
]
|
|
101
|
+
if not tool_use_blocks:
|
|
102
|
+
sanitized.append(message)
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
future_results = tool_results_after[idx] if tool_results_after else set()
|
|
106
|
+
|
|
107
|
+
# Identify unpaired tool_use IDs
|
|
108
|
+
unpaired_ids: set[str] = set()
|
|
109
|
+
for block in tool_use_blocks:
|
|
110
|
+
block_id = (
|
|
111
|
+
getattr(block, "tool_use_id", None)
|
|
112
|
+
or getattr(block, "id", None)
|
|
113
|
+
or (block.get("tool_use_id") if isinstance(block, dict) else None)
|
|
114
|
+
or (block.get("id") if isinstance(block, dict) else None)
|
|
115
|
+
)
|
|
116
|
+
if block_id and str(block_id) not in future_results:
|
|
117
|
+
unpaired_ids.add(str(block_id))
|
|
118
|
+
|
|
119
|
+
if not unpaired_ids:
|
|
120
|
+
sanitized.append(message)
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
# Drop unpaired tool_use blocks
|
|
124
|
+
filtered_content = []
|
|
125
|
+
for part in content:
|
|
126
|
+
part_type = getattr(part, "type", None) or (
|
|
127
|
+
part.get("type") if isinstance(part, dict) else None
|
|
128
|
+
)
|
|
129
|
+
if part_type == "tool_use":
|
|
130
|
+
block_id = (
|
|
131
|
+
getattr(part, "tool_use_id", None)
|
|
132
|
+
or getattr(part, "id", None)
|
|
133
|
+
or (part.get("tool_use_id") if isinstance(part, dict) else None)
|
|
134
|
+
or (part.get("id") if isinstance(part, dict) else None)
|
|
135
|
+
)
|
|
136
|
+
if block_id and str(block_id) in unpaired_ids:
|
|
137
|
+
continue
|
|
138
|
+
filtered_content.append(part)
|
|
139
|
+
|
|
140
|
+
if not filtered_content:
|
|
141
|
+
logger.debug(
|
|
142
|
+
"[provider_clients] Dropped assistant message with unpaired tool_use blocks",
|
|
143
|
+
extra={"unpaired_ids": list(unpaired_ids)},
|
|
144
|
+
)
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
sanitized.append({**message, "content": filtered_content})
|
|
148
|
+
logger.debug(
|
|
149
|
+
"[provider_clients] Sanitized message to remove unpaired tool_use blocks",
|
|
150
|
+
extra={"unpaired_ids": list(unpaired_ids)},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return sanitized
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
async def call_with_timeout_and_retries(
|
|
157
|
+
coro_factory: Callable[[], Awaitable[Any]],
|
|
158
|
+
request_timeout: Optional[float],
|
|
159
|
+
max_retries: int,
|
|
160
|
+
) -> Any:
|
|
161
|
+
"""Run a coroutine with timeout and limited retries."""
|
|
162
|
+
attempts = max(0, int(max_retries)) + 1
|
|
163
|
+
last_error: Optional[Exception] = None
|
|
164
|
+
for attempt in range(1, attempts + 1):
|
|
165
|
+
try:
|
|
166
|
+
if request_timeout and request_timeout > 0:
|
|
167
|
+
return await asyncio.wait_for(coro_factory(), timeout=request_timeout)
|
|
168
|
+
return await coro_factory()
|
|
169
|
+
except asyncio.TimeoutError as exc:
|
|
170
|
+
last_error = exc
|
|
171
|
+
logger.warning(
|
|
172
|
+
"[provider_clients] Request timed out; retrying",
|
|
173
|
+
extra={"attempt": attempt, "max_retries": attempts - 1},
|
|
174
|
+
)
|
|
175
|
+
if attempt == attempts:
|
|
176
|
+
raise
|
|
177
|
+
except Exception as exc:
|
|
178
|
+
last_error = exc
|
|
179
|
+
if attempt == attempts:
|
|
180
|
+
raise
|
|
181
|
+
logger.warning(
|
|
182
|
+
"[provider_clients] Request failed; retrying",
|
|
183
|
+
extra={"attempt": attempt, "max_retries": attempts - 1, "error": str(exc)},
|
|
184
|
+
)
|
|
185
|
+
if last_error:
|
|
186
|
+
raise last_error
|
|
187
|
+
raise RuntimeError("Unexpected error executing request with retries")
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Gemini provider client."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from ripperdoc.core.config import ModelProfile
|
|
10
|
+
from ripperdoc.core.providers.base import (
|
|
11
|
+
ProgressCallback,
|
|
12
|
+
ProviderClient,
|
|
13
|
+
ProviderResponse,
|
|
14
|
+
call_with_timeout_and_retries,
|
|
15
|
+
)
|
|
16
|
+
from ripperdoc.core.tool import Tool
|
|
17
|
+
from ripperdoc.utils.log import get_logger
|
|
18
|
+
|
|
19
|
+
logger = get_logger()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _extract_usage_metadata(payload: Any) -> Dict[str, int]:
|
|
23
|
+
"""Best-effort token extraction from Gemini responses."""
|
|
24
|
+
usage = getattr(payload, "usage_metadata", None) or getattr(payload, "usageMetadata", None)
|
|
25
|
+
if not usage:
|
|
26
|
+
usage = getattr(payload, "usage", None)
|
|
27
|
+
get = lambda key: int(getattr(usage, key, 0) or 0) if usage else 0 # noqa: E731
|
|
28
|
+
return {
|
|
29
|
+
"input_tokens": get("prompt_token_count") + get("cached_content_token_count"),
|
|
30
|
+
"output_tokens": get("candidates_token_count"),
|
|
31
|
+
"cache_read_input_tokens": get("cached_content_token_count"),
|
|
32
|
+
"cache_creation_input_tokens": 0,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _collect_text_parts(candidate: Any) -> str:
|
|
37
|
+
parts = getattr(candidate, "content", None)
|
|
38
|
+
if not parts:
|
|
39
|
+
return ""
|
|
40
|
+
if isinstance(parts, list):
|
|
41
|
+
texts = []
|
|
42
|
+
for part in parts:
|
|
43
|
+
text_val = getattr(part, "text", None) or getattr(part, "content", None)
|
|
44
|
+
if isinstance(text_val, str):
|
|
45
|
+
texts.append(text_val)
|
|
46
|
+
return "".join(texts)
|
|
47
|
+
return str(parts)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class GeminiClient(ProviderClient):
|
|
51
|
+
"""Gemini client with streaming and basic text support."""
|
|
52
|
+
|
|
53
|
+
async def call(
|
|
54
|
+
self,
|
|
55
|
+
*,
|
|
56
|
+
model_profile: ModelProfile,
|
|
57
|
+
system_prompt: str,
|
|
58
|
+
normalized_messages: List[Dict[str, Any]],
|
|
59
|
+
tools: List[Tool[Any, Any]],
|
|
60
|
+
tool_mode: str,
|
|
61
|
+
stream: bool,
|
|
62
|
+
progress_callback: Optional[ProgressCallback],
|
|
63
|
+
request_timeout: Optional[float],
|
|
64
|
+
max_retries: int,
|
|
65
|
+
) -> ProviderResponse:
|
|
66
|
+
try:
|
|
67
|
+
import google.generativeai as genai # type: ignore
|
|
68
|
+
except Exception as exc: # pragma: no cover - import guard
|
|
69
|
+
msg = (
|
|
70
|
+
"Gemini client requires the 'google-generativeai' package. "
|
|
71
|
+
"Install it to enable Gemini support."
|
|
72
|
+
)
|
|
73
|
+
logger.warning(msg, extra={"error": str(exc)})
|
|
74
|
+
return ProviderResponse(
|
|
75
|
+
content_blocks=[{"type": "text", "text": msg}],
|
|
76
|
+
usage_tokens={},
|
|
77
|
+
cost_usd=0.0,
|
|
78
|
+
duration_ms=0.0,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if tools and tool_mode != "text":
|
|
82
|
+
msg = (
|
|
83
|
+
"Gemini client currently supports text-only responses; "
|
|
84
|
+
"tool/function calling is not yet implemented."
|
|
85
|
+
)
|
|
86
|
+
return ProviderResponse(
|
|
87
|
+
content_blocks=[{"type": "text", "text": msg}],
|
|
88
|
+
usage_tokens={},
|
|
89
|
+
cost_usd=0.0,
|
|
90
|
+
duration_ms=0.0,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
api_key = (
|
|
94
|
+
model_profile.api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
95
|
+
)
|
|
96
|
+
genai.configure(api_key=api_key, client_options={"api_endpoint": model_profile.api_base})
|
|
97
|
+
|
|
98
|
+
# Flatten normalized messages into a single text prompt (Gemini supports multi-turn, but keep it simple).
|
|
99
|
+
prompt_parts: List[str] = [system_prompt]
|
|
100
|
+
for msg in normalized_messages: # type: ignore[assignment]
|
|
101
|
+
role: str = (
|
|
102
|
+
str(msg.get("role", "")) if isinstance(msg, dict) else str(getattr(msg, "role", "")) # type: ignore[assignment]
|
|
103
|
+
)
|
|
104
|
+
content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", "")
|
|
105
|
+
if isinstance(content, list):
|
|
106
|
+
for item in content:
|
|
107
|
+
text_val = (
|
|
108
|
+
getattr(item, "text", None)
|
|
109
|
+
or item.get("text", "") # type: ignore[union-attr]
|
|
110
|
+
if isinstance(item, dict)
|
|
111
|
+
else ""
|
|
112
|
+
)
|
|
113
|
+
if text_val:
|
|
114
|
+
prompt_parts.append(f"{role}: {text_val}")
|
|
115
|
+
elif isinstance(content, str):
|
|
116
|
+
prompt_parts.append(f"{role}: {content}")
|
|
117
|
+
full_prompt = "\n".join(part for part in prompt_parts if part)
|
|
118
|
+
|
|
119
|
+
model = genai.GenerativeModel(model_profile.model)
|
|
120
|
+
collected_text: List[str] = []
|
|
121
|
+
start_time = time.time()
|
|
122
|
+
|
|
123
|
+
async def _stream_request() -> Dict[str, Dict[str, int]]:
|
|
124
|
+
stream_resp = model.generate_content(full_prompt, stream=True)
|
|
125
|
+
usage_tokens: Dict[str, int] = {}
|
|
126
|
+
for chunk in stream_resp:
|
|
127
|
+
text_delta = _collect_text_parts(chunk)
|
|
128
|
+
if text_delta:
|
|
129
|
+
collected_text.append(text_delta)
|
|
130
|
+
if progress_callback:
|
|
131
|
+
try:
|
|
132
|
+
await progress_callback(text_delta)
|
|
133
|
+
except Exception:
|
|
134
|
+
logger.exception("[gemini_client] Stream callback failed")
|
|
135
|
+
usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
|
|
136
|
+
return {"usage": usage_tokens}
|
|
137
|
+
|
|
138
|
+
async def _non_stream_request() -> Any:
|
|
139
|
+
return model.generate_content(full_prompt)
|
|
140
|
+
|
|
141
|
+
response: Any = await call_with_timeout_and_retries(
|
|
142
|
+
_stream_request if stream and progress_callback else _non_stream_request,
|
|
143
|
+
request_timeout,
|
|
144
|
+
max_retries,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
148
|
+
usage_tokens = _extract_usage_metadata(response)
|
|
149
|
+
cost_usd = 0.0 # Pricing unknown; leave as 0
|
|
150
|
+
|
|
151
|
+
content_blocks = (
|
|
152
|
+
[{"type": "text", "text": "".join(collected_text)}]
|
|
153
|
+
if collected_text
|
|
154
|
+
else [{"type": "text", "text": _collect_text_parts(response)}]
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
logger.info(
|
|
158
|
+
"[gemini_client] Response received",
|
|
159
|
+
extra={
|
|
160
|
+
"model": model_profile.model,
|
|
161
|
+
"duration_ms": round(duration_ms, 2),
|
|
162
|
+
"tool_mode": tool_mode,
|
|
163
|
+
"stream": stream,
|
|
164
|
+
},
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
return ProviderResponse(
|
|
168
|
+
content_blocks=content_blocks,
|
|
169
|
+
usage_tokens=usage_tokens,
|
|
170
|
+
cost_usd=cost_usd,
|
|
171
|
+
duration_ms=duration_ms,
|
|
172
|
+
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""OpenAI-compatible provider client."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, Dict, List, Optional, cast
|
|
7
|
+
|
|
8
|
+
from openai import AsyncOpenAI
|
|
9
|
+
|
|
10
|
+
from ripperdoc.core.config import ModelProfile
|
|
11
|
+
from ripperdoc.core.providers.base import (
|
|
12
|
+
ProgressCallback,
|
|
13
|
+
ProviderClient,
|
|
14
|
+
ProviderResponse,
|
|
15
|
+
call_with_timeout_and_retries,
|
|
16
|
+
sanitize_tool_history,
|
|
17
|
+
)
|
|
18
|
+
from ripperdoc.core.query_utils import (
|
|
19
|
+
build_openai_tool_schemas,
|
|
20
|
+
content_blocks_from_openai_choice,
|
|
21
|
+
estimate_cost_usd,
|
|
22
|
+
openai_usage_tokens,
|
|
23
|
+
)
|
|
24
|
+
from ripperdoc.core.tool import Tool
|
|
25
|
+
from ripperdoc.utils.log import get_logger
|
|
26
|
+
from ripperdoc.utils.session_usage import record_usage
|
|
27
|
+
|
|
28
|
+
logger = get_logger()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OpenAIClient(ProviderClient):
|
|
32
|
+
"""OpenAI-compatible client with streaming and non-streaming support."""
|
|
33
|
+
|
|
34
|
+
async def call(
|
|
35
|
+
self,
|
|
36
|
+
*,
|
|
37
|
+
model_profile: ModelProfile,
|
|
38
|
+
system_prompt: str,
|
|
39
|
+
normalized_messages: List[Dict[str, Any]],
|
|
40
|
+
tools: List[Tool[Any, Any]],
|
|
41
|
+
tool_mode: str,
|
|
42
|
+
stream: bool,
|
|
43
|
+
progress_callback: Optional[ProgressCallback],
|
|
44
|
+
request_timeout: Optional[float],
|
|
45
|
+
max_retries: int,
|
|
46
|
+
) -> ProviderResponse:
|
|
47
|
+
start_time = time.time()
|
|
48
|
+
openai_tools = await build_openai_tool_schemas(tools)
|
|
49
|
+
openai_messages: List[Dict[str, object]] = [
|
|
50
|
+
{"role": "system", "content": system_prompt}
|
|
51
|
+
] + sanitize_tool_history(list(normalized_messages))
|
|
52
|
+
collected_text: List[str] = []
|
|
53
|
+
|
|
54
|
+
can_stream = stream and tool_mode == "text" and not openai_tools
|
|
55
|
+
|
|
56
|
+
async with AsyncOpenAI(
|
|
57
|
+
api_key=model_profile.api_key, base_url=model_profile.api_base
|
|
58
|
+
) as client:
|
|
59
|
+
|
|
60
|
+
async def _stream_request() -> Dict[str, Dict[str, int]]:
|
|
61
|
+
stream_resp = await client.chat.completions.create( # type: ignore[call-overload]
|
|
62
|
+
model=model_profile.model,
|
|
63
|
+
messages=cast(Any, openai_messages),
|
|
64
|
+
tools=None,
|
|
65
|
+
temperature=model_profile.temperature,
|
|
66
|
+
max_tokens=model_profile.max_tokens,
|
|
67
|
+
stream=True,
|
|
68
|
+
)
|
|
69
|
+
usage_tokens: Dict[str, int] = {}
|
|
70
|
+
async for chunk in stream_resp:
|
|
71
|
+
delta = getattr(chunk.choices[0], "delta", None)
|
|
72
|
+
delta_content = getattr(delta, "content", None) if delta else None
|
|
73
|
+
text_delta = ""
|
|
74
|
+
if delta_content:
|
|
75
|
+
if isinstance(delta_content, list):
|
|
76
|
+
for part in delta_content:
|
|
77
|
+
text_val = getattr(part, "text", None) or getattr(
|
|
78
|
+
part, "content", None
|
|
79
|
+
)
|
|
80
|
+
if isinstance(text_val, str):
|
|
81
|
+
text_delta += text_val
|
|
82
|
+
elif isinstance(delta_content, str):
|
|
83
|
+
text_delta += delta_content
|
|
84
|
+
if text_delta:
|
|
85
|
+
collected_text.append(text_delta)
|
|
86
|
+
if progress_callback:
|
|
87
|
+
try:
|
|
88
|
+
await progress_callback(text_delta)
|
|
89
|
+
except Exception:
|
|
90
|
+
logger.exception("[openai_client] Stream callback failed")
|
|
91
|
+
if getattr(chunk, "usage", None):
|
|
92
|
+
usage_tokens = openai_usage_tokens(chunk.usage)
|
|
93
|
+
return {"usage": usage_tokens}
|
|
94
|
+
|
|
95
|
+
async def _non_stream_request() -> Any:
|
|
96
|
+
return await client.chat.completions.create( # type: ignore[call-overload]
|
|
97
|
+
model=model_profile.model,
|
|
98
|
+
messages=cast(Any, openai_messages),
|
|
99
|
+
tools=openai_tools if openai_tools else None, # type: ignore[arg-type]
|
|
100
|
+
temperature=model_profile.temperature,
|
|
101
|
+
max_tokens=model_profile.max_tokens,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
openai_response: Any = await call_with_timeout_and_retries(
|
|
105
|
+
_stream_request if can_stream else _non_stream_request,
|
|
106
|
+
request_timeout,
|
|
107
|
+
max_retries,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
111
|
+
usage_tokens = openai_usage_tokens(getattr(openai_response, "usage", None))
|
|
112
|
+
cost_usd = estimate_cost_usd(model_profile, usage_tokens)
|
|
113
|
+
record_usage(
|
|
114
|
+
model_profile.model, duration_ms=duration_ms, cost_usd=cost_usd, **usage_tokens
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
finish_reason: Optional[str]
|
|
118
|
+
if can_stream:
|
|
119
|
+
content_blocks = [{"type": "text", "text": "".join(collected_text)}]
|
|
120
|
+
finish_reason = "stream"
|
|
121
|
+
else:
|
|
122
|
+
choice = openai_response.choices[0]
|
|
123
|
+
content_blocks = content_blocks_from_openai_choice(choice, tool_mode)
|
|
124
|
+
finish_reason = cast(Optional[str], getattr(choice, "finish_reason", None))
|
|
125
|
+
|
|
126
|
+
logger.info(
|
|
127
|
+
"[openai_client] Response received",
|
|
128
|
+
extra={
|
|
129
|
+
"model": model_profile.model,
|
|
130
|
+
"duration_ms": round(duration_ms, 2),
|
|
131
|
+
"tool_mode": tool_mode,
|
|
132
|
+
"tool_count": len(openai_tools),
|
|
133
|
+
"finish_reason": finish_reason,
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return ProviderResponse(
|
|
138
|
+
content_blocks=content_blocks,
|
|
139
|
+
usage_tokens=usage_tokens,
|
|
140
|
+
cost_usd=cost_usd,
|
|
141
|
+
duration_ms=duration_ms,
|
|
142
|
+
)
|