inspect-ai 0.3.56__py3-none-any.whl → 0.3.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +4 -2
- inspect_ai/_cli/eval.py +2 -0
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +0 -2
- inspect_ai/_display/core/panel.py +1 -1
- inspect_ai/_display/rich/display.py +4 -4
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/samples.py +41 -5
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/run.py +16 -11
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/run.py +141 -119
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/datetime.py +1 -1
- inspect_ai/_util/deprecation.py +1 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/json.py +11 -1
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/logger.py +2 -1
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_util/trace.py +39 -3
- inspect_ai/_util/transcript.py +36 -7
- inspect_ai/_view/www/.prettierrc.js +12 -0
- inspect_ai/_view/www/dist/assets/index.js +322 -226
- inspect_ai/_view/www/log-schema.json +221 -138
- inspect_ai/_view/www/src/App.mjs +18 -9
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/api/Types.mjs +15 -4
- inspect_ai/_view/www/src/api/api-http.mjs +2 -0
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
- inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +44 -2
- inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
- inspect_ai/_view/www/src/components/Tools.mjs +18 -3
- inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
- inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
- inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
- inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
- inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
- inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +242 -178
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
- inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
- inspect_ai/_view/www/src/types/log.d.ts +53 -35
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +27 -5
- inspect_ai/log/_recorders/eval.py +21 -8
- inspect_ai/log/_samples.py +10 -5
- inspect_ai/log/_transcript.py +28 -1
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +82 -17
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/{_trace.py → _conversation.py} +9 -8
- inspect_ai/model/_model.py +2 -2
- inspect_ai/model/_providers/anthropic.py +9 -7
- inspect_ai/model/_providers/azureai.py +6 -4
- inspect_ai/model/_providers/bedrock.py +6 -4
- inspect_ai/model/_providers/google.py +103 -14
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +6 -9
- inspect_ai/model/_providers/openai.py +34 -8
- inspect_ai/model/_providers/openai_o1.py +10 -12
- inspect_ai/model/_providers/vertex.py +17 -4
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/tool/__init__.py +9 -1
- inspect_ai/tool/_tool.py +9 -2
- inspect_ai/tool/_tool_info.py +2 -1
- inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -3
- inspect_ai/util/__init__.py +4 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -13
- inspect_ai/util/_sandbox/docker/docker.py +20 -13
- inspect_ai/util/_sandbox/docker/util.py +2 -1
- inspect_ai/util/_sandbox/environment.py +13 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- inspect_ai/util/_sandbox/self_check.py +18 -18
- inspect_ai/util/_store.py +2 -2
- inspect_ai/util/_subprocess.py +3 -3
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +107 -103
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
inspect_ai/model/_call_tools.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1
1
|
import asyncio
|
2
2
|
import inspect
|
3
|
+
import types
|
3
4
|
from dataclasses import is_dataclass
|
4
5
|
from logging import getLogger
|
5
6
|
from textwrap import dedent
|
7
|
+
from types import UnionType
|
6
8
|
from typing import (
|
7
9
|
Any,
|
8
10
|
Callable,
|
9
11
|
Dict,
|
10
12
|
List,
|
11
13
|
NamedTuple,
|
14
|
+
Optional,
|
15
|
+
Tuple,
|
12
16
|
Type,
|
17
|
+
Union,
|
13
18
|
get_args,
|
14
19
|
get_origin,
|
15
20
|
get_type_hints,
|
@@ -19,16 +24,19 @@ from typing import (
|
|
19
24
|
from jsonschema import Draft7Validator
|
20
25
|
from pydantic import BaseModel
|
21
26
|
|
22
|
-
from inspect_ai._util.content import
|
27
|
+
from inspect_ai._util.content import (
|
28
|
+
Content,
|
29
|
+
ContentAudio,
|
30
|
+
ContentImage,
|
31
|
+
ContentText,
|
32
|
+
ContentVideo,
|
33
|
+
)
|
23
34
|
from inspect_ai._util.format import format_function_call
|
24
35
|
from inspect_ai._util.text import truncate_string_to_bytes
|
25
36
|
from inspect_ai._util.trace import trace_action
|
26
|
-
from inspect_ai.model.
|
37
|
+
from inspect_ai.model._conversation import conversation_tool_mesage
|
27
38
|
from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
|
28
|
-
from inspect_ai.tool._tool import
|
29
|
-
ToolApprovalError,
|
30
|
-
ToolParsingError,
|
31
|
-
)
|
39
|
+
from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
|
32
40
|
from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
|
33
41
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
34
42
|
from inspect_ai.tool._tool_info import parse_docstring
|
@@ -118,10 +126,14 @@ async def call_tools(
|
|
118
126
|
# massage result, leave list[Content] alone, convert all other
|
119
127
|
# types to string as that is what the model APIs accept
|
120
128
|
truncated: tuple[int, int] | None = None
|
121
|
-
if isinstance(
|
129
|
+
if isinstance(
|
130
|
+
result, ContentText | ContentImage | ContentAudio | ContentVideo
|
131
|
+
):
|
122
132
|
content: str | list[Content] = [result]
|
123
133
|
elif isinstance(result, list) and (
|
124
|
-
isinstance(
|
134
|
+
isinstance(
|
135
|
+
result[0], ContentText | ContentImage | ContentAudio | ContentVideo
|
136
|
+
)
|
125
137
|
):
|
126
138
|
content = result
|
127
139
|
else:
|
@@ -161,6 +173,9 @@ async def call_tools(
|
|
161
173
|
# call tools
|
162
174
|
tool_messages: list[ChatMessageTool] = []
|
163
175
|
for call in message.tool_calls:
|
176
|
+
# create the task
|
177
|
+
task = asyncio.create_task(call_tool_task(call))
|
178
|
+
|
164
179
|
# create pending tool event and add it to the transcript
|
165
180
|
event = ToolEvent(
|
166
181
|
id=call.id,
|
@@ -169,15 +184,44 @@ async def call_tools(
|
|
169
184
|
view=call.view,
|
170
185
|
pending=True,
|
171
186
|
)
|
187
|
+
event.set_task(task)
|
172
188
|
transcript()._event(event)
|
173
189
|
|
174
|
-
# execute the tool call
|
175
|
-
|
176
|
-
|
190
|
+
# execute the tool call. if the operator cancelled the
|
191
|
+
# tool call then synthesize the appropriate message/event
|
192
|
+
try:
|
193
|
+
tool_message, result_event = await task
|
194
|
+
except asyncio.CancelledError:
|
195
|
+
if event.cancelled:
|
196
|
+
tool_message = ChatMessageTool(
|
197
|
+
content="",
|
198
|
+
function=call.function,
|
199
|
+
tool_call_id=call.id,
|
200
|
+
error=ToolCallError(
|
201
|
+
"timeout", "Command timed out before completing."
|
202
|
+
),
|
203
|
+
)
|
204
|
+
result_event = ToolEvent(
|
205
|
+
id=call.id,
|
206
|
+
function=call.function,
|
207
|
+
arguments=call.arguments,
|
208
|
+
result=tool_message.content,
|
209
|
+
truncated=None,
|
210
|
+
view=call.view,
|
211
|
+
error=tool_message.error,
|
212
|
+
events=[],
|
213
|
+
)
|
214
|
+
transcript().info(
|
215
|
+
f"Tool call '{call.function}' was cancelled by operator."
|
216
|
+
)
|
217
|
+
else:
|
218
|
+
raise
|
219
|
+
|
220
|
+
# update return messages
|
177
221
|
tool_messages.append(tool_message)
|
178
222
|
|
179
|
-
#
|
180
|
-
|
223
|
+
# print conversation if display is conversation
|
224
|
+
conversation_tool_mesage(tool_message)
|
181
225
|
|
182
226
|
# update the event with the results
|
183
227
|
event.set_result(
|
@@ -268,6 +312,16 @@ def disable_parallel_tools(
|
|
268
312
|
return False
|
269
313
|
|
270
314
|
|
315
|
+
def type_hint_includes_none(type_hint: Type[Any] | None) -> bool:
|
316
|
+
origin = get_origin(type_hint)
|
317
|
+
|
318
|
+
if origin in {Union, UnionType}:
|
319
|
+
return type(None) in get_args(type_hint)
|
320
|
+
elif origin is Optional:
|
321
|
+
return True
|
322
|
+
return False
|
323
|
+
|
324
|
+
|
271
325
|
def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
|
272
326
|
# parse function typeinfo
|
273
327
|
signature = inspect.signature(func)
|
@@ -296,7 +350,7 @@ def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, An
|
|
296
350
|
# yield parameter (fail if not passed and there is no default)
|
297
351
|
if param_name in input:
|
298
352
|
params[param_name] = tool_param(type_hint, input.get(param_name))
|
299
|
-
elif param.default is not None:
|
353
|
+
elif param.default is not None or type_hint_includes_none(type_hint):
|
300
354
|
params[param_name] = param.default
|
301
355
|
else:
|
302
356
|
raise ToolParsingError(
|
@@ -339,11 +393,21 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
|
|
339
393
|
return [tool_param(args[0], x) for x in input]
|
340
394
|
else:
|
341
395
|
return input
|
396
|
+
elif origin is tuple or origin is Tuple:
|
397
|
+
if args:
|
398
|
+
return tuple([tool_param(args[0], x) for x in input])
|
399
|
+
else:
|
400
|
+
return tuple(input)
|
342
401
|
elif origin is dict or origin is Dict:
|
343
402
|
if args and len(args) > 1:
|
344
403
|
return {k: tool_param(args[1], v) for k, v in input}
|
345
404
|
else:
|
346
405
|
return input
|
406
|
+
elif origin is Union or origin is types.UnionType:
|
407
|
+
if args[1] is type(None):
|
408
|
+
return tool_param(args[0], input)
|
409
|
+
else:
|
410
|
+
return input
|
347
411
|
else:
|
348
412
|
return input
|
349
413
|
|
@@ -389,12 +453,13 @@ def truncate_tool_output(
|
|
389
453
|
# truncate if required
|
390
454
|
truncated = truncate_string_to_bytes(output, active_max_output)
|
391
455
|
if truncated:
|
392
|
-
truncated_output = dedent(
|
456
|
+
truncated_output = dedent("""
|
393
457
|
The output of your call to {tool_name} was too long to be displayed.
|
394
458
|
Here is a truncated version:
|
395
459
|
<START_TOOL_OUTPUT>
|
396
|
-
{
|
397
|
-
<END_TOOL_OUTPUT>
|
460
|
+
{truncated_output}
|
461
|
+
<END_TOOL_OUTPUT>
|
462
|
+
""").format(tool_name=tool_name, truncated_output=truncated.output)
|
398
463
|
return TruncatedToolOutput(
|
399
464
|
truncated_output, truncated.original_bytes, active_max_output
|
400
465
|
)
|
@@ -59,10 +59,8 @@ class ChatMessageBase(BaseModel):
|
|
59
59
|
if isinstance(self.content, str):
|
60
60
|
self.content = text
|
61
61
|
else:
|
62
|
-
|
63
|
-
|
64
|
-
]
|
65
|
-
self.content = [ContentText(text=text)] + all_images
|
62
|
+
all_other = [content for content in self.content if content.type != "text"]
|
63
|
+
self.content = [ContentText(text=text)] + all_other
|
66
64
|
|
67
65
|
|
68
66
|
class ChatMessageSystem(ChatMessageBase):
|
@@ -3,7 +3,8 @@ from rich.text import Text
|
|
3
3
|
|
4
4
|
from inspect_ai._util.rich import lines_display
|
5
5
|
from inspect_ai._util.transcript import transcript_markdown
|
6
|
-
from inspect_ai.util.
|
6
|
+
from inspect_ai.util._conversation import conversation_panel
|
7
|
+
from inspect_ai.util._display import display_type
|
7
8
|
|
8
9
|
from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
|
9
10
|
from ._render import messages_preceding_assistant, render_tool_calls
|
@@ -11,25 +12,25 @@ from ._render import messages_preceding_assistant, render_tool_calls
|
|
11
12
|
MESSAGE_TITLE = "Message"
|
12
13
|
|
13
14
|
|
14
|
-
def
|
15
|
-
if
|
15
|
+
def conversation_tool_mesage(message: ChatMessageTool) -> None:
|
16
|
+
if display_type() == "conversation":
|
16
17
|
# truncate output to 100 lines
|
17
18
|
output = message.error.message if message.error else message.text.strip()
|
18
19
|
content = lines_display(output, 100)
|
19
20
|
|
20
|
-
|
21
|
+
conversation_panel(
|
21
22
|
title=f"Tool Output: {message.function}",
|
22
23
|
content=content,
|
23
24
|
)
|
24
25
|
|
25
26
|
|
26
|
-
def
|
27
|
+
def conversation_assistant_message(
|
27
28
|
input: list[ChatMessage], message: ChatMessageAssistant
|
28
29
|
) -> None:
|
29
|
-
if
|
30
|
+
if display_type() == "conversation":
|
30
31
|
# print precding messages that aren't tool or assistant
|
31
32
|
for m in messages_preceding_assistant(input):
|
32
|
-
|
33
|
+
conversation_panel(
|
33
34
|
title=m.role.capitalize(),
|
34
35
|
content=transcript_markdown(m.text, escape=True),
|
35
36
|
)
|
@@ -45,4 +46,4 @@ def trace_assistant_message(
|
|
45
46
|
content.extend(render_tool_calls(message.tool_calls))
|
46
47
|
|
47
48
|
# print the assistant message
|
48
|
-
|
49
|
+
conversation_panel(title="Assistant", content=content)
|
inspect_ai/model/_model.py
CHANGED
@@ -43,6 +43,7 @@ from ._chat_message import (
|
|
43
43
|
ChatMessageTool,
|
44
44
|
ChatMessageUser,
|
45
45
|
)
|
46
|
+
from ._conversation import conversation_assistant_message
|
46
47
|
from ._generate_config import (
|
47
48
|
GenerateConfig,
|
48
49
|
active_generate_config,
|
@@ -50,7 +51,6 @@ from ._generate_config import (
|
|
50
51
|
)
|
51
52
|
from ._model_call import ModelCall
|
52
53
|
from ._model_output import ModelOutput, ModelUsage
|
53
|
-
from ._trace import trace_assistant_message
|
54
54
|
|
55
55
|
logger = logging.getLogger(__name__)
|
56
56
|
|
@@ -487,7 +487,7 @@ class Model:
|
|
487
487
|
updated_output: ModelOutput, updated_call: ModelCall | None
|
488
488
|
) -> None:
|
489
489
|
# trace
|
490
|
-
|
490
|
+
conversation_assistant_message(input, updated_output.choices[0].message)
|
491
491
|
|
492
492
|
# update event
|
493
493
|
event.output = updated_output
|
@@ -28,11 +28,11 @@ from pydantic import JsonValue
|
|
28
28
|
from typing_extensions import override
|
29
29
|
|
30
30
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, DEFAULT_MAX_RETRIES
|
31
|
-
from inspect_ai._util.content import Content, ContentText
|
31
|
+
from inspect_ai._util.content import Content, ContentImage, ContentText
|
32
32
|
from inspect_ai._util.error import exception_message
|
33
|
-
from inspect_ai._util.images import
|
33
|
+
from inspect_ai._util.images import file_as_data_uri
|
34
34
|
from inspect_ai._util.logger import warn_once
|
35
|
-
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
35
|
+
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
36
36
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
37
37
|
|
38
38
|
from .._chat_message import (
|
@@ -584,11 +584,9 @@ async def message_param_content(
|
|
584
584
|
) -> TextBlockParam | ImageBlockParam:
|
585
585
|
if isinstance(content, ContentText):
|
586
586
|
return TextBlockParam(type="text", text=content.text or NO_CONTENT)
|
587
|
-
|
587
|
+
elif isinstance(content, ContentImage):
|
588
588
|
# resolve to url
|
589
|
-
image = content.image
|
590
|
-
if not is_data_uri(image):
|
591
|
-
image = await image_as_data_uri(image)
|
589
|
+
image = await file_as_data_uri(content.image)
|
592
590
|
|
593
591
|
# resolve mime type and base64 content
|
594
592
|
media_type = data_uri_mime_type(image) or "image/png"
|
@@ -601,6 +599,10 @@ async def message_param_content(
|
|
601
599
|
type="image",
|
602
600
|
source=dict(type="base64", media_type=cast(Any, media_type), data=image),
|
603
601
|
)
|
602
|
+
else:
|
603
|
+
raise RuntimeError(
|
604
|
+
"Anthropic models do not currently support audio or video inputs."
|
605
|
+
)
|
604
606
|
|
605
607
|
|
606
608
|
def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
@@ -31,8 +31,8 @@ from azure.core.exceptions import AzureError, HttpResponseError
|
|
31
31
|
from typing_extensions import override
|
32
32
|
|
33
33
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
34
|
-
from inspect_ai._util.content import Content, ContentText
|
35
|
-
from inspect_ai._util.images import
|
34
|
+
from inspect_ai._util.content import Content, ContentImage, ContentText
|
35
|
+
from inspect_ai._util.images import file_as_data_uri
|
36
36
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
37
37
|
from inspect_ai.tool._tool_call import ToolCall
|
38
38
|
from inspect_ai.tool._tool_choice import ToolFunction
|
@@ -312,12 +312,14 @@ async def chat_request_message(
|
|
312
312
|
async def chat_content_item(content: Content) -> ContentItem:
|
313
313
|
if isinstance(content, ContentText):
|
314
314
|
return TextContentItem(text=content.text)
|
315
|
-
|
315
|
+
elif isinstance(content, ContentImage):
|
316
316
|
return ImageContentItem(
|
317
317
|
image_url=ImageUrl(
|
318
|
-
url=await
|
318
|
+
url=await file_as_data_uri(content.image), detail=content.detail
|
319
319
|
)
|
320
320
|
)
|
321
|
+
else:
|
322
|
+
raise RuntimeError("Azure AI models do not support audio or video inputs.")
|
321
323
|
|
322
324
|
|
323
325
|
def chat_tool_call(tool_call: ToolCall) -> ChatCompletionsToolCall:
|
@@ -11,7 +11,7 @@ from inspect_ai._util.constants import (
|
|
11
11
|
)
|
12
12
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
13
13
|
from inspect_ai._util.error import pip_dependency_error
|
14
|
-
from inspect_ai._util.images import
|
14
|
+
from inspect_ai._util.images import file_as_data
|
15
15
|
from inspect_ai._util.version import verify_required_version
|
16
16
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
17
17
|
from inspect_ai.tool._tool_call import ToolCall
|
@@ -430,7 +430,9 @@ def model_output_from_response(
|
|
430
430
|
content.append(ContentText(type="text", text=c.text))
|
431
431
|
elif c.image is not None:
|
432
432
|
base64_image = base64.b64encode(c.image.source.bytes).decode("utf-8")
|
433
|
-
content.append(
|
433
|
+
content.append(
|
434
|
+
ContentImage(image=f"data:image/{c.image.format};base64,{base64_image}")
|
435
|
+
)
|
434
436
|
elif c.toolUse is not None:
|
435
437
|
tool_calls.append(
|
436
438
|
ToolCall(
|
@@ -565,7 +567,7 @@ async def converse_chat_message(
|
|
565
567
|
if c.type == "text":
|
566
568
|
tool_result_content.append(ConverseToolResultContent(text=c.text))
|
567
569
|
elif c.type == "image":
|
568
|
-
image_data, image_type = await
|
570
|
+
image_data, image_type = await file_as_data(c.image)
|
569
571
|
tool_result_content.append(
|
570
572
|
ConverseToolResultContent(
|
571
573
|
image=ConverseImage(
|
@@ -604,7 +606,7 @@ async def converse_contents(
|
|
604
606
|
result: list[ConverseMessageContent] = []
|
605
607
|
for c in content:
|
606
608
|
if c.type == "image":
|
607
|
-
image_data, image_type = await
|
609
|
+
image_data, image_type = await file_as_data(c.image)
|
608
610
|
result.append(
|
609
611
|
ConverseMessageContent(
|
610
612
|
image=ConverseImage(
|
@@ -1,12 +1,17 @@
|
|
1
|
+
import asyncio
|
1
2
|
import functools
|
3
|
+
import hashlib
|
2
4
|
import json
|
3
5
|
from copy import copy
|
6
|
+
from io import BytesIO
|
7
|
+
from logging import getLogger
|
4
8
|
from typing import Any, cast
|
5
9
|
|
6
10
|
import proto # type: ignore
|
7
11
|
from google.ai.generativelanguage import (
|
8
12
|
Blob,
|
9
13
|
Candidate,
|
14
|
+
File,
|
10
15
|
FunctionCall,
|
11
16
|
FunctionCallingConfig,
|
12
17
|
FunctionDeclaration,
|
@@ -28,6 +33,8 @@ from google.generativeai import ( # type: ignore
|
|
28
33
|
GenerationConfig,
|
29
34
|
GenerativeModel,
|
30
35
|
configure,
|
36
|
+
get_file,
|
37
|
+
upload_file,
|
31
38
|
)
|
32
39
|
from google.generativeai.types import ( # type: ignore
|
33
40
|
AsyncGenerateContentResponse,
|
@@ -45,8 +52,16 @@ from pydantic import JsonValue
|
|
45
52
|
from typing_extensions import override
|
46
53
|
|
47
54
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
48
|
-
from inspect_ai._util.content import
|
49
|
-
|
55
|
+
from inspect_ai._util.content import (
|
56
|
+
Content,
|
57
|
+
ContentAudio,
|
58
|
+
ContentImage,
|
59
|
+
ContentText,
|
60
|
+
ContentVideo,
|
61
|
+
)
|
62
|
+
from inspect_ai._util.images import file_as_data
|
63
|
+
from inspect_ai._util.kvstore import inspect_kvstore
|
64
|
+
from inspect_ai._util.trace import trace_message
|
50
65
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
|
51
66
|
|
52
67
|
from .._chat_message import (
|
@@ -70,6 +85,8 @@ from .._model_output import (
|
|
70
85
|
)
|
71
86
|
from .util import model_base_url
|
72
87
|
|
88
|
+
logger = getLogger(__name__)
|
89
|
+
|
73
90
|
SAFETY_SETTINGS = "safety_settings"
|
74
91
|
|
75
92
|
DEFAULT_SAFETY_SETTINGS: SafetySettingDict = {
|
@@ -194,7 +211,9 @@ class GoogleAPI(ModelAPI):
|
|
194
211
|
model=self.model_name, content=ex.message, stop_reason="model_length"
|
195
212
|
)
|
196
213
|
else:
|
197
|
-
|
214
|
+
return ModelOutput.from_content(
|
215
|
+
model=self.model_name, content=ex.message, stop_reason="unknown"
|
216
|
+
)
|
198
217
|
|
199
218
|
@override
|
200
219
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -362,19 +381,23 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
|
|
362
381
|
return struct
|
363
382
|
|
364
383
|
|
365
|
-
async def content_part(content: Content | str) ->
|
384
|
+
async def content_part(content: Content | str) -> PartType:
|
366
385
|
if isinstance(content, str):
|
367
386
|
return PartDict(text=content or NO_CONTENT)
|
368
387
|
elif isinstance(content, ContentText):
|
369
388
|
return PartDict(text=content.text or NO_CONTENT)
|
370
389
|
else:
|
371
|
-
return
|
390
|
+
return await chat_content_to_part(content)
|
372
391
|
|
373
392
|
|
374
|
-
async def
|
375
|
-
|
376
|
-
|
377
|
-
|
393
|
+
async def chat_content_to_part(
|
394
|
+
content: ContentImage | ContentAudio | ContentVideo,
|
395
|
+
) -> PartType:
|
396
|
+
if isinstance(content, ContentImage):
|
397
|
+
content_bytes, mime_type = await file_as_data(content.image)
|
398
|
+
return Blob(mime_type=mime_type, data=content_bytes)
|
399
|
+
else:
|
400
|
+
return await file_for_content(content)
|
378
401
|
|
379
402
|
|
380
403
|
def prepend_system_messages(
|
@@ -408,25 +431,34 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
|
408
431
|
# https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
|
409
432
|
|
410
433
|
|
411
|
-
def schema_from_param(param: ToolParam | ToolParams) -> Schema:
|
434
|
+
def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
|
412
435
|
if isinstance(param, ToolParams):
|
413
436
|
param = ToolParam(
|
414
437
|
type=param.type, properties=param.properties, required=param.required
|
415
438
|
)
|
416
439
|
|
417
440
|
if param.type == "number":
|
418
|
-
return Schema(
|
441
|
+
return Schema(
|
442
|
+
type=Type.NUMBER, description=param.description, nullable=nullable
|
443
|
+
)
|
419
444
|
elif param.type == "integer":
|
420
|
-
return Schema(
|
445
|
+
return Schema(
|
446
|
+
type=Type.INTEGER, description=param.description, nullable=nullable
|
447
|
+
)
|
421
448
|
elif param.type == "boolean":
|
422
|
-
return Schema(
|
449
|
+
return Schema(
|
450
|
+
type=Type.BOOLEAN, description=param.description, nullable=nullable
|
451
|
+
)
|
423
452
|
elif param.type == "string":
|
424
|
-
return Schema(
|
453
|
+
return Schema(
|
454
|
+
type=Type.STRING, description=param.description, nullable=nullable
|
455
|
+
)
|
425
456
|
elif param.type == "array":
|
426
457
|
return Schema(
|
427
458
|
type=Type.ARRAY,
|
428
459
|
description=param.description,
|
429
460
|
items=schema_from_param(param.items) if param.items else None,
|
461
|
+
nullable=nullable,
|
430
462
|
)
|
431
463
|
elif param.type == "object":
|
432
464
|
return Schema(
|
@@ -436,7 +468,14 @@ def schema_from_param(param: ToolParam | ToolParams) -> Schema:
|
|
436
468
|
if param.properties is not None
|
437
469
|
else None,
|
438
470
|
required=param.required,
|
471
|
+
nullable=nullable,
|
439
472
|
)
|
473
|
+
# convert unions to optional params if the second type is 'null'
|
474
|
+
elif param.anyOf:
|
475
|
+
if len(param.anyOf) == 2 and param.anyOf[1].type == "null":
|
476
|
+
return schema_from_param(param.anyOf[0], nullable=True)
|
477
|
+
else:
|
478
|
+
return Schema(type=Type.TYPE_UNSPECIFIED)
|
440
479
|
else:
|
441
480
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
442
481
|
|
@@ -612,3 +651,53 @@ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
|
|
612
651
|
return HarmBlockThreshold.BLOCK_NONE
|
613
652
|
else:
|
614
653
|
raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
|
654
|
+
|
655
|
+
|
656
|
+
async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
657
|
+
# helper to write trace messages
|
658
|
+
def trace(message: str) -> None:
|
659
|
+
trace_message(logger, "Google Files", message)
|
660
|
+
|
661
|
+
# get the file bytes and compute sha256 hash
|
662
|
+
if isinstance(content, ContentAudio):
|
663
|
+
file = content.audio
|
664
|
+
else:
|
665
|
+
file = content.video
|
666
|
+
content_bytes, mime_type = await file_as_data(file)
|
667
|
+
content_sha256 = hashlib.sha256(content_bytes).hexdigest()
|
668
|
+
|
669
|
+
# we cache uploads for re-use, open the db where we track that
|
670
|
+
# (track up to 1 million previous uploads)
|
671
|
+
with inspect_kvstore("google_files", 1000000) as files_db:
|
672
|
+
# can we serve from existing uploads?
|
673
|
+
uploaded_file = files_db.get(content_sha256)
|
674
|
+
if uploaded_file:
|
675
|
+
try:
|
676
|
+
upload = cast(File, get_file(uploaded_file))
|
677
|
+
if upload.state.name == "ACTIVE":
|
678
|
+
trace(f"Using uploaded file: {uploaded_file}")
|
679
|
+
return upload
|
680
|
+
else:
|
681
|
+
trace(
|
682
|
+
f"Not using uploaded file '{uploaded_file} (state was {upload.state})"
|
683
|
+
)
|
684
|
+
except Exception as ex:
|
685
|
+
trace(f"Error attempting to access uploaded file: {ex}")
|
686
|
+
files_db.delete(content_sha256)
|
687
|
+
|
688
|
+
# do the upload (and record it)
|
689
|
+
upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
|
690
|
+
while upload.state.name == "PROCESSING":
|
691
|
+
await asyncio.sleep(3)
|
692
|
+
upload = get_file(upload.name)
|
693
|
+
|
694
|
+
if upload.state.name == "FAILED":
|
695
|
+
trace(f"Failed to upload file '{upload.name}: {upload.error}")
|
696
|
+
raise ValueError(f"Google file upload failed: {upload.error}")
|
697
|
+
|
698
|
+
# trace and record it
|
699
|
+
trace(f"Uploaded file: {upload.name}")
|
700
|
+
files_db.put(content_sha256, upload.name)
|
701
|
+
|
702
|
+
# return the file
|
703
|
+
return upload
|
@@ -23,8 +23,8 @@ from typing_extensions import override
|
|
23
23
|
|
24
24
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
|
25
25
|
from inspect_ai._util.content import Content
|
26
|
-
from inspect_ai._util.images import
|
27
|
-
from inspect_ai._util.url import
|
26
|
+
from inspect_ai._util.images import file_as_data_uri
|
27
|
+
from inspect_ai._util.url import is_http_url
|
28
28
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
29
29
|
|
30
30
|
from .._chat_message import (
|
@@ -248,18 +248,20 @@ async def as_chat_completion_part(
|
|
248
248
|
) -> ChatCompletionContentPartParam:
|
249
249
|
if content.type == "text":
|
250
250
|
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
251
|
-
|
251
|
+
elif content.type == "image":
|
252
252
|
# API takes URL or base64 encoded file. If it's a remote file or data URL leave it alone, otherwise encode it
|
253
253
|
image_url = content.image
|
254
254
|
detail = content.detail
|
255
255
|
|
256
|
-
if not is_http_url(image_url)
|
257
|
-
image_url = await
|
256
|
+
if not is_http_url(image_url):
|
257
|
+
image_url = await file_as_data_uri(image_url)
|
258
258
|
|
259
259
|
return ChatCompletionContentPartImageParam(
|
260
260
|
type="image_url",
|
261
261
|
image_url=dict(url=image_url, detail=detail),
|
262
262
|
)
|
263
|
+
else:
|
264
|
+
raise RuntimeError("Groq models do not support audio or video inputs.")
|
263
265
|
|
264
266
|
|
265
267
|
def chat_tools(tools: List[ToolInfo]) -> List[Dict[str, Any]]:
|
@@ -239,12 +239,17 @@ class HuggingFaceAPI(ModelAPI):
|
|
239
239
|
hf_messages = inspect_tools_to_string(hf_messages)
|
240
240
|
|
241
241
|
# apply chat template
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
242
|
+
if self.tokenizer.chat_template is not None:
|
243
|
+
chat = self.tokenizer.apply_chat_template(
|
244
|
+
hf_messages,
|
245
|
+
add_generation_prompt=True,
|
246
|
+
tokenize=False,
|
247
|
+
tools=tools_list if len(tools_list) > 0 else None,
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
chat = ""
|
251
|
+
for message in hf_messages:
|
252
|
+
chat += f"{message.role}: {message.content}\n"
|
248
253
|
# return
|
249
254
|
return cast(str, chat)
|
250
255
|
|
@@ -42,8 +42,7 @@ from inspect_ai._util.constants import (
|
|
42
42
|
DEFAULT_TIMEOUT,
|
43
43
|
)
|
44
44
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
45
|
-
from inspect_ai._util.images import
|
46
|
-
from inspect_ai._util.url import is_data_uri
|
45
|
+
from inspect_ai._util.images import file_as_data_uri
|
47
46
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
48
47
|
|
49
48
|
from .._chat_message import (
|
@@ -351,16 +350,14 @@ def mistral_system_message_content(
|
|
351
350
|
async def mistral_content_chunk(content: Content) -> ContentChunk:
|
352
351
|
if isinstance(content, ContentText):
|
353
352
|
return TextChunk(text=content.text or NO_CONTENT)
|
354
|
-
|
353
|
+
elif isinstance(content, ContentImage):
|
355
354
|
# resolve image to url
|
356
|
-
image_url = content.image
|
357
|
-
if not is_data_uri(image_url):
|
358
|
-
image_url = await image_as_data_uri(image_url)
|
355
|
+
image_url = await file_as_data_uri(content.image)
|
359
356
|
|
360
357
|
# return chunk
|
361
|
-
return ImageURLChunk(
|
362
|
-
|
363
|
-
)
|
358
|
+
return ImageURLChunk(image_url=ImageURL(url=image_url, detail=content.detail))
|
359
|
+
else:
|
360
|
+
raise RuntimeError("Mistral models do not support audio or video inputs.")
|
364
361
|
|
365
362
|
|
366
363
|
def mistral_tool_call(tool_call: ToolCall) -> MistralToolCall:
|