inspect-ai 0.3.51__py3-none-any.whl → 0.3.53__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +44 -2
- inspect_ai/_display/core/config.py +4 -0
- inspect_ai/_display/core/panel.py +1 -1
- inspect_ai/_display/core/progress.py +9 -3
- inspect_ai/_display/core/results.py +8 -4
- inspect_ai/_display/textual/widgets/task_detail.py +45 -13
- inspect_ai/_display/textual/widgets/tasks.py +86 -5
- inspect_ai/_display/textual/widgets/transcript.py +4 -17
- inspect_ai/_eval/eval.py +29 -1
- inspect_ai/_eval/evalset.py +7 -0
- inspect_ai/_eval/registry.py +2 -2
- inspect_ai/_eval/task/log.py +6 -1
- inspect_ai/_eval/task/results.py +22 -4
- inspect_ai/_eval/task/run.py +18 -12
- inspect_ai/_eval/task/sandbox.py +72 -43
- inspect_ai/_eval/task/task.py +4 -0
- inspect_ai/_eval/task/util.py +17 -6
- inspect_ai/_util/logger.py +10 -2
- inspect_ai/_util/samples.py +7 -0
- inspect_ai/_util/transcript.py +8 -0
- inspect_ai/_view/www/App.css +13 -0
- inspect_ai/_view/www/dist/assets/index.css +13 -0
- inspect_ai/_view/www/dist/assets/index.js +105 -55
- inspect_ai/_view/www/src/App.mjs +31 -6
- inspect_ai/_view/www/src/Types.mjs +6 -0
- inspect_ai/_view/www/src/components/JsonPanel.mjs +11 -17
- inspect_ai/_view/www/src/components/MessageContent.mjs +9 -2
- inspect_ai/_view/www/src/components/Tools.mjs +46 -18
- inspect_ai/_view/www/src/navbar/Navbar.mjs +12 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +18 -5
- inspect_ai/_view/www/src/samples/SampleList.mjs +2 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +2 -2
- inspect_ai/log/_log.py +6 -0
- inspect_ai/log/_recorders/eval.py +8 -7
- inspect_ai/model/_call_tools.py +2 -6
- inspect_ai/model/_generate_config.py +6 -0
- inspect_ai/model/_model.py +18 -4
- inspect_ai/model/_providers/azureai.py +22 -2
- inspect_ai/model/_providers/bedrock.py +17 -1
- inspect_ai/model/_providers/hf.py +1 -1
- inspect_ai/model/_providers/openai.py +32 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/vllm.py +1 -1
- inspect_ai/model/_render.py +7 -6
- inspect_ai/model/_trace.py +1 -1
- inspect_ai/solver/_basic_agent.py +8 -1
- inspect_ai/tool/_tool_transcript.py +28 -0
- inspect_ai/util/_sandbox/context.py +1 -2
- inspect_ai/util/_sandbox/docker/config.py +8 -10
- inspect_ai/util/_sandbox/docker/docker.py +9 -5
- inspect_ai/util/_sandbox/docker/util.py +3 -3
- inspect_ai/util/_sandbox/environment.py +7 -2
- inspect_ai/util/_sandbox/limits.py +1 -1
- inspect_ai/util/_sandbox/local.py +8 -9
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/METADATA +2 -4
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/RECORD +60 -59
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/top_level.txt +0 -0
@@ -350,7 +350,17 @@ const metadataViewsForSample = (id, sample) => {
|
|
350
350
|
return sampleMetadatas;
|
351
351
|
};
|
352
352
|
|
353
|
-
|
353
|
+
/**
|
354
|
+
* Component to display a sample with relevant context and visibility control.
|
355
|
+
*
|
356
|
+
* @param {Object} props - The properties passed to the component.
|
357
|
+
* @param {string} props.parent_id - The id of the parent com
|
358
|
+
* @param {import("../types/log").EvalSample} [props.sample] - the sample
|
359
|
+
* @param {Object} [props.style] - Inline styles for the table element.
|
360
|
+
* @param {import("../samples/SamplesDescriptor.mjs").SamplesDescriptor} props.sampleDescriptor - the sample descriptor
|
361
|
+
* @returns {import("preact").JSX.Element} The TranscriptView component.
|
362
|
+
*/
|
363
|
+
const SampleSummary = ({ parent_id, sample, style, sampleDescriptor }) => {
|
354
364
|
const input =
|
355
365
|
sampleDescriptor?.messageShape.normalized.input > 0
|
356
366
|
? Math.max(0.15, sampleDescriptor.messageShape.normalized.input)
|
@@ -386,7 +396,7 @@ const SampleSummary = ({ id, sample, style, sampleDescriptor }) => {
|
|
386
396
|
const columns = [];
|
387
397
|
columns.push({
|
388
398
|
label: "Id",
|
389
|
-
value: id,
|
399
|
+
value: sample.id,
|
390
400
|
size: `${idSize}em`,
|
391
401
|
});
|
392
402
|
|
@@ -412,7 +422,8 @@ const SampleSummary = ({ id, sample, style, sampleDescriptor }) => {
|
|
412
422
|
|
413
423
|
const fullAnswer =
|
414
424
|
sample && sampleDescriptor
|
415
|
-
?
|
425
|
+
? // @ts-ignore
|
426
|
+
sampleDescriptor.selectedScorer(sample).answer()
|
416
427
|
: undefined;
|
417
428
|
if (fullAnswer) {
|
418
429
|
columns.push({
|
@@ -445,14 +456,16 @@ const SampleSummary = ({ id, sample, style, sampleDescriptor }) => {
|
|
445
456
|
message=${sample.error.message}
|
446
457
|
style=${{ marginTop: "0.4rem" }}
|
447
458
|
/>`
|
448
|
-
:
|
459
|
+
: // TODO: Cleanup once the PR lands which makes sample / sample summary share common interface
|
460
|
+
// @ts-ignore
|
461
|
+
sampleDescriptor?.selectedScore(sample).render(),
|
449
462
|
size: "minmax(2em, auto)",
|
450
463
|
center: true,
|
451
464
|
});
|
452
465
|
|
453
466
|
return html`
|
454
467
|
<div
|
455
|
-
id=${`sample-heading-${
|
468
|
+
id=${`sample-heading-${parent_id}`}
|
456
469
|
style=${{
|
457
470
|
display: "grid",
|
458
471
|
gridTemplateColumns: `${columns
|
@@ -145,7 +145,7 @@ export const SampleList = (props) => {
|
|
145
145
|
);
|
146
146
|
|
147
147
|
const listStyle = { ...style, flex: "1", overflowY: "auto", outline: "none" };
|
148
|
-
const { limit, answer } = gridColumns(sampleDescriptor);
|
148
|
+
const { limit, answer, target } = gridColumns(sampleDescriptor);
|
149
149
|
|
150
150
|
const headerRow = html`<div
|
151
151
|
style=${{
|
@@ -161,7 +161,7 @@ export const SampleList = (props) => {
|
|
161
161
|
>
|
162
162
|
<div>Id</div>
|
163
163
|
<div>Input</div>
|
164
|
-
<div
|
164
|
+
<div>${target !== "0" ? "Target" : ""}</div>
|
165
165
|
<div>${answer !== "0" ? "Answer" : ""}</div>
|
166
166
|
<div>${limit !== "0" ? "Limit" : ""}</div>
|
167
167
|
<div style=${{ justifySelf: "center" }}>Score</div>
|
@@ -29,10 +29,10 @@ export const ToolEventView = ({ id, event, style, depth }) => {
|
|
29
29
|
return e.event === "approval";
|
30
30
|
});
|
31
31
|
|
32
|
-
const title = `Tool: ${event.function}`;
|
32
|
+
const title = `Tool: ${event.view?.title || event.function}`;
|
33
33
|
return html`
|
34
34
|
<${EventPanel} id=${id} title="${title}" subTitle=${formatDateTime(new Date(event.timestamp))} icon=${ApplicationIcons.solvers.use_tools} style=${style}>
|
35
|
-
<div name="Summary" style=${{ margin: "0.5em 0" }}>
|
35
|
+
<div name="Summary" style=${{ margin: "0.5em 0", width: "100%" }}>
|
36
36
|
<${ToolCallView}
|
37
37
|
functionCall=${functionCall}
|
38
38
|
input=${input}
|
inspect_ai/log/_log.py
CHANGED
@@ -37,6 +37,9 @@ class EvalConfig(BaseModel):
|
|
37
37
|
limit: int | tuple[int, int] | None = Field(default=None)
|
38
38
|
"""Sample limit (number of samples or range of samples)."""
|
39
39
|
|
40
|
+
sample_id: str | int | list[str | int] | None = Field(default=None)
|
41
|
+
"""Evaluate specific sample(s)."""
|
42
|
+
|
40
43
|
epochs: int | None = Field(default=None)
|
41
44
|
"""Number of epochs to run samples over."""
|
42
45
|
|
@@ -76,6 +79,9 @@ class EvalConfig(BaseModel):
|
|
76
79
|
max_subprocesses: int | None = Field(default=None)
|
77
80
|
"""Maximum number of subprocesses to run concurrently."""
|
78
81
|
|
82
|
+
max_sandboxes: int | None = Field(default=None)
|
83
|
+
"""Maximum number of sandboxes to run concurrently."""
|
84
|
+
|
79
85
|
sandbox_cleanup: bool | None = Field(default=None)
|
80
86
|
"""Cleanup sandbox environments after task completes."""
|
81
87
|
|
@@ -362,13 +362,14 @@ class ZipLogFile:
|
|
362
362
|
f"Error occurred during async write to {self._file}: {ex}. Falling back to sync write."
|
363
363
|
)
|
364
364
|
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
365
|
+
try:
|
366
|
+
# write sync if we need to
|
367
|
+
if not written:
|
368
|
+
with file(self._file, "wb") as f:
|
369
|
+
f.write(log_bytes)
|
370
|
+
finally:
|
371
|
+
# re-open zip file w/ self.temp_file pointer at end
|
372
|
+
self._open()
|
372
373
|
|
373
374
|
async def close(self) -> EvalLog:
|
374
375
|
async with self._lock:
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -68,10 +68,6 @@ async def call_tools(
|
|
68
68
|
# create a transript for this call
|
69
69
|
init_transcript(Transcript(name=call.function))
|
70
70
|
|
71
|
-
# Amend the tool call with a custom view
|
72
|
-
view = tool_call_view(call, tdefs)
|
73
|
-
call.view = view
|
74
|
-
|
75
71
|
result: Any = ""
|
76
72
|
tool_error: ToolCallError | None = None
|
77
73
|
try:
|
@@ -142,7 +138,7 @@ async def call_tools(
|
|
142
138
|
arguments=call.arguments,
|
143
139
|
result=content,
|
144
140
|
truncated=truncated,
|
145
|
-
view=view,
|
141
|
+
view=call.view,
|
146
142
|
error=tool_error,
|
147
143
|
events=list(transcript().events),
|
148
144
|
)
|
@@ -163,7 +159,7 @@ async def call_tools(
|
|
163
159
|
id=call.id,
|
164
160
|
function=call.function,
|
165
161
|
arguments=call.arguments,
|
166
|
-
view=
|
162
|
+
view=call.view,
|
167
163
|
pending=True,
|
168
164
|
)
|
169
165
|
transcript()._event(event)
|
@@ -72,6 +72,9 @@ class GenerateConfigArgs(TypedDict, total=False):
|
|
72
72
|
cache_prompt: Literal["auto"] | bool | None
|
73
73
|
"""Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
|
74
74
|
|
75
|
+
reasoning_effort: Literal["low", "medium", "high"] | None
|
76
|
+
"""Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
|
77
|
+
|
75
78
|
|
76
79
|
class GenerateConfig(BaseModel):
|
77
80
|
"""Base class for model generation configs."""
|
@@ -139,6 +142,9 @@ class GenerateConfig(BaseModel):
|
|
139
142
|
cache_prompt: Literal["auto"] | bool | None = Field(default=None)
|
140
143
|
"""Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
|
141
144
|
|
145
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = Field(default=None)
|
146
|
+
"""Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
|
147
|
+
|
142
148
|
def merge(
|
143
149
|
self, other: Union["GenerateConfig", GenerateConfigArgs]
|
144
150
|
) -> "GenerateConfig":
|
inspect_ai/model/_model.py
CHANGED
@@ -31,11 +31,11 @@ from inspect_ai._util.registry import (
|
|
31
31
|
)
|
32
32
|
from inspect_ai._util.retry import log_rate_limit_retry
|
33
33
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
34
|
-
from inspect_ai.tool._tool_def import ToolDef
|
34
|
+
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
35
35
|
from inspect_ai.util import concurrency
|
36
36
|
|
37
37
|
from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
|
38
|
-
from ._call_tools import disable_parallel_tools, tools_info
|
38
|
+
from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
|
39
39
|
from ._chat_message import (
|
40
40
|
ChatMessage,
|
41
41
|
ChatMessageAssistant,
|
@@ -248,7 +248,7 @@ class Model:
|
|
248
248
|
async with self._connection_concurrency(config):
|
249
249
|
return await self._generate(
|
250
250
|
input=input,
|
251
|
-
tools=
|
251
|
+
tools=tools,
|
252
252
|
tool_choice=tool_choice,
|
253
253
|
config=config,
|
254
254
|
cache=cache,
|
@@ -257,7 +257,10 @@ class Model:
|
|
257
257
|
async def _generate(
|
258
258
|
self,
|
259
259
|
input: list[ChatMessage],
|
260
|
-
tools: list[
|
260
|
+
tools: list[Tool]
|
261
|
+
| list[ToolDef]
|
262
|
+
| list[ToolInfo]
|
263
|
+
| list[Tool | ToolDef | ToolInfo],
|
261
264
|
tool_choice: ToolChoice | None,
|
262
265
|
config: GenerateConfig,
|
263
266
|
cache: bool | CachePolicy = False,
|
@@ -265,6 +268,12 @@ class Model:
|
|
265
268
|
# default to 'auto' for tool_choice (same as underlying model apis)
|
266
269
|
tool_choice = tool_choice if tool_choice else "auto"
|
267
270
|
|
271
|
+
# extract tool defs if we can
|
272
|
+
tdefs = tool_defs([tool for tool in tools if not isinstance(tool, ToolInfo)])
|
273
|
+
|
274
|
+
# resolve all tools into tool_info
|
275
|
+
tools = tools_info(tools)
|
276
|
+
|
268
277
|
# if we have a specific tool selected then filter out the others
|
269
278
|
if isinstance(tool_choice, ToolFunction):
|
270
279
|
tools = [tool for tool in tools if tool.name == tool_choice.name]
|
@@ -374,6 +383,11 @@ class Model:
|
|
374
383
|
# update output with time elapsed
|
375
384
|
output.time = time_elapsed
|
376
385
|
|
386
|
+
# add views to tool calls
|
387
|
+
for choice in output.choices:
|
388
|
+
for tool_call in choice.message.tool_calls or []:
|
389
|
+
tool_call.view = tool_call_view(tool_call, tdefs)
|
390
|
+
|
377
391
|
# complete the transcript event
|
378
392
|
complete(output, call)
|
379
393
|
|
@@ -89,6 +89,19 @@ class AzureAIAPI(ModelAPI):
|
|
89
89
|
config=config,
|
90
90
|
)
|
91
91
|
|
92
|
+
# collect known model_args (then delete them so we can pass the rest on)
|
93
|
+
def collect_model_arg(name: str) -> Any | None:
|
94
|
+
nonlocal model_args
|
95
|
+
value = model_args.get(name, None)
|
96
|
+
if value is not None:
|
97
|
+
model_args.pop(name)
|
98
|
+
return value
|
99
|
+
|
100
|
+
emulate_tools = collect_model_arg("emulate_tools")
|
101
|
+
self.emulate_tools = (
|
102
|
+
not not emulate_tools if emulate_tools is not None else None
|
103
|
+
)
|
104
|
+
|
92
105
|
# resolve api_key
|
93
106
|
if not self.api_key:
|
94
107
|
self.api_key = os.environ.get(
|
@@ -118,8 +131,15 @@ class AzureAIAPI(ModelAPI):
|
|
118
131
|
tool_choice: ToolChoice,
|
119
132
|
config: GenerateConfig,
|
120
133
|
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
121
|
-
#
|
122
|
-
|
134
|
+
# emulate tools (auto for llama, opt-in for others)
|
135
|
+
if self.emulate_tools is None and self.is_llama():
|
136
|
+
handler: ChatAPIHandler | None = Llama31Handler()
|
137
|
+
elif self.emulate_tools:
|
138
|
+
handler = Llama31Handler()
|
139
|
+
else:
|
140
|
+
handler = None
|
141
|
+
|
142
|
+
# resolve input
|
123
143
|
if handler:
|
124
144
|
input = handler.input_with_tools(input, tools)
|
125
145
|
|
@@ -236,15 +236,21 @@ class BedrockAPI(ModelAPI):
|
|
236
236
|
self,
|
237
237
|
model_name: str,
|
238
238
|
base_url: str | None,
|
239
|
+
api_key: str | None = None,
|
239
240
|
config: GenerateConfig = GenerateConfig(),
|
240
241
|
**model_args: Any,
|
241
242
|
):
|
242
243
|
super().__init__(
|
243
244
|
model_name=model_name,
|
244
245
|
base_url=model_base_url(base_url, "BEDROCK_BASE_URL"),
|
246
|
+
api_key=api_key,
|
247
|
+
api_key_vars=[],
|
245
248
|
config=config,
|
246
249
|
)
|
247
250
|
|
251
|
+
# save model_args
|
252
|
+
self.model_args = model_args
|
253
|
+
|
248
254
|
# import aioboto3 on demand
|
249
255
|
try:
|
250
256
|
import aioboto3
|
@@ -263,6 +269,9 @@ class BedrockAPI(ModelAPI):
|
|
263
269
|
|
264
270
|
@override
|
265
271
|
def max_tokens(self) -> int | None:
|
272
|
+
if "llama3-70" in self.model_name or "llama3-8" in self.model_name:
|
273
|
+
return 2048
|
274
|
+
|
266
275
|
if "llama3" in self.model_name or "claude3" in self.model_name:
|
267
276
|
return 4096
|
268
277
|
|
@@ -316,6 +325,7 @@ class BedrockAPI(ModelAPI):
|
|
316
325
|
mode="adaptive",
|
317
326
|
),
|
318
327
|
),
|
328
|
+
**self.model_args,
|
319
329
|
) as client:
|
320
330
|
# Process the tools
|
321
331
|
resolved_tools = converse_tools(tools)
|
@@ -658,6 +668,8 @@ def converse_image_type(type: str) -> ConverseImageFormat:
|
|
658
668
|
return "png"
|
659
669
|
case "image/webp":
|
660
670
|
return "webp"
|
671
|
+
case "image/jpeg":
|
672
|
+
return "jpeg"
|
661
673
|
case _:
|
662
674
|
raise ValueError(
|
663
675
|
f"Image mime type {type} is not supported for Bedrock Converse models."
|
@@ -673,7 +685,11 @@ def converse_tools(tools: list[ToolInfo]) -> list[ConverseTool] | None:
|
|
673
685
|
tool_spec = ConverseToolSpec(
|
674
686
|
name=tool.name,
|
675
687
|
description=tool.description,
|
676
|
-
inputSchema={
|
688
|
+
inputSchema={
|
689
|
+
"json": tool.parameters.model_dump(
|
690
|
+
exclude_none=True, exclude={"additionalProperties"}
|
691
|
+
)
|
692
|
+
},
|
677
693
|
)
|
678
694
|
result.append(ConverseTool(toolSpec=tool_spec))
|
679
695
|
return result
|
@@ -18,6 +18,7 @@ from openai.types.chat import (
|
|
18
18
|
ChatCompletionContentPartImageParam,
|
19
19
|
ChatCompletionContentPartParam,
|
20
20
|
ChatCompletionContentPartTextParam,
|
21
|
+
ChatCompletionDeveloperMessageParam,
|
21
22
|
ChatCompletionMessage,
|
22
23
|
ChatCompletionMessageParam,
|
23
24
|
ChatCompletionMessageToolCallParam,
|
@@ -141,6 +142,18 @@ class OpenAIAPI(ModelAPI):
|
|
141
142
|
**model_args,
|
142
143
|
)
|
143
144
|
|
145
|
+
def is_o1(self) -> bool:
|
146
|
+
return self.model_name.startswith("o1")
|
147
|
+
|
148
|
+
def is_o1_full(self) -> bool:
|
149
|
+
return self.is_o1() and not self.is_o1_mini() and not self.is_o1_preview()
|
150
|
+
|
151
|
+
def is_o1_mini(self) -> bool:
|
152
|
+
return self.model_name.startswith("o1-mini")
|
153
|
+
|
154
|
+
def is_o1_preview(self) -> bool:
|
155
|
+
return self.model_name.startswith("o1-preview")
|
156
|
+
|
144
157
|
async def generate(
|
145
158
|
self,
|
146
159
|
input: list[ChatMessage],
|
@@ -148,8 +161,8 @@ class OpenAIAPI(ModelAPI):
|
|
148
161
|
tool_choice: ToolChoice,
|
149
162
|
config: GenerateConfig,
|
150
163
|
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
151
|
-
# short-circuit to call o1-
|
152
|
-
if self.
|
164
|
+
# short-circuit to call o1- models that are text only
|
165
|
+
if self.is_o1_preview() or self.is_o1_mini():
|
153
166
|
return await generate_o1(
|
154
167
|
client=self.client,
|
155
168
|
input=input,
|
@@ -179,7 +192,7 @@ class OpenAIAPI(ModelAPI):
|
|
179
192
|
|
180
193
|
# prepare request (we do this so we can log the ModelCall)
|
181
194
|
request = dict(
|
182
|
-
messages=await as_openai_chat_messages(input),
|
195
|
+
messages=await as_openai_chat_messages(input, self.is_o1_full()),
|
183
196
|
tools=chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
184
197
|
tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
|
185
198
|
**self.completion_params(config, len(tools) > 0),
|
@@ -271,8 +284,10 @@ class OpenAIAPI(ModelAPI):
|
|
271
284
|
params["logprobs"] = config.logprobs
|
272
285
|
if config.top_logprobs is not None:
|
273
286
|
params["top_logprobs"] = config.top_logprobs
|
274
|
-
if tools and config.parallel_tool_calls is not None:
|
287
|
+
if tools and config.parallel_tool_calls is not None and not self.is_o1():
|
275
288
|
params["parallel_tool_calls"] = config.parallel_tool_calls
|
289
|
+
if config.reasoning_effort is not None and self.is_o1_full():
|
290
|
+
params["reasoning_effort"] = config.reasoning_effort
|
276
291
|
|
277
292
|
return params
|
278
293
|
|
@@ -291,14 +306,23 @@ class OpenAIAPI(ModelAPI):
|
|
291
306
|
|
292
307
|
|
293
308
|
async def as_openai_chat_messages(
|
294
|
-
messages: list[ChatMessage],
|
309
|
+
messages: list[ChatMessage], o1_full: bool
|
295
310
|
) -> list[ChatCompletionMessageParam]:
|
296
|
-
return [await openai_chat_message(message) for message in messages]
|
311
|
+
return [await openai_chat_message(message, o1_full) for message in messages]
|
297
312
|
|
298
313
|
|
299
|
-
async def openai_chat_message(
|
314
|
+
async def openai_chat_message(
|
315
|
+
message: ChatMessage, o1_full: bool
|
316
|
+
) -> ChatCompletionMessageParam:
|
300
317
|
if message.role == "system":
|
301
|
-
|
318
|
+
if o1_full:
|
319
|
+
return ChatCompletionDeveloperMessageParam(
|
320
|
+
role="developer", content=message.text
|
321
|
+
)
|
322
|
+
else:
|
323
|
+
return ChatCompletionSystemMessageParam(
|
324
|
+
role=message.role, content=message.text
|
325
|
+
)
|
302
326
|
elif message.role == "user":
|
303
327
|
return ChatCompletionUserMessageParam(
|
304
328
|
role=message.role,
|
inspect_ai/model/_render.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1
1
|
from rich.console import RenderableType
|
2
2
|
|
3
|
-
from inspect_ai._util.format import format_function_call
|
4
|
-
from inspect_ai._util.transcript import transcript_markdown
|
5
3
|
from inspect_ai.tool._tool_call import ToolCall
|
4
|
+
from inspect_ai.tool._tool_transcript import transcript_tool_call
|
6
5
|
|
7
6
|
from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
|
8
7
|
|
@@ -17,8 +16,10 @@ def messages_preceding_assistant(messages: list[ChatMessage]) -> list[ChatMessag
|
|
17
16
|
return list(reversed(preceding))
|
18
17
|
|
19
18
|
|
20
|
-
def render_tool_calls(tool_calls: list[ToolCall]) -> RenderableType:
|
21
|
-
formatted_calls: list[
|
19
|
+
def render_tool_calls(tool_calls: list[ToolCall]) -> list[RenderableType]:
|
20
|
+
formatted_calls: list[RenderableType] = []
|
21
|
+
|
22
22
|
for call in tool_calls:
|
23
|
-
formatted_calls.
|
24
|
-
|
23
|
+
formatted_calls.extend(transcript_tool_call(call))
|
24
|
+
|
25
|
+
return formatted_calls
|
inspect_ai/model/_trace.py
CHANGED
@@ -42,7 +42,7 @@ def trace_assistant_message(
|
|
42
42
|
# print tool calls
|
43
43
|
if message.tool_calls:
|
44
44
|
content.append(Text())
|
45
|
-
content.
|
45
|
+
content.extend(render_tool_calls(message.tool_calls))
|
46
46
|
|
47
47
|
# print the assistant message
|
48
48
|
trace_panel(title="Assistant", content=content)
|
@@ -54,6 +54,7 @@ def basic_agent(
|
|
54
54
|
max_attempts: int = 1,
|
55
55
|
message_limit: int | None = None,
|
56
56
|
token_limit: int | None = None,
|
57
|
+
max_tool_output: int | None = None,
|
57
58
|
score_value: ValueToFloat | None = None,
|
58
59
|
incorrect_message: str
|
59
60
|
| Callable[[TaskState, list[Score]], str] = DEFAULT_INCORRECT_MESSAGE,
|
@@ -87,6 +88,8 @@ def basic_agent(
|
|
87
88
|
If not specified, will use limit_messages defined for the task. If there is none
|
88
89
|
defined for the task, 50 will be used as a default.
|
89
90
|
token_limit (int | None): Limit on tokens used in sample before terminating agent.
|
91
|
+
max_tool_output (int | None): Maximum output length (in bytes).
|
92
|
+
Defaults to max_tool_output from active GenerateConfig.
|
90
93
|
score_value (ValueToFloat): Function used to extract float from scores (defaults
|
91
94
|
to standard value_to_float())
|
92
95
|
incorrect_message (str | Callable[[TaskState, list[Score]], str]): User message reply for an
|
@@ -182,7 +185,9 @@ def basic_agent(
|
|
182
185
|
# resolve tools calls (if any)
|
183
186
|
if state.output.message.tool_calls:
|
184
187
|
# call tool functions
|
185
|
-
tool_results = await call_tools(
|
188
|
+
tool_results = await call_tools(
|
189
|
+
state.output.message, state.tools, max_output=max_tool_output
|
190
|
+
)
|
186
191
|
state.messages.extend(tool_results)
|
187
192
|
|
188
193
|
# was an answer submitted?
|
@@ -194,11 +199,13 @@ def basic_agent(
|
|
194
199
|
# exit if we are at max_attempts
|
195
200
|
attempts += 1
|
196
201
|
if attempts >= max_attempts:
|
202
|
+
state.completed = True
|
197
203
|
break
|
198
204
|
|
199
205
|
# exit if the submission is successful
|
200
206
|
answer_scores = await score(state)
|
201
207
|
if score_value_fn(answer_scores[0].value) == 1.0:
|
208
|
+
state.completed = True
|
202
209
|
break
|
203
210
|
|
204
211
|
# otherwise notify the model that it was incorrect and continue
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from pydantic import JsonValue
|
2
|
+
from rich.console import RenderableType
|
3
|
+
from rich.text import Text
|
4
|
+
from typing_extensions import Protocol
|
5
|
+
|
6
|
+
from inspect_ai._util.transcript import transcript_function, transcript_markdown
|
7
|
+
|
8
|
+
from ._tool_call import ToolCallContent
|
9
|
+
|
10
|
+
|
11
|
+
class TranscriptToolCall(Protocol):
|
12
|
+
function: str
|
13
|
+
arguments: dict[str, JsonValue]
|
14
|
+
view: ToolCallContent | None
|
15
|
+
|
16
|
+
|
17
|
+
def transcript_tool_call(call: TranscriptToolCall) -> list[RenderableType]:
|
18
|
+
content: list[RenderableType] = []
|
19
|
+
if call.view:
|
20
|
+
if call.view.title:
|
21
|
+
content.append(Text.from_markup(f"[bold]{call.view.title}[/bold]\n"))
|
22
|
+
if call.view.format == "markdown":
|
23
|
+
content.append(transcript_markdown(call.view.content))
|
24
|
+
else:
|
25
|
+
content.append(call.view.content)
|
26
|
+
else:
|
27
|
+
content.append(transcript_function(call.function, call.arguments))
|
28
|
+
return content
|
@@ -109,7 +109,7 @@ def raise_no_sandbox() -> NoReturn:
|
|
109
109
|
|
110
110
|
|
111
111
|
async def init_sandbox_environments_sample(
|
112
|
-
|
112
|
+
sandboxenv_type: type[SandboxEnvironment],
|
113
113
|
task_name: str,
|
114
114
|
config: SandboxEnvironmentConfigType | None,
|
115
115
|
files: dict[str, bytes],
|
@@ -117,7 +117,6 @@ async def init_sandbox_environments_sample(
|
|
117
117
|
metadata: dict[str, Any],
|
118
118
|
) -> dict[str, SandboxEnvironment]:
|
119
119
|
# get setup and cleanup functions
|
120
|
-
sandboxenv_type = registry_find_sandboxenv(type)
|
121
120
|
sample_init = cast(SampleInit, getattr(sandboxenv_type, "sample_init"))
|
122
121
|
sample_cleanup = cast(SampleCleanup, getattr(sandboxenv_type, "sample_cleanup"))
|
123
122
|
|
@@ -2,8 +2,6 @@ import os
|
|
2
2
|
from logging import getLogger
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
|
-
import aiofiles
|
6
|
-
|
7
5
|
logger = getLogger(__name__)
|
8
6
|
|
9
7
|
|
@@ -17,7 +15,7 @@ CONFIG_FILES = [
|
|
17
15
|
DOCKERFILE = "Dockerfile"
|
18
16
|
|
19
17
|
|
20
|
-
|
18
|
+
def resolve_compose_file(parent: str = "") -> str:
|
21
19
|
# existing compose file provides all the config we need
|
22
20
|
compose = find_compose_file(parent)
|
23
21
|
if compose is not None:
|
@@ -29,11 +27,11 @@ async def resolve_compose_file(parent: str = "") -> str:
|
|
29
27
|
|
30
28
|
# dockerfile just needs a compose.yaml synthesized
|
31
29
|
elif has_dockerfile(parent):
|
32
|
-
return
|
30
|
+
return auto_compose_file(COMPOSE_DOCKERFILE_YAML, parent)
|
33
31
|
|
34
32
|
# otherwise provide a generic python container
|
35
33
|
else:
|
36
|
-
return
|
34
|
+
return auto_compose_file(COMPOSE_GENERIC_YAML, parent)
|
37
35
|
|
38
36
|
|
39
37
|
def find_compose_file(parent: str = "") -> str | None:
|
@@ -59,9 +57,9 @@ def is_auto_compose_file(file: str) -> bool:
|
|
59
57
|
return os.path.basename(file) == AUTO_COMPOSE_YAML
|
60
58
|
|
61
59
|
|
62
|
-
|
60
|
+
def ensure_auto_compose_file(file: str | None) -> None:
|
63
61
|
if file is not None and is_auto_compose_file(file) and not os.path.exists(file):
|
64
|
-
|
62
|
+
resolve_compose_file(os.path.dirname(file))
|
65
63
|
|
66
64
|
|
67
65
|
def safe_cleanup_auto_compose(file: str | None) -> None:
|
@@ -100,8 +98,8 @@ services:
|
|
100
98
|
"""
|
101
99
|
|
102
100
|
|
103
|
-
|
101
|
+
def auto_compose_file(contents: str, parent: str = "") -> str:
|
104
102
|
path = os.path.join(parent, AUTO_COMPOSE_YAML)
|
105
|
-
|
106
|
-
|
103
|
+
with open(path, "w", encoding="utf-8") as f:
|
104
|
+
f.write(contents)
|
107
105
|
return Path(path).resolve().as_posix()
|