google-genai 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,7 +14,10 @@
14
14
  #
15
15
 
16
16
 
17
- """Base client for calling HTTP APIs sending and receiving JSON."""
17
+ """Base client for calling HTTP APIs sending and receiving JSON.
18
+
19
+ The BaseApiClient is intended to be a private module and is subject to change.
20
+ """
18
21
 
19
22
  import asyncio
20
23
  import copy
@@ -25,19 +28,23 @@ import json
25
28
  import logging
26
29
  import os
27
30
  import sys
28
- from typing import Any, Optional, Tuple, TypedDict, Union
31
+ from typing import Any, AsyncIterator, Optional, Tuple, TypedDict, Union
29
32
  from urllib.parse import urlparse, urlunparse
30
-
31
33
  import google.auth
32
34
  import google.auth.credentials
35
+ from google.auth.credentials import Credentials
33
36
  from google.auth.transport.requests import AuthorizedSession
37
+ from google.auth.transport.requests import Request
38
+ import httpx
34
39
  from pydantic import BaseModel, ConfigDict, Field, ValidationError
35
40
  import requests
36
-
41
+ from . import _common
37
42
  from . import errors
38
43
  from . import version
39
44
  from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
40
45
 
46
+ logger = logging.getLogger('google_genai._api_client')
47
+
41
48
 
42
49
  def _append_library_version_headers(headers: dict[str, str]) -> None:
43
50
  """Appends the telemetry header to the headers dict."""
@@ -61,7 +68,7 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
61
68
 
62
69
 
63
70
  def _patch_http_options(
64
- options: HttpOptionsDict, patch_options: HttpOptionsDict
71
+ options: HttpOptionsDict, patch_options: dict[str, Any]
65
72
  ) -> HttpOptionsDict:
66
73
  # use shallow copy so we don't override the original objects.
67
74
  copy_option = HttpOptionsDict()
@@ -94,6 +101,27 @@ def _join_url_path(base_url: str, path: str) -> str:
94
101
  return urlunparse(parsed_base._replace(path=base_path + '/' + path))
95
102
 
96
103
 
104
+ def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
105
+ """Loads google auth credentials and project id."""
106
+ credentials, loaded_project_id = google.auth.default(
107
+ scopes=['https://www.googleapis.com/auth/cloud-platform'],
108
+ )
109
+
110
+ if not project:
111
+ project = loaded_project_id
112
+
113
+ if not project:
114
+ raise ValueError(
115
+ 'Could not resolve project using application default credentials.'
116
+ )
117
+
118
+ return credentials, project
119
+
120
+
121
+ def _refresh_auth(credentials: Credentials) -> None:
122
+ credentials.refresh(Request())
123
+
124
+
97
125
  @dataclass
98
126
  class HttpRequest:
99
127
  headers: dict[str, str]
@@ -105,15 +133,14 @@ class HttpRequest:
105
133
 
106
134
  # TODO(b/394358912): Update this class to use a SDKResponse class that can be
107
135
  # generated and used for all languages.
108
- @dataclass
109
- class BaseResponse:
110
- http_headers: dict[str, str]
136
+ class BaseResponse(_common.BaseModel):
137
+ http_headers: Optional[dict[str, str]] = Field(
138
+ default=None, description='The http headers of the response.'
139
+ )
111
140
 
112
- @property
113
- def dict(self) -> dict[str, Any]:
114
- if isinstance(self, dict):
115
- return self
116
- return {'httpHeaders': self.http_headers}
141
+ json_payload: Optional[Any] = Field(
142
+ default=None, description='The json payload of the response.'
143
+ )
117
144
 
118
145
 
119
146
  class HttpResponse:
@@ -128,15 +155,15 @@ class HttpResponse:
128
155
  self.headers = headers
129
156
  self.response_stream = response_stream
130
157
  self.byte_stream = byte_stream
131
- self.segment_iterator = self.segments()
132
158
 
