langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from typing import Any
19
19
  import unittest
20
20
  from unittest import mock
21
21
 
22
+ import langfun.core as lf
22
23
  from langfun.core import modalities as lf_modalities
23
24
  from langfun.core.llms import anthropic
24
25
  import pyglove as pg
@@ -119,6 +120,17 @@ class AnthropicTest(unittest.TestCase):
119
120
  )
120
121
  self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
121
122
 
123
+ def test_model_alias(self):
124
+ # Alias will be normalized to the official version.
125
+ self.assertEqual(
126
+ anthropic.Anthropic('claude-3-5-sonnet-20241022').model_id,
127
+ 'claude-3-5-sonnet-20241022'
128
+ )
129
+ self.assertEqual(
130
+ anthropic.Anthropic('claude-3-5-sonnet-v2@20241022').model_id,
131
+ 'claude-3-5-sonnet-20241022'
132
+ )
133
+
122
134
  def test_api_key(self):
123
135
  lm = anthropic.Claude3Haiku()
124
136
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
@@ -151,6 +163,7 @@ class AnthropicTest(unittest.TestCase):
151
163
  self.assertIsNotNone(response.usage.prompt_tokens, 2)
152
164
  self.assertIsNotNone(response.usage.completion_tokens, 1)
153
165
  self.assertIsNotNone(response.usage.total_tokens, 3)
166
+ self.assertGreater(response.usage.estimated_cost, 0)
154
167
 
155
168
  def test_mm_call(self):
156
169
  with mock.patch('requests.Session.post') as mock_mm_request:
@@ -162,7 +175,7 @@ class AnthropicTest(unittest.TestCase):
162
175
  def test_pdf_call(self):
163
176
  with mock.patch('requests.Session.post') as mock_mm_request:
164
177
  mock_mm_request.side_effect = mock_mm_requests_post
165
- lm = anthropic.Claude3Haiku(api_key='fake_key')
178
+ lm = anthropic.Claude35Sonnet(api_key='fake_key')
166
179
  response = lm(lf_modalities.PDF.from_bytes(pdf_content), lm=lm)
167
180
  self.assertEqual(response.text, 'document: application/pdf')
168
181
 
@@ -182,6 +195,12 @@ class AnthropicTest(unittest.TestCase):
182
195
  ):
183
196
  lm('hello', max_attempts=1)
184
197
 
198
+ def test_lm_get(self):
199
+ self.assertIsInstance(
200
+ lf.LanguageModel.get('claude-3-5-sonnet-latest'),
201
+ anthropic.Anthropic,
202
+ )
203
+
185
204
 
186
205
  if __name__ == '__main__':
187
206
  unittest.main()
@@ -13,33 +13,81 @@
13
13
  # limitations under the License.
14
14
  """Language models from DeepSeek."""
15
15
 
16
+ import datetime
17
+ import functools
16
18
  import os
17
- from typing import Annotated, Any
19
+ from typing import Annotated, Any, Final
18
20
 
19
21
  import langfun.core as lf
20
22
  from langfun.core.llms import openai_compatible
21
23
  import pyglove as pg
22
24
 
