langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
+ import datetime
17
+ import functools
16
18
  import os
17
19
  from typing import Annotated, Any, Literal
18
20
 
@@ -49,9 +51,23 @@ class VertexAI(rest.REST):
49
51
  Please check out VertexAIGemini in `gemini.py` as an example.
50
52
  """
51
53
 
54
+ model: pg.typing.Annotated[
55
+ pg.typing.Enum(
56
+ pg.MISSING_VALUE,
57
+ [
58
+ m.model_id for m in gemini.SUPPORTED_MODELS
59
+ if m.provider == 'VertexAI' or (
60
+ isinstance(m.provider, pg.hyper.OneOf)
61
+ and 'VertexAI' in m.provider.candidates
62
+ )
63
+ ]
64
+ ),
65
+ 'The name of the model to use.',
66
+ ]
67
+
52
68
  model: Annotated[
53
69
  str | None,
54
- 'Model ID.'
70
+ 'Model name.'
55
71
  ] = None
56
72
 
57
73
  project: Annotated[
@@ -113,11 +129,6 @@ class VertexAI(rest.REST):
113
129
  )
114
130
  self._credentials = credentials
115
131
 
116
- @property
117
- def model_id(self) -> str:
118
- """Returns a string to identify the model."""
119
- return f'VertexAI({self.model})'
120
-
121
132
  def _session(self):
122
133
  assert self._credentials is not None
123
134
  assert auth_requests is not None
@@ -134,6 +145,9 @@ class VertexAI(rest.REST):
134
145
  class VertexAIGemini(VertexAI, gemini.Gemini):
135
146
  """Gemini models served by Vertex AI.."""
136
147
 
148
+ # Set default location to us-central1.
149
+ location = 'us-central1'
150
+
137
151
  @property
138
152
  def api_endpoint(self) -> str:
139
153
  assert self._api_initialized
@@ -143,104 +157,93 @@ class VertexAIGemini(VertexAI, gemini.Gemini):
143
157
  f'models/{self.model}:generateContent'
144
158
  )
145
159
 
160
+ @functools.cached_property
161
+ def model_info(self) -> gemini.GeminiModelInfo:
162
+ return super().model_info.clone(override=dict(provider='VertexAI'))
146
163
 
147
- class VertexAIGemini2Flash(VertexAIGemini): # pylint: disable=invalid-name
148
- """Gemini Flash 2.0 model launched on 02/05/2025."""
149
-
150
- api_version = 'v1beta'
151
- model = 'gemini-2.0-flash-001'
152
- location = 'us-central1'
153
-
154
-
155
- class VertexAIGemini2ProExp_20250205(VertexAIGemini): # pylint: disable=invalid-name
156
- """Gemini Flash 2.0 Pro model launched on 02/05/2025."""
157
-
158
- api_version = 'v1beta'
159
- model = 'gemini-2.0-pro-exp-02-05'
160
- location = 'us-central1'
161
-
162
-
163
- class VertexAIGemini2FlashThinkingExp_20250121(VertexAIGemini): # pylint: disable=invalid-name
164
- """Gemini Flash 2.0 Thinking model launched on 01/21/2025."""
165
-
166
- api_version = 'v1beta'
167
- model = 'gemini-2.0-flash-thinking-exp-01-21'
168
- timeout = None
169
- location = 'us-central1'
170
-
171
-
172
- class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAIGemini): # pylint: disable=invalid-name
173
- """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
164
+ #
165
+ # Production models.
166
+ #
174
167
 
175
- api_version = 'v1alpha'
176
- model = 'gemini-2.0-flash-thinking-exp-1219'
177
- timeout = None
178
168
 
169
+ class VertexAIGemini2Flash(VertexAIGemini): # pylint: disable=invalid-name
170
+ """Gemini Flash 2.0 model (latest stable)."""
171
+ model = 'gemini-2.0-flash'
179
172
 
180
- class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
181
- """Vertex AI Gemini 2.0 Flash model."""
182
173
 
