inspect-ai 0.3.93__py3-none-any.whl → 0.3.94__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. inspect_ai/_display/textual/widgets/samples.py +3 -3
  2. inspect_ai/_display/textual/widgets/transcript.py +3 -29
  3. inspect_ai/_eval/task/run.py +10 -7
  4. inspect_ai/_util/answer.py +26 -0
  5. inspect_ai/_util/constants.py +0 -1
  6. inspect_ai/_util/local_server.py +51 -21
  7. inspect_ai/_view/www/dist/assets/index.css +14 -13
  8. inspect_ai/_view/www/dist/assets/index.js +400 -84
  9. inspect_ai/_view/www/log-schema.json +375 -0
  10. inspect_ai/_view/www/src/@types/log.d.ts +90 -12
  11. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
  12. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
  13. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
  14. inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
  15. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
  16. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
  17. inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
  18. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
  19. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
  20. inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
  21. inspect_ai/agent/_as_solver.py +3 -1
  22. inspect_ai/agent/_as_tool.py +6 -4
  23. inspect_ai/agent/_handoff.py +5 -1
  24. inspect_ai/agent/_react.py +4 -3
  25. inspect_ai/agent/_run.py +6 -1
  26. inspect_ai/agent/_types.py +9 -0
  27. inspect_ai/dataset/_dataset.py +6 -3
  28. inspect_ai/log/__init__.py +10 -0
  29. inspect_ai/log/_convert.py +4 -9
  30. inspect_ai/log/_samples.py +14 -17
  31. inspect_ai/log/_transcript.py +77 -35
  32. inspect_ai/log/_tree.py +118 -0
  33. inspect_ai/model/_call_tools.py +42 -34
  34. inspect_ai/model/_model.py +45 -40
  35. inspect_ai/model/_providers/hf.py +27 -1
  36. inspect_ai/model/_providers/sglang.py +8 -2
  37. inspect_ai/model/_providers/vllm.py +6 -2
  38. inspect_ai/scorer/_choice.py +1 -2
  39. inspect_ai/solver/_chain.py +1 -1
  40. inspect_ai/solver/_fork.py +1 -1
  41. inspect_ai/solver/_multiple_choice.py +5 -22
  42. inspect_ai/solver/_plan.py +2 -2
  43. inspect_ai/solver/_transcript.py +6 -7
  44. inspect_ai/tool/_mcp/_mcp.py +6 -5
  45. inspect_ai/tool/_tools/_execute.py +4 -1
  46. inspect_ai/util/__init__.py +4 -0
  47. inspect_ai/util/_anyio.py +11 -0
  48. inspect_ai/util/_collect.py +50 -0
  49. inspect_ai/util/_span.py +58 -0
  50. inspect_ai/util/_subtask.py +27 -42
  51. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
  52. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +56 -51
  53. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
  54. inspect_ai/_display/core/group.py +0 -79
  55. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
  56. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
  57. {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from typing import (
19
19
  cast,
20
20
  )
21
21
 
22
+ from pydantic import BaseModel
22
23
  from pydantic_core import to_jsonable_python
23
24
  from tenacity import (
24
25
  RetryCallState,
@@ -402,36 +403,32 @@ class Model:
402
403
  start_time = datetime.now()
403
404
  working_start = sample_working_time()
404
405
  async with self._connection_concurrency(config):
405
- from inspect_ai.log._samples import track_active_sample_retries
406
-
407
406
  # generate
408
- with track_active_sample_retries():
409
- output = await self._generate(
410
- input=input,
411
- tools=tools,
412
- tool_choice=tool_choice,
413
- config=config,
414
- cache=cache,
415
- )
407
+ output, event = await self._generate(
408
+ input=input,
409
+ tools=tools,
410
+ tool_choice=tool_choice,
411
+ config=config,
412
+ cache=cache,
413
+ )
416
414
 
417
415
  # update the most recent ModelEvent with the actual start/completed
418
416
  # times as well as a computation of working time (events are
419
417
  # created _after_ the call to _generate, potentially in response
420
418
  # to retries, so they need their timestamp updated so it accurately
421
419
  # reflects the full start/end time which we know here)
422
- from inspect_ai.log._transcript import ModelEvent, transcript
423
-
424
- last_model_event = transcript().find_last_event(ModelEvent)
425
- if last_model_event:
426
- last_model_event.timestamp = start_time
427
- last_model_event.working_start = working_start
428
- completed = datetime.now()
429
- last_model_event.completed = completed
430
- last_model_event.working_time = (
431
- output.time
432
- if output.time is not None
433
- else (completed - start_time).total_seconds()
434
- )
420
+ from inspect_ai.log._transcript import ModelEvent
421
+
422
+ assert isinstance(event, ModelEvent)
423
+ event.timestamp = start_time
424
+ event.working_start = working_start
425
+ completed = datetime.now()
426
+ event.completed = completed
427
+ event.working_time = (
428
+ output.time
429
+ if output.time is not None
430
+ else (completed - start_time).total_seconds()
431
+ )
435
432
 
436
433
  # return output
437
434
  return output
@@ -492,9 +489,12 @@ class Model:
492
489
  tool_choice: ToolChoice | None,
493
490
  config: GenerateConfig,
494
491
  cache: bool | CachePolicy = False,
495
- ) -> ModelOutput:
492
+ ) -> tuple[ModelOutput, BaseModel]:
493
+ from inspect_ai.log._samples import track_active_model_event
494
+ from inspect_ai.log._transcript import ModelEvent
495
+
496
496
  # default to 'auto' for tool_choice (same as underlying model apis)
