langfun 0.1.2.dev202501080804__py3-none-any.whl → 0.1.2.dev202501240804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. langfun/core/__init__.py +1 -6
  2. langfun/core/coding/python/__init__.py +5 -11
  3. langfun/core/coding/python/correction.py +4 -7
  4. langfun/core/coding/python/correction_test.py +2 -3
  5. langfun/core/coding/python/execution.py +22 -211
  6. langfun/core/coding/python/execution_test.py +11 -90
  7. langfun/core/coding/python/generation.py +3 -2
  8. langfun/core/coding/python/generation_test.py +2 -2
  9. langfun/core/coding/python/parsing.py +108 -194
  10. langfun/core/coding/python/parsing_test.py +2 -105
  11. langfun/core/component.py +11 -273
  12. langfun/core/component_test.py +2 -29
  13. langfun/core/concurrent.py +187 -82
  14. langfun/core/concurrent_test.py +28 -19
  15. langfun/core/console.py +7 -3
  16. langfun/core/eval/base.py +2 -3
  17. langfun/core/eval/v2/evaluation.py +3 -1
  18. langfun/core/eval/v2/reporting.py +8 -4
  19. langfun/core/language_model.py +84 -8
  20. langfun/core/language_model_test.py +84 -29
  21. langfun/core/llms/__init__.py +46 -11
  22. langfun/core/llms/anthropic.py +1 -123
  23. langfun/core/llms/anthropic_test.py +0 -48
  24. langfun/core/llms/deepseek.py +117 -0
  25. langfun/core/llms/deepseek_test.py +61 -0
  26. langfun/core/llms/gemini.py +1 -1
  27. langfun/core/llms/groq.py +12 -99
  28. langfun/core/llms/groq_test.py +31 -137
  29. langfun/core/llms/llama_cpp.py +17 -54
  30. langfun/core/llms/llama_cpp_test.py +2 -34
  31. langfun/core/llms/openai.py +9 -147
  32. langfun/core/llms/openai_compatible.py +179 -0
  33. langfun/core/llms/openai_compatible_test.py +495 -0
  34. langfun/core/llms/openai_test.py +13 -423
  35. langfun/core/llms/rest_test.py +1 -1
  36. langfun/core/llms/vertexai.py +387 -18
  37. langfun/core/llms/vertexai_test.py +52 -0
  38. langfun/core/message_test.py +3 -3
  39. langfun/core/modalities/mime.py +8 -0
  40. langfun/core/modalities/mime_test.py +19 -4
  41. langfun/core/modality_test.py +0 -1
  42. langfun/core/structured/mapping.py +13 -13
  43. langfun/core/structured/mapping_test.py +2 -2
  44. langfun/core/structured/schema.py +16 -8
  45. langfun/core/structured/schema_generation.py +1 -1
  46. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/METADATA +13 -2
  47. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/RECORD +50 -52
  48. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/WHEEL +1 -1
  49. langfun/core/coding/python/errors.py +0 -108
  50. langfun/core/coding/python/errors_test.py +0 -99
  51. langfun/core/coding/python/permissions.py +0 -90
  52. langfun/core/coding/python/permissions_test.py +0 -86
  53. langfun/core/text_formatting.py +0 -168
  54. langfun/core/text_formatting_test.py +0 -65
  55. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/LICENSE +0 -0
  56. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,10 +15,13 @@
15
15
 
16
16
  import functools
17
17
  import os
18
- from typing import Annotated, Any
18
+ from typing import Annotated, Any, Literal
19
19
 
20
20
  import langfun.core as lf
21
+ from langfun.core.llms import anthropic
21
22
  from langfun.core.llms import gemini
23
+ from langfun.core.llms import openai_compatible
24
+ from langfun.core.llms import rest
22
25
  import pyglove as pg
23
26
 
24
27
  try:
@@ -36,10 +39,21 @@ except ImportError:
36
39
  Credentials = Any
37
40
 
38
41
 
39
- @lf.use_init_args(['model'])
40
- @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
41
- class VertexAI(gemini.Gemini):
42
- """Language model served on VertexAI with REST API."""
42
+ @pg.use_init_args(['api_endpoint'])
43
+ class VertexAI(rest.REST):
44
+ """Base class for VertexAI models.
45
+
46
+ This class handles the authentication of vertex AI models. Subclasses
47
+ should implement `request` and `result` methods, as well as the `api_endpoint`
48
+ property. Or let users to provide them as __init__ arguments.
49
+
50
+ Please check out VertexAIGemini in `gemini.py` as an example.
51
+ """
52
+
53
+ model: Annotated[
54
+ str | None,
55
+ 'Model ID.'
56
+ ] = None
43
57
 
44
58
  project: Annotated[
45
59
  str | None,
@@ -95,7 +109,7 @@ class VertexAI(gemini.Gemini):
95
109
  credentials = self.credentials
96
110
  if credentials is None:
97
111
  # Use default credentials.
98
- credentials = google_auth.default(
112
+ credentials, _ = google_auth.default(
99
113
  scopes=['https://www.googleapis.com/auth/cloud-platform']
100
114
  )
101
115
  self._credentials = credentials
@@ -114,6 +128,17 @@ class VertexAI(gemini.Gemini):
114
128
  s.headers.update(self.headers or {})
115
129
  return s
116
130
 
131
+
132
+ #
133
+ # Gemini models served by Vertex AI.
134
+ #
135
+
136
+
137
+ @pg.use_init_args(['model'])
138
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
139
+ class VertexAIGemini(VertexAI, gemini.Gemini):
140
+ """Gemini models served by Vertex AI.."""
141
+
117
142
  @property
118
143
  def api_endpoint(self) -> str:
119
144
  assert self._api_initialized
@@ -124,7 +149,7 @@ class VertexAI(gemini.Gemini):
124
149
  )
125
150
 
126
151
 
127
- class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
152
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAIGemini): # pylint: disable=invalid-name
128
153
  """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
129
154
 
130
155
  api_version = 'v1alpha'
@@ -132,61 +157,405 @@ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=i
132
157
  timeout = None
133
158
 
134
159
 
135
- class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
160
+ class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
136
161
  """Vertex AI Gemini 2.0 Flash model."""
137
162
 
138
163
  model = 'gemini-2.0-flash-exp'
139
164
 
140
165
 
141
- class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
166
+ class VertexAIGeminiExp_20241206(VertexAIGemini): # pylint: disable=invalid-name
142
167
  """Vertex AI Gemini Experimental model launched on 12/06/2024."""
143
168
 
144
169
  model = 'gemini-exp-1206'
145
170
 
146
171
 
147
- class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
172
+ class VertexAIGeminiExp_20241114(VertexAIGemini): # pylint: disable=invalid-name
148
173
  """Vertex AI Gemini Experimental model launched on 11/14/2024."""
149
174
 
150
175
  model = 'gemini-exp-1114'
151
176
 
152
177
 
153
- class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
178
+ class VertexAIGeminiPro1_5(VertexAIGemini): # pylint: disable=invalid-name
154
179
  """Vertex AI Gemini 1.5 Pro model."""
155
180
 
156
181
  model = 'gemini-1.5-pro-latest'
157
182
 
158
183
 
159
- class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
184
+ class VertexAIGeminiPro1_5_002(VertexAIGemini): # pylint: disable=invalid-name
160
185
  """Vertex AI Gemini 1.5 Pro model."""
161
186
 
162
187
  model = 'gemini-1.5-pro-002'
163
188
 
164
189
 
165
- class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
190
+ class VertexAIGeminiPro1_5_001(VertexAIGemini): # pylint: disable=invalid-name
166
191
  """Vertex AI Gemini 1.5 Pro model."""
167
192
 
168
193
  model = 'gemini-1.5-pro-001'
169
194
 
170
195
 
171
- class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
196
+ class VertexAIGeminiFlash1_5(VertexAIGemini): # pylint: disable=invalid-name
172
197
  """Vertex AI Gemini 1.5 Flash model."""
173
198
 
174
199
  model = 'gemini-1.5-flash'
175
200
 
176
201
 
177
- class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
202
+ class VertexAIGeminiFlash1_5_002(VertexAIGemini): # pylint: disable=invalid-name
178
203
  """Vertex AI Gemini 1.5 Flash model."""
179
204
 
180
205
  model = 'gemini-1.5-flash-002'
181
206
 
182
207
 
183
- class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
208
+ class VertexAIGeminiFlash1_5_001(VertexAIGemini): # pylint: disable=invalid-name
184
209
  """Vertex AI Gemini 1.5 Flash model."""
185
210
 
186
211
  model = 'gemini-1.5-flash-001'
187
212
 
188
213
 
189
- class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
214
+ class VertexAIGeminiPro1(VertexAIGemini): # pylint: disable=invalid-name
190
215
  """Vertex AI Gemini 1.0 Pro model."""
191
216
 
192
217
  model = 'gemini-1.0-pro'
218
+
219
+
220
+ #
221
+ # Anthropic models on Vertex AI.
222
+ #
223
+
224
+
225
+ @pg.use_init_args(['model'])
226
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
227
+ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
228
+ """Anthropic models on VertexAI."""
229
+
230
+ location: Annotated[
231
+ Literal['us-east5', 'europe-west1'],
232
+ 'GCP location with Anthropic models hosted.'
233
+ ] = 'us-east5'
234
+
235
+ api_version = 'vertex-2023-10-16'
236
+
237
+ @property
238
+ def headers(self):
239
+ return {
240
+ 'Content-Type': 'application/json; charset=utf-8',
241
+ }
242
+
243
+ @property
244
+ def api_endpoint(self) -> str:
245
+ return (
246
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
247
+ f'{self._project}/locations/{self.location}/publishers/anthropic/'
248
+ f'models/{self.model}:streamRawPredict'
249
+ )
250
+
251
+ def request(
252
+ self,
253
+ prompt: lf.Message,
254
+ sampling_options: lf.LMSamplingOptions
255
+ ):
256
+ request = super().request(prompt, sampling_options)
257
+ request['anthropic_version'] = self.api_version
258
+ del request['model']
259
+ return request
260
+
261
+
262
+ # pylint: disable=invalid-name
263
+
264
+
265
+ class VertexAIClaude3_Opus_20240229(VertexAIAnthropic):
266
+ """Anthropic's Claude 3 Opus model on VertexAI."""
267
+ model = 'claude-3-opus@20240229'
268
+
269
+
270
+ class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic):
271
+ """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
272
+ model = 'claude-3-5-sonnet-v2@20241022'
273
+
274
+
275
+ class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic):
276
+ """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
277
+ model = 'claude-3-5-sonnet@20240620'
278
+
279
+
280
+ class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic):
281
+ """Anthropic's Claude 3.5 Haiku model on VertexAI."""
282
+ model = 'claude-3-5-haiku@20241022'
283
+
284
+ # pylint: enable=invalid-name
285
+
286
+ #
287
+ # Llama models on Vertex AI.
288
+ # pylint: disable=line-too-long
289
+ # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#meta-models
290
+ # pylint: enable=line-too-long
291
+
292
+ LLAMA_MODELS = {
293
+ 'llama-3.2-90b-vision-instruct-maas': pg.Dict(
294
+ latest_update='2024-09-25',
295
+ in_service=True,
296
+ rpm=0,
297
+ tpm=0,
298
+ # Free during preview.
299
+ cost_per_1m_input_tokens=None,
300
+ cost_per_1m_output_tokens=None,
301
+ ),
302
+ 'llama-3.1-405b-instruct-maas': pg.Dict(
303
+ latest_update='2024-09-25',
304
+ in_service=True,
305
+ rpm=0,
306
+ tpm=0,
307
+ # GA.
308
+ cost_per_1m_input_tokens=5,
309
+ cost_per_1m_output_tokens=16,
310
+ ),
311
+ 'llama-3.1-70b-instruct-maas': pg.Dict(
312
+ latest_update='2024-09-25',
313
+ in_service=True,
314
+ rpm=0,
315
+ tpm=0,
316
+ # Free during preview.
317
+ cost_per_1m_input_tokens=None,
318
+ cost_per_1m_output_tokens=None,
319
+ ),
320
+ 'llama-3.1-8b-instruct-maas': pg.Dict(
321
+ latest_update='2024-09-25',
322
+ in_service=True,
323
+ rpm=0,
324
+ tpm=0,
325
+ # Free during preview.
326
+ cost_per_1m_input_tokens=None,
327
+ cost_per_1m_output_tokens=None,
328
+ )
329
+ }
330
+
331
+
332
+ @pg.use_init_args(['model'])
333
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
334
+ class VertexAILlama(VertexAI, openai_compatible.OpenAICompatible):
335
+ """Llama models on VertexAI."""
336
+
337
+ model: pg.typing.Annotated[
338
+ pg.typing.Enum(pg.MISSING_VALUE, list(LLAMA_MODELS.keys())),
339
+ 'Llama model ID.',
340
+ ]
341
+
342
+ locations: Annotated[
343
+ Literal['us-central1'],
344
+ (
345
+ 'GCP locations with Llama models hosted. '
346
+ 'See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#regions-quotas'
347
+ )
348
+ ] = 'us-central1'
349
+
350
+ @property
351
+ def api_endpoint(self) -> str:
352
+ assert self._api_initialized
353
+ return (
354
+ f'https://{self._location}-aiplatform.googleapis.com/v1beta1/projects/'
355
+ f'{self._project}/locations/{self._location}/endpoints/'
356
+ f'openapi/chat/completions'
357
+ )
358
+
359
+ def request(
360
+ self,
361
+ prompt: lf.Message,
362
+ sampling_options: lf.LMSamplingOptions
363
+ ):
364
+ request = super().request(prompt, sampling_options)
365
+ request['model'] = f'meta/{self.model}'
366
+ return request
367
+
368
+ @property
369
+ def max_concurrency(self) -> int:
370
+ rpm = LLAMA_MODELS[self.model].get('rpm', 0)
371
+ tpm = LLAMA_MODELS[self.model].get('tpm', 0)
372
+ return self.rate_to_max_concurrency(
373
+ requests_per_min=rpm, tokens_per_min=tpm
374
+ )
375
+
376
+ def estimate_cost(
377
+ self,
378
+ num_input_tokens: int,
379
+ num_output_tokens: int
380
+ ) -> float | None:
381
+ """Estimate the cost based on usage."""
382
+ cost_per_1m_input_tokens = LLAMA_MODELS[self.model].get(
383
+ 'cost_per_1m_input_tokens', None
384
+ )
385
+ cost_per_1m_output_tokens = LLAMA_MODELS[self.model].get(
386
+ 'cost_per_1m_output_tokens', None
387
+ )
388
+ if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
389
+ return None
390
+ return (
391
+ cost_per_1m_input_tokens * num_input_tokens
392
+ + cost_per_1m_output_tokens * num_output_tokens
393
+ ) / 1000_000
394
+
395
+
396
+ # pylint: disable=invalid-name
397
+ class VertexAILlama3_2_90B(VertexAILlama):
398
+ """Llama 3.2 90B vision instruct model on VertexAI."""
399
+
400
+ model = 'llama-3.2-90b-vision-instruct-maas'
401
+
402
+
403
+ class VertexAILlama3_1_405B(VertexAILlama):
404
+ """Llama 3.1 405B vision instruct model on VertexAI."""
405
+
406
+ model = 'llama-3.1-405b-instruct-maas'
407
+
408
+
409
+ class VertexAILlama3_1_70B(VertexAILlama):
410
+ """Llama 3.1 70B vision instruct model on VertexAI."""
411
+
412
+ model = 'llama-3.1-70b-instruct-maas'
413
+
414
+
415
+ class VertexAILlama3_1_8B(VertexAILlama):
416
+ """Llama 3.1 8B vision instruct model on VertexAI."""
417
+
418
+ model = 'llama-3.1-8b-instruct-maas'
419
+ # pylint: enable=invalid-name
420
+
421
+ #
422
+ # Mistral models on Vertex AI.
423
+ # pylint: disable=line-too-long
424
+ # Pricing: https://cloud.google.com/vertex-ai/generative-ai/pricing?_gl=1*ukuk6u*_ga*MjEzMjc4NjM2My4xNzMzODg4OTg3*_ga_WH2QY8WWF5*MTczNzEzNDU1Mi4xMjQuMS4xNzM3MTM0NzczLjU5LjAuMA..#mistral-models
425
+ # pylint: enable=line-too-long
426
+
427
+
428
+ MISTRAL_MODELS = {
429
+ 'mistral-large-2411': pg.Dict(
430
+ latest_update='2024-11-21',
431
+ in_service=True,
432
+ rpm=0,
433
+ tpm=0,
434
+ # GA.
435
+ cost_per_1m_input_tokens=2,
436
+ cost_per_1m_output_tokens=6,
437
+ ),
438
+ 'mistral-large@2407': pg.Dict(
439
+ latest_update='2024-07-24',
440
+ in_service=True,
441
+ rpm=0,
442
+ tpm=0,
443
+ # GA.
444
+ cost_per_1m_input_tokens=2,
445
+ cost_per_1m_output_tokens=6,
446
+ ),
447
+ 'mistral-nemo@2407': pg.Dict(
448
+ latest_update='2024-07-24',
449
+ in_service=True,
450
+ rpm=0,
451
+ tpm=0,
452
+ # GA.
453
+ cost_per_1m_input_tokens=0.15,
454
+ cost_per_1m_output_tokens=0.15,
455
+ ),
456
+ 'codestral-2501': pg.Dict(
457
+ latest_update='2025-01-13',
458
+ in_service=True,
459
+ rpm=0,
460
+ tpm=0,
461
+ # GA.
462
+ cost_per_1m_input_tokens=0.3,
463
+ cost_per_1m_output_tokens=0.9,
464
+ ),
465
+ 'codestral@2405': pg.Dict(
466
+ latest_update='2024-05-29',
467
+ in_service=True,
468
+ rpm=0,
469
+ tpm=0,
470
+ # GA.
471
+ cost_per_1m_input_tokens=0.2,
472
+ cost_per_1m_output_tokens=0.6,
473
+ ),
474
+ }
475
+
476
+
477
+ @pg.use_init_args(['model'])
478
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
479
+ class VertexAIMistral(VertexAI, openai_compatible.OpenAICompatible):
480
+ """Mistral AI models on VertexAI."""
481
+
482
+ model: pg.typing.Annotated[
483
+ pg.typing.Enum(pg.MISSING_VALUE, list(MISTRAL_MODELS.keys())),
484
+ 'Mistral model ID.',
485
+ ]
486
+
487
+ locations: Annotated[
488
+ Literal['us-central1', 'europe-west4'],
489
+ (
490
+ 'GCP locations with Mistral models hosted. '
491
+ 'See https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/mistral#regions-quotas'
492
+ )
493
+ ] = 'us-central1'
494
+
495
+ @property
496
+ def api_endpoint(self) -> str:
497
+ assert self._api_initialized
498
+ return (
499
+ f'https://{self._location}-aiplatform.googleapis.com/v1/projects/'
500
+ f'{self._project}/locations/{self._location}/publishers/mistralai/'
501
+ f'models/{self.model}:rawPredict'
502
+ )
503
+
504
+ @property
505
+ def max_concurrency(self) -> int:
506
+ rpm = MISTRAL_MODELS[self.model].get('rpm', 0)
507
+ tpm = MISTRAL_MODELS[self.model].get('tpm', 0)
508
+ return self.rate_to_max_concurrency(
509
+ requests_per_min=rpm, tokens_per_min=tpm
510
+ )
511
+
512
+ def estimate_cost(
513
+ self,
514
+ num_input_tokens: int,
515
+ num_output_tokens: int
516
+ ) -> float | None:
517
+ """Estimate the cost based on usage."""
518
+ cost_per_1m_input_tokens = MISTRAL_MODELS[self.model].get(
519
+ 'cost_per_1m_input_tokens', None
520
+ )
521
+ cost_per_1m_output_tokens = MISTRAL_MODELS[self.model].get(
522
+ 'cost_per_1m_output_tokens', None
523
+ )
524
+ if cost_per_1m_output_tokens is None or cost_per_1m_input_tokens is None:
525
+ return None
526
+ return (
527
+ cost_per_1m_input_tokens * num_input_tokens
528
+ + cost_per_1m_output_tokens * num_output_tokens
529
+ ) / 1000_000
530
+
531
+
532
+ # pylint: disable=invalid-name
533
+ class VertexAIMistralLarge_20241121(VertexAIMistral):
534
+ """Mistral Large model on VertexAI released on 2024/11/21."""
535
+
536
+ model = 'mistral-large-2411'
537
+
538
+
539
+ class VertexAIMistralLarge_20240724(VertexAIMistral):
540
+ """Mistral Large model on VertexAI released on 2024/07/24."""
541
+
542
+ model = 'mistral-large@2407'
543
+
544
+
545
+ class VertexAIMistralNemo_20240724(VertexAIMistral):
546
+ """Mistral Nemo model on VertexAI released on 2024/07/24."""
547
+
548
+ model = 'mistral-nemo@2407'
549
+
550
+
551
+ class VertexAICodestral_20250113(VertexAIMistral):
552
+ """Mistral Nemo model on VertexAI released on 2024/07/24."""
553
+
554
+ model = 'codestral-2501'
555
+
556
+
557
+ class VertexAICodestral_20240529(VertexAIMistral):
558
+ """Mistral Nemo model on VertexAI released on 2024/05/29."""
559
+
560
+ model = 'codestral@2405'
561
+ # pylint: enable=invalid-name
@@ -17,6 +17,8 @@ import os
17
17
  import unittest
18
18
  from unittest import mock
19
19
 
20
+ from google.auth import exceptions
21
+ import langfun.core as lf
20
22
  from langfun.core.llms import vertexai
21
23
 
22
24
 
@@ -48,5 +50,55 @@ class VertexAITest(unittest.TestCase):
48
50
  del os.environ['VERTEXAI_LOCATION']
49
51
 
50
52
 
53
+ class VertexAIAnthropicTest(unittest.TestCase):
54
+ """Tests for VertexAI Anthropic models."""
55
+
56
+ def test_basics(self):
57
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
58
+ lm = vertexai.VertexAIClaude3_5_Sonnet_20241022()
59
+ lm('hi')
60
+
61
+ model = vertexai.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
62
+
63
+ # NOTE(daiyip): For OSS users, default credentials are not available unless
64
+ # users have already set up their GCP project. Therefore we ignore the
65
+ # exception here.
66
+ try:
67
+ model._initialize()
68
+ except exceptions.DefaultCredentialsError:
69
+ pass
70
+
71
+ self.assertEqual(
72
+ model.api_endpoint,
73
+ (
74
+ 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
75
+ 'langfun/locations/us-east5/publishers/anthropic/'
76
+ 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
77
+ )
78
+ )
79
+ self.assertEqual(
80
+ model.headers,
81
+ {
82
+ 'Content-Type': 'application/json; charset=utf-8',
83
+ },
84
+ )
85
+ request = model.request(
86
+ lf.UserMessage('hi'), lf.LMSamplingOptions(temperature=0.0),
87
+ )
88
+ self.assertEqual(
89
+ request,
90
+ {
91
+ 'anthropic_version': 'vertex-2023-10-16',
92
+ 'max_tokens': 8192,
93
+ 'messages': [
94
+ {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
95
+ ],
96
+ 'stream': False,
97
+ 'temperature': 0.0,
98
+ 'top_k': 40,
99
+ },
100
+ )
101
+
102
+
51
103
  if __name__ == '__main__':
52
104
  unittest.main()