inspect-ai 0.3.51__py3-none-any.whl → 0.3.53__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. inspect_ai/_cli/eval.py +44 -2
  2. inspect_ai/_display/core/config.py +4 -0
  3. inspect_ai/_display/core/panel.py +1 -1
  4. inspect_ai/_display/core/progress.py +9 -3
  5. inspect_ai/_display/core/results.py +8 -4
  6. inspect_ai/_display/textual/widgets/task_detail.py +45 -13
  7. inspect_ai/_display/textual/widgets/tasks.py +86 -5
  8. inspect_ai/_display/textual/widgets/transcript.py +4 -17
  9. inspect_ai/_eval/eval.py +29 -1
  10. inspect_ai/_eval/evalset.py +7 -0
  11. inspect_ai/_eval/registry.py +2 -2
  12. inspect_ai/_eval/task/log.py +6 -1
  13. inspect_ai/_eval/task/results.py +22 -4
  14. inspect_ai/_eval/task/run.py +18 -12
  15. inspect_ai/_eval/task/sandbox.py +72 -43
  16. inspect_ai/_eval/task/task.py +4 -0
  17. inspect_ai/_eval/task/util.py +17 -6
  18. inspect_ai/_util/logger.py +10 -2
  19. inspect_ai/_util/samples.py +7 -0
  20. inspect_ai/_util/transcript.py +8 -0
  21. inspect_ai/_view/www/App.css +13 -0
  22. inspect_ai/_view/www/dist/assets/index.css +13 -0
  23. inspect_ai/_view/www/dist/assets/index.js +105 -55
  24. inspect_ai/_view/www/src/App.mjs +31 -6
  25. inspect_ai/_view/www/src/Types.mjs +6 -0
  26. inspect_ai/_view/www/src/components/JsonPanel.mjs +11 -17
  27. inspect_ai/_view/www/src/components/MessageContent.mjs +9 -2
  28. inspect_ai/_view/www/src/components/Tools.mjs +46 -18
  29. inspect_ai/_view/www/src/navbar/Navbar.mjs +12 -0
  30. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +18 -5
  31. inspect_ai/_view/www/src/samples/SampleList.mjs +2 -2
  32. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +2 -2
  33. inspect_ai/log/_log.py +6 -0
  34. inspect_ai/log/_recorders/eval.py +8 -7
  35. inspect_ai/model/_call_tools.py +2 -6
  36. inspect_ai/model/_generate_config.py +6 -0
  37. inspect_ai/model/_model.py +18 -4
  38. inspect_ai/model/_providers/azureai.py +22 -2
  39. inspect_ai/model/_providers/bedrock.py +17 -1
  40. inspect_ai/model/_providers/hf.py +1 -1
  41. inspect_ai/model/_providers/openai.py +32 -8
  42. inspect_ai/model/_providers/providers.py +1 -1
  43. inspect_ai/model/_providers/vllm.py +1 -1
  44. inspect_ai/model/_render.py +7 -6
  45. inspect_ai/model/_trace.py +1 -1
  46. inspect_ai/solver/_basic_agent.py +8 -1
  47. inspect_ai/tool/_tool_transcript.py +28 -0
  48. inspect_ai/util/_sandbox/context.py +1 -2
  49. inspect_ai/util/_sandbox/docker/config.py +8 -10
  50. inspect_ai/util/_sandbox/docker/docker.py +9 -5
  51. inspect_ai/util/_sandbox/docker/util.py +3 -3
  52. inspect_ai/util/_sandbox/environment.py +7 -2
  53. inspect_ai/util/_sandbox/limits.py +1 -1
  54. inspect_ai/util/_sandbox/local.py +8 -9
  55. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/METADATA +2 -4
  56. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/RECORD +60 -59
  57. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/LICENSE +0 -0
  58. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/WHEEL +0 -0
  59. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/entry_points.txt +0 -0
  60. {inspect_ai-0.3.51.dist-info → inspect_ai-0.3.53.dist-info}/top_level.txt +0 -0
