inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/common.py +3 -1
- inspect_ai/_cli/eval.py +15 -9
- inspect_ai/_display/core/active.py +4 -1
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +0 -5
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +79 -12
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +3 -1
- inspect_ai/_eval/task/results.py +51 -22
- inspect_ai/_eval/task/run.py +47 -13
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25498 -2044
- inspect_ai/_view/www/log-schema.json +32 -2
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +14 -16
- inspect_ai/_view/www/src/Types.mjs +1 -2
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +77 -4
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +13 -2
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_samples.py +16 -0
- inspect_ai/log/_transcript.py +4 -1
- inspect_ai/model/_call_tools.py +59 -0
- inspect_ai/model/_conversation.py +16 -7
- inspect_ai/model/_generate_config.py +12 -12
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +22 -2
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +152 -55
- inspect_ai/model/_providers/azureai.py +21 -21
- inspect_ai/model/_providers/bedrock.py +37 -40
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/google.py +46 -54
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +13 -12
- inspect_ai/model/_providers/openai.py +51 -218
- inspect_ai/model/_providers/openai_o1.py +11 -12
- inspect_ai/model/_providers/providers.py +23 -1
- inspect_ai/model/_providers/together.py +12 -12
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/model/_providers/vertex.py +1 -4
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +4 -3
- inspect_ai/solver/__init__.py +4 -5
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +12 -1
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/docker/docker.py +64 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/environment.py +14 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,248 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List, Literal, get_args
|
3
|
+
|
4
|
+
from goodfire import AsyncClient
|
5
|
+
from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
|
6
|
+
from goodfire.api.exceptions import InvalidRequestException, RateLimitException
|
7
|
+
from goodfire.variants.variants import SUPPORTED_MODELS, Variant
|
8
|
+
from typing_extensions import override
|
9
|
+
|
10
|
+
from inspect_ai.tool._tool_choice import ToolChoice
|
11
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
12
|
+
|
13
|
+
from .._chat_message import (
|
14
|
+
ChatMessage,
|
15
|
+
ChatMessageAssistant,
|
16
|
+
ChatMessageSystem,
|
17
|
+
ChatMessageTool,
|
18
|
+
ChatMessageUser,
|
19
|
+
)
|
20
|
+
from .._generate_config import GenerateConfig
|
21
|
+
from .._model import ModelAPI
|
22
|
+
from .._model_call import ModelCall
|
23
|
+
from .._model_output import (
|
24
|
+
ChatCompletionChoice,
|
25
|
+
ModelOutput,
|
26
|
+
ModelUsage,
|
27
|
+
)
|
28
|
+
from .util import environment_prerequisite_error, model_base_url
|
29
|
+
|
30
|
+
# Constants
|
31
|
+
GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
|
32
|
+
DEFAULT_BASE_URL = "https://api.goodfire.ai"
|
33
|
+
DEFAULT_MAX_TOKENS = 4096
|
34
|
+
DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
|
35
|
+
DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
|
36
|
+
|
37
|
+
|
38
|
+
class GoodfireAPI(ModelAPI):
|
39
|
+
"""Goodfire API provider.
|
40
|
+
|
41
|
+
This provider implements the Goodfire API for LLM inference. It supports:
|
42
|
+
- Chat completions with standard message formats
|
43
|
+
- Basic parameter controls (temperature, top_p, etc.)
|
44
|
+
- Usage statistics tracking
|
45
|
+
- Stop reason handling
|
46
|
+
|
47
|
+
Does not currently support:
|
48
|
+
- Tool calls
|
49
|
+
- Feature analysis
|
50
|
+
- Streaming responses
|
51
|
+
|
52
|
+
Known limitations:
|
53
|
+
- Limited role support (system/user/assistant only)
|
54
|
+
- Tool messages converted to user messages
|
55
|
+
"""
|
56
|
+
|
57
|
+
client: AsyncClient
|
58
|
+
variant: Variant
|
59
|
+
model_args: dict[str, Any]
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
model_name: str,
|
64
|
+
base_url: str | None = None,
|
65
|
+
api_key: str | None = None,
|
66
|
+
config: GenerateConfig = GenerateConfig(),
|
67
|
+
**model_args: Any,
|
68
|
+
) -> None:
|
69
|
+
"""Initialize the Goodfire API provider.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
model_name: Name of the model to use
|
73
|
+
base_url: Optional custom API base URL
|
74
|
+
api_key: Optional API key (will check env vars if not provided)
|
75
|
+
config: Generation config options
|
76
|
+
**model_args: Additional arguments passed to the API
|
77
|
+
"""
|
78
|
+
super().__init__(
|
79
|
+
model_name=model_name,
|
80
|
+
base_url=base_url,
|
81
|
+
api_key=api_key,
|
82
|
+
api_key_vars=[GOODFIRE_API_KEY],
|
83
|
+
config=config,
|
84
|
+
)
|
85
|
+
|
86
|
+
# resolve api_key
|
87
|
+
if not self.api_key:
|
88
|
+
self.api_key = os.environ.get(GOODFIRE_API_KEY)
|
89
|
+
if not self.api_key:
|
90
|
+
raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
|
91
|
+
|
92
|
+
# Validate model name against supported models
|
93
|
+
supported_models = list(get_args(SUPPORTED_MODELS))
|
94
|
+
if self.model_name not in supported_models:
|
95
|
+
raise ValueError(
|
96
|
+
f"Model {self.model_name} not supported. Supported models: {supported_models}"
|
97
|
+
)
|
98
|
+
|
99
|
+
# Initialize client with minimal configuration
|
100
|
+
base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
|
101
|
+
assert isinstance(base_url_val, str) or base_url_val is None
|
102
|
+
|
103
|
+
# Store model args for use in generate
|
104
|
+
self.model_args = model_args
|
105
|
+
|
106
|
+
self.client = AsyncClient(
|
107
|
+
api_key=self.api_key,
|
108
|
+
base_url=base_url_val or DEFAULT_BASE_URL,
|
109
|
+
)
|
110
|
+
|
111
|
+
# Initialize variant directly with model name
|
112
|
+
self.variant = Variant(self.model_name) # type: ignore
|
113
|
+
|
114
|
+
def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
|
115
|
+
"""Convert an Inspect message to a Goodfire message format.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
message: The message to convert
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
The converted message in Goodfire format
|
122
|
+
|
123
|
+
Raises:
|
124
|
+
ValueError: If the message type is unknown
|
125
|
+
"""
|
126
|
+
role: Literal["system", "user", "assistant"] = "user"
|
127
|
+
if isinstance(message, ChatMessageSystem):
|
128
|
+
role = "system"
|
129
|
+
elif isinstance(message, ChatMessageUser):
|
130
|
+
role = "user"
|
131
|
+
elif isinstance(message, ChatMessageAssistant):
|
132
|
+
role = "assistant"
|
133
|
+
elif isinstance(message, ChatMessageTool):
|
134
|
+
role = "user" # Convert tool messages to user messages
|
135
|
+
else:
|
136
|
+
raise ValueError(f"Unknown message type: {type(message)}")
|
137
|
+
|
138
|
+
content = str(message.content)
|
139
|
+
if isinstance(message, ChatMessageTool):
|
140
|
+
content = f"Tool {message.function}: {content}"
|
141
|
+
|
142
|
+
return GoodfireChatMessage(role=role, content=content)
|
143
|
+
|
144
|
+
def handle_error(self, ex: Exception) -> ModelOutput | Exception:
|
145
|
+
"""Handle only errors that need special treatment for retry logic or model limits."""
|
146
|
+
# Handle token/context length errors
|
147
|
+
if isinstance(ex, InvalidRequestException):
|
148
|
+
error_msg = str(ex).lower()
|
149
|
+
if "context length" in error_msg or "max tokens" in error_msg:
|
150
|
+
return ModelOutput.from_content(
|
151
|
+
model=self.model_name,
|
152
|
+
content=str(ex),
|
153
|
+
stop_reason="model_length",
|
154
|
+
error=error_msg,
|
155
|
+
)
|
156
|
+
|
157
|
+
# Let all other errors propagate
|
158
|
+
return ex
|
159
|
+
|
160
|
+
@override
|
161
|
+
def is_rate_limit(self, ex: BaseException) -> bool:
|
162
|
+
"""Check if exception is due to rate limiting."""
|
163
|
+
return isinstance(ex, RateLimitException)
|
164
|
+
|
165
|
+
@override
|
166
|
+
def connection_key(self) -> str:
|
167
|
+
"""Return key for connection pooling."""
|
168
|
+
return f"goodfire:{self.api_key}"
|
169
|
+
|
170
|
+
@override
|
171
|
+
def max_tokens(self) -> int | None:
|
172
|
+
"""Return maximum tokens supported by model."""
|
173
|
+
return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
|
174
|
+
|
175
|
+
async def generate(
|
176
|
+
self,
|
177
|
+
input: List[ChatMessage],
|
178
|
+
tools: List[ToolInfo],
|
179
|
+
tool_choice: ToolChoice,
|
180
|
+
config: GenerateConfig,
|
181
|
+
*,
|
182
|
+
cache: bool = True,
|
183
|
+
) -> tuple[ModelOutput | Exception, ModelCall]:
|
184
|
+
"""Generate output from the model."""
|
185
|
+
# Convert messages and prepare request params
|
186
|
+
messages = [self._to_goodfire_message(msg) for msg in input]
|
187
|
+
# Build request parameters with type hints
|
188
|
+
params: dict[str, Any] = {
|
189
|
+
"model": self.variant.base_model, # Use base_model instead of stringifying the Variant
|
190
|
+
"messages": messages,
|
191
|
+
"max_completion_tokens": int(config.max_tokens)
|
192
|
+
if config.max_tokens
|
193
|
+
else DEFAULT_MAX_TOKENS,
|
194
|
+
"stream": False,
|
195
|
+
}
|
196
|
+
|
197
|
+
# Add generation parameters from config if not in model_args
|
198
|
+
if "temperature" not in self.model_args and config.temperature is not None:
|
199
|
+
params["temperature"] = float(config.temperature)
|
200
|
+
elif "temperature" not in self.model_args:
|
201
|
+
params["temperature"] = DEFAULT_TEMPERATURE
|
202
|
+
|
203
|
+
if "top_p" not in self.model_args and config.top_p is not None:
|
204
|
+
params["top_p"] = float(config.top_p)
|
205
|
+
elif "top_p" not in self.model_args:
|
206
|
+
params["top_p"] = DEFAULT_TOP_P
|
207
|
+
|
208
|
+
# Add any additional model args (highest priority)
|
209
|
+
api_params = {
|
210
|
+
k: v
|
211
|
+
for k, v in self.model_args.items()
|
212
|
+
if k not in ["api_key", "base_url", "model_args"]
|
213
|
+
}
|
214
|
+
params.update(api_params)
|
215
|
+
|
216
|
+
try:
|
217
|
+
# Use native async client
|
218
|
+
response = await self.client.chat.completions.create(**params)
|
219
|
+
response_dict = response.model_dump()
|
220
|
+
|
221
|
+
output = ModelOutput(
|
222
|
+
model=self.model_name,
|
223
|
+
choices=[
|
224
|
+
ChatCompletionChoice(
|
225
|
+
message=ChatMessageAssistant(
|
226
|
+
content=response_dict["choices"][0]["message"]["content"]
|
227
|
+
),
|
228
|
+
stop_reason="stop",
|
229
|
+
)
|
230
|
+
],
|
231
|
+
usage=ModelUsage(**response_dict["usage"])
|
232
|
+
if "usage" in response_dict
|
233
|
+
else None,
|
234
|
+
)
|
235
|
+
model_call = ModelCall.create(request=params, response=response_dict)
|
236
|
+
return (output, model_call)
|
237
|
+
except Exception as ex:
|
238
|
+
result = self.handle_error(ex)
|
239
|
+
model_call = ModelCall.create(
|
240
|
+
request=params,
|
241
|
+
response={}, # Empty response for error case
|
242
|
+
)
|
243
|
+
return (result, model_call)
|
244
|
+
|
245
|
+
@property
|
246
|
+
def name(self) -> str:
|
247
|
+
"""Get provider name."""
|
248
|
+
return "goodfire"
|
@@ -11,7 +11,6 @@ import proto # type: ignore
|
|
11
11
|
from google.ai.generativelanguage import (
|
12
12
|
Blob,
|
13
13
|
Candidate,
|
14
|
-
File,
|
15
14
|
FunctionCall,
|
16
15
|
FunctionCallingConfig,
|
17
16
|
FunctionDeclaration,
|
@@ -29,29 +28,29 @@ from google.api_core.exceptions import (
|
|
29
28
|
TooManyRequests,
|
30
29
|
)
|
31
30
|
from google.api_core.retry.retry_base import if_transient_error
|
32
|
-
from google.generativeai import
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
get_file,
|
37
|
-
upload_file,
|
38
|
-
)
|
39
|
-
from google.generativeai.types import ( # type: ignore
|
40
|
-
AsyncGenerateContentResponse,
|
31
|
+
from google.generativeai.client import configure
|
32
|
+
from google.generativeai.files import get_file, upload_file
|
33
|
+
from google.generativeai.generative_models import GenerativeModel
|
34
|
+
from google.generativeai.types import (
|
41
35
|
ContentDict,
|
42
|
-
|
43
|
-
HarmCategory,
|
36
|
+
GenerationConfig,
|
44
37
|
PartDict,
|
45
38
|
PartType,
|
46
|
-
SafetySettingDict,
|
47
39
|
Tool,
|
48
40
|
)
|
41
|
+
from google.generativeai.types.file_types import File
|
42
|
+
from google.generativeai.types.generation_types import AsyncGenerateContentResponse
|
43
|
+
from google.generativeai.types.safety_types import (
|
44
|
+
EasySafetySettingDict,
|
45
|
+
HarmBlockThreshold,
|
46
|
+
HarmCategory,
|
47
|
+
)
|
49
48
|
from google.protobuf.json_format import MessageToDict, ParseDict
|
50
49
|
from google.protobuf.struct_pb2 import Struct
|
51
50
|
from pydantic import JsonValue
|
52
51
|
from typing_extensions import override
|
53
52
|
|
54
|
-
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
53
|
+
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
55
54
|
from inspect_ai._util.content import (
|
56
55
|
Content,
|
57
56
|
ContentAudio,
|
@@ -89,7 +88,7 @@ logger = getLogger(__name__)
|
|
89
88
|
|
90
89
|
SAFETY_SETTINGS = "safety_settings"
|
91
90
|
|
92
|
-
DEFAULT_SAFETY_SETTINGS:
|
91
|
+
DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
|
93
92
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
94
93
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
95
94
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
@@ -141,7 +140,7 @@ class GoogleAPI(ModelAPI):
|
|
141
140
|
tools: list[ToolInfo],
|
142
141
|
tool_choice: ToolChoice,
|
143
142
|
config: GenerateConfig,
|
144
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
143
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
145
144
|
parameters = GenerationConfig(
|
146
145
|
temperature=config.temperature,
|
147
146
|
top_p=config.top_p,
|
@@ -149,11 +148,8 @@ class GoogleAPI(ModelAPI):
|
|
149
148
|
max_output_tokens=config.max_tokens,
|
150
149
|
stop_sequences=config.stop_seqs,
|
151
150
|
candidate_count=config.num_choices,
|
152
|
-
seed=config.seed,
|
153
151
|
presence_penalty=config.presence_penalty,
|
154
152
|
frequency_penalty=config.frequency_penalty,
|
155
|
-
response_logprobs=config.logprobs,
|
156
|
-
logprobs=config.top_logprobs,
|
157
153
|
)
|
158
154
|
|
159
155
|
# google-native messages
|
@@ -176,18 +172,15 @@ class GoogleAPI(ModelAPI):
|
|
176
172
|
response=response,
|
177
173
|
)
|
178
174
|
|
179
|
-
# cast to AsyncGenerateContentResponse since we passed stream=False
|
180
175
|
try:
|
181
|
-
response =
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
tools=gemini_tools,
|
188
|
-
tool_config=gemini_tool_config,
|
189
|
-
),
|
176
|
+
response = await self.model.generate_content_async(
|
177
|
+
contents=contents,
|
178
|
+
safety_settings=self.safety_settings,
|
179
|
+
generation_config=parameters,
|
180
|
+
tools=gemini_tools,
|
181
|
+
tool_config=gemini_tool_config,
|
190
182
|
)
|
183
|
+
|
191
184
|
except InvalidArgument as ex:
|
192
185
|
return self.handle_invalid_argument(ex), model_call()
|
193
186
|
|
@@ -205,15 +198,13 @@ class GoogleAPI(ModelAPI):
|
|
205
198
|
# return
|
206
199
|
return output, model_call()
|
207
200
|
|
208
|
-
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput:
|
201
|
+
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
|
209
202
|
if "size exceeds the limit" in ex.message.lower():
|
210
203
|
return ModelOutput.from_content(
|
211
204
|
model=self.model_name, content=ex.message, stop_reason="model_length"
|
212
205
|
)
|
213
206
|
else:
|
214
|
-
return
|
215
|
-
model=self.model_name, content=ex.message, stop_reason="unknown"
|
216
|
-
)
|
207
|
+
return ex
|
217
208
|
|
218
209
|
@override
|
219
210
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -231,7 +222,7 @@ class GoogleAPI(ModelAPI):
|
|
231
222
|
def build_model_call(
|
232
223
|
contents: list[ContentDict],
|
233
224
|
generation_config: GenerationConfig,
|
234
|
-
safety_settings:
|
225
|
+
safety_settings: EasySafetySettingDict,
|
235
226
|
tools: list[Tool] | None,
|
236
227
|
tool_config: ToolConfig | None,
|
237
228
|
response: AsyncGenerateContentResponse | None,
|
@@ -248,7 +239,7 @@ def build_model_call(
|
|
248
239
|
if tool_config is not None
|
249
240
|
else None,
|
250
241
|
),
|
251
|
-
response=response.to_dict() if response is not None else {},
|
242
|
+
response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
|
252
243
|
filter=model_call_filter,
|
253
244
|
)
|
254
245
|
|
@@ -269,12 +260,12 @@ def model_call_content(content: ContentDict) -> ContentDict:
|
|
269
260
|
|
270
261
|
def model_call_part(part: PartType) -> PartType:
|
271
262
|
if isinstance(part, proto.Message):
|
272
|
-
return MessageToDict(part._pb)
|
263
|
+
return cast(PartDict, MessageToDict(part._pb))
|
273
264
|
elif isinstance(part, dict):
|
274
265
|
part = part.copy()
|
275
266
|
keys = list(part.keys())
|
276
267
|
for key in keys:
|
277
|
-
part[key] = model_call_part(part[key])
|
268
|
+
part[key] = model_call_part(part[key]) # type: ignore[literal-required]
|
278
269
|
return part
|
279
270
|
else:
|
280
271
|
return part
|
@@ -316,9 +307,6 @@ def consective_tool_message_reducer(
|
|
316
307
|
return messages
|
317
308
|
|
318
309
|
|
319
|
-
NO_CONTENT = "(no content)"
|
320
|
-
|
321
|
-
|
322
310
|
async def content_dict(
|
323
311
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
324
312
|
) -> ContentDict:
|
@@ -326,13 +314,13 @@ async def content_dict(
|
|
326
314
|
return ContentDict(
|
327
315
|
role="user",
|
328
316
|
parts=(
|
329
|
-
[
|
317
|
+
[message.content or NO_CONTENT]
|
330
318
|
if isinstance(message.content, str)
|
331
319
|
else [await content_part(content) for content in message.content]
|
332
320
|
),
|
333
321
|
)
|
334
322
|
elif isinstance(message, ChatMessageAssistant):
|
335
|
-
content_parts: list[
|
323
|
+
content_parts: list[PartType] = []
|
336
324
|
# tool call parts
|
337
325
|
if message.tool_calls is not None:
|
338
326
|
content_parts.extend(
|
@@ -383,9 +371,9 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
|
|
383
371
|
|
384
372
|
async def content_part(content: Content | str) -> PartType:
|
385
373
|
if isinstance(content, str):
|
386
|
-
return
|
374
|
+
return content or NO_CONTENT
|
387
375
|
elif isinstance(content, ContentText):
|
388
|
-
return
|
376
|
+
return content.text or NO_CONTENT
|
389
377
|
else:
|
390
378
|
return await chat_content_to_part(content)
|
391
379
|
|
@@ -404,7 +392,9 @@ def prepend_system_messages(
|
|
404
392
|
messages: list[ContentDict], system_messages: list[ChatMessageSystem]
|
405
393
|
) -> None:
|
406
394
|
# create system_parts
|
407
|
-
system_parts = [
|
395
|
+
system_parts: list[PartType] = [
|
396
|
+
Part(text=message.content) for message in system_messages
|
397
|
+
]
|
408
398
|
|
409
399
|
# we want the system messages to be prepended to the first user message
|
410
400
|
# (if there is no first user message then prepend one)
|
@@ -476,6 +466,8 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
|
|
476
466
|
return schema_from_param(param.anyOf[0], nullable=True)
|
477
467
|
else:
|
478
468
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
469
|
+
elif param.enum:
|
470
|
+
return Schema(type=Type.STRING, format="enum", enum=param.enum)
|
479
471
|
else:
|
480
472
|
return Schema(type=Type.TYPE_UNSPECIFIED)
|
481
473
|
|
@@ -600,14 +592,14 @@ def gapi_should_retry(ex: BaseException) -> bool:
|
|
600
592
|
|
601
593
|
def parse_safety_settings(
|
602
594
|
safety_settings: Any,
|
603
|
-
) ->
|
595
|
+
) -> EasySafetySettingDict:
|
604
596
|
# ensure we have a dict
|
605
597
|
if isinstance(safety_settings, str):
|
606
598
|
safety_settings = json.loads(safety_settings)
|
607
599
|
if not isinstance(safety_settings, dict):
|
608
600
|
raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
|
609
601
|
|
610
|
-
parsed_settings:
|
602
|
+
parsed_settings: EasySafetySettingDict = {}
|
611
603
|
for key, value in safety_settings.items():
|
612
604
|
if isinstance(key, str):
|
613
605
|
key = str_to_harm_category(key)
|
@@ -623,23 +615,23 @@ def parse_safety_settings(
|
|
623
615
|
return parsed_settings
|
624
616
|
|
625
617
|
|
626
|
-
def str_to_harm_category(category: str) ->
|
618
|
+
def str_to_harm_category(category: str) -> int:
|
627
619
|
category = category.upper()
|
628
620
|
if "HARASSMENT" in category:
|
629
|
-
return HarmCategory.HARM_CATEGORY_HARASSMENT
|
621
|
+
return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
|
630
622
|
elif "HATE_SPEECH" in category:
|
631
|
-
return HarmCategory.HARM_CATEGORY_HATE_SPEECH
|
623
|
+
return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
|
632
624
|
elif "SEXUALLY_EXPLICIT" in category:
|
633
|
-
return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
|
625
|
+
return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
|
634
626
|
elif "DANGEROUS_CONTENT" in category:
|
635
|
-
return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
|
627
|
+
return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
|
636
628
|
else:
|
637
629
|
# NOTE: Although there is an "UNSPECIFIED" category, in the
|
638
630
|
# documentation, the API does not accept it.
|
639
631
|
raise ValueError(f"Unknown HarmCategory: {category}")
|
640
632
|
|
641
633
|
|
642
|
-
def str_to_harm_block_threshold(threshold: str) ->
|
634
|
+
def str_to_harm_block_threshold(threshold: str) -> int:
|
643
635
|
threshold = threshold.upper()
|
644
636
|
if "LOW" in threshold:
|
645
637
|
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
@@ -673,7 +665,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
673
665
|
uploaded_file = files_db.get(content_sha256)
|
674
666
|
if uploaded_file:
|
675
667
|
try:
|
676
|
-
upload =
|
668
|
+
upload = get_file(uploaded_file)
|
677
669
|
if upload.state.name == "ACTIVE":
|
678
670
|
trace(f"Using uploaded file: {uploaded_file}")
|
679
671
|
return upload
|
@@ -27,6 +27,7 @@ from inspect_ai._util.images import file_as_data_uri
|
|
27
27
|
from inspect_ai._util.url import is_http_url
|
28
28
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
29
29
|
|
30
|
+
from .._call_tools import parse_tool_call
|
30
31
|
from .._chat_message import (
|
31
32
|
ChatMessage,
|
32
33
|
ChatMessageAssistant,
|
@@ -37,12 +38,15 @@ from .._chat_message import (
|
|
37
38
|
from .._generate_config import GenerateConfig
|
38
39
|
from .._model import ModelAPI
|
39
40
|
from .._model_call import ModelCall
|
40
|
-
from .._model_output import
|
41
|
-
|
41
|
+
from .._model_output import (
|
42
|
+
ChatCompletionChoice,
|
43
|
+
ModelOutput,
|
44
|
+
ModelUsage,
|
42
45
|
as_stop_reason,
|
46
|
+
)
|
47
|
+
from .util import (
|
43
48
|
environment_prerequisite_error,
|
44
49
|
model_base_url,
|
45
|
-
parse_tool_call,
|
46
50
|
)
|
47
51
|
|
48
52
|
GROQ_API_KEY = "GROQ_API_KEY"
|
@@ -150,6 +150,12 @@ class HuggingFaceAPI(ModelAPI):
|
|
150
150
|
kwargs["output_logits"] = config.logprobs
|
151
151
|
if "return_dict_in_generate" in kwargs:
|
152
152
|
assert kwargs["return_dict_in_generate"]
|
153
|
+
if config.stop_seqs is not None:
|
154
|
+
from transformers.generation import StopStringCriteria # type: ignore
|
155
|
+
|
156
|
+
stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
|
157
|
+
kwargs["stopping_criteria"] = stopping_criteria
|
158
|
+
|
153
159
|
kwargs["return_dict_in_generate"] = True
|
154
160
|
generator = functools.partial(self.model.generate, **kwargs)
|
155
161
|
|
@@ -40,11 +40,13 @@ from typing_extensions import override
|
|
40
40
|
# https://github.com/mistralai/client-python/blob/main/MIGRATION.md
|
41
41
|
from inspect_ai._util.constants import (
|
42
42
|
DEFAULT_TIMEOUT,
|
43
|
+
NO_CONTENT,
|
43
44
|
)
|
44
45
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
45
46
|
from inspect_ai._util.images import file_as_data_uri
|
46
47
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
47
48
|
|
49
|
+
from .._call_tools import parse_tool_call
|
48
50
|
from .._chat_message import (
|
49
51
|
ChatMessage,
|
50
52
|
ChatMessageAssistant,
|
@@ -58,7 +60,7 @@ from .._model_output import (
|
|
58
60
|
ModelUsage,
|
59
61
|
StopReason,
|
60
62
|
)
|
61
|
-
from .util import environment_prerequisite_error, model_base_url
|
63
|
+
from .util import environment_prerequisite_error, model_base_url
|
62
64
|
|
63
65
|
AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
|
64
66
|
AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
|
@@ -122,7 +124,7 @@ class MistralAPI(ModelAPI):
|
|
122
124
|
tools: list[ToolInfo],
|
123
125
|
tool_choice: ToolChoice,
|
124
126
|
config: GenerateConfig,
|
125
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
127
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
126
128
|
# build request
|
127
129
|
request: dict[str, Any] = dict(
|
128
130
|
model=self.model_name,
|
@@ -146,7 +148,7 @@ class MistralAPI(ModelAPI):
|
|
146
148
|
response = await self.client.chat.complete_async(**request)
|
147
149
|
except SDKError as ex:
|
148
150
|
if ex.status_code == 400:
|
149
|
-
return self.handle_bad_request(ex)
|
151
|
+
return self.handle_bad_request(ex), mistral_model_call(request, None)
|
150
152
|
else:
|
151
153
|
raise ex
|
152
154
|
|
@@ -181,25 +183,27 @@ class MistralAPI(ModelAPI):
|
|
181
183
|
def connection_key(self) -> str:
|
182
184
|
return str(self.api_key)
|
183
185
|
|
184
|
-
def handle_bad_request(self, ex: SDKError) -> ModelOutput:
|
186
|
+
def handle_bad_request(self, ex: SDKError) -> ModelOutput | Exception:
|
187
|
+
body = json.loads(ex.body)
|
188
|
+
content = body.get("message", ex.body)
|
185
189
|
if "maximum context length" in ex.body:
|
186
|
-
body = json.loads(ex.body)
|
187
|
-
content = body.get("message", ex.body)
|
188
190
|
return ModelOutput.from_content(
|
189
191
|
model=self.model_name, content=content, stop_reason="model_length"
|
190
192
|
)
|
191
193
|
else:
|
192
|
-
|
194
|
+
return ex
|
193
195
|
|
194
196
|
|
195
197
|
def mistral_model_call(
|
196
|
-
request: dict[str, Any], response: MistralChatCompletionResponse
|
198
|
+
request: dict[str, Any], response: MistralChatCompletionResponse | None
|
197
199
|
) -> ModelCall:
|
198
200
|
request = request.copy()
|
199
201
|
request.update(messages=[message.model_dump() for message in request["messages"]])
|
200
202
|
if request.get("tools", None) is not None:
|
201
203
|
request["tools"] = [tool.model_dump() for tool in request["tools"]]
|
202
|
-
return ModelCall(
|
204
|
+
return ModelCall(
|
205
|
+
request=request, response=response.model_dump() if response else {}
|
206
|
+
)
|
203
207
|
|
204
208
|
|
205
209
|
def mistral_chat_tools(tools: list[ToolInfo]) -> list[MistralTool]:
|
@@ -326,9 +330,6 @@ async def mistral_chat_message(
|
|
326
330
|
)
|
327
331
|
|
328
332
|
|
329
|
-
NO_CONTENT = "(no content)"
|
330
|
-
|
331
|
-
|
332
333
|
async def mistral_message_content(
|
333
334
|
content: str | list[Content],
|
334
335
|
) -> str | list[ContentChunk]:
|