inspect-ai 0.3.93__py3-none-any.whl → 0.3.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/task/run.py +10 -7
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_view/www/dist/assets/index.css +14 -13
- inspect_ai/_view/www/dist/assets/index.js +400 -84
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +42 -34
- inspect_ai/model/_model.py +45 -40
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +5 -22
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +56 -51
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
inspect_ai/model/_model.py
CHANGED
@@ -19,6 +19,7 @@ from typing import (
|
|
19
19
|
cast,
|
20
20
|
)
|
21
21
|
|
22
|
+
from pydantic import BaseModel
|
22
23
|
from pydantic_core import to_jsonable_python
|
23
24
|
from tenacity import (
|
24
25
|
RetryCallState,
|
@@ -402,36 +403,32 @@ class Model:
|
|
402
403
|
start_time = datetime.now()
|
403
404
|
working_start = sample_working_time()
|
404
405
|
async with self._connection_concurrency(config):
|
405
|
-
from inspect_ai.log._samples import track_active_sample_retries
|
406
|
-
|
407
406
|
# generate
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
)
|
407
|
+
output, event = await self._generate(
|
408
|
+
input=input,
|
409
|
+
tools=tools,
|
410
|
+
tool_choice=tool_choice,
|
411
|
+
config=config,
|
412
|
+
cache=cache,
|
413
|
+
)
|
416
414
|
|
417
415
|
# update the most recent ModelEvent with the actual start/completed
|
418
416
|
# times as well as a computation of working time (events are
|
419
417
|
# created _after_ the call to _generate, potentially in response
|
420
418
|
# to retries, so they need their timestamp updated so it accurately
|
421
419
|
# reflects the full start/end time which we know here)
|
422
|
-
from inspect_ai.log._transcript import ModelEvent
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
)
|
420
|
+
from inspect_ai.log._transcript import ModelEvent
|
421
|
+
|
422
|
+
assert isinstance(event, ModelEvent)
|
423
|
+
event.timestamp = start_time
|
424
|
+
event.working_start = working_start
|
425
|
+
completed = datetime.now()
|
426
|
+
event.completed = completed
|
427
|
+
event.working_time = (
|
428
|
+
output.time
|
429
|
+
if output.time is not None
|
430
|
+
else (completed - start_time).total_seconds()
|
431
|
+
)
|
435
432
|
|
436
433
|
# return output
|
437
434
|
return output
|
@@ -492,9 +489,12 @@ class Model:
|
|
492
489
|
tool_choice: ToolChoice | None,
|
493
490
|
config: GenerateConfig,
|
494
491
|
cache: bool | CachePolicy = False,
|
495
|
-
) -> ModelOutput:
|
492
|
+
) -> tuple[ModelOutput, BaseModel]:
|
493
|
+
from inspect_ai.log._samples import track_active_model_event
|
494
|
+
from inspect_ai.log._transcript import ModelEvent
|
495
|
+
|
496
496
|
# default to 'auto' for tool_choice (same as underlying model apis)
|
497
|
-
tool_choice = tool_choice if tool_choice else "auto"
|
497
|
+
tool_choice = tool_choice if tool_choice is not None else "auto"
|
498
498
|
|
499
499
|
# resolve top level tool source
|
500
500
|
if isinstance(tools, ToolSource):
|
@@ -581,7 +581,10 @@ class Model:
|
|
581
581
|
stop=stop,
|
582
582
|
before_sleep=functools.partial(log_model_retry, self.api.model_name),
|
583
583
|
)
|
584
|
-
async def generate() -> ModelOutput:
|
584
|
+
async def generate() -> tuple[ModelOutput, BaseModel]:
|
585
|
+
# type-checker can't see that we made sure tool_choice is not none in the outer frame
|
586
|
+
assert tool_choice is not None
|
587
|
+
|
585
588
|
check_sample_interrupt()
|
586
589
|
|
587
590
|
cache_entry: CacheEntry | None
|
@@ -602,7 +605,7 @@ class Model:
|
|
602
605
|
)
|
603
606
|
existing = cache_fetch(cache_entry)
|
604
607
|
if isinstance(existing, ModelOutput):
|
605
|
-
self._record_model_interaction(
|
608
|
+
_, event = self._record_model_interaction(
|
606
609
|
input=input,
|
607
610
|
tools=tools_info,
|
608
611
|
tool_choice=tool_choice,
|
@@ -611,7 +614,7 @@ class Model:
|
|
611
614
|
output=existing,
|
612
615
|
call=None,
|
613
616
|
)
|
614
|
-
return existing
|
617
|
+
return existing, event
|
615
618
|
else:
|
616
619
|
cache_entry = None
|
617
620
|
|
@@ -620,7 +623,7 @@ class Model:
|
|
620
623
|
|
621
624
|
# record the interaction before the call to generate
|
622
625
|
# (we'll update it with the results once we have them)
|
623
|
-
complete = self._record_model_interaction(
|
626
|
+
complete, event = self._record_model_interaction(
|
624
627
|
input=input,
|
625
628
|
tools=tools_info,
|
626
629
|
tool_choice=tool_choice,
|
@@ -631,12 +634,14 @@ class Model:
|
|
631
634
|
with trace_action(logger, "Model", f"generate ({str(self)})"):
|
632
635
|
time_start = time.monotonic()
|
633
636
|
try:
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
637
|
+
assert isinstance(event, ModelEvent)
|
638
|
+
with track_active_model_event(event):
|
639
|
+
result = await self.api.generate(
|
640
|
+
input=input,
|
641
|
+
tools=tools_info,
|
642
|
+
tool_choice=tool_choice,
|
643
|
+
config=config,
|
644
|
+
)
|
640
645
|
finally:
|
641
646
|
time_elapsed = time.monotonic() - time_start
|
642
647
|
|
@@ -686,18 +691,18 @@ class Model:
|
|
686
691
|
if cache and cache_entry:
|
687
692
|
cache_store(entry=cache_entry, output=output)
|
688
693
|
|
689
|
-
return output
|
694
|
+
return output, event
|
690
695
|
|
691
696
|
# call the model (this will so retries, etc., so report waiting time
|
692
697
|
# as elapsed time - actual time for successful model call)
|
693
698
|
time_start = time.monotonic()
|
694
|
-
model_output = await generate()
|
699
|
+
model_output, event = await generate()
|
695
700
|
total_time = time.monotonic() - time_start
|
696
701
|
if model_output.time:
|
697
702
|
report_sample_waiting_time(total_time - model_output.time)
|
698
703
|
|
699
704
|
# return results
|
700
|
-
return model_output
|
705
|
+
return model_output, event
|
701
706
|
|
702
707
|
def should_retry(self, ex: BaseException) -> bool:
|
703
708
|
if isinstance(ex, Exception):
|
@@ -769,7 +774,7 @@ class Model:
|
|
769
774
|
cache: Literal["read", "write"] | None,
|
770
775
|
output: ModelOutput | None = None,
|
771
776
|
call: ModelCall | None = None,
|
772
|
-
) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
|
777
|
+
) -> tuple[Callable[[ModelOutput | Exception, ModelCall | None], None], BaseModel]:
|
773
778
|
from inspect_ai.log._transcript import ModelEvent, transcript
|
774
779
|
|
775
780
|
# create event and add it to the transcript
|
@@ -809,7 +814,7 @@ class Model:
|
|
809
814
|
if output:
|
810
815
|
complete(output, call)
|
811
816
|
|
812
|
-
return complete
|
817
|
+
return complete, event
|
813
818
|
|
814
819
|
|
815
820
|
class ModelName:
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import concurrent
|
2
4
|
import concurrent.futures
|
3
5
|
import copy
|
@@ -26,7 +28,12 @@ from transformers import ( # type: ignore
|
|
26
28
|
from typing_extensions import override
|
27
29
|
|
28
30
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
29
|
-
from inspect_ai._util.content import
|
31
|
+
from inspect_ai._util.content import (
|
32
|
+
ContentAudio,
|
33
|
+
ContentImage,
|
34
|
+
ContentText,
|
35
|
+
ContentVideo,
|
36
|
+
)
|
30
37
|
from inspect_ai._util.trace import trace_action
|
31
38
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
32
39
|
|
@@ -85,6 +92,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
85
92
|
self.batch_size = collect_model_arg("batch_size")
|
86
93
|
self.chat_template = collect_model_arg("chat_template")
|
87
94
|
self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
|
95
|
+
self.enable_thinking = collect_model_arg("enable_thinking")
|
88
96
|
if self.tokenizer_call_args is None:
|
89
97
|
self.tokenizer_call_args = {}
|
90
98
|
|
@@ -263,6 +271,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
263
271
|
elif "qwen" in self.model_name.lower():
|
264
272
|
hf_messages = inspect_tools_to_string(hf_messages)
|
265
273
|
|
274
|
+
hf_messages = message_content_to_string(hf_messages)
|
266
275
|
# apply chat template
|
267
276
|
if self.tokenizer.chat_template is not None:
|
268
277
|
chat = self.tokenizer.apply_chat_template(
|
@@ -270,6 +279,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
270
279
|
add_generation_prompt=True,
|
271
280
|
tokenize=False,
|
272
281
|
tools=tools_list if len(tools_list) > 0 else None,
|
282
|
+
enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
|
273
283
|
)
|
274
284
|
else:
|
275
285
|
chat = ""
|
@@ -279,6 +289,22 @@ class HuggingFaceAPI(ModelAPI):
|
|
279
289
|
return cast(str, chat)
|
280
290
|
|
281
291
|
|
292
|
+
def message_content_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
|
293
|
+
"""Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
|
294
|
+
for message in messages:
|
295
|
+
if isinstance(message.content, list):
|
296
|
+
is_multimodal = any(
|
297
|
+
isinstance(item, ContentAudio | ContentImage | ContentVideo)
|
298
|
+
for item in message.content
|
299
|
+
)
|
300
|
+
if is_multimodal:
|
301
|
+
raise NotImplementedError(
|
302
|
+
"HuggingFace provider does not support multimodal content, please provide text inputs only."
|
303
|
+
)
|
304
|
+
message.content = message.text
|
305
|
+
return messages
|
306
|
+
|
307
|
+
|
282
308
|
def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
|
283
309
|
"""Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
|
284
310
|
for i, message in enumerate(messages):
|
@@ -71,6 +71,7 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
71
71
|
SGLANG_DEFAULT_SERVER_ARGS, server_args, logger
|
72
72
|
)
|
73
73
|
|
74
|
+
self.server_found = True
|
74
75
|
try:
|
75
76
|
# Try to initialize with existing server
|
76
77
|
super().__init__(
|
@@ -83,7 +84,9 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
83
84
|
)
|
84
85
|
logger.info(f"Using existing SGLang server at {self.base_url}")
|
85
86
|
except PrerequisiteError:
|
86
|
-
|
87
|
+
self.server_found = False
|
88
|
+
|
89
|
+
if not self.server_found:
|
87
90
|
logger.warning(
|
88
91
|
f"Existing SGLang server not found. Starting new server for {model_name}."
|
89
92
|
)
|
@@ -125,7 +128,9 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
125
128
|
api_key = "inspectai" # Create a default API key if not provided
|
126
129
|
|
127
130
|
# Handle device configuration
|
128
|
-
self.server_args = configure_devices(
|
131
|
+
self.server_args, env_vars = configure_devices(
|
132
|
+
self.server_args, parallel_size_param="tp"
|
133
|
+
)
|
129
134
|
|
130
135
|
timeout = self.server_args.pop("timeout", None)
|
131
136
|
host = self.server_args.pop("host", "0.0.0.0")
|
@@ -149,6 +154,7 @@ class SGLangAPI(OpenAICompatibleAPI):
|
|
149
154
|
server_type="SGLang",
|
150
155
|
timeout=timeout,
|
151
156
|
server_args=self.server_args,
|
157
|
+
env=env_vars,
|
152
158
|
)
|
153
159
|
|
154
160
|
# Register cleanup function to run when Python exits
|
@@ -76,6 +76,7 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
76
76
|
VLLM_DEFAULT_SERVER_ARGS, server_args, logger
|
77
77
|
)
|
78
78
|
|
79
|
+
self.server_found = True
|
79
80
|
try:
|
80
81
|
# Try to initialize with existing server
|
81
82
|
super().__init__(
|
@@ -88,7 +89,9 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
88
89
|
)
|
89
90
|
logger.info(f"Using existing vLLM server at {self.base_url}")
|
90
91
|
except PrerequisiteError:
|
91
|
-
|
92
|
+
self.server_found = False
|
93
|
+
|
94
|
+
if not self.server_found:
|
92
95
|
logger.warning(
|
93
96
|
f"Existing vLLM server not found. Starting new server for {model_name}."
|
94
97
|
)
|
@@ -131,7 +134,7 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
131
134
|
raise pip_dependency_error("vLLM Server", ["vllm"])
|
132
135
|
|
133
136
|
# Handle device configuration
|
134
|
-
self.server_args = configure_devices(
|
137
|
+
self.server_args, env_vars = configure_devices(
|
135
138
|
self.server_args, parallel_size_param="tensor_parallel_size"
|
136
139
|
)
|
137
140
|
|
@@ -152,6 +155,7 @@ class VLLMAPI(OpenAICompatibleAPI):
|
|
152
155
|
server_type="vLLM",
|
153
156
|
timeout=timeout,
|
154
157
|
server_args=self.server_args,
|
158
|
+
env=env_vars,
|
155
159
|
)
|
156
160
|
|
157
161
|
# Register cleanup function to run when Python exits
|
inspect_ai/scorer/_choice.py
CHANGED
inspect_ai/solver/_chain.py
CHANGED
@@ -82,7 +82,7 @@ class Chain(Sequence[Solver], Solver):
|
|
82
82
|
from ._transcript import solver_transcript
|
83
83
|
|
84
84
|
for slv in self._solvers:
|
85
|
-
with solver_transcript(slv, state) as st:
|
85
|
+
async with solver_transcript(slv, state) as st:
|
86
86
|
state = await slv(state, generate)
|
87
87
|
st.complete(state)
|
88
88
|
if state.completed:
|
inspect_ai/solver/_fork.py
CHANGED
@@ -73,7 +73,7 @@ async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
|
|
73
73
|
@subtask(name=name, store=state.store, type="fork", input=input) # type: ignore
|
74
74
|
async def solve() -> TaskState:
|
75
75
|
if not isinstance(solver, Chain):
|
76
|
-
with solver_transcript(solver, state) as st:
|
76
|
+
async with solver_transcript(solver, state) as st:
|
77
77
|
new_state = await solver(state, generate)
|
78
78
|
st.complete(new_state)
|
79
79
|
return new_state
|
@@ -6,6 +6,7 @@ from typing import Match, TypedDict
|
|
6
6
|
|
7
7
|
from typing_extensions import Unpack
|
8
8
|
|
9
|
+
from inspect_ai._util.answer import answer_character, answer_index
|
9
10
|
from inspect_ai._util.logger import warn_once
|
10
11
|
from inspect_ai.util import resource
|
11
12
|
|
@@ -64,31 +65,13 @@ def answer_options(choices: Choices) -> str:
|
|
64
65
|
indexes = list(range(len(choices)))
|
65
66
|
|
66
67
|
return "\n".join(
|
67
|
-
[f"{
|
68
|
+
[f"{answer_character(i)}) {choices[j].value}" for i, j in enumerate(indexes)]
|
68
69
|
)
|
69
70
|
|
70
71
|
|
71
|
-
def answer_character(index: int) -> str:
|
72
|
-
r"""
|
73
|
-
Helper to go from array index to char, for example:
|
74
|
-
|
75
|
-
0 -> 'A', 1 -> 'B', etc
|
76
|
-
"""
|
77
|
-
return chr(ord("A") + index)
|
78
|
-
|
79
|
-
|
80
|
-
def answer_index(char: str) -> int:
|
81
|
-
r"""
|
82
|
-
Helper to go from char to array index, for example:
|
83
|
-
|
84
|
-
'A' -> 0, 'B' -> 1, etc
|
85
|
-
"""
|
86
|
-
return ord(char.upper()) - ord("A")
|
87
|
-
|
88
|
-
|
89
72
|
def prompt(question: str, choices: Choices, template: str) -> str:
|
90
73
|
choices_text = answer_options(choices)
|
91
|
-
letters = ",".join(
|
74
|
+
letters = ",".join(answer_character(i) for i in range(len(choices)))
|
92
75
|
|
93
76
|
return template.format(
|
94
77
|
choices=choices_text,
|
@@ -112,7 +95,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
|
|
112
95
|
# In this case, we're looking for a single line which contains the expected
|
113
96
|
# ANSWER: B,C string with only whitespace after it
|
114
97
|
match = re.search(
|
115
|
-
r"(?i)^ANSWER\s*:\s*([A-Za-z ,]+)\s*(?:$|\n)",
|
98
|
+
r"(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n)",
|
116
99
|
state.output.completion,
|
117
100
|
flags=re.MULTILINE,
|
118
101
|
)
|
@@ -121,7 +104,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
|
|
121
104
|
# version for backward compatibility
|
122
105
|
if match is None:
|
123
106
|
return re.search(
|
124
|
-
r"(?i)ANSWER\s*:\s*([A-Za-z ,]+)(?:[^\w]|\n|$)", state.output.completion
|
107
|
+
r"(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$)", state.output.completion
|
125
108
|
)
|
126
109
|
else:
|
127
110
|
return match
|
inspect_ai/solver/_plan.py
CHANGED
@@ -102,7 +102,7 @@ class Plan(Solver):
|
|
102
102
|
# execute steps
|
103
103
|
for index, solver in enumerate(self.steps):
|
104
104
|
# run solver
|
105
|
-
with solver_transcript(solver, state) as st:
|
105
|
+
async with solver_transcript(solver, state) as st:
|
106
106
|
state = await solver(state, generate)
|
107
107
|
st.complete(state)
|
108
108
|
|
@@ -113,7 +113,7 @@ class Plan(Solver):
|
|
113
113
|
|
114
114
|
# execute finish
|
115
115
|
if self.finish:
|
116
|
-
with solver_transcript(self.finish, state) as st:
|
116
|
+
async with solver_transcript(self.finish, state) as st:
|
117
117
|
state = await self.finish(state, generate)
|
118
118
|
st.complete(state)
|
119
119
|
check_sample_interrupt()
|
inspect_ai/solver/_transcript.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import contextlib
|
2
|
-
from typing import
|
2
|
+
from typing import AsyncIterator
|
3
3
|
|
4
4
|
from inspect_ai._util.json import json_changes
|
5
5
|
from inspect_ai._util.registry import registry_log_name
|
6
|
+
from inspect_ai.util._span import span
|
6
7
|
|
7
8
|
from ._solver import Solver
|
8
9
|
from ._task_state import TaskState, state_jsonable
|
@@ -22,12 +23,10 @@ class SolverTranscript:
|
|
22
23
|
transcript()._event(StateEvent(changes=changes))
|
23
24
|
|
24
25
|
|
25
|
-
@contextlib.
|
26
|
-
def solver_transcript(
|
26
|
+
@contextlib.asynccontextmanager
|
27
|
+
async def solver_transcript(
|
27
28
|
solver: Solver, state: TaskState, name: str | None = None
|
28
|
-
) ->
|
29
|
-
from inspect_ai.log._transcript import transcript
|
30
|
-
|
29
|
+
) -> AsyncIterator[SolverTranscript]:
|
31
30
|
name = registry_log_name(name or solver)
|
32
|
-
with
|
31
|
+
async with span(name=name, type="solver"):
|
33
32
|
yield SolverTranscript(name, state)
|
inspect_ai/tool/_mcp/_mcp.py
CHANGED
@@ -61,16 +61,17 @@ class MCPServerImpl(MCPServer):
|
|
61
61
|
) -> list[Tool]:
|
62
62
|
return await self._task_session()._list_tools(tools)
|
63
63
|
|
64
|
-
# create a separate MCPServer session per async task
|
65
|
-
_task_sessions: dict[
|
64
|
+
# create a separate MCPServer session per async task / server name
|
65
|
+
_task_sessions: dict[str, "MCPServerSession"] = {}
|
66
66
|
|
67
67
|
def _task_session(self) -> "MCPServerSession":
|
68
68
|
task_id = anyio.get_current_task().id
|
69
|
-
|
70
|
-
|
69
|
+
session_key = f"{task_id}_{self._name}"
|
70
|
+
if session_key not in self._task_sessions:
|
71
|
+
MCPServerImpl._task_sessions[session_key] = MCPServerSession(
|
71
72
|
self._client, name=self._name, events=self._events
|
72
73
|
)
|
73
|
-
return MCPServerImpl._task_sessions[
|
74
|
+
return MCPServerImpl._task_sessions[session_key]
|
74
75
|
|
75
76
|
|
76
77
|
class MCPServerSession(MCPServer):
|
@@ -96,7 +96,10 @@ def python(
|
|
96
96
|
The output of the Python code.
|
97
97
|
"""
|
98
98
|
result = await sandbox_env(sandbox).exec(
|
99
|
-
cmd=["
|
99
|
+
cmd=["bash", "--login", "-c", "python3 -"],
|
100
|
+
input=code,
|
101
|
+
timeout=timeout,
|
102
|
+
user=user,
|
100
103
|
)
|
101
104
|
# return output (including stderr if any)
|
102
105
|
output = ""
|
inspect_ai/util/__init__.py
CHANGED
@@ -8,6 +8,7 @@ from inspect_ai.util._limit import (
|
|
8
8
|
token_limit,
|
9
9
|
)
|
10
10
|
|
11
|
+
from ._collect import collect
|
11
12
|
from ._concurrency import concurrency
|
12
13
|
from ._console import input_screen
|
13
14
|
from ._display import DisplayType, display_counter, display_type
|
@@ -28,6 +29,7 @@ from ._sandbox import (
|
|
28
29
|
sandbox_with,
|
29
30
|
sandboxenv,
|
30
31
|
)
|
32
|
+
from ._span import span
|
31
33
|
from ._store import Store, store
|
32
34
|
from ._store_model import StoreModel, store_as
|
33
35
|
from ._subprocess import (
|
@@ -71,6 +73,8 @@ __all__ = [
|
|
71
73
|
"store",
|
72
74
|
"StoreModel",
|
73
75
|
"store_as",
|
76
|
+
"span",
|
77
|
+
"collect",
|
74
78
|
"Subtask",
|
75
79
|
"subtask",
|
76
80
|
"throttle",
|
inspect_ai/util/_anyio.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
import itertools
|
2
2
|
import sys
|
3
3
|
|
4
|
+
import anyio
|
5
|
+
|
6
|
+
from inspect_ai._util._async import current_async_backend
|
7
|
+
|
4
8
|
if sys.version_info < (3, 11):
|
5
9
|
from exceptiongroup import ExceptionGroup
|
6
10
|
|
@@ -36,3 +40,10 @@ def _flatten_exception(exc: Exception) -> list[Exception]:
|
|
36
40
|
]
|
37
41
|
|
38
42
|
return maybe_this_exception + other_exceptions
|
43
|
+
|
44
|
+
|
45
|
+
def safe_current_task_id() -> int | None:
|
46
|
+
if current_async_backend() is not None:
|
47
|
+
return anyio.get_current_task().id
|
48
|
+
else:
|
49
|
+
return None
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import sys
|
2
|
+
from typing import Awaitable, TypeVar, cast
|
3
|
+
|
4
|
+
import anyio
|
5
|
+
|
6
|
+
from ._span import span
|
7
|
+
|
8
|
+
if sys.version_info < (3, 11):
|
9
|
+
from exceptiongroup import ExceptionGroup
|
10
|
+
|
11
|
+
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
14
|
+
|
15
|
+
async def collect(*tasks: Awaitable[T]) -> list[T]:
|
16
|
+
"""Run and collect the results of one or more async coroutines.
|
17
|
+
|
18
|
+
Similar to [`asyncio.gather()`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather),
|
19
|
+
but also works when [Trio](https://trio.readthedocs.io/en/stable/) is the async backend.
|
20
|
+
|
21
|
+
Automatically includes each task in a `span()`, which
|
22
|
+
ensures that its events are grouped together in the transcript.
|
23
|
+
|
24
|
+
Using `collect()` in preference to `asyncio.gather()` is highly recommended
|
25
|
+
for both Trio compatibility and more legible transcript output.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
*tasks: Tasks to run
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
List of task results.
|
32
|
+
"""
|
33
|
+
results: list[None | T] = [None] * len(tasks)
|
34
|
+
|
35
|
+
try:
|
36
|
+
async with anyio.create_task_group() as tg:
|
37
|
+
|
38
|
+
async def run_task(index: int, task: Awaitable[T]) -> None:
|
39
|
+
async with span(f"task-{index + 1}", type="task"):
|
40
|
+
results[index] = await task
|
41
|
+
|
42
|
+
for i, task in enumerate(tasks):
|
43
|
+
tg.start_soon(run_task, i, task)
|
44
|
+
except ExceptionGroup as ex:
|
45
|
+
if len(ex.exceptions) == 1:
|
46
|
+
raise ex.exceptions[0] from None
|
47
|
+
else:
|
48
|
+
raise
|
49
|
+
|
50
|
+
return cast(list[T], results)
|
inspect_ai/util/_span.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
import contextlib
|
2
|
+
from contextvars import ContextVar
|
3
|
+
from typing import AsyncIterator
|
4
|
+
from uuid import uuid4
|
5
|
+
|
6
|
+
|
7
|
+
@contextlib.asynccontextmanager
|
8
|
+
async def span(name: str, *, type: str | None = None) -> AsyncIterator[None]:
|
9
|
+
"""Context manager for establishing a transcript span.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
name (str): Step name.
|
13
|
+
type (str | None): Optional span type.
|
14
|
+
"""
|
15
|
+
from inspect_ai.log._transcript import (
|
16
|
+
SpanBeginEvent,
|
17
|
+
SpanEndEvent,
|
18
|
+
track_store_changes,
|
19
|
+
transcript,
|
20
|
+
)
|
21
|
+
|
22
|
+
# span id
|
23
|
+
id = uuid4().hex
|
24
|
+
|
25
|
+
# capture parent id
|
26
|
+
parent_id = _current_span_id.get()
|
27
|
+
|
28
|
+
# set new current span (reset at the end)
|
29
|
+
token = _current_span_id.set(id)
|
30
|
+
|
31
|
+
# run the span
|
32
|
+
try:
|
33
|
+
# span begin event
|
34
|
+
transcript()._event(
|
35
|
+
SpanBeginEvent(
|
36
|
+
id=id,
|
37
|
+
parent_id=parent_id,
|
38
|
+
type=type,
|
39
|
+
name=name,
|
40
|
+
)
|
41
|
+
)
|
42
|
+
|
43
|
+
# run span w/ store change events
|
44
|
+
with track_store_changes():
|
45
|
+
yield
|
46
|
+
|
47
|
+
finally:
|
48
|
+
# send end event
|
49
|
+
transcript()._event(SpanEndEvent(id=id))
|
50
|
+
|
51
|
+
_current_span_id.reset(token)
|
52
|
+
|
53
|
+
|
54
|
+
def current_span_id() -> str | None:
|
55
|
+
return _current_span_id.get()
|
56
|
+
|
57
|
+
|
58
|
+
_current_span_id: ContextVar[str | None] = ContextVar("_current_span_id", default=None)
|