dingo-python 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dingo/__init__.py +0 -0
- dingo/config/__init__.py +1 -0
- dingo/config/config.py +47 -0
- dingo/convert/__init__.py +4 -0
- dingo/convert/base.py +147 -0
- dingo/exec/__init__.py +3 -0
- dingo/exec/base.py +54 -0
- dingo/exec/local.py +288 -0
- dingo/exec/spark.py +169 -0
- dingo/io/__init__.py +2 -0
- dingo/io/export.py +0 -0
- dingo/io/input.py +27 -0
- dingo/io/summary.py +28 -0
- dingo/model/__init__.py +3 -0
- dingo/model/llm/__init__.py +0 -0
- dingo/model/llm/base.py +12 -0
- dingo/model/llm/common/__init__.py +0 -0
- dingo/model/llm/common/base_llm.py +395 -0
- dingo/model/llm/common/base_llm_api.py +396 -0
- dingo/model/llm/common/openai_api.py +222 -0
- dingo/model/llm/common/turbomind_api.py +148 -0
- dingo/model/llm/gpt.py +62 -0
- dingo/model/llm/llama3.py +97 -0
- dingo/model/llm/perspective.py +68 -0
- dingo/model/model.py +227 -0
- dingo/model/rule/__init__.py +0 -0
- dingo/model/rule/base.py +14 -0
- dingo/model/rule/common_rule.py +551 -0
- dingo/model/rule/image_rule.py +81 -0
- dingo/model/rule/prompt_rule.py +39 -0
- dingo/model/rule/util.py +282 -0
- dingo/utils/__init__.py +1 -0
- dingo/utils/log_util/__init__.py +32 -0
- dingo/utils/log_util/logger.py +39 -0
- dingo_python-1.0.dist-info/LICENSE +201 -0
- dingo_python-1.0.dist-info/METADATA +221 -0
- dingo_python-1.0.dist-info/RECORD +39 -0
- dingo_python-1.0.dist-info/WHEEL +5 -0
- dingo_python-1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import threading
|
|
3
|
+
import warnings
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from time import sleep
|
|
7
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from dingo.model.llm.common.base_llm import BaseLLMModel
|
|
11
|
+
from dingo.utils import log
|
|
12
|
+
|
|
13
|
+
PromptType = str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseLLMAPIModel(BaseLLMModel):
|
|
17
|
+
"""Base class for API model wrapper.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
path (str): The path to the model.
|
|
21
|
+
query_per_second (int): The maximum queries allowed per second
|
|
22
|
+
between two consecutive calls of the API. Defaults to 1.
|
|
23
|
+
retry (int): Number of retires if the API call fails. Defaults to 2.
|
|
24
|
+
max_seq_len (int): The maximum sequence length of the model. Defaults
|
|
25
|
+
to 2048.
|
|
26
|
+
meta_template (Dict, optional): The model's meta prompt
|
|
27
|
+
template if needed, in case the requirement of injecting or
|
|
28
|
+
wrapping of any meta instructions.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
is_api: bool = True
|
|
32
|
+
|
|
33
|
+
def __init__(self,
|
|
34
|
+
path: str,
|
|
35
|
+
query_per_second: int = 1,
|
|
36
|
+
retry: int = 2,
|
|
37
|
+
max_seq_len: int = 2048,
|
|
38
|
+
meta_template: Optional[Dict] = None):
|
|
39
|
+
self.path = path
|
|
40
|
+
self.max_seq_len = max_seq_len
|
|
41
|
+
self.meta_template = meta_template
|
|
42
|
+
self.retry = retry
|
|
43
|
+
self.query_per_second = query_per_second
|
|
44
|
+
self.token_bucket = TokenBucket(query_per_second)
|
|
45
|
+
self.template_parser = APITemplateParser(meta_template)
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def generate(self, inputs: List[PromptType],
|
|
49
|
+
max_out_len: int) -> List[str]:
|
|
50
|
+
"""Generate results given a list of inputs.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
|
54
|
+
The PromptDict should be organized in OpenCompass'
|
|
55
|
+
API format.
|
|
56
|
+
max_out_len (int): The maximum length of the output.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List[str]: A list of generated strings.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def get_ppl(self,
|
|
64
|
+
inputs: List[PromptType],
|
|
65
|
+
mask_length: Optional[List[int]] = None) -> List[float]:
|
|
66
|
+
"""Get perplexity scores given a list of inputs.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
inputs (List[str or PromptList]): A list of strings.
|
|
70
|
+
mask_length (Optional[List[int]]): A list of mask lengths. If
|
|
71
|
+
provided, the perplexity scores will be calculated with the
|
|
72
|
+
first mask_length[i] tokens masked out. It's okay to skip
|
|
73
|
+
its implementation if advanced features in PPLInfernecer is
|
|
74
|
+
not needed.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
List[float]: A list of perplexity scores.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def get_token_len(self, prompt: str) -> int:
|
|
81
|
+
"""Get lengths of the tokenized string. Only English and Chinese
|
|
82
|
+
characters are counted for now. Users are encouraged to override this
|
|
83
|
+
method if more accurate length is needed.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
prompt (str): Input string.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
int: Length of the input tokens
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
|
|
93
|
+
chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
|
|
94
|
+
|
|
95
|
+
# Count English words
|
|
96
|
+
english_count = sum(len(part.split()) for part in english_parts)
|
|
97
|
+
|
|
98
|
+
# Count Chinese words
|
|
99
|
+
chinese_count = sum(len(part) for part in chinese_parts)
|
|
100
|
+
|
|
101
|
+
return english_count + chinese_count
|
|
102
|
+
|
|
103
|
+
def wait(self):
|
|
104
|
+
"""Wait till the next query can be sent.
|
|
105
|
+
|
|
106
|
+
Applicable in both single-thread and multi-thread environments.
|
|
107
|
+
"""
|
|
108
|
+
return self.token_bucket.get_token()
|
|
109
|
+
|
|
110
|
+
def to(self, device):
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class APITemplateParser:
|
|
115
|
+
"""Intermediate prompt template parser, specifically for API models.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
meta_template (Dict): The meta template for the model.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, meta_template: Optional[Dict] = None):
|
|
122
|
+
self.meta_template = meta_template
|
|
123
|
+
# Check meta template
|
|
124
|
+
if meta_template:
|
|
125
|
+
assert 'round' in meta_template, 'round is required in meta' \
|
|
126
|
+
' template'
|
|
127
|
+
assert isinstance(meta_template['round'], list)
|
|
128
|
+
keys_to_check = ['round']
|
|
129
|
+
|
|
130
|
+
if 'reserved_roles' in meta_template:
|
|
131
|
+
assert isinstance(meta_template['reserved_roles'], list)
|
|
132
|
+
keys_to_check.append('reserved_roles')
|
|
133
|
+
|
|
134
|
+
self.roles: Dict[str, dict] = dict() # maps role name to config
|
|
135
|
+
for meta_key in keys_to_check:
|
|
136
|
+
for item in meta_template[meta_key]:
|
|
137
|
+
assert isinstance(item, (str, dict))
|
|
138
|
+
if isinstance(item, dict):
|
|
139
|
+
assert item['role'] not in self.roles, \
|
|
140
|
+
'role in meta prompt must be unique!'
|
|
141
|
+
self.roles[item['role']] = item.copy()
|
|
142
|
+
|
|
143
|
+
def parse_template(self, prompt_template: PromptType,
|
|
144
|
+
mode: str) -> PromptType:
|
|
145
|
+
"""Parse the intermediate prompt template, and wrap it with meta
|
|
146
|
+
template if applicable. When the meta template is set and the input is
|
|
147
|
+
a PromptList, the return value will be a PromptList containing the full
|
|
148
|
+
conversation history. Each item looks like:
|
|
149
|
+
|
|
150
|
+
.. code-block:: python
|
|
151
|
+
|
|
152
|
+
{'role': 'user', 'prompt': '...'}).
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
prompt_template (List[str or PromptList]): An intermediate prompt
|
|
156
|
+
template (potentially before being wrapped by meta template).
|
|
157
|
+
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
List[str or PromptList]: The finalized prompt or a conversation.
|
|
161
|
+
"""
|
|
162
|
+
assert isinstance(prompt_template, (str, list))
|
|
163
|
+
|
|
164
|
+
if not isinstance(prompt_template, (str)):
|
|
165
|
+
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
|
166
|
+
|
|
167
|
+
assert mode in ['ppl', 'gen']
|
|
168
|
+
if isinstance(prompt_template, str):
|
|
169
|
+
return prompt_template
|
|
170
|
+
if self.meta_template:
|
|
171
|
+
|
|
172
|
+
prompt = str
|
|
173
|
+
# Whether to keep generating the prompt
|
|
174
|
+
generate = True
|
|
175
|
+
|
|
176
|
+
section_stack = [] # stores tuples: (section_name, start_idx)
|
|
177
|
+
|
|
178
|
+
for i, item in enumerate(prompt_template):
|
|
179
|
+
if not generate:
|
|
180
|
+
break
|
|
181
|
+
if isinstance(item, str):
|
|
182
|
+
if item.strip():
|
|
183
|
+
# TODO: logger
|
|
184
|
+
warnings.warn('Non-empty string in prompt template '
|
|
185
|
+
'will be ignored in API models.')
|
|
186
|
+
elif isinstance(item, dict) and 'section' in item:
|
|
187
|
+
if item['pos'] == 'end':
|
|
188
|
+
section_name, start_idx = section_stack.pop(-1)
|
|
189
|
+
assert section_name == item['section']
|
|
190
|
+
if section_name in ['round', 'ice']:
|
|
191
|
+
dialogue = prompt_template[start_idx:i]
|
|
192
|
+
round_ranges = self._split_rounds(
|
|
193
|
+
dialogue, self.meta_template['round'])
|
|
194
|
+
# Consider inserting multiple round examples into
|
|
195
|
+
# template
|
|
196
|
+
for i in range(len(round_ranges) - 1):
|
|
197
|
+
start = round_ranges[i]
|
|
198
|
+
end = round_ranges[i + 1]
|
|
199
|
+
round_template = dialogue[start:end]
|
|
200
|
+
role_dict = self._update_role_dict(
|
|
201
|
+
round_template)
|
|
202
|
+
api_prompts, generate = self._prompt2api(
|
|
203
|
+
self.meta_template['round'],
|
|
204
|
+
role_dict,
|
|
205
|
+
# Start generating only when the mode is in
|
|
206
|
+
# generation and the template reaches the
|
|
207
|
+
# last round
|
|
208
|
+
for_gen=mode == 'gen'
|
|
209
|
+
and section_name == 'round'
|
|
210
|
+
and i == len(round_ranges) - 2)
|
|
211
|
+
prompt += api_prompts
|
|
212
|
+
elif item['pos'] == 'begin':
|
|
213
|
+
assert item['section'] in [
|
|
214
|
+
'begin', 'round', 'end', 'ice'
|
|
215
|
+
]
|
|
216
|
+
section_stack.append((item['section'], i + 1))
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError(f'Invalid pos {item["pos"]}')
|
|
219
|
+
elif section_stack[-1][0] in ['begin', 'end']:
|
|
220
|
+
role_dict = self._update_role_dict(item)
|
|
221
|
+
api_prompts, generate = self._prompt2api(
|
|
222
|
+
item, role_dict, for_gen=mode == 'gen')
|
|
223
|
+
prompt.append(api_prompts)
|
|
224
|
+
|
|
225
|
+
# merge the consecutive prompts assigned to the same role
|
|
226
|
+
new_prompt = str([prompt[0]])
|
|
227
|
+
last_role = prompt[0]['role']
|
|
228
|
+
for item in prompt[1:]:
|
|
229
|
+
if item['role'] == last_role:
|
|
230
|
+
new_prompt[-1]['prompt'] += '\n' + item['prompt']
|
|
231
|
+
else:
|
|
232
|
+
last_role = item['role']
|
|
233
|
+
new_prompt.append(item)
|
|
234
|
+
prompt = new_prompt
|
|
235
|
+
|
|
236
|
+
else:
|
|
237
|
+
# in case the model does not have any meta template
|
|
238
|
+
prompt = ''
|
|
239
|
+
last_sep = ''
|
|
240
|
+
for item in prompt_template:
|
|
241
|
+
if isinstance(item, dict) and {'section', 'pos'} == set(
|
|
242
|
+
item.keys()):
|
|
243
|
+
continue
|
|
244
|
+
if isinstance(item, str):
|
|
245
|
+
if item:
|
|
246
|
+
prompt += last_sep + item
|
|
247
|
+
elif item.get('prompt', ''):
|
|
248
|
+
prompt += last_sep + item.get('prompt', '')
|
|
249
|
+
last_sep = '\n'
|
|
250
|
+
return prompt
|
|
251
|
+
|
|
252
|
+
def _update_role_dict(self, prompts: Union[List, str]) -> Dict[str, Dict]:
|
|
253
|
+
"""Update the default role dict with the given prompts."""
|
|
254
|
+
role_dict = deepcopy(self.roles)
|
|
255
|
+
if isinstance(prompts, str):
|
|
256
|
+
return role_dict
|
|
257
|
+
elif isinstance(prompts, dict):
|
|
258
|
+
prompts = [prompts]
|
|
259
|
+
for prompt in prompts:
|
|
260
|
+
if isinstance(prompt, dict):
|
|
261
|
+
role = prompt['role']
|
|
262
|
+
if role not in self.roles:
|
|
263
|
+
role = prompt.get('fallback_role', None)
|
|
264
|
+
if not role:
|
|
265
|
+
log.info(f'{prompt} neither has an appropriate role nor a fallback role.')
|
|
266
|
+
role_dict[role].update(prompt)
|
|
267
|
+
return role_dict
|
|
268
|
+
|
|
269
|
+
def _split_rounds(
|
|
270
|
+
self, prompt_template: List[Union[str, Dict]],
|
|
271
|
+
single_round_template: List[Union[str, Dict]]) -> List[int]:
|
|
272
|
+
"""Split the prompt template into rounds, based on single round
|
|
273
|
+
template.
|
|
274
|
+
|
|
275
|
+
Return the index ranges of each round. Specifically,
|
|
276
|
+
prompt_template[res[i]:res[i+1]] represents the i-th round in the
|
|
277
|
+
template.
|
|
278
|
+
"""
|
|
279
|
+
role_idxs = {
|
|
280
|
+
role_cfg['role']: i
|
|
281
|
+
for i, role_cfg in enumerate(single_round_template)
|
|
282
|
+
if not isinstance(role_cfg, str)
|
|
283
|
+
}
|
|
284
|
+
last_role_idx = -1
|
|
285
|
+
cutoff_idxs = [0]
|
|
286
|
+
for idx, template in enumerate(prompt_template):
|
|
287
|
+
if isinstance(template, str):
|
|
288
|
+
continue
|
|
289
|
+
role_idx = role_idxs.get(template['role'], None)
|
|
290
|
+
if role_idx is None:
|
|
291
|
+
try:
|
|
292
|
+
role_idx = role_idxs[template['fallback_role']]
|
|
293
|
+
except KeyError:
|
|
294
|
+
raise KeyError(f'{template} neither has an appropriate '
|
|
295
|
+
'role nor a fallback role.')
|
|
296
|
+
if role_idx <= last_role_idx:
|
|
297
|
+
cutoff_idxs.append(idx)
|
|
298
|
+
last_role_idx = role_idx
|
|
299
|
+
cutoff_idxs.append(len(prompt_template))
|
|
300
|
+
return cutoff_idxs
|
|
301
|
+
|
|
302
|
+
def _prompt2api(self,
|
|
303
|
+
prompts: Union[List, str],
|
|
304
|
+
role_dict: Dict[str, Dict],
|
|
305
|
+
for_gen: bool = False) -> Tuple[str, bool]:
|
|
306
|
+
"""Convert the prompts to a API-style prompts, given an updated
|
|
307
|
+
role_dict.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
prompts (Union[List, str]): The prompts to be converted.
|
|
311
|
+
role_dict (Dict[str, Dict]): The updated role dict.
|
|
312
|
+
for_gen (bool): If True, the prompts will be converted for
|
|
313
|
+
generation tasks. The conversion stops before the first
|
|
314
|
+
role whose "generate" is set to True.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Tuple[str, bool]: The converted string, and whether the follow-up
|
|
318
|
+
conversion should be proceeded.
|
|
319
|
+
"""
|
|
320
|
+
cont = True
|
|
321
|
+
if isinstance(prompts, str):
|
|
322
|
+
return prompts, cont
|
|
323
|
+
elif isinstance(prompts, dict):
|
|
324
|
+
api_role, cont = self._role2api_role(prompts, role_dict, for_gen)
|
|
325
|
+
return api_role, cont
|
|
326
|
+
|
|
327
|
+
res = []
|
|
328
|
+
for prompt in prompts:
|
|
329
|
+
if isinstance(prompt, str):
|
|
330
|
+
raise TypeError('Mixing str without explictt role is not '
|
|
331
|
+
'allowed in API models!')
|
|
332
|
+
else:
|
|
333
|
+
api_role, cont = self._role2api_role(prompt, role_dict,
|
|
334
|
+
for_gen)
|
|
335
|
+
if api_role:
|
|
336
|
+
res.append(api_role)
|
|
337
|
+
if not cont:
|
|
338
|
+
break
|
|
339
|
+
return res, cont
|
|
340
|
+
|
|
341
|
+
def _role2api_role(self,
|
|
342
|
+
role_prompt: Dict,
|
|
343
|
+
role_dict: Dict[str, Dict],
|
|
344
|
+
for_gen: bool = False) -> Tuple[str, bool]:
|
|
345
|
+
"""Convert a role prompt to a string, given an updated role_dict.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
role_prompt (Dict): The role prompt to be converted.
|
|
349
|
+
role_dict (Dict[str, Dict]): The updated role dict.
|
|
350
|
+
for_gen (bool): If True, the prompts will be converted for
|
|
351
|
+
generation tasks. The conversion stops before the first
|
|
352
|
+
role whose "generate" is set to True.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
Tuple[str, bool]: The converted string, and whether the follow-up
|
|
356
|
+
conversion should be proceeded.
|
|
357
|
+
"""
|
|
358
|
+
merged_prompt = role_dict.get(
|
|
359
|
+
role_prompt['role'],
|
|
360
|
+
role_dict.get(role_prompt.get('fallback_role')))
|
|
361
|
+
# res_api_prompt = dict(type='', )
|
|
362
|
+
if for_gen and merged_prompt.get('generate', False):
|
|
363
|
+
return None, False
|
|
364
|
+
res = {}
|
|
365
|
+
res['role'] = merged_prompt['api_role']
|
|
366
|
+
res['prompt'] = merged_prompt.get('begin', '')
|
|
367
|
+
res['prompt'] += merged_prompt.get('prompt', '')
|
|
368
|
+
res['prompt'] += merged_prompt.get('end', '')
|
|
369
|
+
return res, True
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class TokenBucket:
|
|
373
|
+
"""A token bucket for rate limiting.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
query_per_second (float): The rate of the token bucket.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(self, rate):
|
|
380
|
+
self._rate = rate
|
|
381
|
+
self._tokens = threading.Semaphore(0)
|
|
382
|
+
self.started = False
|
|
383
|
+
|
|
384
|
+
def _add_tokens(self):
|
|
385
|
+
"""Add tokens to the bucket."""
|
|
386
|
+
while True:
|
|
387
|
+
if self._tokens._value < self._rate:
|
|
388
|
+
self._tokens.release()
|
|
389
|
+
sleep(1 / self._rate)
|
|
390
|
+
|
|
391
|
+
def get_token(self):
|
|
392
|
+
"""Get a token from the bucket."""
|
|
393
|
+
if not self.started:
|
|
394
|
+
self.started = True
|
|
395
|
+
threading.Thread(target=self._add_tokens, daemon=True).start()
|
|
396
|
+
self._tokens.acquire()
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
import requests
|
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
6
|
+
from threading import Lock
|
|
7
|
+
from typing import Dict, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
from dingo.model.llm.common.base_llm_api import BaseLLMAPIModel, PromptType
|
|
10
|
+
from dingo.utils import log
|
|
11
|
+
|
|
12
|
+
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OpenAI(BaseLLMAPIModel):
|
|
16
|
+
"""Model wrapper around OpenAI's models.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
path (str): The name of OpenAI's model.
|
|
20
|
+
max_seq_len (int): The maximum allowed sequence length of a model.
|
|
21
|
+
Note that the length of prompt + generated tokens shall not exceed
|
|
22
|
+
this value. Defaults to 2048.
|
|
23
|
+
query_per_second (int): The maximum queries allowed per second
|
|
24
|
+
between two consecutive calls of the API. Defaults to 1.
|
|
25
|
+
retry (int): Number of retires if the API call fails. Defaults to 2.
|
|
26
|
+
key (str or List[str]): OpenAI key(s). In particular, when it
|
|
27
|
+
is set to "ENV", the key will be fetched from the environment
|
|
28
|
+
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
|
|
29
|
+
list, the keys will be used in round-robin manner. Defaults to
|
|
30
|
+
'ENV'.
|
|
31
|
+
org (str or List[str], optional): OpenAI organization(s). If not
|
|
32
|
+
specified, OpenAI uses the default organization bound to each API
|
|
33
|
+
key. If specified, the orgs will be posted with each request in
|
|
34
|
+
round-robin manner. Defaults to None.
|
|
35
|
+
meta_template (Dict, optional): The model's meta prompt
|
|
36
|
+
template if needed, in case the requirement of injecting or
|
|
37
|
+
wrapping of any meta instructions.
|
|
38
|
+
openai_api_base (str): The base url of OpenAI's API. Defaults to
|
|
39
|
+
'https://api.openai.com/v1/chat/completions'.
|
|
40
|
+
temperature (float, optional): What sampling temperature to use.
|
|
41
|
+
If not None, will override the temperature in the `generate()`
|
|
42
|
+
call. Defaults to None.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
is_api: bool = True
|
|
46
|
+
|
|
47
|
+
def __init__(self,
|
|
48
|
+
path: str,
|
|
49
|
+
max_seq_len: int = 2048,
|
|
50
|
+
query_per_second: int = 1,
|
|
51
|
+
retry: int = 2,
|
|
52
|
+
key: Union[str, List[str]] = "ENV",
|
|
53
|
+
org: Optional[Union[str, List[str]]] = None,
|
|
54
|
+
meta_template: Optional[Dict] = None,
|
|
55
|
+
openai_api_base: str = OPENAI_API_BASE,
|
|
56
|
+
temperature: Optional[float] = None):
|
|
57
|
+
|
|
58
|
+
super().__init__(path=path,
|
|
59
|
+
max_seq_len=max_seq_len,
|
|
60
|
+
meta_template=meta_template,
|
|
61
|
+
query_per_second=query_per_second,
|
|
62
|
+
retry=retry)
|
|
63
|
+
try:
|
|
64
|
+
import tiktoken
|
|
65
|
+
except ImportError:
|
|
66
|
+
raise ImportError('tiktoken is not installed, please install tiktoken.')
|
|
67
|
+
self.tiktoken = tiktoken
|
|
68
|
+
self.temperature = temperature
|
|
69
|
+
|
|
70
|
+
if isinstance(key, str):
|
|
71
|
+
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
|
72
|
+
else:
|
|
73
|
+
self.keys = key
|
|
74
|
+
self.key_ctr = 0
|
|
75
|
+
if isinstance(org, str):
|
|
76
|
+
self.orgs = [org]
|
|
77
|
+
else:
|
|
78
|
+
self.orgs = org
|
|
79
|
+
self.org_ctr = 0
|
|
80
|
+
self.url = openai_api_base
|
|
81
|
+
|
|
82
|
+
def generate(
|
|
83
|
+
self,
|
|
84
|
+
inputs: List[str],
|
|
85
|
+
max_out_len: int = 512,
|
|
86
|
+
temperature: float = 0.7,
|
|
87
|
+
) -> List[str]:
|
|
88
|
+
"""Generate results given a list of inputs.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
|
92
|
+
The PromptDict should be organized in OpenCompass'
|
|
93
|
+
API format.
|
|
94
|
+
max_out_len (int): The maximum length of the output.
|
|
95
|
+
temperature (float): What sampling temperature to use,
|
|
96
|
+
between 0 and 2. Higher values like 0.8 will make the output
|
|
97
|
+
more random, while lower values like 0.2 will make it more
|
|
98
|
+
focused and deterministic. Defaults to 0.7.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
List[str]: A list of generated strings.
|
|
102
|
+
"""
|
|
103
|
+
if self.temperature is not None:
|
|
104
|
+
temperature = self.temperature
|
|
105
|
+
|
|
106
|
+
with ThreadPoolExecutor() as executor:
|
|
107
|
+
results = list(
|
|
108
|
+
executor.map(self._generate, inputs,
|
|
109
|
+
[max_out_len] * len(inputs),
|
|
110
|
+
[temperature] * len(inputs)))
|
|
111
|
+
return results
|
|
112
|
+
|
|
113
|
+
def _generate(self, input: str, max_out_len: int,
|
|
114
|
+
temperature: float) -> str:
|
|
115
|
+
"""Generate results given a list of inputs.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
inputs (str or PromptList): A string or PromptDict.
|
|
119
|
+
The PromptDict should be organized in OpenCompass'
|
|
120
|
+
API format.
|
|
121
|
+
max_out_len (int): The maximum length of the output.
|
|
122
|
+
temperature (float): What sampling temperature to use,
|
|
123
|
+
between 0 and 2. Higher values like 0.8 will make the output
|
|
124
|
+
more random, while lower values like 0.2 will make it more
|
|
125
|
+
focused and deterministic.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
str: The generated string.
|
|
129
|
+
"""
|
|
130
|
+
assert isinstance(input, (str))
|
|
131
|
+
|
|
132
|
+
if isinstance(input, str):
|
|
133
|
+
messages = [{'role': 'user', 'content': input}]
|
|
134
|
+
else:
|
|
135
|
+
messages = []
|
|
136
|
+
for item in input:
|
|
137
|
+
msg = {'content': item['prompt']}
|
|
138
|
+
if item['role'] == 'HUMAN':
|
|
139
|
+
msg['role'] = 'user'
|
|
140
|
+
elif item['role'] == 'BOT':
|
|
141
|
+
msg['role'] = 'assistant'
|
|
142
|
+
elif item['role'] == 'SYSTEM':
|
|
143
|
+
msg['role'] = 'system'
|
|
144
|
+
messages.append(msg)
|
|
145
|
+
|
|
146
|
+
# max num token for gpt-3.5-turbo is 4097
|
|
147
|
+
max_out_len = min(max_out_len, 4000 - self.get_token_len(str(input)))
|
|
148
|
+
if max_out_len <= 0:
|
|
149
|
+
return ''
|
|
150
|
+
|
|
151
|
+
max_num_retries = 0
|
|
152
|
+
while max_num_retries < self.retry:
|
|
153
|
+
self.wait()
|
|
154
|
+
if hasattr(self, 'keys'):
|
|
155
|
+
with Lock():
|
|
156
|
+
self.key_ctr += 1
|
|
157
|
+
if self.key_ctr == len(self.keys):
|
|
158
|
+
self.key_ctr = 0
|
|
159
|
+
header = {
|
|
160
|
+
'Authorization': f'Bearer {self.keys[self.key_ctr]}',
|
|
161
|
+
'content-type': 'application/json',
|
|
162
|
+
}
|
|
163
|
+
if self.orgs:
|
|
164
|
+
with Lock():
|
|
165
|
+
self.org_ctr += 1
|
|
166
|
+
if self.org_ctr == len(self.orgs):
|
|
167
|
+
self.org_ctr = 0
|
|
168
|
+
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
data = dict(
|
|
172
|
+
model=self.path,
|
|
173
|
+
messages=messages,
|
|
174
|
+
# response_format={"type": "json_object"},
|
|
175
|
+
max_tokens=max_out_len,
|
|
176
|
+
n=1,
|
|
177
|
+
stop=None,
|
|
178
|
+
temperature=temperature,
|
|
179
|
+
)
|
|
180
|
+
raw_response = requests.post(self.url,
|
|
181
|
+
headers=header,
|
|
182
|
+
data=json.dumps(data))
|
|
183
|
+
except requests.ConnectionError:
|
|
184
|
+
log.error('Got connection error, retrying...')
|
|
185
|
+
continue
|
|
186
|
+
try:
|
|
187
|
+
response = raw_response.json()
|
|
188
|
+
except requests.JSONDecodeError:
|
|
189
|
+
log.error('JsonDecode error, got ' + str(raw_response.content))
|
|
190
|
+
continue
|
|
191
|
+
try:
|
|
192
|
+
return response['choices'][0]['message']['content'].strip()
|
|
193
|
+
except KeyError:
|
|
194
|
+
if 'error' in response:
|
|
195
|
+
if response['error']['code'] == 'rate_limit_exceeded':
|
|
196
|
+
time.sleep(1)
|
|
197
|
+
continue
|
|
198
|
+
log.error('Find error message in response: ' + str(response['error']))
|
|
199
|
+
max_num_retries += 1
|
|
200
|
+
|
|
201
|
+
raise RuntimeError('Calling OpenAI failed after retrying for '
|
|
202
|
+
f'{max_num_retries} times. Check the logs for '
|
|
203
|
+
'details.')
|
|
204
|
+
|
|
205
|
+
def get_ppl(self,
|
|
206
|
+
inputs: List[PromptType],
|
|
207
|
+
mask_length: Optional[List[int]] = None) -> List[float]:
|
|
208
|
+
raise NotImplementedError('get_ppl is not implemented for `OpenAI`')
|
|
209
|
+
|
|
210
|
+
def get_token_len(self, prompt: str) -> int:
|
|
211
|
+
"""Get lengths of the tokenized string. Only English and Chinese
|
|
212
|
+
characters are counted for now. Users are encouraged to override this
|
|
213
|
+
method if more accurate length is needed.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
prompt (str): Input string.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
int: Length of the input tokens
|
|
220
|
+
"""
|
|
221
|
+
enc = self.tiktoken.encoding_for_model(self.path)
|
|
222
|
+
return len(enc.encode(prompt))
|