crfm-helm 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/METADATA +11 -8
  2. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/RECORD +67 -38
  3. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/WHEEL +1 -1
  4. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/entry_points.txt +2 -1
  5. helm/benchmark/__init__.py +13 -0
  6. helm/benchmark/adaptation/adapter_spec.py +3 -0
  7. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -7
  8. helm/benchmark/augmentations/correct_to_misspelling.json +1 -0
  9. helm/benchmark/contamination/__init__.py +0 -0
  10. helm/benchmark/metrics/classification_metrics.py +70 -0
  11. helm/benchmark/metrics/machine_translation_metrics.py +36 -0
  12. helm/benchmark/metrics/summarization_metrics.py +7 -8
  13. helm/benchmark/metrics/test_classification_metrics.py +150 -0
  14. helm/benchmark/presentation/create_plots.py +617 -0
  15. helm/benchmark/presentation/run_display.py +7 -48
  16. helm/benchmark/presentation/summarize.py +4 -2
  17. helm/benchmark/presentation/test_create_plots.py +32 -0
  18. helm/benchmark/run.py +144 -48
  19. helm/benchmark/run_expander.py +164 -47
  20. helm/benchmark/run_specs.py +346 -39
  21. helm/benchmark/runner.py +34 -6
  22. helm/benchmark/scenarios/copyright_scenario.py +1 -1
  23. helm/benchmark/scenarios/covid_dialog_scenario.py +84 -0
  24. helm/benchmark/scenarios/imdb_listdir.json +50014 -0
  25. helm/benchmark/scenarios/lex_glue_scenario.py +253 -0
  26. helm/benchmark/scenarios/lextreme_scenario.py +458 -0
  27. helm/benchmark/scenarios/me_q_sum_scenario.py +86 -0
  28. helm/benchmark/scenarios/med_dialog_scenario.py +132 -0
  29. helm/benchmark/scenarios/med_mcqa_scenario.py +102 -0
  30. helm/benchmark/scenarios/med_paragraph_simplification_scenario.py +119 -0
  31. helm/benchmark/scenarios/med_qa_scenario.py +96 -0
  32. helm/benchmark/scenarios/opinions_qa_scenario.py +194 -0
  33. helm/benchmark/scenarios/scenario.py +5 -0
  34. helm/benchmark/scenarios/the_pile_scenario.py +1 -1
  35. helm/benchmark/scenarios/wmt_14_scenario.py +96 -0
  36. helm/benchmark/static/benchmarking.css +14 -0
  37. helm/benchmark/static/benchmarking.js +43 -0
  38. helm/benchmark/static/index.html +2 -0
  39. helm/benchmark/static/json-urls.js +4 -0
  40. helm/benchmark/static/plot-captions.js +16 -0
  41. helm/benchmark/static/schema.yaml +154 -1
  42. helm/benchmark/window_services/cohere_window_service.py +20 -0
  43. helm/benchmark/window_services/flan_t5_window_service.py +29 -0
  44. helm/benchmark/window_services/huggingface_window_service.py +39 -0
  45. helm/benchmark/window_services/santacoder_window_service.py +27 -0
  46. helm/benchmark/window_services/test_flan_t5_window_service.py +12 -0
  47. helm/benchmark/window_services/wider_ai21_window_service.py +13 -0
  48. helm/benchmark/window_services/window_service_factory.py +34 -7
  49. helm/common/codec.py +123 -0
  50. helm/common/general.py +12 -5
  51. helm/common/test_codec.py +144 -0
  52. helm/proxy/clients/aleph_alpha_client.py +47 -28
  53. helm/proxy/clients/auto_client.py +32 -24
  54. helm/proxy/clients/google_client.py +88 -0
  55. helm/proxy/clients/huggingface_client.py +32 -16
  56. helm/proxy/clients/huggingface_model_registry.py +111 -0
  57. helm/proxy/clients/huggingface_tokenizer.py +25 -7
  58. helm/proxy/clients/openai_client.py +60 -2
  59. helm/proxy/clients/test_huggingface_model_registry.py +57 -0
  60. helm/proxy/clients/test_huggingface_tokenizer.py +3 -0
  61. helm/proxy/clients/together_client.py +17 -2
  62. helm/proxy/clients/yalm_tokenizer/voc_100b.sp +0 -0
  63. helm/proxy/clients/yalm_tokenizer/yalm_tokenizer.py +8 -2
  64. helm/proxy/models.py +115 -7
  65. helm/proxy/test_models.py +1 -1
  66. helm/benchmark/presentation/present.py +0 -249
  67. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/LICENSE +0 -0
  68. {crfm_helm-0.2.0.dist-info → crfm_helm-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,617 @@
1
+ import argparse
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from datetime import date
5
+ import json
6
+ import os
7
+ from typing import List, Dict, Optional, Any, Callable, Union, Mapping, Tuple, Set
8
+
9
+ import colorcet
10
+ import matplotlib
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ from scipy.stats import pearsonr
14
+ import seaborn as sns
15
+
16
+ from helm.common.hierarchical_logger import hlog
17
+ from helm.benchmark.presentation.schema import read_schema
18
+ from helm.benchmark.presentation.summarize import AGGREGATE_WIN_RATE_COLUMN
19
+
20
+ sns.set_style("whitegrid")
21
+
22
+ DOWN_ARROW = "\u2193"
23
+ UP_ARROW = "\u2191"
24
+ metric_group_to_label = {
25
+ "Accuracy": f"Accuracy {UP_ARROW}",
26
+ "Calibration": f"Calibration error {DOWN_ARROW}",
27
+ "Robustness": f"Robustness {UP_ARROW}",
28
+ "Fairness": f"Fairness {UP_ARROW}",
29
+ "Bias": f"Bias (gender repr.) {DOWN_ARROW}",
30
+ "Toxicity": f"Toxicity {DOWN_ARROW}",
31
+ "Efficiency": f"Inference time (s) {DOWN_ARROW}",
32
+ }
33
+ all_metric_groups = list(metric_group_to_label.keys())
34
+
35
+
36
+ @dataclass
37
+ class Column:
38
+ """Values and metadata for each column of the table."""
39
+
40
+ name: str
41
+ group: str
42
+ metric: str
43
+ values: np.ndarray
44
+ lower_is_better: Optional[bool]
45
+
46
+
47
+ @dataclass
48
+ class Table:
49
+ """Column-based representation of a standard run-group table. See summarize.py for exact documentation."""
50
+
51
+ adapters: List[str]
52
+ columns: List[Column]
53
+ mean_win_rates: Optional[np.ndarray] = None
54
+
55
+
56
+ def parse_table(raw_table: Dict[str, Any]) -> Table:
57
+ """Convert raw table dict to a Table. Ignores strongly contaminated table entries."""
58
+
59
+ def get_cell_values(cells: List[dict]) -> List[Any]:
60
+ values = []
61
+ for cell in cells:
62
+ value = cell["value"] if "value" in cell else np.nan
63
+ if "contamination_level" in cell and cell["contamination_level"] == "strong":
64
+ value = np.nan
65
+ values.append(value)
66
+ return values
67
+
68
+ adapters: Optional[List[str]] = None
69
+ columns: List[Column] = []
70
+ mean_win_rates: Optional[np.ndarray] = None
71
+ for column_index, (header_cell, *column_cells) in enumerate(zip(raw_table["header"], *raw_table["rows"])):
72
+ cell_values = get_cell_values(column_cells)
73
+ if column_index == 0:
74
+ adapters = cell_values
75
+ elif column_index == AGGREGATE_WIN_RATE_COLUMN and "win rate" in header_cell["value"]:
76
+ mean_win_rates = np.array(cell_values)
77
+ else:
78
+ assert "metadata" in header_cell
79
+ name = header_cell["value"]
80
+ group = header_cell["metadata"]["run_group"]
81
+ metric = header_cell["metadata"]["metric"]
82
+ lower_is_better = header_cell["lower_is_better"] if "lower_is_better" in header_cell else None
83
+ columns.append(Column(name, group, metric, np.array(cell_values), lower_is_better))
84
+ assert adapters is not None
85
+
86
+ return Table(adapters, columns, mean_win_rates)
87
+
88
+
89
+ def get_color_palette(n_colors: int) -> sns.palettes._ColorPalette:
90
+ if n_colors < 6:
91
+ return sns.color_palette("colorblind", n_colors=n_colors)
92
+ else:
93
+ return sns.color_palette(colorcet.glasbey_warm, n_colors=n_colors)
94
+
95
+
96
+ def draw_box_plot(
97
+ x_to_ys: Mapping[str, Union[List[float], np.ndarray]], ax: matplotlib.axis.Axis, rotate_xticklabels: bool = True
98
+ ):
99
+ """Given a mapping from string to floats, draw a box plot on the given axis ax. For instance, this might be a
100
+ mapping from scenario_name to a list of model accuracies in which case the box plot captures aggregate model
101
+ performance and highlights outliers."""
102
+ xs: List[str] = []
103
+ all_ys: List[List[float]] = []
104
+ for x, ys in x_to_ys.items():
105
+ ys = [y for y in ys if not np.isnan(y)]
106
+ if ys:
107
+ xs.append(x)
108
+ all_ys.append(ys)
109
+ ax.scatter([len(xs)] * len(ys), ys, c="#bbb")
110
+ ax.boxplot(all_ys, boxprops={"linewidth": 2}, medianprops={"linewidth": 2})
111
+ if rotate_xticklabels:
112
+ ax.set_xticklabels(xs, rotation=-30, ha="left")
113
+ else:
114
+ ax.set_xticklabels(xs)
115
+
116
+
117
+ class Plotter:
118
+ """
119
+ Main class producing plots. Each create_*() method reads group data from base_path and creates a plot at the
120
+ save_path. create_all_plots() runs all these functions at once.
121
+ """
122
+
123
+ def __init__(self, base_path: str, save_path: str, plot_format: str):
124
+ self.base_path = base_path
125
+ self.save_path = save_path
126
+ self.plot_format = plot_format
127
+ self._tables_cache: Dict[str, Dict[str, Table]] = {}
128
+
129
+ schema = read_schema()
130
+ self.model_metadata = {model_field.display_name: model_field for model_field in schema.models}
131
+
132
+ def get_group_tables(self, group_name: str) -> Dict[str, Table]:
133
+ """Reads and parses group tables. Uses _tables_cache to avoid reprocessing the same table multiple times."""
134
+ if group_name in self._tables_cache:
135
+ return self._tables_cache[group_name]
136
+ with open(os.path.join(self.base_path, "groups", f"{group_name}.json")) as fp:
137
+ tables = json.load(fp)
138
+
139
+ name_to_table: Dict[str, Table] = {}
140
+ for table in tables:
141
+ name_to_table[table["title"]] = parse_table(table)
142
+
143
+ return name_to_table
144
+
145
+ def save_figure(self, fig: plt.Figure, name: str):
146
+ """Save and close a figure."""
147
+ if not os.path.exists(self.save_path):
148
+ os.makedirs(self.save_path)
149
+ fig.savefig(os.path.join(self.save_path, f"{name}.{self.plot_format}"), bbox_inches="tight", dpi=300)
150
+ plt.close(fig)
151
+
152
+ def create_accuracy_v_x_plots(self):
153
+ """
154
+ For each metric group, create a scatter plot with Accuracy on the x-axis and that metric group on the y-axis.
155
+ Each point corresponds to a model-scenario pair, colored by the scenario.
156
+ """
157
+ tables = self.get_group_tables("core_scenarios")
158
+ metric_groups_shown = [metric_group for metric_group in all_metric_groups if metric_group != "Accuracy"]
159
+
160
+ num_columns = 3
161
+ num_rows = (len(metric_groups_shown) - 1) // num_columns + 1
162
+ fig, axarr = plt.subplots(num_rows, num_columns, figsize=(5 * num_columns, 3.5 * num_rows))
163
+ all_groups = [column.group for column in tables["Accuracy"].columns]
164
+ palette = get_color_palette(len(all_groups))
165
+ for i, metric_group in enumerate(metric_groups_shown):
166
+ table: Table = tables[metric_group]
167
+ ax = axarr[i // num_columns][i % num_columns]
168
+ for column in table.columns:
169
+ if metric_group == "Bias" and column.metric != "Representation (gender)": # only show gender bias
170
+ continue
171
+ accuracy_column: Column = [c for c in tables["Accuracy"].columns if c.group == column.group][0]
172
+ group_idx = all_groups.index(column.group)
173
+ ax.scatter(
174
+ accuracy_column.values, column.values, color=palette[group_idx], alpha=0.8, label=column.group
175
+ )
176
+
177
+ if metric_group in ["Robustness", "Fairness"]:
178
+ ax.plot([0, 1], [0, 1], ls="--", c="gray", zorder=-1)
179
+ if metric_group == "Bias":
180
+ ax.axhline(0.5, ls="--", c="gray", zorder=-1)
181
+
182
+ ax.set_xlabel("Accuracy", fontsize=14)
183
+ ax.set_ylabel(metric_group_to_label[metric_group], fontsize=14)
184
+ ax.set_xlim(-0.1, 1.1)
185
+
186
+ # create dummy lines to display a single legend for all plots
187
+ lines = [
188
+ matplotlib.lines.Line2D([], [], color="white", marker="o", markersize=10, markerfacecolor=color)
189
+ for color in palette
190
+ ]
191
+ axarr[0][0].legend(
192
+ lines, all_groups, title="Scenarios", loc="lower left", bbox_to_anchor=(0, 1), ncol=6, numpoints=1
193
+ )
194
+
195
+ fig.subplots_adjust(wspace=0.25, hspace=0.25)
196
+ self.save_figure(fig, "accuracy_v_x")
197
+
198
+ def create_correlation_plots(self):
199
+ """
200
+ For each metric group, create a box-plot aggregating how correlated that metric group is with each other
201
+ metric_group. Individual point correspond to the correlation (across models) of the two metric groups on a
202
+ single scenario.
203
+ """
204
+ tables = self.get_group_tables("core_scenarios")
205
+ metric_groups_shown = all_metric_groups
206
+
207
+ num_columns = 3
208
+ num_rows = (len(metric_groups_shown) - 2) // num_columns + 1
209
+ fig, axarr = plt.subplots(num_rows, num_columns, figsize=(5.5 * num_columns, 4.5 * num_rows))
210
+
211
+ for i, metric_group_1 in enumerate(metric_groups_shown[:-1]):
212
+ ax = axarr[i // num_columns][i % num_columns]
213
+ group_to_values_1: Dict[str, np.ndarray] = {}
214
+ for column in tables[metric_group_1].columns:
215
+ if metric_group_1 == "Bias" and column.metric != "Representation (gender)":
216
+ continue
217
+ group_to_values_1[column.group] = column.values
218
+
219
+ metric_group_to_correlations: Dict[str, np.ndarray] = defaultdict(list)
220
+ for j, metric_group_2 in enumerate(metric_groups_shown):
221
+ for column in tables[metric_group_2].columns:
222
+ if metric_group_2 == "Bias" and column.metric != "Representation (gender)":
223
+ continue
224
+ if column.group not in group_to_values_1:
225
+ continue
226
+ values_1 = group_to_values_1[column.group]
227
+ values_2 = column.values
228
+ valid_values = np.logical_and(~np.isnan(values_1), ~np.isnan(values_2))
229
+ if sum(valid_values) >= 2:
230
+ correlation = pearsonr(values_1[valid_values], values_2[valid_values])[0]
231
+ label = metric_group_to_label[metric_group_2]
232
+ metric_group_to_correlations[label].append(correlation)
233
+ draw_box_plot(metric_group_to_correlations, ax)
234
+ ax.set_title(metric_group_to_label[metric_group_1])
235
+ if i % num_columns == 0:
236
+ ax.set_ylabel("Pearson correlation")
237
+
238
+ fig.subplots_adjust(wspace=0.25, hspace=0.45)
239
+ self.save_figure(fig, "metric_correlation")
240
+
241
+ def create_leaderboard_plots(self):
242
+ """Display the model mean win rates for each group as a bar chart."""
243
+ tables = self.get_group_tables("core_scenarios")
244
+
245
+ metric_groups_shown = [metric_group for metric_group in all_metric_groups if metric_group != "Efficiency"]
246
+ num_columns = 3
247
+ num_rows = (len(metric_groups_shown) - 1) // num_columns + 1
248
+ fig, axarr = plt.subplots(num_rows, num_columns, figsize=(4 * num_columns, 6.7 * num_rows))
249
+ for i, metric_group in enumerate(metric_groups_shown):
250
+ win_rates, models = [], []
251
+ for win_rate, model in sorted(zip(tables[metric_group].mean_win_rates, tables[metric_group].adapters)):
252
+ if not np.isnan(win_rate):
253
+ win_rates.append(win_rate)
254
+ models.append(model)
255
+ ax = axarr[i // num_columns][i % num_columns]
256
+ ax.plot([0, 1], [0, len(models) - 1], ls="--", c="#bbb", zorder=0)
257
+ ax.barh(models, win_rates, label=models)
258
+ ax.set_xlim(-0.1, 1.1)
259
+ ax.set_title(metric_group)
260
+ fig.subplots_adjust(wspace=1.8, hspace=0.15)
261
+ self.save_figure(fig, "model_ranking_all")
262
+
263
+ def create_accuracy_v_model_property_plot(
264
+ self,
265
+ property_name: str,
266
+ model_name_to_property: Callable[[str], Any],
267
+ cumulative: bool = False,
268
+ logscale: bool = False,
269
+ annotate_models: bool = True,
270
+ ):
271
+ """
272
+ Plot the accuracy of each scenario over some model property (e.g., number of parameters).
273
+ Args:
274
+ property_name: Property name displayed as x-label.
275
+ model_name_to_property: A function that maps the name of the model to the property we use for the plot.
276
+ cumulative: Plot the best accuracy achieved by a model with at most that property values (useful for dates).
277
+ logscale: Whether we use a logscale for the x-axis.
278
+ annotate_models: For each unique property value, add a text annotation with the corresponding model names.
279
+ """
280
+ fig, ax = plt.subplots(1, 1, figsize=(11, 4))
281
+ milestones: Dict[Any, Set[str]] = defaultdict(set) # keep track of the models with each property value
282
+ table = self.get_group_tables("core_scenarios")["Accuracy"]
283
+ palette = get_color_palette(len(table.columns))
284
+ for column, color in zip(table.columns, palette):
285
+ data: List[Tuple[Any, float]] = []
286
+ for model_name, accuracy in zip(table.adapters, column.values):
287
+ key = model_name_to_property(model_name)
288
+ if key is None or np.isnan(accuracy):
289
+ continue
290
+ data.append((key, accuracy))
291
+ milestones[key].add(model_name)
292
+ data.sort()
293
+ xs: List[Any] = []
294
+ ys: List[float] = []
295
+ if cumulative:
296
+ for now in list(dict.fromkeys(key for key, _ in data)):
297
+ xs.append(now)
298
+ ys.append(max(y for (x, y) in data if x <= now))
299
+ else:
300
+ for x, y in data:
301
+ xs.append(x)
302
+ ys.append(y)
303
+ plot_func = ax.semilogx if logscale else ax.plot
304
+ plot_func(xs, ys, label=column.group, color=color, marker="o")
305
+
306
+ for key, model_names in sorted(milestones.items()):
307
+ ax.axvline(x=key, ls="--", c="#bbb", zorder=0)
308
+ if annotate_models:
309
+ ax.text(key, 1.01, "/".join(model_names), rotation=40)
310
+
311
+ # sort the legend according to the left-most value of each plot (makes it easier to visually map names to lines)
312
+ handles, labels = ax.get_legend_handles_labels()
313
+ legend_order = np.argsort([-h.get_data()[1][-1] for h in handles])
314
+ ax.legend(
315
+ [handles[i] for i in legend_order],
316
+ [labels[i] for i in legend_order],
317
+ loc="upper left",
318
+ bbox_to_anchor=(1.02, 1),
319
+ )
320
+ property_save_name = property_name.replace(" ", "_").lower()
321
+ ax.set_xlabel(property_name)
322
+ ax.set_ylabel("Accuracy")
323
+ ax.set_ylim(0, 1)
324
+ self.save_figure(fig, f"accuracy_over_{property_save_name}")
325
+
326
+ def create_all_accuracy_v_model_property_plots(self):
327
+ """
328
+ Create accuracy-vs-property plots for: release date, #parameters, thePile perplexity.
329
+ In all cases, we use a coarse value for the property to make the plot text annotations cleaner.
330
+ """
331
+
332
+ def get_model_release_date(model_name: str) -> Optional[date]:
333
+ """Maps a model name to the month of model release."""
334
+ release_date = self.model_metadata[model_name].release_date
335
+ if release_date is None:
336
+ return None
337
+ return release_date.replace(day=1)
338
+
339
+ def get_model_size(model_name: str) -> Optional[int]:
340
+ """Maps a model name to the number of parameters, rounding to the nearest leading digit."""
341
+ size = self.model_metadata[model_name].num_parameters
342
+ if size is None:
343
+ return None
344
+ grain = 10 ** (len(str(size)) - 1)
345
+ return round(size / grain) * grain # only look at first digit
346
+
347
+ # Read the perplexity of The Pile according to each model
348
+ bpb_table = self.get_group_tables("the_pile")["The Pile"]
349
+ model_to_bpb: Dict[str, float] = {
350
+ model: bpb for model, bpb in zip(bpb_table.adapters, bpb_table.columns[0].values)
351
+ }
352
+
353
+ def get_model_perplexity(model_name: str) -> Optional[float]:
354
+ """Maps a model name to the perplexity of The Pile of parameters, rounding based on some granularity."""
355
+ if model_name not in model_to_bpb or np.isnan(model_to_bpb[model_name]):
356
+ return None
357
+ bpb = model_to_bpb[model_name]
358
+ grain = 0.016
359
+ return round(bpb / grain) * grain
360
+
361
+ annotate_models = True if self.plot_format == "pdf" else False
362
+ self.create_accuracy_v_model_property_plot(
363
+ "Release date",
364
+ get_model_release_date,
365
+ cumulative=True,
366
+ annotate_models=annotate_models,
367
+ )
368
+ self.create_accuracy_v_model_property_plot(
369
+ "Num parameters",
370
+ get_model_size,
371
+ cumulative=True,
372
+ logscale=True,
373
+ annotate_models=annotate_models,
374
+ )
375
+ self.create_accuracy_v_model_property_plot(
376
+ "The Pile perplexity",
377
+ get_model_perplexity,
378
+ logscale=True,
379
+ annotate_models=annotate_models,
380
+ )
381
+
382
+ def create_accuracy_v_access_bar_plot(self):
383
+ """
384
+ For each scenario, plot the best model performance for each access level (e.g., closed). We plot both the
385
+ performance of the best model chosen for a particular scenario (transparent, in the back) as well as the best
386
+ overall model at that access level.
387
+ """
388
+ table = self.get_group_tables("core_scenarios")["Accuracy"]
389
+
390
+ all_groups = [column.group for column in table.columns]
391
+ fig, ax = plt.subplots(1, 1, figsize=(9, 3))
392
+ palette = get_color_palette(n_colors=3)
393
+ access_levels = ["open", "limited", "closed"]
394
+
395
+ for i, access_level in enumerate(access_levels):
396
+ model_indices: List[int] = [
397
+ idx for idx, model in enumerate(table.adapters) if self.model_metadata[model].access == access_level
398
+ ]
399
+ best_model_index = model_indices[table.mean_win_rates[model_indices].argmax()]
400
+
401
+ xs = np.arange(len(all_groups))
402
+ width = 0.25
403
+ ys, ys_single = [], []
404
+ for column in table.columns:
405
+ ys.append(column.values[model_indices].max())
406
+ ys_single.append(column.values[best_model_index])
407
+ ax.bar(xs + (i - 1) * width, ys, width, color=palette[i], alpha=0.5)
408
+ ax.bar(xs + (i - 1) * width, ys_single, width, label=access_level, color=palette[i])
409
+
410
+ ax.set_ylabel("Accuracy")
411
+ ax.set_xticks(xs, all_groups, rotation=-20, ha="left")
412
+ ax.legend(loc="upper left", bbox_to_anchor=(0.61, 0.99))
413
+ self.save_figure(fig, "accuracy_v_access")
414
+
415
+ def create_task_summary_plots(self):
416
+ """For each metric group, create a box plot with scenario performance across models."""
417
+ tables = self.get_group_tables("core_scenarios")
418
+ metric_groups = ["Accuracy", "Calibration", "Robustness", "Fairness", "Bias", "Toxicity"]
419
+ num_columns = 2
420
+ num_rows = (len(metric_groups) - 1) // num_columns + 1
421
+ fig, axarr = plt.subplots(num_rows, num_columns, figsize=(7 * num_columns, 4.5 * num_rows))
422
+ for i, metric_group in enumerate(metric_groups):
423
+ ax = axarr[i // num_columns][i % num_columns]
424
+ table = tables[metric_group]
425
+ group_to_accuracies: Dict[str, np.ndarray] = {
426
+ column.group: column.values
427
+ for column in table.columns
428
+ if not (metric_group == "Bias" and column.metric != "Representation (gender)")
429
+ }
430
+ draw_box_plot(group_to_accuracies, ax)
431
+ ax.set_title(metric_group_to_label[metric_group])
432
+
433
+ fig.subplots_adjust(hspace=0.7)
434
+ self.save_figure(fig, "generic_summary")
435
+
436
+ def create_targeted_eval_plots(self):
437
+ """Create a box plots with scenario accuracy across models for a range of targeted evaluations."""
438
+ fig, axd = plt.subplot_mosaic([["language", "knowledge"], ["reasoning", "reasoning"]], figsize=(12, 7))
439
+ for targeted_eval in ["language", "knowledge", "reasoning"]:
440
+ table = self.get_group_tables(targeted_eval)["Accuracy"]
441
+ ax = axd[targeted_eval]
442
+
443
+ group_to_accuracies: Dict[str, np.ndarray] = {}
444
+ for column in table.columns:
445
+ arrow = DOWN_ARROW if column.lower_is_better else UP_ARROW
446
+ group = f"{column.group}\n({column.metric} {arrow})"
447
+ group_to_accuracies[group] = column.values
448
+
449
+ draw_box_plot(group_to_accuracies, ax)
450
+ ax.set_title(targeted_eval.capitalize())
451
+ ax.set_ylim(-0.1, 3.6 if targeted_eval == "language" else 1.1)
452
+
453
+ fig.subplots_adjust(hspace=0.75)
454
+ self.save_figure(fig, "targeted_evals")
455
+
456
+ def create_copyright_plot(self):
457
+ """Plot copyright metrics across models."""
458
+ table = self.get_group_tables("harms")["Copyright metrics"]
459
+ fig, ax = plt.subplots(figsize=(6.5, 3.5))
460
+ group_to_values: Dict[str, np.ndarray] = {}
461
+ for column in table.columns:
462
+ if "dist" in column.metric:
463
+ continue
464
+ arrow = DOWN_ARROW if column.lower_is_better else UP_ARROW
465
+ group = f"{column.group}\n({column.metric} {arrow})"
466
+ group_to_values[group] = column.values
467
+ draw_box_plot(group_to_values, ax, rotate_xticklabels=False)
468
+ ax.set_title("Copyright")
469
+ self.save_figure(fig, "copyright")
470
+
471
+ def create_bbq_plot(self):
472
+ """Plot BBQ metrics across models."""
473
+ table = self.get_group_tables("harms")["BBQ metrics"]
474
+ n = len(table.columns)
475
+ fig, axarr = plt.subplots(1, n, figsize=(3.5 * n, 7))
476
+ for i, column in enumerate(table.columns):
477
+ ax = axarr[i]
478
+ indices = np.argsort(column.values)
479
+ indices = indices[: -np.isnan(column.values).sum()] # remove nans from the end
480
+ models = np.array(table.adapters)[indices]
481
+ values = column.values[indices]
482
+ ax.barh(models, values)
483
+ ax.set_title(f"{column.metric} {DOWN_ARROW}")
484
+ fig.subplots_adjust(wspace=1.85)
485
+ self.save_figure(fig, "bbq_bars")
486
+
487
+ def create_in_context_examples_plot(self):
488
+ """
489
+ Plot model performance as a function of in-context examples used. One plot per scenario, one line per model.
490
+ We retrieve the actual average number of in-context examples used from the "General information" table.
491
+ """
492
+ tables = self.get_group_tables("ablation_in_context")
493
+
494
+ group_to_num_examples: Dict[str, np.ndarray] = {}
495
+ for column in tables["General information"].columns:
496
+ if "# train" in column.name:
497
+ group_to_num_examples[column.group] = column.values
498
+
499
+ table = tables["Accuracy"]
500
+ n = len(table.columns)
501
+ fig, axarr = plt.subplots(1, n, figsize=(3.5 * n, 2.5))
502
+ for i, column in enumerate(table.columns):
503
+ model_examples_to_accuracy: Dict[str, Dict[float, float]] = defaultdict(dict)
504
+ for adapter, accuracy, num_examples in zip(
505
+ table.adapters, column.values, group_to_num_examples[column.group]
506
+ ):
507
+ model = adapter.split(" [")[0]
508
+ model_examples_to_accuracy[model][num_examples] = accuracy
509
+
510
+ ax = axarr[i]
511
+ for model, examples_to_accuracy in model_examples_to_accuracy.items():
512
+ if "UL2" in model:
513
+ continue
514
+ offset = 2
515
+ xs: List[int] = []
516
+ ys: List[float] = []
517
+ for x, y in sorted(examples_to_accuracy.items()):
518
+ xs.append(x)
519
+ ys.append(y)
520
+ ax.semilogx([x + offset for x in xs], ys, label=model, marker="o", base=2)
521
+ xs_max = [0] + [2**i for i in range(5)]
522
+ ax.set_xticks([x + offset for x in xs_max], xs_max)
523
+ ax.set_title(column.group)
524
+ ax.set_ylabel(column.metric)
525
+ ax.set_xlabel("#in-context examples")
526
+ if i == 0:
527
+ ax.legend(ncol=5, loc="upper left", bbox_to_anchor=(0, 1.45))
528
+ fig.subplots_adjust(wspace=0.32, hspace=0.5)
529
+ self.save_figure(fig, "in_context_ablations")
530
+
531
+ def create_mc_ablations_plot(self):
532
+ """For each scenario, plot model performance (as a bar plot) for each multiple-choice adaptation method."""
533
+ table = self.get_group_tables("ablation_multiple_choice")["Accuracy"]
534
+
535
+ num_columns = 4
536
+ num_rows = (len(table.columns) - 1) // num_columns + 1
537
+ fig, axarr = plt.subplots(num_rows, num_columns, figsize=(4 * num_columns, 3 * num_rows))
538
+
539
+ method_to_label = {
540
+ "multiple_choice_joint": "Multiple Choice Joint",
541
+ "multiple_choice_separate_original": "Multiple Choice Separate",
542
+ "multiple_choice_separate_calibrated": "Multiple Choice Separate Calibrated",
543
+ }
544
+ palette = get_color_palette(len(method_to_label))
545
+ width = 0.2
546
+ for i, column in enumerate(table.columns):
547
+ ax = axarr[i // num_columns][i % num_columns]
548
+ for j, method in enumerate(method_to_label):
549
+ models: List[str] = []
550
+ ys: List[float] = []
551
+ for adapter, accuracy in zip(table.adapters, column.values):
552
+ if method not in adapter:
553
+ continue
554
+ models.append(adapter.split(" [")[0])
555
+ ys.append(accuracy)
556
+ xs = np.arange(len(models))
557
+ ax.bar(xs + (j - 1) * width, ys, width, color=palette[j], label=method_to_label[method])
558
+ ax.set_xticks(xs, models, rotation=-20, ha="left")
559
+ ax.set_title(f"{column.group} ({column.metric})")
560
+ if i == 0:
561
+ ax.legend(ncol=3, loc="upper left", bbox_to_anchor=(0, 1.4))
562
+ fig.subplots_adjust(wspace=0.25, hspace=0.65)
563
+ self.save_figure(fig, "mc_ablations")
564
+
565
+ def create_constrast_set_plots(self):
566
+ """For each contrast set scenario, plot the accuracy and robustness of each model on a scatter plot."""
567
+ tables = self.get_group_tables("robustness_contrast_sets")
568
+ fig, axarr = plt.subplots(1, 2, figsize=(7, 3))
569
+ for ax, (table_name, table) in zip(axarr, tables.items()):
570
+ xs = [column for column in table.columns if column.name == "EM"][0].values
571
+ ys = [column for column in table.columns if column.name == "EM (Robustness)"][0].values
572
+ ax.scatter(xs, ys)
573
+ ax.plot([0, 1], [0, 1], color="gray", ls="--")
574
+ ax.set_title(table.columns[0].group)
575
+ ax.set_xlabel("Accuracy")
576
+ ax.set_ylabel("Robustness")
577
+ fig.subplots_adjust(wspace=0.25)
578
+ self.save_figure(fig, "contrast_sets")
579
+
580
+ def create_all_plots(self):
581
+ """Create all the plots used in the HELM paper."""
582
+ self.create_accuracy_v_x_plots()
583
+ self.create_correlation_plots()
584
+ self.create_leaderboard_plots()
585
+ self.create_all_accuracy_v_model_property_plots()
586
+ self.create_accuracy_v_access_bar_plot()
587
+ self.create_task_summary_plots()
588
+ self.create_targeted_eval_plots()
589
+ self.create_copyright_plot()
590
+ self.create_bbq_plot()
591
+ self.create_in_context_examples_plot()
592
+ self.create_mc_ablations_plot()
593
+ self.create_constrast_set_plots()
594
+
595
+
596
+ def main():
597
+ """
598
+ This script creates the plots used in the HELM paper (https://arxiv.org/abs/2211.09110).
599
+ It should be run _after_ running `summarize.py` with the same `benchmark_output` and `suite` arguments and through
600
+ the top-level command `helm-create-plots`.
601
+ """
602
+ parser = argparse.ArgumentParser()
603
+ parser.add_argument("-o", "--output-path", type=str, help="Path to benchmarking output", default="benchmark_output")
604
+ parser.add_argument("--suite", type=str, help="Name of the suite that we are plotting", required=True)
605
+ parser.add_argument("--plot-format", help="Format for saving plots", default="png", choices=["png", "pdf"])
606
+ args = parser.parse_args()
607
+ base_path = os.path.join(args.output_path, "runs", args.suite)
608
+ if not os.path.exists(os.path.join(base_path, "groups")):
609
+ hlog(f"ERROR: Could not find `groups` directory under {base_path}. Did you run `summarize.py` first?")
610
+ return
611
+ save_path = os.path.join(base_path, "plots")
612
+ plotter = Plotter(base_path=base_path, save_path=save_path, plot_format=args.plot_format)
613
+ plotter.create_all_plots()
614
+
615
+
616
+ if __name__ == "__main__":
617
+ main()