inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/common.py +3 -1
- inspect_ai/_cli/eval.py +15 -9
- inspect_ai/_display/core/active.py +4 -1
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +0 -5
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +79 -12
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +3 -1
- inspect_ai/_eval/task/results.py +51 -22
- inspect_ai/_eval/task/run.py +47 -13
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25498 -2044
- inspect_ai/_view/www/log-schema.json +32 -2
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +14 -16
- inspect_ai/_view/www/src/Types.mjs +1 -2
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +77 -4
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +13 -2
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_samples.py +16 -0
- inspect_ai/log/_transcript.py +4 -1
- inspect_ai/model/_call_tools.py +59 -0
- inspect_ai/model/_conversation.py +16 -7
- inspect_ai/model/_generate_config.py +12 -12
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +22 -2
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +152 -55
- inspect_ai/model/_providers/azureai.py +21 -21
- inspect_ai/model/_providers/bedrock.py +37 -40
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/google.py +46 -54
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +13 -12
- inspect_ai/model/_providers/openai.py +51 -218
- inspect_ai/model/_providers/openai_o1.py +11 -12
- inspect_ai/model/_providers/providers.py +23 -1
- inspect_ai/model/_providers/together.py +12 -12
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/model/_providers/vertex.py +1 -4
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +4 -3
- inspect_ai/solver/__init__.py +4 -5
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +12 -1
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/docker/docker.py +64 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/environment.py +14 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
inspect_ai/model/_model.py
CHANGED
@@ -33,6 +33,7 @@ from inspect_ai._util.trace import trace_action
|
|
33
33
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
34
34
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
35
35
|
from inspect_ai.util import concurrency
|
36
|
+
from inspect_ai.util._limit import SampleLimitExceededError
|
36
37
|
|
37
38
|
from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
|
38
39
|
from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
|
@@ -43,7 +44,7 @@ from ._chat_message import (
|
|
43
44
|
ChatMessageTool,
|
44
45
|
ChatMessageUser,
|
45
46
|
)
|
46
|
-
from ._conversation import conversation_assistant_message
|
47
|
+
from ._conversation import conversation_assistant_error, conversation_assistant_message
|
47
48
|
from ._generate_config import (
|
48
49
|
GenerateConfig,
|
49
50
|
active_generate_config,
|
@@ -116,7 +117,7 @@ class ModelAPI(abc.ABC):
|
|
116
117
|
tools: list[ToolInfo],
|
117
118
|
tool_choice: ToolChoice,
|
118
119
|
config: GenerateConfig,
|
119
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
120
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
120
121
|
"""Generate output from the model.
|
121
122
|
|
122
123
|
Args:
|
@@ -165,7 +166,7 @@ class ModelAPI(abc.ABC):
|
|
165
166
|
return False
|
166
167
|
|
167
168
|
def tool_result_images(self) -> bool:
|
168
|
-
"""Tool results can
|
169
|
+
"""Tool results can contain images"""
|
169
170
|
return False
|
170
171
|
|
171
172
|
|
@@ -222,11 +223,17 @@ class Model:
|
|
222
223
|
Returns:
|
223
224
|
ModelOutput
|
224
225
|
"""
|
226
|
+
# if we are the default model then enforce message limit if it
|
227
|
+
# exists (raise an exception if it is exceeded)
|
228
|
+
is_active_model = self == active_model()
|
229
|
+
if is_active_model:
|
230
|
+
handle_sample_message_limit(input)
|
231
|
+
|
225
232
|
# base config for this model
|
226
233
|
base_config = self.config
|
227
234
|
|
228
235
|
# if we are the active_model then merge active generate config
|
229
|
-
if
|
236
|
+
if is_active_model:
|
230
237
|
base_config = base_config.merge(active_generate_config())
|
231
238
|
|
232
239
|
# merge passed config
|
@@ -296,6 +303,9 @@ class Model:
|
|
296
303
|
tools = []
|
297
304
|
tool_choice = "none"
|
298
305
|
|
306
|
+
# apply any tool model_input handlers
|
307
|
+
input = resolve_tool_model_input(tdefs, input)
|
308
|
+
|
299
309
|
# break tool image content out into user messages if the model doesn't
|
300
310
|
# support tools returning images
|
301
311
|
if not self.api.tool_result_images():
|
@@ -389,6 +399,17 @@ class Model:
|
|
389
399
|
output = result
|
390
400
|
call = None
|
391
401
|
|
402
|
+
# raise error
|
403
|
+
if isinstance(output, Exception):
|
404
|
+
complete(output, call)
|
405
|
+
|
406
|
+
# Wrap the error in a runtime error which will show the
|
407
|
+
# request which caused the error
|
408
|
+
error = repr(output)
|
409
|
+
request = json.dumps(call.request, indent=2) if call is not None else ""
|
410
|
+
error_message = f"{error}\n\nRequest:\n{request}"
|
411
|
+
raise RuntimeError(error_message)
|
412
|
+
|
392
413
|
# update output with time elapsed
|
393
414
|
output.time = time_elapsed
|
394
415
|
|
@@ -464,7 +485,7 @@ class Model:
|
|
464
485
|
cache: Literal["read", "write"] | None,
|
465
486
|
output: ModelOutput | None = None,
|
466
487
|
call: ModelCall | None = None,
|
467
|
-
) -> Callable[[ModelOutput, ModelCall | None], None]:
|
488
|
+
) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
|
468
489
|
from inspect_ai.log._transcript import ModelEvent, transcript
|
469
490
|
|
470
491
|
# create event and add it to the transcript
|
@@ -484,13 +505,16 @@ class Model:
|
|
484
505
|
|
485
506
|
# callable that can be used to update the interaction w/ output
|
486
507
|
def complete(
|
487
|
-
|
508
|
+
result: ModelOutput | Exception, updated_call: ModelCall | None
|
488
509
|
) -> None:
|
489
510
|
# trace
|
490
|
-
|
511
|
+
if isinstance(result, ModelOutput):
|
512
|
+
conversation_assistant_message(input, result.choices[0].message)
|
513
|
+
event.output = result
|
514
|
+
else:
|
515
|
+
conversation_assistant_error(result)
|
516
|
+
event.error = repr(result)
|
491
517
|
|
492
|
-
# update event
|
493
|
-
event.output = updated_output
|
494
518
|
event.call = updated_call
|
495
519
|
event.pending = None
|
496
520
|
|
@@ -703,6 +727,40 @@ def simple_input_messages(
|
|
703
727
|
return messages
|
704
728
|
|
705
729
|
|
730
|
+
def resolve_tool_model_input(
|
731
|
+
tdefs: list[ToolDef], messages: list[ChatMessage]
|
732
|
+
) -> list[ChatMessage]:
|
733
|
+
# filter on tooldefs that have a model input handler
|
734
|
+
tdefs = [tdef for tdef in tdefs if tdef.model_input is not None]
|
735
|
+
|
736
|
+
# bail if there are no handlers
|
737
|
+
if len(tdefs) == 0:
|
738
|
+
return messages
|
739
|
+
|
740
|
+
# don't mutate the original messages
|
741
|
+
messages = deepcopy(messages)
|
742
|
+
|
743
|
+
# extract tool messages
|
744
|
+
tool_messages = [
|
745
|
+
message for message in messages if isinstance(message, ChatMessageTool)
|
746
|
+
]
|
747
|
+
# run model_input handlers over all tool_messages with the same function name
|
748
|
+
for tdef in tdefs:
|
749
|
+
assert tdef.model_input
|
750
|
+
# filter messages down to just this tool
|
751
|
+
tdef_tool_messages = [
|
752
|
+
message for message in tool_messages if message.function == tdef.name
|
753
|
+
]
|
754
|
+
# call the function for each tool, passing the index, total, and content
|
755
|
+
for index, message in enumerate(tdef_tool_messages):
|
756
|
+
message.content = tdef.model_input(
|
757
|
+
index, len(tool_messages), message.content
|
758
|
+
)
|
759
|
+
|
760
|
+
# return modified messages
|
761
|
+
return messages
|
762
|
+
|
763
|
+
|
706
764
|
def tool_result_images_as_user_message(
|
707
765
|
messages: list[ChatMessage],
|
708
766
|
) -> list[ChatMessage]:
|
@@ -713,16 +771,21 @@ def tool_result_images_reducer(
|
|
713
771
|
messages: list[ChatMessage],
|
714
772
|
message: ChatMessage,
|
715
773
|
) -> list[ChatMessage]:
|
716
|
-
# append the message
|
717
|
-
messages.append(message)
|
718
|
-
|
719
774
|
# if there are tool result images, pull them out into a ChatUserMessage
|
720
775
|
if isinstance(message, ChatMessageTool) and isinstance(message.content, list):
|
776
|
+
tool_message = ChatMessageTool(
|
777
|
+
content=message.content.copy(),
|
778
|
+
tool_call_id=message.tool_call_id,
|
779
|
+
function=message.function,
|
780
|
+
)
|
781
|
+
assert isinstance(tool_message.content, list)
|
782
|
+
messages.append(tool_message)
|
783
|
+
|
721
784
|
user_content: list[Content] = []
|
722
|
-
for i in range(0, len(
|
723
|
-
if isinstance(
|
785
|
+
for i in range(0, len(tool_message.content)):
|
786
|
+
if isinstance(tool_message.content[i], ContentImage):
|
724
787
|
user_content.append(message.content[i])
|
725
|
-
|
788
|
+
tool_message.content[i] = ContentText(
|
726
789
|
text="Image content is in the message below."
|
727
790
|
)
|
728
791
|
if len(user_content) > 0:
|
@@ -730,6 +793,9 @@ def tool_result_images_reducer(
|
|
730
793
|
ChatMessageUser(content=user_content, tool_call_id=message.tool_call_id)
|
731
794
|
)
|
732
795
|
|
796
|
+
else:
|
797
|
+
messages.append(message)
|
798
|
+
|
733
799
|
# return messages
|
734
800
|
return messages
|
735
801
|
|
@@ -813,6 +879,24 @@ def active_model() -> Model | None:
|
|
813
879
|
active_model_context_var: ContextVar[Model] = ContextVar("active_model")
|
814
880
|
|
815
881
|
|
882
|
+
def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
|
883
|
+
from inspect_ai.log._samples import (
|
884
|
+
active_sample_message_limit,
|
885
|
+
set_active_sample_total_messages,
|
886
|
+
)
|
887
|
+
|
888
|
+
total_messages = 1 if isinstance(input, str) else len(input)
|
889
|
+
message_limit = active_sample_message_limit()
|
890
|
+
if message_limit is not None:
|
891
|
+
if total_messages >= message_limit:
|
892
|
+
raise SampleLimitExceededError(
|
893
|
+
"message", value=total_messages, limit=message_limit
|
894
|
+
)
|
895
|
+
|
896
|
+
# set total messages
|
897
|
+
set_active_sample_total_messages(total_messages)
|
898
|
+
|
899
|
+
|
816
900
|
def init_model_usage() -> None:
|
817
901
|
model_usage_context_var.set({})
|
818
902
|
|
@@ -822,13 +906,28 @@ def init_sample_model_usage() -> None:
|
|
822
906
|
|
823
907
|
|
824
908
|
def record_model_usage(model: str, usage: ModelUsage) -> None:
|
909
|
+
from inspect_ai.log._samples import (
|
910
|
+
active_sample_token_limit,
|
911
|
+
set_active_sample_total_tokens,
|
912
|
+
)
|
913
|
+
|
914
|
+
# record usage
|
825
915
|
set_model_usage(model, usage, sample_model_usage_context_var.get(None))
|
826
916
|
set_model_usage(model, usage, model_usage_context_var.get(None))
|
827
917
|
|
828
|
-
#
|
829
|
-
|
918
|
+
# compute total tokens
|
919
|
+
total_tokens = sample_total_tokens()
|
830
920
|
|
831
|
-
|
921
|
+
# update active sample
|
922
|
+
set_active_sample_total_tokens(total_tokens)
|
923
|
+
|
924
|
+
# check for token limit overflow and raise
|
925
|
+
token_limit = active_sample_token_limit()
|
926
|
+
if token_limit is not None:
|
927
|
+
if total_tokens > token_limit:
|
928
|
+
raise SampleLimitExceededError(
|
929
|
+
"token", value=total_tokens, limit=token_limit
|
930
|
+
)
|
832
931
|
|
833
932
|
|
834
933
|
def set_model_usage(
|
@@ -26,9 +26,14 @@ class ModelUsage(BaseModel):
|
|
26
26
|
|
27
27
|
|
28
28
|
StopReason = Literal[
|
29
|
-
"stop",
|
29
|
+
"stop",
|
30
|
+
"max_tokens",
|
31
|
+
"model_length",
|
32
|
+
"tool_calls",
|
33
|
+
"content_filter",
|
34
|
+
"unknown",
|
30
35
|
]
|
31
|
-
"""Reason that the model stopped
|
36
|
+
"""Reason that the model stopped or failed to generate."""
|
32
37
|
|
33
38
|
|
34
39
|
class TopLogprob(BaseModel):
|
@@ -209,3 +214,18 @@ class ModelOutput(BaseModel):
|
|
209
214
|
)
|
210
215
|
],
|
211
216
|
)
|
217
|
+
|
218
|
+
|
219
|
+
def as_stop_reason(reason: str | None) -> StopReason:
|
220
|
+
"""Encode common reason strings into standard StopReason."""
|
221
|
+
match reason:
|
222
|
+
case "stop" | "eos":
|
223
|
+
return "stop"
|
224
|
+
case "length":
|
225
|
+
return "max_tokens"
|
226
|
+
case "tool_calls" | "function_call":
|
227
|
+
return "tool_calls"
|
228
|
+
case "content_filter" | "model_length" | "max_tokens":
|
229
|
+
return reason
|
230
|
+
case _:
|
231
|
+
return "unknown"
|
@@ -0,0 +1,383 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Literal
|
3
|
+
|
4
|
+
from openai.types.chat import (
|
5
|
+
ChatCompletion,
|
6
|
+
ChatCompletionAssistantMessageParam,
|
7
|
+
ChatCompletionContentPartImageParam,
|
8
|
+
ChatCompletionContentPartInputAudioParam,
|
9
|
+
ChatCompletionContentPartParam,
|
10
|
+
ChatCompletionContentPartRefusalParam,
|
11
|
+
ChatCompletionContentPartTextParam,
|
12
|
+
ChatCompletionDeveloperMessageParam,
|
13
|
+
ChatCompletionMessage,
|
14
|
+
ChatCompletionMessageParam,
|
15
|
+
ChatCompletionMessageToolCall,
|
16
|
+
ChatCompletionMessageToolCallParam,
|
17
|
+
ChatCompletionNamedToolChoiceParam,
|
18
|
+
ChatCompletionSystemMessageParam,
|
19
|
+
ChatCompletionToolChoiceOptionParam,
|
20
|
+
ChatCompletionToolMessageParam,
|
21
|
+
ChatCompletionToolParam,
|
22
|
+
ChatCompletionUserMessageParam,
|
23
|
+
)
|
24
|
+
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
|
25
|
+
from openai.types.chat.chat_completion_message_tool_call import Function
|
26
|
+
from openai.types.completion_usage import CompletionUsage
|
27
|
+
from openai.types.shared_params.function_definition import FunctionDefinition
|
28
|
+
|
29
|
+
from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentText
|
30
|
+
from inspect_ai._util.images import file_as_data_uri
|
31
|
+
from inspect_ai._util.url import is_http_url
|
32
|
+
from inspect_ai.model._call_tools import parse_tool_call
|
33
|
+
from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs
|
34
|
+
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
35
|
+
|
36
|
+
from ._chat_message import (
|
37
|
+
ChatMessage,
|
38
|
+
ChatMessageAssistant,
|
39
|
+
ChatMessageSystem,
|
40
|
+
ChatMessageTool,
|
41
|
+
ChatMessageUser,
|
42
|
+
)
|
43
|
+
from ._model_output import ModelUsage, StopReason, as_stop_reason
|
44
|
+
|
45
|
+
|
46
|
+
def is_o1(name: str) -> bool:
|
47
|
+
return name.startswith("o1")
|
48
|
+
|
49
|
+
|
50
|
+
def is_o1_full(name: str) -> bool:
|
51
|
+
return is_o1(name) and not is_o1_mini(name) and not is_o1_preview(name)
|
52
|
+
|
53
|
+
|
54
|
+
def is_o1_mini(name: str) -> bool:
|
55
|
+
return name.startswith("o1-mini")
|
56
|
+
|
57
|
+
|
58
|
+
def is_o1_preview(name: str) -> bool:
|
59
|
+
return name.startswith("o1-preview")
|
60
|
+
|
61
|
+
|
62
|
+
def openai_chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCall:
|
63
|
+
return ChatCompletionMessageToolCall(
|
64
|
+
type="function",
|
65
|
+
id=tool_call.id,
|
66
|
+
function=Function(
|
67
|
+
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
68
|
+
),
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def openai_chat_tool_call_param(
|
73
|
+
tool_call: ToolCall,
|
74
|
+
) -> ChatCompletionMessageToolCallParam:
|
75
|
+
return ChatCompletionMessageToolCallParam(
|
76
|
+
id=tool_call.id,
|
77
|
+
function=dict(
|
78
|
+
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
79
|
+
),
|
80
|
+
type=tool_call.type,
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
async def openai_chat_completion_part(
|
85
|
+
content: Content,
|
86
|
+
) -> ChatCompletionContentPartParam:
|
87
|
+
if content.type == "text":
|
88
|
+
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
89
|
+
elif content.type == "image":
|
90
|
+
# API takes URL or base64 encoded file. If it's a remote file or
|
91
|
+
# data URL leave it alone, otherwise encode it
|
92
|
+
image_url = content.image
|
93
|
+
detail = content.detail
|
94
|
+
|
95
|
+
if not is_http_url(image_url):
|
96
|
+
image_url = await file_as_data_uri(image_url)
|
97
|
+
|
98
|
+
return ChatCompletionContentPartImageParam(
|
99
|
+
type="image_url",
|
100
|
+
image_url=dict(url=image_url, detail=detail),
|
101
|
+
)
|
102
|
+
elif content.type == "audio":
|
103
|
+
audio_data = await file_as_data_uri(content.audio)
|
104
|
+
|
105
|
+
return ChatCompletionContentPartInputAudioParam(
|
106
|
+
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
107
|
+
)
|
108
|
+
|
109
|
+
else:
|
110
|
+
raise RuntimeError(
|
111
|
+
"Video content is not currently supported by Open AI chat models."
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
async def openai_chat_message(
|
116
|
+
message: ChatMessage, model: str
|
117
|
+
) -> ChatCompletionMessageParam:
|
118
|
+
if message.role == "system":
|
119
|
+
if is_o1(model):
|
120
|
+
return ChatCompletionDeveloperMessageParam(
|
121
|
+
role="developer", content=message.text
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
return ChatCompletionSystemMessageParam(
|
125
|
+
role=message.role, content=message.text
|
126
|
+
)
|
127
|
+
elif message.role == "user":
|
128
|
+
return ChatCompletionUserMessageParam(
|
129
|
+
role=message.role,
|
130
|
+
content=(
|
131
|
+
message.content
|
132
|
+
if isinstance(message.content, str)
|
133
|
+
else [
|
134
|
+
await openai_chat_completion_part(content)
|
135
|
+
for content in message.content
|
136
|
+
]
|
137
|
+
),
|
138
|
+
)
|
139
|
+
elif message.role == "assistant":
|
140
|
+
if message.tool_calls:
|
141
|
+
return ChatCompletionAssistantMessageParam(
|
142
|
+
role=message.role,
|
143
|
+
content=message.text,
|
144
|
+
tool_calls=[
|
145
|
+
openai_chat_tool_call_param(call) for call in message.tool_calls
|
146
|
+
],
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
return ChatCompletionAssistantMessageParam(
|
150
|
+
role=message.role, content=message.text
|
151
|
+
)
|
152
|
+
elif message.role == "tool":
|
153
|
+
return ChatCompletionToolMessageParam(
|
154
|
+
role=message.role,
|
155
|
+
content=(
|
156
|
+
f"Error: {message.error.message}" if message.error else message.text
|
157
|
+
),
|
158
|
+
tool_call_id=str(message.tool_call_id),
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
raise ValueError(f"Unexpected message role {message.role}")
|
162
|
+
|
163
|
+
|
164
|
+
async def openai_chat_messages(
|
165
|
+
messages: list[ChatMessage], model: str
|
166
|
+
) -> list[ChatCompletionMessageParam]:
|
167
|
+
return [await openai_chat_message(message, model) for message in messages]
|
168
|
+
|
169
|
+
|
170
|
+
def openai_chat_choices(choices: list[ChatCompletionChoice]) -> list[Choice]:
|
171
|
+
oai_choices: list[Choice] = []
|
172
|
+
|
173
|
+
for index, choice in enumerate(choices):
|
174
|
+
if isinstance(choice.message.content, str):
|
175
|
+
content = choice.message.content
|
176
|
+
else:
|
177
|
+
content = "\n".join(
|
178
|
+
[c.text for c in choice.message.content if c.type == "text"]
|
179
|
+
)
|
180
|
+
if choice.message.tool_calls:
|
181
|
+
tool_calls = [openai_chat_tool_call(tc) for tc in choice.message.tool_calls]
|
182
|
+
else:
|
183
|
+
tool_calls = None
|
184
|
+
message = ChatCompletionMessage(
|
185
|
+
role="assistant", content=content, tool_calls=tool_calls
|
186
|
+
)
|
187
|
+
oai_choices.append(
|
188
|
+
Choice(
|
189
|
+
finish_reason=openai_finish_reason(choice.stop_reason),
|
190
|
+
index=index,
|
191
|
+
message=message,
|
192
|
+
logprobs=ChoiceLogprobs(**choice.logprobs.model_dump())
|
193
|
+
if choice.logprobs is not None
|
194
|
+
else None,
|
195
|
+
)
|
196
|
+
)
|
197
|
+
|
198
|
+
return oai_choices
|
199
|
+
|
200
|
+
|
201
|
+
def openai_completion_usage(usage: ModelUsage) -> CompletionUsage:
|
202
|
+
return CompletionUsage(
|
203
|
+
completion_tokens=usage.output_tokens,
|
204
|
+
prompt_tokens=usage.input_tokens,
|
205
|
+
total_tokens=usage.total_tokens,
|
206
|
+
)
|
207
|
+
|
208
|
+
|
209
|
+
def openai_finish_reason(
|
210
|
+
stop_reason: StopReason,
|
211
|
+
) -> Literal["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
212
|
+
match stop_reason:
|
213
|
+
case "stop" | "tool_calls" | "content_filter":
|
214
|
+
return stop_reason
|
215
|
+
case "model_length":
|
216
|
+
return "length"
|
217
|
+
case _:
|
218
|
+
return "stop"
|
219
|
+
|
220
|
+
|
221
|
+
def openai_chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
|
222
|
+
function = FunctionDefinition(
|
223
|
+
name=tool.name,
|
224
|
+
description=tool.description,
|
225
|
+
parameters=tool.parameters.model_dump(exclude_none=True),
|
226
|
+
)
|
227
|
+
return ChatCompletionToolParam(type="function", function=function)
|
228
|
+
|
229
|
+
|
230
|
+
def openai_chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
|
231
|
+
return [openai_chat_tool_param(tool) for tool in tools]
|
232
|
+
|
233
|
+
|
234
|
+
def openai_chat_tool_choice(
|
235
|
+
tool_choice: ToolChoice,
|
236
|
+
) -> ChatCompletionToolChoiceOptionParam:
|
237
|
+
if isinstance(tool_choice, ToolFunction):
|
238
|
+
return ChatCompletionNamedToolChoiceParam(
|
239
|
+
type="function", function=dict(name=tool_choice.name)
|
240
|
+
)
|
241
|
+
# openai supports 'any' via the 'required' keyword
|
242
|
+
elif tool_choice == "any":
|
243
|
+
return "required"
|
244
|
+
else:
|
245
|
+
return tool_choice
|
246
|
+
|
247
|
+
|
248
|
+
def chat_tool_calls_from_openai(
|
249
|
+
message: ChatCompletionMessage, tools: list[ToolInfo]
|
250
|
+
) -> list[ToolCall] | None:
|
251
|
+
if message.tool_calls:
|
252
|
+
return [
|
253
|
+
parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
|
254
|
+
for call in message.tool_calls
|
255
|
+
]
|
256
|
+
else:
|
257
|
+
return None
|
258
|
+
|
259
|
+
|
260
|
+
def chat_messages_from_openai(
|
261
|
+
messages: list[ChatCompletionMessageParam],
|
262
|
+
) -> list[ChatMessage]:
|
263
|
+
# track tool names by id
|
264
|
+
tool_names: dict[str, str] = {}
|
265
|
+
|
266
|
+
chat_messages: list[ChatMessage] = []
|
267
|
+
|
268
|
+
for message in messages:
|
269
|
+
if message["role"] == "system" or message["role"] == "developer":
|
270
|
+
sys_content = message["content"]
|
271
|
+
if isinstance(sys_content, str):
|
272
|
+
chat_messages.append(ChatMessageSystem(content=sys_content))
|
273
|
+
else:
|
274
|
+
chat_messages.append(
|
275
|
+
ChatMessageSystem(
|
276
|
+
content=[content_from_openai(c) for c in sys_content]
|
277
|
+
)
|
278
|
+
)
|
279
|
+
elif message["role"] == "user":
|
280
|
+
user_content = message["content"]
|
281
|
+
if isinstance(user_content, str):
|
282
|
+
chat_messages.append(ChatMessageUser(content=user_content))
|
283
|
+
else:
|
284
|
+
chat_messages.append(
|
285
|
+
ChatMessageUser(
|
286
|
+
content=[content_from_openai(c) for c in user_content]
|
287
|
+
)
|
288
|
+
)
|
289
|
+
elif message["role"] == "assistant":
|
290
|
+
# resolve content
|
291
|
+
asst_content = message["content"]
|
292
|
+
if isinstance(asst_content, str):
|
293
|
+
content: str | list[Content] = asst_content
|
294
|
+
elif asst_content is None:
|
295
|
+
content = message.get("refusal", None) or ""
|
296
|
+
else:
|
297
|
+
content = [content_from_openai(c) for c in asst_content]
|
298
|
+
|
299
|
+
# return message
|
300
|
+
if "tool_calls" in message:
|
301
|
+
tool_calls: list[ToolCall] = []
|
302
|
+
for tc in message["tool_calls"]:
|
303
|
+
tool_calls.append(tool_call_from_openai(tc))
|
304
|
+
tool_names[tc["id"]] = tc["function"]["name"]
|
305
|
+
|
306
|
+
else:
|
307
|
+
tool_calls = []
|
308
|
+
chat_messages.append(
|
309
|
+
ChatMessageAssistant(content=content, tool_calls=tool_calls or None)
|
310
|
+
)
|
311
|
+
elif message["role"] == "tool":
|
312
|
+
tool_content = message.get("content", None) or ""
|
313
|
+
if isinstance(tool_content, str):
|
314
|
+
content = tool_content
|
315
|
+
else:
|
316
|
+
content = [content_from_openai(c) for c in tool_content]
|
317
|
+
chat_messages.append(
|
318
|
+
ChatMessageTool(
|
319
|
+
content=content,
|
320
|
+
tool_call_id=message["tool_call_id"],
|
321
|
+
function=tool_names.get(message["tool_call_id"], ""),
|
322
|
+
)
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
raise ValueError(f"Unexpected message param type: {type(message)}")
|
326
|
+
|
327
|
+
return chat_messages
|
328
|
+
|
329
|
+
|
330
|
+
def tool_call_from_openai(tool_call: ChatCompletionMessageToolCallParam) -> ToolCall:
|
331
|
+
return parse_tool_call(
|
332
|
+
tool_call["id"],
|
333
|
+
tool_call["function"]["name"],
|
334
|
+
tool_call["function"]["arguments"],
|
335
|
+
)
|
336
|
+
|
337
|
+
|
338
|
+
def content_from_openai(
|
339
|
+
content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
|
340
|
+
) -> Content:
|
341
|
+
if content["type"] == "text":
|
342
|
+
return ContentText(text=content["text"])
|
343
|
+
elif content["type"] == "image_url":
|
344
|
+
return ContentImage(
|
345
|
+
image=content["image_url"]["url"], detail=content["image_url"]["detail"]
|
346
|
+
)
|
347
|
+
elif content["type"] == "input_audio":
|
348
|
+
return ContentAudio(
|
349
|
+
audio=content["input_audio"]["data"],
|
350
|
+
format=content["input_audio"]["format"],
|
351
|
+
)
|
352
|
+
elif content["type"] == "refusal":
|
353
|
+
return ContentText(text=content["refusal"])
|
354
|
+
|
355
|
+
|
356
|
+
def chat_message_assistant_from_openai(
|
357
|
+
message: ChatCompletionMessage, tools: list[ToolInfo]
|
358
|
+
) -> ChatMessageAssistant:
|
359
|
+
refusal = getattr(message, "refusal", None)
|
360
|
+
return ChatMessageAssistant(
|
361
|
+
content=refusal or message.content or "",
|
362
|
+
source="generate",
|
363
|
+
tool_calls=chat_tool_calls_from_openai(message, tools),
|
364
|
+
)
|
365
|
+
|
366
|
+
|
367
|
+
def chat_choices_from_openai(
|
368
|
+
response: ChatCompletion, tools: list[ToolInfo]
|
369
|
+
) -> list[ChatCompletionChoice]:
|
370
|
+
choices = list(response.choices)
|
371
|
+
choices.sort(key=lambda c: c.index)
|
372
|
+
return [
|
373
|
+
ChatCompletionChoice(
|
374
|
+
message=chat_message_assistant_from_openai(choice.message, tools),
|
375
|
+
stop_reason=as_stop_reason(choice.finish_reason),
|
376
|
+
logprobs=(
|
377
|
+
Logprobs(**choice.logprobs.model_dump())
|
378
|
+
if choice.logprobs is not None
|
379
|
+
else None
|
380
|
+
),
|
381
|
+
)
|
382
|
+
for choice in choices
|
383
|
+
]
|