inspect-ai 0.3.59__py3-none-any.whl → 0.3.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +0 -8
- inspect_ai/_display/textual/widgets/samples.py +1 -1
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +2 -1
- inspect_ai/_eval/task/generate.py +41 -35
- inspect_ai/_eval/task/results.py +6 -5
- inspect_ai/_eval/task/run.py +21 -15
- inspect_ai/_util/hooks.py +17 -7
- inspect_ai/_view/www/dist/assets/index.js +262 -303
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/App.mjs +6 -6
- inspect_ai/_view/www/src/Types.mjs +1 -1
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/index.js +2 -2
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/model/_call_tools.py +55 -0
- inspect_ai/model/_chat_message.py +2 -2
- inspect_ai/model/_conversation.py +1 -4
- inspect_ai/model/_generate_config.py +2 -8
- inspect_ai/model/_model.py +90 -25
- inspect_ai/model/_model_output.py +15 -0
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +52 -14
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +2 -1
- inspect_ai/model/_providers/openai.py +36 -202
- inspect_ai/model/_providers/openai_o1.py +2 -4
- inspect_ai/model/_providers/providers.py +22 -0
- inspect_ai/model/_providers/together.py +4 -4
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_scorer.py +2 -1
- inspect_ai/solver/__init__.py +4 -0
- inspect_ai/solver/_basic_agent.py +65 -55
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/{util → solver}/_limit.py +13 -0
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +37 -7
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -1
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +1 -3
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +1 -1
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +10 -0
- inspect_ai/util/__init__.py +0 -2
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/self_check.py +51 -28
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/RECORD +81 -76
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +0 -10
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,13 @@ from anthropic import (
|
|
14
14
|
APIConnectionError,
|
15
15
|
AsyncAnthropic,
|
16
16
|
AsyncAnthropicBedrock,
|
17
|
+
AsyncAnthropicVertex,
|
17
18
|
BadRequestError,
|
18
19
|
InternalServerError,
|
20
|
+
NotGiven,
|
19
21
|
RateLimitError,
|
20
22
|
)
|
23
|
+
from anthropic._types import Body
|
21
24
|
from anthropic.types import (
|
22
25
|
ImageBlockParam,
|
23
26
|
Message,
|
@@ -64,15 +67,25 @@ class AnthropicAPI(ModelAPI):
|
|
64
67
|
base_url: str | None = None,
|
65
68
|
api_key: str | None = None,
|
66
69
|
config: GenerateConfig = GenerateConfig(),
|
67
|
-
bedrock: bool = False,
|
68
70
|
**model_args: Any,
|
69
71
|
):
|
70
72
|
# extract any service prefix from model name
|
71
73
|
parts = model_name.split("/")
|
72
74
|
if len(parts) > 1:
|
73
|
-
service = parts[0]
|
74
|
-
bedrock = service == "bedrock"
|
75
|
+
self.service: str | None = parts[0]
|
75
76
|
model_name = "/".join(parts[1:])
|
77
|
+
else:
|
78
|
+
self.service = None
|
79
|
+
|
80
|
+
# collect gemerate model_args (then delete them so we can pass the rest on)
|
81
|
+
def collect_model_arg(name: str) -> Any | None:
|
82
|
+
nonlocal model_args
|
83
|
+
value = model_args.get(name, None)
|
84
|
+
if value is not None:
|
85
|
+
model_args.pop(name)
|
86
|
+
return value
|
87
|
+
|
88
|
+
self.extra_body: Body | None = collect_model_arg("extra_body")
|
76
89
|
|
77
90
|
# call super
|
78
91
|
super().__init__(
|
@@ -84,7 +97,7 @@ class AnthropicAPI(ModelAPI):
|
|
84
97
|
)
|
85
98
|
|
86
99
|
# create client
|
87
|
-
if
|
100
|
+
if self.is_bedrock():
|
88
101
|
base_url = model_base_url(
|
89
102
|
base_url, ["ANTHROPIC_BEDROCK_BASE_URL", "BEDROCK_ANTHROPIC_BASE_URL"]
|
90
103
|
)
|
@@ -95,7 +108,9 @@ class AnthropicAPI(ModelAPI):
|
|
95
108
|
if base_region is None:
|
96
109
|
aws_region = os.environ.get("AWS_DEFAULT_REGION", None)
|
97
110
|
|
98
|
-
self.client:
|
111
|
+
self.client: (
|
112
|
+
AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
|
113
|
+
) = AsyncAnthropicBedrock(
|
99
114
|
base_url=base_url,
|
100
115
|
max_retries=(
|
101
116
|
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
@@ -103,6 +118,21 @@ class AnthropicAPI(ModelAPI):
|
|
103
118
|
aws_region=aws_region,
|
104
119
|
**model_args,
|
105
120
|
)
|
121
|
+
elif self.is_vertex():
|
122
|
+
base_url = model_base_url(
|
123
|
+
base_url, ["ANTHROPIC_VERTEX_BASE_URL", "VERTEX_ANTHROPIC_BASE_URL"]
|
124
|
+
)
|
125
|
+
region = os.environ.get("ANTHROPIC_VERTEX_REGION", NotGiven())
|
126
|
+
project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID", NotGiven())
|
127
|
+
self.client = AsyncAnthropicVertex(
|
128
|
+
region=region,
|
129
|
+
project_id=project_id,
|
130
|
+
base_url=base_url,
|
131
|
+
max_retries=(
|
132
|
+
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
133
|
+
),
|
134
|
+
**model_args,
|
135
|
+
)
|
106
136
|
else:
|
107
137
|
# resolve api_key
|
108
138
|
if not self.api_key:
|
@@ -119,6 +149,12 @@ class AnthropicAPI(ModelAPI):
|
|
119
149
|
**model_args,
|
120
150
|
)
|
121
151
|
|
152
|
+
def is_bedrock(self) -> bool:
|
153
|
+
return self.service == "bedrock"
|
154
|
+
|
155
|
+
def is_vertex(self) -> bool:
|
156
|
+
return self.service == "vertex"
|
157
|
+
|
122
158
|
async def generate(
|
123
159
|
self,
|
124
160
|
input: list[ChatMessage],
|
@@ -163,6 +199,10 @@ class AnthropicAPI(ModelAPI):
|
|
163
199
|
if computer_use:
|
164
200
|
request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
|
165
201
|
|
202
|
+
# extra_body
|
203
|
+
if self.extra_body is not None:
|
204
|
+
request["extra_body"] = self.extra_body
|
205
|
+
|
166
206
|
# make request
|
167
207
|
message = await self.client.messages.create(**request, stream=False)
|
168
208
|
|
@@ -251,9 +291,6 @@ class AnthropicAPI(ModelAPI):
|
|
251
291
|
elif "content filtering" in error:
|
252
292
|
content = "Sorry, but I am unable to help with that request."
|
253
293
|
stop_reason = "content_filter"
|
254
|
-
else:
|
255
|
-
content = error
|
256
|
-
stop_reason = "unknown"
|
257
294
|
|
258
295
|
if content and stop_reason:
|
259
296
|
return ModelOutput.from_content(
|
@@ -466,6 +503,12 @@ def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolCh
|
|
466
503
|
|
467
504
|
|
468
505
|
async def message_param(message: ChatMessage) -> MessageParam:
|
506
|
+
# if content is empty that is going to result in an error when we replay
|
507
|
+
# this message to claude, so in that case insert a NO_CONTENT message
|
508
|
+
if isinstance(message.content, list) and len(message.content) == 0:
|
509
|
+
message = message.model_copy()
|
510
|
+
message.content = [ContentText(text=NO_CONTENT)]
|
511
|
+
|
469
512
|
# no system role for anthropic (this is more like an assertion,
|
470
513
|
# as these should have already been filtered out)
|
471
514
|
if message.role == "system":
|
@@ -507,7 +550,7 @@ async def message_param(message: ChatMessage) -> MessageParam:
|
|
507
550
|
elif message.role == "assistant" and message.tool_calls:
|
508
551
|
# first include content (claude <thinking>)
|
509
552
|
tools_content: list[TextBlockParam | ImageBlockParam | ToolUseBlockParam] = (
|
510
|
-
[TextBlockParam(type="text", text=message.content)]
|
553
|
+
[TextBlockParam(type="text", text=message.content or NO_CONTENT)]
|
511
554
|
if isinstance(message.content, str)
|
512
555
|
else (
|
513
556
|
[(await message_param_content(content)) for content in message.content]
|
@@ -576,11 +619,6 @@ def model_output_from_message(message: Message, tools: list[ToolInfo]) -> ModelO
|
|
576
619
|
)
|
577
620
|
)
|
578
621
|
|
579
|
-
# if content is empty that is going to result in an error when we replay
|
580
|
-
# this message to claude, so in that case insert a NO_CONTENT message
|
581
|
-
if len(content) == 0:
|
582
|
-
content = [ContentText(text=NO_CONTENT)]
|
583
|
-
|
584
622
|
# resolve choice
|
585
623
|
choice = ChatCompletionChoice(
|
586
624
|
message=ChatMessageAssistant(
|
@@ -37,6 +37,7 @@ from inspect_ai.tool import ToolChoice, ToolInfo
|
|
37
37
|
from inspect_ai.tool._tool_call import ToolCall
|
38
38
|
from inspect_ai.tool._tool_choice import ToolFunction
|
39
39
|
|
40
|
+
from .._call_tools import parse_tool_call
|
40
41
|
from .._chat_message import (
|
41
42
|
ChatMessage,
|
42
43
|
ChatMessageAssistant,
|
@@ -60,7 +61,6 @@ from .util import (
|
|
60
61
|
)
|
61
62
|
from .util.chatapi import ChatAPIHandler
|
62
63
|
from .util.llama31 import Llama31Handler
|
63
|
-
from .util.util import parse_tool_call
|
64
64
|
|
65
65
|
AZUREAI_API_KEY = "AZUREAI_API_KEY"
|
66
66
|
AZUREAI_ENDPOINT_KEY = "AZUREAI_ENDPOINT_KEY"
|
@@ -0,0 +1,248 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List, Literal, get_args
|
3
|
+
|
4
|
+
from goodfire import AsyncClient
|
5
|
+
from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
|
6
|
+
from goodfire.api.exceptions import InvalidRequestException, RateLimitException
|
7
|
+
from goodfire.variants.variants import SUPPORTED_MODELS, Variant
|
8
|
+
from typing_extensions import override
|
9
|
+
|
10
|
+
from inspect_ai.tool._tool_choice import ToolChoice
|
11
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
12
|
+
|
13
|
+
from .._chat_message import (
|
14
|
+
ChatMessage,
|
15
|
+
ChatMessageAssistant,
|
16
|
+
ChatMessageSystem,
|
17
|
+
ChatMessageTool,
|
18
|
+
ChatMessageUser,
|
19
|
+
)
|
20
|
+
from .._generate_config import GenerateConfig
|
21
|
+
from .._model import ModelAPI
|
22
|
+
from .._model_call import ModelCall
|
23
|
+
from .._model_output import (
|
24
|
+
ChatCompletionChoice,
|
25
|
+
ModelOutput,
|
26
|
+
ModelUsage,
|
27
|
+
)
|
28
|
+
from .util import environment_prerequisite_error, model_base_url
|
29
|
+
|
30
|
+
# Constants
|
31
|
+
GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
|
32
|
+
DEFAULT_BASE_URL = "https://api.goodfire.ai"
|
33
|
+
DEFAULT_MAX_TOKENS = 4096
|
34
|
+
DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
|
35
|
+
DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
|
36
|
+
|
37
|
+
|
38
|
+
class GoodfireAPI(ModelAPI):
|
39
|
+
"""Goodfire API provider.
|
40
|
+
|
41
|
+
This provider implements the Goodfire API for LLM inference. It supports:
|
42
|
+
- Chat completions with standard message formats
|
43
|
+
- Basic parameter controls (temperature, top_p, etc.)
|
44
|
+
- Usage statistics tracking
|
45
|
+
- Stop reason handling
|
46
|
+
|
47
|
+
Does not currently support:
|
48
|
+
- Tool calls
|
49
|
+
- Feature analysis
|
50
|
+
- Streaming responses
|
51
|
+
|
52
|
+
Known limitations:
|
53
|
+
- Limited role support (system/user/assistant only)
|
54
|
+
- Tool messages converted to user messages
|
55
|
+
"""
|
56
|
+
|
57
|
+
client: AsyncClient
|
58
|
+
variant: Variant
|
59
|
+
model_args: dict[str, Any]
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
model_name: str,
|
64
|
+
base_url: str | None = None,
|
65
|
+
api_key: str | None = None,
|
66
|
+
config: GenerateConfig = GenerateConfig(),
|
67
|
+
**model_args: Any,
|
68
|
+
) -> None:
|
69
|
+
"""Initialize the Goodfire API provider.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
model_name: Name of the model to use
|
73
|
+
base_url: Optional custom API base URL
|
74
|
+
api_key: Optional API key (will check env vars if not provided)
|
75
|
+
config: Generation config options
|
76
|
+
**model_args: Additional arguments passed to the API
|
77
|
+
"""
|
78
|
+
super().__init__(
|
79
|
+
model_name=model_name,
|
80
|
+
base_url=base_url,
|
81
|
+
api_key=api_key,
|
82
|
+
api_key_vars=[GOODFIRE_API_KEY],
|
83
|
+
config=config,
|
84
|
+
)
|
85
|
+
|
86
|
+
# resolve api_key
|
87
|
+
if not self.api_key:
|
88
|
+
self.api_key = os.environ.get(GOODFIRE_API_KEY)
|
89
|
+
if not self.api_key:
|
90
|
+
raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
|
91
|
+
|
92
|
+
# Validate model name against supported models
|
93
|
+
supported_models = list(get_args(SUPPORTED_MODELS))
|
94
|
+
if self.model_name not in supported_models:
|
95
|
+
raise ValueError(
|
96
|
+
f"Model {self.model_name} not supported. Supported models: {supported_models}"
|
97
|
+
)
|
98
|
+
|
99
|
+
# Initialize client with minimal configuration
|
100
|
+
base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
|
101
|
+
assert isinstance(base_url_val, str) or base_url_val is None
|
102
|
+
|
103
|
+
# Store model args for use in generate
|
104
|
+
self.model_args = model_args
|
105
|
+
|
106
|
+
self.client = AsyncClient(
|
107
|
+
api_key=self.api_key,
|
108
|
+
base_url=base_url_val or DEFAULT_BASE_URL,
|
109
|
+
)
|
110
|
+
|
111
|
+
# Initialize variant directly with model name
|
112
|
+
self.variant = Variant(self.model_name) # type: ignore
|
113
|
+
|
114
|
+
def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
|
115
|
+
"""Convert an Inspect message to a Goodfire message format.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
message: The message to convert
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
The converted message in Goodfire format
|
122
|
+
|
123
|
+
Raises:
|
124
|
+
ValueError: If the message type is unknown
|
125
|
+
"""
|
126
|
+
role: Literal["system", "user", "assistant"] = "user"
|
127
|
+
if isinstance(message, ChatMessageSystem):
|
128
|
+
role = "system"
|
129
|
+
elif isinstance(message, ChatMessageUser):
|
130
|
+
role = "user"
|
131
|
+
elif isinstance(message, ChatMessageAssistant):
|
132
|
+
role = "assistant"
|
133
|
+
elif isinstance(message, ChatMessageTool):
|
134
|
+
role = "user" # Convert tool messages to user messages
|
135
|
+
else:
|
136
|
+
raise ValueError(f"Unknown message type: {type(message)}")
|
137
|
+
|
138
|
+
content = str(message.content)
|
139
|
+
if isinstance(message, ChatMessageTool):
|
140
|
+
content = f"Tool {message.function}: {content}"
|
141
|
+
|
142
|
+
return GoodfireChatMessage(role=role, content=content)
|
143
|
+
|
144
|
+
def handle_error(self, ex: Exception) -> ModelOutput | Exception:
|
145
|
+
"""Handle only errors that need special treatment for retry logic or model limits."""
|
146
|
+
# Handle token/context length errors
|
147
|
+
if isinstance(ex, InvalidRequestException):
|
148
|
+
error_msg = str(ex).lower()
|
149
|
+
if "context length" in error_msg or "max tokens" in error_msg:
|
150
|
+
return ModelOutput.from_content(
|
151
|
+
model=self.model_name,
|
152
|
+
content=str(ex),
|
153
|
+
stop_reason="model_length",
|
154
|
+
error=error_msg,
|
155
|
+
)
|
156
|
+
|
157
|
+
# Let all other errors propagate
|
158
|
+
return ex
|
159
|
+
|
160
|
+
@override
|
161
|
+
def is_rate_limit(self, ex: BaseException) -> bool:
|
162
|
+
"""Check if exception is due to rate limiting."""
|
163
|
+
return isinstance(ex, RateLimitException)
|
164
|
+
|
165
|
+
@override
|
166
|
+
def connection_key(self) -> str:
|
167
|
+
"""Return key for connection pooling."""
|
168
|
+
return f"goodfire:{self.api_key}"
|
169
|
+
|
170
|
+
@override
|
171
|
+
def max_tokens(self) -> int | None:
|
172
|
+
"""Return maximum tokens supported by model."""
|
173
|
+
return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
|
174
|
+
|
175
|
+
async def generate(
|
176
|
+
self,
|
177
|
+
input: List[ChatMessage],
|
178
|
+
tools: List[ToolInfo],
|
179
|
+
tool_choice: ToolChoice,
|
180
|
+
config: GenerateConfig,
|
181
|
+
*,
|
182
|
+
cache: bool = True,
|
183
|
+
) -> tuple[ModelOutput | Exception, ModelCall]:
|
184
|
+
"""Generate output from the model."""
|
185
|
+
# Convert messages and prepare request params
|
186
|
+
messages = [self._to_goodfire_message(msg) for msg in input]
|
187
|
+
# Build request parameters with type hints
|
188
|
+
params: dict[str, Any] = {
|
189
|
+
"model": self.variant.base_model, # Use base_model instead of stringifying the Variant
|
190
|
+
"messages": messages,
|
191
|
+
"max_completion_tokens": int(config.max_tokens)
|
192
|
+
if config.max_tokens
|
193
|
+
else DEFAULT_MAX_TOKENS,
|
194
|
+
"stream": False,
|
195
|
+
}
|
196
|
+
|
197
|
+
# Add generation parameters from config if not in model_args
|
198
|
+
if "temperature" not in self.model_args and config.temperature is not None:
|
199
|
+
params["temperature"] = float(config.temperature)
|
200
|
+
elif "temperature" not in self.model_args:
|
201
|
+
params["temperature"] = DEFAULT_TEMPERATURE
|
202
|
+
|
203
|
+
if "top_p" not in self.model_args and config.top_p is not None:
|
204
|
+
params["top_p"] = float(config.top_p)
|
205
|
+
elif "top_p" not in self.model_args:
|
206
|
+
params["top_p"] = DEFAULT_TOP_P
|
207
|
+
|
208
|
+
# Add any additional model args (highest priority)
|
209
|
+
api_params = {
|
210
|
+
k: v
|
211
|
+
for k, v in self.model_args.items()
|
212
|
+
if k not in ["api_key", "base_url", "model_args"]
|
213
|
+
}
|
214
|
+
params.update(api_params)
|
215
|
+
|
216
|
+
try:
|
217
|
+
# Use native async client
|
218
|
+
response = await self.client.chat.completions.create(**params)
|
219
|
+
response_dict = response.model_dump()
|
220
|
+
|
221
|
+
output = ModelOutput(
|
222
|
+
model=self.model_name,
|
223
|
+
choices=[
|
224
|
+
ChatCompletionChoice(
|
225
|
+
message=ChatMessageAssistant(
|
226
|
+
content=response_dict["choices"][0]["message"]["content"]
|
227
|
+
),
|
228
|
+
stop_reason="stop",
|
229
|
+
)
|
230
|
+
],
|
231
|
+
usage=ModelUsage(**response_dict["usage"])
|
232
|
+
if "usage" in response_dict
|
233
|
+
else None,
|
234
|
+
)
|
235
|
+
model_call = ModelCall.create(request=params, response=response_dict)
|
236
|
+
return (output, model_call)
|
237
|
+
except Exception as ex:
|
238
|
+
result = self.handle_error(ex)
|
239
|
+
model_call = ModelCall.create(
|
240
|
+
request=params,
|
241
|
+
response={}, # Empty response for error case
|
242
|
+
)
|
243
|
+
return (result, model_call)
|
244
|
+
|
245
|
+
@property
|
246
|
+
def name(self) -> str:
|
247
|
+
"""Get provider name."""
|
248
|
+
return "goodfire"
|
@@ -27,6 +27,7 @@ from inspect_ai._util.images import file_as_data_uri
|
|
27
27
|
from inspect_ai._util.url import is_http_url
|
28
28
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
29
29
|
|
30
|
+
from .._call_tools import parse_tool_call
|
30
31
|
from .._chat_message import (
|
31
32
|
ChatMessage,
|
32
33
|
ChatMessageAssistant,
|
@@ -37,12 +38,15 @@ from .._chat_message import (
|
|
37
38
|
from .._generate_config import GenerateConfig
|
38
39
|
from .._model import ModelAPI
|
39
40
|
from .._model_call import ModelCall
|
40
|
-
from .._model_output import
|
41
|
-
|
41
|
+
from .._model_output import (
|
42
|
+
ChatCompletionChoice,
|
43
|
+
ModelOutput,
|
44
|
+
ModelUsage,
|
42
45
|
as_stop_reason,
|
46
|
+
)
|
47
|
+
from .util import (
|
43
48
|
environment_prerequisite_error,
|
44
49
|
model_base_url,
|
45
|
-
parse_tool_call,
|
46
50
|
)
|
47
51
|
|
48
52
|
GROQ_API_KEY = "GROQ_API_KEY"
|
@@ -150,6 +150,12 @@ class HuggingFaceAPI(ModelAPI):
|
|
150
150
|
kwargs["output_logits"] = config.logprobs
|
151
151
|
if "return_dict_in_generate" in kwargs:
|
152
152
|
assert kwargs["return_dict_in_generate"]
|
153
|
+
if config.stop_seqs is not None:
|
154
|
+
from transformers.generation import StopStringCriteria # type: ignore
|
155
|
+
|
156
|
+
stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
|
157
|
+
kwargs["stopping_criteria"] = stopping_criteria
|
158
|
+
|
153
159
|
kwargs["return_dict_in_generate"] = True
|
154
160
|
generator = functools.partial(self.model.generate, **kwargs)
|
155
161
|
|
@@ -46,6 +46,7 @@ from inspect_ai._util.content import Content, ContentImage, ContentText
|
|
46
46
|
from inspect_ai._util.images import file_as_data_uri
|
47
47
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
48
48
|
|
49
|
+
from .._call_tools import parse_tool_call
|
49
50
|
from .._chat_message import (
|
50
51
|
ChatMessage,
|
51
52
|
ChatMessageAssistant,
|
@@ -59,7 +60,7 @@ from .._model_output import (
|
|
59
60
|
ModelUsage,
|
60
61
|
StopReason,
|
61
62
|
)
|
62
|
-
from .util import environment_prerequisite_error, model_base_url
|
63
|
+
from .util import environment_prerequisite_error, model_base_url
|
63
64
|
|
64
65
|
AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
|
65
66
|
AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
|