EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  import collections.abc as c
4
4
  import logging
5
- import os
5
+ import re
6
6
  import typing as t
7
7
  from functools import cached_property, partial
8
8
  from json import JSONDecodeError
@@ -15,6 +15,7 @@ from huggingface_hub import HfApi
15
15
  from huggingface_hub import whoami as hf_whoami
16
16
  from huggingface_hub.errors import (
17
17
  GatedRepoError,
18
+ HfHubHTTPError,
18
19
  HFValidationError,
19
20
  LocalTokenNotFoundError,
20
21
  RepositoryNotFoundError,
@@ -35,6 +36,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
35
36
  from transformers.trainer import Trainer
36
37
  from urllib3.exceptions import RequestError
37
38
 
39
+ from ..caching_utils import cache_arguments
38
40
  from ..constants import (
39
41
  DUMMY_FILL_VALUE,
40
42
  GENERATIVE_PIPELINE_TAGS,
@@ -42,7 +44,7 @@ from ..constants import (
42
44
  MAX_CONTEXT_LENGTH,
43
45
  MERGE_TAGS,
44
46
  )
45
- from ..data_models import HFModelInfo, ModelConfig
47
+ from ..data_models import HashableDict, HFModelInfo, ModelConfig
46
48
  from ..enums import (
47
49
  BatchingPreference,
48
50
  GenerativeType,
@@ -57,19 +59,21 @@ from ..exceptions import (
57
59
  NeedsEnvironmentVariable,
58
60
  NeedsExtraInstalled,
59
61
  )
62
+ from ..generation_utils import raise_if_wrong_params
60
63
  from ..languages import get_all_languages
64
+ from ..logging_utils import block_terminal_output, log, log_once
61
65
  from ..task_group_utils import (
62
66
  multiple_choice_classification,
63
67
  question_answering,
64
68
  token_classification,
65
69
  )
66
- from ..tokenization_utils import get_bos_token, get_eos_token
70
+ from ..tokenisation_utils import get_bos_token, get_eos_token
67
71
  from ..utils import (
68
- block_terminal_output,
69
72
  create_model_cache_dir,
70
73
  get_class_by_name,
74
+ get_hf_token,
71
75
  internet_connection_available,
72
- log_once,
76
+ split_model_id,
73
77
  )
74
78
  from .base import BenchmarkModule
75
79
 
@@ -81,8 +85,6 @@ if t.TYPE_CHECKING:
81
85
  from ..data_models import BenchmarkConfig, DatasetConfig, Task
82
86
  from ..types import ExtractLabelsFunction
83
87
 
84
- logger = logging.getLogger("euroeval")
85
-
86
88
 
87
89
  class HuggingFaceEncoderModel(BenchmarkModule):
88
90
  """An encoder model from the Hugging Face Hub."""
@@ -90,12 +92,14 @@ class HuggingFaceEncoderModel(BenchmarkModule):
90
92
  fresh_model = False
91
93
  batching_preference = BatchingPreference.NO_PREFERENCE
92
94
  high_priority = True
95
+ allowed_params = {re.compile(r".*"): ["slow-tokenizer"]}
93
96
 
94
97
  def __init__(
95
98
  self,
96
99
  model_config: "ModelConfig",
97
100
  dataset_config: "DatasetConfig",
98
101
  benchmark_config: "BenchmarkConfig",
102
+ log_metadata: bool = True,
99
103
  ) -> None:
100
104
  """Initialise the model.
101
105
 
@@ -106,18 +110,24 @@ class HuggingFaceEncoderModel(BenchmarkModule):
106
110
  The dataset configuration.
107
111
  benchmark_config:
108
112
  The benchmark configuration.
113
+ log_metadata:
114
+ Whether to log the model metadata.
109
115
  """
110
- model, tokenizer = load_model_and_tokenizer(
116
+ raise_if_wrong_params(
117
+ model_config=model_config, allowed_params=self.allowed_params
118
+ )
119
+
120
+ model, tokeniser = load_model_and_tokeniser(
111
121
  model_config=model_config,
112
122
  dataset_config=dataset_config,
113
123
  benchmark_config=benchmark_config,
114
124
  )
115
125
  self._model: "PreTrainedModel" = model
116
- self._tokenizer: "PreTrainedTokenizer" = tokenizer
126
+ self._tokeniser: "PreTrainedTokenizer" = tokeniser
117
127
 
118
- self._model, self._tokenizer = align_model_and_tokenizer(
128
+ self._model, self._tokeniser = align_model_and_tokeniser(
119
129
  model=self._model,
120
- tokenizer=self._tokenizer,
130
+ tokeniser=self._tokeniser,
121
131
  model_max_length=self.model_max_length,
122
132
  raise_errors=benchmark_config.raise_errors,
123
133
  )
@@ -126,6 +136,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
126
136
  model_config=model_config,
127
137
  dataset_config=dataset_config,
128
138
  benchmark_config=benchmark_config,
139
+ log_metadata=log_metadata,
129
140
  )
130
141
 
131
142
  @cached_property
@@ -135,23 +146,25 @@ class HuggingFaceEncoderModel(BenchmarkModule):
135
146
  Returns:
136
147
  The number of parameters in the model.
137
148
  """
138
- token = (
139
- self.benchmark_config.api_key or os.getenv("HUGGINGFACE_API_KEY") or True
140
- )
141
- hf_api = HfApi(token=token)
142
- try:
143
- repo_info = hf_api.model_info(
144
- repo_id=self.model_config.adapter_base_model_id
145
- or self.model_config.model_id,
146
- revision=self.model_config.revision,
147
- )
148
- except (
149
- RepositoryNotFoundError,
150
- RevisionNotFoundError,
151
- RequestException,
152
- HFValidationError,
153
- ):
149
+ # No need to try to use the API if we have no internet.
150
+ if not internet_connection_available():
154
151
  repo_info = None
152
+ else:
153
+ token = get_hf_token(api_key=self.benchmark_config.api_key)
154
+ hf_api = HfApi(token=token)
155
+ try:
156
+ repo_info = hf_api.model_info(
157
+ repo_id=self.model_config.adapter_base_model_id
158
+ or self.model_config.model_id,
159
+ revision=self.model_config.revision,
160
+ )
161
+ except (
162
+ RepositoryNotFoundError,
163
+ RevisionNotFoundError,
164
+ RequestException,
165
+ HFValidationError,
166
+ ):
167
+ repo_info = None
155
168
 
156
169
  if (
157
170
  repo_info is not None
@@ -168,12 +181,13 @@ class HuggingFaceEncoderModel(BenchmarkModule):
168
181
  elif hasattr(self._model, "parameters"):
169
182
  num_params = sum(p.numel() for p in self._model.parameters())
170
183
  else:
171
- logger.warning(
184
+ log(
172
185
  "The number of parameters could not be determined for the model, since "
173
186
  "the model is not stored in the safetensors format. If this is your "
174
187
  "own model, then you can use this Hugging Face Space to convert your "
175
188
  "model to the safetensors format: "
176
- "https://huggingface.co/spaces/safetensors/convert."
189
+ "https://huggingface.co/spaces/safetensors/convert.",
190
+ level=logging.WARNING,
177
191
  )
178
192
  num_params = -1
179
193
  return num_params
@@ -191,10 +205,10 @@ class HuggingFaceEncoderModel(BenchmarkModule):
191
205
  ):
192
206
  vocab_size = self._model.config.vocab_size
193
207
  elif (
194
- hasattr(self._tokenizer, "vocab_size")
195
- and self._tokenizer.vocab_size is not None
208
+ hasattr(self._tokeniser, "vocab_size")
209
+ and self._tokeniser.vocab_size is not None
196
210
  ):
197
- vocab_size = self._tokenizer.vocab_size
211
+ vocab_size = self._tokeniser.vocab_size
198
212
  else:
199
213
  vocab_size = -1
200
214
  return vocab_size
@@ -208,18 +222,18 @@ class HuggingFaceEncoderModel(BenchmarkModule):
208
222
  """
209
223
  all_max_lengths: list[int] = list()
210
224
 
211
- # Add the registered max length of the tokenizer
225
+ # Add the registered max length of the tokeniser
212
226
  if hasattr(
213
- self._tokenizer, "model_max_length"
214
- ) and self._tokenizer.model_max_length < int(1e30):
215
- all_max_lengths.append(self._tokenizer.model_max_length)
227
+ self._tokeniser, "model_max_length"
228
+ ) and self._tokeniser.model_max_length < int(1e30):
229
+ all_max_lengths.append(self._tokeniser.model_max_length)
216
230
 
217
231
  # Add the max length derived from the model's input sizes
218
- if hasattr(self._tokenizer, "max_model_input_sizes"):
232
+ if hasattr(self._tokeniser, "max_model_input_sizes"):
219
233
  all_max_lengths.extend(
220
234
  [
221
235
  size
222
- for size in self._tokenizer.max_model_input_sizes.values()
236
+ for size in self._tokeniser.max_model_input_sizes.values()
223
237
  if size is not None
224
238
  ]
225
239
  )
@@ -245,15 +259,6 @@ class HuggingFaceEncoderModel(BenchmarkModule):
245
259
  max_length for max_length in all_max_lengths if max_length >= 128
246
260
  ]
247
261
 
248
- # We remove the upper cap of maximum context length for the model, as it is
249
- # highly unlikely that this is the model's actual maximum context length - we
250
- # would rather not report a value than report an incorrect one.
251
- all_max_lengths = [
252
- max_length
253
- for max_length in all_max_lengths
254
- if max_length != MAX_CONTEXT_LENGTH
255
- ]
256
-
257
262
  if len(list(all_max_lengths)) > 0:
258
263
  model_max_length = min(list(all_max_lengths))
259
264
  else:
@@ -262,7 +267,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
262
267
  return model_max_length
263
268
 
264
269
  @property
265
- def data_collator(self) -> c.Callable[[list[t.Any]], dict[str, t.Any]]:
270
+ def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
266
271
  """The data collator used to prepare samples during finetuning.
267
272
 
268
273
  Returns:
@@ -275,10 +280,10 @@ class HuggingFaceEncoderModel(BenchmarkModule):
275
280
  | TaskGroup.QUESTION_ANSWERING
276
281
  | TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
277
282
  ):
278
- return DataCollatorWithPadding(self._tokenizer, padding="longest")
283
+ return DataCollatorWithPadding(self._tokeniser, padding="longest")
279
284
  case TaskGroup.TOKEN_CLASSIFICATION:
280
285
  return DataCollatorForTokenClassification(
281
- tokenizer=self._tokenizer, label_pad_token_id=-100
286
+ tokenizer=self._tokeniser, label_pad_token_id=-100
282
287
  )
283
288
  case _:
284
289
  raise NotImplementedError(
@@ -357,16 +362,16 @@ class HuggingFaceEncoderModel(BenchmarkModule):
357
362
  self._model.config.label2id[lbl.lower()]
358
363
  for lbl in examples["label"]
359
364
  ]
360
- except KeyError:
365
+ except KeyError as e:
361
366
  raise InvalidBenchmark(
362
367
  f"One of the labels in the dataset, "
363
368
  f"{examples['label'].lower()}, does not occur in the "
364
369
  f"label2id dictionary {self._model.config.label2id}."
365
- )
370
+ ) from e
366
371
  return examples
367
372
 
368
373
  def tokenise(examples: dict) -> "BatchEncoding":
369
- return self._tokenizer(text=examples["text"], truncation=True, padding=True)
374
+ return self._tokeniser(text=examples["text"], truncation=True, padding=True)
370
375
 
371
376
  match task.task_group:
372
377
  case TaskGroup.SEQUENCE_CLASSIFICATION:
@@ -376,39 +381,20 @@ class HuggingFaceEncoderModel(BenchmarkModule):
376
381
 
377
382
  case TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
378
383
  dataset = DatasetDict(
379
- train=dataset["train"].map(
380
- partial(
381
- multiple_choice_classification.prepare_examples,
382
- tokenizer=self._tokenizer,
383
- ),
384
- batched=True,
385
- batch_size=10,
386
- remove_columns=dataset["train"].column_names,
387
- load_from_cache_file=False,
388
- keep_in_memory=True,
389
- ),
390
- val=dataset["val"].map(
391
- partial(
392
- multiple_choice_classification.prepare_examples,
393
- tokenizer=self._tokenizer,
394
- ),
395
- batched=True,
396
- batch_size=10,
397
- remove_columns=dataset["val"].column_names,
398
- load_from_cache_file=False,
399
- keep_in_memory=True,
400
- ),
401
- test=dataset["test"].map(
402
- partial(
403
- multiple_choice_classification.prepare_examples,
404
- tokenizer=self._tokenizer,
405
- ),
406
- batched=True,
407
- batch_size=10,
408
- remove_columns=dataset["test"].column_names,
409
- load_from_cache_file=False,
410
- keep_in_memory=True,
411
- ),
384
+ {
385
+ split_name: split.map(
386
+ partial(
387
+ multiple_choice_classification.prepare_examples,
388
+ tokeniser=self._tokeniser,
389
+ ),
390
+ batched=True,
391
+ batch_size=10,
392
+ remove_columns=split.column_names,
393
+ load_from_cache_file=False,
394
+ keep_in_memory=True,
395
+ )
396
+ for split_name, split in dataset.items()
397
+ }
412
398
  )
413
399
 
414
400
  case TaskGroup.TEXT_TO_TEXT:
@@ -423,7 +409,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
423
409
  dataset = dataset.map(
424
410
  partial(
425
411
  token_classification.tokenize_and_align_labels,
426
- tokenizer=self._tokenizer,
412
+ tokeniser=self._tokeniser,
427
413
  label2id=self._model.config.label2id,
428
414
  ),
429
415
  batched=True,
@@ -432,43 +418,44 @@ class HuggingFaceEncoderModel(BenchmarkModule):
432
418
  )
433
419
 
434
420
  case TaskGroup.QUESTION_ANSWERING:
435
- dataset = DatasetDict(
436
- dict(
437
- train=dataset["train"].map(
438
- partial(
439
- question_answering.prepare_train_examples,
440
- tokenizer=self._tokenizer,
441
- ),
442
- batched=True,
443
- batch_size=10,
444
- remove_columns=dataset["test"].column_names,
445
- load_from_cache_file=False,
446
- keep_in_memory=True,
421
+ data_dict = dict()
422
+ if "train" in dataset:
423
+ data_dict["train"] = dataset["train"].map(
424
+ partial(
425
+ question_answering.prepare_train_examples,
426
+ tokeniser=self._tokeniser,
447
427
  ),
448
- val=dataset["val"].map(
449
- partial(
450
- question_answering.prepare_train_examples,
451
- tokenizer=self._tokenizer,
452
- ),
453
- batched=True,
454
- batch_size=10,
455
- remove_columns=dataset["test"].column_names,
456
- load_from_cache_file=False,
457
- keep_in_memory=True,
428
+ batched=True,
429
+ batch_size=10,
430
+ remove_columns=dataset["test"].column_names,
431
+ load_from_cache_file=False,
432
+ keep_in_memory=True,
433
+ )
434
+ if "val" in dataset:
435
+ data_dict["val"] = dataset["val"].map(
436
+ partial(
437
+ question_answering.prepare_train_examples,
438
+ tokeniser=self._tokeniser,
458
439
  ),
459
- test=dataset["test"].map(
460
- partial(
461
- question_answering.prepare_test_examples,
462
- tokenizer=self._tokenizer,
463
- ),
464
- batched=True,
465
- batch_size=10,
466
- remove_columns=dataset["test"].column_names,
467
- load_from_cache_file=False,
468
- keep_in_memory=True,
440
+ batched=True,
441
+ batch_size=10,
442
+ remove_columns=dataset["test"].column_names,
443
+ load_from_cache_file=False,
444
+ keep_in_memory=True,
445
+ )
446
+ if "test" in dataset:
447
+ data_dict["test"] = dataset["test"].map(
448
+ partial(
449
+ question_answering.prepare_test_examples,
450
+ tokeniser=self._tokeniser,
469
451
  ),
452
+ batched=True,
453
+ batch_size=10,
454
+ remove_columns=dataset["test"].column_names,
455
+ load_from_cache_file=False,
456
+ keep_in_memory=True,
470
457
  )
471
- )
458
+ dataset = DatasetDict(data_dict)
472
459
 
473
460
  # The Trainer hides the columns that are not used by the model (here
474
461
  # `id` and `offset_mapping` which we will need for our post-processing),
@@ -499,11 +486,15 @@ class HuggingFaceEncoderModel(BenchmarkModule):
499
486
  Whether the model exists, or an error describing why we cannot check
500
487
  whether the model exists.
501
488
  """
502
- model_id, revision = (
503
- model_id.split("@") if "@" in model_id else (model_id, "main")
504
- )
489
+ model_id_components = split_model_id(model_id=model_id)
505
490
  model_info = get_model_repo_info(
506
- model_id=model_id, revision=revision, benchmark_config=benchmark_config
491
+ model_id=model_id_components.model_id,
492
+ revision=model_id_components.revision,
493
+ api_key=benchmark_config.api_key,
494
+ cache_dir=benchmark_config.cache_dir,
495
+ trust_remote_code=benchmark_config.trust_remote_code,
496
+ requires_safetensors=benchmark_config.requires_safetensors,
497
+ run_with_cli=benchmark_config.run_with_cli,
507
498
  )
508
499
  return (
509
500
  model_info is not None
@@ -525,11 +516,15 @@ class HuggingFaceEncoderModel(BenchmarkModule):
525
516
  Returns:
526
517
  The model configuration.
527
518
  """
528
- model_id, revision = (
529
- model_id.split("@") if "@" in model_id else (model_id, "main")
530
- )
519
+ model_id_components = split_model_id(model_id=model_id)
531
520
  model_info = get_model_repo_info(
532
- model_id=model_id, revision=revision, benchmark_config=benchmark_config
521
+ model_id=model_id_components.model_id,
522
+ revision=model_id_components.revision,
523
+ api_key=benchmark_config.api_key,
524
+ cache_dir=benchmark_config.cache_dir,
525
+ trust_remote_code=benchmark_config.trust_remote_code,
526
+ requires_safetensors=benchmark_config.requires_safetensors,
527
+ run_with_cli=benchmark_config.run_with_cli,
533
528
  )
534
529
  if model_info is None:
535
530
  raise InvalidModel(f"The model {model_id!r} could not be found.")
@@ -538,8 +533,9 @@ class HuggingFaceEncoderModel(BenchmarkModule):
538
533
  language_codes = list(language_mapping.keys())
539
534
 
540
535
  model_config = ModelConfig(
541
- model_id=model_id,
542
- revision=revision,
536
+ model_id=model_id_components.model_id,
537
+ revision=model_id_components.revision,
538
+ param=model_id_components.param,
543
539
  task=model_info.pipeline_tag,
544
540
  languages=[
545
541
  language_mapping[tag]
@@ -559,12 +555,12 @@ class HuggingFaceEncoderModel(BenchmarkModule):
559
555
  return model_config
560
556
 
561
557
 
562
- def load_model_and_tokenizer(
558
+ def load_model_and_tokeniser(
563
559
  model_config: "ModelConfig",
564
560
  dataset_config: "DatasetConfig",
565
561
  benchmark_config: "BenchmarkConfig",
566
562
  ) -> tuple["PreTrainedModel", "PreTrainedTokenizer"]:
567
- """Load the model and tokenizer.
563
+ """Load the model and tokeniser.
568
564
 
569
565
  Args:
570
566
  model_config:
@@ -575,7 +571,7 @@ def load_model_and_tokenizer(
575
571
  The benchmark configuration
576
572
 
577
573
  Returns:
578
- The loaded model and tokenizer.
574
+ A pair (model, tokeniser), with the loaded model and tokeniser
579
575
  """
580
576
  config: "PretrainedConfig"
581
577
  block_terminal_output()
@@ -594,8 +590,8 @@ def load_model_and_tokenizer(
594
590
  config = load_hf_model_config(
595
591
  model_id=model_id,
596
592
  num_labels=len(id2label),
597
- id2label=id2label,
598
- label2id={label: idx for idx, label in id2label.items()},
593
+ id2label=HashableDict(id2label),
594
+ label2id=HashableDict({label: idx for idx, label in id2label.items()}),
599
595
  revision=model_config.revision,
600
596
  model_cache_dir=model_config.model_cache_dir,
601
597
  api_key=benchmark_config.api_key,
@@ -607,23 +603,20 @@ def load_model_and_tokenizer(
607
603
  config=config,
608
604
  ignore_mismatched_sizes=ignore_mismatched_sizes,
609
605
  revision=model_config.revision,
610
- token=benchmark_config.api_key or os.getenv("HUGGINGFACE_API_KEY") or True,
606
+ token=get_hf_token(api_key=benchmark_config.api_key),
611
607
  cache_dir=model_config.model_cache_dir,
612
608
  trust_remote_code=benchmark_config.trust_remote_code,
613
- torch_dtype=get_torch_dtype(
609
+ dtype=get_dtype(
614
610
  device=benchmark_config.device,
615
- torch_dtype_is_set=config.to_dict().get("torch_dtype") is not None,
611
+ dtype_is_set=config.to_dict().get("dtype") is not None,
616
612
  bf16_available=(
617
613
  torch.cuda.is_available() and torch.cuda.is_bf16_supported()
618
614
  ),
619
615
  ),
620
616
  )
621
617
 
622
- # These are used when a timeout occurs
623
- attempts_left = 5
624
-
625
618
  model: "PreTrainedModel | None" = None
626
- while True:
619
+ for _ in range(num_attempts := 5):
627
620
  # Get the model class associated with the task group
628
621
  model_cls_or_none: t.Type["PreTrainedModel"] | None = get_class_by_name(
629
622
  class_name=task_group_to_class_name(task_group=task_group),
@@ -650,36 +643,41 @@ def load_model_and_tokenizer(
650
643
  break
651
644
  except (KeyError, RuntimeError) as e:
652
645
  if not model_kwargs["ignore_mismatched_sizes"]:
653
- logger.debug(
646
+ log(
654
647
  f"{type(e).__name__} occurred during the loading "
655
648
  f"of the {model_id!r} model. Retrying with "
656
- "`ignore_mismatched_sizes` set to True."
649
+ "`ignore_mismatched_sizes` set to True.",
650
+ level=logging.DEBUG,
657
651
  )
658
652
  model_kwargs["ignore_mismatched_sizes"] = True
659
653
  continue
660
654
  else:
661
- raise InvalidModel(str(e))
655
+ raise InvalidModel(str(e)) from e
662
656
  except (TimeoutError, RequestError):
663
- attempts_left -= 1
664
- if attempts_left == 0:
665
- raise InvalidModel("The model could not be loaded after 5 attempts.")
666
- logger.info(f"Couldn't load the model {model_id!r}. Retrying.")
657
+ log(
658
+ f"Couldn't load the model {model_id!r}. Retrying.",
659
+ level=logging.WARNING,
660
+ )
667
661
  sleep(5)
668
662
  continue
669
663
  except (OSError, ValueError) as e:
670
664
  if "checkpoint seems to be incorrect" in str(e):
671
665
  raise InvalidModel(
672
666
  f"The model {model_id!r} has an incorrect checkpoint."
673
- )
667
+ ) from e
674
668
  if "trust_remote_code" in str(e):
675
669
  raise InvalidModel(
676
670
  f"Loading the model {model_id!r} needs to trust remote code. "
677
671
  "If you trust the suppliers of this model, then you can enable "
678
672
  "this by setting the `--trust-remote-code` flag."
679
- )
673
+ ) from e
680
674
  raise InvalidModel(
681
675
  f"The model {model_id!r} could not be loaded. The error was {e!r}."
682
- )
676
+ ) from e
677
+ else:
678
+ raise InvalidModel(
679
+ f"Could not load the model {model_id!r} after {num_attempts} attempts."
680
+ )
683
681
 
684
682
  if isinstance(model_or_tuple, tuple):
685
683
  model = model_or_tuple[0]
@@ -697,17 +695,25 @@ def load_model_and_tokenizer(
697
695
  ):
698
696
  model = setup_model_for_question_answering(model=model)
699
697
 
700
- tokenizer = load_tokenizer(
698
+ tokeniser = load_tokeniser(
701
699
  model=model,
702
700
  model_id=model_id,
703
701
  trust_remote_code=benchmark_config.trust_remote_code,
702
+ model_config=model_config,
704
703
  )
705
704
 
706
- return model, tokenizer
705
+ return model, tokeniser
707
706
 
708
707
 
708
+ @cache_arguments("model_id", "revision")
709
709
  def get_model_repo_info(
710
- model_id: str, revision: str, benchmark_config: "BenchmarkConfig"
710
+ model_id: str,
711
+ revision: str,
712
+ api_key: str | None,
713
+ cache_dir: str,
714
+ trust_remote_code: bool,
715
+ requires_safetensors: bool,
716
+ run_with_cli: bool,
711
717
  ) -> "HFModelInfo | None":
712
718
  """Get the information about the model from the HF Hub or a local directory.
713
719
 
@@ -716,28 +722,30 @@ def get_model_repo_info(
716
722
  The model ID.
717
723
  revision:
718
724
  The revision of the model.
719
- benchmark_config:
720
- The benchmark configuration.
721
725
 
722
726
  Returns:
723
727
  The information about the model, or None if the model could not be found.
724
728
  """
725
- token = benchmark_config.api_key or os.getenv("HUGGINGFACE_API_KEY") or True
729
+ token = get_hf_token(api_key=api_key)
726
730
  hf_api = HfApi(token=token)
727
- model_id, revision = model_id.split("@") if "@" in model_id else (model_id, "main")
728
731
 
729
732
  # Get information on the model.
730
733
  # The first case is when the model is a local model, in which case we create a dummy
731
734
  # model info object.
732
735
  model_info: HfApiModelInfo | None = None
733
736
  if Path(model_id).is_dir():
734
- logger.debug(f"Checking for local model in {model_id}.")
737
+ log(f"Checking for local model in {model_id}.", level=logging.DEBUG)
735
738
  if all(
736
739
  (Path(model_id) / required_file).exists()
737
740
  for required_file in LOCAL_MODELS_REQUIRED_FILES
738
741
  ):
739
742
  model_info = HfApiModelInfo(id=model_id, tags=None, pipeline_tag=None)
740
743
 
744
+ # If we have not internet, and the model_id is not a directory for a local model
745
+ # we also just create a dummy model info object.
746
+ elif not internet_connection_available():
747
+ model_info = HfApiModelInfo(id=model_id, tags=None, pipeline_tag=None)
748
+
741
749
  # If the model does not exist locally, then we get the model info from the Hugging
742
750
  # Face Hub, if possible
743
751
  if model_info is None:
@@ -752,35 +760,39 @@ def get_model_repo_info(
752
760
  except (GatedRepoError, LocalTokenNotFoundError) as e:
753
761
  try:
754
762
  hf_whoami(token=token)
755
- logger.debug(
763
+ log(
756
764
  f"Could not access the model {model_id} with the revision "
757
- f"{revision}. The error was {str(e)!r}."
765
+ f"{revision}. The error was {str(e)!r}.",
766
+ level=logging.DEBUG,
758
767
  )
759
768
  return None
760
769
  except LocalTokenNotFoundError:
761
- logger.debug(
770
+ log(
762
771
  f"Could not access the model {model_id} with the revision "
763
772
  f"{revision}. The error was {str(e)!r}. Please set the "
764
773
  "`HUGGINGFACE_API_KEY` environment variable or use the "
765
- "`--api-key` argument."
774
+ "`--api-key` argument.",
775
+ level=logging.DEBUG,
766
776
  )
767
777
  return None
768
- except (RepositoryNotFoundError, HFValidationError):
778
+ except (RepositoryNotFoundError, HFValidationError, HfHubHTTPError):
769
779
  return None
770
780
  except (OSError, RequestException) as e:
771
781
  if internet_connection_available():
772
782
  errors.append(e)
773
783
  continue
774
- logger.debug(
784
+ log(
775
785
  "Could not access the Hugging Face Hub. Please check your internet "
776
- "connection."
786
+ "connection.",
787
+ level=logging.DEBUG,
777
788
  )
778
789
  return None
779
790
  else:
780
- logger.debug(
791
+ log(
781
792
  f"Could not access model info for the model {model_id!r} from the "
782
793
  f"Hugging Face Hub, after {num_attempts} attempts. The errors "
783
- f"encountered were {errors!r}."
794
+ f"encountered were {errors!r}.",
795
+ level=logging.DEBUG,
784
796
  )
785
797
  return None
786
798
 
@@ -800,12 +812,7 @@ def get_model_repo_info(
800
812
  level=logging.DEBUG,
801
813
  )
802
814
  if base_model_id is not None:
803
- base_model_info = hf_api.model_info(
804
- repo_id=base_model_id,
805
- token=benchmark_config.api_key
806
- or os.getenv("HUGGINGFACE_API_KEY")
807
- or True,
808
- )
815
+ base_model_info = hf_api.model_info(repo_id=base_model_id, token=token)
809
816
  tags += base_model_info.tags or list()
810
817
  tags = list(set(tags))
811
818
 
@@ -816,15 +823,15 @@ def get_model_repo_info(
816
823
  hf_config = load_hf_model_config(
817
824
  model_id=base_model_id or model_id,
818
825
  num_labels=0,
819
- id2label=dict(),
820
- label2id=dict(),
826
+ id2label=HashableDict(),
827
+ label2id=HashableDict(),
821
828
  revision=revision,
822
829
  model_cache_dir=create_model_cache_dir(
823
- cache_dir=benchmark_config.cache_dir, model_id=model_id
830
+ cache_dir=cache_dir, model_id=model_id
824
831
  ),
825
- api_key=benchmark_config.api_key,
826
- trust_remote_code=benchmark_config.trust_remote_code,
827
- run_with_cli=benchmark_config.run_with_cli,
832
+ api_key=api_key,
833
+ trust_remote_code=trust_remote_code,
834
+ run_with_cli=run_with_cli,
828
835
  )
829
836
  class_names = hf_config.architectures
830
837
  generative_class_names = [
@@ -839,19 +846,19 @@ def get_model_repo_info(
839
846
  else:
840
847
  pipeline_tag = "fill-mask"
841
848
 
842
- if benchmark_config.only_allow_safetensors:
849
+ if requires_safetensors:
843
850
  repo_files = hf_api.list_repo_files(repo_id=model_id, revision=revision)
844
851
  has_safetensors = any(f.endswith(".safetensors") for f in repo_files)
845
852
  if not has_safetensors:
846
853
  msg = f"Model {model_id} does not have safetensors weights available. "
847
- if benchmark_config.run_with_cli:
854
+ if run_with_cli:
848
855
  msg += "Skipping since the `--only-allow-safetensors` flag is set."
849
856
  else:
850
857
  msg += (
851
- "Skipping since the `only_allow_safetensors` argument is set "
858
+ "Skipping since the `requires_safetensors` argument is set "
852
859
  "to `True`."
853
860
  )
854
- logger.warning(msg)
861
+ log(msg, level=logging.WARNING)
855
862
  return None
856
863
 
857
864
  # Also check base model if we are evaluating an adapter
@@ -865,11 +872,11 @@ def get_model_repo_info(
865
872
  f"Base model {base_model_id} does not have safetensors weights "
866
873
  "available."
867
874
  )
868
- if benchmark_config.run_with_cli:
875
+ if run_with_cli:
869
876
  msg += " Skipping since the `--only-allow-safetensors` flag is set."
870
877
  else:
871
878
  msg += (
872
- " Skipping since the `only_allow_safetensors` argument is set "
879
+ " Skipping since the `requires_safetensors` argument is set "
873
880
  "to `True`."
874
881
  )
875
882
  logging.warning(msg)
@@ -880,10 +887,13 @@ def get_model_repo_info(
880
887
  )
881
888
 
882
889
 
883
- def load_tokenizer(
884
- model: "PreTrainedModel | None", model_id: str, trust_remote_code: bool
890
+ def load_tokeniser(
891
+ model: "PreTrainedModel | None",
892
+ model_id: str,
893
+ trust_remote_code: bool,
894
+ model_config: "ModelConfig",
885
895
  ) -> "PreTrainedTokenizer":
886
- """Load the tokenizer.
896
+ """Load the tokeniser.
887
897
 
888
898
  Args:
889
899
  model:
@@ -893,16 +903,19 @@ def load_tokenizer(
893
903
  The model identifier. Used for logging.
894
904
  trust_remote_code:
895
905
  Whether to trust remote code.
906
+ model_config:
907
+ The model configuration.
896
908
 
897
909
  Returns:
898
- The loaded tokenizer.
910
+ The loaded tokeniser.
899
911
  """
900
912
  loading_kwargs: dict[str, bool | str] = dict(
901
- use_fast=True,
913
+ use_fast=False if model_config.param == "slow-tokenizer" else True,
902
914
  verbose=False,
903
915
  trust_remote_code=trust_remote_code,
904
916
  padding_side="right",
905
917
  truncation_side="right",
918
+ cache_dir=model_config.model_cache_dir,
906
919
  )
907
920
 
908
921
  # If the model is a subclass of a certain model types then we have to add a prefix
@@ -918,45 +931,51 @@ def load_tokenizer(
918
931
  num_retries = 5
919
932
  for _ in range(num_retries):
920
933
  try:
921
- tokenizer = AutoTokenizer.from_pretrained(model_id, **loading_kwargs)
934
+ tokeniser = AutoTokenizer.from_pretrained(model_id, **loading_kwargs)
922
935
  break
923
- except (JSONDecodeError, OSError, TypeError):
924
- raise InvalidModel(f"Could not load tokenizer for model {model_id!r}.")
936
+ except (JSONDecodeError, OSError, TypeError) as e:
937
+ raise InvalidModel(
938
+ f"Could not load tokeniser for model {model_id!r}."
939
+ ) from e
925
940
  except (TimeoutError, RequestError):
926
- logger.info(f"Couldn't load tokenizer for {model_id!r}. Retrying.")
941
+ log(
942
+ f"Couldn't load tokeniser for {model_id!r}. Retrying.",
943
+ level=logging.WARNING,
944
+ )
927
945
  sleep(5)
928
946
  continue
929
947
  else:
930
948
  raise InvalidModel(
931
- f"Could not load tokenizer for model {model_id!r} after {num_retries} "
949
+ f"Could not load tokeniser for model {model_id!r} after {num_retries} "
932
950
  "attempts."
933
951
  )
934
952
 
935
953
  # Ensure that BOS, EOS and PAD tokens are set
936
- tokenizer.bos_token, tokenizer.bos_token_id = get_bos_token(tokenizer=tokenizer)
937
- tokenizer.eos_token, tokenizer.eos_token_id = get_eos_token(tokenizer=tokenizer)
954
+ tokeniser.bos_token, tokeniser.bos_token_id = get_bos_token(tokeniser=tokeniser)
955
+ tokeniser.eos_token, tokeniser.eos_token_id = get_eos_token(tokeniser=tokeniser)
938
956
 
939
- return tokenizer
957
+ return tokeniser
940
958
 
941
959
 
942
- def get_torch_dtype(
943
- device: torch.device, torch_dtype_is_set: bool, bf16_available: bool
960
+ @cache_arguments()
961
+ def get_dtype(
962
+ device: torch.device, dtype_is_set: bool, bf16_available: bool
944
963
  ) -> str | torch.dtype:
945
964
  """Get the torch dtype, used for loading the model.
946
965
 
947
966
  Args:
948
967
  device:
949
968
  The device to use.
950
- torch_dtype_is_set:
951
- Whether the torch data type is set in the model configuration.
969
+ dtype_is_set:
970
+ Whether the data type is set in the model configuration.
952
971
  bf16_available:
953
972
  Whether bfloat16 is available.
954
973
 
955
974
  Returns:
956
- The torch dtype.
975
+ The dtype.
957
976
  """
958
977
  using_cuda = device == torch.device("cuda")
959
- if using_cuda and torch_dtype_is_set:
978
+ if using_cuda and dtype_is_set:
960
979
  return "auto"
961
980
  elif using_cuda and bf16_available:
962
981
  return torch.bfloat16
@@ -965,6 +984,7 @@ def get_torch_dtype(
965
984
  return torch.float32
966
985
 
967
986
 
987
+ @cache_arguments("model_id", "revision", "num_labels", "id2label", "label2id")
968
988
  def load_hf_model_config(
969
989
  model_id: str,
970
990
  num_labels: int,
@@ -1001,7 +1021,7 @@ def load_hf_model_config(
1001
1021
  Returns:
1002
1022
  The Hugging Face model configuration.
1003
1023
  """
1004
- while True:
1024
+ for _ in range(num_attempts := 5):
1005
1025
  try:
1006
1026
  config = AutoConfig.from_pretrained(
1007
1027
  model_id,
@@ -1009,35 +1029,36 @@ def load_hf_model_config(
1009
1029
  id2label=id2label,
1010
1030
  label2id=label2id,
1011
1031
  revision=revision,
1012
- token=api_key or os.getenv("HUGGINGFACE_API_KEY") or True,
1032
+ token=get_hf_token(api_key=api_key),
1013
1033
  trust_remote_code=trust_remote_code,
1014
1034
  cache_dir=model_cache_dir,
1035
+ local_files_only=not internet_connection_available(),
1015
1036
  )
1016
- if config.eos_token_id is not None and config.pad_token_id is None:
1017
- if isinstance(config.eos_token_id, list):
1018
- config.pad_token_id = config.eos_token_id[0]
1019
- else:
1020
- config.pad_token_id = config.eos_token_id
1021
- return config
1037
+ break
1022
1038
  except KeyError as e:
1023
1039
  key = e.args[0]
1024
1040
  raise InvalidModel(
1025
1041
  f"The model config for the model {model_id!r} could not be "
1026
1042
  f"loaded, as the key {key!r} was not found in the config."
1027
- )
1043
+ ) from e
1028
1044
  except (OSError, GatedRepoError) as e:
1029
- # TEMP: When the model is gated then we cannot set cache dir, for some
1030
- # reason (since transformers v4.38.2, still a problem in v4.48.0). This
1031
- # should be included back in when this is fixed.
1032
- if "gated repo" in str(e):
1033
- model_cache_dir = None
1034
- continue
1045
+ if isinstance(e, GatedRepoError) or "gated repo" in str(e).lower():
1046
+ raise InvalidModel(
1047
+ f"The model {model_id!r} is a gated repository. Please ensure "
1048
+ "that you are logged in with `hf auth login` or have provided a "
1049
+ "valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
1050
+ "environment variable or the `--api-key` argument. Also check that "
1051
+ "your account has access to this model."
1052
+ ) from e
1035
1053
  raise InvalidModel(
1036
1054
  f"Couldn't load model config for {model_id!r}. The error was "
1037
1055
  f"{e!r}. Skipping"
1038
- )
1056
+ ) from e
1039
1057
  except (TimeoutError, RequestError):
1040
- logger.info(f"Couldn't load model config for {model_id!r}. Retrying.")
1058
+ log(
1059
+ f"Couldn't load model config for {model_id!r}. Retrying.",
1060
+ level=logging.WARNING,
1061
+ )
1041
1062
  sleep(5)
1042
1063
  continue
1043
1064
  except ValueError as e:
@@ -1045,17 +1066,31 @@ def load_hf_model_config(
1045
1066
  raise InvalidModel(
1046
1067
  f"The model {model_id!r} is awaiting a review from the repository "
1047
1068
  "authors. Please try again later."
1048
- )
1069
+ ) from e
1049
1070
  if "trust_remote_code" in str(e):
1050
1071
  raise NeedsAdditionalArgument(
1051
1072
  cli_argument="--trust-remote-code",
1052
1073
  script_argument="trust_remote_code=True",
1053
1074
  run_with_cli=run_with_cli,
1054
- )
1075
+ ) from e
1055
1076
  raise InvalidModel(
1056
1077
  f"The config for the model {model_id!r} could not be loaded. The "
1057
1078
  f"error was {e!r}."
1058
- )
1079
+ ) from e
1080
+ else:
1081
+ raise InvalidModel(
1082
+ f"Couldn't load model config for {model_id!r} after {num_attempts} "
1083
+ "attempts."
1084
+ )
1085
+
1086
+ # Ensure that the PAD token ID is set
1087
+ if config.eos_token_id is not None and config.pad_token_id is None:
1088
+ if isinstance(config.eos_token_id, list):
1089
+ config.pad_token_id = config.eos_token_id[0]
1090
+ else:
1091
+ config.pad_token_id = config.eos_token_id
1092
+
1093
+ return config
1059
1094
 
1060
1095
 
1061
1096
  def setup_model_for_question_answering(model: "PreTrainedModel") -> "PreTrainedModel":
@@ -1140,33 +1175,33 @@ def get_children_of_module(
1140
1175
  return submodules
1141
1176
 
1142
1177
 
1143
- def align_model_and_tokenizer(
1178
+ def align_model_and_tokeniser(
1144
1179
  model: "PreTrainedModel",
1145
- tokenizer: "PreTrainedTokenizer",
1180
+ tokeniser: "PreTrainedTokenizer",
1146
1181
  model_max_length: int,
1147
1182
  raise_errors: bool = False,
1148
1183
  ) -> tuple["PreTrainedModel", "PreTrainedTokenizer"]:
1149
- """Aligns the model and the tokenizer.
1184
+ """Aligns the model and the tokeniser.
1150
1185
 
1151
1186
  Args:
1152
1187
  model:
1153
1188
  The model to fix.
1154
- tokenizer:
1155
- The tokenizer to fix.
1189
+ tokeniser:
1190
+ The tokeniser to fix.
1156
1191
  model_max_length:
1157
1192
  The maximum length of the model.
1158
1193
  raise_errors:
1159
1194
  Whether to raise errors instead of trying to fix them silently.
1160
1195
 
1161
1196
  Returns:
1162
- The fixed model and tokenizer.
1197
+ The fixed model and tokeniser.
1163
1198
  """
1164
1199
  model_max_length = min(model_max_length, MAX_CONTEXT_LENGTH)
1165
1200
 
1166
1201
  if model_max_length > 0:
1167
- tokenizer.model_max_length = model_max_length
1202
+ tokeniser.model_max_length = model_max_length
1168
1203
  else:
1169
- tokenizer.model_max_length = 512
1204
+ tokeniser.model_max_length = 512
1170
1205
 
1171
1206
  # Move the model to the CPU, since otherwise we can't catch the IndexErrors when
1172
1207
  # finding the maximum sequence length of the model
@@ -1175,9 +1210,9 @@ def align_model_and_tokenizer(
1175
1210
 
1176
1211
  # Manually check that this model max length is valid for the model, and adjust
1177
1212
  # otherwise
1178
- initial_max_length = tokenizer.model_max_length
1213
+ initial_max_length = tokeniser.model_max_length
1179
1214
  for max_length in range(initial_max_length, 0, -1):
1180
- tokenizer.model_max_length = max_length
1215
+ tokeniser.model_max_length = max_length
1181
1216
  dummy_inputs = torch.full(
1182
1217
  size=(1, max_length),
1183
1218
  fill_value=DUMMY_FILL_VALUE,
@@ -1204,26 +1239,27 @@ def align_model_and_tokenizer(
1204
1239
  # Move the model back to the original device
1205
1240
  model.to(model_device) # type: ignore[arg-type]
1206
1241
 
1207
- # If there is a mismatch between the vocab size according to the tokenizer and
1242
+ # If there is a mismatch between the vocab size according to the tokeniser and
1208
1243
  # the vocab size according to the model, we raise an error
1209
1244
  if hasattr(model.config, "vocab_size"):
1210
- if model.config.vocab_size < len(tokenizer):
1245
+ if model.config.vocab_size < len(tokeniser):
1211
1246
  if raise_errors:
1212
1247
  raise InvalidModel(
1213
- "The vocab size of the tokenizer is larger than the vocab size of "
1248
+ "The vocab size of the tokeniser is larger than the vocab size of "
1214
1249
  "the model. As the --raise-errors option was specified, the "
1215
1250
  "embeddings of the model will not be automatically adjusted."
1216
1251
  )
1217
1252
  if hasattr(model, "resize_token_embeddings"):
1218
- model.resize_token_embeddings(new_num_tokens=tokenizer.vocab_size + 1)
1253
+ model.resize_token_embeddings(new_num_tokens=tokeniser.vocab_size + 1)
1219
1254
 
1220
- if tokenizer.bos_token is None and tokenizer.eos_token is not None:
1221
- tokenizer.bos_token = tokenizer.eos_token
1222
- tokenizer.bos_token_id = tokenizer.eos_token_id
1255
+ if tokeniser.bos_token is None and tokeniser.eos_token is not None:
1256
+ tokeniser.bos_token = tokeniser.eos_token
1257
+ tokeniser.bos_token_id = tokeniser.eos_token_id
1223
1258
 
1224
- return model, tokenizer
1259
+ return model, tokeniser
1225
1260
 
1226
1261
 
1262
+ @cache_arguments()
1227
1263
  def task_group_to_class_name(task_group: TaskGroup) -> str:
1228
1264
  """Convert a task group to a class name.
1229
1265