inspect-ai 0.3.59__py3-none-any.whl → 0.3.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. inspect_ai/_cli/eval.py +0 -8
  2. inspect_ai/_display/textual/widgets/samples.py +1 -1
  3. inspect_ai/_eval/eval.py +10 -1
  4. inspect_ai/_eval/loader.py +79 -19
  5. inspect_ai/_eval/registry.py +6 -0
  6. inspect_ai/_eval/score.py +2 -1
  7. inspect_ai/_eval/task/generate.py +41 -35
  8. inspect_ai/_eval/task/results.py +6 -5
  9. inspect_ai/_eval/task/run.py +21 -15
  10. inspect_ai/_util/hooks.py +17 -7
  11. inspect_ai/_view/www/dist/assets/index.js +262 -303
  12. inspect_ai/_view/www/package.json +1 -1
  13. inspect_ai/_view/www/src/App.mjs +6 -6
  14. inspect_ai/_view/www/src/Types.mjs +1 -1
  15. inspect_ai/_view/www/src/api/Types.ts +133 -0
  16. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  17. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  18. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  19. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  20. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  21. inspect_ai/_view/www/src/api/index.ts +51 -0
  22. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  23. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  24. inspect_ai/_view/www/src/index.js +2 -2
  25. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  26. inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
  27. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
  28. inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
  29. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  30. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
  31. inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
  32. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  33. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
  34. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  35. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  36. inspect_ai/approval/_human/manager.py +1 -1
  37. inspect_ai/model/_call_tools.py +55 -0
  38. inspect_ai/model/_chat_message.py +2 -2
  39. inspect_ai/model/_conversation.py +1 -4
  40. inspect_ai/model/_generate_config.py +2 -8
  41. inspect_ai/model/_model.py +90 -25
  42. inspect_ai/model/_model_output.py +15 -0
  43. inspect_ai/model/_openai.py +383 -0
  44. inspect_ai/model/_providers/anthropic.py +52 -14
  45. inspect_ai/model/_providers/azureai.py +1 -1
  46. inspect_ai/model/_providers/goodfire.py +248 -0
  47. inspect_ai/model/_providers/groq.py +7 -3
  48. inspect_ai/model/_providers/hf.py +6 -0
  49. inspect_ai/model/_providers/mistral.py +2 -1
  50. inspect_ai/model/_providers/openai.py +36 -202
  51. inspect_ai/model/_providers/openai_o1.py +2 -4
  52. inspect_ai/model/_providers/providers.py +22 -0
  53. inspect_ai/model/_providers/together.py +4 -4
  54. inspect_ai/model/_providers/util/__init__.py +2 -3
  55. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  56. inspect_ai/model/_providers/util/llama31.py +1 -1
  57. inspect_ai/model/_providers/util/util.py +0 -76
  58. inspect_ai/scorer/_metric.py +3 -0
  59. inspect_ai/scorer/_scorer.py +2 -1
  60. inspect_ai/solver/__init__.py +4 -0
  61. inspect_ai/solver/_basic_agent.py +65 -55
  62. inspect_ai/solver/_bridge/__init__.py +3 -0
  63. inspect_ai/solver/_bridge/bridge.py +100 -0
  64. inspect_ai/solver/_bridge/patch.py +170 -0
  65. inspect_ai/{util → solver}/_limit.py +13 -0
  66. inspect_ai/solver/_solver.py +6 -0
  67. inspect_ai/solver/_task_state.py +37 -7
  68. inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -1
  69. inspect_ai/tool/beta/_computer/_resources/Dockerfile +1 -3
  70. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +1 -1
  71. inspect_ai/tool/beta/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +10 -0
  72. inspect_ai/util/__init__.py +0 -2
  73. inspect_ai/util/_display.py +5 -0
  74. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  75. inspect_ai/util/_sandbox/self_check.py +51 -28
  76. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/METADATA +3 -2
  77. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/RECORD +81 -76
  78. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  79. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  80. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  81. inspect_ai/_view/www/src/api/index.mjs +0 -49
  82. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  83. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  84. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +0 -10
  85. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/LICENSE +0 -0
  86. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/WHEEL +0 -0
  87. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/entry_points.txt +0 -0
  88. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,13 @@ from anthropic import (
14
14
  APIConnectionError,
15
15
  AsyncAnthropic,
16
16
  AsyncAnthropicBedrock,
17
+ AsyncAnthropicVertex,
17
18
  BadRequestError,
18
19
  InternalServerError,
20
+ NotGiven,
19
21
  RateLimitError,
20
22
  )
