inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. inspect_ai/_display/textual/widgets/samples.py +3 -3
  2. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  3. inspect_ai/_eval/loader.py +1 -1
  4. inspect_ai/_eval/task/run.py +21 -12
  5. inspect_ai/_util/answer.py +26 -0
  6. inspect_ai/_util/constants.py +0 -1
  7. inspect_ai/_util/exception.py +4 -0
  8. inspect_ai/_util/hash.py +39 -0
  9. inspect_ai/_util/local_server.py +51 -21
  10. inspect_ai/_util/path.py +22 -0
  11. inspect_ai/_util/trace.py +1 -1
  12. inspect_ai/_util/working.py +4 -0
  13. inspect_ai/_view/www/dist/assets/index.css +23 -22
  14. inspect_ai/_view/www/dist/assets/index.js +517 -204
  15. inspect_ai/_view/www/log-schema.json +375 -0
  16. inspect_ai/_view/www/package.json +1 -1
  17. inspect_ai/_view/www/src/@types/log.d.ts +90 -12
  18. inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
  19. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
  20. inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
  21. inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
  22. inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
  23. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
  24. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
  25. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  26. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  27. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  28. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  29. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  30. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  31. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  32. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  33. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  34. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  35. inspect_ai/_view/www/src/app/types.ts +12 -2
  36. inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
  37. inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
  38. inspect_ai/_view/www/src/state/hooks.ts +19 -3
  39. inspect_ai/_view/www/src/state/logSlice.ts +23 -5
  40. inspect_ai/_view/www/yarn.lock +9 -9
  41. inspect_ai/agent/_as_solver.py +3 -1
  42. inspect_ai/agent/_as_tool.py +6 -4
  43. inspect_ai/agent/_bridge/patch.py +1 -3
  44. inspect_ai/agent/_handoff.py +5 -1
  45. inspect_ai/agent/_react.py +4 -3
  46. inspect_ai/agent/_run.py +6 -1
  47. inspect_ai/agent/_types.py +9 -0
  48. inspect_ai/analysis/__init__.py +0 -0
  49. inspect_ai/analysis/beta/__init__.py +57 -0
  50. inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
  51. inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
  52. inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
  53. inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
  54. inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
  55. inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
  56. inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
  57. inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
  58. inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
  59. inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
  60. inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
  61. inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
  62. inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
  63. inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
  64. inspect_ai/analysis/beta/_dataframe/record.py +377 -0
  65. inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
  66. inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
  67. inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
  68. inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
  69. inspect_ai/analysis/beta/_dataframe/util.py +157 -0
  70. inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
  71. inspect_ai/dataset/_dataset.py +6 -3
  72. inspect_ai/log/__init__.py +10 -0
  73. inspect_ai/log/_convert.py +4 -9
  74. inspect_ai/log/_file.py +1 -1
  75. inspect_ai/log/_log.py +21 -1
  76. inspect_ai/log/_samples.py +14 -17
  77. inspect_ai/log/_transcript.py +77 -35
  78. inspect_ai/log/_tree.py +118 -0
  79. inspect_ai/model/_call_tools.py +44 -35
  80. inspect_ai/model/_model.py +51 -44
  81. inspect_ai/model/_openai_responses.py +17 -18
  82. inspect_ai/model/_providers/anthropic.py +30 -5
  83. inspect_ai/model/_providers/hf.py +27 -1
  84. inspect_ai/model/_providers/providers.py +1 -1
  85. inspect_ai/model/_providers/sglang.py +8 -2
  86. inspect_ai/model/_providers/vllm.py +6 -2
  87. inspect_ai/scorer/_choice.py +1 -2
  88. inspect_ai/solver/_chain.py +1 -1
  89. inspect_ai/solver/_fork.py +1 -1
  90. inspect_ai/solver/_multiple_choice.py +9 -23
  91. inspect_ai/solver/_plan.py +2 -2
  92. inspect_ai/solver/_task_state.py +7 -3
  93. inspect_ai/solver/_transcript.py +6 -7
  94. inspect_ai/tool/_mcp/_context.py +3 -5
  95. inspect_ai/tool/_mcp/_mcp.py +6 -5
  96. inspect_ai/tool/_mcp/server.py +1 -1
  97. inspect_ai/tool/_tools/_execute.py +4 -1
  98. inspect_ai/tool/_tools/_think.py +1 -1
  99. inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
  100. inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
  101. inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
  102. inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
  103. inspect_ai/util/__init__.py +4 -0
  104. inspect_ai/util/_anyio.py +11 -0
  105. inspect_ai/util/_collect.py +50 -0
  106. inspect_ai/util/_sandbox/events.py +3 -2
  107. inspect_ai/util/_span.py +58 -0
  108. inspect_ai/util/_subtask.py +27 -42
  109. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
  110. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
  111. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
  112. inspect_ai/_display/core/group.py +0 -79
  113. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
  114. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
  115. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,118 @@
