google-genai 1.4.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@
19
19
  The BaseApiClient is intended to be a private module and is subject to change.
20
20
  """
21
21
 
22
+ import anyio
22
23
  import asyncio
23
24
  import copy
24
25
  from dataclasses import dataclass
@@ -28,22 +29,21 @@ import json
28
29
  import logging
29
30
  import os
30
31
  import sys
31
- from typing import Any, AsyncIterator, Optional, Tuple, TypedDict, Union
32
+ from typing import Any, AsyncIterator, Optional, Tuple, Union
32
33
  from urllib.parse import urlparse, urlunparse
33
34
  import google.auth
34
35
  import google.auth.credentials
35
36
  from google.auth.credentials import Credentials
36
- from google.auth.transport.requests import AuthorizedSession
37
37
  from google.auth.transport.requests import Request
38
38
  import httpx
39
- from pydantic import BaseModel, ConfigDict, Field, ValidationError
40
- import requests
39
+ from pydantic import BaseModel, Field, ValidationError
41
40
  from . import _common
42
41
  from . import errors
43
42
  from . import version
44
43
  from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
45
44
 
46
45
  logger = logging.getLogger('google_genai._api_client')
46
+ CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
47
47
 
48
48
 
49
49
  def _append_library_version_headers(headers: dict[str, str]) -> None:
@@ -86,7 +86,8 @@ def _patch_http_options(
86
86
  copy_option[patch_key].update(patch_value)
87
87
  elif patch_value is not None: # Accept empty values.
88
88
  copy_option[patch_key] = patch_value
89
- _append_library_version_headers(copy_option['headers'])
89
+ if copy_option['headers']:
90
+ _append_library_version_headers(copy_option['headers'])
90
91
  return copy_option
91
92
 
92
93
 
@@ -101,7 +102,7 @@ def _join_url_path(base_url: str, path: str) -> str:
101
102
  return urlunparse(parsed_base._replace(path=base_path + '/' + path))
102
103
 
103
104
 
104
- def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
105
+ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
105
106
  """Loads google auth credentials and project id."""
106
107
  credentials, loaded_project_id = google.auth.default(
107
108
  scopes=['https://www.googleapis.com/auth/cloud-platform'],
@@ -118,8 +119,9 @@ def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
118
119
  return credentials, project
119
120
 
120
121
 
121
- def _refresh_auth(credentials: Credentials) -> None:
122
+ def _refresh_auth(credentials: Credentials) -> Credentials:
122
123
  credentials.refresh(Request())
124
+ return credentials
123
125
 
124
126
 
125
127
  @dataclass
@@ -147,7 +149,7 @@ class HttpResponse:
147
149
 
148
150
  def __init__(
149
151
  self,
150
- headers: dict[str, str],
152
+ headers: Union[dict[str, str], httpx.Headers],
151
153
  response_stream: Union[Any, str] = None,
152
154
  byte_stream: Union[Any, bytes] = None,
153
155
  ):
@@ -200,14 +202,19 @@ class HttpResponse:
200
202
  yield c
201
203
  else:
202
204
  # Iterator of objects retrieved from the API.
203
- async for chunk in self.response_stream.aiter_lines():
204
- # This is httpx.Response.
205
- if chunk:
206
- # In async streaming mode, the chunk of JSON is prefixed with "data:"
207
- # which we must strip before parsing.
208
- if chunk.startswith('data: '):
209
- chunk = chunk[len('data: ') :]
210
- yield json.loads(chunk)
205
+ if hasattr(self.response_stream, 'aiter_lines'):
206
+ async for chunk in self.response_stream.aiter_lines():
207
+ # This is httpx.Response.
208
+ if chunk:
209
+ # In async streaming mode, the chunk of JSON is prefixed with "data:"
210
+ # which we must strip before parsing.
211
+ if chunk.startswith('data: '):
212
+ chunk = chunk[len('data: ') :]
213
+ yield json.loads(chunk)
214
+ else:
215
+ raise ValueError(
216
+ 'Error parsing streaming response.'
217
+ )
211
218
 
212
219
  def byte_segments(self):
213
220
  if isinstance(self.byte_stream, list):
@@ -265,7 +272,9 @@ class BaseApiClient:
265
272
  validated_http_options: dict[str, Any]
266
273
  if isinstance(http_options, dict):
267
274
  try:
268
- validated_http_options = HttpOptions.model_validate(http_options).model_dump()
275
+ validated_http_options = HttpOptions.model_validate(
276
+ http_options
277
+ ).model_dump()
269
278
  except ValidationError as e:
270
279
  raise ValueError(f'Invalid http_options: {e}')
271
280
  elif isinstance(http_options, HttpOptions):
@@ -351,7 +360,9 @@ class BaseApiClient:
351
360
  self._http_options['headers']['x-goog-api-key'] = self.api_key
352
361
  # Update the http options with the user provided http options.
353
362
  if http_options:
354
- self._http_options = _patch_http_options(self._http_options, validated_http_options)
363
+ self._http_options = _patch_http_options(
364
+ self._http_options, validated_http_options
365
+ )
355
366
  else:
356
367
  _append_library_version_headers(self._http_options['headers'])
357
368
 
@@ -359,8 +370,27 @@ class BaseApiClient:
359
370
  url_parts = urlparse(self._http_options['base_url'])
360
371
  return url_parts._replace(scheme='wss').geturl()
361
372
 
362
- async def _async_access_token(self) -> str:
373
+ def _access_token(self) -> str:
363
374
  """Retrieves the access token for the credentials."""
375
+ if not self._credentials:
376
+ self._credentials, project = _load_auth(project=self.project)
377
+ if not self.project:
378
+ self.project = project
379
+
380
+ if self._credentials:
381
+ if (
382
+ self._credentials.expired or not self._credentials.token
383
+ ):
384
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
385
+ _refresh_auth(self._credentials)
386
+ if not self._credentials.token:
387
+ raise RuntimeError('Could not resolve API token from the environment')
388
+ return self._credentials.token
389
+ else:
390
+ raise RuntimeError('Could not resolve API token from the environment')
391
+
392
+ async def _async_access_token(self) -> str:
393
+ """Retrieves the access token for the credentials asynchronously."""
364
394
  if not self._credentials:
365
395
  async with self._auth_lock:
366
396
  # This ensures that only one coroutine can execute the auth logic at a
@@ -373,17 +403,22 @@ class BaseApiClient:
373
403
  if not self.project:
374
404
  self.project = project
375
405
 
376
- if self._credentials.expired or not self._credentials.token:
377
- # Only refresh when it needs to. Default expiration is 3600 seconds.
378
- async with self._auth_lock:
379
- if self._credentials.expired or not self._credentials.token:
380
- # Double check that the credentials expired before refreshing.
381
- await asyncio.to_thread(_refresh_auth, self._credentials)
406
+ if self._credentials:
407
+ if (
408
+ self._credentials.expired or not self._credentials.token
409
+ ):
410
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
411
+ async with self._auth_lock:
412
+ if self._credentials.expired or not self._credentials.token:
413
+ # Double check that the credentials expired before refreshing.
414
+ await asyncio.to_thread(_refresh_auth, self._credentials)
382
415
 
383
- if not self._credentials.token:
384
- raise RuntimeError('Could not resolve API token from the environment')
416
+ if not self._credentials.token:
417
+ raise RuntimeError('Could not resolve API token from the environment')
385
418
 
386
- return self._credentials.token
419
+ return self._credentials.token
420
+ else:
421
+ raise RuntimeError('Could not resolve API token from the environment')
387
422
 
388
423
  def _build_request(
389
424
  self,
@@ -424,12 +459,16 @@ class BaseApiClient:
424
459
  ):
425
460
  path = f'projects/{self.project}/locations/{self.location}/' + path
426
461
  url = _join_url_path(
427
- patched_http_options['base_url'],
428
- patched_http_options['api_version'] + '/' + path,
462
+ patched_http_options.get('base_url', ''),
463
+ patched_http_options.get('api_version', '') + '/' + path,
429
464
  )
430
465
 
431
- timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get('timeout', None)
466
+ timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
467
+ 'timeout', None
468
+ )
432
469
  if timeout_in_seconds:
470
+ # HttpOptions.timeout is in milliseconds. But httpx.Client.request()
471
+ # expects seconds.
433
472
  timeout_in_seconds = timeout_in_seconds / 1000.0
434
473
  else:
435
474
  timeout_in_seconds = None
@@ -447,74 +486,80 @@ class BaseApiClient:
447
486
  http_request: HttpRequest,
448
487
  stream: bool = False,
449
488
  ) -> HttpResponse:
489
+ data: Optional[Union[str, bytes]] = None
450
490
  if self.vertexai and not self.api_key:
451
- if not self._credentials:
452
- self._credentials, _ = _load_auth(project=self.project)
453
- if self._credentials.quota_project_id:
491
+ http_request.headers['Authorization'] = (
492
+ f'Bearer {self._access_token()}'
493
+ )
494
+ if self._credentials and self._credentials.quota_project_id:
454
495
  http_request.headers['x-goog-user-project'] = (
455
496
  self._credentials.quota_project_id
456
497
  )
457
- authed_session = AuthorizedSession(self._credentials)
458
- authed_session.stream = stream
459
- response = authed_session.request(
460
- http_request.method.upper(),
461
- http_request.url,
498
+ data = json.dumps(http_request.data)
499
+ else:
500
+ if http_request.data:
501
+ if not isinstance(http_request.data, bytes):
502
+ data = json.dumps(http_request.data)
503
+ else:
504
+ data = http_request.data
505
+
506
+ if stream:
507
+ client = httpx.Client()
508
+ httpx_request = client.build_request(
509
+ method=http_request.method,
510
+ url=http_request.url,
511
+ content=data,
462
512
  headers=http_request.headers,
463
- data=json.dumps(http_request.data) if http_request.data else None,
464
513
  timeout=http_request.timeout,
465
514
  )
515
+ response = client.send(httpx_request, stream=stream)
466
516
  errors.APIError.raise_for_response(response)
467
517
  return HttpResponse(
468
518
  response.headers, response if stream else [response.text]
469
519
  )
470
520
  else:
471
- return self._request_unauthorized(http_request, stream)
472
-
473
- def _request_unauthorized(
474
- self,
475
- http_request: HttpRequest,
476
- stream: bool = False,
477
- ) -> HttpResponse:
478
- data: Optional[Union[str, bytes]] = None
479
- if http_request.data:
480
- if not isinstance(http_request.data, bytes):
481
- data = json.dumps(http_request.data)
482
- else:
483
- data = http_request.data
484
-
485
- http_session = requests.Session()
486
- response = http_session.request(
487
- method=http_request.method,
488
- url=http_request.url,
489
- headers=http_request.headers,
490
- data=data,
491
- timeout=http_request.timeout,
492
- stream=stream,
493
- )
494
- errors.APIError.raise_for_response(response)
495
- return HttpResponse(
496
- response.headers, response if stream else [response.text]
497
- )
521
+ with httpx.Client() as client:
522
+ response = client.request(
523
+ method=http_request.method,
524
+ url=http_request.url,
525
+ headers=http_request.headers,
526
+ content=data,
527
+ timeout=http_request.timeout,
528
+ )
529
+ errors.APIError.raise_for_response(response)
530
+ return HttpResponse(
531
+ response.headers, response if stream else [response.text]
532
+ )
498
533
 
499
534
  async def _async_request(
500
535
  self, http_request: HttpRequest, stream: bool = False
501
536
  ):
502
- if self.vertexai:
537
+ data: Optional[Union[str, bytes]] = None
538
+ if self.vertexai and not self.api_key:
503
539
  http_request.headers['Authorization'] = (
504
540
  f'Bearer {await self._async_access_token()}'
505
541
  )
506
- if self._credentials.quota_project_id:
542
+ if self._credentials and self._credentials.quota_project_id:
507
543
  http_request.headers['x-goog-user-project'] = (
508
544
  self._credentials.quota_project_id
509
545
  )
546
+ data = json.dumps(http_request.data)
547
+ else:
548
+ if http_request.data:
549
+ if not isinstance(http_request.data, bytes):
550
+ data = json.dumps(http_request.data)
551
+ else:
552
+ data = http_request.data
553
+
510
554
  if stream:
511
- httpx_request = httpx.Request(
555
+ aclient = httpx.AsyncClient()
556
+ httpx_request = aclient.build_request(
512
557
  method=http_request.method,
513
558
  url=http_request.url,
514
- content=json.dumps(http_request.data),
559
+ content=data,
515
560
  headers=http_request.headers,
561
+ timeout=http_request.timeout,
516
562
  )
517
- aclient = httpx.AsyncClient()
518
563
  response = await aclient.send(
519
564
  httpx_request,
520
565
  stream=stream,
@@ -529,7 +574,7 @@ class BaseApiClient:
529
574
  method=http_request.method,
530
575
  url=http_request.url,
531
576
  headers=http_request.headers,
532
- content=json.dumps(http_request.data) if http_request.data else None,
577
+ content=data,
533
578
  timeout=http_request.timeout,
534
579
  )
535
580
  errors.APIError.raise_for_response(response)
@@ -615,7 +660,7 @@ class BaseApiClient:
615
660
 
616
661
  def upload_file(
617
662
  self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
618
- ) -> str:
663
+ ) -> dict[str, str]:
619
664
  """Transfers a file to the given URL.
