inspect-ai 0.3.58__py3-none-any.whl → 0.3.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. inspect_ai/_cli/common.py +3 -1
  2. inspect_ai/_cli/eval.py +15 -2
  3. inspect_ai/_display/core/active.py +4 -1
  4. inspect_ai/_display/core/config.py +3 -3
  5. inspect_ai/_display/core/panel.py +7 -3
  6. inspect_ai/_display/plain/__init__.py +0 -0
  7. inspect_ai/_display/plain/display.py +203 -0
  8. inspect_ai/_display/rich/display.py +0 -5
  9. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  10. inspect_ai/_display/textual/widgets/samples.py +78 -11
  11. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  12. inspect_ai/_eval/score.py +1 -0
  13. inspect_ai/_eval/task/results.py +50 -22
  14. inspect_ai/_eval/task/run.py +41 -7
  15. inspect_ai/_eval/task/sandbox.py +10 -5
  16. inspect_ai/_util/constants.py +1 -0
  17. inspect_ai/_util/port_names.py +61 -0
  18. inspect_ai/_util/text.py +23 -0
  19. inspect_ai/_view/www/App.css +31 -1
  20. inspect_ai/_view/www/dist/assets/index.css +31 -1
  21. inspect_ai/_view/www/dist/assets/index.js +25344 -1849
  22. inspect_ai/_view/www/log-schema.json +32 -2
  23. inspect_ai/_view/www/package.json +2 -0
  24. inspect_ai/_view/www/src/App.mjs +8 -10
  25. inspect_ai/_view/www/src/Types.mjs +0 -1
  26. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  27. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  28. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  29. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  30. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  31. inspect_ai/_view/www/src/index.js +75 -2
  32. inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
  33. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
  34. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  35. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  36. inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
  37. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  38. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +24 -12
  39. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
  40. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  41. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  42. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  43. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  44. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  45. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  46. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  47. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  48. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  49. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  50. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  51. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  52. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  53. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  54. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  55. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  56. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  57. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  58. inspect_ai/_view/www/src/types/log.d.ts +13 -2
  59. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  60. inspect_ai/_view/www/src/utils/Json.mjs +12 -6
  61. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
  62. inspect_ai/_view/www/vite.config.js +7 -0
  63. inspect_ai/_view/www/yarn.lock +116 -0
  64. inspect_ai/approval/_human/__init__.py +0 -0
  65. inspect_ai/approval/_policy.py +12 -6
  66. inspect_ai/log/_log.py +1 -1
  67. inspect_ai/log/_samples.py +16 -0
  68. inspect_ai/log/_transcript.py +4 -1
  69. inspect_ai/model/_call_tools.py +4 -0
  70. inspect_ai/model/_conversation.py +20 -8
  71. inspect_ai/model/_generate_config.py +10 -4
  72. inspect_ai/model/_model.py +117 -18
  73. inspect_ai/model/_model_output.py +7 -2
  74. inspect_ai/model/_providers/anthropic.py +100 -44
  75. inspect_ai/model/_providers/azureai.py +20 -20
  76. inspect_ai/model/_providers/bedrock.py +37 -40
  77. inspect_ai/model/_providers/google.py +46 -54
  78. inspect_ai/model/_providers/mistral.py +11 -11
  79. inspect_ai/model/_providers/openai.py +15 -16
  80. inspect_ai/model/_providers/openai_o1.py +9 -8
  81. inspect_ai/model/_providers/providers.py +1 -1
  82. inspect_ai/model/_providers/together.py +8 -8
  83. inspect_ai/model/_providers/vertex.py +1 -4
  84. inspect_ai/scorer/_reducer/reducer.py +1 -1
  85. inspect_ai/scorer/_scorer.py +2 -2
  86. inspect_ai/solver/__init__.py +2 -5
  87. inspect_ai/solver/_prompt.py +35 -5
  88. inspect_ai/solver/_task_state.py +80 -38
  89. inspect_ai/tool/__init__.py +2 -0
  90. inspect_ai/tool/_tool.py +12 -1
  91. inspect_ai/tool/_tool_call.py +10 -0
  92. inspect_ai/tool/_tool_def.py +16 -5
  93. inspect_ai/tool/_tool_with.py +21 -4
  94. inspect_ai/tool/beta/__init__.py +5 -0
  95. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  96. inspect_ai/tool/beta/_computer/_common.py +133 -0
  97. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  98. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  99. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  100. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  101. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  102. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  103. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  104. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  105. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  106. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  107. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  108. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  109. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  110. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  111. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  112. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  113. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  114. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  115. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  116. inspect_ai/util/__init__.py +2 -0
  117. inspect_ai/util/_limit.py +26 -0
  118. inspect_ai/util/_sandbox/docker/docker.py +64 -1
  119. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  120. inspect_ai/util/_sandbox/environment.py +14 -0
  121. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
  122. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +126 -98
  123. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  124. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
  125. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
  126. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
  127. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ import proto # type: ignore
