transformers-rb 0.1.1 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +64 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +16 -5
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +15 -5
    
        checksums.yaml
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            ---
         | 
| 2 2 | 
             
            SHA256:
         | 
| 3 | 
            -
              metadata.gz:  | 
| 4 | 
            -
              data.tar.gz:  | 
| 3 | 
            +
              metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
         | 
| 4 | 
            +
              data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
         | 
| 5 5 | 
             
            SHA512:
         | 
| 6 | 
            -
              metadata.gz:  | 
| 7 | 
            -
              data.tar.gz:  | 
| 6 | 
            +
              metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
         | 
| 7 | 
            +
              data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
         | 
    
        data/CHANGELOG.md
    CHANGED
    
    
    
        data/README.md
    CHANGED
    
    | @@ -32,8 +32,17 @@ Embedding | |
| 32 32 | 
             
            - [intfloat/e5-base-v2](#intfloate5-base-v2)
         | 
| 33 33 | 
             
            - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
         | 
| 34 34 | 
             
            - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
         | 
| 35 | 
            +
            - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            Sparse embedding
         | 
| 38 | 
            +
             | 
| 35 39 | 
             
            - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
         | 
| 36 40 |  | 
| 41 | 
            +
            Reranking
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
         | 
| 44 | 
            +
            - [BAAI/bge-reranker-base](#baaibge-reranker-base)
         | 
| 45 | 
            +
             | 
| 37 46 | 
             
            ### sentence-transformers/all-MiniLM-L6-v2
         | 
| 38 47 |  | 
| 39 48 | 
             
            [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
         | 
| @@ -139,6 +148,17 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v | |
| 139 148 | 
             
            embeddings = model.(input, pooling: "cls")
         | 
| 140 149 | 
             
            ```
         | 
| 141 150 |  | 
| 151 | 
            +
            ### sentence-transformers/all-mpnet-base-v2
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            ```ruby
         | 
| 156 | 
            +
            sentences = ["This is an example sentence", "Each sentence is converted"]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
            model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
         | 
| 159 | 
            +
            embeddings = model.(sentences)
         | 
| 160 | 
            +
            ```
         | 
| 161 | 
            +
             | 
| 142 162 | 
             
            ### opensearch-project/opensearch-neural-sparse-encoding-v1
         | 
| 143 163 |  | 
| 144 164 | 
             
            [Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
         | 
| @@ -160,8 +180,37 @@ values[0.., special_token_ids] = 0 | |
| 160 180 | 
             
            embeddings = values.to_a
         | 
| 161 181 | 
             
            ```
         | 
| 162 182 |  | 
| 183 | 
            +
            ### mixedbread-ai/mxbai-rerank-base-v1
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            ```ruby
         | 
| 188 | 
            +
            query = "How many people live in London?"
         | 
| 189 | 
            +
            docs = ["Around 9 Million people live in London", "London is known for its financial district"]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
         | 
| 192 | 
            +
            result = model.(query, docs)
         | 
| 193 | 
            +
            ```
         | 
| 194 | 
            +
             | 
| 195 | 
            +
            ### BAAI/bge-reranker-base
         | 
| 196 | 
            +
             | 
| 197 | 
            +
            [Docs](https://huggingface.co/BAAI/bge-reranker-base)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
            ```ruby
         | 
| 200 | 
            +
            query = "How many people live in London?"
         | 
| 201 | 
            +
            docs = ["Around 9 Million people live in London", "London is known for its financial district"]
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
         | 
| 204 | 
            +
            result = model.(query, docs)
         | 
| 205 | 
            +
            ```
         | 
| 206 | 
            +
             | 
| 163 207 | 
             
            ## Pipelines
         | 
| 164 208 |  | 
| 209 | 
            +
            - [Text](#text)
         | 
| 210 | 
            +
            - [Vision](#vision)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
            ### Text
         | 
| 213 | 
            +
             | 
| 165 214 | 
             
            Embedding
         | 
| 166 215 |  | 
| 167 216 | 
             
            ```ruby
         | 
| @@ -169,6 +218,13 @@ embed = Transformers.pipeline("embedding") | |
| 169 218 | 
             
            embed.("We are very happy to show you the 🤗 Transformers library.")
         | 
| 170 219 | 
             
            ```
         | 
| 171 220 |  | 
| 221 | 
            +
            Reranking
         | 
| 222 | 
            +
             | 
| 223 | 
            +
            ```ruby
         | 
| 224 | 
            +
            rerank = Informers.pipeline("reranking")
         | 
| 225 | 
            +
            rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
         | 
| 226 | 
            +
            ```
         | 
| 227 | 
            +
             | 
| 172 228 | 
             
            Named-entity recognition
         | 
| 173 229 |  | 
| 174 230 | 
             
            ```ruby
         | 
| @@ -197,27 +253,32 @@ extractor = Transformers.pipeline("feature-extraction") | |
| 197 253 | 
             
            extractor.("We are very happy to show you the 🤗 Transformers library.")
         | 
| 198 254 | 
             
            ```
         | 
| 199 255 |  | 
| 256 | 
            +
            ### Vision
         | 
| 257 | 
            +
             | 
| 200 258 | 
             
            Image classification
         | 
| 201 259 |  | 
| 202 260 | 
             
            ```ruby
         | 
| 203 261 | 
             
            classifier = Transformers.pipeline("image-classification")
         | 
| 204 | 
            -
            classifier.( | 
| 262 | 
            +
            classifier.("image.jpg")
         | 
| 205 263 | 
             
            ```
         | 
| 206 264 |  | 
| 207 265 | 
             
            Image feature extraction
         | 
| 208 266 |  | 
| 209 267 | 
             
            ```ruby
         | 
| 210 268 | 
             
            extractor = Transformers.pipeline("image-feature-extraction")
         | 
| 211 | 
            -
            extractor.( | 
| 269 | 
            +
            extractor.("image.jpg")
         | 
| 212 270 | 
             
            ```
         | 
| 213 271 |  | 
| 214 272 | 
             
            ## API
         | 
| 215 273 |  | 
| 216 | 
            -
            This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index).  | 
| 274 | 
            +
            This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
         | 
| 217 275 |  | 
| 218 276 | 
             
            - BERT
         | 
| 277 | 
            +
            - DeBERTa-v2
         | 
| 219 278 | 
             
            - DistilBERT
         | 
| 279 | 
            +
            - MPNet
         | 
| 220 280 | 
             
            - ViT
         | 
| 281 | 
            +
            - XLM-RoBERTa
         | 
| 221 282 |  | 
| 222 283 | 
             
            ## History
         | 
| 223 284 |  | 
| @@ -91,10 +91,24 @@ module Transformers | |
| 91 91 | 
             
                  # Config hash
         | 
| 92 92 | 
             
                  @commit_hash = kwargs.delete(:_commit_hash)
         | 
| 93 93 |  | 
| 94 | 
            -
                  #  | 
| 95 | 
            -
                  @ | 
| 96 | 
            -
             | 
| 97 | 
            -
                   | 
| 94 | 
            +
                  # Attention implementation to use, if relevant.
         | 
| 95 | 
            +
                  @attn_implementation_internal = kwargs.delete(:attn_implementation)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  # Drop the transformers version info
         | 
| 98 | 
            +
                  @transformers_version = kwargs.delete(:transformers_version)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  # Deal with gradient checkpointing
         | 
| 101 | 
            +
                  # if kwargs[:gradient_checkpointing] == false
         | 
| 102 | 
            +
                  #   warn(
         | 
| 103 | 
            +
                  #     "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
         | 
| 104 | 
            +
                  #     "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
         | 
| 105 | 
            +
                  #     "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
         | 
| 106 | 
            +
                  #   )
         | 
| 107 | 
            +
                  # end
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                  kwargs.each do |k, v|
         | 
| 110 | 
            +
                    instance_variable_set("@#{k}", v)
         | 
| 111 | 
            +
                  end
         | 
| 98 112 | 
             
                end
         | 
| 99 113 |  | 
| 100 114 | 
             
                def name_or_path
         | 
| @@ -182,6 +196,20 @@ module Transformers | |
| 182 196 | 
             
                  JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
         | 
| 183 197 | 
             
                end
         | 
| 184 198 |  | 
| 199 | 
            +
                def getattr(key, default)
         | 
| 200 | 
            +
                  if respond_to?(key)
         | 
| 201 | 
            +
                    public_send(key)
         | 
| 202 | 
            +
                  elsif instance_variable_defined?("@#{key}")
         | 
| 203 | 
            +
                    instance_variable_get("@#{key}")
         | 
| 204 | 
            +
                  else
         | 
| 205 | 
            +
                    default
         | 
| 206 | 
            +
                  end
         | 
| 207 | 
            +
                end
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def hasattr(key)
         | 
| 210 | 
            +
                  respond_to?(key) || instance_variable_defined?("@#{key}")
         | 
| 211 | 
            +
                end
         | 
| 212 | 
            +
             | 
| 185 213 | 
             
                class << self
         | 
| 186 214 | 
             
                  def from_pretrained(
         | 
| 187 215 | 
             
                    pretrained_model_name_or_path,
         | 
| @@ -207,7 +207,7 @@ module Transformers | |
| 207 207 |  | 
| 208 208 | 
             
                def init_weights
         | 
| 209 209 | 
             
                  # Prune heads if needed
         | 
| 210 | 
            -
                  if @config.pruned_heads
         | 
| 210 | 
            +
                  if @config.pruned_heads.any?
         | 
| 211 211 | 
             
                    prune_heads(@config.pruned_heads)
         | 
| 212 212 | 
             
                  end
         | 
| 213 213 |  | 
| @@ -803,11 +803,18 @@ module Transformers | |
| 803 803 | 
             
                      raise Todo
         | 
| 804 804 | 
             
                    end
         | 
| 805 805 |  | 
| 806 | 
            +
                    model_class_name = model.class.name.split("::").last
         | 
| 807 | 
            +
             | 
| 806 808 | 
             
                    if error_msgs.length > 0
         | 
| 807 | 
            -
                       | 
| 809 | 
            +
                      error_msg = error_msgs.join("\n\t")
         | 
| 810 | 
            +
                      if error_msg.include?("size mismatch")
         | 
| 811 | 
            +
                        error_msg += (
         | 
| 812 | 
            +
                          "\n\tYou may consider adding `ignore_mismatched_sizes: true` in the model `from_pretrained` method."
         | 
| 813 | 
            +
                        )
         | 
| 814 | 
            +
                      end
         | 
| 815 | 
            +
                      raise RuntimeError, "Error(s) in loading state_dict for #{model_class_name}:\n\t#{error_msg}"
         | 
| 808 816 | 
             
                    end
         | 
| 809 817 |  | 
| 810 | 
            -
                    model_class_name = model.class.name.split("::").last
         | 
| 811 818 | 
             
                    if unexpected_keys.length > 0
         | 
| 812 819 | 
             
                      archs = model.config.architectures.nil? ? [] : model.config.architectures
         | 
| 813 820 | 
             
                      warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
         | 
| @@ -116,7 +116,7 @@ module Transformers | |
| 116 116 | 
             
                def _load_attr_from_module(model_type, attr)
         | 
| 117 117 | 
             
                  module_name = model_type_to_module_name(model_type)
         | 
| 118 118 | 
             
                  if !@modules.include?(module_name)
         | 
| 119 | 
            -
                    @modules[module_name] = Transformers.const_get(module_name.capitalize)
         | 
| 119 | 
            +
                    @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
         | 
| 120 120 | 
             
                  end
         | 
| 121 121 | 
             
                  getattribute_from_module(@modules[module_name], attr)
         | 
| 122 122 | 
             
                end
         | 
| @@ -15,8 +15,11 @@ | |
| 15 15 | 
             
            module Transformers
         | 
| 16 16 | 
             
              CONFIG_MAPPING_NAMES = {
         | 
| 17 17 | 
             
                "bert" => "BertConfig",
         | 
| 18 | 
            +
                "deberta-v2" => "DebertaV2Config",
         | 
| 18 19 | 
             
                "distilbert" => "DistilBertConfig",
         | 
| 19 | 
            -
                " | 
| 20 | 
            +
                "mpnet" => "MPNetConfig",
         | 
| 21 | 
            +
                "vit" => "ViTConfig",
         | 
| 22 | 
            +
                "xlm-roberta" => "XLMRobertaConfig"
         | 
| 20 23 | 
             
              }
         | 
| 21 24 |  | 
| 22 25 | 
             
              class LazyConfigMapping
         | 
| @@ -30,7 +33,7 @@ module Transformers | |
| 30 33 | 
             
                  value = @mapping.fetch(key)
         | 
| 31 34 | 
             
                  module_name = model_type_to_module_name(key)
         | 
| 32 35 | 
             
                  if !@modules.include?(module_name)
         | 
| 33 | 
            -
                    @modules[module_name] = Transformers.const_get(module_name.capitalize)
         | 
| 36 | 
            +
                    @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
         | 
| 34 37 | 
             
                  end
         | 
| 35 38 | 
             
                  @modules[module_name].const_get(value)
         | 
| 36 39 | 
             
                end
         | 
| @@ -15,16 +15,22 @@ | |
| 15 15 | 
             
            module Transformers
         | 
| 16 16 | 
             
              MODEL_MAPPING_NAMES = {
         | 
| 17 17 | 
             
                "bert" => "BertModel",
         | 
| 18 | 
            +
                "deberta-v2" => "DebertaV2Model",
         | 
| 18 19 | 
             
                "distilbert" => "DistilBertModel",
         | 
| 19 | 
            -
                " | 
| 20 | 
            +
                "mpnet" => "MPNetModel",
         | 
| 21 | 
            +
                "vit" => "ViTModel",
         | 
| 22 | 
            +
                "xlm-roberta" => "XLMRobertaModel"
         | 
| 20 23 | 
             
              }
         | 
| 21 24 |  | 
| 22 25 | 
             
              MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
         | 
| 23 | 
            -
                "bert" => "BertForMaskedLM"
         | 
| 26 | 
            +
                "bert" => "BertForMaskedLM",
         | 
| 27 | 
            +
                "mpnet" => "MPNetForMaskedLM"
         | 
| 24 28 | 
             
              }
         | 
| 25 29 |  | 
| 26 30 | 
             
              MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
         | 
| 27 | 
            -
                " | 
| 31 | 
            +
                "deberta-v2" => "DebertaV2ForSequenceClassification",
         | 
| 32 | 
            +
                "distilbert" => "DistilBertForSequenceClassification",
         | 
| 33 | 
            +
                "xlm-roberta" => "XLMRobertaForSequenceClassification"
         | 
| 28 34 | 
             
              }
         | 
| 29 35 |  | 
| 30 36 | 
             
              MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
         | 
| @@ -15,7 +15,10 @@ | |
| 15 15 | 
             
            module Transformers
         | 
| 16 16 | 
             
              TOKENIZER_MAPPING_NAMES = {
         | 
| 17 17 | 
             
                "bert" => ["BertTokenizer", "BertTokenizerFast"],
         | 
| 18 | 
            -
                " | 
| 18 | 
            +
                "deberta-v2" => ["DebertaV2TokenizerFast"],
         | 
| 19 | 
            +
                "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
         | 
| 20 | 
            +
                "mpnet" => ["MPNetTokenizerFast"],
         | 
| 21 | 
            +
                "xlm-roberta" => ["XLMRobertaTokenizerFast"]
         | 
| 19 22 | 
             
              }
         | 
| 20 23 |  | 
| 21 24 | 
             
              TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
         | 
| @@ -98,7 +101,7 @@ module Transformers | |
| 98 101 |  | 
| 99 102 | 
             
                    TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
         | 
| 100 103 | 
             
                      if tokenizers.include?(class_name)
         | 
| 101 | 
            -
                        cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
         | 
| 104 | 
            +
                        cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
         | 
| 102 105 | 
             
                        raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
         | 
| 103 106 | 
             
                        return cls
         | 
| 104 107 | 
             
                      end
         | 
| @@ -0,0 +1,80 @@ | |
| 1 | 
            +
            # Copyright 2020, Microsoft and the HuggingFace Inc. team.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            module Transformers
         | 
| 16 | 
            +
              module DebertaV2
         | 
| 17 | 
            +
                class DebertaV2Config < PretrainedConfig
         | 
| 18 | 
            +
                  self.model_type = "deberta-v2"
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                  attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
         | 
| 21 | 
            +
                    :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
         | 
| 22 | 
            +
                    :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
         | 
| 23 | 
            +
                    :relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
         | 
| 24 | 
            +
                    :pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  def initialize(
         | 
| 27 | 
            +
                    vocab_size: 128100,
         | 
| 28 | 
            +
                    hidden_size: 1536,
         | 
| 29 | 
            +
                    num_hidden_layers: 24,
         | 
| 30 | 
            +
                    num_attention_heads: 24,
         | 
| 31 | 
            +
                    intermediate_size: 6144,
         | 
| 32 | 
            +
                    hidden_act: "gelu",
         | 
| 33 | 
            +
                    hidden_dropout_prob: 0.1,
         | 
| 34 | 
            +
                    attention_probs_dropout_prob: 0.1,
         | 
| 35 | 
            +
                    max_position_embeddings: 512,
         | 
| 36 | 
            +
                    type_vocab_size: 0,
         | 
| 37 | 
            +
                    initializer_range: 0.02,
         | 
| 38 | 
            +
                    layer_norm_eps: 1e-07,
         | 
| 39 | 
            +
                    relative_attention: false,
         | 
| 40 | 
            +
                    max_relative_positions: -1,
         | 
| 41 | 
            +
                    pad_token_id: 0,
         | 
| 42 | 
            +
                    position_biased_input: true,
         | 
| 43 | 
            +
                    pos_att_type: nil,
         | 
| 44 | 
            +
                    pooler_dropout: 0,
         | 
| 45 | 
            +
                    pooler_hidden_act: "gelu",
         | 
| 46 | 
            +
                    **kwargs
         | 
| 47 | 
            +
                  )
         | 
| 48 | 
            +
                    super(**kwargs)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    @hidden_size = hidden_size
         | 
| 51 | 
            +
                    @num_hidden_layers = num_hidden_layers
         | 
| 52 | 
            +
                    @num_attention_heads = num_attention_heads
         | 
| 53 | 
            +
                    @intermediate_size = intermediate_size
         | 
| 54 | 
            +
                    @hidden_act = hidden_act
         | 
| 55 | 
            +
                    @hidden_dropout_prob = hidden_dropout_prob
         | 
| 56 | 
            +
                    @attention_probs_dropout_prob = attention_probs_dropout_prob
         | 
| 57 | 
            +
                    @max_position_embeddings = max_position_embeddings
         | 
| 58 | 
            +
                    @type_vocab_size = type_vocab_size
         | 
| 59 | 
            +
                    @initializer_range = initializer_range
         | 
| 60 | 
            +
                    @relative_attention = relative_attention
         | 
| 61 | 
            +
                    @max_relative_positions = max_relative_positions
         | 
| 62 | 
            +
                    @pad_token_id = pad_token_id
         | 
| 63 | 
            +
                    @position_biased_input = position_biased_input
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # Backwards compatibility
         | 
| 66 | 
            +
                    if pos_att_type.is_a?(String)
         | 
| 67 | 
            +
                      pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
         | 
| 68 | 
            +
                    end
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    @pos_att_type = pos_att_type
         | 
| 71 | 
            +
                    @vocab_size = vocab_size
         | 
| 72 | 
            +
                    @layer_norm_eps = layer_norm_eps
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    @pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
         | 
| 75 | 
            +
                    @pooler_dropout = pooler_dropout
         | 
| 76 | 
            +
                    @pooler_hidden_act = pooler_hidden_act
         | 
| 77 | 
            +
                  end
         | 
| 78 | 
            +
                end
         | 
| 79 | 
            +
              end
         | 
| 80 | 
            +
            end
         |