langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412030804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/llms/__init__.py +1 -7
- langfun/core/llms/openai.py +142 -207
- langfun/core/llms/openai_test.py +160 -224
- langfun/core/llms/vertexai.py +23 -422
- langfun/core/llms/vertexai_test.py +21 -335
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/METADATA +1 -12
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/RECORD +10 -10
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202412020805.dist-info → langfun-0.1.2.dev202412030804.dist-info}/top_level.txt +0 -0
langfun/core/llms/openai_test.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Tests for OpenAI models."""
|
15
15
|
|
16
|
+
from typing import Any
|
16
17
|
import unittest
|
17
18
|
from unittest import mock
|
18
19
|
|
@@ -20,86 +21,106 @@ import langfun.core as lf
|
|
20
21
|
from langfun.core import modalities as lf_modalities
|
21
22
|
from langfun.core.llms import openai
|
22
23
|
import pyglove as pg
|
24
|
+
import requests
|
23
25
|
|
24
26
|
|
25
|
-
def
|
26
|
-
del kwargs
|
27
|
-
|
28
|
-
for i, _ in enumerate(prompt):
|
29
|
-
for k in range(n):
|
30
|
-
choices.append(pg.Dict(
|
31
|
-
index=i,
|
32
|
-
text=f'Sample {k} for prompt {i}.',
|
33
|
-
logprobs=k / 10,
|
34
|
-
))
|
35
|
-
return pg.Dict(
|
36
|
-
choices=choices,
|
37
|
-
usage=lf.LMSamplingUsage(
|
38
|
-
prompt_tokens=100,
|
39
|
-
completion_tokens=100,
|
40
|
-
total_tokens=200,
|
41
|
-
),
|
42
|
-
)
|
43
|
-
|
44
|
-
|
45
|
-
def mock_chat_completion_query(messages, *, n=1, **kwargs):
|
27
|
+
def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
28
|
+
del url, kwargs
|
29
|
+
messages = json['messages']
|
46
30
|
if len(messages) > 1:
|
47
31
|
system_message = f' system={messages[0]["content"]}'
|
48
32
|
else:
|
49
33
|
system_message = ''
|
50
34
|
|
51
|
-
if 'response_format' in
|
52
|
-
response_format = f' format={
|
35
|
+
if 'response_format' in json:
|
36
|
+
response_format = f' format={json["response_format"]["type"]}'
|
53
37
|
else:
|
54
38
|
response_format = ''
|
55
39
|
|
56
40
|
choices = []
|
57
|
-
for k in range(n):
|
58
|
-
|
59
|
-
|
41
|
+
for k in range(json['n']):
|
42
|
+
if json.get('logprobs'):
|
43
|
+
logprobs = dict(
|
44
|
+
content=[
|
45
|
+
dict(
|
46
|
+
token='chosen_token',
|
47
|
+
logprob=0.5,
|
48
|
+
top_logprobs=[
|
49
|
+
dict(
|
50
|
+
token=f'alternative_token_{i + 1}',
|
51
|
+
logprob=0.1
|
52
|
+
) for i in range(3)
|
53
|
+
]
|
54
|
+
)
|
55
|
+
]
|
56
|
+
)
|
57
|
+
else:
|
58
|
+
logprobs = None
|
59
|
+
|
60
|
+
choices.append(dict(
|
61
|
+
message=dict(
|
60
62
|
content=(
|
61
63
|
f'Sample {k} for message.{system_message}{response_format}'
|
62
64
|
)
|
63
65
|
),
|
64
|
-
logprobs=
|
66
|
+
logprobs=logprobs,
|
65
67
|
))
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
68
|
+
response = requests.Response()
|
69
|
+
response.status_code = 200
|
70
|
+
response._content = pg.to_json_str(
|
71
|
+
dict(
|
72
|
+
choices=choices,
|
73
|
+
usage=lf.LMSamplingUsage(
|
74
|
+
prompt_tokens=100,
|
75
|
+
completion_tokens=100,
|
76
|
+
total_tokens=200,
|
77
|
+
),
|
78
|
+
)
|
79
|
+
).encode()
|
80
|
+
return response
|
74
81
|
|
75
82
|
|
76
|
-
def
|
77
|
-
|
83
|
+
def mock_chat_completion_request_vision(
|
84
|
+
url: str, json: dict[str, Any], **kwargs
|
85
|
+
):
|
86
|
+
del url, kwargs
|
78
87
|
choices = []
|
79
88
|
urls = [
|
80
89
|
c['image_url']['url']
|
81
|
-
for c in messages[0]['content'] if c['type'] == 'image_url'
|
90
|
+
for c in json['messages'][0]['content'] if c['type'] == 'image_url'
|
82
91
|
]
|
83
|
-
for k in range(n):
|
92
|
+
for k in range(json['n']):
|
84
93
|
choices.append(pg.Dict(
|
85
94
|
message=pg.Dict(
|
86
95
|
content=f'Sample {k} for message: {"".join(urls)}'
|
87
96
|
),
|
88
97
|
logprobs=None,
|
89
98
|
))
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
99
|
+
response = requests.Response()
|
100
|
+
response.status_code = 200
|
101
|
+
response._content = pg.to_json_str(
|
102
|
+
dict(
|
103
|
+
choices=choices,
|
104
|
+
usage=lf.LMSamplingUsage(
|
105
|
+
prompt_tokens=100,
|
106
|
+
completion_tokens=100,
|
107
|
+
total_tokens=200,
|
108
|
+
),
|
109
|
+
)
|
110
|
+
).encode()
|
111
|
+
return response
|
98
112
|
|
99
113
|
|
100
114
|
class OpenAITest(unittest.TestCase):
|
101
115
|
"""Tests for OpenAI language model."""
|
102
116
|
|
117
|
+
def test_dir(self):
|
118
|
+
self.assertIn('gpt-4-turbo', openai.OpenAI.dir())
|
119
|
+
|
120
|
+
def test_key(self):
|
121
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
122
|
+
openai.Gpt4()('hi')
|
123
|
+
|
103
124
|
def test_model_id(self):
|
104
125
|
self.assertEqual(
|
105
126
|
openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)')
|
@@ -112,29 +133,9 @@ class OpenAITest(unittest.TestCase):
|
|
112
133
|
def test_max_concurrency(self):
|
113
134
|
self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
|
114
135
|
|
115
|
-
def
|
116
|
-
self.assertEqual(
|
117
|
-
openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args(
|
118
|
-
lf.LMSamplingOptions(
|
119
|
-
temperature=2.0,
|
120
|
-
logprobs=True,
|
121
|
-
n=2,
|
122
|
-
max_tokens=4096,
|
123
|
-
top_p=1.0)),
|
124
|
-
dict(
|
125
|
-
engine='text-davinci-003',
|
126
|
-
logprobs=True,
|
127
|
-
top_logprobs=None,
|
128
|
-
n=2,
|
129
|
-
temperature=2.0,
|
130
|
-
max_tokens=4096,
|
131
|
-
stream=False,
|
132
|
-
timeout=90.0,
|
133
|
-
top_p=1.0,
|
134
|
-
)
|
135
|
-
)
|
136
|
+
def test_request_args(self):
|
136
137
|
self.assertEqual(
|
137
|
-
openai.Gpt4(api_key='test_key').
|
138
|
+
openai.Gpt4(api_key='test_key')._request_args(
|
138
139
|
lf.LMSamplingOptions(
|
139
140
|
temperature=1.0, stop=['\n'], n=1, random_seed=123
|
140
141
|
)
|
@@ -144,40 +145,93 @@ class OpenAITest(unittest.TestCase):
|
|
144
145
|
top_logprobs=None,
|
145
146
|
n=1,
|
146
147
|
temperature=1.0,
|
147
|
-
stream=False,
|
148
|
-
timeout=120.0,
|
149
148
|
stop=['\n'],
|
150
149
|
seed=123,
|
151
150
|
),
|
152
151
|
)
|
153
152
|
with self.assertRaisesRegex(RuntimeError, '`logprobs` is not supported.*'):
|
154
|
-
openai.GptO1Preview(api_key='test_key').
|
153
|
+
openai.GptO1Preview(api_key='test_key')._request_args(
|
155
154
|
lf.LMSamplingOptions(
|
156
155
|
temperature=1.0, logprobs=True
|
157
156
|
)
|
158
157
|
)
|
159
158
|
|
160
|
-
def
|
161
|
-
with mock.patch('
|
162
|
-
|
163
|
-
lm = openai.OpenAI(
|
159
|
+
def test_call_chat_completion(self):
|
160
|
+
with mock.patch('requests.Session.post') as mock_request:
|
161
|
+
mock_request.side_effect = mock_chat_completion_request
|
162
|
+
lm = openai.OpenAI(
|
163
|
+
model='gpt-4',
|
164
|
+
api_key='test_key',
|
165
|
+
organization='my_org',
|
166
|
+
project='my_project'
|
167
|
+
)
|
164
168
|
self.assertEqual(
|
165
169
|
lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
|
166
|
-
'Sample 0 for
|
170
|
+
'Sample 0 for message.',
|
167
171
|
)
|
168
172
|
|
169
|
-
def
|
170
|
-
with mock.patch('
|
171
|
-
|
172
|
-
lm = openai.OpenAI(
|
173
|
+
def test_call_chat_completion_with_logprobs(self):
|
174
|
+
with mock.patch('requests.Session.post') as mock_request:
|
175
|
+
mock_request.side_effect = mock_chat_completion_request
|
176
|
+
lm = openai.OpenAI(
|
177
|
+
model='gpt-4',
|
178
|
+
api_key='test_key',
|
179
|
+
organization='my_org',
|
180
|
+
project='my_project'
|
181
|
+
)
|
182
|
+
results = lm.sample(['hello'], logprobs=True)
|
183
|
+
self.assertEqual(len(results), 1)
|
173
184
|
self.assertEqual(
|
174
|
-
|
175
|
-
|
185
|
+
results[0],
|
186
|
+
lf.LMSamplingResult(
|
187
|
+
[
|
188
|
+
lf.LMSample(
|
189
|
+
response=lf.AIMessage(
|
190
|
+
text='Sample 0 for message.',
|
191
|
+
metadata={
|
192
|
+
'score': 0.0,
|
193
|
+
'logprobs': [(
|
194
|
+
'chosen_token',
|
195
|
+
0.5,
|
196
|
+
[
|
197
|
+
('alternative_token_1', 0.1),
|
198
|
+
('alternative_token_2', 0.1),
|
199
|
+
('alternative_token_3', 0.1),
|
200
|
+
],
|
201
|
+
)],
|
202
|
+
'is_cached': False,
|
203
|
+
'usage': lf.LMSamplingUsage(
|
204
|
+
prompt_tokens=100,
|
205
|
+
completion_tokens=100,
|
206
|
+
total_tokens=200,
|
207
|
+
estimated_cost=0.009,
|
208
|
+
),
|
209
|
+
},
|
210
|
+
tags=['lm-response'],
|
211
|
+
),
|
212
|
+
logprobs=[(
|
213
|
+
'chosen_token',
|
214
|
+
0.5,
|
215
|
+
[
|
216
|
+
('alternative_token_1', 0.1),
|
217
|
+
('alternative_token_2', 0.1),
|
218
|
+
('alternative_token_3', 0.1),
|
219
|
+
],
|
220
|
+
)],
|
221
|
+
)
|
222
|
+
],
|
223
|
+
usage=lf.LMSamplingUsage(
|
224
|
+
prompt_tokens=100,
|
225
|
+
completion_tokens=100,
|
226
|
+
total_tokens=200,
|
227
|
+
estimated_cost=0.009,
|
228
|
+
),
|
229
|
+
),
|
176
230
|
)
|
177
231
|
|
178
232
|
def test_call_chat_completion_vision(self):
|
179
|
-
with mock.patch('
|
180
|
-
|
233
|
+
with mock.patch('requests.Session.post') as mock_request:
|
234
|
+
mock_request.side_effect = mock_chat_completion_request_vision
|
181
235
|
lm_1 = openai.Gpt4Turbo(api_key='test_key')
|
182
236
|
lm_2 = openai.Gpt4VisionPreview(api_key='test_key')
|
183
237
|
for lm in (lm_1, lm_2):
|
@@ -191,136 +245,18 @@ class OpenAITest(unittest.TestCase):
|
|
191
245
|
),
|
192
246
|
'Sample 0 for message: https://fake/image',
|
193
247
|
)
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
248
|
+
lm_3 = openai.Gpt35Turbo(api_key='test_key')
|
249
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
250
|
+
lm_3(
|
251
|
+
lf.UserMessage(
|
252
|
+
'hello <<[[image]]>>',
|
253
|
+
image=lf_modalities.Image.from_uri('https://fake/image')
|
254
|
+
),
|
201
255
|
)
|
202
256
|
|
203
|
-
self.assertEqual(len(results), 2)
|
204
|
-
self.assertEqual(
|
205
|
-
results[0],
|
206
|
-
lf.LMSamplingResult(
|
207
|
-
[
|
208
|
-
lf.LMSample(
|
209
|
-
lf.AIMessage(
|
210
|
-
'Sample 0 for prompt 0.',
|
211
|
-
score=0.0,
|
212
|
-
logprobs=None,
|
213
|
-
is_cached=False,
|
214
|
-
usage=lf.LMSamplingUsage(
|
215
|
-
prompt_tokens=16,
|
216
|
-
completion_tokens=16,
|
217
|
-
total_tokens=33
|
218
|
-
),
|
219
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
220
|
-
),
|
221
|
-
score=0.0,
|
222
|
-
logprobs=None,
|
223
|
-
),
|
224
|
-
lf.LMSample(
|
225
|
-
lf.AIMessage(
|
226
|
-
'Sample 1 for prompt 0.',
|
227
|
-
score=0.1,
|
228
|
-
logprobs=None,
|
229
|
-
is_cached=False,
|
230
|
-
usage=lf.LMSamplingUsage(
|
231
|
-
prompt_tokens=16,
|
232
|
-
completion_tokens=16,
|
233
|
-
total_tokens=33
|
234
|
-
),
|
235
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
236
|
-
),
|
237
|
-
score=0.1,
|
238
|
-
logprobs=None,
|
239
|
-
),
|
240
|
-
lf.LMSample(
|
241
|
-
lf.AIMessage(
|
242
|
-
'Sample 2 for prompt 0.',
|
243
|
-
score=0.2,
|
244
|
-
logprobs=None,
|
245
|
-
is_cached=False,
|
246
|
-
usage=lf.LMSamplingUsage(
|
247
|
-
prompt_tokens=16,
|
248
|
-
completion_tokens=16,
|
249
|
-
total_tokens=33
|
250
|
-
),
|
251
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
252
|
-
),
|
253
|
-
score=0.2,
|
254
|
-
logprobs=None,
|
255
|
-
),
|
256
|
-
],
|
257
|
-
usage=lf.LMSamplingUsage(
|
258
|
-
prompt_tokens=50, completion_tokens=50, total_tokens=100
|
259
|
-
),
|
260
|
-
),
|
261
|
-
)
|
262
|
-
self.assertEqual(
|
263
|
-
results[1],
|
264
|
-
lf.LMSamplingResult(
|
265
|
-
[
|
266
|
-
lf.LMSample(
|
267
|
-
lf.AIMessage(
|
268
|
-
'Sample 0 for prompt 1.',
|
269
|
-
score=0.0,
|
270
|
-
logprobs=None,
|
271
|
-
is_cached=False,
|
272
|
-
usage=lf.LMSamplingUsage(
|
273
|
-
prompt_tokens=16,
|
274
|
-
completion_tokens=16,
|
275
|
-
total_tokens=33
|
276
|
-
),
|
277
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
278
|
-
),
|
279
|
-
score=0.0,
|
280
|
-
logprobs=None,
|
281
|
-
),
|
282
|
-
lf.LMSample(
|
283
|
-
lf.AIMessage(
|
284
|
-
'Sample 1 for prompt 1.',
|
285
|
-
score=0.1,
|
286
|
-
logprobs=None,
|
287
|
-
is_cached=False,
|
288
|
-
usage=lf.LMSamplingUsage(
|
289
|
-
prompt_tokens=16,
|
290
|
-
completion_tokens=16,
|
291
|
-
total_tokens=33
|
292
|
-
),
|
293
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
294
|
-
),
|
295
|
-
score=0.1,
|
296
|
-
logprobs=None,
|
297
|
-
),
|
298
|
-
lf.LMSample(
|
299
|
-
lf.AIMessage(
|
300
|
-
'Sample 2 for prompt 1.',
|
301
|
-
score=0.2,
|
302
|
-
logprobs=None,
|
303
|
-
is_cached=False,
|
304
|
-
usage=lf.LMSamplingUsage(
|
305
|
-
prompt_tokens=16,
|
306
|
-
completion_tokens=16,
|
307
|
-
total_tokens=33
|
308
|
-
),
|
309
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
310
|
-
),
|
311
|
-
score=0.2,
|
312
|
-
logprobs=None,
|
313
|
-
),
|
314
|
-
],
|
315
|
-
usage=lf.LMSamplingUsage(
|
316
|
-
prompt_tokens=50, completion_tokens=50, total_tokens=100
|
317
|
-
),
|
318
|
-
),
|
319
|
-
)
|
320
|
-
|
321
257
|
def test_sample_chat_completion(self):
|
322
|
-
with mock.patch('
|
323
|
-
|
258
|
+
with mock.patch('requests.Session.post') as mock_request:
|
259
|
+
mock_request.side_effect = mock_chat_completion_request
|
324
260
|
openai.SUPPORTED_MODELS_AND_SETTINGS['gpt-4'].update({
|
325
261
|
'cost_per_1k_input_tokens': 1.0,
|
326
262
|
'cost_per_1k_output_tokens': 1.0,
|
@@ -458,8 +394,8 @@ class OpenAITest(unittest.TestCase):
|
|
458
394
|
)
|
459
395
|
|
460
396
|
def test_sample_with_contextual_options(self):
|
461
|
-
with mock.patch('
|
462
|
-
|
397
|
+
with mock.patch('requests.Session.post') as mock_request:
|
398
|
+
mock_request.side_effect = mock_chat_completion_request
|
463
399
|
lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
|
464
400
|
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
465
401
|
results = lm.sample(['hello'])
|
@@ -471,7 +407,7 @@ class OpenAITest(unittest.TestCase):
|
|
471
407
|
[
|
472
408
|
lf.LMSample(
|
473
409
|
lf.AIMessage(
|
474
|
-
'Sample 0 for
|
410
|
+
'Sample 0 for message.',
|
475
411
|
score=0.0,
|
476
412
|
logprobs=None,
|
477
413
|
is_cached=False,
|
@@ -487,8 +423,8 @@ class OpenAITest(unittest.TestCase):
|
|
487
423
|
),
|
488
424
|
lf.LMSample(
|
489
425
|
lf.AIMessage(
|
490
|
-
'Sample 1 for
|
491
|
-
score=0.
|
426
|
+
'Sample 1 for message.',
|
427
|
+
score=0.0,
|
492
428
|
logprobs=None,
|
493
429
|
is_cached=False,
|
494
430
|
usage=lf.LMSamplingUsage(
|
@@ -498,19 +434,19 @@ class OpenAITest(unittest.TestCase):
|
|
498
434
|
),
|
499
435
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
500
436
|
),
|
501
|
-
score=0.
|
437
|
+
score=0.0,
|
502
438
|
logprobs=None,
|
503
439
|
),
|
504
440
|
],
|
505
441
|
usage=lf.LMSamplingUsage(
|
506
442
|
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
507
443
|
),
|
508
|
-
)
|
444
|
+
)
|
509
445
|
)
|
510
446
|
|
511
447
|
def test_call_with_system_message(self):
|
512
|
-
with mock.patch('
|
513
|
-
|
448
|
+
with mock.patch('requests.Session.post') as mock_request:
|
449
|
+
mock_request.side_effect = mock_chat_completion_request
|
514
450
|
lm = openai.OpenAI(api_key='test_key', model='gpt-4')
|
515
451
|
self.assertEqual(
|
516
452
|
lm(
|
@@ -520,12 +456,12 @@ class OpenAITest(unittest.TestCase):
|
|
520
456
|
),
|
521
457
|
sampling_options=lf.LMSamplingOptions(n=2)
|
522
458
|
),
|
523
|
-
'Sample 0 for message. system=hi',
|
459
|
+
'''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
|
524
460
|
)
|
525
461
|
|
526
462
|
def test_call_with_json_schema(self):
|
527
|
-
with mock.patch('
|
528
|
-
|
463
|
+
with mock.patch('requests.Session.post') as mock_request:
|
464
|
+
mock_request.side_effect = mock_chat_completion_request
|
529
465
|
lm = openai.OpenAI(api_key='test_key', model='gpt-4')
|
530
466
|
self.assertEqual(
|
531
467
|
lm(
|