google-genai 1.30.0__py3-none-any.whl → 1.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +32 -32
- google/genai/_automatic_function_calling_util.py +12 -0
- google/genai/_base_transformers.py +26 -0
- google/genai/_live_converters.py +1 -0
- google/genai/_local_tokenizer_loader.py +223 -0
- google/genai/_operations_converters.py +307 -0
- google/genai/_tokens_converters.py +1 -0
- google/genai/_transformers.py +0 -10
- google/genai/batches.py +141 -0
- google/genai/caches.py +15 -2
- google/genai/files.py +11 -2
- google/genai/local_tokenizer.py +362 -0
- google/genai/models.py +518 -17
- google/genai/operations.py +1 -0
- google/genai/tunings.py +135 -0
- google/genai/types.py +781 -323
- google/genai/version.py +1 -1
- {google_genai-1.30.0.dist-info → google_genai-1.32.0.dist-info}/METADATA +6 -6
- google_genai-1.32.0.dist-info/RECORD +39 -0
- google_genai-1.30.0.dist-info/RECORD +0 -35
- {google_genai-1.30.0.dist-info → google_genai-1.32.0.dist-info}/WHEEL +0 -0
- {google_genai-1.30.0.dist-info → google_genai-1.32.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.30.0.dist-info → google_genai-1.32.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -91,7 +91,7 @@ class EphemeralTokenAPIKeyError(ValueError):
|
|
91
91
|
|
92
92
|
# This method checks for the API key in the environment variables. Google API
|
93
93
|
# key is precedenced over Gemini API key.
|
94
|
-
def
|
94
|
+
def get_env_api_key() -> Optional[str]:
|
95
95
|
"""Gets the API key from environment variables, prioritizing GOOGLE_API_KEY.
|
96
96
|
|
97
97
|
Returns:
|
@@ -108,7 +108,7 @@ def _get_env_api_key() -> Optional[str]:
|
|
108
108
|
return env_google_api_key or env_gemini_api_key or None
|
109
109
|
|
110
110
|
|
111
|
-
def
|
111
|
+
def append_library_version_headers(headers: dict[str, str]) -> None:
|
112
112
|
"""Appends the telemetry header to the headers dict."""
|
113
113
|
library_label = f'google-genai-sdk/{version.__version__}'
|
114
114
|
language_label = 'gl-python/' + sys.version.split()[0]
|
@@ -131,7 +131,7 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
131
131
|
headers['x-goog-api-client'] = version_header_value
|
132
132
|
|
133
133
|
|
134
|
-
def
|
134
|
+
def patch_http_options(
|
135
135
|
options: HttpOptions, patch_options: HttpOptions
|
136
136
|
) -> HttpOptions:
|
137
137
|
copy_option = options.model_copy()
|
@@ -155,11 +155,11 @@ def _patch_http_options(
|
|
155
155
|
setattr(copy_option, key, getattr(options, key))
|
156
156
|
|
157
157
|
if copy_option.headers is not None:
|
158
|
-
|
158
|
+
append_library_version_headers(copy_option.headers)
|
159
159
|
return copy_option
|
160
160
|
|
161
161
|
|
162
|
-
def
|
162
|
+
def populate_server_timeout_header(
|
163
163
|
headers: dict[str, str], timeout_in_seconds: Optional[Union[float, int]]
|
164
164
|
) -> None:
|
165
165
|
"""Populates the server timeout header in the headers dict."""
|
@@ -167,7 +167,7 @@ def _populate_server_timeout_header(
|
|
167
167
|
headers['X-Server-Timeout'] = str(math.ceil(timeout_in_seconds))
|
168
168
|
|
169
169
|
|
170
|
-
def
|
170
|
+
def join_url_path(base_url: str, path: str) -> str:
|
171
171
|
parsed_base = urlparse(base_url)
|
172
172
|
base_path = (
|
173
173
|
parsed_base.path[:-1]
|
@@ -178,7 +178,7 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
178
178
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
179
179
|
|
180
180
|
|
181
|
-
def
|
181
|
+
def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
182
182
|
"""Loads google auth credentials and project id."""
|
183
183
|
credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
|
184
184
|
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
@@ -195,12 +195,12 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
|
195
195
|
return credentials, project
|
196
196
|
|
197
197
|
|
198
|
-
def
|
198
|
+
def refresh_auth(credentials: Credentials) -> Credentials:
|
199
199
|
credentials.refresh(Request()) # type: ignore[no-untyped-call]
|
200
200
|
return credentials
|
201
201
|
|
202
202
|
|
203
|
-
def
|
203
|
+
def get_timeout_in_seconds(
|
204
204
|
timeout: Optional[Union[float, int]],
|
205
205
|
) -> Optional[float]:
|
206
206
|
"""Converts the timeout to seconds."""
|
@@ -454,7 +454,7 @@ _RETRY_HTTP_STATUS_CODES = (
|
|
454
454
|
)
|
455
455
|
|
456
456
|
|
457
|
-
def
|
457
|
+
def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
|
458
458
|
"""Returns the retry args for the given http retry options.
|
459
459
|
|
460
460
|
Args:
|
@@ -574,7 +574,7 @@ class BaseApiClient:
|
|
574
574
|
# Retrieve implicitly set values from the environment.
|
575
575
|
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
576
576
|
env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
577
|
-
env_api_key =
|
577
|
+
env_api_key = get_env_api_key()
|
578
578
|
self.project = project or env_project
|
579
579
|
self.location = location or env_location
|
580
580
|
self.api_key = api_key or env_api_key
|
@@ -631,7 +631,7 @@ class BaseApiClient:
|
|
631
631
|
and not self.api_key
|
632
632
|
and not validated_http_options.base_url
|
633
633
|
):
|
634
|
-
credentials, self.project =
|
634
|
+
credentials, self.project = load_auth(project=None)
|
635
635
|
if not self._credentials:
|
636
636
|
self._credentials = credentials
|
637
637
|
|
@@ -670,12 +670,12 @@ class BaseApiClient:
|
|
670
670
|
self._http_options.headers['x-goog-api-key'] = self.api_key
|
671
671
|
# Update the http options with the user provided http options.
|
672
672
|
if http_options:
|
673
|
-
self._http_options =
|
673
|
+
self._http_options = patch_http_options(
|
674
674
|
self._http_options, validated_http_options
|
675
675
|
)
|
676
676
|
else:
|
677
677
|
if self._http_options.headers is not None:
|
678
|
-
|
678
|
+
append_library_version_headers(self._http_options.headers)
|
679
679
|
|
680
680
|
client_args, async_client_args = self._ensure_httpx_ssl_ctx(
|
681
681
|
self._http_options
|
@@ -689,7 +689,7 @@ class BaseApiClient:
|
|
689
689
|
)
|
690
690
|
self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
|
691
691
|
|
692
|
-
retry_kwargs =
|
692
|
+
retry_kwargs = retry_args(self._http_options.retry_options)
|
693
693
|
self._retry = tenacity.Retrying(**retry_kwargs)
|
694
694
|
self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
|
695
695
|
|
@@ -889,14 +889,14 @@ class BaseApiClient:
|
|
889
889
|
"""Retrieves the access token for the credentials."""
|
890
890
|
with self._sync_auth_lock:
|
891
891
|
if not self._credentials:
|
892
|
-
self._credentials, project =
|
892
|
+
self._credentials, project = load_auth(project=self.project)
|
893
893
|
if not self.project:
|
894
894
|
self.project = project
|
895
895
|
|
896
896
|
if self._credentials:
|
897
897
|
if self._credentials.expired or not self._credentials.token:
|
898
898
|
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
899
|
-
|
899
|
+
refresh_auth(self._credentials)
|
900
900
|
if not self._credentials.token:
|
901
901
|
raise RuntimeError('Could not resolve API token from the environment')
|
902
902
|
return self._credentials.token # type: ignore[no-any-return]
|
@@ -912,7 +912,7 @@ class BaseApiClient:
|
|
912
912
|
if not self._credentials:
|
913
913
|
# Double check that the credentials are not set before loading them.
|
914
914
|
self._credentials, project = await asyncio.to_thread(
|
915
|
-
|
915
|
+
load_auth, project=self.project
|
916
916
|
)
|
917
917
|
if not self.project:
|
918
918
|
self.project = project
|
@@ -923,7 +923,7 @@ class BaseApiClient:
|
|
923
923
|
async with self._async_auth_lock:
|
924
924
|
if self._credentials.expired or not self._credentials.token:
|
925
925
|
# Double check that the credentials expired before refreshing.
|
926
|
-
await asyncio.to_thread(
|
926
|
+
await asyncio.to_thread(refresh_auth, self._credentials)
|
927
927
|
|
928
928
|
if not self._credentials.token:
|
929
929
|
raise RuntimeError('Could not resolve API token from the environment')
|
@@ -946,12 +946,12 @@ class BaseApiClient:
|
|
946
946
|
# patch the http options with the user provided settings.
|
947
947
|
if http_options:
|
948
948
|
if isinstance(http_options, HttpOptions):
|
949
|
-
patched_http_options =
|
949
|
+
patched_http_options = patch_http_options(
|
950
950
|
self._http_options,
|
951
951
|
http_options,
|
952
952
|
)
|
953
953
|
else:
|
954
|
-
patched_http_options =
|
954
|
+
patched_http_options = patch_http_options(
|
955
955
|
self._http_options, HttpOptions.model_validate(http_options)
|
956
956
|
)
|
957
957
|
else:
|
@@ -993,7 +993,7 @@ class BaseApiClient:
|
|
993
993
|
request_dict, patched_http_options.extra_body
|
994
994
|
)
|
995
995
|
|
996
|
-
url =
|
996
|
+
url = join_url_path(
|
997
997
|
base_url,
|
998
998
|
versioned_path,
|
999
999
|
)
|
@@ -1003,11 +1003,11 @@ class BaseApiClient:
|
|
1003
1003
|
'Ephemeral tokens can only be used with the live API.'
|
1004
1004
|
)
|
1005
1005
|
|
1006
|
-
timeout_in_seconds =
|
1006
|
+
timeout_in_seconds = get_timeout_in_seconds(patched_http_options.timeout)
|
1007
1007
|
|
1008
1008
|
if patched_http_options.headers is None:
|
1009
1009
|
raise ValueError('Request headers must be set.')
|
1010
|
-
|
1010
|
+
populate_server_timeout_header(
|
1011
1011
|
patched_http_options.headers, timeout_in_seconds
|
1012
1012
|
)
|
1013
1013
|
return HttpRequest(
|
@@ -1079,7 +1079,7 @@ class BaseApiClient:
|
|
1079
1079
|
)
|
1080
1080
|
# Support per request retry options.
|
1081
1081
|
if parameter_model.retry_options:
|
1082
|
-
retry_kwargs =
|
1082
|
+
retry_kwargs = retry_args(parameter_model.retry_options)
|
1083
1083
|
retry = tenacity.Retrying(**retry_kwargs)
|
1084
1084
|
return retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
|
1085
1085
|
|
@@ -1239,7 +1239,7 @@ class BaseApiClient:
|
|
1239
1239
|
)
|
1240
1240
|
# Support per request retry options.
|
1241
1241
|
if parameter_model.retry_options:
|
1242
|
-
retry_kwargs =
|
1242
|
+
retry_kwargs = retry_args(parameter_model.retry_options)
|
1243
1243
|
retry = tenacity.AsyncRetrying(**retry_kwargs)
|
1244
1244
|
return await retry(self._async_request_once, http_request, stream) # type: ignore[no-any-return]
|
1245
1245
|
return await self._async_retry( # type: ignore[no-any-return]
|
@@ -1398,13 +1398,13 @@ class BaseApiClient:
|
|
1398
1398
|
if isinstance(self._http_options, dict)
|
1399
1399
|
else self._http_options.timeout
|
1400
1400
|
)
|
1401
|
-
timeout_in_seconds =
|
1401
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
1402
1402
|
upload_headers = {
|
1403
1403
|
'X-Goog-Upload-Command': upload_command,
|
1404
1404
|
'X-Goog-Upload-Offset': str(offset),
|
1405
1405
|
'Content-Length': str(chunk_size),
|
1406
1406
|
}
|
1407
|
-
|
1407
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
1408
1408
|
retry_count = 0
|
1409
1409
|
while retry_count < MAX_RETRY_COUNT:
|
1410
1410
|
response = self._httpx_client.request(
|
@@ -1558,13 +1558,13 @@ class BaseApiClient:
|
|
1558
1558
|
if isinstance(self._http_options, dict)
|
1559
1559
|
else self._http_options.timeout
|
1560
1560
|
)
|
1561
|
-
timeout_in_seconds =
|
1561
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
1562
1562
|
upload_headers = {
|
1563
1563
|
'X-Goog-Upload-Command': upload_command,
|
1564
1564
|
'X-Goog-Upload-Offset': str(offset),
|
1565
1565
|
'Content-Length': str(chunk_size),
|
1566
1566
|
}
|
1567
|
-
|
1567
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
1568
1568
|
|
1569
1569
|
retry_count = 0
|
1570
1570
|
response = None
|
@@ -1634,13 +1634,13 @@ class BaseApiClient:
|
|
1634
1634
|
if isinstance(self._http_options, dict)
|
1635
1635
|
else self._http_options.timeout
|
1636
1636
|
)
|
1637
|
-
timeout_in_seconds =
|
1637
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
1638
1638
|
upload_headers = {
|
1639
1639
|
'X-Goog-Upload-Command': upload_command,
|
1640
1640
|
'X-Goog-Upload-Offset': str(offset),
|
1641
1641
|
'Content-Length': str(chunk_size),
|
1642
1642
|
}
|
1643
|
-
|
1643
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
1644
1644
|
|
1645
1645
|
retry_count = 0
|
1646
1646
|
client_response = None
|
@@ -30,6 +30,18 @@ if sys.version_info >= (3, 10):
|
|
30
30
|
else:
|
31
31
|
VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
|
32
32
|
|
33
|
+
|
34
|
+
__all__ = [
|
35
|
+
'_py_builtin_type_to_schema_type',
|
36
|
+
'_raise_for_unsupported_param',
|
37
|
+
'_handle_params_as_deferred_annotations',
|
38
|
+
'_add_unevaluated_items_to_fixed_len_tuple_schema',
|
39
|
+
'_is_builtin_primitive_or_compound',
|
40
|
+
'_is_default_value_compatible',
|
41
|
+
'_parse_schema_from_parameter',
|
42
|
+
'_get_required_fields',
|
43
|
+
]
|
44
|
+
|
33
45
|
_py_builtin_type_to_schema_type = {
|
34
46
|
str: types.Type.STRING,
|
35
47
|
int: types.Type.INTEGER,
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Base transformers for Google GenAI SDK."""
|
17
|
+
import base64
|
18
|
+
|
19
|
+
# Some fields don't accept url safe base64 encoding.
|
20
|
+
# We shouldn't use this transformer if the backend adhere to Cloud Type
|
21
|
+
# format https://cloud.google.com/docs/discovery/type-format.
|
22
|
+
# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
|
23
|
+
def t_bytes(data: bytes) -> str:
|
24
|
+
if not isinstance(data, bytes):
|
25
|
+
return data
|
26
|
+
return base64.b64encode(data).decode('ascii')
|
google/genai/_live_converters.py
CHANGED
@@ -0,0 +1,223 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
import dataclasses
|
17
|
+
import functools
|
18
|
+
import hashlib
|
19
|
+
import os
|
20
|
+
import tempfile
|
21
|
+
from typing import Optional, cast
|
22
|
+
import uuid
|
23
|
+
|
24
|
+
import requests # type: ignore
|
25
|
+
import sentencepiece as spm
|
26
|
+
from sentencepiece import sentencepiece_model_pb2
|
27
|
+
|
28
|
+
|
29
|
+
# Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
|
30
|
+
_GEMINI_MODELS_TO_TOKENIZER_NAMES = {
|
31
|
+
"gemini-1.0-pro": "gemma2",
|
32
|
+
"gemini-1.5-pro": "gemma2",
|
33
|
+
"gemini-1.5-flash": "gemma2",
|
34
|
+
"gemini-2.5-pro": "gemma3",
|
35
|
+
"gemini-2.5-flash": "gemma3",
|
36
|
+
"gemini-2.5-flash-lite": "gemma3",
|
37
|
+
"gemini-2.0-flash": "gemma3",
|
38
|
+
"gemini-2.0-flash-lite": "gemma3",
|
39
|
+
}
|
40
|
+
_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES = {
|
41
|
+
"gemini-1.0-pro-001": "gemma2",
|
42
|
+
"gemini-1.0-pro-002": "gemma2",
|
43
|
+
"gemini-1.5-pro-001": "gemma2",
|
44
|
+
"gemini-1.5-flash-001": "gemma2",
|
45
|
+
"gemini-1.5-flash-002": "gemma2",
|
46
|
+
"gemini-1.5-pro-002": "gemma2",
|
47
|
+
"gemini-2.5-pro-preview-06-05": "gemma3",
|
48
|
+
"gemini-2.5-pro-preview-05-06": "gemma3",
|
49
|
+
"gemini-2.5-pro-exp-03-25": "gemma3",
|
50
|
+
"gemini-live-2.5-flash": "gemma3",
|
51
|
+
"gemini-2.5-flash-preview-05-20": "gemma3",
|
52
|
+
"gemini-2.5-flash-preview-04-17": "gemma3",
|
53
|
+
"gemini-2.5-flash-lite-preview-06-17": "gemma3",
|
54
|
+
"gemini-2.0-flash-001": "gemma3",
|
55
|
+
"gemini-2.0-flash-lite-001": "gemma3",
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
@dataclasses.dataclass(frozen=True)
|
60
|
+
class _TokenizerConfig:
|
61
|
+
model_url: str
|
62
|
+
model_hash: str
|
63
|
+
|
64
|
+
|
65
|
+
# TODO: update gemma3 tokenizer
|
66
|
+
_TOKENIZERS = {
|
67
|
+
"gemma2": _TokenizerConfig(
|
68
|
+
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
|
69
|
+
model_hash=(
|
70
|
+
"61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
|
71
|
+
),
|
72
|
+
),
|
73
|
+
"gemma3": _TokenizerConfig(
|
74
|
+
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
75
|
+
model_hash=(
|
76
|
+
"1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c"
|
77
|
+
),
|
78
|
+
),
|
79
|
+
}
|
80
|
+
|
81
|
+
|
82
|
+
def _load_file(file_url_path: str) -> bytes:
|
83
|
+
"""Loads file bytes from the given file url path."""
|
84
|
+
resp = requests.get(file_url_path)
|
85
|
+
resp.raise_for_status()
|
86
|
+
return cast(bytes, resp.content)
|
87
|
+
|
88
|
+
|
89
|
+
def _is_valid_model(*, model_data: bytes, expected_hash: str) -> bool:
|
90
|
+
"""Returns true if the content is valid by checking the hash."""
|
91
|
+
if not expected_hash:
|
92
|
+
raise ValueError("expected_hash is required")
|
93
|
+
return hashlib.sha256(model_data).hexdigest() == expected_hash
|
94
|
+
|
95
|
+
|
96
|
+
def _maybe_remove_file(file_path: str) -> None:
|
97
|
+
"""Removes the file if exists."""
|
98
|
+
if not os.path.exists(file_path):
|
99
|
+
return
|
100
|
+
try:
|
101
|
+
os.remove(file_path)
|
102
|
+
except OSError:
|
103
|
+
# Don't raise if we cannot remove file.
|
104
|
+
pass
|
105
|
+
|
106
|
+
|
107
|
+
def _maybe_load_from_cache(
|
108
|
+
*, file_path: str, expected_hash: str
|
109
|
+
) -> Optional[bytes]:
|
110
|
+
"""Loads the content from the cache path."""
|
111
|
+
if not os.path.exists(file_path):
|
112
|
+
return None
|
113
|
+
with open(file_path, "rb") as f:
|
114
|
+
content = f.read()
|
115
|
+
if _is_valid_model(model_data=content, expected_hash=expected_hash):
|
116
|
+
return content
|
117
|
+
|
118
|
+
# Cached file corrupted.
|
119
|
+
_maybe_remove_file(file_path)
|
120
|
+
return None
|
121
|
+
|
122
|
+
|
123
|
+
def _maybe_save_to_cache(
|
124
|
+
*, cache_dir: str, cache_path: str, content: bytes
|
125
|
+
) -> None:
|
126
|
+
"""Saves the content to the cache path."""
|
127
|
+
try:
|
128
|
+
os.makedirs(cache_dir, exist_ok=True)
|
129
|
+
tmp_path = cache_dir + "." + str(uuid.uuid4()) + ".tmp"
|
130
|
+
with open(tmp_path, "wb") as f:
|
131
|
+
f.write(content)
|
132
|
+
os.rename(tmp_path, cache_path)
|
133
|
+
except OSError:
|
134
|
+
# Don't raise if we cannot write file.
|
135
|
+
pass
|
136
|
+
|
137
|
+
|
138
|
+
def _load_from_url(*, file_url: str, expected_hash: str) -> bytes:
|
139
|
+
"""Loads model bytes from the given file url."""
|
140
|
+
content = _load_file(file_url)
|
141
|
+
if not _is_valid_model(model_data=content, expected_hash=expected_hash):
|
142
|
+
actual_hash = hashlib.sha256(content).hexdigest()
|
143
|
+
raise ValueError(
|
144
|
+
"Downloaded model file is corrupted."
|
145
|
+
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
|
146
|
+
)
|
147
|
+
return content
|
148
|
+
|
149
|
+
|
150
|
+
def _load(*, file_url: str, expected_hash: str) -> bytes:
|
151
|
+
"""Loads model bytes from the given file url.
|
152
|
+
|
153
|
+
1. If the find local cached file for the given url and the cached file hash
|
154
|
+
matches the expected hash, the cached file is returned.
|
155
|
+
2. If local cached file is not found or the hash does not match, the file is
|
156
|
+
downloaded from the given url. And write to local cache and return the
|
157
|
+
file bytes.
|
158
|
+
3. If the file downloaded from the given url does not match the expected
|
159
|
+
hash, raise ValueError.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
file_url: The url of the file to load.
|
163
|
+
expected_hash: The expected hash of the file.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
The file bytes.
|
167
|
+
"""
|
168
|
+
model_dir = os.path.join(tempfile.gettempdir(), "vertexai_tokenizer_model")
|
169
|
+
filename = hashlib.sha1(file_url.encode()).hexdigest()
|
170
|
+
model_path = os.path.join(model_dir, filename)
|
171
|
+
|
172
|
+
model_data = _maybe_load_from_cache(
|
173
|
+
file_path=model_path, expected_hash=expected_hash
|
174
|
+
)
|
175
|
+
if not model_data:
|
176
|
+
model_data = _load_from_url(file_url=file_url, expected_hash=expected_hash)
|
177
|
+
|
178
|
+
_maybe_save_to_cache(
|
179
|
+
cache_dir=model_dir, cache_path=model_path, content=model_data
|
180
|
+
)
|
181
|
+
return model_data
|
182
|
+
|
183
|
+
|
184
|
+
def _load_model_proto_bytes(tokenizer_name: str) -> bytes:
|
185
|
+
"""Loads model proto bytes from the given tokenizer name."""
|
186
|
+
if tokenizer_name not in _TOKENIZERS:
|
187
|
+
raise ValueError(
|
188
|
+
f"Tokenizer {tokenizer_name} is not supported."
|
189
|
+
f"Supported tokenizers: {list(_TOKENIZERS.keys())}"
|
190
|
+
)
|
191
|
+
return _load(
|
192
|
+
file_url=_TOKENIZERS[tokenizer_name].model_url,
|
193
|
+
expected_hash=_TOKENIZERS[tokenizer_name].model_hash,
|
194
|
+
)
|
195
|
+
|
196
|
+
|
197
|
+
@functools.lru_cache()
|
198
|
+
def load_model_proto(
|
199
|
+
tokenizer_name: str,
|
200
|
+
) -> sentencepiece_model_pb2.ModelProto:
|
201
|
+
"""Loads model proto from the given tokenizer name."""
|
202
|
+
model_proto = sentencepiece_model_pb2.ModelProto()
|
203
|
+
model_proto.ParseFromString(_load_model_proto_bytes(tokenizer_name))
|
204
|
+
return model_proto
|
205
|
+
|
206
|
+
|
207
|
+
def get_tokenizer_name(model_name: str) -> str:
|
208
|
+
"""Gets the tokenizer name for the given model name."""
|
209
|
+
if model_name in _GEMINI_MODELS_TO_TOKENIZER_NAMES.keys():
|
210
|
+
return _GEMINI_MODELS_TO_TOKENIZER_NAMES[model_name]
|
211
|
+
if model_name in _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys():
|
212
|
+
return _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES[model_name]
|
213
|
+
raise ValueError(
|
214
|
+
f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_MODELS_TO_TOKENIZER_NAMES.keys())}, {', '.join(_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys())}.\n" # pylint: disable=line-too-long
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
@functools.lru_cache()
|
219
|
+
def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor:
|
220
|
+
"""Loads sentencepiece tokenizer from the given tokenizer name."""
|
221
|
+
processor = spm.SentencePieceProcessor()
|
222
|
+
processor.LoadFromSerializedProto(_load_model_proto_bytes(tokenizer_name))
|
223
|
+
return processor
|