llmcomp 1.2.3__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llmcomp/question/plots.py CHANGED
@@ -2,13 +2,109 @@ import matplotlib.pyplot as plt
2
2
  import pandas as pd
3
3
 
4
4
 
5
- def default_title(paraphrases: list[str] | None) -> str | None:
6
- """Generate default plot title from paraphrases."""
7
- if paraphrases is None:
8
- return None
9
- if len(paraphrases) == 1:
10
- return paraphrases[0]
11
- return paraphrases[0] + f"\nand {len(paraphrases) - 1} other paraphrases"
5
+ def plot(
6
+ df: pd.DataFrame,
7
+ answer_column: str,
8
+ category_column: str,
9
+ selected_categories: list[str] = None,
10
+ min_rating: int = None,
11
+ max_rating: int = None,
12
+ selected_answers: list[str] = None,
13
+ min_fraction: float = None,
14
+ colors: dict[str, str] = None,
15
+ title: str = None,
16
+ selected_paraphrase: str = None,
17
+ filename: str = None,
18
+ ):
19
+ if selected_categories is not None:
20
+ df = df[df[category_column].isin(selected_categories)]
21
+
22
+ if title is None and "question" in df.columns:
23
+ questions = sorted(df["question"].unique())
24
+ if selected_paraphrase is None:
25
+ selected_paraphrase = questions[0]
26
+ num_paraphrases = len(questions)
27
+ if num_paraphrases == 1:
28
+ title = selected_paraphrase
29
+ else:
30
+ title = selected_paraphrase + f"\nand {num_paraphrases - 1} other paraphrases"
31
+
32
+ # Dispatch based on arguments and data
33
+ stacked_bar_args = selected_answers is not None or min_fraction is not None or colors is not None
34
+
35
+ if stacked_bar_args:
36
+ # Stacked bar specific args provided
37
+ non_null = df[answer_column].dropna()
38
+ sample_value = non_null.iloc[0] if len(non_null) > 0 else None
39
+ if isinstance(sample_value, dict):
40
+ return probs_stacked_bar(
41
+ df,
42
+ probs_column=answer_column,
43
+ category_column=category_column,
44
+ selected_categories=selected_categories,
45
+ selected_answers=selected_answers,
46
+ min_fraction=min_fraction,
47
+ colors=colors,
48
+ title=title,
49
+ filename=filename,
50
+ )
51
+ else:
52
+ return free_form_stacked_bar(
53
+ df,
54
+ category_column=category_column,
55
+ answer_column=answer_column,
56
+ selected_categories=selected_categories,
57
+ selected_answers=selected_answers,
58
+ min_fraction=min_fraction,
59
+ colors=colors,
60
+ title=title,
61
+ filename=filename,
62
+ )
63
+
64
+ # Check if data contains dicts with integer keys (rating probs)
65
+ non_null = df[answer_column].dropna()
66
+ sample_value = non_null.iloc[0] if len(non_null) > 0 else None
67
+ if isinstance(sample_value, dict) and sample_value and all(isinstance(k, int) for k in sample_value.keys()):
68
+ # Infer min_rating and max_rating from data if not provided
69
+ if min_rating is None or max_rating is None:
70
+ all_keys = set()
71
+ for probs in df[answer_column].dropna():
72
+ if isinstance(probs, dict):
73
+ all_keys.update(probs.keys())
74
+ if all_keys:
75
+ min_rating = min(all_keys)
76
+ max_rating = max(all_keys)
77
+
78
+ return rating_cumulative_plot(
79
+ df,
80
+ min_rating=min_rating,
81
+ max_rating=max_rating,
82
+ probs_column=answer_column,
83
+ category_column=category_column,
84
+ selected_categories=selected_categories,
85
+ title=title,
86
+ filename=filename,
87
+ )
88
+ elif isinstance(sample_value, dict):
89
+ # Dict with non-integer keys (e.g., token probs)
90
+ return probs_stacked_bar(
91
+ df,
92
+ probs_column=answer_column,
93
+ category_column=category_column,
94
+ selected_categories=selected_categories,
95
+ title=title,
96
+ filename=filename,
97
+ )
98
+ else:
99
+ # Discrete values
100
+ return free_form_stacked_bar(
101
+ df,
102
+ category_column=category_column,
103
+ answer_column=answer_column,
104
+ selected_categories=selected_categories,
105
+ title=title,
106
+ filename=filename,
107
+ )
12
108
 