23
+ from anthropic._types import Body
21
24
  from anthropic.types import (
22
25
  ImageBlockParam,
23
26
  Message,
@@ -64,15 +67,25 @@ class AnthropicAPI(ModelAPI):
64
67
  base_url: str | None = None,
65
68
  api_key: str | None = None,
66
69
  config: GenerateConfig = GenerateConfig(),
67
- bedrock: bool = False,
68
70
  **model_args: Any,
69
71
  ):
70
72
  # extract any service prefix from model name
71
73
  parts = model_name.split("/")
72
74
  if len(parts) > 1:
73
- service = parts[0]
74
- bedrock = service == "bedrock"
75
+ self.service: str | None = parts[0]
75
76
  model_name = "/".join(parts[1:])
77
+ else:
78
+ self.service = None
79
+
80
+ # collect gemerate model_args (then delete them so we can pass the rest on)
81
+ def collect_model_arg(name: str) -> Any | None:
82
+ nonlocal model_args
83
+ value = model_args.get(name, None)
84
+ if value is not None:
85
+ model_args.pop(name)
86
+ return value
87
+
88
+ self.extra_body: Body | None = collect_model_arg("extra_body")
76
89
 
77
90
  # call super
78
91
  super().__init__(
@@ -84,7 +97,7 @@ class AnthropicAPI(ModelAPI):
84
97
  )
85
98
 
86
99
  # create client
87
- if bedrock:
100
+ if self.is_bedrock():
88
101
  base_url = model_base_url(
89
102
  base_url, ["ANTHROPIC_BEDROCK_BASE_URL", "BEDROCK_ANTHROPIC_BASE_URL"]
90
103
  )
@@ -95,7 +108,9 @@ class AnthropicAPI(ModelAPI):
95
108
  if base_region is None:
96
109
  aws_region = os.environ.get("AWS_DEFAULT_REGION", None)
97
110
 
