hindsight-api 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,502 @@
1
+ """
2
+ Google Gemini/VertexAI LLM provider.
3
+
4
+ This provider supports both:
5
+ 1. Gemini API (api.generativeai.google.com) with API key authentication
6
+ 2. Vertex AI with service account or Application Default Credentials (ADC)
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import logging
12
+ import os
13
+ import time
14
+ from typing import Any
15
+
16
+ from google import genai
17
+ from google.genai import errors as genai_errors
18
+ from google.genai import types as genai_types
19
+
20
+ from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError
21
+ from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
22
+ from hindsight_api.metrics import get_metrics_collector
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Vertex AI imports (optional)
27
+ try:
28
+ import google.auth
29
+ from google.oauth2 import service_account
30
+
31
+ VERTEXAI_AVAILABLE = True
32
+ except ImportError:
33
+ VERTEXAI_AVAILABLE = False
34
+
35
+
36
+ class GeminiLLM(LLMInterface):
37
+ """
38
+ LLM provider for Google Gemini and Vertex AI.
39
+
40
+ Supports:
41
+ - Gemini API: provider="gemini", requires api_key
42
+ - Vertex AI: provider="vertexai", requires project_id and region, uses ADC or service account
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ provider: str,
48
+ api_key: str,
49
+ base_url: str,
50
+ model: str,
51
+ reasoning_effort: str = "low",
52
+ **kwargs: Any,
53
+ ):
54
+ """Initialize Gemini/VertexAI LLM provider."""
55
+ super().__init__(provider, api_key, base_url, model, reasoning_effort, **kwargs)
56
+
57
+ self._client = None
58
+ self._is_vertexai = self.provider == "vertexai"
59
+
60
+ if self._is_vertexai:
61
+ self._init_vertexai(**kwargs)
62
+ else:
63
+ self._init_gemini()
64
+
65
+ def _init_gemini(self) -> None:
66
+ """Initialize Gemini API client."""
67
+ if not self.api_key:
68
+ raise ValueError("Gemini provider requires api_key")
69
+
70
+ self._client = genai.Client(api_key=self.api_key)
71
+ logger.info(f"Gemini API: model={self.model}")
72
+
73
+ def _init_vertexai(self, **kwargs: Any) -> None:
74
+ """Initialize Vertex AI client with project, region, and credentials."""
75
+ # Extract Vertex AI config from kwargs
76
+ project_id = kwargs.get("vertexai_project_id")
77
+ region = kwargs.get("vertexai_region", "us-central1")
78
+ service_account_key = kwargs.get("vertexai_service_account_key")
79
+ credentials = kwargs.get("vertexai_credentials") # Pre-loaded credentials object
80
+
81
+ if not project_id:
82
+ raise ValueError(
83
+ "HINDSIGHT_API_LLM_VERTEXAI_PROJECT_ID is required for Vertex AI provider. "
84
+ "Set it to your GCP project ID."
85
+ )
86
+
87
+ auth_method = "ADC"
88
+
89
+ # Use pre-loaded credentials if provided (passed from LLMProvider)
90
+ if credentials is not None:
91
+ auth_method = "service_account"
92
+ # Otherwise, load explicit service account credentials if path provided
93
+ elif service_account_key:
94
+ if not VERTEXAI_AVAILABLE:
95
+ raise ValueError(
96
+ "Vertex AI service account auth requires 'google-auth' package. "
97
+ "Install with: pip install google-auth"
98
+ )
99
+ credentials = service_account.Credentials.from_service_account_file(
100
+ service_account_key,
101
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
102
+ )
103
+ auth_method = "service_account"
104
+ logger.info(f"Vertex AI: Using service account key: {service_account_key}")
105
+
106
+ # Strip google/ prefix from model name — native SDK uses bare names
107
+ # e.g. "google/gemini-2.0-flash-lite-001" -> "gemini-2.0-flash-lite-001"
108
+ if self.model.startswith("google/"):
109
+ self.model = self.model[len("google/") :]
110
+
111
+ # Create Vertex AI client
112
+ client_kwargs: dict[str, Any] = {
113
+ "vertexai": True,
114
+ "project": project_id,
115
+ "location": region,
116
+ }
117
+ if credentials is not None:
118
+ client_kwargs["credentials"] = credentials
119
+
120
+ self._client = genai.Client(**client_kwargs)
121
+
122
+ logger.info(f"Vertex AI: project={project_id}, region={region}, model={self.model}, auth={auth_method}")
123
+
124
+ async def verify_connection(self) -> None:
125
+ """
126
+ Verify that the Gemini/VertexAI provider is configured correctly.
127
+
128
+ Raises:
129
+ RuntimeError: If the connection test fails.
130
+ """
131
+ try:
132
+ logger.info(f"Verifying {self.provider.upper()}: model={self.model}...")
133
+ await self.call(
134
+ messages=[{"role": "user", "content": "Say 'ok'"}],
135
+ max_completion_tokens=100,
136
+ max_retries=2,
137
+ initial_backoff=0.5,
138
+ max_backoff=2.0,
139
+ )
140
+ logger.info(f"{self.provider.upper()} connection verified successfully")
141
+ except Exception as e:
142
+ raise RuntimeError(f"Failed to verify {self.provider.upper()} connection: {e}") from e
143
+
144
+ async def call(
145
+ self,
146
+ messages: list[dict[str, str]],
147
+ response_format: Any | None = None,
148
+ max_completion_tokens: int | None = None,
149
+ temperature: float | None = None,
150
+ scope: str = "memory",
151
+ max_retries: int = 10,
152
+ initial_backoff: float = 1.0,
153
+ max_backoff: float = 60.0,
154
+ skip_validation: bool = False,
155
+ strict_schema: bool = False,
156
+ return_usage: bool = False,
157
+ ) -> Any:
158
+ """
159
+ Make a Gemini/VertexAI API call with retry logic.
160
+
161
+ Args:
162
+ messages: List of message dicts with 'role' and 'content'.
163
+ response_format: Optional Pydantic model for structured output.
164
+ max_completion_tokens: Maximum tokens in response (not supported by Gemini).
165
+ temperature: Sampling temperature (0.0-2.0).
166
+ scope: Scope identifier for tracking.
167
+ max_retries: Maximum retry attempts.
168
+ initial_backoff: Initial backoff time in seconds.
169
+ max_backoff: Maximum backoff time in seconds.
170
+ skip_validation: Return raw JSON without Pydantic validation.
171
+ strict_schema: Use strict JSON schema enforcement (not supported by Gemini).
172
+ return_usage: If True, return tuple (result, TokenUsage).
173
+
174
+ Returns:
175
+ If return_usage=False: Parsed response if response_format provided, else text.
176
+ If return_usage=True: Tuple of (result, TokenUsage).
177
+ """
178
+ start_time = time.time()
179
+
180
+ # Convert OpenAI-style messages to Gemini format
181
+ system_instruction = None
182
+ gemini_contents = []
183
+
184
+ for msg in messages:
185
+ role = msg.get("role", "user")
186
+ content = msg.get("content", "")
187
+
188
+ if role == "system":
189
+ if system_instruction:
190
+ system_instruction += "\n\n" + content
191
+ else:
192
+ system_instruction = content
193
+ elif role == "assistant":
194
+ gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
195
+ else:
196
+ gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
197
+
198
+ # Add JSON schema instruction if response_format is provided
199
+ if response_format is not None and hasattr(response_format, "model_json_schema"):
200
+ schema = response_format.model_json_schema()
201
+ schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2)}"
202
+ if system_instruction:
203
+ system_instruction += schema_msg
204
+ else:
205
+ system_instruction = schema_msg
206
+
207
+ # Build generation config
208
+ config_kwargs: dict[str, Any] = {}
209
+ if system_instruction:
210
+ config_kwargs["system_instruction"] = system_instruction
211
+ if response_format is not None:
212
+ config_kwargs["response_mime_type"] = "application/json"
213
+ config_kwargs["response_schema"] = response_format
214
+ if temperature is not None:
215
+ config_kwargs["temperature"] = temperature
216
+
217
+ generation_config = genai_types.GenerateContentConfig(**config_kwargs) if config_kwargs else None
218
+
219
+ last_exception = None
220
+
221
+ for attempt in range(max_retries + 1):
222
+ try:
223
+ response = await self._client.aio.models.generate_content(
224
+ model=self.model,
225
+ contents=gemini_contents,
226
+ config=generation_config,
227
+ )
228
+
229
+ content = response.text
230
+
231
+ # Handle empty response
232
+ if content is None:
233
+ block_reason = None
234
+ if hasattr(response, "candidates") and response.candidates:
235
+ candidate = response.candidates[0]
236
+ if hasattr(candidate, "finish_reason"):
237
+ block_reason = candidate.finish_reason
238
+
239
+ if attempt < max_retries:
240
+ logger.warning(f"Gemini returned empty response (reason: {block_reason}), retrying...")
241
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
242
+ await asyncio.sleep(backoff)
243
+ continue
244
+ else:
245
+ raise RuntimeError(f"Gemini returned empty response after {max_retries + 1} attempts")
246
+
247
+ # Parse structured output if requested
248
+ if response_format is not None:
249
+ json_data = json.loads(content)
250
+ if skip_validation:
251
+ result = json_data
252
+ else:
253
+ result = response_format.model_validate(json_data)
254
+ else:
255
+ result = content
256
+
257
+ # Extract token usage
258
+ input_tokens = 0
259
+ output_tokens = 0
260
+ if hasattr(response, "usage_metadata") and response.usage_metadata:
261
+ usage = response.usage_metadata
262
+ input_tokens = usage.prompt_token_count or 0
263
+ output_tokens = usage.candidates_token_count or 0
264
+
265
+ # Record metrics
266
+ duration = time.time() - start_time
267
+ metrics = get_metrics_collector()
268
+ metrics.record_llm_call(
269
+ provider=self.provider,
270
+ model=self.model,
271
+ scope=scope,
272
+ duration=duration,
273
+ input_tokens=input_tokens,
274
+ output_tokens=output_tokens,
275
+ success=True,
276
+ )
277
+
278
+ # Log slow calls
279
+ if duration > 10.0 and input_tokens > 0:
280
+ logger.info(
281
+ f"slow llm call: scope={scope}, model={self.provider}/{self.model}, "
282
+ f"input_tokens={input_tokens}, output_tokens={output_tokens}, "
283
+ f"time={duration:.3f}s"
284
+ )
285
+
286
+ if return_usage:
287
+ token_usage = TokenUsage(
288
+ input_tokens=input_tokens,
289
+ output_tokens=output_tokens,
290
+ total_tokens=input_tokens + output_tokens,
291
+ )
292
+ return result, token_usage
293
+ return result
294
+
295
+ except json.JSONDecodeError as e:
296
+ last_exception = e
297
+ if attempt < max_retries:
298
+ logger.warning("Gemini returned invalid JSON, retrying...")
299
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
300
+ await asyncio.sleep(backoff)
301
+ continue
302
+ else:
303
+ logger.error(f"Gemini returned invalid JSON after {max_retries + 1} attempts")
304
+ raise
305
+
306
+ except genai_errors.APIError as e:
307
+ # Fast fail on auth errors - these won't recover with retries
308
+ if e.code in (401, 403):
309
+ logger.error(f"Gemini auth error (HTTP {e.code}), not retrying: {str(e)}")
310
+ raise
311
+
312
+ # Retry on retryable errors (rate limits, server errors, client errors)
313
+ if e.code in (400, 429, 500, 502, 503, 504) or (e.code and e.code >= 500):
314
+ last_exception = e
315
+ if attempt < max_retries:
316
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
317
+ jitter = backoff * 0.2 * (2 * (time.time() % 1) - 1)
318
+ await asyncio.sleep(backoff + jitter)
319
+ else:
320
+ logger.error(f"Gemini API error after {max_retries + 1} attempts: {str(e)}")
321
+ raise
322
+ else:
323
+ logger.error(f"Gemini API error: {type(e).__name__}: {str(e)}")
324
+ raise
325
+
326
+ except Exception as e:
327
+ logger.error(f"Unexpected error during Gemini call: {type(e).__name__}: {str(e)}")
328
+ raise
329
+
330
+ if last_exception:
331
+ raise last_exception
332
+ raise RuntimeError("Gemini call failed after all retries")
333
+
334
+ async def call_with_tools(
335
+ self,
336
+ messages: list[dict[str, Any]],
337
+ tools: list[dict[str, Any]],
338
+ max_completion_tokens: int | None = None,
339
+ temperature: float | None = None,
340
+ scope: str = "tools",
341
+ max_retries: int = 5,
342
+ initial_backoff: float = 1.0,
343
+ max_backoff: float = 30.0,
344
+ tool_choice: str | dict[str, Any] = "auto",
345
+ ) -> LLMToolCallResult:
346
+ """
347
+ Make a Gemini/VertexAI API call with tool/function calling support.
348
+
349
+ Args:
350
+ messages: List of message dicts. Can include tool results with role='tool'.
351
+ tools: List of tool definitions in OpenAI format.
352
+ max_completion_tokens: Maximum tokens (not supported by Gemini).
353
+ temperature: Sampling temperature.
354
+ scope: Scope identifier for tracking.
355
+ max_retries: Maximum retry attempts.
356
+ initial_backoff: Initial backoff time in seconds.
357
+ max_backoff: Maximum backoff time in seconds.
358
+ tool_choice: How to choose tools (Gemini uses "auto" only).
359
+
360
+ Returns:
361
+ LLMToolCallResult with content and/or tool_calls.
362
+ """
363
+ start_time = time.time()
364
+
365
+ # Convert tools to Gemini format
366
+ gemini_tools = []
367
+ for tool in tools:
368
+ func = tool.get("function", {})
369
+ gemini_tools.append(
370
+ genai_types.Tool(
371
+ function_declarations=[
372
+ genai_types.FunctionDeclaration(
373
+ name=func.get("name", ""),
374
+ description=func.get("description", ""),
375
+ parameters=func.get("parameters"),
376
+ )
377
+ ]
378
+ )
379
+ )
380
+
381
+ # Convert messages
382
+ system_instruction = None
383
+ gemini_contents = []
384
+ for msg in messages:
385
+ role = msg.get("role", "user")
386
+ content = msg.get("content", "")
387
+
388
+ if role == "system":
389
+ system_instruction = (system_instruction + "\n\n" + content) if system_instruction else content
390
+ elif role == "tool":
391
+ # Gemini uses function_response
392
+ gemini_contents.append(
393
+ genai_types.Content(
394
+ role="user",
395
+ parts=[
396
+ genai_types.Part(
397
+ function_response=genai_types.FunctionResponse(
398
+ name=msg.get("name", ""),
399
+ response={"result": content},
400
+ )
401
+ )
402
+ ],
403
+ )
404
+ )
405
+ elif role == "assistant":
406
+ gemini_contents.append(genai_types.Content(role="model", parts=[genai_types.Part(text=content)]))
407
+ else:
408
+ gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))
409
+
410
+ config_kwargs: dict[str, Any] = {"tools": gemini_tools}
411
+ if system_instruction:
412
+ config_kwargs["system_instruction"] = system_instruction
413
+ if temperature is not None:
414
+ config_kwargs["temperature"] = temperature
415
+
416
+ config = genai_types.GenerateContentConfig(**config_kwargs)
417
+
418
+ last_exception = None
419
+ for attempt in range(max_retries + 1):
420
+ try:
421
+ response = await self._client.aio.models.generate_content(
422
+ model=self.model,
423
+ contents=gemini_contents,
424
+ config=config,
425
+ )
426
+
427
+ # Extract content and tool calls
428
+ content = None
429
+ tool_calls: list[LLMToolCall] = []
430
+
431
+ if response.candidates and response.candidates[0].content:
432
+ parts = response.candidates[0].content.parts
433
+ if parts:
434
+ for part in parts:
435
+ if hasattr(part, "text") and part.text:
436
+ content = part.text
437
+ if hasattr(part, "function_call") and part.function_call:
438
+ fc = part.function_call
439
+ tool_calls.append(
440
+ LLMToolCall(
441
+ id=f"gemini_{len(tool_calls)}",
442
+ name=fc.name,
443
+ arguments=dict(fc.args) if fc.args else {},
444
+ )
445
+ )
446
+
447
+ finish_reason = "tool_calls" if tool_calls else "stop"
448
+
449
+ # Extract token usage
450
+ input_tokens = 0
451
+ output_tokens = 0
452
+ if response.usage_metadata:
453
+ input_tokens = response.usage_metadata.prompt_token_count or 0
454
+ output_tokens = response.usage_metadata.candidates_token_count or 0
455
+
456
+ # Record metrics
457
+ duration = time.time() - start_time
458
+ metrics = get_metrics_collector()
459
+ metrics.record_llm_call(
460
+ provider=self.provider,
461
+ model=self.model,
462
+ scope=scope,
463
+ duration=duration,
464
+ input_tokens=input_tokens,
465
+ output_tokens=output_tokens,
466
+ success=True,
467
+ )
468
+
469
+ return LLMToolCallResult(
470
+ content=content,
471
+ tool_calls=tool_calls,
472
+ finish_reason=finish_reason,
473
+ input_tokens=input_tokens,
474
+ output_tokens=output_tokens,
475
+ )
476
+
477
+ except genai_errors.APIError as e:
478
+ # Fast fail on auth errors
479
+ if e.code in (401, 403):
480
+ logger.error(f"Gemini auth error (HTTP {e.code}), not retrying: {str(e)}")
481
+ raise
482
+
483
+ # Retry on retryable errors
484
+ last_exception = e
485
+ if attempt < max_retries:
486
+ backoff = min(initial_backoff * (2**attempt), max_backoff)
487
+ await asyncio.sleep(backoff)
488
+ continue
489
+ raise
490
+
491
+ except Exception as e:
492
+ logger.error(f"Unexpected error during Gemini tool call: {type(e).__name__}: {str(e)}")
493
+ raise
494
+
495
+ if last_exception:
496
+ raise last_exception
497
+ raise RuntimeError("Gemini tool call failed")
498
+
499
+ async def cleanup(self) -> None:
500
+ """Clean up resources (close connections, etc.)."""
501
+ # Gemini client doesn't require explicit cleanup
502
+ pass