transformers-rb 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2b7441c0cc400d85d3d0775e6c53a631c83887e481588148a7cb3242e50c0f08
4
- data.tar.gz: df639eb5c3231d344ec8516ec861cb40239ac5ff0310faa5b1b102a79af97972
3
+ metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
4
+ data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
5
5
  SHA512:
6
- metadata.gz: 566a03e8a65255d3eb7e87bb3e68c0ac4d228ba4bc55e62eaad98f6677b8a5f29ee9201efaefc3edb309d279c1a986a664dfffbf3cc292114394be1c10fa949a
7
- data.tar.gz: fdae84f4fa8aac90f1a627d738fb701ef91d384414173f4042bac56bfe5c78c6ceb98aaea8ef0f2158d324ec60d18cbd91f84bd30f42b49d63193392e0105e96
6
+ metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
7
+ data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## 0.1.3 (2024-09-17)
2
+
3
+ - Added `reranking` pipeline
4
+ - Added DeBERTa-v2
5
+ - Added MPNet
6
+ - Added XLM-RoBERTa
7
+
8
+ ## 0.1.2 (2024-09-10)
9
+
10
+ - Fixed default revision for pipelines
11
+
1
12
  ## 0.1.1 (2024-08-29)
2
13
 
3
14
  - Added `embedding` pipeline
