transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,937 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ VERY_LARGE_INTEGER = 1e30.to_i # This is used to set the max input length for a model with infinite size input
17
+ LARGE_INTEGER = 1e20.to_i # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
18
+
19
+ # Slow tokenizers used to be saved in three separated files
20
+ SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
21
+ ADDED_TOKENS_FILE = "added_tokens.json"
22
+ TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
23
+
24
+ # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
25
+ FULL_TOKENIZER_FILE = "tokenizer.json"
26
+
27
+ class TruncationStrategy < ExplicitEnum
28
+ ONLY_FIRST = "only_first"
29
+ ONLY_SECOND = "only_second"
30
+ LONGEST_FIRST = "longest_first"
31
+ DO_NOT_TRUNCATE = "do_not_truncate"
32
+ end
33
+
34
+ class BatchEncoding
35
+ def initialize(
36
+ data: nil,
37
+ encoding: nil,
38
+ tensor_type: nil,
39
+ prepend_batch_axis: false,
40
+ n_sequences: nil
41
+ )
42
+ @data = data
43
+
44
+ @encodings = encoding
45
+
46
+ convert_to_tensors(tensor_type: tensor_type, prepend_batch_axis: prepend_batch_axis)
47
+ end
48
+
49
+ def convert_to_tensors(tensor_type: nil, prepend_batch_axis: false)
50
+ if tensor_type.nil?
51
+ return self
52
+ end
53
+
54
+ if !tensor_type.is_a?(TensorType)
55
+ tensor_type = TensorType.new(tensor_type)
56
+ end
57
+
58
+ is_tensor = Torch.method(:tensor?)
59
+
60
+ as_tensor = lambda do |value, dtype: nil|
61
+ if value.is_a?(Array) && value[0].is_a?(Numo::NArray)
62
+ return Torch.tensor(Numo::NArray.cast(value))
63
+ end
64
+ Torch.tensor(value)
65
+ end
66
+
67
+ items.each do |key, value|
68
+ if prepend_batch_axis
69
+ value = [value]
70
+ end
71
+
72
+ if !is_tensor.(value)
73
+ tensor = as_tensor.(value)
74
+ @data[key] = tensor
75
+ end
76
+ end
77
+ end
78
+
79
+ def [](item)
80
+ if item.is_a?(String)
81
+ @data[item]
82
+ elsif item.is_a?(Symbol)
83
+ @data[item.to_s]
84
+ elsif !@encodings.nil?
85
+ @encodings[item]
86
+ elsif item.is_a?(Range)
87
+ @data.keys.to_h { |key| [key, @data[key][item]] }
88
+ else
89
+ raise KeyError, "Invalid key. Only three types of key are available: (1) string, (2) integers for backend Encoding, and (3) ranges for data subsetting."
90
+ end
91
+ end
92
+
93
+ def include?(item)
94
+ @data.include?(item.to_s)
95
+ end
96
+
97
+ def delete(item)
98
+ @data.delete(item.to_s)
99
+ end
100
+
101
+ def items
102
+ @data
103
+ end
104
+
105
+ def encodings
106
+ @encodings
107
+ end
108
+
109
+ def sequence_ids(batch_index = 0)
110
+ if !@encodings
111
+ raise ArgumentError,
112
+ "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`" +
113
+ " class)."
114
+ end
115
+ @encodings[batch_index].sequence_ids
116
+ end
117
+
118
+ def to_h
119
+ @data.transform_keys(&:to_sym)
120
+ end
121
+ alias_method :to_hash, :to_h
122
+ end
123
+
124
+ module SpecialTokensMixin
125
+ SPECIAL_TOKENS_ATTRIBUTES = [
126
+ :bos_token,
127
+ :eos_token,
128
+ :unk_token,
129
+ :sep_token,
130
+ :pad_token,
131
+ :cls_token,
132
+ :mask_token,
133
+ :additional_special_tokens
134
+ ]
135
+ attr_reader(*SPECIAL_TOKENS_ATTRIBUTES)
136
+
137
+ def initialize(**kwargs)
138
+ SPECIAL_TOKENS_ATTRIBUTES.each do |k|
139
+ instance_variable_set("@#{k}", kwargs[k])
140
+ end
141
+ end
142
+
143
+ def bos_token_id
144
+ if @bos_token.nil?
145
+ return nil
146
+ end
147
+ convert_tokens_to_ids(@bos_token)
148
+ end
149
+
150
+ def eos_token_id
151
+ if @eos_token.nil?
152
+ return nil
153
+ end
154
+ convert_tokens_to_ids(@eos_token)
155
+ end
156
+
157
+ def unk_token_id
158
+ if @unk_token.nil?
159
+ return nil
160
+ end
161
+ convert_tokens_to_ids(@unk_token)
162
+ end
163
+
164
+ def sep_token_id
165
+ if @sep_token.nil?
166
+ return nil
167
+ end
168
+ convert_tokens_to_ids(@sep_token)
169
+ end
170
+
171
+ def pad_token_id
172
+ if @pad_token.nil?
173
+ return nil
174
+ end
175
+ convert_tokens_to_ids(@pad_token)
176
+ end
177
+
178
+ def cls_token_id
179
+ if @cls_token.nil?
180
+ return nil
181
+ end
182
+ convert_tokens_to_ids(@cls_token)
183
+ end
184
+
185
+ def special_tokens_map
186
+ set_attr = {}
187
+ SPECIAL_TOKENS_ATTRIBUTES.each do |attr|
188
+ attr_value = send(attr)
189
+ if attr_value
190
+ set_attr[attr] = attr_value
191
+ end
192
+ end
193
+ set_attr
194
+ end
195
+ end
196
+
197
+ class PreTrainedTokenizerBase
198
+ include SpecialTokensMixin
199
+ extend ClassAttribute
200
+
201
+ class_attribute :vocab_files_names, {}
202
+
203
+ class_attribute :model_input_names, ["input_ids", "token_type_ids", "attention_mask"]
204
+ class_attribute :padding_side, "right"
205
+ class_attribute :truncation_side, "right"
206
+ class_attribute :slow_tokenizer_class
207
+
208
+ attr_reader :init_kwargs, :model_max_length
209
+
210
+ def initialize(**kwargs)
211
+ # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
212
+ @init_inputs = []
213
+ @init_kwargs = kwargs.dup # copy.deepcopy(kwargs)
214
+ @name_or_path = kwargs.delete(:name_or_path) { "" }
215
+ @processor_class = kwargs.delete(:processor_class)
216
+
217
+ # For backward compatibility we fallback to set model_max_length from max_len if provided
218
+ model_max_length = kwargs.delete(:model_max_length) { kwargs.delete(:max_len) }
219
+ @model_max_length = !model_max_length.nil? ? model_max_length : VERY_LARGE_INTEGER
220
+
221
+ # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it
222
+ # is changed.
223
+ @padding_side = kwargs.delete(:padding_side) { self.class.padding_side }
224
+ if !["right", "left"].include?(@padding_side)
225
+ raise ArgumentError, "Padding side should be selected between 'right' and 'left', current value: #{@padding_side}"
226
+ end
227
+
228
+ @truncation_side = kwargs.delete(:truncation_side) { self.class.truncation_side }
229
+ if !["right", "left"].include?(@truncation_side)
230
+ raise ArgumentError, "Truncation side should be selected between 'right' and 'left', current value: #{@truncation_side}"
231
+ end
232
+
233
+ @model_input_names = kwargs.delete(:model_input_names) { self.class.model_input_names }
234
+
235
+ # By default, cleaning tokenization spaces for both fast and slow tokenizers
236
+ @clean_up_tokenization_spaces = kwargs.delete(:clean_up_tokenization_spaces) { true }
237
+
238
+ # By default, do not split special tokens for both fast and slow tokenizers
239
+ @split_special_tokens = kwargs.delete(:split_special_tokens) { false }
240
+
241
+ @deprecation_warnings = {}
242
+ @in_target_context_manager = false
243
+
244
+ # Stores a Jinja template that formats chat histories into tokenizable strings
245
+ @chat_template = kwargs.delete(:chat_template)
246
+ if @chat_template.is_a?(Array)
247
+ # Chat templates are stored as lists of dicts with fixed key names,
248
+ # we reconstruct that into a single dict while loading them.
249
+ @chat_template = @chat_template.to_h { |template| [template["name"], template["template"]] }
250
+ end
251
+
252
+ super
253
+ end
254
+
255
+ def _eventual_warn_about_too_long_sequence(ids, max_length, verbose)
256
+ if max_length.nil? && ids.length > @model_max_length && verbose
257
+ raise Todo
258
+ end
259
+ end
260
+
261
+ def call(
262
+ text,
263
+ text_pair: nil,
264
+ text_target: nil,
265
+ text_pair_target: nil,
266
+ add_special_tokens: true,
267
+ padding: false,
268
+ truncation: nil,
269
+ max_length: nil,
270
+ stride: 0,
271
+ is_split_into_words: false,
272
+ pad_to_multiple_of: nil,
273
+ return_tensors: nil,
274
+ return_token_type_ids: nil,
275
+ return_attention_mask: nil,
276
+ return_overflowing_tokens: false,
277
+ return_special_tokens_mask: false,
278
+ return_offsets_mapping: false,
279
+ return_length: false,
280
+ verbose: true,
281
+ **kwargs
282
+ )
283
+ # To avoid duplicating
284
+ all_kwargs = {
285
+ add_special_tokens: add_special_tokens,
286
+ padding: padding,
287
+ truncation: truncation,
288
+ max_length: max_length,
289
+ stride: stride,
290
+ is_split_into_words: is_split_into_words,
291
+ pad_to_multiple_of: pad_to_multiple_of,
292
+ return_tensors: return_tensors,
293
+ return_token_type_ids: return_token_type_ids,
294
+ return_attention_mask: return_attention_mask,
295
+ return_overflowing_tokens: return_overflowing_tokens,
296
+ return_special_tokens_mask: return_special_tokens_mask,
297
+ return_offsets_mapping: return_offsets_mapping,
298
+ return_length: return_length,
299
+ verbose: verbose
300
+ }
301
+ all_kwargs.merge!(kwargs)
302
+ if text.nil? && text_target.nil?
303
+ raise ArgumentError, "You need to specify either `text` or `text_target`."
304
+ end
305
+ if !text.nil?
306
+ # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
307
+ # input mode in this case.
308
+ if !@in_target_context_manager
309
+ _switch_to_input_mode
310
+ end
311
+ encodings = _call_one(text: text, text_pair: text_pair, **all_kwargs)
312
+ end
313
+ if !text_target.nil?
314
+ _switch_to_target_mode
315
+ target_encodings = _call_one(text: text_target, text_pair: text_pair_target, **all_kwargs)
316
+ end
317
+ # Leave back tokenizer in input mode
318
+ _switch_to_input_mode
319
+
320
+ if text_target.nil?
321
+ encodings
322
+ elsif text.nil?
323
+ target_encodings
324
+ else
325
+ encodings["labels"] = target_encodings["input_ids"]
326
+ encodings
327
+ end
328
+ end
329
+
330
+ protected
331
+
332
+ def _switch_to_input_mode
333
+ end
334
+
335
+ def _switch_to_target_mode
336
+ end
337
+
338
+ private
339
+
340
+ def _call_one(
341
+ text:,
342
+ text_pair: nil,
343
+ add_special_tokens: true,
344
+ padding: false,
345
+ truncation: nil,
346
+ max_length: nil,
347
+ stride: 0,
348
+ is_split_into_words: false,
349
+ pad_to_multiple_of: nil,
350
+ return_tensors: nil,
351
+ return_token_type_ids: nil,
352
+ return_attention_mask: nil,
353
+ return_overflowing_tokens: false,
354
+ return_special_tokens_mask: false,
355
+ return_offsets_mapping: false,
356
+ return_length: false,
357
+ verbose: true,
358
+ **kwargs
359
+ )
360
+ # Input type checking for clearer error
361
+ _is_valid_text_input = lambda do |t|
362
+ if t.is_a?(String)
363
+ # Strings are fine
364
+ true
365
+ elsif t.is_a?(Array)
366
+ # List are fine as long as they are...
367
+ if t.length == 0
368
+ # ... empty
369
+ true
370
+ elsif t[0].is_a?(String)
371
+ # ... list of strings
372
+ true
373
+ elsif t[0].is_a?(Array)
374
+ # ... list with an empty list or with a list of strings
375
+ t[0].length == 0 || t[0][0].is_a?(String)
376
+ else
377
+ false
378
+ end
379
+ else
380
+ false
381
+ end
382
+ end
383
+
384
+ if !_is_valid_text_input.(text)
385
+ raise ArgumentError, "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
386
+ end
387
+
388
+ if !text_pair.nil? && !_is_valid_text_input.(text_pair)
389
+ raise ArgumentError, "text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)."
390
+ end
391
+
392
+ if is_split_into_words
393
+ is_batched = text.is_a?(Array) && text[0].is_a?(Array)
394
+ else
395
+ is_batched = text.is_a?(Array)
396
+ end
397
+
398
+ if is_batched
399
+ if text_pair.is_a?(String)
400
+ raise TypeError, "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`."
401
+ end
402
+ if !text_pair.nil? && text.length != text_pair.length
403
+ raise ArgumentError, "batch length of `text`: #{text.length} does not match batch length of `text_pair`: #{text_pair.length}."
404
+ end
405
+ batch_text_or_text_pairs = !text_pair.nil? ? text.zip(text_pair).to_a : text
406
+ batch_encode_plus(
407
+ batch_text_or_text_pairs: batch_text_or_text_pairs,
408
+ add_special_tokens: add_special_tokens,
409
+ padding: padding,
410
+ truncation: truncation,
411
+ max_length: max_length,
412
+ stride: stride,
413
+ is_split_into_words: is_split_into_words,
414
+ pad_to_multiple_of: pad_to_multiple_of,
415
+ return_tensors: return_tensors,
416
+ return_token_type_ids: return_token_type_ids,
417
+ return_attention_mask: return_attention_mask,
418
+ return_overflowing_tokens: return_overflowing_tokens,
419
+ return_special_tokens_mask: return_special_tokens_mask,
420
+ return_offsets_mapping: return_offsets_mapping,
421
+ return_length: return_length,
422
+ verbose: verbose,
423
+ **kwargs
424
+ )
425
+ else
426
+ encode_plus(
427
+ text: text,
428
+ text_pair: text_pair,
429
+ add_special_tokens: add_special_tokens,
430
+ padding: padding,
431
+ truncation: truncation,
432
+ max_length: max_length,
433
+ stride: stride,
434
+ is_split_into_words: is_split_into_words,
435
+ pad_to_multiple_of: pad_to_multiple_of,
436
+ return_tensors: return_tensors,
437
+ return_token_type_ids: return_token_type_ids,
438
+ return_attention_mask: return_attention_mask,
439
+ return_overflowing_tokens: return_overflowing_tokens,
440
+ return_special_tokens_mask: return_special_tokens_mask,
441
+ return_offsets_mapping: return_offsets_mapping,
442
+ return_length: return_length,
443
+ verbose: verbose,
444
+ **kwargs
445
+ )
446
+ end
447
+ end
448
+
449
+ def encode_plus(
450
+ text:,
451
+ text_pair: nil,
452
+ add_special_tokens: true,
453
+ padding: false,
454
+ truncation: nil,
455
+ max_length: nil,
456
+ stride: 0,
457
+ is_split_into_words: false,
458
+ pad_to_multiple_of: nil,
459
+ return_tensors: nil,
460
+ return_token_type_ids: nil,
461
+ return_attention_mask: nil,
462
+ return_overflowing_tokens: false,
463
+ return_special_tokens_mask: false,
464
+ return_offsets_mapping: false,
465
+ return_length: false,
466
+ verbose: true,
467
+ **kwargs
468
+ )
469
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
470
+ padding_strategy, truncation_strategy, max_length, kwargs =
471
+ _get_padding_truncation_strategies(
472
+ padding: padding,
473
+ truncation: truncation,
474
+ max_length: max_length,
475
+ pad_to_multiple_of: pad_to_multiple_of,
476
+ verbose: verbose,
477
+ **kwargs
478
+ )
479
+
480
+ _encode_plus(
481
+ text: text,
482
+ text_pair: text_pair,
483
+ add_special_tokens: add_special_tokens,
484
+ padding_strategy: padding_strategy,
485
+ truncation_strategy: truncation_strategy,
486
+ max_length: max_length,
487
+ stride: stride,
488
+ is_split_into_words: is_split_into_words,
489
+ pad_to_multiple_of: pad_to_multiple_of,
490
+ return_tensors: return_tensors,
491
+ return_token_type_ids: return_token_type_ids,
492
+ return_attention_mask: return_attention_mask,
493
+ return_overflowing_tokens: return_overflowing_tokens,
494
+ return_special_tokens_mask: return_special_tokens_mask,
495
+ return_offsets_mapping: return_offsets_mapping,
496
+ return_length: return_length,
497
+ verbose: verbose,
498
+ **kwargs
499
+ )
500
+ end
501
+
502
+ def batch_encode_plus(
503
+ batch_text_or_text_pairs:,
504
+ add_special_tokens: true,
505
+ padding: false,
506
+ truncation: nil,
507
+ max_length: nil,
508
+ stride: 0,
509
+ is_split_into_words: false,
510
+ pad_to_multiple_of: nil,
511
+ return_tensors: nil,
512
+ return_token_type_ids: nil,
513
+ return_attention_mask: nil,
514
+ return_overflowing_tokens: false,
515
+ return_special_tokens_mask: false,
516
+ return_offsets_mapping: false,
517
+ return_length: false,
518
+ verbose: true,
519
+ **kwargs
520
+ )
521
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
522
+ padding_strategy, truncation_strategy, max_length, kwargs =
523
+ _get_padding_truncation_strategies(
524
+ padding: padding,
525
+ truncation: truncation,
526
+ max_length: max_length,
527
+ pad_to_multiple_of: pad_to_multiple_of,
528
+ verbose: verbose,
529
+ **kwargs
530
+ )
531
+
532
+ _batch_encode_plus(
533
+ batch_text_or_text_pairs,
534
+ add_special_tokens: add_special_tokens,
535
+ padding_strategy: padding_strategy,
536
+ truncation_strategy: truncation_strategy,
537
+ max_length: max_length,
538
+ stride: stride,
539
+ is_split_into_words: is_split_into_words,
540
+ pad_to_multiple_of: pad_to_multiple_of,
541
+ return_tensors: return_tensors,
542
+ return_token_type_ids: return_token_type_ids,
543
+ return_attention_mask: return_attention_mask,
544
+ return_overflowing_tokens: return_overflowing_tokens,
545
+ return_special_tokens_mask: return_special_tokens_mask,
546
+ return_offsets_mapping: return_offsets_mapping,
547
+ return_length: return_length,
548
+ verbose: verbose,
549
+ **kwargs
550
+ )
551
+ end
552
+
553
+ def _get_padding_truncation_strategies(
554
+ padding: false,
555
+ truncation: nil,
556
+ max_length: nil,
557
+ pad_to_multiple_of: nil,
558
+ verbose: true,
559
+ **kwargs
560
+ )
561
+ padding_strategy = PaddingStrategy::DO_NOT_PAD
562
+ truncation_strategy = TruncationStrategy::DO_NOT_TRUNCATE
563
+
564
+ old_truncation_strategy = kwargs.delete(:truncation_strategy) || "do_not_truncate"
565
+ old_pad_to_max_length = kwargs.delete(:pad_to_max_length) || false
566
+
567
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
568
+ # If you only set max_length, it activates truncation for max_length
569
+ if !max_length.nil? && padding == false && truncation.nil?
570
+ raise Todo
571
+ end
572
+
573
+ # Get padding strategy
574
+ if padding == false && old_pad_to_max_length
575
+ if verbose
576
+ raise Todo
577
+ end
578
+ if max_length.nil?
579
+ padding_strategy = PaddingStrategy::LONGEST
580
+ else
581
+ padding_strategy = PaddingStrategy::MAX_LENGTH
582
+ end
583
+ elsif padding != false
584
+ if padding == true
585
+ if verbose
586
+ # raise Todo
587
+ end
588
+ padding_strategy = PaddingStrategy::LONGEST # Default to pad to the longest sequence in the batch
589
+ elsif !padding.is_a?(PaddingStrategy)
590
+ padding_strategy = PaddingStrategy.new(padding)
591
+ elsif padding.is_a?(PaddingStrategy)
592
+ padding_strategy = padding
593
+ end
594
+ else
595
+ padding_strategy = PaddingStrategy::DO_NOT_PAD
596
+ end
597
+
598
+ # Get truncation strategy
599
+ if truncation.nil? && old_truncation_strategy != "do_not_truncate"
600
+ if verbose
601
+ raise Todo
602
+ end
603
+ truncation_strategy = TruncationStrategy.new(old_truncation_strategy).to_s
604
+ elsif truncation != false && !truncation.nil?
605
+ if truncation == true
606
+ truncation_strategy = (
607
+ TruncationStrategy::LONGEST_FIRST
608
+ ) # Default to truncate the longest sequences in pairs of inputs
609
+ elsif !truncation.is_a?(TruncationStrategy)
610
+ truncation_strategy = TruncationStrategy.new(truncation).to_s
611
+ else
612
+ truncation_strategy = truncation
613
+ end
614
+ else
615
+ truncation_strategy = TruncationStrategy::DO_NOT_TRUNCATE
616
+ end
617
+
618
+ # Set max length if needed
619
+ if max_length.nil?
620
+ if padding_strategy == PaddingStrategy::MAX_LENGTH
621
+ if @model_max_length > LARGE_INTEGER
622
+ if verbose
623
+ raise Todo
624
+ end
625
+ padding_strategy = PaddingStrategy::DO_NOT_PAD
626
+ else
627
+ max_length = @model_max_length
628
+ end
629
+ end
630
+
631
+ if truncation_strategy != TruncationStrategy::DO_NOT_TRUNCATE
632
+ if @model_max_length > LARGE_INTEGER
633
+ if verbose
634
+ raise Todo
635
+ end
636
+ truncation_strategy = TruncationStrategy::DO_NOT_TRUNCATE
637
+ else
638
+ max_length = @model_max_length
639
+ end
640
+ end
641
+ end
642
+
643
+ # Test if we have a padding token
644
+ if padding_strategy != PaddingStrategy::DO_NOT_PAD && (@pad_token.nil? || pad_token_id < 0)
645
+ raise ArgumentError,
646
+ "Asking to pad but the tokenizer does not have a padding token. " +
647
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " +
648
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
649
+ end
650
+
651
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
652
+ if (
653
+ truncation_strategy != TruncationStrategy::DO_NOT_TRUNCATE &&
654
+ padding_strategy != PaddingStrategy::DO_NOT_PAD &&
655
+ !pad_to_multiple_of.nil? &&
656
+ !max_length.nil? &&
657
+ (max_length % pad_to_multiple_of != 0)
658
+ )
659
+ raise ArgumentError,
660
+ "Truncation and padding are both activated but " +
661
+ "truncation length (#{max_length}) is not a multiple of pad_to_multiple_of (#{pad_to_multiple_of})."
662
+ end
663
+
664
+ [padding_strategy, truncation_strategy, max_length, kwargs]
665
+ end
666
+
667
+ class << self
668
+ def from_pretrained(
669
+ pretrained_model_name_or_path,
670
+ *init_inputs,
671
+ cache_dir: nil,
672
+ force_download: false,
673
+ local_files_only: false,
674
+ token: nil,
675
+ revision: "main",
676
+ trust_remote_code: false,
677
+ **kwargs
678
+ )
679
+ resume_download = kwargs.delete(:resume_download) { false }
680
+ proxies = kwargs.delete(:proxies)
681
+ subfolder = kwargs.delete(:subfolder)
682
+ from_pipeline = kwargs.delete(:_from_pipeline)
683
+ from_auto_class = kwargs.delete(:_from_auto) { false }
684
+ commit_hash = kwargs.delete(:_commit_hash)
685
+
686
+ user_agent = {file_type: "tokenizer", from_auto_class: from_auto_class, is_fast: name.include?("Fast")}
687
+ if !from_pipeline.nil?
688
+ user_agent[:using_pipeline] = from_pipeline
689
+ end
690
+
691
+ if Utils::Hub.is_offline_mode && !local_files_only
692
+ Transformers.logger.info("Offline mode: forcing local_files_only: true")
693
+ local_files_only = true
694
+ end
695
+
696
+ pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
697
+ vocab_files = {}
698
+ init_configuration = {}
699
+
700
+ is_local = Dir.exist?(pretrained_model_name_or_path)
701
+ single_file_id = nil
702
+ if File.exist?(pretrained_model_name_or_path)
703
+ raise Todo
704
+ end
705
+
706
+ # At this point pretrained_model_name_or_path is either a directory or a model identifier name
707
+ additional_files_names = {
708
+ added_tokens_file: ADDED_TOKENS_FILE, # kept only for legacy
709
+ special_tokens_map_file: SPECIAL_TOKENS_MAP_FILE, # kept only for legacy
710
+ tokenizer_config_file: TOKENIZER_CONFIG_FILE,
711
+ # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
712
+ tokenizer_file: FULL_TOKENIZER_FILE
713
+ }
714
+ vocab_files = vocab_files_names.merge(additional_files_names)
715
+ if vocab_files[:tokenizer_file]
716
+ # Try to get the tokenizer config to see if there are versioned tokenizer files.
717
+ fast_tokenizer_file = FULL_TOKENIZER_FILE
718
+ resolved_config_file =
719
+ Utils::Hub.cached_file(
720
+ pretrained_model_name_or_path,
721
+ TOKENIZER_CONFIG_FILE,
722
+ cache_dir: cache_dir,
723
+ force_download: force_download,
724
+ resume_download: resume_download,
725
+ proxies: proxies,
726
+ token: token,
727
+ revision: revision,
728
+ local_files_only: local_files_only,
729
+ subfolder: subfolder,
730
+ user_agent: user_agent,
731
+ _raise_exceptions_for_gated_repo: false,
732
+ _raise_exceptions_for_missing_entries: false,
733
+ _raise_exceptions_for_connection_errors: false,
734
+ _commit_hash: commit_hash
735
+ )
736
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
737
+ if !resolved_config_file.nil?
738
+ tokenizer_config = JSON.load_file(resolved_config_file)
739
+ if tokenizer_config["fast_tokenizer_files"]
740
+ fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
741
+ end
742
+ end
743
+ vocab_files[:tokenizer_file] = fast_tokenizer_file
744
+ end
745
+
746
+ # Get files from url, cache, or disk depending on the case
747
+ resolved_vocab_files = {}
748
+ unresolved_files = []
749
+ vocab_files.each do |file_id, file_path|
750
+ if file_path.nil?
751
+ resolved_vocab_files[file_id] = nil
752
+ elsif single_file_id == file_id
753
+ if File.exist?(file_path)
754
+ resolved_vocab_files[file_id] = file_path
755
+ else
756
+ raise Todo
757
+ end
758
+ else
759
+ resolved_vocab_files[file_id] =
760
+ Utils::Hub.cached_file(
761
+ pretrained_model_name_or_path,
762
+ file_path,
763
+ cache_dir: cache_dir,
764
+ force_download: force_download,
765
+ proxies: proxies,
766
+ resume_download: resume_download,
767
+ local_files_only: local_files_only,
768
+ token: token,
769
+ user_agent: user_agent,
770
+ revision: revision,
771
+ subfolder: subfolder,
772
+ _raise_exceptions_for_gated_repo: false,
773
+ _raise_exceptions_for_missing_entries: false,
774
+ _raise_exceptions_for_connection_errors: false,
775
+ _commit_hash: commit_hash,
776
+ )
777
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
778
+ end
779
+ end
780
+
781
+ # not used?
782
+ if unresolved_files.length > 0
783
+ raise Todo
784
+ end
785
+
786
+ vocab_files.each do |file_id, file_path|
787
+ if !resolved_vocab_files.include?(file_id)
788
+ next
789
+ end
790
+
791
+ if is_local
792
+ Transformers.logger.info("loading file #{file_path}")
793
+ else
794
+ Transformers.logger.info("loading file #{file_path} from cache at #{resolved_vocab_files[file_id] || "nil"}")
795
+ end
796
+ end
797
+
798
+ _from_pretrained(
799
+ resolved_vocab_files,
800
+ pretrained_model_name_or_path,
801
+ init_configuration,
802
+ *init_inputs,
803
+ token: token,
804
+ cache_dir: cache_dir,
805
+ local_files_only: local_files_only,
806
+ _commit_hash: commit_hash,
807
+ _is_local: is_local,
808
+ trust_remote_code: trust_remote_code,
809
+ **kwargs
810
+ )
811
+ end
812
+
813
+ def _from_pretrained(
814
+ resolved_vocab_files,
815
+ pretrained_model_name_or_path,
816
+ init_configuration,
817
+ *init_inputs,
818
+ token: nil,
819
+ cache_dir: nil,
820
+ local_files_only: false,
821
+ _commit_hash: nil,
822
+ _is_local: false,
823
+ trust_remote_code: false,
824
+ **kwargs
825
+ )
826
+ # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
827
+ # file or if `from_slow` is set to True.
828
+ from_slow = kwargs.delete(:from_slow) { false }
829
+ has_tokenizer_file = !resolved_vocab_files[:tokenizer_file].nil?
830
+ if (from_slow || !has_tokenizer_file) && !slow_tokenizer_class.nil?
831
+ slow_tokenizer =
832
+ slow_tokenizer_class._from_pretrained(
833
+ Copy.deepcopy(resolved_vocab_files),
834
+ pretrained_model_name_or_path,
835
+ Copy.deepcopy(init_configuration),
836
+ *init_inputs,
837
+ token: token,
838
+ cache_dir: cache_dir,
839
+ local_files_only: local_files_only,
840
+ _commit_hash: _commit_hash,
841
+ **Copy.deepcopy(kwargs)
842
+ )
843
+ else
844
+ slow_tokenizer = nil
845
+ end
846
+
847
+ # Prepare tokenizer initialization kwargs
848
+ # Did we saved some inputs and kwargs to reload ?
849
+ tokenizer_config_file = resolved_vocab_files.delete(:tokenizer_config_file)
850
+ if !tokenizer_config_file.nil?
851
+ init_kwargs = JSON.load_file(tokenizer_config_file).transform_keys(&:to_sym)
852
+ # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
853
+ config_tokenizer_class = init_kwargs[:tokenizer_class]
854
+ init_kwargs.delete(:tokenizer_class)
855
+ if !has_tokenizer_file
856
+ init_kwargs.delete(:tokenizer_file)
857
+ end
858
+ saved_init_inputs = init_kwargs.delete(:init_inputs) { [] }
859
+ if init_inputs.empty?
860
+ init_inputs = saved_init_inputs
861
+ end
862
+ else
863
+ config_tokenizer_class = nil
864
+ init_kwargs = init_configuration
865
+ end
866
+
867
+ if config_tokenizer_class.nil?
868
+ config =
869
+ AutoConfig.from_pretrained(
870
+ pretrained_model_name_or_path,
871
+ token: token,
872
+ cache_dir: cache_dir,
873
+ local_files_only: local_files_only,
874
+ trust_remote_code: trust_remote_code,
875
+ _commit_hash: _commit_hash,
876
+ )
877
+ config_tokenizer_class = config.tokenizer_class
878
+
879
+ if config_tokenizer_class.nil?
880
+ # Third attempt. If we have not yet found the original type of the tokenizer,
881
+ # we are loading we see if we can infer it from the type of the configuration file
882
+ if config.class.model_type
883
+ model_type = config.class.model_type
884
+ else
885
+ # Fallback: use pattern matching on the string.
886
+ model_type = nil
887
+ TOKENIZER_MAPPING_NAMES.each_key do |pattern|
888
+ if pretrained_model_name_or_path.to_s.include?(pattern)
889
+ model_type = pattern
890
+ break
891
+ end
892
+ end
893
+ end
894
+
895
+ if !model_type.nil?
896
+ config_tokenizer_class, config_tokenizer_class_fast =
897
+ TOKENIZER_MAPPING_NAMES.fetch(model_type, [nil, nil])
898
+
899
+ if config_tokenizer_class.nil?
900
+ config_tokenizer_class = config_tokenizer_class_fast
901
+ end
902
+ end
903
+ end
904
+ end
905
+
906
+ if !config_tokenizer_class.nil?
907
+ if name.split("::").last.gsub("Fast", "") != config_tokenizer_class.gsub("Fast", "")
908
+ raise Todo
909
+ end
910
+ end
911
+
912
+ # Update with newly provided kwargs
913
+ init_kwargs.merge!(kwargs)
914
+
915
+ # Merge resolved_vocab_files arguments in init_kwargs.
916
+ _added_tokens_file = resolved_vocab_files.delete(:added_tokens_file)
917
+ _special_tokens_map_file = resolved_vocab_files.delete(:special_tokens_map_file)
918
+ resolved_vocab_files.each do |args_name, file_path|
919
+ if !init_kwargs.include?(args_name)
920
+ init_kwargs[args_name] = file_path
921
+ end
922
+ end
923
+ _tokenizer_file = resolved_vocab_files.delete(:tokenizer_file)
924
+
925
+ if !slow_tokenizer.nil?
926
+ init_kwargs[:__slow_tokenizer] = slow_tokenizer
927
+ end
928
+ init_kwargs[:name_or_path] = pretrained_model_name_or_path
929
+
930
+ # Instantiate the tokenizer.
931
+ tokenizer = new(*init_inputs, **init_kwargs)
932
+
933
+ tokenizer
934
+ end
935
+ end
936
+ end
937
+ end