transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,937 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ VERY_LARGE_INTEGER = 1e30.to_i # This is used to set the max input length for a model with infinite size input
17
+ LARGE_INTEGER = 1e20.to_i # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
18
+
19
+ # Slow tokenizers used to be saved in three separated files
20
+ SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
21
+ ADDED_TOKENS_FILE = "added_tokens.json"
22
+ TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
23
+
24
+ # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
25
+ FULL_TOKENIZER_FILE = "tokenizer.json"
26
+
27
+ class TruncationStrategy < ExplicitEnum
28
+ ONLY_FIRST = "only_first"
29
+ ONLY_SECOND = "only_second"
30
+ LONGEST_FIRST = "longest_first"
31
+ DO_NOT_TRUNCATE = "do_not_truncate"
32
+ end
33
+
34
+ class BatchEncoding
35
+ def initialize(
36
+ data: nil,
37
+ encoding: nil,
38
+ tensor_type: nil,
39
+ prepend_batch_axis: false,
40
+ n_sequences: nil
41
+ )
42
+ @data = data
43
+
44
+ @encodings = encoding
45
+
46
+ convert_to_tensors(tensor_type: tensor_type, prepend_batch_axis: prepend_batch_axis)
47
+ end
48
+
49
+ def convert_to_tensors(tensor_type: nil, prepend_batch_axis: false)
50
+ if tensor_type.nil?
51
+ return self
52
+ end
53
+
54
+ if !tensor_type.is_a?(TensorType)
55
+ tensor_type = TensorType.new(tensor_type)
56
+ end
57
+
58
+ is_tensor = Torch.method(:tensor?)
59
+
60
+ as_tensor = lambda do |value, dtype: nil|
61
+ if value.is_a?(Array) && value[0].is_a?(Numo::NArray)
62
+ return Torch.tensor(Numo::NArray.cast(value))
63
+ end
64
+ Torch.tensor(value)
65
+ end
66
+
67
+ items.each do |key, value|
68
+ if prepend_batch_axis
69
+ value = [value]
70
+ end
71
+
72
+ if !is_tensor.(value)
73
+ tensor = as_tensor.(value)
74
+ @data[key] = tensor
75
+ end
76
+ end
77
+ end
78
+
79
+ def [](item)
80
+ if item.is_a?(String)
81
+ @data[item]
82
+ elsif item.is_a?(Symbol)
83
+ @data[item.to_s]
84
+ elsif !@encodings.nil?
85
+ @encodings[item]
86
+ elsif item.is_a?(Range)
87
+ @data.keys.to_h { |key| [key, @data[key][item]] }
88
+ else
89
+ raise KeyError, "Invalid key. Only three types of key are available: (1) string, (2) integers for backend Encoding, and (3) ranges for data subsetting."
90
+ end
91
+ end
92
+
93
+ def include?(item)
94
+ @data.include?(item.to_s)
95
+ end
96
+
97
+ def delete(item)
98
+ @data.delete(item.to_s)
99
+ end
100
+
101
+ def items
102
+ @data
103
+ end
104
+
105
+ def encodings
106
+ @encodings
107
+ end
108
+
109
+ def sequence_ids(batch_index = 0)
110
+ if !@encodings
111
+ raise ArgumentError,
112
+ "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" +
113
+ " class)."
114
+ end
115
+ @encodings[batch_index].sequence_ids
116
+ end
117
+
118
+ def to_h
119
+ @data.transform_keys(&:to_sym)
120
+ end
121
+ alias_method :to_hash, :to_h
122
+ end
123
+
124
+ module SpecialTokensMixin
125
+ SPECIAL_TOKENS_ATTRIBUTES = [
126
+ :bos_token,
127
+ :eos_token,
128
+ :unk_token,
129
+ :sep_token,
130
+ :pad_token,
131
+ :cls_token,
132
+ :mask_token,
133
+ :additional_special_tokens
134
+ ]
135
+ attr_reader(*SPECIAL_TOKENS_ATTRIBUTES)
136
+
137
+ def initialize(**kwargs)
138
+ SPECIAL_TOKENS_ATTRIBUTES.each do |k|
139
+ instance_variable_set("@#{k}", kwargs[k])
140
+ end
141
+ end
142
+
143
+ def bos_token_id
144
+ if @bos_token.nil?
145
+ return nil
146
+ end
147
+ convert_tokens_to_ids(@bos_token)
148
+ end
149
+
150
+ def eos_token_id
151
+ if @eos_token.nil?
152
+ return nil
153
+ end
154
+ convert_tokens_to_ids(@eos_token)
155
+ end
156
+
157
+ def unk_token_id
158
+ if @unk_token.nil?
159
+ return nil
160
+ end
161
+ convert_tokens_to_ids(@unk_token)
162
+ end
163
+
164
+ def sep_token_id
165
+ if @sep_token.nil?
166
+ return nil
167
+ end
168
+ convert_tokens_to_ids(@sep_token)
169
+ end
170
+
171
+ def pad_token_id
172
+ if @pad_token.nil?
173
+ return nil
174
+ end
175
+ convert_tokens_to_ids(@pad_token)
176
+ end
177
+
178
+ def cls_token_id
179
+ if @cls_token.nil?
180
+ return nil
181
+ end
182
+ convert_tokens_to_ids(@cls_token)
183
+ end
184
+
185
+ def special_tokens_map
186
+ set_attr = {}
187
+ SPECIAL_TOKENS_ATTRIBUTES.each do |attr|
188
+ attr_value = send(attr)
189
+ if attr_value
190
+ set_attr[attr] = attr_value
191
+ end
192
+ end
193
+ set_attr
194
+ end
195
+ end
196
+
197
+ class PreTrainedTokenizerBase
198
+ include SpecialTokensMixin
199
+ extend ClassAttribute
200
+
201
+ class_attribute :vocab_files_names, {}
202
+
203
+ class_attribute :model_input_names, ["input_ids", "token_type_ids", "attention_mask"]
204
+ class_attribute :padding_side, "right"
205
+ class_attribute :truncation_side, "right"
206
+ class_attribute :slow_tokenizer_class
207
+
208
+ attr_reader :init_kwargs, :model_max_length
209
+
210
+ def initialize(**kwargs)
211
+ # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
212
+ @init_inputs = []
213
+ @init_kwargs = kwargs.dup # copy.deepcopy(kwargs)
214
+ @name_or_path = kwargs.delete(:name_or_path) { "" }
215
+ @processor_class = kwargs.delete(:processor_class)
216
+
217
+ # For backward compatibility we fallback to set model_max_length from max_len if provided
218
+ model_max_length = kwargs.delete(:model_max_length) { kwargs.delete(:max_len) }
219
+ @model_max_length = !model_max_length.nil? ? model_max_length : VERY_LARGE_INTEGER
220
+
221
+ # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it
222
+ # is changed.
223
+ @padding_side = kwargs.delete(:padding_side) { self.class.padding_side }
224
+ if !["right", "left"].include?(@padding_side)
225
+ raise ArgumentError, "Padding side should be selected between 'right' and 'left', current value: #{@padding_side}"
226
+ end
227
+
228
+ @truncation_side = kwargs.delete(:truncation_side) { self.class.truncation_side }
229
+ if !["right", "left"].include?(@truncation_side)
230
+ raise ArgumentError, "Truncation side should be selected between 'right' and 'left', current value: #{@truncation_side}"
231
+ end
232
+
233
+ @model_input_names = kwargs.delete(:model_input_names) { self.class.model_input_names }
234
+
235
+ # By default, cleaning tokenization spaces for both fast and slow tokenizers
236
+ @clean_up_tokenization_spaces = kwargs.delete(:clean_up_tokenization_spaces) { true }
237
+
238
+ # By default, do not split special tokens for both fast and slow tokenizers
239
+ @split_special_tokens = kwargs.delete(:split_special_tokens) { false }
240
+
241
+ @deprecation_warnings = {}
242
+ @in_target_context_manager = false
243
+
244
+ # Stores a Jinja template that formats chat histories into tokenizable strings
245
+ @chat_template = kwargs.delete(:chat_template)
246
+ if @chat_template.is_a?(Array)
247
+ # Chat templates are stored as lists of dicts with fixed key names,
248
+ # we reconstruct that into a single dict while loading them.
249
+ @chat_template = @chat_template.to_h { |template| [template["name"], template["template"]] }
250
+ end
251
+
252
+ super
253
+ end
254
+
255
+ def _eventual_warn_about_too_long_sequence(ids, max_length, verbose)
256
+ if max_length.nil? && ids.length > @model_max_length && verbose
257
+ raise Todo
258
+ end
259
+ end
260
+
261
+ def call(
262
+ text,
263
+ text_pair: nil,
264
+ text_target: nil,
265
+ text_pair_target: nil,
266
+ add_special_tokens: true,
267
+ padding: false,
268
+ truncation: nil,
269
+ max_length: nil,
270
+ stride: 0,
271
+ is_split_into_words: false,
272
+ pad_to_multiple_of: nil,
273
+ return_tensors: nil,
274
+ return_token_type_ids: nil,
275
+ return_attention_mask: nil,
276
+ return_overflowing_tokens: false,
277
+ return_special_tokens_mask: false,
278
+ return_offsets_mapping: false,
279
+ return_length: false,
280
+ verbose: true,
281
+ **kwargs
282
+ )
283
+ # To avoid duplicating
284
+ all_kwargs = {
285
+ add_special_tokens: add_special_tokens,
286
+ padding: padding,
287
+ truncation: truncation,
288
+ max_length: max_length,
289
+ stride: stride,
290
+ is_split_into_words: is_split_into_words,
291
+ pad_to_multiple_of: pad_to_multiple_of,
292
+ return_tensors: return_tensors,
293
+ return_token_type_ids: return_token_type_ids,
294
+ return_attention_mask: return_attention_mask,
295
+ return_overflowing_tokens: return_overflowing_tokens,
296
+ return_special_tokens_mask: return_special_tokens_mask,
297
+ return_offsets_mapping: return_offsets_mapping,
298
+ return_length: return_length,
299
+ verbose: verbose
300
+ }
301
+ all_kwargs.merge!(kwargs)
302
+ if text.nil? && text_target.nil?
303
+ raise ArgumentError, "You need to specify either `text` or `text_target`."
304
+ end
305
+ if !text.nil?
306
+ # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
307
+ # input mode in this case.
308
+ if !@in_target_context_manager
309
+ _switch_to_input_mode
310
+ end
311
+ encodings = _call_one(text: text, text_pair: text_pair, **all_kwargs)
312
+ end
313
+ if !text_target.nil?
314
+ _switch_to_target_mode
315
+ target_encodings = _call_one(text: text_target, text_pair: text_pair_target, **all_kwargs)
316
+ end
317
+ # Leave back tokenizer in input mode
318
+ _switch_to_input_mode
319
+
320
+ if text_target.nil?
321
+ encodings
322
+ elsif text.nil?
323
+ target_encodings
324
+ else
325
+ encodings["labels"] = target_encodings["input_ids"]
326
+ encodings
327
+ end
328
+ end
329
+
330
+ protected
331
+
332
+ def _switch_to_input_mode
333
+ end
334
+
335
+ def _switch_to_target_mode
336
+ end
337
+
338
+ private
339
+
340
+ def _call_one(
341
+ text:,
342
+ text_pair: nil,
343
+ add_special_tokens: true,
344
+ padding: false,
345
+ truncation: nil,
346
+ max_length: nil,
347
+ stride: 0,
348
+ is_split_into_words: false,
349
+ pad_to_multiple_of: nil,
350
+ return_tensors: nil,
351
+ return_token_type_ids: nil,
352
+ return_attention_mask: nil,
353
+ return_overflowing_tokens: false,
354
+ return_special_tokens_mask: false,
355
+ return_offsets_mapping: false,
356
+ return_length: false,
357
+ verbose: true,
358
+ **kwargs
359
+ )
360
+ # Input type checking for clearer error
361
+ _is_valid_text_input = lambda do |t|
362
+ if t.is_a?(String)
363
+ # Strings are fine
364
+ true
365
+ elsif t.is_a?(Array)
366
+ # List are fine as long as they are...
367
+ if t.length == 0
368
+ # ... empty
369
+ true
370
+ elsif t[0].is_a?(String)
371
+ # ... list of strings
372
+ true
373
+ elsif t[0].is_a?(Array)
374
+ # ... list with an empty list or with a list of strings
375
+ t[0].length == 0 || t[0][0].is_a?(String)
376
+ else
377
+ false
378
+ end
379
+ else
380
+ false
381
+ end
382
+ end
383
+
384
+ if !_is_valid_text_input.(text)
385
+ raise ArgumentError, "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
386
+ end
387
+
388
+ if !text_pair.nil? && !_is_valid_text_input.(text_pair)
389
+ raise ArgumentError, "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
390
+ end
391
+
392
+ if is_split_into_words
393
+ is_batched = text.is_a?(Array) && text[0].is_a?(Array)
394
+ else
395
+ is_batched = text.is_a?(Array)
396
+ end
397
+
398
+ if is_batched
399
+ if text_pair.is_a?(String)
400
+ raise TypeError, "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`."
401
+ end
402
+ if !text_pair.nil? && text.length != text_pair.length
403
+ raise ArgumentError, "batch length of `text`: #{text.length} does not match batch length of `text_pair`: #{text_pair.length}."
404
+ end
405
+ batch_text_or_text_pairs = !text_pair.nil? ? text.zip(text_pair).to_a : text
406
+ batch_encode_plus(
407
+ batch_text_or_text_pairs: batch_text_or_text_pairs,
408
+ add_special_tokens: add_special_tokens,
409
+ padding: padding,
410
+ truncation: truncation,
411
+ max_length: max_length,
412
+ stride: stride,
413
+ is_split_into_words: is_split_into_words,
414
+ pad_to_multiple_of: pad_to_multiple_of,
415
+ return_tensors: return_tensors,
416
+ return_token_type_ids: return_token_type_ids,
417
+ return_attention_mask: return_attention_mask,
418
+ return_overflowing_tokens: return_overflowing_tokens,
419
+ return_special_tokens_mask: return_special_tokens_mask,
420
+ return_offsets_mapping: return_offsets_mapping,
421
+ return_length: return_length,
422
+ verbose: verbose,
423
+ **kwargs
424
+ )
425
+ else
426
+ encode_plus(
427
+ text: text,
428
+ text_pair: text_pair,
429
+ add_special_tokens: add_special_tokens,
430
+ padding: padding,
431
+ truncation: truncation,
432
+ max_length: max_length,
433
+ stride: stride,
434
+ is_split_into_words: is_split_into_words,
435
+ pad_to_multiple_of: pad_to_multiple_of,
436
+ return_tensors: return_tensors,
437
+ return_token_type_ids: return_token_type_ids,
438
+ return_attention_mask: return_attention_mask,
439
+ return_overflowing_tokens: return_overflowing_tokens,
440
+ return_special_tokens_mask: return_special_tokens_mask,
441
+ return_offsets_mapping: return_offsets_mapping,
442
+ return_length: return_length,
443
+ verbose: verbose,
444
+ **kwargs
445
+ )
446
+ end
447
+ end
448
+
449
+ def encode_plus(
450
+ text:,
451
+ text_pair: nil,
452
+ add_special_tokens: true,
453
+ padding: false,
454
+ truncation: nil,
455
+ max_length: nil,
456
+ stride: 0,
457
+ is_split_into_words: false,
458
+ pad_to_multiple_of: nil,
459
+ return_tensors: nil,
460
+ return_token_type_ids: nil,
461
+ return_attention_mask: nil,
462
+ return_overflowing_tokens: false,
463
+ return_special_tokens_mask: false,
464
+ return_offsets_mapping: false,
465
+ return_length: false,
466
+ verbose: true,
467
+ **kwargs
468
+ )
469
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
470
+ padding_strategy, truncation_strategy, max_length, kwargs =
471
+ _get_padding_truncation_strategies(
472
+ padding: padding,
473
+ truncation: truncation,
474
+ max_length: max_length,
475
+ pad_to_multiple_of: pad_to_multiple_of,
476
+ verbose: verbose,
477
+ **kwargs
478
+ )
479
+
480
+ _encode_plus(
481
+ text: text,
482
+ text_pair: text_pair,
483
+ add_special_tokens: add_special_tokens,
484
+ padding_strategy: padding_strategy,
485
+ truncation_strategy: truncation_strategy,
486
+ max_length: max_length,
487
+ stride: stride,
488
+ is_split_into_words: is_split_into_words,
489
+ pad_to_multiple_of: pad_to_multiple_of,
490
+ return_tensors: return_tensors,
491
+ return_token_type_ids: return_token_type_ids,
492
+ return_attention_mask: return_attention_mask,
493
+ return_overflowing_tokens: return_overflowing_tokens,
494
+ return_special_tokens_mask: return_special_tokens_mask,
495
+ return_offsets_mapping: return_offsets_mapping,
496
+ return_length: return_length,
497
+ verbose: verbose,
498
+ **kwargs
499
+ )
500
+ end
501
+
502
+ def batch_encode_plus(
503
+ batch_text_or_text_pairs:,
504
+ add_special_tokens: true,
505
+ padding: false,
506
+ truncation: nil,
507
+ max_length: nil,
508
+ stride: 0,
509
+ is_split_into_words: false,
510
+ pad_to_multiple_of: nil,
511
+ return_tensors: nil,
512
+ return_token_type_ids: nil,
513
+ return_attention_mask: nil,
514
+ return_overflowing_tokens: false,
515
+ return_special_tokens_mask: false,
516
+ return_offsets_mapping: false,
517
+ return_length: false,
518
+ verbose: true,
519
+ **kwargs
520
+ )
521
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
522
+ padding_strategy, truncation_strategy, max_length, kwargs =
523
+ _get_padding_truncation_strategies(
524
+ padding: padding,
525
+ truncation: truncation,
526
+ max_length: max_length,
527
+ pad_to_multiple_of: pad_to_multiple_of,
528
+ verbose: verbose,
529
+ **kwargs
530
+ )
531
+
532
+ _batch_encode_plus(
533
+ batch_text_or_text_pairs,
534
+ add_special_tokens: add_special_tokens,
535
+ padding_strategy: padding_strategy,
536
+ truncation_strategy: truncation_strategy,
537
+ max_length: max_length,
538
+ stride: stride,
539
+ is_split_into_words: is_split_into_words,
540
+ pad_to_multiple_of: pad_to_multiple_of,
541
+ return_tensors: return_tensors,
542
+ return_token_type_ids: return_token_type_ids,
543
+ return_attention_mask: return_attention_mask,
544
+ return_overflowing_tokens: return_overflowing_tokens,
545
+ return_special_tokens_mask: return_special_tokens_mask,
546
+ return_offsets_mapping: return_offsets_mapping,
547
+ return_length: return_length,
548
+ verbose: verbose,
549
+ **kwargs
550
+ )
551
+ end
552
+
553
+ def _get_padding_truncation_strategies(
554
+ padding: false,
555
+ truncation: nil,
556
+ max_length: nil,
557
+ pad_to_multiple_of: nil,
558
+ verbose: true,
559
+ **kwargs
560
+ )
561
+ padding_strategy = PaddingStrategy::DO_NOT_PAD
562
+ truncation_strategy = TruncationStrategy::DO_NOT_TRUNCATE
563
+
564
+ old_truncation_strategy = kwargs.delete(:truncation_strategy) || "do_not_truncate"
565
+ old_pad_to_max_length = kwargs.delete(:pad_to_max_length) || false
566
+
567
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
568
+ # If you only set max_length, it activates truncation for max_length
569
+ if !max_length.nil? && padding == false && truncation.nil?
570
+ raise Todo
571
+ end
572
+
573
+ # Get padding strategy
574
+ if padding == false && old_pad_to_max_length
575
+ if verbose
576
+ raise Todo
577
+ end
578
+ if max_length.nil?
579
+ padding_strategy = PaddingStrategy::LONGEST
580
+ else
581
+ padding_strategy = PaddingStrategy::MAX_LENGTH
582
+ end
583
+ elsif padding != false
584
+ if padding == true
585
+ if verbose
586
+ # raise Todo
587
+ end
588
+ padding_strategy = PaddingStrategy::LONGEST # Default to pad to the longest sequence in the batch
589
+ elsif !padding.is_a?(PaddingStrategy)
590
+ padding_strategy = PaddingStrategy.new(padding)
591
+ elsif padding.is_a?(PaddingStrategy)
592
+ padding_strategy = padding
593
+ end
594
+ else
595
+ padding_strategy = PaddingStrategy::DO_NOT_PAD
596
+ end
597
+
598
+ # Get truncation strategy
599
+ if truncation.nil? && old_truncation_strategy != "do_not_truncate"
600
+ if verbose
601
+ raise Todo
602
+ end
603
+ truncation_strategy = TruncationStrategy.new(old_truncation_strategy).to_s
604
+ elsif truncation != false && !truncation.nil?
605
+ if truncation == true
606
+ truncation_strategy = (
607
+ TruncationStrategy::LONGEST_FIRST
608
+ ) # Default to truncate the longest sequences in pairs of inputs
609
+ elsif !truncation.is_a?(TruncationStrategy)
610
+ truncation_strategy = TruncationStrategy.new(truncation).to_s
611
+ else
612
+ truncation_strategy = truncation
613
+ end
614
+ else
615
+ truncation_strategy = TruncationStrategy::DO_NOT_TRUNCATE
616
+ end
617
+
618
+ # Set max length if needed
619
+ if max_length.nil?
620
+ if padding_strategy == PaddingStrategy::MAX_LENGTH
621
+ if @model_max_length > LARGE_INTEGER
622
+ if verbose
623
+ raise Todo
624
+ end
625
+ padding_strategy = PaddingStrategy::DO_NOT_PAD
626
+ else
627
+ max_length = @model_max_length
628
+ end
629
+ end
630
+
631
+ if truncation_strategy != TruncationStrategy::DO_NOT_TRUNCATE
632
+ if @model_max_length > LARGE_INTEGER
633
+ if verbose
634
+ raise Todo
635
+ end
636
+ truncation_strategy = TruncationStrategy::DO_NOT_TRUNCATE
637
+ else
638
+ max_length = @model_max_length
639
+ end
640
+ end
641
+ end
642
+
643
+ # Test if we have a padding token
644
+ if padding_strategy != PaddingStrategy::DO_NOT_PAD && (@pad_token.nil? || pad_token_id < 0)
645
+ raise ArgumentError,
646
+ "Asking to pad but the tokenizer does not have a padding token. " +
647
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " +
648
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
649
+ end
650
+
651
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
652
+ if (
653
+ truncation_strategy != TruncationStrategy::DO_NOT_TRUNCATE &&
654
+ padding_strategy != PaddingStrategy::DO_NOT_PAD &&
655
+ !pad_to_multiple_of.nil? &&
656
+ !max_length.nil? &&
657
+ (max_length % pad_to_multiple_of != 0)
658
+ )
659
+ raise ArgumentError,
660
+ "Truncation and padding are both activated but " +
661
+ "truncation length (#{max_length}) is not a multiple of pad_to_multiple_of (#{pad_to_multiple_of})."
662
+ end
663
+
664
+ [padding_strategy, truncation_strategy, max_length, kwargs]
665
+ end
666
+
667
+ class << self
668
+ def from_pretrained(
669
+ pretrained_model_name_or_path,
670
+ *init_inputs,
671
+ cache_dir: nil,
672
+ force_download: false,
673
+ local_files_only: false,
674
+ token: nil,
675
+ revision: "main",
676
+ trust_remote_code: false,
677
+ **kwargs
678
+ )
679
+ resume_download = kwargs.delete(:resume_download) { false }
680
+ proxies = kwargs.delete(:proxies)
681
+ subfolder = kwargs.delete(:subfolder)
682
+ from_pipeline = kwargs.delete(:_from_pipeline)
683
+ from_auto_class = kwargs.delete(:_from_auto) { false }
684
+ commit_hash = kwargs.delete(:_commit_hash)
685
+
686
+ user_agent = {file_type: "tokenizer", from_auto_class: from_auto_class, is_fast: name.include?("Fast")}
687
+ if !from_pipeline.nil?
688
+ user_agent[:using_pipeline] = from_pipeline
689
+ end
690
+
691
+ if Utils::Hub.is_offline_mode && !local_files_only
692
+ Transformers.logger.info("Offline mode: forcing local_files_only: true")
693
+ local_files_only = true
694
+ end
695
+
696
+ pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
697
+ vocab_files = {}
698
+ init_configuration = {}
699
+
700
+ is_local = Dir.exist?(pretrained_model_name_or_path)
701
+ single_file_id = nil
702
+ if File.exist?(pretrained_model_name_or_path)
703
+ raise Todo
704
+ end
705
+
706
+ # At this point pretrained_model_name_or_path is either a directory or a model identifier name
707
+ additional_files_names = {
708
+ added_tokens_file: ADDED_TOKENS_FILE, # kept only for legacy
709
+ special_tokens_map_file: SPECIAL_TOKENS_MAP_FILE, # kept only for legacy
710
+ tokenizer_config_file: TOKENIZER_CONFIG_FILE,
711
+ # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
712
+ tokenizer_file: FULL_TOKENIZER_FILE
713
+ }
714
+ vocab_files = vocab_files_names.merge(additional_files_names)
715
+ if vocab_files[:tokenizer_file]
716
+ # Try to get the tokenizer config to see if there are versioned tokenizer files.
717
+ fast_tokenizer_file = FULL_TOKENIZER_FILE
718
+ resolved_config_file =
719
+ Utils::Hub.cached_file(
720
+ pretrained_model_name_or_path,
721
+ TOKENIZER_CONFIG_FILE,
722
+ cache_dir: cache_dir,
723
+ force_download: force_download,
724
+ resume_download: resume_download,
725
+ proxies: proxies,
726
+ token: token,
727
+ revision: revision,
728
+ local_files_only: local_files_only,
729
+ subfolder: subfolder,
730
+ user_agent: user_agent,
731
+ _raise_exceptions_for_gated_repo: false,
732
+ _raise_exceptions_for_missing_entries: false,
733
+ _raise_exceptions_for_connection_errors: false,
734
+ _commit_hash: commit_hash
735
+ )
736
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
737
+ if !resolved_config_file.nil?
738
+ tokenizer_config = JSON.load_file(resolved_config_file)
739
+ if tokenizer_config["fast_tokenizer_files"]
740
+ fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
741
+ end
742
+ end
743
+ vocab_files[:tokenizer_file] = fast_tokenizer_file
744
+ end
745
+
746
+ # Get files from url, cache, or disk depending on the case
747
+ resolved_vocab_files = {}
748
+ unresolved_files = []
749
+ vocab_files.each do |file_id, file_path|
750
+ if file_path.nil?
751
+ resolved_vocab_files[file_id] = nil
752
+ elsif single_file_id == file_id
753
+ if File.exist?(file_path)
754
+ resolved_vocab_files[file_id] = file_path
755
+ else
756
+ raise Todo
757
+ end
758
+ else
759
+ resolved_vocab_files[file_id] =
760
+ Utils::Hub.cached_file(
761
+ pretrained_model_name_or_path,
762
+ file_path,
763
+ cache_dir: cache_dir,
764
+ force_download: force_download,
765
+ proxies: proxies,
766
+ resume_download: resume_download,
767
+ local_files_only: local_files_only,
768
+ token: token,
769
+ user_agent: user_agent,
770
+ revision: revision,
771
+ subfolder: subfolder,
772
+ _raise_exceptions_for_gated_repo: false,
773
+ _raise_exceptions_for_missing_entries: false,
774
+ _raise_exceptions_for_connection_errors: false,
775
+ _commit_hash: commit_hash,
776
+ )
777
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
778
+ end
779
+ end
780
+
781
+ # not used?
782
+ if unresolved_files.length > 0
783
+ raise Todo
784
+ end
785
+
786
+ vocab_files.each do |file_id, file_path|
787
+ if !resolved_vocab_files.include?(file_id)
788
+ next
789
+ end
790
+
791
+ if is_local
792
+ Transformers.logger.info("loading file #{file_path}")
793
+ else
794
+ Transformers.logger.info("loading file #{file_path} from cache at #{resolved_vocab_files[file_id] || "nil"}")
795
+ end
796
+ end
797
+
798
+ _from_pretrained(
799
+ resolved_vocab_files,
800
+ pretrained_model_name_or_path,
801
+ init_configuration,
802
+ *init_inputs,
803
+ token: token,
804
+ cache_dir: cache_dir,
805
+ local_files_only: local_files_only,
806
+ _commit_hash: commit_hash,
807
+ _is_local: is_local,
808
+ trust_remote_code: trust_remote_code,
809
+ **kwargs
810
+ )
811
+ end
812
+
813
+ def _from_pretrained(
814
+ resolved_vocab_files,
815
+ pretrained_model_name_or_path,
816
+ init_configuration,
817
+ *init_inputs,
818
+ token: nil,
819
+ cache_dir: nil,
820
+ local_files_only: false,
821
+ _commit_hash: nil,
822
+ _is_local: false,
823
+ trust_remote_code: false,
824
+ **kwargs
825
+ )
826
+ # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
827
+ # file or if `from_slow` is set to True.
828
+ from_slow = kwargs.delete(:from_slow) { false }
829
+ has_tokenizer_file = !resolved_vocab_files[:tokenizer_file].nil?
830
+ if (from_slow || !has_tokenizer_file) && !slow_tokenizer_class.nil?
831
+ slow_tokenizer =
832
+ slow_tokenizer_class._from_pretrained(
833
+ Copy.deepcopy(resolved_vocab_files),
834
+ pretrained_model_name_or_path,
835
+ Copy.deepcopy(init_configuration),
836
+ *init_inputs,
837
+ token: token,
838
+ cache_dir: cache_dir,
839
+ local_files_only: local_files_only,
840
+ _commit_hash: _commit_hash,
841
+ **Copy.deepcopy(kwargs)
842
+ )
843
+ else
844
+ slow_tokenizer = nil
845
+ end
846
+
847
+ # Prepare tokenizer initialization kwargs
848
+ # Did we saved some inputs and kwargs to reload ?
849
+ tokenizer_config_file = resolved_vocab_files.delete(:tokenizer_config_file)
850
+ if !tokenizer_config_file.nil?
851
+ init_kwargs = JSON.load_file(tokenizer_config_file).transform_keys(&:to_sym)
852
+ # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
853
+ config_tokenizer_class = init_kwargs[:tokenizer_class]
854
+ init_kwargs.delete(:tokenizer_class)
855
+ if !has_tokenizer_file
856
+ init_kwargs.delete(:tokenizer_file)
857
+ end
858
+ saved_init_inputs = init_kwargs.delete(:init_inputs) { [] }
859
+ if init_inputs.empty?
860
+ init_inputs = saved_init_inputs
861
+ end
862
+ else
863
+ config_tokenizer_class = nil
864
+ init_kwargs = init_configuration
865
+ end
866
+
867
+ if config_tokenizer_class.nil?
868
+ config =
869
+ AutoConfig.from_pretrained(
870
+ pretrained_model_name_or_path,
871
+ token: token,
872
+ cache_dir: cache_dir,
873
+ local_files_only: local_files_only,
874
+ trust_remote_code: trust_remote_code,
875
+ _commit_hash: _commit_hash,
876
+ )
877
+ config_tokenizer_class = config.tokenizer_class
878
+
879
+ if config_tokenizer_class.nil?
880
+ # Third attempt. If we have not yet found the original type of the tokenizer,
881
+ # we are loading we see if we can infer it from the type of the configuration file
882
+ if config.class.model_type
883
+ model_type = config.class.model_type
884
+ else
885
+ # Fallback: use pattern matching on the string.
886
+ model_type = nil
887
+ TOKENIZER_MAPPING_NAMES.each_key do |pattern|
888
+ if pretrained_model_name_or_path.to_s.include?(pattern)
889
+ model_type = pattern
890
+ break
891
+ end
892
+ end
893
+ end
894
+
895
+ if !model_type.nil?
896
+ config_tokenizer_class, config_tokenizer_class_fast =
897
+ TOKENIZER_MAPPING_NAMES.fetch(model_type, [nil, nil])
898
+
899
+ if config_tokenizer_class.nil?
900
+ config_tokenizer_class = config_tokenizer_class_fast
901
+ end
902
+ end
903
+ end
904
+ end
905
+
906
+ if !config_tokenizer_class.nil?
907
+ if name.split("::").last.gsub("Fast", "") != config_tokenizer_class.gsub("Fast", "")
908
+ raise Todo
909
+ end
910
+ end
911
+
912
+ # Update with newly provided kwargs
913
+ init_kwargs.merge!(kwargs)
914
+
915
+ # Merge resolved_vocab_files arguments in init_kwargs.
916
+ _added_tokens_file = resolved_vocab_files.delete(:added_tokens_file)
917
+ _special_tokens_map_file = resolved_vocab_files.delete(:special_tokens_map_file)
918
+ resolved_vocab_files.each do |args_name, file_path|
919
+ if !init_kwargs.include?(args_name)
920
+ init_kwargs[args_name] = file_path
921
+ end
922
+ end
923
+ _tokenizer_file = resolved_vocab_files.delete(:tokenizer_file)
924
+
925
+ if !slow_tokenizer.nil?
926
+ init_kwargs[:__slow_tokenizer] = slow_tokenizer
927
+ end
928
+ init_kwargs[:name_or_path] = pretrained_model_name_or_path
929
+
930
+ # Instantiate the tokenizer.
931
+ tokenizer = new(*init_inputs, **init_kwargs)
932
+
933
+ tokenizer
934
+ end
935
+ end
936
+ end
937
+ end