transformers-rb 0.1.1 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2b7441c0cc400d85d3d0775e6c53a631c83887e481588148a7cb3242e50c0f08
4
- data.tar.gz: df639eb5c3231d344ec8516ec861cb40239ac5ff0310faa5b1b102a79af97972
3
+ metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
4
+ data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
5
5
  SHA512:
6
- metadata.gz: 566a03e8a65255d3eb7e87bb3e68c0ac4d228ba4bc55e62eaad98f6677b8a5f29ee9201efaefc3edb309d279c1a986a664dfffbf3cc292114394be1c10fa949a
7
- data.tar.gz: fdae84f4fa8aac90f1a627d738fb701ef91d384414173f4042bac56bfe5c78c6ceb98aaea8ef0f2158d324ec60d18cbd91f84bd30f42b49d63193392e0105e96
6
+ metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
7
+ data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## 0.1.3 (2024-09-17)
2
+
3
+ - Added `reranking` pipeline
4
+ - Added DeBERTa-v2
5
+ - Added MPNet
6
+ - Added XLM-RoBERTa
7
+
8
+ ## 0.1.2 (2024-09-10)
9
+
10
+ - Fixed default revision for pipelines
11
+
1
12
  ## 0.1.1 (2024-08-29)
2
13
 
3
14
  - Added `embedding` pipeline
data/README.md CHANGED
@@ -32,8 +32,17 @@ Embedding
32
32
  - [intfloat/e5-base-v2](#intfloate5-base-v2)
33
33
  - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
34
34
  - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+ - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
36
+
37
+ Sparse embedding
38
+
35
39
  - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
36
40
 
41
+ Reranking
42
+
43
+ - [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
44
+ - [BAAI/bge-reranker-base](#baaibge-reranker-base)
45
+
37
46
  ### sentence-transformers/all-MiniLM-L6-v2
38
47
 
39
48
  [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
@@ -139,6 +148,17 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
139
148
  embeddings = model.(input, pooling: "cls")
140
149
  ```
141
150
 
151
+ ### sentence-transformers/all-mpnet-base-v2
152
+
153
+ [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
154
+
155
+ ```ruby
156
+ sentences = ["This is an example sentence", "Each sentence is converted"]
157
+
158
+ model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
159
+ embeddings = model.(sentences)
160
+ ```
161
+
142
162
  ### opensearch-project/opensearch-neural-sparse-encoding-v1
143
163
 
144
164
  [Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
@@ -160,8 +180,37 @@ values[0.., special_token_ids] = 0
160
180
  embeddings = values.to_a
161
181
  ```
162
182
 
183
+ ### mixedbread-ai/mxbai-rerank-base-v1
184
+
185
+ [Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
186
+
187
+ ```ruby
188
+ query = "How many people live in London?"
189
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
190
+
191
+ model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
192
+ result = model.(query, docs)
193
+ ```
194
+
195
+ ### BAAI/bge-reranker-base
196
+
197
+ [Docs](https://huggingface.co/BAAI/bge-reranker-base)
198
+
199
+ ```ruby
200
+ query = "How many people live in London?"
201
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
202
+
203
+ model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
204
+ result = model.(query, docs)
205
+ ```
206
+
163
207
  ## Pipelines
164
208
 
209
+ - [Text](#text)
210
+ - [Vision](#vision)
211
+
212
+ ### Text
213
+
165
214
  Embedding
166
215
 
167
216
  ```ruby
@@ -169,6 +218,13 @@ embed = Transformers.pipeline("embedding")
169
218
  embed.("We are very happy to show you the 🤗 Transformers library.")
170
219
  ```
171
220
 
221
+ Reranking
222
+
223
+ ```ruby
224
+ rerank = Informers.pipeline("reranking")
225
+ rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
226
+ ```
227
+
172
228
  Named-entity recognition
173
229
 
174
230
  ```ruby
@@ -197,27 +253,32 @@ extractor = Transformers.pipeline("feature-extraction")
197
253
  extractor.("We are very happy to show you the 🤗 Transformers library.")
198
254
  ```
199
255
 
256
+ ### Vision
257
+
200
258
  Image classification
201
259
 
202
260
  ```ruby
203
261
  classifier = Transformers.pipeline("image-classification")
204
- classifier.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"))
262
+ classifier.("image.jpg")
205
263
  ```
206
264
 
207
265
  Image feature extraction
208
266
 
209
267
  ```ruby
210
268
  extractor = Transformers.pipeline("image-feature-extraction")
211
- extractor.(URI("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"))
269
+ extractor.("image.jpg")
212
270
  ```
213
271
 
214
272
  ## API
215
273
 
216
- This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). Only a few model architectures are currently supported:
274
+ This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
217
275
 
218
276
  - BERT
277
+ - DeBERTa-v2
219
278
  - DistilBERT
279
+ - MPNet
220
280
  - ViT
281
+ - XLM-RoBERTa
221
282
 
222
283
  ## History
223
284
 
@@ -91,10 +91,24 @@ module Transformers
91
91
  # Config hash
92
92
  @commit_hash = kwargs.delete(:_commit_hash)
93
93
 
94
- # TODO set kwargs
95
- @gradient_checkpointing = kwargs[:gradient_checkpointing]
96
- @output_past = kwargs[:output_past]
97
- @tie_weights_ = kwargs[:tie_weights_]
94
+ # Attention implementation to use, if relevant.
95
+ @attn_implementation_internal = kwargs.delete(:attn_implementation)
96
+
97
+ # Drop the transformers version info
98
+ @transformers_version = kwargs.delete(:transformers_version)
99
+
100
+ # Deal with gradient checkpointing
101
+ # if kwargs[:gradient_checkpointing] == false
102
+ # warn(
103
+ # "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
104
+ # "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
105
+ # "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
106
+ # )
107
+ # end
108
+
109
+ kwargs.each do |k, v|
110
+ instance_variable_set("@#{k}", v)
111
+ end
98
112
  end
99
113
 
100
114
  def name_or_path
@@ -182,6 +196,20 @@ module Transformers
182
196
  JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
183
197
  end
184
198
 
199
+ def getattr(key, default)
200
+ if respond_to?(key)
201
+ public_send(key)
202
+ elsif instance_variable_defined?("@#{key}")
203
+ instance_variable_get("@#{key}")
204
+ else
205
+ default
206
+ end
207
+ end
208
+
209
+ def hasattr(key)
210
+ respond_to?(key) || instance_variable_defined?("@#{key}")
211
+ end
212
+
185
213
  class << self
186
214
  def from_pretrained(
187
215
  pretrained_model_name_or_path,
@@ -207,7 +207,7 @@ module Transformers
207
207
 
208
208
  def init_weights
209
209
  # Prune heads if needed
210
- if @config.pruned_heads
210
+ if @config.pruned_heads.any?
211
211
  prune_heads(@config.pruned_heads)
212
212
  end
213
213
 
@@ -803,11 +803,18 @@ module Transformers
803
803
  raise Todo
804
804
  end
805
805
 
806
+ model_class_name = model.class.name.split("::").last
807
+
806
808
  if error_msgs.length > 0
807
- raise Todo
809
+ error_msg = error_msgs.join("\n\t")
810
+ if error_msg.include?("size mismatch")
811
+ error_msg += (
812
+ "\n\tYou may consider adding `ignore_mismatched_sizes: true` in the model `from_pretrained` method."
813
+ )
814
+ end
815
+ raise RuntimeError, "Error(s) in loading state_dict for #{model_class_name}:\n\t#{error_msg}"
808
816
  end
809
817
 
810
- model_class_name = model.class.name.split("::").last
811
818
  if unexpected_keys.length > 0
812
819
  archs = model.config.architectures.nil? ? [] : model.config.architectures
813
820
  warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
@@ -116,7 +116,7 @@ module Transformers
116
116
  def _load_attr_from_module(model_type, attr)
117
117
  module_name = model_type_to_module_name(model_type)
118
118
  if !@modules.include?(module_name)
119
- @modules[module_name] = Transformers.const_get(module_name.capitalize)
119
+ @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
120
120
  end
121
121
  getattribute_from_module(@modules[module_name], attr)
122
122
  end
@@ -15,8 +15,11 @@
15
15
  module Transformers
16
16
  CONFIG_MAPPING_NAMES = {
17
17
  "bert" => "BertConfig",
18
+ "deberta-v2" => "DebertaV2Config",
18
19
  "distilbert" => "DistilBertConfig",
19
- "vit" => "ViTConfig"
20
+ "mpnet" => "MPNetConfig",
21
+ "vit" => "ViTConfig",
22
+ "xlm-roberta" => "XLMRobertaConfig"
20
23
  }
21
24
 
22
25
  class LazyConfigMapping
@@ -30,7 +33,7 @@ module Transformers
30
33
  value = @mapping.fetch(key)
31
34
  module_name = model_type_to_module_name(key)
32
35
  if !@modules.include?(module_name)
33
- @modules[module_name] = Transformers.const_get(module_name.capitalize)
36
+ @modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
34
37
  end
35
38
  @modules[module_name].const_get(value)
36
39
  end
@@ -15,16 +15,22 @@
15
15
  module Transformers
16
16
  MODEL_MAPPING_NAMES = {
17
17
  "bert" => "BertModel",
18
+ "deberta-v2" => "DebertaV2Model",
18
19
  "distilbert" => "DistilBertModel",
19
- "vit" => "ViTModel"
20
+ "mpnet" => "MPNetModel",
21
+ "vit" => "ViTModel",
22
+ "xlm-roberta" => "XLMRobertaModel"
20
23
  }
21
24
 
22
25
  MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
23
- "bert" => "BertForMaskedLM"
26
+ "bert" => "BertForMaskedLM",
27
+ "mpnet" => "MPNetForMaskedLM"
24
28
  }
25
29
 
26
30
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
27
- "distilbert" => "DistilBertForSequenceClassification"
31
+ "deberta-v2" => "DebertaV2ForSequenceClassification",
32
+ "distilbert" => "DistilBertForSequenceClassification",
33
+ "xlm-roberta" => "XLMRobertaForSequenceClassification"
28
34
  }
29
35
 
30
36
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
@@ -15,7 +15,10 @@
15
15
  module Transformers
16
16
  TOKENIZER_MAPPING_NAMES = {
17
17
  "bert" => ["BertTokenizer", "BertTokenizerFast"],
18
- "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"]
18
+ "deberta-v2" => ["DebertaV2TokenizerFast"],
19
+ "distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
20
+ "mpnet" => ["MPNetTokenizerFast"],
21
+ "xlm-roberta" => ["XLMRobertaTokenizerFast"]
19
22
  }
20
23
 
21
24
  TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
@@ -98,7 +101,7 @@ module Transformers
98
101
 
99
102
  TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
100
103
  if tokenizers.include?(class_name)
101
- cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
104
+ cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
102
105
  raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
103
106
  return cls
104
107
  end
@@ -0,0 +1,80 @@
1
+ # Copyright 2020, Microsoft and the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module DebertaV2
17
+ class DebertaV2Config < PretrainedConfig
18
+ self.model_type = "deberta-v2"
19
+
20
+ attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
21
+ :intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
22
+ :max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
23
+ :relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
24
+ :pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
25
+
26
+ def initialize(
27
+ vocab_size: 128100,
28
+ hidden_size: 1536,
29
+ num_hidden_layers: 24,
30
+ num_attention_heads: 24,
31
+ intermediate_size: 6144,
32
+ hidden_act: "gelu",
33
+ hidden_dropout_prob: 0.1,
34
+ attention_probs_dropout_prob: 0.1,
35
+ max_position_embeddings: 512,
36
+ type_vocab_size: 0,
37
+ initializer_range: 0.02,
38
+ layer_norm_eps: 1e-07,
39
+ relative_attention: false,
40
+ max_relative_positions: -1,
41
+ pad_token_id: 0,
42
+ position_biased_input: true,
43
+ pos_att_type: nil,
44
+ pooler_dropout: 0,
45
+ pooler_hidden_act: "gelu",
46
+ **kwargs
47
+ )
48
+ super(**kwargs)
49
+
50
+ @hidden_size = hidden_size
51
+ @num_hidden_layers = num_hidden_layers
52
+ @num_attention_heads = num_attention_heads
53
+ @intermediate_size = intermediate_size
54
+ @hidden_act = hidden_act
55
+ @hidden_dropout_prob = hidden_dropout_prob
56
+ @attention_probs_dropout_prob = attention_probs_dropout_prob
57
+ @max_position_embeddings = max_position_embeddings
58
+ @type_vocab_size = type_vocab_size
59
+ @initializer_range = initializer_range
60
+ @relative_attention = relative_attention
61
+ @max_relative_positions = max_relative_positions
62
+ @pad_token_id = pad_token_id
63
+ @position_biased_input = position_biased_input
64
+
65
+ # Backwards compatibility
66
+ if pos_att_type.is_a?(String)
67
+ pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
68
+ end
69
+
70
+ @pos_att_type = pos_att_type
71
+ @vocab_size = vocab_size
72
+ @layer_norm_eps = layer_norm_eps
73
+
74
+ @pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
75
+ @pooler_dropout = pooler_dropout
76
+ @pooler_hidden_act = pooler_hidden_act
77
+ end
78
+ end
79
+ end
80
+ end