inspect-ai 0.3.49__py3-none-any.whl → 0.3.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. inspect_ai/_cli/info.py +2 -2
  2. inspect_ai/_cli/log.py +2 -2
  3. inspect_ai/_cli/score.py +2 -2
  4. inspect_ai/_display/core/display.py +19 -0
  5. inspect_ai/_display/core/panel.py +37 -7
  6. inspect_ai/_display/core/progress.py +29 -2
  7. inspect_ai/_display/core/results.py +79 -40
  8. inspect_ai/_display/core/textual.py +21 -0
  9. inspect_ai/_display/rich/display.py +28 -8
  10. inspect_ai/_display/textual/app.py +107 -1
  11. inspect_ai/_display/textual/display.py +1 -1
  12. inspect_ai/_display/textual/widgets/samples.py +132 -91
  13. inspect_ai/_display/textual/widgets/task_detail.py +232 -0
  14. inspect_ai/_display/textual/widgets/tasks.py +74 -6
  15. inspect_ai/_display/textual/widgets/toggle.py +32 -0
  16. inspect_ai/_eval/context.py +2 -0
  17. inspect_ai/_eval/eval.py +4 -3
  18. inspect_ai/_eval/loader.py +1 -1
  19. inspect_ai/_eval/run.py +35 -2
  20. inspect_ai/_eval/task/log.py +13 -11
  21. inspect_ai/_eval/task/results.py +12 -3
  22. inspect_ai/_eval/task/run.py +139 -36
  23. inspect_ai/_eval/task/sandbox.py +2 -1
  24. inspect_ai/_util/_async.py +30 -1
  25. inspect_ai/_util/file.py +31 -4
  26. inspect_ai/_util/html.py +3 -0
  27. inspect_ai/_util/logger.py +6 -5
  28. inspect_ai/_util/platform.py +5 -6
  29. inspect_ai/_util/registry.py +1 -1
  30. inspect_ai/_view/server.py +9 -9
  31. inspect_ai/_view/www/App.css +2 -2
  32. inspect_ai/_view/www/dist/assets/index.css +2 -2
  33. inspect_ai/_view/www/dist/assets/index.js +352 -294
  34. inspect_ai/_view/www/log-schema.json +13 -0
  35. inspect_ai/_view/www/package.json +1 -0
  36. inspect_ai/_view/www/src/components/MessageBand.mjs +1 -1
  37. inspect_ai/_view/www/src/components/Tools.mjs +16 -13
  38. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -3
  39. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +52 -77
  40. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -13
  41. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +15 -2
  42. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.mjs +4 -2
  43. inspect_ai/_view/www/src/types/log.d.ts +2 -0
  44. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +2 -0
  45. inspect_ai/_view/www/yarn.lock +9 -4
  46. inspect_ai/approval/__init__.py +1 -1
  47. inspect_ai/approval/_human/approver.py +35 -0
  48. inspect_ai/approval/_human/console.py +62 -0
  49. inspect_ai/approval/_human/manager.py +108 -0
  50. inspect_ai/approval/_human/panel.py +233 -0
  51. inspect_ai/approval/_human/util.py +51 -0
  52. inspect_ai/dataset/_sources/hf.py +2 -2
  53. inspect_ai/dataset/_sources/util.py +1 -1
  54. inspect_ai/log/_file.py +106 -36
  55. inspect_ai/log/_recorders/eval.py +226 -158
  56. inspect_ai/log/_recorders/file.py +9 -6
  57. inspect_ai/log/_recorders/json.py +35 -12
  58. inspect_ai/log/_recorders/recorder.py +15 -15
  59. inspect_ai/log/_samples.py +52 -0
  60. inspect_ai/model/_model.py +14 -0
  61. inspect_ai/model/_model_output.py +4 -0
  62. inspect_ai/model/_providers/azureai.py +1 -1
  63. inspect_ai/model/_providers/hf.py +106 -4
  64. inspect_ai/model/_providers/util/__init__.py +2 -0
  65. inspect_ai/model/_providers/util/hf_handler.py +200 -0
  66. inspect_ai/scorer/_common.py +1 -1
  67. inspect_ai/solver/_plan.py +0 -8
  68. inspect_ai/solver/_task_state.py +18 -1
  69. inspect_ai/solver/_use_tools.py +9 -1
  70. inspect_ai/tool/_tool_def.py +2 -2
  71. inspect_ai/tool/_tool_info.py +14 -2
  72. inspect_ai/tool/_tool_params.py +2 -1
  73. inspect_ai/tool/_tools/_execute.py +1 -1
  74. inspect_ai/tool/_tools/_web_browser/_web_browser.py +6 -0
  75. inspect_ai/util/__init__.py +5 -6
  76. inspect_ai/util/_panel.py +91 -0
  77. inspect_ai/util/_sandbox/__init__.py +2 -6
  78. inspect_ai/util/_sandbox/context.py +4 -3
  79. inspect_ai/util/_sandbox/docker/compose.py +12 -2
  80. inspect_ai/util/_sandbox/docker/docker.py +19 -9
  81. inspect_ai/util/_sandbox/docker/util.py +10 -2
  82. inspect_ai/util/_sandbox/environment.py +47 -41
  83. inspect_ai/util/_sandbox/local.py +15 -10
  84. inspect_ai/util/_subprocess.py +43 -3
  85. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/METADATA +2 -2
  86. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/RECORD +90 -82
  87. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  88. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  89. inspect_ai/approval/_human.py +0 -123
  90. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/LICENSE +0 -0
  91. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/WHEEL +0 -0
  92. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/entry_points.txt +0 -0
  93. {inspect_ai-0.3.49.dist-info → inspect_ai-0.3.50.dist-info}/top_level.txt +0 -0
