llmcomp 1.2.4__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,459 @@
1
+ """DataFrame viewer for browsing question results.
2
+
3
+ Spawns a local Streamlit server to interactively browse (api_kwargs, answer) pairs.
4
+
5
+ Usage:
6
+ from llmcomp import Question
7
+
8
+ question = Question.create(...)
9
+ df = question.df(models)
10
+ Question.view(df)
11
+ """
12
+
13
+ import json
14
+ import os
15
+ import subprocess
16
+ import sys
17
+ import tempfile
18
+ import webbrowser
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ # Streamlit imports are inside functions to avoid import errors when streamlit isn't installed
23
+
24
+
25
+ def render_dataframe(
26
+ df: "pd.DataFrame",
27
+ sort_by: str | None = None,
28
+ sort_ascending: bool = True,
29
+ open_browser: bool = True,
30
+ port: int = 8501,
31
+ ) -> None:
32
+ """Launch a Streamlit viewer for the DataFrame.
33
+
34
+ Args:
35
+ df: DataFrame with at least 'api_kwargs' and 'answer' columns.
36
+ Other columns (model, group, etc.) are displayed as metadata.
37
+ sort_by: Column name to sort by initially. If None, keeps original order.
38
+ sort_ascending: Sort order. Default: True (ascending).
39
+ open_browser: If True, automatically open the viewer in default browser.
40
+ port: Port to run the Streamlit server on.
41
+
42
+ Raises:
43
+ ValueError: If required columns are missing.
44
+ """
45
+ # Validate required columns
46
+ if "api_kwargs" not in df.columns:
47
+ raise ValueError("DataFrame must have an 'api_kwargs' column")
48
+ if "answer" not in df.columns:
49
+ raise ValueError("DataFrame must have an 'answer' column")
50
+ if sort_by is not None and sort_by not in df.columns:
51
+ raise ValueError(f"sort_by column '{sort_by}' not found in DataFrame")
52
+
53
+ # Save DataFrame to a temp file
54
+ temp_dir = tempfile.mkdtemp(prefix="llmcomp_viewer_")
55
+ temp_path = os.path.join(temp_dir, "data.jsonl")
56
+
57
+ # Convert DataFrame to JSONL
58
+ with open(temp_path, "w", encoding="utf-8") as f:
59
+ for _, row in df.iterrows():
60
+ row_dict = row.to_dict()
61
+ f.write(json.dumps(row_dict, default=str) + "\n")
62
+
63
+ url = f"http://localhost:{port}"
64
+ print(f"Starting viewer at {url}")
65
+ print(f"Data file: {temp_path}")
66
+ print("Press Ctrl+C to stop the server.\n")
67
+
68
+ if open_browser:
69
+ # Open browser after a short delay to let server start
70
+ import threading
71
+ threading.Timer(1.5, lambda: webbrowser.open(url)).start()
72
+
73
+ # Launch Streamlit
74
+ viewer_path = Path(__file__).resolve()
75
+ cmd = [
76
+ sys.executable, "-m", "streamlit", "run",
77
+ str(viewer_path),
78
+ "--server.port", str(port),
79
+ "--server.headless", "true",
80
+ "--", # Separator for script args
81
+ temp_path,
82
+ sort_by or "", # Empty string means no sorting
83
+ "asc" if sort_ascending else "desc",
84
+ ]
85
+
86
+ try:
87
+ subprocess.run(cmd, check=True)
88
+ except KeyboardInterrupt:
89
+ print("\nViewer stopped.")
90
+ finally:
91
+ # Clean up temp file
92
+ try:
93
+ os.remove(temp_path)
94
+ os.rmdir(temp_dir)
95
+ except OSError:
96
+ pass
97
+
98
+
99
+ # =============================================================================
100
+ # Streamlit App (runs when this file is executed by streamlit)
101
+ # =============================================================================
102
+
103
+ def _get_data_path() -> str | None:
104
+ """Get data file path from command line args."""
105
+ # Args after -- are passed to the script
106
+ if len(sys.argv) > 1:
107
+ return sys.argv[1]
108
+ return None
109
+
110
+
111
+ def _get_initial_sort() -> tuple[str | None, bool]:
112
+ """Get initial sort settings from command line args."""
113
+ sort_by = None
114
+ sort_ascending = True
115
+
116
+ if len(sys.argv) > 2:
117
+ sort_by = sys.argv[2] if sys.argv[2] else None
118
+ if len(sys.argv) > 3:
119
+ sort_ascending = sys.argv[3] != "desc"
120
+
121
+ return sort_by, sort_ascending
122
+
123
+
124
+ def _read_jsonl(path: str) -> list[dict[str, Any]]:
125
+ """Read JSONL file into a list of dicts."""
126
+ items = []
127
+ with open(path, "r", encoding="utf-8") as f:
128
+ for line in f:
129
+ line = line.strip()
130
+ if line:
131
+ items.append(json.loads(line))
132
+ return items
133
+
134
+
135
+ def _display_messages(messages: list[dict[str, str]]) -> None:
136
+ """Display a list of chat messages in Streamlit chat format."""
137
+ import streamlit as st
138
+
139
+ for msg in messages:
140
+ role = msg.get("role", "user")
141
+ content = msg.get("content", "")
142
+
143
+ # Map roles to streamlit chat_message roles
144
+ if role == "system":
145
+ with st.chat_message("assistant", avatar="⚙️"):
146
+ st.markdown("**System**")
147
+ st.text(content)
148
+ elif role == "assistant":
149
+ with st.chat_message("assistant"):
150
+ st.text(content)
151
+ else: # user or other
152
+ with st.chat_message("user"):
153
+ st.text(content)
154
+
155
+
156
+ def _display_answer(answer: Any, label: str | None = None) -> None:
157
+ """Display the answer, handling different types."""
158
+ import streamlit as st
159
+
160
+ if label:
161
+ st.markdown(f"**{label}**")
162
+
163
+ if isinstance(answer, dict):
164
+ # For NextToken questions, answer is {token: probability}
165
+ # Sort by probability descending
166
+ sorted_items = sorted(answer.items(), key=lambda x: -x[1] if isinstance(x[1], (int, float)) else 0)
167
+ # Display as a table-like format
168
+ for token, prob in sorted_items[:20]: # Show top 20
169
+ if isinstance(prob, float):
170
+ st.text(f" {token!r}: {prob:.4f}")
171
+ else:
172
+ st.text(f" {token!r}: {prob}")
173
+ elif isinstance(answer, str):
174
+ st.text(answer)
175
+ else:
176
+ st.text(str(answer))
177
+
178
+
179
+ def _display_metadata(row: dict[str, Any], exclude_keys: set[str]) -> None:
180
+ """Display metadata columns."""
181
+ import streamlit as st
182
+
183
+ metadata = {k: v for k, v in row.items() if k not in exclude_keys}
184
+ if metadata:
185
+ with st.expander("Metadata", expanded=False):
186
+ for key, value in metadata.items():
187
+ if isinstance(value, (dict, list)):
188
+ st.markdown(f"**{key}:**")
189
+ st.json(value)
190
+ else:
191
+ st.markdown(f"**{key}:** {value}")
192
+
193
+
194
+ def _search_items(items: list[dict[str, Any]], query: str) -> list[dict[str, Any]]:
195
+ """Filter items by search query.
196
+
197
+ Supports:
198
+ - Regular search: "foo" - includes items containing "foo"
199
+ - Negative search: "-foo" - excludes items containing "foo"
200
+ - Combined: "foo -bar" - items with "foo" but not "bar"
201
+ """
202
+ if not query:
203
+ return items
204
+
205
+ # Parse query into positive and negative terms
206
+ terms = query.split()
207
+ positive_terms = []
208
+ negative_terms = []
209
+
210
+ for term in terms:
211
+ if term.startswith("-") and len(term) > 1:
212
+ negative_terms.append(term[1:].lower())
213
+ else:
214
+ positive_terms.append(term.lower())
215
+
216
+ results = []
217
+
218
+ for item in items:
219
+ # Build searchable text from item
220
+ api_kwargs = item.get("api_kwargs", {})
221
+ messages = api_kwargs.get("messages", []) if isinstance(api_kwargs, dict) else []
222
+ messages_text = " ".join(m.get("content", "") for m in messages)
223
+
224
+ answer = item.get("answer", "")
225
+ answer_text = str(answer) if not isinstance(answer, str) else answer
226
+
227
+ all_text = messages_text + " " + answer_text
228
+ all_text += " " + " ".join(str(v) for v in item.values() if isinstance(v, str))
229
+ all_text_lower = all_text.lower()
230
+
231
+ # Check positive terms (all must match)
232
+ if positive_terms and not all(term in all_text_lower for term in positive_terms):
233
+ continue
234
+
235
+ # Check negative terms (none must match)
236
+ if any(term in all_text_lower for term in negative_terms):
237
+ continue
238
+
239
+ results.append(item)
240
+
241
+ return results
242
+
243
+
244
+ def _streamlit_main():
245
+ """Main Streamlit app."""
246
+ import streamlit as st
247
+
248
+ st.set_page_config(
249
+ page_title="llmcomp Viewer",
250
+ page_icon="🔬",
251
+ layout="wide",
252
+ )
253
+
254
+ st.title("🔬 llmcomp Viewer")
255
+
256
+ # Get data path
257
+ data_path = _get_data_path()
258
+ if data_path is None or not os.path.exists(data_path):
259
+ st.error("No data file provided or file not found.")
260
+ st.info("Use `Question.render(df)` to launch the viewer with data.")
261
+ return
262
+
263
+ # Load data (cache in session state)
264
+ cache_key = f"llmcomp_data_{data_path}"
265
+ if cache_key not in st.session_state:
266
+ st.session_state[cache_key] = _read_jsonl(data_path)
267
+
268
+ items = st.session_state[cache_key]
269
+
270
+ if not items:
271
+ st.warning("No data to display.")
272
+ return
273
+
274
+ # Get sortable columns (numeric or string, exclude complex types)
275
+ sortable_columns = ["(none)"]
276
+ if items:
277
+ for key, value in items[0].items():
278
+ if key not in ("api_kwargs",) and isinstance(value, (int, float, str, type(None))):
279
+ sortable_columns.append(key)
280
+
281
+ # Initialize sort settings from command line args
282
+ initial_sort_by, initial_sort_asc = _get_initial_sort()
283
+ if "sort_by" not in st.session_state:
284
+ st.session_state.sort_by = initial_sort_by if initial_sort_by in sortable_columns else "(none)"
285
+ st.session_state.sort_ascending = initial_sort_asc
286
+
287
+ # Initialize view index
288
+ if "view_idx" not in st.session_state:
289
+ st.session_state.view_idx = 0
290
+
291
+ # Initialize secondary sort
292
+ if "sort_by_2" not in st.session_state:
293
+ st.session_state.sort_by_2 = "(none)"
294
+ st.session_state.sort_ascending_2 = True
295
+
296
+ # Search and sort controls
297
+ col_search, col_sort, col_order = st.columns([3, 2, 1])
298
+
299
+ with col_search:
300
+ query = st.text_input("🔍 Search", placeholder="Filter... (use -term to exclude)")
301
+
302
+ with col_sort:
303
+ sort_by = st.selectbox(
304
+ "Sort by",
305
+ options=sortable_columns,
306
+ index=sortable_columns.index(st.session_state.sort_by) if st.session_state.sort_by in sortable_columns else 0,
307
+ key="sort_by_select",
308
+ )
309
+ if sort_by != st.session_state.sort_by:
310
+ st.session_state.sort_by = sort_by
311
+ st.session_state.view_idx = 0 # Reset to first item when sort changes
312
+
313
+ with col_order:
314
+ st.markdown("<br>", unsafe_allow_html=True) # Align checkbox with selectbox
315
+ sort_ascending = st.checkbox("Asc", value=st.session_state.sort_ascending, key="sort_asc_check")
316
+ if sort_ascending != st.session_state.sort_ascending:
317
+ st.session_state.sort_ascending = sort_ascending
318
+ st.session_state.view_idx = 0
319
+
320
+ # Secondary sort (only show if primary sort is selected)
321
+ if st.session_state.sort_by and st.session_state.sort_by != "(none)":
322
+ col_spacer, col_sort2, col_order2 = st.columns([3, 2, 1])
323
+ with col_sort2:
324
+ sort_by_2 = st.selectbox(
325
+ "Then by",
326
+ options=sortable_columns,
327
+ index=sortable_columns.index(st.session_state.sort_by_2) if st.session_state.sort_by_2 in sortable_columns else 0,
328
+ key="sort_by_select_2",
329
+ )
330
+ if sort_by_2 != st.session_state.sort_by_2:
331
+ st.session_state.sort_by_2 = sort_by_2
332
+ st.session_state.view_idx = 0
333
+ with col_order2:
334
+ st.markdown("<br>", unsafe_allow_html=True) # Align checkbox with selectbox
335
+ sort_ascending_2 = st.checkbox("Asc", value=st.session_state.sort_ascending_2, key="sort_asc_check_2")
336
+ if sort_ascending_2 != st.session_state.sort_ascending_2:
337
+ st.session_state.sort_ascending_2 = sort_ascending_2
338
+ st.session_state.view_idx = 0
339
+
340
+ # Apply search
341
+ filtered_items = _search_items(items, query)
342
+
343
+ # Apply sorting (stable sort - secondary first, then primary)
344
+ if st.session_state.sort_by and st.session_state.sort_by != "(none)" and filtered_items:
345
+ sort_key_2 = st.session_state.sort_by_2 if st.session_state.sort_by_2 != "(none)" else None
346
+
347
+ # Secondary sort first (stable sort preserves this ordering within primary groups)
348
+ if sort_key_2:
349
+ filtered_items = sorted(
350
+ filtered_items,
351
+ key=lambda x: (x.get(sort_key_2) is None, x.get(sort_key_2)),
352
+ reverse=not st.session_state.sort_ascending_2,
353
+ )
354
+
355
+ # Primary sort
356
+ sort_key = st.session_state.sort_by
357
+ filtered_items = sorted(
358
+ filtered_items,
359
+ key=lambda x: (x.get(sort_key) is None, x.get(sort_key)),
360
+ reverse=not st.session_state.sort_ascending,
361
+ )
362
+
363
+ if not filtered_items:
364
+ st.warning(f"No results found for '{query}'")
365
+ return
366
+
367
+ # Clamp view index to valid range
368
+ max_idx = len(filtered_items) - 1
369
+ st.session_state.view_idx = max(0, min(st.session_state.view_idx, max_idx))
370
+
371
+ # Navigation
372
+ col1, col2, col3, col4 = st.columns([1, 1, 2, 2])
373
+
374
+ with col1:
375
+ if st.button("⬅️ Prev", use_container_width=True):
376
+ st.session_state.view_idx = max(0, st.session_state.view_idx - 1)
377
+ st.rerun()
378
+
379
+ with col2:
380
+ if st.button("Next ➡️", use_container_width=True):
381
+ st.session_state.view_idx = min(max_idx, st.session_state.view_idx + 1)
382
+ st.rerun()
383
+
384
+ with col3:
385
+ # Jump to specific index
386
+ new_idx = st.number_input(
387
+ "Go to",
388
+ min_value=1,
389
+ max_value=len(filtered_items),
390
+ value=st.session_state.view_idx + 1,
391
+ step=1,
392
+ label_visibility="collapsed",
393
+ )
394
+ if new_idx - 1 != st.session_state.view_idx:
395
+ st.session_state.view_idx = new_idx - 1
396
+ st.rerun()
397
+
398
+ with col4:
399
+ st.markdown(f"**{st.session_state.view_idx + 1}** of **{len(filtered_items)}**")
400
+ if query:
401
+ st.caption(f"({len(items)} total)")
402
+
403
+ st.divider()
404
+
405
+ # Display current item
406
+ current = filtered_items[st.session_state.view_idx]
407
+
408
+ # Main content in two columns
409
+ left_col, right_col = st.columns([1, 2])
410
+
411
+ with left_col:
412
+ st.subheader("💬 Messages")
413
+ api_kwargs = current.get("api_kwargs", {})
414
+ messages = api_kwargs.get("messages", []) if isinstance(api_kwargs, dict) else []
415
+ if messages:
416
+ _display_messages(messages)
417
+ else:
418
+ st.info("No messages")
419
+
420
+ with right_col:
421
+ model_name = current.get("model", "Response")
422
+ st.subheader(f"🤖 {model_name}")
423
+ answer = current.get("answer")
424
+ if answer is not None:
425
+ _display_answer(answer, label=None)
426
+ else:
427
+ st.info("No answer")
428
+
429
+ # Display judge columns if present
430
+ judge_columns = [k for k in current.keys() if not k.startswith("_") and k not in {
431
+ "api_kwargs", "answer", "question", "model", "group", "paraphrase_ix", "raw_answer"
432
+ } and not k.endswith("_question") and not k.endswith("_raw_answer")]
433
+
434
+ if judge_columns:
435
+ st.markdown("---")
436
+ for judge_col in judge_columns:
437
+ value = current[judge_col]
438
+ if isinstance(value, float):
439
+ st.markdown(f"**{judge_col}:** {value:.2f}")
440
+ else:
441
+ st.markdown(f"**{judge_col}:** {value}")
442
+
443
+ # Metadata at the bottom
444
+ st.divider()
445
+ # Show api_kwargs in metadata, but without messages (already displayed above)
446
+ current_for_metadata = current.copy()
447
+ if "api_kwargs" in current_for_metadata and isinstance(current_for_metadata["api_kwargs"], dict):
448
+ api_kwargs_without_messages = {k: v for k, v in current_for_metadata["api_kwargs"].items() if k != "messages"}
449
+ current_for_metadata["api_kwargs"] = api_kwargs_without_messages
450
+ exclude_keys = {"answer", "question", "paraphrase_ix"} | set(judge_columns)
451
+ _display_metadata(current_for_metadata, exclude_keys)
452
+
453
+ # Keyboard navigation hint
454
+ st.caption("💡 Tip: Use the navigation buttons or enter a number to jump to a specific row.")
455
+
456
+
457
+ # Entry point when run by Streamlit
458
+ if __name__ == "__main__":
459
+ _streamlit_main()
llmcomp/runner/runner.py CHANGED
@@ -51,12 +51,15 @@ class Runner:
51
51
  prepared = ModelAdapter.prepare(params, self.model)
52
52
  return {"timeout": Config.timeout, **prepared}
53
53
 
54
- def get_text(self, params: dict) -> str:
54
+ def get_text(self, params: dict) -> tuple[str, dict]:
55
55
  """Get a text completion from the model.
56
56
 
57
57
  Args:
58
58
  params: Dictionary of parameters for the API.
59
59
  Must include 'messages'. Other common keys: 'temperature', 'max_tokens'.
60
+
61
+ Returns:
62
+ Tuple of (content, prepared_kwargs) where prepared_kwargs is what was sent to the API.
60
63
  """
61
64
  prepared = self._prepare_for_model(params)
62
65
  completion = openai_chat_completion(client=self.client, **prepared)
@@ -72,8 +75,8 @@ class Runner:
72
75
  # refusal="I'm sorry, I'm unable to fulfill that request.",
73
76
  # ...))])
74
77
  warnings.warn(f"API sent None as content. Returning empty string.\n{completion}", stacklevel=2)
75
- return ""
76
- return content
78
+ return "", prepared
79
+ return content, prepared
77
80
  except Exception:
78
81
  warnings.warn(f"Unexpected error.\n{completion}")
79
82
  raise
@@ -84,7 +87,7 @@ class Runner:
84
87
  *,
85
88
  num_samples: int = 1,
86
89
  convert_to_probs: bool = True,
87
- ) -> dict:
90
+ ) -> tuple[dict, dict]:
88
91
  """Get probability distribution of the next token, optionally averaged over multiple samples.
89
92
 
90
93
  Args:
@@ -92,22 +95,26 @@ class Runner:
92
95
  Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
93
96
  num_samples: Number of samples to average over. Default: 1.
94
97
  convert_to_probs: If True, convert logprobs to probabilities. Default: True.
98
+
99
+ Returns:
100
+ Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
95
101
  """
96
102
  probs = {}
103
+ prepared = None
97
104
  for _ in range(num_samples):
98
- new_probs = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
105
+ new_probs, prepared = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
99
106
  for key, value in new_probs.items():
100
107
  probs[key] = probs.get(key, 0) + value
101
108
  result = {key: value / num_samples for key, value in probs.items()}
102
109
  result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
103
- return result
110
+ return result, prepared
104
111
 
105
112
  def single_token_probs_one_sample(
106
113
  self,
107
114
  params: dict,
108
115
  *,
109
116
  convert_to_probs: bool = True,
110
- ) -> dict:
117
+ ) -> tuple[dict, dict]:
111
118
  """Get probability distribution of the next token (single sample).
112
119
 
113
120
  Args:
@@ -115,6 +122,9 @@ class Runner:
115
122
  Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
116
123
  convert_to_probs: If True, convert logprobs to probabilities. Default: True.
117
124
 
125
+ Returns:
126
+ Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
127
+
118
128
  Note: This function forces max_tokens=1, temperature=0, logprobs=True.
119
129
  """
120
130
  # Build complete params with defaults and forced params
@@ -138,7 +148,7 @@ class Runner:
138
148
  except IndexError:
139
149
  # This should not happen according to the API docs. But it sometimes does.
140
150
  print(NO_LOGPROBS_WARNING.format(model=self.model, completion=completion))
141
- return {}
151
+ return {}, prepared
142
152
 
143
153
  # Check for duplicate tokens - this shouldn't happen with OpenAI but might with other providers
144
154
  tokens = [el.token for el in logprobs]
@@ -153,7 +163,7 @@ class Runner:
153
163
  for el in logprobs:
154
164
  result[el.token] = math.exp(el.logprob) if convert_to_probs else el.logprob
155
165
 
156
- return result
166
+ return result, prepared
157
167
 
158
168
  def get_many(
159
169
  self,
@@ -173,8 +183,8 @@ class Runner:
173
183
  {"params": {"messages": [{"role": "user", "content": "Hello"}]}},
174
184
  {"params": {"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7}},
175
185
  ]
176
- for in_, out in runner.get_many(runner.get_text, kwargs_list):
177
- print(in_, "->", out)
186
+ for in_, (out, prepared_kwargs) in runner.get_many(runner.get_text, kwargs_list):
187
+ print(in_, "->", out, prepared_kwargs)
178
188
 
179
189
  or
180
190
 
@@ -182,14 +192,14 @@ class Runner:
182
192
  {"params": {"messages": [{"role": "user", "content": "Hello"}]}},
183
193
  {"params": {"messages": [{"role": "user", "content": "Bye"}]}},
184
194
  ]
185
- for in_, out in runner.get_many(runner.single_token_probs, kwargs_list):
186
- print(in_, "->", out)
195
+ for in_, (out, prepared_kwargs) in runner.get_many(runner.single_token_probs, kwargs_list):
196
+ print(in_, "->", out, prepared_kwargs)
187
197
 
188
198
  (FUNC that is a different callable should also work)
189
199
 
190
200
  This function returns a generator that yields pairs (input, output),
191
- where input is an element from KWARGS_LIST and output is the thing returned by
192
- FUNC for this input.
201
+ where input is an element from KWARGS_LIST and output is the tuple (result, prepared_kwargs)
202
+ returned by FUNC. prepared_kwargs contains the actual parameters sent to the API.
193
203
 
194
204
  Dictionaries in KWARGS_LIST might include optional keys starting with underscore,
195
205
  they are just ignored, but they are returned in the first element of the pair, so that's useful
@@ -230,7 +240,7 @@ class Runner:
230
240
  f"Model: {self.model}, function: {func.__name__}{msg_info}. "
231
241
  f"Error: {type(e).__name__}: {e}"
232
242
  )
233
- result = None
243
+ result = (None, {})
234
244
  return kwargs, result
235
245
 
236
246
  futures = [executor.submit(get_data, kwargs) for kwargs in kwargs_list]
@@ -251,7 +261,7 @@ class Runner:
251
261
  params: dict,
252
262
  *,
253
263
  num_samples: int,
254
- ) -> dict:
264
+ ) -> tuple[dict, dict]:
255
265
  """Sample answers NUM_SAMPLES times. Returns probabilities of answers.
256
266
 
257
267
  Args:
@@ -259,6 +269,9 @@ class Runner:
259
269
  Must include 'messages'. Other common keys: 'max_tokens', 'temperature'.
260
270
  num_samples: Number of samples to collect.
261
271
 
272
+ Returns:
273
+ Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
274
+
262
275
  Works only if the API supports `n` parameter.
263
276
 
264
277
  Usecases:
@@ -268,6 +281,7 @@ class Runner:
268
281
  for Runner.single_token_probs.
269
282
  """
270
283
  cnts = defaultdict(int)
284
+ prepared = None
271
285
  for i in range(((num_samples - 1) // 128) + 1):
272
286
  n = min(128, num_samples - i * 128)
273
287
  # Build complete params with forced param
@@ -285,4 +299,4 @@ class Runner:
285
299
  )
286
300
  result = {key: val / num_samples for key, val in cnts.items()}
287
301
  result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
288
- return result
302
+ return result, prepared
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: llmcomp
3
- Version: 1.2.4
3
+ Version: 1.3.0
4
4
  Summary: Research library for black-box experiments on language models.
5
5
  Project-URL: Homepage, https://github.com/johny-b/llmcomp
6
6
  Project-URL: Repository, https://github.com/johny-b/llmcomp
@@ -15,6 +15,7 @@ Requires-Dist: openai>=1.0.0
15
15
  Requires-Dist: pandas
16
16
  Requires-Dist: pyyaml
17
17
  Requires-Dist: requests
18
+ Requires-Dist: streamlit>=1.20.0
18
19
  Requires-Dist: tqdm
19
20
  Description-Content-Type: text/markdown
20
21
 
@@ -49,9 +50,9 @@ question = Question.create(
49
50
  samples_per_paraphrase=100,
50
51
  temperature=1,
51
52
  )
52
- question.plot(MODELS, min_fraction=0.03)
53
- df = question.df(MODELS)
54
- print(df.head(1).iloc[0])
53
+ df = question.df(MODELS) # Dataframe with the results
54
+ question.plot(MODELS, min_fraction=0.03) # Aggregated bar chart
55
+ question.view(MODELS) # Interactive browser for individual responses
55
56
  ```
56
57
 
57
58
  ## Main features
@@ -61,6 +62,7 @@ print(df.head(1).iloc[0])
61
62
  * **Parallel requests** - configurable concurrency across models
62
63
  * **Multi-key support** - use `OPENAI_API_KEY_0`, `OPENAI_API_KEY_1`, etc. to compare models from different orgs
63
64
  * **Provider-agnostic** - works with any OpenAI-compatible API ([OpenRouter](https://openrouter.ai/docs/quickstart#using-the-openai-sdk), [Tinker](https://tinker-docs.thinkingmachines.ai/compatible-apis/openai), etc.)
65
+ * **Built-in viewer** - browse answers interactively with `question.view(MODELS)`
64
66
  * **Extensible** - highly configurable as long as your goal is comparing LLMs
65
67
 
66
68
  ## Cookbook
@@ -148,7 +150,7 @@ You can send more parallel requests by increasing `Config.max_workers`.
148
150
  Suppose you have many prompts you want to send to models. There are three options:
149
151
  1. Have a separate Question object for each prompt and execute them in a loop
150
152
  2. Have a separate Question object for each prompt and execute them in parallel
151
- 3. Have a single Question object with many paraphrases and then split the resulting dataframe (using any of the `paraphrase_ix`, `question` or `messages` columns)
153
+ 3. Have a single Question object with many paraphrases and then split the resulting dataframe (using any of the `paraphrase_ix` or `question` columns)
152
154
 
153
155
  Option 1 will be slow - the more quick questions you have, the worse.
154
156
  Option 2 will be fast, but you need to write parallelization yourself. Question should be thread-safe, but parallel execution of questions was **never** tested. One thing that won't work: `llmcomp.Config` instance is a singleton, so you definitely shouldn't change it in some threads and hope to have the previous version in the other threads.
@@ -0,0 +1,21 @@
1
+ llmcomp/__init__.py,sha256=y_oUvd0Q3jhF-lf8UD3eF-2ppEuZmccqpYJItXEoTns,267
2
+ llmcomp/config.py,sha256=xADWhqsQphJZQvf7WemWencmWuBnvTN_KeJrjWfnmHY,8942
3
+ llmcomp/default_adapters.py,sha256=txs6NUOwGttC8jUahaRsoPCTbE5riBE7yKdAGPvKRhM,2578
4
+ llmcomp/utils.py,sha256=8-jakxvwbMqfDkelE9ZY1q8Fo538Y_ryRv6PizRhHR0,2683
5
+ llmcomp/finetuning/__init__.py,sha256=UEdwtJNVVqWjhrxvLvRLW4W4xjkKKwOR-GRkDxCP2Qo,58
6
+ llmcomp/finetuning/manager.py,sha256=6G0CW3NWK8vdfBoAjH0HATx_g16wwq5oU0mlHs-q28o,19083
7
+ llmcomp/finetuning/update_jobs.py,sha256=blsHzg_ViTa2hBJtWCqR5onttehTtmXn3vmCTNd_hJw,980
8
+ llmcomp/finetuning/validation.py,sha256=v4FoFw8woo5No9A01ktuALsMsXdgb3N2rS58ttBUmHY,14047
9
+ llmcomp/question/judge.py,sha256=tNY94AHqncrbl2gf-g_Y3lepJ_HrahJRH-WgQyokegk,6568
10
+ llmcomp/question/plots.py,sha256=Izp9jxWzQDgRgycgM7_-lhIkqx7yr_WBQedUcUcpaFA,11164
11
+ llmcomp/question/question.py,sha256=cLOVp8ZD0O-Y1UI8RVpi6ZD3ulRtY8PeFwEgeAnLzvs,41100
12
+ llmcomp/question/result.py,sha256=psc9tQpwEEhS4LGxaI7GhqCE1CSAmCo39yrKap9cLjA,8216
13
+ llmcomp/question/viewer.py,sha256=hMHWr5cONWXF37ybXJTI_kudSz3xaA0shkQFRoNRZmI,16380
14
+ llmcomp/runner/chat_completion.py,sha256=iDiWE0N0_MYfggD-ouyfUPyaADt7602K5Wo16a7JJo4,967
15
+ llmcomp/runner/model_adapter.py,sha256=Dua98E7aBVrCaZ2Ep44vl164oFkpH1P78YqImQkns4U,3406
16
+ llmcomp/runner/runner.py,sha256=B8p9b3At9JWWIW-mlADwyelJKqHxW4CIorSWyaD3gHM,12294
17
+ llmcomp-1.3.0.dist-info/METADATA,sha256=CWC5sdrfuvQWWFOwjj7RJIzk0Rgb3EKCRPA75D5Wu4U,12963
18
+ llmcomp-1.3.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
19
+ llmcomp-1.3.0.dist-info/entry_points.txt,sha256=1aoN8_W9LDUnX7OIOX7ACmzNkbBMJ6GqNn_A1KUKjQc,76
20
+ llmcomp-1.3.0.dist-info/licenses/LICENSE,sha256=z7WR2X27WF_wZNuzfNFNlkt9cU7eFwP_3-qx7RyrGK4,1064
21
+ llmcomp-1.3.0.dist-info/RECORD,,