llmcomp 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,193 @@
1
+ import json
2
+ import os
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from llmcomp.config import Config
8
+
9
+ if TYPE_CHECKING:
10
+ from llmcomp.question.question import Question
11
+
12
+
13
+ @dataclass
14
+ class Result:
15
+ """Cache for question results per model.
16
+
17
+ Storage format (JSONL):
18
+ Line 1: metadata dict
19
+ Lines 2+: one JSON object per result entry
20
+ """
21
+
22
+ question: "Question"
23
+ model: str
24
+ data: list[dict]
25
+
26
+ @classmethod
27
+ def file_path(cls, question: "Question", model: str) -> str:
28
+ return f"{Config.cache_dir}/question/{question.name}/{question.hash()[:7]}/{model}.jsonl"
29
+
30
+ def save(self):
31
+ path = self.file_path(self.question, self.model)
32
+ os.makedirs(os.path.dirname(path), exist_ok=True)
33
+ with open(path, "w") as f:
34
+ f.write(json.dumps(self._metadata()) + "\n")
35
+ for d in self.data:
36
+ f.write(json.dumps(d) + "\n")
37
+
38
+ @classmethod
39
+ def load(cls, question: "Question", model: str) -> "Result":
40
+ path = cls.file_path(question, model)
41
+
42
+ if not os.path.exists(path):
43
+ raise FileNotFoundError(f"Result for model {model} on question {question.name} not found in {path}")
44
+
45
+ with open(path, "r") as f:
46
+ lines = f.readlines()
47
+ if len(lines) == 0:
48
+ raise FileNotFoundError(f"Result for model {model} on question {question.name} is empty.")
49
+
50
+ metadata = json.loads(lines[0])
51
+
52
+ # Hash collision on 7-character prefix - extremely rare
53
+ if metadata["hash"] != question.hash():
54
+ os.remove(path)
55
+ print(f"Rare hash collision detected for {question.name}/{model}. Cached result removed.")
56
+ raise FileNotFoundError(f"Result for model {model} on question {question.name} not found in {path}")
57
+
58
+ data = [json.loads(line) for line in lines[1:]]
59
+ return cls(question, model, data)
60
+
61
+ def _metadata(self) -> dict:
62
+ return {
63
+ "name": self.question.name,
64
+ "model": self.model,
65
+ "last_update": datetime.now().isoformat(),
66
+ "hash": self.question.hash(),
67
+ }
68
+
69
+
70
+ class JudgeCache:
71
+ """Key-value cache for judge results.
72
+
73
+ Storage format (JSON):
74
+ {
75
+ "metadata": {
76
+ "name": "...",
77
+ "model": "...",
78
+ "last_update": "...",
79
+ "hash": "...",
80
+ "prompt": "...",
81
+ "uses_question": true/false
82
+ },
83
+ "data": {
84
+ "<question>": {
85
+ "<answer>": <judge_response>,
86
+ ...
87
+ },
88
+ ...
89
+ }
90
+ }
91
+
92
+ The key is the (question, answer) pair.
93
+
94
+ When the judge template doesn't use {question}, the question key is null
95
+ (Python None), indicating that the judge response only depends on the answer.
96
+ """
97
+
98
+ def __init__(self, judge: "Question"):
99
+ self.judge = judge
100
+ self._data: dict[str | None, dict[str, Any]] | None = None
101
+
102
+ @classmethod
103
+ def file_path(cls, judge: "Question") -> str:
104
+ return f"{Config.cache_dir}/judge/{judge.name}/{judge.hash()[:7]}.json"
105
+
106
+ def _load(self) -> dict[str | None, dict[str, Any]]:
107
+ """Load cache from disk, or return empty dict if not exists."""
108
+ if self._data is not None:
109
+ return self._data
110
+
111
+ path = self.file_path(self.judge)
112
+
113
+ if not os.path.exists(path):
114
+ self._data = {}
115
+ return self._data
116
+
117
+ with open(path, "r") as f:
118
+ file_data = json.load(f)
119
+
120
+ metadata = file_data["metadata"]
121
+
122
+ # Hash collision on 7-character prefix - extremely rare
123
+ if metadata["hash"] != self.judge.hash():
124
+ os.remove(path)
125
+ print(f"Rare hash collision detected for judge {self.judge.name}. Cached result removed.")
126
+ self._data = {}
127
+ return self._data
128
+
129
+ # Sanity check: prompt should match (if hash matches, this should always pass)
130
+ if metadata.get("prompt") != self.judge.paraphrases[0]:
131
+ os.remove(path)
132
+ print(f"Judge prompt mismatch for {self.judge.name}. Cached result removed.")
133
+ self._data = {}
134
+ return self._data
135
+
136
+ self._data = file_data["data"]
137
+ return self._data
138
+
139
+ def save(self):
140
+ """Save cache to disk."""
141
+ if self._data is None:
142
+ return
143
+
144
+ path = self.file_path(self.judge)
145
+ os.makedirs(os.path.dirname(path), exist_ok=True)
146
+ file_data = {
147
+ "metadata": self._metadata(),
148
+ "data": self._data,
149
+ }
150
+ with open(path, "w") as f:
151
+ json.dump(file_data, f, indent=2)
152
+
153
+ def _metadata(self) -> dict:
154
+ return {
155
+ "name": self.judge.name,
156
+ "model": self.judge.model,
157
+ "last_update": datetime.now().isoformat(),
158
+ "hash": self.judge.hash(),
159
+ "prompt": self.judge.paraphrases[0],
160
+ "uses_question": self.judge.uses_question,
161
+ }
162
+
163
+ def _key(self, question: str | None) -> str:
164
+ """Convert question to cache key. None becomes 'null' string for JSON compatibility."""
165
+ # JSON serializes None as null, which becomes the string key "null" when loaded
166
+ # We handle this by using the string "null" internally
167
+ return "null" if question is None else question
168
+
169
+ def get(self, question: str | None, answer: str) -> Any | None:
170
+ """Get the judge response for a (question, answer) pair."""
171
+ data = self._load()
172
+ key = self._key(question)
173
+ if key not in data:
174
+ return None
175
+ return data[key].get(answer)
176
+
177
+ def get_uncached(self, pairs: list[tuple[str | None, str]]) -> list[tuple[str | None, str]]:
178
+ """Return list of (question, answer) pairs that are NOT in cache."""
179
+ data = self._load()
180
+ uncached = []
181
+ for q, a in pairs:
182
+ key = self._key(q)
183
+ if key not in data or a not in data[key]:
184
+ uncached.append((q, a))
185
+ return uncached
186
+
187
+ def set(self, question: str | None, answer: str, judge_response: Any):
188
+ """Add a single entry to cache."""
189
+ data = self._load()
190
+ key = self._key(question)
191
+ if key not in data:
192
+ data[key] = {}
193
+ data[key][answer] = judge_response
@@ -0,0 +1,33 @@
1
+ import backoff
2
+ import openai
3
+
4
+
5
+ def on_backoff(details):
6
+ """We don't print connection error because there's sometimes a lot of them and they're not interesting."""
7
+ exception_details = details["exception"]
8
+ if not str(exception_details).startswith("Connection error."):
9
+ print(exception_details)
10
+
11
+
12
+ @backoff.on_exception(
13
+ wait_gen=backoff.expo,
14
+ exception=(
15
+ openai.RateLimitError,
16
+ openai.APIConnectionError,
17
+ openai.APITimeoutError,
18
+ openai.InternalServerError,
19
+ ),
20
+ max_value=60,
21
+ factor=1.5,
22
+ on_backoff=on_backoff,
23
+ )
24
+ def openai_chat_completion(*, client, **kwargs):
25
+ if kwargs["model"].startswith("gpt-5"):
26
+ kwargs["reasoning_effort"] = "minimal"
27
+ if "max_tokens" in kwargs:
28
+ if kwargs["max_tokens"] < 16:
29
+ raise ValueError("max_tokens must be at least 16 for gpt-5 for whatever reason")
30
+ kwargs["max_completion_tokens"] = kwargs["max_tokens"]
31
+ del kwargs["max_tokens"]
32
+
33
+ return client.chat.completions.create(**kwargs)
@@ -0,0 +1,249 @@
1
+ import math
2
+ import warnings
3
+ from collections import defaultdict
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from threading import Lock
6
+
7
+ from tqdm import tqdm
8
+
9
+ from llmcomp.config import Config, NoClientForModel
10
+ from llmcomp.runner.chat_completion import openai_chat_completion
11
+
12
+ NO_LOGPROBS_WARNING = """\
13
+ Failed to get logprobs because {model} didn't send them.
14
+ Returning empty dict, I hope you can handle it.
15
+
16
+ Last completion has empty logprobs.content:
17
+ {completion}
18
+ """
19
+
20
+
21
+ class Runner:
22
+ def __init__(self, model: str):
23
+ self.model = model
24
+ self._client = None
25
+ self._get_client_lock = Lock()
26
+
27
+ @property
28
+ def client(self):
29
+ if self._client is None:
30
+ with self._get_client_lock:
31
+ if self._client is None:
32
+ self._client = Config.client_for_model(self.model)
33
+ return self._client
34
+
35
+ def get_text(
36
+ self,
37
+ messages: list[dict],
38
+ temperature=1,
39
+ max_tokens=None,
40
+ max_completion_tokens=None,
41
+ **kwargs,
42
+ ) -> str:
43
+ """Just a simple text request. Might get more arguments later."""
44
+ args = {
45
+ "client": self.client,
46
+ "model": self.model,
47
+ "messages": messages,
48
+ "temperature": temperature,
49
+ "timeout": Config.timeout,
50
+ **kwargs,
51
+ }
52
+ if max_tokens is not None:
53
+ # Sending max_tokens is not supported for o3.
54
+ args["max_tokens"] = max_tokens
55
+
56
+ if max_completion_tokens is not None:
57
+ args["max_completion_tokens"] = max_completion_tokens
58
+
59
+ completion = openai_chat_completion(**args)
60
+ try:
61
+ return completion.choices[0].message.content
62
+ except Exception:
63
+ print(completion)
64
+ raise
65
+
66
+ def single_token_probs(
67
+ self,
68
+ messages: list[dict],
69
+ top_logprobs: int = 20,
70
+ num_samples: int = 1,
71
+ convert_to_probs: bool = True,
72
+ **kwargs,
73
+ ) -> dict:
74
+ probs = {}
75
+ for _ in range(num_samples):
76
+ new_probs = self.single_token_probs_one_sample(messages, top_logprobs, convert_to_probs, **kwargs)
77
+ for key, value in new_probs.items():
78
+ probs[key] = probs.get(key, 0) + value
79
+ result = {key: value / num_samples for key, value in probs.items()}
80
+ result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
81
+ return result
82
+
83
+ def single_token_probs_one_sample(
84
+ self,
85
+ messages: list[dict],
86
+ top_logprobs: int = 20,
87
+ convert_to_probs: bool = True,
88
+ **kwargs,
89
+ ) -> dict:
90
+ """Returns probabilities of the next token. Always samples 1 token."""
91
+ completion = openai_chat_completion(
92
+ client=self.client,
93
+ model=self.model,
94
+ messages=messages,
95
+ max_tokens=1,
96
+ temperature=0,
97
+ logprobs=True,
98
+ top_logprobs=top_logprobs,
99
+ timeout=Config.timeout,
100
+ **kwargs,
101
+ )
102
+
103
+ if completion.choices[0].logprobs is None:
104
+ raise Exception(f"No logprobs returned, it seems that your provider for {self.model} doesn't support that.")
105
+
106
+ try:
107
+ logprobs = completion.choices[0].logprobs.content[0].top_logprobs
108
+ except IndexError:
109
+ # This should not happen according to the API docs. But it sometimes does.
110
+ print(NO_LOGPROBS_WARNING.format(model=self.model, completion=completion))
111
+ return {}
112
+
113
+ result = {}
114
+ for el in logprobs:
115
+ result[el.token] = math.exp(el.logprob) if convert_to_probs else el.logprob
116
+
117
+ return result
118
+
119
+ def get_many(
120
+ self,
121
+ func,
122
+ kwargs_list,
123
+ *,
124
+ max_workers=None,
125
+ silent=False,
126
+ title=None,
127
+ executor=None,
128
+ ):
129
+ """Call FUNC with arguments from KWARGS_LIST in MAX_WORKERS parallel threads.
130
+
131
+ FUNC is get_text or single_token_probs. Examples:
132
+
133
+ kwargs_list = [
134
+ {"messages": [{"role": "user", "content": "Hello"}]},
135
+ {"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7},
136
+ ]
137
+ for in_, out in runner.get_many(runner.get_text, kwargs_list):
138
+ print(in_, "->", out)
139
+
140
+ or
141
+
142
+ kwargs_list = [
143
+ {"messages": [{"role": "user", "content": "Hello"}]},
144
+ {"messages": [{"role": "user", "content": "Bye"}]},
145
+ ]
146
+ for in_, out in runner.get_many(runner.single_token_probs, kwargs_list):
147
+ print(in_, "->", out)
148
+
149
+ (FUNC that is a different callable should also work)
150
+
151
+ This function returns a generator that yields pairs (input, output),
152
+ where input is an element from KWARGS_SET and output is the thing returned by
153
+ FUNC for this input.
154
+
155
+ Dictionaries in KWARGS_SET might include optional keys starting with underscore,
156
+ they are just ignored, but they are returned in the first element of the pair, so that's useful
157
+ for passing some additional information that will be later paired with the output.
158
+
159
+ Other parameters:
160
+ - MAX_WORKERS: number of parallel threads, overrides Config.max_workers.
161
+ - SILENT: passed to tqdm
162
+ - TITLE: passed to tqdm as desc
163
+ - EXECUTOR: optional ThreadPoolExecutor instance, if you want many calls to get_many to run within
164
+ the same executor. MAX_WORKERS and Config.max_workers are then ignored.
165
+ """
166
+ if max_workers is None:
167
+ max_workers = Config.max_workers
168
+
169
+ executor_created = False
170
+ if executor is None:
171
+ executor = ThreadPoolExecutor(max_workers)
172
+ executor_created = True
173
+
174
+ def get_data(kwargs):
175
+ func_kwargs = {key: val for key, val in kwargs.items() if not key.startswith("_")}
176
+ try:
177
+ result = func(**func_kwargs)
178
+ except NoClientForModel:
179
+ raise
180
+ except Exception as e:
181
+ # Truncate messages for readability
182
+ messages = func_kwargs.get("messages", [])
183
+ if messages:
184
+ last_msg = str(messages[-1].get("content", ""))[:100]
185
+ msg_info = f", last message: {last_msg!r}..."
186
+ else:
187
+ msg_info = ""
188
+ warnings.warn(
189
+ f"Unexpected error (probably API-related), runner returns None. "
190
+ f"Model: {self.model}, function: {func.__name__}{msg_info}. "
191
+ f"Error: {type(e).__name__}: {e}"
192
+ )
193
+ result = None
194
+ return kwargs, result
195
+
196
+ futures = [executor.submit(get_data, kwargs) for kwargs in kwargs_list]
197
+
198
+ try:
199
+ for future in tqdm(as_completed(futures), total=len(futures), disable=silent, desc=title):
200
+ yield future.result()
201
+ except (Exception, KeyboardInterrupt):
202
+ for fut in futures:
203
+ fut.cancel()
204
+ raise
205
+ finally:
206
+ if executor_created:
207
+ executor.shutdown(wait=False)
208
+
209
+ def sample_probs(
210
+ self,
211
+ messages: list[dict],
212
+ *,
213
+ num_samples: int,
214
+ max_tokens: int,
215
+ temperature: float = 1,
216
+ **kwargs,
217
+ ) -> dict:
218
+ """Sample answers NUM_SAMPLES times. Returns probabilities of answers.
219
+
220
+ Works only if the API supports `n` parameter.
221
+
222
+ Usecases:
223
+ * It should be faster and cheaper than get_many + get_text
224
+ (uses `n` parameter so you don't pay for input tokens for each request separately).
225
+ * If your API doesn't support logprobs, but supports `n`, you can use that as a replacement
226
+ for Runner.single_token_probs.
227
+ """
228
+ cnts = defaultdict(int)
229
+ for i in range(((num_samples - 1) // 128) + 1):
230
+ n = min(128, num_samples - i * 128)
231
+ completion = openai_chat_completion(
232
+ client=self.client,
233
+ model=self.model,
234
+ messages=messages,
235
+ max_tokens=max_tokens,
236
+ temperature=temperature,
237
+ n=n,
238
+ timeout=Config.timeout,
239
+ **kwargs,
240
+ )
241
+ for choice in completion.choices:
242
+ cnts[choice.message.content] += 1
243
+ if sum(cnts.values()) != num_samples:
244
+ raise Exception(
245
+ f"Something weird happened. Expected {num_samples} samples, got {sum(cnts.values())}. Maybe n parameter is ignored for {self.model}?"
246
+ )
247
+ result = {key: val / num_samples for key, val in cnts.items()}
248
+ result = dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
249
+ return result
llmcomp/utils.py ADDED
@@ -0,0 +1,97 @@
1
+ """Utility functions for llmcomp."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+
8
+ def write_jsonl(path: str | Path, data: list[dict[str, Any]]) -> None:
9
+ """Write a list of dictionaries to a JSONL file.
10
+
11
+ Each dictionary is written as a JSON object on a separate line.
12
+
13
+ Args:
14
+ path: Path to the output JSONL file
15
+ data: List of dictionaries to write
16
+
17
+ Example:
18
+ >>> data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
19
+ >>> write_jsonl("people.jsonl", data)
20
+ """
21
+ path = Path(path)
22
+ path.parent.mkdir(parents=True, exist_ok=True)
23
+
24
+ with open(path, "w", encoding="utf-8") as f:
25
+ for item in data:
26
+ f.write(json.dumps(item) + "\n")
27
+
28
+
29
+ def read_jsonl(path: str | Path) -> list[dict[str, Any]]:
30
+ """Read a JSONL file and return a list of dictionaries.
31
+
32
+ Each line is parsed as a JSON object.
33
+
34
+ Args:
35
+ path: Path to the input JSONL file
36
+
37
+ Returns:
38
+ List of dictionaries, one per line in the file
39
+
40
+ Example:
41
+ >>> data = read_jsonl("people.jsonl")
42
+ >>> print(data)
43
+ [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
44
+ """
45
+ path = Path(path)
46
+ data = []
47
+
48
+ with open(path, "r", encoding="utf-8") as f:
49
+ for line in f:
50
+ line = line.strip()
51
+ if line: # Skip empty lines
52
+ data.append(json.loads(line))
53
+
54
+ return data
55
+
56
+
57
+ def get_error_bars(fraction_list, rng=None, alpha=0.95, n_resamples=2000):
58
+ """
59
+ Given a list of fractions, compute a bootstrap-based confidence interval
60
+ around the mean of that list.
61
+
62
+ Returns:
63
+ (center, lower_err, upper_err)
64
+ where:
65
+ - center = mean of fraction_list
66
+ - lower_err = center - lower_CI
67
+ - upper_err = upper_CI - center
68
+
69
+ So if you want to pass these to plt.errorbar as yerr:
70
+ yerr = [[lower_err], [upper_err]]
71
+ """
72
+ import numpy as np
73
+
74
+ if rng is None:
75
+ rng = np.random.default_rng(0)
76
+ fractions = np.array(fraction_list, dtype=float)
77
+
78
+ # Edge cases
79
+ if len(fractions) == 0:
80
+ return (0.0, 0.0, 0.0)
81
+ if len(fractions) == 1:
82
+ return (fractions[0], 0.0, 0.0)
83
+
84
+ boot_means = []
85
+ for _ in range(n_resamples):
86
+ sample = rng.choice(fractions, size=len(fractions), replace=True)
87
+ boot_means.append(np.mean(sample))
88
+ boot_means = np.array(boot_means)
89
+
90
+ lower_bound = np.percentile(boot_means, (1 - alpha) / 2 * 100)
91
+ upper_bound = np.percentile(boot_means, (1 - (1 - alpha) / 2) * 100)
92
+ center = np.mean(fractions)
93
+
94
+ lower_err = center - lower_bound
95
+ upper_err = upper_bound - center
96
+
97
+ return (center, lower_err, upper_err)