gemba 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- gemba/__init__.py +3 -0
- gemba/gemba_da.py +62 -0
- gemba/gemba_esa.py +84 -0
- gemba/gemba_mqm_utils.py +234 -0
- gemba/gpt_api.py +174 -0
- gemba/mtme_tools.py +99 -0
- gemba/prompt.py +139 -0
- gemba/scores.py +103 -0
- gemba/testset.py +58 -0
- gemba/utils.py +78 -0
- gemba-0.1.0.dist-info/METADATA +136 -0
- gemba-0.1.0.dist-info/RECORD +14 -0
- gemba-0.1.0.dist-info/WHEEL +4 -0
- gemba-0.1.0.dist-info/licenses/LICENSE.md +427 -0
gemba/__init__.py
ADDED
gemba/gemba_da.py
ADDED
@@ -0,0 +1,62 @@
|
|
1
|
+
import diskcache as dc
|
2
|
+
from gemba.prompt import prompts, language_codes
|
3
|
+
from gemba.gpt_api import GptApi
|
4
|
+
from gemba.testset import Testset
|
5
|
+
from gemba.scores import Scores
|
6
|
+
|
7
|
+
|
8
|
+
def main():
|
9
|
+
scenarios = [
|
10
|
+
["text-davinci-003", "GEMBA-DA", [["wmt22", "en-de"], ["wmt22", "zh-en"], ["wmt22", "en-ru"]], ],
|
11
|
+
["text-davinci-003", "GEMBA-DA_ref", [["wmt22", "en-de"], ["wmt22", "zh-en"], ["wmt22", "en-ru"]], ],
|
12
|
+
]
|
13
|
+
|
14
|
+
gptapi = GptApi()
|
15
|
+
for scenario in scenarios:
|
16
|
+
use_model = scenario[0]
|
17
|
+
annotation = scenario[1]
|
18
|
+
cache = dc.Cache(f'cache/{use_model}_{annotation}', expire=None, size_limit=int(10e10), cull_limit=0, eviction_policy='none')
|
19
|
+
|
20
|
+
scoring_name = f"{annotation}_{use_model}"
|
21
|
+
|
22
|
+
if use_model not in credentials["deployments"].keys():
|
23
|
+
print(f"Model {use_model} not supported by credentials")
|
24
|
+
continue
|
25
|
+
|
26
|
+
for dataset, lp in scenario[2]:
|
27
|
+
testset = Testset("mt-metrics-eval-v2", dataset, lp)
|
28
|
+
if prompts[annotation]["use_ref"]:
|
29
|
+
refname = testset.main_ref
|
30
|
+
else:
|
31
|
+
refname = None
|
32
|
+
|
33
|
+
scores = Scores(scoring_name, testset, refname)
|
34
|
+
|
35
|
+
# starts with -1 as it is incremented before the first request
|
36
|
+
hypothesis_index = -1
|
37
|
+
total = testset.segments_count()
|
38
|
+
for src, hyp, ref, system in testset.iterate_over_all(refname):
|
39
|
+
hypothesis_index += 1
|
40
|
+
|
41
|
+
if scores.get_score(system, hypothesis_index) != 'None':
|
42
|
+
continue
|
43
|
+
|
44
|
+
print(f"Processing hypothesis {hypothesis_index}/{total} for {scoring_name} on {dataset}/{lp}")
|
45
|
+
|
46
|
+
data = {
|
47
|
+
"source_seg": src,
|
48
|
+
"target_seg": hyp,
|
49
|
+
"reference_seg": ref,
|
50
|
+
"source_lang": language_codes[lp.split("-")[0]],
|
51
|
+
"target_lang": language_codes[lp.split("-")[1]],
|
52
|
+
}
|
53
|
+
prompt = prompts[annotation]["prompt"].format(**data)
|
54
|
+
parsed_answers = gptapi.request(prompt, use_model, prompts[annotation]["validate_answer"], cache=cache)
|
55
|
+
|
56
|
+
scores.assign_score(system, hypothesis_index, parsed_answers[0]['answer'], parsed_answers[0]['temperature'])
|
57
|
+
|
58
|
+
scores.save()
|
59
|
+
|
60
|
+
|
61
|
+
if __name__ == '__main__':
|
62
|
+
main()
|
gemba/gemba_esa.py
ADDED
@@ -0,0 +1,84 @@
|
|
1
|
+
import ipdb
|
2
|
+
import json
|
3
|
+
import re
|
4
|
+
from collections import defaultdict
|
5
|
+
|
6
|
+
|
7
|
+
def esa_fewshot(few_shots):
|
8
|
+
prompts = [
|
9
|
+
{
|
10
|
+
"role": "system",
|
11
|
+
"content": f"Your task is to identify machine translation errors and assess the quality of the translation."
|
12
|
+
}
|
13
|
+
]
|
14
|
+
|
15
|
+
template = """{source_lang} source:
|
16
|
+
```{source_seg}```
|
17
|
+
{target_lang} translation:
|
18
|
+
```{target_seg}```
|
19
|
+
|
20
|
+
Based on the source segment and machine translation surrounded with triple backticks, identify error types in the translation and classify them. The categories of errors are: accuracy (addition, mistranslation, omission, untranslated text), fluency (character encoding, grammar, inconsistency, punctuation, register, spelling), style (awkward), terminology (inappropriate for context, inconsistent use), non-translation, other, or no-error.\nEach error is classified as one of two categories: major or minor. Major errors disrupt the flow and make the understandability of text difficult or impossible. Minor errors are errors that do not disrupt the flow significantly and what the text is trying to say is still understandable."""
|
21
|
+
|
22
|
+
for shot in few_shots:
|
23
|
+
prompts.append({
|
24
|
+
"role": "user",
|
25
|
+
"content": template.format(**shot)
|
26
|
+
})
|
27
|
+
answer = shot['answer']
|
28
|
+
|
29
|
+
prompts.append({
|
30
|
+
"role": "assistant",
|
31
|
+
"content": answer
|
32
|
+
})
|
33
|
+
|
34
|
+
prompts.append({
|
35
|
+
"role": "user",
|
36
|
+
"content": template
|
37
|
+
})
|
38
|
+
|
39
|
+
return prompts
|
40
|
+
|
41
|
+
|
42
|
+
esa_few_shots = {
|
43
|
+
"ende": {
|
44
|
+
"source_lang": "English",
|
45
|
+
"source_seg": "I do apologise about this, we must gain permission from the account holder to discuss an order with another person, I apologise if this was done previously, however, I would not be able to discuss this with yourself without the account holders permission.",
|
46
|
+
"target_lang": "German",
|
47
|
+
"target_seg": "Ich entschuldige mich dafür, wir müssen die Erlaubnis einholen, um eine Bestellung mit einer anderen Person zu besprechen. Ich entschuldige mich, falls dies zuvor geschehen wäre, aber ohne die Erlaubnis des Kontoinhabers wäre ich nicht in der Lage, dies mit dir involvement.",
|
48
|
+
"answer": """Major:
|
49
|
+
accuracy/mistranslation - "involvement"
|
50
|
+
accuracy/omission - "the account holder"
|
51
|
+
Minor:
|
52
|
+
fluency/grammar - "wäre"
|
53
|
+
fluency/register - "dir"
|
54
|
+
""",
|
55
|
+
},
|
56
|
+
"encs": {
|
57
|
+
"source_lang": "English",
|
58
|
+
"source_seg": "Talks have resumed in Vienna to try to revive the nuclear pact, with both sides trying to gauge the prospects of success after the latest exchanges in the stop-start negotiations.",
|
59
|
+
"target_lang": "Czech",
|
60
|
+
"target_seg": "Ve Vídni se ve Vídni obnovily rozhovory o oživení jaderného paktu, přičemž obě partaje se snaží posoudit vyhlídky na úspěch po posledních výměnách v jednáních.",
|
61
|
+
"answer": """Major:
|
62
|
+
accuracy/addition - "ve Vídni"
|
63
|
+
accuracy/omission - "the stop-start"
|
64
|
+
Minor:
|
65
|
+
terminology/inappropriate for context - "partaje"
|
66
|
+
""",
|
67
|
+
},
|
68
|
+
"zhen": {
|
69
|
+
"source_lang": "Chinese",
|
70
|
+
"source_seg": "大众点评乌鲁木齐家居卖场频道为您提供高铁居然之家地址,电话,营业时间等最新商户信息,找装修公司,就上大众点评",
|
71
|
+
"target_lang": "English",
|
72
|
+
"target_seg": "Urumqi Home Furnishing Store Channel provides you with the latest business information such as the address, telephone number, business hours, etc., of high-speed rail, and find a decoration company, and go to the reviews.",
|
73
|
+
"answer": """Major:
|
74
|
+
accuracy/addition - "of high-speed rail"
|
75
|
+
accuracy/mistranslation - "go to the reviews"
|
76
|
+
Minor:
|
77
|
+
style/awkward - "etc.,"
|
78
|
+
""",
|
79
|
+
},
|
80
|
+
}
|
81
|
+
|
82
|
+
TEMPLATE_GEMBA_ESA_ERROR_SPANS = esa_fewshot([esa_few_shots['ende'], esa_few_shots['encs'], esa_few_shots['zhen']])
|
83
|
+
|
84
|
+
TEMPLATE_GEMBA_ESA_RANKING = 'Given the translation from {source_lang} to {target_lang} and the annotated error spans, assign a score on a continuous scale from 0 to 100. The scale has following reference points: 0="No meaning preserved", 33="Some meaning preserved", 66="Most meaning preserved and few grammar mistakes", up to 100="Perfect meaning and grammar".\n\nScore the following translation from {source_lang} source:\n```{source_seg}```\n{target_lang} translation:\n```{target_seg}```\nAnnotated error spans:\n```{error_spans}```\nScore (0-100): '
|
gemba/gemba_mqm_utils.py
ADDED
@@ -0,0 +1,234 @@
|
|
1
|
+
import ipdb
|
2
|
+
import json
|
3
|
+
import re
|
4
|
+
from collections import defaultdict
|
5
|
+
|
6
|
+
def apply_template(template, data):
|
7
|
+
if isinstance(template, str):
|
8
|
+
return template.format(**data)
|
9
|
+
elif isinstance(template, list):
|
10
|
+
prompt = []
|
11
|
+
for conversation_turn in template:
|
12
|
+
p = conversation_turn.copy()
|
13
|
+
p['content'] = p['content'].format(**data)
|
14
|
+
prompt.append(p)
|
15
|
+
return prompt
|
16
|
+
else:
|
17
|
+
raise ValueError(f"Unknown template type {type(template)}")
|
18
|
+
|
19
|
+
def parse_broken_json(x):
|
20
|
+
improved_translation = ""
|
21
|
+
errors = defaultdict(list)
|
22
|
+
if '"errors": ' in x and "improved translation" in x:
|
23
|
+
data = x.split('", "errors": ')
|
24
|
+
if len(data) != 2:
|
25
|
+
return {"improved translation": improved_translation, "errors": errors}
|
26
|
+
# from data[0] parse improved translation
|
27
|
+
improved_translation = data[0].split('"improved translation": "')[1]
|
28
|
+
# remove last character from data[1]
|
29
|
+
data[1] = data[1][:-1]
|
30
|
+
|
31
|
+
try:
|
32
|
+
errors = json.loads(data[1])
|
33
|
+
except:
|
34
|
+
# just try to get error count
|
35
|
+
words = re.findall(r'\b\w+\b', data[1].lower())
|
36
|
+
keywords = ['critical', 'major', 'minor']
|
37
|
+
|
38
|
+
last_key = None
|
39
|
+
for word in words:
|
40
|
+
if word in keywords:
|
41
|
+
last_key = word
|
42
|
+
elif last_key is not None and word == "class":
|
43
|
+
errors[last_key].append({"class": "other"})
|
44
|
+
|
45
|
+
return {"improved translation": improved_translation, "errors": errors}
|
46
|
+
|
47
|
+
|
48
|
+
def parse_error_class(error):
|
49
|
+
# parse error from error description, errors are ['accuracy', 'fluency', 'locale convention', 'style', 'terminology', 'non-translation', 'other']
|
50
|
+
# locale convention (currency, date, name, telephone, or time format), style (awkward), terminology (inappropriate for context, inconsistent use),
|
51
|
+
class_name = "unknown"
|
52
|
+
if "accuracy" in error:
|
53
|
+
class_name = "accuracy"
|
54
|
+
for subclass in ["addition", "mistranslation", "omission", "untranslated text"]:
|
55
|
+
if subclass in error:
|
56
|
+
class_name = f"accuracy-{subclass}"
|
57
|
+
elif "fluency" in error:
|
58
|
+
class_name = "fluency"
|
59
|
+
for subclass in ["character encoding", "grammar", "inconsistency", "punctuation", "register", "spelling"]:
|
60
|
+
if subclass in error:
|
61
|
+
class_name = f"fluency-{subclass}"
|
62
|
+
elif "locale convention" in error:
|
63
|
+
class_name = "locale convention"
|
64
|
+
for subclass in ["currency", "date", "name", "telephone", "time"]:
|
65
|
+
if subclass in error:
|
66
|
+
class_name = f"locale convention-{subclass}"
|
67
|
+
elif "style" in error:
|
68
|
+
class_name = "style"
|
69
|
+
elif "terminology" in error:
|
70
|
+
class_name = "terminology"
|
71
|
+
for subclass in ["inappropriate", "inconsistent"]:
|
72
|
+
if subclass in error:
|
73
|
+
class_name = f"terminology-{subclass}"
|
74
|
+
elif "non-translation" in error:
|
75
|
+
class_name = "non-translation"
|
76
|
+
elif "other" in error:
|
77
|
+
class_name = "other"
|
78
|
+
|
79
|
+
return class_name
|
80
|
+
|
81
|
+
|
82
|
+
def parse_mqm_answer(x, list_mqm_errors=False, full_desc=True, normalize=True):
|
83
|
+
if x is None:
|
84
|
+
return None
|
85
|
+
|
86
|
+
x = str(x)
|
87
|
+
if x.startswith('{"improved translation"'):
|
88
|
+
try:
|
89
|
+
x = json.loads(x)
|
90
|
+
except:
|
91
|
+
x = parse_broken_json(x)
|
92
|
+
errors = x["errors"]
|
93
|
+
|
94
|
+
|
95
|
+
else:
|
96
|
+
x = x.lower()
|
97
|
+
errors = {'critical': [], 'major': [], 'minor': []}
|
98
|
+
error_level = None
|
99
|
+
for line in x.split('\n'):
|
100
|
+
line = line.strip()
|
101
|
+
if "no-error" in line or "no error" in line or "" == line:
|
102
|
+
continue
|
103
|
+
if "critical:" == line:
|
104
|
+
error_level = "critical"
|
105
|
+
continue
|
106
|
+
elif "major:" == line:
|
107
|
+
error_level = "major"
|
108
|
+
continue
|
109
|
+
elif "minor:" == line:
|
110
|
+
error_level = "minor"
|
111
|
+
continue
|
112
|
+
|
113
|
+
if "critical" in line or "major" in line or "minor" in line:
|
114
|
+
if not any([line.startswith(x) for x in ['accuracy', 'fluency', 'locale convention', 'style', 'terminology', 'non-translation', 'other']]):
|
115
|
+
print(line)
|
116
|
+
|
117
|
+
if error_level is None:
|
118
|
+
print(f"No error level for {line}")
|
119
|
+
continue
|
120
|
+
|
121
|
+
if "non-translation" in line:
|
122
|
+
errors["critical"].append(line)
|
123
|
+
else:
|
124
|
+
errors[error_level].append(line)
|
125
|
+
|
126
|
+
error_classes = defaultdict(list)
|
127
|
+
final_score = 0
|
128
|
+
error_counter = 0
|
129
|
+
for error_level in ['critical', 'major', 'minor']:
|
130
|
+
if error_level not in errors:
|
131
|
+
continue
|
132
|
+
for error in errors[error_level]:
|
133
|
+
if error_counter < 5:
|
134
|
+
final_score += 25 if error_level == 'critical' else 5 if error_level == 'major' else 1
|
135
|
+
error_counter += 1
|
136
|
+
|
137
|
+
if full_desc:
|
138
|
+
error_classes[error_level].append(error)
|
139
|
+
else:
|
140
|
+
class_name = parse_error_class(error)
|
141
|
+
error_classes[error_level].append(class_name)
|
142
|
+
if final_score > 25:
|
143
|
+
final_score = 25
|
144
|
+
|
145
|
+
# negative score is to normalize that higher score is better
|
146
|
+
return_score = (-final_score * 4 + 100) if normalize else -final_score
|
147
|
+
if list_mqm_errors:
|
148
|
+
return return_score, error_classes
|
149
|
+
else:
|
150
|
+
return return_score
|
151
|
+
|
152
|
+
|
153
|
+
def mqm_fewshot(few_shots):
|
154
|
+
prompts = [
|
155
|
+
{
|
156
|
+
"role": "system",
|
157
|
+
"content": f"You are an annotator for the quality of machine translation. Your task is to identify errors and assess the quality of the translation."
|
158
|
+
}
|
159
|
+
]
|
160
|
+
|
161
|
+
template = """{source_lang} source:
|
162
|
+
```{source_seg}```
|
163
|
+
{target_lang} translation:
|
164
|
+
```{target_seg}```
|
165
|
+
|
166
|
+
Based on the source segment and machine translation surrounded with triple backticks, identify error types in the translation and classify them. The categories of errors are: accuracy (addition, mistranslation, omission, untranslated text), fluency (character encoding, grammar, inconsistency, punctuation, register, spelling), style (awkward), terminology (inappropriate for context, inconsistent use), non-translation, other, or no-error.\nEach error is classified as one of three categories: critical, major, and minor. Critical errors inhibit comprehension of the text. Major errors disrupt the flow, but what the text is trying to say is still understandable. Minor errors are technically errors, but do not disrupt the flow or hinder comprehension."""
|
167
|
+
|
168
|
+
for shot in few_shots:
|
169
|
+
prompts.append({
|
170
|
+
"role": "user",
|
171
|
+
"content": template.format(**shot)
|
172
|
+
})
|
173
|
+
answer = shot['answer']
|
174
|
+
|
175
|
+
prompts.append({
|
176
|
+
"role": "assistant",
|
177
|
+
"content": answer
|
178
|
+
})
|
179
|
+
|
180
|
+
prompts.append({
|
181
|
+
"role": "user",
|
182
|
+
"content": template
|
183
|
+
})
|
184
|
+
|
185
|
+
return prompts
|
186
|
+
|
187
|
+
|
188
|
+
few_shots = {
|
189
|
+
"ende": {
|
190
|
+
"source_lang": "English",
|
191
|
+
"source_seg": "I do apologise about this, we must gain permission from the account holder to discuss an order with another person, I apologise if this was done previously, however, I would not be able to discuss this with yourself without the account holders permission.",
|
192
|
+
"target_lang": "German",
|
193
|
+
"target_seg": "Ich entschuldige mich dafür, wir müssen die Erlaubnis einholen, um eine Bestellung mit einer anderen Person zu besprechen. Ich entschuldige mich, falls dies zuvor geschehen wäre, aber ohne die Erlaubnis des Kontoinhabers wäre ich nicht in der Lage, dies mit dir involvement.",
|
194
|
+
"answer": """Critical:
|
195
|
+
no-error
|
196
|
+
Major:
|
197
|
+
accuracy/mistranslation - "involvement"
|
198
|
+
accuracy/omission - "the account holder"
|
199
|
+
Minor:
|
200
|
+
fluency/grammar - "wäre"
|
201
|
+
fluency/register - "dir"
|
202
|
+
""",
|
203
|
+
},
|
204
|
+
"encs": {
|
205
|
+
"source_lang": "English",
|
206
|
+
"source_seg": "Talks have resumed in Vienna to try to revive the nuclear pact, with both sides trying to gauge the prospects of success after the latest exchanges in the stop-start negotiations.",
|
207
|
+
"target_lang": "Czech",
|
208
|
+
"target_seg": "Ve Vídni se ve Vídni obnovily rozhovory o oživení jaderného paktu, přičemž obě partaje se snaží posoudit vyhlídky na úspěch po posledních výměnách v jednáních.",
|
209
|
+
"answer": """Critical:
|
210
|
+
no-error
|
211
|
+
Major:
|
212
|
+
accuracy/addition - "ve Vídni"
|
213
|
+
accuracy/omission - "the stop-start"
|
214
|
+
Minor:
|
215
|
+
terminology/inappropriate for context - "partaje"
|
216
|
+
""",
|
217
|
+
},
|
218
|
+
"zhen": {
|
219
|
+
"source_lang": "Chinese",
|
220
|
+
"source_seg": "大众点评乌鲁木齐家居卖场频道为您提供高铁居然之家地址,电话,营业时间等最新商户信息,找装修公司,就上大众点评",
|
221
|
+
"target_lang": "English",
|
222
|
+
"target_seg": "Urumqi Home Furnishing Store Channel provides you with the latest business information such as the address, telephone number, business hours, etc., of high-speed rail, and find a decoration company, and go to the reviews.",
|
223
|
+
"answer": """Critical:
|
224
|
+
accuracy/addition - "of high-speed rail"
|
225
|
+
Major:
|
226
|
+
accuracy/mistranslation - "go to the reviews"
|
227
|
+
Minor:
|
228
|
+
style/awkward - "etc.,"
|
229
|
+
""",
|
230
|
+
},
|
231
|
+
}
|
232
|
+
|
233
|
+
TEMPLATE_GEMBA_MQM = mqm_fewshot([few_shots['ende'], few_shots['encs'], few_shots['zhen']])
|
234
|
+
|
gemba/gpt_api.py
ADDED
@@ -0,0 +1,174 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
import time
|
4
|
+
import ipdb
|
5
|
+
import logging
|
6
|
+
from termcolor import colored
|
7
|
+
from datetime import datetime
|
8
|
+
import openai
|
9
|
+
import tqdm
|
10
|
+
|
11
|
+
|
12
|
+
# class for calling OpenAI API and handling cache
|
13
|
+
class GptApi:
|
14
|
+
def __init__(self, verbose=False):
|
15
|
+
self.verbose = verbose
|
16
|
+
|
17
|
+
if "OPENAI_AZURE_ENDPOINT" in os.environ:
|
18
|
+
assert "OPENAI_AZURE_KEY" in os.environ, "OPENAI_AZURE_KEY not found in environment"
|
19
|
+
|
20
|
+
# Azure API access
|
21
|
+
self.client = openai.AzureOpenAI(
|
22
|
+
api_key=os.environ["OPENAI_AZURE_KEY"],
|
23
|
+
azure_endpoint=os.environ["OPENAI_AZURE_ENDPOINT"],
|
24
|
+
api_version="2023-07-01-preview"
|
25
|
+
)
|
26
|
+
elif "OPENAI_API_KEY" in os.environ:
|
27
|
+
# OpenAI API access
|
28
|
+
self.client = openai.OpenAI(
|
29
|
+
api_key=os.environ["OPENAI_API_KEY"]
|
30
|
+
)
|
31
|
+
else:
|
32
|
+
raise Exception("OPENAI_API_KEY or OPENAI_AZURE_KEY not found in environment")
|
33
|
+
|
34
|
+
logging.getLogger().setLevel(logging.CRITICAL) # in order to suppress all these HTTP INFO log messages
|
35
|
+
|
36
|
+
# answer_id is used for determining if it was the top answer or how deep in the list it was
|
37
|
+
def request(self, prompt, model, parse_response, temperature=0, answer_id=-1, cache=None, max_tokens=None):
|
38
|
+
request = {"model": model, "temperature": temperature, "prompt": prompt}
|
39
|
+
|
40
|
+
if request in cache and cache[request] is not None and len(cache[request]) > 0:
|
41
|
+
answers = cache[request]
|
42
|
+
else:
|
43
|
+
answers = self.request_api(prompt, model, temperature, max_tokens)
|
44
|
+
cache[request] = answers
|
45
|
+
|
46
|
+
# there is no valid answer
|
47
|
+
if len(answers) == 0:
|
48
|
+
return [{
|
49
|
+
"temperature": temperature,
|
50
|
+
"answer_id": answer_id,
|
51
|
+
"answer": None,
|
52
|
+
"prompt": prompt,
|
53
|
+
"finish_reason": None,
|
54
|
+
"model": model,
|
55
|
+
}]
|
56
|
+
|
57
|
+
parsed_answers = []
|
58
|
+
for full_answer in answers:
|
59
|
+
finish_reason = full_answer["finish_reason"]
|
60
|
+
full_answer = full_answer["answer"]
|
61
|
+
answer_id += 1
|
62
|
+
answer = parse_response(full_answer)
|
63
|
+
if isinstance(answer, tuple):
|
64
|
+
answer, errors = answer
|
65
|
+
else:
|
66
|
+
errors = None
|
67
|
+
if self.verbose or temperature > 0:
|
68
|
+
print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
|
69
|
+
if answer is None:
|
70
|
+
continue
|
71
|
+
parsed_answers.append(
|
72
|
+
{
|
73
|
+
"temperature": temperature,
|
74
|
+
"answer_id": answer_id,
|
75
|
+
"answer": answer,
|
76
|
+
"errors": errors,
|
77
|
+
"prompt": prompt,
|
78
|
+
"finish_reason": finish_reason,
|
79
|
+
"model": model,
|
80
|
+
}
|
81
|
+
)
|
82
|
+
|
83
|
+
# there was no valid answer, increase temperature and try again
|
84
|
+
if len(parsed_answers) == 0:
|
85
|
+
return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
|
86
|
+
|
87
|
+
return parsed_answers
|
88
|
+
|
89
|
+
def request_api(self, prompt, model, temperature=0, max_tokens=None):
|
90
|
+
if temperature > 10:
|
91
|
+
return []
|
92
|
+
|
93
|
+
while True:
|
94
|
+
try:
|
95
|
+
response = self.call_api(prompt, model, temperature, max_tokens)
|
96
|
+
break
|
97
|
+
except Exception as e:
|
98
|
+
# response was filtered
|
99
|
+
if hasattr(e, 'code'):
|
100
|
+
if e.code == 'content_filter':
|
101
|
+
return []
|
102
|
+
print(e.code, file=sys.stderr)
|
103
|
+
if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
|
104
|
+
return []
|
105
|
+
|
106
|
+
# frequent error is reaching the API limit
|
107
|
+
print(colored("Error, retrying...", "red"), file=sys.stderr)
|
108
|
+
print(e, file=sys.stderr)
|
109
|
+
time.sleep(1)
|
110
|
+
|
111
|
+
answers = []
|
112
|
+
for choice in response.choices:
|
113
|
+
if choice.message.content is None:
|
114
|
+
return []
|
115
|
+
if hasattr(choice, "message"):
|
116
|
+
answer = choice.message.content.strip()
|
117
|
+
else:
|
118
|
+
answer = choice.text.strip()
|
119
|
+
|
120
|
+
# one of the responses didn't finish, we need to request more tokens
|
121
|
+
if choice.finish_reason != "stop":
|
122
|
+
if self.verbose:
|
123
|
+
print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
|
124
|
+
print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
|
125
|
+
if max_tokens is None:
|
126
|
+
return []
|
127
|
+
return self.request_api(prompt, model, temperature=temperature, max_tokens=max_tokens + 200)
|
128
|
+
|
129
|
+
answers.append({
|
130
|
+
"answer": answer,
|
131
|
+
"finish_reason": choice.finish_reason,
|
132
|
+
})
|
133
|
+
|
134
|
+
if len(answers) > 1:
|
135
|
+
# remove duplicate answers
|
136
|
+
answers = [dict(t) for t in {tuple(d.items()) for d in answers}]
|
137
|
+
|
138
|
+
return answers
|
139
|
+
|
140
|
+
def call_api(self, prompt, model, temperature, max_tokens):
|
141
|
+
parameters = {
|
142
|
+
"temperature": temperature/10,
|
143
|
+
"top_p": 1,
|
144
|
+
"n": 1,
|
145
|
+
"frequency_penalty": 0,
|
146
|
+
"presence_penalty": 0,
|
147
|
+
"stop": None,
|
148
|
+
"model": model
|
149
|
+
}
|
150
|
+
|
151
|
+
if max_tokens is not None:
|
152
|
+
parameters["max_tokens"] = max_tokens
|
153
|
+
|
154
|
+
if isinstance(prompt, list):
|
155
|
+
# check that prompt contain list of dictionaries with role and content
|
156
|
+
assert all(isinstance(p, dict) for p in prompt), "Prompts must be a list of dictionaries."
|
157
|
+
assert all("role" in p and "content" in p for p in prompt), "Prompts must be a list of dictionaries with role and content."
|
158
|
+
|
159
|
+
parameters["messages"] = prompt
|
160
|
+
else:
|
161
|
+
parameters["messages"] = [{
|
162
|
+
"role": "user",
|
163
|
+
"content": prompt,
|
164
|
+
}]
|
165
|
+
|
166
|
+
return self.client.chat.completions.create(**parameters)
|
167
|
+
|
168
|
+
def bulk_request(self, df, model, parse_mqm_answer, cache, max_tokens=None):
|
169
|
+
answers = []
|
170
|
+
for i, row in tqdm.tqdm(df.iterrows(), total=len(df), file=sys.stderr):
|
171
|
+
prompt = row["prompt"]
|
172
|
+
parsed_answers = self.request(prompt, model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
|
173
|
+
answers += parsed_answers
|
174
|
+
return answers
|
gemba/mtme_tools.py
ADDED
@@ -0,0 +1,99 @@
|
|
1
|
+
from mt_metrics_eval import data
|
2
|
+
import scipy
|
3
|
+
|
4
|
+
######
|
5
|
+
# Functions in this script are copied from mt-metrics-eval/wmt22_metrics.ipynb
|
6
|
+
######
|
7
|
+
|
8
|
+
|
9
|
+
def eval_metrics(eval_sets, langs, levels, primary_only, k, gold_name='std',
|
10
|
+
include_domains=True, seg_level_no_avg=False,
|
11
|
+
include_human_with_acc=False):
|
12
|
+
"""Evaluate all metrics for eval sets, across multiple task settings.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
eval_sets: Map from lang-pair to eval_set objects.
|
16
|
+
langs: List of language pairs (eg 'en-de') for which to compute results.
|
17
|
+
levels: List of levels for which to compute results, allowed elements are
|
18
|
+
'sys' and 'seg'.
|
19
|
+
primary_only: Include only primary metrics.
|
20
|
+
k: Number of boostrap draws. If 0, no significance tests for metric-score
|
21
|
+
differences are run, and execution is much faster.
|
22
|
+
gold_name: Name of gold scores to use, standard scores if 'std'.
|
23
|
+
include_domains: Generate domain-specific results in addition to global
|
24
|
+
results.
|
25
|
+
seg_level_no_avg: If True, use only the average_by=None setting for segment-
|
26
|
+
level correlations
|
27
|
+
include_human_with_acc: If True, include human outputs in accuracy tasks.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
Map from task names to metric -> (rank, corr, sig_string) stats.
|
31
|
+
"""
|
32
|
+
results = {}
|
33
|
+
|
34
|
+
# First task is global accuracy, iff more than one language is given.
|
35
|
+
if len(langs) > 0:
|
36
|
+
evs_list = [eval_sets[lp] for lp in langs]
|
37
|
+
main_refs = [{evs.std_ref} for evs in evs_list]
|
38
|
+
close_refs = [set() for evs in evs_list]
|
39
|
+
if gold_name == 'std':
|
40
|
+
gold = evs_list[0].StdHumanScoreName('sys')
|
41
|
+
else:
|
42
|
+
gold = gold_name
|
43
|
+
humans = [True, False] if include_human_with_acc else [False]
|
44
|
+
for human in humans:
|
45
|
+
taskname = data.MakeTaskName(
|
46
|
+
'wmt22', langs, None, 'sys', human, 'none', 'accuracy', k, gold,
|
47
|
+
main_refs, close_refs, False, primary_only)
|
48
|
+
print(taskname)
|
49
|
+
res = data.CompareMetricsWithGlobalAccuracy(
|
50
|
+
evs_list, main_refs, close_refs, include_human=human,
|
51
|
+
include_outliers=False, gold_name=gold,
|
52
|
+
primary_metrics=primary_only,
|
53
|
+
domain=None, k=k, pval=0.05)
|
54
|
+
results[taskname] = reformat(res)
|
55
|
+
|
56
|
+
# Remaining tasks are specific to language, domain, etc.
|
57
|
+
for lp in langs:
|
58
|
+
evs = eval_sets[lp]
|
59
|
+
main_refs = {evs.std_ref}
|
60
|
+
close_refs = set()
|
61
|
+
for domain in [None] + (list(evs.domain_names) if include_domains else []):
|
62
|
+
for level in levels:
|
63
|
+
gold = evs.StdHumanScoreName(level) if gold_name == 'std' else gold_name
|
64
|
+
for avg in 'none', 'sys', 'item':
|
65
|
+
if (level == 'sys' or seg_level_no_avg) and avg != 'none':
|
66
|
+
continue
|
67
|
+
for human in True, False:
|
68
|
+
if human == True and len(evs.ref_names) == 1:
|
69
|
+
continue # Single ref
|
70
|
+
for corr in 'pearson', 'kendall':
|
71
|
+
corr_fcn = {'pearson': scipy.stats.pearsonr,
|
72
|
+
'kendall': scipy.stats.kendalltau}[corr]
|
73
|
+
taskname = data.MakeTaskName(
|
74
|
+
'wmt22', lp, domain, level, human, avg, corr, k, gold,
|
75
|
+
main_refs, close_refs, False, primary=primary_only)
|
76
|
+
print(taskname)
|
77
|
+
corrs = data.GetCorrelations(
|
78
|
+
evs=evs, level=level, main_refs={evs.std_ref},
|
79
|
+
close_refs=close_refs, include_human=human,
|
80
|
+
include_outliers=False, gold_name=gold_name,
|
81
|
+
primary_metrics=primary_only, domain=domain)
|
82
|
+
metrics, sig_matrix = data.CompareMetrics(
|
83
|
+
corrs, corr_fcn, average_by=avg, k=k, pval=0.05)
|
84
|
+
# Make compatible with accuracy results.
|
85
|
+
metrics = {evs.DisplayName(m): v for m, v in metrics.items()}
|
86
|
+
results[taskname] = reformat((metrics, sig_matrix))
|
87
|
+
|
88
|
+
return results
|
89
|
+
|
90
|
+
|
91
|
+
def reformat(results):
|
92
|
+
"""Reformat CompareMetrics() results to match mtme's format."""
|
93
|
+
metrics, sig_matrix = results
|
94
|
+
res = {}
|
95
|
+
for i, (m, (corr, rank)) in enumerate(metrics.items()):
|
96
|
+
sigs = ['1' if p < 0.05 else '0' for p in sig_matrix[i]]
|
97
|
+
sigs = ['x'] * (i + 1) + sigs[i + 1:]
|
98
|
+
res[m] = (rank, corr, ' '.join(sigs))
|
99
|
+
return res
|