google-genai 0.2.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {google_genai-0.2.2/google_genai.egg-info → google_genai-0.4.0}/PKG-INFO +66 -18
  2. {google_genai-0.2.2 → google_genai-0.4.0}/README.md +65 -17
  3. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/__init__.py +2 -1
  4. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/_api_client.py +91 -38
  5. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/_automatic_function_calling_util.py +19 -22
  6. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/_replay_api_client.py +22 -28
  7. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/_transformers.py +15 -0
  8. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/batches.py +16 -16
  9. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/caches.py +48 -46
  10. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/chats.py +88 -15
  11. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/client.py +6 -3
  12. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/files.py +22 -22
  13. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/live.py +28 -5
  14. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/models.py +109 -77
  15. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/tunings.py +17 -17
  16. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/types.py +173 -90
  17. google_genai-0.4.0/google/genai/version.py +16 -0
  18. {google_genai-0.2.2 → google_genai-0.4.0/google_genai.egg-info}/PKG-INFO +66 -18
  19. {google_genai-0.2.2 → google_genai-0.4.0}/google_genai.egg-info/SOURCES.txt +1 -0
  20. {google_genai-0.2.2 → google_genai-0.4.0}/pyproject.toml +1 -1
  21. {google_genai-0.2.2 → google_genai-0.4.0}/LICENSE +0 -0
  22. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/_common.py +0 -0
  23. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/_extra_utils.py +0 -0
  24. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/_test_api_client.py +0 -0
  25. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/errors.py +0 -0
  26. {google_genai-0.2.2 → google_genai-0.4.0}/google/genai/pagers.py +0 -0
  27. {google_genai-0.2.2 → google_genai-0.4.0}/google_genai.egg-info/dependency_links.txt +0 -0
  28. {google_genai-0.2.2 → google_genai-0.4.0}/google_genai.egg-info/requires.txt +0 -0
  29. {google_genai-0.2.2 → google_genai-0.4.0}/google_genai.egg-info/top_level.txt +0 -0
  30. {google_genai-0.2.2 → google_genai-0.4.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: google-genai
3
- Version: 0.2.2
3
+ Version: 0.4.0
4
4
  Summary: GenAI Python SDK
5
5
  Author-email: Google LLC <googleapis-packages@google.com>
6
6
  License: Apache-2.0
@@ -35,6 +35,12 @@ Requires-Dist: websockets<15.0dev,>=13.0
35
35
 
36
36
  -----
37
37
 
38
+ ## Installation
39
+
40
+ ``` cmd
41
+ pip install google-genai
42
+ ```
43
+
38
44
  ## Imports
39
45
 
40
46
  ``` python
@@ -430,10 +436,10 @@ response1 = client.models.generate_image(
430
436
  model='imagen-3.0-generate-001',
431
437
  prompt='An umbrella in the foreground, and a rainy night sky in the background',
432
438
  config=types.GenerateImageConfig(
433
- negative_prompt= "human",
439
+ negative_prompt= 'human',
434
440
  number_of_images= 1,
435
441
  include_rai_reason= True,
436
- output_mime_type= "image/jpeg"
442
+ output_mime_type= 'image/jpeg'
437
443
  )
438
444
  )
439
445
  response1.generated_images[0].image.show()
@@ -448,13 +454,19 @@ Upscale image is not supported in Google AI.
448
454
  response2 = client.models.upscale_image(
449
455
  model='imagen-3.0-generate-001',
450
456
  image=response1.generated_images[0].image,
451
- config=types.UpscaleImageConfig(upscale_factor="x2")
457
+ upscale_factor='x2',
458
+ config=types.UpscaleImageConfig(
459
+ include_rai_reason= True,
460
+ output_mime_type= 'image/jpeg',
461
+ ),
452
462
  )
453
463
  response2.generated_images[0].image.show()
454
464
  ```
455
465
 
456
466
  #### Edit Image
457
467
 
468
+ Edit image uses a separate model from generate and upscale.
469
+
458
470
  Edit image is not supported in Google AI.
459
471
 
460
472
  ``` python
@@ -475,7 +487,7 @@ mask_ref_image = MaskReferenceImage(
475
487
  )
476
488
 
477
489
  response3 = client.models.edit_image(
478
- model='imagen-3.0-capability-preview-0930',
490
+ model='imagen-3.0-capability-001',
479
491
  prompt='Sunlight and clear sky',
480
492
  reference_images=[raw_ref_image, mask_ref_image],
481
493
  config=types.EditImageConfig(
@@ -489,6 +501,42 @@ response3 = client.models.edit_image(
489
501
  response3.generated_images[0].image.show()
490
502
  ```
491
503
 
504
+ ## Chats
505
+
506
+ Create a chat session to start a multi-turn conversations with the model.
507
+
508
+ ### Send Message
509
+
510
+ ```python
511
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
512
+ response = chat.send_message('tell me a story')
513
+ print(response.text)
514
+ ```
515
+
516
+ ### Streaming
517
+
518
+ ```python
519
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
520
+ for chunk in chat.send_message_stream('tell me a story'):
521
+ print(chunk.text)
522
+ ```
523
+
524
+ ### Async
525
+
526
+ ```python
527
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
528
+ response = await chat.send_message('tell me a story')
529
+ print(response.text)
530
+ ```
531
+
532
+ ### Async Streaming
533
+
534
+ ```python
535
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
536
+ async for chunk in chat.send_message_stream('tell me a story'):
537
+ print(chunk.text)
538
+ ```
539
+
492
540
  ## Files (Only Google AI)
493
541
 
494
542
  ``` python
@@ -531,20 +579,20 @@ else:
531
579
 
532
580
  cached_content = client.caches.create(
533
581
  model='gemini-1.5-pro-002',
534
- contents=[
535
- types.Content(
536
- role='user',
537
- parts=[
538
- types.Part.from_uri(
539
- file_uri=file_uris[0],
540
- mime_type='application/pdf'),
541
- types.Part.from_uri(
542
- file_uri=file_uris[1],
543
- mime_type='application/pdf',)])
544
- ],
545
582
  config=types.CreateCachedContentConfig(
546
- display_name='test cache',
583
+ contents=[
584
+ types.Content(
585
+ role='user',
586
+ parts=[
587
+ types.Part.from_uri(
588
+ file_uri=file_uris[0],
589
+ mime_type='application/pdf'),
590
+ types.Part.from_uri(
591
+ file_uri=file_uris[1],
592
+ mime_type='application/pdf',)])
593
+ ],
547
594
  system_instruction='What is the sum of the two pdfs?',
595
+ display_name='test cache',
548
596
  ttl='3600s',
549
597
  ),
550
598
  )
@@ -7,6 +7,12 @@
7
7
 
8
8
  -----
9
9
 
10
+ ## Installation
11
+
12
+ ``` cmd
13
+ pip install google-genai
14
+ ```
15
+
10
16
  ## Imports
11
17
 
12
18
  ``` python
@@ -402,10 +408,10 @@ response1 = client.models.generate_image(
402
408
  model='imagen-3.0-generate-001',
403
409
  prompt='An umbrella in the foreground, and a rainy night sky in the background',
404
410
  config=types.GenerateImageConfig(
405
- negative_prompt= "human",
411
+ negative_prompt= 'human',
406
412
  number_of_images= 1,
407
413
  include_rai_reason= True,
408
- output_mime_type= "image/jpeg"
414
+ output_mime_type= 'image/jpeg'
409
415
  )
410
416
  )
411
417
  response1.generated_images[0].image.show()
@@ -420,13 +426,19 @@ Upscale image is not supported in Google AI.
420
426
  response2 = client.models.upscale_image(
421
427
  model='imagen-3.0-generate-001',
422
428
  image=response1.generated_images[0].image,
423
- config=types.UpscaleImageConfig(upscale_factor="x2")
429
+ upscale_factor='x2',
430
+ config=types.UpscaleImageConfig(
431
+ include_rai_reason= True,
432
+ output_mime_type= 'image/jpeg',
433
+ ),
424
434
  )
425
435
  response2.generated_images[0].image.show()
426
436
  ```
427
437
 
428
438
  #### Edit Image
429
439
 
440
+ Edit image uses a separate model from generate and upscale.
441
+
430
442
  Edit image is not supported in Google AI.
431
443
 
432
444
  ``` python
@@ -447,7 +459,7 @@ mask_ref_image = MaskReferenceImage(
447
459
  )
448
460
 
449
461
  response3 = client.models.edit_image(
450
- model='imagen-3.0-capability-preview-0930',
462
+ model='imagen-3.0-capability-001',
451
463
  prompt='Sunlight and clear sky',
452
464
  reference_images=[raw_ref_image, mask_ref_image],
453
465
  config=types.EditImageConfig(
@@ -461,6 +473,42 @@ response3 = client.models.edit_image(
461
473
  response3.generated_images[0].image.show()
462
474
  ```
463
475
 
476
+ ## Chats
477
+
478
+ Create a chat session to start a multi-turn conversations with the model.
479
+
480
+ ### Send Message
481
+
482
+ ```python
483
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
484
+ response = chat.send_message('tell me a story')
485
+ print(response.text)
486
+ ```
487
+
488
+ ### Streaming
489
+
490
+ ```python
491
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
492
+ for chunk in chat.send_message_stream('tell me a story'):
493
+ print(chunk.text)
494
+ ```
495
+
496
+ ### Async
497
+
498
+ ```python
499
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
500
+ response = await chat.send_message('tell me a story')
501
+ print(response.text)
502
+ ```
503
+
504
+ ### Async Streaming
505
+
506
+ ```python
507
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
508
+ async for chunk in chat.send_message_stream('tell me a story'):
509
+ print(chunk.text)
510
+ ```
511
+
464
512
  ## Files (Only Google AI)
465
513
 
466
514
  ``` python
@@ -503,20 +551,20 @@ else:
503
551
 
504
552
  cached_content = client.caches.create(
505
553
  model='gemini-1.5-pro-002',
506
- contents=[
507
- types.Content(
508
- role='user',
509
- parts=[
510
- types.Part.from_uri(
511
- file_uri=file_uris[0],
512
- mime_type='application/pdf'),
513
- types.Part.from_uri(
514
- file_uri=file_uris[1],
515
- mime_type='application/pdf',)])
516
- ],
517
554
  config=types.CreateCachedContentConfig(
518
- display_name='test cache',
555
+ contents=[
556
+ types.Content(
557
+ role='user',
558
+ parts=[
559
+ types.Part.from_uri(
560
+ file_uri=file_uris[0],
561
+ mime_type='application/pdf'),
562
+ types.Part.from_uri(
563
+ file_uri=file_uris[1],
564
+ mime_type='application/pdf',)])
565
+ ],
519
566
  system_instruction='What is the sum of the two pdfs?',
567
+ display_name='test cache',
520
568
  ttl='3600s',
521
569
  ),
522
570
  )
@@ -809,4 +857,4 @@ print(async_pager[0])
809
857
  delete_job = client.batches.delete(name=job.name)
810
858
 
811
859
  delete_job
812
- ```
860
+ ```
@@ -16,7 +16,8 @@
16
16
  """Google Gen AI SDK"""
17
17
 
18
18
  from .client import Client
19
+ from . import version
19
20
 
20
- __version__ = '0.2.2'
21
+ __version__ = version.__version__
21
22
 
22
23
  __all__ = ['Client']
@@ -23,35 +23,66 @@ import datetime
23
23
  import json
24
24
  import os
25
25
  import sys
26
- from typing import Any, Optional, TypedDict, Union
26
+ from typing import Any, Optional, Tuple, TypedDict, Union
27
27
  from urllib.parse import urlparse, urlunparse
28
28
 
29
29
  import google.auth
30
30
  import google.auth.credentials
31
31
  from google.auth.transport.requests import AuthorizedSession
32
- from pydantic import BaseModel
32
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError
33
33
  import requests
34
34
 
35
35
  from . import errors
36
+ from . import version
36
37
 
37
38
 
38
- class HttpOptions(TypedDict):
39
+ class HttpOptions(BaseModel):
40
+ """HTTP options for the api client."""
41
+ model_config = ConfigDict(extra='forbid')
42
+
43
+ base_url: Optional[str] = Field(
44
+ default=None,
45
+ description="""The base URL for the AI platform service endpoint.""",
46
+ )
47
+ api_version: Optional[str] = Field(
48
+ default=None,
49
+ description="""Specifies the version of the API to use.""",
50
+ )
51
+ headers: Optional[dict[str, str]] = Field(
52
+ default=None,
53
+ description="""Additional HTTP headers to be sent with the request.""",
54
+ )
55
+ response_payload: Optional[dict] = Field(
56
+ default=None,
57
+ description="""If set, the response payload will be returned int the supplied dict.""",
58
+ )
59
+ timeout: Optional[Union[float, Tuple[float, float]]] = Field(
60
+ default=None,
61
+ description="""Timeout for the request in seconds.""",
62
+ )
63
+
64
+
65
+ class HttpOptionsDict(TypedDict):
39
66
  """HTTP options for the api client."""
40
67
 
41
- base_url: str = None
68
+ base_url: Optional[str] = None
42
69
  """The base URL for the AI platform service endpoint."""
43
- api_version: str = None
70
+ api_version: Optional[str] = None
44
71
  """Specifies the version of the API to use."""
45
- headers: dict[str, dict] = None
72
+ headers: Optional[dict[str, Union[str, list[str]]]] = None
46
73
  """Additional HTTP headers to be sent with the request."""
47
- response_payload: dict = None
74
+ response_payload: Optional[dict] = None
48
75
  """If set, the response payload will be returned int the supplied dict."""
76
+ timeout: Optional[Union[float, Tuple[float, float]]] = None
77
+ """Timeout for the request in seconds."""
78
+
79
+
80
+ HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
49
81
 
50
82
 
51
83
  def _append_library_version_headers(headers: dict[str, str]) -> None:
52
84
  """Appends the telemetry header to the headers dict."""
53
- # TODO: Automate revisions to the SDK library version.
54
- library_label = f'google-genai-sdk/0.2.2'
85
+ library_label = f'google-genai-sdk/{version.__version__}'
55
86
  language_label = 'gl-python/' + sys.version.split()[0]
56
87
  version_header_value = f'{library_label} {language_label}'
57
88
  if (
@@ -71,20 +102,24 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
71
102
 
72
103
 
73
104
  def _patch_http_options(
74
- options: HttpOptions, patch_options: HttpOptions
75
- ) -> HttpOptions:
105
+ options: HttpOptionsDict, patch_options: HttpOptionsDict
106
+ ) -> HttpOptionsDict:
76
107
  # use shallow copy so we don't override the original objects.
77
- copy_option = HttpOptions()
108
+ copy_option = HttpOptionsDict()
78
109
  copy_option.update(options)
79
- for k, v in patch_options.items():
110
+ for patch_key, patch_value in patch_options.items():
80
111
  # if both are dicts, update the copy.
81
112
  # This is to handle cases like merging headers.
82
- if isinstance(v, dict) and isinstance(copy_option.get(k, None), dict):
83
- copy_option[k] = {}
84
- copy_option[k].update(options[k]) # shallow copy from original options.
85
- copy_option[k].update(v)
86
- elif v is not None: # Accept empty values.
87
- copy_option[k] = v
113
+ if isinstance(patch_value, dict) and isinstance(
114
+ copy_option.get(patch_key, None), dict
115
+ ):
116
+ copy_option[patch_key] = {}
117
+ copy_option[patch_key].update(
118
+ options[patch_key]
119
+ ) # shallow copy from original options.
120
+ copy_option[patch_key].update(patch_value)
121
+ elif patch_value is not None: # Accept empty values.
122
+ copy_option[patch_key] = patch_value
88
123
  _append_library_version_headers(copy_option['headers'])
89
124
  return copy_option
90
125
 
@@ -98,10 +133,11 @@ def _join_url_path(base_url: str, path: str) -> str:
98
133
 
99
134
  @dataclass
100
135
  class HttpRequest:
101
- headers: dict[str, str]
136
+ headers: dict[str, Union[str, list[str]]]
102
137
  url: str
103
138
  method: str
104
139
  data: Union[dict[str, object], bytes]
140
+ timeout: Optional[Union[float, Tuple[float, float]]] = None
105
141
 
106
142
 
107
143
  class HttpResponse:
@@ -147,7 +183,7 @@ class ApiClient:
147
183
  credentials: google.auth.credentials.Credentials = None,
148
184
  project: Union[str, None] = None,
149
185
  location: Union[str, None] = None,
150
- http_options: HttpOptions = None,
186
+ http_options: HttpOptionsOrDict = None,
151
187
  ):
152
188
  self.vertexai = vertexai
153
189
  if self.vertexai is None:
@@ -163,11 +199,20 @@ class ApiClient:
163
199
  'Project/location and API key are mutually exclusive in the client initializer.'
164
200
  )
165
201
 
202
+ # Validate http_options if a dict is provided.
203
+ if isinstance(http_options, dict):
204
+ try:
205
+ HttpOptions.model_validate(http_options)
206
+ except ValidationError as e:
207
+ raise ValueError(f'Invalid http_options: {e}')
208
+ elif(isinstance(http_options, HttpOptions)):
209
+ http_options = http_options.model_dump()
210
+
166
211
  self.api_key: Optional[str] = None
167
212
  self.project = project or os.environ.get('GOOGLE_CLOUD_PROJECT', None)
168
213
  self.location = location or os.environ.get('GOOGLE_CLOUD_LOCATION', None)
169
214
  self._credentials = credentials
170
- self._http_options = HttpOptions()
215
+ self._http_options = HttpOptionsDict()
171
216
 
172
217
  if self.vertexai:
173
218
  if not self.project:
@@ -208,7 +253,7 @@ class ApiClient:
208
253
  http_method: str,
209
254
  path: str,
210
255
  request_dict: dict[str, object],
211
- http_options: HttpOptions = None,
256
+ http_options: HttpOptionsDict = None,
212
257
  ) -> HttpRequest:
213
258
  # Remove all special dict keys such as _url and _query.
214
259
  keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
@@ -232,6 +277,7 @@ class ApiClient:
232
277
  url=url,
233
278
  headers=patched_http_options['headers'],
234
279
  data=request_dict,
280
+ timeout=patched_http_options.get('timeout', None),
235
281
  )