1
+ from dataclasses import dataclass, field
2
+ from logging import getLogger
3
+ from typing import Iterable, Sequence, TypeAlias
4
+
5
+ from ._transcript import Event, SpanBeginEvent, SpanEndEvent
6
+
7
+ logger = getLogger(__name__)
8
+
9
+ EventNode: TypeAlias = "SpanNode" | Event
10
+ """Node in an event tree."""
11
+
12
+ EventTree: TypeAlias = list[EventNode]
13
+ """Tree of events (has invividual events and event spans)."""
14
+
15
+
16
+ @dataclass
17
+ class SpanNode:
18
+ """Event tree node representing a span of events."""
19
+
20
+ id: str
21
+ """Span id."""
22
+
23
+ parent_id: str | None
24
+ """Parent span id."""
25
+
26
+ type: str | None
27
+ """Optional 'type' field for span."""
28
+
29
+ name: str
30
+ """Span name."""
31
+
32
+ begin: SpanBeginEvent
33
+ """Span begin event."""
34
+
35
+ end: SpanEndEvent | None = None
36
+ """Span end event (if any)."""
37
+
38
+ children: list[EventNode] = field(default_factory=list)
39
+ """Children in the span."""
40
+
41
+
42
+ def event_tree(events: Sequence[Event]) -> EventTree:
43
+ """Build a tree representation of a sequence of events.
44
+
45
+ Organize events heirarchially into event spans.
46
+
47
+ Args:
48
+ events: Sequence of `Event`.
49
+
50
+ Returns:
51
+ Event tree.
52
+ """
53
+ # Convert one flat list of (possibly interleaved) events into *forest*
54
+ # (list of root-level items).
55
+
56
+ # Pre-create one node per span so we can attach events no matter when they
57
+ # arrive in the file. A single forward scan guarantees that the order of
58
+ # `children` inside every span reflects the order in which things appeared
59
+ # in the transcript.
60
+ nodes: dict[str, SpanNode] = {
61
+ ev.id: SpanNode(
62
+ id=ev.id, parent_id=ev.parent_id, type=ev.type, name=ev.name, begin=ev
63
+ )
64
+ for ev in events
65
+ if isinstance(ev, SpanBeginEvent)
66
+ }
67
+
68
+ roots: list[EventNode] = []
69
+
70
+ # Where should an event with `span_id` go?
71
+ def bucket(span_id: str | None) -> list[EventNode]:
72
+ if span_id and span_id in nodes:
73
+ return nodes[span_id].children
74
+ return roots # root level
75
+
76
+ # Single pass in original order
77
+ for ev in events:
78
+ if isinstance(ev, SpanBeginEvent): # span starts
79
+ bucket(ev.parent_id).append(nodes[ev.id])
80
+
81
+ elif isinstance(ev, SpanEndEvent): # span ends
82
+ if n := nodes.get(ev.id):
83
+ n.end = ev
84
+ else:
85
+ logger.warning(f"Span end event (id: {ev.id} with no span begin)")
86
+
87
+ else: # ordinary event
88
+ bucket(ev.span_id).append(ev)
89
+
90
+ return roots
91
+
92
+
93
+ def event_sequence(tree: EventTree) -> Iterable[Event]:
94
+ """Flatten a span forest back into a properly ordered seqeunce.
95
+
96
+ Args:
97
+ tree: Event tree
98
+
99
+ Returns:
100
+ Sequence of events.
101
+ """
102
+ for item in tree:
103
+ if isinstance(item, SpanNode):
104
+ yield item.begin
105
+ yield from event_sequence(item.children)
106
+ if item.end:
107
+ yield item.end
108
+ else:
109
+ yield item
110
+
111
+
112
+ def _print_event_tree(tree: EventTree, indent: str = "") -> None:
113
+ for item in tree:
114
+ if isinstance(item, SpanNode):
115
+ print(f"{indent}span ({item.type}): {item.name}")
116
+ _print_event_tree(item.children, f"{indent} ")
117
+ else:
118
+ print(f"{indent}{item.event}")
@@ -39,6 +39,7 @@ from inspect_ai._util.content import (
39
39
  ContentText,
40
40
  ContentVideo,
41
41
  )
