inspect-ai 0.3.55__py3-none-any.whl → 0.3.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +1 -0
- inspect_ai/_cli/common.py +1 -1
- inspect_ai/_cli/trace.py +33 -20
- inspect_ai/_display/core/active.py +1 -1
- inspect_ai/_display/core/display.py +1 -1
- inspect_ai/_display/core/footer.py +1 -1
- inspect_ai/_display/core/panel.py +1 -1
- inspect_ai/_display/core/progress.py +0 -6
- inspect_ai/_display/core/rich.py +1 -1
- inspect_ai/_display/rich/display.py +2 -2
- inspect_ai/_display/textual/app.py +15 -17
- inspect_ai/_display/textual/widgets/clock.py +3 -3
- inspect_ai/_display/textual/widgets/samples.py +6 -13
- inspect_ai/_eval/context.py +9 -1
- inspect_ai/_eval/run.py +16 -11
- inspect_ai/_eval/score.py +4 -10
- inspect_ai/_eval/task/results.py +5 -4
- inspect_ai/_eval/task/run.py +6 -12
- inspect_ai/_eval/task/task.py +10 -0
- inspect_ai/_util/ansi.py +31 -0
- inspect_ai/_util/datetime.py +1 -1
- inspect_ai/_util/deprecation.py +1 -1
- inspect_ai/_util/format.py +7 -0
- inspect_ai/_util/json.py +11 -1
- inspect_ai/_util/logger.py +14 -13
- inspect_ai/_util/throttle.py +10 -1
- inspect_ai/_util/trace.py +79 -47
- inspect_ai/_util/transcript.py +37 -4
- inspect_ai/_util/vscode.py +51 -0
- inspect_ai/_view/notify.py +2 -1
- inspect_ai/_view/www/.prettierrc.js +12 -0
- inspect_ai/_view/www/App.css +22 -1
- inspect_ai/_view/www/dist/assets/index.css +2374 -2
- inspect_ai/_view/www/dist/assets/index.js +29752 -24492
- inspect_ai/_view/www/log-schema.json +262 -215
- inspect_ai/_view/www/package.json +1 -0
- inspect_ai/_view/www/src/App.mjs +19 -9
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/api/Types.mjs +15 -4
- inspect_ai/_view/www/src/api/api-http.mjs +2 -0
- inspect_ai/_view/www/src/appearance/Icons.mjs +2 -0
- inspect_ai/_view/www/src/components/AsciiCinemaPlayer.mjs +74 -0
- inspect_ai/_view/www/src/components/CopyButton.mjs +0 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
- inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
- inspect_ai/_view/www/src/components/HumanBaselineView.mjs +168 -0
- inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
- inspect_ai/_view/www/src/components/LightboxCarousel.mjs +217 -0
- inspect_ai/_view/www/src/components/MessageContent.mjs +1 -1
- inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
- inspect_ai/_view/www/src/components/Tools.mjs +28 -5
- inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
- inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
- inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
- inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
- inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
- inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +238 -178
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
- inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +3 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +1 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +56 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +17 -5
- inspect_ai/_view/www/src/types/asciicinema-player.d.ts +26 -0
- inspect_ai/_view/www/src/types/log.d.ts +28 -20
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/_view/www/yarn.lock +44 -0
- inspect_ai/approval/_apply.py +4 -0
- inspect_ai/approval/_human/panel.py +5 -8
- inspect_ai/dataset/_dataset.py +51 -10
- inspect_ai/dataset/_util.py +31 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +30 -2
- inspect_ai/log/_recorders/eval.py +2 -0
- inspect_ai/model/_call_tools.py +31 -7
- inspect_ai/model/_chat_message.py +3 -0
- inspect_ai/model/_model.py +42 -1
- inspect_ai/model/_providers/anthropic.py +4 -0
- inspect_ai/model/_providers/google.py +24 -6
- inspect_ai/model/_providers/openai.py +17 -3
- inspect_ai/model/_providers/openai_o1.py +10 -12
- inspect_ai/model/_render.py +9 -2
- inspect_ai/scorer/_metric.py +12 -1
- inspect_ai/solver/__init__.py +2 -0
- inspect_ai/solver/_human_agent/agent.py +83 -0
- inspect_ai/solver/_human_agent/commands/__init__.py +36 -0
- inspect_ai/solver/_human_agent/commands/clock.py +70 -0
- inspect_ai/solver/_human_agent/commands/command.py +59 -0
- inspect_ai/solver/_human_agent/commands/instructions.py +74 -0
- inspect_ai/solver/_human_agent/commands/note.py +42 -0
- inspect_ai/solver/_human_agent/commands/score.py +80 -0
- inspect_ai/solver/_human_agent/commands/status.py +62 -0
- inspect_ai/solver/_human_agent/commands/submit.py +151 -0
- inspect_ai/solver/_human_agent/install.py +222 -0
- inspect_ai/solver/_human_agent/panel.py +252 -0
- inspect_ai/solver/_human_agent/service.py +45 -0
- inspect_ai/solver/_human_agent/state.py +55 -0
- inspect_ai/solver/_human_agent/view.py +24 -0
- inspect_ai/solver/_task_state.py +28 -2
- inspect_ai/tool/_tool.py +10 -2
- inspect_ai/tool/_tool_info.py +2 -1
- inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +16 -13
- inspect_ai/util/__init__.py +12 -4
- inspect_ai/{_util/display.py → util/_display.py} +6 -0
- inspect_ai/util/_panel.py +31 -9
- inspect_ai/util/_sandbox/__init__.py +0 -3
- inspect_ai/util/_sandbox/context.py +5 -1
- inspect_ai/util/_sandbox/docker/compose.py +17 -13
- inspect_ai/util/_sandbox/docker/docker.py +9 -6
- inspect_ai/util/_sandbox/docker/internal.py +1 -1
- inspect_ai/util/_sandbox/docker/util.py +3 -2
- inspect_ai/util/_sandbox/environment.py +6 -5
- inspect_ai/util/_sandbox/local.py +1 -1
- inspect_ai/util/_sandbox/self_check.py +18 -18
- inspect_ai/util/_sandbox/service.py +22 -7
- inspect_ai/util/_store.py +7 -8
- inspect_ai/util/_store_model.py +110 -0
- inspect_ai/util/_subprocess.py +3 -3
- inspect_ai/util/_throttle.py +32 -0
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/RECORD +131 -108
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/top_level.txt +0 -0
inspect_ai/log/__init__.py
CHANGED
@@ -23,6 +23,7 @@ from ._log import (
|
|
23
23
|
EvalRevision,
|
24
24
|
EvalSample,
|
25
25
|
EvalSampleReductions,
|
26
|
+
EvalSampleScore,
|
26
27
|
EvalScore,
|
27
28
|
EvalSpec,
|
28
29
|
EvalStats,
|
@@ -60,6 +61,7 @@ __all__ = [
|
|
60
61
|
"EvalResults",
|
61
62
|
"EvalRevision",
|
62
63
|
"EvalSample",
|
64
|
+
"EvalSampleScore",
|
63
65
|
"EvalSampleReductions",
|
64
66
|
"EvalScore",
|
65
67
|
"EvalSpec",
|
inspect_ai/log/_log.py
CHANGED
@@ -16,6 +16,7 @@ from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, PKG_NAME
|
|
16
16
|
from inspect_ai._util.error import EvalError, exception_message
|
17
17
|
from inspect_ai._util.logger import warn_once
|
18
18
|
from inspect_ai.approval._policy import ApprovalPolicyConfig
|
19
|
+
from inspect_ai.dataset._dataset import MT, metadata_as
|
19
20
|
from inspect_ai.model import (
|
20
21
|
ChatMessage,
|
21
22
|
GenerateConfig,
|
@@ -23,8 +24,9 @@ from inspect_ai.model import (
|
|
23
24
|
ModelUsage,
|
24
25
|
)
|
25
26
|
from inspect_ai.scorer import Score
|
26
|
-
from inspect_ai.scorer._metric import SampleScore
|
27
27
|
from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
|
28
|
+
from inspect_ai.util._store import Store
|
29
|
+
from inspect_ai.util._store_model import SMT
|
28
30
|
|
29
31
|
from ._transcript import Event
|
30
32
|
|
@@ -159,9 +161,31 @@ class EvalSample(BaseModel):
|
|
159
161
|
metadata: dict[str, Any]
|
160
162
|
"""Additional sample metadata."""
|
161
163
|
|
164
|
+
def metadata_as(self, metadata_cls: Type[MT]) -> MT:
|
165
|
+
"""Pydantic model interface to metadata.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
metadata_cls: Pydantic model type
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
BaseModel: Instance of metadata_cls bound to sample metadata.
|
172
|
+
"""
|
173
|
+
return metadata_as(self.metadata, metadata_cls)
|
174
|
+
|
162
175
|
store: dict[str, Any] = Field(default_factory=dict)
|
163
176
|
"""State at end of sample execution."""
|
164
177
|
|
178
|
+
def store_as(self, model_cls: Type[SMT]) -> SMT:
|
179
|
+
"""Pydantic model interface to the store.
|
180
|
+
|
181
|
+
Args:
|
182
|
+
model_cls: Pydantic model type (must derive from StoreModel)
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
StoreModel: Instance of model_cls bound to sample store data.
|
186
|
+
"""
|
187
|
+
return model_cls(store=Store(self.store))
|
188
|
+
|
165
189
|
events: list[Event] = Field(default_factory=list)
|
166
190
|
"""Events that occurred during sample execution."""
|
167
191
|
|
@@ -301,6 +325,10 @@ class EvalScore(BaseModel):
|
|
301
325
|
"""Additional scorer metadata."""
|
302
326
|
|
303
327
|
|
328
|
+
class EvalSampleScore(Score):
|
329
|
+
sample_id: str | int | None = Field(default=None)
|
330
|
+
|
331
|
+
|
304
332
|
class EvalSampleReductions(BaseModel):
|
305
333
|
scorer: str
|
306
334
|
"""Name the of scorer"""
|
@@ -308,7 +336,7 @@ class EvalSampleReductions(BaseModel):
|
|
308
336
|
reducer: str | None = Field(default=None)
|
309
337
|
"""Name the of reducer"""
|
310
338
|
|
311
|
-
samples: list[
|
339
|
+
samples: list[EvalSampleScore]
|
312
340
|
"""List of reduced scores"""
|
313
341
|
|
314
342
|
|
@@ -252,6 +252,8 @@ def text_inputs(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
|
|
252
252
|
filtered_content.append(ContentText(text="(Image)"))
|
253
253
|
message.content = filtered_content
|
254
254
|
input.append(message)
|
255
|
+
else:
|
256
|
+
input.append(message)
|
255
257
|
|
256
258
|
return input
|
257
259
|
else:
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1
1
|
import asyncio
|
2
2
|
import inspect
|
3
|
+
import types
|
3
4
|
from dataclasses import is_dataclass
|
4
5
|
from logging import getLogger
|
5
6
|
from textwrap import dedent
|
7
|
+
from types import UnionType
|
6
8
|
from typing import (
|
7
9
|
Any,
|
8
10
|
Callable,
|
9
11
|
Dict,
|
10
12
|
List,
|
11
13
|
NamedTuple,
|
14
|
+
Optional,
|
15
|
+
Tuple,
|
12
16
|
Type,
|
17
|
+
Union,
|
13
18
|
get_args,
|
14
19
|
get_origin,
|
15
20
|
get_type_hints,
|
@@ -25,10 +30,7 @@ from inspect_ai._util.text import truncate_string_to_bytes
|
|
25
30
|
from inspect_ai._util.trace import trace_action
|
26
31
|
from inspect_ai.model._trace import trace_tool_mesage
|
27
32
|
from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
|
28
|
-
from inspect_ai.tool._tool import
|
29
|
-
ToolApprovalError,
|
30
|
-
ToolParsingError,
|
31
|
-
)
|
33
|
+
from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
|
32
34
|
from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
|
33
35
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
34
36
|
from inspect_ai.tool._tool_info import parse_docstring
|
@@ -118,10 +120,12 @@ async def call_tools(
|
|
118
120
|
# massage result, leave list[Content] alone, convert all other
|
119
121
|
# types to string as that is what the model APIs accept
|
120
122
|
truncated: tuple[int, int] | None = None
|
121
|
-
if isinstance(result,
|
123
|
+
if isinstance(result, ContentText | ContentImage):
|
124
|
+
content: str | list[Content] = [result]
|
125
|
+
elif isinstance(result, list) and (
|
122
126
|
isinstance(result[0], ContentText | ContentImage)
|
123
127
|
):
|
124
|
-
content
|
128
|
+
content = result
|
125
129
|
else:
|
126
130
|
content = str(result)
|
127
131
|
|
@@ -266,6 +270,16 @@ def disable_parallel_tools(
|
|
266
270
|
return False
|
267
271
|
|
268
272
|
|
273
|
+
def type_hint_includes_none(type_hint: Type[Any] | None) -> bool:
|
274
|
+
origin = get_origin(type_hint)
|
275
|
+
|
276
|
+
if origin in {Union, UnionType}:
|
277
|
+
return type(None) in get_args(type_hint)
|
278
|
+
elif origin is Optional:
|
279
|
+
return True
|
280
|
+
return False
|
281
|
+
|
282
|
+
|
269
283
|
def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
|
270
284
|
# parse function typeinfo
|
271
285
|
signature = inspect.signature(func)
|
@@ -294,7 +308,7 @@ def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, An
|
|
294
308
|
# yield parameter (fail if not passed and there is no default)
|
295
309
|
if param_name in input:
|
296
310
|
params[param_name] = tool_param(type_hint, input.get(param_name))
|
297
|
-
elif param.default is not None:
|
311
|
+
elif param.default is not None or type_hint_includes_none(type_hint):
|
298
312
|
params[param_name] = param.default
|
299
313
|
else:
|
300
314
|
raise ToolParsingError(
|
@@ -337,11 +351,21 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
|
|
337
351
|
return [tool_param(args[0], x) for x in input]
|
338
352
|
else:
|
339
353
|
return input
|
354
|
+
elif origin is tuple or origin is Tuple:
|
355
|
+
if args:
|
356
|
+
return tuple([tool_param(args[0], x) for x in input])
|
357
|
+
else:
|
358
|
+
return tuple(input)
|
340
359
|
elif origin is dict or origin is Dict:
|
341
360
|
if args and len(args) > 1:
|
342
361
|
return {k: tool_param(args[1], v) for k, v in input}
|
343
362
|
else:
|
344
363
|
return input
|
364
|
+
elif origin is Union or origin is types.UnionType:
|
365
|
+
if args[1] is type(None):
|
366
|
+
return tool_param(args[0], input)
|
367
|
+
else:
|
368
|
+
return input
|
345
369
|
else:
|
346
370
|
return input
|
347
371
|
|
@@ -74,6 +74,9 @@ class ChatMessageUser(ChatMessageBase):
|
|
74
74
|
role: Literal["user"] = Field(default="user")
|
75
75
|
"""Conversation role."""
|
76
76
|
|
77
|
+
tool_call_id: str | None = Field(default=None)
|
78
|
+
"""ID of tool call this message has the content payload for."""
|
79
|
+
|
77
80
|
|
78
81
|
class ChatMessageAssistant(ChatMessageBase):
|
79
82
|
role: Literal["assistant"] = Field(default="assistant")
|
inspect_ai/model/_model.py
CHANGED
@@ -19,7 +19,7 @@ from tenacity import (
|
|
19
19
|
)
|
20
20
|
|
21
21
|
from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS
|
22
|
-
from inspect_ai._util.content import ContentText
|
22
|
+
from inspect_ai._util.content import Content, ContentImage, ContentText
|
23
23
|
from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry
|
24
24
|
from inspect_ai._util.platform import platform_init
|
25
25
|
from inspect_ai._util.registry import (
|
@@ -40,6 +40,7 @@ from ._chat_message import (
|
|
40
40
|
ChatMessage,
|
41
41
|
ChatMessageAssistant,
|
42
42
|
ChatMessageSystem,
|
43
|
+
ChatMessageTool,
|
43
44
|
ChatMessageUser,
|
44
45
|
)
|
45
46
|
from ._generate_config import (
|
@@ -163,6 +164,10 @@ class ModelAPI(abc.ABC):
|
|
163
164
|
"""Any tool use in a message stream means that tools must be passed."""
|
164
165
|
return False
|
165
166
|
|
167
|
+
def tool_result_images(self) -> bool:
|
168
|
+
"""Tool results can containe images"""
|
169
|
+
return False
|
170
|
+
|
166
171
|
|
167
172
|
class Model:
|
168
173
|
"""Model interface."""
|
@@ -291,6 +296,11 @@ class Model:
|
|
291
296
|
tools = []
|
292
297
|
tool_choice = "none"
|
293
298
|
|
299
|
+
# break tool image content out into user messages if the model doesn't
|
300
|
+
# support tools returning images
|
301
|
+
if not self.api.tool_result_images():
|
302
|
+
input = tool_result_images_as_user_message(input)
|
303
|
+
|
294
304
|
# optionally collapse *consecutive* messages into one -
|
295
305
|
# (some apis e.g. anthropic require this)
|
296
306
|
if self.api.collapse_user_messages():
|
@@ -693,6 +703,37 @@ def simple_input_messages(
|
|
693
703
|
return messages
|
694
704
|
|
695
705
|
|
706
|
+
def tool_result_images_as_user_message(
|
707
|
+
messages: list[ChatMessage],
|
708
|
+
) -> list[ChatMessage]:
|
709
|
+
return functools.reduce(tool_result_images_reducer, messages, [])
|
710
|
+
|
711
|
+
|
712
|
+
def tool_result_images_reducer(
|
713
|
+
messages: list[ChatMessage],
|
714
|
+
message: ChatMessage,
|
715
|
+
) -> list[ChatMessage]:
|
716
|
+
# append the message
|
717
|
+
messages.append(message)
|
718
|
+
|
719
|
+
# if there are tool result images, pull them out into a ChatUserMessage
|
720
|
+
if isinstance(message, ChatMessageTool) and isinstance(message.content, list):
|
721
|
+
user_content: list[Content] = []
|
722
|
+
for i in range(0, len(message.content)):
|
723
|
+
if isinstance(message.content[i], ContentImage):
|
724
|
+
user_content.append(message.content[i])
|
725
|
+
message.content[i] = ContentText(
|
726
|
+
text="Image content is in the message below."
|
727
|
+
)
|
728
|
+
if len(user_content) > 0:
|
729
|
+
messages.append(
|
730
|
+
ChatMessageUser(content=user_content, tool_call_id=message.tool_call_id)
|
731
|
+
)
|
732
|
+
|
733
|
+
# return messages
|
734
|
+
return messages
|
735
|
+
|
736
|
+
|
696
737
|
# Functions to reduce consecutive user messages to a single user message -> required for some models
|
697
738
|
def collapse_consecutive_user_messages(
|
698
739
|
messages: list[ChatMessage],
|
@@ -229,6 +229,10 @@ class AnthropicAPI(ModelAPI):
|
|
229
229
|
def tools_required(self) -> bool:
|
230
230
|
return True
|
231
231
|
|
232
|
+
@override
|
233
|
+
def tool_result_images(self) -> bool:
|
234
|
+
return True
|
235
|
+
|
232
236
|
# convert some common BadRequestError states into 'refusal' model output
|
233
237
|
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None:
|
234
238
|
error = exception_message(ex).lower()
|
@@ -194,7 +194,9 @@ class GoogleAPI(ModelAPI):
|
|
194
194
|
model=self.model_name, content=ex.message, stop_reason="model_length"
|
195
195
|
)
|
196
196
|
else:
|
197
|
-
|
197
|
+
return ModelOutput.from_content(
|
198
|
+
model=self.model_name, content=ex.message, stop_reason="unknown"
|
199
|
+
)
|
198
200
|
|
199
201
|
@override
|
200
202
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -408,25 +410,34 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
|
408
410
|
# https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
|
409
411
|
|
410
412
|
|
411
|
-
def schema_from_param(param: ToolParam | ToolParams) -> Schema:
|
413
|
+
def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
|
412
414
|
if isinstance(param, ToolParams):
|
413
415
|
param = ToolParam(
|
414
416
|
type=param.type, properties=param.properties, required=param.required
|
415
417
|
)
|
416
418
|
|
417
419
|
if param.type == "number":
|
418
|
-
return Schema(
|
420
|
+
return Schema(
|
421
|
+
type=Type.NUMBER, description=param.description, nullable=nullable
|
422
|
+
)
|
419
423
|
elif param.type == "integer":
|
420
|
-
return Schema(
|
424
|
+
return Schema(
|
425
|
+
type=Type.INTEGER, description=param.description, nullable=nullable
|
426
|
+
)
|
421
427
|
elif param.type == "boolean":
|
422
|
-
return Schema(
|
428
|
+
return Schema(
|
429
|
+
type=Type.BOOLEAN, description=param.description, nullable=nullable
|
430
|
+
)
|
423
431
|
elif param.type == "string":
|
424
|
-
return Schema(
|
432
|
+
return Schema(
|
433
|
+
type=Type.STRING, description=param.description, nullable=nullable
|
434
|
+
)
|
425
435
|
elif param.type == "array":
|
426
436
|
return Schema(
|
427
437
|
type=Type.ARRAY,
|
428
438
|
description=param.description,
|
429
439
|
items=schema_from_param(param.items) if param.items else None,
|
440
|
+
nullable=nullable,
|
430
441
|
)
|
431
442
|
elif param.type == "object":
|
432
443
|
return Schema(
|
@@ -436,7 +447,14 @@ def schema_from_param(param: ToolParam | ToolParams) -> Schema:
|
|
436
447
|
if param.properties is not None
|
437
448
|
else None,
|
438
449
|
required=param.required,
|
450
|
+
nullable=nullable,
|
439
451
|
)
|
452
|
+
# convert unions to optional params if the second type is 'null'
|
453
|
+
elif param.anyOf:
|
454
|
+
if len(param.anyOf) == 2 and param.anyOf[1].type == "null":
|
455
|
+
return schema_from_param(param.anyOf[0], nullable=True)
|
456
|
+
else:
|
457
|
+
return Schema(type=Type.TYPE_UNSPECIFIED)
|
440
458
|
else:
|
441
459
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
442
460
|
|
@@ -51,6 +51,7 @@ from .._model_output import (
|
|
51
51
|
Logprobs,
|
52
52
|
ModelOutput,
|
53
53
|
ModelUsage,
|
54
|
+
StopReason,
|
54
55
|
)
|
55
56
|
from .openai_o1 import generate_o1
|
56
57
|
from .util import (
|
@@ -262,7 +263,10 @@ class OpenAIAPI(ModelAPI):
|
|
262
263
|
model=self.model_name,
|
263
264
|
)
|
264
265
|
if config.max_tokens is not None:
|
265
|
-
|
266
|
+
if self.is_o1():
|
267
|
+
params["max_completion_tokens"] = config.max_tokens
|
268
|
+
else:
|
269
|
+
params["max_tokens"] = config.max_tokens
|
266
270
|
if config.frequency_penalty is not None:
|
267
271
|
params["frequency_penalty"] = config.frequency_penalty
|
268
272
|
if config.stop_seqs is not None:
|
@@ -303,13 +307,23 @@ class OpenAIAPI(ModelAPI):
|
|
303
307
|
|
304
308
|
# convert some well known bad request errors into ModelOutput
|
305
309
|
def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
|
306
|
-
if e.status_code == 400
|
310
|
+
if e.status_code == 400:
|
311
|
+
# extract message
|
307
312
|
if isinstance(e.body, dict) and "message" in e.body.keys():
|
308
313
|
content = str(e.body.get("message"))
|
309
314
|
else:
|
310
315
|
content = e.message
|
316
|
+
|
317
|
+
# narrow stop_reason
|
318
|
+
if e.code == "context_length_exceeded":
|
319
|
+
stop_reason: StopReason = "model_length"
|
320
|
+
elif e.code == "invalid_prompt":
|
321
|
+
stop_reason = "content_filter"
|
322
|
+
else:
|
323
|
+
stop_reason = "unknown"
|
324
|
+
|
311
325
|
return ModelOutput.from_content(
|
312
|
-
model=self.model_name, content=content, stop_reason=
|
326
|
+
model=self.model_name, content=content, stop_reason=stop_reason
|
313
327
|
)
|
314
328
|
else:
|
315
329
|
raise e
|
@@ -25,7 +25,7 @@ from inspect_ai.model import (
|
|
25
25
|
from inspect_ai.tool import ToolCall, ToolInfo
|
26
26
|
|
27
27
|
from .._model_call import ModelCall
|
28
|
-
from .._model_output import ModelUsage
|
28
|
+
from .._model_output import ModelUsage, StopReason
|
29
29
|
from .._providers.util import (
|
30
30
|
ChatAPIHandler,
|
31
31
|
ChatAPIMessage,
|
@@ -48,12 +48,6 @@ async def generate_o1(
|
|
48
48
|
# create chatapi handler
|
49
49
|
handler = O1PreviewChatAPIHandler()
|
50
50
|
|
51
|
-
# map max_tokens => max_completion_tokens
|
52
|
-
max_tokens = params.get("max_tokens", None)
|
53
|
-
if max_tokens:
|
54
|
-
params["max_completion_tokens"] = max_tokens
|
55
|
-
del params["max_tokens"]
|
56
|
-
|
57
51
|
# call model
|
58
52
|
request = dict(
|
59
53
|
model=model,
|
@@ -89,12 +83,16 @@ async def generate_o1(
|
|
89
83
|
|
90
84
|
|
91
85
|
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
|
92
|
-
if ex.code == "
|
93
|
-
|
94
|
-
|
95
|
-
|
86
|
+
if ex.code == "context_length_exceeded":
|
87
|
+
stop_reason: StopReason = "model_length"
|
88
|
+
elif ex.code == "invalid_prompt":
|
89
|
+
stop_reason = "content_filter"
|
96
90
|
else:
|
97
|
-
|
91
|
+
stop_reason = "unknown"
|
92
|
+
|
93
|
+
return ModelOutput.from_content(
|
94
|
+
model=model, content=str(ex), stop_reason=stop_reason
|
95
|
+
)
|
98
96
|
|
99
97
|
|
100
98
|
def chat_messages(
|
inspect_ai/model/_render.py
CHANGED
@@ -3,13 +3,20 @@ from rich.console import RenderableType
|
|
3
3
|
from inspect_ai.tool._tool_call import ToolCall
|
4
4
|
from inspect_ai.tool._tool_transcript import transcript_tool_call
|
5
5
|
|
6
|
-
from ._chat_message import
|
6
|
+
from ._chat_message import (
|
7
|
+
ChatMessage,
|
8
|
+
ChatMessageAssistant,
|
9
|
+
ChatMessageTool,
|
10
|
+
ChatMessageUser,
|
11
|
+
)
|
7
12
|
|
8
13
|
|
9
14
|
def messages_preceding_assistant(messages: list[ChatMessage]) -> list[ChatMessage]:
|
10
15
|
preceding: list[ChatMessage] = []
|
11
16
|
for m in reversed(messages):
|
12
|
-
if not isinstance(m, ChatMessageTool | ChatMessageAssistant)
|
17
|
+
if not isinstance(m, ChatMessageTool | ChatMessageAssistant) and not (
|
18
|
+
isinstance(m, ChatMessageUser) and m.tool_call_id
|
19
|
+
):
|
13
20
|
preceding.append(m)
|
14
21
|
else:
|
15
22
|
break
|
inspect_ai/scorer/_metric.py
CHANGED
@@ -90,6 +90,13 @@ class Score(BaseModel):
|
|
90
90
|
"""Read the score as a boolean."""
|
91
91
|
return bool(self._as_scalar())
|
92
92
|
|
93
|
+
def as_list(self) -> list[str | int | float | bool]:
|
94
|
+
"""Read the score as a list."""
|
95
|
+
if isinstance(self.value, list):
|
96
|
+
return self.value
|
97
|
+
else:
|
98
|
+
raise ValueError("This score is not a list")
|
99
|
+
|
93
100
|
def as_dict(self) -> dict[str, str | int | float | bool | None]:
|
94
101
|
"""Read the score as a dictionary."""
|
95
102
|
if isinstance(self.value, dict):
|
@@ -104,13 +111,17 @@ class Score(BaseModel):
|
|
104
111
|
raise ValueError("This score is not a scalar")
|
105
112
|
|
106
113
|
|
107
|
-
class SampleScore(
|
114
|
+
class SampleScore(BaseModel):
|
108
115
|
"""Score for a Sample
|
109
116
|
|
110
117
|
Args:
|
118
|
+
score: Score
|
111
119
|
sample_id: (str | int | None) Unique id of a sample
|
112
120
|
"""
|
113
121
|
|
122
|
+
score: Score
|
123
|
+
"""A score"""
|
124
|
+
|
114
125
|
sample_id: str | int | None = Field(default=None)
|
115
126
|
"""A sample id"""
|
116
127
|
|
inspect_ai/solver/__init__.py
CHANGED
@@ -4,6 +4,7 @@ from ._basic_agent import basic_agent
|
|
4
4
|
from ._chain import chain
|
5
5
|
from ._critique import self_critique
|
6
6
|
from ._fork import fork
|
7
|
+
from ._human_agent.agent import human_agent
|
7
8
|
from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
|
8
9
|
from ._plan import Plan, plan
|
9
10
|
from ._prompt import (
|
@@ -17,6 +18,7 @@ from ._use_tools import use_tools
|
|
17
18
|
|
18
19
|
__all__ = [
|
19
20
|
"basic_agent",
|
21
|
+
"human_agent",
|
20
22
|
"chain",
|
21
23
|
"fork",
|
22
24
|
"generate",
|
@@ -0,0 +1,83 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
3
|
+
from inspect_ai.util import display_type, input_panel, sandbox
|
4
|
+
|
5
|
+
from .._solver import Generate, Solver, solver
|
6
|
+
from .._task_state import TaskState
|
7
|
+
from .commands import human_agent_commands
|
8
|
+
from .install import install_human_agent
|
9
|
+
from .panel import HumanAgentPanel
|
10
|
+
from .service import run_human_agent_service
|
11
|
+
from .view import ConsoleView, HumanAgentView
|
12
|
+
|
13
|
+
|
14
|
+
@solver
|
15
|
+
def human_agent(
|
16
|
+
answer: bool | str = True,
|
17
|
+
intermediate_scoring: bool = False,
|
18
|
+
record_session: bool = True,
|
19
|
+
) -> Solver:
|
20
|
+
"""Human solver for agentic tasks that run in a Linux environment.
|
21
|
+
|
22
|
+
The Human agent solver installs agent task tools in the default
|
23
|
+
sandbox and presents the user with both task instructions and
|
24
|
+
documentation for the various tools (e.g. `task submit`,
|
25
|
+
`task start`, `task stop` `task instructions`, etc.). A human agent panel
|
26
|
+
is displayed with instructions for logging in to the sandbox.
|
27
|
+
|
28
|
+
If the user is running in VS Code with the Inspect extension,
|
29
|
+
they will also be presented with links to login to the sandbox
|
30
|
+
using a VS Code Window or Terminal.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
answer (bool | str): Is an explicit answer required for this
|
34
|
+
task or is it scored based on files in the container? Pass a
|
35
|
+
`str` with a regex to validate that the answer matches
|
36
|
+
the expected format.
|
37
|
+
intermediate_scoring (bool): Allow the human agent to
|
38
|
+
check their score while working.
|
39
|
+
record_session (bool): Record all user commands and outputs in
|
40
|
+
the sandbox bash session.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
Solver: Human agent solver.
|
44
|
+
"""
|
45
|
+
# we can only run one human agent interaction at a time (use lock to enforce)
|
46
|
+
agent_lock = asyncio.Lock()
|
47
|
+
|
48
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
49
|
+
async with agent_lock:
|
50
|
+
# ensure that we have a sandbox to work with
|
51
|
+
try:
|
52
|
+
connection = await sandbox().connection()
|
53
|
+
except ProcessLookupError:
|
54
|
+
raise RuntimeError("Human agent must run in a task with a sandbox.")
|
55
|
+
except NotImplementedError:
|
56
|
+
raise RuntimeError(
|
57
|
+
"Human agent must run with a sandbox that supports connections."
|
58
|
+
)
|
59
|
+
|
60
|
+
# helper function to run the agent (called for fullscreen vs. fallback below)
|
61
|
+
async def run_human_agent(view: HumanAgentView) -> TaskState:
|
62
|
+
# create agent commands
|
63
|
+
commands = human_agent_commands(
|
64
|
+
state, answer, intermediate_scoring, record_session
|
65
|
+
)
|
66
|
+
|
67
|
+
# install agent tools
|
68
|
+
await install_human_agent(state, commands, record_session)
|
69
|
+
|
70
|
+
# hookup the view ui
|
71
|
+
view.connect(connection)
|
72
|
+
|
73
|
+
# run sandbox service
|
74
|
+
return await run_human_agent_service(state, commands, view)
|
75
|
+
|
76
|
+
# support both fullscreen ui and fallback
|
77
|
+
if display_type() == "full":
|
78
|
+
async with await input_panel(HumanAgentPanel) as panel:
|
79
|
+
return await run_human_agent(panel)
|
80
|
+
else:
|
81
|
+
return await run_human_agent(ConsoleView())
|
82
|
+
|
83
|
+
return solve
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from inspect_ai.solver._task_state import TaskState
|
2
|
+
|
3
|
+
from .clock import StartCommand, StopCommand
|
4
|
+
from .command import HumanAgentCommand
|
5
|
+
from .instructions import InstructionsCommand
|
6
|
+
from .note import NoteCommand
|
7
|
+
from .score import ScoreCommand
|
8
|
+
from .status import StatusCommand
|
9
|
+
from .submit import SubmitCommand, ValidateCommand
|
10
|
+
|
11
|
+
|
12
|
+
def human_agent_commands(
|
13
|
+
state: TaskState,
|
14
|
+
answer: bool | str,
|
15
|
+
intermediate_scoring: bool,
|
16
|
+
record_session: bool,
|
17
|
+
) -> list[HumanAgentCommand]:
|
18
|
+
# base submit and validate
|
19
|
+
commands = [SubmitCommand(record_session), ValidateCommand(answer)]
|
20
|
+
|
21
|
+
# optional intermediate scoring
|
22
|
+
if intermediate_scoring:
|
23
|
+
commands.append(ScoreCommand(state))
|
24
|
+
|
25
|
+
# remaining commands
|
26
|
+
commands.extend(
|
27
|
+
[
|
28
|
+
NoteCommand(),
|
29
|
+
StatusCommand(),
|
30
|
+
StartCommand(),
|
31
|
+
StopCommand(),
|
32
|
+
]
|
33
|
+
)
|
34
|
+
|
35
|
+
# with instructions (letting it see the other commands)
|
36
|
+
return commands + [InstructionsCommand(commands)]
|