dingo-python 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,396 @@
1
+ import re
2
+ import threading
3
+ import warnings
4
+ from abc import abstractmethod
5
+ from copy import deepcopy
6
+ from time import sleep
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+
10
+ from dingo.model.llm.common.base_llm import BaseLLMModel
11
+ from dingo.utils import log
12
+
13
+ PromptType = str
14
+
15
+
16
+ class BaseLLMAPIModel(BaseLLMModel):
17
+ """Base class for API model wrapper.
18
+
19
+ Args:
20
+ path (str): The path to the model.
21
+ query_per_second (int): The maximum queries allowed per second
22
+ between two consecutive calls of the API. Defaults to 1.
23
+ retry (int): Number of retires if the API call fails. Defaults to 2.
24
+ max_seq_len (int): The maximum sequence length of the model. Defaults
25
+ to 2048.
26
+ meta_template (Dict, optional): The model's meta prompt
27
+ template if needed, in case the requirement of injecting or
28
+ wrapping of any meta instructions.
29
+ """
30
+
31
+ is_api: bool = True
32
+
33
+ def __init__(self,
34
+ path: str,
35
+ query_per_second: int = 1,
36
+ retry: int = 2,
37
+ max_seq_len: int = 2048,
38
+ meta_template: Optional[Dict] = None):
39
+ self.path = path
40
+ self.max_seq_len = max_seq_len
41
+ self.meta_template = meta_template
42
+ self.retry = retry
43
+ self.query_per_second = query_per_second
44
+ self.token_bucket = TokenBucket(query_per_second)
45
+ self.template_parser = APITemplateParser(meta_template)
46
+
47
+ @abstractmethod
48
+ def generate(self, inputs: List[PromptType],
49
+ max_out_len: int) -> List[str]:
50
+ """Generate results given a list of inputs.
51
+
52
+ Args:
53
+ inputs (List[str or PromptList]): A list of strings or PromptDicts.
54
+ The PromptDict should be organized in OpenCompass'
55
+ API format.
56
+ max_out_len (int): The maximum length of the output.
57
+
58
+ Returns:
59
+ List[str]: A list of generated strings.
60
+ """
61
+
62
+ @abstractmethod
63
+ def get_ppl(self,
64
+ inputs: List[PromptType],
65
+ mask_length: Optional[List[int]] = None) -> List[float]:
66
+ """Get perplexity scores given a list of inputs.
67
+
68
+ Args:
69
+ inputs (List[str or PromptList]): A list of strings.
70
+ mask_length (Optional[List[int]]): A list of mask lengths. If
71
+ provided, the perplexity scores will be calculated with the
72
+ first mask_length[i] tokens masked out. It's okay to skip
73
+ its implementation if advanced features in PPLInfernecer is
74
+ not needed.
75
+
76
+ Returns:
77
+ List[float]: A list of perplexity scores.
78
+ """
79
+
80
+ def get_token_len(self, prompt: str) -> int:
81
+ """Get lengths of the tokenized string. Only English and Chinese
82
+ characters are counted for now. Users are encouraged to override this
83
+ method if more accurate length is needed.
84
+
85
+ Args:
86
+ prompt (str): Input string.
87
+
88
+ Returns:
89
+ int: Length of the input tokens
90
+ """
91
+
92
+ english_parts = re.findall(r'[A-Za-z0-9]+', prompt)
93
+ chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt)
94
+
95
+ # Count English words
96
+ english_count = sum(len(part.split()) for part in english_parts)
97
+
98
+ # Count Chinese words
99
+ chinese_count = sum(len(part) for part in chinese_parts)
100
+
101
+ return english_count + chinese_count
102
+
103
+ def wait(self):
104
+ """Wait till the next query can be sent.
105
+
106
+ Applicable in both single-thread and multi-thread environments.
107
+ """
108
+ return self.token_bucket.get_token()
109
+
110
+ def to(self, device):
111
+ pass
112
+
113
+
114
+ class APITemplateParser:
115
+ """Intermediate prompt template parser, specifically for API models.
116
+
117
+ Args:
118
+ meta_template (Dict): The meta template for the model.
119
+ """
120
+
121
+ def __init__(self, meta_template: Optional[Dict] = None):
122
+ self.meta_template = meta_template
123
+ # Check meta template
124
+ if meta_template:
125
+ assert 'round' in meta_template, 'round is required in meta' \
126
+ ' template'
127
+ assert isinstance(meta_template['round'], list)
128
+ keys_to_check = ['round']
129
+
130
+ if 'reserved_roles' in meta_template:
131
+ assert isinstance(meta_template['reserved_roles'], list)
132
+ keys_to_check.append('reserved_roles')
133
+
134
+ self.roles: Dict[str, dict] = dict() # maps role name to config
135
+ for meta_key in keys_to_check:
136
+ for item in meta_template[meta_key]:
137
+ assert isinstance(item, (str, dict))
138
+ if isinstance(item, dict):
139
+ assert item['role'] not in self.roles, \
140
+ 'role in meta prompt must be unique!'
141
+ self.roles[item['role']] = item.copy()
142
+
143
+ def parse_template(self, prompt_template: PromptType,
144
+ mode: str) -> PromptType:
145
+ """Parse the intermediate prompt template, and wrap it with meta
146
+ template if applicable. When the meta template is set and the input is
147
+ a PromptList, the return value will be a PromptList containing the full
148
+ conversation history. Each item looks like:
149
+
150
+ .. code-block:: python
151
+
152
+ {'role': 'user', 'prompt': '...'}).
153
+
154
+ Args:
155
+ prompt_template (List[str or PromptList]): An intermediate prompt
156
+ template (potentially before being wrapped by meta template).
157
+ mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
158
+
159
+ Returns:
160
+ List[str or PromptList]: The finalized prompt or a conversation.
161
+ """
162
+ assert isinstance(prompt_template, (str, list))
163
+
164
+ if not isinstance(prompt_template, (str)):
165
+ return [self.parse_template(p, mode=mode) for p in prompt_template]
166
+
167
+ assert mode in ['ppl', 'gen']
168
+ if isinstance(prompt_template, str):
169
+ return prompt_template
170
+ if self.meta_template:
171
+
172
+ prompt = str
173
+ # Whether to keep generating the prompt
174
+ generate = True
175
+
176
+ section_stack = [] # stores tuples: (section_name, start_idx)
177
+
178
+ for i, item in enumerate(prompt_template):
179
+ if not generate:
180
+ break
181
+ if isinstance(item, str):
182
+ if item.strip():
183
+ # TODO: logger
184
+ warnings.warn('Non-empty string in prompt template '
185
+ 'will be ignored in API models.')
186
+ elif isinstance(item, dict) and 'section' in item:
187
+ if item['pos'] == 'end':
188
+ section_name, start_idx = section_stack.pop(-1)
189
+ assert section_name == item['section']
190
+ if section_name in ['round', 'ice']:
191
+ dialogue = prompt_template[start_idx:i]
192
+ round_ranges = self._split_rounds(
193
+ dialogue, self.meta_template['round'])
194
+ # Consider inserting multiple round examples into
195
+ # template
196
+ for i in range(len(round_ranges) - 1):
197
+ start = round_ranges[i]
198
+ end = round_ranges[i + 1]
199
+ round_template = dialogue[start:end]
200
+ role_dict = self._update_role_dict(
201
+ round_template)
202
+ api_prompts, generate = self._prompt2api(
203
+ self.meta_template['round'],
204
+ role_dict,
205
+ # Start generating only when the mode is in
206
+ # generation and the template reaches the
207
+ # last round
208
+ for_gen=mode == 'gen'
209
+ and section_name == 'round'
210
+ and i == len(round_ranges) - 2)
211
+ prompt += api_prompts
212
+ elif item['pos'] == 'begin':
213
+ assert item['section'] in [
214
+ 'begin', 'round', 'end', 'ice'
215
+ ]
216
+ section_stack.append((item['section'], i + 1))
217
+ else:
218
+ raise ValueError(f'Invalid pos {item["pos"]}')
219
+ elif section_stack[-1][0] in ['begin', 'end']:
220
+ role_dict = self._update_role_dict(item)
221
+ api_prompts, generate = self._prompt2api(
222
+ item, role_dict, for_gen=mode == 'gen')
223
+ prompt.append(api_prompts)
224
+
225
+ # merge the consecutive prompts assigned to the same role
226
+ new_prompt = str([prompt[0]])
227
+ last_role = prompt[0]['role']
228
+ for item in prompt[1:]:
229
+ if item['role'] == last_role:
230
+ new_prompt[-1]['prompt'] += '\n' + item['prompt']
231
+ else:
232
+ last_role = item['role']
233
+ new_prompt.append(item)
234
+ prompt = new_prompt
235
+
236
+ else:
237
+ # in case the model does not have any meta template
238
+ prompt = ''
239
+ last_sep = ''
240
+ for item in prompt_template:
241
+ if isinstance(item, dict) and {'section', 'pos'} == set(
242
+ item.keys()):
243
+ continue
244
+ if isinstance(item, str):
245
+ if item:
246
+ prompt += last_sep + item
247
+ elif item.get('prompt', ''):
248
+ prompt += last_sep + item.get('prompt', '')
249
+ last_sep = '\n'
250
+ return prompt
251
+
252
+ def _update_role_dict(self, prompts: Union[List, str]) -> Dict[str, Dict]:
253
+ """Update the default role dict with the given prompts."""
254
+ role_dict = deepcopy(self.roles)
255
+ if isinstance(prompts, str):
256
+ return role_dict
257
+ elif isinstance(prompts, dict):
258
+ prompts = [prompts]
259
+ for prompt in prompts:
260
+ if isinstance(prompt, dict):
261
+ role = prompt['role']
262
+ if role not in self.roles:
263
+ role = prompt.get('fallback_role', None)
264
+ if not role:
265
+ log.info(f'{prompt} neither has an appropriate role nor a fallback role.')
266
+ role_dict[role].update(prompt)
267
+ return role_dict
268
+
269
+ def _split_rounds(
270
+ self, prompt_template: List[Union[str, Dict]],
271
+ single_round_template: List[Union[str, Dict]]) -> List[int]:
272
+ """Split the prompt template into rounds, based on single round
273
+ template.
274
+
275
+ Return the index ranges of each round. Specifically,
276
+ prompt_template[res[i]:res[i+1]] represents the i-th round in the
277
+ template.
278
+ """
279
+ role_idxs = {
280
+ role_cfg['role']: i
281
+ for i, role_cfg in enumerate(single_round_template)
282
+ if not isinstance(role_cfg, str)
283
+ }
284
+ last_role_idx = -1
285
+ cutoff_idxs = [0]
286
+ for idx, template in enumerate(prompt_template):
287
+ if isinstance(template, str):
288
+ continue
289
+ role_idx = role_idxs.get(template['role'], None)
290
+ if role_idx is None:
291
+ try:
292
+ role_idx = role_idxs[template['fallback_role']]
293
+ except KeyError:
294
+ raise KeyError(f'{template} neither has an appropriate '
295
+ 'role nor a fallback role.')
296
+ if role_idx <= last_role_idx:
297
+ cutoff_idxs.append(idx)
298
+ last_role_idx = role_idx
299
+ cutoff_idxs.append(len(prompt_template))
300
+ return cutoff_idxs
301
+
302
+ def _prompt2api(self,
303
+ prompts: Union[List, str],
304
+ role_dict: Dict[str, Dict],
305
+ for_gen: bool = False) -> Tuple[str, bool]:
306
+ """Convert the prompts to a API-style prompts, given an updated
307
+ role_dict.
308
+
309
+ Args:
310
+ prompts (Union[List, str]): The prompts to be converted.
311
+ role_dict (Dict[str, Dict]): The updated role dict.
312
+ for_gen (bool): If True, the prompts will be converted for
313
+ generation tasks. The conversion stops before the first
314
+ role whose "generate" is set to True.
315
+
316
+ Returns:
317
+ Tuple[str, bool]: The converted string, and whether the follow-up
318
+ conversion should be proceeded.
319
+ """
320
+ cont = True
321
+ if isinstance(prompts, str):
322
+ return prompts, cont
323
+ elif isinstance(prompts, dict):
324
+ api_role, cont = self._role2api_role(prompts, role_dict, for_gen)
325
+ return api_role, cont
326
+
327
+ res = []
328
+ for prompt in prompts:
329
+ if isinstance(prompt, str):
330
+ raise TypeError('Mixing str without explictt role is not '
331
+ 'allowed in API models!')
332
+ else:
333
+ api_role, cont = self._role2api_role(prompt, role_dict,
334
+ for_gen)
335
+ if api_role:
336
+ res.append(api_role)
337
+ if not cont:
338
+ break
339
+ return res, cont
340
+
341
+ def _role2api_role(self,
342
+ role_prompt: Dict,
343
+ role_dict: Dict[str, Dict],
344
+ for_gen: bool = False) -> Tuple[str, bool]:
345
+ """Convert a role prompt to a string, given an updated role_dict.
346
+
347
+ Args:
348
+ role_prompt (Dict): The role prompt to be converted.
349
+ role_dict (Dict[str, Dict]): The updated role dict.
350
+ for_gen (bool): If True, the prompts will be converted for
351
+ generation tasks. The conversion stops before the first
352
+ role whose "generate" is set to True.
353
+
354
+ Returns:
355
+ Tuple[str, bool]: The converted string, and whether the follow-up
356
+ conversion should be proceeded.
357
+ """
358
+ merged_prompt = role_dict.get(
359
+ role_prompt['role'],
360
+ role_dict.get(role_prompt.get('fallback_role')))
361
+ # res_api_prompt = dict(type='', )
362
+ if for_gen and merged_prompt.get('generate', False):
363
+ return None, False
364
+ res = {}
365
+ res['role'] = merged_prompt['api_role']
366
+ res['prompt'] = merged_prompt.get('begin', '')
367
+ res['prompt'] += merged_prompt.get('prompt', '')
368
+ res['prompt'] += merged_prompt.get('end', '')
369
+ return res, True
370
+
371
+
372
+ class TokenBucket:
373
+ """A token bucket for rate limiting.
374
+
375
+ Args:
376
+ query_per_second (float): The rate of the token bucket.
377
+ """
378
+
379
+ def __init__(self, rate):
380
+ self._rate = rate
381
+ self._tokens = threading.Semaphore(0)
382
+ self.started = False
383
+
384
+ def _add_tokens(self):
385
+ """Add tokens to the bucket."""
386
+ while True:
387
+ if self._tokens._value < self._rate:
388
+ self._tokens.release()
389
+ sleep(1 / self._rate)
390
+
391
+ def get_token(self):
392
+ """Get a token from the bucket."""
393
+ if not self.started:
394
+ self.started = True
395
+ threading.Thread(target=self._add_tokens, daemon=True).start()
396
+ self._tokens.acquire()
@@ -0,0 +1,222 @@
1
+ import json
2
+ import os
3
+ import time
4
+ import requests
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from threading import Lock
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ from dingo.model.llm.common.base_llm_api import BaseLLMAPIModel, PromptType
10
+ from dingo.utils import log
11
+
12
+ OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
13
+
14
+
15
+ class OpenAI(BaseLLMAPIModel):
16
+ """Model wrapper around OpenAI's models.
17
+
18
+ Args:
19
+ path (str): The name of OpenAI's model.
20
+ max_seq_len (int): The maximum allowed sequence length of a model.
21
+ Note that the length of prompt + generated tokens shall not exceed
22
+ this value. Defaults to 2048.
23
+ query_per_second (int): The maximum queries allowed per second
24
+ between two consecutive calls of the API. Defaults to 1.
25
+ retry (int): Number of retires if the API call fails. Defaults to 2.
26
+ key (str or List[str]): OpenAI key(s). In particular, when it
27
+ is set to "ENV", the key will be fetched from the environment
28
+ variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
29
+ list, the keys will be used in round-robin manner. Defaults to
30
+ 'ENV'.
31
+ org (str or List[str], optional): OpenAI organization(s). If not
32
+ specified, OpenAI uses the default organization bound to each API
33
+ key. If specified, the orgs will be posted with each request in
34
+ round-robin manner. Defaults to None.
35
+ meta_template (Dict, optional): The model's meta prompt
36
+ template if needed, in case the requirement of injecting or
37
+ wrapping of any meta instructions.
38
+ openai_api_base (str): The base url of OpenAI's API. Defaults to
39
+ 'https://api.openai.com/v1/chat/completions'.
40
+ temperature (float, optional): What sampling temperature to use.
41
+ If not None, will override the temperature in the `generate()`
42
+ call. Defaults to None.
43
+ """
44
+
45
+ is_api: bool = True
46
+
47
+ def __init__(self,
48
+ path: str,
49
+ max_seq_len: int = 2048,
50
+ query_per_second: int = 1,
51
+ retry: int = 2,
52
+ key: Union[str, List[str]] = "ENV",
53
+ org: Optional[Union[str, List[str]]] = None,
54
+ meta_template: Optional[Dict] = None,
55
+ openai_api_base: str = OPENAI_API_BASE,
56
+ temperature: Optional[float] = None):
57
+
58
+ super().__init__(path=path,
59
+ max_seq_len=max_seq_len,
60
+ meta_template=meta_template,
61
+ query_per_second=query_per_second,
62
+ retry=retry)
63
+ try:
64
+ import tiktoken
65
+ except ImportError:
66
+ raise ImportError('tiktoken is not installed, please install tiktoken.')
67
+ self.tiktoken = tiktoken
68
+ self.temperature = temperature
69
+
70
+ if isinstance(key, str):
71
+ self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
72
+ else:
73
+ self.keys = key
74
+ self.key_ctr = 0
75
+ if isinstance(org, str):
76
+ self.orgs = [org]
77
+ else:
78
+ self.orgs = org
79
+ self.org_ctr = 0
80
+ self.url = openai_api_base
81
+
82
+ def generate(
83
+ self,
84
+ inputs: List[str],
85
+ max_out_len: int = 512,
86
+ temperature: float = 0.7,
87
+ ) -> List[str]:
88
+ """Generate results given a list of inputs.
89
+
90
+ Args:
91
+ inputs (List[str or PromptList]): A list of strings or PromptDicts.
92
+ The PromptDict should be organized in OpenCompass'
93
+ API format.
94
+ max_out_len (int): The maximum length of the output.
95
+ temperature (float): What sampling temperature to use,
96
+ between 0 and 2. Higher values like 0.8 will make the output
97
+ more random, while lower values like 0.2 will make it more
98
+ focused and deterministic. Defaults to 0.7.
99
+
100
+ Returns:
101
+ List[str]: A list of generated strings.
102
+ """
103
+ if self.temperature is not None:
104
+ temperature = self.temperature
105
+
106
+ with ThreadPoolExecutor() as executor:
107
+ results = list(
108
+ executor.map(self._generate, inputs,
109
+ [max_out_len] * len(inputs),
110
+ [temperature] * len(inputs)))
111
+ return results
112
+
113
+ def _generate(self, input: str, max_out_len: int,
114
+ temperature: float) -> str:
115
+ """Generate results given a list of inputs.
116
+
117
+ Args:
118
+ inputs (str or PromptList): A string or PromptDict.
119
+ The PromptDict should be organized in OpenCompass'
120
+ API format.
121
+ max_out_len (int): The maximum length of the output.
122
+ temperature (float): What sampling temperature to use,
123
+ between 0 and 2. Higher values like 0.8 will make the output
124
+ more random, while lower values like 0.2 will make it more
125
+ focused and deterministic.
126
+
127
+ Returns:
128
+ str: The generated string.
129
+ """
130
+ assert isinstance(input, (str))
131
+
132
+ if isinstance(input, str):
133
+ messages = [{'role': 'user', 'content': input}]
134
+ else:
135
+ messages = []
136
+ for item in input:
137
+ msg = {'content': item['prompt']}
138
+ if item['role'] == 'HUMAN':
139
+ msg['role'] = 'user'
140
+ elif item['role'] == 'BOT':
141
+ msg['role'] = 'assistant'
142
+ elif item['role'] == 'SYSTEM':
143
+ msg['role'] = 'system'
144
+ messages.append(msg)
145
+
146
+ # max num token for gpt-3.5-turbo is 4097
147
+ max_out_len = min(max_out_len, 4000 - self.get_token_len(str(input)))
148
+ if max_out_len <= 0:
149
+ return ''
150
+
151
+ max_num_retries = 0
152
+ while max_num_retries < self.retry:
153
+ self.wait()
154
+ if hasattr(self, 'keys'):
155
+ with Lock():
156
+ self.key_ctr += 1
157
+ if self.key_ctr == len(self.keys):
158
+ self.key_ctr = 0
159
+ header = {
160
+ 'Authorization': f'Bearer {self.keys[self.key_ctr]}',
161
+ 'content-type': 'application/json',
162
+ }
163
+ if self.orgs:
164
+ with Lock():
165
+ self.org_ctr += 1
166
+ if self.org_ctr == len(self.orgs):
167
+ self.org_ctr = 0
168
+ header['OpenAI-Organization'] = self.orgs[self.org_ctr]
169
+
170
+ try:
171
+ data = dict(
172
+ model=self.path,
173
+ messages=messages,
174
+ # response_format={"type": "json_object"},
175
+ max_tokens=max_out_len,
176
+ n=1,
177
+ stop=None,
178
+ temperature=temperature,
179
+ )
180
+ raw_response = requests.post(self.url,
181
+ headers=header,
182
+ data=json.dumps(data))
183
+ except requests.ConnectionError:
184
+ log.error('Got connection error, retrying...')
185
+ continue
186
+ try:
187
+ response = raw_response.json()
188
+ except requests.JSONDecodeError:
189
+ log.error('JsonDecode error, got ' + str(raw_response.content))
190
+ continue
191
+ try:
192
+ return response['choices'][0]['message']['content'].strip()
193
+ except KeyError:
194
+ if 'error' in response:
195
+ if response['error']['code'] == 'rate_limit_exceeded':
196
+ time.sleep(1)
197
+ continue
198
+ log.error('Find error message in response: ' + str(response['error']))
199
+ max_num_retries += 1
200
+
201
+ raise RuntimeError('Calling OpenAI failed after retrying for '
202
+ f'{max_num_retries} times. Check the logs for '
203
+ 'details.')
204
+
205
+ def get_ppl(self,
206
+ inputs: List[PromptType],
207
+ mask_length: Optional[List[int]] = None) -> List[float]:
208
+ raise NotImplementedError('get_ppl is not implemented for `OpenAI`')
209
+
210
+ def get_token_len(self, prompt: str) -> int:
211
+ """Get lengths of the tokenized string. Only English and Chinese
212
+ characters are counted for now. Users are encouraged to override this
213
+ method if more accurate length is needed.
214
+
215
+ Args:
216
+ prompt (str): Input string.
217
+
218
+ Returns:
219
+ int: Length of the input tokens
220
+ """
221
+ enc = self.tiktoken.encoding_for_model(self.path)
222
+ return len(enc.encode(prompt))