langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,223 +11,28 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for Google GenAI models."""
15
15
 
16
16
  import os
17
17
  import unittest
18
- from unittest import mock
19
-
20
- from google import generativeai as genai
21
- import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
18
  from langfun.core.llms import google_genai
24
- import pyglove as pg
25
-
26
-
27
- example_image = (
28
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
- )
37
-
38
-
39
- def mock_get_model(model_name, *args, **kwargs):
40
- del args, kwargs
41
- if 'gemini' in model_name:
42
- method = 'generateContent'
43
- elif 'chat' in model_name:
44
- method = 'generateMessage'
45
- else:
46
- method = 'generateText'
47
- return pg.Dict(supported_generation_methods=[method])
48
-
49
-
50
- def mock_generate_text(*, model, prompt, **kwargs):
51
- return pg.Dict(
52
- candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
- )
54
-
55
-
56
- def mock_chat(*, model, messages, **kwargs):
57
- return pg.Dict(
58
- candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
- )
60
-
61
-
62
- def mock_generate_content(content, generation_config, **kwargs):
63
- del kwargs
64
- c = generation_config
65
- return genai.types.GenerateContentResponse(
66
- done=True,
67
- iterator=None,
68
- chunks=[],
69
- result=pg.Dict(
70
- prompt_feedback=pg.Dict(block_reason=None),
71
- candidates=[
72
- pg.Dict(
73
- content=pg.Dict(
74
- parts=[
75
- pg.Dict(
76
- text=(
77
- f'This is a response to {content[0]} with '
78
- f'n={c.candidate_count}, '
79
- f'temperature={c.temperature}, '
80
- f'top_p={c.top_p}, '
81
- f'top_k={c.top_k}, '
82
- f'max_tokens={c.max_output_tokens}, '
83
- f'stop={c.stop_sequences}.'
84
- )
85
- )
86
- ]
87
- ),
88
- ),
89
- ],
90
- ),
91
- )
92
19
 
93
20
 
94
21
  class GenAITest(unittest.TestCase):
95
- """Tests for Google GenAI model."""
96
-
97
- def test_content_from_message_text_only(self):
98
- text = 'This is a beautiful day'
99
- model = google_genai.GeminiPro()
100
- chunks = model._content_from_message(lf.UserMessage(text))
101
- self.assertEqual(chunks, [text])
102
-
103
- def test_content_from_message_mm(self):
104
- message = lf.UserMessage(
105
- 'This is an <<[[image]]>>, what is it?',
106
- image=lf_modalities.Image.from_bytes(example_image),
107
- )
22
+ """Tests for GenAI model."""
108
23
 
109
- # Non-multimodal model.
110
- with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
111
- google_genai.GeminiPro()._content_from_message(message)
112
-
113
- model = google_genai.GeminiProVision()
114
- chunks = model._content_from_message(message)
115
- self.maxDiff = None
116
- self.assertEqual(
117
- chunks,
118
- [
119
- 'This is an',
120
- genai.types.BlobDict(mime_type='image/png', data=example_image),
121
- ', what is it?',
122
- ],
123
- )
124
-
125
- def test_response_to_result_text_only(self):
126
- response = genai.types.GenerateContentResponse(
127
- done=True,
128
- iterator=None,
129
- chunks=[],
130
- result=pg.Dict(
131
- prompt_feedback=pg.Dict(block_reason=None),
132
- candidates=[
133
- pg.Dict(
134
- content=pg.Dict(
135
- parts=[pg.Dict(text='This is response 1.')]
136
- ),
137
- ),
138
- pg.Dict(
139
- content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
140
- ),
141
- ],
142
- ),
143
- )
144
- model = google_genai.GeminiProVision()
145
- result = model._response_to_result(response)
146
- self.assertEqual(
147
- result,
148
- lf.LMSamplingResult([
149
- lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
150
- lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
151
- ]),
152
- )
153
-
154
- def test_model_hub(self):
155
- orig_get_model = genai.get_model
156
- genai.get_model = mock_get_model
157
-
158
- model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
159
- self.assertIsNotNone(model)
160
- self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
161
-
162
- genai.get_model = orig_get_model
163
-
164
- def test_api_key_check(self):
24
+ def test_basics(self):
165
25
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
166
- _ = google_genai.GeminiPro()._api_initialized
26
+ _ = google_genai.GeminiPro1_5().api_endpoint
27
+
28
+ self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
167
29
 
168
- self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
169
30
  os.environ['GOOGLE_API_KEY'] = 'abc'
170
- self.assertTrue(google_genai.GeminiPro()._api_initialized)
31
+ lm = google_genai.GeminiPro1_5()
32
+ self.assertIsNotNone(lm.api_endpoint)
33
+ self.assertTrue(lm.model_id.startswith('GenAI('))
171
34
  del os.environ['GOOGLE_API_KEY']
172
35
 
173
- def test_call(self):
174
- with mock.patch(
175
- 'google.generativeai.GenerativeModel.generate_content',
176
- ) as mock_generate:
177
- orig_get_model = genai.get_model
178
- genai.get_model = mock_get_model
179
- mock_generate.side_effect = mock_generate_content
180
-
181
- lm = google_genai.GeminiPro(api_key='test_key')
182
- self.maxDiff = None
183
- self.assertEqual(
184
- lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
185
- (
186
- 'This is a response to hello with n=1, temperature=2.0, '
187
- 'top_p=None, top_k=20, max_tokens=1024, stop=None.'
188
- ),
189
- )
190
- genai.get_model = orig_get_model
191
-
192
- def test_call_with_legacy_completion_model(self):
193
- orig_get_model = genai.get_model
194
- genai.get_model = mock_get_model
195
- orig_generate_text = getattr(genai, 'generate_text', None)
196
- if orig_generate_text is not None:
197
- genai.generate_text = mock_generate_text
198
-
199
- lm = google_genai.Palm2(api_key='test_key')
200
- self.maxDiff = None
201
- self.assertEqual(
202
- lm('hello', temperature=2.0, top_k=20).text,
203
- (
204
- "hello to models/text-bison-001 with {'temperature': 2.0, "
205
- "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
206
- "'max_output_tokens': None, 'stop_sequences': None}"
207
- ),
208
- )
209
- genai.generate_text = orig_generate_text
210
- genai.get_model = orig_get_model
211
-
212
- def test_call_with_legacy_chat_model(self):
213
- orig_get_model = genai.get_model
214
- genai.get_model = mock_get_model
215
- orig_chat = getattr(genai, 'chat', None)
216
- if orig_chat is not None:
217
- genai.chat = mock_chat
218
-
219
- lm = google_genai.Palm2_IT(api_key='test_key')
220
- self.maxDiff = None
221
- self.assertEqual(
222
- lm('hello', temperature=2.0, top_k=20).text,
223
- (
224
- "hello to models/chat-bison-001 with {'temperature': 2.0, "
225
- "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
226
- ),
227
- )
228
- genai.chat = orig_chat
229
- genai.get_model = orig_get_model
230
-
231
36
 
232
37
  if __name__ == '__main__':
233
38
  unittest.main()
@@ -13,14 +13,12 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
- import base64
17
16
  import functools
18
17
  import os
19
18
  from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import rest
21
+ from langfun.core.llms import gemini
24
22
  import pyglove as pg
25
23
 
26
24
  try:
@@ -38,114 +36,11 @@ except ImportError:
38
36
  Credentials = Any
39
37
 
40
38
 
41
- # https://cloud.google.com/vertex-ai/generative-ai/pricing
42
- # describes that the average number of characters per token is about 4.
43
- AVGERAGE_CHARS_PER_TOKEN = 4
44
-
45
-
46
- # Price in US dollars,
47
- # from https://cloud.google.com/vertex-ai/generative-ai/pricing
48
- # as of 2024-10-10.
49
- SUPPORTED_MODELS_AND_SETTINGS = {
50
- 'gemini-1.5-pro-001': pg.Dict(
51
- rpm=100,
52
- cost_per_1k_input_chars=0.0003125,
53
- cost_per_1k_output_chars=0.00125,
54
- ),
55
- 'gemini-1.5-pro-002': pg.Dict(
56
- rpm=100,
57
- cost_per_1k_input_chars=0.0003125,
58
- cost_per_1k_output_chars=0.00125,
59
- ),
60
- 'gemini-1.5-flash-002': pg.Dict(
61
- rpm=500,
62
- cost_per_1k_input_chars=0.00001875,
63
- cost_per_1k_output_chars=0.000075,
64
- ),
65
- 'gemini-1.5-flash-001': pg.Dict(
66
- rpm=500,
67
- cost_per_1k_input_chars=0.00001875,
68
- cost_per_1k_output_chars=0.000075,
69
- ),
70
- 'gemini-1.5-pro': pg.Dict(
71
- rpm=100,
72
- cost_per_1k_input_chars=0.0003125,
73
- cost_per_1k_output_chars=0.00125,
74
- ),
75
- 'gemini-1.5-flash': pg.Dict(
76
- rpm=500,
77
- cost_per_1k_input_chars=0.00001875,
78
- cost_per_1k_output_chars=0.000075,
79
- ),
80
- 'gemini-1.5-pro-preview-0514': pg.Dict(
81
- rpm=50,
82
- cost_per_1k_input_chars=0.0003125,
83
- cost_per_1k_output_chars=0.00125,
84
- ),
85
- 'gemini-1.5-pro-preview-0409': pg.Dict(
86
- rpm=50,
87
- cost_per_1k_input_chars=0.0003125,
88
- cost_per_1k_output_chars=0.00125,
89
- ),
90
- 'gemini-1.5-flash-preview-0514': pg.Dict(
91
- rpm=200,
92
- cost_per_1k_input_chars=0.00001875,
93
- cost_per_1k_output_chars=0.000075,
94
- ),
95
- 'gemini-1.0-pro': pg.Dict(
96
- rpm=300,
97
- cost_per_1k_input_chars=0.000125,
98
- cost_per_1k_output_chars=0.000375,
99
- ),
100
- 'gemini-1.0-pro-vision': pg.Dict(
101
- rpm=100,
102
- cost_per_1k_input_chars=0.000125,
103
- cost_per_1k_output_chars=0.000375,
104
- ),
105
- # TODO(sharatsharat): Update costs when published
106
- 'gemini-exp-1206': pg.Dict(
107
- rpm=20,
108
- cost_per_1k_input_chars=0.000,
109
- cost_per_1k_output_chars=0.000,
110
- ),
111
- # TODO(sharatsharat): Update costs when published
112
- 'gemini-2.0-flash-exp': pg.Dict(
113
- rpm=10,
114
- cost_per_1k_input_chars=0.000,
115
- cost_per_1k_output_chars=0.000,
116
- ),
117
- # TODO(yifenglu): Update costs when published
118
- 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
119
- rpm=10,
120
- cost_per_1k_input_chars=0.000,
121
- cost_per_1k_output_chars=0.000,
122
- ),
123
- # TODO(chengrun): Set a more appropriate rpm for endpoint.
124
- 'vertexai-endpoint': pg.Dict(
125
- rpm=20,
126
- cost_per_1k_input_chars=0.0000125,
127
- cost_per_1k_output_chars=0.0000375,
128
- ),
129
- }
130
-
131
-
132
39
  @lf.use_init_args(['model'])
133
40
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
134
- class VertexAI(rest.REST):
41
+ class VertexAI(gemini.Gemini):
135
42
  """Language model served on VertexAI with REST API."""
136
43
 
137
- model: pg.typing.Annotated[
138
- pg.typing.Enum(
139
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
140
- ),
141
- (
142
- 'Vertex AI model name with REST API support. See '
143
- 'https://cloud.google.com/vertex-ai/generative-ai/docs/'
144
- 'model-reference/inference#supported-models'
145
- ' for details.'
146
- ),
147
- ]
148
-
149
44
  project: Annotated[
150
45
  str | None,
151
46
  (
@@ -170,11 +65,6 @@ class VertexAI(rest.REST):
170
65
  ),
171
66
  ] = None
172
67
 
173
- supported_modalities: Annotated[
174
- list[str],
175
- 'A list of MIME types for supported modalities'
176
- ] = []
177
-
178
68
  def _on_bound(self):
179
69
  super()._on_bound()
180
70
  if google_auth is None:
@@ -209,31 +99,9 @@ class VertexAI(rest.REST):
209
99
  self._credentials = credentials
210
100
 
211
101
  @property
212
- def max_concurrency(self) -> int:
213
- """Returns the maximum number of concurrent requests."""
214
- return self.rate_to_max_concurrency(
215
- requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
216
- tokens_per_min=0,
217
- )
218
-
219
- def estimate_cost(
220
- self,
221
- num_input_tokens: int,
222
- num_output_tokens: int
223
- ) -> float | None:
224
- """Estimate the cost based on usage."""
225
- cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
226
- 'cost_per_1k_input_chars', None
227
- )
228
- cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
229
- 'cost_per_1k_output_chars', None
230
- )
231
- if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
232
- return None
233
- return (
234
- cost_per_1k_input_chars * num_input_tokens
235
- + cost_per_1k_output_chars * num_output_tokens
236
- ) * AVGERAGE_CHARS_PER_TOKEN / 1000
102
+ def model_id(self) -> str:
103
+ """Returns a string to identify the model."""
104
+ return f'VertexAI({self.model})'
237
105
 
238
106
  @functools.cached_property
239
107
  def _session(self):
@@ -244,12 +112,6 @@ class VertexAI(rest.REST):
244
112
  s.headers.update(self.headers or {})
245
113
  return s
246
114
 
247
- @property
248
- def headers(self):
249
- return {
250
- 'Content-Type': 'application/json; charset=utf-8',
251
- }
252
-
253
115
  @property
254
116
  def api_endpoint(self) -> str:
255
117
  return (
@@ -258,263 +120,69 @@ class VertexAI(rest.REST):
258
120
  f'models/{self.model}:generateContent'
259
121
  )
260
122
 
261
- def request(
262
- self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
263
- ) -> dict[str, Any]:
264
- request = dict(
265
- generationConfig=self._generation_config(prompt, sampling_options)
266
- )
267
- request['contents'] = [self._content_from_message(prompt)]
268
- return request
269
-
270
- def _generation_config(
271
- self, prompt: lf.Message, options: lf.LMSamplingOptions
272
- ) -> dict[str, Any]:
273
- """Returns a dict as generation config for prompt and LMSamplingOptions."""
274
- config = dict(
275
- temperature=options.temperature,
276
- maxOutputTokens=options.max_tokens,
277
- candidateCount=options.n,
278
- topK=options.top_k,
279
- topP=options.top_p,
280
- stopSequences=options.stop,
281
- seed=options.random_seed,
282
- responseLogprobs=options.logprobs,
283
- logprobs=options.top_logprobs,
284
- )
285
123
 
286
- if json_schema := prompt.metadata.get('json_schema'):
287
- if not isinstance(json_schema, dict):
288
- raise ValueError(
289
- f'`json_schema` must be a dict, got {json_schema!r}.'
290
- )
291
- json_schema = pg.to_json(json_schema)
292
- config['responseSchema'] = json_schema
293
- config['responseMimeType'] = 'application/json'
294
- prompt.metadata.formatted_text = (
295
- prompt.text
296
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
297
- + pg.to_json_str(json_schema, json_indent=2)
298
- )
299
- return config
300
-
301
- def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
302
- """Gets generation content from langfun message."""
303
- parts = []
304
- for lf_chunk in prompt.chunk():
305
- if isinstance(lf_chunk, str):
306
- parts.append({'text': lf_chunk})
307
- elif isinstance(lf_chunk, lf_modalities.Mime):
308
- try:
309
- modalities = lf_chunk.make_compatible(
310
- self.supported_modalities + ['text/plain']
311
- )
312
- if isinstance(modalities, lf_modalities.Mime):
313
- modalities = [modalities]
314
- for modality in modalities:
315
- if modality.is_text:
316
- parts.append({'text': modality.to_text()})
317
- else:
318
- parts.append({
319
- 'inlineData': {
320
- 'data': base64.b64encode(modality.to_bytes()).decode(),
321
- 'mimeType': modality.mime_type,
322
- }
323
- })
324
- except lf.ModalityError as e:
325
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
326
- else:
327
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
328
- return dict(role='user', parts=parts)
329
-
330
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
331
- messages = [
332
- self._message_from_content_parts(candidate['content']['parts'])
333
- for candidate in json['candidates']
334
- ]
335
- usage = json['usageMetadata']
336
- input_tokens = usage['promptTokenCount']
337
- output_tokens = usage['candidatesTokenCount']
338
- return lf.LMSamplingResult(
339
- [lf.LMSample(message) for message in messages],
340
- usage=lf.LMSamplingUsage(
341
- prompt_tokens=input_tokens,
342
- completion_tokens=output_tokens,
343
- total_tokens=input_tokens + output_tokens,
344
- estimated_cost=self.estimate_cost(
345
- num_input_tokens=input_tokens,
346
- num_output_tokens=output_tokens,
347
- ),
348
- ),
349
- )
124
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
125
+ """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
350
126
 