133
159
  # Async iterator for async streaming.
134
160
  def __aiter__(self):
161
+ self.segment_iterator = self.async_segments()
135
162
  return self
136
163
 
137
164
  async def __anext__(self):
138
165
  try:
139
- return next(self.segment_iterator)
166
+ return await self.segment_iterator.__anext__()
140
167
  except StopIteration:
141
168
  raise StopAsyncIteration
142
169
 
@@ -163,6 +190,25 @@ class HttpResponse:
163
190
  chunk = chunk[len(b'data: ') :]
164
191
  yield json.loads(str(chunk, 'utf-8'))
165
192
 
193
+ async def async_segments(self) -> AsyncIterator[Any]:
194
+ if isinstance(self.response_stream, list):
195
+ # list of objects retrieved from replay or from non-streaming API.
196
+ for chunk in self.response_stream:
197
+ yield json.loads(chunk) if chunk else {}
198
+ elif self.response_stream is None:
199
+ async for c in []:
200
+ yield c
201
+ else:
202
+ # Iterator of objects retrieved from the API.
203
+ async for chunk in self.response_stream.aiter_lines():
204
+ # This is httpx.Response.
205
+ if chunk:
206
+ # In async streaming mode, the chunk of JSON is prefixed with "data:"
207
+ # which we must strip before parsing.
208
+ if chunk.startswith('data: '):
209
+ chunk = chunk[len('data: ') :]
210
+ yield json.loads(chunk)
211
+
166
212
  def byte_segments(self):
167
213
  if isinstance(self.byte_stream, list):
168
214
  # list of objects retrieved from replay or from non-streaming API.
@@ -181,17 +227,17 @@ class HttpResponse:
181
227
  response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
182
228
 
183
229
 
184
- class ApiClient:
230
+ class BaseApiClient:
185
231
  """Client for calling HTTP APIs sending and receiving JSON."""
186
232
 
187
233
  def __init__(
188
234
  self,
189
- vertexai: Union[bool, None] = None,
190
- api_key: Union[str, None] = None,
191
- credentials: google.auth.credentials.Credentials = None,
192
- project: Union[str, None] = None,
193
- location: Union[str, None] = None,
194
- http_options: HttpOptionsOrDict = None,
235
+ vertexai: Optional[bool] = None,
236
+ api_key: Optional[str] = None,
237
+ credentials: Optional[google.auth.credentials.Credentials] = None,
238
+ project: Optional[str] = None,
239
+ location: Optional[str] = None,
240
+ http_options: Optional[HttpOptionsOrDict] = None,
195
241
  ):
196
242
  self.vertexai = vertexai
197
243
  if self.vertexai is None:
@@ -215,14 +261,15 @@ class ApiClient:
215
261
  ' initializer.'
216
262
  )
217
263
 
218
- # Validate http_options if a dict is provided.
264
+ # Validate http_options if it is provided.
265
+ validated_http_options: dict[str, Any]
219
266
  if isinstance(http_options, dict):
220
267
  try:
221
- HttpOptions.model_validate(http_options)
268
+ validated_http_options = HttpOptions.model_validate(http_options).model_dump()
222
269
  except ValidationError as e:
223
270
  raise ValueError(f'Invalid http_options: {e}')
224
271
  elif isinstance(http_options, HttpOptions):
225
- http_options = http_options.model_dump()
272
+ validated_http_options = http_options.model_dump()
226
273
 
227
274
  # Retrieve implicitly set values from the environment.
228
275
  env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
@@ -234,20 +281,24 @@ class ApiClient:
234
281
 
235
282
  self._credentials = credentials
236
283
  self._http_options = HttpOptionsDict()
284
+ # Initialize the lock. This lock will be used to protect access to the
285
+ # credentials. This is crucial for thread safety when multiple coroutines
286
+ # might be accessing the credentials at the same time.
287
+ self._auth_lock = asyncio.Lock()
237
288
 
