eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,554 @@
1
+ import gc
2
+ import logging
3
+ import math
4
+ import os
5
+ import warnings
6
+ from abc import ABC, abstractmethod
7
+ from collections.abc import Callable, Sequence
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+ from pathlib import Path
11
+ from typing import Any, Literal, Protocol, cast, override
12
+
13
+ import torch
14
+ from vllm import LLM, SamplingParams
15
+ from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
16
+ from vllm.inputs.data import TokensPrompt
17
+ from vllm.outputs import RequestOutput
18
+ from vllm.transformers_utils.tokenizer import get_tokenizer
19
+
20
+ from eval_framework.llm.base import BaseLLM
21
+ from eval_framework.shared.types import (
22
+ ConcatCompression,
23
+ Error,
24
+ PromptTooLongException,
25
+ RawCompletion,
26
+ RawLoglikelihood,
27
+ )
28
+ from eval_framework.tasks.base import Sample
29
+ from eval_framework.tasks.utils import raise_errors
30
+ from eval_framework.utils.constants import RED, RESET
31
+ from template_formatting.formatter import BaseFormatter, HFFormatter, Message
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class TokenizedContainer:
38
+ """
39
+ Container object to store tokens and formatted prompt
40
+ """
41
+
42
+ tokens: list[int]
43
+ text: str
44
+
45
+
46
+ class VLLMTokenizerAPI[prompt_type: (list[Message], str)](ABC):
47
+ """
48
+ Protocol for tokenizer interface that defines required methods.
49
+ Needed for type checking because of the vllm tokenizer.
50
+ """
51
+
52
+ @abstractmethod
53
+ def encode_formatted_struct(self, struct: prompt_type) -> TokenizedContainer:
54
+ """Encode prompt to token IDs."""
55
+ pass
56
+
57
+ @abstractmethod
58
+ def encode_plain_text(self, text: str) -> TokenizedContainer:
59
+ pass
60
+
61
+ @property
62
+ def chat_template(self) -> str | None:
63
+ return None
64
+
65
+
66
+ class HFTokenizerProtocol(Protocol):
67
+ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
68
+ """Encode text to token IDs."""
69
+ ...
70
+
71
+ def decode(self, tokens: list[int]) -> str:
72
+ """Decode token IDs to text."""
73
+ ...
74
+
75
+ @property
76
+ def chat_template(self) -> str | None:
77
+ """Chat template for the tokenizer."""
78
+ ...
79
+
80
+
81
+ class VLLMTokenizer(VLLMTokenizerAPI[str]):
82
+ def __init__(self, target_mdl: str | Path) -> None:
83
+ self.tokenizer = cast(HFTokenizerProtocol, get_tokenizer(target_mdl))
84
+
85
+ def _encode_text(self, text: str) -> TokenizedContainer:
86
+ tokens = self.tokenizer.encode(text, add_special_tokens=False)
87
+ return TokenizedContainer(tokens=tokens, text=text)
88
+
89
+ def encode_formatted_struct(self, struct: str) -> TokenizedContainer:
90
+ return self._encode_text(text=struct)
91
+
92
+ def encode_plain_text(self, text: str) -> TokenizedContainer:
93
+ return self._encode_text(text=text)
94
+
95
+ def decode(self, tokens: list[int]) -> str:
96
+ return self.tokenizer.decode(tokens)
97
+
98
+ @override
99
+ @property
100
+ def chat_template(self) -> str | None:
101
+ return self.tokenizer.chat_template
102
+
103
+
104
+ class BaseVLLMModel(BaseLLM):
105
+ LLM_NAME: str
106
+ DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
107
+ SEQ_LENGTH: int | None = None
108
+ BYTES_PER_TOKEN: float = 4.0 # rule of thumb according to https://platform.openai.com/tokenizer
109
+
110
+ def __init__(
111
+ self,
112
+ formatter: BaseFormatter | None = None,
113
+ max_model_len: int | None = None,
114
+ tensor_parallel_size: int = 1,
115
+ gpu_memory_utilization: float = 0.9,
116
+ batch_size: int = 1,
117
+ checkpoint_path: str | Path | None = None,
118
+ checkpoint_name: str | None = None,
119
+ sampling_params: SamplingParams | dict[str, Any] | None = None,
120
+ bytes_per_token: float | None = None,
121
+ **kwargs: Any,
122
+ ) -> None:
123
+ # Store the max_model_len for later use
124
+ self._max_model_len = max_model_len
125
+ self.checkpoint_name = checkpoint_name
126
+ self.checkpoint_path = checkpoint_path
127
+
128
+ model_args = {
129
+ "model": str(self.checkpoint_path) if self.checkpoint_path else self.LLM_NAME,
130
+ "max_model_len": max_model_len or self.SEQ_LENGTH,
131
+ "max_num_seqs": batch_size,
132
+ "tensor_parallel_size": tensor_parallel_size,
133
+ "gpu_memory_utilization": gpu_memory_utilization,
134
+ **kwargs,
135
+ }
136
+
137
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
138
+
139
+ self.batch_size = batch_size
140
+ self._tokenizer: None | VLLMTokenizerAPI = None
141
+
142
+ self.model = LLM(**model_args, device=device)
143
+
144
+ self.sampling_params: SamplingParams = self._process_sampling_params(sampling_params)
145
+
146
+ logger.info(
147
+ f"{RED}[ Model initialized ------------------- {RESET}{self.checkpoint_path or self.LLM_NAME} {RED}]{RESET}"
148
+ )
149
+ self._set_formatter(formatter)
150
+ # set bytes_per_token_scalar for non-standard models
151
+ if bytes_per_token is not None and bytes_per_token <= 0:
152
+ raise ValueError("bytes_per_token must be positive")
153
+ self.bytes_per_token_scalar = (
154
+ 4.0 / bytes_per_token if bytes_per_token is not None else 4.0 / self.BYTES_PER_TOKEN
155
+ )
156
+
157
+ def _process_sampling_params(self, sampling_params: SamplingParams | dict[str, Any] | None) -> SamplingParams:
158
+ processed_sampling_params: SamplingParams | None = None
159
+ if isinstance(sampling_params, dict):
160
+ processed_sampling_params = SamplingParams(**sampling_params)
161
+ logger.info(f"Converted sampling_params dict to SamplingParams: {processed_sampling_params}")
162
+ elif sampling_params is not None:
163
+ processed_sampling_params = sampling_params
164
+ else:
165
+ processed_sampling_params = self.model.get_default_sampling_params()
166
+
167
+ return processed_sampling_params
168
+
169
+ def _set_formatter(self, formatter: BaseFormatter | None = None) -> None:
170
+ if formatter is not None:
171
+ self._formatter = formatter
172
+ elif self.DEFAULT_FORMATTER is not None:
173
+ self._formatter = self.DEFAULT_FORMATTER()
174
+ elif self.tokenizer.chat_template is not None:
175
+ self._formatter = HFFormatter(self.checkpoint_path or self.LLM_NAME)
176
+ else:
177
+ raise ValueError("No formatter specified and no default formatter available.")
178
+
179
+ logger.info(
180
+ f"{RED}[ Using default formatter --------------------- {RESET}{self._formatter.__class__.__name__} {RED}]{RESET}" # noqa: E501
181
+ )
182
+
183
+ @property
184
+ def tokenizer(self) -> VLLMTokenizerAPI:
185
+ if self._tokenizer is None:
186
+ self._tokenizer = VLLMTokenizer(target_mdl=self.checkpoint_path or self.LLM_NAME)
187
+ return self._tokenizer
188
+
189
+ def count_tokens(self, text: str, /) -> int:
190
+ return len(self.tokenizer.encode_plain_text(text).tokens)
191
+
192
+ @property
193
+ def formatter_output_mode(self) -> Literal["string", "list"]:
194
+ return "string"
195
+
196
+ @property
197
+ def name(self) -> str:
198
+ if self.checkpoint_name:
199
+ return f"{self.__class__.__name__}_checkpoint_{self.checkpoint_name}"
200
+ return self.__class__.__name__
201
+
202
+ def build_redis_key_from_prompt_objs(
203
+ self, prompt_objs: list[TokenizedContainer], sampling_params: SamplingParams
204
+ ) -> Any:
205
+ """
206
+ Build a redis key from a list of prompt objects and sampling parameters.
207
+ TokenizedContainers are not serializable so we just pass the tokens and sampling params.
208
+ """
209
+ return ([obj.tokens for obj in prompt_objs], sampling_params)
210
+
211
+ def __del__(self) -> None:
212
+ if hasattr(self, "model"):
213
+ if hasattr(self.model, "llm_engine") and hasattr(self.model.llm_engine, "engine_core"):
214
+ self.model.llm_engine.engine_core.shutdown()
215
+ del self.model
216
+ cleanup_dist_env_and_memory()
217
+ gc.collect()
218
+ torch.cuda.empty_cache()
219
+
220
+ def generate_from_messages(
221
+ self,
222
+ messages: list[Sequence[Message]],
223
+ stop_sequences: list[str] | None = None,
224
+ max_tokens: int | None = None,
225
+ temperature: float | None = None,
226
+ ) -> list[RawCompletion]:
227
+ raw_completions: list[RawCompletion | None] = [None] * len(messages)
228
+ prompt_objs = []
229
+ prompt_info = []
230
+
231
+ # Adjust max tokens based on bytes_per_token_scalar so that non-standard models generate full responses
232
+ scaled_max_tokens = math.ceil(max_tokens * self.bytes_per_token_scalar) if max_tokens is not None else None
233
+
234
+ sampling_params = self._resolve_sampling_params(
235
+ self.sampling_params, scaled_max_tokens, stop_sequences, temperature
236
+ )
237
+
238
+ for i, single_messages in enumerate(messages):
239
+ output_mode = self.formatter_output_mode
240
+ prompt: str | list[Message] = self._formatter.format(single_messages, output_mode=output_mode)
241
+ prompt_obj: TokenizedContainer = self.tokenizer.encode_formatted_struct(prompt)
242
+ prompt_token_count = len(prompt_obj.tokens)
243
+
244
+ max_tokens_to_generate = self.max_seq_length - prompt_token_count
245
+
246
+ # If max_tokens is specified, use the smaller of the two
247
+ max_tokens_to_generate = min(filter(None, [max_tokens_to_generate, scaled_max_tokens]))
248
+
249
+ if max_tokens_to_generate < 1:
250
+ if raise_errors():
251
+ raise PromptTooLongException("Prompt exceeded context size.")
252
+
253
+ raw_completions[i] = RawCompletion(
254
+ prompt=prompt_obj.text,
255
+ prompt_sequence_positions=prompt_token_count,
256
+ completion="",
257
+ completion_sequence_positions=0,
258
+ raw_completion_error=Error(
259
+ error_class=PromptTooLongException.__name__,
260
+ message="Prompt exceeded context size.",
261
+ traceback="",
262
+ ),
263
+ )
264
+ continue
265
+
266
+ prompt_objs.append(prompt_obj)
267
+ prompt_info.append((i, single_messages))
268
+
269
+ if prompt_objs:
270
+ model_outputs = self._model_generate(prompt_objs=prompt_objs, sampling_params=sampling_params)
271
+
272
+ for (original_index, single_messages), prompt_obj, output in zip(prompt_info, prompt_objs, model_outputs):
273
+ raw_completions[original_index] = RawCompletion(
274
+ prompt=prompt_obj.text,
275
+ prompt_sequence_positions=len(output.prompt_token_ids) if output.prompt_token_ids else 0,
276
+ concat_compression=ConcatCompression.calculate(
277
+ single_messages, count_tokens=self.count_tokens, completion=output.outputs[0].text
278
+ ),
279
+ completion=output.outputs[0].text,
280
+ completion_sequence_positions=len(output.outputs[0].token_ids)
281
+ if output.outputs[0].token_ids
282
+ else 0,
283
+ raw_completion_error=None,
284
+ )
285
+
286
+ # Ensure all positions are filled (should never be None at this point)
287
+ return cast(list[RawCompletion], raw_completions)
288
+
289
+ @staticmethod
290
+ def _resolve_sampling_params(
291
+ sampling_params: SamplingParams,
292
+ max_tokens: int | None,
293
+ stop_sequences: list[str] | None,
294
+ temperature: float | None,
295
+ ) -> SamplingParams:
296
+ sampling_params.max_tokens = max_tokens
297
+ sampling_params.stop = stop_sequences
298
+ if temperature is not None:
299
+ logger.warning(
300
+ f"Overriding sampling params temperature {sampling_params.temperature} with custom value {temperature}"
301
+ )
302
+ sampling_params.temperature = temperature
303
+ else:
304
+ logger.info(
305
+ f"Using sampling params temperature value: {sampling_params.temperature} "
306
+ f"as no custom temperature value was provided"
307
+ )
308
+ return sampling_params
309
+
310
+ def _model_generate(
311
+ self,
312
+ prompt_objs: list[TokenizedContainer],
313
+ sampling_params: SamplingParams,
314
+ ) -> list[RequestOutput]:
315
+ vllm_token_prompt = [TokensPrompt(prompt_token_ids=prompt_obj.tokens) for prompt_obj in prompt_objs]
316
+ outputs = self.model.generate(vllm_token_prompt, sampling_params)
317
+
318
+ return outputs
319
+
320
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
321
+ """Batched version of logprobs for improved performance."""
322
+ results: list[RawLoglikelihood | None] = [None] * len(samples)
323
+
324
+ # Collect all prompt-choice combinations
325
+ batch_data = []
326
+ sample_choice_indices = [] # Maps batch index back to (sample_index, choice)
327
+
328
+ for sample_idx, sample in enumerate(samples):
329
+ output_mode = self.formatter_output_mode
330
+ prompt: str | list[Message] = self._formatter.format(sample.messages, output_mode=output_mode)
331
+ prompt_obj: TokenizedContainer = self.tokenizer.encode_formatted_struct(prompt)
332
+
333
+ choices_log_probs: dict[str, float] = {}
334
+ choices_log_probs_sequence_positions: dict[str, int] = {}
335
+ error: Error | None = None
336
+ valid_choices = []
337
+
338
+ for choice in sample.possible_completions or []:
339
+ choice_obj: TokenizedContainer = self.tokenizer.encode_plain_text(choice)
340
+ total_tokens_count = len(prompt_obj.tokens + choice_obj.tokens)
341
+
342
+ if total_tokens_count > self.max_seq_length:
343
+ if raise_errors():
344
+ raise PromptTooLongException("Prompt exceeded context size.")
345
+ choices_log_probs = {}
346
+ choices_log_probs_sequence_positions = {}
347
+ error = Error(
348
+ error_class=PromptTooLongException.__name__,
349
+ message="Prompt and choice exceeded context size.",
350
+ traceback="",
351
+ )
352
+ break
353
+ else:
354
+ batch_data.append((prompt_obj, choice_obj))
355
+ sample_choice_indices.append((sample_idx, choice))
356
+ valid_choices.append(choice)
357
+ choices_log_probs_sequence_positions[choice] = len(choice_obj.tokens)
358
+
359
+ # If we had an error, store the result immediately
360
+ if error is not None:
361
+ results[sample_idx] = RawLoglikelihood(
362
+ prompt=prompt_obj.text,
363
+ prompt_sequence_positions=len(prompt_obj.tokens),
364
+ loglikelihoods=choices_log_probs,
365
+ loglikelihoods_sequence_positions=choices_log_probs_sequence_positions,
366
+ raw_loglikelihood_error=error,
367
+ )
368
+ else:
369
+ results[sample_idx] = RawLoglikelihood(
370
+ prompt=prompt_obj.text,
371
+ prompt_sequence_positions=len(prompt_obj.tokens),
372
+ loglikelihoods=choices_log_probs,
373
+ loglikelihoods_sequence_positions=choices_log_probs_sequence_positions,
374
+ raw_loglikelihood_error=None,
375
+ concat_compression=ConcatCompression.calculate(
376
+ sample.messages, count_tokens=self.count_tokens, choices=valid_choices
377
+ ),
378
+ )
379
+
380
+ # Process batch if we have valid data
381
+ if batch_data:
382
+ batch_logprobs = self._model_log_probs(batch_data)
383
+
384
+ # Distribute results back to samples
385
+ for batch_idx, logprob in enumerate(batch_logprobs):
386
+ sample_idx, choice = sample_choice_indices[batch_idx]
387
+ result = results[sample_idx]
388
+ if result is not None:
389
+ result.loglikelihoods[choice] = logprob
390
+
391
+ return cast(list[RawLoglikelihood], results)
392
+
393
+ def _model_log_probs(self, batch_data: list[tuple[TokenizedContainer, TokenizedContainer]]) -> list[float]:
394
+ """Batched version of _model_log_probs for processing multiple prompt-choice pairs at once."""
395
+ sampling_params = SamplingParams(
396
+ max_tokens=1,
397
+ temperature=0.0,
398
+ prompt_logprobs=1,
399
+ detokenize=False,
400
+ )
401
+
402
+ vllm_token_prompts = [
403
+ TokensPrompt(prompt_token_ids=prompt_obj.tokens + choice_obj.tokens)
404
+ for prompt_obj, choice_obj in batch_data
405
+ ]
406
+
407
+ try:
408
+ outputs = self.model.generate(vllm_token_prompts, sampling_params)
409
+ except Exception as e:
410
+ raise e
411
+
412
+ results = []
413
+ for i, (prompt_obj, choice_obj) in enumerate(batch_data):
414
+ output = outputs[i]
415
+ assert output.prompt_logprobs is not None
416
+
417
+ choice_logprobs = output.prompt_logprobs[-len(choice_obj.tokens) :]
418
+ total_logprob = 0.0
419
+
420
+ # VLLM guarantees the actual token's logprob is included in the output
421
+ for j, token_id in enumerate(choice_obj.tokens):
422
+ logprob_obj = choice_logprobs[j]
423
+ assert logprob_obj is not None, f"logprob_obj is None: {logprob_obj}"
424
+ logprob_value = getattr(logprob_obj[token_id], "logprob")
425
+ assert logprob_value is not None, f"logprob_value is None: {logprob_value}"
426
+ total_logprob += logprob_value
427
+
428
+ results.append(total_logprob)
429
+
430
+ return results
431
+
432
+ @property
433
+ def max_seq_length(self) -> int:
434
+ """
435
+ Returns the maximum sequence length for this model.
436
+ Priority order:
437
+ 1. max_model_len parameter passed to __init__
438
+ 2. SEQ_LENGTH class attribute
439
+ 3. Model's actual max_model_len from config
440
+ 4. Default fallback of 2048
441
+ """
442
+ if self._max_model_len is not None:
443
+ return self._max_model_len
444
+
445
+ if self.SEQ_LENGTH is not None:
446
+ return self.SEQ_LENGTH
447
+
448
+ if hasattr(self.model, "llm_engine") and hasattr(self.model.llm_engine, "model_config"):
449
+ return self.model.llm_engine.model_config.max_model_len
450
+
451
+ return 2048
452
+
453
+ @property
454
+ def seq_length(self) -> int | None:
455
+ """
456
+ Kept for backward compatibility.
457
+ """
458
+ return self.max_seq_length
459
+
460
+
461
+ class VLLMModel(BaseVLLMModel):
462
+ """A class to create VLLM instances from various model sources."""
463
+
464
+ def __init__(
465
+ self,
466
+ # Model source (3 options: file path, HuggingFace model name, Wandb artifact name):
467
+ checkpoint_path: str | Path | None = None,
468
+ model_name: str | None = None,
469
+ artifact_name: str | None = None,
470
+ # Formatter (2 options):
471
+ formatter: BaseFormatter | None = None,
472
+ formatter_name: str | None = None,
473
+ formatter_kwargs: dict[str, Any] | None = None,
474
+ # Explicit name for the `name` property:
475
+ checkpoint_name: str | None = None,
476
+ # VLLM parameters (not complete):
477
+ max_model_len: int | None = None,
478
+ tensor_parallel_size: int = 1,
479
+ gpu_memory_utilization: float = 0.9,
480
+ batch_size: int = 1,
481
+ sampling_params: SamplingParams | dict[str, Any] | None = None,
482
+ **kwargs: Any,
483
+ ) -> None:
484
+ final_path, possible_name = self._get_final_checkpoint(checkpoint_path, model_name, artifact_name)
485
+
486
+ if final_path:
487
+ self.LLM_NAME = str(final_path)
488
+
489
+ final_name = checkpoint_name
490
+ if final_name is None and possible_name is not None:
491
+ final_name = possible_name.replace("/", "_").replace(":", "_").strip("_") # sanitize pathname
492
+
493
+ final_formatter = self._get_final_formatter(formatter, formatter_name, formatter_kwargs)
494
+
495
+ super().__init__(
496
+ formatter=final_formatter,
497
+ checkpoint_path=final_path,
498
+ checkpoint_name=final_name,
499
+ max_model_len=max_model_len,
500
+ tensor_parallel_size=tensor_parallel_size,
501
+ gpu_memory_utilization=gpu_memory_utilization,
502
+ batch_size=batch_size,
503
+ sampling_params=sampling_params,
504
+ **kwargs,
505
+ )
506
+
507
+
508
+ class VLLMRegistryModel(VLLMModel): # deprecated
509
+ """
510
+ A class to create VLLM instances from registered models in Wandb registry.
511
+ Downloads the model artifacts from Wandb and creates a local VLLM instance.
512
+ """
513
+
514
+ def __init__(
515
+ self,
516
+ artifact_name: str,
517
+ version: str = "latest",
518
+ formatter: str = "",
519
+ formatter_identifier: str = "",
520
+ **kwargs: Any,
521
+ ) -> None:
522
+ """
523
+ Initialize VLLM from a Wandb registered model artifact.
524
+
525
+ Args:
526
+ artifact_name: Name of the artifact in the Wandb registry
527
+ version: Version of the artifact to download (default: "latest")
528
+ formatter: Type of formatter to use (default: "")
529
+ **kwargs: Additional arguments passed to VLLMModel
530
+ """
531
+
532
+ warnings.warn("`VLLMRegistryModel` is deprecated, please use `VLLMModel`.", DeprecationWarning)
533
+
534
+ download_path = kwargs.pop("download_path", None)
535
+ if download_path is not None and os.getenv("WANDB_ARTIFACT_DIR") is None:
536
+ os.environ["WANDB_ARTIFACT_DIR"] = download_path
537
+
538
+ super().__init__(
539
+ artifact_name=f"{artifact_name}:{version}",
540
+ formatter_name=formatter,
541
+ formatter_kwargs={"hf_llm_name": formatter_identifier} if formatter_identifier else {},
542
+ checkpoint_name=f"{artifact_name}/{version}",
543
+ **kwargs,
544
+ )
545
+
546
+
547
+ class Qwen3_0_6B_VLLM(VLLMModel):
548
+ LLM_NAME = "Qwen/Qwen3-0.6B"
549
+ DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME, chat_template_kwargs={"enable_thinking": True})
550
+
551
+
552
+ class Qwen3_0_6B_VLLM_No_Thinking(VLLMModel):
553
+ LLM_NAME = "Qwen/Qwen3-0.6B"
554
+ DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME, chat_template_kwargs={"enable_thinking": False})
@@ -0,0 +1,3 @@
1
+ import logging
2
+
3
+ logger = logging.getLogger("eval_framework")