497
- tool_choice = tool_choice if tool_choice else "auto"
497
+ tool_choice = tool_choice if tool_choice is not None else "auto"
498
498
 
499
499
  # resolve top level tool source
500
500
  if isinstance(tools, ToolSource):
@@ -581,7 +581,10 @@ class Model:
581
581
  stop=stop,
582
582
  before_sleep=functools.partial(log_model_retry, self.api.model_name),
583
583
  )
584
- async def generate() -> ModelOutput:
584
+ async def generate() -> tuple[ModelOutput, BaseModel]:
585
+ # type-checker can't see that we made sure tool_choice is not none in the outer frame
586
+ assert tool_choice is not None
587
+
585
588
  check_sample_interrupt()
586
589
 
587
590
  cache_entry: CacheEntry | None
@@ -602,7 +605,7 @@ class Model:
602
605
  )
603
606
  existing = cache_fetch(cache_entry)
604
607
  if isinstance(existing, ModelOutput):
605
- self._record_model_interaction(
608
+ _, event = self._record_model_interaction(
606
609
  input=input,
607
610
  tools=tools_info,
608
611
  tool_choice=tool_choice,
@@ -611,7 +614,7 @@ class Model:
611
614
  output=existing,
612
615
  call=None,
613
616
  )
614
- return existing
617
+ return existing, event
615
618
  else:
616
619
  cache_entry = None
617
620
 
@@ -620,7 +623,7 @@ class Model:
620
623
 
621
624
  # record the interaction before the call to generate
622
625
  # (we'll update it with the results once we have them)
623
- complete = self._record_model_interaction(
626
+ complete, event = self._record_model_interaction(
624
627
  input=input,
625
628
  tools=tools_info,
626
629
  tool_choice=tool_choice,
@@ -631,12 +634,14 @@ class Model:
631
634
  with trace_action(logger, "Model", f"generate ({str(self)})"):
632
635
  time_start = time.monotonic()
633
636
  try:
634
- result = await self.api.generate(
635
- input=input,
636
- tools=tools_info,
637
- tool_choice=tool_choice,
638
- config=config,
639
- )
637
+ assert isinstance(event, ModelEvent)
638
+ with track_active_model_event(event):
639
+ result = await self.api.generate(
640
+ input=input,
641
+ tools=tools_info,
642
+ tool_choice=tool_choice,
643
+ config=config,
644
+ )
640
645
  finally:
641
646
  time_elapsed = time.monotonic() - time_start
642
647
 
@@ -686,18 +691,18 @@ class Model:
686
691
  if cache and cache_entry:
687
692
  cache_store(entry=cache_entry, output=output)
688
693
 
689
- return output
694
+ return output, event
690
695
 
691
696
  # call the model (this will so retries, etc., so report waiting time
692
697
  # as elapsed time - actual time for successful model call)
693
698
  time_start = time.monotonic()
694
- model_output = await generate()
699
+ model_output, event = await generate()
695
700
  total_time = time.monotonic() - time_start
696
701
  if model_output.time:
697
702
  report_sample_waiting_time(total_time - model_output.time)
698
703
 
699
704
  # return results
700
- return model_output
705
+ return model_output, event
701
706
 
702
707
  def should_retry(self, ex: BaseException) -> bool:
703
708
  if isinstance(ex, Exception):
@@ -769,7 +774,7 @@ class Model:
769
774
  cache: Literal["read", "write"] | None,
770
775
  output: ModelOutput | None = None,
771
776
  call: ModelCall | None = None,
772
- ) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
777
+ ) -> tuple[Callable[[ModelOutput | Exception, ModelCall | None], None], BaseModel]:
773
778
  from inspect_ai.log._transcript import ModelEvent, transcript
774
779
 
775
780
  # create event and add it to the transcript
@@ -809,7 +814,7 @@ class Model:
809
814
  if output:
810
815
  complete(output, call)
811
816
 
812
- return complete
817
+ return complete, event
813
818
 
814
819
 
815
820
  class ModelName:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent
2
4
  import concurrent.futures
3
5
  import copy
@@ -26,7 +28,12 @@ from transformers import ( # type: ignore
26
28
  from typing_extensions import override
27
29
 
28
30
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
29
- from inspect_ai._util.content import ContentText
31
+ from inspect_ai._util.content import (
32
+ ContentAudio,
33
+ ContentImage,
34
+ ContentText,
35
+ ContentVideo,
36
+ )
30
37
  from inspect_ai._util.trace import trace_action
31
38
  from inspect_ai.tool import ToolChoice, ToolInfo
32
39
 
@@ -85,6 +92,7 @@ class HuggingFaceAPI(ModelAPI):
85
92
  self.batch_size = collect_model_arg("batch_size")
86
93
  self.chat_template = collect_model_arg("chat_template")
87
94
  self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
95
+ self.enable_thinking = collect_model_arg("enable_thinking")
88
96
  if self.tokenizer_call_args is None:
89
97
  self.tokenizer_call_args = {}
90
98
 
@@ -263,6 +271,7 @@ class HuggingFaceAPI(ModelAPI):
263
271
  elif "qwen" in self.model_name.lower():
264
272
  hf_messages = inspect_tools_to_string(hf_messages)
265
273
 
274
+ hf_messages = message_content_to_string(hf_messages)
266
275
  # apply chat template
267
276
  if self.tokenizer.chat_template is not None:
268
277
  chat = self.tokenizer.apply_chat_template(
@@ -270,6 +279,7 @@ class HuggingFaceAPI(ModelAPI):
270
279
  add_generation_prompt=True,
271
280
  tokenize=False,
272
281
  tools=tools_list if len(tools_list) > 0 else None,
282
+ enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
273
283
  )
274
284
  else:
275
285
  chat = ""
@@ -279,6 +289,22 @@ class HuggingFaceAPI(ModelAPI):
279
289
  return cast(str, chat)
280
290
 
281
291
 
292
+ def message_content_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
293
+ """Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
294
+ for message in messages:
295
+ if isinstance(message.content, list):
296
+ is_multimodal = any(
297
+ isinstance(item, ContentAudio | ContentImage | ContentVideo)
298
+ for item in message.content
299
+ )
300
+ if is_multimodal:
301
+ raise NotImplementedError(
302
+ "HuggingFace provider does not support multimodal content, please provide text inputs only."
303
+ )
304
+ message.content = message.text
305
+ return messages
306
+
307
+
282
308
  def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
283
309
  """Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
284
310
  for i, message in enumerate(messages):
@@ -71,6 +71,7 @@ class SGLangAPI(OpenAICompatibleAPI):
71
71
  SGLANG_DEFAULT_SERVER_ARGS, server_args, logger
72
72
  )
73
73
 
74
+ self.server_found = True
74
75
  try:
75
76
  # Try to initialize with existing server
76
77
  super().__init__(
@@ -83,7 +84,9 @@ class SGLangAPI(OpenAICompatibleAPI):
83
84
  )
84
85
  logger.info(f"Using existing SGLang server at {self.base_url}")
85
86
  except PrerequisiteError:
86
- # No existing server found, start a new one
87
+ self.server_found = False
88
+
89
+ if not self.server_found:
87
90
  logger.warning(
88
91
  f"Existing SGLang server not found. Starting new server for {model_name}."
89
92
  )
@@ -125,7 +128,9 @@ class SGLangAPI(OpenAICompatibleAPI):
125
128
  api_key = "inspectai" # Create a default API key if not provided
126
129
 
127
130
  # Handle device configuration
128
- self.server_args = configure_devices(self.server_args, parallel_size_param="tp")
131
+ self.server_args, env_vars = configure_devices(
132
+ self.server_args, parallel_size_param="tp"
133
+ )
129
134
 
130
135
  timeout = self.server_args.pop("timeout", None)
131
136
  host = self.server_args.pop("host", "0.0.0.0")
@@ -149,6 +154,7 @@ class SGLangAPI(OpenAICompatibleAPI):
149
154
  server_type="SGLang",
150
155
  timeout=timeout,
151
156
  server_args=self.server_args,
157
+ env=env_vars,
152
158
  )