238
289
  # Handle when to use Vertex AI in express mode (api key).
239
290
  # Explicit initializer arguments are already validated above.
240
291
  if self.vertexai:
241
292
  if credentials:
242
293
  # Explicit credentials take precedence over implicit api_key.
243
- logging.info(
294
+ logger.info(
244
295
  'The user provided Google Cloud credentials will take precedence'
245
296
  + ' over the API key from the environment variable.'
246
297
  )
247
298
  self.api_key = None
248
299
  elif (env_location or env_project) and api_key:
249
300
  # Explicit api_key takes precedence over implicit project/location.
250
- logging.info(
301
+ logger.info(
251
302
  'The user provided Vertex AI API key will take precedence over the'
252
303
  + ' project/location from the environment variables.'
253
304
  )
@@ -255,20 +306,22 @@ class ApiClient:
255
306
  self.location = None
256
307
  elif (project or location) and env_api_key:
257
308
  # Explicit project/location takes precedence over implicit api_key.
258
- logging.info(
309
+ logger.info(
259
310
  'The user provided project/location will take precedence over the'
260
311
  + ' Vertex AI API key from the environment variable.'
261
312
  )
262
313
  self.api_key = None
263
314
  elif (env_location or env_project) and env_api_key:
264
315
  # Implicit project/location takes precedence over implicit api_key.
265
- logging.info(
316
+ logger.info(
266
317
  'The project/location from the environment variables will take'
267
318
  + ' precedence over the API key from the environment variables.'
268
319
  )
269
320
  self.api_key = None
270
321
  if not self.project and not self.api_key:
271
- self.project = google.auth.default()[1]
322
+ credentials, self.project = _load_auth(project=None)
323
+ if not self._credentials:
324
+ self._credentials = credentials
272
325
  if not ((self.project and self.location) or self.api_key):
273
326
  raise ValueError(
274
327
  'Project and location or API key must be set when using the Vertex '
@@ -298,7 +351,7 @@ class ApiClient:
298
351
  self._http_options['headers']['x-goog-api-key'] = self.api_key
299
352
  # Update the http options with the user provided http options.
300
353
  if http_options:
301
- self._http_options = _patch_http_options(self._http_options, http_options)
354
+ self._http_options = _patch_http_options(self._http_options, validated_http_options)
302
355
  else:
303
356
  _append_library_version_headers(self._http_options['headers'])
304
357
 
@@ -306,12 +359,38 @@ class ApiClient:
306
359
  url_parts = urlparse(self._http_options['base_url'])
307
360
  return url_parts._replace(scheme='wss').geturl()
308
361
 
362
+ async def _async_access_token(self) -> str:
363
+ """Retrieves the access token for the credentials."""
364
+ if not self._credentials:
365
+ async with self._auth_lock:
366
+ # This ensures that only one coroutine can execute the auth logic at a
367
+ # time for thread safety.
368
+ if not self._credentials:
369
+ # Double check that the credentials are not set before loading them.
370
+ self._credentials, project = await asyncio.to_thread(
371
+ _load_auth, project=self.project
372
+ )
373
+ if not self.project:
374
+ self.project = project
375
+
376
+ if self._credentials.expired or not self._credentials.token:
377
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
378
+ async with self._auth_lock:
379
+ if self._credentials.expired or not self._credentials.token:
380
+ # Double check that the credentials expired before refreshing.
381
+ await asyncio.to_thread(_refresh_auth, self._credentials)
382
+
383
+ if not self._credentials.token:
384
+ raise RuntimeError('Could not resolve API token from the environment')
385
+
386
+ return self._credentials.token
387
+
309
388
  def _build_request(
310
389
  self,
311
390
  http_method: str,
312
391
  path: str,
313
392
  request_dict: dict[str, object],
314
- http_options: HttpOptionsOrDict = None,
393
+ http_options: Optional[HttpOptionsOrDict] = None,
315
394
  ) -> HttpRequest:
316
395
  # Remove all special dict keys such as _url and _query.
317
396
  keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
@@ -349,7 +428,7 @@ class ApiClient:
349
428
  patched_http_options['api_version'] + '/' + path,
350
429
  )
351
430
 
352
- timeout_in_seconds = patched_http_options.get('timeout', None)
431
+ timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get('timeout', None)
353
432
  if timeout_in_seconds:
354
433
  timeout_in_seconds = timeout_in_seconds / 1000.0
355
434
  else:
@@ -370,8 +449,10 @@ class ApiClient:
370
449
  ) -> HttpResponse:
371
450
  if self.vertexai and not self.api_key:
372
451
  if not self._credentials:
373
- self._credentials, _ = google.auth.default(
374
- scopes=['https://www.googleapis.com/auth/cloud-platform'],
452
+ self._credentials, _ = _load_auth(project=self.project)
453
+ if self._credentials.quota_project_id:
454
+ http_request.headers['x-goog-user-project'] = (
455
+ self._credentials.quota_project_id
375
456
  )
376
457
  authed_session = AuthorizedSession(self._credentials)
377
458
  authed_session.stream = stream
@@ -394,7 +475,7 @@ class ApiClient:
394
475
  http_request: HttpRequest,
395
476
  stream: bool = False,
396
477
  ) -> HttpResponse:
397
- data = None
478
+ data: Optional[Union[str, bytes]] = None
398
479
  if http_request.data:
399
480
  if not isinstance(http_request.data, bytes):
400
481
  data = json.dumps(http_request.data)
@@ -419,21 +500,42 @@ class ApiClient:
419
500
  self, http_request: HttpRequest, stream: bool = False
420
501
  ):
421
502
  if self.vertexai:
422
- if not self._credentials:
423
- self._credentials, _ = google.auth.default(
424
- scopes=['https://www.googleapis.com/auth/cloud-platform'],
503
+ http_request.headers['Authorization'] = (
504
+ f'Bearer {await self._async_access_token()}'
505
+ )
506
+ if self._credentials.quota_project_id:
507
+ http_request.headers['x-goog-user-project'] = (
508
+ self._credentials.quota_project_id
425
509
  )
426
- return await asyncio.to_thread(
427
- self._request,
428
- http_request,
429
- stream=stream,
510
+ if stream:
511
+ httpx_request = httpx.Request(
512
+ method=http_request.method,
513
+ url=http_request.url,
514
+ content=json.dumps(http_request.data),
515
+ headers=http_request.headers,
430
516
  )
431
- else:
432
- return await asyncio.to_thread(
433
- self._request,
434
- http_request,
517
+ aclient = httpx.AsyncClient()
518
+ response = await aclient.send(
519
+ httpx_request,
435
520
  stream=stream,
436
521
  )
522
+ errors.APIError.raise_for_response(response)
523
+ return HttpResponse(
524
+ response.headers, response if stream else [response.text]
525
+ )
526
+ else:
527
+ async with httpx.AsyncClient() as aclient:
528
+ response = await aclient.request(
529
+ method=http_request.method,
530
+ url=http_request.url,
531
+ headers=http_request.headers,
532
+ content=json.dumps(http_request.data) if http_request.data else None,
533
+ timeout=http_request.timeout,
534
+ )
535
+ errors.APIError.raise_for_response(response)
536
+ return HttpResponse(
537
+ response.headers, response if stream else [response.text]
538
+ )
437
539
 
438
540
  def get_read_only_http_options(self) -> HttpOptionsDict:
439
541
  copied = HttpOptionsDict()
@@ -447,7 +549,7 @@ class ApiClient:
447
549
  http_method: str,
448
550
  path: str,
449
551
  request_dict: dict[str, object],
450
- http_options: HttpOptionsOrDict = None,
552
+ http_options: Optional[HttpOptionsOrDict] = None,
451
553
  ):
452
554
  http_request = self._build_request(
453
555
  http_method, path, request_dict, http_options
@@ -455,9 +557,9 @@ class ApiClient:
455
557
  response = self._request(http_request, stream=False)
456
558
  json_response = response.json
457
559
  if not json_response:
458
- base_response = BaseResponse(response.headers).dict
459
- return base_response
460
-
560
+ return BaseResponse(http_headers=response.headers).model_dump(
561
+ by_alias=True
562
+ )
461
563
  return json_response
462
564
 
463
565
  def request_streamed(
@@ -465,7 +567,7 @@ class ApiClient:
465
567
  http_method: str,
466
568
  path: str,
467
569
  request_dict: dict[str, object],
468
- http_options: HttpOptionsDict = None,
570
+ http_options: Optional[HttpOptionsDict] = None,
469
571
  ):
470
572
  http_request = self._build_request(
471
573
  http_method, path, request_dict, http_options
@@ -480,7 +582,7 @@ class ApiClient:
480
582
  http_method: str,
481
583
  path: str,
482
584
  request_dict: dict[str, object],
483
- http_options: HttpOptionsDict = None,
585
+ http_options: Optional[HttpOptionsOrDict] = None,
484
586
  ) -> dict[str, object]:
485
587
  http_request = self._build_request(
486
588
  http_method, path, request_dict, http_options
@@ -489,8 +591,7 @@ class ApiClient:
489
591
  result = await self._async_request(http_request=http_request, stream=False)
490
592
  json_response = result.json
491
593
  if not json_response:
492
- base_response = BaseResponse(result.headers).dict
493
- return base_response
594
+ return BaseResponse(http_headers=result.headers).model_dump(by_alias=True)
494
595
  return json_response
495
596
 
496
597
  async def async_request_streamed(
@@ -498,7 +599,7 @@ class ApiClient:
498
599
  http_method: str,
499
600
  path: str,
500
601
  request_dict: dict[str, object],
501
- http_options: HttpOptionsDict = None,
602
+ http_options: Optional[HttpOptionsDict] = None,
502
603
  ):
503
604
  http_request = self._build_request(
504
605
  http_method, path, request_dict, http_options
@@ -607,10 +708,10 @@ class ApiClient:
607
708
  self,
608
709
  http_request: HttpRequest,
609
710
  ) -> HttpResponse:
610
- data = None
711
+ data: str | bytes | None = None
611
712
  if http_request.data:
612
713
  if not isinstance(http_request.data, bytes):
613
- data = json.dumps(http_request.data, cls=RequestJsonEncoder)
714
+ data = json.dumps(http_request.data)
614
715
  else:
615
716
  data = http_request.data
616
717
 
@@ -695,5 +796,5 @@ class ApiClient:
695
796
  # This method does nothing in the real api client. It is used in the
696
797
  # replay_api_client to verify the response from the SDK method matches the
697
798
  # recorded response.
698
- def _verify_response(self, response_model: BaseModel):
799
+ def _verify_response(self, response_model: _common.BaseModel):
699
800
  pass
@@ -15,10 +15,15 @@
15
15
 
16
16
  """Utilities for the API Modules of the Google Gen AI SDK."""
17
17
 
18
+ from typing import Optional
18
19
  from . import _api_client
19
20
 
20
21
 
21
22
  class BaseModule:
22
23
 
23
- def __init__(self, api_client_: _api_client.ApiClient):
24
+ def __init__(self, api_client_: _api_client.BaseApiClient):
24
25
  self._api_client = api_client_
26
+
27
+ @property
28
+ def vertexai(self) -> Optional[bool]:
29
+ return self._api_client.vertexai
@@ -31,12 +31,12 @@ else:
31
31
  VersionedUnionType = typing._UnionGenericAlias
32
32
 
33
33
  _py_builtin_type_to_schema_type = {
34
- str: 'STRING',
35
- int: 'INTEGER',
36
- float: 'NUMBER',
37
- bool: 'BOOLEAN',
38
- list: 'ARRAY',
39
- dict: 'OBJECT',
34
+ str: types.Type.STRING,
35
+ int: types.Type.INTEGER,
36
+ float: types.Type.NUMBER,
37
+ bool: types.Type.BOOLEAN,
38
+ list: types.Type.ARRAY,
39
+ dict: types.Type.OBJECT,
40
40
  }
41
41
 
42
42
 
@@ -145,7 +145,7 @@ def _parse_schema_from_parameter(
145
145
  for arg in get_args(param.annotation)
146
146
  )
147
147
  ):
148
- schema.type = 'OBJECT'
148
+ schema.type = _py_builtin_type_to_schema_type[dict]
149
149
  schema.any_of = []
150
150
  unique_types = set()
151
151
  for arg in get_args(param.annotation):
@@ -183,7 +183,7 @@ def _parse_schema_from_parameter(
183
183
  origin = get_origin(param.annotation)
184
184
  args = get_args(param.annotation)
185
185
  if origin is dict:
186
- schema.type = 'OBJECT'
186
+ schema.type = _py_builtin_type_to_schema_type[dict]
187
187
  if param.default is not inspect.Parameter.empty:
188
188
  if not _is_default_value_compatible(param.default, param.annotation):
189
189
  raise ValueError(default_value_error_msg)
@@ -195,7 +195,7 @@ def _parse_schema_from_parameter(
195
195
  raise ValueError(
196
196
  f'Literal type {param.annotation} must be a list of strings.'
197
197
  )
198
- schema.type = 'STRING'
198
+ schema.type = _py_builtin_type_to_schema_type[str]
199
199
  schema.enum = list(args)
200
200
  if param.default is not inspect.Parameter.empty:
201
201
  if not _is_default_value_compatible(param.default, param.annotation):
@@ -204,7 +204,7 @@ def _parse_schema_from_parameter(
204
204
  _raise_if_schema_unsupported(api_option, schema)
205
205
  return schema
206
206
  if origin is list:
207
- schema.type = 'ARRAY'
207
+ schema.type = _py_builtin_type_to_schema_type[list]
208
208
  schema.items = _parse_schema_from_parameter(
209
209
  api_option,
210
210
  inspect.Parameter(
@@ -222,7 +222,7 @@ def _parse_schema_from_parameter(
222
222
  return schema
223
223
  if origin is Union:
224
224
  schema.any_of = []
225
- schema.type = 'OBJECT'
225
+ schema.type = _py_builtin_type_to_schema_type[dict]
226
226
  unique_types = set()
227
227
  for arg in args:
228
228
  # The first check is for NoneType in Python 3.9, since the __name__
@@ -280,7 +280,7 @@ def _parse_schema_from_parameter(
280
280
  and param.default is not None
281
281
  ):
282
282
  schema.default = param.default
283
- schema.type = 'OBJECT'
283
+ schema.type = _py_builtin_type_to_schema_type[dict]
284
284
  schema.properties = {}
285
285
  for field_name, field_info in param.annotation.model_fields.items():
286
286
  schema.properties[field_name] = _parse_schema_from_parameter(
google/genai/_common.py CHANGED
@@ -60,6 +60,12 @@ def set_value_by_path(data, keys, value):
60
60
  for d in data[key_name]:
61
61
  set_value_by_path(d, keys[i + 1 :], value)
62
62
  return
63
+ elif key.endswith('[0]'):
64
+ key_name = key[:-3]
65
+ if key_name not in data:
66
+ data[key_name] = [{}]
67
+ set_value_by_path(data[key_name][0], keys[i + 1 :], value)
68
+ return
63
69
 
64
70
  data = data.setdefault(key, {})
65
71
 
@@ -106,6 +112,12 @@ def get_value_by_path(data: object, keys: list[str]):
106
112
  return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
107
113
  else:
108
114
  return None
115
+ elif key.endswith('[0]'):
116
+ key_name = key[:-3]
117
+ if key_name in data and data[key_name]:
118
+ return get_value_by_path(data[key_name][0], keys[i + 1 :])
119
+ else:
120
+ return None
109
121
  else:
110
122
  if key in data:
111
123
  data = data[key]
@@ -193,7 +205,7 @@ class BaseModel(pydantic.BaseModel):
193
205
 
194
206
  @classmethod
195
207
  def _from_response(
196
- cls, response: dict[str, object], kwargs: dict[str, object]
208
+ cls, *, response: dict[str, object], kwargs: dict[str, object]
197
209
  ) -> 'BaseModel':
198
210
  # To maintain forward compatibility, we need to remove extra fields from
199
211
  # the response.
@@ -254,7 +266,7 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
254
266
  A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
255
267
  to compatible type (e.g. base64 encoded string, isoformat date string).
256
268
  """
257
- processed_data = {}
269
+ processed_data: dict[str, object] = {}
258
270
  if not isinstance(data, dict):
259
271
  return data
260
272
  for key, value in data.items():
@@ -34,6 +34,8 @@ else:
34
34
 
35
35
  _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
36
36
 
37
+ logger = logging.getLogger('google_genai.models')
38
+
37
39
 
38
40
  def format_destination(
39
41
  src: str,
@@ -74,7 +76,7 @@ def get_function_map(
74
76
  if config and isinstance(config, dict)
75
77
  else config
76
78
  )
77
- function_map = {}
79
+ function_map: dict[str, object] = {}
78
80
  if not config_model:
79
81
  return function_map
80
82
  if config_model.tools:
@@ -218,15 +220,16 @@ def get_function_response_parts(
218
220
  func_name = part.function_call.name
219
221
  func = function_map[func_name]
220
222
  args = convert_number_values_for_function_call_args(part.function_call.args)
223
+ func_response: dict[str, Any]
221
224
  try:
222
- response = {'result': invoke_function_from_dict_args(args, func)}
225
+ func_response = {'result': invoke_function_from_dict_args(args, func)}
223
226
  except Exception as e: # pylint: disable=broad-except
224
- response = {'error': str(e)}
225
- func_response = types.Part.from_function_response(
226
- name=func_name, response=response
227
+ func_response = {'error': str(e)}
228
+ func_response_part = types.Part.from_function_response(
229
+ name=func_name, response=func_response
227
230
  )
228
231
 
229
- func_response_parts.append(func_response)
232
+ func_response_parts.append(func_response_part)
230
233
  return func_response_parts
231
234
 
232
235
 
@@ -248,7 +251,7 @@ def should_disable_afc(
248
251
  is not None
249
252
  and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
250
253
  ):
251
- logging.warning(
254
+ logger.warning(
252
255
  'max_remote_calls in automatic_function_calling_config'
253
256
  f' {config_model.automatic_function_calling.maximum_remote_calls} is'
254
257
  ' less than or equal to 0. Disabling automatic function calling.'
@@ -268,9 +271,12 @@ def should_disable_afc(
268
271
  config_model.automatic_function_calling.disable
269
272
  and config_model.automatic_function_calling.maximum_remote_calls
270
273
  is not None
274
+ # exclude the case where max_remote_calls is set to 10 by default.
275
+ and 'maximum_remote_calls'
276
+ in config_model.automatic_function_calling.model_fields_set
271
277
  and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
272
278
  ):
273
- logging.warning(
279
+ logger.warning(
274
280
  '`automatic_function_calling.disable` is set to `True`. And'
275
281
  ' `automatic_function_calling.maximum_remote_calls` is a'
276
282
  ' positive number'