inspect-ai 0.3.57__py3-none-any.whl → 0.3.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_cli/common.py +7 -3
  3. inspect_ai/_cli/eval.py +17 -2
  4. inspect_ai/_cli/trace.py +21 -2
  5. inspect_ai/_display/core/active.py +4 -3
  6. inspect_ai/_display/core/config.py +3 -3
  7. inspect_ai/_display/core/panel.py +7 -3
  8. inspect_ai/_display/plain/__init__.py +0 -0
  9. inspect_ai/_display/plain/display.py +203 -0
  10. inspect_ai/_display/rich/display.py +4 -9
  11. inspect_ai/_display/textual/app.py +4 -1
  12. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  13. inspect_ai/_display/textual/widgets/samples.py +119 -16
  14. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  15. inspect_ai/_eval/eval.py +32 -20
  16. inspect_ai/_eval/evalset.py +7 -5
  17. inspect_ai/_eval/score.py +1 -0
  18. inspect_ai/_eval/task/__init__.py +2 -2
  19. inspect_ai/_eval/task/images.py +40 -25
  20. inspect_ai/_eval/task/results.py +50 -22
  21. inspect_ai/_eval/task/run.py +180 -124
  22. inspect_ai/_eval/task/sandbox.py +10 -5
  23. inspect_ai/_eval/task/task.py +140 -25
  24. inspect_ai/_util/constants.py +2 -0
  25. inspect_ai/_util/content.py +23 -1
  26. inspect_ai/_util/images.py +20 -17
  27. inspect_ai/_util/kvstore.py +73 -0
  28. inspect_ai/_util/notgiven.py +18 -0
  29. inspect_ai/_util/port_names.py +61 -0
  30. inspect_ai/_util/text.py +23 -0
  31. inspect_ai/_util/thread.py +5 -0
  32. inspect_ai/_view/www/App.css +31 -1
  33. inspect_ai/_view/www/dist/assets/index.css +31 -1
  34. inspect_ai/_view/www/dist/assets/index.js +25375 -1846
  35. inspect_ai/_view/www/log-schema.json +129 -15
  36. inspect_ai/_view/www/package.json +2 -0
  37. inspect_ai/_view/www/src/App.mjs +8 -10
  38. inspect_ai/_view/www/src/Types.mjs +0 -1
  39. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  40. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  41. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  42. inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
  43. inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
  44. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  45. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  46. inspect_ai/_view/www/src/index.js +75 -2
  47. inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
  48. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
  49. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  50. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  51. inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
  52. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  53. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +29 -13
  54. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
  55. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  56. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +62 -27
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/Json.mjs +12 -6
  76. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
  77. inspect_ai/_view/www/vite.config.js +7 -0
  78. inspect_ai/_view/www/yarn.lock +116 -0
  79. inspect_ai/approval/_human/__init__.py +0 -0
  80. inspect_ai/approval/_human/util.py +2 -2
  81. inspect_ai/approval/_policy.py +12 -6
  82. inspect_ai/dataset/_sources/csv.py +2 -1
  83. inspect_ai/dataset/_sources/json.py +2 -1
  84. inspect_ai/dataset/_sources/util.py +15 -7
  85. inspect_ai/log/_condense.py +11 -1
  86. inspect_ai/log/_log.py +3 -6
  87. inspect_ai/log/_recorders/eval.py +19 -8
  88. inspect_ai/log/_samples.py +26 -5
  89. inspect_ai/log/_transcript.py +32 -2
  90. inspect_ai/model/__init__.py +10 -2
  91. inspect_ai/model/_call_tools.py +59 -12
  92. inspect_ai/model/_chat_message.py +2 -4
  93. inspect_ai/model/_conversation.py +61 -0
  94. inspect_ai/model/_generate_config.py +10 -4
  95. inspect_ai/model/_model.py +117 -18
  96. inspect_ai/model/_model_output.py +7 -2
  97. inspect_ai/model/_providers/anthropic.py +109 -51
  98. inspect_ai/model/_providers/azureai.py +26 -24
  99. inspect_ai/model/_providers/bedrock.py +43 -44
  100. inspect_ai/model/_providers/google.py +121 -58
  101. inspect_ai/model/_providers/groq.py +7 -5
  102. inspect_ai/model/_providers/hf.py +11 -6
  103. inspect_ai/model/_providers/mistral.py +17 -20
  104. inspect_ai/model/_providers/openai.py +32 -21
  105. inspect_ai/model/_providers/openai_o1.py +9 -8
  106. inspect_ai/model/_providers/providers.py +1 -1
  107. inspect_ai/model/_providers/together.py +8 -8
  108. inspect_ai/model/_providers/vertex.py +18 -8
  109. inspect_ai/scorer/__init__.py +13 -2
  110. inspect_ai/scorer/_metrics/__init__.py +2 -2
  111. inspect_ai/scorer/_metrics/std.py +3 -3
  112. inspect_ai/scorer/_reducer/reducer.py +1 -1
  113. inspect_ai/scorer/_scorer.py +2 -2
  114. inspect_ai/solver/__init__.py +2 -5
  115. inspect_ai/solver/_prompt.py +35 -5
  116. inspect_ai/solver/_task_state.py +80 -38
  117. inspect_ai/tool/__init__.py +11 -1
  118. inspect_ai/tool/_tool.py +21 -3
  119. inspect_ai/tool/_tool_call.py +10 -0
  120. inspect_ai/tool/_tool_def.py +16 -5
  121. inspect_ai/tool/_tool_with.py +21 -4
  122. inspect_ai/tool/beta/__init__.py +5 -0
  123. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  124. inspect_ai/tool/beta/_computer/_common.py +133 -0
  125. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  126. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  127. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  128. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  129. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  130. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  131. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  134. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  135. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  136. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  137. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  138. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  139. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  144. inspect_ai/util/__init__.py +2 -3
  145. inspect_ai/util/{_trace.py → _conversation.py} +3 -17
  146. inspect_ai/util/_display.py +14 -4
  147. inspect_ai/util/_limit.py +26 -0
  148. inspect_ai/util/_sandbox/context.py +12 -13
  149. inspect_ai/util/_sandbox/docker/compose.py +24 -11
  150. inspect_ai/util/_sandbox/docker/docker.py +84 -14
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/environment.py +27 -1
  153. inspect_ai/util/_sandbox/local.py +1 -0
  154. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
  155. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +159 -128
  156. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  157. inspect_ai/model/_trace.py +0 -48
  158. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
  159. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
  160. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
  161. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,10 @@
