eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,332 @@
1
+ import logging
2
+ from collections.abc import Callable, Sequence
3
+ from functools import partial
4
+ from typing import Any
5
+
6
+ import torch
7
+ from tokenizers import Tokenizer
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
9
+
10
+ from eval_framework.llm.base import BaseLLM
11
+ from eval_framework.shared.types import (
12
+ ConcatCompression,
13
+ Error,
14
+ PromptTooLongException,
15
+ RawCompletion,
16
+ RawLoglikelihood,
17
+ )
18
+ from eval_framework.tasks.base import Sample
19
+ from eval_framework.tasks.utils import raise_errors
20
+ from eval_framework.utils.constants import RED, RESET
21
+ from template_formatting.formatter import BaseFormatter, ConcatFormatter, HFFormatter, Llama3Formatter, Message
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class StopSequenceCriteria(StoppingCriteria):
27
+ def __init__(self, tokenizer: Tokenizer, stop_sequences: list[str], prompt_token_count: int) -> None:
28
+ self.tokenizer = tokenizer
29
+ self.stop_sequences = stop_sequences
30
+ self.prompt_token_count = prompt_token_count
31
+ # (relatively weak) upper bound for the number of tokens that
32
+ # need to be decoded to check for stop sequences
33
+ self.token_history_length = max(map(len, stop_sequences))
34
+
35
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> bool:
36
+ sequence = input_ids[0].tolist()
37
+ sequence = sequence[self.prompt_token_count :]
38
+ if len(sequence) > self.token_history_length:
39
+ sequence = sequence[-self.token_history_length :]
40
+ decoded_text = self.tokenizer.decode(sequence, skip_special_tokens=True)
41
+
42
+ for stop_sequence in self.stop_sequences:
43
+ if stop_sequence in decoded_text:
44
+ return True
45
+ return False
46
+
47
+
48
+ class RepeatedTokenSequenceCriteria(StoppingCriteria):
49
+ def __init__(self, tokenizer: Tokenizer, completion_start_index: int) -> None:
50
+ self.tokenizer = tokenizer
51
+ # Initialize with an empty string to store the last line
52
+ self.last_line = ""
53
+ self.completion_start_index = completion_start_index
54
+ # self.newline_token_id = tokenizer.encode('\n')
55
+
56
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> torch.Tensor:
57
+ # Convert token ids to tokens
58
+ tokens = self.tokenizer.decode(input_ids[0][self.completion_start_index :])
59
+
60
+ # Join tokens to form the current text
61
+ current_text = "".join(tokens)
62
+ # Split text into lines
63
+ lines = current_text.split("\n")
64
+
65
+ # Check if the last full line (ignoring the last if it's incomplete) is repeated
66
+ if len(lines) > 1 and lines[-2] == lines[-1] and not (lines[-1] == "" and lines[-2] == ""):
67
+ return torch.BoolTensor([True]).to(input_ids.device) # Stop generation if repeated line is found
68
+
69
+ return torch.BoolTensor([False]).to(input_ids.device)
70
+
71
+
72
+ class HFLLM(BaseLLM):
73
+ LLM_NAME: str
74
+ DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
75
+ SEQ_LENGTH: int | None = None
76
+
77
+ def __init__(self, formatter: BaseFormatter | None = None) -> None:
78
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ self.tokenizer = AutoTokenizer.from_pretrained(self.LLM_NAME)
80
+ self.model = AutoModelForCausalLM.from_pretrained(self.LLM_NAME, device_map="auto")
81
+ logger.info(f"{RED}[ Model initialized --------------------- {RESET}{self.LLM_NAME} {RED}]{RESET}")
82
+ self._set_formatter(formatter)
83
+
84
+ def _set_formatter(self, formatter: BaseFormatter | None = None) -> None:
85
+ # if formatter is being set at initialization time, use it
86
+ if formatter is not None:
87
+ self._formatter = formatter
88
+ # if formatter is not being set at initialization time, but DEFAULT_FORMATTER was specified, use it
89
+ elif self.DEFAULT_FORMATTER is not None:
90
+ self._formatter = self.DEFAULT_FORMATTER()
91
+ # if formatter is not being set at initialization time and there is no default formatter,
92
+ # using HF chat formatter if exists
93
+ elif self.tokenizer.chat_template is not None:
94
+ self._formatter = HFFormatter(self.LLM_NAME)
95
+ # if formatter is not being set at initialization time and there is no default formatter and no chat formatter,
96
+ # using ConcatFormatter
97
+ else:
98
+ raise ValueError("No formatter specified and no default formatter available.")
99
+
100
+ logger.info(
101
+ f"{RED}[ Using default formatter --------------------- {RESET}{self._formatter.__class__.__name__} {RED}]{RESET}" # noqa: E501
102
+ )
103
+
104
+ def count_tokens(self, text: str, /) -> int:
105
+ """Count the number of tokens in a string."""
106
+ return len(self.tokenizer(text, add_special_tokens=False)["input_ids"])
107
+
108
+ def generate_from_messages(
109
+ self,
110
+ messages: list[Sequence[Message]],
111
+ stop_sequences: list[str] | None = None,
112
+ max_tokens: int | None = None,
113
+ temperature: float | None = None,
114
+ ) -> list[RawCompletion]:
115
+ if temperature is None:
116
+ effective_temperature = 0.0 # Current default, TODO: refactor to use model's default
117
+ logger.info(
118
+ f"Using default temperature value: {effective_temperature} as no custom temperature value was provided"
119
+ )
120
+ else:
121
+ effective_temperature = temperature
122
+
123
+ raw_completions = []
124
+ for single_messages in messages:
125
+ # format
126
+ prompt = self._formatter.format(single_messages, output_mode="string")
127
+ # add_special_tokens would add a second BOS token without explicitly setting it False
128
+ inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.device)
129
+
130
+ prompt_token_count = len(inputs["input_ids"][0])
131
+ pad_token_id = self.tokenizer.eos_token_id
132
+
133
+ # Prepare stopping criteria
134
+ stopping_criteria = StoppingCriteriaList()
135
+ if stop_sequences is not None:
136
+ stopping_criteria.append(StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_token_count)) # type: ignore[attr-defined]
137
+
138
+ stopping_criteria.append( # type: ignore[attr-defined]
139
+ RepeatedTokenSequenceCriteria(
140
+ self.tokenizer,
141
+ prompt_token_count,
142
+ )
143
+ )
144
+
145
+ min_seq_length = min(filter(None, [self.seq_length, self.SEQ_LENGTH]))
146
+
147
+ # Calculate the maximum number of tokens to generate
148
+ max_tokens_to_generate = min_seq_length - prompt_token_count
149
+ # If max_tokens is specified, use the smaller of the two
150
+ max_tokens_to_generate = min(filter(None, [max_tokens_to_generate, max_tokens]))
151
+
152
+ if max_tokens_to_generate < 1:
153
+ if raise_errors():
154
+ raise PromptTooLongException("Prompt exceeded context size.")
155
+ raw_completions.append(
156
+ RawCompletion(
157
+ prompt=prompt,
158
+ prompt_sequence_positions=prompt_token_count,
159
+ completion="",
160
+ completion_sequence_positions=0,
161
+ raw_completion_error=Error(
162
+ error_class=PromptTooLongException.__name__,
163
+ message="Prompt exceeded context size.",
164
+ traceback="",
165
+ ),
166
+ )
167
+ )
168
+ continue
169
+
170
+ completion, completion_token_count = self._model_generate(
171
+ redis_key=(prompt, stop_sequences, max_tokens_to_generate, effective_temperature),
172
+ prompt_token_count=prompt_token_count,
173
+ inputs=inputs["input_ids"],
174
+ max_new_tokens=max_tokens_to_generate,
175
+ stopping_criteria=stopping_criteria,
176
+ num_return_sequences=1,
177
+ pad_token_id=pad_token_id,
178
+ return_dict_in_generate=False,
179
+ output_scores=False,
180
+ do_sample=effective_temperature > 0,
181
+ temperature=effective_temperature if effective_temperature > 0 else None,
182
+ )
183
+
184
+ raw_completions.append(
185
+ RawCompletion(
186
+ prompt=prompt,
187
+ prompt_sequence_positions=prompt_token_count,
188
+ concat_compression=ConcatCompression.calculate(
189
+ single_messages, count_tokens=self.count_tokens, completion=completion
190
+ ),
191
+ completion=completion,
192
+ completion_sequence_positions=completion_token_count,
193
+ )
194
+ )
195
+ return raw_completions
196
+
197
+ def _model_generate(self, redis_key: Any, prompt_token_count: int, **kwargs: Any) -> tuple[str, int]:
198
+ outputs = self.model.generate(**kwargs)[0]
199
+ completion = self.tokenizer.decode(outputs[prompt_token_count:], skip_special_tokens=True)
200
+
201
+ if kwargs["stopping_criteria"][0].__class__.__name__ == "StopSequenceCriteria":
202
+ for stop_sequence in kwargs["stopping_criteria"][0].stop_sequences:
203
+ completion = completion.split(stop_sequence)[0]
204
+ return completion, len(outputs[prompt_token_count:])
205
+
206
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
207
+ results = []
208
+ for sample in samples:
209
+ # format
210
+ prompt = self._formatter.format(sample.messages, output_mode="string")
211
+ choices_log_probs: dict[str, float] = {}
212
+ choices_log_probs_sequence_positions: dict[str, float] = {}
213
+ error: Error | None = None
214
+
215
+ for choice in sample.possible_completions or []:
216
+ num_choice_tokens = len(self.tokenizer.encode(choice, add_special_tokens=False))
217
+ prompt_and_choice = f"{prompt}{choice}"
218
+
219
+ total_tokens_count = len(self.tokenizer.encode(prompt_and_choice, add_special_tokens=False))
220
+
221
+ min_max_tokens = min(filter(None, [self.SEQ_LENGTH, self.seq_length]))
222
+
223
+ if min_max_tokens < total_tokens_count:
224
+ if raise_errors():
225
+ raise PromptTooLongException("Prompt exceeded context size.")
226
+ choices_log_probs = {}
227
+ choices_log_probs_sequence_positions = {}
228
+ error = Error(
229
+ error_class=PromptTooLongException.__name__,
230
+ message="Prompt and choice exceeded context size.",
231
+ traceback="",
232
+ )
233
+ break
234
+ else:
235
+ # Calculate log-likelihoods for each token in the completion
236
+ sum_log_probs = self._model_log_probs(prompt_and_choice, num_choice_tokens)
237
+
238
+ choices_log_probs.update({choice: sum_log_probs})
239
+ choices_log_probs_sequence_positions.update({choice: num_choice_tokens})
240
+
241
+ results.append(
242
+ RawLoglikelihood(
243
+ prompt=prompt,
244
+ prompt_sequence_positions=len(self.tokenizer.encode(prompt, add_special_tokens=False)),
245
+ concat_compression=ConcatCompression.calculate(
246
+ sample.messages, count_tokens=self.count_tokens, choices=sample.possible_completions
247
+ ),
248
+ loglikelihoods=choices_log_probs,
249
+ loglikelihoods_sequence_positions=choices_log_probs_sequence_positions,
250
+ raw_loglikelihood_error=error,
251
+ )
252
+ )
253
+ return results
254
+
255
+ def _model_log_probs(self, prompt: str, num_choice_tokens: int) -> float:
256
+ with torch.no_grad():
257
+ inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(self.device)
258
+ outputs = self.model(**inputs, labels=inputs["input_ids"])
259
+ logits = outputs.logits[:, :-1, :].squeeze(0)
260
+ target_ids = inputs["input_ids"][:, 1:].squeeze(0)
261
+
262
+ token_loglikelihoods = []
263
+ for i in range(0, len(target_ids)):
264
+ token_id = target_ids[i].item()
265
+ token = self.tokenizer.decode([token_id])
266
+ loglikelihood = torch.log_softmax(logits[i], dim=-1)[token_id].item()
267
+ token_loglikelihoods.append({token: loglikelihood})
268
+
269
+ return sum([list(log_prob.values())[0] for log_prob in token_loglikelihoods[-num_choice_tokens:]])
270
+
271
+ @property
272
+ def seq_length(self) -> int | None:
273
+ config = self.model.config
274
+ return config.max_position_embeddings if hasattr(config, "max_position_embeddings") else None
275
+
276
+
277
+ class HFLLM_from_name(HFLLM):
278
+ """
279
+ A generic class to create HFLLM instances from a given model name.
280
+ """
281
+
282
+ def __init__(self, model_name: str | None = None, formatter: str = "Llama3Formatter", **kwargs: Any) -> None:
283
+ if model_name is None:
284
+ raise ValueError("model_name is required")
285
+
286
+ self.LLM_NAME = model_name
287
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
288
+ self.tokenizer = AutoTokenizer.from_pretrained(self.LLM_NAME)
289
+ self.model = AutoModelForCausalLM.from_pretrained(self.LLM_NAME, device_map="auto")
290
+
291
+ # Lazy formatter initialization - only create the one we need
292
+ selected_formatter = self._get_formatter(formatter, model_name)
293
+
294
+ print(f"{RED}[ Model initialized --------------------- {RESET}{self.LLM_NAME} {RED}]{RESET}")
295
+ print(f"{RED}[ Formatter: {formatter} ]{RESET}")
296
+ self._set_formatter(selected_formatter)
297
+
298
+ def _get_formatter(self, formatter: str, model_name: str) -> Any:
299
+ """Get formatter instance based on formatter name."""
300
+ if formatter == "Llama3Formatter":
301
+ return Llama3Formatter()
302
+ elif formatter == "MistralFormatter":
303
+ from eval_framework.llm.mistral import MagistralFormatter
304
+
305
+ return MagistralFormatter(model_name)
306
+ elif formatter == "ConcatFormatter":
307
+ return ConcatFormatter()
308
+ elif formatter == "HFFormatter":
309
+ return HFFormatter(model_name)
310
+ else:
311
+ supported = ["Llama3Formatter", "MistralFormatter", "ConcatFormatter", "HFFormatter"]
312
+ raise ValueError(f"Unsupported formatter: {formatter}. Supported formatters: {supported}")
313
+
314
+
315
+ class Pythia410m(HFLLM):
316
+ LLM_NAME = "EleutherAI/pythia-410m"
317
+ DEFAULT_FORMATTER = ConcatFormatter
318
+
319
+
320
+ class SmolLM135M(HFLLM):
321
+ LLM_NAME = "HuggingFaceTB/SmolLM-135M"
322
+ DEFAULT_FORMATTER = ConcatFormatter
323
+
324
+
325
+ class Smollm135MInstruct(HFLLM):
326
+ LLM_NAME = "HuggingFaceTB/SmolLM-135M-Instruct"
327
+ DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME)
328
+
329
+
330
+ class Qwen3_0_6B(HFLLM):
331
+ LLM_NAME = "Qwen/Qwen3-0.6B"
332
+ DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME, chat_template_kwargs={"enable_thinking": True})
@@ -0,0 +1,73 @@
1
+ from functools import partial
2
+ from typing import Any, Literal, override
3
+
4
+ from vllm import SamplingParams
5
+
6
+ from eval_framework.llm.vllm import TokenizedContainer, VLLMModel, VLLMTokenizerAPI
7
+ from template_formatting.formatter import BaseFormatter, Message
8
+ from template_formatting.mistral_formatter import MagistralFormatter, MistralSerializer
9
+
10
+ __all__ = [
11
+ "MistralAdapter",
12
+ "MistralVLLM",
13
+ ]
14
+
15
+
16
+ class MistralAdapter(VLLMTokenizerAPI[list[Message]]):
17
+ def __init__(self, target_mdl: str) -> None:
18
+ self.serializer = MistralSerializer(llm_target=target_mdl)
19
+ self.tokenizer = self.serializer.get_tokenizer()
20
+
21
+ def encode_formatted_struct(self, struct: list[Message]) -> TokenizedContainer:
22
+ mistral_msg_lst = self.serializer.convert_from_aa(msg_lst=struct)
23
+ mistral_request = self.serializer.build_mistral_request(mistral_msg_lst=mistral_msg_lst)
24
+ mistral_tokenized_obj = self.tokenizer.encode_instruct(mistral_request)
25
+ return TokenizedContainer(tokens=mistral_tokenized_obj.tokens, text=mistral_tokenized_obj.text)
26
+
27
+ def encode_plain_text(self, text: str) -> TokenizedContainer:
28
+ choice_tokens = self.tokenizer.tokenizer.encode(text, False, False)
29
+ return TokenizedContainer(tokens=choice_tokens, text=text)
30
+
31
+
32
+ class MistralVLLM(VLLMModel):
33
+ def __init__(
34
+ self,
35
+ formatter: BaseFormatter | None = None,
36
+ max_model_len: int | None = None,
37
+ tensor_parallel_size: int = 1,
38
+ gpu_memory_utilization: float = 0.9,
39
+ batch_size: int = 1,
40
+ checkpoint_path: str | None = None,
41
+ checkpoint_name: str | None = None,
42
+ sampling_params: SamplingParams | dict[str, Any] | None = None,
43
+ **kwargs: Any,
44
+ ) -> None:
45
+ model_args = {"tokenizer_mode": "mistral", "config_format": "mistral", "load_format": "mistral"}
46
+ super().__init__(
47
+ formatter,
48
+ max_model_len,
49
+ tensor_parallel_size,
50
+ gpu_memory_utilization,
51
+ batch_size,
52
+ checkpoint_path,
53
+ checkpoint_name,
54
+ sampling_params,
55
+ **{**model_args, **kwargs},
56
+ )
57
+
58
+ @override
59
+ @property
60
+ def tokenizer(self) -> VLLMTokenizerAPI:
61
+ if self._tokenizer is None:
62
+ self._tokenizer = MistralAdapter(target_mdl=self.LLM_NAME)
63
+ return self._tokenizer
64
+
65
+ @property
66
+ def formatter_output_mode(self) -> Literal["string", "list"]:
67
+ """Determine the correct output mode for the formatter based on tokenizer type."""
68
+ return "list"
69
+
70
+
71
+ class MagistralVLLM(MistralVLLM):
72
+ LLM_NAME = "mistralai/Magistral-Small-2506"
73
+ DEFAULT_FORMATTER = partial(MagistralFormatter, "mistralai/Magistral-Small-2506")
@@ -0,0 +1,16 @@
1
+ """This is just a default model file with some small models for testing.
2
+
3
+ Please define your own model file externally and pass it to the eval-framework entrypoint
4
+ to use it.
5
+ """
6
+
7
+ from eval_framework.utils.packaging import is_extra_installed
8
+
9
+ if is_extra_installed(extra="transformers"):
10
+ from eval_framework.llm.huggingface import Pythia410m, SmolLM135M, Smollm135MInstruct, Qwen3_0_6B # noqa F401
11
+
12
+ if is_extra_installed("mistral"):
13
+ from eval_framework.llm.mistral import MagistralVLLM # noqa F401
14
+
15
+ if is_extra_installed("vllm"):
16
+ from eval_framework.llm.vllm import Qwen3_0_6B_VLLM, Qwen3_0_6B_VLLM_No_Thinking # noqa F401
@@ -0,0 +1,205 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ from collections.abc import Callable, Sequence
5
+ from typing import Any
6
+
7
+ import tiktoken # OpenAI's official tokenizer library
8
+ from openai import OpenAI
9
+
10
+ from eval_framework.llm.base import BaseLLM
11
+ from eval_framework.shared.types import ConcatCompression, RawCompletion, RawLoglikelihood
12
+ from eval_framework.tasks.base import Sample
13
+ from template_formatting.formatter import BaseFormatter, Message, Role
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OpenAIModel(BaseLLM):
19
+ DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
20
+
21
+ def __init__(
22
+ self,
23
+ model_name: str = "gpt-4o",
24
+ formatter: BaseFormatter | None = None,
25
+ temperature: float | None = None,
26
+ api_key: str | None = None,
27
+ organization: str | None = None,
28
+ base_url: str | None = None,
29
+ ) -> None:
30
+ """Initialize OpenAI API client.
31
+ Args:
32
+ model_name: Name of the OpenAI model to use (e.g., "gpt-4", "gpt-3.5-turbo")
33
+ formatter: Optional message formatter
34
+ temperature: Sampling temperature (0.0 to 2.0)
35
+ api_key: OpenAI API key (defaults to OPENAI_API_KEY env variable)
36
+ organization: Optional organization ID
37
+ base_url: Optional API base URL for Azure or other endpoints
38
+ """
39
+ self._model_name = model_name
40
+ logger.info(f"Using {model_name} as a judge")
41
+ self._formatter = formatter or self.DEFAULT_FORMATTER() if self.DEFAULT_FORMATTER is not None else None
42
+ self._temperature = temperature
43
+ # Initialize OpenAI client
44
+ self._client = OpenAI(
45
+ api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
46
+ organization=organization,
47
+ base_url=base_url,
48
+ )
49
+
50
+ # Initialize tiktoken tokenizer for the model
51
+ self._encoding = tiktoken.encoding_for_model(self._model_name)
52
+
53
+ def _count_tokens(self, text: str) -> int:
54
+ """Helper method to count tokens using tiktoken."""
55
+ return len(self._encoding.encode(text))
56
+
57
+ def generate_from_messages(
58
+ self,
59
+ messages: list[Sequence[Message]],
60
+ stop_sequences: list[str] | None = None,
61
+ max_tokens: int | None = None,
62
+ temperature: float | None = None,
63
+ ) -> list[RawCompletion]:
64
+ if temperature is None:
65
+ effective_temperature = 0.0 # Current default, TODO: refactor to use model's default
66
+ logger.info(
67
+ f"Using default temperature value: {effective_temperature} as no custom temperature value was provided"
68
+ )
69
+ else:
70
+ effective_temperature = temperature
71
+ """Generate completion from messages.
72
+ Args:
73
+ messages: Sequence of messages
74
+ stop_sequences: Optional list of stop sequences
75
+ max_tokens: Optional maximum number of tokens to generate
76
+ Returns:
77
+ Tuple of (prompt, completion)
78
+ """
79
+ results = []
80
+ for single_messages in messages:
81
+ if self._formatter is not None:
82
+ # Use formatter for text completion API
83
+ prompt = self._formatter.format(single_messages, output_mode="string")
84
+ response = self._client.completions.create(
85
+ model=self._model_name,
86
+ prompt=prompt,
87
+ temperature=effective_temperature,
88
+ max_tokens=max_tokens,
89
+ stop=stop_sequences,
90
+ )
91
+
92
+ prompt_sequence_positions: int | None = self._count_tokens(prompt)
93
+ completion = response.choices[0].text
94
+ completion_sequence_positions = self._count_tokens(completion)
95
+
96
+ results.append(
97
+ RawCompletion(
98
+ prompt=prompt,
99
+ prompt_sequence_positions=prompt_sequence_positions,
100
+ concat_compression=ConcatCompression.calculate(
101
+ single_messages, count_tokens=self._count_tokens, completion=completion
102
+ ),
103
+ completion=completion,
104
+ completion_sequence_positions=completion_sequence_positions,
105
+ )
106
+ )
107
+ else:
108
+ # Use chat completion API
109
+ from openai.types.chat import ChatCompletionAssistantMessageParam, ChatCompletionUserMessageParam
110
+
111
+ chat_messages = [
112
+ (
113
+ ChatCompletionUserMessageParam(role="user", content=m.content)
114
+ if m.role is not None and m.role.value.lower() == "user"
115
+ else ChatCompletionAssistantMessageParam(role="assistant", content=m.content)
116
+ )
117
+ for m in single_messages
118
+ ]
119
+
120
+ chat_response = self._client.chat.completions.create(
121
+ model=self._model_name,
122
+ messages=chat_messages,
123
+ temperature=effective_temperature,
124
+ max_tokens=max_tokens,
125
+ stop=stop_sequences,
126
+ )
127
+
128
+ # Reconstruct the prompt (since OpenAI API does not return it)
129
+ prompt = "\n".join([f"{m['role']}: {m['content']}" for m in chat_messages])
130
+
131
+ prompt_sequence_positions = (
132
+ chat_response.usage.prompt_tokens if chat_response.usage else None
133
+ ) # OpenAI API gives token count
134
+ completion = (
135
+ chat_response.choices[0].message.content if chat_response.choices[0].message.content else ""
136
+ )
137
+ completion_sequence_positions = self._count_tokens(completion)
138
+
139
+ results.append(
140
+ RawCompletion(
141
+ prompt=prompt,
142
+ prompt_sequence_positions=prompt_sequence_positions,
143
+ concat_compression=ConcatCompression.calculate(
144
+ single_messages, count_tokens=self._count_tokens, completion=completion
145
+ ),
146
+ completion=completion,
147
+ completion_sequence_positions=completion_sequence_positions,
148
+ )
149
+ )
150
+ return results
151
+
152
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
153
+ """Get log probabilities for possible completions.
154
+ Args:
155
+ samples: list of Sample containing possible completions
156
+ Returns:
157
+ list of Tuple of (prompt, dict of completion->logprob)
158
+ Raises:
159
+ NotImplementedError: Logprobs not yet implemented
160
+ """
161
+ raise NotImplementedError("Logprobs not yet implemented for OpenAI API")
162
+
163
+ def generate_structured_output(
164
+ self,
165
+ messages: list[Sequence[Message]],
166
+ stop_sequences: list[str] | None = None,
167
+ max_tokens: int | None = None,
168
+ temperature: float = 0.0,
169
+ ) -> Any:
170
+ """Generate structured output (e.g. JSON) from messages.
171
+ This implementation ensures the model returns valid JSON.
172
+ Args:
173
+ messages: list of Sequence of messages
174
+ stop_sequences: Optional stop sequences
175
+ max_tokens: Optional max tokens
176
+ Returns:
177
+ Parsed JSON response
178
+ """
179
+ completions = []
180
+ list_json_messages: list[Sequence[Message]] = []
181
+ for single_messages in messages:
182
+ # Add system message to encourage JSON output
183
+ json_messages = list(single_messages)
184
+ if not any(m.role == Role.SYSTEM for m in single_messages):
185
+ json_messages.insert(
186
+ 0,
187
+ Message(
188
+ role=Role.SYSTEM, content="You are a helpful assistant that always responds with valid JSON."
189
+ ),
190
+ )
191
+ list_json_messages.append(json_messages)
192
+ # Generate completion
193
+ completions = self.generate_from_messages(
194
+ messages=list_json_messages, stop_sequences=stop_sequences, max_tokens=max_tokens
195
+ )
196
+ responses = []
197
+ for completion in completions:
198
+ try:
199
+ # Parse JSON responses
200
+ responses.append(json.loads(completion.completion))
201
+ except json.JSONDecodeError as e:
202
+ logger.info(f"Warning: Failed to parse JSON response: {e}")
203
+ logger.info(f"Raw response: {completion.completion}")
204
+ raise
205
+ return responses