google-genai 1.30.0__py3-none-any.whl → 1.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -91,7 +91,7 @@ class EphemeralTokenAPIKeyError(ValueError):
91
91
 
92
92
  # This method checks for the API key in the environment variables. Google API
93
93
  # key is precedenced over Gemini API key.
94
- def _get_env_api_key() -> Optional[str]:
94
+ def get_env_api_key() -> Optional[str]:
95
95
  """Gets the API key from environment variables, prioritizing GOOGLE_API_KEY.
96
96
 
97
97
  Returns:
@@ -108,7 +108,7 @@ def _get_env_api_key() -> Optional[str]:
108
108
  return env_google_api_key or env_gemini_api_key or None
109
109
 
110
110
 
111
- def _append_library_version_headers(headers: dict[str, str]) -> None:
111
+ def append_library_version_headers(headers: dict[str, str]) -> None:
112
112
  """Appends the telemetry header to the headers dict."""
113
113
  library_label = f'google-genai-sdk/{version.__version__}'
114
114
  language_label = 'gl-python/' + sys.version.split()[0]
@@ -131,7 +131,7 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
131
131
  headers['x-goog-api-client'] = version_header_value
132
132
 
133
133
 
134
- def _patch_http_options(
134
+ def patch_http_options(
135
135
  options: HttpOptions, patch_options: HttpOptions
136
136
  ) -> HttpOptions:
137
137
  copy_option = options.model_copy()
@@ -155,11 +155,11 @@ def _patch_http_options(
155
155
  setattr(copy_option, key, getattr(options, key))
156
156
 
157
157
  if copy_option.headers is not None:
158
- _append_library_version_headers(copy_option.headers)
158
+ append_library_version_headers(copy_option.headers)
159
159
  return copy_option
160
160
 
161
161
 
162
- def _populate_server_timeout_header(
162
+ def populate_server_timeout_header(
163
163
  headers: dict[str, str], timeout_in_seconds: Optional[Union[float, int]]
164
164
  ) -> None:
165
165
  """Populates the server timeout header in the headers dict."""
@@ -167,7 +167,7 @@ def _populate_server_timeout_header(
167
167
  headers['X-Server-Timeout'] = str(math.ceil(timeout_in_seconds))
168
168
 
169
169
 
170
- def _join_url_path(base_url: str, path: str) -> str:
170
+ def join_url_path(base_url: str, path: str) -> str:
171
171
  parsed_base = urlparse(base_url)
172
172
  base_path = (
173
173
  parsed_base.path[:-1]
@@ -178,7 +178,7 @@ def _join_url_path(base_url: str, path: str) -> str:
178
178
  return urlunparse(parsed_base._replace(path=base_path + '/' + path))
179
179
 
180
180
 
181
- def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
181
+ def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
182
182
  """Loads google auth credentials and project id."""
183
183
  credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
184
184
  scopes=['https://www.googleapis.com/auth/cloud-platform'],
@@ -195,12 +195,12 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
195
195
  return credentials, project
196
196
 
197
197
 
198
- def _refresh_auth(credentials: Credentials) -> Credentials:
198
+ def refresh_auth(credentials: Credentials) -> Credentials:
199
199
  credentials.refresh(Request()) # type: ignore[no-untyped-call]
200
200
  return credentials
201
201
 
202
202
 
203
- def _get_timeout_in_seconds(
203
+ def get_timeout_in_seconds(
204
204
  timeout: Optional[Union[float, int]],
205
205
  ) -> Optional[float]:
206
206
  """Converts the timeout to seconds."""
@@ -454,7 +454,7 @@ _RETRY_HTTP_STATUS_CODES = (
454
454
  )
455
455
 
456
456
 
457
- def _retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
457
+ def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
458
458
  """Returns the retry args for the given http retry options.
459
459
 
460
460
  Args:
@@ -574,7 +574,7 @@ class BaseApiClient:
574
574
  # Retrieve implicitly set values from the environment.
575
575
  env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
576
576
  env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
577
- env_api_key = _get_env_api_key()
577
+ env_api_key = get_env_api_key()
578
578
  self.project = project or env_project
579
579
  self.location = location or env_location
580
580
  self.api_key = api_key or env_api_key
@@ -631,7 +631,7 @@ class BaseApiClient:
631
631
  and not self.api_key
632
632
  and not validated_http_options.base_url
633
633
  ):
634
- credentials, self.project = _load_auth(project=None)
634
+ credentials, self.project = load_auth(project=None)
635
635
  if not self._credentials:
636
636
  self._credentials = credentials
637
637
 
