transformers-rb 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 3f29055705824ba101cba238960d4f10825c75bc7867b9eb0b611cda6a547612
4
- data.tar.gz: d0967f7742f7b2d6194376eb040a3be81e77a9ded94302aeb934de678959e434
3
+ metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
4
+ data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
5
5
  SHA512:
6
- metadata.gz: 38b9ed4fd654ca593e3d6e7c7f20eb3c6b68ecfa5f86099fbc8d160f9093617cc79a571c331a0ec0c70a6770c8d9460194ba75e61d534f9e15931f22e5ae60c3
7
- data.tar.gz: 00ce437ce8fe419fafddd59b7f9f61050d2ddf5817816b53beb1c43badbced9fe77ea28b26f2d607e87a9be855a19c4b254bc50df2881b4126f2def5c6875c3d
6
+ metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
7
+ data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ ## 0.1.3 (2024-09-17)
2
+
3
+ - Added `reranking` pipeline
4
+ - Added DeBERTa-v2
5
+ - Added MPNet
6
+ - Added XLM-RoBERTa
7
+
1
8
  ## 0.1.2 (2024-09-10)
2
9
 
3
10
  - Fixed default revision for pipelines
data/README.md CHANGED
@@ -32,11 +32,17 @@ Embedding
32
32
  - [intfloat/e5-base-v2](#intfloate5-base-v2)
33
33
  - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
34
34
  - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+ - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
35
36
 
36
37
  Sparse embedding
37
38
 
38
39
  - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
39
40
 
41
+ Reranking
42
+
43
+ - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
44
+ - [BAAI/bge-reranker-base](#baaibge-reranker-base)
45
+
40
46
  ### sentence-transformers/all-MiniLM-L6-v2
41
47
 
42
48
  [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
@@ -142,6 +148,17 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
142
148
  embeddings = model.(input, pooling: "cls")
143
149
  ```
144
150
 
151
+ ### sentence-transformers/all-mpnet-base-v2
152
+
153
+ [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
154
+
155
+ ```ruby
156
+ sentences = ["This is an example sentence", "Each sentence is converted"]
157
+
158
+ model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
159
+ embeddings = model.(sentences)
160
+ ```
161
+
145
162
  ### opensearch-project/opensearch-neural-sparse-encoding-v1
146
163
 
147
164
  [Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
@@ -163,8 +180,37 @@ values[0.., special_token_ids] = 0
163
180
  embeddings = values.to_a
164
181
  ```
165
182
 
183
+ ### mixedbread-ai/mxbai-rerank-base-v1
184
+
185
+ [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
186
+
187
+ ```ruby
188
+ query = "How many people live in London?"
189
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
190
+
191
+ model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
192
+ result = model.(query, docs)
193
+ ```
194
+
195
+ ### BAAI/bge-reranker-base
196
+
197
+ [Docs](https://huggingface.co/BAAI/bge-reranker-base)
198
+
199
+ ```ruby
200
+ query = "How many people live in London?"
201
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
202
+
203
+ model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
204
+ result = model.(query, docs)
205
+ ```
206
+
166
207
  ## Pipelines
167
208
 
209
+ - [Text](#text)
210
+ - [Vision](#vision)
211
+
212
+ ### Text
213
+
168
214
  Embedding
169
215
 
170
216
  ```ruby
@@ -172,6 +218,13 @@ embed = Transformers.pipeline("embedding")
172
218
  embed.("We are very happy to show you the 🤗 Transformers library.")
173
219
  ```
174
220
 
221
+ Reranking
222
+
223
+ ```ruby
224
+ rerank = Informers.pipeline("reranking")
225
+ rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
226
+ ```
227
+
175
228
  Named-entity recognition
176
229
 
177
230
  ```ruby
@@ -200,27 +253,32 @@ extractor = Transformers.pipeline("feature-extraction")
200
253
  extractor.("We are very happy to show you the 🤗 Transformers library.")
201
254
  ```
202
255
 
256
+ ### Vision
257
+
203
258
  Image classification
204
259
 
205
260
  ```ruby
206
261
  classifier = Transformers.pipeline("image-classification")
207
- classifier.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"))
262
+ classifier.("image.jpg")
208
263
  ```
209
264
 
210
265
  Image feature extraction
211
266
 
212
267
  ```ruby
213
268
  extractor = Transformers.pipeline("image-feature-extraction")
214
- extractor.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"))
269
+ extractor.("image.jpg")
215
270
  ```
216
271
 
217
272
  ## API
218
273
 
219
- This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). Only a few model architectures are currently supported:
274
+ This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
220
275
 
221
276
  - BERT
277
+ - DeBERTa-v2
222
278
  - DistilBERT
279
+ - MPNet
223
280
  - ViT
281
+ - XLM-RoBERTa
224
282
 
225
283
  ## History
226
284
 
@@ -91,10 +91,24 @@ module Transformers
91
91
  # Config hash
92
92
  @commit_hash = kwargs.delete(:_commit_hash)
93
93
 
94
- # TODO set kwargs
95
- @gradient_checkpointing = kwargs[:gradient_checkpointing]
96
- @output_past = kwargs[:output_past]
97
- @tie_weights_ = kwargs[:tie_weights_]
94
+ # Attention implementation to use, if relevant.
95
+ @attn_implementation_internal = kwargs.delete(:attn_implementation)
96
+
97
+ # Drop the transformers version info
98
+ @transformers_version = kwargs.delete(:transformers_version)
99
+
100
+ # Deal with gradient checkpointing
101
+ # if kwargs[:gradient_checkpointing] == false
102
+ # warn(
103
+ # "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
104
+ # "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
105
+ # "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
106
+ # )
107
+ # end
108
+
109
+ kwargs.each do |k, v|
110
+ instance_variable_set("@#{k}", v)
111
+ end
98
112
  end
99
113
 
100
114
  def name_or_path
@@ -182,6 +196,20 @@ module Transformers
182
196
  JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
183
197
  end
184
198
 
199
+ def getattr(key, default)
200
+ if respond_to?(key)
201
+ public_send(key)
202
+ elsif instance_variable_defined?("@#{key}")
203
+ instance_variable_get("@#{key}")
204
+ else
205
+ default
206
+ end
207
+ end
208
+
209
+ def hasattr(key)
210
+ respond_to?(key) || instance_variable_defined?("@#{key}")
211
+ end
212
+
185
213
  class << self
186
214
  def from_pretrained(
187
215
  pretrained_model_name_or_path,
@@ -207,7 +207,7 @@ module Transformers
207
207
 
208
208
  def init_weights
209
209
  # Prune heads if needed
210
- if @config.pruned_heads
210
+ if @config.pruned_heads.any?
211
211
  prune_heads(@config.pruned_heads)
212
212
  end
213
213
 
@@ -803,11 +803,18 @@ module Transformers
803
803
  raise Todo
804
804
  end
805
805
 
806
+ model_class_name = model.class.name.split("::").last
807
+
806
808
  if error_msgs.length > 0
807
- raise Todo
809
+ error_msg = error_msgs.join("\n\t")
810
+ if error_msg.include?("size mismatch")
811
+ error_msg += (
812
+ "\n\tYou may consider adding `ignore_mismatched_sizes: true` in the model `from_pretrained` method."
813
+ )
814
+ end
815
+ raise RuntimeError, "Error(s) in loading state_dict for #{model_class_name}:\n\t#{error_msg}"
808
816
  end
809
817
 
810
- model_class_name = model.class.name.split("::").last
811
818
  if unexpected_keys.length > 0
812
819
  archs = model.config.architectures.nil? ? [] : model.config.architectures
813
820
  warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
@@ -116,7 +116,7 @@ module Transformers
116
116
  def _load_attr_from_module(model_type, attr)
117
117
  module_name = model_type_to_module_name(model_type)
118
118
  if !@modules.include?(module_name)
119
- @modules[module_name] = Transformers.const_get(module_name.capitalize)
119
+ @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
120
120
  end
121
121
  getattribute_from_module(@modules[module_name], attr)
122
122
  end
@@ -15,8 +15,11 @@
15
15
  module Transformers
16
16
  CONFIG_MAPPING_NAMES = {
17
17
  "bert" => "BertConfig",
18
+ "deberta-v2" => "DebertaV2Config",
18
19
  "distilbert" => "DistilBertConfig",
19
- "vit" => "ViTConfig"
20
+ "mpnet" => "MPNetConfig",
21
+ "vit" => "ViTConfig",
22
+ "xlm-roberta" => "XLMRobertaConfig"
20
23
  }
21
24
 
22
25
  class LazyConfigMapping
@@ -30,7 +33,7 @@ module Transformers
30
33
  value = @mapping.fetch(key)
31
34
  module_name = model_type_to_module_name(key)
32
35
  if !@modules.include?(module_name)
33
- @modules[module_name] = Transformers.const_get(module_name.capitalize)
36
+ @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
34
37
  end
35
38
  @modules[module_name].const_get(value)
36
39
  end
@@ -15,16 +15,22 @@
15
15
  module Transformers
16
16
  MODEL_MAPPING_NAMES = {
17
17
  "bert" => "BertModel",
18
+ "deberta-v2" => "DebertaV2Model",
18
19
  "distilbert" => "DistilBertModel",
19
- "vit" => "ViTModel"
20
+ "mpnet" => "MPNetModel",
21
+ "vit" => "ViTModel",
22
+ "xlm-roberta" => "XLMRobertaModel"
20
23
  }
21
24
 
22
25
  MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
23
- "bert" => "BertForMaskedLM"
26
+ "bert" => "BertForMaskedLM",
27
+ "mpnet" => "MPNetForMaskedLM"
24
28
  }
25
29
 
26
30
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
27
- "distilbert" => "DistilBertForSequenceClassification"
31
+ "deberta-v2" => "DebertaV2ForSequenceClassification",
32
+ "distilbert" => "DistilBertForSequenceClassification",
33
+ "xlm-roberta" => "XLMRobertaForSequenceClassification"
28
34
  }
29
35
 
30
36
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
@@ -15,7 +15,10 @@
15
15
  module Transformers
16
16
  TOKENIZER_MAPPING_NAMES = {
17
17
  "bert" => ["BertTokenizer", "BertTokenizerFast"],
18
- "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"]
18
+ "deberta-v2" => ["DebertaV2TokenizerFast"],
19
+ "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
20
+ "mpnet" => ["MPNetTokenizerFast"],
21
+ "xlm-roberta" => ["XLMRobertaTokenizerFast"]
19
22
  }
20
23
 
21
24
  TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
@@ -98,7 +101,7 @@ module Transformers
98
101
 
99
102
  TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
100
103
  if tokenizers.include?(class_name)
101
- cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
104
+ cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
102
105
  raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
103
106
  return cls
104
107
  end
@@ -0,0 +1,80 @@
1
+ # Copyright 2020, Microsoft and the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module DebertaV2
17
+ class DebertaV2Config < PretrainedConfig
18
+ self.model_type = "deberta-v2"
19
+
20
+ attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
21
+ :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
22
+ :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
23
+ :relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
24
+ :pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
25
+
26
+ def initialize(
27
+ vocab_size: 128100,
28
+ hidden_size: 1536,
29
+ num_hidden_layers: 24,
30
+ num_attention_heads: 24,
31
+ intermediate_size: 6144,
32
+ hidden_act: "gelu",
33
+ hidden_dropout_prob: 0.1,
34
+ attention_probs_dropout_prob: 0.1,
35
+ max_position_embeddings: 512,
36
+ type_vocab_size: 0,
37
+ initializer_range: 0.02,
38
+ layer_norm_eps: 1e-07,
39
+ relative_attention: false,
40
+ max_relative_positions: -1,
41
+ pad_token_id: 0,
42
+ position_biased_input: true,
43
+ pos_att_type: nil,
44
+ pooler_dropout: 0,
45
+ pooler_hidden_act: "gelu",
46
+ **kwargs
47
+ )
48
+ super(**kwargs)
49
+
50
+ @hidden_size = hidden_size
51
+ @num_hidden_layers = num_hidden_layers
52
+ @num_attention_heads = num_attention_heads
53
+ @intermediate_size = intermediate_size
54
+ @hidden_act = hidden_act
55
+ @hidden_dropout_prob = hidden_dropout_prob
56
+ @attention_probs_dropout_prob = attention_probs_dropout_prob
57
+ @max_position_embeddings = max_position_embeddings
58
+ @type_vocab_size = type_vocab_size
59
+ @initializer_range = initializer_range
60
+ @relative_attention = relative_attention
61
+ @max_relative_positions = max_relative_positions
62
+ @pad_token_id = pad_token_id
63
+ @position_biased_input = position_biased_input
64
+
65
+ # Backwards compatibility
66
+ if pos_att_type.is_a?(String)
67
+ pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
68
+ end
69
+
70
+ @pos_att_type = pos_att_type
71
+ @vocab_size = vocab_size
72
+ @layer_norm_eps = layer_norm_eps
73
+
74
+ @pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
75
+ @pooler_dropout = pooler_dropout
76
+ @pooler_hidden_act = pooler_hidden_act
77
+ end
78
+ end
79
+ end
80
+ end