ripperdoc 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripperdoc/__init__.py +1 -1
- ripperdoc/__main__.py +0 -5
- ripperdoc/cli/cli.py +37 -16
- ripperdoc/cli/commands/__init__.py +2 -0
- ripperdoc/cli/commands/agents_cmd.py +12 -9
- ripperdoc/cli/commands/compact_cmd.py +7 -3
- ripperdoc/cli/commands/context_cmd.py +35 -15
- ripperdoc/cli/commands/doctor_cmd.py +27 -14
- ripperdoc/cli/commands/exit_cmd.py +1 -1
- ripperdoc/cli/commands/mcp_cmd.py +13 -8
- ripperdoc/cli/commands/memory_cmd.py +5 -5
- ripperdoc/cli/commands/models_cmd.py +47 -16
- ripperdoc/cli/commands/permissions_cmd.py +302 -0
- ripperdoc/cli/commands/resume_cmd.py +1 -2
- ripperdoc/cli/commands/tasks_cmd.py +24 -13
- ripperdoc/cli/ui/rich_ui.py +523 -396
- ripperdoc/cli/ui/tool_renderers.py +298 -0
- ripperdoc/core/agents.py +172 -4
- ripperdoc/core/config.py +130 -6
- ripperdoc/core/default_tools.py +13 -2
- ripperdoc/core/permissions.py +20 -14
- ripperdoc/core/providers/__init__.py +31 -15
- ripperdoc/core/providers/anthropic.py +122 -8
- ripperdoc/core/providers/base.py +93 -15
- ripperdoc/core/providers/gemini.py +539 -96
- ripperdoc/core/providers/openai.py +371 -26
- ripperdoc/core/query.py +301 -62
- ripperdoc/core/query_utils.py +51 -7
- ripperdoc/core/skills.py +295 -0
- ripperdoc/core/system_prompt.py +79 -67
- ripperdoc/core/tool.py +15 -6
- ripperdoc/sdk/client.py +14 -1
- ripperdoc/tools/ask_user_question_tool.py +431 -0
- ripperdoc/tools/background_shell.py +82 -26
- ripperdoc/tools/bash_tool.py +356 -209
- ripperdoc/tools/dynamic_mcp_tool.py +428 -0
- ripperdoc/tools/enter_plan_mode_tool.py +226 -0
- ripperdoc/tools/exit_plan_mode_tool.py +153 -0
- ripperdoc/tools/file_edit_tool.py +53 -10
- ripperdoc/tools/file_read_tool.py +17 -7
- ripperdoc/tools/file_write_tool.py +49 -13
- ripperdoc/tools/glob_tool.py +10 -9
- ripperdoc/tools/grep_tool.py +182 -51
- ripperdoc/tools/ls_tool.py +6 -6
- ripperdoc/tools/mcp_tools.py +172 -413
- ripperdoc/tools/multi_edit_tool.py +49 -9
- ripperdoc/tools/notebook_edit_tool.py +57 -13
- ripperdoc/tools/skill_tool.py +205 -0
- ripperdoc/tools/task_tool.py +91 -9
- ripperdoc/tools/todo_tool.py +12 -12
- ripperdoc/tools/tool_search_tool.py +5 -6
- ripperdoc/utils/coerce.py +34 -0
- ripperdoc/utils/context_length_errors.py +252 -0
- ripperdoc/utils/file_watch.py +5 -4
- ripperdoc/utils/json_utils.py +4 -4
- ripperdoc/utils/log.py +3 -3
- ripperdoc/utils/mcp.py +82 -22
- ripperdoc/utils/memory.py +9 -6
- ripperdoc/utils/message_compaction.py +19 -16
- ripperdoc/utils/messages.py +73 -8
- ripperdoc/utils/path_ignore.py +677 -0
- ripperdoc/utils/permissions/__init__.py +7 -1
- ripperdoc/utils/permissions/path_validation_utils.py +5 -3
- ripperdoc/utils/permissions/shell_command_validation.py +496 -18
- ripperdoc/utils/prompt.py +1 -1
- ripperdoc/utils/safe_get_cwd.py +5 -2
- ripperdoc/utils/session_history.py +38 -19
- ripperdoc/utils/todo.py +6 -2
- ripperdoc/utils/token_estimation.py +34 -0
- {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/METADATA +14 -1
- ripperdoc-0.2.5.dist-info/RECORD +107 -0
- ripperdoc-0.2.3.dist-info/RECORD +0 -95
- {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/WHEEL +0 -0
- {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/entry_points.txt +0 -0
- {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/licenses/LICENSE +0 -0
- {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,14 @@
|
|
|
1
|
-
"""Gemini provider client."""
|
|
1
|
+
"""Gemini provider client with function/tool calling support."""
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import asyncio
|
|
6
|
+
import copy
|
|
7
|
+
import inspect
|
|
5
8
|
import os
|
|
6
9
|
import time
|
|
7
|
-
from typing import Any, Dict, List, Optional
|
|
10
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, cast
|
|
11
|
+
from uuid import uuid4
|
|
8
12
|
|
|
9
13
|
from ripperdoc.core.config import ModelProfile
|
|
10
14
|
from ripperdoc.core.providers.base import (
|
|
@@ -12,43 +16,358 @@ from ripperdoc.core.providers.base import (
|
|
|
12
16
|
ProviderClient,
|
|
13
17
|
ProviderResponse,
|
|
14
18
|
call_with_timeout_and_retries,
|
|
19
|
+
iter_with_timeout,
|
|
15
20
|
)
|
|
21
|
+
from ripperdoc.core.query_utils import _normalize_tool_args, build_tool_description
|
|
16
22
|
from ripperdoc.core.tool import Tool
|
|
17
23
|
from ripperdoc.utils.log import get_logger
|
|
24
|
+
from ripperdoc.utils.session_usage import record_usage
|
|
25
|
+
from ripperdoc.core.query_utils import estimate_cost_usd
|
|
18
26
|
|
|
19
27
|
logger = get_logger()
|
|
20
28
|
|
|
29
|
+
# Constants
|
|
30
|
+
GEMINI_SDK_IMPORT_ERROR = (
|
|
31
|
+
"Gemini client requires the 'google-genai' package. Install it with: pip install google-genai"
|
|
32
|
+
)
|
|
33
|
+
GEMINI_MODELS_ENDPOINT_ERROR = "Gemini client is missing 'models' endpoint"
|
|
34
|
+
GEMINI_GENERATE_CONTENT_ERROR = "Gemini client is missing generate_content() method"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _classify_gemini_error(exc: Exception) -> tuple[str, str]:
|
|
38
|
+
"""Classify a Gemini exception into error code and user-friendly message."""
|
|
39
|
+
exc_type = type(exc).__name__
|
|
40
|
+
exc_msg = str(exc)
|
|
41
|
+
|
|
42
|
+
# Try to import Google's exception types for more specific handling
|
|
43
|
+
try:
|
|
44
|
+
from google.api_core import exceptions as google_exceptions # type: ignore
|
|
45
|
+
|
|
46
|
+
if isinstance(exc, google_exceptions.Unauthenticated):
|
|
47
|
+
return "authentication_error", f"Authentication failed: {exc_msg}"
|
|
48
|
+
if isinstance(exc, google_exceptions.PermissionDenied):
|
|
49
|
+
return "permission_denied", f"Permission denied: {exc_msg}"
|
|
50
|
+
if isinstance(exc, google_exceptions.NotFound):
|
|
51
|
+
return "model_not_found", f"Model not found: {exc_msg}"
|
|
52
|
+
if isinstance(exc, google_exceptions.InvalidArgument):
|
|
53
|
+
if "context" in exc_msg.lower() or "token" in exc_msg.lower():
|
|
54
|
+
return "context_length_exceeded", f"Context length exceeded: {exc_msg}"
|
|
55
|
+
return "bad_request", f"Invalid request: {exc_msg}"
|
|
56
|
+
if isinstance(exc, google_exceptions.ResourceExhausted):
|
|
57
|
+
return "rate_limit", f"Rate limit exceeded: {exc_msg}"
|
|
58
|
+
if isinstance(exc, google_exceptions.ServiceUnavailable):
|
|
59
|
+
return "service_unavailable", f"Service unavailable: {exc_msg}"
|
|
60
|
+
if isinstance(exc, google_exceptions.GoogleAPICallError):
|
|
61
|
+
return "api_error", f"API error: {exc_msg}"
|
|
62
|
+
except ImportError:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
# Fallback for generic exceptions
|
|
66
|
+
if isinstance(exc, asyncio.TimeoutError):
|
|
67
|
+
return "timeout", f"Request timed out: {exc_msg}"
|
|
68
|
+
if isinstance(exc, ConnectionError):
|
|
69
|
+
return "connection_error", f"Connection error: {exc_msg}"
|
|
70
|
+
if "quota" in exc_msg.lower() or "limit" in exc_msg.lower():
|
|
71
|
+
return "rate_limit", f"Rate limit exceeded: {exc_msg}"
|
|
72
|
+
if "auth" in exc_msg.lower() or "key" in exc_msg.lower():
|
|
73
|
+
return "authentication_error", f"Authentication error: {exc_msg}"
|
|
74
|
+
if "not found" in exc_msg.lower():
|
|
75
|
+
return "model_not_found", f"Model not found: {exc_msg}"
|
|
76
|
+
|
|
77
|
+
return "unknown_error", f"Unexpected error ({exc_type}): {exc_msg}"
|
|
78
|
+
|
|
21
79
|
|
|
22
80
|
def _extract_usage_metadata(payload: Any) -> Dict[str, int]:
|
|
23
81
|
"""Best-effort token extraction from Gemini responses."""
|
|
24
82
|
usage = getattr(payload, "usage_metadata", None) or getattr(payload, "usageMetadata", None)
|
|
25
83
|
if not usage:
|
|
26
84
|
usage = getattr(payload, "usage", None)
|
|
27
|
-
|
|
85
|
+
if not usage and getattr(payload, "candidates", None):
|
|
86
|
+
usage = getattr(payload.candidates[0], "usage_metadata", None)
|
|
87
|
+
|
|
88
|
+
def safe_get_int(key: str) -> int:
|
|
89
|
+
"""Safely extract integer value from usage metadata."""
|
|
90
|
+
if not usage:
|
|
91
|
+
return 0
|
|
92
|
+
value = getattr(usage, key, 0)
|
|
93
|
+
return int(value) if value else 0
|
|
94
|
+
|
|
95
|
+
thought_tokens = safe_get_int("thoughts_token_count")
|
|
96
|
+
candidate_tokens = safe_get_int("candidates_token_count")
|
|
97
|
+
|
|
28
98
|
return {
|
|
29
|
-
"input_tokens":
|
|
30
|
-
|
|
31
|
-
"
|
|
99
|
+
"input_tokens": safe_get_int("prompt_token_count")
|
|
100
|
+
+ safe_get_int("cached_content_token_count"),
|
|
101
|
+
"output_tokens": candidate_tokens + thought_tokens,
|
|
102
|
+
"cache_read_input_tokens": safe_get_int("cached_content_token_count"),
|
|
32
103
|
"cache_creation_input_tokens": 0,
|
|
33
104
|
}
|
|
34
105
|
|
|
35
106
|
|
|
36
|
-
def
|
|
37
|
-
parts
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
107
|
+
def _collect_parts(candidate: Any) -> List[Any]:
|
|
108
|
+
"""Return a list of parts from a candidate regardless of SDK shape."""
|
|
109
|
+
content = getattr(candidate, "content", None)
|
|
110
|
+
if content is None:
|
|
111
|
+
return []
|
|
112
|
+
if hasattr(content, "parts"):
|
|
113
|
+
return list(getattr(content, "parts", []) or [])
|
|
114
|
+
if isinstance(content, list):
|
|
115
|
+
return content
|
|
116
|
+
return []
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _collect_text_from_parts(parts: List[Any]) -> str:
|
|
120
|
+
texts: List[str] = []
|
|
121
|
+
for part in parts:
|
|
122
|
+
text_val = (
|
|
123
|
+
getattr(part, "text", None)
|
|
124
|
+
or getattr(part, "content", None)
|
|
125
|
+
or getattr(part, "raw_text", None)
|
|
126
|
+
)
|
|
127
|
+
if isinstance(text_val, str):
|
|
128
|
+
texts.append(text_val)
|
|
129
|
+
return "".join(texts)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _extract_function_calls(parts: List[Any]) -> List[Dict[str, Any]]:
|
|
133
|
+
calls: List[Dict[str, Any]] = []
|
|
134
|
+
for part in parts:
|
|
135
|
+
fn_call = getattr(part, "function_call", None) or getattr(part, "functionCall", None)
|
|
136
|
+
if not fn_call:
|
|
137
|
+
continue
|
|
138
|
+
name = getattr(fn_call, "name", None) or getattr(fn_call, "function_name", None)
|
|
139
|
+
args = getattr(fn_call, "args", None) or getattr(fn_call, "arguments", None) or {}
|
|
140
|
+
call_id = getattr(fn_call, "id", None) or getattr(fn_call, "call_id", None)
|
|
141
|
+
calls.append({"name": name, "args": _normalize_tool_args(args), "id": call_id})
|
|
142
|
+
return calls
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
|
|
146
|
+
"""Inline $ref entries and drop $defs/$ref for Gemini Schema compatibility.
|
|
147
|
+
|
|
148
|
+
Gemini API doesn't support JSON Schema references, so this function
|
|
149
|
+
resolves all $ref pointers by inlining the referenced definitions.
|
|
150
|
+
"""
|
|
151
|
+
definitions = copy.deepcopy(schema.get("$defs") or schema.get("definitions") or {})
|
|
152
|
+
|
|
153
|
+
def _resolve(node: Any) -> Any:
|
|
154
|
+
"""Recursively resolve $ref pointers and remove unsupported fields."""
|
|
155
|
+
if isinstance(node, dict):
|
|
156
|
+
# Handle $ref resolution
|
|
157
|
+
ref = node.get("$ref")
|
|
158
|
+
if isinstance(ref, str) and ref.startswith("#/"):
|
|
159
|
+
ref_key = ref.split("/")[-1]
|
|
160
|
+
if ref_key in definitions:
|
|
161
|
+
return _resolve(copy.deepcopy(definitions[ref_key]))
|
|
162
|
+
|
|
163
|
+
# Process remaining fields, excluding schema metadata
|
|
164
|
+
resolved: Dict[str, Any] = {}
|
|
165
|
+
for key, value in node.items():
|
|
166
|
+
if key in {"$ref", "$defs", "definitions"}:
|
|
167
|
+
continue
|
|
168
|
+
resolved[key] = _resolve(value)
|
|
169
|
+
return resolved
|
|
170
|
+
|
|
171
|
+
if isinstance(node, list):
|
|
172
|
+
return [_resolve(item) for item in node]
|
|
173
|
+
|
|
174
|
+
return node
|
|
175
|
+
|
|
176
|
+
return cast(Dict[str, Any], _resolve(copy.deepcopy(schema)))
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _supports_stream_arg(fn: Any) -> bool:
|
|
180
|
+
"""Return True if the callable appears to accept a 'stream' kwarg."""
|
|
181
|
+
try:
|
|
182
|
+
sig = inspect.signature(fn)
|
|
183
|
+
except (TypeError, ValueError):
|
|
184
|
+
# If we cannot inspect, avoid passing stream to prevent TypeErrors.
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
for param in sig.parameters.values():
|
|
188
|
+
if param.kind == param.VAR_KEYWORD:
|
|
189
|
+
return True
|
|
190
|
+
if param.name == "stream":
|
|
191
|
+
return True
|
|
192
|
+
return False
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _build_thinking_config(max_thinking_tokens: int, model_name: str) -> Dict[str, Any]:
|
|
196
|
+
"""Map max_thinking_tokens to Gemini thinking_config settings."""
|
|
197
|
+
if max_thinking_tokens <= 0:
|
|
198
|
+
return {}
|
|
199
|
+
name = (model_name or "").lower()
|
|
200
|
+
config: Dict[str, Any] = {"include_thoughts": True}
|
|
201
|
+
if "gemini-3" in name:
|
|
202
|
+
config["thinking_level"] = "low" if max_thinking_tokens <= 2048 else "high"
|
|
203
|
+
else:
|
|
204
|
+
config["thinking_budget"] = max_thinking_tokens
|
|
205
|
+
return config
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _collect_thoughts_from_parts(parts: List[Any]) -> List[str]:
|
|
209
|
+
"""Extract thought summaries from parts flagged as thoughts."""
|
|
210
|
+
snippets: List[str] = []
|
|
211
|
+
for part in parts:
|
|
212
|
+
is_thought = getattr(part, "thought", None)
|
|
213
|
+
if is_thought is None and isinstance(part, dict):
|
|
214
|
+
is_thought = part.get("thought")
|
|
215
|
+
if not is_thought:
|
|
216
|
+
continue
|
|
217
|
+
text_val = (
|
|
218
|
+
getattr(part, "text", None)
|
|
219
|
+
or getattr(part, "content", None)
|
|
220
|
+
or getattr(part, "raw_text", None)
|
|
221
|
+
)
|
|
222
|
+
if isinstance(text_val, str):
|
|
223
|
+
snippets.append(text_val)
|
|
224
|
+
return snippets
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
async def _async_build_tool_declarations(tools: List[Tool[Any, Any]]) -> List[Dict[str, Any]]:
|
|
228
|
+
declarations: List[Dict[str, Any]] = []
|
|
229
|
+
try:
|
|
230
|
+
from google.genai import types as genai_types # type: ignore
|
|
231
|
+
except (ImportError, ModuleNotFoundError): # pragma: no cover - fallback when SDK not installed
|
|
232
|
+
genai_types = None # type: ignore[assignment]
|
|
233
|
+
|
|
234
|
+
for tool in tools:
|
|
235
|
+
description = await build_tool_description(tool, include_examples=True, max_examples=2)
|
|
236
|
+
parameters_schema = _flatten_schema(tool.input_schema.model_json_schema())
|
|
237
|
+
if genai_types:
|
|
238
|
+
func_decl = genai_types.FunctionDeclaration(
|
|
239
|
+
name=tool.name,
|
|
240
|
+
description=description,
|
|
241
|
+
parameters_json_schema=parameters_schema,
|
|
242
|
+
)
|
|
243
|
+
declarations.append(
|
|
244
|
+
func_decl.model_dump(mode="json", exclude_none=True)
|
|
245
|
+
)
|
|
246
|
+
else:
|
|
247
|
+
declarations.append(
|
|
248
|
+
{
|
|
249
|
+
"name": tool.name,
|
|
250
|
+
"description": description,
|
|
251
|
+
"parameters_json_schema": parameters_schema,
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
return declarations
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _convert_messages_to_genai_contents(
|
|
258
|
+
normalized_messages: List[Dict[str, Any]],
|
|
259
|
+
) -> Tuple[List[Any], Dict[str, str]]:
|
|
260
|
+
"""Map normalized OpenAI-style messages to Gemini content payloads.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
contents: List of Content-like dicts/objects
|
|
264
|
+
tool_name_by_id: Map of tool_call_id -> function name (for pairing responses)
|
|
265
|
+
"""
|
|
266
|
+
tool_name_by_id: Dict[str, str] = {}
|
|
267
|
+
contents: List[Any] = []
|
|
268
|
+
|
|
269
|
+
# Lazy import to avoid hard dependency in tests.
|
|
270
|
+
try:
|
|
271
|
+
from google.genai import types as genai_types # type: ignore
|
|
272
|
+
except (ImportError, ModuleNotFoundError): # pragma: no cover - fallback when SDK not installed
|
|
273
|
+
genai_types = None # type: ignore[assignment]
|
|
274
|
+
|
|
275
|
+
def _mk_part_from_text(text: str) -> Any:
|
|
276
|
+
if genai_types:
|
|
277
|
+
return genai_types.Part(text=text)
|
|
278
|
+
return {"text": text}
|
|
279
|
+
|
|
280
|
+
def _mk_part_from_function_call(name: str, args: Dict[str, Any], call_id: Optional[str]) -> Any:
|
|
281
|
+
# Store mapping using actual call_id if available, otherwise generate one
|
|
282
|
+
actual_id = call_id or str(uuid4())
|
|
283
|
+
tool_name_by_id[actual_id] = name
|
|
284
|
+
if genai_types:
|
|
285
|
+
return genai_types.Part(function_call=genai_types.FunctionCall(name=name, args=args))
|
|
286
|
+
return {"function_call": {"name": name, "args": args, "id": actual_id}}
|
|
287
|
+
|
|
288
|
+
def _mk_part_from_function_response(
|
|
289
|
+
name: str, response: Dict[str, Any], call_id: Optional[str]
|
|
290
|
+
) -> Any:
|
|
291
|
+
if call_id:
|
|
292
|
+
response = {**response, "call_id": call_id}
|
|
293
|
+
if genai_types:
|
|
294
|
+
return genai_types.Part.from_function_response(name=name, response=response)
|
|
295
|
+
payload = {"function_response": {"name": name, "response": response}}
|
|
296
|
+
if call_id:
|
|
297
|
+
payload["function_response"]["id"] = call_id
|
|
298
|
+
return payload
|
|
299
|
+
|
|
300
|
+
def _mk_content(role: str, parts: List[Any]) -> Any:
|
|
301
|
+
if genai_types:
|
|
302
|
+
return genai_types.Content(role=role, parts=parts)
|
|
303
|
+
return {"role": role, "parts": parts}
|
|
304
|
+
|
|
305
|
+
for message in normalized_messages:
|
|
306
|
+
role = message.get("role") or ""
|
|
307
|
+
msg_parts: List[Any] = []
|
|
308
|
+
|
|
309
|
+
# Assistant tool calls
|
|
310
|
+
for tool_call in message.get("tool_calls") or []:
|
|
311
|
+
func = tool_call.get("function") or {}
|
|
312
|
+
name = func.get("name") or ""
|
|
313
|
+
args = _normalize_tool_args(func.get("arguments") or {})
|
|
314
|
+
call_id = tool_call.get("id")
|
|
315
|
+
msg_parts.append(_mk_part_from_function_call(name, args, call_id))
|
|
316
|
+
|
|
317
|
+
content_value = message.get("content")
|
|
318
|
+
if isinstance(content_value, str) and content_value:
|
|
319
|
+
msg_parts.append(_mk_part_from_text(content_value))
|
|
320
|
+
|
|
321
|
+
if role == "tool":
|
|
322
|
+
call_id = message.get("tool_call_id") or ""
|
|
323
|
+
name = tool_name_by_id.get(call_id, call_id or "tool_response")
|
|
324
|
+
response = {"result": content_value}
|
|
325
|
+
msg_parts.append(_mk_part_from_function_response(name, response, call_id))
|
|
326
|
+
role = "user" # Tool responses are treated as user-provided context
|
|
327
|
+
|
|
328
|
+
if not msg_parts:
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
mapped_role = "user" if role == "user" else "model"
|
|
332
|
+
contents.append(_mk_content(mapped_role, msg_parts))
|
|
333
|
+
|
|
334
|
+
return contents, tool_name_by_id
|
|
48
335
|
|
|
49
336
|
|
|
50
337
|
class GeminiClient(ProviderClient):
|
|
51
|
-
"""Gemini client with streaming and
|
|
338
|
+
"""Gemini client with streaming and function calling support."""
|
|
339
|
+
|
|
340
|
+
def __init__(self, client_factory: Optional[Any] = None) -> None:
|
|
341
|
+
self._client_factory = client_factory
|
|
342
|
+
|
|
343
|
+
async def _client(self, model_profile: ModelProfile) -> Any:
|
|
344
|
+
if self._client_factory is not None:
|
|
345
|
+
client = self._client_factory
|
|
346
|
+
if inspect.iscoroutinefunction(client):
|
|
347
|
+
return await client()
|
|
348
|
+
if inspect.isawaitable(client):
|
|
349
|
+
return await client # type: ignore[return-value]
|
|
350
|
+
if callable(client):
|
|
351
|
+
result = client()
|
|
352
|
+
return await result if inspect.isawaitable(result) else result
|
|
353
|
+
return client
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
from google import genai # type: ignore
|
|
357
|
+
except (ImportError, ModuleNotFoundError) as exc: # pragma: no cover - import guard
|
|
358
|
+
raise RuntimeError(GEMINI_SDK_IMPORT_ERROR) from exc
|
|
359
|
+
|
|
360
|
+
client_kwargs: Dict[str, Any] = {}
|
|
361
|
+
api_key = (
|
|
362
|
+
model_profile.api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
|
|
363
|
+
)
|
|
364
|
+
if api_key:
|
|
365
|
+
client_kwargs["api_key"] = api_key
|
|
366
|
+
if model_profile.api_base:
|
|
367
|
+
from google.genai import types as genai_types # type: ignore
|
|
368
|
+
|
|
369
|
+
client_kwargs["http_options"] = genai_types.HttpOptions(base_url=model_profile.api_base)
|
|
370
|
+
return genai.Client(**client_kwargs)
|
|
52
371
|
|
|
53
372
|
async def call(
|
|
54
373
|
self,
|
|
@@ -62,96 +381,218 @@ class GeminiClient(ProviderClient):
|
|
|
62
381
|
progress_callback: Optional[ProgressCallback],
|
|
63
382
|
request_timeout: Optional[float],
|
|
64
383
|
max_retries: int,
|
|
384
|
+
max_thinking_tokens: int,
|
|
65
385
|
) -> ProviderResponse:
|
|
386
|
+
start_time = time.time()
|
|
387
|
+
|
|
66
388
|
try:
|
|
67
|
-
|
|
68
|
-
except
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
389
|
+
client = await self._client(model_profile)
|
|
390
|
+
except asyncio.CancelledError:
|
|
391
|
+
raise # Don't suppress task cancellation
|
|
392
|
+
except Exception as exc:
|
|
393
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
394
|
+
error_code, error_message = _classify_gemini_error(exc)
|
|
395
|
+
logger.error(
|
|
396
|
+
"[gemini_client] Initialization failed",
|
|
397
|
+
extra={
|
|
398
|
+
"model": model_profile.model,
|
|
399
|
+
"error_code": error_code,
|
|
400
|
+
"error_message": error_message,
|
|
401
|
+
"duration_ms": round(duration_ms, 2),
|
|
402
|
+
},
|
|
72
403
|
)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
cost_usd=0.0,
|
|
78
|
-
duration_ms=0.0,
|
|
404
|
+
return ProviderResponse.create_error(
|
|
405
|
+
error_code=error_code,
|
|
406
|
+
error_message=error_message,
|
|
407
|
+
duration_ms=duration_ms,
|
|
79
408
|
)
|
|
80
409
|
|
|
410
|
+
declarations: List[Dict[str, Any]] = []
|
|
81
411
|
if tools and tool_mode != "text":
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
412
|
+
declarations = await _async_build_tool_declarations(tools)
|
|
413
|
+
|
|
414
|
+
contents, _ = _convert_messages_to_genai_contents(normalized_messages)
|
|
415
|
+
|
|
416
|
+
config: Dict[str, Any] = {"system_instruction": system_prompt}
|
|
417
|
+
if model_profile.max_tokens:
|
|
418
|
+
config["max_output_tokens"] = model_profile.max_tokens
|
|
419
|
+
thinking_config = _build_thinking_config(max_thinking_tokens, model_profile.model)
|
|
420
|
+
if thinking_config:
|
|
421
|
+
try:
|
|
422
|
+
from google.genai import types as genai_types # type: ignore
|
|
423
|
+
|
|
424
|
+
config["thinking_config"] = genai_types.ThinkingConfig(**thinking_config)
|
|
425
|
+
except (ImportError, ModuleNotFoundError, TypeError, ValueError): # pragma: no cover - fallback when SDK not installed
|
|
426
|
+
config["thinking_config"] = thinking_config
|
|
427
|
+
if declarations:
|
|
428
|
+
config["tools"] = [{"function_declarations": declarations}]
|
|
429
|
+
|
|
430
|
+
generate_kwargs: Dict[str, Any] = {
|
|
431
|
+
"model": model_profile.model,
|
|
432
|
+
"contents": contents,
|
|
433
|
+
"config": config,
|
|
434
|
+
}
|
|
435
|
+
usage_tokens: Dict[str, int] = {}
|
|
436
|
+
collected_text: List[str] = []
|
|
437
|
+
function_calls: List[Dict[str, Any]] = []
|
|
438
|
+
reasoning_parts: List[str] = []
|
|
439
|
+
response_metadata: Dict[str, Any] = {}
|
|
440
|
+
|
|
441
|
+
async def _call_generate(streaming: bool) -> Any:
|
|
442
|
+
models_api = getattr(client, "models", None) or getattr(
|
|
443
|
+
getattr(client, "aio", None), "models", None
|
|
85
444
|
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
445
|
+
if models_api is None:
|
|
446
|
+
raise RuntimeError(GEMINI_MODELS_ENDPOINT_ERROR)
|
|
447
|
+
|
|
448
|
+
generate_fn = getattr(models_api, "generate_content", None)
|
|
449
|
+
stream_fn = getattr(models_api, "generate_content_stream", None) or getattr(
|
|
450
|
+
models_api, "stream_generate_content", None
|
|
91
451
|
)
|
|
92
452
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
453
|
+
if streaming:
|
|
454
|
+
if stream_fn:
|
|
455
|
+
result = stream_fn(**generate_kwargs)
|
|
456
|
+
if inspect.isawaitable(result):
|
|
457
|
+
return await result
|
|
458
|
+
return result
|
|
459
|
+
|
|
460
|
+
if generate_fn is None:
|
|
461
|
+
raise RuntimeError(GEMINI_GENERATE_CONTENT_ERROR)
|
|
97
462
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
463
|
+
if _supports_stream_arg(generate_fn):
|
|
464
|
+
gen_kwargs: Dict[str, Any] = dict(generate_kwargs)
|
|
465
|
+
gen_kwargs["stream"] = True
|
|
466
|
+
result = generate_fn(**gen_kwargs)
|
|
467
|
+
if inspect.isawaitable(result):
|
|
468
|
+
return await result
|
|
469
|
+
return result
|
|
470
|
+
|
|
471
|
+
# Fallback: non-streaming generate; wrap to keep downstream iterator usage
|
|
472
|
+
result = generate_fn(**generate_kwargs)
|
|
473
|
+
if inspect.isawaitable(result):
|
|
474
|
+
result = await result
|
|
475
|
+
|
|
476
|
+
async def _single_chunk_stream() -> AsyncIterator[Any]:
|
|
477
|
+
yield result
|
|
478
|
+
|
|
479
|
+
return _single_chunk_stream()
|
|
480
|
+
|
|
481
|
+
if generate_fn is None:
|
|
482
|
+
raise RuntimeError(GEMINI_GENERATE_CONTENT_ERROR)
|
|
483
|
+
|
|
484
|
+
try:
|
|
485
|
+
if stream:
|
|
486
|
+
stream_resp = await _call_generate(streaming=True)
|
|
487
|
+
|
|
488
|
+
# Normalize streams into an async iterator to avoid StopIteration surfacing through
|
|
489
|
+
# asyncio executors and to handle sync iterables.
|
|
490
|
+
def _to_async_iter(obj: Any) -> AsyncIterator[Any]:
|
|
491
|
+
"""Convert various iterable types to async generator."""
|
|
492
|
+
if inspect.isasyncgen(obj) or hasattr(obj, "__aiter__"):
|
|
493
|
+
|
|
494
|
+
async def _wrap_async() -> AsyncIterator[Any]:
|
|
495
|
+
async for item in obj:
|
|
496
|
+
yield item
|
|
497
|
+
|
|
498
|
+
return _wrap_async()
|
|
499
|
+
if hasattr(obj, "__iter__"):
|
|
500
|
+
|
|
501
|
+
async def _wrap_sync() -> AsyncIterator[Any]:
|
|
502
|
+
for item in obj:
|
|
503
|
+
yield item
|
|
504
|
+
|
|
505
|
+
return _wrap_sync()
|
|
506
|
+
|
|
507
|
+
async def _single() -> AsyncIterator[Any]:
|
|
508
|
+
yield obj
|
|
509
|
+
|
|
510
|
+
return _single()
|
|
511
|
+
|
|
512
|
+
stream_iter = _to_async_iter(stream_resp)
|
|
513
|
+
|
|
514
|
+
async for chunk in iter_with_timeout(stream_iter, request_timeout):
|
|
515
|
+
candidates = getattr(chunk, "candidates", None) or []
|
|
516
|
+
for candidate in candidates:
|
|
517
|
+
parts = _collect_parts(candidate)
|
|
518
|
+
text_chunk = _collect_text_from_parts(parts)
|
|
519
|
+
if progress_callback:
|
|
520
|
+
if text_chunk:
|
|
521
|
+
try:
|
|
522
|
+
await progress_callback(text_chunk)
|
|
523
|
+
except (RuntimeError, ValueError, TypeError, OSError) as cb_exc:
|
|
524
|
+
logger.warning(
|
|
525
|
+
"[gemini_client] Stream callback failed: %s: %s",
|
|
526
|
+
type(cb_exc).__name__, cb_exc,
|
|
527
|
+
)
|
|
528
|
+
if text_chunk:
|
|
529
|
+
collected_text.append(text_chunk)
|
|
530
|
+
reasoning_parts.extend(_collect_thoughts_from_parts(parts))
|
|
531
|
+
function_calls.extend(_extract_function_calls(parts))
|
|
532
|
+
usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
|
|
533
|
+
else:
|
|
534
|
+
# Use retry logic for non-streaming calls
|
|
535
|
+
response = await call_with_timeout_and_retries(
|
|
536
|
+
lambda: _call_generate(streaming=False),
|
|
537
|
+
request_timeout,
|
|
538
|
+
max_retries,
|
|
539
|
+
)
|
|
540
|
+
candidates = getattr(response, "candidates", None) or []
|
|
541
|
+
if candidates:
|
|
542
|
+
parts = _collect_parts(candidates[0])
|
|
543
|
+
collected_text.append(_collect_text_from_parts(parts))
|
|
544
|
+
reasoning_parts.extend(_collect_thoughts_from_parts(parts))
|
|
545
|
+
function_calls.extend(_extract_function_calls(parts))
|
|
546
|
+
else:
|
|
547
|
+
# Fallback: try to read text directly
|
|
548
|
+
collected_text.append(getattr(response, "text", "") or "")
|
|
549
|
+
usage_tokens = _extract_usage_metadata(response)
|
|
550
|
+
except asyncio.CancelledError:
|
|
551
|
+
raise # Don't suppress task cancellation
|
|
552
|
+
except Exception as exc:
|
|
553
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
554
|
+
error_code, error_message = _classify_gemini_error(exc)
|
|
555
|
+
logger.error(
|
|
556
|
+
"[gemini_client] API call failed",
|
|
557
|
+
extra={
|
|
558
|
+
"model": model_profile.model,
|
|
559
|
+
"error_code": error_code,
|
|
560
|
+
"error_message": error_message,
|
|
561
|
+
"duration_ms": round(duration_ms, 2),
|
|
562
|
+
},
|
|
563
|
+
)
|
|
564
|
+
return ProviderResponse.create_error(
|
|
565
|
+
error_code=error_code,
|
|
566
|
+
error_message=error_message,
|
|
567
|
+
duration_ms=duration_ms,
|
|
103
568
|
)
|
|
104
|
-
content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", "")
|
|
105
|
-
if isinstance(content, list):
|
|
106
|
-
for item in content:
|
|
107
|
-
text_val = (
|
|
108
|
-
getattr(item, "text", None)
|
|
109
|
-
or item.get("text", "") # type: ignore[union-attr]
|
|
110
|
-
if isinstance(item, dict)
|
|
111
|
-
else ""
|
|
112
|
-
)
|
|
113
|
-
if text_val:
|
|
114
|
-
prompt_parts.append(f"{role}: {text_val}")
|
|
115
|
-
elif isinstance(content, str):
|
|
116
|
-
prompt_parts.append(f"{role}: {content}")
|
|
117
|
-
full_prompt = "\n".join(part for part in prompt_parts if part)
|
|
118
|
-
|
|
119
|
-
model = genai.GenerativeModel(model_profile.model)
|
|
120
|
-
collected_text: List[str] = []
|
|
121
|
-
start_time = time.time()
|
|
122
569
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
collected_text.append(text_delta)
|
|
130
|
-
if progress_callback:
|
|
131
|
-
try:
|
|
132
|
-
await progress_callback(text_delta)
|
|
133
|
-
except Exception:
|
|
134
|
-
logger.exception("[gemini_client] Stream callback failed")
|
|
135
|
-
usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
|
|
136
|
-
return {"usage": usage_tokens}
|
|
137
|
-
|
|
138
|
-
async def _non_stream_request() -> Any:
|
|
139
|
-
return model.generate_content(full_prompt)
|
|
140
|
-
|
|
141
|
-
response: Any = await call_with_timeout_and_retries(
|
|
142
|
-
_stream_request if stream and progress_callback else _non_stream_request,
|
|
143
|
-
request_timeout,
|
|
144
|
-
max_retries,
|
|
145
|
-
)
|
|
570
|
+
content_blocks: List[Dict[str, Any]] = []
|
|
571
|
+
combined_text = "".join(collected_text).strip()
|
|
572
|
+
if combined_text:
|
|
573
|
+
content_blocks.append({"type": "text", "text": combined_text})
|
|
574
|
+
if reasoning_parts:
|
|
575
|
+
response_metadata["reasoning_content"] = "".join(reasoning_parts)
|
|
146
576
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
577
|
+
for call in function_calls:
|
|
578
|
+
if not call.get("name"):
|
|
579
|
+
continue
|
|
580
|
+
content_blocks.append(
|
|
581
|
+
{
|
|
582
|
+
"type": "tool_use",
|
|
583
|
+
"tool_use_id": call.get("id") or str(uuid4()),
|
|
584
|
+
"name": call["name"],
|
|
585
|
+
"input": call.get("args") or {},
|
|
586
|
+
}
|
|
587
|
+
)
|
|
150
588
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
589
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
590
|
+
cost_usd = estimate_cost_usd(model_profile, usage_tokens) if usage_tokens else 0.0
|
|
591
|
+
record_usage(
|
|
592
|
+
model_profile.model,
|
|
593
|
+
duration_ms=duration_ms,
|
|
594
|
+
cost_usd=cost_usd,
|
|
595
|
+
**(usage_tokens or {}),
|
|
155
596
|
)
|
|
156
597
|
|
|
157
598
|
logger.info(
|
|
@@ -161,12 +602,14 @@ class GeminiClient(ProviderClient):
|
|
|
161
602
|
"duration_ms": round(duration_ms, 2),
|
|
162
603
|
"tool_mode": tool_mode,
|
|
163
604
|
"stream": stream,
|
|
605
|
+
"function_call_count": len(function_calls),
|
|
164
606
|
},
|
|
165
607
|
)
|
|
166
608
|
|
|
167
609
|
return ProviderResponse(
|
|
168
|
-
content_blocks=content_blocks,
|
|
610
|
+
content_blocks=content_blocks or [{"type": "text", "text": ""}],
|
|
169
611
|
usage_tokens=usage_tokens,
|
|
170
612
|
cost_usd=cost_usd,
|
|
171
613
|
duration_ms=duration_ms,
|
|
614
|
+
metadata=response_metadata,
|
|
172
615
|
)
|