inspect-ai 0.3.68__py3-none-any.whl → 0.3.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. inspect_ai/_cli/eval.py +13 -1
  2. inspect_ai/_display/plain/display.py +9 -11
  3. inspect_ai/_display/textual/app.py +5 -5
  4. inspect_ai/_display/textual/widgets/samples.py +47 -18
  5. inspect_ai/_display/textual/widgets/transcript.py +25 -12
  6. inspect_ai/_eval/eval.py +14 -2
  7. inspect_ai/_eval/evalset.py +6 -1
  8. inspect_ai/_eval/run.py +6 -0
  9. inspect_ai/_eval/task/run.py +44 -15
  10. inspect_ai/_eval/task/task.py +26 -3
  11. inspect_ai/_util/interrupt.py +15 -0
  12. inspect_ai/_util/logger.py +23 -0
  13. inspect_ai/_util/rich.py +7 -8
  14. inspect_ai/_util/text.py +301 -1
  15. inspect_ai/_util/transcript.py +10 -2
  16. inspect_ai/_util/working.py +46 -0
  17. inspect_ai/_view/www/dist/assets/index.css +56 -12
  18. inspect_ai/_view/www/dist/assets/index.js +905 -751
  19. inspect_ai/_view/www/log-schema.json +337 -2
  20. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
  21. inspect_ai/_view/www/node_modules/flatted/python/test.py +63 -0
  22. inspect_ai/_view/www/src/appearance/icons.ts +3 -1
  23. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +0 -1
  24. inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
  25. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +28 -1
  26. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
  27. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +23 -2
  28. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +1 -1
  29. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -0
  30. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
  31. inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +152 -0
  32. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +9 -2
  33. inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +19 -1
  34. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
  35. inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
  36. inspect_ai/_view/www/src/types/log.d.ts +188 -108
  37. inspect_ai/_view/www/src/utils/format.ts +7 -4
  38. inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +9 -6
  39. inspect_ai/log/__init__.py +2 -0
  40. inspect_ai/log/_condense.py +1 -0
  41. inspect_ai/log/_log.py +72 -12
  42. inspect_ai/log/_samples.py +5 -5
  43. inspect_ai/log/_transcript.py +31 -1
  44. inspect_ai/model/_call_tools.py +1 -1
  45. inspect_ai/model/_conversation.py +1 -1
  46. inspect_ai/model/_model.py +35 -16
  47. inspect_ai/model/_model_call.py +10 -3
  48. inspect_ai/model/_providers/anthropic.py +13 -2
  49. inspect_ai/model/_providers/bedrock.py +7 -0
  50. inspect_ai/model/_providers/cloudflare.py +20 -7
  51. inspect_ai/model/_providers/google.py +358 -302
  52. inspect_ai/model/_providers/groq.py +57 -23
  53. inspect_ai/model/_providers/hf.py +6 -0
  54. inspect_ai/model/_providers/mistral.py +81 -52
  55. inspect_ai/model/_providers/openai.py +9 -0
  56. inspect_ai/model/_providers/providers.py +6 -6
  57. inspect_ai/model/_providers/util/tracker.py +92 -0
  58. inspect_ai/model/_providers/vllm.py +13 -5
  59. inspect_ai/solver/_basic_agent.py +1 -3
  60. inspect_ai/solver/_bridge/patch.py +0 -2
  61. inspect_ai/solver/_limit.py +4 -4
  62. inspect_ai/solver/_plan.py +3 -3
  63. inspect_ai/solver/_solver.py +3 -0
  64. inspect_ai/solver/_task_state.py +10 -1
  65. inspect_ai/tool/_tools/_web_search.py +3 -3
  66. inspect_ai/util/_concurrency.py +14 -8
  67. inspect_ai/util/_sandbox/context.py +15 -0
  68. inspect_ai/util/_sandbox/docker/cleanup.py +8 -3
  69. inspect_ai/util/_sandbox/docker/compose.py +5 -9
  70. inspect_ai/util/_sandbox/docker/docker.py +20 -6
  71. inspect_ai/util/_sandbox/docker/util.py +10 -1
  72. inspect_ai/util/_sandbox/environment.py +32 -1
  73. inspect_ai/util/_sandbox/events.py +149 -0
  74. inspect_ai/util/_sandbox/local.py +3 -3
  75. inspect_ai/util/_sandbox/self_check.py +2 -1
  76. inspect_ai/util/_subprocess.py +4 -1
  77. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/METADATA +5 -5
  78. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/RECORD +82 -74
  79. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/LICENSE +0 -0
  80. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/WHEEL +0 -0
  81. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/entry_points.txt +0 -0
  82. {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/top_level.txt +0 -0
@@ -2,103 +2,96 @@ import asyncio
2
2
  import functools
3
3
  import hashlib
4
4
  import json
5
+ import os
5
6
  from copy import copy
6
7
  from io import BytesIO
7
8
  from logging import getLogger
8
- from typing import Any, MutableSequence, cast
9
+ from typing import Any
9
10
 
10
- import proto # type: ignore
11
- from google.ai.generativelanguage import (
12
- Blob,
11
+ # SDK Docs: https://googleapis.github.io/python-genai/
12
+ from google.genai import Client # type: ignore
13
+ from google.genai.errors import APIError, ClientError # type: ignore
14
+ from google.genai.types import ( # type: ignore
13
15
  Candidate,
14
- FunctionCall,
16
+ Content,
17
+ File,
18
+ FinishReason,
15
19
  FunctionCallingConfig,
16
20
  FunctionDeclaration,
17
21
  FunctionResponse,
22
+ GenerateContentConfig,
23
+ GenerateContentResponse,
24
+ GenerateContentResponsePromptFeedback,
25
+ GenerateContentResponseUsageMetadata,
26
+ GenerationConfig,
27
+ HarmBlockThreshold,
28
+ HarmCategory,
18
29
  Part,
30
+ SafetySetting,
31
+ SafetySettingDict,
19
32
  Schema,
33
+ Tool,
20
34
  ToolConfig,
21
35
  Type,
22
36
  )
23
- from google.api_core.exceptions import (
24
- GatewayTimeout,
25
- InternalServerError,
26
- InvalidArgument,
27
- ServiceUnavailable,
28
- TooManyRequests,
29
- )
30
- from google.api_core.retry.retry_base import if_transient_error
31
- from google.generativeai.client import configure
32
- from google.generativeai.files import get_file, upload_file
33
- from google.generativeai.generative_models import GenerativeModel
34
- from google.generativeai.types import (
35
- ContentDict,
36
- GenerationConfig,
37
- PartDict,
38
- PartType,
39
- Tool,
40
- )
41
- from google.generativeai.types.file_types import File
42
- from google.generativeai.types.generation_types import AsyncGenerateContentResponse
43
- from google.generativeai.types.safety_types import (
44
- EasySafetySettingDict,
45
- HarmBlockThreshold,
46
- HarmCategory,
47
- )
48
- from google.protobuf.json_format import MessageToDict, ParseDict
49
- from google.protobuf.struct_pb2 import Struct
50
37
  from pydantic import JsonValue
51
38
  from typing_extensions import override
52
39
 
53
40
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
41
+ from inspect_ai._util.content import Content as InspectContent
54
42
  from inspect_ai._util.content import (
55
- Content,
56
43
  ContentAudio,
57
44
  ContentImage,
58
45
  ContentText,
59
46
  ContentVideo,
60
47
  )
48
+ from inspect_ai._util.error import PrerequisiteError
61
49
  from inspect_ai._util.images import file_as_data
62
50
  from inspect_ai._util.kvstore import inspect_kvstore
63
51
  from inspect_ai._util.trace import trace_message
64
- from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
65
-
66
- from .._chat_message import (
52
+ from inspect_ai.model import (
53
+ ChatCompletionChoice,
67
54
  ChatMessage,
68
55
  ChatMessageAssistant,
69
- ChatMessageSystem,
70
56
  ChatMessageTool,
71
57
  ChatMessageUser,
72
- )
73
- from .._generate_config import GenerateConfig
74
- from .._model import ModelAPI
75
- from .._model_call import ModelCall
76
- from .._model_output import (
77
- ChatCompletionChoice,
58
+ GenerateConfig,
78
59
  Logprob,
79
60
  Logprobs,
61
+ ModelAPI,
80
62
  ModelOutput,
81
63
  ModelUsage,
82
64
  StopReason,
83
65
  TopLogprob,
84
66
  )
85
- from .util import model_base_url
67
+ from inspect_ai.model._model_call import ModelCall
68
+ from inspect_ai.model._providers.util import model_base_url
69
+ from inspect_ai.tool import (
70
+ ToolCall,
71
+ ToolChoice,
72
+ ToolFunction,
73
+ ToolInfo,
74
+ ToolParam,
75
+ ToolParams,
76
+ )
86
77
 
87
78
  logger = getLogger(__name__)
88
79
 
89
- SAFETY_SETTINGS = "safety_settings"
90
80
 
91
- DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
92
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
93
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
94
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
81
+ GOOGLE_API_KEY = "GOOGLE_API_KEY"
82
+ VERTEX_API_KEY = "VERTEX_API_KEY"
83
+
84
+ SAFETY_SETTINGS = "safety_settings"
85
+ DEFAULT_SAFETY_SETTINGS = {
86
+ HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: HarmBlockThreshold.BLOCK_NONE,
95
87
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
88
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
89
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
90
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
96
91
  }
97
92
 
98
- GOOGLE_API_KEY = "GOOGLE_API_KEY"
99
-
100
93
 
101
- class GoogleAPI(ModelAPI):
94
+ class GoogleGenAIAPI(ModelAPI):
102
95
  def __init__(
103
96
  self,
104
97
  model_name: str,
@@ -111,11 +104,11 @@ class GoogleAPI(ModelAPI):
111
104
  model_name=model_name,
112
105
  base_url=base_url,
113
106
  api_key=api_key,
114
- api_key_vars=[GOOGLE_API_KEY],
107
+ api_key_vars=[GOOGLE_API_KEY, VERTEX_API_KEY],
115
108
  config=config,
116
109
  )
117
110
 
118
- # pick out vertex safety settings and merge against default
111
+ # pick out user-provided safety settings and merge against default
119
112
  self.safety_settings = DEFAULT_SAFETY_SETTINGS.copy()
120
113
  if SAFETY_SETTINGS in model_args:
121
114
  self.safety_settings.update(
@@ -123,22 +116,79 @@ class GoogleAPI(ModelAPI):
123
116
  )
124
117
  del model_args[SAFETY_SETTINGS]
125
118
 
126
- # configure genai client
127
- base_url = model_base_url(base_url, "GOOGLE_BASE_URL")
128
- configure(
119
+ # extract any service prefix from model name
120
+ parts = model_name.split("/")
121
+ if len(parts) > 1:
122
+ self.service: str | None = parts[0]
123
+ model_name = "/".join(parts[1:])
124
+ else:
125
+ self.service = None
126
+
127
+ # vertex can also be forced by the GOOGLE_GENAI_USE_VERTEX_AI flag
128
+ if self.service is None:
129
+ if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true":
130
+ self.service = "vertex"
131
+
132
+ # ensure we haven't specified an invalid service
133
+ if self.service is not None and self.service != "vertex":
134
+ raise RuntimeError(
135
+ f"Invalid service name for google: {self.service}. "
136
+ + "Currently 'vertex' is the only supported service."
137
+ )
138
+
139
+ # handle auth (vertex or standard google api key)
140
+ if self.is_vertex():
141
+ # see if we are running in express mode (propagate api key if we are)
142
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
143
+ vertex_api_key = os.environ.get(VERTEX_API_KEY, None)
144
+ if vertex_api_key and not self.api_key:
145
+ self.api_key = vertex_api_key
146
+
147
+ # When not using express mode the GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION
148
+ # environment variables should be set, OR the 'project' and 'location' should be
149
+ # passed within the model_args.
150
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2
151
+ if not vertex_api_key:
152
+ if not os.environ.get(
153
+ "GOOGLE_CLOUD_PROJECT", None
154
+ ) and not model_args.get("project", None):
155
+ raise PrerequisiteError(
156
+ "Google provider requires either the GOOGLE_CLOUD_PROJECT environment variable "
157
+ + "or the 'project' custom model arg (-M) when running against vertex."
158
+ )
159
+ if not os.environ.get(
160
+ "GOOGLE_CLOUD_LOCATION", None
161
+ ) and not model_args.get("location", None):
162
+ raise PrerequisiteError(
163
+ "Google provider requires either the GOOGLE_CLOUD_LOCATION environment variable "
164
+ + "or the 'location' custom model arg (-M) when running against vertex."
165
+ )
166
+
167
+ # normal google endpoint
168
+ else:
169
+ # read api key from env
170
+ if not self.api_key:
171
+ self.api_key = os.environ.get(GOOGLE_API_KEY, None)
172
+
173
+ # custom base_url
174
+ base_url = model_base_url(base_url, "GOOGLE_BASE_URL")
175
+
176
+ # create client
177
+ self.client = Client(
178
+ vertexai=self.is_vertex(),
129
179
  api_key=self.api_key,
130
- client_options=dict(api_endpoint=base_url),
180
+ http_options={"base_url": base_url},
131
181
  **model_args,
132
182
  )
133
183
 
134
- # create model
135
- self.model = GenerativeModel(self.model_name)
136
-
137
184
  @override
138
185
  async def close(self) -> None:
139
186
  # GenerativeModel uses a cached/shared client so there is no 'close'
140
187
  pass
141
188
 
189
+ def is_vertex(self) -> bool:
190
+ return self.service == "vertex"
191
+
142
192
  async def generate(
143
193
  self,
144
194
  input: list[ChatMessage],
@@ -146,7 +196,11 @@ class GoogleAPI(ModelAPI):
146
196
  tool_choice: ToolChoice,
147
197
  config: GenerateConfig,
148
198
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
149
- parameters = GenerationConfig(
199
+ # Create google-genai types.
200
+ gemini_contents = await as_chat_messages(self.client, input)
201
+ gemini_tools = chat_tools(tools) if len(tools) > 0 else None
202
+ gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
203
+ parameters = GenerateContentConfig(
150
204
  temperature=config.temperature,
151
205
  top_p=config.top_p,
152
206
  top_k=config.top_k,
@@ -155,21 +209,19 @@ class GoogleAPI(ModelAPI):
155
209
  candidate_count=config.num_choices,
156
210
  presence_penalty=config.presence_penalty,
157
211
  frequency_penalty=config.frequency_penalty,
212
+ safety_settings=safety_settings_to_list(self.safety_settings),
213
+ tools=gemini_tools,
214
+ tool_config=gemini_tool_config,
215
+ system_instruction=await extract_system_message_as_parts(
216
+ self.client, input
217
+ ),
158
218
  )
159
219
 
160
- # google-native messages
161
- contents = await as_chat_messages(input)
162
-
163
- # tools
164
- gemini_tools = chat_tools(tools) if len(tools) > 0 else None
165
- gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
166
-
167
- # response for ModelCall
168
- response: AsyncGenerateContentResponse | None = None
220
+ response: GenerateContentResponse | None = None
169
221
 
170
222
  def model_call() -> ModelCall:
171
223
  return build_model_call(
172
- contents=contents,
224
+ contents=gemini_contents,
173
225
  safety_settings=self.safety_settings,
174
226
  generation_config=parameters,
175
227
  tools=gemini_tools,
@@ -177,164 +229,149 @@ class GoogleAPI(ModelAPI):
177
229
  response=response,
178
230
  )
179
231
 
232
+ # TODO: would need to monkey patch AuthorizedSession.request
233
+
180
234
  try:
181
- response = await self.model.generate_content_async(
182
- contents=contents,
183
- safety_settings=self.safety_settings,
184
- generation_config=parameters,
185
- tools=gemini_tools,
186
- tool_config=gemini_tool_config,
235
+ response = await self.client.aio.models.generate_content(
236
+ model=self.model_name,
237
+ contents=gemini_contents,
238
+ config=parameters,
187
239
  )
240
+ except ClientError as ex:
241
+ return self.handle_client_error(ex), model_call()
188
242
 
189
- except InvalidArgument as ex:
190
- return self.handle_invalid_argument(ex), model_call()
191
-
192
- # build output
193
243
  output = ModelOutput(
194
244
  model=self.model_name,
195
- choices=completion_choices_from_candidates(response.candidates),
196
- usage=ModelUsage(
197
- input_tokens=response.usage_metadata.prompt_token_count,
198
- output_tokens=response.usage_metadata.candidates_token_count,
199
- total_tokens=response.usage_metadata.total_token_count,
200
- ),
245
+ choices=completion_choices_from_candidates(response),
246
+ usage=usage_metadata_to_model_usage(response.usage_metadata),
201
247
  )
202
248
 
203
- # return
204
249
  return output, model_call()
205
250
 
206
- def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
207
- if "size exceeds the limit" in ex.message.lower():
208
- return ModelOutput.from_content(
209
- model=self.model_name, content=ex.message, stop_reason="model_length"
210
- )
211
- else:
212
- return ex
213
-
214
251
  @override
215
252
  def is_rate_limit(self, ex: BaseException) -> bool:
216
- return isinstance(
217
- ex,
218
- TooManyRequests | InternalServerError | ServiceUnavailable | GatewayTimeout,
219
- )
253
+ return isinstance(ex, APIError) and ex.code in (429, 500, 503, 504)
220
254
 
221
255
  @override
222
256
  def connection_key(self) -> str:
223
257
  """Scope for enforcing max_connections (could also use endpoint)."""
224
258
  return self.model_name
225
259
 
260
+ def handle_client_error(self, ex: ClientError) -> ModelOutput | Exception:
261
+ if (
262
+ ex.code == 400
263
+ and ex.message
264
+ and (
265
+ "maximum number of tokens" in ex.message
266
+ or "size exceeds the limit" in ex.message
267
+ )
268
+ ):
269
+ return ModelOutput.from_content(
270
+ self.model_name, content=ex.message, stop_reason="model_length"
271
+ )
272
+ else:
273
+ raise ex
274
+
275
+
276
+ def safety_settings_to_list(safety_settings: SafetySettingDict) -> list[SafetySetting]:
277
+ return [
278
+ SafetySetting(
279
+ category=category,
280
+ threshold=threshold,
281
+ )
282
+ for category, threshold in safety_settings.items()
283
+ ]
284
+
226
285
 
227
286
  def build_model_call(
228
- contents: list[ContentDict],
287
+ contents: list[Content],
229
288
  generation_config: GenerationConfig,
230
- safety_settings: EasySafetySettingDict,
289
+ safety_settings: SafetySettingDict,
231
290
  tools: list[Tool] | None,
232
291
  tool_config: ToolConfig | None,
233
- response: AsyncGenerateContentResponse | None,
292
+ response: GenerateContentResponse | None,
234
293
  ) -> ModelCall:
235
294
  return ModelCall.create(
236
295
  request=dict(
237
- contents=[model_call_content(content) for content in contents],
296
+ contents=contents,
238
297
  generation_config=generation_config,
239
298
  safety_settings=safety_settings,
240
- tools=[MessageToDict(tool._proto._pb) for tool in tools]
241
- if tools is not None
242
- else None,
243
- tool_config=MessageToDict(tool_config._pb)
244
- if tool_config is not None
245
- else None,
299
+ tools=tools if tools is not None else None,
300
+ tool_config=tool_config if tool_config is not None else None,
246
301
  ),
247
- response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
302
+ response=response if response is not None else {},
248
303
  filter=model_call_filter,
249
304
  )
250
305
 
251
306
 
252
307
  def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
253
- # remove images from raw api call
254
308
  if key == "inline_data" and isinstance(value, dict) and "data" in value:
255
309
  value = copy(value)
256
310
  value.update(data=BASE_64_DATA_REMOVED)
257
311
  return value
258
312
 
259
313
 
260
- def model_call_content(content: ContentDict) -> ContentDict:
261
- return ContentDict(
262
- role=content["role"], parts=[model_call_part(part) for part in content["parts"]]
263
- )
264
-
265
-
266
- def model_call_part(part: PartType) -> PartType:
267
- if isinstance(part, proto.Message):
268
- return cast(PartDict, MessageToDict(part._pb))
269
- elif isinstance(part, dict):
270
- part = part.copy()
271
- keys = list(part.keys())
272
- for key in keys:
273
- part[key] = model_call_part(part[key]) # type: ignore[literal-required]
274
- return part
275
- else:
276
- return part
277
-
278
-
279
- async def as_chat_messages(messages: list[ChatMessage]) -> list[ContentDict]:
280
- # google does not support system messages so filter them out to start with
281
- system_messages = [message for message in messages if message.role == "system"]
314
+ async def as_chat_messages(
315
+ client: Client, messages: list[ChatMessage]
316
+ ) -> list[Content]:
317
+ # There is no "system" role in the `google-genai` package. Instead, system messages
318
+ # are included in the `GenerateContentConfig` as a `system_instruction`. Strip any
319
+ # system messages out.
282
320
  supported_messages = [message for message in messages if message.role != "system"]
283
321
 
284
322
  # build google chat messages
285
- chat_messages = [await content_dict(message) for message in supported_messages]
286
-
287
- # we want the system messages to be prepended to the first user message
288
- # (if there is no first user message then prepend one)
289
- prepend_system_messages(chat_messages, system_messages)
323
+ chat_messages = [await content(client, message) for message in supported_messages]
290
324
 
291
325
  # combine consecutive tool messages
292
- chat_messages = functools.reduce(consective_tool_message_reducer, chat_messages, [])
326
+ chat_messages = functools.reduce(
327
+ consecutive_tool_message_reducer, chat_messages, []
328
+ )
293
329
 
294
330
  # return messages
295
331
  return chat_messages
296
332
 
297
333
 
298
- def consective_tool_message_reducer(
299
- messages: list[ContentDict],
300
- message: ContentDict,
301
- ) -> list[ContentDict]:
334
+ def consecutive_tool_message_reducer(
335
+ messages: list[Content],
336
+ message: Content,
337
+ ) -> list[Content]:
302
338
  if (
303
- message["role"] == "function"
339
+ message.role == "function"
304
340
  and len(messages) > 0
305
- and messages[-1]["role"] == "function"
341
+ and messages[-1].role == "function"
306
342
  ):
307
- messages[-1] = ContentDict(
308
- role="function", parts=messages[-1]["parts"] + message["parts"]
343
+ messages[-1] = Content(
344
+ role="function", parts=messages[-1].parts + message.parts
309
345
  )
310
346
  else:
311
347
  messages.append(message)
312
348
  return messages
313
349
 
314
350
 
315
- async def content_dict(
351
+ async def content(
352
+ client: Client,
316
353
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
317
- ) -> ContentDict:
354
+ ) -> Content:
318
355
  if isinstance(message, ChatMessageUser):
319
- return ContentDict(
356
+ if isinstance(message.content, str):
357
+ return Content(
358
+ role="user", parts=[await content_part(client, message.content)]
359
+ )
360
+ return Content(
320
361
  role="user",
321
362
  parts=(
322
- [message.content or NO_CONTENT]
323
- if isinstance(message.content, str)
324
- else [await content_part(content) for content in message.content]
363
+ [await content_part(client, content) for content in message.content]
325
364
  ),
326
365
  )
327
366
  elif isinstance(message, ChatMessageAssistant):
328
- content_parts: list[PartType] = []
367
+ content_parts: list[Part] = []
329
368
  # tool call parts
330
369
  if message.tool_calls is not None:
331
370
  content_parts.extend(
332
371
  [
333
- Part(
334
- function_call=FunctionCall(
335
- name=tool_call.function,
336
- args=dict_to_struct(tool_call.arguments),
337
- )
372
+ Part.from_function_call(
373
+ name=tool_call.function,
374
+ args=tool_call.arguments,
338
375
  )
339
376
  for tool_call in message.tool_calls
340
377
  ]
@@ -345,68 +382,62 @@ async def content_dict(
345
382
  content_parts.append(Part(text=message.content or NO_CONTENT))
346
383
  else:
347
384
  content_parts.extend(
348
- [await content_part(content) for content in message.content]
385
+ [await content_part(client, content) for content in message.content]
349
386
  )
350
387
 
351
388
  # return parts
352
- return ContentDict(role="model", parts=content_parts)
389
+ return Content(role="model", parts=content_parts)
353
390
 
354
391
  elif isinstance(message, ChatMessageTool):
355
392
  response = FunctionResponse(
356
393
  name=message.tool_call_id,
357
- response=ParseDict(
358
- js_dict={
359
- "content": (
360
- message.error.message
361
- if message.error is not None
362
- else message.text
363
- )
364
- },
365
- message=Struct(),
366
- ),
394
+ response={
395
+ "content": (
396
+ message.error.message if message.error is not None else message.text
397
+ )
398
+ },
367
399
  )
368
- return ContentDict(role="function", parts=[Part(function_response=response)])
369
-
400
+ return Content(role="function", parts=[Part(function_response=response)])
370
401
 
371
- def dict_to_struct(x: dict[str, Any]) -> Struct:
372
- struct = Struct()
373
- struct.update(x)
374
- return struct
375
402
 
376
-
377
- async def content_part(content: Content | str) -> PartType:
403
+ async def content_part(client: Client, content: InspectContent | str) -> Part:
378
404
  if isinstance(content, str):
379
- return content or NO_CONTENT
405
+ return Part.from_text(text=content or NO_CONTENT)
380
406
  elif isinstance(content, ContentText):
381
- return content.text or NO_CONTENT
407
+ return Part.from_text(text=content.text or NO_CONTENT)
382
408
  else:
383
- return await chat_content_to_part(content)
409
+ return await chat_content_to_part(client, content)
384
410
 
385
411
 
386
412
  async def chat_content_to_part(
413
+ client: Client,
387
414
  content: ContentImage | ContentAudio | ContentVideo,
388
- ) -> PartType:
415
+ ) -> Part:
389
416
  if isinstance(content, ContentImage):
390
417
  content_bytes, mime_type = await file_as_data(content.image)
391
- return Blob(mime_type=mime_type, data=content_bytes)
392
- else:
393
- return await file_for_content(content)
394
-
395
-
396
- def prepend_system_messages(
397
- messages: list[ContentDict], system_messages: list[ChatMessageSystem]
398
- ) -> None:
399
- # create system_parts
400
- system_parts: list[PartType] = [
401
- Part(text=message.text) for message in system_messages
402
- ]
403
-
404
- # we want the system messages to be prepended to the first user message
405
- # (if there is no first user message then prepend one)
406
- if len(messages) > 0 and messages[0].get("role") == "user":
407
- messages[0]["parts"] = system_parts + messages[0].get("parts", [])
418
+ return Part.from_bytes(mime_type=mime_type, data=content_bytes)
408
419
  else:
409
- messages.insert(0, ContentDict(role="user", parts=system_parts))
420
+ return await file_for_content(client, content)
421
+
422
+
423
+ async def extract_system_message_as_parts(
424
+ client: Client,
425
+ messages: list[ChatMessage],
426
+ ) -> list[Part] | None:
427
+ system_parts: list[Part] = []
428
+ for message in messages:
429
+ if message.role == "system":
430
+ content = message.content
431
+ if isinstance(content, str):
432
+ system_parts.append(Part.from_text(text=content))
433
+ elif isinstance(content, list): # list[InspectContent]
434
+ system_parts.extend(
435
+ [await content_part(client, content) for content in content]
436
+ )
437
+ else:
438
+ raise ValueError(f"Unsupported system message content: {content}")
439
+ # google-genai raises "ValueError: content is required." if the list is empty.
440
+ return system_parts or None
410
441
 
411
442
 
412
443
  def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
@@ -424,8 +455,6 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
424
455
 
425
456
 
426
457
  # https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
427
-
428
-
429
458
  def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
430
459
  if isinstance(param, ToolParams):
431
460
  param = ToolParam(
@@ -461,7 +490,7 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
461
490
  description=param.description,
462
491
  properties={k: schema_from_param(v) for k, v in param.properties.items()}
463
492
  if param.properties is not None
464
- else None,
493
+ else {},
465
494
  required=param.required,
466
495
  nullable=nullable,
467
496
  )
@@ -478,57 +507,56 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
478
507
 
479
508
 
480
509
  def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
481
- # NOTE: Google seems to sporadically return errors when being
482
- # passed a FunctionCallingConfig with mode="ANY". therefore,
483
- # we 'correct' this to "AUTO" to prevent the errors
484
- mode = "AUTO"
485
- if tool_choice == "none":
486
- mode = "NONE"
487
- return ToolConfig(function_calling_config=FunctionCallingConfig(mode=mode))
488
-
489
- # This is the 'correct' implementation if Google wasn't returning
490
- # errors for mode="ANY". we can test whether this is working properly
491
- # by commenting this back in and running pytest -k google_tools
492
- #
493
- # if isinstance(tool_choice, ToolFunction):
494
- # return ToolConfig(
495
- # function_calling_config=FunctionCallingConfig(
496
- # mode="ANY", allowed_function_names=[tool_choice.name]
497
- # )
498
- # )
499
- # else:
500
- # return ToolConfig(
501
- # function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
502
- # )
510
+ if isinstance(tool_choice, ToolFunction):
511
+ return ToolConfig(
512
+ function_calling_config=FunctionCallingConfig(
513
+ mode="ANY", allowed_function_names=[tool_choice.name]
514
+ )
515
+ )
516
+ else:
517
+ return ToolConfig(
518
+ function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
519
+ )
503
520
 
504
521
 
505
522
  def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
506
523
  # check for completion text
507
- content = " ".join(
508
- [part.text for part in candidate.content.parts if part.text is not None]
509
- )
524
+ content = ""
525
+ # content can be None when the finish_reason is SAFETY
526
+ if candidate.content is not None:
527
+ content = " ".join(
528
+ [
529
+ part.text
530
+ for part in candidate.content.parts
531
+ if part.text is not None and candidate.content is not None
532
+ ]
533
+ )
534
+
535
+ # split reasoning
536
+ reasoning, content = split_reasoning(content)
510
537
 
511
538
  # now tool calls
512
539
  tool_calls: list[ToolCall] = []
513
- for part in candidate.content.parts:
514
- if part.function_call:
515
- function_call = MessageToDict(getattr(part.function_call, "_pb"))
516
- tool_calls.append(
517
- ToolCall(
518
- type="function",
519
- id=function_call["name"],
520
- function=function_call["name"],
521
- arguments=function_call["args"],
540
+ if candidate.content is not None and candidate.content.parts is not None:
541
+ for part in candidate.content.parts:
542
+ if part.function_call:
543
+ tool_calls.append(
544
+ ToolCall(
545
+ type="function",
546
+ id=part.function_call.name,
547
+ function=part.function_call.name,
548
+ arguments=part.function_call.args,
549
+ )
522
550
  )
523
- )
524
551
 
525
552
  # stop reason
526
- stop_reason = candidate_stop_reason(candidate.finish_reason)
553
+ stop_reason = finish_reason_to_stop_reason(candidate.finish_reason)
527
554
 
528
- # build choide
555
+ # build choice
529
556
  choice = ChatCompletionChoice(
530
557
  message=ChatMessageAssistant(
531
558
  content=content,
559
+ reasoning=reasoning,
532
560
  tool_calls=tool_calls if len(tool_calls) > 0 else None,
533
561
  source="generate",
534
562
  ),
@@ -558,111 +586,144 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
558
586
 
559
587
 
560
588
  def completion_choices_from_candidates(
561
- candidates: MutableSequence[Candidate],
589
+ response: GenerateContentResponse,
562
590
  ) -> list[ChatCompletionChoice]:
591
+ candidates = response.candidates
563
592
  if candidates:
564
593
  candidates_list = sorted(candidates, key=lambda c: c.index)
565
594
  return [
566
595
  completion_choice_from_candidate(candidate) for candidate in candidates_list
567
596
  ]
568
- else:
597
+ elif response.prompt_feedback:
569
598
  return [
570
599
  ChatCompletionChoice(
571
600
  message=ChatMessageAssistant(
572
- content="I was unable to generate a response.",
601
+ content=prompt_feedback_to_content(response.prompt_feedback),
573
602
  source="generate",
574
603
  ),
575
- stop_reason="unknown",
604
+ stop_reason="content_filter",
576
605
  )
577
606
  ]
607
+ else:
608
+ raise RuntimeError(
609
+ "Google response includes no completion candidates and no block reason: "
610
+ + f"{response.model_dump_json(indent=2)}"
611
+ )
578
612
 
579
613
 
580
- # google doesn't export FinishReason (it's in a sub-namespace with a beta
581
- # designation that seems destined to change, so we vendor the enum here)
582
- class FinishReason:
583
- FINISH_REASON_UNSPECIFIED = 0
584
- STOP = 1
585
- MAX_TOKENS = 2
586
- SAFETY = 3
587
- RECITATION = 4
588
- OTHER = 5
614
+ def split_reasoning(content: str) -> tuple[str | None, str]:
615
+ separator = "\nFinal Answer: "
616
+ if separator in content:
617
+ parts = content.split(separator, 1) # dplit only on first occurrence
618
+ return parts[0].strip(), separator.lstrip() + parts[1].strip()
619
+ else:
620
+ return None, content.strip()
621
+
589
622
 
623
+ def prompt_feedback_to_content(
624
+ feedback: GenerateContentResponsePromptFeedback,
625
+ ) -> str:
626
+ content: list[str] = []
627
+ block_reason = str(feedback.block_reason) if feedback.block_reason else "UNKNOWN"
628
+ content.append(f"BLOCKED: {block_reason}")
590
629
 
591
- def candidate_stop_reason(finish_reason: FinishReason) -> StopReason:
630
+ if feedback.block_reason_message is not None:
631
+ content.append(feedback.block_reason_message)
632
+ if feedback.safety_ratings is not None:
633
+ content.extend(
634
+ [rating.model_dump_json(indent=2) for rating in feedback.safety_ratings]
635
+ )
636
+ return "\n".join(content)
637
+
638
+
639
+ def usage_metadata_to_model_usage(
640
+ metadata: GenerateContentResponseUsageMetadata,
641
+ ) -> ModelUsage | None:
642
+ if metadata is None:
643
+ return None
644
+ return ModelUsage(
645
+ input_tokens=metadata.prompt_token_count or 0,
646
+ output_tokens=metadata.candidates_token_count or 0,
647
+ total_tokens=metadata.total_token_count or 0,
648
+ )
649
+
650
+
651
+ def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
592
652
  match finish_reason:
593
653
  case FinishReason.STOP:
594
654
  return "stop"
595
655
  case FinishReason.MAX_TOKENS:
596
656
  return "max_tokens"
597
- case FinishReason.SAFETY | FinishReason.RECITATION:
657
+ case (
658
+ FinishReason.SAFETY
659
+ | FinishReason.RECITATION
660
+ | FinishReason.BLOCKLIST
661
+ | FinishReason.PROHIBITED_CONTENT
662
+ | FinishReason.SPII
663
+ ):
598
664
  return "content_filter"
599
665
  case _:
600
666
  return "unknown"
601
667
 
602
668
 
603
- def gapi_should_retry(ex: BaseException) -> bool:
604
- if isinstance(ex, Exception):
605
- return if_transient_error(ex)
606
- else:
607
- return False
608
-
609
-
610
669
  def parse_safety_settings(
611
670
  safety_settings: Any,
612
- ) -> EasySafetySettingDict:
671
+ ) -> dict[HarmCategory, HarmBlockThreshold]:
613
672
  # ensure we have a dict
614
673
  if isinstance(safety_settings, str):
615
674
  safety_settings = json.loads(safety_settings)
616
675
  if not isinstance(safety_settings, dict):
617
676
  raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
618
677
 
619
- parsed_settings: EasySafetySettingDict = {}
678
+ parsed_settings: dict[HarmCategory, HarmBlockThreshold] = {}
620
679
  for key, value in safety_settings.items():
621
- if isinstance(key, str):
622
- key = str_to_harm_category(key)
623
- if not isinstance(key, HarmCategory):
680
+ if not isinstance(key, str):
624
681
  raise ValueError(f"Unexpected type for harm category: {key}")
625
- if isinstance(value, str):
626
- value = str_to_harm_block_threshold(value)
627
- if not isinstance(value, HarmBlockThreshold):
682
+ if not isinstance(value, str):
628
683
  raise ValueError(f"Unexpected type for harm block threshold: {value}")
629
-
684
+ key = str_to_harm_category(key)
685
+ value = str_to_harm_block_threshold(value)
630
686
  parsed_settings[key] = value
631
-
632
687
  return parsed_settings
633
688
 
634
689
 
635
- def str_to_harm_category(category: str) -> int:
690
+ def str_to_harm_category(category: str) -> HarmCategory:
636
691
  category = category.upper()
692
+ # `in` instead of `==` to allow users to pass in short version e.g. "HARASSMENT" or
693
+ # long version e.g. "HARM_CATEGORY_HARASSMENT" strings.
694
+ if "CIVIC_INTEGRITY" in category:
695
+ return HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY
696
+ if "DANGEROUS_CONTENT" in category:
697
+ return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
698
+ if "HATE_SPEECH" in category:
699
+ return HarmCategory.HARM_CATEGORY_HATE_SPEECH
637
700
  if "HARASSMENT" in category:
638
- return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
639
- elif "HATE_SPEECH" in category:
640
- return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
641
- elif "SEXUALLY_EXPLICIT" in category:
642
- return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
643
- elif "DANGEROUS_CONTENT" in category:
644
- return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
645
- else:
646
- # NOTE: Although there is an "UNSPECIFIED" category, in the
647
- # documentation, the API does not accept it.
648
- raise ValueError(f"Unknown HarmCategory: {category}")
701
+ return HarmCategory.HARM_CATEGORY_HARASSMENT
702
+ if "SEXUALLY_EXPLICIT" in category:
703
+ return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
704
+ if "UNSPECIFIED" in category:
705
+ return HarmCategory.HARM_CATEGORY_UNSPECIFIED
706
+ raise ValueError(f"Unknown HarmCategory: {category}")
649
707
 
650
708
 
651
- def str_to_harm_block_threshold(threshold: str) -> int:
709
+ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
652
710
  threshold = threshold.upper()
653
711
  if "LOW" in threshold:
654
712
  return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
655
- elif "MEDIUM" in threshold:
713
+ if "MEDIUM" in threshold:
656
714
  return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
657
- elif "HIGH" in threshold:
715
+ if "HIGH" in threshold:
658
716
  return HarmBlockThreshold.BLOCK_ONLY_HIGH
659
- elif "NONE" in threshold:
717
+ if "NONE" in threshold:
660
718
  return HarmBlockThreshold.BLOCK_NONE
661
- else:
662
- raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
719
+ if "OFF" in threshold:
720
+ return HarmBlockThreshold.OFF
721
+ raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
663
722
 
664
723
 
665
- async def file_for_content(content: ContentAudio | ContentVideo) -> File:
724
+ async def file_for_content(
725
+ client: Client, content: ContentAudio | ContentVideo
726
+ ) -> File:
666
727
  # helper to write trace messages
667
728
  def trace(message: str) -> None:
668
729
  trace_message(logger, "Google Files", message)
@@ -674,7 +735,6 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
674
735
  file = content.video
675
736
  content_bytes, mime_type = await file_as_data(file)
676
737
  content_sha256 = hashlib.sha256(content_bytes).hexdigest()
677
-
678
738
  # we cache uploads for re-use, open the db where we track that
679
739
  # (track up to 1 million previous uploads)
680
740
  with inspect_kvstore("google_files", 1000000) as files_db:
@@ -682,7 +742,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
682
742
  uploaded_file = files_db.get(content_sha256)
683
743
  if uploaded_file:
684
744
  try:
685
- upload = get_file(uploaded_file)
745
+ upload: File = client.files.get(uploaded_file)
686
746
  if upload.state.name == "ACTIVE":
687
747
  trace(f"Using uploaded file: {uploaded_file}")
688
748
  return upload
@@ -693,20 +753,16 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
693
753
  except Exception as ex:
694
754
  trace(f"Error attempting to access uploaded file: {ex}")
695
755
  files_db.delete(content_sha256)
696
-
697
756
  # do the upload (and record it)
698
- upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
757
+ upload = client.files.upload(BytesIO(content_bytes), mime_type=mime_type)
699
758
  while upload.state.name == "PROCESSING":
700
759
  await asyncio.sleep(3)
701
- upload = get_file(upload.name)
702
-
760
+ upload = client.files.get(upload.name)
703
761
  if upload.state.name == "FAILED":
704
762
  trace(f"Failed to upload file '{upload.name}: {upload.error}")
705
763
  raise ValueError(f"Google file upload failed: {upload.error}")
706
-
707
764
  # trace and record it
708
765
  trace(f"Uploaded file: {upload.name}")
709
766
  files_db.put(content_sha256, upload.name)
710
-
711
767
  # return the file
712
768
  return upload