351
- def _message_from_content_parts(
352
- self, parts: list[dict[str, Any]]
353
- ) -> lf.Message:
354
- """Converts Vertex AI's content parts protocol to message."""
355
- chunks = []
356
- for part in parts:
357
- if text_part := part.get('text'):
358
- chunks.append(text_part)
359
- else:
360
- raise ValueError(f'Unsupported part: {part}')
361
- return lf.AIMessage.from_chunks(chunks)
362
-
363
-
364
- IMAGE_TYPES = [
365
- 'image/png',
366
- 'image/jpeg',
367
- 'image/webp',
368
- 'image/heic',
369
- 'image/heif',
370
- ]
371
-
372
- AUDIO_TYPES = [
373
- 'audio/aac',
374
- 'audio/flac',
375
- 'audio/mp3',
376
- 'audio/m4a',
377
- 'audio/mpeg',
378
- 'audio/mpga',
379
- 'audio/mp4',
380
- 'audio/opus',
381
- 'audio/pcm',
382
- 'audio/wav',
383
- 'audio/webm',
384
- ]
385
-
386
- VIDEO_TYPES = [
387
- 'video/mov',
388
- 'video/mpeg',
389
- 'video/mpegps',
390
- 'video/mpg',
391
- 'video/mp4',
392
- 'video/webm',
393
- 'video/wmv',
394
- 'video/x-flv',
395
- 'video/3gpp',
396
- 'video/quicktime',
397
- ]
398
-
399
- DOCUMENT_TYPES = [
400
- 'application/pdf',
401
- 'text/plain',
402
- 'text/csv',
403
- 'text/html',
404
- 'text/xml',
405
- 'text/x-script.python',
406
- 'application/json',
407
- ]
408
-
409
-
410
- class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
411
- """Vertex AI Gemini 2.0 model."""
412
-
413
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
414
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
415
- )
416
-
417
-
418
- class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
127
+ api_version = 'v1alpha'
128
+ model = 'gemini-2.0-flash-thinking-exp-1219'
129
+
130
+
131
+ class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
419
132
  """Vertex AI Gemini 2.0 Flash model."""
