llmcomp 1.2.4__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,488 @@
1
+ """DataFrame viewer for browsing question results.
2
+
3
+ Spawns a local Streamlit server to interactively browse (api_kwargs, answer) pairs.
4
+
5
+ Usage:
6
+ from llmcomp import Question
7
+
8
+ question = Question.create(...)
9
+ df = question.df(models)
10
+ Question.view(df)
11
+ """
12
+
13
+ import json
14
+ import os
15
+ import subprocess
16
+ import sys
17
+ import tempfile
18
+ import webbrowser
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ # Streamlit imports are inside functions to avoid import errors when streamlit isn't installed
23
+
24
+
25
+ def render_dataframe(
26
+ df: "pd.DataFrame",
27
+ sort_by: str | None = "__random__",
28
+ sort_ascending: bool = True,
29
+ open_browser: bool = True,
30
+ port: int = 8501,
31
+ ) -> None:
32
+ """Launch a Streamlit viewer for the DataFrame.
33
+
34
+ Args:
35
+ df: DataFrame with at least 'api_kwargs' and 'answer' columns.
36
+ Other columns (model, group, etc.) are displayed as metadata.
37
+ sort_by: Column name to sort by initially. Default: "__random__" for random
38
+ shuffling (new seed on each refresh). Use None for original order.
39
+ sort_ascending: Sort order. Default: True (ascending).
40
+ open_browser: If True, automatically open the viewer in default browser.
41
+ port: Port to run the Streamlit server on.
42
+
43
+ Raises:
44
+ ValueError: If required columns are missing.
45
+ """
46
+ # Validate required columns
47
+ if "api_kwargs" not in df.columns:
48
+ raise ValueError("DataFrame must have an 'api_kwargs' column")
49
+ if "answer" not in df.columns:
50
+ raise ValueError("DataFrame must have an 'answer' column")
51
+ if sort_by is not None and sort_by != "__random__" and sort_by not in df.columns:
52
+ raise ValueError(f"sort_by column '{sort_by}' not found in DataFrame")
53
+
54
+ # Save DataFrame to a temp file
55
+ temp_dir = tempfile.mkdtemp(prefix="llmcomp_viewer_")
56
+ temp_path = os.path.join(temp_dir, "data.jsonl")
57
+
58
+ # Convert DataFrame to JSONL
59
+ with open(temp_path, "w", encoding="utf-8") as f:
60
+ for _, row in df.iterrows():
61
+ row_dict = row.to_dict()
62
+ f.write(json.dumps(row_dict, default=str) + "\n")
63
+
64
+ url = f"http://localhost:{port}"
65
+ print(f"Starting viewer at {url}")
66
+ print(f"Data file: {temp_path}")
67
+ print("Press Ctrl+C to stop the server.\n")
68
+
69
+ if open_browser:
70
+ # Open browser after a short delay to let server start
71
+ import threading
72
+ threading.Timer(0.5, lambda: webbrowser.open(url)).start()
73
+
74
+ # Launch Streamlit
75
+ viewer_path = Path(__file__).resolve()
76
+ cmd = [
77
+ sys.executable, "-m", "streamlit", "run",
78
+ str(viewer_path),
79
+ "--server.port", str(port),
80
+ "--server.headless", "true",
81
+ "--", # Separator for script args
82
+ temp_path,
83
+ sort_by or "", # Empty string means no sorting
84
+ "asc" if sort_ascending else "desc",
85
+ ]
86
+
87
+ try:
88
+ subprocess.run(cmd, check=True)
89
+ except KeyboardInterrupt:
90
+ print("\nViewer stopped.")
91
+ finally:
92
+ # Clean up temp file
93
+ try:
94
+ os.remove(temp_path)
95
+ os.rmdir(temp_dir)
96
+ except OSError:
97
+ pass
98
+
99
+
100
+ # =============================================================================
101
+ # Streamlit App (runs when this file is executed by streamlit)
102
+ # =============================================================================
103
+
104
+ def _get_data_path() -> str | None:
105
+ """Get data file path from command line args."""
106
+ # Args after -- are passed to the script
107
+ if len(sys.argv) > 1:
108
+ return sys.argv[1]
109
+ return None
110
+
111
+
112
+ def _get_initial_sort() -> tuple[str | None, bool]:
113
+ """Get initial sort settings from command line args."""
114
+ sort_by = None
115
+ sort_ascending = True
116
+
117
+ if len(sys.argv) > 2:
118
+ sort_by = sys.argv[2] if sys.argv[2] else None
119
+ if len(sys.argv) > 3:
120
+ sort_ascending = sys.argv[3] != "desc"
121
+
122
+ return sort_by, sort_ascending
123
+
124
+
125
+ def _read_jsonl(path: str) -> list[dict[str, Any]]:
126
+ """Read JSONL file into a list of dicts."""
127
+ items = []
128
+ with open(path, "r", encoding="utf-8") as f:
129
+ for line in f:
130
+ line = line.strip()
131
+ if line:
132
+ items.append(json.loads(line))
133
+ return items
134
+
135
+
136
+ def _display_messages(messages: list[dict[str, str]]) -> None:
137
+ """Display a list of chat messages in Streamlit chat format."""
138
+ import streamlit as st
139
+
140
+ for msg in messages:
141
+ role = msg.get("role", "user")
142
+ content = msg.get("content", "")
143
+
144
+ # Map roles to streamlit chat_message roles
145
+ if role == "system":
146
+ with st.chat_message("assistant", avatar="⚙️"):
147
+ st.markdown("**System**")
148
+ st.text(content)
149
+ elif role == "assistant":
150
+ with st.chat_message("assistant"):
151
+ st.text(content)
152
+ else: # user or other
153
+ with st.chat_message("user"):
154
+ st.text(content)
155
+
156
+
157
+ def _display_answer(answer: Any, label: str | None = None) -> None:
158
+ """Display the answer, handling different types."""
159
+ import streamlit as st
160
+
161
+ if label:
162
+ st.markdown(f"**{label}**")
163
+
164
+ if isinstance(answer, dict):
165
+ # For NextToken questions, answer is {token: probability}
166
+ # Sort by probability descending
167
+ sorted_items = sorted(answer.items(), key=lambda x: -x[1] if isinstance(x[1], (int, float)) else 0)
168
+ # Display as a table-like format
169
+ for token, prob in sorted_items[:20]: # Show top 20
170
+ if isinstance(prob, float):
171
+ st.text(f" {token!r}: {prob:.4f}")
172
+ else:
173
+ st.text(f" {token!r}: {prob}")
174
+ elif isinstance(answer, str):
175
+ st.text(answer)
176
+ else:
177
+ st.text(str(answer))
178
+
179
+
180
+ def _display_metadata(row: dict[str, Any], exclude_keys: set[str]) -> None:
181
+ """Display metadata columns."""
182
+ import streamlit as st
183
+
184
+ metadata = {k: v for k, v in row.items() if k not in exclude_keys}
185
+ if metadata:
186
+ with st.expander("Metadata", expanded=False):
187
+ for key, value in metadata.items():
188
+ if isinstance(value, (dict, list)):
189
+ st.markdown(f"**{key}:**")
190
+ # Collapse _raw_answer and _probs dicts by default
191
+ collapsed = key.endswith("_raw_answer") or key.endswith("_probs")
192
+ st.json(value, expanded=not collapsed)
193
+ else:
194
+ st.markdown(f"**{key}:** {value}")
195
+
196
+
197
+ def _search_items(items: list[dict[str, Any]], query: str) -> list[dict[str, Any]]:
198
+ """Filter items by search query.
199
+
200
+ Supports:
201
+ - Regular search: "foo" - includes items containing "foo"
202
+ - Negative search: "-foo" - excludes items containing "foo"
203
+ - Combined: "foo -bar" - items with "foo" but not "bar"
204
+ """
205
+ if not query:
206
+ return items
207
+
208
+ # Parse query into positive and negative terms
209
+ terms = query.split()
210
+ positive_terms = []
211
+ negative_terms = []
212
+
213
+ for term in terms:
214
+ if term.startswith("-") and len(term) > 1:
215
+ negative_terms.append(term[1:].lower())
216
+ else:
217
+ positive_terms.append(term.lower())
218
+
219
+ results = []
220
+
221
+ for item in items:
222
+ # Build searchable text from item
223
+ api_kwargs = item.get("api_kwargs", {})
224
+ messages = api_kwargs.get("messages", []) if isinstance(api_kwargs, dict) else []
225
+ messages_text = " ".join(m.get("content", "") for m in messages)
226
+
227
+ answer = item.get("answer", "")
228
+ answer_text = str(answer) if not isinstance(answer, str) else answer
229
+
230
+ all_text = messages_text + " " + answer_text
231
+ all_text += " " + " ".join(str(v) for v in item.values() if isinstance(v, str))
232
+ all_text_lower = all_text.lower()
233
+
234
+ # Check positive terms (all must match)
235
+ if positive_terms and not all(term in all_text_lower for term in positive_terms):
236
+ continue
237
+
238
+ # Check negative terms (none must match)
239
+ if any(term in all_text_lower for term in negative_terms):
240
+ continue
241
+
242
+ results.append(item)
243
+
244
+ return results
245
+
246
+
247
+ def _streamlit_main():
248
+ """Main Streamlit app."""
249
+ import streamlit as st
250
+
251
+ st.set_page_config(
252
+ page_title="llmcomp Viewer",
253
+ page_icon="🔬",
254
+ layout="wide",
255
+ )
256
+
257
+ st.title("🔬 llmcomp Viewer")
258
+
259
+ # Get data path
260
+ data_path = _get_data_path()
261
+ if data_path is None or not os.path.exists(data_path):
262
+ st.error("No data file provided or file not found.")
263
+ st.info("Use `Question.render(df)` to launch the viewer with data.")
264
+ return
265
+
266
+ # Load data (cache in session state)
267
+ cache_key = f"llmcomp_data_{data_path}"
268
+ if cache_key not in st.session_state:
269
+ st.session_state[cache_key] = _read_jsonl(data_path)
270
+
271
+ items = st.session_state[cache_key]
272
+
273
+ if not items:
274
+ st.warning("No data to display.")
275
+ return
276
+
277
+ # Get sortable columns (numeric or string, exclude complex types)
278
+ sortable_columns = ["(random)", "(none)"]
279
+ if items:
280
+ for key, value in items[0].items():
281
+ if key not in ("api_kwargs",) and isinstance(value, (int, float, str, type(None))):
282
+ sortable_columns.append(key)
283
+
284
+ # Initialize sort settings from command line args
285
+ initial_sort_by, initial_sort_asc = _get_initial_sort()
286
+ if "sort_by" not in st.session_state:
287
+ # Map __random__ from CLI to (random) in UI
288
+ if initial_sort_by == "__random__":
289
+ st.session_state.sort_by = "(random)"
290
+ elif initial_sort_by in sortable_columns:
291
+ st.session_state.sort_by = initial_sort_by
292
+ else:
293
+ st.session_state.sort_by = "(none)"
294
+ st.session_state.sort_ascending = initial_sort_asc
295
+
296
+ # Initialize view index
297
+ if "view_idx" not in st.session_state:
298
+ st.session_state.view_idx = 0
299
+
300
+ # Initialize secondary sort
301
+ if "sort_by_2" not in st.session_state:
302
+ st.session_state.sort_by_2 = "(none)"
303
+ st.session_state.sort_ascending_2 = True
304
+
305
+ # Search and sort controls
306
+ col_search, col_sort, col_order = st.columns([3, 2, 1])
307
+
308
+ with col_search:
309
+ query = st.text_input("🔍 Search", placeholder="Filter... (use -term to exclude)")
310
+
311
+ with col_sort:
312
+ sort_by = st.selectbox(
313
+ "Sort by",
314
+ options=sortable_columns,
315
+ index=sortable_columns.index(st.session_state.sort_by) if st.session_state.sort_by in sortable_columns else 0,
316
+ key="sort_by_select",
317
+ )
318
+ if sort_by != st.session_state.sort_by:
319
+ st.session_state.sort_by = sort_by
320
+ st.session_state.view_idx = 0 # Reset to first item when sort changes
321
+
322
+ with col_order:
323
+ st.markdown("<br>", unsafe_allow_html=True) # Align checkbox with selectbox
324
+ sort_ascending = st.checkbox("Asc", value=st.session_state.sort_ascending, key="sort_asc_check")
325
+ if sort_ascending != st.session_state.sort_ascending:
326
+ st.session_state.sort_ascending = sort_ascending
327
+ st.session_state.view_idx = 0
328
+
329
+ # Reshuffle button for random sort
330
+ if st.session_state.sort_by == "(random)":
331
+ import random
332
+ col_reshuffle, _ = st.columns([1, 5])
333
+ with col_reshuffle:
334
+ if st.button("🔀 Reshuffle"):
335
+ st.session_state.random_seed = random.randint(0, 2**32 - 1)
336
+ st.session_state.view_idx = 0
337
+ st.rerun()
338
+
339
+ # Secondary sort (only show if primary sort is selected)
340
+ if st.session_state.sort_by and st.session_state.sort_by != "(none)":
341
+ col_spacer, col_sort2, col_order2 = st.columns([3, 2, 1])
342
+ with col_sort2:
343
+ sort_by_2 = st.selectbox(
344
+ "Then by",
345
+ options=sortable_columns,
346
+ index=sortable_columns.index(st.session_state.sort_by_2) if st.session_state.sort_by_2 in sortable_columns else 0,
347
+ key="sort_by_select_2",
348
+ )
349
+ if sort_by_2 != st.session_state.sort_by_2:
350
+ st.session_state.sort_by_2 = sort_by_2
351
+ st.session_state.view_idx = 0
352
+ with col_order2:
353
+ st.markdown("<br>", unsafe_allow_html=True) # Align checkbox with selectbox
354
+ sort_ascending_2 = st.checkbox("Asc", value=st.session_state.sort_ascending_2, key="sort_asc_check_2")
355
+ if sort_ascending_2 != st.session_state.sort_ascending_2:
356
+ st.session_state.sort_ascending_2 = sort_ascending_2
357
+ st.session_state.view_idx = 0
358
+
359
+ # Apply search
360
+ filtered_items = _search_items(items, query)
361
+
362
+ # Apply random shuffle if selected (new seed on each refresh via Reshuffle button)
363
+ if st.session_state.sort_by == "(random)" and filtered_items:
364
+ import random
365
+ # Generate a new seed on first load or when explicitly reshuffled
366
+ if "random_seed" not in st.session_state:
367
+ st.session_state.random_seed = random.randint(0, 2**32 - 1)
368
+ rng = random.Random(st.session_state.random_seed)
369
+ filtered_items = filtered_items.copy()
370
+ rng.shuffle(filtered_items)
371
+
372
+ # Apply sorting (stable sort - secondary first, then primary)
373
+ if st.session_state.sort_by and st.session_state.sort_by not in ("(none)", "(random)") and filtered_items:
374
+ sort_key_2 = st.session_state.sort_by_2 if st.session_state.sort_by_2 != "(none)" else None
375
+
376
+ # Secondary sort first (stable sort preserves this ordering within primary groups)
377
+ if sort_key_2:
378
+ filtered_items = sorted(
379
+ filtered_items,
380
+ key=lambda x: (x.get(sort_key_2) is None, x.get(sort_key_2)),
381
+ reverse=not st.session_state.sort_ascending_2,
382
+ )
383
+
384
+ # Primary sort
385
+ sort_key = st.session_state.sort_by
386
+ filtered_items = sorted(
387
+ filtered_items,
388
+ key=lambda x: (x.get(sort_key) is None, x.get(sort_key)),
389
+ reverse=not st.session_state.sort_ascending,
390
+ )
391
+
392
+ if not filtered_items:
393
+ st.warning(f"No results found for '{query}'")
394
+ return
395
+
396
+ # Clamp view index to valid range
397
+ max_idx = len(filtered_items) - 1
398
+ st.session_state.view_idx = max(0, min(st.session_state.view_idx, max_idx))
399
+
400
+ # Navigation
401
+ col1, col2, col3, col4 = st.columns([1, 1, 2, 2])
402
+
403
+ with col1:
404
+ if st.button("⬅️ Prev", use_container_width=True):
405
+ st.session_state.view_idx = max(0, st.session_state.view_idx - 1)
406
+ st.rerun()
407
+
408
+ with col2:
409
+ if st.button("Next ➡️", use_container_width=True):
410
+ st.session_state.view_idx = min(max_idx, st.session_state.view_idx + 1)
411
+ st.rerun()
412
+
413
+ with col3:
414
+ # Jump to specific index
415
+ new_idx = st.number_input(
416
+ "Go to",
417
+ min_value=1,
418
+ max_value=len(filtered_items),
419
+ value=st.session_state.view_idx + 1,
420
+ step=1,
421
+ label_visibility="collapsed",
422
+ )
423
+ if new_idx - 1 != st.session_state.view_idx:
424
+ st.session_state.view_idx = new_idx - 1
425
+ st.rerun()
426
+
427
+ with col4:
428
+ st.markdown(f"**{st.session_state.view_idx + 1}** of **{len(filtered_items)}**")
429
+ if query:
430
+ st.caption(f"({len(items)} total)")
431
+
432
+ st.divider()
433
+
434
+ # Display current item
435
+ current = filtered_items[st.session_state.view_idx]
436
+
437
+ # Main content in two columns
438
+ left_col, right_col = st.columns([1, 2])
439
+
440
+ with left_col:
441
+ st.subheader("💬 Messages")
442
+ api_kwargs = current.get("api_kwargs", {})
443
+ messages = api_kwargs.get("messages", []) if isinstance(api_kwargs, dict) else []
444
+ if messages:
445
+ _display_messages(messages)
446
+ else:
447
+ st.info("No messages")
448
+
449
+ with right_col:
450
+ model_name = current.get("model", "Response")
451
+ st.subheader(f"🤖 {model_name}")
452
+ answer = current.get("answer")
453
+ if answer is not None:
454
+ _display_answer(answer, label=None)
455
+ else:
456
+ st.info("No answer")
457
+
458
+ # Display judge columns if present
459
+ judge_columns = [k for k in current.keys() if not k.startswith("_") and k not in {
460
+ "api_kwargs", "answer", "question", "model", "group", "paraphrase_ix", "raw_answer"
461
+ } and not k.endswith("_question") and not k.endswith("_raw_answer") and not k.endswith("_probs")]
462
+
463
+ if judge_columns:
464
+ st.markdown("---")
465
+ for judge_col in judge_columns:
466
+ value = current[judge_col]
467
+ if isinstance(value, float):
468
+ st.markdown(f"**{judge_col}:** {value:.2f}")
469
+ else:
470
+ st.markdown(f"**{judge_col}:** {value}")
471
+
472
+ # Metadata at the bottom
473
+ st.divider()
474
+ # Show api_kwargs in metadata, but without messages (already displayed above)
475
+ current_for_metadata = current.copy()
476
+ if "api_kwargs" in current_for_metadata and isinstance(current_for_metadata["api_kwargs"], dict):
477
+ api_kwargs_without_messages = {k: v for k, v in current_for_metadata["api_kwargs"].items() if k != "messages"}
478
+ current_for_metadata["api_kwargs"] = api_kwargs_without_messages
479
+ exclude_keys = {"answer", "question", "paraphrase_ix"} | set(judge_columns)
480
+ _display_metadata(current_for_metadata, exclude_keys)
481
+
482
+ # Keyboard navigation hint
483
+ st.caption("💡 Tip: Use the navigation buttons or enter a number to jump to a specific row.")
484
+
485
+
486
+ # Entry point when run by Streamlit
487
+ if __name__ == "__main__":
488
+ _streamlit_main()
llmcomp/runner/runner.py CHANGED
@@ -51,12 +51,15 @@ class Runner:
51
51
  prepared = ModelAdapter.prepare(params, self.model)
52
52
  return {"timeout": Config.timeout, **prepared}
53
53
 
54
- def get_text(self, params: dict) -> str:
54
+ def get_text(self, params: dict) -> tuple[str, dict]:
55
55
  """Get a text completion from the model.
56
56
 
57
57
  Args:
58
58
  params: Dictionary of parameters for the API.
59
59
  Must include 'messages'. Other common keys: 'temperature', 'max_tokens'.
60
+
61
+ Returns:
62
+ Tuple of (content, prepared_kwargs) where prepared_kwargs is what was sent to the API.
60
63
  """
61
64
  prepared = self._prepare_for_model(params)
62
65
  completion = openai_chat_completion(client=self.client, **prepared)
@@ -72,8 +75,8 @@ class Runner:
72
75
  # refusal="I'm sorry, I'm unable to fulfill that request.",
73
76
  # ...))])
74
77
  warnings.warn(f"API sent None as content. Returning empty string.\n{completion}", stacklevel=2)
75
- return ""
76
- return content
78
+ return "", prepared
79
+ return content, prepared
77
80
  except Exception:
78
81
  warnings.warn(f"Unexpected error.\n{completion}")
79
82
  raise
@@ -84,7 +87,7 @@ class Runner:
84
87
  *,
85
88
  num_samples: int = 1,
