subset2evaluate 0.0.1a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- subset2evaluate/__init__.py +6 -0
- subset2evaluate/evaluate.py +272 -0
- subset2evaluate/methods.py +363 -0
- subset2evaluate/methods_old.py +436 -0
- subset2evaluate/select_subset.py +101 -0
- subset2evaluate/test.py +66 -0
- subset2evaluate/utils.py +410 -0
- subset2evaluate-0.0.1a6.dist-info/METADATA +210 -0
- subset2evaluate-0.0.1a6.dist-info/RECORD +12 -0
- subset2evaluate-0.0.1a6.dist-info/WHEEL +5 -0
- subset2evaluate-0.0.1a6.dist-info/entry_points.txt +3 -0
- subset2evaluate-0.0.1a6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
# These are some subset selection methods that are not polished enough to be used in practice
|
|
2
|
+
from typing import Any, Callable, List, Tuple, Union
|
|
3
|
+
from functools import partial
|
|
4
|
+
import subset2evaluate.utils as utils
|
|
5
|
+
import subset2evaluate
|
|
6
|
+
import subset2evaluate.evaluate
|
|
7
|
+
import numpy as np
|
|
8
|
+
import random
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _fn_information_content_old(item_irt, data_irt) -> float:
|
|
12
|
+
# This formula is based on the simplified formula of Rodriquez et al 2021
|
|
13
|
+
information = 0
|
|
14
|
+
for theta in data_irt["systems"].values():
|
|
15
|
+
prob = utils.pred_irt(
|
|
16
|
+
theta,
|
|
17
|
+
item_irt
|
|
18
|
+
)
|
|
19
|
+
information += prob * (1 - prob) * (item_irt["disc"]**2)
|
|
20
|
+
return information
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def metric_consistency(data, metric, **kwargs) -> List[float]:
|
|
24
|
+
metric_scores = subset2evaluate.evaluate.get_sys_absolute(data, metric=metric)
|
|
25
|
+
rank_correlation = {}
|
|
26
|
+
sys_names = list(metric_scores.keys())
|
|
27
|
+
for example in data:
|
|
28
|
+
consistency = 0
|
|
29
|
+
total = 0
|
|
30
|
+
for i in range(len(sys_names)):
|
|
31
|
+
for j in range(i + 1, len(sys_names)):
|
|
32
|
+
if (metric_scores[sys_names[i]] - metric_scores[sys_names[j]]) * (example['scores'][sys_names[i]][metric] - example['scores'][sys_names[j]][metric]) > 0:
|
|
33
|
+
consistency += 1
|
|
34
|
+
else:
|
|
35
|
+
consistency -= 1
|
|
36
|
+
total += 1
|
|
37
|
+
rank_correlation[example['i']] = consistency / total
|
|
38
|
+
|
|
39
|
+
return [
|
|
40
|
+
rank_correlation[item['i']]
|
|
41
|
+
for item in data
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def nn_irt(data, metric, **kwargs):
|
|
46
|
+
raise NotImplementedError("This method is not yet implemented for use.")
|
|
47
|
+
import torch
|
|
48
|
+
import neural_irt.train
|
|
49
|
+
from neural_irt.lit_module import IrtLitModule
|
|
50
|
+
from neural_irt.data import collators, datasets
|
|
51
|
+
from neural_irt.configs.common import IrtModelConfig
|
|
52
|
+
from torch.utils import data as torch_data
|
|
53
|
+
import wandb
|
|
54
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
55
|
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
56
|
+
from sklearn.model_selection import train_test_split
|
|
57
|
+
from typing import Any, Optional, Sequence
|
|
58
|
+
from pydantic import BaseModel
|
|
59
|
+
|
|
60
|
+
BATCH_SIZE = 256
|
|
61
|
+
|
|
62
|
+
systems = list(data[0]["scores"].keys())
|
|
63
|
+
|
|
64
|
+
def wrangle_data(data_local):
|
|
65
|
+
data_out = []
|
|
66
|
+
for line in data_local:
|
|
67
|
+
for sys, sys_v in line["scores"].items():
|
|
68
|
+
data_out.append({
|
|
69
|
+
"agent_name": sys,
|
|
70
|
+
"agent_id": systems.index(sys),
|
|
71
|
+
"agent_type": "general",
|
|
72
|
+
"query_id": line["i"],
|
|
73
|
+
"query_rep": torch.nn.functional.one_hot(torch.tensor([line["i"]]), num_classes=len(data)).float(),
|
|
74
|
+
"ruling": sys_v[metric],
|
|
75
|
+
})
|
|
76
|
+
return data_out
|
|
77
|
+
|
|
78
|
+
def create_agent_indexer_from_dataset(
|
|
79
|
+
dataset_or_path: str | Sequence[dict[str, Any]],
|
|
80
|
+
) -> neural_irt.train.AgentIndexer:
|
|
81
|
+
dataset = dataset_or_path
|
|
82
|
+
if isinstance(dataset, str):
|
|
83
|
+
dataset = datasets.load_as_hf_dataset(dataset_or_path)
|
|
84
|
+
|
|
85
|
+
# NOTE: this was entry["id"] and entry["type"] before
|
|
86
|
+
agent_ids = [entry["agent_id"] for entry in dataset]
|
|
87
|
+
agent_types = list({entry["agent_type"] for entry in dataset})
|
|
88
|
+
agent_type_map = {entry["agent_id"]: entry["agent_type"] for entry in dataset}
|
|
89
|
+
return neural_irt.train.AgentIndexer(agent_ids, agent_types, agent_type_map)
|
|
90
|
+
|
|
91
|
+
data_train, data_test = train_test_split(data, test_size=0.1, random_state=0)
|
|
92
|
+
data_train = wrangle_data(data_train)
|
|
93
|
+
data_test = wrangle_data(data_test)
|
|
94
|
+
agent_indexer = create_agent_indexer_from_dataset(data_train + data_test)
|
|
95
|
+
train_collator = collators.CaimiraCollator(agent_indexer, is_training=True)
|
|
96
|
+
|
|
97
|
+
# TODO: very likely will fail here because the data is not wrangled properly
|
|
98
|
+
train_loader = torch_data.DataLoader(
|
|
99
|
+
data_train,
|
|
100
|
+
batch_size=BATCH_SIZE,
|
|
101
|
+
shuffle=True,
|
|
102
|
+
collate_fn=train_collator,
|
|
103
|
+
num_workers=1,
|
|
104
|
+
)
|
|
105
|
+
val_loaders_dict = {
|
|
106
|
+
"val": torch_data.DataLoader(
|
|
107
|
+
data_test,
|
|
108
|
+
batch_size=BATCH_SIZE,
|
|
109
|
+
shuffle=True,
|
|
110
|
+
collate_fn=train_collator,
|
|
111
|
+
num_workers=1,
|
|
112
|
+
)
|
|
113
|
+
}
|
|
114
|
+
val_loader_names = list(val_loaders_dict.keys())
|
|
115
|
+
val_loaders = [val_loaders_dict[name] for name in val_loader_names]
|
|
116
|
+
|
|
117
|
+
class TrainerConfig(BaseModel):
|
|
118
|
+
# Train time
|
|
119
|
+
# TODO: we define max_epochs twice?
|
|
120
|
+
max_epochs: int = 100
|
|
121
|
+
max_steps: Optional[int] = None
|
|
122
|
+
sampler: Optional[str] = None
|
|
123
|
+
batch_size: int = BATCH_SIZE
|
|
124
|
+
|
|
125
|
+
# Optimizer
|
|
126
|
+
optimizer: str = "Adam" # [Adam, RMSprop, SGD]
|
|
127
|
+
learning_rate: float = 1e-3
|
|
128
|
+
cyclic_lr: bool = False
|
|
129
|
+
|
|
130
|
+
second_optimizer: str = "SGD"
|
|
131
|
+
second_learning_rate: float = 5e-4
|
|
132
|
+
second_optimizer_start_epoch: Optional[int] = 75
|
|
133
|
+
|
|
134
|
+
freeze_bias_after: Optional[int] = None
|
|
135
|
+
|
|
136
|
+
ckpt_savedir: str = "./checkpoints/irt"
|
|
137
|
+
|
|
138
|
+
c_reg_skill: float = 1e-6
|
|
139
|
+
c_reg_difficulty: float = 1e-6
|
|
140
|
+
c_reg_relevance: float = 1e-6
|
|
141
|
+
|
|
142
|
+
class CaimiraConfig(IrtModelConfig):
|
|
143
|
+
# Number of dimensions in item embeddings
|
|
144
|
+
# TODO: turn this into real embeddings
|
|
145
|
+
n_dim_item_embed: int = len(data)
|
|
146
|
+
|
|
147
|
+
# Number of dimensions for the agent embedding
|
|
148
|
+
rel_mode: str = "linear" # [linear, mlp]
|
|
149
|
+
dif_mode: str = "linear" # [linear, mlp]
|
|
150
|
+
|
|
151
|
+
# Number of hidden units for the MLPs if mode is mlp
|
|
152
|
+
n_hidden_dif: int = 128
|
|
153
|
+
n_hidden_rel: int = 128
|
|
154
|
+
|
|
155
|
+
# Sparsity controls for importance [only used if fit_importance is True]
|
|
156
|
+
# Temperature for importance
|
|
157
|
+
rel_temperature: float = 0.5
|
|
158
|
+
fast: bool = False
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def arch(self):
|
|
162
|
+
return "caimira"
|
|
163
|
+
|
|
164
|
+
n_agents: int = len(systems)
|
|
165
|
+
# 1 for now because we don't know what it really is
|
|
166
|
+
n_agent_types: int = 1
|
|
167
|
+
n_dim: int = 32
|
|
168
|
+
fit_guess_bias: float = False
|
|
169
|
+
# TODO: turn off?
|
|
170
|
+
fit_agent_type_embeddings: bool = True
|
|
171
|
+
|
|
172
|
+
model = IrtLitModule(
|
|
173
|
+
trainer_config=TrainerConfig(),
|
|
174
|
+
model_or_config=CaimiraConfig(),
|
|
175
|
+
val_dataloader_names=val_loader_names,
|
|
176
|
+
agent_indexer=agent_indexer,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
train_logger = None
|
|
180
|
+
train_logger = WandbLogger(
|
|
181
|
+
project="irt-mt-dev",
|
|
182
|
+
name="nnirt base",
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
checkpoint_callback = ModelCheckpoint(
|
|
186
|
+
save_top_k=3,
|
|
187
|
+
monitor="val/acc",
|
|
188
|
+
mode="max", # Error: This should be "max" instead of "min" for accuracy
|
|
189
|
+
auto_insert_metric_name=False,
|
|
190
|
+
filename="epoch={epoch}-acc={val/acc:.2f}",
|
|
191
|
+
)
|
|
192
|
+
checkpoint_callback.FILE_EXTENSION = ""
|
|
193
|
+
trainer = neural_irt.train.CaimiraTrainer(
|
|
194
|
+
max_epochs=kwargs["max_epochs"],
|
|
195
|
+
accelerator="auto",
|
|
196
|
+
logger=train_logger,
|
|
197
|
+
callbacks=[checkpoint_callback],
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loaders)
|
|
201
|
+
loaded_model = IrtLitModule.load_from_checkpoint(
|
|
202
|
+
checkpoint_callback.best_model_path
|
|
203
|
+
)
|
|
204
|
+
print(loaded_model)
|
|
205
|
+
wandb.save()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def our_irt(data, metric, **kwargs):
|
|
209
|
+
import torch
|
|
210
|
+
import torch.utils
|
|
211
|
+
import lightning as L
|
|
212
|
+
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
|
|
213
|
+
from lightning.pytorch.callbacks.callback import Callback
|
|
214
|
+
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
|
215
|
+
from irt_mt_dev.irt.scalar import IRTModelScalar
|
|
216
|
+
from irt_mt_dev.irt.tfidf import IRTModelTFIDF
|
|
217
|
+
from irt_mt_dev.irt.embd import IRTModelEmbd
|
|
218
|
+
# turn off pesky pytorch logs
|
|
219
|
+
import logging
|
|
220
|
+
logging.disable(logging.INFO)
|
|
221
|
+
import wandb
|
|
222
|
+
import os
|
|
223
|
+
os.environ["WANDB_SILENT"] = "true"
|
|
224
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
225
|
+
|
|
226
|
+
systems = list(data[0]["scores"].keys())
|
|
227
|
+
|
|
228
|
+
ModelClass = {
|
|
229
|
+
"scalar": IRTModelScalar,
|
|
230
|
+
"tfidf": IRTModelTFIDF,
|
|
231
|
+
"embd": IRTModelEmbd,
|
|
232
|
+
}[kwargs["model"]]
|
|
233
|
+
model = ModelClass(data, systems, data_old=data, **kwargs)
|
|
234
|
+
|
|
235
|
+
data_flat = [
|
|
236
|
+
((sent_i, sys_i), sent["scores"][sys][metric])
|
|
237
|
+
for sent_i, sent in enumerate(data)
|
|
238
|
+
for sys_i, sys in enumerate(systems)
|
|
239
|
+
]
|
|
240
|
+
|
|
241
|
+
# TODO: in the future run first training with dev set to find out the best epoch count
|
|
242
|
+
# and then run again on full data with that epoch count
|
|
243
|
+
|
|
244
|
+
# use all data for both training and validation for now
|
|
245
|
+
data_train = torch.utils.data.DataLoader(
|
|
246
|
+
data_flat,
|
|
247
|
+
batch_size=len(data_flat),
|
|
248
|
+
num_workers=24,
|
|
249
|
+
shuffle=True,
|
|
250
|
+
# fully move to GPU
|
|
251
|
+
pin_memory=True,
|
|
252
|
+
# don't kill workers because that's our bottleneck
|
|
253
|
+
persistent_workers=True,
|
|
254
|
+
)
|
|
255
|
+
data_val = torch.utils.data.DataLoader(
|
|
256
|
+
data_flat,
|
|
257
|
+
batch_size=len(data_flat),
|
|
258
|
+
num_workers=24,
|
|
259
|
+
shuffle=False,
|
|
260
|
+
# fully move to GPU
|
|
261
|
+
pin_memory=True,
|
|
262
|
+
# don't kill workers because that's our bottleneck
|
|
263
|
+
persistent_workers=True,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# tiny handler to propagate ctrl-C instead of just stopping the training
|
|
267
|
+
class PropagateExitCallback(Callback):
|
|
268
|
+
def on_exception(self, trainer, pl_module, exception):
|
|
269
|
+
print(exception)
|
|
270
|
+
exit()
|
|
271
|
+
|
|
272
|
+
trainer = L.Trainer(
|
|
273
|
+
max_epochs=2000,
|
|
274
|
+
check_val_every_n_epoch=50,
|
|
275
|
+
log_every_n_steps=1,
|
|
276
|
+
enable_checkpointing=True,
|
|
277
|
+
enable_progress_bar=False,
|
|
278
|
+
enable_model_summary=False,
|
|
279
|
+
# logger=False,
|
|
280
|
+
callbacks=[
|
|
281
|
+
EarlyStopping(monitor="cluster_count_metric", patience=20, verbose=True, mode="max"),
|
|
282
|
+
ModelCheckpoint(filename='best_model', monitor='cluster_count_metric', mode='max', save_top_k=1),
|
|
283
|
+
PropagateExitCallback(),
|
|
284
|
+
],
|
|
285
|
+
logger=WandbLogger(project="irt-mt-dev"),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
trainer.fit(
|
|
289
|
+
model=model,
|
|
290
|
+
train_dataloaders=data_train,
|
|
291
|
+
val_dataloaders=data_val,
|
|
292
|
+
)
|
|
293
|
+
wandb.finish()
|
|
294
|
+
|
|
295
|
+
# reload best model
|
|
296
|
+
model.eval()
|
|
297
|
+
# model.load_state_dict(torch.load(trainer.checkpoint_callback.best_model_path, weights_only=True)["state_dict"])
|
|
298
|
+
# model.validation_step(None, None)
|
|
299
|
+
# data_irt = model.pack_irt_params()
|
|
300
|
+
|
|
301
|
+
best_val_step = max(range(len(model.results_log)), key=lambda i: model.results_log[i]["subset_consistency_accuracy_metric"])
|
|
302
|
+
# print("Best validation step was:", best_val_step)
|
|
303
|
+
data_irt = model.params_log[best_val_step]
|
|
304
|
+
|
|
305
|
+
return [
|
|
306
|
+
model.fn_utility(x, data_irt["systems"])
|
|
307
|
+
for x in data_irt["items"]
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def get_nice_subset(data_old, target_size=100, step_size=10, metric="human") -> List[float]:
|
|
312
|
+
raise NotImplementedError("This method is not yet implemented for use.")
|
|
313
|
+
import numpy as np
|
|
314
|
+
order_full = subset2evaluate.evaluate.get_sys_ordering(data_old, metric=metric)
|
|
315
|
+
|
|
316
|
+
print(f"Previous average accuracy: {np.average([subset2evaluate.evaluate.eval_order_accuracy(order_full, subset2evaluate.evaluate.get_sys_ordering([line], metric=metric)) for line in data_old]):.1%}")
|
|
317
|
+
|
|
318
|
+
while len(data_old) > target_size:
|
|
319
|
+
order_full = subset2evaluate.evaluate.get_sys_ordering(data_old, metric=metric)
|
|
320
|
+
data_old.sort(key=lambda line: subset2evaluate.evaluate.eval_order_accuracy(order_full, subset2evaluate.evaluate.get_sys_ordering([line], metric=metric)))
|
|
321
|
+
data_old = data_old[step_size:]
|
|
322
|
+
|
|
323
|
+
print(f"New average accuracy: {np.average([subset2evaluate.evaluate.eval_order_accuracy(order_full, subset2evaluate.evaluate.get_sys_ordering([line], metric=metric)) for line in data_old]):.1%}")
|
|
324
|
+
return data_old
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def premlp_other(data, data_train, fn_utility: Callable, **kwargs) -> List[float]:
|
|
328
|
+
# turn off warnings from sentence-transformers
|
|
329
|
+
import warnings
|
|
330
|
+
warnings.filterwarnings('ignore')
|
|
331
|
+
import sentence_transformers
|
|
332
|
+
import sklearn.neural_network
|
|
333
|
+
|
|
334
|
+
embd_model = sentence_transformers.SentenceTransformer("paraphrase-MiniLM-L12-v2")
|
|
335
|
+
data_x_train = embd_model.encode([line["src"] for line in data_train])
|
|
336
|
+
data_y_train = [fn_utility(line) for line in data_train]
|
|
337
|
+
|
|
338
|
+
model = sklearn.neural_network.MLPRegressor(
|
|
339
|
+
hidden_layer_sizes=(128, 16),
|
|
340
|
+
max_iter=1000,
|
|
341
|
+
verbose=False,
|
|
342
|
+
)
|
|
343
|
+
model.fit(data_x_train, data_y_train)
|
|
344
|
+
data_x_test = embd_model.encode([line["src"] for line in data])
|
|
345
|
+
data_y_test = model.predict(data_x_test)
|
|
346
|
+
|
|
347
|
+
warnings.resetwarnings()
|
|
348
|
+
return list(data_y_test)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def premlp_irt(data, data_train, load_model=None, return_model=False, **kwargs) -> Union[List[float], Tuple[List[float], Any]]:
|
|
352
|
+
import sklearn.neural_network
|
|
353
|
+
import sentence_transformers
|
|
354
|
+
import subset2evaluate.methods
|
|
355
|
+
|
|
356
|
+
# turn off warnings from sentence-transformers
|
|
357
|
+
import warnings
|
|
358
|
+
warnings.filterwarnings('ignore')
|
|
359
|
+
|
|
360
|
+
embd_model = sentence_transformers.SentenceTransformer("paraphrase-MiniLM-L12-v2")
|
|
361
|
+
data_x_train = embd_model.encode([line["src"] for line in data_train])
|
|
362
|
+
|
|
363
|
+
model_fn = lambda: sklearn.neural_network.MLPRegressor(
|
|
364
|
+
hidden_layer_sizes=(128, 16),
|
|
365
|
+
max_iter=1000,
|
|
366
|
+
verbose=False,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
if load_model is not None:
|
|
370
|
+
model_diff, model_disc = load_model
|
|
371
|
+
else:
|
|
372
|
+
model_diff = model_fn()
|
|
373
|
+
model_disc = model_fn()
|
|
374
|
+
|
|
375
|
+
data_y_diff = [line["irt"]["diff"] for line in data_train]
|
|
376
|
+
data_y_disc = [line["irt"]["disc"] for line in data_train]
|
|
377
|
+
|
|
378
|
+
model_diff.fit(data_x_train, data_y_diff)
|
|
379
|
+
model_disc.fit(data_x_train, data_y_disc)
|
|
380
|
+
|
|
381
|
+
data_x_test = embd_model.encode([line["src"] for line in data])
|
|
382
|
+
data_y_diff = model_diff.predict(data_x_test)
|
|
383
|
+
data_y_disc = model_disc.predict(data_x_test)
|
|
384
|
+
|
|
385
|
+
data_irt_items = [
|
|
386
|
+
{"diff": diff, "disc": disc}
|
|
387
|
+
for diff, disc in zip(data_y_diff, data_y_disc)
|
|
388
|
+
]
|
|
389
|
+
|
|
390
|
+
items_joint = list(zip(data, data_irt_items))
|
|
391
|
+
items_joint.sort(
|
|
392
|
+
key=lambda x: subset2evaluate.methods.fn_irt_utility(x[0], x[1], None, kwargs["fn_utility"]),
|
|
393
|
+
reverse=True
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
items = [x[0] for x in items_joint]
|
|
397
|
+
|
|
398
|
+
if return_model:
|
|
399
|
+
return items, (model_diff, model_disc)
|
|
400
|
+
else:
|
|
401
|
+
return items
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def _run_simulation(args):
|
|
405
|
+
data_old, data_prior = args
|
|
406
|
+
data_new = random.sample(data_old, k=min(len(data_old), 20))
|
|
407
|
+
clusters = subset2evaluate.evaluate.eval_subset_clusters(data_new + data_prior, metric="MetricX-23-c")
|
|
408
|
+
return data_new + data_prior, clusters
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def synthetic_simulation(data, **kwargs):
|
|
412
|
+
raise NotImplementedError("This method is not yet implemented for use.")
|
|
413
|
+
import multiprocessing
|
|
414
|
+
|
|
415
|
+
data_new = []
|
|
416
|
+
while len(data) > 0:
|
|
417
|
+
print("Remaining", len(data))
|
|
418
|
+
with multiprocessing.Pool(20) as pool:
|
|
419
|
+
results = pool.map(_run_simulation, [[data, data_new]] * 1000)
|
|
420
|
+
|
|
421
|
+
# take best clustering but evaluate on human data
|
|
422
|
+
data_new = max(results, key=lambda x: x[1])[0]
|
|
423
|
+
data_best_i = {line["i"] for line in data_new}
|
|
424
|
+
data = [line for line in data if line["i"] not in data_best_i]
|
|
425
|
+
|
|
426
|
+
return data_new
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
METHODS = {
|
|
430
|
+
"synthetic_simulation": synthetic_simulation,
|
|
431
|
+
"premlp_irt_diffdisc": partial(premlp_irt, fn_utility="diffdisc"),
|
|
432
|
+
"premlp_irt_diff": partial(premlp_irt, fn_utility="diff"),
|
|
433
|
+
"premlp_irt_disc": partial(premlp_irt, fn_utility="disc"),
|
|
434
|
+
"premlp_var": partial(premlp_other, fn_utility=lambda line: np.var([sys_v["human"] for sys_v in line["scores"].values()])),
|
|
435
|
+
"premlp_avg": partial(premlp_other, fn_utility=lambda line: np.average([sys_v["human"] for sys_v in line["scores"].values()])),
|
|
436
|
+
}
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from typing import List, Any, Union, Tuple, Dict
|
|
2
|
+
import subset2evaluate.utils as utils
|
|
3
|
+
import subset2evaluate.methods as methods
|
|
4
|
+
import copy
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def run_select_subset(
|
|
9
|
+
data: Union[List, str],
|
|
10
|
+
method: str,
|
|
11
|
+
metric=None,
|
|
12
|
+
return_model=False,
|
|
13
|
+
load_model=None,
|
|
14
|
+
retry_on_error=False,
|
|
15
|
+
**kwargs
|
|
16
|
+
) -> Union[List, Tuple[List, Any]]:
|
|
17
|
+
"""
|
|
18
|
+
Returs list of ordered data and possibly also the model, if return_model=True. Not all methods support this though.
|
|
19
|
+
"""
|
|
20
|
+
# both list or descriptor is fine
|
|
21
|
+
data = utils.load_data(data)
|
|
22
|
+
|
|
23
|
+
if method not in methods.METHODS:
|
|
24
|
+
raise Exception(f"Method {method} not found")
|
|
25
|
+
method = methods.METHODS[method]
|
|
26
|
+
|
|
27
|
+
out_fn = lambda: method(
|
|
28
|
+
data,
|
|
29
|
+
metric=metric,
|
|
30
|
+
return_model=return_model,
|
|
31
|
+
load_model=load_model,
|
|
32
|
+
**kwargs
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# methods might mutate data, make sure we keep it clean
|
|
36
|
+
data = copy.deepcopy(data)
|
|
37
|
+
|
|
38
|
+
# pyirt does not handle divergence well and just crashes
|
|
39
|
+
# on that occasion, let's just restart
|
|
40
|
+
if retry_on_error:
|
|
41
|
+
while True:
|
|
42
|
+
try:
|
|
43
|
+
out = out_fn()
|
|
44
|
+
except Exception as e:
|
|
45
|
+
print(e, file=sys.stderr)
|
|
46
|
+
continue
|
|
47
|
+
else:
|
|
48
|
+
out = out_fn()
|
|
49
|
+
|
|
50
|
+
if return_model:
|
|
51
|
+
out, model = out
|
|
52
|
+
|
|
53
|
+
out: List[Tuple[float, Dict]]
|
|
54
|
+
|
|
55
|
+
# store utilities and sort from highest to lowest
|
|
56
|
+
for score, item in zip(out, data):
|
|
57
|
+
item["subset2evaluate_utility"] = score
|
|
58
|
+
data.sort(key=lambda x: x["subset2evaluate_utility"], reverse=True)
|
|
59
|
+
|
|
60
|
+
if return_model:
|
|
61
|
+
return data, model
|
|
62
|
+
else:
|
|
63
|
+
return data
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def main_cli():
|
|
67
|
+
import argparse
|
|
68
|
+
import json
|
|
69
|
+
import ast
|
|
70
|
+
|
|
71
|
+
args = argparse.ArgumentParser(
|
|
72
|
+
description="""
|
|
73
|
+
Select subset of data. The returned data is ordered by the method's utility in descending order (first is best).
|
|
74
|
+
The segment utility is also stored in the 'subset2evaluate_utility' field of each item.
|
|
75
|
+
"""
|
|
76
|
+
)
|
|
77
|
+
args.add_argument(
|
|
78
|
+
'data', type=str,
|
|
79
|
+
default='wmt23/en-cs',
|
|
80
|
+
help="Either descriptor of data, such as wmt22/en-de, or summeval, or path to JSON file with data."
|
|
81
|
+
)
|
|
82
|
+
args.add_argument(
|
|
83
|
+
'--method', default="metric_var",
|
|
84
|
+
choices=methods.METHODS.keys(),
|
|
85
|
+
help="Subset selection method.",
|
|
86
|
+
)
|
|
87
|
+
args.add_argument(
|
|
88
|
+
'--args',
|
|
89
|
+
default='{"metric": "MetricX-23-c"}',
|
|
90
|
+
help="Additional optional arguments for the method as a Python dictionary."
|
|
91
|
+
)
|
|
92
|
+
args = args.parse_args()
|
|
93
|
+
|
|
94
|
+
data_new = run_select_subset(
|
|
95
|
+
args.data,
|
|
96
|
+
method=args.method,
|
|
97
|
+
**ast.literal_eval(args.args)
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
for item in data_new:
|
|
101
|
+
print(json.dumps(item, ensure_ascii=False))
|
subset2evaluate/test.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import subset2evaluate
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_wmt_loader():
|
|
6
|
+
data = subset2evaluate.utils.load_data("wmt23/all")
|
|
7
|
+
assert isinstance(data, dict)
|
|
8
|
+
assert len(data) == 33
|
|
9
|
+
assert len(data[("wmt23", "en-cs")]) == 1098
|
|
10
|
+
assert "src" in data[("wmt23", "en-cs")][0]
|
|
11
|
+
assert "tgt" in data[("wmt23", "en-cs")][0]
|
|
12
|
+
assert "scores" in data[("wmt23", "en-cs")][0]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_wmt_method_random():
|
|
16
|
+
data_new = subset2evaluate.select_subset.run_select_subset("wmt23/en-cs", method="random", seed=0)
|
|
17
|
+
clu_new, acc_new = subset2evaluate.evaluate.run_evaluate_cluacc(data_new, "wmt23/en-cs", metric="human")
|
|
18
|
+
# random is usually random but we fix the seed
|
|
19
|
+
assert abs(np.average(clu_new) - 1.4000) < 0.01
|
|
20
|
+
assert abs(np.average(acc_new) - 0.8104) < 0.01
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_wmt_method_metric_var():
|
|
24
|
+
data_new = subset2evaluate.select_subset.run_select_subset("wmt23/en-cs", method="metric_var", metric="MetricX-23-c")
|
|
25
|
+
clu_new, acc_new = subset2evaluate.evaluate.run_evaluate_cluacc(data_new, "wmt23/en-cs", metric="human")
|
|
26
|
+
assert abs(np.average(clu_new) - 1.8000) < 0.01
|
|
27
|
+
assert abs(np.average(acc_new) - 0.8552) < 0.01
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_wmt_method_diversity():
|
|
31
|
+
data_new = subset2evaluate.select_subset.run_select_subset("wmt23/en-de", method="diversity_bleu")
|
|
32
|
+
clu_new, acc_new = subset2evaluate.evaluate.run_evaluate_cluacc(data_new, "wmt23/en-de", metric="human")
|
|
33
|
+
assert abs(np.average(clu_new) - 2.3000) < 0.01
|
|
34
|
+
assert abs(np.average(acc_new) - 0.9152) < 0.01
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_summeval_loader():
|
|
38
|
+
data = subset2evaluate.utils.load_data("summeval")
|
|
39
|
+
assert isinstance(data, list)
|
|
40
|
+
assert len(data) == 100
|
|
41
|
+
assert "tgt" in data[0]
|
|
42
|
+
assert "scores" in data[0]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_summeval_method_random():
|
|
46
|
+
data_new = subset2evaluate.select_subset.run_select_subset("summeval", method="random", seed=0)
|
|
47
|
+
clu_new, acc_new = subset2evaluate.evaluate.run_evaluate_cluacc(data_new, "summeval", metric="human_all")
|
|
48
|
+
# random is usually random but we fix the seed
|
|
49
|
+
# it is a bit different on GitHub actions, therefore higher error margin
|
|
50
|
+
assert abs(np.average(clu_new) - 1.6000) < 0.2
|
|
51
|
+
assert abs(np.average(acc_new) - 0.9279) < 0.2
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_summeval_method_metric_var():
|
|
55
|
+
data_new = subset2evaluate.select_subset.run_select_subset("summeval", method="metric_var", metric="coverage")
|
|
56
|
+
clu_new, acc_new = subset2evaluate.evaluate.run_evaluate_cluacc(data_new, "summeval", metric="human_all")
|
|
57
|
+
assert abs(np.average(clu_new) - 2.3000) < 0.01
|
|
58
|
+
assert abs(np.average(acc_new) - 0.9220) < 0.01
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_summeval_method_diversity():
|
|
62
|
+
data_new = subset2evaluate.select_subset.run_select_subset("summeval", method="diversity_bleu")
|
|
63
|
+
clu_new, acc_new = subset2evaluate.evaluate.run_evaluate_cluacc(data_new, "summeval", metric="human_all")
|
|
64
|
+
# it is a bit different on GitHub actions, therefore higher error margin
|
|
65
|
+
assert abs(np.average(clu_new) - 2.9000) < 0.2
|
|
66
|
+
assert abs(np.average(acc_new) - 0.8934) < 0.2
|