google-genai 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,15 +17,15 @@
17
17
 
18
18
  import base64
19
19
  import copy
20
+ import datetime
20
21
  import inspect
22
+ import io
21
23
  import json
22
24
  import os
23
25
  import re
24
- import datetime
25
26
  from typing import Any, Literal, Optional, Union
26
27
 
27
28
  import google.auth
28
- from pydantic import BaseModel
29
29
  from requests.exceptions import HTTPError
30
30
 
31
31
  from . import errors
@@ -33,7 +33,8 @@ from ._api_client import ApiClient
33
33
  from ._api_client import HttpOptions
34
34
  from ._api_client import HttpRequest
35
35
  from ._api_client import HttpResponse
36
- from ._api_client import RequestJsonEncoder
36
+ from ._common import BaseModel
37
+
37
38
 
38
39
  def _redact_version_numbers(version_string: str) -> str:
39
40
  """Redacts version numbers in the form x.y.z from a string."""
@@ -72,6 +73,11 @@ def _redact_request_url(url: str) -> str:
72
73
  '{VERTEX_URL_PREFIX}/',
73
74
  url,
74
75
  )
76
+ result = re.sub(
77
+ r'.*-aiplatform.googleapis.com/[^/]+/',
78
+ '{VERTEX_URL_PREFIX}/',
79
+ result,
80
+ )
75
81
  result = re.sub(
76
82
  r'https://generativelanguage.googleapis.com/[^/]+',
77
83
  '{MLDEV_URL_PREFIX}',
@@ -140,6 +146,7 @@ class ReplayResponse(BaseModel):
140
146
  status_code: int = 200
141
147
  headers: dict[str, str]
142
148
  body_segments: list[dict[str, object]]
149
+ byte_segments: Optional[list[bytes]] = None
143
150
  sdk_response_segments: list[dict[str, object]]
144
151
 
145
152
  def model_post_init(self, __context: Any) -> None:
@@ -259,26 +266,13 @@ class ReplayApiClient(ApiClient):
259
266
  replay_file_path = self._get_replay_file_path()
260
267
  os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
261
268
  with open(replay_file_path, 'w') as f:
262
- replay_session_dict = self.replay_session.model_dump()
263
- # Use for non-utf-8 bytes in image/video... output.
264
- for interaction in replay_session_dict['interactions']:
265
- segments = []
266
- for response in interaction['response']['sdk_response_segments']:
267
- segments.append(json.loads(json.dumps(
268
- response, cls=ResponseJsonEncoder
269
- )))
270
- interaction['response']['sdk_response_segments'] = segments
271
- f.write(
272
- json.dumps(
273
- replay_session_dict, indent=2, cls=RequestJsonEncoder
274
- )
275
- )
269
+ f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
276
270
  self.replay_session = None
277
271
 
278
272
  def _record_interaction(
279
273
  self,
280
274
  http_request: HttpRequest,
281
- http_response: Union[HttpResponse, errors.APIError],
275
+ http_response: Union[HttpResponse, errors.APIError, bytes],
282
276
  ):
283
277
  if not self._should_update_replay():
284
278
  return
@@ -293,6 +287,9 @@ class ReplayApiClient(ApiClient):
293
287
  response = ReplayResponse(
294
288
  headers=dict(http_response.headers),
295
289
  body_segments=list(http_response.segments()),
290
+ byte_segments=[
291
+ seg[:100] + b'...' for seg in http_response.byte_segments()
292
+ ],
296
293
  status_code=http_response.status_code,
297
294
  sdk_response_segments=[],
298
295
  )
@@ -326,11 +323,7 @@ class ReplayApiClient(ApiClient):
326
323
  # so that the comparison is fair.
327
324
  _redact_request_body(request_data_copy)
328
325
 
329
- # Need to call dumps() and loads() to convert dict bytes values to strings.
330
- # Because the expected_request_body dict never contains bytes values.
331
- actual_request_body = [
332
- json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
333
- ]
326
+ actual_request_body = [request_data_copy]
334
327
  expected_request_body = interaction.request.body_segments
335
328
  assert actual_request_body == expected_request_body, (
336
329
  'Request body mismatch:\n'
@@ -353,6 +346,7 @@ class ReplayApiClient(ApiClient):
353
346
  json.dumps(segment)
354
347
  for segment in interaction.response.body_segments
355
348
  ],
349
+ byte_stream=interaction.response.byte_segments,
356
350
  )
357
351
 
358
352
  def _verify_response(self, response_model: BaseModel):
