inspect-ai 0.3.68__py3-none-any.whl → 0.3.70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +13 -1
- inspect_ai/_display/plain/display.py +9 -11
- inspect_ai/_display/textual/app.py +5 -5
- inspect_ai/_display/textual/widgets/samples.py +47 -18
- inspect_ai/_display/textual/widgets/transcript.py +25 -12
- inspect_ai/_eval/eval.py +14 -2
- inspect_ai/_eval/evalset.py +6 -1
- inspect_ai/_eval/run.py +6 -0
- inspect_ai/_eval/task/run.py +44 -15
- inspect_ai/_eval/task/task.py +26 -3
- inspect_ai/_util/interrupt.py +15 -0
- inspect_ai/_util/logger.py +23 -0
- inspect_ai/_util/rich.py +7 -8
- inspect_ai/_util/text.py +301 -1
- inspect_ai/_util/transcript.py +10 -2
- inspect_ai/_util/working.py +46 -0
- inspect_ai/_view/www/dist/assets/index.css +56 -12
- inspect_ai/_view/www/dist/assets/index.js +905 -751
- inspect_ai/_view/www/log-schema.json +337 -2
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
- inspect_ai/_view/www/node_modules/flatted/python/test.py +63 -0
- inspect_ai/_view/www/src/appearance/icons.ts +3 -1
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +0 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +28 -1
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +23 -2
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.module.css +32 -0
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +152 -0
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +9 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +19 -1
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +6 -3
- inspect_ai/_view/www/src/samples/transcript/types.ts +3 -1
- inspect_ai/_view/www/src/types/log.d.ts +188 -108
- inspect_ai/_view/www/src/utils/format.ts +7 -4
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +9 -6
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_condense.py +1 -0
- inspect_ai/log/_log.py +72 -12
- inspect_ai/log/_samples.py +5 -5
- inspect_ai/log/_transcript.py +31 -1
- inspect_ai/model/_call_tools.py +1 -1
- inspect_ai/model/_conversation.py +1 -1
- inspect_ai/model/_model.py +35 -16
- inspect_ai/model/_model_call.py +10 -3
- inspect_ai/model/_providers/anthropic.py +13 -2
- inspect_ai/model/_providers/bedrock.py +7 -0
- inspect_ai/model/_providers/cloudflare.py +20 -7
- inspect_ai/model/_providers/google.py +358 -302
- inspect_ai/model/_providers/groq.py +57 -23
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +81 -52
- inspect_ai/model/_providers/openai.py +9 -0
- inspect_ai/model/_providers/providers.py +6 -6
- inspect_ai/model/_providers/util/tracker.py +92 -0
- inspect_ai/model/_providers/vllm.py +13 -5
- inspect_ai/solver/_basic_agent.py +1 -3
- inspect_ai/solver/_bridge/patch.py +0 -2
- inspect_ai/solver/_limit.py +4 -4
- inspect_ai/solver/_plan.py +3 -3
- inspect_ai/solver/_solver.py +3 -0
- inspect_ai/solver/_task_state.py +10 -1
- inspect_ai/tool/_tools/_web_search.py +3 -3
- inspect_ai/util/_concurrency.py +14 -8
- inspect_ai/util/_sandbox/context.py +15 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +8 -3
- inspect_ai/util/_sandbox/docker/compose.py +5 -9
- inspect_ai/util/_sandbox/docker/docker.py +20 -6
- inspect_ai/util/_sandbox/docker/util.py +10 -1
- inspect_ai/util/_sandbox/environment.py +32 -1
- inspect_ai/util/_sandbox/events.py +149 -0
- inspect_ai/util/_sandbox/local.py +3 -3
- inspect_ai/util/_sandbox/self_check.py +2 -1
- inspect_ai/util/_subprocess.py +4 -1
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/METADATA +5 -5
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/RECORD +82 -74
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.68.dist-info → inspect_ai-0.3.70.dist-info}/top_level.txt +0 -0
@@ -2,103 +2,96 @@ import asyncio
|
|
2
2
|
import functools
|
3
3
|
import hashlib
|
4
4
|
import json
|
5
|
+
import os
|
5
6
|
from copy import copy
|
6
7
|
from io import BytesIO
|
7
8
|
from logging import getLogger
|
8
|
-
from typing import Any
|
9
|
+
from typing import Any
|
9
10
|
|
10
|
-
|
11
|
-
from google.
|
12
|
-
|
11
|
+
# SDK Docs: https://googleapis.github.io/python-genai/
|
12
|
+
from google.genai import Client # type: ignore
|
13
|
+
from google.genai.errors import APIError, ClientError # type: ignore
|
14
|
+
from google.genai.types import ( # type: ignore
|
13
15
|
Candidate,
|
14
|
-
|
16
|
+
Content,
|
17
|
+
File,
|
18
|
+
FinishReason,
|
15
19
|
FunctionCallingConfig,
|
16
20
|
FunctionDeclaration,
|
17
21
|
FunctionResponse,
|
22
|
+
GenerateContentConfig,
|
23
|
+
GenerateContentResponse,
|
24
|
+
GenerateContentResponsePromptFeedback,
|
25
|
+
GenerateContentResponseUsageMetadata,
|
26
|
+
GenerationConfig,
|
27
|
+
HarmBlockThreshold,
|
28
|
+
HarmCategory,
|
18
29
|
Part,
|
30
|
+
SafetySetting,
|
31
|
+
SafetySettingDict,
|
19
32
|
Schema,
|
33
|
+
Tool,
|
20
34
|
ToolConfig,
|
21
35
|
Type,
|
22
36
|
)
|
23
|
-
from google.api_core.exceptions import (
|
24
|
-
GatewayTimeout,
|
25
|
-
InternalServerError,
|
26
|
-
InvalidArgument,
|
27
|
-
ServiceUnavailable,
|
28
|
-
TooManyRequests,
|
29
|
-
)
|
30
|
-
from google.api_core.retry.retry_base import if_transient_error
|
31
|
-
from google.generativeai.client import configure
|
32
|
-
from google.generativeai.files import get_file, upload_file
|
33
|
-
from google.generativeai.generative_models import GenerativeModel
|
34
|
-
from google.generativeai.types import (
|
35
|
-
ContentDict,
|
36
|
-
GenerationConfig,
|
37
|
-
PartDict,
|
38
|
-
PartType,
|
39
|
-
Tool,
|
40
|
-
)
|
41
|
-
from google.generativeai.types.file_types import File
|
42
|
-
from google.generativeai.types.generation_types import AsyncGenerateContentResponse
|
43
|
-
from google.generativeai.types.safety_types import (
|
44
|
-
EasySafetySettingDict,
|
45
|
-
HarmBlockThreshold,
|
46
|
-
HarmCategory,
|
47
|
-
)
|
48
|
-
from google.protobuf.json_format import MessageToDict, ParseDict
|
49
|
-
from google.protobuf.struct_pb2 import Struct
|
50
37
|
from pydantic import JsonValue
|
51
38
|
from typing_extensions import override
|
52
39
|
|
53
40
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
41
|
+
from inspect_ai._util.content import Content as InspectContent
|
54
42
|
from inspect_ai._util.content import (
|
55
|
-
Content,
|
56
43
|
ContentAudio,
|
57
44
|
ContentImage,
|
58
45
|
ContentText,
|
59
46
|
ContentVideo,
|
60
47
|
)
|
48
|
+
from inspect_ai._util.error import PrerequisiteError
|
61
49
|
from inspect_ai._util.images import file_as_data
|
62
50
|
from inspect_ai._util.kvstore import inspect_kvstore
|
63
51
|
from inspect_ai._util.trace import trace_message
|
64
|
-
from inspect_ai.
|
65
|
-
|
66
|
-
from .._chat_message import (
|
52
|
+
from inspect_ai.model import (
|
53
|
+
ChatCompletionChoice,
|
67
54
|
ChatMessage,
|
68
55
|
ChatMessageAssistant,
|
69
|
-
ChatMessageSystem,
|
70
56
|
ChatMessageTool,
|
71
57
|
ChatMessageUser,
|
72
|
-
|
73
|
-
from .._generate_config import GenerateConfig
|
74
|
-
from .._model import ModelAPI
|
75
|
-
from .._model_call import ModelCall
|
76
|
-
from .._model_output import (
|
77
|
-
ChatCompletionChoice,
|
58
|
+
GenerateConfig,
|
78
59
|
Logprob,
|
79
60
|
Logprobs,
|
61
|
+
ModelAPI,
|
80
62
|
ModelOutput,
|
81
63
|
ModelUsage,
|
82
64
|
StopReason,
|
83
65
|
TopLogprob,
|
84
66
|
)
|
85
|
-
from .
|
67
|
+
from inspect_ai.model._model_call import ModelCall
|
68
|
+
from inspect_ai.model._providers.util import model_base_url
|
69
|
+
from inspect_ai.tool import (
|
70
|
+
ToolCall,
|
71
|
+
ToolChoice,
|
72
|
+
ToolFunction,
|
73
|
+
ToolInfo,
|
74
|
+
ToolParam,
|
75
|
+
ToolParams,
|
76
|
+
)
|
86
77
|
|
87
78
|
logger = getLogger(__name__)
|
88
79
|
|
89
|
-
SAFETY_SETTINGS = "safety_settings"
|
90
80
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
81
|
+
GOOGLE_API_KEY = "GOOGLE_API_KEY"
|
82
|
+
VERTEX_API_KEY = "VERTEX_API_KEY"
|
83
|
+
|
84
|
+
SAFETY_SETTINGS = "safety_settings"
|
85
|
+
DEFAULT_SAFETY_SETTINGS = {
|
86
|
+
HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: HarmBlockThreshold.BLOCK_NONE,
|
95
87
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
88
|
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
89
|
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
90
|
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
96
91
|
}
|
97
92
|
|
98
|
-
GOOGLE_API_KEY = "GOOGLE_API_KEY"
|
99
|
-
|
100
93
|
|
101
|
-
class
|
94
|
+
class GoogleGenAIAPI(ModelAPI):
|
102
95
|
def __init__(
|
103
96
|
self,
|
104
97
|
model_name: str,
|
@@ -111,11 +104,11 @@ class GoogleAPI(ModelAPI):
|
|
111
104
|
model_name=model_name,
|
112
105
|
base_url=base_url,
|
113
106
|
api_key=api_key,
|
114
|
-
api_key_vars=[GOOGLE_API_KEY],
|
107
|
+
api_key_vars=[GOOGLE_API_KEY, VERTEX_API_KEY],
|
115
108
|
config=config,
|
116
109
|
)
|
117
110
|
|
118
|
-
# pick out
|
111
|
+
# pick out user-provided safety settings and merge against default
|
119
112
|
self.safety_settings = DEFAULT_SAFETY_SETTINGS.copy()
|
120
113
|
if SAFETY_SETTINGS in model_args:
|
121
114
|
self.safety_settings.update(
|
@@ -123,22 +116,79 @@ class GoogleAPI(ModelAPI):
|
|
123
116
|
)
|
124
117
|
del model_args[SAFETY_SETTINGS]
|
125
118
|
|
126
|
-
#
|
127
|
-
|
128
|
-
|
119
|
+
# extract any service prefix from model name
|
120
|
+
parts = model_name.split("/")
|
121
|
+
if len(parts) > 1:
|
122
|
+
self.service: str | None = parts[0]
|
123
|
+
model_name = "/".join(parts[1:])
|
124
|
+
else:
|
125
|
+
self.service = None
|
126
|
+
|
127
|
+
# vertex can also be forced by the GOOGLE_GENAI_USE_VERTEX_AI flag
|
128
|
+
if self.service is None:
|
129
|
+
if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true":
|
130
|
+
self.service = "vertex"
|
131
|
+
|
132
|
+
# ensure we haven't specified an invalid service
|
133
|
+
if self.service is not None and self.service != "vertex":
|
134
|
+
raise RuntimeError(
|
135
|
+
f"Invalid service name for google: {self.service}. "
|
136
|
+
+ "Currently 'vertex' is the only supported service."
|
137
|
+
)
|
138
|
+
|
139
|
+
# handle auth (vertex or standard google api key)
|
140
|
+
if self.is_vertex():
|
141
|
+
# see if we are running in express mode (propagate api key if we are)
|
142
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
|
143
|
+
vertex_api_key = os.environ.get(VERTEX_API_KEY, None)
|
144
|
+
if vertex_api_key and not self.api_key:
|
145
|
+
self.api_key = vertex_api_key
|
146
|
+
|
147
|
+
# When not using express mode the GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION
|
148
|
+
# environment variables should be set, OR the 'project' and 'location' should be
|
149
|
+
# passed within the model_args.
|
150
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2
|
151
|
+
if not vertex_api_key:
|
152
|
+
if not os.environ.get(
|
153
|
+
"GOOGLE_CLOUD_PROJECT", None
|
154
|
+
) and not model_args.get("project", None):
|
155
|
+
raise PrerequisiteError(
|
156
|
+
"Google provider requires either the GOOGLE_CLOUD_PROJECT environment variable "
|
157
|
+
+ "or the 'project' custom model arg (-M) when running against vertex."
|
158
|
+
)
|
159
|
+
if not os.environ.get(
|
160
|
+
"GOOGLE_CLOUD_LOCATION", None
|
161
|
+
) and not model_args.get("location", None):
|
162
|
+
raise PrerequisiteError(
|
163
|
+
"Google provider requires either the GOOGLE_CLOUD_LOCATION environment variable "
|
164
|
+
+ "or the 'location' custom model arg (-M) when running against vertex."
|
165
|
+
)
|
166
|
+
|
167
|
+
# normal google endpoint
|
168
|
+
else:
|
169
|
+
# read api key from env
|
170
|
+
if not self.api_key:
|
171
|
+
self.api_key = os.environ.get(GOOGLE_API_KEY, None)
|
172
|
+
|
173
|
+
# custom base_url
|
174
|
+
base_url = model_base_url(base_url, "GOOGLE_BASE_URL")
|
175
|
+
|
176
|
+
# create client
|
177
|
+
self.client = Client(
|
178
|
+
vertexai=self.is_vertex(),
|
129
179
|
api_key=self.api_key,
|
130
|
-
|
180
|
+
http_options={"base_url": base_url},
|
131
181
|
**model_args,
|
132
182
|
)
|
133
183
|
|
134
|
-
# create model
|
135
|
-
self.model = GenerativeModel(self.model_name)
|
136
|
-
|
137
184
|
@override
|
138
185
|
async def close(self) -> None:
|
139
186
|
# GenerativeModel uses a cached/shared client so there is no 'close'
|
140
187
|
pass
|
141
188
|
|
189
|
+
def is_vertex(self) -> bool:
|
190
|
+
return self.service == "vertex"
|
191
|
+
|
142
192
|
async def generate(
|
143
193
|
self,
|
144
194
|
input: list[ChatMessage],
|
@@ -146,7 +196,11 @@ class GoogleAPI(ModelAPI):
|
|
146
196
|
tool_choice: ToolChoice,
|
147
197
|
config: GenerateConfig,
|
148
198
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
149
|
-
|
199
|
+
# Create google-genai types.
|
200
|
+
gemini_contents = await as_chat_messages(self.client, input)
|
201
|
+
gemini_tools = chat_tools(tools) if len(tools) > 0 else None
|
202
|
+
gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
|
203
|
+
parameters = GenerateContentConfig(
|
150
204
|
temperature=config.temperature,
|
151
205
|
top_p=config.top_p,
|
152
206
|
top_k=config.top_k,
|
@@ -155,21 +209,19 @@ class GoogleAPI(ModelAPI):
|
|
155
209
|
candidate_count=config.num_choices,
|
156
210
|
presence_penalty=config.presence_penalty,
|
157
211
|
frequency_penalty=config.frequency_penalty,
|
212
|
+
safety_settings=safety_settings_to_list(self.safety_settings),
|
213
|
+
tools=gemini_tools,
|
214
|
+
tool_config=gemini_tool_config,
|
215
|
+
system_instruction=await extract_system_message_as_parts(
|
216
|
+
self.client, input
|
217
|
+
),
|
158
218
|
)
|
159
219
|
|
160
|
-
|
161
|
-
contents = await as_chat_messages(input)
|
162
|
-
|
163
|
-
# tools
|
164
|
-
gemini_tools = chat_tools(tools) if len(tools) > 0 else None
|
165
|
-
gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None
|
166
|
-
|
167
|
-
# response for ModelCall
|
168
|
-
response: AsyncGenerateContentResponse | None = None
|
220
|
+
response: GenerateContentResponse | None = None
|
169
221
|
|
170
222
|
def model_call() -> ModelCall:
|
171
223
|
return build_model_call(
|
172
|
-
contents=
|
224
|
+
contents=gemini_contents,
|
173
225
|
safety_settings=self.safety_settings,
|
174
226
|
generation_config=parameters,
|
175
227
|
tools=gemini_tools,
|
@@ -177,164 +229,149 @@ class GoogleAPI(ModelAPI):
|
|
177
229
|
response=response,
|
178
230
|
)
|
179
231
|
|
232
|
+
# TODO: would need to monkey patch AuthorizedSession.request
|
233
|
+
|
180
234
|
try:
|
181
|
-
response = await self.
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
tools=gemini_tools,
|
186
|
-
tool_config=gemini_tool_config,
|
235
|
+
response = await self.client.aio.models.generate_content(
|
236
|
+
model=self.model_name,
|
237
|
+
contents=gemini_contents,
|
238
|
+
config=parameters,
|
187
239
|
)
|
240
|
+
except ClientError as ex:
|
241
|
+
return self.handle_client_error(ex), model_call()
|
188
242
|
|
189
|
-
except InvalidArgument as ex:
|
190
|
-
return self.handle_invalid_argument(ex), model_call()
|
191
|
-
|
192
|
-
# build output
|
193
243
|
output = ModelOutput(
|
194
244
|
model=self.model_name,
|
195
|
-
choices=completion_choices_from_candidates(response
|
196
|
-
usage=
|
197
|
-
input_tokens=response.usage_metadata.prompt_token_count,
|
198
|
-
output_tokens=response.usage_metadata.candidates_token_count,
|
199
|
-
total_tokens=response.usage_metadata.total_token_count,
|
200
|
-
),
|
245
|
+
choices=completion_choices_from_candidates(response),
|
246
|
+
usage=usage_metadata_to_model_usage(response.usage_metadata),
|
201
247
|
)
|
202
248
|
|
203
|
-
# return
|
204
249
|
return output, model_call()
|
205
250
|
|
206
|
-
def handle_invalid_argument(self, ex: InvalidArgument) -> ModelOutput | Exception:
|
207
|
-
if "size exceeds the limit" in ex.message.lower():
|
208
|
-
return ModelOutput.from_content(
|
209
|
-
model=self.model_name, content=ex.message, stop_reason="model_length"
|
210
|
-
)
|
211
|
-
else:
|
212
|
-
return ex
|
213
|
-
|
214
251
|
@override
|
215
252
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
216
|
-
return isinstance(
|
217
|
-
ex,
|
218
|
-
TooManyRequests | InternalServerError | ServiceUnavailable | GatewayTimeout,
|
219
|
-
)
|
253
|
+
return isinstance(ex, APIError) and ex.code in (429, 500, 503, 504)
|
220
254
|
|
221
255
|
@override
|
222
256
|
def connection_key(self) -> str:
|
223
257
|
"""Scope for enforcing max_connections (could also use endpoint)."""
|
224
258
|
return self.model_name
|
225
259
|
|
260
|
+
def handle_client_error(self, ex: ClientError) -> ModelOutput | Exception:
|
261
|
+
if (
|
262
|
+
ex.code == 400
|
263
|
+
and ex.message
|
264
|
+
and (
|
265
|
+
"maximum number of tokens" in ex.message
|
266
|
+
or "size exceeds the limit" in ex.message
|
267
|
+
)
|
268
|
+
):
|
269
|
+
return ModelOutput.from_content(
|
270
|
+
self.model_name, content=ex.message, stop_reason="model_length"
|
271
|
+
)
|
272
|
+
else:
|
273
|
+
raise ex
|
274
|
+
|
275
|
+
|
276
|
+
def safety_settings_to_list(safety_settings: SafetySettingDict) -> list[SafetySetting]:
|
277
|
+
return [
|
278
|
+
SafetySetting(
|
279
|
+
category=category,
|
280
|
+
threshold=threshold,
|
281
|
+
)
|
282
|
+
for category, threshold in safety_settings.items()
|
283
|
+
]
|
284
|
+
|
226
285
|
|
227
286
|
def build_model_call(
|
228
|
-
contents: list[
|
287
|
+
contents: list[Content],
|
229
288
|
generation_config: GenerationConfig,
|
230
|
-
safety_settings:
|
289
|
+
safety_settings: SafetySettingDict,
|
231
290
|
tools: list[Tool] | None,
|
232
291
|
tool_config: ToolConfig | None,
|
233
|
-
response:
|
292
|
+
response: GenerateContentResponse | None,
|
234
293
|
) -> ModelCall:
|
235
294
|
return ModelCall.create(
|
236
295
|
request=dict(
|
237
|
-
contents=
|
296
|
+
contents=contents,
|
238
297
|
generation_config=generation_config,
|
239
298
|
safety_settings=safety_settings,
|
240
|
-
tools=
|
241
|
-
if
|
242
|
-
else None,
|
243
|
-
tool_config=MessageToDict(tool_config._pb)
|
244
|
-
if tool_config is not None
|
245
|
-
else None,
|
299
|
+
tools=tools if tools is not None else None,
|
300
|
+
tool_config=tool_config if tool_config is not None else None,
|
246
301
|
),
|
247
|
-
response=response
|
302
|
+
response=response if response is not None else {},
|
248
303
|
filter=model_call_filter,
|
249
304
|
)
|
250
305
|
|
251
306
|
|
252
307
|
def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
253
|
-
# remove images from raw api call
|
254
308
|
if key == "inline_data" and isinstance(value, dict) and "data" in value:
|
255
309
|
value = copy(value)
|
256
310
|
value.update(data=BASE_64_DATA_REMOVED)
|
257
311
|
return value
|
258
312
|
|
259
313
|
|
260
|
-
def
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
def model_call_part(part: PartType) -> PartType:
|
267
|
-
if isinstance(part, proto.Message):
|
268
|
-
return cast(PartDict, MessageToDict(part._pb))
|
269
|
-
elif isinstance(part, dict):
|
270
|
-
part = part.copy()
|
271
|
-
keys = list(part.keys())
|
272
|
-
for key in keys:
|
273
|
-
part[key] = model_call_part(part[key]) # type: ignore[literal-required]
|
274
|
-
return part
|
275
|
-
else:
|
276
|
-
return part
|
277
|
-
|
278
|
-
|
279
|
-
async def as_chat_messages(messages: list[ChatMessage]) -> list[ContentDict]:
|
280
|
-
# google does not support system messages so filter them out to start with
|
281
|
-
system_messages = [message for message in messages if message.role == "system"]
|
314
|
+
async def as_chat_messages(
|
315
|
+
client: Client, messages: list[ChatMessage]
|
316
|
+
) -> list[Content]:
|
317
|
+
# There is no "system" role in the `google-genai` package. Instead, system messages
|
318
|
+
# are included in the `GenerateContentConfig` as a `system_instruction`. Strip any
|
319
|
+
# system messages out.
|
282
320
|
supported_messages = [message for message in messages if message.role != "system"]
|
283
321
|
|
284
322
|
# build google chat messages
|
285
|
-
chat_messages = [await
|
286
|
-
|
287
|
-
# we want the system messages to be prepended to the first user message
|
288
|
-
# (if there is no first user message then prepend one)
|
289
|
-
prepend_system_messages(chat_messages, system_messages)
|
323
|
+
chat_messages = [await content(client, message) for message in supported_messages]
|
290
324
|
|
291
325
|
# combine consecutive tool messages
|
292
|
-
chat_messages = functools.reduce(
|
326
|
+
chat_messages = functools.reduce(
|
327
|
+
consecutive_tool_message_reducer, chat_messages, []
|
328
|
+
)
|
293
329
|
|
294
330
|
# return messages
|
295
331
|
return chat_messages
|
296
332
|
|
297
333
|
|
298
|
-
def
|
299
|
-
messages: list[
|
300
|
-
message:
|
301
|
-
) -> list[
|
334
|
+
def consecutive_tool_message_reducer(
|
335
|
+
messages: list[Content],
|
336
|
+
message: Content,
|
337
|
+
) -> list[Content]:
|
302
338
|
if (
|
303
|
-
message
|
339
|
+
message.role == "function"
|
304
340
|
and len(messages) > 0
|
305
|
-
and messages[-1]
|
341
|
+
and messages[-1].role == "function"
|
306
342
|
):
|
307
|
-
messages[-1] =
|
308
|
-
role="function", parts=messages[-1]
|
343
|
+
messages[-1] = Content(
|
344
|
+
role="function", parts=messages[-1].parts + message.parts
|
309
345
|
)
|
310
346
|
else:
|
311
347
|
messages.append(message)
|
312
348
|
return messages
|
313
349
|
|
314
350
|
|
315
|
-
async def
|
351
|
+
async def content(
|
352
|
+
client: Client,
|
316
353
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
317
|
-
) ->
|
354
|
+
) -> Content:
|
318
355
|
if isinstance(message, ChatMessageUser):
|
319
|
-
|
356
|
+
if isinstance(message.content, str):
|
357
|
+
return Content(
|
358
|
+
role="user", parts=[await content_part(client, message.content)]
|
359
|
+
)
|
360
|
+
return Content(
|
320
361
|
role="user",
|
321
362
|
parts=(
|
322
|
-
[message.content
|
323
|
-
if isinstance(message.content, str)
|
324
|
-
else [await content_part(content) for content in message.content]
|
363
|
+
[await content_part(client, content) for content in message.content]
|
325
364
|
),
|
326
365
|
)
|
327
366
|
elif isinstance(message, ChatMessageAssistant):
|
328
|
-
content_parts: list[
|
367
|
+
content_parts: list[Part] = []
|
329
368
|
# tool call parts
|
330
369
|
if message.tool_calls is not None:
|
331
370
|
content_parts.extend(
|
332
371
|
[
|
333
|
-
Part(
|
334
|
-
|
335
|
-
|
336
|
-
args=dict_to_struct(tool_call.arguments),
|
337
|
-
)
|
372
|
+
Part.from_function_call(
|
373
|
+
name=tool_call.function,
|
374
|
+
args=tool_call.arguments,
|
338
375
|
)
|
339
376
|
for tool_call in message.tool_calls
|
340
377
|
]
|
@@ -345,68 +382,62 @@ async def content_dict(
|
|
345
382
|
content_parts.append(Part(text=message.content or NO_CONTENT))
|
346
383
|
else:
|
347
384
|
content_parts.extend(
|
348
|
-
[await content_part(content) for content in message.content]
|
385
|
+
[await content_part(client, content) for content in message.content]
|
349
386
|
)
|
350
387
|
|
351
388
|
# return parts
|
352
|
-
return
|
389
|
+
return Content(role="model", parts=content_parts)
|
353
390
|
|
354
391
|
elif isinstance(message, ChatMessageTool):
|
355
392
|
response = FunctionResponse(
|
356
393
|
name=message.tool_call_id,
|
357
|
-
response=
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
else message.text
|
363
|
-
)
|
364
|
-
},
|
365
|
-
message=Struct(),
|
366
|
-
),
|
394
|
+
response={
|
395
|
+
"content": (
|
396
|
+
message.error.message if message.error is not None else message.text
|
397
|
+
)
|
398
|
+
},
|
367
399
|
)
|
368
|
-
return
|
369
|
-
|
400
|
+
return Content(role="function", parts=[Part(function_response=response)])
|
370
401
|
|
371
|
-
def dict_to_struct(x: dict[str, Any]) -> Struct:
|
372
|
-
struct = Struct()
|
373
|
-
struct.update(x)
|
374
|
-
return struct
|
375
402
|
|
376
|
-
|
377
|
-
async def content_part(content: Content | str) -> PartType:
|
403
|
+
async def content_part(client: Client, content: InspectContent | str) -> Part:
|
378
404
|
if isinstance(content, str):
|
379
|
-
return content or NO_CONTENT
|
405
|
+
return Part.from_text(text=content or NO_CONTENT)
|
380
406
|
elif isinstance(content, ContentText):
|
381
|
-
return content.text or NO_CONTENT
|
407
|
+
return Part.from_text(text=content.text or NO_CONTENT)
|
382
408
|
else:
|
383
|
-
return await chat_content_to_part(content)
|
409
|
+
return await chat_content_to_part(client, content)
|
384
410
|
|
385
411
|
|
386
412
|
async def chat_content_to_part(
|
413
|
+
client: Client,
|
387
414
|
content: ContentImage | ContentAudio | ContentVideo,
|
388
|
-
) ->
|
415
|
+
) -> Part:
|
389
416
|
if isinstance(content, ContentImage):
|
390
417
|
content_bytes, mime_type = await file_as_data(content.image)
|
391
|
-
return
|
392
|
-
else:
|
393
|
-
return await file_for_content(content)
|
394
|
-
|
395
|
-
|
396
|
-
def prepend_system_messages(
|
397
|
-
messages: list[ContentDict], system_messages: list[ChatMessageSystem]
|
398
|
-
) -> None:
|
399
|
-
# create system_parts
|
400
|
-
system_parts: list[PartType] = [
|
401
|
-
Part(text=message.text) for message in system_messages
|
402
|
-
]
|
403
|
-
|
404
|
-
# we want the system messages to be prepended to the first user message
|
405
|
-
# (if there is no first user message then prepend one)
|
406
|
-
if len(messages) > 0 and messages[0].get("role") == "user":
|
407
|
-
messages[0]["parts"] = system_parts + messages[0].get("parts", [])
|
418
|
+
return Part.from_bytes(mime_type=mime_type, data=content_bytes)
|
408
419
|
else:
|
409
|
-
|
420
|
+
return await file_for_content(client, content)
|
421
|
+
|
422
|
+
|
423
|
+
async def extract_system_message_as_parts(
|
424
|
+
client: Client,
|
425
|
+
messages: list[ChatMessage],
|
426
|
+
) -> list[Part] | None:
|
427
|
+
system_parts: list[Part] = []
|
428
|
+
for message in messages:
|
429
|
+
if message.role == "system":
|
430
|
+
content = message.content
|
431
|
+
if isinstance(content, str):
|
432
|
+
system_parts.append(Part.from_text(text=content))
|
433
|
+
elif isinstance(content, list): # list[InspectContent]
|
434
|
+
system_parts.extend(
|
435
|
+
[await content_part(client, content) for content in content]
|
436
|
+
)
|
437
|
+
else:
|
438
|
+
raise ValueError(f"Unsupported system message content: {content}")
|
439
|
+
# google-genai raises "ValueError: content is required." if the list is empty.
|
440
|
+
return system_parts or None
|
410
441
|
|
411
442
|
|
412
443
|
def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
@@ -424,8 +455,6 @@ def chat_tools(tools: list[ToolInfo]) -> list[Tool]:
|
|
424
455
|
|
425
456
|
|
426
457
|
# https://ai.google.dev/gemini-api/tutorials/extract_structured_data#define_the_schema
|
427
|
-
|
428
|
-
|
429
458
|
def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) -> Schema:
|
430
459
|
if isinstance(param, ToolParams):
|
431
460
|
param = ToolParam(
|
@@ -461,7 +490,7 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
|
|
461
490
|
description=param.description,
|
462
491
|
properties={k: schema_from_param(v) for k, v in param.properties.items()}
|
463
492
|
if param.properties is not None
|
464
|
-
else
|
493
|
+
else {},
|
465
494
|
required=param.required,
|
466
495
|
nullable=nullable,
|
467
496
|
)
|
@@ -478,57 +507,56 @@ def schema_from_param(param: ToolParam | ToolParams, nullable: bool = False) ->
|
|
478
507
|
|
479
508
|
|
480
509
|
def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
# by commenting this back in and running pytest -k google_tools
|
492
|
-
#
|
493
|
-
# if isinstance(tool_choice, ToolFunction):
|
494
|
-
# return ToolConfig(
|
495
|
-
# function_calling_config=FunctionCallingConfig(
|
496
|
-
# mode="ANY", allowed_function_names=[tool_choice.name]
|
497
|
-
# )
|
498
|
-
# )
|
499
|
-
# else:
|
500
|
-
# return ToolConfig(
|
501
|
-
# function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
|
502
|
-
# )
|
510
|
+
if isinstance(tool_choice, ToolFunction):
|
511
|
+
return ToolConfig(
|
512
|
+
function_calling_config=FunctionCallingConfig(
|
513
|
+
mode="ANY", allowed_function_names=[tool_choice.name]
|
514
|
+
)
|
515
|
+
)
|
516
|
+
else:
|
517
|
+
return ToolConfig(
|
518
|
+
function_calling_config=FunctionCallingConfig(mode=tool_choice.upper())
|
519
|
+
)
|
503
520
|
|
504
521
|
|
505
522
|
def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
|
506
523
|
# check for completion text
|
507
|
-
content = "
|
508
|
-
|
509
|
-
|
524
|
+
content = ""
|
525
|
+
# content can be None when the finish_reason is SAFETY
|
526
|
+
if candidate.content is not None:
|
527
|
+
content = " ".join(
|
528
|
+
[
|
529
|
+
part.text
|
530
|
+
for part in candidate.content.parts
|
531
|
+
if part.text is not None and candidate.content is not None
|
532
|
+
]
|
533
|
+
)
|
534
|
+
|
535
|
+
# split reasoning
|
536
|
+
reasoning, content = split_reasoning(content)
|
510
537
|
|
511
538
|
# now tool calls
|
512
539
|
tool_calls: list[ToolCall] = []
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
540
|
+
if candidate.content is not None and candidate.content.parts is not None:
|
541
|
+
for part in candidate.content.parts:
|
542
|
+
if part.function_call:
|
543
|
+
tool_calls.append(
|
544
|
+
ToolCall(
|
545
|
+
type="function",
|
546
|
+
id=part.function_call.name,
|
547
|
+
function=part.function_call.name,
|
548
|
+
arguments=part.function_call.args,
|
549
|
+
)
|
522
550
|
)
|
523
|
-
)
|
524
551
|
|
525
552
|
# stop reason
|
526
|
-
stop_reason =
|
553
|
+
stop_reason = finish_reason_to_stop_reason(candidate.finish_reason)
|
527
554
|
|
528
|
-
# build
|
555
|
+
# build choice
|
529
556
|
choice = ChatCompletionChoice(
|
530
557
|
message=ChatMessageAssistant(
|
531
558
|
content=content,
|
559
|
+
reasoning=reasoning,
|
532
560
|
tool_calls=tool_calls if len(tool_calls) > 0 else None,
|
533
561
|
source="generate",
|
534
562
|
),
|
@@ -558,111 +586,144 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
558
586
|
|
559
587
|
|
560
588
|
def completion_choices_from_candidates(
|
561
|
-
|
589
|
+
response: GenerateContentResponse,
|
562
590
|
) -> list[ChatCompletionChoice]:
|
591
|
+
candidates = response.candidates
|
563
592
|
if candidates:
|
564
593
|
candidates_list = sorted(candidates, key=lambda c: c.index)
|
565
594
|
return [
|
566
595
|
completion_choice_from_candidate(candidate) for candidate in candidates_list
|
567
596
|
]
|
568
|
-
|
597
|
+
elif response.prompt_feedback:
|
569
598
|
return [
|
570
599
|
ChatCompletionChoice(
|
571
600
|
message=ChatMessageAssistant(
|
572
|
-
content=
|
601
|
+
content=prompt_feedback_to_content(response.prompt_feedback),
|
573
602
|
source="generate",
|
574
603
|
),
|
575
|
-
stop_reason="
|
604
|
+
stop_reason="content_filter",
|
576
605
|
)
|
577
606
|
]
|
607
|
+
else:
|
608
|
+
raise RuntimeError(
|
609
|
+
"Google response includes no completion candidates and no block reason: "
|
610
|
+
+ f"{response.model_dump_json(indent=2)}"
|
611
|
+
)
|
578
612
|
|
579
613
|
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
OTHER = 5
|
614
|
+
def split_reasoning(content: str) -> tuple[str | None, str]:
|
615
|
+
separator = "\nFinal Answer: "
|
616
|
+
if separator in content:
|
617
|
+
parts = content.split(separator, 1) # dplit only on first occurrence
|
618
|
+
return parts[0].strip(), separator.lstrip() + parts[1].strip()
|
619
|
+
else:
|
620
|
+
return None, content.strip()
|
621
|
+
|
589
622
|
|
623
|
+
def prompt_feedback_to_content(
|
624
|
+
feedback: GenerateContentResponsePromptFeedback,
|
625
|
+
) -> str:
|
626
|
+
content: list[str] = []
|
627
|
+
block_reason = str(feedback.block_reason) if feedback.block_reason else "UNKNOWN"
|
628
|
+
content.append(f"BLOCKED: {block_reason}")
|
590
629
|
|
591
|
-
|
630
|
+
if feedback.block_reason_message is not None:
|
631
|
+
content.append(feedback.block_reason_message)
|
632
|
+
if feedback.safety_ratings is not None:
|
633
|
+
content.extend(
|
634
|
+
[rating.model_dump_json(indent=2) for rating in feedback.safety_ratings]
|
635
|
+
)
|
636
|
+
return "\n".join(content)
|
637
|
+
|
638
|
+
|
639
|
+
def usage_metadata_to_model_usage(
|
640
|
+
metadata: GenerateContentResponseUsageMetadata,
|
641
|
+
) -> ModelUsage | None:
|
642
|
+
if metadata is None:
|
643
|
+
return None
|
644
|
+
return ModelUsage(
|
645
|
+
input_tokens=metadata.prompt_token_count or 0,
|
646
|
+
output_tokens=metadata.candidates_token_count or 0,
|
647
|
+
total_tokens=metadata.total_token_count or 0,
|
648
|
+
)
|
649
|
+
|
650
|
+
|
651
|
+
def finish_reason_to_stop_reason(finish_reason: FinishReason) -> StopReason:
|
592
652
|
match finish_reason:
|
593
653
|
case FinishReason.STOP:
|
594
654
|
return "stop"
|
595
655
|
case FinishReason.MAX_TOKENS:
|
596
656
|
return "max_tokens"
|
597
|
-
case
|
657
|
+
case (
|
658
|
+
FinishReason.SAFETY
|
659
|
+
| FinishReason.RECITATION
|
660
|
+
| FinishReason.BLOCKLIST
|
661
|
+
| FinishReason.PROHIBITED_CONTENT
|
662
|
+
| FinishReason.SPII
|
663
|
+
):
|
598
664
|
return "content_filter"
|
599
665
|
case _:
|
600
666
|
return "unknown"
|
601
667
|
|
602
668
|
|
603
|
-
def gapi_should_retry(ex: BaseException) -> bool:
|
604
|
-
if isinstance(ex, Exception):
|
605
|
-
return if_transient_error(ex)
|
606
|
-
else:
|
607
|
-
return False
|
608
|
-
|
609
|
-
|
610
669
|
def parse_safety_settings(
|
611
670
|
safety_settings: Any,
|
612
|
-
) ->
|
671
|
+
) -> dict[HarmCategory, HarmBlockThreshold]:
|
613
672
|
# ensure we have a dict
|
614
673
|
if isinstance(safety_settings, str):
|
615
674
|
safety_settings = json.loads(safety_settings)
|
616
675
|
if not isinstance(safety_settings, dict):
|
617
676
|
raise ValueError(f"{SAFETY_SETTINGS} must be dictionary.")
|
618
677
|
|
619
|
-
parsed_settings:
|
678
|
+
parsed_settings: dict[HarmCategory, HarmBlockThreshold] = {}
|
620
679
|
for key, value in safety_settings.items():
|
621
|
-
if isinstance(key, str):
|
622
|
-
key = str_to_harm_category(key)
|
623
|
-
if not isinstance(key, HarmCategory):
|
680
|
+
if not isinstance(key, str):
|
624
681
|
raise ValueError(f"Unexpected type for harm category: {key}")
|
625
|
-
if isinstance(value, str):
|
626
|
-
value = str_to_harm_block_threshold(value)
|
627
|
-
if not isinstance(value, HarmBlockThreshold):
|
682
|
+
if not isinstance(value, str):
|
628
683
|
raise ValueError(f"Unexpected type for harm block threshold: {value}")
|
629
|
-
|
684
|
+
key = str_to_harm_category(key)
|
685
|
+
value = str_to_harm_block_threshold(value)
|
630
686
|
parsed_settings[key] = value
|
631
|
-
|
632
687
|
return parsed_settings
|
633
688
|
|
634
689
|
|
635
|
-
def str_to_harm_category(category: str) ->
|
690
|
+
def str_to_harm_category(category: str) -> HarmCategory:
|
636
691
|
category = category.upper()
|
692
|
+
# `in` instead of `==` to allow users to pass in short version e.g. "HARASSMENT" or
|
693
|
+
# long version e.g. "HARM_CATEGORY_HARASSMENT" strings.
|
694
|
+
if "CIVIC_INTEGRITY" in category:
|
695
|
+
return HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY
|
696
|
+
if "DANGEROUS_CONTENT" in category:
|
697
|
+
return HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
|
698
|
+
if "HATE_SPEECH" in category:
|
699
|
+
return HarmCategory.HARM_CATEGORY_HATE_SPEECH
|
637
700
|
if "HARASSMENT" in category:
|
638
|
-
return
|
639
|
-
|
640
|
-
return
|
641
|
-
|
642
|
-
return
|
643
|
-
|
644
|
-
return cast(int, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
|
645
|
-
else:
|
646
|
-
# NOTE: Although there is an "UNSPECIFIED" category, in the
|
647
|
-
# documentation, the API does not accept it.
|
648
|
-
raise ValueError(f"Unknown HarmCategory: {category}")
|
701
|
+
return HarmCategory.HARM_CATEGORY_HARASSMENT
|
702
|
+
if "SEXUALLY_EXPLICIT" in category:
|
703
|
+
return HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT
|
704
|
+
if "UNSPECIFIED" in category:
|
705
|
+
return HarmCategory.HARM_CATEGORY_UNSPECIFIED
|
706
|
+
raise ValueError(f"Unknown HarmCategory: {category}")
|
649
707
|
|
650
708
|
|
651
|
-
def str_to_harm_block_threshold(threshold: str) ->
|
709
|
+
def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
|
652
710
|
threshold = threshold.upper()
|
653
711
|
if "LOW" in threshold:
|
654
712
|
return HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
|
655
|
-
|
713
|
+
if "MEDIUM" in threshold:
|
656
714
|
return HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE
|
657
|
-
|
715
|
+
if "HIGH" in threshold:
|
658
716
|
return HarmBlockThreshold.BLOCK_ONLY_HIGH
|
659
|
-
|
717
|
+
if "NONE" in threshold:
|
660
718
|
return HarmBlockThreshold.BLOCK_NONE
|
661
|
-
|
662
|
-
|
719
|
+
if "OFF" in threshold:
|
720
|
+
return HarmBlockThreshold.OFF
|
721
|
+
raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
|
663
722
|
|
664
723
|
|
665
|
-
async def file_for_content(
|
724
|
+
async def file_for_content(
|
725
|
+
client: Client, content: ContentAudio | ContentVideo
|
726
|
+
) -> File:
|
666
727
|
# helper to write trace messages
|
667
728
|
def trace(message: str) -> None:
|
668
729
|
trace_message(logger, "Google Files", message)
|
@@ -674,7 +735,6 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
674
735
|
file = content.video
|
675
736
|
content_bytes, mime_type = await file_as_data(file)
|
676
737
|
content_sha256 = hashlib.sha256(content_bytes).hexdigest()
|
677
|
-
|
678
738
|
# we cache uploads for re-use, open the db where we track that
|
679
739
|
# (track up to 1 million previous uploads)
|
680
740
|
with inspect_kvstore("google_files", 1000000) as files_db:
|
@@ -682,7 +742,7 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
682
742
|
uploaded_file = files_db.get(content_sha256)
|
683
743
|
if uploaded_file:
|
684
744
|
try:
|
685
|
-
upload =
|
745
|
+
upload: File = client.files.get(uploaded_file)
|
686
746
|
if upload.state.name == "ACTIVE":
|
687
747
|
trace(f"Using uploaded file: {uploaded_file}")
|
688
748
|
return upload
|
@@ -693,20 +753,16 @@ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
|
693
753
|
except Exception as ex:
|
694
754
|
trace(f"Error attempting to access uploaded file: {ex}")
|
695
755
|
files_db.delete(content_sha256)
|
696
|
-
|
697
756
|
# do the upload (and record it)
|
698
|
-
upload =
|
757
|
+
upload = client.files.upload(BytesIO(content_bytes), mime_type=mime_type)
|
699
758
|
while upload.state.name == "PROCESSING":
|
700
759
|
await asyncio.sleep(3)
|
701
|
-
upload =
|
702
|
-
|
760
|
+
upload = client.files.get(upload.name)
|
703
761
|
if upload.state.name == "FAILED":
|
704
762
|
trace(f"Failed to upload file '{upload.name}: {upload.error}")
|
705
763
|
raise ValueError(f"Google file upload failed: {upload.error}")
|
706
|
-
|
707
764
|
# trace and record it
|
708
765
|
trace(f"Uploaded file: {upload.name}")
|
709
766
|
files_db.put(content_sha256, upload.name)
|
710
|
-
|
711
767
|
# return the file
|
712
768
|
return upload
|