langfun 0.1.2.dev202501080804__py3-none-any.whl → 0.1.2.dev202501240804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. langfun/core/__init__.py +1 -6
  2. langfun/core/coding/python/__init__.py +5 -11
  3. langfun/core/coding/python/correction.py +4 -7
  4. langfun/core/coding/python/correction_test.py +2 -3
  5. langfun/core/coding/python/execution.py +22 -211
  6. langfun/core/coding/python/execution_test.py +11 -90
  7. langfun/core/coding/python/generation.py +3 -2
  8. langfun/core/coding/python/generation_test.py +2 -2
  9. langfun/core/coding/python/parsing.py +108 -194
  10. langfun/core/coding/python/parsing_test.py +2 -105
  11. langfun/core/component.py +11 -273
  12. langfun/core/component_test.py +2 -29
  13. langfun/core/concurrent.py +187 -82
  14. langfun/core/concurrent_test.py +28 -19
  15. langfun/core/console.py +7 -3
  16. langfun/core/eval/base.py +2 -3
  17. langfun/core/eval/v2/evaluation.py +3 -1
  18. langfun/core/eval/v2/reporting.py +8 -4
  19. langfun/core/language_model.py +84 -8
  20. langfun/core/language_model_test.py +84 -29
  21. langfun/core/llms/__init__.py +46 -11
  22. langfun/core/llms/anthropic.py +1 -123
  23. langfun/core/llms/anthropic_test.py +0 -48
  24. langfun/core/llms/deepseek.py +117 -0
  25. langfun/core/llms/deepseek_test.py +61 -0
  26. langfun/core/llms/gemini.py +1 -1
  27. langfun/core/llms/groq.py +12 -99
  28. langfun/core/llms/groq_test.py +31 -137
  29. langfun/core/llms/llama_cpp.py +17 -54
  30. langfun/core/llms/llama_cpp_test.py +2 -34
  31. langfun/core/llms/openai.py +9 -147
  32. langfun/core/llms/openai_compatible.py +179 -0
  33. langfun/core/llms/openai_compatible_test.py +495 -0
  34. langfun/core/llms/openai_test.py +13 -423
  35. langfun/core/llms/rest_test.py +1 -1
  36. langfun/core/llms/vertexai.py +387 -18
  37. langfun/core/llms/vertexai_test.py +52 -0
  38. langfun/core/message_test.py +3 -3
  39. langfun/core/modalities/mime.py +8 -0
  40. langfun/core/modalities/mime_test.py +19 -4
  41. langfun/core/modality_test.py +0 -1
  42. langfun/core/structured/mapping.py +13 -13
  43. langfun/core/structured/mapping_test.py +2 -2
  44. langfun/core/structured/schema.py +16 -8
  45. langfun/core/structured/schema_generation.py +1 -1
  46. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/METADATA +13 -2
  47. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/RECORD +50 -52
  48. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/WHEEL +1 -1
  49. langfun/core/coding/python/errors.py +0 -108
  50. langfun/core/coding/python/errors_test.py +0 -99
  51. langfun/core/coding/python/permissions.py +0 -90
  52. langfun/core/coding/python/permissions_test.py +0 -86
  53. langfun/core/text_formatting.py +0 -168
  54. langfun/core/text_formatting_test.py +0 -65
  55. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/LICENSE +0 -0
  56. {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/top_level.txt +0 -0
@@ -81,6 +81,61 @@ class LMSample(pg.Object):
81
81
  ] = None
82
82
 
83
83
 
