google-genai 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,438 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """Transformers for Google GenAI SDK."""
17
+
18
+ import base64
19
+ from collections.abc import Iterable, Mapping
20
+ import inspect
21
+ import io
22
+ import re
23
+ import time
24
+ from typing import Any, Optional, Union
25
+
26
+ import PIL.Image
27
+
28
+ from . import _api_client
29
+ from . import types
30
+ from ._automatic_function_calling_util import function_to_declaration
31
+
32
+
33
+ def _resource_name(
34
+ client: _api_client.ApiClient,
35
+ resource_name: str,
36
+ *,
37
+ collection_identifier: str,
38
+ collection_hirearchy_depth: int = 2,
39
+ ):
40
+ # pylint: disable=line-too-long
41
+ """Prepends resource name with project, location, collection_identifier if needed.
42
+
43
+ The collection_identifier will only be prepended if it's not present
44
+ and the prepending won't violate the collection hierarchy depth.
45
+ When the prepending condition doesn't meet, returns the input
46
+ resource_name.
47
+
48
+ Args:
49
+ client: The API client.
50
+ resource_name: The user input resource name to be completed.
51
+ collection_identifier: The collection identifier to be prepended.
52
+ See collection identifiers in https://google.aip.dev/122.
53
+ collection_hirearchy_depth: The collection hierarchy depth.
54
+ Only set this field when the resource has nested collections.
55
+ For example, `users/vhugo1802/events/birthday-dinner-226`, the
56
+ collection_identifier is `users` and collection_hirearchy_depth is 4.
57
+ See nested collections in https://google.aip.dev/122.
58
+
59
+ Example:
60
+
61
+ resource_name = 'cachedContents/123'
62
+ client.vertexai = True
63
+ client.project = 'bar'
64
+ client.location = 'us-west1'
65
+ _resource_name(client, 'cachedContents/123', collection_identifier='cachedContents')
66
+ returns: 'projects/bar/locations/us-west1/cachedContents/123'
67
+
68
+ Example:
69
+
70
+ resource_name = 'projects/foo/locations/us-central1/cachedContents/123'
71
+ # resource_name = 'locations/us-central1/cachedContents/123'
72
+ client.vertexai = True
73
+ client.project = 'bar'
74
+ client.location = 'us-west1'
75
+ _resource_name(client, resource_name, collection_identifier='cachedContents')
76
+ returns: 'projects/foo/locations/us-central1/cachedContents/123'
77
+
78
+ Example:
79
+
80
+ resource_name = '123'
81
+ # resource_name = 'cachedContents/123'
82
+ client.vertexai = False
83
+ _resource_name(client, resource_name, collection_identifier='cachedContents')
84
+ returns 'cachedContents/123'
85
+
86
+ Example:
87
+ resource_name = 'some/wrong/cachedContents/resource/name/123'
88
+ resource_prefix = 'cachedContents'
89
+ client.vertexai = False
90
+ # client.vertexai = True
91
+ _resource_name(client, resource_name, collection_identifier='cachedContents')
92
+ returns: 'some/wrong/cachedContents/resource/name/123'
93
+
94
+ Returns:
95
+ The completed resource name.
96
+ """
97
+ should_prepend_collection_identifier = (
98
+ not resource_name.startswith(f'{collection_identifier}/')
99
+ # Check if prepending the collection identifier won't violate the
100
+ # collection hierarchy depth.
101
+ and f'{collection_identifier}/{resource_name}'.count('/') + 1
102
+ == collection_hirearchy_depth
103
+ )
104
+ if client.vertexai:
105
+ if resource_name.startswith('projects/'):
106
+ return resource_name
107
+ elif resource_name.startswith('locations/'):
108
+ return f'projects/{client.project}/{resource_name}'
109
+ elif resource_name.startswith(f'{collection_identifier}/'):
110
+ return f'projects/{client.project}/locations/{client.location}/{resource_name}'
111
+ elif should_prepend_collection_identifier:
112
+ return f'projects/{client.project}/locations/{client.location}/{collection_identifier}/{resource_name}'
113
+ else:
114
+ return resource_name
115
+ else:
116
+ if should_prepend_collection_identifier:
117
+ return f'{collection_identifier}/{resource_name}'
118
+ else:
119
+ return resource_name
120
+
121
+
122
+ def t_model(client: _api_client.ApiClient, model: str):
123
+ if not model:
124
+ raise ValueError('model is required.')
125
+ if client.vertexai:
126
+ if (
127
+ model.startswith('projects/')
128
+ or model.startswith('models/')
129
+ or model.startswith('publishers/')
130
+ ):
131
+ return model
132
+ elif '/' in model:
133
+ publisher, model_id = model.split('/', 1)
134
+ return f'publishers/{publisher}/models/{model_id}'
135
+ else:
136
+ return f'publishers/google/models/{model}'
137
+ else:
138
+ if model.startswith('models/'):
139
+ return model
140
+ elif model.startswith('tunedModels/'):
141
+ return model
142
+ else:
143
+ return f'models/{model}'
144
+
145
+ def t_caches_model(api_client: _api_client.ApiClient, model: str):
146
+ model = t_model(api_client, model)
147
+ if not model:
148
+ return None
149
+ if model.startswith('publishers/') and api_client.vertexai:
150
+ # vertex caches only support model name start with projects.
151
+ return (
152
+ f'projects/{api_client.project}/locations/{api_client.location}/{model}'
153
+ )
154
+ elif model.startswith('models/') and api_client.vertexai:
155
+ return f'projects/{api_client.project}/locations/{api_client.location}/publishers/google/{model}'
156
+ else:
157
+ return model
158
+
159
+
160
+ def pil_to_blob(img):
161
+ bytesio = io.BytesIO()
162
+ if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == 'RGBA':
163
+ img.save(bytesio, format='PNG')
164
+ mime_type = 'image/png'
165
+ else:
166
+ img.save(bytesio, format='JPEG')
167
+ mime_type = 'image/jpeg'
168
+ bytesio.seek(0)
169
+ data = bytesio.read()
170
+ return types.Blob(mime_type=mime_type, data=data)
171
+
172
+
173
+ PartType = Union[types.Part, types.PartDict, str, PIL.Image.Image]
174
+
175
+
176
+ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
177
+ if not part:
178
+ raise ValueError('content part is required.')
179
+ if isinstance(part, str):
180
+ return types.Part(text=part)
181
+ if isinstance(part, PIL.Image.Image):
182
+ return types.Part(inline_data=pil_to_blob(part))
183
+ else:
184
+ return part
185
+
186
+
187
+ def t_parts(
188
+ client: _api_client.ApiClient, parts: Union[list, PartType]
189
+ ) -> list[types.Part]:
190
+ if parts is None:
191
+ raise ValueError('content parts are required.')
192
+ if isinstance(parts, list):
193
+ return [t_part(client, part) for part in parts]
194
+ else:
195
+ return [t_part(client, parts)]
196
+
197
+
198
+ def t_image_predictions(
199
+ client: _api_client.ApiClient,
200
+ predictions: Optional[Iterable[Mapping[str, Any]]],
201
+ ) -> list[types.GeneratedImage]:
202
+ if not predictions:
203
+ return None
204
+ images = []
205
+ for prediction in predictions:
206
+ if prediction.get('image'):
207
+ images.append(
208
+ types.GeneratedImage(
209
+ image=types.Image(
210
+ gcs_uri=prediction['image']['gcsUri'],
211
+ image_bytes=prediction['image']['imageBytes'],
212
+ )
213
+ )
214
+ )
215
+ return images
216
+
217
+
218
+ ContentType = Union[types.Content, types.ContentDict, PartType]
219
+
220
+
221
+ def t_content(
222
+ client: _api_client.ApiClient,
223
+ content: ContentType,
224
+ ):
225
+ if not content:
226
+ raise ValueError('content is required.')
227
+ if isinstance(content, types.Content):
228
+ return content
229
+ if isinstance(content, dict):
230
+ return types.Content.model_validate(content)
231
+ return types.Content(role='user', parts=t_parts(client, content))
232
+
233
+
234
+ def t_contents_for_embed(
235
+ client: _api_client.ApiClient,
236
+ contents: Union[list[types.Content], list[types.ContentDict], ContentType],
237
+ ):
238
+ if client.vertexai and isinstance(contents, list):
239
+ # TODO: Assert that only text is supported.
240
+ return [t_content(client, content).parts[0].text for content in contents]
241
+ elif client.vertexai:
242
+ return [t_content(client, contents).parts[0].text]
243
+ elif isinstance(contents, list):
244
+ return [t_content(client, content) for content in contents]
245
+ else:
246
+ return [t_content(client, contents)]
247
+
248
+
249
+ def t_contents(
250
+ client: _api_client.ApiClient,
251
+ contents: Union[list[types.Content], list[types.ContentDict], ContentType],
252
+ ):
253
+ if not contents:
254
+ raise ValueError('contents are required.')
255
+ if isinstance(contents, list):
256
+ return [t_content(client, content) for content in contents]
257
+ else:
258
+ return [t_content(client, contents)]
259
+
260
+
261
+ def process_schema(data: dict):
262
+ if isinstance(data, dict):
263
+ # Iterate over a copy of keys to allow deletion
264
+ for key in list(data.keys()):
265
+ if key == 'title':
266
+ del data[key]
267
+ elif key == 'type':
268
+ data[key] = data[key].upper()
269
+ else:
270
+ process_schema(data[key])
271
+ elif isinstance(data, list):
272
+ for item in data:
273
+ process_schema(item)
274
+
275
+ return data
276
+
277
+
278
+ def t_schema(
279
+ _: _api_client.ApiClient, origin: Union[types.SchemaDict, Any]
280
+ ) -> Optional[types.Schema]:
281
+ if not origin:
282
+ return None
283
+ if isinstance(origin, dict):
284
+ return origin
285
+ schema = process_schema(origin.model_json_schema())
286
+ return types.Schema.model_validate(schema)
287
+
288
+
289
+ def t_speech_config(
290
+ _: _api_client.ApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
291
+ ) -> Optional[types.SpeechConfig]:
292
+ if not origin:
293
+ return None
294
+ if isinstance(origin, types.SpeechConfig):
295
+ return origin
296
+ if isinstance(origin, str):
297
+ return types.SpeechConfig(
298
+ voice_config=types.VoiceConfig(
299
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
300
+ )
301
+ )
302
+ raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
303
+
304
+
305
+ def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
306
+ if not origin:
307
+ return None
308
+ if inspect.isfunction(origin):
309
+ return types.Tool(
310
+ function_declarations=[function_to_declaration(client, origin)]
311
+ )
312
+ else:
313
+ return origin
314
+
315
+
316
+ # Only support functions now.
317
+ def t_tools(
318
+ client: _api_client.ApiClient, origin: list[Any]
319
+ ) -> list[types.Tool]:
320
+ if not origin:
321
+ return []
322
+ function_tool = types.Tool(function_declarations=[])
323
+ tools = []
324
+ for tool in origin:
325
+ transformed_tool = t_tool(client, tool)
326
+ # All functions should be merged into one tool.
327
+ if transformed_tool.function_declarations:
328
+ function_tool.function_declarations += (
329
+ transformed_tool.function_declarations
330
+ )
331
+ else:
332
+ tools.append(transformed_tool)
333
+ if function_tool.function_declarations:
334
+ tools.append(function_tool)
335
+ return tools
336
+
337
+
338
+ def t_cached_content_name(client: _api_client.ApiClient, name: str):
339
+ return _resource_name(client, name, collection_identifier='cachedContents')
340
+
341
+
342
+ def t_batch_job_source(client: _api_client.ApiClient, src: str):
343
+ if src.startswith('gs://'):
344
+ return types.BatchJobSource(
345
+ format='jsonl',
346
+ gcs_uri=[src],
347
+ )
348
+ elif src.startswith('bq://'):
349
+ return types.BatchJobSource(
350
+ format='bigquery',
351
+ bigquery_uri=src,
352
+ )
353
+ else:
354
+ raise ValueError(f'Unsupported source: {src}')
355
+
356
+
357
+ def t_batch_job_destination(client: _api_client.ApiClient, dest: str):
358
+ if dest.startswith('gs://'):
359
+ return types.BatchJobDestination(
360
+ format='jsonl',
361
+ gcs_uri=dest,
362
+ )
363
+ elif dest.startswith('bq://'):
364
+ return types.BatchJobDestination(
365
+ format='bigquery',
366
+ bigquery_uri=dest,
367
+ )
368
+ else:
369
+ raise ValueError(f'Unsupported destination: {dest}')
370
+
371
+
372
+ def t_batch_job_name(client: _api_client.ApiClient, name: str):
373
+ if not client.vertexai:
374
+ return name
375
+
376
+ pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
377
+ if re.match(pattern, name):
378
+ return name.split('/')[-1]
379
+ elif name.isdigit():
380
+ return name
381
+ else:
382
+ raise ValueError(f'Invalid batch job name: {name}.')
383
+
384
+
385
+ LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
386
+ LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
387
+ LRO_POLLING_TIMEOUT_SECONDS = 900.0
388
+ LRO_POLLING_MULTIPLIER = 1.5
389
+
390
+
391
+ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
392
+ if (name := struct.get('name')) and '/operations/' in name:
393
+ operation: dict[str, Any] = struct
394
+ total_seconds = 0.0
395
+ delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS
396
+ while operation.get('done') != True:
397
+ if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
398
+ raise RuntimeError(f'Operation {name} timed out.\n{operation}')
399
+ # TODO(b/374433890): Replace with LRO module once it's available.
400
+ operation: dict[str, Any] = api_client.request(
401
+ http_method='GET', path=name, request_dict={}
402
+ )
403
+ time.sleep(delay_seconds)
404
+ total_seconds += total_seconds
405
+ # Exponential backoff
406
+ delay_seconds = min(
407
+ delay_seconds * LRO_POLLING_MULTIPLIER,
408
+ LRO_POLLING_MAXIMUM_DELAY_SECONDS,
409
+ )
410
+ if error := operation.get('error'):
411
+ raise RuntimeError(
412
+ f'Operation {name} failed with error: {error}.\n{operation}'
413
+ )
414
+ return operation.get('response')
415
+ else:
416
+ return struct
417
+
418
+
419
+ def t_file_name(api_client: _api_client.ApiClient, name: str):
420
+ # Remove the files/ prefx since it's added to the url path.
421
+ if name.startswith('files/'):
422
+ return name.split('files/')[1]
423
+ return name
424
+
425
+
426
+ def t_tuning_job_status(
427
+ api_client: _api_client.ApiClient, status: str
428
+ ) -> types.JobState:
429
+ if status == 'STATE_UNSPECIFIED':
430
+ return 'JOB_STATE_UNSPECIFIED'
431
+ elif status == 'CREATING':
432
+ return 'JOB_STATE_RUNNING'
433
+ elif status == 'ACTIVE':
434
+ return 'JOB_STATE_SUCCEEDED'
435
+ elif status == 'FAILED':
436
+ return 'JOB_STATE_FAILED'
437
+ else:
438
+ return status