ripperdoc 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,13 @@
1
- """Gemini provider client."""
1
+ """Gemini provider client with function/tool calling support."""
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import copy
6
+ import inspect
5
7
  import os
6
8
  import time
7
- from typing import Any, Dict, List, Optional
9
+ from typing import Any, AsyncIterable, AsyncIterator, Dict, List, Optional, Tuple, cast
10
+ from uuid import uuid4
8
11
 
9
12
  from ripperdoc.core.config import ModelProfile
10
13
  from ripperdoc.core.providers.base import (
@@ -12,43 +15,273 @@ from ripperdoc.core.providers.base import (
12
15
  ProviderClient,
13
16
  ProviderResponse,
14
17
  call_with_timeout_and_retries,
18
+ iter_with_timeout,
15
19
  )
20
+ from ripperdoc.core.query_utils import _normalize_tool_args, build_tool_description
16
21
  from ripperdoc.core.tool import Tool
17
22
  from ripperdoc.utils.log import get_logger
23
+ from ripperdoc.utils.session_usage import record_usage
24
+ from ripperdoc.core.query_utils import estimate_cost_usd
18
25
 
19
26
  logger = get_logger()
20
27
 
28
+ # Constants
29
+ GEMINI_SDK_IMPORT_ERROR = (
30
+ "Gemini client requires the 'google-genai' package. "
31
+ "Install it with: pip install google-genai"
32
+ )
33
+ GEMINI_MODELS_ENDPOINT_ERROR = "Gemini client is missing 'models' endpoint"
34
+ GEMINI_GENERATE_CONTENT_ERROR = "Gemini client is missing generate_content() method"
35
+
21
36
 
22
37
  def _extract_usage_metadata(payload: Any) -> Dict[str, int]:
23
38
  """Best-effort token extraction from Gemini responses."""
24
39
  usage = getattr(payload, "usage_metadata", None) or getattr(payload, "usageMetadata", None)
25
40
  if not usage:
26
41
  usage = getattr(payload, "usage", None)
27
- get = lambda key: int(getattr(usage, key, 0) or 0) if usage else 0 # noqa: E731
42
+ if not usage and getattr(payload, "candidates", None):
43
+ usage = getattr(payload.candidates[0], "usage_metadata", None)
44
+
45
+ def safe_get_int(key: str) -> int:
46
+ """Safely extract integer value from usage metadata."""
47
+ if not usage:
48
+ return 0
49
+ value = getattr(usage, key, 0)
50
+ return int(value) if value else 0
51
+
28
52
  return {
29
- "input_tokens": get("prompt_token_count") + get("cached_content_token_count"),
30
- "output_tokens": get("candidates_token_count"),
31
- "cache_read_input_tokens": get("cached_content_token_count"),
53
+ "input_tokens": safe_get_int("prompt_token_count") + safe_get_int("cached_content_token_count"),
54
+ "output_tokens": safe_get_int("candidates_token_count"),
55
+ "cache_read_input_tokens": safe_get_int("cached_content_token_count"),
32
56
  "cache_creation_input_tokens": 0,
33
57
  }
34
58
 
35
59
 
36
- def _collect_text_parts(candidate: Any) -> str:
37
- parts = getattr(candidate, "content", None)
38
- if not parts:
39
- return ""
40
- if isinstance(parts, list):
41
- texts = []
42
- for part in parts:
43
- text_val = getattr(part, "text", None) or getattr(part, "content", None)
44
- if isinstance(text_val, str):
45
- texts.append(text_val)
46
- return "".join(texts)
47
- return str(parts)
60
+ def _collect_parts(candidate: Any) -> List[Any]:
61
+ """Return a list of parts from a candidate regardless of SDK shape."""
62
+ content = getattr(candidate, "content", None)
63
+ if content is None:
64
+ return []
65
+ if hasattr(content, "parts"):
66
+ return list(getattr(content, "parts", []) or [])
67
+ if isinstance(content, list):
68
+ return content
69
+ return []
70
+
71
+
72
+ def _collect_text_from_parts(parts: List[Any]) -> str:
73
+ texts: List[str] = []
74
+ for part in parts:
75
+ text_val = getattr(part, "text", None) or getattr(part, "content", None) or getattr(
76
+ part, "raw_text", None
77
+ )
78
+ if isinstance(text_val, str):
79
+ texts.append(text_val)
80
+ return "".join(texts)
81
+
82
+
83
+ def _extract_function_calls(parts: List[Any]) -> List[Dict[str, Any]]:
84
+ calls: List[Dict[str, Any]] = []
85
+ for part in parts:
86
+ fn_call = getattr(part, "function_call", None) or getattr(part, "functionCall", None)
87
+ if not fn_call:
88
+ continue
89
+ name = getattr(fn_call, "name", None) or getattr(fn_call, "function_name", None)
90
+ args = getattr(fn_call, "args", None) or getattr(fn_call, "arguments", None) or {}
91
+ call_id = getattr(fn_call, "id", None) or getattr(fn_call, "call_id", None)
92
+ calls.append({"name": name, "args": _normalize_tool_args(args), "id": call_id})
93
+ return calls
94
+
95
+
96
+ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
97
+ """Inline $ref entries and drop $defs/$ref for Gemini Schema compatibility.
98
+
99
+ Gemini API doesn't support JSON Schema references, so this function
100
+ resolves all $ref pointers by inlining the referenced definitions.
101
+ """
102
+ definitions = copy.deepcopy(schema.get("$defs") or schema.get("definitions") or {})
103
+
104
+ def _resolve(node: Any) -> Any:
105
+ """Recursively resolve $ref pointers and remove unsupported fields."""
106
+ if isinstance(node, dict):
107
+ # Handle $ref resolution
108
+ ref = node.get("$ref")
109
+ if isinstance(ref, str) and ref.startswith("#/"):
110
+ ref_key = ref.split("/")[-1]
111
+ if ref_key in definitions:
112
+ return _resolve(copy.deepcopy(definitions[ref_key]))
113
+
114
+ # Process remaining fields, excluding schema metadata
115
+ resolved: Dict[str, Any] = {}
116
+ for key, value in node.items():
117
+ if key in {"$ref", "$defs", "definitions"}:
118
+ continue
119
+ resolved[key] = _resolve(value)
120
+ return resolved
121
+
122
+ if isinstance(node, list):
123
+ return [_resolve(item) for item in node]
124
+
125
+ return node
126
+
127
+ return cast(Dict[str, Any], _resolve(copy.deepcopy(schema)))
128
+
129
+
130
+ def _supports_stream_arg(fn: Any) -> bool:
131
+ """Return True if the callable appears to accept a 'stream' kwarg."""
132
+ try:
133
+ sig = inspect.signature(fn)
134
+ except (TypeError, ValueError):
135
+ # If we cannot inspect, avoid passing stream to prevent TypeErrors.
136
+ return False
137
+
138
+ for param in sig.parameters.values():
139
+ if param.kind == param.VAR_KEYWORD:
140
+ return True
141
+ if param.name == "stream":
142
+ return True
143
+ return False
144
+
145
+
146
+ async def _async_build_tool_declarations(tools: List[Tool[Any, Any]]) -> List[Dict[str, Any]]:
147
+ declarations: List[Dict[str, Any]] = []
148
+ try:
149
+ from google.genai import types as genai_types # type: ignore
150
+ except Exception: # pragma: no cover - fallback when SDK not installed
151
+ genai_types = None
152
+
153
+ for tool in tools:
154
+ description = await build_tool_description(tool, include_examples=True, max_examples=2)
155
+ parameters_schema = _flatten_schema(tool.input_schema.model_json_schema())
156
+ if genai_types:
157
+ declarations.append(
158
+ genai_types.FunctionDeclaration(
159
+ name=tool.name,
160
+ description=description,
161
+ parameters=genai_types.Schema(**parameters_schema),
162
+ )
163
+ )
164
+ else:
165
+ declarations.append(
166
+ {"name": tool.name, "description": description, "parameters": parameters_schema}
167
+ )
168
+ return declarations
169
+
170
+
171
+ def _convert_messages_to_genai_contents(
172
+ normalized_messages: List[Dict[str, Any]],
173
+ ) -> Tuple[List[Any], Dict[str, str]]:
174
+ """Map normalized OpenAI-style messages to Gemini content payloads.
175
+
176
+ Returns:
177
+ contents: List of Content-like dicts/objects
178
+ tool_name_by_id: Map of tool_call_id -> function name (for pairing responses)
179
+ """
180
+ tool_name_by_id: Dict[str, str] = {}
181
+ contents: List[Any] = []
182
+
183
+ # Lazy import to avoid hard dependency in tests.
184
+ try:
185
+ from google.genai import types as genai_types # type: ignore
186
+ except Exception: # pragma: no cover - fallback when SDK not installed
187
+ genai_types = None
188
+
189
+ def _mk_part_from_text(text: str) -> Any:
190
+ if genai_types:
191
+ return genai_types.Part(text=text)
192
+ return {"text": text}
193
+
194
+ def _mk_part_from_function_call(name: str, args: Dict[str, Any], call_id: Optional[str]) -> Any:
195
+ # Store mapping using actual call_id if available, otherwise generate one
196
+ actual_id = call_id or str(uuid4())
197
+ tool_name_by_id[actual_id] = name
198
+ if genai_types:
199
+ return genai_types.Part(function_call=genai_types.FunctionCall(name=name, args=args))
200
+ return {"function_call": {"name": name, "args": args, "id": actual_id}}
201
+
202
+ def _mk_part_from_function_response(
203
+ name: str, response: Dict[str, Any], call_id: Optional[str]
204
+ ) -> Any:
205
+ if call_id:
206
+ response = {**response, "call_id": call_id}
207
+ if genai_types:
208
+ return genai_types.Part.from_function_response(name=name, response=response)
209
+ payload = {"function_response": {"name": name, "response": response}}
210
+ if call_id:
211
+ payload["function_response"]["id"] = call_id
212
+ return payload
213
+
214
+ def _mk_content(role: str, parts: List[Any]) -> Any:
215
+ if genai_types:
216
+ return genai_types.Content(role=role, parts=parts)
217
+ return {"role": role, "parts": parts}
218
+
219
+ for message in normalized_messages:
220
+ role = message.get("role") or ""
221
+ msg_parts: List[Any] = []
222
+
223
+ # Assistant tool calls
224
+ for tool_call in message.get("tool_calls") or []:
225
+ func = tool_call.get("function") or {}
226
+ name = func.get("name") or ""
227
+ args = _normalize_tool_args(func.get("arguments") or {})
228
+ call_id = tool_call.get("id")
229
+ msg_parts.append(_mk_part_from_function_call(name, args, call_id))
230
+
231
+ content_value = message.get("content")
232
+ if isinstance(content_value, str) and content_value:
233
+ msg_parts.append(_mk_part_from_text(content_value))
234
+
235
+ if role == "tool":
236
+ call_id = message.get("tool_call_id") or ""
237
+ name = tool_name_by_id.get(call_id, call_id or "tool_response")
238
+ response = {"result": content_value}
239
+ msg_parts.append(_mk_part_from_function_response(name, response, call_id))
240
+ role = "user" # Tool responses are treated as user-provided context
241
+
242
+ if not msg_parts:
243
+ continue
244
+
245
+ mapped_role = "user" if role == "user" else "model"
246
+ contents.append(_mk_content(mapped_role, msg_parts))
247
+
248
+ return contents, tool_name_by_id
48
249
 
49
250
 
50
251
  class GeminiClient(ProviderClient):
51
- """Gemini client with streaming and basic text support."""
252
+ """Gemini client with streaming and function calling support."""
253
+
254
+ def __init__(self, client_factory: Optional[Any] = None) -> None:
255
+ self._client_factory = client_factory
256
+
257
+ async def _client(self, model_profile: ModelProfile) -> Any:
258
+ if self._client_factory is not None:
259
+ client = self._client_factory
260
+ if inspect.iscoroutinefunction(client):
261
+ return await client()
262
+ if inspect.isawaitable(client):
263
+ return await client # type: ignore[return-value]
264
+ if callable(client):
265
+ result = client()
266
+ return await result if inspect.isawaitable(result) else result
267
+ return client
268
+
269
+ try:
270
+ from google import genai # type: ignore
271
+ except Exception as exc: # pragma: no cover - import guard
272
+ raise RuntimeError(GEMINI_SDK_IMPORT_ERROR) from exc
273
+
274
+ client_kwargs: Dict[str, Any] = {}
275
+ api_key = model_profile.api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
276
+ if api_key:
277
+ client_kwargs["api_key"] = api_key
278
+ if model_profile.api_base:
279
+ from google.genai import types as genai_types # type: ignore
280
+
281
+ client_kwargs["http_options"] = genai_types.HttpOptions(
282
+ base_url=model_profile.api_base
283
+ )
284
+ return genai.Client(**client_kwargs)
52
285
 
53
286
  async def call(
54
287
  self,
@@ -63,95 +296,185 @@ class GeminiClient(ProviderClient):
63
296
  request_timeout: Optional[float],
64
297
  max_retries: int,
65
298
  ) -> ProviderResponse:
299
+ start_time = time.time()
300
+
66
301
  try:
67
- import google.generativeai as genai # type: ignore
68
- except Exception as exc: # pragma: no cover - import guard
69
- msg = (
70
- "Gemini client requires the 'google-generativeai' package. "
71
- "Install it to enable Gemini support."
72
- )
73
- logger.warning(msg, extra={"error": str(exc)})
302
+ client = await self._client(model_profile)
303
+ except Exception as exc:
304
+ msg = str(exc)
305
+ logger.warning("[gemini_client] Initialization failed", extra={"error": msg})
74
306
  return ProviderResponse(
75
307
  content_blocks=[{"type": "text", "text": msg}],
76
308
  usage_tokens={},
77
309
  cost_usd=0.0,
78
- duration_ms=0.0,
310
+ duration_ms=(time.time() - start_time) * 1000,
79
311
  )
80
312
 
313
+ declarations: List[Dict[str, Any]] = []
81
314
  if tools and tool_mode != "text":
82
- msg = (
83
- "Gemini client currently supports text-only responses; "
84
- "tool/function calling is not yet implemented."
315
+ declarations = await _async_build_tool_declarations(tools)
316
+
317
+ contents, _ = _convert_messages_to_genai_contents(normalized_messages)
318
+
319
+ config: Dict[str, Any] = {"system_instruction": system_prompt}
320
+ if model_profile.max_tokens:
321
+ config["max_output_tokens"] = model_profile.max_tokens
322
+ if declarations:
323
+ try:
324
+ from google.genai import types as genai_types # type: ignore
325
+
326
+ config["tools"] = [genai_types.Tool(function_declarations=declarations)]
327
+ except Exception: # pragma: no cover - fallback when SDK not installed
328
+ config["tools"] = [{"function_declarations": declarations}]
329
+
330
+ generate_kwargs: Dict[str, Any] = {
331
+ "model": model_profile.model,
332
+ "contents": contents,
333
+ "config": config,
334
+ }
335
+ usage_tokens: Dict[str, int] = {}
336
+ collected_text: List[str] = []
337
+ function_calls: List[Dict[str, Any]] = []
338
+
339
+ async def _call_generate(streaming: bool) -> Any:
340
+ models_api = getattr(client, "models", None) or getattr(
341
+ getattr(client, "aio", None), "models", None
342
+ )
343
+ if models_api is None:
344
+ raise RuntimeError(GEMINI_MODELS_ENDPOINT_ERROR)
345
+
346
+ generate_fn = getattr(models_api, "generate_content", None)
347
+ stream_fn = getattr(models_api, "generate_content_stream", None) or getattr(
348
+ models_api, "stream_generate_content", None
85
349
  )
350
+
351
+ if streaming:
352
+ if stream_fn:
353
+ result = stream_fn(**generate_kwargs)
354
+ if inspect.isawaitable(result):
355
+ return await result
356
+ return result
357
+
358
+ if generate_fn is None:
359
+ raise RuntimeError(GEMINI_GENERATE_CONTENT_ERROR)
360
+
361
+ if _supports_stream_arg(generate_fn):
362
+ gen_kwargs: Dict[str, Any] = dict(generate_kwargs)
363
+ gen_kwargs["stream"] = True
364
+ result = generate_fn(**gen_kwargs)
365
+ if inspect.isawaitable(result):
366
+ return await result
367
+ return result
368
+
369
+ # Fallback: non-streaming generate; wrap to keep downstream iterator usage
370
+ result = generate_fn(**generate_kwargs)
371
+ if inspect.isawaitable(result):
372
+ result = await result
373
+
374
+ async def _single_chunk_stream() -> AsyncIterator[Any]:
375
+ yield result
376
+
377
+ return _single_chunk_stream()
378
+
379
+ if generate_fn is None:
380
+ raise RuntimeError(GEMINI_GENERATE_CONTENT_ERROR)
381
+
382
+ result = generate_fn(**generate_kwargs)
383
+ if inspect.isawaitable(result):
384
+ return await result
385
+ return result
386
+
387
+ try:
388
+ if stream:
389
+ stream_resp = await _call_generate(streaming=True)
390
+
391
+ # Normalize streams into an async iterator to avoid StopIteration surfacing through
392
+ # asyncio executors and to handle sync iterables.
393
+ def _to_async_iter(obj: Any) -> AsyncIterator[Any]:
394
+ """Convert various iterable types to async generator."""
395
+ if inspect.isasyncgen(obj) or hasattr(obj, "__aiter__"):
396
+ async def _wrap_async() -> AsyncIterator[Any]:
397
+ async for item in obj:
398
+ yield item
399
+
400
+ return _wrap_async()
401
+ if hasattr(obj, "__iter__"):
402
+ async def _wrap_sync() -> AsyncIterator[Any]:
403
+ for item in obj:
404
+ yield item
405
+
406
+ return _wrap_sync()
407
+
408
+ async def _single() -> AsyncIterator[Any]:
409
+ yield obj
410
+
411
+ return _single()
412
+
413
+ stream_iter = _to_async_iter(stream_resp)
414
+
415
+ async for chunk in iter_with_timeout(stream_iter, request_timeout):
416
+ candidates = getattr(chunk, "candidates", None) or []
417
+ for candidate in candidates:
418
+ parts = _collect_parts(candidate)
419
+ if progress_callback:
420
+ text_delta = _collect_text_from_parts(parts)
421
+ if text_delta:
422
+ try:
423
+ await progress_callback(text_delta)
424
+ except Exception:
425
+ logger.exception("[gemini_client] Stream callback failed")
426
+ collected_text.append(_collect_text_from_parts(parts))
427
+ function_calls.extend(_extract_function_calls(parts))
428
+ usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
429
+ else:
430
+ # Use retry logic for non-streaming calls
431
+ response = await call_with_timeout_and_retries(
432
+ lambda: _call_generate(streaming=False),
433
+ request_timeout,
434
+ max_retries,
435
+ )
436
+ candidates = getattr(response, "candidates", None) or []
437
+ if candidates:
438
+ parts = _collect_parts(candidates[0])
439
+ collected_text.append(_collect_text_from_parts(parts))
440
+ function_calls.extend(_extract_function_calls(parts))
441
+ else:
442
+ # Fallback: try to read text directly
443
+ collected_text.append(getattr(response, "text", "") or "")
444
+ usage_tokens = _extract_usage_metadata(response)
445
+ except Exception as exc:
446
+ logger.exception("[gemini_client] Error during call", extra={"error": str(exc)})
86
447
  return ProviderResponse(
87
- content_blocks=[{"type": "text", "text": msg}],
448
+ content_blocks=[{"type": "text", "text": f"Gemini call failed: {exc}"}],
88
449
  usage_tokens={},
89
450
  cost_usd=0.0,
90
- duration_ms=0.0,
451
+ duration_ms=(time.time() - start_time) * 1000,
91
452
  )
92
453
 
93
- api_key = (
94
- model_profile.api_key or os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
95
- )
96
- genai.configure(api_key=api_key, client_options={"api_endpoint": model_profile.api_base})
454
+ content_blocks: List[Dict[str, Any]] = []
455
+ combined_text = "".join(collected_text).strip()
456
+ if combined_text:
457
+ content_blocks.append({"type": "text", "text": combined_text})
97
458
 
98
- # Flatten normalized messages into a single text prompt (Gemini supports multi-turn, but keep it simple).
99
- prompt_parts: List[str] = [system_prompt]
100
- for msg in normalized_messages: # type: ignore[assignment]
101
- role: str = (
102
- str(msg.get("role", "")) if isinstance(msg, dict) else str(getattr(msg, "role", "")) # type: ignore[assignment]
459
+ for call in function_calls:
460
+ if not call.get("name"):
461
+ continue
462
+ content_blocks.append(
463
+ {
464
+ "type": "tool_use",
465
+ "tool_use_id": call.get("id") or str(uuid4()),
466
+ "name": call["name"],
467
+ "input": call.get("args") or {},
468
+ }
103
469
  )
104
- content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", "")
105
- if isinstance(content, list):
106
- for item in content:
107
- text_val = (
108
- getattr(item, "text", None)
109
- or item.get("text", "") # type: ignore[union-attr]
110
- if isinstance(item, dict)
111
- else ""
112
- )
113
- if text_val:
114
- prompt_parts.append(f"{role}: {text_val}")
115
- elif isinstance(content, str):
116
- prompt_parts.append(f"{role}: {content}")
117
- full_prompt = "\n".join(part for part in prompt_parts if part)
118
-
119
- model = genai.GenerativeModel(model_profile.model)
120
- collected_text: List[str] = []
121
- start_time = time.time()
122
-
123
- async def _stream_request() -> Dict[str, Dict[str, int]]:
124
- stream_resp = model.generate_content(full_prompt, stream=True)
125
- usage_tokens: Dict[str, int] = {}
126
- for chunk in stream_resp:
127
- text_delta = _collect_text_parts(chunk)
128
- if text_delta:
129
- collected_text.append(text_delta)
130
- if progress_callback:
131
- try:
132
- await progress_callback(text_delta)
133
- except Exception:
134
- logger.exception("[gemini_client] Stream callback failed")
135
- usage_tokens = _extract_usage_metadata(chunk) or usage_tokens
136
- return {"usage": usage_tokens}
137
-
138
- async def _non_stream_request() -> Any:
139
- return model.generate_content(full_prompt)
140
-
141
- response: Any = await call_with_timeout_and_retries(
142
- _stream_request if stream and progress_callback else _non_stream_request,
143
- request_timeout,
144
- max_retries,
145
- )
146
470
 
147
471
  duration_ms = (time.time() - start_time) * 1000
148
- usage_tokens = _extract_usage_metadata(response)
149
- cost_usd = 0.0 # Pricing unknown; leave as 0
150
-
151
- content_blocks = (
152
- [{"type": "text", "text": "".join(collected_text)}]
153
- if collected_text
154
- else [{"type": "text", "text": _collect_text_parts(response)}]
472
+ cost_usd = estimate_cost_usd(model_profile, usage_tokens) if usage_tokens else 0.0
473
+ record_usage(
474
+ model_profile.model,
475
+ duration_ms=duration_ms,
476
+ cost_usd=cost_usd,
477
+ **(usage_tokens or {}),
155
478
  )
156
479
 
157
480
  logger.info(
@@ -161,11 +484,12 @@ class GeminiClient(ProviderClient):
161
484
  "duration_ms": round(duration_ms, 2),
162
485
  "tool_mode": tool_mode,
163
486
  "stream": stream,
487
+ "function_call_count": len(function_calls),
164
488
  },
165
489
  )
166
490
 
167
491
  return ProviderResponse(
168
- content_blocks=content_blocks,
492
+ content_blocks=content_blocks or [{"type": "text", "text": ""}],
169
493
  usage_tokens=usage_tokens,
170
494
  cost_usd=cost_usd,
171
495
  duration_ms=duration_ms,