subset2evaluate 0.0.1a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- subset2evaluate/__init__.py +6 -0
- subset2evaluate/evaluate.py +272 -0
- subset2evaluate/methods.py +363 -0
- subset2evaluate/methods_old.py +436 -0
- subset2evaluate/select_subset.py +101 -0
- subset2evaluate/test.py +66 -0
- subset2evaluate/utils.py +410 -0
- subset2evaluate-0.0.1a6.dist-info/METADATA +210 -0
- subset2evaluate-0.0.1a6.dist-info/RECORD +12 -0
- subset2evaluate-0.0.1a6.dist-info/WHEEL +5 -0
- subset2evaluate-0.0.1a6.dist-info/entry_points.txt +3 -0
- subset2evaluate-0.0.1a6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
from typing import Dict, List, Tuple
|
|
2
|
+
import numpy as np
|
|
3
|
+
import subset2evaluate
|
|
4
|
+
import subset2evaluate.utils as utils
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def run_evaluate_cluacc(data_new: List[Dict], data_old: List[Dict], metric="human", props: List[float]=utils.PROPS) -> Tuple[float, float]:
|
|
8
|
+
# both list or descriptor is fine
|
|
9
|
+
data_new = utils.load_data(data_new)
|
|
10
|
+
data_old = utils.load_data(data_old)
|
|
11
|
+
|
|
12
|
+
clu_new = []
|
|
13
|
+
acc_new = []
|
|
14
|
+
for prop in props:
|
|
15
|
+
k = int(len(data_old) * prop)
|
|
16
|
+
clu_new.append(eval_subset_clusters(data_new[:k], metric=metric))
|
|
17
|
+
acc_new.append(eval_subset_accuracy(data_new[:k], data_old, metric=metric))
|
|
18
|
+
|
|
19
|
+
return clu_new, acc_new
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def run_evaluate_cluacc_par(
|
|
23
|
+
data_new: List[Dict],
|
|
24
|
+
data_old: List[Dict],
|
|
25
|
+
clus_tgt: List[float],
|
|
26
|
+
accs_tgt: List[float],
|
|
27
|
+
metric="human",
|
|
28
|
+
props: List[float]=utils.PROPS,
|
|
29
|
+
workers=10,
|
|
30
|
+
) -> Tuple[float, float]:
|
|
31
|
+
"""
|
|
32
|
+
Evaluates the proportion of data that is needed to achieve parity with target.
|
|
33
|
+
"""
|
|
34
|
+
import multiprocessing.pool
|
|
35
|
+
|
|
36
|
+
# both list or descriptor is fine
|
|
37
|
+
data_new = utils.load_data(data_new)
|
|
38
|
+
data_old = utils.load_data(data_old)
|
|
39
|
+
|
|
40
|
+
def _par_clu(data_new, clu_tgt, metric):
|
|
41
|
+
for k in range(5, len(data_new) + 1):
|
|
42
|
+
if eval_subset_clusters(data_new[:k], metric=metric) >= clu_tgt:
|
|
43
|
+
break
|
|
44
|
+
return k
|
|
45
|
+
|
|
46
|
+
def _par_acc(data_new, data_old, acc_tgt, metric):
|
|
47
|
+
for k in range(5, len(data_new) + 1):
|
|
48
|
+
if eval_subset_accuracy(data_new[:k], data_old, metric=metric) >= acc_tgt:
|
|
49
|
+
break
|
|
50
|
+
return k
|
|
51
|
+
|
|
52
|
+
# multiprocess for each prop rather than k because the thread
|
|
53
|
+
# orchestration would be more expensive otherwise
|
|
54
|
+
with multiprocessing.pool.ThreadPool(min(workers, len(props))) as pool:
|
|
55
|
+
ks_clu_par = pool.starmap(
|
|
56
|
+
_par_clu,
|
|
57
|
+
[(data_new, clu_tgt, metric) for prop, clu_tgt in zip(props, clus_tgt)]
|
|
58
|
+
)
|
|
59
|
+
ks_clu_par = [k / (len(data_old) * prop) for k, prop in zip(ks_clu_par, props)]
|
|
60
|
+
|
|
61
|
+
ks_acc_par = pool.starmap(
|
|
62
|
+
_par_acc,
|
|
63
|
+
[(data_new, data_old, clu_tgt, metric) for prop, clu_tgt in zip(props, accs_tgt)]
|
|
64
|
+
)
|
|
65
|
+
ks_acc_par = [k / (len(data_old) * prop) for k, prop in zip(ks_acc_par, props)]
|
|
66
|
+
|
|
67
|
+
return np.average(ks_clu_par), np.average(ks_acc_par)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def precompute_randnorm(
|
|
71
|
+
data_old: List[Dict],
|
|
72
|
+
random_seeds=10,
|
|
73
|
+
metric="human",
|
|
74
|
+
workers=10,
|
|
75
|
+
) -> Tuple[List[float], List[float], float, float]:
|
|
76
|
+
import subset2evaluate.select_subset
|
|
77
|
+
|
|
78
|
+
clu_random = []
|
|
79
|
+
acc_random = []
|
|
80
|
+
for seed in range(random_seeds):
|
|
81
|
+
clu_new, acc_new = run_evaluate_cluacc(
|
|
82
|
+
subset2evaluate.select_subset.run_select_subset(data_old, method="random", seed=seed),
|
|
83
|
+
data_old,
|
|
84
|
+
metric=metric,
|
|
85
|
+
)
|
|
86
|
+
clu_random.append(clu_new)
|
|
87
|
+
acc_random.append(acc_new)
|
|
88
|
+
clu_random = np.average(clu_random, axis=0)
|
|
89
|
+
acc_random = np.average(acc_random, axis=0)
|
|
90
|
+
|
|
91
|
+
pars_clu_rand = []
|
|
92
|
+
pars_acc_rand = []
|
|
93
|
+
|
|
94
|
+
for seed in range(random_seeds, 2*random_seeds):
|
|
95
|
+
par_clu_rand, par_acc_rand = run_evaluate_cluacc_par(
|
|
96
|
+
subset2evaluate.select_subset.run_select_subset(data_old, method="random", seed=seed),
|
|
97
|
+
data_old,
|
|
98
|
+
clu_random,
|
|
99
|
+
acc_random,
|
|
100
|
+
metric=metric,
|
|
101
|
+
workers=workers,
|
|
102
|
+
)
|
|
103
|
+
pars_clu_rand.append(par_clu_rand)
|
|
104
|
+
pars_acc_rand.append(par_acc_rand)
|
|
105
|
+
|
|
106
|
+
return (clu_random, acc_random), (np.average(pars_clu_rand), np.average(pars_acc_rand))
|
|
107
|
+
|
|
108
|
+
def run_evaluate_cluacc_randnorm(
|
|
109
|
+
data_new: List[Dict],
|
|
110
|
+
data_old: List[Dict],
|
|
111
|
+
random_seeds=10,
|
|
112
|
+
metric="human",
|
|
113
|
+
cluacc_precomputed = None
|
|
114
|
+
) -> Tuple[float, float]:
|
|
115
|
+
|
|
116
|
+
if cluacc_precomputed is not None:
|
|
117
|
+
(clu_random, acc_random), (clu_random_norm, acc_random_norm) = cluacc_precomputed
|
|
118
|
+
else:
|
|
119
|
+
(clu_random, acc_random), (clu_random_norm, acc_random_norm) = precompute_randnorm(data_old, random_seeds=random_seeds, metric=metric)
|
|
120
|
+
|
|
121
|
+
# compute the parity of the new data
|
|
122
|
+
par_clu, par_acc = run_evaluate_cluacc_par(
|
|
123
|
+
data_new, data_old,
|
|
124
|
+
clu_random, acc_random,
|
|
125
|
+
metric=metric
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return par_clu/clu_random_norm, par_acc/acc_random_norm
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def run_evaluate_top_timebudget(data_new, data_old, metric="human"):
|
|
132
|
+
# both list or descriptor is fine
|
|
133
|
+
data_old = utils.load_data(data_old)
|
|
134
|
+
data_new = utils.load_data(data_new)
|
|
135
|
+
|
|
136
|
+
clu_new = []
|
|
137
|
+
acc_new = []
|
|
138
|
+
for prop in utils.PROPS:
|
|
139
|
+
k = int(len(data_old) * prop)
|
|
140
|
+
data_new_inbudget = []
|
|
141
|
+
budget = k
|
|
142
|
+
for item in data_new:
|
|
143
|
+
if item["time"] <= budget:
|
|
144
|
+
budget -= item["time"]
|
|
145
|
+
data_new_inbudget.append(item)
|
|
146
|
+
else:
|
|
147
|
+
break
|
|
148
|
+
clu_new.append(eval_subset_clusters(data_new_inbudget, metric=metric))
|
|
149
|
+
acc_new.append(eval_subset_accuracy(data_new_inbudget, data_old, metric=metric))
|
|
150
|
+
|
|
151
|
+
return clu_new, acc_new
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def eval_subset_accuracy(data_new: List[Dict], data_old: List[Dict], metric="human"):
|
|
155
|
+
# evaluates against ordering from data_old
|
|
156
|
+
import itertools
|
|
157
|
+
|
|
158
|
+
systems = list(data_old[0]["scores"].keys())
|
|
159
|
+
|
|
160
|
+
scores_old = get_sys_absolute(data_old, metric=metric)
|
|
161
|
+
scores_new = get_sys_absolute(data_new, metric=metric)
|
|
162
|
+
|
|
163
|
+
result = []
|
|
164
|
+
for sys1, sys2 in itertools.combinations(systems, 2):
|
|
165
|
+
result.append((scores_old[sys1] < scores_old[sys2]) == (scores_new[sys1] < scores_new[sys2]))
|
|
166
|
+
|
|
167
|
+
return np.average(result)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def eval_subset_clusters(data: List[Dict], metric="human"):
|
|
171
|
+
from scipy.stats import wilcoxon
|
|
172
|
+
import warnings
|
|
173
|
+
|
|
174
|
+
# if we have just 3 samples, we can't say that there are clusters
|
|
175
|
+
if len(data) < 3:
|
|
176
|
+
return 1
|
|
177
|
+
|
|
178
|
+
# sort from top
|
|
179
|
+
sys_ord = list(get_sys_absolute(data, metric=metric).items())
|
|
180
|
+
sys_ord.sort(key=lambda x: x[1], reverse=True)
|
|
181
|
+
sys_ord = [sys for sys, _ in sys_ord]
|
|
182
|
+
|
|
183
|
+
def get_scores(system):
|
|
184
|
+
return [line["scores"][system][metric] for line in data]
|
|
185
|
+
|
|
186
|
+
clusters = [[get_scores(sys_ord.pop(0))]]
|
|
187
|
+
while sys_ord:
|
|
188
|
+
sys_scores = get_scores(sys_ord.pop(0))
|
|
189
|
+
diffs = [x - y for x, y in zip(sys_scores, clusters[-1][-1])]
|
|
190
|
+
|
|
191
|
+
with warnings.catch_warnings(action="ignore"):
|
|
192
|
+
if all([d == 0 for d in diffs]) or wilcoxon(diffs, alternative="less").pvalue < 0.05:
|
|
193
|
+
clusters.append([sys_scores])
|
|
194
|
+
else:
|
|
195
|
+
clusters[-1].append(sys_scores)
|
|
196
|
+
|
|
197
|
+
return len(clusters)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def get_sys_absolute(data_new, metric="human") -> Dict[str, float]:
|
|
201
|
+
import collections
|
|
202
|
+
import numpy as np
|
|
203
|
+
|
|
204
|
+
scores_new = collections.defaultdict(list)
|
|
205
|
+
|
|
206
|
+
systems = list(data_new[0]["scores"].keys())
|
|
207
|
+
for line in data_new:
|
|
208
|
+
for sys in systems:
|
|
209
|
+
scores_new[sys].append(line["scores"][sys][metric])
|
|
210
|
+
|
|
211
|
+
scores_new = {
|
|
212
|
+
sys: np.average(scores_new[sys])
|
|
213
|
+
for sys in systems
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
return scores_new
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def get_sys_ordering(data_new: List[Dict], metric="human"):
|
|
220
|
+
scores_new = get_sys_absolute(data_new, metric)
|
|
221
|
+
|
|
222
|
+
# sort to get ordering
|
|
223
|
+
scores_new = list(scores_new.items())
|
|
224
|
+
# sort from highest
|
|
225
|
+
scores_new.sort(key=lambda x: x[1], reverse=True)
|
|
226
|
+
|
|
227
|
+
sys_ord = {
|
|
228
|
+
sys: sys_i
|
|
229
|
+
for sys_i, (sys, sys_v) in enumerate(scores_new)
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
return sys_ord
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def eval_order_accuracy(scores_new: Dict[str, float], scores_old: Dict[str, float]):
|
|
236
|
+
# evaluates against ordering from data_old
|
|
237
|
+
import itertools
|
|
238
|
+
import numpy as np
|
|
239
|
+
|
|
240
|
+
systems = list(scores_old.keys())
|
|
241
|
+
|
|
242
|
+
result = []
|
|
243
|
+
for sys1, sys2 in itertools.combinations(systems, 2):
|
|
244
|
+
result.append((scores_old[sys1] < scores_old[sys2]) == (scores_new[sys1] < scores_new[sys2]))
|
|
245
|
+
|
|
246
|
+
return np.average(result)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def main_cli():
|
|
250
|
+
import argparse
|
|
251
|
+
|
|
252
|
+
args = argparse.ArgumentParser(
|
|
253
|
+
description="Meta-evaluate subset selection methods with cluster count and system accuracy."
|
|
254
|
+
)
|
|
255
|
+
args.add_argument(
|
|
256
|
+
'data_old', type=str, default='wmt23/en-cs',
|
|
257
|
+
help="Original data descriptor or path."
|
|
258
|
+
)
|
|
259
|
+
args.add_argument(
|
|
260
|
+
'data_new', type=str, default='wmt23/en-cs',
|
|
261
|
+
help="Path to new ordered data."
|
|
262
|
+
)
|
|
263
|
+
args.add_argument(
|
|
264
|
+
'--metric', type=str, default='human',
|
|
265
|
+
help="Metric to evaluate against, e.g., human or human_consistency. Can also be a metric and not human score."
|
|
266
|
+
)
|
|
267
|
+
args = args.parse_args()
|
|
268
|
+
|
|
269
|
+
clu_new, acc_new = run_evaluate_cluacc(args.data_old, args.data_new, args.metric)
|
|
270
|
+
|
|
271
|
+
print(f"Clusters: {np.average(clu_new):.2f}")
|
|
272
|
+
print(f"Accuracy: {np.average(acc_new):.1%}")
|
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
from typing import Any, List, Tuple, Union
|
|
2
|
+
from functools import partial
|
|
3
|
+
import numpy as np
|
|
4
|
+
import os
|
|
5
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def random_subset(data, seed=None, **kwargs) -> List[float]:
|
|
9
|
+
import random
|
|
10
|
+
r = random.Random(seed)
|
|
11
|
+
return [r.random() for _ in data]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def metric_avg(data, metric, **kwargs) -> List[float]:
|
|
15
|
+
return [
|
|
16
|
+
-np.average([sys_v[metric] for sys_v in item["scores"].values()])
|
|
17
|
+
for item in data
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def metric_var(data, metric, **kwargs) -> List[float]:
|
|
22
|
+
return [
|
|
23
|
+
np.var([sys_v[metric] for sys_v in item["scores"].values()])
|
|
24
|
+
for item in data
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _fn_information_content(item_old, item_irt, data_irt) -> float:
|
|
29
|
+
information = 0
|
|
30
|
+
for theta in data_irt["systems"].values():
|
|
31
|
+
x1 = np.exp(item_irt["disc"] * (theta + item_irt["diff"]))
|
|
32
|
+
x2 = np.exp(item_irt["disc"] * item_irt["diff"])
|
|
33
|
+
x3 = np.exp(item_irt["disc"] * theta)
|
|
34
|
+
information += (item_irt["disc"]**2) * x1 / (x2 + x3)**2
|
|
35
|
+
return information
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def fn_irt_utility(item_old, item_irt, data_irt, fn_utility) -> float:
|
|
39
|
+
if fn_utility == "fisher_information_content":
|
|
40
|
+
return _fn_information_content(item_old, item_irt, data_irt)
|
|
41
|
+
elif fn_utility == "diff":
|
|
42
|
+
return -item_irt["diff"]
|
|
43
|
+
elif fn_utility == "disc":
|
|
44
|
+
return -item_irt["disc"]
|
|
45
|
+
elif fn_utility == "diffdisc":
|
|
46
|
+
return item_irt["diff"] * item_irt["disc"]
|
|
47
|
+
elif fn_utility == "feas":
|
|
48
|
+
return item_irt["feas"]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def pyirt(data, metric, return_model=False, load_model=None, model="4pl_score", dropout=0.25, epochs=1000, enforce_positive_disc=False, **kwargs) -> Union[List[float], Tuple[List[float], Any]]:
|
|
52
|
+
import py_irt
|
|
53
|
+
import py_irt.config
|
|
54
|
+
import py_irt.dataset
|
|
55
|
+
import py_irt.io
|
|
56
|
+
import py_irt.training
|
|
57
|
+
import py_irt.models
|
|
58
|
+
import py_irt.models.abstract_model
|
|
59
|
+
import pandas as pd
|
|
60
|
+
|
|
61
|
+
if model not in py_irt.models.abstract_model._IRT_REGISTRY:
|
|
62
|
+
raise Exception("Please install py-irt with `pip install git+https://github.com/zouharvi/py-irt.git")
|
|
63
|
+
|
|
64
|
+
systems = list(data[0]["scores"].keys())
|
|
65
|
+
|
|
66
|
+
if load_model is not None:
|
|
67
|
+
data_irt = load_model
|
|
68
|
+
else:
|
|
69
|
+
# we need median binarization if we are not using 4pl_score model
|
|
70
|
+
median = np.median([
|
|
71
|
+
system_v[metric]
|
|
72
|
+
for line in data
|
|
73
|
+
for system_v in line["scores"].values()
|
|
74
|
+
])
|
|
75
|
+
dataset = pd.DataFrame({
|
|
76
|
+
"system": systems,
|
|
77
|
+
**{
|
|
78
|
+
f"item_{line['i']}": [
|
|
79
|
+
line["scores"][system][metric]
|
|
80
|
+
if "_score" in model else
|
|
81
|
+
line["scores"][system][metric] >= median
|
|
82
|
+
for system in systems
|
|
83
|
+
]
|
|
84
|
+
for line in data
|
|
85
|
+
}
|
|
86
|
+
})
|
|
87
|
+
|
|
88
|
+
embeddings = None
|
|
89
|
+
if "amortized_" in model:
|
|
90
|
+
import sentence_transformers
|
|
91
|
+
embd_model = sentence_transformers.SentenceTransformer("paraphrase-MiniLM-L12-v2")
|
|
92
|
+
embeddings = embd_model.encode([line["src"] for line in data])
|
|
93
|
+
embeddings = {f"item_{line['i']}": emb.tolist() for line, emb in zip(data, embeddings)}
|
|
94
|
+
del embd_model
|
|
95
|
+
|
|
96
|
+
dataset = py_irt.dataset.Dataset.from_pandas(
|
|
97
|
+
dataset,
|
|
98
|
+
subject_column="system",
|
|
99
|
+
item_columns=[f"item_{line['i']}" for line in data],
|
|
100
|
+
embeddings=embeddings,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
config = py_irt.config.IrtConfig(
|
|
104
|
+
model_type=model,
|
|
105
|
+
log_every=100,
|
|
106
|
+
dropout=dropout,
|
|
107
|
+
priors="hiearchical",
|
|
108
|
+
seed=0,
|
|
109
|
+
deterministic=True,
|
|
110
|
+
)
|
|
111
|
+
trainer = py_irt.training.IrtModelTrainer(
|
|
112
|
+
config=config,
|
|
113
|
+
data_path=None,
|
|
114
|
+
dataset=dataset,
|
|
115
|
+
verbose=False
|
|
116
|
+
)
|
|
117
|
+
trainer.train(epochs=epochs, device='cuda')
|
|
118
|
+
|
|
119
|
+
params = trainer.best_params
|
|
120
|
+
|
|
121
|
+
# this flipping should not affect the predictions
|
|
122
|
+
if enforce_positive_disc and np.average(params["disc"]) < 0:
|
|
123
|
+
params["disc"] = -np.array(params["disc"])
|
|
124
|
+
params["ability"] = -np.array(params["ability"])
|
|
125
|
+
params["diff"] = -np.array(params["diff"])
|
|
126
|
+
|
|
127
|
+
# normalize naming
|
|
128
|
+
if "lambdas" in params:
|
|
129
|
+
params["feas"] = params.pop("lambdas")
|
|
130
|
+
|
|
131
|
+
# TODO: cross-check make sure that we do the predictions as the models were trained
|
|
132
|
+
if "feas" in params:
|
|
133
|
+
# 3PL/4PL
|
|
134
|
+
data_irt = {
|
|
135
|
+
"systems": {sys: sys_v for sys, sys_v in zip(systems, params["ability"])},
|
|
136
|
+
"items": [
|
|
137
|
+
{"disc": disc, "diff": diff, "feas": feas}
|
|
138
|
+
for disc, diff, feas in zip(
|
|
139
|
+
params["disc"],
|
|
140
|
+
params["diff"],
|
|
141
|
+
params["feas"],
|
|
142
|
+
)
|
|
143
|
+
]
|
|
144
|
+
}
|
|
145
|
+
elif "disc" in params:
|
|
146
|
+
data_irt = {
|
|
147
|
+
"systems": {sys: sys_v for sys, sys_v in zip(systems, params["ability"])},
|
|
148
|
+
"items": [
|
|
149
|
+
{"disc": disc, "diff": diff}
|
|
150
|
+
for disc, diff in zip(
|
|
151
|
+
params["disc"],
|
|
152
|
+
params["diff"],
|
|
153
|
+
)
|
|
154
|
+
]
|
|
155
|
+
}
|
|
156
|
+
else:
|
|
157
|
+
data_irt = {
|
|
158
|
+
"systems": {sys: sys_v for sys, sys_v in zip(systems, params["ability"])},
|
|
159
|
+
"items": [
|
|
160
|
+
{"diff": diff}
|
|
161
|
+
for diff in params["diff"]
|
|
162
|
+
]
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
scores = [
|
|
166
|
+
fn_irt_utility(item_old, item_irt, data_irt, kwargs["fn_utility"])
|
|
167
|
+
for item_old, item_irt in zip(data, data_irt["items"])
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
if return_model:
|
|
171
|
+
return scores, data_irt
|
|
172
|
+
else:
|
|
173
|
+
return scores
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _assert_comet_version():
|
|
177
|
+
import comet
|
|
178
|
+
if "HypothesislessRegression" not in dir(comet.models):
|
|
179
|
+
raise Exception("Please install COMET with `pip install git+https://github.com/zouharvi/comet-src.git`")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def precomet(data, model_path, return_model=False, load_model=None, reverse=False, **kwargs) -> Union[List, Tuple[List, Any]]:
|
|
183
|
+
import os
|
|
184
|
+
prev_tqdm_setting = os.environ.get("TQDM_DISABLE", None)
|
|
185
|
+
os.environ["TQDM_DISABLE"] = "1"
|
|
186
|
+
|
|
187
|
+
import logging
|
|
188
|
+
import comet
|
|
189
|
+
import warnings
|
|
190
|
+
|
|
191
|
+
logging.disable(logging.INFO)
|
|
192
|
+
_assert_comet_version()
|
|
193
|
+
|
|
194
|
+
with warnings.catch_warnings(action="ignore"):
|
|
195
|
+
if load_model is not None:
|
|
196
|
+
model = load_model
|
|
197
|
+
elif os.path.exists(model_path):
|
|
198
|
+
model = comet.load_from_checkpoint(model_path)
|
|
199
|
+
else:
|
|
200
|
+
model = comet.load_from_checkpoint(comet.download_model(model_path))
|
|
201
|
+
scores = model.predict([
|
|
202
|
+
{"src": line["src"]}
|
|
203
|
+
for line in data
|
|
204
|
+
], progress_bar=False).scores
|
|
205
|
+
if reverse:
|
|
206
|
+
scores = [-x for x in scores]
|
|
207
|
+
|
|
208
|
+
logging.disable(logging.NOTSET)
|
|
209
|
+
if prev_tqdm_setting is not None:
|
|
210
|
+
os.environ["TQDM_DISABLE"] = prev_tqdm_setting
|
|
211
|
+
else:
|
|
212
|
+
os.environ.pop("TQDM_DISABLE")
|
|
213
|
+
|
|
214
|
+
if return_model:
|
|
215
|
+
return scores, model
|
|
216
|
+
else:
|
|
217
|
+
return scores
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def precomet_dual(data, model_path1, model_path2, return_model=False, load_model=None, reverse=False, **kwargs) -> Union[List, Tuple[List, Any]]:
|
|
221
|
+
import os
|
|
222
|
+
tqdm_disable_prev = os.environ.get("TQDM_DISABLE", None)
|
|
223
|
+
os.environ["TQDM_DISABLE"] = "1"
|
|
224
|
+
|
|
225
|
+
import comet
|
|
226
|
+
import warnings
|
|
227
|
+
import logging
|
|
228
|
+
|
|
229
|
+
logging.disable(logging.INFO)
|
|
230
|
+
_assert_comet_version()
|
|
231
|
+
|
|
232
|
+
with warnings.catch_warnings(action="ignore"):
|
|
233
|
+
if load_model is not None:
|
|
234
|
+
model1, model2 = load_model
|
|
235
|
+
else:
|
|
236
|
+
if os.path.exists(model_path1):
|
|
237
|
+
model1 = comet.load_from_checkpoint(model_path1)
|
|
238
|
+
else:
|
|
239
|
+
model1 = comet.load_from_checkpoint(comet.download_model(model_path1))
|
|
240
|
+
|
|
241
|
+
if os.path.exists(model_path2):
|
|
242
|
+
model2 = comet.load_from_checkpoint(model_path2)
|
|
243
|
+
else:
|
|
244
|
+
model2 = comet.load_from_checkpoint(comet.download_model(model_path2))
|
|
245
|
+
scores1 = model1.predict([
|
|
246
|
+
{"src": line["src"]}
|
|
247
|
+
for line in data
|
|
248
|
+
], progress_bar=False).scores
|
|
249
|
+
scores2 = model2.predict([
|
|
250
|
+
{"src": line["src"]}
|
|
251
|
+
for line in data
|
|
252
|
+
], progress_bar=False).scores
|
|
253
|
+
|
|
254
|
+
if reverse:
|
|
255
|
+
scores = [-s1 * s2 for s1, s2 in zip(scores1, scores2)]
|
|
256
|
+
else:
|
|
257
|
+
scores = [s1 * s2 for s1, s2 in zip(scores1, scores2)]
|
|
258
|
+
|
|
259
|
+
logging.disable(logging.NOTSET)
|
|
260
|
+
if tqdm_disable_prev is not None:
|
|
261
|
+
os.environ["TQDM_DISABLE"] = tqdm_disable_prev
|
|
262
|
+
else:
|
|
263
|
+
os.environ.pop("TQDM_DISABLE")
|
|
264
|
+
|
|
265
|
+
if return_model:
|
|
266
|
+
return scores, (model1, model2)
|
|
267
|
+
else:
|
|
268
|
+
return scores
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def diversity_unigram(data, **kwargs) -> List[float]:
|
|
272
|
+
import itertools
|
|
273
|
+
import collections
|
|
274
|
+
|
|
275
|
+
def _f(line):
|
|
276
|
+
out = []
|
|
277
|
+
for text_a, text_b in itertools.combinations(line["tgt"].values(), 2):
|
|
278
|
+
text_a = collections.Counter(text_a.split())
|
|
279
|
+
text_b = collections.Counter(text_b.split())
|
|
280
|
+
if text_a.total() == 0 or text_b.total() == 0:
|
|
281
|
+
out.append(1)
|
|
282
|
+
else:
|
|
283
|
+
out.append(2 * (text_a & text_b).total() / (text_a.total() + text_b.total()))
|
|
284
|
+
return np.average(out)
|
|
285
|
+
|
|
286
|
+
# we prefer smallest similarity so flip
|
|
287
|
+
return [
|
|
288
|
+
-_f(line)
|
|
289
|
+
for line in data
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def diversity_bleu(data, **kwargs) -> List[float]:
|
|
294
|
+
import itertools
|
|
295
|
+
import sacrebleu
|
|
296
|
+
metric = sacrebleu.metrics.BLEU(effective_order=True)
|
|
297
|
+
|
|
298
|
+
def _f(line):
|
|
299
|
+
return np.average([
|
|
300
|
+
metric.sentence_score(
|
|
301
|
+
text_a,
|
|
302
|
+
[text_b],
|
|
303
|
+
).score
|
|
304
|
+
for text_a, text_b in itertools.product(line["tgt"].values(), line["tgt"].values())
|
|
305
|
+
])
|
|
306
|
+
|
|
307
|
+
# we prefer smallest similarity so flip
|
|
308
|
+
return [
|
|
309
|
+
-_f(line)
|
|
310
|
+
for line in data
|
|
311
|
+
]
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def diversity_chrf(data, **kwargs) -> List[float]:
|
|
315
|
+
import itertools
|
|
316
|
+
import sacrebleu
|
|
317
|
+
metric = sacrebleu.metrics.CHRF()
|
|
318
|
+
|
|
319
|
+
def _f(line):
|
|
320
|
+
return np.average([
|
|
321
|
+
metric.sentence_score(
|
|
322
|
+
text_a,
|
|
323
|
+
[text_b],
|
|
324
|
+
).score
|
|
325
|
+
for text_a, text_b in itertools.product(line["tgt"].values(), line["tgt"].values())
|
|
326
|
+
])
|
|
327
|
+
|
|
328
|
+
# we prefer smallest similarity so flip
|
|
329
|
+
return [
|
|
330
|
+
-_f(line)
|
|
331
|
+
for line in data
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
METHODS = {
|
|
336
|
+
"random": random_subset,
|
|
337
|
+
"metric_avg": metric_avg,
|
|
338
|
+
"metric_var": metric_var,
|
|
339
|
+
"diversity_bleu": diversity_bleu,
|
|
340
|
+
"diversity_chrf": diversity_chrf,
|
|
341
|
+
"diversity_unigram": diversity_unigram,
|
|
342
|
+
|
|
343
|
+
"pyirt_diff": partial(pyirt, fn_utility="diff"),
|
|
344
|
+
"pyirt_disc": partial(pyirt, fn_utility="disc"),
|
|
345
|
+
"pyirt_diffdisc": partial(pyirt, fn_utility="diffdisc"),
|
|
346
|
+
"pyirt_feas": partial(pyirt, fn_utility="feas"),
|
|
347
|
+
"pyirt_fic": partial(pyirt, fn_utility="fisher_information_content"),
|
|
348
|
+
"pyirt_experiment": partial(pyirt, fn_utility="experiment"),
|
|
349
|
+
|
|
350
|
+
"precomet_var": partial(precomet, model_path="zouharvi/PreCOMET-var", reverse=True),
|
|
351
|
+
"precomet_avg": partial(precomet, model_path="zouharvi/PreCOMET-avg", reverse=True),
|
|
352
|
+
"precomet_diversity": partial(precomet, model_path="zouharvi/PreCOMET-diversity", reverse=True),
|
|
353
|
+
|
|
354
|
+
"precomet_diff": partial(precomet, model_path="zouharvi/PreCOMET-diff", reverse=False),
|
|
355
|
+
"precomet_disc": partial(precomet, model_path="zouharvi/PreCOMET-disc", reverse=True),
|
|
356
|
+
"precomet_diffdisc_direct": partial(precomet, model_path="zouharvi/PreCOMET-diffdisc_direct", reverse=False),
|
|
357
|
+
"precomet_diffdisc": partial(
|
|
358
|
+
precomet_dual,
|
|
359
|
+
model_path1="zouharvi/PreCOMET-diff",
|
|
360
|
+
model_path2="zouharvi/PreCOMET-disc",
|
|
361
|
+
reverse=False,
|
|
362
|
+
),
|
|
363
|
+
}
|