google-genai 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +234 -131
- google/genai/_api_module.py +24 -0
- google/genai/_automatic_function_calling_util.py +43 -22
- google/genai/_common.py +37 -12
- google/genai/_extra_utils.py +25 -19
- google/genai/_replay_api_client.py +47 -35
- google/genai/_test_api_client.py +1 -1
- google/genai/_transformers.py +301 -51
- google/genai/batches.py +204 -165
- google/genai/caches.py +127 -144
- google/genai/chats.py +22 -18
- google/genai/client.py +32 -37
- google/genai/errors.py +1 -1
- google/genai/files.py +333 -165
- google/genai/live.py +16 -6
- google/genai/models.py +601 -283
- google/genai/tunings.py +91 -428
- google/genai/types.py +1190 -955
- google/genai/version.py +1 -1
- google_genai-0.7.0.dist-info/METADATA +1021 -0
- google_genai-0.7.0.dist-info/RECORD +26 -0
- google_genai-0.5.0.dist-info/METADATA +0 -888
- google_genai-0.5.0.dist-info/RECORD +0 -25
- {google_genai-0.5.0.dist-info → google_genai-0.7.0.dist-info}/LICENSE +0 -0
- {google_genai-0.5.0.dist-info → google_genai-0.7.0.dist-info}/WHEEL +0 -0
- {google_genai-0.5.0.dist-info → google_genai-0.7.0.dist-info}/top_level.txt +0 -0
google/genai/_transformers.py
CHANGED
@@ -17,25 +17,35 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
from collections.abc import Iterable, Mapping
|
20
|
+
from enum import Enum, EnumMeta
|
20
21
|
import inspect
|
21
22
|
import io
|
22
23
|
import re
|
23
24
|
import time
|
24
|
-
|
25
|
+
import typing
|
26
|
+
from typing import Any, GenericAlias, Optional, Union
|
27
|
+
import sys
|
25
28
|
|
26
|
-
|
27
|
-
import PIL.
|
29
|
+
if typing.TYPE_CHECKING:
|
30
|
+
import PIL.Image
|
31
|
+
|
32
|
+
import pydantic
|
28
33
|
|
29
34
|
from . import _api_client
|
30
35
|
from . import types
|
31
36
|
|
37
|
+
if sys.version_info >= (3, 11):
|
38
|
+
from types import UnionType
|
39
|
+
else:
|
40
|
+
UnionType = typing._UnionGenericAlias
|
41
|
+
|
32
42
|
|
33
43
|
def _resource_name(
|
34
44
|
client: _api_client.ApiClient,
|
35
45
|
resource_name: str,
|
36
46
|
*,
|
37
47
|
collection_identifier: str,
|
38
|
-
|
48
|
+
collection_hierarchy_depth: int = 2,
|
39
49
|
):
|
40
50
|
# pylint: disable=line-too-long
|
41
51
|
"""Prepends resource name with project, location, collection_identifier if needed.
|
@@ -48,13 +58,13 @@ def _resource_name(
|
|
48
58
|
Args:
|
49
59
|
client: The API client.
|
50
60
|
resource_name: The user input resource name to be completed.
|
51
|
-
collection_identifier: The collection identifier to be prepended.
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
61
|
+
collection_identifier: The collection identifier to be prepended. See
|
62
|
+
collection identifiers in https://google.aip.dev/122.
|
63
|
+
collection_hierarchy_depth: The collection hierarchy depth. Only set this
|
64
|
+
field when the resource has nested collections. For example,
|
65
|
+
`users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is
|
66
|
+
`users` and collection_hierarchy_depth is 4. See nested collections in
|
67
|
+
https://google.aip.dev/122.
|
58
68
|
|
59
69
|
Example:
|
60
70
|
|
@@ -62,7 +72,8 @@ def _resource_name(
|
|
62
72
|
client.vertexai = True
|
63
73
|
client.project = 'bar'
|
64
74
|
client.location = 'us-west1'
|
65
|
-
_resource_name(client, 'cachedContents/123',
|
75
|
+
_resource_name(client, 'cachedContents/123',
|
76
|
+
collection_identifier='cachedContents')
|
66
77
|
returns: 'projects/bar/locations/us-west1/cachedContents/123'
|
67
78
|
|
68
79
|
Example:
|
@@ -72,7 +83,8 @@ def _resource_name(
|
|
72
83
|
client.vertexai = True
|
73
84
|
client.project = 'bar'
|
74
85
|
client.location = 'us-west1'
|
75
|
-
_resource_name(client, resource_name,
|
86
|
+
_resource_name(client, resource_name,
|
87
|
+
collection_identifier='cachedContents')
|
76
88
|
returns: 'projects/foo/locations/us-central1/cachedContents/123'
|
77
89
|
|
78
90
|
Example:
|
@@ -80,7 +92,8 @@ def _resource_name(
|
|
80
92
|
resource_name = '123'
|
81
93
|
# resource_name = 'cachedContents/123'
|
82
94
|
client.vertexai = False
|
83
|
-
_resource_name(client, resource_name,
|
95
|
+
_resource_name(client, resource_name,
|
96
|
+
collection_identifier='cachedContents')
|
84
97
|
returns 'cachedContents/123'
|
85
98
|
|
86
99
|
Example:
|
@@ -88,7 +101,8 @@ def _resource_name(
|
|
88
101
|
resource_prefix = 'cachedContents'
|
89
102
|
client.vertexai = False
|
90
103
|
# client.vertexai = True
|
91
|
-
_resource_name(client, resource_name,
|
104
|
+
_resource_name(client, resource_name,
|
105
|
+
collection_identifier='cachedContents')
|
92
106
|
returns: 'some/wrong/cachedContents/resource/name/123'
|
93
107
|
|
94
108
|
Returns:
|
@@ -99,7 +113,7 @@ def _resource_name(
|
|
99
113
|
# Check if prepending the collection identifier won't violate the
|
100
114
|
# collection hierarchy depth.
|
101
115
|
and f'{collection_identifier}/{resource_name}'.count('/') + 1
|
102
|
-
==
|
116
|
+
== collection_hierarchy_depth
|
103
117
|
)
|
104
118
|
if client.vertexai:
|
105
119
|
if resource_name.startswith('projects/'):
|
@@ -142,6 +156,7 @@ def t_model(client: _api_client.ApiClient, model: str):
|
|
142
156
|
else:
|
143
157
|
return f'models/{model}'
|
144
158
|
|
159
|
+
|
145
160
|
def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
|
146
161
|
if api_client.vertexai:
|
147
162
|
if base_models:
|
@@ -155,8 +170,12 @@ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
|
|
155
170
|
return 'tunedModels'
|
156
171
|
|
157
172
|
|
158
|
-
def t_extract_models(
|
159
|
-
|
173
|
+
def t_extract_models(
|
174
|
+
api_client: _api_client.ApiClient, response: dict
|
175
|
+
) -> list[types.Model]:
|
176
|
+
if not response:
|
177
|
+
return []
|
178
|
+
elif response.get('models') is not None:
|
160
179
|
return response.get('models')
|
161
180
|
elif response.get('tunedModels') is not None:
|
162
181
|
return response.get('tunedModels')
|
@@ -181,9 +200,15 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
|
181
200
|
return model
|
182
201
|
|
183
202
|
|
184
|
-
def pil_to_blob(img):
|
203
|
+
def pil_to_blob(img) -> types.Blob:
|
204
|
+
try:
|
205
|
+
import PIL.PngImagePlugin
|
206
|
+
PngImagePlugin = PIL.PngImagePlugin
|
207
|
+
except ImportError:
|
208
|
+
PngImagePlugin = None
|
209
|
+
|
185
210
|
bytesio = io.BytesIO()
|
186
|
-
if isinstance(img,
|
211
|
+
if PngImagePlugin is not None and isinstance(img, PngImagePlugin.PngImageFile) or img.mode == 'RGBA':
|
187
212
|
img.save(bytesio, format='PNG')
|
188
213
|
mime_type = 'image/png'
|
189
214
|
else:
|
@@ -194,16 +219,26 @@ def pil_to_blob(img):
|
|
194
219
|
return types.Blob(mime_type=mime_type, data=data)
|
195
220
|
|
196
221
|
|
197
|
-
PartType = Union[types.Part, types.PartDict, str, PIL.Image.Image]
|
222
|
+
PartType = Union[types.Part, types.PartDict, str, 'PIL.Image.Image']
|
198
223
|
|
199
224
|
|
200
225
|
def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
226
|
+
try:
|
227
|
+
import PIL.Image
|
228
|
+
PIL_Image = PIL.Image.Image
|
229
|
+
except ImportError:
|
230
|
+
PIL_Image = None
|
231
|
+
|
201
232
|
if not part:
|
202
233
|
raise ValueError('content part is required.')
|
203
234
|
if isinstance(part, str):
|
204
235
|
return types.Part(text=part)
|
205
|
-
if isinstance(part,
|
236
|
+
if PIL_Image is not None and isinstance(part, PIL_Image):
|
206
237
|
return types.Part(inline_data=pil_to_blob(part))
|
238
|
+
if isinstance(part, types.File):
|
239
|
+
if not part.uri or not part.mime_type:
|
240
|
+
raise ValueError('file uri and mime_type are required.')
|
241
|
+
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
|
207
242
|
else:
|
208
243
|
return part
|
209
244
|
|
@@ -282,32 +317,234 @@ def t_contents(
|
|
282
317
|
return [t_content(client, contents)]
|
283
318
|
|
284
319
|
|
285
|
-
def
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
320
|
+
def handle_null_fields(schema: dict[str, Any]):
|
321
|
+
"""Process null fields in the schema so it is compatible with OpenAPI.
|
322
|
+
|
323
|
+
The OpenAPI spec does not support 'type: 'null' in the schema. This function
|
324
|
+
handles this case by adding 'nullable: True' to the null field and removing
|
325
|
+
the {'type': 'null'} entry.
|
326
|
+
|
327
|
+
https://swagger.io/docs/specification/v3_0/data-models/data-types/#null
|
328
|
+
|
329
|
+
Example of schema properties before and after handling null fields:
|
330
|
+
Before:
|
331
|
+
{
|
332
|
+
"name": {
|
333
|
+
"title": "Name",
|
334
|
+
"type": "string"
|
335
|
+
},
|
336
|
+
"total_area_sq_mi": {
|
337
|
+
"anyOf": [
|
338
|
+
{
|
339
|
+
"type": "integer"
|
340
|
+
},
|
341
|
+
{
|
342
|
+
"type": "null"
|
343
|
+
}
|
344
|
+
],
|
345
|
+
"default": null,
|
346
|
+
"title": "Total Area Sq Mi"
|
347
|
+
}
|
348
|
+
}
|
349
|
+
|
350
|
+
After:
|
351
|
+
{
|
352
|
+
"name": {
|
353
|
+
"title": "Name",
|
354
|
+
"type": "string"
|
355
|
+
},
|
356
|
+
"total_area_sq_mi": {
|
357
|
+
"type": "integer",
|
358
|
+
"nullable": true,
|
359
|
+
"default": null,
|
360
|
+
"title": "Total Area Sq Mi"
|
361
|
+
}
|
362
|
+
}
|
363
|
+
"""
|
364
|
+
if (
|
365
|
+
isinstance(schema, dict)
|
366
|
+
and 'type' in schema
|
367
|
+
and schema['type'] == 'null'
|
368
|
+
):
|
369
|
+
schema['nullable'] = True
|
370
|
+
del schema['type']
|
371
|
+
elif 'anyOf' in schema:
|
372
|
+
for item in schema['anyOf']:
|
373
|
+
if 'type' in item and item['type'] == 'null':
|
374
|
+
schema['nullable'] = True
|
375
|
+
schema['anyOf'].remove({'type': 'null'})
|
376
|
+
if len(schema['anyOf']) == 1:
|
377
|
+
# If there is only one type left after removing null, remove the anyOf field.
|
378
|
+
field_type = schema['anyOf'][0]['type']
|
379
|
+
schema['type'] = field_type
|
380
|
+
del schema['anyOf']
|
381
|
+
|
382
|
+
|
383
|
+
def process_schema(
|
384
|
+
schema: dict[str, Any],
|
385
|
+
client: Optional[_api_client.ApiClient] = None,
|
386
|
+
defs: Optional[dict[str, Any]]=None):
|
387
|
+
"""Updates the schema and each sub-schema inplace to be API-compatible.
|
388
|
+
|
389
|
+
- Removes the `title` field from the schema if the client is not vertexai.
|
390
|
+
- Inlines the $defs.
|
391
|
+
|
392
|
+
Example of a schema before and after (with mldev):
|
393
|
+
Before:
|
394
|
+
|
395
|
+
`schema`
|
396
|
+
|
397
|
+
{
|
398
|
+
'items': {
|
399
|
+
'$ref': '#/$defs/CountryInfo'
|
400
|
+
},
|
401
|
+
'title': 'Placeholder',
|
402
|
+
'type': 'array'
|
403
|
+
}
|
404
|
+
|
405
|
+
|
406
|
+
`defs`
|
407
|
+
|
408
|
+
{
|
409
|
+
'CountryInfo': {
|
410
|
+
'properties': {
|
411
|
+
'continent': {
|
412
|
+
'title': 'Continent',
|
413
|
+
'type': 'string'
|
414
|
+
},
|
415
|
+
'gdp': {
|
416
|
+
'title': 'Gdp',
|
417
|
+
'type': 'integer'}
|
418
|
+
},
|
419
|
+
}
|
420
|
+
'required':['continent', 'gdp'],
|
421
|
+
'title': 'CountryInfo',
|
422
|
+
'type': 'object'
|
423
|
+
}
|
424
|
+
}
|
425
|
+
|
426
|
+
After:
|
427
|
+
|
428
|
+
`schema`
|
429
|
+
{
|
430
|
+
'items': {
|
431
|
+
'properties': {
|
432
|
+
'continent': {
|
433
|
+
'type': 'string'
|
434
|
+
},
|
435
|
+
'gdp': {
|
436
|
+
'type': 'integer'}
|
437
|
+
},
|
438
|
+
}
|
439
|
+
'required':['continent', 'gdp'],
|
440
|
+
'type': 'object'
|
441
|
+
},
|
442
|
+
'type': 'array'
|
443
|
+
}
|
444
|
+
"""
|
445
|
+
if client and not client.vertexai:
|
446
|
+
schema.pop('title', None)
|
447
|
+
|
448
|
+
if defs is None:
|
449
|
+
defs = schema.pop('$defs', {})
|
450
|
+
for _, sub_schema in defs.items():
|
451
|
+
process_schema(sub_schema, client, defs)
|
452
|
+
|
453
|
+
handle_null_fields(schema)
|
454
|
+
|
455
|
+
any_of = schema.get('anyOf', None)
|
456
|
+
if any_of is not None:
|
457
|
+
for sub_schema in any_of:
|
458
|
+
process_schema(sub_schema, client, defs)
|
459
|
+
return
|
460
|
+
|
461
|
+
schema_type = schema.get('type', None)
|
462
|
+
if isinstance(schema_type, Enum):
|
463
|
+
schema_type = schema_type.value
|
464
|
+
schema_type = schema_type.upper()
|
465
|
+
|
466
|
+
if schema_type == 'OBJECT':
|
467
|
+
properties = schema.get('properties', None)
|
468
|
+
if properties is None:
|
469
|
+
return
|
470
|
+
for name, sub_schema in properties.items():
|
471
|
+
ref_key = sub_schema.get('$ref', None)
|
472
|
+
if ref_key is None:
|
473
|
+
process_schema(sub_schema, client, defs)
|
293
474
|
else:
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
475
|
+
ref = defs[ref_key.split('defs/')[-1]]
|
476
|
+
process_schema(ref, client, defs)
|
477
|
+
properties[name] = ref
|
478
|
+
elif schema_type == 'ARRAY':
|
479
|
+
sub_schema = schema.get('items', None)
|
480
|
+
if sub_schema is None:
|
481
|
+
return
|
482
|
+
ref_key = sub_schema.get('$ref', None)
|
483
|
+
if ref_key is None:
|
484
|
+
process_schema(sub_schema, client, defs)
|
485
|
+
else:
|
486
|
+
ref = defs[ref_key.split('defs/')[-1]]
|
487
|
+
process_schema(ref, client, defs)
|
488
|
+
schema['items'] = ref
|
489
|
+
|
490
|
+
def _process_enum(
|
491
|
+
enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
|
492
|
+
) -> types.Schema:
|
493
|
+
for member in enum:
|
494
|
+
if not isinstance(member.value, str):
|
495
|
+
raise TypeError(
|
496
|
+
f'Enum member {member.name} value must be a string, got'
|
497
|
+
f' {type(member.value)}'
|
498
|
+
)
|
499
|
+
class Placeholder(pydantic.BaseModel):
|
500
|
+
placeholder: enum
|
298
501
|
|
299
|
-
|
502
|
+
enum_schema = Placeholder.model_json_schema()
|
503
|
+
process_schema(enum_schema, client)
|
504
|
+
enum_schema = enum_schema['properties']['placeholder']
|
505
|
+
return types.Schema.model_validate(enum_schema)
|
300
506
|
|
301
507
|
|
302
508
|
def t_schema(
|
303
|
-
|
509
|
+
client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any]
|
304
510
|
) -> Optional[types.Schema]:
|
305
511
|
if not origin:
|
306
512
|
return None
|
307
513
|
if isinstance(origin, dict):
|
308
|
-
|
309
|
-
|
310
|
-
|
514
|
+
process_schema(origin, client)
|
515
|
+
return types.Schema.model_validate(origin)
|
516
|
+
if isinstance(origin, EnumMeta):
|
517
|
+
return _process_enum(origin, client)
|
518
|
+
if isinstance(origin, types.Schema):
|
519
|
+
if dict(origin) == dict(types.Schema()):
|
520
|
+
# response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
|
521
|
+
raise ValueError(f'Unsupported schema type.')
|
522
|
+
schema = origin.model_dump(exclude_unset=True)
|
523
|
+
process_schema(schema, client)
|
524
|
+
return types.Schema.model_validate(schema)
|
525
|
+
|
526
|
+
if (
|
527
|
+
# in Python 3.9 Generic alias list[int] counts as a type,
|
528
|
+
# and breaks issubclass because it's not a class.
|
529
|
+
not isinstance(origin, GenericAlias) and
|
530
|
+
isinstance(origin, type) and
|
531
|
+
issubclass(origin, pydantic.BaseModel)
|
532
|
+
):
|
533
|
+
schema = origin.model_json_schema()
|
534
|
+
process_schema(schema, client)
|
535
|
+
return types.Schema.model_validate(schema)
|
536
|
+
elif (
|
537
|
+
isinstance(origin, GenericAlias) or isinstance(origin, type) or isinstance(origin, UnionType)
|
538
|
+
):
|
539
|
+
class Placeholder(pydantic.BaseModel):
|
540
|
+
placeholder: origin
|
541
|
+
|
542
|
+
schema = Placeholder.model_json_schema()
|
543
|
+
process_schema(schema, client)
|
544
|
+
schema = schema['properties']['placeholder']
|
545
|
+
return types.Schema.model_validate(schema)
|
546
|
+
|
547
|
+
raise ValueError(f'Unsupported schema type: {origin}')
|
311
548
|
|
312
549
|
|
313
550
|
def t_speech_config(
|
@@ -343,10 +580,12 @@ def t_speech_config(
|
|
343
580
|
def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
|
344
581
|
if not origin:
|
345
582
|
return None
|
346
|
-
if inspect.isfunction(origin):
|
583
|
+
if inspect.isfunction(origin) or inspect.ismethod(origin):
|
347
584
|
return types.Tool(
|
348
585
|
function_declarations=[
|
349
|
-
types.FunctionDeclaration.
|
586
|
+
types.FunctionDeclaration.from_callable(
|
587
|
+
client=client, callable=origin
|
588
|
+
)
|
350
589
|
]
|
351
590
|
)
|
352
591
|
else:
|
@@ -456,10 +695,25 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
|
|
456
695
|
return struct
|
457
696
|
|
458
697
|
|
459
|
-
def t_file_name(
|
460
|
-
|
461
|
-
|
462
|
-
|
698
|
+
def t_file_name(
|
699
|
+
api_client: _api_client.ApiClient, name: Union[str, types.File]
|
700
|
+
):
|
701
|
+
# Remove the files/ prefix since it's added to the url path.
|
702
|
+
if isinstance(name, types.File):
|
703
|
+
name = name.name
|
704
|
+
|
705
|
+
if name is None:
|
706
|
+
raise ValueError('File name is required.')
|
707
|
+
|
708
|
+
if name.startswith('https://'):
|
709
|
+
suffix = name.split('files/')[1]
|
710
|
+
match = re.match('[a-z0-9]+', suffix)
|
711
|
+
if match is None:
|
712
|
+
raise ValueError(f'Could not extract file name from URI: {name}')
|
713
|
+
name = match.group(0)
|
714
|
+
elif name.startswith('files/'):
|
715
|
+
name = name.split('files/')[1]
|
716
|
+
|
463
717
|
return name
|
464
718
|
|
465
719
|
|
@@ -481,12 +735,8 @@ def t_tuning_job_status(
|
|
481
735
|
# Some fields don't accept url safe base64 encoding.
|
482
736
|
# We shouldn't use this transformer if the backend adhere to Cloud Type
|
483
737
|
# format https://cloud.google.com/docs/discovery/type-format.
|
484
|
-
# TODO(b/389133914): Remove the hack after
|
738
|
+
# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
|
485
739
|
def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
|
486
740
|
if not isinstance(data, bytes):
|
487
741
|
return data
|
488
|
-
|
489
|
-
return base64.b64encode(data).decode('ascii')
|
490
|
-
else:
|
491
|
-
return base64.urlsafe_encode(data).decode('ascii')
|
492
|
-
|
742
|
+
return base64.b64encode(data).decode('ascii')
|