23
- SUPPORTED_MODELS_AND_SETTINGS = {
24
- # pylint: disable=g-line-too-long
25
- # TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
26
- # DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
27
- # The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
28
- 'deepseek-reasoner': pg.Dict(
25
+
26
+ class DeepSeekModelInfo(lf.ModelInfo):
27
+ """DeepSeek model info."""
28
+
29
+ LINKS = dict(
30
+ models='https://api-docs.deepseek.com/quick_start/pricing',
31
+ pricing='https://api-docs.deepseek.com/quick_start/pricing',
32
+ rate_limits='https://api-docs.deepseek.com/quick_start/rate_limit',
33
+ error_codes='https://api-docs.deepseek.com/quick_start/error_codes',
34
+ )
35
+
36
+ provider: Final[str] = 'DeepSeek' # pylint: disable=invalid-name
37
+
38
+ api_model_name: Annotated[
39
+ str,
40
+ 'The model name used in the DeepSeek API.'
41
+ ]
42
+
43
+
44
+ SUPPORTED_MODELS = [
45
+ DeepSeekModelInfo(
46
+ model_id='deepseek-r1',
29
47
  in_service=True,
30
- rpm=100,
31
- tpm=1000000,
32
- cost_per_1k_input_tokens=0.00055,
33
- cost_per_1k_output_tokens=0.00219,
48
+ model_type='thinking',
49
+ api_model_name='deepseek-reasoner',
50
+ description='DeepSeek Reasoner model (01/20/2025).',
51
+ url='https://api-docs.deepseek.com/news/news250120',
52
+ release_date=datetime.datetime(2025, 1, 20),
53
+ input_modalities=lf.ModelInfo.TEXT_INPUT_ONLY,
54
+ context_length=lf.ModelInfo.ContextLength(
55
+ max_input_tokens=64_000,
56
+ max_output_tokens=8_000,
57
+ max_cot_tokens=32_000,
58
+ ),
59
+ pricing=lf.ModelInfo.Pricing(
60
+ cost_per_1m_cached_input_tokens=0.14,
61
+ cost_per_1m_input_tokens=0.55,
62
+ cost_per_1m_output_tokens=2.19,
63
+ ),
64
+ # No rate limits is enforced by DeepSeek for now.
65
+ rate_limits=None
34
66
  ),
35
- 'deepseek-chat': pg.Dict(
67
+ DeepSeekModelInfo(
68
+ model_id='deepseek-v3',
36
69
  in_service=True,
37
- rpm=100,
38
- tpm=1000000,
39
- cost_per_1k_input_tokens=0.00014,
40
- cost_per_1k_output_tokens=0.00028,
70
+ model_type='instruction-tuned',
71
+ api_model_name='deepseek-chat',
72
+ description='DeepSeek V3 model (12/26/2024).',
73
+ url='https://api-docs.deepseek.com/news/news1226',
74
+ release_date=datetime.datetime(2024, 12, 26),
75
+ input_modalities=lf.ModelInfo.TEXT_INPUT_ONLY,
76
+ context_length=lf.ModelInfo.ContextLength(
77
+ max_input_tokens=64_000,
78
+ max_output_tokens=8_000,
79
+ ),
80
+ pricing=lf.ModelInfo.Pricing(
81
+ cost_per_1m_cached_input_tokens=0.07,
82
+ cost_per_1m_input_tokens=0.27,
83
+ cost_per_1m_output_tokens=1.1,
84
+ ),
85
+ # No rate limits is enforced by DeepSeek for now.
86
+ rate_limits=None
41
87
  ),
42
- }
88
+ ]
89
+
90
+ _SUPPORTED_MODELS_BY_ID = {m.model_id: m for m in SUPPORTED_MODELS}
43
91
 
44
92
 
45
93
  # DeepSeek API uses an API format compatible with OpenAI.
@@ -50,7 +98,7 @@ class DeepSeek(openai_compatible.OpenAICompatible):
50
98
 
51
99
  model: pg.typing.Annotated[
52
100
  pg.typing.Enum(
53
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
101
+ pg.MISSING_VALUE, [m.model_id for m in SUPPORTED_MODELS]
54
102
  ),
55
103
  'The name of the model to use.',
56
104
  ]
@@ -79,56 +127,47 @@ class DeepSeek(openai_compatible.OpenAICompatible):
79
127
  })
80
128
  return headers
81
129
 
82
- @property
83
- def model_id(self) -> str:
84
- """Returns a string to identify the model."""
85
- return f'DeepSeek({self.model})'
130
+ @functools.cached_property
131
+ def model_info(self) -> DeepSeekModelInfo:
132
+ return _SUPPORTED_MODELS_BY_ID[self.model]
86
133
 
87
- @property
88
- def max_concurrency(self) -> int:
89
- rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
90
- tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
91
- return self.rate_to_max_concurrency(
92
- requests_per_min=rpm, tokens_per_min=tpm
93
- )
94
-
95
- def estimate_cost(
96
- self, num_input_tokens: int, num_output_tokens: int
97
- ) -> float | None:
98
- """Estimate the cost based on usage."""
99
- cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
100
- 'cost_per_1k_input_tokens', None
101
- )
102
- cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
103
- 'cost_per_1k_output_tokens', None
104
- )
105
- if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
106
- return None
107
- return (
108
- cost_per_1k_input_tokens * num_input_tokens
109
- + cost_per_1k_output_tokens * num_output_tokens
110
- ) / 1000
134
+ def _request_args(
135
+ self, options: lf.LMSamplingOptions) -> dict[str, Any]:
136
+ """Returns a dict as request arguments."""
137
+ # NOTE(daiyip): Replace model name with the API model name instead of the
138
+ # model ID.
139
+ args = super()._request_args(options)
140
+ args['model'] = self.model_info.api_model_name
141
+ return args
111
142
 
112
143
  @classmethod
113
144
  def dir(cls):
