google-genai 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,9 +17,9 @@
17
17
 
18
18
  import inspect
19
19
  import logging
20
+ import sys
20
21
  import typing
21
22
  from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
22
- import sys
23
23
 
24
24
  import pydantic
25
25
 
@@ -37,6 +37,15 @@ _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
37
37
  logger = logging.getLogger('google_genai.models')
38
38
 
39
39
 
40
+ def _create_generate_content_config_model(
41
+ config: types.GenerateContentConfigOrDict,
42
+ ) -> types.GenerateContentConfig:
43
+ if isinstance(config, dict):
44
+ return types.GenerateContentConfig(**config)
45
+ else:
46
+ return config
47
+
48
+
40
49
  def format_destination(
41
50
  src: str,
42
51
  config: Optional[types.CreateBatchJobConfigOrDict] = None,
@@ -69,16 +78,12 @@ def format_destination(
69
78
 
70
79
  def get_function_map(
71
80
  config: Optional[types.GenerateContentConfigOrDict] = None,
72
- ) -> dict[str, object]:
81
+ ) -> dict[str, Callable]:
73
82
  """Returns a function map from the config."""
74
- config_model = (
75
- types.GenerateContentConfig(**config)
76
- if config and isinstance(config, dict)
77
- else config
78
- )
79
- function_map = {}
80
- if not config_model:
83
+ function_map: dict[str, Callable] = {}
84
+ if not config:
81
85
  return function_map
86
+ config_model = _create_generate_content_config_model(config)
82
87
  if config_model.tools:
83
88
  for tool in config_model.tools:
84
89
  if callable(tool):
@@ -92,6 +97,16 @@ def get_function_map(
92
97
  return function_map
93
98
 
94
99
 
100
+ def convert_number_values_for_dict_function_call_args(
101
+ args: dict[str, Any],
102
+ ) -> dict[str, Any]:
103
+ """Converts float values in dict with no decimal to integers."""
104
+ return {
105
+ key: convert_number_values_for_function_call_args(value)
106
+ for key, value in args.items()
107
+ }
108
+
109
+
95
110
  def convert_number_values_for_function_call_args(
96
111
  args: Union[dict[str, object], list[object], object],
97
112
  ) -> Union[dict[str, object], list[object], object]:
@@ -210,25 +225,35 @@ def invoke_function_from_dict_args(
210
225
 
211
226
  def get_function_response_parts(
212
227
  response: types.GenerateContentResponse,
213
- function_map: dict[str, object],
228
+ function_map: dict[str, Callable],
214
229
  ) -> list[types.Part]:
215
230
  """Returns the function response parts from the response."""
216
231
  func_response_parts = []
217
- for part in response.candidates[0].content.parts:
218
- if not part.function_call:
219
- continue
220
- func_name = part.function_call.name
221
- func = function_map[func_name]
222
- args = convert_number_values_for_function_call_args(part.function_call.args)
223
- try:
224
- response = {'result': invoke_function_from_dict_args(args, func)}
225
- except Exception as e: # pylint: disable=broad-except
226
- response = {'error': str(e)}
227
- func_response = types.Part.from_function_response(
228
- name=func_name, response=response
229
- )
230
-
231
- func_response_parts.append(func_response)
232
+ if (
233
+ response.candidates is not None
234
+ and isinstance(response.candidates[0].content, types.Content)
235
+ and response.candidates[0].content.parts is not None
236
+ ):
237
+ for part in response.candidates[0].content.parts:
238
+ if not part.function_call:
239
+ continue
240
+ func_name = part.function_call.name
241
+ if func_name is not None and part.function_call.args is not None:
242
+ func = function_map[func_name]
243
+ args = convert_number_values_for_dict_function_call_args(
244
+ part.function_call.args
245
+ )
246
+ func_response: dict[str, Any]
247
+ try:
248
+ func_response = {
249
+ 'result': invoke_function_from_dict_args(args, func)
250
+ }
251
+ except Exception as e: # pylint: disable=broad-except
252
+ func_response = {'error': str(e)}
253
+ func_response_part = types.Part.from_function_response(
254
+ name=func_name, response=func_response
255
+ )
256
+ func_response_parts.append(func_response_part)
232
257
  return func_response_parts
233
258
 
234
259
 
@@ -236,12 +261,9 @@ def should_disable_afc(
236
261
  config: Optional[types.GenerateContentConfigOrDict] = None,
237
262
  ) -> bool:
238
263
  """Returns whether automatic function calling is enabled."""
239
- config_model = (
240
- types.GenerateContentConfig(**config)
241
- if config and isinstance(config, dict)
242
- else config
243
- )
244
-
264
+ if not config:
265
+ return False
266
+ config_model = _create_generate_content_config_model(config)
245
267
  # If max_remote_calls is less or equal to 0, warn and disable AFC.
246
268
  if (
247
269
  config_model
@@ -260,8 +282,7 @@ def should_disable_afc(
260
282
 
261
283
  # Default to enable AFC if not specified.
262
284
  if (
263
- not config_model
264
- or not config_model.automatic_function_calling
285
+ not config_model.automatic_function_calling
265
286
  or config_model.automatic_function_calling.disable is None
266
287
  ):
267
288
  return False
@@ -294,20 +315,17 @@ def should_disable_afc(
294
315
  def get_max_remote_calls_afc(
295
316
  config: Optional[types.GenerateContentConfigOrDict] = None,
296
317
  ) -> int:
318
+ if not config:
319
+ return _DEFAULT_MAX_REMOTE_CALLS_AFC
297
320
  """Returns the remaining remote calls for automatic function calling."""
298
321
  if should_disable_afc(config):
299
322
  raise ValueError(
300
323
  'automatic function calling is not enabled, but SDK is trying to get'
301
324
  ' max remote calls.'
302
325
  )
303
- config_model = (
304
- types.GenerateContentConfig(**config)
305
- if config and isinstance(config, dict)
306
- else config
307
- )
326
+ config_model = _create_generate_content_config_model(config)
308
327
  if (
309
- not config_model
310
- or not config_model.automatic_function_calling
328
+ not config_model.automatic_function_calling
311
329
  or config_model.automatic_function_calling.maximum_remote_calls is None
312
330
  ):
313
331
  return _DEFAULT_MAX_REMOTE_CALLS_AFC
@@ -317,11 +335,9 @@ def get_max_remote_calls_afc(
317
335
  def should_append_afc_history(
318
336
  config: Optional[types.GenerateContentConfigOrDict] = None,
319
337
  ) -> bool:
320
- config_model = (
321
- types.GenerateContentConfig(**config)
322
- if config and isinstance(config, dict)
323
- else config
324
- )
325
- if not config_model or not config_model.automatic_function_calling:
338
+ if not config:
339
+ return True
340
+ config_model = _create_generate_content_config_model(config)
341
+ if not config_model.automatic_function_calling:
326
342
  return True
327
343
  return not config_model.automatic_function_calling.ignore_call_history
@@ -29,7 +29,7 @@ import google.auth
29
29
  from requests.exceptions import HTTPError
30
30
 
31
31
  from . import errors
32
- from ._api_client import ApiClient
32
+ from ._api_client import BaseApiClient
33
33
  from ._api_client import HttpOptions
34
34
  from ._api_client import HttpRequest
35
35
  from ._api_client import HttpResponse
@@ -109,7 +109,8 @@ def _redact_project_location_path(path: str) -> str:
109
109
  return path
110
110
 
111
111
 
112
- def _redact_request_body(body: dict[str, object]) -> dict[str, object]:
112
+ def _redact_request_body(body: dict[str, object]):
113
+ """Redacts fields in the request body in place."""
113
114
  for key, value in body.items():
114
115
  if isinstance(value, str):
115
116
  body[key] = _redact_project_location_path(value)
@@ -179,7 +180,7 @@ class ReplayFile(BaseModel):
179
180
  interactions: list[ReplayInteraction]
180
181
 
181
182
 
182
- class ReplayApiClient(ApiClient):
183
+ class ReplayApiClient(BaseApiClient):
183
184
  """For integration testing, send recorded response or records a response."""
184
185
 
185
186
  def __init__(
@@ -302,13 +303,24 @@ class ReplayApiClient(ApiClient):
302
303
  status_code=http_response.status_code,
303
304
  sdk_response_segments=[],
304
305
  )
305
- else:
306
+ elif isinstance(http_response, errors.APIError):
306
307
  response = ReplayResponse(
307
308
  headers=dict(http_response.response.headers),
308
309
  body_segments=[http_response._to_replay_record()],
309
310
  status_code=http_response.code,
310
311
  sdk_response_segments=[],
311
312
  )
313
+ elif isinstance(http_response, bytes):
314
+ response = ReplayResponse(
315
+ headers={},
316
+ body_segments=[],
317
+ byte_segments=[http_response],
318
+ sdk_response_segments=[],
319
+ )
320
+ else:
321
+ raise ValueError(
322
+ 'Unsupported http_response type: ' + str(type(http_response))
323
+ )
312
324
  self.replay_session.interactions.append(
313
325
  ReplayInteraction(request=request, response=response)
314
326
  )
@@ -457,6 +469,7 @@ class ReplayApiClient(ApiClient):
457
469
  method='POST', url='', data={'file_path': file_path}, headers={}
458
470
  )
459
471
  if self._should_call_api():
472
+ result: Union[str, HttpResponse]
460
473
  try:
461
474
  result = super().upload_file(file_path, upload_url, upload_size)
462
475
  except HTTPError as e:
@@ -470,6 +483,43 @@ class ReplayApiClient(ApiClient):
470
483
  else:
471
484
  return self._build_response_from_replay(request).json
472
485
 
486
+ async def async_upload_file(
487
+ self,
488
+ file_path: Union[str, io.IOBase],
489
+ upload_url: str,
490
+ upload_size: int,
491
+ ) -> str:
492
+ if isinstance(file_path, io.IOBase):
493
+ offset = file_path.tell()
494
+ content = file_path.read()
495
+ file_path.seek(offset, os.SEEK_SET)
496
+ request = HttpRequest(
497
+ method='POST',
498
+ url='',
499
+ data={'bytes': base64.b64encode(content).decode('utf-8')},
500
+ headers={},
501
+ )
502
+ else:
503
+ request = HttpRequest(
504
+ method='POST', url='', data={'file_path': file_path}, headers={}
505
+ )
506
+ if self._should_call_api():
507
+ result: Union[str, HttpResponse]
508
+ try:
509
+ result = await super().async_upload_file(
510
+ file_path, upload_url, upload_size
511
+ )
512
+ except HTTPError as e:
513
+ result = HttpResponse(
514
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
515
+ )
516
+ result.status_code = e.response.status_code
517
+ raise e
518
+ self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
519
+ return result
520
+ else:
521
+ return self._build_response_from_replay(request).json
522
+
473
523
  def _download_file_request(self, request):
474
524
  self._initialize_replay_session_if_not_loaded()
475
525
  if self._should_call_api():
@@ -485,3 +535,22 @@ class ReplayApiClient(ApiClient):
485
535
  return result
486
536
  else:
487
537
  return self._build_response_from_replay(request)
538
+
539
+ async def async_download_file(self, path: str, http_options):
540
+ self._initialize_replay_session_if_not_loaded()
541
+ request = self._build_request(
542
+ 'get', path=path, request_dict={}, http_options=http_options
543
+ )
544
+ if self._should_call_api():
545
+ try:
546
+ result = await super().async_download_file(path, http_options)
547
+ except HTTPError as e:
548
+ result = HttpResponse(
549
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
550
+ )
551
+ result.status_code = e.response.status_code
552
+ raise e
553
+ self._record_interaction(request, result)
554
+ return result
555
+ else:
556
+ return self._build_response_from_replay(request).byte_stream[0]
@@ -17,13 +17,13 @@ import asyncio
17
17
  import time
18
18
  from unittest.mock import MagicMock, patch
19
19
  import pytest
20
- from .api_client import ApiClient
20
+ from .api_client import BaseApiClient
21
21
 
22
22
 
23
- @patch('genai.api_client.ApiClient._build_request')
24
- @patch('genai.api_client.ApiClient._request')
23
+ @patch('genai.api_client.BaseApiClient._build_request')
24
+ @patch('genai.api_client.BaseApiClient._request')
25
25
  def test_request_streamed_non_blocking(mock_request, mock_build_request):
26
- api_client = ApiClient(api_key='test_api_key')
26
+ api_client = BaseApiClient(api_key='test_api_key')
27
27
  http_method = 'GET'
28
28
  path = 'test/path'
29
29
  request_dict = {'key': 'value'}
@@ -56,8 +56,8 @@ def test_request_streamed_non_blocking(mock_request, mock_build_request):
56
56
  assert end_time - start_time > 0.3
57
57
 
58
58
 
59
- @patch('genai.api_client.ApiClient._build_request')
60
- @patch('genai.api_client.ApiClient._async_request')
59
+ @patch('genai.api_client.BaseApiClient._build_request')
60
+ @patch('genai.api_client.BaseApiClient._async_request')
61
61
  @pytest.mark.asyncio
62
62
  async def test_async_request(mock_async_request, mock_build_request):
63
63
  api_client = ApiClient(api_key='test_api_key')
@@ -99,8 +99,8 @@ async def test_async_request(mock_async_request, mock_build_request):
99
99
  assert 0.1 <= end_time - start_time < 0.15
100
100
 
101
101
 
102
- @patch('genai.api_client.ApiClient._build_request')
103
- @patch('genai.api_client.ApiClient._async_request')
102
+ @patch('genai.api_client.BaseApiClient._build_request')
103
+ @patch('genai.api_client.BaseApiClient._async_request')
104
104
  @pytest.mark.asyncio
105
105
  async def test_async_request_streamed_non_blocking(
106
106
  mock_async_request, mock_build_request