236
282
 
237
283
  def _request(
@@ -241,17 +287,19 @@ class ApiClient:
241
287
  ) -> HttpResponse:
242
288
  if self.vertexai:
243
289
  if not self._credentials:
244
- self._credentials, _ = google.auth.default()
290
+ self._credentials, _ = google.auth.default(
291
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
292
+ )
245
293
  authed_session = AuthorizedSession(self._credentials)
246
294
  authed_session.stream = stream
247
295
  response = authed_session.request(
248
296
  http_request.method.upper(),
249
297
  http_request.url,
250
298
  headers=http_request.headers,
251
- data=json.dumps(http_request.data, cls=RequestJsonEncoder) if http_request.data else None,
252
- # TODO: support timeout in RequestOptions so it can be configured
253
- # per methods.
254
- timeout=None,
299
+ data=json.dumps(http_request.data, cls=RequestJsonEncoder)
300
+ if http_request.data
301
+ else None,
302
+ timeout=http_request.timeout,
255
303
  )
256
304
  errors.APIError.raise_for_response(response)
257
305
  return HttpResponse(
@@ -273,13 +321,14 @@ class ApiClient:
273
321
  data = http_request.data
274
322
 
275
323
  http_session = requests.Session()
276
- request = requests.Request(
324
+ response = http_session.request(
277
325
  method=http_request.method,
278
326
  url=http_request.url,
279
327
  headers=http_request.headers,
280
328
  data=data,
281
- ).prepare()
282
- response = http_session.send(request, stream=stream)
329
+ timeout=http_request.timeout,
330
+ stream=stream,
331
+ )
283
332
  errors.APIError.raise_for_response(response)
284
333
  return HttpResponse(
285
334
  response.headers, response if stream else [response.text]
@@ -290,7 +339,9 @@ class ApiClient:
290
339
  ):
291
340
  if self.vertexai:
292
341
  if not self._credentials:
293
- self._credentials, _ = google.auth.default()
342
+ self._credentials, _ = google.auth.default(
343
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
344
+ )
294
345
  return await asyncio.to_thread(
295
346
  self._request,
296
347
  http_request,
@@ -303,8 +354,10 @@ class ApiClient:
303
354
  stream=stream,
304
355
  )
305
356
 
306
- def get_read_only_http_options(self) -> HttpOptions:
307
- copied = HttpOptions()
357
+ def get_read_only_http_options(self) -> HttpOptionsDict:
358
+ copied = HttpOptionsDict()
359
+ if isinstance(self._http_options, BaseModel):
360
+ self._http_options = self._http_options.model_dump()
308
361
  copied.update(self._http_options)
309
362
  return copied
310
363
 
@@ -313,7 +366,7 @@ class ApiClient:
313
366
  http_method: str,
314
367
  path: str,
315
368
  request_dict: dict[str, object],
316
- http_options: HttpOptions = None,
369
+ http_options: HttpOptionsDict = None,
317
370
  ):
