inspect-ai 0.3.55__py3-none-any.whl → 0.3.57__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. inspect_ai/__init__.py +1 -0
  2. inspect_ai/_cli/common.py +1 -1
  3. inspect_ai/_cli/trace.py +33 -20
  4. inspect_ai/_display/core/active.py +1 -1
  5. inspect_ai/_display/core/display.py +1 -1
  6. inspect_ai/_display/core/footer.py +1 -1
  7. inspect_ai/_display/core/panel.py +1 -1
  8. inspect_ai/_display/core/progress.py +0 -6
  9. inspect_ai/_display/core/rich.py +1 -1
  10. inspect_ai/_display/rich/display.py +2 -2
  11. inspect_ai/_display/textual/app.py +15 -17
  12. inspect_ai/_display/textual/widgets/clock.py +3 -3
  13. inspect_ai/_display/textual/widgets/samples.py +6 -13
  14. inspect_ai/_eval/context.py +9 -1
  15. inspect_ai/_eval/run.py +16 -11
  16. inspect_ai/_eval/score.py +4 -10
  17. inspect_ai/_eval/task/results.py +5 -4
  18. inspect_ai/_eval/task/run.py +6 -12
  19. inspect_ai/_eval/task/task.py +10 -0
  20. inspect_ai/_util/ansi.py +31 -0
  21. inspect_ai/_util/datetime.py +1 -1
  22. inspect_ai/_util/deprecation.py +1 -1
  23. inspect_ai/_util/format.py +7 -0
  24. inspect_ai/_util/json.py +11 -1
  25. inspect_ai/_util/logger.py +14 -13
  26. inspect_ai/_util/throttle.py +10 -1
  27. inspect_ai/_util/trace.py +79 -47
  28. inspect_ai/_util/transcript.py +37 -4
  29. inspect_ai/_util/vscode.py +51 -0
  30. inspect_ai/_view/notify.py +2 -1
  31. inspect_ai/_view/www/.prettierrc.js +12 -0
  32. inspect_ai/_view/www/App.css +22 -1
  33. inspect_ai/_view/www/dist/assets/index.css +2374 -2
  34. inspect_ai/_view/www/dist/assets/index.js +29752 -24492
  35. inspect_ai/_view/www/log-schema.json +262 -215
  36. inspect_ai/_view/www/package.json +1 -0
  37. inspect_ai/_view/www/src/App.mjs +19 -9
  38. inspect_ai/_view/www/src/Types.mjs +0 -1
  39. inspect_ai/_view/www/src/api/Types.mjs +15 -4
  40. inspect_ai/_view/www/src/api/api-http.mjs +2 -0
  41. inspect_ai/_view/www/src/appearance/Icons.mjs +2 -0
  42. inspect_ai/_view/www/src/components/AsciiCinemaPlayer.mjs +74 -0
  43. inspect_ai/_view/www/src/components/CopyButton.mjs +0 -1
  44. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
  45. inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
  46. inspect_ai/_view/www/src/components/HumanBaselineView.mjs +168 -0
  47. inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
  48. inspect_ai/_view/www/src/components/LightboxCarousel.mjs +217 -0
  49. inspect_ai/_view/www/src/components/MessageContent.mjs +1 -1
  50. inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
  51. inspect_ai/_view/www/src/components/Tools.mjs +28 -5
  52. inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
  53. inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
  54. inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
  55. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
  56. inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
  57. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
  58. inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
  59. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +238 -178
  60. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
  61. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
  62. inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
  63. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
  64. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +3 -2
  65. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
  66. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +1 -0
  67. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +56 -0
  68. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +17 -5
  69. inspect_ai/_view/www/src/types/asciicinema-player.d.ts +26 -0
  70. inspect_ai/_view/www/src/types/log.d.ts +28 -20
  71. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  72. inspect_ai/_view/www/yarn.lock +44 -0
  73. inspect_ai/approval/_apply.py +4 -0
  74. inspect_ai/approval/_human/panel.py +5 -8
  75. inspect_ai/dataset/_dataset.py +51 -10
  76. inspect_ai/dataset/_util.py +31 -3
  77. inspect_ai/log/__init__.py +2 -0
  78. inspect_ai/log/_log.py +30 -2
  79. inspect_ai/log/_recorders/eval.py +2 -0
  80. inspect_ai/model/_call_tools.py +31 -7
  81. inspect_ai/model/_chat_message.py +3 -0
  82. inspect_ai/model/_model.py +42 -1
  83. inspect_ai/model/_providers/anthropic.py +4 -0
  84. inspect_ai/model/_providers/google.py +24 -6
  85. inspect_ai/model/_providers/openai.py +17 -3
  86. inspect_ai/model/_providers/openai_o1.py +10 -12
  87. inspect_ai/model/_render.py +9 -2
  88. inspect_ai/scorer/_metric.py +12 -1
  89. inspect_ai/solver/__init__.py +2 -0
  90. inspect_ai/solver/_human_agent/agent.py +83 -0
  91. inspect_ai/solver/_human_agent/commands/__init__.py +36 -0
  92. inspect_ai/solver/_human_agent/commands/clock.py +70 -0
  93. inspect_ai/solver/_human_agent/commands/command.py +59 -0
  94. inspect_ai/solver/_human_agent/commands/instructions.py +74 -0
  95. inspect_ai/solver/_human_agent/commands/note.py +42 -0
  96. inspect_ai/solver/_human_agent/commands/score.py +80 -0
  97. inspect_ai/solver/_human_agent/commands/status.py +62 -0
  98. inspect_ai/solver/_human_agent/commands/submit.py +151 -0
  99. inspect_ai/solver/_human_agent/install.py +222 -0
  100. inspect_ai/solver/_human_agent/panel.py +252 -0
  101. inspect_ai/solver/_human_agent/service.py +45 -0
  102. inspect_ai/solver/_human_agent/state.py +55 -0
  103. inspect_ai/solver/_human_agent/view.py +24 -0
  104. inspect_ai/solver/_task_state.py +28 -2
  105. inspect_ai/tool/_tool.py +10 -2
  106. inspect_ai/tool/_tool_info.py +2 -1
  107. inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
  108. inspect_ai/tool/_tools/_web_browser/_web_browser.py +16 -13
  109. inspect_ai/util/__init__.py +12 -4
  110. inspect_ai/{_util/display.py → util/_display.py} +6 -0
  111. inspect_ai/util/_panel.py +31 -9
  112. inspect_ai/util/_sandbox/__init__.py +0 -3
  113. inspect_ai/util/_sandbox/context.py +5 -1
  114. inspect_ai/util/_sandbox/docker/compose.py +17 -13
  115. inspect_ai/util/_sandbox/docker/docker.py +9 -6
  116. inspect_ai/util/_sandbox/docker/internal.py +1 -1
  117. inspect_ai/util/_sandbox/docker/util.py +3 -2
  118. inspect_ai/util/_sandbox/environment.py +6 -5
  119. inspect_ai/util/_sandbox/local.py +1 -1
  120. inspect_ai/util/_sandbox/self_check.py +18 -18
  121. inspect_ai/util/_sandbox/service.py +22 -7
  122. inspect_ai/util/_store.py +7 -8
  123. inspect_ai/util/_store_model.py +110 -0
  124. inspect_ai/util/_subprocess.py +3 -3
  125. inspect_ai/util/_throttle.py +32 -0
  126. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/METADATA +3 -3
  127. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/RECORD +131 -108
  128. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/WHEEL +1 -1
  129. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/LICENSE +0 -0
  130. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/entry_points.txt +0 -0
  131. {inspect_ai-0.3.55.dist-info → inspect_ai-0.3.57.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from ._log import (
23
23
  EvalRevision,
24
24
  EvalSample,
25
25
  EvalSampleReductions,
26
+ EvalSampleScore,
26
27
  EvalScore,
27
28
  EvalSpec,
28
29
  EvalStats,
@@ -60,6 +61,7 @@ __all__ = [
60
61
  "EvalResults",
61
62
  "EvalRevision",
62
63
  "EvalSample",
64
+ "EvalSampleScore",
63
65
  "EvalSampleReductions",
64
66
  "EvalScore",
65
67
  "EvalSpec",
inspect_ai/log/_log.py CHANGED
@@ -16,6 +16,7 @@ from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH, PKG_NAME
16
16
  from inspect_ai._util.error import EvalError, exception_message
17
17
  from inspect_ai._util.logger import warn_once
18
18
  from inspect_ai.approval._policy import ApprovalPolicyConfig
19
+ from inspect_ai.dataset._dataset import MT, metadata_as
19
20
  from inspect_ai.model import (
20
21
  ChatMessage,
21
22
  GenerateConfig,
@@ -23,8 +24,9 @@ from inspect_ai.model import (
23
24
  ModelUsage,
24
25
  )
25
26
  from inspect_ai.scorer import Score
26
- from inspect_ai.scorer._metric import SampleScore
27
27
  from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
28
+ from inspect_ai.util._store import Store
29
+ from inspect_ai.util._store_model import SMT
28
30
 
29
31
  from ._transcript import Event
30
32
 
@@ -159,9 +161,31 @@ class EvalSample(BaseModel):
159
161
  metadata: dict[str, Any]
160
162
  """Additional sample metadata."""
161
163
 
164
+ def metadata_as(self, metadata_cls: Type[MT]) -> MT:
165
+ """Pydantic model interface to metadata.
166
+
167
+ Args:
168
+ metadata_cls: Pydantic model type
169
+
170
+ Returns:
171
+ BaseModel: Instance of metadata_cls bound to sample metadata.
172
+ """
173
+ return metadata_as(self.metadata, metadata_cls)
174
+
162
175
  store: dict[str, Any] = Field(default_factory=dict)
163
176
  """State at end of sample execution."""
164
177
 
178
+ def store_as(self, model_cls: Type[SMT]) -> SMT:
179
+ """Pydantic model interface to the store.
180
+
181
+ Args:
182
+ model_cls: Pydantic model type (must derive from StoreModel)
183
+
184
+ Returns:
185
+ StoreModel: Instance of model_cls bound to sample store data.
186
+ """
187
+ return model_cls(store=Store(self.store))
188
+
165
189
  events: list[Event] = Field(default_factory=list)
166
190
  """Events that occurred during sample execution."""
167
191
 
@@ -301,6 +325,10 @@ class EvalScore(BaseModel):
301
325
  """Additional scorer metadata."""
302
326
 
303
327
 
328
+ class EvalSampleScore(Score):
329
+ sample_id: str | int | None = Field(default=None)
330
+
331
+
304
332
  class EvalSampleReductions(BaseModel):
305
333
  scorer: str
306
334
  """Name the of scorer"""
@@ -308,7 +336,7 @@ class EvalSampleReductions(BaseModel):
308
336
  reducer: str | None = Field(default=None)
309
337
  """Name the of reducer"""
310
338
 
311
- samples: list[SampleScore]
339
+ samples: list[EvalSampleScore]
312
340
  """List of reduced scores"""
313
341
 
314
342
 
@@ -252,6 +252,8 @@ def text_inputs(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
252
252
  filtered_content.append(ContentText(text="(Image)"))
253
253
  message.content = filtered_content
254
254
  input.append(message)
255
+ else:
256
+ input.append(message)
255
257
 
256
258
  return input
257
259
  else:
@@ -1,15 +1,20 @@
1
1
  import asyncio
2
2
  import inspect
3
+ import types
3
4
  from dataclasses import is_dataclass
4
5
  from logging import getLogger
5
6
  from textwrap import dedent
7
+ from types import UnionType
6
8
  from typing import (
7
9
  Any,
8
10
  Callable,
9
11
  Dict,
10
12
  List,
11
13
  NamedTuple,
14
+ Optional,
15
+ Tuple,
12
16
  Type,
17
+ Union,
13
18
  get_args,
14
19
  get_origin,
15
20
  get_type_hints,
@@ -25,10 +30,7 @@ from inspect_ai._util.text import truncate_string_to_bytes
25
30
  from inspect_ai._util.trace import trace_action
26
31
  from inspect_ai.model._trace import trace_tool_mesage
27
32
  from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
28
- from inspect_ai.tool._tool import (
29
- ToolApprovalError,
30
- ToolParsingError,
31
- )
33
+ from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
32
34
  from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
33
35
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
34
36
  from inspect_ai.tool._tool_info import parse_docstring
@@ -118,10 +120,12 @@ async def call_tools(
118
120
  # massage result, leave list[Content] alone, convert all other
119
121
  # types to string as that is what the model APIs accept
120
122
  truncated: tuple[int, int] | None = None
121
- if isinstance(result, list) and (
123
+ if isinstance(result, ContentText | ContentImage):
124
+ content: str | list[Content] = [result]
125
+ elif isinstance(result, list) and (
122
126
  isinstance(result[0], ContentText | ContentImage)
123
127
  ):
124
- content: str | list[Content] = result
128
+ content = result
125
129
  else:
126
130
  content = str(result)
127
131
 
@@ -266,6 +270,16 @@ def disable_parallel_tools(
266
270
  return False
267
271
 
268
272
 
273
+ def type_hint_includes_none(type_hint: Type[Any] | None) -> bool:
274
+ origin = get_origin(type_hint)
275
+
276
+ if origin in {Union, UnionType}:
277
+ return type(None) in get_args(type_hint)
278
+ elif origin is Optional:
279
+ return True
280
+ return False
281
+
282
+
269
283
  def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, Any]:
270
284
  # parse function typeinfo
271
285
  signature = inspect.signature(func)
@@ -294,7 +308,7 @@ def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, An
294
308
  # yield parameter (fail if not passed and there is no default)
295
309
  if param_name in input:
296
310
  params[param_name] = tool_param(type_hint, input.get(param_name))
297
- elif param.default is not None:
311
+ elif param.default is not None or type_hint_includes_none(type_hint):
298
312
  params[param_name] = param.default
299
313
  else:
300
314
  raise ToolParsingError(
@@ -337,11 +351,21 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
337
351
  return [tool_param(args[0], x) for x in input]
338
352
  else:
339
353
  return input
354
+ elif origin is tuple or origin is Tuple:
355
+ if args:
356
+ return tuple([tool_param(args[0], x) for x in input])
357
+ else:
358
+ return tuple(input)
340
359
  elif origin is dict or origin is Dict:
341
360
  if args and len(args) > 1:
342
361
  return {k: tool_param(args[1], v) for k, v in input}
343
362
  else:
344
363
  return input
364
+ elif origin is Union or origin is types.UnionType:
365
+ if args[1] is type(None):
366
+ return tool_param(args[0], input)
367
+ else:
368
+ return input
345
369
  else:
346
370
  return input
347
371
 
@@ -74,6 +74,9 @@ class ChatMessageUser(ChatMessageBase):
74
74
  role: Literal["user"] = Field(default="user")
75
75
  """Conversation role."""
76
76
 
77
+ tool_call_id: str | None = Field(default=None)
78
+ """ID of tool call this message has the content payload for."""
79
+
77
80
 
78
81
  class ChatMessageAssistant(ChatMessageBase):
79
82
  role: Literal["assistant"] = Field(default="assistant")
@@ -19,7 +19,7 @@ from tenacity import (
19
19
  )
20
20
 
21
21
  from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS
22
- from inspect_ai._util.content import ContentText
22
+ from inspect_ai._util.content import Content, ContentImage, ContentText
23
23
  from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry
24
24
  from inspect_ai._util.platform import platform_init
25
25
  from inspect_ai._util.registry import (
@@ -40,6 +40,7 @@ from ._chat_message import (
40
40
  ChatMessage,
41
41
  ChatMessageAssistant,
42
42
  ChatMessageSystem,
43
+ ChatMessageTool,
43
44
  ChatMessageUser,
44
45
  )
45
46
  from ._generate_config import (
@@ -163,6 +164,10 @@ class ModelAPI(abc.ABC):
163
164
  """Any tool use in a message stream means that tools must be passed."""
164
165
  return False
165
166
 
167
+ def tool_result_images(self) -> bool:
168
+ """Tool results can containe images"""
169
+ return False
170
+
166
171
 
167
172
  class Model:
168
173
  """Model interface."""
@@ -291,6 +296,11 @@ class Model:
291
296
  tools = []
292
297
  tool_choice = "none"
293
298
 
299
+ # break tool image content out into user messages if the model doesn't
300
+ # support tools returning images
301
+ if not self.api.tool_result_images():
302
+ input = tool_result_images_as_user_message(input)
303
+
294
304
  # optionally collapse *consecutive* messages into one -
295
305
  # (some apis e.g. anthropic require this)
296
306
  if self.api.collapse_user_messages():
@@ -693,6 +703,37 @@ def simple_input_messages(
693
703
  return messages
694
704
 
695
705
 
706
+ def tool_result_images_as_user_message(
707
+ messages: list[ChatMessage],
708
+ ) -> list[ChatMessage]:
709
+ return functools.reduce(tool_result_images_reducer, messages, [])
710
+
711
+
712
+ def tool_result_images_reducer(
713
+ messages: list[ChatMessage],
714
+ message: ChatMessage,
715
+ ) -> list[ChatMessage]:
716
+ # append the message
717
+ messages.append(message)
718
+
719
+ # if there are tool result images, pull them out into a ChatUserMessage
720
+ if isinstance(message, ChatMessageTool) and isinstance(message.content, list):
721
+ user_content: list[Content] = []
722
+ for i in range(0, len(message.content)):
723
+ if isinstance(message.content[i], ContentImage):
724
+ user_content.append(message.content[i])
725
+ message.content[i] = ContentText(
726
+ text="Image content is in the message below."
727
+ )
728
+ if len(user_content) > 0:
729
+ messages.append(
730
+ ChatMessageUser(content=user_content, tool_call_id=message.tool_call_id)
731
+ )
732
+
733
+ # return messages
734
+ return messages
735
+
736
+
696
737
  # Functions to reduce consecutive user messages to a single user message -> required for some models
697
738
  def collapse_consecutive_user_messages(
698
739
  messages: list[ChatMessage],
@@ -229,6 +229,10 @@ class AnthropicAPI(ModelAPI):
229
229
  def tools_required(self) -> bool:
230
230
  return True
231
231
 
232
+ @override
233
+ def tool_result_images(self) -> bool:
234
+ return True
235
+
232
236
  # convert some common BadRequestError states into 'refusal' model output
233
237
  def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None:
234
238
  error = exception_message(ex).lower()
@@ -194,7 +194,9 @@ class GoogleAPI(ModelAPI):
194
194
  model=self.model_name, content=ex.message, stop_reason="model_length"
195
195
  )
196
196
  else:
197
- raise ex
197
+ return ModelOutput.from_content(
198
+ model=self.model_name, content=ex.message, stop_reason="unknown"
199
+ )
198
200
 
199
201
  @override
200
202
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -408,25 +410,34 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
408
410
  # https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
409
411
 
410
412
 
411
- def schema_from_param(param: ToolParam | ToolParams) -> Schema:
413
+ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
412
414
  if isinstance(param, ToolParams):
413
415
  param = ToolParam(
414
416
  type=param.type, properties=param.properties, required=param.required
415
417
  )
416
418
 
417
419
  if param.type == "number":
418
- return Schema(type=Type.NUMBER, description=param.description)
420
+ return Schema(
421
+ type=Type.NUMBER, description=param.description, nullable=nullable
422
+ )
419
423
  elif param.type == "integer":
420
- return Schema(type=Type.INTEGER, description=param.description)
424
+ return Schema(
425
+ type=Type.INTEGER, description=param.description, nullable=nullable
426
+ )
421
427
  elif param.type == "boolean":
422
- return Schema(type=Type.BOOLEAN, description=param.description)
428
+ return Schema(
429
+ type=Type.BOOLEAN, description=param.description, nullable=nullable
430
+ )
423
431
  elif param.type == "string":
424
- return Schema(type=Type.STRING, description=param.description)
432
+ return Schema(
433
+ type=Type.STRING, description=param.description, nullable=nullable
434
+ )
425
435
  elif param.type == "array":
426
436
  return Schema(
427
437
  type=Type.ARRAY,
428
438
  description=param.description,
429
439
  items=schema_from_param(param.items) if param.items else None,
440
+ nullable=nullable,
430
441
  )
431
442
  elif param.type == "object":
432
443
  return Schema(
@@ -436,7 +447,14 @@ def schema_from_param(param: ToolParam | ToolParams) -> Schema:
436
447
  if param.properties is not None
437
448
  else None,
438
449
  required=param.required,
450
+ nullable=nullable,
439
451
  )
452
+ # convert unions to optional params if the second type is 'null'
453
+ elif param.anyOf:
454
+ if len(param.anyOf) == 2 and param.anyOf[1].type == "null":
455
+ return schema_from_param(param.anyOf[0], nullable=True)
456
+ else:
457
+ return Schema(type=Type.TYPE_UNSPECIFIED)
440
458
  else:
441
459
  return Schema(type=Type.TYPE_UNSPECIFIED)
442
460
 
@@ -51,6 +51,7 @@ from .._model_output import (
51
51
  Logprobs,
52
52
  ModelOutput,
53
53
  ModelUsage,
54
+ StopReason,
54
55
  )
55
56
  from .openai_o1 import generate_o1
56
57
  from .util import (
@@ -262,7 +263,10 @@ class OpenAIAPI(ModelAPI):
262
263
  model=self.model_name,
263
264
  )
264
265
  if config.max_tokens is not None:
265
- params["max_tokens"] = config.max_tokens
266
+ if self.is_o1():
267
+ params["max_completion_tokens"] = config.max_tokens
268
+ else:
269
+ params["max_tokens"] = config.max_tokens
266
270
  if config.frequency_penalty is not None:
267
271
  params["frequency_penalty"] = config.frequency_penalty
268
272
  if config.stop_seqs is not None:
@@ -303,13 +307,23 @@ class OpenAIAPI(ModelAPI):
303
307
 
304
308
  # convert some well known bad request errors into ModelOutput
305
309
  def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
306
- if e.status_code == 400 and e.code == "context_length_exceeded":
310
+ if e.status_code == 400:
311
+ # extract message
307
312
  if isinstance(e.body, dict) and "message" in e.body.keys():
308
313
  content = str(e.body.get("message"))
309
314
  else:
310
315
  content = e.message
316
+
317
+ # narrow stop_reason
318
+ if e.code == "context_length_exceeded":
319
+ stop_reason: StopReason = "model_length"
320
+ elif e.code == "invalid_prompt":
321
+ stop_reason = "content_filter"
322
+ else:
323
+ stop_reason = "unknown"
324
+
311
325
  return ModelOutput.from_content(
312
- model=self.model_name, content=content, stop_reason="model_length"
326
+ model=self.model_name, content=content, stop_reason=stop_reason
313
327
  )
314
328
  else:
315
329
  raise e
@@ -25,7 +25,7 @@ from inspect_ai.model import (
25
25
  from inspect_ai.tool import ToolCall, ToolInfo
26
26
 
27
27
  from .._model_call import ModelCall
28
- from .._model_output import ModelUsage
28
+ from .._model_output import ModelUsage, StopReason
29
29
  from .._providers.util import (
30
30
  ChatAPIHandler,
31
31
  ChatAPIMessage,
@@ -48,12 +48,6 @@ async def generate_o1(
48
48
  # create chatapi handler
49
49
  handler = O1PreviewChatAPIHandler()
50
50
 
51
- # map max_tokens => max_completion_tokens
52
- max_tokens = params.get("max_tokens", None)
53
- if max_tokens:
54
- params["max_completion_tokens"] = max_tokens
55
- del params["max_tokens"]
56
-
57
51
  # call model
58
52
  request = dict(
59
53
  model=model,
@@ -89,12 +83,16 @@ async def generate_o1(
89
83
 
90
84
 
91
85
  def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
92
- if ex.code == "invalid_prompt":
93
- return ModelOutput.from_content(
94
- model=model, content=str(ex), stop_reason="content_filter"
95
- )
86
+ if ex.code == "context_length_exceeded":
87
+ stop_reason: StopReason = "model_length"
88
+ elif ex.code == "invalid_prompt":
89
+ stop_reason = "content_filter"
96
90
  else:
97
- raise ex
91
+ stop_reason = "unknown"
92
+
93
+ return ModelOutput.from_content(
94
+ model=model, content=str(ex), stop_reason=stop_reason
95
+ )
98
96
 
99
97
 
100
98
  def chat_messages(
@@ -3,13 +3,20 @@ from rich.console import RenderableType
3
3
  from inspect_ai.tool._tool_call import ToolCall
4
4
  from inspect_ai.tool._tool_transcript import transcript_tool_call
5
5
 
6
- from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
6
+ from ._chat_message import (
7
+ ChatMessage,
8
+ ChatMessageAssistant,
9
+ ChatMessageTool,
10
+ ChatMessageUser,
11
+ )
7
12
 
8
13
 
9
14
  def messages_preceding_assistant(messages: list[ChatMessage]) -> list[ChatMessage]:
10
15
  preceding: list[ChatMessage] = []
11
16
  for m in reversed(messages):
12
- if not isinstance(m, ChatMessageTool | ChatMessageAssistant):
17
+ if not isinstance(m, ChatMessageTool | ChatMessageAssistant) and not (
18
+ isinstance(m, ChatMessageUser) and m.tool_call_id
19
+ ):
13
20
  preceding.append(m)
14
21
  else:
15
22
  break
@@ -90,6 +90,13 @@ class Score(BaseModel):
90
90
  """Read the score as a boolean."""
91
91
  return bool(self._as_scalar())
92
92
 
93
+ def as_list(self) -> list[str | int | float | bool]:
94
+ """Read the score as a list."""
95
+ if isinstance(self.value, list):
96
+ return self.value
97
+ else:
98
+ raise ValueError("This score is not a list")
99
+
93
100
  def as_dict(self) -> dict[str, str | int | float | bool | None]:
94
101
  """Read the score as a dictionary."""
95
102
  if isinstance(self.value, dict):
@@ -104,13 +111,17 @@ class Score(BaseModel):
104
111
  raise ValueError("This score is not a scalar")
105
112
 
106
113
 
107
- class SampleScore(Score):
114
+ class SampleScore(BaseModel):
108
115
  """Score for a Sample
109
116
 
110
117
  Args:
118
+ score: Score
111
119
  sample_id: (str | int | None) Unique id of a sample
112
120
  """
113
121
 
122
+ score: Score
123
+ """A score"""
124
+
114
125
  sample_id: str | int | None = Field(default=None)
115
126
  """A sample id"""
116
127
 
@@ -4,6 +4,7 @@ from ._basic_agent import basic_agent
4
4
  from ._chain import chain
5
5
  from ._critique import self_critique
6
6
  from ._fork import fork
7
+ from ._human_agent.agent import human_agent
7
8
  from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
8
9
  from ._plan import Plan, plan
9
10
  from ._prompt import (
@@ -17,6 +18,7 @@ from ._use_tools import use_tools
17
18
 
18
19
  __all__ = [
19
20
  "basic_agent",
21
+ "human_agent",
20
22
  "chain",
21
23
  "fork",
22
24
  "generate",
@@ -0,0 +1,83 @@
1
+ import asyncio
2
+
3
+ from inspect_ai.util import display_type, input_panel, sandbox
4
+
5
+ from .._solver import Generate, Solver, solver
6
+ from .._task_state import TaskState
7
+ from .commands import human_agent_commands
8
+ from .install import install_human_agent
9
+ from .panel import HumanAgentPanel
10
+ from .service import run_human_agent_service
11
+ from .view import ConsoleView, HumanAgentView
12
+
13
+
14
+ @solver
15
+ def human_agent(
16
+ answer: bool | str = True,
17
+ intermediate_scoring: bool = False,
18
+ record_session: bool = True,
19
+ ) -> Solver:
20
+ """Human solver for agentic tasks that run in a Linux environment.
21
+
22
+ The Human agent solver installs agent task tools in the default
23
+ sandbox and presents the user with both task instructions and
24
+ documentation for the various tools (e.g. `task submit`,
25
+ `task start`, `task stop` `task instructions`, etc.). A human agent panel
26
+ is displayed with instructions for logging in to the sandbox.
27
+
28
+ If the user is running in VS Code with the Inspect extension,
29
+ they will also be presented with links to login to the sandbox
30
+ using a VS Code Window or Terminal.
31
+
32
+ Args:
33
+ answer (bool | str): Is an explicit answer required for this
34
+ task or is it scored based on files in the container? Pass a
35
+ `str` with a regex to validate that the answer matches
36
+ the expected format.
37
+ intermediate_scoring (bool): Allow the human agent to
38
+ check their score while working.
39
+ record_session (bool): Record all user commands and outputs in
40
+ the sandbox bash session.
41
+
42
+ Returns:
43
+ Solver: Human agent solver.
44
+ """
45
+ # we can only run one human agent interaction at a time (use lock to enforce)
46
+ agent_lock = asyncio.Lock()
47
+
48
+ async def solve(state: TaskState, generate: Generate) -> TaskState:
49
+ async with agent_lock:
50
+ # ensure that we have a sandbox to work with
51
+ try:
52
+ connection = await sandbox().connection()
53
+ except ProcessLookupError:
54
+ raise RuntimeError("Human agent must run in a task with a sandbox.")
55
+ except NotImplementedError:
56
+ raise RuntimeError(
57
+ "Human agent must run with a sandbox that supports connections."
58
+ )
59
+
60
+ # helper function to run the agent (called for fullscreen vs. fallback below)
61
+ async def run_human_agent(view: HumanAgentView) -> TaskState:
62
+ # create agent commands
63
+ commands = human_agent_commands(
64
+ state, answer, intermediate_scoring, record_session
65
+ )
66
+
67
+ # install agent tools
68
+ await install_human_agent(state, commands, record_session)
69
+
70
+ # hookup the view ui
71
+ view.connect(connection)
72
+
73
+ # run sandbox service
74
+ return await run_human_agent_service(state, commands, view)
75
+
76
+ # support both fullscreen ui and fallback
77
+ if display_type() == "full":
78
+ async with await input_panel(HumanAgentPanel) as panel:
79
+ return await run_human_agent(panel)
80
+ else:
81
+ return await run_human_agent(ConsoleView())
82
+
83
+ return solve
@@ -0,0 +1,36 @@
1
+ from inspect_ai.solver._task_state import TaskState
2
+
3
+ from .clock import StartCommand, StopCommand
4
+ from .command import HumanAgentCommand
5
+ from .instructions import InstructionsCommand
6
+ from .note import NoteCommand
7
+ from .score import ScoreCommand
8
+ from .status import StatusCommand
9
+ from .submit import SubmitCommand, ValidateCommand
10
+
11
+
12
+ def human_agent_commands(
13
+ state: TaskState,
14
+ answer: bool | str,
15
+ intermediate_scoring: bool,
16
+ record_session: bool,
17
+ ) -> list[HumanAgentCommand]:
18
+ # base submit and validate
19
+ commands = [SubmitCommand(record_session), ValidateCommand(answer)]
20
+
21
+ # optional intermediate scoring
22
+ if intermediate_scoring:
23
+ commands.append(ScoreCommand(state))
24
+
25
+ # remaining commands
26
+ commands.extend(
27
+ [
28
+ NoteCommand(),
29
+ StatusCommand(),
30
+ StartCommand(),
31
+ StopCommand(),
32
+ ]
33
+ )
34
+
35
+ # with instructions (letting it see the other commands)
36
+ return commands + [InstructionsCommand(commands)]