42
+ from inspect_ai._util.exception import TerminateSampleError
42
43
  from inspect_ai._util.format import format_function_call
43
44
  from inspect_ai._util.logger import warn_once
44
45
  from inspect_ai._util.registry import registry_unqualified_name
@@ -61,6 +62,7 @@ from inspect_ai.tool._tool_params import ToolParams
61
62
  from inspect_ai.util import OutputLimitExceededError
62
63
  from inspect_ai.util._anyio import inner_exception
63
64
  from inspect_ai.util._limit import LimitExceededError, apply_limits
65
+ from inspect_ai.util._span import span
64
66
 
65
67
  from ._chat_message import (
66
68
  ChatMessage,
@@ -109,26 +111,18 @@ async def execute_tools(
109
111
  """
110
112
  message = messages[-1]
111
113
  if isinstance(message, ChatMessageAssistant) and message.tool_calls:
112
- from inspect_ai.log._transcript import (
113
- ToolEvent,
114
- Transcript,
115
- init_transcript,
116
- track_store_changes,
117
- transcript,
118
- )
114
+ from inspect_ai.log._transcript import ToolEvent, transcript
119
115
 
120
116
  tdefs = await tool_defs(tools)
121
117
 
122
118
  async def call_tool_task(
123
119
  call: ToolCall,
120
+ event: ToolEvent,
124
121
  conversation: list[ChatMessage],
125
122
  send_stream: MemoryObjectSendStream[
126
123
  tuple[ExecuteToolsResult, ToolEvent, Exception | None]
127
124
  ],
128
125
  ) -> None:
129
- # create a transript for this call
130
- init_transcript(Transcript(name=call.function))
131
-
132
126
  result: ToolResult = ""
133
127
  messages: list[ChatMessage] = []
134
128
  output: ModelOutput | None = None
@@ -136,15 +130,14 @@ async def execute_tools(
136
130
  tool_error: ToolCallError | None = None
137
131
  tool_exception: Exception | None = None
138
132
  try:
139
- with track_store_changes():
140
- try:
141
- result, messages, output, agent = await call_tool(
142
- tdefs, message.text, call, conversation
143
- )
144
- # unwrap exception group
145
- except Exception as ex:
146
- inner_ex = inner_exception(ex)
147
- raise inner_ex.with_traceback(inner_ex.__traceback__)
133
+ try:
134
+ result, messages, output, agent = await call_tool(
135
+ tdefs, message.text, call, event, conversation
136
+ )
137
+ # unwrap exception group
138
+ except Exception as ex:
139
+ inner_ex = inner_exception(ex)
140
+ raise inner_ex.with_traceback(inner_ex.__traceback__)
148
141
 
149
142
  except TimeoutError:
150
143
  tool_error = ToolCallError(
@@ -227,7 +220,6 @@ async def execute_tools(
227
220
  truncated=truncated,
228
221
  view=call.view,
229
222
  error=tool_error,
230
- events=list(transcript().events),
231
223
  agent=agent,
232
224
  )
233
225
 
@@ -270,7 +262,6 @@ async def execute_tools(
270
262
  internal=call.internal,
271
263
  pending=True,
272
264
  )
273
- transcript()._event(event)
274
265
 
275
266
  # execute the tool call. if the operator cancels the
276
267
  # tool call then synthesize the appropriate message/event
@@ -280,7 +271,7 @@ async def execute_tools(
280
271
 
281
272
  result_exception = None
282
273
  async with anyio.create_task_group() as tg:
283
- tg.start_soon(call_tool_task, call, messages, send_stream)
274
+ tg.start_soon(call_tool_task, call, event, messages, send_stream)
284
275
  event._set_cancel_fn(tg.cancel_scope.cancel)
285
276
  async with receive_stream:
286
277
  (
@@ -306,7 +297,6 @@ async def execute_tools(
306
297
  truncated=None,
307
298
  view=call.view,
308
299
  error=tool_message.error,
309
- events=[],
310
300
  )
311
301
  transcript().info(
312
302
  f"Tool call '{call.function}' was cancelled by operator."
@@ -326,7 +316,6 @@ async def execute_tools(
326
316
  result=result_event.result,
327
317
  truncated=result_event.truncated,
328
318
  error=result_event.error,
329
- events=result_event.events,
330
319
  waiting_time=waiting_time_end - waiting_time_start,
331
320
  agent=result_event.agent,
332
321
  failed=True if result_exception else None,
@@ -347,19 +336,34 @@ async def execute_tools(
347
336
 
348
337
 
349
338
  async def call_tool(
350
- tools: list[ToolDef], message: str, call: ToolCall, conversation: list[ChatMessage]
339
+ tools: list[ToolDef],
340
+ message: str,
341
+ call: ToolCall,
342
+ event: BaseModel,
343
+ conversation: list[ChatMessage],
351
344
  ) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str | None]:
352
345
  from inspect_ai.agent._handoff import AgentTool
353
- from inspect_ai.log._transcript import SampleLimitEvent, transcript
346
+ from inspect_ai.log._transcript import SampleLimitEvent, ToolEvent, transcript
347
+
348
+ # dodge circular import
349
+ assert isinstance(event, ToolEvent)
350
+
351
+ # this function is responsible for transcript events so that it can
352
+ # put them in the right enclosure (e.g. handoff/agent/tool). This
353
+ # means that if we throw early we need to do the enclosure when raising.
354
+ async def record_tool_parsing_error(error: str) -> Exception:
355
+ async with span(name=call.function, type="tool"):
356
+ transcript()._event(event)
357
+ return ToolParsingError(error)
354
358
 
355
359
  # if there was an error parsing the ToolCall, raise that
356
360
  if call.parse_error:
357
- raise ToolParsingError(call.parse_error)
361
+ raise await record_tool_parsing_error(call.parse_error)
358
362
 
359
363
  # find the tool
360
364
  tool_def = next((tool for tool in tools if tool.name == call.function), None)
361
365
  if tool_def is None:
362
- raise ToolParsingError(f"Tool {call.function} not found")
366
+ raise await record_tool_parsing_error(f"Tool {call.function} not found")
363
367
 
364
368
  # if we have a tool approver, apply it now
365
369
  from inspect_ai.approval._apply import apply_tool_approval
@@ -373,7 +377,7 @@ async def call_tool(
373
377
  transcript()._event(
374
378
  SampleLimitEvent(type="operator", limit=1, message=message)
375
379
  )
376
- raise LimitExceededError("operator", value=1, limit=1, message=message)
380
+ raise TerminateSampleError(message)
377
381
  else:
378
382
  raise ToolApprovalError(approval.explanation if approval else None)
379
383
  if approval and approval.modified:
@@ -382,7 +386,7 @@ async def call_tool(
382
386
  # validate the schema of the passed object
383
387
  validation_errors = validate_tool_input(call.arguments, tool_def.parameters)
384
388
  if validation_errors:
385
- raise ToolParsingError(validation_errors)
389
+ raise await record_tool_parsing_error(validation_errors)
386
390
 
387
391
  # get arguments (with creation of dataclasses, pydantic objects, etc.)
388
392
  arguments = tool_params(call.arguments, tool_def.tool)
@@ -391,14 +395,18 @@ async def call_tool(
391
395
  with trace_action(
392
396
  logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
393
397
  ):
394
- # agent tools get special handling
395
398
  if isinstance(tool_def.tool, AgentTool):
396
- return await agent_handoff(tool_def, call, conversation)
399
+ async with span(tool_def.tool.name, type="handoff"):
400
+ async with span(name=call.function, type="tool"):
401
+ transcript()._event(event)
402
+ return await agent_handoff(tool_def, call, conversation)
397
403
 
398
404
  # normal tool call
399
405
  else:
400
- result: ToolResult = await tool_def.tool(**arguments)
401
- return result, [], None, None
406
+ async with span(name=call.function, type="tool"):
407
+ transcript()._event(event)
408
+ result: ToolResult = await tool_def.tool(**arguments)
409
+ return result, [], None, None
402
410
 
403
411
 
404
412
  async def agent_handoff(
@@ -463,7 +471,8 @@ async def agent_handoff(
463
471
  agent_state = AgentState(messages=copy(agent_conversation))
464
472
  try:
465
473
  with apply_limits(agent_tool.limits):
466
- agent_state = await agent_tool.agent(agent_state, **arguments)
474
+ async with span(name=agent_name, type="agent"):
475
+ agent_state = await agent_tool.agent(agent_state, **arguments)
467
476
  except LimitExceededError as ex:
468
477
  limit_error = ex
469
478
 
@@ -19,6 +19,7 @@ from typing import (
19
19
  cast,
20
20
  )
21
21
 
22
+ from pydantic import BaseModel
22
23
  from pydantic_core import to_jsonable_python
23
24
  from tenacity import (
24
25
  RetryCallState,
@@ -402,36 +403,32 @@ class Model:
402
403
  start_time = datetime.now()
403
404
  working_start = sample_working_time()
404
405
  async with self._connection_concurrency(config):
405
- from inspect_ai.log._samples import track_active_sample_retries
406
-
407
406
  # generate
408
- with track_active_sample_retries():
409
- output = await self._generate(
410
- input=input,
411
- tools=tools,
412
- tool_choice=tool_choice,
413
- config=config,
414
- cache=cache,
415
- )
407
+ output, event = await self._generate(
408
+ input=input,
409
+ tools=tools,
410
+ tool_choice=tool_choice,
411
+ config=config,
412
+ cache=cache,
413
+ )
416
414
 
417
415
  # update the most recent ModelEvent with the actual start/completed
418
416
  # times as well as a computation of working time (events are
419
417
  # created _after_ the call to _generate, potentially in response
420
418
  # to retries, so they need their timestamp updated so it accurately
421
419
  # reflects the full start/end time which we know here)
422
- from inspect_ai.log._transcript import ModelEvent, transcript
423
-
424
- last_model_event = transcript().find_last_event(ModelEvent)
425
- if last_model_event:
426
- last_model_event.timestamp = start_time
427
- last_model_event.working_start = working_start
428
- completed = datetime.now()
429
- last_model_event.completed = completed
430
- last_model_event.working_time = (
431
- output.time
432
- if output.time is not None
433
- else (completed - start_time).total_seconds()
434
- )
420
+ from inspect_ai.log._transcript import ModelEvent
421
+
422
+ assert isinstance(event, ModelEvent)
423
+ event.timestamp = start_time
424
+ event.working_start = working_start
425
+ completed = datetime.now()
426
+ event.completed = completed
427
+ event.working_time = (
428
+ output.time
429
+ if output.time is not None
430
+ else (completed - start_time).total_seconds()
431
+ )
435
432
 
436
433
  # return output
437
434
  return output
@@ -492,9 +489,12 @@ class Model:
492
489
  tool_choice: ToolChoice | None,
493
490
  config: GenerateConfig,
494
491
  cache: bool | CachePolicy = False,
495
- ) -> ModelOutput:
492
+ ) -> tuple[ModelOutput, BaseModel]:
493
+ from inspect_ai.log._samples import track_active_model_event
494
+ from inspect_ai.log._transcript import ModelEvent
495
+
496
496
  # default to 'auto' for tool_choice (same as underlying model apis)
497
- tool_choice = tool_choice if tool_choice else "auto"
497
+ tool_choice = tool_choice if tool_choice is not None else "auto"
498
498
 
499
499
  # resolve top level tool source
500
500
  if isinstance(tools, ToolSource):
@@ -581,7 +581,10 @@ class Model:
581
581
  stop=stop,
582
582
  before_sleep=functools.partial(log_model_retry, self.api.model_name),
583
583
  )
584
- async def generate() -> ModelOutput:
584
+ async def generate() -> tuple[ModelOutput, BaseModel]:
585
+ # type-checker can't see that we made sure tool_choice is not none in the outer frame
586
+ assert tool_choice is not None
587
+
585
588
  check_sample_interrupt()
586
589
 
587
590
  cache_entry: CacheEntry | None
@@ -602,7 +605,7 @@ class Model:
602
605
  )
603
606
  existing = cache_fetch(cache_entry)
604
607
  if isinstance(existing, ModelOutput):
605
- self._record_model_interaction(
608
+ _, event = self._record_model_interaction(
606
609
  input=input,
607
610
  tools=tools_info,
608
611
  tool_choice=tool_choice,
@@ -611,7 +614,7 @@ class Model:
611
614
  output=existing,
612
615
  call=None,
613
616
  )
614
- return existing
617
+ return existing, event
615
618
  else:
616
619
  cache_entry = None
617
620
 
@@ -620,7 +623,7 @@ class Model:
620
623
 
621
624
  # record the interaction before the call to generate
622
625
  # (we'll update it with the results once we have them)
623
- complete = self._record_model_interaction(
626
+ complete, event = self._record_model_interaction(
624
627
  input=input,
625
628
  tools=tools_info,
626
629
  tool_choice=tool_choice,
@@ -631,12 +634,14 @@ class Model:
631
634
  with trace_action(logger, "Model", f"generate ({str(self)})"):
632
635
  time_start = time.monotonic()
633
636
  try:
634
- result = await self.api.generate(
635
- input=input,
636
- tools=tools_info,
637
- tool_choice=tool_choice,
638
- config=config,
639
- )
637
+ assert isinstance(event, ModelEvent)
638
+ with track_active_model_event(event):
639
+ result = await self.api.generate(
640
+ input=input,
641
+ tools=tools_info,
642
+ tool_choice=tool_choice,
643
+ config=config,
644
+ )
640
645
  finally:
641
646
  time_elapsed = time.monotonic() - time_start
642
647
 
@@ -686,18 +691,18 @@ class Model:
686
691
  if cache and cache_entry:
687
692
  cache_store(entry=cache_entry, output=output)
688
693
 
689
- return output
694
+ return output, event
690
695
 
691
696
  # call the model (this will so retries, etc., so report waiting time
692
697
  # as elapsed time - actual time for successful model call)
693
698
  time_start = time.monotonic()
694
- model_output = await generate()
699
+ model_output, event = await generate()
695
700
  total_time = time.monotonic() - time_start
696
701
  if model_output.time:
697
702
  report_sample_waiting_time(total_time - model_output.time)
698
703
 
699
704
  # return results
700
- return model_output
705
+ return model_output, event
701
706
 
702
707
  def should_retry(self, ex: BaseException) -> bool:
703
708
  if isinstance(ex, Exception):
@@ -769,7 +774,7 @@ class Model:
769
774
  cache: Literal["read", "write"] | None,
770
775
  output: ModelOutput | None = None,
771
776
  call: ModelCall | None = None,
772
- ) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
777
+ ) -> tuple[Callable[[ModelOutput | Exception, ModelCall | None], None], BaseModel]:
773
778
  from inspect_ai.log._transcript import ModelEvent, transcript
774
779
 
775
780
  # create event and add it to the transcript
@@ -809,7 +814,7 @@ class Model:
809
814
  if output:
810
815
  complete(output, call)
811
816
 
812
- return complete
817
+ return complete, event
813
818
 
814
819
 
815
820
  class ModelName:
@@ -1232,9 +1237,10 @@ def tool_result_images_as_user_message(
1232
1237
 
1233
1238
  Tool responses will have images replaced with "Image content is included below.", and the new user message will contain the images.
1234
1239
  """
1235
- init_accum: ImagesAccumulator = ([], [], [])
1236
1240
  chat_messages, user_message_content, tool_call_ids = functools.reduce(
1237
- tool_result_images_reducer, messages, init_accum
1241
+ tool_result_images_reducer,
1242
+ messages,
1243
+ (list[ChatMessage](), list[Content](), list[str]()),
1238
1244
  )
1239
1245
  # if the last message was a tool result, we may need to flush the pending stuff here
1240
1246
  return maybe_adding_user_message(chat_messages, user_message_content, tool_call_ids)
@@ -1260,9 +1266,10 @@ def tool_result_images_reducer(
1260
1266
  and isinstance(message.content, list)
1261
1267
  and any([isinstance(c, ContentImage) for c in message.content])
1262
1268
  ):
1263
- init_accum: ImageContentAccumulator = ([], [])
1264
1269
  new_user_message_content, edited_tool_message_content = functools.reduce(
1265
- tool_result_image_content_reducer, message.content, init_accum
1270
+ tool_result_image_content_reducer,
1271
+ message.content,
1272
+ (list[Content](), list[Content]()),
1266
1273
  )
1267
1274
 
1268
1275
  return (
@@ -184,24 +184,23 @@ def openai_responses_chat_choices(
184
184
  # │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │
185
185
  # │ │ │ type: "reasoning" │ │ │ │ │ │ ContentText │ │ │ │ │ │ type: "reasoning" │ │ │
186
186
  # │ │ │ id: "rs_bbbbbb" │ │ │ │ │ │ text: "" │ │ │ │ │ │ id: "rs_bbbbbb" │ │ │
187
- # │ │ │ summary: [] │ │ │ │ │ └───────────────────┘ │ │ │ │ │ summary: [] │ │ │
188
- # │ │ └───────────────────┘ │ │ │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │
189
- # │ │ ┌───────────────────┐ │ │ │ │ │ ContentText │ │ │ │ │ │ type: "message" │ │ │
190
- # │ │ │ type: "message" │ │ │ │ │ text: "text1" │ │ │ │ │ id: "msg_ccccccc" │ │ │
191
- # │ │ │ id: "msg_ccccccc" │ │ │ │ │ └───────────────────┘ │ │ │ │ │ role: "assistant" │ │ │
192
- # │ │ │ role: "assistant" │ │ │--->│┌───────────────────┐│--->│ │ │ ┌───────────────┐ │ │ │
193
- # │ │ │ ┌───────────────┐ │ │ │ │ │ ContentText │ │ │ │ │ │ Content │ │ │ │
194
- # │ │ │ │ Content │ │ │ │ │ text: "text2" │ │ │ │ │ │ │ ┌───────────┐ │ │ │ │
195
- # │ │ │ │ ┌───────────┐ │ │ │ │ │ └───────────────────────┘ │ │ │ │ │ │"text1" │ │ │ │ │
196
- # │ │ │ │ │"text1" │ │ │ │ │ │ ┌───────────────────────┐ │ │ │ │ │ └───────────┘ │ │ │ │
197
- # │ │ │ │ └───────────┘ │ │ │ │ │ │ internal │ │ │ │ │ │ ┌───────────┐ │ │ │ │
198
- # │ │ │ │ ┌───────────┐ │ │ │ │ │ │ ┌───────────────────┐ │ │ │ │ │ │ "text2" │ │ │ │
199
- # │ │ │ "text2" │ │ │ │ │ │ │ │ reasoning_id: │ │ │ │ │ │ └───────────┘ │ │ │
200
- # │ │ └───────────┘ │ │ │ │ │ "rs_bbbbbb" │ │ │ │ └───────────────┘ │ │
201
- # │ └───────────────┘ │ │ │ │ │ └───────────────────┘ │ │ │ └───────────────────┘
202
- # │ │ └───────────────────┘ │ │ ┌───────────────────┐ │ └───────────────────────┘ │
203
- # └───────────────────────┘ │ │ output_msg_id: │ │ │ └───────────────────────────┘
204
- # └───────────────────────────┘ │ │ │ "msg_ccccccc" │ │ │
187
+ # │ │ │ summary: [] │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ summary: [] │ │ │
188
+ # │ │ ├───────────────────┤ │ │ │ │ ContentText │ │ │ │ ├───────────────────┤ │ │
189
+ # │ │ type: "message" │ │ │ │ │ text: "text1" │ │ │ │ │ │ type: "message" │ │ │
190
+ # │ │ │ id: "msg_ccccccc" │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ id: "msg_ccccccc" │ │ │
191
+ # │ │ │ role: "assistant" │ │ │ │ │ ContentText │ │ │ │ │ role: "assistant" │ │ │
192
+ # │ │ │ ┌───────────────┐ │ │ ->text: "text2" │ │ │ -> │ │ │ ┌───────────────┐ │ │ │
193
+ # │ │ │ Content │ │ │ │ │ └───────────────────┘ │ │ │ │ │ │ Content │ │ │ │
194
+ # │ │ │ │ ┌───────────┐ │ │ │ │ │ └───────────────────────┘ │ │ │ │ │ ┌───────────┐ │ │ │ │
195
+ # │ │ │ │ │"text1" │ │ │ │ │ │ ┌───────────────────────┐ │ │ │ │ │ │"text1" │ │ │ │ │
196
+ # │ │ │ │ ├───────────┤ │ │ │ │ │ internal │ │ │ │ ├───────────┤ │ │ │ │
197
+ # │ │ │ │ │"text2" │ │ │ │ │ │ │ ┌───────────────────┐ │ │ │ │ │ │ │"text2" │ │ │ │ │
198
+ # │ │ │ │ └───────────┘ │ │ │ │ │ │ reasoning_id: │ │ │ │ │ │ └───────────┘ │ │ │ │
199
+ # │ │ │ └───────────────┘ │ │ │ │ │ │ "rs_bbbbbb" │ │ │ │ │ │ └───────────────┘ │ │ │
200
+ # │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │
201
+ # │ └───────────────────────┘ │ │ │ ┌───────────────────┐ │ │ │ └───────────────────────┘
202
+ # └───────────────────────────┘ │ │ │ output_msg_id: │ │ │ └───────────────────────────┘
203
+ # │ │ │ "msg_ccccccc" │ │ │
205
204
  # │ │ └───────────────────┘ │ │
206
205
  # │ └───────────────────────┘ │
207
206
  # └───────────────────────────┘
@@ -33,7 +33,10 @@ from anthropic.types import (
33
33
  ToolUseBlockParam,
34
34
  message_create_params,
35
35
  )
36
- from anthropic.types.beta import BetaToolComputerUse20250124Param
36
+ from anthropic.types.beta import (
37
+ BetaToolComputerUse20250124Param,
38
+ BetaToolTextEditor20241022Param,
39
+ )
37
40
  from pydantic import JsonValue
38
41
  from typing_extensions import override
39
42
 
@@ -218,6 +221,8 @@ class AnthropicAPI(ModelAPI):
218
221
  # tools are generally available for Claude 3.5 Sonnet (new) as well and
219
222
  # can be used without the computer use beta header.
220
223
  betas.append("computer-use-2025-01-24")
224
+ if any("20241022" in str(tool.get("type", "")) for tool in tools_param):
225
+ betas.append("computer-use-2024-10-22")
221
226
  if len(betas) > 0:
222
227
  extra_headers["anthropic-beta"] = ",".join(betas)
223
228
 
@@ -337,6 +342,15 @@ class AnthropicAPI(ModelAPI):
337
342
  @override
338
343
  def should_retry(self, ex: Exception) -> bool:
339
344
  if isinstance(ex, APIStatusError):
345
+ # for unknown reasons, anthropic does not always set status_code == 529
346
+ # for "overloaded_error" so we check for it explicitly
347
+ if (
348
+ isinstance(ex.body, dict)
349
+ and ex.body.get("error", {}).get("type", "") == "overloaded_error"
350
+ ):
351
+ return True
352
+
353
+ # standard http status code checking
340
354
  return is_retryable_http_status(ex.status_code)
341
355
  elif httpx_should_retry(ex):
342
356
  return True
@@ -545,7 +559,7 @@ class AnthropicAPI(ModelAPI):
545
559
 
546
560
  def text_editor_tool_param(
547
561
  self, tool: ToolInfo
548
- ) -> Optional[ToolTextEditor20250124Param]:
562
+ ) -> ToolTextEditor20250124Param | BetaToolTextEditor20241022Param | None:
549
563
  # check for compatible 'text editor' tool
550
564
  if tool.name == "text_editor" and (
551
565
  sorted(tool.parameters.properties.keys())
@@ -561,8 +575,14 @@ class AnthropicAPI(ModelAPI):
561
575
  ]
562
576
  )
563
577
  ):
564
- return ToolTextEditor20250124Param(
565
- type="text_editor_20250124", name="str_replace_editor"
578
+ return (
579
+ BetaToolTextEditor20241022Param(
580
+ type="text_editor_20241022", name="str_replace_editor"
581
+ )
582
+ if self.is_claude_3_5()
583
+ else ToolTextEditor20250124Param(
584
+ type="text_editor_20250124", name="str_replace_editor"
585
+ )
566
586
  )
567
587
  # not a text_editor tool
568
588
  else:
@@ -571,7 +591,10 @@ class AnthropicAPI(ModelAPI):
571
591
 
572
592
  # tools can be either a stock tool param or a special Anthropic native use tool param
573
593
  ToolParamDef = (
574
- ToolParam | BetaToolComputerUse20250124Param | ToolTextEditor20250124Param
594
+ ToolParam
595
+ | BetaToolComputerUse20250124Param
596
+ | ToolTextEditor20250124Param
597
+ | BetaToolTextEditor20241022Param
575
598
  )
576
599
 
577
600
 
@@ -580,6 +603,7 @@ def add_cache_control(
580
603
  | ToolParam
581
604
  | BetaToolComputerUse20250124Param
582
605
  | ToolTextEditor20250124Param
606
+ | BetaToolTextEditor20241022Param
583
607
  | dict[str, Any],
584
608
  ) -> None:
585
609
  cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
@@ -844,6 +868,7 @@ def _names_for_tool_call(
844
868
  """
845
869
  mappings = (
846
870
  (INTERNAL_COMPUTER_TOOL_NAME, "computer_20250124", "computer"),
871
+ ("str_replace_editor", "text_editor_20241022", "text_editor"),
847
872
  ("str_replace_editor", "text_editor_20250124", "text_editor"),
848
873
  ("bash", "bash_20250124", "bash_session"),
849
874
  )