183
- model = 'gemini-2.0-flash-exp'
174
+ class VertexAIGemini2Flash_001(VertexAIGemini): # pylint: disable=invalid-name
175
+ """Gemini Flash 2.0 model version 001."""
176
+ model = 'gemini-2.0-flash-001'
184
177
 
185
178
 
186
- class VertexAIGeminiExp_20241206(VertexAIGemini): # pylint: disable=invalid-name
187
- """Vertex AI Gemini Experimental model launched on 12/06/2024."""
179
+ class VertexAIGemini2FlashLitePreview_20250205(VertexAIGemini): # pylint: disable=invalid-name
180
+ """Gemini 2.0 Flash lite preview model launched on 02/05/2025."""
181
+ model = 'gemini-2.0-flash-lite-preview-02-05'
188
182
 
189
- model = 'gemini-exp-1206'
190
183
 
184
+ class VertexAIGemini15Pro(VertexAIGemini): # pylint: disable=invalid-name
185
+ """Vertex AI Gemini 1.5 Pro model (latest stable)."""
186
+ model = 'gemini-1.5-pro'
191
187
 
192
- class VertexAIGeminiExp_20241114(VertexAIGemini): # pylint: disable=invalid-name
193
- """Vertex AI Gemini Experimental model launched on 11/14/2024."""
194
188
 
195
- model = 'gemini-exp-1114'
189
+ class VertexAIGemini15Pro_002(VertexAIGemini): # pylint: disable=invalid-name
190
+ """Vertex AI Gemini 1.5 Pro model (version 002)."""
191
+ model = 'gemini-1.5-pro-002'
196
192
 
197
193
 
198
- class VertexAIGeminiPro1_5(VertexAIGemini): # pylint: disable=invalid-name
199
- """Vertex AI Gemini 1.5 Pro model."""
194
+ class VertexAIGemini15Pro_001(VertexAIGemini): # pylint: disable=invalid-name
195
+ """Vertex AI Gemini 1.5 Pro model (version 001)."""
196
+ model = 'gemini-1.5-pro-001'
200
197
 
201
- model = 'gemini-1.5-pro-002'
202
- location = 'us-central1'
203
198
 
199
+ class VertexAIGemini15Flash(VertexAIGemini): # pylint: disable=invalid-name
200
+ """Vertex AI Gemini 1.5 Flash model (latest stable)."""
201
+ model = 'gemini-1.5-flash'
204
202
 
205
- class VertexAIGeminiPro1_5_002(VertexAIGemini): # pylint: disable=invalid-name
206
- """Vertex AI Gemini 1.5 Pro model."""
207
203
 
208
- model = 'gemini-1.5-pro-002'
209
- location = 'us-central1'
204
+ class VertexAIGemini15Flash_002(VertexAIGemini): # pylint: disable=invalid-name
205
+ """Vertex AI Gemini 1.5 Flash model (version 002)."""
206
+ model = 'gemini-1.5-flash-002'
210
207
 
211
208
 
212
- class VertexAIGeminiPro1_5_001(VertexAIGemini): # pylint: disable=invalid-name
213
- """Vertex AI Gemini 1.5 Pro model."""
209
+ class VertexAIGemini15Flash_001(VertexAIGemini): # pylint: disable=invalid-name
210
+ """Vertex AI Gemini 1.5 Flash model (version 001)."""
211
+ model = 'gemini-1.5-flash-001'
214
212
 
215
- model = 'gemini-1.5-pro-001'
216
- location = 'us-central1'
217
213
 
214
+ class VertexAIGemini15Flash8B(VertexAIGemini): # pylint: disable=invalid-name
215
+ """Vertex AI Gemini 1.5 Flash 8B model (latest stable)."""
216
+ model = 'gemini-1.5-flash-8b'
218
217
 
219
- class VertexAIGeminiFlash1_5(VertexAIGemini): # pylint: disable=invalid-name
220
- """Vertex AI Gemini 1.5 Flash model."""
221
218
 
222
- model = 'gemini-1.5-flash'
223
- location = 'us-central1'
219
+ class VertexAIGemini15Flash8B_001(VertexAIGemini): # pylint: disable=invalid-name
220
+ """Vertex AI Gemini 1.5 Flash 8B model (version 001)."""
221
+ model = 'gemini-1.5-flash-8b-001'
224
222
 
223
+ #
224
+ # Experimental models.
225
+ #
225
226
 
226
- class VertexAIGeminiFlash1_5_002(VertexAIGemini): # pylint: disable=invalid-name
227
- """Vertex AI Gemini 1.5 Flash model."""
228
227
 
229
- model = 'gemini-1.5-flash-002'
230
- location = 'us-central1'
228
+ class VertexAIGemini2ProExp_20250205(VertexAIGemini): # pylint: disable=invalid-name
229
+ """Gemini Flash 2.0 Pro model launched on 02/05/2025."""
230
+ model = 'gemini-2.0-pro-exp-02-05'
231
231
 
232
232
 
233
- class VertexAIGeminiFlash1_5_001(VertexAIGemini): # pylint: disable=invalid-name
234
- """Vertex AI Gemini 1.5 Flash model."""
233
+ class VertexAIGemini2FlashThinkingExp_20250121(VertexAIGemini): # pylint: disable=invalid-name
234
+ """Gemini Flash 2.0 Thinking model launched on 01/21/2025."""
235
+ model = 'gemini-2.0-flash-thinking-exp-01-21'
236
+ timeout = None
235
237
 
236
- model = 'gemini-1.5-flash-001'
237
- location = 'us-central1'
238
238
 
239
+ class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
240
+ """Vertex AI Gemini 2.0 Flash model."""
241
+ model = 'gemini-2.0-flash-exp'
239
242
 
240
- class VertexAIGeminiPro1(VertexAIGemini): # pylint: disable=invalid-name
241
- """Vertex AI Gemini 1.0 Pro model."""
242
243
 
243
- model = 'gemini-1.0-pro'
244
+ class VertexAIGeminiExp_20241206(VertexAIGemini): # pylint: disable=invalid-name
245
+ """Vertex AI Gemini Experimental model launched on 12/06/2024."""
246
+ model = 'gemini-exp-1206'
244
247
 
245
248
 
246
249
  #
@@ -260,6 +263,17 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
260
263
 
261
264
  api_version = 'vertex-2023-10-16'
262
265
 
266
+ @functools.cached_property
267
+ def model_info(self) -> lf.ModelInfo:
268
+ mi = anthropic._SUPPORTED_MODELS_BY_MODEL_ID[self.model] # pylint: disable=protected-access
269
+ if mi.provider != 'VertexAI':
270
+ for m in anthropic.SUPPORTED_MODELS:
271
+ if m.provider == 'VertexAI' and m.alias_for == m.model_id:
272
+ mi = m
273
+ self.rebind(model=mi.model_id, skip_notification=True)
274
+ break
275
+ return mi
276
+
263
277
  @property
264
278
  def headers(self):
265
279
  return {
@@ -288,71 +302,102 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
288
302
  # pylint: disable=invalid-name
289
303
 
290
304
 
291
- class VertexAIClaude3_Opus_20240229(VertexAIAnthropic):
292
- """Anthropic's Claude 3 Opus model on VertexAI."""
293
- model = 'claude-3-opus@20240229'
294
-
295
-
296
- class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic):
305
+ class VertexAIClaude35Sonnet_20241022(VertexAIAnthropic):
297
306
  """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
298
307
  model = 'claude-3-5-sonnet-v2@20241022'
299
308
 
300
309
 
301
- class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic):
302
- """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
303
- model = 'claude-3-5-sonnet@20240620'
304
-
305
-
306
- class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic):
310
+ class VertexAIClaude35Haiku_20241022(VertexAIAnthropic):
307
311
  """Anthropic's Claude 3.5 Haiku model on VertexAI."""
308
312
  model = 'claude-3-5-haiku@20241022'
309
313
 
314
+
315
+ class VertexAIClaude3Opus_20240229(VertexAIAnthropic):
316
+ """Anthropic's Claude 3 Opus model on VertexAI."""
317
+ model = 'claude-3-opus@20240229'
318
+
310
319
  # pylint: enable=invalid-name
311
320
 
312
321
  #
313
322
  # Llama models on Vertex AI.
314
- # pylint: disable=line-too-long
315
- # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#meta-models
316
- # pylint: enable=line-too-long
323
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama
324
+ # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing#meta-models
317
325
 
318
- LLAMA_MODELS = {
319
- 'llama-3.2-90b-vision-instruct-maas': pg.Dict(
320
- latest_update='2024-09-25',
326
+ LLAMA_MODELS = [
327
+ lf.ModelInfo(
328
+ model_id='llama-3.1-405b-instruct-maas',
321
329
  in_service=True,
322
- rpm=0,
323
- tpm=0,
324
- # Free during preview.
325
- cost_per_1m_input_tokens=None,
326
- cost_per_1m_output_tokens=None,
330
+ model_type='instruction-tuned',
331
+ provider='VertexAI',
332
+ description=(
333
+ 'Llama 3.2 405B vision instruct model on VertexAI (Preview)'
334
+ ),
335
+ url='https://huggingface.co/meta-llama/Llama-3.1-405B',
336
+ release_date=datetime.datetime(2024, 7, 23),
337
+ context_length=lf.ModelInfo.ContextLength(
338
+ max_input_tokens=128_000,
339
+ max_output_tokens=8_192,
340
+ ),
341
+ pricing=lf.ModelInfo.Pricing(
342
+ cost_per_1m_input_tokens=5.0,
343
+ cost_per_1m_output_tokens=16.0,
344
+ ),
345
+ rate_limits=None,
327
346
  ),
328
- 'llama-3.1-405b-instruct-maas': pg.Dict(
329
- latest_update='2024-09-25',
347
+ lf.ModelInfo(
348
+ model_id='llama-3.2-90b-vision-instruct-maas',
330
349
  in_service=True,
331
- rpm=0,
332
- tpm=0,
333
- # GA.
334
- cost_per_1m_input_tokens=5,
335
- cost_per_1m_output_tokens=16,
350
+ model_type='instruction-tuned',
351
+ provider='VertexAI',
352
+ description=(
353
+ 'Llama 3.2 90B vision instruct model on VertexAI (Preview)'
354
+ ),
355
+ release_date=datetime.datetime(2024, 7, 23),
356
+ context_length=lf.ModelInfo.ContextLength(
357
+ max_input_tokens=128_000,
358
+ max_output_tokens=8_192,
359
+ ),
360
+ # Free during preview.
361
+ pricing=None,
362
+ rate_limits=None,
336
363
  ),
337
- 'llama-3.1-70b-instruct-maas': pg.Dict(
338
- latest_update='2024-09-25',
364
+ lf.ModelInfo(
365
+ model_id='llama-3.1-70b-instruct-maas',
339
366
  in_service=True,
340
- rpm=0,
341
- tpm=0,
367
+ model_type='instruction-tuned',
368
+ provider='VertexAI',
369
+ description=(
370
+ 'Llama 3.2 70B vision instruct model on VertexAI (Preview)'
371
+ ),
372
+ release_date=datetime.datetime(2024, 7, 23),
373
+ context_length=lf.ModelInfo.ContextLength(
374
+ max_input_tokens=128_000,
375
+ max_output_tokens=8_192,
376
+ ),
342
377
  # Free during preview.
343
- cost_per_1m_input_tokens=None,
344
- cost_per_1m_output_tokens=None,
378
+ pricing=None,
379
+ rate_limits=None,
345
380
  ),
346
- 'llama-3.1-8b-instruct-maas': pg.Dict(
347
- latest_update='2024-09-25',
381
+ lf.ModelInfo(
382
+ model_id='llama-3.1-8b-instruct-maas',
348
383
  in_service=True,
349
- rpm=0,
350
- tpm=0,
384
+ model_type='instruction-tuned',
385
+ provider='VertexAI',
386
+ description=(
387
+ 'Llama 3.2 8B vision instruct model on VertexAI (Preview)'
388
+ ),
389
+ release_date=datetime.datetime(2024, 7, 23),
390
+ context_length=lf.ModelInfo.ContextLength(
391
+ max_input_tokens=0,
392
+ max_output_tokens=0,
393
+ ),
351
394
  # Free during preview.
352
- cost_per_1m_input_tokens=None,
353
- cost_per_1m_output_tokens=None,
354
- )
355
- }
395
+ pricing=None,
396
+ rate_limits=None,
397
+ ),
398
+ ]
399
+
400
+ _LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
356
401
 
357
402
 
358
403
  @pg.use_init_args(['model'])
@@ -361,7 +406,7 @@ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
361
406
  """Llama models on VertexAI."""
362
407
 
363
408
  model: pg.typing.Annotated[
364
- pg.typing.Enum(pg.MISSING_VALUE, list(LLAMA_MODELS.keys())),
409
+ pg.typing.Enum(pg.MISSING_VALUE, [m.model_id for m in LLAMA_MODELS]),
365
410
  'Llama model ID.',
366
411
  ]
367
412
 
@@ -373,6 +418,10 @@ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
373
418
  )
374
419
  ] = 'us-central1'
375
420
 
421
+ @functools.cached_property
422
+ def model_info(self) -> lf.ModelInfo:
423
+ return _LLAMA_MODELS_BY_MODEL_ID[self.model]
424
+
376
425
  @property
377
426
  def api_endpoint(self) -> str:
378
427
  assert self._api_initialized
@@ -391,113 +440,77 @@ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
391
440
  request['model'] = f'meta/{self.model}'
392
441
  return request
393
442
 
394
- @property
395
- def max_concurrency(self) -> int:
396
- rpm = LLAMA_MODELS[self.model].get('rpm', 0)
397
- tpm = LLAMA_MODELS[self.model].get('tpm', 0)
398
- return self.rate_to_max_concurrency(
399
- requests_per_min=rpm, tokens_per_min=tpm
400
- )
401
-
402
- def estimate_cost(
403
- self,
404
- num_input_tokens: int,
405
- num_output_tokens: int
406
- ) -> float | None:
407
- """Estimate the cost based on usage."""
408
- cost_per_1m_input_tokens = LLAMA_MODELS[self.model].get(
409
- 'cost_per_1m_input_tokens', None
410
- )
411
- cost_per_1m_output_tokens = LLAMA_MODELS[self.model].get(
412
- 'cost_per_1m_output_tokens', None
413
- )
414
- if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
415
- return None
416
- return (
417
- cost_per_1m_input_tokens * num_input_tokens
418
- + cost_per_1m_output_tokens * num_output_tokens
419
- ) / 1000_000
420
-
421
443
 
422
444
  # pylint: disable=invalid-name
423
- class VertexAILlama3_2_90B(VertexAILlama):
445
+ class VertexAILlama32_90B(VertexAILlama):
424
446
  """Llama 3.2 90B vision instruct model on VertexAI."""
425
-
426
447
  model = 'llama-3.2-90b-vision-instruct-maas'
427
448
 
428
449
 
429
- class VertexAILlama3_1_405B(VertexAILlama):
450
+ class VertexAILlama31_405B(VertexAILlama):
430
451
  """Llama 3.1 405B vision instruct model on VertexAI."""
431
-
432
452
  model = 'llama-3.1-405b-instruct-maas'
433
453
 
434
454
 
435
- class VertexAILlama3_1_70B(VertexAILlama):
455
+ class VertexAILlama31_70B(VertexAILlama):
436
456
  """Llama 3.1 70B vision instruct model on VertexAI."""
437
-
438
457
  model = 'llama-3.1-70b-instruct-maas'
439
458
 
440
459
 
441
- class VertexAILlama3_1_8B(VertexAILlama):
460
+ class VertexAILlama31_8B(VertexAILlama):
442
461
  """Llama 3.1 8B vision instruct model on VertexAI."""
443
-
444
462
  model = 'llama-3.1-8b-instruct-maas'
463
+
464
+
445
465
  # pylint: enable=invalid-name
446
466
 
447
467
  #
448
468
  # Mistral models on Vertex AI.
449
469
  # pylint: disable=line-too-long
450
- # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#mistral-models
470
+ # Models: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral
471
+ # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing#mistral-models
451
472
  # pylint: enable=line-too-long
452
473
 
453
-
454
- MISTRAL_MODELS = {
455
- 'mistral-large-2411': pg.Dict(
456
- latest_update='2024-11-21',
457
- in_service=True,
458
- rpm=0,
459
- tpm=0,
460
- # GA.
461
- cost_per_1m_input_tokens=2,
462
- cost_per_1m_output_tokens=6,
463
- ),
464
- 'mistral-large@2407': pg.Dict(
465
- latest_update='2024-07-24',
466
- in_service=True,
467
- rpm=0,
468
- tpm=0,
469
- # GA.
470
- cost_per_1m_input_tokens=2,
471
- cost_per_1m_output_tokens=6,
472
- ),
473
- 'mistral-nemo@2407': pg.Dict(
474
- latest_update='2024-07-24',
474
+ MISTRAL_MODELS = [
475
+ lf.ModelInfo(
476
+ model_id='mistral-large-2411',
475
477
  in_service=True,
476
- rpm=0,
477
- tpm=0,
478
- # GA.
479
- cost_per_1m_input_tokens=0.15,
480
- cost_per_1m_output_tokens=0.15,
478
+ model_type='instruction-tuned',
479
+ provider='VertexAI',
480
+ description='Mistral Large model on VertexAI (GA) version 11/21/2024',
481
+ release_date=datetime.datetime(2024, 11, 21),
482
+ context_length=lf.ModelInfo.ContextLength(
483
+ max_input_tokens=128_000,
484
+ max_output_tokens=8_192,
485
+ ),
486
+ pricing=lf.ModelInfo.Pricing(
487
+ cost_per_1m_input_tokens=2.0,
488
+ cost_per_1m_output_tokens=6.0,
489
+ ),
490
+ rate_limits=None,
481
491
  ),
482
- 'codestral-2501': pg.Dict(
483
- latest_update='2025-01-13',
492
+ lf.ModelInfo(
493
+ model_id='codestral-2501',
484
494
  in_service=True,
485
- rpm=0,
486
- tpm=0,
487
- # GA.
488
- cost_per_1m_input_tokens=0.3,
489
- cost_per_1m_output_tokens=0.9,
495
+ model_type='instruction-tuned',
496
+ provider='VertexAI',
497
+ description=(
498
+ 'Mistral Codestral model on VertexAI (GA) (version 01/13/2025)'
499
+ ),
500
+ release_date=datetime.datetime(2025, 1, 13),
501
+ context_length=lf.ModelInfo.ContextLength(
502
+ max_input_tokens=128_000,
503
+ max_output_tokens=8_192,
504
+ ),
505
+ pricing=lf.ModelInfo.Pricing(
506
+ cost_per_1m_input_tokens=0.3,
507
+ cost_per_1m_output_tokens=0.9,
508
+ ),
509
+ rate_limits=None,
490
510
  ),
491
- 'codestral@2405': pg.Dict(
492
- latest_update='2024-05-29',
493
- in_service=True,
494
- rpm=0,
495
- tpm=0,
496
- # GA.
497
- cost_per_1m_input_tokens=0.2,
498
- cost_per_1m_output_tokens=0.6,
499
- ),
500
- }
511
+ ]
512
+
513
+ _MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
501
514
 
502
515
 
503
516
  @pg.use_init_args(['model'])
@@ -506,7 +519,7 @@ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
506
519
  """Mistral AI models on VertexAI."""
507
520
 
508
521
  model: pg.typing.Annotated[
509
- pg.typing.Enum(pg.MISSING_VALUE, list(MISTRAL_MODELS.keys())),
522
+ pg.typing.Enum(pg.MISSING_VALUE, [m.model_id for m in MISTRAL_MODELS]),
510
523
  'Mistral model ID.',
511
524
  ]
512
525
 
@@ -518,6 +531,10 @@ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
518
531
  )
519
532
  ] = 'us-central1'
520
533
 
534
+ @functools.cached_property
535
+ def model_info(self) -> lf.ModelInfo:
536
+ return _MISTRAL_MODELS_BY_MODEL_ID[self.model]
537
+
521
538
  @property
522
539
  def api_endpoint(self) -> str:
523
540
  assert self._api_initialized
@@ -527,61 +544,42 @@ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
527
544
  f'models/{self.model}:rawPredict'
528
545
  )
