inspect-ai 0.3.58__py3-none-any.whl → 0.3.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/common.py +3 -1
- inspect_ai/_cli/eval.py +15 -2
- inspect_ai/_display/core/active.py +4 -1
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +0 -5
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +78 -11
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/score.py +1 -0
- inspect_ai/_eval/task/results.py +50 -22
- inspect_ai/_eval/task/run.py +41 -7
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25344 -1849
- inspect_ai/_view/www/log-schema.json +32 -2
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +8 -10
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +75 -2
- inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +24 -12
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +13 -2
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/Json.mjs +12 -6
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_samples.py +16 -0
- inspect_ai/log/_transcript.py +4 -1
- inspect_ai/model/_call_tools.py +4 -0
- inspect_ai/model/_conversation.py +20 -8
- inspect_ai/model/_generate_config.py +10 -4
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +7 -2
- inspect_ai/model/_providers/anthropic.py +100 -44
- inspect_ai/model/_providers/azureai.py +20 -20
- inspect_ai/model/_providers/bedrock.py +37 -40
- inspect_ai/model/_providers/google.py +46 -54
- inspect_ai/model/_providers/mistral.py +11 -11
- inspect_ai/model/_providers/openai.py +15 -16
- inspect_ai/model/_providers/openai_o1.py +9 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +8 -8
- inspect_ai/model/_providers/vertex.py +1 -4
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +2 -2
- inspect_ai/solver/__init__.py +2 -5
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +12 -1
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/docker/docker.py +64 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/environment.py +14 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +126 -98
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ import proto # type: ignore
|
|
11
11
|
from google.ai.generativelanguage import (
|
12
12
|
Blob,
|
13
13
|
Candidate,
|
14
|
-
File,
|
15
14
|
FunctionCall,
|
16
15
|
FunctionCallingConfig,
|
17
16
|
FunctionDeclaration,
|
@@ -29,29 +28,29 @@ from google.api_core.exceptions import (
|
|
29
28
|
TooManyRequests,
|
30
29
|
)
|
31
30
|
from google.api_core.retry.retry_base import if_transient_error
|
32
|
-
from google.generativeai import
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
get_file,
|
37
|
-
upload_file,
|
38
|
-
)
|
39
|
-
from google.generativeai.types import ( # type: ignore
|
40
|
-
AsyncGenerateContentResponse,
|
31
|
+
from google.generativeai.client import configure
|
32
|
+
from google.generativeai.files import get_file, upload_file
|
33
|
+
from google.generativeai.generative_models import GenerativeModel
|
34
|
+
from google.generativeai.types import (
|
41
35
|
ContentDict,
|
42
|
-
|
43
|
-
HarmCategory,
|
36
|
+
GenerationConfig,
|
44
37
|
PartDict,
|
45
38
|
PartType,
|
46
|
-
SafetySettingDict,
|
47
39
|
Tool,
|
48
40
|
)
|
41
|
+
from google.generativeai.types.file_types import File
|
42
|
+
from google.generativeai.types.generation_types import AsyncGenerateContentResponse
|
43
|
+
from google.generativeai.types.safety_types import (
|
44
|
+
EasySafetySettingDict,
|
45
|
+
HarmBlockThreshold,
|
46
|
+
HarmCategory,
|
47
|
+
)
|
49
48
|
from google.protobuf.json_format import MessageToDict, ParseDict
|
50
49
|
from google.protobuf.struct_pb2 import Struct
|
51
50
|
from pydantic import JsonValue
|
52
51
|
from typing_extensions import override
|
53
52
|
|
54
|
-
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
53
|
+
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
55
54
|
from inspect_ai._util.content import (
|
56
55
|
Content,
|
57
56
|
ContentAudio,
|
@@ -89,7 +88,7 @@ logger = getLogger(__name__)
|
|
89
88
|
|
90
89
|
SAFETY_SETTINGS = "safety_settings"
|
91
90
|
|
92
|
-
DEFAULT_SAFETY_SETTINGS:
|
91
|
+
DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
|
93
92
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
94
93
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
95
94
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
@@ -141,7 +140,7 @@ class GoogleAPI(ModelAPI):
|
|
141
140
|
tools: list[ToolInfo],
|
142
141
|
tool_choice: ToolChoice,
|
143
142
|
config: GenerateConfig,
|
144
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
143
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
145
144
|
parameters = GenerationConfig(
|
146
145
|
temperature=config.temperature,
|
147
146
|
top_p=config.top_p,
|
@@ -149,11 +148,8 @@ class GoogleAPI(ModelAPI):
|
|
149
148
|
max_output_tokens=config.max_tokens,
|
150
149
|
stop_sequences=config.stop_seqs,
|
151
150
|
candidate_count=config.num_choices,
|
152
|
-
seed=config.seed,
|
153
151
|
presence_penalty=config.presence_penalty,
|
154
152
|
frequency_penalty=config.frequency_penalty,
|
155
|
-
response_logprobs=config.logprobs,
|
156
|
-
logprobs=config.top_logprobs,
|
157
153
|
)
|
158
154
|
|
159
155
|
# google-native messages
|
@@ -176,18 +172,15 @@ class GoogleAPI(ModelAPI):
|
|
176
172
|
response=response,
|
177
173
|
)
|
178
174
|
|
179
|
-
# cast to AsyncGenerateContentResponse since we passed stream=False
|
180
175
|
try:
|
181
|
-
response =
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
tools=gemini_tools,
|
188
|
-
tool_config=gemini_tool_config,
|
189
|
-
),
|
176
|
+
response = await self.model.generate_content_async(
|
177
|
+
contents=contents,
|
178
|
+
safety_settings=self.safety_settings,
|
179
|
+
generation_config=parameters,
|
180
|
+
tools=gemini_tools,
|
181
|
+
tool_config=gemini_tool_config,
|
190
182
|
)
|
183
|
+
|
191
184
|
except InvalidArgument as ex:
|
192
185
|
return self.handle_invalid_argument(ex), model_call()
|
193
186
|
|
@@ -205,15 +198,13 @@ class GoogleAPI(ModelAPI):
|
|
205
198
|
# return
|
206
199
|
return output, model_call()
|
207
200
|
|
208
|
-
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput:
|
201
|
+
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
|
209
202
|
if "size exceeds the limit" in ex.message.lower():
|
210
203
|
return ModelOutput.from_content(
|
211
204
|
model=self.model_name, content=ex.message, stop_reason="model_length"
|
212
205
|
)
|
213
206
|
else:
|
214
|
-
return
|
215
|
-
model=self.model_name, content=ex.message, stop_reason="unknown"
|
216
|
-
)
|
207
|
+
return ex
|
217
208
|
|
218
209
|
@override
|
219
210
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -231,7 +222,7 @@ class GoogleAPI(ModelAPI):
|
|
231
222
|
def build_model_call(
|
232
223
|
contents: list[ContentDict],
|
233
224
|
generation_config: GenerationConfig,
|
234
|
-
safety_settings:
|
225
|
+
safety_settings: EasySafetySettingDict,
|
235
226
|
tools: list[Tool] | None,
|
236
227
|
tool_config: ToolConfig | None,
|
237
228
|
response: AsyncGenerateContentResponse | None,
|
@@ -248,7 +239,7 @@ def build_model_call(
|
|
248
239
|
if tool_config is not None
|
249
240
|
else None,
|
250
241
|
),
|
251
|
-
response=response.to_dict() if response is not None else {},
|
242
|
+
response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
|
252
243
|
filter=model_call_filter,
|
253
244
|
)
|
254
245
|
|
@@ -269,12 +260,12 @@ def model_call_content(content: ContentDict) -> ContentDict:
|
|
269
260
|
|
270
261
|
def model_call_part(part: PartType) -> PartType:
|
271
262
|
if isinstance(part, proto.Message):
|
272
|
-
return MessageToDict(part._pb)
|
263
|
+
return cast(PartDict, MessageToDict(part._pb))
|
273
264
|
elif isinstance(part, dict):
|
274
265
|
part = part.copy()
|
275
266
|
keys = list(part.keys())
|
276
267
|
for key in keys:
|
277
|
-
part[key] = model_call_part(part[key])
|
268
|
+
part[key] = model_call_part(part[key]) # type: ignore[literal-required]
|
278
269
|
return part
|
279
270
|
else:
|
280
271
|
return part
|
@@ -316,9 +307,6 @@ def consective_tool_message_reducer(
|
|
316
307
|
return messages
|
317
308
|
|
318
309
|
|
319
|
-
NO_CONTENT = "(no content)"
|
320
|
-
|
321
|
-
|
322
310
|
async def content_dict(
|
323
311
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
324
312
|
) -> ContentDict:
|
@@ -326,13 +314,13 @@ async def content_dict(
|
|
326
314
|
return ContentDict(
|
327
315
|
role="user",
|
328
316
|
parts=(
|
329
|
-
[
|
317
|
+
[message.content or NO_CONTENT]
|
330
318
|
if isinstance(message.content, str)
|
331
319
|
else [await content_part(content) for content in message.content]
|
332
320
|
),
|
333
321
|
)
|
334
322
|
elif isinstance(message, ChatMessageAssistant):
|
335
|
-
content_parts: list[
|
323
|
+
content_parts: list[PartType] = []
|
336
324
|
# tool call parts
|
337
325
|
if message.tool_calls is not None:
|
338
326
|
content_parts.extend(
|
@@ -383,9 +371,9 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
|
|
383
371
|
|
384
372
|
async def content_part(content: Content | str) -> PartType:
|
385
373
|
if isinstance(content, str):
|
386
|
-
return
|
374
|
+
return content or NO_CONTENT
|
387
375
|
elif isinstance(content, ContentText):
|
388
|
-
return
|
376
|
+
return content.text or NO_CONTENT
|
389
377
|
else:
|
390
378
|
return await chat_content_to_part(content)
|
391
379
|
|
@@ -404,7 +392,9 @@ def prepend_system_messages(
|
|
404
392
|
messages: list[ContentDict], system_messages: list[ChatMessageSystem]
|
405
393
|
) -> None:
|
406
394
|
# create system_parts
|
407
|
-
system_parts = [
|
395
|
+
system_parts: list[PartType] = [
|
396
|
+
Part(text=message.content) for message in system_messages
|
397
|
+
]
|
408
398
|
|
409
399
|
# we want the system messages to be prepended to the first user message
|
410
400
|
# (if there is no first user message then prepend one)
|
@@ -476,6 +466,8 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
|
|
476
466
|
return schema_from_param(param.anyOf[0], nullable=True)
|
477
467
|
else:
|
478
468
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
469
|
+
elif param.enum:
|
470
|
+
return Schema(type=Type.STRING, format="enum", enum=param.enum)
|
479
471
|
else:
|
480
472
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
481
473
|
|
@@ -600,14 +592,14 @@ def gapi_should_retry(ex: BaseException) -> bool:
|
|
600
592
|
|
601
593
|
def parse_safety_settings(
|
602
594
|
safety_settings: Any,
|
603
|
-
) ->
|
595
|
+
) -> EasySafetySettingDict:
|
604
596
|
# ensure we have a dict
|
605
597
|
if isinstance(safety_settings, str):
|
606
598
|
safety_settings = json.loads(safety_settings)
|
607
599
|
if not isinstance(safety_settings, dict):
|
608
600
|
raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
|
609
601
|
|
610
|
-
parsed_settings:
|
602
|
+
parsed_settings: EasySafetySettingDict = {}
|
611
603
|
for key, value in safety_settings.items():
|
612
604
|
if isinstance(key, str):
|
613
605
|
key = str_to_harm_category(key)
|
@@ -623,23 +615,23 @@ def parse_safety_settings(
|
|
623
615
|
return parsed_settings
|
624
616
|
|
625
617
|
|
626
|
-
def str_to_harm_category(category: str) ->
|
618
|
+
def str_to_harm_category(category: str) -> int:
|
627
619
|
category = category.upper()
|
628
620
|
if "HARASSMENT" in category:
|
629
|
-
return HarmCategory.HARM_CATEGORY_HARASSMENT
|
621
|
+
return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
|
630
622
|
elif "HATE_SPEECH" in category:
|
631
|
-
return HarmCategory.HARM_CATEGORY_HATE_SPEECH
|
623
|
+
return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
|
632
624
|
elif "SEXUALLY_EXPLICIT" in category:
|
633
|
-
return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
|
625
|
+
return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
|
634
626
|
elif "DANGEROUS_CONTENT" in category:
|
635
|
-
return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
|
627
|
+
return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
|
636
628
|
else:
|
637
629
|
# NOTE: Although there is an "UNSPECIFIED" category, in the
|
638
630
|
# documentation, the API does not accept it.
|
639
631
|
raise ValueError(f"Unknown HarmCategory: {category}")
|
640
632
|
|
641
633
|
|
642
|
-
def str_to_harm_block_threshold(threshold: str) ->
|
634
|
+
def str_to_harm_block_threshold(threshold: str) -> int:
|
643
635
|
threshold = threshold.upper()
|
644
636
|
if "LOW" in threshold:
|
645
637
|
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
@@ -673,7 +665,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
673
665
|
uploaded_file = files_db.get(content_sha256)
|
674
666
|
if uploaded_file:
|
675
667
|
try:
|
676
|
-
upload =
|
668
|
+
upload = get_file(uploaded_file)
|
677
669
|
if upload.state.name == "ACTIVE":
|
678
670
|
trace(f"Using uploaded file: {uploaded_file}")
|
679
671
|
return upload
|
@@ -40,6 +40,7 @@ from typing_extensions import override
|
|
40
40
|
# https://github.com/mistralai/client-python/blob/main/MIGRATION.md
|
41
41
|
from inspect_ai._util.constants import (
|
42
42
|
DEFAULT_TIMEOUT,
|
43
|
+
NO_CONTENT,
|
43
44
|
)
|
44
45
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
45
46
|
from inspect_ai._util.images import file_as_data_uri
|
@@ -122,7 +123,7 @@ class MistralAPI(ModelAPI):
|
|
122
123
|
tools: list[ToolInfo],
|
123
124
|
tool_choice: ToolChoice,
|
124
125
|
config: GenerateConfig,
|
125
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
126
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
126
127
|
# build request
|
127
128
|
request: dict[str, Any] = dict(
|
128
129
|
model=self.model_name,
|
@@ -146,7 +147,7 @@ class MistralAPI(ModelAPI):
|
|
146
147
|
response = await self.client.chat.complete_async(**request)
|
147
148
|
except SDKError as ex:
|
148
149
|
if ex.status_code == 400:
|
149
|
-
return self.handle_bad_request(ex)
|
150
|
+
return self.handle_bad_request(ex), mistral_model_call(request, None)
|
150
151
|
else:
|
151
152
|
raise ex
|
152
153
|
|
@@ -181,25 +182,27 @@ class MistralAPI(ModelAPI):
|
|
181
182
|
def connection_key(self) -> str:
|
182
183
|
return str(self.api_key)
|
183
184
|
|
184
|
-
def handle_bad_request(self, ex: SDKError) -> ModelOutput:
|
185
|
+
def handle_bad_request(self, ex: SDKError) -> ModelOutput | Exception:
|
186
|
+
body = json.loads(ex.body)
|
187
|
+
content = body.get("message", ex.body)
|
185
188
|
if "maximum context length" in ex.body:
|
186
|
-
body = json.loads(ex.body)
|
187
|
-
content = body.get("message", ex.body)
|
188
189
|
return ModelOutput.from_content(
|
189
190
|
model=self.model_name, content=content, stop_reason="model_length"
|
190
191
|
)
|
191
192
|
else:
|
192
|
-
|
193
|
+
return ex
|
193
194
|
|
194
195
|
|
195
196
|
def mistral_model_call(
|
196
|
-
request: dict[str, Any], response: MistralChatCompletionResponse
|
197
|
+
request: dict[str, Any], response: MistralChatCompletionResponse | None
|
197
198
|
) -> ModelCall:
|
198
199
|
request = request.copy()
|
199
200
|
request.update(messages=[message.model_dump() for message in request["messages"]])
|
200
201
|
if request.get("tools", None) is not None:
|
201
202
|
request["tools"] = [tool.model_dump() for tool in request["tools"]]
|
202
|
-
return ModelCall(
|
203
|
+
return ModelCall(
|
204
|
+
request=request, response=response.model_dump() if response else {}
|
205
|
+
)
|
203
206
|
|
204
207
|
|
205
208
|
def mistral_chat_tools(tools: list[ToolInfo]) -> list[MistralTool]:
|
@@ -326,9 +329,6 @@ async def mistral_chat_message(
|
|
326
329
|
)
|
327
330
|
|
328
331
|
|
329
|
-
NO_CONTENT = "(no content)"
|
330
|
-
|
331
|
-
|
332
332
|
async def mistral_message_content(
|
333
333
|
content: str | list[Content],
|
334
334
|
) -> str | list[ContentChunk]:
|
@@ -166,7 +166,7 @@ class OpenAIAPI(ModelAPI):
|
|
166
166
|
tools: list[ToolInfo],
|
167
167
|
tool_choice: ToolChoice,
|
168
168
|
config: GenerateConfig,
|
169
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
169
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
170
170
|
# short-circuit to call o1- models that are text only
|
171
171
|
if self.is_o1_preview() or self.is_o1_mini():
|
172
172
|
return await generate_o1(
|
@@ -307,27 +307,26 @@ class OpenAIAPI(ModelAPI):
|
|
307
307
|
return params
|
308
308
|
|
309
309
|
# convert some well known bad request errors into ModelOutput
|
310
|
-
def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
content = e.message
|
310
|
+
def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
|
311
|
+
# extract message
|
312
|
+
if isinstance(e.body, dict) and "message" in e.body.keys():
|
313
|
+
content = str(e.body.get("message"))
|
314
|
+
else:
|
315
|
+
content = e.message
|
317
316
|
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
stop_reason = "unknown"
|
317
|
+
# narrow stop_reason
|
318
|
+
stop_reason: StopReason | None = None
|
319
|
+
if e.code == "context_length_exceeded":
|
320
|
+
stop_reason = "model_length"
|
321
|
+
elif e.code == "invalid_prompt":
|
322
|
+
stop_reason = "content_filter"
|
325
323
|
|
324
|
+
if stop_reason:
|
326
325
|
return ModelOutput.from_content(
|
327
326
|
model=self.model_name, content=content, stop_reason=stop_reason
|
328
327
|
)
|
329
328
|
else:
|
330
|
-
|
329
|
+
return e
|
331
330
|
|
332
331
|
|
333
332
|
async def as_openai_chat_messages(
|
@@ -44,7 +44,7 @@ async def generate_o1(
|
|
44
44
|
input: list[ChatMessage],
|
45
45
|
tools: list[ToolInfo],
|
46
46
|
**params: Any,
|
47
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
47
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
48
48
|
# create chatapi handler
|
49
49
|
handler = O1PreviewChatAPIHandler()
|
50
50
|
|
@@ -82,17 +82,18 @@ async def generate_o1(
|
|
82
82
|
), model_call()
|
83
83
|
|
84
84
|
|
85
|
-
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
|
85
|
+
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput | Exception:
|
86
86
|
if ex.code == "context_length_exceeded":
|
87
|
-
stop_reason: StopReason = "model_length"
|
87
|
+
stop_reason: StopReason | None = "model_length"
|
88
88
|
elif ex.code == "invalid_prompt":
|
89
89
|
stop_reason = "content_filter"
|
90
|
-
else:
|
91
|
-
stop_reason = "unknown"
|
92
90
|
|
93
|
-
|
94
|
-
|
95
|
-
|
91
|
+
if stop_reason:
|
92
|
+
return ModelOutput.from_content(
|
93
|
+
model=model, content=str(ex), stop_reason=stop_reason
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
return ex
|
96
97
|
|
97
98
|
|
98
99
|
def chat_messages(
|
@@ -94,7 +94,7 @@ def vertex() -> type[ModelAPI]:
|
|
94
94
|
def google() -> type[ModelAPI]:
|
95
95
|
FEATURE = "Google API"
|
96
96
|
PACKAGE = "google-generativeai"
|
97
|
-
MIN_VERSION = "0.8.
|
97
|
+
MIN_VERSION = "0.8.4"
|
98
98
|
|
99
99
|
# workaround log spam
|
100
100
|
# https://github.com/ray-project/ray/issues/24917
|
@@ -103,18 +103,18 @@ class TogetherAIAPI(OpenAIAPI):
|
|
103
103
|
return DEFAULT_MAX_TOKENS
|
104
104
|
|
105
105
|
@override
|
106
|
-
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput:
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
106
|
+
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
|
107
|
+
response = ex.response.json()
|
108
|
+
if "error" in response and "message" in response.get("error"):
|
109
|
+
content = response.get("error").get("message")
|
110
|
+
else:
|
111
|
+
content = str(response)
|
112
|
+
if "max_new_tokens" in ex.message:
|
113
113
|
return ModelOutput.from_content(
|
114
114
|
model=self.model_name, content=content, stop_reason="model_length"
|
115
115
|
)
|
116
116
|
else:
|
117
|
-
|
117
|
+
return ex
|
118
118
|
|
119
119
|
# Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
|
120
120
|
def _chat_choices_from_response(
|
@@ -23,7 +23,7 @@ from vertexai.generative_models import ( # type: ignore
|
|
23
23
|
)
|
24
24
|
from vertexai.generative_models import Content as VertexContent
|
25
25
|
|
26
|
-
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
26
|
+
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
27
27
|
from inspect_ai._util.content import (
|
28
28
|
Content,
|
29
29
|
ContentAudio,
|
@@ -250,9 +250,6 @@ def consective_tool_message_reducer(
|
|
250
250
|
return messages
|
251
251
|
|
252
252
|
|
253
|
-
NO_CONTENT = "(no content)"
|
254
|
-
|
255
|
-
|
256
253
|
async def content_dict(
|
257
254
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
258
255
|
) -> VertexContent:
|
inspect_ai/scorer/_scorer.py
CHANGED
@@ -151,8 +151,8 @@ def scorer_metrics(
|
|
151
151
|
return cast(list[Metric | dict[str, list[Metric]]], metrics_raw)
|
152
152
|
|
153
153
|
|
154
|
-
def unique_scorer_name(scorer: Scorer, already_used_names: list[str]) -> str:
|
155
|
-
base_name = registry_unqualified_name(scorer)
|
154
|
+
def unique_scorer_name(scorer: Scorer | str, already_used_names: list[str]) -> str:
|
155
|
+
base_name = scorer if isinstance(scorer, str) else registry_unqualified_name(scorer)
|
156
156
|
scorer_name = base_name
|
157
157
|
count = 1
|
158
158
|
while scorer_name in already_used_names:
|
inspect_ai/solver/__init__.py
CHANGED
@@ -7,11 +7,7 @@ from ._fork import fork
|
|
7
7
|
from ._human_agent.agent import human_agent
|
8
8
|
from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
|
9
9
|
from ._plan import Plan, plan
|
10
|
-
from ._prompt import
|
11
|
-
chain_of_thought,
|
12
|
-
prompt_template,
|
13
|
-
system_message,
|
14
|
-
)
|
10
|
+
from ._prompt import chain_of_thought, prompt_template, system_message, user_message
|
15
11
|
from ._solver import Generate, Solver, SolverSpec, generate, solver
|
16
12
|
from ._task_state import Choice, Choices, TaskState
|
17
13
|
from ._use_tools import use_tools
|
@@ -26,6 +22,7 @@ __all__ = [
|
|
26
22
|
"chain_of_thought",
|
27
23
|
"multiple_choice",
|
28
24
|
"system_message",
|
25
|
+
"user_message",
|
29
26
|
"self_critique",
|
30
27
|
"use_tools",
|
31
28
|
"plan",
|
inspect_ai/solver/_prompt.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Any
|
|
2
2
|
|
3
3
|
from inspect_ai._util.dict import omit
|
4
4
|
from inspect_ai.model import ChatMessageSystem
|
5
|
+
from inspect_ai.model._chat_message import ChatMessageUser
|
5
6
|
from inspect_ai.util import resource
|
6
7
|
|
7
8
|
from ._solver import Generate, Solver, solver
|
@@ -15,7 +16,8 @@ def prompt_template(template: str, **params: Any) -> Solver:
|
|
15
16
|
|
16
17
|
Prompt template containing a `{prompt}` placeholder and any
|
17
18
|
number of additional `params`. All values contained in sample
|
18
|
-
`metadata` are also automatically included in the
|
19
|
+
`metadata` and `store` are also automatically included in the
|
20
|
+
`params`.
|
19
21
|
|
20
22
|
Args:
|
21
23
|
template: (str): Template for prompt.
|
@@ -29,7 +31,7 @@ def prompt_template(template: str, **params: Any) -> Solver:
|
|
29
31
|
|
30
32
|
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
31
33
|
prompt = state.user_prompt
|
32
|
-
kwargs = omit(state.metadata, ["prompt"]) | params
|
34
|
+
kwargs = omit(state.metadata | state.store._data, ["prompt"]) | params
|
33
35
|
prompt.text = prompt_template.format(prompt=prompt.text, **kwargs)
|
34
36
|
return state
|
35
37
|
|
@@ -41,8 +43,9 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
41
43
|
"""Solver which inserts a system message into the conversation.
|
42
44
|
|
43
45
|
System message template containing any number of optional `params`.
|
44
|
-
for substitution
|
45
|
-
|
46
|
+
for substitution using the `str.format()` method. All values
|
47
|
+
contained in sample `metadata` and `store` are also automatically
|
48
|
+
included in the `params`.
|
46
49
|
|
47
50
|
The new message will go after other system messages (if there
|
48
51
|
are none it will be inserted at the beginning of the conversation).
|
@@ -58,7 +61,7 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
58
61
|
content = resource(template)
|
59
62
|
|
60
63
|
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
61
|
-
kwargs = state.metadata | params
|
64
|
+
kwargs = state.metadata | state.store._data | params
|
62
65
|
append_system_message(
|
63
66
|
state.messages, ChatMessageSystem(content=content.format(**kwargs))
|
64
67
|
)
|
@@ -67,6 +70,33 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
67
70
|
return solve
|
68
71
|
|
69
72
|
|
73
|
+
@solver
|
74
|
+
def user_message(template: str, **params: Any) -> Solver:
|
75
|
+
"""Solver which inserts a user message into the conversation.
|
76
|
+
|
77
|
+
User message template containing any number of optional `params`.
|
78
|
+
for substitution using the `str.format()` method. All values
|
79
|
+
contained in sample `metadata` and `store` are also automatically
|
80
|
+
included in the `params`.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
template (str): Template for user message.
|
84
|
+
**params (dict[str,Any]): Parameters to fill into the template.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
A solver that inserts the parameterised user message.
|
88
|
+
"""
|
89
|
+
# read template
|
90
|
+
content = resource(template)
|
91
|
+
|
92
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
93
|
+
kwargs = state.metadata | state.store._data | params
|
94
|
+
state.messages.append(ChatMessageUser(content=content.format(**kwargs)))
|
95
|
+
return state
|
96
|
+
|
97
|
+
return solve
|
98
|
+
|
99
|
+
|
70
100
|
DEFAULT_COT_TEMPLATE = r"""
|
71
101
|
{prompt}
|
72
102
|
|