84
+ class RetryStats(pg.Object):
85
+ """Retry stats, which is aggregated across multiple retry entries."""
86
+
87
+ num_occurences: Annotated[
88
+ int,
89
+ 'Total number of retry attempts on LLM (excluding the first attempt).',
90
+ ] = 0
91
+ total_wait_interval: Annotated[
92
+ float, 'Total wait interval in seconds due to retry.'
93
+ ] = 0
94
+ total_call_interval: Annotated[
95
+ float, 'Total LLM call interval in seconds.'
96
+ ] = 0
97
+ errors: Annotated[
98
+ dict[str, int],
99
+ 'A Counter of error types encountered during the retry attempts.',
100
+ ] = {}
101
+
102
+ @classmethod
103
+ def from_retry_entries(
104
+ cls, retry_entries: Sequence[concurrent.RetryEntry]
105
+ ) -> 'RetryStats':
106
+ """Creates a RetryStats from a sequence of RetryEntry."""
107
+ if not retry_entries:
108
+ return RetryStats()
109
+ errors = {}
110
+ for retry in retry_entries:
111
+ if retry.error is not None:
112
+ errors[retry.error.__class__.__name__] = (
113
+ errors.get(retry.error.__class__.__name__, 0) + 1
114
+ )
115
+ return RetryStats(
116
+ num_occurences=len(retry_entries) - 1,
117
+ total_wait_interval=sum(e.wait_interval for e in retry_entries),
118
+ total_call_interval=sum(e.call_interval for e in retry_entries),
119
+ errors=errors,
120
+ )
121
+
122
+ def __add__(self, other: 'RetryStats') -> 'RetryStats':
123
+ errors = self.errors.copy()
124
+ for error, count in other.errors.items():
125
+ errors[error] = errors.get(error, 0) + count
126
+ return RetryStats(
127
+ num_occurences=self.num_occurences + other.num_occurences,
128
+ total_wait_interval=self.total_wait_interval
129
+ + other.total_wait_interval,
130
+ total_call_interval=self.total_call_interval
131
+ + other.total_call_interval,
132
+ errors=errors,
133
+ )
134
+
135
+ def __radd__(self, other: 'RetryStats') -> 'RetryStats':
136
+ return self + other
137
+
138
+
84
139
  class LMSamplingUsage(pg.Object):
85
140
  """Usage information per completion."""
86
141
 
@@ -93,8 +148,9 @@ class LMSamplingUsage(pg.Object):
93
148
  (
94
149
  'Estimated cost in US dollars. If None, cost estimating is not '
95
150
  'suppported on the model being queried.'
96
- )
151
+ ),
97
152
  ] = None
153
+ retry_stats: RetryStats = RetryStats()
98
154
 
99
155
  def __bool__(self) -> bool:
100
156
  return self.num_requests > 0
@@ -136,6 +192,7 @@ class LMSamplingUsage(pg.Object):
136
192
  total_tokens=self.total_tokens + other.total_tokens,
137
193
  num_requests=self.num_requests + other.num_requests,
138
194
  estimated_cost=estimated_cost,
195
+ retry_stats=self.retry_stats + other.retry_stats,
139
196
  )
140
197
 
141
198
  def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
@@ -511,7 +568,18 @@ class LanguageModel(component.Component):
511
568
  total_tokens=usage.total_tokens // n,
512
569
  estimated_cost=(
513
570
  usage.estimated_cost / n if usage.estimated_cost else None
514
- )
571
+ ),
572
+ retry_stats=RetryStats(
573
+ num_occurences=usage.retry_stats.num_occurences // n,
574
+ total_wait_interval=usage.retry_stats.total_wait_interval
575
+ / n,
576
+ total_call_interval=usage.retry_stats.total_call_interval
577
+ / n,
578
+ errors={
579
+ error: count // n
580
+ for error, count in usage.retry_stats.errors.items()
581
+ },
582
+ ),
515
583
  )
516
584
 
517
585
  # Track usage.
@@ -584,16 +652,16 @@ class LanguageModel(component.Component):
584
652
 
