llmcomp 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmcomp/__init__.py +4 -0
- llmcomp/config.py +10 -15
- llmcomp/default_adapters.py +81 -0
- llmcomp/finetuning/__init__.py +2 -0
- llmcomp/finetuning/manager.py +473 -0
- llmcomp/finetuning/update_jobs.py +38 -0
- llmcomp/question/question.py +11 -31
- llmcomp/question/result.py +58 -6
- llmcomp/runner/chat_completion.py +0 -8
- llmcomp/runner/model_adapter.py +98 -0
- llmcomp/runner/runner.py +74 -63
- {llmcomp-1.0.0.dist-info → llmcomp-1.1.0.dist-info}/METADATA +85 -21
- llmcomp-1.1.0.dist-info/RECORD +19 -0
- llmcomp-1.1.0.dist-info/entry_points.txt +2 -0
- llmcomp-1.0.0.dist-info/RECORD +0 -13
- {llmcomp-1.0.0.dist-info → llmcomp-1.1.0.dist-info}/WHEEL +0 -0
- {llmcomp-1.0.0.dist-info → llmcomp-1.1.0.dist-info}/licenses/LICENSE +0 -0
llmcomp/question/question.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import hashlib
|
|
4
|
-
import json
|
|
5
3
|
import os
|
|
6
4
|
import warnings
|
|
7
5
|
from abc import ABC, abstractmethod
|
|
@@ -29,10 +27,6 @@ if TYPE_CHECKING:
|
|
|
29
27
|
|
|
30
28
|
|
|
31
29
|
class Question(ABC):
|
|
32
|
-
# Purpose of _version: it is used in the hash function so if some important part of the implementation changes,
|
|
33
|
-
# we can change the version here and it'll invalidate all the cached results.
|
|
34
|
-
_version = 1
|
|
35
|
-
|
|
36
30
|
def __init__(
|
|
37
31
|
self,
|
|
38
32
|
name: str | None = "__unnamed",
|
|
@@ -315,9 +309,9 @@ class Question(ABC):
|
|
|
315
309
|
in_, out = payload
|
|
316
310
|
data = results[models.index(model)]
|
|
317
311
|
data[in_["_original_ix"]] = {
|
|
318
|
-
# Deepcopy because in_["messages"] is reused for multiple models
|
|
319
|
-
# side effects if someone later edits the messages
|
|
320
|
-
"messages": deepcopy(in_["messages"]),
|
|
312
|
+
# Deepcopy because in_["params"]["messages"] is reused for multiple models
|
|
313
|
+
# and we don't want weird side effects if someone later edits the messages
|
|
314
|
+
"messages": deepcopy(in_["params"]["messages"]),
|
|
321
315
|
"question": in_["_question"],
|
|
322
316
|
"answer": out,
|
|
323
317
|
"paraphrase_ix": in_["_paraphrase_ix"],
|
|
@@ -343,9 +337,11 @@ class Question(ABC):
|
|
|
343
337
|
messages_set = self.as_messages()
|
|
344
338
|
runner_input = []
|
|
345
339
|
for paraphrase_ix, messages in enumerate(messages_set):
|
|
340
|
+
params = {"messages": messages}
|
|
341
|
+
if self.logit_bias is not None:
|
|
342
|
+
params["logit_bias"] = self.logit_bias
|
|
346
343
|
this_input = {
|
|
347
|
-
"
|
|
348
|
-
"logit_bias": self.logit_bias,
|
|
344
|
+
"params": params,
|
|
349
345
|
"_question": messages[-1]["content"],
|
|
350
346
|
"_paraphrase_ix": paraphrase_ix,
|
|
351
347
|
}
|
|
@@ -371,21 +367,6 @@ class Question(ABC):
|
|
|
371
367
|
messages_set.append(messages)
|
|
372
368
|
return messages_set
|
|
373
369
|
|
|
374
|
-
###########################################################################
|
|
375
|
-
# OTHER STUFF
|
|
376
|
-
def hash(self):
|
|
377
|
-
"""Unique identifier for caching. Changes when question parameters change.
|
|
378
|
-
|
|
379
|
-
Used to determine whether we can use cached results.
|
|
380
|
-
Excludes judges since they don't affect the raw LLM answers.
|
|
381
|
-
"""
|
|
382
|
-
excluded = {"judges"}
|
|
383
|
-
attributes = {k: v for k, v in self.__dict__.items() if k not in excluded}
|
|
384
|
-
attributes["_version"] = self._version
|
|
385
|
-
json_str = json.dumps(attributes, sort_keys=True)
|
|
386
|
-
return hashlib.sha256(json_str.encode()).hexdigest()
|
|
387
|
-
|
|
388
|
-
|
|
389
370
|
class FreeForm(Question):
|
|
390
371
|
"""Question type for free-form text generation.
|
|
391
372
|
|
|
@@ -440,8 +421,8 @@ class FreeForm(Question):
|
|
|
440
421
|
def get_runner_input(self) -> list[dict]:
|
|
441
422
|
runner_input = super().get_runner_input()
|
|
442
423
|
for el in runner_input:
|
|
443
|
-
el["temperature"] = self.temperature
|
|
444
|
-
el["max_tokens"] = self.max_tokens
|
|
424
|
+
el["params"]["temperature"] = self.temperature
|
|
425
|
+
el["params"]["max_tokens"] = self.max_tokens
|
|
445
426
|
return runner_input
|
|
446
427
|
|
|
447
428
|
def df(self, model_groups: dict[str, list[str]]) -> pd.DataFrame:
|
|
@@ -745,7 +726,7 @@ class Rating(Question):
|
|
|
745
726
|
def get_runner_input(self) -> list[dict]:
|
|
746
727
|
runner_input = super().get_runner_input()
|
|
747
728
|
for el in runner_input:
|
|
748
|
-
el["top_logprobs"] = self.top_logprobs
|
|
729
|
+
el["params"]["top_logprobs"] = self.top_logprobs
|
|
749
730
|
return runner_input
|
|
750
731
|
|
|
751
732
|
def df(self, model_groups: dict[str, list[str]]) -> pd.DataFrame:
|
|
@@ -899,9 +880,8 @@ class NextToken(Question):
|
|
|
899
880
|
|
|
900
881
|
def get_runner_input(self) -> list[dict]:
|
|
901
882
|
runner_input = super().get_runner_input()
|
|
902
|
-
|
|
903
883
|
for el in runner_input:
|
|
904
|
-
el["top_logprobs"] = self.top_logprobs
|
|
884
|
+
el["params"]["top_logprobs"] = self.top_logprobs
|
|
905
885
|
el["convert_to_probs"] = self.convert_to_probs
|
|
906
886
|
el["num_samples"] = self.num_samples
|
|
907
887
|
return runner_input
|
llmcomp/question/result.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import hashlib
|
|
1
2
|
import json
|
|
2
3
|
import os
|
|
3
4
|
from dataclasses import dataclass
|
|
@@ -5,10 +6,61 @@ from datetime import datetime
|
|
|
5
6
|
from typing import TYPE_CHECKING, Any
|
|
6
7
|
|
|
7
8
|
from llmcomp.config import Config
|
|
9
|
+
from llmcomp.runner.model_adapter import ModelAdapter
|
|
8
10
|
|
|
9
11
|
if TYPE_CHECKING:
|
|
10
12
|
from llmcomp.question.question import Question
|
|
11
13
|
|
|
14
|
+
# Bump this to invalidate all cached results when the caching implementation changes.
|
|
15
|
+
CACHE_VERSION = 2
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def cache_hash(question: "Question", model: str) -> str:
|
|
19
|
+
"""Compute cache hash for a question and model combination.
|
|
20
|
+
|
|
21
|
+
The hash includes:
|
|
22
|
+
- Question name and type
|
|
23
|
+
- All prepared API parameters (after ModelAdapter transformations)
|
|
24
|
+
- Runner-level settings (e.g., convert_to_probs, num_samples)
|
|
25
|
+
|
|
26
|
+
This ensures cache invalidation when:
|
|
27
|
+
- Question content changes (messages, temperature, etc.)
|
|
28
|
+
- Model-specific config changes (reasoning_effort, max_completion_tokens, etc.)
|
|
29
|
+
- Number of samples changes (samples_per_paraphrase)
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
question: The Question object
|
|
33
|
+
model: Model identifier (needed for ModelAdapter transformations)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
SHA256 hash string
|
|
37
|
+
"""
|
|
38
|
+
runner_input = question.get_runner_input()
|
|
39
|
+
|
|
40
|
+
# For each input, compute what would be sent to the API
|
|
41
|
+
prepared_inputs = []
|
|
42
|
+
for inp in runner_input:
|
|
43
|
+
params = inp["params"]
|
|
44
|
+
prepared_params = ModelAdapter.prepare(params, model)
|
|
45
|
+
|
|
46
|
+
# Include runner-level settings (not underscore-prefixed, not params)
|
|
47
|
+
runner_settings = {k: v for k, v in inp.items() if not k.startswith("_") and k != "params"}
|
|
48
|
+
|
|
49
|
+
prepared_inputs.append({
|
|
50
|
+
"prepared_params": prepared_params,
|
|
51
|
+
**runner_settings,
|
|
52
|
+
})
|
|
53
|
+
|
|
54
|
+
hash_input = {
|
|
55
|
+
"name": question.name,
|
|
56
|
+
"type": question.type(),
|
|
57
|
+
"inputs": prepared_inputs,
|
|
58
|
+
"_version": CACHE_VERSION,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
json_str = json.dumps(hash_input, sort_keys=True)
|
|
62
|
+
return hashlib.sha256(json_str.encode()).hexdigest()
|
|
63
|
+
|
|
12
64
|
|
|
13
65
|
@dataclass
|
|
14
66
|
class Result:
|
|
@@ -25,7 +77,7 @@ class Result:
|
|
|
25
77
|
|
|
26
78
|
@classmethod
|
|
27
79
|
def file_path(cls, question: "Question", model: str) -> str:
|
|
28
|
-
return f"{Config.cache_dir}/question/{question.name}/{question
|
|
80
|
+
return f"{Config.cache_dir}/question/{question.name}/{cache_hash(question, model)[:7]}.jsonl"
|
|
29
81
|
|
|
30
82
|
def save(self):
|
|
31
83
|
path = self.file_path(self.question, self.model)
|
|
@@ -50,7 +102,7 @@ class Result:
|
|
|
50
102
|
metadata = json.loads(lines[0])
|
|
51
103
|
|
|
52
104
|
# Hash collision on 7-character prefix - extremely rare
|
|
53
|
-
if metadata["hash"] != question
|
|
105
|
+
if metadata["hash"] != cache_hash(question, model):
|
|
54
106
|
os.remove(path)
|
|
55
107
|
print(f"Rare hash collision detected for {question.name}/{model}. Cached result removed.")
|
|
56
108
|
raise FileNotFoundError(f"Result for model {model} on question {question.name} not found in {path}")
|
|
@@ -63,7 +115,7 @@ class Result:
|
|
|
63
115
|
"name": self.question.name,
|
|
64
116
|
"model": self.model,
|
|
65
117
|
"last_update": datetime.now().isoformat(),
|
|
66
|
-
"hash": self.question.
|
|
118
|
+
"hash": cache_hash(self.question, self.model),
|
|
67
119
|
}
|
|
68
120
|
|
|
69
121
|
|
|
@@ -101,7 +153,7 @@ class JudgeCache:
|
|
|
101
153
|
|
|
102
154
|
@classmethod
|
|
103
155
|
def file_path(cls, judge: "Question") -> str:
|
|
104
|
-
return f"{Config.cache_dir}/judge/{judge.name}/{judge.
|
|
156
|
+
return f"{Config.cache_dir}/judge/{judge.name}/{cache_hash(judge, judge.model)[:7]}.json"
|
|
105
157
|
|
|
106
158
|
def _load(self) -> dict[str | None, dict[str, Any]]:
|
|
107
159
|
"""Load cache from disk, or return empty dict if not exists."""
|
|
@@ -120,7 +172,7 @@ class JudgeCache:
|
|
|
120
172
|
metadata = file_data["metadata"]
|
|
121
173
|
|
|
122
174
|
# Hash collision on 7-character prefix - extremely rare
|
|
123
|
-
if metadata["hash"] != self.judge.
|
|
175
|
+
if metadata["hash"] != cache_hash(self.judge, self.judge.model):
|
|
124
176
|
os.remove(path)
|
|
125
177
|
print(f"Rare hash collision detected for judge {self.judge.name}. Cached result removed.")
|
|
126
178
|
self._data = {}
|
|
@@ -155,7 +207,7 @@ class JudgeCache:
|
|
|
155
207
|
"name": self.judge.name,
|
|
156
208
|
"model": self.judge.model,
|
|
157
209
|
"last_update": datetime.now().isoformat(),
|
|
158
|
-
"hash": self.judge.
|
|
210
|
+
"hash": cache_hash(self.judge, self.judge.model),
|
|
159
211
|
"prompt": self.judge.paraphrases[0],
|
|
160
212
|
"uses_question": self.judge.uses_question,
|
|
161
213
|
}
|
|
@@ -22,12 +22,4 @@ def on_backoff(details):
|
|
|
22
22
|
on_backoff=on_backoff,
|
|
23
23
|
)
|
|
24
24
|
def openai_chat_completion(*, client, **kwargs):
|
|
25
|
-
if kwargs["model"].startswith("gpt-5"):
|
|
26
|
-
kwargs["reasoning_effort"] = "minimal"
|
|
27
|
-
if "max_tokens" in kwargs:
|
|
28
|
-
if kwargs["max_tokens"] < 16:
|
|
29
|
-
raise ValueError("max_tokens must be at least 16 for gpt-5 for whatever reason")
|
|
30
|
-
kwargs["max_completion_tokens"] = kwargs["max_tokens"]
|
|
31
|
-
del kwargs["max_tokens"]
|
|
32
|
-
|
|
33
25
|
return client.chat.completions.create(**kwargs)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
ModelSelector = Callable[[str], bool]
|
|
5
|
+
PrepareFunction = Callable[[dict, str], dict]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModelAdapter:
|
|
9
|
+
"""Adapts API request params for specific models.
|
|
10
|
+
|
|
11
|
+
Handlers can be registered to transform params for specific models.
|
|
12
|
+
All matching handlers are applied in registration order.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
_handlers: list[tuple[ModelSelector, PrepareFunction]] = []
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def register(cls, model_selector: ModelSelector, prepare_function: PrepareFunction):
|
|
19
|
+
"""Register a handler for model-specific param transformation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model_selector: Callable[[str], bool] - returns True if this handler
|
|
23
|
+
should be applied for the given model name.
|
|
24
|
+
prepare_function: Callable[[dict, str], dict] - transforms params.
|
|
25
|
+
Receives (params, model) and returns transformed params.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
# Register a handler for a custom model
|
|
29
|
+
def my_model_prepare(params, model):
|
|
30
|
+
# Transform params as needed
|
|
31
|
+
return {**params, "custom_param": "value"}
|
|
32
|
+
|
|
33
|
+
ModelAdapter.register(
|
|
34
|
+
lambda model: model == "my-model",
|
|
35
|
+
my_model_prepare
|
|
36
|
+
)
|
|
37
|
+
"""
|
|
38
|
+
cls._handlers.append((model_selector, prepare_function))
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def prepare(cls, params: dict, model: str) -> dict:
|
|
42
|
+
"""Prepare params for the API call.
|
|
43
|
+
|
|
44
|
+
Applies all registered handlers whose model_selector returns True.
|
|
45
|
+
Handlers are applied in registration order, each receiving the output
|
|
46
|
+
of the previous handler.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
params: The params to transform.
|
|
50
|
+
model: The model name.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Transformed params ready for the API call.
|
|
54
|
+
"""
|
|
55
|
+
result = params
|
|
56
|
+
for model_selector, prepare_function in cls._handlers:
|
|
57
|
+
if model_selector(model):
|
|
58
|
+
result = prepare_function(result, model)
|
|
59
|
+
return result
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def test_request_params(cls, model: str) -> dict:
|
|
63
|
+
"""Get minimal params for testing if a model works.
|
|
64
|
+
|
|
65
|
+
Returns params for a minimal API request to verify connectivity.
|
|
66
|
+
Does NOT use registered handlers - just handles core model requirements.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
model: The model name.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Dict with model, messages, and appropriate token limit params.
|
|
73
|
+
"""
|
|
74
|
+
params = {
|
|
75
|
+
"model": model,
|
|
76
|
+
"messages": [{"role": "user", "content": "Hi"}],
|
|
77
|
+
"timeout": 30, # Some providers are slow
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
if cls._is_reasoning_model(model):
|
|
81
|
+
# Reasoning models need max_completion_tokens and reasoning_effort
|
|
82
|
+
params["max_completion_tokens"] = 16
|
|
83
|
+
params["reasoning_effort"] = "none"
|
|
84
|
+
else:
|
|
85
|
+
params["max_tokens"] = 1
|
|
86
|
+
|
|
87
|
+
return params
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def _is_reasoning_model(cls, model: str) -> bool:
|
|
91
|
+
"""Check if model is a reasoning model (o1, o3, o4, gpt-5 series)."""
|
|
92
|
+
return (
|
|
93
|
+
model.startswith("o1")
|
|
94
|
+
or model.startswith("o3")
|
|
95
|
+
or model.startswith("o4")
|
|
96
|
+
or model.startswith("gpt-5")
|
|
97
|
+
)
|
|
98
|
+
|
llmcomp/runner/runner.py
CHANGED
|
@@ -8,6 +8,7 @@ from tqdm import tqdm
|
|
|
8
8
|
|
|
9
9
|
from llmcomp.config import Config, NoClientForModel
|
|
10
10
|
from llmcomp.runner.chat_completion import openai_chat_completion
|
|
11
|
+
from llmcomp.runner.model_adapter import ModelAdapter
|
|
11
12
|
|
|
12
13
|
NO_LOGPROBS_WARNING = """\
|
|
13
14
|
Failed to get logprobs because {model} didn't send them.
|
|
@@ -32,31 +33,26 @@ class Runner:
|
|
|
32
33
|
self._client = Config.client_for_model(self.model)
|
|
33
34
|
return self._client
|
|
34
35
|
|
|
35
|
-
def
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
"client": self.client,
|
|
46
|
-
"model": self.model,
|
|
47
|
-
"messages": messages,
|
|
48
|
-
"temperature": temperature,
|
|
49
|
-
"timeout": Config.timeout,
|
|
50
|
-
**kwargs,
|
|
51
|
-
}
|
|
52
|
-
if max_tokens is not None:
|
|
53
|
-
# Sending max_tokens is not supported for o3.
|
|
54
|
-
args["max_tokens"] = max_tokens
|
|
36
|
+
def _prepare_for_model(self, params: dict) -> dict:
|
|
37
|
+
"""Prepare params for the API call via ModelAdapter.
|
|
38
|
+
|
|
39
|
+
Also adds timeout from Config. Timeout is added here (not in ModelAdapter)
|
|
40
|
+
because it doesn't affect API response content and shouldn't be part of the cache hash.
|
|
41
|
+
|
|
42
|
+
Note: timeout is set first so that ModelAdapter handlers can override it if needed.
|
|
43
|
+
"""
|
|
44
|
+
prepared = ModelAdapter.prepare(params, self.model)
|
|
45
|
+
return {"timeout": Config.timeout, **prepared}
|
|
55
46
|
|
|
56
|
-
|
|
57
|
-
|
|
47
|
+
def get_text(self, params: dict) -> str:
|
|
48
|
+
"""Get a text completion from the model.
|
|
58
49
|
|
|
59
|
-
|
|
50
|
+
Args:
|
|
51
|
+
params: Dictionary of parameters for the API.
|
|
52
|
+
Must include 'messages'. Other common keys: 'temperature', 'max_tokens'.
|
|
53
|
+
"""
|
|
54
|
+
prepared = self._prepare_for_model(params)
|
|
55
|
+
completion = openai_chat_completion(client=self.client, **prepared)
|
|
60
56
|
try:
|
|
61
57
|
return completion.choices[0].message.content
|
|
62
58
|
except Exception:
|
|
@@ -65,15 +61,22 @@ class Runner:
|
|
|
65
61
|
|
|
66
62
|
def single_token_probs(
|
|
67
63
|
self,
|
|
68
|
-
|
|
69
|
-
|
|
64
|
+
params: dict,
|
|
65
|
+
*,
|
|
70
66
|
num_samples: int = 1,
|
|
71
67
|
convert_to_probs: bool = True,
|
|
72
|
-
**kwargs,
|
|
73
68
|
) -> dict:
|
|
69
|
+
"""Get probability distribution of the next token, optionally averaged over multiple samples.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
params: Dictionary of parameters for the API.
|
|
73
|
+
Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
|
|
74
|
+
num_samples: Number of samples to average over. Default: 1.
|
|
75
|
+
convert_to_probs: If True, convert logprobs to probabilities. Default: True.
|
|
76
|
+
"""
|
|
74
77
|
probs = {}
|
|
75
78
|
for _ in range(num_samples):
|
|
76
|
-
new_probs = self.single_token_probs_one_sample(
|
|
79
|
+
new_probs = self.single_token_probs_one_sample(params, convert_to_probs=convert_to_probs)
|
|
77
80
|
for key, value in new_probs.items():
|
|
78
81
|
probs[key] = probs.get(key, 0) + value
|
|
79
82
|
result = {key: value / num_samples for key, value in probs.items()}
|
|
@@ -82,23 +85,31 @@ class Runner:
|
|
|
82
85
|
|
|
83
86
|
def single_token_probs_one_sample(
|
|
84
87
|
self,
|
|
85
|
-
|
|
86
|
-
|
|
88
|
+
params: dict,
|
|
89
|
+
*,
|
|
87
90
|
convert_to_probs: bool = True,
|
|
88
|
-
**kwargs,
|
|
89
91
|
) -> dict:
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
92
|
+
"""Get probability distribution of the next token (single sample).
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
params: Dictionary of parameters for the API.
|
|
96
|
+
Must include 'messages'. Other common keys: 'top_logprobs', 'logit_bias'.
|
|
97
|
+
convert_to_probs: If True, convert logprobs to probabilities. Default: True.
|
|
98
|
+
|
|
99
|
+
Note: This function forces max_tokens=1, temperature=0, logprobs=True.
|
|
100
|
+
"""
|
|
101
|
+
# Build complete params with defaults and forced params
|
|
102
|
+
complete_params = {
|
|
103
|
+
# Default for top_logprobs, can be overridden by params:
|
|
104
|
+
"top_logprobs": 20,
|
|
105
|
+
**params,
|
|
106
|
+
# These are required for single_token_probs semantics (cannot be overridden):
|
|
107
|
+
"max_tokens": 1,
|
|
108
|
+
"temperature": 0,
|
|
109
|
+
"logprobs": True,
|
|
110
|
+
}
|
|
111
|
+
prepared = self._prepare_for_model(complete_params)
|
|
112
|
+
completion = openai_chat_completion(client=self.client, **prepared)
|
|
102
113
|
|
|
103
114
|
if completion.choices[0].logprobs is None:
|
|
104
115
|
raise Exception(f"No logprobs returned, it seems that your provider for {self.model} doesn't support that.")
|
|
@@ -131,8 +142,8 @@ class Runner:
|
|
|
131
142
|
FUNC is get_text or single_token_probs. Examples:
|
|
132
143
|
|
|
133
144
|
kwargs_list = [
|
|
134
|
-
{"messages": [{"role": "user", "content": "Hello"}]},
|
|
135
|
-
{"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7},
|
|
145
|
+
{"params": {"messages": [{"role": "user", "content": "Hello"}]}},
|
|
146
|
+
{"params": {"messages": [{"role": "user", "content": "Bye"}], "temperature": 0.7}},
|
|
136
147
|
]
|
|
137
148
|
for in_, out in runner.get_many(runner.get_text, kwargs_list):
|
|
138
149
|
print(in_, "->", out)
|
|
@@ -140,8 +151,8 @@ class Runner:
|
|
|
140
151
|
or
|
|
141
152
|
|
|
142
153
|
kwargs_list = [
|
|
143
|
-
{"messages": [{"role": "user", "content": "Hello"}]},
|
|
144
|
-
{"messages": [{"role": "user", "content": "Bye"}]},
|
|
154
|
+
{"params": {"messages": [{"role": "user", "content": "Hello"}]}},
|
|
155
|
+
{"params": {"messages": [{"role": "user", "content": "Bye"}]}},
|
|
145
156
|
]
|
|
146
157
|
for in_, out in runner.get_many(runner.single_token_probs, kwargs_list):
|
|
147
158
|
print(in_, "->", out)
|
|
@@ -149,10 +160,10 @@ class Runner:
|
|
|
149
160
|
(FUNC that is a different callable should also work)
|
|
150
161
|
|
|
151
162
|
This function returns a generator that yields pairs (input, output),
|
|
152
|
-
where input is an element from
|
|
163
|
+
where input is an element from KWARGS_LIST and output is the thing returned by
|
|
153
164
|
FUNC for this input.
|
|
154
165
|
|
|
155
|
-
Dictionaries in
|
|
166
|
+
Dictionaries in KWARGS_LIST might include optional keys starting with underscore,
|
|
156
167
|
they are just ignored, but they are returned in the first element of the pair, so that's useful
|
|
157
168
|
for passing some additional information that will be later paired with the output.
|
|
158
169
|
|
|
@@ -179,7 +190,8 @@ class Runner:
|
|
|
179
190
|
raise
|
|
180
191
|
except Exception as e:
|
|
181
192
|
# Truncate messages for readability
|
|
182
|
-
|
|
193
|
+
params = func_kwargs.get("params", {})
|
|
194
|
+
messages = params.get("messages", [])
|
|
183
195
|
if messages:
|
|
184
196
|
last_msg = str(messages[-1].get("content", ""))[:100]
|
|
185
197
|
msg_info = f", last message: {last_msg!r}..."
|
|
@@ -208,15 +220,17 @@ class Runner:
|
|
|
208
220
|
|
|
209
221
|
def sample_probs(
|
|
210
222
|
self,
|
|
211
|
-
|
|
223
|
+
params: dict,
|
|
212
224
|
*,
|
|
213
225
|
num_samples: int,
|
|
214
|
-
max_tokens: int,
|
|
215
|
-
temperature: float = 1,
|
|
216
|
-
**kwargs,
|
|
217
226
|
) -> dict:
|
|
218
227
|
"""Sample answers NUM_SAMPLES times. Returns probabilities of answers.
|
|
219
228
|
|
|
229
|
+
Args:
|
|
230
|
+
params: Dictionary of parameters for the API.
|
|
231
|
+
Must include 'messages'. Other common keys: 'max_tokens', 'temperature'.
|
|
232
|
+
num_samples: Number of samples to collect.
|
|
233
|
+
|
|
220
234
|
Works only if the API supports `n` parameter.
|
|
221
235
|
|
|
222
236
|
Usecases:
|
|
@@ -228,16 +242,13 @@ class Runner:
|
|
|
228
242
|
cnts = defaultdict(int)
|
|
229
243
|
for i in range(((num_samples - 1) // 128) + 1):
|
|
230
244
|
n = min(128, num_samples - i * 128)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
timeout=Config.timeout,
|
|
239
|
-
**kwargs,
|
|
240
|
-
)
|
|
245
|
+
# Build complete params with forced param
|
|
246
|
+
complete_params = {
|
|
247
|
+
**params,
|
|
248
|
+
"n": n,
|
|
249
|
+
}
|
|
250
|
+
prepared = self._prepare_for_model(complete_params)
|
|
251
|
+
completion = openai_chat_completion(client=self.client, **prepared)
|
|
241
252
|
for choice in completion.choices:
|
|
242
253
|
cnts[choice.message.content] += 1
|
|
243
254
|
if sum(cnts.values()) != num_samples:
|