langfun 0.1.2.dev202501080804__py3-none-any.whl → 0.1.2.dev202501240804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +1 -6
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +4 -7
- langfun/core/coding/python/correction_test.py +2 -3
- langfun/core/coding/python/execution.py +22 -211
- langfun/core/coding/python/execution_test.py +11 -90
- langfun/core/coding/python/generation.py +3 -2
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -194
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +11 -273
- langfun/core/component_test.py +2 -29
- langfun/core/concurrent.py +187 -82
- langfun/core/concurrent_test.py +28 -19
- langfun/core/console.py +7 -3
- langfun/core/eval/base.py +2 -3
- langfun/core/eval/v2/evaluation.py +3 -1
- langfun/core/eval/v2/reporting.py +8 -4
- langfun/core/language_model.py +84 -8
- langfun/core/language_model_test.py +84 -29
- langfun/core/llms/__init__.py +46 -11
- langfun/core/llms/anthropic.py +1 -123
- langfun/core/llms/anthropic_test.py +0 -48
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/gemini.py +1 -1
- langfun/core/llms/groq.py +12 -99
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +17 -54
- langfun/core/llms/llama_cpp_test.py +2 -34
- langfun/core/llms/openai.py +9 -147
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +13 -423
- langfun/core/llms/rest_test.py +1 -1
- langfun/core/llms/vertexai.py +387 -18
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/message_test.py +3 -3
- langfun/core/modalities/mime.py +8 -0
- langfun/core/modalities/mime_test.py +19 -4
- langfun/core/modality_test.py +0 -1
- langfun/core/structured/mapping.py +13 -13
- langfun/core/structured/mapping_test.py +2 -2
- langfun/core/structured/schema.py +16 -8
- langfun/core/structured/schema_generation.py +1 -1
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/METADATA +13 -2
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/RECORD +50 -52
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/text_formatting.py +0 -168
- langfun/core/text_formatting_test.py +0 -65
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501080804.dist-info → langfun-0.1.2.dev202501240804.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -81,6 +81,61 @@ class LMSample(pg.Object):
|
|
81
81
|
] = None
|
82
82
|
|
83
83
|
|
84
|
+
class RetryStats(pg.Object):
|
85
|
+
"""Retry stats, which is aggregated across multiple retry entries."""
|
86
|
+
|
87
|
+
num_occurences: Annotated[
|
88
|
+
int,
|
89
|
+
'Total number of retry attempts on LLM (excluding the first attempt).',
|
90
|
+
] = 0
|
91
|
+
total_wait_interval: Annotated[
|
92
|
+
float, 'Total wait interval in seconds due to retry.'
|
93
|
+
] = 0
|
94
|
+
total_call_interval: Annotated[
|
95
|
+
float, 'Total LLM call interval in seconds.'
|
96
|
+
] = 0
|
97
|
+
errors: Annotated[
|
98
|
+
dict[str, int],
|
99
|
+
'A Counter of error types encountered during the retry attempts.',
|
100
|
+
] = {}
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_retry_entries(
|
104
|
+
cls, retry_entries: Sequence[concurrent.RetryEntry]
|
105
|
+
) -> 'RetryStats':
|
106
|
+
"""Creates a RetryStats from a sequence of RetryEntry."""
|
107
|
+
if not retry_entries:
|
108
|
+
return RetryStats()
|
109
|
+
errors = {}
|
110
|
+
for retry in retry_entries:
|
111
|
+
if retry.error is not None:
|
112
|
+
errors[retry.error.__class__.__name__] = (
|
113
|
+
errors.get(retry.error.__class__.__name__, 0) + 1
|
114
|
+
)
|
115
|
+
return RetryStats(
|
116
|
+
num_occurences=len(retry_entries) - 1,
|
117
|
+
total_wait_interval=sum(e.wait_interval for e in retry_entries),
|
118
|
+
total_call_interval=sum(e.call_interval for e in retry_entries),
|
119
|
+
errors=errors,
|
120
|
+
)
|
121
|
+
|
122
|
+
def __add__(self, other: 'RetryStats') -> 'RetryStats':
|
123
|
+
errors = self.errors.copy()
|
124
|
+
for error, count in other.errors.items():
|
125
|
+
errors[error] = errors.get(error, 0) + count
|
126
|
+
return RetryStats(
|
127
|
+
num_occurences=self.num_occurences + other.num_occurences,
|
128
|
+
total_wait_interval=self.total_wait_interval
|
129
|
+
+ other.total_wait_interval,
|
130
|
+
total_call_interval=self.total_call_interval
|
131
|
+
+ other.total_call_interval,
|
132
|
+
errors=errors,
|
133
|
+
)
|
134
|
+
|
135
|
+
def __radd__(self, other: 'RetryStats') -> 'RetryStats':
|
136
|
+
return self + other
|
137
|
+
|
138
|
+
|
84
139
|
class LMSamplingUsage(pg.Object):
|
85
140
|
"""Usage information per completion."""
|
86
141
|
|
@@ -93,8 +148,9 @@ class LMSamplingUsage(pg.Object):
|
|
93
148
|
(
|
94
149
|
'Estimated cost in US dollars. If None, cost estimating is not '
|
95
150
|
'suppported on the model being queried.'
|
96
|
-
)
|
151
|
+
),
|
97
152
|
] = None
|
153
|
+
retry_stats: RetryStats = RetryStats()
|
98
154
|
|
99
155
|
def __bool__(self) -> bool:
|
100
156
|
return self.num_requests > 0
|
@@ -136,6 +192,7 @@ class LMSamplingUsage(pg.Object):
|
|
136
192
|
total_tokens=self.total_tokens + other.total_tokens,
|
137
193
|
num_requests=self.num_requests + other.num_requests,
|
138
194
|
estimated_cost=estimated_cost,
|
195
|
+
retry_stats=self.retry_stats + other.retry_stats,
|
139
196
|
)
|
140
197
|
|
141
198
|
def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
|
@@ -511,7 +568,18 @@ class LanguageModel(component.Component):
|
|
511
568
|
total_tokens=usage.total_tokens // n,
|
512
569
|
estimated_cost=(
|
513
570
|
usage.estimated_cost / n if usage.estimated_cost else None
|
514
|
-
)
|
571
|
+
),
|
572
|
+
retry_stats=RetryStats(
|
573
|
+
num_occurences=usage.retry_stats.num_occurences // n,
|
574
|
+
total_wait_interval=usage.retry_stats.total_wait_interval
|
575
|
+
/ n,
|
576
|
+
total_call_interval=usage.retry_stats.total_call_interval
|
577
|
+
/ n,
|
578
|
+
errors={
|
579
|
+
error: count // n
|
580
|
+
for error, count in usage.retry_stats.errors.items()
|
581
|
+
},
|
582
|
+
),
|
515
583
|
)
|
516
584
|
|
517
585
|
# Track usage.
|
@@ -584,16 +652,16 @@ class LanguageModel(component.Component):
|
|
584
652
|
|
585
653
|
def _parallel_execute_with_currency_control(
|
586
654
|
self,
|
587
|
-
action: Callable[...,
|
655
|
+
action: Callable[..., LMSamplingResult],
|
588
656
|
inputs: Sequence[Any],
|
589
657
|
retry_on_errors: Union[
|
590
658
|
None,
|
591
659
|
Union[Type[BaseException], Tuple[Type[BaseException], str]],
|
592
660
|
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
|
593
661
|
] = RetryableLMError,
|
594
|
-
) -> Any:
|
662
|
+
) -> list[Any]:
|
595
663
|
"""Helper method for subclasses for implementing _sample."""
|
596
|
-
|
664
|
+
executed_jobs = concurrent.concurrent_execute(
|
597
665
|
action,
|
598
666
|
inputs,
|
599
667
|
executor=self.resource_id if self.max_concurrency else None,
|
@@ -603,7 +671,15 @@ class LanguageModel(component.Component):
|
|
603
671
|
retry_interval=self.retry_interval,
|
604
672
|
exponential_backoff=self.exponential_backoff,
|
605
673
|
max_retry_interval=self.max_retry_interval,
|
674
|
+
return_jobs=True,
|
606
675
|
)
|
676
|
+
for job in executed_jobs:
|
677
|
+
if isinstance(job.result, LMSamplingResult):
|
678
|
+
job.result.usage.rebind(
|
679
|
+
retry_stats=RetryStats.from_retry_entries(job.retry_entries),
|
680
|
+
skip_notification=True,
|
681
|
+
)
|
682
|
+
return [job.result for job in executed_jobs]
|
607
683
|
|
608
684
|
def __call__(
|
609
685
|
self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
|
@@ -653,7 +729,7 @@ class LanguageModel(component.Component):
|
|
653
729
|
"""Outputs debugging information about the model."""
|
654
730
|
title_suffix = ''
|
655
731
|
if usage.total_tokens != 0:
|
656
|
-
title_suffix =
|
732
|
+
title_suffix = pg.colored(
|
657
733
|
f' (total {usage.total_tokens} tokens)', 'red'
|
658
734
|
)
|
659
735
|
|
@@ -672,7 +748,7 @@ class LanguageModel(component.Component):
|
|
672
748
|
"""Outputs debugging information about the prompt."""
|
673
749
|
title_suffix = ''
|
674
750
|
if usage.prompt_tokens != 0:
|
675
|
-
title_suffix =
|
751
|
+
title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
676
752
|
|
677
753
|
console.write(
|
678
754
|
# We use metadata 'formatted_text' for scenarios where the prompt text
|
@@ -703,7 +779,7 @@ class LanguageModel(component.Component):
|
|
703
779
|
if usage.completion_tokens != 0:
|
704
780
|
title_suffix += f'{usage.completion_tokens} tokens '
|
705
781
|
title_suffix += f'in {elapse:.2f} seconds)'
|
706
|
-
title_suffix =
|
782
|
+
title_suffix = pg.colored(title_suffix, 'red')
|
707
783
|
|
708
784
|
console.write(
|
709
785
|
str(response) + '\n',
|
@@ -35,34 +35,34 @@ class MockModel(lm_lib.LanguageModel):
|
|
35
35
|
) -> list[lm_lib.LMSamplingResult]:
|
36
36
|
context = pg.Dict(attempt=0)
|
37
37
|
|
38
|
-
def fake_sample(
|
38
|
+
def fake_sample(prompt):
|
39
39
|
if context.attempt >= self.failures_before_attempt:
|
40
|
-
return
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
]
|
57
|
-
context.attempt += 1
|
40
|
+
return lm_lib.LMSamplingResult(
|
41
|
+
[
|
42
|
+
lm_lib.LMSample( # pylint: disable=g-complex-comprehension
|
43
|
+
response=prompt.text * self.sampling_options.top_k,
|
44
|
+
score=self.sampling_options.temperature or -1.0,
|
45
|
+
)
|
46
|
+
],
|
47
|
+
usage=lm_lib.LMSamplingUsage(
|
48
|
+
prompt_tokens=100,
|
49
|
+
completion_tokens=100,
|
50
|
+
total_tokens=200,
|
51
|
+
estimated_cost=1.0,
|
52
|
+
),
|
53
|
+
)
|
54
|
+
else:
|
55
|
+
context.attempt += 1
|
58
56
|
raise ValueError('Failed to sample prompts.')
|
59
57
|
|
60
|
-
|
61
|
-
fake_sample,
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
58
|
+
results = self._parallel_execute_with_currency_control(
|
59
|
+
fake_sample, prompts, retry_on_errors=ValueError
|
60
|
+
)
|
61
|
+
for result in results:
|
62
|
+
result.usage.retry_stats.rebind(
|
63
|
+
total_call_interval=0, skip_notification=True
|
64
|
+
)
|
65
|
+
return results
|
66
66
|
|
67
67
|
@property
|
68
68
|
def model_id(self) -> str:
|
@@ -448,13 +448,50 @@ class LanguageModelTest(unittest.TestCase):
|
|
448
448
|
|
449
449
|
def test_retry(self):
|
450
450
|
lm = MockModel(
|
451
|
-
failures_before_attempt=1, top_k=1,
|
451
|
+
failures_before_attempt=1, top_k=1, max_attempts=2, retry_interval=1
|
452
452
|
)
|
453
453
|
with self.assertRaisesRegex(
|
454
454
|
concurrent.RetryError, 'Calling .* failed after 1 attempts'
|
455
455
|
):
|
456
456
|
lm('foo', max_attempts=1)
|
457
|
-
|
457
|
+
|
458
|
+
usage = lm_lib.LMSamplingUsage(
|
459
|
+
prompt_tokens=100,
|
460
|
+
completion_tokens=100,
|
461
|
+
total_tokens=200,
|
462
|
+
num_requests=1,
|
463
|
+
estimated_cost=1.0,
|
464
|
+
retry_stats=lm_lib.RetryStats(
|
465
|
+
num_occurences=1,
|
466
|
+
total_wait_interval=1,
|
467
|
+
errors={'ValueError': 1},
|
468
|
+
),
|
469
|
+
)
|
470
|
+
out = lm.sample(['foo'])
|
471
|
+
self.assertEqual(
|
472
|
+
# lm.sample(['foo'], max_attempts=2),
|
473
|
+
out,
|
474
|
+
[
|
475
|
+
lm_lib.LMSamplingResult(
|
476
|
+
[
|
477
|
+
lm_lib.LMSample(
|
478
|
+
message_lib.AIMessage(
|
479
|
+
'foo',
|
480
|
+
score=-1.0,
|
481
|
+
logprobs=None,
|
482
|
+
is_cached=False,
|
483
|
+
usage=usage,
|
484
|
+
tags=['lm-response'],
|
485
|
+
),
|
486
|
+
score=-1.0,
|
487
|
+
logprobs=None,
|
488
|
+
)
|
489
|
+
],
|
490
|
+
usage=usage,
|
491
|
+
is_cached=False,
|
492
|
+
)
|
493
|
+
],
|
494
|
+
)
|
458
495
|
|
459
496
|
def test_debug(self):
|
460
497
|
class Image(modality.Modality):
|
@@ -755,16 +792,34 @@ class LMSamplingUsageTest(unittest.TestCase):
|
|
755
792
|
|
756
793
|
def test_add(self):
|
757
794
|
usage1 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
795
|
+
usage1.rebind(retry_stats=lm_lib.RetryStats(1, 3, 4, {'e1': 1}))
|
758
796
|
usage2 = lm_lib.LMSamplingUsage(100, 200, 300, 4, 5.0)
|
759
797
|
self.assertEqual(usage1 + usage2, usage1 + usage2)
|
760
798
|
self.assertIs(usage1 + None, usage1)
|
761
799
|
self.assertIs(None + usage1, usage1)
|
762
800
|
usage3 = lm_lib.LMSamplingUsage(100, 200, 300, 4, None)
|
801
|
+
usage3.rebind(retry_stats=lm_lib.RetryStats(2, 4, 5, {'e1': 2, 'e2': 3}))
|
763
802
|
self.assertEqual(
|
764
|
-
usage1 + usage3,
|
803
|
+
usage1 + usage3,
|
804
|
+
lm_lib.LMSamplingUsage(
|
805
|
+
200,
|
806
|
+
400,
|
807
|
+
600,
|
808
|
+
8,
|
809
|
+
5.0,
|
810
|
+
retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
|
811
|
+
),
|
765
812
|
)
|
766
813
|
self.assertEqual(
|
767
|
-
usage3 + usage1,
|
814
|
+
usage3 + usage1,
|
815
|
+
lm_lib.LMSamplingUsage(
|
816
|
+
200,
|
817
|
+
400,
|
818
|
+
600,
|
819
|
+
8,
|
820
|
+
5.0,
|
821
|
+
retry_stats=lm_lib.RetryStats(3, 7, 9, {'e1': 3, 'e2': 3}),
|
822
|
+
),
|
768
823
|
)
|
769
824
|
|
770
825
|
def test_usage_not_available(self):
|
langfun/core/llms/__init__.py
CHANGED
@@ -27,8 +27,15 @@ from langfun.core.llms.fake import StaticSequence
|
|
27
27
|
# Compositional models.
|
28
28
|
from langfun.core.llms.compositional import RandomChoice
|
29
29
|
|
30
|
-
#
|
30
|
+
# Base models by request/response protocol.
|
31
31
|
from langfun.core.llms.rest import REST
|
32
|
+
from langfun.core.llms.openai_compatible import OpenAICompatible
|
33
|
+
from langfun.core.llms.gemini import Gemini
|
34
|
+
from langfun.core.llms.anthropic import Anthropic
|
35
|
+
|
36
|
+
# Base models by serving platforms.
|
37
|
+
from langfun.core.llms.vertexai import VertexAI
|
38
|
+
from langfun.core.llms.groq import Groq
|
32
39
|
|
33
40
|
# Gemini models.
|
34
41
|
from langfun.core.llms.google_genai import GenAI
|
@@ -44,7 +51,7 @@ from langfun.core.llms.google_genai import GeminiFlash1_5_002
|
|
44
51
|
from langfun.core.llms.google_genai import GeminiFlash1_5_001
|
45
52
|
from langfun.core.llms.google_genai import GeminiPro1
|
46
53
|
|
47
|
-
from langfun.core.llms.vertexai import
|
54
|
+
from langfun.core.llms.vertexai import VertexAIGemini
|
48
55
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
|
49
56
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
|
50
57
|
from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
|
@@ -111,20 +118,34 @@ from langfun.core.llms.openai import Gpt3Curie
|
|
111
118
|
from langfun.core.llms.openai import Gpt3Babbage
|
112
119
|
from langfun.core.llms.openai import Gpt3Ada
|
113
120
|
|
114
|
-
|
121
|
+
# Anthropic models.
|
122
|
+
|
115
123
|
from langfun.core.llms.anthropic import Claude35Sonnet
|
116
124
|
from langfun.core.llms.anthropic import Claude35Sonnet20241022
|
117
125
|
from langfun.core.llms.anthropic import Claude35Sonnet20240620
|
118
126
|
from langfun.core.llms.anthropic import Claude3Opus
|
119
127
|
from langfun.core.llms.anthropic import Claude3Sonnet
|
120
128
|
from langfun.core.llms.anthropic import Claude3Haiku
|
121
|
-
from langfun.core.llms.anthropic import VertexAIAnthropic
|
122
|
-
from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20241022
|
123
|
-
from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20240620
|
124
|
-
from langfun.core.llms.anthropic import VertexAIClaude3_5_Haiku_20241022
|
125
|
-
from langfun.core.llms.anthropic import VertexAIClaude3_Opus_20240229
|
126
129
|
|
127
|
-
from langfun.core.llms.
|
130
|
+
from langfun.core.llms.vertexai import VertexAIAnthropic
|
131
|
+
from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20241022
|
132
|
+
from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20240620
|
133
|
+
from langfun.core.llms.vertexai import VertexAIClaude3_5_Haiku_20241022
|
134
|
+
from langfun.core.llms.vertexai import VertexAIClaude3_Opus_20240229
|
135
|
+
|
136
|
+
# Misc open source models.
|
137
|
+
|
138
|
+
# Gemma models.
|
139
|
+
from langfun.core.llms.groq import GroqGemma2_9B_IT
|
140
|
+
from langfun.core.llms.groq import GroqGemma_7B_IT
|
141
|
+
|
142
|
+
# Llama models.
|
143
|
+
from langfun.core.llms.vertexai import VertexAILlama
|
144
|
+
from langfun.core.llms.vertexai import VertexAILlama3_2_90B
|
145
|
+
from langfun.core.llms.vertexai import VertexAILlama3_1_405B
|
146
|
+
from langfun.core.llms.vertexai import VertexAILlama3_1_70B
|
147
|
+
from langfun.core.llms.vertexai import VertexAILlama3_1_8B
|
148
|
+
|
128
149
|
from langfun.core.llms.groq import GroqLlama3_2_3B
|
129
150
|
from langfun.core.llms.groq import GroqLlama3_2_1B
|
130
151
|
from langfun.core.llms.groq import GroqLlama3_1_70B
|
@@ -132,15 +153,29 @@ from langfun.core.llms.groq import GroqLlama3_1_8B
|
|
132
153
|
from langfun.core.llms.groq import GroqLlama3_70B
|
133
154
|
from langfun.core.llms.groq import GroqLlama3_8B
|
134
155
|
from langfun.core.llms.groq import GroqLlama2_70B
|
156
|
+
|
157
|
+
# Mistral models.
|
158
|
+
from langfun.core.llms.vertexai import VertexAIMistral
|
159
|
+
from langfun.core.llms.vertexai import VertexAIMistralLarge_20241121
|
160
|
+
from langfun.core.llms.vertexai import VertexAIMistralLarge_20240724
|
161
|
+
from langfun.core.llms.vertexai import VertexAIMistralNemo_20240724
|
162
|
+
from langfun.core.llms.vertexai import VertexAICodestral_20250113
|
163
|
+
from langfun.core.llms.vertexai import VertexAICodestral_20240529
|
164
|
+
|
135
165
|
from langfun.core.llms.groq import GroqMistral_8x7B
|
136
|
-
|
137
|
-
|
166
|
+
|
167
|
+
# DeepSeek models.
|
168
|
+
from langfun.core.llms.deepseek import DeepSeek
|
169
|
+
from langfun.core.llms.deepseek import DeepSeekChat
|
170
|
+
|
171
|
+
# Whisper models.
|
138
172
|
from langfun.core.llms.groq import GroqWhisper_Large_v3
|
139
173
|
from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
|
140
174
|
|
141
175
|
# LLaMA C++ models.
|
142
176
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
143
177
|
|
178
|
+
|
144
179
|
# Placeholder for Google-internal imports.
|
145
180
|
|
146
181
|
# Include cache as sub-module.
|
langfun/core/llms/anthropic.py
CHANGED
@@ -14,9 +14,8 @@
|
|
14
14
|
"""Language models from Anthropic."""
|
15
15
|
|
16
16
|
import base64
|
17
|
-
import functools
|
18
17
|
import os
|
19
|
-
from typing import Annotated, Any
|
18
|
+
from typing import Annotated, Any
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
21
|
from langfun.core import modalities as lf_modalities
|
@@ -24,20 +23,6 @@ from langfun.core.llms import rest
|
|
24
23
|
import pyglove as pg
|
25
24
|
|
26
25
|
|
27
|
-
try:
|
28
|
-
# pylint: disable=g-import-not-at-top
|
29
|
-
from google import auth as google_auth
|
30
|
-
from google.auth import credentials as credentials_lib
|
31
|
-
from google.auth.transport import requests as auth_requests
|
32
|
-
Credentials = credentials_lib.Credentials
|
33
|
-
# pylint: enable=g-import-not-at-top
|
34
|
-
except ImportError:
|
35
|
-
google_auth = None
|
36
|
-
auth_requests = None
|
37
|
-
credentials_lib = None
|
38
|
-
Credentials = Any # pylint: disable=invalid-name
|
39
|
-
|
40
|
-
|
41
26
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
42
27
|
# See https://docs.anthropic.com/claude/docs/models-overview
|
43
28
|
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
|
@@ -379,110 +364,3 @@ class Claude21(Anthropic):
|
|
379
364
|
class ClaudeInstant(Anthropic):
|
380
365
|
"""Cheapest small and fast model, 100K context window."""
|
381
366
|
model = 'claude-instant-1.2'
|
382
|
-
|
383
|
-
|
384
|
-
#
|
385
|
-
# Authropic models on VertexAI.
|
386
|
-
#
|
387
|
-
|
388
|
-
|
389
|
-
class VertexAIAnthropic(Anthropic):
|
390
|
-
"""Anthropic models on VertexAI."""
|
391
|
-
|
392
|
-
project: Annotated[
|
393
|
-
str | None,
|
394
|
-
'Google Cloud project ID.',
|
395
|
-
] = None
|
396
|
-
|
397
|
-
location: Annotated[
|
398
|
-
Literal['us-east5', 'europe-west1'],
|
399
|
-
'GCP location with Anthropic models hosted.'
|
400
|
-
] = 'us-east5'
|
401
|
-
|
402
|
-
credentials: Annotated[
|
403
|
-
Credentials | None, # pytype: disable=invalid-annotation
|
404
|
-
(
|
405
|
-
'Credentials to use. If None, the default credentials '
|
406
|
-
'to the environment will be used.'
|
407
|
-
),
|
408
|
-
] = None
|
409
|
-
|
410
|
-
api_version = 'vertex-2023-10-16'
|
411
|
-
|
412
|
-
def _on_bound(self):
|
413
|
-
super()._on_bound()
|
414
|
-
if google_auth is None:
|
415
|
-
raise ValueError(
|
416
|
-
'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
|
417
|
-
)
|
418
|
-
self._project = None
|
419
|
-
self._credentials = None
|
420
|
-
|
421
|
-
def _initialize(self):
|
422
|
-
project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
|
423
|
-
if not project:
|
424
|
-
raise ValueError(
|
425
|
-
'Please specify `project` during `__init__` or set environment '
|
426
|
-
'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
|
427
|
-
)
|
428
|
-
self._project = project
|
429
|
-
credentials = self.credentials
|
430
|
-
if credentials is None:
|
431
|
-
# Use default credentials.
|
432
|
-
credentials = google_auth.default(
|
433
|
-
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
434
|
-
)
|
435
|
-
self._credentials = credentials
|
436
|
-
|
437
|
-
@functools.cached_property
|
438
|
-
def _session(self):
|
439
|
-
assert self._api_initialized
|
440
|
-
assert self._credentials is not None
|
441
|
-
assert auth_requests is not None
|
442
|
-
s = auth_requests.AuthorizedSession(self._credentials)
|
443
|
-
s.headers.update(self.headers or {})
|
444
|
-
return s
|
445
|
-
|
446
|
-
@property
|
447
|
-
def headers(self):
|
448
|
-
return {
|
449
|
-
'Content-Type': 'application/json; charset=utf-8',
|
450
|
-
}
|
451
|
-
|
452
|
-
@property
|
453
|
-
def api_endpoint(self) -> str:
|
454
|
-
return (
|
455
|
-
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
|
456
|
-
f'{self._project}/locations/{self.location}/publishers/anthropic/'
|
457
|
-
f'models/{self.model}:streamRawPredict'
|
458
|
-
)
|
459
|
-
|
460
|
-
def request(
|
461
|
-
self,
|
462
|
-
prompt: lf.Message,
|
463
|
-
sampling_options: lf.LMSamplingOptions
|
464
|
-
):
|
465
|
-
request = super().request(prompt, sampling_options)
|
466
|
-
request['anthropic_version'] = self.api_version
|
467
|
-
del request['model']
|
468
|
-
return request
|
469
|
-
|
470
|
-
|
471
|
-
class VertexAIClaude3_Opus_20240229(VertexAIAnthropic): # pylint: disable=invalid-name
|
472
|
-
"""Anthropic's Claude 3 Opus model on VertexAI."""
|
473
|
-
model = 'claude-3-opus@20240229'
|
474
|
-
|
475
|
-
|
476
|
-
class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
|
477
|
-
"""Anthropic's Claude 3.5 Sonnet model on VertexAI."""
|
478
|
-
model = 'claude-3-5-sonnet-v2@20241022'
|
479
|
-
|
480
|
-
|
481
|
-
class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic): # pylint: disable=invalid-name
|
482
|
-
"""Anthropic's Claude 3.5 Sonnet model on VertexAI."""
|
483
|
-
model = 'claude-3-5-sonnet@20240620'
|
484
|
-
|
485
|
-
|
486
|
-
class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
|
487
|
-
"""Anthropic's Claude 3.5 Haiku model on VertexAI."""
|
488
|
-
model = 'claude-3-5-haiku@20241022'
|
@@ -19,9 +19,6 @@ from typing import Any
|
|
19
19
|
import unittest
|
20
20
|
from unittest import mock
|
21
21
|
|
22
|
-
from google.auth import exceptions
|
23
|
-
from langfun.core import language_model
|
24
|
-
from langfun.core import message as lf_message
|
25
22
|
from langfun.core import modalities as lf_modalities
|
26
23
|
from langfun.core.llms import anthropic
|
27
24
|
import pyglove as pg
|
@@ -186,50 +183,5 @@ class AnthropicTest(unittest.TestCase):
|
|
186
183
|
lm('hello', max_attempts=1)
|
187
184
|
|
188
185
|
|
189
|
-
class VertexAIAnthropicTest(unittest.TestCase):
|
190
|
-
"""Tests for VertexAI Anthropic models."""
|
191
|
-
|
192
|
-
def test_basics(self):
|
193
|
-
with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
|
194
|
-
lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
|
195
|
-
lm('hi')
|
196
|
-
|
197
|
-
model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
|
198
|
-
|
199
|
-
# NOTE(daiyip): For OSS users, default credentials are not available unless
|
200
|
-
# users have already set up their GCP project. Therefore we ignore the
|
201
|
-
# exception here.
|
202
|
-
try:
|
203
|
-
model._initialize()
|
204
|
-
except exceptions.DefaultCredentialsError:
|
205
|
-
pass
|
206
|
-
|
207
|
-
self.assertEqual(
|
208
|
-
model.api_endpoint,
|
209
|
-
(
|
210
|
-
'https://us-east5-aiplatform.googleapis.com/v1/projects/'
|
211
|
-
'langfun/locations/us-east5/publishers/anthropic/'
|
212
|
-
'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
|
213
|
-
)
|
214
|
-
)
|
215
|
-
request = model.request(
|
216
|
-
lf_message.UserMessage('hi'),
|
217
|
-
language_model.LMSamplingOptions(temperature=0.0),
|
218
|
-
)
|
219
|
-
self.assertEqual(
|
220
|
-
request,
|
221
|
-
{
|
222
|
-
'anthropic_version': 'vertex-2023-10-16',
|
223
|
-
'max_tokens': 8192,
|
224
|
-
'messages': [
|
225
|
-
{'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
|
226
|
-
],
|
227
|
-
'stream': False,
|
228
|
-
'temperature': 0.0,
|
229
|
-
'top_k': 40,
|
230
|
-
},
|
231
|
-
)
|
232
|
-
|
233
|
-
|
234
186
|
if __name__ == '__main__':
|
235
187
|
unittest.main()
|