529
546
 
530
- @property
531
- def max_concurrency(self) -> int:
532
- rpm = MISTRAL_MODELS[self.model].get('rpm', 0)
533
- tpm = MISTRAL_MODELS[self.model].get('tpm', 0)
534
- return self.rate_to_max_concurrency(
535
- requests_per_min=rpm, tokens_per_min=tpm
536
- )
537
-
538
- def estimate_cost(
539
- self,
540
- num_input_tokens: int,
541
- num_output_tokens: int
542
- ) -> float | None:
543
- """Estimate the cost based on usage."""
544
- cost_per_1m_input_tokens = MISTRAL_MODELS[self.model].get(
545
- 'cost_per_1m_input_tokens', None
546
- )
547
- cost_per_1m_output_tokens = MISTRAL_MODELS[self.model].get(
548
- 'cost_per_1m_output_tokens', None
549
- )
550
- if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
551
- return None
552
- return (
553
- cost_per_1m_input_tokens * num_input_tokens
554
- + cost_per_1m_output_tokens * num_output_tokens
555
- ) / 1000_000
556
-
557
547
 
558
548
  # pylint: disable=invalid-name
559
549
  class VertexAIMistralLarge_20241121(VertexAIMistral):
560
550
  """Mistral Large model on VertexAI released on 2024/11/21."""
