google-genai 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +240 -71
- google/genai/_common.py +47 -31
- google/genai/_extra_utils.py +3 -3
- google/genai/_replay_api_client.py +51 -74
- google/genai/_transformers.py +197 -30
- google/genai/batches.py +74 -72
- google/genai/caches.py +104 -90
- google/genai/chats.py +5 -8
- google/genai/client.py +2 -1
- google/genai/errors.py +1 -1
- google/genai/files.py +302 -102
- google/genai/live.py +42 -30
- google/genai/models.py +379 -250
- google/genai/tunings.py +78 -76
- google/genai/types.py +563 -350
- google/genai/version.py +1 -1
- google_genai-0.6.0.dist-info/METADATA +973 -0
- google_genai-0.6.0.dist-info/RECORD +25 -0
- google_genai-0.4.0.dist-info/METADATA +0 -888
- google_genai-0.4.0.dist-info/RECORD +0 -25
- {google_genai-0.4.0.dist-info → google_genai-0.6.0.dist-info}/LICENSE +0 -0
- {google_genai-0.4.0.dist-info → google_genai-0.6.0.dist-info}/WHEEL +0 -0
- {google_genai-0.4.0.dist-info → google_genai-0.6.0.dist-info}/top_level.txt +0 -0
@@ -17,15 +17,15 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
import copy
|
20
|
+
import datetime
|
20
21
|
import inspect
|
22
|
+
import io
|
21
23
|
import json
|
22
24
|
import os
|
23
25
|
import re
|
24
|
-
import datetime
|
25
26
|
from typing import Any, Literal, Optional, Union
|
26
27
|
|
27
28
|
import google.auth
|
28
|
-
from pydantic import BaseModel
|
29
29
|
from requests.exceptions import HTTPError
|
30
30
|
|
31
31
|
from . import errors
|
@@ -33,7 +33,8 @@ from ._api_client import ApiClient
|
|
33
33
|
from ._api_client import HttpOptions
|
34
34
|
from ._api_client import HttpRequest
|
35
35
|
from ._api_client import HttpResponse
|
36
|
-
from .
|
36
|
+
from ._common import BaseModel
|
37
|
+
|
37
38
|
|
38
39
|
def _redact_version_numbers(version_string: str) -> str:
|
39
40
|
"""Redacts version numbers in the form x.y.z from a string."""
|
@@ -72,6 +73,11 @@ def _redact_request_url(url: str) -> str:
|
|
72
73
|
'{VERTEX_URL_PREFIX}/',
|
73
74
|
url,
|
74
75
|
)
|
76
|
+
result = re.sub(
|
77
|
+
r'.*-aiplatform.googleapis.com/[^/]+/',
|
78
|
+
'{VERTEX_URL_PREFIX}/',
|
79
|
+
result,
|
80
|
+
)
|
75
81
|
result = re.sub(
|
76
82
|
r'https://generativelanguage.googleapis.com/[^/]+',
|
77
83
|
'{MLDEV_URL_PREFIX}',
|
@@ -140,6 +146,7 @@ class ReplayResponse(BaseModel):
|
|
140
146
|
status_code: int = 200
|
141
147
|
headers: dict[str, str]
|
142
148
|
body_segments: list[dict[str, object]]
|
149
|
+
byte_segments: Optional[list[bytes]] = None
|
143
150
|
sdk_response_segments: list[dict[str, object]]
|
144
151
|
|
145
152
|
def model_post_init(self, __context: Any) -> None:
|
@@ -259,26 +266,13 @@ class ReplayApiClient(ApiClient):
|
|
259
266
|
replay_file_path = self._get_replay_file_path()
|
260
267
|
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
|
261
268
|
with open(replay_file_path, 'w') as f:
|
262
|
-
|
263
|
-
# Use for non-utf-8 bytes in image/video... output.
|
264
|
-
for interaction in replay_session_dict['interactions']:
|
265
|
-
segments = []
|
266
|
-
for response in interaction['response']['sdk_response_segments']:
|
267
|
-
segments.append(json.loads(json.dumps(
|
268
|
-
response, cls=ResponseJsonEncoder
|
269
|
-
)))
|
270
|
-
interaction['response']['sdk_response_segments'] = segments
|
271
|
-
f.write(
|
272
|
-
json.dumps(
|
273
|
-
replay_session_dict, indent=2, cls=RequestJsonEncoder
|
274
|
-
)
|
275
|
-
)
|
269
|
+
f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
|
276
270
|
self.replay_session = None
|
277
271
|
|
278
272
|
def _record_interaction(
|
279
273
|
self,
|
280
274
|
http_request: HttpRequest,
|
281
|
-
http_response: Union[HttpResponse, errors.APIError],
|
275
|
+
http_response: Union[HttpResponse, errors.APIError, bytes],
|
282
276
|
):
|
283
277
|
if not self._should_update_replay():
|
284
278
|
return
|
@@ -293,6 +287,9 @@ class ReplayApiClient(ApiClient):
|
|
293
287
|
response = ReplayResponse(
|
294
288
|
headers=dict(http_response.headers),
|
295
289
|
body_segments=list(http_response.segments()),
|
290
|
+
byte_segments=[
|
291
|
+
seg[:100] + b'...' for seg in http_response.byte_segments()
|
292
|
+
],
|
296
293
|
status_code=http_response.status_code,
|
297
294
|
sdk_response_segments=[],
|
298
295
|
)
|
@@ -326,11 +323,7 @@ class ReplayApiClient(ApiClient):
|
|
326
323
|
# so that the comparison is fair.
|
327
324
|
_redact_request_body(request_data_copy)
|
328
325
|
|
329
|
-
|
330
|
-
# Because the expected_request_body dict never contains bytes values.
|
331
|
-
actual_request_body = [
|
332
|
-
json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
|
333
|
-
]
|
326
|
+
actual_request_body = [request_data_copy]
|
334
327
|
expected_request_body = interaction.request.body_segments
|
335
328
|
assert actual_request_body == expected_request_body, (
|
336
329
|
'Request body mismatch:\n'
|
@@ -353,6 +346,7 @@ class ReplayApiClient(ApiClient):
|
|
353
346
|
json.dumps(segment)
|
354
347
|
for segment in interaction.response.body_segments
|
355
348
|
],
|
349
|
+
byte_stream=interaction.response.byte_segments,
|
356
350
|
)
|
357
351
|
|
358
352
|
def _verify_response(self, response_model: BaseModel):
|
@@ -371,15 +365,10 @@ class ReplayApiClient(ApiClient):
|
|
371
365
|
if isinstance(response_model, list):
|
372
366
|
response_model = response_model[0]
|
373
367
|
print('response_model: ', response_model.model_dump(exclude_none=True))
|
374
|
-
actual =
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
)
|
379
|
-
expected = json.dumps(
|
380
|
-
interaction.response.sdk_response_segments[self._sdk_response_index],
|
381
|
-
sort_keys=True,
|
382
|
-
)
|
368
|
+
actual = response_model.model_dump(exclude_none=True, mode='json')
|
369
|
+
expected = interaction.response.sdk_response_segments[
|
370
|
+
self._sdk_response_index
|
371
|
+
]
|
383
372
|
assert (
|
384
373
|
actual == expected
|
385
374
|
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
|
@@ -413,10 +402,21 @@ class ReplayApiClient(ApiClient):
|
|
413
402
|
else:
|
414
403
|
return self._build_response_from_replay(http_request)
|
415
404
|
|
416
|
-
def upload_file(self, file_path: str, upload_url: str, upload_size: int):
|
417
|
-
|
418
|
-
|
419
|
-
|
405
|
+
def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
|
406
|
+
if isinstance(file_path, io.IOBase):
|
407
|
+
offset = file_path.tell()
|
408
|
+
content = file_path.read()
|
409
|
+
file_path.seek(offset, os.SEEK_SET)
|
410
|
+
request = HttpRequest(
|
411
|
+
method='POST',
|
412
|
+
url='',
|
413
|
+
data={'bytes': base64.b64encode(content).decode('utf-8')},
|
414
|
+
headers={}
|
415
|
+
)
|
416
|
+
else:
|
417
|
+
request = HttpRequest(
|
418
|
+
method='POST', url='', data={'file_path': file_path}, headers={}
|
419
|
+
)
|
420
420
|
if self._should_call_api():
|
421
421
|
try:
|
422
422
|
result = super().upload_file(file_path, upload_url, upload_size)
|
@@ -431,42 +431,19 @@ class ReplayApiClient(ApiClient):
|
|
431
431
|
else:
|
432
432
|
return self._build_response_from_replay(request).text
|
433
433
|
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
sent to server. For the bytes type, there is no base64 string in response
|
448
|
-
anymore, because SDK handles it internally. So bytes type in Response is
|
449
|
-
non-stringable. The ResponseJsonEncoder uses different encoding
|
450
|
-
strategy than the RequestJsonEncoder to deal with utf-8 JSON broken issue.
|
451
|
-
"""
|
452
|
-
def default(self, o):
|
453
|
-
if isinstance(o, bytes):
|
454
|
-
# Use base64.b64encode() to encode bytes to string so that the media bytes
|
455
|
-
# fields are serializable.
|
456
|
-
# o.decode(encoding='utf-8', errors='replace') doesn't work because it
|
457
|
-
# uses a fixed error string `\ufffd` for all non-utf-8 characters,
|
458
|
-
# which cannot be converted back to original bytes. And other languages
|
459
|
-
# only have the original bytes to compare with.
|
460
|
-
# Since we use base64.b64encoding() in replay test, a change that breaks
|
461
|
-
# native bytes can be captured by
|
462
|
-
# test_compute_tokens.py::test_token_bytes_deserialization.
|
463
|
-
return base64.b64encode(o).decode(encoding='utf-8')
|
464
|
-
elif isinstance(o, datetime.datetime):
|
465
|
-
# dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
|
466
|
-
# but replay files want "2024-11-15T23:27:45.624657Z"
|
467
|
-
if o.isoformat().endswith('+00:00'):
|
468
|
-
return o.isoformat().replace('+00:00', 'Z')
|
469
|
-
else:
|
470
|
-
return o.isoformat()
|
434
|
+
def _download_file_request(self, request):
|
435
|
+
self._initialize_replay_session_if_not_loaded()
|
436
|
+
if self._should_call_api():
|
437
|
+
try:
|
438
|
+
result = super()._download_file_request(request)
|
439
|
+
except HTTPError as e:
|
440
|
+
result = HttpResponse(
|
441
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
442
|
+
)
|
443
|
+
result.status_code = e.response.status_code
|
444
|
+
raise e
|
445
|
+
self._record_interaction(request, result)
|
446
|
+
return result
|
471
447
|
else:
|
472
|
-
return
|
448
|
+
return self._build_response_from_replay(request)
|
449
|
+
|
google/genai/_transformers.py
CHANGED
@@ -21,10 +21,12 @@ import inspect
|
|
21
21
|
import io
|
22
22
|
import re
|
23
23
|
import time
|
24
|
-
|
24
|
+
import typing
|
25
|
+
from typing import Any, GenericAlias, Optional, Union
|
25
26
|
|
26
27
|
import PIL.Image
|
27
28
|
import PIL.PngImagePlugin
|
29
|
+
import pydantic
|
28
30
|
|
29
31
|
from . import _api_client
|
30
32
|
from . import types
|
@@ -35,7 +37,7 @@ def _resource_name(
|
|
35
37
|
resource_name: str,
|
36
38
|
*,
|
37
39
|
collection_identifier: str,
|
38
|
-
|
40
|
+
collection_hierarchy_depth: int = 2,
|
39
41
|
):
|
40
42
|
# pylint: disable=line-too-long
|
41
43
|
"""Prepends resource name with project, location, collection_identifier if needed.
|
@@ -48,13 +50,13 @@ def _resource_name(
|
|
48
50
|
Args:
|
49
51
|
client: The API client.
|
50
52
|
resource_name: The user input resource name to be completed.
|
51
|
-
collection_identifier: The collection identifier to be prepended.
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
53
|
+
collection_identifier: The collection identifier to be prepended. See
|
54
|
+
collection identifiers in https://google.aip.dev/122.
|
55
|
+
collection_hierarchy_depth: The collection hierarchy depth. Only set this
|
56
|
+
field when the resource has nested collections. For example,
|
57
|
+
`users/vhugo1802/events/birthday-dinner-226`, the collection_identifier is
|
58
|
+
`users` and collection_hierarchy_depth is 4. See nested collections in
|
59
|
+
https://google.aip.dev/122.
|
58
60
|
|
59
61
|
Example:
|
60
62
|
|
@@ -62,7 +64,8 @@ def _resource_name(
|
|
62
64
|
client.vertexai = True
|
63
65
|
client.project = 'bar'
|
64
66
|
client.location = 'us-west1'
|
65
|
-
_resource_name(client, 'cachedContents/123',
|
67
|
+
_resource_name(client, 'cachedContents/123',
|
68
|
+
collection_identifier='cachedContents')
|
66
69
|
returns: 'projects/bar/locations/us-west1/cachedContents/123'
|
67
70
|
|
68
71
|
Example:
|
@@ -72,7 +75,8 @@ def _resource_name(
|
|
72
75
|
client.vertexai = True
|
73
76
|
client.project = 'bar'
|
74
77
|
client.location = 'us-west1'
|
75
|
-
_resource_name(client, resource_name,
|
78
|
+
_resource_name(client, resource_name,
|
79
|
+
collection_identifier='cachedContents')
|
76
80
|
returns: 'projects/foo/locations/us-central1/cachedContents/123'
|
77
81
|
|
78
82
|
Example:
|
@@ -80,7 +84,8 @@ def _resource_name(
|
|
80
84
|
resource_name = '123'
|
81
85
|
# resource_name = 'cachedContents/123'
|
82
86
|
client.vertexai = False
|
83
|
-
_resource_name(client, resource_name,
|
87
|
+
_resource_name(client, resource_name,
|
88
|
+
collection_identifier='cachedContents')
|
84
89
|
returns 'cachedContents/123'
|
85
90
|
|
86
91
|
Example:
|
@@ -88,7 +93,8 @@ def _resource_name(
|
|
88
93
|
resource_prefix = 'cachedContents'
|
89
94
|
client.vertexai = False
|
90
95
|
# client.vertexai = True
|
91
|
-
_resource_name(client, resource_name,
|
96
|
+
_resource_name(client, resource_name,
|
97
|
+
collection_identifier='cachedContents')
|
92
98
|
returns: 'some/wrong/cachedContents/resource/name/123'
|
93
99
|
|
94
100
|
Returns:
|
@@ -99,7 +105,7 @@ def _resource_name(
|
|
99
105
|
# Check if prepending the collection identifier won't violate the
|
100
106
|
# collection hierarchy depth.
|
101
107
|
and f'{collection_identifier}/{resource_name}'.count('/') + 1
|
102
|
-
==
|
108
|
+
== collection_hierarchy_depth
|
103
109
|
)
|
104
110
|
if client.vertexai:
|
105
111
|
if resource_name.startswith('projects/'):
|
@@ -142,6 +148,35 @@ def t_model(client: _api_client.ApiClient, model: str):
|
|
142
148
|
else:
|
143
149
|
return f'models/{model}'
|
144
150
|
|
151
|
+
|
152
|
+
def t_models_url(api_client: _api_client.ApiClient, base_models: bool) -> str:
|
153
|
+
if api_client.vertexai:
|
154
|
+
if base_models:
|
155
|
+
return 'publishers/google/models'
|
156
|
+
else:
|
157
|
+
return 'models'
|
158
|
+
else:
|
159
|
+
if base_models:
|
160
|
+
return 'models'
|
161
|
+
else:
|
162
|
+
return 'tunedModels'
|
163
|
+
|
164
|
+
|
165
|
+
def t_extract_models(
|
166
|
+
api_client: _api_client.ApiClient, response: dict
|
167
|
+
) -> list[types.Model]:
|
168
|
+
if not response:
|
169
|
+
return []
|
170
|
+
elif response.get('models') is not None:
|
171
|
+
return response.get('models')
|
172
|
+
elif response.get('tunedModels') is not None:
|
173
|
+
return response.get('tunedModels')
|
174
|
+
elif response.get('publisherModels') is not None:
|
175
|
+
return response.get('publisherModels')
|
176
|
+
else:
|
177
|
+
raise ValueError('Cannot determine the models type.')
|
178
|
+
|
179
|
+
|
145
180
|
def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
146
181
|
model = t_model(api_client, model)
|
147
182
|
if not model:
|
@@ -180,6 +215,10 @@ def t_part(client: _api_client.ApiClient, part: PartType) -> types.Part:
|
|
180
215
|
return types.Part(text=part)
|
181
216
|
if isinstance(part, PIL.Image.Image):
|
182
217
|
return types.Part(inline_data=pil_to_blob(part))
|
218
|
+
if isinstance(part, types.File):
|
219
|
+
if not part.uri or not part.mime_type:
|
220
|
+
raise ValueError('file uri and mime_type are required.')
|
221
|
+
return types.Part.from_uri(part.uri, part.mime_type)
|
183
222
|
else:
|
184
223
|
return part
|
185
224
|
|
@@ -258,32 +297,135 @@ def t_contents(
|
|
258
297
|
return [t_content(client, contents)]
|
259
298
|
|
260
299
|
|
261
|
-
def process_schema(
|
300
|
+
def process_schema(
|
301
|
+
data: dict[str, Any], client: Optional[_api_client.ApiClient] = None
|
302
|
+
):
|
262
303
|
if isinstance(data, dict):
|
263
304
|
# Iterate over a copy of keys to allow deletion
|
264
305
|
for key in list(data.keys()):
|
265
|
-
|
306
|
+
# Only delete 'title'for the Gemini API
|
307
|
+
if client and not client.vertexai and key == 'title':
|
266
308
|
del data[key]
|
267
|
-
elif key == 'type':
|
268
|
-
data[key] = data[key].upper()
|
269
309
|
else:
|
270
|
-
process_schema(data[key])
|
310
|
+
process_schema(data[key], client)
|
271
311
|
elif isinstance(data, list):
|
272
312
|
for item in data:
|
273
|
-
process_schema(item)
|
313
|
+
process_schema(item, client)
|
274
314
|
|
275
315
|
return data
|
276
316
|
|
277
317
|
|
318
|
+
def _build_schema(fname: str, fields_dict: dict[str, Any]) -> dict[str, Any]:
|
319
|
+
parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
|
320
|
+
defs = parameters.pop('$defs', {})
|
321
|
+
|
322
|
+
for _, value in defs.items():
|
323
|
+
unpack_defs(value, defs)
|
324
|
+
|
325
|
+
unpack_defs(parameters, defs)
|
326
|
+
return parameters['properties']['dummy']
|
327
|
+
|
328
|
+
|
329
|
+
def unpack_defs(schema: dict[str, Any], defs: dict[str, Any]):
|
330
|
+
"""Unpacks the $defs values in the schema generated by pydantic so they can be understood by the API.
|
331
|
+
|
332
|
+
Example of a schema before and after unpacking:
|
333
|
+
Before:
|
334
|
+
|
335
|
+
`schema`
|
336
|
+
|
337
|
+
{'properties': {
|
338
|
+
'dummy': {
|
339
|
+
'items': {
|
340
|
+
'$ref': '#/$defs/CountryInfo'
|
341
|
+
},
|
342
|
+
'title': 'Dummy',
|
343
|
+
'type': 'array'
|
344
|
+
}
|
345
|
+
},
|
346
|
+
'required': ['dummy'],
|
347
|
+
'title': 'dummy',
|
348
|
+
'type': 'object'}
|
349
|
+
|
350
|
+
`defs`
|
351
|
+
|
352
|
+
{'CountryInfo': {'properties': {'continent': {'title': 'Continent', 'type':
|
353
|
+
'string'}, 'gdp': {'title': 'Gdp', 'type': 'integer'}}, 'required':
|
354
|
+
['continent', 'gdp'], 'title': 'CountryInfo', 'type': 'object'}}
|
355
|
+
|
356
|
+
After:
|
357
|
+
|
358
|
+
`schema`
|
359
|
+
{'properties': {
|
360
|
+
'continent': {'title': 'Continent', 'type': 'string'},
|
361
|
+
'gdp': {'title': 'Gdp', 'type': 'integer'}
|
362
|
+
},
|
363
|
+
'required': ['continent', 'gdp'],
|
364
|
+
'title': 'CountryInfo',
|
365
|
+
'type': 'object'
|
366
|
+
}
|
367
|
+
"""
|
368
|
+
properties = schema.get('properties', None)
|
369
|
+
if properties is None:
|
370
|
+
return
|
371
|
+
|
372
|
+
for name, value in properties.items():
|
373
|
+
ref_key = value.get('$ref', None)
|
374
|
+
if ref_key is not None:
|
375
|
+
ref = defs[ref_key.split('defs/')[-1]]
|
376
|
+
unpack_defs(ref, defs)
|
377
|
+
properties[name] = ref
|
378
|
+
continue
|
379
|
+
|
380
|
+
anyof = value.get('anyOf', None)
|
381
|
+
if anyof is not None:
|
382
|
+
for i, atype in enumerate(anyof):
|
383
|
+
ref_key = atype.get('$ref', None)
|
384
|
+
if ref_key is not None:
|
385
|
+
ref = defs[ref_key.split('defs/')[-1]]
|
386
|
+
unpack_defs(ref, defs)
|
387
|
+
anyof[i] = ref
|
388
|
+
continue
|
389
|
+
|
390
|
+
items = value.get('items', None)
|
391
|
+
if items is not None:
|
392
|
+
ref_key = items.get('$ref', None)
|
393
|
+
if ref_key is not None:
|
394
|
+
ref = defs[ref_key.split('defs/')[-1]]
|
395
|
+
unpack_defs(ref, defs)
|
396
|
+
value['items'] = ref
|
397
|
+
continue
|
398
|
+
|
399
|
+
|
278
400
|
def t_schema(
|
279
|
-
|
401
|
+
client: _api_client.ApiClient, origin: Union[types.SchemaUnionDict, Any]
|
280
402
|
) -> Optional[types.Schema]:
|
281
403
|
if not origin:
|
282
404
|
return None
|
283
405
|
if isinstance(origin, dict):
|
284
|
-
return origin
|
285
|
-
|
286
|
-
|
406
|
+
return process_schema(origin, client)
|
407
|
+
if isinstance(origin, types.Schema):
|
408
|
+
if dict(origin) == dict(types.Schema()):
|
409
|
+
# response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
|
410
|
+
raise ValueError(f'Unsupported schema type.')
|
411
|
+
schema = process_schema(origin.model_dump(exclude_unset=True), client)
|
412
|
+
return types.Schema.model_validate(schema)
|
413
|
+
if isinstance(origin, GenericAlias):
|
414
|
+
if origin.__origin__ is list:
|
415
|
+
if isinstance(origin.__args__[0], typing.types.UnionType):
|
416
|
+
raise ValueError(f'Unsupported schema type: GenericAlias {origin}')
|
417
|
+
if issubclass(origin.__args__[0], pydantic.BaseModel):
|
418
|
+
# Handle cases where response schema is `list[pydantic.BaseModel]`
|
419
|
+
list_schema = _build_schema(
|
420
|
+
'dummy', {'dummy': (origin, pydantic.Field())}
|
421
|
+
)
|
422
|
+
list_schema = process_schema(list_schema, client)
|
423
|
+
return types.Schema.model_validate(list_schema)
|
424
|
+
raise ValueError(f'Unsupported schema type: GenericAlias {origin}')
|
425
|
+
if issubclass(origin, pydantic.BaseModel):
|
426
|
+
schema = process_schema(origin.model_json_schema(), client)
|
427
|
+
return types.Schema.model_validate(schema)
|
428
|
+
raise ValueError(f'Unsupported schema type: {origin}')
|
287
429
|
|
288
430
|
|
289
431
|
def t_speech_config(
|
@@ -319,10 +461,10 @@ def t_speech_config(
|
|
319
461
|
def t_tool(client: _api_client.ApiClient, origin) -> types.Tool:
|
320
462
|
if not origin:
|
321
463
|
return None
|
322
|
-
if inspect.isfunction(origin):
|
464
|
+
if inspect.isfunction(origin) or inspect.ismethod(origin):
|
323
465
|
return types.Tool(
|
324
466
|
function_declarations=[
|
325
|
-
types.FunctionDeclaration.
|
467
|
+
types.FunctionDeclaration.from_callable(client, origin)
|
326
468
|
]
|
327
469
|
)
|
328
470
|
else:
|
@@ -432,10 +574,25 @@ def t_resolve_operation(api_client: _api_client.ApiClient, struct: dict):
|
|
432
574
|
return struct
|
433
575
|
|
434
576
|
|
435
|
-
def t_file_name(
|
436
|
-
|
437
|
-
|
438
|
-
|
577
|
+
def t_file_name(
|
578
|
+
api_client: _api_client.ApiClient, name: Union[str, types.File]
|
579
|
+
):
|
580
|
+
# Remove the files/ prefix since it's added to the url path.
|
581
|
+
if isinstance(name, types.File):
|
582
|
+
name = name.name
|
583
|
+
|
584
|
+
if name is None:
|
585
|
+
raise ValueError('File name is required.')
|
586
|
+
|
587
|
+
if name.startswith('https://'):
|
588
|
+
suffix = name.split('files/')[1]
|
589
|
+
match = re.match('[a-z0-9]+', suffix)
|
590
|
+
if match is None:
|
591
|
+
raise ValueError(f'Could not extract file name from URI: {name}')
|
592
|
+
name = match.group(0)
|
593
|
+
elif name.startswith('files/'):
|
594
|
+
name = name.split('files/')[1]
|
595
|
+
|
439
596
|
return name
|
440
597
|
|
441
598
|
|
@@ -452,3 +609,13 @@ def t_tuning_job_status(
|
|
452
609
|
return 'JOB_STATE_FAILED'
|
453
610
|
else:
|
454
611
|
return status
|
612
|
+
|
613
|
+
|
614
|
+
# Some fields don't accept url safe base64 encoding.
|
615
|
+
# We shouldn't use this transformer if the backend adhere to Cloud Type
|
616
|
+
# format https://cloud.google.com/docs/discovery/type-format.
|
617
|
+
# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
|
618
|
+
def t_bytes(api_client: _api_client.ApiClient, data: bytes) -> str:
|
619
|
+
if not isinstance(data, bytes):
|
620
|
+
return data
|
621
|
+
return base64.b64encode(data).decode('ascii')
|