inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. inspect_ai/_cli/common.py +3 -1
  2. inspect_ai/_cli/eval.py +15 -9
  3. inspect_ai/_display/core/active.py +4 -1
  4. inspect_ai/_display/core/config.py +3 -3
  5. inspect_ai/_display/core/panel.py +7 -3
  6. inspect_ai/_display/plain/__init__.py +0 -0
  7. inspect_ai/_display/plain/display.py +203 -0
  8. inspect_ai/_display/rich/display.py +0 -5
  9. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  10. inspect_ai/_display/textual/widgets/samples.py +79 -12
  11. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  12. inspect_ai/_eval/eval.py +10 -1
  13. inspect_ai/_eval/loader.py +79 -19
  14. inspect_ai/_eval/registry.py +6 -0
  15. inspect_ai/_eval/score.py +3 -1
  16. inspect_ai/_eval/task/results.py +51 -22
  17. inspect_ai/_eval/task/run.py +47 -13
  18. inspect_ai/_eval/task/sandbox.py +10 -5
  19. inspect_ai/_util/constants.py +1 -0
  20. inspect_ai/_util/port_names.py +61 -0
  21. inspect_ai/_util/text.py +23 -0
  22. inspect_ai/_view/www/App.css +31 -1
  23. inspect_ai/_view/www/dist/assets/index.css +31 -1
  24. inspect_ai/_view/www/dist/assets/index.js +25498 -2044
  25. inspect_ai/_view/www/log-schema.json +32 -2
  26. inspect_ai/_view/www/package.json +2 -0
  27. inspect_ai/_view/www/src/App.mjs +14 -16
  28. inspect_ai/_view/www/src/Types.mjs +1 -2
  29. inspect_ai/_view/www/src/api/Types.ts +133 -0
  30. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  31. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  32. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  33. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  34. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  35. inspect_ai/_view/www/src/api/index.ts +51 -0
  36. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  37. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  38. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  39. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  40. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  41. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  42. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  43. inspect_ai/_view/www/src/index.js +77 -4
  44. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
  46. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
  47. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  48. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  49. inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
  50. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  51. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  52. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
  53. inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
  54. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  55. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  56. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +13 -2
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
  76. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  77. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
  78. inspect_ai/_view/www/vite.config.js +7 -0
  79. inspect_ai/_view/www/yarn.lock +116 -0
  80. inspect_ai/approval/_human/__init__.py +0 -0
  81. inspect_ai/approval/_human/manager.py +1 -1
  82. inspect_ai/approval/_policy.py +12 -6
  83. inspect_ai/log/_log.py +1 -1
  84. inspect_ai/log/_samples.py +16 -0
  85. inspect_ai/log/_transcript.py +4 -1
  86. inspect_ai/model/_call_tools.py +59 -0
  87. inspect_ai/model/_conversation.py +16 -7
  88. inspect_ai/model/_generate_config.py +12 -12
  89. inspect_ai/model/_model.py +117 -18
  90. inspect_ai/model/_model_output.py +22 -2
  91. inspect_ai/model/_openai.py +383 -0
  92. inspect_ai/model/_providers/anthropic.py +152 -55
  93. inspect_ai/model/_providers/azureai.py +21 -21
  94. inspect_ai/model/_providers/bedrock.py +37 -40
  95. inspect_ai/model/_providers/goodfire.py +248 -0
  96. inspect_ai/model/_providers/google.py +46 -54
  97. inspect_ai/model/_providers/groq.py +7 -3
  98. inspect_ai/model/_providers/hf.py +6 -0
  99. inspect_ai/model/_providers/mistral.py +13 -12
  100. inspect_ai/model/_providers/openai.py +51 -218
  101. inspect_ai/model/_providers/openai_o1.py +11 -12
  102. inspect_ai/model/_providers/providers.py +23 -1
  103. inspect_ai/model/_providers/together.py +12 -12
  104. inspect_ai/model/_providers/util/__init__.py +2 -3
  105. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  106. inspect_ai/model/_providers/util/llama31.py +1 -1
  107. inspect_ai/model/_providers/util/util.py +0 -76
  108. inspect_ai/model/_providers/vertex.py +1 -4
  109. inspect_ai/scorer/_metric.py +3 -0
  110. inspect_ai/scorer/_reducer/reducer.py +1 -1
  111. inspect_ai/scorer/_scorer.py +4 -3
  112. inspect_ai/solver/__init__.py +4 -5
  113. inspect_ai/solver/_basic_agent.py +1 -1
  114. inspect_ai/solver/_bridge/__init__.py +3 -0
  115. inspect_ai/solver/_bridge/bridge.py +100 -0
  116. inspect_ai/solver/_bridge/patch.py +170 -0
  117. inspect_ai/solver/_prompt.py +35 -5
  118. inspect_ai/solver/_solver.py +6 -0
  119. inspect_ai/solver/_task_state.py +80 -38
  120. inspect_ai/tool/__init__.py +2 -0
  121. inspect_ai/tool/_tool.py +12 -1
  122. inspect_ai/tool/_tool_call.py +10 -0
  123. inspect_ai/tool/_tool_def.py +16 -5
  124. inspect_ai/tool/_tool_with.py +21 -4
  125. inspect_ai/tool/beta/__init__.py +5 -0
  126. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  127. inspect_ai/tool/beta/_computer/_common.py +133 -0
  128. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  129. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  130. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  131. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  134. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  135. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  136. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  137. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  138. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  139. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  144. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  145. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  146. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  147. inspect_ai/util/__init__.py +2 -0
  148. inspect_ai/util/_display.py +5 -0
  149. inspect_ai/util/_limit.py +26 -0
  150. inspect_ai/util/_sandbox/docker/docker.py +64 -1
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  153. inspect_ai/util/_sandbox/environment.py +14 -0
  154. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
  155. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
  156. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  157. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  158. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  159. inspect_ai/_view/www/src/api/index.mjs +0 -49
  160. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  161. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  162. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  163. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
  164. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
  165. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
  166. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,248 @@
1
+ import os
2
+ from typing import Any, List, Literal, get_args
3
+
4
+ from goodfire import AsyncClient
5
+ from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
6
+ from goodfire.api.exceptions import InvalidRequestException, RateLimitException
7
+ from goodfire.variants.variants import SUPPORTED_MODELS, Variant
8
+ from typing_extensions import override
9
+
10
+ from inspect_ai.tool._tool_choice import ToolChoice
11
+ from inspect_ai.tool._tool_info import ToolInfo
12
+
13
+ from .._chat_message import (
14
+ ChatMessage,
15
+ ChatMessageAssistant,
16
+ ChatMessageSystem,
17
+ ChatMessageTool,
18
+ ChatMessageUser,
19
+ )
20
+ from .._generate_config import GenerateConfig
21
+ from .._model import ModelAPI
22
+ from .._model_call import ModelCall
23
+ from .._model_output import (
24
+ ChatCompletionChoice,
25
+ ModelOutput,
26
+ ModelUsage,
27
+ )
28
+ from .util import environment_prerequisite_error, model_base_url
29
+
30
+ # Constants
31
+ GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
32
+ DEFAULT_BASE_URL = "https://api.goodfire.ai"
33
+ DEFAULT_MAX_TOKENS = 4096
34
+ DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
35
+ DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
36
+
37
+
38
+ class GoodfireAPI(ModelAPI):
39
+ """Goodfire API provider.
40
+
41
+ This provider implements the Goodfire API for LLM inference. It supports:
42
+ - Chat completions with standard message formats
43
+ - Basic parameter controls (temperature, top_p, etc.)
44
+ - Usage statistics tracking
45
+ - Stop reason handling
46
+
47
+ Does not currently support:
48
+ - Tool calls
49
+ - Feature analysis
50
+ - Streaming responses
51
+
52
+ Known limitations:
53
+ - Limited role support (system/user/assistant only)
54
+ - Tool messages converted to user messages
55
+ """
56
+
57
+ client: AsyncClient
58
+ variant: Variant
59
+ model_args: dict[str, Any]
60
+
61
+ def __init__(
62
+ self,
63
+ model_name: str,
64
+ base_url: str | None = None,
65
+ api_key: str | None = None,
66
+ config: GenerateConfig = GenerateConfig(),
67
+ **model_args: Any,
68
+ ) -> None:
69
+ """Initialize the Goodfire API provider.
70
+
71
+ Args:
72
+ model_name: Name of the model to use
73
+ base_url: Optional custom API base URL
74
+ api_key: Optional API key (will check env vars if not provided)
75
+ config: Generation config options
76
+ **model_args: Additional arguments passed to the API
77
+ """
78
+ super().__init__(
79
+ model_name=model_name,
80
+ base_url=base_url,
81
+ api_key=api_key,
82
+ api_key_vars=[GOODFIRE_API_KEY],
83
+ config=config,
84
+ )
85
+
86
+ # resolve api_key
87
+ if not self.api_key:
88
+ self.api_key = os.environ.get(GOODFIRE_API_KEY)
89
+ if not self.api_key:
90
+ raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
91
+
92
+ # Validate model name against supported models
93
+ supported_models = list(get_args(SUPPORTED_MODELS))
94
+ if self.model_name not in supported_models:
95
+ raise ValueError(
96
+ f"Model {self.model_name} not supported. Supported models: {supported_models}"
97
+ )
98
+
99
+ # Initialize client with minimal configuration
100
+ base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
101
+ assert isinstance(base_url_val, str) or base_url_val is None
102
+
103
+ # Store model args for use in generate
104
+ self.model_args = model_args
105
+
106
+ self.client = AsyncClient(
107
+ api_key=self.api_key,
108
+ base_url=base_url_val or DEFAULT_BASE_URL,
109
+ )
110
+
111
+ # Initialize variant directly with model name
112
+ self.variant = Variant(self.model_name) # type: ignore
113
+
114
+ def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
115
+ """Convert an Inspect message to a Goodfire message format.
116
+
117
+ Args:
118
+ message: The message to convert
119
+
120
+ Returns:
121
+ The converted message in Goodfire format
122
+
123
+ Raises:
124
+ ValueError: If the message type is unknown
125
+ """
126
+ role: Literal["system", "user", "assistant"] = "user"
127
+ if isinstance(message, ChatMessageSystem):
128
+ role = "system"
129
+ elif isinstance(message, ChatMessageUser):
130
+ role = "user"
131
+ elif isinstance(message, ChatMessageAssistant):
132
+ role = "assistant"
133
+ elif isinstance(message, ChatMessageTool):
134
+ role = "user" # Convert tool messages to user messages
135
+ else:
136
+ raise ValueError(f"Unknown message type: {type(message)}")
137
+
138
+ content = str(message.content)
139
+ if isinstance(message, ChatMessageTool):
140
+ content = f"Tool {message.function}: {content}"
141
+
142
+ return GoodfireChatMessage(role=role, content=content)
143
+
144
+ def handle_error(self, ex: Exception) -> ModelOutput | Exception:
145
+ """Handle only errors that need special treatment for retry logic or model limits."""
146
+ # Handle token/context length errors
147
+ if isinstance(ex, InvalidRequestException):
148
+ error_msg = str(ex).lower()
149
+ if "context length" in error_msg or "max tokens" in error_msg:
150
+ return ModelOutput.from_content(
151
+ model=self.model_name,
152
+ content=str(ex),
153
+ stop_reason="model_length",
154
+ error=error_msg,
155
+ )
156
+
157
+ # Let all other errors propagate
158
+ return ex
159
+
160
+ @override
161
+ def is_rate_limit(self, ex: BaseException) -> bool:
162
+ """Check if exception is due to rate limiting."""
163
+ return isinstance(ex, RateLimitException)
164
+
165
+ @override
166
+ def connection_key(self) -> str:
167
+ """Return key for connection pooling."""
168
+ return f"goodfire:{self.api_key}"
169
+
170
+ @override
171
+ def max_tokens(self) -> int | None:
172
+ """Return maximum tokens supported by model."""
173
+ return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
174
+
175
+ async def generate(
176
+ self,
177
+ input: List[ChatMessage],
178
+ tools: List[ToolInfo],
179
+ tool_choice: ToolChoice,
180
+ config: GenerateConfig,
181
+ *,
182
+ cache: bool = True,
183
+ ) -> tuple[ModelOutput | Exception, ModelCall]:
184
+ """Generate output from the model."""
185
+ # Convert messages and prepare request params
186
+ messages = [self._to_goodfire_message(msg) for msg in input]
187
+ # Build request parameters with type hints
188
+ params: dict[str, Any] = {
189
+ "model": self.variant.base_model, # Use base_model instead of stringifying the Variant
190
+ "messages": messages,
191
+ "max_completion_tokens": int(config.max_tokens)
192
+ if config.max_tokens
193
+ else DEFAULT_MAX_TOKENS,
194
+ "stream": False,
195
+ }
196
+
197
+ # Add generation parameters from config if not in model_args
198
+ if "temperature" not in self.model_args and config.temperature is not None:
199
+ params["temperature"] = float(config.temperature)
200
+ elif "temperature" not in self.model_args:
201
+ params["temperature"] = DEFAULT_TEMPERATURE
202
+
203
+ if "top_p" not in self.model_args and config.top_p is not None:
204
+ params["top_p"] = float(config.top_p)
205
+ elif "top_p" not in self.model_args:
206
+ params["top_p"] = DEFAULT_TOP_P
207
+
208
+ # Add any additional model args (highest priority)
209
+ api_params = {
210
+ k: v
211
+ for k, v in self.model_args.items()
212
+ if k not in ["api_key", "base_url", "model_args"]
213
+ }
214
+ params.update(api_params)
215
+
216
+ try:
217
+ # Use native async client
218
+ response = await self.client.chat.completions.create(**params)
219
+ response_dict = response.model_dump()
220
+
221
+ output = ModelOutput(
222
+ model=self.model_name,
223
+ choices=[
224
+ ChatCompletionChoice(
225
+ message=ChatMessageAssistant(
226
+ content=response_dict["choices"][0]["message"]["content"]
227
+ ),
228
+ stop_reason="stop",
229
+ )
230
+ ],
231
+ usage=ModelUsage(**response_dict["usage"])
232
+ if "usage" in response_dict
233
+ else None,
234
+ )
235
+ model_call = ModelCall.create(request=params, response=response_dict)
236
+ return (output, model_call)
237
+ except Exception as ex:
238
+ result = self.handle_error(ex)
239
+ model_call = ModelCall.create(
240
+ request=params,
241
+ response={}, # Empty response for error case
242
+ )
243
+ return (result, model_call)
244
+
245
+ @property
246
+ def name(self) -> str:
247
+ """Get provider name."""
248
+ return "goodfire"
@@ -11,7 +11,6 @@ import proto # type: ignore
11
11
  from google.ai.generativelanguage import (
12
12
  Blob,
13
13
  Candidate,
14
- File,
15
14
  FunctionCall,
16
15
  FunctionCallingConfig,
17
16
  FunctionDeclaration,
@@ -29,29 +28,29 @@ from google.api_core.exceptions import (
29
28
  TooManyRequests,
30
29
  )
31
30
  from google.api_core.retry.retry_base import if_transient_error
32
- from google.generativeai import ( # type: ignore
33
- GenerationConfig,
34
- GenerativeModel,
35
- configure,
36
- get_file,
37
- upload_file,
38
- )
39
- from google.generativeai.types import ( # type: ignore
40
- AsyncGenerateContentResponse,
31
+ from google.generativeai.client import configure
32
+ from google.generativeai.files import get_file, upload_file
33
+ from google.generativeai.generative_models import GenerativeModel
34
+ from google.generativeai.types import (
41
35
  ContentDict,
42
- HarmBlockThreshold,
43
- HarmCategory,
36
+ GenerationConfig,
44
37
  PartDict,
45
38
  PartType,
46
- SafetySettingDict,
47
39
  Tool,
48
40
  )
41
+ from google.generativeai.types.file_types import File
42
+ from google.generativeai.types.generation_types import AsyncGenerateContentResponse
43
+ from google.generativeai.types.safety_types import (
44
+ EasySafetySettingDict,
45
+ HarmBlockThreshold,
46
+ HarmCategory,
47
+ )
49
48
  from google.protobuf.json_format import MessageToDict, ParseDict
50
49
  from google.protobuf.struct_pb2 import Struct
51
50
  from pydantic import JsonValue
52
51
  from typing_extensions import override
53
52
 
54
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED
53
+ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
55
54
  from inspect_ai._util.content import (
56
55
  Content,
57
56
  ContentAudio,
@@ -89,7 +88,7 @@ logger = getLogger(__name__)
89
88
 
90
89
  SAFETY_SETTINGS = "safety_settings"
91
90
 
92
- DEFAULT_SAFETY_SETTINGS: SafetySettingDict = {
91
+ DEFAULT_SAFETY_SETTINGS: EasySafetySettingDict = {
93
92
  HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
94
93
  HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
95
94
  HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
@@ -141,7 +140,7 @@ class GoogleAPI(ModelAPI):
141
140
  tools: list[ToolInfo],
142
141
  tool_choice: ToolChoice,
143
142
  config: GenerateConfig,
144
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
143
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
145
144
  parameters = GenerationConfig(
146
145
  temperature=config.temperature,
147
146
  top_p=config.top_p,
@@ -149,11 +148,8 @@ class GoogleAPI(ModelAPI):
149
148
  max_output_tokens=config.max_tokens,
150
149
  stop_sequences=config.stop_seqs,
151
150
  candidate_count=config.num_choices,
152
- seed=config.seed,
153
151
  presence_penalty=config.presence_penalty,
154
152
  frequency_penalty=config.frequency_penalty,
155
- response_logprobs=config.logprobs,
156
- logprobs=config.top_logprobs,
157
153
  )
158
154
 
159
155
  # google-native messages
@@ -176,18 +172,15 @@ class GoogleAPI(ModelAPI):
176
172
  response=response,
177
173
  )
178
174
 
179
- # cast to AsyncGenerateContentResponse since we passed stream=False
180
175
  try:
181
- response = cast(
182
- AsyncGenerateContentResponse,
183
- await self.model.generate_content_async(
184
- contents=contents,
185
- safety_settings=self.safety_settings,
186
- generation_config=parameters,
187
- tools=gemini_tools,
188
- tool_config=gemini_tool_config,
189
- ),
176
+ response = await self.model.generate_content_async(
177
+ contents=contents,
178
+ safety_settings=self.safety_settings,
179
+ generation_config=parameters,
180
+ tools=gemini_tools,
181
+ tool_config=gemini_tool_config,
190
182
  )
183
+
191
184
  except InvalidArgument as ex:
192
185
  return self.handle_invalid_argument(ex), model_call()
193
186
 
@@ -205,15 +198,13 @@ class GoogleAPI(ModelAPI):
205
198
  # return
206
199
  return output, model_call()
207
200
 
208
- def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput:
201
+ def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
209
202
  if "size exceeds the limit" in ex.message.lower():
210
203
  return ModelOutput.from_content(
211
204
  model=self.model_name, content=ex.message, stop_reason="model_length"
212
205
  )
213
206
  else:
214
- return ModelOutput.from_content(
215
- model=self.model_name, content=ex.message, stop_reason="unknown"
216
- )
207
+ return ex
217
208
 
218
209
  @override
219
210
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -231,7 +222,7 @@ class GoogleAPI(ModelAPI):
231
222
  def build_model_call(
232
223
  contents: list[ContentDict],
233
224
  generation_config: GenerationConfig,
234
- safety_settings: SafetySettingDict,
225
+ safety_settings: EasySafetySettingDict,
235
226
  tools: list[Tool] | None,
236
227
  tool_config: ToolConfig | None,
237
228
  response: AsyncGenerateContentResponse | None,
@@ -248,7 +239,7 @@ def build_model_call(
248
239
  if tool_config is not None
249
240
  else None,
250
241
  ),
251
- response=response.to_dict() if response is not None else {},
242
+ response=response.to_dict() if response is not None else {}, # type: ignore[no-untyped-call]
252
243
  filter=model_call_filter,
253
244
  )
254
245
 
@@ -269,12 +260,12 @@ def model_call_content(content: ContentDict) -> ContentDict:
269
260
 
270
261
  def model_call_part(part: PartType) -> PartType:
271
262
  if isinstance(part, proto.Message):
272
- return MessageToDict(part._pb)
263
+ return cast(PartDict, MessageToDict(part._pb))
273
264
  elif isinstance(part, dict):
274
265
  part = part.copy()
275
266
  keys = list(part.keys())
276
267
  for key in keys:
277
- part[key] = model_call_part(part[key])
268
+ part[key] = model_call_part(part[key]) # type: ignore[literal-required]
278
269
  return part
279
270
  else:
280
271
  return part
@@ -316,9 +307,6 @@ def consective_tool_message_reducer(
316
307
  return messages
317
308
 
318
309
 
319
- NO_CONTENT = "(no content)"
320
-
321
-
322
310
  async def content_dict(
323
311
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
324
312
  ) -> ContentDict:
@@ -326,13 +314,13 @@ async def content_dict(
326
314
  return ContentDict(
327
315
  role="user",
328
316
  parts=(
329
- [PartDict(text=message.content or NO_CONTENT)]
317
+ [message.content or NO_CONTENT]
330
318
  if isinstance(message.content, str)
331
319
  else [await content_part(content) for content in message.content]
332
320
  ),
333
321
  )
334
322
  elif isinstance(message, ChatMessageAssistant):
335
- content_parts: list[Part] = []
323
+ content_parts: list[PartType] = []
336
324
  # tool call parts
337
325
  if message.tool_calls is not None:
338
326
  content_parts.extend(
@@ -383,9 +371,9 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
383
371
 
384
372
  async def content_part(content: Content | str) -> PartType:
385
373
  if isinstance(content, str):
386
- return PartDict(text=content or NO_CONTENT)
374
+ return content or NO_CONTENT
387
375
  elif isinstance(content, ContentText):
388
- return PartDict(text=content.text or NO_CONTENT)
376
+ return content.text or NO_CONTENT
389
377
  else:
390
378
  return await chat_content_to_part(content)
391
379
 
@@ -404,7 +392,9 @@ def prepend_system_messages(
404
392
  messages: list[ContentDict], system_messages: list[ChatMessageSystem]
405
393
  ) -> None:
406
394
  # create system_parts
407
- system_parts = [Part(text=message.content) for message in system_messages]
395
+ system_parts: list[PartType] = [
396
+ Part(text=message.content) for message in system_messages
397
+ ]
408
398
 
409
399
  # we want the system messages to be prepended to the first user message
410
400
  # (if there is no first user message then prepend one)
@@ -476,6 +466,8 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
476
466
  return schema_from_param(param.anyOf[0], nullable=True)
477
467
  else:
478
468
  return Schema(type=Type.TYPE_UNSPECIFIED)
469
+ elif param.enum:
470
+ return Schema(type=Type.STRING, format="enum", enum=param.enum)
479
471
  else:
480
472
  return Schema(type=Type.TYPE_UNSPECIFIED)
481
473
 
@@ -600,14 +592,14 @@ def gapi_should_retry(ex: BaseException) -> bool:
600
592
 
601
593
  def parse_safety_settings(
602
594
  safety_settings: Any,
603
- ) -> dict[HarmCategory, HarmBlockThreshold]:
595
+ ) -> EasySafetySettingDict:
604
596
  # ensure we have a dict
605
597
  if isinstance(safety_settings, str):
606
598
  safety_settings = json.loads(safety_settings)
607
599
  if not isinstance(safety_settings, dict):
608
600
  raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
609
601
 
610
- parsed_settings: dict[HarmCategory, HarmBlockThreshold] = {}
602
+ parsed_settings: EasySafetySettingDict = {}
611
603
  for key, value in safety_settings.items():
612
604
  if isinstance(key, str):
613
605
  key = str_to_harm_category(key)
@@ -623,23 +615,23 @@ def parse_safety_settings(
623
615
  return parsed_settings
624
616
 
625
617
 
626
- def str_to_harm_category(category: str) -> HarmCategory:
618
+ def str_to_harm_category(category: str) -> int:
627
619
  category = category.upper()
628
620
  if "HARASSMENT" in category:
629
- return HarmCategory.HARM_CATEGORY_HARASSMENT
621
+ return cast(int, HarmCategory.HARM_CATEGORY_HARASSMENT)
630
622
  elif "HATE_SPEECH" in category:
631
- return HarmCategory.HARM_CATEGORY_HATE_SPEECH
623
+ return cast(int, HarmCategory.HARM_CATEGORY_HATE_SPEECH)
632
624
  elif "SEXUALLY_EXPLICIT" in category:
633
- return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
625
+ return cast(int, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT)
634
626
  elif "DANGEROUS_CONTENT" in category:
635
- return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
627
+ return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
636
628
  else:
637
629
  # NOTE: Although there is an "UNSPECIFIED" category, in the
638
630
  # documentation, the API does not accept it.
639
631
  raise ValueError(f"Unknown HarmCategory: {category}")
640
632
 
641
633
 
642
- def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
634
+ def str_to_harm_block_threshold(threshold: str) -> int:
643
635
  threshold = threshold.upper()
644
636
  if "LOW" in threshold:
645
637
  return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
@@ -673,7 +665,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
673
665
  uploaded_file = files_db.get(content_sha256)
674
666
  if uploaded_file:
675
667
  try:
676
- upload = cast(File, get_file(uploaded_file))
668
+ upload = get_file(uploaded_file)
677
669
  if upload.state.name == "ACTIVE":
678
670
  trace(f"Using uploaded file: {uploaded_file}")
679
671
  return upload
@@ -27,6 +27,7 @@ from inspect_ai._util.images import file_as_data_uri
27
27
  from inspect_ai._util.url import is_http_url
28
28
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
29
29
 
30
+ from .._call_tools import parse_tool_call
30
31
  from .._chat_message import (
31
32
  ChatMessage,
32
33
  ChatMessageAssistant,
@@ -37,12 +38,15 @@ from .._chat_message import (
37
38
  from .._generate_config import GenerateConfig
38
39
  from .._model import ModelAPI
39
40
  from .._model_call import ModelCall
40
- from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
41
- from .util import (
41
+ from .._model_output import (
42
+ ChatCompletionChoice,
43
+ ModelOutput,
44
+ ModelUsage,
42
45
  as_stop_reason,
46
+ )
47
+ from .util import (
43
48
  environment_prerequisite_error,
44
49
  model_base_url,
45
- parse_tool_call,
46
50
  )
47
51
 
48
52
  GROQ_API_KEY = "GROQ_API_KEY"
@@ -150,6 +150,12 @@ class HuggingFaceAPI(ModelAPI):
150
150
  kwargs["output_logits"] = config.logprobs
151
151
  if "return_dict_in_generate" in kwargs:
152
152
  assert kwargs["return_dict_in_generate"]
153
+ if config.stop_seqs is not None:
154
+ from transformers.generation import StopStringCriteria # type: ignore
155
+
156
+ stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
157
+ kwargs["stopping_criteria"] = stopping_criteria
158
+
153
159
  kwargs["return_dict_in_generate"] = True
154
160
  generator = functools.partial(self.model.generate, **kwargs)
155
161
 
@@ -40,11 +40,13 @@ from typing_extensions import override
40
40
  # https://github.com/mistralai/client-python/blob/main/MIGRATION.md
41
41
  from inspect_ai._util.constants import (
42
42
  DEFAULT_TIMEOUT,
43
+ NO_CONTENT,
43
44
  )
44
45
  from inspect_ai._util.content import Content, ContentImage, ContentText
45
46
  from inspect_ai._util.images import file_as_data_uri
46
47
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
47
48
 
49
+ from .._call_tools import parse_tool_call
48
50
  from .._chat_message import (
49
51
  ChatMessage,
50
52
  ChatMessageAssistant,
@@ -58,7 +60,7 @@ from .._model_output import (
58
60
  ModelUsage,
59
61
  StopReason,
60
62
  )
61
- from .util import environment_prerequisite_error, model_base_url, parse_tool_call
63
+ from .util import environment_prerequisite_error, model_base_url
62
64
 
63
65
  AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
64
66
  AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
@@ -122,7 +124,7 @@ class MistralAPI(ModelAPI):
122
124
  tools: list[ToolInfo],
123
125
  tool_choice: ToolChoice,
124
126
  config: GenerateConfig,
125
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
127
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
126
128
  # build request
127
129
  request: dict[str, Any] = dict(
128
130
  model=self.model_name,
@@ -146,7 +148,7 @@ class MistralAPI(ModelAPI):
146
148
  response = await self.client.chat.complete_async(**request)
147
149
  except SDKError as ex:
148
150
  if ex.status_code == 400:
149
- return self.handle_bad_request(ex)
151
+ return self.handle_bad_request(ex), mistral_model_call(request, None)
150
152
  else:
151
153
  raise ex
152
154
 
@@ -181,25 +183,27 @@ class MistralAPI(ModelAPI):
181
183
  def connection_key(self) -> str:
182
184
  return str(self.api_key)
183
185
 
184
- def handle_bad_request(self, ex: SDKError) -> ModelOutput:
186
+ def handle_bad_request(self, ex: SDKError) -> ModelOutput | Exception:
187
+ body = json.loads(ex.body)
188
+ content = body.get("message", ex.body)
185
189
  if "maximum context length" in ex.body:
186
- body = json.loads(ex.body)
187
- content = body.get("message", ex.body)
188
190
  return ModelOutput.from_content(
189
191
  model=self.model_name, content=content, stop_reason="model_length"
190
192
  )
191
193
  else:
192
- raise ex
194
+ return ex
193
195
 
194
196
 
195
197
  def mistral_model_call(
196
- request: dict[str, Any], response: MistralChatCompletionResponse
198
+ request: dict[str, Any], response: MistralChatCompletionResponse | None
197
199
  ) -> ModelCall:
198
200
  request = request.copy()
199
201
  request.update(messages=[message.model_dump() for message in request["messages"]])
200
202
  if request.get("tools", None) is not None:
201
203
  request["tools"] = [tool.model_dump() for tool in request["tools"]]
202
- return ModelCall(request=request, response=response.model_dump())
204
+ return ModelCall(
205
+ request=request, response=response.model_dump() if response else {}
206
+ )
203
207
 
204
208
 
205
209
  def mistral_chat_tools(tools: list[ToolInfo]) -> list[MistralTool]:
@@ -326,9 +330,6 @@ async def mistral_chat_message(
326
330
  )
327
331
 
328
332
 
329
- NO_CONTENT = "(no content)"
330
-
331
-
332
333
  async def mistral_message_content(
333
334
  content: str | list[Content],
334
335
  ) -> str | list[ContentChunk]: