transformers-rb 0.1.1 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +64 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +16 -5
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +15 -5
| @@ -0,0 +1,68 @@ | |
| 1 | 
            +
            # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            module Transformers
         | 
| 16 | 
            +
              module XlmRoberta
         | 
| 17 | 
            +
                class XLMRobertaTokenizerFast < PreTrainedTokenizerFast
         | 
| 18 | 
            +
                  VOCAB_FILES_NAMES = {vocab_file: "sentencepiece.bpe.model", tokenizer_file: "tokenizer.json"}
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                  self.vocab_files_names = VOCAB_FILES_NAMES
         | 
| 21 | 
            +
                  self.model_input_names = ["input_ids", "attention_mask"]
         | 
| 22 | 
            +
                  # self.slow_tokenizer_class = XLMRobertaTokenizer
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                  def initialize(
         | 
| 25 | 
            +
                    vocab_file: nil,
         | 
| 26 | 
            +
                    tokenizer_file: nil,
         | 
| 27 | 
            +
                    bos_token: "<s>",
         | 
| 28 | 
            +
                    eos_token: "</s>",
         | 
| 29 | 
            +
                    sep_token: "</s>",
         | 
| 30 | 
            +
                    cls_token: "<s>",
         | 
| 31 | 
            +
                    unk_token: "<unk>",
         | 
| 32 | 
            +
                    pad_token: "<pad>",
         | 
| 33 | 
            +
                    mask_token: "<mask>",
         | 
| 34 | 
            +
                    **kwargs
         | 
| 35 | 
            +
                  )
         | 
| 36 | 
            +
                    # Mask token behave like a normal word, i.e. include the space before it
         | 
| 37 | 
            +
                    mask_token = mask_token.is_a?(String) ? Tokenizers::AddedToken.new(mask_token, lstrip: true, rstrip: false) : mask_token
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    super(vocab_file, tokenizer_file: tokenizer_file, bos_token: bos_token, eos_token: eos_token, sep_token: sep_token, cls_token: cls_token, unk_token: unk_token, pad_token: pad_token, mask_token: mask_token, **kwargs)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    @vocab_file = vocab_file
         | 
| 42 | 
            +
                  end
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                  def can_save_slow_tokenizer
         | 
| 45 | 
            +
                    @vocab_file ? File.exist?(@vocab_file) : false
         | 
| 46 | 
            +
                  end
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  def build_inputs_with_special_tokens(token_ids_0, token_ids_1: nil)
         | 
| 49 | 
            +
                    if token_ids_1.nil?
         | 
| 50 | 
            +
                      return [@cls_token_id] + token_ids_0 + [@sep_token_id]
         | 
| 51 | 
            +
                    end
         | 
| 52 | 
            +
                    cls = [@cls_token_id]
         | 
| 53 | 
            +
                    sep = [@sep_token_id]
         | 
| 54 | 
            +
                    cls + token_ids_0 + sep + sep + token_ids_1 + sep
         | 
| 55 | 
            +
                  end
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                  def create_token_type_ids_from_sequences(token_ids_0, token_ids_1: nil)
         | 
| 58 | 
            +
                    sep = [@sep_token_id]
         | 
| 59 | 
            +
                    cls = [@cls_token_id]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    if token_ids_1.nil?
         | 
| 62 | 
            +
                      return (cls + token_ids_0 + sep).length * [0]
         | 
| 63 | 
            +
                    end
         | 
| 64 | 
            +
                    (cls + token_ids_0 + sep + sep + token_ids_1 + sep).length * [0]
         | 
| 65 | 
            +
                  end
         | 
| 66 | 
            +
                end
         | 
| 67 | 
            +
              end
         | 
| 68 | 
            +
            end
         | 
| @@ -24,7 +24,7 @@ module Transformers | |
| 24 24 | 
             
                  "pt" => [AutoModel],
         | 
| 25 25 | 
             
                  "default" => {
         | 
| 26 26 | 
             
                    "model" => {
         | 
| 27 | 
            -
                      "pt" => ["distilbert/distilbert-base-cased", " | 
| 27 | 
            +
                      "pt" => ["distilbert/distilbert-base-cased", "6ea8117"]
         | 
| 28 28 | 
             
                    }
         | 
| 29 29 | 
             
                  },
         | 
| 30 30 | 
             
                  "type" => "multimodal"
         | 
| @@ -34,7 +34,7 @@ module Transformers | |
| 34 34 | 
             
                  "pt" => [AutoModelForSequenceClassification],
         | 
| 35 35 | 
             
                  "default" => {
         | 
| 36 36 | 
             
                    "model" => {
         | 
| 37 | 
            -
                      "pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", " | 
| 37 | 
            +
                      "pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"]
         | 
| 38 38 | 
             
                    }
         | 
| 39 39 | 
             
                  },
         | 
