langfun 0.1.2.dev202501090804__py3-none-any.whl → 0.1.2.dev202501110804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -57,6 +57,9 @@ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
57
57
  from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
58
58
  from langfun.core.llms.vertexai import VertexAIGeminiPro1
59
59
 
60
+ # Base for OpenAI-compatible models.
61
+ from langfun.core.llms.openai_compatible import OpenAICompatible
62
+
60
63
  # OpenAI models.
61
64
  from langfun.core.llms.openai import OpenAI
62
65
 
@@ -17,8 +17,7 @@ import os
17
17
  from typing import Annotated, Any
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import modalities as lf_modalities
21
- from langfun.core.llms import rest
20
+ from langfun.core.llms import openai_compatible
22
21
  import pyglove as pg
23
22
 
24
23
  SUPPORTED_MODELS_AND_SETTINGS = {
@@ -39,7 +38,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
39
38
  # DeepSeek API uses an API format compatible with OpenAI.
40
39
  # Reference: https://api-docs.deepseek.com/
41
40
  @lf.use_init_args(['model'])
42
- class DeepSeek(rest.REST):
41
+ class DeepSeek(openai_compatible.OpenAICompatible):
43
42
  """DeepSeek model."""
44
43
 
45
44
  model: pg.typing.Annotated[
@@ -51,10 +50,6 @@ class DeepSeek(rest.REST):
51
50
 
52
51
  api_endpoint: str = 'https://api.deepseek.com/chat/completions'
53
52
 
54
- multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
55
- False
56
- )
57
-
58
53
  api_key: Annotated[
59
54
  str | None,
60
55
  (
@@ -63,25 +58,18 @@ class DeepSeek(rest.REST):
63
58
  ),
64
59
  ] = None
65
60
 
66
- def _on_bound(self):
67
- super()._on_bound()
68
- self._api_key = None
69
-
70
- def _initialize(self):
61
+ @property
62
+ def headers(self) -> dict[str, Any]:
71
63
  api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
72
64
  if not api_key:
73
65
  raise ValueError(
74
66
  'Please specify `api_key` during `__init__` or set environment '
75
67
  'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
76
68
  )
77
- self._api_key = api_key
78
-
79
- @property
80
- def headers(self) -> dict[str, Any]:
81
- headers = {
82
- 'Content-Type': 'application/json',
83
- 'Authorization': f'Bearer {self._api_key}',
84
- }
69
+ headers = super().headers
70
+ headers.update({
71
+ 'Authorization': f'Bearer {api_key}',
72
+ })
85
73
  return headers
86
74
 
87
75
  @property
@@ -118,138 +106,6 @@ class DeepSeek(rest.REST):
118
106
  def dir(cls):
119
107
  return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
120
108
 
121
- def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
122
- # Reference:
123
- # https://platform.openai.com/docs/api-reference/completions/create
124
- # NOTE(daiyip): options.top_k is not applicable.
125
- args = dict(
126
- model=self.model,
127
- n=options.n,
128
- top_logprobs=options.top_logprobs,
129
- )
130
- if options.logprobs:
131
- args['logprobs'] = options.logprobs
132
-
133
- if options.temperature is not None:
134
- args['temperature'] = options.temperature
135
- if options.max_tokens is not None:
136
- args['max_completion_tokens'] = options.max_tokens
137
- if options.top_p is not None:
138
- args['top_p'] = options.top_p
139
- if options.stop:
140
- args['stop'] = options.stop
141
- if options.random_seed is not None:
142
- args['seed'] = options.random_seed
143
- return args
144
-
145
- def _content_from_message(self, message: lf.Message):
146
- """Returns a OpenAI content object from a Langfun message."""
147
-
148
- def _uri_from(chunk: lf.Modality) -> str:
149
- if chunk.uri and chunk.uri.lower().startswith(
150
- ('http:', 'https:', 'ftp:')
151
- ):
152
- return chunk.uri
153
- return chunk.content_uri
154
-
155
- content = []
156
- for chunk in message.chunk():
157
- if isinstance(chunk, str):
158
- item = dict(type='text', text=chunk)
159
- elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
160
- item = dict(type='image_url', image_url=dict(url=_uri_from(chunk)))
161
- else:
162
- raise ValueError(f'Unsupported modality: {chunk!r}.')
163
- content.append(item)
164
- return content
165
-
166
- def request(
167
- self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
168
- ) -> dict[str, Any]:
169
- """Returns the JSON input for a message."""
170
- request_args = self._request_args(sampling_options)
171
-
172
- # Users could use `metadata_json_schema` to pass additional
173
- # request arguments.
174
- json_schema = prompt.metadata.get('json_schema')
175
- if json_schema is not None:
176
- if not isinstance(json_schema, dict):
177
- raise ValueError(f'`json_schema` must be a dict, got {json_schema!r}.')
178
- if 'title' not in json_schema:
179
- raise ValueError(
180
- 'The root of `json_schema` must have a `title` field, '
181
- f'got {json_schema!r}.'
182
- )
183
- request_args.update(
184
- response_format=dict(
185
- type='json_schema',
186
- json_schema=dict(
187
- schema=json_schema,
188
- name=json_schema['title'],
189
- strict=True,
190
- ),
191
- )
192
- )
193
- prompt.metadata.formatted_text = (
194
- prompt.text
195
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
196
- + pg.to_json_str(request_args['response_format'], json_indent=2)
197
- )
198
-
199
- # Prepare messages.
200
- messages = []
201
- # Users could use `metadata_system_message` to pass system message.
202
- system_message = prompt.metadata.get('system_message')
203
- if system_message:
204
- system_message = lf.SystemMessage.from_value(system_message)
205
- messages.append(
206
- dict(
207
- role='system', content=self._content_from_message(system_message)
208
- )
209
- )
210
- messages.append(
211
- dict(role='user', content=self._content_from_message(prompt))
212
- )
213
- request = dict()
214
- request.update(request_args)
215
- request['messages'] = messages
216
- return request
217
-
218
- def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
219
- # Reference:
220
- # https://platform.openai.com/docs/api-reference/chat/object
221
- logprobs = None
222
- choice_logprobs = choice.get('logprobs')
223
- if choice_logprobs:
224
- logprobs = [
225
- (
226
- t['token'],
227
- t['logprob'],
228
- [(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
229
- )
230
- for t in choice_logprobs['content']
231
- ]
232
- return lf.LMSample(
233
- choice['message']['content'],
234
- score=0.0,
235
- logprobs=logprobs,
236
- )
237
-
238
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
239
- usage = json['usage']
240
- return lf.LMSamplingResult(
241
- samples=[self._parse_choice(choice) for choice in json['choices']],
242
- usage=lf.LMSamplingUsage(
243
- prompt_tokens=usage['prompt_tokens'],
244
- completion_tokens=usage['completion_tokens'],
245
- total_tokens=usage['total_tokens'],
246
- estimated_cost=self.estimate_cost(
247
- num_input_tokens=usage['prompt_tokens'],
248
- num_output_tokens=usage['completion_tokens'],
249
- ),
250
- ),
251
- )
252
-
253
109
 
254
110
  class DeepSeekChat(DeepSeek):
255
111
  """DeepSeek Chat model.
@@ -11,72 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for OpenAI models."""
15
-
16
- from typing import Any
17
14
  import unittest
18
- from unittest import mock
19
-
20
- import langfun.core as lf
21
15
  from langfun.core.llms import deepseek
22
- import pyglove as pg
23
- import requests
24
-
25
-
26
- def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
27
- del url, kwargs
28
- messages = json['messages']
29
- if len(messages) > 1:
30
- system_message = f' system={messages[0]["content"]}'
31
- else:
32
- system_message = ''
33
-
34
- if 'response_format' in json:
35
- response_format = f' format={json["response_format"]["type"]}'
36
- else:
37
- response_format = ''
38
-
39
- choices = []
40
- for k in range(json['n']):
41
- if json.get('logprobs'):
42
- logprobs = dict(
43
- content=[
44
- dict(
45
- token='chosen_token',
46
- logprob=0.5,
47
- top_logprobs=[
48
- dict(
49
- token=f'alternative_token_{i + 1}',
50
- logprob=0.1
51
- ) for i in range(3)
52
- ]
53
- )
54
- ]
55
- )
56
- else:
57
- logprobs = None
58
-
59
- choices.append(dict(
60
- message=dict(
61
- content=(
62
- f'Sample {k} for message.{system_message}{response_format}'
63
- )
64
- ),
65
- logprobs=logprobs,
66
- ))
67
- response = requests.Response()
68
- response.status_code = 200
69
- response._content = pg.to_json_str(
70
- dict(
71
- choices=choices,
72
- usage=lf.LMSamplingUsage(
73
- prompt_tokens=100,
74
- completion_tokens=100,
75
- total_tokens=200,
76
- ),
77
- )
78
- ).encode()
79
- return response
80
16
 
81
17
 
82
18
  class DeepSeekTest(unittest.TestCase):
@@ -87,7 +23,14 @@ class DeepSeekTest(unittest.TestCase):
87
23
 
88
24
  def test_key(self):
89
25
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
90
- deepseek.DeepSeekChat()('hi')
26
+ _ = deepseek.DeepSeekChat().headers
27
+ self.assertEqual(
28
+ deepseek.DeepSeekChat(api_key='test_key').headers,
29
+ {
30
+ 'Content-Type': 'application/json',
31
+ 'Authorization': 'Bearer test_key',
32
+ }
33
+ )
91
34
 
92
35
  def test_model_id(self):
93
36
  self.assertEqual(
@@ -106,333 +49,13 @@ class DeepSeekTest(unittest.TestCase):
106
49
  deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
107
50
  )
108
51
 
109
- def test_request_args(self):
110
- self.assertEqual(
111
- deepseek.DeepSeekChat(api_key='test_key')._request_args(
112
- lf.LMSamplingOptions(
113
- temperature=1.0, stop=['\n'], n=1, random_seed=123
114
- )
115
- ),
116
- dict(
117
- model='deepseek-chat',
118
- top_logprobs=None,
119
- n=1,
120
- temperature=1.0,
121
- stop=['\n'],
122
- seed=123,
123
- ),
124
- )
125
-
126
- def test_call_chat_completion(self):
127
- with mock.patch('requests.Session.post') as mock_request:
128
- mock_request.side_effect = mock_chat_completion_request
129
- lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key')
130
- self.assertEqual(
131
- lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
132
- 'Sample 0 for message.',
133
- )
134
-
135
- def test_call_chat_completion_with_logprobs(self):
136
- with mock.patch('requests.Session.post') as mock_request:
137
- mock_request.side_effect = mock_chat_completion_request
138
- lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key')
139
- results = lm.sample(['hello'], logprobs=True)
140
- self.assertEqual(len(results), 1)
141
- expected = lf.LMSamplingResult(
142
- [
143
- lf.LMSample(
144
- response=lf.AIMessage(
145
- text='Sample 0 for message.',
146
- metadata={
147
- 'score': 0.0,
148
- 'logprobs': [(
149
- 'chosen_token',
150
- 0.5,
151
- [
152
- ('alternative_token_1', 0.1),
153
- ('alternative_token_2', 0.1),
154
- ('alternative_token_3', 0.1),
155
- ],
156
- )],
157
- 'is_cached': False,
158
- 'usage': lf.LMSamplingUsage(
159
- prompt_tokens=100,
160
- completion_tokens=100,
161
- total_tokens=200,
162
- estimated_cost=4.2e-05,
163
- ),
164
- },
165
- tags=['lm-response'],
166
- ),
167
- logprobs=[(
168
- 'chosen_token',
169
- 0.5,
170
- [
171
- ('alternative_token_1', 0.1),
172
- ('alternative_token_2', 0.1),
173
- ('alternative_token_3', 0.1),
174
- ],
175
- )],
176
- )
177
- ],
178
- usage=lf.LMSamplingUsage(
179
- prompt_tokens=100,
180
- completion_tokens=100,
181
- total_tokens=200,
182
- estimated_cost=4.2e-05,
183
- ),
184
- )
185
- self.assertTrue(pg.eq(results[0], expected))
186
-
187
- def test_sample_chat_completion(self):
188
- with mock.patch('requests.Session.post') as mock_request:
189
- mock_request.side_effect = mock_chat_completion_request
190
- deepseek.SUPPORTED_MODELS_AND_SETTINGS['deepseek-chat'].update({
191
- 'cost_per_1k_input_tokens': 1.0,
192
- 'cost_per_1k_output_tokens': 1.0,
193
- })
194
- lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
195
- results = lm.sample(
196
- ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
197
- )
198
-
199
- self.assertEqual(len(results), 2)
200
- print(results[0])
201
- self.assertEqual(
202
- results[0],
203
- lf.LMSamplingResult(
204
- [
205
- lf.LMSample(
206
- lf.AIMessage(
207
- 'Sample 0 for message.',
208
- score=0.0,
209
- logprobs=None,
210
- is_cached=False,
211
- usage=lf.LMSamplingUsage(
212
- prompt_tokens=33,
213
- completion_tokens=33,
214
- total_tokens=66,
215
- estimated_cost=0.2 / 3,
216
- ),
217
- tags=[lf.Message.TAG_LM_RESPONSE],
218
- ),
219
- score=0.0,
220
- logprobs=None,
221
- ),
222
- lf.LMSample(
223
- lf.AIMessage(
224
- 'Sample 1 for message.',
225
- score=0.0,
226
- logprobs=None,
227
- is_cached=False,
228
- usage=lf.LMSamplingUsage(
229
- prompt_tokens=33,
230
- completion_tokens=33,
231
- total_tokens=66,
232
- estimated_cost=0.2 / 3,
233
- ),
234
- tags=[lf.Message.TAG_LM_RESPONSE],
235
- ),
236
- score=0.0,
237
- logprobs=None,
238
- ),
239
- lf.LMSample(
240
- lf.AIMessage(
241
- 'Sample 2 for message.',
242
- score=0.0,
243
- logprobs=None,
244
- is_cached=False,
245
- usage=lf.LMSamplingUsage(
246
- prompt_tokens=33,
247
- completion_tokens=33,
248
- total_tokens=66,
249
- estimated_cost=0.2 / 3,
250
- ),
251
- tags=[lf.Message.TAG_LM_RESPONSE],
252
- ),
253
- score=0.0,
254
- logprobs=None,
255
- ),
256
- ],
257
- usage=lf.LMSamplingUsage(
258
- prompt_tokens=100, completion_tokens=100, total_tokens=200,
259
- estimated_cost=0.2,
260
- ),
261
- ),
262
- )
52
+ def test_estimate_cost(self):
263
53
  self.assertEqual(
264
- results[1],
265
- lf.LMSamplingResult(
266
- [
267
- lf.LMSample(
268
- lf.AIMessage(
269
- 'Sample 0 for message.',
270
- score=0.0,
271
- logprobs=None,
272
- is_cached=False,
273
- usage=lf.LMSamplingUsage(
274
- prompt_tokens=33,
275
- completion_tokens=33,
276
- total_tokens=66,
277
- estimated_cost=0.2 / 3,
278
- ),
279
- tags=[lf.Message.TAG_LM_RESPONSE],
280
- ),
281
- score=0.0,
282
- logprobs=None,
283
- ),
284
- lf.LMSample(
285
- lf.AIMessage(
286
- 'Sample 1 for message.',
287
- score=0.0,
288
- logprobs=None,
289
- is_cached=False,
290
- usage=lf.LMSamplingUsage(
291
- prompt_tokens=33,
292
- completion_tokens=33,
293
- total_tokens=66,
294
- estimated_cost=0.2 / 3,
295
- ),
296
- tags=[lf.Message.TAG_LM_RESPONSE],
297
- ),
298
- score=0.0,
299
- logprobs=None,
300
- ),
301
- lf.LMSample(
302
- lf.AIMessage(
303
- 'Sample 2 for message.',
304
- score=0.0,
305
- logprobs=None,
306
- is_cached=False,
307
- usage=lf.LMSamplingUsage(
308
- prompt_tokens=33,
309
- completion_tokens=33,
310
- total_tokens=66,
311
- estimated_cost=0.2 / 3,
312
- ),
313
- tags=[lf.Message.TAG_LM_RESPONSE],
314
- ),
315
- score=0.0,
316
- logprobs=None,
317
- ),
318
- ],
319
- usage=lf.LMSamplingUsage(
320
- prompt_tokens=100, completion_tokens=100, total_tokens=200,
321
- estimated_cost=0.2,
322
- ),
54
+ deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
55
+ num_input_tokens=100, num_output_tokens=100
323
56
  ),
57
+ 4.2e-5
324
58
  )
325
59
 
326
- def test_sample_with_contextual_options(self):
327
- with mock.patch('requests.Session.post') as mock_request:
328
- mock_request.side_effect = mock_chat_completion_request
329
- lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
330
- with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
331
- results = lm.sample(['hello'])
332
-
333
- self.assertEqual(len(results), 1)
334
- expected = lf.LMSamplingResult(
335
- samples=[
336
- lf.LMSample(
337
- response=lf.AIMessage(
338
- text='Sample 0 for message.',
339
- sender='AI',
340
- metadata=pg.Dict(
341
- score=0.0,
342
- logprobs=None,
343
- is_cached=False,
344
- usage=lf.LMSamplingUsage(
345
- prompt_tokens=50,
346
- completion_tokens=50,
347
- total_tokens=100,
348
- num_requests=1,
349
- estimated_cost=0.1,
350
- ),
351
- ),
352
- tags=['lm-response'],
353
- ),
354
- score=0.0,
355
- logprobs=None,
356
- ),
357
- lf.LMSample(
358
- response=lf.AIMessage(
359
- text='Sample 1 for message.',
360
- sender='AI',
361
- metadata=pg.Dict(
362
- score=0.0,
363
- logprobs=None,
364
- is_cached=False,
365
- usage=lf.LMSamplingUsage(
366
- prompt_tokens=50,
367
- completion_tokens=50,
368
- total_tokens=100,
369
- num_requests=1,
370
- estimated_cost=0.1,
371
- ),
372
- ),
373
- tags=['lm-response'],
374
- ),
375
- score=0.0,
376
- logprobs=None,
377
- ),
378
- ],
379
- usage=lf.LMSamplingUsage(
380
- prompt_tokens=100,
381
- completion_tokens=100,
382
- total_tokens=200,
383
- num_requests=1,
384
- estimated_cost=0.2,
385
- ),
386
- is_cached=False,
387
- )
388
- self.assertTrue(pg.eq(results[0], expected))
389
-
390
- def test_call_with_system_message(self):
391
- with mock.patch('requests.Session.post') as mock_request:
392
- mock_request.side_effect = mock_chat_completion_request
393
- lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
394
- self.assertEqual(
395
- lm(
396
- lf.UserMessage(
397
- 'hello',
398
- system_message='hi',
399
- ),
400
- sampling_options=lf.LMSamplingOptions(n=2)
401
- ),
402
- '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
403
- )
404
-
405
- def test_call_with_json_schema(self):
406
- with mock.patch('requests.Session.post') as mock_request:
407
- mock_request.side_effect = mock_chat_completion_request
408
- lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
409
- self.assertEqual(
410
- lm(
411
- lf.UserMessage(
412
- 'hello',
413
- json_schema={
414
- 'type': 'object',
415
- 'properties': {
416
- 'name': {'type': 'string'},
417
- },
418
- 'required': ['name'],
419
- 'title': 'Person',
420
- }
421
- ),
422
- sampling_options=lf.LMSamplingOptions(n=2)
423
- ),
424
- 'Sample 0 for message. format=json_schema',
425
- )
426
-
427
- # Test bad json schema.
428
- with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
429
- lm(lf.UserMessage('hello', json_schema='foo'))
430
-
431
- with self.assertRaisesRegex(
432
- ValueError, 'The root of `json_schema` must have a `title` field'
433
- ):
434
- lm(lf.UserMessage('hello', json_schema={}))
435
-
436
-
437
60
  if __name__ == '__main__':
438
61
  unittest.main()