langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +6 -2
- langfun/core/language_model.py +365 -22
- langfun/core/language_model_test.py +123 -35
- langfun/core/llms/__init__.py +50 -57
- langfun/core/llms/anthropic.py +434 -163
- langfun/core/llms/anthropic_test.py +20 -1
- langfun/core/llms/deepseek.py +90 -51
- langfun/core/llms/deepseek_test.py +15 -16
- langfun/core/llms/fake.py +6 -0
- langfun/core/llms/gemini.py +480 -390
- langfun/core/llms/gemini_test.py +27 -7
- langfun/core/llms/google_genai.py +80 -50
- langfun/core/llms/google_genai_test.py +11 -4
- langfun/core/llms/groq.py +268 -167
- langfun/core/llms/groq_test.py +9 -3
- langfun/core/llms/openai.py +839 -328
- langfun/core/llms/openai_compatible.py +3 -18
- langfun/core/llms/openai_compatible_test.py +20 -5
- langfun/core/llms/openai_test.py +14 -4
- langfun/core/llms/rest.py +11 -6
- langfun/core/llms/vertexai.py +238 -240
- langfun/core/llms/vertexai_test.py +35 -8
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/RECORD +27 -27
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from typing import Any
|
|
19
19
|
import unittest
|
20
20
|
from unittest import mock
|
21
21
|
|
22
|
+
import langfun.core as lf
|
22
23
|
from langfun.core import modalities as lf_modalities
|
23
24
|
from langfun.core.llms import anthropic
|
24
25
|
import pyglove as pg
|
@@ -119,6 +120,17 @@ class AnthropicTest(unittest.TestCase):
|
|
119
120
|
)
|
120
121
|
self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
|
121
122
|
|
123
|
+
def test_model_alias(self):
|
124
|
+
# Alias will be normalized to the official version.
|
125
|
+
self.assertEqual(
|
126
|
+
anthropic.Anthropic('claude-3-5-sonnet-20241022').model_id,
|
127
|
+
'claude-3-5-sonnet-20241022'
|
128
|
+
)
|
129
|
+
self.assertEqual(
|
130
|
+
anthropic.Anthropic('claude-3-5-sonnet-v2@20241022').model_id,
|
131
|
+
'claude-3-5-sonnet-20241022'
|
132
|
+
)
|
133
|
+
|
122
134
|
def test_api_key(self):
|
123
135
|
lm = anthropic.Claude3Haiku()
|
124
136
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
@@ -151,6 +163,7 @@ class AnthropicTest(unittest.TestCase):
|
|
151
163
|
self.assertIsNotNone(response.usage.prompt_tokens, 2)
|
152
164
|
self.assertIsNotNone(response.usage.completion_tokens, 1)
|
153
165
|
self.assertIsNotNone(response.usage.total_tokens, 3)
|
166
|
+
self.assertGreater(response.usage.estimated_cost, 0)
|
154
167
|
|
155
168
|
def test_mm_call(self):
|
156
169
|
with mock.patch('requests.Session.post') as mock_mm_request:
|
@@ -162,7 +175,7 @@ class AnthropicTest(unittest.TestCase):
|
|
162
175
|
def test_pdf_call(self):
|
163
176
|
with mock.patch('requests.Session.post') as mock_mm_request:
|
164
177
|
mock_mm_request.side_effect = mock_mm_requests_post
|
165
|
-
lm = anthropic.
|
178
|
+
lm = anthropic.Claude35Sonnet(api_key='fake_key')
|
166
179
|
response = lm(lf_modalities.PDF.from_bytes(pdf_content), lm=lm)
|
167
180
|
self.assertEqual(response.text, 'document: application/pdf')
|
168
181
|
|
@@ -182,6 +195,12 @@ class AnthropicTest(unittest.TestCase):
|
|
182
195
|
):
|
183
196
|
lm('hello', max_attempts=1)
|
184
197
|
|
198
|
+
def test_lm_get(self):
|
199
|
+
self.assertIsInstance(
|
200
|
+
lf.LanguageModel.get('claude-3-5-sonnet-latest'),
|
201
|
+
anthropic.Anthropic,
|
202
|
+
)
|
203
|
+
|
185
204
|
|
186
205
|
if __name__ == '__main__':
|
187
206
|
unittest.main()
|
langfun/core/llms/deepseek.py
CHANGED
@@ -13,33 +13,81 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from DeepSeek."""
|
15
15
|
|
16
|
+
import datetime
|
17
|
+
import functools
|
16
18
|
import os
|
17
|
-
from typing import Annotated, Any
|
19
|
+
from typing import Annotated, Any, Final
|
18
20
|
|
19
21
|
import langfun.core as lf
|
20
22
|
from langfun.core.llms import openai_compatible
|
21
23
|
import pyglove as pg
|
22
24
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
25
|
+
|
26
|
+
class DeepSeekModelInfo(lf.ModelInfo):
|
27
|
+
"""DeepSeek model info."""
|
28
|
+
|
29
|
+
LINKS = dict(
|
30
|
+
models='https://api-docs.deepseek.com/quick_start/pricing',
|
31
|
+
pricing='https://api-docs.deepseek.com/quick_start/pricing',
|
32
|
+
rate_limits='https://api-docs.deepseek.com/quick_start/rate_limit',
|
33
|
+
error_codes='https://api-docs.deepseek.com/quick_start/error_codes',
|
34
|
+
)
|
35
|
+
|
36
|
+
provider: Final[str] = 'DeepSeek' # pylint: disable=invalid-name
|
37
|
+
|
38
|
+
api_model_name: Annotated[
|
39
|
+
str,
|
40
|
+
'The model name used in the DeepSeek API.'
|
41
|
+
]
|
42
|
+
|
43
|
+
|
44
|
+
SUPPORTED_MODELS = [
|
45
|
+
DeepSeekModelInfo(
|
46
|
+
model_id='deepseek-r1',
|
29
47
|
in_service=True,
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
48
|
+
model_type='thinking',
|
49
|
+
api_model_name='deepseek-reasoner',
|
50
|
+
description='DeepSeek Reasoner model (01/20/2025).',
|
51
|
+
url='https://api-docs.deepseek.com/news/news250120',
|
52
|
+
release_date=datetime.datetime(2025, 1, 20),
|
53
|
+
input_modalities=lf.ModelInfo.TEXT_INPUT_ONLY,
|
54
|
+
context_length=lf.ModelInfo.ContextLength(
|
55
|
+
max_input_tokens=64_000,
|
56
|
+
max_output_tokens=8_000,
|
57
|
+
max_cot_tokens=32_000,
|
58
|
+
),
|
59
|
+
pricing=lf.ModelInfo.Pricing(
|
60
|
+
cost_per_1m_cached_input_tokens=0.14,
|
61
|
+
cost_per_1m_input_tokens=0.55,
|
62
|
+
cost_per_1m_output_tokens=2.19,
|
63
|
+
),
|
64
|
+
# No rate limits is enforced by DeepSeek for now.
|
65
|
+
rate_limits=None
|
34
66
|
),
|
35
|
-
|
67
|
+
DeepSeekModelInfo(
|
68
|
+
model_id='deepseek-v3',
|
36
69
|
in_service=True,
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
70
|
+
model_type='instruction-tuned',
|
71
|
+
api_model_name='deepseek-chat',
|
72
|
+
description='DeepSeek V3 model (12/26/2024).',
|
73
|
+
url='https://api-docs.deepseek.com/news/news1226',
|
74
|
+
release_date=datetime.datetime(2024, 12, 26),
|
75
|
+
input_modalities=lf.ModelInfo.TEXT_INPUT_ONLY,
|
76
|
+
context_length=lf.ModelInfo.ContextLength(
|
77
|
+
max_input_tokens=64_000,
|
78
|
+
max_output_tokens=8_000,
|
79
|
+
),
|
80
|
+
pricing=lf.ModelInfo.Pricing(
|
81
|
+
cost_per_1m_cached_input_tokens=0.07,
|
82
|
+
cost_per_1m_input_tokens=0.27,
|
83
|
+
cost_per_1m_output_tokens=1.1,
|
84
|
+
),
|
85
|
+
# No rate limits is enforced by DeepSeek for now.
|
86
|
+
rate_limits=None
|
41
87
|
),
|
42
|
-
|
88
|
+
]
|
89
|
+
|
90
|
+
_SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
|
43
91
|
|
44
92
|
|
45
93
|
# DeepSeek API uses an API format compatible with OpenAI.
|
@@ -50,7 +98,7 @@ class DeepSeek(openai_compatible.OpenAICompatible):
|
|
50
98
|
|
51
99
|
model: pg.typing.Annotated[
|
52
100
|
pg.typing.Enum(
|
53
|
-
pg.MISSING_VALUE,
|
101
|
+
pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]
|
54
102
|
),
|
55
103
|
'The name of the model to use.',
|
56
104
|
]
|
@@ -79,56 +127,47 @@ class DeepSeek(openai_compatible.OpenAICompatible):
|
|
79
127
|
})
|
80
128
|
return headers
|
81
129
|
|
82
|
-
@
|
83
|
-
def
|
84
|
-
|
85
|
-
return f'DeepSeek({self.model})'
|
130
|
+
@functools.cached_property
|
131
|
+
def model_info(self) -> DeepSeekModelInfo:
|
132
|
+
return _SUPPORTED_MODELS_BY_ID[self.model]
|
86
133
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
def estimate_cost(
|
96
|
-
self, num_input_tokens: int, num_output_tokens: int
|
97
|
-
) -> float | None:
|
98
|
-
"""Estimate the cost based on usage."""
|
99
|
-
cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
100
|
-
'cost_per_1k_input_tokens', None
|
101
|
-
)
|
102
|
-
cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
103
|
-
'cost_per_1k_output_tokens', None
|
104
|
-
)
|
105
|
-
if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
|
106
|
-
return None
|
107
|
-
return (
|
108
|
-
cost_per_1k_input_tokens * num_input_tokens
|
109
|
-
+ cost_per_1k_output_tokens * num_output_tokens
|
110
|
-
) / 1000
|
134
|
+
def _request_args(
|
135
|
+
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
136
|
+
"""Returns a dict as request arguments."""
|
137
|
+
# NOTE(daiyip): Replace model name with the API model name instead of the
|
138
|
+
# model ID.
|
139
|
+
args = super()._request_args(options)
|
140
|
+
args['model'] = self.model_info.api_model_name
|
141
|
+
return args
|
111
142
|
|
112
143
|
@classmethod
|
113
144
|
def dir(cls):
|
114
|
-
return [
|
145
|
+
return [m.model_id for m in SUPPORTED_MODELS if m.in_service]
|
115
146
|
|
116
147
|
|
117
|
-
class
|
148
|
+
class DeepSeekR1(DeepSeek):
|
118
149
|
"""DeepSeek Reasoner model.
|
119
150
|
|
120
151
|
Currently it is powered by DeepSeek-R1 model, 64k input context, 8k max
|
121
152
|
output, 32k max CoT output.
|
122
153
|
"""
|
123
154
|
|
124
|
-
model = 'deepseek-
|
155
|
+
model = 'deepseek-r1'
|
125
156
|
|
126
157
|
|
127
|
-
class
|
158
|
+
class DeepSeekV3(DeepSeek):
|
128
159
|
"""DeepSeek Chat model.
|
129
160
|
|
130
161
|
Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
|
131
162
|
8k max output tokens.
|
132
163
|
"""
|
133
164
|
|
134
|
-
model = 'deepseek-
|
165
|
+
model = 'deepseek-v3'
|
166
|
+
|
167
|
+
|
168
|
+
def _register_deepseek_models():
|
169
|
+
"""Registers DeepSeek models."""
|
170
|
+
for m in SUPPORTED_MODELS:
|
171
|
+
lf.LanguageModel.register(m.model_id, DeepSeek)
|
172
|
+
|
173
|
+
_register_deepseek_models()
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
import unittest
|
15
|
+
import langfun.core as lf
|
15
16
|
from langfun.core.llms import deepseek
|
16
17
|
|
17
18
|
|
@@ -19,13 +20,13 @@ class DeepSeekTest(unittest.TestCase):
|
|
19
20
|
"""Tests for DeepSeek language model."""
|
20
21
|
|
21
22
|
def test_dir(self):
|
22
|
-
self.assertIn('deepseek-
|
23
|
+
self.assertIn('deepseek-v3', deepseek.DeepSeek.dir())
|
23
24
|
|
24
25
|
def test_key(self):
|
25
26
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
26
|
-
_ = deepseek.
|
27
|
+
_ = deepseek.DeepSeekV3().headers
|
27
28
|
self.assertEqual(
|
28
|
-
deepseek.
|
29
|
+
deepseek.DeepSeekV3(api_key='test_key').headers,
|
29
30
|
{
|
30
31
|
'Content-Type': 'application/json',
|
31
32
|
'Authorization': 'Bearer test_key',
|
@@ -34,27 +35,25 @@ class DeepSeekTest(unittest.TestCase):
|
|
34
35
|
|
35
36
|
def test_model_id(self):
|
36
37
|
self.assertEqual(
|
37
|
-
deepseek.
|
38
|
-
'
|
38
|
+
deepseek.DeepSeekV3(api_key='test_key').model_id,
|
39
|
+
'deepseek-v3',
|
39
40
|
)
|
40
41
|
|
41
42
|
def test_resource_id(self):
|
42
43
|
self.assertEqual(
|
43
|
-
deepseek.
|
44
|
-
'
|
44
|
+
deepseek.DeepSeekV3(api_key='test_key').resource_id,
|
45
|
+
'deepseek://deepseek-v3',
|
45
46
|
)
|
46
47
|
|
47
|
-
def
|
48
|
-
|
49
|
-
|
48
|
+
def test_request(self):
|
49
|
+
request = deepseek.DeepSeekV3(api_key='test_key').request(
|
50
|
+
lf.UserMessage('hi'), lf.LMSamplingOptions(temperature=0.0),
|
50
51
|
)
|
52
|
+
self.assertEqual(request['model'], 'deepseek-chat')
|
51
53
|
|
52
|
-
def
|
53
|
-
self.
|
54
|
-
|
55
|
-
num_input_tokens=100, num_output_tokens=100
|
56
|
-
),
|
57
|
-
4.2e-5
|
54
|
+
def test_lm_get(self):
|
55
|
+
self.assertIsInstance(
|
56
|
+
lf.LanguageModel.get('deepseek-v3'), deepseek.DeepSeek
|
58
57
|
)
|
59
58
|
|
60
59
|
if __name__ == '__main__':
|
langfun/core/llms/fake.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Fake LMs for testing."""
|
15
15
|
|
16
16
|
import abc
|
17
|
+
import functools
|
17
18
|
from typing import Annotated
|
18
19
|
import langfun.core as lf
|
19
20
|
|
@@ -44,6 +45,11 @@ class Fake(lf.LanguageModel):
|
|
44
45
|
)
|
45
46
|
return results
|
46
47
|
|
48
|
+
@functools.cached_property
|
49
|
+
def model_info(self) -> lf.ModelInfo:
|
50
|
+
"""Returns the specification of the model."""
|
51
|
+
return lf.ModelInfo(model_id=self.__class__.__name__)
|
52
|
+
|
47
53
|
@abc.abstractmethod
|
48
54
|
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
49
55
|
"""Returns the response for the given prompt."""
|