eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from tokenizers import Tokenizer
|
|
8
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
|
|
9
|
+
|
|
10
|
+
from eval_framework.llm.base import BaseLLM
|
|
11
|
+
from eval_framework.shared.types import (
|
|
12
|
+
ConcatCompression,
|
|
13
|
+
Error,
|
|
14
|
+
PromptTooLongException,
|
|
15
|
+
RawCompletion,
|
|
16
|
+
RawLoglikelihood,
|
|
17
|
+
)
|
|
18
|
+
from eval_framework.tasks.base import Sample
|
|
19
|
+
from eval_framework.tasks.utils import raise_errors
|
|
20
|
+
from eval_framework.utils.constants import RED, RESET
|
|
21
|
+
from template_formatting.formatter import BaseFormatter, ConcatFormatter, HFFormatter, Llama3Formatter, Message
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class StopSequenceCriteria(StoppingCriteria):
|
|
27
|
+
def __init__(self, tokenizer: Tokenizer, stop_sequences: list[str], prompt_token_count: int) -> None:
|
|
28
|
+
self.tokenizer = tokenizer
|
|
29
|
+
self.stop_sequences = stop_sequences
|
|
30
|
+
self.prompt_token_count = prompt_token_count
|
|
31
|
+
# (relatively weak) upper bound for the number of tokens that
|
|
32
|
+
# need to be decoded to check for stop sequences
|
|
33
|
+
self.token_history_length = max(map(len, stop_sequences))
|
|
34
|
+
|
|
35
|
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> bool:
|
|
36
|
+
sequence = input_ids[0].tolist()
|
|
37
|
+
sequence = sequence[self.prompt_token_count :]
|
|
38
|
+
if len(sequence) > self.token_history_length:
|
|
39
|
+
sequence = sequence[-self.token_history_length :]
|
|
40
|
+
decoded_text = self.tokenizer.decode(sequence, skip_special_tokens=True)
|
|
41
|
+
|
|
42
|
+
for stop_sequence in self.stop_sequences:
|
|
43
|
+
if stop_sequence in decoded_text:
|
|
44
|
+
return True
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RepeatedTokenSequenceCriteria(StoppingCriteria):
|
|
49
|
+
def __init__(self, tokenizer: Tokenizer, completion_start_index: int) -> None:
|
|
50
|
+
self.tokenizer = tokenizer
|
|
51
|
+
# Initialize with an empty string to store the last line
|
|
52
|
+
self.last_line = ""
|
|
53
|
+
self.completion_start_index = completion_start_index
|
|
54
|
+
# self.newline_token_id = tokenizer.encode('\n')
|
|
55
|
+
|
|
56
|
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> torch.Tensor:
|
|
57
|
+
# Convert token ids to tokens
|
|
58
|
+
tokens = self.tokenizer.decode(input_ids[0][self.completion_start_index :])
|
|
59
|
+
|
|
60
|
+
# Join tokens to form the current text
|
|
61
|
+
current_text = "".join(tokens)
|
|
62
|
+
# Split text into lines
|
|
63
|
+
lines = current_text.split("\n")
|
|
64
|
+
|
|
65
|
+
# Check if the last full line (ignoring the last if it's incomplete) is repeated
|
|
66
|
+
if len(lines) > 1 and lines[-2] == lines[-1] and not (lines[-1] == "" and lines[-2] == ""):
|
|
67
|
+
return torch.BoolTensor([True]).to(input_ids.device) # Stop generation if repeated line is found
|
|
68
|
+
|
|
69
|
+
return torch.BoolTensor([False]).to(input_ids.device)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class HFLLM(BaseLLM):
|
|
73
|
+
LLM_NAME: str
|
|
74
|
+
DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
|
|
75
|
+
SEQ_LENGTH: int | None = None
|
|
76
|
+
|
|
77
|
+
def __init__(self, formatter: BaseFormatter | None = None) -> None:
|
|
78
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
79
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.LLM_NAME)
|
|
80
|
+
self.model = AutoModelForCausalLM.from_pretrained(self.LLM_NAME, device_map="auto")
|
|
81
|
+
logger.info(f"{RED}[ Model initialized --------------------- {RESET}{self.LLM_NAME} {RED}]{RESET}")
|
|
82
|
+
self._set_formatter(formatter)
|
|
83
|
+
|
|
84
|
+
def _set_formatter(self, formatter: BaseFormatter | None = None) -> None:
|
|
85
|
+
# if formatter is being set at initialization time, use it
|
|
86
|
+
if formatter is not None:
|
|
87
|
+
self._formatter = formatter
|
|
88
|
+
# if formatter is not being set at initialization time, but DEFAULT_FORMATTER was specified, use it
|
|
89
|
+
elif self.DEFAULT_FORMATTER is not None:
|
|
90
|
+
self._formatter = self.DEFAULT_FORMATTER()
|
|
91
|
+
# if formatter is not being set at initialization time and there is no default formatter,
|
|
92
|
+
# using HF chat formatter if exists
|
|
93
|
+
elif self.tokenizer.chat_template is not None:
|
|
94
|
+
self._formatter = HFFormatter(self.LLM_NAME)
|
|
95
|
+
# if formatter is not being set at initialization time and there is no default formatter and no chat formatter,
|
|
96
|
+
# using ConcatFormatter
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError("No formatter specified and no default formatter available.")
|
|
99
|
+
|
|
100
|
+
logger.info(
|
|
101
|
+
f"{RED}[ Using default formatter --------------------- {RESET}{self._formatter.__class__.__name__} {RED}]{RESET}" # noqa: E501
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def count_tokens(self, text: str, /) -> int:
|
|
105
|
+
"""Count the number of tokens in a string."""
|
|
106
|
+
return len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
|
|
107
|
+
|
|
108
|
+
def generate_from_messages(
|
|
109
|
+
self,
|
|
110
|
+
messages: list[Sequence[Message]],
|
|
111
|
+
stop_sequences: list[str] | None = None,
|
|
112
|
+
max_tokens: int | None = None,
|
|
113
|
+
temperature: float | None = None,
|
|
114
|
+
) -> list[RawCompletion]:
|
|
115
|
+
if temperature is None:
|
|
116
|
+
effective_temperature = 0.0 # Current default, TODO: refactor to use model's default
|
|
117
|
+
logger.info(
|
|
118
|
+
f"Using default temperature value: {effective_temperature} as no custom temperature value was provided"
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
effective_temperature = temperature
|
|
122
|
+
|
|
123
|
+
raw_completions = []
|
|
124
|
+
for single_messages in messages:
|
|
125
|
+
# format
|
|
126
|
+
prompt = self._formatter.format(single_messages, output_mode="string")
|
|
127
|
+
# add_special_tokens would add a second BOS token without explicitly setting it False
|
|
128
|
+
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.device)
|
|
129
|
+
|
|
130
|
+
prompt_token_count = len(inputs["input_ids"][0])
|
|
131
|
+
pad_token_id = self.tokenizer.eos_token_id
|
|
132
|
+
|
|
133
|
+
# Prepare stopping criteria
|
|
134
|
+
stopping_criteria = StoppingCriteriaList()
|
|
135
|
+
if stop_sequences is not None:
|
|
136
|
+
stopping_criteria.append(StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_token_count)) # type: ignore[attr-defined]
|
|
137
|
+
|
|
138
|
+
stopping_criteria.append( # type: ignore[attr-defined]
|
|
139
|
+
RepeatedTokenSequenceCriteria(
|
|
140
|
+
self.tokenizer,
|
|
141
|
+
prompt_token_count,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
min_seq_length = min(filter(None, [self.seq_length, self.SEQ_LENGTH]))
|
|
146
|
+
|
|
147
|
+
# Calculate the maximum number of tokens to generate
|
|
148
|
+
max_tokens_to_generate = min_seq_length - prompt_token_count
|
|
149
|
+
# If max_tokens is specified, use the smaller of the two
|
|
150
|
+
max_tokens_to_generate = min(filter(None, [max_tokens_to_generate, max_tokens]))
|
|
151
|
+
|
|
152
|
+
if max_tokens_to_generate < 1:
|
|
153
|
+
if raise_errors():
|
|
154
|
+
raise PromptTooLongException("Prompt exceeded context size.")
|
|
155
|
+
raw_completions.append(
|
|
156
|
+
RawCompletion(
|
|
157
|
+
prompt=prompt,
|
|
158
|
+
prompt_sequence_positions=prompt_token_count,
|
|
159
|
+
completion="",
|
|
160
|
+
completion_sequence_positions=0,
|
|
161
|
+
raw_completion_error=Error(
|
|
162
|
+
error_class=PromptTooLongException.__name__,
|
|
163
|
+
message="Prompt exceeded context size.",
|
|
164
|
+
traceback="",
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
completion, completion_token_count = self._model_generate(
|
|
171
|
+
redis_key=(prompt, stop_sequences, max_tokens_to_generate, effective_temperature),
|
|
172
|
+
prompt_token_count=prompt_token_count,
|
|
173
|
+
inputs=inputs["input_ids"],
|
|
174
|
+
max_new_tokens=max_tokens_to_generate,
|
|
175
|
+
stopping_criteria=stopping_criteria,
|
|
176
|
+
num_return_sequences=1,
|
|
177
|
+
pad_token_id=pad_token_id,
|
|
178
|
+
return_dict_in_generate=False,
|
|
179
|
+
output_scores=False,
|
|
180
|
+
do_sample=effective_temperature > 0,
|
|
181
|
+
temperature=effective_temperature if effective_temperature > 0 else None,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
raw_completions.append(
|
|
185
|
+
RawCompletion(
|
|
186
|
+
prompt=prompt,
|
|
187
|
+
prompt_sequence_positions=prompt_token_count,
|
|
188
|
+
concat_compression=ConcatCompression.calculate(
|
|
189
|
+
single_messages, count_tokens=self.count_tokens, completion=completion
|
|
190
|
+
),
|
|
191
|
+
completion=completion,
|
|
192
|
+
completion_sequence_positions=completion_token_count,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
return raw_completions
|
|
196
|
+
|
|
197
|
+
def _model_generate(self, redis_key: Any, prompt_token_count: int, **kwargs: Any) -> tuple[str, int]:
|
|
198
|
+
outputs = self.model.generate(**kwargs)[0]
|
|
199
|
+
completion = self.tokenizer.decode(outputs[prompt_token_count:], skip_special_tokens=True)
|
|
200
|
+
|
|
201
|
+
if kwargs["stopping_criteria"][0].__class__.__name__ == "StopSequenceCriteria":
|
|
202
|
+
for stop_sequence in kwargs["stopping_criteria"][0].stop_sequences:
|
|
203
|
+
completion = completion.split(stop_sequence)[0]
|
|
204
|
+
return completion, len(outputs[prompt_token_count:])
|
|
205
|
+
|
|
206
|
+
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
|
|
207
|
+
results = []
|
|
208
|
+
for sample in samples:
|
|
209
|
+
# format
|
|
210
|
+
prompt = self._formatter.format(sample.messages, output_mode="string")
|
|
211
|
+
choices_log_probs: dict[str, float] = {}
|
|
212
|
+
choices_log_probs_sequence_positions: dict[str, float] = {}
|
|
213
|
+
error: Error | None = None
|
|
214
|
+
|
|
215
|
+
for choice in sample.possible_completions or []:
|
|
216
|
+
num_choice_tokens = len(self.tokenizer.encode(choice, add_special_tokens=False))
|
|
217
|
+
prompt_and_choice = f"{prompt}{choice}"
|
|
218
|
+
|
|
219
|
+
total_tokens_count = len(self.tokenizer.encode(prompt_and_choice, add_special_tokens=False))
|
|
220
|
+
|
|
221
|
+
min_max_tokens = min(filter(None, [self.SEQ_LENGTH, self.seq_length]))
|
|
222
|
+
|
|
223
|
+
if min_max_tokens < total_tokens_count:
|
|
224
|
+
if raise_errors():
|
|
225
|
+
raise PromptTooLongException("Prompt exceeded context size.")
|
|
226
|
+
choices_log_probs = {}
|
|
227
|
+
choices_log_probs_sequence_positions = {}
|
|
228
|
+
error = Error(
|
|
229
|
+
error_class=PromptTooLongException.__name__,
|
|
230
|
+
message="Prompt and choice exceeded context size.",
|
|
231
|
+
traceback="",
|
|
232
|
+
)
|
|
233
|
+
break
|
|
234
|
+
else:
|
|
235
|
+
# Calculate log-likelihoods for each token in the completion
|
|
236
|
+
sum_log_probs = self._model_log_probs(prompt_and_choice, num_choice_tokens)
|
|
237
|
+
|
|
238
|
+
choices_log_probs.update({choice: sum_log_probs})
|
|
239
|
+
choices_log_probs_sequence_positions.update({choice: num_choice_tokens})
|
|
240
|
+
|
|
241
|
+
results.append(
|
|
242
|
+
RawLoglikelihood(
|
|
243
|
+
prompt=prompt,
|
|
244
|
+
prompt_sequence_positions=len(self.tokenizer.encode(prompt, add_special_tokens=False)),
|
|
245
|
+
concat_compression=ConcatCompression.calculate(
|
|
246
|
+
sample.messages, count_tokens=self.count_tokens, choices=sample.possible_completions
|
|
247
|
+
),
|
|
248
|
+
loglikelihoods=choices_log_probs,
|
|
249
|
+
loglikelihoods_sequence_positions=choices_log_probs_sequence_positions,
|
|
250
|
+
raw_loglikelihood_error=error,
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
return results
|
|
254
|
+
|
|
255
|
+
def _model_log_probs(self, prompt: str, num_choice_tokens: int) -> float:
|
|
256
|
+
with torch.no_grad():
|
|
257
|
+
inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.device)
|
|
258
|
+
outputs = self.model(**inputs, labels=inputs["input_ids"])
|
|
259
|
+
logits = outputs.logits[:, :-1, :].squeeze(0)
|
|
260
|
+
target_ids = inputs["input_ids"][:, 1:].squeeze(0)
|
|
261
|
+
|
|
262
|
+
token_loglikelihoods = []
|
|
263
|
+
for i in range(0, len(target_ids)):
|
|
264
|
+
token_id = target_ids[i].item()
|
|
265
|
+
token = self.tokenizer.decode([token_id])
|
|
266
|
+
loglikelihood = torch.log_softmax(logits[i], dim=-1)[token_id].item()
|
|
267
|
+
token_loglikelihoods.append({token: loglikelihood})
|
|
268
|
+
|
|
269
|
+
return sum([list(log_prob.values())[0] for log_prob in token_loglikelihoods[-num_choice_tokens:]])
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def seq_length(self) -> int | None:
|
|
273
|
+
config = self.model.config
|
|
274
|
+
return config.max_position_embeddings if hasattr(config, "max_position_embeddings") else None
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class HFLLM_from_name(HFLLM):
|
|
278
|
+
"""
|
|
279
|
+
A generic class to create HFLLM instances from a given model name.
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(self, model_name: str | None = None, formatter: str = "Llama3Formatter", **kwargs: Any) -> None:
|
|
283
|
+
if model_name is None:
|
|
284
|
+
raise ValueError("model_name is required")
|
|
285
|
+
|
|
286
|
+
self.LLM_NAME = model_name
|
|
287
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
288
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.LLM_NAME)
|
|
289
|
+
self.model = AutoModelForCausalLM.from_pretrained(self.LLM_NAME, device_map="auto")
|
|
290
|
+
|
|
291
|
+
# Lazy formatter initialization - only create the one we need
|
|
292
|
+
selected_formatter = self._get_formatter(formatter, model_name)
|
|
293
|
+
|
|
294
|
+
print(f"{RED}[ Model initialized --------------------- {RESET}{self.LLM_NAME} {RED}]{RESET}")
|
|
295
|
+
print(f"{RED}[ Formatter: {formatter} ]{RESET}")
|
|
296
|
+
self._set_formatter(selected_formatter)
|
|
297
|
+
|
|
298
|
+
def _get_formatter(self, formatter: str, model_name: str) -> Any:
|
|
299
|
+
"""Get formatter instance based on formatter name."""
|
|
300
|
+
if formatter == "Llama3Formatter":
|
|
301
|
+
return Llama3Formatter()
|
|
302
|
+
elif formatter == "MistralFormatter":
|
|
303
|
+
from eval_framework.llm.mistral import MagistralFormatter
|
|
304
|
+
|
|
305
|
+
return MagistralFormatter(model_name)
|
|
306
|
+
elif formatter == "ConcatFormatter":
|
|
307
|
+
return ConcatFormatter()
|
|
308
|
+
elif formatter == "HFFormatter":
|
|
309
|
+
return HFFormatter(model_name)
|
|
310
|
+
else:
|
|
311
|
+
supported = ["Llama3Formatter", "MistralFormatter", "ConcatFormatter", "HFFormatter"]
|
|
312
|
+
raise ValueError(f"Unsupported formatter: {formatter}. Supported formatters: {supported}")
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class Pythia410m(HFLLM):
|
|
316
|
+
LLM_NAME = "EleutherAI/pythia-410m"
|
|
317
|
+
DEFAULT_FORMATTER = ConcatFormatter
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class SmolLM135M(HFLLM):
|
|
321
|
+
LLM_NAME = "HuggingFaceTB/SmolLM-135M"
|
|
322
|
+
DEFAULT_FORMATTER = ConcatFormatter
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class Smollm135MInstruct(HFLLM):
|
|
326
|
+
LLM_NAME = "HuggingFaceTB/SmolLM-135M-Instruct"
|
|
327
|
+
DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class Qwen3_0_6B(HFLLM):
|
|
331
|
+
LLM_NAME = "Qwen/Qwen3-0.6B"
|
|
332
|
+
DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME, chat_template_kwargs={"enable_thinking": True})
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Any, Literal, override
|
|
3
|
+
|
|
4
|
+
from vllm import SamplingParams
|
|
5
|
+
|
|
6
|
+
from eval_framework.llm.vllm import TokenizedContainer, VLLMModel, VLLMTokenizerAPI
|
|
7
|
+
from template_formatting.formatter import BaseFormatter, Message
|
|
8
|
+
from template_formatting.mistral_formatter import MagistralFormatter, MistralSerializer
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"MistralAdapter",
|
|
12
|
+
"MistralVLLM",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MistralAdapter(VLLMTokenizerAPI[list[Message]]):
|
|
17
|
+
def __init__(self, target_mdl: str) -> None:
|
|
18
|
+
self.serializer = MistralSerializer(llm_target=target_mdl)
|
|
19
|
+
self.tokenizer = self.serializer.get_tokenizer()
|
|
20
|
+
|
|
21
|
+
def encode_formatted_struct(self, struct: list[Message]) -> TokenizedContainer:
|
|
22
|
+
mistral_msg_lst = self.serializer.convert_from_aa(msg_lst=struct)
|
|
23
|
+
mistral_request = self.serializer.build_mistral_request(mistral_msg_lst=mistral_msg_lst)
|
|
24
|
+
mistral_tokenized_obj = self.tokenizer.encode_instruct(mistral_request)
|
|
25
|
+
return TokenizedContainer(tokens=mistral_tokenized_obj.tokens, text=mistral_tokenized_obj.text)
|
|
26
|
+
|
|
27
|
+
def encode_plain_text(self, text: str) -> TokenizedContainer:
|
|
28
|
+
choice_tokens = self.tokenizer.tokenizer.encode(text, False, False)
|
|
29
|
+
return TokenizedContainer(tokens=choice_tokens, text=text)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MistralVLLM(VLLMModel):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
formatter: BaseFormatter | None = None,
|
|
36
|
+
max_model_len: int | None = None,
|
|
37
|
+
tensor_parallel_size: int = 1,
|
|
38
|
+
gpu_memory_utilization: float = 0.9,
|
|
39
|
+
batch_size: int = 1,
|
|
40
|
+
checkpoint_path: str | None = None,
|
|
41
|
+
checkpoint_name: str | None = None,
|
|
42
|
+
sampling_params: SamplingParams | dict[str, Any] | None = None,
|
|
43
|
+
**kwargs: Any,
|
|
44
|
+
) -> None:
|
|
45
|
+
model_args = {"tokenizer_mode": "mistral", "config_format": "mistral", "load_format": "mistral"}
|
|
46
|
+
super().__init__(
|
|
47
|
+
formatter,
|
|
48
|
+
max_model_len,
|
|
49
|
+
tensor_parallel_size,
|
|
50
|
+
gpu_memory_utilization,
|
|
51
|
+
batch_size,
|
|
52
|
+
checkpoint_path,
|
|
53
|
+
checkpoint_name,
|
|
54
|
+
sampling_params,
|
|
55
|
+
**{**model_args, **kwargs},
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@override
|
|
59
|
+
@property
|
|
60
|
+
def tokenizer(self) -> VLLMTokenizerAPI:
|
|
61
|
+
if self._tokenizer is None:
|
|
62
|
+
self._tokenizer = MistralAdapter(target_mdl=self.LLM_NAME)
|
|
63
|
+
return self._tokenizer
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def formatter_output_mode(self) -> Literal["string", "list"]:
|
|
67
|
+
"""Determine the correct output mode for the formatter based on tokenizer type."""
|
|
68
|
+
return "list"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class MagistralVLLM(MistralVLLM):
|
|
72
|
+
LLM_NAME = "mistralai/Magistral-Small-2506"
|
|
73
|
+
DEFAULT_FORMATTER = partial(MagistralFormatter, "mistralai/Magistral-Small-2506")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""This is just a default model file with some small models for testing.
|
|
2
|
+
|
|
3
|
+
Please define your own model file externally and pass it to the eval-framework entrypoint
|
|
4
|
+
to use it.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from eval_framework.utils.packaging import is_extra_installed
|
|
8
|
+
|
|
9
|
+
if is_extra_installed(extra="transformers"):
|
|
10
|
+
from eval_framework.llm.huggingface import Pythia410m, SmolLM135M, Smollm135MInstruct, Qwen3_0_6B # noqa F401
|
|
11
|
+
|
|
12
|
+
if is_extra_installed("mistral"):
|
|
13
|
+
from eval_framework.llm.mistral import MagistralVLLM # noqa F401
|
|
14
|
+
|
|
15
|
+
if is_extra_installed("vllm"):
|
|
16
|
+
from eval_framework.llm.vllm import Qwen3_0_6B_VLLM, Qwen3_0_6B_VLLM_No_Thinking # noqa F401
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Callable, Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import tiktoken # OpenAI's official tokenizer library
|
|
8
|
+
from openai import OpenAI
|
|
9
|
+
|
|
10
|
+
from eval_framework.llm.base import BaseLLM
|
|
11
|
+
from eval_framework.shared.types import ConcatCompression, RawCompletion, RawLoglikelihood
|
|
12
|
+
from eval_framework.tasks.base import Sample
|
|
13
|
+
from template_formatting.formatter import BaseFormatter, Message, Role
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OpenAIModel(BaseLLM):
|
|
19
|
+
DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model_name: str = "gpt-4o",
|
|
24
|
+
formatter: BaseFormatter | None = None,
|
|
25
|
+
temperature: float | None = None,
|
|
26
|
+
api_key: str | None = None,
|
|
27
|
+
organization: str | None = None,
|
|
28
|
+
base_url: str | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initialize OpenAI API client.
|
|
31
|
+
Args:
|
|
32
|
+
model_name: Name of the OpenAI model to use (e.g., "gpt-4", "gpt-3.5-turbo")
|
|
33
|
+
formatter: Optional message formatter
|
|
34
|
+
temperature: Sampling temperature (0.0 to 2.0)
|
|
35
|
+
api_key: OpenAI API key (defaults to OPENAI_API_KEY env variable)
|
|
36
|
+
organization: Optional organization ID
|
|
37
|
+
base_url: Optional API base URL for Azure or other endpoints
|
|
38
|
+
"""
|
|
39
|
+
self._model_name = model_name
|
|
40
|
+
logger.info(f"Using {model_name} as a judge")
|
|
41
|
+
self._formatter = formatter or self.DEFAULT_FORMATTER() if self.DEFAULT_FORMATTER is not None else None
|
|
42
|
+
self._temperature = temperature
|
|
43
|
+
# Initialize OpenAI client
|
|
44
|
+
self._client = OpenAI(
|
|
45
|
+
api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
|
|
46
|
+
organization=organization,
|
|
47
|
+
base_url=base_url,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Initialize tiktoken tokenizer for the model
|
|
51
|
+
self._encoding = tiktoken.encoding_for_model(self._model_name)
|
|
52
|
+
|
|
53
|
+
def _count_tokens(self, text: str) -> int:
|
|
54
|
+
"""Helper method to count tokens using tiktoken."""
|
|
55
|
+
return len(self._encoding.encode(text))
|
|
56
|
+
|
|
57
|
+
def generate_from_messages(
|
|
58
|
+
self,
|
|
59
|
+
messages: list[Sequence[Message]],
|
|
60
|
+
stop_sequences: list[str] | None = None,
|
|
61
|
+
max_tokens: int | None = None,
|
|
62
|
+
temperature: float | None = None,
|
|
63
|
+
) -> list[RawCompletion]:
|
|
64
|
+
if temperature is None:
|
|
65
|
+
effective_temperature = 0.0 # Current default, TODO: refactor to use model's default
|
|
66
|
+
logger.info(
|
|
67
|
+
f"Using default temperature value: {effective_temperature} as no custom temperature value was provided"
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
effective_temperature = temperature
|
|
71
|
+
"""Generate completion from messages.
|
|
72
|
+
Args:
|
|
73
|
+
messages: Sequence of messages
|
|
74
|
+
stop_sequences: Optional list of stop sequences
|
|
75
|
+
max_tokens: Optional maximum number of tokens to generate
|
|
76
|
+
Returns:
|
|
77
|
+
Tuple of (prompt, completion)
|
|
78
|
+
"""
|
|
79
|
+
results = []
|
|
80
|
+
for single_messages in messages:
|
|
81
|
+
if self._formatter is not None:
|
|
82
|
+
# Use formatter for text completion API
|
|
83
|
+
prompt = self._formatter.format(single_messages, output_mode="string")
|
|
84
|
+
response = self._client.completions.create(
|
|
85
|
+
model=self._model_name,
|
|
86
|
+
prompt=prompt,
|
|
87
|
+
temperature=effective_temperature,
|
|
88
|
+
max_tokens=max_tokens,
|
|
89
|
+
stop=stop_sequences,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
prompt_sequence_positions: int | None = self._count_tokens(prompt)
|
|
93
|
+
completion = response.choices[0].text
|
|
94
|
+
completion_sequence_positions = self._count_tokens(completion)
|
|
95
|
+
|
|
96
|
+
results.append(
|
|
97
|
+
RawCompletion(
|
|
98
|
+
prompt=prompt,
|
|
99
|
+
prompt_sequence_positions=prompt_sequence_positions,
|
|
100
|
+
concat_compression=ConcatCompression.calculate(
|
|
101
|
+
single_messages, count_tokens=self._count_tokens, completion=completion
|
|
102
|
+
),
|
|
103
|
+
completion=completion,
|
|
104
|
+
completion_sequence_positions=completion_sequence_positions,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
# Use chat completion API
|
|
109
|
+
from openai.types.chat import ChatCompletionAssistantMessageParam, ChatCompletionUserMessageParam
|
|
110
|
+
|
|
111
|
+
chat_messages = [
|
|
112
|
+
(
|
|
113
|
+
ChatCompletionUserMessageParam(role="user", content=m.content)
|
|
114
|
+
if m.role is not None and m.role.value.lower() == "user"
|
|
115
|
+
else ChatCompletionAssistantMessageParam(role="assistant", content=m.content)
|
|
116
|
+
)
|
|
117
|
+
for m in single_messages
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
chat_response = self._client.chat.completions.create(
|
|
121
|
+
model=self._model_name,
|
|
122
|
+
messages=chat_messages,
|
|
123
|
+
temperature=effective_temperature,
|
|
124
|
+
max_tokens=max_tokens,
|
|
125
|
+
stop=stop_sequences,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Reconstruct the prompt (since OpenAI API does not return it)
|
|
129
|
+
prompt = "\n".join([f"{m['role']}: {m['content']}" for m in chat_messages])
|
|
130
|
+
|
|
131
|
+
prompt_sequence_positions = (
|
|
132
|
+
chat_response.usage.prompt_tokens if chat_response.usage else None
|
|
133
|
+
) # OpenAI API gives token count
|
|
134
|
+
completion = (
|
|
135
|
+
chat_response.choices[0].message.content if chat_response.choices[0].message.content else ""
|
|
136
|
+
)
|
|
137
|
+
completion_sequence_positions = self._count_tokens(completion)
|
|
138
|
+
|
|
139
|
+
results.append(
|
|
140
|
+
RawCompletion(
|
|
141
|
+
prompt=prompt,
|
|
142
|
+
prompt_sequence_positions=prompt_sequence_positions,
|
|
143
|
+
concat_compression=ConcatCompression.calculate(
|
|
144
|
+
single_messages, count_tokens=self._count_tokens, completion=completion
|
|
145
|
+
),
|
|
146
|
+
completion=completion,
|
|
147
|
+
completion_sequence_positions=completion_sequence_positions,
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
return results
|
|
151
|
+
|
|
152
|
+
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
|
|
153
|
+
"""Get log probabilities for possible completions.
|
|
154
|
+
Args:
|
|
155
|
+
samples: list of Sample containing possible completions
|
|
156
|
+
Returns:
|
|
157
|
+
list of Tuple of (prompt, dict of completion->logprob)
|
|
158
|
+
Raises:
|
|
159
|
+
NotImplementedError: Logprobs not yet implemented
|
|
160
|
+
"""
|
|
161
|
+
raise NotImplementedError("Logprobs not yet implemented for OpenAI API")
|
|
162
|
+
|
|
163
|
+
def generate_structured_output(
|
|
164
|
+
self,
|
|
165
|
+
messages: list[Sequence[Message]],
|
|
166
|
+
stop_sequences: list[str] | None = None,
|
|
167
|
+
max_tokens: int | None = None,
|
|
168
|
+
temperature: float = 0.0,
|
|
169
|
+
) -> Any:
|
|
170
|
+
"""Generate structured output (e.g. JSON) from messages.
|
|
171
|
+
This implementation ensures the model returns valid JSON.
|
|
172
|
+
Args:
|
|
173
|
+
messages: list of Sequence of messages
|
|
174
|
+
stop_sequences: Optional stop sequences
|
|
175
|
+
max_tokens: Optional max tokens
|
|
176
|
+
Returns:
|
|
177
|
+
Parsed JSON response
|
|
178
|
+
"""
|
|
179
|
+
completions = []
|
|
180
|
+
list_json_messages: list[Sequence[Message]] = []
|
|
181
|
+
for single_messages in messages:
|
|
182
|
+
# Add system message to encourage JSON output
|
|
183
|
+
json_messages = list(single_messages)
|
|
184
|
+
if not any(m.role == Role.SYSTEM for m in single_messages):
|
|
185
|
+
json_messages.insert(
|
|
186
|
+
0,
|
|
187
|
+
Message(
|
|
188
|
+
role=Role.SYSTEM, content="You are a helpful assistant that always responds with valid JSON."
|
|
189
|
+
),
|
|
190
|
+
)
|
|
191
|
+
list_json_messages.append(json_messages)
|
|
192
|
+
# Generate completion
|
|
193
|
+
completions = self.generate_from_messages(
|
|
194
|
+
messages=list_json_messages, stop_sequences=stop_sequences, max_tokens=max_tokens
|
|
195
|
+
)
|
|
196
|
+
responses = []
|
|
197
|
+
for completion in completions:
|
|
198
|
+
try:
|
|
199
|
+
# Parse JSON responses
|
|
200
|
+
responses.append(json.loads(completion.completion))
|
|
201
|
+
except json.JSONDecodeError as e:
|
|
202
|
+
logger.info(f"Warning: Failed to parse JSON response: {e}")
|
|
203
|
+
logger.info(f"Raw response: {completion.completion}")
|
|
204
|
+
raise
|
|
205
|
+
return responses
|