langfun 0.1.2.dev202501090804__py3-none-any.whl → 0.1.2.dev202501100804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/llms/__init__.py +3 -0
- langfun/core/llms/deepseek.py +8 -152
- langfun/core/llms/deepseek_test.py +12 -389
- langfun/core/llms/groq.py +12 -99
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +17 -54
- langfun/core/llms/llama_cpp_test.py +2 -34
- langfun/core/llms/openai.py +9 -147
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +480 -0
- langfun/core/llms/openai_test.py +13 -423
- langfun/core/modalities/mime.py +8 -0
- langfun/core/modalities/mime_test.py +19 -4
- langfun/core/modality_test.py +0 -1
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/RECORD +19 -17
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202501090804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/top_level.txt +0 -0
langfun/core/llms/groq.py
CHANGED
@@ -17,8 +17,7 @@ import os
|
|
17
17
|
from typing import Annotated, Any
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import
|
21
|
-
from langfun.core.llms import rest
|
20
|
+
from langfun.core.llms import openai_compatible
|
22
21
|
import pyglove as pg
|
23
22
|
|
24
23
|
|
@@ -95,7 +94,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
95
94
|
|
96
95
|
|
97
96
|
@lf.use_init_args(['model'])
|
98
|
-
class Groq(
|
97
|
+
class Groq(openai_compatible.OpenAICompatible):
|
99
98
|
"""Groq LLMs through REST APIs (OpenAI compatible).
|
100
99
|
|
101
100
|
See https://platform.openai.com/docs/api-reference/chat
|
@@ -108,10 +107,6 @@ class Groq(rest.REST):
|
|
108
107
|
'The name of the model to use.',
|
109
108
|
]
|
110
109
|
|
111
|
-
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
112
|
-
False
|
113
|
-
)
|
114
|
-
|
115
110
|
api_key: Annotated[
|
116
111
|
str | None,
|
117
112
|
(
|
@@ -122,25 +117,19 @@ class Groq(rest.REST):
|
|
122
117
|
|
123
118
|
api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
|
124
119
|
|
125
|
-
|
126
|
-
|
127
|
-
self._api_key = None
|
128
|
-
|
129
|
-
def _initialize(self):
|
120
|
+
@property
|
121
|
+
def headers(self) -> dict[str, Any]:
|
130
122
|
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
131
123
|
if not api_key:
|
132
124
|
raise ValueError(
|
133
125
|
'Please specify `api_key` during `__init__` or set environment '
|
134
126
|
'variable `GROQ_API_KEY` with your Groq API key.'
|
135
127
|
)
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
return
|
141
|
-
'Authorization': f'Bearer {self._api_key}',
|
142
|
-
'Content-Type': 'application/json',
|
143
|
-
}
|
128
|
+
headers = super().headers
|
129
|
+
headers.update({
|
130
|
+
'Authorization': f'Bearer {api_key}',
|
131
|
+
})
|
132
|
+
return headers
|
144
133
|
|
145
134
|
@property
|
146
135
|
def model_id(self) -> str:
|
@@ -170,90 +159,14 @@ class Groq(rest.REST):
|
|
170
159
|
+ cost_per_1k_output_tokens * num_output_tokens
|
171
160
|
) / 1000
|
172
161
|
|
173
|
-
def request(
|
174
|
-
self,
|
175
|
-
prompt: lf.Message,
|
176
|
-
sampling_options: lf.LMSamplingOptions
|
177
|
-
) -> dict[str, Any]:
|
178
|
-
"""Returns the JSON input for a message."""
|
179
|
-
request = dict()
|
180
|
-
request.update(self._request_args(sampling_options))
|
181
|
-
request.update(
|
182
|
-
dict(
|
183
|
-
messages=[
|
184
|
-
dict(role='user', content=self._content_from_message(prompt))
|
185
|
-
]
|
186
|
-
)
|
187
|
-
)
|
188
|
-
return request
|
189
|
-
|
190
162
|
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
191
163
|
"""Returns a dict as request arguments."""
|
192
164
|
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
193
|
-
args =
|
194
|
-
|
195
|
-
|
196
|
-
stream=False,
|
197
|
-
)
|
198
|
-
|
199
|
-
if options.temperature is not None:
|
200
|
-
args['temperature'] = options.temperature
|
201
|
-
if options.max_tokens is not None:
|
202
|
-
args['max_tokens'] = options.max_tokens
|
203
|
-
if options.top_p is not None:
|
204
|
-
args['top_p'] = options.top_p
|
205
|
-
if options.stop:
|
206
|
-
args['stop'] = options.stop
|
165
|
+
args = super()._request_args(options)
|
166
|
+
args.pop('logprobs', None)
|
167
|
+
args.pop('top_logprobs', None)
|
207
168
|
return args
|
208
169
|
|
209
|
-
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
210
|
-
"""Converts an message to Groq's content protocol (list of dicts)."""
|
211
|
-
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
212
|
-
content = []
|
213
|
-
for chunk in prompt.chunk():
|
214
|
-
if isinstance(chunk, str):
|
215
|
-
item = dict(type='text', text=chunk)
|
216
|
-
elif (
|
217
|
-
self.multimodal
|
218
|
-
and isinstance(chunk, lf_modalities.Image)
|
219
|
-
and chunk.uri
|
220
|
-
):
|
221
|
-
# NOTE(daiyip): Groq only support image URL.
|
222
|
-
item = dict(type='image_url', image_url=chunk.uri)
|
223
|
-
else:
|
224
|
-
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
225
|
-
content.append(item)
|
226
|
-
return content
|
227
|
-
|
228
|
-
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
229
|
-
samples = [
|
230
|
-
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
231
|
-
for choice in json['choices']
|
232
|
-
]
|
233
|
-
usage = json['usage']
|
234
|
-
return lf.LMSamplingResult(
|
235
|
-
samples,
|
236
|
-
usage=lf.LMSamplingUsage(
|
237
|
-
prompt_tokens=usage['prompt_tokens'],
|
238
|
-
completion_tokens=usage['completion_tokens'],
|
239
|
-
total_tokens=usage['total_tokens'],
|
240
|
-
estimated_cost=self.estimate_cost(
|
241
|
-
num_input_tokens=usage['prompt_tokens'],
|
242
|
-
num_output_tokens=usage['completion_tokens'],
|
243
|
-
),
|
244
|
-
),
|
245
|
-
)
|
246
|
-
|
247
|
-
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
248
|
-
"""Converts Groq's content protocol to message."""
|
249
|
-
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
250
|
-
content = choice['message']['content']
|
251
|
-
if isinstance(content, str):
|
252
|
-
return lf.AIMessage(content)
|
253
|
-
return lf.AIMessage.from_chunks(
|
254
|
-
[x['text'] for x in content if x['type'] == 'text']
|
255
|
-
)
|
256
|
-
|
257
170
|
|
258
171
|
class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
|
259
172
|
"""Llama3.2-3B with 8K context window.
|
langfun/core/llms/groq_test.py
CHANGED
@@ -11,89 +11,10 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for Groq models."""
|
15
|
-
|
16
14
|
import os
|
17
|
-
from typing import Any
|
18
15
|
import unittest
|
19
|
-
|
20
|
-
from langfun.core import modalities as lf_modalities
|
16
|
+
import langfun.core as lf
|
21
17
|
from langfun.core.llms import groq
|
22
|
-
import pyglove as pg
|
23
|
-
import requests
|
24
|
-
|
25
|
-
|
26
|
-
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
27
|
-
del url, kwargs
|
28
|
-
|
29
|
-
response = requests.Response()
|
30
|
-
response.status_code = 200
|
31
|
-
response._content = pg.to_json_str({
|
32
|
-
'choices': [{
|
33
|
-
'message': {
|
34
|
-
'content': [{
|
35
|
-
'type': 'text',
|
36
|
-
'text': (
|
37
|
-
f'hello with temperature={json.get("temperature")}, '
|
38
|
-
f'top_p={json.get("top_p")}, '
|
39
|
-
f'max_tokens={json.get("max_tokens")}, '
|
40
|
-
f'stop={json.get("stop")}.'
|
41
|
-
),
|
42
|
-
}],
|
43
|
-
}
|
44
|
-
}],
|
45
|
-
'usage': {
|
46
|
-
'prompt_tokens': 2,
|
47
|
-
'completion_tokens': 1,
|
48
|
-
'total_tokens': 3,
|
49
|
-
},
|
50
|
-
}).encode()
|
51
|
-
return response
|
52
|
-
|
53
|
-
|
54
|
-
def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
|
55
|
-
del url, kwargs
|
56
|
-
v = json['messages'][0]['content'][0]
|
57
|
-
image = lf_modalities.Image.from_uri(v['image_url'])
|
58
|
-
|
59
|
-
response = requests.Response()
|
60
|
-
response.status_code = 200
|
61
|
-
response._content = pg.to_json_str({
|
62
|
-
'choices': [
|
63
|
-
{
|
64
|
-
'message': {
|
65
|
-
'content': [{
|
66
|
-
'type': 'text',
|
67
|
-
'text': image.uri,
|
68
|
-
}],
|
69
|
-
}
|
70
|
-
}
|
71
|
-
],
|
72
|
-
'usage': {
|
73
|
-
'prompt_tokens': 2,
|
74
|
-
'completion_tokens': 1,
|
75
|
-
'total_tokens': 3,
|
76
|
-
},
|
77
|
-
}).encode()
|
78
|
-
return response
|
79
|
-
|
80
|
-
|
81
|
-
def mock_requests_post_error(status_code, error_type, error_message):
|
82
|
-
def _mock_requests(url: str, json: dict[str, Any], **kwargs):
|
83
|
-
del url, json, kwargs
|
84
|
-
response = requests.Response()
|
85
|
-
response.status_code = status_code
|
86
|
-
response._content = pg.to_json_str(
|
87
|
-
{
|
88
|
-
'error': {
|
89
|
-
'type': error_type,
|
90
|
-
'message': error_message,
|
91
|
-
}
|
92
|
-
}
|
93
|
-
).encode()
|
94
|
-
return response
|
95
|
-
|
96
|
-
return _mock_requests
|
97
18
|
|
98
19
|
|
99
20
|
class AuthropicTest(unittest.TestCase):
|
@@ -101,69 +22,42 @@ class AuthropicTest(unittest.TestCase):
|
|
101
22
|
def test_basics(self):
|
102
23
|
self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
|
103
24
|
self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
|
25
|
+
self.assertEqual(groq.GroqMistral_8x7B().estimate_cost(100, 100), 4.8e-5)
|
26
|
+
|
27
|
+
def test_request_args(self):
|
28
|
+
args = groq.GroqMistral_8x7B()._request_args(
|
29
|
+
lf.LMSamplingOptions(
|
30
|
+
temperature=1.0, stop=['\n'], n=1, random_seed=123,
|
31
|
+
logprobs=True, top_logprobs=True
|
32
|
+
)
|
33
|
+
)
|
34
|
+
self.assertNotIn('logprobs', args)
|
35
|
+
self.assertNotIn('top_logprobs', args)
|
104
36
|
|
105
37
|
def test_api_key(self):
|
106
38
|
lm = groq.GroqMistral_8x7B()
|
107
39
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
108
|
-
lm
|
109
|
-
|
110
|
-
with mock.patch('requests.Session.post') as mock_request:
|
111
|
-
mock_request.side_effect = mock_requests_post
|
112
|
-
|
113
|
-
lm = groq.GroqMistral_8x7B(api_key='fake key')
|
114
|
-
self.assertRegex(lm('hi').text, 'hello.*')
|
115
|
-
|
116
|
-
os.environ['GROQ_API_KEY'] = 'abc'
|
117
|
-
lm = groq.GroqMistral_8x7B()
|
118
|
-
self.assertRegex(lm('hi').text, 'hello.*')
|
119
|
-
del os.environ['GROQ_API_KEY']
|
120
|
-
|
121
|
-
def test_call(self):
|
122
|
-
with mock.patch('requests.Session.post') as mock_request:
|
123
|
-
mock_request.side_effect = mock_requests_post
|
124
|
-
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
125
|
-
response = lm(
|
126
|
-
'hello',
|
127
|
-
temperature=0.0,
|
128
|
-
max_tokens=1024,
|
129
|
-
top_k=0.1,
|
130
|
-
top_p=0.2,
|
131
|
-
stop=['\n'],
|
132
|
-
)
|
133
|
-
self.assertEqual(
|
134
|
-
response.text,
|
135
|
-
(
|
136
|
-
'hello with temperature=0.0, top_p=0.2, '
|
137
|
-
"max_tokens=1024, stop=['\\n']."
|
138
|
-
),
|
139
|
-
)
|
140
|
-
self.assertIsNotNone(response.usage)
|
141
|
-
self.assertIsNotNone(response.usage.prompt_tokens, 2)
|
142
|
-
self.assertIsNotNone(response.usage.completion_tokens, 1)
|
143
|
-
self.assertIsNotNone(response.usage.total_tokens, 3)
|
40
|
+
_ = lm.headers
|
144
41
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
42
|
+
lm = groq.GroqMistral_8x7B(api_key='fake key')
|
43
|
+
self.assertEqual(
|
44
|
+
lm.headers,
|
45
|
+
{
|
46
|
+
'Content-Type': 'application/json',
|
47
|
+
'Authorization': 'Bearer fake key',
|
48
|
+
}
|
49
|
+
)
|
151
50
|
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
163
|
-
with self.assertRaisesRegex(
|
164
|
-
Exception, f'{status_code}:.*{error_type}'
|
165
|
-
):
|
166
|
-
lm('hello', max_attempts=1)
|
51
|
+
os.environ['GROQ_API_KEY'] = 'abc'
|
52
|
+
lm = groq.GroqMistral_8x7B()
|
53
|
+
self.assertEqual(
|
54
|
+
lm.headers,
|
55
|
+
{
|
56
|
+
'Content-Type': 'application/json',
|
57
|
+
'Authorization': 'Bearer abc',
|
58
|
+
}
|
59
|
+
)
|
60
|
+
del os.environ['GROQ_API_KEY']
|
167
61
|
|
168
62
|
|
169
63
|
if __name__ == '__main__':
|
langfun/core/llms/llama_cpp.py
CHANGED
@@ -13,72 +13,35 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from llama.cpp."""
|
15
15
|
|
16
|
-
from typing import
|
17
|
-
|
18
|
-
import langfun.core as lf
|
19
|
-
from langfun.core.llms import rest
|
16
|
+
from typing import Annotated
|
17
|
+
from langfun.core.llms import openai_compatible
|
20
18
|
import pyglove as pg
|
21
19
|
|
22
20
|
|
23
|
-
|
21
|
+
@pg.use_init_args(['url', 'model'])
|
22
|
+
@pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
|
23
|
+
class LlamaCppRemote(openai_compatible.OpenAICompatible):
|
24
24
|
"""The remote LLaMA C++ model.
|
25
25
|
|
26
26
|
The Remote LLaMA C++ models can be launched via
|
27
27
|
https://github.com/ggerganov/llama.cpp/tree/master/examples/server
|
28
28
|
"""
|
29
|
+
url: Annotated[
|
30
|
+
str,
|
31
|
+
'The URL of the LLaMA C++ server.',
|
32
|
+
]
|
33
|
+
|
34
|
+
model: Annotated[
|
35
|
+
str,
|
36
|
+
'The name of the model to use.',
|
37
|
+
] = ''
|
29
38
|
|
30
|
-
@
|
31
|
-
def
|
32
|
-
|
39
|
+
@property
|
40
|
+
def api_endpoint(self) -> str:
|
41
|
+
return self.url + '/completion'
|
33
42
|
|
34
43
|
@property
|
35
44
|
def model_id(self) -> str:
|
36
45
|
"""Returns a string to identify the model."""
|
37
46
|
return f'LLaMAC++({self.model or ""})'
|
38
47
|
|
39
|
-
def request(
|
40
|
-
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
|
41
|
-
) -> dict[str, Any]:
|
42
|
-
"""Returns the JSON input for a message."""
|
43
|
-
request = dict()
|
44
|
-
request.update(self._request_args(sampling_options))
|
45
|
-
# NOTE(daiyip): multi-modal is current not supported.
|
46
|
-
request['prompt'] = prompt.text
|
47
|
-
return request
|
48
|
-
|
49
|
-
def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
50
|
-
"""Returns a dict as request arguments."""
|
51
|
-
args = dict(
|
52
|
-
n_predict=options.max_tokens or 1024,
|
53
|
-
top_k=options.top_k or 50,
|
54
|
-
top_p=options.top_p or 0.95,
|
55
|
-
)
|
56
|
-
if options.temperature is not None:
|
57
|
-
args['temperature'] = options.temperature
|
58
|
-
return args
|
59
|
-
|
60
|
-
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
61
|
-
return lf.LMSamplingResult(
|
62
|
-
[lf.LMSample(item['content'], score=0.0) for item in json['items']]
|
63
|
-
)
|
64
|
-
|
65
|
-
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
66
|
-
request = self.request(prompt, self.sampling_options)
|
67
|
-
|
68
|
-
def _sample_one_example(request):
|
69
|
-
response = self._session.post(
|
70
|
-
self.api_endpoint,
|
71
|
-
json=request,
|
72
|
-
timeout=self.timeout,
|
73
|
-
)
|
74
|
-
if response.status_code == 200:
|
75
|
-
return response.json()
|
76
|
-
else:
|
77
|
-
error_cls = self._error_cls_from_status(response.status_code)
|
78
|
-
raise error_cls(f'{response.status_code}: {response.content}')
|
79
|
-
|
80
|
-
items = self._parallel_execute_with_currency_control(
|
81
|
-
_sample_one_example,
|
82
|
-
[request] * (self.sampling_options.n or 1),
|
83
|
-
)
|
84
|
-
return self.result(dict(items=items))
|
@@ -11,48 +11,16 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for llama cpp models."""
|
15
|
-
|
16
|
-
import typing
|
17
14
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
15
|
from langfun.core.llms import llama_cpp
|
21
16
|
|
22
17
|
|
23
|
-
def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
|
24
|
-
del kwargs
|
25
|
-
|
26
|
-
class TEMP:
|
27
|
-
@property
|
28
|
-
def status_code(self):
|
29
|
-
return 200
|
30
|
-
|
31
|
-
def json(self):
|
32
|
-
return {"content": json["prompt"] + "\n" + url}
|
33
|
-
|
34
|
-
return TEMP()
|
35
|
-
|
36
|
-
|
37
18
|
class LlamaCppRemoteTest(unittest.TestCase):
|
38
19
|
"""Tests for the LlamaCppRemote model."""
|
39
20
|
|
40
|
-
def
|
41
|
-
with mock.patch("requests.Session.post") as mock_request:
|
42
|
-
mock_request.side_effect = mock_requests_post
|
43
|
-
lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
|
44
|
-
[result] = lm.sample(["hello"], n=2)
|
45
|
-
self.assertEqual(
|
46
|
-
len(result.samples),
|
47
|
-
2
|
48
|
-
)
|
49
|
-
self.assertEqual(
|
50
|
-
str(result.samples[0].response),
|
51
|
-
"hello\nhttp://127.0.0.1:8080/completion",
|
52
|
-
)
|
53
|
-
|
54
|
-
def test_model_id(self):
|
21
|
+
def test_basics(self):
|
55
22
|
lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
|
23
|
+
self.assertEqual(lm.api_endpoint, "http://127.0.0.1:8080/completion")
|
56
24
|
self.assertEqual(lm.model_id, "LLaMAC++()")
|
57
25
|
lm = llama_cpp.LlamaCppRemote("xxx", model="x")
|
58
26
|
self.assertEqual(lm.model_id, "LLaMAC++(x)")
|
langfun/core/llms/openai.py
CHANGED
@@ -17,8 +17,7 @@ import os
|
|
17
17
|
from typing import Annotated, Any
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import
|
21
|
-
from langfun.core.llms import rest
|
20
|
+
from langfun.core.llms import openai_compatible
|
22
21
|
import pyglove as pg
|
23
22
|
|
24
23
|
|
@@ -299,7 +298,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
299
298
|
|
300
299
|
|
301
300
|
@lf.use_init_args(['model'])
|
302
|
-
class OpenAI(
|
301
|
+
class OpenAI(openai_compatible.OpenAICompatible):
|
303
302
|
"""OpenAI model."""
|
304
303
|
|
305
304
|
model: pg.typing.Annotated[
|
@@ -311,11 +310,6 @@ class OpenAI(rest.REST):
|
|
311
310
|
|
312
311
|
api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
|
313
312
|
|
314
|
-
multimodal: Annotated[
|
315
|
-
bool,
|
316
|
-
'Whether this model has multimodal support.'
|
317
|
-
] = False
|
318
|
-
|
319
313
|
api_key: Annotated[
|
320
314
|
str | None,
|
321
315
|
(
|
@@ -363,10 +357,9 @@ class OpenAI(rest.REST):
|
|
363
357
|
|
364
358
|
@property
|
365
359
|
def headers(self) -> dict[str, Any]:
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
}
|
360
|
+
assert self._api_initialized
|
361
|
+
headers = super().headers
|
362
|
+
headers['Authorization'] = f'Bearer {self._api_key}'
|
370
363
|
if self._organization:
|
371
364
|
headers['OpenAI-Organization'] = self._organization
|
372
365
|
if self._project:
|
@@ -411,141 +404,10 @@ class OpenAI(rest.REST):
|
|
411
404
|
|
412
405
|
def _request_args(
|
413
406
|
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
414
|
-
#
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
model=self.model,
|
419
|
-
n=options.n,
|
420
|
-
top_logprobs=options.top_logprobs,
|
421
|
-
)
|
422
|
-
if options.logprobs:
|
423
|
-
# Reasoning models (o1 series) does not support `logprobs` by 2024/09/12.
|
424
|
-
if self.model.startswith('o1-'):
|
425
|
-
raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
|
426
|
-
args['logprobs'] = options.logprobs
|
427
|
-
|
428
|
-
if options.temperature is not None:
|
429
|
-
args['temperature'] = options.temperature
|
430
|
-
if options.max_tokens is not None:
|
431
|
-
args['max_completion_tokens'] = options.max_tokens
|
432
|
-
if options.top_p is not None:
|
433
|
-
args['top_p'] = options.top_p
|
434
|
-
if options.stop:
|
435
|
-
args['stop'] = options.stop
|
436
|
-
if options.random_seed is not None:
|
437
|
-
args['seed'] = options.random_seed
|
438
|
-
return args
|
439
|
-
|
440
|
-
def _content_from_message(self, message: lf.Message):
|
441
|
-
"""Returns a OpenAI content object from a Langfun message."""
|
442
|
-
def _uri_from(chunk: lf.Modality) -> str:
|
443
|
-
if chunk.uri and chunk.uri.lower().startswith(
|
444
|
-
('http:', 'https:', 'ftp:')
|
445
|
-
):
|
446
|
-
return chunk.uri
|
447
|
-
return chunk.content_uri
|
448
|
-
|
449
|
-
content = []
|
450
|
-
for chunk in message.chunk():
|
451
|
-
if isinstance(chunk, str):
|
452
|
-
item = dict(type='text', text=chunk)
|
453
|
-
elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
|
454
|
-
item = dict(type='image_url', image_url=dict(url=_uri_from(chunk)))
|
455
|
-
else:
|
456
|
-
raise ValueError(f'Unsupported modality: {chunk!r}.')
|
457
|
-
content.append(item)
|
458
|
-
return content
|
459
|
-
|
460
|
-
def request(
|
461
|
-
self,
|
462
|
-
prompt: lf.Message,
|
463
|
-
sampling_options: lf.LMSamplingOptions
|
464
|
-
) -> dict[str, Any]:
|
465
|
-
"""Returns the JSON input for a message."""
|
466
|
-
request_args = self._request_args(sampling_options)
|
467
|
-
|
468
|
-
# Users could use `metadata_json_schema` to pass additional
|
469
|
-
# request arguments.
|
470
|
-
json_schema = prompt.metadata.get('json_schema')
|
471
|
-
if json_schema is not None:
|
472
|
-
if not isinstance(json_schema, dict):
|
473
|
-
raise ValueError(
|
474
|
-
f'`json_schema` must be a dict, got {json_schema!r}.'
|
475
|
-
)
|
476
|
-
if 'title' not in json_schema:
|
477
|
-
raise ValueError(
|
478
|
-
f'The root of `json_schema` must have a `title` field, '
|
479
|
-
f'got {json_schema!r}.'
|
480
|
-
)
|
481
|
-
request_args.update(
|
482
|
-
response_format=dict(
|
483
|
-
type='json_schema',
|
484
|
-
json_schema=dict(
|
485
|
-
schema=json_schema,
|
486
|
-
name=json_schema['title'],
|
487
|
-
strict=True,
|
488
|
-
)
|
489
|
-
)
|
490
|
-
)
|
491
|
-
prompt.metadata.formatted_text = (
|
492
|
-
prompt.text
|
493
|
-
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
|
494
|
-
+ pg.to_json_str(request_args['response_format'], json_indent=2)
|
495
|
-
)
|
496
|
-
|
497
|
-
# Prepare messages.
|
498
|
-
messages = []
|
499
|
-
# Users could use `metadata_system_message` to pass system message.
|
500
|
-
system_message = prompt.metadata.get('system_message')
|
501
|
-
if system_message:
|
502
|
-
system_message = lf.SystemMessage.from_value(system_message)
|
503
|
-
messages.append(
|
504
|
-
dict(role='system',
|
505
|
-
content=self._content_from_message(system_message))
|
506
|
-
)
|
507
|
-
messages.append(
|
508
|
-
dict(role='user', content=self._content_from_message(prompt))
|
509
|
-
)
|
510
|
-
request = dict()
|
511
|
-
request.update(request_args)
|
512
|
-
request['messages'] = messages
|
513
|
-
return request
|
514
|
-
|
515
|
-
def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
|
516
|
-
# Reference:
|
517
|
-
# https://platform.openai.com/docs/api-reference/chat/object
|
518
|
-
logprobs = None
|
519
|
-
choice_logprobs = choice.get('logprobs')
|
520
|
-
if choice_logprobs:
|
521
|
-
logprobs = [
|
522
|
-
(
|
523
|
-
t['token'],
|
524
|
-
t['logprob'],
|
525
|
-
[(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
|
526
|
-
)
|
527
|
-
for t in choice_logprobs['content']
|
528
|
-
]
|
529
|
-
return lf.LMSample(
|
530
|
-
choice['message']['content'],
|
531
|
-
score=0.0,
|
532
|
-
logprobs=logprobs,
|
533
|
-
)
|
534
|
-
|
535
|
-
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
536
|
-
usage = json['usage']
|
537
|
-
return lf.LMSamplingResult(
|
538
|
-
samples=[self._parse_choice(choice) for choice in json['choices']],
|
539
|
-
usage=lf.LMSamplingUsage(
|
540
|
-
prompt_tokens=usage['prompt_tokens'],
|
541
|
-
completion_tokens=usage['completion_tokens'],
|
542
|
-
total_tokens=usage['total_tokens'],
|
543
|
-
estimated_cost=self.estimate_cost(
|
544
|
-
num_input_tokens=usage['prompt_tokens'],
|
545
|
-
num_output_tokens=usage['completion_tokens'],
|
546
|
-
)
|
547
|
-
),
|
548
|
-
)
|
407
|
+
# Reasoning models (o1 series) does not support `logprobs` by 2024/09/12.
|
408
|
+
if options.logprobs and self.model.startswith(('o1-', 'o3-')):
|
409
|
+
raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
|
410
|
+
return super()._request_args(options)
|
549
411
|
|
550
412
|
|
551
413
|
class GptO1(OpenAI):
|