posthoganalytics 6.7.5__py3-none-any.whl → 7.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- posthoganalytics/__init__.py +84 -7
- posthoganalytics/ai/anthropic/anthropic_async.py +30 -67
- posthoganalytics/ai/anthropic/anthropic_converter.py +40 -0
- posthoganalytics/ai/gemini/__init__.py +3 -0
- posthoganalytics/ai/gemini/gemini.py +1 -1
- posthoganalytics/ai/gemini/gemini_async.py +423 -0
- posthoganalytics/ai/gemini/gemini_converter.py +160 -24
- posthoganalytics/ai/langchain/callbacks.py +55 -11
- posthoganalytics/ai/openai/openai.py +27 -2
- posthoganalytics/ai/openai/openai_async.py +49 -5
- posthoganalytics/ai/openai/openai_converter.py +130 -0
- posthoganalytics/ai/sanitization.py +27 -5
- posthoganalytics/ai/types.py +1 -0
- posthoganalytics/ai/utils.py +32 -2
- posthoganalytics/client.py +338 -90
- posthoganalytics/contexts.py +81 -0
- posthoganalytics/exception_utils.py +250 -2
- posthoganalytics/feature_flags.py +26 -10
- posthoganalytics/flag_definition_cache.py +127 -0
- posthoganalytics/integrations/django.py +149 -50
- posthoganalytics/request.py +203 -23
- posthoganalytics/test/test_client.py +250 -22
- posthoganalytics/test/test_exception_capture.py +418 -0
- posthoganalytics/test/test_feature_flag_result.py +441 -2
- posthoganalytics/test/test_feature_flags.py +306 -102
- posthoganalytics/test/test_flag_definition_cache.py +612 -0
- posthoganalytics/test/test_module.py +0 -8
- posthoganalytics/test/test_request.py +536 -0
- posthoganalytics/test/test_utils.py +4 -1
- posthoganalytics/types.py +40 -0
- posthoganalytics/version.py +1 -1
- {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/METADATA +12 -12
- posthoganalytics-7.4.3.dist-info/RECORD +57 -0
- posthoganalytics-6.7.5.dist-info/RECORD +0 -54
- {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/WHEEL +0 -0
- {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/licenses/LICENSE +0 -0
- {posthoganalytics-6.7.5.dist-info → posthoganalytics-7.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import Any, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from posthoganalytics.ai.types import TokenUsage, StreamingEventData
|
|
7
|
+
from posthoganalytics.ai.utils import merge_system_prompt
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from google import genai
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ModuleNotFoundError(
|
|
13
|
+
"Please install the Google Gemini SDK to use this feature: 'pip install google-genai'"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from posthoganalytics import setup
|
|
17
|
+
from posthoganalytics.ai.utils import (
|
|
18
|
+
call_llm_and_track_usage_async,
|
|
19
|
+
capture_streaming_event,
|
|
20
|
+
merge_usage_stats,
|
|
21
|
+
)
|
|
22
|
+
from posthoganalytics.ai.gemini.gemini_converter import (
|
|
23
|
+
extract_gemini_usage_from_chunk,
|
|
24
|
+
extract_gemini_content_from_chunk,
|
|
25
|
+
format_gemini_streaming_output,
|
|
26
|
+
)
|
|
27
|
+
from posthoganalytics.ai.sanitization import sanitize_gemini
|
|
28
|
+
from posthoganalytics.client import Client as PostHogClient
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AsyncClient:
|
|
32
|
+
"""
|
|
33
|
+
An async drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog.
|
|
34
|
+
|
|
35
|
+
Usage:
|
|
36
|
+
client = AsyncClient(
|
|
37
|
+
api_key="your_api_key",
|
|
38
|
+
posthog_client=posthog_client,
|
|
39
|
+
posthog_distinct_id="default_user", # Optional defaults
|
|
40
|
+
posthog_properties={"team": "ai"} # Optional defaults
|
|
41
|
+
)
|
|
42
|
+
response = await client.models.generate_content(
|
|
43
|
+
model="gemini-2.0-flash",
|
|
44
|
+
contents=["Hello world"],
|
|
45
|
+
posthog_distinct_id="specific_user" # Override default
|
|
46
|
+
)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
_ph_client: PostHogClient
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
api_key: Optional[str] = None,
|
|
54
|
+
vertexai: Optional[bool] = None,
|
|
55
|
+
credentials: Optional[Any] = None,
|
|
56
|
+
project: Optional[str] = None,
|
|
57
|
+
location: Optional[str] = None,
|
|
58
|
+
debug_config: Optional[Any] = None,
|
|
59
|
+
http_options: Optional[Any] = None,
|
|
60
|
+
posthog_client: Optional[PostHogClient] = None,
|
|
61
|
+
posthog_distinct_id: Optional[str] = None,
|
|
62
|
+
posthog_properties: Optional[Dict[str, Any]] = None,
|
|
63
|
+
posthog_privacy_mode: bool = False,
|
|
64
|
+
posthog_groups: Optional[Dict[str, Any]] = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Args:
|
|
69
|
+
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI)
|
|
70
|
+
vertexai: Whether to use Vertex AI authentication
|
|
71
|
+
credentials: Vertex AI credentials object
|
|
72
|
+
project: GCP project ID for Vertex AI
|
|
73
|
+
location: GCP location for Vertex AI
|
|
74
|
+
debug_config: Debug configuration for the client
|
|
75
|
+
http_options: HTTP options for the client
|
|
76
|
+
posthog_client: PostHog client for tracking usage
|
|
77
|
+
posthog_distinct_id: Default distinct ID for all calls (can be overridden per call)
|
|
78
|
+
posthog_properties: Default properties for all calls (can be overridden per call)
|
|
79
|
+
posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call)
|
|
80
|
+
posthog_groups: Default groups for all calls (can be overridden per call)
|
|
81
|
+
**kwargs: Additional arguments (for future compatibility)
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
self._ph_client = posthog_client or setup()
|
|
85
|
+
|
|
86
|
+
if self._ph_client is None:
|
|
87
|
+
raise ValueError("posthog_client is required for PostHog tracking")
|
|
88
|
+
|
|
89
|
+
self.models = AsyncModels(
|
|
90
|
+
api_key=api_key,
|
|
91
|
+
vertexai=vertexai,
|
|
92
|
+
credentials=credentials,
|
|
93
|
+
project=project,
|
|
94
|
+
location=location,
|
|
95
|
+
debug_config=debug_config,
|
|
96
|
+
http_options=http_options,
|
|
97
|
+
posthog_client=self._ph_client,
|
|
98
|
+
posthog_distinct_id=posthog_distinct_id,
|
|
99
|
+
posthog_properties=posthog_properties,
|
|
100
|
+
posthog_privacy_mode=posthog_privacy_mode,
|
|
101
|
+
posthog_groups=posthog_groups,
|
|
102
|
+
**kwargs,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class AsyncModels:
|
|
107
|
+
"""
|
|
108
|
+
Async Models interface that mimics genai.Client().aio.models with PostHog tracking.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
_ph_client: PostHogClient # Not None after __init__ validation
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
api_key: Optional[str] = None,
|
|
116
|
+
vertexai: Optional[bool] = None,
|
|
117
|
+
credentials: Optional[Any] = None,
|
|
118
|
+
project: Optional[str] = None,
|
|
119
|
+
location: Optional[str] = None,
|
|
120
|
+
debug_config: Optional[Any] = None,
|
|
121
|
+
http_options: Optional[Any] = None,
|
|
122
|
+
posthog_client: Optional[PostHogClient] = None,
|
|
123
|
+
posthog_distinct_id: Optional[str] = None,
|
|
124
|
+
posthog_properties: Optional[Dict[str, Any]] = None,
|
|
125
|
+
posthog_privacy_mode: bool = False,
|
|
126
|
+
posthog_groups: Optional[Dict[str, Any]] = None,
|
|
127
|
+
**kwargs,
|
|
128
|
+
):
|
|
129
|
+
"""
|
|
130
|
+
Args:
|
|
131
|
+
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable (not required for Vertex AI)
|
|
132
|
+
vertexai: Whether to use Vertex AI authentication
|
|
133
|
+
credentials: Vertex AI credentials object
|
|
134
|
+
project: GCP project ID for Vertex AI
|
|
135
|
+
location: GCP location for Vertex AI
|
|
136
|
+
debug_config: Debug configuration for the client
|
|
137
|
+
http_options: HTTP options for the client
|
|
138
|
+
posthog_client: PostHog client for tracking usage
|
|
139
|
+
posthog_distinct_id: Default distinct ID for all calls
|
|
140
|
+
posthog_properties: Default properties for all calls
|
|
141
|
+
posthog_privacy_mode: Default privacy mode for all calls
|
|
142
|
+
posthog_groups: Default groups for all calls
|
|
143
|
+
**kwargs: Additional arguments (for future compatibility)
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
self._ph_client = posthog_client or setup()
|
|
147
|
+
|
|
148
|
+
if self._ph_client is None:
|
|
149
|
+
raise ValueError("posthog_client is required for PostHog tracking")
|
|
150
|
+
|
|
151
|
+
# Store default PostHog settings
|
|
152
|
+
self._default_distinct_id = posthog_distinct_id
|
|
153
|
+
self._default_properties = posthog_properties or {}
|
|
154
|
+
self._default_privacy_mode = posthog_privacy_mode
|
|
155
|
+
self._default_groups = posthog_groups
|
|
156
|
+
|
|
157
|
+
# Build genai.Client arguments
|
|
158
|
+
client_args: Dict[str, Any] = {}
|
|
159
|
+
|
|
160
|
+
# Add Vertex AI parameters if provided
|
|
161
|
+
if vertexai is not None:
|
|
162
|
+
client_args["vertexai"] = vertexai
|
|
163
|
+
|
|
164
|
+
if credentials is not None:
|
|
165
|
+
client_args["credentials"] = credentials
|
|
166
|
+
|
|
167
|
+
if project is not None:
|
|
168
|
+
client_args["project"] = project
|
|
169
|
+
|
|
170
|
+
if location is not None:
|
|
171
|
+
client_args["location"] = location
|
|
172
|
+
|
|
173
|
+
if debug_config is not None:
|
|
174
|
+
client_args["debug_config"] = debug_config
|
|
175
|
+
|
|
176
|
+
if http_options is not None:
|
|
177
|
+
client_args["http_options"] = http_options
|
|
178
|
+
|
|
179
|
+
# Handle API key authentication
|
|
180
|
+
if vertexai:
|
|
181
|
+
# For Vertex AI, api_key is optional
|
|
182
|
+
if api_key is not None:
|
|
183
|
+
client_args["api_key"] = api_key
|
|
184
|
+
else:
|
|
185
|
+
# For non-Vertex AI mode, api_key is required (backwards compatibility)
|
|
186
|
+
if api_key is None:
|
|
187
|
+
api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY")
|
|
188
|
+
|
|
189
|
+
if api_key is None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
client_args["api_key"] = api_key
|
|
195
|
+
|
|
196
|
+
self._client = genai.Client(**client_args)
|
|
197
|
+
self._base_url = "https://generativelanguage.googleapis.com"
|
|
198
|
+
|
|
199
|
+
def _merge_posthog_params(
|
|
200
|
+
self,
|
|
201
|
+
call_distinct_id: Optional[str],
|
|
202
|
+
call_trace_id: Optional[str],
|
|
203
|
+
call_properties: Optional[Dict[str, Any]],
|
|
204
|
+
call_privacy_mode: Optional[bool],
|
|
205
|
+
call_groups: Optional[Dict[str, Any]],
|
|
206
|
+
):
|
|
207
|
+
"""Merge call-level PostHog parameters with client defaults."""
|
|
208
|
+
|
|
209
|
+
# Use call-level values if provided, otherwise fall back to defaults
|
|
210
|
+
distinct_id = (
|
|
211
|
+
call_distinct_id
|
|
212
|
+
if call_distinct_id is not None
|
|
213
|
+
else self._default_distinct_id
|
|
214
|
+
)
|
|
215
|
+
privacy_mode = (
|
|
216
|
+
call_privacy_mode
|
|
217
|
+
if call_privacy_mode is not None
|
|
218
|
+
else self._default_privacy_mode
|
|
219
|
+
)
|
|
220
|
+
groups = call_groups if call_groups is not None else self._default_groups
|
|
221
|
+
|
|
222
|
+
# Merge properties: default properties + call properties (call properties override)
|
|
223
|
+
properties = dict(self._default_properties)
|
|
224
|
+
|
|
225
|
+
if call_properties:
|
|
226
|
+
properties.update(call_properties)
|
|
227
|
+
|
|
228
|
+
if call_trace_id is None:
|
|
229
|
+
call_trace_id = str(uuid.uuid4())
|
|
230
|
+
|
|
231
|
+
return distinct_id, call_trace_id, properties, privacy_mode, groups
|
|
232
|
+
|
|
233
|
+
async def generate_content(
|
|
234
|
+
self,
|
|
235
|
+
model: str,
|
|
236
|
+
contents,
|
|
237
|
+
posthog_distinct_id: Optional[str] = None,
|
|
238
|
+
posthog_trace_id: Optional[str] = None,
|
|
239
|
+
posthog_properties: Optional[Dict[str, Any]] = None,
|
|
240
|
+
posthog_privacy_mode: Optional[bool] = None,
|
|
241
|
+
posthog_groups: Optional[Dict[str, Any]] = None,
|
|
242
|
+
**kwargs: Any,
|
|
243
|
+
):
|
|
244
|
+
"""
|
|
245
|
+
Generate content using Gemini's API while tracking usage in PostHog.
|
|
246
|
+
|
|
247
|
+
This method signature exactly matches genai.Client().aio.models.generate_content()
|
|
248
|
+
with additional PostHog tracking parameters.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
model: The model to use (e.g., 'gemini-2.0-flash')
|
|
252
|
+
contents: The input content for generation
|
|
253
|
+
posthog_distinct_id: ID to associate with the usage event (overrides client default)
|
|
254
|
+
posthog_trace_id: Trace UUID for linking events (auto-generated if not provided)
|
|
255
|
+
posthog_properties: Extra properties to include in the event (merged with client defaults)
|
|
256
|
+
posthog_privacy_mode: Whether to redact sensitive information (overrides client default)
|
|
257
|
+
posthog_groups: Group analytics properties (overrides client default)
|
|
258
|
+
**kwargs: Arguments passed to Gemini's generate_content
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
# Merge PostHog parameters
|
|
262
|
+
distinct_id, trace_id, properties, privacy_mode, groups = (
|
|
263
|
+
self._merge_posthog_params(
|
|
264
|
+
posthog_distinct_id,
|
|
265
|
+
posthog_trace_id,
|
|
266
|
+
posthog_properties,
|
|
267
|
+
posthog_privacy_mode,
|
|
268
|
+
posthog_groups,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
kwargs_with_contents = {"model": model, "contents": contents, **kwargs}
|
|
273
|
+
|
|
274
|
+
return await call_llm_and_track_usage_async(
|
|
275
|
+
distinct_id,
|
|
276
|
+
self._ph_client,
|
|
277
|
+
"gemini",
|
|
278
|
+
trace_id,
|
|
279
|
+
properties,
|
|
280
|
+
privacy_mode,
|
|
281
|
+
groups,
|
|
282
|
+
self._base_url,
|
|
283
|
+
self._client.aio.models.generate_content,
|
|
284
|
+
**kwargs_with_contents,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
async def _generate_content_streaming(
|
|
288
|
+
self,
|
|
289
|
+
model: str,
|
|
290
|
+
contents,
|
|
291
|
+
distinct_id: Optional[str],
|
|
292
|
+
trace_id: Optional[str],
|
|
293
|
+
properties: Optional[Dict[str, Any]],
|
|
294
|
+
privacy_mode: bool,
|
|
295
|
+
groups: Optional[Dict[str, Any]],
|
|
296
|
+
**kwargs: Any,
|
|
297
|
+
):
|
|
298
|
+
start_time = time.time()
|
|
299
|
+
usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0)
|
|
300
|
+
accumulated_content = []
|
|
301
|
+
|
|
302
|
+
kwargs_without_stream = {"model": model, "contents": contents, **kwargs}
|
|
303
|
+
response = await self._client.aio.models.generate_content_stream(
|
|
304
|
+
**kwargs_without_stream
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
async def async_generator():
|
|
308
|
+
nonlocal usage_stats
|
|
309
|
+
nonlocal accumulated_content
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
async for chunk in response:
|
|
313
|
+
# Extract usage stats from chunk
|
|
314
|
+
chunk_usage = extract_gemini_usage_from_chunk(chunk)
|
|
315
|
+
|
|
316
|
+
if chunk_usage:
|
|
317
|
+
# Gemini reports cumulative totals, not incremental values
|
|
318
|
+
merge_usage_stats(usage_stats, chunk_usage, mode="cumulative")
|
|
319
|
+
|
|
320
|
+
# Extract content from chunk (now returns content blocks)
|
|
321
|
+
content_block = extract_gemini_content_from_chunk(chunk)
|
|
322
|
+
|
|
323
|
+
if content_block is not None:
|
|
324
|
+
accumulated_content.append(content_block)
|
|
325
|
+
|
|
326
|
+
yield chunk
|
|
327
|
+
|
|
328
|
+
finally:
|
|
329
|
+
end_time = time.time()
|
|
330
|
+
latency = end_time - start_time
|
|
331
|
+
|
|
332
|
+
self._capture_streaming_event(
|
|
333
|
+
model,
|
|
334
|
+
contents,
|
|
335
|
+
distinct_id,
|
|
336
|
+
trace_id,
|
|
337
|
+
properties,
|
|
338
|
+
privacy_mode,
|
|
339
|
+
groups,
|
|
340
|
+
kwargs,
|
|
341
|
+
usage_stats,
|
|
342
|
+
latency,
|
|
343
|
+
accumulated_content,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
return async_generator()
|
|
347
|
+
|
|
348
|
+
def _capture_streaming_event(
|
|
349
|
+
self,
|
|
350
|
+
model: str,
|
|
351
|
+
contents,
|
|
352
|
+
distinct_id: Optional[str],
|
|
353
|
+
trace_id: Optional[str],
|
|
354
|
+
properties: Optional[Dict[str, Any]],
|
|
355
|
+
privacy_mode: bool,
|
|
356
|
+
groups: Optional[Dict[str, Any]],
|
|
357
|
+
kwargs: Dict[str, Any],
|
|
358
|
+
usage_stats: TokenUsage,
|
|
359
|
+
latency: float,
|
|
360
|
+
output: Any,
|
|
361
|
+
):
|
|
362
|
+
# Prepare standardized event data
|
|
363
|
+
formatted_input = self._format_input(contents, **kwargs)
|
|
364
|
+
sanitized_input = sanitize_gemini(formatted_input)
|
|
365
|
+
|
|
366
|
+
event_data = StreamingEventData(
|
|
367
|
+
provider="gemini",
|
|
368
|
+
model=model,
|
|
369
|
+
base_url=self._base_url,
|
|
370
|
+
kwargs=kwargs,
|
|
371
|
+
formatted_input=sanitized_input,
|
|
372
|
+
formatted_output=format_gemini_streaming_output(output),
|
|
373
|
+
usage_stats=usage_stats,
|
|
374
|
+
latency=latency,
|
|
375
|
+
distinct_id=distinct_id,
|
|
376
|
+
trace_id=trace_id,
|
|
377
|
+
properties=properties,
|
|
378
|
+
privacy_mode=privacy_mode,
|
|
379
|
+
groups=groups,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# Use the common capture function
|
|
383
|
+
capture_streaming_event(self._ph_client, event_data)
|
|
384
|
+
|
|
385
|
+
def _format_input(self, contents, **kwargs):
|
|
386
|
+
"""Format input contents for PostHog tracking"""
|
|
387
|
+
|
|
388
|
+
# Create kwargs dict with contents for merge_system_prompt
|
|
389
|
+
input_kwargs = {"contents": contents, **kwargs}
|
|
390
|
+
return merge_system_prompt(input_kwargs, "gemini")
|
|
391
|
+
|
|
392
|
+
async def generate_content_stream(
|
|
393
|
+
self,
|
|
394
|
+
model: str,
|
|
395
|
+
contents,
|
|
396
|
+
posthog_distinct_id: Optional[str] = None,
|
|
397
|
+
posthog_trace_id: Optional[str] = None,
|
|
398
|
+
posthog_properties: Optional[Dict[str, Any]] = None,
|
|
399
|
+
posthog_privacy_mode: Optional[bool] = None,
|
|
400
|
+
posthog_groups: Optional[Dict[str, Any]] = None,
|
|
401
|
+
**kwargs: Any,
|
|
402
|
+
):
|
|
403
|
+
# Merge PostHog parameters
|
|
404
|
+
distinct_id, trace_id, properties, privacy_mode, groups = (
|
|
405
|
+
self._merge_posthog_params(
|
|
406
|
+
posthog_distinct_id,
|
|
407
|
+
posthog_trace_id,
|
|
408
|
+
posthog_properties,
|
|
409
|
+
posthog_privacy_mode,
|
|
410
|
+
posthog_groups,
|
|
411
|
+
)
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
return await self._generate_content_streaming(
|
|
415
|
+
model,
|
|
416
|
+
contents,
|
|
417
|
+
distinct_id,
|
|
418
|
+
trace_id,
|
|
419
|
+
properties,
|
|
420
|
+
privacy_mode,
|
|
421
|
+
groups,
|
|
422
|
+
**kwargs,
|
|
423
|
+
)
|
|
@@ -29,35 +29,76 @@ class GeminiMessage(TypedDict, total=False):
|
|
|
29
29
|
text: str
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def
|
|
32
|
+
def _format_parts_as_content_blocks(parts: List[Any]) -> List[FormattedContentItem]:
|
|
33
33
|
"""
|
|
34
|
-
|
|
34
|
+
Format Gemini parts array into structured content blocks.
|
|
35
|
+
|
|
36
|
+
Preserves structure for multimodal content (text + images) instead of
|
|
37
|
+
concatenating everything into a string.
|
|
35
38
|
|
|
36
39
|
Args:
|
|
37
|
-
parts: List of parts that may contain text
|
|
40
|
+
parts: List of parts that may contain text, inline_data, etc.
|
|
38
41
|
|
|
39
42
|
Returns:
|
|
40
|
-
|
|
43
|
+
List of formatted content blocks
|
|
41
44
|
"""
|
|
42
|
-
|
|
43
|
-
content_parts = []
|
|
45
|
+
content_blocks: List[FormattedContentItem] = []
|
|
44
46
|
|
|
45
47
|
for part in parts:
|
|
48
|
+
# Handle dict with text field
|
|
46
49
|
if isinstance(part, dict) and "text" in part:
|
|
47
|
-
|
|
50
|
+
content_blocks.append({"type": "text", "text": part["text"]})
|
|
48
51
|
|
|
52
|
+
# Handle string parts
|
|
49
53
|
elif isinstance(part, str):
|
|
50
|
-
|
|
54
|
+
content_blocks.append({"type": "text", "text": part})
|
|
55
|
+
|
|
56
|
+
# Handle dict with inline_data (images, documents, etc.)
|
|
57
|
+
elif isinstance(part, dict) and "inline_data" in part:
|
|
58
|
+
inline_data = part["inline_data"]
|
|
59
|
+
mime_type = inline_data.get("mime_type", "")
|
|
60
|
+
content_type = "image" if mime_type.startswith("image/") else "document"
|
|
61
|
+
|
|
62
|
+
content_blocks.append(
|
|
63
|
+
{
|
|
64
|
+
"type": content_type,
|
|
65
|
+
"inline_data": inline_data,
|
|
66
|
+
}
|
|
67
|
+
)
|
|
51
68
|
|
|
69
|
+
# Handle object with text attribute
|
|
52
70
|
elif hasattr(part, "text"):
|
|
53
|
-
# Get the text attribute value
|
|
54
71
|
text_value = getattr(part, "text", "")
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
72
|
+
if text_value:
|
|
73
|
+
content_blocks.append({"type": "text", "text": text_value})
|
|
74
|
+
|
|
75
|
+
# Handle object with inline_data attribute
|
|
76
|
+
elif hasattr(part, "inline_data"):
|
|
77
|
+
inline_data = part.inline_data
|
|
78
|
+
# Convert to dict if needed
|
|
79
|
+
if hasattr(inline_data, "mime_type") and hasattr(inline_data, "data"):
|
|
80
|
+
# Determine type based on mime_type
|
|
81
|
+
mime_type = inline_data.mime_type
|
|
82
|
+
content_type = "image" if mime_type.startswith("image/") else "document"
|
|
83
|
+
|
|
84
|
+
content_blocks.append(
|
|
85
|
+
{
|
|
86
|
+
"type": content_type,
|
|
87
|
+
"inline_data": {
|
|
88
|
+
"mime_type": mime_type,
|
|
89
|
+
"data": inline_data.data,
|
|
90
|
+
},
|
|
91
|
+
}
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
content_blocks.append(
|
|
95
|
+
{
|
|
96
|
+
"type": "image",
|
|
97
|
+
"inline_data": inline_data,
|
|
98
|
+
}
|
|
99
|
+
)
|
|
59
100
|
|
|
60
|
-
return
|
|
101
|
+
return content_blocks
|
|
61
102
|
|
|
62
103
|
|
|
63
104
|
def _format_dict_message(item: Dict[str, Any]) -> FormattedMessage:
|
|
@@ -73,16 +114,17 @@ def _format_dict_message(item: Dict[str, Any]) -> FormattedMessage:
|
|
|
73
114
|
|
|
74
115
|
# Handle dict format with parts array (Gemini-specific format)
|
|
75
116
|
if "parts" in item and isinstance(item["parts"], list):
|
|
76
|
-
|
|
77
|
-
return {"role": item.get("role", "user"), "content":
|
|
117
|
+
content_blocks = _format_parts_as_content_blocks(item["parts"])
|
|
118
|
+
return {"role": item.get("role", "user"), "content": content_blocks}
|
|
78
119
|
|
|
79
120
|
# Handle dict with content field
|
|
80
121
|
if "content" in item:
|
|
81
122
|
content = item["content"]
|
|
82
123
|
|
|
83
124
|
if isinstance(content, list):
|
|
84
|
-
# If content is a list,
|
|
85
|
-
|
|
125
|
+
# If content is a list, format it as content blocks
|
|
126
|
+
content_blocks = _format_parts_as_content_blocks(content)
|
|
127
|
+
return {"role": item.get("role", "user"), "content": content_blocks}
|
|
86
128
|
|
|
87
129
|
elif not isinstance(content, str):
|
|
88
130
|
content = str(content)
|
|
@@ -110,14 +152,14 @@ def _format_object_message(item: Any) -> FormattedMessage:
|
|
|
110
152
|
|
|
111
153
|
# Handle object with parts attribute
|
|
112
154
|
if hasattr(item, "parts") and hasattr(item.parts, "__iter__"):
|
|
113
|
-
|
|
155
|
+
content_blocks = _format_parts_as_content_blocks(list(item.parts))
|
|
114
156
|
role = getattr(item, "role", "user") if hasattr(item, "role") else "user"
|
|
115
157
|
|
|
116
158
|
# Ensure role is a string
|
|
117
159
|
if not isinstance(role, str):
|
|
118
160
|
role = "user"
|
|
119
161
|
|
|
120
|
-
return {"role": role, "content":
|
|
162
|
+
return {"role": role, "content": content_blocks}
|
|
121
163
|
|
|
122
164
|
# Handle object with text attribute
|
|
123
165
|
if hasattr(item, "text"):
|
|
@@ -140,7 +182,8 @@ def _format_object_message(item: Any) -> FormattedMessage:
|
|
|
140
182
|
content = item.content
|
|
141
183
|
|
|
142
184
|
if isinstance(content, list):
|
|
143
|
-
|
|
185
|
+
content_blocks = _format_parts_as_content_blocks(content)
|
|
186
|
+
return {"role": role, "content": content_blocks}
|
|
144
187
|
|
|
145
188
|
elif not isinstance(content, str):
|
|
146
189
|
content = str(content)
|
|
@@ -193,6 +236,29 @@ def format_gemini_response(response: Any) -> List[FormattedMessage]:
|
|
|
193
236
|
}
|
|
194
237
|
)
|
|
195
238
|
|
|
239
|
+
elif hasattr(part, "inline_data") and part.inline_data:
|
|
240
|
+
# Handle audio/media inline data
|
|
241
|
+
import base64
|
|
242
|
+
|
|
243
|
+
inline_data = part.inline_data
|
|
244
|
+
mime_type = getattr(inline_data, "mime_type", "audio/pcm")
|
|
245
|
+
raw_data = getattr(inline_data, "data", b"")
|
|
246
|
+
|
|
247
|
+
# Encode binary data as base64 string for JSON serialization
|
|
248
|
+
if isinstance(raw_data, bytes):
|
|
249
|
+
data = base64.b64encode(raw_data).decode("utf-8")
|
|
250
|
+
else:
|
|
251
|
+
# Already a string (base64)
|
|
252
|
+
data = raw_data
|
|
253
|
+
|
|
254
|
+
content.append(
|
|
255
|
+
{
|
|
256
|
+
"type": "audio",
|
|
257
|
+
"mime_type": mime_type,
|
|
258
|
+
"data": data,
|
|
259
|
+
}
|
|
260
|
+
)
|
|
261
|
+
|
|
196
262
|
if content:
|
|
197
263
|
output.append(
|
|
198
264
|
{
|
|
@@ -338,6 +404,61 @@ def format_gemini_input(contents: Any) -> List[FormattedMessage]:
|
|
|
338
404
|
return [_format_object_message(contents)]
|
|
339
405
|
|
|
340
406
|
|
|
407
|
+
def extract_gemini_web_search_count(response: Any) -> int:
|
|
408
|
+
"""
|
|
409
|
+
Extract web search count from Gemini response.
|
|
410
|
+
|
|
411
|
+
Gemini bills per request that uses grounding, not per query.
|
|
412
|
+
Returns 1 if grounding_metadata is present with actual search data, 0 otherwise.
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
response: The response from Gemini API
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
1 if web search/grounding was used, 0 otherwise
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
# Check for grounding_metadata in candidates
|
|
422
|
+
if hasattr(response, "candidates"):
|
|
423
|
+
for candidate in response.candidates:
|
|
424
|
+
if (
|
|
425
|
+
hasattr(candidate, "grounding_metadata")
|
|
426
|
+
and candidate.grounding_metadata
|
|
427
|
+
):
|
|
428
|
+
grounding_metadata = candidate.grounding_metadata
|
|
429
|
+
|
|
430
|
+
# Check if web_search_queries exists and is non-empty
|
|
431
|
+
if hasattr(grounding_metadata, "web_search_queries"):
|
|
432
|
+
queries = grounding_metadata.web_search_queries
|
|
433
|
+
|
|
434
|
+
if queries is not None and len(queries) > 0:
|
|
435
|
+
return 1
|
|
436
|
+
|
|
437
|
+
# Check if grounding_chunks exists and is non-empty
|
|
438
|
+
if hasattr(grounding_metadata, "grounding_chunks"):
|
|
439
|
+
chunks = grounding_metadata.grounding_chunks
|
|
440
|
+
|
|
441
|
+
if chunks is not None and len(chunks) > 0:
|
|
442
|
+
return 1
|
|
443
|
+
|
|
444
|
+
# Also check for google_search or grounding in function call names
|
|
445
|
+
if hasattr(candidate, "content") and candidate.content:
|
|
446
|
+
if hasattr(candidate.content, "parts") and candidate.content.parts:
|
|
447
|
+
for part in candidate.content.parts:
|
|
448
|
+
if hasattr(part, "function_call") and part.function_call:
|
|
449
|
+
function_name = getattr(
|
|
450
|
+
part.function_call, "name", ""
|
|
451
|
+
).lower()
|
|
452
|
+
|
|
453
|
+
if (
|
|
454
|
+
"google_search" in function_name
|
|
455
|
+
or "grounding" in function_name
|
|
456
|
+
):
|
|
457
|
+
return 1
|
|
458
|
+
|
|
459
|
+
return 0
|
|
460
|
+
|
|
461
|
+
|
|
341
462
|
def _extract_usage_from_metadata(metadata: Any) -> TokenUsage:
|
|
342
463
|
"""
|
|
343
464
|
Common logic to extract usage from Gemini metadata.
|
|
@@ -382,7 +503,14 @@ def extract_gemini_usage_from_response(response: Any) -> TokenUsage:
|
|
|
382
503
|
if not hasattr(response, "usage_metadata") or not response.usage_metadata:
|
|
383
504
|
return TokenUsage(input_tokens=0, output_tokens=0)
|
|
384
505
|
|
|
385
|
-
|
|
506
|
+
usage = _extract_usage_from_metadata(response.usage_metadata)
|
|
507
|
+
|
|
508
|
+
# Add web search count if present
|
|
509
|
+
web_search_count = extract_gemini_web_search_count(response)
|
|
510
|
+
if web_search_count > 0:
|
|
511
|
+
usage["web_search_count"] = web_search_count
|
|
512
|
+
|
|
513
|
+
return usage
|
|
386
514
|
|
|
387
515
|
|
|
388
516
|
def extract_gemini_usage_from_chunk(chunk: Any) -> TokenUsage:
|
|
@@ -398,11 +526,19 @@ def extract_gemini_usage_from_chunk(chunk: Any) -> TokenUsage:
|
|
|
398
526
|
|
|
399
527
|
usage: TokenUsage = TokenUsage()
|
|
400
528
|
|
|
529
|
+
# Extract web search count from the chunk before checking for usage_metadata
|
|
530
|
+
# Web search indicators can appear on any chunk, not just those with usage data
|
|
531
|
+
web_search_count = extract_gemini_web_search_count(chunk)
|
|
532
|
+
if web_search_count > 0:
|
|
533
|
+
usage["web_search_count"] = web_search_count
|
|
534
|
+
|
|
401
535
|
if not hasattr(chunk, "usage_metadata") or not chunk.usage_metadata:
|
|
402
536
|
return usage
|
|
403
537
|
|
|
404
|
-
|
|
405
|
-
|
|
538
|
+
usage_from_metadata = _extract_usage_from_metadata(chunk.usage_metadata)
|
|
539
|
+
|
|
540
|
+
# Merge the usage from metadata with any web search count we found
|
|
541
|
+
usage.update(usage_from_metadata)
|
|
406
542
|
|
|
407
543
|
return usage
|
|
408
544
|
|