langfun 0.1.2.dev202501060804__py3-none-any.whl → 0.1.2.dev202501100804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +8 -4
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +7 -0
  14. langfun/core/llms/deepseek.py +117 -0
  15. langfun/core/llms/deepseek_test.py +61 -0
  16. langfun/core/llms/google_genai.py +1 -0
  17. langfun/core/llms/groq.py +12 -99
  18. langfun/core/llms/groq_test.py +31 -137
  19. langfun/core/llms/llama_cpp.py +17 -54
  20. langfun/core/llms/llama_cpp_test.py +2 -34
  21. langfun/core/llms/openai.py +14 -147
  22. langfun/core/llms/openai_compatible.py +179 -0
  23. langfun/core/llms/openai_compatible_test.py +480 -0
  24. langfun/core/llms/openai_test.py +13 -423
  25. langfun/core/llms/vertexai.py +6 -2
  26. langfun/core/llms/vertexai_test.py +1 -1
  27. langfun/core/modalities/mime.py +8 -0
  28. langfun/core/modalities/mime_test.py +19 -4
  29. langfun/core/modality_test.py +0 -1
  30. langfun/core/structured/mapping.py +13 -13
  31. langfun/core/structured/mapping_test.py +2 -2
  32. langfun/core/structured/schema.py +16 -8
  33. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/METADATA +13 -2
  34. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/RECORD +37 -35
  35. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/WHEEL +1 -1
  36. langfun/core/text_formatting.py +0 -168
  37. langfun/core/text_formatting_test.py +0 -65
  38. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/LICENSE +0 -0
  39. {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for OpenAI models."""
15
+
16
+ from typing import Any
17
+ import unittest
18
+ from unittest import mock
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import modalities as lf_modalities
22
+ from langfun.core.llms import openai_compatible
23
+ import pyglove as pg
24
+ import requests
25
+
26
+
27
+ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
28
+ del url, kwargs
29
+ messages = json['messages']
30
+ if len(messages) > 1:
31
+ system_message = f' system={messages[0]["content"]}'
32
+ else:
33
+ system_message = ''
34
+
35
+ if 'response_format' in json:
36
+ response_format = f' format={json["response_format"]["type"]}'
37
+ else:
38
+ response_format = ''
39
+
40
+ choices = []
41
+ for k in range(json['n']):
42
+ if json.get('logprobs'):
43
+ logprobs = dict(
44
+ content=[
45
+ dict(
46
+ token='chosen_token',
47
+ logprob=0.5,
48
+ top_logprobs=[
49
+ dict(
50
+ token=f'alternative_token_{i + 1}',
51
+ logprob=0.1
52
+ ) for i in range(3)
53
+ ]
54
+ )
55
+ ]
56
+ )
57
+ else:
58
+ logprobs = None
59
+
60
+ choices.append(dict(
61
+ message=dict(
62
+ content=(
63
+ f'Sample {k} for message.{system_message}{response_format}'
64
+ )
65
+ ),
66
+ logprobs=logprobs,
67
+ ))
68
+ response = requests.Response()
69
+ response.status_code = 200
70
+ response._content = pg.to_json_str(
71
+ dict(
72
+ choices=choices,
73
+ usage=lf.LMSamplingUsage(
74
+ prompt_tokens=100,
75
+ completion_tokens=100,
76
+ total_tokens=200,
77
+ ),
78
+ )
79
+ ).encode()
80
+ return response
81
+
82
+
83
+ def mock_chat_completion_request_vision(
84
+ url: str, json: dict[str, Any], **kwargs
85
+ ):
86
+ del url, kwargs
87
+ choices = []
88
+ urls = [
89
+ c['image_url']['url']
90
+ for c in json['messages'][0]['content'] if c['type'] == 'image_url'
91
+ ]
92
+ for k in range(json['n']):
93
+ choices.append(pg.Dict(
94
+ message=pg.Dict(
95
+ content=f'Sample {k} for message: {"".join(urls)}'
96
+ ),
97
+ logprobs=None,
98
+ ))
99
+ response = requests.Response()
100
+ response.status_code = 200
101
+ response._content = pg.to_json_str(
102
+ dict(
103
+ choices=choices,
104
+ usage=lf.LMSamplingUsage(
105
+ prompt_tokens=100,
106
+ completion_tokens=100,
107
+ total_tokens=200,
108
+ ),
109
+ )
110
+ ).encode()
111
+ return response
112
+
113
+
114
+ class OpenAIComptibleTest(unittest.TestCase):
115
+ """Tests for OpenAI compatible language model."""
116
+
117
+ def test_request_args(self):
118
+ self.assertEqual(
119
+ openai_compatible.OpenAICompatible(
120
+ api_endpoint='https://test-server',
121
+ model='test-model'
122
+ )._request_args(
123
+ lf.LMSamplingOptions(
124
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
125
+ )
126
+ ),
127
+ dict(
128
+ model='test-model',
129
+ top_logprobs=None,
130
+ n=1,
131
+ temperature=1.0,
132
+ stop=['\n'],
133
+ seed=123,
134
+ ),
135
+ )
136
+
137
+ def test_call_chat_completion(self):
138
+ with mock.patch('requests.Session.post') as mock_request:
139
+ mock_request.side_effect = mock_chat_completion_request
140
+ lm = openai_compatible.OpenAICompatible(
141
+ api_endpoint='https://test-server', model='test-model',
142
+ )
143
+ self.assertEqual(
144
+ lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
145
+ 'Sample 0 for message.',
146
+ )
147
+
148
+ def test_call_chat_completion_with_logprobs(self):
149
+ with mock.patch('requests.Session.post') as mock_request:
150
+ mock_request.side_effect = mock_chat_completion_request
151
+ lm = openai_compatible.OpenAICompatible(
152
+ api_endpoint='https://test-server', model='test-model',
153
+ )
154
+ results = lm.sample(['hello'], logprobs=True)
155
+ self.assertEqual(len(results), 1)
156
+ self.assertEqual(
157
+ results[0],
158
+ lf.LMSamplingResult(
159
+ [
160
+ lf.LMSample(
161
+ response=lf.AIMessage(
162
+ text='Sample 0 for message.',
163
+ metadata={
164
+ 'score': 0.0,
165
+ 'logprobs': [(
166
+ 'chosen_token',
167
+ 0.5,
168
+ [
169
+ ('alternative_token_1', 0.1),
170
+ ('alternative_token_2', 0.1),
171
+ ('alternative_token_3', 0.1),
172
+ ],
173
+ )],
174
+ 'is_cached': False,
175
+ 'usage': lf.LMSamplingUsage(
176
+ prompt_tokens=100,
177
+ completion_tokens=100,
178
+ total_tokens=200,
179
+ estimated_cost=None,
180
+ ),
181
+ },
182
+ tags=['lm-response'],
183
+ ),
184
+ logprobs=[(
185
+ 'chosen_token',
186
+ 0.5,
187
+ [
188
+ ('alternative_token_1', 0.1),
189
+ ('alternative_token_2', 0.1),
190
+ ('alternative_token_3', 0.1),
191
+ ],
192
+ )],
193
+ )
194
+ ],
195
+ usage=lf.LMSamplingUsage(
196
+ prompt_tokens=100,
197
+ completion_tokens=100,
198
+ total_tokens=200,
199
+ estimated_cost=None,
200
+ ),
201
+ ),
202
+ )
203
+
204
+ def test_call_chat_completion_vision(self):
205
+ with mock.patch('requests.Session.post') as mock_request:
206
+ mock_request.side_effect = mock_chat_completion_request_vision
207
+ lm_1 = openai_compatible.OpenAICompatible(
208
+ api_endpoint='https://test-server',
209
+ model='test-model1',
210
+ multimodal=True
211
+ )
212
+ lm_2 = openai_compatible.OpenAICompatible(
213
+ api_endpoint='https://test-server',
214
+ model='test-model2',
215
+ multimodal=True
216
+ )
217
+ for lm in (lm_1, lm_2):
218
+ self.assertEqual(
219
+ lm(
220
+ lf.UserMessage(
221
+ 'hello <<[[image]]>>',
222
+ image=lf_modalities.Image.from_uri('https://fake/image')
223
+ ),
224
+ sampling_options=lf.LMSamplingOptions(n=2)
225
+ ),
226
+ 'Sample 0 for message: https://fake/image',
227
+ )
228
+ lm_3 = openai_compatible.OpenAICompatible(
229
+ api_endpoint='https://test-server', model='test-model3'
230
+ )
231
+ with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
232
+ lm_3(
233
+ lf.UserMessage(
234
+ 'hello <<[[image]]>>',
235
+ image=lf_modalities.Image.from_uri('https://fake/image')
236
+ ),
237
+ )
238
+
239
+ def test_sample_chat_completion(self):
240
+ with mock.patch('requests.Session.post') as mock_request:
241
+ mock_request.side_effect = mock_chat_completion_request
242
+ lm = openai_compatible.OpenAICompatible(
243
+ api_endpoint='https://test-server', model='test-model'
244
+ )
245
+ results = lm.sample(
246
+ ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
247
+ )
248
+
249
+ self.assertEqual(len(results), 2)
250
+ self.assertEqual(
251
+ results[0],
252
+ lf.LMSamplingResult(
253
+ [
254
+ lf.LMSample(
255
+ lf.AIMessage(
256
+ 'Sample 0 for message.',
257
+ score=0.0,
258
+ logprobs=None,
259
+ is_cached=False,
260
+ usage=lf.LMSamplingUsage(
261
+ prompt_tokens=33,
262
+ completion_tokens=33,
263
+ total_tokens=66,
264
+ estimated_cost=None,
265
+ ),
266
+ tags=[lf.Message.TAG_LM_RESPONSE],
267
+ ),
268
+ score=0.0,
269
+ logprobs=None,
270
+ ),
271
+ lf.LMSample(
272
+ lf.AIMessage(
273
+ 'Sample 1 for message.',
274
+ score=0.0,
275
+ logprobs=None,
276
+ is_cached=False,
277
+ usage=lf.LMSamplingUsage(
278
+ prompt_tokens=33,
279
+ completion_tokens=33,
280
+ total_tokens=66,
281
+ estimated_cost=None,
282
+ ),
283
+ tags=[lf.Message.TAG_LM_RESPONSE],
284
+ ),
285
+ score=0.0,
286
+ logprobs=None,
287
+ ),
288
+ lf.LMSample(
289
+ lf.AIMessage(
290
+ 'Sample 2 for message.',
291
+ score=0.0,
292
+ logprobs=None,
293
+ is_cached=False,
294
+ usage=lf.LMSamplingUsage(
295
+ prompt_tokens=33,
296
+ completion_tokens=33,
297
+ total_tokens=66,
298
+ estimated_cost=None,
299
+ ),
300
+ tags=[lf.Message.TAG_LM_RESPONSE],
301
+ ),
302
+ score=0.0,
303
+ logprobs=None,
304
+ ),
305
+ ],
306
+ usage=lf.LMSamplingUsage(
307
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
308
+ estimated_cost=None,
309
+ ),
310
+ ),
311
+ )
312
+ self.assertEqual(
313
+ results[1],
314
+ lf.LMSamplingResult(
315
+ [
316
+ lf.LMSample(
317
+ lf.AIMessage(
318
+ 'Sample 0 for message.',
319
+ score=0.0,
320
+ logprobs=None,
321
+ is_cached=False,
322
+ usage=lf.LMSamplingUsage(
323
+ prompt_tokens=33,
324
+ completion_tokens=33,
325
+ total_tokens=66,
326
+ estimated_cost=None,
327
+ ),
328
+ tags=[lf.Message.TAG_LM_RESPONSE],
329
+ ),
330
+ score=0.0,
331
+ logprobs=None,
332
+ ),
333
+ lf.LMSample(
334
+ lf.AIMessage(
335
+ 'Sample 1 for message.',
336
+ score=0.0,
337
+ logprobs=None,
338
+ is_cached=False,
339
+ usage=lf.LMSamplingUsage(
340
+ prompt_tokens=33,
341
+ completion_tokens=33,
342
+ total_tokens=66,
343
+ estimated_cost=None,
344
+ ),
345
+ tags=[lf.Message.TAG_LM_RESPONSE],
346
+ ),
347
+ score=0.0,
348
+ logprobs=None,
349
+ ),
350
+ lf.LMSample(
351
+ lf.AIMessage(
352
+ 'Sample 2 for message.',
353
+ score=0.0,
354
+ logprobs=None,
355
+ is_cached=False,
356
+ usage=lf.LMSamplingUsage(
357
+ prompt_tokens=33,
358
+ completion_tokens=33,
359
+ total_tokens=66,
360
+ estimated_cost=None,
361
+ ),
362
+ tags=[lf.Message.TAG_LM_RESPONSE],
363
+ ),
364
+ score=0.0,
365
+ logprobs=None,
366
+ ),
367
+ ],
368
+ usage=lf.LMSamplingUsage(
369
+ prompt_tokens=100, completion_tokens=100, total_tokens=200,
370
+ estimated_cost=None,
371
+ ),
372
+ ),
373
+ )
374
+
375
+ def test_sample_with_contextual_options(self):
376
+ with mock.patch('requests.Session.post') as mock_request:
377
+ mock_request.side_effect = mock_chat_completion_request
378
+ lm = openai_compatible.OpenAICompatible(
379
+ api_endpoint='https://test-server', model='test-model'
380
+ )
381
+ with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
382
+ results = lm.sample(['hello'])
383
+
384
+ self.assertEqual(len(results), 1)
385
+ self.assertEqual(
386
+ results[0],
387
+ lf.LMSamplingResult(
388
+ [
389
+ lf.LMSample(
390
+ lf.AIMessage(
391
+ 'Sample 0 for message.',
392
+ score=0.0,
393
+ logprobs=None,
394
+ is_cached=False,
395
+ usage=lf.LMSamplingUsage(
396
+ prompt_tokens=50,
397
+ completion_tokens=50,
398
+ total_tokens=100,
399
+ ),
400
+ tags=[lf.Message.TAG_LM_RESPONSE],
401
+ ),
402
+ score=0.0,
403
+ logprobs=None,
404
+ ),
405
+ lf.LMSample(
406
+ lf.AIMessage(
407
+ 'Sample 1 for message.',
408
+ score=0.0,
409
+ logprobs=None,
410
+ is_cached=False,
411
+ usage=lf.LMSamplingUsage(
412
+ prompt_tokens=50,
413
+ completion_tokens=50,
414
+ total_tokens=100,
415
+ ),
416
+ tags=[lf.Message.TAG_LM_RESPONSE],
417
+ ),
418
+ score=0.0,
419
+ logprobs=None,
420
+ ),
421
+ ],
422
+ usage=lf.LMSamplingUsage(
423
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
424
+ ),
425
+ )
426
+ )
427
+
428
+ def test_call_with_system_message(self):
429
+ with mock.patch('requests.Session.post') as mock_request:
430
+ mock_request.side_effect = mock_chat_completion_request
431
+ lm = openai_compatible.OpenAICompatible(
432
+ api_endpoint='https://test-server', model='test-model'
433
+ )
434
+ self.assertEqual(
435
+ lm(
436
+ lf.UserMessage(
437
+ 'hello',
438
+ system_message='hi',
439
+ ),
440
+ sampling_options=lf.LMSamplingOptions(n=2)
441
+ ),
442
+ '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
443
+ )
444
+
445
+ def test_call_with_json_schema(self):
446
+ with mock.patch('requests.Session.post') as mock_request:
447
+ mock_request.side_effect = mock_chat_completion_request
448
+ lm = openai_compatible.OpenAICompatible(
449
+ api_endpoint='https://test-server', model='test-model'
450
+ )
451
+ self.assertEqual(
452
+ lm(
453
+ lf.UserMessage(
454
+ 'hello',
455
+ json_schema={
456
+ 'type': 'object',
457
+ 'properties': {
458
+ 'name': {'type': 'string'},
459
+ },
460
+ 'required': ['name'],
461
+ 'title': 'Person',
462
+ }
463
+ ),
464
+ sampling_options=lf.LMSamplingOptions(n=2)
465
+ ),
466
+ 'Sample 0 for message. format=json_schema',
467
+ )
468
+
469
+ # Test bad json schema.
470
+ with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
471
+ lm(lf.UserMessage('hello', json_schema='foo'))
472
+
473
+ with self.assertRaisesRegex(
474
+ ValueError, 'The root of `json_schema` must have a `title` field'
475
+ ):
476
+ lm(lf.UserMessage('hello', json_schema={}))
477
+
478
+
479
+ if __name__ == '__main__':
480
+ unittest.main()