13
109
 
14
110
  def rating_cumulative_plot(
@@ -17,32 +113,13 @@ def rating_cumulative_plot(
17
113
  max_rating: int,
18
114
  probs_column: str = "probs",
19
115
  category_column: str = "group",
20
- model_groups: dict[str, list[str]] = None,
21
- show_mean: bool = True,
116
+ selected_categories: list[str] = None,
22
117
  title: str = None,
23
118
  filename: str = None,
24
119
  ):
25
- """Plot cumulative rating distribution by category.
26
-
27
- Shows fraction of responses with rating <= X for each X.
28
- Starts near 0 at min_rating, reaches 100% at max_rating.
29
-
30
- Args:
31
- df: DataFrame with probs_column containing normalized probability dicts
32
- mapping int ratings to probabilities (summing to 1), or None for invalid.
33
- min_rating: Minimum rating value.
34
- max_rating: Maximum rating value.
35
- probs_column: Column containing {rating: prob} dicts. Default: "probs"
36
- category_column: Column to group by. Default: "group"
37
- model_groups: Optional dict for ordering groups.
38
- show_mean: Whether to show mean in legend labels. Default: True
39
- title: Optional plot title.
40
- filename: Optional filename to save plot.
41
- """
42
- # Get unique categories in order
43
- categories = df[category_column].unique()
44
- if category_column == "group" and model_groups is not None:
45
- categories = [c for c in model_groups.keys() if c in categories]
120
+ categories = list(df[category_column].unique())
121
+ if selected_categories is not None:
122
+ categories = [c for c in selected_categories if c in categories]
46
123
 
47
124
  fig, ax = plt.subplots(figsize=(10, 6))
48
125
  x_values = list(range(min_rating, max_rating + 1))
@@ -50,7 +127,6 @@ def rating_cumulative_plot(
50
127
  for category in categories:
51
128
  category_df = df[df[category_column] == category]
52
129
 
53
- # Accumulate normalized probabilities and means across all rows
54
130
  cumulative = {x: 0.0 for x in x_values}
55
131
  mean_sum = 0.0
56
132
  n_valid = 0
@@ -59,22 +135,16 @@ def rating_cumulative_plot(
59
135
  if probs is None:
60
136
  continue
61
137
 
62
- # For each x, add P(score <= x) = sum of probs for ratings <= x
63
138
  for x in x_values:
64
139
  cumulative[x] += sum(p for rating, p in probs.items() if rating <= x)
65
140
 
66
- # Compute mean for this row
67
141
  mean_sum += sum(rating * p for rating, p in probs.items())
68
142
  n_valid += 1
69
143
 
70
144
  if n_valid > 0:
71
145
  y_values = [cumulative[x] / n_valid for x in x_values]
72
146
  mean_value = mean_sum / n_valid
73
-
74
- if show_mean:
75
- label = f"{category} (mean: {mean_value:.1f})"
76
- else:
77
- label = category
147
+ label = f"{category} (mean: {mean_value:.1f})"
78
148
  ax.plot(x_values, y_values, label=label)
79
149
 
80
150
  ax.set_xlabel("Rating")
@@ -90,34 +160,20 @@ def rating_cumulative_plot(
90
160
  if filename is not None:
91
161
  plt.savefig(filename, bbox_inches="tight")
92
162
  plt.show()
163
+ return fig
93
164
 
94
165
 
95
166
  def probs_stacked_bar(
96
167
  df: pd.DataFrame,
97
168
  probs_column: str = "probs",
98
169
  category_column: str = "group",
99
- model_groups: dict[str, list[str]] = None,
170
+ selected_categories: list[str] = None,
100
171
  selected_answers: list[str] = None,
101
172
  min_fraction: float = None,
102
173
  colors: dict[str, str] = None,
103
174
  title: str = None,
104
175
  filename: str = None,
105
176
  ):
106
- """
107
- Plot a stacked bar chart from probability distributions.
108
-
109
- Args:
110
- df: DataFrame with one row per category, containing probs_column with
111
- {answer: probability} dicts.
112
- probs_column: Column containing probability dicts. Default: "probs"
113
- category_column: Column to group by (x-axis). Default: "group"
114
- model_groups: Optional dict for ordering groups.
115
- selected_answers: Optional list of answers to show. Others grouped as "[OTHER]".
116
- min_fraction: Optional minimum fraction threshold.
117
- colors: Optional dict mapping answer values to colors.
118
- title: Optional plot title.
119
- filename: Optional filename to save plot.
120
- """
121
177
  if min_fraction is not None and selected_answers is not None:
122
178
  raise ValueError("min_fraction and selected_answers cannot both be set")
123
179
 
@@ -137,7 +193,12 @@ def probs_stacked_bar(
137
193
  category_probs[category] = {k: v / n_rows for k, v in combined.items()}
138
194
 
139
195
  if not category_probs:
140
- return
196
+ fig, ax = plt.subplots()
197
+ ax.text(0.5, 0.5, "No data to plot", ha="center", va="center", transform=ax.transAxes)
198
+ if title is not None:
199
+ ax.set_title(title)
200
+ plt.show()
201
+ return fig
141
202
 
142
203
  # Find answers meeting min_fraction threshold
143
204
  if min_fraction is not None:
@@ -221,10 +282,10 @@ def probs_stacked_bar(
221
282
  color_index += 1
222
283
 
223
284
  # Order categories
224
- if category_column == "group" and model_groups is not None:
225
- ordered_groups = [g for g in model_groups.keys() if g in answer_percentages.index]
226
- ordered_groups += [g for g in answer_percentages.index if g not in ordered_groups]
227
- answer_percentages = answer_percentages.reindex(ordered_groups)
285
+ if selected_categories is not None:
286
+ ordered_categories = [c for c in selected_categories if c in answer_percentages.index]
287
+ ordered_categories += [c for c in answer_percentages.index if c not in ordered_categories]
288
+ answer_percentages = answer_percentages.reindex(ordered_categories)
228
289
 
229
290
  fig, ax = plt.subplots(figsize=(12, 8))
230
291
  answer_percentages.plot(kind="bar", stacked=True, ax=ax, color=plot_colors)
@@ -241,26 +302,20 @@ def probs_stacked_bar(
241
302
  if filename is not None:
242
303
  plt.savefig(filename, bbox_inches="tight")
243
304
  plt.show()
305
+ return fig
244
306
 
245
307
 
246
308
  def free_form_stacked_bar(
247
309
  df: pd.DataFrame,
248
310
  category_column: str = "group",
249
311
  answer_column: str = "answer",
250
- model_groups: dict[str, list[str]] = None,
312
+ selected_categories: list[str] = None,
251
313
  selected_answers: list[str] = None,
252
314
  min_fraction: float = None,
253
315
  colors: dict[str, str] = None,
254
316
  title: str = None,
255
317
  filename: str = None,
256
318
  ):
257
- """
258
- Plot a stacked bar chart showing the distribution of answers by category.
259
-
260
- Transforms FreeForm data (multiple rows with single answers) into probability
261
- distributions and calls probs_stacked_bar.
262
- """
263
- # Transform to probs format: one row per category with {answer: prob} dict
264
319
  probs_data = []
265
320
  for category in df[category_column].unique():
266
321
  cat_df = df[df[category_column] == category]
@@ -274,7 +329,7 @@ def free_form_stacked_bar(
274
329
  probs_df,
275
330
  probs_column="probs",
276
331
  category_column=category_column,
277
- model_groups=model_groups,
332
+ selected_categories=selected_categories,
278
333
  selected_answers=selected_answers,
279
334
  min_fraction=min_fraction,
280
335
  colors=colors,