eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,28 @@
1
+ """This is just a default model file with some small models for testing.
2
+
3
+ Please define your own model file externally and pass it to the eval-framework entrypoint
4
+ to use it.
5
+ """
6
+
7
+ from eval_framework.utils.packaging import is_extra_installed
8
+
9
+ if is_extra_installed("api"):
10
+ from eval_framework.llm.aleph_alpha import AlephAlphaAPIModel # noqa F401
11
+
12
+ if is_extra_installed(extra="transformers"):
13
+ from eval_framework.llm.huggingface import ( # noqa F401
14
+ HFLLMRegistryModel,
15
+ Pythia410m,
16
+ SmolLM135M,
17
+ Smollm135MInstruct,
18
+ Qwen3_0_6B,
19
+ )
20
+
21
+ if is_extra_installed("mistral"):
22
+ from eval_framework.llm.mistral import MagistralVLLM # noqa F401
23
+
24
+ if is_extra_installed("openai"):
25
+ from eval_framework.llm.openai import OpenAIModel # noqa F401
26
+
27
+ if is_extra_installed("vllm"):
28
+ from eval_framework.llm.vllm import VLLMRegistryModel, Qwen3_0_6B_VLLM, Qwen3_0_6B_VLLM_No_Thinking # noqa F401
@@ -0,0 +1,400 @@
1
+ import concurrent.futures
2
+ import logging
3
+ import math
4
+ import os
5
+ import traceback
6
+ from collections.abc import Callable, Sequence
7
+ from functools import partial
8
+
9
+ import tiktoken
10
+ from openai import OpenAI
11
+ from openai.types.chat import ChatCompletionAssistantMessageParam, ChatCompletionUserMessageParam
12
+ from tokenizers import Tokenizer
13
+ from transformers import AutoTokenizer
14
+
15
+ from eval_framework.llm.base import BaseLLM
16
+ from eval_framework.shared.types import ConcatCompression, Error, RawCompletion, RawLoglikelihood
17
+ from eval_framework.tasks.base import Sample
18
+ from template_formatting.formatter import BaseFormatter, ConcatFormatter, HFFormatter, Message
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class OpenAIModel(BaseLLM):
24
+ """
25
+ LLM wrapper for OpenAI API providing text/chat completions and log-probability evaluation output.
26
+ """
27
+
28
+ LLM_NAME: str | None = None
29
+ DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
30
+ BYTES_PER_TOKEN: float = 4.0 # rule of thumb according to https://platform.openai.com/tokenizer
31
+
32
+ def __init__(
33
+ self,
34
+ model_name: str | None = None,
35
+ formatter: BaseFormatter | None = None,
36
+ temperature: float | None = None,
37
+ api_key: str | None = os.getenv("OPENAI_API_KEY", ""),
38
+ organization: str | None = None,
39
+ base_url: str | None = None,
40
+ bytes_per_token: float | None = None,
41
+ ) -> None:
42
+ """
43
+ Initialize the OpenAIModel.
44
+
45
+ Args:
46
+ model_name: OpenAI model name (e.g., "gpt-4o", "gpt-3.5-turbo"). If None, uses LLM_NAME class attribute.
47
+ formatter: Optional message formatter.
48
+ temperature: Sampling temperature used when not passed to generate methods (from 0.0 to 2.0).
49
+ api_key: OpenAI API key (defaults to OPENAI_API_KEY env variable).
50
+ organization: Optional OpenAI organization ID.
51
+ base_url: Optional API base URL for Azure or alternate endpoints.
52
+ bytes_per_token: Optional custom bytes per token scalar for non-standard models.
53
+ """
54
+ assert model_name is not None or self.LLM_NAME is not None, "A model name must be specified."
55
+ self._model_name = model_name if model_name else self.LLM_NAME
56
+ logger.info(f"Instantiating OpenAIModel with name: {self._model_name}")
57
+
58
+ self._formatter = formatter or (self.DEFAULT_FORMATTER() if self.DEFAULT_FORMATTER is not None else None)
59
+ self._temperature = temperature if temperature is not None else 0.0
60
+ assert 0.0 <= self._temperature <= 2.0, "Temperature must be between 0.0 and 2.0"
61
+
62
+ self._client = OpenAI(
63
+ api_key=api_key,
64
+ organization=organization,
65
+ base_url=base_url,
66
+ )
67
+
68
+ # Initialize tokenizer for the model
69
+ self._encoder = self._get_encoder()
70
+
71
+ # set bytes_per_token_scalar for non-standard models
72
+ if bytes_per_token is not None and bytes_per_token <= 0:
73
+ raise ValueError("bytes_per_token must be positive")
74
+ self.bytes_per_token_scalar = (
75
+ 4.0 / bytes_per_token if bytes_per_token is not None else 4.0 / self.BYTES_PER_TOKEN
76
+ )
77
+
78
+ def _get_encoder(self) -> tiktoken.Encoding:
79
+ assert self._model_name is not None
80
+ return tiktoken.encoding_for_model(self._model_name)
81
+
82
+ def _count_tokens(self, text: str) -> int:
83
+ """
84
+ Count tokens for the given text using the encoder.
85
+
86
+ Args:
87
+ text: Input string.
88
+
89
+ Returns:
90
+ Number of tokens.
91
+ """
92
+ return len(self._encoder.encode(text))
93
+
94
+ def generate_from_messages(
95
+ self,
96
+ messages: list[Sequence[Message]],
97
+ stop_sequences: list[str] | None = None,
98
+ max_tokens: int | None = None,
99
+ temperature: float | None = None,
100
+ ) -> list[RawCompletion]:
101
+ """
102
+ Generate completions for a list of message sequences concurrently.
103
+
104
+ Uses text completion API when a formatter is configured, otherwise uses chat completion API.
105
+
106
+ Args:
107
+ messages: Sequence of messages.
108
+ stop_sequences: Optional list of stop sequences.
109
+ max_tokens: Optional maximum number of tokens to generate.
110
+ temperature: Sampling temperature.
111
+
112
+ Returns:
113
+ List of RawCompletion objects containing prompts and completions.
114
+ """
115
+
116
+ effective_temperature = temperature if temperature is not None else self._temperature
117
+ assert 0.0 <= effective_temperature <= 2.0, "Temperature must be between 0.0 and 2.0"
118
+
119
+ def _process_one(single_messages: Sequence[Message]) -> RawCompletion:
120
+ # Adjust max tokens based on bytes_per_token_scalar so that non-standard models generate full responses
121
+ scaled_max_tokens = math.ceil(max_tokens * self.bytes_per_token_scalar) if max_tokens is not None else None
122
+
123
+ if self._formatter is not None:
124
+ # Use formatter and text completion API
125
+ prompt = self._formatter.format(single_messages, output_mode="string")
126
+ # documentation: https://platform.openai.com/docs/api-reference/completions/create
127
+ assert self._model_name is not None
128
+ response = self._client.completions.create(
129
+ model=self._model_name,
130
+ prompt=prompt,
131
+ temperature=effective_temperature,
132
+ max_tokens=scaled_max_tokens,
133
+ stop=stop_sequences,
134
+ )
135
+ completion = response.choices[0].text
136
+ return RawCompletion(
137
+ prompt=prompt,
138
+ prompt_sequence_positions=self._count_tokens(prompt),
139
+ concat_compression=ConcatCompression.calculate(
140
+ single_messages, count_tokens=self._count_tokens, completion=completion
141
+ ),
142
+ completion=completion,
143
+ completion_sequence_positions=self._count_tokens(completion),
144
+ )
145
+
146
+ else:
147
+ # Use chat completion API
148
+ chat_messages = [
149
+ (
150
+ ChatCompletionUserMessageParam(role="user", content=m.content)
151
+ if m.role is not None and m.role.value.lower() == "user"
152
+ else ChatCompletionAssistantMessageParam(role="assistant", content=m.content)
153
+ )
154
+ for m in single_messages
155
+ ]
156
+ assert self._model_name is not None
157
+ chat_response = self._client.chat.completions.create(
158
+ model=self._model_name,
159
+ messages=chat_messages,
160
+ temperature=effective_temperature,
161
+ max_tokens=scaled_max_tokens,
162
+ stop=stop_sequences,
163
+ )
164
+ prompt = "\n".join([f"{m.get('role', '')}: {m.get('content', '')}" for m in chat_messages])
165
+ prompt_tokens = getattr(chat_response.usage, "prompt_tokens", None)
166
+ completion = chat_response.choices[0].message.content or ""
167
+ return RawCompletion(
168
+ prompt=prompt,
169
+ prompt_sequence_positions=prompt_tokens,
170
+ concat_compression=ConcatCompression.calculate(
171
+ single_messages, count_tokens=self._count_tokens, completion=completion
172
+ ),
173
+ completion=completion,
174
+ completion_sequence_positions=self._count_tokens(completion),
175
+ )
176
+
177
+ with concurrent.futures.ThreadPoolExecutor() as executor:
178
+ results = list(executor.map(_process_one, messages))
179
+ return results
180
+
181
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
182
+ """
183
+ Compute total log-probabilities for possible completions given each sample's prompt.
184
+
185
+ Args:
186
+ samples: List of Sample objects, each with prompt messages and possible completions.
187
+
188
+ Returns:
189
+ List of RawLoglikelihood objects mapping each prompt and completion to its log-probability.
190
+
191
+ Note:
192
+ Uses the OpenAI completions API with echo=True; chat logprobs are not supported.
193
+ """
194
+ assert self._model_name in ["babbage-002", "davinci-002"], (
195
+ "Log-probs for prompt tokens are only supported for a limited set of models."
196
+ )
197
+ # apparently OpenAI stopped providing logprobs of prompt tokens, see discussion in:
198
+ # https://github.com/EleutherAI/lm-evaluation-harness/issues/1196
199
+
200
+ assert self._formatter is not None, "Log-probs require a formatter to create text prompts."
201
+ results: list[RawLoglikelihood] = []
202
+ for sample in samples:
203
+ prompt = self._formatter.format(sample.messages, output_mode="string") if sample.messages else ""
204
+ choices_log_probs: dict[str, float] = {}
205
+ choices_sequence_positions: dict[str, int] = {}
206
+ prompt_sequence_positions: int | None = self._count_tokens(prompt)
207
+ error: Error | None = None
208
+
209
+ for choice in sample.possible_completions or []:
210
+ if error is not None:
211
+ continue
212
+
213
+ # Tokenize prompt and completion
214
+ prompt_tokens = self._encoder.encode(prompt)
215
+ completion_tokens = self._encoder.encode(choice)
216
+ full_text = prompt + choice
217
+
218
+ try:
219
+ response = self._client.completions.create(
220
+ model=self._model_name,
221
+ prompt=full_text,
222
+ echo=True,
223
+ max_tokens=0,
224
+ logprobs=1,
225
+ temperature=0,
226
+ )
227
+
228
+ choice_obj = response.choices[0]
229
+ if not hasattr(choice_obj, "logprobs") or choice_obj.logprobs is None:
230
+ raise ValueError("Logprobs not returned in response.")
231
+
232
+ all_tokens = getattr(choice_obj.logprobs, "tokens", None)
233
+ all_logprobs = getattr(choice_obj.logprobs, "token_logprobs", None)
234
+
235
+ if all_tokens is None or all_logprobs is None:
236
+ raise ValueError("Logprobs response missing expected 'tokens' or 'token_logprobs' fields.")
237
+
238
+ if len(all_tokens) != len(prompt_tokens) + len(completion_tokens):
239
+ raise ValueError(
240
+ f"Token count mismatch: tokens in response ({len(all_tokens)}) != prompt+completion "
241
+ f"({len(prompt_tokens) + len(completion_tokens)})"
242
+ )
243
+
244
+ # Sum logprobs for the completion portion
245
+ choices_log_probs[choice] = sum(all_logprobs[len(prompt_tokens) :])
246
+ choices_sequence_positions[choice] = len(completion_tokens)
247
+
248
+ except Exception as e:
249
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
250
+ prompt_sequence_positions = None
251
+ choices_log_probs = {}
252
+ choices_sequence_positions = {}
253
+
254
+ results.append(
255
+ RawLoglikelihood(
256
+ prompt=prompt,
257
+ prompt_sequence_positions=prompt_sequence_positions,
258
+ loglikelihoods=choices_log_probs,
259
+ loglikelihoods_sequence_positions=choices_sequence_positions,
260
+ raw_loglikelihood_error=error,
261
+ )
262
+ )
263
+ return results
264
+
265
+ def __del__(self) -> None:
266
+ if hasattr(self, "_client"):
267
+ self._client.close()
268
+
269
+
270
+ class OpenAIEmbeddingModel(BaseLLM):
271
+ def __init__(
272
+ self,
273
+ model_name: str = "text-embedding-3-large",
274
+ formatter: BaseFormatter | None = None,
275
+ api_key: str | None = None,
276
+ organization: str | None = None,
277
+ base_url: str | None = None,
278
+ ) -> None:
279
+ """Initialize OpenAI API client.
280
+ Args:
281
+ model_name: Name of the OpenAI model to use (e.g., "text-embedding-3-large")
282
+ formatter: Optional message formatter
283
+ api_key: OpenAI API key (defaults to OPENAI_API_KEY env variable)
284
+ organization: Optional organization ID
285
+ base_url: Optional API base URL for Azure or other endpoints
286
+ """
287
+ if formatter is not None:
288
+ raise ValueError("Formatter is not supported for embedding model.")
289
+ self._model_name = model_name
290
+ logger.info(f"Using {model_name} as embedding model")
291
+ self._client = OpenAI(
292
+ api_key=api_key or os.getenv("OPENAI_API_KEY", ""),
293
+ organization=organization,
294
+ base_url=base_url,
295
+ )
296
+
297
+ def generate_from_messages(
298
+ self,
299
+ messages: list[Sequence[Message]],
300
+ stop_sequences: list[str] | None = None,
301
+ max_tokens: int | None = None,
302
+ temperature: float | None = None,
303
+ ) -> list[RawCompletion]:
304
+ raise NotImplementedError(
305
+ "Embedding model does not support generate_from_messages. Use generate_embeddings instead."
306
+ )
307
+
308
+ def generate_embeddings(
309
+ self,
310
+ messages: list[Sequence[Message]],
311
+ ) -> list[list[float]]:
312
+ embeddings = []
313
+ for single_messages in messages:
314
+ prompt = "".join([m.content for m in single_messages])
315
+ response = self._client.embeddings.create(model=self._model_name, input=[prompt])
316
+ embedding = response.data[0].embedding
317
+ embeddings.append(embedding)
318
+ return embeddings
319
+
320
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
321
+ raise NotImplementedError("Embedding model cannot return logprobs.")
322
+
323
+ def __del__(self) -> None:
324
+ if hasattr(self, "_client"):
325
+ self._client.close()
326
+ try:
327
+ self._client.close()
328
+ except Exception:
329
+ pass
330
+
331
+
332
+ class DeepseekModel(OpenAIModel):
333
+ """
334
+ General Deepseek model wrapper using OpenAI-compatible API for deepseek-chat and deepseek-reasoner models.
335
+
336
+ Using the deepseek API: https://api-docs.deepseek.com/quick_start/pricing
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ model_name: str | None = None,
342
+ formatter: BaseFormatter | None = None,
343
+ temperature: float | None = None,
344
+ api_key: str | None = None,
345
+ organization: str | None = None,
346
+ base_url: str | None = None,
347
+ tokenizer_name: str | None = None,
348
+ ) -> None:
349
+ super().__init__(
350
+ model_name=model_name,
351
+ formatter=formatter,
352
+ temperature=temperature,
353
+ api_key=os.getenv("DEEPSEEK_API_KEY", ""),
354
+ organization=organization,
355
+ base_url="https://api.deepseek.com/beta",
356
+ )
357
+ self._tokenizer_name = tokenizer_name if tokenizer_name is not None else "deepseek-ai/DeepSeek-V3.2-Exp"
358
+
359
+ def _get_encoder(self) -> Tokenizer:
360
+ return AutoTokenizer.from_pretrained(self._tokenizer_name)
361
+
362
+ def _count_tokens(self, text: str) -> int:
363
+ return len(self._encoder.encode(text))
364
+
365
+
366
+ ### Model Aliases ###
367
+
368
+
369
+ class OpenAI_gpt_4o_mini(OpenAIModel):
370
+ LLM_NAME = "gpt-4o-mini-2024-07-18"
371
+
372
+
373
+ class OpenAI_gpt_4o_mini_with_ConcatFormatter(OpenAIModel):
374
+ LLM_NAME = "gpt-4o-mini-2024-07-18"
375
+ DEFAULT_FORMATTER = ConcatFormatter
376
+
377
+
378
+ class OpenAI_davinci_002(OpenAIModel):
379
+ LLM_NAME = "davinci-002"
380
+ DEFAULT_FORMATTER = ConcatFormatter
381
+
382
+
383
+ class Deepseek_reasoner(DeepseekModel):
384
+ LLM_NAME = "deepseek-reasoner" # DeepSeek-V3.2-Exp (Thinking Mode)
385
+ # multi-round conversations for reasoning model documented here:
386
+ # https://api-docs.deepseek.com/guides/reasoning_model#api-example
387
+ # does not support completion API
388
+
389
+
390
+ class Deepseek_chat(DeepseekModel):
391
+ LLM_NAME = "deepseek-chat" # DeepSeek-V3.2-Exp (Non-thinking Mode)
392
+
393
+
394
+ class Deepseek_chat_with_formatter(DeepseekModel):
395
+ LLM_NAME = "deepseek-chat" # DeepSeek-V3.2-Exp (Non-thinking Mode)
396
+ DEFAULT_FORMATTER = partial(HFFormatter, "deepseek-ai/DeepSeek-V3.2-Exp")
397
+ """
398
+ <|begin▁of▁sentence|><|User|>Question: What color is the night sky?
399
+ <|Assistant|></think>Answer:
400
+ """