561
-
562
551
  model = 'mistral-large-2411'
563
552
 
564
553
 
565
- class VertexAIMistralLarge_20240724(VertexAIMistral):
566
- """Mistral Large model on VertexAI released on 2024/07/24."""
567
-
568
- model = 'mistral-large@2407'
554
+ class VertexAICodestral_20250113(VertexAIMistral):
555
+ """Mistral Nemo model on VertexAI released on 2024/07/24."""
556
+ model = 'codestral-2501'
569
557
 
558
+ # pylint: enable=invalid-name
570
559
 
571
- class VertexAIMistralNemo_20240724(VertexAIMistral):
572
- """Mistral Nemo model on VertexAI released on 2024/07/24."""
573
560
 
574
- model = 'mistral-nemo@2407'
561
+ #
562
+ # Register Vertex AI models so they can be retrieved with LanguageModel.get().
563
+ #
575
564
 
576
565
 
577
- class VertexAICodestral_20250113(VertexAIMistral):
578
- """Mistral Nemo model on VertexAI released on 2024/07/24."""
566
+ def _register_vertexai_models():
567
+ """Register Vertex AI models."""
568
+ for m in gemini.SUPPORTED_MODELS:
569
+ if m.provider == 'VertexAI' or (
570
+ isinstance(m.provider, pg.hyper.OneOf)
571
+ and 'VertexAI' in m.provider.candidates
572
+ ):
573
+ lf.LanguageModel.register(m.model_id, VertexAIGemini)
579
574
 
580
- model = 'codestral-2501'
575
+ for m in anthropic.SUPPORTED_MODELS:
576
+ if m.provider == 'VertexAI':
577
+ lf.LanguageModel.register(m.model_id, VertexAIAnthropic)
581
578
 
579
+ for m in LLAMA_MODELS:
580
+ lf.LanguageModel.register(m.model_id, VertexAILlama)
582
581
 
583
- class VertexAICodestral_20240529(VertexAIMistral):
584
- """Mistral Nemo model on VertexAI released on 2024/05/29."""
582
+ for m in MISTRAL_MODELS:
583
+ lf.LanguageModel.register(m.model_id, VertexAIMistral)
585
584
 
586
- model = 'codestral@2405'
587
- # pylint: enable=invalid-name
585
+ _register_vertexai_models()