@@ -14,20 +14,27 @@ from inspect_ai.log._log import (
14
14
 
15
15
 
16
16
  class Recorder(abc.ABC):
17
+ @classmethod
18
+ @abc.abstractmethod
19
+ def handles_location(cls, location: str) -> bool: ...
20
+
21
+ @abc.abstractmethod
22
+ def default_log_buffer(self) -> int: ...
23
+
17
24
  @abc.abstractmethod
18
- def log_init(self, eval: EvalSpec, location: str | None = None) -> str: ...
25
+ async def log_init(self, eval: EvalSpec, location: str | None = None) -> str: ...
19
26
 
20
27
  @abc.abstractmethod
21
- def log_start(self, eval: EvalSpec, plan: EvalPlan) -> None: ...
28
+ async def log_start(self, eval: EvalSpec, plan: EvalPlan) -> None: ...
22
29
 
23
30
  @abc.abstractmethod
24
- def log_sample(self, eval: EvalSpec, sample: EvalSample) -> None: ...
31
+ async def log_sample(self, eval: EvalSpec, sample: EvalSample) -> None: ...
25
32
 
26
33
  @abc.abstractmethod
27
- def flush(self, eval: EvalSpec) -> None: ...
34
+ async def flush(self, eval: EvalSpec) -> None: ...
28
35
 
29
36
  @abc.abstractmethod
30
- def log_finish(
37
+ async def log_finish(
31
38
  self,
32
39
  eval: EvalSpec,
33
40
  status: Literal["success", "cancelled", "error"],
@@ -37,23 +44,16 @@ class Recorder(abc.ABC):
37
44
  error: EvalError | None = None,
38
45
  ) -> EvalLog: ...
39
46
 
40
- @abc.abstractmethod
41
- def default_log_buffer(self) -> int: ...
42
-
43
- @classmethod
44
- @abc.abstractmethod
45
- def handles_location(cls, location: str) -> bool: ...
46
-
47
47
  @classmethod
48
48
  @abc.abstractmethod
49
- def read_log(cls, location: str, header_only: bool = False) -> EvalLog: ...
49
+ async def read_log(cls, location: str, header_only: bool = False) -> EvalLog: ...
50
50
 
51
51
  @classmethod
52
52
  @abc.abstractmethod
53
- def read_log_sample(
53
+ async def read_log_sample(
54
54
  cls, location: str, id: str | int, epoch: int = 1
55
55
  ) -> EvalSample: ...
56
56
 
57
57
  @classmethod
58
58
  @abc.abstractmethod
59
- def write_log(cls, location: str, log: EvalLog) -> None: ...
59
+ async def write_log(cls, location: str, log: EvalLog) -> None: ...
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import contextlib
3
+ from contextvars import ContextVar
3
4
  from datetime import datetime
4
5
  from typing import AsyncGenerator, Literal
5
6
 
@@ -15,10 +16,14 @@ from ._transcript import Transcript
15
16
  class ActiveSample:
16
17
  def __init__(
17
18
  self,
19
+ *,
18
20
  task: str,
19
21
  model: str,
20
22
  sample: Sample,
21
23
  epoch: int,
24
+ message_limit: int | None,
25
+ token_limit: int | None,
26
+ time_limit: int | None,
22
27
  fails_on_error: bool,
23
28
  transcript: Transcript,
24
29
  sandboxes: dict[str, SandboxConnection],
@@ -30,7 +35,12 @@ class ActiveSample:
30
35
  self.model = model
31
36
  self.sample = sample
32
37
  self.epoch = epoch
38
+ self.message_limit = message_limit
39
+ self.token_limit = token_limit
40
+ self.time_limit = time_limit
33
41
  self.fails_on_error = fails_on_error
42
+ self.total_messages = 0
43
+ self.total_tokens = 0
34
44
  self.transcript = transcript
35
45
  self.sandboxes = sandboxes
36
46
  self._sample_task = asyncio.current_task()
@@ -59,10 +69,14 @@ def init_active_samples() -> None:
59
69
 
60
70
  @contextlib.asynccontextmanager
61
71
  async def active_sample(
72
+ *,
62
73
  task: str,
63
74
  model: str,
64
75
  sample: Sample,
65
76
  epoch: int,
77
+ message_limit: int | None,
78
+ token_limit: int | None,
79
+ time_limit: int | None,
66
80
  fails_on_error: bool,
67
81
  transcript: Transcript,
68
82
  ) -> AsyncGenerator[ActiveSample, None]:
@@ -72,17 +86,55 @@ async def active_sample(
72
86
  model=model,
73
87
  sample=sample,
74
88
  epoch=epoch,
89
+ message_limit=message_limit,
90
+ token_limit=token_limit,
91
+ time_limit=time_limit,
75
92
  sandboxes=await sandbox_connections(),
76
93
  fails_on_error=fails_on_error,
77
94
  transcript=transcript,
78
95
  )
79
96
 
80
97
  _active_samples.append(active)
98
+ _sample_active.set(active)
81
99
  try:
82
100
  yield active
83
101
  finally:
84
102
  active.completed = datetime.now().timestamp()
85
103
  _active_samples.remove(active)
104
+ _sample_active.set(None)
105
+
106
+
107
+ def sample_active() -> ActiveSample | None:
108
+ return _sample_active.get(None)
109
+
110
+
111
+ def set_active_sample_token_limit(token_limit: int | None) -> None:
112
+ active = sample_active()
113
+ if active:
114
+ active.token_limit = token_limit
115
+
116
+
117
+ def set_active_sample_total_tokens(total_tokens: int) -> None:
118
+ active = sample_active()
119
+ if active:
120
+ active.total_tokens = total_tokens
121
+
122
+
123
+ def set_active_sample_message_limit(message_limit: int | None) -> None:
124
+ active = sample_active()
125
+ if active:
126
+ active.message_limit = message_limit
127
+
128
+
129
+ def set_active_sample_total_messages(total_messages: int) -> None:
130
+ active = sample_active()
131
+ if active:
132
+ active.total_messages = total_messages
133
+
134
+
135
+ _sample_active: ContextVar[ActiveSample | None] = ContextVar(
136
+ "_sample_active", default=None
137
+ )
86
138
 
87
139
 
88
140
  def active_samples() -> list[ActiveSample]:
@@ -4,6 +4,7 @@ import functools
4
4
  import json
5
5
  import logging
6
6
  import os
7
+ import time
7
8
  from contextvars import ContextVar
8
9
  from copy import deepcopy
9
10
  from typing import Any, Callable, Literal, Type, cast
@@ -355,12 +356,14 @@ class Model:
355
356
 
356
357
  generate_id = uuid()
357
358
  logger.debug(f"model generate {generate_id} ({str(self)})")
359
+ time_start = time.perf_counter()
358
360
  result = await self.api.generate(
359
361
  input=input,
360
362
  tools=tools,
361
363
  tool_choice=tool_choice,
362
364
  config=config,
363
365
  )
366
+ time_elapsed = time.perf_counter() - time_start
364
367
  logger.debug(f"model generate {generate_id} (completed)")
365
368
  if isinstance(result, tuple):
366
369
  output, call = result
@@ -368,12 +371,18 @@ class Model:
368
371
  output = result
369
372
  call = None
370
373
 
374
+ # update output with time elapsed
375
+ output.time = time_elapsed
376
+
371
377
  # complete the transcript event
372
378
  complete(output, call)
373
379
 
374
380
  # record usage
375
381
  if output.usage:
382
+ # record usage
376
383
  record_model_usage(f"{self}", output.usage)
384
+
385
+ # send telemetry if its hooked up
377
386
  await send_telemetry(
378
387
  "model_usage",
379
388
  json.dumps(dict(model=str(self), usage=output.usage.model_dump())),
@@ -762,6 +771,11 @@ def record_model_usage(model: str, usage: ModelUsage) -> None:
762
771
  set_model_usage(model, usage, sample_model_usage_context_var.get(None))
763
772
  set_model_usage(model, usage, model_usage_context_var.get(None))
764
773
 
774
+ # update active sample
775
+ from inspect_ai.log._samples import set_active_sample_total_tokens
776
+
777
+ set_active_sample_total_tokens(sample_total_tokens())
778
+
765
779
 
766
780
  def set_model_usage(
767
781
  model: str, usage: ModelUsage, model_usage: dict[str, ModelUsage] | None
@@ -100,7 +100,11 @@ class ModelOutput(BaseModel):
100
100
  usage: ModelUsage | None = Field(default=None)
101
101
  """Model token usage"""
102
102
 
103
+ time: float | None = Field(default=None)
104
+ """Time elapsed (in seconds) for call to generate."""
105
+
103
106
  metadata: dict[str, Any] | None = Field(default=None)
107
+ """Additional metadata associated with model output."""
104
108
 
105
109
  error: str | None = Field(default=None)
106
110
  """Error message in the case of content moderation refusals."""
@@ -362,7 +362,7 @@ def chat_completion_assistant_message(
362
362
  return handler.parse_assistant_response(response.content, tools)
363
363
  else:
364
364
  return ChatMessageAssistant(
365
- content=response.content,
365
+ content=response.content or "",
366
366
  tool_calls=[
367
367
  chat_completion_tool_call(call, tools) for call in response.tool_calls
368
368
  ]
@@ -1,5 +1,7 @@
1
1
  import asyncio
2
+ import copy
2
3
  import functools
4
+ import json
3
5
  import os
4
6
  from dataclasses import dataclass
5
7
  from queue import Empty, Queue
@@ -18,6 +20,7 @@ from transformers import ( # type: ignore
18
20
  from typing_extensions import override
19
21
 
20
22
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
23
+ from inspect_ai._util.content import ContentText
21
24
  from inspect_ai.tool import ToolChoice, ToolInfo
22
25
 
23
26
  from .._chat_message import ChatMessage, ChatMessageAssistant
@@ -31,7 +34,7 @@ from .._model_output import (
31
34
  ModelUsage,
32
35
  TopLogprob,
33
36
  )
34
- from .util import chat_api_input
37
+ from .util import ChatAPIHandler, HFHandler
35
38
 
36
39
  HF_TOKEN = "HF_TOKEN"
37
40
 
@@ -71,6 +74,9 @@ class HuggingFaceAPI(ModelAPI):
71
74
  tokenizer_path = collect_model_arg("tokenizer_path")
72
75
  self.batch_size = collect_model_arg("batch_size")
73
76
  self.chat_template = collect_model_arg("chat_template")
77
+ self.tokenizer_call_args = collect_model_arg("tokenizer_call_args")
78
+ if self.tokenizer_call_args is None:
79
+ self.tokenizer_call_args = {}
74
80
 
75
81
  # device
76
82
  if device:
@@ -113,11 +119,22 @@ class HuggingFaceAPI(ModelAPI):
113
119
  tool_choice: ToolChoice,
114
120
  config: GenerateConfig,
115
121
  ) -> ModelOutput:
122
+ # create handler
123
+ handler: ChatAPIHandler | None = (
124
+ HFHandler(self.model_name) if len(tools) > 0 else None
125
+ )
126
+
116
127
  # create chat
117
128
  chat = self.hf_chat(input, tools)
118
129
 
130
+ assert isinstance(self.tokenizer_call_args, dict)
119
131
  # prepare tokenizer
120
- tokenizer = functools.partial(self.tokenizer, return_tensors="pt", padding=True)
132
+ tokenizer = functools.partial(
133
+ self.tokenizer,
134
+ return_tensors="pt",
135
+ padding=True,
136
+ **self.tokenizer_call_args,
137
+ )
121
138
 
122
139
  # prepare generator
123
140
  kwargs: dict[str, Any] = dict(do_sample=True)
@@ -172,6 +189,15 @@ class HuggingFaceAPI(ModelAPI):
172
189
  ),
173
190
  )
174
191
 
192
+ choice = ChatCompletionChoice(
193
+ message=chat_completion_assistant_message(
194
+ response, tools, handler, self.model_name
195
+ ),
196
+ logprobs=(
197
+ Logprobs(content=final_logprobs) if final_logprobs is not None else None
198
+ ),
199
+ )
200
+
175
201
  # return output
176
202
  return ModelOutput(
177
203
  model=self.model_name,
@@ -199,18 +225,94 @@ class HuggingFaceAPI(ModelAPI):
199
225
 
200
226
  def hf_chat(self, messages: list[ChatMessage], tools: list[ToolInfo]) -> str:
201
227
  # convert to hf format
202
- hf_messages = chat_api_input(messages, tools)
228
+ tools_list = []
229
+ hf_messages = copy.deepcopy(messages)
230
+ if len(tools) > 0:
231
+ tools_list = [
232
+ json.loads(tool.model_dump_json(exclude_none=True, indent=2))
233
+ for tool in tools
234
+ ]
235
+ if "mistral" in self.model_name.lower():
236
+ hf_messages = shorten_tool_id(hf_messages)
237
+ tools_list = tools_to_mistral_format(tools_list)
238
+ elif "qwen" in self.model_name.lower():
239
+ hf_messages = inspect_tools_to_string(hf_messages)
240
+
203
241
  # apply chat template
204
242
  chat = self.tokenizer.apply_chat_template(
205
243
  hf_messages,
206
244
  add_generation_prompt=True,
207
245
  tokenize=False,
208
- chat_template=self.chat_template,
246
+ tools=tools_list if len(tools_list) > 0 else None,
209
247
  )
210
248
  # return
211
249
  return cast(str, chat)
212
250
 
213
251
 
252
+ def shorten_tool_id(messages: list[ChatMessage]) -> list[ChatMessage]:
253
+ """Shorten the tool_call_id in the messages to the last 9 characters for Mistral."""
254
+ for i, message in enumerate(messages):
255
+ if message.role == "tool":
256
+ # Trim tool_call_id in tool messages
257
+ if message.tool_call_id is not None:
258
+ message.tool_call_id = message.tool_call_id[-9:]
259
+ elif message.role == "assistant" and hasattr(message, "tool_calls"):
260
+ # Trim tool_call IDs inside tool_calls for assistant messages
261
+ for tool_call in message.tool_calls or []:
262
+ tool_call.id = tool_call.id[-9:]
263
+ return messages
264
+
265
+
266
+ def tools_to_mistral_format(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
267
+ """Convert tools to the format required for Mistral."""
268
+ mistral_tools = []
269
+ for tool in tools:
270
+ mistral_tools.append(
271
+ {
272
+ "function": {
273
+ "name": tool["name"],
274
+ "description": tool["description"],
275
+ "parameters": {
276
+ "type": tool["parameters"]["type"],
277
+ "properties": tool["parameters"]["properties"],
278
+ "required": tool["parameters"]["required"],
279
+ },
280
+ }
281
+ }
282
+ )
283
+ return mistral_tools
284
+
285
+
286
+ def inspect_tools_to_string(messages: list[ChatMessage]) -> list[ChatMessage]:
287
+ """Convert tools to a string for Qwen."""
288
+ for message in messages:
289
+ if message.role == "assistant":
290
+ # check if the message contains a tool call
291
+ tool_content = ""
292
+ if message.tool_calls:
293
+ for tool_call in message.tool_calls:
294
+ tool_content += f'\n```json\n{{"name": "{tool_call.function}", "arguments": {json.dumps(tool_call.arguments)}}}\n```'
295
+ # remove the tool call from the message
296
+ message.tool_calls = None
297
+ if isinstance(message.content, str):
298
+ message.content += tool_content
299
+ else:
300
+ message.content.append(ContentText(text=tool_content))
301
+ return messages
302
+
303
+
304
+ def chat_completion_assistant_message(
305
+ response: Any,
306
+ tools: list[ToolInfo],
307
+ handler: ChatAPIHandler | None,
308
+ model_name: str,
309
+ ) -> ChatMessageAssistant:
310
+ if handler:
311
+ return handler.parse_assistant_response(response.output, tools)
312
+ else:
313
+ return ChatMessageAssistant(content=response.output, source="generate")
314
+
315
+
214
316
  def set_random_seeds(seed: int | None = None) -> None:
215
317
  if seed is None:
216
318
  seed = np.random.default_rng().integers(2**32 - 1)
@@ -5,6 +5,7 @@ from .chatapi import (
5
5
  chat_api_request,
6
6
  is_chat_api_rate_limit,
7
7
  )
8
+ from .hf_handler import HFHandler
8
9
  from .llama31 import Llama31Handler
9
10
  from .util import (
10
11
  as_stop_reason,
@@ -26,4 +27,5 @@ __all__ = [
26
27
  "ChatAPIHandler",
27
28
  "ChatAPIMessage",
28
29
  "Llama31Handler",
30
+ "HFHandler",
29
31
  ]
@@ -0,0 +1,200 @@
1
+ import json
2
+ import re
3
+ from logging import getLogger
4
+
5
+ from shortuuid import uuid
6
+ from typing_extensions import override
7
+
8
+ from inspect_ai.tool._tool_call import ToolCall
9
+ from inspect_ai.tool._tool_info import ToolInfo
10
+
11
+ from ..._chat_message import ChatMessageAssistant
12
+ from .chatapi import ChatAPIHandler
13
+ from .util import parse_tool_call, tool_parse_error_message
14
+
15
+ logger = getLogger(__name__)
16
+
17
+
18
+ # Hugging Face handler currently supports LLama, Mistral and Qwen models, but will
19
+ # work with any model that uses the same tool calling conventions
20
+
21
+
22
+ class HFHandler(ChatAPIHandler):
23
+ def __init__(self, model_name: str) -> None:
24
+ self.model_name = model_name
25
+
26
+ @override
27
+ def parse_assistant_response(
28
+ self, response: str, tools: list[ToolInfo]
29
+ ) -> ChatMessageAssistant:
30
+ """Parse content and tool calls from a model response.
31
+
32
+ This method has an interdependency with `input_with_tools()` (as that is the
33
+ prompt that asks the model to use the <tool_call>...</tool_call> syntax)
34
+ """
35
+ # extract tool calls
36
+ content, tool_calls_content = model_specific_tool_parse(
37
+ response, self.model_name
38
+ )
39
+ # if there are tool calls proceed with parsing
40
+ if len(tool_calls_content) > 0:
41
+ # parse each tool call (if there are parsing error that occur
42
+ # this will be reported in the `parse_error` field of the ToolCall
43
+ # and ultimately reported back to the model)
44
+ tool_calls = [
45
+ parse_tool_call_content(content, tools)
46
+ for content in tool_calls_content
47
+ ]
48
+
49
+ # return the message
50
+ return ChatMessageAssistant(
51
+ content=content,
52
+ tool_calls=tool_calls,
53
+ source="generate",
54
+ )
55
+
56
+ # otherwise this is just an ordinary assistant message
57
+ else:
58
+ return ChatMessageAssistant(
59
+ content=filter_assistant_header(response), source="generate"
60
+ )
61
+
62
+
63
+ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
64
+ """Attempt to parse content from inside <tool_call> tags.
65
+
66
+ Content inside a <tool_call> should be a JSON dictionary with `name` and
67
+ `arguments` (which in turn should be a `dict[str,Any]` but in some cases
68
+ we've seen models pass `str`). This function attempts to extract this from
69
+ the passed tcontentext. A `ToolCall` is returned for all cases (if the
70
+ parsing fails then it will have a `parse_error`, which will be subsequently
71
+ reported to the model.
72
+ """
73
+ try:
74
+ # parse raw JSON
75
+ tool_call_data = json.loads(content)
76
+ if "parameters" in tool_call_data:
77
+ tool_call_data["arguments"] = tool_call_data.pop("parameters")
78
+
79
+ # if its not a dict then report error
80
+ if not isinstance(tool_call_data, dict):
81
+ raise ValueError("The provided arguments are not a JSON dictionary.")
82
+
83
+ # see if we can get the fields (if not report error)
84
+ name = tool_call_data.get("name", None)
85
+ arguments = tool_call_data.get("arguments", None)
86
+ if not name or not arguments:
87
+ raise ValueError(
88
+ "Required 'name' and 'arguments' not provided in JSON dictionary."
89
+ )
90
+
91
+ # now perform the parse (we need to call thi function because it includes
92
+ # the special handling to for mapping arguments that are a plain `str`
93
+ # to the first parameter of the function)
94
+ unique_id = f"{name}_{uuid()}"
95
+ return parse_tool_call(unique_id, name, json.dumps(arguments), tools)
96
+
97
+ except Exception as ex:
98
+ # buld error message
99
+ parse_error = tool_parse_error_message(content, ex)
100
+
101
+ # log it to 'info'
102
+ logger.info(parse_error)
103
+
104
+ # notify model
105
+ return ToolCall(
106
+ id="unknown",
107
+ function="unknown",
108
+ arguments={},
109
+ type="function",
110
+ parse_error=parse_error,
111
+ )
112
+
113
+
114
+ def model_specific_tool_parse(response: str, model_name: str) -> tuple[str, list[str]]:
115
+ model_name = model_name.lower()
116
+
117
+ if "llama" in model_name:
118
+ if "name" in response and ("parameters" in response or "arguments" in response):
119
+ function_calls, content = json_extract_raw(response)
120
+ else:
121
+ content = response
122
+ function_calls = []
123
+ elif "mistral" in model_name:
124
+ if "name" in response and "arguments" in response:
125
+ content = ""
126
+ function_calls = [json.dumps(tool) for tool in json.loads(response)]
127
+ else:
128
+ content = response
129
+ function_calls = []
130
+ elif "qwen" in model_name and "coder" in model_name:
131
+ if "name" in response and "arguments" in response:
132
+ function_calls, content = json_extract(response)
133
+ else:
134
+ content = response
135
+ function_calls = []
136
+ elif "qwen" in model_name and "instruct" in model_name:
137
+ if "name" in response and "arguments" in response:
138
+ function_calls, content = xml_extract(response, "tool_call")
139
+ else:
140
+ content = response
141
+ function_calls = []
142
+ else:
143
+ try:
144
+ function_calls, content = parse_unknown_tool_calls(response)
145
+ except Exception:
146
+ raise ValueError(
147
+ f"Unsupported model: {model_name}. No tool parsing implemented. Check if any of the current parsings work with your tool calling conventions and add the model name to the correct elif block."
148
+ )
149
+ return content, function_calls
150
+
151
+
152
+ def json_extract(raw_string: str) -> tuple[list[str], str]:
153
+ """Extract tools in form ```json{...}``` and return the remaining content."""
154
+ function_calls = re.findall(r"```json\s*(\{.*?\})\s*```", raw_string, re.DOTALL)
155
+
156
+ remaining_content = re.sub(
157
+ r"```json\s*\{.*?\}\s*```", "", raw_string, flags=re.DOTALL
158
+ ).strip()
159
+
160
+ return function_calls, remaining_content
161
+
162
+
163
+ def json_extract_raw(raw_string: str) -> tuple[list[str], str]:
164
+ """Extract tools in form `{...}` and return the remaining content."""
165
+ # Regex to extract sequences starting with '{' and ending with '}}'
166
+ json_like_regex = r"\{.*?\}\}"
167
+ function_calls = re.findall(json_like_regex, raw_string)
168
+ remaining_content = re.sub(json_like_regex, "", raw_string).strip()
169
+
170
+ return function_calls, remaining_content
171
+
172
+
173
+ def xml_extract(raw_string: str, tag: str) -> tuple[list[str], str]:
174
+ """Extract tools in form <tag>{...}</tag> and return the remaining content."""
175
+ tool_call_regex = rf"<{tag}>((?:.|\n)*?)</{tag}>"
176
+ function_calls = re.findall(tool_call_regex, raw_string)
177
+ tool_call_content_regex = rf"<{tag}>(?:.|\n)*?</{tag}>"
178
+ other_content = re.split(tool_call_content_regex, raw_string, flags=re.DOTALL)
179
+ other_content = [
180
+ str(content).strip() for content in other_content if str(content).strip()
181
+ ]
182
+ content = "\n\n".join(other_content)
183
+ return function_calls, content
184
+
185
+
186
+ def parse_unknown_tool_calls(response: str) -> tuple[list[str], str]:
187
+ if "```json" in response:
188
+ return json_extract(response)
189
+ elif "<tool_call>" in response:
190
+ return xml_extract(response, "tool_call")
191
+ elif "<function>" in response:
192
+ return xml_extract(response, "function")
193
+ elif "{" in response and "}}" in response:
194
+ return json_extract_raw(response)
195
+ else:
196
+ return [], response
197
+
198
+
199
+ def filter_assistant_header(message: str) -> str:
200
+ return re.sub(r"<\|start_header_id\|>assistant<\|end_header_id\|>", "", message)
@@ -61,7 +61,7 @@ def match_str(
61
61
  if ignore_case:
62
62
  v = v.casefold()
63
63
  t = t.casefold()
64
- if numeric:
64
+ if numeric and t.isnumeric():
65
65
  # remove punctuation
66
66
  v = strip_numeric_punctuation(v)
67
67
  t = strip_numeric_punctuation(t)
@@ -57,7 +57,6 @@ class Plan(Solver):
57
57
 
58
58
  self.finish = finish
59
59
  self.cleanup = cleanup
60
- self.progress: Callable[[], None] = lambda: None
61
60
  self._name = name
62
61
 
63
62
  if not internal:
@@ -106,14 +105,8 @@ class Plan(Solver):
106
105
  state = await solver(state, generate)
107
106
  st.complete(state)
108
107
 
109
- # tick progress
110
- self.progress()
111
-
112
108
  # check for completed
113
109
  if state.completed:
114
- # tick rest of progress
115
- for _ in range(index + 1, len(self.steps)):
116
- self.progress()
117
110
  # exit loop
118
111
  break
119
112
 
@@ -122,7 +115,6 @@ class Plan(Solver):
122
115
  with solver_transcript(self.finish, state) as st:
123
116
  state = await self.finish(state, generate)
124
117
  st.complete(state)
125
- self.progress()
126
118
 
127
119
  # mark completed
128
120
  state.completed = True