langfun 0.1.2.dev202501090804__py3-none-any.whl → 0.1.2.dev202501100804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/llms/__init__.py +3 -0
- langfun/core/llms/deepseek.py +8 -152
- langfun/core/llms/deepseek_test.py +12 -389
- langfun/core/llms/groq.py +12 -99
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +17 -54
- langfun/core/llms/llama_cpp_test.py +2 -34
- langfun/core/llms/openai.py +9 -147
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +480 -0
- langfun/core/llms/openai_test.py +13 -423
- langfun/core/modalities/mime.py +8 -0
- langfun/core/modalities/mime_test.py +19 -4
- langfun/core/modality_test.py +0 -1
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/RECORD +19 -17
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/top_level.txt +0 -0
langfun/core/llms/__init__.py
CHANGED
@@ -57,6 +57,9 @@ from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_002
|
|
57
57
|
from langfun.core.llms.vertexai import VertexAIGeminiFlash1_5_001
|
58
58
|
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
59
59
|
|
60
|
+
# Base for OpenAI-compatible models.
|
61
|
+
from langfun.core.llms.openai_compatible import OpenAICompatible
|
62
|
+
|
60
63
|
# OpenAI models.
|
61
64
|
from langfun.core.llms.openai import OpenAI
|
62
65
|
|
langfun/core/llms/deepseek.py
CHANGED
@@ -17,8 +17,7 @@ import os
|
|
17
17
|
from typing import Annotated, Any
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import
|
21
|
-
from langfun.core.llms import rest
|
20
|
+
from langfun.core.llms import openai_compatible
|
22
21
|
import pyglove as pg
|
23
22
|
|
24
23
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
@@ -39,7 +38,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
39
38
|
# DeepSeek API uses an API format compatible with OpenAI.
|
40
39
|
# Reference: https://api-docs.deepseek.com/
|
41
40
|
@lf.use_init_args(['model'])
|
42
|
-
class DeepSeek(
|
41
|
+
class DeepSeek(openai_compatible.OpenAICompatible):
|
43
42
|
"""DeepSeek model."""
|
44
43
|
|
45
44
|
model: pg.typing.Annotated[
|
@@ -51,10 +50,6 @@ class DeepSeek(rest.REST):
|
|
51
50
|
|
52
51
|
api_endpoint: str = 'https://api.deepseek.com/chat/completions'
|
53
52
|
|
54
|
-
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
55
|
-
False
|
56
|
-
)
|
57
|
-
|
58
53
|
api_key: Annotated[
|
59
54
|
str | None,
|
60
55
|
(
|
@@ -63,25 +58,18 @@ class DeepSeek(rest.REST):
|
|
63
58
|
),
|
64
59
|
] = None
|
65
60
|
|
66
|
-
|
67
|
-
|
68
|
-
self._api_key = None
|
69
|
-
|
70
|
-
def _initialize(self):
|
61
|
+
@property
|
62
|
+
def headers(self) -> dict[str, Any]:
|
71
63
|
api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
|
72
64
|
if not api_key:
|
73
65
|
raise ValueError(
|
74
66
|
'Please specify `api_key` during `__init__` or set environment '
|
75
67
|
'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
|
76
68
|
)
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
headers = {
|
82
|
-
'Content-Type': 'application/json',
|
83
|
-
'Authorization': f'Bearer {self._api_key}',
|
84
|
-
}
|
69
|
+
headers = super().headers
|
70
|
+
headers.update({
|
71
|
+
'Authorization': f'Bearer {api_key}',
|
72
|
+
})
|
85
73
|
return headers
|
86
74
|
|
87
75
|
@property
|
@@ -118,138 +106,6 @@ class DeepSeek(rest.REST):
|
|
118
106
|
def dir(cls):
|
119
107
|
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
120
108
|
|
121
|
-
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
122
|
-
# Reference:
|
123
|
-
# https://platform.openai.com/docs/api-reference/completions/create
|
124
|
-
# NOTE(daiyip): options.top_k is not applicable.
|
125
|
-
args = dict(
|
126
|
-
model=self.model,
|
127
|
-
n=options.n,
|
128
|
-
top_logprobs=options.top_logprobs,
|
129
|
-
)
|
130
|
-
if options.logprobs:
|
131
|
-
args['logprobs'] = options.logprobs
|
132
|
-
|
133
|
-
if options.temperature is not None:
|
134
|
-
args['temperature'] = options.temperature
|
135
|
-
if options.max_tokens is not None:
|
136
|
-
args['max_completion_tokens'] = options.max_tokens
|
137
|
-
if options.top_p is not None:
|
138
|
-
args['top_p'] = options.top_p
|
139
|
-
if options.stop:
|
140
|
-
args['stop'] = options.stop
|
141
|
-
if options.random_seed is not None:
|
142
|
-
args['seed'] = options.random_seed
|
143
|
-
return args
|
144
|
-
|
145
|
-
def _content_from_message(self, message: lf.Message):
|
146
|
-
"""Returns a OpenAI content object from a Langfun message."""
|
147
|
-
|
148
|
-
def _uri_from(chunk: lf.Modality) -> str:
|
149
|
-
if chunk.uri and chunk.uri.lower().startswith(
|
150
|
-
('http:', 'https:', 'ftp:')
|
151
|
-
):
|
152
|
-
return chunk.uri
|
153
|
-
return chunk.content_uri
|
154
|
-
|
155
|
-
content = []
|
156
|
-
for chunk in message.chunk():
|
157
|
-
if isinstance(chunk, str):
|
158
|
-
item = dict(type='text', text=chunk)
|
159
|
-
elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
|
160
|
-
item = dict(type='image_url', image_url=dict(url=_uri_from(chunk)))
|
161
|
-
else:
|
162
|
-
raise ValueError(f'Unsupported modality: {chunk!r}.')
|
163
|
-
content.append(item)
|
164
|
-
return content
|
165
|
-
|
166
|
-
def request(
|
167
|
-
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
168
|
-
) -> dict[str, Any]:
|
169
|
-
"""Returns the JSON input for a message."""
|
170
|
-
request_args = self._request_args(sampling_options)
|
171
|
-
|
172
|
-
# Users could use `metadata_json_schema` to pass additional
|
173
|
-
# request arguments.
|
174
|
-
json_schema = prompt.metadata.get('json_schema')
|
175
|
-
if json_schema is not None:
|
176
|
-
if not isinstance(json_schema, dict):
|
177
|
-
raise ValueError(f'`json_schema` must be a dict, got {json_schema!r}.')
|
178
|
-
if 'title' not in json_schema:
|
179
|
-
raise ValueError(
|
180
|
-
'The root of `json_schema` must have a `title` field, '
|
181
|
-
f'got {json_schema!r}.'
|
182
|
-
)
|
183
|
-
request_args.update(
|
184
|
-
response_format=dict(
|
185
|
-
type='json_schema',
|
186
|
-
json_schema=dict(
|
187
|
-
schema=json_schema,
|
188
|
-
name=json_schema['title'],
|
189
|
-
strict=True,
|
190
|
-
),
|
191
|
-
)
|
192
|
-
)
|
193
|
-
prompt.metadata.formatted_text = (
|
194
|
-
prompt.text
|
195
|
-
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
196
|
-
+ pg.to_json_str(request_args['response_format'], json_indent=2)
|
197
|
-
)
|
198
|
-
|
199
|
-
# Prepare messages.
|
200
|
-
messages = []
|
201
|
-
# Users could use `metadata_system_message` to pass system message.
|
202
|
-
system_message = prompt.metadata.get('system_message')
|
203
|
-
if system_message:
|
204
|
-
system_message = lf.SystemMessage.from_value(system_message)
|
205
|
-
messages.append(
|
206
|
-
dict(
|
207
|
-
role='system', content=self._content_from_message(system_message)
|
208
|
-
)
|
209
|
-
)
|
210
|
-
messages.append(
|
211
|
-
dict(role='user', content=self._content_from_message(prompt))
|
212
|
-
)
|
213
|
-
request = dict()
|
214
|
-
request.update(request_args)
|
215
|
-
request['messages'] = messages
|
216
|
-
return request
|
217
|
-
|
218
|
-
def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
|
219
|
-
# Reference:
|
220
|
-
# https://platform.openai.com/docs/api-reference/chat/object
|
221
|
-
logprobs = None
|
222
|
-
choice_logprobs = choice.get('logprobs')
|
223
|
-
if choice_logprobs:
|
224
|
-
logprobs = [
|
225
|
-
(
|
226
|
-
t['token'],
|
227
|
-
t['logprob'],
|
228
|
-
[(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
|
229
|
-
)
|
230
|
-
for t in choice_logprobs['content']
|
231
|
-
]
|
232
|
-
return lf.LMSample(
|
233
|
-
choice['message']['content'],
|
234
|
-
score=0.0,
|
235
|
-
logprobs=logprobs,
|
236
|
-
)
|
237
|
-
|
238
|
-
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
239
|
-
usage = json['usage']
|
240
|
-
return lf.LMSamplingResult(
|
241
|
-
samples=[self._parse_choice(choice) for choice in json['choices']],
|
242
|
-
usage=lf.LMSamplingUsage(
|
243
|
-
prompt_tokens=usage['prompt_tokens'],
|
244
|
-
completion_tokens=usage['completion_tokens'],
|
245
|
-
total_tokens=usage['total_tokens'],
|
246
|
-
estimated_cost=self.estimate_cost(
|
247
|
-
num_input_tokens=usage['prompt_tokens'],
|
248
|
-
num_output_tokens=usage['completion_tokens'],
|
249
|
-
),
|
250
|
-
),
|
251
|
-
)
|
252
|
-
|
253
109
|
|
254
110
|
class DeepSeekChat(DeepSeek):
|
255
111
|
"""DeepSeek Chat model.
|
@@ -11,72 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for OpenAI models."""
|
15
|
-
|
16
|
-
from typing import Any
|
17
14
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
|
-
import langfun.core as lf
|
21
15
|
from langfun.core.llms import deepseek
|
22
|
-
import pyglove as pg
|
23
|
-
import requests
|
24
|
-
|
25
|
-
|
26
|
-
def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
27
|
-
del url, kwargs
|
28
|
-
messages = json['messages']
|
29
|
-
if len(messages) > 1:
|
30
|
-
system_message = f' system={messages[0]["content"]}'
|
31
|
-
else:
|
32
|
-
system_message = ''
|
33
|
-
|
34
|
-
if 'response_format' in json:
|
35
|
-
response_format = f' format={json["response_format"]["type"]}'
|
36
|
-
else:
|
37
|
-
response_format = ''
|
38
|
-
|
39
|
-
choices = []
|
40
|
-
for k in range(json['n']):
|
41
|
-
if json.get('logprobs'):
|
42
|
-
logprobs = dict(
|
43
|
-
content=[
|
44
|
-
dict(
|
45
|
-
token='chosen_token',
|
46
|
-
logprob=0.5,
|
47
|
-
top_logprobs=[
|
48
|
-
dict(
|
49
|
-
token=f'alternative_token_{i + 1}',
|
50
|
-
logprob=0.1
|
51
|
-
) for i in range(3)
|
52
|
-
]
|
53
|
-
)
|
54
|
-
]
|
55
|
-
)
|
56
|
-
else:
|
57
|
-
logprobs = None
|
58
|
-
|
59
|
-
choices.append(dict(
|
60
|
-
message=dict(
|
61
|
-
content=(
|
62
|
-
f'Sample {k} for message.{system_message}{response_format}'
|
63
|
-
)
|
64
|
-
),
|
65
|
-
logprobs=logprobs,
|
66
|
-
))
|
67
|
-
response = requests.Response()
|
68
|
-
response.status_code = 200
|
69
|
-
response._content = pg.to_json_str(
|
70
|
-
dict(
|
71
|
-
choices=choices,
|
72
|
-
usage=lf.LMSamplingUsage(
|
73
|
-
prompt_tokens=100,
|
74
|
-
completion_tokens=100,
|
75
|
-
total_tokens=200,
|
76
|
-
),
|
77
|
-
)
|
78
|
-
).encode()
|
79
|
-
return response
|
80
16
|
|
81
17
|
|
82
18
|
class DeepSeekTest(unittest.TestCase):
|
@@ -87,7 +23,14 @@ class DeepSeekTest(unittest.TestCase):
|
|
87
23
|
|
88
24
|
def test_key(self):
|
89
25
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
90
|
-
deepseek.DeepSeekChat()
|
26
|
+
_ = deepseek.DeepSeekChat().headers
|
27
|
+
self.assertEqual(
|
28
|
+
deepseek.DeepSeekChat(api_key='test_key').headers,
|
29
|
+
{
|
30
|
+
'Content-Type': 'application/json',
|
31
|
+
'Authorization': 'Bearer test_key',
|
32
|
+
}
|
33
|
+
)
|
91
34
|
|
92
35
|
def test_model_id(self):
|
93
36
|
self.assertEqual(
|
@@ -106,333 +49,13 @@ class DeepSeekTest(unittest.TestCase):
|
|
106
49
|
deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
|
107
50
|
)
|
108
51
|
|
109
|
-
def
|
110
|
-
self.assertEqual(
|
111
|
-
deepseek.DeepSeekChat(api_key='test_key')._request_args(
|
112
|
-
lf.LMSamplingOptions(
|
113
|
-
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
114
|
-
)
|
115
|
-
),
|
116
|
-
dict(
|
117
|
-
model='deepseek-chat',
|
118
|
-
top_logprobs=None,
|
119
|
-
n=1,
|
120
|
-
temperature=1.0,
|
121
|
-
stop=['\n'],
|
122
|
-
seed=123,
|
123
|
-
),
|
124
|
-
)
|
125
|
-
|
126
|
-
def test_call_chat_completion(self):
|
127
|
-
with mock.patch('requests.Session.post') as mock_request:
|
128
|
-
mock_request.side_effect = mock_chat_completion_request
|
129
|
-
lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key')
|
130
|
-
self.assertEqual(
|
131
|
-
lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
|
132
|
-
'Sample 0 for message.',
|
133
|
-
)
|
134
|
-
|
135
|
-
def test_call_chat_completion_with_logprobs(self):
|
136
|
-
with mock.patch('requests.Session.post') as mock_request:
|
137
|
-
mock_request.side_effect = mock_chat_completion_request
|
138
|
-
lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key')
|
139
|
-
results = lm.sample(['hello'], logprobs=True)
|
140
|
-
self.assertEqual(len(results), 1)
|
141
|
-
expected = lf.LMSamplingResult(
|
142
|
-
[
|
143
|
-
lf.LMSample(
|
144
|
-
response=lf.AIMessage(
|
145
|
-
text='Sample 0 for message.',
|
146
|
-
metadata={
|
147
|
-
'score': 0.0,
|
148
|
-
'logprobs': [(
|
149
|
-
'chosen_token',
|
150
|
-
0.5,
|
151
|
-
[
|
152
|
-
('alternative_token_1', 0.1),
|
153
|
-
('alternative_token_2', 0.1),
|
154
|
-
('alternative_token_3', 0.1),
|
155
|
-
],
|
156
|
-
)],
|
157
|
-
'is_cached': False,
|
158
|
-
'usage': lf.LMSamplingUsage(
|
159
|
-
prompt_tokens=100,
|
160
|
-
completion_tokens=100,
|
161
|
-
total_tokens=200,
|
162
|
-
estimated_cost=4.2e-05,
|
163
|
-
),
|
164
|
-
},
|
165
|
-
tags=['lm-response'],
|
166
|
-
),
|
167
|
-
logprobs=[(
|
168
|
-
'chosen_token',
|
169
|
-
0.5,
|
170
|
-
[
|
171
|
-
('alternative_token_1', 0.1),
|
172
|
-
('alternative_token_2', 0.1),
|
173
|
-
('alternative_token_3', 0.1),
|
174
|
-
],
|
175
|
-
)],
|
176
|
-
)
|
177
|
-
],
|
178
|
-
usage=lf.LMSamplingUsage(
|
179
|
-
prompt_tokens=100,
|
180
|
-
completion_tokens=100,
|
181
|
-
total_tokens=200,
|
182
|
-
estimated_cost=4.2e-05,
|
183
|
-
),
|
184
|
-
)
|
185
|
-
self.assertTrue(pg.eq(results[0], expected))
|
186
|
-
|
187
|
-
def test_sample_chat_completion(self):
|
188
|
-
with mock.patch('requests.Session.post') as mock_request:
|
189
|
-
mock_request.side_effect = mock_chat_completion_request
|
190
|
-
deepseek.SUPPORTED_MODELS_AND_SETTINGS['deepseek-chat'].update({
|
191
|
-
'cost_per_1k_input_tokens': 1.0,
|
192
|
-
'cost_per_1k_output_tokens': 1.0,
|
193
|
-
})
|
194
|
-
lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
|
195
|
-
results = lm.sample(
|
196
|
-
['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
|
197
|
-
)
|
198
|
-
|
199
|
-
self.assertEqual(len(results), 2)
|
200
|
-
print(results[0])
|
201
|
-
self.assertEqual(
|
202
|
-
results[0],
|
203
|
-
lf.LMSamplingResult(
|
204
|
-
[
|
205
|
-
lf.LMSample(
|
206
|
-
lf.AIMessage(
|
207
|
-
'Sample 0 for message.',
|
208
|
-
score=0.0,
|
209
|
-
logprobs=None,
|
210
|
-
is_cached=False,
|
211
|
-
usage=lf.LMSamplingUsage(
|
212
|
-
prompt_tokens=33,
|
213
|
-
completion_tokens=33,
|
214
|
-
total_tokens=66,
|
215
|
-
estimated_cost=0.2 / 3,
|
216
|
-
),
|
217
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
218
|
-
),
|
219
|
-
score=0.0,
|
220
|
-
logprobs=None,
|
221
|
-
),
|
222
|
-
lf.LMSample(
|
223
|
-
lf.AIMessage(
|
224
|
-
'Sample 1 for message.',
|
225
|
-
score=0.0,
|
226
|
-
logprobs=None,
|
227
|
-
is_cached=False,
|
228
|
-
usage=lf.LMSamplingUsage(
|
229
|
-
prompt_tokens=33,
|
230
|
-
completion_tokens=33,
|
231
|
-
total_tokens=66,
|
232
|
-
estimated_cost=0.2 / 3,
|
233
|
-
),
|
234
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
235
|
-
),
|
236
|
-
score=0.0,
|
237
|
-
logprobs=None,
|
238
|
-
),
|
239
|
-
lf.LMSample(
|
240
|
-
lf.AIMessage(
|
241
|
-
'Sample 2 for message.',
|
242
|
-
score=0.0,
|
243
|
-
logprobs=None,
|
244
|
-
is_cached=False,
|
245
|
-
usage=lf.LMSamplingUsage(
|
246
|
-
prompt_tokens=33,
|
247
|
-
completion_tokens=33,
|
248
|
-
total_tokens=66,
|
249
|
-
estimated_cost=0.2 / 3,
|
250
|
-
),
|
251
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
252
|
-
),
|
253
|
-
score=0.0,
|
254
|
-
logprobs=None,
|
255
|
-
),
|
256
|
-
],
|
257
|
-
usage=lf.LMSamplingUsage(
|
258
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
259
|
-
estimated_cost=0.2,
|
260
|
-
),
|
261
|
-
),
|
262
|
-
)
|
52
|
+
def test_estimate_cost(self):
|
263
53
|
self.assertEqual(
|
264
|
-
|
265
|
-
|
266
|
-
[
|
267
|
-
lf.LMSample(
|
268
|
-
lf.AIMessage(
|
269
|
-
'Sample 0 for message.',
|
270
|
-
score=0.0,
|
271
|
-
logprobs=None,
|
272
|
-
is_cached=False,
|
273
|
-
usage=lf.LMSamplingUsage(
|
274
|
-
prompt_tokens=33,
|
275
|
-
completion_tokens=33,
|
276
|
-
total_tokens=66,
|
277
|
-
estimated_cost=0.2 / 3,
|
278
|
-
),
|
279
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
280
|
-
),
|
281
|
-
score=0.0,
|
282
|
-
logprobs=None,
|
283
|
-
),
|
284
|
-
lf.LMSample(
|
285
|
-
lf.AIMessage(
|
286
|
-
'Sample 1 for message.',
|
287
|
-
score=0.0,
|
288
|
-
logprobs=None,
|
289
|
-
is_cached=False,
|
290
|
-
usage=lf.LMSamplingUsage(
|
291
|
-
prompt_tokens=33,
|
292
|
-
completion_tokens=33,
|
293
|
-
total_tokens=66,
|
294
|
-
estimated_cost=0.2 / 3,
|
295
|
-
),
|
296
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
297
|
-
),
|
298
|
-
score=0.0,
|
299
|
-
logprobs=None,
|
300
|
-
),
|
301
|
-
lf.LMSample(
|
302
|
-
lf.AIMessage(
|
303
|
-
'Sample 2 for message.',
|
304
|
-
score=0.0,
|
305
|
-
logprobs=None,
|
306
|
-
is_cached=False,
|
307
|
-
usage=lf.LMSamplingUsage(
|
308
|
-
prompt_tokens=33,
|
309
|
-
completion_tokens=33,
|
310
|
-
total_tokens=66,
|
311
|
-
estimated_cost=0.2 / 3,
|
312
|
-
),
|
313
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
314
|
-
),
|
315
|
-
score=0.0,
|
316
|
-
logprobs=None,
|
317
|
-
),
|
318
|
-
],
|
319
|
-
usage=lf.LMSamplingUsage(
|
320
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
321
|
-
estimated_cost=0.2,
|
322
|
-
),
|
54
|
+
deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
|
55
|
+
num_input_tokens=100, num_output_tokens=100
|
323
56
|
),
|
57
|
+
4.2e-5
|
324
58
|
)
|
325
59
|
|
326
|
-
def test_sample_with_contextual_options(self):
|
327
|
-
with mock.patch('requests.Session.post') as mock_request:
|
328
|
-
mock_request.side_effect = mock_chat_completion_request
|
329
|
-
lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
|
330
|
-
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
331
|
-
results = lm.sample(['hello'])
|
332
|
-
|
333
|
-
self.assertEqual(len(results), 1)
|
334
|
-
expected = lf.LMSamplingResult(
|
335
|
-
samples=[
|
336
|
-
lf.LMSample(
|
337
|
-
response=lf.AIMessage(
|
338
|
-
text='Sample 0 for message.',
|
339
|
-
sender='AI',
|
340
|
-
metadata=pg.Dict(
|
341
|
-
score=0.0,
|
342
|
-
logprobs=None,
|
343
|
-
is_cached=False,
|
344
|
-
usage=lf.LMSamplingUsage(
|
345
|
-
prompt_tokens=50,
|
346
|
-
completion_tokens=50,
|
347
|
-
total_tokens=100,
|
348
|
-
num_requests=1,
|
349
|
-
estimated_cost=0.1,
|
350
|
-
),
|
351
|
-
),
|
352
|
-
tags=['lm-response'],
|
353
|
-
),
|
354
|
-
score=0.0,
|
355
|
-
logprobs=None,
|
356
|
-
),
|
357
|
-
lf.LMSample(
|
358
|
-
response=lf.AIMessage(
|
359
|
-
text='Sample 1 for message.',
|
360
|
-
sender='AI',
|
361
|
-
metadata=pg.Dict(
|
362
|
-
score=0.0,
|
363
|
-
logprobs=None,
|
364
|
-
is_cached=False,
|
365
|
-
usage=lf.LMSamplingUsage(
|
366
|
-
prompt_tokens=50,
|
367
|
-
completion_tokens=50,
|
368
|
-
total_tokens=100,
|
369
|
-
num_requests=1,
|
370
|
-
estimated_cost=0.1,
|
371
|
-
),
|
372
|
-
),
|
373
|
-
tags=['lm-response'],
|
374
|
-
),
|
375
|
-
score=0.0,
|
376
|
-
logprobs=None,
|
377
|
-
),
|
378
|
-
],
|
379
|
-
usage=lf.LMSamplingUsage(
|
380
|
-
prompt_tokens=100,
|
381
|
-
completion_tokens=100,
|
382
|
-
total_tokens=200,
|
383
|
-
num_requests=1,
|
384
|
-
estimated_cost=0.2,
|
385
|
-
),
|
386
|
-
is_cached=False,
|
387
|
-
)
|
388
|
-
self.assertTrue(pg.eq(results[0], expected))
|
389
|
-
|
390
|
-
def test_call_with_system_message(self):
|
391
|
-
with mock.patch('requests.Session.post') as mock_request:
|
392
|
-
mock_request.side_effect = mock_chat_completion_request
|
393
|
-
lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
|
394
|
-
self.assertEqual(
|
395
|
-
lm(
|
396
|
-
lf.UserMessage(
|
397
|
-
'hello',
|
398
|
-
system_message='hi',
|
399
|
-
),
|
400
|
-
sampling_options=lf.LMSamplingOptions(n=2)
|
401
|
-
),
|
402
|
-
'''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
|
403
|
-
)
|
404
|
-
|
405
|
-
def test_call_with_json_schema(self):
|
406
|
-
with mock.patch('requests.Session.post') as mock_request:
|
407
|
-
mock_request.side_effect = mock_chat_completion_request
|
408
|
-
lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
|
409
|
-
self.assertEqual(
|
410
|
-
lm(
|
411
|
-
lf.UserMessage(
|
412
|
-
'hello',
|
413
|
-
json_schema={
|
414
|
-
'type': 'object',
|
415
|
-
'properties': {
|
416
|
-
'name': {'type': 'string'},
|
417
|
-
},
|
418
|
-
'required': ['name'],
|
419
|
-
'title': 'Person',
|
420
|
-
}
|
421
|
-
),
|
422
|
-
sampling_options=lf.LMSamplingOptions(n=2)
|
423
|
-
),
|
424
|
-
'Sample 0 for message. format=json_schema',
|
425
|
-
)
|
426
|
-
|
427
|
-
# Test bad json schema.
|
428
|
-
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
429
|
-
lm(lf.UserMessage('hello', json_schema='foo'))
|
430
|
-
|
431
|
-
with self.assertRaisesRegex(
|
432
|
-
ValueError, 'The root of `json_schema` must have a `title` field'
|
433
|
-
):
|
434
|
-
lm(lf.UserMessage('hello', json_schema={}))
|
435
|
-
|
436
|
-
|
437
60
|
if __name__ == '__main__':
|
438
61
|
unittest.main()
|