transformers-rb 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/README.md +61 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +10 -0
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +14 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
|
4
|
+
data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
|
7
|
+
data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -32,11 +32,17 @@ Embedding
|
|
32
32
|
- [intfloat/e5-base-v2](#intfloate5-base-v2)
|
33
33
|
- [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
|
34
34
|
- [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
|
35
|
+
- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
|
35
36
|
|
36
37
|
Sparse embedding
|
37
38
|
|
38
39
|
- [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
|
39
40
|
|
41
|
+
Reranking
|
42
|
+
|
43
|
+
- [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
|
44
|
+
- [BAAI/bge-reranker-base](#baaibge-reranker-base)
|
45
|
+
|
40
46
|
### sentence-transformers/all-MiniLM-L6-v2
|
41
47
|
|
42
48
|
[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
@@ -142,6 +148,17 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
|
|
142
148
|
embeddings = model.(input, pooling: "cls")
|
143
149
|
```
|
144
150
|
|
151
|
+
### sentence-transformers/all-mpnet-base-v2
|
152
|
+
|
153
|
+
[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
|
154
|
+
|
155
|
+
```ruby
|
156
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
157
|
+
|
158
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
|
159
|
+
embeddings = model.(sentences)
|
160
|
+
```
|
161
|
+
|
145
162
|
### opensearch-project/opensearch-neural-sparse-encoding-v1
|
146
163
|
|
147
164
|
[Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
|
@@ -163,8 +180,37 @@ values[0.., special_token_ids] = 0
|
|
163
180
|
embeddings = values.to_a
|
164
181
|
```
|
165
182
|
|
183
|
+
### mixedbread-ai/mxbai-rerank-base-v1
|
184
|
+
|
185
|
+
[Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
|
186
|
+
|
187
|
+
```ruby
|
188
|
+
query = "How many people live in London?"
|
189
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
190
|
+
|
191
|
+
model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
|
192
|
+
result = model.(query, docs)
|
193
|
+
```
|
194
|
+
|
195
|
+
### BAAI/bge-reranker-base
|
196
|
+
|
197
|
+
[Docs](https://huggingface.co/BAAI/bge-reranker-base)
|
198
|
+
|
199
|
+
```ruby
|
200
|
+
query = "How many people live in London?"
|
201
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
202
|
+
|
203
|
+
model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
|
204
|
+
result = model.(query, docs)
|
205
|
+
```
|
206
|
+
|
166
207
|
## Pipelines
|
167
208
|
|
209
|
+
- [Text](#text)
|
210
|
+
- [Vision](#vision)
|
211
|
+
|
212
|
+
### Text
|
213
|
+
|
168
214
|
Embedding
|
169
215
|
|
170
216
|
```ruby
|
@@ -172,6 +218,13 @@ embed = Transformers.pipeline("embedding")
|
|
172
218
|
embed.("We are very happy to show you the 🤗 Transformers library.")
|
173
219
|
```
|
174
220
|
|
221
|
+
Reranking
|
222
|
+
|
223
|
+
```ruby
|
224
|
+
rerank = Informers.pipeline("reranking")
|
225
|
+
rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
|
226
|
+
```
|
227
|
+
|
175
228
|
Named-entity recognition
|
176
229
|
|
177
230
|
```ruby
|
@@ -200,27 +253,32 @@ extractor = Transformers.pipeline("feature-extraction")
|
|
200
253
|
extractor.("We are very happy to show you the 🤗 Transformers library.")
|
201
254
|
```
|
202
255
|
|
256
|
+
### Vision
|
257
|
+
|
203
258
|
Image classification
|
204
259
|
|
205
260
|
```ruby
|
206
261
|
classifier = Transformers.pipeline("image-classification")
|
207
|
-
classifier.(
|
262
|
+
classifier.("image.jpg")
|
208
263
|
```
|
209
264
|
|
210
265
|
Image feature extraction
|
211
266
|
|
212
267
|
```ruby
|
213
268
|
extractor = Transformers.pipeline("image-feature-extraction")
|
214
|
-
extractor.(
|
269
|
+
extractor.("image.jpg")
|
215
270
|
```
|
216
271
|
|
217
272
|
## API
|
218
273
|
|
219
|
-
This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index).
|
274
|
+
This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
|
220
275
|
|
221
276
|
- BERT
|
277
|
+
- DeBERTa-v2
|
222
278
|
- DistilBERT
|
279
|
+
- MPNet
|
223
280
|
- ViT
|
281
|
+
- XLM-RoBERTa
|
224
282
|
|
225
283
|
## History
|
226
284
|
|
@@ -91,10 +91,24 @@ module Transformers
|
|
91
91
|
# Config hash
|
92
92
|
@commit_hash = kwargs.delete(:_commit_hash)
|
93
93
|
|
94
|
-
#
|
95
|
-
@
|
96
|
-
|
97
|
-
|
94
|
+
# Attention implementation to use, if relevant.
|
95
|
+
@attn_implementation_internal = kwargs.delete(:attn_implementation)
|
96
|
+
|
97
|
+
# Drop the transformers version info
|
98
|
+
@transformers_version = kwargs.delete(:transformers_version)
|
99
|
+
|
100
|
+
# Deal with gradient checkpointing
|
101
|
+
# if kwargs[:gradient_checkpointing] == false
|
102
|
+
# warn(
|
103
|
+
# "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
|
104
|
+
# "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
|
105
|
+
# "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
|
106
|
+
# )
|
107
|
+
# end
|
108
|
+
|
109
|
+
kwargs.each do |k, v|
|
110
|
+
instance_variable_set("@#{k}", v)
|
111
|
+
end
|
98
112
|
end
|
99
113
|
|
100
114
|
def name_or_path
|
@@ -182,6 +196,20 @@ module Transformers
|
|
182
196
|
JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
|
183
197
|
end
|
184
198
|
|
199
|
+
def getattr(key, default)
|
200
|
+
if respond_to?(key)
|
201
|
+
public_send(key)
|
202
|
+
elsif instance_variable_defined?("@#{key}")
|
203
|
+
instance_variable_get("@#{key}")
|
204
|
+
else
|
205
|
+
default
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
def hasattr(key)
|
210
|
+
respond_to?(key) || instance_variable_defined?("@#{key}")
|
211
|
+
end
|
212
|
+
|
185
213
|
class << self
|
186
214
|
def from_pretrained(
|
187
215
|
pretrained_model_name_or_path,
|
@@ -207,7 +207,7 @@ module Transformers
|
|
207
207
|
|
208
208
|
def init_weights
|
209
209
|
# Prune heads if needed
|
210
|
-
if @config.pruned_heads
|
210
|
+
if @config.pruned_heads.any?
|
211
211
|
prune_heads(@config.pruned_heads)
|
212
212
|
end
|
213
213
|
|
@@ -803,11 +803,18 @@ module Transformers
|
|
803
803
|
raise Todo
|
804
804
|
end
|
805
805
|
|
806
|
+
model_class_name = model.class.name.split("::").last
|
807
|
+
|
806
808
|
if error_msgs.length > 0
|
807
|
-
|
809
|
+
error_msg = error_msgs.join("\n\t")
|
810
|
+
if error_msg.include?("size mismatch")
|
811
|
+
error_msg += (
|
812
|
+
"\n\tYou may consider adding `ignore_mismatched_sizes: true` in the model `from_pretrained` method."
|
813
|
+
)
|
814
|
+
end
|
815
|
+
raise RuntimeError, "Error(s) in loading state_dict for #{model_class_name}:\n\t#{error_msg}"
|
808
816
|
end
|
809
817
|
|
810
|
-
model_class_name = model.class.name.split("::").last
|
811
818
|
if unexpected_keys.length > 0
|
812
819
|
archs = model.config.architectures.nil? ? [] : model.config.architectures
|
813
820
|
warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
|
@@ -116,7 +116,7 @@ module Transformers
|
|
116
116
|
def _load_attr_from_module(model_type, attr)
|
117
117
|
module_name = model_type_to_module_name(model_type)
|
118
118
|
if !@modules.include?(module_name)
|
119
|
-
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
119
|
+
@modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
|
120
120
|
end
|
121
121
|
getattribute_from_module(@modules[module_name], attr)
|
122
122
|
end
|
@@ -15,8 +15,11 @@
|
|
15
15
|
module Transformers
|
16
16
|
CONFIG_MAPPING_NAMES = {
|
17
17
|
"bert" => "BertConfig",
|
18
|
+
"deberta-v2" => "DebertaV2Config",
|
18
19
|
"distilbert" => "DistilBertConfig",
|
19
|
-
"
|
20
|
+
"mpnet" => "MPNetConfig",
|
21
|
+
"vit" => "ViTConfig",
|
22
|
+
"xlm-roberta" => "XLMRobertaConfig"
|
20
23
|
}
|
21
24
|
|
22
25
|
class LazyConfigMapping
|
@@ -30,7 +33,7 @@ module Transformers
|
|
30
33
|
value = @mapping.fetch(key)
|
31
34
|
module_name = model_type_to_module_name(key)
|
32
35
|
if !@modules.include?(module_name)
|
33
|
-
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
36
|
+
@modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
|
34
37
|
end
|
35
38
|
@modules[module_name].const_get(value)
|
36
39
|
end
|
@@ -15,16 +15,22 @@
|
|
15
15
|
module Transformers
|
16
16
|
MODEL_MAPPING_NAMES = {
|
17
17
|
"bert" => "BertModel",
|
18
|
+
"deberta-v2" => "DebertaV2Model",
|
18
19
|
"distilbert" => "DistilBertModel",
|
19
|
-
"
|
20
|
+
"mpnet" => "MPNetModel",
|
21
|
+
"vit" => "ViTModel",
|
22
|
+
"xlm-roberta" => "XLMRobertaModel"
|
20
23
|
}
|
21
24
|
|
22
25
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
23
|
-
"bert" => "BertForMaskedLM"
|
26
|
+
"bert" => "BertForMaskedLM",
|
27
|
+
"mpnet" => "MPNetForMaskedLM"
|
24
28
|
}
|
25
29
|
|
26
30
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
27
|
-
"
|
31
|
+
"deberta-v2" => "DebertaV2ForSequenceClassification",
|
32
|
+
"distilbert" => "DistilBertForSequenceClassification",
|
33
|
+
"xlm-roberta" => "XLMRobertaForSequenceClassification"
|
28
34
|
}
|
29
35
|
|
30
36
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
@@ -15,7 +15,10 @@
|
|
15
15
|
module Transformers
|
16
16
|
TOKENIZER_MAPPING_NAMES = {
|
17
17
|
"bert" => ["BertTokenizer", "BertTokenizerFast"],
|
18
|
-
"
|
18
|
+
"deberta-v2" => ["DebertaV2TokenizerFast"],
|
19
|
+
"distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
|
20
|
+
"mpnet" => ["MPNetTokenizerFast"],
|
21
|
+
"xlm-roberta" => ["XLMRobertaTokenizerFast"]
|
19
22
|
}
|
20
23
|
|
21
24
|
TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
@@ -98,7 +101,7 @@ module Transformers
|
|
98
101
|
|
99
102
|
TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
|
100
103
|
if tokenizers.include?(class_name)
|
101
|
-
cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
|
104
|
+
cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
|
102
105
|
raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
|
103
106
|
return cls
|
104
107
|
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module DebertaV2
|
17
|
+
class DebertaV2Config < PretrainedConfig
|
18
|
+
self.model_type = "deberta-v2"
|
19
|
+
|
20
|
+
attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
|
21
|
+
:intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
|
22
|
+
:max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
|
23
|
+
:relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
|
24
|
+
:pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
|
25
|
+
|
26
|
+
def initialize(
|
27
|
+
vocab_size: 128100,
|
28
|
+
hidden_size: 1536,
|
29
|
+
num_hidden_layers: 24,
|
30
|
+
num_attention_heads: 24,
|
31
|
+
intermediate_size: 6144,
|
32
|
+
hidden_act: "gelu",
|
33
|
+
hidden_dropout_prob: 0.1,
|
34
|
+
attention_probs_dropout_prob: 0.1,
|
35
|
+
max_position_embeddings: 512,
|
36
|
+
type_vocab_size: 0,
|
37
|
+
initializer_range: 0.02,
|
38
|
+
layer_norm_eps: 1e-07,
|
39
|
+
relative_attention: false,
|
40
|
+
max_relative_positions: -1,
|
41
|
+
pad_token_id: 0,
|
42
|
+
position_biased_input: true,
|
43
|
+
pos_att_type: nil,
|
44
|
+
pooler_dropout: 0,
|
45
|
+
pooler_hidden_act: "gelu",
|
46
|
+
**kwargs
|
47
|
+
)
|
48
|
+
super(**kwargs)
|
49
|
+
|
50
|
+
@hidden_size = hidden_size
|
51
|
+
@num_hidden_layers = num_hidden_layers
|
52
|
+
@num_attention_heads = num_attention_heads
|
53
|
+
@intermediate_size = intermediate_size
|
54
|
+
@hidden_act = hidden_act
|
55
|
+
@hidden_dropout_prob = hidden_dropout_prob
|
56
|
+
@attention_probs_dropout_prob = attention_probs_dropout_prob
|
57
|
+
@max_position_embeddings = max_position_embeddings
|
58
|
+
@type_vocab_size = type_vocab_size
|
59
|
+
@initializer_range = initializer_range
|
60
|
+
@relative_attention = relative_attention
|
61
|
+
@max_relative_positions = max_relative_positions
|
62
|
+
@pad_token_id = pad_token_id
|
63
|
+
@position_biased_input = position_biased_input
|
64
|
+
|
65
|
+
# Backwards compatibility
|
66
|
+
if pos_att_type.is_a?(String)
|
67
|
+
pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
|
68
|
+
end
|
69
|
+
|
70
|
+
@pos_att_type = pos_att_type
|
71
|
+
@vocab_size = vocab_size
|
72
|
+
@layer_norm_eps = layer_norm_eps
|
73
|
+
|
74
|
+
@pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
|
75
|
+
@pooler_dropout = pooler_dropout
|
76
|
+
@pooler_hidden_act = pooler_hidden_act
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|