inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/task/run.py +21 -12
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/exception.py +4 -0
- inspect_ai/_util/hash.py +39 -0
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_util/path.py +22 -0
- inspect_ai/_util/trace.py +1 -1
- inspect_ai/_util/working.py +4 -0
- inspect_ai/_view/www/dist/assets/index.css +23 -22
- inspect_ai/_view/www/dist/assets/index.js +517 -204
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
- inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
- inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/types.ts +12 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
- inspect_ai/_view/www/src/state/hooks.ts +19 -3
- inspect_ai/_view/www/src/state/logSlice.ts +23 -5
- inspect_ai/_view/www/yarn.lock +9 -9
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_bridge/patch.py +1 -3
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/analysis/__init__.py +0 -0
- inspect_ai/analysis/beta/__init__.py +57 -0
- inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
- inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
- inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
- inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
- inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
- inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
- inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
- inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
- inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
- inspect_ai/analysis/beta/_dataframe/record.py +377 -0
- inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
- inspect_ai/analysis/beta/_dataframe/util.py +157 -0
- inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +1 -1
- inspect_ai/log/_log.py +21 -1
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +44 -35
- inspect_ai/model/_model.py +51 -44
- inspect_ai/model/_openai_responses.py +17 -18
- inspect_ai/model/_providers/anthropic.py +30 -5
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +9 -23
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +7 -3
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_context.py +3 -5
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
- inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
- inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
- inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_sandbox/events.py +3 -2
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
inspect_ai/log/_tree.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from logging import getLogger
|
3
|
+
from typing import Iterable, Sequence, TypeAlias
|
4
|
+
|
5
|
+
from ._transcript import Event, SpanBeginEvent, SpanEndEvent
|
6
|
+
|
7
|
+
logger = getLogger(__name__)
|
8
|
+
|
9
|
+
EventNode: TypeAlias = "SpanNode" | Event
|
10
|
+
"""Node in an event tree."""
|
11
|
+
|
12
|
+
EventTree: TypeAlias = list[EventNode]
|
13
|
+
"""Tree of events (has invividual events and event spans)."""
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class SpanNode:
|
18
|
+
"""Event tree node representing a span of events."""
|
19
|
+
|
20
|
+
id: str
|
21
|
+
"""Span id."""
|
22
|
+
|
23
|
+
parent_id: str | None
|
24
|
+
"""Parent span id."""
|
25
|
+
|
26
|
+
type: str | None
|
27
|
+
"""Optional 'type' field for span."""
|
28
|
+
|
29
|
+
name: str
|
30
|
+
"""Span name."""
|
31
|
+
|
32
|
+
begin: SpanBeginEvent
|
33
|
+
"""Span begin event."""
|
34
|
+
|
35
|
+
end: SpanEndEvent | None = None
|
36
|
+
"""Span end event (if any)."""
|
37
|
+
|
38
|
+
children: list[EventNode] = field(default_factory=list)
|
39
|
+
"""Children in the span."""
|
40
|
+
|
41
|
+
|
42
|
+
def event_tree(events: Sequence[Event]) -> EventTree:
|
43
|
+
"""Build a tree representation of a sequence of events.
|
44
|
+
|
45
|
+
Organize events heirarchially into event spans.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
events: Sequence of `Event`.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
Event tree.
|
52
|
+
"""
|
53
|
+
# Convert one flat list of (possibly interleaved) events into *forest*
|
54
|
+
# (list of root-level items).
|
55
|
+
|
56
|
+
# Pre-create one node per span so we can attach events no matter when they
|
57
|
+
# arrive in the file. A single forward scan guarantees that the order of
|
58
|
+
# `children` inside every span reflects the order in which things appeared
|
59
|
+
# in the transcript.
|
60
|
+
nodes: dict[str, SpanNode] = {
|
61
|
+
ev.id: SpanNode(
|
62
|
+
id=ev.id, parent_id=ev.parent_id, type=ev.type, name=ev.name, begin=ev
|
63
|
+
)
|
64
|
+
for ev in events
|
65
|
+
if isinstance(ev, SpanBeginEvent)
|
66
|
+
}
|
67
|
+
|
68
|
+
roots: list[EventNode] = []
|
69
|
+
|
70
|
+
# Where should an event with `span_id` go?
|
71
|
+
def bucket(span_id: str | None) -> list[EventNode]:
|
72
|
+
if span_id and span_id in nodes:
|
73
|
+
return nodes[span_id].children
|
74
|
+
return roots # root level
|
75
|
+
|
76
|
+
# Single pass in original order
|
77
|
+
for ev in events:
|
78
|
+
if isinstance(ev, SpanBeginEvent): # span starts
|
79
|
+
bucket(ev.parent_id).append(nodes[ev.id])
|
80
|
+
|
81
|
+
elif isinstance(ev, SpanEndEvent): # span ends
|
82
|
+
if n := nodes.get(ev.id):
|
83
|
+
n.end = ev
|
84
|
+
else:
|
85
|
+
logger.warning(f"Span end event (id: {ev.id} with no span begin)")
|
86
|
+
|
87
|
+
else: # ordinary event
|
88
|
+
bucket(ev.span_id).append(ev)
|
89
|
+
|
90
|
+
return roots
|
91
|
+
|
92
|
+
|
93
|
+
def event_sequence(tree: EventTree) -> Iterable[Event]:
|
94
|
+
"""Flatten a span forest back into a properly ordered seqeunce.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
tree: Event tree
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Sequence of events.
|
101
|
+
"""
|
102
|
+
for item in tree:
|
103
|
+
if isinstance(item, SpanNode):
|
104
|
+
yield item.begin
|
105
|
+
yield from event_sequence(item.children)
|
106
|
+
if item.end:
|
107
|
+
yield item.end
|
108
|
+
else:
|
109
|
+
yield item
|
110
|
+
|
111
|
+
|
112
|
+
def _print_event_tree(tree: EventTree, indent: str = "") -> None:
|
113
|
+
for item in tree:
|
114
|
+
if isinstance(item, SpanNode):
|
115
|
+
print(f"{indent}span ({item.type}): {item.name}")
|
116
|
+
_print_event_tree(item.children, f"{indent} ")
|
117
|
+
else:
|
118
|
+
print(f"{indent}{item.event}")
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -39,6 +39,7 @@ from inspect_ai._util.content import (
|
|
39
39
|
ContentText,
|
40
40
|
ContentVideo,
|
41
41
|
)
|
42
|
+
from inspect_ai._util.exception import TerminateSampleError
|
42
43
|
from inspect_ai._util.format import format_function_call
|
43
44
|
from inspect_ai._util.logger import warn_once
|
44
45
|
from inspect_ai._util.registry import registry_unqualified_name
|
@@ -61,6 +62,7 @@ from inspect_ai.tool._tool_params import ToolParams
|
|
61
62
|
from inspect_ai.util import OutputLimitExceededError
|
62
63
|
from inspect_ai.util._anyio import inner_exception
|
63
64
|
from inspect_ai.util._limit import LimitExceededError, apply_limits
|
65
|
+
from inspect_ai.util._span import span
|
64
66
|
|
65
67
|
from ._chat_message import (
|
66
68
|
ChatMessage,
|
@@ -109,26 +111,18 @@ async def execute_tools(
|
|
109
111
|
"""
|
110
112
|
message = messages[-1]
|
111
113
|
if isinstance(message, ChatMessageAssistant) and message.tool_calls:
|
112
|
-
from inspect_ai.log._transcript import
|
113
|
-
ToolEvent,
|
114
|
-
Transcript,
|
115
|
-
init_transcript,
|
116
|
-
track_store_changes,
|
117
|
-
transcript,
|
118
|
-
)
|
114
|
+
from inspect_ai.log._transcript import ToolEvent, transcript
|
119
115
|
|
120
116
|
tdefs = await tool_defs(tools)
|
121
117
|
|
122
118
|
async def call_tool_task(
|
123
119
|
call: ToolCall,
|
120
|
+
event: ToolEvent,
|
124
121
|
conversation: list[ChatMessage],
|
125
122
|
send_stream: MemoryObjectSendStream[
|
126
123
|
tuple[ExecuteToolsResult, ToolEvent, Exception | None]
|
127
124
|
],
|
128
125
|
) -> None:
|
129
|
-
# create a transript for this call
|
130
|
-
init_transcript(Transcript(name=call.function))
|
131
|
-
|
132
126
|
result: ToolResult = ""
|
133
127
|
messages: list[ChatMessage] = []
|
134
128
|
output: ModelOutput | None = None
|
@@ -136,15 +130,14 @@ async def execute_tools(
|
|
136
130
|
tool_error: ToolCallError | None = None
|
137
131
|
tool_exception: Exception | None = None
|
138
132
|
try:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
raise inner_ex.with_traceback(inner_ex.__traceback__)
|
133
|
+
try:
|
134
|
+
result, messages, output, agent = await call_tool(
|
135
|
+
tdefs, message.text, call, event, conversation
|
136
|
+
)
|
137
|
+
# unwrap exception group
|
138
|
+
except Exception as ex:
|
139
|
+
inner_ex = inner_exception(ex)
|
140
|
+
raise inner_ex.with_traceback(inner_ex.__traceback__)
|
148
141
|
|
149
142
|
except TimeoutError:
|
150
143
|
tool_error = ToolCallError(
|
@@ -227,7 +220,6 @@ async def execute_tools(
|
|
227
220
|
truncated=truncated,
|
228
221
|
view=call.view,
|
229
222
|
error=tool_error,
|
230
|
-
events=list(transcript().events),
|
231
223
|
agent=agent,
|
232
224
|
)
|
233
225
|
|
@@ -270,7 +262,6 @@ async def execute_tools(
|
|
270
262
|
internal=call.internal,
|
271
263
|
pending=True,
|
272
264
|
)
|
273
|
-
transcript()._event(event)
|
274
265
|
|
275
266
|
# execute the tool call. if the operator cancels the
|
276
267
|
# tool call then synthesize the appropriate message/event
|
@@ -280,7 +271,7 @@ async def execute_tools(
|
|
280
271
|
|
281
272
|
result_exception = None
|
282
273
|
async with anyio.create_task_group() as tg:
|
283
|
-
tg.start_soon(call_tool_task, call, messages, send_stream)
|
274
|
+
tg.start_soon(call_tool_task, call, event, messages, send_stream)
|
284
275
|
event._set_cancel_fn(tg.cancel_scope.cancel)
|
285
276
|
async with receive_stream:
|
286
277
|
(
|
@@ -306,7 +297,6 @@ async def execute_tools(
|
|
306
297
|
truncated=None,
|
307
298
|
view=call.view,
|
308
299
|
error=tool_message.error,
|
309
|
-
events=[],
|
310
300
|
)
|
311
301
|
transcript().info(
|
312
302
|
f"Tool call '{call.function}' was cancelled by operator."
|
@@ -326,7 +316,6 @@ async def execute_tools(
|
|
326
316
|
result=result_event.result,
|
327
317
|
truncated=result_event.truncated,
|
328
318
|
error=result_event.error,
|
329
|
-
events=result_event.events,
|
330
319
|
waiting_time=waiting_time_end - waiting_time_start,
|
331
320
|
agent=result_event.agent,
|
332
321
|
failed=True if result_exception else None,
|
@@ -347,19 +336,34 @@ async def execute_tools(
|
|
347
336
|
|
348
337
|
|
349
338
|
async def call_tool(
|
350
|
-
tools: list[ToolDef],
|
339
|
+
tools: list[ToolDef],
|
340
|
+
message: str,
|
341
|
+
call: ToolCall,
|
342
|
+
event: BaseModel,
|
343
|
+
conversation: list[ChatMessage],
|
351
344
|
) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str | None]:
|
352
345
|
from inspect_ai.agent._handoff import AgentTool
|
353
|
-
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
346
|
+
from inspect_ai.log._transcript import SampleLimitEvent, ToolEvent, transcript
|
347
|
+
|
348
|
+
# dodge circular import
|
349
|
+
assert isinstance(event, ToolEvent)
|
350
|
+
|
351
|
+
# this function is responsible for transcript events so that it can
|
352
|
+
# put them in the right enclosure (e.g. handoff/agent/tool). This
|
353
|
+
# means that if we throw early we need to do the enclosure when raising.
|
354
|
+
async def record_tool_parsing_error(error: str) -> Exception:
|
355
|
+
async with span(name=call.function, type="tool"):
|
356
|
+
transcript()._event(event)
|
357
|
+
return ToolParsingError(error)
|
354
358
|
|
355
359
|
# if there was an error parsing the ToolCall, raise that
|
356
360
|
if call.parse_error:
|
357
|
-
raise
|
361
|
+
raise await record_tool_parsing_error(call.parse_error)
|
358
362
|
|
359
363
|
# find the tool
|
360
364
|
tool_def = next((tool for tool in tools if tool.name == call.function), None)
|
361
365
|
if tool_def is None:
|
362
|
-
raise
|
366
|
+
raise await record_tool_parsing_error(f"Tool {call.function} not found")
|
363
367
|
|
364
368
|
# if we have a tool approver, apply it now
|
365
369
|
from inspect_ai.approval._apply import apply_tool_approval
|
@@ -373,7 +377,7 @@ async def call_tool(
|
|
373
377
|
transcript()._event(
|
374
378
|
SampleLimitEvent(type="operator", limit=1, message=message)
|
375
379
|
)
|
376
|
-
raise
|
380
|
+
raise TerminateSampleError(message)
|
377
381
|
else:
|
378
382
|
raise ToolApprovalError(approval.explanation if approval else None)
|
379
383
|
if approval and approval.modified:
|
@@ -382,7 +386,7 @@ async def call_tool(
|
|
382
386
|
# validate the schema of the passed object
|
383
387
|
validation_errors = validate_tool_input(call.arguments, tool_def.parameters)
|
384
388
|
if validation_errors:
|
385
|
-
raise
|
389
|
+
raise await record_tool_parsing_error(validation_errors)
|
386
390
|
|
387
391
|
# get arguments (with creation of dataclasses, pydantic objects, etc.)
|
388
392
|
arguments = tool_params(call.arguments, tool_def.tool)
|
@@ -391,14 +395,18 @@ async def call_tool(
|
|
391
395
|
with trace_action(
|
392
396
|
logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
|
393
397
|
):
|
394
|
-
# agent tools get special handling
|
395
398
|
if isinstance(tool_def.tool, AgentTool):
|
396
|
-
|
399
|
+
async with span(tool_def.tool.name, type="handoff"):
|
400
|
+
async with span(name=call.function, type="tool"):
|
401
|
+
transcript()._event(event)
|
402
|
+
return await agent_handoff(tool_def, call, conversation)
|
397
403
|
|
398
404
|
# normal tool call
|
399
405
|
else:
|
400
|
-
|
401
|
-
|
406
|
+
async with span(name=call.function, type="tool"):
|
407
|
+
transcript()._event(event)
|
408
|
+
result: ToolResult = await tool_def.tool(**arguments)
|
409
|
+
return result, [], None, None
|
402
410
|
|
403
411
|
|
404
412
|
async def agent_handoff(
|
@@ -463,7 +471,8 @@ async def agent_handoff(
|
|
463
471
|
agent_state = AgentState(messages=copy(agent_conversation))
|
464
472
|
try:
|
465
473
|
with apply_limits(agent_tool.limits):
|
466
|
-
|
474
|
+
async with span(name=agent_name, type="agent"):
|
475
|
+
agent_state = await agent_tool.agent(agent_state, **arguments)
|
467
476
|
except LimitExceededError as ex:
|
468
477
|
limit_error = ex
|
469
478
|
|
inspect_ai/model/_model.py
CHANGED
@@ -19,6 +19,7 @@ from typing import (
|
|
19
19
|
cast,
|
20
20
|
)
|
21
21
|
|
22
|
+
from pydantic import BaseModel
|
22
23
|
from pydantic_core import to_jsonable_python
|
23
24
|
from tenacity import (
|
24
25
|
RetryCallState,
|
@@ -402,36 +403,32 @@ class Model:
|
|
402
403
|
start_time = datetime.now()
|
403
404
|
working_start = sample_working_time()
|
404
405
|
async with self._connection_concurrency(config):
|
405
|
-
from inspect_ai.log._samples import track_active_sample_retries
|
406
|
-
|
407
406
|
# generate
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
)
|
407
|
+
output, event = await self._generate(
|
408
|
+
input=input,
|
409
|
+
tools=tools,
|
410
|
+
tool_choice=tool_choice,
|
411
|
+
config=config,
|
412
|
+
cache=cache,
|
413
|
+
)
|
416
414
|
|
417
415
|
# update the most recent ModelEvent with the actual start/completed
|
418
416
|
# times as well as a computation of working time (events are
|
419
417
|
# created _after_ the call to _generate, potentially in response
|
420
418
|
# to retries, so they need their timestamp updated so it accurately
|
421
419
|
# reflects the full start/end time which we know here)
|
422
|
-
from inspect_ai.log._transcript import ModelEvent
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
)
|
420
|
+
from inspect_ai.log._transcript import ModelEvent
|
421
|
+
|
422
|
+
assert isinstance(event, ModelEvent)
|
423
|
+
event.timestamp = start_time
|
424
|
+
event.working_start = working_start
|
425
|
+
completed = datetime.now()
|
426
|
+
event.completed = completed
|
427
|
+
event.working_time = (
|
428
|
+
output.time
|
429
|
+
if output.time is not None
|
430
|
+
else (completed - start_time).total_seconds()
|
431
|
+
)
|
435
432
|
|
436
433
|
# return output
|
437
434
|
return output
|
@@ -492,9 +489,12 @@ class Model:
|
|
492
489
|
tool_choice: ToolChoice | None,
|
493
490
|
config: GenerateConfig,
|
494
491
|
cache: bool | CachePolicy = False,
|
495
|
-
) -> ModelOutput:
|
492
|
+
) -> tuple[ModelOutput, BaseModel]:
|
493
|
+
from inspect_ai.log._samples import track_active_model_event
|
494
|
+
from inspect_ai.log._transcript import ModelEvent
|
495
|
+
|
496
496
|
# default to 'auto' for tool_choice (same as underlying model apis)
|
497
|
-
tool_choice = tool_choice if tool_choice else "auto"
|
497
|
+
tool_choice = tool_choice if tool_choice is not None else "auto"
|
498
498
|
|
499
499
|
# resolve top level tool source
|
500
500
|
if isinstance(tools, ToolSource):
|
@@ -581,7 +581,10 @@ class Model:
|
|
581
581
|
stop=stop,
|
582
582
|
before_sleep=functools.partial(log_model_retry, self.api.model_name),
|
583
583
|
)
|
584
|
-
async def generate() -> ModelOutput:
|
584
|
+
async def generate() -> tuple[ModelOutput, BaseModel]:
|
585
|
+
# type-checker can't see that we made sure tool_choice is not none in the outer frame
|
586
|
+
assert tool_choice is not None
|
587
|
+
|
585
588
|
check_sample_interrupt()
|
586
589
|
|
587
590
|
cache_entry: CacheEntry | None
|
@@ -602,7 +605,7 @@ class Model:
|
|
602
605
|
)
|
603
606
|
existing = cache_fetch(cache_entry)
|
604
607
|
if isinstance(existing, ModelOutput):
|
605
|
-
self._record_model_interaction(
|
608
|
+
_, event = self._record_model_interaction(
|
606
609
|
input=input,
|
607
610
|
tools=tools_info,
|
608
611
|
tool_choice=tool_choice,
|
@@ -611,7 +614,7 @@ class Model:
|
|
611
614
|
output=existing,
|
612
615
|
call=None,
|
613
616
|
)
|
614
|
-
return existing
|
617
|
+
return existing, event
|
615
618
|
else:
|
616
619
|
cache_entry = None
|
617
620
|
|
@@ -620,7 +623,7 @@ class Model:
|
|
620
623
|
|
621
624
|
# record the interaction before the call to generate
|
622
625
|
# (we'll update it with the results once we have them)
|
623
|
-
complete = self._record_model_interaction(
|
626
|
+
complete, event = self._record_model_interaction(
|
624
627
|
input=input,
|
625
628
|
tools=tools_info,
|
626
629
|
tool_choice=tool_choice,
|
@@ -631,12 +634,14 @@ class Model:
|
|
631
634
|
with trace_action(logger, "Model", f"generate ({str(self)})"):
|
632
635
|
time_start = time.monotonic()
|
633
636
|
try:
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
637
|
+
assert isinstance(event, ModelEvent)
|
638
|
+
with track_active_model_event(event):
|
639
|
+
result = await self.api.generate(
|
640
|
+
input=input,
|
641
|
+
tools=tools_info,
|
642
|
+
tool_choice=tool_choice,
|
643
|
+
config=config,
|
644
|
+
)
|
640
645
|
finally:
|
641
646
|
time_elapsed = time.monotonic() - time_start
|
642
647
|
|
@@ -686,18 +691,18 @@ class Model:
|
|
686
691
|
if cache and cache_entry:
|
687
692
|
cache_store(entry=cache_entry, output=output)
|
688
693
|
|
689
|
-
return output
|
694
|
+
return output, event
|
690
695
|
|
691
696
|
# call the model (this will so retries, etc., so report waiting time
|
692
697
|
# as elapsed time - actual time for successful model call)
|
693
698
|
time_start = time.monotonic()
|
694
|
-
model_output = await generate()
|
699
|
+
model_output, event = await generate()
|
695
700
|
total_time = time.monotonic() - time_start
|
696
701
|
if model_output.time:
|
697
702
|
report_sample_waiting_time(total_time - model_output.time)
|
698
703
|
|
699
704
|
# return results
|
700
|
-
return model_output
|
705
|
+
return model_output, event
|
701
706
|
|
702
707
|
def should_retry(self, ex: BaseException) -> bool:
|
703
708
|
if isinstance(ex, Exception):
|
@@ -769,7 +774,7 @@ class Model:
|
|
769
774
|
cache: Literal["read", "write"] | None,
|
770
775
|
output: ModelOutput | None = None,
|
771
776
|
call: ModelCall | None = None,
|
772
|
-
) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
|
777
|
+
) -> tuple[Callable[[ModelOutput | Exception, ModelCall | None], None], BaseModel]:
|
773
778
|
from inspect_ai.log._transcript import ModelEvent, transcript
|
774
779
|
|
775
780
|
# create event and add it to the transcript
|
@@ -809,7 +814,7 @@ class Model:
|
|
809
814
|
if output:
|
810
815
|
complete(output, call)
|
811
816
|
|
812
|
-
return complete
|
817
|
+
return complete, event
|
813
818
|
|
814
819
|
|
815
820
|
class ModelName:
|
@@ -1232,9 +1237,10 @@ def tool_result_images_as_user_message(
|
|
1232
1237
|
|
1233
1238
|
Tool responses will have images replaced with "Image content is included below.", and the new user message will contain the images.
|
1234
1239
|
"""
|
1235
|
-
init_accum: ImagesAccumulator = ([], [], [])
|
1236
1240
|
chat_messages, user_message_content, tool_call_ids = functools.reduce(
|
1237
|
-
tool_result_images_reducer,
|
1241
|
+
tool_result_images_reducer,
|
1242
|
+
messages,
|
1243
|
+
(list[ChatMessage](), list[Content](), list[str]()),
|
1238
1244
|
)
|
1239
1245
|
# if the last message was a tool result, we may need to flush the pending stuff here
|
1240
1246
|
return maybe_adding_user_message(chat_messages, user_message_content, tool_call_ids)
|
@@ -1260,9 +1266,10 @@ def tool_result_images_reducer(
|
|
1260
1266
|
and isinstance(message.content, list)
|
1261
1267
|
and any([isinstance(c, ContentImage) for c in message.content])
|
1262
1268
|
):
|
1263
|
-
init_accum: ImageContentAccumulator = ([], [])
|
1264
1269
|
new_user_message_content, edited_tool_message_content = functools.reduce(
|
1265
|
-
tool_result_image_content_reducer,
|
1270
|
+
tool_result_image_content_reducer,
|
1271
|
+
message.content,
|
1272
|
+
(list[Content](), list[Content]()),
|
1266
1273
|
)
|
1267
1274
|
|
1268
1275
|
return (
|
@@ -184,24 +184,23 @@ def openai_responses_chat_choices(
|
|
184
184
|
# │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │
|
185
185
|
# │ │ │ type: "reasoning" │ │ │ │ │ │ ContentText │ │ │ │ │ │ type: "reasoning" │ │ │
|
186
186
|
# │ │ │ id: "rs_bbbbbb" │ │ │ │ │ │ text: "" │ │ │ │ │ │ id: "rs_bbbbbb" │ │ │
|
187
|
-
# │ │ │ summary: [] │ │ │ │ │
|
188
|
-
# │ │
|
189
|
-
# │ │
|
190
|
-
# │ │ │
|
191
|
-
# │ │ │
|
192
|
-
# │ │ │
|
193
|
-
# │ │ │
|
194
|
-
# │ │ │ │
|
195
|
-
# │ │ │ │
|
196
|
-
# │ │ │ │
|
197
|
-
# │ │ │ │
|
198
|
-
# │ │ │ │
|
199
|
-
# │ │ │
|
200
|
-
# │ │
|
201
|
-
# │
|
202
|
-
# │ │
|
203
|
-
#
|
204
|
-
# └───────────────────────────┘ │ │ │ "msg_ccccccc" │ │ │
|
187
|
+
# │ │ │ summary: [] │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ summary: [] │ │ │
|
188
|
+
# │ │ ├───────────────────┤ │ │ │ │ │ ContentText │ │ │ │ │ ├───────────────────┤ │ │
|
189
|
+
# │ │ │ type: "message" │ │ │ │ │ │ text: "text1" │ │ │ │ │ │ type: "message" │ │ │
|
190
|
+
# │ │ │ id: "msg_ccccccc" │ │ │ │ │ ├───────────────────┤ │ │ │ │ │ id: "msg_ccccccc" │ │ │
|
191
|
+
# │ │ │ role: "assistant" │ │ │ │ │ │ ContentText │ │ │ │ │ │ role: "assistant" │ │ │
|
192
|
+
# │ │ │ ┌───────────────┐ │ │ │ -> │ │ │ text: "text2" │ │ │ -> │ │ │ ┌───────────────┐ │ │ │
|
193
|
+
# │ │ │ │ Content │ │ │ │ │ │ └───────────────────┘ │ │ │ │ │ │ Content │ │ │ │
|
194
|
+
# │ │ │ │ ┌───────────┐ │ │ │ │ │ └───────────────────────┘ │ │ │ │ │ ┌───────────┐ │ │ │ │
|
195
|
+
# │ │ │ │ │"text1" │ │ │ │ │ │ ┌───────────────────────┐ │ │ │ │ │ │"text1" │ │ │ │ │
|
196
|
+
# │ │ │ │ ├───────────┤ │ │ │ │ │ │ internal │ │ │ │ │ │ ├───────────┤ │ │ │ │
|
197
|
+
# │ │ │ │ │"text2" │ │ │ │ │ │ │ ┌───────────────────┐ │ │ │ │ │ │ │"text2" │ │ │ │ │
|
198
|
+
# │ │ │ │ └───────────┘ │ │ │ │ │ │ │ reasoning_id: │ │ │ │ │ │ │ └───────────┘ │ │ │ │
|
199
|
+
# │ │ │ └───────────────┘ │ │ │ │ │ │ "rs_bbbbbb" │ │ │ │ │ │ └───────────────┘ │ │ │
|
200
|
+
# │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │ │ │ └───────────────────┘ │ │
|
201
|
+
# │ └───────────────────────┘ │ │ │ ┌───────────────────┐ │ │ │ └───────────────────────┘ │
|
202
|
+
# └───────────────────────────┘ │ │ │ output_msg_id: │ │ │ └───────────────────────────┘
|
203
|
+
# │ │ │ "msg_ccccccc" │ │ │
|
205
204
|
# │ │ └───────────────────┘ │ │
|
206
205
|
# │ └───────────────────────┘ │
|
207
206
|
# └───────────────────────────┘
|
@@ -33,7 +33,10 @@ from anthropic.types import (
|
|
33
33
|
ToolUseBlockParam,
|
34
34
|
message_create_params,
|
35
35
|
)
|
36
|
-
from anthropic.types.beta import
|
36
|
+
from anthropic.types.beta import (
|
37
|
+
BetaToolComputerUse20250124Param,
|
38
|
+
BetaToolTextEditor20241022Param,
|
39
|
+
)
|
37
40
|
from pydantic import JsonValue
|
38
41
|
from typing_extensions import override
|
39
42
|
|
@@ -218,6 +221,8 @@ class AnthropicAPI(ModelAPI):
|
|
218
221
|
# tools are generally available for Claude 3.5 Sonnet (new) as well and
|
219
222
|
# can be used without the computer use beta header.
|
220
223
|
betas.append("computer-use-2025-01-24")
|
224
|
+
if any("20241022" in str(tool.get("type", "")) for tool in tools_param):
|
225
|
+
betas.append("computer-use-2024-10-22")
|
221
226
|
if len(betas) > 0:
|
222
227
|
extra_headers["anthropic-beta"] = ",".join(betas)
|
223
228
|
|
@@ -337,6 +342,15 @@ class AnthropicAPI(ModelAPI):
|
|
337
342
|
@override
|
338
343
|
def should_retry(self, ex: Exception) -> bool:
|
339
344
|
if isinstance(ex, APIStatusError):
|
345
|
+
# for unknown reasons, anthropic does not always set status_code == 529
|
346
|
+
# for "overloaded_error" so we check for it explicitly
|
347
|
+
if (
|
348
|
+
isinstance(ex.body, dict)
|
349
|
+
and ex.body.get("error", {}).get("type", "") == "overloaded_error"
|
350
|
+
):
|
351
|
+
return True
|
352
|
+
|
353
|
+
# standard http status code checking
|
340
354
|
return is_retryable_http_status(ex.status_code)
|
341
355
|
elif httpx_should_retry(ex):
|
342
356
|
return True
|
@@ -545,7 +559,7 @@ class AnthropicAPI(ModelAPI):
|
|
545
559
|
|
546
560
|
def text_editor_tool_param(
|
547
561
|
self, tool: ToolInfo
|
548
|
-
) ->
|
562
|
+
) -> ToolTextEditor20250124Param | BetaToolTextEditor20241022Param | None:
|
549
563
|
# check for compatible 'text editor' tool
|
550
564
|
if tool.name == "text_editor" and (
|
551
565
|
sorted(tool.parameters.properties.keys())
|
@@ -561,8 +575,14 @@ class AnthropicAPI(ModelAPI):
|
|
561
575
|
]
|
562
576
|
)
|
563
577
|
):
|
564
|
-
return
|
565
|
-
|
578
|
+
return (
|
579
|
+
BetaToolTextEditor20241022Param(
|
580
|
+
type="text_editor_20241022", name="str_replace_editor"
|
581
|
+
)
|
582
|
+
if self.is_claude_3_5()
|
583
|
+
else ToolTextEditor20250124Param(
|
584
|
+
type="text_editor_20250124", name="str_replace_editor"
|
585
|
+
)
|
566
586
|
)
|
567
587
|
# not a text_editor tool
|
568
588
|
else:
|
@@ -571,7 +591,10 @@ class AnthropicAPI(ModelAPI):
|
|
571
591
|
|
572
592
|
# tools can be either a stock tool param or a special Anthropic native use tool param
|
573
593
|
ToolParamDef = (
|
574
|
-
ToolParam
|
594
|
+
ToolParam
|
595
|
+
| BetaToolComputerUse20250124Param
|
596
|
+
| ToolTextEditor20250124Param
|
597
|
+
| BetaToolTextEditor20241022Param
|
575
598
|
)
|
576
599
|
|
577
600
|
|
@@ -580,6 +603,7 @@ def add_cache_control(
|
|
580
603
|
| ToolParam
|
581
604
|
| BetaToolComputerUse20250124Param
|
582
605
|
| ToolTextEditor20250124Param
|
606
|
+
| BetaToolTextEditor20241022Param
|
583
607
|
| dict[str, Any],
|
584
608
|
) -> None:
|
585
609
|
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
|
@@ -844,6 +868,7 @@ def _names_for_tool_call(
|
|
844
868
|
"""
|
845
869
|
mappings = (
|
846
870
|
(INTERNAL_COMPUTER_TOOL_NAME, "computer_20250124", "computer"),
|
871
|
+
("str_replace_editor", "text_editor_20241022", "text_editor"),
|
847
872
|
("str_replace_editor", "text_editor_20250124", "text_editor"),
|
848
873
|
("bash", "bash_20250124", "bash_session"),
|
849
874
|
)
|