inspect-ai 0.3.57__py3-none-any.whl → 0.3.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +7 -3
- inspect_ai/_cli/eval.py +17 -2
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +4 -3
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +4 -9
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +119 -16
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/score.py +1 -0
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/results.py +50 -22
- inspect_ai/_eval/task/run.py +180 -124
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +2 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25375 -1846
- inspect_ai/_view/www/log-schema.json +129 -15
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +8 -10
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +75 -2
- inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +29 -13
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +62 -27
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/Json.mjs +12 -6
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +3 -6
- inspect_ai/log/_recorders/eval.py +19 -8
- inspect_ai/log/_samples.py +26 -5
- inspect_ai/log/_transcript.py +32 -2
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +59 -12
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/_conversation.py +61 -0
- inspect_ai/model/_generate_config.py +10 -4
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +7 -2
- inspect_ai/model/_providers/anthropic.py +109 -51
- inspect_ai/model/_providers/azureai.py +26 -24
- inspect_ai/model/_providers/bedrock.py +43 -44
- inspect_ai/model/_providers/google.py +121 -58
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +17 -20
- inspect_ai/model/_providers/openai.py +32 -21
- inspect_ai/model/_providers/openai_o1.py +9 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +8 -8
- inspect_ai/model/_providers/vertex.py +18 -8
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +2 -2
- inspect_ai/solver/__init__.py +2 -5
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +11 -1
- inspect_ai/tool/_tool.py +21 -3
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -11
- inspect_ai/util/_sandbox/docker/docker.py +84 -14
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/environment.py +27 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +159 -128
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/model/_trace.py +0 -48
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,10 @@
|
|
1
|
+
import asyncio
|
1
2
|
import functools
|
3
|
+
import hashlib
|
2
4
|
import json
|
3
5
|
from copy import copy
|
6
|
+
from io import BytesIO
|
7
|
+
from logging import getLogger
|
4
8
|
from typing import Any, cast
|
5
9
|
|
6
10
|
import proto # type: ignore
|
@@ -24,29 +28,39 @@ from google.api_core.exceptions import (
|
|
24
28
|
TooManyRequests,
|
25
29
|
)
|
26
30
|
from google.api_core.retry.retry_base import if_transient_error
|
27
|
-
from google.generativeai import
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
)
|
32
|
-
from google.generativeai.types import ( # type: ignore
|
33
|
-
AsyncGenerateContentResponse,
|
31
|
+
from google.generativeai.client import configure
|
32
|
+
from google.generativeai.files import get_file, upload_file
|
33
|
+
from google.generativeai.generative_models import GenerativeModel
|
34
|
+
from google.generativeai.types import (
|
34
35
|
ContentDict,
|
35
|
-
|
36
|
-
HarmCategory,
|
36
|
+
GenerationConfig,
|
37
37
|
PartDict,
|
38
38
|
PartType,
|
39
|
-
SafetySettingDict,
|
40
39
|
Tool,
|
41
40
|
)
|
41
|
+
from google.generativeai.types.file_types import File
|
42
|
+
from google.generativeai.types.generation_types import AsyncGenerateContentResponse
|
43
|
+
from google.generativeai.types.safety_types import (
|
44
|
+
EasySafetySettingDict,
|
45
|
+
HarmBlockThreshold,
|
46
|
+
HarmCategory,
|
47
|
+
)
|
42
48
|
from google.protobuf.json_format import MessageToDict, ParseDict
|
43
49
|
from google.protobuf.struct_pb2 import Struct
|
44
50
|
from pydantic import JsonValue
|
45
51
|
from typing_extensions import override
|
46
52
|
|
47
|
-
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
48
|
-
from inspect_ai._util.content import
|
49
|
-
|
53
|
+
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
54
|
+
from inspect_ai._util.content import (
|
55
|
+
Content,
|
56
|
+
ContentAudio,
|
57
|
+
ContentImage,
|
58
|
+
ContentText,
|
59
|
+
ContentVideo,
|
60
|
+
)
|
61
|
+
from inspect_ai._util.images import file_as_data
|
62
|
+
from inspect_ai._util.kvstore import inspect_kvstore
|
63
|
+
from inspect_ai._util.trace import trace_message
|
50
64
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
|
51
65
|
|
52
66
|
from .._chat_message import (
|
@@ -70,9 +84,11 @@ from .._model_output import (
|
|
70
84
|
)
|
71
85
|
from .util import model_base_url
|
72
86
|
|
87
|
+
logger = getLogger(__name__)
|
88
|
+
|
73
89
|
SAFETY_SETTINGS = "safety_settings"
|
74
90
|
|
75
|
-
DEFAULT_SAFETY_SETTINGS:
|
91
|
+
DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
|
76
92
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
77
93
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
78
94
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
@@ -124,7 +140,7 @@ class GoogleAPI(ModelAPI):
|
|
124
140
|
tools: list[ToolInfo],
|
125
141
|
tool_choice: ToolChoice,
|
126
142
|
config: GenerateConfig,
|
127
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
143
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
128
144
|
parameters = GenerationConfig(
|
129
145
|
temperature=config.temperature,
|
130
146
|
top_p=config.top_p,
|
@@ -132,11 +148,8 @@ class GoogleAPI(ModelAPI):
|
|
132
148
|
max_output_tokens=config.max_tokens,
|
133
149
|
stop_sequences=config.stop_seqs,
|
134
150
|
candidate_count=config.num_choices,
|
135
|
-
seed=config.seed,
|
136
151
|
presence_penalty=config.presence_penalty,
|
137
152
|
frequency_penalty=config.frequency_penalty,
|
138
|
-
response_logprobs=config.logprobs,
|
139
|
-
logprobs=config.top_logprobs,
|
140
153
|
)
|
141
154
|
|
142
155
|
# google-native messages
|
@@ -159,18 +172,15 @@ class GoogleAPI(ModelAPI):
|
|
159
172
|
response=response,
|
160
173
|
)
|
161
174
|
|
162
|
-
# cast to AsyncGenerateContentResponse since we passed stream=False
|
163
175
|
try:
|
164
|
-
response =
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
tools=gemini_tools,
|
171
|
-
tool_config=gemini_tool_config,
|
172
|
-
),
|
176
|
+
response = await self.model.generate_content_async(
|
177
|
+
contents=contents,
|
178
|
+
safety_settings=self.safety_settings,
|
179
|
+
generation_config=parameters,
|
180
|
+
tools=gemini_tools,
|
181
|
+
tool_config=gemini_tool_config,
|
173
182
|
)
|
183
|
+
|
174
184
|
except InvalidArgument as ex:
|
175
185
|
return self.handle_invalid_argument(ex), model_call()
|
176
186
|
|
@@ -188,15 +198,13 @@ class GoogleAPI(ModelAPI):
|
|
188
198
|
# return
|
189
199
|
return output, model_call()
|
190
200
|
|
191
|
-
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput:
|
201
|
+
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
|
192
202
|
if "size exceeds the limit" in ex.message.lower():
|
193
203
|
return ModelOutput.from_content(
|
194
204
|
model=self.model_name, content=ex.message, stop_reason="model_length"
|
195
205
|
)
|
196
206
|
else:
|
197
|
-
return
|
198
|
-
model=self.model_name, content=ex.message, stop_reason="unknown"
|
199
|
-
)
|
207
|
+
return ex
|
200
208
|
|
201
209
|
@override
|
202
210
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -214,7 +222,7 @@ class GoogleAPI(ModelAPI):
|
|
214
222
|
def build_model_call(
|
215
223
|
contents: list[ContentDict],
|
216
224
|
generation_config: GenerationConfig,
|
217
|
-
safety_settings:
|
225
|
+
safety_settings: EasySafetySettingDict,
|
218
226
|
tools: list[Tool] | None,
|
219
227
|
tool_config: ToolConfig | None,
|
220
228
|
response: AsyncGenerateContentResponse | None,
|
@@ -231,7 +239,7 @@ def build_model_call(
|
|
231
239
|
if tool_config is not None
|
232
240
|
else None,
|
233
241
|
),
|
234
|
-
response=response.to_dict() if response is not None else {},
|
242
|
+
response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
|
235
243
|
filter=model_call_filter,
|
236
244
|
)
|
237
245
|
|
@@ -252,12 +260,12 @@ def model_call_content(content: ContentDict) -> ContentDict:
|
|
252
260
|
|
253
261
|
def model_call_part(part: PartType) -> PartType:
|
254
262
|
if isinstance(part, proto.Message):
|
255
|
-
return MessageToDict(part._pb)
|
263
|
+
return cast(PartDict, MessageToDict(part._pb))
|
256
264
|
elif isinstance(part, dict):
|
257
265
|
part = part.copy()
|
258
266
|
keys = list(part.keys())
|
259
267
|
for key in keys:
|
260
|
-
part[key] = model_call_part(part[key])
|
268
|
+
part[key] = model_call_part(part[key]) # type: ignore[literal-required]
|
261
269
|
return part
|
262
270
|
else:
|
263
271
|
return part
|
@@ -299,9 +307,6 @@ def consective_tool_message_reducer(
|
|
299
307
|
return messages
|
300
308
|
|
301
309
|
|
302
|
-
NO_CONTENT = "(no content)"
|
303
|
-
|
304
|
-
|
305
310
|
async def content_dict(
|
306
311
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
307
312
|
) -> ContentDict:
|
@@ -309,13 +314,13 @@ async def content_dict(
|
|
309
314
|
return ContentDict(
|
310
315
|
role="user",
|
311
316
|
parts=(
|
312
|
-
[
|
317
|
+
[message.content or NO_CONTENT]
|
313
318
|
if isinstance(message.content, str)
|
314
319
|
else [await content_part(content) for content in message.content]
|
315
320
|
),
|
316
321
|
)
|
317
322
|
elif isinstance(message, ChatMessageAssistant):
|
318
|
-
content_parts: list[
|
323
|
+
content_parts: list[PartType] = []
|
319
324
|
# tool call parts
|
320
325
|
if message.tool_calls is not None:
|
321
326
|
content_parts.extend(
|
@@ -364,26 +369,32 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
|
|
364
369
|
return struct
|
365
370
|
|
366
371
|
|
367
|
-
async def content_part(content: Content | str) ->
|
372
|
+
async def content_part(content: Content | str) -> PartType:
|
368
373
|
if isinstance(content, str):
|
369
|
-
return
|
374
|
+
return content or NO_CONTENT
|
370
375
|
elif isinstance(content, ContentText):
|
371
|
-
return
|
376
|
+
return content.text or NO_CONTENT
|
372
377
|
else:
|
373
|
-
return
|
378
|
+
return await chat_content_to_part(content)
|
374
379
|
|
375
380
|
|
376
|
-
async def
|
377
|
-
|
378
|
-
|
379
|
-
|
381
|
+
async def chat_content_to_part(
|
382
|
+
content: ContentImage | ContentAudio | ContentVideo,
|
383
|
+
) -> PartType:
|
384
|
+
if isinstance(content, ContentImage):
|
385
|
+
content_bytes, mime_type = await file_as_data(content.image)
|
386
|
+
return Blob(mime_type=mime_type, data=content_bytes)
|
387
|
+
else:
|
388
|
+
return await file_for_content(content)
|
380
389
|
|
381
390
|
|
382
391
|
def prepend_system_messages(
|
383
392
|
messages: list[ContentDict], system_messages: list[ChatMessageSystem]
|
384
393
|
) -> None:
|
385
394
|
# create system_parts
|
386
|
-
system_parts = [
|
395
|
+
system_parts: list[PartType] = [
|
396
|
+
Part(text=message.content) for message in system_messages
|
397
|
+
]
|
387
398
|
|
388
399
|
# we want the system messages to be prepended to the first user message
|
389
400
|
# (if there is no first user message then prepend one)
|
@@ -455,6 +466,8 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
|
|
455
466
|
return schema_from_param(param.anyOf[0], nullable=True)
|
456
467
|
else:
|
457
468
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
469
|
+
elif param.enum:
|
470
|
+
return Schema(type=Type.STRING, format="enum", enum=param.enum)
|
458
471
|
else:
|
459
472
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
460
473
|
|
@@ -579,14 +592,14 @@ def gapi_should_retry(ex: BaseException) -> bool:
|
|
579
592
|
|
580
593
|
def parse_safety_settings(
|
581
594
|
safety_settings: Any,
|
582
|
-
) ->
|
595
|
+
) -> EasySafetySettingDict:
|
583
596
|
# ensure we have a dict
|
584
597
|
if isinstance(safety_settings, str):
|
585
598
|
safety_settings = json.loads(safety_settings)
|
586
599
|
if not isinstance(safety_settings, dict):
|
587
600
|
raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
|
588
601
|
|
589
|
-
parsed_settings:
|
602
|
+
parsed_settings: EasySafetySettingDict = {}
|
590
603
|
for key, value in safety_settings.items():
|
591
604
|
if isinstance(key, str):
|
592
605
|
key = str_to_harm_category(key)
|
@@ -602,23 +615,23 @@ def parse_safety_settings(
|
|
602
615
|
return parsed_settings
|
603
616
|
|
604
617
|
|
605
|
-
def str_to_harm_category(category: str) ->
|
618
|
+
def str_to_harm_category(category: str) -> int:
|
606
619
|
category = category.upper()
|
607
620
|
if "HARASSMENT" in category:
|
608
|
-
return HarmCategory.HARM_CATEGORY_HARASSMENT
|
621
|
+
return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
|
609
622
|
elif "HATE_SPEECH" in category:
|
610
|
-
return HarmCategory.HARM_CATEGORY_HATE_SPEECH
|
623
|
+
return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
|
611
624
|
elif "SEXUALLY_EXPLICIT" in category:
|
612
|
-
return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
|
625
|
+
return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
|
613
626
|
elif "DANGEROUS_CONTENT" in category:
|
614
|
-
return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
|
627
|
+
return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
|
615
628
|
else:
|
616
629
|
# NOTE: Although there is an "UNSPECIFIED" category, in the
|
617
630
|
# documentation, the API does not accept it.
|
618
631
|
raise ValueError(f"Unknown HarmCategory: {category}")
|
619
632
|
|
620
633
|
|
621
|
-
def str_to_harm_block_threshold(threshold: str) ->
|
634
|
+
def str_to_harm_block_threshold(threshold: str) -> int:
|
622
635
|
threshold = threshold.upper()
|
623
636
|
if "LOW" in threshold:
|
624
637
|
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
@@ -630,3 +643,53 @@ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
|
|
630
643
|
return HarmBlockThreshold.BLOCK_NONE
|
631
644
|
else:
|
632
645
|
raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
|
646
|
+
|
647
|
+
|
648
|
+
async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
649
|
+
# helper to write trace messages
|
650
|
+
def trace(message: str) -> None:
|
651
|
+
trace_message(logger, "Google Files", message)
|
652
|
+
|
653
|
+
# get the file bytes and compute sha256 hash
|
654
|
+
if isinstance(content, ContentAudio):
|
655
|
+
file = content.audio
|
656
|
+
else:
|
657
|
+
file = content.video
|
658
|
+
content_bytes, mime_type = await file_as_data(file)
|
659
|
+
content_sha256 = hashlib.sha256(content_bytes).hexdigest()
|
660
|
+
|
661
|
+
# we cache uploads for re-use, open the db where we track that
|
662
|
+
# (track up to 1 million previous uploads)
|
663
|
+
with inspect_kvstore("google_files", 1000000) as files_db:
|
664
|
+
# can we serve from existing uploads?
|
665
|
+
uploaded_file = files_db.get(content_sha256)
|
666
|
+
if uploaded_file:
|
667
|
+
try:
|
668
|
+
upload = get_file(uploaded_file)
|
669
|
+
if upload.state.name == "ACTIVE":
|
670
|
+
trace(f"Using uploaded file: {uploaded_file}")
|
671
|
+
return upload
|
672
|
+
else:
|
673
|
+
trace(
|
674
|
+
f"Not using uploaded file '{uploaded_file} (state was {upload.state})"
|
675
|
+
)
|
676
|
+
except Exception as ex:
|
677
|
+
trace(f"Error attempting to access uploaded file: {ex}")
|
678
|
+
files_db.delete(content_sha256)
|
679
|
+
|
680
|
+
# do the upload (and record it)
|
681
|
+
upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
|
682
|
+
while upload.state.name == "PROCESSING":
|
683
|
+
await asyncio.sleep(3)
|
684
|
+
upload = get_file(upload.name)
|
685
|
+
|
686
|
+
if upload.state.name == "FAILED":
|
687
|
+
trace(f"Failed to upload file '{upload.name}: {upload.error}")
|
688
|
+
raise ValueError(f"Google file upload failed: {upload.error}")
|
689
|
+
|
690
|
+
# trace and record it
|
691
|
+
trace(f"Uploaded file: {upload.name}")
|
692
|
+
files_db.put(content_sha256, upload.name)
|
693
|
+
|
694
|
+
# return the file
|
695
|
+
return upload
|
@@ -23,8 +23,8 @@ from typing_extensions import override
|
|
23
23
|
|
24
24
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
|
25
25
|
from inspect_ai._util.content import Content
|
26
|
-
from inspect_ai._util.images import
|
27
|
-
from inspect_ai._util.url import
|
26
|
+
from inspect_ai._util.images import file_as_data_uri
|
27
|
+
from inspect_ai._util.url import is_http_url
|
28
28
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
29
29
|
|
30
30
|
from .._chat_message import (
|
@@ -248,18 +248,20 @@ async def as_chat_completion_part(
|
|
248
248
|
) -> ChatCompletionContentPartParam:
|
249
249
|
if content.type == "text":
|
250
250
|
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
251
|
-
|
251
|
+
elif content.type == "image":
|
252
252
|
# API takes URL or base64 encoded file. If it's a remote file or data URL leave it alone, otherwise encode it
|
253
253
|
image_url = content.image
|
254
254
|
detail = content.detail
|
255
255
|
|
256
|
-
if not is_http_url(image_url)
|
257
|
-
image_url = await
|
256
|
+
if not is_http_url(image_url):
|
257
|
+
image_url = await file_as_data_uri(image_url)
|
258
258
|
|
259
259
|
return ChatCompletionContentPartImageParam(
|
260
260
|
type="image_url",
|
261
261
|
image_url=dict(url=image_url, detail=detail),
|
262
262
|
)
|
263
|
+
else:
|
264
|
+
raise RuntimeError("Groq models do not support audio or video inputs.")
|
263
265
|
|
264
266
|
|
265
267
|
def chat_tools(tools: List[ToolInfo]) -> List[Dict[str, Any]]:
|
@@ -239,12 +239,17 @@ class HuggingFaceAPI(ModelAPI):
|
|
239
239
|
hf_messages = inspect_tools_to_string(hf_messages)
|
240
240
|
|
241
241
|
# apply chat template
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
242
|
+
if self.tokenizer.chat_template is not None:
|
243
|
+
chat = self.tokenizer.apply_chat_template(
|
244
|
+
hf_messages,
|
245
|
+
add_generation_prompt=True,
|
246
|
+
tokenize=False,
|
247
|
+
tools=tools_list if len(tools_list) > 0 else None,
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
chat = ""
|
251
|
+
for message in hf_messages:
|
252
|
+
chat += f"{message.role}: {message.content}\n"
|
248
253
|
# return
|
249
254
|
return cast(str, chat)
|
250
255
|
|
@@ -40,10 +40,10 @@ from typing_extensions import override
|
|
40
40
|
# https://github.com/mistralai/client-python/blob/main/MIGRATION.md
|
41
41
|
from inspect_ai._util.constants import (
|
42
42
|
DEFAULT_TIMEOUT,
|
43
|
+
NO_CONTENT,
|
43
44
|
)
|
44
45
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
45
|
-
from inspect_ai._util.images import
|
46
|
-
from inspect_ai._util.url import is_data_uri
|
46
|
+
from inspect_ai._util.images import file_as_data_uri
|
47
47
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
48
48
|
|
49
49
|
from .._chat_message import (
|
@@ -123,7 +123,7 @@ class MistralAPI(ModelAPI):
|
|
123
123
|
tools: list[ToolInfo],
|
124
124
|
tool_choice: ToolChoice,
|
125
125
|
config: GenerateConfig,
|
126
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
126
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
127
127
|
# build request
|
128
128
|
request: dict[str, Any] = dict(
|
129
129
|
model=self.model_name,
|
@@ -147,7 +147,7 @@ class MistralAPI(ModelAPI):
|
|
147
147
|
response = await self.client.chat.complete_async(**request)
|
148
148
|
except SDKError as ex:
|
149
149
|
if ex.status_code == 400:
|
150
|
-
return self.handle_bad_request(ex)
|
150
|
+
return self.handle_bad_request(ex), mistral_model_call(request, None)
|
151
151
|
else:
|
152
152
|
raise ex
|
153
153
|
|
@@ -182,25 +182,27 @@ class MistralAPI(ModelAPI):
|
|
182
182
|
def connection_key(self) -> str:
|
183
183
|
return str(self.api_key)
|
184
184
|
|
185
|
-
def handle_bad_request(self, ex: SDKError) -> ModelOutput:
|
185
|
+
def handle_bad_request(self, ex: SDKError) -> ModelOutput | Exception:
|
186
|
+
body = json.loads(ex.body)
|
187
|
+
content = body.get("message", ex.body)
|
186
188
|
if "maximum context length" in ex.body:
|
187
|
-
body = json.loads(ex.body)
|
188
|
-
content = body.get("message", ex.body)
|
189
189
|
return ModelOutput.from_content(
|
190
190
|
model=self.model_name, content=content, stop_reason="model_length"
|
191
191
|
)
|
192
192
|
else:
|
193
|
-
|
193
|
+
return ex
|
194
194
|
|
195
195
|
|
196
196
|
def mistral_model_call(
|
197
|
-
request: dict[str, Any], response: MistralChatCompletionResponse
|
197
|
+
request: dict[str, Any], response: MistralChatCompletionResponse | None
|
198
198
|
) -> ModelCall:
|
199
199
|
request = request.copy()
|
200
200
|
request.update(messages=[message.model_dump() for message in request["messages"]])
|
201
201
|
if request.get("tools", None) is not None:
|
202
202
|
request["tools"] = [tool.model_dump() for tool in request["tools"]]
|
203
|
-
return ModelCall(
|
203
|
+
return ModelCall(
|
204
|
+
request=request, response=response.model_dump() if response else {}
|
205
|
+
)
|
204
206
|
|
205
207
|
|
206
208
|
def mistral_chat_tools(tools: list[ToolInfo]) -> list[MistralTool]:
|
@@ -327,9 +329,6 @@ async def mistral_chat_message(
|
|
327
329
|
)
|
328
330
|
|
329
331
|
|
330
|
-
NO_CONTENT = "(no content)"
|
331
|
-
|
332
|
-
|
333
332
|
async def mistral_message_content(
|
334
333
|
content: str | list[Content],
|
335
334
|
) -> str | list[ContentChunk]:
|
@@ -351,16 +350,14 @@ def mistral_system_message_content(
|
|
351
350
|
async def mistral_content_chunk(content: Content) -> ContentChunk:
|
352
351
|
if isinstance(content, ContentText):
|
353
352
|
return TextChunk(text=content.text or NO_CONTENT)
|
354
|
-
|
353
|
+
elif isinstance(content, ContentImage):
|
355
354
|
# resolve image to url
|
356
|
-
image_url = content.image
|
357
|
-
if not is_data_uri(image_url):
|
358
|
-
image_url = await image_as_data_uri(image_url)
|
355
|
+
image_url = await file_as_data_uri(content.image)
|
359
356
|
|
360
357
|
# return chunk
|
361
|
-
return ImageURLChunk(
|
362
|
-
|
363
|
-
)
|
358
|
+
return ImageURLChunk(image_url=ImageURL(url=image_url, detail=content.detail))
|
359
|
+
else:
|
360
|
+
raise RuntimeError("Mistral models do not support audio or video inputs.")
|
364
361
|
|
365
362
|
|
366
363
|
def mistral_tool_call(tool_call: ToolCall) -> MistralToolCall:
|
@@ -17,6 +17,7 @@ from openai.types.chat import (
|
|
17
17
|
ChatCompletion,
|
18
18
|
ChatCompletionAssistantMessageParam,
|
19
19
|
ChatCompletionContentPartImageParam,
|
20
|
+
ChatCompletionContentPartInputAudioParam,
|
20
21
|
ChatCompletionContentPartParam,
|
21
22
|
ChatCompletionContentPartTextParam,
|
22
23
|
ChatCompletionDeveloperMessageParam,
|
@@ -36,9 +37,9 @@ from typing_extensions import override
|
|
36
37
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
37
38
|
from inspect_ai._util.content import Content
|
38
39
|
from inspect_ai._util.error import PrerequisiteError
|
39
|
-
from inspect_ai._util.images import
|
40
|
+
from inspect_ai._util.images import file_as_data_uri
|
40
41
|
from inspect_ai._util.logger import warn_once
|
41
|
-
from inspect_ai._util.url import
|
42
|
+
from inspect_ai._util.url import is_http_url
|
42
43
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
43
44
|
|
44
45
|
from .._chat_message import ChatMessage, ChatMessageAssistant
|
@@ -165,7 +166,7 @@ class OpenAIAPI(ModelAPI):
|
|
165
166
|
tools: list[ToolInfo],
|
166
167
|
tool_choice: ToolChoice,
|
167
168
|
config: GenerateConfig,
|
168
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
169
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
169
170
|
# short-circuit to call o1- models that are text only
|
170
171
|
if self.is_o1_preview() or self.is_o1_mini():
|
171
172
|
return await generate_o1(
|
@@ -306,27 +307,26 @@ class OpenAIAPI(ModelAPI):
|
|
306
307
|
return params
|
307
308
|
|
308
309
|
# convert some well known bad request errors into ModelOutput
|
309
|
-
def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
content = e.message
|
310
|
+
def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
|
311
|
+
# extract message
|
312
|
+
if isinstance(e.body, dict) and "message" in e.body.keys():
|
313
|
+
content = str(e.body.get("message"))
|
314
|
+
else:
|
315
|
+
content = e.message
|
316
316
|
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
stop_reason = "unknown"
|
317
|
+
# narrow stop_reason
|
318
|
+
stop_reason: StopReason | None = None
|
319
|
+
if e.code == "context_length_exceeded":
|
320
|
+
stop_reason = "model_length"
|
321
|
+
elif e.code == "invalid_prompt":
|
322
|
+
stop_reason = "content_filter"
|
324
323
|
|
324
|
+
if stop_reason:
|
325
325
|
return ModelOutput.from_content(
|
326
326
|
model=self.model_name, content=content, stop_reason=stop_reason
|
327
327
|
)
|
328
328
|
else:
|
329
|
-
|
329
|
+
return e
|
330
330
|
|
331
331
|
|
332
332
|
async def as_openai_chat_messages(
|
@@ -463,16 +463,27 @@ async def as_chat_completion_part(
|
|
463
463
|
) -> ChatCompletionContentPartParam:
|
464
464
|
if content.type == "text":
|
465
465
|
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
466
|
-
|
466
|
+
elif content.type == "image":
|
467
467
|
# API takes URL or base64 encoded file. If it's a remote file or
|
468
468
|
# data URL leave it alone, otherwise encode it
|
469
469
|
image_url = content.image
|
470
470
|
detail = content.detail
|
471
471
|
|
472
|
-
if not is_http_url(image_url)
|
473
|
-
image_url = await
|
472
|
+
if not is_http_url(image_url):
|
473
|
+
image_url = await file_as_data_uri(image_url)
|
474
474
|
|
475
475
|
return ChatCompletionContentPartImageParam(
|
476
476
|
type="image_url",
|
477
477
|
image_url=dict(url=image_url, detail=detail),
|
478
478
|
)
|
479
|
+
elif content.type == "audio":
|
480
|
+
audio_data = await file_as_data_uri(content.audio)
|
481
|
+
|
482
|
+
return ChatCompletionContentPartInputAudioParam(
|
483
|
+
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
484
|
+
)
|
485
|
+
|
486
|
+
else:
|
487
|
+
raise RuntimeError(
|
488
|
+
"Video content is not currently supported by Open AI chat models."
|
489
|
+
)
|
@@ -44,7 +44,7 @@ async def generate_o1(
|
|
44
44
|
input: list[ChatMessage],
|
45
45
|
tools: list[ToolInfo],
|
46
46
|
**params: Any,
|
47
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
47
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
48
48
|
# create chatapi handler
|
49
49
|
handler = O1PreviewChatAPIHandler()
|
50
50
|
|
@@ -82,17 +82,18 @@ async def generate_o1(
|
|
82
82
|
), model_call()
|
83
83
|
|
84
84
|
|
85
|
-
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
|
85
|
+
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput | Exception:
|
86
86
|
if ex.code == "context_length_exceeded":
|
87
|
-
stop_reason: StopReason = "model_length"
|
87
|
+
stop_reason: StopReason | None = "model_length"
|
88
88
|
elif ex.code == "invalid_prompt":
|
89
89
|
stop_reason = "content_filter"
|
90
|
-
else:
|
91
|
-
stop_reason = "unknown"
|
92
90
|
|
93
|
-
|
94
|
-
|
95
|
-
|
91
|
+
if stop_reason:
|
92
|
+
return ModelOutput.from_content(
|
93
|
+
model=model, content=str(ex), stop_reason=stop_reason
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
return ex
|
96
97
|
|
97
98
|
|
98
99
|
def chat_messages(
|
@@ -94,7 +94,7 @@ def vertex() -> type[ModelAPI]:
|
|
94
94
|
def google() -> type[ModelAPI]:
|
95
95
|
FEATURE = "Google API"
|
96
96
|
PACKAGE = "google-generativeai"
|
97
|
-
MIN_VERSION = "0.8.
|
97
|
+
MIN_VERSION = "0.8.4"
|
98
98
|
|
99
99
|
# workaround log spam
|
100
100
|
# https://github.com/ray-project/ray/issues/24917
|
@@ -103,18 +103,18 @@ class TogetherAIAPI(OpenAIAPI):
|
|
103
103
|
return DEFAULT_MAX_TOKENS
|
104
104
|
|
105
105
|
@override
|
106
|
-
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput:
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
106
|
+
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
|
107
|
+
response = ex.response.json()
|
108
|
+
if "error" in response and "message" in response.get("error"):
|
109
|
+
content = response.get("error").get("message")
|
110
|
+
else:
|
111
|
+
content = str(response)
|
112
|
+
if "max_new_tokens" in ex.message:
|
113
113
|
return ModelOutput.from_content(
|
114
114
|
model=self.model_name, content=content, stop_reason="model_length"
|
115
115
|
)
|
116
116
|
else:
|
117
|
-
|
117
|
+
return ex
|
118
118
|
|
119
119
|
# Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
|
120
120
|
def _chat_choices_from_response(
|