langfun 0.1.2.dev202501160804__py3-none-any.whl → 0.1.2.dev202501180803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,8 +27,15 @@ from langfun.core.llms.fake import StaticSequence
27
27
  # Compositional models.
28
28
  from langfun.core.llms.compositional import RandomChoice
29
29
 
30
- # REST-based models.
30
+ # Base models by request/response protocol.
31
31
  from langfun.core.llms.rest import REST
32
+ from langfun.core.llms.openai_compatible import OpenAICompatible
33
+ from langfun.core.llms.gemini import Gemini
34
+ from langfun.core.llms.anthropic import Anthropic
35
+
36
+ # Base models by serving platforms.
37
+ from langfun.core.llms.vertexai import VertexAI
38
+ from langfun.core.llms.groq import Groq
32
39
 
33
40
  # Gemini models.
34
41
  from langfun.core.llms.google_genai import GenAI
@@ -44,7 +51,7 @@ from langfun.core.llms.google_genai import GeminiFlash1_5_002
44
51
  from langfun.core.llms.google_genai import GeminiFlash1_5_001
45
52
  from langfun.core.llms.google_genai import GeminiPro1
46
53
 
47
- from langfun.core.llms.vertexai import VertexAI
54
+ from langfun.core.llms.vertexai import VertexAIGemini
48
55
  from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
49
56
  from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
50
57
  from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
@@ -57,9 +64,6 @@ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
57
64
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
58
65
  from langfun.core.llms.vertexai import VertexAIGeminiPro1
59
66
 
60
- # Base for OpenAI-compatible models.
61
- from langfun.core.llms.openai_compatible import OpenAICompatible
62
-
63
67
  # OpenAI models.
64
68
  from langfun.core.llms.openai import OpenAI
65
69
 
@@ -114,20 +118,34 @@ from langfun.core.llms.openai import Gpt3Curie
114
118
  from langfun.core.llms.openai import Gpt3Babbage
115
119
  from langfun.core.llms.openai import Gpt3Ada
116
120
 
117
- from langfun.core.llms.anthropic import Anthropic
121
+ # Anthropic models.
122
+
118
123
  from langfun.core.llms.anthropic import Claude35Sonnet
119
124
  from langfun.core.llms.anthropic import Claude35Sonnet20241022
120
125
  from langfun.core.llms.anthropic import Claude35Sonnet20240620
121
126
  from langfun.core.llms.anthropic import Claude3Opus
122
127
  from langfun.core.llms.anthropic import Claude3Sonnet
123
128
  from langfun.core.llms.anthropic import Claude3Haiku
124
- from langfun.core.llms.anthropic import VertexAIAnthropic
125
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20241022
126
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20240620
127
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Haiku_20241022
128
- from langfun.core.llms.anthropic import VertexAIClaude3_Opus_20240229
129
129
 
130
- from langfun.core.llms.groq import Groq
130
+ from langfun.core.llms.vertexai import VertexAIAnthropic
131
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20241022
132
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20240620
133
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Haiku_20241022
134
+ from langfun.core.llms.vertexai import VertexAIClaude3_Opus_20240229
135
+
136
+ # Misc open source models.
137
+
138
+ # Gemma models.
139
+ from langfun.core.llms.groq import GroqGemma2_9B_IT
140
+ from langfun.core.llms.groq import GroqGemma_7B_IT
141
+
142
+ # Llama models.
143
+ from langfun.core.llms.vertexai import VertexAILlama
144
+ from langfun.core.llms.vertexai import VertexAILlama3_2_90B
145
+ from langfun.core.llms.vertexai import VertexAILlama3_1_405B
146
+ from langfun.core.llms.vertexai import VertexAILlama3_1_70B
147
+ from langfun.core.llms.vertexai import VertexAILlama3_1_8B
148
+
131
149
  from langfun.core.llms.groq import GroqLlama3_2_3B
132
150
  from langfun.core.llms.groq import GroqLlama3_2_1B
133
151
  from langfun.core.llms.groq import GroqLlama3_1_70B
@@ -135,18 +153,28 @@ from langfun.core.llms.groq import GroqLlama3_1_8B
135
153
  from langfun.core.llms.groq import GroqLlama3_70B
136
154
  from langfun.core.llms.groq import GroqLlama3_8B
137
155
  from langfun.core.llms.groq import GroqLlama2_70B
156
+
157
+ # Mistral models.
158
+ from langfun.core.llms.vertexai import VertexAIMistral
159
+ from langfun.core.llms.vertexai import VertexAIMistralLarge_20241121
160
+ from langfun.core.llms.vertexai import VertexAIMistralLarge_20240724
161
+ from langfun.core.llms.vertexai import VertexAIMistralNemo_20240724
162
+ from langfun.core.llms.vertexai import VertexAICodestral_20250113
163
+ from langfun.core.llms.vertexai import VertexAICodestral_20240529
164
+
138
165
  from langfun.core.llms.groq import GroqMistral_8x7B
