dingo-python 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,148 @@
1
+ import threading
2
+ import os
3
+
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ from dingo.model.llm.common.base_llm_api import BaseLLMModel
8
+ from dingo.model.llm.common.base_llm import LMTemplateParser
9
+
10
+ # from opencompass.utils.logging import get_logger
11
+ # from opencompass.utils.prompt import PromptList
12
+
13
+
14
+ PromptType = Union[str]
15
+
16
+
17
+ def valid_str(string, coding='utf-8'):
18
+ """decode text according to its encoding type."""
19
+ invalid_chars = [b'\xef\xbf\xbd']
20
+ byte_str = bytes(string, coding)
21
+ for invalid_char in invalid_chars:
22
+ byte_str = byte_str.replace(invalid_char, b'')
23
+ ret = byte_str.decode(encoding=coding, errors='ignore')
24
+ return ret
25
+
26
+
27
+ class TurboMindAPIModel(BaseLLMModel):
28
+ """Model wrapper for lmdeploy api server.
29
+
30
+ Args:
31
+ api_addr (str): The address (ip:port format) of lmdeploy's
32
+ api server.
33
+ max_seq_len (int): The maximum allowed sequence length of a model.
34
+ Note that the length of prompt + generated tokens shall not exceed
35
+ this value. Defaults to 2048.
36
+ meta_template (Dict, optional): The model's meta prompt
37
+ template if needed, in case the requirement of injecting or
38
+ wrapping of any meta instructions.
39
+ end_str (str, optional): Whether to trim generated strings with end_str
40
+ if the model has special ending strings that are not handled well.
41
+ Defaults to None.
42
+ """
43
+
44
+ is_api: bool = True
45
+
46
+ def __init__(self,
47
+ api_addr: str = 'http://0.0.0.0:23333',
48
+ max_seq_len: int = 2048,
49
+ meta_template: Optional[Dict] = None,
50
+ end_str: Optional[str] = None,
51
+ **kwargs):
52
+ super().__init__(path='',
53
+ max_seq_len=max_seq_len,
54
+ meta_template=meta_template)
55
+ try:
56
+ from lmdeploy.serve.openai.api_client import APIClient
57
+ except ImportError:
58
+ raise ImportError('lmdeploy is not installed, please install lmdeploy first.')
59
+ self.chatbot = APIClient(api_addr)
60
+ self.model_name = self.chatbot.available_models[0]
61
+ # self.logger = get_logger()
62
+ self.template_parser = LMTemplateParser(meta_template)
63
+ self.eos_token_id = None
64
+ self.token_bucket = None
65
+ if meta_template and 'eos_token_id' in meta_template:
66
+ self.eos_token_id = meta_template['eos_token_id']
67
+ self.api_addr = api_addr
68
+ self.end_str = end_str
69
+
70
+ def generate(
71
+ self,
72
+ inputs: List[str],
73
+ max_out_len: int = 512,
74
+ temperature: float = 1.0,
75
+ ) -> List[str]:
76
+ """Generate results given a list of inputs.
77
+
78
+ Args:
79
+ inputs (List[str or PromptList]): A list of strings or PromptDicts.
80
+ The PromptDict should be organized in OpenCompass'
81
+ API format.
82
+ max_out_len (int): The maximum length of the output.
83
+ temperature (float): What sampling temperature to use,
84
+ between 0 and 2. Higher values like 0.8 will make the output
85
+ more random, while lower values like 0.2 will make it more
86
+ focused and deterministic. Defaults to 0.7.
87
+ Returns:
88
+ List[str]: A list of generated strings.
89
+ """
90
+
91
+ with ThreadPoolExecutor() as executor:
92
+ results = list(
93
+ executor.map(self._generate, inputs,
94
+ [max_out_len] * len(inputs),
95
+ [temperature] * len(inputs),
96
+ [self.end_str] * len(inputs)))
97
+ return results
98
+
99
+ def get_token_len(self, prompt: str) -> int:
100
+ input_ids, length = self.chatbot.encode(prompt)
101
+ return length
102
+
103
+ def get_ppl(self,
104
+ inputs: List[str],
105
+ mask_length: Optional[List[int]] = None) -> List[float]:
106
+ raise NotImplementedError('Not implemented in TurboMindAPIModel.')
107
+
108
+ def wait(self):
109
+ """Wait till the next query can be sent.
110
+
111
+ Applicable in both single-thread and multi-thread environments.
112
+ """
113
+ return self.token_bucket.get_token()
114
+
115
+ def _generate(self, prompt: str, max_out_len: int,
116
+ temperature: float, end_str: str) -> str:
117
+ """Generate results given a list of inputs.
118
+
119
+ Args:
120
+ prompt (str or PromptList): A string or PromptDict.
121
+ The PromptDict should be organized in OpenCompass'
122
+ API format.
123
+ max_out_len (int): The maximum length of the output.
124
+ temperature (float): What sampling temperature to use,
125
+ between 0 and 2. Higher values like 0.8 will make the output
126
+ more random, while lower values like 0.2 will make it more
127
+ focused and deterministic.
128
+
129
+ Returns:
130
+ str: The generated string.
131
+ """
132
+ assert type(
133
+ prompt) is str, 'We only support string for TurboMind RPC API'
134
+
135
+ res = ''
136
+ for output in self.chatbot.completions_v1(
137
+ session_id=threading.current_thread().ident,
138
+ prompt=prompt,
139
+ model=self.model_name,
140
+ max_tokens=max_out_len,
141
+ temperature=temperature,
142
+ top_p=1.0,
143
+ top_k=10):
144
+ res += output['choices'][0]['text']
145
+ res = valid_str(res)
146
+ if end_str:
147
+ res = res.split(end_str)[0]
148
+ return res
dingo/model/llm/gpt.py ADDED
@@ -0,0 +1,62 @@
1
+ import json
2
+
3
+ from dingo.model import Model
4
+ from dingo.model.llm.common.openai_api import OpenAI
5
+ from dingo.model.llm.base import BaseLLM, ResModel
6
+ from dingo.utils import log
7
+
8
+ @Model.llm_register('gpt')
9
+ class GPT(BaseLLM):
10
+ key = ''
11
+
12
+ gpt_client = None
13
+ general_filter = """
14
+ Please rate the following sentences based on their fluency, completeness, and level of repetition.
15
+ The scores from low to high indicate the quality of the sentences, with values ranging from 0 to 10 and reasons given.
16
+ Please provide a JSON format reply containing the specified key and value.
17
+ requirement:
18
+ -The returned content must be in JSON format and there should be no extra content.
19
+ -The first key returned is score, which is an integer between 0 and 10.
20
+ -The second key returned is error, with a value of one of the following: unsmooth, incomplete, or repetitive. If the sentence is correct, this value is empty.
21
+ -The third key returned is reason, and the value is the reason for scoring.
22
+ -If the sentence is empty, please give it a score of 0.
23
+
24
+
25
+ %s
26
+
27
+ """
28
+
29
+ @classmethod
30
+ def create_client(cls):
31
+ if cls.gpt_client is None:
32
+ cls.gpt_client = OpenAI('gpt-4', key=cls.key)
33
+
34
+ @classmethod
35
+ def check_key(cls, data: json):
36
+ key_list = ['score', 'error', 'reason']
37
+ for key in key_list:
38
+ if key not in data:
39
+ return False
40
+ return True
41
+
42
+ @classmethod
43
+ def call_api(cls, input_data: str) -> ResModel:
44
+ cls.create_client()
45
+ response = cls.gpt_client.generate([cls.general_filter % input_data])
46
+ log.debug(response)
47
+ try:
48
+ response = json.loads(response[0])
49
+ if cls.check_key(response) is False:
50
+ raise RuntimeError('miss key: score, error, reason')
51
+
52
+ return ResModel(
53
+ score=response['score'],
54
+ error=response['error'],
55
+ reason=response['reason']
56
+ )
57
+ except RuntimeError:
58
+ return ResModel(
59
+ score=0,
60
+ error='API_LOSS',
61
+ reason=''
62
+ )
@@ -0,0 +1,97 @@
1
+ import json
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ from dingo.model import Model
6
+ from dingo.model.llm.base import BaseLLM, ResModel
7
+
8
+ try:
9
+ import torch
10
+ except ImportError as e:
11
+ raise ImportError("You need to install `torch`, try `pip install torch`")
12
+
13
+ @Model.llm_register('llama3')
14
+ class LLaMa3(BaseLLM):
15
+ path = ''
16
+
17
+ model = None
18
+ tokenizer = None
19
+ general_filter = """
20
+ Please rate the following sentences based on their fluency, completeness, and level of repetition.
21
+ The scores from low to high indicate the quality of the sentences, with values ranging from 0 to 10 and reasons given.
22
+ Please provide a JSON format reply containing the specified key and value.
23
+ requirement:
24
+ -The returned content must be in JSON format and there should be no extra content.
25
+ -The first key returned is score, which is an integer between 0 and 10.
26
+ -The second key returned is error, with a value of one of the following: unsmooth, incomplete, or repetitive. If the sentence is correct, this value is empty.
27
+ -The third key returned is reason, and the value is the reason for scoring.
28
+ -If the sentence is empty, please give it a score of 0.
29
+
30
+
31
+ %s
32
+
33
+ """
34
+
35
+ @classmethod
36
+ def generate_words(cls, input_data: str) -> json:
37
+ if cls.model is None:
38
+ cls.model = AutoModelForCausalLM.from_pretrained(
39
+ cls.path,
40
+ torch_dtype=torch.bfloat16,
41
+ device_map="auto",
42
+ )
43
+ if cls.tokenizer is None:
44
+ cls.tokenizer = AutoTokenizer.from_pretrained(cls.path)
45
+
46
+ messages = [
47
+ {"role": "system", "content": input_data},
48
+ ]
49
+
50
+ input_ids = cls.tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=True,
53
+ return_tensors="pt"
54
+ ).to(cls.model.device)
55
+
56
+ terminators = [
57
+ cls.tokenizer.eos_token_id,
58
+ cls.tokenizer.convert_tokens_to_ids("<|eot_id|>")
59
+ ]
60
+
61
+ outputs = cls.model.generate(
62
+ input_ids,
63
+ max_new_tokens=256,
64
+ eos_token_id=terminators,
65
+ do_sample=True,
66
+ temperature=0.6,
67
+ top_p=0.9,
68
+ )
69
+ response = outputs[0][input_ids.shape[-1]:]
70
+ return json.loads(cls.tokenizer.decode(response, skip_special_tokens=True))
71
+
72
+ @classmethod
73
+ def check_key(cls, data: json):
74
+ key_list = ['score', 'error', 'reason']
75
+ for key in key_list:
76
+ if key not in data:
77
+ return False
78
+ return True
79
+
80
+ @classmethod
81
+ def call_api(cls, input_data: str) -> ResModel:
82
+ try:
83
+ response = cls.generate_words(cls.general_filter % input_data)
84
+ if cls.check_key(response) is False:
85
+ raise RuntimeError('miss key: score, error, reason')
86
+
87
+ return ResModel(
88
+ score=response['score'],
89
+ error=response['error'],
90
+ reason=response['reason']
91
+ )
92
+ except RuntimeError:
93
+ return ResModel(
94
+ score=0,
95
+ error='API_LOSS',
96
+ reason=''
97
+ )
@@ -0,0 +1,68 @@
1
+ import pprint
2
+
3
+ from dingo.model import Model
4
+ from dingo.model.llm.base import BaseLLM, ResModel
5
+
6
+ try:
7
+ from googleapiclient import discovery
8
+ except ImportError:
9
+ raise ImportError('googleapiclient not installed, please install it with `pip install google-api-python-client`')
10
+
11
+ @Model.llm_register('perspective')
12
+ class Perspective(BaseLLM):
13
+ key = ''
14
+ api_url = 'https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1'
15
+
16
+ client = None
17
+
18
+ @classmethod
19
+ def create_client(cls):
20
+ if cls.client is None:
21
+ cls.client = discovery.build(
22
+ "commentanalyzer",
23
+ "v1alpha1",
24
+ developerKey=cls.key,
25
+ discoveryServiceUrl=cls.api_url,
26
+ static_discovery=False,
27
+ )
28
+
29
+ @classmethod
30
+ def call_api(cls, input_data: str) -> ResModel:
31
+ cls.create_client()
32
+ analyze_request = {
33
+ 'comment': {'text': input_data},
34
+ 'requestedAttributes': {
35
+ 'TOXICITY': {},
36
+ 'SEVERE_TOXICITY': {},
37
+ 'IDENTITY_ATTACK': {},
38
+ 'INSULT': {},
39
+ 'PROFANITY': {},
40
+ 'THREAT': {},
41
+ # 'SEXUALLY_EXPLICIT': {},
42
+ # 'FLIRTATION': {},
43
+ }
44
+ }
45
+
46
+ try:
47
+ response = cls.client.comments().analyze(body=analyze_request).execute()
48
+ result = ResModel()
49
+ is_good = True
50
+ error_list = []
51
+ for e in response['attributeScores']:
52
+ value = response['attributeScores'][e]['summaryScore']['value']
53
+ if value > 0.6:
54
+ is_good = False
55
+ error_list.append(e)
56
+ if is_good is True:
57
+ result.score = 10
58
+ else:
59
+ result.score = 0
60
+ result.error = 'ERROR_PERSPECTIVE'
61
+ result.reason = ",".join(error_list)
62
+ return result
63
+ except RuntimeError:
64
+ return ResModel(
65
+ score=0,
66
+ error='API_LOSS',
67
+ reason=''
68
+ )
dingo/model/model.py ADDED
@@ -0,0 +1,227 @@
1
+ from functools import wraps
2
+ from typing import Dict, List, Callable, Optional
3
+ import os
4
+ import importlib
5
+
6
+
7
+ from pydantic import BaseModel
8
+
9
+
10
+ from dingo.config import GlobalConfig
11
+ from dingo.model.llm.base import BaseLLM
12
+ from dingo.model.rule.base import BaseRule
13
+ from dingo.utils import log
14
+ from dingo.model.llm.common.base_llm import BaseLLMModel
15
+
16
+
17
+ class BaseEvalModel(BaseModel):
18
+ name: str
19
+ type: str
20
+
21
+
22
+ class Model:
23
+ """
24
+ Model configuration class.
25
+ """
26
+ module_loaded = False
27
+ rule_metric_type_map = {
28
+ 'QUALITY_SIGNAL_EFFECTIVENESS': [], # Effectiveness
29
+ 'QUALITY_SIGNAL_COMPLETENESS': [], # Completeness
30
+ 'QUALITY_SIGNAL_UNDERSTANDABILITY': [], # Understandability
31
+ 'QUALITY_SIGNAL_SIMILARITY': [], # Similarity
32
+ 'QUALITY_SIGNAL_FLUENCY': [], # Fluency
33
+ 'QUALITY_SIGNAL_RELEVANCE': [], # Relevance
34
+ 'QUALITY_SIGNAL_SECURITY': [], # Security
35
+ }
36
+ rule_groups = {}
37
+ rule_name_map = {}
38
+ llm_models = {}
39
+
40
+ def __init__(self):
41
+ return
42
+
43
+ @classmethod
44
+ def get_rule_metric_type_map(cls) -> Dict[str, List[Callable]]:
45
+ """
46
+ Returns the rule metric type map.
47
+
48
+ Returns:
49
+ Rule metric type map ( { rule_metric_type: [rules] } )
50
+ """
51
+ return cls.rule_metric_type_map
52
+
53
+ @classmethod
54
+ def get_rule_group(cls, rule_group_name: str) -> List[Callable]:
55
+ """
56
+ Returns the rule groups by rule_group_name.
57
+
58
+ Returns:
59
+ Rule groups ( [rules] ).
60
+ """
61
+ return cls.rule_groups[rule_group_name]
62
+
63
+ @classmethod
64
+ def get_rule_groups(cls) -> Dict[str, List[Callable]]:
65
+ """
66
+ Returns the rule groups.
67
+
68
+ Returns:
69
+ Rule groups map ( { rule_group_id: [rules] } ).
70
+ """
71
+ return cls.rule_groups
72
+
73
+ @classmethod
74
+ def get_rule_by_name(cls, name: str) -> Callable:
75
+ """
76
+ Returns rule by name.
77
+
78
+ Returns:
79
+ Rule function.
80
+ """
81
+ return cls.rule_name_map[name]
82
+
83
+ @classmethod
84
+ def get_llm_models(cls) -> Dict[str, BaseLLMModel]:
85
+ """
86
+ Returns the llm models.
87
+
88
+ Returns:
89
+ LLM models class List
90
+ """
91
+ return cls.llm_models
92
+
93
+ @classmethod
94
+ def get_llm_model(cls, llm_model_name: str) -> BaseLLMModel:
95
+ """
96
+ Returns the llm model by llm_model_name.
97
+ Args:
98
+ llm_model_name (str): The name of the llm model.
99
+
100
+ Returns:
101
+ LLM model class
102
+ """
103
+ return cls.llm_models[llm_model_name]
104
+
105
+ @classmethod
106
+ def print_rule_list(cls) -> None:
107
+ """
108
+ Print the rule list.
109
+
110
+ Returns:
111
+ List of rules.
112
+ """
113
+ rule_list = []
114
+ for rule_name in cls.rule_name_map:
115
+ rule_list.append(rule_name)
116
+ print(rule_list)
117
+
118
+ @classmethod
119
+ def get_all_info(cls):
120
+ """
121
+ Returns rules' map and llm models' map
122
+ """
123
+ raise NotImplementedError()
124
+
125
+ @classmethod
126
+ def rule_register(cls, metric_type: str, group: List[str]) -> Callable:
127
+ """
128
+ Register a model. (register)
129
+ Args:
130
+ metric_type (str): The metric type (quality map).
131
+ group (List[str]): The group names.
132
+ """
133
+ def decorator(root_class):
134
+ # group
135
+ for group_name in group:
136
+ if group_name not in cls.rule_groups:
137
+ cls.rule_groups[group_name] = []
138
+ cls.rule_groups[group_name].append(root_class)
139
+ cls.rule_name_map[root_class.__name__] = root_class
140
+
141
+ # metric_type
142
+ if metric_type not in cls.rule_metric_type_map:
143
+ raise KeyError(f'Metric type "{metric_type}" can not be registered.')
144
+ cls.rule_metric_type_map[metric_type].append(root_class)
145
+
146
+ @wraps(root_class)
147
+ def wrapped_function(*args, **kwargs):
148
+ return root_class(*args, **kwargs)
149
+
150
+ return wrapped_function
151
+
152
+ return decorator
153
+
154
+ @classmethod
155
+ def llm_register(cls, llm_id: str) -> Callable:
156
+ """
157
+ Register a model. (register)
158
+ Args:
159
+ llm_id (str): Name of llm model class.
160
+ """
161
+ def decorator(root_method):
162
+ cls.llm_models[llm_id] = root_method
163
+
164
+ @wraps(root_method)
165
+ def wrapped_function(*args, **kwargs):
166
+ return root_method(*args, **kwargs)
167
+
168
+ return wrapped_function
169
+
170
+ return decorator
171
+
172
+
173
+ @classmethod
174
+ def apply_config(cls, custom_config_path: Optional[str]):
175
+ GlobalConfig.read_config_file(custom_config_path)
176
+ if GlobalConfig.config and GlobalConfig.config.rule_config:
177
+ for rule, params in GlobalConfig.config.rule_config.items():
178
+ if rule not in cls.rule_name_map:
179
+ continue
180
+ assert isinstance(rule, str)
181
+ for param_name in ['threshold', 'pattern', 'key_list', 'file_path']:
182
+ param_value = getattr(params, param_name)
183
+ if not param_value:
184
+ continue
185
+ log.debug(f"[Rule config]: config {param_name} for {rule}")
186
+ cls_rule: BaseRule = cls.rule_name_map[rule]
187
+ setattr(cls_rule, param_name, param_value)
188
+ if GlobalConfig.config and GlobalConfig.config.llm_config:
189
+ for llm, params in GlobalConfig.config.llm_config.items():
190
+ if llm not in cls.llm_models.keys():
191
+ continue
192
+ assert isinstance(llm, str)
193
+ for param_name in ['path', 'key', 'api_url']:
194
+ param_value = getattr(params, param_name)
195
+ if not param_value:
196
+ continue
197
+ log.debug(f"[LLM config]: config {param_name} for {llm}")
198
+ cls_llm: BaseLLM = cls.llm_models[llm]
199
+ setattr(cls_llm, param_name, param_value)
200
+
201
+ @classmethod
202
+ def load_model(cls):
203
+ if cls.module_loaded:
204
+ return
205
+ this_module_directory = os.path.dirname(os.path.abspath(__file__))
206
+ # rule auto register
207
+ for file in os.listdir(os.path.join(this_module_directory, 'rule')):
208
+ path = os.path.join(this_module_directory, 'rule', file)
209
+ if os.path.isfile(path) and file.endswith('.py') and not file == '__init__.py':
210
+ try:
211
+ importlib.import_module('dingo.model.rule.' + file.split('.')[0])
212
+ except ModuleNotFoundError as e:
213
+ log.debug(e)
214
+
215
+ # llm auto register
216
+ for file in os.listdir(os.path.join(this_module_directory, 'llm')):
217
+ path = os.path.join(this_module_directory, 'llm', file)
218
+ if os.path.isfile(path) and file.endswith('.py') and not file == '__init__.py':
219
+ try:
220
+ importlib.import_module('dingo.model.llm.' + file.split('.')[0])
221
+ except ModuleNotFoundError as e:
222
+ log.debug(e)
223
+ except ImportError as e:
224
+ log.debug("=" * 30 + " ImportError " + "=" * 30)
225
+ log.debug(f'module {file.split(".")[0]} not imported because: \n{e}')
226
+ log.debug("=" * 73)
227
+ cls.module_loaded = True
File without changes
@@ -0,0 +1,14 @@
1
+ from typing import Protocol, List, Union
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class ResModel(BaseModel):
6
+ error_status: bool = False
7
+ error_reason: str = ''
8
+
9
+
10
+ class BaseRule(Protocol):
11
+
12
+ @classmethod
13
+ def eval(cls, input_data: List[str]) -> ResModel:
14
+ ...