gemba 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
gemba/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .utils import get_gemba_scores
2
+
3
+ __all__ = ['get_gemba_scores']
gemba/gemba_da.py ADDED
@@ -0,0 +1,62 @@
1
+ import diskcache as dc
2
+ from gemba.prompt import prompts, language_codes
3
+ from gemba.gpt_api import GptApi
4
+ from gemba.testset import Testset
5
+ from gemba.scores import Scores
6
+
7
+
8
+ def main():
9
+ scenarios = [
10
+ ["text-davinci-003", "GEMBA-DA", [["wmt22", "en-de"], ["wmt22", "zh-en"], ["wmt22", "en-ru"]], ],
11
+ ["text-davinci-003", "GEMBA-DA_ref", [["wmt22", "en-de"], ["wmt22", "zh-en"], ["wmt22", "en-ru"]], ],
12
+ ]
13
+
14
+ gptapi = GptApi()
15
+ for scenario in scenarios:
16
+ use_model = scenario[0]
17
+ annotation = scenario[1]
18
+ cache = dc.Cache(f'cache/{use_model}_{annotation}', expire=None, size_limit=int(10e10), cull_limit=0, eviction_policy='none')
19
+
20
+ scoring_name = f"{annotation}_{use_model}"
21
+
22
+ if use_model not in credentials["deployments"].keys():
23
+ print(f"Model {use_model} not supported by credentials")
24
+ continue
25
+
26
+ for dataset, lp in scenario[2]:
27
+ testset = Testset("mt-metrics-eval-v2", dataset, lp)
28
+ if prompts[annotation]["use_ref"]:
29
+ refname = testset.main_ref
30
+ else:
31
+ refname = None
32
+
33
+ scores = Scores(scoring_name, testset, refname)
34
+
35
+ # starts with -1 as it is incremented before the first request
36
+ hypothesis_index = -1
37
+ total = testset.segments_count()
38
+ for src, hyp, ref, system in testset.iterate_over_all(refname):
39
+ hypothesis_index += 1
40
+
41
+ if scores.get_score(system, hypothesis_index) != 'None':
42
+ continue
43
+
44
+ print(f"Processing hypothesis {hypothesis_index}/{total} for {scoring_name} on {dataset}/{lp}")
45
+
46
+ data = {
47
+ "source_seg": src,
48
+ "target_seg": hyp,
49
+ "reference_seg": ref,
50
+ "source_lang": language_codes[lp.split("-")[0]],
51
+ "target_lang": language_codes[lp.split("-")[1]],
52
+ }
53
+ prompt = prompts[annotation]["prompt"].format(**data)
54
+ parsed_answers = gptapi.request(prompt, use_model, prompts[annotation]["validate_answer"], cache=cache)
55
+
56
+ scores.assign_score(system, hypothesis_index, parsed_answers[0]['answer'], parsed_answers[0]['temperature'])
57
+
58
+ scores.save()
59
+
60
+
61
+ if __name__ == '__main__':
62
+ main()
gemba/gemba_esa.py ADDED
@@ -0,0 +1,84 @@
1
+ import ipdb
2
+ import json
3
+ import re
4
+ from collections import defaultdict
5
+
6
+
7
+ def esa_fewshot(few_shots):
8
+ prompts = [
9
+ {
10
+ "role": "system",
11
+ "content": f"Your task is to identify machine translation errors and assess the quality of the translation."
12
+ }
13
+ ]
14
+
15
+ template = """{source_lang} source:
16
+ ```{source_seg}```
17
+ {target_lang} translation:
18
+ ```{target_seg}```
19
+
20
+ Based on the source segment and machine translation surrounded with triple backticks, identify error types in the translation and classify them. The categories of errors are: accuracy (addition, mistranslation, omission, untranslated text), fluency (character encoding, grammar, inconsistency, punctuation, register, spelling), style (awkward), terminology (inappropriate for context, inconsistent use), non-translation, other, or no-error.\nEach error is classified as one of two categories: major or minor. Major errors disrupt the flow and make the understandability of text difficult or impossible. Minor errors are errors that do not disrupt the flow significantly and what the text is trying to say is still understandable."""
21
+
22
+ for shot in few_shots:
23
+ prompts.append({
24
+ "role": "user",
25
+ "content": template.format(**shot)
26
+ })
27
+ answer = shot['answer']
28
+
29
+ prompts.append({
30
+ "role": "assistant",
31
+ "content": answer
32
+ })
33
+
34
+ prompts.append({
35
+ "role": "user",
36
+ "content": template
37
+ })
38
+
39
+ return prompts
40
+
41
+
42
+ esa_few_shots = {
43
+ "ende": {
44
+ "source_lang": "English",
45
+ "source_seg": "I do apologise about this, we must gain permission from the account holder to discuss an order with another person, I apologise if this was done previously, however, I would not be able to discuss this with yourself without the account holders permission.",
46
+ "target_lang": "German",
47
+ "target_seg": "Ich entschuldige mich dafür, wir müssen die Erlaubnis einholen, um eine Bestellung mit einer anderen Person zu besprechen. Ich entschuldige mich, falls dies zuvor geschehen wäre, aber ohne die Erlaubnis des Kontoinhabers wäre ich nicht in der Lage, dies mit dir involvement.",
48
+ "answer": """Major:
49
+ accuracy/mistranslation - "involvement"
50
+ accuracy/omission - "the account holder"
51
+ Minor:
52
+ fluency/grammar - "wäre"
53
+ fluency/register - "dir"
54
+ """,
55
+ },
56
+ "encs": {
57
+ "source_lang": "English",
58
+ "source_seg": "Talks have resumed in Vienna to try to revive the nuclear pact, with both sides trying to gauge the prospects of success after the latest exchanges in the stop-start negotiations.",
59
+ "target_lang": "Czech",
60
+ "target_seg": "Ve Vídni se ve Vídni obnovily rozhovory o oživení jaderného paktu, přičemž obě partaje se snaží posoudit vyhlídky na úspěch po posledních výměnách v jednáních.",
61
+ "answer": """Major:
62
+ accuracy/addition - "ve Vídni"
63
+ accuracy/omission - "the stop-start"
64
+ Minor:
65
+ terminology/inappropriate for context - "partaje"
66
+ """,
67
+ },
68
+ "zhen": {
69
+ "source_lang": "Chinese",
70
+ "source_seg": "大众点评乌鲁木齐家居卖场频道为您提供高铁居然之家地址,电话,营业时间等最新商户信息,找装修公司,就上大众点评",
71
+ "target_lang": "English",
72
+ "target_seg": "Urumqi Home Furnishing Store Channel provides you with the latest business information such as the address, telephone number, business hours, etc., of high-speed rail, and find a decoration company, and go to the reviews.",
73
+ "answer": """Major:
74
+ accuracy/addition - "of high-speed rail"
75
+ accuracy/mistranslation - "go to the reviews"
76
+ Minor:
77
+ style/awkward - "etc.,"
78
+ """,
79
+ },
80
+ }
81
+
82
+ TEMPLATE_GEMBA_ESA_ERROR_SPANS = esa_fewshot([esa_few_shots['ende'], esa_few_shots['encs'], esa_few_shots['zhen']])
83
+
84
+ TEMPLATE_GEMBA_ESA_RANKING = 'Given the translation from {source_lang} to {target_lang} and the annotated error spans, assign a score on a continuous scale from 0 to 100. The scale has following reference points: 0="No meaning preserved", 33="Some meaning preserved", 66="Most meaning preserved and few grammar mistakes", up to 100="Perfect meaning and grammar".\n\nScore the following translation from {source_lang} source:\n```{source_seg}```\n{target_lang} translation:\n```{target_seg}```\nAnnotated error spans:\n```{error_spans}```\nScore (0-100): '
@@ -0,0 +1,234 @@
1
+ import ipdb
2
+ import json
3
+ import re
4
+ from collections import defaultdict
5
+
6
+ def apply_template(template, data):
7
+ if isinstance(template, str):
8
+ return template.format(**data)
9
+ elif isinstance(template, list):
10
+ prompt = []
11
+ for conversation_turn in template:
12
+ p = conversation_turn.copy()
13
+ p['content'] = p['content'].format(**data)
14
+ prompt.append(p)
15
+ return prompt
16
+ else:
17
+ raise ValueError(f"Unknown template type {type(template)}")
18
+
19
+ def parse_broken_json(x):
20
+ improved_translation = ""
21
+ errors = defaultdict(list)
22
+ if '"errors": ' in x and "improved translation" in x:
23
+ data = x.split('", "errors": ')
24
+ if len(data) != 2:
25
+ return {"improved translation": improved_translation, "errors": errors}
26
+ # from data[0] parse improved translation
27
+ improved_translation = data[0].split('"improved translation": "')[1]
28
+ # remove last character from data[1]
29
+ data[1] = data[1][:-1]
30
+
31
+ try:
32
+ errors = json.loads(data[1])
33
+ except:
34
+ # just try to get error count
35
+ words = re.findall(r'\b\w+\b', data[1].lower())
36
+ keywords = ['critical', 'major', 'minor']
37
+
38
+ last_key = None
39
+ for word in words:
40
+ if word in keywords:
41
+ last_key = word
42
+ elif last_key is not None and word == "class":
43
+ errors[last_key].append({"class": "other"})
44
+
45
+ return {"improved translation": improved_translation, "errors": errors}
46
+
47
+
48
+ def parse_error_class(error):
49
+ # parse error from error description, errors are ['accuracy', 'fluency', 'locale convention', 'style', 'terminology', 'non-translation', 'other']
50
+ # locale convention (currency, date, name, telephone, or time format), style (awkward), terminology (inappropriate for context, inconsistent use),
51
+ class_name = "unknown"
52
+ if "accuracy" in error:
53
+ class_name = "accuracy"
54
+ for subclass in ["addition", "mistranslation", "omission", "untranslated text"]:
55
+ if subclass in error:
56
+ class_name = f"accuracy-{subclass}"
57
+ elif "fluency" in error:
58
+ class_name = "fluency"
59
+ for subclass in ["character encoding", "grammar", "inconsistency", "punctuation", "register", "spelling"]:
60
+ if subclass in error:
61
+ class_name = f"fluency-{subclass}"
62
+ elif "locale convention" in error:
63
+ class_name = "locale convention"
64
+ for subclass in ["currency", "date", "name", "telephone", "time"]:
65
+ if subclass in error:
66
+ class_name = f"locale convention-{subclass}"
67
+ elif "style" in error:
68
+ class_name = "style"
69
+ elif "terminology" in error:
70
+ class_name = "terminology"
71
+ for subclass in ["inappropriate", "inconsistent"]:
72
+ if subclass in error:
73
+ class_name = f"terminology-{subclass}"
74
+ elif "non-translation" in error:
75
+ class_name = "non-translation"
76
+ elif "other" in error:
77
+ class_name = "other"
78
+
79
+ return class_name
80
+
81
+
82
+ def parse_mqm_answer(x, list_mqm_errors=False, full_desc=True, normalize=True):
83
+ if x is None:
84
+ return None
85
+
86
+ x = str(x)
87
+ if x.startswith('{"improved translation"'):
88
+ try:
89
+ x = json.loads(x)
90
+ except:
91
+ x = parse_broken_json(x)
92
+ errors = x["errors"]
93
+
94
+
95
+ else:
96
+ x = x.lower()
97
+ errors = {'critical': [], 'major': [], 'minor': []}
98
+ error_level = None
99
+ for line in x.split('\n'):
100
+ line = line.strip()
101
+ if "no-error" in line or "no error" in line or "" == line:
102
+ continue
103
+ if "critical:" == line:
104
+ error_level = "critical"
105
+ continue
106
+ elif "major:" == line:
107
+ error_level = "major"
108
+ continue
109
+ elif "minor:" == line:
110
+ error_level = "minor"
111
+ continue
112
+
113
+ if "critical" in line or "major" in line or "minor" in line:
114
+ if not any([line.startswith(x) for x in ['accuracy', 'fluency', 'locale convention', 'style', 'terminology', 'non-translation', 'other']]):
115
+ print(line)
116
+
117
+ if error_level is None:
118
+ print(f"No error level for {line}")
119
+ continue
120
+
121
+ if "non-translation" in line:
122
+ errors["critical"].append(line)
123
+ else:
124
+ errors[error_level].append(line)
125
+
126
+ error_classes = defaultdict(list)
127
+ final_score = 0
128
+ error_counter = 0
129
+ for error_level in ['critical', 'major', 'minor']:
130
+ if error_level not in errors:
131
+ continue
132
+ for error in errors[error_level]:
133
+ if error_counter < 5:
134
+ final_score += 25 if error_level == 'critical' else 5 if error_level == 'major' else 1
135
+ error_counter += 1
136
+
137
+ if full_desc:
138
+ error_classes[error_level].append(error)
139
+ else:
140
+ class_name = parse_error_class(error)
141
+ error_classes[error_level].append(class_name)
142
+ if final_score > 25:
143
+ final_score = 25
144
+
145
+ # negative score is to normalize that higher score is better
146
+ return_score = (-final_score * 4 + 100) if normalize else -final_score
147
+ if list_mqm_errors:
148
+ return return_score, error_classes
149
+ else:
150
+ return return_score
151
+
152
+
153
+ def mqm_fewshot(few_shots):
154
+ prompts = [
155
+ {
156
+ "role": "system",
157
+ "content": f"You are an annotator for the quality of machine translation. Your task is to identify errors and assess the quality of the translation."
158
+ }
159
+ ]
160
+
161
+ template = """{source_lang} source:
162
+ ```{source_seg}```
163
+ {target_lang} translation:
164
+ ```{target_seg}```
165
+
166
+ Based on the source segment and machine translation surrounded with triple backticks, identify error types in the translation and classify them. The categories of errors are: accuracy (addition, mistranslation, omission, untranslated text), fluency (character encoding, grammar, inconsistency, punctuation, register, spelling), style (awkward), terminology (inappropriate for context, inconsistent use), non-translation, other, or no-error.\nEach error is classified as one of three categories: critical, major, and minor. Critical errors inhibit comprehension of the text. Major errors disrupt the flow, but what the text is trying to say is still understandable. Minor errors are technically errors, but do not disrupt the flow or hinder comprehension."""
167
+
168
+ for shot in few_shots:
169
+ prompts.append({
170
+ "role": "user",
171
+ "content": template.format(**shot)
172
+ })
173
+ answer = shot['answer']
174
+
175
+ prompts.append({
176
+ "role": "assistant",
177
+ "content": answer
178
+ })
179
+
180
+ prompts.append({
181
+ "role": "user",
182
+ "content": template
183
+ })
184
+
185
+ return prompts
186
+
187
+
188
+ few_shots = {
189
+ "ende": {
190
+ "source_lang": "English",
191
+ "source_seg": "I do apologise about this, we must gain permission from the account holder to discuss an order with another person, I apologise if this was done previously, however, I would not be able to discuss this with yourself without the account holders permission.",
192
+ "target_lang": "German",
193
+ "target_seg": "Ich entschuldige mich dafür, wir müssen die Erlaubnis einholen, um eine Bestellung mit einer anderen Person zu besprechen. Ich entschuldige mich, falls dies zuvor geschehen wäre, aber ohne die Erlaubnis des Kontoinhabers wäre ich nicht in der Lage, dies mit dir involvement.",
194
+ "answer": """Critical:
195
+ no-error
196
+ Major:
197
+ accuracy/mistranslation - "involvement"
198
+ accuracy/omission - "the account holder"
199
+ Minor:
200
+ fluency/grammar - "wäre"
201
+ fluency/register - "dir"
202
+ """,
203
+ },
204
+ "encs": {
205
+ "source_lang": "English",
206
+ "source_seg": "Talks have resumed in Vienna to try to revive the nuclear pact, with both sides trying to gauge the prospects of success after the latest exchanges in the stop-start negotiations.",
207
+ "target_lang": "Czech",
208
+ "target_seg": "Ve Vídni se ve Vídni obnovily rozhovory o oživení jaderného paktu, přičemž obě partaje se snaží posoudit vyhlídky na úspěch po posledních výměnách v jednáních.",
209
+ "answer": """Critical:
210
+ no-error
211
+ Major:
212
+ accuracy/addition - "ve Vídni"
213
+ accuracy/omission - "the stop-start"
214
+ Minor:
215
+ terminology/inappropriate for context - "partaje"
216
+ """,
217
+ },
218
+ "zhen": {
219
+ "source_lang": "Chinese",
220
+ "source_seg": "大众点评乌鲁木齐家居卖场频道为您提供高铁居然之家地址,电话,营业时间等最新商户信息,找装修公司,就上大众点评",
221
+ "target_lang": "English",
222
+ "target_seg": "Urumqi Home Furnishing Store Channel provides you with the latest business information such as the address, telephone number, business hours, etc., of high-speed rail, and find a decoration company, and go to the reviews.",
223
+ "answer": """Critical:
224
+ accuracy/addition - "of high-speed rail"
225
+ Major:
226
+ accuracy/mistranslation - "go to the reviews"
227
+ Minor:
228
+ style/awkward - "etc.,"
229
+ """,
230
+ },
231
+ }
232
+
233
+ TEMPLATE_GEMBA_MQM = mqm_fewshot([few_shots['ende'], few_shots['encs'], few_shots['zhen']])
234
+
gemba/gpt_api.py ADDED
@@ -0,0 +1,174 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import ipdb
5
+ import logging
6
+ from termcolor import colored
7
+ from datetime import datetime
8
+ import openai
9
+ import tqdm
10
+
11
+
12
+ # class for calling OpenAI API and handling cache
13
+ class GptApi:
14
+ def __init__(self, verbose=False):
15
+ self.verbose = verbose
16
+
17
+ if "OPENAI_AZURE_ENDPOINT" in os.environ:
18
+ assert "OPENAI_AZURE_KEY" in os.environ, "OPENAI_AZURE_KEY not found in environment"
19
+
20
+ # Azure API access
21
+ self.client = openai.AzureOpenAI(
22
+ api_key=os.environ["OPENAI_AZURE_KEY"],
23
+ azure_endpoint=os.environ["OPENAI_AZURE_ENDPOINT"],
24
+ api_version="2023-07-01-preview"
25
+ )
26
+ elif "OPENAI_API_KEY" in os.environ:
27
+ # OpenAI API access
28
+ self.client = openai.OpenAI(
29
+ api_key=os.environ["OPENAI_API_KEY"]
30
+ )
31
+ else:
32
+ raise Exception("OPENAI_API_KEY or OPENAI_AZURE_KEY not found in environment")
33
+
34
+ logging.getLogger().setLevel(logging.CRITICAL) # in order to suppress all these HTTP INFO log messages
35
+
36
+ # answer_id is used for determining if it was the top answer or how deep in the list it was
37
+ def request(self, prompt, model, parse_response, temperature=0, answer_id=-1, cache=None, max_tokens=None):
38
+ request = {"model": model, "temperature": temperature, "prompt": prompt}
39
+
40
+ if request in cache and cache[request] is not None and len(cache[request]) > 0:
41
+ answers = cache[request]
42
+ else:
43
+ answers = self.request_api(prompt, model, temperature, max_tokens)
44
+ cache[request] = answers
45
+
46
+ # there is no valid answer
47
+ if len(answers) == 0:
48
+ return [{
49
+ "temperature": temperature,
50
+ "answer_id": answer_id,
51
+ "answer": None,
52
+ "prompt": prompt,
53
+ "finish_reason": None,
54
+ "model": model,
55
+ }]
56
+
57
+ parsed_answers = []
58
+ for full_answer in answers:
59
+ finish_reason = full_answer["finish_reason"]
60
+ full_answer = full_answer["answer"]
61
+ answer_id += 1
62
+ answer = parse_response(full_answer)
63
+ if isinstance(answer, tuple):
64
+ answer, errors = answer
65
+ else:
66
+ errors = None
67
+ if self.verbose or temperature > 0:
68
+ print(f"Answer (t={temperature}): " + colored(answer, "yellow") + " (" + colored(full_answer, "blue") + ")", file=sys.stderr)
69
+ if answer is None:
70
+ continue
71
+ parsed_answers.append(
72
+ {
73
+ "temperature": temperature,
74
+ "answer_id": answer_id,
75
+ "answer": answer,
76
+ "errors": errors,
77
+ "prompt": prompt,
78
+ "finish_reason": finish_reason,
79
+ "model": model,
80
+ }
81
+ )
82
+
83
+ # there was no valid answer, increase temperature and try again
84
+ if len(parsed_answers) == 0:
85
+ return self.request(prompt, model, parse_response, temperature=temperature + 1, answer_id=answer_id, cache=cache)
86
+
87
+ return parsed_answers
88
+
89
+ def request_api(self, prompt, model, temperature=0, max_tokens=None):
90
+ if temperature > 10:
91
+ return []
92
+
93
+ while True:
94
+ try:
95
+ response = self.call_api(prompt, model, temperature, max_tokens)
96
+ break
97
+ except Exception as e:
98
+ # response was filtered
99
+ if hasattr(e, 'code'):
100
+ if e.code == 'content_filter':
101
+ return []
102
+ print(e.code, file=sys.stderr)
103
+ if hasattr(e, 'error') and e.error['code'] == 'invalid_model_output':
104
+ return []
105
+
106
+ # frequent error is reaching the API limit
107
+ print(colored("Error, retrying...", "red"), file=sys.stderr)
108
+ print(e, file=sys.stderr)
109
+ time.sleep(1)
110
+
111
+ answers = []
112
+ for choice in response.choices:
113
+ if choice.message.content is None:
114
+ return []
115
+ if hasattr(choice, "message"):
116
+ answer = choice.message.content.strip()
117
+ else:
118
+ answer = choice.text.strip()
119
+
120
+ # one of the responses didn't finish, we need to request more tokens
121
+ if choice.finish_reason != "stop":
122
+ if self.verbose:
123
+ print(colored(f"Increasing max tokens to fit answers.", "red") + colored(answer, "blue"), file=sys.stderr)
124
+ print(f"Finish reason: {choice.finish_reason}", file=sys.stderr)
125
+ if max_tokens is None:
126
+ return []
127
+ return self.request_api(prompt, model, temperature=temperature, max_tokens=max_tokens + 200)
128
+
129
+ answers.append({
130
+ "answer": answer,
131
+ "finish_reason": choice.finish_reason,
132
+ })
133
+
134
+ if len(answers) > 1:
135
+ # remove duplicate answers
136
+ answers = [dict(t) for t in {tuple(d.items()) for d in answers}]
137
+
138
+ return answers
139
+
140
+ def call_api(self, prompt, model, temperature, max_tokens):
141
+ parameters = {
142
+ "temperature": temperature/10,
143
+ "top_p": 1,
144
+ "n": 1,
145
+ "frequency_penalty": 0,
146
+ "presence_penalty": 0,
147
+ "stop": None,
148
+ "model": model
149
+ }
150
+
151
+ if max_tokens is not None:
152
+ parameters["max_tokens"] = max_tokens
153
+
154
+ if isinstance(prompt, list):
155
+ # check that prompt contain list of dictionaries with role and content
156
+ assert all(isinstance(p, dict) for p in prompt), "Prompts must be a list of dictionaries."
157
+ assert all("role" in p and "content" in p for p in prompt), "Prompts must be a list of dictionaries with role and content."
158
+
159
+ parameters["messages"] = prompt
160
+ else:
161
+ parameters["messages"] = [{
162
+ "role": "user",
163
+ "content": prompt,
164
+ }]
165
+
166
+ return self.client.chat.completions.create(**parameters)
167
+
168
+ def bulk_request(self, df, model, parse_mqm_answer, cache, max_tokens=None):
169
+ answers = []
170
+ for i, row in tqdm.tqdm(df.iterrows(), total=len(df), file=sys.stderr):
171
+ prompt = row["prompt"]
172
+ parsed_answers = self.request(prompt, model, parse_mqm_answer, cache=cache, max_tokens=max_tokens)
173
+ answers += parsed_answers
174
+ return answers
gemba/mtme_tools.py ADDED
@@ -0,0 +1,99 @@
1
+ from mt_metrics_eval import data
2
+ import scipy
3
+
4
+ ######
5
+ # Functions in this script are copied from mt-metrics-eval/wmt22_metrics.ipynb
6
+ ######
7
+
8
+
9
+ def eval_metrics(eval_sets, langs, levels, primary_only, k, gold_name='std',
10
+ include_domains=True, seg_level_no_avg=False,
11
+ include_human_with_acc=False):
12
+ """Evaluate all metrics for eval sets, across multiple task settings.
13
+
14
+ Args:
15
+ eval_sets: Map from lang-pair to eval_set objects.
16
+ langs: List of language pairs (eg 'en-de') for which to compute results.
17
+ levels: List of levels for which to compute results, allowed elements are
18
+ 'sys' and 'seg'.
19
+ primary_only: Include only primary metrics.
20
+ k: Number of boostrap draws. If 0, no significance tests for metric-score
21
+ differences are run, and execution is much faster.
22
+ gold_name: Name of gold scores to use, standard scores if 'std'.
23
+ include_domains: Generate domain-specific results in addition to global
24
+ results.
25
+ seg_level_no_avg: If True, use only the average_by=None setting for segment-
26
+ level correlations
27
+ include_human_with_acc: If True, include human outputs in accuracy tasks.
28
+
29
+ Returns:
30
+ Map from task names to metric -> (rank, corr, sig_string) stats.
31
+ """
32
+ results = {}
33
+
34
+ # First task is global accuracy, iff more than one language is given.
35
+ if len(langs) > 0:
36
+ evs_list = [eval_sets[lp] for lp in langs]
37
+ main_refs = [{evs.std_ref} for evs in evs_list]
38
+ close_refs = [set() for evs in evs_list]
39
+ if gold_name == 'std':
40
+ gold = evs_list[0].StdHumanScoreName('sys')
41
+ else:
42
+ gold = gold_name
43
+ humans = [True, False] if include_human_with_acc else [False]
44
+ for human in humans:
45
+ taskname = data.MakeTaskName(
46
+ 'wmt22', langs, None, 'sys', human, 'none', 'accuracy', k, gold,
47
+ main_refs, close_refs, False, primary_only)
48
+ print(taskname)
49
+ res = data.CompareMetricsWithGlobalAccuracy(
50
+ evs_list, main_refs, close_refs, include_human=human,
51
+ include_outliers=False, gold_name=gold,
52
+ primary_metrics=primary_only,
53
+ domain=None, k=k, pval=0.05)
54
+ results[taskname] = reformat(res)
55
+
56
+ # Remaining tasks are specific to language, domain, etc.
57
+ for lp in langs:
58
+ evs = eval_sets[lp]
59
+ main_refs = {evs.std_ref}
60
+ close_refs = set()
61
+ for domain in [None] + (list(evs.domain_names) if include_domains else []):
62
+ for level in levels:
63
+ gold = evs.StdHumanScoreName(level) if gold_name == 'std' else gold_name
64
+ for avg in 'none', 'sys', 'item':
65
+ if (level == 'sys' or seg_level_no_avg) and avg != 'none':
66
+ continue
67
+ for human in True, False:
68
+ if human == True and len(evs.ref_names) == 1:
69
+ continue # Single ref
70
+ for corr in 'pearson', 'kendall':
71
+ corr_fcn = {'pearson': scipy.stats.pearsonr,
72
+ 'kendall': scipy.stats.kendalltau}[corr]
73
+ taskname = data.MakeTaskName(
74
+ 'wmt22', lp, domain, level, human, avg, corr, k, gold,
75
+ main_refs, close_refs, False, primary=primary_only)
76
+ print(taskname)
77
+ corrs = data.GetCorrelations(
78
+ evs=evs, level=level, main_refs={evs.std_ref},
79
+ close_refs=close_refs, include_human=human,
80
+ include_outliers=False, gold_name=gold_name,
81
+ primary_metrics=primary_only, domain=domain)
82
+ metrics, sig_matrix = data.CompareMetrics(
83
+ corrs, corr_fcn, average_by=avg, k=k, pval=0.05)
84
+ # Make compatible with accuracy results.
85
+ metrics = {evs.DisplayName(m): v for m, v in metrics.items()}
86
+ results[taskname] = reformat((metrics, sig_matrix))
87
+
88
+ return results
89
+
90
+
91
+ def reformat(results):
92
+ """Reformat CompareMetrics() results to match mtme's format."""
93
+ metrics, sig_matrix = results
94
+ res = {}
95
+ for i, (m, (corr, rank)) in enumerate(metrics.items()):
96
+ sigs = ['1' if p < 0.05 else '0' for p in sig_matrix[i]]
97
+ sigs = ['x'] * (i + 1) + sigs[i + 1:]
98
+ res[m] = (rank, corr, ' '.join(sigs))
99
+ return res