posthog 6.7.2__py3-none-any.whl → 6.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
posthog/__init__.py CHANGED
@@ -10,7 +10,11 @@ from posthog.contexts import (
10
10
  tag as inner_tag,
11
11
  set_context_session as inner_set_context_session,
12
12
  identify_context as inner_identify_context,
13
+ set_capture_exception_code_variables_context as inner_set_capture_exception_code_variables_context,
14
+ set_code_variables_mask_patterns_context as inner_set_code_variables_mask_patterns_context,
15
+ set_code_variables_ignore_patterns_context as inner_set_code_variables_ignore_patterns_context,
13
16
  )
17
+ from posthog.feature_flags import InconclusiveMatchError, RequiresServerEvaluation
14
18
  from posthog.types import FeatureFlag, FlagsAndPayloads, FeatureFlagResult
15
19
  from posthog.version import VERSION
16
20
 
@@ -19,13 +23,14 @@ __version__ = VERSION
19
23
  """Context management."""
20
24
 
21
25
 
22
- def new_context(fresh=False, capture_exceptions=True):
26
+ def new_context(fresh=False, capture_exceptions=True, client=None):
23
27
  """
24
28
  Create a new context scope that will be active for the duration of the with block.
25
29
 
26
30
  Args:
27
31
  fresh: Whether to start with a fresh context (default: False)
28
32
  capture_exceptions: Whether to capture exceptions raised within the context (default: True)
33
+ client: Optional Posthog client instance to use for this context (default: None)
29
34
 
30
35
  Examples:
31
36
  ```python
@@ -38,7 +43,9 @@ def new_context(fresh=False, capture_exceptions=True):
38
43
  Category:
39
44
  Contexts
40
45
  """
41
- return inner_new_context(fresh=fresh, capture_exceptions=capture_exceptions)
46
+ return inner_new_context(
47
+ fresh=fresh, capture_exceptions=capture_exceptions, client=client
48
+ )
42
49
 
43
50
 
44
51
  def scoped(fresh=False, capture_exceptions=True):
@@ -102,6 +109,27 @@ def identify_context(distinct_id: str):
102
109
  return inner_identify_context(distinct_id)
103
110
 
104
111
 
112
+ def set_capture_exception_code_variables_context(enabled: bool):
113
+ """
114
+ Set whether code variables are captured for the current context.
115
+ """
116
+ return inner_set_capture_exception_code_variables_context(enabled)
117
+
118
+
119
+ def set_code_variables_mask_patterns_context(mask_patterns: list):
120
+ """
121
+ Variable names matching these patterns will be masked with *** when capturing code variables.
122
+ """
123
+ return inner_set_code_variables_mask_patterns_context(mask_patterns)
124
+
125
+
126
+ def set_code_variables_ignore_patterns_context(ignore_patterns: list):
127
+ """
128
+ Variable names matching these patterns will be ignored completely when capturing code variables.
129
+ """
130
+ return inner_set_code_variables_ignore_patterns_context(ignore_patterns)
131
+
132
+
105
133
  def tag(name: str, value: Any):
