langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. langfun/core/__init__.py +0 -4
  2. langfun/core/eval/matching.py +2 -2
  3. langfun/core/eval/scoring.py +6 -2
  4. langfun/core/eval/v2/checkpointing.py +106 -72
  5. langfun/core/eval/v2/checkpointing_test.py +108 -3
  6. langfun/core/eval/v2/eval_test_helper.py +56 -0
  7. langfun/core/eval/v2/evaluation.py +25 -4
  8. langfun/core/eval/v2/evaluation_test.py +11 -0
  9. langfun/core/eval/v2/example.py +11 -1
  10. langfun/core/eval/v2/example_test.py +16 -2
  11. langfun/core/eval/v2/experiment.py +83 -19
  12. langfun/core/eval/v2/experiment_test.py +121 -3
  13. langfun/core/eval/v2/reporting.py +67 -20
  14. langfun/core/eval/v2/reporting_test.py +119 -2
  15. langfun/core/eval/v2/runners.py +7 -4
  16. langfun/core/llms/__init__.py +23 -24
  17. langfun/core/llms/anthropic.py +12 -0
  18. langfun/core/llms/cache/in_memory.py +6 -0
  19. langfun/core/llms/cache/in_memory_test.py +5 -0
  20. langfun/core/llms/gemini.py +507 -0
  21. langfun/core/llms/gemini_test.py +195 -0
  22. langfun/core/llms/google_genai.py +46 -310
  23. langfun/core/llms/google_genai_test.py +9 -204
  24. langfun/core/llms/openai.py +23 -37
  25. langfun/core/llms/vertexai.py +28 -348
  26. langfun/core/llms/vertexai_test.py +6 -166
  27. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
  28. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
  29. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
  30. langfun/core/repr_utils.py +0 -204
  31. langfun/core/repr_utils_test.py +0 -90
  32. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
  33. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Gemini API."""
15
+
16
+ import base64
17
+ from typing import Any
18
+ import unittest
19
+ from unittest import mock
20
+
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import gemini
24
+ import pyglove as pg
25
+ import requests
26
+
27
+
28
+ example_image = (
29
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
30
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
31
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
32
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
33
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
34
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
35
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
36
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
37
+ )
38
+
39
+
40
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
41
+ del url, kwargs
42
+ c = pg.Dict(json['generationConfig'])
43
+ content = json['contents'][0]['parts'][0]['text']
44
+ response = requests.Response()
45
+ response.status_code = 200
46
+ response._content = pg.to_json_str({
47
+ 'candidates': [
48
+ {
49
+ 'content': {
50
+ 'role': 'model',
51
+ 'parts': [
52
+ {
53
+ 'text': (
54
+ f'This is a response to {content} with '
55
+ f'temperature={c.temperature}, '
56
+ f'top_p={c.topP}, '
57
+ f'top_k={c.topK}, '
58
+ f'max_tokens={c.maxOutputTokens}, '
59
+ f'stop={"".join(c.stopSequences)}.'
60
+ ),
61
+ },
62
+ {
63
+ 'text': 'This is the thought.',
64
+ 'thought': True,
65
+ }
66
+ ],
67
+ },
68
+ },
69
+ ],
70
+ 'usageMetadata': {
71
+ 'promptTokenCount': 3,
72
+ 'candidatesTokenCount': 4,
73
+ }
74
+ }).encode()
75
+ return response
76
+
77
+
78
+ class GeminiTest(unittest.TestCase):
79
+ """Tests for Vertex model with REST API."""
80
+
81
+ def test_content_from_message_text_only(self):
82
+ text = 'This is a beautiful day'
83
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
84
+ chunks = model._content_from_message(lf.UserMessage(text))
85
+ self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
86
+
87
+ def test_content_from_message_mm(self):
88
+ image = lf_modalities.Image.from_bytes(example_image)
89
+ message = lf.UserMessage(
90
+ 'This is an <<[[image]]>>, what is it?', image=image
91
+ )
92
+
93
+ # Non-multimodal model.
94
+ with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
95
+ gemini.Gemini(
96
+ 'gemini-1.0-pro', api_endpoint=''
97
+ )._content_from_message(message)
98
+
99
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
100
+ content = model._content_from_message(message)
101
+ self.assertEqual(
102
+ content,
103
+ {
104
+ 'role': 'user',
105
+ 'parts': [
106
+ {'text': 'This is an'},
107
+ {
108
+ 'inlineData': {
109
+ 'data': base64.b64encode(example_image).decode(),
110
+ 'mimeType': 'image/png',
111
+ }
112
+ },
113
+ {'text': ', what is it?'},
114
+ ],
115
+ },
116
+ )
117
+
118
+ def test_generation_config(self):
119
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
120
+ json_schema = {
121
+ 'type': 'object',
122
+ 'properties': {
123
+ 'name': {'type': 'string'},
124
+ },
125
+ 'required': ['name'],
126
+ 'title': 'Person',
127
+ }
128
+ actual = model._generation_config(
129
+ lf.UserMessage('hi', json_schema=json_schema),
130
+ lf.LMSamplingOptions(
131
+ temperature=2.0,
132
+ top_p=1.0,
133
+ top_k=20,
134
+ max_tokens=1024,
135
+ stop=['\n'],
136
+ ),
137
+ )
138
+ self.assertEqual(
139
+ actual,
140
+ dict(
141
+ candidateCount=1,
142
+ temperature=2.0,
143
+ topP=1.0,
144
+ topK=20,
145
+ maxOutputTokens=1024,
146
+ stopSequences=['\n'],
147
+ responseLogprobs=False,
148
+ logprobs=None,
149
+ seed=None,
150
+ responseMimeType='application/json',
151
+ responseSchema={
152
+ 'type': 'object',
153
+ 'properties': {
154
+ 'name': {'type': 'string'}
155
+ },
156
+ 'required': ['name'],
157
+ 'title': 'Person',
158
+ }
159
+ ),
160
+ )
161
+ with self.assertRaisesRegex(
162
+ ValueError, '`json_schema` must be a dict, got'
163
+ ):
164
+ model._generation_config(
165
+ lf.UserMessage('hi', json_schema='not a dict'),
166
+ lf.LMSamplingOptions(),
167
+ )
168
+
169
+ def test_call_model(self):
170
+ with mock.patch('requests.Session.post') as mock_generate:
171
+ mock_generate.side_effect = mock_requests_post
172
+
173
+ lm = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
174
+ r = lm(
175
+ 'hello',
176
+ temperature=2.0,
177
+ top_p=1.0,
178
+ top_k=20,
179
+ max_tokens=1024,
180
+ stop='\n',
181
+ )
182
+ self.assertEqual(
183
+ r.text,
184
+ (
185
+ 'This is a response to hello with temperature=2.0, '
186
+ 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
187
+ ),
188
+ )
189
+ self.assertEqual(r.metadata.thought, 'This is the thought.')
190
+ self.assertEqual(r.metadata.usage.prompt_tokens, 3)
191
+ self.assertEqual(r.metadata.usage.completion_tokens, 4)
192
+
193
+
194
+ if __name__ == '__main__':
195
+ unittest.main()
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,56 +11,21 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Gemini models exposed through Google Generative AI APIs."""
14
+ """Language models from Google GenAI."""
15
15
 