| 40 40 | 
             
                  "type" => "text"
         | 
| @@ -44,7 +44,7 @@ module Transformers | |
| 44 44 | 
             
                  "pt" => [AutoModelForTokenClassification],
         | 
| 45 45 | 
             
                  "default" => {
         | 
| 46 46 | 
             
                    "model" => {
         | 
| 47 | 
            -
                      "pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", " | 
| 47 | 
            +
                      "pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"]
         | 
| 48 48 | 
             
                    }
         | 
| 49 49 | 
             
                  },
         | 
| 50 50 | 
             
                  "type" => "text"
         | 
| @@ -54,7 +54,7 @@ module Transformers | |
| 54 54 | 
             
                  "pt" => [AutoModelForQuestionAnswering],
         | 
| 55 55 | 
             
                  "default" => {
         | 
| 56 56 | 
             
                    "model" => {
         | 
| 57 | 
            -
                      "pt" => ["distilbert/distilbert-base-cased-distilled-squad", " | 
| 57 | 
            +
                      "pt" => ["distilbert/distilbert-base-cased-distilled-squad", "564e9b5"]
         | 
| 58 58 | 
             
                    }
         | 
| 59 59 | 
             
                  },
         | 
| 60 60 | 
             
                  "type" => "text"
         | 
| @@ -64,7 +64,7 @@ module Transformers | |
| 64 64 | 
             
                  "pt" => [AutoModelForImageClassification],
         | 
| 65 65 | 
             
                  "default" => {
         | 
| 66 66 | 
             
                    "model" => {
         | 
| 67 | 
            -
                      "pt" => ["google/vit-base-patch16-224", " | 
| 67 | 
            +
                      "pt" => ["google/vit-base-patch16-224", "3f49326"]
         | 
| 68 68 | 
             
                    }
         | 
| 69 69 | 
             
                  },
         | 
| 70 70 | 
             
                  "type" => "image"
         | 
| @@ -89,6 +89,16 @@ module Transformers | |
| 89 89 | 
             
                  },
         | 
| 90 90 | 
             
                  "type" => "text"
         | 
| 91 91 | 
             
                },
         | 
| 92 | 
            +
                "reranking" => {
         | 
| 93 | 
            +
                  "impl" => RerankingPipeline,
         | 
| 94 | 
            +
                  "pt" => [AutoModelForSequenceClassification],
         | 
| 95 | 
            +
                  "default" => {
         | 
| 96 | 
            +
                    "model" => {
         | 
| 97 | 
            +
                      "pt" => ["mixedbread-ai/mxbai-rerank-base-v1", "03241da"]
         | 
| 98 | 
            +
                    }
         | 
| 99 | 
            +
                  },
         | 
| 100 | 
            +
                  "type" => "text"
         | 
| 101 | 
            +
                }
         | 
| 92 102 | 
             
              }
         | 
| 93 103 |  | 
| 94 104 | 
             
              PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
         | 
| @@ -227,6 +237,7 @@ module Transformers | |
| 227 237 | 
             
                      " #{revision} (#{Utils::Hub::HUGGINGFACE_CO_RESOLVE_ENDPOINT}/#{model}).\n" +
         | 
| 228 238 | 
             
                      "Using a pipeline without specifying a model name and revision in production is not recommended."
         | 
| 229 239 | 
             
                    )
         | 
| 240 | 
            +
                    hub_kwargs[:revision] = revision
         | 
| 230 241 | 
             
                    if config.nil? && model.is_a?(String)
         | 
| 231 242 | 
             
                      config = AutoConfig.from_pretrained(model, _from_pipeline: task, **hub_kwargs, **model_kwargs)
         | 
| 232 243 | 
             
                      hub_kwargs[:_commit_hash] = config._commit_hash
         | 
| @@ -0,0 +1,33 @@ | |
| 1 | 
            +
            module Transformers
         | 
| 2 | 
            +
              class RerankingPipeline < Pipeline
         | 
| 3 | 
            +
                def _sanitize_parameters(**kwargs)
         | 
| 4 | 
            +
                  [{}, {}, kwargs]
         | 
| 5 | 
            +
                end
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                def preprocess(inputs)
         | 
| 8 | 
            +
                  @tokenizer.(
         | 
| 9 | 
            +
                    [inputs[:query]] * inputs[:documents].length,
         | 
| 10 | 
            +
                    text_pair: inputs[:documents],
         | 
| 11 | 
            +
                    return_tensors: @framework
         | 
| 12 | 
            +
                  )
         | 
| 13 | 
            +
                end
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def _forward(model_inputs)
         | 
| 16 | 
            +
                  model_outputs = @model.(**model_inputs)
         | 
| 17 | 
            +
                  model_outputs
         | 
| 18 | 
            +
                end
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def call(query, documents)
         | 
| 21 | 
            +
                  super({query: query, documents: documents})
         | 
| 22 | 
            +
                end
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def postprocess(model_outputs)
         | 
| 25 | 
            +
                   model_outputs[0]
         | 
| 26 | 
            +
                    .sigmoid
         | 
| 27 | 
            +
                    .squeeze
         | 
| 28 | 
            +
                    .to_a
         | 
| 29 | 
            +
                    .map.with_index { |s, i| {index: i, score: s} }
         | 
| 30 | 
            +
                    .sort_by { |v| -v[:score] }
         | 
| 31 | 
            +
                end
         | 
| 32 | 
            +
              end
         | 
| 33 | 
            +
            end
         | 
    
        data/lib/transformers/version.rb
    CHANGED
    
    
    
        data/lib/transformers.rb
    CHANGED
    
    | @@ -61,17 +61,32 @@ require_relative "transformers/models/bert/modeling_bert" | |
| 61 61 | 
             
            require_relative "transformers/models/bert/tokenization_bert"
         | 
| 62 62 | 
             
            require_relative "transformers/models/bert/tokenization_bert_fast"
         | 
| 63 63 |  | 
| 64 | 
            +
            # models deberta-v2
         | 
| 65 | 
            +
            require_relative "transformers/models/deberta_v2/configuration_deberta_v2"
         | 
| 66 | 
            +
            require_relative "transformers/models/deberta_v2/modeling_deberta_v2"
         | 
| 67 | 
            +
            require_relative "transformers/models/deberta_v2/tokenization_deberta_v2_fast"
         | 
| 68 | 
            +
             | 
| 64 69 | 
             
            # models distilbert
         | 
| 65 70 | 
             
            require_relative "transformers/models/distilbert/configuration_distilbert"
         | 
| 66 71 | 
             
            require_relative "transformers/models/distilbert/modeling_distilbert"
         | 
| 67 72 | 
             
            require_relative "transformers/models/distilbert/tokenization_distilbert"
         | 
| 68 73 | 
             
            require_relative "transformers/models/distilbert/tokenization_distilbert_fast"
         | 
| 69 74 |  | 
| 75 | 
            +
            # models mpnet
         | 
| 76 | 
            +
            require_relative "transformers/models/mpnet/configuration_mpnet"
         | 
| 77 | 
            +
            require_relative "transformers/models/mpnet/modeling_mpnet"
         | 
| 78 | 
            +
            require_relative "transformers/models/mpnet/tokenization_mpnet_fast"
         | 
| 79 | 
            +
             | 
| 70 80 | 
             
            # models vit
         | 
| 71 81 | 
             
            require_relative "transformers/models/vit/configuration_vit"
         | 
| 72 82 | 
             
            require_relative "transformers/models/vit/image_processing_vit"
         | 
| 73 83 | 
             
            require_relative "transformers/models/vit/modeling_vit"
         | 
| 74 84 |  | 
| 85 | 
            +
            # models xml-roberta
         | 
| 86 | 
            +
            require_relative "transformers/models/xlm_roberta/configuration_xlm_roberta"
         | 
| 87 | 
            +
            require_relative "transformers/models/xlm_roberta/modeling_xlm_roberta"
         | 
| 88 | 
            +
            require_relative "transformers/models/xlm_roberta/tokenization_xlm_roberta_fast"
         | 
| 89 | 
            +
             | 
| 75 90 | 
             
            # pipelines
         | 
| 76 91 | 
             
            require_relative "transformers/pipelines/base"
         | 
| 77 92 | 
             
            require_relative "transformers/pipelines/feature_extraction"
         | 
| @@ -80,6 +95,7 @@ require_relative "transformers/pipelines/image_classification" | |
| 80 95 | 
             
            require_relative "transformers/pipelines/image_feature_extraction"
         | 
| 81 96 | 
             
            require_relative "transformers/pipelines/pt_utils"
         | 
| 82 97 | 
             
            require_relative "transformers/pipelines/question_answering"
         | 
| 98 | 
            +
            require_relative "transformers/pipelines/reranking"
         | 
| 83 99 | 
             
            require_relative "transformers/pipelines/text_classification"
         | 
| 84 100 | 
             
            require_relative "transformers/pipelines/token_classification"
         | 
| 85 101 | 
             
            require_relative "transformers/pipelines/_init"
         | 
    
        metadata
    CHANGED
    
    | @@ -1,14 +1,14 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification
         | 
| 2 2 | 
             
            name: transformers-rb
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version
         | 
| 4 | 
            -
              version: 0.1. | 
| 4 | 
            +
              version: 0.1.3
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors:
         | 
| 7 7 | 
             
            - Andrew Kane
         | 
| 8 8 | 
             
            autorequire:
         | 
