google-genai 1.31.0__tar.gz → 1.33.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_genai-1.31.0/google_genai.egg-info → google_genai-1.33.0}/PKG-INFO +6 -6
- {google_genai-1.31.0 → google_genai-1.33.0}/README.md +5 -5
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_api_client.py +32 -9
- google_genai-1.33.0/google/genai/_base_transformers.py +26 -0
- google_genai-1.33.0/google/genai/_local_tokenizer_loader.py +223 -0
- google_genai-1.33.0/google/genai/_operations_converters.py +307 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_replay_api_client.py +15 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_transformers.py +0 -10
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/caches.py +14 -2
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/files.py +12 -2
- google_genai-1.33.0/google/genai/local_tokenizer.py +362 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/models.py +171 -196
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/tunings.py +134 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/types.py +402 -304
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/version.py +1 -1
- {google_genai-1.31.0 → google_genai-1.33.0/google_genai.egg-info}/PKG-INFO +6 -6
- {google_genai-1.31.0 → google_genai-1.33.0}/google_genai.egg-info/SOURCES.txt +4 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/pyproject.toml +1 -1
- {google_genai-1.31.0 → google_genai-1.33.0}/LICENSE +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/MANIFEST.in +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/__init__.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_adapters.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_api_module.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_automatic_function_calling_util.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_base_url.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_common.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_extra_utils.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_live_converters.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_mcp_utils.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_test_api_client.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/_tokens_converters.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/batches.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/chats.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/client.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/errors.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/live.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/live_music.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/operations.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/pagers.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/py.typed +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google/genai/tokens.py +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google_genai.egg-info/dependency_links.txt +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google_genai.egg-info/requires.txt +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/google_genai.egg-info/top_level.txt +0 -0
- {google_genai-1.31.0 → google_genai-1.33.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: google-genai
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.33.0
|
4
4
|
Summary: GenAI Python SDK
|
5
5
|
Author-email: Google LLC <googleapis-packages@google.com>
|
6
6
|
License: Apache-2.0
|
@@ -183,7 +183,7 @@ Then, you can pass it through the following way:
|
|
183
183
|
|
184
184
|
http_options = types.HttpOptions(
|
185
185
|
client_args={'proxy': 'socks5://user:pass@host:port'},
|
186
|
-
async_client_args={'proxy': 'socks5://user:pass@host:port'}
|
186
|
+
async_client_args={'proxy': 'socks5://user:pass@host:port'},
|
187
187
|
)
|
188
188
|
|
189
189
|
client=Client(..., http_options=http_options)
|
@@ -388,7 +388,7 @@ The SDK converts all non function call parts into a content with a `user` role.
|
|
388
388
|
[
|
389
389
|
types.UserContent(parts=[
|
390
390
|
types.Part.from_uri(
|
391
|
-
|
391
|
+
file_uri: 'gs://generativeai-downloads/images/scones.jpg',
|
392
392
|
mime_type: 'image/jpeg',
|
393
393
|
)
|
394
394
|
])
|
@@ -824,10 +824,10 @@ user_profile = {
|
|
824
824
|
|
825
825
|
response = client.models.generate_content(
|
826
826
|
model='gemini-2.0-flash',
|
827
|
-
contents='Give me
|
827
|
+
contents='Give me a random user profile.',
|
828
828
|
config={
|
829
829
|
'response_mime_type': 'application/json',
|
830
|
-
'response_json_schema':
|
830
|
+
'response_json_schema': user_profile
|
831
831
|
},
|
832
832
|
)
|
833
833
|
print(response.parsed)
|
@@ -1588,7 +1588,7 @@ batch_job = client.batches.create(
|
|
1588
1588
|
}],
|
1589
1589
|
"role": "user",
|
1590
1590
|
}],
|
1591
|
-
"config
|
1591
|
+
"config": {"response_modalities": ["text"]},
|
1592
1592
|
}],
|
1593
1593
|
)
|
1594
1594
|
|
@@ -149,7 +149,7 @@ Then, you can pass it through the following way:
|
|
149
149
|
|
150
150
|
http_options = types.HttpOptions(
|
151
151
|
client_args={'proxy': 'socks5://user:pass@host:port'},
|
152
|
-
async_client_args={'proxy': 'socks5://user:pass@host:port'}
|
152
|
+
async_client_args={'proxy': 'socks5://user:pass@host:port'},
|
153
153
|
)
|
154
154
|
|
155
155
|
client=Client(..., http_options=http_options)
|
@@ -354,7 +354,7 @@ The SDK converts all non function call parts into a content with a `user` role.
|
|
354
354
|
[
|
355
355
|
types.UserContent(parts=[
|
356
356
|
types.Part.from_uri(
|
357
|
-
|
357
|
+
file_uri: 'gs://generativeai-downloads/images/scones.jpg',
|
358
358
|
mime_type: 'image/jpeg',
|
359
359
|
)
|
360
360
|
])
|
@@ -790,10 +790,10 @@ user_profile = {
|
|
790
790
|
|
791
791
|
response = client.models.generate_content(
|
792
792
|
model='gemini-2.0-flash',
|
793
|
-
contents='Give me
|
793
|
+
contents='Give me a random user profile.',
|
794
794
|
config={
|
795
795
|
'response_mime_type': 'application/json',
|
796
|
-
'response_json_schema':
|
796
|
+
'response_json_schema': user_profile
|
797
797
|
},
|
798
798
|
)
|
799
799
|
print(response.parsed)
|
@@ -1554,7 +1554,7 @@ batch_job = client.batches.create(
|
|
1554
1554
|
}],
|
1555
1555
|
"role": "user",
|
1556
1556
|
}],
|
1557
|
-
"config
|
1557
|
+
"config": {"response_modalities": ["text"]},
|
1558
1558
|
}],
|
1559
1559
|
)
|
1560
1560
|
|
@@ -584,13 +584,9 @@ class BaseApiClient:
|
|
584
584
|
# Initialize the lock. This lock will be used to protect access to the
|
585
585
|
# credentials. This is crucial for thread safety when multiple coroutines
|
586
586
|
# might be accessing the credentials at the same time.
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
except RuntimeError:
|
591
|
-
asyncio.set_event_loop(asyncio.new_event_loop())
|
592
|
-
self._sync_auth_lock = threading.Lock()
|
593
|
-
self._async_auth_lock = asyncio.Lock()
|
587
|
+
self._sync_auth_lock = threading.Lock()
|
588
|
+
self._async_auth_lock: Optional[asyncio.Lock] = None
|
589
|
+
self._async_auth_lock_creation_lock: Optional[asyncio.Lock] = None
|
594
590
|
|
595
591
|
# Handle when to use Vertex AI in express mode (api key).
|
596
592
|
# Explicit initializer arguments are already validated above.
|
@@ -903,10 +899,36 @@ class BaseApiClient:
|
|
903
899
|
else:
|
904
900
|
raise RuntimeError('Could not resolve API token from the environment')
|
905
901
|
|
902
|
+
async def _get_async_auth_lock(self) -> asyncio.Lock:
|
903
|
+
"""Lazily initializes and returns an asyncio.Lock for async authentication.
|
904
|
+
|
905
|
+
This method ensures that a single `asyncio.Lock` instance is created and
|
906
|
+
shared among all asynchronous operations that require authentication,
|
907
|
+
preventing race conditions when accessing or refreshing credentials.
|
908
|
+
|
909
|
+
The lock is created on the first call to this method. An internal async lock
|
910
|
+
is used to protect the creation of the main authentication lock to ensure
|
911
|
+
it's a singleton within the client instance.
|
912
|
+
|
913
|
+
Returns:
|
914
|
+
The asyncio.Lock instance for asynchronous authentication operations.
|
915
|
+
"""
|
916
|
+
if self._async_auth_lock is None:
|
917
|
+
# Create async creation lock if needed
|
918
|
+
if self._async_auth_lock_creation_lock is None:
|
919
|
+
self._async_auth_lock_creation_lock = asyncio.Lock()
|
920
|
+
|
921
|
+
async with self._async_auth_lock_creation_lock:
|
922
|
+
if self._async_auth_lock is None:
|
923
|
+
self._async_auth_lock = asyncio.Lock()
|
924
|
+
|
925
|
+
return self._async_auth_lock
|
926
|
+
|
906
927
|
async def _async_access_token(self) -> Union[str, Any]:
|
907
928
|
"""Retrieves the access token for the credentials asynchronously."""
|
908
929
|
if not self._credentials:
|
909
|
-
|
930
|
+
async_auth_lock = await self._get_async_auth_lock()
|
931
|
+
async with async_auth_lock:
|
910
932
|
# This ensures that only one coroutine can execute the auth logic at a
|
911
933
|
# time for thread safety.
|
912
934
|
if not self._credentials:
|
@@ -920,7 +942,8 @@ class BaseApiClient:
|
|
920
942
|
if self._credentials:
|
921
943
|
if self._credentials.expired or not self._credentials.token:
|
922
944
|
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
923
|
-
|
945
|
+
async_auth_lock = await self._get_async_auth_lock()
|
946
|
+
async with async_auth_lock:
|
924
947
|
if self._credentials.expired or not self._credentials.token:
|
925
948
|
# Double check that the credentials expired before refreshing.
|
926
949
|
await asyncio.to_thread(refresh_auth, self._credentials)
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Base transformers for Google GenAI SDK."""
|
17
|
+
import base64
|
18
|
+
|
19
|
+
# Some fields don't accept url safe base64 encoding.
|
20
|
+
# We shouldn't use this transformer if the backend adhere to Cloud Type
|
21
|
+
# format https://cloud.google.com/docs/discovery/type-format.
|
22
|
+
# TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
|
23
|
+
def t_bytes(data: bytes) -> str:
|
24
|
+
if not isinstance(data, bytes):
|
25
|
+
return data
|
26
|
+
return base64.b64encode(data).decode('ascii')
|
@@ -0,0 +1,223 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
import dataclasses
|
17
|
+
import functools
|
18
|
+
import hashlib
|
19
|
+
import os
|
20
|
+
import tempfile
|
21
|
+
from typing import Optional, cast
|
22
|
+
import uuid
|
23
|
+
|
24
|
+
import requests # type: ignore
|
25
|
+
import sentencepiece as spm
|
26
|
+
from sentencepiece import sentencepiece_model_pb2
|
27
|
+
|
28
|
+
|
29
|
+
# Source of truth: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
|
30
|
+
_GEMINI_MODELS_TO_TOKENIZER_NAMES = {
|
31
|
+
"gemini-1.0-pro": "gemma2",
|
32
|
+
"gemini-1.5-pro": "gemma2",
|
33
|
+
"gemini-1.5-flash": "gemma2",
|
34
|
+
"gemini-2.5-pro": "gemma3",
|
35
|
+
"gemini-2.5-flash": "gemma3",
|
36
|
+
"gemini-2.5-flash-lite": "gemma3",
|
37
|
+
"gemini-2.0-flash": "gemma3",
|
38
|
+
"gemini-2.0-flash-lite": "gemma3",
|
39
|
+
}
|
40
|
+
_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES = {
|
41
|
+
"gemini-1.0-pro-001": "gemma2",
|
42
|
+
"gemini-1.0-pro-002": "gemma2",
|
43
|
+
"gemini-1.5-pro-001": "gemma2",
|
44
|
+
"gemini-1.5-flash-001": "gemma2",
|
45
|
+
"gemini-1.5-flash-002": "gemma2",
|
46
|
+
"gemini-1.5-pro-002": "gemma2",
|
47
|
+
"gemini-2.5-pro-preview-06-05": "gemma3",
|
48
|
+
"gemini-2.5-pro-preview-05-06": "gemma3",
|
49
|
+
"gemini-2.5-pro-exp-03-25": "gemma3",
|
50
|
+
"gemini-live-2.5-flash": "gemma3",
|
51
|
+
"gemini-2.5-flash-preview-05-20": "gemma3",
|
52
|
+
"gemini-2.5-flash-preview-04-17": "gemma3",
|
53
|
+
"gemini-2.5-flash-lite-preview-06-17": "gemma3",
|
54
|
+
"gemini-2.0-flash-001": "gemma3",
|
55
|
+
"gemini-2.0-flash-lite-001": "gemma3",
|
56
|
+
}
|
57
|
+
|
58
|
+
|
59
|
+
@dataclasses.dataclass(frozen=True)
|
60
|
+
class _TokenizerConfig:
|
61
|
+
model_url: str
|
62
|
+
model_hash: str
|
63
|
+
|
64
|
+
|
65
|
+
# TODO: update gemma3 tokenizer
|
66
|
+
_TOKENIZERS = {
|
67
|
+
"gemma2": _TokenizerConfig(
|
68
|
+
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
|
69
|
+
model_hash=(
|
70
|
+
"61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2"
|
71
|
+
),
|
72
|
+
),
|
73
|
+
"gemma3": _TokenizerConfig(
|
74
|
+
model_url="https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
|
75
|
+
model_hash=(
|
76
|
+
"1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c"
|
77
|
+
),
|
78
|
+
),
|
79
|
+
}
|
80
|
+
|
81
|
+
|
82
|
+
def _load_file(file_url_path: str) -> bytes:
|
83
|
+
"""Loads file bytes from the given file url path."""
|
84
|
+
resp = requests.get(file_url_path)
|
85
|
+
resp.raise_for_status()
|
86
|
+
return cast(bytes, resp.content)
|
87
|
+
|
88
|
+
|
89
|
+
def _is_valid_model(*, model_data: bytes, expected_hash: str) -> bool:
|
90
|
+
"""Returns true if the content is valid by checking the hash."""
|
91
|
+
if not expected_hash:
|
92
|
+
raise ValueError("expected_hash is required")
|
93
|
+
return hashlib.sha256(model_data).hexdigest() == expected_hash
|
94
|
+
|
95
|
+
|
96
|
+
def _maybe_remove_file(file_path: str) -> None:
|
97
|
+
"""Removes the file if exists."""
|
98
|
+
if not os.path.exists(file_path):
|
99
|
+
return
|
100
|
+
try:
|
101
|
+
os.remove(file_path)
|
102
|
+
except OSError:
|
103
|
+
# Don't raise if we cannot remove file.
|
104
|
+
pass
|
105
|
+
|
106
|
+
|
107
|
+
def _maybe_load_from_cache(
|
108
|
+
*, file_path: str, expected_hash: str
|
109
|
+
) -> Optional[bytes]:
|
110
|
+
"""Loads the content from the cache path."""
|
111
|
+
if not os.path.exists(file_path):
|
112
|
+
return None
|
113
|
+
with open(file_path, "rb") as f:
|
114
|
+
content = f.read()
|
115
|
+
if _is_valid_model(model_data=content, expected_hash=expected_hash):
|
116
|
+
return content
|
117
|
+
|
118
|
+
# Cached file corrupted.
|
119
|
+
_maybe_remove_file(file_path)
|
120
|
+
return None
|
121
|
+
|
122
|
+
|
123
|
+
def _maybe_save_to_cache(
|
124
|
+
*, cache_dir: str, cache_path: str, content: bytes
|
125
|
+
) -> None:
|
126
|
+
"""Saves the content to the cache path."""
|
127
|
+
try:
|
128
|
+
os.makedirs(cache_dir, exist_ok=True)
|
129
|
+
tmp_path = cache_dir + "." + str(uuid.uuid4()) + ".tmp"
|
130
|
+
with open(tmp_path, "wb") as f:
|
131
|
+
f.write(content)
|
132
|
+
os.rename(tmp_path, cache_path)
|
133
|
+
except OSError:
|
134
|
+
# Don't raise if we cannot write file.
|
135
|
+
pass
|
136
|
+
|
137
|
+
|
138
|
+
def _load_from_url(*, file_url: str, expected_hash: str) -> bytes:
|
139
|
+
"""Loads model bytes from the given file url."""
|
140
|
+
content = _load_file(file_url)
|
141
|
+
if not _is_valid_model(model_data=content, expected_hash=expected_hash):
|
142
|
+
actual_hash = hashlib.sha256(content).hexdigest()
|
143
|
+
raise ValueError(
|
144
|
+
"Downloaded model file is corrupted."
|
145
|
+
f" Expected hash {expected_hash}. Got file hash {actual_hash}."
|
146
|
+
)
|
147
|
+
return content
|
148
|
+
|
149
|
+
|
150
|
+
def _load(*, file_url: str, expected_hash: str) -> bytes:
|
151
|
+
"""Loads model bytes from the given file url.
|
152
|
+
|
153
|
+
1. If the find local cached file for the given url and the cached file hash
|
154
|
+
matches the expected hash, the cached file is returned.
|
155
|
+
2. If local cached file is not found or the hash does not match, the file is
|
156
|
+
downloaded from the given url. And write to local cache and return the
|
157
|
+
file bytes.
|
158
|
+
3. If the file downloaded from the given url does not match the expected
|
159
|
+
hash, raise ValueError.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
file_url: The url of the file to load.
|
163
|
+
expected_hash: The expected hash of the file.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
The file bytes.
|
167
|
+
"""
|
168
|
+
model_dir = os.path.join(tempfile.gettempdir(), "vertexai_tokenizer_model")
|
169
|
+
filename = hashlib.sha1(file_url.encode()).hexdigest()
|
170
|
+
model_path = os.path.join(model_dir, filename)
|
171
|
+
|
172
|
+
model_data = _maybe_load_from_cache(
|
173
|
+
file_path=model_path, expected_hash=expected_hash
|
174
|
+
)
|
175
|
+
if not model_data:
|
176
|
+
model_data = _load_from_url(file_url=file_url, expected_hash=expected_hash)
|
177
|
+
|
178
|
+
_maybe_save_to_cache(
|
179
|
+
cache_dir=model_dir, cache_path=model_path, content=model_data
|
180
|
+
)
|
181
|
+
return model_data
|
182
|
+
|
183
|
+
|
184
|
+
def _load_model_proto_bytes(tokenizer_name: str) -> bytes:
|
185
|
+
"""Loads model proto bytes from the given tokenizer name."""
|
186
|
+
if tokenizer_name not in _TOKENIZERS:
|
187
|
+
raise ValueError(
|
188
|
+
f"Tokenizer {tokenizer_name} is not supported."
|
189
|
+
f"Supported tokenizers: {list(_TOKENIZERS.keys())}"
|
190
|
+
)
|
191
|
+
return _load(
|
192
|
+
file_url=_TOKENIZERS[tokenizer_name].model_url,
|
193
|
+
expected_hash=_TOKENIZERS[tokenizer_name].model_hash,
|
194
|
+
)
|
195
|
+
|
196
|
+
|
197
|
+
@functools.lru_cache()
|
198
|
+
def load_model_proto(
|
199
|
+
tokenizer_name: str,
|
200
|
+
) -> sentencepiece_model_pb2.ModelProto:
|
201
|
+
"""Loads model proto from the given tokenizer name."""
|
202
|
+
model_proto = sentencepiece_model_pb2.ModelProto()
|
203
|
+
model_proto.ParseFromString(_load_model_proto_bytes(tokenizer_name))
|
204
|
+
return model_proto
|
205
|
+
|
206
|
+
|
207
|
+
def get_tokenizer_name(model_name: str) -> str:
|
208
|
+
"""Gets the tokenizer name for the given model name."""
|
209
|
+
if model_name in _GEMINI_MODELS_TO_TOKENIZER_NAMES.keys():
|
210
|
+
return _GEMINI_MODELS_TO_TOKENIZER_NAMES[model_name]
|
211
|
+
if model_name in _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys():
|
212
|
+
return _GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES[model_name]
|
213
|
+
raise ValueError(
|
214
|
+
f"Model {model_name} is not supported. Supported models: {', '.join(_GEMINI_MODELS_TO_TOKENIZER_NAMES.keys())}, {', '.join(_GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES.keys())}.\n" # pylint: disable=line-too-long
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
@functools.lru_cache()
|
219
|
+
def get_sentencepiece(tokenizer_name: str) -> spm.SentencePieceProcessor:
|
220
|
+
"""Loads sentencepiece tokenizer from the given tokenizer name."""
|
221
|
+
processor = spm.SentencePieceProcessor()
|
222
|
+
processor.LoadFromSerializedProto(_load_model_proto_bytes(tokenizer_name))
|
223
|
+
return processor
|