EuroEval 15.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (40) hide show
  1. euroeval/__init__.py +72 -0
  2. euroeval/benchmark_config_factory.py +358 -0
  3. euroeval/benchmark_modules/__init__.py +7 -0
  4. euroeval/benchmark_modules/base.py +354 -0
  5. euroeval/benchmark_modules/fresh.py +286 -0
  6. euroeval/benchmark_modules/hf.py +1185 -0
  7. euroeval/benchmark_modules/litellm.py +905 -0
  8. euroeval/benchmark_modules/vllm.py +1171 -0
  9. euroeval/benchmarker.py +1074 -0
  10. euroeval/callbacks.py +72 -0
  11. euroeval/cli.py +281 -0
  12. euroeval/constants.py +50 -0
  13. euroeval/data_loading.py +96 -0
  14. euroeval/data_models.py +474 -0
  15. euroeval/dataset_configs.py +2001 -0
  16. euroeval/enums.py +144 -0
  17. euroeval/exceptions.py +191 -0
  18. euroeval/finetuning.py +324 -0
  19. euroeval/generation.py +296 -0
  20. euroeval/human_evaluation.py +737 -0
  21. euroeval/languages.py +200 -0
  22. euroeval/model_cache.py +253 -0
  23. euroeval/model_config.py +77 -0
  24. euroeval/model_loading.py +78 -0
  25. euroeval/scores.py +90 -0
  26. euroeval/speed_benchmark.py +124 -0
  27. euroeval/task_utils/__init__.py +1 -0
  28. euroeval/task_utils/multiple_choice_classification.py +176 -0
  29. euroeval/task_utils/question_answering.py +698 -0
  30. euroeval/task_utils/sequence_classification.py +237 -0
  31. euroeval/task_utils/text_to_text.py +150 -0
  32. euroeval/task_utils/token_classification.py +464 -0
  33. euroeval/tasks.py +202 -0
  34. euroeval/types.py +97 -0
  35. euroeval/utils.py +574 -0
  36. euroeval-15.2.0.dist-info/METADATA +234 -0
  37. euroeval-15.2.0.dist-info/RECORD +40 -0
  38. euroeval-15.2.0.dist-info/WHEEL +4 -0
  39. euroeval-15.2.0.dist-info/entry_points.txt +4 -0
  40. euroeval-15.2.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,905 @@
