langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Tests for OpenAI models."""
15
15
 
16
+ from typing import Any
16
17
  import unittest
17
18
  from unittest import mock
18
19
 
@@ -20,86 +21,106 @@ import langfun.core as lf
20
21
  from langfun.core import modalities as lf_modalities
21
22
  from langfun.core.llms import openai
22
23
  import pyglove as pg
24
+ import requests
23
25
 
24
26
 
25
- def mock_completion_query(prompt, *, n=1, **kwargs):
26
- del kwargs
27
- choices = []
28
- for i, _ in enumerate(prompt):
29
- for k in range(n):
30
- choices.append(pg.Dict(
31
- index=i,
32
- text=f'Sample {k} for prompt {i}.',
33
- logprobs=k / 10,
34
- ))
35
- return pg.Dict(
36
- choices=choices,
37
- usage=lf.LMSamplingUsage(
38
- prompt_tokens=100,
39
- completion_tokens=100,
40
- total_tokens=200,
41
- ),
42
- )
43
-
44
-
45
- def mock_chat_completion_query(messages, *, n=1, **kwargs):
27
+ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
28
+ del url, kwargs
29
+ messages = json['messages']
46
30
  if len(messages) > 1:
47
31
  system_message = f' system={messages[0]["content"]}'
48
32
  else:
49
33
  system_message = ''
50
34
 
51
- if 'response_format' in kwargs:
52
- response_format = f' format={kwargs["response_format"]["type"]}'
35
+ if 'response_format' in json:
36
+ response_format = f' format={json["response_format"]["type"]}'
53
37
  else:
54
38
  response_format = ''
55
39
 
56
40
  choices = []
57
- for k in range(n):
58
- choices.append(pg.Dict(
59
- message=pg.Dict(
41
+ for k in range(json['n']):
42
+ if json.get('logprobs'):
43
+ logprobs = dict(
44
+ content=[
45
+ dict(
46
+ token='chosen_token',
47
+ logprob=0.5,
48
+ top_logprobs=[
49
+ dict(
50
+ token=f'alternative_token_{i + 1}',
51
+ logprob=0.1
52
+ ) for i in range(3)
53
+ ]
54
+ )
55
+ ]
56
+ )
57
+ else:
58
+ logprobs = None
59
+
60
+ choices.append(dict(
61
+ message=dict(
60
62
  content=(
61
63
  f'Sample {k} for message.{system_message}{response_format}'
62
64
  )
63
65
  ),
64
- logprobs=None,
66
+ logprobs=logprobs,
65
67
  ))
66
- return pg.Dict(
67
- choices=choices,
68
- usage=lf.LMSamplingUsage(
69
- prompt_tokens=100,
70
- completion_tokens=100,
71
- total_tokens=200,
72
- ),
73
- )
68
+ response = requests.Response()
69
+ response.status_code = 200
70
+ response._content = pg.to_json_str(
71
+ dict(
72
+ choices=choices,
73
+ usage=lf.LMSamplingUsage(
74
+ prompt_tokens=100,
75
+ completion_tokens=100,
76
+ total_tokens=200,
77
+ ),
78
+ )
79
+ ).encode()
80
+ return response
74
81
 
75
82
 
76
- def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
77
- del kwargs
83
+ def mock_chat_completion_request_vision(
84
+ url: str, json: dict[str, Any], **kwargs
85
+ ):
86
+ del url, kwargs
78
87
  choices = []
79
88
  urls = [
80
89
  c['image_url']['url']
81
- for c in messages[0]['content'] if c['type'] == 'image_url'
90
+ for c in json['messages'][0]['content'] if c['type'] == 'image_url'
82
91
  ]
83
- for k in range(n):
92
+ for k in range(json['n']):
84
93
  choices.append(pg.Dict(
85
94
  message=pg.Dict(
86
95
  content=f'Sample {k} for message: {"".join(urls)}'
87
96
  ),
88
97
  logprobs=None,
89
98
  ))
90
- return pg.Dict(
91
- choices=choices,
92
- usage=lf.LMSamplingUsage(
93
- prompt_tokens=100,
94
- completion_tokens=100,
95
- total_tokens=200,
96
- ),
97
- )
99
+ response = requests.Response()
100
+ response.status_code = 200
101
+ response._content = pg.to_json_str(
102
+ dict(
103
+ choices=choices,
104
+ usage=lf.LMSamplingUsage(
105
+ prompt_tokens=100,
106
+ completion_tokens=100,
107
+ total_tokens=200,
108
+ ),
109
+ )
110
+ ).encode()
111
+ return response
98
112
 
99
113
 
100
114
  class OpenAITest(unittest.TestCase):
101
115
  """Tests for OpenAI language model."""
102
116
 
117
+ def test_dir(self):
118
+ self.assertIn('gpt-4-turbo', openai.OpenAI.dir())
119
+
120
+ def test_key(self):
121
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
122
+ openai.Gpt4()('hi')
123
+
103
124
  def test_model_id(self):
104
125
  self.assertEqual(
105
126
  openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)')
@@ -112,29 +133,9 @@ class OpenAITest(unittest.TestCase):
112
133
  def test_max_concurrency(self):
113
134
  self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
114
135
 
115
- def test_get_request_args(self):
116
- self.assertEqual(
117
- openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args(
118
- lf.LMSamplingOptions(
119
- temperature=2.0,
120
- logprobs=True,
121
- n=2,
122
- max_tokens=4096,
123
- top_p=1.0)),
124
- dict(
125
- engine='text-davinci-003',
126
- logprobs=True,
127
- top_logprobs=None,
128
- n=2,
129
- temperature=2.0,
130
- max_tokens=4096,
131
- stream=False,
132
- timeout=90.0,
133
- top_p=1.0,
134
- )
135
- )
136
+ def test_request_args(self):
136
137
  self.assertEqual(
137
- openai.Gpt4(api_key='test_key')._get_request_args(
138
+ openai.Gpt4(api_key='test_key')._request_args(
138
139
  lf.LMSamplingOptions(
139
140
  temperature=1.0, stop=['\n'], n=1, random_seed=123
140
141
  )
@@ -144,40 +145,93 @@ class OpenAITest(unittest.TestCase):
144
145
  top_logprobs=None,
145
146
  n=1,
146
147
  temperature=1.0,
147
- stream=False,
148
- timeout=120.0,
149
148
  stop=['\n'],
150
149
  seed=123,
151
150
  ),
152
151
  )
153
152
  with self.assertRaisesRegex(RuntimeError, '`logprobs` is not supported.*'):
154
- openai.GptO1Preview(api_key='test_key')._get_request_args(
153
+ openai.GptO1Preview(api_key='test_key')._request_args(
155
154
  lf.LMSamplingOptions(
156
155
  temperature=1.0, logprobs=True
157
156
  )
158
157
  )
159
158
 
160
- def test_call_completion(self):
161
- with mock.patch('openai.Completion.create') as mock_completion:
162
- mock_completion.side_effect = mock_completion_query
163
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
159
+ def test_call_chat_completion(self):
160
+ with mock.patch('requests.Session.post') as mock_request:
161
+ mock_request.side_effect = mock_chat_completion_request
162
+ lm = openai.OpenAI(
163
+ model='gpt-4',
164
+ api_key='test_key',
165
+ organization='my_org',
166
+ project='my_project'
167
+ )
164
168
  self.assertEqual(
165
169
  lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
166
- 'Sample 0 for prompt 0.',
170
+ 'Sample 0 for message.',
167
171
  )
168
172
 
169
- def test_call_chat_completion(self):
170
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
171
- mock_chat_completion.side_effect = mock_chat_completion_query
172
- lm = openai.OpenAI(api_key='test_key', model='gpt-4')
173
+ def test_call_chat_completion_with_logprobs(self):
174
+ with mock.patch('requests.Session.post') as mock_request:
175
+ mock_request.side_effect = mock_chat_completion_request
176
+ lm = openai.OpenAI(
177
+ model='gpt-4',
178
+ api_key='test_key',
179
+ organization='my_org',
180
+ project='my_project'
181
+ )
182
+ results = lm.sample(['hello'], logprobs=True)
183
+ self.assertEqual(len(results), 1)
173
184
  self.assertEqual(
174
- lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
175
- 'Sample 0 for message.',
185
+ results[0],
186
+ lf.LMSamplingResult(
187
+ [
188
+ lf.LMSample(
189
+ response=lf.AIMessage(
190
+ text='Sample 0 for message.',
191
+ metadata={
192
+ 'score': 0.0,
193
+ 'logprobs': [(
194
+ 'chosen_token',
195
+ 0.5,
196
+ [
197
+ ('alternative_token_1', 0.1),
198
+ ('alternative_token_2', 0.1),
199
+ ('alternative_token_3', 0.1),
200
+ ],
201
+ )],
202
+ 'is_cached': False,
203
+ 'usage': lf.LMSamplingUsage(
204
+ prompt_tokens=100,
205
+ completion_tokens=100,
206
+ total_tokens=200,
207
+ estimated_cost=0.009,
208
+ ),
209
+ },
210
+ tags=['lm-response'],
211
+ ),
212
+ logprobs=[(
213
+ 'chosen_token',
214
+ 0.5,
215
+ [
216
+ ('alternative_token_1', 0.1),
217
+ ('alternative_token_2', 0.1),
218
+ ('alternative_token_3', 0.1),
219
+ ],
220
+ )],
221
+ )
222
+ ],
223
+ usage=lf.LMSamplingUsage(
224
+ prompt_tokens=100,
225
+ completion_tokens=100,
226
+ total_tokens=200,
227
+ estimated_cost=0.009,
228
+ ),
229
+ ),
176
230
  )
177
231
 
178
232
  def test_call_chat_completion_vision(self):
179
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
180
- mock_chat_completion.side_effect = mock_chat_completion_query_vision
233
+ with mock.patch('requests.Session.post') as mock_request:
234
+ mock_request.side_effect = mock_chat_completion_request_vision
181
235
  lm_1 = openai.Gpt4Turbo(api_key='test_key')
182
236
  lm_2 = openai.Gpt4VisionPreview(api_key='test_key')
183
237
  for lm in (lm_1, lm_2):
@@ -191,136 +245,18 @@ class OpenAITest(unittest.TestCase):
191
245
  ),
192
246
  'Sample 0 for message: https://fake/image',
193
247
  )
194
-
195
- def test_sample_completion(self):
196
- with mock.patch('openai.Completion.create') as mock_completion:
197
- mock_completion.side_effect = mock_completion_query
198
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
199
- results = lm.sample(
200
- ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
248
+ lm_3 = openai.Gpt35Turbo(api_key='test_key')
249
+ with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
250
+ lm_3(
251
+ lf.UserMessage(
252
+ 'hello <<[[image]]>>',
253
+ image=lf_modalities.Image.from_uri('https://fake/image')
254
+ ),
201
255
  )
202
256
 
203
- self.assertEqual(len(results), 2)
204
- self.assertEqual(
205
- results[0],
206
- lf.LMSamplingResult(
207
- [
208
- lf.LMSample(
209
- lf.AIMessage(
210
- 'Sample 0 for prompt 0.',
211
- score=0.0,
212
- logprobs=None,
213
- is_cached=False,
214
- usage=lf.LMSamplingUsage(
215
- prompt_tokens=16,
216
- completion_tokens=16,
217
- total_tokens=33
218
- ),
219
- tags=[lf.Message.TAG_LM_RESPONSE],
220
- ),
221
- score=0.0,
222
- logprobs=None,
223
- ),
224
- lf.LMSample(
225
- lf.AIMessage(
226
- 'Sample 1 for prompt 0.',
227
- score=0.1,
228
- logprobs=None,
229
- is_cached=False,
230
- usage=lf.LMSamplingUsage(
231
- prompt_tokens=16,
232
- completion_tokens=16,
233
- total_tokens=33
234
- ),
235
- tags=[lf.Message.TAG_LM_RESPONSE],
236
- ),
237
- score=0.1,
238
- logprobs=None,
239
- ),
240
- lf.LMSample(
241
- lf.AIMessage(
242
- 'Sample 2 for prompt 0.',
243
- score=0.2,
244
- logprobs=None,
245
- is_cached=False,
246
- usage=lf.LMSamplingUsage(
247
- prompt_tokens=16,
248
- completion_tokens=16,
249
- total_tokens=33
250
- ),
251
- tags=[lf.Message.TAG_LM_RESPONSE],
252
- ),
253
- score=0.2,
254
- logprobs=None,
255
- ),
256
- ],
257
- usage=lf.LMSamplingUsage(
258
- prompt_tokens=50, completion_tokens=50, total_tokens=100
259
- ),
260
- ),
261
- )
262
- self.assertEqual(
263
- results[1],
264
- lf.LMSamplingResult(
265
- [
266
- lf.LMSample(
267
- lf.AIMessage(
268
- 'Sample 0 for prompt 1.',
269
- score=0.0,
270
- logprobs=None,
271
- is_cached=False,
272
- usage=lf.LMSamplingUsage(
273
- prompt_tokens=16,
274
- completion_tokens=16,
275
- total_tokens=33
276
- ),
277
- tags=[lf.Message.TAG_LM_RESPONSE],
278
- ),
279
- score=0.0,
280
- logprobs=None,
281
- ),
282
- lf.LMSample(
283
- lf.AIMessage(
284
- 'Sample 1 for prompt 1.',
285
- score=0.1,
286
- logprobs=None,
287
- is_cached=False,
288
- usage=lf.LMSamplingUsage(
289
- prompt_tokens=16,
290
- completion_tokens=16,
291
- total_tokens=33
292
- ),
293
- tags=[lf.Message.TAG_LM_RESPONSE],
294
- ),
295
- score=0.1,
296
- logprobs=None,
297
- ),
298
- lf.LMSample(
299
- lf.AIMessage(
300
- 'Sample 2 for prompt 1.',
301
- score=0.2,
302
- logprobs=None,
303
- is_cached=False,
304
- usage=lf.LMSamplingUsage(
305
- prompt_tokens=16,
306
- completion_tokens=16,
307
- total_tokens=33
308
- ),
309
- tags=[lf.Message.TAG_LM_RESPONSE],
310
- ),
311
- score=0.2,
312
- logprobs=None,
313
- ),
314
- ],
315
- usage=lf.LMSamplingUsage(
316
- prompt_tokens=50, completion_tokens=50, total_tokens=100
317
- ),
318
- ),
319
- )
320
-
321
257
  def test_sample_chat_completion(self):
322
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
323
- mock_chat_completion.side_effect = mock_chat_completion_query
258
+ with mock.patch('requests.Session.post') as mock_request:
259
+ mock_request.side_effect = mock_chat_completion_request
324
260
  openai.SUPPORTED_MODELS_AND_SETTINGS['gpt-4'].update({
325
261
  'cost_per_1k_input_tokens': 1.0,
326
262
  'cost_per_1k_output_tokens': 1.0,
@@ -458,8 +394,8 @@ class OpenAITest(unittest.TestCase):
458
394
  )
459
395
 
460
396
  def test_sample_with_contextual_options(self):
461
- with mock.patch('openai.Completion.create') as mock_completion:
462
- mock_completion.side_effect = mock_completion_query
397
+ with mock.patch('requests.Session.post') as mock_request:
398
+ mock_request.side_effect = mock_chat_completion_request
463
399
  lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
464
400
  with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
465
401
  results = lm.sample(['hello'])
@@ -471,7 +407,7 @@ class OpenAITest(unittest.TestCase):
471
407
  [
472
408
  lf.LMSample(
473
409
  lf.AIMessage(
474
- 'Sample 0 for prompt 0.',
410
+ 'Sample 0 for message.',
475
411
  score=0.0,
476
412
  logprobs=None,
477
413
  is_cached=False,
@@ -487,8 +423,8 @@ class OpenAITest(unittest.TestCase):
487
423
  ),
488
424
  lf.LMSample(
489
425
  lf.AIMessage(
490
- 'Sample 1 for prompt 0.',
491
- score=0.1,
426
+ 'Sample 1 for message.',
427
+ score=0.0,
492
428
  logprobs=None,
493
429
  is_cached=False,
494
430
  usage=lf.LMSamplingUsage(
@@ -498,19 +434,19 @@ class OpenAITest(unittest.TestCase):
498
434
  ),
499
435
  tags=[lf.Message.TAG_LM_RESPONSE],
500
436
  ),
501
- score=0.1,
437
+ score=0.0,
502
438
  logprobs=None,
503
439
  ),
504
440
  ],
505
441
  usage=lf.LMSamplingUsage(
506
442
  prompt_tokens=100, completion_tokens=100, total_tokens=200
507
443
  ),
508
- ),
444
+ )
509
445
  )
510
446
 
511
447
  def test_call_with_system_message(self):
512
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
513
- mock_chat_completion.side_effect = mock_chat_completion_query
448
+ with mock.patch('requests.Session.post') as mock_request:
449
+ mock_request.side_effect = mock_chat_completion_request
514
450
  lm = openai.OpenAI(api_key='test_key', model='gpt-4')
515
451
  self.assertEqual(
516
452
  lm(
@@ -520,12 +456,12 @@ class OpenAITest(unittest.TestCase):
520
456
  ),
521
457
  sampling_options=lf.LMSamplingOptions(n=2)
522
458
  ),
523
- 'Sample 0 for message. system=hi',
459
+ '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
524
460
  )
525
461
 
526
462
  def test_call_with_json_schema(self):
527
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
528
- mock_chat_completion.side_effect = mock_chat_completion_query
463
+ with mock.patch('requests.Session.post') as mock_request:
464
+ mock_request.side_effect = mock_chat_completion_request
529
465
  lm = openai.OpenAI(api_key='test_key', model='gpt-4')
530
466
  self.assertEqual(
531
467
  lm(