620
665
 
621
666
  Args:
@@ -637,7 +682,7 @@ class BaseApiClient:
637
682
 
638
683
  def _upload_fd(
639
684
  self, file: io.IOBase, upload_url: str, upload_size: int
640
- ) -> str:
685
+ ) -> dict[str, str]:
641
686
  """Transfers a file to the given URL.
642
687
 
643
688
  Args:
@@ -652,7 +697,7 @@ class BaseApiClient:
652
697
  offset = 0
653
698
  # Upload the file in chunks
654
699
  while True:
655
- file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
700
+ file_chunk = file.read(CHUNK_SIZE)
656
701
  chunk_size = 0
657
702
  if file_chunk:
658
703
  chunk_size = len(file_chunk)
@@ -671,7 +716,7 @@ class BaseApiClient:
671
716
  data=file_chunk,
672
717
  )
673
718
 
674
- response = self._request_unauthorized(request, stream=False)
719
+ response = self._request(request, stream=False)
675
720
  offset += chunk_size
676
721
  if response.headers['X-Goog-Upload-Status'] != 'active':
677
722
  break # upload is complete or it has been interrupted.
@@ -679,13 +724,12 @@ class BaseApiClient:
679
724
  if upload_size <= offset: # Status is not finalized.
680
725
  raise ValueError(
681
726
  'All content has been uploaded, but the upload status is not'
682
- f' finalized. {response.headers}, body: {response.json}'
727
+ f' finalized.'
683
728
  )
684
729
 
685
730
  if response.headers['X-Goog-Upload-Status'] != 'final':
686
731
  raise ValueError(
687
- 'Failed to upload file: Upload status is not finalized. headers:'
688
- f' {response.headers}, body: {response.json}'
732
+ 'Failed to upload file: Upload status is not finalized.'
689
733
  )
690
734
  return response.json
691
735
 
@@ -708,32 +752,31 @@ class BaseApiClient:
708
752
  self,
709
753
  http_request: HttpRequest,
710
754
  ) -> HttpResponse:
711
- data: str | bytes | None = None
755
+ data: Optional[Union[str, bytes]] = None
712
756
  if http_request.data:
713
757
  if not isinstance(http_request.data, bytes):
714
758
  data = json.dumps(http_request.data)
715
759
  else:
716
760
  data = http_request.data
717
761
 
718
- http_session = requests.Session()
719
- response = http_session.request(
720
- method=http_request.method,
721
- url=http_request.url,
722
- headers=http_request.headers,
723
- data=data,
724
- timeout=http_request.timeout,
725
- stream=False,
726
- )
762
+ with httpx.Client(follow_redirects=True) as client:
763
+ response = client.request(
764
+ method=http_request.method,
765
+ url=http_request.url,
766
+ headers=http_request.headers,
767
+ content=data,
768
+ timeout=http_request.timeout,
769
+ )
727
770
 
728
- errors.APIError.raise_for_response(response)
729
- return HttpResponse(response.headers, byte_stream=[response.content])
771
+ errors.APIError.raise_for_response(response)
772
+ return HttpResponse(response.headers, byte_stream=[response.read()])
730
773
 
731
774
  async def async_upload_file(
732
775
  self,
733
776
  file_path: Union[str, io.IOBase],
734
777
  upload_url: str,
735
778
  upload_size: int,
736
- ) -> str:
779
+ ) -> dict[str, str]:
737
780
  """Transfers a file asynchronously to the given URL.