@@ -371,15 +365,10 @@ class ReplayApiClient(ApiClient):
371
365
  if isinstance(response_model, list):
372
366
  response_model = response_model[0]
373
367
  print('response_model: ', response_model.model_dump(exclude_none=True))
374
- actual = json.dumps(
375
- response_model.model_dump(exclude_none=True),
376
- cls=ResponseJsonEncoder,
377
- sort_keys=True,
378
- )
379
- expected = json.dumps(
380
- interaction.response.sdk_response_segments[self._sdk_response_index],
381
- sort_keys=True,
382
- )
368
+ actual = response_model.model_dump(exclude_none=True, mode='json')
369
+ expected = interaction.response.sdk_response_segments[
370
+ self._sdk_response_index
371
+ ]
383
372
  assert (
384
373
  actual == expected
385
374
  ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
@@ -413,10 +402,21 @@ class ReplayApiClient(ApiClient):
413
402
  else:
414
403
  return self._build_response_from_replay(http_request)
415
404
 
416
- def upload_file(self, file_path: str, upload_url: str, upload_size: int):
417
- request = HttpRequest(
418
- method='POST', url='', data={'file_path': file_path}, headers={}
419
- )
405
+ def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
406
+ if isinstance(file_path, io.IOBase):
407
+ offset = file_path.tell()
408
+ content = file_path.read()
409
+ file_path.seek(offset, os.SEEK_SET)
410
+ request = HttpRequest(
411
+ method='POST',
412
+ url='',
413
+ data={'bytes': base64.b64encode(content).decode('utf-8')},
414
+ headers={}
415
+ )
416
+ else:
417
+ request = HttpRequest(
418
+ method='POST', url='', data={'file_path': file_path}, headers={}
419
+ )
420
420
  if self._should_call_api():
421
421
  try:
422
422
  result = super().upload_file(file_path, upload_url, upload_size)
@@ -431,42 +431,19 @@ class ReplayApiClient(ApiClient):
431
431
  else:
432
432
  return self._build_response_from_replay(request).text
433
433
 
434
-
435
- class ResponseJsonEncoder(json.JSONEncoder):
436
- """The replay test json encoder for response.
437
-
438
- We need RequestJsonEncoder and ResponseJsonEncoder because:
439
- 1. In production, we only need RequestJsonEncoder to help json module
440
- to convert non-stringable and stringable types to json string. Especially
441
- for bytes type, the value of bytes field is encoded to base64 string so it
442
- is always stringable and the RequestJsonEncoder doesn't have to deal with
443
- utf-8 JSON broken issue.
444
- 2. In replay test, we also need ResponseJsonEncoder to help json module
445
- convert non-stringable and stringable types to json string. But response
446
- object returned from SDK method is different from the request api_client
447
- sent to server. For the bytes type, there is no base64 string in response
448
- anymore, because SDK handles it internally. So bytes type in Response is
449
- non-stringable. The ResponseJsonEncoder uses different encoding
450
- strategy than the RequestJsonEncoder to deal with utf-8 JSON broken issue.
451
- """
452
- def default(self, o):
453
- if isinstance(o, bytes):
454
- # Use base64.b64encode() to encode bytes to string so that the media bytes
455
- # fields are serializable.
456
- # o.decode(encoding='utf-8', errors='replace') doesn't work because it
457
- # uses a fixed error string `\ufffd` for all non-utf-8 characters,
458
- # which cannot be converted back to original bytes. And other languages
459
- # only have the original bytes to compare with.
460
- # Since we use base64.b64encoding() in replay test, a change that breaks
461
- # native bytes can be captured by
462
- # test_compute_tokens.py::test_token_bytes_deserialization.
463
- return base64.b64encode(o).decode(encoding='utf-8')
464
- elif isinstance(o, datetime.datetime):
465
- # dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
466
- # but replay files want "2024-11-15T23:27:45.624657Z"
467
- if o.isoformat().endswith('+00:00'):
468
- return o.isoformat().replace('+00:00', 'Z')
469
- else:
470
- return o.isoformat()
434
+ def _download_file_request(self, request):
435
+ self._initialize_replay_session_if_not_loaded()
436
+ if self._should_call_api():
437
+ try:
438
+ result = super()._download_file_request(request)
439
+ except HTTPError as e:
440
+ result = HttpResponse(
441
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
442
+ )
443
+ result.status_code = e.response.status_code
444
+ raise e
445
+ self._record_interaction(request, result)
446
+ return result
471
447
  else:
472
- return super().default(o)
448
+ return self._build_response_from_replay(request)
449
+
@@ -21,10 +21,12 @@ import inspect
21
21
  import io
22
22
  import re
23
23
  import time
24
- from typing import Any, Optional, Union
24
+ import typing
25
+ from typing import Any, GenericAlias, Optional, Union
25
26
 
26
27
  import PIL.Image
27
28
  import PIL.PngImagePlugin
29
+ import pydantic
28
30
 
29
31
  from . import _api_client
30
32
  from . import types
@@ -35,7 +37,7 @@ def _resource_name(
35
37
  resource_name: str,
36
38
  *,
37
39
  collection_identifier: str,
38
- collection_hirearchy_depth: int = 2,
40
+ collection_hierarchy_depth: int = 2,
39
41
  ):
40
42
  # pylint: disable=line-too-long
41
43
  """Prepends resource name with project, location, collection_identifier if needed.
@@ -48,13 +50,13 @@ def _resource_name(
48
50
  Args:
49
51
  client: The API client.
50
52
  resource_name: The user input resource name to be completed.
51
- collection_identifier: The collection identifier to be prepended.
52
- See collection identifiers in https://google.aip.dev/122.
53
- collection_hirearchy_depth: The collection hierarchy depth.
54
- Only set this field when the resource has nested collections.
55
- For example, `users/vhugo1802/events/birthday-dinner-226`, the
56
- collection_identifier is `users` and collection_hirearchy_depth is 4.
57
- See nested collections in https://google.aip.dev/122.
53
+ collection_identifier: The collection identifier to be prepended. See
54
+ collection identifiers in https://google.aip.dev/122.
55
+ collection_hierarchy_depth: The collection hierarchy depth. Only set this
56
+ field when the resource has nested collections. For example,
57
+ `users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is
58
+ `users` and collection_hierarchy_depth is 4. See nested collections in
59
+ https://google.aip.dev/122.
58
60
 
59
61
  Example:
60
62
 
@@ -62,7 +64,8 @@ def _resource_name(
62
64
  client.vertexai = True
63
65
  client.project = 'bar'
64
66
  client.location = 'us-west1'
65
- _resource_name(client, 'cachedContents/123', collection_identifier='cachedContents')
67
+ _resource_name(client, 'cachedContents/123',
68
+ collection_identifier='cachedContents')
66
69
  returns: 'projects/bar/locations/us-west1/cachedContents/123'
67
70
 
68
71
  Example:
@@ -72,7 +75,8 @@ def _resource_name(
72
75
  client.vertexai = True
73
76
  client.project = 'bar'
74
77
  client.location = 'us-west1'
75
- _resource_name(client, resource_name, collection_identifier='cachedContents')
78
+ _resource_name(client, resource_name,
79
+ collection_identifier='cachedContents')
76
80
  returns: 'projects/foo/locations/us-central1/cachedContents/123'
77
81
 
78
82
  Example:
@@ -80,7 +84,8 @@ def _resource_name(
80
84
  resource_name = '123'
81
85
  # resource_name = 'cachedContents/123'
82
86
  client.vertexai = False
83
- _resource_name(client, resource_name, collection_identifier='cachedContents')
87
+ _resource_name(client, resource_name,
88
+ collection_identifier='cachedContents')
84
89
  returns 'cachedContents/123'
85
90
 
86
91
  Example:
@@ -88,7 +93,8 @@ def _resource_name(
88
93
  resource_prefix = 'cachedContents'
89
94
  client.vertexai = False
90
95
  # client.vertexai = True
91
- _resource_name(client, resource_name, collection_identifier='cachedContents')
96
+ _resource_name(client, resource_name,
97
+ collection_identifier='cachedContents')
92
98
  returns: 'some/wrong/cachedContents/resource/name/123'
93
99
 
94
100
  Returns:
@@ -99,7 +105,7 @@ def _resource_name(
99
105
  # Check if prepending the collection identifier won't violate the
100
106
  # collection hierarchy depth.
101
107
  and f'{collection_identifier}/{resource_name}'.count('/') + 1
102
- == collection_hirearchy_depth
108
+ == collection_hierarchy_depth
103
109
  )
104
110
  if client.vertexai:
105
111
  if resource_name.startswith('projects/'):
@@ -142,6 +148,35 @@ def t_model(client: _api_client.ApiClient, model: str):
142
148
  else:
143
149
  return f'models/{model}'
144
150
 
151
+
152
+ def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
153
+ if api_client.vertexai:
154
+ if base_models:
155
+ return 'publishers/google/models'
156
+ else:
157
+ return 'models'
158
+ else:
159
+ if base_models:
160
+ return 'models'
161
+ else:
162
+ return 'tunedModels'
163
+
164
+
165
+ def t_extract_models(
166
+ api_client: _api_client.ApiClient, response: dict
167
+ ) -> list[types.Model]:
168
+ if not response:
169
+ return []
170
+ elif response.get('models') is not None:
171
+ return response.get('models')
172
+ elif response.get('tunedModels') is not None:
173
+ return response.get('tunedModels')
174
+ elif response.get('publisherModels') is not None:
175
+ return response.get('publisherModels')
176
+ else:
177
+ raise ValueError('Cannot determine the models type.')
178
+
179
+
145
180
  def t_caches_model(api_client: _api_client.ApiClient, model: str):
146
181
  model = t_model(api_client, model)
147
182
  if not model:
@@ -180,6 +215,10 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
180
215
  return types.Part(text=part)
181
216
  if isinstance(part, PIL.Image.Image):
182
217
  return types.Part(inline_data=pil_to_blob(part))
218
+ if isinstance(part, types.File):
219
+ if not part.uri or not part.mime_type:
220
+ raise ValueError('file uri and mime_type are required.')
221
+ return types.Part.from_uri(part.uri, part.mime_type)
183
222
  else:
184
223
  return part
185
224
 
@@ -258,32 +297,135 @@ def t_contents(
258
297
  return [t_content(client, contents)]
259
298
 
260
299
 
261
- def process_schema(data: dict):
300
+ def process_schema(
301
+ data: dict[str, Any], client: Optional[_api_client.ApiClient] = None
302
+ ):
262
303
  if isinstance(data, dict):
263
304
  # Iterate over a copy of keys to allow deletion
264
305
  for key in list(data.keys()):
265
- if key == 'title':
306
+ # Only delete 'title'for the Gemini API
307
+ if client and not client.vertexai and key == 'title':
266
308
  del data[key]
267
- elif key == 'type':
268
- data[key] = data[key].upper()
269
309
  else:
270
- process_schema(data[key])
310
+ process_schema(data[key], client)
271
311
  elif isinstance(data, list):
272
312
  for item in data:
273
- process_schema(item)
313
+ process_schema(item, client)
274
314
 
275
315
  return data
276
316
 
277
317
 
318
+ def _build_schema(fname: str, fields_dict: dict[str, Any]) -> dict[str, Any]:
319
+ parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
320
+ defs = parameters.pop('$defs', {})
321
+
322
+ for _, value in defs.items():
323
+ unpack_defs(value, defs)
324
+
325
+ unpack_defs(parameters, defs)
326
+ return parameters['properties']['dummy']
327
+
328
+
329
+ def unpack_defs(schema: dict[str, Any], defs: dict[str, Any]):
330
+ """Unpacks the $defs values in the schema generated by pydantic so they can be understood by the API.
331
+
332
+ Example of a schema before and after unpacking:
333
+ Before:
334
+
335
+ `schema`
336
+
337
+ {'properties': {
338
+ 'dummy': {
339
+ 'items': {
340
+ '$ref': '#/$defs/CountryInfo'
341
+ },
342
+ 'title': 'Dummy',
343
+ 'type': 'array'
344
+ }
345
+ },
346
+ 'required': ['dummy'],
347
+ 'title': 'dummy',
348
+ 'type': 'object'}
349
+
350
+ `defs`
351
+
352
+ {'CountryInfo': {'properties': {'continent': {'title': 'Continent', 'type':
353
+ 'string'}, 'gdp': {'title': 'Gdp', 'type': 'integer'}}, 'required':
354
+ ['continent', 'gdp'], 'title': 'CountryInfo', 'type': 'object'}}
355
+
356
+ After:
357
+
358
+ `schema`
359
+ {'properties': {
360
+ 'continent': {'title': 'Continent', 'type': 'string'},
361
+ 'gdp': {'title': 'Gdp', 'type': 'integer'}
362
+ },
363
+ 'required': ['continent', 'gdp'],
364
+ 'title': 'CountryInfo',
365
+ 'type': 'object'
366
+ }
367
+ """
368
+ properties = schema.get('properties', None)
369
+ if properties is None:
370
+ return
371
+
372
+ for name, value in properties.items():
373
+ ref_key = value.get('$ref', None)
374
+ if ref_key is not None:
375
+ ref = defs[ref_key.split('defs/')[-1]]
376
+ unpack_defs(ref, defs)
377
+ properties[name] = ref
378
+ continue
379
+
380
+ anyof = value.get('anyOf', None)
381
+ if anyof is not None:
382
+ for i, atype in enumerate(anyof):
383
+ ref_key = atype.get('$ref', None)
384
+ if ref_key is not None:
385
+ ref = defs[ref_key.split('defs/')[-1]]
386
+ unpack_defs(ref, defs)
387
+ anyof[i] = ref
388
+ continue
389
+
390
+ items = value.get('items', None)
391
+ if items is not None:
392
+ ref_key = items.get('$ref', None)
393
+ if ref_key is not None:
394
+ ref = defs[ref_key.split('defs/')[-1]]
395
+ unpack_defs(ref, defs)
396
+ value['items'] = ref
397
+ continue
398
+
399
+
278
400
  def t_schema(
279
- _: _api_client.ApiClient, origin: Union[types.SchemaDict, Any]
401
+ client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any]
280
402
  ) -> Optional[types.Schema]:
281
403
  if not origin:
282
404
  return None
283
405
  if isinstance(origin, dict):
284
- return origin
285
- schema = process_schema(origin.model_json_schema())
286
- return types.Schema.model_validate(schema)
406
+ return process_schema(origin, client)
407
+ if isinstance(origin, types.Schema):
408
+ if dict(origin) == dict(types.Schema()):
409
+ # response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
410
+ raise ValueError(f'Unsupported schema type.')
411
+ schema = process_schema(origin.model_dump(exclude_unset=True), client)
412
+ return types.Schema.model_validate(schema)
413
+ if isinstance(origin, GenericAlias):
414
+ if origin.__origin__ is list:
415
+ if isinstance(origin.__args__[0], typing.types.UnionType):
416
+ raise ValueError(f'Unsupported schema type: GenericAlias {origin}')
417
+ if issubclass(origin.__args__[0], pydantic.BaseModel):
418
+ # Handle cases where response schema is `list[pydantic.BaseModel]`
419
+ list_schema = _build_schema(
420
+ 'dummy', {'dummy': (origin, pydantic.Field())}
421
+ )
422
+ list_schema = process_schema(list_schema, client)
423
+ return types.Schema.model_validate(list_schema)
424
+ raise ValueError(f'Unsupported schema type: GenericAlias {origin}')
425
+ if issubclass(origin, pydantic.BaseModel):
426
+ schema = process_schema(origin.model_json_schema(), client)
427
+ return types.Schema.model_validate(schema)
428
+ raise ValueError(f'Unsupported schema type: {origin}')
287
429
 
288
430
 
289
431
  def t_speech_config(
@@ -319,10 +461,10 @@ def t_speech_config(
319
461
  def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
320
462
  if not origin:
321
463
  return None
322
- if inspect.isfunction(origin):
464
+ if inspect.isfunction(origin) or inspect.ismethod(origin):
323
465
  return types.Tool(
324
466
  function_declarations=[
325
- types.FunctionDeclaration.from_function(client, origin)
467
+ types.FunctionDeclaration.from_callable(client, origin)
326
468
  ]
327
469
  )
328
470
  else:
@@ -432,10 +574,25 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
432
574
  return struct
433
575
 
434
576
 
435
- def t_file_name(api_client: _api_client.ApiClient, name: str):
436
- # Remove the files/ prefx since it's added to the url path.
437
- if name.startswith('files/'):
438
- return name.split('files/')[1]
577
+ def t_file_name(
578
+ api_client: _api_client.ApiClient, name: Union[str, types.File]
579
+ ):
580
+ # Remove the files/ prefix since it's added to the url path.
581
+ if isinstance(name, types.File):
582
+ name = name.name
583
+
584
+ if name is None:
585
+ raise ValueError('File name is required.')
586
+
587
+ if name.startswith('https://'):
588
+ suffix = name.split('files/')[1]
589
+ match = re.match('[a-z0-9]+', suffix)
590
+ if match is None:
591
+ raise ValueError(f'Could not extract file name from URI: {name}')
592
+ name = match.group(0)
593
+ elif name.startswith('files/'):
594
+ name = name.split('files/')[1]
595
+
439
596
  return name
440
597
 
441
598
 
@@ -452,3 +609,13 @@ def t_tuning_job_status(
452
609
  return 'JOB_STATE_FAILED'
453
610
  else:
454
611
  return status
612
+
613
+
614
+ # Some fields don't accept url safe base64 encoding.
615
+ # We shouldn't use this transformer if the backend adhere to Cloud Type
616
+ # format https://cloud.google.com/docs/discovery/type-format.
617
+ # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
618
+ def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
619
+ if not isinstance(data, bytes):
620
+ return data
621
+ return base64.b64encode(data).decode('ascii')