inspect-ai 0.3.93__py3-none-any.whl → 0.3.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/task/run.py +10 -7
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_view/www/dist/assets/index.css +14 -13
- inspect_ai/_view/www/dist/assets/index.js +400 -84
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +42 -34
- inspect_ai/model/_model.py +45 -40
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +5 -22
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +56 -51
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
inspect_ai/log/_samples.py
CHANGED
@@ -5,12 +5,11 @@ from typing import AsyncGenerator, Iterator, Literal
|
|
5
5
|
|
6
6
|
from shortuuid import uuid
|
7
7
|
|
8
|
-
from inspect_ai._util.constants import SAMPLE_SUBTASK
|
9
8
|
from inspect_ai.dataset._dataset import Sample
|
10
9
|
from inspect_ai.util._sandbox import SandboxConnection
|
11
10
|
from inspect_ai.util._sandbox.context import sandbox_connections
|
12
11
|
|
13
|
-
from ._transcript import
|
12
|
+
from ._transcript import ModelEvent, Transcript
|
14
13
|
|
15
14
|
|
16
15
|
class ActiveSample:
|
@@ -47,7 +46,6 @@ class ActiveSample:
|
|
47
46
|
self.total_tokens = 0
|
48
47
|
self.transcript = transcript
|
49
48
|
self.sandboxes = sandboxes
|
50
|
-
self.retry_count = 0
|
51
49
|
self._interrupt_action: Literal["score", "error"] | None = None
|
52
50
|
|
53
51
|
@property
|
@@ -151,27 +149,26 @@ def set_active_sample_total_messages(total_messages: int) -> None:
|
|
151
149
|
active.total_messages = total_messages
|
152
150
|
|
153
151
|
|
152
|
+
_active_model_event: ContextVar[ModelEvent | None] = ContextVar(
|
153
|
+
"_active_model_event", default=None
|
154
|
+
)
|
155
|
+
|
156
|
+
|
154
157
|
@contextlib.contextmanager
|
155
|
-
def
|
156
|
-
|
158
|
+
def track_active_model_event(event: ModelEvent) -> Iterator[None]:
|
159
|
+
token = _active_model_event.set(event)
|
157
160
|
try:
|
158
161
|
yield
|
159
162
|
finally:
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
def reset_active_sample_retries() -> None:
|
164
|
-
active = sample_active()
|
165
|
-
if active:
|
166
|
-
active.retry_count = 0
|
163
|
+
_active_model_event.reset(token)
|
167
164
|
|
168
165
|
|
169
166
|
def report_active_sample_retry() -> None:
|
170
|
-
|
171
|
-
if
|
172
|
-
|
173
|
-
|
174
|
-
|
167
|
+
model_event = _active_model_event.get()
|
168
|
+
if model_event is not None:
|
169
|
+
if model_event.retries is None:
|
170
|
+
model_event.retries = 0
|
171
|
+
model_event.retries = model_event.retries + 1
|
175
172
|
|
176
173
|
|
177
174
|
_sample_active: ContextVar[ActiveSample | None] = ContextVar(
|
inspect_ai/log/_transcript.py
CHANGED
@@ -23,9 +23,10 @@ from pydantic import (
|
|
23
23
|
)
|
24
24
|
from shortuuid import uuid
|
25
25
|
|
26
|
-
from inspect_ai._util.constants import
|
26
|
+
from inspect_ai._util.constants import DESERIALIZING
|
27
27
|
from inspect_ai._util.error import EvalError
|
28
|
-
from inspect_ai._util.json import JsonChange
|
28
|
+
from inspect_ai._util.json import JsonChange
|
29
|
+
from inspect_ai._util.logger import warn_once
|
29
30
|
from inspect_ai._util.working import sample_working_time
|
30
31
|
from inspect_ai.dataset._dataset import Sample
|
31
32
|
from inspect_ai.log._message import LoggingMessage
|
@@ -34,7 +35,6 @@ from inspect_ai.model._generate_config import GenerateConfig
|
|
34
35
|
from inspect_ai.model._model_call import ModelCall
|
35
36
|
from inspect_ai.model._model_output import ModelOutput
|
36
37
|
from inspect_ai.scorer._metric import Score
|
37
|
-
from inspect_ai.solver._task_state import state_jsonable
|
38
38
|
from inspect_ai.tool._tool import ToolResult
|
39
39
|
from inspect_ai.tool._tool_call import (
|
40
40
|
ToolCall,
|
@@ -44,6 +44,7 @@ from inspect_ai.tool._tool_call import (
|
|
44
44
|
)
|
45
45
|
from inspect_ai.tool._tool_choice import ToolChoice
|
46
46
|
from inspect_ai.tool._tool_info import ToolInfo
|
47
|
+
from inspect_ai.util._span import current_span_id
|
47
48
|
from inspect_ai.util._store import store, store_changes, store_jsonable
|
48
49
|
|
49
50
|
logger = getLogger(__name__)
|
@@ -57,6 +58,9 @@ class BaseEvent(BaseModel):
|
|
57
58
|
}
|
58
59
|
id_: str = Field(default_factory=lambda: str(uuid()), exclude=True)
|
59
60
|
|
61
|
+
span_id: str | None = Field(default=None)
|
62
|
+
"""Span the event occurred within."""
|
63
|
+
|
60
64
|
timestamp: datetime = Field(default_factory=datetime.now)
|
61
65
|
"""Clock time at which event occurred."""
|
62
66
|
|
@@ -66,6 +70,17 @@ class BaseEvent(BaseModel):
|
|
66
70
|
pending: bool | None = Field(default=None)
|
67
71
|
"""Is this event pending?"""
|
68
72
|
|
73
|
+
def model_post_init(self, __context: Any) -> None:
|
74
|
+
# check if deserializing
|
75
|
+
is_deserializing = isinstance(__context, dict) and __context.get(
|
76
|
+
DESERIALIZING, False
|
77
|
+
)
|
78
|
+
|
79
|
+
# Generate context id fields if not deserializing
|
80
|
+
if not is_deserializing:
|
81
|
+
if self.span_id is None:
|
82
|
+
self.span_id = current_span_id()
|
83
|
+
|
69
84
|
@field_serializer("timestamp")
|
70
85
|
def serialize_timestamp(self, dt: datetime) -> str:
|
71
86
|
return dt.astimezone().isoformat()
|
@@ -147,6 +162,9 @@ class ModelEvent(BaseEvent):
|
|
147
162
|
output: ModelOutput
|
148
163
|
"""Output from model."""
|
149
164
|
|
165
|
+
retries: int | None = Field(default=None)
|
166
|
+
"""Retries for the model API request."""
|
167
|
+
|
150
168
|
error: str | None = Field(default=None)
|
151
169
|
"""Error which occurred during model call."""
|
152
170
|
|
@@ -203,7 +221,13 @@ class ToolEvent(BaseEvent):
|
|
203
221
|
"""Error that occurred during tool call."""
|
204
222
|
|
205
223
|
events: list["Event"] = Field(default_factory=list)
|
206
|
-
"""Transcript of events for tool.
|
224
|
+
"""Transcript of events for tool.
|
225
|
+
|
226
|
+
Note that events are no longer recorded separately within
|
227
|
+
tool events but rather all events are recorded in the main
|
228
|
+
transcript. This field is deprecated and here for backwards
|
229
|
+
compatibility with transcripts that have sub-events.
|
230
|
+
"""
|
207
231
|
|
208
232
|
completed: datetime | None = Field(default=None)
|
209
233
|
"""Time that tool call completed (see `timestamp` for started)"""
|
@@ -222,7 +246,6 @@ class ToolEvent(BaseEvent):
|
|
222
246
|
result: ToolResult,
|
223
247
|
truncated: tuple[int, int] | None,
|
224
248
|
error: ToolCallError | None,
|
225
|
-
events: list["Event"],
|
226
249
|
waiting_time: float,
|
227
250
|
agent: str | None,
|
228
251
|
failed: bool | None,
|
@@ -230,7 +253,6 @@ class ToolEvent(BaseEvent):
|
|
230
253
|
self.result = result
|
231
254
|
self.truncated = truncated
|
232
255
|
self.error = error
|
233
|
-
self.events = events
|
234
256
|
self.pending = None
|
235
257
|
completed = datetime.now()
|
236
258
|
self.completed = completed
|
@@ -402,6 +424,35 @@ class ScoreEvent(BaseEvent):
|
|
402
424
|
"""Was this an intermediate scoring?"""
|
403
425
|
|
404
426
|
|
427
|
+
class SpanBeginEvent(BaseEvent):
|
428
|
+
"""Mark the beginning of a transcript span."""
|
429
|
+
|
430
|
+
event: Literal["span_begin"] = Field(default="span_begin")
|
431
|
+
"""Event type."""
|
432
|
+
|
433
|
+
id: str
|
434
|
+
"""Unique identifier for span."""
|
435
|
+
|
436
|
+
parent_id: str | None = Field(default=None)
|
437
|
+
"""Identifier for parent span."""
|
438
|
+
|
439
|
+
type: str | None = Field(default=None)
|
440
|
+
"""Optional 'type' field for span."""
|
441
|
+
|
442
|
+
name: str
|
443
|
+
"""Span name."""
|
444
|
+
|
445
|
+
|
446
|
+
class SpanEndEvent(BaseEvent):
|
447
|
+
"""Mark the end of a transcript span."""
|
448
|
+
|
449
|
+
event: Literal["span_end"] = Field(default="span_end")
|
450
|
+
"""Event type."""
|
451
|
+
|
452
|
+
id: str
|
453
|
+
"""Unique identifier for span."""
|
454
|
+
|
455
|
+
|
405
456
|
class StepEvent(BaseEvent):
|
406
457
|
"""Step within current sample or subtask."""
|
407
458
|
|
@@ -437,7 +488,13 @@ class SubtaskEvent(BaseEvent):
|
|
437
488
|
"""Subtask function result."""
|
438
489
|
|
439
490
|
events: list["Event"] = Field(default_factory=list)
|
440
|
-
"""Transcript of events for subtask.
|
491
|
+
"""Transcript of events for subtask.
|
492
|
+
|
493
|
+
Note that events are no longer recorded separately within
|
494
|
+
subtasks but rather all events are recorded in the main
|
495
|
+
transcript. This field is deprecated and here for backwards
|
496
|
+
compatibility with transcripts that have sub-events.
|
497
|
+
"""
|
441
498
|
|
442
499
|
completed: datetime | None = Field(default=None)
|
443
500
|
"""Time that subtask completed (see `timestamp` for started)"""
|
@@ -467,6 +524,8 @@ Event: TypeAlias = Union[
|
|
467
524
|
| ErrorEvent
|
468
525
|
| LoggerEvent
|
469
526
|
| InfoEvent
|
527
|
+
| SpanBeginEvent
|
528
|
+
| SpanEndEvent
|
470
529
|
| StepEvent
|
471
530
|
| SubtaskEvent,
|
472
531
|
]
|
@@ -480,8 +539,7 @@ class Transcript:
|
|
480
539
|
|
481
540
|
_event_logger: Callable[[Event], None] | None
|
482
541
|
|
483
|
-
def __init__(self
|
484
|
-
self.name = name
|
542
|
+
def __init__(self) -> None:
|
485
543
|
self._event_logger = None
|
486
544
|
self._events: list[Event] = []
|
487
545
|
|
@@ -498,19 +556,20 @@ class Transcript:
|
|
498
556
|
def step(self, name: str, type: str | None = None) -> Iterator[None]:
|
499
557
|
"""Context manager for recording StepEvent.
|
500
558
|
|
559
|
+
The `step()` context manager is deprecated and will be removed in a future version.
|
560
|
+
Please use the `span()` context manager instead.
|
561
|
+
|
501
562
|
Args:
|
502
563
|
name (str): Step name.
|
503
564
|
type (str | None): Optional step type.
|
504
565
|
"""
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
# end step event
|
513
|
-
self._event(StepEvent(action="end", name=name, type=type))
|
566
|
+
warn_once(
|
567
|
+
logger,
|
568
|
+
"The `transcript().step()` context manager is deprecated and will "
|
569
|
+
+ "be removed in a future version. Please replace the call to step() "
|
570
|
+
+ "with a call to span().",
|
571
|
+
)
|
572
|
+
yield
|
514
573
|
|
515
574
|
@property
|
516
575
|
def events(self) -> Sequence[Event]:
|
@@ -551,23 +610,6 @@ def track_store_changes() -> Iterator[None]:
|
|
551
610
|
transcript()._event(StoreEvent(changes=changes))
|
552
611
|
|
553
612
|
|
554
|
-
@contextlib.contextmanager
|
555
|
-
def track_state_changes(type: str | None = None) -> Iterator[None]:
|
556
|
-
# we only want to track for step() inside the the sample
|
557
|
-
# (solver level tracking is handled already and there are
|
558
|
-
# no state changes in subtasks)
|
559
|
-
if transcript().name == SAMPLE_SUBTASK and type != "solver":
|
560
|
-
before = state_jsonable()
|
561
|
-
yield
|
562
|
-
after = state_jsonable()
|
563
|
-
|
564
|
-
changes = json_changes(before, after)
|
565
|
-
if changes:
|
566
|
-
transcript()._event(StateEvent(changes=changes))
|
567
|
-
else:
|
568
|
-
yield
|
569
|
-
|
570
|
-
|
571
613
|
def init_transcript(transcript: Transcript) -> None:
|
572
614
|
_transcript.set(transcript)
|
573
615
|
|
inspect_ai/log/_tree.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from logging import getLogger
|
3
|
+
from typing import Iterable, Sequence, TypeAlias
|
4
|
+
|
5
|
+
from ._transcript import Event, SpanBeginEvent, SpanEndEvent
|
6
|
+
|
7
|
+
logger = getLogger(__name__)
|
8
|
+
|
9
|
+
EventNode: TypeAlias = "SpanNode" | Event
|
10
|
+
"""Node in an event tree."""
|
11
|
+
|
12
|
+
EventTree: TypeAlias = list[EventNode]
|
13
|
+
"""Tree of events (has invividual events and event spans)."""
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class SpanNode:
|
18
|
+
"""Event tree node representing a span of events."""
|
19
|
+
|
20
|
+
id: str
|
21
|
+
"""Span id."""
|
22
|
+
|
23
|
+
parent_id: str | None
|
24
|
+
"""Parent span id."""
|
25
|
+
|
26
|
+
type: str | None
|
27
|
+
"""Optional 'type' field for span."""
|
28
|
+
|
29
|
+
name: str
|
30
|
+
"""Span name."""
|
31
|
+
|
32
|
+
begin: SpanBeginEvent
|
33
|
+
"""Span begin event."""
|
34
|
+
|
35
|
+
end: SpanEndEvent | None = None
|
36
|
+
"""Span end event (if any)."""
|
37
|
+
|
38
|
+
children: list[EventNode] = field(default_factory=list)
|
39
|
+
"""Children in the span."""
|
40
|
+
|
41
|
+
|
42
|
+
def event_tree(events: Sequence[Event]) -> EventTree:
|
43
|
+
"""Build a tree representation of a sequence of events.
|
44
|
+
|
45
|
+
Organize events heirarchially into event spans.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
events: Sequence of `Event`.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
Event tree.
|
52
|
+
"""
|
53
|
+
# Convert one flat list of (possibly interleaved) events into *forest*
|
54
|
+
# (list of root-level items).
|
55
|
+
|
56
|
+
# Pre-create one node per span so we can attach events no matter when they
|
57
|
+
# arrive in the file. A single forward scan guarantees that the order of
|
58
|
+
# `children` inside every span reflects the order in which things appeared
|
59
|
+
# in the transcript.
|
60
|
+
nodes: dict[str, SpanNode] = {
|
61
|
+
ev.id: SpanNode(
|
62
|
+
id=ev.id, parent_id=ev.parent_id, type=ev.type, name=ev.name, begin=ev
|
63
|
+
)
|
64
|
+
for ev in events
|
65
|
+
if isinstance(ev, SpanBeginEvent)
|
66
|
+
}
|
67
|
+
|
68
|
+
roots: list[EventNode] = []
|
69
|
+
|
70
|
+
# Where should an event with `span_id` go?
|
71
|
+
def bucket(span_id: str | None) -> list[EventNode]:
|
72
|
+
if span_id and span_id in nodes:
|
73
|
+
return nodes[span_id].children
|
74
|
+
return roots # root level
|
75
|
+
|
76
|
+
# Single pass in original order
|
77
|
+
for ev in events:
|
78
|
+
if isinstance(ev, SpanBeginEvent): # span starts
|
79
|
+
bucket(ev.parent_id).append(nodes[ev.id])
|
80
|
+
|
81
|
+
elif isinstance(ev, SpanEndEvent): # span ends
|
82
|
+
if n := nodes.get(ev.id):
|
83
|
+
n.end = ev
|
84
|
+
else:
|
85
|
+
logger.warning(f"Span end event (id: {ev.id} with no span begin)")
|
86
|
+
|
87
|
+
else: # ordinary event
|
88
|
+
bucket(ev.span_id).append(ev)
|
89
|
+
|
90
|
+
return roots
|
91
|
+
|
92
|
+
|
93
|
+
def event_sequence(tree: EventTree) -> Iterable[Event]:
|
94
|
+
"""Flatten a span forest back into a properly ordered seqeunce.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
tree: Event tree
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Sequence of events.
|
101
|
+
"""
|
102
|
+
for item in tree:
|
103
|
+
if isinstance(item, SpanNode):
|
104
|
+
yield item.begin
|
105
|
+
yield from event_sequence(item.children)
|
106
|
+
if item.end:
|
107
|
+
yield item.end
|
108
|
+
else:
|
109
|
+
yield item
|
110
|
+
|
111
|
+
|
112
|
+
def _print_event_tree(tree: EventTree, indent: str = "") -> None:
|
113
|
+
for item in tree:
|
114
|
+
if isinstance(item, SpanNode):
|
115
|
+
print(f"{indent}span ({item.type}): {item.name}")
|
116
|
+
_print_event_tree(item.children, f"{indent} ")
|
117
|
+
else:
|
118
|
+
print(f"{indent}{item.event}")
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -61,6 +61,7 @@ from inspect_ai.tool._tool_params import ToolParams
|
|
61
61
|
from inspect_ai.util import OutputLimitExceededError
|
62
62
|
from inspect_ai.util._anyio import inner_exception
|
63
63
|
from inspect_ai.util._limit import LimitExceededError, apply_limits
|
64
|
+
from inspect_ai.util._span import span
|
64
65
|
|
65
66
|
from ._chat_message import (
|
66
67
|
ChatMessage,
|
@@ -109,26 +110,18 @@ async def execute_tools(
|
|
109
110
|
"""
|
110
111
|
message = messages[-1]
|
111
112
|
if isinstance(message, ChatMessageAssistant) and message.tool_calls:
|
112
|
-
from inspect_ai.log._transcript import
|
113
|
-
ToolEvent,
|
114
|
-
Transcript,
|
115
|
-
init_transcript,
|
116
|
-
track_store_changes,
|
117
|
-
transcript,
|
118
|
-
)
|
113
|
+
from inspect_ai.log._transcript import ToolEvent, transcript
|
119
114
|
|
120
115
|
tdefs = await tool_defs(tools)
|
121
116
|
|
122
117
|
async def call_tool_task(
|
123
118
|
call: ToolCall,
|
119
|
+
event: ToolEvent,
|
124
120
|
conversation: list[ChatMessage],
|
125
121
|
send_stream: MemoryObjectSendStream[
|
126
122
|
tuple[ExecuteToolsResult, ToolEvent, Exception | None]
|
127
123
|
],
|
128
124
|
) -> None:
|
129
|
-
# create a transript for this call
|
130
|
-
init_transcript(Transcript(name=call.function))
|
131
|
-
|
132
125
|
result: ToolResult = ""
|
133
126
|
messages: list[ChatMessage] = []
|
134
127
|
output: ModelOutput | None = None
|
@@ -136,15 +129,14 @@ async def execute_tools(
|
|
136
129
|
tool_error: ToolCallError | None = None
|
137
130
|
tool_exception: Exception | None = None
|
138
131
|
try:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
raise inner_ex.with_traceback(inner_ex.__traceback__)
|
132
|
+
try:
|
133
|
+
result, messages, output, agent = await call_tool(
|
134
|
+
tdefs, message.text, call, event, conversation
|
135
|
+
)
|
136
|
+
# unwrap exception group
|
137
|
+
except Exception as ex:
|
138
|
+
inner_ex = inner_exception(ex)
|
139
|
+
raise inner_ex.with_traceback(inner_ex.__traceback__)
|
148
140
|
|
149
141
|
except TimeoutError:
|
150
142
|
tool_error = ToolCallError(
|
@@ -227,7 +219,6 @@ async def execute_tools(
|
|
227
219
|
truncated=truncated,
|
228
220
|
view=call.view,
|
229
221
|
error=tool_error,
|
230
|
-
events=list(transcript().events),
|
231
222
|
agent=agent,
|
232
223
|
)
|
233
224
|
|
@@ -270,7 +261,6 @@ async def execute_tools(
|
|
270
261
|
internal=call.internal,
|
271
262
|
pending=True,
|
272
263
|
)
|
273
|
-
transcript()._event(event)
|
274
264
|
|
275
265
|
# execute the tool call. if the operator cancels the
|
276
266
|
# tool call then synthesize the appropriate message/event
|
@@ -280,7 +270,7 @@ async def execute_tools(
|
|
280
270
|
|
281
271
|
result_exception = None
|
282
272
|
async with anyio.create_task_group() as tg:
|
283
|
-
tg.start_soon(call_tool_task, call, messages, send_stream)
|
273
|
+
tg.start_soon(call_tool_task, call, event, messages, send_stream)
|
284
274
|
event._set_cancel_fn(tg.cancel_scope.cancel)
|
285
275
|
async with receive_stream:
|
286
276
|
(
|
@@ -306,7 +296,6 @@ async def execute_tools(
|
|
306
296
|
truncated=None,
|
307
297
|
view=call.view,
|
308
298
|
error=tool_message.error,
|
309
|
-
events=[],
|
310
299
|
)
|
311
300
|
transcript().info(
|
312
301
|
f"Tool call '{call.function}' was cancelled by operator."
|
@@ -326,7 +315,6 @@ async def execute_tools(
|
|
326
315
|
result=result_event.result,
|
327
316
|
truncated=result_event.truncated,
|
328
317
|
error=result_event.error,
|
329
|
-
events=result_event.events,
|
330
318
|
waiting_time=waiting_time_end - waiting_time_start,
|
331
319
|
agent=result_event.agent,
|
332
320
|
failed=True if result_exception else None,
|
@@ -347,19 +335,34 @@ async def execute_tools(
|
|
347
335
|
|
348
336
|
|
349
337
|
async def call_tool(
|
350
|
-
tools: list[ToolDef],
|
338
|
+
tools: list[ToolDef],
|
339
|
+
message: str,
|
340
|
+
call: ToolCall,
|
341
|
+
event: BaseModel,
|
342
|
+
conversation: list[ChatMessage],
|
351
343
|
) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str | None]:
|
352
344
|
from inspect_ai.agent._handoff import AgentTool
|
353
|
-
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
345
|
+
from inspect_ai.log._transcript import SampleLimitEvent, ToolEvent, transcript
|
346
|
+
|
347
|
+
# dodge circular import
|
348
|
+
assert isinstance(event, ToolEvent)
|
349
|
+
|
350
|
+
# this function is responsible for transcript events so that it can
|
351
|
+
# put them in the right enclosure (e.g. handoff/agent/tool). This
|
352
|
+
# means that if we throw early we need to do the enclosure when raising.
|
353
|
+
async def record_tool_parsing_error(error: str) -> Exception:
|
354
|
+
async with span(name=call.function, type="tool"):
|
355
|
+
transcript()._event(event)
|
356
|
+
return ToolParsingError(error)
|
354
357
|
|
355
358
|
# if there was an error parsing the ToolCall, raise that
|
356
359
|
if call.parse_error:
|
357
|
-
raise
|
360
|
+
raise await record_tool_parsing_error(call.parse_error)
|
358
361
|
|
359
362
|
# find the tool
|
360
363
|
tool_def = next((tool for tool in tools if tool.name == call.function), None)
|
361
364
|
if tool_def is None:
|
362
|
-
raise
|
365
|
+
raise await record_tool_parsing_error(f"Tool {call.function} not found")
|
363
366
|
|
364
367
|
# if we have a tool approver, apply it now
|
365
368
|
from inspect_ai.approval._apply import apply_tool_approval
|
@@ -382,7 +385,7 @@ async def call_tool(
|
|
382
385
|
# validate the schema of the passed object
|
383
386
|
validation_errors = validate_tool_input(call.arguments, tool_def.parameters)
|
384
387
|
if validation_errors:
|
385
|
-
raise
|
388
|
+
raise await record_tool_parsing_error(validation_errors)
|
386
389
|
|
387
390
|
# get arguments (with creation of dataclasses, pydantic objects, etc.)
|
388
391
|
arguments = tool_params(call.arguments, tool_def.tool)
|
@@ -391,14 +394,18 @@ async def call_tool(
|
|
391
394
|
with trace_action(
|
392
395
|
logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
|
393
396
|
):
|
394
|
-
# agent tools get special handling
|
395
397
|
if isinstance(tool_def.tool, AgentTool):
|
396
|
-
|
398
|
+
async with span(tool_def.tool.name, type="handoff"):
|
399
|
+
async with span(name=call.function, type="tool"):
|
400
|
+
transcript()._event(event)
|
401
|
+
return await agent_handoff(tool_def, call, conversation)
|
397
402
|
|
398
403
|
# normal tool call
|
399
404
|
else:
|
400
|
-
|
401
|
-
|
405
|
+
async with span(name=call.function, type="tool"):
|
406
|
+
transcript()._event(event)
|
407
|
+
result: ToolResult = await tool_def.tool(**arguments)
|
408
|
+
return result, [], None, None
|
402
409
|
|
403
410
|
|
404
411
|
async def agent_handoff(
|
@@ -463,7 +470,8 @@ async def agent_handoff(
|
|
463
470
|
agent_state = AgentState(messages=copy(agent_conversation))
|
464
471
|
try:
|
465
472
|
with apply_limits(agent_tool.limits):
|
466
|
-
|
473
|
+
async with span(name=agent_name, type="agent"):
|
474
|
+
agent_state = await agent_tool.agent(agent_state, **arguments)
|
467
475
|
except LimitExceededError as ex:
|
468
476
|
limit_error = ex
|
469
477
|
|