153
159
 
154
160
  # Register cleanup function to run when Python exits
@@ -76,6 +76,7 @@ class VLLMAPI(OpenAICompatibleAPI):
76
76
  VLLM_DEFAULT_SERVER_ARGS, server_args, logger
77
77
  )
78
78
 
79
+ self.server_found = True
79
80
  try:
80
81
  # Try to initialize with existing server
81
82
  super().__init__(
@@ -88,7 +89,9 @@ class VLLMAPI(OpenAICompatibleAPI):
88
89
  )
89
90
  logger.info(f"Using existing vLLM server at {self.base_url}")
90
91
  except PrerequisiteError:
91
- # No existing server found, start a new one
92
+ self.server_found = False
93
+
94
+ if not self.server_found:
92
95
  logger.warning(
93
96
  f"Existing vLLM server not found. Starting new server for {model_name}."
94
97
  )
@@ -131,7 +134,7 @@ class VLLMAPI(OpenAICompatibleAPI):
131
134
  raise pip_dependency_error("vLLM Server", ["vllm"])
132
135
 
133
136
  # Handle device configuration
134
- self.server_args = configure_devices(
137
+ self.server_args, env_vars = configure_devices(
135
138
  self.server_args, parallel_size_param="tensor_parallel_size"
136
139
  )
137
140
 
@@ -152,6 +155,7 @@ class VLLMAPI(OpenAICompatibleAPI):
152
155
  server_type="vLLM",
153
156
  timeout=timeout,
154
157
  server_args=self.server_args,
158
+ env=env_vars,
155
159
  )
