dingo-python 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dingo/__init__.py +0 -0
- dingo/config/__init__.py +1 -0
- dingo/config/config.py +47 -0
- dingo/convert/__init__.py +4 -0
- dingo/convert/base.py +147 -0
- dingo/exec/__init__.py +3 -0
- dingo/exec/base.py +54 -0
- dingo/exec/local.py +288 -0
- dingo/exec/spark.py +169 -0
- dingo/io/__init__.py +2 -0
- dingo/io/export.py +0 -0
- dingo/io/input.py +27 -0
- dingo/io/summary.py +28 -0
- dingo/model/__init__.py +3 -0
- dingo/model/llm/__init__.py +0 -0
- dingo/model/llm/base.py +12 -0
- dingo/model/llm/common/__init__.py +0 -0
- dingo/model/llm/common/base_llm.py +395 -0
- dingo/model/llm/common/base_llm_api.py +396 -0
- dingo/model/llm/common/openai_api.py +222 -0
- dingo/model/llm/common/turbomind_api.py +148 -0
- dingo/model/llm/gpt.py +62 -0
- dingo/model/llm/llama3.py +97 -0
- dingo/model/llm/perspective.py +68 -0
- dingo/model/model.py +227 -0
- dingo/model/rule/__init__.py +0 -0
- dingo/model/rule/base.py +14 -0
- dingo/model/rule/common_rule.py +551 -0
- dingo/model/rule/image_rule.py +81 -0
- dingo/model/rule/prompt_rule.py +39 -0
- dingo/model/rule/util.py +282 -0
- dingo/utils/__init__.py +1 -0
- dingo/utils/log_util/__init__.py +32 -0
- dingo/utils/log_util/logger.py +39 -0
- dingo_python-1.0.dist-info/LICENSE +201 -0
- dingo_python-1.0.dist-info/METADATA +221 -0
- dingo_python-1.0.dist-info/RECORD +39 -0
- dingo_python-1.0.dist-info/WHEEL +5 -0
- dingo_python-1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from dingo.model.llm.common.base_llm_api import BaseLLMModel
|
|
8
|
+
from dingo.model.llm.common.base_llm import LMTemplateParser
|
|
9
|
+
|
|
10
|
+
# from opencompass.utils.logging import get_logger
|
|
11
|
+
# from opencompass.utils.prompt import PromptList
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
PromptType = Union[str]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def valid_str(string, coding='utf-8'):
|
|
18
|
+
"""decode text according to its encoding type."""
|
|
19
|
+
invalid_chars = [b'\xef\xbf\xbd']
|
|
20
|
+
byte_str = bytes(string, coding)
|
|
21
|
+
for invalid_char in invalid_chars:
|
|
22
|
+
byte_str = byte_str.replace(invalid_char, b'')
|
|
23
|
+
ret = byte_str.decode(encoding=coding, errors='ignore')
|
|
24
|
+
return ret
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TurboMindAPIModel(BaseLLMModel):
|
|
28
|
+
"""Model wrapper for lmdeploy api server.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
api_addr (str): The address (ip:port format) of lmdeploy's
|
|
32
|
+
api server.
|
|
33
|
+
max_seq_len (int): The maximum allowed sequence length of a model.
|
|
34
|
+
Note that the length of prompt + generated tokens shall not exceed
|
|
35
|
+
this value. Defaults to 2048.
|
|
36
|
+
meta_template (Dict, optional): The model's meta prompt
|
|
37
|
+
template if needed, in case the requirement of injecting or
|
|
38
|
+
wrapping of any meta instructions.
|
|
39
|
+
end_str (str, optional): Whether to trim generated strings with end_str
|
|
40
|
+
if the model has special ending strings that are not handled well.
|
|
41
|
+
Defaults to None.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
is_api: bool = True
|
|
45
|
+
|
|
46
|
+
def __init__(self,
|
|
47
|
+
api_addr: str = 'http://0.0.0.0:23333',
|
|
48
|
+
max_seq_len: int = 2048,
|
|
49
|
+
meta_template: Optional[Dict] = None,
|
|
50
|
+
end_str: Optional[str] = None,
|
|
51
|
+
**kwargs):
|
|
52
|
+
super().__init__(path='',
|
|
53
|
+
max_seq_len=max_seq_len,
|
|
54
|
+
meta_template=meta_template)
|
|
55
|
+
try:
|
|
56
|
+
from lmdeploy.serve.openai.api_client import APIClient
|
|
57
|
+
except ImportError:
|
|
58
|
+
raise ImportError('lmdeploy is not installed, please install lmdeploy first.')
|
|
59
|
+
self.chatbot = APIClient(api_addr)
|
|
60
|
+
self.model_name = self.chatbot.available_models[0]
|
|
61
|
+
# self.logger = get_logger()
|
|
62
|
+
self.template_parser = LMTemplateParser(meta_template)
|
|
63
|
+
self.eos_token_id = None
|
|
64
|
+
self.token_bucket = None
|
|
65
|
+
if meta_template and 'eos_token_id' in meta_template:
|
|
66
|
+
self.eos_token_id = meta_template['eos_token_id']
|
|
67
|
+
self.api_addr = api_addr
|
|
68
|
+
self.end_str = end_str
|
|
69
|
+
|
|
70
|
+
def generate(
|
|
71
|
+
self,
|
|
72
|
+
inputs: List[str],
|
|
73
|
+
max_out_len: int = 512,
|
|
74
|
+
temperature: float = 1.0,
|
|
75
|
+
) -> List[str]:
|
|
76
|
+
"""Generate results given a list of inputs.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
inputs (List[str or PromptList]): A list of strings or PromptDicts.
|
|
80
|
+
The PromptDict should be organized in OpenCompass'
|
|
81
|
+
API format.
|
|
82
|
+
max_out_len (int): The maximum length of the output.
|
|
83
|
+
temperature (float): What sampling temperature to use,
|
|
84
|
+
between 0 and 2. Higher values like 0.8 will make the output
|
|
85
|
+
more random, while lower values like 0.2 will make it more
|
|
86
|
+
focused and deterministic. Defaults to 0.7.
|
|
87
|
+
Returns:
|
|
88
|
+
List[str]: A list of generated strings.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
with ThreadPoolExecutor() as executor:
|
|
92
|
+
results = list(
|
|
93
|
+
executor.map(self._generate, inputs,
|
|
94
|
+
[max_out_len] * len(inputs),
|
|
95
|
+
[temperature] * len(inputs),
|
|
96
|
+
[self.end_str] * len(inputs)))
|
|
97
|
+
return results
|
|
98
|
+
|
|
99
|
+
def get_token_len(self, prompt: str) -> int:
|
|
100
|
+
input_ids, length = self.chatbot.encode(prompt)
|
|
101
|
+
return length
|
|
102
|
+
|
|
103
|
+
def get_ppl(self,
|
|
104
|
+
inputs: List[str],
|
|
105
|
+
mask_length: Optional[List[int]] = None) -> List[float]:
|
|
106
|
+
raise NotImplementedError('Not implemented in TurboMindAPIModel.')
|
|
107
|
+
|
|
108
|
+
def wait(self):
|
|
109
|
+
"""Wait till the next query can be sent.
|
|
110
|
+
|
|
111
|
+
Applicable in both single-thread and multi-thread environments.
|
|
112
|
+
"""
|
|
113
|
+
return self.token_bucket.get_token()
|
|
114
|
+
|
|
115
|
+
def _generate(self, prompt: str, max_out_len: int,
|
|
116
|
+
temperature: float, end_str: str) -> str:
|
|
117
|
+
"""Generate results given a list of inputs.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
prompt (str or PromptList): A string or PromptDict.
|
|
121
|
+
The PromptDict should be organized in OpenCompass'
|
|
122
|
+
API format.
|
|
123
|
+
max_out_len (int): The maximum length of the output.
|
|
124
|
+
temperature (float): What sampling temperature to use,
|
|
125
|
+
between 0 and 2. Higher values like 0.8 will make the output
|
|
126
|
+
more random, while lower values like 0.2 will make it more
|
|
127
|
+
focused and deterministic.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
str: The generated string.
|
|
131
|
+
"""
|
|
132
|
+
assert type(
|
|
133
|
+
prompt) is str, 'We only support string for TurboMind RPC API'
|
|
134
|
+
|
|
135
|
+
res = ''
|
|
136
|
+
for output in self.chatbot.completions_v1(
|
|
137
|
+
session_id=threading.current_thread().ident,
|
|
138
|
+
prompt=prompt,
|
|
139
|
+
model=self.model_name,
|
|
140
|
+
max_tokens=max_out_len,
|
|
141
|
+
temperature=temperature,
|
|
142
|
+
top_p=1.0,
|
|
143
|
+
top_k=10):
|
|
144
|
+
res += output['choices'][0]['text']
|
|
145
|
+
res = valid_str(res)
|
|
146
|
+
if end_str:
|
|
147
|
+
res = res.split(end_str)[0]
|
|
148
|
+
return res
|
dingo/model/llm/gpt.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from dingo.model import Model
|
|
4
|
+
from dingo.model.llm.common.openai_api import OpenAI
|
|
5
|
+
from dingo.model.llm.base import BaseLLM, ResModel
|
|
6
|
+
from dingo.utils import log
|
|
7
|
+
|
|
8
|
+
@Model.llm_register('gpt')
|
|
9
|
+
class GPT(BaseLLM):
|
|
10
|
+
key = ''
|
|
11
|
+
|
|
12
|
+
gpt_client = None
|
|
13
|
+
general_filter = """
|
|
14
|
+
Please rate the following sentences based on their fluency, completeness, and level of repetition.
|
|
15
|
+
The scores from low to high indicate the quality of the sentences, with values ranging from 0 to 10 and reasons given.
|
|
16
|
+
Please provide a JSON format reply containing the specified key and value.
|
|
17
|
+
requirement:
|
|
18
|
+
-The returned content must be in JSON format and there should be no extra content.
|
|
19
|
+
-The first key returned is score, which is an integer between 0 and 10.
|
|
20
|
+
-The second key returned is error, with a value of one of the following: unsmooth, incomplete, or repetitive. If the sentence is correct, this value is empty.
|
|
21
|
+
-The third key returned is reason, and the value is the reason for scoring.
|
|
22
|
+
-If the sentence is empty, please give it a score of 0.
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
%s
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def create_client(cls):
|
|
31
|
+
if cls.gpt_client is None:
|
|
32
|
+
cls.gpt_client = OpenAI('gpt-4', key=cls.key)
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def check_key(cls, data: json):
|
|
36
|
+
key_list = ['score', 'error', 'reason']
|
|
37
|
+
for key in key_list:
|
|
38
|
+
if key not in data:
|
|
39
|
+
return False
|
|
40
|
+
return True
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def call_api(cls, input_data: str) -> ResModel:
|
|
44
|
+
cls.create_client()
|
|
45
|
+
response = cls.gpt_client.generate([cls.general_filter % input_data])
|
|
46
|
+
log.debug(response)
|
|
47
|
+
try:
|
|
48
|
+
response = json.loads(response[0])
|
|
49
|
+
if cls.check_key(response) is False:
|
|
50
|
+
raise RuntimeError('miss key: score, error, reason')
|
|
51
|
+
|
|
52
|
+
return ResModel(
|
|
53
|
+
score=response['score'],
|
|
54
|
+
error=response['error'],
|
|
55
|
+
reason=response['reason']
|
|
56
|
+
)
|
|
57
|
+
except RuntimeError:
|
|
58
|
+
return ResModel(
|
|
59
|
+
score=0,
|
|
60
|
+
error='API_LOSS',
|
|
61
|
+
reason=''
|
|
62
|
+
)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
4
|
+
|
|
5
|
+
from dingo.model import Model
|
|
6
|
+
from dingo.model.llm.base import BaseLLM, ResModel
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import torch
|
|
10
|
+
except ImportError as e:
|
|
11
|
+
raise ImportError("You need to install `torch`, try `pip install torch`")
|
|
12
|
+
|
|
13
|
+
@Model.llm_register('llama3')
|
|
14
|
+
class LLaMa3(BaseLLM):
|
|
15
|
+
path = ''
|
|
16
|
+
|
|
17
|
+
model = None
|
|
18
|
+
tokenizer = None
|
|
19
|
+
general_filter = """
|
|
20
|
+
Please rate the following sentences based on their fluency, completeness, and level of repetition.
|
|
21
|
+
The scores from low to high indicate the quality of the sentences, with values ranging from 0 to 10 and reasons given.
|
|
22
|
+
Please provide a JSON format reply containing the specified key and value.
|
|
23
|
+
requirement:
|
|
24
|
+
-The returned content must be in JSON format and there should be no extra content.
|
|
25
|
+
-The first key returned is score, which is an integer between 0 and 10.
|
|
26
|
+
-The second key returned is error, with a value of one of the following: unsmooth, incomplete, or repetitive. If the sentence is correct, this value is empty.
|
|
27
|
+
-The third key returned is reason, and the value is the reason for scoring.
|
|
28
|
+
-If the sentence is empty, please give it a score of 0.
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
%s
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def generate_words(cls, input_data: str) -> json:
|
|
37
|
+
if cls.model is None:
|
|
38
|
+
cls.model = AutoModelForCausalLM.from_pretrained(
|
|
39
|
+
cls.path,
|
|
40
|
+
torch_dtype=torch.bfloat16,
|
|
41
|
+
device_map="auto",
|
|
42
|
+
)
|
|
43
|
+
if cls.tokenizer is None:
|
|
44
|
+
cls.tokenizer = AutoTokenizer.from_pretrained(cls.path)
|
|
45
|
+
|
|
46
|
+
messages = [
|
|
47
|
+
{"role": "system", "content": input_data},
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
input_ids = cls.tokenizer.apply_chat_template(
|
|
51
|
+
messages,
|
|
52
|
+
add_generation_prompt=True,
|
|
53
|
+
return_tensors="pt"
|
|
54
|
+
).to(cls.model.device)
|
|
55
|
+
|
|
56
|
+
terminators = [
|
|
57
|
+
cls.tokenizer.eos_token_id,
|
|
58
|
+
cls.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
outputs = cls.model.generate(
|
|
62
|
+
input_ids,
|
|
63
|
+
max_new_tokens=256,
|
|
64
|
+
eos_token_id=terminators,
|
|
65
|
+
do_sample=True,
|
|
66
|
+
temperature=0.6,
|
|
67
|
+
top_p=0.9,
|
|
68
|
+
)
|
|
69
|
+
response = outputs[0][input_ids.shape[-1]:]
|
|
70
|
+
return json.loads(cls.tokenizer.decode(response, skip_special_tokens=True))
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def check_key(cls, data: json):
|
|
74
|
+
key_list = ['score', 'error', 'reason']
|
|
75
|
+
for key in key_list:
|
|
76
|
+
if key not in data:
|
|
77
|
+
return False
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def call_api(cls, input_data: str) -> ResModel:
|
|
82
|
+
try:
|
|
83
|
+
response = cls.generate_words(cls.general_filter % input_data)
|
|
84
|
+
if cls.check_key(response) is False:
|
|
85
|
+
raise RuntimeError('miss key: score, error, reason')
|
|
86
|
+
|
|
87
|
+
return ResModel(
|
|
88
|
+
score=response['score'],
|
|
89
|
+
error=response['error'],
|
|
90
|
+
reason=response['reason']
|
|
91
|
+
)
|
|
92
|
+
except RuntimeError:
|
|
93
|
+
return ResModel(
|
|
94
|
+
score=0,
|
|
95
|
+
error='API_LOSS',
|
|
96
|
+
reason=''
|
|
97
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import pprint
|
|
2
|
+
|
|
3
|
+
from dingo.model import Model
|
|
4
|
+
from dingo.model.llm.base import BaseLLM, ResModel
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from googleapiclient import discovery
|
|
8
|
+
except ImportError:
|
|
9
|
+
raise ImportError('googleapiclient not installed, please install it with `pip install google-api-python-client`')
|
|
10
|
+
|
|
11
|
+
@Model.llm_register('perspective')
|
|
12
|
+
class Perspective(BaseLLM):
|
|
13
|
+
key = ''
|
|
14
|
+
api_url = 'https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1'
|
|
15
|
+
|
|
16
|
+
client = None
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def create_client(cls):
|
|
20
|
+
if cls.client is None:
|
|
21
|
+
cls.client = discovery.build(
|
|
22
|
+
"commentanalyzer",
|
|
23
|
+
"v1alpha1",
|
|
24
|
+
developerKey=cls.key,
|
|
25
|
+
discoveryServiceUrl=cls.api_url,
|
|
26
|
+
static_discovery=False,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def call_api(cls, input_data: str) -> ResModel:
|
|
31
|
+
cls.create_client()
|
|
32
|
+
analyze_request = {
|
|
33
|
+
'comment': {'text': input_data},
|
|
34
|
+
'requestedAttributes': {
|
|
35
|
+
'TOXICITY': {},
|
|
36
|
+
'SEVERE_TOXICITY': {},
|
|
37
|
+
'IDENTITY_ATTACK': {},
|
|
38
|
+
'INSULT': {},
|
|
39
|
+
'PROFANITY': {},
|
|
40
|
+
'THREAT': {},
|
|
41
|
+
# 'SEXUALLY_EXPLICIT': {},
|
|
42
|
+
# 'FLIRTATION': {},
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
response = cls.client.comments().analyze(body=analyze_request).execute()
|
|
48
|
+
result = ResModel()
|
|
49
|
+
is_good = True
|
|
50
|
+
error_list = []
|
|
51
|
+
for e in response['attributeScores']:
|
|
52
|
+
value = response['attributeScores'][e]['summaryScore']['value']
|
|
53
|
+
if value > 0.6:
|
|
54
|
+
is_good = False
|
|
55
|
+
error_list.append(e)
|
|
56
|
+
if is_good is True:
|
|
57
|
+
result.score = 10
|
|
58
|
+
else:
|
|
59
|
+
result.score = 0
|
|
60
|
+
result.error = 'ERROR_PERSPECTIVE'
|
|
61
|
+
result.reason = ",".join(error_list)
|
|
62
|
+
return result
|
|
63
|
+
except RuntimeError:
|
|
64
|
+
return ResModel(
|
|
65
|
+
score=0,
|
|
66
|
+
error='API_LOSS',
|
|
67
|
+
reason=''
|
|
68
|
+
)
|
dingo/model/model.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
from typing import Dict, List, Callable, Optional
|
|
3
|
+
import os
|
|
4
|
+
import importlib
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from dingo.config import GlobalConfig
|
|
11
|
+
from dingo.model.llm.base import BaseLLM
|
|
12
|
+
from dingo.model.rule.base import BaseRule
|
|
13
|
+
from dingo.utils import log
|
|
14
|
+
from dingo.model.llm.common.base_llm import BaseLLMModel
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseEvalModel(BaseModel):
|
|
18
|
+
name: str
|
|
19
|
+
type: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Model:
|
|
23
|
+
"""
|
|
24
|
+
Model configuration class.
|
|
25
|
+
"""
|
|
26
|
+
module_loaded = False
|
|
27
|
+
rule_metric_type_map = {
|
|
28
|
+
'QUALITY_SIGNAL_EFFECTIVENESS': [], # Effectiveness
|
|
29
|
+
'QUALITY_SIGNAL_COMPLETENESS': [], # Completeness
|
|
30
|
+
'QUALITY_SIGNAL_UNDERSTANDABILITY': [], # Understandability
|
|
31
|
+
'QUALITY_SIGNAL_SIMILARITY': [], # Similarity
|
|
32
|
+
'QUALITY_SIGNAL_FLUENCY': [], # Fluency
|
|
33
|
+
'QUALITY_SIGNAL_RELEVANCE': [], # Relevance
|
|
34
|
+
'QUALITY_SIGNAL_SECURITY': [], # Security
|
|
35
|
+
}
|
|
36
|
+
rule_groups = {}
|
|
37
|
+
rule_name_map = {}
|
|
38
|
+
llm_models = {}
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def get_rule_metric_type_map(cls) -> Dict[str, List[Callable]]:
|
|
45
|
+
"""
|
|
46
|
+
Returns the rule metric type map.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Rule metric type map ( { rule_metric_type: [rules] } )
|
|
50
|
+
"""
|
|
51
|
+
return cls.rule_metric_type_map
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def get_rule_group(cls, rule_group_name: str) -> List[Callable]:
|
|
55
|
+
"""
|
|
56
|
+
Returns the rule groups by rule_group_name.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Rule groups ( [rules] ).
|
|
60
|
+
"""
|
|
61
|
+
return cls.rule_groups[rule_group_name]
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def get_rule_groups(cls) -> Dict[str, List[Callable]]:
|
|
65
|
+
"""
|
|
66
|
+
Returns the rule groups.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Rule groups map ( { rule_group_id: [rules] } ).
|
|
70
|
+
"""
|
|
71
|
+
return cls.rule_groups
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def get_rule_by_name(cls, name: str) -> Callable:
|
|
75
|
+
"""
|
|
76
|
+
Returns rule by name.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Rule function.
|
|
80
|
+
"""
|
|
81
|
+
return cls.rule_name_map[name]
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def get_llm_models(cls) -> Dict[str, BaseLLMModel]:
|
|
85
|
+
"""
|
|
86
|
+
Returns the llm models.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
LLM models class List
|
|
90
|
+
"""
|
|
91
|
+
return cls.llm_models
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def get_llm_model(cls, llm_model_name: str) -> BaseLLMModel:
|
|
95
|
+
"""
|
|
96
|
+
Returns the llm model by llm_model_name.
|
|
97
|
+
Args:
|
|
98
|
+
llm_model_name (str): The name of the llm model.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
LLM model class
|
|
102
|
+
"""
|
|
103
|
+
return cls.llm_models[llm_model_name]
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def print_rule_list(cls) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Print the rule list.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
List of rules.
|
|
112
|
+
"""
|
|
113
|
+
rule_list = []
|
|
114
|
+
for rule_name in cls.rule_name_map:
|
|
115
|
+
rule_list.append(rule_name)
|
|
116
|
+
print(rule_list)
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def get_all_info(cls):
|
|
120
|
+
"""
|
|
121
|
+
Returns rules' map and llm models' map
|
|
122
|
+
"""
|
|
123
|
+
raise NotImplementedError()
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def rule_register(cls, metric_type: str, group: List[str]) -> Callable:
|
|
127
|
+
"""
|
|
128
|
+
Register a model. (register)
|
|
129
|
+
Args:
|
|
130
|
+
metric_type (str): The metric type (quality map).
|
|
131
|
+
group (List[str]): The group names.
|
|
132
|
+
"""
|
|
133
|
+
def decorator(root_class):
|
|
134
|
+
# group
|
|
135
|
+
for group_name in group:
|
|
136
|
+
if group_name not in cls.rule_groups:
|
|
137
|
+
cls.rule_groups[group_name] = []
|
|
138
|
+
cls.rule_groups[group_name].append(root_class)
|
|
139
|
+
cls.rule_name_map[root_class.__name__] = root_class
|
|
140
|
+
|
|
141
|
+
# metric_type
|
|
142
|
+
if metric_type not in cls.rule_metric_type_map:
|
|
143
|
+
raise KeyError(f'Metric type "{metric_type}" can not be registered.')
|
|
144
|
+
cls.rule_metric_type_map[metric_type].append(root_class)
|
|
145
|
+
|
|
146
|
+
@wraps(root_class)
|
|
147
|
+
def wrapped_function(*args, **kwargs):
|
|
148
|
+
return root_class(*args, **kwargs)
|
|
149
|
+
|
|
150
|
+
return wrapped_function
|
|
151
|
+
|
|
152
|
+
return decorator
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def llm_register(cls, llm_id: str) -> Callable:
|
|
156
|
+
"""
|
|
157
|
+
Register a model. (register)
|
|
158
|
+
Args:
|
|
159
|
+
llm_id (str): Name of llm model class.
|
|
160
|
+
"""
|
|
161
|
+
def decorator(root_method):
|
|
162
|
+
cls.llm_models[llm_id] = root_method
|
|
163
|
+
|
|
164
|
+
@wraps(root_method)
|
|
165
|
+
def wrapped_function(*args, **kwargs):
|
|
166
|
+
return root_method(*args, **kwargs)
|
|
167
|
+
|
|
168
|
+
return wrapped_function
|
|
169
|
+
|
|
170
|
+
return decorator
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def apply_config(cls, custom_config_path: Optional[str]):
|
|
175
|
+
GlobalConfig.read_config_file(custom_config_path)
|
|
176
|
+
if GlobalConfig.config and GlobalConfig.config.rule_config:
|
|
177
|
+
for rule, params in GlobalConfig.config.rule_config.items():
|
|
178
|
+
if rule not in cls.rule_name_map:
|
|
179
|
+
continue
|
|
180
|
+
assert isinstance(rule, str)
|
|
181
|
+
for param_name in ['threshold', 'pattern', 'key_list', 'file_path']:
|
|
182
|
+
param_value = getattr(params, param_name)
|
|
183
|
+
if not param_value:
|
|
184
|
+
continue
|
|
185
|
+
log.debug(f"[Rule config]: config {param_name} for {rule}")
|
|
186
|
+
cls_rule: BaseRule = cls.rule_name_map[rule]
|
|
187
|
+
setattr(cls_rule, param_name, param_value)
|
|
188
|
+
if GlobalConfig.config and GlobalConfig.config.llm_config:
|
|
189
|
+
for llm, params in GlobalConfig.config.llm_config.items():
|
|
190
|
+
if llm not in cls.llm_models.keys():
|
|
191
|
+
continue
|
|
192
|
+
assert isinstance(llm, str)
|
|
193
|
+
for param_name in ['path', 'key', 'api_url']:
|
|
194
|
+
param_value = getattr(params, param_name)
|
|
195
|
+
if not param_value:
|
|
196
|
+
continue
|
|
197
|
+
log.debug(f"[LLM config]: config {param_name} for {llm}")
|
|
198
|
+
cls_llm: BaseLLM = cls.llm_models[llm]
|
|
199
|
+
setattr(cls_llm, param_name, param_value)
|
|
200
|
+
|
|
201
|
+
@classmethod
|
|
202
|
+
def load_model(cls):
|
|
203
|
+
if cls.module_loaded:
|
|
204
|
+
return
|
|
205
|
+
this_module_directory = os.path.dirname(os.path.abspath(__file__))
|
|
206
|
+
# rule auto register
|
|
207
|
+
for file in os.listdir(os.path.join(this_module_directory, 'rule')):
|
|
208
|
+
path = os.path.join(this_module_directory, 'rule', file)
|
|
209
|
+
if os.path.isfile(path) and file.endswith('.py') and not file == '__init__.py':
|
|
210
|
+
try:
|
|
211
|
+
importlib.import_module('dingo.model.rule.' + file.split('.')[0])
|
|
212
|
+
except ModuleNotFoundError as e:
|
|
213
|
+
log.debug(e)
|
|
214
|
+
|
|
215
|
+
# llm auto register
|
|
216
|
+
for file in os.listdir(os.path.join(this_module_directory, 'llm')):
|
|
217
|
+
path = os.path.join(this_module_directory, 'llm', file)
|
|
218
|
+
if os.path.isfile(path) and file.endswith('.py') and not file == '__init__.py':
|
|
219
|
+
try:
|
|
220
|
+
importlib.import_module('dingo.model.llm.' + file.split('.')[0])
|
|
221
|
+
except ModuleNotFoundError as e:
|
|
222
|
+
log.debug(e)
|
|
223
|
+
except ImportError as e:
|
|
224
|
+
log.debug("=" * 30 + " ImportError " + "=" * 30)
|
|
225
|
+
log.debug(f'module {file.split(".")[0]} not imported because: \n{e}')
|
|
226
|
+
log.debug("=" * 73)
|
|
227
|
+
cls.module_loaded = True
|
|
File without changes
|
dingo/model/rule/base.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from typing import Protocol, List, Union
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ResModel(BaseModel):
|
|
6
|
+
error_status: bool = False
|
|
7
|
+
error_reason: str = ''
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseRule(Protocol):
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def eval(cls, input_data: List[str]) -> ResModel:
|
|
14
|
+
...
|