transformers-rb 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/README.md +61 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +10 -0
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +14 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
|
4
|
+
data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
|
7
|
+
data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -32,11 +32,17 @@ Embedding
|
|
32
32
|
- [intfloat/e5-base-v2](#intfloate5-base-v2)
|
33
33
|
- [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
|
34
34
|
- [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
|
35
|
+
- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
|
35
36
|
|
36
37
|
Sparse embedding
|
37
38
|
|
38
39
|
- [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
|
39
40
|
|
41
|
+
Reranking
|
42
|
+
|
43
|
+
- [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
|
44
|
+
- [BAAI/bge-reranker-base](#baaibge-reranker-base)
|
45
|
+
|
40
46
|
### sentence-transformers/all-MiniLM-L6-v2
|
41
47
|
|
42
48
|
[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
@@ -142,6 +148,17 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
|
|
142
148
|
embeddings = model.(input, pooling: "cls")
|
143
149
|
```
|
144
150
|
|
151
|
+
### sentence-transformers/all-mpnet-base-v2
|
152
|
+
|
153
|
+
[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
|
154
|
+
|
155
|
+
```ruby
|
156
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
157
|
+
|
158
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
|
159
|
+
embeddings = model.(sentences)
|
160
|
+
```
|
161
|
+
|
145
162
|
### opensearch-project/opensearch-neural-sparse-encoding-v1
|
146
163
|
|
147
164
|
[Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
|
@@ -163,8 +180,37 @@ values[0.., special_token_ids] = 0
|
|
163
180
|
embeddings = values.to_a
|
164
181
|
```
|
165
182
|
|
183
|
+
### mixedbread-ai/mxbai-rerank-base-v1
|
184
|
+
|
185
|
+
[Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
|
186
|
+
|
187
|
+
```ruby
|
188
|
+
query = "How many people live in London?"
|
189
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
190
|
+
|
191
|
+
model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
|
192
|
+
result = model.(query, docs)
|
193
|
+
```
|
194
|
+
|
195
|
+
### BAAI/bge-reranker-base
|
196
|
+
|
197
|
+
[Docs](https://huggingface.co/BAAI/bge-reranker-base)
|
198
|
+
|
199
|
+
```ruby
|
200
|
+
query = "How many people live in London?"
|
201
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
202
|
+
|
203
|
+
model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
|
204
|
+
result = model.(query, docs)
|
205
|
+
```
|
206
|
+
|
166
207
|
## Pipelines
|
167
208
|
|
209
|
+
- [Text](#text)
|
210
|
+
- [Vision](#vision)
|
211
|
+
|
212
|
+
### Text
|
213
|
+
|
168
214
|
Embedding
|
169
215
|
|
170
216
|
```ruby
|
@@ -172,6 +218,13 @@ embed = Transformers.pipeline("embedding")
|
|
172
218
|
embed.("We are very happy to show you the 🤗 Transformers library.")
|
173
219
|
```
|
174
220
|
|
221
|
+
Reranking
|
222
|
+
|
223
|
+
```ruby
|
224
|
+
rerank = Informers.pipeline("reranking")
|
225
|
+
rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
|
226
|
+
```
|
227
|
+
|
175
228
|
Named-entity recognition
|
176
229
|
|
177
230
|
```ruby
|
@@ -200,27 +253,32 @@ extractor = Transformers.pipeline("feature-extraction")
|
|
200
253
|
extractor.("We are very happy to show you the 🤗 Transformers library.")
|
201
254
|
```
|
202
255
|
|
256
|
+
### Vision
|
257
|
+
|
203
258
|
Image classification
|
204
259
|
|
205
260
|
```ruby
|
206
261
|
classifier = Transformers.pipeline("image-classification")
|
207
|
-
classifier.(
|
262
|
+
classifier.("image.jpg")
|
208
263
|
```
|
209
264
|
|
210
265
|
Image feature extraction
|
211
266
|
|
212
267
|
```ruby
|
213
268
|
extractor = Transformers.pipeline("image-feature-extraction")
|
214
|
-
extractor.(
|
269
|
+
extractor.("image.jpg")
|
215
270
|
```
|
216
271
|
|
217
272
|
## API
|
218
273
|
|
219
|
-
This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index).
|
274
|
+
This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
|
220
275
|
|
221
276
|
- BERT
|
277
|
+
- DeBERTa-v2
|
222
278
|
- DistilBERT
|
279
|
+
- MPNet
|
223
280
|
- ViT
|
281
|
+
- XLM-RoBERTa
|
224
282
|
|
225
283
|
## History
|
226
284
|
|
@@ -91,10 +91,24 @@ module Transformers
|
|
91
91
|
# Config hash
|
92
92
|
@commit_hash = kwargs.delete(:_commit_hash)
|
93
93
|
|
94
|
-
#
|
95
|
-
@
|
96
|
-
|
97
|
-
|
94
|
+
# Attention implementation to use, if relevant.
|
95
|
+
@attn_implementation_internal = kwargs.delete(:attn_implementation)
|
96
|
+
|
97
|
+
# Drop the transformers version info
|
98
|
+
@transformers_version = kwargs.delete(:transformers_version)
|
99
|
+
|
100
|
+
# Deal with gradient checkpointing
|
101
|
+
# if kwargs[:gradient_checkpointing] == false
|
102
|
+
# warn(
|
103
|
+
# "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
|
104
|
+
# "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
|
105
|
+
# "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
|
106
|
+
# )
|
107
|
+
# end
|
108
|
+
|
109
|
+
kwargs.each do |k, v|
|
110
|
+
instance_variable_set("@#{k}", v)
|
111
|
+
end
|
98
112
|
end
|
99
113
|
|
100
114
|
def name_or_path
|
@@ -182,6 +196,20 @@ module Transformers
|
|
182
196
|
JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
|
183
197
|
end
|
184
198
|
|
199
|
+
def getattr(key, default)
|
200
|
+
if respond_to?(key)
|
201
|
+
public_send(key)
|
202
|
+
elsif instance_variable_defined?("@#{key}")
|
203
|
+
instance_variable_get("@#{key}")
|
204
|
+
else
|
205
|
+
default
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
def hasattr(key)
|
210
|
+
respond_to?(key) || instance_variable_defined?("@#{key}")
|
211
|
+
end
|
212
|
+
|
185
213
|
class << self
|
186
214
|
def from_pretrained(
|
187
215
|
pretrained_model_name_or_path,
|
@@ -207,7 +207,7 @@ module Transformers
|
|
207
207
|
|
208
208
|
def init_weights
|
209
209
|
# Prune heads if needed
|
210
|
-
if @config.pruned_heads
|
210
|
+
if @config.pruned_heads.any?
|
211
211
|
prune_heads(@config.pruned_heads)
|
212
212
|
end
|
213
213
|
|
@@ -803,11 +803,18 @@ module Transformers
|
|
803
803
|
raise Todo
|
804
804
|
end
|
805
805
|
|
806
|
+
model_class_name = model.class.name.split("::").last
|
807
|
+
|
806
808
|
if error_msgs.length > 0
|
807
|
-
|
809
|
+
error_msg = error_msgs.join("\n\t")
|
810
|
+
if error_msg.include?("size mismatch")
|
811
|
+
error_msg += (
|
812
|
+
"\n\tYou may consider adding `ignore_mismatched_sizes: true` in the model `from_pretrained` method."
|
813
|
+
)
|
814
|
+
end
|
815
|
+
raise RuntimeError, "Error(s) in loading state_dict for #{model_class_name}:\n\t#{error_msg}"
|
808
816
|
end
|
809
817
|
|
810
|
-
model_class_name = model.class.name.split("::").last
|
811
818
|
if unexpected_keys.length > 0
|
812
819
|
archs = model.config.architectures.nil? ? [] : model.config.architectures
|
813
820
|
warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
|
@@ -116,7 +116,7 @@ module Transformers
|
|
116
116
|
def _load_attr_from_module(model_type, attr)
|
117
117
|
module_name = model_type_to_module_name(model_type)
|
118
118
|
if !@modules.include?(module_name)
|
119
|
-
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
119
|
+
@modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
|
120
120
|
end
|
121
121
|
getattribute_from_module(@modules[module_name], attr)
|
122
122
|
end
|
@@ -15,8 +15,11 @@
|
|
15
15
|
module Transformers
|
16
16
|
CONFIG_MAPPING_NAMES = {
|
17
17
|
"bert" => "BertConfig",
|
18
|
+
"deberta-v2" => "DebertaV2Config",
|
18
19
|
"distilbert" => "DistilBertConfig",
|
19
|
-
"
|
20
|
+
"mpnet" => "MPNetConfig",
|
21
|
+
"vit" => "ViTConfig",
|
22
|
+
"xlm-roberta" => "XLMRobertaConfig"
|
20
23
|
}
|
21
24
|
|
22
25
|
class LazyConfigMapping
|
@@ -30,7 +33,7 @@ module Transformers
|
|
30
33
|
value = @mapping.fetch(key)
|
31
34
|
module_name = model_type_to_module_name(key)
|
32
35
|
if !@modules.include?(module_name)
|
33
|
-
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
36
|
+
@modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
|
34
37
|
end
|
35
38
|
@modules[module_name].const_get(value)
|
36
39
|
end
|
@@ -15,16 +15,22 @@
|
|
15
15
|
module Transformers
|
16
16
|
MODEL_MAPPING_NAMES = {
|
17
17
|
"bert" => "BertModel",
|
18
|
+
"deberta-v2" => "DebertaV2Model",
|
18
19
|
"distilbert" => "DistilBertModel",
|
19
|
-
"
|
20
|
+
"mpnet" => "MPNetModel",
|
21
|
+
"vit" => "ViTModel",
|
22
|
+
"xlm-roberta" => "XLMRobertaModel"
|
20
23
|
}
|
21
24
|
|
22
25
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
23
|
-
"bert" => "BertForMaskedLM"
|
26
|
+
"bert" => "BertForMaskedLM",
|
27
|
+
"mpnet" => "MPNetForMaskedLM"
|
24
28
|
}
|
25
29
|
|
26
30
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
27
|
-
"
|
31
|
+
"deberta-v2" => "DebertaV2ForSequenceClassification",
|
32
|
+
"distilbert" => "DistilBertForSequenceClassification",
|
33
|
+
"xlm-roberta" => "XLMRobertaForSequenceClassification"
|
28
34
|
}
|
29
35
|
|
30
36
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
@@ -15,7 +15,10 @@
|
|
15
15
|
module Transformers
|
16
16
|
TOKENIZER_MAPPING_NAMES = {
|
17
17
|
"bert" => ["BertTokenizer", "BertTokenizerFast"],
|
18
|
-
"
|
18
|
+
"deberta-v2" => ["DebertaV2TokenizerFast"],
|
19
|
+
"distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
|
20
|
+
"mpnet" => ["MPNetTokenizerFast"],
|
21
|
+
"xlm-roberta" => ["XLMRobertaTokenizerFast"]
|
19
22
|
}
|
20
23
|
|
21
24
|
TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
@@ -98,7 +101,7 @@ module Transformers
|
|
98
101
|
|
99
102
|
TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
|
100
103
|
if tokenizers.include?(class_name)
|
101
|
-
cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
|
104
|
+
cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
|
102
105
|
raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
|
103
106
|
return cls
|
104
107
|
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module DebertaV2
|
17
|
+
class DebertaV2Config < PretrainedConfig
|
18
|
+
self.model_type = "deberta-v2"
|
19
|
+
|
20
|
+
attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
|
21
|
+
:intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
|
22
|
+
:max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
|
23
|
+
:relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
|
24
|
+
:pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
|
25
|
+
|
26
|
+
def initialize(
|
27
|
+
vocab_size: 128100,
|
28
|
+
hidden_size: 1536,
|
29
|
+
num_hidden_layers: 24,
|
30
|
+
num_attention_heads: 24,
|
31
|
+
intermediate_size: 6144,
|
32
|
+
hidden_act: "gelu",
|
33
|
+
hidden_dropout_prob: 0.1,
|
34
|
+
attention_probs_dropout_prob: 0.1,
|
35
|
+
max_position_embeddings: 512,
|
36
|
+
type_vocab_size: 0,
|
37
|
+
initializer_range: 0.02,
|
38
|
+
layer_norm_eps: 1e-07,
|
39
|
+
relative_attention: false,
|
40
|
+
max_relative_positions: -1,
|
41
|
+
pad_token_id: 0,
|
42
|
+
position_biased_input: true,
|
43
|
+
pos_att_type: nil,
|
44
|
+
pooler_dropout: 0,
|
45
|
+
pooler_hidden_act: "gelu",
|
46
|
+
**kwargs
|
47
|
+
)
|
48
|
+
super(**kwargs)
|
49
|
+
|
50
|
+
@hidden_size = hidden_size
|
51
|
+
@num_hidden_layers = num_hidden_layers
|
52
|
+
@num_attention_heads = num_attention_heads
|
53
|
+
@intermediate_size = intermediate_size
|
54
|
+
@hidden_act = hidden_act
|
55
|
+
@hidden_dropout_prob = hidden_dropout_prob
|
56
|
+
@attention_probs_dropout_prob = attention_probs_dropout_prob
|
57
|
+
@max_position_embeddings = max_position_embeddings
|
58
|
+
@type_vocab_size = type_vocab_size
|
59
|
+
@initializer_range = initializer_range
|
60
|
+
@relative_attention = relative_attention
|
61
|
+
@max_relative_positions = max_relative_positions
|
62
|
+
@pad_token_id = pad_token_id
|
63
|
+
@position_biased_input = position_biased_input
|
64
|
+
|
65
|
+
# Backwards compatibility
|
66
|
+
if pos_att_type.is_a?(String)
|
67
|
+
pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
|
68
|
+
end
|
69
|
+
|
70
|
+
@pos_att_type = pos_att_type
|
71
|
+
@vocab_size = vocab_size
|
72
|
+
@layer_norm_eps = layer_norm_eps
|
73
|
+
|
74
|
+
@pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
|
75
|
+
@pooler_dropout = pooler_dropout
|
76
|
+
@pooler_hidden_act = pooler_hidden_act
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|