langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +6 -2
- langfun/core/language_model.py +365 -22
- langfun/core/language_model_test.py +123 -35
- langfun/core/llms/__init__.py +50 -57
- langfun/core/llms/anthropic.py +434 -163
- langfun/core/llms/anthropic_test.py +20 -1
- langfun/core/llms/deepseek.py +90 -51
- langfun/core/llms/deepseek_test.py +15 -16
- langfun/core/llms/fake.py +6 -0
- langfun/core/llms/gemini.py +480 -390
- langfun/core/llms/gemini_test.py +27 -7
- langfun/core/llms/google_genai.py +80 -50
- langfun/core/llms/google_genai_test.py +11 -4
- langfun/core/llms/groq.py +268 -167
- langfun/core/llms/groq_test.py +9 -3
- langfun/core/llms/openai.py +839 -328
- langfun/core/llms/openai_compatible.py +3 -18
- langfun/core/llms/openai_compatible_test.py +20 -5
- langfun/core/llms/openai_test.py +14 -4
- langfun/core/llms/rest.py +11 -6
- langfun/core/llms/vertexai.py +238 -240
- langfun/core/llms/vertexai_test.py +35 -8
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/RECORD +27 -27
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -13,6 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Vertex AI generative models."""
|
15
15
|
|
16
|
+
import datetime
|
17
|
+
import functools
|
16
18
|
import os
|
17
19
|
from typing import Annotated, Any, Literal
|
18
20
|
|
@@ -49,9 +51,23 @@ class VertexAI(rest.REST):
|
|
49
51
|
Please check out VertexAIGemini in `gemini.py` as an example.
|
50
52
|
"""
|
51
53
|
|
54
|
+
model: pg.typing.Annotated[
|
55
|
+
pg.typing.Enum(
|
56
|
+
pg.MISSING_VALUE,
|
57
|
+
[
|
58
|
+
m.model_id for m in gemini.SUPPORTED_MODELS
|
59
|
+
if m.provider == 'VertexAI' or (
|
60
|
+
isinstance(m.provider, pg.hyper.OneOf)
|
61
|
+
and 'VertexAI' in m.provider.candidates
|
62
|
+
)
|
63
|
+
]
|
64
|
+
),
|
65
|
+
'The name of the model to use.',
|
66
|
+
]
|
67
|
+
|
52
68
|
model: Annotated[
|
53
69
|
str | None,
|
54
|
-
'Model
|
70
|
+
'Model name.'
|
55
71
|
] = None
|
56
72
|
|
57
73
|
project: Annotated[
|
@@ -113,11 +129,6 @@ class VertexAI(rest.REST):
|
|
113
129
|
)
|
114
130
|
self._credentials = credentials
|
115
131
|
|
116
|
-
@property
|
117
|
-
def model_id(self) -> str:
|
118
|
-
"""Returns a string to identify the model."""
|
119
|
-
return f'VertexAI({self.model})'
|
120
|
-
|
121
132
|
def _session(self):
|
122
133
|
assert self._credentials is not None
|
123
134
|
assert auth_requests is not None
|
@@ -134,6 +145,9 @@ class VertexAI(rest.REST):
|
|
134
145
|
class VertexAIGemini(VertexAI, gemini.Gemini):
|
135
146
|
"""Gemini models served by Vertex AI.."""
|
136
147
|
|
148
|
+
# Set default location to us-central1.
|
149
|
+
location = 'us-central1'
|
150
|
+
|
137
151
|
@property
|
138
152
|
def api_endpoint(self) -> str:
|
139
153
|
assert self._api_initialized
|
@@ -143,104 +157,93 @@ class VertexAIGemini(VertexAI, gemini.Gemini):
|
|
143
157
|
f'models/{self.model}:generateContent'
|
144
158
|
)
|
145
159
|
|
160
|
+
@functools.cached_property
|
161
|
+
def model_info(self) -> gemini.GeminiModelInfo:
|
162
|
+
return super().model_info.clone(override=dict(provider='VertexAI'))
|
146
163
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
api_version = 'v1beta'
|
151
|
-
model = 'gemini-2.0-flash-001'
|
152
|
-
location = 'us-central1'
|
153
|
-
|
154
|
-
|
155
|
-
class VertexAIGemini2ProExp_20250205(VertexAIGemini): # pylint: disable=invalid-name
|
156
|
-
"""Gemini Flash 2.0 Pro model launched on 02/05/2025."""
|
157
|
-
|
158
|
-
api_version = 'v1beta'
|
159
|
-
model = 'gemini-2.0-pro-exp-02-05'
|
160
|
-
location = 'us-central1'
|
161
|
-
|
162
|
-
|
163
|
-
class VertexAIGemini2FlashThinkingExp_20250121(VertexAIGemini): # pylint: disable=invalid-name
|
164
|
-
"""Gemini Flash 2.0 Thinking model launched on 01/21/2025."""
|
165
|
-
|
166
|
-
api_version = 'v1beta'
|
167
|
-
model = 'gemini-2.0-flash-thinking-exp-01-21'
|
168
|
-
timeout = None
|
169
|
-
location = 'us-central1'
|
170
|
-
|
171
|
-
|
172
|
-
class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAIGemini): # pylint: disable=invalid-name
|
173
|
-
"""Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
164
|
+
#
|
165
|
+
# Production models.
|
166
|
+
#
|
174
167
|
|
175
|
-
api_version = 'v1alpha'
|
176
|
-
model = 'gemini-2.0-flash-thinking-exp-1219'
|
177
|
-
timeout = None
|
178
168
|
|
169
|
+
class VertexAIGemini2Flash(VertexAIGemini): # pylint: disable=invalid-name
|
170
|
+
"""Gemini Flash 2.0 model (latest stable)."""
|
171
|
+
model = 'gemini-2.0-flash'
|
179
172
|
|
180
|
-
class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
|
181
|
-
"""Vertex AI Gemini 2.0 Flash model."""
|
182
173
|
|
183
|
-
|
174
|
+
class VertexAIGemini2Flash_001(VertexAIGemini): # pylint: disable=invalid-name
|
175
|
+
"""Gemini Flash 2.0 model version 001."""
|
176
|
+
model = 'gemini-2.0-flash-001'
|
184
177
|
|
185
178
|
|
186
|
-
class
|
187
|
-
"""
|
179
|
+
class VertexAIGemini2FlashLitePreview_20250205(VertexAIGemini): # pylint: disable=invalid-name
|
180
|
+
"""Gemini 2.0 Flash lite preview model launched on 02/05/2025."""
|
181
|
+
model = 'gemini-2.0-flash-lite-preview-02-05'
|
188
182
|
|
189
|
-
model = 'gemini-exp-1206'
|
190
183
|
|
184
|
+
class VertexAIGemini15Pro(VertexAIGemini): # pylint: disable=invalid-name
|
185
|
+
"""Vertex AI Gemini 1.5 Pro model (latest stable)."""
|
186
|
+
model = 'gemini-1.5-pro'
|
191
187
|
|
192
|
-
class VertexAIGeminiExp_20241114(VertexAIGemini): # pylint: disable=invalid-name
|
193
|
-
"""Vertex AI Gemini Experimental model launched on 11/14/2024."""
|
194
188
|
|
195
|
-
|
189
|
+
class VertexAIGemini15Pro_002(VertexAIGemini): # pylint: disable=invalid-name
|
190
|
+
"""Vertex AI Gemini 1.5 Pro model (version 002)."""
|
191
|
+
model = 'gemini-1.5-pro-002'
|
196
192
|
|
197
193
|
|
198
|
-
class
|
199
|
-
"""Vertex AI Gemini 1.5 Pro model."""
|
194
|
+
class VertexAIGemini15Pro_001(VertexAIGemini): # pylint: disable=invalid-name
|
195
|
+
"""Vertex AI Gemini 1.5 Pro model (version 001)."""
|
196
|
+
model = 'gemini-1.5-pro-001'
|
200
197
|
|
201
|
-
model = 'gemini-1.5-pro-002'
|
202
|
-
location = 'us-central1'
|
203
198
|
|
199
|
+
class VertexAIGemini15Flash(VertexAIGemini): # pylint: disable=invalid-name
|
200
|
+
"""Vertex AI Gemini 1.5 Flash model (latest stable)."""
|
201
|
+
model = 'gemini-1.5-flash'
|
204
202
|
|
205
|
-
class VertexAIGeminiPro1_5_002(VertexAIGemini): # pylint: disable=invalid-name
|
206
|
-
"""Vertex AI Gemini 1.5 Pro model."""
|
207
203
|
|
208
|
-
|
209
|
-
|
204
|
+
class VertexAIGemini15Flash_002(VertexAIGemini): # pylint: disable=invalid-name
|
205
|
+
"""Vertex AI Gemini 1.5 Flash model (version 002)."""
|
206
|
+
model = 'gemini-1.5-flash-002'
|
210
207
|
|
211
208
|
|
212
|
-
class
|
213
|
-
"""Vertex AI Gemini 1.5
|
209
|
+
class VertexAIGemini15Flash_001(VertexAIGemini): # pylint: disable=invalid-name
|
210
|
+
"""Vertex AI Gemini 1.5 Flash model (version 001)."""
|
211
|
+
model = 'gemini-1.5-flash-001'
|
214
212
|
|
215
|
-
model = 'gemini-1.5-pro-001'
|
216
|
-
location = 'us-central1'
|
217
213
|
|
214
|
+
class VertexAIGemini15Flash8B(VertexAIGemini): # pylint: disable=invalid-name
|
215
|
+
"""Vertex AI Gemini 1.5 Flash 8B model (latest stable)."""
|
216
|
+
model = 'gemini-1.5-flash-8b'
|
218
217
|
|
219
|
-
class VertexAIGeminiFlash1_5(VertexAIGemini): # pylint: disable=invalid-name
|
220
|
-
"""Vertex AI Gemini 1.5 Flash model."""
|
221
218
|
|
222
|
-
|
223
|
-
|
219
|
+
class VertexAIGemini15Flash8B_001(VertexAIGemini): # pylint: disable=invalid-name
|
220
|
+
"""Vertex AI Gemini 1.5 Flash 8B model (version 001)."""
|
221
|
+
model = 'gemini-1.5-flash-8b-001'
|
224
222
|
|
223
|
+
#
|
224
|
+
# Experimental models.
|
225
|
+
#
|
225
226
|
|
226
|
-
class VertexAIGeminiFlash1_5_002(VertexAIGemini): # pylint: disable=invalid-name
|
227
|
-
"""Vertex AI Gemini 1.5 Flash model."""
|
228
227
|
|
229
|
-
|
230
|
-
|
228
|
+
class VertexAIGemini2ProExp_20250205(VertexAIGemini): # pylint: disable=invalid-name
|
229
|
+
"""Gemini Flash 2.0 Pro model launched on 02/05/2025."""
|
230
|
+
model = 'gemini-2.0-pro-exp-02-05'
|
231
231
|
|
232
232
|
|
233
|
-
class
|
234
|
-
"""
|
233
|
+
class VertexAIGemini2FlashThinkingExp_20250121(VertexAIGemini): # pylint: disable=invalid-name
|
234
|
+
"""Gemini Flash 2.0 Thinking model launched on 01/21/2025."""
|
235
|
+
model = 'gemini-2.0-flash-thinking-exp-01-21'
|
236
|
+
timeout = None
|
235
237
|
|
236
|
-
model = 'gemini-1.5-flash-001'
|
237
|
-
location = 'us-central1'
|
238
238
|
|
239
|
+
class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
|
240
|
+
"""Vertex AI Gemini 2.0 Flash model."""
|
241
|
+
model = 'gemini-2.0-flash-exp'
|
239
242
|
|
240
|
-
class VertexAIGeminiPro1(VertexAIGemini): # pylint: disable=invalid-name
|
241
|
-
"""Vertex AI Gemini 1.0 Pro model."""
|
242
243
|
|
243
|
-
|
244
|
+
class VertexAIGeminiExp_20241206(VertexAIGemini): # pylint: disable=invalid-name
|
245
|
+
"""Vertex AI Gemini Experimental model launched on 12/06/2024."""
|
246
|
+
model = 'gemini-exp-1206'
|
244
247
|
|
245
248
|
|
246
249
|
#
|
@@ -260,6 +263,17 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
|
|
260
263
|
|
261
264
|
api_version = 'vertex-2023-10-16'
|
262
265
|
|
266
|
+
@functools.cached_property
|
267
|
+
def model_info(self) -> lf.ModelInfo:
|
268
|
+
mi = anthropic._SUPPORTED_MODELS_BY_MODEL_ID[self.model] # pylint: disable=protected-access
|
269
|
+
if mi.provider != 'VertexAI':
|
270
|
+
for m in anthropic.SUPPORTED_MODELS:
|
271
|
+
if m.provider == 'VertexAI' and m.alias_for == m.model_id:
|
272
|
+
mi = m
|
273
|
+
self.rebind(model=mi.model_id, skip_notification=True)
|
274
|
+
break
|
275
|
+
return mi
|
276
|
+
|
263
277
|
@property
|
264
278
|
def headers(self):
|
265
279
|
return {
|
@@ -288,71 +302,102 @@ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
|
|
288
302
|
# pylint: disable=invalid-name
|
289
303
|
|
290
304
|
|
291
|
-
class
|
292
|
-
"""Anthropic's Claude 3 Opus model on VertexAI."""
|
293
|
-
model = 'claude-3-opus@20240229'
|
294
|
-
|
295
|
-
|
296
|
-
class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic):
|
305
|
+
class VertexAIClaude35Sonnet_20241022(VertexAIAnthropic):
|
297
306
|
"""Anthropic's Claude 3.5 Sonnet model on VertexAI."""
|
298
307
|
model = 'claude-3-5-sonnet-v2@20241022'
|
299
308
|
|
300
309
|
|
301
|
-
class
|
302
|
-
"""Anthropic's Claude 3.5 Sonnet model on VertexAI."""
|
303
|
-
model = 'claude-3-5-sonnet@20240620'
|
304
|
-
|
305
|
-
|
306
|
-
class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic):
|
310
|
+
class VertexAIClaude35Haiku_20241022(VertexAIAnthropic):
|
307
311
|
"""Anthropic's Claude 3.5 Haiku model on VertexAI."""
|
308
312
|
model = 'claude-3-5-haiku@20241022'
|
309
313
|
|
314
|
+
|
315
|
+
class VertexAIClaude3Opus_20240229(VertexAIAnthropic):
|
316
|
+
"""Anthropic's Claude 3 Opus model on VertexAI."""
|
317
|
+
model = 'claude-3-opus@20240229'
|
318
|
+
|
310
319
|
# pylint: enable=invalid-name
|
311
320
|
|
312
321
|
#
|
313
322
|
# Llama models on Vertex AI.
|
314
|
-
#
|
315
|
-
# Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing
|
316
|
-
# pylint: enable=line-too-long
|
323
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama
|
324
|
+
# Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing#meta-models
|
317
325
|
|
318
|
-
LLAMA_MODELS =
|
319
|
-
|
320
|
-
|
326
|
+
LLAMA_MODELS = [
|
327
|
+
lf.ModelInfo(
|
328
|
+
model_id='llama-3.1-405b-instruct-maas',
|
321
329
|
in_service=True,
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
330
|
+
model_type='instruction-tuned',
|
331
|
+
provider='VertexAI',
|
332
|
+
description=(
|
333
|
+
'Llama 3.2 405B vision instruct model on VertexAI (Preview)'
|
334
|
+
),
|
335
|
+
url='https://huggingface.co/meta-llama/Llama-3.1-405B',
|
336
|
+
release_date=datetime.datetime(2024, 7, 23),
|
337
|
+
context_length=lf.ModelInfo.ContextLength(
|
338
|
+
max_input_tokens=128_000,
|
339
|
+
max_output_tokens=8_192,
|
340
|
+
),
|
341
|
+
pricing=lf.ModelInfo.Pricing(
|
342
|
+
cost_per_1m_input_tokens=5.0,
|
343
|
+
cost_per_1m_output_tokens=16.0,
|
344
|
+
),
|
345
|
+
rate_limits=None,
|
327
346
|
),
|
328
|
-
|
329
|
-
|
347
|
+
lf.ModelInfo(
|
348
|
+
model_id='llama-3.2-90b-vision-instruct-maas',
|
330
349
|
in_service=True,
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
350
|
+
model_type='instruction-tuned',
|
351
|
+
provider='VertexAI',
|
352
|
+
description=(
|
353
|
+
'Llama 3.2 90B vision instruct model on VertexAI (Preview)'
|
354
|
+
),
|
355
|
+
release_date=datetime.datetime(2024, 7, 23),
|
356
|
+
context_length=lf.ModelInfo.ContextLength(
|
357
|
+
max_input_tokens=128_000,
|
358
|
+
max_output_tokens=8_192,
|
359
|
+
),
|
360
|
+
# Free during preview.
|
361
|
+
pricing=None,
|
362
|
+
rate_limits=None,
|
336
363
|
),
|
337
|
-
|
338
|
-
|
364
|
+
lf.ModelInfo(
|
365
|
+
model_id='llama-3.1-70b-instruct-maas',
|
339
366
|
in_service=True,
|
340
|
-
|
341
|
-
|
367
|
+
model_type='instruction-tuned',
|
368
|
+
provider='VertexAI',
|
369
|
+
description=(
|
370
|
+
'Llama 3.2 70B vision instruct model on VertexAI (Preview)'
|
371
|
+
),
|
372
|
+
release_date=datetime.datetime(2024, 7, 23),
|
373
|
+
context_length=lf.ModelInfo.ContextLength(
|
374
|
+
max_input_tokens=128_000,
|
375
|
+
max_output_tokens=8_192,
|
376
|
+
),
|
342
377
|
# Free during preview.
|
343
|
-
|
344
|
-
|
378
|
+
pricing=None,
|
379
|
+
rate_limits=None,
|
345
380
|
),
|
346
|
-
|
347
|
-
|
381
|
+
lf.ModelInfo(
|
382
|
+
model_id='llama-3.1-8b-instruct-maas',
|
348
383
|
in_service=True,
|
349
|
-
|
350
|
-
|
384
|
+
model_type='instruction-tuned',
|
385
|
+
provider='VertexAI',
|
386
|
+
description=(
|
387
|
+
'Llama 3.2 8B vision instruct model on VertexAI (Preview)'
|
388
|
+
),
|
389
|
+
release_date=datetime.datetime(2024, 7, 23),
|
390
|
+
context_length=lf.ModelInfo.ContextLength(
|
391
|
+
max_input_tokens=0,
|
392
|
+
max_output_tokens=0,
|
393
|
+
),
|
351
394
|
# Free during preview.
|
352
|
-
|
353
|
-
|
354
|
-
)
|
355
|
-
|
395
|
+
pricing=None,
|
396
|
+
rate_limits=None,
|
397
|
+
),
|
398
|
+
]
|
399
|
+
|
400
|
+
_LLAMA_MODELS_BY_MODEL_ID = {m.model_id: m for m in LLAMA_MODELS}
|
356
401
|
|
357
402
|
|
358
403
|
@pg.use_init_args(['model'])
|
@@ -361,7 +406,7 @@ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
|
|
361
406
|
"""Llama models on VertexAI."""
|
362
407
|
|
363
408
|
model: pg.typing.Annotated[
|
364
|
-
pg.typing.Enum(pg.MISSING_VALUE,
|
409
|
+
pg.typing.Enum(pg.MISSING_VALUE, [m.model_id for m in LLAMA_MODELS]),
|
365
410
|
'Llama model ID.',
|
366
411
|
]
|
367
412
|
|
@@ -373,6 +418,10 @@ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
|
|
373
418
|
)
|
374
419
|
] = 'us-central1'
|
375
420
|
|
421
|
+
@functools.cached_property
|
422
|
+
def model_info(self) -> lf.ModelInfo:
|
423
|
+
return _LLAMA_MODELS_BY_MODEL_ID[self.model]
|
424
|
+
|
376
425
|
@property
|
377
426
|
def api_endpoint(self) -> str:
|
378
427
|
assert self._api_initialized
|
@@ -391,113 +440,77 @@ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
|
|
391
440
|
request['model'] = f'meta/{self.model}'
|
392
441
|
return request
|
393
442
|
|
394
|
-
@property
|
395
|
-
def max_concurrency(self) -> int:
|
396
|
-
rpm = LLAMA_MODELS[self.model].get('rpm', 0)
|
397
|
-
tpm = LLAMA_MODELS[self.model].get('tpm', 0)
|
398
|
-
return self.rate_to_max_concurrency(
|
399
|
-
requests_per_min=rpm, tokens_per_min=tpm
|
400
|
-
)
|
401
|
-
|
402
|
-
def estimate_cost(
|
403
|
-
self,
|
404
|
-
num_input_tokens: int,
|
405
|
-
num_output_tokens: int
|
406
|
-
) -> float | None:
|
407
|
-
"""Estimate the cost based on usage."""
|
408
|
-
cost_per_1m_input_tokens = LLAMA_MODELS[self.model].get(
|
409
|
-
'cost_per_1m_input_tokens', None
|
410
|
-
)
|
411
|
-
cost_per_1m_output_tokens = LLAMA_MODELS[self.model].get(
|
412
|
-
'cost_per_1m_output_tokens', None
|
413
|
-
)
|
414
|
-
if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
|
415
|
-
return None
|
416
|
-
return (
|
417
|
-
cost_per_1m_input_tokens * num_input_tokens
|
418
|
-
+ cost_per_1m_output_tokens * num_output_tokens
|
419
|
-
) / 1000_000
|
420
|
-
|
421
443
|
|
422
444
|
# pylint: disable=invalid-name
|
423
|
-
class
|
445
|
+
class VertexAILlama32_90B(VertexAILlama):
|
424
446
|
"""Llama 3.2 90B vision instruct model on VertexAI."""
|
425
|
-
|
426
447
|
model = 'llama-3.2-90b-vision-instruct-maas'
|
427
448
|
|
428
449
|
|
429
|
-
class
|
450
|
+
class VertexAILlama31_405B(VertexAILlama):
|
430
451
|
"""Llama 3.1 405B vision instruct model on VertexAI."""
|
431
|
-
|
432
452
|
model = 'llama-3.1-405b-instruct-maas'
|
433
453
|
|
434
454
|
|
435
|
-
class
|
455
|
+
class VertexAILlama31_70B(VertexAILlama):
|
436
456
|
"""Llama 3.1 70B vision instruct model on VertexAI."""
|
437
|
-
|
438
457
|
model = 'llama-3.1-70b-instruct-maas'
|
439
458
|
|
440
459
|
|
441
|
-
class
|
460
|
+
class VertexAILlama31_8B(VertexAILlama):
|
442
461
|
"""Llama 3.1 8B vision instruct model on VertexAI."""
|
443
|
-
|
444
462
|
model = 'llama-3.1-8b-instruct-maas'
|
463
|
+
|
464
|
+
|
445
465
|
# pylint: enable=invalid-name
|
446
466
|
|
447
467
|
#
|
448
468
|
# Mistral models on Vertex AI.
|
449
469
|
# pylint: disable=line-too-long
|
450
|
-
#
|
470
|
+
# Models: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral
|
471
|
+
# Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing#mistral-models
|
451
472
|
# pylint: enable=line-too-long
|
452
473
|
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
latest_update='2024-11-21',
|
457
|
-
in_service=True,
|
458
|
-
rpm=0,
|
459
|
-
tpm=0,
|
460
|
-
# GA.
|
461
|
-
cost_per_1m_input_tokens=2,
|
462
|
-
cost_per_1m_output_tokens=6,
|
463
|
-
),
|
464
|
-
'mistral-large@2407': pg.Dict(
|
465
|
-
latest_update='2024-07-24',
|
466
|
-
in_service=True,
|
467
|
-
rpm=0,
|
468
|
-
tpm=0,
|
469
|
-
# GA.
|
470
|
-
cost_per_1m_input_tokens=2,
|
471
|
-
cost_per_1m_output_tokens=6,
|
472
|
-
),
|
473
|
-
'mistral-nemo@2407': pg.Dict(
|
474
|
-
latest_update='2024-07-24',
|
474
|
+
MISTRAL_MODELS = [
|
475
|
+
lf.ModelInfo(
|
476
|
+
model_id='mistral-large-2411',
|
475
477
|
in_service=True,
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
478
|
+
model_type='instruction-tuned',
|
479
|
+
provider='VertexAI',
|
480
|
+
description='Mistral Large model on VertexAI (GA) version 11/21/2024',
|
481
|
+
release_date=datetime.datetime(2024, 11, 21),
|
482
|
+
context_length=lf.ModelInfo.ContextLength(
|
483
|
+
max_input_tokens=128_000,
|
484
|
+
max_output_tokens=8_192,
|
485
|
+
),
|
486
|
+
pricing=lf.ModelInfo.Pricing(
|
487
|
+
cost_per_1m_input_tokens=2.0,
|
488
|
+
cost_per_1m_output_tokens=6.0,
|
489
|
+
),
|
490
|
+
rate_limits=None,
|
481
491
|
),
|
482
|
-
|
483
|
-
|
492
|
+
lf.ModelInfo(
|
493
|
+
model_id='codestral-2501',
|
484
494
|
in_service=True,
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
495
|
+
model_type='instruction-tuned',
|
496
|
+
provider='VertexAI',
|
497
|
+
description=(
|
498
|
+
'Mistral Codestral model on VertexAI (GA) (version 01/13/2025)'
|
499
|
+
),
|
500
|
+
release_date=datetime.datetime(2025, 1, 13),
|
501
|
+
context_length=lf.ModelInfo.ContextLength(
|
502
|
+
max_input_tokens=128_000,
|
503
|
+
max_output_tokens=8_192,
|
504
|
+
),
|
505
|
+
pricing=lf.ModelInfo.Pricing(
|
506
|
+
cost_per_1m_input_tokens=0.3,
|
507
|
+
cost_per_1m_output_tokens=0.9,
|
508
|
+
),
|
509
|
+
rate_limits=None,
|
490
510
|
),
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
rpm=0,
|
495
|
-
tpm=0,
|
496
|
-
# GA.
|
497
|
-
cost_per_1m_input_tokens=0.2,
|
498
|
-
cost_per_1m_output_tokens=0.6,
|
499
|
-
),
|
500
|
-
}
|
511
|
+
]
|
512
|
+
|
513
|
+
_MISTRAL_MODELS_BY_MODEL_ID = {m.model_id: m for m in MISTRAL_MODELS}
|
501
514
|
|
502
515
|
|
503
516
|
@pg.use_init_args(['model'])
|
@@ -506,7 +519,7 @@ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
|
|
506
519
|
"""Mistral AI models on VertexAI."""
|
507
520
|
|
508
521
|
model: pg.typing.Annotated[
|
509
|
-
pg.typing.Enum(pg.MISSING_VALUE,
|
522
|
+
pg.typing.Enum(pg.MISSING_VALUE, [m.model_id for m in MISTRAL_MODELS]),
|
510
523
|
'Mistral model ID.',
|
511
524
|
]
|
512
525
|
|
@@ -518,6 +531,10 @@ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
|
|
518
531
|
)
|
519
532
|
] = 'us-central1'
|
520
533
|
|
534
|
+
@functools.cached_property
|
535
|
+
def model_info(self) -> lf.ModelInfo:
|
536
|
+
return _MISTRAL_MODELS_BY_MODEL_ID[self.model]
|
537
|
+
|
521
538
|
@property
|
522
539
|
def api_endpoint(self) -> str:
|
523
540
|
assert self._api_initialized
|
@@ -527,61 +544,42 @@ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
|
|
527
544
|
f'models/{self.model}:rawPredict'
|
528
545
|
)
|
529
546
|
|
530
|
-
@property
|
531
|
-
def max_concurrency(self) -> int:
|
532
|
-
rpm = MISTRAL_MODELS[self.model].get('rpm', 0)
|
533
|
-
tpm = MISTRAL_MODELS[self.model].get('tpm', 0)
|
534
|
-
return self.rate_to_max_concurrency(
|
535
|
-
requests_per_min=rpm, tokens_per_min=tpm
|
536
|
-
)
|
537
|
-
|
538
|
-
def estimate_cost(
|
539
|
-
self,
|
540
|
-
num_input_tokens: int,
|
541
|
-
num_output_tokens: int
|
542
|
-
) -> float | None:
|
543
|
-
"""Estimate the cost based on usage."""
|
544
|
-
cost_per_1m_input_tokens = MISTRAL_MODELS[self.model].get(
|
545
|
-
'cost_per_1m_input_tokens', None
|
546
|
-
)
|
547
|
-
cost_per_1m_output_tokens = MISTRAL_MODELS[self.model].get(
|
548
|
-
'cost_per_1m_output_tokens', None
|
549
|
-
)
|
550
|
-
if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
|
551
|
-
return None
|
552
|
-
return (
|
553
|
-
cost_per_1m_input_tokens * num_input_tokens
|
554
|
-
+ cost_per_1m_output_tokens * num_output_tokens
|
555
|
-
) / 1000_000
|
556
|
-
|
557
547
|
|
558
548
|
# pylint: disable=invalid-name
|
559
549
|
class VertexAIMistralLarge_20241121(VertexAIMistral):
|
560
550
|
"""Mistral Large model on VertexAI released on 2024/11/21."""
|
561
|
-
|
562
551
|
model = 'mistral-large-2411'
|
563
552
|
|
564
553
|
|
565
|
-
class
|
566
|
-
"""Mistral
|
567
|
-
|
568
|
-
model = 'mistral-large@2407'
|
554
|
+
class VertexAICodestral_20250113(VertexAIMistral):
|
555
|
+
"""Mistral Nemo model on VertexAI released on 2024/07/24."""
|
556
|
+
model = 'codestral-2501'
|
569
557
|
|
558
|
+
# pylint: enable=invalid-name
|
570
559
|
|
571
|
-
class VertexAIMistralNemo_20240724(VertexAIMistral):
|
572
|
-
"""Mistral Nemo model on VertexAI released on 2024/07/24."""
|
573
560
|
|
574
|
-
|
561
|
+
#
|
562
|
+
# Register Vertex AI models so they can be retrieved with LanguageModel.get().
|
563
|
+
#
|
575
564
|
|
576
565
|
|
577
|
-
|
578
|
-
"""
|
566
|
+
def _register_vertexai_models():
|
567
|
+
"""Register Vertex AI models."""
|
568
|
+
for m in gemini.SUPPORTED_MODELS:
|
569
|
+
if m.provider == 'VertexAI' or (
|
570
|
+
isinstance(m.provider, pg.hyper.OneOf)
|
571
|
+
and 'VertexAI' in m.provider.candidates
|
572
|
+
):
|
573
|
+
lf.LanguageModel.register(m.model_id, VertexAIGemini)
|
579
574
|
|
580
|
-
|
575
|
+
for m in anthropic.SUPPORTED_MODELS:
|
576
|
+
if m.provider == 'VertexAI':
|
577
|
+
lf.LanguageModel.register(m.model_id, VertexAIAnthropic)
|
581
578
|
|
579
|
+
for m in LLAMA_MODELS:
|
580
|
+
lf.LanguageModel.register(m.model_id, VertexAILlama)
|
582
581
|
|
583
|
-
|
584
|
-
|
582
|
+
for m in MISTRAL_MODELS:
|
583
|
+
lf.LanguageModel.register(m.model_id, VertexAIMistral)
|
585
584
|
|
586
|
-
|
587
|
-
# pylint: enable=invalid-name
|
585
|
+
_register_vertexai_models()
|