google-genai 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +20 -0
- google/genai/_api_client.py +467 -0
- google/genai/_automatic_function_calling_util.py +341 -0
- google/genai/_common.py +256 -0
- google/genai/_extra_utils.py +295 -0
- google/genai/_replay_api_client.py +478 -0
- google/genai/_test_api_client.py +149 -0
- google/genai/_transformers.py +438 -0
- google/genai/batches.py +1041 -0
- google/genai/caches.py +1830 -0
- google/genai/chats.py +184 -0
- google/genai/client.py +277 -0
- google/genai/errors.py +110 -0
- google/genai/files.py +1211 -0
- google/genai/live.py +629 -0
- google/genai/models.py +5307 -0
- google/genai/pagers.py +245 -0
- google/genai/tunings.py +1366 -0
- google/genai/types.py +7639 -0
- google_genai-0.0.1.dist-info/LICENSE +202 -0
- google_genai-0.0.1.dist-info/METADATA +763 -0
- google_genai-0.0.1.dist-info/RECORD +24 -0
- google_genai-0.0.1.dist-info/WHEEL +5 -0
- google_genai-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,438 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Transformers for Google GenAI SDK."""
|
17
|
+
|
18
|
+
import base64
|
19
|
+
from collections.abc import Iterable, Mapping
|
20
|
+
import inspect
|
21
|
+
import io
|
22
|
+
import re
|
23
|
+
import time
|
24
|
+
from typing import Any, Optional, Union
|
25
|
+
|
26
|
+
import PIL.Image
|
27
|
+
|
28
|
+
from . import _api_client
|
29
|
+
from . import types
|
30
|
+
from ._automatic_function_calling_util import function_to_declaration
|
31
|
+
|
32
|
+
|
33
|
+
def _resource_name(
|
34
|
+
client: _api_client.ApiClient,
|
35
|
+
resource_name: str,
|
36
|
+
*,
|
37
|
+
collection_identifier: str,
|
38
|
+
collection_hirearchy_depth: int = 2,
|
39
|
+
):
|
40
|
+
# pylint: disable=line-too-long
|
41
|
+
"""Prepends resource name with project, location, collection_identifier if needed.
|
42
|
+
|
43
|
+
The collection_identifier will only be prepended if it's not present
|
44
|
+
and the prepending won't violate the collection hierarchy depth.
|
45
|
+
When the prepending condition doesn't meet, returns the input
|
46
|
+
resource_name.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
client: The API client.
|
50
|
+
resource_name: The user input resource name to be completed.
|
51
|
+
collection_identifier: The collection identifier to be prepended.
|
52
|
+
See collection identifiers in https://google.aip.dev/122.
|
53
|
+
collection_hirearchy_depth: The collection hierarchy depth.
|
54
|
+
Only set this field when the resource has nested collections.
|
55
|
+
For example, `users/vhugo1802/events/birthday-dinner-226`, the
|
56
|
+
collection_identifier is `users` and collection_hirearchy_depth is 4.
|
57
|
+
See nested collections in https://google.aip.dev/122.
|
58
|
+
|
59
|
+
Example:
|
60
|
+
|
61
|
+
resource_name = 'cachedContents/123'
|
62
|
+
client.vertexai = True
|
63
|
+
client.project = 'bar'
|
64
|
+
client.location = 'us-west1'
|
65
|
+
_resource_name(client, 'cachedContents/123', collection_identifier='cachedContents')
|
66
|
+
returns: 'projects/bar/locations/us-west1/cachedContents/123'
|
67
|
+
|
68
|
+
Example:
|
69
|
+
|
70
|
+
resource_name = 'projects/foo/locations/us-central1/cachedContents/123'
|
71
|
+
# resource_name = 'locations/us-central1/cachedContents/123'
|
72
|
+
client.vertexai = True
|
73
|
+
client.project = 'bar'
|
74
|
+
client.location = 'us-west1'
|
75
|
+
_resource_name(client, resource_name, collection_identifier='cachedContents')
|
76
|
+
returns: 'projects/foo/locations/us-central1/cachedContents/123'
|
77
|
+
|
78
|
+
Example:
|
79
|
+
|
80
|
+
resource_name = '123'
|
81
|
+
# resource_name = 'cachedContents/123'
|
82
|
+
client.vertexai = False
|
83
|
+
_resource_name(client, resource_name, collection_identifier='cachedContents')
|
84
|
+
returns 'cachedContents/123'
|
85
|
+
|
86
|
+
Example:
|
87
|
+
resource_name = 'some/wrong/cachedContents/resource/name/123'
|
88
|
+
resource_prefix = 'cachedContents'
|
89
|
+
client.vertexai = False
|
90
|
+
# client.vertexai = True
|
91
|
+
_resource_name(client, resource_name, collection_identifier='cachedContents')
|
92
|
+
returns: 'some/wrong/cachedContents/resource/name/123'
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
The completed resource name.
|
96
|
+
"""
|
97
|
+
should_prepend_collection_identifier = (
|
98
|
+
not resource_name.startswith(f'{collection_identifier}/')
|
99
|
+
# Check if prepending the collection identifier won't violate the
|
100
|
+
# collection hierarchy depth.
|
101
|
+
and f'{collection_identifier}/{resource_name}'.count('/') + 1
|
102
|
+
== collection_hirearchy_depth
|
103
|
+
)
|
104
|
+
if client.vertexai:
|
105
|
+
if resource_name.startswith('projects/'):
|
106
|
+
return resource_name
|
107
|
+
elif resource_name.startswith('locations/'):
|
108
|
+
return f'projects/{client.project}/{resource_name}'
|
109
|
+
elif resource_name.startswith(f'{collection_identifier}/'):
|
110
|
+
return f'projects/{client.project}/locations/{client.location}/{resource_name}'
|
111
|
+
elif should_prepend_collection_identifier:
|
112
|
+
return f'projects/{client.project}/locations/{client.location}/{collection_identifier}/{resource_name}'
|
113
|
+
else:
|
114
|
+
return resource_name
|
115
|
+
else:
|
116
|
+
if should_prepend_collection_identifier:
|
117
|
+
return f'{collection_identifier}/{resource_name}'
|
118
|
+
else:
|
119
|
+
return resource_name
|
120
|
+
|
121
|
+
|
122
|
+
def t_model(client: _api_client.ApiClient, model: str):
|
123
|
+
if not model:
|
124
|
+
raise ValueError('model is required.')
|
125
|
+
if client.vertexai:
|
126
|
+
if (
|
127
|
+
model.startswith('projects/')
|
128
|
+
or model.startswith('models/')
|
129
|
+
or model.startswith('publishers/')
|
130
|
+
):
|
131
|
+
return model
|
132
|
+
elif '/' in model:
|
133
|
+
publisher, model_id = model.split('/', 1)
|
134
|
+
return f'publishers/{publisher}/models/{model_id}'
|
135
|
+
else:
|
136
|
+
return f'publishers/google/models/{model}'
|
137
|
+
else:
|
138
|
+
if model.startswith('models/'):
|
139
|
+
return model
|
140
|
+
elif model.startswith('tunedModels/'):
|
141
|
+
return model
|
142
|
+
else:
|
143
|
+
return f'models/{model}'
|
144
|
+
|
145
|
+
def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
146
|
+
model = t_model(api_client, model)
|
147
|
+
if not model:
|
148
|
+
return None
|
149
|
+
if model.startswith('publishers/') and api_client.vertexai:
|
150
|
+
# vertex caches only support model name start with projects.
|
151
|
+
return (
|
152
|
+
f'projects/{api_client.project}/locations/{api_client.location}/{model}'
|
153
|
+
)
|
154
|
+
elif model.startswith('models/') and api_client.vertexai:
|
155
|
+
return f'projects/{api_client.project}/locations/{api_client.location}/publishers/google/{model}'
|
156
|
+
else:
|
157
|
+
return model
|
158
|
+
|
159
|
+
|
160
|
+
def pil_to_blob(img):
|
161
|
+
bytesio = io.BytesIO()
|
162
|
+
if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == 'RGBA':
|
163
|
+
img.save(bytesio, format='PNG')
|
164
|
+
mime_type = 'image/png'
|
165
|
+
else:
|
166
|
+
img.save(bytesio, format='JPEG')
|
167
|
+
mime_type = 'image/jpeg'
|
168
|
+
bytesio.seek(0)
|
169
|
+
data = bytesio.read()
|
170
|
+
return types.Blob(mime_type=mime_type, data=data)
|
171
|
+
|
172
|
+
|
173
|
+
PartType = Union[types.Part, types.PartDict, str, PIL.Image.Image]
|
174
|
+
|
175
|
+
|
176
|
+
def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
177
|
+
if not part:
|
178
|
+
raise ValueError('content part is required.')
|
179
|
+
if isinstance(part, str):
|
180
|
+
return types.Part(text=part)
|
181
|
+
if isinstance(part, PIL.Image.Image):
|
182
|
+
return types.Part(inline_data=pil_to_blob(part))
|
183
|
+
else:
|
184
|
+
return part
|
185
|
+
|
186
|
+
|
187
|
+
def t_parts(
|
188
|
+
client: _api_client.ApiClient, parts: Union[list, PartType]
|
189
|
+
) -> list[types.Part]:
|
190
|
+
if parts is None:
|
191
|
+
raise ValueError('content parts are required.')
|
192
|
+
if isinstance(parts, list):
|
193
|
+
return [t_part(client, part) for part in parts]
|
194
|
+
else:
|
195
|
+
return [t_part(client, parts)]
|
196
|
+
|
197
|
+
|
198
|
+
def t_image_predictions(
|
199
|
+
client: _api_client.ApiClient,
|
200
|
+
predictions: Optional[Iterable[Mapping[str, Any]]],
|
201
|
+
) -> list[types.GeneratedImage]:
|
202
|
+
if not predictions:
|
203
|
+
return None
|
204
|
+
images = []
|
205
|
+
for prediction in predictions:
|
206
|
+
if prediction.get('image'):
|
207
|
+
images.append(
|
208
|
+
types.GeneratedImage(
|
209
|
+
image=types.Image(
|
210
|
+
gcs_uri=prediction['image']['gcsUri'],
|
211
|
+
image_bytes=prediction['image']['imageBytes'],
|
212
|
+
)
|
213
|
+
)
|
214
|
+
)
|
215
|
+
return images
|
216
|
+
|
217
|
+
|
218
|
+
ContentType = Union[types.Content, types.ContentDict, PartType]
|
219
|
+
|
220
|
+
|
221
|
+
def t_content(
|
222
|
+
client: _api_client.ApiClient,
|
223
|
+
content: ContentType,
|
224
|
+
):
|
225
|
+
if not content:
|
226
|
+
raise ValueError('content is required.')
|
227
|
+
if isinstance(content, types.Content):
|
228
|
+
return content
|
229
|
+
if isinstance(content, dict):
|
230
|
+
return types.Content.model_validate(content)
|
231
|
+
return types.Content(role='user', parts=t_parts(client, content))
|
232
|
+
|
233
|
+
|
234
|
+
def t_contents_for_embed(
|
235
|
+
client: _api_client.ApiClient,
|
236
|
+
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
237
|
+
):
|
238
|
+
if client.vertexai and isinstance(contents, list):
|
239
|
+
# TODO: Assert that only text is supported.
|
240
|
+
return [t_content(client, content).parts[0].text for content in contents]
|
241
|
+
elif client.vertexai:
|
242
|
+
return [t_content(client, contents).parts[0].text]
|
243
|
+
elif isinstance(contents, list):
|
244
|
+
return [t_content(client, content) for content in contents]
|
245
|
+
else:
|
246
|
+
return [t_content(client, contents)]
|
247
|
+
|
248
|
+
|
249
|
+
def t_contents(
|
250
|
+
client: _api_client.ApiClient,
|
251
|
+
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
252
|
+
):
|
253
|
+
if not contents:
|
254
|
+
raise ValueError('contents are required.')
|
255
|
+
if isinstance(contents, list):
|
256
|
+
return [t_content(client, content) for content in contents]
|
257
|
+
else:
|
258
|
+
return [t_content(client, contents)]
|
259
|
+
|
260
|
+
|
261
|
+
def process_schema(data: dict):
|
262
|
+
if isinstance(data, dict):
|
263
|
+
# Iterate over a copy of keys to allow deletion
|
264
|
+
for key in list(data.keys()):
|
265
|
+
if key == 'title':
|
266
|
+
del data[key]
|
267
|
+
elif key == 'type':
|
268
|
+
data[key] = data[key].upper()
|
269
|
+
else:
|
270
|
+
process_schema(data[key])
|
271
|
+
elif isinstance(data, list):
|
272
|
+
for item in data:
|
273
|
+
process_schema(item)
|
274
|
+
|
275
|
+
return data
|
276
|
+
|
277
|
+
|
278
|
+
def t_schema(
|
279
|
+
_: _api_client.ApiClient, origin: Union[types.SchemaDict, Any]
|
280
|
+
) -> Optional[types.Schema]:
|
281
|
+
if not origin:
|
282
|
+
return None
|
283
|
+
if isinstance(origin, dict):
|
284
|
+
return origin
|
285
|
+
schema = process_schema(origin.model_json_schema())
|
286
|
+
return types.Schema.model_validate(schema)
|
287
|
+
|
288
|
+
|
289
|
+
def t_speech_config(
|
290
|
+
_: _api_client.ApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
|
291
|
+
) -> Optional[types.SpeechConfig]:
|
292
|
+
if not origin:
|
293
|
+
return None
|
294
|
+
if isinstance(origin, types.SpeechConfig):
|
295
|
+
return origin
|
296
|
+
if isinstance(origin, str):
|
297
|
+
return types.SpeechConfig(
|
298
|
+
voice_config=types.VoiceConfig(
|
299
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
|
300
|
+
)
|
301
|
+
)
|
302
|
+
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
|
303
|
+
|
304
|
+
|
305
|
+
def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
|
306
|
+
if not origin:
|
307
|
+
return None
|
308
|
+
if inspect.isfunction(origin):
|
309
|
+
return types.Tool(
|
310
|
+
function_declarations=[function_to_declaration(client, origin)]
|
311
|
+
)
|
312
|
+
else:
|
313
|
+
return origin
|
314
|
+
|
315
|
+
|
316
|
+
# Only support functions now.
|
317
|
+
def t_tools(
|
318
|
+
client: _api_client.ApiClient, origin: list[Any]
|
319
|
+
) -> list[types.Tool]:
|
320
|
+
if not origin:
|
321
|
+
return []
|
322
|
+
function_tool = types.Tool(function_declarations=[])
|
323
|
+
tools = []
|
324
|
+
for tool in origin:
|
325
|
+
transformed_tool = t_tool(client, tool)
|
326
|
+
# All functions should be merged into one tool.
|
327
|
+
if transformed_tool.function_declarations:
|
328
|
+
function_tool.function_declarations += (
|
329
|
+
transformed_tool.function_declarations
|
330
|
+
)
|
331
|
+
else:
|
332
|
+
tools.append(transformed_tool)
|
333
|
+
if function_tool.function_declarations:
|
334
|
+
tools.append(function_tool)
|
335
|
+
return tools
|
336
|
+
|
337
|
+
|
338
|
+
def t_cached_content_name(client: _api_client.ApiClient, name: str):
|
339
|
+
return _resource_name(client, name, collection_identifier='cachedContents')
|
340
|
+
|
341
|
+
|
342
|
+
def t_batch_job_source(client: _api_client.ApiClient, src: str):
|
343
|
+
if src.startswith('gs://'):
|
344
|
+
return types.BatchJobSource(
|
345
|
+
format='jsonl',
|
346
|
+
gcs_uri=[src],
|
347
|
+
)
|
348
|
+
elif src.startswith('bq://'):
|
349
|
+
return types.BatchJobSource(
|
350
|
+
format='bigquery',
|
351
|
+
bigquery_uri=src,
|
352
|
+
)
|
353
|
+
else:
|
354
|
+
raise ValueError(f'Unsupported source: {src}')
|
355
|
+
|
356
|
+
|
357
|
+
def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
|
358
|
+
if dest.startswith('gs://'):
|
359
|
+
return types.BatchJobDestination(
|
360
|
+
format='jsonl',
|
361
|
+
gcs_uri=dest,
|
362
|
+
)
|
363
|
+
elif dest.startswith('bq://'):
|
364
|
+
return types.BatchJobDestination(
|
365
|
+
format='bigquery',
|
366
|
+
bigquery_uri=dest,
|
367
|
+
)
|
368
|
+
else:
|
369
|
+
raise ValueError(f'Unsupported destination: {dest}')
|
370
|
+
|
371
|
+
|
372
|
+
def t_batch_job_name(client: _api_client.ApiClient, name: str):
|
373
|
+
if not client.vertexai:
|
374
|
+
return name
|
375
|
+
|
376
|
+
pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
|
377
|
+
if re.match(pattern, name):
|
378
|
+
return name.split('/')[-1]
|
379
|
+
elif name.isdigit():
|
380
|
+
return name
|
381
|
+
else:
|
382
|
+
raise ValueError(f'Invalid batch job name: {name}.')
|
383
|
+
|
384
|
+
|
385
|
+
LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
|
386
|
+
LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
|
387
|
+
LRO_POLLING_TIMEOUT_SECONDS = 900.0
|
388
|
+
LRO_POLLING_MULTIPLIER = 1.5
|
389
|
+
|
390
|
+
|
391
|
+
def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
|
392
|
+
if (name := struct.get('name')) and '/operations/' in name:
|
393
|
+
operation: dict[str, Any] = struct
|
394
|
+
total_seconds = 0.0
|
395
|
+
delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS
|
396
|
+
while operation.get('done') != True:
|
397
|
+
if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
|
398
|
+
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
|
399
|
+
# TODO(b/374433890): Replace with LRO module once it's available.
|
400
|
+
operation: dict[str, Any] = api_client.request(
|
401
|
+
http_method='GET', path=name, request_dict={}
|
402
|
+
)
|
403
|
+
time.sleep(delay_seconds)
|
404
|
+
total_seconds += total_seconds
|
405
|
+
# Exponential backoff
|
406
|
+
delay_seconds = min(
|
407
|
+
delay_seconds * LRO_POLLING_MULTIPLIER,
|
408
|
+
LRO_POLLING_MAXIMUM_DELAY_SECONDS,
|
409
|
+
)
|
410
|
+
if error := operation.get('error'):
|
411
|
+
raise RuntimeError(
|
412
|
+
f'Operation {name} failed with error: {error}.\n{operation}'
|
413
|
+
)
|
414
|
+
return operation.get('response')
|
415
|
+
else:
|
416
|
+
return struct
|
417
|
+
|
418
|
+
|
419
|
+
def t_file_name(api_client: _api_client.ApiClient, name: str):
|
420
|
+
# Remove the files/ prefx since it's added to the url path.
|
421
|
+
if name.startswith('files/'):
|
422
|
+
return name.split('files/')[1]
|
423
|
+
return name
|
424
|
+
|
425
|
+
|
426
|
+
def t_tuning_job_status(
|
427
|
+
api_client: _api_client.ApiClient, status: str
|
428
|
+
) -> types.JobState:
|
429
|
+
if status == 'STATE_UNSPECIFIED':
|
430
|
+
return 'JOB_STATE_UNSPECIFIED'
|
431
|
+
elif status == 'CREATING':
|
432
|
+
return 'JOB_STATE_RUNNING'
|
433
|
+
elif status == 'ACTIVE':
|
434
|
+
return 'JOB_STATE_SUCCEEDED'
|
435
|
+
elif status == 'FAILED':
|
436
|
+
return 'JOB_STATE_FAILED'
|
437
|
+
else:
|
438
|
+
return status
|