156
160
 
157
161
  # Register cleanup function to run when Python exits
@@ -1,6 +1,5 @@
1
+ from inspect_ai._util.answer import answer_character, answer_index
1
2
  from inspect_ai.solver._multiple_choice import (
2
- answer_character,
3
- answer_index,
4
3
  answer_options,
5
4
  unshuffle_choices,
6
5
  )
@@ -82,7 +82,7 @@ class Chain(Sequence[Solver], Solver):
82
82
  from ._transcript import solver_transcript
83
83
 
84
84
  for slv in self._solvers:
85
- with solver_transcript(slv, state) as st:
85
+ async with solver_transcript(slv, state) as st:
86
86
  state = await slv(state, generate)
87
87
  st.complete(state)
88
88
  if state.completed:
@@ -73,7 +73,7 @@ async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
73
73
  @subtask(name=name, store=state.store, type="fork", input=input) # type: ignore
74
74
  async def solve() -> TaskState:
75
75
  if not isinstance(solver, Chain):
76
- with solver_transcript(solver, state) as st:
76
+ async with solver_transcript(solver, state) as st:
77
77
  new_state = await solver(state, generate)
78
78
  st.complete(new_state)
79
79
  return new_state
@@ -6,6 +6,7 @@ from typing import Match, TypedDict
6
6
 
7
7
  from typing_extensions import Unpack
8
8
 
9
+ from inspect_ai._util.answer import answer_character, answer_index
9
10
  from inspect_ai._util.logger import warn_once
10
11
  from inspect_ai.util import resource
11
12
 
@@ -64,31 +65,13 @@ def answer_options(choices: Choices) -> str:
64
65
  indexes = list(range(len(choices)))
65
66
 
66
67
  return "\n".join(
67
- [f"{chr(65 + i)}) {choices[j].value}" for i, j in enumerate(indexes)]
68
+ [f"{answer_character(i)}) {choices[j].value}" for i, j in enumerate(indexes)]
68
69
  )
69
70
 
70
71
 
71
- def answer_character(index: int) -> str:
72
- r"""
73
- Helper to go from array index to char, for example:
74
-
75
- 0 -> 'A', 1 -> 'B', etc
76
- """
77
- return chr(ord("A") + index)
78
-
79
-
80
- def answer_index(char: str) -> int:
81
- r"""
82
- Helper to go from char to array index, for example:
83
-
84
- 'A' -> 0, 'B' -> 1, etc
85
- """
86
- return ord(char.upper()) - ord("A")
87
-
88
-
89
72
  def prompt(question: str, choices: Choices, template: str) -> str:
90
73
  choices_text = answer_options(choices)
91
- letters = ",".join(chr(65 + i) for i in range(len(choices)))
74
+ letters = ",".join(answer_character(i) for i in range(len(choices)))
92
75
 
93
76
  return template.format(
94
77
  choices=choices_text,
@@ -112,7 +95,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
112
95
  # In this case, we're looking for a single line which contains the expected
113
96
  # ANSWER: B,C string with only whitespace after it
114
97
  match = re.search(
115
- r"(?i)^ANSWER\s*:\s*([A-Za-z ,]+)\s*(?:$|\n)",
98
+ r"(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n)",
116
99
  state.output.completion,
117
100
  flags=re.MULTILINE,
118
101
  )
@@ -121,7 +104,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
121
104
  # version for backward compatibility
122
105
  if match is None:
123
106
  return re.search(
124
- r"(?i)ANSWER\s*:\s*([A-Za-z ,]+)(?:[^\w]|\n|$)", state.output.completion
107
+ r"(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$)", state.output.completion
125
108
  )
126
109
  else:
127
110
  return match
@@ -102,7 +102,7 @@ class Plan(Solver):
102
102
  # execute steps
103
103
  for index, solver in enumerate(self.steps):
104
104
  # run solver
105
- with solver_transcript(solver, state) as st:
105
+ async with solver_transcript(solver, state) as st:
106
106
  state = await solver(state, generate)
107
107
  st.complete(state)
108
108
 
@@ -113,7 +113,7 @@ class Plan(Solver):
113
113
 
114
114
  # execute finish
115
115
  if self.finish:
116
- with solver_transcript(self.finish, state) as st:
116
+ async with solver_transcript(self.finish, state) as st:
117
117
  state = await self.finish(state, generate)
118
118
  st.complete(state)
119
119
  check_sample_interrupt()
@@ -1,8 +1,9 @@
1
1
  import contextlib
2
- from typing import Iterator
2
+ from typing import AsyncIterator
3
3
 
4
4
  from inspect_ai._util.json import json_changes
5
5
  from inspect_ai._util.registry import registry_log_name
6
+ from inspect_ai.util._span import span
6
7
 
7
8
  from ._solver import Solver
8
9
  from ._task_state import TaskState, state_jsonable
@@ -22,12 +23,10 @@ class SolverTranscript:
22
23
  transcript()._event(StateEvent(changes=changes))
23
24
 
24
25
 
25
- @contextlib.contextmanager
26
- def solver_transcript(
26
+ @contextlib.asynccontextmanager
27
+ async def solver_transcript(
27
28
  solver: Solver, state: TaskState, name: str | None = None
28
- ) -> Iterator[SolverTranscript]:
29
- from inspect_ai.log._transcript import transcript
30
-
29
+ ) -> AsyncIterator[SolverTranscript]:
31
30
  name = registry_log_name(name or solver)
32
- with transcript().step(name=name, type="solver"):
31
+ async with span(name=name, type="solver"):
33
32
  yield SolverTranscript(name, state)
@@ -61,16 +61,17 @@ class MCPServerImpl(MCPServer):
61
61
  ) -> list[Tool]:
62
62
  return await self._task_session()._list_tools(tools)
63
63
 
64
- # create a separate MCPServer session per async task
65
- _task_sessions: dict[int, "MCPServerSession"] = {}
64
+ # create a separate MCPServer session per async task / server name
65
+ _task_sessions: dict[str, "MCPServerSession"] = {}
66
66
 
67
67
  def _task_session(self) -> "MCPServerSession":
68
68
  task_id = anyio.get_current_task().id
69
- if task_id not in self._task_sessions:
70
- MCPServerImpl._task_sessions[task_id] = MCPServerSession(
69
+ session_key = f"{task_id}_{self._name}"
70
+ if session_key not in self._task_sessions:
71
+ MCPServerImpl._task_sessions[session_key] = MCPServerSession(
71
72
  self._client, name=self._name, events=self._events
72
73
  )
73
- return MCPServerImpl._task_sessions[task_id]
74
+ return MCPServerImpl._task_sessions[session_key]
74
75
 
75
76
 
76
77
  class MCPServerSession(MCPServer):
@@ -96,7 +96,10 @@ def python(
96
96
  The output of the Python code.
97
97
  """
98
98
  result = await sandbox_env(sandbox).exec(
99
- cmd=["python3"], input=code, timeout=timeout, user=user
99
+ cmd=["bash", "--login", "-c", "python3 -"],
100
+ input=code,
101
+ timeout=timeout,
102
+ user=user,
100
103
  )
101
104
  # return output (including stderr if any)
102
105
  output = ""
@@ -8,6 +8,7 @@ from inspect_ai.util._limit import (
8
8
  token_limit,
9
9
  )
10
10
 
11
+ from ._collect import collect
11
12
  from ._concurrency import concurrency
12
13
  from ._console import input_screen
13
14
  from ._display import DisplayType, display_counter, display_type
@@ -28,6 +29,7 @@ from ._sandbox import (
28
29
  sandbox_with,
29
30
  sandboxenv,
30
31
  )
32
+ from ._span import span
31
33
  from ._store import Store, store
32
34
  from ._store_model import StoreModel, store_as
33
35
  from ._subprocess import (
@@ -71,6 +73,8 @@ __all__ = [
71
73
  "store",
72
74
  "StoreModel",
73
75
  "store_as",
76
+ "span",
77
+ "collect",
74
78
  "Subtask",
75
79
  "subtask",
76
80
  "throttle",
inspect_ai/util/_anyio.py CHANGED
@@ -1,6 +1,10 @@
1
1
  import itertools
2
2
  import sys
3
3
 
4
+ import anyio
5
+
6
+ from inspect_ai._util._async import current_async_backend
7
+
4
8
  if sys.version_info < (3, 11):
5
9
  from exceptiongroup import ExceptionGroup
6
10
 
@@ -36,3 +40,10 @@ def _flatten_exception(exc: Exception) -> list[Exception]:
36
40
  ]
37
41
 
38
42
  return maybe_this_exception + other_exceptions
43
+
44
+
45
+ def safe_current_task_id() -> int | None:
46
+ if current_async_backend() is not None:
47
+ return anyio.get_current_task().id
48
+ else:
49
+ return None
@@ -0,0 +1,50 @@
1
+ import sys
2
+ from typing import Awaitable, TypeVar, cast
3
+
4
+ import anyio
5
+
6
+ from ._span import span
7
+
8
+ if sys.version_info < (3, 11):
9
+ from exceptiongroup import ExceptionGroup
10
+
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ async def collect(*tasks: Awaitable[T]) -> list[T]:
16
+ """Run and collect the results of one or more async coroutines.
17
+
18
+ Similar to [`asyncio.gather()`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather),
19
+ but also works when [Trio](https://trio.readthedocs.io/en/stable/) is the async backend.
20
+
21
+ Automatically includes each task in a `span()`, which
22
+ ensures that its events are grouped together in the transcript.
23
+
24
+ Using `collect()` in preference to `asyncio.gather()` is highly recommended
25
+ for both Trio compatibility and more legible transcript output.
26
+
27
+ Args:
28
+ *tasks: Tasks to run
29
+
30
+ Returns:
31
+ List of task results.
32
+ """
33
+ results: list[None | T] = [None] * len(tasks)
34
+
35
+ try:
36
+ async with anyio.create_task_group() as tg:
37
+
38
+ async def run_task(index: int, task: Awaitable[T]) -> None:
39
+ async with span(f"task-{index + 1}", type="task"):
40
+ results[index] = await task
41
+
42
+ for i, task in enumerate(tasks):
43
+ tg.start_soon(run_task, i, task)
44
+ except ExceptionGroup as ex:
45
+ if len(ex.exceptions) == 1:
46
+ raise ex.exceptions[0] from None
47
+ else:
48
+ raise
49
+
50
+ return cast(list[T], results)
@@ -0,0 +1,58 @@
1
+ import contextlib
2
+ from contextvars import ContextVar
3
+ from typing import AsyncIterator
4
+ from uuid import uuid4
5
+
6
+
7
+ @contextlib.asynccontextmanager
8
+ async def span(name: str, *, type: str | None = None) -> AsyncIterator[None]:
9
+ """Context manager for establishing a transcript span.
10
+
11
+ Args:
12
+ name (str): Step name.
13
+ type (str | None): Optional span type.
14
+ """
15
+ from inspect_ai.log._transcript import (
16
+ SpanBeginEvent,
17
+ SpanEndEvent,
18
+ track_store_changes,
19
+ transcript,
20
+ )
21
+
22
+ # span id
23
+ id = uuid4().hex
24
+
25
+ # capture parent id
26
+ parent_id = _current_span_id.get()
27
+
28
+ # set new current span (reset at the end)
29
+ token = _current_span_id.set(id)
30
+
31
+ # run the span
32
+ try:
33
+ # span begin event
34
+ transcript()._event(
35
+ SpanBeginEvent(
36
+ id=id,
37
+ parent_id=parent_id,
38
+ type=type,
39
+ name=name,
40
+ )
41
+ )
42
+
43
+ # run span w/ store change events
44
+ with track_store_changes():
45
+ yield
46
+
47
+ finally:
48
+ # send end event
49
+ transcript()._event(SpanEndEvent(id=id))
50
+
51
+ _current_span_id.reset(token)
52
+
53
+
54
+ def current_span_id() -> str | None:
55
+ return _current_span_id.get()
56
+
57
+
58
+ _current_span_id: ContextVar[str | None] = ContextVar("_current_span_id", default=None)