inspect-ai 0.3.56__py3-none-any.whl → 0.3.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_cli/common.py +4 -2
  3. inspect_ai/_cli/eval.py +2 -0
  4. inspect_ai/_cli/trace.py +21 -2
  5. inspect_ai/_display/core/active.py +0 -2
  6. inspect_ai/_display/core/panel.py +1 -1
  7. inspect_ai/_display/rich/display.py +4 -4
  8. inspect_ai/_display/textual/app.py +4 -1
  9. inspect_ai/_display/textual/widgets/samples.py +41 -5
  10. inspect_ai/_eval/eval.py +32 -20
  11. inspect_ai/_eval/evalset.py +7 -5
  12. inspect_ai/_eval/run.py +16 -11
  13. inspect_ai/_eval/task/__init__.py +2 -2
  14. inspect_ai/_eval/task/images.py +40 -25
  15. inspect_ai/_eval/task/run.py +141 -119
  16. inspect_ai/_eval/task/task.py +140 -25
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/content.py +23 -1
  19. inspect_ai/_util/datetime.py +1 -1
  20. inspect_ai/_util/deprecation.py +1 -1
  21. inspect_ai/_util/images.py +20 -17
  22. inspect_ai/_util/json.py +11 -1
  23. inspect_ai/_util/kvstore.py +73 -0
  24. inspect_ai/_util/logger.py +2 -1
  25. inspect_ai/_util/notgiven.py +18 -0
  26. inspect_ai/_util/thread.py +5 -0
  27. inspect_ai/_util/trace.py +39 -3
  28. inspect_ai/_util/transcript.py +36 -7
  29. inspect_ai/_view/www/.prettierrc.js +12 -0
  30. inspect_ai/_view/www/dist/assets/index.js +322 -226
  31. inspect_ai/_view/www/log-schema.json +221 -138
  32. inspect_ai/_view/www/src/App.mjs +18 -9
  33. inspect_ai/_view/www/src/Types.mjs +0 -1
  34. inspect_ai/_view/www/src/api/Types.mjs +15 -4
  35. inspect_ai/_view/www/src/api/api-http.mjs +2 -0
  36. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
  37. inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
  38. inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
  39. inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
  40. inspect_ai/_view/www/src/components/MessageContent.mjs +44 -2
  41. inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
  42. inspect_ai/_view/www/src/components/Tools.mjs +18 -3
  43. inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
  44. inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
  46. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
  47. inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
  48. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
  49. inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
  50. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +242 -178
  51. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
  52. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
  53. inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
  54. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
  55. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
  56. inspect_ai/_view/www/src/types/log.d.ts +53 -35
  57. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  58. inspect_ai/approval/_human/util.py +2 -2
  59. inspect_ai/dataset/_sources/csv.py +2 -1
  60. inspect_ai/dataset/_sources/json.py +2 -1
  61. inspect_ai/dataset/_sources/util.py +15 -7
  62. inspect_ai/log/_condense.py +11 -1
  63. inspect_ai/log/_log.py +27 -5
  64. inspect_ai/log/_recorders/eval.py +21 -8
  65. inspect_ai/log/_samples.py +10 -5
  66. inspect_ai/log/_transcript.py +28 -1
  67. inspect_ai/model/__init__.py +10 -2
  68. inspect_ai/model/_call_tools.py +82 -17
  69. inspect_ai/model/_chat_message.py +2 -4
  70. inspect_ai/model/{_trace.py → _conversation.py} +9 -8
  71. inspect_ai/model/_model.py +2 -2
  72. inspect_ai/model/_providers/anthropic.py +9 -7
  73. inspect_ai/model/_providers/azureai.py +6 -4
  74. inspect_ai/model/_providers/bedrock.py +6 -4
  75. inspect_ai/model/_providers/google.py +103 -14
  76. inspect_ai/model/_providers/groq.py +7 -5
  77. inspect_ai/model/_providers/hf.py +11 -6
  78. inspect_ai/model/_providers/mistral.py +6 -9
  79. inspect_ai/model/_providers/openai.py +34 -8
  80. inspect_ai/model/_providers/openai_o1.py +10 -12
  81. inspect_ai/model/_providers/vertex.py +17 -4
  82. inspect_ai/scorer/__init__.py +13 -2
  83. inspect_ai/scorer/_metrics/__init__.py +2 -2
  84. inspect_ai/scorer/_metrics/std.py +3 -3
  85. inspect_ai/tool/__init__.py +9 -1
  86. inspect_ai/tool/_tool.py +9 -2
  87. inspect_ai/tool/_tool_info.py +2 -1
  88. inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
  89. inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -3
  90. inspect_ai/util/__init__.py +4 -3
  91. inspect_ai/util/{_trace.py → _conversation.py} +3 -17
  92. inspect_ai/util/_display.py +14 -4
  93. inspect_ai/util/_sandbox/context.py +12 -13
  94. inspect_ai/util/_sandbox/docker/compose.py +24 -13
  95. inspect_ai/util/_sandbox/docker/docker.py +20 -13
  96. inspect_ai/util/_sandbox/docker/util.py +2 -1
  97. inspect_ai/util/_sandbox/environment.py +13 -1
  98. inspect_ai/util/_sandbox/local.py +1 -0
  99. inspect_ai/util/_sandbox/self_check.py +18 -18
  100. inspect_ai/util/_store.py +2 -2
  101. inspect_ai/util/_subprocess.py +3 -3
  102. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +3 -3
  103. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +107 -103
  104. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +1 -1
  105. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
  106. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
  107. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,20 @@
