transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
| @@ -0,0 +1,386 @@ | |
| 1 | 
            +
            # Copyright 2020 The HuggingFace Inc. team.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            module Transformers
         | 
| 16 | 
            +
              class PreTrainedTokenizerFast < PreTrainedTokenizerBase
         | 
| 17 | 
            +
                def initialize(*args, **kwargs)
         | 
| 18 | 
            +
                  tokenizer_object = kwargs.delete(:tokenizer_object)
         | 
| 19 | 
            +
                  slow_tokenizer = kwargs.delete(:__slow_tokenizer)
         | 
| 20 | 
            +
                  fast_tokenizer_file = kwargs.delete(:tokenizer_file)
         | 
| 21 | 
            +
                  from_slow = kwargs.delete(:from_slow) { false }
         | 
| 22 | 
            +
                  _added_tokens_decoder = kwargs.delete(:added_tokens_decoder)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                  if !tokenizer_object.nil?
         | 
| 25 | 
            +
                    fast_tokenizer = Copy.deepcopy(tokenizer_object)
         | 
| 26 | 
            +
                  elsif !fast_tokenizer_file.nil? && !from_slow
         | 
| 27 | 
            +
                    # We have a serialization from tokenizers which let us directly build the backend
         | 
| 28 | 
            +
                    fast_tokenizer = Tokenizers::Tokenizer.from_file(fast_tokenizer_file)
         | 
| 29 | 
            +
                  elsif !slow_tokenizer.nil?
         | 
| 30 | 
            +
                    # We need to convert a slow tokenizer to build the backend
         | 
| 31 | 
            +
                    fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
         | 
| 32 | 
            +
                  elsif !@slow_tokenizer_class.nil?
         | 
| 33 | 
            +
                    # We need to create and convert a slow tokenizer to build the backend
         | 
| 34 | 
            +
                    slow_tokenizer = @slow_tokenizer_class.new(*args, **kwargs)
         | 
| 35 | 
            +
                    fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
         | 
| 36 | 
            +
                  else
         | 
| 37 | 
            +
                    raise ArgumentError, <<~MSG
         | 
| 38 | 
            +
                      Couldn't instantiate the backend tokenizer from one of:
         | 
| 39 | 
            +
                      (1) a `tokenizers` library serialization file,
         | 
| 40 | 
            +
                      (2) a slow tokenizer instance to convert or
         | 
| 41 | 
            +
                      (3) an equivalent slow tokenizer class to instantiate and convert.
         | 
| 42 | 
            +
                      You need to have sentencepiece installed to convert a slow tokenizer to a fast one.
         | 
| 43 | 
            +
                    MSG
         | 
| 44 | 
            +
                  end
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                  @tokenizer = fast_tokenizer
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  if !slow_tokenizer.nil?
         | 
| 49 | 
            +
                    kwargs.merge!(slow_tokenizer.init_kwargs)
         | 
| 50 | 
            +
                  end
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                  @decode_use_source_tokenizer = false
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                  _truncation = @tokenizer.truncation
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                  if !_truncation.nil?
         | 
| 57 | 
            +
                    _truncation = _truncation.transform_keys(&:to_sym)
         | 
| 58 | 
            +
                    @tokenizer.enable_truncation(_truncation[:max_length], **_truncation.except(:max_length))
         | 
| 59 | 
            +
                    kwargs[:max_length] ||= _truncation[:max_length]
         | 
| 60 | 
            +
                    kwargs[:truncation_side] ||= _truncation[:direction]
         | 
| 61 | 
            +
                    kwargs[:stride] ||= _truncation[:stride]
         | 
| 62 | 
            +
                    kwargs[:truncation_strategy] ||= _truncation[:strategy]
         | 
| 63 | 
            +
                  else
         | 
| 64 | 
            +
                    @tokenizer.no_truncation
         | 
| 65 | 
            +
                  end
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                  _padding = @tokenizer.padding
         | 
| 68 | 
            +
                  if !_padding.nil?
         | 
| 69 | 
            +
                    _padding = _padding.transform_keys(&:to_sym)
         | 
| 70 | 
            +
                    @tokenizer.enable_padding(**_padding)
         | 
| 71 | 
            +
                    kwargs[:pad_token] ||= _padding[:pad_token]
         | 
| 72 | 
            +
                    kwargs[:pad_token_type_id] ||= _padding[:pad_token_type_id]
         | 
| 73 | 
            +
                    kwargs[:padding_side] ||= _padding[:direction]
         | 
| 74 | 
            +
                    kwargs[:max_length] ||= _padding[:length]
         | 
| 75 | 
            +
                    kwargs[:pad_to_multiple_of] ||= _padding[:pad_to_multiple_of]
         | 
| 76 | 
            +
                  end
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                  # We call this after having initialized the backend tokenizer because we update it.
         | 
| 79 | 
            +
                  super(**kwargs)
         | 
| 80 | 
            +
                end
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def is_fast
         | 
| 83 | 
            +
                  true
         | 
| 84 | 
            +
                end
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def get_vocab
         | 
| 87 | 
            +
                  @tokenizer.vocab(with_added_tokens: true)
         | 
| 88 | 
            +
                end
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def vocab
         | 
| 91 | 
            +
                  get_vocab
         | 
| 92 | 
            +
                end
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def convert_tokens_to_ids(tokens)
         | 
| 95 | 
            +
                  if tokens.nil?
         | 
| 96 | 
            +
                    return nil
         | 
| 97 | 
            +
                  end
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                  if tokens.is_a?(String)
         | 
| 100 | 
            +
                    return _convert_token_to_id_with_added_voc(tokens)
         | 
| 101 | 
            +
                  end
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                  ids = []
         | 
| 104 | 
            +
                  tokens.each do |token|
         | 
| 105 | 
            +
                    ids << _convert_token_to_id_with_added_voc(token)
         | 
| 106 | 
            +
                  end
         | 
| 107 | 
            +
                  ids
         | 
| 108 | 
            +
                end
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def _convert_token_to_id_with_added_voc(token)
         | 
| 111 | 
            +
                  index = @tokenizer.token_to_id(token)
         | 
| 112 | 
            +
                  if index.nil?
         | 
| 113 | 
            +
                    return unk_token_id
         | 
| 114 | 
            +
                  end
         | 
| 115 | 
            +
                  index
         | 
| 116 | 
            +
                end
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def convert_ids_to_tokens(ids, skip_special_tokens: false)
         | 
| 119 | 
            +
                  if ids.is_a?(Integer)
         | 
| 120 | 
            +
                    return @tokenizer.id_to_token(ids)
         | 
| 121 | 
            +
                  end
         | 
| 122 | 
            +
                  tokens = []
         | 
| 123 | 
            +
                  ids.each do |index|
         | 
| 124 | 
            +
                    index = index.to_i
         | 
| 125 | 
            +
                    if skip_special_tokens && @all_special_ids.include?(index)
         | 
| 126 | 
            +
                      next
         | 
| 127 | 
            +
                    end
         | 
| 128 | 
            +
                    tokens << @tokenizer.id_to_token(index)
         | 
| 129 | 
            +
                  end
         | 
| 130 | 
            +
                  tokens
         | 
| 131 | 
            +
                end
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                private
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def set_truncation_and_padding(
         | 
| 136 | 
            +
                  padding_strategy:,
         | 
| 137 | 
            +
                  truncation_strategy:,
         | 
| 138 | 
            +
                  max_length:,
         | 
| 139 | 
            +
                  stride:,
         | 
| 140 | 
            +
                  pad_to_multiple_of:
         | 
| 141 | 
            +
                )
         | 
| 142 | 
            +
                  _truncation = @tokenizer.truncation
         | 
| 143 | 
            +
                  _padding = @tokenizer.padding
         | 
| 144 | 
            +
                  # Set truncation and padding on the backend tokenizer
         | 
| 145 | 
            +
                  if truncation_strategy == TruncationStrategy::DO_NOT_TRUNCATE
         | 
| 146 | 
            +
                    if !_truncation.nil?
         | 
| 147 | 
            +
                      @tokenizer.no_truncation
         | 
| 148 | 
            +
                    end
         | 
| 149 | 
            +
                  else
         | 
| 150 | 
            +
                    target = {
         | 
| 151 | 
            +
                      max_length: max_length,
         | 
| 152 | 
            +
                      stride: stride,
         | 
| 153 | 
            +
                      strategy: truncation_strategy,
         | 
| 154 | 
            +
                      direction: @truncation_side
         | 
| 155 | 
            +
                    }
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    # _truncation might contain more keys that the target `transformers`
         | 
| 158 | 
            +
                    # supports. Use only the target keys to trigger `enable_truncation`.
         | 
| 159 | 
            +
                    # This should enable this code to works on various `tokenizers`
         | 
| 160 | 
            +
                    # targets.
         | 
| 161 | 
            +
                    if _truncation.nil?
         | 
| 162 | 
            +
                      current = nil
         | 
| 163 | 
            +
                    else
         | 
| 164 | 
            +
                      current = target.to_h { |k, _| [k, _truncation[k]] }
         | 
| 165 | 
            +
                    end
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    if current != target
         | 
| 168 | 
            +
                      @tokenizer.enable_truncation(target.delete(:max_length), **target)
         | 
| 169 | 
            +
                    end
         | 
| 170 | 
            +
                  end
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                  if padding_strategy == PaddingStrategy::DO_NOT_PAD
         | 
| 173 | 
            +
                    if !_padding.nil?
         | 
| 174 | 
            +
                      @tokenizer.no_padding
         | 
| 175 | 
            +
                    end
         | 
| 176 | 
            +
                  else
         | 
| 177 | 
            +
                    length = padding_strategy == PaddingStrategy::MAX_LENGTH ? max_length : nil
         | 
| 178 | 
            +
                    target = {
         | 
| 179 | 
            +
                      length: length,
         | 
| 180 | 
            +
                      direction: @padding_side,
         | 
| 181 | 
            +
                      pad_id: @pad_token_id,
         | 
| 182 | 
            +
                      pad_token: @pad_token,
         | 
| 183 | 
            +
                      pad_type_id: @pad_token_type_id,
         | 
| 184 | 
            +
                      pad_to_multiple_of: pad_to_multiple_of
         | 
| 185 | 
            +
                    }
         | 
| 186 | 
            +
                    if _padding != target
         | 
| 187 | 
            +
                      @tokenizer.enable_padding(**target)
         | 
| 188 | 
            +
                    end
         | 
| 189 | 
            +
                  end
         | 
| 190 | 
            +
                end
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                def _batch_encode_plus(
         | 
| 193 | 
            +
                  batch_text_or_text_pairs,
         | 
| 194 | 
            +
                  add_special_tokens: true,
         | 
| 195 | 
            +
                  padding_strategy: PaddingStrategy::DO_NOT_PAD,
         | 
| 196 | 
            +
                  truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
         | 
| 197 | 
            +
                  max_length: nil,
         | 
| 198 | 
            +
                  stride: 0,
         | 
| 199 | 
            +
                  is_split_into_words: false,
         | 
| 200 | 
            +
                  pad_to_multiple_of: nil,
         | 
| 201 | 
            +
                  return_tensors: nil,
         | 
| 202 | 
            +
                  return_token_type_ids: nil,
         | 
| 203 | 
            +
                  return_attention_mask: nil,
         | 
| 204 | 
            +
                  return_overflowing_tokens: false,
         | 
| 205 | 
            +
                  return_special_tokens_mask: false,
         | 
| 206 | 
            +
                  return_offsets_mapping: false,
         | 
| 207 | 
            +
                  return_length: false,
         | 
| 208 | 
            +
                  verbose: true
         | 
| 209 | 
            +
                )
         | 
| 210 | 
            +
                  if !batch_text_or_text_pairs.is_a?(Array)
         | 
| 211 | 
            +
                    raise TypeError, "batch_text_or_text_pairs has to be an array (got #{batch_text_or_text_pairs.class.name})"
         | 
| 212 | 
            +
                  end
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                  # Set the truncation and padding strategy and restore the initial configuration
         | 
| 215 | 
            +
                  set_truncation_and_padding(
         | 
| 216 | 
            +
                    padding_strategy: padding_strategy,
         | 
| 217 | 
            +
                    truncation_strategy: truncation_strategy,
         | 
| 218 | 
            +
                    max_length: max_length,
         | 
| 219 | 
            +
                    stride: stride,
         | 
| 220 | 
            +
                    pad_to_multiple_of: pad_to_multiple_of
         | 
| 221 | 
            +
                  )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                  encodings =
         | 
| 224 | 
            +
                    @tokenizer.encode_batch(
         | 
| 225 | 
            +
                      batch_text_or_text_pairs,
         | 
| 226 | 
            +
                      add_special_tokens: add_special_tokens,
         | 
| 227 | 
            +
                      is_pretokenized: is_split_into_words,
         | 
| 228 | 
            +
                    )
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                  # Convert encoding to dict
         | 
| 231 | 
            +
                  # `Tokens` has type: Tuple[
         | 
| 232 | 
            +
                  #                       List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
         | 
| 233 | 
            +
                  #                       List[EncodingFast]
         | 
| 234 | 
            +
                  #                    ]
         | 
| 235 | 
            +
                  # with nested dimensions corresponding to batch, overflows, sequence length
         | 
| 236 | 
            +
                  tokens_and_encodings =
         | 
| 237 | 
            +
                    encodings.map do |encoding|
         | 
| 238 | 
            +
                      _convert_encoding(
         | 
| 239 | 
            +
                        encoding: encoding,
         | 
| 240 | 
            +
                        return_token_type_ids: return_token_type_ids,
         | 
| 241 | 
            +
                        return_attention_mask: return_attention_mask,
         | 
| 242 | 
            +
                        return_overflowing_tokens: return_overflowing_tokens,
         | 
| 243 | 
            +
                        return_special_tokens_mask: return_special_tokens_mask,
         | 
| 244 | 
            +
                        return_offsets_mapping: return_offsets_mapping,
         | 
| 245 | 
            +
                        return_length: return_length,
         | 
| 246 | 
            +
                        verbose: verbose
         | 
| 247 | 
            +
                      )
         | 
| 248 | 
            +
                    end
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                  # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
         | 
| 251 | 
            +
                  # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
         | 
| 252 | 
            +
                  # (we say ~ because the number of overflow varies with the example in the batch)
         | 
| 253 | 
            +
                  #
         | 
| 254 | 
            +
                  # To match each overflowing sample with the original sample in the batch
         | 
| 255 | 
            +
                  # we add an overflow_to_sample_mapping array (see below)
         | 
| 256 | 
            +
                  sanitized_tokens = {}
         | 
| 257 | 
            +
                  tokens_and_encodings[0][0].each_key do |key|
         | 
| 258 | 
            +
                    stack = tokens_and_encodings.map { |item, _| item[key][0] }
         | 
| 259 | 
            +
                    sanitized_tokens[key] = stack
         | 
| 260 | 
            +
                  end
         | 
| 261 | 
            +
                  sanitized_encodings = tokens_and_encodings.map { |_, item| item[0] }
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                  # If returning overflowing tokens, we need to return a mapping
         | 
| 264 | 
            +
                  # from the batch idx to the original sample
         | 
| 265 | 
            +
                  if return_overflowing_tokens
         | 
| 266 | 
            +
                    overflow_to_sample_mapping = []
         | 
| 267 | 
            +
                    tokens_and_encodings.each_with_index do |(toks, _), i|
         | 
| 268 | 
            +
                      overflow_to_sample_mapping += [i] * toks["input_ids"].length
         | 
| 269 | 
            +
                    end
         | 
| 270 | 
            +
                    sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
         | 
| 271 | 
            +
                  end
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                  sanitized_tokens["input_ids"].each do |input_ids|
         | 
| 274 | 
            +
                    _eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
         | 
| 275 | 
            +
                  end
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                  BatchEncoding.new(data: sanitized_tokens, encoding: sanitized_encodings, tensor_type: return_tensors)
         | 
| 278 | 
            +
                end
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def _convert_encoding(
         | 
| 281 | 
            +
                  encoding:,
         | 
| 282 | 
            +
                  return_token_type_ids: nil,
         | 
| 283 | 
            +
                  return_attention_mask: nil,
         | 
| 284 | 
            +
                  return_overflowing_tokens: false,
         | 
| 285 | 
            +
                  return_special_tokens_mask: false,
         | 
| 286 | 
            +
                  return_offsets_mapping: false,
         | 
| 287 | 
            +
                  return_length: false,
         | 
| 288 | 
            +
                  verbose: true
         | 
| 289 | 
            +
                )
         | 
| 290 | 
            +
                  if return_token_type_ids.nil?
         | 
| 291 | 
            +
                    return_token_type_ids = self.class.model_input_names.include?("token_type_ids")
         | 
| 292 | 
            +
                  end
         | 
| 293 | 
            +
                  if return_attention_mask.nil?
         | 
| 294 | 
            +
                    return_attention_mask = self.class.model_input_names.include?("attention_mask")
         | 
| 295 | 
            +
                  end
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                  if return_overflowing_tokens && !encoding.overflowing.nil?
         | 
| 298 | 
            +
                    encodings = [encoding] + encoding.overflowing
         | 
| 299 | 
            +
                  else
         | 
| 300 | 
            +
                    encodings = [encoding]
         | 
| 301 | 
            +
                  end
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                  encoding_dict = Hash.new { |h, k| h[k] = [] }
         | 
| 304 | 
            +
                  encodings.each do |e|
         | 
| 305 | 
            +
                    encoding_dict["input_ids"] << e.ids
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    if return_token_type_ids
         | 
| 308 | 
            +
                      encoding_dict["token_type_ids"] << e.type_ids
         | 
| 309 | 
            +
                    end
         | 
| 310 | 
            +
                    if return_attention_mask
         | 
| 311 | 
            +
                      encoding_dict["attention_mask"] << e.attention_mask
         | 
| 312 | 
            +
                    end
         | 
| 313 | 
            +
                    if return_special_tokens_mask
         | 
| 314 | 
            +
                      encoding_dict["special_tokens_mask"] << e.special_tokens_mask
         | 
| 315 | 
            +
                    end
         | 
| 316 | 
            +
                    if return_offsets_mapping
         | 
| 317 | 
            +
                      encoding_dict["offset_mapping"] << e.offsets
         | 
| 318 | 
            +
                    end
         | 
| 319 | 
            +
                    if return_length
         | 
| 320 | 
            +
                      encoding_dict["length"] << e.ids.length
         | 
| 321 | 
            +
                    end
         | 
| 322 | 
            +
                  end
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                  [encoding_dict, encodings]
         | 
| 325 | 
            +
                end
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                def _encode_plus(
         | 
| 328 | 
            +
                  text:,
         | 
| 329 | 
            +
                  text_pair: nil,
         | 
| 330 | 
            +
                  add_special_tokens: true,
         | 
| 331 | 
            +
                  padding_strategy: PaddingStrategy::DO_NOT_PAD,
         | 
| 332 | 
            +
                  truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
         | 
| 333 | 
            +
                  max_length: nil,
         | 
| 334 | 
            +
                  stride: 0,
         | 
| 335 | 
            +
                  is_split_into_words: false,
         | 
| 336 | 
            +
                  pad_to_multiple_of: nil,
         | 
| 337 | 
            +
                  return_tensors: nil,
         | 
| 338 | 
            +
                  return_token_type_ids: nil,
         | 
| 339 | 
            +
                  return_attention_mask: nil,
         | 
| 340 | 
            +
                  return_overflowing_tokens: false,
         | 
| 341 | 
            +
                  return_special_tokens_mask: false,
         | 
| 342 | 
            +
                  return_offsets_mapping: false,
         | 
| 343 | 
            +
                  return_length: false,
         | 
| 344 | 
            +
                  verbose: true,
         | 
| 345 | 
            +
                  **kwargs
         | 
| 346 | 
            +
                )
         | 
| 347 | 
            +
                  batched_input = text_pair ? [[text, text_pair]] : [text]
         | 
| 348 | 
            +
                  batched_output =
         | 
| 349 | 
            +
                    _batch_encode_plus(
         | 
| 350 | 
            +
                      batched_input,
         | 
| 351 | 
            +
                      is_split_into_words: is_split_into_words,
         | 
| 352 | 
            +
                      add_special_tokens: add_special_tokens,
         | 
| 353 | 
            +
                      padding_strategy: padding_strategy,
         | 
| 354 | 
            +
                      truncation_strategy: truncation_strategy,
         | 
| 355 | 
            +
                      max_length: max_length,
         | 
| 356 | 
            +
                      stride: stride,
         | 
| 357 | 
            +
                      pad_to_multiple_of: pad_to_multiple_of,
         | 
| 358 | 
            +
                      return_tensors: return_tensors,
         | 
| 359 | 
            +
                      return_token_type_ids: return_token_type_ids,
         | 
| 360 | 
            +
                      return_attention_mask: return_attention_mask,
         | 
| 361 | 
            +
                      return_overflowing_tokens: return_overflowing_tokens,
         | 
| 362 | 
            +
                      return_special_tokens_mask: return_special_tokens_mask,
         | 
| 363 | 
            +
                      return_offsets_mapping: return_offsets_mapping,
         | 
| 364 | 
            +
                      return_length: return_length,
         | 
| 365 | 
            +
                      verbose: verbose,
         | 
| 366 | 
            +
                      **kwargs
         | 
| 367 | 
            +
                    )
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                  # Return tensor is None, then we can remove the leading batch axis
         | 
| 370 | 
            +
                  # Overflowing tokens are returned as a batch of output so we keep them in this case
         | 
| 371 | 
            +
                  if return_tensors.nil? && !return_overflowing_tokens
         | 
| 372 | 
            +
                    batched_output =
         | 
| 373 | 
            +
                      BatchEncoding.new(
         | 
| 374 | 
            +
                        data: batched_output.items.to_h { |key, value|
         | 
| 375 | 
            +
                          [key, value.length > 0 && value[0].is_a?(Array) ? value[0] : value]
         | 
| 376 | 
            +
                        },
         | 
| 377 | 
            +
                        encoding: batched_output.encodings
         | 
| 378 | 
            +
                      )
         | 
| 379 | 
            +
                  end
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                  _eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                  batched_output
         | 
| 384 | 
            +
                end
         | 
| 385 | 
            +
              end
         | 
| 386 | 
            +
            end
         | 
| @@ -0,0 +1,25 @@ | |
| 1 | 
            +
            # Copyright 2022 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            module Transformers
         | 
| 16 | 
            +
              module TorchUtils
         | 
| 17 | 
            +
                def self.apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
         | 
| 18 | 
            +
                  if chunk_size > 0
         | 
| 19 | 
            +
                    raise Todo
         | 
| 20 | 
            +
                  end
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                  forward_fn.(*input_tensors)
         | 
| 23 | 
            +
                end
         | 
| 24 | 
            +
              end
         | 
| 25 | 
            +
            end
         | 
| @@ -0,0 +1,31 @@ | |
| 1 | 
            +
            # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            module Transformers
         | 
| 16 | 
            +
              WEIGHTS_NAME = "pytorch_model.bin"
         | 
| 17 | 
            +
              WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
         | 
| 18 | 
            +
              TF2_WEIGHTS_NAME = "tf_model.h5"
         | 
| 19 | 
            +
              TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
         | 
| 20 | 
            +
              TF_WEIGHTS_NAME = "model.ckpt"
         | 
| 21 | 
            +
              FLAX_WEIGHTS_NAME = "flax_model.msgpack"
         | 
| 22 | 
            +
              FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
         | 
| 23 | 
            +
              SAFE_WEIGHTS_NAME = "model.safetensors"
         | 
| 24 | 
            +
              SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
         | 
| 25 | 
            +
              CONFIG_NAME = "config.json"
         | 
| 26 | 
            +
              FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
         | 
| 27 | 
            +
              IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
         | 
| 28 | 
            +
              PROCESSOR_NAME = "processor_config.json"
         | 
| 29 | 
            +
              GENERATION_CONFIG_NAME = "generation_config.json"
         | 
| 30 | 
            +
              MODEL_CARD_NAME = "modelcard.json"
         | 
| 31 | 
            +
            end
         | 
| @@ -0,0 +1,107 @@ | |
| 1 | 
            +
            # Copyright 2022 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            module Transformers
         | 
| 16 | 
            +
              class ModelOutput
         | 
| 17 | 
            +
                def self.attributes
         | 
| 18 | 
            +
                  @attributes ||= []
         | 
