ripperdoc 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. ripperdoc/__init__.py +1 -1
  2. ripperdoc/__main__.py +0 -5
  3. ripperdoc/cli/cli.py +37 -16
  4. ripperdoc/cli/commands/__init__.py +2 -0
  5. ripperdoc/cli/commands/agents_cmd.py +12 -9
  6. ripperdoc/cli/commands/compact_cmd.py +7 -3
  7. ripperdoc/cli/commands/context_cmd.py +35 -15
  8. ripperdoc/cli/commands/doctor_cmd.py +27 -14
  9. ripperdoc/cli/commands/exit_cmd.py +1 -1
  10. ripperdoc/cli/commands/mcp_cmd.py +13 -8
  11. ripperdoc/cli/commands/memory_cmd.py +5 -5
  12. ripperdoc/cli/commands/models_cmd.py +47 -16
  13. ripperdoc/cli/commands/permissions_cmd.py +302 -0
  14. ripperdoc/cli/commands/resume_cmd.py +1 -2
  15. ripperdoc/cli/commands/tasks_cmd.py +24 -13
  16. ripperdoc/cli/ui/rich_ui.py +523 -396
  17. ripperdoc/cli/ui/tool_renderers.py +298 -0
  18. ripperdoc/core/agents.py +172 -4
  19. ripperdoc/core/config.py +130 -6
  20. ripperdoc/core/default_tools.py +13 -2
  21. ripperdoc/core/permissions.py +20 -14
  22. ripperdoc/core/providers/__init__.py +31 -15
  23. ripperdoc/core/providers/anthropic.py +122 -8
  24. ripperdoc/core/providers/base.py +93 -15
  25. ripperdoc/core/providers/gemini.py +539 -96
  26. ripperdoc/core/providers/openai.py +371 -26
  27. ripperdoc/core/query.py +301 -62
  28. ripperdoc/core/query_utils.py +51 -7
  29. ripperdoc/core/skills.py +295 -0
  30. ripperdoc/core/system_prompt.py +79 -67
  31. ripperdoc/core/tool.py +15 -6
  32. ripperdoc/sdk/client.py +14 -1
  33. ripperdoc/tools/ask_user_question_tool.py +431 -0
  34. ripperdoc/tools/background_shell.py +82 -26
  35. ripperdoc/tools/bash_tool.py +356 -209
  36. ripperdoc/tools/dynamic_mcp_tool.py +428 -0
  37. ripperdoc/tools/enter_plan_mode_tool.py +226 -0
  38. ripperdoc/tools/exit_plan_mode_tool.py +153 -0
  39. ripperdoc/tools/file_edit_tool.py +53 -10
  40. ripperdoc/tools/file_read_tool.py +17 -7
  41. ripperdoc/tools/file_write_tool.py +49 -13
  42. ripperdoc/tools/glob_tool.py +10 -9
  43. ripperdoc/tools/grep_tool.py +182 -51
  44. ripperdoc/tools/ls_tool.py +6 -6
  45. ripperdoc/tools/mcp_tools.py +172 -413
  46. ripperdoc/tools/multi_edit_tool.py +49 -9
  47. ripperdoc/tools/notebook_edit_tool.py +57 -13
  48. ripperdoc/tools/skill_tool.py +205 -0
  49. ripperdoc/tools/task_tool.py +91 -9
  50. ripperdoc/tools/todo_tool.py +12 -12
  51. ripperdoc/tools/tool_search_tool.py +5 -6
  52. ripperdoc/utils/coerce.py +34 -0
  53. ripperdoc/utils/context_length_errors.py +252 -0
  54. ripperdoc/utils/file_watch.py +5 -4
  55. ripperdoc/utils/json_utils.py +4 -4
  56. ripperdoc/utils/log.py +3 -3
  57. ripperdoc/utils/mcp.py +82 -22
  58. ripperdoc/utils/memory.py +9 -6
  59. ripperdoc/utils/message_compaction.py +19 -16
  60. ripperdoc/utils/messages.py +73 -8
  61. ripperdoc/utils/path_ignore.py +677 -0
  62. ripperdoc/utils/permissions/__init__.py +7 -1
  63. ripperdoc/utils/permissions/path_validation_utils.py +5 -3
  64. ripperdoc/utils/permissions/shell_command_validation.py +496 -18
  65. ripperdoc/utils/prompt.py +1 -1
  66. ripperdoc/utils/safe_get_cwd.py +5 -2
  67. ripperdoc/utils/session_history.py +38 -19
  68. ripperdoc/utils/todo.py +6 -2
  69. ripperdoc/utils/token_estimation.py +34 -0
  70. {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/METADATA +14 -1
  71. ripperdoc-0.2.5.dist-info/RECORD +107 -0
  72. ripperdoc-0.2.3.dist-info/RECORD +0 -95
  73. {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/WHEEL +0 -0
  74. {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/entry_points.txt +0 -0
  75. {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/licenses/LICENSE +0 -0
  76. {ripperdoc-0.2.3.dist-info → ripperdoc-0.2.5.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,14 @@
1
- """Gemini provider client."""
1
+ """Gemini provider client with function/tool calling support."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import asyncio
6
+ import copy
7
+ import inspect
5
8
  import os
6
9
  import time
7
- from typing import Any, Dict, List, Optional
10
+ from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, cast
11
+ from uuid import uuid4
8
12
 
9
13
  from ripperdoc.core.config import ModelProfile
10
14
  from ripperdoc.core.providers.base import (
@@ -12,43 +16,358 @@ from ripperdoc.core.providers.base import (
12
16
  ProviderClient,
13
17
  ProviderResponse,
14
18
  call_with_timeout_and_retries,
19
+ iter_with_timeout,
15
20
  )
21
+ from ripperdoc.core.query_utils import _normalize_tool_args, build_tool_description
16
22
  from ripperdoc.core.tool import Tool
17
23
  from ripperdoc.utils.log import get_logger
24
+ from ripperdoc.utils.session_usage import record_usage
25
+ from ripperdoc.core.query_utils import estimate_cost_usd
18
26
 
19
27
  logger = get_logger()
20
28
 
29
+ # Constants
30
+ GEMINI_SDK_IMPORT_ERROR = (
31
+ "Gemini client requires the 'google-genai' package. Install it with: pip install google-genai"
32
+ )
33
+ GEMINI_MODELS_ENDPOINT_ERROR = "Gemini client is missing 'models' endpoint"
34
+ GEMINI_GENERATE_CONTENT_ERROR = "Gemini client is missing generate_content() method"
35
+
36
+
37
+ def _classify_gemini_error(exc: Exception) -> tuple[str, str]:
38
+ """Classify a Gemini exception into error code and user-friendly message."""
39
+ exc_type = type(exc).__name__
40
+ exc_msg = str(exc)
41
+
42
+ # Try to import Google's exception types for more specific handling
43
+ try:
44
+ from google.api_core import exceptions as google_exceptions # type: ignore
45
+
46
+ if isinstance(exc, google_exceptions.Unauthenticated):
47
+ return "authentication_error", f"Authentication failed: {exc_msg}"
48
+ if isinstance(exc, google_exceptions.PermissionDenied):
49
+ return "permission_denied", f"Permission denied: {exc_msg}"
50
+ if isinstance(exc, google_exceptions.NotFound):
51
+ return "model_not_found", f"Model not found: {exc_msg}"
52
+ if isinstance(exc, google_exceptions.InvalidArgument):
53
+ if "context" in exc_msg.lower() or "token" in exc_msg.lower():
54
+ return "context_length_exceeded", f"Context length exceeded: {exc_msg}"
55
+ return "bad_request", f"Invalid request: {exc_msg}"
56
+ if isinstance(exc, google_exceptions.ResourceExhausted):
57
+ return "rate_limit", f"Rate limit exceeded: {exc_msg}"
58
+ if isinstance(exc, google_exceptions.ServiceUnavailable):
59
+ return "service_unavailable", f"Service unavailable: {exc_msg}"
60
+ if isinstance(exc, google_exceptions.GoogleAPICallError):
61
+ return "api_error", f"API error: {exc_msg}"
62
+ except ImportError:
63
+ pass
64
+
65
+ # Fallback for generic exceptions
66
+ if isinstance(exc, asyncio.TimeoutError):
67
+ return "timeout", f"Request timed out: {exc_msg}"
68
+ if isinstance(exc, ConnectionError):
69
+ return "connection_error", f"Connection error: {exc_msg}"
70
+ if "quota" in exc_msg.lower() or "limit" in exc_msg.lower():
71
+ return "rate_limit", f"Rate limit exceeded: {exc_msg}"
72
+ if "auth" in exc_msg.lower() or "key" in exc_msg.lower():
73
+ return "authentication_error", f"Authentication error: {exc_msg}"
74
+ if "not found" in exc_msg.lower():
75
+ return "model_not_found", f"Model not found: {exc_msg}"
76
+
77
+ return "unknown_error", f"Unexpected error ({exc_type}): {exc_msg}"
78
+
21
79
 
22
80
  def _extract_usage_metadata(payload: Any) -> Dict[str, int]:
23
81
  """Best-effort token extraction from Gemini responses."""
24
82
  usage = getattr(payload, "usage_metadata", None) or getattr(payload, "usageMetadata", None)
25
83
  if not usage:
26
84
  usage = getattr(payload, "usage", None)
27
- get = lambda key: int(getattr(usage, key, 0) or 0) if usage else 0 # noqa: E731
85
+ if not usage and getattr(payload, "candidates", None):
86
+ usage = getattr(payload.candidates[0], "usage_metadata", None)
87
+
88
+ def safe_get_int(key: str) -> int:
89
+ """Safely extract integer value from usage metadata."""
90
+ if not usage:
91
+ return 0
92
+ value = getattr(usage, key, 0)
93
+ return int(value) if value else 0
94
+
95
+ thought_tokens = safe_get_int("thoughts_token_count")
96
+ candidate_tokens = safe_get_int("candidates_token_count")
97
+
28
98
  return {
29
- "input_tokens": get("prompt_token_count") + get("cached_content_token_count"),
30
- "output_tokens": get("candidates_token_count"),
31
- "cache_read_input_tokens": get("cached_content_token_count"),
99
+ "input_tokens": safe_get_int("prompt_token_count")
100
+ + safe_get_int("cached_content_token_count"),
101
+ "output_tokens": candidate_tokens + thought_tokens,
102
+ "cache_read_input_tokens": safe_get_int("cached_content_token_count"),
32
103
  "cache_creation_input_tokens": 0,
33
104
  }
34
105
 
35
106
 
36
- def _collect_text_parts(candidate: Any) -> str:
37
- parts = getattr(candidate, "content", None)
38
- if not parts:
39
- return ""
40
- if isinstance(parts, list):
41
- texts = []
42
- for part in parts:
43
- text_val = getattr(part, "text", None) or getattr(part, "content", None)
44
- if isinstance(text_val, str):
45
- texts.append(text_val)
46
- return "".join(texts)
47
- return str(parts)
107
+ def _collect_parts(candidate: Any) -> List[Any]:
108
+ """Return a list of parts from a candidate regardless of SDK shape."""
109
+ content = getattr(candidate, "content", None)
110
+ if content is None:
111
+ return []
112
+ if hasattr(content, "parts"):
113
+ return list(getattr(content, "parts", []) or [])
114
+ if isinstance(content, list):
115
+ return content
116
+ return []
117
+
118
+
119
+ def _collect_text_from_parts(parts: List[Any]) -> str:
120
+ texts: List[str] = []
121
+ for part in parts:
122
+ text_val = (
123
+ getattr(part, "text", None)
124
+ or getattr(part, "content", None)
125
+ or getattr(part, "raw_text", None)
126
+ )
127
+ if isinstance(text_val, str):
128
+ texts.append(text_val)
129
+ return "".join(texts)
130
+
131
+
132
+ def _extract_function_calls(parts: List[Any]) -> List[Dict[str, Any]]:
133
+ calls: List[Dict[str, Any]] = []
134
+ for part in parts:
135
+ fn_call = getattr(part, "function_call", None) or getattr(part, "functionCall", None)
136
+ if not fn_call:
137
+ continue
138
+ name = getattr(fn_call, "name", None) or getattr(fn_call, "function_name", None)
139
+ args = getattr(fn_call, "args", None) or getattr(fn_call, "arguments", None) or {}
140
+ call_id = getattr(fn_call, "id", None) or getattr(fn_call, "call_id", None)
141
+ calls.append({"name": name, "args": _normalize_tool_args(args), "id": call_id})
142
+ return calls
143
+
144
+
145
+ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
146
+ """Inline $ref entries and drop $defs/$ref for Gemini Schema compatibility.
147
+
148
+ Gemini API doesn't support JSON Schema references, so this function
149
+ resolves all $ref pointers by inlining the referenced definitions.
150
+ """
151
+ definitions = copy.deepcopy(schema.get("$defs") or schema.get("definitions") or {})
152
+
153
+ def _resolve(node: Any) -> Any:
154
+ """Recursively resolve $ref pointers and remove unsupported fields."""
155
+ if isinstance(node, dict):
156
+ # Handle $ref resolution
157
+ ref = node.get("$ref")
158
+ if isinstance(ref, str) and ref.startswith("#/"):
159
+ ref_key = ref.split("/")[-1]
160
+ if ref_key in definitions:
161
+ return _resolve(copy.deepcopy(definitions[ref_key]))
162
+
163
+ # Process remaining fields, excluding schema metadata
164
+ resolved: Dict[str, Any] = {}
165
+ for key, value in node.items():
166
+ if key in {"$ref", "$defs", "definitions"}:
167
+ continue
168
+ resolved[key] = _resolve(value)
169
+ return resolved
170
+
171
+ if isinstance(node, list):
172
+ return [_resolve(item) for item in node]
173
+
174
+ return node
175
+
176
+ return cast(Dict[str, Any], _resolve(copy.deepcopy(schema)))
177
+
178
+
179
+ def _supports_stream_arg(fn: Any) -> bool:
180
+ """Return True if the callable appears to accept a 'stream' kwarg."""
181
+ try:
182
+ sig = inspect.signature(fn)
183
+ except (TypeError, ValueError):
184
+ # If we cannot inspect, avoid passing stream to prevent TypeErrors.
185
+ return False
186
+
187
+ for param in sig.parameters.values():
188
+ if param.kind == param.VAR_KEYWORD:
189
+ return True
190
+ if param.name == "stream":
191
+ return True
192
+ return False
193
+
194
+
195
+ def _build_thinking_config(max_thinking_tokens: int, model_name: str) -> Dict[str, Any]:
196
+ """Map max_thinking_tokens to Gemini thinking_config settings."""
197
+ if max_thinking_tokens <= 0:
198
+ return {}
199
+ name = (model_name or "").lower()
200
+ config: Dict[str, Any] = {"include_thoughts": True}
201
+ if "gemini-3" in name:
202
+ config["thinking_level"] = "low" if max_thinking_tokens <= 2048 else "high"
203
+ else:
204
+ config["thinking_budget"] = max_thinking_tokens
205
+ return config
206
+
207
+
208
+ def _collect_thoughts_from_parts(parts: List[Any]) -> List[str]:
209
+ """Extract thought summaries from parts flagged as thoughts."""
210
+ snippets: List[str] = []
211
+ for part in parts:
212
+ is_thought = getattr(part, "thought", None)
213
+ if is_thought is None and isinstance(part, dict):
214
+ is_thought = part.get("thought")
215
+ if not is_thought:
216
+ continue
217
+ text_val = (
218
+ getattr(part, "text", None)
219
+ or getattr(part, "content", None)
220
+ or getattr(part, "raw_text", None)
221
+ )
222
+ if isinstance(text_val, str):
223
+ snippets.append(text_val)
224
+ return snippets
225
+
226
+
227
+ async def _async_build_tool_declarations(tools: List[Tool[Any, Any]]) -> List[Dict[str, Any]]:
228
+ declarations: List[Dict[str, Any]] = []
229
+ try:
230
+ from google.genai import types as genai_types # type: ignore
231
+ except (ImportError, ModuleNotFoundError): # pragma: no cover - fallback when SDK not installed
232
+ genai_types = None # type: ignore[assignment]
233
+
234
+ for tool in tools:
235
+ description = await build_tool_description(tool, include_examples=True, max_examples=2)
236
+ parameters_schema = _flatten_schema(tool.input_schema.model_json_schema())
237
+ if genai_types:
238
+ func_decl = genai_types.FunctionDeclaration(
239
+ name=tool.name,
240
+ description=description,
241
+ parameters_json_schema=parameters_schema,
242
+ )
243
+ declarations.append(
244
+ func_decl.model_dump(mode="json", exclude_none=True)
245
+ )
246
+ else:
247
+ declarations.append(
248
+ {
249
+ "name": tool.name,
250
+ "description": description,
251
+ "parameters_json_schema": parameters_schema,
252
+ }
253
+ )
254
+ return declarations
255
+
256
+
257
+ def _convert_messages_to_genai_contents(
258
+ normalized_messages: List[Dict[str, Any]],
259
+ ) -> Tuple[List[Any], Dict[str, str]]:
260
+ """Map normalized OpenAI-style messages to Gemini content payloads.
261
+
262
+ Returns:
263
+ contents: List of Content-like dicts/objects
264
+ tool_name_by_id: Map of tool_call_id -> function name (for pairing responses)
265
+ """
266
+ tool_name_by_id: Dict[str, str] = {}
267
+ contents: List[Any] = []
268
+
269
+ # Lazy import to avoid hard dependency in tests.
270
+ try:
271
+ from google.genai import types as genai_types # type: ignore
272
+ except (ImportError, ModuleNotFoundError): # pragma: no cover - fallback when SDK not installed
273
+ genai_types = None # type: ignore[assignment]
274
+
275
+ def _mk_part_from_text(text: str) -> Any:
276
+ if genai_types:
277
+ return genai_types.Part(text=text)
278
+ return {"text": text}
279
+
280
+ def _mk_part_from_function_call(name: str, args: Dict[str, Any], call_id: Optional[str]) -> Any:
281
+ # Store mapping using actual call_id if available, otherwise generate one
282
+ actual_id = call_id or str(uuid4())
283
+ tool_name_by_id[actual_id] = name
284
+ if genai_types:
285
+ return genai_types.Part(function_call=genai_types.FunctionCall(name=name, args=args))
286
+ return {"function_call": {"name": name, "args": args, "id": actual_id}}
287
+
288
+ def _mk_part_from_function_response(
289
+ name: str, response: Dict[str, Any], call_id: Optional[str]
290
+ ) -> Any:
291
+ if call_id:
292
+ response = {**response, "call_id": call_id}
293
+ if genai_types:
294
+ return genai_types.Part.from_function_response(name=name, response=response)
295
+ payload = {"function_response": {"name": name, "response": response}}
296
+ if call_id:
297
+ payload["function_response"]["id"] = call_id
298
+ return payload
299
+
300
+ def _mk_content(role: str, parts: List[Any]) -> Any:
301
+ if genai_types:
302
+ return genai_types.Content(role=role, parts=parts)
303
+ return {"role": role, "parts": parts}
304
+
305
+ for message in normalized_messages:
306
+ role = message.get("role") or ""
307
+ msg_parts: List[Any] = []
308
+
309
+ # Assistant tool calls
310
+ for tool_call in message.get("tool_calls") or []:
311
+ func = tool_call.get("function") or {}
312
+ name = func.get("name") or ""
313
+ args = _normalize_tool_args(func.get("arguments") or {})
314
+ call_id = tool_call.get("id")
315
+ msg_parts.append(_mk_part_from_function_call(name, args, call_id))
316
+
317
+ content_value = message.get("content")
318
+ if isinstance(content_value, str) and content_value:
319
+ msg_parts.append(_mk_part_from_text(content_value))
320
+
321
+ if role == "tool":
322
+ call_id = message.get("tool_call_id") or ""
323
+ name = tool_name_by_id.get(call_id, call_id or "tool_response")
324
+ response = {"result": content_value}
325
+ msg_parts.append(_mk_part_from_function_response(name, response, call_id))
326
+ role = "user" # Tool responses are treated as user-provided context
327
+
328
+ if not msg_parts:
329
+ continue
330
+
331
+ mapped_role = "user" if role == "user" else "model"
332
+ contents.append(_mk_content(mapped_role, msg_parts))
333
+
334
+ return contents, tool_name_by_id
48
335
 
49
336
 
50
337
  class GeminiClient(ProviderClient):
51
- """Gemini client with streaming and basic text support."""
338
+ """Gemini client with streaming and function calling support."""
339
+
340
+ def __init__(self, client_factory: Optional[Any] = None) -> None:
341
+ self._client_factory = client_factory
342
+
343
+ async def _client(self, model_profile: ModelProfile) -> Any:
344
+ if self._client_factory is not None:
345
+ client = self._client_factory
346
+ if inspect.iscoroutinefunction(client):
347
+ return await client()
348
+ if inspect.isawaitable(client):
349
+ return await client # type: ignore[return-value]
350
+ if callable(client):
351
+ result = client()
352
+ return await result if inspect.isawaitable(result) else result
353
+ return client
354
+
355
+ try:
356
+ from google import genai # type: ignore
357
+ except (ImportError, ModuleNotFoundError) as exc: # pragma: no cover - import guard
358
+ raise RuntimeError(GEMINI_SDK_IMPORT_ERROR) from exc
359
+
360
+ client_kwargs: Dict[str, Any] = {}
361
+ api_key = (
362
+ model_profile.api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
363
+ )
364
+ if api_key:
365
+ client_kwargs["api_key"] = api_key
366
+ if model_profile.api_base:
367
+ from google.genai import types as genai_types # type: ignore
368
+
369
+ client_kwargs["http_options"] = genai_types.HttpOptions(base_url=model_profile.api_base)
370
+ return genai.Client(**client_kwargs)
52
371
 
53
372
  async def call(
54
373
  self,
@@ -62,96 +381,218 @@ class GeminiClient(ProviderClient):
62
381
  progress_callback: Optional[ProgressCallback],
63
382
  request_timeout: Optional[float],
64
383
  max_retries: int,
384
+ max_thinking_tokens: int,
65
385
  ) -> ProviderResponse:
386
+ start_time = time.time()
387
+
66
388
  try:
67
- import google.generativeai as genai # type: ignore
68
- except Exception as exc: # pragma: no cover - import guard
69
- msg = (
70
- "Gemini client requires the 'google-generativeai' package. "
71
- "Install it to enable Gemini support."
389
+ client = await self._client(model_profile)
390
+ except asyncio.CancelledError:
391
+ raise # Don't suppress task cancellation
392
+ except Exception as exc:
393
+ duration_ms = (time.time() - start_time) * 1000
394
+ error_code, error_message = _classify_gemini_error(exc)
395
+ logger.error(
396
+ "[gemini_client] Initialization failed",
397
+ extra={
398
+ "model": model_profile.model,
399
+ "error_code": error_code,
400
+ "error_message": error_message,
401
+ "duration_ms": round(duration_ms, 2),
402
+ },
72
403
  )
73
- logger.warning(msg, extra={"error": str(exc)})
74
- return ProviderResponse(
75
- content_blocks=[{"type": "text", "text": msg}],
76
- usage_tokens={},
77
- cost_usd=0.0,
78
- duration_ms=0.0,
404
+ return ProviderResponse.create_error(
405
+ error_code=error_code,
406
+ error_message=error_message,
407
+ duration_ms=duration_ms,
79
408
  )
80
409
 
410
+ declarations: List[Dict[str, Any]] = []
81
411
  if tools and tool_mode != "text":
82
- msg = (
83
- "Gemini client currently supports text-only responses; "
84
- "tool/function calling is not yet implemented."
412
+ declarations = await _async_build_tool_declarations(tools)
413
+
414
+ contents, _ = _convert_messages_to_genai_contents(normalized_messages)
415
+
416
+ config: Dict[str, Any] = {"system_instruction": system_prompt}
417
+ if model_profile.max_tokens:
418
+ config["max_output_tokens"] = model_profile.max_tokens
419
+ thinking_config = _build_thinking_config(max_thinking_tokens, model_profile.model)
420
+ if thinking_config:
421
+ try:
422
+ from google.genai import types as genai_types # type: ignore
423
+
424
+ config["thinking_config"] = genai_types.ThinkingConfig(**thinking_config)
425
+ except (ImportError, ModuleNotFoundError, TypeError, ValueError): # pragma: no cover - fallback when SDK not installed
426
+ config["thinking_config"] = thinking_config
427
+ if declarations:
428
+ config["tools"] = [{"function_declarations": declarations}]
429
+
430
+ generate_kwargs: Dict[str, Any] = {
431
+ "model": model_profile.model,
432
+ "contents": contents,
433
+ "config": config,
434
+ }
435
+ usage_tokens: Dict[str, int] = {}
436
+ collected_text: List[str] = []
437
+ function_calls: List[Dict[str, Any]] = []
438
+ reasoning_parts: List[str] = []
439
+ response_metadata: Dict[str, Any] = {}
440
+
441
+ async def _call_generate(streaming: bool) -> Any:
442
+ models_api = getattr(client, "models", None) or getattr(
443
+ getattr(client, "aio", None), "models", None
85
444
  )
86
- return ProviderResponse(
87
- content_blocks=[{"type": "text", "text": msg}],
88
- usage_tokens={},
89
- cost_usd=0.0,
90
- duration_ms=0.0,
445
+ if models_api is None:
446
+ raise RuntimeError(GEMINI_MODELS_ENDPOINT_ERROR)
447
+
448
+ generate_fn = getattr(models_api, "generate_content", None)
449
+ stream_fn = getattr(models_api, "generate_content_stream", None) or getattr(
450
+ models_api, "stream_generate_content", None
91
451
  )
92
452
 
93
- api_key = (
94
- model_profile.api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
95
- )
96
- genai.configure(api_key=api_key, client_options={"api_endpoint": model_profile.api_base})
453
+ if streaming:
454
+ if stream_fn:
455
+ result = stream_fn(**generate_kwargs)
456
+ if inspect.isawaitable(result):
457
+ return await result
458
+ return result
459
+
460
+ if generate_fn is None:
461
+ raise RuntimeError(GEMINI_GENERATE_CONTENT_ERROR)
97
462
 
98
- # Flatten normalized messages into a single text prompt (Gemini supports multi-turn, but keep it simple).
99
- prompt_parts: List[str] = [system_prompt]
100
- for msg in normalized_messages: # type: ignore[assignment]
101
- role: str = (
102
- str(msg.get("role", "")) if isinstance(msg, dict) else str(getattr(msg, "role", "")) # type: ignore[assignment]
463
+ if _supports_stream_arg(generate_fn):
464
+ gen_kwargs: Dict[str, Any] = dict(generate_kwargs)
465
+ gen_kwargs["stream"] = True
466
+ result = generate_fn(**gen_kwargs)
467
+ if inspect.isawaitable(result):
468
+ return await result
469
+ return result
470
+
471
+ # Fallback: non-streaming generate; wrap to keep downstream iterator usage
472
+ result = generate_fn(**generate_kwargs)
473
+ if inspect.isawaitable(result):
474
+ result = await result
475
+
476
+ async def _single_chunk_stream() -> AsyncIterator[Any]:
477
+ yield result
478
+
479
+ return _single_chunk_stream()
480
+
481
+ if generate_fn is None:
482
+ raise RuntimeError(GEMINI_GENERATE_CONTENT_ERROR)
483
+
484
+ try:
485
+ if stream:
486
+ stream_resp = await _call_generate(streaming=True)
487
+
488
+ # Normalize streams into an async iterator to avoid StopIteration surfacing through
489
+ # asyncio executors and to handle sync iterables.
490
+ def _to_async_iter(obj: Any) -> AsyncIterator[Any]:
491
+ """Convert various iterable types to async generator."""
492
+ if inspect.isasyncgen(obj) or hasattr(obj, "__aiter__"):
493
+
494
+ async def _wrap_async() -> AsyncIterator[Any]:
495
+ async for item in obj:
496
+ yield item
497
+
498
+ return _wrap_async()
499
+ if hasattr(obj, "__iter__"):
500
+
501
+ async def _wrap_sync() -> AsyncIterator[Any]:
502
+ for item in obj:
503
+ yield item
504
+
505
+ return _wrap_sync()
506
+
507
+ async def _single() -> AsyncIterator[Any]:
508
+ yield obj
509
+
510
+ return _single()
511
+
512
+ stream_iter = _to_async_iter(stream_resp)
513
+
514
+ async for chunk in iter_with_timeout(stream_iter, request_timeout):
515
+ candidates = getattr(chunk, "candidates", None) or []
516
+ for candidate in candidates:
517
+ parts = _collect_parts(candidate)
518
+ text_chunk = _collect_text_from_parts(parts)
519
+ if progress_callback:
520
+ if text_chunk:
521
+ try:
522
+ await progress_callback(text_chunk)
523
+ except (RuntimeError, ValueError, TypeError, OSError) as cb_exc:
524
+ logger.warning(
525
+ "[gemini_client] Stream callback failed: %s: %s",
526
+ type(cb_exc).__name__, cb_exc,
527
+ )
528
+ if text_chunk:
529
+ collected_text.append(text_chunk)
530
+ reasoning_parts.extend(_collect_thoughts_from_parts(parts))
531
+ function_calls.extend(_extract_function_calls(parts))
532
+ usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
533
+ else:
534
+ # Use retry logic for non-streaming calls
535
+ response = await call_with_timeout_and_retries(
536
+ lambda: _call_generate(streaming=False),
537
+ request_timeout,
538
+ max_retries,
539
+ )
540
+ candidates = getattr(response, "candidates", None) or []
541
+ if candidates:
542
+ parts = _collect_parts(candidates[0])
543
+ collected_text.append(_collect_text_from_parts(parts))
544
+ reasoning_parts.extend(_collect_thoughts_from_parts(parts))
545
+ function_calls.extend(_extract_function_calls(parts))
546
+ else:
547
+ # Fallback: try to read text directly
548
+ collected_text.append(getattr(response, "text", "") or "")
549
+ usage_tokens = _extract_usage_metadata(response)
550
+ except asyncio.CancelledError:
551
+ raise # Don't suppress task cancellation
552
+ except Exception as exc:
553
+ duration_ms = (time.time() - start_time) * 1000
554
+ error_code, error_message = _classify_gemini_error(exc)
555
+ logger.error(
556
+ "[gemini_client] API call failed",
557
+ extra={
558
+ "model": model_profile.model,
559
+ "error_code": error_code,
560
+ "error_message": error_message,
561
+ "duration_ms": round(duration_ms, 2),
562
+ },
563
+ )
564
+ return ProviderResponse.create_error(
565
+ error_code=error_code,
566
+ error_message=error_message,
567
+ duration_ms=duration_ms,
103
568
  )
104
- content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", "")
105
- if isinstance(content, list):
106
- for item in content:
107
- text_val = (
108
- getattr(item, "text", None)
109
- or item.get("text", "") # type: ignore[union-attr]
110
- if isinstance(item, dict)
111
- else ""
112
- )
113
- if text_val:
114
- prompt_parts.append(f"{role}: {text_val}")
115
- elif isinstance(content, str):
116
- prompt_parts.append(f"{role}: {content}")
117
- full_prompt = "\n".join(part for part in prompt_parts if part)
118
-
119
- model = genai.GenerativeModel(model_profile.model)
120
- collected_text: List[str] = []
121
- start_time = time.time()
122
569
 
123
- async def _stream_request() -> Dict[str, Dict[str, int]]:
124
- stream_resp = model.generate_content(full_prompt, stream=True)
125
- usage_tokens: Dict[str, int] = {}
126
- for chunk in stream_resp:
127
- text_delta = _collect_text_parts(chunk)
128
- if text_delta:
129
- collected_text.append(text_delta)
130
- if progress_callback:
131
- try:
132
- await progress_callback(text_delta)
133
- except Exception:
134
- logger.exception("[gemini_client] Stream callback failed")
135
- usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
136
- return {"usage": usage_tokens}
137
-
138
- async def _non_stream_request() -> Any:
139
- return model.generate_content(full_prompt)
140
-
141
- response: Any = await call_with_timeout_and_retries(
142
- _stream_request if stream and progress_callback else _non_stream_request,
143
- request_timeout,
144
- max_retries,
145
- )
570
+ content_blocks: List[Dict[str, Any]] = []
571
+ combined_text = "".join(collected_text).strip()
572
+ if combined_text:
573
+ content_blocks.append({"type": "text", "text": combined_text})
574
+ if reasoning_parts:
575
+ response_metadata["reasoning_content"] = "".join(reasoning_parts)
146
576
 
147
- duration_ms = (time.time() - start_time) * 1000
148
- usage_tokens = _extract_usage_metadata(response)
149
- cost_usd = 0.0 # Pricing unknown; leave as 0
577
+ for call in function_calls:
578
+ if not call.get("name"):
579
+ continue
580
+ content_blocks.append(
581
+ {
582
+ "type": "tool_use",
583
+ "tool_use_id": call.get("id") or str(uuid4()),
584
+ "name": call["name"],
585
+ "input": call.get("args") or {},
586
+ }
587
+ )
150
588
 
151
- content_blocks = (
152
- [{"type": "text", "text": "".join(collected_text)}]
153
- if collected_text
154
- else [{"type": "text", "text": _collect_text_parts(response)}]
589
+ duration_ms = (time.time() - start_time) * 1000
590
+ cost_usd = estimate_cost_usd(model_profile, usage_tokens) if usage_tokens else 0.0
591
+ record_usage(
592
+ model_profile.model,
593
+ duration_ms=duration_ms,
594
+ cost_usd=cost_usd,
595
+ **(usage_tokens or {}),
155
596
  )
156
597
 
157
598
  logger.info(
@@ -161,12 +602,14 @@ class GeminiClient(ProviderClient):
161
602
  "duration_ms": round(duration_ms, 2),
162
603
  "tool_mode": tool_mode,
163
604
  "stream": stream,
605
+ "function_call_count": len(function_calls),
164
606
  },
165
607
  )
166
608
 
167
609
  return ProviderResponse(
168
- content_blocks=content_blocks,
610
+ content_blocks=content_blocks or [{"type": "text", "text": ""}],
169
611
  usage_tokens=usage_tokens,
170
612
  cost_usd=cost_usd,
171
613
  duration_ms=duration_ms,
614
+ metadata=response_metadata,
172
615
  )