inspect-ai 0.3.68__py3-none-any.whl → 0.3.69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,103 +2,96 @@ import asyncio
2
2
  import functools
3
3
  import hashlib
4
4
  import json
5
+ import os
5
6
  from copy import copy
6
7
  from io import BytesIO
7
8
  from logging import getLogger
8
- from typing import Any, MutableSequence, cast
9
+ from typing import Any
9
10
 
10
- import proto # type: ignore
11
- from google.ai.generativelanguage import (
12
- Blob,
11
+ # SDK Docs: https://googleapis.github.io/python-genai/
12
+ from google.genai import Client # type: ignore
13
+ from google.genai.errors import APIError, ClientError # type: ignore
14
+ from google.genai.types import ( # type: ignore
13
15
  Candidate,
14
- FunctionCall,
16
+ Content,
17
+ File,
18
+ FinishReason,
15
19
  FunctionCallingConfig,
16
20
  FunctionDeclaration,
17
21
  FunctionResponse,
22
+ GenerateContentConfig,
23
+ GenerateContentResponse,
24
+ GenerateContentResponsePromptFeedback,
25
+ GenerateContentResponseUsageMetadata,
26
+ GenerationConfig,
27
+ HarmBlockThreshold,
28
+ HarmCategory,
18
29
  Part,
30
+ SafetySetting,
31
+ SafetySettingDict,
19
32
  Schema,
33
+ Tool,
20
34
  ToolConfig,
21
35
  Type,
22
36
  )
23
- from google.api_core.exceptions import (
24
- GatewayTimeout,
25
- InternalServerError,
26
- InvalidArgument,
27
- ServiceUnavailable,
28
- TooManyRequests,
29
- )
30
- from google.api_core.retry.retry_base import if_transient_error
31
- from google.generativeai.client import configure
32
- from google.generativeai.files import get_file, upload_file
33
- from google.generativeai.generative_models import GenerativeModel
34
- from google.generativeai.types import (
35
- ContentDict,
36
- GenerationConfig,
37
- PartDict,
38
- PartType,
39
- Tool,
40
- )
41
- from google.generativeai.types.file_types import File
42
- from google.generativeai.types.generation_types import AsyncGenerateContentResponse
43
- from google.generativeai.types.safety_types import (
44
- EasySafetySettingDict,
45
- HarmBlockThreshold,
46
- HarmCategory,
47
- )
48
- from google.protobuf.json_format import MessageToDict, ParseDict
49
- from google.protobuf.struct_pb2 import Struct
50
37
  from pydantic import JsonValue
51
38
  from typing_extensions import override
52
39
 
53
40
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
41
+ from inspect_ai._util.content import Content as InspectContent
54
42
  from inspect_ai._util.content import (
55
- Content,
56
43
  ContentAudio,
57
44
  ContentImage,
58
45
  ContentText,
59
46
  ContentVideo,
60
47
  )
48
+ from inspect_ai._util.error import PrerequisiteError
61
49
  from inspect_ai._util.images import file_as_data
62
50
  from inspect_ai._util.kvstore import inspect_kvstore
63
51
  from inspect_ai._util.trace import trace_message
64
- from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
65
-
66
- from .._chat_message import (
52
+ from inspect_ai.model import (
53
+ ChatCompletionChoice,
67
54
  ChatMessage,
68
55
  ChatMessageAssistant,
69
- ChatMessageSystem,
70
56
  ChatMessageTool,
71
57
  ChatMessageUser,
72
- )
73
- from .._generate_config import GenerateConfig
74
- from .._model import ModelAPI
75
- from .._model_call import ModelCall
76
- from .._model_output import (
77
- ChatCompletionChoice,
58
+ GenerateConfig,
78
59
  Logprob,
79
60
  Logprobs,
61
+ ModelAPI,
80
62
  ModelOutput,
81
63
  ModelUsage,
82
64
  StopReason,
83
65
  TopLogprob,
84
66
  )
85
- from .util import model_base_url
67
+ from inspect_ai.model._model_call import ModelCall
68
+ from inspect_ai.model._providers.util import model_base_url
69
+ from inspect_ai.tool import (
70
+ ToolCall,
71
+ ToolChoice,
72
+ ToolFunction,
73
+ ToolInfo,
74
+ ToolParam,
75
+ ToolParams,
76
+ )
86
77
 
87
78
  logger = getLogger(__name__)
88
79
 
89
- SAFETY_SETTINGS = "safety_settings"
90
80
 
91
- DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
92
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
93
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
94
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
81
+ GOOGLE_API_KEY = "GOOGLE_API_KEY"
82
+ VERTEX_API_KEY = "VERTEX_API_KEY"
83
+
84
+ SAFETY_SETTINGS = "safety_settings"
85
+ DEFAULT_SAFETY_SETTINGS = {
86
+ HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: HarmBlockThreshold.BLOCK_NONE,
95
87
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
88
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
89
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
90
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
96
91
  }
97
92
 
98
- GOOGLE_API_KEY = "GOOGLE_API_KEY"
99
-
100
93
 
101
- class GoogleAPI(ModelAPI):
94
+ class GoogleGenAIAPI(ModelAPI):
102
95
  def __init__(
103
96
  self,
104
97
  model_name: str,
@@ -111,11 +104,11 @@ class GoogleAPI(ModelAPI):
111
104
  model_name=model_name,
112
105
  base_url=base_url,
113
106
  api_key=api_key,
114
- api_key_vars=[GOOGLE_API_KEY],
107
+ api_key_vars=[GOOGLE_API_KEY, VERTEX_API_KEY],
115
108
  config=config,
116
109
  )
117
110
 
118
- # pick out vertex safety settings and merge against default
111
+ # pick out user-provided safety settings and merge against default
119
112
  self.safety_settings = DEFAULT_SAFETY_SETTINGS.copy()
120
113
  if SAFETY_SETTINGS in model_args:
121
114
  self.safety_settings.update(
@@ -123,22 +116,79 @@ class GoogleAPI(ModelAPI):
123
116
  )
124
117
  del model_args[SAFETY_SETTINGS]
125
118
 
126
- # configure genai client
127
- base_url = model_base_url(base_url, "GOOGLE_BASE_URL")
128
- configure(
119
+ # extract any service prefix from model name
120
+ parts = model_name.split("/")
121
+ if len(parts) > 1:
122
+ self.service: str | None = parts[0]
123
+ model_name = "/".join(parts[1:])
124
+ else:
125
+ self.service = None
126
+
127
+ # vertex can also be forced by the GOOGLE_GENAI_USE_VERTEX_AI flag
128
+ if self.service is None:
129
+ if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true":
130
+ self.service = "vertex"
131
+
132
+ # ensure we haven't specified an invalid service
133
+ if self.service is not None and self.service != "vertex":
134
+ raise RuntimeError(
135
+ f"Invalid service name for google: {self.service}. "
136
+ + "Currently 'vertex' is the only supported service."
137
+ )
138
+
139
+ # handle auth (vertex or standard google api key)
140
+ if self.is_vertex():
141
+ # see if we are running in express mode (propagate api key if we are)
142
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
143
+ vertex_api_key = os.environ.get(VERTEX_API_KEY, None)
144
+ if vertex_api_key and not self.api_key:
145
+ self.api_key = vertex_api_key
146
+
147
+ # When not using express mode the GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION
148
+ # environment variables should be set, OR the 'project' and 'location' should be
149
+ # passed within the model_args.
150
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2
151
+ if not vertex_api_key:
152
+ if not os.environ.get(
153
+ "GOOGLE_CLOUD_PROJECT", None
154
+ ) and not model_args.get("project", None):
155
+ raise PrerequisiteError(
156
+ "Google provider requires either the GOOGLE_CLOUD_PROJECT environment variable "
157
+ + "or the 'project' custom model arg (-M) when running against vertex."
158
+ )
159
+ if not os.environ.get(
160
+ "GOOGLE_CLOUD_LOCATION", None
161
+ ) and not model_args.get("location", None):
162
+ raise PrerequisiteError(
163
+ "Google provider requires either the GOOGLE_CLOUD_LOCATION environment variable "
164
+ + "or the 'location' custom model arg (-M) when running against vertex."
165
+ )
166
+
167
+ # normal google endpoint
168
+ else:
169
+ # read api key from env
170
+ if not self.api_key:
171
+ self.api_key = os.environ.get(GOOGLE_API_KEY, None)
172
+
173
+ # custom base_url
174
+ base_url = model_base_url(base_url, "GOOGLE_BASE_URL")
175
+
176
+ # create client
177
+ self.client = Client(
178
+ vertexai=self.is_vertex(),
129
179
  api_key=self.api_key,
130
- client_options=dict(api_endpoint=base_url),
180
+ http_options={"base_url": base_url},
131
181
  **model_args,
132
182
  )
133
183
 
134
- # create model
135
- self.model = GenerativeModel(self.model_name)
136
-
137
184
  @override
138
185
  async def close(self) -> None:
139
186
  # GenerativeModel uses a cached/shared client so there is no 'close'
140
187
  pass
141
188
 
189
+ def is_vertex(self) -> bool:
190
+ return self.service == "vertex"
191
+
142
192
  async def generate(
143
193
  self,
144
194
  input: list[ChatMessage],
@@ -146,7 +196,11 @@ class GoogleAPI(ModelAPI):
146
196
  tool_choice: ToolChoice,
147
197
  config: GenerateConfig,
148
198
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
149
- parameters = GenerationConfig(
199
+ # Create google-genai types.
200
+ gemini_contents = await as_chat_messages(self.client, input)
201
+ gemini_tools = chat_tools(tools) if len(tools) > 0 else None
202
+ gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
203
+ parameters = GenerateContentConfig(
150
204
  temperature=config.temperature,
151
205
  top_p=config.top_p,
152
206
  top_k=config.top_k,
@@ -155,21 +209,19 @@ class GoogleAPI(ModelAPI):
155
209
  candidate_count=config.num_choices,
156
210
  presence_penalty=config.presence_penalty,
157
211
  frequency_penalty=config.frequency_penalty,
212
+ safety_settings=safety_settings_to_list(self.safety_settings),
213
+ tools=gemini_tools,
214
+ tool_config=gemini_tool_config,
215
+ system_instruction=await extract_system_message_as_parts(
216
+ self.client, input
217
+ ),
158
218
  )
159
219
 
160
- # google-native messages
161
- contents = await as_chat_messages(input)
162
-
163
- # tools
164
- gemini_tools = chat_tools(tools) if len(tools) > 0 else None
165
- gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
166
-
167
- # response for ModelCall
168
- response: AsyncGenerateContentResponse | None = None
220
+ response: GenerateContentResponse | None = None
169
221
 
170
222
  def model_call() -> ModelCall:
171
223
  return build_model_call(
172
- contents=contents,
224
+ contents=gemini_contents,
173
225
  safety_settings=self.safety_settings,
174
226
  generation_config=parameters,
175
227
  tools=gemini_tools,
@@ -178,163 +230,146 @@ class GoogleAPI(ModelAPI):
178
230
  )
179
231
 
180
232
  try:
181
- response = await self.model.generate_content_async(
182
- contents=contents,
183
- safety_settings=self.safety_settings,
184
- generation_config=parameters,
185
- tools=gemini_tools,
186
- tool_config=gemini_tool_config,
233
+ response = await self.client.aio.models.generate_content(
234
+ model=self.model_name,
235
+ contents=gemini_contents,
236
+ config=parameters,
187
237
  )
238
+ except ClientError as ex:
239
+ return self.handle_client_error(ex), model_call()
188
240
 
189
- except InvalidArgument as ex:
190
- return self.handle_invalid_argument(ex), model_call()
191
-
192
- # build output
193
241
  output = ModelOutput(
194
242
  model=self.model_name,
195
- choices=completion_choices_from_candidates(response.candidates),
196
- usage=ModelUsage(
197
- input_tokens=response.usage_metadata.prompt_token_count,
198
- output_tokens=response.usage_metadata.candidates_token_count,
199
- total_tokens=response.usage_metadata.total_token_count,
200
- ),
243
+ choices=completion_choices_from_candidates(response),
244
+ usage=usage_metadata_to_model_usage(response.usage_metadata),
201
245
  )
202
246
 
203
- # return
204
247
  return output, model_call()
205
248
 
206
- def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
207
- if "size exceeds the limit" in ex.message.lower():
208
- return ModelOutput.from_content(
209
- model=self.model_name, content=ex.message, stop_reason="model_length"
210
- )
211
- else:
212
- return ex
213
-
214
249
  @override
215
250
  def is_rate_limit(self, ex: BaseException) -> bool:
216
- return isinstance(
217
- ex,
218
- TooManyRequests | InternalServerError | ServiceUnavailable | GatewayTimeout,
219
- )
251
+ return isinstance(ex, APIError) and ex.code in (429, 500, 503, 504)
220
252
 
221
253
  @override
222
254
  def connection_key(self) -> str:
223
255
  """Scope for enforcing max_connections (could also use endpoint)."""
224
256
  return self.model_name
225
257
 
258
+ def handle_client_error(self, ex: ClientError) -> ModelOutput | Exception:
259
+ if (
260
+ ex.code == 400
261
+ and ex.message
262
+ and (
263
+ "maximum number of tokens" in ex.message
264
+ or "size exceeds the limit" in ex.message
265
+ )
266
+ ):
267
+ return ModelOutput.from_content(
268
+ self.model_name, content=ex.message, stop_reason="model_length"
269
+ )
270
+ else:
271
+ raise ex
272
+
273
+
274
+ def safety_settings_to_list(safety_settings: SafetySettingDict) -> list[SafetySetting]:
275
+ return [
276
+ SafetySetting(
277
+ category=category,
278
+ threshold=threshold,
279
+ )
280
+ for category, threshold in safety_settings.items()
281
+ ]
282
+
226
283
 
227
284
  def build_model_call(
228
- contents: list[ContentDict],
285
+ contents: list[Content],
229
286
  generation_config: GenerationConfig,
230
- safety_settings: EasySafetySettingDict,
287
+ safety_settings: SafetySettingDict,
231
288
  tools: list[Tool] | None,
232
289
  tool_config: ToolConfig | None,
233
- response: AsyncGenerateContentResponse | None,
290
+ response: GenerateContentResponse | None,
234
291
  ) -> ModelCall:
235
292
  return ModelCall.create(
236
293
  request=dict(
237
- contents=[model_call_content(content) for content in contents],
294
+ contents=contents,
238
295
  generation_config=generation_config,
239
296
  safety_settings=safety_settings,
240
- tools=[MessageToDict(tool._proto._pb) for tool in tools]
241
- if tools is not None
242
- else None,
243
- tool_config=MessageToDict(tool_config._pb)
244
- if tool_config is not None
245
- else None,
297
+ tools=tools if tools is not None else None,
298
+ tool_config=tool_config if tool_config is not None else None,
246
299
  ),
247
- response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
300
+ response=response if response is not None else {},
248
301
  filter=model_call_filter,
249
302
  )
250
303
 
251
304
 
252
305
  def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
253
- # remove images from raw api call
254
306
  if key == "inline_data" and isinstance(value, dict) and "data" in value:
255
307
  value = copy(value)
256
308
  value.update(data=BASE_64_DATA_REMOVED)
257
309
  return value
258
310
 
259
311
 
260
- def model_call_content(content: ContentDict) -> ContentDict:
261
- return ContentDict(
262
- role=content["role"], parts=[model_call_part(part) for part in content["parts"]]
263
- )
264
-
265
-
266
- def model_call_part(part: PartType) -> PartType:
267
- if isinstance(part, proto.Message):
268
- return cast(PartDict, MessageToDict(part._pb))
269
- elif isinstance(part, dict):
270
- part = part.copy()
271
- keys = list(part.keys())
272
- for key in keys:
273
- part[key] = model_call_part(part[key]) # type: ignore[literal-required]
274
- return part
275
- else:
276
- return part
277
-
278
-
279
- async def as_chat_messages(messages: list[ChatMessage]) -> list[ContentDict]:
280
- # google does not support system messages so filter them out to start with
281
- system_messages = [message for message in messages if message.role == "system"]
312
+ async def as_chat_messages(
313
+ client: Client, messages: list[ChatMessage]
314
+ ) -> list[Content]:
315
+ # There is no "system" role in the `google-genai` package. Instead, system messages
316
+ # are included in the `GenerateContentConfig` as a `system_instruction`. Strip any
317
+ # system messages out.
282
318
  supported_messages = [message for message in messages if message.role != "system"]
283
319
 
284
320
  # build google chat messages
285
- chat_messages = [await content_dict(message) for message in supported_messages]
286
-
287
- # we want the system messages to be prepended to the first user message
288
- # (if there is no first user message then prepend one)
289
- prepend_system_messages(chat_messages, system_messages)
321
+ chat_messages = [await content(client, message) for message in supported_messages]
290
322
 
291
323
  # combine consecutive tool messages
292
- chat_messages = functools.reduce(consective_tool_message_reducer, chat_messages, [])
324
+ chat_messages = functools.reduce(
325
+ consecutive_tool_message_reducer, chat_messages, []
326
+ )
293
327
 
294
328
  # return messages
295
329
  return chat_messages
296
330
 
297
331
 
298
- def consective_tool_message_reducer(
299
- messages: list[ContentDict],
300
- message: ContentDict,
301
- ) -> list[ContentDict]:
332
+ def consecutive_tool_message_reducer(
333
+ messages: list[Content],
334
+ message: Content,
335
+ ) -> list[Content]:
302
336
  if (
303
- message["role"] == "function"
337
+ message.role == "function"
304
338
  and len(messages) > 0
305
- and messages[-1]["role"] == "function"
339
+ and messages[-1].role == "function"
306
340
  ):
307
- messages[-1] = ContentDict(
308
- role="function", parts=messages[-1]["parts"] + message["parts"]
341
+ messages[-1] = Content(
342
+ role="function", parts=messages[-1].parts + message.parts
309
343
  )
310
344
  else:
311
345
  messages.append(message)
312
346
  return messages
313
347
 
314
348
 
315
- async def content_dict(
349
+ async def content(
350
+ client: Client,
316
351
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
317
- ) -> ContentDict:
352
+ ) -> Content:
318
353
  if isinstance(message, ChatMessageUser):
319
- return ContentDict(
354
+ if isinstance(message.content, str):
355
+ return Content(
356
+ role="user", parts=[await content_part(client, message.content)]
357
+ )
358
+ return Content(
320
359
  role="user",
321
360
  parts=(
322
- [message.content or NO_CONTENT]
323
- if isinstance(message.content, str)
324
- else [await content_part(content) for content in message.content]
361
+ [await content_part(client, content) for content in message.content]
325
362
  ),
326
363
  )
327
364
  elif isinstance(message, ChatMessageAssistant):
328
- content_parts: list[PartType] = []
365
+ content_parts: list[Part] = []
329
366
  # tool call parts
330
367
  if message.tool_calls is not None:
331
368
  content_parts.extend(
332
369
  [
333
- Part(
334
- function_call=FunctionCall(
335
- name=tool_call.function,
336
- args=dict_to_struct(tool_call.arguments),
337
- )
370
+ Part.from_function_call(
371
+ name=tool_call.function,
372
+ args=tool_call.arguments,
338
373
  )
339
374
  for tool_call in message.tool_calls
340
375
  ]
@@ -345,68 +380,62 @@ async def content_dict(
345
380
  content_parts.append(Part(text=message.content or NO_CONTENT))
346
381
  else:
347
382
  content_parts.extend(
348
- [await content_part(content) for content in message.content]
383
+ [await content_part(client, content) for content in message.content]
349
384
  )
350
385
 
351
386
  # return parts
352
- return ContentDict(role="model", parts=content_parts)
387
+ return Content(role="model", parts=content_parts)
353
388
 
354
389
  elif isinstance(message, ChatMessageTool):
355
390
  response = FunctionResponse(
356
391
  name=message.tool_call_id,
357
- response=ParseDict(
358
- js_dict={
359
- "content": (
360
- message.error.message
361
- if message.error is not None
362
- else message.text
363
- )
364
- },
365
- message=Struct(),
366
- ),
392
+ response={
393
+ "content": (
394
+ message.error.message if message.error is not None else message.text
395
+ )
396
+ },
367
397
  )
368
- return ContentDict(role="function", parts=[Part(function_response=response)])
369
-
398
+ return Content(role="function", parts=[Part(function_response=response)])
370
399
 
371
- def dict_to_struct(x: dict[str, Any]) -> Struct:
372
- struct = Struct()
373
- struct.update(x)
374
- return struct
375
400
 
376
-
377
- async def content_part(content: Content | str) -> PartType:
401
+ async def content_part(client: Client, content: InspectContent | str) -> Part:
378
402
  if isinstance(content, str):
379
- return content or NO_CONTENT
403
+ return Part.from_text(text=content or NO_CONTENT)
380
404
  elif isinstance(content, ContentText):
381
- return content.text or NO_CONTENT
405
+ return Part.from_text(text=content.text or NO_CONTENT)
382
406
  else:
383
- return await chat_content_to_part(content)
407
+ return await chat_content_to_part(client, content)
384
408
 
385
409
 
386
410
  async def chat_content_to_part(
411
+ client: Client,
387
412
  content: ContentImage | ContentAudio | ContentVideo,
388
- ) -> PartType:
413
+ ) -> Part:
389
414
  if isinstance(content, ContentImage):
390
415
  content_bytes, mime_type = await file_as_data(content.image)
391
- return Blob(mime_type=mime_type, data=content_bytes)
392
- else:
393
- return await file_for_content(content)
394
-
395
-
396
- def prepend_system_messages(
397
- messages: list[ContentDict], system_messages: list[ChatMessageSystem]
398
- ) -> None:
399
- # create system_parts
400
- system_parts: list[PartType] = [
401
- Part(text=message.text) for message in system_messages
402
- ]
403
-
404
- # we want the system messages to be prepended to the first user message
405
- # (if there is no first user message then prepend one)
406
- if len(messages) > 0 and messages[0].get("role") == "user":
407
- messages[0]["parts"] = system_parts + messages[0].get("parts", [])
416
+ return Part.from_bytes(mime_type=mime_type, data=content_bytes)
408
417
  else:
409
- messages.insert(0, ContentDict(role="user", parts=system_parts))
418
+ return await file_for_content(client, content)
419
+
420
+
421
+ async def extract_system_message_as_parts(
422
+ client: Client,
423
+ messages: list[ChatMessage],
424
+ ) -> list[Part] | None:
425
+ system_parts: list[Part] = []
426
+ for message in messages:
427
+ if message.role == "system":
428
+ content = message.content
429
+ if isinstance(content, str):
430
+ system_parts.append(Part.from_text(text=content))
431
+ elif isinstance(content, list): # list[InspectContent]
432
+ system_parts.extend(
433
+ [await content_part(client, content) for content in content]
434
+ )
435
+ else:
436
+ raise ValueError(f"Unsupported system message content: {content}")
437
+ # google-genai raises "ValueError: content is required." if the list is empty.
438
+ return system_parts or None
410
439
 
411
440
 
412
441
  def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
@@ -424,8 +453,6 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
424
453
 
425
454
 
426
455
  # https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
427
-
428
-
429
456
  def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
430
457
  if isinstance(param, ToolParams):
431
458
  param = ToolParam(
@@ -461,7 +488,7 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
461
488
  description=param.description,
462
489
  properties={k: schema_from_param(v) for k, v in param.properties.items()}
463
490
  if param.properties is not None
464
- else None,
491
+ else {},
465
492
  required=param.required,
466
493
  nullable=nullable,
467
494
  )
@@ -478,57 +505,56 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
478
505
 
479
506
 
480
507
  def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
481
- # NOTE: Google seems to sporadically return errors when being
482
- # passed a FunctionCallingConfig with mode="ANY". therefore,
483
- # we 'correct' this to "AUTO" to prevent the errors
484
- mode = "AUTO"
485
- if tool_choice == "none":
486
- mode = "NONE"
487
- return ToolConfig(function_calling_config=FunctionCallingConfig(mode=mode))
488
-
489
- # This is the 'correct' implementation if Google wasn't returning
490
- # errors for mode="ANY". we can test whether this is working properly
491
- # by commenting this back in and running pytest -k google_tools
492
- #
493
- # if isinstance(tool_choice, ToolFunction):
494
- # return ToolConfig(
495
- # function_calling_config=FunctionCallingConfig(
496
- # mode="ANY", allowed_function_names=[tool_choice.name]
497
- # )
498
- # )
499
- # else:
500
- # return ToolConfig(
501
- # function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
502
- # )
508
+ if isinstance(tool_choice, ToolFunction):
509
+ return ToolConfig(
510
+ function_calling_config=FunctionCallingConfig(
511
+ mode="ANY", allowed_function_names=[tool_choice.name]
512
+ )
513
+ )
514
+ else:
515
+ return ToolConfig(
516
+ function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
517
+ )
503
518
 
504
519
 
505
520
  def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
506
521
  # check for completion text
507
- content = " ".join(
508
- [part.text for part in candidate.content.parts if part.text is not None]
509
- )
522
+ content = ""
523
+ # content can be None when the finish_reason is SAFETY
524
+ if candidate.content is not None:
525
+ content = " ".join(
526
+ [
527
+ part.text
528
+ for part in candidate.content.parts
529
+ if part.text is not None and candidate.content is not None
530
+ ]
531
+ )
532
+
533
+ # split reasoning
534
+ reasoning, content = split_reasoning(content)
510
535
 
511
536
  # now tool calls
512
537
  tool_calls: list[ToolCall] = []
513
- for part in candidate.content.parts:
514
- if part.function_call:
515
- function_call = MessageToDict(getattr(part.function_call, "_pb"))
516
- tool_calls.append(
517
- ToolCall(
518
- type="function",
519
- id=function_call["name"],
520
- function=function_call["name"],
521
- arguments=function_call["args"],
538
+ if candidate.content is not None and candidate.content.parts is not None:
539
+ for part in candidate.content.parts:
540
+ if part.function_call:
541
+ tool_calls.append(
542
+ ToolCall(
543
+ type="function",
544
+ id=part.function_call.name,
545
+ function=part.function_call.name,
546
+ arguments=part.function_call.args,
547
+ )
522
548
  )
523
- )
524
549
 
525
550
  # stop reason
526
- stop_reason = candidate_stop_reason(candidate.finish_reason)
551
+ stop_reason = finish_reason_to_stop_reason(candidate.finish_reason)
527
552
 
528
- # build choide
553
+ # build choice
529
554
  choice = ChatCompletionChoice(
530
555
  message=ChatMessageAssistant(
531
556
  content=content,
557
+ reasoning=reasoning,
532
558
  tool_calls=tool_calls if len(tool_calls) > 0 else None,
533
559
  source="generate",
534
560
  ),
@@ -558,111 +584,144 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
558
584
 
559
585
 
560
586
  def completion_choices_from_candidates(
561
- candidates: MutableSequence[Candidate],
587
+ response: GenerateContentResponse,
562
588
  ) -> list[ChatCompletionChoice]:
589
+ candidates = response.candidates
563
590
  if candidates:
564
591
  candidates_list = sorted(candidates, key=lambda c: c.index)
565
592
  return [
566
593
  completion_choice_from_candidate(candidate) for candidate in candidates_list
567
594
  ]
568
- else:
595
+ elif response.prompt_feedback:
569
596
  return [
570
597
  ChatCompletionChoice(
571
598
  message=ChatMessageAssistant(
572
- content="I was unable to generate a response.",
599
+ content=prompt_feedback_to_content(response.prompt_feedback),
573
600
  source="generate",
574
601
  ),
575
- stop_reason="unknown",
602
+ stop_reason="content_filter",
576
603
  )
577
604
  ]
605
+ else:
606
+ raise RuntimeError(
607
+ "Google response includes no completion candidates and no block reason: "
608
+ + f"{response.model_dump_json(indent=2)}"
609
+ )
578
610
 
579
611
 
580
- # google doesn't export FinishReason (it's in a sub-namespace with a beta
581
- # designation that seems destined to change, so we vendor the enum here)
582
- class FinishReason:
583
- FINISH_REASON_UNSPECIFIED = 0
584
- STOP = 1
585
- MAX_TOKENS = 2
586
- SAFETY = 3
587
- RECITATION = 4
588
- OTHER = 5
612
+ def split_reasoning(content: str) -> tuple[str | None, str]:
613
+ separator = "\nFinal Answer: "
614
+ if separator in content:
615
+ parts = content.split(separator, 1) # dplit only on first occurrence
616
+ return parts[0].strip(), separator.lstrip() + parts[1].strip()
617
+ else:
618
+ return None, content.strip()
619
+
589
620
 
621
+ def prompt_feedback_to_content(
622
+ feedback: GenerateContentResponsePromptFeedback,
623
+ ) -> str:
624
+ content: list[str] = []
625
+ block_reason = str(feedback.block_reason) if feedback.block_reason else "UNKNOWN"
626
+ content.append(f"BLOCKED: {block_reason}")
590
627
 
591
- def candidate_stop_reason(finish_reason: FinishReason) -> StopReason:
628
+ if feedback.block_reason_message is not None:
629
+ content.append(feedback.block_reason_message)
630
+ if feedback.safety_ratings is not None:
631
+ content.extend(
632
+ [rating.model_dump_json(indent=2) for rating in feedback.safety_ratings]
633
+ )
634
+ return "\n".join(content)
635
+
636
+
637
+ def usage_metadata_to_model_usage(
638
+ metadata: GenerateContentResponseUsageMetadata,
639
+ ) -> ModelUsage | None:
640
+ if metadata is None:
641
+ return None
642
+ return ModelUsage(
643
+ input_tokens=metadata.prompt_token_count or 0,
644
+ output_tokens=metadata.candidates_token_count or 0,
645
+ total_tokens=metadata.total_token_count or 0,
646
+ )
647
+
648
+
649
+ def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
592
650
  match finish_reason:
593
651
  case FinishReason.STOP:
594
652
  return "stop"
595
653
  case FinishReason.MAX_TOKENS:
596
654
  return "max_tokens"
597
- case FinishReason.SAFETY | FinishReason.RECITATION:
655
+ case (
656
+ FinishReason.SAFETY
657
+ | FinishReason.RECITATION
658
+ | FinishReason.BLOCKLIST
659
+ | FinishReason.PROHIBITED_CONTENT
660
+ | FinishReason.SPII
661
+ ):
598
662
  return "content_filter"
599
663
  case _:
600
664
  return "unknown"
601
665
 
602
666
 
603
- def gapi_should_retry(ex: BaseException) -> bool:
604
- if isinstance(ex, Exception):
605
- return if_transient_error(ex)
606
- else:
607
- return False
608
-
609
-
610
667
  def parse_safety_settings(
611
668
  safety_settings: Any,
612
- ) -> EasySafetySettingDict:
669
+ ) -> dict[HarmCategory, HarmBlockThreshold]:
613
670
  # ensure we have a dict
614
671
  if isinstance(safety_settings, str):
615
672
  safety_settings = json.loads(safety_settings)
616
673
  if not isinstance(safety_settings, dict):
617
674
  raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
618
675
 
619
- parsed_settings: EasySafetySettingDict = {}
676
+ parsed_settings: dict[HarmCategory, HarmBlockThreshold] = {}
620
677
  for key, value in safety_settings.items():
621
- if isinstance(key, str):
622
- key = str_to_harm_category(key)
623
- if not isinstance(key, HarmCategory):
678
+ if not isinstance(key, str):
624
679
  raise ValueError(f"Unexpected type for harm category: {key}")
625
- if isinstance(value, str):
626
- value = str_to_harm_block_threshold(value)
627
- if not isinstance(value, HarmBlockThreshold):
680
+ if not isinstance(value, str):
628
681
  raise ValueError(f"Unexpected type for harm block threshold: {value}")
629
-
682
+ key = str_to_harm_category(key)
683
+ value = str_to_harm_block_threshold(value)
630
684
  parsed_settings[key] = value
631
-
632
685
  return parsed_settings
633
686
 
634
687
 
635
- def str_to_harm_category(category: str) -> int:
688
+ def str_to_harm_category(category: str) -> HarmCategory:
636
689
  category = category.upper()
690
+ # `in` instead of `==` to allow users to pass in short version e.g. "HARASSMENT" or
691
+ # long version e.g. "HARM_CATEGORY_HARASSMENT" strings.
692
+ if "CIVIC_INTEGRITY" in category:
693
+ return HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY
694
+ if "DANGEROUS_CONTENT" in category:
695
+ return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
696
+ if "HATE_SPEECH" in category:
697
+ return HarmCategory.HARM_CATEGORY_HATE_SPEECH
637
698
  if "HARASSMENT" in category:
638
- return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
639
- elif "HATE_SPEECH" in category:
640
- return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
641
- elif "SEXUALLY_EXPLICIT" in category:
642
- return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
643
- elif "DANGEROUS_CONTENT" in category:
644
- return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
645
- else:
646
- # NOTE: Although there is an "UNSPECIFIED" category, in the
647
- # documentation, the API does not accept it.
648
- raise ValueError(f"Unknown HarmCategory: {category}")
699
+ return HarmCategory.HARM_CATEGORY_HARASSMENT
700
+ if "SEXUALLY_EXPLICIT" in category:
701
+ return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
702
+ if "UNSPECIFIED" in category:
703
+ return HarmCategory.HARM_CATEGORY_UNSPECIFIED
704
+ raise ValueError(f"Unknown HarmCategory: {category}")
649
705
 
650
706
 
651
- def str_to_harm_block_threshold(threshold: str) -> int:
707
+ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
652
708
  threshold = threshold.upper()
653
709
  if "LOW" in threshold:
654
710
  return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
655
- elif "MEDIUM" in threshold:
711
+ if "MEDIUM" in threshold:
656
712
  return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
657
- elif "HIGH" in threshold:
713
+ if "HIGH" in threshold:
658
714
  return HarmBlockThreshold.BLOCK_ONLY_HIGH
659
- elif "NONE" in threshold:
715
+ if "NONE" in threshold:
660
716
  return HarmBlockThreshold.BLOCK_NONE
661
- else:
662
- raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
717
+ if "OFF" in threshold:
718
+ return HarmBlockThreshold.OFF
719
+ raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
663
720
 
664
721
 
665
- async def file_for_content(content: ContentAudio | ContentVideo) -> File:
722
+ async def file_for_content(
723
+ client: Client, content: ContentAudio | ContentVideo
724
+ ) -> File:
666
725
  # helper to write trace messages
667
726
  def trace(message: str) -> None:
668
727
  trace_message(logger, "Google Files", message)
@@ -674,7 +733,6 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
674
733
  file = content.video
675
734
  content_bytes, mime_type = await file_as_data(file)
676
735
  content_sha256 = hashlib.sha256(content_bytes).hexdigest()
677
-
678
736
  # we cache uploads for re-use, open the db where we track that
679
737
  # (track up to 1 million previous uploads)
680
738
  with inspect_kvstore("google_files", 1000000) as files_db:
@@ -682,7 +740,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
682
740
  uploaded_file = files_db.get(content_sha256)
683
741
  if uploaded_file:
684
742
  try:
685
- upload = get_file(uploaded_file)
743
+ upload: File = client.files.get(uploaded_file)
686
744
  if upload.state.name == "ACTIVE":
687
745
  trace(f"Using uploaded file: {uploaded_file}")
688
746
  return upload
@@ -693,20 +751,16 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
693
751
  except Exception as ex:
694
752
  trace(f"Error attempting to access uploaded file: {ex}")
695
753
  files_db.delete(content_sha256)
696
-
697
754
  # do the upload (and record it)
698
- upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
755
+ upload = client.files.upload(BytesIO(content_bytes), mime_type=mime_type)
699
756
  while upload.state.name == "PROCESSING":
700
757
  await asyncio.sleep(3)
701
- upload = get_file(upload.name)
702
-
758
+ upload = client.files.get(upload.name)
703
759
  if upload.state.name == "FAILED":
704
760
  trace(f"Failed to upload file '{upload.name}: {upload.error}")
705
761
  raise ValueError(f"Google file upload failed: {upload.error}")
706
-
707
762
  # trace and record it
708
763
  trace(f"Uploaded file: {upload.name}")
709
764
  files_db.put(content_sha256, upload.name)
710
-
711
765
  # return the file
712
766
  return upload