| 19 | 
            +
                end
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def self.attribute(attribute)
         | 
| 22 | 
            +
                  attributes << attribute.to_sym
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                  define_method(attribute) do
         | 
| 25 | 
            +
                    self[attribute]
         | 
| 26 | 
            +
                  end
         | 
| 27 | 
            +
                end
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def initialize(**kwargs)
         | 
| 30 | 
            +
                  @data = kwargs
         | 
| 31 | 
            +
                end
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def [](k)
         | 
| 34 | 
            +
                  if k.is_a?(String) || k.is_a?(Symbol)
         | 
| 35 | 
            +
                    @data[k.to_sym]
         | 
| 36 | 
            +
                  else
         | 
| 37 | 
            +
                    to_tuple[k]
         | 
| 38 | 
            +
                  end
         | 
| 39 | 
            +
                end
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def to_tuple
         | 
| 42 | 
            +
                  self.class.attributes.map { |k| @data[k] }.compact
         | 
| 43 | 
            +
                end
         | 
| 44 | 
            +
              end
         | 
| 45 | 
            +
             | 
| 46 | 
            +
              class ExplicitEnum
         | 
| 47 | 
            +
                def initialize(value)
         | 
| 48 | 
            +
                  expected = self.class.constants.map { |k| self.class.const_get(k) }
         | 
| 49 | 
            +
                  unless expected.include?(value)
         | 
| 50 | 
            +
                    raise ArgumentError, "#{value} is not a valid #{self.class.name}, please select one of #{expected.inspect}"
         | 
| 51 | 
            +
                  end
         | 
| 52 | 
            +
                  @value = value
         | 
| 53 | 
            +
                end
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def to_s
         | 
| 56 | 
            +
                  @value
         | 
| 57 | 
            +
                end
         | 
| 58 | 
            +
              end
         | 
| 59 | 
            +
             | 
| 60 | 
            +
              class PaddingStrategy < ExplicitEnum
         | 
| 61 | 
            +
                LONGEST = "longest"
         | 
| 62 | 
            +
                MAX_LENGTH = "max_length"
         | 
| 63 | 
            +
                DO_NOT_PAD = "do_not_pad"
         | 
| 64 | 
            +
              end
         | 
| 65 | 
            +
             | 
| 66 | 
            +
              class TensorType < ExplicitEnum
         | 
| 67 | 
            +
                PYTORCH = "pt"
         | 
| 68 | 
            +
                TENSORFLOW = "tf"
         | 
| 69 | 
            +
                NUMPY = "np"
         | 
| 70 | 
            +
                JAX = "jax"
         | 
| 71 | 
            +
                MLX = "mlx"
         | 
| 72 | 
            +
              end
         | 
| 73 | 
            +
             | 
| 74 | 
            +
              module Utils
         | 
| 75 | 
            +
                def self.infer_framework(model_class)
         | 
| 76 | 
            +
                  if model_class < Torch::NN::Module
         | 
| 77 | 
            +
                    "pt"
         | 
| 78 | 
            +
                  else
         | 
| 79 | 
            +
                    raise TypeError, "Could not infer framework from class #{model_class}."
         | 
| 80 | 
            +
                  end
         | 
| 81 | 
            +
                end
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def self._is_numo(x)
         | 
| 84 | 
            +
                  x.is_a?(Numo::NArray)
         | 
| 85 | 
            +
                end
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def self.is_numo_array(x)
         | 
| 88 | 
            +
                  _is_numo(x)
         | 
| 89 | 
            +
                end
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def self._is_torch(x)
         | 
| 92 | 
            +
                  x.is_a?(Torch::Tensor)
         | 
| 93 | 
            +
                end
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def self.is_torch_tensor(x)
         | 
| 96 | 
            +
                  _is_torch(x)
         | 
| 97 | 
            +
                end
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def self._is_torch_device(x)
         | 
| 100 | 
            +
                  x.is_a?(Torch::Device)
         | 
| 101 | 
            +
                end
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def self.is_torch_device(x)
         | 
| 104 | 
            +
                  _is_torch_device(x)
         | 
| 105 | 
            +
                end
         | 
| 106 | 
            +
              end
         | 
| 107 | 
            +
            end
         |