86
89
  convert_to_probs: bool = True,
87
- ) -> dict:
90
+ ) -> tuple[dict, dict]:
88
91
  """Get probability distribution of the next token, optionally averaged over multiple samples.
89
92
 
90
93
  Args:
@@ -92,22 +95,26 @@ class Runner:
92
95
  Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
93
96
  num_samples: Number of samples to average over. Default: 1.
94
97
  convert_to_probs: If True, convert logprobs to probabilities. Default: True.
98
+
99
+ Returns:
100
+ Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
95
101
  """
96
102
  probs = {}
103
+ prepared = None
97
104
  for _ in range(num_samples):
98
- new_probs = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
105
+ new_probs, prepared = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
99
106
  for key, value in new_probs.items():
100
107
  probs[key] = probs.get(key, 0) + value
101
108
  result = {key: value / num_samples for key, value in probs.items()}
102
109
  result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
103
- return result
110
+ return result, prepared
104
111
 
105
112
  def single_token_probs_one_sample(
106
113
  self,
107
114
  params: dict,
108
115
  *,
109
116
  convert_to_probs: bool = True,
110
- ) -> dict:
117
+ ) -> tuple[dict, dict]:
111
118
  """Get probability distribution of the next token (single sample).
112
119
 
113
120
  Args:
@@ -115,6 +122,9 @@ class Runner:
115
122
  Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
116
123
  convert_to_probs: If True, convert logprobs to probabilities. Default: True.
117
124
 
125
+ Returns:
126
+ Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
127
+
118
128
  Note: This function forces max_tokens=1, temperature=0, logprobs=True.
119
129
  """