106
134
  """
107
135
  Add a tag to the current context.
@@ -10,7 +10,7 @@ import time
10
10
  import uuid
11
11
  from typing import Any, Dict, List, Optional
12
12
 
13
- from posthog.ai.types import StreamingContentBlock, ToolInProgress
13
+ from posthog.ai.types import StreamingContentBlock, TokenUsage, ToolInProgress
14
14
  from posthog.ai.utils import (
15
15
  call_llm_and_track_usage,
16
16
  merge_usage_stats,
@@ -126,7 +126,7 @@ class WrappedMessages(Messages):
126
126
  **kwargs: Any,
127
127
  ):
128
128
  start_time = time.time()
129
- usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
129
+ usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0)
130
130
  accumulated_content = ""
131
131
  content_blocks: List[StreamingContentBlock] = []
132
132
  tools_in_progress: Dict[str, ToolInProgress] = {}
@@ -210,14 +210,13 @@ class WrappedMessages(Messages):
210
210
  posthog_privacy_mode: bool,
211
211
  posthog_groups: Optional[Dict[str, Any]],
212
212
  kwargs: Dict[str, Any],
213
- usage_stats: Dict[str, int],
213
+ usage_stats: TokenUsage,
214
214
  latency: float,
215
215
  content_blocks: List[StreamingContentBlock],
216
216
  accumulated_content: str,
217
217
  ):
218
218
  from posthog.ai.types import StreamingEventData
219
219
  from posthog.ai.anthropic.anthropic_converter import (
220
- standardize_anthropic_usage,
221
220
  format_anthropic_streaming_input,
222
221
  format_anthropic_streaming_output_complete,
223
222
  )
@@ -236,7 +235,7 @@ class WrappedMessages(Messages):
236
235
  formatted_output=format_anthropic_streaming_output_complete(
237
236
  content_blocks, accumulated_content
238
237
  ),
239
- usage_stats=standardize_anthropic_usage(usage_stats),
238
+ usage_stats=usage_stats,
240
239
  latency=latency,
241
240
  distinct_id=posthog_distinct_id,
242
241
  trace_id=posthog_trace_id,
@@ -11,17 +11,12 @@ import uuid
11
11
  from typing import Any, Dict, List, Optional
12
12
 
13
13
  from posthog import setup
14
- from posthog.ai.types import StreamingContentBlock, ToolInProgress
14
+ from posthog.ai.types import StreamingContentBlock, TokenUsage, ToolInProgress
15
15
  from posthog.ai.utils import (
16
16
  call_llm_and_track_usage_async,
17
- extract_available_tool_calls,
18
- get_model_params,
19
- merge_system_prompt,
20
17
  merge_usage_stats,
21
- with_privacy_mode,
22
18
  )
23
19
  from posthog.ai.anthropic.anthropic_converter import (
24
- format_anthropic_streaming_content,
25
20
  extract_anthropic_usage_from_event,
26
21
  handle_anthropic_content_block_start,
27
22
  handle_anthropic_text_delta,
@@ -131,7 +126,7 @@ class AsyncWrappedMessages(AsyncMessages):
131
126
  **kwargs: Any,
132
127
  ):
133
128
  start_time = time.time()
134
- usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
129
+ usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0)
135
130
  accumulated_content = ""
136
131
  content_blocks: List[StreamingContentBlock] = []
137
132
  tools_in_progress: Dict[str, ToolInProgress] = {}
@@ -215,71 +210,39 @@ class AsyncWrappedMessages(AsyncMessages):
215
210
  posthog_privacy_mode: bool,
216
211
  posthog_groups: Optional[Dict[str, Any]],
217
212
  kwargs: Dict[str, Any],
218
- usage_stats: Dict[str, int],
213
+ usage_stats: TokenUsage,
219
214
  latency: float,
220
215
  content_blocks: List[StreamingContentBlock],
221
216
  accumulated_content: str,
222
217
  ):
223
- if posthog_trace_id is None:
224
- posthog_trace_id = str(uuid.uuid4())
225
-
226
- # Format output using converter
227
- formatted_content = format_anthropic_streaming_content(content_blocks)
228
- formatted_output = []
229
-
230
- if formatted_content:
231
- formatted_output = [{"role": "assistant", "content": formatted_content}]
232
- else:
233
- # Fallback to accumulated content if no blocks
234
- formatted_output = [
235
- {
236
- "role": "assistant",
237
- "content": [{"type": "text", "text": accumulated_content}],
238
- }
239
- ]
240
-
241
- event_properties = {
242
- "$ai_provider": "anthropic",
243
- "$ai_model": kwargs.get("model"),
244
- "$ai_model_parameters": get_model_params(kwargs),
245
- "$ai_input": with_privacy_mode(
246
- self._client._ph_client,
247
- posthog_privacy_mode,
248
- sanitize_anthropic(merge_system_prompt(kwargs, "anthropic")),
249
- ),
250
- "$ai_output_choices": with_privacy_mode(
251
- self._client._ph_client,
252
- posthog_privacy_mode,
253
- formatted_output,
254
- ),
255
- "$ai_http_status": 200,
256
- "$ai_input_tokens": usage_stats.get("input_tokens", 0),
257
- "$ai_output_tokens": usage_stats.get("output_tokens", 0),
258
- "$ai_cache_read_input_tokens": usage_stats.get(
259
- "cache_read_input_tokens", 0
260
- ),
261
- "$ai_cache_creation_input_tokens": usage_stats.get(
262
- "cache_creation_input_tokens", 0
218
+ from posthog.ai.types import StreamingEventData
219
+ from posthog.ai.anthropic.anthropic_converter import (
220
+ format_anthropic_streaming_input,
221
+ format_anthropic_streaming_output_complete,
222
+ )
223
+ from posthog.ai.utils import capture_streaming_event
224
+
225
+ # Prepare standardized event data
226
+ formatted_input = format_anthropic_streaming_input(kwargs)
227
+ sanitized_input = sanitize_anthropic(formatted_input)
228
+
229
+ event_data = StreamingEventData(
230
+ provider="anthropic",
231
+ model=kwargs.get("model", "unknown"),
232
+ base_url=str(self._client.base_url),
233
+ kwargs=kwargs,
234
+ formatted_input=sanitized_input,
235
+ formatted_output=format_anthropic_streaming_output_complete(
236
+ content_blocks, accumulated_content
263
237
  ),
264
- "$ai_latency": latency,
265
- "$ai_trace_id": posthog_trace_id,
266
- "$ai_base_url": str(self._client.base_url),
267
- **(posthog_properties or {}),
268
- }
269
-
270
- # Add tools if available
271
- available_tools = extract_available_tool_calls("anthropic", kwargs)
272
-
273
- if available_tools:
274
- event_properties["$ai_tools"] = available_tools
275
-
276
- if posthog_distinct_id is None:
277
- event_properties["$process_person_profile"] = False
278
-
279
- if hasattr(self._client._ph_client, "capture"):
280
- self._client._ph_client.capture(
281
- distinct_id=posthog_distinct_id or posthog_trace_id,
282
- event="$ai_generation",
283
- properties=event_properties,
284
- groups=posthog_groups,
285
- )
238
+ usage_stats=usage_stats,
239
+ latency=latency,
240
+ distinct_id=posthog_distinct_id,
241
+ trace_id=posthog_trace_id,
242
+ properties=posthog_properties,
243
+ privacy_mode=posthog_privacy_mode,
244
+ groups=posthog_groups,
245
+ )
246
+
247
+ # Use the common capture function
248
+ capture_streaming_event(self._client._ph_client, event_data)
@@ -14,7 +14,6 @@ from posthog.ai.types import (
14
14
  FormattedMessage,
15
15
  FormattedTextContent,
16
16
  StreamingContentBlock,
17
- StreamingUsageStats,
18
17
  TokenUsage,
19
18
  ToolInProgress,
20
19
  )
@@ -164,7 +163,68 @@ def format_anthropic_streaming_content(
164
163
  return formatted
165
164
 
166
165
 
167
- def extract_anthropic_usage_from_event(event: Any) -> StreamingUsageStats:
166
+ def extract_anthropic_web_search_count(response: Any) -> int:
167
+ """
168
+ Extract web search count from Anthropic response.
169
+
170
+ Anthropic provides exact web search counts via usage.server_tool_use.web_search_requests.
171
+
172
+ Args:
173
+ response: The response from Anthropic API
174
+
175
+ Returns:
176
+ Number of web search requests (0 if none)
177
+ """
178
+ if not hasattr(response, "usage"):
179
+ return 0
180
+
181
+ if not hasattr(response.usage, "server_tool_use"):
182
+ return 0
183
+
184
+ server_tool_use = response.usage.server_tool_use
185
+
186
+ if hasattr(server_tool_use, "web_search_requests"):
187
+ return max(0, int(getattr(server_tool_use, "web_search_requests", 0)))
188
+
189
+ return 0
190
+
191
+
192
+ def extract_anthropic_usage_from_response(response: Any) -> TokenUsage:
193
+ """
194
+ Extract usage from a full Anthropic response (non-streaming).
195
+
196
+ Args:
197
+ response: The complete response from Anthropic API
198
+
199
+ Returns:
200
+ TokenUsage with standardized usage
201
+ """
202
+ if not hasattr(response, "usage"):
203
+ return TokenUsage(input_tokens=0, output_tokens=0)
204
+
205
+ result = TokenUsage(
206
+ input_tokens=getattr(response.usage, "input_tokens", 0),
207
+ output_tokens=getattr(response.usage, "output_tokens", 0),
208
+ )
209
+
210
+ if hasattr(response.usage, "cache_read_input_tokens"):
211
+ cache_read = response.usage.cache_read_input_tokens
212
+ if cache_read and cache_read > 0:
213
+ result["cache_read_input_tokens"] = cache_read
214
+
215
+ if hasattr(response.usage, "cache_creation_input_tokens"):
216
+ cache_creation = response.usage.cache_creation_input_tokens
217
+ if cache_creation and cache_creation > 0:
218
+ result["cache_creation_input_tokens"] = cache_creation
219
+
220
+ web_search_count = extract_anthropic_web_search_count(response)
221
+ if web_search_count > 0:
222
+ result["web_search_count"] = web_search_count
223
+
224
+ return result
225
+
226
+
227
+ def extract_anthropic_usage_from_event(event: Any) -> TokenUsage:
168
228
  """