16
- import abc
17
- import functools
18
16
  import os
19
- from typing import Annotated, Any, Literal
17
+ from typing import Annotated, Literal
20
18
 
21
19
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import vertexai
20
+ from langfun.core.llms import gemini
24
21
  import pyglove as pg
25
22
 
26
23
 
27
- try:
28
- import google.generativeai as genai # pylint: disable=g-import-not-at-top
29
- BlobDict = genai.types.BlobDict
30
- GenerativeModel = genai.GenerativeModel
31
- Completion = getattr(genai.types, 'Completion', Any)
32
- ChatResponse = getattr(genai.types, 'ChatResponse', Any)
33
- GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any)
34
- GenerationConfig = genai.GenerationConfig
35
- except ImportError:
36
- genai = None
37
- BlobDict = Any
38
- GenerativeModel = Any
39
- Completion = Any
40
- ChatResponse = Any
41
- GenerationConfig = Any
42
- GenerateContentResponse = Any
43
-
44
-
45
24
  @lf.use_init_args(['model'])
46
- class GenAI(lf.LanguageModel):
25
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
26
+ class GenAI(gemini.Gemini):
47
27
  """Language models provided by Google GenAI."""
48
28
 
49
- model: Annotated[
50
- Literal[
51
- 'gemini-2.0-flash-exp',
52
- 'gemini-exp-1206',
53
- 'gemini-exp-1114',
54
- 'gemini-1.5-pro-latest',
55
- 'gemini-1.5-flash-latest',
56
- 'gemini-pro',
57
- 'gemini-pro-vision',
58
- 'text-bison-001',
59
- 'chat-bison-001',
60
- ],
61
- 'Model name.',
62
- ]
63
-
64
29
  api_key: Annotated[
65
30
  str | None,
66
31
  (
@@ -70,26 +35,18 @@ class GenAI(lf.LanguageModel):
70
35
  ),
71
36
  ] = None
