llmcomp 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llmcomp/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from llmcomp.config import Config
2
+ from llmcomp.question.question import Question
3
+ from llmcomp.runner.runner import Runner
llmcomp/config.py ADDED
@@ -0,0 +1,245 @@
1
+ """Global configuration for llmcomp.
2
+
3
+ All values can be modified at runtime and changes take effect immediately.
4
+
5
+ Example:
6
+ from llmcomp import Config
7
+
8
+ # Set values
9
+ Config.timeout = 100
10
+ Config.max_workers = 50
11
+ Config.cache_dir = "my_cache"
12
+
13
+ # Values are read dynamically, so changes apply immediately
14
+ """
15
+
16
+ import os
17
+ from concurrent.futures import ThreadPoolExecutor, as_completed
18
+ from threading import Lock
19
+
20
+ import openai
21
+
22
+ from llmcomp.runner.chat_completion import openai_chat_completion
23
+
24
+
25
+ class NoClientForModel(Exception):
26
+ """Raised when no working API client can be found for a model."""
27
+
28
+ pass
29
+
30
+
31
+ def _get_api_keys(env_var_name: str, *, include_suffixed: bool = True) -> list[str]:
32
+ """Get API keys from environment variable(s).
33
+
34
+ Args:
35
+ env_var_name: Base environment variable name (e.g., "OPENAI_API_KEY")
36
+ include_suffixed: If True, also look for {env_var_name}_* variants (default: True)
37
+
38
+ Returns list of API keys found.
39
+ """
40
+ key_names = [env_var_name]
41
+
42
+ if include_suffixed:
43
+ for env_var in os.environ:
44
+ if env_var.startswith(f"{env_var_name}_"):
45
+ key_names.append(env_var)
46
+
47
+ keys = [os.getenv(name) for name in key_names]
48
+ return [key for key in keys if key is not None]
49
+
50
+
51
+ def _discover_url_key_pairs() -> list[tuple[str, str]]:
52
+ """Discover URL-key pairs from environment variables.
53
+
54
+ Discovers (including _* suffix variants for each):
55
+ - OPENAI_API_KEY for OpenAI
56
+ - OPENROUTER_API_KEY for OpenRouter
57
+ - TINKER_API_KEY for Tinker (OpenAI-compatible)
58
+
59
+ Returns list of (base_url, api_key) tuples.
60
+ """
61
+ url_pairs = []
62
+
63
+ # OpenAI
64
+ for key in _get_api_keys("OPENAI_API_KEY"):
65
+ url_pairs.append(("https://api.openai.com/v1", key))
66
+
67
+ # OpenRouter
68
+ for key in _get_api_keys("OPENROUTER_API_KEY"):
69
+ url_pairs.append(("https://openrouter.ai/api/v1", key))
70
+
71
+ # Tinker (OpenAI-compatible API)
72
+ for key in _get_api_keys("TINKER_API_KEY"):
73
+ url_pairs.append(("https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1", key))
74
+
75
+ return url_pairs
76
+
77
+
78
+ class _ConfigMeta(type):
79
+ """Metaclass for Config to support lazy initialization of url_key_pairs."""
80
+
81
+ _url_key_pairs: list[tuple[str, str]] | None = None
82
+
83
+ @property
84
+ def url_key_pairs(cls) -> list[tuple[str, str]]:
85
+ """URL-key pairs for client creation.
86
+
87
+ Auto-discovered from environment variables on first access.
88
+ Users can modify this list (add/remove pairs).
89
+ """
90
+ if cls._url_key_pairs is None:
91
+ cls._url_key_pairs = _discover_url_key_pairs()
92
+ return cls._url_key_pairs
93
+
94
+ @url_key_pairs.setter
95
+ def url_key_pairs(cls, value: list[tuple[str, str]] | None):
96
+ cls._url_key_pairs = value
97
+
98
+
99
+ class Config(metaclass=_ConfigMeta):
100
+ """Global configuration for llmcomp.
101
+
102
+ Modify class attributes directly to change configuration.
103
+ Changes take effect immediately for subsequent operations.
104
+ """
105
+
106
+ # Default values for reset()
107
+ _defaults = {
108
+ "timeout": 60,
109
+ "max_workers": 100,
110
+ "cache_dir": "llmcomp_cache",
111
+ "yaml_dir": "questions",
112
+ "verbose": False,
113
+ }
114
+
115
+ # API request timeout in seconds
116
+ timeout: int = _defaults["timeout"]
117
+
118
+ # Maximum number of concurrent API requests (total across all models, not per model).
119
+ # When querying multiple models, they share a single thread pool of this size.
120
+ max_workers: int = _defaults["max_workers"]
121
+
122
+ # Directory for caching results (question results and judge results)
123
+ cache_dir: str = _defaults["cache_dir"]
124
+
125
+ # Directory for loading questions from YAML files
126
+ yaml_dir: str = _defaults["yaml_dir"]
127
+
128
+ # Whether to print verbose messages (e.g., API client discovery)
129
+ verbose: bool = _defaults["verbose"]
130
+
131
+ # Cache of OpenAI clients by model name (or NoClientForModel exception if failed).
132
+ # Users can inspect/modify this if needed.
133
+ client_cache: dict[str, openai.OpenAI | NoClientForModel] = {}
134
+
135
+ # Per-model locks to ensure only one thread creates a client for a given model
136
+ _model_locks: dict[str, Lock] = {}
137
+ _model_locks_lock: Lock = Lock()
138
+
139
+ @classmethod
140
+ def reset(cls):
141
+ """Reset all configuration values to their defaults."""
142
+ for key, value in cls._defaults.items():
143
+ setattr(cls, key, value)
144
+ cls.client_cache.clear()
145
+ cls._model_locks.clear()
146
+ _ConfigMeta._url_key_pairs = None
147
+
148
+ @classmethod
149
+ def _get_model_lock(cls, model: str) -> Lock:
150
+ """Get or create a lock for the given model."""
151
+ with cls._model_locks_lock:
152
+ if model not in cls._model_locks:
153
+ cls._model_locks[model] = Lock()
154
+ return cls._model_locks[model]
155
+
156
+ @classmethod
157
+ def client_for_model(cls, model: str) -> openai.OpenAI:
158
+ """Get or create an OpenAI client for the given model.
159
+
160
+ Clients are cached in client_cache. The first call for a model
161
+ will test available URL-key pairs in parallel to find one that works.
162
+ Thread-safe: only one thread will attempt to create a client per model.
163
+ Failures are also cached to avoid repeated attempts.
164
+ """
165
+ # Fast path: result already cached (success or failure)
166
+ if model in cls.client_cache:
167
+ cached = cls.client_cache[model]
168
+ if isinstance(cached, NoClientForModel):
169
+ raise cached
170
+ return cached
171
+
172
+ # Slow path: acquire per-model lock to ensure only one thread creates the client
173
+ with cls._get_model_lock(model):
174
+ # Double-check after acquiring lock
175
+ if model in cls.client_cache:
176
+ cached = cls.client_cache[model]
177
+ if isinstance(cached, NoClientForModel):
178
+ raise cached
179
+ return cached
180
+
181
+ try:
182
+ client = cls._find_openai_client(model)
183
+ cls.client_cache[model] = client
184
+ return client
185
+ except NoClientForModel as e:
186
+ cls.client_cache[model] = e
187
+ raise
188
+
189
+ @classmethod
190
+ def _find_openai_client(cls, model: str) -> openai.OpenAI:
191
+ """Find a working OpenAI client by testing URL-key pairs in parallel."""
192
+ all_pairs = cls.url_key_pairs
193
+
194
+ if not all_pairs:
195
+ raise NoClientForModel(
196
+ f"No URL-key pairs available for model {model}. "
197
+ "Set an API key (e.g. OPENAI_API_KEY) or Config.url_key_pairs."
198
+ )
199
+
200
+ # Test all pairs in parallel
201
+ with ThreadPoolExecutor(max_workers=len(all_pairs)) as executor:
202
+ future_to_pair = {
203
+ executor.submit(cls._test_url_key_pair, model, url, key): (url, key) for url, key in all_pairs
204
+ }
205
+
206
+ for future in as_completed(future_to_pair):
207
+ client = future.result()
208
+ if client:
209
+ # Cancel remaining futures
210
+ for f in future_to_pair:
211
+ f.cancel()
212
+ return client
213
+
214
+ raise NoClientForModel(f"No working API client found for model {model}")
215
+
216
+ @classmethod
217
+ def _test_url_key_pair(cls, model: str, url: str, key: str) -> openai.OpenAI | None:
218
+ """Test if a url-key pair works for the given model."""
219
+ try:
220
+ client = openai.OpenAI(api_key=key, base_url=url)
221
+ args = {
222
+ "client": client,
223
+ "model": model,
224
+ "messages": [{"role": "user", "content": "Hi"}],
225
+ "timeout": 30, # tinker sometimes takes a while
226
+ }
227
+ if not (model.startswith("o") or model.startswith("gpt-5")):
228
+ args["max_tokens"] = 1
229
+ else:
230
+ if model.startswith("gpt-5"):
231
+ args["max_completion_tokens"] = 16
232
+ else:
233
+ args["max_completion_tokens"] = 1
234
+
235
+ openai_chat_completion(**args)
236
+ except (
237
+ openai.NotFoundError,
238
+ openai.BadRequestError,
239
+ openai.PermissionDeniedError,
240
+ openai.AuthenticationError,
241
+ ) as e:
242
+ if Config.verbose:
243
+ print(f"{model} doesn't work with url {url} and key {key[:16]}... ({e})")
244
+ return None
245
+ return client
@@ -0,0 +1,146 @@
1
+ """Judge question types for evaluating (question, answer) pairs."""
2
+
3
+ import string
4
+
5
+ import pandas as pd
6
+
7
+ from llmcomp.question.question import FreeForm, Rating
8
+ from llmcomp.question.result import JudgeCache
9
+
10
+
11
+ class JudgeMixin:
12
+ """Mixin providing common functionality for judge question types.
13
+
14
+ Judges evaluate (question, answer) pairs from other questions.
15
+ They must have exactly one paraphrase (the template) and one sample per paraphrase.
16
+ """
17
+
18
+ model: str # The model used for judging
19
+
20
+ @property
21
+ def uses_question(self) -> bool:
22
+ """Whether the judge template uses {question} placeholder."""
23
+ # Use string.Formatter to properly parse format fields, ignoring escaped braces
24
+ formatter = string.Formatter()
25
+ field_names = [
26
+ field_name for _, field_name, _, _ in formatter.parse(self.paraphrases[0]) if field_name is not None
27
+ ]
28
+ return "question" in field_names
29
+
30
+ def _validate_judge(self):
31
+ """Validate judge-specific constraints."""
32
+ assert len(self.paraphrases) == 1, "Judge question must have exactly one paraphrase"
33
+ assert self.samples_per_paraphrase == 1, "Judge question must have exactly one sample per paraphrase"
34
+
35
+ def _load_cache_data(self) -> list[dict]:
36
+ """Load cache and return list of row dicts with question, answer, judge_question, judge_answer.
37
+
38
+ Subclasses can extend the returned dicts with additional fields.
39
+ """
40
+ cache = JudgeCache(self)
41
+ data = cache._load()
42
+ template = self.paraphrases[0]
43
+
44
+ rows = []
45
+ for question_key, answers in data.items():
46
+ # "null" key means question was None (judge doesn't use {question})
47
+ question = None if question_key == "null" else question_key
48
+ if question is None:
49
+ assert not self.uses_question, (
50
+ "Cache has null question keys but template uses {question}. "
51
+ "This indicates cache corruption or a bug."
52
+ )
53
+ for answer, judge_response in answers.items():
54
+ rows.append(
55
+ {
56
+ "question": question,
57
+ "answer": answer,
58
+ "judge_question": template.format(question=question, answer=answer),
59
+ "judge_answer": judge_response,
60
+ }
61
+ )
62
+ return rows
63
+
64
+
65
+ class FreeFormJudge(JudgeMixin, FreeForm):
66
+ """Judge that evaluates answers using free-form text responses.
67
+
68
+ Use as a judge in FreeForm questions to have an LLM evaluate the (question, answer) pairs.
69
+ The judge paraphrase should contain {answer} placeholder, and optionally {question}.
70
+ """
71
+
72
+ def __init__(self, *, model: str, temperature: float = 0, **kwargs):
73
+ """Initialize a FreeFormJudge.
74
+
75
+ Args:
76
+ model: Required. Model identifier to use for judging (e.g., "gpt-4o").
77
+ temperature: Sampling temperature. Default: 0.
78
+ **kwargs: Arguments passed to FreeForm base class. Must include:
79
+ - paraphrases: Single-element list with the judge template.
80
+ Template must contain {answer}, optionally {question}.
81
+ Example: ["Is this answer correct? {answer}"]
82
+ """
83
+ super().__init__(temperature=temperature, **kwargs)
84
+ self._validate_judge()
85
+ assert self.judges is None or len(self.judges) == 0, "Judge question cannot have judges"
86
+ self.model = model
87
+
88
+ def get_cache(self) -> pd.DataFrame:
89
+ """Return all cached judge evaluations as a DataFrame.
90
+
91
+ Useful for inspecting what the judge has evaluated so far.
92
+
93
+ Returns:
94
+ DataFrame with columns:
95
+ - question: Original question (None if judge doesn't use {question})
96
+ - answer: Original answer that was judged
97
+ - judge_question: The formatted prompt sent to the judge
98
+ - judge_answer: The judge's response text
99
+ """
100
+ return pd.DataFrame(self._load_cache_data())
101
+
102
+
103
+ class RatingJudge(JudgeMixin, Rating):
104
+ """Judge that evaluates answers using numeric ratings.
105
+
106
+ Use as a judge in FreeForm questions to have an LLM rate the (question, answer) pairs.
107
+ Returns mean rating computed from logprobs.
108
+ The judge template should contain {answer} placeholder, and optionally {question}.
109
+ """
110
+
111
+ def __init__(self, *, model: str, **kwargs):
112
+ """Initialize a RatingJudge.
113
+
114
+ Args:
115
+ model: Model identifier to use for judging (e.g., "gpt-4o").
116
+ **kwargs: Arguments passed to Rating base class. Must include:
117
+ - paraphrases: Single-element list with the judge template.
118
+ Template must contain {answer}, optionally {question}.
119
+ Example: ["Rate this answer 0-10: {answer}"]
120
+ Optional:
121
+ - min_rating: Minimum rating value. Default: 0.
122
+ - max_rating: Maximum rating value. Default: 100.
123
+ """
124
+ super().__init__(**kwargs)
125
+ self._validate_judge()
126
+ self.model = model
127
+
128
+ def get_cache(self) -> pd.DataFrame:
129
+ """Return all cached judge evaluations as a DataFrame.
130
+
131
+ Useful for inspecting what the judge has evaluated so far.
132
+
133
+ Returns:
134
+ DataFrame with columns:
135
+ - question: Original question (None if judge doesn't use {question})
136
+ - answer: Original answer that was judged
137
+ - judge_question: The formatted prompt sent to the judge
138
+ - judge_answer: Expected rating (float) computed from logprobs
139
+ - judge_raw_answer: Raw logprobs dict {token: probability}
140
+ """
141
+ rows = self._load_cache_data()
142
+ for row in rows:
143
+ # For RatingJudge: rename judge_answer to raw, compute processed score
144
+ row["judge_raw_answer"] = row["judge_answer"]
145
+ row["judge_answer"] = self._compute_expected_rating(row["judge_raw_answer"])
146
+ return pd.DataFrame(rows)
@@ -0,0 +1,283 @@
1
+ import matplotlib.pyplot as plt
2
+ import pandas as pd
3
+
4
+
5
+ def default_title(paraphrases: list[str] | None) -> str | None:
6
+ """Generate default plot title from paraphrases."""
7
+ if paraphrases is None:
8
+ return None
9
+ if len(paraphrases) == 1:
10
+ return paraphrases[0]
11
+ return paraphrases[0] + f"\nand {len(paraphrases) - 1} other paraphrases"
12
+
13
+
14
+ def rating_cumulative_plot(
15
+ df: pd.DataFrame,
16
+ min_rating: int,
17
+ max_rating: int,
18
+ probs_column: str = "probs",
19
+ category_column: str = "group",
20
+ model_groups: dict[str, list[str]] = None,
21
+ show_mean: bool = True,
22
+ title: str = None,
23
+ filename: str = None,
24
+ ):
25
+ """Plot cumulative rating distribution by category.
26
+
27
+ Shows fraction of responses with rating <= X for each X.
28
+ Starts near 0 at min_rating, reaches 100% at max_rating.
29
+
30
+ Args:
31
+ df: DataFrame with probs_column containing normalized probability dicts
32
+ mapping int ratings to probabilities (summing to 1), or None for invalid.
33
+ min_rating: Minimum rating value.
34
+ max_rating: Maximum rating value.
35
+ probs_column: Column containing {rating: prob} dicts. Default: "probs"
36
+ category_column: Column to group by. Default: "group"
37
+ model_groups: Optional dict for ordering groups.
38
+ show_mean: Whether to show mean in legend labels. Default: True
39
+ title: Optional plot title.
40
+ filename: Optional filename to save plot.
41
+ """
42
+ # Get unique categories in order
43
+ categories = df[category_column].unique()
44
+ if category_column == "group" and model_groups is not None:
45
+ categories = [c for c in model_groups.keys() if c in categories]
46
+
47
+ fig, ax = plt.subplots(figsize=(10, 6))
48
+ x_values = list(range(min_rating, max_rating + 1))
49
+
50
+ for category in categories:
51
+ category_df = df[df[category_column] == category]
52
+
53
+ # Accumulate normalized probabilities and means across all rows
54
+ cumulative = {x: 0.0 for x in x_values}
55
+ mean_sum = 0.0
56
+ n_valid = 0
57
+
58
+ for probs in category_df[probs_column]:
59
+ if probs is None:
60
+ continue
61
+
62
+ # For each x, add P(score <= x) = sum of probs for ratings <= x
63
+ for x in x_values:
64
+ cumulative[x] += sum(p for rating, p in probs.items() if rating <= x)
65
+
66
+ # Compute mean for this row
67
+ mean_sum += sum(rating * p for rating, p in probs.items())
68
+ n_valid += 1
69
+
70
+ if n_valid > 0:
71
+ y_values = [cumulative[x] / n_valid for x in x_values]
72
+ mean_value = mean_sum / n_valid
73
+
74
+ if show_mean:
75
+ label = f"{category} (mean: {mean_value:.1f})"
76
+ else:
77
+ label = category
78
+ ax.plot(x_values, y_values, label=label)
79
+
80
+ ax.set_xlabel("Rating")
81
+ ax.set_ylabel("Fraction with score ≤ X")
82
+ ax.set_xlim(min_rating, max_rating)
83
+ ax.set_ylim(0, 1)
84
+ ax.legend()
85
+
86
+ if title is not None:
87
+ ax.set_title(title)
88
+
89
+ plt.tight_layout()
90
+ if filename is not None:
91
+ plt.savefig(filename, bbox_inches="tight")
92
+ plt.show()
93
+
94
+
95
+ def probs_stacked_bar(
96
+ df: pd.DataFrame,
97
+ probs_column: str = "probs",
98
+ category_column: str = "group",
99
+ model_groups: dict[str, list[str]] = None,
100
+ selected_answers: list[str] = None,
101
+ min_fraction: float = None,
102
+ colors: dict[str, str] = None,
103
+ title: str = None,
104
+ filename: str = None,
105
+ ):
106
+ """
107
+ Plot a stacked bar chart from probability distributions.
108
+
109
+ Args:
110
+ df: DataFrame with one row per category, containing probs_column with
111
+ {answer: probability} dicts.
112
+ probs_column: Column containing probability dicts. Default: "probs"
113
+ category_column: Column to group by (x-axis). Default: "group"
114
+ model_groups: Optional dict for ordering groups.
115
+ selected_answers: Optional list of answers to show. Others grouped as "[OTHER]".
116
+ min_fraction: Optional minimum fraction threshold.
117
+ colors: Optional dict mapping answer values to colors.
118
+ title: Optional plot title.
119
+ filename: Optional filename to save plot.
120
+ """
121
+ if min_fraction is not None and selected_answers is not None:
122
+ raise ValueError("min_fraction and selected_answers cannot both be set")
123
+
124
+ # Aggregate probs across rows for each category
125
+ category_probs = {}
126
+ for category in df[category_column].unique():
127
+ cat_df = df[df[category_column] == category]
128
+ combined = {}
129
+ n_rows = 0
130
+ for probs in cat_df[probs_column]:
131
+ if probs is None:
132
+ continue
133
+ for answer, prob in probs.items():
134
+ combined[answer] = combined.get(answer, 0) + prob
135
+ n_rows += 1
136
+ if n_rows > 0:
137
+ category_probs[category] = {k: v / n_rows for k, v in combined.items()}
138
+
139
+ if not category_probs:
140
+ return
141
+
142
+ # Find answers meeting min_fraction threshold
143
+ if min_fraction is not None:
144
+ selected_answers_set = set()
145
+ for probs in category_probs.values():
146
+ for answer, prob in probs.items():
147
+ if prob >= min_fraction:
148
+ selected_answers_set.add(answer)
149
+ selected_answers = list(selected_answers_set)
150
+
151
+ # Group non-selected answers into "[OTHER]"
152
+ if selected_answers is not None:
153
+ for category in category_probs:
154
+ probs = category_probs[category]
155
+ other_prob = sum(p for a, p in probs.items() if a not in selected_answers)
156
+ category_probs[category] = {a: p for a, p in probs.items() if a in selected_answers}
157
+ if other_prob > 0:
158
+ category_probs[category]["[OTHER]"] = other_prob
159
+
160
+ # Build percentages DataFrame
161
+ all_answers = set()
162
+ for probs in category_probs.values():
163
+ all_answers.update(probs.keys())
164
+
165
+ data = {cat: {a: probs.get(a, 0) * 100 for a in all_answers} for cat, probs in category_probs.items()}
166
+ answer_percentages = pd.DataFrame(data).T
167
+
168
+ # Color setup
169
+ if colors is None:
170
+ colors = {}
171
+ if "[OTHER]" in all_answers and "[OTHER]" not in colors:
172
+ colors["[OTHER]"] = "grey"
173
+
174
+ color_palette = [
175
+ "red",
176
+ "blue",
177
+ "green",
178
+ "orange",
179
+ "purple",
180
+ "brown",
181
+ "pink",
182
+ "olive",
183
+ "cyan",
184
+ "magenta",
185
+ "yellow",
186
+ "navy",
187
+ "lime",
188
+ "maroon",
189
+ "teal",
190
+ "silver",
191
+ "gold",
192
+ "indigo",
193
+ "coral",
194
+ "crimson",
195
+ ]
196
+
197
+ # Order answers
198
+ column_answers = list(answer_percentages.columns)
199
+ if selected_answers is not None:
200
+ ordered_answers = [a for a in selected_answers if a in column_answers]
201
+ extras = sorted([a for a in column_answers if a not in selected_answers])
202
+ ordered_answers += extras
203
+ elif colors:
204
+ ordered_answers = [a for a in colors.keys() if a in column_answers]
205
+ extras = sorted([a for a in column_answers if a not in ordered_answers])
206
+ ordered_answers += extras
207
+ else:
208
+ ordered_answers = sorted(column_answers)
209
+ answer_percentages = answer_percentages.reindex(columns=ordered_answers)
210
+
211
+ # Build colors list
212
+ plot_colors = []
213
+ color_index = 0
214
+ for answer in ordered_answers:
215
+ if answer in colors:
216
+ plot_colors.append(colors[answer])
217
+ elif answer == "[OTHER]":
218
+ plot_colors.append("grey")
219
+ else:
220
+ plot_colors.append(color_palette[color_index % len(color_palette)])
221
+ color_index += 1
222
+
223
+ # Order categories
224
+ if category_column == "group" and model_groups is not None:
225
+ ordered_groups = [g for g in model_groups.keys() if g in answer_percentages.index]
226
+ ordered_groups += [g for g in answer_percentages.index if g not in ordered_groups]
227
+ answer_percentages = answer_percentages.reindex(ordered_groups)
228
+
229
+ fig, ax = plt.subplots(figsize=(12, 8))
230
+ answer_percentages.plot(kind="bar", stacked=True, ax=ax, color=plot_colors)
231
+
232
+ plt.xlabel(category_column)
233
+ plt.ylabel("Percentage")
234
+ plt.legend(title="answer")
235
+ plt.xticks(rotation=45, ha="right")
236
+
237
+ if title is not None:
238
+ plt.title(title)
239
+
240
+ plt.tight_layout()
241
+ if filename is not None:
242
+ plt.savefig(filename, bbox_inches="tight")
243
+ plt.show()
244
+
245
+
246
+ def free_form_stacked_bar(
247
+ df: pd.DataFrame,
248
+ category_column: str = "group",
249
+ answer_column: str = "answer",
250
+ model_groups: dict[str, list[str]] = None,
251
+ selected_answers: list[str] = None,
252
+ min_fraction: float = None,
253
+ colors: dict[str, str] = None,
254
+ title: str = None,
255
+ filename: str = None,
256
+ ):
257
+ """
258
+ Plot a stacked bar chart showing the distribution of answers by category.
259
+
260
+ Transforms FreeForm data (multiple rows with single answers) into probability
261
+ distributions and calls probs_stacked_bar.
262
+ """
263
+ # Transform to probs format: one row per category with {answer: prob} dict
264
+ probs_data = []
265
+ for category in df[category_column].unique():
266
+ cat_df = df[df[category_column] == category]
267
+ counts = cat_df[answer_column].value_counts()
268
+ probs = (counts / counts.sum()).to_dict()
269
+ probs_data.append({category_column: category, "probs": probs})
270
+
271
+ probs_df = pd.DataFrame(probs_data)
272
+
273
+ return probs_stacked_bar(
274
+ probs_df,
275
+ probs_column="probs",
276
+ category_column=category_column,
277
+ model_groups=model_groups,
278
+ selected_answers=selected_answers,
279
+ min_fraction=min_fraction,
280
+ colors=colors,
281
+ title=title,
282
+ filename=filename,
283
+ )