@@ -670,12 +670,12 @@ class BaseApiClient:
670
670
  self._http_options.headers['x-goog-api-key'] = self.api_key
671
671
  # Update the http options with the user provided http options.
672
672
  if http_options:
673
- self._http_options = _patch_http_options(
673
+ self._http_options = patch_http_options(
674
674
  self._http_options, validated_http_options
675
675
  )
676
676
  else:
677
677
  if self._http_options.headers is not None:
678
- _append_library_version_headers(self._http_options.headers)
678
+ append_library_version_headers(self._http_options.headers)
679
679
 
680
680
  client_args, async_client_args = self._ensure_httpx_ssl_ctx(
681
681
  self._http_options
@@ -689,7 +689,7 @@ class BaseApiClient:
689
689
  )
690
690
  self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
691
691
 
692
- retry_kwargs = _retry_args(self._http_options.retry_options)
692
+ retry_kwargs = retry_args(self._http_options.retry_options)
693
693
  self._retry = tenacity.Retrying(**retry_kwargs)
694
694
  self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
695
695
 
@@ -889,14 +889,14 @@ class BaseApiClient:
889
889
  """Retrieves the access token for the credentials."""
890
890
  with self._sync_auth_lock:
891
891
  if not self._credentials:
892
- self._credentials, project = _load_auth(project=self.project)
892
+ self._credentials, project = load_auth(project=self.project)
893
893
  if not self.project:
894
894
  self.project = project
895
895
 
896
896
  if self._credentials:
897
897
  if self._credentials.expired or not self._credentials.token:
898
898
  # Only refresh when it needs to. Default expiration is 3600 seconds.
899
- _refresh_auth(self._credentials)
899
+ refresh_auth(self._credentials)
900
900
  if not self._credentials.token:
901
901
  raise RuntimeError('Could not resolve API token from the environment')
902
902
  return self._credentials.token # type: ignore[no-any-return]
@@ -912,7 +912,7 @@ class BaseApiClient:
912
912
  if not self._credentials:
913
913
  # Double check that the credentials are not set before loading them.
914
914
  self._credentials, project = await asyncio.to_thread(
915
- _load_auth, project=self.project
915
+ load_auth, project=self.project
916
916
  )
917
917
  if not self.project:
918
918
  self.project = project
@@ -923,7 +923,7 @@ class BaseApiClient:
923
923
  async with self._async_auth_lock:
924
924
  if self._credentials.expired or not self._credentials.token:
925
925
  # Double check that the credentials expired before refreshing.
926
- await asyncio.to_thread(_refresh_auth, self._credentials)
926
+ await asyncio.to_thread(refresh_auth, self._credentials)
927
927
 
928
928
  if not self._credentials.token:
929
929
  raise RuntimeError('Could not resolve API token from the environment')
@@ -946,12 +946,12 @@ class BaseApiClient:
946
946
  # patch the http options with the user provided settings.
947
947
  if http_options:
948
948
  if isinstance(http_options, HttpOptions):
949
- patched_http_options = _patch_http_options(
949
+ patched_http_options = patch_http_options(
950
950
  self._http_options,
951
951
  http_options,
952
952
  )
953
953
  else:
954
- patched_http_options = _patch_http_options(
954
+ patched_http_options = patch_http_options(
955
955
  self._http_options, HttpOptions.model_validate(http_options)
956
956
  )
957
957
  else:
@@ -993,7 +993,7 @@ class BaseApiClient:
993
993
  request_dict, patched_http_options.extra_body
994
994
  )
995
995
 
996
- url = _join_url_path(
996
+ url = join_url_path(
997
997
  base_url,
998
998
  versioned_path,
999
999
  )
@@ -1003,11 +1003,11 @@ class BaseApiClient:
1003
1003
  'Ephemeral tokens can only be used with the live API.'
1004
1004
  )
1005
1005
 
1006
- timeout_in_seconds = _get_timeout_in_seconds(patched_http_options.timeout)
1006
+ timeout_in_seconds = get_timeout_in_seconds(patched_http_options.timeout)
1007
1007
 
1008
1008
  if patched_http_options.headers is None:
1009
1009
  raise ValueError('Request headers must be set.')
1010
- _populate_server_timeout_header(
1010
+ populate_server_timeout_header(
1011
1011
  patched_http_options.headers, timeout_in_seconds
1012
1012
  )
1013
1013
  return HttpRequest(
@@ -1079,7 +1079,7 @@ class BaseApiClient:
1079
1079
  )
1080
1080
  # Support per request retry options.
1081
1081
  if parameter_model.retry_options:
1082
- retry_kwargs = _retry_args(parameter_model.retry_options)
1082
+ retry_kwargs = retry_args(parameter_model.retry_options)
1083
1083
  retry = tenacity.Retrying(**retry_kwargs)
1084
1084
  return retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
1085
1085
 
@@ -1239,7 +1239,7 @@ class BaseApiClient:
1239
1239
  )
1240
1240
  # Support per request retry options.
1241
1241
  if parameter_model.retry_options:
1242
- retry_kwargs = _retry_args(parameter_model.retry_options)
1242
+ retry_kwargs = retry_args(parameter_model.retry_options)
1243
1243
  retry = tenacity.AsyncRetrying(**retry_kwargs)
1244
1244
  return await retry(self._async_request_once, http_request, stream) # type: ignore[no-any-return]
1245
1245
  return await self._async_retry( # type: ignore[no-any-return]
@@ -1398,13 +1398,13 @@ class BaseApiClient:
1398
1398
  if isinstance(self._http_options, dict)
1399
1399
  else self._http_options.timeout
1400
1400
  )
1401
- timeout_in_seconds = _get_timeout_in_seconds(timeout)
1401
+ timeout_in_seconds = get_timeout_in_seconds(timeout)
1402
1402
  upload_headers = {
1403
1403
  'X-Goog-Upload-Command': upload_command,
1404
1404
  'X-Goog-Upload-Offset': str(offset),
1405
1405
  'Content-Length': str(chunk_size),
1406
1406
  }
1407
- _populate_server_timeout_header(upload_headers, timeout_in_seconds)
1407
+ populate_server_timeout_header(upload_headers, timeout_in_seconds)
1408
1408
  retry_count = 0
1409
1409
  while retry_count < MAX_RETRY_COUNT:
1410
1410
  response = self._httpx_client.request(
@@ -1558,13 +1558,13 @@ class BaseApiClient:
1558
1558
  if isinstance(self._http_options, dict)
1559
1559
  else self._http_options.timeout
1560
1560
  )
1561
- timeout_in_seconds = _get_timeout_in_seconds(timeout)
1561
+ timeout_in_seconds = get_timeout_in_seconds(timeout)
1562
1562
  upload_headers = {
1563
1563
  'X-Goog-Upload-Command': upload_command,
1564
1564
  'X-Goog-Upload-Offset': str(offset),
1565
1565
  'Content-Length': str(chunk_size),
1566
1566
  }
1567
- _populate_server_timeout_header(upload_headers, timeout_in_seconds)
1567
+ populate_server_timeout_header(upload_headers, timeout_in_seconds)
1568
1568
 
1569
1569
  retry_count = 0
1570
1570
  response = None
@@ -1634,13 +1634,13 @@ class BaseApiClient:
1634
1634
  if isinstance(self._http_options, dict)
1635
1635
  else self._http_options.timeout
1636
1636
  )
1637
- timeout_in_seconds = _get_timeout_in_seconds(timeout)
1637
+ timeout_in_seconds = get_timeout_in_seconds(timeout)
1638
1638
  upload_headers = {
1639
1639
  'X-Goog-Upload-Command': upload_command,
1640
1640
  'X-Goog-Upload-Offset': str(offset),
1641
1641
  'Content-Length': str(chunk_size),
1642
1642
  }
1643
- _populate_server_timeout_header(upload_headers, timeout_in_seconds)
1643
+ populate_server_timeout_header(upload_headers, timeout_in_seconds)
1644
1644
 
1645
1645
  retry_count = 0
1646
1646
  client_response = None
@@ -30,6 +30,18 @@ if sys.version_info >= (3, 10):
30
30
  else:
31
31
  VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
32
32
 
33
+
34
+ __all__ = [
35
+ '_py_builtin_type_to_schema_type',
36
+ '_raise_for_unsupported_param',
37
+ '_handle_params_as_deferred_annotations',
38
+ '_add_unevaluated_items_to_fixed_len_tuple_schema',
39
+ '_is_builtin_primitive_or_compound',
40
+ '_is_default_value_compatible',
41
+ '_parse_schema_from_parameter',
42
+ '_get_required_fields',
43
+ ]
44
+
33
45
  _py_builtin_type_to_schema_type = {
34
46
  str: types.Type.STRING,
35
47
  int: types.Type.INTEGER,
@@ -0,0 +1,26 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """Base transformers for Google GenAI SDK."""
17
+ import base64
18
+
19
+ # Some fields don't accept url safe base64 encoding.
20
+ # We shouldn't use this transformer if the backend adhere to Cloud Type
21
+ # format https://cloud.google.com/docs/discovery/type-format.
22
+ # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
23
+ def t_bytes(data: bytes) -> str:
24
+ if not isinstance(data, bytes):
25
+ return data
26
+ return base64.b64encode(data).decode('ascii')
@@ -16,6 +16,7 @@
16
16
  # Code generated by the Google Gen AI SDK generator DO NOT EDIT.
17
17
 
18
18
  from typing import Any, Optional, Union
19
+
19
20
  from . import _transformers as t
20
21
  from ._api_client import BaseApiClient
21
22
  from ._common import get_value_by_path as getv
@@ -0,0 +1,223 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ import dataclasses
17
+ import functools
18
+ import hashlib
19
+ import os
20
+ import tempfile
21
+ from typing import Optional, cast
22
+ import uuid
23
+
24
+ import requests # type: ignore
25
+ import sentencepiece as spm
26
+ from sentencepiece import sentencepiece_model_pb2
27
+
28
+
29
+ # Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
30
+ _GEMINI_MODELS_TO_TOKENIZER_NAMES = {
31
+ "gemini-1.0-pro": "gemma2",
32
+ "gemini-1.5-pro": "gemma2",
33
+ "gemini-1.5-flash": "gemma2",
34
+ "gemini-2.5-pro": "gemma3",
35
+ "gemini-2.5-flash": "gemma3",
36
+ "gemini-2.5-flash-lite": "gemma3",
37
+ "gemini-2.0-flash": "gemma3",
38
+ "gemini-2.0-flash-lite": "gemma3",
39
+ }
40
+ _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES = {
41
+ "gemini-1.0-pro-001": "gemma2",
42
+ "gemini-1.0-pro-002": "gemma2",
43
+ "gemini-1.5-pro-001": "gemma2",
44
+ "gemini-1.5-flash-001": "gemma2",
45
+ "gemini-1.5-flash-002": "gemma2",
46
+ "gemini-1.5-pro-002": "gemma2",
47
+ "gemini-2.5-pro-preview-06-05": "gemma3",
48
+ "gemini-2.5-pro-preview-05-06": "gemma3",
49
+ "gemini-2.5-pro-exp-03-25": "gemma3",
50
+ "gemini-live-2.5-flash": "gemma3",
51
+ "gemini-2.5-flash-preview-05-20": "gemma3",
52
+ "gemini-2.5-flash-preview-04-17": "gemma3",
53
+ "gemini-2.5-flash-lite-preview-06-17": "gemma3",
54
+ "gemini-2.0-flash-001": "gemma3",
55
+ "gemini-2.0-flash-lite-001": "gemma3",
56
+ }
57
+
58
+
59
+ @dataclasses.dataclass(frozen=True)
60
+ class _TokenizerConfig:
61
+ model_url: str
62
+ model_hash: str
63
+
64
+
65
+ # TODO: update gemma3 tokenizer
66
+ _TOKENIZERS = {
67
+ "gemma2": _TokenizerConfig(
68
+ model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
69
+ model_hash=(
70
+ "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
71
+ ),
72
+ ),
73
+ "gemma3": _TokenizerConfig(
74
+ model_url="https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
75
+ model_hash=(
76
+ "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c"
77
+ ),
78
+ ),
79
+ }
80
+
81
+
82
+ def _load_file(file_url_path: str) -> bytes:
83
+ """Loads file bytes from the given file url path."""
84
+ resp = requests.get(file_url_path)
85
+ resp.raise_for_status()
86
+ return cast(bytes, resp.content)
87
+
88
+
89
+ def _is_valid_model(*, model_data: bytes, expected_hash: str) -> bool:
90
+ """Returns true if the content is valid by checking the hash."""
91
+ if not expected_hash:
92
+ raise ValueError("expected_hash is required")
93
+ return hashlib.sha256(model_data).hexdigest() == expected_hash
94
+
95
+
96
+ def _maybe_remove_file(file_path: str) -> None:
97
+ """Removes the file if exists."""
98
+ if not os.path.exists(file_path):
99
+ return
100
+ try:
101
+ os.remove(file_path)
102
+ except OSError:
103
+ # Don't raise if we cannot remove file.
104
+ pass
105
+
106
+
107
+ def _maybe_load_from_cache(
108
+ *, file_path: str, expected_hash: str
109
+ ) -> Optional[bytes]:
110
+ """Loads the content from the cache path."""
111
+ if not os.path.exists(file_path):
112
+ return None
113
+ with open(file_path, "rb") as f:
114
+ content = f.read()
115
+ if _is_valid_model(model_data=content, expected_hash=expected_hash):
116
+ return content
117
+
118
+ # Cached file corrupted.
119
+ _maybe_remove_file(file_path)
120
+ return None
121
+
122
+
123
+ def _maybe_save_to_cache(
124
+ *, cache_dir: str, cache_path: str, content: bytes
125
+ ) -> None:
126
+ """Saves the content to the cache path."""
127
+ try:
128
+ os.makedirs(cache_dir, exist_ok=True)
129
+ tmp_path = cache_dir + "." + str(uuid.uuid4()) + ".tmp"
130
+ with open(tmp_path, "wb") as f:
131
+ f.write(content)
132
+ os.rename(tmp_path, cache_path)
133
+ except OSError:
134
+ # Don't raise if we cannot write file.
135
+ pass
136
+
137
+
138
+ def _load_from_url(*, file_url: str, expected_hash: str) -> bytes:
139
+ """Loads model bytes from the given file url."""
140
+ content = _load_file(file_url)
141
+ if not _is_valid_model(model_data=content, expected_hash=expected_hash):
142
+ actual_hash = hashlib.sha256(content).hexdigest()
143
+ raise ValueError(
144
+ "Downloaded model file is corrupted."
145
+ f" Expected hash {expected_hash}. Got file hash {actual_hash}."
146
+ )
147
+ return content
148
+
149
+
150
+ def _load(*, file_url: str, expected_hash: str) -> bytes:
151
+ """Loads model bytes from the given file url.
152
+
153
+ 1. If the find local cached file for the given url and the cached file hash
154
+ matches the expected hash, the cached file is returned.
155
+ 2. If local cached file is not found or the hash does not match, the file is
156
+ downloaded from the given url. And write to local cache and return the
157
+ file bytes.
158
+ 3. If the file downloaded from the given url does not match the expected
159
+ hash, raise ValueError.
160
+
161
+ Args:
162
+ file_url: The url of the file to load.
163
+ expected_hash: The expected hash of the file.
164
+
165
+ Returns:
166
+ The file bytes.
167
+ """
168
+ model_dir = os.path.join(tempfile.gettempdir(), "vertexai_tokenizer_model")
169
+ filename = hashlib.sha1(file_url.encode()).hexdigest()
170
+ model_path = os.path.join(model_dir, filename)
171
+
172
+ model_data = _maybe_load_from_cache(
173
+ file_path=model_path, expected_hash=expected_hash
174
+ )
175
+ if not model_data:
176
+ model_data = _load_from_url(file_url=file_url, expected_hash=expected_hash)
177
+
178
+ _maybe_save_to_cache(
179
+ cache_dir=model_dir, cache_path=model_path, content=model_data
180
+ )
181
+ return model_data
182
+
183
+
184
+ def _load_model_proto_bytes(tokenizer_name: str) -> bytes:
185
+ """Loads model proto bytes from the given tokenizer name."""
186
+ if tokenizer_name not in _TOKENIZERS:
187
+ raise ValueError(
188
+ f"Tokenizer {tokenizer_name} is not supported."
189
+ f"Supported tokenizers: {list(_TOKENIZERS.keys())}"
190
+ )
191
+ return _load(
192
+ file_url=_TOKENIZERS[tokenizer_name].model_url,
193
+ expected_hash=_TOKENIZERS[tokenizer_name].model_hash,
194
+ )
195
+
196
+
197
+ @functools.lru_cache()
198
+ def load_model_proto(
199
+ tokenizer_name: str,
200
+ ) -> sentencepiece_model_pb2.ModelProto:
201
+ """Loads model proto from the given tokenizer name."""
202
+ model_proto = sentencepiece_model_pb2.ModelProto()
203
+ model_proto.ParseFromString(_load_model_proto_bytes(tokenizer_name))
204
+ return model_proto
205
+
206
+
207
+ def get_tokenizer_name(model_name: str) -> str:
208
+ """Gets the tokenizer name for the given model name."""
209
+ if model_name in _GEMINI_MODELS_TO_TOKENIZER_NAMES.keys():
210
+ return _GEMINI_MODELS_TO_TOKENIZER_NAMES[model_name]
211
+ if model_name in _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys():
212
+ return _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES[model_name]
213
+ raise ValueError(
214
+ f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_MODELS_TO_TOKENIZER_NAMES.keys())}, {', '.join(_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys())}.\n" # pylint: disable=line-too-long
215
+ )
216
+
217
+
218
+ @functools.lru_cache()
219
+ def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor:
220
+ """Loads sentencepiece tokenizer from the given tokenizer name."""
221
+ processor = spm.SentencePieceProcessor()
222
+ processor.LoadFromSerializedProto(_load_model_proto_bytes(tokenizer_name))
223
+ return processor