1
+ import asyncio
1
2
  import functools
3
+ import hashlib
2
4
  import json
3
5
  from copy import copy
6
+ from io import BytesIO
7
+ from logging import getLogger
4
8
  from typing import Any, cast
5
9
 
6
10
  import proto # type: ignore
@@ -24,29 +28,39 @@ from google.api_core.exceptions import (
24
28
  TooManyRequests,
25
29
  )
26
30
  from google.api_core.retry.retry_base import if_transient_error
27
- from google.generativeai import ( # type: ignore
28
- GenerationConfig,
29
- GenerativeModel,
30
- configure,
31
- )
32
- from google.generativeai.types import ( # type: ignore
33
- AsyncGenerateContentResponse,
31
+ from google.generativeai.client import configure
32
+ from google.generativeai.files import get_file, upload_file
33
+ from google.generativeai.generative_models import GenerativeModel
34
+ from google.generativeai.types import (
34
35
  ContentDict,
35
- HarmBlockThreshold,
36
- HarmCategory,
36
+ GenerationConfig,
37
37
  PartDict,
38
38
  PartType,
39
- SafetySettingDict,
40
39
  Tool,
41
40
  )
41
+ from google.generativeai.types.file_types import File
42
+ from google.generativeai.types.generation_types import AsyncGenerateContentResponse
43
+ from google.generativeai.types.safety_types import (
44
+ EasySafetySettingDict,
45
+ HarmBlockThreshold,
46
+ HarmCategory,
47
+ )
42
48
  from google.protobuf.json_format import MessageToDict, ParseDict
43
49
  from google.protobuf.struct_pb2 import Struct
44
50
  from pydantic import JsonValue
45
51
  from typing_extensions import override
46
52
 
47
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED
48
- from inspect_ai._util.content import Content, ContentImage, ContentText
49
- from inspect_ai._util.images import image_as_data
53
+ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
54
+ from inspect_ai._util.content import (
55
+ Content,
56
+ ContentAudio,
57
+ ContentImage,
58
+ ContentText,
59
+ ContentVideo,
60
+ )
61
+ from inspect_ai._util.images import file_as_data
62
+ from inspect_ai._util.kvstore import inspect_kvstore
63
+ from inspect_ai._util.trace import trace_message
50
64
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
51
65
 
52
66
  from .._chat_message import (
@@ -70,9 +84,11 @@ from .._model_output import (
70
84
  )
71
85
  from .util import model_base_url
72
86
 
87
+ logger = getLogger(__name__)
88
+
73
89
  SAFETY_SETTINGS = "safety_settings"
74
90
 
75
- DEFAULT_SAFETY_SETTINGS: SafetySettingDict = {
91
+ DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
76
92
  HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
77
93
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
78
94
  HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
@@ -124,7 +140,7 @@ class GoogleAPI(ModelAPI):
124
140
  tools: list[ToolInfo],
125
141
  tool_choice: ToolChoice,
126
142
  config: GenerateConfig,
127
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
143
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
128
144
  parameters = GenerationConfig(
129
145
  temperature=config.temperature,
130
146
  top_p=config.top_p,
@@ -132,11 +148,8 @@ class GoogleAPI(ModelAPI):
132
148
  max_output_tokens=config.max_tokens,
133
149
  stop_sequences=config.stop_seqs,
134
150
  candidate_count=config.num_choices,
135
- seed=config.seed,
136
151
  presence_penalty=config.presence_penalty,
137
152
  frequency_penalty=config.frequency_penalty,
138
- response_logprobs=config.logprobs,
139
- logprobs=config.top_logprobs,
140
153
  )
141
154
 
142
155
  # google-native messages
@@ -159,18 +172,15 @@ class GoogleAPI(ModelAPI):
159
172
  response=response,
160
173
  )
161
174
 
162
- # cast to AsyncGenerateContentResponse since we passed stream=False
163
175
  try:
164
- response = cast(
165
- AsyncGenerateContentResponse,
166
- await self.model.generate_content_async(
167
- contents=contents,
168
- safety_settings=self.safety_settings,
169
- generation_config=parameters,
170
- tools=gemini_tools,
171
- tool_config=gemini_tool_config,
172
- ),
176
+ response = await self.model.generate_content_async(
177
+ contents=contents,
178
+ safety_settings=self.safety_settings,
179
+ generation_config=parameters,
180
+ tools=gemini_tools,
181
+ tool_config=gemini_tool_config,
173
182
  )
183
+
174
184
  except InvalidArgument as ex:
175
185
  return self.handle_invalid_argument(ex), model_call()
176
186
 
@@ -188,15 +198,13 @@ class GoogleAPI(ModelAPI):
188
198
  # return
189
199
  return output, model_call()
190
200
 
191
- def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput:
201
+ def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
192
202
  if "size exceeds the limit" in ex.message.lower():
193
203
  return ModelOutput.from_content(
194
204
  model=self.model_name, content=ex.message, stop_reason="model_length"
195
205
  )
196
206
  else:
197
- return ModelOutput.from_content(
198
- model=self.model_name, content=ex.message, stop_reason="unknown"
199
- )
207
+ return ex
200
208
 
201
209
  @override
202
210
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -214,7 +222,7 @@ class GoogleAPI(ModelAPI):
214
222
  def build_model_call(
215
223
  contents: list[ContentDict],
216
224
  generation_config: GenerationConfig,
217
- safety_settings: SafetySettingDict,
225
+ safety_settings: EasySafetySettingDict,
218
226
  tools: list[Tool] | None,
219
227
  tool_config: ToolConfig | None,
220
228
  response: AsyncGenerateContentResponse | None,
@@ -231,7 +239,7 @@ def build_model_call(
231
239
  if tool_config is not None
232
240
  else None,
233
241
  ),
234
- response=response.to_dict() if response is not None else {},
242
+ response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
235
243
  filter=model_call_filter,
236
244
  )
237
245
 
@@ -252,12 +260,12 @@ def model_call_content(content: ContentDict) -> ContentDict:
252
260
 
253
261
  def model_call_part(part: PartType) -> PartType:
254
262
  if isinstance(part, proto.Message):
255
- return MessageToDict(part._pb)
263
+ return cast(PartDict, MessageToDict(part._pb))
256
264
  elif isinstance(part, dict):
257
265
  part = part.copy()
258
266
  keys = list(part.keys())
259
267
  for key in keys:
260
- part[key] = model_call_part(part[key])
268
+ part[key] = model_call_part(part[key]) # type: ignore[literal-required]
261
269
  return part
262
270
  else:
263
271
  return part
@@ -299,9 +307,6 @@ def consective_tool_message_reducer(
299
307
  return messages
300
308
 
301
309
 
302
- NO_CONTENT = "(no content)"
303
-
304
-
305
310
  async def content_dict(
306
311
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
307
312
  ) -> ContentDict:
@@ -309,13 +314,13 @@ async def content_dict(
309
314
  return ContentDict(
310
315
  role="user",
311
316
  parts=(
312
- [PartDict(text=message.content or NO_CONTENT)]
317
+ [message.content or NO_CONTENT]
313
318
  if isinstance(message.content, str)
314
319
  else [await content_part(content) for content in message.content]
315
320
  ),
316
321
  )
317
322
  elif isinstance(message, ChatMessageAssistant):
318
- content_parts: list[Part] = []
323
+ content_parts: list[PartType] = []
319
324
  # tool call parts
320
325
  if message.tool_calls is not None:
321
326
  content_parts.extend(
@@ -364,26 +369,32 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
364
369
  return struct
365
370
 
366
371
 
367
- async def content_part(content: Content | str) -> PartDict:
372
+ async def content_part(content: Content | str) -> PartType:
368
373
  if isinstance(content, str):
369
- return PartDict(text=content or NO_CONTENT)
374
+ return content or NO_CONTENT
370
375
  elif isinstance(content, ContentText):
371
- return PartDict(text=content.text or NO_CONTENT)
376
+ return content.text or NO_CONTENT
372
377
  else:
373
- return PartDict(inline_data=await chat_content_image_to_blob(content))
378
+ return await chat_content_to_part(content)
374
379
 
375
380
 
376
- async def chat_content_image_to_blob(image: ContentImage) -> Blob:
377
- image_url = image.image
378
- image_bytes, mime_type = await image_as_data(image_url)
379
- return Blob(mime_type=mime_type, data=image_bytes)
381
+ async def chat_content_to_part(
382
+ content: ContentImage | ContentAudio | ContentVideo,
383
+ ) -> PartType:
384
+ if isinstance(content, ContentImage):
385
+ content_bytes, mime_type = await file_as_data(content.image)
386
+ return Blob(mime_type=mime_type, data=content_bytes)
387
+ else:
388
+ return await file_for_content(content)
380
389
 
381
390
 
382
391
  def prepend_system_messages(
383
392
  messages: list[ContentDict], system_messages: list[ChatMessageSystem]
384
393
  ) -> None:
385
394
  # create system_parts
386
- system_parts = [Part(text=message.content) for message in system_messages]
395
+ system_parts: list[PartType] = [
396
+ Part(text=message.content) for message in system_messages
397
+ ]
387
398
 
388
399
  # we want the system messages to be prepended to the first user message
389
400
  # (if there is no first user message then prepend one)
@@ -455,6 +466,8 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
455
466
  return schema_from_param(param.anyOf[0], nullable=True)
456
467
  else:
457
468
  return Schema(type=Type.TYPE_UNSPECIFIED)
469
+ elif param.enum:
470
+ return Schema(type=Type.STRING, format="enum", enum=param.enum)
458
471
  else:
459
472
  return Schema(type=Type.TYPE_UNSPECIFIED)
460
473
 
@@ -579,14 +592,14 @@ def gapi_should_retry(ex: BaseException) -> bool:
579
592
 
580
593
  def parse_safety_settings(
581
594
  safety_settings: Any,
582
- ) -> dict[HarmCategory, HarmBlockThreshold]:
595
+ ) -> EasySafetySettingDict:
583
596
  # ensure we have a dict
584
597
  if isinstance(safety_settings, str):
585
598
  safety_settings = json.loads(safety_settings)
586
599
  if not isinstance(safety_settings, dict):
587
600
  raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
588
601
 
589
- parsed_settings: dict[HarmCategory, HarmBlockThreshold] = {}
602
+ parsed_settings: EasySafetySettingDict = {}
590
603
  for key, value in safety_settings.items():
591
604
  if isinstance(key, str):
592
605
  key = str_to_harm_category(key)
@@ -602,23 +615,23 @@ def parse_safety_settings(
602
615
  return parsed_settings
603
616
 
604
617
 
605
- def str_to_harm_category(category: str) -> HarmCategory:
618
+ def str_to_harm_category(category: str) -> int:
606
619
  category = category.upper()
607
620
  if "HARASSMENT" in category:
608
- return HarmCategory.HARM_CATEGORY_HARASSMENT
621
+ return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
609
622
  elif "HATE_SPEECH" in category:
610
- return HarmCategory.HARM_CATEGORY_HATE_SPEECH
623
+ return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
611
624
  elif "SEXUALLY_EXPLICIT" in category:
612
- return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
625
+ return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
613
626
  elif "DANGEROUS_CONTENT" in category:
614
- return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
627
+ return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
615
628
  else:
616
629
  # NOTE: Although there is an "UNSPECIFIED" category, in the
617
630
  # documentation, the API does not accept it.
618
631
  raise ValueError(f"Unknown HarmCategory: {category}")
619
632
 
620
633
 
621
- def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
634
+ def str_to_harm_block_threshold(threshold: str) -> int:
622
635
  threshold = threshold.upper()
623
636
  if "LOW" in threshold:
624
637
  return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
@@ -630,3 +643,53 @@ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
630
643
  return HarmBlockThreshold.BLOCK_NONE
631
644
  else:
632
645
  raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
646
+
647
+
648
+ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
649
+ # helper to write trace messages
650
+ def trace(message: str) -> None:
651
+ trace_message(logger, "Google Files", message)
652
+
653
+ # get the file bytes and compute sha256 hash
654
+ if isinstance(content, ContentAudio):
655
+ file = content.audio
656
+ else:
657
+ file = content.video
658
+ content_bytes, mime_type = await file_as_data(file)
659
+ content_sha256 = hashlib.sha256(content_bytes).hexdigest()
660
+
661
+ # we cache uploads for re-use, open the db where we track that
662
+ # (track up to 1 million previous uploads)
663
+ with inspect_kvstore("google_files", 1000000) as files_db:
664
+ # can we serve from existing uploads?
665
+ uploaded_file = files_db.get(content_sha256)
666
+ if uploaded_file:
667
+ try:
668
+ upload = get_file(uploaded_file)
669
+ if upload.state.name == "ACTIVE":
670
+ trace(f"Using uploaded file: {uploaded_file}")
671
+ return upload
672
+ else:
673
+ trace(
674
+ f"Not using uploaded file '{uploaded_file} (state was {upload.state})"
675
+ )
676
+ except Exception as ex:
677
+ trace(f"Error attempting to access uploaded file: {ex}")
678
+ files_db.delete(content_sha256)
679
+
680
+ # do the upload (and record it)
681
+ upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
682
+ while upload.state.name == "PROCESSING":
683
+ await asyncio.sleep(3)
684
+ upload = get_file(upload.name)
685
+
686
+ if upload.state.name == "FAILED":
687
+ trace(f"Failed to upload file '{upload.name}: {upload.error}")
688
+ raise ValueError(f"Google file upload failed: {upload.error}")
689
+
690
+ # trace and record it
691
+ trace(f"Uploaded file: {upload.name}")
692
+ files_db.put(content_sha256, upload.name)
693
+
694
+ # return the file
695
+ return upload
@@ -23,8 +23,8 @@ from typing_extensions import override
23
23
 
24
24
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
25
25
  from inspect_ai._util.content import Content
26
- from inspect_ai._util.images import image_as_data_uri
27
- from inspect_ai._util.url import is_data_uri, is_http_url
26
+ from inspect_ai._util.images import file_as_data_uri
27
+ from inspect_ai._util.url import is_http_url
28
28
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
29
29
 
30
30
  from .._chat_message import (
@@ -248,18 +248,20 @@ async def as_chat_completion_part(
248
248
  ) -> ChatCompletionContentPartParam:
249
249
  if content.type == "text":
250
250
  return ChatCompletionContentPartTextParam(type="text", text=content.text)
251
- else:
251
+ elif content.type == "image":
252
252
  # API takes URL or base64 encoded file. If it's a remote file or data URL leave it alone, otherwise encode it
253
253
  image_url = content.image
254
254
  detail = content.detail
255
255
 
256
- if not is_http_url(image_url) and not is_data_uri(image_url):
257
- image_url = await image_as_data_uri(image_url)
256
+ if not is_http_url(image_url):
257
+ image_url = await file_as_data_uri(image_url)
258
258
 
259
259
  return ChatCompletionContentPartImageParam(
260
260
  type="image_url",
261
261
  image_url=dict(url=image_url, detail=detail),
262
262
  )
263
+ else:
264
+ raise RuntimeError("Groq models do not support audio or video inputs.")
263
265
 
264
266
 
265
267
  def chat_tools(tools: List[ToolInfo]) -> List[Dict[str, Any]]:
@@ -239,12 +239,17 @@ class HuggingFaceAPI(ModelAPI):
239
239
  hf_messages = inspect_tools_to_string(hf_messages)
240
240
 
241
241
  # apply chat template
242
- chat = self.tokenizer.apply_chat_template(
243
- hf_messages,
244
- add_generation_prompt=True,
245
- tokenize=False,
246
- tools=tools_list if len(tools_list) > 0 else None,
247
- )
242
+ if self.tokenizer.chat_template is not None:
243
+ chat = self.tokenizer.apply_chat_template(
244
+ hf_messages,
245
+ add_generation_prompt=True,
246
+ tokenize=False,
247
+ tools=tools_list if len(tools_list) > 0 else None,
248
+ )
249
+ else:
250
+ chat = ""
251
+ for message in hf_messages:
252
+ chat += f"{message.role}: {message.content}\n"
248
253
  # return
249
254
  return cast(str, chat)
250
255
 
@@ -40,10 +40,10 @@ from typing_extensions import override
40
40
  # https://github.com/mistralai/client-python/blob/main/MIGRATION.md
41
41
  from inspect_ai._util.constants import (
42
42
  DEFAULT_TIMEOUT,
43
+ NO_CONTENT,
43
44
  )
44
45
  from inspect_ai._util.content import Content, ContentImage, ContentText
45
- from inspect_ai._util.images import image_as_data_uri
46
- from inspect_ai._util.url import is_data_uri
46
+ from inspect_ai._util.images import file_as_data_uri
47
47
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
48
48
 
49
49
  from .._chat_message import (
@@ -123,7 +123,7 @@ class MistralAPI(ModelAPI):
123
123
  tools: list[ToolInfo],
124
124
  tool_choice: ToolChoice,
125
125
  config: GenerateConfig,
126
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
126
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
127
127
  # build request
128
128
  request: dict[str, Any] = dict(
129
129
  model=self.model_name,
@@ -147,7 +147,7 @@ class MistralAPI(ModelAPI):
147
147
  response = await self.client.chat.complete_async(**request)
148
148
  except SDKError as ex:
149
149
  if ex.status_code == 400:
150
- return self.handle_bad_request(ex)
150
+ return self.handle_bad_request(ex), mistral_model_call(request, None)
151
151
  else:
152
152
  raise ex
153
153
 
@@ -182,25 +182,27 @@ class MistralAPI(ModelAPI):
182
182
  def connection_key(self) -> str:
183
183
  return str(self.api_key)
184
184
 
185
- def handle_bad_request(self, ex: SDKError) -> ModelOutput:
185
+ def handle_bad_request(self, ex: SDKError) -> ModelOutput | Exception:
186
+ body = json.loads(ex.body)
187
+ content = body.get("message", ex.body)
186
188
  if "maximum context length" in ex.body:
187
- body = json.loads(ex.body)
188
- content = body.get("message", ex.body)
189
189
  return ModelOutput.from_content(
190
190
  model=self.model_name, content=content, stop_reason="model_length"
191
191
  )
192
192
  else:
193
- raise ex
193
+ return ex
194
194
 
195
195
 
196
196
  def mistral_model_call(
197
- request: dict[str, Any], response: MistralChatCompletionResponse
197
+ request: dict[str, Any], response: MistralChatCompletionResponse | None
198
198
  ) -> ModelCall:
199
199
  request = request.copy()
200
200
  request.update(messages=[message.model_dump() for message in request["messages"]])
201
201
  if request.get("tools", None) is not None:
202
202
  request["tools"] = [tool.model_dump() for tool in request["tools"]]
203
- return ModelCall(request=request, response=response.model_dump())
203
+ return ModelCall(
204
+ request=request, response=response.model_dump() if response else {}
205
+ )
204
206
 
205
207
 
206
208
  def mistral_chat_tools(tools: list[ToolInfo]) -> list[MistralTool]:
@@ -327,9 +329,6 @@ async def mistral_chat_message(
327
329
  )
328
330
 
329
331
 
330
- NO_CONTENT = "(no content)"
331
-
332
-
333
332
  async def mistral_message_content(
334
333
  content: str | list[Content],
335
334
  ) -> str | list[ContentChunk]:
@@ -351,16 +350,14 @@ def mistral_system_message_content(
351
350
  async def mistral_content_chunk(content: Content) -> ContentChunk:
352
351
  if isinstance(content, ContentText):
353
352
  return TextChunk(text=content.text or NO_CONTENT)
354
- else:
353
+ elif isinstance(content, ContentImage):
355
354
  # resolve image to url
356
- image_url = content.image
357
- if not is_data_uri(image_url):
358
- image_url = await image_as_data_uri(image_url)
355
+ image_url = await file_as_data_uri(content.image)
359
356
 
360
357
  # return chunk
361
- return ImageURLChunk(
362
- image_url=ImageURL(url=content.image, detail=content.detail)
363
- )
358
+ return ImageURLChunk(image_url=ImageURL(url=image_url, detail=content.detail))
359
+ else:
360
+ raise RuntimeError("Mistral models do not support audio or video inputs.")
364
361
 
365
362
 
366
363
  def mistral_tool_call(tool_call: ToolCall) -> MistralToolCall:
@@ -17,6 +17,7 @@ from openai.types.chat import (
17
17
  ChatCompletion,
18
18
  ChatCompletionAssistantMessageParam,
19
19
  ChatCompletionContentPartImageParam,
20
+ ChatCompletionContentPartInputAudioParam,
20
21
  ChatCompletionContentPartParam,
21
22
  ChatCompletionContentPartTextParam,
22
23
  ChatCompletionDeveloperMessageParam,
@@ -36,9 +37,9 @@ from typing_extensions import override
36
37
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
37
38
  from inspect_ai._util.content import Content
38
39
  from inspect_ai._util.error import PrerequisiteError
39
- from inspect_ai._util.images import image_as_data_uri
40
+ from inspect_ai._util.images import file_as_data_uri
40
41
  from inspect_ai._util.logger import warn_once
41
- from inspect_ai._util.url import is_data_uri, is_http_url
42
+ from inspect_ai._util.url import is_http_url
42
43
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
43
44
 
44
45
  from .._chat_message import ChatMessage, ChatMessageAssistant
@@ -165,7 +166,7 @@ class OpenAIAPI(ModelAPI):
165
166
  tools: list[ToolInfo],
166
167
  tool_choice: ToolChoice,
167
168
  config: GenerateConfig,
168
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
169
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
169
170
  # short-circuit to call o1- models that are text only
170
171
  if self.is_o1_preview() or self.is_o1_mini():
171
172
  return await generate_o1(
@@ -306,27 +307,26 @@ class OpenAIAPI(ModelAPI):
306
307
  return params
307
308
 
308
309
  # convert some well known bad request errors into ModelOutput
309
- def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
310
- if e.status_code == 400:
311
- # extract message
312
- if isinstance(e.body, dict) and "message" in e.body.keys():
313
- content = str(e.body.get("message"))
314
- else:
315
- content = e.message
310
+ def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
311
+ # extract message
312
+ if isinstance(e.body, dict) and "message" in e.body.keys():
313
+ content = str(e.body.get("message"))
314
+ else:
315
+ content = e.message
316
316
 
317
- # narrow stop_reason
318
- if e.code == "context_length_exceeded":
319
- stop_reason: StopReason = "model_length"
320
- elif e.code == "invalid_prompt":
321
- stop_reason = "content_filter"
322
- else:
323
- stop_reason = "unknown"
317
+ # narrow stop_reason
318
+ stop_reason: StopReason | None = None
319
+ if e.code == "context_length_exceeded":
320
+ stop_reason = "model_length"
321
+ elif e.code == "invalid_prompt":
322
+ stop_reason = "content_filter"
324
323
 
324
+ if stop_reason:
325
325
  return ModelOutput.from_content(
326
326
  model=self.model_name, content=content, stop_reason=stop_reason
327
327
  )
328
328
  else:
329
- raise e
329
+ return e
330
330
 
331
331
 
332
332
  async def as_openai_chat_messages(
@@ -463,16 +463,27 @@ async def as_chat_completion_part(
463
463
  ) -> ChatCompletionContentPartParam:
464
464
  if content.type == "text":
465
465
  return ChatCompletionContentPartTextParam(type="text", text=content.text)
466
- else:
466
+ elif content.type == "image":
467
467
  # API takes URL or base64 encoded file. If it's a remote file or
468
468
  # data URL leave it alone, otherwise encode it
469
469
  image_url = content.image
470
470
  detail = content.detail
471
471
 
472
- if not is_http_url(image_url) and not is_data_uri(image_url):
473
- image_url = await image_as_data_uri(image_url)
472
+ if not is_http_url(image_url):
473
+ image_url = await file_as_data_uri(image_url)
474
474
 
475
475
  return ChatCompletionContentPartImageParam(
476
476
  type="image_url",
477
477
  image_url=dict(url=image_url, detail=detail),
478
478
  )
479
+ elif content.type == "audio":
480
+ audio_data = await file_as_data_uri(content.audio)
481
+
482
+ return ChatCompletionContentPartInputAudioParam(
483
+ type="input_audio", input_audio=dict(data=audio_data, format=content.format)
484
+ )
485
+
486
+ else:
487
+ raise RuntimeError(
488
+ "Video content is not currently supported by Open AI chat models."
489
+ )
@@ -44,7 +44,7 @@ async def generate_o1(
44
44
  input: list[ChatMessage],
45
45
  tools: list[ToolInfo],
46
46
  **params: Any,
47
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
47
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
48
48
  # create chatapi handler
49
49
  handler = O1PreviewChatAPIHandler()
50
50
 
@@ -82,17 +82,18 @@ async def generate_o1(
82
82
  ), model_call()
83
83
 
84
84
 
85
- def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
85
+ def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput | Exception:
86
86
  if ex.code == "context_length_exceeded":
87
- stop_reason: StopReason = "model_length"
87
+ stop_reason: StopReason | None = "model_length"
88
88
  elif ex.code == "invalid_prompt":
89
89
  stop_reason = "content_filter"
90
- else:
91
- stop_reason = "unknown"
92
90
 
93
- return ModelOutput.from_content(
94
- model=model, content=str(ex), stop_reason=stop_reason
95
- )
91
+ if stop_reason:
92
+ return ModelOutput.from_content(
93
+ model=model, content=str(ex), stop_reason=stop_reason
94
+ )
95
+ else:
96
+ return ex
96
97
 
97
98
 
98
99
  def chat_messages(
@@ -94,7 +94,7 @@ def vertex() -> type[ModelAPI]:
94
94
  def google() -> type[ModelAPI]:
95
95
  FEATURE = "Google API"
96
96
  PACKAGE = "google-generativeai"
97
- MIN_VERSION = "0.8.3"
97
+ MIN_VERSION = "0.8.4"
98
98
 
99
99
  # workaround log spam
100
100
  # https://github.com/ray-project/ray/issues/24917
@@ -103,18 +103,18 @@ class TogetherAIAPI(OpenAIAPI):
103
103
  return DEFAULT_MAX_TOKENS
104
104
 
105
105
  @override
106
- def handle_bad_request(self, ex: BadRequestError) -> ModelOutput:
107
- if ex.status_code == 400 and "max_new_tokens" in ex.message:
108
- response = ex.response.json()
109
- if "error" in response and "message" in response.get("error"):
110
- content = response.get("error").get("message")
111
- else:
112
- content = str(response)
106
+ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
107
+ response = ex.response.json()
108
+ if "error" in response and "message" in response.get("error"):
109
+ content = response.get("error").get("message")
110
+ else:
111
+ content = str(response)
112
+ if "max_new_tokens" in ex.message:
113
113
  return ModelOutput.from_content(
114
114
  model=self.model_name, content=content, stop_reason="model_length"
115
115
  )
116
116
  else:
117
- raise ex
117
+ return ex
118
118
 
119
119
  # Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
120
120
  def _chat_choices_from_response(