738
781
 
739
782
  Args:
@@ -746,19 +789,20 @@ class BaseApiClient:
746
789
  returns:
747
790
  The response json object from the finalize request.
748
791
  """
749
- return await asyncio.to_thread(
750
- self.upload_file,
751
- file_path,
752
- upload_url,
753
- upload_size,
754
- )
792
+ if isinstance(file_path, io.IOBase):
793
+ return await self._async_upload_fd(file_path, upload_url, upload_size)
794
+ else:
795
+ file = anyio.Path(file_path)
796
+ fd = await file.open('rb')
797
+ async with fd:
798
+ return await self._async_upload_fd(fd, upload_url, upload_size)
755
799
 
756
800
  async def _async_upload_fd(
757
801
  self,
758
- file: io.IOBase,
802
+ file: Union[io.IOBase, anyio.AsyncFile],
759
803
  upload_url: str,
760
804
  upload_size: int,
761
- ) -> str:
805
+ ) -> dict[str, str]:
762
806
  """Transfers a file asynchronously to the given URL.
763
807
 
764
808
  Args:
@@ -770,12 +814,45 @@ class BaseApiClient:
770
814
  returns:
771
815
  The response json object from the finalize request.
772
816
  """
773
- return await asyncio.to_thread(
774
- self._upload_fd,
775
- file,
776
- upload_url,
777
- upload_size,
778
- )
817
+ async with httpx.AsyncClient() as aclient:
818
+ offset = 0
819
+ # Upload the file in chunks
820
+ while True:
821
+ if isinstance(file, io.IOBase):
822
+ file_chunk = file.read(CHUNK_SIZE)
823
+ else:
824
+ file_chunk = await file.read(CHUNK_SIZE)
825
+ chunk_size = 0
826
+ if file_chunk:
827
+ chunk_size = len(file_chunk)
828
+ upload_command = 'upload'
829
+ # If last chunk, finalize the upload.
830
+ if chunk_size + offset >= upload_size:
831
+ upload_command += ', finalize'
832
+ response = await aclient.request(
833
+ method='POST',
834
+ url=upload_url,
835
+ content=file_chunk,
836
+ headers={
837
+ 'X-Goog-Upload-Command': upload_command,
838
+ 'X-Goog-Upload-Offset': str(offset),
839
+ 'Content-Length': str(chunk_size),
840
+ },
841
+ )
842
+ offset += chunk_size
843
+ if response.headers.get('x-goog-upload-status') != 'active':
844
+ break # upload is complete or it has been interrupted.
845
+
846
+ if upload_size <= offset: # Status is not finalized.
847
+ raise ValueError(
848
+ 'All content has been uploaded, but the upload status is not'
849
+ f' finalized.'
850
+ )
851
+ if response.headers.get('x-goog-upload-status') != 'final':
852
+ raise ValueError(
853
+ 'Failed to upload file: Upload status is not finalized.'
854
+ )
855
+ return response.json()
779
856
 
780
857
  async def async_download_file(self, path: str, http_options):
781
858
  """Downloads the file data.