1
1
  import asyncio
2
2
  import inspect
3
+ import types
3
4
  from dataclasses import is_dataclass
4
5
  from logging import getLogger
5
6
  from textwrap import dedent
7
+ from types import UnionType
6
8
  from typing import (
7
9
  Any,
8
10
  Callable,
9
11
  Dict,
10
12
  List,
11
13
  NamedTuple,
14
+ Optional,
15
+ Tuple,
12
16
  Type,
17
+ Union,
13
18
  get_args,
14
19
  get_origin,
15
20
  get_type_hints,
@@ -19,16 +24,19 @@ from typing import (
19
24
  from jsonschema import Draft7Validator
20
25
  from pydantic import BaseModel
21
26
 
22
- from inspect_ai._util.content import Content, ContentImage, ContentText
27
+ from inspect_ai._util.content import (
28
+ Content,
29
+ ContentAudio,
30
+ ContentImage,
31
+ ContentText,
32
+ ContentVideo,
33
+ )
23
34
  from inspect_ai._util.format import format_function_call
24
35
  from inspect_ai._util.text import truncate_string_to_bytes
25
36
  from inspect_ai._util.trace import trace_action
26
- from inspect_ai.model._trace import trace_tool_mesage
37
+ from inspect_ai.model._conversation import conversation_tool_mesage
27
38
  from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
28
- from inspect_ai.tool._tool import (
29
- ToolApprovalError,
30
- ToolParsingError,
31
- )
39
+ from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
32
40
  from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
33
41
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
34
42
  from inspect_ai.tool._tool_info import parse_docstring
@@ -118,10 +126,14 @@ async def call_tools(
118
126
  # massage result, leave list[Content] alone, convert all other
119
127
  # types to string as that is what the model APIs accept
120
128
  truncated: tuple[int, int] | None = None
121
- if isinstance(result, ContentText | ContentImage):
129
+ if isinstance(
130
+ result, ContentText | ContentImage | ContentAudio | ContentVideo
131
+ ):
122
132
  content: str | list[Content] = [result]
123
133
  elif isinstance(result, list) and (
124
- isinstance(result[0], ContentText | ContentImage)
134
+ isinstance(
135
+ result[0], ContentText | ContentImage | ContentAudio | ContentVideo
136
+ )
125
137
  ):
126
138
  content = result
127
139
  else:
@@ -161,6 +173,9 @@ async def call_tools(
161
173
  # call tools
162
174
  tool_messages: list[ChatMessageTool] = []
163
175
  for call in message.tool_calls:
176
+ # create the task
177
+ task = asyncio.create_task(call_tool_task(call))
178
+
164
179
  # create pending tool event and add it to the transcript
165
180
  event = ToolEvent(
166
181
  id=call.id,
@@ -169,15 +184,44 @@ async def call_tools(
169
184
  view=call.view,
170
185
  pending=True,
171
186
  )
187
+ event.set_task(task)
172
188
  transcript()._event(event)
173
189
 
174
- # execute the tool call
175
- task = asyncio.create_task(call_tool_task(call))
176
- tool_message, result_event = await task
190
+ # execute the tool call. if the operator cancelled the
191
+ # tool call then synthesize the appropriate message/event
192
+ try:
193
+ tool_message, result_event = await task
194
+ except asyncio.CancelledError:
195
+ if event.cancelled:
196
+ tool_message = ChatMessageTool(
197
+ content="",
198
+ function=call.function,
199
+ tool_call_id=call.id,
200
+ error=ToolCallError(
201
+ "timeout", "Command timed out before completing."
202
+ ),
203
+ )
204
+ result_event = ToolEvent(
205
+ id=call.id,
206
+ function=call.function,
207
+ arguments=call.arguments,
208
+ result=tool_message.content,
209
+ truncated=None,
210
+ view=call.view,
211
+ error=tool_message.error,
212
+ events=[],
213
+ )
214
+ transcript().info(
215
+ f"Tool call '{call.function}' was cancelled by operator."
216
+ )
217
+ else:
218
+ raise
219
+
220
+ # update return messages
177
221
  tool_messages.append(tool_message)
178
222
 
179
- # trace if we are tracing
180
- trace_tool_mesage(tool_message)
223
+ # print conversation if display is conversation
224
+ conversation_tool_mesage(tool_message)
181
225
 
182
226
  # update the event with the results
183
227
  event.set_result(
@@ -268,6 +312,16 @@ def disable_parallel_tools(
268
312
  return False
269
313
 
270
314
 
315
+ def type_hint_includes_none(type_hint: Type[Any] | None) -> bool:
316
+ origin = get_origin(type_hint)
317
+
318
+ if origin in {Union, UnionType}:
319
+ return type(None) in get_args(type_hint)
320
+ elif origin is Optional:
321
+ return True
322
+ return False
323
+
324
+
271
325
  def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
272
326
  # parse function typeinfo
273
327
  signature = inspect.signature(func)
@@ -296,7 +350,7 @@ def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, An
296
350
  # yield parameter (fail if not passed and there is no default)
297
351
  if param_name in input:
298
352
  params[param_name] = tool_param(type_hint, input.get(param_name))
299
- elif param.default is not None:
353
+ elif param.default is not None or type_hint_includes_none(type_hint):
300
354
  params[param_name] = param.default
301
355
  else:
302
356
  raise ToolParsingError(
@@ -339,11 +393,21 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
339
393
  return [tool_param(args[0], x) for x in input]
340
394
  else:
341
395
  return input
396
+ elif origin is tuple or origin is Tuple:
397
+ if args:
398
+ return tuple([tool_param(args[0], x) for x in input])
399
+ else:
400
+ return tuple(input)
342
401
  elif origin is dict or origin is Dict:
343
402
  if args and len(args) > 1:
344
403
  return {k: tool_param(args[1], v) for k, v in input}
345
404
  else:
346
405
  return input
406
+ elif origin is Union or origin is types.UnionType:
407
+ if args[1] is type(None):
408
+ return tool_param(args[0], input)
409
+ else:
410
+ return input
347
411
  else:
348
412
  return input
349
413
 
@@ -389,12 +453,13 @@ def truncate_tool_output(
389
453
  # truncate if required
390
454
  truncated = truncate_string_to_bytes(output, active_max_output)
391
455
  if truncated:
392
- truncated_output = dedent(f"""
456
+ truncated_output = dedent("""
393
457
  The output of your call to {tool_name} was too long to be displayed.
394
458
  Here is a truncated version:
395
459
  <START_TOOL_OUTPUT>
396
- {truncated.output}
397
- <END_TOOL_OUTPUT>""")
460
+ {truncated_output}
461
+ <END_TOOL_OUTPUT>
462
+ """).format(tool_name=tool_name, truncated_output=truncated.output)
398
463
  return TruncatedToolOutput(
399
464
  truncated_output, truncated.original_bytes, active_max_output
400
465
  )
@@ -59,10 +59,8 @@ class ChatMessageBase(BaseModel):
59
59
  if isinstance(self.content, str):
60
60
  self.content = text
61
61
  else:
62
- all_images = [
63
- content for content in self.content if content.type == "image"
64
- ]
65
- self.content = [ContentText(text=text)] + all_images
62
+ all_other = [content for content in self.content if content.type != "text"]
63
+ self.content = [ContentText(text=text)] + all_other
66
64
 
67
65
 
68
66
  class ChatMessageSystem(ChatMessageBase):
@@ -3,7 +3,8 @@ from rich.text import Text
3
3
 
4
4
  from inspect_ai._util.rich import lines_display
5
5
  from inspect_ai._util.transcript import transcript_markdown
6
- from inspect_ai.util._trace import trace_enabled, trace_panel
6
+ from inspect_ai.util._conversation import conversation_panel
7
+ from inspect_ai.util._display import display_type
7
8
 
8
9
  from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
9
10
  from ._render import messages_preceding_assistant, render_tool_calls
@@ -11,25 +12,25 @@ from ._render import messages_preceding_assistant, render_tool_calls
11
12
  MESSAGE_TITLE = "Message"
12
13
 
13
14
 
14
- def trace_tool_mesage(message: ChatMessageTool) -> None:
15
- if trace_enabled():
15
+ def conversation_tool_mesage(message: ChatMessageTool) -> None:
16
+ if display_type() == "conversation":
16
17
  # truncate output to 100 lines
17
18
  output = message.error.message if message.error else message.text.strip()
18
19
  content = lines_display(output, 100)
19
20
 
20
- trace_panel(
21
+ conversation_panel(
21
22
  title=f"Tool Output: {message.function}",
22
23
  content=content,
23
24
  )
24
25
 
25
26
 
26
- def trace_assistant_message(
27
+ def conversation_assistant_message(
27
28
  input: list[ChatMessage], message: ChatMessageAssistant
28
29
  ) -> None:
29
- if trace_enabled():
30
+ if display_type() == "conversation":
30
31
  # print precding messages that aren't tool or assistant
31
32
  for m in messages_preceding_assistant(input):
32
- trace_panel(
33
+ conversation_panel(
33
34
  title=m.role.capitalize(),
34
35
  content=transcript_markdown(m.text, escape=True),
35
36
  )
@@ -45,4 +46,4 @@ def trace_assistant_message(
45
46
  content.extend(render_tool_calls(message.tool_calls))
46
47
 
47
48
  # print the assistant message
48
- trace_panel(title="Assistant", content=content)
49
+ conversation_panel(title="Assistant", content=content)
@@ -43,6 +43,7 @@ from ._chat_message import (
43
43
  ChatMessageTool,
44
44
  ChatMessageUser,
45
45
  )
46
+ from ._conversation import conversation_assistant_message
46
47
  from ._generate_config import (
47
48
  GenerateConfig,
48
49
  active_generate_config,
@@ -50,7 +51,6 @@ from ._generate_config import (
50
51
  )
51
52
  from ._model_call import ModelCall
52
53
  from ._model_output import ModelOutput, ModelUsage
53
- from ._trace import trace_assistant_message
54
54
 
55
55
  logger = logging.getLogger(__name__)
56
56
 
@@ -487,7 +487,7 @@ class Model:
487
487
  updated_output: ModelOutput, updated_call: ModelCall | None
488
488
  ) -> None:
489
489
  # trace
490
- trace_assistant_message(input, updated_output.choices[0].message)
490
+ conversation_assistant_message(input, updated_output.choices[0].message)
491
491
 
492
492
  # update event
493
493
  event.output = updated_output
@@ -28,11 +28,11 @@ from pydantic import JsonValue
28
28
  from typing_extensions import override
29
29
 
30
30
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED, DEFAULT_MAX_RETRIES
31
- from inspect_ai._util.content import Content, ContentText
31
+ from inspect_ai._util.content import Content, ContentImage, ContentText
32
32
  from inspect_ai._util.error import exception_message
33
- from inspect_ai._util.images import image_as_data_uri
33
+ from inspect_ai._util.images import file_as_data_uri
34
34
  from inspect_ai._util.logger import warn_once
35
- from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64, is_data_uri
35
+ from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
36
36
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
37
37
 
38
38
  from .._chat_message import (
@@ -584,11 +584,9 @@ async def message_param_content(
584
584
  ) -> TextBlockParam | ImageBlockParam:
585
585
  if isinstance(content, ContentText):
586
586
  return TextBlockParam(type="text", text=content.text or NO_CONTENT)
587
- else:
587
+ elif isinstance(content, ContentImage):
588
588
  # resolve to url
589
- image = content.image
590
- if not is_data_uri(image):
591
- image = await image_as_data_uri(image)
589
+ image = await file_as_data_uri(content.image)
592
590
 
593
591
  # resolve mime type and base64 content
594
592
  media_type = data_uri_mime_type(image) or "image/png"
@@ -601,6 +599,10 @@ async def message_param_content(
601
599
  type="image",
602
600
  source=dict(type="base64", media_type=cast(Any, media_type), data=image),
603
601
  )
602
+ else:
603
+ raise RuntimeError(
604
+ "Anthropic models do not currently support audio or video inputs."
605
+ )
604
606
 
605
607
 
606
608
  def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
@@ -31,8 +31,8 @@ from azure.core.exceptions import AzureError, HttpResponseError
31
31
  from typing_extensions import override
32
32
 
33
33
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
34
- from inspect_ai._util.content import Content, ContentText
35
- from inspect_ai._util.images import image_as_data_uri
34
+ from inspect_ai._util.content import Content, ContentImage, ContentText
35
+ from inspect_ai._util.images import file_as_data_uri
36
36
  from inspect_ai.tool import ToolChoice, ToolInfo
37
37
  from inspect_ai.tool._tool_call import ToolCall
38
38
  from inspect_ai.tool._tool_choice import ToolFunction
@@ -312,12 +312,14 @@ async def chat_request_message(
312
312
  async def chat_content_item(content: Content) -> ContentItem:
313
313
  if isinstance(content, ContentText):
314
314
  return TextContentItem(text=content.text)
315
- else:
315
+ elif isinstance(content, ContentImage):
316
316
  return ImageContentItem(
317
317
  image_url=ImageUrl(
318
- url=await image_as_data_uri(content.image), detail=content.detail
318
+ url=await file_as_data_uri(content.image), detail=content.detail
319
319
  )
320
320
  )
321
+ else:
322
+ raise RuntimeError("Azure AI models do not support audio or video inputs.")
321
323
 
322
324
 
323
325
  def chat_tool_call(tool_call: ToolCall) -> ChatCompletionsToolCall:
@@ -11,7 +11,7 @@ from inspect_ai._util.constants import (
11
11
  )
12
12
  from inspect_ai._util.content import Content, ContentImage, ContentText
13
13
  from inspect_ai._util.error import pip_dependency_error
14
- from inspect_ai._util.images import image_as_data
14
+ from inspect_ai._util.images import file_as_data
15
15
  from inspect_ai._util.version import verify_required_version
16
16
  from inspect_ai.tool import ToolChoice, ToolInfo
17
17
  from inspect_ai.tool._tool_call import ToolCall
@@ -430,7 +430,9 @@ def model_output_from_response(
430
430
  content.append(ContentText(type="text", text=c.text))
431
431
  elif c.image is not None:
432
432
  base64_image = base64.b64encode(c.image.source.bytes).decode("utf-8")
433
- content.append(ContentImage(image=base64_image))
433
+ content.append(
434
+ ContentImage(image=f"data:image/{c.image.format};base64,{base64_image}")
435
+ )
434
436
  elif c.toolUse is not None:
435
437
  tool_calls.append(
436
438
  ToolCall(
@@ -565,7 +567,7 @@ async def converse_chat_message(
565
567
  if c.type == "text":
566
568
  tool_result_content.append(ConverseToolResultContent(text=c.text))
567
569
  elif c.type == "image":
568
- image_data, image_type = await image_as_data(c.image)
570
+ image_data, image_type = await file_as_data(c.image)
569
571
  tool_result_content.append(
570
572
  ConverseToolResultContent(
571
573
  image=ConverseImage(
@@ -604,7 +606,7 @@ async def converse_contents(
604
606
  result: list[ConverseMessageContent] = []
605
607
  for c in content:
606
608
  if c.type == "image":
607
- image_data, image_type = await image_as_data(c.image)
609
+ image_data, image_type = await file_as_data(c.image)
608
610
  result.append(
609
611
  ConverseMessageContent(
610
612
  image=ConverseImage(
@@ -1,12 +1,17 @@
1
+ import asyncio
1
2
  import functools
3
+ import hashlib
2
4
  import json
3
5
  from copy import copy
6
+ from io import BytesIO
7
+ from logging import getLogger
4
8
  from typing import Any, cast
5
9
 
6
10
  import proto # type: ignore
7
11
  from google.ai.generativelanguage import (
8
12
  Blob,
9
13
  Candidate,
14
+ File,
10
15
  FunctionCall,
11
16
  FunctionCallingConfig,
12
17
  FunctionDeclaration,
@@ -28,6 +33,8 @@ from google.generativeai import ( # type: ignore
28
33
  GenerationConfig,
29
34
  GenerativeModel,
30
35
  configure,
36
+ get_file,
37
+ upload_file,
31
38
  )
32
39
  from google.generativeai.types import ( # type: ignore
33
40
  AsyncGenerateContentResponse,
@@ -45,8 +52,16 @@ from pydantic import JsonValue
45
52
  from typing_extensions import override
46
53
 
47
54
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED
48
- from inspect_ai._util.content import Content, ContentImage, ContentText
49
- from inspect_ai._util.images import image_as_data
55
+ from inspect_ai._util.content import (
56
+ Content,
57
+ ContentAudio,
58
+ ContentImage,
59
+ ContentText,
60
+ ContentVideo,
61
+ )
62
+ from inspect_ai._util.images import file_as_data
63
+ from inspect_ai._util.kvstore import inspect_kvstore
64
+ from inspect_ai._util.trace import trace_message
50
65
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
51
66
 
52
67
  from .._chat_message import (
@@ -70,6 +85,8 @@ from .._model_output import (
70
85
  )
71
86
  from .util import model_base_url
72
87
 
88
+ logger = getLogger(__name__)
89
+
73
90
  SAFETY_SETTINGS = "safety_settings"
74
91
 
75
92
  DEFAULT_SAFETY_SETTINGS: SafetySettingDict = {
@@ -194,7 +211,9 @@ class GoogleAPI(ModelAPI):
194
211
  model=self.model_name, content=ex.message, stop_reason="model_length"
195
212
  )
196
213
  else:
197
- raise ex
214
+ return ModelOutput.from_content(
215
+ model=self.model_name, content=ex.message, stop_reason="unknown"
216
+ )
198
217
 
199
218
  @override
200
219
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -362,19 +381,23 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
362
381
  return struct
363
382
 
364
383
 
365
- async def content_part(content: Content | str) -> PartDict:
384
+ async def content_part(content: Content | str) -> PartType:
366
385
  if isinstance(content, str):
367
386
  return PartDict(text=content or NO_CONTENT)
368
387
  elif isinstance(content, ContentText):
369
388
  return PartDict(text=content.text or NO_CONTENT)
370
389
  else:
371
- return PartDict(inline_data=await chat_content_image_to_blob(content))
390
+ return await chat_content_to_part(content)
372
391
 
373
392
 
374
- async def chat_content_image_to_blob(image: ContentImage) -> Blob:
375
- image_url = image.image
376
- image_bytes, mime_type = await image_as_data(image_url)
377
- return Blob(mime_type=mime_type, data=image_bytes)
393
+ async def chat_content_to_part(
394
+ content: ContentImage | ContentAudio | ContentVideo,
395
+ ) -> PartType:
396
+ if isinstance(content, ContentImage):
397
+ content_bytes, mime_type = await file_as_data(content.image)
398
+ return Blob(mime_type=mime_type, data=content_bytes)
399
+ else:
400
+ return await file_for_content(content)
378
401
 
379
402
 
380
403
  def prepend_system_messages(
@@ -408,25 +431,34 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
408
431
  # https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
409
432
 
410
433
 
411
- def schema_from_param(param: ToolParam | ToolParams) -> Schema:
434
+ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
412
435
  if isinstance(param, ToolParams):
413
436
  param = ToolParam(
414
437
  type=param.type, properties=param.properties, required=param.required
415
438
  )
416
439
 
417
440
  if param.type == "number":
418
- return Schema(type=Type.NUMBER, description=param.description)
441
+ return Schema(
442
+ type=Type.NUMBER, description=param.description, nullable=nullable
443
+ )
419
444
  elif param.type == "integer":
420
- return Schema(type=Type.INTEGER, description=param.description)
445
+ return Schema(
446
+ type=Type.INTEGER, description=param.description, nullable=nullable
447
+ )
421
448
  elif param.type == "boolean":
422
- return Schema(type=Type.BOOLEAN, description=param.description)
449
+ return Schema(
450
+ type=Type.BOOLEAN, description=param.description, nullable=nullable
451
+ )
423
452
  elif param.type == "string":
424
- return Schema(type=Type.STRING, description=param.description)
453
+ return Schema(
454
+ type=Type.STRING, description=param.description, nullable=nullable
455
+ )
425
456
  elif param.type == "array":
426
457
  return Schema(
427
458
  type=Type.ARRAY,
428
459
  description=param.description,
429
460
  items=schema_from_param(param.items) if param.items else None,
461
+ nullable=nullable,
430
462
  )
431
463
  elif param.type == "object":
432
464
  return Schema(
@@ -436,7 +468,14 @@ def schema_from_param(param: ToolParam | ToolParams) -> Schema:
436
468
  if param.properties is not None
437
469
  else None,
438
470
  required=param.required,
471
+ nullable=nullable,
439
472
  )
473
+ # convert unions to optional params if the second type is 'null'
474
+ elif param.anyOf:
475
+ if len(param.anyOf) == 2 and param.anyOf[1].type == "null":
476
+ return schema_from_param(param.anyOf[0], nullable=True)
477
+ else:
478
+ return Schema(type=Type.TYPE_UNSPECIFIED)
440
479
  else:
441
480
  return Schema(type=Type.TYPE_UNSPECIFIED)
442
481
 
@@ -612,3 +651,53 @@ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
612
651
  return HarmBlockThreshold.BLOCK_NONE
613
652
  else:
614
653
  raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
654
+
655
+
656
+ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
657
+ # helper to write trace messages
658
+ def trace(message: str) -> None:
659
+ trace_message(logger, "Google Files", message)
660
+
661
+ # get the file bytes and compute sha256 hash
662
+ if isinstance(content, ContentAudio):
663
+ file = content.audio
664
+ else:
665
+ file = content.video
666
+ content_bytes, mime_type = await file_as_data(file)
667
+ content_sha256 = hashlib.sha256(content_bytes).hexdigest()
668
+
669
+ # we cache uploads for re-use, open the db where we track that
670
+ # (track up to 1 million previous uploads)
671
+ with inspect_kvstore("google_files", 1000000) as files_db:
672
+ # can we serve from existing uploads?
673
+ uploaded_file = files_db.get(content_sha256)
674
+ if uploaded_file:
675
+ try:
676
+ upload = cast(File, get_file(uploaded_file))
677
+ if upload.state.name == "ACTIVE":
678
+ trace(f"Using uploaded file: {uploaded_file}")
679
+ return upload
680
+ else:
681
+ trace(
682
+ f"Not using uploaded file '{uploaded_file} (state was {upload.state})"
683
+ )
684
+ except Exception as ex:
685
+ trace(f"Error attempting to access uploaded file: {ex}")
686
+ files_db.delete(content_sha256)
687
+
688
+ # do the upload (and record it)
689
+ upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
690
+ while upload.state.name == "PROCESSING":
691
+ await asyncio.sleep(3)
692
+ upload = get_file(upload.name)
693
+
694
+ if upload.state.name == "FAILED":
695
+ trace(f"Failed to upload file '{upload.name}: {upload.error}")
696
+ raise ValueError(f"Google file upload failed: {upload.error}")
697
+
698
+ # trace and record it
699
+ trace(f"Uploaded file: {upload.name}")
700
+ files_db.put(content_sha256, upload.name)
701
+
702
+ # return the file
703
+ return upload
@@ -23,8 +23,8 @@ from typing_extensions import override
23
23
 
24
24
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
25
25
  from inspect_ai._util.content import Content
26
- from inspect_ai._util.images import image_as_data_uri
27
- from inspect_ai._util.url import is_data_uri, is_http_url
26
+ from inspect_ai._util.images import file_as_data_uri
27
+ from inspect_ai._util.url import is_http_url
28
28
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
29
29
 
30
30
  from .._chat_message import (
@@ -248,18 +248,20 @@ async def as_chat_completion_part(
248
248
  ) -> ChatCompletionContentPartParam:
249
249
  if content.type == "text":
250
250
  return ChatCompletionContentPartTextParam(type="text", text=content.text)
251
- else:
251
+ elif content.type == "image":
252
252
  # API takes URL or base64 encoded file. If it's a remote file or data URL leave it alone, otherwise encode it
253
253
  image_url = content.image
254
254
  detail = content.detail
255
255
 
256
- if not is_http_url(image_url) and not is_data_uri(image_url):
257
- image_url = await image_as_data_uri(image_url)
256
+ if not is_http_url(image_url):
257
+ image_url = await file_as_data_uri(image_url)
258
258
 
259
259
  return ChatCompletionContentPartImageParam(
260
260
  type="image_url",
261
261
  image_url=dict(url=image_url, detail=detail),
262
262
  )
263
+ else:
264
+ raise RuntimeError("Groq models do not support audio or video inputs.")
263
265
 
264
266
 
265
267
  def chat_tools(tools: List[ToolInfo]) -> List[Dict[str, Any]]:
@@ -239,12 +239,17 @@ class HuggingFaceAPI(ModelAPI):
239
239
  hf_messages = inspect_tools_to_string(hf_messages)
240
240
 
241
241
  # apply chat template
242
- chat = self.tokenizer.apply_chat_template(
243
- hf_messages,
244
- add_generation_prompt=True,
245
- tokenize=False,
246
- tools=tools_list if len(tools_list) > 0 else None,
247
- )
242
+ if self.tokenizer.chat_template is not None:
243
+ chat = self.tokenizer.apply_chat_template(
244
+ hf_messages,
245
+ add_generation_prompt=True,
246
+ tokenize=False,
247
+ tools=tools_list if len(tools_list) > 0 else None,
248
+ )
249
+ else:
250
+ chat = ""
251
+ for message in hf_messages:
252
+ chat += f"{message.role}: {message.content}\n"
248
253
  # return
249
254
  return cast(str, chat)
250
255
 
@@ -42,8 +42,7 @@ from inspect_ai._util.constants import (
42
42
  DEFAULT_TIMEOUT,
43
43
  )
44
44
  from inspect_ai._util.content import Content, ContentImage, ContentText
45
- from inspect_ai._util.images import image_as_data_uri
46
- from inspect_ai._util.url import is_data_uri
45
+ from inspect_ai._util.images import file_as_data_uri
47
46
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
48
47
 
49
48
  from .._chat_message import (
@@ -351,16 +350,14 @@ def mistral_system_message_content(
351
350
  async def mistral_content_chunk(content: Content) -> ContentChunk:
352
351
  if isinstance(content, ContentText):
353
352
  return TextChunk(text=content.text or NO_CONTENT)
354
- else:
353
+ elif isinstance(content, ContentImage):
355
354
  # resolve image to url
356
- image_url = content.image
357
- if not is_data_uri(image_url):
358
- image_url = await image_as_data_uri(image_url)
355
+ image_url = await file_as_data_uri(content.image)
359
356
 
360
357
  # return chunk
361
- return ImageURLChunk(
362
- image_url=ImageURL(url=content.image, detail=content.detail)
363
- )
358
+ return ImageURLChunk(image_url=ImageURL(url=image_url, detail=content.detail))
359
+ else:
360
+ raise RuntimeError("Mistral models do not support audio or video inputs.")
364
361
 
365
362
 
366
363
  def mistral_tool_call(tool_call: ToolCall) -> MistralToolCall: