llmcomp 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import hashlib
4
- import json
5
3
  import os
6
4
  import warnings
7
5
  from abc import ABC, abstractmethod
@@ -29,10 +27,6 @@ if TYPE_CHECKING:
29
27
 
30
28
 
31
29
  class Question(ABC):
32
- # Purpose of _version: it is used in the hash function so if some important part of the implementation changes,
33
- # we can change the version here and it'll invalidate all the cached results.
34
- _version = 1
35
-
36
30
  def __init__(
37
31
  self,
38
32
  name: str | None = "__unnamed",
@@ -315,9 +309,9 @@ class Question(ABC):
315
309
  in_, out = payload
316
310
  data = results[models.index(model)]
317
311
  data[in_["_original_ix"]] = {
318
- # Deepcopy because in_["messages"] is reused for multiple models and we don't want weird
319
- # side effects if someone later edits the messages in the resulting dataframe
320
- "messages": deepcopy(in_["messages"]),
312
+ # Deepcopy because in_["params"]["messages"] is reused for multiple models
313
+ # and we don't want weird side effects if someone later edits the messages
314
+ "messages": deepcopy(in_["params"]["messages"]),
321
315
  "question": in_["_question"],
322
316
  "answer": out,
323
317
  "paraphrase_ix": in_["_paraphrase_ix"],
@@ -343,9 +337,11 @@ class Question(ABC):
343
337
  messages_set = self.as_messages()
344
338
  runner_input = []
345
339
  for paraphrase_ix, messages in enumerate(messages_set):
340
+ params = {"messages": messages}
341
+ if self.logit_bias is not None:
342
+ params["logit_bias"] = self.logit_bias
346
343
  this_input = {
347
- "messages": messages,
348
- "logit_bias": self.logit_bias,
344
+ "params": params,
349
345
  "_question": messages[-1]["content"],
350
346
  "_paraphrase_ix": paraphrase_ix,
351
347
  }
@@ -371,21 +367,6 @@ class Question(ABC):
371
367
  messages_set.append(messages)
372
368
  return messages_set
373
369
 
374
- ###########################################################################
375
- # OTHER STUFF
376
- def hash(self):
377
- """Unique identifier for caching. Changes when question parameters change.
378
-
379
- Used to determine whether we can use cached results.
380
- Excludes judges since they don't affect the raw LLM answers.
381
- """
382
- excluded = {"judges"}
383
- attributes = {k: v for k, v in self.__dict__.items() if k not in excluded}
384
- attributes["_version"] = self._version
385
- json_str = json.dumps(attributes, sort_keys=True)
386
- return hashlib.sha256(json_str.encode()).hexdigest()
387
-
388
-
389
370
  class FreeForm(Question):
390
371
  """Question type for free-form text generation.
391
372
 
@@ -440,8 +421,8 @@ class FreeForm(Question):
440
421
  def get_runner_input(self) -> list[dict]:
441
422
  runner_input = super().get_runner_input()
442
423
  for el in runner_input:
443
- el["temperature"] = self.temperature
444
- el["max_tokens"] = self.max_tokens
424
+ el["params"]["temperature"] = self.temperature
425
+ el["params"]["max_tokens"] = self.max_tokens
445
426
  return runner_input
446
427
 
447
428
  def df(self, model_groups: dict[str, list[str]]) -> pd.DataFrame:
@@ -745,7 +726,7 @@ class Rating(Question):
745
726
  def get_runner_input(self) -> list[dict]:
746
727
  runner_input = super().get_runner_input()
747
728
  for el in runner_input:
748
- el["top_logprobs"] = self.top_logprobs
729
+ el["params"]["top_logprobs"] = self.top_logprobs
749
730
  return runner_input
750
731
 
751
732
  def df(self, model_groups: dict[str, list[str]]) -> pd.DataFrame:
@@ -899,9 +880,8 @@ class NextToken(Question):
899
880
 
900
881
  def get_runner_input(self) -> list[dict]:
901
882
  runner_input = super().get_runner_input()
902
-
903
883
  for el in runner_input:
904
- el["top_logprobs"] = self.top_logprobs
884
+ el["params"]["top_logprobs"] = self.top_logprobs
905
885
  el["convert_to_probs"] = self.convert_to_probs
906
886
  el["num_samples"] = self.num_samples
907
887
  return runner_input
@@ -1,3 +1,4 @@
1
+ import hashlib
1
2
  import json
2
3
  import os
3
4
  from dataclasses import dataclass
@@ -5,10 +6,61 @@ from datetime import datetime
5
6
  from typing import TYPE_CHECKING, Any
6
7
 
7
8
  from llmcomp.config import Config
9
+ from llmcomp.runner.model_adapter import ModelAdapter
8
10
 
9
11
  if TYPE_CHECKING:
10
12
  from llmcomp.question.question import Question
11
13
 
14
+ # Bump this to invalidate all cached results when the caching implementation changes.
15
+ CACHE_VERSION = 2
16
+
17
+
18
+ def cache_hash(question: "Question", model: str) -> str:
19
+ """Compute cache hash for a question and model combination.
20
+
21
+ The hash includes:
22
+ - Question name and type
23
+ - All prepared API parameters (after ModelAdapter transformations)
24
+ - Runner-level settings (e.g., convert_to_probs, num_samples)
25
+
26
+ This ensures cache invalidation when:
27
+ - Question content changes (messages, temperature, etc.)
28
+ - Model-specific config changes (reasoning_effort, max_completion_tokens, etc.)
29
+ - Number of samples changes (samples_per_paraphrase)
30
+
31
+ Args:
32
+ question: The Question object
33
+ model: Model identifier (needed for ModelAdapter transformations)
34
+
35
+ Returns:
36
+ SHA256 hash string
37
+ """
38
+ runner_input = question.get_runner_input()
39
+
40
+ # For each input, compute what would be sent to the API
41
+ prepared_inputs = []
42
+ for inp in runner_input:
43
+ params = inp["params"]
44
+ prepared_params = ModelAdapter.prepare(params, model)
45
+
46
+ # Include runner-level settings (not underscore-prefixed, not params)
47
+ runner_settings = {k: v for k, v in inp.items() if not k.startswith("_") and k != "params"}
48
+
49
+ prepared_inputs.append({
50
+ "prepared_params": prepared_params,
51
+ **runner_settings,
52
+ })
53
+
54
+ hash_input = {
55
+ "name": question.name,
56
+ "type": question.type(),
57
+ "inputs": prepared_inputs,
58
+ "_version": CACHE_VERSION,
59
+ }
60
+
61
+ json_str = json.dumps(hash_input, sort_keys=True)
62
+ return hashlib.sha256(json_str.encode()).hexdigest()
63
+
12
64
 
13
65
  @dataclass
14
66
  class Result:
@@ -25,7 +77,7 @@ class Result:
25
77
 
26
78
  @classmethod
27
79
  def file_path(cls, question: "Question", model: str) -> str:
28
- return f"{Config.cache_dir}/question/{question.name}/{question.hash()[:7]}/{model}.jsonl"
80
+ return f"{Config.cache_dir}/question/{question.name}/{cache_hash(question, model)[:7]}.jsonl"
29
81
 
30
82
  def save(self):
31
83
  path = self.file_path(self.question, self.model)
@@ -50,7 +102,7 @@ class Result:
50
102
  metadata = json.loads(lines[0])
51
103
 
52
104
  # Hash collision on 7-character prefix - extremely rare
53
- if metadata["hash"] != question.hash():
105
+ if metadata["hash"] != cache_hash(question, model):
54
106
  os.remove(path)
55
107
  print(f"Rare hash collision detected for {question.name}/{model}. Cached result removed.")
56
108
  raise FileNotFoundError(f"Result for model {model} on question {question.name} not found in {path}")
@@ -63,7 +115,7 @@ class Result:
63
115
  "name": self.question.name,
64
116
  "model": self.model,
65
117
  "last_update": datetime.now().isoformat(),
66
- "hash": self.question.hash(),
118
+ "hash": cache_hash(self.question, self.model),
67
119
  }
68
120
 
69
121
 
@@ -101,7 +153,7 @@ class JudgeCache:
101
153
 
102
154
  @classmethod
103
155
  def file_path(cls, judge: "Question") -> str:
104
- return f"{Config.cache_dir}/judge/{judge.name}/{judge.hash()[:7]}.json"
156
+ return f"{Config.cache_dir}/judge/{judge.name}/{cache_hash(judge, judge.model)[:7]}.json"
105
157
 
106
158
  def _load(self) -> dict[str | None, dict[str, Any]]:
107
159
  """Load cache from disk, or return empty dict if not exists."""
@@ -120,7 +172,7 @@ class JudgeCache:
120
172
  metadata = file_data["metadata"]
121
173
 
122
174
  # Hash collision on 7-character prefix - extremely rare
123
- if metadata["hash"] != self.judge.hash():
175
+ if metadata["hash"] != cache_hash(self.judge, self.judge.model):
124
176
  os.remove(path)
125
177
  print(f"Rare hash collision detected for judge {self.judge.name}. Cached result removed.")
126
178
  self._data = {}
@@ -155,7 +207,7 @@ class JudgeCache:
155
207
  "name": self.judge.name,
156
208
  "model": self.judge.model,
157
209
  "last_update": datetime.now().isoformat(),
158
- "hash": self.judge.hash(),
210
+ "hash": cache_hash(self.judge, self.judge.model),
159
211
  "prompt": self.judge.paraphrases[0],
160
212
  "uses_question": self.judge.uses_question,
161
213
  }
@@ -22,12 +22,4 @@ def on_backoff(details):
22
22
  on_backoff=on_backoff,
23
23
  )
24
24
  def openai_chat_completion(*, client, **kwargs):
25
- if kwargs["model"].startswith("gpt-5"):
26
- kwargs["reasoning_effort"] = "minimal"
27
- if "max_tokens" in kwargs:
28
- if kwargs["max_tokens"] < 16:
29
- raise ValueError("max_tokens must be at least 16 for gpt-5 for whatever reason")
30
- kwargs["max_completion_tokens"] = kwargs["max_tokens"]
31
- del kwargs["max_tokens"]
32
-
33
25
  return client.chat.completions.create(**kwargs)
@@ -0,0 +1,98 @@
1
+
2
+ from typing import Callable
3
+
4
+ ModelSelector = Callable[[str], bool]
5
+ PrepareFunction = Callable[[dict, str], dict]
6
+
7
+
8
+ class ModelAdapter:
9
+ """Adapts API request params for specific models.
10
+
11
+ Handlers can be registered to transform params for specific models.
12
+ All matching handlers are applied in registration order.
13
+ """
14
+
15
+ _handlers: list[tuple[ModelSelector, PrepareFunction]] = []
16
+
17
+ @classmethod
18
+ def register(cls, model_selector: ModelSelector, prepare_function: PrepareFunction):
19
+ """Register a handler for model-specific param transformation.
20
+
21
+ Args:
22
+ model_selector: Callable[[str], bool] - returns True if this handler
23
+ should be applied for the given model name.
24
+ prepare_function: Callable[[dict, str], dict] - transforms params.
25
+ Receives (params, model) and returns transformed params.
26
+
27
+ Example:
28
+ # Register a handler for a custom model
29
+ def my_model_prepare(params, model):
30
+ # Transform params as needed
31
+ return {**params, "custom_param": "value"}
32
+
33
+ ModelAdapter.register(
34
+ lambda model: model == "my-model",
35
+ my_model_prepare
36
+ )
37
+ """
38
+ cls._handlers.append((model_selector, prepare_function))
39
+
40
+ @classmethod
41
+ def prepare(cls, params: dict, model: str) -> dict:
42
+ """Prepare params for the API call.
43
+
44
+ Applies all registered handlers whose model_selector returns True.
45
+ Handlers are applied in registration order, each receiving the output
46
+ of the previous handler.
47
+
48
+ Args:
49
+ params: The params to transform.
50
+ model: The model name.
51
+
52
+ Returns:
53
+ Transformed params ready for the API call.
54
+ """
55
+ result = params
56
+ for model_selector, prepare_function in cls._handlers:
57
+ if model_selector(model):
58
+ result = prepare_function(result, model)
59
+ return result
60
+
61
+ @classmethod
62
+ def test_request_params(cls, model: str) -> dict:
63
+ """Get minimal params for testing if a model works.
64
+
65
+ Returns params for a minimal API request to verify connectivity.
66
+ Does NOT use registered handlers - just handles core model requirements.
67
+
68
+ Args:
69
+ model: The model name.
70
+
71
+ Returns:
72
+ Dict with model, messages, and appropriate token limit params.
73
+ """
74
+ params = {
75
+ "model": model,
76
+ "messages": [{"role": "user", "content": "Hi"}],
77
+ "timeout": 30, # Some providers are slow
78
+ }
79
+
80
+ if cls._is_reasoning_model(model):
81
+ # Reasoning models need max_completion_tokens and reasoning_effort
82
+ params["max_completion_tokens"] = 16
83
+ params["reasoning_effort"] = "none"
84
+ else:
85
+ params["max_tokens"] = 1
86
+
87
+ return params
88
+
89
+ @classmethod
90
+ def _is_reasoning_model(cls, model: str) -> bool:
91
+ """Check if model is a reasoning model (o1, o3, o4, gpt-5 series)."""
92
+ return (
93
+ model.startswith("o1")
94
+ or model.startswith("o3")
95
+ or model.startswith("o4")
96
+ or model.startswith("gpt-5")
97
+ )
98
+
llmcomp/runner/runner.py CHANGED
@@ -8,6 +8,7 @@ from tqdm import tqdm
8
8
 
9
9
  from llmcomp.config import Config, NoClientForModel
10
10
  from llmcomp.runner.chat_completion import openai_chat_completion
11
+ from llmcomp.runner.model_adapter import ModelAdapter
11
12
 
12
13
  NO_LOGPROBS_WARNING = """\
13
14
  Failed to get logprobs because {model} didn't send them.
@@ -32,31 +33,26 @@ class Runner:
32
33
  self._client = Config.client_for_model(self.model)
33
34
  return self._client
34
35
 
35
- def get_text(
36
- self,
37
- messages: list[dict],
38
- temperature=1,
39
- max_tokens=None,
40
- max_completion_tokens=None,
41
- **kwargs,
42
- ) -> str:
43
- """Just a simple text request. Might get more arguments later."""
44
- args = {
45
- "client": self.client,
46
- "model": self.model,
47
- "messages": messages,
48
- "temperature": temperature,
49
- "timeout": Config.timeout,
50
- **kwargs,
51
- }
52
- if max_tokens is not None:
53
- # Sending max_tokens is not supported for o3.
54
- args["max_tokens"] = max_tokens
36
+ def _prepare_for_model(self, params: dict) -> dict:
37
+ """Prepare params for the API call via ModelAdapter.
38
+
39
+ Also adds timeout from Config. Timeout is added here (not in ModelAdapter)
40
+ because it doesn't affect API response content and shouldn't be part of the cache hash.
41
+
42
+ Note: timeout is set first so that ModelAdapter handlers can override it if needed.
43
+ """
44
+ prepared = ModelAdapter.prepare(params, self.model)
45
+ return {"timeout": Config.timeout, **prepared}
55
46
 
56
- if max_completion_tokens is not None:
57
- args["max_completion_tokens"] = max_completion_tokens
47
+ def get_text(self, params: dict) -> str:
48
+ """Get a text completion from the model.
58
49
 
59
- completion = openai_chat_completion(**args)
50
+ Args:
51
+ params: Dictionary of parameters for the API.
52
+ Must include 'messages'. Other common keys: 'temperature', 'max_tokens'.
53
+ """
54
+ prepared = self._prepare_for_model(params)
55
+ completion = openai_chat_completion(client=self.client, **prepared)
60
56
  try:
61
57
  return completion.choices[0].message.content
62
58
  except Exception:
@@ -65,15 +61,22 @@ class Runner:
65
61
 
66
62
  def single_token_probs(
67
63
  self,
68
- messages: list[dict],
69
- top_logprobs: int = 20,
64
+ params: dict,
65
+ *,
70
66
  num_samples: int = 1,
71
67
  convert_to_probs: bool = True,
72
- **kwargs,
73
68
  ) -> dict:
69
+ """Get probability distribution of the next token, optionally averaged over multiple samples.
70
+
71
+ Args:
72
+ params: Dictionary of parameters for the API.
73
+ Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
74
+ num_samples: Number of samples to average over. Default: 1.
75
+ convert_to_probs: If True, convert logprobs to probabilities. Default: True.
76
+ """
74
77
  probs = {}
75
78
  for _ in range(num_samples):
76
- new_probs = self.single_token_probs_one_sample(messages, top_logprobs, convert_to_probs, **kwargs)
79
+ new_probs = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
77
80
  for key, value in new_probs.items():
78
81
  probs[key] = probs.get(key, 0) + value
79
82
  result = {key: value / num_samples for key, value in probs.items()}
@@ -82,23 +85,31 @@ class Runner:
82
85
 
83
86
  def single_token_probs_one_sample(
84
87
  self,
85
- messages: list[dict],
86
- top_logprobs: int = 20,
88
+ params: dict,
89
+ *,
87
90
  convert_to_probs: bool = True,
88
- **kwargs,
89
91
  ) -> dict:
90
- """Returns probabilities of the next token. Always samples 1 token."""
91
- completion = openai_chat_completion(
92
- client=self.client,
93
- model=self.model,
94
- messages=messages,
95
- max_tokens=1,
96
- temperature=0,
97
- logprobs=True,
98
- top_logprobs=top_logprobs,
99
- timeout=Config.timeout,
100
- **kwargs,
101
- )
92
+ """Get probability distribution of the next token (single sample).
93
+
94
+ Args:
95
+ params: Dictionary of parameters for the API.
96
+ Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
97
+ convert_to_probs: If True, convert logprobs to probabilities. Default: True.
98
+
99
+ Note: This function forces max_tokens=1, temperature=0, logprobs=True.
100
+ """
101
+ # Build complete params with defaults and forced params
102
+ complete_params = {
103
+ # Default for top_logprobs, can be overridden by params:
104
+ "top_logprobs": 20,
105
+ **params,
106
+ # These are required for single_token_probs semantics (cannot be overridden):
107
+ "max_tokens": 1,
108
+ "temperature": 0,
109
+ "logprobs": True,
110
+ }
111
+ prepared = self._prepare_for_model(complete_params)
112
+ completion = openai_chat_completion(client=self.client, **prepared)
102
113
 
103
114
  if completion.choices[0].logprobs is None:
104
115
  raise Exception(f"No logprobs returned, it seems that your provider for {self.model} doesn't support that.")
@@ -131,8 +142,8 @@ class Runner:
131
142
  FUNC is get_text or single_token_probs. Examples:
132
143
 
133
144
  kwargs_list = [
134
- {"messages": [{"role": "user", "content": "Hello"}]},
135
- {"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7},
145
+ {"params": {"messages": [{"role": "user", "content": "Hello"}]}},
146
+ {"params": {"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7}},
136
147
  ]
137
148
  for in_, out in runner.get_many(runner.get_text, kwargs_list):
138
149
  print(in_, "->", out)
@@ -140,8 +151,8 @@ class Runner:
140
151
  or
141
152
 
142
153
  kwargs_list = [
143
- {"messages": [{"role": "user", "content": "Hello"}]},
144
- {"messages": [{"role": "user", "content": "Bye"}]},
154
+ {"params": {"messages": [{"role": "user", "content": "Hello"}]}},
155
+ {"params": {"messages": [{"role": "user", "content": "Bye"}]}},
145
156
  ]
146
157
  for in_, out in runner.get_many(runner.single_token_probs, kwargs_list):
147
158
  print(in_, "->", out)
@@ -149,10 +160,10 @@ class Runner:
149
160
  (FUNC that is a different callable should also work)
150
161
 
151
162
  This function returns a generator that yields pairs (input, output),
152
- where input is an element from KWARGS_SET and output is the thing returned by
163
+ where input is an element from KWARGS_LIST and output is the thing returned by
153
164
  FUNC for this input.
154
165
 
155
- Dictionaries in KWARGS_SET might include optional keys starting with underscore,
166
+ Dictionaries in KWARGS_LIST might include optional keys starting with underscore,
156
167
  they are just ignored, but they are returned in the first element of the pair, so that's useful
157
168
  for passing some additional information that will be later paired with the output.
158
169
 
@@ -179,7 +190,8 @@ class Runner:
179
190
  raise
180
191
  except Exception as e:
181
192
  # Truncate messages for readability
182
- messages = func_kwargs.get("messages", [])
193
+ params = func_kwargs.get("params", {})
194
+ messages = params.get("messages", [])
183
195
  if messages:
184
196
  last_msg = str(messages[-1].get("content", ""))[:100]
185
197
  msg_info = f", last message: {last_msg!r}..."
@@ -208,15 +220,17 @@ class Runner:
208
220
 
209
221
  def sample_probs(
210
222
  self,
211
- messages: list[dict],
223
+ params: dict,
212
224
  *,
213
225
  num_samples: int,
214
- max_tokens: int,
215
- temperature: float = 1,
216
- **kwargs,
217
226
  ) -> dict:
218
227
  """Sample answers NUM_SAMPLES times. Returns probabilities of answers.
219
228
 
229
+ Args:
230
+ params: Dictionary of parameters for the API.
231
+ Must include 'messages'. Other common keys: 'max_tokens', 'temperature'.
232
+ num_samples: Number of samples to collect.
233
+
220
234
  Works only if the API supports `n` parameter.
221
235
 
222
236
  Usecases:
@@ -228,16 +242,13 @@ class Runner:
228
242
  cnts = defaultdict(int)
229
243
  for i in range(((num_samples - 1) // 128) + 1):
230
244
  n = min(128, num_samples - i * 128)
231
- completion = openai_chat_completion(
232
- client=self.client,
233
- model=self.model,
234
- messages=messages,
235
- max_tokens=max_tokens,
236
- temperature=temperature,
237
- n=n,
238
- timeout=Config.timeout,
239
- **kwargs,
240
- )
245
+ # Build complete params with forced param
246
+ complete_params = {
247
+ **params,
248
+ "n": n,
249
+ }
250
+ prepared = self._prepare_for_model(complete_params)
251
+ completion = openai_chat_completion(client=self.client, **prepared)
241
252
  for choice in completion.choices:
242
253
  cnts[choice.message.content] += 1
243
254
  if sum(cnts.values()) != num_samples: