google-genai 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,25 +17,35 @@
17
17
 
18
18
  import base64
19
19
  from collections.abc import Iterable, Mapping
20
+ from enum import Enum, EnumMeta
20
21
  import inspect
21
22
  import io
22
23
  import re
23
24
  import time
24
- from typing import Any, Optional, Union
25
+ import typing
26
+ from typing import Any, GenericAlias, Optional, Union
27
+ import sys
25
28
 
26
- import PIL.Image
27
- import PIL.PngImagePlugin
29
+ if typing.TYPE_CHECKING:
30
+ import PIL.Image
31
+
32
+ import pydantic
28
33
 
29
34
  from . import _api_client
30
35
  from . import types
31
36
 
37
+ if sys.version_info >= (3, 11):
38
+ from types import UnionType
39
+ else:
40
+ UnionType = typing._UnionGenericAlias
41
+
32
42
 
33
43
  def _resource_name(
34
44
  client: _api_client.ApiClient,
35
45
  resource_name: str,
36
46
  *,
37
47
  collection_identifier: str,
38
- collection_hirearchy_depth: int = 2,
48
+ collection_hierarchy_depth: int = 2,
39
49
  ):
40
50
  # pylint: disable=line-too-long
41
51
  """Prepends resource name with project, location, collection_identifier if needed.
@@ -48,13 +58,13 @@ def _resource_name(
48
58
  Args:
49
59
  client: The API client.
50
60
  resource_name: The user input resource name to be completed.
51
- collection_identifier: The collection identifier to be prepended.
52
- See collection identifiers in https://google.aip.dev/122.
53
- collection_hirearchy_depth: The collection hierarchy depth.
54
- Only set this field when the resource has nested collections.
55
- For example, `users/vhugo1802/events/birthday-dinner-226`, the
56
- collection_identifier is `users` and collection_hirearchy_depth is 4.
57
- See nested collections in https://google.aip.dev/122.
61
+ collection_identifier: The collection identifier to be prepended. See
62
+ collection identifiers in https://google.aip.dev/122.
63
+ collection_hierarchy_depth: The collection hierarchy depth. Only set this
64
+ field when the resource has nested collections. For example,
65
+ `users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is
66
+ `users` and collection_hierarchy_depth is 4. See nested collections in
67
+ https://google.aip.dev/122.
58
68
 
59
69
  Example:
60
70
 
@@ -62,7 +72,8 @@ def _resource_name(
62
72
  client.vertexai = True
63
73
  client.project = 'bar'
64
74
  client.location = 'us-west1'
65
- _resource_name(client, 'cachedContents/123', collection_identifier='cachedContents')
75
+ _resource_name(client, 'cachedContents/123',
76
+ collection_identifier='cachedContents')
66
77
  returns: 'projects/bar/locations/us-west1/cachedContents/123'
67
78
 
68
79
  Example:
@@ -72,7 +83,8 @@ def _resource_name(
72
83
  client.vertexai = True
73
84
  client.project = 'bar'
74
85
  client.location = 'us-west1'
75
- _resource_name(client, resource_name, collection_identifier='cachedContents')
86
+ _resource_name(client, resource_name,
87
+ collection_identifier='cachedContents')
76
88
  returns: 'projects/foo/locations/us-central1/cachedContents/123'
77
89
 
78
90
  Example:
@@ -80,7 +92,8 @@ def _resource_name(
80
92
  resource_name = '123'
81
93
  # resource_name = 'cachedContents/123'
82
94
  client.vertexai = False
83
- _resource_name(client, resource_name, collection_identifier='cachedContents')
95
+ _resource_name(client, resource_name,
96
+ collection_identifier='cachedContents')
84
97
  returns 'cachedContents/123'
85
98
 
86
99
  Example:
@@ -88,7 +101,8 @@ def _resource_name(
88
101
  resource_prefix = 'cachedContents'
89
102
  client.vertexai = False
90
103
  # client.vertexai = True
91
- _resource_name(client, resource_name, collection_identifier='cachedContents')
104
+ _resource_name(client, resource_name,
105
+ collection_identifier='cachedContents')
92
106
  returns: 'some/wrong/cachedContents/resource/name/123'
93
107
 
94
108
  Returns:
@@ -99,7 +113,7 @@ def _resource_name(
99
113
  # Check if prepending the collection identifier won't violate the
100
114
  # collection hierarchy depth.
101
115
  and f'{collection_identifier}/{resource_name}'.count('/') + 1
102
- == collection_hirearchy_depth
116
+ == collection_hierarchy_depth
103
117
  )
104
118
  if client.vertexai:
105
119
  if resource_name.startswith('projects/'):
@@ -142,6 +156,7 @@ def t_model(client: _api_client.ApiClient, model: str):
142
156
  else:
143
157
  return f'models/{model}'
144
158
 
159
+
145
160
  def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
146
161
  if api_client.vertexai:
147
162
  if base_models:
@@ -155,8 +170,12 @@ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
155
170
  return 'tunedModels'
156
171
 
157
172
 
158
- def t_extract_models(api_client: _api_client.ApiClient, response: dict) -> list[types.Model]:
159
- if response.get('models') is not None:
173
+ def t_extract_models(
174
+ api_client: _api_client.ApiClient, response: dict
175
+ ) -> list[types.Model]:
176
+ if not response:
177
+ return []
178
+ elif response.get('models') is not None:
160
179
  return response.get('models')
161
180
  elif response.get('tunedModels') is not None:
162
181
  return response.get('tunedModels')
@@ -181,9 +200,15 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
181
200
  return model
182
201
 
183
202
 
184
- def pil_to_blob(img):
203
+ def pil_to_blob(img) -> types.Blob:
204
+ try:
205
+ import PIL.PngImagePlugin
206
+ PngImagePlugin = PIL.PngImagePlugin
207
+ except ImportError:
208
+ PngImagePlugin = None
209
+
185
210
  bytesio = io.BytesIO()
186
- if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == 'RGBA':
211
+ if PngImagePlugin is not None and isinstance(img, PngImagePlugin.PngImageFile) or img.mode == 'RGBA':
187
212
  img.save(bytesio, format='PNG')
188
213
  mime_type = 'image/png'
189
214
  else:
@@ -194,16 +219,26 @@ def pil_to_blob(img):
194
219
  return types.Blob(mime_type=mime_type, data=data)
195
220
 
196
221
 
197
- PartType = Union[types.Part, types.PartDict, str, PIL.Image.Image]
222
+ PartType = Union[types.Part, types.PartDict, str, 'PIL.Image.Image']
198
223
 
199
224
 
200
225
  def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
226
+ try:
227
+ import PIL.Image
228
+ PIL_Image = PIL.Image.Image
229
+ except ImportError:
230
+ PIL_Image = None
231
+
201
232
  if not part:
202
233
  raise ValueError('content part is required.')
203
234
  if isinstance(part, str):
204
235
  return types.Part(text=part)
205
- if isinstance(part, PIL.Image.Image):
236
+ if PIL_Image is not None and isinstance(part, PIL_Image):
206
237
  return types.Part(inline_data=pil_to_blob(part))
238
+ if isinstance(part, types.File):
239
+ if not part.uri or not part.mime_type:
240
+ raise ValueError('file uri and mime_type are required.')
241
+ return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
207
242
  else:
208
243
  return part
209
244
 
@@ -282,32 +317,234 @@ def t_contents(
282
317
  return [t_content(client, contents)]
283
318
 
284
319
 
285
- def process_schema(data: dict):
286
- if isinstance(data, dict):
287
- # Iterate over a copy of keys to allow deletion
288
- for key in list(data.keys()):
289
- if key == 'title':
290
- del data[key]
291
- elif key == 'type':
292
- data[key] = data[key].upper()
320
+ def handle_null_fields(schema: dict[str, Any]):
321
+ """Process null fields in the schema so it is compatible with OpenAPI.
322
+
323
+ The OpenAPI spec does not support 'type: 'null' in the schema. This function
324
+ handles this case by adding 'nullable: True' to the null field and removing
325
+ the {'type': 'null'} entry.
326
+
327
+ https://swagger.io/docs/specification/v3_0/data-models/data-types/#null
328
+
329
+ Example of schema properties before and after handling null fields:
330
+ Before:
331
+ {
332
+ "name": {
333
+ "title": "Name",
334
+ "type": "string"
335
+ },
336
+ "total_area_sq_mi": {
337
+ "anyOf": [
338
+ {
339
+ "type": "integer"
340
+ },
341
+ {
342
+ "type": "null"
343
+ }
344
+ ],
345
+ "default": null,
346
+ "title": "Total Area Sq Mi"
347
+ }
348
+ }
349
+
350
+ After:
351
+ {
352
+ "name": {
353
+ "title": "Name",
354
+ "type": "string"
355
+ },
356
+ "total_area_sq_mi": {
357
+ "type": "integer",
358
+ "nullable": true,
359
+ "default": null,
360
+ "title": "Total Area Sq Mi"
361
+ }
362
+ }
363
+ """
364
+ if (
365
+ isinstance(schema, dict)
366
+ and 'type' in schema
367
+ and schema['type'] == 'null'
368
+ ):
369
+ schema['nullable'] = True
370
+ del schema['type']
371
+ elif 'anyOf' in schema:
372
+ for item in schema['anyOf']:
373
+ if 'type' in item and item['type'] == 'null':
374
+ schema['nullable'] = True
375
+ schema['anyOf'].remove({'type': 'null'})
376
+ if len(schema['anyOf']) == 1:
377
+ # If there is only one type left after removing null, remove the anyOf field.
378
+ field_type = schema['anyOf'][0]['type']
379
+ schema['type'] = field_type
380
+ del schema['anyOf']
381
+
382
+
383
+ def process_schema(
384
+ schema: dict[str, Any],
385
+ client: Optional[_api_client.ApiClient] = None,
386
+ defs: Optional[dict[str, Any]]=None):
387
+ """Updates the schema and each sub-schema inplace to be API-compatible.
388
+
389
+ - Removes the `title` field from the schema if the client is not vertexai.
390
+ - Inlines the $defs.
391
+
392
+ Example of a schema before and after (with mldev):
393
+ Before:
394
+
395
+ `schema`
396
+
397
+ {
398
+ 'items': {
399
+ '$ref': '#/$defs/CountryInfo'
400
+ },
401
+ 'title': 'Placeholder',
402
+ 'type': 'array'
403
+ }
404
+
405
+
406
+ `defs`
407
+
408
+ {
409
+ 'CountryInfo': {
410
+ 'properties': {
411
+ 'continent': {
412
+ 'title': 'Continent',
413
+ 'type': 'string'
414
+ },
415
+ 'gdp': {
416
+ 'title': 'Gdp',
417
+ 'type': 'integer'}
418
+ },
419
+ }
420
+ 'required':['continent', 'gdp'],
421
+ 'title': 'CountryInfo',
422
+ 'type': 'object'
423
+ }
424
+ }
425
+
426
+ After:
427
+
428
+ `schema`
429
+ {
430
+ 'items': {
431
+ 'properties': {
432
+ 'continent': {
433
+ 'type': 'string'
434
+ },
435
+ 'gdp': {
436
+ 'type': 'integer'}
437
+ },
438
+ }
439
+ 'required':['continent', 'gdp'],
440
+ 'type': 'object'
441
+ },
442
+ 'type': 'array'
443
+ }
444
+ """
445
+ if client and not client.vertexai:
446
+ schema.pop('title', None)
447
+
448
+ if defs is None:
449
+ defs = schema.pop('$defs', {})
450
+ for _, sub_schema in defs.items():
451
+ process_schema(sub_schema, client, defs)
452
+
453
+ handle_null_fields(schema)
454
+
455
+ any_of = schema.get('anyOf', None)
456
+ if any_of is not None:
457
+ for sub_schema in any_of:
458
+ process_schema(sub_schema, client, defs)
459
+ return
460
+
461
+ schema_type = schema.get('type', None)
462
+ if isinstance(schema_type, Enum):
463
+ schema_type = schema_type.value
464
+ schema_type = schema_type.upper()
465
+
466
+ if schema_type == 'OBJECT':
467
+ properties = schema.get('properties', None)
468
+ if properties is None:
469
+ return
470
+ for name, sub_schema in properties.items():
471
+ ref_key = sub_schema.get('$ref', None)
472
+ if ref_key is None:
473
+ process_schema(sub_schema, client, defs)
293
474
  else:
294
- process_schema(data[key])
295
- elif isinstance(data, list):
296
- for item in data:
297
- process_schema(item)
475
+ ref = defs[ref_key.split('defs/')[-1]]
476
+ process_schema(ref, client, defs)
477
+ properties[name] = ref
478
+ elif schema_type == 'ARRAY':
479
+ sub_schema = schema.get('items', None)
480
+ if sub_schema is None:
481
+ return
482
+ ref_key = sub_schema.get('$ref', None)
483
+ if ref_key is None:
484
+ process_schema(sub_schema, client, defs)
485
+ else:
486
+ ref = defs[ref_key.split('defs/')[-1]]
487
+ process_schema(ref, client, defs)
488
+ schema['items'] = ref
489
+
490
+ def _process_enum(
491
+ enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
492
+ ) -> types.Schema:
493
+ for member in enum:
494
+ if not isinstance(member.value, str):
495
+ raise TypeError(
496
+ f'Enum member {member.name} value must be a string, got'
497
+ f' {type(member.value)}'
498
+ )
499
+ class Placeholder(pydantic.BaseModel):
500
+ placeholder: enum
298
501
 
299
- return data
502
+ enum_schema = Placeholder.model_json_schema()
503
+ process_schema(enum_schema, client)
504
+ enum_schema = enum_schema['properties']['placeholder']
505
+ return types.Schema.model_validate(enum_schema)
300
506
 
301
507
 
302
508
  def t_schema(
303
- _: _api_client.ApiClient, origin: Union[types.SchemaDict, Any]
509
+ client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any]
304
510
  ) -> Optional[types.Schema]:
305
511
  if not origin:
306
512
  return None
307
513
  if isinstance(origin, dict):
308
- return origin
309
- schema = process_schema(origin.model_json_schema())
310
- return types.Schema.model_validate(schema)
514
+ process_schema(origin, client)
515
+ return types.Schema.model_validate(origin)
516
+ if isinstance(origin, EnumMeta):
517
+ return _process_enum(origin, client)
518
+ if isinstance(origin, types.Schema):
519
+ if dict(origin) == dict(types.Schema()):
520
+ # response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
521
+ raise ValueError(f'Unsupported schema type.')
522
+ schema = origin.model_dump(exclude_unset=True)
523
+ process_schema(schema, client)
524
+ return types.Schema.model_validate(schema)
525
+
526
+ if (
527
+ # in Python 3.9 Generic alias list[int] counts as a type,
528
+ # and breaks issubclass because it's not a class.
529
+ not isinstance(origin, GenericAlias) and
530
+ isinstance(origin, type) and
531
+ issubclass(origin, pydantic.BaseModel)
532
+ ):
533
+ schema = origin.model_json_schema()
534
+ process_schema(schema, client)
535
+ return types.Schema.model_validate(schema)
536
+ elif (
537
+ isinstance(origin, GenericAlias) or isinstance(origin, type) or isinstance(origin, UnionType)
538
+ ):
539
+ class Placeholder(pydantic.BaseModel):
540
+ placeholder: origin
541
+
542
+ schema = Placeholder.model_json_schema()
543
+ process_schema(schema, client)
544
+ schema = schema['properties']['placeholder']
545
+ return types.Schema.model_validate(schema)
546
+
547
+ raise ValueError(f'Unsupported schema type: {origin}')
311
548
 
312
549
 
313
550
  def t_speech_config(
@@ -343,10 +580,12 @@ def t_speech_config(
343
580
  def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
344
581
  if not origin:
345
582
  return None
346
- if inspect.isfunction(origin):
583
+ if inspect.isfunction(origin) or inspect.ismethod(origin):
347
584
  return types.Tool(
348
585
  function_declarations=[
349
- types.FunctionDeclaration.from_function(client, origin)
586
+ types.FunctionDeclaration.from_callable(
587
+ client=client, callable=origin
588
+ )
350
589
  ]
351
590
  )
352
591
  else:
@@ -456,10 +695,25 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
456
695
  return struct
457
696
 
458
697
 
459
- def t_file_name(api_client: _api_client.ApiClient, name: str):
460
- # Remove the files/ prefx since it's added to the url path.
461
- if name.startswith('files/'):
462
- return name.split('files/')[1]
698
+ def t_file_name(
699
+ api_client: _api_client.ApiClient, name: Union[str, types.File]
700
+ ):
701
+ # Remove the files/ prefix since it's added to the url path.
702
+ if isinstance(name, types.File):
703
+ name = name.name
704
+
705
+ if name is None:
706
+ raise ValueError('File name is required.')
707
+
708
+ if name.startswith('https://'):
709
+ suffix = name.split('files/')[1]
710
+ match = re.match('[a-z0-9]+', suffix)
711
+ if match is None:
712
+ raise ValueError(f'Could not extract file name from URI: {name}')
713
+ name = match.group(0)
714
+ elif name.startswith('files/'):
715
+ name = name.split('files/')[1]
716
+
463
717
  return name
464
718
 
465
719
 
@@ -481,12 +735,8 @@ def t_tuning_job_status(
481
735
  # Some fields don't accept url safe base64 encoding.
482
736
  # We shouldn't use this transformer if the backend adhere to Cloud Type
483
737
  # format https://cloud.google.com/docs/discovery/type-format.
484
- # TODO(b/389133914): Remove the hack after Vertex backend fix the issue.
738
+ # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
485
739
  def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
486
740
  if not isinstance(data, bytes):
487
741
  return data
488
- if api_client.vertexai:
489
- return base64.b64encode(data).decode('ascii')
490
- else:
491
- return base64.urlsafe_encode(data).decode('ascii')
492
-
742
+ return base64.b64encode(data).decode('ascii')