langfun 0.1.2.dev202501060804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +8 -4
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +4 -0
  14. langfun/core/llms/deepseek.py +261 -0
  15. langfun/core/llms/deepseek_test.py +438 -0
  16. langfun/core/llms/google_genai.py +1 -0
  17. langfun/core/llms/openai.py +5 -0
  18. langfun/core/llms/vertexai.py +6 -2
  19. langfun/core/llms/vertexai_test.py +1 -1
  20. langfun/core/structured/mapping.py +13 -13
  21. langfun/core/structured/mapping_test.py +2 -2
  22. langfun/core/structured/schema.py +16 -8
  23. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +13 -2
  24. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +27 -27
  25. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
  26. langfun/core/text_formatting.py +0 -168
  27. langfun/core/text_formatting_test.py +0 -65
  28. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
  29. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,261 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from DeepSeek."""
15
+
16
+ import os
17
+ from typing import Annotated, Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
22
+ import pyglove as pg
23
+
24
+ SUPPORTED_MODELS_AND_SETTINGS = {
25
+ # pylint: disable=g-line-too-long
26
+ # TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
27
+ # DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
28
+ # The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
29
+ 'deepseek-chat': pg.Dict(
30
+ in_service=True,
31
+ rpm=100,
32
+ tpm=1000000,
33
+ cost_per_1k_input_tokens=0.00014,
34
+ cost_per_1k_output_tokens=0.00028,
35
+ ),
36
+ }
37
+
38
+
39
+ # DeepSeek API uses an API format compatible with OpenAI.
40
+ # Reference: https://api-docs.deepseek.com/
41
+ @lf.use_init_args(['model'])
42
+ class DeepSeek(rest.REST):
43
+ """DeepSeek model."""
44
+
45
+ model: pg.typing.Annotated[
46
+ pg.typing.Enum(
47
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
48
+ ),
49
+ 'The name of the model to use.',
50
+ ]
51
+
52
+ api_endpoint: str = 'https://api.deepseek.com/chat/completions'
53
+
54
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
55
+ False
56
+ )
57
+
58
+ api_key: Annotated[
59
+ str | None,
60
+ (
61
+ 'API key. If None, the key will be read from environment variable '
62
+ "'DEEPSEEK_API_KEY'."
63
+ ),
64
+ ] = None
65
+
66
+ def _on_bound(self):
67
+ super()._on_bound()
68
+ self._api_key = None
69
+
70
+ def _initialize(self):
71
+ api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
72
+ if not api_key:
73
+ raise ValueError(
74
+ 'Please specify `api_key` during `__init__` or set environment '
75
+ 'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
76
+ )
77
+ self._api_key = api_key
78
+
79
+ @property
80
+ def headers(self) -> dict[str, Any]:
81
+ headers = {
82
+ 'Content-Type': 'application/json',
83
+ 'Authorization': f'Bearer {self._api_key}',
84
+ }
85
+ return headers
86
+
87
+ @property
88
+ def model_id(self) -> str:
89
+ """Returns a string to identify the model."""
90
+ return f'DeepSeek({self.model})'
91
+
92
+ @property
93
+ def max_concurrency(self) -> int:
94
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
95
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
96
+ return self.rate_to_max_concurrency(
97
+ requests_per_min=rpm, tokens_per_min=tpm
98
+ )
99
+
100
+ def estimate_cost(
101
+ self, num_input_tokens: int, num_output_tokens: int
102
+ ) -> float | None:
103
+ """Estimate the cost based on usage."""
104
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
105
+ 'cost_per_1k_input_tokens', None
106
+ )
107
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
108
+ 'cost_per_1k_output_tokens', None
109
+ )
110
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
111
+ return None
112
+ return (
113
+ cost_per_1k_input_tokens * num_input_tokens
114
+ + cost_per_1k_output_tokens * num_output_tokens
115
+ ) / 1000
116
+
117
+ @classmethod
118
+ def dir(cls):
119
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
120
+
121
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
122
+ # Reference:
123
+ # https://platform.openai.com/docs/api-reference/completions/create
124
+ # NOTE(daiyip): options.top_k is not applicable.
125
+ args = dict(
126
+ model=self.model,
127
+ n=options.n,
128
+ top_logprobs=options.top_logprobs,
129
+ )
130
+ if options.logprobs:
131
+ args['logprobs'] = options.logprobs
132
+
133
+ if options.temperature is not None:
134
+ args['temperature'] = options.temperature
135
+ if options.max_tokens is not None:
136
+ args['max_completion_tokens'] = options.max_tokens
137
+ if options.top_p is not None:
138
+ args['top_p'] = options.top_p
139
+ if options.stop:
140
+ args['stop'] = options.stop
141
+ if options.random_seed is not None:
142
+ args['seed'] = options.random_seed
143
+ return args
144
+
145
+ def _content_from_message(self, message: lf.Message):
146
+ """Returns a OpenAI content object from a Langfun message."""
147
+
148
+ def _uri_from(chunk: lf.Modality) -> str:
149
+ if chunk.uri and chunk.uri.lower().startswith(
150
+ ('http:', 'https:', 'ftp:')
151
+ ):
152
+ return chunk.uri
153
+ return chunk.content_uri
154
+
155
+ content = []
156
+ for chunk in message.chunk():
157
+ if isinstance(chunk, str):
158
+ item = dict(type='text', text=chunk)
159
+ elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
160
+ item = dict(type='image_url', image_url=dict(url=_uri_from(chunk)))
161
+ else:
162
+ raise ValueError(f'Unsupported modality: {chunk!r}.')
163
+ content.append(item)
164
+ return content
165
+
166
+ def request(
167
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
168
+ ) -> dict[str, Any]:
169
+ """Returns the JSON input for a message."""
170
+ request_args = self._request_args(sampling_options)
171
+
172
+ # Users could use `metadata_json_schema` to pass additional
173
+ # request arguments.
174
+ json_schema = prompt.metadata.get('json_schema')
175
+ if json_schema is not None:
176
+ if not isinstance(json_schema, dict):
177
+ raise ValueError(f'`json_schema` must be a dict, got {json_schema!r}.')
178
+ if 'title' not in json_schema:
179
+ raise ValueError(
180
+ 'The root of `json_schema` must have a `title` field, '
181
+ f'got {json_schema!r}.'
182
+ )
183
+ request_args.update(
184
+ response_format=dict(
185
+ type='json_schema',
186
+ json_schema=dict(
187
+ schema=json_schema,
188
+ name=json_schema['title'],
189
+ strict=True,
190
+ ),
191
+ )
192
+ )
193
+ prompt.metadata.formatted_text = (
194
+ prompt.text
195
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
196
+ + pg.to_json_str(request_args['response_format'], json_indent=2)
197
+ )
198
+
199
+ # Prepare messages.
200
+ messages = []
201
+ # Users could use `metadata_system_message` to pass system message.
202
+ system_message = prompt.metadata.get('system_message')
203
+ if system_message:
204
+ system_message = lf.SystemMessage.from_value(system_message)
205
+ messages.append(
206
+ dict(
207
+ role='system', content=self._content_from_message(system_message)
208
+ )
209
+ )
210
+ messages.append(
211
+ dict(role='user', content=self._content_from_message(prompt))
212
+ )
213
+ request = dict()
214
+ request.update(request_args)
215
+ request['messages'] = messages
216
+ return request
217
+
218
+ def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
219
+ # Reference:
220
+ # https://platform.openai.com/docs/api-reference/chat/object
221
+ logprobs = None
222
+ choice_logprobs = choice.get('logprobs')
223
+ if choice_logprobs:
224
+ logprobs = [
225
+ (
226
+ t['token'],
227
+ t['logprob'],
228
+ [(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
229
+ )
230
+ for t in choice_logprobs['content']
231
+ ]
232
+ return lf.LMSample(
233
+ choice['message']['content'],
234
+ score=0.0,
235
+ logprobs=logprobs,
236
+ )
237
+
238
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
239
+ usage = json['usage']
240
+ return lf.LMSamplingResult(
241
+ samples=[self._parse_choice(choice) for choice in json['choices']],
242
+ usage=lf.LMSamplingUsage(
243
+ prompt_tokens=usage['prompt_tokens'],
244
+ completion_tokens=usage['completion_tokens'],
245
+ total_tokens=usage['total_tokens'],
246
+ estimated_cost=self.estimate_cost(
247
+ num_input_tokens=usage['prompt_tokens'],
248
+ num_output_tokens=usage['completion_tokens'],
249
+ ),
250
+ ),
251
+ )
252
+
253
+
254
+ class DeepSeekChat(DeepSeek):
255
+ """DeepSeek Chat model.
256
+
257
+ Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
258
+ 8k max output tokens.
259
+ """
260
+
261
+ model = 'deepseek-chat'
@@ -0,0 +1,438 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for OpenAI models."""
15
+
16
+ from typing import Any
17
+ import unittest
18
+ from unittest import mock
19
+
20
+ import langfun.core as lf
21
+ from langfun.core.llms import deepseek
22
+ import pyglove as pg
23
+ import requests
24
+
25
+
26
+ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
27
+ del url, kwargs
28
+ messages = json['messages']
29
+ if len(messages) > 1:
30
+ system_message = f' system={messages[0]["content"]}'
31
+ else:
32
+ system_message = ''
33
+
34
+ if 'response_format' in json:
35
+ response_format = f' format={json["response_format"]["type"]}'
36
+ else:
37
+ response_format = ''
38
+
39
+ choices = []
40
+ for k in range(json['n']):
41
+ if json.get('logprobs'):
42
+ logprobs = dict(
43
+ content=[
44
+ dict(
45
+ token='chosen_token',
46
+ logprob=0.5,
47
+ top_logprobs=[
48
+ dict(
49
+ token=f'alternative_token_{i + 1}',
50
+ logprob=0.1
51
+ ) for i in range(3)
52
+ ]
53
+ )
54
+ ]
55
+ )
56
+ else:
57
+ logprobs = None
58
+
59
+ choices.append(dict(
60
+ message=dict(
61
+ content=(
62
+ f'Sample {k} for message.{system_message}{response_format}'
63
+ )
64
+ ),
65
+ logprobs=logprobs,
66
+ ))
67
+ response = requests.Response()
68
+ response.status_code = 200
69
+ response._content = pg.to_json_str(
70
+ dict(
71
+ choices=choices,
72
+ usage=lf.LMSamplingUsage(
73
+ prompt_tokens=100,
74
+ completion_tokens=100,
75
+ total_tokens=200,
76
+ ),
77
+ )
78
+ ).encode()
79
+ return response
80
+
81
+
82
+ class DeepSeekTest(unittest.TestCase):
83
+ """Tests for DeepSeek language model."""
84
+
85
+ def test_dir(self):
86
+ self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
87
+
88
+ def test_key(self):
89
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
90
+ deepseek.DeepSeekChat()('hi')
91
+
92
+ def test_model_id(self):
93
+ self.assertEqual(
94
+ deepseek.DeepSeekChat(api_key='test_key').model_id,
95
+ 'DeepSeek(deepseek-chat)',
96
+ )
97
+
98
+ def test_resource_id(self):
99
+ self.assertEqual(
100
+ deepseek.DeepSeekChat(api_key='test_key').resource_id,
101
+ 'DeepSeek(deepseek-chat)',
102
+ )
103
+
104
+ def test_max_concurrency(self):
105
+ self.assertGreater(
106
+ deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
107
+ )
108
+
109
+ def test_request_args(self):
110
+ self.assertEqual(
111
+ deepseek.DeepSeekChat(api_key='test_key')._request_args(
112
+ lf.LMSamplingOptions(
113
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
114
+ )
115
+ ),
116
+ dict(
117
+ model='deepseek-chat',
118
+ top_logprobs=None,
119
+ n=1,
120
+ temperature=1.0,
121
+ stop=['\n'],
122
+ seed=123,
123
+ ),
124
+ )
125
+
126
+ def test_call_chat_completion(self):
127
+ with mock.patch('requests.Session.post') as mock_request:
128
+ mock_request.side_effect = mock_chat_completion_request
129
+ lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key')
130
+ self.assertEqual(
131
+ lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
132
+ 'Sample 0 for message.',
133
+ )
134
+
135
+ def test_call_chat_completion_with_logprobs(self):
136
+ with mock.patch('requests.Session.post') as mock_request:
137
+ mock_request.side_effect = mock_chat_completion_request
138
+ lm = deepseek.DeepSeek(model='deepseek-chat', api_key='test_key')
139
+ results = lm.sample(['hello'], logprobs=True)
140
+ self.assertEqual(len(results), 1)
141
+ expected = lf.LMSamplingResult(
142
+ [
143
+ lf.LMSample(
144
+ response=lf.AIMessage(
145
+ text='Sample 0 for message.',
146
+ metadata={
147
+ 'score': 0.0,
148
+ 'logprobs': [(
149
+ 'chosen_token',
150
+ 0.5,
151
+ [
152
+ ('alternative_token_1', 0.1),
153
+ ('alternative_token_2', 0.1),
154
+ ('alternative_token_3', 0.1),
155
+ ],
156
+ )],
157
+ 'is_cached': False,
158
+ 'usage': lf.LMSamplingUsage(
159
+ prompt_tokens=100,
160
+ completion_tokens=100,
161
+ total_tokens=200,
162
+ estimated_cost=4.2e-05,
163
+ ),
164
+ },
165
+ tags=['lm-response'],
166
+ ),
167
+ logprobs=[(
168
+ 'chosen_token',
169
+ 0.5,
170
+ [
171
+ ('alternative_token_1', 0.1),
172
+ ('alternative_token_2', 0.1),
173
+ ('alternative_token_3', 0.1),
174
+ ],
175
+ )],
176
+ )
177
+ ],
178
+ usage=lf.LMSamplingUsage(
179
+ prompt_tokens=100,
180
+ completion_tokens=100,
181
+ total_tokens=200,
182
+ estimated_cost=4.2e-05,
183
+ ),
184
+ )
185
+ self.assertTrue(pg.eq(results[0], expected))
186
+
187
+ def test_sample_chat_completion(self):
188
+ with mock.patch('requests.Session.post') as mock_request:
189
+ mock_request.side_effect = mock_chat_completion_request
190
+ deepseek.SUPPORTED_MODELS_AND_SETTINGS['deepseek-chat'].update({
191
+ 'cost_per_1k_input_tokens': 1.0,
192
+ 'cost_per_1k_output_tokens': 1.0,
193
+ })
194
+ lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
195
+ results = lm.sample(
196
+ ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
197
+ )
198
+
199
+ self.assertEqual(len(results), 2)
200
+ print(results[0])
201
+ self.assertEqual(
202
+ results[0],
203
+ lf.LMSamplingResult(
204
+ [
205
+ lf.LMSample(
206
+ lf.AIMessage(
207
+ 'Sample 0 for message.',
208
+ score=0.0,
209
+ logprobs=None,
210
+ is_cached=False,
211
+ usage=lf.LMSamplingUsage(
212
+ prompt_tokens=33,
213
+ completion_tokens=33,
214
+ total_tokens=66,
215
+ estimated_cost=0.2 / 3,
216
+ ),
217
+ tags=[lf.Message.TAG_LM_RESPONSE],
218
+ ),
219
+ score=0.0,
220
+ logprobs=None,
221
+ ),
222
+ lf.LMSample(
223
+ lf.AIMessage(
224
+ 'Sample 1 for message.',
225
+ score=0.0,
226
+ logprobs=None,
227
+ is_cached=False,
228
+ usage=lf.LMSamplingUsage(
229
+ prompt_tokens=33,
230
+ completion_tokens=33,
231
+ total_tokens=66,
232
+ estimated_cost=0.2 / 3,
233
+ ),
234
+ tags=[lf.Message.TAG_LM_RESPONSE],
235
+ ),
236
+ score=0.0,
237
+ logprobs=None,
238
+ ),
239
+ lf.LMSample(
240
+ lf.AIMessage(
241
+ 'Sample 2 for message.',
242
+ score=0.0,
243
+ logprobs=None,
244
+ is_cached=False,
245
+ usage=lf.LMSamplingUsage(
246
+ prompt_tokens=33,
247
+ completion_tokens=33,
248
+ total_tokens=66,
249
+ estimated_cost=0.2 / 3,
250
+ ),
251
+ tags=[lf.Message.TAG_LM_RESPONSE],
252
+ ),
253
+ score=0.0,
254
+ logprobs=None,
255
+ ),
256
+ ],
257
+ usage=lf.LMSamplingUsage(
258
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
259
+ estimated_cost=0.2,
260
+ ),
261
+ ),
262
+ )
263
+ self.assertEqual(
264
+ results[1],
265
+ lf.LMSamplingResult(
266
+ [
267
+ lf.LMSample(
268
+ lf.AIMessage(
269
+ 'Sample 0 for message.',
270
+ score=0.0,
271
+ logprobs=None,
272
+ is_cached=False,
273
+ usage=lf.LMSamplingUsage(
274
+ prompt_tokens=33,
275
+ completion_tokens=33,
276
+ total_tokens=66,
277
+ estimated_cost=0.2 / 3,
278
+ ),
279
+ tags=[lf.Message.TAG_LM_RESPONSE],
280
+ ),
281
+ score=0.0,
282
+ logprobs=None,
283
+ ),
284
+ lf.LMSample(
285
+ lf.AIMessage(
286
+ 'Sample 1 for message.',
287
+ score=0.0,
288
+ logprobs=None,
289
+ is_cached=False,
290
+ usage=lf.LMSamplingUsage(
291
+ prompt_tokens=33,
292
+ completion_tokens=33,
293
+ total_tokens=66,
294
+ estimated_cost=0.2 / 3,
295
+ ),
296
+ tags=[lf.Message.TAG_LM_RESPONSE],
297
+ ),
298
+ score=0.0,
299
+ logprobs=None,
300
+ ),
301
+ lf.LMSample(
302
+ lf.AIMessage(
303
+ 'Sample 2 for message.',
304
+ score=0.0,
305
+ logprobs=None,
306
+ is_cached=False,
307
+ usage=lf.LMSamplingUsage(
308
+ prompt_tokens=33,
309
+ completion_tokens=33,
310
+ total_tokens=66,
311
+ estimated_cost=0.2 / 3,
312
+ ),
313
+ tags=[lf.Message.TAG_LM_RESPONSE],
314
+ ),
315
+ score=0.0,
316
+ logprobs=None,
317
+ ),
318
+ ],
319
+ usage=lf.LMSamplingUsage(
320
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
321
+ estimated_cost=0.2,
322
+ ),
323
+ ),
324
+ )
325
+
326
+ def test_sample_with_contextual_options(self):
327
+ with mock.patch('requests.Session.post') as mock_request:
328
+ mock_request.side_effect = mock_chat_completion_request
329
+ lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
330
+ with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
331
+ results = lm.sample(['hello'])
332
+
333
+ self.assertEqual(len(results), 1)
334
+ expected = lf.LMSamplingResult(
335
+ samples=[
336
+ lf.LMSample(
337
+ response=lf.AIMessage(
338
+ text='Sample 0 for message.',
339
+ sender='AI',
340
+ metadata=pg.Dict(
341
+ score=0.0,
342
+ logprobs=None,
343
+ is_cached=False,
344
+ usage=lf.LMSamplingUsage(
345
+ prompt_tokens=50,
346
+ completion_tokens=50,
347
+ total_tokens=100,
348
+ num_requests=1,
349
+ estimated_cost=0.1,
350
+ ),
351
+ ),
352
+ tags=['lm-response'],
353
+ ),
354
+ score=0.0,
355
+ logprobs=None,
356
+ ),
357
+ lf.LMSample(
358
+ response=lf.AIMessage(
359
+ text='Sample 1 for message.',
360
+ sender='AI',
361
+ metadata=pg.Dict(
362
+ score=0.0,
363
+ logprobs=None,
364
+ is_cached=False,
365
+ usage=lf.LMSamplingUsage(
366
+ prompt_tokens=50,
367
+ completion_tokens=50,
368
+ total_tokens=100,
369
+ num_requests=1,
370
+ estimated_cost=0.1,
371
+ ),
372
+ ),
373
+ tags=['lm-response'],
374
+ ),
375
+ score=0.0,
376
+ logprobs=None,
377
+ ),
378
+ ],
379
+ usage=lf.LMSamplingUsage(
380
+ prompt_tokens=100,
381
+ completion_tokens=100,
382
+ total_tokens=200,
383
+ num_requests=1,
384
+ estimated_cost=0.2,
385
+ ),
386
+ is_cached=False,
387
+ )
388
+ self.assertTrue(pg.eq(results[0], expected))
389
+
390
+ def test_call_with_system_message(self):
391
+ with mock.patch('requests.Session.post') as mock_request:
392
+ mock_request.side_effect = mock_chat_completion_request
393
+ lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
394
+ self.assertEqual(
395
+ lm(
396
+ lf.UserMessage(
397
+ 'hello',
398
+ system_message='hi',
399
+ ),
400
+ sampling_options=lf.LMSamplingOptions(n=2)
401
+ ),
402
+ '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
403
+ )
404
+
405
+ def test_call_with_json_schema(self):
406
+ with mock.patch('requests.Session.post') as mock_request:
407
+ mock_request.side_effect = mock_chat_completion_request
408
+ lm = deepseek.DeepSeek(api_key='test_key', model='deepseek-chat')
409
+ self.assertEqual(
410
+ lm(
411
+ lf.UserMessage(
412
+ 'hello',
413
+ json_schema={
414
+ 'type': 'object',
415
+ 'properties': {
416
+ 'name': {'type': 'string'},
417
+ },
418
+ 'required': ['name'],
419
+ 'title': 'Person',
420
+ }
421
+ ),
422
+ sampling_options=lf.LMSamplingOptions(n=2)
423
+ ),
424
+ 'Sample 0 for message. format=json_schema',
425
+ )
426
+
427
+ # Test bad json schema.
428
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
429
+ lm(lf.UserMessage('hello', json_schema='foo'))
430
+
431
+ with self.assertRaisesRegex(
432
+ ValueError, 'The root of `json_schema` must have a `title` field'
433
+ ):
434
+ lm(lf.UserMessage('hello', json_schema={}))
435
+
436
+
437
+ if __name__ == '__main__':
438
+ unittest.main()
@@ -68,6 +68,7 @@ class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
68
68
 
69
69
  api_version = 'v1alpha'
70
70
  model = 'gemini-2.0-flash-thinking-exp-1219'
71
+ timeout = None
71
72
 
72
73
 
73
74
  class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name