169
229
  Extract usage statistics from an Anthropic streaming event.
170
230
 
@@ -175,7 +235,7 @@ def extract_anthropic_usage_from_event(event: Any) -> StreamingUsageStats:
175
235
  Dictionary of usage statistics
176
236
  """
177
237
 
178
- usage: StreamingUsageStats = {}
238
+ usage: TokenUsage = TokenUsage()
179
239
 
180
240
  # Handle usage stats from message_start event
181
241
  if hasattr(event, "type") and event.type == "message_start":
@@ -192,6 +252,16 @@ def extract_anthropic_usage_from_event(event: Any) -> StreamingUsageStats:
192
252
  if hasattr(event, "usage") and event.usage:
193
253
  usage["output_tokens"] = getattr(event.usage, "output_tokens", 0)
194
254
 
255
+ # Extract web search count from usage
256
+ if hasattr(event.usage, "server_tool_use"):
257
+ server_tool_use = event.usage.server_tool_use
258
+ if hasattr(server_tool_use, "web_search_requests"):
259
+ web_search_count = int(
260
+ getattr(server_tool_use, "web_search_requests", 0)
261
+ )
262
+ if web_search_count > 0:
263
+ usage["web_search_count"] = web_search_count
264
+
195
265
  return usage
196
266
 
197
267
 
@@ -329,26 +399,6 @@ def finalize_anthropic_tool_input(
329
399
  del tools_in_progress[block["id"]]
330
400
 
331
401
 
332
- def standardize_anthropic_usage(usage: Dict[str, Any]) -> TokenUsage:
333
- """
334
- Standardize Anthropic usage statistics to common TokenUsage format.
335
-
336
- Anthropic already uses standard field names, so this mainly structures the data.
337
-
338
- Args:
339
- usage: Raw usage statistics from Anthropic
340
-
341
- Returns:
342
- Standardized TokenUsage dict
343
- """
344
- return TokenUsage(
345
- input_tokens=usage.get("input_tokens", 0),
346
- output_tokens=usage.get("output_tokens", 0),
347
- cache_read_input_tokens=usage.get("cache_read_input_tokens"),
348
- cache_creation_input_tokens=usage.get("cache_creation_input_tokens"),
349
- )
350
-
351
-
352
402
  def format_anthropic_streaming_input(kwargs: Dict[str, Any]) -> Any:
353
403
  """