139
- from langfun.core.llms.groq import GroqGemma2_9B_IT
140
- from langfun.core.llms.groq import GroqGemma_7B_IT
166
+
167
+ # DeepSeek models.
168
+ from langfun.core.llms.deepseek import DeepSeek
169
+ from langfun.core.llms.deepseek import DeepSeekChat
170
+
171
+ # Whisper models.
141
172
  from langfun.core.llms.groq import GroqWhisper_Large_v3
142
173
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
143
174
 
144
175
  # LLaMA C++ models.
145
176
  from langfun.core.llms.llama_cpp import LlamaCppRemote
146
177
 
147
- # DeepSeek models.
148
- from langfun.core.llms.deepseek import DeepSeek
149
- from langfun.core.llms.deepseek import DeepSeekChat
150
178
 
151
179
  # Placeholder for Google-internal imports.
152
180
 
@@ -14,9 +14,8 @@
14
14
  """Language models from Anthropic."""
15
15
 
16
16
  import base64
17
- import functools
18
17
  import os
19
- from typing import Annotated, Any, Literal
18
+ from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
21
  from langfun.core import modalities as lf_modalities
@@ -24,20 +23,6 @@ from langfun.core.llms import rest
24
23
  import pyglove as pg
25
24
 
26
25
 
27
- try:
28
- # pylint: disable=g-import-not-at-top
29
- from google import auth as google_auth
30
- from google.auth import credentials as credentials_lib
31
- from google.auth.transport import requests as auth_requests
32
- Credentials = credentials_lib.Credentials
33
- # pylint: enable=g-import-not-at-top
34
- except ImportError:
35
- google_auth = None
36
- auth_requests = None
37
- credentials_lib = None
38
- Credentials = Any # pylint: disable=invalid-name
39
-
40
-
41
26
  SUPPORTED_MODELS_AND_SETTINGS = {
42
27
  # See https://docs.anthropic.com/claude/docs/models-overview
43
28
  # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
@@ -379,110 +364,3 @@ class Claude21(Anthropic):
379
364
  class ClaudeInstant(Anthropic):
380
365
  """Cheapest small and fast model, 100K context window."""
381
366
  model = 'claude-instant-1.2'
382
-
383
-
384
- #
385
- # Authropic models on VertexAI.
386
- #
387
-
388
-
389
- class VertexAIAnthropic(Anthropic):
390
- """Anthropic models on VertexAI."""
391
-
392
- project: Annotated[
393
- str | None,
394
- 'Google Cloud project ID.',
395
- ] = None
396
-
397
- location: Annotated[
398
- Literal['us-east5', 'europe-west1'],
399
- 'GCP location with Anthropic models hosted.'
400
- ] = 'us-east5'
401
-
402
- credentials: Annotated[
403
- Credentials | None, # pytype: disable=invalid-annotation
404
- (
405
- 'Credentials to use. If None, the default credentials '
406
- 'to the environment will be used.'
407
- ),
408
- ] = None
409
-
410
- api_version = 'vertex-2023-10-16'
411
-
412
- def _on_bound(self):
413
- super()._on_bound()
414
- if google_auth is None:
415
- raise ValueError(
416
- 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
417
- )
418
- self._project = None
419
- self._credentials = None
420
-
421
- def _initialize(self):
422
- project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
423
- if not project:
424
- raise ValueError(
425
- 'Please specify `project` during `__init__` or set environment '
426
- 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
427
- )
428
- self._project = project
429
- credentials = self.credentials
430
- if credentials is None:
431
- # Use default credentials.
432
- credentials = google_auth.default(
433
- scopes=['https://www.googleapis.com/auth/cloud-platform']
434
- )
435
- self._credentials = credentials
436
-
437
- @functools.cached_property
438
- def _session(self):
439
- assert self._api_initialized
440
- assert self._credentials is not None
441
- assert auth_requests is not None
442
- s = auth_requests.AuthorizedSession(self._credentials)
443
- s.headers.update(self.headers or {})
444
- return s
445
-
446
- @property
447
- def headers(self):
448
- return {
449
- 'Content-Type': 'application/json; charset=utf-8',
450
- }
451
-
452
- @property
453
- def api_endpoint(self) -> str:
454
- return (
455
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
456
- f'{self._project}/locations/{self.location}/publishers/anthropic/'
457
- f'models/{self.model}:streamRawPredict'
458
- )
459
-
460
- def request(
461
- self,
462
- prompt: lf.Message,
463
- sampling_options: lf.LMSamplingOptions
464
- ):
465
- request = super().request(prompt, sampling_options)
466
- request['anthropic_version'] = self.api_version
467
- del request['model']
468
- return request
469
-
470
-
471
- class VertexAIClaude3_Opus_20240229(VertexAIAnthropic): # pylint: disable=invalid-name
472
- """Anthropic's Claude 3 Opus model on VertexAI."""
473
- model = 'claude-3-opus@20240229'
474
-
475
-
476
- class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
477
- """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
478
- model = 'claude-3-5-sonnet-v2@20241022'
479
-
480
-
481
- class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic): # pylint: disable=invalid-name
482
- """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
483
- model = 'claude-3-5-sonnet@20240620'
484
-
485
-
486
- class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
487
- """Anthropic's Claude 3.5 Haiku model on VertexAI."""
488
- model = 'claude-3-5-haiku@20241022'
@@ -19,9 +19,6 @@ from typing import Any
19
19
  import unittest
20
20
  from unittest import mock
21
21
 
22
- from google.auth import exceptions
23
- from langfun.core import language_model
24
- from langfun.core import message as lf_message
25
22
  from langfun.core import modalities as lf_modalities
26
23
  from langfun.core.llms import anthropic
27
24
  import pyglove as pg
@@ -186,50 +183,5 @@ class AnthropicTest(unittest.TestCase):
186
183
  lm('hello', max_attempts=1)
187
184
 
188
185
 
189
- class VertexAIAnthropicTest(unittest.TestCase):
190
- """Tests for VertexAI Anthropic models."""
191
-
192
- def test_basics(self):
193
- with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
194
- lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
195
- lm('hi')
196
-
197
- model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
198
-
199
- # NOTE(daiyip): For OSS users, default credentials are not available unless
200
- # users have already set up their GCP project. Therefore we ignore the
201
- # exception here.
202
- try:
203
- model._initialize()
204
- except exceptions.DefaultCredentialsError:
205
- pass
206
-
207
- self.assertEqual(
208
- model.api_endpoint,
209
- (
210
- 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
211
- 'langfun/locations/us-east5/publishers/anthropic/'
212
- 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
213
- )
214
- )
215
- request = model.request(
216
- lf_message.UserMessage('hi'),
217
- language_model.LMSamplingOptions(temperature=0.0),
218
- )
219
- self.assertEqual(
220
- request,
221
- {
222
- 'anthropic_version': 'vertex-2023-10-16',
223
- 'max_tokens': 8192,
224
- 'messages': [
225
- {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
226
- ],
227
- 'stream': False,
228
- 'temperature': 0.0,
229
- 'top_k': 40,
230
- },
231
- )
232
-
233
-
234
186
  if __name__ == '__main__':
235
187
  unittest.main()
@@ -380,7 +380,7 @@ class Gemini(rest.REST):
380
380
  return (
381
381
  cost_per_1m_input_tokens * num_input_tokens
382
382
  + cost_per_1m_output_tokens * num_output_tokens
383
- ) / 1000_1000
383
+ ) / 1000_000
384
384
 
385
385
  @property
386
386
  def model_id(self) -> str:
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,10 +15,13 @@
15
15
 
16
16
  import functools
17
17
  import os
18
- from typing import Annotated, Any
18
+ from typing import Annotated, Any, Literal
19
19
 
20
20
  import langfun.core as lf
21
+ from langfun.core.llms import anthropic
21
22
  from langfun.core.llms import gemini
23
+ from langfun.core.llms import openai_compatible
24
+ from langfun.core.llms import rest
22
25
  import pyglove as pg
23
26
 
24
27
  try:
@@ -36,10 +39,21 @@ except ImportError:
36
39
  Credentials = Any
37
40
 
38
41
 
39
- @lf.use_init_args(['model'])
40
- @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
41
- class VertexAI(gemini.Gemini):
42
- """Language model served on VertexAI with REST API."""
42
+ @pg.use_init_args(['api_endpoint'])
43
+ class VertexAI(rest.REST):
44
+ """Base class for VertexAI models.
45
+
46
+ This class handles the authentication of vertex AI models. Subclasses
47
+ should implement `request` and `result` methods, as well as the `api_endpoint`
48
+ property. Or let users to provide them as __init__ arguments.
49
+
50
+ Please check out VertexAIGemini in `gemini.py` as an example.
51
+ """
52
+
53
+ model: Annotated[
54
+ str | None,
55
+ 'Model ID.'
56
+ ] = None
43
57
 
44
58
  project: Annotated[
45
59
  str | None,
@@ -95,7 +109,7 @@ class VertexAI(gemini.Gemini):
95
109
  credentials = self.credentials
96
110
  if credentials is None:
97
111
  # Use default credentials.
98
- credentials = google_auth.default(
112
+ credentials, _ = google_auth.default(
99
113
  scopes=['https://www.googleapis.com/auth/cloud-platform']
100
114
  )
101
115
  self._credentials = credentials
@@ -114,6 +128,17 @@ class VertexAI(gemini.Gemini):
114
128
  s.headers.update(self.headers or {})
115
129
  return s
116
130
 
131
+
132
+ #
133
+ # Gemini models served by Vertex AI.
134
+ #
135
+
136
+
137
+ @pg.use_init_args(['model'])
138
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
139
+ class VertexAIGemini(VertexAI, gemini.Gemini):
140
+ """Gemini models served by Vertex AI.."""
141
+
117
142
  @property
118
143
  def api_endpoint(self) -> str:
119
144
  assert self._api_initialized
@@ -124,7 +149,7 @@ class VertexAI(gemini.Gemini):
124
149
  )
125
150
 
126
151
 
127
- class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
152
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAIGemini): # pylint: disable=invalid-name
128
153
  """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
129
154
 
130
155
  api_version = 'v1alpha'
@@ -132,61 +157,405 @@ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=i
132
157
  timeout = None
133
158
 
134
159
 
135
- class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
160
+ class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
136
161
  """Vertex AI Gemini 2.0 Flash model."""
137
162
 
138
163
  model = 'gemini-2.0-flash-exp'
139
164
 
140
165
 
141
- class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
166
+ class VertexAIGeminiExp_20241206(VertexAIGemini): # pylint: disable=invalid-name
142
167
  """Vertex AI Gemini Experimental model launched on 12/06/2024."""
143
168
 
144
169
  model = 'gemini-exp-1206'
145
170
 
146
171
 
147
- class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
172
+ class VertexAIGeminiExp_20241114(VertexAIGemini): # pylint: disable=invalid-name
148
173
  """Vertex AI Gemini Experimental model launched on 11/14/2024."""
149
174
 
150
175
  model = 'gemini-exp-1114'
151
176
 
152
177
 
153
- class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
178
+ class VertexAIGeminiPro1_5(VertexAIGemini): # pylint: disable=invalid-name
154
179
  """Vertex AI Gemini 1.5 Pro model."""
155
180
 
156
181
  model = 'gemini-1.5-pro-latest'
157
182
 
158
183
 
159
- class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
184
+ class VertexAIGeminiPro1_5_002(VertexAIGemini): # pylint: disable=invalid-name
160
185
  """Vertex AI Gemini 1.5 Pro model."""
161
186
 
162
187
  model = 'gemini-1.5-pro-002'
163
188
 
164
189
 
165
- class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
190
+ class VertexAIGeminiPro1_5_001(VertexAIGemini): # pylint: disable=invalid-name
166
191
  """Vertex AI Gemini 1.5 Pro model."""
167
192
 
168
193
  model = 'gemini-1.5-pro-001'
169
194
 
170
195
 
171
- class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
196
+ class VertexAIGeminiFlash1_5(VertexAIGemini): # pylint: disable=invalid-name
172
197
  """Vertex AI Gemini 1.5 Flash model."""
173
198
 
174
199
  model = 'gemini-1.5-flash'
175
200
 
176
201
 
177
- class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
202
+ class VertexAIGeminiFlash1_5_002(VertexAIGemini): # pylint: disable=invalid-name
178
203
  """Vertex AI Gemini 1.5 Flash model."""
179
204
 
180
205
  model = 'gemini-1.5-flash-002'
181
206
 
182
207
 
183
- class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
208
+ class VertexAIGeminiFlash1_5_001(VertexAIGemini): # pylint: disable=invalid-name
184
209
  """Vertex AI Gemini 1.5 Flash model."""
185
210
 
186
211
  model = 'gemini-1.5-flash-001'
187
212
 
188
213
 
189
- class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
214
+ class VertexAIGeminiPro1(VertexAIGemini): # pylint: disable=invalid-name
190
215
  """Vertex AI Gemini 1.0 Pro model."""
191
216
 
192
217
  model = 'gemini-1.0-pro'
218
+
219
+
220
+ #
221
+ # Anthropic models on Vertex AI.
222
+ #
223
+
224
+
225
+ @pg.use_init_args(['model'])
226
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
227
+ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
228
+ """Anthropic models on VertexAI."""
229
+
230
+ location: Annotated[
231
+ Literal['us-east5', 'europe-west1'],
232
+ 'GCP location with Anthropic models hosted.'
233
+ ] = 'us-east5'
234
+
235
+ api_version = 'vertex-2023-10-16'
236
+
237
+ @property
238
+ def headers(self):
239
+ return {
240
+ 'Content-Type': 'application/json; charset=utf-8',
241
+ }
242
+
243
+ @property
244
+ def api_endpoint(self) -> str:
245
+ return (
246
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
247
+ f'{self._project}/locations/{self.location}/publishers/anthropic/'
248
+ f'models/{self.model}:streamRawPredict'
249
+ )
250
+
251
+ def request(
252
+ self,
253
+ prompt: lf.Message,
254
+ sampling_options: lf.LMSamplingOptions
255
+ ):
256
+ request = super().request(prompt, sampling_options)
257
+ request['anthropic_version'] = self.api_version
258
+ del request['model']
259
+ return request
260
+
261
+
262
+ # pylint: disable=invalid-name
263
+
264
+
265
+ class VertexAIClaude3_Opus_20240229(VertexAIAnthropic):
266
+ """Anthropic's Claude 3 Opus model on VertexAI."""
267
+ model = 'claude-3-opus@20240229'
268
+
269
+
270
+ class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic):
271
+ """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
272
+ model = 'claude-3-5-sonnet-v2@20241022'
273
+
274
+
275
+ class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic):
276
+ """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
277
+ model = 'claude-3-5-sonnet@20240620'
278
+
279
+
280
+ class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic):
281
+ """Anthropic's Claude 3.5 Haiku model on VertexAI."""
282
+ model = 'claude-3-5-haiku@20241022'
283
+
284
+ # pylint: enable=invalid-name
285
+
286
+ #
287
+ # Llama models on Vertex AI.
288
+ # pylint: disable=line-too-long
289
+ # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#meta-models
290
+ # pylint: enable=line-too-long
291
+
292
+ LLAMA_MODELS = {
293
+ 'llama-3.2-90b-vision-instruct-maas': pg.Dict(
294
+ latest_update='2024-09-25',
295
+ in_service=True,
296
+ rpm=0,
297
+ tpm=0,
298
+ # Free during preview.
299
+ cost_per_1m_input_tokens=None,
300
+ cost_per_1m_output_tokens=None,
301
+ ),
302
+ 'llama-3.1-405b-instruct-maas': pg.Dict(
303
+ latest_update='2024-09-25',
304
+ in_service=True,
305
+ rpm=0,
306
+ tpm=0,
307
+ # GA.
308
+ cost_per_1m_input_tokens=5,
309
+ cost_per_1m_output_tokens=16,
310
+ ),
311
+ 'llama-3.1-70b-instruct-maas': pg.Dict(
312
+ latest_update='2024-09-25',
313
+ in_service=True,
314
+ rpm=0,
315
+ tpm=0,
316
+ # Free during preview.
317
+ cost_per_1m_input_tokens=None,
318
+ cost_per_1m_output_tokens=None,
319
+ ),
320
+ 'llama-3.1-8b-instruct-maas': pg.Dict(
321
+ latest_update='2024-09-25',
322
+ in_service=True,
323
+ rpm=0,
324
+ tpm=0,
325
+ # Free during preview.
326
+ cost_per_1m_input_tokens=None,
327
+ cost_per_1m_output_tokens=None,
328
+ )
329
+ }
330
+
331
+
332
+ @pg.use_init_args(['model'])
333
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
334
+ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
335
+ """Llama models on VertexAI."""
336
+
337
+ model: pg.typing.Annotated[
338
+ pg.typing.Enum(pg.MISSING_VALUE, list(LLAMA_MODELS.keys())),
339
+ 'Llama model ID.',
340
+ ]
341
+
342
+ locations: Annotated[
343
+ Literal['us-central1'],
344
+ (
345
+ 'GCP locations with Llama models hosted. '
346
+ 'See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#regions-quotas'
347
+ )
348
+ ] = 'us-central1'
349
+
350
+ @property
351
+ def api_endpoint(self) -> str:
352
+ assert self._api_initialized
353
+ return (
354
+ f'https://{self._location}-aiplatform.googleapis.com/v1beta1/projects/'
355
+ f'{self._project}/locations/{self._location}/endpoints/'
356
+ f'openapi/chat/completions'
357
+ )
358
+
359
+ def request(
360
+ self,
361
+ prompt: lf.Message,
362
+ sampling_options: lf.LMSamplingOptions
363
+ ):
364
+ request = super().request(prompt, sampling_options)
365
+ request['model'] = f'meta/{self.model}'
366
+ return request
367
+
368
+ @property
369
+ def max_concurrency(self) -> int:
370
+ rpm = LLAMA_MODELS[self.model].get('rpm', 0)
371
+ tpm = LLAMA_MODELS[self.model].get('tpm', 0)
372
+ return self.rate_to_max_concurrency(
373
+ requests_per_min=rpm, tokens_per_min=tpm
374
+ )
375
+
376
+ def estimate_cost(
377
+ self,
378
+ num_input_tokens: int,
379
+ num_output_tokens: int
380
+ ) -> float | None:
381
+ """Estimate the cost based on usage."""
382
+ cost_per_1m_input_tokens = LLAMA_MODELS[self.model].get(
383
+ 'cost_per_1m_input_tokens', None
384
+ )
385
+ cost_per_1m_output_tokens = LLAMA_MODELS[self.model].get(
386
+ 'cost_per_1m_output_tokens', None
387
+ )
388
+ if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
389
+ return None
390
+ return (
391
+ cost_per_1m_input_tokens * num_input_tokens
392
+ + cost_per_1m_output_tokens * num_output_tokens
393
+ ) / 1000_000
394
+
395
+
396
+ # pylint: disable=invalid-name
397
+ class VertexAILlama3_2_90B(VertexAILlama):
398
+ """Llama 3.2 90B vision instruct model on VertexAI."""
399
+
400
+ model = 'llama-3.2-90b-vision-instruct-maas'
401
+
402
+
403
+ class VertexAILlama3_1_405B(VertexAILlama):
404
+ """Llama 3.1 405B vision instruct model on VertexAI."""
405
+
406
+ model = 'llama-3.1-405b-instruct-maas'
407
+
408
+
409
+ class VertexAILlama3_1_70B(VertexAILlama):
410
+ """Llama 3.1 70B vision instruct model on VertexAI."""
411
+
412
+ model = 'llama-3.1-70b-instruct-maas'
413
+
414
+
415
+ class VertexAILlama3_1_8B(VertexAILlama):
416
+ """Llama 3.1 8B vision instruct model on VertexAI."""
417
+
418
+ model = 'llama-3.1-8b-instruct-maas'
419
+ # pylint: enable=invalid-name
420
+
421
+ #
422
+ # Mistral models on Vertex AI.
423
+ # pylint: disable=line-too-long
424
+ # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#mistral-models
425
+ # pylint: enable=line-too-long
426
+
427
+
428
+ MISTRAL_MODELS = {
429
+ 'mistral-large-2411': pg.Dict(
430
+ latest_update='2024-11-21',
431
+ in_service=True,
432
+ rpm=0,
433
+ tpm=0,
434
+ # GA.
435
+ cost_per_1m_input_tokens=2,
436
+ cost_per_1m_output_tokens=6,
437
+ ),
438
+ 'mistral-large@2407': pg.Dict(
439
+ latest_update='2024-07-24',
440
+ in_service=True,
441
+ rpm=0,
442
+ tpm=0,
443
+ # GA.
444
+ cost_per_1m_input_tokens=2,
445
+ cost_per_1m_output_tokens=6,
446
+ ),
447
+ 'mistral-nemo@2407': pg.Dict(
448
+ latest_update='2024-07-24',
449
+ in_service=True,
450
+ rpm=0,
451
+ tpm=0,
452
+ # GA.
453
+ cost_per_1m_input_tokens=0.15,
454
+ cost_per_1m_output_tokens=0.15,
455
+ ),
456
+ 'codestral-2501': pg.Dict(
457
+ latest_update='2025-01-13',
458
+ in_service=True,
459
+ rpm=0,
460
+ tpm=0,
461
+ # GA.
462
+ cost_per_1m_input_tokens=0.3,
463
+ cost_per_1m_output_tokens=0.9,
464
+ ),
465
+ 'codestral@2405': pg.Dict(
466
+ latest_update='2024-05-29',
467
+ in_service=True,
468
+ rpm=0,
469
+ tpm=0,
470
+ # GA.
471
+ cost_per_1m_input_tokens=0.2,
472
+ cost_per_1m_output_tokens=0.6,
473
+ ),
474
+ }
475
+
476
+
477
+ @pg.use_init_args(['model'])
478
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
479
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
480
+ """Mistral AI models on VertexAI."""
481
+
482
+ model: pg.typing.Annotated[
483
+ pg.typing.Enum(pg.MISSING_VALUE, list(MISTRAL_MODELS.keys())),
484
+ 'Mistral model ID.',
485
+ ]
486
+
487
+ locations: Annotated[
488
+ Literal['us-central1', 'europe-west4'],
489
+ (
490
+ 'GCP locations with Mistral models hosted. '
491
+ 'See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral#regions-quotas'
492
+ )
493
+ ] = 'us-central1'
494
+
495
+ @property
496
+ def api_endpoint(self) -> str:
497
+ assert self._api_initialized
498
+ return (
499
+ f'https://{self._location}-aiplatform.googleapis.com/v1/projects/'
500
+ f'{self._project}/locations/{self._location}/publishers/mistralai/'
501
+ f'models/{self.model}:rawPredict'
502
+ )
503
+
504
+ @property
505
+ def max_concurrency(self) -> int:
506
+ rpm = MISTRAL_MODELS[self.model].get('rpm', 0)
507
+ tpm = MISTRAL_MODELS[self.model].get('tpm', 0)
508
+ return self.rate_to_max_concurrency(
509
+ requests_per_min=rpm, tokens_per_min=tpm
510
+ )
511
+
512
+ def estimate_cost(
513
+ self,
514
+ num_input_tokens: int,
515
+ num_output_tokens: int
516
+ ) -> float | None:
517
+ """Estimate the cost based on usage."""
518
+ cost_per_1m_input_tokens = MISTRAL_MODELS[self.model].get(
519
+ 'cost_per_1m_input_tokens', None
520
+ )
521
+ cost_per_1m_output_tokens = MISTRAL_MODELS[self.model].get(
522
+ 'cost_per_1m_output_tokens', None
523
+ )
524
+ if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
525
+ return None
526
+ return (
527
+ cost_per_1m_input_tokens * num_input_tokens
528
+ + cost_per_1m_output_tokens * num_output_tokens
529
+ ) / 1000_000
530
+
531
+
532
+ # pylint: disable=invalid-name
533
+ class VertexAIMistralLarge_20241121(VertexAIMistral):
534
+ """Mistral Large model on VertexAI released on 2024/11/21."""
535
+
536
+ model = 'mistral-large-2411'
537
+
538
+
539
+ class VertexAIMistralLarge_20240724(VertexAIMistral):
540
+ """Mistral Large model on VertexAI released on 2024/07/24."""
541
+
542
+ model = 'mistral-large@2407'
543
+
544
+
545
+ class VertexAIMistralNemo_20240724(VertexAIMistral):
546
+ """Mistral Nemo model on VertexAI released on 2024/07/24."""
547
+
548
+ model = 'mistral-nemo@2407'
549
+
550
+
551
+ class VertexAICodestral_20250113(VertexAIMistral):
552
+ """Mistral Nemo model on VertexAI released on 2024/07/24."""
553
+
554
+ model = 'codestral-2501'
555
+
556
+
557
+ class VertexAICodestral_20240529(VertexAIMistral):
558
+ """Mistral Nemo model on VertexAI released on 2024/05/29."""
559
+
560
+ model = 'codestral@2405'
561
+ # pylint: enable=invalid-name
@@ -17,6 +17,8 @@ import os
17
17
  import unittest
