eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""This is just a default model file with some small models for testing.
|
|
2
|
+
|
|
3
|
+
Please define your own model file externally and pass it to the eval-framework entrypoint
|
|
4
|
+
to use it.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from eval_framework.utils.packaging import is_extra_installed
|
|
8
|
+
|
|
9
|
+
if is_extra_installed("api"):
|
|
10
|
+
from eval_framework.llm.aleph_alpha import AlephAlphaAPIModel # noqa F401
|
|
11
|
+
|
|
12
|
+
if is_extra_installed(extra="transformers"):
|
|
13
|
+
from eval_framework.llm.huggingface import ( # noqa F401
|
|
14
|
+
HFLLMRegistryModel,
|
|
15
|
+
Pythia410m,
|
|
16
|
+
SmolLM135M,
|
|
17
|
+
Smollm135MInstruct,
|
|
18
|
+
Qwen3_0_6B,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if is_extra_installed("mistral"):
|
|
22
|
+
from eval_framework.llm.mistral import MagistralVLLM # noqa F401
|
|
23
|
+
|
|
24
|
+
if is_extra_installed("openai"):
|
|
25
|
+
from eval_framework.llm.openai import OpenAIModel # noqa F401
|
|
26
|
+
|
|
27
|
+
if is_extra_installed("vllm"):
|
|
28
|
+
from eval_framework.llm.vllm import VLLMRegistryModel, Qwen3_0_6B_VLLM, Qwen3_0_6B_VLLM_No_Thinking # noqa F401
|
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
import concurrent.futures
|
|
2
|
+
import logging
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
import traceback
|
|
6
|
+
from collections.abc import Callable, Sequence
|
|
7
|
+
from functools import partial
|
|
8
|
+
|
|
9
|
+
import tiktoken
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
from openai.types.chat import ChatCompletionAssistantMessageParam, ChatCompletionUserMessageParam
|
|
12
|
+
from tokenizers import Tokenizer
|
|
13
|
+
from transformers import AutoTokenizer
|
|
14
|
+
|
|
15
|
+
from eval_framework.llm.base import BaseLLM
|
|
16
|
+
from eval_framework.shared.types import ConcatCompression, Error, RawCompletion, RawLoglikelihood
|
|
17
|
+
from eval_framework.tasks.base import Sample
|
|
18
|
+
from template_formatting.formatter import BaseFormatter, ConcatFormatter, HFFormatter, Message
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OpenAIModel(BaseLLM):
|
|
24
|
+
"""
|
|
25
|
+
LLM wrapper for OpenAI API providing text/chat completions and log-probability evaluation output.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
LLM_NAME: str | None = None
|
|
29
|
+
DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
|
|
30
|
+
BYTES_PER_TOKEN: float = 4.0 # rule of thumb according to https://platform.openai.com/tokenizer
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
model_name: str | None = None,
|
|
35
|
+
formatter: BaseFormatter | None = None,
|
|
36
|
+
temperature: float | None = None,
|
|
37
|
+
api_key: str | None = os.getenv("OPENAI_API_KEY", ""),
|
|
38
|
+
organization: str | None = None,
|
|
39
|
+
base_url: str | None = None,
|
|
40
|
+
bytes_per_token: float | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Initialize the OpenAIModel.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_name: OpenAI model name (e.g., "gpt-4o", "gpt-3.5-turbo"). If None, uses LLM_NAME class attribute.
|
|
47
|
+
formatter: Optional message formatter.
|
|
48
|
+
temperature: Sampling temperature used when not passed to generate methods (from 0.0 to 2.0).
|
|
49
|
+
api_key: OpenAI API key (defaults to OPENAI_API_KEY env variable).
|
|
50
|
+
organization: Optional OpenAI organization ID.
|
|
51
|
+
base_url: Optional API base URL for Azure or alternate endpoints.
|
|
52
|
+
bytes_per_token: Optional custom bytes per token scalar for non-standard models.
|
|
53
|
+
"""
|
|
54
|
+
assert model_name is not None or self.LLM_NAME is not None, "A model name must be specified."
|
|
55
|
+
self._model_name = model_name if model_name else self.LLM_NAME
|
|
56
|
+
logger.info(f"Instantiating OpenAIModel with name: {self._model_name}")
|
|
57
|
+
|
|
58
|
+
self._formatter = formatter or (self.DEFAULT_FORMATTER() if self.DEFAULT_FORMATTER is not None else None)
|
|
59
|
+
self._temperature = temperature if temperature is not None else 0.0
|
|
60
|
+
assert 0.0 <= self._temperature <= 2.0, "Temperature must be between 0.0 and 2.0"
|
|
61
|
+
|
|
62
|
+
self._client = OpenAI(
|
|
63
|
+
api_key=api_key,
|
|
64
|
+
organization=organization,
|
|
65
|
+
base_url=base_url,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Initialize tokenizer for the model
|
|
69
|
+
self._encoder = self._get_encoder()
|
|
70
|
+
|
|
71
|
+
# set bytes_per_token_scalar for non-standard models
|
|
72
|
+
if bytes_per_token is not None and bytes_per_token <= 0:
|
|
73
|
+
raise ValueError("bytes_per_token must be positive")
|
|
74
|
+
self.bytes_per_token_scalar = (
|
|
75
|
+
4.0 / bytes_per_token if bytes_per_token is not None else 4.0 / self.BYTES_PER_TOKEN
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _get_encoder(self) -> tiktoken.Encoding:
|
|
79
|
+
assert self._model_name is not None
|
|
80
|
+
return tiktoken.encoding_for_model(self._model_name)
|
|
81
|
+
|
|
82
|
+
def _count_tokens(self, text: str) -> int:
|
|
83
|
+
"""
|
|
84
|
+
Count tokens for the given text using the encoder.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
text: Input string.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Number of tokens.
|
|
91
|
+
"""
|
|
92
|
+
return len(self._encoder.encode(text))
|
|
93
|
+
|
|
94
|
+
def generate_from_messages(
|
|
95
|
+
self,
|
|
96
|
+
messages: list[Sequence[Message]],
|
|
97
|
+
stop_sequences: list[str] | None = None,
|
|
98
|
+
max_tokens: int | None = None,
|
|
99
|
+
temperature: float | None = None,
|
|
100
|
+
) -> list[RawCompletion]:
|
|
101
|
+
"""
|
|
102
|
+
Generate completions for a list of message sequences concurrently.
|
|
103
|
+
|
|
104
|
+
Uses text completion API when a formatter is configured, otherwise uses chat completion API.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
messages: Sequence of messages.
|
|
108
|
+
stop_sequences: Optional list of stop sequences.
|
|
109
|
+
max_tokens: Optional maximum number of tokens to generate.
|
|
110
|
+
temperature: Sampling temperature.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
List of RawCompletion objects containing prompts and completions.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
effective_temperature = temperature if temperature is not None else self._temperature
|
|
117
|
+
assert 0.0 <= effective_temperature <= 2.0, "Temperature must be between 0.0 and 2.0"
|
|
118
|
+
|
|
119
|
+
def _process_one(single_messages: Sequence[Message]) -> RawCompletion:
|
|
120
|
+
# Adjust max tokens based on bytes_per_token_scalar so that non-standard models generate full responses
|
|
121
|
+
scaled_max_tokens = math.ceil(max_tokens * self.bytes_per_token_scalar) if max_tokens is not None else None
|
|
122
|
+
|
|
123
|
+
if self._formatter is not None:
|
|
124
|
+
# Use formatter and text completion API
|
|
125
|
+
prompt = self._formatter.format(single_messages, output_mode="string")
|
|
126
|
+
# documentation: https://platform.openai.com/docs/api-reference/completions/create
|
|
127
|
+
assert self._model_name is not None
|
|
128
|
+
response = self._client.completions.create(
|
|
129
|
+
model=self._model_name,
|
|
130
|
+
prompt=prompt,
|
|
131
|
+
temperature=effective_temperature,
|
|
132
|
+
max_tokens=scaled_max_tokens,
|
|
133
|
+
stop=stop_sequences,
|
|
134
|
+
)
|
|
135
|
+
completion = response.choices[0].text
|
|
136
|
+
return RawCompletion(
|
|
137
|
+
prompt=prompt,
|
|
138
|
+
prompt_sequence_positions=self._count_tokens(prompt),
|
|
139
|
+
concat_compression=ConcatCompression.calculate(
|
|
140
|
+
single_messages, count_tokens=self._count_tokens, completion=completion
|
|
141
|
+
),
|
|
142
|
+
completion=completion,
|
|
143
|
+
completion_sequence_positions=self._count_tokens(completion),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
else:
|
|
147
|
+
# Use chat completion API
|
|
148
|
+
chat_messages = [
|
|
149
|
+
(
|
|
150
|
+
ChatCompletionUserMessageParam(role="user", content=m.content)
|
|
151
|
+
if m.role is not None and m.role.value.lower() == "user"
|
|
152
|
+
else ChatCompletionAssistantMessageParam(role="assistant", content=m.content)
|
|
153
|
+
)
|
|
154
|
+
for m in single_messages
|
|
155
|
+
]
|
|
156
|
+
assert self._model_name is not None
|
|
157
|
+
chat_response = self._client.chat.completions.create(
|
|
158
|
+
model=self._model_name,
|
|
159
|
+
messages=chat_messages,
|
|
160
|
+
temperature=effective_temperature,
|
|
161
|
+
max_tokens=scaled_max_tokens,
|
|
162
|
+
stop=stop_sequences,
|
|
163
|
+
)
|
|
164
|
+
prompt = "\n".join([f"{m.get('role', '')}: {m.get('content', '')}" for m in chat_messages])
|
|
165
|
+
prompt_tokens = getattr(chat_response.usage, "prompt_tokens", None)
|
|
166
|
+
completion = chat_response.choices[0].message.content or ""
|
|
167
|
+
return RawCompletion(
|
|
168
|
+
prompt=prompt,
|
|
169
|
+
prompt_sequence_positions=prompt_tokens,
|
|
170
|
+
concat_compression=ConcatCompression.calculate(
|
|
171
|
+
single_messages, count_tokens=self._count_tokens, completion=completion
|
|
172
|
+
),
|
|
173
|
+
completion=completion,
|
|
174
|
+
completion_sequence_positions=self._count_tokens(completion),
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
178
|
+
results = list(executor.map(_process_one, messages))
|
|
179
|
+
return results
|
|
180
|
+
|
|
181
|
+
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
|
|
182
|
+
"""
|
|
183
|
+
Compute total log-probabilities for possible completions given each sample's prompt.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
samples: List of Sample objects, each with prompt messages and possible completions.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
List of RawLoglikelihood objects mapping each prompt and completion to its log-probability.
|
|
190
|
+
|
|
191
|
+
Note:
|
|
192
|
+
Uses the OpenAI completions API with echo=True; chat logprobs are not supported.
|
|
193
|
+
"""
|
|
194
|
+
assert self._model_name in ["babbage-002", "davinci-002"], (
|
|
195
|
+
"Log-probs for prompt tokens are only supported for a limited set of models."
|
|
196
|
+
)
|
|
197
|
+
# apparently OpenAI stopped providing logprobs of prompt tokens, see discussion in:
|
|
198
|
+
# https://github.com/EleutherAI/lm-evaluation-harness/issues/1196
|
|
199
|
+
|
|
200
|
+
assert self._formatter is not None, "Log-probs require a formatter to create text prompts."
|
|
201
|
+
results: list[RawLoglikelihood] = []
|
|
202
|
+
for sample in samples:
|
|
203
|
+
prompt = self._formatter.format(sample.messages, output_mode="string") if sample.messages else ""
|
|
204
|
+
choices_log_probs: dict[str, float] = {}
|
|
205
|
+
choices_sequence_positions: dict[str, int] = {}
|
|
206
|
+
prompt_sequence_positions: int | None = self._count_tokens(prompt)
|
|
207
|
+
error: Error | None = None
|
|
208
|
+
|
|
209
|
+
for choice in sample.possible_completions or []:
|
|
210
|
+
if error is not None:
|
|
211
|
+
continue
|
|
212
|
+
|
|
213
|
+
# Tokenize prompt and completion
|
|
214
|
+
prompt_tokens = self._encoder.encode(prompt)
|
|
215
|
+
completion_tokens = self._encoder.encode(choice)
|
|
216
|
+
full_text = prompt + choice
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
response = self._client.completions.create(
|
|
220
|
+
model=self._model_name,
|
|
221
|
+
prompt=full_text,
|
|
222
|
+
echo=True,
|
|
223
|
+
max_tokens=0,
|
|
224
|
+
logprobs=1,
|
|
225
|
+
temperature=0,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
choice_obj = response.choices[0]
|
|
229
|
+
if not hasattr(choice_obj, "logprobs") or choice_obj.logprobs is None:
|
|
230
|
+
raise ValueError("Logprobs not returned in response.")
|
|
231
|
+
|
|
232
|
+
all_tokens = getattr(choice_obj.logprobs, "tokens", None)
|
|
233
|
+
all_logprobs = getattr(choice_obj.logprobs, "token_logprobs", None)
|
|
234
|
+
|
|
235
|
+
if all_tokens is None or all_logprobs is None:
|
|
236
|
+
raise ValueError("Logprobs response missing expected 'tokens' or 'token_logprobs' fields.")
|
|
237
|
+
|
|
238
|
+
if len(all_tokens) != len(prompt_tokens) + len(completion_tokens):
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Token count mismatch: tokens in response ({len(all_tokens)}) != prompt+completion "
|
|
241
|
+
f"({len(prompt_tokens) + len(completion_tokens)})"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Sum logprobs for the completion portion
|
|
245
|
+
choices_log_probs[choice] = sum(all_logprobs[len(prompt_tokens) :])
|
|
246
|
+
choices_sequence_positions[choice] = len(completion_tokens)
|
|
247
|
+
|
|
248
|
+
except Exception as e:
|
|
249
|
+
error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
|
|
250
|
+
prompt_sequence_positions = None
|
|
251
|
+
choices_log_probs = {}
|
|
252
|
+
choices_sequence_positions = {}
|
|
253
|
+
|
|
254
|
+
results.append(
|
|
255
|
+
RawLoglikelihood(
|
|
256
|
+
prompt=prompt,
|
|
257
|
+
prompt_sequence_positions=prompt_sequence_positions,
|
|
258
|
+
loglikelihoods=choices_log_probs,
|
|
259
|
+
loglikelihoods_sequence_positions=choices_sequence_positions,
|
|
260
|
+
raw_loglikelihood_error=error,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
return results
|
|
264
|
+
|
|
265
|
+
def __del__(self) -> None:
|
|
266
|
+
if hasattr(self, "_client"):
|
|
267
|
+
self._client.close()
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class OpenAIEmbeddingModel(BaseLLM):
|
|
271
|
+
def __init__(
|
|
272
|
+
self,
|
|
273
|
+
model_name: str = "text-embedding-3-large",
|
|
274
|
+
formatter: BaseFormatter | None = None,
|
|
275
|
+
api_key: str | None = None,
|
|
276
|
+
organization: str | None = None,
|
|
277
|
+
base_url: str | None = None,
|
|
278
|
+
) -> None:
|
|
279
|
+
"""Initialize OpenAI API client.
|
|
280
|
+
Args:
|
|
281
|
+
model_name: Name of the OpenAI model to use (e.g., "text-embedding-3-large")
|
|
282
|
+
formatter: Optional message formatter
|
|
283
|
+
api_key: OpenAI API key (defaults to OPENAI_API_KEY env variable)
|
|
284
|
+
organization: Optional organization ID
|
|
285
|
+
base_url: Optional API base URL for Azure or other endpoints
|
|
286
|
+
"""
|
|
287
|
+
if formatter is not None:
|
|
288
|
+
raise ValueError("Formatter is not supported for embedding model.")
|
|
289
|
+
self._model_name = model_name
|
|
290
|
+
logger.info(f"Using {model_name} as embedding model")
|
|
291
|
+
self._client = OpenAI(
|
|
292
|
+
api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
|
|
293
|
+
organization=organization,
|
|
294
|
+
base_url=base_url,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def generate_from_messages(
|
|
298
|
+
self,
|
|
299
|
+
messages: list[Sequence[Message]],
|
|
300
|
+
stop_sequences: list[str] | None = None,
|
|
301
|
+
max_tokens: int | None = None,
|
|
302
|
+
temperature: float | None = None,
|
|
303
|
+
) -> list[RawCompletion]:
|
|
304
|
+
raise NotImplementedError(
|
|
305
|
+
"Embedding model does not support generate_from_messages. Use generate_embeddings instead."
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def generate_embeddings(
|
|
309
|
+
self,
|
|
310
|
+
messages: list[Sequence[Message]],
|
|
311
|
+
) -> list[list[float]]:
|
|
312
|
+
embeddings = []
|
|
313
|
+
for single_messages in messages:
|
|
314
|
+
prompt = "".join([m.content for m in single_messages])
|
|
315
|
+
response = self._client.embeddings.create(model=self._model_name, input=[prompt])
|
|
316
|
+
embedding = response.data[0].embedding
|
|
317
|
+
embeddings.append(embedding)
|
|
318
|
+
return embeddings
|
|
319
|
+
|
|
320
|
+
def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
|
|
321
|
+
raise NotImplementedError("Embedding model cannot return logprobs.")
|
|
322
|
+
|
|
323
|
+
def __del__(self) -> None:
|
|
324
|
+
if hasattr(self, "_client"):
|
|
325
|
+
self._client.close()
|
|
326
|
+
try:
|
|
327
|
+
self._client.close()
|
|
328
|
+
except Exception:
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class DeepseekModel(OpenAIModel):
|
|
333
|
+
"""
|
|
334
|
+
General Deepseek model wrapper using OpenAI-compatible API for deepseek-chat and deepseek-reasoner models.
|
|
335
|
+
|
|
336
|
+
Using the deepseek API: https://api-docs.deepseek.com/quick_start/pricing
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
def __init__(
|
|
340
|
+
self,
|
|
341
|
+
model_name: str | None = None,
|
|
342
|
+
formatter: BaseFormatter | None = None,
|
|
343
|
+
temperature: float | None = None,
|
|
344
|
+
api_key: str | None = None,
|
|
345
|
+
organization: str | None = None,
|
|
346
|
+
base_url: str | None = None,
|
|
347
|
+
tokenizer_name: str | None = None,
|
|
348
|
+
) -> None:
|
|
349
|
+
super().__init__(
|
|
350
|
+
model_name=model_name,
|
|
351
|
+
formatter=formatter,
|
|
352
|
+
temperature=temperature,
|
|
353
|
+
api_key=os.getenv("DEEPSEEK_API_KEY", ""),
|
|
354
|
+
organization=organization,
|
|
355
|
+
base_url="https://api.deepseek.com/beta",
|
|
356
|
+
)
|
|
357
|
+
self._tokenizer_name = tokenizer_name if tokenizer_name is not None else "deepseek-ai/DeepSeek-V3.2-Exp"
|
|
358
|
+
|
|
359
|
+
def _get_encoder(self) -> Tokenizer:
|
|
360
|
+
return AutoTokenizer.from_pretrained(self._tokenizer_name)
|
|
361
|
+
|
|
362
|
+
def _count_tokens(self, text: str) -> int:
|
|
363
|
+
return len(self._encoder.encode(text))
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
### Model Aliases ###
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class OpenAI_gpt_4o_mini(OpenAIModel):
|
|
370
|
+
LLM_NAME = "gpt-4o-mini-2024-07-18"
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class OpenAI_gpt_4o_mini_with_ConcatFormatter(OpenAIModel):
|
|
374
|
+
LLM_NAME = "gpt-4o-mini-2024-07-18"
|
|
375
|
+
DEFAULT_FORMATTER = ConcatFormatter
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class OpenAI_davinci_002(OpenAIModel):
|
|
379
|
+
LLM_NAME = "davinci-002"
|
|
380
|
+
DEFAULT_FORMATTER = ConcatFormatter
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
class Deepseek_reasoner(DeepseekModel):
|
|
384
|
+
LLM_NAME = "deepseek-reasoner" # DeepSeek-V3.2-Exp (Thinking Mode)
|
|
385
|
+
# multi-round conversations for reasoning model documented here:
|
|
386
|
+
# https://api-docs.deepseek.com/guides/reasoning_model#api-example
|
|
387
|
+
# does not support completion API
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
class Deepseek_chat(DeepseekModel):
|
|
391
|
+
LLM_NAME = "deepseek-chat" # DeepSeek-V3.2-Exp (Non-thinking Mode)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class Deepseek_chat_with_formatter(DeepseekModel):
|
|
395
|
+
LLM_NAME = "deepseek-chat" # DeepSeek-V3.2-Exp (Non-thinking Mode)
|
|
396
|
+
DEFAULT_FORMATTER = partial(HFFormatter, "deepseek-ai/DeepSeek-V3.2-Exp")
|
|
397
|
+
"""
|
|
398
|
+
<|begin▁of▁sentence|><|User|>Question: What color is the night sky?
|
|
399
|
+
<|Assistant|></think>Answer:
|
|
400
|
+
"""
|