langfun 0.1.2.dev202501080804__py3-none-any.whl → 0.1.2.dev202501240804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +1 -6
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +4 -7
- langfun/core/coding/python/correction_test.py +2 -3
- langfun/core/coding/python/execution.py +22 -211
- langfun/core/coding/python/execution_test.py +11 -90
- langfun/core/coding/python/generation.py +3 -2
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -194
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +11 -273
- langfun/core/component_test.py +2 -29
- langfun/core/concurrent.py +187 -82
- langfun/core/concurrent_test.py +28 -19
- langfun/core/console.py +7 -3
- langfun/core/eval/base.py +2 -3
- langfun/core/eval/v2/evaluation.py +3 -1
- langfun/core/eval/v2/reporting.py +8 -4
- langfun/core/language_model.py +84 -8
- langfun/core/language_model_test.py +84 -29
- langfun/core/llms/__init__.py +46 -11
- langfun/core/llms/anthropic.py +1 -123
- langfun/core/llms/anthropic_test.py +0 -48
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/gemini.py +1 -1
- langfun/core/llms/groq.py +12 -99
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +17 -54
- langfun/core/llms/llama_cpp_test.py +2 -34
- langfun/core/llms/openai.py +9 -147
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +13 -423
- langfun/core/llms/rest_test.py +1 -1
- langfun/core/llms/vertexai.py +387 -18
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/message_test.py +3 -3
- langfun/core/modalities/mime.py +8 -0
- langfun/core/modalities/mime_test.py +19 -4
- langfun/core/modality_test.py +0 -1
- langfun/core/structured/mapping.py +13 -13
- langfun/core/structured/mapping_test.py +2 -2
- langfun/core/structured/schema.py +16 -8
- langfun/core/structured/schema_generation.py +1 -1
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/METADATA +13 -2
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/RECORD +50 -52
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/text_formatting.py +0 -168
- langfun/core/text_formatting_test.py +0 -65
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -15,10 +15,13 @@
|
|
15
15
|
|
16
16
|
import functools
|
17
17
|
import os
|
18
|
-
from typing import Annotated, Any
|
18
|
+
from typing import Annotated, Any, Literal
|
19
19
|
|
20
20
|
import langfun.core as lf
|
21
|
+
from langfun.core.llms import anthropic
|
21
22
|
from langfun.core.llms import gemini
|
23
|
+
from langfun.core.llms import openai_compatible
|
24
|
+
from langfun.core.llms import rest
|
22
25
|
import pyglove as pg
|
23
26
|
|
24
27
|
try:
|
@@ -36,10 +39,21 @@ except ImportError:
|
|
36
39
|
Credentials = Any
|
37
40
|
|
38
41
|
|
39
|
-
@
|
40
|
-
|
41
|
-
class VertexAI
|
42
|
-
|
42
|
+
@pg.use_init_args(['api_endpoint'])
|
43
|
+
class VertexAI(rest.REST):
|
44
|
+
"""Base class for VertexAI models.
|
45
|
+
|
46
|
+
This class handles the authentication of vertex AI models. Subclasses
|
47
|
+
should implement `request` and `result` methods, as well as the `api_endpoint`
|
48
|
+
property. Or let users to provide them as __init__ arguments.
|
49
|
+
|
50
|
+
Please check out VertexAIGemini in `gemini.py` as an example.
|
51
|
+
"""
|
52
|
+
|
53
|
+
model: Annotated[
|
54
|
+
str | None,
|
55
|
+
'Model ID.'
|
56
|
+
] = None
|
43
57
|
|
44
58
|
project: Annotated[
|
45
59
|
str | None,
|
@@ -95,7 +109,7 @@ class VertexAI(gemini.Gemini):
|
|
95
109
|
credentials = self.credentials
|
96
110
|
if credentials is None:
|
97
111
|
# Use default credentials.
|
98
|
-
credentials = google_auth.default(
|
112
|
+
credentials, _ = google_auth.default(
|
99
113
|
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
100
114
|
)
|
101
115
|
self._credentials = credentials
|
@@ -114,6 +128,17 @@ class VertexAI(gemini.Gemini):
|
|
114
128
|
s.headers.update(self.headers or {})
|
115
129
|
return s
|
116
130
|
|
131
|
+
|
132
|
+
#
|
133
|
+
# Gemini models served by Vertex AI.
|
134
|
+
#
|
135
|
+
|
136
|
+
|
137
|
+
@pg.use_init_args(['model'])
|
138
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
139
|
+
class VertexAIGemini(VertexAI, gemini.Gemini):
|
140
|
+
"""Gemini models served by Vertex AI.."""
|
141
|
+
|
117
142
|
@property
|
118
143
|
def api_endpoint(self) -> str:
|
119
144
|
assert self._api_initialized
|
@@ -124,7 +149,7 @@ class VertexAI(gemini.Gemini):
|
|
124
149
|
)
|
125
150
|
|
126
151
|
|
127
|
-
class VertexAIGeminiFlash2_0ThinkingExp_20241219(
|
152
|
+
class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAIGemini): # pylint: disable=invalid-name
|
128
153
|
"""Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
|
129
154
|
|
130
155
|
api_version = 'v1alpha'
|
@@ -132,61 +157,405 @@ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=i
|
|
132
157
|
timeout = None
|
133
158
|
|
134
159
|
|
135
|
-
class VertexAIGeminiFlash2_0Exp(
|
160
|
+
class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
|
136
161
|
"""Vertex AI Gemini 2.0 Flash model."""
|
137
162
|
|
138
163
|
model = 'gemini-2.0-flash-exp'
|
139
164
|
|
140
165
|
|
141
|
-
class VertexAIGeminiExp_20241206(
|
166
|
+
class VertexAIGeminiExp_20241206(VertexAIGemini): # pylint: disable=invalid-name
|
142
167
|
"""Vertex AI Gemini Experimental model launched on 12/06/2024."""
|
143
168
|
|
144
169
|
model = 'gemini-exp-1206'
|
145
170
|
|
146
171
|
|
147
|
-
class VertexAIGeminiExp_20241114(
|
172
|
+
class VertexAIGeminiExp_20241114(VertexAIGemini): # pylint: disable=invalid-name
|
148
173
|
"""Vertex AI Gemini Experimental model launched on 11/14/2024."""
|
149
174
|
|
150
175
|
model = 'gemini-exp-1114'
|
151
176
|
|
152
177
|
|
153
|
-
class VertexAIGeminiPro1_5(
|
178
|
+
class VertexAIGeminiPro1_5(VertexAIGemini): # pylint: disable=invalid-name
|
154
179
|
"""Vertex AI Gemini 1.5 Pro model."""
|
155
180
|
|
156
181
|
model = 'gemini-1.5-pro-latest'
|
157
182
|
|
158
183
|
|
159
|
-
class VertexAIGeminiPro1_5_002(
|
184
|
+
class VertexAIGeminiPro1_5_002(VertexAIGemini): # pylint: disable=invalid-name
|
160
185
|
"""Vertex AI Gemini 1.5 Pro model."""
|
161
186
|
|
162
187
|
model = 'gemini-1.5-pro-002'
|
163
188
|
|
164
189
|
|
165
|
-
class VertexAIGeminiPro1_5_001(
|
190
|
+
class VertexAIGeminiPro1_5_001(VertexAIGemini): # pylint: disable=invalid-name
|
166
191
|
"""Vertex AI Gemini 1.5 Pro model."""
|
167
192
|
|
168
193
|
model = 'gemini-1.5-pro-001'
|
169
194
|
|
170
195
|
|
171
|
-
class VertexAIGeminiFlash1_5(
|
196
|
+
class VertexAIGeminiFlash1_5(VertexAIGemini): # pylint: disable=invalid-name
|
172
197
|
"""Vertex AI Gemini 1.5 Flash model."""
|
173
198
|
|
174
199
|
model = 'gemini-1.5-flash'
|
175
200
|
|
176
201
|
|
177
|
-
class VertexAIGeminiFlash1_5_002(
|
202
|
+
class VertexAIGeminiFlash1_5_002(VertexAIGemini): # pylint: disable=invalid-name
|
178
203
|
"""Vertex AI Gemini 1.5 Flash model."""
|
179
204
|
|
180
205
|
model = 'gemini-1.5-flash-002'
|
181
206
|
|
182
207
|
|
183
|
-
class VertexAIGeminiFlash1_5_001(
|
208
|
+
class VertexAIGeminiFlash1_5_001(VertexAIGemini): # pylint: disable=invalid-name
|
184
209
|
"""Vertex AI Gemini 1.5 Flash model."""
|
185
210
|
|
186
211
|
model = 'gemini-1.5-flash-001'
|
187
212
|
|
188
213
|
|
189
|
-
class VertexAIGeminiPro1(
|
214
|
+
class VertexAIGeminiPro1(VertexAIGemini): # pylint: disable=invalid-name
|
190
215
|
"""Vertex AI Gemini 1.0 Pro model."""
|
191
216
|
|
192
217
|
model = 'gemini-1.0-pro'
|
218
|
+
|
219
|
+
|
220
|
+
#
|
221
|
+
# Anthropic models on Vertex AI.
|
222
|
+
#
|
223
|
+
|
224
|
+
|
225
|
+
@pg.use_init_args(['model'])
|
226
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
227
|
+
class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
|
228
|
+
"""Anthropic models on VertexAI."""
|
229
|
+
|
230
|
+
location: Annotated[
|
231
|
+
Literal['us-east5', 'europe-west1'],
|
232
|
+
'GCP location with Anthropic models hosted.'
|
233
|
+
] = 'us-east5'
|
234
|
+
|
235
|
+
api_version = 'vertex-2023-10-16'
|
236
|
+
|
237
|
+
@property
|
238
|
+
def headers(self):
|
239
|
+
return {
|
240
|
+
'Content-Type': 'application/json; charset=utf-8',
|
241
|
+
}
|
242
|
+
|
243
|
+
@property
|
244
|
+
def api_endpoint(self) -> str:
|
245
|
+
return (
|
246
|
+
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
247
|
+
f'{self._project}/locations/{self.location}/publishers/anthropic/'
|
248
|
+
f'models/{self.model}:streamRawPredict'
|
249
|
+
)
|
250
|
+
|
251
|
+
def request(
|
252
|
+
self,
|
253
|
+
prompt: lf.Message,
|
254
|
+
sampling_options: lf.LMSamplingOptions
|
255
|
+
):
|
256
|
+
request = super().request(prompt, sampling_options)
|
257
|
+
request['anthropic_version'] = self.api_version
|
258
|
+
del request['model']
|
259
|
+
return request
|
260
|
+
|
261
|
+
|
262
|
+
# pylint: disable=invalid-name
|
263
|
+
|
264
|
+
|
265
|
+
class VertexAIClaude3_Opus_20240229(VertexAIAnthropic):
|
266
|
+
"""Anthropic's Claude 3 Opus model on VertexAI."""
|
267
|
+
model = 'claude-3-opus@20240229'
|
268
|
+
|
269
|
+
|
270
|
+
class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic):
|
271
|
+
"""Anthropic's Claude 3.5 Sonnet model on VertexAI."""
|
272
|
+
model = 'claude-3-5-sonnet-v2@20241022'
|
273
|
+
|
274
|
+
|
275
|
+
class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic):
|
276
|
+
"""Anthropic's Claude 3.5 Sonnet model on VertexAI."""
|
277
|
+
model = 'claude-3-5-sonnet@20240620'
|
278
|
+
|
279
|
+
|
280
|
+
class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic):
|
281
|
+
"""Anthropic's Claude 3.5 Haiku model on VertexAI."""
|
282
|
+
model = 'claude-3-5-haiku@20241022'
|
283
|
+
|
284
|
+
# pylint: enable=invalid-name
|
285
|
+
|
286
|
+
#
|
287
|
+
# Llama models on Vertex AI.
|
288
|
+
# pylint: disable=line-too-long
|
289
|
+
# Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#meta-models
|
290
|
+
# pylint: enable=line-too-long
|
291
|
+
|
292
|
+
LLAMA_MODELS = {
|
293
|
+
'llama-3.2-90b-vision-instruct-maas': pg.Dict(
|
294
|
+
latest_update='2024-09-25',
|
295
|
+
in_service=True,
|
296
|
+
rpm=0,
|
297
|
+
tpm=0,
|
298
|
+
# Free during preview.
|
299
|
+
cost_per_1m_input_tokens=None,
|
300
|
+
cost_per_1m_output_tokens=None,
|
301
|
+
),
|
302
|
+
'llama-3.1-405b-instruct-maas': pg.Dict(
|
303
|
+
latest_update='2024-09-25',
|
304
|
+
in_service=True,
|
305
|
+
rpm=0,
|
306
|
+
tpm=0,
|
307
|
+
# GA.
|
308
|
+
cost_per_1m_input_tokens=5,
|
309
|
+
cost_per_1m_output_tokens=16,
|
310
|
+
),
|
311
|
+
'llama-3.1-70b-instruct-maas': pg.Dict(
|
312
|
+
latest_update='2024-09-25',
|
313
|
+
in_service=True,
|
314
|
+
rpm=0,
|
315
|
+
tpm=0,
|
316
|
+
# Free during preview.
|
317
|
+
cost_per_1m_input_tokens=None,
|
318
|
+
cost_per_1m_output_tokens=None,
|
319
|
+
),
|
320
|
+
'llama-3.1-8b-instruct-maas': pg.Dict(
|
321
|
+
latest_update='2024-09-25',
|
322
|
+
in_service=True,
|
323
|
+
rpm=0,
|
324
|
+
tpm=0,
|
325
|
+
# Free during preview.
|
326
|
+
cost_per_1m_input_tokens=None,
|
327
|
+
cost_per_1m_output_tokens=None,
|
328
|
+
)
|
329
|
+
}
|
330
|
+
|
331
|
+
|
332
|
+
@pg.use_init_args(['model'])
|
333
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
334
|
+
class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
|
335
|
+
"""Llama models on VertexAI."""
|
336
|
+
|
337
|
+
model: pg.typing.Annotated[
|
338
|
+
pg.typing.Enum(pg.MISSING_VALUE, list(LLAMA_MODELS.keys())),
|
339
|
+
'Llama model ID.',
|
340
|
+
]
|
341
|
+
|
342
|
+
locations: Annotated[
|
343
|
+
Literal['us-central1'],
|
344
|
+
(
|
345
|
+
'GCP locations with Llama models hosted. '
|
346
|
+
'See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#regions-quotas'
|
347
|
+
)
|
348
|
+
] = 'us-central1'
|
349
|
+
|
350
|
+
@property
|
351
|
+
def api_endpoint(self) -> str:
|
352
|
+
assert self._api_initialized
|
353
|
+
return (
|
354
|
+
f'https://{self._location}-aiplatform.googleapis.com/v1beta1/projects/'
|
355
|
+
f'{self._project}/locations/{self._location}/endpoints/'
|
356
|
+
f'openapi/chat/completions'
|
357
|
+
)
|
358
|
+
|
359
|
+
def request(
|
360
|
+
self,
|
361
|
+
prompt: lf.Message,
|
362
|
+
sampling_options: lf.LMSamplingOptions
|
363
|
+
):
|
364
|
+
request = super().request(prompt, sampling_options)
|
365
|
+
request['model'] = f'meta/{self.model}'
|
366
|
+
return request
|
367
|
+
|
368
|
+
@property
|
369
|
+
def max_concurrency(self) -> int:
|
370
|
+
rpm = LLAMA_MODELS[self.model].get('rpm', 0)
|
371
|
+
tpm = LLAMA_MODELS[self.model].get('tpm', 0)
|
372
|
+
return self.rate_to_max_concurrency(
|
373
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
374
|
+
)
|
375
|
+
|
376
|
+
def estimate_cost(
|
377
|
+
self,
|
378
|
+
num_input_tokens: int,
|
379
|
+
num_output_tokens: int
|
380
|
+
) -> float | None:
|
381
|
+
"""Estimate the cost based on usage."""
|
382
|
+
cost_per_1m_input_tokens = LLAMA_MODELS[self.model].get(
|
383
|
+
'cost_per_1m_input_tokens', None
|
384
|
+
)
|
385
|
+
cost_per_1m_output_tokens = LLAMA_MODELS[self.model].get(
|
386
|
+
'cost_per_1m_output_tokens', None
|
387
|
+
)
|
388
|
+
if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
|
389
|
+
return None
|
390
|
+
return (
|
391
|
+
cost_per_1m_input_tokens * num_input_tokens
|
392
|
+
+ cost_per_1m_output_tokens * num_output_tokens
|
393
|
+
) / 1000_000
|
394
|
+
|
395
|
+
|
396
|
+
# pylint: disable=invalid-name
|
397
|
+
class VertexAILlama3_2_90B(VertexAILlama):
|
398
|
+
"""Llama 3.2 90B vision instruct model on VertexAI."""
|
399
|
+
|
400
|
+
model = 'llama-3.2-90b-vision-instruct-maas'
|
401
|
+
|
402
|
+
|
403
|
+
class VertexAILlama3_1_405B(VertexAILlama):
|
404
|
+
"""Llama 3.1 405B vision instruct model on VertexAI."""
|
405
|
+
|
406
|
+
model = 'llama-3.1-405b-instruct-maas'
|
407
|
+
|
408
|
+
|
409
|
+
class VertexAILlama3_1_70B(VertexAILlama):
|
410
|
+
"""Llama 3.1 70B vision instruct model on VertexAI."""
|
411
|
+
|
412
|
+
model = 'llama-3.1-70b-instruct-maas'
|
413
|
+
|
414
|
+
|
415
|
+
class VertexAILlama3_1_8B(VertexAILlama):
|
416
|
+
"""Llama 3.1 8B vision instruct model on VertexAI."""
|
417
|
+
|
418
|
+
model = 'llama-3.1-8b-instruct-maas'
|
419
|
+
# pylint: enable=invalid-name
|
420
|
+
|
421
|
+
#
|
422
|
+
# Mistral models on Vertex AI.
|
423
|
+
# pylint: disable=line-too-long
|
424
|
+
# Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#mistral-models
|
425
|
+
# pylint: enable=line-too-long
|
426
|
+
|
427
|
+
|
428
|
+
MISTRAL_MODELS = {
|
429
|
+
'mistral-large-2411': pg.Dict(
|
430
|
+
latest_update='2024-11-21',
|
431
|
+
in_service=True,
|
432
|
+
rpm=0,
|
433
|
+
tpm=0,
|
434
|
+
# GA.
|
435
|
+
cost_per_1m_input_tokens=2,
|
436
|
+
cost_per_1m_output_tokens=6,
|
437
|
+
),
|
438
|
+
'mistral-large@2407': pg.Dict(
|
439
|
+
latest_update='2024-07-24',
|
440
|
+
in_service=True,
|
441
|
+
rpm=0,
|
442
|
+
tpm=0,
|
443
|
+
# GA.
|
444
|
+
cost_per_1m_input_tokens=2,
|
445
|
+
cost_per_1m_output_tokens=6,
|
446
|
+
),
|
447
|
+
'mistral-nemo@2407': pg.Dict(
|
448
|
+
latest_update='2024-07-24',
|
449
|
+
in_service=True,
|
450
|
+
rpm=0,
|
451
|
+
tpm=0,
|
452
|
+
# GA.
|
453
|
+
cost_per_1m_input_tokens=0.15,
|
454
|
+
cost_per_1m_output_tokens=0.15,
|
455
|
+
),
|
456
|
+
'codestral-2501': pg.Dict(
|
457
|
+
latest_update='2025-01-13',
|
458
|
+
in_service=True,
|
459
|
+
rpm=0,
|
460
|
+
tpm=0,
|
461
|
+
# GA.
|
462
|
+
cost_per_1m_input_tokens=0.3,
|
463
|
+
cost_per_1m_output_tokens=0.9,
|
464
|
+
),
|
465
|
+
'codestral@2405': pg.Dict(
|
466
|
+
latest_update='2024-05-29',
|
467
|
+
in_service=True,
|
468
|
+
rpm=0,
|
469
|
+
tpm=0,
|
470
|
+
# GA.
|
471
|
+
cost_per_1m_input_tokens=0.2,
|
472
|
+
cost_per_1m_output_tokens=0.6,
|
473
|
+
),
|
474
|
+
}
|
475
|
+
|
476
|
+
|
477
|
+
@pg.use_init_args(['model'])
|
478
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
479
|
+
class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
|
480
|
+
"""Mistral AI models on VertexAI."""
|
481
|
+
|
482
|
+
model: pg.typing.Annotated[
|
483
|
+
pg.typing.Enum(pg.MISSING_VALUE, list(MISTRAL_MODELS.keys())),
|
484
|
+
'Mistral model ID.',
|
485
|
+
]
|
486
|
+
|
487
|
+
locations: Annotated[
|
488
|
+
Literal['us-central1', 'europe-west4'],
|
489
|
+
(
|
490
|
+
'GCP locations with Mistral models hosted. '
|
491
|
+
'See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral#regions-quotas'
|
492
|
+
)
|
493
|
+
] = 'us-central1'
|
494
|
+
|
495
|
+
@property
|
496
|
+
def api_endpoint(self) -> str:
|
497
|
+
assert self._api_initialized
|
498
|
+
return (
|
499
|
+
f'https://{self._location}-aiplatform.googleapis.com/v1/projects/'
|
500
|
+
f'{self._project}/locations/{self._location}/publishers/mistralai/'
|
501
|
+
f'models/{self.model}:rawPredict'
|
502
|
+
)
|
503
|
+
|
504
|
+
@property
|
505
|
+
def max_concurrency(self) -> int:
|
506
|
+
rpm = MISTRAL_MODELS[self.model].get('rpm', 0)
|
507
|
+
tpm = MISTRAL_MODELS[self.model].get('tpm', 0)
|
508
|
+
return self.rate_to_max_concurrency(
|
509
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
510
|
+
)
|
511
|
+
|
512
|
+
def estimate_cost(
|
513
|
+
self,
|
514
|
+
num_input_tokens: int,
|
515
|
+
num_output_tokens: int
|
516
|
+
) -> float | None:
|
517
|
+
"""Estimate the cost based on usage."""
|
518
|
+
cost_per_1m_input_tokens = MISTRAL_MODELS[self.model].get(
|
519
|
+
'cost_per_1m_input_tokens', None
|
520
|
+
)
|
521
|
+
cost_per_1m_output_tokens = MISTRAL_MODELS[self.model].get(
|
522
|
+
'cost_per_1m_output_tokens', None
|
523
|
+
)
|
524
|
+
if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
|
525
|
+
return None
|
526
|
+
return (
|
527
|
+
cost_per_1m_input_tokens * num_input_tokens
|
528
|
+
+ cost_per_1m_output_tokens * num_output_tokens
|
529
|
+
) / 1000_000
|
530
|
+
|
531
|
+
|
532
|
+
# pylint: disable=invalid-name
|
533
|
+
class VertexAIMistralLarge_20241121(VertexAIMistral):
|
534
|
+
"""Mistral Large model on VertexAI released on 2024/11/21."""
|
535
|
+
|
536
|
+
model = 'mistral-large-2411'
|
537
|
+
|
538
|
+
|
539
|
+
class VertexAIMistralLarge_20240724(VertexAIMistral):
|
540
|
+
"""Mistral Large model on VertexAI released on 2024/07/24."""
|
541
|
+
|
542
|
+
model = 'mistral-large@2407'
|
543
|
+
|
544
|
+
|
545
|
+
class VertexAIMistralNemo_20240724(VertexAIMistral):
|
546
|
+
"""Mistral Nemo model on VertexAI released on 2024/07/24."""
|
547
|
+
|
548
|
+
model = 'mistral-nemo@2407'
|
549
|
+
|
550
|
+
|
551
|
+
class VertexAICodestral_20250113(VertexAIMistral):
|
552
|
+
"""Mistral Nemo model on VertexAI released on 2024/07/24."""
|
553
|
+
|
554
|
+
model = 'codestral-2501'
|
555
|
+
|
556
|
+
|
557
|
+
class VertexAICodestral_20240529(VertexAIMistral):
|
558
|
+
"""Mistral Nemo model on VertexAI released on 2024/05/29."""
|
559
|
+
|
560
|
+
model = 'codestral@2405'
|
561
|
+
# pylint: enable=invalid-name
|
@@ -17,6 +17,8 @@ import os
|
|
17
17
|
import unittest
|
18
18
|
from unittest import mock
|
19
19
|
|
20
|
+
from google.auth import exceptions
|
21
|
+
import langfun.core as lf
|
20
22
|
from langfun.core.llms import vertexai
|
21
23
|
|
22
24
|
|
@@ -48,5 +50,55 @@ class VertexAITest(unittest.TestCase):
|
|
48
50
|
del os.environ['VERTEXAI_LOCATION']
|
49
51
|
|
50
52
|
|
53
|
+
class VertexAIAnthropicTest(unittest.TestCase):
|
54
|
+
"""Tests for VertexAI Anthropic models."""
|
55
|
+
|
56
|
+
def test_basics(self):
|
57
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
58
|
+
lm = vertexai.VertexAIClaude3_5_Sonnet_20241022()
|
59
|
+
lm('hi')
|
60
|
+
|
61
|
+
model = vertexai.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
|
62
|
+
|
63
|
+
# NOTE(daiyip): For OSS users, default credentials are not available unless
|
64
|
+
# users have already set up their GCP project. Therefore we ignore the
|
65
|
+
# exception here.
|
66
|
+
try:
|
67
|
+
model._initialize()
|
68
|
+
except exceptions.DefaultCredentialsError:
|
69
|
+
pass
|
70
|
+
|
71
|
+
self.assertEqual(
|
72
|
+
model.api_endpoint,
|
73
|
+
(
|
74
|
+
'https://us-east5-aiplatform.googleapis.com/v1/projects/'
|
75
|
+
'langfun/locations/us-east5/publishers/anthropic/'
|
76
|
+
'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
|
77
|
+
)
|
78
|
+
)
|
79
|
+
self.assertEqual(
|
80
|
+
model.headers,
|
81
|
+
{
|
82
|
+
'Content-Type': 'application/json; charset=utf-8',
|
83
|
+
},
|
84
|
+
)
|
85
|
+
request = model.request(
|
86
|
+
lf.UserMessage('hi'), lf.LMSamplingOptions(temperature=0.0),
|
87
|
+
)
|
88
|
+
self.assertEqual(
|
89
|
+
request,
|
90
|
+
{
|
91
|
+
'anthropic_version': 'vertex-2023-10-16',
|
92
|
+
'max_tokens': 8192,
|
93
|
+
'messages': [
|
94
|
+
{'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
|
95
|
+
],
|
96
|
+
'stream': False,
|
97
|
+
'temperature': 0.0,
|
98
|
+
'top_k': 40,
|
99
|
+
},
|
100
|
+
)
|
101
|
+
|
102
|
+
|
51
103
|
if __name__ == '__main__':
|
52
104
|
unittest.main()
|