18
18
  from unittest import mock
19
19
 
20
+ from google.auth import exceptions
21
+ import langfun.core as lf
20
22
  from langfun.core.llms import vertexai
21
23
 
22
24
 
@@ -48,5 +50,55 @@ class VertexAITest(unittest.TestCase):
48
50
  del os.environ['VERTEXAI_LOCATION']
49
51
 
50
52
 
53
+ class VertexAIAnthropicTest(unittest.TestCase):
54
+ """Tests for VertexAI Anthropic models."""
55
+
56
+ def test_basics(self):
57
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
58
+ lm = vertexai.VertexAIClaude3_5_Sonnet_20241022()
59
+ lm('hi')
60
+
61
+ model = vertexai.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
62
+
63
+ # NOTE(daiyip): For OSS users, default credentials are not available unless
64
+ # users have already set up their GCP project. Therefore we ignore the
65
+ # exception here.
66
+ try:
67
+ model._initialize()
68
+ except exceptions.DefaultCredentialsError:
69
+ pass
70
+
71
+ self.assertEqual(
72
+ model.api_endpoint,
73
+ (
74
+ 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
75
+ 'langfun/locations/us-east5/publishers/anthropic/'
76
+ 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
77
+ )
78
+ )
79
+ self.assertEqual(
80
+ model.headers,
81
+ {
82
+ 'Content-Type': 'application/json; charset=utf-8',
83
+ },
84
+ )
85
+ request = model.request(
86
+ lf.UserMessage('hi'), lf.LMSamplingOptions(temperature=0.0),
87
+ )
88
+ self.assertEqual(
89
+ request,
90
+ {
91
+ 'anthropic_version': 'vertex-2023-10-16',
92
+ 'max_tokens': 8192,
93
+ 'messages': [
94
+ {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
95
+ ],
96
+ 'stream': False,
97
+ 'temperature': 0.0,
98
+ 'top_k': 40,
99
+ },
100
+ )
101
+
102
+
51
103
  if __name__ == '__main__':
52
104
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: langfun
3
- Version: 0.1.2.dev202501160804
3
+ Version: 0.1.2.dev202501180803
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -71,16 +71,16 @@ langfun/core/eval/v2/reporting.py,sha256=QOp5jX761Esvi5w_UIRLDqPY_XRO6ru02-DOrdq
71
71
  langfun/core/eval/v2/reporting_test.py,sha256=UmYSAQvD3AIXsSyWQ-WD2uLtEISYpmBeoKY5u5Qwc8E,5696
72
72
  langfun/core/eval/v2/runners.py,sha256=DKEmSlGXjOXKWFdBhTpLy7tMsBHZHd1Brl3hWIngsSQ,15931
73
73
  langfun/core/eval/v2/runners_test.py,sha256=A37fKK2MvAVTiShsg_laluJzJ9AuAQn52k7HPbfD0Ks,11666
74
- langfun/core/llms/__init__.py,sha256=Ntr0kvHc17VEZ5EV9fCoYY1kzRvQxCoZrtDRYNiMWCs,6742
75
- langfun/core/llms/anthropic.py,sha256=a5MmnFsBA0CbfvwzXT1v_0fqLRMrhUNdh1tx6469PQ4,14357
76
- langfun/core/llms/anthropic_test.py,sha256=-2U4kc_pgBM7wqxu8RuxzyHPGww1EAWqKUvN4PW8Btw,8058
74
+ langfun/core/llms/__init__.py,sha256=50mJagAgkIhMwhOyHxGq_O5st4HhpnE-okeYzc7GU6c,7667
75
+ langfun/core/llms/anthropic.py,sha256=z_DWDpR1VKNzv6wq-9CXLzWdqCDXRKuVFacJNpgBqAs,10826
76
+ langfun/core/llms/anthropic_test.py,sha256=zZ2eSP8hhVv-RDSWxT7wX-NS5DfGfQmCjS9P0pusAHM,6556
77
77
  langfun/core/llms/compositional.py,sha256=csW_FLlgL-tpeyCOTVvfUQkMa_zCN5Y2I-YbSNuK27U,2872
78
78
  langfun/core/llms/compositional_test.py,sha256=4eTnOer-DncRKGaIJW2ZQQMLnt5r2R0UIx_DYOvGAQo,2027
79
79
  langfun/core/llms/deepseek.py,sha256=Y7DlLUWrukbPVyBMesppd-m75Q-PxD0b3KnMKaoY_8I,3744
80
80
  langfun/core/llms/deepseek_test.py,sha256=dS72i52bwMpCN4dJDvpJI59AnNChpwxS5eYYFrhGh90,1843
81
81
  langfun/core/llms/fake.py,sha256=gCHBYBLvBCsC78HI1hpoqXCS-p1FMTgY1P1qh_sGBPk,3070
82
82
  langfun/core/llms/fake_test.py,sha256=2h13qkwEz_JR0mtUDPxdAhQo7MueXaFSwsD2DIRDW9g,7653
83
- langfun/core/llms/gemini.py,sha256=tfM4vrt0WnvnrxRhWXZWh7Gp8dYYfMnSbi9uOstkSak,17399
83
+ langfun/core/llms/gemini.py,sha256=itwTCmQHRjwSjt7_UzFfaat23gyRL-El4qmJrg-OGVA,17398
84
84
  langfun/core/llms/gemini_test.py,sha256=2ERhYWCJwnfDTQbCaZHFuB1TdWJFrOBS7yyCBInIdQk,6129
85
85
  langfun/core/llms/google_genai.py,sha256=85Vmx5QmsziON03PRsFQINSu5NF6pAAuFFhUdDteWGc,3662
86
86
  langfun/core/llms/google_genai_test.py,sha256=JZf_cbQ4GGGpwiQCLjFJn7V4jxBBqgZhIx91AzbGKVo,1250
@@ -94,8 +94,8 @@ langfun/core/llms/openai_compatible_test.py,sha256=0uFYhCiuHo2Wrlgj16-GRG6rW8P6E
94
94
  langfun/core/llms/openai_test.py,sha256=m85YjGCvWvV5ZYagjC0FqI0FcqyCEVCbUUs8Wm3iUrc,2475
95
95
  langfun/core/llms/rest.py,sha256=sWbYUV8S3SuOg9giq7xwD-xDRfaF7NP_ig7bI52-Rj4,3442
96
96
  langfun/core/llms/rest_test.py,sha256=zWGiI08f9gXsoQPJS9TlX1zD2uQLrJUB-1VpAJXRHfs,3475
97
- langfun/core/llms/vertexai.py,sha256=MuwLPTJ6-9x2uRDCSM1_biPK6M76FFlL1ezf5OmobDA,5504
98
- langfun/core/llms/vertexai_test.py,sha256=iXjmQs7TNiwcueoaRGpdp4KnASkDJaTP__Z9QroN8zQ,1787
97
+ langfun/core/llms/vertexai.py,sha256=SVvLTqQZ6Ha8wZh3azkh4g3O838CpNkuP3XlgIrLMKo,15751
98
+ langfun/core/llms/vertexai_test.py,sha256=6eLQOyeL5iGZOIWb39sFcf1TgYD_6TBGYdMO4UIvhf4,3333
99
99
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
100
100
  langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
101
101
  langfun/core/llms/cache/in_memory.py,sha256=i58oiQL28RDsq37dwqgVpC2mBETJjIEFS20yHiV5MKU,5185
@@ -146,8 +146,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
146
146
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
147
147
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
148
148
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
149
- langfun-0.1.2.dev202501160804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
150
- langfun-0.1.2.dev202501160804.dist-info/METADATA,sha256=_XM3ancZIb8-33gpRxLKmdJOBZsMfd1_2-4otzha19Q,8172
151
- langfun-0.1.2.dev202501160804.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
152
- langfun-0.1.2.dev202501160804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
153
- langfun-0.1.2.dev202501160804.dist-info/RECORD,,
149
+ langfun-0.1.2.dev202501180803.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
150
+ langfun-0.1.2.dev202501180803.dist-info/METADATA,sha256=W9jkpCCOZx-Tl8sNz3y1IdVZNG48qcjs21airG2TTI0,8172
151
+ langfun-0.1.2.dev202501180803.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
152
+ langfun-0.1.2.dev202501180803.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
153
+ langfun-0.1.2.dev202501180803.dist-info/RECORD,,