subset2evaluate 0.0.1a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ # default imports
2
+ # flake8: noqa F401
3
+ import subset2evaluate.utils
4
+ import subset2evaluate.evaluate
5
+ import subset2evaluate.select_subset
6
+ import subset2evaluate.methods
@@ -0,0 +1,272 @@
1
+ from typing import Dict, List, Tuple
2
+ import numpy as np
3
+ import subset2evaluate
4
+ import subset2evaluate.utils as utils
5
+
6
+
7
+ def run_evaluate_cluacc(data_new: List[Dict], data_old: List[Dict], metric="human", props: List[float]=utils.PROPS) -> Tuple[float, float]:
8
+ # both list or descriptor is fine
9
+ data_new = utils.load_data(data_new)
10
+ data_old = utils.load_data(data_old)
11
+
12
+ clu_new = []
13
+ acc_new = []
14
+ for prop in props:
15
+ k = int(len(data_old) * prop)
16
+ clu_new.append(eval_subset_clusters(data_new[:k], metric=metric))
17
+ acc_new.append(eval_subset_accuracy(data_new[:k], data_old, metric=metric))
18
+
19
+ return clu_new, acc_new
20
+
21
+
22
+ def run_evaluate_cluacc_par(
23
+ data_new: List[Dict],
24
+ data_old: List[Dict],
25
+ clus_tgt: List[float],
26
+ accs_tgt: List[float],
27
+ metric="human",
28
+ props: List[float]=utils.PROPS,
29
+ workers=10,
30
+ ) -> Tuple[float, float]:
31
+ """
32
+ Evaluates the proportion of data that is needed to achieve parity with target.
33
+ """
34
+ import multiprocessing.pool
35
+
36
+ # both list or descriptor is fine
37
+ data_new = utils.load_data(data_new)
38
+ data_old = utils.load_data(data_old)
39
+
40
+ def _par_clu(data_new, clu_tgt, metric):
41
+ for k in range(5, len(data_new) + 1):
42
+ if eval_subset_clusters(data_new[:k], metric=metric) >= clu_tgt:
43
+ break
44
+ return k
45
+
46
+ def _par_acc(data_new, data_old, acc_tgt, metric):
47
+ for k in range(5, len(data_new) + 1):
48
+ if eval_subset_accuracy(data_new[:k], data_old, metric=metric) >= acc_tgt:
49
+ break
50
+ return k
51
+
52
+ # multiprocess for each prop rather than k because the thread
53
+ # orchestration would be more expensive otherwise
54
+ with multiprocessing.pool.ThreadPool(min(workers, len(props))) as pool:
55
+ ks_clu_par = pool.starmap(
56
+ _par_clu,
57
+ [(data_new, clu_tgt, metric) for prop, clu_tgt in zip(props, clus_tgt)]
58
+ )
59
+ ks_clu_par = [k / (len(data_old) * prop) for k, prop in zip(ks_clu_par, props)]
60
+
61
+ ks_acc_par = pool.starmap(
62
+ _par_acc,
63
+ [(data_new, data_old, clu_tgt, metric) for prop, clu_tgt in zip(props, accs_tgt)]
64
+ )
65
+ ks_acc_par = [k / (len(data_old) * prop) for k, prop in zip(ks_acc_par, props)]
66
+
67
+ return np.average(ks_clu_par), np.average(ks_acc_par)
68
+
69
+
70
+ def precompute_randnorm(
71
+ data_old: List[Dict],
72
+ random_seeds=10,
73
+ metric="human",
74
+ workers=10,
75
+ ) -> Tuple[List[float], List[float], float, float]:
76
+ import subset2evaluate.select_subset
77
+
78
+ clu_random = []
79
+ acc_random = []
80
+ for seed in range(random_seeds):
81
+ clu_new, acc_new = run_evaluate_cluacc(
82
+ subset2evaluate.select_subset.run_select_subset(data_old, method="random", seed=seed),
83
+ data_old,
84
+ metric=metric,
85
+ )
86
+ clu_random.append(clu_new)
87
+ acc_random.append(acc_new)
88
+ clu_random = np.average(clu_random, axis=0)
89
+ acc_random = np.average(acc_random, axis=0)
90
+
91
+ pars_clu_rand = []
92
+ pars_acc_rand = []
93
+
94
+ for seed in range(random_seeds, 2*random_seeds):
95
+ par_clu_rand, par_acc_rand = run_evaluate_cluacc_par(
96
+ subset2evaluate.select_subset.run_select_subset(data_old, method="random", seed=seed),
97
+ data_old,
98
+ clu_random,
99
+ acc_random,
100
+ metric=metric,
101
+ workers=workers,
102
+ )
103
+ pars_clu_rand.append(par_clu_rand)
104
+ pars_acc_rand.append(par_acc_rand)
105
+
106
+ return (clu_random, acc_random), (np.average(pars_clu_rand), np.average(pars_acc_rand))
107
+
108
+ def run_evaluate_cluacc_randnorm(
109
+ data_new: List[Dict],
110
+ data_old: List[Dict],
111
+ random_seeds=10,
112
+ metric="human",
113
+ cluacc_precomputed = None
114
+ ) -> Tuple[float, float]:
115
+
116
+ if cluacc_precomputed is not None:
117
+ (clu_random, acc_random), (clu_random_norm, acc_random_norm) = cluacc_precomputed
118
+ else:
119
+ (clu_random, acc_random), (clu_random_norm, acc_random_norm) = precompute_randnorm(data_old, random_seeds=random_seeds, metric=metric)
120
+
121
+ # compute the parity of the new data
122
+ par_clu, par_acc = run_evaluate_cluacc_par(
123
+ data_new, data_old,
124
+ clu_random, acc_random,
125
+ metric=metric
126
+ )
127
+
128
+ return par_clu/clu_random_norm, par_acc/acc_random_norm
129
+
130
+
131
+ def run_evaluate_top_timebudget(data_new, data_old, metric="human"):
132
+ # both list or descriptor is fine
133
+ data_old = utils.load_data(data_old)
134
+ data_new = utils.load_data(data_new)
135
+
136
+ clu_new = []
137
+ acc_new = []
138
+ for prop in utils.PROPS:
139
+ k = int(len(data_old) * prop)
140
+ data_new_inbudget = []
141
+ budget = k
142
+ for item in data_new:
143
+ if item["time"] <= budget:
144
+ budget -= item["time"]
145
+ data_new_inbudget.append(item)
146
+ else:
147
+ break
148
+ clu_new.append(eval_subset_clusters(data_new_inbudget, metric=metric))
149
+ acc_new.append(eval_subset_accuracy(data_new_inbudget, data_old, metric=metric))
150
+
151
+ return clu_new, acc_new
152
+
153
+
154
+ def eval_subset_accuracy(data_new: List[Dict], data_old: List[Dict], metric="human"):
155
+ # evaluates against ordering from data_old
156
+ import itertools
157
+
158
+ systems = list(data_old[0]["scores"].keys())
159
+
160
+ scores_old = get_sys_absolute(data_old, metric=metric)
161
+ scores_new = get_sys_absolute(data_new, metric=metric)
162
+
163
+ result = []
164
+ for sys1, sys2 in itertools.combinations(systems, 2):
165
+ result.append((scores_old[sys1] < scores_old[sys2]) == (scores_new[sys1] < scores_new[sys2]))
166
+
167
+ return np.average(result)
168
+
169
+
170
+ def eval_subset_clusters(data: List[Dict], metric="human"):
171
+ from scipy.stats import wilcoxon
172
+ import warnings
173
+
174
+ # if we have just 3 samples, we can't say that there are clusters
175
+ if len(data) < 3:
176
+ return 1
177
+
178
+ # sort from top
179
+ sys_ord = list(get_sys_absolute(data, metric=metric).items())
180
+ sys_ord.sort(key=lambda x: x[1], reverse=True)
181
+ sys_ord = [sys for sys, _ in sys_ord]
182
+
183
+ def get_scores(system):
184
+ return [line["scores"][system][metric] for line in data]
185
+
186
+ clusters = [[get_scores(sys_ord.pop(0))]]
187
+ while sys_ord:
188
+ sys_scores = get_scores(sys_ord.pop(0))
189
+ diffs = [x - y for x, y in zip(sys_scores, clusters[-1][-1])]
190
+
191
+ with warnings.catch_warnings(action="ignore"):
192
+ if all([d == 0 for d in diffs]) or wilcoxon(diffs, alternative="less").pvalue < 0.05:
193
+ clusters.append([sys_scores])
194
+ else:
195
+ clusters[-1].append(sys_scores)
196
+
197
+ return len(clusters)
198
+
199
+
200
+ def get_sys_absolute(data_new, metric="human") -> Dict[str, float]:
201
+ import collections
202
+ import numpy as np
203
+
204
+ scores_new = collections.defaultdict(list)
205
+
206
+ systems = list(data_new[0]["scores"].keys())
207
+ for line in data_new:
208
+ for sys in systems:
209
+ scores_new[sys].append(line["scores"][sys][metric])
210
+
211
+ scores_new = {
212
+ sys: np.average(scores_new[sys])
213
+ for sys in systems
214
+ }
215
+
216
+ return scores_new
217
+
218
+
219
+ def get_sys_ordering(data_new: List[Dict], metric="human"):
220
+ scores_new = get_sys_absolute(data_new, metric)
221
+
222
+ # sort to get ordering
223
+ scores_new = list(scores_new.items())
224
+ # sort from highest
225
+ scores_new.sort(key=lambda x: x[1], reverse=True)
226
+
227
+ sys_ord = {
228
+ sys: sys_i
229
+ for sys_i, (sys, sys_v) in enumerate(scores_new)
230
+ }
231
+
232
+ return sys_ord
233
+
234
+
235
+ def eval_order_accuracy(scores_new: Dict[str, float], scores_old: Dict[str, float]):
236
+ # evaluates against ordering from data_old
237
+ import itertools
238
+ import numpy as np
239
+
240
+ systems = list(scores_old.keys())
241
+
242
+ result = []
243
+ for sys1, sys2 in itertools.combinations(systems, 2):
244
+ result.append((scores_old[sys1] < scores_old[sys2]) == (scores_new[sys1] < scores_new[sys2]))
245
+
246
+ return np.average(result)
247
+
248
+
249
+ def main_cli():
250
+ import argparse
251
+
252
+ args = argparse.ArgumentParser(
253
+ description="Meta-evaluate subset selection methods with cluster count and system accuracy."
254
+ )
255
+ args.add_argument(
256
+ 'data_old', type=str, default='wmt23/en-cs',
257
+ help="Original data descriptor or path."
258
+ )
259
+ args.add_argument(
260
+ 'data_new', type=str, default='wmt23/en-cs',
261
+ help="Path to new ordered data."
262
+ )
263
+ args.add_argument(
264
+ '--metric', type=str, default='human',
265
+ help="Metric to evaluate against, e.g., human or human_consistency. Can also be a metric and not human score."
266
+ )
267
+ args = args.parse_args()
268
+
269
+ clu_new, acc_new = run_evaluate_cluacc(args.data_old, args.data_new, args.metric)
270
+
271
+ print(f"Clusters: {np.average(clu_new):.2f}")
272
+ print(f"Accuracy: {np.average(acc_new):.1%}")
@@ -0,0 +1,363 @@
1
+ from typing import Any, List, Tuple, Union
2
+ from functools import partial
3
+ import numpy as np
4
+ import os
5
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
6
+
7
+
8
+ def random_subset(data, seed=None, **kwargs) -> List[float]:
9
+ import random
10
+ r = random.Random(seed)
11
+ return [r.random() for _ in data]
12
+
13
+
14
+ def metric_avg(data, metric, **kwargs) -> List[float]:
15
+ return [
16
+ -np.average([sys_v[metric] for sys_v in item["scores"].values()])
17
+ for item in data
18
+ ]
19
+
20
+
21
+ def metric_var(data, metric, **kwargs) -> List[float]:
22
+ return [
23
+ np.var([sys_v[metric] for sys_v in item["scores"].values()])
24
+ for item in data
25
+ ]
26
+
27
+
28
+ def _fn_information_content(item_old, item_irt, data_irt) -> float:
29
+ information = 0
30
+ for theta in data_irt["systems"].values():
31
+ x1 = np.exp(item_irt["disc"] * (theta + item_irt["diff"]))
32
+ x2 = np.exp(item_irt["disc"] * item_irt["diff"])
33
+ x3 = np.exp(item_irt["disc"] * theta)
34
+ information += (item_irt["disc"]**2) * x1 / (x2 + x3)**2
35
+ return information
36
+
37
+
38
+ def fn_irt_utility(item_old, item_irt, data_irt, fn_utility) -> float:
39
+ if fn_utility == "fisher_information_content":
40
+ return _fn_information_content(item_old, item_irt, data_irt)
41
+ elif fn_utility == "diff":
42
+ return -item_irt["diff"]
43
+ elif fn_utility == "disc":
44
+ return -item_irt["disc"]
45
+ elif fn_utility == "diffdisc":
46
+ return item_irt["diff"] * item_irt["disc"]
47
+ elif fn_utility == "feas":
48
+ return item_irt["feas"]
49
+
50
+
51
+ def pyirt(data, metric, return_model=False, load_model=None, model="4pl_score", dropout=0.25, epochs=1000, enforce_positive_disc=False, **kwargs) -> Union[List[float], Tuple[List[float], Any]]:
52
+ import py_irt
53
+ import py_irt.config
54
+ import py_irt.dataset
55
+ import py_irt.io
56
+ import py_irt.training
57
+ import py_irt.models
58
+ import py_irt.models.abstract_model
59
+ import pandas as pd
60
+
61
+ if model not in py_irt.models.abstract_model._IRT_REGISTRY:
62
+ raise Exception("Please install py-irt with `pip install git+https://github.com/zouharvi/py-irt.git")
63
+
64
+ systems = list(data[0]["scores"].keys())
65
+
66
+ if load_model is not None:
67
+ data_irt = load_model
68
+ else:
69
+ # we need median binarization if we are not using 4pl_score model
70
+ median = np.median([
71
+ system_v[metric]
72
+ for line in data
73
+ for system_v in line["scores"].values()
74
+ ])
75
+ dataset = pd.DataFrame({
76
+ "system": systems,
77
+ **{
78
+ f"item_{line['i']}": [
79
+ line["scores"][system][metric]
80
+ if "_score" in model else
81
+ line["scores"][system][metric] >= median
82
+ for system in systems
83
+ ]
84
+ for line in data
85
+ }
86
+ })
87
+
88
+ embeddings = None
89
+ if "amortized_" in model:
90
+ import sentence_transformers
91
+ embd_model = sentence_transformers.SentenceTransformer("paraphrase-MiniLM-L12-v2")
92
+ embeddings = embd_model.encode([line["src"] for line in data])
93
+ embeddings = {f"item_{line['i']}": emb.tolist() for line, emb in zip(data, embeddings)}
94
+ del embd_model
95
+
96
+ dataset = py_irt.dataset.Dataset.from_pandas(
97
+ dataset,
98
+ subject_column="system",
99
+ item_columns=[f"item_{line['i']}" for line in data],
100
+ embeddings=embeddings,
101
+ )
102
+
103
+ config = py_irt.config.IrtConfig(
104
+ model_type=model,
105
+ log_every=100,
106
+ dropout=dropout,
107
+ priors="hiearchical",
108
+ seed=0,
109
+ deterministic=True,
110
+ )
111
+ trainer = py_irt.training.IrtModelTrainer(
112
+ config=config,
113
+ data_path=None,
114
+ dataset=dataset,
115
+ verbose=False
116
+ )
117
+ trainer.train(epochs=epochs, device='cuda')
118
+
119
+ params = trainer.best_params
120
+
121
+ # this flipping should not affect the predictions
122
+ if enforce_positive_disc and np.average(params["disc"]) < 0:
123
+ params["disc"] = -np.array(params["disc"])
124
+ params["ability"] = -np.array(params["ability"])
125
+ params["diff"] = -np.array(params["diff"])
126
+
127
+ # normalize naming
128
+ if "lambdas" in params:
129
+ params["feas"] = params.pop("lambdas")
130
+
131
+ # TODO: cross-check make sure that we do the predictions as the models were trained
132
+ if "feas" in params:
133
+ # 3PL/4PL
134
+ data_irt = {
135
+ "systems": {sys: sys_v for sys, sys_v in zip(systems, params["ability"])},
136
+ "items": [
137
+ {"disc": disc, "diff": diff, "feas": feas}
138
+ for disc, diff, feas in zip(
139
+ params["disc"],
140
+ params["diff"],
141
+ params["feas"],
142
+ )
143
+ ]
144
+ }
145
+ elif "disc" in params:
146
+ data_irt = {
147
+ "systems": {sys: sys_v for sys, sys_v in zip(systems, params["ability"])},
148
+ "items": [
149
+ {"disc": disc, "diff": diff}
150
+ for disc, diff in zip(
151
+ params["disc"],
152
+ params["diff"],
153
+ )
154
+ ]
155
+ }
156
+ else:
157
+ data_irt = {
158
+ "systems": {sys: sys_v for sys, sys_v in zip(systems, params["ability"])},
159
+ "items": [
160
+ {"diff": diff}
161
+ for diff in params["diff"]
162
+ ]
163
+ }
164
+
165
+ scores = [
166
+ fn_irt_utility(item_old, item_irt, data_irt, kwargs["fn_utility"])
167
+ for item_old, item_irt in zip(data, data_irt["items"])
168
+ ]
169
+
170
+ if return_model:
171
+ return scores, data_irt
172
+ else:
173
+ return scores
174
+
175
+
176
+ def _assert_comet_version():
177
+ import comet
178
+ if "HypothesislessRegression" not in dir(comet.models):
179
+ raise Exception("Please install COMET with `pip install git+https://github.com/zouharvi/comet-src.git`")
180
+
181
+
182
+ def precomet(data, model_path, return_model=False, load_model=None, reverse=False, **kwargs) -> Union[List, Tuple[List, Any]]:
183
+ import os
184
+ prev_tqdm_setting = os.environ.get("TQDM_DISABLE", None)
185
+ os.environ["TQDM_DISABLE"] = "1"
186
+
187
+ import logging
188
+ import comet
189
+ import warnings
190
+
191
+ logging.disable(logging.INFO)
192
+ _assert_comet_version()
193
+
194
+ with warnings.catch_warnings(action="ignore"):
195
+ if load_model is not None:
196
+ model = load_model
197
+ elif os.path.exists(model_path):
198
+ model = comet.load_from_checkpoint(model_path)
199
+ else:
200
+ model = comet.load_from_checkpoint(comet.download_model(model_path))
201
+ scores = model.predict([
202
+ {"src": line["src"]}
203
+ for line in data
204
+ ], progress_bar=False).scores
205
+ if reverse:
206
+ scores = [-x for x in scores]
207
+
208
+ logging.disable(logging.NOTSET)
209
+ if prev_tqdm_setting is not None:
210
+ os.environ["TQDM_DISABLE"] = prev_tqdm_setting
211
+ else:
212
+ os.environ.pop("TQDM_DISABLE")
213
+
214
+ if return_model:
215
+ return scores, model
216
+ else:
217
+ return scores
218
+
219
+
220
+ def precomet_dual(data, model_path1, model_path2, return_model=False, load_model=None, reverse=False, **kwargs) -> Union[List, Tuple[List, Any]]:
221
+ import os
222
+ tqdm_disable_prev = os.environ.get("TQDM_DISABLE", None)
223
+ os.environ["TQDM_DISABLE"] = "1"
224
+
225
+ import comet
226
+ import warnings
227
+ import logging
228
+
229
+ logging.disable(logging.INFO)
230
+ _assert_comet_version()
231
+
232
+ with warnings.catch_warnings(action="ignore"):
233
+ if load_model is not None:
234
+ model1, model2 = load_model
235
+ else:
236
+ if os.path.exists(model_path1):
237
+ model1 = comet.load_from_checkpoint(model_path1)
238
+ else:
239
+ model1 = comet.load_from_checkpoint(comet.download_model(model_path1))
240
+
241
+ if os.path.exists(model_path2):
242
+ model2 = comet.load_from_checkpoint(model_path2)
243
+ else:
244
+ model2 = comet.load_from_checkpoint(comet.download_model(model_path2))
245
+ scores1 = model1.predict([
246
+ {"src": line["src"]}
247
+ for line in data
248
+ ], progress_bar=False).scores
249
+ scores2 = model2.predict([
250
+ {"src": line["src"]}
251
+ for line in data
252
+ ], progress_bar=False).scores
253
+
254
+ if reverse:
255
+ scores = [-s1 * s2 for s1, s2 in zip(scores1, scores2)]
256
+ else:
257
+ scores = [s1 * s2 for s1, s2 in zip(scores1, scores2)]
258
+
259
+ logging.disable(logging.NOTSET)
260
+ if tqdm_disable_prev is not None:
261
+ os.environ["TQDM_DISABLE"] = tqdm_disable_prev
262
+ else:
263
+ os.environ.pop("TQDM_DISABLE")
264
+
265
+ if return_model:
266
+ return scores, (model1, model2)
267
+ else:
268
+ return scores
269
+
270
+
271
+ def diversity_unigram(data, **kwargs) -> List[float]:
272
+ import itertools
273
+ import collections
274
+
275
+ def _f(line):
276
+ out = []
277
+ for text_a, text_b in itertools.combinations(line["tgt"].values(), 2):
278
+ text_a = collections.Counter(text_a.split())
279
+ text_b = collections.Counter(text_b.split())
280
+ if text_a.total() == 0 or text_b.total() == 0:
281
+ out.append(1)
282
+ else:
283
+ out.append(2 * (text_a & text_b).total() / (text_a.total() + text_b.total()))
284
+ return np.average(out)
285
+
286
+ # we prefer smallest similarity so flip
287
+ return [
288
+ -_f(line)
289
+ for line in data
290
+ ]
291
+
292
+
293
+ def diversity_bleu(data, **kwargs) -> List[float]:
294
+ import itertools
295
+ import sacrebleu
296
+ metric = sacrebleu.metrics.BLEU(effective_order=True)
297
+
298
+ def _f(line):
299
+ return np.average([
300
+ metric.sentence_score(
301
+ text_a,
302
+ [text_b],
303
+ ).score
304
+ for text_a, text_b in itertools.product(line["tgt"].values(), line["tgt"].values())
305
+ ])
306
+
307
+ # we prefer smallest similarity so flip
308
+ return [
309
+ -_f(line)
310
+ for line in data
311
+ ]
312
+
313
+
314
+ def diversity_chrf(data, **kwargs) -> List[float]:
315
+ import itertools
316
+ import sacrebleu
317
+ metric = sacrebleu.metrics.CHRF()
318
+
319
+ def _f(line):
320
+ return np.average([
321
+ metric.sentence_score(
322
+ text_a,
323
+ [text_b],
324
+ ).score
325
+ for text_a, text_b in itertools.product(line["tgt"].values(), line["tgt"].values())
326
+ ])
327
+
328
+ # we prefer smallest similarity so flip
329
+ return [
330
+ -_f(line)
331
+ for line in data
332
+ ]
333
+
334
+
335
+ METHODS = {
336
+ "random": random_subset,
337
+ "metric_avg": metric_avg,
338
+ "metric_var": metric_var,
339
+ "diversity_bleu": diversity_bleu,
340
+ "diversity_chrf": diversity_chrf,
341
+ "diversity_unigram": diversity_unigram,
342
+
343
+ "pyirt_diff": partial(pyirt, fn_utility="diff"),
344
+ "pyirt_disc": partial(pyirt, fn_utility="disc"),
345
+ "pyirt_diffdisc": partial(pyirt, fn_utility="diffdisc"),
346
+ "pyirt_feas": partial(pyirt, fn_utility="feas"),
347
+ "pyirt_fic": partial(pyirt, fn_utility="fisher_information_content"),
348
+ "pyirt_experiment": partial(pyirt, fn_utility="experiment"),
349
+
350
+ "precomet_var": partial(precomet, model_path="zouharvi/PreCOMET-var", reverse=True),
351
+ "precomet_avg": partial(precomet, model_path="zouharvi/PreCOMET-avg", reverse=True),
352
+ "precomet_diversity": partial(precomet, model_path="zouharvi/PreCOMET-diversity", reverse=True),
353
+
354
+ "precomet_diff": partial(precomet, model_path="zouharvi/PreCOMET-diff", reverse=False),
355
+ "precomet_disc": partial(precomet, model_path="zouharvi/PreCOMET-disc", reverse=True),
356
+ "precomet_diffdisc_direct": partial(precomet, model_path="zouharvi/PreCOMET-diffdisc_direct", reverse=False),
357
+ "precomet_diffdisc": partial(
358
+ precomet_dual,
359
+ model_path1="zouharvi/PreCOMET-diff",
360
+ model_path2="zouharvi/PreCOMET-disc",
361
+ reverse=False,
362
+ ),
363
+ }