gemba 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- gemba/__init__.py +3 -0
- gemba/gemba_da.py +62 -0
- gemba/gemba_esa.py +84 -0
- gemba/gemba_mqm_utils.py +234 -0
- gemba/gpt_api.py +174 -0
- gemba/mtme_tools.py +99 -0
- gemba/prompt.py +139 -0
- gemba/scores.py +103 -0
- gemba/testset.py +58 -0
- gemba/utils.py +78 -0
- gemba-0.1.0.dist-info/METADATA +136 -0
- gemba-0.1.0.dist-info/RECORD +14 -0
- gemba-0.1.0.dist-info/WHEEL +4 -0
- gemba-0.1.0.dist-info/licenses/LICENSE.md +427 -0
gemba/prompt.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
import re
|
2
|
+
from termcolor import colored
|
3
|
+
|
4
|
+
|
5
|
+
def parse_and_check_numerical_answer(answer, min=None, max=None):
|
6
|
+
attempt = parse_numerical_answer(answer, min, max)
|
7
|
+
if attempt is not None:
|
8
|
+
if attempt < min or attempt > max:
|
9
|
+
return None
|
10
|
+
return attempt
|
11
|
+
|
12
|
+
return None
|
13
|
+
|
14
|
+
|
15
|
+
def parse_numerical_answer(answer, min=None, max=None):
|
16
|
+
# get all numbers in a string
|
17
|
+
numbers = re.findall(r'\d+', answer)
|
18
|
+
if len(numbers) == 1:
|
19
|
+
return int(numbers[0])
|
20
|
+
|
21
|
+
# check if the answer is in form ['100'] and extract the number
|
22
|
+
r1 = re.match(r"^\[['\"][0-9]*['\"]\]$", answer)
|
23
|
+
if r1 is not None:
|
24
|
+
return int(answer[2:-2])
|
25
|
+
|
26
|
+
if max is not None:
|
27
|
+
# check if the answer is in a form of 0/100
|
28
|
+
r2 = re.match(rf"^[0-9]*/{max}$", answer)
|
29
|
+
if r2 is not None:
|
30
|
+
return int(answer.split("/")[0])
|
31
|
+
|
32
|
+
return None
|
33
|
+
|
34
|
+
|
35
|
+
def validate_number(x, min=0, max=100):
|
36
|
+
attempt = parse_and_check_numerical_answer(x, min, max)
|
37
|
+
if attempt is not None:
|
38
|
+
return attempt
|
39
|
+
return None
|
40
|
+
|
41
|
+
|
42
|
+
def parse_classes(answer, classes):
|
43
|
+
final_class = None
|
44
|
+
for i in range(len(classes)):
|
45
|
+
if classes[i].lower() in answer.lower():
|
46
|
+
if final_class is None:
|
47
|
+
final_class = i
|
48
|
+
else:
|
49
|
+
print(colored(f"Two classes found in answer {answer}", "red"))
|
50
|
+
return None
|
51
|
+
|
52
|
+
return final_class
|
53
|
+
|
54
|
+
|
55
|
+
def validate_stars(x):
|
56
|
+
x = x.lower()
|
57
|
+
# try to find all possible answers as sometimes it seems to be explaining itself
|
58
|
+
possible_answers = set()
|
59
|
+
|
60
|
+
# check if string x contains * characters
|
61
|
+
if "*" in x:
|
62
|
+
possible_answers.add(x.count("*"))
|
63
|
+
if "★" in x:
|
64
|
+
possible_answers.add(x.count("★"))
|
65
|
+
|
66
|
+
x = f" {x} ".replace("\n", " ")
|
67
|
+
# possible answers: "five stars", "5 stars", "five", "five starts: perfect translation", ...
|
68
|
+
if " one " in x or "1 star" in x:
|
69
|
+
possible_answers.add(1)
|
70
|
+
if " two " in x or "2 star" in x:
|
71
|
+
possible_answers.add(2)
|
72
|
+
if " three " in x or "3 star" in x:
|
73
|
+
possible_answers.add(3)
|
74
|
+
if " four " in x or "4 star" in x:
|
75
|
+
possible_answers.add(4)
|
76
|
+
if " five " in x or "5 star" in x:
|
77
|
+
possible_answers.add(5)
|
78
|
+
|
79
|
+
numerical = parse_numerical_answer(x)
|
80
|
+
if numerical is not None:
|
81
|
+
possible_answers.add(numerical)
|
82
|
+
|
83
|
+
if len(possible_answers) == 1:
|
84
|
+
answer = possible_answers.pop()
|
85
|
+
if 1 <= answer <= 5:
|
86
|
+
return answer
|
87
|
+
return None
|
88
|
+
|
89
|
+
|
90
|
+
language_codes = {
|
91
|
+
"en": "English",
|
92
|
+
"de": "German",
|
93
|
+
"zh": "Chinese",
|
94
|
+
"ru": "Russian",
|
95
|
+
}
|
96
|
+
|
97
|
+
prompts = {
|
98
|
+
"GEMBA-DA": {
|
99
|
+
"prompt": 'Score the following translation from {source_lang} to {target_lang} on a continuous scale from 0 to 100, where a score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\n\n{source_lang} source: "{source_seg}"\n{target_lang} translation: "{target_seg}"\nScore: ',
|
100
|
+
"validate_answer": lambda x: validate_number(x),
|
101
|
+
"use_ref": False},
|
102
|
+
|
103
|
+
"GEMBA-DA_ref": {
|
104
|
+
"prompt": 'Score the following translation from {source_lang} to {target_lang} with respect to human reference on a continuous scale 0 to 100 where score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".\n\n{source_lang} source: "{source_seg}"\n{target_lang} human reference: {reference_seg}\n{target_lang} machine translation: "{target_seg}"\nScore: ',
|
105
|
+
"validate_answer": lambda x: validate_number(x),
|
106
|
+
"use_ref": True},
|
107
|
+
|
108
|
+
"GEMBA-SQM": {
|
109
|
+
"prompt": 'Score the following translation from {source_lang} to {target_lang} on a continuous scale from 0 to 100 that starts on "No meaning preserved", goes through "Some meaning preserved", then "Most meaning preserved and few grammar mistakes", up to "Perfect meaning and grammar".\n\n{source_lang} source: "{source_seg}"\n{target_lang} translation: "{target_seg}"\nScore (0-100): ',
|
110
|
+
"validate_answer": lambda x: validate_number(x),
|
111
|
+
"use_ref": False},
|
112
|
+
|
113
|
+
"GEMBA-SQM_ref": {
|
114
|
+
"prompt": 'Score the following machine translation from {source_lang} to {target_lang} with respect to the human reference on a continuous scale from 0 to 100 that starts with "No meaning preserved", goes through "Some meaning preserved", then "Most meaning preserved and few grammar mistakes", up to "Perfect meaning and grammar".\n\n{source_lang} source: "{source_seg}"\n{target_lang} human reference: "{reference_seg}"\n{target_lang} machine translation: "{target_seg}"\nScore (0-100): ',
|
115
|
+
"validate_answer": lambda x: validate_number(x),
|
116
|
+
"use_ref": True},
|
117
|
+
|
118
|
+
"GEMBA-stars": {
|
119
|
+
"prompt": 'Score the following translation from {source_lang} to {target_lang} with one to five stars. Where one star means "Nonsense/No meaning preserved", two stars mean "Some meaning preserved, but not understandable", three stars mean "Some meaning preserved and understandable", four stars mean "Most meaning preserved with possibly few grammar mistakes", and five stars mean "Perfect meaning and grammar".\n\n{source_lang} source: "{source_seg}"\n{target_lang} translation: "{target_seg}"\nStars: ',
|
120
|
+
"validate_answer": lambda x: validate_stars(x),
|
121
|
+
"use_ref": False},
|
122
|
+
|
123
|
+
"GEMBA-stars_ref": {
|
124
|
+
"prompt": 'Score the following translation from {source_lang} to {target_lang} with respect to the human reference with one to five stars. Where one star means "Nonsense/No meaning preserved", two stars mean "Some meaning preserved, but not understandable", three stars mean "Some meaning preserved and understandable", four stars mean "Most meaning preserved with possibly few grammar mistakes", and five stars mean "Perfect meaning and grammar".\n\n{source_lang} source: "{source_seg}"\n{target_lang} human reference: "{reference_seg}"\n{target_lang} translation: "{target_seg}"\nStars: ',
|
125
|
+
"validate_answer": lambda x: validate_stars(x),
|
126
|
+
"use_ref": True},
|
127
|
+
|
128
|
+
"GEMBA-classes": {
|
129
|
+
"prompt": 'Classify the quality of machine translation from {source_lang} to {target_lang} into one of following classes: "No meaning preserved", "Some meaning preserved, but not understandable", "Some meaning preserved and understandable", "Most meaning preserved, minor issues", "Perfect translation".\n\n{source_lang} source: "{source_seg}"\n{target_lang} machine translation: "{target_seg}"\nClass: ',
|
130
|
+
"use_ref": False,
|
131
|
+
"validate_answer": lambda x, classes=["No meaning preserved", "Some meaning preserved, but not understandable", "Some meaning preserved and understandable", "Most meaning preserved, minor issues", "Perfect translation"]: parse_classes(x, classes),
|
132
|
+
"max_tokens": 100},
|
133
|
+
|
134
|
+
"GEMBA-classes_ref": {
|
135
|
+
"prompt": 'Classify the quality of machine translation from {source_lang} to {target_lang} with respect to the human reference into one of following classes: "No meaning preserved", "Some meaning preserved, but not understandable", "Some meaning preserved and understandable", "Most meaning preserved, minor issues", "Perfect translation".\n\n{source_lang} source: "{source_seg}"\n{target_lang} human reference: "{reference_seg}"\n{target_lang} machine translation: "{target_seg}"\nClass: ',
|
136
|
+
"use_ref": True,
|
137
|
+
"validate_answer": lambda x, classes=["No meaning preserved", "Some meaning preserved, but not understandable", "Some meaning preserved and understandable", "Most meaning preserved, minor issues", "Perfect translation"]: parse_classes(x, classes),
|
138
|
+
"max_tokens": 100},
|
139
|
+
}
|
gemba/scores.py
ADDED
@@ -0,0 +1,103 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
import os
|
3
|
+
import pandas as pd
|
4
|
+
|
5
|
+
|
6
|
+
class Scores:
|
7
|
+
def __init__(self, name, testset, refname, output_path=None):
|
8
|
+
self.name = name
|
9
|
+
self.testset = testset
|
10
|
+
self.refname = refname
|
11
|
+
if output_path is None:
|
12
|
+
output_path = testset.basepath
|
13
|
+
|
14
|
+
self.output_path = output_path
|
15
|
+
|
16
|
+
self.seg_scores = None
|
17
|
+
self.metadata = None
|
18
|
+
self.prefix = None
|
19
|
+
self.load()
|
20
|
+
|
21
|
+
def load(self):
|
22
|
+
output_folder = f"{self.testset.basepath}/{self.testset.dataset}/metric-scores/{self.testset.lp}"
|
23
|
+
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
24
|
+
|
25
|
+
if self.refname is not None:
|
26
|
+
self.prefix = f"{output_folder}/{self.name}-{self.refname}"
|
27
|
+
else:
|
28
|
+
self.prefix = f"{output_folder}/{self.name}-src"
|
29
|
+
|
30
|
+
seg_scores_file = self.get_seg_path()
|
31
|
+
if os.path.isfile(f"{seg_scores_file}"):
|
32
|
+
self.seg_scores = pd.read_csv(seg_scores_file, sep="\t", names=["system", "score"], index_col=False)
|
33
|
+
else:
|
34
|
+
self.seg_scores = pd.DataFrame(columns=["system", "score"])
|
35
|
+
|
36
|
+
if os.path.isfile(f"{self.get_meta_path()}"):
|
37
|
+
self.metadata = pd.read_csv(self.get_meta_path(), sep="\t", names=["system", "temperature"], index_col=False)
|
38
|
+
else:
|
39
|
+
self.metadata = pd.DataFrame(columns=["system", "temperature"])
|
40
|
+
|
41
|
+
# generate placeholders for scores
|
42
|
+
segment_count = len(self.testset.sources)
|
43
|
+
for system in self.testset.systems.keys():
|
44
|
+
if system not in self.seg_scores["system"].values:
|
45
|
+
# generate placeholders for scores
|
46
|
+
placeholder = pd.DataFrame([{"system": system, 'score': 'None'}] * segment_count)
|
47
|
+
self.seg_scores = pd.concat([self.seg_scores, placeholder], ignore_index=True)
|
48
|
+
|
49
|
+
if system not in self.metadata["system"].values:
|
50
|
+
placeholder = pd.DataFrame([{"system": system, 'temperature': 'None'}] * segment_count)
|
51
|
+
self.metadata = pd.concat([self.metadata, placeholder], ignore_index=True)
|
52
|
+
|
53
|
+
# check that all systems are present and have correct number of scores
|
54
|
+
assert len(self.seg_scores[self.seg_scores.system == system]) == segment_count
|
55
|
+
assert len(self.metadata[self.metadata.system == system]) == segment_count
|
56
|
+
|
57
|
+
def get_seg_path(self):
|
58
|
+
return f"{self.prefix}.seg.score"
|
59
|
+
|
60
|
+
def get_sys_path(self):
|
61
|
+
return f"{self.prefix}.sys.score"
|
62
|
+
|
63
|
+
def get_domain_path(self):
|
64
|
+
return f"{self.prefix}.domain.score"
|
65
|
+
|
66
|
+
def get_meta_path(self):
|
67
|
+
return f"{self.prefix}.seg.meta"
|
68
|
+
|
69
|
+
def _remap_index(self, system, hypothesis_index):
|
70
|
+
# the order of systems may be different
|
71
|
+
# get id of the first hypothesis of the system
|
72
|
+
index = (self.seg_scores["system"] == system).argmax()
|
73
|
+
h = hypothesis_index % len(self.testset.sources)
|
74
|
+
return index + h
|
75
|
+
|
76
|
+
def get_score(self, system, hypothesis_index):
|
77
|
+
index = self._remap_index(system, hypothesis_index)
|
78
|
+
return self.seg_scores.iloc[index]['score']
|
79
|
+
|
80
|
+
def assign_score(self, system, hypothesis_index, answer, temperature=None):
|
81
|
+
index = self._remap_index(system, hypothesis_index)
|
82
|
+
self.seg_scores.iloc[index]['score'] = answer
|
83
|
+
self.metadata.iloc[index]['temperature'] = temperature
|
84
|
+
|
85
|
+
def save(self):
|
86
|
+
# segment level scores
|
87
|
+
self.seg_scores.to_csv(self.get_seg_path(), sep="\t", index=False, header=False, na_rep="None")
|
88
|
+
|
89
|
+
# system scores
|
90
|
+
self.seg_scores.score = self.seg_scores.score.replace("None", None).astype(float)
|
91
|
+
sys_scores_df = self.seg_scores.groupby(['system'], as_index=False, dropna=True).mean()
|
92
|
+
sys_scores_df.to_csv(self.get_sys_path(), sep="\t", index=False, header=False, na_rep="None")
|
93
|
+
|
94
|
+
# domain scores
|
95
|
+
df = self.seg_scores.copy()
|
96
|
+
documents = self.testset.documents
|
97
|
+
total_systems = len(self.testset.systems)
|
98
|
+
df["domains"] = pd.DataFrame([x.split("\t")[0] for x in documents] * total_systems)[0]
|
99
|
+
df = df.groupby(["domains", 'system'], as_index=False, dropna=True).mean()
|
100
|
+
df.to_csv(self.get_domain_path(), sep="\t", index=False, header=False, na_rep="None")
|
101
|
+
|
102
|
+
# metadata
|
103
|
+
self.metadata.to_csv(self.get_meta_path(), sep="\t", index=False, header=False, na_rep="None")
|
gemba/testset.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
import glob
|
2
|
+
import os
|
3
|
+
|
4
|
+
|
5
|
+
class Testset:
|
6
|
+
def __init__(self, basepath, dataset, lp):
|
7
|
+
self.basepath = basepath
|
8
|
+
self.dataset = dataset
|
9
|
+
self.lp = lp
|
10
|
+
|
11
|
+
self.sources = []
|
12
|
+
self.references = {}
|
13
|
+
self.systems = {}
|
14
|
+
self.documents = []
|
15
|
+
self.main_ref = None
|
16
|
+
|
17
|
+
self.load()
|
18
|
+
|
19
|
+
def load(self):
|
20
|
+
dataset = f"{self.basepath}/{self.dataset}"
|
21
|
+
|
22
|
+
self.sources = self.load_segment_files(f"{dataset}/sources/{self.lp}.txt")
|
23
|
+
|
24
|
+
# list all files in references folder
|
25
|
+
refs = glob.glob(f"{dataset}/references/{self.lp}.*.txt")
|
26
|
+
for reffile in refs:
|
27
|
+
refname = reffile.split('.')[-2]
|
28
|
+
if self.main_ref is None:
|
29
|
+
self.main_ref = refname
|
30
|
+
self.references[refname] = self.load_segment_files(reffile)
|
31
|
+
|
32
|
+
systems = f"{dataset}/system-outputs/{self.lp}"
|
33
|
+
# keep systems in order
|
34
|
+
all_systems = sorted(os.listdir(systems))
|
35
|
+
for system in all_systems:
|
36
|
+
systemname = system.replace(".txt", "")
|
37
|
+
self.systems[systemname] = self.load_segment_files(f"{systems}/{system}")
|
38
|
+
|
39
|
+
self.documents = self.load_segment_files(f"{dataset}/documents/{self.lp}.docs")
|
40
|
+
|
41
|
+
def iterate_over_all(self, reference=None):
|
42
|
+
for system in self.systems.keys():
|
43
|
+
if reference is None:
|
44
|
+
for src, hyp in zip(self.sources, self.systems[system]):
|
45
|
+
yield src, hyp, None, system
|
46
|
+
else:
|
47
|
+
for src, hyp, ref in zip(self.sources, self.systems[system], self.references[reference]):
|
48
|
+
yield src, hyp, ref, system
|
49
|
+
|
50
|
+
def load_segment_files(self, path):
|
51
|
+
segments = []
|
52
|
+
with open(path, "r") as fh:
|
53
|
+
for line in fh:
|
54
|
+
segments.append(line.rstrip())
|
55
|
+
return segments
|
56
|
+
|
57
|
+
def segments_count(self):
|
58
|
+
return len(self.sources)*len(self.systems)
|
gemba/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
1
|
+
import ipdb
|
2
|
+
import pandas as pd
|
3
|
+
import diskcache as dc
|
4
|
+
from gemba.gpt_api import GptApi
|
5
|
+
from gemba.gemba_mqm_utils import TEMPLATE_GEMBA_MQM, apply_template, parse_mqm_answer
|
6
|
+
from gemba.gemba_esa import TEMPLATE_GEMBA_ESA_ERROR_SPANS, TEMPLATE_GEMBA_ESA_RANKING
|
7
|
+
from gemba.prompt import prompts, validate_number
|
8
|
+
|
9
|
+
|
10
|
+
def get_gemba_scores(source, hypothesis, source_lang, target_lang, method="GEMBA-MQM_norm", model="gpt-4o", cache_dir=".cache"):
|
11
|
+
"""Get GEMBA scores for machine translation evaluation.
|
12
|
+
|
13
|
+
This function evaluates machine translation quality using various GEMBA methods by leveraging
|
14
|
+
large language models to analyze source and translated text pairs.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
source (list): List of source language segments to be evaluated
|
18
|
+
hypothesis (list): List of target language translations to be evaluated
|
19
|
+
source_lang (str): Source language code (e.g. 'en' for English)
|
20
|
+
target_lang (str): Target language code (e.g. 'de' for German)
|
21
|
+
method (str): Evaluation method to use. One of:
|
22
|
+
- "GEMBA-MQM": MQM-style error annotation and scoring
|
23
|
+
- "GEMBA-MQM_norm": MQM-style error annotation and scoring with normalization
|
24
|
+
- "GEMBA-DA": Direct assessment scoring
|
25
|
+
- "GEMBA-DA_ref": Direct assessment with reference
|
26
|
+
- "GEMBA-SQM": Scalar quality metrics
|
27
|
+
- "GEMBA-SQM_ref": Scalar quality metrics with reference
|
28
|
+
- "GEMBA-stars": Star rating evaluation
|
29
|
+
- "GEMBA-stars_ref": Star rating with reference
|
30
|
+
- "GEMBA-classes": Classification-based evaluation
|
31
|
+
- "GEMBA-classes_ref": Classification with reference
|
32
|
+
- "GEMBA-ESA": Error span annotation and ranking
|
33
|
+
model (str): Name of the LLM model to use for evaluation
|
34
|
+
cache_dir (str): Directory to store the cache in
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
list: List of scores/evaluations for each source-hypothesis pair. The format depends on the method:
|
38
|
+
- For MQM: Negative scores where higher is better (max 0, min -25)
|
39
|
+
- For MQM_norm: Normalized scores to 0-100 range
|
40
|
+
- For DA/SQM: Numeric scores
|
41
|
+
- For stars: 1-5 star ratings
|
42
|
+
- For classes: Classification labels
|
43
|
+
- For ESA: Numeric rankings based on error spans
|
44
|
+
list: List of error classes for each source-hypothesis pair. Only returned for MQM methods.
|
45
|
+
|
46
|
+
The function uses disk caching to store results and avoid redundant API calls. Cache is stored
|
47
|
+
in a '{cache_dir}/{model}_{method}' directory.
|
48
|
+
"""
|
49
|
+
|
50
|
+
df = pd.DataFrame({'source_seg': source, 'target_seg': hypothesis})
|
51
|
+
df['source_lang'] = source_lang
|
52
|
+
df['target_lang'] = target_lang
|
53
|
+
|
54
|
+
cache = dc.Cache(f'{cache_dir}/{model}_{method}', expire=None, size_limit=int(10e10), cull_limit=0, eviction_policy='none')
|
55
|
+
gptapi = GptApi()
|
56
|
+
|
57
|
+
if method in ["GEMBA-MQM", "GEMBA-MQM_norm"]:
|
58
|
+
df["prompt"] = df.apply(lambda x: apply_template(TEMPLATE_GEMBA_MQM, x), axis=1)
|
59
|
+
parse_answer = lambda x: parse_mqm_answer(x, list_mqm_errors=True, full_desc=True, normalize=method == "GEMBA-MQM_norm")
|
60
|
+
answers = gptapi.bulk_request(df, model, parse_answer, cache=cache, max_tokens=500)
|
61
|
+
elif method in ["GEMBA-DA", "GEMBA-DA_ref", "GEMBA-SQM", "GEMBA-SQM_ref", "GEMBA-stars", "GEMBA-stars_ref", "GEMBA-classes", "GEMBA-classes_ref"]:
|
62
|
+
df["prompt"] = df.apply(lambda x: apply_template(prompts[method]['prompt'], x), axis=1)
|
63
|
+
parse_answer = prompts[method]["validate_answer"]
|
64
|
+
answers = gptapi.bulk_request(df, model, parse_answer, cache=cache, max_tokens=500)
|
65
|
+
elif method == "GEMBA-ESA":
|
66
|
+
df["prompt"] = df.apply(lambda x: apply_template(TEMPLATE_GEMBA_ESA_ERROR_SPANS, x), axis=1)
|
67
|
+
parse_answer = lambda x: x
|
68
|
+
error_spans = gptapi.bulk_request(df, model, parse_answer, cache=cache)
|
69
|
+
df['error_spans'] = pd.DataFrame(error_spans)['answer']
|
70
|
+
|
71
|
+
df["prompt"] = df.apply(lambda x: apply_template(TEMPLATE_GEMBA_ESA_RANKING, x), axis=1)
|
72
|
+
parse_answer = validate_number
|
73
|
+
answers = gptapi.bulk_request(df, model, parse_answer, cache=cache)
|
74
|
+
else:
|
75
|
+
raise Exception(f"Method {method} not supported.")
|
76
|
+
|
77
|
+
df = pd.DataFrame(answers)
|
78
|
+
return df['answer'].tolist(), df['errors'].tolist()
|
@@ -0,0 +1,136 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: gemba
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: GEMBA — GPT Estimation Metric Based Assessment
|
5
|
+
Project-URL: Homepage, https://github.com/joelniklaus/gemba
|
6
|
+
Author-email: Joel Niklaus <joel@niklaus.ai>
|
7
|
+
License: MIT
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
9
|
+
Classifier: Operating System :: OS Independent
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
11
|
+
Requires-Python: >=3.8
|
12
|
+
Requires-Dist: absl-py
|
13
|
+
Requires-Dist: diskcache
|
14
|
+
Requires-Dist: ipdb
|
15
|
+
Requires-Dist: openai>=1.0.0
|
16
|
+
Requires-Dist: pandas
|
17
|
+
Requires-Dist: pexpect
|
18
|
+
Requires-Dist: scipy
|
19
|
+
Requires-Dist: termcolor
|
20
|
+
Description-Content-Type: text/markdown
|
21
|
+
|
22
|
+
# GEMBA-MQM and GEMBA-DA
|
23
|
+
|
24
|
+
## Setup
|
25
|
+
|
26
|
+
Install required packages with python >= 3.8 (tested with 3.12.7)
|
27
|
+
|
28
|
+
```
|
29
|
+
pip install -r requirements.txt
|
30
|
+
```
|
31
|
+
|
32
|
+
Set up secrets either for Azure API or OpenAI API:
|
33
|
+
|
34
|
+
```
|
35
|
+
export OPENAI_AZURE_ENDPOINT=
|
36
|
+
export OPENAI_AZURE_KEY=
|
37
|
+
```
|
38
|
+
|
39
|
+
or
|
40
|
+
|
41
|
+
```
|
42
|
+
export OPENAI_API_KEY=
|
43
|
+
```
|
44
|
+
|
45
|
+
## Scoring with GEMBA
|
46
|
+
|
47
|
+
Install the gemba package with `pip install gemba` and use the following code:
|
48
|
+
|
49
|
+
```python
|
50
|
+
from gemba import get_gemba_scores
|
51
|
+
|
52
|
+
source = ["Hello, how are you?", "I am fine, thank you.", "I am not fine, thank you."]
|
53
|
+
hypothesis = ["Hallo, wie geht es dir?", "Ich bin gut, danke.", "Ich bin Adolf, wer bist du?"]
|
54
|
+
source_lang = "en"
|
55
|
+
target_lang = "de"
|
56
|
+
|
57
|
+
answers, errors = get_gemba_scores(source, hypothesis, source_lang, target_lang, method="GEMBA-MQM_norm", model="gpt-4o")
|
58
|
+
|
59
|
+
for answer, error in zip(answers, errors):
|
60
|
+
print(answer, error)
|
61
|
+
|
62
|
+
```
|
63
|
+
|
64
|
+
Alternatively, you can run the main file on two text files. It assumes two files with the same number of lines. It prints the score for each line pair:
|
65
|
+
|
66
|
+
```bash
|
67
|
+
python main.py --source=source.txt --hypothesis=hypothesis.txt --source_lang=English --target_lang=Czech --method="GEMBA-MQM" --model="gpt-4"
|
68
|
+
```
|
69
|
+
|
70
|
+
The main recommended methods: `GEMBA-MQM` and `GEMBA-DA` with the model `gpt-4`.
|
71
|
+
|
72
|
+
## Collecting and evaluating experiments for GEMBA-DA
|
73
|
+
|
74
|
+
Get mt-metric-eval and download resources:
|
75
|
+
|
76
|
+
```
|
77
|
+
git clone https://github.com/google-research/mt-metrics-eval.git
|
78
|
+
cd mt-metrics-eval
|
79
|
+
pip install .
|
80
|
+
alias mtme='python3 -m mt_metrics_eval.mtme'
|
81
|
+
mtme --download
|
82
|
+
cd ..
|
83
|
+
mv ~/.mt-metrics-eval/mt-metrics-eval-v2 mt-metrics-eval-v2
|
84
|
+
```
|
85
|
+
|
86
|
+
Collect data and run the scorer
|
87
|
+
|
88
|
+
```
|
89
|
+
python gemba_da.py
|
90
|
+
|
91
|
+
export PYTHONPATH=mt-metrics-eval:$PYTHONPATH
|
92
|
+
python evaluate.py
|
93
|
+
```
|
94
|
+
|
95
|
+
## License
|
96
|
+
GEMBA code and data are released under the [CC BY-SA 4.0 license](https://github.com/MicrosoftTranslator/GEMBA/blob/main/LICENSE.md).
|
97
|
+
|
98
|
+
## Paper
|
99
|
+
You can read more about GEMBA-DA [in our arXiv paper](https://arxiv.org/pdf/2302.14520.pdf)
|
100
|
+
or GEMBA-MQM [in our arXiv paper](https://arxiv.org/pdf/2310.13988.pdf).
|
101
|
+
|
102
|
+
## How to Cite
|
103
|
+
|
104
|
+
|
105
|
+
### GEMBA-MQM
|
106
|
+
|
107
|
+
@inproceedings{kocmi-federmann-2023-gemba-mqm,
|
108
|
+
title = {GEMBA-MQM: Detecting Translation Quality Error Spans with GPT-4},
|
109
|
+
author = {Kocmi, Tom and Federmann, Christian},
|
110
|
+
booktitle = "Proceedings of the Eighth Conference on Machine Translation",
|
111
|
+
month = dec,
|
112
|
+
year = "2023",
|
113
|
+
address = "Singapore",
|
114
|
+
publisher = "Association for Computational Linguistics",
|
115
|
+
}
|
116
|
+
|
117
|
+
### GEMBA-DA
|
118
|
+
|
119
|
+
@inproceedings{kocmi-federmann-2023-large,
|
120
|
+
title = "Large Language Models Are State-of-the-Art Evaluators of Translation Quality",
|
121
|
+
author = "Kocmi, Tom and Federmann, Christian",
|
122
|
+
booktitle = "Proceedings of the 24th Annual Conference of the European Association for Machine Translation",
|
123
|
+
month = jun,
|
124
|
+
year = "2023",
|
125
|
+
address = "Tampere, Finland",
|
126
|
+
publisher = "European Association for Machine Translation",
|
127
|
+
url = "https://aclanthology.org/2023.eamt-1.19",
|
128
|
+
pages = "193--203",
|
129
|
+
}
|
130
|
+
|
131
|
+
|
132
|
+
|
133
|
+
|
134
|
+
|
135
|
+
|
136
|
+
|
@@ -0,0 +1,14 @@
|
|
1
|
+
gemba/__init__.py,sha256=0ZuEumkUMWPI5wQMY7OxLolELI9GYYlup-iJw8SwBgc,67
|
2
|
+
gemba/gemba_da.py,sha256=YCOKKP7kZBL9e1d44Zr7aTa23BqLFvh4KDOfbNSMgOU,2360
|
3
|
+
gemba/gemba_esa.py,sha256=nBCeFjrS24wXLOcAXHRSmZFYJSkUzRS4hfp2LEqYwp8,4461
|
4
|
+
gemba/gemba_mqm_utils.py,sha256=qiIdJv7IDx0eeqpsTCHMoUeo8EUOhG6k-YfrzkRfxyw,9612
|
5
|
+
gemba/gpt_api.py,sha256=Igp8uQn6chKL1QWFMqKP2VR9Fbzxm8Xk83ELxk5NfM8,6671
|
6
|
+
gemba/mtme_tools.py,sha256=xpLxCzfnLHFIxsq_LOi1Lpb-gkyFGYqFXiq9y6O315Q,4667
|
7
|
+
gemba/prompt.py,sha256=AuPBhO2OBL3EB5I37p-GX10sx29gRw35xFAnB3bqtII,7578
|
8
|
+
gemba/scores.py,sha256=FmmBJ-ds-abExphcVUw9qaPMnKttPWobuXNwZKLAtEs,4388
|
9
|
+
gemba/testset.py,sha256=tDvi6xQIBXrODg02WWINrYg9jNQqruCmhBrxe9AaK48,1926
|
10
|
+
gemba/utils.py,sha256=Re5uW5dcFj3ITWIGpxjXdAKNDKQ7i4H-Tr_s74SQgmk,4311
|
11
|
+
gemba-0.1.0.dist-info/METADATA,sha256=9_jYmIPKmAz5cmPn-fTUB7a5xHLbYrlTXpdzhEYaSSw,3692
|
12
|
+
gemba-0.1.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
13
|
+
gemba-0.1.0.dist-info/licenses/LICENSE.md,sha256=XkNv-P-7d9hgciDpvOIMiRXYYAEP7rbB6-9ahWiOmzk,20137
|
14
|
+
gemba-0.1.0.dist-info/RECORD,,
|