98
- self.client: AsyncAnthropic | AsyncAnthropicBedrock = AsyncAnthropicBedrock(
111
+ self.client: (
112
+ AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
113
+ ) = AsyncAnthropicBedrock(
99
114
  base_url=base_url,
100
115
  max_retries=(
101
116
  config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
@@ -103,6 +118,21 @@ class AnthropicAPI(ModelAPI):
103
118
  aws_region=aws_region,
104
119
  **model_args,
105
120
  )
121
+ elif self.is_vertex():
122
+ base_url = model_base_url(
123
+ base_url, ["ANTHROPIC_VERTEX_BASE_URL", "VERTEX_ANTHROPIC_BASE_URL"]
124
+ )
125
+ region = os.environ.get("ANTHROPIC_VERTEX_REGION", NotGiven())
126
+ project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID", NotGiven())
127
+ self.client = AsyncAnthropicVertex(
128
+ region=region,
129
+ project_id=project_id,
130
+ base_url=base_url,
131
+ max_retries=(
132
+ config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
133
+ ),
134
+ **model_args,
135
+ )
106
136
  else:
107
137
  # resolve api_key
108
138
  if not self.api_key:
@@ -119,6 +149,12 @@ class AnthropicAPI(ModelAPI):
119
149
  **model_args,
120
150
  )
121
151
 
152
+ def is_bedrock(self) -> bool:
153
+ return self.service == "bedrock"
154
+
155
+ def is_vertex(self) -> bool:
156
+ return self.service == "vertex"
157
+
122
158
  async def generate(
123
159
  self,
124
160
  input: list[ChatMessage],
@@ -163,6 +199,10 @@ class AnthropicAPI(ModelAPI):
163
199
  if computer_use:
164
200
  request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
165
201
 
202
+ # extra_body
203
+ if self.extra_body is not None:
204
+ request["extra_body"] = self.extra_body
205
+
166
206
  # make request
167
207
  message = await self.client.messages.create(**request, stream=False)
168
208
 
@@ -251,9 +291,6 @@ class AnthropicAPI(ModelAPI):
251
291
  elif "content filtering" in error:
252
292
  content = "Sorry, but I am unable to help with that request."
253
293
  stop_reason = "content_filter"
254
- else:
255
- content = error
256
- stop_reason = "unknown"
257
294
 
258
295
  if content and stop_reason:
259
296
  return ModelOutput.from_content(
@@ -466,6 +503,12 @@ def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolCh
466
503
 
467
504
 
468
505
  async def message_param(message: ChatMessage) -> MessageParam:
506
+ # if content is empty that is going to result in an error when we replay
507
+ # this message to claude, so in that case insert a NO_CONTENT message
508
+ if isinstance(message.content, list) and len(message.content) == 0:
509
+ message = message.model_copy()
510
+ message.content = [ContentText(text=NO_CONTENT)]
511
+
469
512
  # no system role for anthropic (this is more like an assertion,
470
513
  # as these should have already been filtered out)
471
514
  if message.role == "system":
@@ -507,7 +550,7 @@ async def message_param(message: ChatMessage) -> MessageParam:
507
550
  elif message.role == "assistant" and message.tool_calls:
508
551
  # first include content (claude <thinking>)
509
552
  tools_content: list[TextBlockParam | ImageBlockParam | ToolUseBlockParam] = (
510
- [TextBlockParam(type="text", text=message.content)]
553
+ [TextBlockParam(type="text", text=message.content or NO_CONTENT)]
511
554
  if isinstance(message.content, str)
512
555
  else (
513
556
  [(await message_param_content(content)) for content in message.content]
@@ -576,11 +619,6 @@ def model_output_from_message(message: Message, tools: list[ToolInfo]) -> ModelO
576
619
  )
577
620
  )
578
621
 
579
- # if content is empty that is going to result in an error when we replay
580
- # this message to claude, so in that case insert a NO_CONTENT message
581
- if len(content) == 0:
582
- content = [ContentText(text=NO_CONTENT)]
583
-
584
622
  # resolve choice
585
623
  choice = ChatCompletionChoice(
586
624
  message=ChatMessageAssistant(
@@ -37,6 +37,7 @@ from inspect_ai.tool import ToolChoice, ToolInfo
37
37
  from inspect_ai.tool._tool_call import ToolCall
38
38
  from inspect_ai.tool._tool_choice import ToolFunction
39
39
 
40
+ from .._call_tools import parse_tool_call
40
41
  from .._chat_message import (
41
42
  ChatMessage,
42
43
  ChatMessageAssistant,
@@ -60,7 +61,6 @@ from .util import (
60
61
  )
61
62
  from .util.chatapi import ChatAPIHandler
62
63
  from .util.llama31 import Llama31Handler
63
- from .util.util import parse_tool_call
64
64
 
65
65
  AZUREAI_API_KEY = "AZUREAI_API_KEY"
66
66
  AZUREAI_ENDPOINT_KEY = "AZUREAI_ENDPOINT_KEY"
@@ -0,0 +1,248 @@
1
+ import os
2
+ from typing import Any, List, Literal, get_args
3
+
4
+ from goodfire import AsyncClient
5
+ from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
6
+ from goodfire.api.exceptions import InvalidRequestException, RateLimitException
7
+ from goodfire.variants.variants import SUPPORTED_MODELS, Variant
8
+ from typing_extensions import override
9
+
10
+ from inspect_ai.tool._tool_choice import ToolChoice
11
+ from inspect_ai.tool._tool_info import ToolInfo
12
+
13
+ from .._chat_message import (
14
+ ChatMessage,
15
+ ChatMessageAssistant,
16
+ ChatMessageSystem,
17
+ ChatMessageTool,
18
+ ChatMessageUser,
19
+ )
20
+ from .._generate_config import GenerateConfig
21
+ from .._model import ModelAPI
22
+ from .._model_call import ModelCall
23
+ from .._model_output import (
24
+ ChatCompletionChoice,
25
+ ModelOutput,
26
+ ModelUsage,
27
+ )
28
+ from .util import environment_prerequisite_error, model_base_url
29
+
30
+ # Constants
31
+ GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
32
+ DEFAULT_BASE_URL = "https://api.goodfire.ai"
33
+ DEFAULT_MAX_TOKENS = 4096
34
+ DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
35
+ DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
36
+
37
+
38
+ class GoodfireAPI(ModelAPI):
39
+ """Goodfire API provider.
40
+
41
+ This provider implements the Goodfire API for LLM inference. It supports:
42
+ - Chat completions with standard message formats
43
+ - Basic parameter controls (temperature, top_p, etc.)
44
+ - Usage statistics tracking
45
+ - Stop reason handling
46
+
47
+ Does not currently support:
48
+ - Tool calls
49
+ - Feature analysis
50
+ - Streaming responses
51
+
52
+ Known limitations:
53
+ - Limited role support (system/user/assistant only)
54
+ - Tool messages converted to user messages
55
+ """
56
+
57
+ client: AsyncClient
58
+ variant: Variant
59
+ model_args: dict[str, Any]
60
+
61
+ def __init__(
62
+ self,
63
+ model_name: str,
64
+ base_url: str | None = None,
65
+ api_key: str | None = None,
66
+ config: GenerateConfig = GenerateConfig(),
67
+ **model_args: Any,
68
+ ) -> None:
69
+ """Initialize the Goodfire API provider.
70
+
71
+ Args:
72
+ model_name: Name of the model to use
73
+ base_url: Optional custom API base URL
74
+ api_key: Optional API key (will check env vars if not provided)
75
+ config: Generation config options
76
+ **model_args: Additional arguments passed to the API
77
+ """
78
+ super().__init__(
79
+ model_name=model_name,
80
+ base_url=base_url,
81
+ api_key=api_key,
82
+ api_key_vars=[GOODFIRE_API_KEY],
83
+ config=config,
84
+ )
85
+
86
+ # resolve api_key
87
+ if not self.api_key:
88
+ self.api_key = os.environ.get(GOODFIRE_API_KEY)
89
+ if not self.api_key:
90
+ raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
91
+
92
+ # Validate model name against supported models
93
+ supported_models = list(get_args(SUPPORTED_MODELS))
94
+ if self.model_name not in supported_models:
95
+ raise ValueError(
96
+ f"Model {self.model_name} not supported. Supported models: {supported_models}"
97
+ )
98
+
99
+ # Initialize client with minimal configuration
100
+ base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
101
+ assert isinstance(base_url_val, str) or base_url_val is None
102
+
103
+ # Store model args for use in generate
104
+ self.model_args = model_args
105
+
106
+ self.client = AsyncClient(
107
+ api_key=self.api_key,
108
+ base_url=base_url_val or DEFAULT_BASE_URL,
109
+ )
110
+
111
+ # Initialize variant directly with model name
112
+ self.variant = Variant(self.model_name) # type: ignore
113
+
114
+ def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
115
+ """Convert an Inspect message to a Goodfire message format.
116
+
117
+ Args:
118
+ message: The message to convert
119
+
120
+ Returns:
121
+ The converted message in Goodfire format
122
+
123
+ Raises:
124
+ ValueError: If the message type is unknown
125
+ """
126
+ role: Literal["system", "user", "assistant"] = "user"
127
+ if isinstance(message, ChatMessageSystem):
128
+ role = "system"
129
+ elif isinstance(message, ChatMessageUser):
130
+ role = "user"
131
+ elif isinstance(message, ChatMessageAssistant):
132
+ role = "assistant"
133
+ elif isinstance(message, ChatMessageTool):
134
+ role = "user" # Convert tool messages to user messages
135
+ else:
136
+ raise ValueError(f"Unknown message type: {type(message)}")
137
+
138
+ content = str(message.content)
139
+ if isinstance(message, ChatMessageTool):
140
+ content = f"Tool {message.function}: {content}"
141
+
142
+ return GoodfireChatMessage(role=role, content=content)
143
+
144
+ def handle_error(self, ex: Exception) -> ModelOutput | Exception:
145
+ """Handle only errors that need special treatment for retry logic or model limits."""
146
+ # Handle token/context length errors
147
+ if isinstance(ex, InvalidRequestException):
148
+ error_msg = str(ex).lower()
149
+ if "context length" in error_msg or "max tokens" in error_msg:
150
+ return ModelOutput.from_content(
151
+ model=self.model_name,
152
+ content=str(ex),
153
+ stop_reason="model_length",
154
+ error=error_msg,
155
+ )
156
+
157
+ # Let all other errors propagate
158
+ return ex
159
+
160
+ @override
161
+ def is_rate_limit(self, ex: BaseException) -> bool:
162
+ """Check if exception is due to rate limiting."""
163
+ return isinstance(ex, RateLimitException)
164
+
165
+ @override
166
+ def connection_key(self) -> str:
167
+ """Return key for connection pooling."""
168
+ return f"goodfire:{self.api_key}"
169
+
170
+ @override
171
+ def max_tokens(self) -> int | None:
172
+ """Return maximum tokens supported by model."""
173
+ return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
174
+
175
+ async def generate(
176
+ self,
177
+ input: List[ChatMessage],
178
+ tools: List[ToolInfo],
179
+ tool_choice: ToolChoice,
180
+ config: GenerateConfig,
181
+ *,
182
+ cache: bool = True,
183
+ ) -> tuple[ModelOutput | Exception, ModelCall]:
184
+ """Generate output from the model."""
185
+ # Convert messages and prepare request params
186
+ messages = [self._to_goodfire_message(msg) for msg in input]
187
+ # Build request parameters with type hints
188
+ params: dict[str, Any] = {
189
+ "model": self.variant.base_model, # Use base_model instead of stringifying the Variant
190
+ "messages": messages,
191
+ "max_completion_tokens": int(config.max_tokens)
192
+ if config.max_tokens
193
+ else DEFAULT_MAX_TOKENS,
194
+ "stream": False,
195
+ }
196
+
197
+ # Add generation parameters from config if not in model_args
198
+ if "temperature" not in self.model_args and config.temperature is not None:
199
+ params["temperature"] = float(config.temperature)
200
+ elif "temperature" not in self.model_args:
201
+ params["temperature"] = DEFAULT_TEMPERATURE
202
+
203
+ if "top_p" not in self.model_args and config.top_p is not None:
204
+ params["top_p"] = float(config.top_p)
205
+ elif "top_p" not in self.model_args:
206
+ params["top_p"] = DEFAULT_TOP_P
207
+
208
+ # Add any additional model args (highest priority)
209
+ api_params = {
210
+ k: v
211
+ for k, v in self.model_args.items()
212
+ if k not in ["api_key", "base_url", "model_args"]
213
+ }
214
+ params.update(api_params)
215
+
216
+ try:
217
+ # Use native async client
218
+ response = await self.client.chat.completions.create(**params)
219
+ response_dict = response.model_dump()
220
+
221
+ output = ModelOutput(
222
+ model=self.model_name,
223
+ choices=[
224
+ ChatCompletionChoice(
225
+ message=ChatMessageAssistant(
226
+ content=response_dict["choices"][0]["message"]["content"]
227
+ ),
228
+ stop_reason="stop",
229
+ )
230
+ ],
231
+ usage=ModelUsage(**response_dict["usage"])
232
+ if "usage" in response_dict
233
+ else None,
234
+ )
235
+ model_call = ModelCall.create(request=params, response=response_dict)
236
+ return (output, model_call)
237
+ except Exception as ex:
238
+ result = self.handle_error(ex)
239
+ model_call = ModelCall.create(
240
+ request=params,
241
+ response={}, # Empty response for error case
242
+ )
243
+ return (result, model_call)
244
+
245
+ @property
246
+ def name(self) -> str:
247
+ """Get provider name."""
248
+ return "goodfire"
@@ -27,6 +27,7 @@ from inspect_ai._util.images import file_as_data_uri
27
27
  from inspect_ai._util.url import is_http_url
28
28
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
29
29
 
30
+ from .._call_tools import parse_tool_call
30
31
  from .._chat_message import (
31
32
  ChatMessage,
32
33
  ChatMessageAssistant,
@@ -37,12 +38,15 @@ from .._chat_message import (
37
38
  from .._generate_config import GenerateConfig
38
39
  from .._model import ModelAPI
39
40
  from .._model_call import ModelCall
40
- from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
41
- from .util import (
41
+ from .._model_output import (
42
+ ChatCompletionChoice,
43
+ ModelOutput,
44
+ ModelUsage,
42
45
  as_stop_reason,
46
+ )
47
+ from .util import (
43
48
  environment_prerequisite_error,
44
49
  model_base_url,
45
- parse_tool_call,
46
50
  )
47
51
 
48
52
  GROQ_API_KEY = "GROQ_API_KEY"
@@ -150,6 +150,12 @@ class HuggingFaceAPI(ModelAPI):
150
150
  kwargs["output_logits"] = config.logprobs
151
151
  if "return_dict_in_generate" in kwargs:
152
152
  assert kwargs["return_dict_in_generate"]
153
+ if config.stop_seqs is not None:
154
+ from transformers.generation import StopStringCriteria # type: ignore
155
+
156
+ stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
157
+ kwargs["stopping_criteria"] = stopping_criteria
158
+
153
159
  kwargs["return_dict_in_generate"] = True
154
160
  generator = functools.partial(self.model.generate, **kwargs)
155
161
 
@@ -46,6 +46,7 @@ from inspect_ai._util.content import Content, ContentImage, ContentText
46
46
  from inspect_ai._util.images import file_as_data_uri
47
47
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
48
48
 
49
+ from .._call_tools import parse_tool_call
49
50
  from .._chat_message import (
50
51
  ChatMessage,
51
52
  ChatMessageAssistant,
@@ -59,7 +60,7 @@ from .._model_output import (
59
60
  ModelUsage,
60
61
  StopReason,
61
62
  )
62
- from .util import environment_prerequisite_error, model_base_url, parse_tool_call
63
+ from .util import environment_prerequisite_error, model_base_url
63
64
 
64
65
  AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
65
66
  AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"