120
130
  # Build complete params with defaults and forced params
@@ -138,7 +148,7 @@ class Runner:
138
148
  except IndexError:
139
149
  # This should not happen according to the API docs. But it sometimes does.
140
150
  print(NO_LOGPROBS_WARNING.format(model=self.model, completion=completion))
141
- return {}
151
+ return {}, prepared
142
152
 
143
153
  # Check for duplicate tokens - this shouldn't happen with OpenAI but might with other providers
144
154
  tokens = [el.token for el in logprobs]
@@ -153,7 +163,7 @@ class Runner:
153
163
  for el in logprobs:
154
164
  result[el.token] = math.exp(el.logprob) if convert_to_probs else el.logprob
155
165
 
156
- return result
166
+ return result, prepared
157
167
 
158
168
  def get_many(
159
169
  self,
@@ -173,8 +183,8 @@ class Runner:
173
183
  {"params": {"messages": [{"role": "user", "content": "Hello"}]}},
174
184
  {"params": {"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7}},
175
185
  ]
176
- for in_, out in runner.get_many(runner.get_text, kwargs_list):
177
- print(in_, "->", out)
186
+ for in_, (out, prepared_kwargs) in runner.get_many(runner.get_text, kwargs_list):
187
+ print(in_, "->", out, prepared_kwargs)
178
188
 
179
189
  or
180
190
 
@@ -182,14 +192,14 @@ class Runner:
182
192
  {"params": {"messages": [{"role": "user", "content": "Hello"}]}},
183
193
  {"params": {"messages": [{"role": "user", "content": "Bye"}]}},
184
194
  ]
185
- for in_, out in runner.get_many(runner.single_token_probs, kwargs_list):
186
- print(in_, "->", out)
195
+ for in_, (out, prepared_kwargs) in runner.get_many(runner.single_token_probs, kwargs_list):
196
+ print(in_, "->", out, prepared_kwargs)
187
197
 
188
198
  (FUNC that is a different callable should also work)
189
199
 
190
200
  This function returns a generator that yields pairs (input, output),
191
- where input is an element from KWARGS_LIST and output is the thing returned by
192
- FUNC for this input.
201
+ where input is an element from KWARGS_LIST and output is the tuple (result, prepared_kwargs)
202
+ returned by FUNC. prepared_kwargs contains the actual parameters sent to the API.
193
203
 
194
204
  Dictionaries in KWARGS_LIST might include optional keys starting with underscore,
195
205
  they are just ignored, but they are returned in the first element of the pair, so that's useful
@@ -230,7 +240,7 @@ class Runner:
230
240
  f"Model: {self.model}, function: {func.__name__}{msg_info}. "
231
241
  f"Error: {type(e).__name__}: {e}"
232
242
  )