| 9 9 | 
             
            bindir: bin
         | 
| 10 10 | 
             
            cert_chain: []
         | 
| 11 | 
            -
            date: 2024- | 
| 11 | 
            +
            date: 2024-09-17 00:00:00.000000000 Z
         | 
| 12 12 | 
             
            dependencies:
         | 
| 13 13 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 14 14 | 
             
              name: numo-narray
         | 
| @@ -44,14 +44,14 @@ dependencies: | |
| 44 44 | 
             
                requirements:
         | 
| 45 45 | 
             
                - - ">="
         | 
| 46 46 | 
             
                  - !ruby/object:Gem::Version
         | 
| 47 | 
            -
                    version: 0.5. | 
| 47 | 
            +
                    version: 0.5.3
         | 
| 48 48 | 
             
              type: :runtime
         | 
| 49 49 | 
             
              prerelease: false
         | 
| 50 50 | 
             
              version_requirements: !ruby/object:Gem::Requirement
         | 
| 51 51 | 
             
                requirements:
         | 
| 52 52 | 
             
                - - ">="
         | 
| 53 53 | 
             
                  - !ruby/object:Gem::Version
         | 
| 54 | 
            -
                    version: 0.5. | 
| 54 | 
            +
                    version: 0.5.3
         | 
| 55 55 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 56 56 | 
             
              name: torch-rb
         | 
| 57 57 | 
             
              requirement: !ruby/object:Gem::Requirement
         | 
| @@ -104,13 +104,22 @@ files: | |
| 104 104 | 
             
            - lib/transformers/models/bert/modeling_bert.rb
         | 
| 105 105 | 
             
            - lib/transformers/models/bert/tokenization_bert.rb
         | 
| 106 106 | 
             
            - lib/transformers/models/bert/tokenization_bert_fast.rb
         | 
| 107 | 
            +
            - lib/transformers/models/deberta_v2/configuration_deberta_v2.rb
         | 
| 108 | 
            +
            - lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
         | 
| 109 | 
            +
            - lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb
         | 
| 107 110 | 
             
            - lib/transformers/models/distilbert/configuration_distilbert.rb
         | 
| 108 111 | 
             
            - lib/transformers/models/distilbert/modeling_distilbert.rb
         | 
| 109 112 | 
             
            - lib/transformers/models/distilbert/tokenization_distilbert.rb
         | 
| 110 113 | 
             
            - lib/transformers/models/distilbert/tokenization_distilbert_fast.rb
         | 
| 114 | 
            +
            - lib/transformers/models/mpnet/configuration_mpnet.rb
         | 
| 115 | 
            +
            - lib/transformers/models/mpnet/modeling_mpnet.rb
         | 
| 116 | 
            +
            - lib/transformers/models/mpnet/tokenization_mpnet_fast.rb
         | 
| 111 117 | 
             
            - lib/transformers/models/vit/configuration_vit.rb
         | 
| 112 118 | 
             
            - lib/transformers/models/vit/image_processing_vit.rb
         | 
| 113 119 | 
             
            - lib/transformers/models/vit/modeling_vit.rb
         | 
| 120 | 
            +
            - lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb
         | 
| 121 | 
            +
            - lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
         | 
| 122 | 
            +
            - lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb
         | 
| 114 123 | 
             
            - lib/transformers/pipelines/_init.rb
         | 
| 115 124 | 
             
            - lib/transformers/pipelines/base.rb
         | 
| 116 125 | 
             
            - lib/transformers/pipelines/embedding.rb
         | 
| @@ -119,6 +128,7 @@ files: | |
| 119 128 | 
             
            - lib/transformers/pipelines/image_feature_extraction.rb
         | 
| 120 129 | 
             
            - lib/transformers/pipelines/pt_utils.rb
         | 
| 121 130 | 
             
            - lib/transformers/pipelines/question_answering.rb
         | 
| 131 | 
            +
            - lib/transformers/pipelines/reranking.rb
         | 
| 122 132 | 
             
            - lib/transformers/pipelines/text_classification.rb
         | 
| 123 133 | 
             
            - lib/transformers/pipelines/token_classification.rb
         | 
| 124 134 | 
             
            - lib/transformers/ruby_utils.rb
         | 
| @@ -155,7 +165,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement | |
| 155 165 | 
             
                - !ruby/object:Gem::Version
         | 
| 156 166 | 
             
                  version: '0'
         | 
| 157 167 | 
             
            requirements: []
         | 
| 158 | 
            -
            rubygems_version: 3.5. | 
| 168 | 
            +
            rubygems_version: 3.5.16
         | 
| 159 169 | 
             
            signing_key:
         | 
| 160 170 | 
             
            specification_version: 4
         | 
| 161 171 | 
             
            summary: State-of-the-art transformers for Ruby
         |