langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/reporting.py +7 -2
- langfun/core/llms/__init__.py +21 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +45 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/vertexai.py +25 -357
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +15 -13
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Gemini API."""
|
15
|
+
|
16
|
+
import base64
|
17
|
+
from typing import Any
|
18
|
+
import unittest
|
19
|
+
from unittest import mock
|
20
|
+
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import gemini
|
24
|
+
import pyglove as pg
|
25
|
+
import requests
|
26
|
+
|
27
|
+
|
28
|
+
example_image = (
|
29
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
30
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
31
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
32
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
33
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
34
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
35
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
36
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
41
|
+
del url, kwargs
|
42
|
+
c = pg.Dict(json['generationConfig'])
|
43
|
+
content = json['contents'][0]['parts'][0]['text']
|
44
|
+
response = requests.Response()
|
45
|
+
response.status_code = 200
|
46
|
+
response._content = pg.to_json_str({
|
47
|
+
'candidates': [
|
48
|
+
{
|
49
|
+
'content': {
|
50
|
+
'role': 'model',
|
51
|
+
'parts': [
|
52
|
+
{
|
53
|
+
'text': (
|
54
|
+
f'This is a response to {content} with '
|
55
|
+
f'temperature={c.temperature}, '
|
56
|
+
f'top_p={c.topP}, '
|
57
|
+
f'top_k={c.topK}, '
|
58
|
+
f'max_tokens={c.maxOutputTokens}, '
|
59
|
+
f'stop={"".join(c.stopSequences)}.'
|
60
|
+
),
|
61
|
+
},
|
62
|
+
{
|
63
|
+
'text': 'This is the thought.',
|
64
|
+
'thought': True,
|
65
|
+
}
|
66
|
+
],
|
67
|
+
},
|
68
|
+
},
|
69
|
+
],
|
70
|
+
'usageMetadata': {
|
71
|
+
'promptTokenCount': 3,
|
72
|
+
'candidatesTokenCount': 4,
|
73
|
+
}
|
74
|
+
}).encode()
|
75
|
+
return response
|
76
|
+
|
77
|
+
|
78
|
+
class GeminiTest(unittest.TestCase):
|
79
|
+
"""Tests for Vertex model with REST API."""
|
80
|
+
|
81
|
+
def test_content_from_message_text_only(self):
|
82
|
+
text = 'This is a beautiful day'
|
83
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
84
|
+
chunks = model._content_from_message(lf.UserMessage(text))
|
85
|
+
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
86
|
+
|
87
|
+
def test_content_from_message_mm(self):
|
88
|
+
image = lf_modalities.Image.from_bytes(example_image)
|
89
|
+
message = lf.UserMessage(
|
90
|
+
'This is an <<[[image]]>>, what is it?', image=image
|
91
|
+
)
|
92
|
+
|
93
|
+
# Non-multimodal model.
|
94
|
+
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
95
|
+
gemini.Gemini(
|
96
|
+
'gemini-1.0-pro', api_endpoint=''
|
97
|
+
)._content_from_message(message)
|
98
|
+
|
99
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
100
|
+
content = model._content_from_message(message)
|
101
|
+
self.assertEqual(
|
102
|
+
content,
|
103
|
+
{
|
104
|
+
'role': 'user',
|
105
|
+
'parts': [
|
106
|
+
{'text': 'This is an'},
|
107
|
+
{
|
108
|
+
'inlineData': {
|
109
|
+
'data': base64.b64encode(example_image).decode(),
|
110
|
+
'mimeType': 'image/png',
|
111
|
+
}
|
112
|
+
},
|
113
|
+
{'text': ', what is it?'},
|
114
|
+
],
|
115
|
+
},
|
116
|
+
)
|
117
|
+
|
118
|
+
def test_generation_config(self):
|
119
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
120
|
+
json_schema = {
|
121
|
+
'type': 'object',
|
122
|
+
'properties': {
|
123
|
+
'name': {'type': 'string'},
|
124
|
+
},
|
125
|
+
'required': ['name'],
|
126
|
+
'title': 'Person',
|
127
|
+
}
|
128
|
+
actual = model._generation_config(
|
129
|
+
lf.UserMessage('hi', json_schema=json_schema),
|
130
|
+
lf.LMSamplingOptions(
|
131
|
+
temperature=2.0,
|
132
|
+
top_p=1.0,
|
133
|
+
top_k=20,
|
134
|
+
max_tokens=1024,
|
135
|
+
stop=['\n'],
|
136
|
+
),
|
137
|
+
)
|
138
|
+
self.assertEqual(
|
139
|
+
actual,
|
140
|
+
dict(
|
141
|
+
candidateCount=1,
|
142
|
+
temperature=2.0,
|
143
|
+
topP=1.0,
|
144
|
+
topK=20,
|
145
|
+
maxOutputTokens=1024,
|
146
|
+
stopSequences=['\n'],
|
147
|
+
responseLogprobs=False,
|
148
|
+
logprobs=None,
|
149
|
+
seed=None,
|
150
|
+
responseMimeType='application/json',
|
151
|
+
responseSchema={
|
152
|
+
'type': 'object',
|
153
|
+
'properties': {
|
154
|
+
'name': {'type': 'string'}
|
155
|
+
},
|
156
|
+
'required': ['name'],
|
157
|
+
'title': 'Person',
|
158
|
+
}
|
159
|
+
),
|
160
|
+
)
|
161
|
+
with self.assertRaisesRegex(
|
162
|
+
ValueError, '`json_schema` must be a dict, got'
|
163
|
+
):
|
164
|
+
model._generation_config(
|
165
|
+
lf.UserMessage('hi', json_schema='not a dict'),
|
166
|
+
lf.LMSamplingOptions(),
|
167
|
+
)
|
168
|
+
|
169
|
+
def test_call_model(self):
|
170
|
+
with mock.patch('requests.Session.post') as mock_generate:
|
171
|
+
mock_generate.side_effect = mock_requests_post
|
172
|
+
|
173
|
+
lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
174
|
+
r = lm(
|
175
|
+
'hello',
|
176
|
+
temperature=2.0,
|
177
|
+
top_p=1.0,
|
178
|
+
top_k=20,
|
179
|
+
max_tokens=1024,
|
180
|
+
stop='\n',
|
181
|
+
)
|
182
|
+
self.assertEqual(
|
183
|
+
r.text,
|
184
|
+
(
|
185
|
+
'This is a response to hello with temperature=2.0, '
|
186
|
+
'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
|
187
|
+
),
|
188
|
+
)
|
189
|
+
self.assertEqual(r.metadata.thought, 'This is the thought.')
|
190
|
+
self.assertEqual(r.metadata.usage.prompt_tokens, 3)
|
191
|
+
self.assertEqual(r.metadata.usage.completion_tokens, 4)
|
192
|
+
|
193
|
+
|
194
|
+
if __name__ == '__main__':
|
195
|
+
unittest.main()
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -11,57 +11,21 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""
|
14
|
+
"""Language models from Google GenAI."""
|
15
15
|
|
16
|
-
import abc
|
17
|
-
import functools
|
18
16
|
import os
|
19
|
-
from typing import Annotated,
|
17
|
+
from typing import Annotated, Literal
|
20
18
|
|
21
19
|
import langfun.core as lf
|
22
|
-
from langfun.core import
|
23
|
-
from langfun.core.llms import vertexai
|
20
|
+
from langfun.core.llms import gemini
|
24
21
|
import pyglove as pg
|
25
22
|
|
26
23
|
|
27
|
-
try:
|
28
|
-
import google.generativeai as genai # pylint: disable=g-import-not-at-top
|
29
|
-
BlobDict = genai.types.BlobDict
|
30
|
-
GenerativeModel = genai.GenerativeModel
|
31
|
-
Completion = getattr(genai.types, 'Completion', Any)
|
32
|
-
ChatResponse = getattr(genai.types, 'ChatResponse', Any)
|
33
|
-
GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any)
|
34
|
-
GenerationConfig = genai.GenerationConfig
|
35
|
-
except ImportError:
|
36
|
-
genai = None
|
37
|
-
BlobDict = Any
|
38
|
-
GenerativeModel = Any
|
39
|
-
Completion = Any
|
40
|
-
ChatResponse = Any
|
41
|
-
GenerationConfig = Any
|
42
|
-
GenerateContentResponse = Any
|
43
|
-
|
44
|
-
|
45
24
|
@lf.use_init_args(['model'])
|
46
|
-
|
25
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
26
|
+
class GenAI(gemini.Gemini):
|
47
27
|
"""Language models provided by Google GenAI."""
|
48
28
|
|
49
|
-
model: Annotated[
|
50
|
-
Literal[
|
51
|
-
'gemini-2.0-flash-thinking-exp-1219',
|
52
|
-
'gemini-2.0-flash-exp',
|
53
|
-
'gemini-exp-1206',
|
54
|
-
'gemini-exp-1114',
|
55
|
-
'gemini-1.5-pro-latest',
|
56
|
-
'gemini-1.5-flash-latest',
|
57
|
-
'gemini-pro',
|
58
|
-
'gemini-pro-vision',
|
59
|
-
'text-bison-001',
|
60
|
-
'chat-bison-001',
|
61
|
-
],
|
62
|
-
'Model name.',
|
63
|
-
]
|
64
|
-
|
65
29
|
api_key: Annotated[
|
66
30
|
str | None,
|
67
31
|
(
|
@@ -71,26 +35,18 @@ class GenAI(lf.LanguageModel):
|
|
71
35
|
),
|
72
36
|
] = None
|
73
37
|
|
74
|
-
|
75
|
-
|
76
|
-
'
|
77
|
-
] =
|
38
|
+
api_version: Annotated[
|
39
|
+
Literal['v1beta', 'v1alpha'],
|
40
|
+
'The API version to use.'
|
41
|
+
] = 'v1beta'
|
78
42
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
super()._on_bound()
|
84
|
-
if genai is None:
|
85
|
-
raise RuntimeError(
|
86
|
-
'Please install "langfun[llm-google-genai]" to use '
|
87
|
-
'Google Generative AI models.'
|
88
|
-
)
|
89
|
-
self.__dict__.pop('_api_initialized', None)
|
43
|
+
@property
|
44
|
+
def model_id(self) -> str:
|
45
|
+
"""Returns a string to identify the model."""
|
46
|
+
return f'GenAI({self.model})'
|
90
47
|
|
91
|
-
@
|
92
|
-
def
|
93
|
-
assert genai is not None
|
48
|
+
@property
|
49
|
+
def api_endpoint(self) -> str:
|
94
50
|
api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
|
95
51
|
if not api_key:
|
96
52
|
raise ValueError(
|
@@ -100,306 +56,75 @@ class GenAI(lf.LanguageModel):
|
|
100
56
|
'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
|
101
57
|
'for more details.'
|
102
58
|
)
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
def dir(cls) -> list[str]:
|
108
|
-
"""Lists generative models."""
|
109
|
-
assert genai is not None
|
110
|
-
return [
|
111
|
-
m.name.lstrip('models/')
|
112
|
-
for m in genai.list_models()
|
113
|
-
if (
|
114
|
-
'generateContent' in m.supported_generation_methods
|
115
|
-
or 'generateText' in m.supported_generation_methods
|
116
|
-
or 'generateMessage' in m.supported_generation_methods
|
117
|
-
)
|
118
|
-
]
|
119
|
-
|
120
|
-
@property
|
121
|
-
def model_id(self) -> str:
|
122
|
-
"""Returns a string to identify the model."""
|
123
|
-
return self.model
|
124
|
-
|
125
|
-
@property
|
126
|
-
def resource_id(self) -> str:
|
127
|
-
"""Returns a string to identify the resource for rate control."""
|
128
|
-
return self.model_id
|
129
|
-
|
130
|
-
def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
131
|
-
"""Creates generation config from langfun sampling options."""
|
132
|
-
return GenerationConfig(
|
133
|
-
candidate_count=options.n,
|
134
|
-
temperature=options.temperature,
|
135
|
-
top_p=options.top_p,
|
136
|
-
top_k=options.top_k,
|
137
|
-
max_output_tokens=options.max_tokens,
|
138
|
-
stop_sequences=options.stop,
|
59
|
+
return (
|
60
|
+
f'https://generativelanguage.googleapis.com/{self.api_version}'
|
61
|
+
f'/models/{self.model}:generateContent?'
|
62
|
+
f'key={api_key}'
|
139
63
|
)
|
140
64
|
|
141
|
-
def _content_from_message(
|
142
|
-
self, prompt: lf.Message
|
143
|
-
) -> list[str | BlobDict]:
|
144
|
-
"""Gets Evergreen formatted content from langfun message."""
|
145
|
-
formatted = lf.UserMessage(prompt.text)
|
146
|
-
formatted.source = prompt
|
147
|
-
|
148
|
-
chunks = []
|
149
|
-
for lf_chunk in formatted.chunk():
|
150
|
-
if isinstance(lf_chunk, str):
|
151
|
-
chunks.append(lf_chunk)
|
152
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
153
|
-
try:
|
154
|
-
modalities = lf_chunk.make_compatible(
|
155
|
-
self.supported_modalities + ['text/plain']
|
156
|
-
)
|
157
|
-
if isinstance(modalities, lf_modalities.Mime):
|
158
|
-
modalities = [modalities]
|
159
|
-
for modality in modalities:
|
160
|
-
if modality.is_text:
|
161
|
-
chunk = modality.to_text()
|
162
|
-
else:
|
163
|
-
chunk = BlobDict(
|
164
|
-
data=modality.to_bytes(),
|
165
|
-
mime_type=modality.mime_type
|
166
|
-
)
|
167
|
-
chunks.append(chunk)
|
168
|
-
except lf.ModalityError as e:
|
169
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
170
|
-
else:
|
171
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
|
172
|
-
return chunks
|
173
|
-
|
174
|
-
def _response_to_result(
|
175
|
-
self, response: GenerateContentResponse | pg.Dict
|
176
|
-
) -> lf.LMSamplingResult:
|
177
|
-
"""Parses generative response into message."""
|
178
|
-
samples = []
|
179
|
-
for candidate in response.candidates:
|
180
|
-
chunks = []
|
181
|
-
for part in candidate.content.parts:
|
182
|
-
# TODO(daiyip): support multi-modal parts when they are available via
|
183
|
-
# Gemini API.
|
184
|
-
if hasattr(part, 'text'):
|
185
|
-
chunks.append(part.text)
|
186
|
-
samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
|
187
|
-
return lf.LMSamplingResult(samples)
|
188
|
-
|
189
|
-
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
-
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
191
|
-
return self._parallel_execute_with_currency_control(
|
192
|
-
self._sample_single,
|
193
|
-
prompts,
|
194
|
-
)
|
195
65
|
|
196
|
-
|
197
|
-
|
198
|
-
model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
|
199
|
-
input_content = self._content_from_message(prompt)
|
200
|
-
response = model.generate_content(
|
201
|
-
input_content,
|
202
|
-
generation_config=self._generation_config(self.sampling_options),
|
203
|
-
)
|
204
|
-
return self._response_to_result(response)
|
205
|
-
|
206
|
-
|
207
|
-
class _LegacyGenerativeModel(pg.Object):
|
208
|
-
"""Base for legacy GenAI generative model."""
|
209
|
-
|
210
|
-
model: str
|
211
|
-
|
212
|
-
def generate_content(
|
213
|
-
self,
|
214
|
-
input_content: list[str | BlobDict],
|
215
|
-
generation_config: GenerationConfig,
|
216
|
-
) -> pg.Dict:
|
217
|
-
"""Generate content."""
|
218
|
-
segments = []
|
219
|
-
for s in input_content:
|
220
|
-
if not isinstance(s, str):
|
221
|
-
raise ValueError(f'Unsupported modality: {s!r}')
|
222
|
-
segments.append(s)
|
223
|
-
return self.generate(' '.join(segments), generation_config)
|
224
|
-
|
225
|
-
@abc.abstractmethod
|
226
|
-
def generate(
|
227
|
-
self, prompt: str, generation_config: GenerationConfig) -> pg.Dict:
|
228
|
-
"""Generate response based on prompt."""
|
229
|
-
|
230
|
-
|
231
|
-
class _LegacyCompletionModel(_LegacyGenerativeModel):
|
232
|
-
"""Legacy GenAI completion model."""
|
233
|
-
|
234
|
-
def generate(
|
235
|
-
self, prompt: str, generation_config: GenerationConfig
|
236
|
-
) -> pg.Dict:
|
237
|
-
assert genai is not None
|
238
|
-
completion: Completion = genai.generate_text(
|
239
|
-
model=f'models/{self.model}',
|
240
|
-
prompt=prompt,
|
241
|
-
temperature=generation_config.temperature,
|
242
|
-
top_k=generation_config.top_k,
|
243
|
-
top_p=generation_config.top_p,
|
244
|
-
candidate_count=generation_config.candidate_count,
|
245
|
-
max_output_tokens=generation_config.max_output_tokens,
|
246
|
-
stop_sequences=generation_config.stop_sequences,
|
247
|
-
)
|
248
|
-
return pg.Dict(
|
249
|
-
candidates=[
|
250
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
|
251
|
-
for c in completion.candidates
|
252
|
-
]
|
253
|
-
)
|
254
|
-
|
255
|
-
|
256
|
-
class _LegacyChatModel(_LegacyGenerativeModel):
|
257
|
-
"""Legacy GenAI chat model."""
|
258
|
-
|
259
|
-
def generate(
|
260
|
-
self, prompt: str, generation_config: GenerationConfig
|
261
|
-
) -> pg.Dict:
|
262
|
-
assert genai is not None
|
263
|
-
response: ChatResponse = genai.chat(
|
264
|
-
model=f'models/{self.model}',
|
265
|
-
messages=prompt,
|
266
|
-
temperature=generation_config.temperature,
|
267
|
-
top_k=generation_config.top_k,
|
268
|
-
top_p=generation_config.top_p,
|
269
|
-
candidate_count=generation_config.candidate_count,
|
270
|
-
)
|
271
|
-
return pg.Dict(
|
272
|
-
candidates=[
|
273
|
-
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
|
274
|
-
for c in response.candidates
|
275
|
-
]
|
276
|
-
)
|
277
|
-
|
278
|
-
|
279
|
-
class _ModelHub:
|
280
|
-
"""Google Generative AI model hub."""
|
281
|
-
|
282
|
-
def __init__(self):
|
283
|
-
self._model_cache = {}
|
284
|
-
|
285
|
-
def get(
|
286
|
-
self, model_name: str
|
287
|
-
) -> GenerativeModel | _LegacyGenerativeModel:
|
288
|
-
"""Gets a generative model by model id."""
|
289
|
-
assert genai is not None
|
290
|
-
model = self._model_cache.get(model_name, None)
|
291
|
-
if model is None:
|
292
|
-
model_info = genai.get_model(f'models/{model_name}')
|
293
|
-
if 'generateContent' in model_info.supported_generation_methods:
|
294
|
-
model = genai.GenerativeModel(model_name)
|
295
|
-
elif 'generateText' in model_info.supported_generation_methods:
|
296
|
-
model = _LegacyCompletionModel(model_name)
|
297
|
-
elif 'generateMessage' in model_info.supported_generation_methods:
|
298
|
-
model = _LegacyChatModel(model_name)
|
299
|
-
else:
|
300
|
-
raise ValueError(f'Unsupported model: {model_name!r}')
|
301
|
-
self._model_cache[model_name] = model
|
302
|
-
return model
|
303
|
-
|
304
|
-
|
305
|
-
_GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
306
|
-
|
307
|
-
|
308
|
-
#
|
309
|
-
# Public Gemini models.
|
310
|
-
#
|
311
|
-
class GeminiFlash2_0ThinkingExp(GenAI): # pylint: disable=invalid-name
|
312
|
-
"""Gemini 2.0 Flash Thinking Experimental model."""
|
66
|
+
class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
|
67
|
+
"""Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
313
68
|
|
69
|
+
api_version = 'v1alpha'
|
314
70
|
model = 'gemini-2.0-flash-thinking-exp-1219'
|
315
|
-
supported_modalities = (
|
316
|
-
vertexai.DOCUMENT_TYPES
|
317
|
-
+ vertexai.IMAGE_TYPES
|
318
|
-
+ vertexai.AUDIO_TYPES
|
319
|
-
+ vertexai.VIDEO_TYPES
|
320
|
-
)
|
321
71
|
|
322
72
|
|
323
73
|
class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
|
324
|
-
"""Gemini
|
74
|
+
"""Gemini Flash 2.0 model launched on 12/11/2024."""
|
325
75
|
|
326
76
|
model = 'gemini-2.0-flash-exp'
|
327
|
-
supported_modalities = (
|
328
|
-
vertexai.DOCUMENT_TYPES
|
329
|
-
+ vertexai.IMAGE_TYPES
|
330
|
-
+ vertexai.AUDIO_TYPES
|
331
|
-
+ vertexai.VIDEO_TYPES
|
332
|
-
)
|
333
77
|
|
334
78
|
|
335
79
|
class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
|
336
80
|
"""Gemini Experimental model launched on 12/06/2024."""
|
337
81
|
|
338
82
|
model = 'gemini-exp-1206'
|
339
|
-
supported_modalities = (
|
340
|
-
vertexai.DOCUMENT_TYPES
|
341
|
-
+ vertexai.IMAGE_TYPES
|
342
|
-
+ vertexai.AUDIO_TYPES
|
343
|
-
+ vertexai.VIDEO_TYPES
|
344
|
-
)
|
345
83
|
|
346
84
|
|
347
85
|
class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
|
348
86
|
"""Gemini Experimental model launched on 11/14/2024."""
|
349
87
|
|
350
88
|
model = 'gemini-exp-1114'
|
351
|
-
supported_modalities = (
|
352
|
-
vertexai.DOCUMENT_TYPES
|
353
|
-
+ vertexai.IMAGE_TYPES
|
354
|
-
+ vertexai.AUDIO_TYPES
|
355
|
-
+ vertexai.VIDEO_TYPES
|
356
|
-
)
|
357
89
|
|
358
90
|
|
359
91
|
class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
|
360
92
|
"""Gemini Pro latest model."""
|
361
93
|
|
362
94
|
model = 'gemini-1.5-pro-latest'
|
363
|
-
supported_modalities = (
|
364
|
-
vertexai.DOCUMENT_TYPES
|
365
|
-
+ vertexai.IMAGE_TYPES
|
366
|
-
+ vertexai.AUDIO_TYPES
|
367
|
-
+ vertexai.VIDEO_TYPES
|
368
|
-
)
|
369
95
|
|
370
96
|
|
371
|
-
class
|
372
|
-
"""Gemini
|
97
|
+
class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
|
98
|
+
"""Gemini Pro latest model."""
|
99
|
+
|
100
|
+
model = 'gemini-1.5-pro-002'
|
373
101
|
|
374
|
-
model = 'gemini-1.5-flash-latest'
|
375
|
-
supported_modalities = (
|
376
|
-
vertexai.DOCUMENT_TYPES
|
377
|
-
+ vertexai.IMAGE_TYPES
|
378
|
-
+ vertexai.AUDIO_TYPES
|
379
|
-
+ vertexai.VIDEO_TYPES
|
380
|
-
)
|
381
102
|
|
103
|
+
class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
|
104
|
+
"""Gemini Pro latest model."""
|
382
105
|
|
383
|
-
|
384
|
-
"""Gemini Pro model."""
|
106
|
+
model = 'gemini-1.5-pro-001'
|
385
107
|
|
386
|
-
|
108
|
+
|
109
|
+
class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
|
110
|
+
"""Gemini Flash latest model."""
|
111
|
+
|
112
|
+
model = 'gemini-1.5-flash-latest'
|
387
113
|
|
388
114
|
|
389
|
-
class
|
390
|
-
"""Gemini
|
115
|
+
class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
|
116
|
+
"""Gemini Flash 1.5 model stable version 002."""
|
391
117
|
|
392
|
-
model = 'gemini-
|
393
|
-
supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
|
118
|
+
model = 'gemini-1.5-flash-002'
|
394
119
|
|
395
120
|
|
396
|
-
class
|
397
|
-
"""
|
121
|
+
class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
|
122
|
+
"""Gemini Flash 1.5 model stable version 001."""
|
398
123
|
|
399
|
-
model = '
|
124
|
+
model = 'gemini-1.5-flash-001'
|
400
125
|
|
401
126
|
|
402
|
-
class
|
403
|
-
"""
|
127
|
+
class GeminiPro1(GenAI): # pylint: disable=invalid-name
|
128
|
+
"""Gemini 1.0 Pro model."""
|
404
129
|
|
405
|
-
model = '
|
130
|
+
model = 'gemini-1.0-pro'
|