ripperdoc 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripperdoc/__init__.py +1 -1
- ripperdoc/__main__.py +0 -5
- ripperdoc/cli/cli.py +37 -16
- ripperdoc/cli/commands/__init__.py +2 -0
- ripperdoc/cli/commands/agents_cmd.py +12 -9
- ripperdoc/cli/commands/compact_cmd.py +7 -3
- ripperdoc/cli/commands/context_cmd.py +33 -13
- ripperdoc/cli/commands/doctor_cmd.py +27 -14
- ripperdoc/cli/commands/exit_cmd.py +1 -1
- ripperdoc/cli/commands/mcp_cmd.py +13 -8
- ripperdoc/cli/commands/memory_cmd.py +5 -5
- ripperdoc/cli/commands/models_cmd.py +47 -16
- ripperdoc/cli/commands/permissions_cmd.py +302 -0
- ripperdoc/cli/commands/resume_cmd.py +1 -2
- ripperdoc/cli/commands/tasks_cmd.py +24 -13
- ripperdoc/cli/ui/rich_ui.py +500 -406
- ripperdoc/cli/ui/tool_renderers.py +298 -0
- ripperdoc/core/agents.py +17 -9
- ripperdoc/core/config.py +130 -6
- ripperdoc/core/default_tools.py +7 -2
- ripperdoc/core/permissions.py +20 -14
- ripperdoc/core/providers/anthropic.py +107 -4
- ripperdoc/core/providers/base.py +33 -4
- ripperdoc/core/providers/gemini.py +169 -50
- ripperdoc/core/providers/openai.py +257 -23
- ripperdoc/core/query.py +294 -61
- ripperdoc/core/query_utils.py +50 -6
- ripperdoc/core/skills.py +295 -0
- ripperdoc/core/system_prompt.py +13 -7
- ripperdoc/core/tool.py +8 -6
- ripperdoc/sdk/client.py +14 -1
- ripperdoc/tools/ask_user_question_tool.py +20 -22
- ripperdoc/tools/background_shell.py +19 -13
- ripperdoc/tools/bash_tool.py +356 -209
- ripperdoc/tools/dynamic_mcp_tool.py +428 -0
- ripperdoc/tools/enter_plan_mode_tool.py +5 -2
- ripperdoc/tools/exit_plan_mode_tool.py +6 -3
- ripperdoc/tools/file_edit_tool.py +53 -10
- ripperdoc/tools/file_read_tool.py +17 -7
- ripperdoc/tools/file_write_tool.py +49 -13
- ripperdoc/tools/glob_tool.py +10 -9
- ripperdoc/tools/grep_tool.py +182 -51
- ripperdoc/tools/ls_tool.py +6 -6
- ripperdoc/tools/mcp_tools.py +106 -456
- ripperdoc/tools/multi_edit_tool.py +49 -9
- ripperdoc/tools/notebook_edit_tool.py +57 -13
- ripperdoc/tools/skill_tool.py +205 -0
- ripperdoc/tools/task_tool.py +7 -8
- ripperdoc/tools/todo_tool.py +12 -12
- ripperdoc/tools/tool_search_tool.py +5 -6
- ripperdoc/utils/coerce.py +34 -0
- ripperdoc/utils/context_length_errors.py +252 -0
- ripperdoc/utils/file_watch.py +5 -4
- ripperdoc/utils/json_utils.py +4 -4
- ripperdoc/utils/log.py +3 -3
- ripperdoc/utils/mcp.py +36 -15
- ripperdoc/utils/memory.py +9 -6
- ripperdoc/utils/message_compaction.py +16 -11
- ripperdoc/utils/messages.py +73 -8
- ripperdoc/utils/path_ignore.py +677 -0
- ripperdoc/utils/permissions/__init__.py +7 -1
- ripperdoc/utils/permissions/path_validation_utils.py +5 -3
- ripperdoc/utils/permissions/shell_command_validation.py +496 -18
- ripperdoc/utils/prompt.py +1 -1
- ripperdoc/utils/safe_get_cwd.py +5 -2
- ripperdoc/utils/session_history.py +38 -19
- ripperdoc/utils/todo.py +6 -2
- ripperdoc/utils/token_estimation.py +4 -3
- {ripperdoc-0.2.4.dist-info → ripperdoc-0.2.5.dist-info}/METADATA +12 -1
- ripperdoc-0.2.5.dist-info/RECORD +107 -0
- ripperdoc-0.2.4.dist-info/RECORD +0 -99
- {ripperdoc-0.2.4.dist-info → ripperdoc-0.2.5.dist-info}/WHEEL +0 -0
- {ripperdoc-0.2.4.dist-info → ripperdoc-0.2.5.dist-info}/entry_points.txt +0 -0
- {ripperdoc-0.2.4.dist-info → ripperdoc-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {ripperdoc-0.2.4.dist-info → ripperdoc-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@ import asyncio
|
|
|
6
6
|
import time
|
|
7
7
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
|
8
8
|
|
|
9
|
+
import anthropic
|
|
9
10
|
from anthropic import AsyncAnthropic
|
|
10
11
|
|
|
11
12
|
from ripperdoc.core.config import ModelProfile
|
|
@@ -30,6 +31,38 @@ from ripperdoc.utils.session_usage import record_usage
|
|
|
30
31
|
logger = get_logger()
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
def _classify_anthropic_error(exc: Exception) -> tuple[str, str]:
|
|
35
|
+
"""Classify an Anthropic exception into error code and user-friendly message."""
|
|
36
|
+
exc_type = type(exc).__name__
|
|
37
|
+
exc_msg = str(exc)
|
|
38
|
+
|
|
39
|
+
if isinstance(exc, anthropic.AuthenticationError):
|
|
40
|
+
return "authentication_error", f"Authentication failed: {exc_msg}"
|
|
41
|
+
if isinstance(exc, anthropic.PermissionDeniedError):
|
|
42
|
+
if "balance" in exc_msg.lower() or "insufficient" in exc_msg.lower():
|
|
43
|
+
return "insufficient_balance", f"Insufficient balance: {exc_msg}"
|
|
44
|
+
return "permission_denied", f"Permission denied: {exc_msg}"
|
|
45
|
+
if isinstance(exc, anthropic.NotFoundError):
|
|
46
|
+
return "model_not_found", f"Model not found: {exc_msg}"
|
|
47
|
+
if isinstance(exc, anthropic.BadRequestError):
|
|
48
|
+
if "context" in exc_msg.lower() or "token" in exc_msg.lower():
|
|
49
|
+
return "context_length_exceeded", f"Context length exceeded: {exc_msg}"
|
|
50
|
+
if "content" in exc_msg.lower() and "policy" in exc_msg.lower():
|
|
51
|
+
return "content_policy_violation", f"Content policy violation: {exc_msg}"
|
|
52
|
+
return "bad_request", f"Invalid request: {exc_msg}"
|
|
53
|
+
if isinstance(exc, anthropic.RateLimitError):
|
|
54
|
+
return "rate_limit", f"Rate limit exceeded: {exc_msg}"
|
|
55
|
+
if isinstance(exc, anthropic.APIConnectionError):
|
|
56
|
+
return "connection_error", f"Connection error: {exc_msg}"
|
|
57
|
+
if isinstance(exc, anthropic.APIStatusError):
|
|
58
|
+
status = getattr(exc, "status_code", "unknown")
|
|
59
|
+
return "api_error", f"API error ({status}): {exc_msg}"
|
|
60
|
+
if isinstance(exc, asyncio.TimeoutError):
|
|
61
|
+
return "timeout", f"Request timed out: {exc_msg}"
|
|
62
|
+
|
|
63
|
+
return "unknown_error", f"Unexpected error ({exc_type}): {exc_msg}"
|
|
64
|
+
|
|
65
|
+
|
|
33
66
|
class AnthropicClient(ProviderClient):
|
|
34
67
|
"""Anthropic client with streaming and non-streaming support."""
|
|
35
68
|
|
|
@@ -53,10 +86,64 @@ class AnthropicClient(ProviderClient):
|
|
|
53
86
|
progress_callback: Optional[ProgressCallback],
|
|
54
87
|
request_timeout: Optional[float],
|
|
55
88
|
max_retries: int,
|
|
89
|
+
max_thinking_tokens: int,
|
|
56
90
|
) -> ProviderResponse:
|
|
57
91
|
start_time = time.time()
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
return await self._call_impl(
|
|
95
|
+
model_profile=model_profile,
|
|
96
|
+
system_prompt=system_prompt,
|
|
97
|
+
normalized_messages=normalized_messages,
|
|
98
|
+
tools=tools,
|
|
99
|
+
tool_mode=tool_mode,
|
|
100
|
+
stream=stream,
|
|
101
|
+
progress_callback=progress_callback,
|
|
102
|
+
request_timeout=request_timeout,
|
|
103
|
+
max_retries=max_retries,
|
|
104
|
+
max_thinking_tokens=max_thinking_tokens,
|
|
105
|
+
start_time=start_time,
|
|
106
|
+
)
|
|
107
|
+
except asyncio.CancelledError:
|
|
108
|
+
raise # Don't suppress task cancellation
|
|
109
|
+
except Exception as exc:
|
|
110
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
111
|
+
error_code, error_message = _classify_anthropic_error(exc)
|
|
112
|
+
logger.error(
|
|
113
|
+
"[anthropic_client] API call failed",
|
|
114
|
+
extra={
|
|
115
|
+
"model": model_profile.model,
|
|
116
|
+
"error_code": error_code,
|
|
117
|
+
"error_message": error_message,
|
|
118
|
+
"duration_ms": round(duration_ms, 2),
|
|
119
|
+
},
|
|
120
|
+
)
|
|
121
|
+
return ProviderResponse.create_error(
|
|
122
|
+
error_code=error_code,
|
|
123
|
+
error_message=error_message,
|
|
124
|
+
duration_ms=duration_ms,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
async def _call_impl(
|
|
128
|
+
self,
|
|
129
|
+
*,
|
|
130
|
+
model_profile: ModelProfile,
|
|
131
|
+
system_prompt: str,
|
|
132
|
+
normalized_messages: Any,
|
|
133
|
+
tools: List[Tool[Any, Any]],
|
|
134
|
+
tool_mode: str,
|
|
135
|
+
stream: bool,
|
|
136
|
+
progress_callback: Optional[ProgressCallback],
|
|
137
|
+
request_timeout: Optional[float],
|
|
138
|
+
max_retries: int,
|
|
139
|
+
max_thinking_tokens: int,
|
|
140
|
+
start_time: float,
|
|
141
|
+
) -> ProviderResponse:
|
|
142
|
+
"""Internal implementation of call, may raise exceptions."""
|
|
58
143
|
tool_schemas = await build_anthropic_tool_schemas(tools)
|
|
59
144
|
collected_text: List[str] = []
|
|
145
|
+
reasoning_parts: List[str] = []
|
|
146
|
+
response_metadata: Dict[str, Any] = {}
|
|
60
147
|
|
|
61
148
|
anthropic_kwargs = {"base_url": model_profile.api_base}
|
|
62
149
|
if model_profile.api_key:
|
|
@@ -67,6 +154,10 @@ class AnthropicClient(ProviderClient):
|
|
|
67
154
|
|
|
68
155
|
normalized_messages = sanitize_tool_history(list(normalized_messages))
|
|
69
156
|
|
|
157
|
+
thinking_payload: Optional[Dict[str, Any]] = None
|
|
158
|
+
if max_thinking_tokens > 0:
|
|
159
|
+
thinking_payload = {"type": "enabled", "budget_tokens": max_thinking_tokens}
|
|
160
|
+
|
|
70
161
|
async with await self._client(anthropic_kwargs) as client:
|
|
71
162
|
|
|
72
163
|
async def _stream_request() -> Any:
|
|
@@ -77,6 +168,7 @@ class AnthropicClient(ProviderClient):
|
|
|
77
168
|
messages=normalized_messages, # type: ignore[arg-type]
|
|
78
169
|
tools=tool_schemas if tool_schemas else None, # type: ignore
|
|
79
170
|
temperature=model_profile.temperature,
|
|
171
|
+
thinking=thinking_payload, # type: ignore[arg-type]
|
|
80
172
|
)
|
|
81
173
|
stream_resp = (
|
|
82
174
|
await asyncio.wait_for(stream_cm.__aenter__(), timeout=request_timeout)
|
|
@@ -90,8 +182,11 @@ class AnthropicClient(ProviderClient):
|
|
|
90
182
|
if progress_callback:
|
|
91
183
|
try:
|
|
92
184
|
await progress_callback(text)
|
|
93
|
-
except
|
|
94
|
-
logger.
|
|
185
|
+
except (RuntimeError, ValueError, TypeError, OSError) as cb_exc:
|
|
186
|
+
logger.warning(
|
|
187
|
+
"[anthropic_client] Stream callback failed: %s: %s",
|
|
188
|
+
type(cb_exc).__name__, cb_exc,
|
|
189
|
+
)
|
|
95
190
|
getter = getattr(stream_resp, "get_final_response", None) or getattr(
|
|
96
191
|
stream_resp, "get_final_message", None
|
|
97
192
|
)
|
|
@@ -109,6 +204,7 @@ class AnthropicClient(ProviderClient):
|
|
|
109
204
|
messages=normalized_messages, # type: ignore[arg-type]
|
|
110
205
|
tools=tool_schemas if tool_schemas else None, # type: ignore
|
|
111
206
|
temperature=model_profile.temperature,
|
|
207
|
+
thinking=thinking_payload, # type: ignore[arg-type]
|
|
112
208
|
)
|
|
113
209
|
|
|
114
210
|
timeout_for_call = None if stream else request_timeout
|
|
@@ -126,8 +222,14 @@ class AnthropicClient(ProviderClient):
|
|
|
126
222
|
)
|
|
127
223
|
|
|
128
224
|
content_blocks = content_blocks_from_anthropic_response(response, tool_mode)
|
|
129
|
-
|
|
130
|
-
|
|
225
|
+
for blk in content_blocks:
|
|
226
|
+
if blk.get("type") == "thinking":
|
|
227
|
+
thinking_text = blk.get("thinking") or blk.get("text") or ""
|
|
228
|
+
if thinking_text:
|
|
229
|
+
reasoning_parts.append(str(thinking_text))
|
|
230
|
+
if reasoning_parts:
|
|
231
|
+
response_metadata["reasoning_content"] = "\n".join(reasoning_parts)
|
|
232
|
+
# Streaming progress is handled via text_stream; final content retains thinking blocks.
|
|
131
233
|
|
|
132
234
|
logger.info(
|
|
133
235
|
"[anthropic_client] Response received",
|
|
@@ -144,4 +246,5 @@ class AnthropicClient(ProviderClient):
|
|
|
144
246
|
usage_tokens=usage_tokens,
|
|
145
247
|
cost_usd=cost_usd,
|
|
146
248
|
duration_ms=duration_ms,
|
|
249
|
+
metadata=response_metadata,
|
|
147
250
|
)
|
ripperdoc/core/providers/base.py
CHANGED
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import asyncio
|
|
6
6
|
import random
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
|
-
from dataclasses import dataclass
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
9
|
from typing import (
|
|
10
10
|
Any,
|
|
11
11
|
AsyncIterable,
|
|
@@ -35,6 +35,29 @@ class ProviderResponse:
|
|
|
35
35
|
usage_tokens: Dict[str, int]
|
|
36
36
|
cost_usd: float
|
|
37
37
|
duration_ms: float
|
|
38
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
39
|
+
# Error handling fields
|
|
40
|
+
is_error: bool = False
|
|
41
|
+
error_code: Optional[str] = None # e.g., "permission_denied", "context_length_exceeded"
|
|
42
|
+
error_message: Optional[str] = None
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def create_error(
|
|
46
|
+
cls,
|
|
47
|
+
error_code: str,
|
|
48
|
+
error_message: str,
|
|
49
|
+
duration_ms: float = 0.0,
|
|
50
|
+
) -> "ProviderResponse":
|
|
51
|
+
"""Create an error response with a text block containing the error message."""
|
|
52
|
+
return cls(
|
|
53
|
+
content_blocks=[{"type": "text", "text": f"[API Error] {error_message}"}],
|
|
54
|
+
usage_tokens={},
|
|
55
|
+
cost_usd=0.0,
|
|
56
|
+
duration_ms=duration_ms,
|
|
57
|
+
is_error=True,
|
|
58
|
+
error_code=error_code,
|
|
59
|
+
error_message=error_message,
|
|
60
|
+
)
|
|
38
61
|
|
|
39
62
|
|
|
40
63
|
class ProviderClient(ABC):
|
|
@@ -53,6 +76,7 @@ class ProviderClient(ABC):
|
|
|
53
76
|
progress_callback: Optional[ProgressCallback],
|
|
54
77
|
request_timeout: Optional[float],
|
|
55
78
|
max_retries: int,
|
|
79
|
+
max_thinking_tokens: int,
|
|
56
80
|
) -> ProviderResponse:
|
|
57
81
|
"""Execute a model call and return a normalized response."""
|
|
58
82
|
|
|
@@ -170,6 +194,7 @@ def _retry_delay_seconds(attempt: int, base_delay: float = 0.5, max_delay: float
|
|
|
170
194
|
jitter: float = float(random.random() * 0.25 * capped_base)
|
|
171
195
|
return float(capped_base + jitter)
|
|
172
196
|
|
|
197
|
+
|
|
173
198
|
async def iter_with_timeout(
|
|
174
199
|
stream: Iterable[Any] | AsyncIterable[Any], timeout: Optional[float]
|
|
175
200
|
) -> AsyncIterator[Any]:
|
|
@@ -194,7 +219,9 @@ async def iter_with_timeout(
|
|
|
194
219
|
iterator = iter(stream)
|
|
195
220
|
while True:
|
|
196
221
|
try:
|
|
197
|
-
next_item = await asyncio.wait_for(
|
|
222
|
+
next_item = await asyncio.wait_for(
|
|
223
|
+
asyncio.to_thread(next, iterator), timeout=timeout
|
|
224
|
+
)
|
|
198
225
|
except StopIteration:
|
|
199
226
|
break
|
|
200
227
|
yield next_item
|
|
@@ -228,9 +255,11 @@ async def call_with_timeout_and_retries(
|
|
|
228
255
|
},
|
|
229
256
|
)
|
|
230
257
|
await asyncio.sleep(delay_seconds)
|
|
231
|
-
except
|
|
258
|
+
except asyncio.CancelledError:
|
|
259
|
+
raise # Don't suppress task cancellation
|
|
260
|
+
except (RuntimeError, ValueError, TypeError, OSError, ConnectionError) as exc:
|
|
232
261
|
# Non-timeout errors are not retried; surface immediately.
|
|
233
|
-
raise
|
|
262
|
+
raise exc
|
|
234
263
|
if last_error:
|
|
235
264
|
raise RuntimeError(f"Request timed out after {attempts} attempts") from last_error
|
|
236
265
|
raise RuntimeError("Unexpected error executing request with retries")
|
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import asyncio
|
|
5
6
|
import copy
|
|
6
7
|
import inspect
|
|
7
8
|
import os
|
|
8
9
|
import time
|
|
9
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, cast
|
|
10
11
|
from uuid import uuid4
|
|
11
12
|
|
|
12
13
|
from ripperdoc.core.config import ModelProfile
|
|
@@ -27,13 +28,55 @@ logger = get_logger()
|
|
|
27
28
|
|
|
28
29
|
# Constants
|
|
29
30
|
GEMINI_SDK_IMPORT_ERROR = (
|
|
30
|
-
"Gemini client requires the 'google-genai' package. "
|
|
31
|
-
"Install it with: pip install google-genai"
|
|
31
|
+
"Gemini client requires the 'google-genai' package. Install it with: pip install google-genai"
|
|
32
32
|
)
|
|
33
33
|
GEMINI_MODELS_ENDPOINT_ERROR = "Gemini client is missing 'models' endpoint"
|
|
34
34
|
GEMINI_GENERATE_CONTENT_ERROR = "Gemini client is missing generate_content() method"
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
def _classify_gemini_error(exc: Exception) -> tuple[str, str]:
|
|
38
|
+
"""Classify a Gemini exception into error code and user-friendly message."""
|
|
39
|
+
exc_type = type(exc).__name__
|
|
40
|
+
exc_msg = str(exc)
|
|
41
|
+
|
|
42
|
+
# Try to import Google's exception types for more specific handling
|
|
43
|
+
try:
|
|
44
|
+
from google.api_core import exceptions as google_exceptions # type: ignore
|
|
45
|
+
|
|
46
|
+
if isinstance(exc, google_exceptions.Unauthenticated):
|
|
47
|
+
return "authentication_error", f"Authentication failed: {exc_msg}"
|
|
48
|
+
if isinstance(exc, google_exceptions.PermissionDenied):
|
|
49
|
+
return "permission_denied", f"Permission denied: {exc_msg}"
|
|
50
|
+
if isinstance(exc, google_exceptions.NotFound):
|
|
51
|
+
return "model_not_found", f"Model not found: {exc_msg}"
|
|
52
|
+
if isinstance(exc, google_exceptions.InvalidArgument):
|
|
53
|
+
if "context" in exc_msg.lower() or "token" in exc_msg.lower():
|
|
54
|
+
return "context_length_exceeded", f"Context length exceeded: {exc_msg}"
|
|
55
|
+
return "bad_request", f"Invalid request: {exc_msg}"
|
|
56
|
+
if isinstance(exc, google_exceptions.ResourceExhausted):
|
|
57
|
+
return "rate_limit", f"Rate limit exceeded: {exc_msg}"
|
|
58
|
+
if isinstance(exc, google_exceptions.ServiceUnavailable):
|
|
59
|
+
return "service_unavailable", f"Service unavailable: {exc_msg}"
|
|
60
|
+
if isinstance(exc, google_exceptions.GoogleAPICallError):
|
|
61
|
+
return "api_error", f"API error: {exc_msg}"
|
|
62
|
+
except ImportError:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
# Fallback for generic exceptions
|
|
66
|
+
if isinstance(exc, asyncio.TimeoutError):
|
|
67
|
+
return "timeout", f"Request timed out: {exc_msg}"
|
|
68
|
+
if isinstance(exc, ConnectionError):
|
|
69
|
+
return "connection_error", f"Connection error: {exc_msg}"
|
|
70
|
+
if "quota" in exc_msg.lower() or "limit" in exc_msg.lower():
|
|
71
|
+
return "rate_limit", f"Rate limit exceeded: {exc_msg}"
|
|
72
|
+
if "auth" in exc_msg.lower() or "key" in exc_msg.lower():
|
|
73
|
+
return "authentication_error", f"Authentication error: {exc_msg}"
|
|
74
|
+
if "not found" in exc_msg.lower():
|
|
75
|
+
return "model_not_found", f"Model not found: {exc_msg}"
|
|
76
|
+
|
|
77
|
+
return "unknown_error", f"Unexpected error ({exc_type}): {exc_msg}"
|
|
78
|
+
|
|
79
|
+
|
|
37
80
|
def _extract_usage_metadata(payload: Any) -> Dict[str, int]:
|
|
38
81
|
"""Best-effort token extraction from Gemini responses."""
|
|
39
82
|
usage = getattr(payload, "usage_metadata", None) or getattr(payload, "usageMetadata", None)
|
|
@@ -49,9 +92,13 @@ def _extract_usage_metadata(payload: Any) -> Dict[str, int]:
|
|
|
49
92
|
value = getattr(usage, key, 0)
|
|
50
93
|
return int(value) if value else 0
|
|
51
94
|
|
|
95
|
+
thought_tokens = safe_get_int("thoughts_token_count")
|
|
96
|
+
candidate_tokens = safe_get_int("candidates_token_count")
|
|
97
|
+
|
|
52
98
|
return {
|
|
53
|
-
"input_tokens": safe_get_int("prompt_token_count")
|
|
54
|
-
|
|
99
|
+
"input_tokens": safe_get_int("prompt_token_count")
|
|
100
|
+
+ safe_get_int("cached_content_token_count"),
|
|
101
|
+
"output_tokens": candidate_tokens + thought_tokens,
|
|
55
102
|
"cache_read_input_tokens": safe_get_int("cached_content_token_count"),
|
|
56
103
|
"cache_creation_input_tokens": 0,
|
|
57
104
|
}
|
|
@@ -72,8 +119,10 @@ def _collect_parts(candidate: Any) -> List[Any]:
|
|
|
72
119
|
def _collect_text_from_parts(parts: List[Any]) -> str:
|
|
73
120
|
texts: List[str] = []
|
|
74
121
|
for part in parts:
|
|
75
|
-
text_val =
|
|
76
|
-
part, "
|
|
122
|
+
text_val = (
|
|
123
|
+
getattr(part, "text", None)
|
|
124
|
+
or getattr(part, "content", None)
|
|
125
|
+
or getattr(part, "raw_text", None)
|
|
77
126
|
)
|
|
78
127
|
if isinstance(text_val, str):
|
|
79
128
|
texts.append(text_val)
|
|
@@ -143,27 +192,64 @@ def _supports_stream_arg(fn: Any) -> bool:
|
|
|
143
192
|
return False
|
|
144
193
|
|
|
145
194
|
|
|
195
|
+
def _build_thinking_config(max_thinking_tokens: int, model_name: str) -> Dict[str, Any]:
|
|
196
|
+
"""Map max_thinking_tokens to Gemini thinking_config settings."""
|
|
197
|
+
if max_thinking_tokens <= 0:
|
|
198
|
+
return {}
|
|
199
|
+
name = (model_name or "").lower()
|
|
200
|
+
config: Dict[str, Any] = {"include_thoughts": True}
|
|
201
|
+
if "gemini-3" in name:
|
|
202
|
+
config["thinking_level"] = "low" if max_thinking_tokens <= 2048 else "high"
|
|
203
|
+
else:
|
|
204
|
+
config["thinking_budget"] = max_thinking_tokens
|
|
205
|
+
return config
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _collect_thoughts_from_parts(parts: List[Any]) -> List[str]:
|
|
209
|
+
"""Extract thought summaries from parts flagged as thoughts."""
|
|
210
|
+
snippets: List[str] = []
|
|
211
|
+
for part in parts:
|
|
212
|
+
is_thought = getattr(part, "thought", None)
|
|
213
|
+
if is_thought is None and isinstance(part, dict):
|
|
214
|
+
is_thought = part.get("thought")
|
|
215
|
+
if not is_thought:
|
|
216
|
+
continue
|
|
217
|
+
text_val = (
|
|
218
|
+
getattr(part, "text", None)
|
|
219
|
+
or getattr(part, "content", None)
|
|
220
|
+
or getattr(part, "raw_text", None)
|
|
221
|
+
)
|
|
222
|
+
if isinstance(text_val, str):
|
|
223
|
+
snippets.append(text_val)
|
|
224
|
+
return snippets
|
|
225
|
+
|
|
226
|
+
|
|
146
227
|
async def _async_build_tool_declarations(tools: List[Tool[Any, Any]]) -> List[Dict[str, Any]]:
|
|
147
228
|
declarations: List[Dict[str, Any]] = []
|
|
148
229
|
try:
|
|
149
230
|
from google.genai import types as genai_types # type: ignore
|
|
150
|
-
except
|
|
151
|
-
genai_types = None
|
|
231
|
+
except (ImportError, ModuleNotFoundError): # pragma: no cover - fallback when SDK not installed
|
|
232
|
+
genai_types = None # type: ignore[assignment]
|
|
152
233
|
|
|
153
234
|
for tool in tools:
|
|
154
235
|
description = await build_tool_description(tool, include_examples=True, max_examples=2)
|
|
155
236
|
parameters_schema = _flatten_schema(tool.input_schema.model_json_schema())
|
|
156
237
|
if genai_types:
|
|
238
|
+
func_decl = genai_types.FunctionDeclaration(
|
|
239
|
+
name=tool.name,
|
|
240
|
+
description=description,
|
|
241
|
+
parameters_json_schema=parameters_schema,
|
|
242
|
+
)
|
|
157
243
|
declarations.append(
|
|
158
|
-
|
|
159
|
-
name=tool.name,
|
|
160
|
-
description=description,
|
|
161
|
-
parameters=genai_types.Schema(**parameters_schema),
|
|
162
|
-
)
|
|
244
|
+
func_decl.model_dump(mode="json", exclude_none=True)
|
|
163
245
|
)
|
|
164
246
|
else:
|
|
165
247
|
declarations.append(
|
|
166
|
-
{
|
|
248
|
+
{
|
|
249
|
+
"name": tool.name,
|
|
250
|
+
"description": description,
|
|
251
|
+
"parameters_json_schema": parameters_schema,
|
|
252
|
+
}
|
|
167
253
|
)
|
|
168
254
|
return declarations
|
|
169
255
|
|
|
@@ -183,8 +269,8 @@ def _convert_messages_to_genai_contents(
|
|
|
183
269
|
# Lazy import to avoid hard dependency in tests.
|
|
184
270
|
try:
|
|
185
271
|
from google.genai import types as genai_types # type: ignore
|
|
186
|
-
except
|
|
187
|
-
genai_types = None
|
|
272
|
+
except (ImportError, ModuleNotFoundError): # pragma: no cover - fallback when SDK not installed
|
|
273
|
+
genai_types = None # type: ignore[assignment]
|
|
188
274
|
|
|
189
275
|
def _mk_part_from_text(text: str) -> Any:
|
|
190
276
|
if genai_types:
|
|
@@ -268,19 +354,19 @@ class GeminiClient(ProviderClient):
|
|
|
268
354
|
|
|
269
355
|
try:
|
|
270
356
|
from google import genai # type: ignore
|
|
271
|
-
except
|
|
357
|
+
except (ImportError, ModuleNotFoundError) as exc: # pragma: no cover - import guard
|
|
272
358
|
raise RuntimeError(GEMINI_SDK_IMPORT_ERROR) from exc
|
|
273
359
|
|
|
274
360
|
client_kwargs: Dict[str, Any] = {}
|
|
275
|
-
api_key =
|
|
361
|
+
api_key = (
|
|
362
|
+
model_profile.api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
363
|
+
)
|
|
276
364
|
if api_key:
|
|
277
365
|
client_kwargs["api_key"] = api_key
|
|
278
366
|
if model_profile.api_base:
|
|
279
367
|
from google.genai import types as genai_types # type: ignore
|
|
280
368
|
|
|
281
|
-
client_kwargs["http_options"] = genai_types.HttpOptions(
|
|
282
|
-
base_url=model_profile.api_base
|
|
283
|
-
)
|
|
369
|
+
client_kwargs["http_options"] = genai_types.HttpOptions(base_url=model_profile.api_base)
|
|
284
370
|
return genai.Client(**client_kwargs)
|
|
285
371
|
|
|
286
372
|
async def call(
|
|
@@ -295,19 +381,30 @@ class GeminiClient(ProviderClient):
|
|
|
295
381
|
progress_callback: Optional[ProgressCallback],
|
|
296
382
|
request_timeout: Optional[float],
|
|
297
383
|
max_retries: int,
|
|
384
|
+
max_thinking_tokens: int,
|
|
298
385
|
) -> ProviderResponse:
|
|
299
386
|
start_time = time.time()
|
|
300
387
|
|
|
301
388
|
try:
|
|
302
389
|
client = await self._client(model_profile)
|
|
390
|
+
except asyncio.CancelledError:
|
|
391
|
+
raise # Don't suppress task cancellation
|
|
303
392
|
except Exception as exc:
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
393
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
394
|
+
error_code, error_message = _classify_gemini_error(exc)
|
|
395
|
+
logger.error(
|
|
396
|
+
"[gemini_client] Initialization failed",
|
|
397
|
+
extra={
|
|
398
|
+
"model": model_profile.model,
|
|
399
|
+
"error_code": error_code,
|
|
400
|
+
"error_message": error_message,
|
|
401
|
+
"duration_ms": round(duration_ms, 2),
|
|
402
|
+
},
|
|
403
|
+
)
|
|
404
|
+
return ProviderResponse.create_error(
|
|
405
|
+
error_code=error_code,
|
|
406
|
+
error_message=error_message,
|
|
407
|
+
duration_ms=duration_ms,
|
|
311
408
|
)
|
|
312
409
|
|
|
313
410
|
declarations: List[Dict[str, Any]] = []
|
|
@@ -319,13 +416,16 @@ class GeminiClient(ProviderClient):
|
|
|
319
416
|
config: Dict[str, Any] = {"system_instruction": system_prompt}
|
|
320
417
|
if model_profile.max_tokens:
|
|
321
418
|
config["max_output_tokens"] = model_profile.max_tokens
|
|
322
|
-
|
|
419
|
+
thinking_config = _build_thinking_config(max_thinking_tokens, model_profile.model)
|
|
420
|
+
if thinking_config:
|
|
323
421
|
try:
|
|
324
422
|
from google.genai import types as genai_types # type: ignore
|
|
325
423
|
|
|
326
|
-
config["
|
|
327
|
-
except
|
|
328
|
-
config["
|
|
424
|
+
config["thinking_config"] = genai_types.ThinkingConfig(**thinking_config)
|
|
425
|
+
except (ImportError, ModuleNotFoundError, TypeError, ValueError): # pragma: no cover - fallback when SDK not installed
|
|
426
|
+
config["thinking_config"] = thinking_config
|
|
427
|
+
if declarations:
|
|
428
|
+
config["tools"] = [{"function_declarations": declarations}]
|
|
329
429
|
|
|
330
430
|
generate_kwargs: Dict[str, Any] = {
|
|
331
431
|
"model": model_profile.model,
|
|
@@ -335,6 +435,8 @@ class GeminiClient(ProviderClient):
|
|
|
335
435
|
usage_tokens: Dict[str, int] = {}
|
|
336
436
|
collected_text: List[str] = []
|
|
337
437
|
function_calls: List[Dict[str, Any]] = []
|
|
438
|
+
reasoning_parts: List[str] = []
|
|
439
|
+
response_metadata: Dict[str, Any] = {}
|
|
338
440
|
|
|
339
441
|
async def _call_generate(streaming: bool) -> Any:
|
|
340
442
|
models_api = getattr(client, "models", None) or getattr(
|
|
@@ -379,11 +481,6 @@ class GeminiClient(ProviderClient):
|
|
|
379
481
|
if generate_fn is None:
|
|
380
482
|
raise RuntimeError(GEMINI_GENERATE_CONTENT_ERROR)
|
|
381
483
|
|
|
382
|
-
result = generate_fn(**generate_kwargs)
|
|
383
|
-
if inspect.isawaitable(result):
|
|
384
|
-
return await result
|
|
385
|
-
return result
|
|
386
|
-
|
|
387
484
|
try:
|
|
388
485
|
if stream:
|
|
389
486
|
stream_resp = await _call_generate(streaming=True)
|
|
@@ -393,12 +490,14 @@ class GeminiClient(ProviderClient):
|
|
|
393
490
|
def _to_async_iter(obj: Any) -> AsyncIterator[Any]:
|
|
394
491
|
"""Convert various iterable types to async generator."""
|
|
395
492
|
if inspect.isasyncgen(obj) or hasattr(obj, "__aiter__"):
|
|
493
|
+
|
|
396
494
|
async def _wrap_async() -> AsyncIterator[Any]:
|
|
397
495
|
async for item in obj:
|
|
398
496
|
yield item
|
|
399
497
|
|
|
400
498
|
return _wrap_async()
|
|
401
499
|
if hasattr(obj, "__iter__"):
|
|
500
|
+
|
|
402
501
|
async def _wrap_sync() -> AsyncIterator[Any]:
|
|
403
502
|
for item in obj:
|
|
404
503
|
yield item
|
|
@@ -416,14 +515,19 @@ class GeminiClient(ProviderClient):
|
|
|
416
515
|
candidates = getattr(chunk, "candidates", None) or []
|
|
417
516
|
for candidate in candidates:
|
|
418
517
|
parts = _collect_parts(candidate)
|
|
518
|
+
text_chunk = _collect_text_from_parts(parts)
|
|
419
519
|
if progress_callback:
|
|
420
|
-
|
|
421
|
-
if text_delta:
|
|
520
|
+
if text_chunk:
|
|
422
521
|
try:
|
|
423
|
-
await progress_callback(
|
|
424
|
-
except
|
|
425
|
-
logger.
|
|
426
|
-
|
|
522
|
+
await progress_callback(text_chunk)
|
|
523
|
+
except (RuntimeError, ValueError, TypeError, OSError) as cb_exc:
|
|
524
|
+
logger.warning(
|
|
525
|
+
"[gemini_client] Stream callback failed: %s: %s",
|
|
526
|
+
type(cb_exc).__name__, cb_exc,
|
|
527
|
+
)
|
|
528
|
+
if text_chunk:
|
|
529
|
+
collected_text.append(text_chunk)
|
|
530
|
+
reasoning_parts.extend(_collect_thoughts_from_parts(parts))
|
|
427
531
|
function_calls.extend(_extract_function_calls(parts))
|
|
428
532
|
usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
|
|
429
533
|
else:
|
|
@@ -437,24 +541,38 @@ class GeminiClient(ProviderClient):
|
|
|
437
541
|
if candidates:
|
|
438
542
|
parts = _collect_parts(candidates[0])
|
|
439
543
|
collected_text.append(_collect_text_from_parts(parts))
|
|
544
|
+
reasoning_parts.extend(_collect_thoughts_from_parts(parts))
|
|
440
545
|
function_calls.extend(_extract_function_calls(parts))
|
|
441
546
|
else:
|
|
442
547
|
# Fallback: try to read text directly
|
|
443
548
|
collected_text.append(getattr(response, "text", "") or "")
|
|
444
549
|
usage_tokens = _extract_usage_metadata(response)
|
|
550
|
+
except asyncio.CancelledError:
|
|
551
|
+
raise # Don't suppress task cancellation
|
|
445
552
|
except Exception as exc:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
553
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
554
|
+
error_code, error_message = _classify_gemini_error(exc)
|
|
555
|
+
logger.error(
|
|
556
|
+
"[gemini_client] API call failed",
|
|
557
|
+
extra={
|
|
558
|
+
"model": model_profile.model,
|
|
559
|
+
"error_code": error_code,
|
|
560
|
+
"error_message": error_message,
|
|
561
|
+
"duration_ms": round(duration_ms, 2),
|
|
562
|
+
},
|
|
563
|
+
)
|
|
564
|
+
return ProviderResponse.create_error(
|
|
565
|
+
error_code=error_code,
|
|
566
|
+
error_message=error_message,
|
|
567
|
+
duration_ms=duration_ms,
|
|
452
568
|
)
|
|
453
569
|
|
|
454
570
|
content_blocks: List[Dict[str, Any]] = []
|
|
455
571
|
combined_text = "".join(collected_text).strip()
|
|
456
572
|
if combined_text:
|
|
457
573
|
content_blocks.append({"type": "text", "text": combined_text})
|
|
574
|
+
if reasoning_parts:
|
|
575
|
+
response_metadata["reasoning_content"] = "".join(reasoning_parts)
|
|
458
576
|
|
|
459
577
|
for call in function_calls:
|
|
460
578
|
if not call.get("name"):
|
|
@@ -493,4 +611,5 @@ class GeminiClient(ProviderClient):
|
|
|
493
611
|
usage_tokens=usage_tokens,
|
|
494
612
|
cost_usd=cost_usd,
|
|
495
613
|
duration_ms=duration_ms,
|
|
614
|
+
metadata=response_metadata,
|
|
496
615
|
)
|