data/README.md CHANGED
@@ -32,8 +32,17 @@ Embedding
32
32
  - [intfloat/e5-base-v2](#intfloate5-base-v2)
33
33
  - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
34
34
  - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+ - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
36
+
37
+ Sparse embedding
38
+
35
39
  - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
36
40
 
41
+ Reranking
42
+
43
+ - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
44
+ - [BAAI/bge-reranker-base](#baaibge-reranker-base)
45
+
37
46
  ### sentence-transformers/all-MiniLM-L6-v2
38
47
 
39
48
  [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
@@ -139,6 +148,17 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
139
148
  embeddings = model.(input, pooling: "cls")
140
149
  ```
141
150
 
151
+ ### sentence-transformers/all-mpnet-base-v2
152
+
153
+ [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
154
+
155
+ ```ruby
156
+ sentences = ["This is an example sentence", "Each sentence is converted"]
157
+
158
+ model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
159
+ embeddings = model.(sentences)
160
+ ```
161
+
142
162
  ### opensearch-project/opensearch-neural-sparse-encoding-v1
143
163
 
144
164
  [Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
@@ -160,8 +180,37 @@ values[0.., special_token_ids] = 0
160
180
  embeddings = values.to_a
161
181
  ```
162
182
 
183
+ ### mixedbread-ai/mxbai-rerank-base-v1
184
+
185
+ [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
186
+
187
+ ```ruby
188
+ query = "How many people live in London?"
189
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
190
+
191
+ model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
192
+ result = model.(query, docs)
193
+ ```
194
+
195
+ ### BAAI/bge-reranker-base
196
+
197
+ [Docs](https://huggingface.co/BAAI/bge-reranker-base)
198
+
199
+ ```ruby
200
+ query = "How many people live in London?"
201
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
202
+
203
+ model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
204
+ result = model.(query, docs)
205
+ ```
206
+
163
207
  ## Pipelines
164
208
 
209
+ - [Text](#text)
210
+ - [Vision](#vision)
211
+
212
+ ### Text
213
+
165
214
  Embedding
166
215
 
167
216
  ```ruby
@@ -169,6 +218,13 @@ embed = Transformers.pipeline("embedding")
169
218
  embed.("We are very happy to show you the 🤗 Transformers library.")
170
219
  ```
171
220
 
221
+ Reranking
222
+
223
+ ```ruby
224
+ rerank = Informers.pipeline("reranking")
225
+ rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
226
+ ```
227
+
172
228
  Named-entity recognition
173
229
 
174
230
  ```ruby
@@ -197,27 +253,32 @@ extractor = Transformers.pipeline("feature-extraction")
197
253
  extractor.("We are very happy to show you the 🤗 Transformers library.")
198
254
  ```
199
255
 
256
+ ### Vision
257
+
200
258
  Image classification
201
259
 
202
260
  ```ruby
203
261
  classifier = Transformers.pipeline("image-classification")
204
- classifier.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"))
262
+ classifier.("image.jpg")
205
263
  ```
206
264
 
207
265
  Image feature extraction
208
266
 
209
267
  ```ruby
210
268
  extractor = Transformers.pipeline("image-feature-extraction")
211
- extractor.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"))
269
+ extractor.("image.jpg")
212
270
  ```
213
271
 
214
272
  ## API
215
273
 
216
- This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). Only a few model architectures are currently supported:
274
+ This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
217
275
 
218
276
  - BERT
277
+ - DeBERTa-v2
219
278
  - DistilBERT
279
+ - MPNet
220
280
  - ViT
281
+ - XLM-RoBERTa
221
282
 
222
283
  ## History
223
284
 
@@ -91,10 +91,24 @@ module Transformers
91
91
  # Config hash
92
92
  @commit_hash = kwargs.delete(:_commit_hash)
93
93
 
94
- # TODO set kwargs
95
- @gradient_checkpointing = kwargs[:gradient_checkpointing]
96
- @output_past = kwargs[:output_past]
97
- @tie_weights_ = kwargs[:tie_weights_]
94
+ # Attention implementation to use, if relevant.
95
+ @attn_implementation_internal = kwargs.delete(:attn_implementation)
96
+
97
+ # Drop the transformers version info
98
+ @transformers_version = kwargs.delete(:transformers_version)
99
+
100
+ # Deal with gradient checkpointing
101
+ # if kwargs[:gradient_checkpointing] == false
102
+ # warn(
103
+ # "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
104
+ # "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
105
+ # "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
106
+ # )
107
+ # end
108
+
109
+ kwargs.each do |k, v|
110
+ instance_variable_set("@#{k}", v)
111
+ end
98
112
  end
99
113
 
100
114
  def name_or_path
@@ -182,6 +196,20 @@ module Transformers
182
196
  JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
183
197
  end
184
198
 
199
+ def getattr(key, default)
200
+ if respond_to?(key)
201
+ public_send(key)
202
+ elsif instance_variable_defined?("@#{key}")
203
+ instance_variable_get("@#{key}")
204
+ else
205
+ default
206
+ end
207
+ end
208
+
209
+ def hasattr(key)
210
+ respond_to?(key) || instance_variable_defined?("@#{key}")
211
+ end
212
+
185
213
  class << self
186
214
  def from_pretrained(
187
215
  pretrained_model_name_or_path,
@@ -207,7 +207,7 @@ module Transformers
207
207
 
208
208
  def init_weights
209
209
  # Prune heads if needed
210
- if @config.pruned_heads
210
+ if @config.pruned_heads.any?
211
211
  prune_heads(@config.pruned_heads)
212
212
  end
213
213
 
@@ -803,11 +803,18 @@ module Transformers
803
803
  raise Todo
804
804
  end
805
805
 
806
+ model_class_name = model.class.name.split("::").last
807
+
806
808
  if error_msgs.length > 0
807
- raise Todo
809
+ error_msg = error_msgs.join("\n\t")
810
+ if error_msg.include?("size mismatch")
811
+ error_msg += (
812
+ "\n\tYou may consider adding `ignore_mismatched_sizes: true` in the model `from_pretrained` method."
813
+ )
814
+ end
815
+ raise RuntimeError, "Error(s) in loading state_dict for #{model_class_name}:\n\t#{error_msg}"
808
816
  end
809
817
 
810
- model_class_name = model.class.name.split("::").last
811
818
  if unexpected_keys.length > 0
812
819
  archs = model.config.architectures.nil? ? [] : model.config.architectures
813
820
  warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
@@ -116,7 +116,7 @@ module Transformers
116
116
  def _load_attr_from_module(model_type, attr)
117
117
  module_name = model_type_to_module_name(model_type)
118
118
  if !@modules.include?(module_name)
119
- @modules[module_name] = Transformers.const_get(module_name.capitalize)
119
+ @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
120
120
  end
121
121
  getattribute_from_module(@modules[module_name], attr)
122
122
  end
@@ -15,8 +15,11 @@
15
15
  module Transformers
16
16
  CONFIG_MAPPING_NAMES = {
17
17
  "bert" => "BertConfig",
18
+ "deberta-v2" => "DebertaV2Config",
18
19
  "distilbert" => "DistilBertConfig",
19
- "vit" => "ViTConfig"
20
+ "mpnet" => "MPNetConfig",
21
+ "vit" => "ViTConfig",
22
+ "xlm-roberta" => "XLMRobertaConfig"
20
23
  }
21
24
 
22
25
  class LazyConfigMapping
@@ -30,7 +33,7 @@ module Transformers
30
33
  value = @mapping.fetch(key)
31
34
  module_name = model_type_to_module_name(key)
32
35
  if !@modules.include?(module_name)
33
- @modules[module_name] = Transformers.const_get(module_name.capitalize)
36
+ @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
34
37
  end
35
38
  @modules[module_name].const_get(value)
36
39
  end
@@ -15,16 +15,22 @@
15
15
  module Transformers
16
16
  MODEL_MAPPING_NAMES = {
17
17
  "bert" => "BertModel",
18
+ "deberta-v2" => "DebertaV2Model",
18
19
  "distilbert" => "DistilBertModel",
19
- "vit" => "ViTModel"
20
+ "mpnet" => "MPNetModel",
21
+ "vit" => "ViTModel",
22
+ "xlm-roberta" => "XLMRobertaModel"
20
23
  }
21
24
 
22
25
  MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
23
- "bert" => "BertForMaskedLM"
26
+ "bert" => "BertForMaskedLM",
27
+ "mpnet" => "MPNetForMaskedLM"
24
28
  }
25
29
 
26
30
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
27
- "distilbert" => "DistilBertForSequenceClassification"
31
+ "deberta-v2" => "DebertaV2ForSequenceClassification",
32
+ "distilbert" => "DistilBertForSequenceClassification",
33
+ "xlm-roberta" => "XLMRobertaForSequenceClassification"
28
34
  }
29
35
 
30
36
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
@@ -15,7 +15,10 @@
15
15
  module Transformers
16
16
  TOKENIZER_MAPPING_NAMES = {
17
17
  "bert" => ["BertTokenizer", "BertTokenizerFast"],
18
- "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"]
18
+ "deberta-v2" => ["DebertaV2TokenizerFast"],
19
+ "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
20
+ "mpnet" => ["MPNetTokenizerFast"],
21
+ "xlm-roberta" => ["XLMRobertaTokenizerFast"]
19
22
  }
20
23
 
21
24
  TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
@@ -98,7 +101,7 @@ module Transformers
98
101
 
99
102
  TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
100
103
  if tokenizers.include?(class_name)
101
- cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
104
+ cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
102
105
  raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
103
106
  return cls
104
107
  end
@@ -0,0 +1,80 @@
1
+ # Copyright 2020, Microsoft and the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module DebertaV2
17
+ class DebertaV2Config < PretrainedConfig
18
+ self.model_type = "deberta-v2"
19
+
20
+ attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
21
+ :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
22
+ :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
23
+ :relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
24
+ :pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
25
+
26
+ def initialize(
27
+ vocab_size: 128100,
28
+ hidden_size: 1536,
29
+ num_hidden_layers: 24,
30
+ num_attention_heads: 24,
31
+ intermediate_size: 6144,
32
+ hidden_act: "gelu",
33
+ hidden_dropout_prob: 0.1,
34
+ attention_probs_dropout_prob: 0.1,
35
+ max_position_embeddings: 512,
36
+ type_vocab_size: 0,
37
+ initializer_range: 0.02,
38
+ layer_norm_eps: 1e-07,
39
+ relative_attention: false,
40
+ max_relative_positions: -1,
41
+ pad_token_id: 0,
42
+ position_biased_input: true,
43
+ pos_att_type: nil,
44
+ pooler_dropout: 0,
45
+ pooler_hidden_act: "gelu",
46
+ **kwargs
47
+ )
48
+ super(**kwargs)
49
+
50
+ @hidden_size = hidden_size
51
+ @num_hidden_layers = num_hidden_layers
52
+ @num_attention_heads = num_attention_heads
53
+ @intermediate_size = intermediate_size
54
+ @hidden_act = hidden_act
55
+ @hidden_dropout_prob = hidden_dropout_prob
56
+ @attention_probs_dropout_prob = attention_probs_dropout_prob
57
+ @max_position_embeddings = max_position_embeddings
58
+ @type_vocab_size = type_vocab_size
59
+ @initializer_range = initializer_range
60
+ @relative_attention = relative_attention
61
+ @max_relative_positions = max_relative_positions
62
+ @pad_token_id = pad_token_id
63
+ @position_biased_input = position_biased_input
64
+
65
+ # Backwards compatibility
66
+ if pos_att_type.is_a?(String)
67
+ pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
68
+ end
69
+
70
+ @pos_att_type = pos_att_type
71
+ @vocab_size = vocab_size
72
+ @layer_norm_eps = layer_norm_eps
73
+
74
+ @pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
75
+ @pooler_dropout = pooler_dropout
76
+ @pooler_hidden_act = pooler_hidden_act
77
+ end
78
+ end
79
+ end
80
+ end