72
37
 
73
- supported_modalities: Annotated[
74
- list[str],
75
- 'A list of MIME types for supported modalities'
76
- ] = []
38
+ api_version: Annotated[
39
+ Literal['v1beta', 'v1alpha'],
40
+ 'The API version to use.'
41
+ ] = 'v1beta'
77
42
 
78
- # Set the default max concurrency to 8 workers.
79
- max_concurrency = 8
80
-
81
- def _on_bound(self):
82
- super()._on_bound()
83
- if genai is None:
84
- raise RuntimeError(
85
- 'Please install "langfun[llm-google-genai]" to use '
86
- 'Google Generative AI models.'
87
- )
88
- self.__dict__.pop('_api_initialized', None)
43
+ @property
44
+ def model_id(self) -> str:
45
+ """Returns a string to identify the model."""
46
+ return f'GenAI({self.model})'
89
47
 
90
- @functools.cached_property
91
- def _api_initialized(self):
92
- assert genai is not None
48
+ @property
49
+ def api_endpoint(self) -> str:
93
50
  api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
94
51
  if not api_key:
95
52
  raise ValueError(
@@ -99,296 +56,75 @@ class GenAI(lf.LanguageModel):
99
56
  'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
100
57
  'for more details.'
101
58
  )
102
- genai.configure(api_key=api_key)
103
- return True
104
-
105
- @classmethod
106
- def dir(cls) -> list[str]:
107
- """Lists generative models."""
108
- assert genai is not None
109
- return [
110
- m.name.lstrip('models/')
111
- for m in genai.list_models()
112
- if (
113
- 'generateContent' in m.supported_generation_methods
114
- or 'generateText' in m.supported_generation_methods
115
- or 'generateMessage' in m.supported_generation_methods
116
- )
117
- ]
118
-
119
- @property
120
- def model_id(self) -> str:
121
- """Returns a string to identify the model."""
122
- return self.model
123
-
124
- @property
125
- def resource_id(self) -> str:
126
- """Returns a string to identify the resource for rate control."""
127
- return self.model_id
128
-
129
- def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
130
- """Creates generation config from langfun sampling options."""
131
- return GenerationConfig(
132
- candidate_count=options.n,
133
- temperature=options.temperature,
134
- top_p=options.top_p,
135
- top_k=options.top_k,
136
- max_output_tokens=options.max_tokens,
137
- stop_sequences=options.stop,
138
- )
139
-
140
- def _content_from_message(
141
- self, prompt: lf.Message
142
- ) -> list[str | BlobDict]:
143
- """Gets Evergreen formatted content from langfun message."""
144
- formatted = lf.UserMessage(prompt.text)
145
- formatted.source = prompt
146
-
147
- chunks = []
148
- for lf_chunk in formatted.chunk():
149
- if isinstance(lf_chunk, str):
150
- chunks.append(lf_chunk)
151
- elif isinstance(lf_chunk, lf_modalities.Mime):
152
- try:
153
- modalities = lf_chunk.make_compatible(
154
- self.supported_modalities + ['text/plain']
155
- )
156
- if isinstance(modalities, lf_modalities.Mime):
157
- modalities = [modalities]
158
- for modality in modalities:
159
- if modality.is_text:
160
- chunk = modality.to_text()
161
- else:
162
- chunk = BlobDict(
163
- data=modality.to_bytes(),
164
- mime_type=modality.mime_type
165
- )
166
- chunks.append(chunk)
167
- except lf.ModalityError as e:
168
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
169
- else:
170
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
171
- return chunks
172
-
173
- def _response_to_result(
174
- self, response: GenerateContentResponse | pg.Dict
175
- ) -> lf.LMSamplingResult:
176
- """Parses generative response into message."""
177
- samples = []
178
- for candidate in response.candidates:
179
- chunks = []
180
- for part in candidate.content.parts:
181
- # TODO(daiyip): support multi-modal parts when they are available via
182
- # Gemini API.
183
- if hasattr(part, 'text'):
184
- chunks.append(part.text)
185
- samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
186
- return lf.LMSamplingResult(samples)
187
-
188
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
189
- assert self._api_initialized, 'Vertex AI API is not initialized.'
190
- return self._parallel_execute_with_currency_control(
191
- self._sample_single,
192
- prompts,
193
- )
194
-
195
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
196
- """Samples a single prompt."""
197
- model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
198
- input_content = self._content_from_message(prompt)
199
- response = model.generate_content(
200
- input_content,
201
- generation_config=self._generation_config(self.sampling_options),
202
- )
203
- return self._response_to_result(response)
204
-
205
-
206
- class _LegacyGenerativeModel(pg.Object):
207
- """Base for legacy GenAI generative model."""
208
-
209
- model: str
210
-
211
- def generate_content(
212
- self,
213
- input_content: list[str | BlobDict],
214
- generation_config: GenerationConfig,
215
- ) -> pg.Dict:
216
- """Generate content."""
217
- segments = []
218
- for s in input_content:
219
- if not isinstance(s, str):
220
- raise ValueError(f'Unsupported modality: {s!r}')
221
- segments.append(s)
222
- return self.generate(' '.join(segments), generation_config)
223
-
224
- @abc.abstractmethod
225
- def generate(
226
- self, prompt: str, generation_config: GenerationConfig) -> pg.Dict:
227
- """Generate response based on prompt."""
228
-
229
-
230
- class _LegacyCompletionModel(_LegacyGenerativeModel):
231
- """Legacy GenAI completion model."""
232
-
233
- def generate(
234
- self, prompt: str, generation_config: GenerationConfig
235
- ) -> pg.Dict:
236
- assert genai is not None
237
- completion: Completion = genai.generate_text(
238
- model=f'models/{self.model}',
239
- prompt=prompt,
240
- temperature=generation_config.temperature,
241
- top_k=generation_config.top_k,
242
- top_p=generation_config.top_p,
243
- candidate_count=generation_config.candidate_count,
244
- max_output_tokens=generation_config.max_output_tokens,
245
- stop_sequences=generation_config.stop_sequences,
246
- )
247
- return pg.Dict(
248
- candidates=[
249
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
250
- for c in completion.candidates
251
- ]
252
- )
253
-
254
-
255
- class _LegacyChatModel(_LegacyGenerativeModel):
256
- """Legacy GenAI chat model."""
257
-
258
- def generate(
259
- self, prompt: str, generation_config: GenerationConfig
260
- ) -> pg.Dict:
261
- assert genai is not None
262
- response: ChatResponse = genai.chat(
263
- model=f'models/{self.model}',
264
- messages=prompt,
265
- temperature=generation_config.temperature,
266
- top_k=generation_config.top_k,
267
- top_p=generation_config.top_p,
268
- candidate_count=generation_config.candidate_count,
269
- )
270
- return pg.Dict(
271
- candidates=[
272
- pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
273
- for c in response.candidates
274
- ]
59
+ return (
60
+ f'https://generativelanguage.googleapis.com/{self.api_version}'
61
+ f'/models/{self.model}:generateContent?'
62
+ f'key={api_key}'
275
63
  )
276
64
 
277
65
 
278
- class _ModelHub:
279
- """Google Generative AI model hub."""
280
-
281
- def __init__(self):
282
- self._model_cache = {}
283
-
284
- def get(
285
- self, model_name: str
286
- ) -> GenerativeModel | _LegacyGenerativeModel:
287
- """Gets a generative model by model id."""
288
- assert genai is not None
289
- model = self._model_cache.get(model_name, None)
290
- if model is None:
291
- model_info = genai.get_model(f'models/{model_name}')
292
- if 'generateContent' in model_info.supported_generation_methods:
293
- model = genai.GenerativeModel(model_name)
294
- elif 'generateText' in model_info.supported_generation_methods:
295
- model = _LegacyCompletionModel(model_name)
296
- elif 'generateMessage' in model_info.supported_generation_methods:
297
- model = _LegacyChatModel(model_name)
298
- else:
299
- raise ValueError(f'Unsupported model: {model_name!r}')
300
- self._model_cache[model_name] = model
301
- return model
302
-
66
+ class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
67
+ """Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
303
68
 
304
- _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
305
-
306
-
307
- #
308
- # Public Gemini models.
309
- #
69
+ api_version = 'v1alpha'
70
+ model = 'gemini-2.0-flash-thinking-exp-1219'
310
71
 
311
72
 
312
73
  class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
313
- """Gemini Experimental model launched on 12/06/2024."""
74
+ """Gemini Flash 2.0 model launched on 12/11/2024."""
314
75
 
315
76
  model = 'gemini-2.0-flash-exp'
316
- supported_modalities = (
317
- vertexai.DOCUMENT_TYPES
318
- + vertexai.IMAGE_TYPES
319
- + vertexai.AUDIO_TYPES
320
- + vertexai.VIDEO_TYPES
321
- )
322
77
 
323
78
 
324
79
  class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
325
80
  """Gemini Experimental model launched on 12/06/2024."""
326
81
 
327
82
  model = 'gemini-exp-1206'
328
- supported_modalities = (
329
- vertexai.DOCUMENT_TYPES
330
- + vertexai.IMAGE_TYPES
331
- + vertexai.AUDIO_TYPES
332
- + vertexai.VIDEO_TYPES
333
- )
334
83
 
335
84
 
336
85
  class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
337
86
  """Gemini Experimental model launched on 11/14/2024."""
338
87
 
339
88
  model = 'gemini-exp-1114'
340
- supported_modalities = (
341
- vertexai.DOCUMENT_TYPES
342
- + vertexai.IMAGE_TYPES
343
- + vertexai.AUDIO_TYPES
344
- + vertexai.VIDEO_TYPES
345
- )
346
89
 
347
90
 
348
91
  class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
349
92
  """Gemini Pro latest model."""
350
93
 
351
94
  model = 'gemini-1.5-pro-latest'
352
- supported_modalities = (
353
- vertexai.DOCUMENT_TYPES
354
- + vertexai.IMAGE_TYPES
355
- + vertexai.AUDIO_TYPES
356
- + vertexai.VIDEO_TYPES
357
- )
358
95
 
359
96
 
360
- class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
361
- """Gemini Flash latest model."""
97
+ class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
98
+ """Gemini Pro latest model."""
99
+
100
+ model = 'gemini-1.5-pro-002'
362
101
 
363
- model = 'gemini-1.5-flash-latest'
364
- supported_modalities = (
365
- vertexai.DOCUMENT_TYPES
366
- + vertexai.IMAGE_TYPES
367
- + vertexai.AUDIO_TYPES
368
- + vertexai.VIDEO_TYPES
369
- )
370
102
 
103
+ class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
104
+ """Gemini Pro latest model."""
371
105
 
372
- class GeminiPro(GenAI):
373
- """Gemini Pro model."""
106
+ model = 'gemini-1.5-pro-001'
374
107
 
375
- model = 'gemini-pro'
108
+
109
+ class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
110
+ """Gemini Flash latest model."""
111
+
112
+ model = 'gemini-1.5-flash-latest'
376
113
 
377
114
 
378
- class GeminiProVision(GenAI):
379
- """Gemini Pro vision model."""
115
+ class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
116
+ """Gemini Flash 1.5 model stable version 002."""
380
117
 
381
- model = 'gemini-pro-vision'
382
- supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
118
+ model = 'gemini-1.5-flash-002'
383
119
 
384
120
 
385
- class Palm2(GenAI):
386
- """PaLM2 model."""
121
+ class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
122
+ """Gemini Flash 1.5 model stable version 001."""
387
123
 
388
- model = 'text-bison-001'
124
+ model = 'gemini-1.5-flash-001'
389
125
 
390
126
 
391
- class Palm2_IT(GenAI): # pylint: disable=invalid-name
392
- """PaLM2 instruction-tuned model."""
127
+ class GeminiPro1(GenAI): # pylint: disable=invalid-name
128
+ """Gemini 1.0 Pro model."""
393
129
 
394
- model = 'chat-bison-001'
130
+ model = 'gemini-1.0-pro'