318
371
  http_request = self._build_request(
319
372
  http_method, path, request_dict, http_options
@@ -328,7 +381,7 @@ class ApiClient:
328
381
  http_method: str,
329
382
  path: str,
330
383
  request_dict: dict[str, object],
331
- http_options: HttpOptions = None,
384
+ http_options: HttpOptionsDict = None,
332
385
  ):
333
386
  http_request = self._build_request(
334
387
  http_method, path, request_dict, http_options
@@ -345,7 +398,7 @@ class ApiClient:
345
398
  http_method: str,
346
399
  path: str,
347
400
  request_dict: dict[str, object],
348
- http_options: HttpOptions = None,
401
+ http_options: HttpOptionsDict = None,
349
402
  ) -> dict[str, object]:
350
403
  http_request = self._build_request(
351
404
  http_method, path, request_dict, http_options
@@ -361,7 +414,7 @@ class ApiClient:
361
414
  http_method: str,
362
415
  path: str,
363
416
  request_dict: dict[str, object],
364
- http_options: HttpOptions = None,
417
+ http_options: HttpOptionsDict = None,
365
418
  ):
366
419
  http_request = self._build_request(
367
420
  http_method, path, request_dict, http_options
@@ -58,8 +58,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
58
58
  )
59
59
 
60
60
 