11
11
  from google.ai.generativelanguage import (
12
12
  Blob,
13
13
  Candidate,
14
- File,
15
14
  FunctionCall,
16
15
  FunctionCallingConfig,
17
16
  FunctionDeclaration,
@@ -29,29 +28,29 @@ from google.api_core.exceptions import (
29
28
  TooManyRequests,
30
29
  )
31
30
  from google.api_core.retry.retry_base import if_transient_error
32
- from google.generativeai import ( # type: ignore
33
- GenerationConfig,
34
- GenerativeModel,
35
- configure,
36
- get_file,
37
- upload_file,
38
- )
39
- from google.generativeai.types import ( # type: ignore
40
- AsyncGenerateContentResponse,
31
+ from google.generativeai.client import configure
32
+ from google.generativeai.files import get_file, upload_file
33
+ from google.generativeai.generative_models import GenerativeModel
34
+ from google.generativeai.types import (
41
35
  ContentDict,
42
- HarmBlockThreshold,
43
- HarmCategory,
36
+ GenerationConfig,
44
37
  PartDict,
45
38
  PartType,
46
- SafetySettingDict,
47
39
  Tool,
48
40
  )
41
+ from google.generativeai.types.file_types import File
42
+ from google.generativeai.types.generation_types import AsyncGenerateContentResponse
43
+ from google.generativeai.types.safety_types import (
44
+ EasySafetySettingDict,
45
+ HarmBlockThreshold,
46
+ HarmCategory,
47
+ )
49
48
  from google.protobuf.json_format import MessageToDict, ParseDict
50
49
  from google.protobuf.struct_pb2 import Struct
51
50
  from pydantic import JsonValue
52
51
  from typing_extensions import override
53
52
 
54
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED
53
+ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
55
54
  from inspect_ai._util.content import (
56
55
  Content,
57
56
  ContentAudio,
@@ -89,7 +88,7 @@ logger = getLogger(__name__)
89
88
 
90
89
  SAFETY_SETTINGS = "safety_settings"
91
90
 
92
- DEFAULT_SAFETY_SETTINGS: SafetySettingDict = {
91
+ DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
93
92
  HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
94
93
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
95
94
  HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
@@ -141,7 +140,7 @@ class GoogleAPI(ModelAPI):
141
140
  tools: list[ToolInfo],
142
141
  tool_choice: ToolChoice,
143
142
  config: GenerateConfig,
144
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
143
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
145
144
  parameters = GenerationConfig(
146
145
  temperature=config.temperature,
147
146
  top_p=config.top_p,
@@ -149,11 +148,8 @@ class GoogleAPI(ModelAPI):
149
148
  max_output_tokens=config.max_tokens,
150
149
  stop_sequences=config.stop_seqs,
151
150
  candidate_count=config.num_choices,
152
- seed=config.seed,
153
151
  presence_penalty=config.presence_penalty,
154
152
  frequency_penalty=config.frequency_penalty,
155
- response_logprobs=config.logprobs,
156
- logprobs=config.top_logprobs,
157
153
  )
158
154
 
159
155
  # google-native messages
@@ -176,18 +172,15 @@ class GoogleAPI(ModelAPI):
176
172
  response=response,
177
173
  )
178
174
 
179
- # cast to AsyncGenerateContentResponse since we passed stream=False
180
175
  try:
181
- response = cast(
182
- AsyncGenerateContentResponse,
183
- await self.model.generate_content_async(
184
- contents=contents,
185
- safety_settings=self.safety_settings,
186
- generation_config=parameters,
187
- tools=gemini_tools,
188
- tool_config=gemini_tool_config,
189
- ),
176
+ response = await self.model.generate_content_async(
177
+ contents=contents,
178
+ safety_settings=self.safety_settings,
179
+ generation_config=parameters,
180
+ tools=gemini_tools,
181
+ tool_config=gemini_tool_config,
190
182
  )
183
+
191
184
  except InvalidArgument as ex:
192
185
  return self.handle_invalid_argument(ex), model_call()
193
186
 
@@ -205,15 +198,13 @@ class GoogleAPI(ModelAPI):
205
198
  # return
206
199
  return output, model_call()
207
200
 
208
- def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput:
201
+ def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
209
202
  if "size exceeds the limit" in ex.message.lower():
210
203
  return ModelOutput.from_content(
211
204
  model=self.model_name, content=ex.message, stop_reason="model_length"
212
205
  )
213
206
  else:
214
- return ModelOutput.from_content(
215
- model=self.model_name, content=ex.message, stop_reason="unknown"
216
- )
207
+ return ex
217
208
 
218
209
  @override
219
210
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -231,7 +222,7 @@ class GoogleAPI(ModelAPI):
231
222
  def build_model_call(
232
223
  contents: list[ContentDict],
233
224
  generation_config: GenerationConfig,
234
- safety_settings: SafetySettingDict,
225
+ safety_settings: EasySafetySettingDict,
235
226
  tools: list[Tool] | None,
236
227
  tool_config: ToolConfig | None,
237
228
  response: AsyncGenerateContentResponse | None,
@@ -248,7 +239,7 @@ def build_model_call(
248
239
  if tool_config is not None
249
240
  else None,
250
241
  ),
251
- response=response.to_dict() if response is not None else {},
242
+ response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
252
243
  filter=model_call_filter,
253
244
  )
254
245
 
@@ -269,12 +260,12 @@ def model_call_content(content: ContentDict) -> ContentDict:
269
260
 
270
261
  def model_call_part(part: PartType) -> PartType:
271
262
  if isinstance(part, proto.Message):
272
- return MessageToDict(part._pb)
263
+ return cast(PartDict, MessageToDict(part._pb))
273
264
  elif isinstance(part, dict):
274
265
  part = part.copy()
275
266
  keys = list(part.keys())
276
267
  for key in keys:
277
- part[key] = model_call_part(part[key])
268
+ part[key] = model_call_part(part[key]) # type: ignore[literal-required]
278
269
  return part
279
270
  else:
280
271
  return part
@@ -316,9 +307,6 @@ def consective_tool_message_reducer(
316
307
  return messages
317
308
 
318
309
 
319
- NO_CONTENT = "(no content)"
320
-
321
-
322
310
  async def content_dict(
323
311
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
324
312
  ) -> ContentDict:
@@ -326,13 +314,13 @@ async def content_dict(
326
314
  return ContentDict(
327
315
  role="user",
328
316
  parts=(
329
- [PartDict(text=message.content or NO_CONTENT)]
317
+ [message.content or NO_CONTENT]
330
318
  if isinstance(message.content, str)
331
319
  else [await content_part(content) for content in message.content]
332
320
  ),
333
321
  )
334
322
  elif isinstance(message, ChatMessageAssistant):
335
- content_parts: list[Part] = []
323
+ content_parts: list[PartType] = []
336
324
  # tool call parts
337
325
  if message.tool_calls is not None:
338
326
  content_parts.extend(
@@ -383,9 +371,9 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
383
371
 
384
372
  async def content_part(content: Content | str) -> PartType:
385
373
  if isinstance(content, str):
386
- return PartDict(text=content or NO_CONTENT)
374
+ return content or NO_CONTENT
387
375
  elif isinstance(content, ContentText):
388
- return PartDict(text=content.text or NO_CONTENT)
376
+ return content.text or NO_CONTENT
389
377
  else:
390
378
  return await chat_content_to_part(content)
391
379
 
@@ -404,7 +392,9 @@ def prepend_system_messages(
404
392
  messages: list[ContentDict], system_messages: list[ChatMessageSystem]
405
393
  ) -> None:
406
394
  # create system_parts
407
- system_parts = [Part(text=message.content) for message in system_messages]
395
+ system_parts: list[PartType] = [
396
+ Part(text=message.content) for message in system_messages
397
+ ]
408
398
 
409
399
  # we want the system messages to be prepended to the first user message
410
400
  # (if there is no first user message then prepend one)
@@ -476,6 +466,8 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
476
466
  return schema_from_param(param.anyOf[0], nullable=True)
477
467
  else:
478
468
  return Schema(type=Type.TYPE_UNSPECIFIED)
469
+ elif param.enum:
470
+ return Schema(type=Type.STRING, format="enum", enum=param.enum)
479
471
  else:
480
472
  return Schema(type=Type.TYPE_UNSPECIFIED)
481
473
 
@@ -600,14 +592,14 @@ def gapi_should_retry(ex: BaseException) -> bool:
600
592
 
601
593
  def parse_safety_settings(
602
594
  safety_settings: Any,
603
- ) -> dict[HarmCategory, HarmBlockThreshold]:
595
+ ) -> EasySafetySettingDict:
604
596
  # ensure we have a dict
605
597
  if isinstance(safety_settings, str):
606
598
  safety_settings = json.loads(safety_settings)
607
599
  if not isinstance(safety_settings, dict):
608
600
  raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
609
601
 
610
- parsed_settings: dict[HarmCategory, HarmBlockThreshold] = {}
602
+ parsed_settings: EasySafetySettingDict = {}
611
603
  for key, value in safety_settings.items():
612
604
  if isinstance(key, str):
613
605
  key = str_to_harm_category(key)
@@ -623,23 +615,23 @@ def parse_safety_settings(
623
615
  return parsed_settings
624
616
 
625
617
 
626
- def str_to_harm_category(category: str) -> HarmCategory:
618
+ def str_to_harm_category(category: str) -> int:
627
619
  category = category.upper()
628
620
  if "HARASSMENT" in category:
629
- return HarmCategory.HARM_CATEGORY_HARASSMENT
621
+ return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
630
622
  elif "HATE_SPEECH" in category:
631
- return HarmCategory.HARM_CATEGORY_HATE_SPEECH
623
+ return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
632
624
  elif "SEXUALLY_EXPLICIT" in category:
633
- return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
625
+ return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
634
626
  elif "DANGEROUS_CONTENT" in category:
635
- return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
627
+ return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
636
628
  else:
637
629
  # NOTE: Although there is an "UNSPECIFIED" category, in the
638
630
  # documentation, the API does not accept it.
639
631
  raise ValueError(f"Unknown HarmCategory: {category}")
640
632
 
641
633
 
642
- def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
634
+ def str_to_harm_block_threshold(threshold: str) -> int:
643
635
  threshold = threshold.upper()
644
636
  if "LOW" in threshold:
645
637
  return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
@@ -673,7 +665,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
673
665
  uploaded_file = files_db.get(content_sha256)
674
666
  if uploaded_file:
675
667
  try:
676
- upload = cast(File, get_file(uploaded_file))
668
+ upload = get_file(uploaded_file)
677
669
  if upload.state.name == "ACTIVE":
678
670
  trace(f"Using uploaded file: {uploaded_file}")
679
671
  return upload
@@ -40,6 +40,7 @@ from typing_extensions import override
40
40
  # https://github.com/mistralai/client-python/blob/main/MIGRATION.md
41
41
  from inspect_ai._util.constants import (
42
42
  DEFAULT_TIMEOUT,
43
+ NO_CONTENT,
43
44
  )
44
45
  from inspect_ai._util.content import Content, ContentImage, ContentText
45
46
  from inspect_ai._util.images import file_as_data_uri
@@ -122,7 +123,7 @@ class MistralAPI(ModelAPI):
122
123
  tools: list[ToolInfo],
123
124
  tool_choice: ToolChoice,
124
125
  config: GenerateConfig,
125
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
126
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
126
127
  # build request
127
128
  request: dict[str, Any] = dict(
128
129
  model=self.model_name,
@@ -146,7 +147,7 @@ class MistralAPI(ModelAPI):
146
147
  response = await self.client.chat.complete_async(**request)
147
148
  except SDKError as ex:
148
149
  if ex.status_code == 400:
149
- return self.handle_bad_request(ex)
150
+ return self.handle_bad_request(ex), mistral_model_call(request, None)
150
151
  else:
151
152
  raise ex
152
153
 
@@ -181,25 +182,27 @@ class MistralAPI(ModelAPI):
181
182
  def connection_key(self) -> str:
182
183
  return str(self.api_key)
183
184
 
184
- def handle_bad_request(self, ex: SDKError) -> ModelOutput:
185
+ def handle_bad_request(self, ex: SDKError) -> ModelOutput | Exception:
186
+ body = json.loads(ex.body)
187
+ content = body.get("message", ex.body)
185
188
  if "maximum context length" in ex.body:
186
- body = json.loads(ex.body)
187
- content = body.get("message", ex.body)
188
189
  return ModelOutput.from_content(
189
190
  model=self.model_name, content=content, stop_reason="model_length"
190
191
  )
191
192
  else:
192
- raise ex
193
+ return ex
193
194
 
194
195
 
195
196
  def mistral_model_call(
196
- request: dict[str, Any], response: MistralChatCompletionResponse
197
+ request: dict[str, Any], response: MistralChatCompletionResponse | None
197
198
  ) -> ModelCall:
198
199
  request = request.copy()
199
200
  request.update(messages=[message.model_dump() for message in request["messages"]])
200
201
  if request.get("tools", None) is not None:
201
202
  request["tools"] = [tool.model_dump() for tool in request["tools"]]
202
- return ModelCall(request=request, response=response.model_dump())
203
+ return ModelCall(
204
+ request=request, response=response.model_dump() if response else {}
205
+ )
203
206
 
204
207
 
205
208
  def mistral_chat_tools(tools: list[ToolInfo]) -> list[MistralTool]:
@@ -326,9 +329,6 @@ async def mistral_chat_message(
326
329
  )
327
330
 
328
331
 
329
- NO_CONTENT = "(no content)"
330
-
331
-
332
332
  async def mistral_message_content(
333
333
  content: str | list[Content],
334
334
  ) -> str | list[ContentChunk]:
@@ -166,7 +166,7 @@ class OpenAIAPI(ModelAPI):
166
166
  tools: list[ToolInfo],
167
167
  tool_choice: ToolChoice,
168
168
  config: GenerateConfig,
169
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
169
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
170
170
  # short-circuit to call o1- models that are text only
171
171
  if self.is_o1_preview() or self.is_o1_mini():
172
172
  return await generate_o1(
@@ -307,27 +307,26 @@ class OpenAIAPI(ModelAPI):
307
307
  return params
308
308
 
309
309
  # convert some well known bad request errors into ModelOutput
310
- def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
311
- if e.status_code == 400:
312
- # extract message
313
- if isinstance(e.body, dict) and "message" in e.body.keys():
314
- content = str(e.body.get("message"))
315
- else:
316
- content = e.message
310
+ def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
311
+ # extract message
312
+ if isinstance(e.body, dict) and "message" in e.body.keys():
313
+ content = str(e.body.get("message"))
314
+ else:
315
+ content = e.message
317
316
 
318
- # narrow stop_reason
319
- if e.code == "context_length_exceeded":
320
- stop_reason: StopReason = "model_length"
321
- elif e.code == "invalid_prompt":
322
- stop_reason = "content_filter"
323
- else:
324
- stop_reason = "unknown"
317
+ # narrow stop_reason
318
+ stop_reason: StopReason | None = None
319
+ if e.code == "context_length_exceeded":
320
+ stop_reason = "model_length"
321
+ elif e.code == "invalid_prompt":
322
+ stop_reason = "content_filter"
325
323
 
324
+ if stop_reason:
326
325
  return ModelOutput.from_content(
327
326
  model=self.model_name, content=content, stop_reason=stop_reason
328
327
  )
329
328
  else:
330
- raise e
329
+ return e
331
330
 
332
331
 
333
332
  async def as_openai_chat_messages(
@@ -44,7 +44,7 @@ async def generate_o1(
44
44
  input: list[ChatMessage],
45
45
  tools: list[ToolInfo],
46
46
  **params: Any,
47
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
47
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
48
48
  # create chatapi handler
49
49
  handler = O1PreviewChatAPIHandler()
50
50
 
@@ -82,17 +82,18 @@ async def generate_o1(
82
82
  ), model_call()
83
83
 
84
84
 
85
- def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
85
+ def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput | Exception:
86
86
  if ex.code == "context_length_exceeded":
87
- stop_reason: StopReason = "model_length"
87
+ stop_reason: StopReason | None = "model_length"
88
88
  elif ex.code == "invalid_prompt":
89
89
  stop_reason = "content_filter"
90
- else:
91
- stop_reason = "unknown"
92
90
 
93
- return ModelOutput.from_content(
94
- model=model, content=str(ex), stop_reason=stop_reason
95
- )
91
+ if stop_reason:
92
+ return ModelOutput.from_content(
93
+ model=model, content=str(ex), stop_reason=stop_reason
94
+ )
95
+ else:
96
+ return ex
96
97
 
97
98
 
98
99
  def chat_messages(
@@ -94,7 +94,7 @@ def vertex() -> type[ModelAPI]:
94
94
  def google() -> type[ModelAPI]:
95
95
  FEATURE = "Google API"
96
96
  PACKAGE = "google-generativeai"
97
- MIN_VERSION = "0.8.3"
97
+ MIN_VERSION = "0.8.4"
98
98
 
99
99
  # workaround log spam
100
100
  # https://github.com/ray-project/ray/issues/24917
@@ -103,18 +103,18 @@ class TogetherAIAPI(OpenAIAPI):
103
103
  return DEFAULT_MAX_TOKENS
104
104
 
105
105
  @override
106
- def handle_bad_request(self, ex: BadRequestError) -> ModelOutput:
107
- if ex.status_code == 400 and "max_new_tokens" in ex.message:
108
- response = ex.response.json()
109
- if "error" in response and "message" in response.get("error"):
110
- content = response.get("error").get("message")
111
- else:
112
- content = str(response)
106
+ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
107
+ response = ex.response.json()
108
+ if "error" in response and "message" in response.get("error"):
109
+ content = response.get("error").get("message")
110
+ else:
111
+ content = str(response)
112
+ if "max_new_tokens" in ex.message:
113
113
  return ModelOutput.from_content(
114
114
  model=self.model_name, content=content, stop_reason="model_length"
115
115
  )
116
116
  else:
117
- raise ex
117
+ return ex
118
118
 
119
119
  # Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
120
120
  def _chat_choices_from_response(
@@ -23,7 +23,7 @@ from vertexai.generative_models import ( # type: ignore
23
23
  )
24
24
  from vertexai.generative_models import Content as VertexContent
25
25
 
26
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED
26
+ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
27
27
  from inspect_ai._util.content import (
28
28
  Content,
29
29
  ContentAudio,
@@ -250,9 +250,6 @@ def consective_tool_message_reducer(
250
250
  return messages
251
251
 
252
252
 
253
- NO_CONTENT = "(no content)"
254
-
255
-
256
253
  async def content_dict(
257
254
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
258
255
  ) -> VertexContent:
@@ -111,7 +111,7 @@ def pass_at(
111
111
  if total - correct < k:
112
112
  return 1.0
113
113
  else:
114
- return 1.0 - cast(
114
+ return 1.0 - cast( # type: ignore[redundant-cast]
115
115
  float,
116
116
  np.prod(1.0 - k / np.arange(total - correct + 1, total + 1)).item(),
117
117
  )
@@ -151,8 +151,8 @@ def scorer_metrics(
151
151
  return cast(list[Metric | dict[str, list[Metric]]], metrics_raw)
152
152
 
153
153
 
154
- def unique_scorer_name(scorer: Scorer, already_used_names: list[str]) -> str:
155
- base_name = registry_unqualified_name(scorer)
154
+ def unique_scorer_name(scorer: Scorer | str, already_used_names: list[str]) -> str:
155
+ base_name = scorer if isinstance(scorer, str) else registry_unqualified_name(scorer)
156
156
  scorer_name = base_name
157
157
  count = 1
158
158
  while scorer_name in already_used_names:
@@ -7,11 +7,7 @@ from ._fork import fork
7
7
  from ._human_agent.agent import human_agent
8
8
  from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
9
9
  from ._plan import Plan, plan
10
- from ._prompt import (
11
- chain_of_thought,
12
- prompt_template,
13
- system_message,
14
- )
10
+ from ._prompt import chain_of_thought, prompt_template, system_message, user_message
15
11
  from ._solver import Generate, Solver, SolverSpec, generate, solver
16
12
  from ._task_state import Choice, Choices, TaskState
17
13
  from ._use_tools import use_tools
@@ -26,6 +22,7 @@ __all__ = [
26
22
  "chain_of_thought",
27
23
  "multiple_choice",
28
24
  "system_message",
25
+ "user_message",
29
26
  "self_critique",
30
27
  "use_tools",
31
28
  "plan",
@@ -2,6 +2,7 @@ from typing import Any
2
2
 
3
3
  from inspect_ai._util.dict import omit
4
4
  from inspect_ai.model import ChatMessageSystem
5
+ from inspect_ai.model._chat_message import ChatMessageUser
5
6
  from inspect_ai.util import resource
6
7
 
7
8
  from ._solver import Generate, Solver, solver
@@ -15,7 +16,8 @@ def prompt_template(template: str, **params: Any) -> Solver:
15
16
 
16
17
  Prompt template containing a `{prompt}` placeholder and any
17
18
  number of additional `params`. All values contained in sample
18
- `metadata` are also automatically included in the `params`.
19
+ `metadata` and `store` are also automatically included in the
20
+ `params`.
19
21
 
20
22
  Args:
21
23
  template: (str): Template for prompt.
@@ -29,7 +31,7 @@ def prompt_template(template: str, **params: Any) -> Solver:
29
31
 
30
32
  async def solve(state: TaskState, generate: Generate) -> TaskState:
31
33
  prompt = state.user_prompt
32
- kwargs = omit(state.metadata, ["prompt"]) | params
34
+ kwargs = omit(state.metadata | state.store._data, ["prompt"]) | params
33
35
  prompt.text = prompt_template.format(prompt=prompt.text, **kwargs)
34
36
  return state
35
37
 
@@ -41,8 +43,9 @@ def system_message(template: str, **params: Any) -> Solver:
41
43
  """Solver which inserts a system message into the conversation.
42
44
 
43
45
  System message template containing any number of optional `params`.
44
- for substitution. All values contained in sample `metadata` are also
45
- automatically included in the `params`.
46
+ for substitution using the `str.format()` method. All values
47
+ contained in sample `metadata` and `store` are also automatically
48
+ included in the `params`.
46
49
 
47
50
  The new message will go after other system messages (if there
48
51
  are none it will be inserted at the beginning of the conversation).
@@ -58,7 +61,7 @@ def system_message(template: str, **params: Any) -> Solver:
58
61
  content = resource(template)
59
62
 
60
63
  async def solve(state: TaskState, generate: Generate) -> TaskState:
61
- kwargs = state.metadata | params
64
+ kwargs = state.metadata | state.store._data | params
62
65
  append_system_message(
63
66
  state.messages, ChatMessageSystem(content=content.format(**kwargs))
64
67
  )
@@ -67,6 +70,33 @@ def system_message(template: str, **params: Any) -> Solver:
67
70
  return solve
68
71
 
69
72
 
73
+ @solver
74
+ def user_message(template: str, **params: Any) -> Solver:
75
+ """Solver which inserts a user message into the conversation.
76
+
77
+ User message template containing any number of optional `params`.
78
+ for substitution using the `str.format()` method. All values
79
+ contained in sample `metadata` and `store` are also automatically
80
+ included in the `params`.
81
+
82
+ Args:
83
+ template (str): Template for user message.
84
+ **params (dict[str,Any]): Parameters to fill into the template.
85
+
86
+ Returns:
87
+ A solver that inserts the parameterised user message.
88
+ """
89
+ # read template
90
+ content = resource(template)
91
+
92
+ async def solve(state: TaskState, generate: Generate) -> TaskState:
93
+ kwargs = state.metadata | state.store._data | params
94
+ state.messages.append(ChatMessageUser(content=content.format(**kwargs)))
95
+ return state
96
+
97
+ return solve
98
+
99
+
70
100
  DEFAULT_COT_TEMPLATE = r"""
71
101
  {prompt}
72
102