354
404
  Format Anthropic streaming input using system prompt merging.
@@ -3,6 +3,9 @@ import time
3
3
  import uuid
4
4
  from typing import Any, Dict, Optional
5
5
 
6
+ from posthog.ai.types import TokenUsage, StreamingEventData
7
+ from posthog.ai.utils import merge_system_prompt
8
+
6
9
  try:
7
10
  from google import genai
8
11
  except ImportError:
@@ -17,7 +20,6 @@ from posthog.ai.utils import (
17
20
  merge_usage_stats,
18
21
  )
19
22
  from posthog.ai.gemini.gemini_converter import (
20
- format_gemini_input,
21
23
  extract_gemini_usage_from_chunk,
22
24
  extract_gemini_content_from_chunk,
23
25
  format_gemini_streaming_output,
@@ -294,7 +296,7 @@ class Models:
294
296
  **kwargs: Any,
295
297
  ):
296
298
  start_time = time.time()
297
- usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
299
+ usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0)
298
300
  accumulated_content = []
299
301
 
300
302
  kwargs_without_stream = {"model": model, "contents": contents, **kwargs}
@@ -350,15 +352,12 @@ class Models:
350
352
  privacy_mode: bool,
351
353
  groups: Optional[Dict[str, Any]],
352
354
  kwargs: Dict[str, Any],