61
- def _raise_if_schema_unsupported(client, schema: types.Schema):
62
- if not client.vertexai:
61
+ def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
62
+ if not variant == 'VERTEX_AI':
63
63
  _raise_for_any_of_if_mldev(schema)
64
64
  _raise_for_default_if_mldev(schema)
65
65
  _raise_for_nullable_if_mldev(schema)
@@ -112,7 +112,7 @@ def _is_default_value_compatible(
112
112
 
113
113
 
114
114
  def _parse_schema_from_parameter(
115
- client, param: inspect.Parameter, func_name: str
115
+ variant: str, param: inspect.Parameter, func_name: str
116
116
  ) -> types.Schema:
117
117
  """parse schema from parameter.
118
118
 
@@ -130,7 +130,7 @@ def _parse_schema_from_parameter(
130
130
  raise ValueError(default_value_error_msg)
131
131
  schema.default = param.default
132
132
  schema.type = _py_builtin_type_to_schema_type[param.annotation]
133
- _raise_if_schema_unsupported(client, schema)
133
+ _raise_if_schema_unsupported(variant, schema)
134
134
  return schema
135
135
  if (
136
136
  isinstance(param.annotation, typing_types.UnionType)
@@ -149,7 +149,7 @@ def _parse_schema_from_parameter(
149
149
  schema.nullable = True
150
150
  continue
151
151
  schema_in_any_of = _parse_schema_from_parameter(
152
- client,
152
+ variant,
153
153
  inspect.Parameter(
154
154
  'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
155
155
  ),
@@ -170,9 +170,8 @@ def _parse_schema_from_parameter(
170
170
  ):
171
171
  if not _is_default_value_compatible(param.default, param.annotation):
172
172
  raise ValueError(default_value_error_msg)
173
- # TODO: b/379715133 - handle pydantic model default value
174
173
  schema.default = param.default
175
- _raise_if_schema_unsupported(client, schema)
174
+ _raise_if_schema_unsupported(variant, schema)
176
175
  return schema
177
176
  if isinstance(param.annotation, _GenericAlias) or isinstance(
178
177
  param.annotation, typing_types.GenericAlias
@@ -185,7 +184,7 @@ def _parse_schema_from_parameter(
185
184
  if not _is_default_value_compatible(param.default, param.annotation):
186
185
  raise ValueError(default_value_error_msg)
187
186
  schema.default = param.default
188
- _raise_if_schema_unsupported(client, schema)
187
+ _raise_if_schema_unsupported(variant, schema)
189
188
  return schema
190
189
  if origin is Literal:
191
190
  if not all(isinstance(arg, str) for arg in args):
@@ -198,12 +197,12 @@ def _parse_schema_from_parameter(
198
197
  if not _is_default_value_compatible(param.default, param.annotation):
199
198
  raise ValueError(default_value_error_msg)
200
199
  schema.default = param.default
201
- _raise_if_schema_unsupported(client, schema)
200
+ _raise_if_schema_unsupported(variant, schema)
202
201
  return schema
203
202
  if origin is list:
204
203
  schema.type = 'ARRAY'
205
204
  schema.items = _parse_schema_from_parameter(
206
- client,
205
+ variant,
207
206
  inspect.Parameter(
208
207
  'item',
209
208
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -215,7 +214,7 @@ def _parse_schema_from_parameter(
215
214
  if not _is_default_value_compatible(param.default, param.annotation):
216
215
  raise ValueError(default_value_error_msg)
217
216
  schema.default = param.default
218
- _raise_if_schema_unsupported(client, schema)
217
+ _raise_if_schema_unsupported(variant, schema)
219
218
  return schema
220
219
  if origin is Union:
221
220
  schema.any_of = []
@@ -226,7 +225,7 @@ def _parse_schema_from_parameter(
226
225
  schema.nullable = True
227
226
  continue
228
227
  schema_in_any_of = _parse_schema_from_parameter(
229
- client,
228
+ variant,
230
229
  inspect.Parameter(
231
230
  'item',
232
231
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -250,7 +249,7 @@ def _parse_schema_from_parameter(
250
249
  if not _is_default_value_compatible(param.default, param.annotation):
251
250
  raise ValueError(default_value_error_msg)
252
251
  schema.default = param.default
253
- _raise_if_schema_unsupported(client, schema)
252
+ _raise_if_schema_unsupported(variant, schema)
254
253
  return schema
255
254
  # all other generic alias will be invoked in raise branch
256
255
  if (
@@ -258,17 +257,16 @@ def _parse_schema_from_parameter(
258
257
  # for user defined class, we only support pydantic model
259
258
  and issubclass(param.annotation, pydantic.BaseModel)
260
259
  ):
261
- if param.default is not inspect.Parameter.empty:
262
- # TODO: b/379715133 - handle pydantic model default value
263
- raise ValueError(
264
- f'Default value {param.default} of Pydantic model{param} of function'
265
- f' {func_name} is not supported.'
266
- )
260
+ if (
261
+ param.default is not inspect.Parameter.empty
262
+ and param.default is not None
263
+ ):
264
+ schema.default = param.default
267
265
  schema.type = 'OBJECT'
268
266
  schema.properties = {}
269
267
  for field_name, field_info in param.annotation.model_fields.items():
270
268
  schema.properties[field_name] = _parse_schema_from_parameter(
271
- client,
269
+ variant,
272
270
  inspect.Parameter(
273
271
  field_name,
274
272
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -276,7 +274,7 @@ def _parse_schema_from_parameter(
276
274
  ),
277
275
  func_name,
278
276
  )
279
- _raise_if_schema_unsupported(client, schema)
277
+ _raise_if_schema_unsupported(variant, schema)
280
278
  return schema
281
279
  raise ValueError(
282
280
  f'Failed to parse the parameter {param} of function {func_name} for'
@@ -294,4 +292,3 @@ def _get_required_fields(schema: types.Schema) -> list[str]:
294
292
  for field_name, field_schema in schema.properties.items()
295
293
  if not field_schema.nullable and field_schema.default is None
296
294
  ]
297
-