inspect-ai 0.3.68__py3-none-any.whl → 0.3.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/plain/display.py +9 -11
- inspect_ai/_display/textual/app.py +3 -4
- inspect_ai/_display/textual/widgets/samples.py +43 -8
- inspect_ai/_util/interrupt.py +9 -0
- inspect_ai/_util/logger.py +4 -0
- inspect_ai/_util/text.py +288 -1
- inspect_ai/_view/www/dist/assets/index.js +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +1 -1
- inspect_ai/log/_samples.py +0 -4
- inspect_ai/model/_model.py +3 -0
- inspect_ai/model/_providers/google.py +356 -302
- inspect_ai/model/_providers/mistral.py +10 -8
- inspect_ai/model/_providers/providers.py +5 -5
- inspect_ai/solver/_plan.py +3 -0
- inspect_ai/solver/_solver.py +3 -0
- inspect_ai/solver/_task_state.py +3 -1
- inspect_ai/util/_sandbox/docker/cleanup.py +8 -3
- inspect_ai/util/_sandbox/docker/compose.py +5 -9
- inspect_ai/util/_sandbox/docker/docker.py +14 -2
- inspect_ai/util/_sandbox/docker/util.py +10 -1
- inspect_ai/util/_sandbox/self_check.py +2 -1
- inspect_ai/util/_subprocess.py +4 -1
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/RECORD +28 -27
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.69.dist-info}/top_level.txt +0 -0
@@ -2,103 +2,96 @@ import asyncio
|
|
2
2
|
import functools
|
3
3
|
import hashlib
|
4
4
|
import json
|
5
|
+
import os
|
5
6
|
from copy import copy
|
6
7
|
from io import BytesIO
|
7
8
|
from logging import getLogger
|
8
|
-
from typing import Any
|
9
|
+
from typing import Any
|
9
10
|
|
10
|
-
|
11
|
-
from google.
|
12
|
-
|
11
|
+
# SDK Docs: https://googleapis.github.io/python-genai/
|
12
|
+
from google.genai import Client # type: ignore
|
13
|
+
from google.genai.errors import APIError, ClientError # type: ignore
|
14
|
+
from google.genai.types import ( # type: ignore
|
13
15
|
Candidate,
|
14
|
-
|
16
|
+
Content,
|
17
|
+
File,
|
18
|
+
FinishReason,
|
15
19
|
FunctionCallingConfig,
|
16
20
|
FunctionDeclaration,
|
17
21
|
FunctionResponse,
|
22
|
+
GenerateContentConfig,
|
23
|
+
GenerateContentResponse,
|
24
|
+
GenerateContentResponsePromptFeedback,
|
25
|
+
GenerateContentResponseUsageMetadata,
|
26
|
+
GenerationConfig,
|
27
|
+
HarmBlockThreshold,
|
28
|
+
HarmCategory,
|
18
29
|
Part,
|
30
|
+
SafetySetting,
|
31
|
+
SafetySettingDict,
|
19
32
|
Schema,
|
33
|
+
Tool,
|
20
34
|
ToolConfig,
|
21
35
|
Type,
|
22
36
|
)
|
23
|
-
from google.api_core.exceptions import (
|
24
|
-
GatewayTimeout,
|
25
|
-
InternalServerError,
|
26
|
-
InvalidArgument,
|
27
|
-
ServiceUnavailable,
|
28
|
-
TooManyRequests,
|
29
|
-
)
|
30
|
-
from google.api_core.retry.retry_base import if_transient_error
|
31
|
-
from google.generativeai.client import configure
|
32
|
-
from google.generativeai.files import get_file, upload_file
|
33
|
-
from google.generativeai.generative_models import GenerativeModel
|
34
|
-
from google.generativeai.types import (
|
35
|
-
ContentDict,
|
36
|
-
GenerationConfig,
|
37
|
-
PartDict,
|
38
|
-
PartType,
|
39
|
-
Tool,
|
40
|
-
)
|
41
|
-
from google.generativeai.types.file_types import File
|
42
|
-
from google.generativeai.types.generation_types import AsyncGenerateContentResponse
|
43
|
-
from google.generativeai.types.safety_types import (
|
44
|
-
EasySafetySettingDict,
|
45
|
-
HarmBlockThreshold,
|
46
|
-
HarmCategory,
|
47
|
-
)
|
48
|
-
from google.protobuf.json_format import MessageToDict, ParseDict
|
49
|
-
from google.protobuf.struct_pb2 import Struct
|
50
37
|
from pydantic import JsonValue
|
51
38
|
from typing_extensions import override
|
52
39
|
|
53
40
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
41
|
+
from inspect_ai._util.content import Content as InspectContent
|
54
42
|
from inspect_ai._util.content import (
|
55
|
-
Content,
|
56
43
|
ContentAudio,
|
57
44
|
ContentImage,
|
58
45
|
ContentText,
|
59
46
|
ContentVideo,
|
60
47
|
)
|
48
|
+
from inspect_ai._util.error import PrerequisiteError
|
61
49
|
from inspect_ai._util.images import file_as_data
|
62
50
|
from inspect_ai._util.kvstore import inspect_kvstore
|
63
51
|
from inspect_ai._util.trace import trace_message
|
64
|
-
from inspect_ai.
|
65
|
-
|
66
|
-
from .._chat_message import (
|
52
|
+
from inspect_ai.model import (
|
53
|
+
ChatCompletionChoice,
|
67
54
|
ChatMessage,
|
68
55
|
ChatMessageAssistant,
|
69
|
-
ChatMessageSystem,
|
70
56
|
ChatMessageTool,
|
71
57
|
ChatMessageUser,
|
72
|
-
|
73
|
-
from .._generate_config import GenerateConfig
|
74
|
-
from .._model import ModelAPI
|
75
|
-
from .._model_call import ModelCall
|
76
|
-
from .._model_output import (
|
77
|
-
ChatCompletionChoice,
|
58
|
+
GenerateConfig,
|
78
59
|
Logprob,
|
79
60
|
Logprobs,
|
61
|
+
ModelAPI,
|
80
62
|
ModelOutput,
|
81
63
|
ModelUsage,
|
82
64
|
StopReason,
|
83
65
|
TopLogprob,
|
84
66
|
)
|
85
|
-
from .
|
67
|
+
from inspect_ai.model._model_call import ModelCall
|
68
|
+
from inspect_ai.model._providers.util import model_base_url
|
69
|
+
from inspect_ai.tool import (
|
70
|
+
ToolCall,
|
71
|
+
ToolChoice,
|
72
|
+
ToolFunction,
|
73
|
+
ToolInfo,
|
74
|
+
ToolParam,
|
75
|
+
ToolParams,
|
76
|
+
)
|
86
77
|
|
87
78
|
logger = getLogger(__name__)
|
88
79
|
|
89
|
-
SAFETY_SETTINGS = "safety_settings"
|
90
80
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
81
|
+
GOOGLE_API_KEY = "GOOGLE_API_KEY"
|
82
|
+
VERTEX_API_KEY = "VERTEX_API_KEY"
|
83
|
+
|
84
|
+
SAFETY_SETTINGS = "safety_settings"
|
85
|
+
DEFAULT_SAFETY_SETTINGS = {
|
86
|
+
HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: HarmBlockThreshold.BLOCK_NONE,
|
95
87
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
88
|
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
89
|
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
90
|
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
96
91
|
}
|
97
92
|
|
98
|
-
GOOGLE_API_KEY = "GOOGLE_API_KEY"
|
99
|
-
|
100
93
|
|
101
|
-
class
|
94
|
+
class GoogleGenAIAPI(ModelAPI):
|
102
95
|
def __init__(
|
103
96
|
self,
|
104
97
|
model_name: str,
|
@@ -111,11 +104,11 @@ class GoogleAPI(ModelAPI):
|
|
111
104
|
model_name=model_name,
|
112
105
|
base_url=base_url,
|
113
106
|
api_key=api_key,
|
114
|
-
api_key_vars=[GOOGLE_API_KEY],
|
107
|
+
api_key_vars=[GOOGLE_API_KEY, VERTEX_API_KEY],
|
115
108
|
config=config,
|
116
109
|
)
|
117
110
|
|
118
|
-
# pick out
|
111
|
+
# pick out user-provided safety settings and merge against default
|
119
112
|
self.safety_settings = DEFAULT_SAFETY_SETTINGS.copy()
|
120
113
|
if SAFETY_SETTINGS in model_args:
|
121
114
|
self.safety_settings.update(
|
@@ -123,22 +116,79 @@ class GoogleAPI(ModelAPI):
|
|
123
116
|
)
|
124
117
|
del model_args[SAFETY_SETTINGS]
|
125
118
|
|
126
|
-
#
|
127
|
-
|
128
|
-
|
119
|
+
# extract any service prefix from model name
|
120
|
+
parts = model_name.split("/")
|
121
|
+
if len(parts) > 1:
|
122
|
+
self.service: str | None = parts[0]
|
123
|
+
model_name = "/".join(parts[1:])
|
124
|
+
else:
|
125
|
+
self.service = None
|
126
|
+
|
127
|
+
# vertex can also be forced by the GOOGLE_GENAI_USE_VERTEX_AI flag
|
128
|
+
if self.service is None:
|
129
|
+
if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true":
|
130
|
+
self.service = "vertex"
|
131
|
+
|
132
|
+
# ensure we haven't specified an invalid service
|
133
|
+
if self.service is not None and self.service != "vertex":
|
134
|
+
raise RuntimeError(
|
135
|
+
f"Invalid service name for google: {self.service}. "
|
136
|
+
+ "Currently 'vertex' is the only supported service."
|
137
|
+
)
|
138
|
+
|
139
|
+
# handle auth (vertex or standard google api key)
|
140
|
+
if self.is_vertex():
|
141
|
+
# see if we are running in express mode (propagate api key if we are)
|
142
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
|
143
|
+
vertex_api_key = os.environ.get(VERTEX_API_KEY, None)
|
144
|
+
if vertex_api_key and not self.api_key:
|
145
|
+
self.api_key = vertex_api_key
|
146
|
+
|
147
|
+
# When not using express mode the GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION
|
148
|
+
# environment variables should be set, OR the 'project' and 'location' should be
|
149
|
+
# passed within the model_args.
|
150
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2
|
151
|
+
if not vertex_api_key:
|
152
|
+
if not os.environ.get(
|
153
|
+
"GOOGLE_CLOUD_PROJECT", None
|
154
|
+
) and not model_args.get("project", None):
|
155
|
+
raise PrerequisiteError(
|
156
|
+
"Google provider requires either the GOOGLE_CLOUD_PROJECT environment variable "
|
157
|
+
+ "or the 'project' custom model arg (-M) when running against vertex."
|
158
|
+
)
|
159
|
+
if not os.environ.get(
|
160
|
+
"GOOGLE_CLOUD_LOCATION", None
|
161
|
+
) and not model_args.get("location", None):
|
162
|
+
raise PrerequisiteError(
|
163
|
+
"Google provider requires either the GOOGLE_CLOUD_LOCATION environment variable "
|
164
|
+
+ "or the 'location' custom model arg (-M) when running against vertex."
|
165
|
+
)
|
166
|
+
|
167
|
+
# normal google endpoint
|
168
|
+
else:
|
169
|
+
# read api key from env
|
170
|
+
if not self.api_key:
|
171
|
+
self.api_key = os.environ.get(GOOGLE_API_KEY, None)
|
172
|
+
|
173
|
+
# custom base_url
|
174
|
+
base_url = model_base_url(base_url, "GOOGLE_BASE_URL")
|
175
|
+
|
176
|
+
# create client
|
177
|
+
self.client = Client(
|
178
|
+
vertexai=self.is_vertex(),
|
129
179
|
api_key=self.api_key,
|
130
|
-
|
180
|
+
http_options={"base_url": base_url},
|
131
181
|
**model_args,
|
132
182
|
)
|
133
183
|
|
134
|
-
# create model
|
135
|
-
self.model = GenerativeModel(self.model_name)
|
136
|
-
|
137
184
|
@override
|
138
185
|
async def close(self) -> None:
|
139
186
|
# GenerativeModel uses a cached/shared client so there is no 'close'
|
140
187
|
pass
|
141
188
|
|
189
|
+
def is_vertex(self) -> bool:
|
190
|
+
return self.service == "vertex"
|
191
|
+
|
142
192
|
async def generate(
|
143
193
|
self,
|
144
194
|
input: list[ChatMessage],
|
@@ -146,7 +196,11 @@ class GoogleAPI(ModelAPI):
|
|
146
196
|
tool_choice: ToolChoice,
|
147
197
|
config: GenerateConfig,
|
148
198
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
149
|
-
|
199
|
+
# Create google-genai types.
|
200
|
+
gemini_contents = await as_chat_messages(self.client, input)
|
201
|
+
gemini_tools = chat_tools(tools) if len(tools) > 0 else None
|
202
|
+
gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
|
203
|
+
parameters = GenerateContentConfig(
|
150
204
|
temperature=config.temperature,
|
151
205
|
top_p=config.top_p,
|
152
206
|
top_k=config.top_k,
|
@@ -155,21 +209,19 @@ class GoogleAPI(ModelAPI):
|
|
155
209
|
candidate_count=config.num_choices,
|
156
210
|
presence_penalty=config.presence_penalty,
|
157
211
|
frequency_penalty=config.frequency_penalty,
|
212
|
+
safety_settings=safety_settings_to_list(self.safety_settings),
|
213
|
+
tools=gemini_tools,
|
214
|
+
tool_config=gemini_tool_config,
|
215
|
+
system_instruction=await extract_system_message_as_parts(
|
216
|
+
self.client, input
|
217
|
+
),
|
158
218
|
)
|
159
219
|
|
160
|
-
|
161
|
-
contents = await as_chat_messages(input)
|
162
|
-
|
163
|
-
# tools
|
164
|
-
gemini_tools = chat_tools(tools) if len(tools) > 0 else None
|
165
|
-
gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
|
166
|
-
|
167
|
-
# response for ModelCall
|
168
|
-
response: AsyncGenerateContentResponse | None = None
|
220
|
+
response: GenerateContentResponse | None = None
|
169
221
|
|
170
222
|
def model_call() -> ModelCall:
|
171
223
|
return build_model_call(
|
172
|
-
contents=
|
224
|
+
contents=gemini_contents,
|
173
225
|
safety_settings=self.safety_settings,
|
174
226
|
generation_config=parameters,
|
175
227
|
tools=gemini_tools,
|
@@ -178,163 +230,146 @@ class GoogleAPI(ModelAPI):
|
|
178
230
|
)
|
179
231
|
|
180
232
|
try:
|
181
|
-
response = await self.
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
tools=gemini_tools,
|
186
|
-
tool_config=gemini_tool_config,
|
233
|
+
response = await self.client.aio.models.generate_content(
|
234
|
+
model=self.model_name,
|
235
|
+
contents=gemini_contents,
|
236
|
+
config=parameters,
|
187
237
|
)
|
238
|
+
except ClientError as ex:
|
239
|
+
return self.handle_client_error(ex), model_call()
|
188
240
|
|
189
|
-
except InvalidArgument as ex:
|
190
|
-
return self.handle_invalid_argument(ex), model_call()
|
191
|
-
|
192
|
-
# build output
|
193
241
|
output = ModelOutput(
|
194
242
|
model=self.model_name,
|
195
|
-
choices=completion_choices_from_candidates(response
|
196
|
-
usage=
|
197
|
-
input_tokens=response.usage_metadata.prompt_token_count,
|
198
|
-
output_tokens=response.usage_metadata.candidates_token_count,
|
199
|
-
total_tokens=response.usage_metadata.total_token_count,
|
200
|
-
),
|
243
|
+
choices=completion_choices_from_candidates(response),
|
244
|
+
usage=usage_metadata_to_model_usage(response.usage_metadata),
|
201
245
|
)
|
202
246
|
|
203
|
-
# return
|
204
247
|
return output, model_call()
|
205
248
|
|
206
|
-
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
|
207
|
-
if "size exceeds the limit" in ex.message.lower():
|
208
|
-
return ModelOutput.from_content(
|
209
|
-
model=self.model_name, content=ex.message, stop_reason="model_length"
|
210
|
-
)
|
211
|
-
else:
|
212
|
-
return ex
|
213
|
-
|
214
249
|
@override
|
215
250
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
216
|
-
return isinstance(
|
217
|
-
ex,
|
218
|
-
TooManyRequests | InternalServerError | ServiceUnavailable | GatewayTimeout,
|
219
|
-
)
|
251
|
+
return isinstance(ex, APIError) and ex.code in (429, 500, 503, 504)
|
220
252
|
|
221
253
|
@override
|
222
254
|
def connection_key(self) -> str:
|
223
255
|
"""Scope for enforcing max_connections (could also use endpoint)."""
|
224
256
|
return self.model_name
|
225
257
|
|
258
|
+
def handle_client_error(self, ex: ClientError) -> ModelOutput | Exception:
|
259
|
+
if (
|
260
|
+
ex.code == 400
|
261
|
+
and ex.message
|
262
|
+
and (
|
263
|
+
"maximum number of tokens" in ex.message
|
264
|
+
or "size exceeds the limit" in ex.message
|
265
|
+
)
|
266
|
+
):
|
267
|
+
return ModelOutput.from_content(
|
268
|
+
self.model_name, content=ex.message, stop_reason="model_length"
|
269
|
+
)
|
270
|
+
else:
|
271
|
+
raise ex
|
272
|
+
|
273
|
+
|
274
|
+
def safety_settings_to_list(safety_settings: SafetySettingDict) -> list[SafetySetting]:
|
275
|
+
return [
|
276
|
+
SafetySetting(
|
277
|
+
category=category,
|
278
|
+
threshold=threshold,
|
279
|
+
)
|
280
|
+
for category, threshold in safety_settings.items()
|
281
|
+
]
|
282
|
+
|
226
283
|
|
227
284
|
def build_model_call(
|
228
|
-
contents: list[
|
285
|
+
contents: list[Content],
|
229
286
|
generation_config: GenerationConfig,
|
230
|
-
safety_settings:
|
287
|
+
safety_settings: SafetySettingDict,
|
231
288
|
tools: list[Tool] | None,
|
232
289
|
tool_config: ToolConfig | None,
|
233
|
-
response:
|
290
|
+
response: GenerateContentResponse | None,
|
234
291
|
) -> ModelCall:
|
235
292
|
return ModelCall.create(
|
236
293
|
request=dict(
|
237
|
-
contents=
|
294
|
+
contents=contents,
|
238
295
|
generation_config=generation_config,
|
239
296
|
safety_settings=safety_settings,
|
240
|
-
tools=
|
241
|
-
if
|
242
|
-
else None,
|
243
|
-
tool_config=MessageToDict(tool_config._pb)
|
244
|
-
if tool_config is not None
|
245
|
-
else None,
|
297
|
+
tools=tools if tools is not None else None,
|
298
|
+
tool_config=tool_config if tool_config is not None else None,
|
246
299
|
),
|
247
|
-
response=response
|
300
|
+
response=response if response is not None else {},
|
248
301
|
filter=model_call_filter,
|
249
302
|
)
|
250
303
|
|
251
304
|
|
252
305
|
def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
253
|
-
# remove images from raw api call
|
254
306
|
if key == "inline_data" and isinstance(value, dict) and "data" in value:
|
255
307
|
value = copy(value)
|
256
308
|
value.update(data=BASE_64_DATA_REMOVED)
|
257
309
|
return value
|
258
310
|
|
259
311
|
|
260
|
-
def
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
def model_call_part(part: PartType) -> PartType:
|
267
|
-
if isinstance(part, proto.Message):
|
268
|
-
return cast(PartDict, MessageToDict(part._pb))
|
269
|
-
elif isinstance(part, dict):
|
270
|
-
part = part.copy()
|
271
|
-
keys = list(part.keys())
|
272
|
-
for key in keys:
|
273
|
-
part[key] = model_call_part(part[key]) # type: ignore[literal-required]
|
274
|
-
return part
|
275
|
-
else:
|
276
|
-
return part
|
277
|
-
|
278
|
-
|
279
|
-
async def as_chat_messages(messages: list[ChatMessage]) -> list[ContentDict]:
|
280
|
-
# google does not support system messages so filter them out to start with
|
281
|
-
system_messages = [message for message in messages if message.role == "system"]
|
312
|
+
async def as_chat_messages(
|
313
|
+
client: Client, messages: list[ChatMessage]
|
314
|
+
) -> list[Content]:
|
315
|
+
# There is no "system" role in the `google-genai` package. Instead, system messages
|
316
|
+
# are included in the `GenerateContentConfig` as a `system_instruction`. Strip any
|
317
|
+
# system messages out.
|
282
318
|
supported_messages = [message for message in messages if message.role != "system"]
|
283
319
|
|
284
320
|
# build google chat messages
|
285
|
-
chat_messages = [await
|
286
|
-
|
287
|
-
# we want the system messages to be prepended to the first user message
|
288
|
-
# (if there is no first user message then prepend one)
|
289
|
-
prepend_system_messages(chat_messages, system_messages)
|
321
|
+
chat_messages = [await content(client, message) for message in supported_messages]
|
290
322
|
|
291
323
|
# combine consecutive tool messages
|
292
|
-
chat_messages = functools.reduce(
|
324
|
+
chat_messages = functools.reduce(
|
325
|
+
consecutive_tool_message_reducer, chat_messages, []
|
326
|
+
)
|
293
327
|
|
294
328
|
# return messages
|
295
329
|
return chat_messages
|
296
330
|
|
297
331
|
|
298
|
-
def
|
299
|
-
messages: list[
|
300
|
-
message:
|
301
|
-
) -> list[
|
332
|
+
def consecutive_tool_message_reducer(
|
333
|
+
messages: list[Content],
|
334
|
+
message: Content,
|
335
|
+
) -> list[Content]:
|
302
336
|
if (
|
303
|
-
message
|
337
|
+
message.role == "function"
|
304
338
|
and len(messages) > 0
|
305
|
-
and messages[-1]
|
339
|
+
and messages[-1].role == "function"
|
306
340
|
):
|
307
|
-
messages[-1] =
|
308
|
-
role="function", parts=messages[-1]
|
341
|
+
messages[-1] = Content(
|
342
|
+
role="function", parts=messages[-1].parts + message.parts
|
309
343
|
)
|
310
344
|
else:
|
311
345
|
messages.append(message)
|
312
346
|
return messages
|
313
347
|
|
314
348
|
|
315
|
-
async def
|
349
|
+
async def content(
|
350
|
+
client: Client,
|
316
351
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
317
|
-
) ->
|
352
|
+
) -> Content:
|
318
353
|
if isinstance(message, ChatMessageUser):
|
319
|
-
|
354
|
+
if isinstance(message.content, str):
|
355
|
+
return Content(
|
356
|
+
role="user", parts=[await content_part(client, message.content)]
|
357
|
+
)
|
358
|
+
return Content(
|
320
359
|
role="user",
|
321
360
|
parts=(
|
322
|
-
[message.content
|
323
|
-
if isinstance(message.content, str)
|
324
|
-
else [await content_part(content) for content in message.content]
|
361
|
+
[await content_part(client, content) for content in message.content]
|
325
362
|
),
|
326
363
|
)
|
327
364
|
elif isinstance(message, ChatMessageAssistant):
|
328
|
-
content_parts: list[
|
365
|
+
content_parts: list[Part] = []
|
329
366
|
# tool call parts
|
330
367
|
if message.tool_calls is not None:
|
331
368
|
content_parts.extend(
|
332
369
|
[
|
333
|
-
Part(
|
334
|
-
|
335
|
-
|
336
|
-
args=dict_to_struct(tool_call.arguments),
|
337
|
-
)
|
370
|
+
Part.from_function_call(
|
371
|
+
name=tool_call.function,
|
372
|
+
args=tool_call.arguments,
|
338
373
|
)
|
339
374
|
for tool_call in message.tool_calls
|
340
375
|
]
|
@@ -345,68 +380,62 @@ async def content_dict(
|
|
345
380
|
content_parts.append(Part(text=message.content or NO_CONTENT))
|
346
381
|
else:
|
347
382
|
content_parts.extend(
|
348
|
-
[await content_part(content) for content in message.content]
|
383
|
+
[await content_part(client, content) for content in message.content]
|
349
384
|
)
|
350
385
|
|
351
386
|
# return parts
|
352
|
-
return
|
387
|
+
return Content(role="model", parts=content_parts)
|
353
388
|
|
354
389
|
elif isinstance(message, ChatMessageTool):
|
355
390
|
response = FunctionResponse(
|
356
391
|
name=message.tool_call_id,
|
357
|
-
response=
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
else message.text
|
363
|
-
)
|
364
|
-
},
|
365
|
-
message=Struct(),
|
366
|
-
),
|
392
|
+
response={
|
393
|
+
"content": (
|
394
|
+
message.error.message if message.error is not None else message.text
|
395
|
+
)
|
396
|
+
},
|
367
397
|
)
|
368
|
-
return
|
369
|
-
|
398
|
+
return Content(role="function", parts=[Part(function_response=response)])
|
370
399
|
|
371
|
-
def dict_to_struct(x: dict[str, Any]) -> Struct:
|
372
|
-
struct = Struct()
|
373
|
-
struct.update(x)
|
374
|
-
return struct
|
375
400
|
|
376
|
-
|
377
|
-
async def content_part(content: Content | str) -> PartType:
|
401
|
+
async def content_part(client: Client, content: InspectContent | str) -> Part:
|
378
402
|
if isinstance(content, str):
|
379
|
-
return content or NO_CONTENT
|
403
|
+
return Part.from_text(text=content or NO_CONTENT)
|
380
404
|
elif isinstance(content, ContentText):
|
381
|
-
return content.text or NO_CONTENT
|
405
|
+
return Part.from_text(text=content.text or NO_CONTENT)
|
382
406
|
else:
|
383
|
-
return await chat_content_to_part(content)
|
407
|
+
return await chat_content_to_part(client, content)
|
384
408
|
|
385
409
|
|
386
410
|
async def chat_content_to_part(
|
411
|
+
client: Client,
|
387
412
|
content: ContentImage | ContentAudio | ContentVideo,
|
388
|
-
) ->
|
413
|
+
) -> Part:
|
389
414
|
if isinstance(content, ContentImage):
|
390
415
|
content_bytes, mime_type = await file_as_data(content.image)
|
391
|
-
return
|
392
|
-
else:
|
393
|
-
return await file_for_content(content)
|
394
|
-
|
395
|
-
|
396
|
-
def prepend_system_messages(
|
397
|
-
messages: list[ContentDict], system_messages: list[ChatMessageSystem]
|
398
|
-
) -> None:
|
399
|
-
# create system_parts
|
400
|
-
system_parts: list[PartType] = [
|
401
|
-
Part(text=message.text) for message in system_messages
|
402
|
-
]
|
403
|
-
|
404
|
-
# we want the system messages to be prepended to the first user message
|
405
|
-
# (if there is no first user message then prepend one)
|
406
|
-
if len(messages) > 0 and messages[0].get("role") == "user":
|
407
|
-
messages[0]["parts"] = system_parts + messages[0].get("parts", [])
|
416
|
+
return Part.from_bytes(mime_type=mime_type, data=content_bytes)
|
408
417
|
else:
|
409
|
-
|
418
|
+
return await file_for_content(client, content)
|
419
|
+
|
420
|
+
|
421
|
+
async def extract_system_message_as_parts(
|
422
|
+
client: Client,
|
423
|
+
messages: list[ChatMessage],
|
424
|
+
) -> list[Part] | None:
|
425
|
+
system_parts: list[Part] = []
|
426
|
+
for message in messages:
|
427
|
+
if message.role == "system":
|
428
|
+
content = message.content
|
429
|
+
if isinstance(content, str):
|
430
|
+
system_parts.append(Part.from_text(text=content))
|
431
|
+
elif isinstance(content, list): # list[InspectContent]
|
432
|
+
system_parts.extend(
|
433
|
+
[await content_part(client, content) for content in content]
|
434
|
+
)
|
435
|
+
else:
|
436
|
+
raise ValueError(f"Unsupported system message content: {content}")
|
437
|
+
# google-genai raises "ValueError: content is required." if the list is empty.
|
438
|
+
return system_parts or None
|
410
439
|
|
411
440
|
|
412
441
|
def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
@@ -424,8 +453,6 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
|
424
453
|
|
425
454
|
|
426
455
|
# https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
|
427
|
-
|
428
|
-
|
429
456
|
def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
|
430
457
|
if isinstance(param, ToolParams):
|
431
458
|
param = ToolParam(
|
@@ -461,7 +488,7 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
|
|
461
488
|
description=param.description,
|
462
489
|
properties={k: schema_from_param(v) for k, v in param.properties.items()}
|
463
490
|
if param.properties is not None
|
464
|
-
else
|
491
|
+
else {},
|
465
492
|
required=param.required,
|
466
493
|
nullable=nullable,
|
467
494
|
)
|
@@ -478,57 +505,56 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
|
|
478
505
|
|
479
506
|
|
480
507
|
def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
# by commenting this back in and running pytest -k google_tools
|
492
|
-
#
|
493
|
-
# if isinstance(tool_choice, ToolFunction):
|
494
|
-
# return ToolConfig(
|
495
|
-
# function_calling_config=FunctionCallingConfig(
|
496
|
-
# mode="ANY", allowed_function_names=[tool_choice.name]
|
497
|
-
# )
|
498
|
-
# )
|
499
|
-
# else:
|
500
|
-
# return ToolConfig(
|
501
|
-
# function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
|
502
|
-
# )
|
508
|
+
if isinstance(tool_choice, ToolFunction):
|
509
|
+
return ToolConfig(
|
510
|
+
function_calling_config=FunctionCallingConfig(
|
511
|
+
mode="ANY", allowed_function_names=[tool_choice.name]
|
512
|
+
)
|
513
|
+
)
|
514
|
+
else:
|
515
|
+
return ToolConfig(
|
516
|
+
function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
|
517
|
+
)
|
503
518
|
|
504
519
|
|
505
520
|
def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
|
506
521
|
# check for completion text
|
507
|
-
content = "
|
508
|
-
|
509
|
-
|
522
|
+
content = ""
|
523
|
+
# content can be None when the finish_reason is SAFETY
|
524
|
+
if candidate.content is not None:
|
525
|
+
content = " ".join(
|
526
|
+
[
|
527
|
+
part.text
|
528
|
+
for part in candidate.content.parts
|
529
|
+
if part.text is not None and candidate.content is not None
|
530
|
+
]
|
531
|
+
)
|
532
|
+
|
533
|
+
# split reasoning
|
534
|
+
reasoning, content = split_reasoning(content)
|
510
535
|
|
511
536
|
# now tool calls
|
512
537
|
tool_calls: list[ToolCall] = []
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
538
|
+
if candidate.content is not None and candidate.content.parts is not None:
|
539
|
+
for part in candidate.content.parts:
|
540
|
+
if part.function_call:
|
541
|
+
tool_calls.append(
|
542
|
+
ToolCall(
|
543
|
+
type="function",
|
544
|
+
id=part.function_call.name,
|
545
|
+
function=part.function_call.name,
|
546
|
+
arguments=part.function_call.args,
|
547
|
+
)
|
522
548
|
)
|
523
|
-
)
|
524
549
|
|
525
550
|
# stop reason
|
526
|
-
stop_reason =
|
551
|
+
stop_reason = finish_reason_to_stop_reason(candidate.finish_reason)
|
527
552
|
|
528
|
-
# build
|
553
|
+
# build choice
|
529
554
|
choice = ChatCompletionChoice(
|
530
555
|
message=ChatMessageAssistant(
|
531
556
|
content=content,
|
557
|
+
reasoning=reasoning,
|
532
558
|
tool_calls=tool_calls if len(tool_calls) > 0 else None,
|
533
559
|
source="generate",
|
534
560
|
),
|
@@ -558,111 +584,144 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
558
584
|
|
559
585
|
|
560
586
|
def completion_choices_from_candidates(
|
561
|
-
|
587
|
+
response: GenerateContentResponse,
|
562
588
|
) -> list[ChatCompletionChoice]:
|
589
|
+
candidates = response.candidates
|
563
590
|
if candidates:
|
564
591
|
candidates_list = sorted(candidates, key=lambda c: c.index)
|
565
592
|
return [
|
566
593
|
completion_choice_from_candidate(candidate) for candidate in candidates_list
|
567
594
|
]
|
568
|
-
|
595
|
+
elif response.prompt_feedback:
|
569
596
|
return [
|
570
597
|
ChatCompletionChoice(
|
571
598
|
message=ChatMessageAssistant(
|
572
|
-
content=
|
599
|
+
content=prompt_feedback_to_content(response.prompt_feedback),
|
573
600
|
source="generate",
|
574
601
|
),
|
575
|
-
stop_reason="
|
602
|
+
stop_reason="content_filter",
|
576
603
|
)
|
577
604
|
]
|
605
|
+
else:
|
606
|
+
raise RuntimeError(
|
607
|
+
"Google response includes no completion candidates and no block reason: "
|
608
|
+
+ f"{response.model_dump_json(indent=2)}"
|
609
|
+
)
|
578
610
|
|
579
611
|
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
OTHER = 5
|
612
|
+
def split_reasoning(content: str) -> tuple[str | None, str]:
|
613
|
+
separator = "\nFinal Answer: "
|
614
|
+
if separator in content:
|
615
|
+
parts = content.split(separator, 1) # dplit only on first occurrence
|
616
|
+
return parts[0].strip(), separator.lstrip() + parts[1].strip()
|
617
|
+
else:
|
618
|
+
return None, content.strip()
|
619
|
+
|
589
620
|
|
621
|
+
def prompt_feedback_to_content(
|
622
|
+
feedback: GenerateContentResponsePromptFeedback,
|
623
|
+
) -> str:
|
624
|
+
content: list[str] = []
|
625
|
+
block_reason = str(feedback.block_reason) if feedback.block_reason else "UNKNOWN"
|
626
|
+
content.append(f"BLOCKED: {block_reason}")
|
590
627
|
|
591
|
-
|
628
|
+
if feedback.block_reason_message is not None:
|
629
|
+
content.append(feedback.block_reason_message)
|
630
|
+
if feedback.safety_ratings is not None:
|
631
|
+
content.extend(
|
632
|
+
[rating.model_dump_json(indent=2) for rating in feedback.safety_ratings]
|
633
|
+
)
|
634
|
+
return "\n".join(content)
|
635
|
+
|
636
|
+
|
637
|
+
def usage_metadata_to_model_usage(
|
638
|
+
metadata: GenerateContentResponseUsageMetadata,
|
639
|
+
) -> ModelUsage | None:
|
640
|
+
if metadata is None:
|
641
|
+
return None
|
642
|
+
return ModelUsage(
|
643
|
+
input_tokens=metadata.prompt_token_count or 0,
|
644
|
+
output_tokens=metadata.candidates_token_count or 0,
|
645
|
+
total_tokens=metadata.total_token_count or 0,
|
646
|
+
)
|
647
|
+
|
648
|
+
|
649
|
+
def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
|
592
650
|
match finish_reason:
|
593
651
|
case FinishReason.STOP:
|
594
652
|
return "stop"
|
595
653
|
case FinishReason.MAX_TOKENS:
|
596
654
|
return "max_tokens"
|
597
|
-
case
|
655
|
+
case (
|
656
|
+
FinishReason.SAFETY
|
657
|
+
| FinishReason.RECITATION
|
658
|
+
| FinishReason.BLOCKLIST
|
659
|
+
| FinishReason.PROHIBITED_CONTENT
|
660
|
+
| FinishReason.SPII
|
661
|
+
):
|
598
662
|
return "content_filter"
|
599
663
|
case _:
|
600
664
|
return "unknown"
|
601
665
|
|
602
666
|
|
603
|
-
def gapi_should_retry(ex: BaseException) -> bool:
|
604
|
-
if isinstance(ex, Exception):
|
605
|
-
return if_transient_error(ex)
|
606
|
-
else:
|
607
|
-
return False
|
608
|
-
|
609
|
-
|
610
667
|
def parse_safety_settings(
|
611
668
|
safety_settings: Any,
|
612
|
-
) ->
|
669
|
+
) -> dict[HarmCategory, HarmBlockThreshold]:
|
613
670
|
# ensure we have a dict
|
614
671
|
if isinstance(safety_settings, str):
|
615
672
|
safety_settings = json.loads(safety_settings)
|
616
673
|
if not isinstance(safety_settings, dict):
|
617
674
|
raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
|
618
675
|
|
619
|
-
parsed_settings:
|
676
|
+
parsed_settings: dict[HarmCategory, HarmBlockThreshold] = {}
|
620
677
|
for key, value in safety_settings.items():
|
621
|
-
if isinstance(key, str):
|
622
|
-
key = str_to_harm_category(key)
|
623
|
-
if not isinstance(key, HarmCategory):
|
678
|
+
if not isinstance(key, str):
|
624
679
|
raise ValueError(f"Unexpected type for harm category: {key}")
|
625
|
-
if isinstance(value, str):
|
626
|
-
value = str_to_harm_block_threshold(value)
|
627
|
-
if not isinstance(value, HarmBlockThreshold):
|
680
|
+
if not isinstance(value, str):
|
628
681
|
raise ValueError(f"Unexpected type for harm block threshold: {value}")
|
629
|
-
|
682
|
+
key = str_to_harm_category(key)
|
683
|
+
value = str_to_harm_block_threshold(value)
|
630
684
|
parsed_settings[key] = value
|
631
|
-
|
632
685
|
return parsed_settings
|
633
686
|
|
634
687
|
|
635
|
-
def str_to_harm_category(category: str) ->
|
688
|
+
def str_to_harm_category(category: str) -> HarmCategory:
|
636
689
|
category = category.upper()
|
690
|
+
# `in` instead of `==` to allow users to pass in short version e.g. "HARASSMENT" or
|
691
|
+
# long version e.g. "HARM_CATEGORY_HARASSMENT" strings.
|
692
|
+
if "CIVIC_INTEGRITY" in category:
|
693
|
+
return HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY
|
694
|
+
if "DANGEROUS_CONTENT" in category:
|
695
|
+
return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
|
696
|
+
if "HATE_SPEECH" in category:
|
697
|
+
return HarmCategory.HARM_CATEGORY_HATE_SPEECH
|
637
698
|
if "HARASSMENT" in category:
|
638
|
-
return
|
639
|
-
|
640
|
-
return
|
641
|
-
|
642
|
-
return
|
643
|
-
|
644
|
-
return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
|
645
|
-
else:
|
646
|
-
# NOTE: Although there is an "UNSPECIFIED" category, in the
|
647
|
-
# documentation, the API does not accept it.
|
648
|
-
raise ValueError(f"Unknown HarmCategory: {category}")
|
699
|
+
return HarmCategory.HARM_CATEGORY_HARASSMENT
|
700
|
+
if "SEXUALLY_EXPLICIT" in category:
|
701
|
+
return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
|
702
|
+
if "UNSPECIFIED" in category:
|
703
|
+
return HarmCategory.HARM_CATEGORY_UNSPECIFIED
|
704
|
+
raise ValueError(f"Unknown HarmCategory: {category}")
|
649
705
|
|
650
706
|
|
651
|
-
def str_to_harm_block_threshold(threshold: str) ->
|
707
|
+
def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
|
652
708
|
threshold = threshold.upper()
|
653
709
|
if "LOW" in threshold:
|
654
710
|
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
655
|
-
|
711
|
+
if "MEDIUM" in threshold:
|
656
712
|
return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
|
657
|
-
|
713
|
+
if "HIGH" in threshold:
|
658
714
|
return HarmBlockThreshold.BLOCK_ONLY_HIGH
|
659
|
-
|
715
|
+
if "NONE" in threshold:
|
660
716
|
return HarmBlockThreshold.BLOCK_NONE
|
661
|
-
|
662
|
-
|
717
|
+
if "OFF" in threshold:
|
718
|
+
return HarmBlockThreshold.OFF
|
719
|
+
raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
|
663
720
|
|
664
721
|
|
665
|
-
async def file_for_content(
|
722
|
+
async def file_for_content(
|
723
|
+
client: Client, content: ContentAudio | ContentVideo
|
724
|
+
) -> File:
|
666
725
|
# helper to write trace messages
|
667
726
|
def trace(message: str) -> None:
|
668
727
|
trace_message(logger, "Google Files", message)
|
@@ -674,7 +733,6 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
674
733
|
file = content.video
|
675
734
|
content_bytes, mime_type = await file_as_data(file)
|
676
735
|
content_sha256 = hashlib.sha256(content_bytes).hexdigest()
|
677
|
-
|
678
736
|
# we cache uploads for re-use, open the db where we track that
|
679
737
|
# (track up to 1 million previous uploads)
|
680
738
|
with inspect_kvstore("google_files", 1000000) as files_db:
|
@@ -682,7 +740,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
682
740
|
uploaded_file = files_db.get(content_sha256)
|
683
741
|
if uploaded_file:
|
684
742
|
try:
|
685
|
-
upload =
|
743
|
+
upload: File = client.files.get(uploaded_file)
|
686
744
|
if upload.state.name == "ACTIVE":
|
687
745
|
trace(f"Using uploaded file: {uploaded_file}")
|
688
746
|
return upload
|
@@ -693,20 +751,16 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
693
751
|
except Exception as ex:
|
694
752
|
trace(f"Error attempting to access uploaded file: {ex}")
|
695
753
|
files_db.delete(content_sha256)
|
696
|
-
|
697
754
|
# do the upload (and record it)
|
698
|
-
upload =
|
755
|
+
upload = client.files.upload(BytesIO(content_bytes), mime_type=mime_type)
|
699
756
|
while upload.state.name == "PROCESSING":
|
700
757
|
await asyncio.sleep(3)
|
701
|
-
upload =
|
702
|
-
|
758
|
+
upload = client.files.get(upload.name)
|
703
759
|
if upload.state.name == "FAILED":
|
704
760
|
trace(f"Failed to upload file '{upload.name}: {upload.error}")
|
705
761
|
raise ValueError(f"Google file upload failed: {upload.error}")
|
706
|
-
|
707
762
|
# trace and record it
|
708
763
|
trace(f"Uploaded file: {upload.name}")
|
709
764
|
files_db.put(content_sha256, upload.name)
|
710
|
-
|
711
765
|
# return the file
|
712
766
|
return upload
|