1
+ """Generative models from an inference API, using the LiteLLM framework."""
2
+
3
+ import collections.abc as c
4
+ import itertools as it
5
+ import json
6
+ import logging
7
+ import os
8
+ import random
9
+ import re
10
+ import typing as t
11
+ from functools import cached_property, partial
12
+ from time import sleep
13
+
14
+ import litellm
15
+ from datasets import DatasetDict
16
+ from huggingface_hub import HfApi
17
+ from huggingface_hub.errors import (
18
+ HFValidationError,
19
+ RepositoryNotFoundError,
20
+ RevisionNotFoundError,
21
+ )
22
+ from litellm.exceptions import (
23
+ APIConnectionError,
24
+ APIError,
25
+ AuthenticationError,
26
+ BadRequestError,
27
+ InternalServerError,
28
+ NotFoundError,
29
+ ServiceUnavailableError,
30
+ Timeout,
31
+ )
32
+ from litellm.types.utils import ModelResponse
33
+ from requests.exceptions import RequestException
34
+ from transformers import Trainer
35
+
36
+ from ..constants import (
37
+ MAX_LOGPROBS,
38
+ REASONING_MAX_TOKENS,
39
+ TASK_GROUPS_USING_LOGPROBS,
40
+ TASKS_USING_JSON,
41
+ )
42
+ from ..data_models import BenchmarkConfig, GenerativeModelOutput, ModelConfig, Task
43
+ from ..enums import (
44
+ BatchingPreference,
45
+ GenerativeType,
46
+ InferenceBackend,
47
+ ModelType,
48
+ TaskGroup,
49
+ )
50
+ from ..exceptions import (
51
+ InvalidBenchmark,
52
+ NeedsAdditionalArgument,
53
+ NeedsEnvironmentVariable,
54
+ NeedsExtraInstalled,
55
+ )
56
+ from ..task_utils import (
57
+ question_answering,
58
+ sequence_classification,
59
+ text_to_text,
60
+ token_classification,
61
+ )
62
+ from ..types import ExtractLabelsFunction
63
+ from ..utils import create_model_cache_dir
64
+ from .base import BenchmarkModule
65
+ from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokenizer
66
+
67
+ logger = logging.getLogger("euroeval")
68
+
69
+
70
+ VOCAB_SIZE_MAPPING = {
71
+ # OpenAI models
72
+ "(text-)?(ada|babbage|curie|davinci)(-001)?": 50_257,
73
+ "(code|text)-davinci-00[2-9]": 50_281,
74
+ "gpt-3.5-turbo(-16k)?(-[0-9]{4})?": 100_256,
75
+ "gpt-4-(32k)?(-[0-9]{4})?": 100_256,
76
+ "gpt-4-[0-9]{4}-preview": 100_256,
77
+ "gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
78
+ "gpt-4-(vision|turbo)(-preview)?": 100_256,
79
+ "gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
80
+ "gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
81
+ "o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
82
+ # Anthropic models
83
+ "claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
84
+ }
85
+
86
+
87
+ MODEL_MAX_LENGTH_MAPPING = {
88
+ # OpenAI models
89
+ "(text-)?(ada|babbage|curie|davinci)(-001)?": 2_050,
90
+ "text-davinci-00[2-9]": 4_098,
91
+ "code-davinci-00[1-9]": 8_002,
92
+ "gpt-3.5-turbo-0613": 4_096,
93
+ "gpt-3.5-turbo(-[0-9]{4})?": 16_385,
94
+ "gpt-3.5-turbo-16k(-[0-9]{4})?": 16_384,
95
+ "gpt-4(-[0-9]{4})?": 8_191,
96
+ "gpt-4-32k(-[0-9]{4})?": 32_767,
97
+ "gpt-4-[0-9]{4}-preview": 128_000,
98
+ "gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
99
+ "gpt-4-(vision|turbo)(-preview)?": 128_000,
100
+ "gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
101
+ "gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
102
+ "o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
103
+ "o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
104
+ "o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
105
+ # Anthropic models
106
+ "claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
107
+ }
108
+
109
+
110
+ NUM_PARAMS_MAPPING = {
111
+ # OpenAI models
112
+ "(text-)?ada(-001)?": 350_000_000,
113
+ "(text-)?babbage(-001)?": 3_000_000_000,
114
+ "(text-)?curie(-001)?": 13_000_000_000,
115
+ "((text|code)-)?davinci(-00[1-9])?": 175_000_000_000,
116
+ "gpt-(3.5|4)-turbo-((16|32)k)?(-[0-9]{4})?": -1,
117
+ "gpt-4-[0-9]{4}-preview": -1,
118
+ "gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
119
+ "gpt-4-(vision|turbo)(-preview)?": -1,
120
+ "gpt-3.5-turbo-instruct(-[0-9]{4})?": -1,
121
+ "gpt-4o(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
122
+ "gpt-4o-mini(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
123
+ "o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
124
+ # Anthropic models
125
+ "claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
126
+ }
127
+
128
+
129
+ REASONING_MODELS = ["o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?"]
130
+
131
+
132
+ class LiteLLMModel(BenchmarkModule):
133
+ """A generative model from LiteLLM."""
134
+
135
+ fresh_model = False
136
+ batching_preference = BatchingPreference.SINGLE_SAMPLE
137
+ high_priority = False
138
+
139
+ @property
140
+ def generative_type(self) -> GenerativeType | None:
141
+ """Get the generative type of the model.
142
+
143
+ Returns:
144
+ The generative type of the model, or None if it has not been set yet.
145
+ """
146
+ if re.fullmatch(
147
+ pattern="|".join(REASONING_MODELS), string=self.model_config.model_id
148
+ ):
149
+ return GenerativeType.REASONING
150
+ else:
151
+ return GenerativeType.INSTRUCTION_TUNED
152
+
153
+ def generate(self, inputs: dict) -> GenerativeModelOutput:
154
+ """Generate outputs from the model.
155
+
156
+ Args:
157
+ inputs:
158
+ A batch of inputs to pass through the model.
159
+
160
+ Returns:
161
+ The generated model outputs.
162
+ """
163
+ assert "messages" in inputs, "The input must contain a 'messages' key."
164
+ assert len(inputs["messages"]) == 1, (
165
+ "API models only support single-sample batching."
166
+ )
167
+ messages = inputs["messages"][0]
168
+
169
+ generation_kwargs: dict[str, t.Any] = dict(
170
+ model=self.model_config.model_id,
171
+ max_completion_tokens=(
172
+ REASONING_MAX_TOKENS
173
+ if self.generative_type == GenerativeType.REASONING
174
+ else self.dataset_config.max_generated_tokens
175
+ ),
176
+ stop=[],
177
+ temperature=0.0,
178
+ seed=4242,
179
+ api_key=self.benchmark_config.api_key,
180
+ api_base=self.benchmark_config.api_base,
181
+ api_version=self.benchmark_config.api_version,
182
+ )
183
+
184
+ if self.dataset_config.task.task_group in TASK_GROUPS_USING_LOGPROBS:
185
+ generation_kwargs["logprobs"] = True
186
+ generation_kwargs["top_logprobs"] = MAX_LOGPROBS
187
+
188
+ if self.dataset_config.task in TASKS_USING_JSON:
189
+ assert "json" in messages[0]["content"].lower(), (
190
+ "Prompt must contain 'json' for JSON tasks."
191
+ )
192
+ generation_kwargs["response_format"] = dict(type="json_object")
193
+
194
+ # This drops generation kwargs that are not supported by the model
195
+ litellm.drop_params = True
196
+
197
+ # Extract the generated sequences from the model response. Some APIs cannot
198
+ # handle using newlines as stop sequences, so we try both.
199
+ num_attempts = 10
200
+ for _ in range(num_attempts):
201
+ try:
202
+ model_response = litellm.completion(
203
+ messages=messages, max_retries=3, **generation_kwargs
204
+ )
205
+ break
206
+ except BadRequestError as e:
207
+ if "stop_sequences" in str(e).lower():
208
+ generation_kwargs["stop"] = None
209
+ elif "you are not allowed to request logprobs" in str(e).lower():
210
+ generation_kwargs.pop("logprobs")
211
+ generation_kwargs.pop("top_logprobs")
212
+ elif (
213
+ "'temperature' is not supported with this model." in str(e).lower()
214
+ ):
215
+ generation_kwargs.pop("temperature")
216
+ else:
217
+ raise InvalidBenchmark(
218
+ f"Failed to generate text. The error message was: {e}"
219
+ )
220
+ except (
221
+ Timeout,
222
+ ServiceUnavailableError,
223
+ APIConnectionError,
224
+ InternalServerError,
225
+ ):
226
+ logger.debug(
227
+ "Service temporarily unavailable. Retrying in 5 seconds..."
228
+ )
229
+ sleep(5)
230
+ except APIError as e:
231
+ raise InvalidBenchmark(
232
+ f"Failed to generate text. The error message was: {e}"
233
+ )
234
+ except AuthenticationError:
235
+ raise NeedsAdditionalArgument(
236
+ cli_argument="--api-key",
237
+ script_argument="api_key=<your-api-key>",
238
+ run_with_cli=self.benchmark_config.run_with_cli,
239
+ )
240
+ else:
241
+ raise InvalidBenchmark(
242
+ message=f"Failed to generate text, after {num_attempts} attempts."
243
+ )
244
+
245
+ assert isinstance(model_response, ModelResponse)
246
+ model_response_choices = model_response.choices[0]
247
+ assert isinstance(model_response_choices, litellm.Choices)
248
+ generation_output = model_response_choices.message["content"] or ""
249
+ generation_output = generation_output.strip()
250
+
251
+ # Structure the model output as a GenerativeModelOutput object
252
+ model_output = GenerativeModelOutput(sequences=[generation_output])
253
+ if hasattr(model_response_choices, "logprobs"):
254
+ logprobs_list: list[list[tuple[str, float]]] = [
255
+ [
256
+ (top_logprob.token, top_logprob.logprob)
257
+ for top_logprob in content.top_logprobs
258
+ ]
259
+ for content in model_response_choices.logprobs.content or list()
260
+ ]
261
+ model_output.scores = [logprobs_list]
262
+
263
+ return model_output
264
+
265
+ @cached_property
266
+ def num_params(self) -> int:
267
+ """The number of parameters in the model.
268
+
269
+ Returns:
270
+ The number of parameters in the model.
271
+ """
272
+ for key, value in NUM_PARAMS_MAPPING.items():
273
+ if re.fullmatch(pattern=key, string=self.model_config.model_id) is not None:
274
+ return value
275
+
276
+ if self.model_config.model_id.startswith("huggingface/"):
277
+ model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
278
+ if HuggingFaceEncoderModel.model_exists(
279
+ model_id=model_id, benchmark_config=self.benchmark_config
280
+ ):
281
+ hf_config = load_hf_model_config(
282
+ model_id=model_id,
283
+ num_labels=self.dataset_config.num_labels,
284
+ id2label=self.dataset_config.id2label,
285
+ label2id=self.dataset_config.label2id,
286
+ revision=self.model_config.revision,
287
+ model_cache_dir=self.model_config.model_cache_dir,
288
+ api_key=self.benchmark_config.api_key,
289
+ trust_remote_code=self.benchmark_config.trust_remote_code,
290
+ run_with_cli=self.benchmark_config.run_with_cli,
291
+ )
292
+
293
+ hf_api = HfApi()
294
+ try:
295
+ repo_info = hf_api.model_info(
296
+ repo_id=model_id,
297
+ revision=self.model_config.revision,
298
+ token=os.getenv("HUGGINGFACE_API_KEY")
299
+ or self.benchmark_config.api_key
300
+ or True,
301
+ )
302
+ except (
303
+ RepositoryNotFoundError,
304
+ RevisionNotFoundError,
305
+ RequestException,
306
+ HFValidationError,
307
+ ):
308
+ repo_info = None
309
+
310
+ if (
311
+ repo_info is not None
312
+ and hasattr(repo_info, "safetensors")
313
+ and repo_info.safetensors is not None
314
+ and "total" in repo_info.safetensors
315
+ ):
316
+ return repo_info.safetensors["total"]
317
+ elif (
318
+ hasattr(hf_config, "num_params")
319
+ and hf_config.num_params is not None
320
+ ):
321
+ return hf_config.num_params
322
+
323
+ return -1
324
+
325
+ @cached_property
326
+ def vocab_size(self) -> int:
327
+ """The vocabulary size of the model.
328
+
329
+ Returns:
330
+ The vocabulary size of the model.
331
+ """
332
+ for key, value in VOCAB_SIZE_MAPPING.items():
333
+ if re.fullmatch(pattern=key, string=self.model_config.model_id) is not None:
334
+ return value
335
+
336
+ if self.model_config.model_id.startswith("huggingface/"):
337
+ model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
338
+ if HuggingFaceEncoderModel.model_exists(
339
+ model_id=model_id, benchmark_config=self.benchmark_config
340
+ ):
341
+ hf_config = load_hf_model_config(
342
+ model_id=model_id,
343
+ num_labels=self.dataset_config.num_labels,
344
+ id2label=self.dataset_config.id2label,
345
+ label2id=self.dataset_config.label2id,
346
+ revision=self.model_config.revision,
347
+ model_cache_dir=self.model_config.model_cache_dir,
348
+ api_key=self.benchmark_config.api_key,
349
+ trust_remote_code=self.benchmark_config.trust_remote_code,
350
+ run_with_cli=self.benchmark_config.run_with_cli,
351
+ )
352
+
353
+ tokenizer = load_tokenizer(
354
+ model=None,
355
+ model_id=model_id,
356
+ trust_remote_code=self.benchmark_config.trust_remote_code,
357
+ )
358
+
359
+ if (
360
+ hasattr(hf_config, "vocab_size")
361
+ and hf_config.vocab_size is not None
362
+ ):
363
+ vocab_size = hf_config.vocab_size
364
+ elif (
365
+ hasattr(tokenizer, "vocab_size")
366
+ and tokenizer.vocab_size is not None
367
+ ):
368
+ vocab_size = tokenizer.vocab_size
369
+ else:
370
+ vocab_size = -1
371
+ return vocab_size
372
+
373
+ return -1
374
+
375
+ @cached_property
376
+ def model_max_length(self) -> int:
377
+ """The maximum length of the model.
378
+
379
+ Returns:
380
+ The maximum length of the model.
381
+ """
382
+ for key, value in MODEL_MAX_LENGTH_MAPPING.items():
383
+ if re.fullmatch(pattern=key, string=self.model_config.model_id) is not None:
384
+ return value
385
+
386
+ if self.model_config.model_id.startswith("huggingface/"):
387
+ model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
388
+ if HuggingFaceEncoderModel.model_exists(
389
+ model_id=model_id, benchmark_config=self.benchmark_config
390
+ ):
391
+ hf_config = load_hf_model_config(
392
+ model_id=model_id,
393
+ num_labels=self.dataset_config.num_labels,
394
+ id2label=self.dataset_config.id2label,
395
+ label2id=self.dataset_config.label2id,
396
+ revision=self.model_config.revision,
397
+ model_cache_dir=self.model_config.model_cache_dir,
398
+ api_key=self.benchmark_config.api_key,
399
+ trust_remote_code=self.benchmark_config.trust_remote_code,
400
+ run_with_cli=self.benchmark_config.run_with_cli,
401
+ )
402
+
403
+ tokenizer = load_tokenizer(
404
+ model=None,
405
+ model_id=model_id,
406
+ trust_remote_code=self.benchmark_config.trust_remote_code,
407
+ )
408
+
409
+ all_max_lengths: list[int] = list()
410
+
411
+ # Add the registered max length of the tokenizer
412
+ if hasattr(
413
+ tokenizer, "model_max_length"
414
+ ) and tokenizer.model_max_length < int(1e30):
415
+ all_max_lengths.append(tokenizer.model_max_length)
416
+
417
+ # Add the max length derived from the model's input sizes
418
+ if hasattr(tokenizer, "max_model_input_sizes"):
419
+ all_max_lengths.extend(
420
+ [
421
+ size
422
+ for size in tokenizer.max_model_input_sizes.values()
423
+ if size is not None
424
+ ]
425
+ )
426
+
427
+ # Add max length candidates from the model's configuration
428
+ candidate_config_max_lengths = [
429
+ "max_position_embeddings",
430
+ "max_sequence_length",
431
+ "model_max_length",
432
+ "sliding_window",
433
+ "sliding_window_size",
434
+ "n_positions",
435
+ ]
436
+ for candidate_config_max_length in candidate_config_max_lengths:
437
+ if (
438
+ hasattr(hf_config, candidate_config_max_length)
439
+ and (value := getattr(hf_config, candidate_config_max_length))
440
+ is not None
441
+ ):
442
+ all_max_lengths.append(value)
443
+
444
+ # To avoid models having artificially low max lengths, we remove any max
445
+ # lengths that are less than 128
446
+ all_max_lengths = [
447
+ max_length for max_length in all_max_lengths if max_length >= 128
448
+ ]
449
+
450
+ if len(list(all_max_lengths)) > 0:
451
+ return min(list(all_max_lengths))
452
+
453
+ return -1
454
+
455
+ @property
456
+ def data_collator(self) -> c.Callable[[list[t.Any]], dict[str, t.Any]]:
457
+ """The data collator used to prepare samples during finetuning.
458
+
459
+ Returns:
460
+ The data collator.
461
+ """
462
+ raise NotImplementedError(
463
+ "The `data_collator` property has not been implemented for LiteLLM models."
464
+ )
465
+
466
+ @property
467
+ def extract_labels_from_generation(self) -> ExtractLabelsFunction:
468
+ """The function used to extract the labels from the generated output.
469
+
470
+ Returns:
471
+ The function used to extract the labels from the generated output.
472
+ """
473
+ match self.dataset_config.task.task_group:
474
+ case (
475
+ TaskGroup.SEQUENCE_CLASSIFICATION
476
+ | TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
477
+ ):
478
+ return partial(
479
+ sequence_classification.extract_labels_from_generation,
480
+ dataset_config=self.dataset_config,
481
+ )
482
+ case TaskGroup.TEXT_TO_TEXT:
483
+ return text_to_text.extract_labels_from_generation
484
+ case TaskGroup.TOKEN_CLASSIFICATION:
485
+ return partial(
486
+ token_classification.extract_labels_from_generation,
487
+ dataset_config=self.dataset_config,
488
+ )
489
+ case TaskGroup.QUESTION_ANSWERING:
490
+ return question_answering.extract_labels_from_generation
491
+ case _:
492
+ raise NotImplementedError(
493
+ f"Unsupported task group: {self.dataset_config.task.task_group}."
494
+ )
495
+
496
+ @property
497
+ def trainer_class(self) -> t.Type["Trainer"]:
498
+ """The Trainer class to use for finetuning.
499
+
500
+ Returns:
501
+ The Trainer class.
502
+ """
503
+ raise NotImplementedError(
504
+ "The `trainer_class` property has not been implemented for LiteLLM models."
505
+ )
506
+
507
+ @classmethod
508
+ def model_exists(
509
+ cls, model_id: str, benchmark_config: BenchmarkConfig
510
+ ) -> bool | NeedsExtraInstalled | NeedsEnvironmentVariable:
511
+ """Check if a model exists.
512
+
513
+ Args:
514
+ model_id:
515
+ The model ID.
516
+ benchmark_config:
517
+ The benchmark configuration.
518
+
519
+ Returns:
520
+ Whether the model exists, or an error describing why we cannot check
521
+ whether the model exists.
522
+ """
523
+ if model_id in litellm.model_list:
524
+ return True
525
+
526
+ num_attempts = 10
527
+ for _ in range(num_attempts):
528
+ try:
529
+ litellm.completion(
530
+ messages=[dict(role="user", content="X")],
531
+ model=model_id,
532
+ max_tokens=1,
533
+ api_key=benchmark_config.api_key,
534
+ api_base=benchmark_config.api_base,
535
+ api_version=benchmark_config.api_version,
536
+ )
537
+ return True
538
+ except APIError as e:
539
+ if "'503 Service Unavailable" not in str(e):
540
+ raise e
541
+ logger.warning(
542
+ f"Failed to check if model {model_id!r} exists. Retrying in "
543
+ f"{num_attempts} seconds..."
544
+ )
545
+ sleep(10)
546
+ except (BadRequestError, NotFoundError):
547
+ candidate_models = [
548
+ candidate_model_id
549
+ for candidate_model_id in litellm.model_list
550
+ if candidate_model_id.startswith(model_id)
551
+ ]
552
+ match len(candidate_models):
553
+ case 0:
554
+ pass
555
+ case 1:
556
+ logger.warning(
557
+ f"Could not find the model ID {model_id!r}. Did you mean "
558
+ f"{candidate_models[0]!r}?"
559
+ )
560
+ case _:
561
+ candidate_models_str = "', '".join(candidate_models)
562
+ logger.warning(
563
+ f"Could not find the model ID {model_id!r}. Did you mean "
564
+ f"any of the following model IDs: '{candidate_models_str}'?"
565
+ )
566
+ return False
567
+ else:
568
+ logger.error(
569
+ f"Failed to check if model {model_id!r} exists after {num_attempts} "
570
+ "attempts. Assuming it does not exist."
571
+ )
572
+ return False
573
+
574
+ @classmethod
575
+ def get_model_config(
576
+ cls, model_id: str, benchmark_config: BenchmarkConfig
577
+ ) -> ModelConfig:
578
+ """Fetch the model configuration.
579
+
580
+ Args:
581
+ model_id:
582
+ The model ID.
583
+ benchmark_config:
584
+ The benchmark configuration.
585
+
586
+ Returns:
587
+ The model configuration.
588
+ """
589
+ return ModelConfig(
590
+ model_id=model_id,
591
+ revision="main",
592
+ task="text-generation",
593
+ languages=list(),
594
+ merge=False,
595
+ inference_backend=InferenceBackend.LITELLM,
596
+ model_type=ModelType.GENERATIVE,
597
+ fresh=False,
598
+ model_cache_dir=create_model_cache_dir(
599
+ cache_dir=benchmark_config.cache_dir, model_id=model_id
600
+ ),
601
+ adapter_base_model_id=None,
602
+ )
603
+
604
+ def prepare_dataset(
605
+ self, dataset: DatasetDict, task: Task, itr_idx: int
606
+ ) -> DatasetDict:
607
+ """Prepare the dataset for the model.
608
+
609
+ This includes things like tokenisation.
610
+
611
+ Args:
612
+ dataset:
613
+ The dataset to prepare.
614
+ task:
615
+ The task to prepare the dataset for.
616
+ itr_idx:
617
+ The index of the dataset in the iterator.
618
+
619
+ Returns:
620
+ The prepared dataset.
621
+ """
622
+ if task.task_group == TaskGroup.QUESTION_ANSWERING:
623
+ dataset = dataset.map(
624
+ lambda examples: dict(
625
+ label=[
626
+ dict(
627
+ id=id,
628
+ answers=dict(
629
+ answer_start=answer_dct["answer_start"],
630
+ text=[
631
+ answer_text.lower()
632
+ for answer_text in answer_dct["text"]
633
+ ],
634
+ ),
635
+ )
636
+ for id, answer_dct in zip(examples["id"], examples["answers"])
637
+ ]
638
+ ),
639
+ batched=True,
640
+ load_from_cache_file=False,
641
+ keep_in_memory=True,
642
+ )
643
+
644
+ if self.benchmark_config.few_shot:
645
+ few_shot_examples = self._extract_few_shot_examples(
646
+ dataset=dataset, task=task, itr_idx=itr_idx
647
+ )
648
+ else:
649
+ few_shot_examples = list()
650
+
651
+ dataset["test"] = dataset["test"].map(
652
+ partial(self._apply_prompt, few_shot_examples=few_shot_examples, task=task),
653
+ batched=True,
654
+ load_from_cache_file=False,
655
+ keep_in_memory=True,
656
+ )
657
+
658
+ return dataset
659
+
660
+ def _extract_few_shot_examples(
661
+ self, dataset: DatasetDict, task: Task, itr_idx: int
662
+ ) -> list[dict[str, t.Any]]:
663
+ """Extract few-shot examples from a dataset.
664
+
665
+ This will always extract the examples from the training split.
666
+
667
+ We ensure that the few-shot examples are unique by picking them one at a time.
668
+
669
+ Args:
670
+ dataset:
671
+ The dataset to extract the few-shot examples from.
672
+ task:
673
+ The task that is being benchmarked.
674
+ itr_idx:
675
+ The index of the dataset in the iterator.
676
+
677
+ Returns:
678
+ The few-shot examples.
679
+ """
680
+ random_seed = 4242 + itr_idx
681
+ num_few_shots = self.dataset_config.num_few_shot_examples
682
+ few_shot_examples: list[dict[str, t.Any]] = list()
683
+ shuffled_train = dataset["train"].shuffle(seed=random_seed)
684
+
685
+ match task.task_group:
686
+ case (
687
+ TaskGroup.SEQUENCE_CLASSIFICATION
688
+ | TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
689
+ ):
690
+ labels = it.cycle(self.dataset_config.labels)
691
+ while (
692
+ len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
693
+ ):
694
+ label = next(labels)
695
+ possible_examples = shuffled_train.filter(
696
+ lambda x: x["label"].lower() == label.lower()
697
+ )
698
+ if len(possible_examples) == 0:
699
+ continue
700
+ example = possible_examples.select(range(1))[0]
701
+ few_shot_examples.append(example)
702
+ shuffled_train = shuffled_train.filter(
703
+ lambda x: x["text"] != example["text"]
704
+ )
705
+
706
+ case TaskGroup.TEXT_TO_TEXT:
707
+ while (
708
+ len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
709
+ ):
710
+ example = shuffled_train.select(range(1))[0]
711
+ few_shot_examples.append(example)
712
+ shuffled_train = shuffled_train.filter(
713
+ lambda x: x["text"] != example["text"]
714
+ )
715
+
716
+ case TaskGroup.TOKEN_CLASSIFICATION:
717
+ labels = it.cycle(
718
+ [
719
+ label.lower()
720
+ for label in self.dataset_config.labels
721
+ if label.lower().startswith("b-")
722
+ ]
723
+ )
724
+ while (
725
+ len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
726
+ ):
727
+ label = next(labels)
728
+ possible_examples = shuffled_train.filter(
729
+ lambda x: label in [tag.lower() for tag in x["labels"]]
730
+ )
731
+ if len(possible_examples) == 0:
732
+ continue
733
+ example = possible_examples.select(range(1))[0]
734
+ few_shot_examples.append(example)
735
+ shuffled_train = shuffled_train.filter(
736
+ lambda x: x["tokens"] != example["tokens"]
737
+ )
738
+
739
+ case TaskGroup.QUESTION_ANSWERING:
740
+ # Locate the maximum number of tokens that constitutes a short example
741
+ for max_num_tokens in [512, 1024, 2048, 4096, 8192]:
742
+ train_with_short_examples = dataset["train"].filter(
743
+ lambda example: len(example["context"]) < max_num_tokens
744
+ )
745
+ num_short_examples = len(train_with_short_examples)
746
+ if num_short_examples >= self.dataset_config.num_few_shot_examples:
747
+ break
748
+ else:
749
+ raise InvalidBenchmark(
750
+ "Could not find enough short examples for few-shot learning."
751
+ )
752
+
753
+ shuffled_train = train_with_short_examples.shuffle(seed=random_seed)
754
+ while (
755
+ len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
756
+ ):
757
+ example = shuffled_train.select(range(1))[0]
758
+ few_shot_examples.append(example)
759
+ shuffled_train = shuffled_train.filter(
760
+ lambda x: x["context"] != example["context"]
761
+ )
762
+
763
+ case _:
764
+ raise NotImplementedError(f"Unsupported task group: {task.task_group}.")
765
+
766
+ random.seed(random_seed)
767
+ random.shuffle(few_shot_examples)
768
+ return few_shot_examples
769
+
770
+ def _apply_prompt(
771
+ self,
772
+ examples: dict[str, t.Any],
773
+ few_shot_examples: list[dict[str, t.Any]],
774
+ task: Task,
775
+ ) -> dict[str, t.Any]:
776
+ """Apply prompt template to an example, potentially with few-shot examples.
777
+
778
+ Args:
779
+ examples:
780
+ The examples to apply the few-shot examples to.
781
+ few_shot_examples:
782
+ The few-shot examples to apply.
783
+ task:
784
+ The task that is being benchmarked.
785
+
786
+ Returns:
787
+ The example with the few-shot examples applied.
788
+ """
789
+
790
+ def create_prompt(**kwargs: str) -> tuple[str, str]:
791
+ """Create a prompt from the given keyword arguments.
792
+
793
+ Args:
794
+ kwargs:
795
+ The keyword arguments to use in the prompt.
796
+
797
+ Returns:
798
+ A pair (prompt, label), where "label" is an empty string if the model is
799
+ not instruction tuned (as in this case it is included in the prompt).
800
+ """
801
+ label_key = "label" if "label" in kwargs else "target_text"
802
+ label = kwargs.pop(label_key)
803
+ label_mapping = self.dataset_config.prompt_label_mapping
804
+ label = label_mapping.get(label, label)
805
+ prompt = self.dataset_config.instruction_prompt.format(**kwargs)
806
+ return prompt, label
807
+
808
+ match task.task_group:
809
+ case (
810
+ TaskGroup.SEQUENCE_CLASSIFICATION
811
+ | TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
812
+ ):
813
+ few_shot_sections = [
814
+ create_prompt(
815
+ text=example["text"].replace("\n", " ").strip(),
816
+ label=example["label"].replace("\n", " ").strip(),
817
+ )
818
+ for example in few_shot_examples
819
+ ]
820
+ new_sections = [
821
+ create_prompt(text=text.replace("\n", " ").strip(), label="")
822
+ for text in examples["text"]
823
+ ]
824
+
825
+ case TaskGroup.TEXT_TO_TEXT:
826
+ few_shot_sections = [
827
+ create_prompt(
828
+ text=example["text"].replace("\n", " ").strip(),
829
+ target_text=example["target_text"].replace("\n", " ").strip(),
830
+ )
831
+ for example in few_shot_examples
832
+ ]
833
+ new_sections = [
834
+ create_prompt(text=text.replace("\n", " ").strip(), target_text="")
835
+ for text in examples["text"]
836
+ ]
837
+
838
+ case TaskGroup.TOKEN_CLASSIFICATION:
839
+
840
+ def create_label(example: dict) -> str:
841
+ prompt_labels = self.dataset_config.prompt_label_mapping.values()
842
+ labels: dict[str, list[str]] = {
843
+ prompt_label: list() for prompt_label in prompt_labels
844
+ }
845
+ for token, label in zip(example["tokens"], example["labels"]):
846
+ label = label.lower()
847
+ if label == "o":
848
+ continue
849
+ prompt_label = self.dataset_config.prompt_label_mapping[label]
850
+ if label.startswith("b-"):
851
+ labels[prompt_label].append(token)
852
+ elif label.startswith("i-"):
853
+ labels[prompt_label][-1] += " " + token
854
+ return json.dumps(labels, ensure_ascii=False)
855
+
856
+ few_shot_sections = [
857
+ create_prompt(
858
+ text=" ".join(example["tokens"]).replace("\n", " ").strip(),
859
+ label=create_label(example=example),
860
+ )
861
+ for example in few_shot_examples
862
+ ]
863
+ new_sections = [
864
+ create_prompt(
865
+ text=" ".join(tokens).replace("\n", " ").strip(), label=""
866
+ )
867
+ for tokens in examples["tokens"]
868
+ ]
869
+
870
+ case TaskGroup.QUESTION_ANSWERING:
871
+ few_shot_sections = [
872
+ create_prompt(
873
+ text=example["context"].replace("\n", " ").strip(),
874
+ question=example["question"].replace("\n", " ").strip(),
875
+ label=example["answers"]["text"][0].replace("\n", " "),
876
+ )
877
+ for example in few_shot_examples
878
+ ]
879
+ new_sections = [
880
+ create_prompt(
881
+ text=context.replace("\n", " ").strip(),
882
+ question=question.replace("\n", " ").strip(),
883
+ label="",
884
+ )
885
+ for context, question in zip(
886
+ examples["context"], examples["question"]
887
+ )
888
+ ]
889
+
890
+ case _:
891
+ raise NotImplementedError(f"Unsupported task group: {task.task_group}.")
892
+
893
+ few_shot_messages = [
894
+ dict(role=role, content=content)
895
+ for prompt, label in few_shot_sections
896
+ for role, content in [("user", prompt), ("assistant", label)]
897
+ ]
898
+
899
+ messages_list = [
900
+ few_shot_messages + [dict(role="user", content=prompt)]
901
+ for prompt, _ in new_sections
902
+ ]
903
+
904
+ examples["messages"] = messages_list
905
+ return examples