inspect-ai 0.3.59__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +0 -7
- inspect_ai/_display/textual/widgets/samples.py +1 -1
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +2 -1
- inspect_ai/_eval/task/results.py +6 -5
- inspect_ai/_eval/task/run.py +11 -11
- inspect_ai/_view/www/dist/assets/index.js +262 -303
- inspect_ai/_view/www/src/App.mjs +6 -6
- inspect_ai/_view/www/src/Types.mjs +1 -1
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/index.js +2 -2
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/model/_call_tools.py +55 -0
- inspect_ai/model/_conversation.py +1 -4
- inspect_ai/model/_generate_config.py +2 -8
- inspect_ai/model/_model_output.py +15 -0
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +52 -11
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +2 -1
- inspect_ai/model/_providers/openai.py +36 -202
- inspect_ai/model/_providers/openai_o1.py +2 -4
- inspect_ai/model/_providers/providers.py +22 -0
- inspect_ai/model/_providers/together.py +4 -4
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_scorer.py +2 -1
- inspect_ai/solver/__init__.py +2 -0
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +68 -63
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,383 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Literal
|
3
|
+
|
4
|
+
from openai.types.chat import (
|
5
|
+
ChatCompletion,
|
6
|
+
ChatCompletionAssistantMessageParam,
|
7
|
+
ChatCompletionContentPartImageParam,
|
8
|
+
ChatCompletionContentPartInputAudioParam,
|
9
|
+
ChatCompletionContentPartParam,
|
10
|
+
ChatCompletionContentPartRefusalParam,
|
11
|
+
ChatCompletionContentPartTextParam,
|
12
|
+
ChatCompletionDeveloperMessageParam,
|
13
|
+
ChatCompletionMessage,
|
14
|
+
ChatCompletionMessageParam,
|
15
|
+
ChatCompletionMessageToolCall,
|
16
|
+
ChatCompletionMessageToolCallParam,
|
17
|
+
ChatCompletionNamedToolChoiceParam,
|
18
|
+
ChatCompletionSystemMessageParam,
|
19
|
+
ChatCompletionToolChoiceOptionParam,
|
20
|
+
ChatCompletionToolMessageParam,
|
21
|
+
ChatCompletionToolParam,
|
22
|
+
ChatCompletionUserMessageParam,
|
23
|
+
)
|
24
|
+
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
|
25
|
+
from openai.types.chat.chat_completion_message_tool_call import Function
|
26
|
+
from openai.types.completion_usage import CompletionUsage
|
27
|
+
from openai.types.shared_params.function_definition import FunctionDefinition
|
28
|
+
|
29
|
+
from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentText
|
30
|
+
from inspect_ai._util.images import file_as_data_uri
|
31
|
+
from inspect_ai._util.url import is_http_url
|
32
|
+
from inspect_ai.model._call_tools import parse_tool_call
|
33
|
+
from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs
|
34
|
+
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
35
|
+
|
36
|
+
from ._chat_message import (
|
37
|
+
ChatMessage,
|
38
|
+
ChatMessageAssistant,
|
39
|
+
ChatMessageSystem,
|
40
|
+
ChatMessageTool,
|
41
|
+
ChatMessageUser,
|
42
|
+
)
|
43
|
+
from ._model_output import ModelUsage, StopReason, as_stop_reason
|
44
|
+
|
45
|
+
|
46
|
+
def is_o1(name: str) -> bool:
|
47
|
+
return name.startswith("o1")
|
48
|
+
|
49
|
+
|
50
|
+
def is_o1_full(name: str) -> bool:
|
51
|
+
return is_o1(name) and not is_o1_mini(name) and not is_o1_preview(name)
|
52
|
+
|
53
|
+
|
54
|
+
def is_o1_mini(name: str) -> bool:
|
55
|
+
return name.startswith("o1-mini")
|
56
|
+
|
57
|
+
|
58
|
+
def is_o1_preview(name: str) -> bool:
|
59
|
+
return name.startswith("o1-preview")
|
60
|
+
|
61
|
+
|
62
|
+
def openai_chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCall:
|
63
|
+
return ChatCompletionMessageToolCall(
|
64
|
+
type="function",
|
65
|
+
id=tool_call.id,
|
66
|
+
function=Function(
|
67
|
+
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
68
|
+
),
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def openai_chat_tool_call_param(
|
73
|
+
tool_call: ToolCall,
|
74
|
+
) -> ChatCompletionMessageToolCallParam:
|
75
|
+
return ChatCompletionMessageToolCallParam(
|
76
|
+
id=tool_call.id,
|
77
|
+
function=dict(
|
78
|
+
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
79
|
+
),
|
80
|
+
type=tool_call.type,
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
async def openai_chat_completion_part(
|
85
|
+
content: Content,
|
86
|
+
) -> ChatCompletionContentPartParam:
|
87
|
+
if content.type == "text":
|
88
|
+
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
89
|
+
elif content.type == "image":
|
90
|
+
# API takes URL or base64 encoded file. If it's a remote file or
|
91
|
+
# data URL leave it alone, otherwise encode it
|
92
|
+
image_url = content.image
|
93
|
+
detail = content.detail
|
94
|
+
|
95
|
+
if not is_http_url(image_url):
|
96
|
+
image_url = await file_as_data_uri(image_url)
|
97
|
+
|
98
|
+
return ChatCompletionContentPartImageParam(
|
99
|
+
type="image_url",
|
100
|
+
image_url=dict(url=image_url, detail=detail),
|
101
|
+
)
|
102
|
+
elif content.type == "audio":
|
103
|
+
audio_data = await file_as_data_uri(content.audio)
|
104
|
+
|
105
|
+
return ChatCompletionContentPartInputAudioParam(
|
106
|
+
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
107
|
+
)
|
108
|
+
|
109
|
+
else:
|
110
|
+
raise RuntimeError(
|
111
|
+
"Video content is not currently supported by Open AI chat models."
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
async def openai_chat_message(
|
116
|
+
message: ChatMessage, model: str
|
117
|
+
) -> ChatCompletionMessageParam:
|
118
|
+
if message.role == "system":
|
119
|
+
if is_o1(model):
|
120
|
+
return ChatCompletionDeveloperMessageParam(
|
121
|
+
role="developer", content=message.text
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
return ChatCompletionSystemMessageParam(
|
125
|
+
role=message.role, content=message.text
|
126
|
+
)
|
127
|
+
elif message.role == "user":
|
128
|
+
return ChatCompletionUserMessageParam(
|
129
|
+
role=message.role,
|
130
|
+
content=(
|
131
|
+
message.content
|
132
|
+
if isinstance(message.content, str)
|
133
|
+
else [
|
134
|
+
await openai_chat_completion_part(content)
|
135
|
+
for content in message.content
|
136
|
+
]
|
137
|
+
),
|
138
|
+
)
|
139
|
+
elif message.role == "assistant":
|
140
|
+
if message.tool_calls:
|
141
|
+
return ChatCompletionAssistantMessageParam(
|
142
|
+
role=message.role,
|
143
|
+
content=message.text,
|
144
|
+
tool_calls=[
|
145
|
+
openai_chat_tool_call_param(call) for call in message.tool_calls
|
146
|
+
],
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
return ChatCompletionAssistantMessageParam(
|
150
|
+
role=message.role, content=message.text
|
151
|
+
)
|
152
|
+
elif message.role == "tool":
|
153
|
+
return ChatCompletionToolMessageParam(
|
154
|
+
role=message.role,
|
155
|
+
content=(
|
156
|
+
f"Error: {message.error.message}" if message.error else message.text
|
157
|
+
),
|
158
|
+
tool_call_id=str(message.tool_call_id),
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
raise ValueError(f"Unexpected message role {message.role}")
|
162
|
+
|
163
|
+
|
164
|
+
async def openai_chat_messages(
|
165
|
+
messages: list[ChatMessage], model: str
|
166
|
+
) -> list[ChatCompletionMessageParam]:
|
167
|
+
return [await openai_chat_message(message, model) for message in messages]
|
168
|
+
|
169
|
+
|
170
|
+
def openai_chat_choices(choices: list[ChatCompletionChoice]) -> list[Choice]:
|
171
|
+
oai_choices: list[Choice] = []
|
172
|
+
|
173
|
+
for index, choice in enumerate(choices):
|
174
|
+
if isinstance(choice.message.content, str):
|
175
|
+
content = choice.message.content
|
176
|
+
else:
|
177
|
+
content = "\n".join(
|
178
|
+
[c.text for c in choice.message.content if c.type == "text"]
|
179
|
+
)
|
180
|
+
if choice.message.tool_calls:
|
181
|
+
tool_calls = [openai_chat_tool_call(tc) for tc in choice.message.tool_calls]
|
182
|
+
else:
|
183
|
+
tool_calls = None
|
184
|
+
message = ChatCompletionMessage(
|
185
|
+
role="assistant", content=content, tool_calls=tool_calls
|
186
|
+
)
|
187
|
+
oai_choices.append(
|
188
|
+
Choice(
|
189
|
+
finish_reason=openai_finish_reason(choice.stop_reason),
|
190
|
+
index=index,
|
191
|
+
message=message,
|
192
|
+
logprobs=ChoiceLogprobs(**choice.logprobs.model_dump())
|
193
|
+
if choice.logprobs is not None
|
194
|
+
else None,
|
195
|
+
)
|
196
|
+
)
|
197
|
+
|
198
|
+
return oai_choices
|
199
|
+
|
200
|
+
|
201
|
+
def openai_completion_usage(usage: ModelUsage) -> CompletionUsage:
|
202
|
+
return CompletionUsage(
|
203
|
+
completion_tokens=usage.output_tokens,
|
204
|
+
prompt_tokens=usage.input_tokens,
|
205
|
+
total_tokens=usage.total_tokens,
|
206
|
+
)
|
207
|
+
|
208
|
+
|
209
|
+
def openai_finish_reason(
|
210
|
+
stop_reason: StopReason,
|
211
|
+
) -> Literal["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
212
|
+
match stop_reason:
|
213
|
+
case "stop" | "tool_calls" | "content_filter":
|
214
|
+
return stop_reason
|
215
|
+
case "model_length":
|
216
|
+
return "length"
|
217
|
+
case _:
|
218
|
+
return "stop"
|
219
|
+
|
220
|
+
|
221
|
+
def openai_chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
|
222
|
+
function = FunctionDefinition(
|
223
|
+
name=tool.name,
|
224
|
+
description=tool.description,
|
225
|
+
parameters=tool.parameters.model_dump(exclude_none=True),
|
226
|
+
)
|
227
|
+
return ChatCompletionToolParam(type="function", function=function)
|
228
|
+
|
229
|
+
|
230
|
+
def openai_chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
|
231
|
+
return [openai_chat_tool_param(tool) for tool in tools]
|
232
|
+
|
233
|
+
|
234
|
+
def openai_chat_tool_choice(
|
235
|
+
tool_choice: ToolChoice,
|
236
|
+
) -> ChatCompletionToolChoiceOptionParam:
|
237
|
+
if isinstance(tool_choice, ToolFunction):
|
238
|
+
return ChatCompletionNamedToolChoiceParam(
|
239
|
+
type="function", function=dict(name=tool_choice.name)
|
240
|
+
)
|
241
|
+
# openai supports 'any' via the 'required' keyword
|
242
|
+
elif tool_choice == "any":
|
243
|
+
return "required"
|
244
|
+
else:
|
245
|
+
return tool_choice
|
246
|
+
|
247
|
+
|
248
|
+
def chat_tool_calls_from_openai(
|
249
|
+
message: ChatCompletionMessage, tools: list[ToolInfo]
|
250
|
+
) -> list[ToolCall] | None:
|
251
|
+
if message.tool_calls:
|
252
|
+
return [
|
253
|
+
parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
|
254
|
+
for call in message.tool_calls
|
255
|
+
]
|
256
|
+
else:
|
257
|
+
return None
|
258
|
+
|
259
|
+
|
260
|
+
def chat_messages_from_openai(
|
261
|
+
messages: list[ChatCompletionMessageParam],
|
262
|
+
) -> list[ChatMessage]:
|
263
|
+
# track tool names by id
|
264
|
+
tool_names: dict[str, str] = {}
|
265
|
+
|
266
|
+
chat_messages: list[ChatMessage] = []
|
267
|
+
|
268
|
+
for message in messages:
|
269
|
+
if message["role"] == "system" or message["role"] == "developer":
|
270
|
+
sys_content = message["content"]
|
271
|
+
if isinstance(sys_content, str):
|
272
|
+
chat_messages.append(ChatMessageSystem(content=sys_content))
|
273
|
+
else:
|
274
|
+
chat_messages.append(
|
275
|
+
ChatMessageSystem(
|
276
|
+
content=[content_from_openai(c) for c in sys_content]
|
277
|
+
)
|
278
|
+
)
|
279
|
+
elif message["role"] == "user":
|
280
|
+
user_content = message["content"]
|
281
|
+
if isinstance(user_content, str):
|
282
|
+
chat_messages.append(ChatMessageUser(content=user_content))
|
283
|
+
else:
|
284
|
+
chat_messages.append(
|
285
|
+
ChatMessageUser(
|
286
|
+
content=[content_from_openai(c) for c in user_content]
|
287
|
+
)
|
288
|
+
)
|
289
|
+
elif message["role"] == "assistant":
|
290
|
+
# resolve content
|
291
|
+
asst_content = message["content"]
|
292
|
+
if isinstance(asst_content, str):
|
293
|
+
content: str | list[Content] = asst_content
|
294
|
+
elif asst_content is None:
|
295
|
+
content = message.get("refusal", None) or ""
|
296
|
+
else:
|
297
|
+
content = [content_from_openai(c) for c in asst_content]
|
298
|
+
|
299
|
+
# return message
|
300
|
+
if "tool_calls" in message:
|
301
|
+
tool_calls: list[ToolCall] = []
|
302
|
+
for tc in message["tool_calls"]:
|
303
|
+
tool_calls.append(tool_call_from_openai(tc))
|
304
|
+
tool_names[tc["id"]] = tc["function"]["name"]
|
305
|
+
|
306
|
+
else:
|
307
|
+
tool_calls = []
|
308
|
+
chat_messages.append(
|
309
|
+
ChatMessageAssistant(content=content, tool_calls=tool_calls or None)
|
310
|
+
)
|
311
|
+
elif message["role"] == "tool":
|
312
|
+
tool_content = message.get("content", None) or ""
|
313
|
+
if isinstance(tool_content, str):
|
314
|
+
content = tool_content
|
315
|
+
else:
|
316
|
+
content = [content_from_openai(c) for c in tool_content]
|
317
|
+
chat_messages.append(
|
318
|
+
ChatMessageTool(
|
319
|
+
content=content,
|
320
|
+
tool_call_id=message["tool_call_id"],
|
321
|
+
function=tool_names.get(message["tool_call_id"], ""),
|
322
|
+
)
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
raise ValueError(f"Unexpected message param type: {type(message)}")
|
326
|
+
|
327
|
+
return chat_messages
|
328
|
+
|
329
|
+
|
330
|
+
def tool_call_from_openai(tool_call: ChatCompletionMessageToolCallParam) -> ToolCall:
|
331
|
+
return parse_tool_call(
|
332
|
+
tool_call["id"],
|
333
|
+
tool_call["function"]["name"],
|
334
|
+
tool_call["function"]["arguments"],
|
335
|
+
)
|
336
|
+
|
337
|
+
|
338
|
+
def content_from_openai(
|
339
|
+
content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
|
340
|
+
) -> Content:
|
341
|
+
if content["type"] == "text":
|
342
|
+
return ContentText(text=content["text"])
|
343
|
+
elif content["type"] == "image_url":
|
344
|
+
return ContentImage(
|
345
|
+
image=content["image_url"]["url"], detail=content["image_url"]["detail"]
|
346
|
+
)
|
347
|
+
elif content["type"] == "input_audio":
|
348
|
+
return ContentAudio(
|
349
|
+
audio=content["input_audio"]["data"],
|
350
|
+
format=content["input_audio"]["format"],
|
351
|
+
)
|
352
|
+
elif content["type"] == "refusal":
|
353
|
+
return ContentText(text=content["refusal"])
|
354
|
+
|
355
|
+
|
356
|
+
def chat_message_assistant_from_openai(
|
357
|
+
message: ChatCompletionMessage, tools: list[ToolInfo]
|
358
|
+
) -> ChatMessageAssistant:
|
359
|
+
refusal = getattr(message, "refusal", None)
|
360
|
+
return ChatMessageAssistant(
|
361
|
+
content=refusal or message.content or "",
|
362
|
+
source="generate",
|
363
|
+
tool_calls=chat_tool_calls_from_openai(message, tools),
|
364
|
+
)
|
365
|
+
|
366
|
+
|
367
|
+
def chat_choices_from_openai(
|
368
|
+
response: ChatCompletion, tools: list[ToolInfo]
|
369
|
+
) -> list[ChatCompletionChoice]:
|
370
|
+
choices = list(response.choices)
|
371
|
+
choices.sort(key=lambda c: c.index)
|
372
|
+
return [
|
373
|
+
ChatCompletionChoice(
|
374
|
+
message=chat_message_assistant_from_openai(choice.message, tools),
|
375
|
+
stop_reason=as_stop_reason(choice.finish_reason),
|
376
|
+
logprobs=(
|
377
|
+
Logprobs(**choice.logprobs.model_dump())
|
378
|
+
if choice.logprobs is not None
|
379
|
+
else None
|
380
|
+
),
|
381
|
+
)
|
382
|
+
for choice in choices
|
383
|
+
]
|
@@ -14,10 +14,13 @@ from anthropic import (
|
|
14
14
|
APIConnectionError,
|
15
15
|
AsyncAnthropic,
|
16
16
|
AsyncAnthropicBedrock,
|
17
|
+
AsyncAnthropicVertex,
|
17
18
|
BadRequestError,
|
18
19
|
InternalServerError,
|
20
|
+
NotGiven,
|
19
21
|
RateLimitError,
|
20
22
|
)
|
23
|
+
from anthropic._types import Body
|
21
24
|
from anthropic.types import (
|
22
25
|
ImageBlockParam,
|
23
26
|
Message,
|
@@ -64,15 +67,25 @@ class AnthropicAPI(ModelAPI):
|
|
64
67
|
base_url: str | None = None,
|
65
68
|
api_key: str | None = None,
|
66
69
|
config: GenerateConfig = GenerateConfig(),
|
67
|
-
bedrock: bool = False,
|
68
70
|
**model_args: Any,
|
69
71
|
):
|
70
72
|
# extract any service prefix from model name
|
71
73
|
parts = model_name.split("/")
|
72
74
|
if len(parts) > 1:
|
73
|
-
service = parts[0]
|
74
|
-
bedrock = service == "bedrock"
|
75
|
+
self.service: str | None = parts[0]
|
75
76
|
model_name = "/".join(parts[1:])
|
77
|
+
else:
|
78
|
+
self.service = None
|
79
|
+
|
80
|
+
# collect gemerate model_args (then delete them so we can pass the rest on)
|
81
|
+
def collect_model_arg(name: str) -> Any | None:
|
82
|
+
nonlocal model_args
|
83
|
+
value = model_args.get(name, None)
|
84
|
+
if value is not None:
|
85
|
+
model_args.pop(name)
|
86
|
+
return value
|
87
|
+
|
88
|
+
self.extra_body: Body | None = collect_model_arg("extra_body")
|
76
89
|
|
77
90
|
# call super
|
78
91
|
super().__init__(
|
@@ -84,7 +97,7 @@ class AnthropicAPI(ModelAPI):
|
|
84
97
|
)
|
85
98
|
|
86
99
|
# create client
|
87
|
-
if
|
100
|
+
if self.is_bedrock():
|
88
101
|
base_url = model_base_url(
|
89
102
|
base_url, ["ANTHROPIC_BEDROCK_BASE_URL", "BEDROCK_ANTHROPIC_BASE_URL"]
|
90
103
|
)
|
@@ -95,7 +108,9 @@ class AnthropicAPI(ModelAPI):
|
|
95
108
|
if base_region is None:
|
96
109
|
aws_region = os.environ.get("AWS_DEFAULT_REGION", None)
|
97
110
|
|
98
|
-
self.client:
|
111
|
+
self.client: (
|
112
|
+
AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex
|
113
|
+
) = AsyncAnthropicBedrock(
|
99
114
|
base_url=base_url,
|
100
115
|
max_retries=(
|
101
116
|
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
@@ -103,6 +118,21 @@ class AnthropicAPI(ModelAPI):
|
|
103
118
|
aws_region=aws_region,
|
104
119
|
**model_args,
|
105
120
|
)
|
121
|
+
elif self.is_vertex():
|
122
|
+
base_url = model_base_url(
|
123
|
+
base_url, ["ANTHROPIC_VERTEX_BASE_URL", "VERTEX_ANTHROPIC_BASE_URL"]
|
124
|
+
)
|
125
|
+
region = os.environ.get("ANTHROPIC_VERTEX_REGION", NotGiven())
|
126
|
+
project_id = os.environ.get("ANTHROPIC_VERTEX_PROJECT_ID", NotGiven())
|
127
|
+
self.client = AsyncAnthropicVertex(
|
128
|
+
region=region,
|
129
|
+
project_id=project_id,
|
130
|
+
base_url=base_url,
|
131
|
+
max_retries=(
|
132
|
+
config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES
|
133
|
+
),
|
134
|
+
**model_args,
|
135
|
+
)
|
106
136
|
else:
|
107
137
|
# resolve api_key
|
108
138
|
if not self.api_key:
|
@@ -119,6 +149,12 @@ class AnthropicAPI(ModelAPI):
|
|
119
149
|
**model_args,
|
120
150
|
)
|
121
151
|
|
152
|
+
def is_bedrock(self) -> bool:
|
153
|
+
return self.service == "bedrock"
|
154
|
+
|
155
|
+
def is_vertex(self) -> bool:
|
156
|
+
return self.service == "vertex"
|
157
|
+
|
122
158
|
async def generate(
|
123
159
|
self,
|
124
160
|
input: list[ChatMessage],
|
@@ -163,6 +199,10 @@ class AnthropicAPI(ModelAPI):
|
|
163
199
|
if computer_use:
|
164
200
|
request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}
|
165
201
|
|
202
|
+
# extra_body
|
203
|
+
if self.extra_body is not None:
|
204
|
+
request["extra_body"] = self.extra_body
|
205
|
+
|
166
206
|
# make request
|
167
207
|
message = await self.client.messages.create(**request, stream=False)
|
168
208
|
|
@@ -466,6 +506,12 @@ def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolCh
|
|
466
506
|
|
467
507
|
|
468
508
|
async def message_param(message: ChatMessage) -> MessageParam:
|
509
|
+
# if content is empty that is going to result in an error when we replay
|
510
|
+
# this message to claude, so in that case insert a NO_CONTENT message
|
511
|
+
if isinstance(message.content, list) and len(message.content) == 0:
|
512
|
+
message = message.model_copy()
|
513
|
+
message.content = [ContentText(text=NO_CONTENT)]
|
514
|
+
|
469
515
|
# no system role for anthropic (this is more like an assertion,
|
470
516
|
# as these should have already been filtered out)
|
471
517
|
if message.role == "system":
|
@@ -507,7 +553,7 @@ async def message_param(message: ChatMessage) -> MessageParam:
|
|
507
553
|
elif message.role == "assistant" and message.tool_calls:
|
508
554
|
# first include content (claude <thinking>)
|
509
555
|
tools_content: list[TextBlockParam | ImageBlockParam | ToolUseBlockParam] = (
|
510
|
-
[TextBlockParam(type="text", text=message.content)]
|
556
|
+
[TextBlockParam(type="text", text=message.content or NO_CONTENT)]
|
511
557
|
if isinstance(message.content, str)
|
512
558
|
else (
|
513
559
|
[(await message_param_content(content)) for content in message.content]
|
@@ -576,11 +622,6 @@ def model_output_from_message(message: Message, tools: list[ToolInfo]) -> ModelO
|
|
576
622
|
)
|
577
623
|
)
|
578
624
|
|
579
|
-
# if content is empty that is going to result in an error when we replay
|
580
|
-
# this message to claude, so in that case insert a NO_CONTENT message
|
581
|
-
if len(content) == 0:
|
582
|
-
content = [ContentText(text=NO_CONTENT)]
|
583
|
-
|
584
625
|
# resolve choice
|
585
626
|
choice = ChatCompletionChoice(
|
586
627
|
message=ChatMessageAssistant(
|
@@ -37,6 +37,7 @@ from inspect_ai.tool import ToolChoice, ToolInfo
|
|
37
37
|
from inspect_ai.tool._tool_call import ToolCall
|
38
38
|
from inspect_ai.tool._tool_choice import ToolFunction
|
39
39
|
|
40
|
+
from .._call_tools import parse_tool_call
|
40
41
|
from .._chat_message import (
|
41
42
|
ChatMessage,
|
42
43
|
ChatMessageAssistant,
|
@@ -60,7 +61,6 @@ from .util import (
|
|
60
61
|
)
|
61
62
|
from .util.chatapi import ChatAPIHandler
|
62
63
|
from .util.llama31 import Llama31Handler
|
63
|
-
from .util.util import parse_tool_call
|
64
64
|
|
65
65
|
AZUREAI_API_KEY = "AZUREAI_API_KEY"
|
66
66
|
AZUREAI_ENDPOINT_KEY = "AZUREAI_ENDPOINT_KEY"
|