inspect-ai 0.3.49__py3-none-any.whl → 0.3.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/info.py +2 -2
- inspect_ai/_cli/log.py +2 -2
- inspect_ai/_cli/score.py +2 -2
- inspect_ai/_display/core/display.py +19 -0
- inspect_ai/_display/core/panel.py +37 -7
- inspect_ai/_display/core/progress.py +29 -2
- inspect_ai/_display/core/results.py +79 -40
- inspect_ai/_display/core/textual.py +21 -0
- inspect_ai/_display/rich/display.py +28 -8
- inspect_ai/_display/textual/app.py +107 -1
- inspect_ai/_display/textual/display.py +1 -1
- inspect_ai/_display/textual/widgets/samples.py +132 -91
- inspect_ai/_display/textual/widgets/task_detail.py +232 -0
- inspect_ai/_display/textual/widgets/tasks.py +74 -6
- inspect_ai/_display/textual/widgets/toggle.py +32 -0
- inspect_ai/_eval/context.py +2 -0
- inspect_ai/_eval/eval.py +4 -3
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/run.py +35 -2
- inspect_ai/_eval/task/log.py +13 -11
- inspect_ai/_eval/task/results.py +12 -3
- inspect_ai/_eval/task/run.py +139 -36
- inspect_ai/_eval/task/sandbox.py +2 -1
- inspect_ai/_util/_async.py +30 -1
- inspect_ai/_util/file.py +31 -4
- inspect_ai/_util/html.py +3 -0
- inspect_ai/_util/logger.py +6 -5
- inspect_ai/_util/platform.py +5 -6
- inspect_ai/_util/registry.py +1 -1
- inspect_ai/_view/server.py +9 -9
- inspect_ai/_view/www/App.css +2 -2
- inspect_ai/_view/www/dist/assets/index.css +2 -2
- inspect_ai/_view/www/dist/assets/index.js +352 -294
- inspect_ai/_view/www/log-schema.json +13 -0
- inspect_ai/_view/www/package.json +1 -0
- inspect_ai/_view/www/src/components/MessageBand.mjs +1 -1
- inspect_ai/_view/www/src/components/Tools.mjs +16 -13
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -3
- inspect_ai/_view/www/src/samples/SampleScoreView.mjs +52 -77
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -13
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +15 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +4 -2
- inspect_ai/_view/www/src/types/log.d.ts +2 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +2 -0
- inspect_ai/_view/www/yarn.lock +9 -4
- inspect_ai/approval/__init__.py +1 -1
- inspect_ai/approval/_human/approver.py +35 -0
- inspect_ai/approval/_human/console.py +62 -0
- inspect_ai/approval/_human/manager.py +108 -0
- inspect_ai/approval/_human/panel.py +233 -0
- inspect_ai/approval/_human/util.py +51 -0
- inspect_ai/dataset/_sources/hf.py +2 -2
- inspect_ai/dataset/_sources/util.py +1 -1
- inspect_ai/log/_file.py +106 -36
- inspect_ai/log/_recorders/eval.py +226 -158
- inspect_ai/log/_recorders/file.py +9 -6
- inspect_ai/log/_recorders/json.py +35 -12
- inspect_ai/log/_recorders/recorder.py +15 -15
- inspect_ai/log/_samples.py +52 -0
- inspect_ai/model/_model.py +14 -0
- inspect_ai/model/_model_output.py +4 -0
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/hf.py +106 -4
- inspect_ai/model/_providers/util/__init__.py +2 -0
- inspect_ai/model/_providers/util/hf_handler.py +200 -0
- inspect_ai/scorer/_common.py +1 -1
- inspect_ai/solver/_plan.py +0 -8
- inspect_ai/solver/_task_state.py +18 -1
- inspect_ai/solver/_use_tools.py +9 -1
- inspect_ai/tool/_tool_def.py +2 -2
- inspect_ai/tool/_tool_info.py +14 -2
- inspect_ai/tool/_tool_params.py +2 -1
- inspect_ai/tool/_tools/_execute.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +6 -0
- inspect_ai/util/__init__.py +5 -6
- inspect_ai/util/_panel.py +91 -0
- inspect_ai/util/_sandbox/__init__.py +2 -6
- inspect_ai/util/_sandbox/context.py +4 -3
- inspect_ai/util/_sandbox/docker/compose.py +12 -2
- inspect_ai/util/_sandbox/docker/docker.py +19 -9
- inspect_ai/util/_sandbox/docker/util.py +10 -2
- inspect_ai/util/_sandbox/environment.py +47 -41
- inspect_ai/util/_sandbox/local.py +15 -10
- inspect_ai/util/_subprocess.py +43 -3
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/RECORD +90 -82
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/approval/_human.py +0 -123
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/top_level.txt +0 -0
@@ -14,20 +14,27 @@ from inspect_ai.log._log import (
|
|
14
14
|
|
15
15
|
|
16
16
|
class Recorder(abc.ABC):
|
17
|
+
@classmethod
|
18
|
+
@abc.abstractmethod
|
19
|
+
def handles_location(cls, location: str) -> bool: ...
|
20
|
+
|
21
|
+
@abc.abstractmethod
|
22
|
+
def default_log_buffer(self) -> int: ...
|
23
|
+
|
17
24
|
@abc.abstractmethod
|
18
|
-
def log_init(self, eval: EvalSpec, location: str | None = None) -> str: ...
|
25
|
+
async def log_init(self, eval: EvalSpec, location: str | None = None) -> str: ...
|
19
26
|
|
20
27
|
@abc.abstractmethod
|
21
|
-
def log_start(self, eval: EvalSpec, plan: EvalPlan) -> None: ...
|
28
|
+
async def log_start(self, eval: EvalSpec, plan: EvalPlan) -> None: ...
|
22
29
|
|
23
30
|
@abc.abstractmethod
|
24
|
-
def log_sample(self, eval: EvalSpec, sample: EvalSample) -> None: ...
|
31
|
+
async def log_sample(self, eval: EvalSpec, sample: EvalSample) -> None: ...
|
25
32
|
|
26
33
|
@abc.abstractmethod
|
27
|
-
def flush(self, eval: EvalSpec) -> None: ...
|
34
|
+
async def flush(self, eval: EvalSpec) -> None: ...
|
28
35
|
|
29
36
|
@abc.abstractmethod
|
30
|
-
def log_finish(
|
37
|
+
async def log_finish(
|
31
38
|
self,
|
32
39
|
eval: EvalSpec,
|
33
40
|
status: Literal["success", "cancelled", "error"],
|
@@ -37,23 +44,16 @@ class Recorder(abc.ABC):
|
|
37
44
|
error: EvalError | None = None,
|
38
45
|
) -> EvalLog: ...
|
39
46
|
|
40
|
-
@abc.abstractmethod
|
41
|
-
def default_log_buffer(self) -> int: ...
|
42
|
-
|
43
|
-
@classmethod
|
44
|
-
@abc.abstractmethod
|
45
|
-
def handles_location(cls, location: str) -> bool: ...
|
46
|
-
|
47
47
|
@classmethod
|
48
48
|
@abc.abstractmethod
|
49
|
-
def read_log(cls, location: str, header_only: bool = False) -> EvalLog: ...
|
49
|
+
async def read_log(cls, location: str, header_only: bool = False) -> EvalLog: ...
|
50
50
|
|
51
51
|
@classmethod
|
52
52
|
@abc.abstractmethod
|
53
|
-
def read_log_sample(
|
53
|
+
async def read_log_sample(
|
54
54
|
cls, location: str, id: str | int, epoch: int = 1
|
55
55
|
) -> EvalSample: ...
|
56
56
|
|
57
57
|
@classmethod
|
58
58
|
@abc.abstractmethod
|
59
|
-
def write_log(cls, location: str, log: EvalLog) -> None: ...
|
59
|
+
async def write_log(cls, location: str, log: EvalLog) -> None: ...
|
inspect_ai/log/_samples.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import contextlib
|
3
|
+
from contextvars import ContextVar
|
3
4
|
from datetime import datetime
|
4
5
|
from typing import AsyncGenerator, Literal
|
5
6
|
|
@@ -15,10 +16,14 @@ from ._transcript import Transcript
|
|
15
16
|
class ActiveSample:
|
16
17
|
def __init__(
|
17
18
|
self,
|
19
|
+
*,
|
18
20
|
task: str,
|
19
21
|
model: str,
|
20
22
|
sample: Sample,
|
21
23
|
epoch: int,
|
24
|
+
message_limit: int | None,
|
25
|
+
token_limit: int | None,
|
26
|
+
time_limit: int | None,
|
22
27
|
fails_on_error: bool,
|
23
28
|
transcript: Transcript,
|
24
29
|
sandboxes: dict[str, SandboxConnection],
|
@@ -30,7 +35,12 @@ class ActiveSample:
|
|
30
35
|
self.model = model
|
31
36
|
self.sample = sample
|
32
37
|
self.epoch = epoch
|
38
|
+
self.message_limit = message_limit
|
39
|
+
self.token_limit = token_limit
|
40
|
+
self.time_limit = time_limit
|
33
41
|
self.fails_on_error = fails_on_error
|
42
|
+
self.total_messages = 0
|
43
|
+
self.total_tokens = 0
|
34
44
|
self.transcript = transcript
|
35
45
|
self.sandboxes = sandboxes
|
36
46
|
self._sample_task = asyncio.current_task()
|
@@ -59,10 +69,14 @@ def init_active_samples() -> None:
|
|
59
69
|
|
60
70
|
@contextlib.asynccontextmanager
|
61
71
|
async def active_sample(
|
72
|
+
*,
|
62
73
|
task: str,
|
63
74
|
model: str,
|
64
75
|
sample: Sample,
|
65
76
|
epoch: int,
|
77
|
+
message_limit: int | None,
|
78
|
+
token_limit: int | None,
|
79
|
+
time_limit: int | None,
|
66
80
|
fails_on_error: bool,
|
67
81
|
transcript: Transcript,
|
68
82
|
) -> AsyncGenerator[ActiveSample, None]:
|
@@ -72,17 +86,55 @@ async def active_sample(
|
|
72
86
|
model=model,
|
73
87
|
sample=sample,
|
74
88
|
epoch=epoch,
|
89
|
+
message_limit=message_limit,
|
90
|
+
token_limit=token_limit,
|
91
|
+
time_limit=time_limit,
|
75
92
|
sandboxes=await sandbox_connections(),
|
76
93
|
fails_on_error=fails_on_error,
|
77
94
|
transcript=transcript,
|
78
95
|
)
|
79
96
|
|
80
97
|
_active_samples.append(active)
|
98
|
+
_sample_active.set(active)
|
81
99
|
try:
|
82
100
|
yield active
|
83
101
|
finally:
|
84
102
|
active.completed = datetime.now().timestamp()
|
85
103
|
_active_samples.remove(active)
|
104
|
+
_sample_active.set(None)
|
105
|
+
|
106
|
+
|
107
|
+
def sample_active() -> ActiveSample | None:
|
108
|
+
return _sample_active.get(None)
|
109
|
+
|
110
|
+
|
111
|
+
def set_active_sample_token_limit(token_limit: int | None) -> None:
|
112
|
+
active = sample_active()
|
113
|
+
if active:
|
114
|
+
active.token_limit = token_limit
|
115
|
+
|
116
|
+
|
117
|
+
def set_active_sample_total_tokens(total_tokens: int) -> None:
|
118
|
+
active = sample_active()
|
119
|
+
if active:
|
120
|
+
active.total_tokens = total_tokens
|
121
|
+
|
122
|
+
|
123
|
+
def set_active_sample_message_limit(message_limit: int | None) -> None:
|
124
|
+
active = sample_active()
|
125
|
+
if active:
|
126
|
+
active.message_limit = message_limit
|
127
|
+
|
128
|
+
|
129
|
+
def set_active_sample_total_messages(total_messages: int) -> None:
|
130
|
+
active = sample_active()
|
131
|
+
if active:
|
132
|
+
active.total_messages = total_messages
|
133
|
+
|
134
|
+
|
135
|
+
_sample_active: ContextVar[ActiveSample | None] = ContextVar(
|
136
|
+
"_sample_active", default=None
|
137
|
+
)
|
86
138
|
|
87
139
|
|
88
140
|
def active_samples() -> list[ActiveSample]:
|
inspect_ai/model/_model.py
CHANGED
@@ -4,6 +4,7 @@ import functools
|
|
4
4
|
import json
|
5
5
|
import logging
|
6
6
|
import os
|
7
|
+
import time
|
7
8
|
from contextvars import ContextVar
|
8
9
|
from copy import deepcopy
|
9
10
|
from typing import Any, Callable, Literal, Type, cast
|
@@ -355,12 +356,14 @@ class Model:
|
|
355
356
|
|
356
357
|
generate_id = uuid()
|
357
358
|
logger.debug(f"model generate {generate_id} ({str(self)})")
|
359
|
+
time_start = time.perf_counter()
|
358
360
|
result = await self.api.generate(
|
359
361
|
input=input,
|
360
362
|
tools=tools,
|
361
363
|
tool_choice=tool_choice,
|
362
364
|
config=config,
|
363
365
|
)
|
366
|
+
time_elapsed = time.perf_counter() - time_start
|
364
367
|
logger.debug(f"model generate {generate_id} (completed)")
|
365
368
|
if isinstance(result, tuple):
|
366
369
|
output, call = result
|
@@ -368,12 +371,18 @@ class Model:
|
|
368
371
|
output = result
|
369
372
|
call = None
|
370
373
|
|
374
|
+
# update output with time elapsed
|
375
|
+
output.time = time_elapsed
|
376
|
+
|
371
377
|
# complete the transcript event
|
372
378
|
complete(output, call)
|
373
379
|
|
374
380
|
# record usage
|
375
381
|
if output.usage:
|
382
|
+
# record usage
|
376
383
|
record_model_usage(f"{self}", output.usage)
|
384
|
+
|
385
|
+
# send telemetry if its hooked up
|
377
386
|
await send_telemetry(
|
378
387
|
"model_usage",
|
379
388
|
json.dumps(dict(model=str(self), usage=output.usage.model_dump())),
|
@@ -762,6 +771,11 @@ def record_model_usage(model: str, usage: ModelUsage) -> None:
|
|
762
771
|
set_model_usage(model, usage, sample_model_usage_context_var.get(None))
|
763
772
|
set_model_usage(model, usage, model_usage_context_var.get(None))
|
764
773
|
|
774
|
+
# update active sample
|
775
|
+
from inspect_ai.log._samples import set_active_sample_total_tokens
|
776
|
+
|
777
|
+
set_active_sample_total_tokens(sample_total_tokens())
|
778
|
+
|
765
779
|
|
766
780
|
def set_model_usage(
|
767
781
|
model: str, usage: ModelUsage, model_usage: dict[str, ModelUsage] | None
|
@@ -100,7 +100,11 @@ class ModelOutput(BaseModel):
|
|
100
100
|
usage: ModelUsage | None = Field(default=None)
|
101
101
|
"""Model token usage"""
|
102
102
|
|
103
|
+
time: float | None = Field(default=None)
|
104
|
+
"""Time elapsed (in seconds) for call to generate."""
|
105
|
+
|
103
106
|
metadata: dict[str, Any] | None = Field(default=None)
|
107
|
+
"""Additional metadata associated with model output."""
|
104
108
|
|
105
109
|
error: str | None = Field(default=None)
|
106
110
|
"""Error message in the case of content moderation refusals."""
|
@@ -362,7 +362,7 @@ def chat_completion_assistant_message(
|
|
362
362
|
return handler.parse_assistant_response(response.content, tools)
|
363
363
|
else:
|
364
364
|
return ChatMessageAssistant(
|
365
|
-
content=response.content,
|
365
|
+
content=response.content or "",
|
366
366
|
tool_calls=[
|
367
367
|
chat_completion_tool_call(call, tools) for call in response.tool_calls
|
368
368
|
]
|
@@ -1,5 +1,7 @@
|
|
1
1
|
import asyncio
|
2
|
+
import copy
|
2
3
|
import functools
|
4
|
+
import json
|
3
5
|
import os
|
4
6
|
from dataclasses import dataclass
|
5
7
|
from queue import Empty, Queue
|
@@ -18,6 +20,7 @@ from transformers import ( # type: ignore
|
|
18
20
|
from typing_extensions import override
|
19
21
|
|
20
22
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
23
|
+
from inspect_ai._util.content import ContentText
|
21
24
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
22
25
|
|
23
26
|
from .._chat_message import ChatMessage, ChatMessageAssistant
|
@@ -31,7 +34,7 @@ from .._model_output import (
|
|
31
34
|
ModelUsage,
|
32
35
|
TopLogprob,
|
33
36
|
)
|
34
|
-
from .util import
|
37
|
+
from .util import ChatAPIHandler, HFHandler
|
35
38
|
|
36
39
|
HF_TOKEN = "HF_TOKEN"
|
37
40
|
|
@@ -71,6 +74,9 @@ class HuggingFaceAPI(ModelAPI):
|
|
71
74
|
tokenizer_path = collect_model_arg("tokenizer_path")
|
72
75
|
self.batch_size = collect_model_arg("batch_size")
|
73
76
|
self.chat_template = collect_model_arg("chat_template")
|
77
|
+
self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
|
78
|
+
if self.tokenizer_call_args is None:
|
79
|
+
self.tokenizer_call_args = {}
|
74
80
|
|
75
81
|
# device
|
76
82
|
if device:
|
@@ -113,11 +119,22 @@ class HuggingFaceAPI(ModelAPI):
|
|
113
119
|
tool_choice: ToolChoice,
|
114
120
|
config: GenerateConfig,
|
115
121
|
) -> ModelOutput:
|
122
|
+
# create handler
|
123
|
+
handler: ChatAPIHandler | None = (
|
124
|
+
HFHandler(self.model_name) if len(tools) > 0 else None
|
125
|
+
)
|
126
|
+
|
116
127
|
# create chat
|
117
128
|
chat = self.hf_chat(input, tools)
|
118
129
|
|
130
|
+
assert isinstance(self.tokenizer_call_args, dict)
|
119
131
|
# prepare tokenizer
|
120
|
-
tokenizer = functools.partial(
|
132
|
+
tokenizer = functools.partial(
|
133
|
+
self.tokenizer,
|
134
|
+
return_tensors="pt",
|
135
|
+
padding=True,
|
136
|
+
**self.tokenizer_call_args,
|
137
|
+
)
|
121
138
|
|
122
139
|
# prepare generator
|
123
140
|
kwargs: dict[str, Any] = dict(do_sample=True)
|
@@ -172,6 +189,15 @@ class HuggingFaceAPI(ModelAPI):
|
|
172
189
|
),
|
173
190
|
)
|
174
191
|
|
192
|
+
choice = ChatCompletionChoice(
|
193
|
+
message=chat_completion_assistant_message(
|
194
|
+
response, tools, handler, self.model_name
|
195
|
+
),
|
196
|
+
logprobs=(
|
197
|
+
Logprobs(content=final_logprobs) if final_logprobs is not None else None
|
198
|
+
),
|
199
|
+
)
|
200
|
+
|
175
201
|
# return output
|
176
202
|
return ModelOutput(
|
177
203
|
model=self.model_name,
|
@@ -199,18 +225,94 @@ class HuggingFaceAPI(ModelAPI):
|
|
199
225
|
|
200
226
|
def hf_chat(self, messages: list[ChatMessage], tools: list[ToolInfo]) -> str:
|
201
227
|
# convert to hf format
|
202
|
-
|
228
|
+
tools_list = []
|
229
|
+
hf_messages = copy.deepcopy(messages)
|
230
|
+
if len(tools) > 0:
|
231
|
+
tools_list = [
|
232
|
+
json.loads(tool.model_dump_json(exclude_none=True, indent=2))
|
233
|
+
for tool in tools
|
234
|
+
]
|
235
|
+
if "mistral" in self.model_name.lower():
|
236
|
+
hf_messages = shorten_tool_id(hf_messages)
|
237
|
+
tools_list = tools_to_mistral_format(tools_list)
|
238
|
+
elif "qwen" in self.model_name.lower():
|
239
|
+
hf_messages = inspect_tools_to_string(hf_messages)
|
240
|
+
|
203
241
|
# apply chat template
|
204
242
|
chat = self.tokenizer.apply_chat_template(
|
205
243
|
hf_messages,
|
206
244
|
add_generation_prompt=True,
|
207
245
|
tokenize=False,
|
208
|
-
|
246
|
+
tools=tools_list if len(tools_list) > 0 else None,
|
209
247
|
)
|
210
248
|
# return
|
211
249
|
return cast(str, chat)
|
212
250
|
|
213
251
|
|
252
|
+
def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
|
253
|
+
"""Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
|
254
|
+
for i, message in enumerate(messages):
|
255
|
+
if message.role == "tool":
|
256
|
+
# Trim tool_call_id in tool messages
|
257
|
+
if message.tool_call_id is not None:
|
258
|
+
message.tool_call_id = message.tool_call_id[-9:]
|
259
|
+
elif message.role == "assistant" and hasattr(message, "tool_calls"):
|
260
|
+
# Trim tool_call IDs inside tool_calls for assistant messages
|
261
|
+
for tool_call in message.tool_calls or []:
|
262
|
+
tool_call.id = tool_call.id[-9:]
|
263
|
+
return messages
|
264
|
+
|
265
|
+
|
266
|
+
def tools_to_mistral_format(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
267
|
+
"""Convert tools to the format required for Mistral."""
|
268
|
+
mistral_tools = []
|
269
|
+
for tool in tools:
|
270
|
+
mistral_tools.append(
|
271
|
+
{
|
272
|
+
"function": {
|
273
|
+
"name": tool["name"],
|
274
|
+
"description": tool["description"],
|
275
|
+
"parameters": {
|
276
|
+
"type": tool["parameters"]["type"],
|
277
|
+
"properties": tool["parameters"]["properties"],
|
278
|
+
"required": tool["parameters"]["required"],
|
279
|
+
},
|
280
|
+
}
|
281
|
+
}
|
282
|
+
)
|
283
|
+
return mistral_tools
|
284
|
+
|
285
|
+
|
286
|
+
def inspect_tools_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
|
287
|
+
"""Convert tools to a string for Qwen."""
|
288
|
+
for message in messages:
|
289
|
+
if message.role == "assistant":
|
290
|
+
# check if the message contains a tool call
|
291
|
+
tool_content = ""
|
292
|
+
if message.tool_calls:
|
293
|
+
for tool_call in message.tool_calls:
|
294
|
+
tool_content += f'\n```json\n{{"name": "{tool_call.function}", "arguments": {json.dumps(tool_call.arguments)}}}\n```'
|
295
|
+
# remove the tool call from the message
|
296
|
+
message.tool_calls = None
|
297
|
+
if isinstance(message.content, str):
|
298
|
+
message.content += tool_content
|
299
|
+
else:
|
300
|
+
message.content.append(ContentText(text=tool_content))
|
301
|
+
return messages
|
302
|
+
|
303
|
+
|
304
|
+
def chat_completion_assistant_message(
|
305
|
+
response: Any,
|
306
|
+
tools: list[ToolInfo],
|
307
|
+
handler: ChatAPIHandler | None,
|
308
|
+
model_name: str,
|
309
|
+
) -> ChatMessageAssistant:
|
310
|
+
if handler:
|
311
|
+
return handler.parse_assistant_response(response.output, tools)
|
312
|
+
else:
|
313
|
+
return ChatMessageAssistant(content=response.output, source="generate")
|
314
|
+
|
315
|
+
|
214
316
|
def set_random_seeds(seed: int | None = None) -> None:
|
215
317
|
if seed is None:
|
216
318
|
seed = np.random.default_rng().integers(2**32 - 1)
|
@@ -5,6 +5,7 @@ from .chatapi import (
|
|
5
5
|
chat_api_request,
|
6
6
|
is_chat_api_rate_limit,
|
7
7
|
)
|
8
|
+
from .hf_handler import HFHandler
|
8
9
|
from .llama31 import Llama31Handler
|
9
10
|
from .util import (
|
10
11
|
as_stop_reason,
|
@@ -26,4 +27,5 @@ __all__ = [
|
|
26
27
|
"ChatAPIHandler",
|
27
28
|
"ChatAPIMessage",
|
28
29
|
"Llama31Handler",
|
30
|
+
"HFHandler",
|
29
31
|
]
|
@@ -0,0 +1,200 @@
|
|
1
|
+
import json
|
2
|
+
import re
|
3
|
+
from logging import getLogger
|
4
|
+
|
5
|
+
from shortuuid import uuid
|
6
|
+
from typing_extensions import override
|
7
|
+
|
8
|
+
from inspect_ai.tool._tool_call import ToolCall
|
9
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
10
|
+
|
11
|
+
from ..._chat_message import ChatMessageAssistant
|
12
|
+
from .chatapi import ChatAPIHandler
|
13
|
+
from .util import parse_tool_call, tool_parse_error_message
|
14
|
+
|
15
|
+
logger = getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
# Hugging Face handler currently supports LLama, Mistral and Qwen models, but will
|
19
|
+
# work with any model that uses the same tool calling conventions
|
20
|
+
|
21
|
+
|
22
|
+
class HFHandler(ChatAPIHandler):
|
23
|
+
def __init__(self, model_name: str) -> None:
|
24
|
+
self.model_name = model_name
|
25
|
+
|
26
|
+
@override
|
27
|
+
def parse_assistant_response(
|
28
|
+
self, response: str, tools: list[ToolInfo]
|
29
|
+
) -> ChatMessageAssistant:
|
30
|
+
"""Parse content and tool calls from a model response.
|
31
|
+
|
32
|
+
This method has an interdependency with `input_with_tools()` (as that is the
|
33
|
+
prompt that asks the model to use the <tool_call>...</tool_call> syntax)
|
34
|
+
"""
|
35
|
+
# extract tool calls
|
36
|
+
content, tool_calls_content = model_specific_tool_parse(
|
37
|
+
response, self.model_name
|
38
|
+
)
|
39
|
+
# if there are tool calls proceed with parsing
|
40
|
+
if len(tool_calls_content) > 0:
|
41
|
+
# parse each tool call (if there are parsing error that occur
|
42
|
+
# this will be reported in the `parse_error` field of the ToolCall
|
43
|
+
# and ultimately reported back to the model)
|
44
|
+
tool_calls = [
|
45
|
+
parse_tool_call_content(content, tools)
|
46
|
+
for content in tool_calls_content
|
47
|
+
]
|
48
|
+
|
49
|
+
# return the message
|
50
|
+
return ChatMessageAssistant(
|
51
|
+
content=content,
|
52
|
+
tool_calls=tool_calls,
|
53
|
+
source="generate",
|
54
|
+
)
|
55
|
+
|
56
|
+
# otherwise this is just an ordinary assistant message
|
57
|
+
else:
|
58
|
+
return ChatMessageAssistant(
|
59
|
+
content=filter_assistant_header(response), source="generate"
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
|
64
|
+
"""Attempt to parse content from inside <tool_call> tags.
|
65
|
+
|
66
|
+
Content inside a <tool_call> should be a JSON dictionary with `name` and
|
67
|
+
`arguments` (which in turn should be a `dict[str,Any]` but in some cases
|
68
|
+
we've seen models pass `str`). This function attempts to extract this from
|
69
|
+
the passed tcontentext. A `ToolCall` is returned for all cases (if the
|
70
|
+
parsing fails then it will have a `parse_error`, which will be subsequently
|
71
|
+
reported to the model.
|
72
|
+
"""
|
73
|
+
try:
|
74
|
+
# parse raw JSON
|
75
|
+
tool_call_data = json.loads(content)
|
76
|
+
if "parameters" in tool_call_data:
|
77
|
+
tool_call_data["arguments"] = tool_call_data.pop("parameters")
|
78
|
+
|
79
|
+
# if its not a dict then report error
|
80
|
+
if not isinstance(tool_call_data, dict):
|
81
|
+
raise ValueError("The provided arguments are not a JSON dictionary.")
|
82
|
+
|
83
|
+
# see if we can get the fields (if not report error)
|
84
|
+
name = tool_call_data.get("name", None)
|
85
|
+
arguments = tool_call_data.get("arguments", None)
|
86
|
+
if not name or not arguments:
|
87
|
+
raise ValueError(
|
88
|
+
"Required 'name' and 'arguments' not provided in JSON dictionary."
|
89
|
+
)
|
90
|
+
|
91
|
+
# now perform the parse (we need to call thi function because it includes
|
92
|
+
# the special handling to for mapping arguments that are a plain `str`
|
93
|
+
# to the first parameter of the function)
|
94
|
+
unique_id = f"{name}_{uuid()}"
|
95
|
+
return parse_tool_call(unique_id, name, json.dumps(arguments), tools)
|
96
|
+
|
97
|
+
except Exception as ex:
|
98
|
+
# buld error message
|
99
|
+
parse_error = tool_parse_error_message(content, ex)
|
100
|
+
|
101
|
+
# log it to 'info'
|
102
|
+
logger.info(parse_error)
|
103
|
+
|
104
|
+
# notify model
|
105
|
+
return ToolCall(
|
106
|
+
id="unknown",
|
107
|
+
function="unknown",
|
108
|
+
arguments={},
|
109
|
+
type="function",
|
110
|
+
parse_error=parse_error,
|
111
|
+
)
|
112
|
+
|
113
|
+
|
114
|
+
def model_specific_tool_parse(response: str, model_name: str) -> tuple[str, list[str]]:
|
115
|
+
model_name = model_name.lower()
|
116
|
+
|
117
|
+
if "llama" in model_name:
|
118
|
+
if "name" in response and ("parameters" in response or "arguments" in response):
|
119
|
+
function_calls, content = json_extract_raw(response)
|
120
|
+
else:
|
121
|
+
content = response
|
122
|
+
function_calls = []
|
123
|
+
elif "mistral" in model_name:
|
124
|
+
if "name" in response and "arguments" in response:
|
125
|
+
content = ""
|
126
|
+
function_calls = [json.dumps(tool) for tool in json.loads(response)]
|
127
|
+
else:
|
128
|
+
content = response
|
129
|
+
function_calls = []
|
130
|
+
elif "qwen" in model_name and "coder" in model_name:
|
131
|
+
if "name" in response and "arguments" in response:
|
132
|
+
function_calls, content = json_extract(response)
|
133
|
+
else:
|
134
|
+
content = response
|
135
|
+
function_calls = []
|
136
|
+
elif "qwen" in model_name and "instruct" in model_name:
|
137
|
+
if "name" in response and "arguments" in response:
|
138
|
+
function_calls, content = xml_extract(response, "tool_call")
|
139
|
+
else:
|
140
|
+
content = response
|
141
|
+
function_calls = []
|
142
|
+
else:
|
143
|
+
try:
|
144
|
+
function_calls, content = parse_unknown_tool_calls(response)
|
145
|
+
except Exception:
|
146
|
+
raise ValueError(
|
147
|
+
f"Unsupported model: {model_name}. No tool parsing implemented. Check if any of the current parsings work with your tool calling conventions and add the model name to the correct elif block."
|
148
|
+
)
|
149
|
+
return content, function_calls
|
150
|
+
|
151
|
+
|
152
|
+
def json_extract(raw_string: str) -> tuple[list[str], str]:
|
153
|
+
"""Extract tools in form ```json{...}``` and return the remaining content."""
|
154
|
+
function_calls = re.findall(r"```json\s*(\{.*?\})\s*```", raw_string, re.DOTALL)
|
155
|
+
|
156
|
+
remaining_content = re.sub(
|
157
|
+
r"```json\s*\{.*?\}\s*```", "", raw_string, flags=re.DOTALL
|
158
|
+
).strip()
|
159
|
+
|
160
|
+
return function_calls, remaining_content
|
161
|
+
|
162
|
+
|
163
|
+
def json_extract_raw(raw_string: str) -> tuple[list[str], str]:
|
164
|
+
"""Extract tools in form `{...}` and return the remaining content."""
|
165
|
+
# Regex to extract sequences starting with '{' and ending with '}}'
|
166
|
+
json_like_regex = r"\{.*?\}\}"
|
167
|
+
function_calls = re.findall(json_like_regex, raw_string)
|
168
|
+
remaining_content = re.sub(json_like_regex, "", raw_string).strip()
|
169
|
+
|
170
|
+
return function_calls, remaining_content
|
171
|
+
|
172
|
+
|
173
|
+
def xml_extract(raw_string: str, tag: str) -> tuple[list[str], str]:
|
174
|
+
"""Extract tools in form <tag>{...}</tag> and return the remaining content."""
|
175
|
+
tool_call_regex = rf"<{tag}>((?:.|\n)*?)</{tag}>"
|
176
|
+
function_calls = re.findall(tool_call_regex, raw_string)
|
177
|
+
tool_call_content_regex = rf"<{tag}>(?:.|\n)*?</{tag}>"
|
178
|
+
other_content = re.split(tool_call_content_regex, raw_string, flags=re.DOTALL)
|
179
|
+
other_content = [
|
180
|
+
str(content).strip() for content in other_content if str(content).strip()
|
181
|
+
]
|
182
|
+
content = "\n\n".join(other_content)
|
183
|
+
return function_calls, content
|
184
|
+
|
185
|
+
|
186
|
+
def parse_unknown_tool_calls(response: str) -> tuple[list[str], str]:
|
187
|
+
if "```json" in response:
|
188
|
+
return json_extract(response)
|
189
|
+
elif "<tool_call>" in response:
|
190
|
+
return xml_extract(response, "tool_call")
|
191
|
+
elif "<function>" in response:
|
192
|
+
return xml_extract(response, "function")
|
193
|
+
elif "{" in response and "}}" in response:
|
194
|
+
return json_extract_raw(response)
|
195
|
+
else:
|
196
|
+
return [], response
|
197
|
+
|
198
|
+
|
199
|
+
def filter_assistant_header(message: str) -> str:
|
200
|
+
return re.sub(r"<\|start_header_id\|>assistant<\|end_header_id\|>", "", message)
|
inspect_ai/scorer/_common.py
CHANGED
inspect_ai/solver/_plan.py
CHANGED
@@ -57,7 +57,6 @@ class Plan(Solver):
|
|
57
57
|
|
58
58
|
self.finish = finish
|
59
59
|
self.cleanup = cleanup
|
60
|
-
self.progress: Callable[[], None] = lambda: None
|
61
60
|
self._name = name
|
62
61
|
|
63
62
|
if not internal:
|
@@ -106,14 +105,8 @@ class Plan(Solver):
|
|
106
105
|
state = await solver(state, generate)
|
107
106
|
st.complete(state)
|
108
107
|
|
109
|
-
# tick progress
|
110
|
-
self.progress()
|
111
|
-
|
112
108
|
# check for completed
|
113
109
|
if state.completed:
|
114
|
-
# tick rest of progress
|
115
|
-
for _ in range(index + 1, len(self.steps)):
|
116
|
-
self.progress()
|
117
110
|
# exit loop
|
118
111
|
break
|
119
112
|
|
@@ -122,7 +115,6 @@ class Plan(Solver):
|
|
122
115
|
with solver_transcript(self.finish, state) as st:
|
123
116
|
state = await self.finish(state, generate)
|
124
117
|
st.complete(state)
|
125
|
-
self.progress()
|
126
118
|
|
127
119
|
# mark completed
|
128
120
|
state.completed = True
|