EuroEval 15.5.0__py3-none-any.whl → 15.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (53) hide show
  1. euroeval/benchmark_modules/base.py +3 -2
  2. euroeval/benchmark_modules/fresh.py +8 -6
  3. euroeval/benchmark_modules/hf.py +33 -31
  4. euroeval/benchmark_modules/litellm.py +120 -56
  5. euroeval/benchmark_modules/vllm.py +41 -26
  6. euroeval/benchmarker.py +23 -21
  7. euroeval/callbacks.py +2 -2
  8. euroeval/constants.py +1 -1
  9. euroeval/data_models.py +257 -42
  10. euroeval/dataset_configs/__init__.py +61 -0
  11. euroeval/dataset_configs/danish.py +120 -0
  12. euroeval/dataset_configs/dutch.py +123 -0
  13. euroeval/dataset_configs/english.py +88 -0
  14. euroeval/dataset_configs/faroese.py +53 -0
  15. euroeval/dataset_configs/french.py +83 -0
  16. euroeval/dataset_configs/german.py +91 -0
  17. euroeval/dataset_configs/icelandic.py +148 -0
  18. euroeval/dataset_configs/italian.py +81 -0
  19. euroeval/dataset_configs/norwegian.py +178 -0
  20. euroeval/dataset_configs/spanish.py +78 -0
  21. euroeval/dataset_configs/swedish.py +100 -0
  22. euroeval/exceptions.py +10 -10
  23. euroeval/finetuning.py +6 -10
  24. euroeval/generation.py +1 -0
  25. euroeval/human_evaluation.py +2 -2
  26. euroeval/languages.py +20 -13
  27. euroeval/model_cache.py +1 -1
  28. euroeval/model_loading.py +1 -12
  29. euroeval/prompt_templates/__init__.py +8 -0
  30. euroeval/prompt_templates/linguistic_acceptability.py +112 -0
  31. euroeval/prompt_templates/multiple_choice.py +97 -0
  32. euroeval/prompt_templates/named_entity_recognition.py +257 -0
  33. euroeval/prompt_templates/reading_comprehension.py +118 -0
  34. euroeval/prompt_templates/sentiment_classification.py +137 -0
  35. euroeval/prompt_templates/summarization.py +97 -0
  36. euroeval/speed_benchmark.py +1 -1
  37. euroeval/{task_utils → task_group_utils}/multiple_choice_classification.py +19 -11
  38. euroeval/{task_utils → task_group_utils}/question_answering.py +31 -30
  39. euroeval/{task_utils → task_group_utils}/sequence_classification.py +1 -1
  40. euroeval/{task_utils → task_group_utils}/text_to_text.py +1 -1
  41. euroeval/{task_utils → task_group_utils}/token_classification.py +3 -2
  42. euroeval/tasks.py +54 -0
  43. euroeval/tokenization_utils.py +343 -0
  44. euroeval/types.py +3 -1
  45. euroeval/utils.py +2 -347
  46. {euroeval-15.5.0.dist-info → euroeval-15.6.0.dist-info}/METADATA +30 -9
  47. euroeval-15.6.0.dist-info/RECORD +59 -0
  48. euroeval/dataset_configs.py +0 -2408
  49. euroeval-15.5.0.dist-info/RECORD +0 -40
  50. /euroeval/{task_utils → task_group_utils}/__init__.py +0 -0
  51. {euroeval-15.5.0.dist-info → euroeval-15.6.0.dist-info}/WHEEL +0 -0
  52. {euroeval-15.5.0.dist-info → euroeval-15.6.0.dist-info}/entry_points.txt +0 -0
  53. {euroeval-15.5.0.dist-info → euroeval-15.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -10,7 +10,8 @@ from functools import cached_property, partial
10
10
  from datasets import DatasetDict
11
11
  from torch import nn
12
12
  from tqdm.auto import tqdm
13
- from transformers import PreTrainedTokenizer, Trainer
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+ from transformers.trainer import Trainer
14
15
 
15
16
  from ..data_models import (
16
17
  BenchmarkConfig,
@@ -21,7 +22,7 @@ from ..data_models import (
21
22
  )
22
23
  from ..enums import BatchingPreference, GenerativeType, TaskGroup
23
24
  from ..exceptions import NeedsEnvironmentVariable, NeedsExtraInstalled
24
- from ..task_utils import (
25
+ from ..task_group_utils import (
25
26
  question_answering,
26
27
  sequence_classification,
27
28
  text_to_text,
@@ -4,19 +4,21 @@ import os
4
4
  from functools import cached_property
5
5
  from json import JSONDecodeError
6
6
 
7
- from transformers import (
8
- AutoConfig,
9
- AutoTokenizer,
7
+ from transformers.configuration_utils import PretrainedConfig
8
+ from transformers.modeling_utils import PreTrainedModel
9
+ from transformers.models.auto.configuration_auto import AutoConfig
10
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
11
+ from transformers.models.electra import (
10
12
  ElectraForQuestionAnswering,
11
13
  ElectraForSequenceClassification,
12
14
  ElectraForTokenClassification,
13
- PretrainedConfig,
14
- PreTrainedModel,
15
- PreTrainedTokenizer,
15
+ )
16
+ from transformers.models.xlm_roberta import (
16
17
  XLMRobertaForQuestionAnswering,
17
18
  XLMRobertaForSequenceClassification,
18
19
  XLMRobertaForTokenClassification,
19
20
  )
21
+ from transformers.tokenization_utils import PreTrainedTokenizer
20
22
 
21
23
  from ..data_models import BenchmarkConfig, DatasetConfig, ModelConfig
22
24
  from ..enums import InferenceBackend, ModelType, TaskGroup
@@ -13,31 +13,29 @@ import torch
13
13
  from datasets import DatasetDict
14
14
  from huggingface_hub import HfApi
15
15
  from huggingface_hub import whoami as hf_whoami
16
- from huggingface_hub.hf_api import ModelInfo as HfApiModelInfo
17
- from huggingface_hub.hf_api import RepositoryNotFoundError, RevisionNotFoundError
18
- from huggingface_hub.utils import (
16
+ from huggingface_hub.errors import (
19
17
  GatedRepoError,
20
18
  HFValidationError,
21
19
  LocalTokenNotFoundError,
20
+ RepositoryNotFoundError,
21
+ RevisionNotFoundError,
22
22
  )
23
+ from huggingface_hub.hf_api import ModelInfo as HfApiModelInfo
23
24
  from peft import PeftConfig
24
25
  from requests.exceptions import RequestException
25
26
  from torch import nn
26
- from transformers import (
27
- AutoConfig,
28
- AutoTokenizer,
29
- BatchEncoding,
27
+ from transformers.configuration_utils import PretrainedConfig
28
+ from transformers.data.data_collator import (
30
29
  DataCollatorForTokenClassification,
31
30
  DataCollatorWithPadding,
32
- PretrainedConfig,
33
- PreTrainedModel,
34
- PreTrainedTokenizer,
35
- Trainer,
36
31
  )
37
32
  from transformers.modelcard import TASK_MAPPING
38
- from transformers.models.auto.modeling_auto import (
39
- MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
40
- )
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.models.auto.configuration_auto import AutoConfig
35
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
36
+ from transformers.tokenization_utils import PreTrainedTokenizer
37
+ from transformers.tokenization_utils_base import BatchEncoding
38
+ from transformers.trainer import Trainer
41
39
  from urllib3.exceptions import RequestError
42
40
 
43
41
  from ..constants import (
@@ -65,18 +63,17 @@ from ..exceptions import (
65
63
  NoInternetConnection,
66
64
  )
67
65
  from ..languages import get_all_languages
68
- from ..task_utils import (
66
+ from ..task_group_utils import (
69
67
  multiple_choice_classification,
70
68
  question_answering,
71
69
  token_classification,
72
70
  )
71
+ from ..tokenization_utils import get_bos_token, get_eos_token
73
72
  from ..types import ExtractLabelsFunction
74
73
  from ..utils import (
75
74
  block_terminal_output,
76
75
  create_model_cache_dir,
77
- get_bos_token,
78
76
  get_class_by_name,
79
- get_eos_token,
80
77
  internet_connection_available,
81
78
  log_once,
82
79
  )
@@ -690,7 +687,7 @@ def load_model_and_tokenizer(
690
687
  assert model is not None, "The model should not be None."
691
688
 
692
689
  model.eval()
693
- model.to(benchmark_config.device)
690
+ model.to(benchmark_config.device) # type: ignore[arg-type]
694
691
 
695
692
  if (
696
693
  isinstance(model, PreTrainedModel)
@@ -797,12 +794,6 @@ def get_model_repo_info(
797
794
  tags += base_model_info.tags or list()
798
795
  tags = list(set(tags))
799
796
 
800
- # TEMP: This extends the `TASK_MAPPING` dictionary to include the missing
801
- # 'image-text-to-text' pipeline tag. This will be added as part of `TASK_MAPPING`
802
- # when this PR has been merged in and published:
803
- # https://github.com/huggingface/transformers/pull/37107
804
- TASK_MAPPING["image-text-to-text"] = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
805
-
806
797
  # Get the pipeline tag for the model. If it is not specified, then we determine it
807
798
  # by checking the model's architecture as written in the model's Hugging Face config
808
799
  pipeline_tag = model_info.pipeline_tag
@@ -824,7 +815,7 @@ def get_model_repo_info(
824
815
  generative_class_names = [
825
816
  class_name
826
817
  for tag in GENERATIVE_PIPELINE_TAGS
827
- for class_name in TASK_MAPPING.get(tag, dict()).values()
818
+ for class_name in TASK_MAPPING.get(tag, dict()).values() # type: ignore[attr-defined]
828
819
  ]
829
820
  if class_names is not None and any(
830
821
  class_name in generative_class_names for class_name in class_names
@@ -1083,17 +1074,20 @@ def setup_model_for_question_answering(model: "PreTrainedModel") -> "PreTrainedM
1083
1074
  for attribute in attribute_list:
1084
1075
  token_type_embeddings = getattr(token_type_embeddings, attribute)
1085
1076
 
1077
+ token_type_embedding_tensor = token_type_embeddings.weight.data
1078
+ assert isinstance(token_type_embedding_tensor, torch.Tensor)
1079
+
1086
1080
  # If the token type embeddings has shape (1, ...) then set the shape to
1087
1081
  # (2, ...) by randomly initializing the second token type embedding
1088
- if token_type_embeddings.weight.data.shape[0] == 1:
1082
+ if token_type_embedding_tensor.shape[0] == 1:
1089
1083
  token_type_embeddings.weight.data = torch.cat(
1090
1084
  (
1091
- token_type_embeddings.weight.data,
1092
- torch.rand_like(token_type_embeddings.weight.data),
1085
+ token_type_embedding_tensor,
1086
+ torch.rand_like(token_type_embedding_tensor),
1093
1087
  ),
1094
1088
  dim=0,
1095
1089
  )
1096
- token_type_embeddings.num_embeddings = 2
1090
+ token_type_embeddings.num_embeddings = 2 # type: ignore[assignment]
1097
1091
 
1098
1092
  # Set the model config to use the new type vocab size
1099
1093
  model.config.type_vocab_size = 2
@@ -1160,7 +1154,7 @@ def align_model_and_tokenizer(
1160
1154
  # Move the model to the CPU, since otherwise we can't catch the IndexErrors when
1161
1155
  # finding the maximum sequence length of the model
1162
1156
  model_device = model.device
1163
- model.to(torch.device("cpu"))
1157
+ model.to(torch.device("cpu")) # type: ignore[arg-type]
1164
1158
 
1165
1159
  # Manually check that this model max length is valid for the model, and adjust
1166
1160
  # otherwise
@@ -1182,8 +1176,16 @@ def align_model_and_tokenizer(
1182
1176
  except IndexError:
1183
1177
  continue
1184
1178
 
1179
+ except ValueError as e:
1180
+ # This happens when the model is using Triton, such as with ModernBERT,
1181
+ # which doesn't work with CPU tensors at all
1182
+ if "cpu tensor" in str(e):
1183
+ break
1184
+ else:
1185
+ raise e
1186
+
1185
1187
  # Move the model back to the original device
1186
- model.to(model_device)
1188
+ model.to(model_device) # type: ignore[arg-type]
1187
1189
 
1188
1190
  # If there is a mismatch between the vocab size according to the tokenizer and
1189
1191
  # the vocab size according to the model, we raise an error
@@ -32,10 +32,10 @@ from litellm.exceptions import (
32
32
  Timeout,
33
33
  )
34
34
  from litellm.llms.vertex_ai.common_utils import VertexAIError
35
- from litellm.types.utils import ModelResponse
35
+ from litellm.types.utils import ChoiceLogprobs, ModelResponse
36
36
  from requests.exceptions import RequestException
37
37
  from tqdm.auto import tqdm
38
- from transformers import Trainer
38
+ from transformers.trainer import Trainer
39
39
 
40
40
  from ..constants import MAX_LOGPROBS, REASONING_MAX_TOKENS, TASKS_USING_JSON
41
41
  from ..data_models import (
@@ -59,14 +59,15 @@ from ..exceptions import (
59
59
  NeedsEnvironmentVariable,
60
60
  NeedsExtraInstalled,
61
61
  )
62
- from ..task_utils import (
62
+ from ..task_group_utils import (
63
63
  question_answering,
64
64
  sequence_classification,
65
65
  text_to_text,
66
66
  token_classification,
67
67
  )
68
+ from ..tokenization_utils import get_first_label_token_mapping
68
69
  from ..types import ExtractLabelsFunction
69
- from ..utils import create_model_cache_dir, get_first_label_token_mapping, log_once
70
+ from ..utils import create_model_cache_dir, log_once
70
71
  from .base import BenchmarkModule
71
72
  from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokenizer
72
73
 
@@ -316,7 +317,7 @@ class LiteLLMModel(BenchmarkModule):
316
317
  elif isinstance(e, RateLimitError):
317
318
  raise InvalidModel(
318
319
  "You have encountered your rate limit for model "
319
- f"{self.model_config.model_id!r}. The error message was: {e}"
320
+ f"{self.model_config.model_id!r}. Skipping."
320
321
  )
321
322
  else:
322
323
  raise InvalidBenchmark(
@@ -366,14 +367,22 @@ class LiteLLMModel(BenchmarkModule):
366
367
  # Structure the model output as a GenerativeModelOutput object
367
368
  model_output = GenerativeModelOutput(sequences=[generation_output])
368
369
  if hasattr(model_response_choices, "logprobs"):
369
- logprobs_list: list[list[tuple[str, float]]] = [
370
- [
371
- (top_logprob.token, top_logprob.logprob)
372
- for top_logprob in content.top_logprobs
370
+ logprobs_obj = model_response_choices.logprobs
371
+ if isinstance(logprobs_obj, ChoiceLogprobs):
372
+ logprobs_list: list[list[tuple[str, float]]] = [
373
+ [
374
+ (top_logprob.token, top_logprob.logprob)
375
+ for top_logprob in content.top_logprobs
376
+ ]
377
+ for content in model_response_choices.logprobs.content or list()
373
378
  ]
374
- for content in model_response_choices.logprobs.content or list()
375
- ]
376
- model_output.scores = [logprobs_list]
379
+ model_output.scores = [logprobs_list]
380
+ else:
381
+ log_once(
382
+ "The logprobs object is malformed, so we won't use logprobs to "
383
+ "determine the labels.",
384
+ level=logging.WARNING,
385
+ )
377
386
 
378
387
  return model_output
379
388
 
@@ -403,7 +412,7 @@ class LiteLLMModel(BenchmarkModule):
403
412
  # get the number of parameters from the Hugging Face model configuration from
404
413
  # the Hugging Face Hub
405
414
  if self.model_config.model_id.startswith("huggingface/"):
406
- model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
415
+ model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
407
416
  if HuggingFaceEncoderModel.model_exists(
408
417
  model_id=model_id, benchmark_config=self.benchmark_config
409
418
  ):
@@ -467,7 +476,7 @@ class LiteLLMModel(BenchmarkModule):
467
476
  # get the vocabulary size from the Hugging Face model configuration from the
468
477
  # Hugging Face Hub
469
478
  if self.model_config.model_id.startswith("huggingface/"):
470
- model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
479
+ model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
471
480
  if HuggingFaceEncoderModel.model_exists(
472
481
  model_id=model_id, benchmark_config=self.benchmark_config
473
482
  ):
@@ -547,7 +556,7 @@ class LiteLLMModel(BenchmarkModule):
547
556
  # get the maximum length from the Hugging Face model configuration from the
548
557
  # Hugging Face Hub
549
558
  if self.model_config.model_id.startswith("huggingface/"):
550
- model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
559
+ model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
551
560
  if HuggingFaceEncoderModel.model_exists(
552
561
  model_id=model_id, benchmark_config=self.benchmark_config
553
562
  ):
@@ -688,42 +697,11 @@ class LiteLLMModel(BenchmarkModule):
688
697
  if model_id in litellm.model_list:
689
698
  return True
690
699
 
691
- # If it is an Ollama model then try to download it
700
+ # Separate check for Ollama models
692
701
  if model_id.startswith("ollama/") or model_id.startswith("ollama_chat/"):
693
- ollama_model_id = "/".join(model_id.split("/")[1:])
694
- downloaded_ollama_models: list[str] = [
695
- model_obj.model
696
- for model_obj in ollama.list().models
697
- if model_obj.model is not None
698
- ]
699
- if ollama_model_id not in downloaded_ollama_models:
700
- try:
701
- response = ollama.pull(model=ollama_model_id, stream=True)
702
- with tqdm(
703
- desc=f"Downloading {ollama_model_id}",
704
- unit_scale=True,
705
- unit="B",
706
- leave=False,
707
- ) as pbar:
708
- for status in response:
709
- if status.total is not None:
710
- pbar.total = status.total
711
- if status.completed is not None:
712
- pbar.update(status.completed - pbar.n)
713
- except ollama.ResponseError as e:
714
- if "file does not exist" in str(e).lower():
715
- return False
716
- else:
717
- raise InvalidModel(
718
- f"Failed to download Ollama model {ollama_model_id}. The "
719
- f"error message was: {e}"
720
- )
721
- else:
722
- log_once(
723
- f"Ollama model {ollama_model_id!r} already downloaded, so skipping "
724
- "download.",
725
- level=logging.DEBUG,
726
- )
702
+ ollama_model_exists = try_download_ollama_model(model_id=model_id)
703
+ if ollama_model_exists:
704
+ return ollama_model_exists
727
705
 
728
706
  num_attempts = 10
729
707
  for _ in range(num_attempts):
@@ -737,6 +715,10 @@ class LiteLLMModel(BenchmarkModule):
737
715
  api_version=benchmark_config.api_version,
738
716
  )
739
717
  return True
718
+ # A rate limit indicates that the model *does* exist, but we are being rate
719
+ # limited.
720
+ except RateLimitError:
721
+ return True
740
722
  except (
741
723
  APIConnectionError,
742
724
  Timeout,
@@ -748,12 +730,6 @@ class LiteLLMModel(BenchmarkModule):
748
730
  "Retrying in 10 seconds..."
749
731
  )
750
732
  sleep(5)
751
- except RateLimitError:
752
- logger.warning(
753
- f"Rate limit exceeded for model {model_id!r}. Retrying in 10 "
754
- "seconds..."
755
- )
756
- sleep(10)
757
733
  except APIError as e:
758
734
  if "'503 Service Unavailable" not in str(e):
759
735
  raise e
@@ -1155,3 +1131,91 @@ def raise_if_wrong_params(
1155
1131
  msg += " No parameters are allowed."
1156
1132
  raise InvalidModel(msg)
1157
1133
  return
1134
+
1135
+
1136
+ def try_download_ollama_model(model_id: str) -> bool:
1137
+ """Try to download an Ollama model.
1138
+
1139
+ Args:
1140
+ model_id:
1141
+ The model ID. If the model does not start with "ollama/" or "ollama_chat/"
1142
+ then this function will return False.
1143
+
1144
+ Returns:
1145
+ Whether the model was downloaded successfully.
1146
+ """
1147
+ if not (model_id.startswith("ollama/") or model_id.startswith("ollama_chat/")):
1148
+ return False
1149
+
1150
+ if model_id.startswith("ollama/"):
1151
+ log_once(
1152
+ "You're trying to benchmark a model with the old 'ollama/' prefix, which "
1153
+ "probably results in bad performance, as it doesn't use the model's chat "
1154
+ "template. If the model is not a chat model then just disregard this "
1155
+ "warning, but if it is a chat model then please cancel this run and "
1156
+ "use the 'ollama_chat/' prefix instead.",
1157
+ level=logging.WARNING,
1158
+ )
1159
+
1160
+ downloaded_ollama_models: list[str] = [
1161
+ model_obj.model
1162
+ for model_obj in ollama.list().models
1163
+ if model_obj.model is not None
1164
+ ]
1165
+
1166
+ ollama_model_id = "/".join(model_id.split("/")[1:])
1167
+ if ollama_model_id not in downloaded_ollama_models:
1168
+ # Try fetching the model info
1169
+ try:
1170
+ response = ollama.pull(model=ollama_model_id, stream=True)
1171
+ except ollama.ResponseError as e:
1172
+ if "file does not exist" in str(e).lower():
1173
+ # Check if the model exists if we prepend "hf.co/"
1174
+ try:
1175
+ ollama_model_id_with_prefix = f"hf.co/{ollama_model_id}"
1176
+ model_id_with_prefix = (
1177
+ f"{model_id.split('/')[0]}/{ollama_model_id_with_prefix}"
1178
+ )
1179
+ ollama.pull(model=ollama_model_id_with_prefix, stream=True)
1180
+ log_once(
1181
+ f"The model {model_id!r} cannot be found on Ollama, but the "
1182
+ f"model {model_id_with_prefix} *was* found, so we would "
1183
+ "recommend you cancelling this run and trying the evaluation "
1184
+ "with that model ID instead."
1185
+ )
1186
+ return False
1187
+ except ollama.ResponseError as inner_e:
1188
+ if "file does not exist" in str(inner_e).lower():
1189
+ return False
1190
+ else:
1191
+ raise InvalidModel(
1192
+ f"Failed to download Ollama model {ollama_model_id}. "
1193
+ f"The error message was: {inner_e}"
1194
+ )
1195
+ else:
1196
+ raise InvalidModel(
1197
+ f"Failed to download Ollama model {ollama_model_id}. "
1198
+ f"The error message was: {e}"
1199
+ )
1200
+
1201
+ # Download the model
1202
+ with tqdm(
1203
+ desc=f"Downloading {ollama_model_id}",
1204
+ unit_scale=True,
1205
+ unit="B",
1206
+ leave=False,
1207
+ ) as pbar:
1208
+ for status in response:
1209
+ if status.total is not None:
1210
+ pbar.total = status.total
1211
+ if status.completed is not None:
1212
+ pbar.update(status.completed - pbar.n)
1213
+ return True
1214
+
1215
+ else:
1216
+ log_once(
1217
+ f"Ollama model {ollama_model_id!r} already downloaded, so skipping "
1218
+ "download.",
1219
+ level=logging.DEBUG,
1220
+ )
1221
+ return True
@@ -1,6 +1,7 @@
1
1
  """Generative models using the vLLM inference framework."""
2
2
 
3
3
  import collections.abc as c
4
+ import contextlib
4
5
  import importlib.util
5
6
  import itertools as it
6
7
  import json
@@ -20,7 +21,10 @@ from datasets import DatasetDict
20
21
  from huggingface_hub import snapshot_download
21
22
  from pydantic import conlist, create_model
22
23
  from tqdm.auto import tqdm
23
- from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer, Trainer
24
+ from transformers.models.auto.configuration_auto import AutoConfig
25
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
26
+ from transformers.tokenization_utils import PreTrainedTokenizer
27
+ from transformers.trainer import Trainer
24
28
  from urllib3.exceptions import RequestError
25
29
 
26
30
  from ..constants import (
@@ -53,40 +57,39 @@ from ..exceptions import (
53
57
  NeedsExtraInstalled,
54
58
  )
55
59
  from ..languages import get_all_languages
56
- from ..task_utils import (
60
+ from ..task_group_utils import (
57
61
  question_answering,
58
62
  sequence_classification,
59
63
  text_to_text,
60
64
  token_classification,
61
65
  )
62
- from ..types import ExtractLabelsFunction
63
- from ..utils import (
64
- clear_memory,
65
- create_model_cache_dir,
66
+ from ..tokenization_utils import (
66
67
  get_bos_token,
67
68
  get_end_of_chat_token_ids,
68
69
  get_eos_token,
69
70
  get_first_label_token_mapping,
71
+ should_prompts_be_stripped,
72
+ )
73
+ from ..types import ExtractLabelsFunction
74
+ from ..utils import (
75
+ clear_memory,
76
+ create_model_cache_dir,
70
77
  get_min_cuda_compute_capability,
71
78
  log_once,
72
- should_prompts_be_stripped,
73
79
  )
74
80
  from .hf import HuggingFaceEncoderModel, get_model_repo_info, load_hf_model_config
75
81
 
76
82
  if t.TYPE_CHECKING or importlib.util.find_spec("vllm") is not None:
77
83
  from vllm import LLM, RequestOutput, SamplingParams
84
+ from vllm.distributed.parallel_state import (
85
+ destroy_distributed_environment,
86
+ destroy_model_parallel,
87
+ )
78
88
  from vllm.lora.request import LoRARequest
79
89
 
80
- try:
81
- from vllm.model_executor.parallel_utils.parallel_state import (
82
- destroy_model_parallel,
83
- )
84
- except ImportError:
85
- from vllm.distributed.parallel_state import destroy_model_parallel
86
-
87
90
  if t.TYPE_CHECKING or importlib.util.find_spec("outlines") is not None:
88
91
  from outlines.models.vllm import adapt_tokenizer
89
- from outlines.processors import JSONLogitsProcessor
92
+ from outlines.processors.structured import JSONLogitsProcessor
90
93
 
91
94
  if t.TYPE_CHECKING or importlib.util.find_spec("ray") is not None:
92
95
  import ray
@@ -156,6 +159,14 @@ class VLLMModel(HuggingFaceEncoderModel):
156
159
  lora_name="adapter", lora_int_id=1, lora_path=adapter_path
157
160
  )
158
161
 
162
+ def __del__(self) -> None:
163
+ """Clean up the model and tokenizer."""
164
+ clear_vllm()
165
+ if hasattr(self, "_model"):
166
+ del self._model
167
+ if hasattr(self, "_tokenizer"):
168
+ del self._tokenizer
169
+
159
170
  @property
160
171
  def generative_type(self) -> GenerativeType | None:
161
172
  """Get the generative type of the model.
@@ -330,7 +341,7 @@ class VLLMModel(HuggingFaceEncoderModel):
330
341
  pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
331
342
  logits_processor = JSONLogitsProcessor(
332
343
  schema=pydantic_class,
333
- tokenizer=adapt_tokenizer(tokenizer=self._tokenizer), #  type: ignore
344
+ tokenizer=adapt_tokenizer(tokenizer=self._tokenizer), # type: ignore
334
345
  whitespace_pattern=r" ?",
335
346
  )
336
347
  log_once(
@@ -982,19 +993,19 @@ def load_model_and_tokenizer(
982
993
 
983
994
  clear_vllm()
984
995
 
985
- executor_backend = "ray" if torch.cuda.device_count() > 1 else "mp"
986
-
987
996
  try:
988
997
  model = LLM(
989
998
  model=model_id,
990
999
  tokenizer=model_id,
991
- gpu_memory_utilization=0.95,
1000
+ gpu_memory_utilization=0.9,
992
1001
  max_model_len=min(true_max_model_len, MAX_CONTEXT_LENGTH),
993
1002
  download_dir=download_dir,
994
1003
  trust_remote_code=benchmark_config.trust_remote_code,
995
1004
  revision=revision,
996
1005
  seed=4242,
997
- distributed_executor_backend=executor_backend,
1006
+ distributed_executor_backend=(
1007
+ "ray" if torch.cuda.device_count() > 1 else "mp"
1008
+ ),
998
1009
  tensor_parallel_size=torch.cuda.device_count(),
999
1010
  disable_custom_all_reduce=True,
1000
1011
  quantization=quantization,
@@ -1145,13 +1156,16 @@ def _run_engine_with_fixed_progress_bars(
1145
1156
 
1146
1157
  def clear_vllm() -> None:
1147
1158
  """Clear the GPU memory used by the vLLM model, enabling re-initialisation."""
1148
- try:
1159
+ with contextlib.suppress(ValueError):
1149
1160
  destroy_model_parallel()
1150
- except ImportError:
1151
- pass
1152
- clear_memory()
1161
+ destroy_distributed_environment()
1153
1162
  if ray.is_initialized():
1154
1163
  ray.shutdown()
1164
+ with contextlib.suppress(AssertionError):
1165
+ torch.distributed.destroy_process_group()
1166
+ if ray.is_initialized():
1167
+ ray.shutdown()
1168
+ clear_memory()
1155
1169
 
1156
1170
 
1157
1171
  def get_end_of_reasoning_token_id(
@@ -1175,12 +1189,13 @@ def get_end_of_reasoning_token_id(
1175
1189
  if tokenizer.chat_template is None:
1176
1190
  prompt = "What is your name?"
1177
1191
  else:
1178
- prompt = tokenizer.apply_chat_template(
1192
+ templated_prompt = tokenizer.apply_chat_template(
1179
1193
  conversation=[dict(role="user", content="What is your name?")],
1180
1194
  add_generation_prompt=True,
1181
1195
  tokenize=False,
1182
1196
  )
1183
- assert isinstance(prompt, str)
1197
+ assert isinstance(templated_prompt, str)
1198
+ prompt = templated_prompt
1184
1199
 
1185
1200
  # Generate a completion and remove the BOS token from it, to not confuse it with the
1186
1201
  # potential reasoning token