@@ -350,7 +350,17 @@ const metadataViewsForSample = (id, sample) => {
350
350
  return sampleMetadatas;
351
351
  };
352
352
 
353
- const SampleSummary = ({ id, sample, style, sampleDescriptor }) => {
353
+ /**
354
+ * Component to display a sample with relevant context and visibility control.
355
+ *
356
+ * @param {Object} props - The properties passed to the component.
357
+ * @param {string} props.parent_id - The id of the parent com
358
+ * @param {import("../types/log").EvalSample} [props.sample] - the sample
359
+ * @param {Object} [props.style] - Inline styles for the table element.
360
+ * @param {import("../samples/SamplesDescriptor.mjs").SamplesDescriptor} props.sampleDescriptor - the sample descriptor
361
+ * @returns {import("preact").JSX.Element} The TranscriptView component.
362
+ */
363
+ const SampleSummary = ({ parent_id, sample, style, sampleDescriptor }) => {
354
364
  const input =
355
365
  sampleDescriptor?.messageShape.normalized.input > 0
356
366
  ? Math.max(0.15, sampleDescriptor.messageShape.normalized.input)
@@ -386,7 +396,7 @@ const SampleSummary = ({ id, sample, style, sampleDescriptor }) => {
386
396
  const columns = [];
387
397
  columns.push({
388
398
  label: "Id",
389
- value: id,
399
+ value: sample.id,
390
400
  size: `${idSize}em`,
391
401
  });
392
402
 
@@ -412,7 +422,8 @@ const SampleSummary = ({ id, sample, style, sampleDescriptor }) => {
412
422
 
413
423
  const fullAnswer =
414
424
  sample && sampleDescriptor
415
- ? sampleDescriptor.selectedScorer(sample).answer()
425
+ ? // @ts-ignore
426
+ sampleDescriptor.selectedScorer(sample).answer()
416
427
  : undefined;
417
428
  if (fullAnswer) {
418
429
  columns.push({
@@ -445,14 +456,16 @@ const SampleSummary = ({ id, sample, style, sampleDescriptor }) => {
445
456
  message=${sample.error.message}
446
457
  style=${{ marginTop: "0.4rem" }}
447
458
  />`
448
- : sampleDescriptor?.selectedScore(sample).render(),
459
+ : // TODO: Cleanup once the PR lands which makes sample / sample summary share common interface
460
+ // @ts-ignore
461
+ sampleDescriptor?.selectedScore(sample).render(),
449
462
  size: "minmax(2em, auto)",
450
463
  center: true,
451
464
  });
452
465
 
453
466
  return html`
454
467
  <div
455
- id=${`sample-heading-${id}`}
468
+ id=${`sample-heading-${parent_id}`}
456
469
  style=${{
457
470
  display: "grid",
458
471
  gridTemplateColumns: `${columns
@@ -145,7 +145,7 @@ export const SampleList = (props) => {
145
145
  );
146
146
 
147
147
  const listStyle = { ...style, flex: "1", overflowY: "auto", outline: "none" };
148
- const { limit, answer } = gridColumns(sampleDescriptor);
148
+ const { limit, answer, target } = gridColumns(sampleDescriptor);
149
149
 
150
150
  const headerRow = html`<div
151
151
  style=${{
@@ -161,7 +161,7 @@ export const SampleList = (props) => {
161
161
  >
162
162
  <div>Id</div>
163
163
  <div>Input</div>
164
- <div>Target</div>
164
+ <div>${target !== "0" ? "Target" : ""}</div>
165
165
  <div>${answer !== "0" ? "Answer" : ""}</div>
166
166
  <div>${limit !== "0" ? "Limit" : ""}</div>
167
167
  <div style=${{ justifySelf: "center" }}>Score</div>
@@ -29,10 +29,10 @@ export const ToolEventView = ({ id, event, style, depth }) => {
29
29
  return e.event === "approval";
30
30
  });
31
31
 
32
- const title = `Tool: ${event.function}`;
32
+ const title = `Tool: ${event.view?.title || event.function}`;
33
33
  return html`
34
34
  <${EventPanel} id=${id} title="${title}" subTitle=${formatDateTime(new Date(event.timestamp))} icon=${ApplicationIcons.solvers.use_tools} style=${style}>
35
- <div name="Summary" style=${{ margin: "0.5em 0" }}>
35
+ <div name="Summary" style=${{ margin: "0.5em 0", width: "100%" }}>
36
36
  <${ToolCallView}
37
37
  functionCall=${functionCall}
38
38
  input=${input}
inspect_ai/log/_log.py CHANGED
@@ -37,6 +37,9 @@ class EvalConfig(BaseModel):
37
37
  limit: int | tuple[int, int] | None = Field(default=None)
38
38
  """Sample limit (number of samples or range of samples)."""
39
39
 
40
+ sample_id: str | int | list[str | int] | None = Field(default=None)
41
+ """Evaluate specific sample(s)."""
42
+
40
43
  epochs: int | None = Field(default=None)
41
44
  """Number of epochs to run samples over."""
42
45
 
@@ -76,6 +79,9 @@ class EvalConfig(BaseModel):
76
79
  max_subprocesses: int | None = Field(default=None)
77
80
  """Maximum number of subprocesses to run concurrently."""
78
81
 
82
+ max_sandboxes: int | None = Field(default=None)
83
+ """Maximum number of sandboxes to run concurrently."""
84
+
79
85
  sandbox_cleanup: bool | None = Field(default=None)
80
86
  """Cleanup sandbox environments after task completes."""
81
87
 
@@ -362,13 +362,14 @@ class ZipLogFile:
362
362
  f"Error occurred during async write to {self._file}: {ex}. Falling back to sync write."
363
363
  )
364
364
 
365
- # write sync if we need to
366
- if not written:
367
- with file(self._file, "wb") as f:
368
- f.write(log_bytes)
369
-
370
- # re-open zip file w/ self.temp_file pointer at end
371
- self._open()
365
+ try:
366
+ # write sync if we need to
367
+ if not written:
368
+ with file(self._file, "wb") as f:
369
+ f.write(log_bytes)
370
+ finally:
371
+ # re-open zip file w/ self.temp_file pointer at end
372
+ self._open()
372
373
 
373
374
  async def close(self) -> EvalLog:
374
375
  async with self._lock:
@@ -68,10 +68,6 @@ async def call_tools(
68
68
  # create a transript for this call
69
69
  init_transcript(Transcript(name=call.function))
70
70
 
71
- # Amend the tool call with a custom view
72
- view = tool_call_view(call, tdefs)
73
- call.view = view
74
-
75
71
  result: Any = ""
76
72
  tool_error: ToolCallError | None = None
77
73
  try:
@@ -142,7 +138,7 @@ async def call_tools(
142
138
  arguments=call.arguments,
143
139
  result=content,
144
140
  truncated=truncated,
145
- view=view,
141
+ view=call.view,
146
142
  error=tool_error,
147
143
  events=list(transcript().events),
148
144
  )
@@ -163,7 +159,7 @@ async def call_tools(
163
159
  id=call.id,
164
160
  function=call.function,
165
161
  arguments=call.arguments,
166
- view=tool_call_view(call, tdefs),
162
+ view=call.view,
167
163
  pending=True,
168
164
  )
169
165
  transcript()._event(event)
@@ -72,6 +72,9 @@ class GenerateConfigArgs(TypedDict, total=False):
72
72
  cache_prompt: Literal["auto"] | bool | None
73
73
  """Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
74
74
 
75
+ reasoning_effort: Literal["low", "medium", "high"] | None
76
+ """Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
77
+
75
78
 
76
79
  class GenerateConfig(BaseModel):
77
80
  """Base class for model generation configs."""
@@ -139,6 +142,9 @@ class GenerateConfig(BaseModel):
139
142
  cache_prompt: Literal["auto"] | bool | None = Field(default=None)
140
143
  """Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
141
144
 
145
+ reasoning_effort: Literal["low", "medium", "high"] | None = Field(default=None)
146
+ """Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
147
+
142
148
  def merge(
143
149
  self, other: Union["GenerateConfig", GenerateConfigArgs]
144
150
  ) -> "GenerateConfig":
@@ -31,11 +31,11 @@ from inspect_ai._util.registry import (
31
31
  )
32
32
  from inspect_ai._util.retry import log_rate_limit_retry
33
33
  from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
34
- from inspect_ai.tool._tool_def import ToolDef
34
+ from inspect_ai.tool._tool_def import ToolDef, tool_defs
35
35
  from inspect_ai.util import concurrency
36
36
 
37
37
  from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
38
- from ._call_tools import disable_parallel_tools, tools_info
38
+ from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
39
39
  from ._chat_message import (
40
40
  ChatMessage,
41
41
  ChatMessageAssistant,
@@ -248,7 +248,7 @@ class Model:
248
248
  async with self._connection_concurrency(config):
249
249
  return await self._generate(
250
250
  input=input,
251
- tools=tools_info(tools),
251
+ tools=tools,
252
252
  tool_choice=tool_choice,
253
253
  config=config,
254
254
  cache=cache,
@@ -257,7 +257,10 @@ class Model:
257
257
  async def _generate(
258
258
  self,
259
259
  input: list[ChatMessage],
260
- tools: list[ToolInfo],
260
+ tools: list[Tool]
261
+ | list[ToolDef]
262
+ | list[ToolInfo]
263
+ | list[Tool | ToolDef | ToolInfo],
261
264
  tool_choice: ToolChoice | None,
262
265
  config: GenerateConfig,
263
266
  cache: bool | CachePolicy = False,
@@ -265,6 +268,12 @@ class Model:
265
268
  # default to 'auto' for tool_choice (same as underlying model apis)
266
269
  tool_choice = tool_choice if tool_choice else "auto"
267
270
 
271
+ # extract tool defs if we can
272
+ tdefs = tool_defs([tool for tool in tools if not isinstance(tool, ToolInfo)])
273
+
274
+ # resolve all tools into tool_info
275
+ tools = tools_info(tools)
276
+
268
277
  # if we have a specific tool selected then filter out the others
269
278
  if isinstance(tool_choice, ToolFunction):
270
279
  tools = [tool for tool in tools if tool.name == tool_choice.name]
@@ -374,6 +383,11 @@ class Model:
374
383
  # update output with time elapsed
375
384
  output.time = time_elapsed
376
385
 
386
+ # add views to tool calls
387
+ for choice in output.choices:
388
+ for tool_call in choice.message.tool_calls or []:
389
+ tool_call.view = tool_call_view(tool_call, tdefs)
390
+
377
391
  # complete the transcript event
378
392
  complete(output, call)
379
393
 
@@ -89,6 +89,19 @@ class AzureAIAPI(ModelAPI):
89
89
  config=config,
90
90
  )
91
91
 
92
+ # collect known model_args (then delete them so we can pass the rest on)
93
+ def collect_model_arg(name: str) -> Any | None:
94
+ nonlocal model_args
95
+ value = model_args.get(name, None)
96
+ if value is not None:
97
+ model_args.pop(name)
98
+ return value
99
+
100
+ emulate_tools = collect_model_arg("emulate_tools")
101
+ self.emulate_tools = (
102
+ not not emulate_tools if emulate_tools is not None else None
103
+ )
104
+
92
105
  # resolve api_key
93
106
  if not self.api_key:
94
107
  self.api_key = os.environ.get(
@@ -118,8 +131,15 @@ class AzureAIAPI(ModelAPI):
118
131
  tool_choice: ToolChoice,
119
132
  config: GenerateConfig,
120
133
  ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
121
- # if its llama then do fake tool calls
122
- handler: ChatAPIHandler | None = Llama31Handler() if self.is_llama() else None
134
+ # emulate tools (auto for llama, opt-in for others)
135
+ if self.emulate_tools is None and self.is_llama():
136
+ handler: ChatAPIHandler | None = Llama31Handler()
137
+ elif self.emulate_tools:
138
+ handler = Llama31Handler()
139
+ else:
140
+ handler = None
141
+
142
+ # resolve input
123
143
  if handler:
124
144
  input = handler.input_with_tools(input, tools)
125
145
 
@@ -236,15 +236,21 @@ class BedrockAPI(ModelAPI):
236
236
  self,
237
237
  model_name: str,
238
238
  base_url: str | None,
239
+ api_key: str | None = None,
239
240
  config: GenerateConfig = GenerateConfig(),
240
241
  **model_args: Any,
241
242
  ):
242
243
  super().__init__(
243
244
  model_name=model_name,
244
245
  base_url=model_base_url(base_url, "BEDROCK_BASE_URL"),
246
+ api_key=api_key,
247
+ api_key_vars=[],
245
248
  config=config,
246
249
  )
247
250
 
251
+ # save model_args
252
+ self.model_args = model_args
253
+
248
254
  # import aioboto3 on demand
249
255
  try:
250
256
  import aioboto3
@@ -263,6 +269,9 @@ class BedrockAPI(ModelAPI):
263
269
 
264
270
  @override
265
271
  def max_tokens(self) -> int | None:
272
+ if "llama3-70" in self.model_name or "llama3-8" in self.model_name:
273
+ return 2048
274
+
266
275
  if "llama3" in self.model_name or "claude3" in self.model_name:
267
276
  return 4096
268
277
 
@@ -316,6 +325,7 @@ class BedrockAPI(ModelAPI):
316
325
  mode="adaptive",
317
326
  ),
318
327
  ),
328
+ **self.model_args,
319
329
  ) as client:
320
330
  # Process the tools
321
331
  resolved_tools = converse_tools(tools)
@@ -658,6 +668,8 @@ def converse_image_type(type: str) -> ConverseImageFormat:
658
668
  return "png"
659
669
  case "image/webp":
660
670
  return "webp"
671
+ case "image/jpeg":
672
+ return "jpeg"
661
673
  case _:
662
674
  raise ValueError(
663
675
  f"Image mime type {type} is not supported for Bedrock Converse models."
@@ -673,7 +685,11 @@ def converse_tools(tools: list[ToolInfo]) -> list[ConverseTool] | None:
673
685
  tool_spec = ConverseToolSpec(
674
686
  name=tool.name,
675
687
  description=tool.description,
676
- inputSchema={"json": tool.parameters.model_dump(exclude_none=True)},
688
+ inputSchema={
689
+ "json": tool.parameters.model_dump(
690
+ exclude_none=True, exclude={"additionalProperties"}
691
+ )
692
+ },
677
693
  )
678
694
  result.append(ConverseTool(toolSpec=tool_spec))
679
695
  return result
@@ -64,7 +64,7 @@ class HuggingFaceAPI(ModelAPI):
64
64
  def collect_model_arg(name: str) -> Any | None:
65
65
  nonlocal model_args
66
66
  value = model_args.get(name, None)
67
- if value:
67
+ if value is not None:
68
68
  model_args.pop(name)
69
69
  return value
70
70
 
@@ -18,6 +18,7 @@ from openai.types.chat import (
18
18
  ChatCompletionContentPartImageParam,
19
19
  ChatCompletionContentPartParam,
20
20
  ChatCompletionContentPartTextParam,
21
+ ChatCompletionDeveloperMessageParam,
21
22
  ChatCompletionMessage,
22
23
  ChatCompletionMessageParam,
23
24
  ChatCompletionMessageToolCallParam,
@@ -141,6 +142,18 @@ class OpenAIAPI(ModelAPI):
141
142
  **model_args,
142
143
  )
143
144
 
145
+ def is_o1(self) -> bool:
146
+ return self.model_name.startswith("o1")
147
+
148
+ def is_o1_full(self) -> bool:
149
+ return self.is_o1() and not self.is_o1_mini() and not self.is_o1_preview()
150
+
151
+ def is_o1_mini(self) -> bool:
152
+ return self.model_name.startswith("o1-mini")
153
+
154
+ def is_o1_preview(self) -> bool:
155
+ return self.model_name.startswith("o1-preview")
156
+
144
157
  async def generate(
145
158
  self,
146
159
  input: list[ChatMessage],
@@ -148,8 +161,8 @@ class OpenAIAPI(ModelAPI):
148
161
  tool_choice: ToolChoice,
149
162
  config: GenerateConfig,
150
163
  ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
151
- # short-circuit to call o1- model
152
- if self.model_name.startswith("o1-"):
164
+ # short-circuit to call o1- models that are text only
165
+ if self.is_o1_preview() or self.is_o1_mini():
153
166
  return await generate_o1(
154
167
  client=self.client,
155
168
  input=input,
@@ -179,7 +192,7 @@ class OpenAIAPI(ModelAPI):
179
192
 
180
193
  # prepare request (we do this so we can log the ModelCall)
181
194
  request = dict(
182
- messages=await as_openai_chat_messages(input),
195
+ messages=await as_openai_chat_messages(input, self.is_o1_full()),
183
196
  tools=chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
184
197
  tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
185
198
  **self.completion_params(config, len(tools) > 0),
@@ -271,8 +284,10 @@ class OpenAIAPI(ModelAPI):
271
284
  params["logprobs"] = config.logprobs
272
285
  if config.top_logprobs is not None:
273
286
  params["top_logprobs"] = config.top_logprobs
274
- if tools and config.parallel_tool_calls is not None:
287
+ if tools and config.parallel_tool_calls is not None and not self.is_o1():
275
288
  params["parallel_tool_calls"] = config.parallel_tool_calls
289
+ if config.reasoning_effort is not None and self.is_o1_full():
290
+ params["reasoning_effort"] = config.reasoning_effort
276
291
 
277
292
  return params
278
293
 
@@ -291,14 +306,23 @@ class OpenAIAPI(ModelAPI):
291
306
 
292
307
 
293
308
  async def as_openai_chat_messages(
294
- messages: list[ChatMessage],
309
+ messages: list[ChatMessage], o1_full: bool
295
310
  ) -> list[ChatCompletionMessageParam]:
296
- return [await openai_chat_message(message) for message in messages]
311
+ return [await openai_chat_message(message, o1_full) for message in messages]
297
312
 
298
313
 
299
- async def openai_chat_message(message: ChatMessage) -> ChatCompletionMessageParam:
314
+ async def openai_chat_message(
315
+ message: ChatMessage, o1_full: bool
316
+ ) -> ChatCompletionMessageParam:
300
317
  if message.role == "system":
301
- return ChatCompletionSystemMessageParam(role=message.role, content=message.text)
318
+ if o1_full:
319
+ return ChatCompletionDeveloperMessageParam(
320
+ role="developer", content=message.text
321
+ )
322
+ else:
323
+ return ChatCompletionSystemMessageParam(
324
+ role=message.role, content=message.text
325
+ )
302
326
  elif message.role == "user":
303
327
  return ChatCompletionUserMessageParam(
304
328
  role=message.role,
@@ -242,7 +242,7 @@ def mockllm() -> type[ModelAPI]:
242
242
  def validate_openai_client(feature: str) -> None:
243
243
  FEATURE = feature
244
244
  PACKAGE = "openai"
245
- MIN_VERSION = "1.45.0"
245
+ MIN_VERSION = "1.58.1"
246
246
 
247
247
  # verify we have the package
248
248
  try:
@@ -75,7 +75,7 @@ class VLLMAPI(ModelAPI):
75
75
  def collect_model_arg(name: str) -> Any | None:
76
76
  nonlocal model_args
77
77
  value = model_args.get(name, None)
78
- if value:
78
+ if value is not None:
79
79
  model_args.pop(name)
80
80
  return value
81
81
 
@@ -1,8 +1,7 @@
1
1
  from rich.console import RenderableType
2
2
 
3
- from inspect_ai._util.format import format_function_call
4
- from inspect_ai._util.transcript import transcript_markdown
5
3
  from inspect_ai.tool._tool_call import ToolCall
4
+ from inspect_ai.tool._tool_transcript import transcript_tool_call
6
5
 
7
6
  from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
8
7
 
@@ -17,8 +16,10 @@ def messages_preceding_assistant(messages: list[ChatMessage]) -> list[ChatMessag
17
16
  return list(reversed(preceding))
18
17
 
19
18
 
20
- def render_tool_calls(tool_calls: list[ToolCall]) -> RenderableType:
21
- formatted_calls: list[str] = []
19
+ def render_tool_calls(tool_calls: list[ToolCall]) -> list[RenderableType]:
20
+ formatted_calls: list[RenderableType] = []
21
+
22
22
  for call in tool_calls:
23
- formatted_calls.append(format_function_call(call.function, call.arguments))
24
- return transcript_markdown("```python\n" + "\n\n".join(formatted_calls) + "\n```\n")
23
+ formatted_calls.extend(transcript_tool_call(call))
24
+
25
+ return formatted_calls
@@ -42,7 +42,7 @@ def trace_assistant_message(
42
42
  # print tool calls
43
43
  if message.tool_calls:
44
44
  content.append(Text())
45
- content.append(render_tool_calls(message.tool_calls))
45
+ content.extend(render_tool_calls(message.tool_calls))
46
46
 
47
47
  # print the assistant message
48
48
  trace_panel(title="Assistant", content=content)
@@ -54,6 +54,7 @@ def basic_agent(
54
54
  max_attempts: int = 1,
55
55
  message_limit: int | None = None,
56
56
  token_limit: int | None = None,
57
+ max_tool_output: int | None = None,
57
58
  score_value: ValueToFloat | None = None,
58
59
  incorrect_message: str
59
60
  | Callable[[TaskState, list[Score]], str] = DEFAULT_INCORRECT_MESSAGE,
@@ -87,6 +88,8 @@ def basic_agent(
87
88
  If not specified, will use limit_messages defined for the task. If there is none
88
89
  defined for the task, 50 will be used as a default.
89
90
  token_limit (int | None): Limit on tokens used in sample before terminating agent.
91
+ max_tool_output (int | None): Maximum output length (in bytes).
92
+ Defaults to max_tool_output from active GenerateConfig.
90
93
  score_value (ValueToFloat): Function used to extract float from scores (defaults
91
94
  to standard value_to_float())
92
95
  incorrect_message (str | Callable[[TaskState, list[Score]], str]): User message reply for an
@@ -182,7 +185,9 @@ def basic_agent(
182
185
  # resolve tools calls (if any)
183
186
  if state.output.message.tool_calls:
184
187
  # call tool functions
185
- tool_results = await call_tools(state.output.message, state.tools)
188
+ tool_results = await call_tools(
189
+ state.output.message, state.tools, max_output=max_tool_output
190
+ )
186
191
  state.messages.extend(tool_results)
187
192
 
188
193
  # was an answer submitted?
@@ -194,11 +199,13 @@ def basic_agent(
194
199
  # exit if we are at max_attempts
195
200
  attempts += 1
196
201
  if attempts >= max_attempts:
202
+ state.completed = True
197
203
  break
198
204
 
199
205
  # exit if the submission is successful
200
206
  answer_scores = await score(state)
201
207
  if score_value_fn(answer_scores[0].value) == 1.0:
208
+ state.completed = True
202
209
  break
203
210
 
204
211
  # otherwise notify the model that it was incorrect and continue
@@ -0,0 +1,28 @@
1
+ from pydantic import JsonValue
2
+ from rich.console import RenderableType
3
+ from rich.text import Text
4
+ from typing_extensions import Protocol
5
+
6
+ from inspect_ai._util.transcript import transcript_function, transcript_markdown
7
+
8
+ from ._tool_call import ToolCallContent
9
+
10
+
11
+ class TranscriptToolCall(Protocol):
12
+ function: str
13
+ arguments: dict[str, JsonValue]
14
+ view: ToolCallContent | None
15
+
16
+
17
+ def transcript_tool_call(call: TranscriptToolCall) -> list[RenderableType]:
18
+ content: list[RenderableType] = []
19
+ if call.view:
20
+ if call.view.title:
21
+ content.append(Text.from_markup(f"[bold]{call.view.title}[/bold]\n"))
22
+ if call.view.format == "markdown":
23
+ content.append(transcript_markdown(call.view.content))
24
+ else:
25
+ content.append(call.view.content)
26
+ else:
27
+ content.append(transcript_function(call.function, call.arguments))
28
+ return content
@@ -109,7 +109,7 @@ def raise_no_sandbox() -> NoReturn:
109
109
 
110
110
 
111
111
  async def init_sandbox_environments_sample(
112
- type: str,
112
+ sandboxenv_type: type[SandboxEnvironment],
113
113
  task_name: str,
114
114
  config: SandboxEnvironmentConfigType | None,
115
115
  files: dict[str, bytes],
@@ -117,7 +117,6 @@ async def init_sandbox_environments_sample(
117
117
  metadata: dict[str, Any],
118
118
  ) -> dict[str, SandboxEnvironment]:
119
119
  # get setup and cleanup functions
120
- sandboxenv_type = registry_find_sandboxenv(type)
121
120
  sample_init = cast(SampleInit, getattr(sandboxenv_type, "sample_init"))
122
121
  sample_cleanup = cast(SampleCleanup, getattr(sandboxenv_type, "sample_cleanup"))
123
122
 
@@ -2,8 +2,6 @@ import os
2
2
  from logging import getLogger
3
3
  from pathlib import Path
4
4
 
5
- import aiofiles
6
-
7
5
  logger = getLogger(__name__)
8
6
 
9
7
 
@@ -17,7 +15,7 @@ CONFIG_FILES = [
17
15
  DOCKERFILE = "Dockerfile"
18
16
 
19
17
 
20
- async def resolve_compose_file(parent: str = "") -> str:
18
+ def resolve_compose_file(parent: str = "") -> str:
21
19
  # existing compose file provides all the config we need
22
20
  compose = find_compose_file(parent)
23
21
  if compose is not None:
@@ -29,11 +27,11 @@ async def resolve_compose_file(parent: str = "") -> str:
29
27
 
30
28
  # dockerfile just needs a compose.yaml synthesized
31
29
  elif has_dockerfile(parent):
32
- return await auto_compose_file(COMPOSE_DOCKERFILE_YAML, parent)
30
+ return auto_compose_file(COMPOSE_DOCKERFILE_YAML, parent)
33
31
 
34
32
  # otherwise provide a generic python container
35
33
  else:
36
- return await auto_compose_file(COMPOSE_GENERIC_YAML, parent)
34
+ return auto_compose_file(COMPOSE_GENERIC_YAML, parent)
37
35
 
38
36
 
39
37
  def find_compose_file(parent: str = "") -> str | None:
@@ -59,9 +57,9 @@ def is_auto_compose_file(file: str) -> bool:
59
57
  return os.path.basename(file) == AUTO_COMPOSE_YAML
60
58
 
61
59
 
62
- async def ensure_auto_compose_file(file: str | None) -> None:
60
+ def ensure_auto_compose_file(file: str | None) -> None:
63
61
  if file is not None and is_auto_compose_file(file) and not os.path.exists(file):
64
- await resolve_compose_file(os.path.dirname(file))
62
+ resolve_compose_file(os.path.dirname(file))
65
63
 
66
64
 
67
65
  def safe_cleanup_auto_compose(file: str | None) -> None:
@@ -100,8 +98,8 @@ services:
100
98
  """
101
99
 
102
100
 
103
- async def auto_compose_file(contents: str, parent: str = "") -> str:
101
+ def auto_compose_file(contents: str, parent: str = "") -> str:
104
102
  path = os.path.join(parent, AUTO_COMPOSE_YAML)
105
- async with aiofiles.open(path, "w", encoding="utf-8") as f:
106
- await f.write(contents)
103
+ with open(path, "w", encoding="utf-8") as f:
104
+ f.write(contents)
107
105
  return Path(path).resolve().as_posix()