420
133
 
421
134
  model = 'gemini-2.0-flash-exp'
422
135
 
423
136
 
424
- class VertexAIGeminiFlash2_0ThinkingExp(VertexAIGemini2_0): # pylint: disable=invalid-name
425
- """Vertex AI Gemini 2.0 Flash model."""
137
+ class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
138
+ """Vertex AI Gemini Experimental model launched on 12/06/2024."""
426
139
 
427
- model = 'gemini-2.0-flash-thinking-exp-1219'
140
+ model = 'gemini-exp-1206'
428
141
 
429
142
 
430
- class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
431
- """Vertex AI Gemini 1.5 model."""
143
+ class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
144
+ """Vertex AI Gemini Experimental model launched on 11/14/2024."""
432
145
 
433
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
434
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
435
- )
146
+ model = 'gemini-exp-1114'
436
147
 
437
148
 
438
- class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
149
+ class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
439
150
  """Vertex AI Gemini 1.5 Pro model."""
440
151
 
441
- model = 'gemini-1.5-pro'
152
+ model = 'gemini-1.5-pro-latest'
442
153
 
443
154
 
444
- class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
155
+ class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
445
156
  """Vertex AI Gemini 1.5 Pro model."""
446
157
 
447
158
  model = 'gemini-1.5-pro-002'