353
- usage_stats: Dict[str, int],
355
+ usage_stats: TokenUsage,
354
356
  latency: float,
355
357
  output: Any,
356
358
  ):
357
- from posthog.ai.types import StreamingEventData
358
- from posthog.ai.gemini.gemini_converter import standardize_gemini_usage
359
-
360
359
  # Prepare standardized event data
361
- formatted_input = self._format_input(contents)
360
+ formatted_input = self._format_input(contents, **kwargs)
362
361
  sanitized_input = sanitize_gemini(formatted_input)
363
362
 
364
363
  event_data = StreamingEventData(
@@ -368,7 +367,7 @@ class Models:
368
367
  kwargs=kwargs,
369
368
  formatted_input=sanitized_input,
370
369
  formatted_output=format_gemini_streaming_output(output),
371
- usage_stats=standardize_gemini_usage(usage_stats),
370
+ usage_stats=usage_stats,
372
371
  latency=latency,
373
372
  distinct_id=distinct_id,
374
373
  trace_id=trace_id,
@@ -380,10 +379,12 @@ class Models:
380
379
  # Use the common capture function
381
380
  capture_streaming_event(self._ph_client, event_data)
382
381
 
383
- def _format_input(self, contents):
382
+ def _format_input(self, contents, **kwargs):
384
383
  """Format input contents for PostHog tracking"""
385
384
 
386
- return format_gemini_input(contents)
385
+ # Create kwargs dict with contents for merge_system_prompt
386
+ input_kwargs = {"contents": contents, **kwargs}
387
+ return merge_system_prompt(input_kwargs, "gemini")
387
388
 
388
389
  def generate_content_stream(
389
390
  self,
@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, TypedDict, Union
10
10
  from posthog.ai.types import (
11
11
  FormattedContentItem,
12
12
  FormattedMessage,
13
- StreamingUsageStats,
14
13
  TokenUsage,
15
14
  )
16
15
 
@@ -221,6 +220,30 @@ def format_gemini_response(response: Any) -> List[FormattedMessage]:
221
220
  return output
222
221
 
223
222
 
223
+ def extract_gemini_system_instruction(config: Any) -> Optional[str]:
224
+ """
225
+ Extract system instruction from Gemini config parameter.
226
+
227
+ Args:
228
+ config: Config object or dict that may contain system instruction
229
+
230
+ Returns:
231
+ System instruction string if present, None otherwise
232
+ """
233
+ if config is None:
234
+ return None
235
+
236
+ # Handle different config formats
237
+ if hasattr(config, "system_instruction"):
238
+ return config.system_instruction
239
+ elif isinstance(config, dict) and "system_instruction" in config:
240
+ return config["system_instruction"]
241
+ elif isinstance(config, dict) and "systemInstruction" in config:
242
+ return config["systemInstruction"]
243
+
244
+ return None
245
+
246
+
224
247
  def extract_gemini_tools(kwargs: Dict[str, Any]) -> Optional[Any]:
225
248
  """
226
249
  Extract tool definitions from Gemini API kwargs.
@@ -238,6 +261,38 @@ def extract_gemini_tools(kwargs: Dict[str, Any]) -> Optional[Any]:
238
261
  return None
239
262
 
240
263
 
264
+ def format_gemini_input_with_system(
265
+ contents: Any, config: Any = None
266
+ ) -> List[FormattedMessage]:
267
+ """
268
+ Format Gemini input contents into standardized message format, including system instruction handling.
269
+
270
+ Args:
271
+ contents: Input contents in various possible formats
272
+ config: Config object or dict that may contain system instruction
273
+
274
+ Returns:
275
+ List of formatted messages with role and content fields, with system message prepended if needed
276
+ """
277
+ formatted_messages = format_gemini_input(contents)
278
+
279
+ # Check if system instruction is provided in config parameter
280
+ system_instruction = extract_gemini_system_instruction(config)
281
+
282
+ if system_instruction is not None:
283
+ has_system = any(msg.get("role") == "system" for msg in formatted_messages)
284
+ if not has_system:
285
+ from posthog.ai.types import FormattedMessage
286
+
287
+ system_message: FormattedMessage = {
288
+ "role": "system",
289
+ "content": system_instruction,
290
+ }
291
+ formatted_messages = [system_message] + list(formatted_messages)
292
+
293
+ return formatted_messages
294
+
295
+
241
296
  def format_gemini_input(contents: Any) -> List[FormattedMessage]:
242
297
  """
243
298
  Format Gemini input contents into standardized message format for PostHog tracking.
@@ -283,7 +338,116 @@ def format_gemini_input(contents: Any) -> List[FormattedMessage]:
283
338
  return [_format_object_message(contents)]
284
339
 
285
340
 
286
- def extract_gemini_usage_from_chunk(chunk: Any) -> StreamingUsageStats:
341
+ def extract_gemini_web_search_count(response: Any) -> int:
342
+ """
343
+ Extract web search count from Gemini response.
344
+
345
+ Gemini bills per request that uses grounding, not per query.
346
+ Returns 1 if grounding_metadata is present with actual search data, 0 otherwise.
347
+
348
+ Args:
349
+ response: The response from Gemini API
350
+
351
+ Returns:
352
+ 1 if web search/grounding was used, 0 otherwise
353
+ """
354
+
355
+ # Check for grounding_metadata in candidates
356
+ if hasattr(response, "candidates"):
357
+ for candidate in response.candidates:
358
+ if (
359
+ hasattr(candidate, "grounding_metadata")
360
+ and candidate.grounding_metadata
361
+ ):
362
+ grounding_metadata = candidate.grounding_metadata
363
+
364
+ # Check if web_search_queries exists and is non-empty
365
+ if hasattr(grounding_metadata, "web_search_queries"):
366
+ queries = grounding_metadata.web_search_queries
367
+
368
+ if queries is not None and len(queries) > 0:
369
+ return 1
370
+
371
+ # Check if grounding_chunks exists and is non-empty
372
+ if hasattr(grounding_metadata, "grounding_chunks"):
373
+ chunks = grounding_metadata.grounding_chunks
374
+
375
+ if chunks is not None and len(chunks) > 0:
376
+ return 1
377
+
378
+ # Also check for google_search or grounding in function call names
379
+ if hasattr(candidate, "content") and candidate.content:
380
+ if hasattr(candidate.content, "parts") and candidate.content.parts:
381
+ for part in candidate.content.parts:
382
+ if hasattr(part, "function_call") and part.function_call:
383
+ function_name = getattr(
384
+ part.function_call, "name", ""
385
+ ).lower()
386
+
387
+ if (
388
+ "google_search" in function_name
389
+ or "grounding" in function_name
390
+ ):
391
+ return 1
392
+
393
+ return 0
394
+
395
+
396
+ def _extract_usage_from_metadata(metadata: Any) -> TokenUsage:
397
+ """
398
+ Common logic to extract usage from Gemini metadata.
399
+ Used by both streaming and non-streaming paths.
400
+
401
+ Args:
402
+ metadata: usage_metadata from Gemini response or chunk
403
+
404
+ Returns:
405
+ TokenUsage with standardized usage
406
+ """
407
+ usage = TokenUsage(
408
+ input_tokens=getattr(metadata, "prompt_token_count", 0),
409
+ output_tokens=getattr(metadata, "candidates_token_count", 0),
410
+ )
411
+
412
+ # Add cache tokens if present (don't add if 0)
413
+ if hasattr(metadata, "cached_content_token_count"):
414
+ cache_tokens = metadata.cached_content_token_count
415
+ if cache_tokens and cache_tokens > 0:
416
+ usage["cache_read_input_tokens"] = cache_tokens
417
+
418
+ # Add reasoning tokens if present (don't add if 0)
419
+ if hasattr(metadata, "thoughts_token_count"):
420
+ reasoning_tokens = metadata.thoughts_token_count
421
+ if reasoning_tokens and reasoning_tokens > 0:
422
+ usage["reasoning_tokens"] = reasoning_tokens
423
+
424
+ return usage
425
+
426
+
427
+ def extract_gemini_usage_from_response(response: Any) -> TokenUsage:
428
+ """
429
+ Extract usage statistics from a full Gemini response (non-streaming).
430
+
431
+ Args:
432
+ response: The complete response from Gemini API
433
+
434
+ Returns:
435
+ TokenUsage with standardized usage statistics
436
+ """
437
+ if not hasattr(response, "usage_metadata") or not response.usage_metadata:
438
+ return TokenUsage(input_tokens=0, output_tokens=0)
439
+
440
+ usage = _extract_usage_from_metadata(response.usage_metadata)
441
+
442
+ # Add web search count if present
443
+ web_search_count = extract_gemini_web_search_count(response)
444
+ if web_search_count > 0:
445
+ usage["web_search_count"] = web_search_count
446
+
447
+ return usage
448
+
449
+
450
+ def extract_gemini_usage_from_chunk(chunk: Any) -> TokenUsage:
287
451
  """
288
452
  Extract usage statistics from a Gemini streaming chunk.
289
453
 
@@ -291,21 +455,24 @@ def extract_gemini_usage_from_chunk(chunk: Any) -> StreamingUsageStats:
291
455
  chunk: Streaming chunk from Gemini API
292
456
 
293
457
  Returns:
294
- Dictionary of usage statistics
458
+ TokenUsage with standardized usage statistics
295
459
  """
296
460
 
297
- usage: StreamingUsageStats = {}
461
+ usage: TokenUsage = TokenUsage()
462
+
463
+ # Extract web search count from the chunk before checking for usage_metadata
464
+ # Web search indicators can appear on any chunk, not just those with usage data
465
+ web_search_count = extract_gemini_web_search_count(chunk)
466
+ if web_search_count > 0:
467
+ usage["web_search_count"] = web_search_count
298
468
 
299
469
  if not hasattr(chunk, "usage_metadata") or not chunk.usage_metadata:
300
470
  return usage
301
471
 
302
- # Gemini uses prompt_token_count and candidates_token_count
303
- usage["input_tokens"] = getattr(chunk.usage_metadata, "prompt_token_count", 0)
304
- usage["output_tokens"] = getattr(chunk.usage_metadata, "candidates_token_count", 0)
472
+ usage_from_metadata = _extract_usage_from_metadata(chunk.usage_metadata)
305
473
 
306
- # Calculate total if both values are defined (including 0)
307
- if "input_tokens" in usage and "output_tokens" in usage:
308
- usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"]
474
+ # Merge the usage from metadata with any web search count we found
475
+ usage.update(usage_from_metadata)
309
476
 
310
477
  return usage
311
478
 
@@ -417,22 +584,3 @@ def format_gemini_streaming_output(
417
584
 
418
585
  # Fallback for empty or unexpected input
419
586
  return [{"role": "assistant", "content": [{"type": "text", "text": ""}]}]
420
-
421
-
422
- def standardize_gemini_usage(usage: Dict[str, Any]) -> TokenUsage:
423
- """
424
- Standardize Gemini usage statistics to common TokenUsage format.
425
-
426
- Gemini already uses standard field names (input_tokens/output_tokens).
427
-
428
- Args:
429
- usage: Raw usage statistics from Gemini
430
-
431
- Returns:
432
- Standardized TokenUsage dict
433
- """
434
- return TokenUsage(
435
- input_tokens=usage.get("input_tokens", 0),
436
- output_tokens=usage.get("output_tokens", 0),
437
- # Gemini doesn't currently support cache or reasoning tokens
438
- )