233
- result = None
243
+ result = (None, {})
234
244
  return kwargs, result
235
245
 
236
246
  futures = [executor.submit(get_data, kwargs) for kwargs in kwargs_list]
@@ -251,7 +261,7 @@ class Runner:
251
261
  params: dict,
252
262
  *,
253
263
  num_samples: int,
254
- ) -> dict:
264
+ ) -> tuple[dict, dict]:
255
265
  """Sample answers NUM_SAMPLES times. Returns probabilities of answers.
256
266
 
257
267
  Args:
@@ -259,6 +269,9 @@ class Runner:
259
269
  Must include 'messages'. Other common keys: 'max_tokens', 'temperature'.
260
270
  num_samples: Number of samples to collect.
261
271
 
272
+ Returns:
273
+ Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
274
+
262
275
  Works only if the API supports `n` parameter.
263
276
 
264
277
  Usecases:
@@ -268,6 +281,7 @@ class Runner:
268
281
  for Runner.single_token_probs.
269
282
  """
270
283
  cnts = defaultdict(int)
284
+ prepared = None
271
285
  for i in range(((num_samples - 1) // 128) + 1):
272
286
  n = min(128, num_samples - i * 128)
273
287
  # Build complete params with forced param
@@ -285,4 +299,4 @@ class Runner:
285
299
  )
286
300
  result = {key: val / num_samples for key, val in cnts.items()}
287
301
  result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
288
- return result
302
+ return result, prepared