448
159
 
449
160
 
450
- class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
161
+ class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
451
162
  """Vertex AI Gemini 1.5 Pro model."""
452
163
 
453
164
  model = 'gemini-1.5-pro-001'
454
165
 
455
166
 
456
- class VertexAIGeminiPro1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
457
- """Vertex AI Gemini 1.5 Pro preview model."""
458
-
459
- model = 'gemini-1.5-pro-preview-0514'
460
-
461
-
462
- class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name
463
- """Vertex AI Gemini 1.5 Pro preview model."""
464
-
465
- model = 'gemini-1.5-pro-preview-0409'
466
-
467
-
468
- class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
167
+ class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
469
168
  """Vertex AI Gemini 1.5 Flash model."""
470
169
 
471
170
  model = 'gemini-1.5-flash'
472
171
 
473
172
 
474
- class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
173
+ class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
475
174
  """Vertex AI Gemini 1.5 Flash model."""
476
175
 
477
176
  model = 'gemini-1.5-flash-002'
478
177
 
479
178
 
480
- class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
179
+ class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
481
180
  """Vertex AI Gemini 1.5 Flash model."""
482
181
 
483
182
  model = 'gemini-1.5-flash-001'
484
183
 
485
184
 
486
- class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
487
- """Vertex AI Gemini 1.5 Flash preview model."""
488
-
489
- model = 'gemini-1.5-flash-preview-0514'
490
-
491
-
492
185
  class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
493
186
  """Vertex AI Gemini 1.0 Pro model."""
494
187
 
495
188
  model = 'gemini-1.0-pro'
496
-
497
-
498
- class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
499
- """Vertex AI Gemini 1.0 Pro Vision model."""
500
-
501
- model = 'gemini-1.0-pro-vision'
502
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
503
- IMAGE_TYPES + VIDEO_TYPES
504
- )
505
-
506
-
507
- class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
508
- """Vertex AI Endpoint model."""
509
-
510
- model = 'vertexai-endpoint'
511
-
512
- endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
513
-
514
- @property
515
- def api_endpoint(self) -> str:
516
- return (
517
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
518
- f'{self.project}/locations/{self.location}/'
519
- f'endpoints/{self.endpoint}:generateContent'
520
- )