@@ -787,12 +864,31 @@ class BaseApiClient:
787
864
  returns:
788
865
  The file bytes
789
866
  """
790
- return await asyncio.to_thread(
791
- self.download_file,
792
- path,
793
- http_options,
867
+ http_request = self._build_request(
868
+ 'get', path=path, request_dict={}, http_options=http_options
794
869
  )
795
870
 
871
+ data: Optional[Union[str, bytes]] = None
872
+ if http_request.data:
873
+ if not isinstance(http_request.data, bytes):
874
+ data = json.dumps(http_request.data)
875
+ else:
876
+ data = http_request.data
877
+
878
+ async with httpx.AsyncClient(follow_redirects=True) as aclient:
879
+ response = await aclient.request(
880
+ method=http_request.method,
881
+ url=http_request.url,
882
+ headers=http_request.headers,
883
+ content=data,
884
+ timeout=http_request.timeout,
885
+ )
886
+ errors.APIError.raise_for_response(response)
887
+
888
+ return HttpResponse(
889
+ response.headers, byte_stream=[response.read()]
890
+ ).byte_stream[0]
891
+
796
892
  # This method does nothing in the real api client. It is used in the
797
893
  # replay_api_client to verify the response from the SDK method matches the
798
894
  # recorded response.
@@ -17,7 +17,7 @@ import inspect
17
17
  import sys
18
18
  import types as builtin_types
19
19
  import typing
20
- from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Union
20
+ from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]
21
21
 
22
22
  import pydantic
23
23
 
@@ -41,19 +41,11 @@ _py_builtin_type_to_schema_type = {
41
41
 
42
42
 
43
43
  def _is_builtin_primitive_or_compound(
44
- annotation: inspect.Parameter.annotation,
44
+ annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
45
45
  ) -> bool:
46
46
  return annotation in _py_builtin_type_to_schema_type.keys()
47
47
 
48
48
 
49
- def _raise_for_any_of_if_mldev(schema: types.Schema):
50
- if schema.any_of:
51
- raise ValueError(
52
- 'AnyOf is not supported in function declaration schema for'
53
- ' the Gemini API.'
54
- )
55
-
56
-
57
49
  def _raise_for_default_if_mldev(schema: types.Schema):
58
50
  if schema.default is not None:
59
51
  raise ValueError(
@@ -64,12 +56,11 @@ def _raise_for_default_if_mldev(schema: types.Schema):
64
56
 
65
57
  def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
66
58
  if api_option == 'GEMINI_API':
67
- _raise_for_any_of_if_mldev(schema)
68
59
  _raise_for_default_if_mldev(schema)
69
60
 
70
61
 
71
62
  def _is_default_value_compatible(
72
- default_value: Any, annotation: inspect.Parameter.annotation
63
+ default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
73
64
  ) -> bool:
74
65
  # None type is expected to be handled external to this function
75
66
  if _is_builtin_primitive_or_compound(annotation):
@@ -292,8 +283,7 @@ def _parse_schema_from_parameter(
292
283
  ),
293
284
  func_name,
294
285
  )
295
- if api_option == 'VERTEX_AI':
296
- schema.required = _get_required_fields(schema)
286
+ schema.required = _get_required_fields(schema)
297
287
  _raise_if_schema_unsupported(api_option, schema)
298
288
  return schema
299
289
  raise ValueError(
@@ -304,9 +294,9 @@ def _parse_schema_from_parameter(
304
294
  )
305
295
 
306
296
 
307
- def _get_required_fields(schema: types.Schema) -> list[str]:
297
+ def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
308
298
  if not schema.properties:
309
- return
299
+ return None
310
300
  return [
311
301
  field_name
312
302
  for field_name, field_schema in schema.properties.items()
google/genai/_common.py CHANGED
@@ -188,6 +188,8 @@ def _remove_extra_fields(
188
188
  if isinstance(item, dict):
189
189
  _remove_extra_fields(typing.get_args(annotation)[0], item)
190
190
 
191
+ T = typing.TypeVar('T', bound='BaseModel')
192
+
191
193
 
192
194
  class BaseModel(pydantic.BaseModel):
193
195
 
@@ -201,12 +203,13 @@ class BaseModel(pydantic.BaseModel):
201
203
  arbitrary_types_allowed=True,
202
204
  ser_json_bytes='base64',
203
205
  val_json_bytes='base64',
206
+ ignored_types=(typing.TypeVar,)
204
207
  )
205
208
 
206
209
  @classmethod
207
210
  def _from_response(
208
- cls, *, response: dict[str, object], kwargs: dict[str, object]
209
- ) -> 'BaseModel':
211
+ cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
212
+ ) -> T:
210
213
  # To maintain forward compatibility, we need to remove extra fields from
211
214
  # the response.
212
215
  # We will provide another mechanism to allow users to access these fields.