llmcomp 1.2.4__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmcomp/finetuning/manager.py +21 -0
- llmcomp/finetuning/validation.py +406 -0
- llmcomp/question/judge.py +11 -0
- llmcomp/question/plots.py +150 -71
- llmcomp/question/question.py +255 -190
- llmcomp/question/result.py +33 -10
- llmcomp/question/viewer.py +488 -0
- llmcomp/runner/runner.py +32 -18
- {llmcomp-1.2.4.dist-info → llmcomp-1.3.1.dist-info}/METADATA +8 -5
- llmcomp-1.3.1.dist-info/RECORD +21 -0
- llmcomp-1.2.4.dist-info/RECORD +0 -19
- {llmcomp-1.2.4.dist-info → llmcomp-1.3.1.dist-info}/WHEEL +0 -0
- {llmcomp-1.2.4.dist-info → llmcomp-1.3.1.dist-info}/entry_points.txt +0 -0
- {llmcomp-1.2.4.dist-info → llmcomp-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
"""DataFrame viewer for browsing question results.
|
|
2
|
+
|
|
3
|
+
Spawns a local Streamlit server to interactively browse (api_kwargs, answer) pairs.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
from llmcomp import Question
|
|
7
|
+
|
|
8
|
+
question = Question.create(...)
|
|
9
|
+
df = question.df(models)
|
|
10
|
+
Question.view(df)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import os
|
|
15
|
+
import subprocess
|
|
16
|
+
import sys
|
|
17
|
+
import tempfile
|
|
18
|
+
import webbrowser
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
# Streamlit imports are inside functions to avoid import errors when streamlit isn't installed
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def render_dataframe(
|
|
26
|
+
df: "pd.DataFrame",
|
|
27
|
+
sort_by: str | None = "__random__",
|
|
28
|
+
sort_ascending: bool = True,
|
|
29
|
+
open_browser: bool = True,
|
|
30
|
+
port: int = 8501,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Launch a Streamlit viewer for the DataFrame.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
df: DataFrame with at least 'api_kwargs' and 'answer' columns.
|
|
36
|
+
Other columns (model, group, etc.) are displayed as metadata.
|
|
37
|
+
sort_by: Column name to sort by initially. Default: "__random__" for random
|
|
38
|
+
shuffling (new seed on each refresh). Use None for original order.
|
|
39
|
+
sort_ascending: Sort order. Default: True (ascending).
|
|
40
|
+
open_browser: If True, automatically open the viewer in default browser.
|
|
41
|
+
port: Port to run the Streamlit server on.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
ValueError: If required columns are missing.
|
|
45
|
+
"""
|
|
46
|
+
# Validate required columns
|
|
47
|
+
if "api_kwargs" not in df.columns:
|
|
48
|
+
raise ValueError("DataFrame must have an 'api_kwargs' column")
|
|
49
|
+
if "answer" not in df.columns:
|
|
50
|
+
raise ValueError("DataFrame must have an 'answer' column")
|
|
51
|
+
if sort_by is not None and sort_by != "__random__" and sort_by not in df.columns:
|
|
52
|
+
raise ValueError(f"sort_by column '{sort_by}' not found in DataFrame")
|
|
53
|
+
|
|
54
|
+
# Save DataFrame to a temp file
|
|
55
|
+
temp_dir = tempfile.mkdtemp(prefix="llmcomp_viewer_")
|
|
56
|
+
temp_path = os.path.join(temp_dir, "data.jsonl")
|
|
57
|
+
|
|
58
|
+
# Convert DataFrame to JSONL
|
|
59
|
+
with open(temp_path, "w", encoding="utf-8") as f:
|
|
60
|
+
for _, row in df.iterrows():
|
|
61
|
+
row_dict = row.to_dict()
|
|
62
|
+
f.write(json.dumps(row_dict, default=str) + "\n")
|
|
63
|
+
|
|
64
|
+
url = f"http://localhost:{port}"
|
|
65
|
+
print(f"Starting viewer at {url}")
|
|
66
|
+
print(f"Data file: {temp_path}")
|
|
67
|
+
print("Press Ctrl+C to stop the server.\n")
|
|
68
|
+
|
|
69
|
+
if open_browser:
|
|
70
|
+
# Open browser after a short delay to let server start
|
|
71
|
+
import threading
|
|
72
|
+
threading.Timer(0.5, lambda: webbrowser.open(url)).start()
|
|
73
|
+
|
|
74
|
+
# Launch Streamlit
|
|
75
|
+
viewer_path = Path(__file__).resolve()
|
|
76
|
+
cmd = [
|
|
77
|
+
sys.executable, "-m", "streamlit", "run",
|
|
78
|
+
str(viewer_path),
|
|
79
|
+
"--server.port", str(port),
|
|
80
|
+
"--server.headless", "true",
|
|
81
|
+
"--", # Separator for script args
|
|
82
|
+
temp_path,
|
|
83
|
+
sort_by or "", # Empty string means no sorting
|
|
84
|
+
"asc" if sort_ascending else "desc",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
subprocess.run(cmd, check=True)
|
|
89
|
+
except KeyboardInterrupt:
|
|
90
|
+
print("\nViewer stopped.")
|
|
91
|
+
finally:
|
|
92
|
+
# Clean up temp file
|
|
93
|
+
try:
|
|
94
|
+
os.remove(temp_path)
|
|
95
|
+
os.rmdir(temp_dir)
|
|
96
|
+
except OSError:
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# =============================================================================
|
|
101
|
+
# Streamlit App (runs when this file is executed by streamlit)
|
|
102
|
+
# =============================================================================
|
|
103
|
+
|
|
104
|
+
def _get_data_path() -> str | None:
|
|
105
|
+
"""Get data file path from command line args."""
|
|
106
|
+
# Args after -- are passed to the script
|
|
107
|
+
if len(sys.argv) > 1:
|
|
108
|
+
return sys.argv[1]
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _get_initial_sort() -> tuple[str | None, bool]:
|
|
113
|
+
"""Get initial sort settings from command line args."""
|
|
114
|
+
sort_by = None
|
|
115
|
+
sort_ascending = True
|
|
116
|
+
|
|
117
|
+
if len(sys.argv) > 2:
|
|
118
|
+
sort_by = sys.argv[2] if sys.argv[2] else None
|
|
119
|
+
if len(sys.argv) > 3:
|
|
120
|
+
sort_ascending = sys.argv[3] != "desc"
|
|
121
|
+
|
|
122
|
+
return sort_by, sort_ascending
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _read_jsonl(path: str) -> list[dict[str, Any]]:
|
|
126
|
+
"""Read JSONL file into a list of dicts."""
|
|
127
|
+
items = []
|
|
128
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
129
|
+
for line in f:
|
|
130
|
+
line = line.strip()
|
|
131
|
+
if line:
|
|
132
|
+
items.append(json.loads(line))
|
|
133
|
+
return items
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _display_messages(messages: list[dict[str, str]]) -> None:
|
|
137
|
+
"""Display a list of chat messages in Streamlit chat format."""
|
|
138
|
+
import streamlit as st
|
|
139
|
+
|
|
140
|
+
for msg in messages:
|
|
141
|
+
role = msg.get("role", "user")
|
|
142
|
+
content = msg.get("content", "")
|
|
143
|
+
|
|
144
|
+
# Map roles to streamlit chat_message roles
|
|
145
|
+
if role == "system":
|
|
146
|
+
with st.chat_message("assistant", avatar="⚙️"):
|
|
147
|
+
st.markdown("**System**")
|
|
148
|
+
st.text(content)
|
|
149
|
+
elif role == "assistant":
|
|
150
|
+
with st.chat_message("assistant"):
|
|
151
|
+
st.text(content)
|
|
152
|
+
else: # user or other
|
|
153
|
+
with st.chat_message("user"):
|
|
154
|
+
st.text(content)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _display_answer(answer: Any, label: str | None = None) -> None:
|
|
158
|
+
"""Display the answer, handling different types."""
|
|
159
|
+
import streamlit as st
|
|
160
|
+
|
|
161
|
+
if label:
|
|
162
|
+
st.markdown(f"**{label}**")
|
|
163
|
+
|
|
164
|
+
if isinstance(answer, dict):
|
|
165
|
+
# For NextToken questions, answer is {token: probability}
|
|
166
|
+
# Sort by probability descending
|
|
167
|
+
sorted_items = sorted(answer.items(), key=lambda x: -x[1] if isinstance(x[1], (int, float)) else 0)
|
|
168
|
+
# Display as a table-like format
|
|
169
|
+
for token, prob in sorted_items[:20]: # Show top 20
|
|
170
|
+
if isinstance(prob, float):
|
|
171
|
+
st.text(f" {token!r}: {prob:.4f}")
|
|
172
|
+
else:
|
|
173
|
+
st.text(f" {token!r}: {prob}")
|
|
174
|
+
elif isinstance(answer, str):
|
|
175
|
+
st.text(answer)
|
|
176
|
+
else:
|
|
177
|
+
st.text(str(answer))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _display_metadata(row: dict[str, Any], exclude_keys: set[str]) -> None:
|
|
181
|
+
"""Display metadata columns."""
|
|
182
|
+
import streamlit as st
|
|
183
|
+
|
|
184
|
+
metadata = {k: v for k, v in row.items() if k not in exclude_keys}
|
|
185
|
+
if metadata:
|
|
186
|
+
with st.expander("Metadata", expanded=False):
|
|
187
|
+
for key, value in metadata.items():
|
|
188
|
+
if isinstance(value, (dict, list)):
|
|
189
|
+
st.markdown(f"**{key}:**")
|
|
190
|
+
# Collapse _raw_answer and _probs dicts by default
|
|
191
|
+
collapsed = key.endswith("_raw_answer") or key.endswith("_probs")
|
|
192
|
+
st.json(value, expanded=not collapsed)
|
|
193
|
+
else:
|
|
194
|
+
st.markdown(f"**{key}:** {value}")
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _search_items(items: list[dict[str, Any]], query: str) -> list[dict[str, Any]]:
|
|
198
|
+
"""Filter items by search query.
|
|
199
|
+
|
|
200
|
+
Supports:
|
|
201
|
+
- Regular search: "foo" - includes items containing "foo"
|
|
202
|
+
- Negative search: "-foo" - excludes items containing "foo"
|
|
203
|
+
- Combined: "foo -bar" - items with "foo" but not "bar"
|
|
204
|
+
"""
|
|
205
|
+
if not query:
|
|
206
|
+
return items
|
|
207
|
+
|
|
208
|
+
# Parse query into positive and negative terms
|
|
209
|
+
terms = query.split()
|
|
210
|
+
positive_terms = []
|
|
211
|
+
negative_terms = []
|
|
212
|
+
|
|
213
|
+
for term in terms:
|
|
214
|
+
if term.startswith("-") and len(term) > 1:
|
|
215
|
+
negative_terms.append(term[1:].lower())
|
|
216
|
+
else:
|
|
217
|
+
positive_terms.append(term.lower())
|
|
218
|
+
|
|
219
|
+
results = []
|
|
220
|
+
|
|
221
|
+
for item in items:
|
|
222
|
+
# Build searchable text from item
|
|
223
|
+
api_kwargs = item.get("api_kwargs", {})
|
|
224
|
+
messages = api_kwargs.get("messages", []) if isinstance(api_kwargs, dict) else []
|
|
225
|
+
messages_text = " ".join(m.get("content", "") for m in messages)
|
|
226
|
+
|
|
227
|
+
answer = item.get("answer", "")
|
|
228
|
+
answer_text = str(answer) if not isinstance(answer, str) else answer
|
|
229
|
+
|
|
230
|
+
all_text = messages_text + " " + answer_text
|
|
231
|
+
all_text += " " + " ".join(str(v) for v in item.values() if isinstance(v, str))
|
|
232
|
+
all_text_lower = all_text.lower()
|
|
233
|
+
|
|
234
|
+
# Check positive terms (all must match)
|
|
235
|
+
if positive_terms and not all(term in all_text_lower for term in positive_terms):
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
# Check negative terms (none must match)
|
|
239
|
+
if any(term in all_text_lower for term in negative_terms):
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
results.append(item)
|
|
243
|
+
|
|
244
|
+
return results
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _streamlit_main():
|
|
248
|
+
"""Main Streamlit app."""
|
|
249
|
+
import streamlit as st
|
|
250
|
+
|
|
251
|
+
st.set_page_config(
|
|
252
|
+
page_title="llmcomp Viewer",
|
|
253
|
+
page_icon="🔬",
|
|
254
|
+
layout="wide",
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
st.title("🔬 llmcomp Viewer")
|
|
258
|
+
|
|
259
|
+
# Get data path
|
|
260
|
+
data_path = _get_data_path()
|
|
261
|
+
if data_path is None or not os.path.exists(data_path):
|
|
262
|
+
st.error("No data file provided or file not found.")
|
|
263
|
+
st.info("Use `Question.render(df)` to launch the viewer with data.")
|
|
264
|
+
return
|
|
265
|
+
|
|
266
|
+
# Load data (cache in session state)
|
|
267
|
+
cache_key = f"llmcomp_data_{data_path}"
|
|
268
|
+
if cache_key not in st.session_state:
|
|
269
|
+
st.session_state[cache_key] = _read_jsonl(data_path)
|
|
270
|
+
|
|
271
|
+
items = st.session_state[cache_key]
|
|
272
|
+
|
|
273
|
+
if not items:
|
|
274
|
+
st.warning("No data to display.")
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
# Get sortable columns (numeric or string, exclude complex types)
|
|
278
|
+
sortable_columns = ["(random)", "(none)"]
|
|
279
|
+
if items:
|
|
280
|
+
for key, value in items[0].items():
|
|
281
|
+
if key not in ("api_kwargs",) and isinstance(value, (int, float, str, type(None))):
|
|
282
|
+
sortable_columns.append(key)
|
|
283
|
+
|
|
284
|
+
# Initialize sort settings from command line args
|
|
285
|
+
initial_sort_by, initial_sort_asc = _get_initial_sort()
|
|
286
|
+
if "sort_by" not in st.session_state:
|
|
287
|
+
# Map __random__ from CLI to (random) in UI
|
|
288
|
+
if initial_sort_by == "__random__":
|
|
289
|
+
st.session_state.sort_by = "(random)"
|
|
290
|
+
elif initial_sort_by in sortable_columns:
|
|
291
|
+
st.session_state.sort_by = initial_sort_by
|
|
292
|
+
else:
|
|
293
|
+
st.session_state.sort_by = "(none)"
|
|
294
|
+
st.session_state.sort_ascending = initial_sort_asc
|
|
295
|
+
|
|
296
|
+
# Initialize view index
|
|
297
|
+
if "view_idx" not in st.session_state:
|
|
298
|
+
st.session_state.view_idx = 0
|
|
299
|
+
|
|
300
|
+
# Initialize secondary sort
|
|
301
|
+
if "sort_by_2" not in st.session_state:
|
|
302
|
+
st.session_state.sort_by_2 = "(none)"
|
|
303
|
+
st.session_state.sort_ascending_2 = True
|
|
304
|
+
|
|
305
|
+
# Search and sort controls
|
|
306
|
+
col_search, col_sort, col_order = st.columns([3, 2, 1])
|
|
307
|
+
|
|
308
|
+
with col_search:
|
|
309
|
+
query = st.text_input("🔍 Search", placeholder="Filter... (use -term to exclude)")
|
|
310
|
+
|
|
311
|
+
with col_sort:
|
|
312
|
+
sort_by = st.selectbox(
|
|
313
|
+
"Sort by",
|
|
314
|
+
options=sortable_columns,
|
|
315
|
+
index=sortable_columns.index(st.session_state.sort_by) if st.session_state.sort_by in sortable_columns else 0,
|
|
316
|
+
key="sort_by_select",
|
|
317
|
+
)
|
|
318
|
+
if sort_by != st.session_state.sort_by:
|
|
319
|
+
st.session_state.sort_by = sort_by
|
|
320
|
+
st.session_state.view_idx = 0 # Reset to first item when sort changes
|
|
321
|
+
|
|
322
|
+
with col_order:
|
|
323
|
+
st.markdown("<br>", unsafe_allow_html=True) # Align checkbox with selectbox
|
|
324
|
+
sort_ascending = st.checkbox("Asc", value=st.session_state.sort_ascending, key="sort_asc_check")
|
|
325
|
+
if sort_ascending != st.session_state.sort_ascending:
|
|
326
|
+
st.session_state.sort_ascending = sort_ascending
|
|
327
|
+
st.session_state.view_idx = 0
|
|
328
|
+
|
|
329
|
+
# Reshuffle button for random sort
|
|
330
|
+
if st.session_state.sort_by == "(random)":
|
|
331
|
+
import random
|
|
332
|
+
col_reshuffle, _ = st.columns([1, 5])
|
|
333
|
+
with col_reshuffle:
|
|
334
|
+
if st.button("🔀 Reshuffle"):
|
|
335
|
+
st.session_state.random_seed = random.randint(0, 2**32 - 1)
|
|
336
|
+
st.session_state.view_idx = 0
|
|
337
|
+
st.rerun()
|
|
338
|
+
|
|
339
|
+
# Secondary sort (only show if primary sort is selected)
|
|
340
|
+
if st.session_state.sort_by and st.session_state.sort_by != "(none)":
|
|
341
|
+
col_spacer, col_sort2, col_order2 = st.columns([3, 2, 1])
|
|
342
|
+
with col_sort2:
|
|
343
|
+
sort_by_2 = st.selectbox(
|
|
344
|
+
"Then by",
|
|
345
|
+
options=sortable_columns,
|
|
346
|
+
index=sortable_columns.index(st.session_state.sort_by_2) if st.session_state.sort_by_2 in sortable_columns else 0,
|
|
347
|
+
key="sort_by_select_2",
|
|
348
|
+
)
|
|
349
|
+
if sort_by_2 != st.session_state.sort_by_2:
|
|
350
|
+
st.session_state.sort_by_2 = sort_by_2
|
|
351
|
+
st.session_state.view_idx = 0
|
|
352
|
+
with col_order2:
|
|
353
|
+
st.markdown("<br>", unsafe_allow_html=True) # Align checkbox with selectbox
|
|
354
|
+
sort_ascending_2 = st.checkbox("Asc", value=st.session_state.sort_ascending_2, key="sort_asc_check_2")
|
|
355
|
+
if sort_ascending_2 != st.session_state.sort_ascending_2:
|
|
356
|
+
st.session_state.sort_ascending_2 = sort_ascending_2
|
|
357
|
+
st.session_state.view_idx = 0
|
|
358
|
+
|
|
359
|
+
# Apply search
|
|
360
|
+
filtered_items = _search_items(items, query)
|
|
361
|
+
|
|
362
|
+
# Apply random shuffle if selected (new seed on each refresh via Reshuffle button)
|
|
363
|
+
if st.session_state.sort_by == "(random)" and filtered_items:
|
|
364
|
+
import random
|
|
365
|
+
# Generate a new seed on first load or when explicitly reshuffled
|
|
366
|
+
if "random_seed" not in st.session_state:
|
|
367
|
+
st.session_state.random_seed = random.randint(0, 2**32 - 1)
|
|
368
|
+
rng = random.Random(st.session_state.random_seed)
|
|
369
|
+
filtered_items = filtered_items.copy()
|
|
370
|
+
rng.shuffle(filtered_items)
|
|
371
|
+
|
|
372
|
+
# Apply sorting (stable sort - secondary first, then primary)
|
|
373
|
+
if st.session_state.sort_by and st.session_state.sort_by not in ("(none)", "(random)") and filtered_items:
|
|
374
|
+
sort_key_2 = st.session_state.sort_by_2 if st.session_state.sort_by_2 != "(none)" else None
|
|
375
|
+
|
|
376
|
+
# Secondary sort first (stable sort preserves this ordering within primary groups)
|
|
377
|
+
if sort_key_2:
|
|
378
|
+
filtered_items = sorted(
|
|
379
|
+
filtered_items,
|
|
380
|
+
key=lambda x: (x.get(sort_key_2) is None, x.get(sort_key_2)),
|
|
381
|
+
reverse=not st.session_state.sort_ascending_2,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Primary sort
|
|
385
|
+
sort_key = st.session_state.sort_by
|
|
386
|
+
filtered_items = sorted(
|
|
387
|
+
filtered_items,
|
|
388
|
+
key=lambda x: (x.get(sort_key) is None, x.get(sort_key)),
|
|
389
|
+
reverse=not st.session_state.sort_ascending,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if not filtered_items:
|
|
393
|
+
st.warning(f"No results found for '{query}'")
|
|
394
|
+
return
|
|
395
|
+
|
|
396
|
+
# Clamp view index to valid range
|
|
397
|
+
max_idx = len(filtered_items) - 1
|
|
398
|
+
st.session_state.view_idx = max(0, min(st.session_state.view_idx, max_idx))
|
|
399
|
+
|
|
400
|
+
# Navigation
|
|
401
|
+
col1, col2, col3, col4 = st.columns([1, 1, 2, 2])
|
|
402
|
+
|
|
403
|
+
with col1:
|
|
404
|
+
if st.button("⬅️ Prev", use_container_width=True):
|
|
405
|
+
st.session_state.view_idx = max(0, st.session_state.view_idx - 1)
|
|
406
|
+
st.rerun()
|
|
407
|
+
|
|
408
|
+
with col2:
|
|
409
|
+
if st.button("Next ➡️", use_container_width=True):
|
|
410
|
+
st.session_state.view_idx = min(max_idx, st.session_state.view_idx + 1)
|
|
411
|
+
st.rerun()
|
|
412
|
+
|
|
413
|
+
with col3:
|
|
414
|
+
# Jump to specific index
|
|
415
|
+
new_idx = st.number_input(
|
|
416
|
+
"Go to",
|
|
417
|
+
min_value=1,
|
|
418
|
+
max_value=len(filtered_items),
|
|
419
|
+
value=st.session_state.view_idx + 1,
|
|
420
|
+
step=1,
|
|
421
|
+
label_visibility="collapsed",
|
|
422
|
+
)
|
|
423
|
+
if new_idx - 1 != st.session_state.view_idx:
|
|
424
|
+
st.session_state.view_idx = new_idx - 1
|
|
425
|
+
st.rerun()
|
|
426
|
+
|
|
427
|
+
with col4:
|
|
428
|
+
st.markdown(f"**{st.session_state.view_idx + 1}** of **{len(filtered_items)}**")
|
|
429
|
+
if query:
|
|
430
|
+
st.caption(f"({len(items)} total)")
|
|
431
|
+
|
|
432
|
+
st.divider()
|
|
433
|
+
|
|
434
|
+
# Display current item
|
|
435
|
+
current = filtered_items[st.session_state.view_idx]
|
|
436
|
+
|
|
437
|
+
# Main content in two columns
|
|
438
|
+
left_col, right_col = st.columns([1, 2])
|
|
439
|
+
|
|
440
|
+
with left_col:
|
|
441
|
+
st.subheader("💬 Messages")
|
|
442
|
+
api_kwargs = current.get("api_kwargs", {})
|
|
443
|
+
messages = api_kwargs.get("messages", []) if isinstance(api_kwargs, dict) else []
|
|
444
|
+
if messages:
|
|
445
|
+
_display_messages(messages)
|
|
446
|
+
else:
|
|
447
|
+
st.info("No messages")
|
|
448
|
+
|
|
449
|
+
with right_col:
|
|
450
|
+
model_name = current.get("model", "Response")
|
|
451
|
+
st.subheader(f"🤖 {model_name}")
|
|
452
|
+
answer = current.get("answer")
|
|
453
|
+
if answer is not None:
|
|
454
|
+
_display_answer(answer, label=None)
|
|
455
|
+
else:
|
|
456
|
+
st.info("No answer")
|
|
457
|
+
|
|
458
|
+
# Display judge columns if present
|
|
459
|
+
judge_columns = [k for k in current.keys() if not k.startswith("_") and k not in {
|
|
460
|
+
"api_kwargs", "answer", "question", "model", "group", "paraphrase_ix", "raw_answer"
|
|
461
|
+
} and not k.endswith("_question") and not k.endswith("_raw_answer") and not k.endswith("_probs")]
|
|
462
|
+
|
|
463
|
+
if judge_columns:
|
|
464
|
+
st.markdown("---")
|
|
465
|
+
for judge_col in judge_columns:
|
|
466
|
+
value = current[judge_col]
|
|
467
|
+
if isinstance(value, float):
|
|
468
|
+
st.markdown(f"**{judge_col}:** {value:.2f}")
|
|
469
|
+
else:
|
|
470
|
+
st.markdown(f"**{judge_col}:** {value}")
|
|
471
|
+
|
|
472
|
+
# Metadata at the bottom
|
|
473
|
+
st.divider()
|
|
474
|
+
# Show api_kwargs in metadata, but without messages (already displayed above)
|
|
475
|
+
current_for_metadata = current.copy()
|
|
476
|
+
if "api_kwargs" in current_for_metadata and isinstance(current_for_metadata["api_kwargs"], dict):
|
|
477
|
+
api_kwargs_without_messages = {k: v for k, v in current_for_metadata["api_kwargs"].items() if k != "messages"}
|
|
478
|
+
current_for_metadata["api_kwargs"] = api_kwargs_without_messages
|
|
479
|
+
exclude_keys = {"answer", "question", "paraphrase_ix"} | set(judge_columns)
|
|
480
|
+
_display_metadata(current_for_metadata, exclude_keys)
|
|
481
|
+
|
|
482
|
+
# Keyboard navigation hint
|
|
483
|
+
st.caption("💡 Tip: Use the navigation buttons or enter a number to jump to a specific row.")
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
# Entry point when run by Streamlit
|
|
487
|
+
if __name__ == "__main__":
|
|
488
|
+
_streamlit_main()
|
llmcomp/runner/runner.py
CHANGED
|
@@ -51,12 +51,15 @@ class Runner:
|
|
|
51
51
|
prepared = ModelAdapter.prepare(params, self.model)
|
|
52
52
|
return {"timeout": Config.timeout, **prepared}
|
|
53
53
|
|
|
54
|
-
def get_text(self, params: dict) -> str:
|
|
54
|
+
def get_text(self, params: dict) -> tuple[str, dict]:
|
|
55
55
|
"""Get a text completion from the model.
|
|
56
56
|
|
|
57
57
|
Args:
|
|
58
58
|
params: Dictionary of parameters for the API.
|
|
59
59
|
Must include 'messages'. Other common keys: 'temperature', 'max_tokens'.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Tuple of (content, prepared_kwargs) where prepared_kwargs is what was sent to the API.
|
|
60
63
|
"""
|
|
61
64
|
prepared = self._prepare_for_model(params)
|
|
62
65
|
completion = openai_chat_completion(client=self.client, **prepared)
|
|
@@ -72,8 +75,8 @@ class Runner:
|
|
|
72
75
|
# refusal="I'm sorry, I'm unable to fulfill that request.",
|
|
73
76
|
# ...))])
|
|
74
77
|
warnings.warn(f"API sent None as content. Returning empty string.\n{completion}", stacklevel=2)
|
|
75
|
-
return ""
|
|
76
|
-
return content
|
|
78
|
+
return "", prepared
|
|
79
|
+
return content, prepared
|
|
77
80
|
except Exception:
|
|
78
81
|
warnings.warn(f"Unexpected error.\n{completion}")
|
|
79
82
|
raise
|
|
@@ -84,7 +87,7 @@ class Runner:
|
|
|
84
87
|
*,
|
|
85
88
|
num_samples: int = 1,
|
|
86
89
|
convert_to_probs: bool = True,
|
|
87
|
-
) -> dict:
|
|
90
|
+
) -> tuple[dict, dict]:
|
|
88
91
|
"""Get probability distribution of the next token, optionally averaged over multiple samples.
|
|
89
92
|
|
|
90
93
|
Args:
|
|
@@ -92,22 +95,26 @@ class Runner:
|
|
|
92
95
|
Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
|
|
93
96
|
num_samples: Number of samples to average over. Default: 1.
|
|
94
97
|
convert_to_probs: If True, convert logprobs to probabilities. Default: True.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
|
|
95
101
|
"""
|
|
96
102
|
probs = {}
|
|
103
|
+
prepared = None
|
|
97
104
|
for _ in range(num_samples):
|
|
98
|
-
new_probs = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
|
|
105
|
+
new_probs, prepared = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
|
|
99
106
|
for key, value in new_probs.items():
|
|
100
107
|
probs[key] = probs.get(key, 0) + value
|
|
101
108
|
result = {key: value / num_samples for key, value in probs.items()}
|
|
102
109
|
result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
|
|
103
|
-
return result
|
|
110
|
+
return result, prepared
|
|
104
111
|
|
|
105
112
|
def single_token_probs_one_sample(
|
|
106
113
|
self,
|
|
107
114
|
params: dict,
|
|
108
115
|
*,
|
|
109
116
|
convert_to_probs: bool = True,
|
|
110
|
-
) -> dict:
|
|
117
|
+
) -> tuple[dict, dict]:
|
|
111
118
|
"""Get probability distribution of the next token (single sample).
|
|
112
119
|
|
|
113
120
|
Args:
|
|
@@ -115,6 +122,9 @@ class Runner:
|
|
|
115
122
|
Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
|
|
116
123
|
convert_to_probs: If True, convert logprobs to probabilities. Default: True.
|
|
117
124
|
|
|
125
|
+
Returns:
|
|
126
|
+
Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
|
|
127
|
+
|
|
118
128
|
Note: This function forces max_tokens=1, temperature=0, logprobs=True.
|
|
119
129
|
"""
|
|
120
130
|
# Build complete params with defaults and forced params
|
|
@@ -138,7 +148,7 @@ class Runner:
|
|
|
138
148
|
except IndexError:
|
|
139
149
|
# This should not happen according to the API docs. But it sometimes does.
|
|
140
150
|
print(NO_LOGPROBS_WARNING.format(model=self.model, completion=completion))
|
|
141
|
-
return {}
|
|
151
|
+
return {}, prepared
|
|
142
152
|
|
|
143
153
|
# Check for duplicate tokens - this shouldn't happen with OpenAI but might with other providers
|
|
144
154
|
tokens = [el.token for el in logprobs]
|
|
@@ -153,7 +163,7 @@ class Runner:
|
|
|
153
163
|
for el in logprobs:
|
|
154
164
|
result[el.token] = math.exp(el.logprob) if convert_to_probs else el.logprob
|
|
155
165
|
|
|
156
|
-
return result
|
|
166
|
+
return result, prepared
|
|
157
167
|
|
|
158
168
|
def get_many(
|
|
159
169
|
self,
|
|
@@ -173,8 +183,8 @@ class Runner:
|
|
|
173
183
|
{"params": {"messages": [{"role": "user", "content": "Hello"}]}},
|
|
174
184
|
{"params": {"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7}},
|
|
175
185
|
]
|
|
176
|
-
for in_, out in runner.get_many(runner.get_text, kwargs_list):
|
|
177
|
-
print(in_, "->", out)
|
|
186
|
+
for in_, (out, prepared_kwargs) in runner.get_many(runner.get_text, kwargs_list):
|
|
187
|
+
print(in_, "->", out, prepared_kwargs)
|
|
178
188
|
|
|
179
189
|
or
|
|
180
190
|
|
|
@@ -182,14 +192,14 @@ class Runner:
|
|
|
182
192
|
{"params": {"messages": [{"role": "user", "content": "Hello"}]}},
|
|
183
193
|
{"params": {"messages": [{"role": "user", "content": "Bye"}]}},
|
|
184
194
|
]
|
|
185
|
-
for in_, out in runner.get_many(runner.single_token_probs, kwargs_list):
|
|
186
|
-
print(in_, "->", out)
|
|
195
|
+
for in_, (out, prepared_kwargs) in runner.get_many(runner.single_token_probs, kwargs_list):
|
|
196
|
+
print(in_, "->", out, prepared_kwargs)
|
|
187
197
|
|
|
188
198
|
(FUNC that is a different callable should also work)
|
|
189
199
|
|
|
190
200
|
This function returns a generator that yields pairs (input, output),
|
|
191
|
-
where input is an element from KWARGS_LIST and output is the
|
|
192
|
-
FUNC
|
|
201
|
+
where input is an element from KWARGS_LIST and output is the tuple (result, prepared_kwargs)
|
|
202
|
+
returned by FUNC. prepared_kwargs contains the actual parameters sent to the API.
|
|
193
203
|
|
|
194
204
|
Dictionaries in KWARGS_LIST might include optional keys starting with underscore,
|
|
195
205
|
they are just ignored, but they are returned in the first element of the pair, so that's useful
|
|
@@ -230,7 +240,7 @@ class Runner:
|
|
|
230
240
|
f"Model: {self.model}, function: {func.__name__}{msg_info}. "
|
|
231
241
|
f"Error: {type(e).__name__}: {e}"
|
|
232
242
|
)
|
|
233
|
-
result = None
|
|
243
|
+
result = (None, {})
|
|
234
244
|
return kwargs, result
|
|
235
245
|
|
|
236
246
|
futures = [executor.submit(get_data, kwargs) for kwargs in kwargs_list]
|
|
@@ -251,7 +261,7 @@ class Runner:
|
|
|
251
261
|
params: dict,
|
|
252
262
|
*,
|
|
253
263
|
num_samples: int,
|
|
254
|
-
) -> dict:
|
|
264
|
+
) -> tuple[dict, dict]:
|
|
255
265
|
"""Sample answers NUM_SAMPLES times. Returns probabilities of answers.
|
|
256
266
|
|
|
257
267
|
Args:
|
|
@@ -259,6 +269,9 @@ class Runner:
|
|
|
259
269
|
Must include 'messages'. Other common keys: 'max_tokens', 'temperature'.
|
|
260
270
|
num_samples: Number of samples to collect.
|
|
261
271
|
|
|
272
|
+
Returns:
|
|
273
|
+
Tuple of (probs_dict, prepared_kwargs) where prepared_kwargs is what was sent to the API.
|
|
274
|
+
|
|
262
275
|
Works only if the API supports `n` parameter.
|
|
263
276
|
|
|
264
277
|
Usecases:
|
|
@@ -268,6 +281,7 @@ class Runner:
|
|
|
268
281
|
for Runner.single_token_probs.
|
|
269
282
|
"""
|
|
270
283
|
cnts = defaultdict(int)
|
|
284
|
+
prepared = None
|
|
271
285
|
for i in range(((num_samples - 1) // 128) + 1):
|
|
272
286
|
n = min(128, num_samples - i * 128)
|
|
273
287
|
# Build complete params with forced param
|
|
@@ -285,4 +299,4 @@ class Runner:
|
|
|
285
299
|
)
|
|
286
300
|
result = {key: val / num_samples for key, val in cnts.items()}
|
|
287
301
|
result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
|
|
288
|
-
return result
|
|
302
|
+
return result, prepared
|