114
- return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
145
+ return [m.model_id for m in SUPPORTED_MODELS if m.in_service]
115
146
 
116
147
 
117
- class DeepSeekReasoner(DeepSeek):
148
+ class DeepSeekR1(DeepSeek):
118
149
  """DeepSeek Reasoner model.
119
150
 
120
151
  Currently it is powered by DeepSeek-R1 model, 64k input context, 8k max
121
152
  output, 32k max CoT output.
122
153
  """
123
154
 
124
- model = 'deepseek-reasoner'
155
+ model = 'deepseek-r1'
125
156
 
126
157
 
127
- class DeepSeekChat(DeepSeek):
158
+ class DeepSeekV3(DeepSeek):
128
159
  """DeepSeek Chat model.
129
160
 
130
161
  Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
131
162
  8k max output tokens.
132
163
  """
133
164
 
134
- model = 'deepseek-chat'
165
+ model = 'deepseek-v3'
166
+
167
+
168
+ def _register_deepseek_models():
169
+ """Registers DeepSeek models."""
170
+ for m in SUPPORTED_MODELS:
171
+ lf.LanguageModel.register(m.model_id, DeepSeek)
172
+
173
+ _register_deepseek_models()
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import unittest
15
+ import langfun.core as lf
15
16
  from langfun.core.llms import deepseek
16
17
 
17
18
 
@@ -19,13 +20,13 @@ class DeepSeekTest(unittest.TestCase):
19
20
  """Tests for DeepSeek language model."""
20
21
 
21
22
  def test_dir(self):
22
- self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
23
+ self.assertIn('deepseek-v3', deepseek.DeepSeek.dir())
23
24
 
24
25
  def test_key(self):
25
26
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
26
- _ = deepseek.DeepSeekChat().headers
27
+ _ = deepseek.DeepSeekV3().headers
27
28
  self.assertEqual(
28
- deepseek.DeepSeekChat(api_key='test_key').headers,
29
+ deepseek.DeepSeekV3(api_key='test_key').headers,
29
30
  {
30
31
  'Content-Type': 'application/json',
31
32
  'Authorization': 'Bearer test_key',
@@ -34,27 +35,25 @@ class DeepSeekTest(unittest.TestCase):
34
35
 
35
36
  def test_model_id(self):
36
37
  self.assertEqual(
37
- deepseek.DeepSeekChat(api_key='test_key').model_id,
38
- 'DeepSeek(deepseek-chat)',
38
+ deepseek.DeepSeekV3(api_key='test_key').model_id,
39
+ 'deepseek-v3',
39
40
  )
40
41
 
41
42
  def test_resource_id(self):
42
43
  self.assertEqual(
43
- deepseek.DeepSeekChat(api_key='test_key').resource_id,
44
- 'DeepSeek(deepseek-chat)',
44
+ deepseek.DeepSeekV3(api_key='test_key').resource_id,
45
+ 'deepseek://deepseek-v3',
45
46
  )
46
47
 
47
- def test_max_concurrency(self):
48
- self.assertGreater(
49
- deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
48
+ def test_request(self):
49
+ request = deepseek.DeepSeekV3(api_key='test_key').request(
50
+ lf.UserMessage('hi'), lf.LMSamplingOptions(temperature=0.0),
50
51
  )
52
+ self.assertEqual(request['model'], 'deepseek-chat')
51
53
 
52
- def test_estimate_cost(self):
53
- self.assertEqual(
54
- deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
55
- num_input_tokens=100, num_output_tokens=100
56
- ),
57
- 4.2e-5
54
+ def test_lm_get(self):
55
+ self.assertIsInstance(
56
+ lf.LanguageModel.get('deepseek-v3'), deepseek.DeepSeek
58
57
  )
59
58
 
60
59
  if __name__ == '__main__':
langfun/core/llms/fake.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Fake LMs for testing."""
15
15
 
16
16
  import abc
17
+ import functools
17
18
  from typing import Annotated
18
19
  import langfun.core as lf
19
20
 
@@ -44,6 +45,11 @@ class Fake(lf.LanguageModel):
44
45
  )
45
46
  return results
46
47
 
48
+ @functools.cached_property
49
+ def model_info(self) -> lf.ModelInfo:
50
+ """Returns the specification of the model."""
51
+ return lf.ModelInfo(model_id=self.__class__.__name__)
52
+
47
53
  @abc.abstractmethod
48
54
  def _response_from(self, prompt: lf.Message) -> lf.Message:
49
55
  """Returns the response for the given prompt."""