inspect-ai 0.3.93__py3-none-any.whl → 0.3.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. inspect_ai/_display/textual/widgets/samples.py +3 -3
  2. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  3. inspect_ai/_eval/task/run.py +10 -7
  4. inspect_ai/_util/answer.py +26 -0
  5. inspect_ai/_util/constants.py +0 -1
  6. inspect_ai/_util/local_server.py +51 -21
  7. inspect_ai/_view/www/dist/assets/index.css +14 -13
  8. inspect_ai/_view/www/dist/assets/index.js +400 -84
  9. inspect_ai/_view/www/log-schema.json +375 -0
  10. inspect_ai/_view/www/src/@types/log.d.ts +90 -12
  11. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  12. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  13. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  14. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  15. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  16. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  17. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  18. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  19. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  20. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  21. inspect_ai/agent/_as_solver.py +3 -1
  22. inspect_ai/agent/_as_tool.py +6 -4
  23. inspect_ai/agent/_handoff.py +5 -1
  24. inspect_ai/agent/_react.py +4 -3
  25. inspect_ai/agent/_run.py +6 -1
  26. inspect_ai/agent/_types.py +9 -0
  27. inspect_ai/dataset/_dataset.py +6 -3
  28. inspect_ai/log/__init__.py +10 -0
  29. inspect_ai/log/_convert.py +4 -9
  30. inspect_ai/log/_samples.py +14 -17
  31. inspect_ai/log/_transcript.py +77 -35
  32. inspect_ai/log/_tree.py +118 -0
  33. inspect_ai/model/_call_tools.py +42 -34
  34. inspect_ai/model/_model.py +45 -40
  35. inspect_ai/model/_providers/hf.py +27 -1
  36. inspect_ai/model/_providers/sglang.py +8 -2
  37. inspect_ai/model/_providers/vllm.py +6 -2
  38. inspect_ai/scorer/_choice.py +1 -2
  39. inspect_ai/solver/_chain.py +1 -1
  40. inspect_ai/solver/_fork.py +1 -1
  41. inspect_ai/solver/_multiple_choice.py +5 -22
  42. inspect_ai/solver/_plan.py +2 -2
  43. inspect_ai/solver/_transcript.py +6 -7
  44. inspect_ai/tool/_mcp/_mcp.py +6 -5
  45. inspect_ai/tool/_tools/_execute.py +4 -1
  46. inspect_ai/util/__init__.py +4 -0
  47. inspect_ai/util/_anyio.py +11 -0
  48. inspect_ai/util/_collect.py +50 -0
  49. inspect_ai/util/_span.py +58 -0
  50. inspect_ai/util/_subtask.py +27 -42
  51. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
  52. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +56 -51
  53. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
  54. inspect_ai/_display/core/group.py +0 -79
  55. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
  56. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
  57. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -5,12 +5,11 @@ from typing import AsyncGenerator, Iterator, Literal
5
5
 
6
6
  from shortuuid import uuid
7
7
 
8
- from inspect_ai._util.constants import SAMPLE_SUBTASK
9
8
  from inspect_ai.dataset._dataset import Sample
10
9
  from inspect_ai.util._sandbox import SandboxConnection
11
10
  from inspect_ai.util._sandbox.context import sandbox_connections
12
11
 
13
- from ._transcript import Transcript, transcript
12
+ from ._transcript import ModelEvent, Transcript
14
13
 
15
14
 
16
15
  class ActiveSample:
@@ -47,7 +46,6 @@ class ActiveSample:
47
46
  self.total_tokens = 0
48
47
  self.transcript = transcript
49
48
  self.sandboxes = sandboxes
50
- self.retry_count = 0
51
49
  self._interrupt_action: Literal["score", "error"] | None = None
52
50
 
53
51
  @property
@@ -151,27 +149,26 @@ def set_active_sample_total_messages(total_messages: int) -> None:
151
149
  active.total_messages = total_messages
152
150
 
153
151
 
152
+ _active_model_event: ContextVar[ModelEvent | None] = ContextVar(
153
+ "_active_model_event", default=None
154
+ )
155
+
156
+
154
157
  @contextlib.contextmanager
155
- def track_active_sample_retries() -> Iterator[None]:
156
- reset_active_sample_retries()
158
+ def track_active_model_event(event: ModelEvent) -> Iterator[None]:
159
+ token = _active_model_event.set(event)
157
160
  try:
158
161
  yield
159
162
  finally:
160
- reset_active_sample_retries()
161
-
162
-
163
- def reset_active_sample_retries() -> None:
164
- active = sample_active()
165
- if active:
166
- active.retry_count = 0
163
+ _active_model_event.reset(token)
167
164
 
168
165
 
169
166
  def report_active_sample_retry() -> None:
170
- active = sample_active()
171
- if active:
172
- # only do this for the top level subtask
173
- if transcript().name == SAMPLE_SUBTASK:
174
- active.retry_count = active.retry_count + 1
167
+ model_event = _active_model_event.get()
168
+ if model_event is not None:
169
+ if model_event.retries is None:
170
+ model_event.retries = 0
171
+ model_event.retries = model_event.retries + 1
175
172
 
176
173
 
177
174
  _sample_active: ContextVar[ActiveSample | None] = ContextVar(
@@ -23,9 +23,10 @@ from pydantic import (
23
23
  )
24
24
  from shortuuid import uuid
25
25
 
26
- from inspect_ai._util.constants import SAMPLE_SUBTASK
26
+ from inspect_ai._util.constants import DESERIALIZING
27
27
  from inspect_ai._util.error import EvalError
28
- from inspect_ai._util.json import JsonChange, json_changes
28
+ from inspect_ai._util.json import JsonChange
29
+ from inspect_ai._util.logger import warn_once
29
30
  from inspect_ai._util.working import sample_working_time
30
31
  from inspect_ai.dataset._dataset import Sample
31
32
  from inspect_ai.log._message import LoggingMessage
@@ -34,7 +35,6 @@ from inspect_ai.model._generate_config import GenerateConfig
34
35
  from inspect_ai.model._model_call import ModelCall
35
36
  from inspect_ai.model._model_output import ModelOutput
36
37
  from inspect_ai.scorer._metric import Score
37
- from inspect_ai.solver._task_state import state_jsonable
38
38
  from inspect_ai.tool._tool import ToolResult
39
39
  from inspect_ai.tool._tool_call import (
40
40
  ToolCall,
@@ -44,6 +44,7 @@ from inspect_ai.tool._tool_call import (
44
44
  )
45
45
  from inspect_ai.tool._tool_choice import ToolChoice
46
46
  from inspect_ai.tool._tool_info import ToolInfo
47
+ from inspect_ai.util._span import current_span_id
47
48
  from inspect_ai.util._store import store, store_changes, store_jsonable
48
49
 
49
50
  logger = getLogger(__name__)
@@ -57,6 +58,9 @@ class BaseEvent(BaseModel):
57
58
  }
58
59
  id_: str = Field(default_factory=lambda: str(uuid()), exclude=True)
59
60
 
61
+ span_id: str | None = Field(default=None)
62
+ """Span the event occurred within."""
63
+
60
64
  timestamp: datetime = Field(default_factory=datetime.now)
61
65
  """Clock time at which event occurred."""
62
66
 
@@ -66,6 +70,17 @@ class BaseEvent(BaseModel):
66
70
  pending: bool | None = Field(default=None)
67
71
  """Is this event pending?"""
68
72
 
73
+ def model_post_init(self, __context: Any) -> None:
74
+ # check if deserializing
75
+ is_deserializing = isinstance(__context, dict) and __context.get(
76
+ DESERIALIZING, False
77
+ )
78
+
79
+ # Generate context id fields if not deserializing
80
+ if not is_deserializing:
81
+ if self.span_id is None:
82
+ self.span_id = current_span_id()
83
+
69
84
  @field_serializer("timestamp")
70
85
  def serialize_timestamp(self, dt: datetime) -> str:
71
86
  return dt.astimezone().isoformat()
@@ -147,6 +162,9 @@ class ModelEvent(BaseEvent):
147
162
  output: ModelOutput
148
163
  """Output from model."""
149
164
 
165
+ retries: int | None = Field(default=None)
166
+ """Retries for the model API request."""
167
+
150
168
  error: str | None = Field(default=None)
151
169
  """Error which occurred during model call."""
152
170
 
@@ -203,7 +221,13 @@ class ToolEvent(BaseEvent):
203
221
  """Error that occurred during tool call."""
204
222
 
205
223
  events: list["Event"] = Field(default_factory=list)
206
- """Transcript of events for tool."""
224
+ """Transcript of events for tool.
225
+
226
+ Note that events are no longer recorded separately within
227
+ tool events but rather all events are recorded in the main
228
+ transcript. This field is deprecated and here for backwards
229
+ compatibility with transcripts that have sub-events.
230
+ """
207
231
 
208
232
  completed: datetime | None = Field(default=None)
209
233
  """Time that tool call completed (see `timestamp` for started)"""
@@ -222,7 +246,6 @@ class ToolEvent(BaseEvent):
222
246
  result: ToolResult,
223
247
  truncated: tuple[int, int] | None,
224
248
  error: ToolCallError | None,
225
- events: list["Event"],
226
249
  waiting_time: float,
227
250
  agent: str | None,
228
251
  failed: bool | None,
@@ -230,7 +253,6 @@ class ToolEvent(BaseEvent):
230
253
  self.result = result
231
254
  self.truncated = truncated
232
255
  self.error = error
233
- self.events = events
234
256
  self.pending = None
235
257
  completed = datetime.now()
236
258
  self.completed = completed
@@ -402,6 +424,35 @@ class ScoreEvent(BaseEvent):
402
424
  """Was this an intermediate scoring?"""
403
425
 
404
426
 
427
+ class SpanBeginEvent(BaseEvent):
428
+ """Mark the beginning of a transcript span."""
429
+
430
+ event: Literal["span_begin"] = Field(default="span_begin")
431
+ """Event type."""
432
+
433
+ id: str
434
+ """Unique identifier for span."""
435
+
436
+ parent_id: str | None = Field(default=None)
437
+ """Identifier for parent span."""
438
+
439
+ type: str | None = Field(default=None)
440
+ """Optional 'type' field for span."""
441
+
442
+ name: str
443
+ """Span name."""
444
+
445
+
446
+ class SpanEndEvent(BaseEvent):
447
+ """Mark the end of a transcript span."""
448
+
449
+ event: Literal["span_end"] = Field(default="span_end")
450
+ """Event type."""
451
+
452
+ id: str
453
+ """Unique identifier for span."""
454
+
455
+
405
456
  class StepEvent(BaseEvent):
406
457
  """Step within current sample or subtask."""
407
458
 
@@ -437,7 +488,13 @@ class SubtaskEvent(BaseEvent):
437
488
  """Subtask function result."""
438
489
 
439
490
  events: list["Event"] = Field(default_factory=list)
440
- """Transcript of events for subtask."""
491
+ """Transcript of events for subtask.
492
+
493
+ Note that events are no longer recorded separately within
494
+ subtasks but rather all events are recorded in the main
495
+ transcript. This field is deprecated and here for backwards
496
+ compatibility with transcripts that have sub-events.
497
+ """
441
498
 
442
499
  completed: datetime | None = Field(default=None)
443
500
  """Time that subtask completed (see `timestamp` for started)"""
@@ -467,6 +524,8 @@ Event: TypeAlias = Union[
467
524
  | ErrorEvent
468
525
  | LoggerEvent
469
526
  | InfoEvent
527
+ | SpanBeginEvent
528
+ | SpanEndEvent
470
529
  | StepEvent
471
530
  | SubtaskEvent,
472
531
  ]
@@ -480,8 +539,7 @@ class Transcript:
480
539
 
481
540
  _event_logger: Callable[[Event], None] | None
482
541
 
483
- def __init__(self, name: str = "") -> None:
484
- self.name = name
542
+ def __init__(self) -> None:
485
543
  self._event_logger = None
486
544
  self._events: list[Event] = []
487
545
 
@@ -498,19 +556,20 @@ class Transcript:
498
556
  def step(self, name: str, type: str | None = None) -> Iterator[None]:
499
557
  """Context manager for recording StepEvent.
500
558
 
559
+ The `step()` context manager is deprecated and will be removed in a future version.
560
+ Please use the `span()` context manager instead.
561
+
501
562
  Args:
502
563
  name (str): Step name.
503
564
  type (str | None): Optional step type.
504
565
  """
505
- # step event
506
- self._event(StepEvent(action="begin", name=name, type=type))
507
-
508
- # run the step (tracking state/store changes)
509
- with track_state_changes(type), track_store_changes():
510
- yield
511
-
512
- # end step event
513
- self._event(StepEvent(action="end", name=name, type=type))
566
+ warn_once(
567
+ logger,
568
+ "The `transcript().step()` context manager is deprecated and will "
569
+ + "be removed in a future version. Please replace the call to step() "
570
+ + "with a call to span().",
571
+ )
572
+ yield
514
573
 
515
574
  @property
516
575
  def events(self) -> Sequence[Event]:
@@ -551,23 +610,6 @@ def track_store_changes() -> Iterator[None]:
551
610
  transcript()._event(StoreEvent(changes=changes))
552
611
 
553
612
 
554
- @contextlib.contextmanager
555
- def track_state_changes(type: str | None = None) -> Iterator[None]:
556
- # we only want to track for step() inside the the sample
557
- # (solver level tracking is handled already and there are
558
- # no state changes in subtasks)
559
- if transcript().name == SAMPLE_SUBTASK and type != "solver":
560
- before = state_jsonable()
561
- yield
562
- after = state_jsonable()
563
-
564
- changes = json_changes(before, after)
565
- if changes:
566
- transcript()._event(StateEvent(changes=changes))
567
- else:
568
- yield
569
-
570
-
571
613
  def init_transcript(transcript: Transcript) -> None:
572
614
  _transcript.set(transcript)
573
615
 
@@ -0,0 +1,118 @@
1
+ from dataclasses import dataclass, field
2
+ from logging import getLogger
3
+ from typing import Iterable, Sequence, TypeAlias
4
+
5
+ from ._transcript import Event, SpanBeginEvent, SpanEndEvent
6
+
7
+ logger = getLogger(__name__)
8
+
9
+ EventNode: TypeAlias = "SpanNode" | Event
10
+ """Node in an event tree."""
11
+
12
+ EventTree: TypeAlias = list[EventNode]
13
+ """Tree of events (has invividual events and event spans)."""
14
+
15
+
16
+ @dataclass
17
+ class SpanNode:
18
+ """Event tree node representing a span of events."""
19
+
20
+ id: str
21
+ """Span id."""
22
+
23
+ parent_id: str | None
24
+ """Parent span id."""
25
+
26
+ type: str | None
27
+ """Optional 'type' field for span."""
28
+
29
+ name: str
30
+ """Span name."""
31
+
32
+ begin: SpanBeginEvent
33
+ """Span begin event."""
34
+
35
+ end: SpanEndEvent | None = None
36
+ """Span end event (if any)."""
37
+
38
+ children: list[EventNode] = field(default_factory=list)
39
+ """Children in the span."""
40
+
41
+
42
+ def event_tree(events: Sequence[Event]) -> EventTree:
43
+ """Build a tree representation of a sequence of events.
44
+
45
+ Organize events heirarchially into event spans.
46
+
47
+ Args:
48
+ events: Sequence of `Event`.
49
+
50
+ Returns:
51
+ Event tree.
52
+ """
53
+ # Convert one flat list of (possibly interleaved) events into *forest*
54
+ # (list of root-level items).
55
+
56
+ # Pre-create one node per span so we can attach events no matter when they
57
+ # arrive in the file. A single forward scan guarantees that the order of
58
+ # `children` inside every span reflects the order in which things appeared
59
+ # in the transcript.
60
+ nodes: dict[str, SpanNode] = {
61
+ ev.id: SpanNode(
62
+ id=ev.id, parent_id=ev.parent_id, type=ev.type, name=ev.name, begin=ev
63
+ )
64
+ for ev in events
65
+ if isinstance(ev, SpanBeginEvent)
66
+ }
67
+
68
+ roots: list[EventNode] = []
69
+
70
+ # Where should an event with `span_id` go?
71
+ def bucket(span_id: str | None) -> list[EventNode]:
72
+ if span_id and span_id in nodes:
73
+ return nodes[span_id].children
74
+ return roots # root level
75
+
76
+ # Single pass in original order
77
+ for ev in events:
78
+ if isinstance(ev, SpanBeginEvent): # span starts
79
+ bucket(ev.parent_id).append(nodes[ev.id])
80
+
81
+ elif isinstance(ev, SpanEndEvent): # span ends
82
+ if n := nodes.get(ev.id):
83
+ n.end = ev
84
+ else:
85
+ logger.warning(f"Span end event (id: {ev.id} with no span begin)")
86
+
87
+ else: # ordinary event
88
+ bucket(ev.span_id).append(ev)
89
+
90
+ return roots
91
+
92
+
93
+ def event_sequence(tree: EventTree) -> Iterable[Event]:
94
+ """Flatten a span forest back into a properly ordered seqeunce.
95
+
96
+ Args:
97
+ tree: Event tree
98
+
99
+ Returns:
100
+ Sequence of events.
101
+ """
102
+ for item in tree:
103
+ if isinstance(item, SpanNode):
104
+ yield item.begin
105
+ yield from event_sequence(item.children)
106
+ if item.end:
107
+ yield item.end
108
+ else:
109
+ yield item
110
+
111
+
112
+ def _print_event_tree(tree: EventTree, indent: str = "") -> None:
113
+ for item in tree:
114
+ if isinstance(item, SpanNode):
115
+ print(f"{indent}span ({item.type}): {item.name}")
116
+ _print_event_tree(item.children, f"{indent} ")
117
+ else:
118
+ print(f"{indent}{item.event}")
@@ -61,6 +61,7 @@ from inspect_ai.tool._tool_params import ToolParams
61
61
  from inspect_ai.util import OutputLimitExceededError
62
62
  from inspect_ai.util._anyio import inner_exception
63
63
  from inspect_ai.util._limit import LimitExceededError, apply_limits
64
+ from inspect_ai.util._span import span
64
65
 
65
66
  from ._chat_message import (
66
67
  ChatMessage,
@@ -109,26 +110,18 @@ async def execute_tools(
109
110
  """
110
111
  message = messages[-1]
111
112
  if isinstance(message, ChatMessageAssistant) and message.tool_calls:
112
- from inspect_ai.log._transcript import (
113
- ToolEvent,
114
- Transcript,
115
- init_transcript,
116
- track_store_changes,
117
- transcript,
118
- )
113
+ from inspect_ai.log._transcript import ToolEvent, transcript
119
114
 
120
115
  tdefs = await tool_defs(tools)
121
116
 
122
117
  async def call_tool_task(
123
118
  call: ToolCall,
119
+ event: ToolEvent,
124
120
  conversation: list[ChatMessage],
125
121
  send_stream: MemoryObjectSendStream[
126
122
  tuple[ExecuteToolsResult, ToolEvent, Exception | None]
127
123
  ],
128
124
  ) -> None:
129
- # create a transript for this call
130
- init_transcript(Transcript(name=call.function))
131
-
132
125
  result: ToolResult = ""
133
126
  messages: list[ChatMessage] = []
134
127
  output: ModelOutput | None = None
@@ -136,15 +129,14 @@ async def execute_tools(
136
129
  tool_error: ToolCallError | None = None
137
130
  tool_exception: Exception | None = None
138
131
  try:
139
- with track_store_changes():
140
- try:
141
- result, messages, output, agent = await call_tool(
142
- tdefs, message.text, call, conversation
143
- )
144
- # unwrap exception group
145
- except Exception as ex:
146
- inner_ex = inner_exception(ex)
147
- raise inner_ex.with_traceback(inner_ex.__traceback__)
132
+ try:
133
+ result, messages, output, agent = await call_tool(
134
+ tdefs, message.text, call, event, conversation
135
+ )
136
+ # unwrap exception group
137
+ except Exception as ex:
138
+ inner_ex = inner_exception(ex)
139
+ raise inner_ex.with_traceback(inner_ex.__traceback__)
148
140
 
149
141
  except TimeoutError:
150
142
  tool_error = ToolCallError(
@@ -227,7 +219,6 @@ async def execute_tools(
227
219
  truncated=truncated,
228
220
  view=call.view,
229
221
  error=tool_error,
230
- events=list(transcript().events),
231
222
  agent=agent,
232
223
  )
233
224
 
@@ -270,7 +261,6 @@ async def execute_tools(
270
261
  internal=call.internal,
271
262
  pending=True,
272
263
  )
273
- transcript()._event(event)
274
264
 
275
265
  # execute the tool call. if the operator cancels the
276
266
  # tool call then synthesize the appropriate message/event
@@ -280,7 +270,7 @@ async def execute_tools(
280
270
 
281
271
  result_exception = None
282
272
  async with anyio.create_task_group() as tg:
283
- tg.start_soon(call_tool_task, call, messages, send_stream)
273
+ tg.start_soon(call_tool_task, call, event, messages, send_stream)
284
274
  event._set_cancel_fn(tg.cancel_scope.cancel)
285
275
  async with receive_stream:
286
276
  (
@@ -306,7 +296,6 @@ async def execute_tools(
306
296
  truncated=None,
307
297
  view=call.view,
308
298
  error=tool_message.error,
309
- events=[],
310
299
  )
311
300
  transcript().info(
312
301
  f"Tool call '{call.function}' was cancelled by operator."
@@ -326,7 +315,6 @@ async def execute_tools(
326
315
  result=result_event.result,
327
316
  truncated=result_event.truncated,
328
317
  error=result_event.error,
329
- events=result_event.events,
330
318
  waiting_time=waiting_time_end - waiting_time_start,
331
319
  agent=result_event.agent,
332
320
  failed=True if result_exception else None,
@@ -347,19 +335,34 @@ async def execute_tools(
347
335
 
348
336
 
349
337
  async def call_tool(
350
- tools: list[ToolDef], message: str, call: ToolCall, conversation: list[ChatMessage]
338
+ tools: list[ToolDef],
339
+ message: str,
340
+ call: ToolCall,
341
+ event: BaseModel,
342
+ conversation: list[ChatMessage],
351
343
  ) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str | None]:
352
344
  from inspect_ai.agent._handoff import AgentTool
353
- from inspect_ai.log._transcript import SampleLimitEvent, transcript
345
+ from inspect_ai.log._transcript import SampleLimitEvent, ToolEvent, transcript
346
+
347
+ # dodge circular import
348
+ assert isinstance(event, ToolEvent)
349
+
350
+ # this function is responsible for transcript events so that it can
351
+ # put them in the right enclosure (e.g. handoff/agent/tool). This
352
+ # means that if we throw early we need to do the enclosure when raising.
353
+ async def record_tool_parsing_error(error: str) -> Exception:
354
+ async with span(name=call.function, type="tool"):
355
+ transcript()._event(event)
356
+ return ToolParsingError(error)
354
357
 
355
358
  # if there was an error parsing the ToolCall, raise that
356
359
  if call.parse_error:
357
- raise ToolParsingError(call.parse_error)
360
+ raise await record_tool_parsing_error(call.parse_error)
358
361
 
359
362
  # find the tool
360
363
  tool_def = next((tool for tool in tools if tool.name == call.function), None)
361
364
  if tool_def is None:
362
- raise ToolParsingError(f"Tool {call.function} not found")
365
+ raise await record_tool_parsing_error(f"Tool {call.function} not found")
363
366
 
364
367
  # if we have a tool approver, apply it now
365
368
  from inspect_ai.approval._apply import apply_tool_approval
@@ -382,7 +385,7 @@ async def call_tool(
382
385
  # validate the schema of the passed object
383
386
  validation_errors = validate_tool_input(call.arguments, tool_def.parameters)
384
387
  if validation_errors:
385
- raise ToolParsingError(validation_errors)
388
+ raise await record_tool_parsing_error(validation_errors)
386
389
 
387
390
  # get arguments (with creation of dataclasses, pydantic objects, etc.)
388
391
  arguments = tool_params(call.arguments, tool_def.tool)
@@ -391,14 +394,18 @@ async def call_tool(
391
394
  with trace_action(
392
395
  logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
393
396
  ):
394
- # agent tools get special handling
395
397
  if isinstance(tool_def.tool, AgentTool):
396
- return await agent_handoff(tool_def, call, conversation)
398
+ async with span(tool_def.tool.name, type="handoff"):
399
+ async with span(name=call.function, type="tool"):
400
+ transcript()._event(event)
401
+ return await agent_handoff(tool_def, call, conversation)
397
402
 
398
403
  # normal tool call
399
404
  else:
400
- result: ToolResult = await tool_def.tool(**arguments)
401
- return result, [], None, None
405
+ async with span(name=call.function, type="tool"):
406
+ transcript()._event(event)
407
+ result: ToolResult = await tool_def.tool(**arguments)
408
+ return result, [], None, None
402
409
 
403
410
 
404
411
  async def agent_handoff(
@@ -463,7 +470,8 @@ async def agent_handoff(
463
470
  agent_state = AgentState(messages=copy(agent_conversation))
464
471
  try:
465
472
  with apply_limits(agent_tool.limits):
466
- agent_state = await agent_tool.agent(agent_state, **arguments)
473
+ async with span(name=agent_name, type="agent"):
474
+ agent_state = await agent_tool.agent(agent_state, **arguments)
467
475
  except LimitExceededError as ex:
468
476
  limit_error = ex
469
477