dingo-python 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dingo/__init__.py +0 -0
- dingo/config/__init__.py +1 -0
- dingo/config/config.py +47 -0
- dingo/convert/__init__.py +4 -0
- dingo/convert/base.py +147 -0
- dingo/exec/__init__.py +3 -0
- dingo/exec/base.py +54 -0
- dingo/exec/local.py +288 -0
- dingo/exec/spark.py +169 -0
- dingo/io/__init__.py +2 -0
- dingo/io/export.py +0 -0
- dingo/io/input.py +27 -0
- dingo/io/summary.py +28 -0
- dingo/model/__init__.py +3 -0
- dingo/model/llm/__init__.py +0 -0
- dingo/model/llm/base.py +12 -0
- dingo/model/llm/common/__init__.py +0 -0
- dingo/model/llm/common/base_llm.py +395 -0
- dingo/model/llm/common/base_llm_api.py +396 -0
- dingo/model/llm/common/openai_api.py +222 -0
- dingo/model/llm/common/turbomind_api.py +148 -0
- dingo/model/llm/gpt.py +62 -0
- dingo/model/llm/llama3.py +97 -0
- dingo/model/llm/perspective.py +68 -0
- dingo/model/model.py +227 -0
- dingo/model/rule/__init__.py +0 -0
- dingo/model/rule/base.py +14 -0
- dingo/model/rule/common_rule.py +551 -0
- dingo/model/rule/image_rule.py +81 -0
- dingo/model/rule/prompt_rule.py +39 -0
- dingo/model/rule/util.py +282 -0
- dingo/utils/__init__.py +1 -0
- dingo/utils/log_util/__init__.py +32 -0
- dingo/utils/log_util/logger.py +39 -0
- dingo_python-1.0.dist-info/LICENSE +201 -0
- dingo_python-1.0.dist-info/METADATA +221 -0
- dingo_python-1.0.dist-info/RECORD +39 -0
- dingo_python-1.0.dist-info/WHEEL +5 -0
- dingo_python-1.0.dist-info/top_level.txt +1 -0
dingo/exec/spark.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import json
|
|
3
|
+
import orjson
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Protocol, List, Dict, Any, Callable, Optional
|
|
7
|
+
|
|
8
|
+
from pyspark import SparkConf
|
|
9
|
+
from pyspark.sql import SparkSession, Row, DataFrame
|
|
10
|
+
from pyspark.sql.functions import explode, count, col, format_number
|
|
11
|
+
from pyspark.sql.types import StructType, StructField, StringType, BooleanType, ArrayType
|
|
12
|
+
|
|
13
|
+
from dingo.model import Model
|
|
14
|
+
from dingo.model.rule.base import BaseRule, ResModel as RuleResModel
|
|
15
|
+
|
|
16
|
+
QUALITY_MAP = Model.rule_metric_type_map
|
|
17
|
+
|
|
18
|
+
class SparkExecutor():
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.spark: Optional[SparkSession] = None
|
|
21
|
+
self.input_df: Optional[DataFrame] = None
|
|
22
|
+
self.convert_df: Optional[DataFrame] = None
|
|
23
|
+
self.output_df: Optional[DataFrame] = None
|
|
24
|
+
self.summary = {
|
|
25
|
+
'score': 0.0,
|
|
26
|
+
'num_good': 0,
|
|
27
|
+
'num_bad': 0,
|
|
28
|
+
'total': 0,
|
|
29
|
+
'error_ratio': {
|
|
30
|
+
"QUALITY_SIGNAL_EFFECTIVENESS": 0.0,
|
|
31
|
+
"QUALITY_SIGNAL_COMPLETENESS": 0.0,
|
|
32
|
+
"QUALITY_SIGNAL_UNDERSTANDABILITY": 0.0,
|
|
33
|
+
"QUALITY_SIGNAL_SIMILARITY": 0.0,
|
|
34
|
+
"QUALITY_SIGNAL_FLUENCY": 0.0,
|
|
35
|
+
"QUALITY_SIGNAL_RELEVANCE": 0.0,
|
|
36
|
+
"QUALITY_SIGNAL_SECURITY": 0.0
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
def set_spark(self, spark: SparkSession):
|
|
41
|
+
self.spark = spark
|
|
42
|
+
|
|
43
|
+
def set_input_df(self, df: DataFrame):
|
|
44
|
+
self.input_df = df
|
|
45
|
+
|
|
46
|
+
def get_spark(self):
|
|
47
|
+
return self.spark
|
|
48
|
+
|
|
49
|
+
def get_input_df(self):
|
|
50
|
+
return self.input_df
|
|
51
|
+
|
|
52
|
+
def get_convert_df(self):
|
|
53
|
+
return self.convert_df
|
|
54
|
+
|
|
55
|
+
def get_output_df(self):
|
|
56
|
+
return self.output_df
|
|
57
|
+
|
|
58
|
+
def get_summary(self):
|
|
59
|
+
return self.summary
|
|
60
|
+
|
|
61
|
+
def create_spark(self, conf: SparkConf):
|
|
62
|
+
try:
|
|
63
|
+
self.spark = SparkSession.builder.config(conf=conf).enableHiveSupport().getOrCreate() # type: ignore
|
|
64
|
+
except:
|
|
65
|
+
self.spark = SparkSession.builder.config(conf=conf).getOrCreate() # type: ignore
|
|
66
|
+
|
|
67
|
+
def convert_data(
|
|
68
|
+
self,
|
|
69
|
+
column_content: List[str],
|
|
70
|
+
column_id: List[str] = None,
|
|
71
|
+
column_prompt: List[str] = None,
|
|
72
|
+
):
|
|
73
|
+
def func(row: Row) -> Row:
|
|
74
|
+
data = orjson.loads(row.value)
|
|
75
|
+
new_data = {
|
|
76
|
+
'data_id': find_nested_data(data, column_id) if column_id is not None else str(uuid.uuid4()),
|
|
77
|
+
'prompt': find_nested_data(data, column_prompt) if column_prompt is not None else '',
|
|
78
|
+
'content': find_nested_data(data, column_content),
|
|
79
|
+
}
|
|
80
|
+
return Row(value=orjson.dumps(new_data).decode("utf-8"))
|
|
81
|
+
|
|
82
|
+
convert_df = self.input_df.rdd.map(func).toDF()
|
|
83
|
+
self.summary['total'] = convert_df.count()
|
|
84
|
+
self.convert_df = convert_df
|
|
85
|
+
|
|
86
|
+
# # @abstractmethod
|
|
87
|
+
# # def evaluate(self) -> List[SummaryModel]:
|
|
88
|
+
# # raise NotImplementedError()
|
|
89
|
+
|
|
90
|
+
def summarize(self):
|
|
91
|
+
self.summary['num_good'] = self.summary['total'] - self.summary['num_bad']
|
|
92
|
+
self.summary['score'] = round(self.summary['num_good'] / self.summary['total'] * 100, 2) if self.summary['total'] != 0 else 0
|
|
93
|
+
|
|
94
|
+
def extract_error_info(row):
|
|
95
|
+
data = orjson.loads(row.value)
|
|
96
|
+
return Row(id=data['data_id'], error_functions=data['quality_signals'])
|
|
97
|
+
|
|
98
|
+
schema = StructType([
|
|
99
|
+
StructField("data_id", StringType(), True),
|
|
100
|
+
StructField("quality_signals", ArrayType(StringType()), True)
|
|
101
|
+
])
|
|
102
|
+
|
|
103
|
+
df_error_info = self.spark.createDataFrame(self.output_df.rdd.map(extract_error_info), schema=schema)
|
|
104
|
+
|
|
105
|
+
df_exploded = df_error_info.select("data_id", explode("quality_signals").alias("quality_signal"))
|
|
106
|
+
df_grouped = df_exploded.groupBy("quality_signal").agg(count("*").alias("count"))
|
|
107
|
+
df_grouped = df_grouped.withColumn("ratio", format_number(col("count") / self.summary["total"], 6))
|
|
108
|
+
|
|
109
|
+
rows = df_grouped.collect()
|
|
110
|
+
for row in rows:
|
|
111
|
+
quality_signal = row['quality_signal']
|
|
112
|
+
ratio = row['ratio']
|
|
113
|
+
self.summary['error_ratio'][quality_signal] = ratio
|
|
114
|
+
|
|
115
|
+
def execute(self, rule_list: List[str]):
|
|
116
|
+
def func_exec(row: Row):
|
|
117
|
+
data = orjson.loads(row.value)
|
|
118
|
+
new_data = execute_rule(rule_list, data)
|
|
119
|
+
return Row(value=orjson.dumps(new_data).decode("utf-8"))
|
|
120
|
+
|
|
121
|
+
def func_filter(row: Row):
|
|
122
|
+
return orjson.loads(row.value)['error_status'] is True
|
|
123
|
+
|
|
124
|
+
self.output_df = self.convert_df.rdd.map(func_exec).toDF()
|
|
125
|
+
self.output_df = self.output_df.rdd.filter(func_filter).toDF()
|
|
126
|
+
self.summary['num_bad'] = self.output_df.count()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def find_nested_data(jsn: json, levels: List[str]):
|
|
130
|
+
data = jsn
|
|
131
|
+
for key in levels:
|
|
132
|
+
data = data[key]
|
|
133
|
+
return data
|
|
134
|
+
|
|
135
|
+
def get_quality_signal(rule: BaseRule):
|
|
136
|
+
for quality_signal in QUALITY_MAP:
|
|
137
|
+
for rule_class in QUALITY_MAP[quality_signal]:
|
|
138
|
+
if rule.__name__ == rule_class.__name__:
|
|
139
|
+
return quality_signal
|
|
140
|
+
|
|
141
|
+
raise RuntimeError('this rule can not find its quality_signal: ' + rule.__name__)
|
|
142
|
+
|
|
143
|
+
def execute_rule(rule_list: List[str], data: json) -> json:
|
|
144
|
+
data['error_status'] = False
|
|
145
|
+
data['error_functions'] = []
|
|
146
|
+
data['quality_signals'] = []
|
|
147
|
+
|
|
148
|
+
model: List[BaseRule] = []
|
|
149
|
+
for rule in rule_list:
|
|
150
|
+
assert isinstance(rule, str)
|
|
151
|
+
if rule not in Model.rule_name_map:
|
|
152
|
+
raise KeyError(f"{rule} not in Model.rule_name_map, there are {str(Model.rule_name_map.keys())}")
|
|
153
|
+
model.append(Model.rule_name_map[rule])
|
|
154
|
+
|
|
155
|
+
for rule_class in model:
|
|
156
|
+
rule_name = rule_class.__name__
|
|
157
|
+
if rule_name.startswith('Prompt'):
|
|
158
|
+
tmp: RuleResModel = rule_class.eval([data["prompt"], data["content"]])
|
|
159
|
+
else:
|
|
160
|
+
tmp: RuleResModel = rule_class.eval([data["content"]])
|
|
161
|
+
|
|
162
|
+
if tmp.error_status:
|
|
163
|
+
data['error_status'] = True
|
|
164
|
+
data['error_functions'].append(rule_name)
|
|
165
|
+
quality_signal = get_quality_signal(rule_class)
|
|
166
|
+
if quality_signal not in data['quality_signals']:
|
|
167
|
+
data['quality_signals'].append(quality_signal)
|
|
168
|
+
|
|
169
|
+
return data
|
dingo/io/__init__.py
ADDED
dingo/io/export.py
ADDED
|
File without changes
|
dingo/io/input.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class InputModel(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Input model, output of converter.
|
|
9
|
+
"""
|
|
10
|
+
data_id: str
|
|
11
|
+
prompt: str
|
|
12
|
+
content: str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RawInputModel(BaseModel):
|
|
16
|
+
"""
|
|
17
|
+
Dataset model, output of converter.
|
|
18
|
+
"""
|
|
19
|
+
dataset_id: str = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
20
|
+
eval_models: List[str] = ['default']
|
|
21
|
+
input_path: str = "data/inputs/test_data1.json"
|
|
22
|
+
output_path: str = "data/outputs/"
|
|
23
|
+
data_type: str = "json"
|
|
24
|
+
column_content: List[str] = []
|
|
25
|
+
column_id: List[str] = []
|
|
26
|
+
column_prompt: List[str] = []
|
|
27
|
+
custom_config_path: Optional[str] = None
|
dingo/io/summary.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SummaryModel(BaseModel):
|
|
7
|
+
dataset_id: str
|
|
8
|
+
input_model: str
|
|
9
|
+
input_path: str
|
|
10
|
+
output_path: str
|
|
11
|
+
score: float
|
|
12
|
+
num_good: int
|
|
13
|
+
num_bad: int
|
|
14
|
+
total: int
|
|
15
|
+
error_ratio: Dict[str, float]
|
|
16
|
+
|
|
17
|
+
def to_dict(self):
|
|
18
|
+
return {
|
|
19
|
+
'dataset_id': self.dataset_id,
|
|
20
|
+
'input_model': self.input_model,
|
|
21
|
+
'input_path': self.input_path,
|
|
22
|
+
'output_path': self.output_path,
|
|
23
|
+
'score': self.score,
|
|
24
|
+
'num_good': self.num_good,
|
|
25
|
+
'num_bad': self.num_bad,
|
|
26
|
+
'total': self.total,
|
|
27
|
+
'error_ratio': self.error_ratio
|
|
28
|
+
}
|
dingo/model/__init__.py
ADDED
|
File without changes
|
dingo/model/llm/base.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from dingo.utils import log
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
PromptType = str
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseLLMModel(ABC):
|
|
12
|
+
"""Base class for model wrapper.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
path (str): The path to the model.
|
|
16
|
+
max_seq_len (int): The maximum sequence length of the model. Defaults
|
|
17
|
+
to 2048.
|
|
18
|
+
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
|
19
|
+
Defaults to False.
|
|
20
|
+
meta_template (Dict, optional): The model's meta prompt
|
|
21
|
+
template if needed, in case the requirement of injecting or
|
|
22
|
+
wrapping of any meta instructions.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
is_api: bool = False
|
|
26
|
+
|
|
27
|
+
def __init__(self,
|
|
28
|
+
path: str,
|
|
29
|
+
max_seq_len: int = 2048,
|
|
30
|
+
tokenizer_only: bool = False,
|
|
31
|
+
meta_template: Optional[Dict] = None):
|
|
32
|
+
self.model = None
|
|
33
|
+
self.path = path
|
|
34
|
+
self.max_seq_len = max_seq_len
|
|
35
|
+
self.tokenizer_only = tokenizer_only
|
|
36
|
+
# meta template
|
|
37
|
+
self.template_parser = LMTemplateParser(meta_template)
|
|
38
|
+
self.eos_token_id = None
|
|
39
|
+
if meta_template and 'eos_token_id' in meta_template:
|
|
40
|
+
self.eos_token_id = meta_template['eos_token_id']
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
|
|
44
|
+
"""Generate results given a list of inputs.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
inputs (List[str]): A list of strings.
|
|
48
|
+
max_out_len (int): The maximum length of the output.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
List[str]: A list of generated strings.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def get_ppl(self,
|
|
56
|
+
inputs: List[str],
|
|
57
|
+
mask_length: Optional[List[int]] = None) -> List[float]:
|
|
58
|
+
"""Get perplexity scores given a list of inputs.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
inputs (List[str]): A list of strings.
|
|
62
|
+
mask_length (Optional[List[int]]): A list of mask lengths. If
|
|
63
|
+
provided, the perplexity scores will be calculated with the
|
|
64
|
+
first mask_length[i] tokens masked out. It's okay to skip
|
|
65
|
+
its implementation if advanced features in PPLInfernecer is
|
|
66
|
+
not needed.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List[float]: A list of perplexity scores.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def get_token_len(self, prompt: str) -> int:
|
|
74
|
+
"""Get lengths of the tokenized strings.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
prompt (str): Input string.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
int: Length of the input tokens
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def parse_template(self, prompt_template: PromptType, mode: str) -> str:
|
|
84
|
+
"""Parse a prompt template, and wrap it with meta template if
|
|
85
|
+
applicable.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
prompt_template (List[str or PromptList]): A prompt
|
|
89
|
+
template (potentially before being wrapped by meta template).
|
|
90
|
+
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
str: The final string.
|
|
94
|
+
"""
|
|
95
|
+
return self.template_parser.parse_template(prompt_template, mode)
|
|
96
|
+
|
|
97
|
+
def get_ppl_from_template(self,
|
|
98
|
+
templates: List[PromptType],
|
|
99
|
+
mask_length=None):
|
|
100
|
+
"""Get perplexity given a list of templates.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
templates (List[PromptType]): A list of templates.
|
|
104
|
+
mask_length (List[int]): A list of mask lengths. If provided, the
|
|
105
|
+
perplexity will be calculated only on the unmasked tokens.
|
|
106
|
+
"""
|
|
107
|
+
inputs = self.parse_template(templates, mode='ppl')
|
|
108
|
+
return self.get_ppl(inputs, mask_length)
|
|
109
|
+
|
|
110
|
+
def generate_from_template(self, templates: List[PromptType],
|
|
111
|
+
max_out_len: int, **kwargs):
|
|
112
|
+
"""Generate completion from a list of templates.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
templates (List[PromptType]): A list of templates.
|
|
116
|
+
max_out_len (int): The maximum length of the output.
|
|
117
|
+
"""
|
|
118
|
+
inputs = self.parse_template(templates, mode='gen')
|
|
119
|
+
return self.generate(inputs, max_out_len=max_out_len, **kwargs)
|
|
120
|
+
|
|
121
|
+
def get_token_len_from_template(
|
|
122
|
+
self,
|
|
123
|
+
templates: Union[PromptType, List[PromptType]],
|
|
124
|
+
mode: str = 'ppl') -> Union[List[int], int]:
|
|
125
|
+
"""Get lengths given a list of templates.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
templates (Union[List[str], str]): Input template(s).
|
|
129
|
+
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Union[List[int], int]: Length(s) of the input tokens. If the input
|
|
133
|
+
is a list, a list of lengths will be returned. Otherwise, an int
|
|
134
|
+
will be returned.
|
|
135
|
+
"""
|
|
136
|
+
prompts = self.parse_template(templates, mode=mode)
|
|
137
|
+
assert isinstance(prompts, (list, str)), 'tokens must be list or str'
|
|
138
|
+
is_batched = isinstance(prompts,
|
|
139
|
+
list) and not isinstance(prompts, str)
|
|
140
|
+
if not is_batched:
|
|
141
|
+
prompts = [prompts]
|
|
142
|
+
prompts = [str(prompt) for prompt in prompts]
|
|
143
|
+
token_lens = [self.get_token_len(prompt) for prompt in prompts]
|
|
144
|
+
return token_lens[0] if not is_batched else token_lens
|
|
145
|
+
|
|
146
|
+
def to(self, device):
|
|
147
|
+
self.model.to(device)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class LMTemplateParser:
|
|
151
|
+
"""Intermediate prompt template parser, specifically for language models.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
meta_template (Dict): The meta template for the model.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def __init__(self, meta_template: Optional[Dict] = None):
|
|
158
|
+
self.meta_template = meta_template
|
|
159
|
+
if meta_template:
|
|
160
|
+
assert 'round' in meta_template, 'round is required in meta' \
|
|
161
|
+
' template'
|
|
162
|
+
assert isinstance(meta_template['round'], list)
|
|
163
|
+
keys_to_check = ['round']
|
|
164
|
+
|
|
165
|
+
if 'reserved_roles' in meta_template:
|
|
166
|
+
assert isinstance(meta_template['reserved_roles'], list)
|
|
167
|
+
keys_to_check.append('reserved_roles')
|
|
168
|
+
|
|
169
|
+
self.roles: Dict[str, dict] = dict() # maps role name to config
|
|
170
|
+
for meta_key in keys_to_check:
|
|
171
|
+
for item in meta_template[meta_key]:
|
|
172
|
+
assert isinstance(item, (str, dict))
|
|
173
|
+
if isinstance(item, dict):
|
|
174
|
+
assert item['role'] not in self.roles, \
|
|
175
|
+
'role in meta prompt must be unique!'
|
|
176
|
+
self.roles[item['role']] = item.copy()
|
|
177
|
+
# convert list of string and int into a raw string
|
|
178
|
+
# for the ease of future prompt processing
|
|
179
|
+
for key in ['begin', 'end']:
|
|
180
|
+
value = self.roles[item['role']].get(key, '')
|
|
181
|
+
if isinstance(value, list):
|
|
182
|
+
self.roles[item['role']][
|
|
183
|
+
key] = self._encode_speical_tokens(value)
|
|
184
|
+
|
|
185
|
+
def parse_template(self, prompt_template: PromptType, mode: str) -> str:
|
|
186
|
+
"""Parse a prompt template, and wrap it with meta template if
|
|
187
|
+
applicable.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
prompt_template (List[str or PromptList]): A prompt
|
|
191
|
+
template (potentially before being wrapped by meta template).
|
|
192
|
+
mode (str): Parsing mode. Choices are 'ppl' and 'gen'.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
str: The final string.
|
|
196
|
+
"""
|
|
197
|
+
assert isinstance(prompt_template, (str, list))
|
|
198
|
+
if not isinstance(prompt_template, (str)):
|
|
199
|
+
return [self.parse_template(p, mode=mode) for p in prompt_template]
|
|
200
|
+
|
|
201
|
+
assert mode in ['ppl', 'gen']
|
|
202
|
+
if isinstance(prompt_template, str):
|
|
203
|
+
return prompt_template
|
|
204
|
+
if self.meta_template:
|
|
205
|
+
|
|
206
|
+
prompt = ''
|
|
207
|
+
# Whether to keep generating the prompt
|
|
208
|
+
generate = True
|
|
209
|
+
|
|
210
|
+
section_stack = [] # stores tuples: (section_name, start_idx)
|
|
211
|
+
|
|
212
|
+
for i, item in enumerate(prompt_template):
|
|
213
|
+
if not generate:
|
|
214
|
+
break
|
|
215
|
+
if isinstance(item, str):
|
|
216
|
+
prompt += item
|
|
217
|
+
elif isinstance(item, dict) and 'section' in item:
|
|
218
|
+
if item['pos'] == 'end':
|
|
219
|
+
section_name, start_idx = section_stack.pop(-1)
|
|
220
|
+
assert section_name == item['section']
|
|
221
|
+
if section_name in ['round', 'ice']:
|
|
222
|
+
dialogue = prompt_template[start_idx:i]
|
|
223
|
+
round_ranges = self._split_rounds(
|
|
224
|
+
dialogue, self.meta_template['round'])
|
|
225
|
+
# Consider inserting multiple round examples into
|
|
226
|
+
# template
|
|
227
|
+
for i in range(len(round_ranges) - 1):
|
|
228
|
+
start = round_ranges[i]
|
|
229
|
+
end = round_ranges[i + 1]
|
|
230
|
+
round_template = dialogue[start:end]
|
|
231
|
+
role_dict = self._update_role_dict(
|
|
232
|
+
round_template)
|
|
233
|
+
new_str, generate = self._prompt2str(
|
|
234
|
+
self.meta_template['round'],
|
|
235
|
+
role_dict,
|
|
236
|
+
# Start generating only when the mode is in
|
|
237
|
+
# generation and the template reaches the
|
|
238
|
+
# last round
|
|
239
|
+
for_gen=mode == 'gen'
|
|
240
|
+
and i == len(round_ranges) - 2
|
|
241
|
+
and section_name == 'round')
|
|
242
|
+
prompt += new_str
|
|
243
|
+
elif item['pos'] == 'begin':
|
|
244
|
+
assert item['section'] in [
|
|
245
|
+
'begin', 'round', 'end', 'ice'
|
|
246
|
+
]
|
|
247
|
+
section_stack.append((item['section'], i + 1))
|
|
248
|
+
else:
|
|
249
|
+
raise ValueError(f'Invalid pos {item["pos"]}')
|
|
250
|
+
elif section_stack[-1][0] in ['begin', 'end']:
|
|
251
|
+
role_dict = self._update_role_dict(item)
|
|
252
|
+
new_str, generate = self._prompt2str(item,
|
|
253
|
+
role_dict,
|
|
254
|
+
for_gen=mode == 'gen')
|
|
255
|
+
prompt += new_str
|
|
256
|
+
|
|
257
|
+
prompt = self.meta_template.get('begin', '') + prompt
|
|
258
|
+
if generate:
|
|
259
|
+
prompt += self.meta_template.get('end', '')
|
|
260
|
+
|
|
261
|
+
else:
|
|
262
|
+
# in case the model does not have any meta template
|
|
263
|
+
prompt = ''
|
|
264
|
+
last_sep = ''
|
|
265
|
+
for item in prompt_template:
|
|
266
|
+
if isinstance(item, dict) and {'section', 'pos'} == set(
|
|
267
|
+
item.keys()):
|
|
268
|
+
continue
|
|
269
|
+
if isinstance(item, str):
|
|
270
|
+
if item:
|
|
271
|
+
prompt += last_sep + item
|
|
272
|
+
elif item.get('prompt', ''): # it's a dict
|
|
273
|
+
prompt += last_sep + item.get('prompt', '')
|
|
274
|
+
last_sep = '\n'
|
|
275
|
+
return prompt
|
|
276
|
+
|
|
277
|
+
def _split_rounds(
|
|
278
|
+
self, prompt_template: List[Union[str, Dict]],
|
|
279
|
+
single_round_template: List[Union[str, Dict]]) -> List[int]:
|
|
280
|
+
"""Split the prompt template into rounds, based on single round
|
|
281
|
+
template.
|
|
282
|
+
|
|
283
|
+
Return the index ranges of each round. Specifically,
|
|
284
|
+
prompt_template[res[i]:res[i+1]] represents the i-th round in the
|
|
285
|
+
template.
|
|
286
|
+
"""
|
|
287
|
+
role_idxs = {
|
|
288
|
+
role_cfg['role']: i
|
|
289
|
+
for i, role_cfg in enumerate(single_round_template)
|
|
290
|
+
if not isinstance(role_cfg, str)
|
|
291
|
+
}
|
|
292
|
+
last_role_idx = -1
|
|
293
|
+
cutoff_idxs = [0]
|
|
294
|
+
for idx, template in enumerate(prompt_template):
|
|
295
|
+
if isinstance(template, str):
|
|
296
|
+
continue
|
|
297
|
+
role_idx = role_idxs[template['role']]
|
|
298
|
+
if role_idx <= last_role_idx:
|
|
299
|
+
cutoff_idxs.append(idx)
|
|
300
|
+
last_role_idx = role_idx
|
|
301
|
+
cutoff_idxs.append(len(prompt_template))
|
|
302
|
+
return cutoff_idxs
|
|
303
|
+
|
|
304
|
+
def _update_role_dict(self, prompt: Union[List, str,
|
|
305
|
+
Dict]) -> Dict[str, Dict]:
|
|
306
|
+
"""Update the default role dict with the given prompt(s)."""
|
|
307
|
+
assert isinstance(prompt, (str, list, dict))
|
|
308
|
+
role_dict = deepcopy(self.roles)
|
|
309
|
+
if isinstance(prompt, str):
|
|
310
|
+
return role_dict
|
|
311
|
+
if isinstance(prompt, dict):
|
|
312
|
+
prompt = [prompt]
|
|
313
|
+
for p in prompt:
|
|
314
|
+
if isinstance(p, dict):
|
|
315
|
+
role = p['role']
|
|
316
|
+
if role not in self.roles:
|
|
317
|
+
role = p.get('fallback_role', None)
|
|
318
|
+
if not role:
|
|
319
|
+
log.info(f'{p} neither has an appropriate role nor a fallback role.')
|
|
320
|
+
role_dict[role].update(p)
|
|
321
|
+
return role_dict
|
|
322
|
+
|
|
323
|
+
def _prompt2str(self,
|
|
324
|
+
prompt: Union[List, str, Dict],
|
|
325
|
+
role_dict: Dict[str, Dict],
|
|
326
|
+
for_gen: bool = False) -> Tuple[str, bool]:
|
|
327
|
+
"""Convert the prompts to a string, given an updated role_dict.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
prompts (Union[List, str, dict]): The prompt(s) to be converted.
|
|
331
|
+
role_dict (Dict[str, Dict]): The updated role dict.
|
|
332
|
+
for_gen (bool): If True, the prompts will be converted for
|
|
333
|
+
generation tasks. The conversion stops before the first
|
|
334
|
+
role whose "generate" is set to True.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Tuple[str, bool]: The converted string, and whether the follow-up
|
|
338
|
+
conversion should be proceeded.
|
|
339
|
+
"""
|
|
340
|
+
assert isinstance(prompt, (list, str, dict))
|
|
341
|
+
|
|
342
|
+
if isinstance(prompt, str):
|
|
343
|
+
return prompt, True
|
|
344
|
+
if isinstance(prompt, dict):
|
|
345
|
+
return self._role2str(prompt, role_dict, for_gen)
|
|
346
|
+
|
|
347
|
+
res = ''
|
|
348
|
+
for p in prompt:
|
|
349
|
+
new_str, cont = self._prompt2str(p, role_dict, for_gen)
|
|
350
|
+
res += new_str
|
|
351
|
+
if not cont:
|
|
352
|
+
break
|
|
353
|
+
return res, cont
|
|
354
|
+
|
|
355
|
+
def _role2str(self,
|
|
356
|
+
role_prompt: Dict,
|
|
357
|
+
role_dict: Dict[str, Dict],
|
|
358
|
+
for_gen: bool = False) -> Tuple[str, bool]:
|
|
359
|
+
"""Convert a role prompt to a string, given an updated role_dict.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
role_prompt (Dict): The role prompt to be converted.
|
|
363
|
+
role_dict (Dict[str, Dict]): The updated role dict.
|
|
364
|
+
for_gen (bool): If True, the prompts will be converted for
|
|
365
|
+
generation tasks. The conversion stops before the first
|
|
366
|
+
role whose "generate" is set to True.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Tuple[str, bool]: The converted string, and whether the follow-up
|
|
370
|
+
conversion should be proceeded.
|
|
371
|
+
"""
|
|
372
|
+
merged_prompt = role_dict.get(
|
|
373
|
+
role_prompt['role'],
|
|
374
|
+
role_dict.get(role_prompt.get('fallback_role')))
|
|
375
|
+
res = merged_prompt.get('begin', '')
|
|
376
|
+
if for_gen and merged_prompt.get('generate', False):
|
|
377
|
+
return res, False
|
|
378
|
+
# res += merged_prompt.get('prompt', '') + merged_prompt.get('end', '')
|
|
379
|
+
res += merged_prompt.get('prompt', '') + merged_prompt.get('end', '')
|
|
380
|
+
return res, True
|
|
381
|
+
|
|
382
|
+
def _encode_speical_tokens(self, prompt: List[Union[str, int]]) -> str:
|
|
383
|
+
"""Encode the special tokens in the prompt.
|
|
384
|
+
|
|
385
|
+
Now this is left for the future work
|
|
386
|
+
"""
|
|
387
|
+
raise NotImplementedError('Using List[str|int] is as the begin or end'
|
|
388
|
+
'of a prompt is not supported yet.')
|
|
389
|
+
res = ''
|
|
390
|
+
for item in prompt:
|
|
391
|
+
if isinstance(item, str):
|
|
392
|
+
res += item
|
|
393
|
+
else:
|
|
394
|
+
res += f'<META_TOKEN_{item}>'
|
|
395
|
+
return res
|