posthoganalytics 6.7.5__py3-none-any.whl → 7.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. posthoganalytics/__init__.py +84 -7
  2. posthoganalytics/ai/anthropic/anthropic_async.py +30 -67
  3. posthoganalytics/ai/anthropic/anthropic_converter.py +40 -0
  4. posthoganalytics/ai/gemini/__init__.py +3 -0
  5. posthoganalytics/ai/gemini/gemini.py +1 -1
  6. posthoganalytics/ai/gemini/gemini_async.py +423 -0
  7. posthoganalytics/ai/gemini/gemini_converter.py +160 -24
  8. posthoganalytics/ai/langchain/callbacks.py +55 -11
  9. posthoganalytics/ai/openai/openai.py +27 -2
  10. posthoganalytics/ai/openai/openai_async.py +49 -5
  11. posthoganalytics/ai/openai/openai_converter.py +130 -0
  12. posthoganalytics/ai/sanitization.py +27 -5
  13. posthoganalytics/ai/types.py +1 -0
  14. posthoganalytics/ai/utils.py +32 -2
  15. posthoganalytics/client.py +338 -90
  16. posthoganalytics/contexts.py +81 -0
  17. posthoganalytics/exception_utils.py +250 -2
  18. posthoganalytics/feature_flags.py +26 -10
  19. posthoganalytics/flag_definition_cache.py +127 -0
  20. posthoganalytics/integrations/django.py +149 -50
  21. posthoganalytics/request.py +203 -23
  22. posthoganalytics/test/test_client.py +250 -22
  23. posthoganalytics/test/test_exception_capture.py +418 -0
  24. posthoganalytics/test/test_feature_flag_result.py +441 -2
  25. posthoganalytics/test/test_feature_flags.py +306 -102
  26. posthoganalytics/test/test_flag_definition_cache.py +612 -0
  27. posthoganalytics/test/test_module.py +0 -8
  28. posthoganalytics/test/test_request.py +536 -0
  29. posthoganalytics/test/test_utils.py +4 -1
  30. posthoganalytics/types.py +40 -0
  31. posthoganalytics/version.py +1 -1
  32. {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/METADATA +12 -12
  33. posthoganalytics-7.4.3.dist-info/RECORD +57 -0
  34. posthoganalytics-6.7.5.dist-info/RECORD +0 -54
  35. {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/WHEEL +0 -0
  36. {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/licenses/LICENSE +0 -0
  37. {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
1
+ import os
2
+ import time
3
+ import uuid
4
+ from typing import Any, Dict, Optional
5
+
6
+ from posthoganalytics.ai.types import TokenUsage, StreamingEventData
7
+ from posthoganalytics.ai.utils import merge_system_prompt
8
+
9
+ try:
10
+ from google import genai
11
+ except ImportError:
12
+ raise ModuleNotFoundError(
13
+ "Please install the Google Gemini SDK to use this feature: 'pip install google-genai'"
14
+ )
15
+
16
+ from posthoganalytics import setup
17
+ from posthoganalytics.ai.utils import (
18
+ call_llm_and_track_usage_async,
19
+ capture_streaming_event,
20
+ merge_usage_stats,
21
+ )
22
+ from posthoganalytics.ai.gemini.gemini_converter import (
23
+ extract_gemini_usage_from_chunk,
24
+ extract_gemini_content_from_chunk,
25
+ format_gemini_streaming_output,
26
+ )
27
+ from posthoganalytics.ai.sanitization import sanitize_gemini
28
+ from posthoganalytics.client import Client as PostHogClient
29
+
30
+
31
+ class AsyncClient:
32
+ """
33
+ An async drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog.
34
+
35
+ Usage:
36
+ client = AsyncClient(
37
+ api_key="your_api_key",
38
+ posthog_client=posthog_client,
39
+ posthog_distinct_id="default_user", # Optional defaults
40
+ posthog_properties={"team": "ai"} # Optional defaults
41
+ )
42
+ response = await client.models.generate_content(
43
+ model="gemini-2.0-flash",
44
+ contents=["Hello world"],
45
+ posthog_distinct_id="specific_user" # Override default
46
+ )
47
+ """
48
+
49
+ _ph_client: PostHogClient
50
+
51
+ def __init__(
52
+ self,
53
+ api_key: Optional[str] = None,
54
+ vertexai: Optional[bool] = None,
55
+ credentials: Optional[Any] = None,
56
+ project: Optional[str] = None,
57
+ location: Optional[str] = None,
58
+ debug_config: Optional[Any] = None,
59
+ http_options: Optional[Any] = None,
60
+ posthog_client: Optional[PostHogClient] = None,
61
+ posthog_distinct_id: Optional[str] = None,
62
+ posthog_properties: Optional[Dict[str, Any]] = None,
63
+ posthog_privacy_mode: bool = False,
64
+ posthog_groups: Optional[Dict[str, Any]] = None,
65
+ **kwargs,
66
+ ):
67
+ """
68
+ Args:
69
+ api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI)
70
+ vertexai: Whether to use Vertex AI authentication
71
+ credentials: Vertex AI credentials object
72
+ project: GCP project ID for Vertex AI
73
+ location: GCP location for Vertex AI
74
+ debug_config: Debug configuration for the client
75
+ http_options: HTTP options for the client
76
+ posthog_client: PostHog client for tracking usage
77
+ posthog_distinct_id: Default distinct ID for all calls (can be overridden per call)
78
+ posthog_properties: Default properties for all calls (can be overridden per call)
79
+ posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call)
80
+ posthog_groups: Default groups for all calls (can be overridden per call)
81
+ **kwargs: Additional arguments (for future compatibility)
82
+ """
83
+
84
+ self._ph_client = posthog_client or setup()
85
+
86
+ if self._ph_client is None:
87
+ raise ValueError("posthog_client is required for PostHog tracking")
88
+
89
+ self.models = AsyncModels(
90
+ api_key=api_key,
91
+ vertexai=vertexai,
92
+ credentials=credentials,
93
+ project=project,
94
+ location=location,
95
+ debug_config=debug_config,
96
+ http_options=http_options,
97
+ posthog_client=self._ph_client,
98
+ posthog_distinct_id=posthog_distinct_id,
99
+ posthog_properties=posthog_properties,
100
+ posthog_privacy_mode=posthog_privacy_mode,
101
+ posthog_groups=posthog_groups,
102
+ **kwargs,
103
+ )
104
+
105
+
106
+ class AsyncModels:
107
+ """
108
+ Async Models interface that mimics genai.Client().aio.models with PostHog tracking.
109
+ """
110
+
111
+ _ph_client: PostHogClient # Not None after __init__ validation
112
+
113
+ def __init__(
114
+ self,
115
+ api_key: Optional[str] = None,
116
+ vertexai: Optional[bool] = None,
117
+ credentials: Optional[Any] = None,
118
+ project: Optional[str] = None,
119
+ location: Optional[str] = None,
120
+ debug_config: Optional[Any] = None,
121
+ http_options: Optional[Any] = None,
122
+ posthog_client: Optional[PostHogClient] = None,
123
+ posthog_distinct_id: Optional[str] = None,
124
+ posthog_properties: Optional[Dict[str, Any]] = None,
125
+ posthog_privacy_mode: bool = False,
126
+ posthog_groups: Optional[Dict[str, Any]] = None,
127
+ **kwargs,
128
+ ):
129
+ """
130
+ Args:
131
+ api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI)
132
+ vertexai: Whether to use Vertex AI authentication
133
+ credentials: Vertex AI credentials object
134
+ project: GCP project ID for Vertex AI
135
+ location: GCP location for Vertex AI
136
+ debug_config: Debug configuration for the client
137
+ http_options: HTTP options for the client
138
+ posthog_client: PostHog client for tracking usage
139
+ posthog_distinct_id: Default distinct ID for all calls
140
+ posthog_properties: Default properties for all calls
141
+ posthog_privacy_mode: Default privacy mode for all calls
142
+ posthog_groups: Default groups for all calls
143
+ **kwargs: Additional arguments (for future compatibility)
144
+ """
145
+
146
+ self._ph_client = posthog_client or setup()
147
+
148
+ if self._ph_client is None:
149
+ raise ValueError("posthog_client is required for PostHog tracking")
150
+
151
+ # Store default PostHog settings
152
+ self._default_distinct_id = posthog_distinct_id
153
+ self._default_properties = posthog_properties or {}
154
+ self._default_privacy_mode = posthog_privacy_mode
155
+ self._default_groups = posthog_groups
156
+
157
+ # Build genai.Client arguments
158
+ client_args: Dict[str, Any] = {}
159
+
160
+ # Add Vertex AI parameters if provided
161
+ if vertexai is not None:
162
+ client_args["vertexai"] = vertexai
163
+
164
+ if credentials is not None:
165
+ client_args["credentials"] = credentials
166
+
167
+ if project is not None:
168
+ client_args["project"] = project
169
+
170
+ if location is not None:
171
+ client_args["location"] = location
172
+
173
+ if debug_config is not None:
174
+ client_args["debug_config"] = debug_config
175
+
176
+ if http_options is not None:
177
+ client_args["http_options"] = http_options
178
+
179
+ # Handle API key authentication
180
+ if vertexai:
181
+ # For Vertex AI, api_key is optional
182
+ if api_key is not None:
183
+ client_args["api_key"] = api_key
184
+ else:
185
+ # For non-Vertex AI mode, api_key is required (backwards compatibility)
186
+ if api_key is None:
187
+ api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY")
188
+
189
+ if api_key is None:
190
+ raise ValueError(
191
+ "API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable"
192
+ )
193
+
194
+ client_args["api_key"] = api_key
195
+
196
+ self._client = genai.Client(**client_args)
197
+ self._base_url = "https://generativelanguage.googleapis.com"
198
+
199
+ def _merge_posthog_params(
200
+ self,
201
+ call_distinct_id: Optional[str],
202
+ call_trace_id: Optional[str],
203
+ call_properties: Optional[Dict[str, Any]],
204
+ call_privacy_mode: Optional[bool],
205
+ call_groups: Optional[Dict[str, Any]],
206
+ ):
207
+ """Merge call-level PostHog parameters with client defaults."""
208
+
209
+ # Use call-level values if provided, otherwise fall back to defaults
210
+ distinct_id = (
211
+ call_distinct_id
212
+ if call_distinct_id is not None
213
+ else self._default_distinct_id
214
+ )
215
+ privacy_mode = (
216
+ call_privacy_mode
217
+ if call_privacy_mode is not None
218
+ else self._default_privacy_mode
219
+ )
220
+ groups = call_groups if call_groups is not None else self._default_groups
221
+
222
+ # Merge properties: default properties + call properties (call properties override)
223
+ properties = dict(self._default_properties)
224
+
225
+ if call_properties:
226
+ properties.update(call_properties)
227
+
228
+ if call_trace_id is None:
229
+ call_trace_id = str(uuid.uuid4())
230
+
231
+ return distinct_id, call_trace_id, properties, privacy_mode, groups
232
+
233
+ async def generate_content(
234
+ self,
235
+ model: str,
236
+ contents,
237
+ posthog_distinct_id: Optional[str] = None,
238
+ posthog_trace_id: Optional[str] = None,
239
+ posthog_properties: Optional[Dict[str, Any]] = None,
240
+ posthog_privacy_mode: Optional[bool] = None,
241
+ posthog_groups: Optional[Dict[str, Any]] = None,
242
+ **kwargs: Any,
243
+ ):
244
+ """
245
+ Generate content using Gemini's API while tracking usage in PostHog.
246
+
247
+ This method signature exactly matches genai.Client().aio.models.generate_content()
248
+ with additional PostHog tracking parameters.
249
+
250
+ Args:
251
+ model: The model to use (e.g., 'gemini-2.0-flash')
252
+ contents: The input content for generation
253
+ posthog_distinct_id: ID to associate with the usage event (overrides client default)
254
+ posthog_trace_id: Trace UUID for linking events (auto-generated if not provided)
255
+ posthog_properties: Extra properties to include in the event (merged with client defaults)
256
+ posthog_privacy_mode: Whether to redact sensitive information (overrides client default)
257
+ posthog_groups: Group analytics properties (overrides client default)
258
+ **kwargs: Arguments passed to Gemini's generate_content
259
+ """
260
+
261
+ # Merge PostHog parameters
262
+ distinct_id, trace_id, properties, privacy_mode, groups = (
263
+ self._merge_posthog_params(
264
+ posthog_distinct_id,
265
+ posthog_trace_id,
266
+ posthog_properties,
267
+ posthog_privacy_mode,
268
+ posthog_groups,
269
+ )
270
+ )
271
+
272
+ kwargs_with_contents = {"model": model, "contents": contents, **kwargs}
273
+
274
+ return await call_llm_and_track_usage_async(
275
+ distinct_id,
276
+ self._ph_client,
277
+ "gemini",
278
+ trace_id,
279
+ properties,
280
+ privacy_mode,
281
+ groups,
282
+ self._base_url,
283
+ self._client.aio.models.generate_content,
284
+ **kwargs_with_contents,
285
+ )
286
+
287
+ async def _generate_content_streaming(
288
+ self,
289
+ model: str,
290
+ contents,
291
+ distinct_id: Optional[str],
292
+ trace_id: Optional[str],
293
+ properties: Optional[Dict[str, Any]],
294
+ privacy_mode: bool,
295
+ groups: Optional[Dict[str, Any]],
296
+ **kwargs: Any,
297
+ ):
298
+ start_time = time.time()
299
+ usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0)
300
+ accumulated_content = []
301
+
302
+ kwargs_without_stream = {"model": model, "contents": contents, **kwargs}
303
+ response = await self._client.aio.models.generate_content_stream(
304
+ **kwargs_without_stream
305
+ )
306
+
307
+ async def async_generator():
308
+ nonlocal usage_stats
309
+ nonlocal accumulated_content
310
+
311
+ try:
312
+ async for chunk in response:
313
+ # Extract usage stats from chunk
314
+ chunk_usage = extract_gemini_usage_from_chunk(chunk)
315
+
316
+ if chunk_usage:
317
+ # Gemini reports cumulative totals, not incremental values
318
+ merge_usage_stats(usage_stats, chunk_usage, mode="cumulative")
319
+
320
+ # Extract content from chunk (now returns content blocks)
321
+ content_block = extract_gemini_content_from_chunk(chunk)
322
+
323
+ if content_block is not None:
324
+ accumulated_content.append(content_block)
325
+
326
+ yield chunk
327
+
328
+ finally:
329
+ end_time = time.time()
330
+ latency = end_time - start_time
331
+
332
+ self._capture_streaming_event(
333
+ model,
334
+ contents,
335
+ distinct_id,
336
+ trace_id,
337
+ properties,
338
+ privacy_mode,
339
+ groups,
340
+ kwargs,
341
+ usage_stats,
342
+ latency,
343
+ accumulated_content,
344
+ )
345
+
346
+ return async_generator()
347
+
348
+ def _capture_streaming_event(
349
+ self,
350
+ model: str,
351
+ contents,
352
+ distinct_id: Optional[str],
353
+ trace_id: Optional[str],
354
+ properties: Optional[Dict[str, Any]],
355
+ privacy_mode: bool,
356
+ groups: Optional[Dict[str, Any]],
357
+ kwargs: Dict[str, Any],
358
+ usage_stats: TokenUsage,
359
+ latency: float,
360
+ output: Any,
361
+ ):
362
+ # Prepare standardized event data
363
+ formatted_input = self._format_input(contents, **kwargs)
364
+ sanitized_input = sanitize_gemini(formatted_input)
365
+
366
+ event_data = StreamingEventData(
367
+ provider="gemini",
368
+ model=model,
369
+ base_url=self._base_url,
370
+ kwargs=kwargs,
371
+ formatted_input=sanitized_input,
372
+ formatted_output=format_gemini_streaming_output(output),
373
+ usage_stats=usage_stats,
374
+ latency=latency,
375
+ distinct_id=distinct_id,
376
+ trace_id=trace_id,
377
+ properties=properties,
378
+ privacy_mode=privacy_mode,
379
+ groups=groups,
380
+ )
381
+
382
+ # Use the common capture function
383
+ capture_streaming_event(self._ph_client, event_data)
384
+
385
+ def _format_input(self, contents, **kwargs):
386
+ """Format input contents for PostHog tracking"""
387
+
388
+ # Create kwargs dict with contents for merge_system_prompt
389
+ input_kwargs = {"contents": contents, **kwargs}
390
+ return merge_system_prompt(input_kwargs, "gemini")
391
+
392
+ async def generate_content_stream(
393
+ self,
394
+ model: str,
395
+ contents,
396
+ posthog_distinct_id: Optional[str] = None,
397
+ posthog_trace_id: Optional[str] = None,
398
+ posthog_properties: Optional[Dict[str, Any]] = None,
399
+ posthog_privacy_mode: Optional[bool] = None,
400
+ posthog_groups: Optional[Dict[str, Any]] = None,
401
+ **kwargs: Any,
402
+ ):
403
+ # Merge PostHog parameters
404
+ distinct_id, trace_id, properties, privacy_mode, groups = (
405
+ self._merge_posthog_params(
406
+ posthog_distinct_id,
407
+ posthog_trace_id,
408
+ posthog_properties,
409
+ posthog_privacy_mode,
410
+ posthog_groups,
411
+ )
412
+ )
413
+
414
+ return await self._generate_content_streaming(
415
+ model,
416
+ contents,
417
+ distinct_id,
418
+ trace_id,
419
+ properties,
420
+ privacy_mode,
421
+ groups,
422
+ **kwargs,
423
+ )
@@ -29,35 +29,76 @@ class GeminiMessage(TypedDict, total=False):
29
29
  text: str
30
30
 
31
31
 
32
- def _extract_text_from_parts(parts: List[Any]) -> str:
32
+ def _format_parts_as_content_blocks(parts: List[Any]) -> List[FormattedContentItem]:
33
33
  """
34
- Extract and concatenate text from a parts array.
34
+ Format Gemini parts array into structured content blocks.
35
+
36
+ Preserves structure for multimodal content (text + images) instead of
37
+ concatenating everything into a string.
35
38
 
36
39
  Args:
37
- parts: List of parts that may contain text content
40
+ parts: List of parts that may contain text, inline_data, etc.
38
41
 
39
42
  Returns:
40
- Concatenated text from all parts
43
+ List of formatted content blocks
41
44
  """
42
-
43
- content_parts = []
45
+ content_blocks: List[FormattedContentItem] = []
44
46
 
45
47
  for part in parts:
48
+ # Handle dict with text field
46
49
  if isinstance(part, dict) and "text" in part:
47
- content_parts.append(part["text"])
50
+ content_blocks.append({"type": "text", "text": part["text"]})
48
51
 
52
+ # Handle string parts
49
53
  elif isinstance(part, str):
50
- content_parts.append(part)
54
+ content_blocks.append({"type": "text", "text": part})
55
+
56
+ # Handle dict with inline_data (images, documents, etc.)
57
+ elif isinstance(part, dict) and "inline_data" in part:
58
+ inline_data = part["inline_data"]
59
+ mime_type = inline_data.get("mime_type", "")
60
+ content_type = "image" if mime_type.startswith("image/") else "document"
61
+
62
+ content_blocks.append(
63
+ {
64
+ "type": content_type,
65
+ "inline_data": inline_data,
66
+ }
67
+ )
51
68
 
69
+ # Handle object with text attribute
52
70
  elif hasattr(part, "text"):
53
- # Get the text attribute value
54
71
  text_value = getattr(part, "text", "")
55
- content_parts.append(text_value if text_value else str(part))
56
-
57
- else:
58
- content_parts.append(str(part))
72
+ if text_value:
73
+ content_blocks.append({"type": "text", "text": text_value})
74
+
75
+ # Handle object with inline_data attribute
76
+ elif hasattr(part, "inline_data"):
77
+ inline_data = part.inline_data
78
+ # Convert to dict if needed
79
+ if hasattr(inline_data, "mime_type") and hasattr(inline_data, "data"):
80
+ # Determine type based on mime_type
81
+ mime_type = inline_data.mime_type
82
+ content_type = "image" if mime_type.startswith("image/") else "document"
83
+
84
+ content_blocks.append(
85
+ {
86
+ "type": content_type,
87
+ "inline_data": {
88
+ "mime_type": mime_type,
89
+ "data": inline_data.data,
90
+ },
91
+ }
92
+ )
93
+ else:
94
+ content_blocks.append(
95
+ {
96
+ "type": "image",
97
+ "inline_data": inline_data,
98
+ }
99
+ )
59
100
 
60
- return "".join(content_parts)
101
+ return content_blocks
61
102
 
62
103
 
63
104
  def _format_dict_message(item: Dict[str, Any]) -> FormattedMessage:
@@ -73,16 +114,17 @@ def _format_dict_message(item: Dict[str, Any]) -> FormattedMessage:
73
114
 
74
115
  # Handle dict format with parts array (Gemini-specific format)
75
116
  if "parts" in item and isinstance(item["parts"], list):
76
- content = _extract_text_from_parts(item["parts"])
77
- return {"role": item.get("role", "user"), "content": content}
117
+ content_blocks = _format_parts_as_content_blocks(item["parts"])
118
+ return {"role": item.get("role", "user"), "content": content_blocks}
78
119
 
79
120
  # Handle dict with content field
80
121
  if "content" in item:
81
122
  content = item["content"]
82
123
 
83
124
  if isinstance(content, list):
84
- # If content is a list, extract text from it
85
- content = _extract_text_from_parts(content)
125
+ # If content is a list, format it as content blocks
126
+ content_blocks = _format_parts_as_content_blocks(content)
127
+ return {"role": item.get("role", "user"), "content": content_blocks}
86
128
 
87
129
  elif not isinstance(content, str):
88
130
  content = str(content)
@@ -110,14 +152,14 @@ def _format_object_message(item: Any) -> FormattedMessage:
110
152
 
111
153
  # Handle object with parts attribute
112
154
  if hasattr(item, "parts") and hasattr(item.parts, "__iter__"):
113
- content = _extract_text_from_parts(item.parts)
155
+ content_blocks = _format_parts_as_content_blocks(list(item.parts))
114
156
  role = getattr(item, "role", "user") if hasattr(item, "role") else "user"
115
157
 
116
158
  # Ensure role is a string
117
159
  if not isinstance(role, str):
118
160
  role = "user"
119
161
 
120
- return {"role": role, "content": content}
162
+ return {"role": role, "content": content_blocks}
121
163
 
122
164
  # Handle object with text attribute
123
165
  if hasattr(item, "text"):
@@ -140,7 +182,8 @@ def _format_object_message(item: Any) -> FormattedMessage:
140
182
  content = item.content
141
183
 
142
184
  if isinstance(content, list):
143
- content = _extract_text_from_parts(content)
185
+ content_blocks = _format_parts_as_content_blocks(content)
186
+ return {"role": role, "content": content_blocks}
144
187
 
145
188
  elif not isinstance(content, str):
146
189
  content = str(content)
@@ -193,6 +236,29 @@ def format_gemini_response(response: Any) -> List[FormattedMessage]:
193
236
  }
194
237
  )
195
238
 
239
+ elif hasattr(part, "inline_data") and part.inline_data:
240
+ # Handle audio/media inline data
241
+ import base64
242
+
243
+ inline_data = part.inline_data
244
+ mime_type = getattr(inline_data, "mime_type", "audio/pcm")
245
+ raw_data = getattr(inline_data, "data", b"")
246
+
247
+ # Encode binary data as base64 string for JSON serialization
248
+ if isinstance(raw_data, bytes):
249
+ data = base64.b64encode(raw_data).decode("utf-8")
250
+ else:
251
+ # Already a string (base64)
252
+ data = raw_data
253
+
254
+ content.append(
255
+ {
256
+ "type": "audio",
257
+ "mime_type": mime_type,
258
+ "data": data,
259
+ }
260
+ )
261
+
196
262
  if content:
197
263
  output.append(
198
264
  {
@@ -338,6 +404,61 @@ def format_gemini_input(contents: Any) -> List[FormattedMessage]:
338
404
  return [_format_object_message(contents)]
339
405
 
340
406
 
407
+ def extract_gemini_web_search_count(response: Any) -> int:
408
+ """
409
+ Extract web search count from Gemini response.
410
+
411
+ Gemini bills per request that uses grounding, not per query.
412
+ Returns 1 if grounding_metadata is present with actual search data, 0 otherwise.
413
+
414
+ Args:
415
+ response: The response from Gemini API
416
+
417
+ Returns:
418
+ 1 if web search/grounding was used, 0 otherwise
419
+ """
420
+
421
+ # Check for grounding_metadata in candidates
422
+ if hasattr(response, "candidates"):
423
+ for candidate in response.candidates:
424
+ if (
425
+ hasattr(candidate, "grounding_metadata")
426
+ and candidate.grounding_metadata
427
+ ):
428
+ grounding_metadata = candidate.grounding_metadata
429
+
430
+ # Check if web_search_queries exists and is non-empty
431
+ if hasattr(grounding_metadata, "web_search_queries"):
432
+ queries = grounding_metadata.web_search_queries
433
+
434
+ if queries is not None and len(queries) > 0:
435
+ return 1
436
+
437
+ # Check if grounding_chunks exists and is non-empty
438
+ if hasattr(grounding_metadata, "grounding_chunks"):
439
+ chunks = grounding_metadata.grounding_chunks
440
+
441
+ if chunks is not None and len(chunks) > 0:
442
+ return 1
443
+
444
+ # Also check for google_search or grounding in function call names
445
+ if hasattr(candidate, "content") and candidate.content:
446
+ if hasattr(candidate.content, "parts") and candidate.content.parts:
447
+ for part in candidate.content.parts:
448
+ if hasattr(part, "function_call") and part.function_call:
449
+ function_name = getattr(
450
+ part.function_call, "name", ""
451
+ ).lower()
452
+
453
+ if (
454
+ "google_search" in function_name
455
+ or "grounding" in function_name
456
+ ):
457
+ return 1
458
+
459
+ return 0
460
+
461
+
341
462
  def _extract_usage_from_metadata(metadata: Any) -> TokenUsage:
342
463
  """
343
464
  Common logic to extract usage from Gemini metadata.
@@ -382,7 +503,14 @@ def extract_gemini_usage_from_response(response: Any) -> TokenUsage:
382
503
  if not hasattr(response, "usage_metadata") or not response.usage_metadata:
383
504
  return TokenUsage(input_tokens=0, output_tokens=0)
384
505
 
385
- return _extract_usage_from_metadata(response.usage_metadata)
506
+ usage = _extract_usage_from_metadata(response.usage_metadata)
507
+
508
+ # Add web search count if present
509
+ web_search_count = extract_gemini_web_search_count(response)
510
+ if web_search_count > 0:
511
+ usage["web_search_count"] = web_search_count
512
+
513
+ return usage
386
514
 
387
515
 
388
516
  def extract_gemini_usage_from_chunk(chunk: Any) -> TokenUsage:
@@ -398,11 +526,19 @@ def extract_gemini_usage_from_chunk(chunk: Any) -> TokenUsage:
398
526
 
399
527
  usage: TokenUsage = TokenUsage()
400
528
 
529
+ # Extract web search count from the chunk before checking for usage_metadata
530
+ # Web search indicators can appear on any chunk, not just those with usage data
531
+ web_search_count = extract_gemini_web_search_count(chunk)
532
+ if web_search_count > 0:
533
+ usage["web_search_count"] = web_search_count
534
+
401
535
  if not hasattr(chunk, "usage_metadata") or not chunk.usage_metadata:
402
536
  return usage
403
537
 
404
- # Use the shared helper to extract usage
405
- usage = _extract_usage_from_metadata(chunk.usage_metadata)
538
+ usage_from_metadata = _extract_usage_from_metadata(chunk.usage_metadata)
539
+
540
+ # Merge the usage from metadata with any web search count we found
541
+ usage.update(usage_from_metadata)
406
542
 
407
543
  return usage
408
544