crfm-helm 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/METADATA +11 -8
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/RECORD +67 -38
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/WHEEL +1 -1
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/entry_points.txt +2 -1
- helm/benchmark/__init__.py +13 -0
- helm/benchmark/adaptation/adapter_spec.py +3 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -7
- helm/benchmark/augmentations/correct_to_misspelling.json +1 -0
- helm/benchmark/contamination/__init__.py +0 -0
- helm/benchmark/metrics/classification_metrics.py +70 -0
- helm/benchmark/metrics/machine_translation_metrics.py +36 -0
- helm/benchmark/metrics/summarization_metrics.py +7 -8
- helm/benchmark/metrics/test_classification_metrics.py +150 -0
- helm/benchmark/presentation/create_plots.py +617 -0
- helm/benchmark/presentation/run_display.py +7 -48
- helm/benchmark/presentation/summarize.py +4 -2
- helm/benchmark/presentation/test_create_plots.py +32 -0
- helm/benchmark/run.py +144 -48
- helm/benchmark/run_expander.py +164 -47
- helm/benchmark/run_specs.py +346 -39
- helm/benchmark/runner.py +34 -6
- helm/benchmark/scenarios/copyright_scenario.py +1 -1
- helm/benchmark/scenarios/covid_dialog_scenario.py +84 -0
- helm/benchmark/scenarios/imdb_listdir.json +50014 -0
- helm/benchmark/scenarios/lex_glue_scenario.py +253 -0
- helm/benchmark/scenarios/lextreme_scenario.py +458 -0
- helm/benchmark/scenarios/me_q_sum_scenario.py +86 -0
- helm/benchmark/scenarios/med_dialog_scenario.py +132 -0
- helm/benchmark/scenarios/med_mcqa_scenario.py +102 -0
- helm/benchmark/scenarios/med_paragraph_simplification_scenario.py +119 -0
- helm/benchmark/scenarios/med_qa_scenario.py +96 -0
- helm/benchmark/scenarios/opinions_qa_scenario.py +194 -0
- helm/benchmark/scenarios/scenario.py +5 -0
- helm/benchmark/scenarios/the_pile_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +96 -0
- helm/benchmark/static/benchmarking.css +14 -0
- helm/benchmark/static/benchmarking.js +43 -0
- helm/benchmark/static/index.html +2 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/plot-captions.js +16 -0
- helm/benchmark/static/schema.yaml +154 -1
- helm/benchmark/window_services/cohere_window_service.py +20 -0
- helm/benchmark/window_services/flan_t5_window_service.py +29 -0
- helm/benchmark/window_services/huggingface_window_service.py +39 -0
- helm/benchmark/window_services/santacoder_window_service.py +27 -0
- helm/benchmark/window_services/test_flan_t5_window_service.py +12 -0
- helm/benchmark/window_services/wider_ai21_window_service.py +13 -0
- helm/benchmark/window_services/window_service_factory.py +34 -7
- helm/common/codec.py +123 -0
- helm/common/general.py +12 -5
- helm/common/test_codec.py +144 -0
- helm/proxy/clients/aleph_alpha_client.py +47 -28
- helm/proxy/clients/auto_client.py +32 -24
- helm/proxy/clients/google_client.py +88 -0
- helm/proxy/clients/huggingface_client.py +32 -16
- helm/proxy/clients/huggingface_model_registry.py +111 -0
- helm/proxy/clients/huggingface_tokenizer.py +25 -7
- helm/proxy/clients/openai_client.py +60 -2
- helm/proxy/clients/test_huggingface_model_registry.py +57 -0
- helm/proxy/clients/test_huggingface_tokenizer.py +3 -0
- helm/proxy/clients/together_client.py +17 -2
- helm/proxy/clients/yalm_tokenizer/voc_100b.sp +0 -0
- helm/proxy/clients/yalm_tokenizer/yalm_tokenizer.py +8 -2
- helm/proxy/models.py +115 -7
- helm/proxy/test_models.py +1 -1
- helm/benchmark/presentation/present.py +0 -249
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/LICENSE +0 -0
- {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,617 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from datetime import date
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import List, Dict, Optional, Any, Callable, Union, Mapping, Tuple, Set
|
|
8
|
+
|
|
9
|
+
import colorcet
|
|
10
|
+
import matplotlib
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import numpy as np
|
|
13
|
+
from scipy.stats import pearsonr
|
|
14
|
+
import seaborn as sns
|
|
15
|
+
|
|
16
|
+
from helm.common.hierarchical_logger import hlog
|
|
17
|
+
from helm.benchmark.presentation.schema import read_schema
|
|
18
|
+
from helm.benchmark.presentation.summarize import AGGREGATE_WIN_RATE_COLUMN
|
|
19
|
+
|
|
20
|
+
sns.set_style("whitegrid")
|
|
21
|
+
|
|
22
|
+
DOWN_ARROW = "\u2193"
|
|
23
|
+
UP_ARROW = "\u2191"
|
|
24
|
+
metric_group_to_label = {
|
|
25
|
+
"Accuracy": f"Accuracy {UP_ARROW}",
|
|
26
|
+
"Calibration": f"Calibration error {DOWN_ARROW}",
|
|
27
|
+
"Robustness": f"Robustness {UP_ARROW}",
|
|
28
|
+
"Fairness": f"Fairness {UP_ARROW}",
|
|
29
|
+
"Bias": f"Bias (gender repr.) {DOWN_ARROW}",
|
|
30
|
+
"Toxicity": f"Toxicity {DOWN_ARROW}",
|
|
31
|
+
"Efficiency": f"Inference time (s) {DOWN_ARROW}",
|
|
32
|
+
}
|
|
33
|
+
all_metric_groups = list(metric_group_to_label.keys())
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class Column:
|
|
38
|
+
"""Values and metadata for each column of the table."""
|
|
39
|
+
|
|
40
|
+
name: str
|
|
41
|
+
group: str
|
|
42
|
+
metric: str
|
|
43
|
+
values: np.ndarray
|
|
44
|
+
lower_is_better: Optional[bool]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class Table:
|
|
49
|
+
"""Column-based representation of a standard run-group table. See summarize.py for exact documentation."""
|
|
50
|
+
|
|
51
|
+
adapters: List[str]
|
|
52
|
+
columns: List[Column]
|
|
53
|
+
mean_win_rates: Optional[np.ndarray] = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def parse_table(raw_table: Dict[str, Any]) -> Table:
|
|
57
|
+
"""Convert raw table dict to a Table. Ignores strongly contaminated table entries."""
|
|
58
|
+
|
|
59
|
+
def get_cell_values(cells: List[dict]) -> List[Any]:
|
|
60
|
+
values = []
|
|
61
|
+
for cell in cells:
|
|
62
|
+
value = cell["value"] if "value" in cell else np.nan
|
|
63
|
+
if "contamination_level" in cell and cell["contamination_level"] == "strong":
|
|
64
|
+
value = np.nan
|
|
65
|
+
values.append(value)
|
|
66
|
+
return values
|
|
67
|
+
|
|
68
|
+
adapters: Optional[List[str]] = None
|
|
69
|
+
columns: List[Column] = []
|
|
70
|
+
mean_win_rates: Optional[np.ndarray] = None
|
|
71
|
+
for column_index, (header_cell, *column_cells) in enumerate(zip(raw_table["header"], *raw_table["rows"])):
|
|
72
|
+
cell_values = get_cell_values(column_cells)
|
|
73
|
+
if column_index == 0:
|
|
74
|
+
adapters = cell_values
|
|
75
|
+
elif column_index == AGGREGATE_WIN_RATE_COLUMN and "win rate" in header_cell["value"]:
|
|
76
|
+
mean_win_rates = np.array(cell_values)
|
|
77
|
+
else:
|
|
78
|
+
assert "metadata" in header_cell
|
|
79
|
+
name = header_cell["value"]
|
|
80
|
+
group = header_cell["metadata"]["run_group"]
|
|
81
|
+
metric = header_cell["metadata"]["metric"]
|
|
82
|
+
lower_is_better = header_cell["lower_is_better"] if "lower_is_better" in header_cell else None
|
|
83
|
+
columns.append(Column(name, group, metric, np.array(cell_values), lower_is_better))
|
|
84
|
+
assert adapters is not None
|
|
85
|
+
|
|
86
|
+
return Table(adapters, columns, mean_win_rates)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_color_palette(n_colors: int) -> sns.palettes._ColorPalette:
|
|
90
|
+
if n_colors < 6:
|
|
91
|
+
return sns.color_palette("colorblind", n_colors=n_colors)
|
|
92
|
+
else:
|
|
93
|
+
return sns.color_palette(colorcet.glasbey_warm, n_colors=n_colors)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def draw_box_plot(
|
|
97
|
+
x_to_ys: Mapping[str, Union[List[float], np.ndarray]], ax: matplotlib.axis.Axis, rotate_xticklabels: bool = True
|
|
98
|
+
):
|
|
99
|
+
"""Given a mapping from string to floats, draw a box plot on the given axis ax. For instance, this might be a
|
|
100
|
+
mapping from scenario_name to a list of model accuracies in which case the box plot captures aggregate model
|
|
101
|
+
performance and highlights outliers."""
|
|
102
|
+
xs: List[str] = []
|
|
103
|
+
all_ys: List[List[float]] = []
|
|
104
|
+
for x, ys in x_to_ys.items():
|
|
105
|
+
ys = [y for y in ys if not np.isnan(y)]
|
|
106
|
+
if ys:
|
|
107
|
+
xs.append(x)
|
|
108
|
+
all_ys.append(ys)
|
|
109
|
+
ax.scatter([len(xs)] * len(ys), ys, c="#bbb")
|
|
110
|
+
ax.boxplot(all_ys, boxprops={"linewidth": 2}, medianprops={"linewidth": 2})
|
|
111
|
+
if rotate_xticklabels:
|
|
112
|
+
ax.set_xticklabels(xs, rotation=-30, ha="left")
|
|
113
|
+
else:
|
|
114
|
+
ax.set_xticklabels(xs)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class Plotter:
|
|
118
|
+
"""
|
|
119
|
+
Main class producing plots. Each create_*() method reads group data from base_path and creates a plot at the
|
|
120
|
+
save_path. create_all_plots() runs all these functions at once.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(self, base_path: str, save_path: str, plot_format: str):
|
|
124
|
+
self.base_path = base_path
|
|
125
|
+
self.save_path = save_path
|
|
126
|
+
self.plot_format = plot_format
|
|
127
|
+
self._tables_cache: Dict[str, Dict[str, Table]] = {}
|
|
128
|
+
|
|
129
|
+
schema = read_schema()
|
|
130
|
+
self.model_metadata = {model_field.display_name: model_field for model_field in schema.models}
|
|
131
|
+
|
|
132
|
+
def get_group_tables(self, group_name: str) -> Dict[str, Table]:
|
|
133
|
+
"""Reads and parses group tables. Uses _tables_cache to avoid reprocessing the same table multiple times."""
|
|
134
|
+
if group_name in self._tables_cache:
|
|
135
|
+
return self._tables_cache[group_name]
|
|
136
|
+
with open(os.path.join(self.base_path, "groups", f"{group_name}.json")) as fp:
|
|
137
|
+
tables = json.load(fp)
|
|
138
|
+
|
|
139
|
+
name_to_table: Dict[str, Table] = {}
|
|
140
|
+
for table in tables:
|
|
141
|
+
name_to_table[table["title"]] = parse_table(table)
|
|
142
|
+
|
|
143
|
+
return name_to_table
|
|
144
|
+
|
|
145
|
+
def save_figure(self, fig: plt.Figure, name: str):
|
|
146
|
+
"""Save and close a figure."""
|
|
147
|
+
if not os.path.exists(self.save_path):
|
|
148
|
+
os.makedirs(self.save_path)
|
|
149
|
+
fig.savefig(os.path.join(self.save_path, f"{name}.{self.plot_format}"), bbox_inches="tight", dpi=300)
|
|
150
|
+
plt.close(fig)
|
|
151
|
+
|
|
152
|
+
def create_accuracy_v_x_plots(self):
|
|
153
|
+
"""
|
|
154
|
+
For each metric group, create a scatter plot with Accuracy on the x-axis and that metric group on the y-axis.
|
|
155
|
+
Each point corresponds to a model-scenario pair, colored by the scenario.
|
|
156
|
+
"""
|
|
157
|
+
tables = self.get_group_tables("core_scenarios")
|
|
158
|
+
metric_groups_shown = [metric_group for metric_group in all_metric_groups if metric_group != "Accuracy"]
|
|
159
|
+
|
|
160
|
+
num_columns = 3
|
|
161
|
+
num_rows = (len(metric_groups_shown) - 1) // num_columns + 1
|
|
162
|
+
fig, axarr = plt.subplots(num_rows, num_columns, figsize=(5 * num_columns, 3.5 * num_rows))
|
|
163
|
+
all_groups = [column.group for column in tables["Accuracy"].columns]
|
|
164
|
+
palette = get_color_palette(len(all_groups))
|
|
165
|
+
for i, metric_group in enumerate(metric_groups_shown):
|
|
166
|
+
table: Table = tables[metric_group]
|
|
167
|
+
ax = axarr[i // num_columns][i % num_columns]
|
|
168
|
+
for column in table.columns:
|
|
169
|
+
if metric_group == "Bias" and column.metric != "Representation (gender)": # only show gender bias
|
|
170
|
+
continue
|
|
171
|
+
accuracy_column: Column = [c for c in tables["Accuracy"].columns if c.group == column.group][0]
|
|
172
|
+
group_idx = all_groups.index(column.group)
|
|
173
|
+
ax.scatter(
|
|
174
|
+
accuracy_column.values, column.values, color=palette[group_idx], alpha=0.8, label=column.group
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if metric_group in ["Robustness", "Fairness"]:
|
|
178
|
+
ax.plot([0, 1], [0, 1], ls="--", c="gray", zorder=-1)
|
|
179
|
+
if metric_group == "Bias":
|
|
180
|
+
ax.axhline(0.5, ls="--", c="gray", zorder=-1)
|
|
181
|
+
|
|
182
|
+
ax.set_xlabel("Accuracy", fontsize=14)
|
|
183
|
+
ax.set_ylabel(metric_group_to_label[metric_group], fontsize=14)
|
|
184
|
+
ax.set_xlim(-0.1, 1.1)
|
|
185
|
+
|
|
186
|
+
# create dummy lines to display a single legend for all plots
|
|
187
|
+
lines = [
|
|
188
|
+
matplotlib.lines.Line2D([], [], color="white", marker="o", markersize=10, markerfacecolor=color)
|
|
189
|
+
for color in palette
|
|
190
|
+
]
|
|
191
|
+
axarr[0][0].legend(
|
|
192
|
+
lines, all_groups, title="Scenarios", loc="lower left", bbox_to_anchor=(0, 1), ncol=6, numpoints=1
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
fig.subplots_adjust(wspace=0.25, hspace=0.25)
|
|
196
|
+
self.save_figure(fig, "accuracy_v_x")
|
|
197
|
+
|
|
198
|
+
def create_correlation_plots(self):
|
|
199
|
+
"""
|
|
200
|
+
For each metric group, create a box-plot aggregating how correlated that metric group is with each other
|
|
201
|
+
metric_group. Individual point correspond to the correlation (across models) of the two metric groups on a
|
|
202
|
+
single scenario.
|
|
203
|
+
"""
|
|
204
|
+
tables = self.get_group_tables("core_scenarios")
|
|
205
|
+
metric_groups_shown = all_metric_groups
|
|
206
|
+
|
|
207
|
+
num_columns = 3
|
|
208
|
+
num_rows = (len(metric_groups_shown) - 2) // num_columns + 1
|
|
209
|
+
fig, axarr = plt.subplots(num_rows, num_columns, figsize=(5.5 * num_columns, 4.5 * num_rows))
|
|
210
|
+
|
|
211
|
+
for i, metric_group_1 in enumerate(metric_groups_shown[:-1]):
|
|
212
|
+
ax = axarr[i // num_columns][i % num_columns]
|
|
213
|
+
group_to_values_1: Dict[str, np.ndarray] = {}
|
|
214
|
+
for column in tables[metric_group_1].columns:
|
|
215
|
+
if metric_group_1 == "Bias" and column.metric != "Representation (gender)":
|
|
216
|
+
continue
|
|
217
|
+
group_to_values_1[column.group] = column.values
|
|
218
|
+
|
|
219
|
+
metric_group_to_correlations: Dict[str, np.ndarray] = defaultdict(list)
|
|
220
|
+
for j, metric_group_2 in enumerate(metric_groups_shown):
|
|
221
|
+
for column in tables[metric_group_2].columns:
|
|
222
|
+
if metric_group_2 == "Bias" and column.metric != "Representation (gender)":
|
|
223
|
+
continue
|
|
224
|
+
if column.group not in group_to_values_1:
|
|
225
|
+
continue
|
|
226
|
+
values_1 = group_to_values_1[column.group]
|
|
227
|
+
values_2 = column.values
|
|
228
|
+
valid_values = np.logical_and(~np.isnan(values_1), ~np.isnan(values_2))
|
|
229
|
+
if sum(valid_values) >= 2:
|
|
230
|
+
correlation = pearsonr(values_1[valid_values], values_2[valid_values])[0]
|
|
231
|
+
label = metric_group_to_label[metric_group_2]
|
|
232
|
+
metric_group_to_correlations[label].append(correlation)
|
|
233
|
+
draw_box_plot(metric_group_to_correlations, ax)
|
|
234
|
+
ax.set_title(metric_group_to_label[metric_group_1])
|
|
235
|
+
if i % num_columns == 0:
|
|
236
|
+
ax.set_ylabel("Pearson correlation")
|
|
237
|
+
|
|
238
|
+
fig.subplots_adjust(wspace=0.25, hspace=0.45)
|
|
239
|
+
self.save_figure(fig, "metric_correlation")
|
|
240
|
+
|
|
241
|
+
def create_leaderboard_plots(self):
|
|
242
|
+
"""Display the model mean win rates for each group as a bar chart."""
|
|
243
|
+
tables = self.get_group_tables("core_scenarios")
|
|
244
|
+
|
|
245
|
+
metric_groups_shown = [metric_group for metric_group in all_metric_groups if metric_group != "Efficiency"]
|
|
246
|
+
num_columns = 3
|
|
247
|
+
num_rows = (len(metric_groups_shown) - 1) // num_columns + 1
|
|
248
|
+
fig, axarr = plt.subplots(num_rows, num_columns, figsize=(4 * num_columns, 6.7 * num_rows))
|
|
249
|
+
for i, metric_group in enumerate(metric_groups_shown):
|
|
250
|
+
win_rates, models = [], []
|
|
251
|
+
for win_rate, model in sorted(zip(tables[metric_group].mean_win_rates, tables[metric_group].adapters)):
|
|
252
|
+
if not np.isnan(win_rate):
|
|
253
|
+
win_rates.append(win_rate)
|
|
254
|
+
models.append(model)
|
|
255
|
+
ax = axarr[i // num_columns][i % num_columns]
|
|
256
|
+
ax.plot([0, 1], [0, len(models) - 1], ls="--", c="#bbb", zorder=0)
|
|
257
|
+
ax.barh(models, win_rates, label=models)
|
|
258
|
+
ax.set_xlim(-0.1, 1.1)
|
|
259
|
+
ax.set_title(metric_group)
|
|
260
|
+
fig.subplots_adjust(wspace=1.8, hspace=0.15)
|
|
261
|
+
self.save_figure(fig, "model_ranking_all")
|
|
262
|
+
|
|
263
|
+
def create_accuracy_v_model_property_plot(
|
|
264
|
+
self,
|
|
265
|
+
property_name: str,
|
|
266
|
+
model_name_to_property: Callable[[str], Any],
|
|
267
|
+
cumulative: bool = False,
|
|
268
|
+
logscale: bool = False,
|
|
269
|
+
annotate_models: bool = True,
|
|
270
|
+
):
|
|
271
|
+
"""
|
|
272
|
+
Plot the accuracy of each scenario over some model property (e.g., number of parameters).
|
|
273
|
+
Args:
|
|
274
|
+
property_name: Property name displayed as x-label.
|
|
275
|
+
model_name_to_property: A function that maps the name of the model to the property we use for the plot.
|
|
276
|
+
cumulative: Plot the best accuracy achieved by a model with at most that property values (useful for dates).
|
|
277
|
+
logscale: Whether we use a logscale for the x-axis.
|
|
278
|
+
annotate_models: For each unique property value, add a text annotation with the corresponding model names.
|
|
279
|
+
"""
|
|
280
|
+
fig, ax = plt.subplots(1, 1, figsize=(11, 4))
|
|
281
|
+
milestones: Dict[Any, Set[str]] = defaultdict(set) # keep track of the models with each property value
|
|
282
|
+
table = self.get_group_tables("core_scenarios")["Accuracy"]
|
|
283
|
+
palette = get_color_palette(len(table.columns))
|
|
284
|
+
for column, color in zip(table.columns, palette):
|
|
285
|
+
data: List[Tuple[Any, float]] = []
|
|
286
|
+
for model_name, accuracy in zip(table.adapters, column.values):
|
|
287
|
+
key = model_name_to_property(model_name)
|
|
288
|
+
if key is None or np.isnan(accuracy):
|
|
289
|
+
continue
|
|
290
|
+
data.append((key, accuracy))
|
|
291
|
+
milestones[key].add(model_name)
|
|
292
|
+
data.sort()
|
|
293
|
+
xs: List[Any] = []
|
|
294
|
+
ys: List[float] = []
|
|
295
|
+
if cumulative:
|
|
296
|
+
for now in list(dict.fromkeys(key for key, _ in data)):
|
|
297
|
+
xs.append(now)
|
|
298
|
+
ys.append(max(y for (x, y) in data if x <= now))
|
|
299
|
+
else:
|
|
300
|
+
for x, y in data:
|
|
301
|
+
xs.append(x)
|
|
302
|
+
ys.append(y)
|
|
303
|
+
plot_func = ax.semilogx if logscale else ax.plot
|
|
304
|
+
plot_func(xs, ys, label=column.group, color=color, marker="o")
|
|
305
|
+
|
|
306
|
+
for key, model_names in sorted(milestones.items()):
|
|
307
|
+
ax.axvline(x=key, ls="--", c="#bbb", zorder=0)
|
|
308
|
+
if annotate_models:
|
|
309
|
+
ax.text(key, 1.01, "/".join(model_names), rotation=40)
|
|
310
|
+
|
|
311
|
+
# sort the legend according to the left-most value of each plot (makes it easier to visually map names to lines)
|
|
312
|
+
handles, labels = ax.get_legend_handles_labels()
|
|
313
|
+
legend_order = np.argsort([-h.get_data()[1][-1] for h in handles])
|
|
314
|
+
ax.legend(
|
|
315
|
+
[handles[i] for i in legend_order],
|
|
316
|
+
[labels[i] for i in legend_order],
|
|
317
|
+
loc="upper left",
|
|
318
|
+
bbox_to_anchor=(1.02, 1),
|
|
319
|
+
)
|
|
320
|
+
property_save_name = property_name.replace(" ", "_").lower()
|
|
321
|
+
ax.set_xlabel(property_name)
|
|
322
|
+
ax.set_ylabel("Accuracy")
|
|
323
|
+
ax.set_ylim(0, 1)
|
|
324
|
+
self.save_figure(fig, f"accuracy_over_{property_save_name}")
|
|
325
|
+
|
|
326
|
+
def create_all_accuracy_v_model_property_plots(self):
|
|
327
|
+
"""
|
|
328
|
+
Create accuracy-vs-property plots for: release date, #parameters, thePile perplexity.
|
|
329
|
+
In all cases, we use a coarse value for the property to make the plot text annotations cleaner.
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
def get_model_release_date(model_name: str) -> Optional[date]:
|
|
333
|
+
"""Maps a model name to the month of model release."""
|
|
334
|
+
release_date = self.model_metadata[model_name].release_date
|
|
335
|
+
if release_date is None:
|
|
336
|
+
return None
|
|
337
|
+
return release_date.replace(day=1)
|
|
338
|
+
|
|
339
|
+
def get_model_size(model_name: str) -> Optional[int]:
|
|
340
|
+
"""Maps a model name to the number of parameters, rounding to the nearest leading digit."""
|
|
341
|
+
size = self.model_metadata[model_name].num_parameters
|
|
342
|
+
if size is None:
|
|
343
|
+
return None
|
|
344
|
+
grain = 10 ** (len(str(size)) - 1)
|
|
345
|
+
return round(size / grain) * grain # only look at first digit
|
|
346
|
+
|
|
347
|
+
# Read the perplexity of The Pile according to each model
|
|
348
|
+
bpb_table = self.get_group_tables("the_pile")["The Pile"]
|
|
349
|
+
model_to_bpb: Dict[str, float] = {
|
|
350
|
+
model: bpb for model, bpb in zip(bpb_table.adapters, bpb_table.columns[0].values)
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
def get_model_perplexity(model_name: str) -> Optional[float]:
|
|
354
|
+
"""Maps a model name to the perplexity of The Pile of parameters, rounding based on some granularity."""
|
|
355
|
+
if model_name not in model_to_bpb or np.isnan(model_to_bpb[model_name]):
|
|
356
|
+
return None
|
|
357
|
+
bpb = model_to_bpb[model_name]
|
|
358
|
+
grain = 0.016
|
|
359
|
+
return round(bpb / grain) * grain
|
|
360
|
+
|
|
361
|
+
annotate_models = True if self.plot_format == "pdf" else False
|
|
362
|
+
self.create_accuracy_v_model_property_plot(
|
|
363
|
+
"Release date",
|
|
364
|
+
get_model_release_date,
|
|
365
|
+
cumulative=True,
|
|
366
|
+
annotate_models=annotate_models,
|
|
367
|
+
)
|
|
368
|
+
self.create_accuracy_v_model_property_plot(
|
|
369
|
+
"Num parameters",
|
|
370
|
+
get_model_size,
|
|
371
|
+
cumulative=True,
|
|
372
|
+
logscale=True,
|
|
373
|
+
annotate_models=annotate_models,
|
|
374
|
+
)
|
|
375
|
+
self.create_accuracy_v_model_property_plot(
|
|
376
|
+
"The Pile perplexity",
|
|
377
|
+
get_model_perplexity,
|
|
378
|
+
logscale=True,
|
|
379
|
+
annotate_models=annotate_models,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
def create_accuracy_v_access_bar_plot(self):
|
|
383
|
+
"""
|
|
384
|
+
For each scenario, plot the best model performance for each access level (e.g., closed). We plot both the
|
|
385
|
+
performance of the best model chosen for a particular scenario (transparent, in the back) as well as the best
|
|
386
|
+
overall model at that access level.
|
|
387
|
+
"""
|
|
388
|
+
table = self.get_group_tables("core_scenarios")["Accuracy"]
|
|
389
|
+
|
|
390
|
+
all_groups = [column.group for column in table.columns]
|
|
391
|
+
fig, ax = plt.subplots(1, 1, figsize=(9, 3))
|
|
392
|
+
palette = get_color_palette(n_colors=3)
|
|
393
|
+
access_levels = ["open", "limited", "closed"]
|
|
394
|
+
|
|
395
|
+
for i, access_level in enumerate(access_levels):
|
|
396
|
+
model_indices: List[int] = [
|
|
397
|
+
idx for idx, model in enumerate(table.adapters) if self.model_metadata[model].access == access_level
|
|
398
|
+
]
|
|
399
|
+
best_model_index = model_indices[table.mean_win_rates[model_indices].argmax()]
|
|
400
|
+
|
|
401
|
+
xs = np.arange(len(all_groups))
|
|
402
|
+
width = 0.25
|
|
403
|
+
ys, ys_single = [], []
|
|
404
|
+
for column in table.columns:
|
|
405
|
+
ys.append(column.values[model_indices].max())
|
|
406
|
+
ys_single.append(column.values[best_model_index])
|
|
407
|
+
ax.bar(xs + (i - 1) * width, ys, width, color=palette[i], alpha=0.5)
|
|
408
|
+
ax.bar(xs + (i - 1) * width, ys_single, width, label=access_level, color=palette[i])
|
|
409
|
+
|
|
410
|
+
ax.set_ylabel("Accuracy")
|
|
411
|
+
ax.set_xticks(xs, all_groups, rotation=-20, ha="left")
|
|
412
|
+
ax.legend(loc="upper left", bbox_to_anchor=(0.61, 0.99))
|
|
413
|
+
self.save_figure(fig, "accuracy_v_access")
|
|
414
|
+
|
|
415
|
+
def create_task_summary_plots(self):
|
|
416
|
+
"""For each metric group, create a box plot with scenario performance across models."""
|
|
417
|
+
tables = self.get_group_tables("core_scenarios")
|
|
418
|
+
metric_groups = ["Accuracy", "Calibration", "Robustness", "Fairness", "Bias", "Toxicity"]
|
|
419
|
+
num_columns = 2
|
|
420
|
+
num_rows = (len(metric_groups) - 1) // num_columns + 1
|
|
421
|
+
fig, axarr = plt.subplots(num_rows, num_columns, figsize=(7 * num_columns, 4.5 * num_rows))
|
|
422
|
+
for i, metric_group in enumerate(metric_groups):
|
|
423
|
+
ax = axarr[i // num_columns][i % num_columns]
|
|
424
|
+
table = tables[metric_group]
|
|
425
|
+
group_to_accuracies: Dict[str, np.ndarray] = {
|
|
426
|
+
column.group: column.values
|
|
427
|
+
for column in table.columns
|
|
428
|
+
if not (metric_group == "Bias" and column.metric != "Representation (gender)")
|
|
429
|
+
}
|
|
430
|
+
draw_box_plot(group_to_accuracies, ax)
|
|
431
|
+
ax.set_title(metric_group_to_label[metric_group])
|
|
432
|
+
|
|
433
|
+
fig.subplots_adjust(hspace=0.7)
|
|
434
|
+
self.save_figure(fig, "generic_summary")
|
|
435
|
+
|
|
436
|
+
def create_targeted_eval_plots(self):
|
|
437
|
+
"""Create a box plots with scenario accuracy across models for a range of targeted evaluations."""
|
|
438
|
+
fig, axd = plt.subplot_mosaic([["language", "knowledge"], ["reasoning", "reasoning"]], figsize=(12, 7))
|
|
439
|
+
for targeted_eval in ["language", "knowledge", "reasoning"]:
|
|
440
|
+
table = self.get_group_tables(targeted_eval)["Accuracy"]
|
|
441
|
+
ax = axd[targeted_eval]
|
|
442
|
+
|
|
443
|
+
group_to_accuracies: Dict[str, np.ndarray] = {}
|
|
444
|
+
for column in table.columns:
|
|
445
|
+
arrow = DOWN_ARROW if column.lower_is_better else UP_ARROW
|
|
446
|
+
group = f"{column.group}\n({column.metric} {arrow})"
|
|
447
|
+
group_to_accuracies[group] = column.values
|
|
448
|
+
|
|
449
|
+
draw_box_plot(group_to_accuracies, ax)
|
|
450
|
+
ax.set_title(targeted_eval.capitalize())
|
|
451
|
+
ax.set_ylim(-0.1, 3.6 if targeted_eval == "language" else 1.1)
|
|
452
|
+
|
|
453
|
+
fig.subplots_adjust(hspace=0.75)
|
|
454
|
+
self.save_figure(fig, "targeted_evals")
|
|
455
|
+
|
|
456
|
+
def create_copyright_plot(self):
|
|
457
|
+
"""Plot copyright metrics across models."""
|
|
458
|
+
table = self.get_group_tables("harms")["Copyright metrics"]
|
|
459
|
+
fig, ax = plt.subplots(figsize=(6.5, 3.5))
|
|
460
|
+
group_to_values: Dict[str, np.ndarray] = {}
|
|
461
|
+
for column in table.columns:
|
|
462
|
+
if "dist" in column.metric:
|
|
463
|
+
continue
|
|
464
|
+
arrow = DOWN_ARROW if column.lower_is_better else UP_ARROW
|
|
465
|
+
group = f"{column.group}\n({column.metric} {arrow})"
|
|
466
|
+
group_to_values[group] = column.values
|
|
467
|
+
draw_box_plot(group_to_values, ax, rotate_xticklabels=False)
|
|
468
|
+
ax.set_title("Copyright")
|
|
469
|
+
self.save_figure(fig, "copyright")
|
|
470
|
+
|
|
471
|
+
def create_bbq_plot(self):
|
|
472
|
+
"""Plot BBQ metrics across models."""
|
|
473
|
+
table = self.get_group_tables("harms")["BBQ metrics"]
|
|
474
|
+
n = len(table.columns)
|
|
475
|
+
fig, axarr = plt.subplots(1, n, figsize=(3.5 * n, 7))
|
|
476
|
+
for i, column in enumerate(table.columns):
|
|
477
|
+
ax = axarr[i]
|
|
478
|
+
indices = np.argsort(column.values)
|
|
479
|
+
indices = indices[: -np.isnan(column.values).sum()] # remove nans from the end
|
|
480
|
+
models = np.array(table.adapters)[indices]
|
|
481
|
+
values = column.values[indices]
|
|
482
|
+
ax.barh(models, values)
|
|
483
|
+
ax.set_title(f"{column.metric} {DOWN_ARROW}")
|
|
484
|
+
fig.subplots_adjust(wspace=1.85)
|
|
485
|
+
self.save_figure(fig, "bbq_bars")
|
|
486
|
+
|
|
487
|
+
def create_in_context_examples_plot(self):
|
|
488
|
+
"""
|
|
489
|
+
Plot model performance as a function of in-context examples used. One plot per scenario, one line per model.
|
|
490
|
+
We retrieve the actual average number of in-context examples used from the "General information" table.
|
|
491
|
+
"""
|
|
492
|
+
tables = self.get_group_tables("ablation_in_context")
|
|
493
|
+
|
|
494
|
+
group_to_num_examples: Dict[str, np.ndarray] = {}
|
|
495
|
+
for column in tables["General information"].columns:
|
|
496
|
+
if "# train" in column.name:
|
|
497
|
+
group_to_num_examples[column.group] = column.values
|
|
498
|
+
|
|
499
|
+
table = tables["Accuracy"]
|
|
500
|
+
n = len(table.columns)
|
|
501
|
+
fig, axarr = plt.subplots(1, n, figsize=(3.5 * n, 2.5))
|
|
502
|
+
for i, column in enumerate(table.columns):
|
|
503
|
+
model_examples_to_accuracy: Dict[str, Dict[float, float]] = defaultdict(dict)
|
|
504
|
+
for adapter, accuracy, num_examples in zip(
|
|
505
|
+
table.adapters, column.values, group_to_num_examples[column.group]
|
|
506
|
+
):
|
|
507
|
+
model = adapter.split(" [")[0]
|
|
508
|
+
model_examples_to_accuracy[model][num_examples] = accuracy
|
|
509
|
+
|
|
510
|
+
ax = axarr[i]
|
|
511
|
+
for model, examples_to_accuracy in model_examples_to_accuracy.items():
|
|
512
|
+
if "UL2" in model:
|
|
513
|
+
continue
|
|
514
|
+
offset = 2
|
|
515
|
+
xs: List[int] = []
|
|
516
|
+
ys: List[float] = []
|
|
517
|
+
for x, y in sorted(examples_to_accuracy.items()):
|
|
518
|
+
xs.append(x)
|
|
519
|
+
ys.append(y)
|
|
520
|
+
ax.semilogx([x + offset for x in xs], ys, label=model, marker="o", base=2)
|
|
521
|
+
xs_max = [0] + [2**i for i in range(5)]
|
|
522
|
+
ax.set_xticks([x + offset for x in xs_max], xs_max)
|
|
523
|
+
ax.set_title(column.group)
|
|
524
|
+
ax.set_ylabel(column.metric)
|
|
525
|
+
ax.set_xlabel("#in-context examples")
|
|
526
|
+
if i == 0:
|
|
527
|
+
ax.legend(ncol=5, loc="upper left", bbox_to_anchor=(0, 1.45))
|
|
528
|
+
fig.subplots_adjust(wspace=0.32, hspace=0.5)
|
|
529
|
+
self.save_figure(fig, "in_context_ablations")
|
|
530
|
+
|
|
531
|
+
def create_mc_ablations_plot(self):
|
|
532
|
+
"""For each scenario, plot model performance (as a bar plot) for each multiple-choice adaptation method."""
|
|
533
|
+
table = self.get_group_tables("ablation_multiple_choice")["Accuracy"]
|
|
534
|
+
|
|
535
|
+
num_columns = 4
|
|
536
|
+
num_rows = (len(table.columns) - 1) // num_columns + 1
|
|
537
|
+
fig, axarr = plt.subplots(num_rows, num_columns, figsize=(4 * num_columns, 3 * num_rows))
|
|
538
|
+
|
|
539
|
+
method_to_label = {
|
|
540
|
+
"multiple_choice_joint": "Multiple Choice Joint",
|
|
541
|
+
"multiple_choice_separate_original": "Multiple Choice Separate",
|
|
542
|
+
"multiple_choice_separate_calibrated": "Multiple Choice Separate Calibrated",
|
|
543
|
+
}
|
|
544
|
+
palette = get_color_palette(len(method_to_label))
|
|
545
|
+
width = 0.2
|
|
546
|
+
for i, column in enumerate(table.columns):
|
|
547
|
+
ax = axarr[i // num_columns][i % num_columns]
|
|
548
|
+
for j, method in enumerate(method_to_label):
|
|
549
|
+
models: List[str] = []
|
|
550
|
+
ys: List[float] = []
|
|
551
|
+
for adapter, accuracy in zip(table.adapters, column.values):
|
|
552
|
+
if method not in adapter:
|
|
553
|
+
continue
|
|
554
|
+
models.append(adapter.split(" [")[0])
|
|
555
|
+
ys.append(accuracy)
|
|
556
|
+
xs = np.arange(len(models))
|
|
557
|
+
ax.bar(xs + (j - 1) * width, ys, width, color=palette[j], label=method_to_label[method])
|
|
558
|
+
ax.set_xticks(xs, models, rotation=-20, ha="left")
|
|
559
|
+
ax.set_title(f"{column.group} ({column.metric})")
|
|
560
|
+
if i == 0:
|
|
561
|
+
ax.legend(ncol=3, loc="upper left", bbox_to_anchor=(0, 1.4))
|
|
562
|
+
fig.subplots_adjust(wspace=0.25, hspace=0.65)
|
|
563
|
+
self.save_figure(fig, "mc_ablations")
|
|
564
|
+
|
|
565
|
+
def create_constrast_set_plots(self):
|
|
566
|
+
"""For each contrast set scenario, plot the accuracy and robustness of each model on a scatter plot."""
|
|
567
|
+
tables = self.get_group_tables("robustness_contrast_sets")
|
|
568
|
+
fig, axarr = plt.subplots(1, 2, figsize=(7, 3))
|
|
569
|
+
for ax, (table_name, table) in zip(axarr, tables.items()):
|
|
570
|
+
xs = [column for column in table.columns if column.name == "EM"][0].values
|
|
571
|
+
ys = [column for column in table.columns if column.name == "EM (Robustness)"][0].values
|
|
572
|
+
ax.scatter(xs, ys)
|
|
573
|
+
ax.plot([0, 1], [0, 1], color="gray", ls="--")
|
|
574
|
+
ax.set_title(table.columns[0].group)
|
|
575
|
+
ax.set_xlabel("Accuracy")
|
|
576
|
+
ax.set_ylabel("Robustness")
|
|
577
|
+
fig.subplots_adjust(wspace=0.25)
|
|
578
|
+
self.save_figure(fig, "contrast_sets")
|
|
579
|
+
|
|
580
|
+
def create_all_plots(self):
|
|
581
|
+
"""Create all the plots used in the HELM paper."""
|
|
582
|
+
self.create_accuracy_v_x_plots()
|
|
583
|
+
self.create_correlation_plots()
|
|
584
|
+
self.create_leaderboard_plots()
|
|
585
|
+
self.create_all_accuracy_v_model_property_plots()
|
|
586
|
+
self.create_accuracy_v_access_bar_plot()
|
|
587
|
+
self.create_task_summary_plots()
|
|
588
|
+
self.create_targeted_eval_plots()
|
|
589
|
+
self.create_copyright_plot()
|
|
590
|
+
self.create_bbq_plot()
|
|
591
|
+
self.create_in_context_examples_plot()
|
|
592
|
+
self.create_mc_ablations_plot()
|
|
593
|
+
self.create_constrast_set_plots()
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def main():
|
|
597
|
+
"""
|
|
598
|
+
This script creates the plots used in the HELM paper (https://arxiv.org/abs/2211.09110).
|
|
599
|
+
It should be run _after_ running `summarize.py` with the same `benchmark_output` and `suite` arguments and through
|
|
600
|
+
the top-level command `helm-create-plots`.
|
|
601
|
+
"""
|
|
602
|
+
parser = argparse.ArgumentParser()
|
|
603
|
+
parser.add_argument("-o", "--output-path", type=str, help="Path to benchmarking output", default="benchmark_output")
|
|
604
|
+
parser.add_argument("--suite", type=str, help="Name of the suite that we are plotting", required=True)
|
|
605
|
+
parser.add_argument("--plot-format", help="Format for saving plots", default="png", choices=["png", "pdf"])
|
|
606
|
+
args = parser.parse_args()
|
|
607
|
+
base_path = os.path.join(args.output_path, "runs", args.suite)
|
|
608
|
+
if not os.path.exists(os.path.join(base_path, "groups")):
|
|
609
|
+
hlog(f"ERROR: Could not find `groups` directory under {base_path}. Did you run `summarize.py` first?")
|
|
610
|
+
return
|
|
611
|
+
save_path = os.path.join(base_path, "plots")
|
|
612
|
+
plotter = Plotter(base_path=base_path, save_path=save_path, plot_format=args.plot_format)
|
|
613
|
+
plotter.create_all_plots()
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
if __name__ == "__main__":
|
|
617
|
+
main()
|