585
653
  def _parallel_execute_with_currency_control(
586
654
  self,
587
- action: Callable[..., Any],
655
+ action: Callable[..., LMSamplingResult],
588
656
  inputs: Sequence[Any],
589
657
  retry_on_errors: Union[
590
658
  None,
591
659
  Union[Type[BaseException], Tuple[Type[BaseException], str]],
592
660
  Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
593
661
  ] = RetryableLMError,
594
- ) -> Any:
662
+ ) -> list[Any]:
595
663
  """Helper method for subclasses for implementing _sample."""
596
- return concurrent.concurrent_execute(
664
+ executed_jobs = concurrent.concurrent_execute(
597
665
  action,
598
666
  inputs,
599
667
  executor=self.resource_id if self.max_concurrency else None,
@@ -603,7 +671,15 @@ class LanguageModel(component.Component):
603
671
  retry_interval=self.retry_interval,
604
672
  exponential_backoff=self.exponential_backoff,
605
673
  max_retry_interval=self.max_retry_interval,
674
+ return_jobs=True,
606
675
  )
676
+ for job in executed_jobs:
677
+ if isinstance(job.result, LMSamplingResult):
678
+ job.result.usage.rebind(
679
+ retry_stats=RetryStats.from_retry_entries(job.retry_entries),
680
+ skip_notification=True,
681
+ )
682
+ return [job.result for job in executed_jobs]
607
683
 
608
684
  def __call__(
609
685
  self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
@@ -653,7 +729,7 @@ class LanguageModel(component.Component):
653
729
  """Outputs debugging information about the model."""
654
730
  title_suffix = ''
655
731
  if usage.total_tokens != 0:
656
- title_suffix = console.colored(
732
+ title_suffix = pg.colored(
657
733
  f' (total {usage.total_tokens} tokens)', 'red'
658
734
  )
659
735
 
@@ -672,7 +748,7 @@ class LanguageModel(component.Component):
672
748
  """Outputs debugging information about the prompt."""
673
749
  title_suffix = ''
674
750
  if usage.prompt_tokens != 0:
675
- title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
751
+ title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
676
752
 
677
753
  console.write(
678
754
  # We use metadata 'formatted_text' for scenarios where the prompt text
@@ -703,7 +779,7 @@ class LanguageModel(component.Component):
703
779
  if usage.completion_tokens != 0:
704
780
  title_suffix += f'{usage.completion_tokens} tokens '
705
781
  title_suffix += f'in {elapse:.2f} seconds)'
706
- title_suffix = console.colored(title_suffix, 'red')
782
+ title_suffix = pg.colored(title_suffix, 'red')
707
783
 
708
784
  console.write(
709
785
  str(response) + '\n',
@@ -35,34 +35,34 @@ class MockModel(lm_lib.LanguageModel):
35
35
  ) -> list[lm_lib.LMSamplingResult]:
36
36
  context = pg.Dict(attempt=0)
37
37
 
38
- def fake_sample(prompts):
38
+ def fake_sample(prompt):
39
39
  if context.attempt >= self.failures_before_attempt:
40
- return [
41
- lm_lib.LMSamplingResult(
42
- [
43
- lm_lib.LMSample( # pylint: disable=g-complex-comprehension
44
- response=prompt.text * self.sampling_options.top_k,
45
- score=self.sampling_options.temperature or -1.0,
46
- )
47
- ],
48
- usage=lm_lib.LMSamplingUsage(
49
- prompt_tokens=100,
50
- completion_tokens=100,
51
- total_tokens=200,
52
- estimated_cost=1.0,
53
- ),
54
- )
55
- for prompt in prompts
56
- ]
57
- context.attempt += 1
40
+ return lm_lib.LMSamplingResult(
41
+ [
42
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
43
+ response=prompt.text * self.sampling_options.top_k,
44
+ score=self.sampling_options.temperature or -1.0,
45
+ )
46
+ ],
47
+ usage=lm_lib.LMSamplingUsage(
48
+ prompt_tokens=100,
49
+ completion_tokens=100,
50
+ total_tokens=200,
51
+ estimated_cost=1.0,
52
+ ),
53
+ )
54
+ else:
55
+ context.attempt += 1
58
56
  raise ValueError('Failed to sample prompts.')
59
57
 
60
- return concurrent.with_retry(
61
- fake_sample,
62
- retry_on_errors=ValueError,
63
- max_attempts=self.max_attempts,
64
- retry_interval=1,
65
- )(prompts)
58
+ results = self._parallel_execute_with_currency_control(
59
+ fake_sample, prompts, retry_on_errors=ValueError
60
+ )
61
+ for result in results:
62
+ result.usage.retry_stats.rebind(
63
+ total_call_interval=0, skip_notification=True
64
+ )
65
+ return results
66
66
 
67
67
  @property
68
68
  def model_id(self) -> str:
@@ -448,13 +448,50 @@ class LanguageModelTest(unittest.TestCase):
448
448
 
449
449
  def test_retry(self):
450
450
  lm = MockModel(
451
- failures_before_attempt=1, top_k=1,
451
+ failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
452
452
  )
453
453
  with self.assertRaisesRegex(
454
454
  concurrent.RetryError, 'Calling .* failed after 1 attempts'
455
455
  ):
456
456
  lm('foo', max_attempts=1)
457
- self.assertEqual(lm('foo', max_attempts=2), 'foo')
457
+
458
+ usage = lm_lib.LMSamplingUsage(
459
+ prompt_tokens=100,
460
+ completion_tokens=100,
461
+ total_tokens=200,
462
+ num_requests=1,
463
+ estimated_cost=1.0,
464
+ retry_stats=lm_lib.RetryStats(
465
+ num_occurences=1,
466
+ total_wait_interval=1,
467
+ errors={'ValueError': 1},
468
+ ),
469
+ )
470
+ out = lm.sample(['foo'])
471
+ self.assertEqual(
472
+ # lm.sample(['foo'], max_attempts=2),
473
+ out,
474
+ [
475
+ lm_lib.LMSamplingResult(
476
+ [
477
+ lm_lib.LMSample(
478
+ message_lib.AIMessage(
479
+ 'foo',
480
+ score=-1.0,
481
+ logprobs=None,
482
+ is_cached=False,
483
+ usage=usage,
484
+ tags=['lm-response'],
485
+ ),
486
+ score=-1.0,
487
+ logprobs=None,
488
+ )
489
+ ],
490
+ usage=usage,
491
+ is_cached=False,
492
+ )
493
+ ],
494
+ )
458
495
 
459
496
  def test_debug(self):
460
497
  class Image(modality.Modality):
@@ -755,16 +792,34 @@ class LMSamplingUsageTest(unittest.TestCase):
755
792
 
756
793
  def test_add(self):
757
794
  usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
795
+ usage1.rebind(retry_stats=lm_lib.RetryStats(1, 3, 4, {'e1': 1}))
758
796
  usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
759
797
  self.assertEqual(usage1 + usage2, usage1 + usage2)
760
798
  self.assertIs(usage1 + None, usage1)
761
799
  self.assertIs(None + usage1, usage1)
762
800
  usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None)
801
+ usage3.rebind(retry_stats=lm_lib.RetryStats(2, 4, 5, {'e1': 2, 'e2': 3}))
763
802
  self.assertEqual(
764
- usage1 + usage3, lm_lib.LMSamplingUsage(200, 400, 600, 8, 5.0)
803
+ usage1 + usage3,
804
+ lm_lib.LMSamplingUsage(
805
+ 200,
806
+ 400,
807
+ 600,
808
+ 8,
809
+ 5.0,
810
+ retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
811
+ ),
765
812
  )
766
813
  self.assertEqual(
767
- usage3 + usage1, lm_lib.LMSamplingUsage(200, 400, 600, 8, 5.0)
814
+ usage3 + usage1,
815
+ lm_lib.LMSamplingUsage(
816
+ 200,
817
+ 400,
818
+ 600,
819
+ 8,
820
+ 5.0,
821
+ retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
822
+ ),
768
823
  )
769
824
 
770
825
  def test_usage_not_available(self):
@@ -27,8 +27,15 @@ from langfun.core.llms.fake import StaticSequence
27
27
  # Compositional models.
28
28
  from langfun.core.llms.compositional import RandomChoice
29
29
 
30
- # REST-based models.
30
+ # Base models by request/response protocol.
31
31
  from langfun.core.llms.rest import REST
32
+ from langfun.core.llms.openai_compatible import OpenAICompatible
33
+ from langfun.core.llms.gemini import Gemini
34
+ from langfun.core.llms.anthropic import Anthropic
35
+
36
+ # Base models by serving platforms.
37
+ from langfun.core.llms.vertexai import VertexAI
38
+ from langfun.core.llms.groq import Groq
32
39
 
33
40
  # Gemini models.
34
41
  from langfun.core.llms.google_genai import GenAI
@@ -44,7 +51,7 @@ from langfun.core.llms.google_genai import GeminiFlash1_5_002
44
51
  from langfun.core.llms.google_genai import GeminiFlash1_5_001
45
52
  from langfun.core.llms.google_genai import GeminiPro1
46
53
 
47
- from langfun.core.llms.vertexai import VertexAI
54
+ from langfun.core.llms.vertexai import VertexAIGemini
48
55
  from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
49
56
  from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
50
57
  from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
@@ -111,20 +118,34 @@ from langfun.core.llms.openai import Gpt3Curie
111
118
  from langfun.core.llms.openai import Gpt3Babbage
112
119
  from langfun.core.llms.openai import Gpt3Ada
113
120
 
114
- from langfun.core.llms.anthropic import Anthropic
121
+ # Anthropic models.
122
+
115
123
  from langfun.core.llms.anthropic import Claude35Sonnet
116
124
  from langfun.core.llms.anthropic import Claude35Sonnet20241022
117
125
  from langfun.core.llms.anthropic import Claude35Sonnet20240620
118
126
  from langfun.core.llms.anthropic import Claude3Opus
119
127
  from langfun.core.llms.anthropic import Claude3Sonnet
120
128
  from langfun.core.llms.anthropic import Claude3Haiku
121
- from langfun.core.llms.anthropic import VertexAIAnthropic
122
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20241022
123
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20240620
124
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Haiku_20241022
125
- from langfun.core.llms.anthropic import VertexAIClaude3_Opus_20240229
126
129
 
127
- from langfun.core.llms.groq import Groq
130
+ from langfun.core.llms.vertexai import VertexAIAnthropic
131
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20241022
132
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20240620
133
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Haiku_20241022
134
+ from langfun.core.llms.vertexai import VertexAIClaude3_Opus_20240229
135
+
136
+ # Misc open source models.
137
+
138
+ # Gemma models.
139
+ from langfun.core.llms.groq import GroqGemma2_9B_IT
140
+ from langfun.core.llms.groq import GroqGemma_7B_IT
141
+
142
+ # Llama models.
143
+ from langfun.core.llms.vertexai import VertexAILlama
144
+ from langfun.core.llms.vertexai import VertexAILlama3_2_90B
145
+ from langfun.core.llms.vertexai import VertexAILlama3_1_405B
146
+ from langfun.core.llms.vertexai import VertexAILlama3_1_70B
147
+ from langfun.core.llms.vertexai import VertexAILlama3_1_8B
148
+
128
149
  from langfun.core.llms.groq import GroqLlama3_2_3B
129
150
  from langfun.core.llms.groq import GroqLlama3_2_1B
130
151
  from langfun.core.llms.groq import GroqLlama3_1_70B
@@ -132,15 +153,29 @@ from langfun.core.llms.groq import GroqLlama3_1_8B
132
153
  from langfun.core.llms.groq import GroqLlama3_70B
133
154
  from langfun.core.llms.groq import GroqLlama3_8B
134
155
  from langfun.core.llms.groq import GroqLlama2_70B
156
+
157
+ # Mistral models.
158
+ from langfun.core.llms.vertexai import VertexAIMistral
159
+ from langfun.core.llms.vertexai import VertexAIMistralLarge_20241121
160
+ from langfun.core.llms.vertexai import VertexAIMistralLarge_20240724
161
+ from langfun.core.llms.vertexai import VertexAIMistralNemo_20240724
162
+ from langfun.core.llms.vertexai import VertexAICodestral_20250113
163
+ from langfun.core.llms.vertexai import VertexAICodestral_20240529
164
+
135
165
  from langfun.core.llms.groq import GroqMistral_8x7B
136
- from langfun.core.llms.groq import GroqGemma2_9B_IT
137
- from langfun.core.llms.groq import GroqGemma_7B_IT
166
+
167
+ # DeepSeek models.
168
+ from langfun.core.llms.deepseek import DeepSeek
169
+ from langfun.core.llms.deepseek import DeepSeekChat
170
+
171
+ # Whisper models.
138
172
  from langfun.core.llms.groq import GroqWhisper_Large_v3
139
173
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
140
174
 
141
175
  # LLaMA C++ models.
142
176
  from langfun.core.llms.llama_cpp import LlamaCppRemote
143
177
 
178
+
144
179
  # Placeholder for Google-internal imports.
145
180
 
146
181
  # Include cache as sub-module.
@@ -14,9 +14,8 @@
14
14
  """Language models from Anthropic."""
15
15
 
16
16
  import base64
17
- import functools
18
17
  import os
19
- from typing import Annotated, Any, Literal
18
+ from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
21
  from langfun.core import modalities as lf_modalities
@@ -24,20 +23,6 @@ from langfun.core.llms import rest
24
23
  import pyglove as pg
25
24
 
26
25
 
27
- try:
28
- # pylint: disable=g-import-not-at-top
29
- from google import auth as google_auth
30
- from google.auth import credentials as credentials_lib
31
- from google.auth.transport import requests as auth_requests
32
- Credentials = credentials_lib.Credentials
33
- # pylint: enable=g-import-not-at-top
34
- except ImportError:
35
- google_auth = None
36
- auth_requests = None
37
- credentials_lib = None
38
- Credentials = Any # pylint: disable=invalid-name
39
-
40
-
41
26
  SUPPORTED_MODELS_AND_SETTINGS = {
42
27
  # See https://docs.anthropic.com/claude/docs/models-overview
43
28
  # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
@@ -379,110 +364,3 @@ class Claude21(Anthropic):
379
364
  class ClaudeInstant(Anthropic):
380
365
  """Cheapest small and fast model, 100K context window."""
381
366
  model = 'claude-instant-1.2'
382
-
383
-
384
- #
385
- # Authropic models on VertexAI.
386
- #
387
-
388
-
389
- class VertexAIAnthropic(Anthropic):
390
- """Anthropic models on VertexAI."""
391
-
392
- project: Annotated[
393
- str | None,
394
- 'Google Cloud project ID.',
395
- ] = None
396
-
397
- location: Annotated[
398
- Literal['us-east5', 'europe-west1'],
399
- 'GCP location with Anthropic models hosted.'
400
- ] = 'us-east5'
401
-
402
- credentials: Annotated[
403
- Credentials | None, # pytype: disable=invalid-annotation
404
- (
405
- 'Credentials to use. If None, the default credentials '
406
- 'to the environment will be used.'
407
- ),
408
- ] = None
409
-
410
- api_version = 'vertex-2023-10-16'
411
-
412
- def _on_bound(self):
413
- super()._on_bound()
414
- if google_auth is None:
415
- raise ValueError(
416
- 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
417
- )
418
- self._project = None
419
- self._credentials = None
420
-
421
- def _initialize(self):
422
- project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
423
- if not project:
424
- raise ValueError(
425
- 'Please specify `project` during `__init__` or set environment '
426
- 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
427
- )
428
- self._project = project
429
- credentials = self.credentials
430
- if credentials is None:
431
- # Use default credentials.
432
- credentials = google_auth.default(
433
- scopes=['https://www.googleapis.com/auth/cloud-platform']
434
- )
435
- self._credentials = credentials
436
-
437
- @functools.cached_property
438
- def _session(self):
439
- assert self._api_initialized
440
- assert self._credentials is not None
441
- assert auth_requests is not None
442
- s = auth_requests.AuthorizedSession(self._credentials)
443
- s.headers.update(self.headers or {})
444
- return s
445
-
446
- @property
447
- def headers(self):
448
- return {
449
- 'Content-Type': 'application/json; charset=utf-8',
450
- }
451
-
452
- @property
453
- def api_endpoint(self) -> str:
454
- return (
455
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
456
- f'{self._project}/locations/{self.location}/publishers/anthropic/'
457
- f'models/{self.model}:streamRawPredict'
458
- )
459
-
460
- def request(
461
- self,
462
- prompt: lf.Message,
463
- sampling_options: lf.LMSamplingOptions
464
- ):
465
- request = super().request(prompt, sampling_options)
466
- request['anthropic_version'] = self.api_version
467
- del request['model']
468
- return request
469
-
470
-
471
- class VertexAIClaude3_Opus_20240229(VertexAIAnthropic): # pylint: disable=invalid-name
472
- """Anthropic's Claude 3 Opus model on VertexAI."""
473
- model = 'claude-3-opus@20240229'
474
-
475
-
476
- class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
477
- """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
478
- model = 'claude-3-5-sonnet-v2@20241022'
479
-
480
-
481
- class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic): # pylint: disable=invalid-name
482
- """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
483
- model = 'claude-3-5-sonnet@20240620'
484
-
485
-
486
- class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
487
- """Anthropic's Claude 3.5 Haiku model on VertexAI."""
488
- model = 'claude-3-5-haiku@20241022'
@@ -19,9 +19,6 @@ from typing import Any
19
19
  import unittest
20
20
  from unittest import mock
21
21
 
22
- from google.auth import exceptions
23
- from langfun.core import language_model
24
- from langfun.core import message as lf_message
25
22
  from langfun.core import modalities as lf_modalities
26
23
  from langfun.core.llms import anthropic
27
24
  import pyglove as pg
@@ -186,50 +183,5 @@ class AnthropicTest(unittest.TestCase):
186
183
  lm('hello', max_attempts=1)
187
184
 
188
185
 
189
- class VertexAIAnthropicTest(unittest.TestCase):
190
- """Tests for VertexAI Anthropic models."""
191
-
192
- def test_basics(self):
193
- with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
194
- lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
195
- lm('hi')
196
-
197
- model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
198
-
199
- # NOTE(daiyip): For OSS users, default credentials are not available unless
200
- # users have already set up their GCP project. Therefore we ignore the
201
- # exception here.
202
- try:
203
- model._initialize()
204
- except exceptions.DefaultCredentialsError:
205
- pass
206
-
207
- self.assertEqual(
208
- model.api_endpoint,
209
- (
210
- 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
211
- 'langfun/locations/us-east5/publishers/anthropic/'
212
- 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
213
- )
214
- )
215
- request = model.request(
216
- lf_message.UserMessage('hi'),
217
- language_model.LMSamplingOptions(temperature=0.0),
218
- )
219
- self.assertEqual(
220
- request,
221
- {
222
- 'anthropic_version': 'vertex-2023-10-16',
223
- 'max_tokens': 8192,
224
- 'messages': [
225
- {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
226
- ],
227
- 'stream': False,
228
- 'temperature': 0.0,
229
- 'top_k': 40,
230
- },
231
- )
232
-
233
-
234
186
  if __name__ == '__main__':
235
187
  unittest.main()