transformers-rb 0.1.1 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +64 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +16 -5
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +15 -5
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
|
4
|
+
data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
|
7
|
+
data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -32,8 +32,17 @@ Embedding
|
|
32
32
|
- [intfloat/e5-base-v2](#intfloate5-base-v2)
|
33
33
|
- [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
|
34
34
|
- [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
|
35
|
+
- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
|
36
|
+
|
37
|
+
Sparse embedding
|
38
|
+
|
35
39
|
- [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
|
36
40
|
|
41
|
+
Reranking
|
42
|
+
|
43
|
+
- [mixedbread-ai/mxbai-rerank-base-v1](#mixedbread-aimxbai-rerank-base-v1)
|
44
|
+
- [BAAI/bge-reranker-base](#baaibge-reranker-base)
|
45
|
+
|
37
46
|
### sentence-transformers/all-MiniLM-L6-v2
|
38
47
|
|
39
48
|
[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
@@ -139,6 +148,17 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
|
|
139
148
|
embeddings = model.(input, pooling: "cls")
|
140
149
|
```
|
141
150
|
|
151
|
+
### sentence-transformers/all-mpnet-base-v2
|
152
|
+
|
153
|
+
[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
|
154
|
+
|
155
|
+
```ruby
|
156
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
157
|
+
|
158
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
|
159
|
+
embeddings = model.(sentences)
|
160
|
+
```
|
161
|
+
|
142
162
|
### opensearch-project/opensearch-neural-sparse-encoding-v1
|
143
163
|
|
144
164
|
[Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
|
@@ -160,8 +180,37 @@ values[0.., special_token_ids] = 0
|
|
160
180
|
embeddings = values.to_a
|
161
181
|
```
|
162
182
|
|
183
|
+
### mixedbread-ai/mxbai-rerank-base-v1
|
184
|
+
|
185
|
+
[Docs](https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1)
|
186
|
+
|
187
|
+
```ruby
|
188
|
+
query = "How many people live in London?"
|
189
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
190
|
+
|
191
|
+
model = Transformers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-base-v1")
|
192
|
+
result = model.(query, docs)
|
193
|
+
```
|
194
|
+
|
195
|
+
### BAAI/bge-reranker-base
|
196
|
+
|
197
|
+
[Docs](https://huggingface.co/BAAI/bge-reranker-base)
|
198
|
+
|
199
|
+
```ruby
|
200
|
+
query = "How many people live in London?"
|
201
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
202
|
+
|
203
|
+
model = Transformers.pipeline("reranking", "BAAI/bge-reranker-base")
|
204
|
+
result = model.(query, docs)
|
205
|
+
```
|
206
|
+
|
163
207
|
## Pipelines
|
164
208
|
|
209
|
+
- [Text](#text)
|
210
|
+
- [Vision](#vision)
|
211
|
+
|
212
|
+
### Text
|
213
|
+
|
165
214
|
Embedding
|
166
215
|
|
167
216
|
```ruby
|
@@ -169,6 +218,13 @@ embed = Transformers.pipeline("embedding")
|
|
169
218
|
embed.("We are very happy to show you the 🤗 Transformers library.")
|
170
219
|
```
|
171
220
|
|
221
|
+
Reranking
|
222
|
+
|
223
|
+
```ruby
|
224
|
+
rerank = Informers.pipeline("reranking")
|
225
|
+
rerank.("Who created Ruby?", ["Matz created Ruby", "Another doc"])
|
226
|
+
```
|
227
|
+
|
172
228
|
Named-entity recognition
|
173
229
|
|
174
230
|
```ruby
|
@@ -197,27 +253,32 @@ extractor = Transformers.pipeline("feature-extraction")
|
|
197
253
|
extractor.("We are very happy to show you the 🤗 Transformers library.")
|
198
254
|
```
|
199
255
|
|
256
|
+
### Vision
|
257
|
+
|
200
258
|
Image classification
|
201
259
|
|
202
260
|
```ruby
|
203
261
|
classifier = Transformers.pipeline("image-classification")
|
204
|
-
classifier.(
|
262
|
+
classifier.("image.jpg")
|
205
263
|
```
|
206
264
|
|
207
265
|
Image feature extraction
|
208
266
|
|
209
267
|
```ruby
|
210
268
|
extractor = Transformers.pipeline("image-feature-extraction")
|
211
|
-
extractor.(
|
269
|
+
extractor.("image.jpg")
|
212
270
|
```
|
213
271
|
|
214
272
|
## API
|
215
273
|
|
216
|
-
This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index).
|
274
|
+
This library follows the [Transformers Python API](https://huggingface.co/docs/transformers/index). The following model architectures are currently supported:
|
217
275
|
|
218
276
|
- BERT
|
277
|
+
- DeBERTa-v2
|
219
278
|
- DistilBERT
|
279
|
+
- MPNet
|
220
280
|
- ViT
|
281
|
+
- XLM-RoBERTa
|
221
282
|
|
222
283
|
## History
|
223
284
|
|
@@ -91,10 +91,24 @@ module Transformers
|
|
91
91
|
# Config hash
|
92
92
|
@commit_hash = kwargs.delete(:_commit_hash)
|
93
93
|
|
94
|
-
#
|
95
|
-
@
|
96
|
-
|
97
|
-
|
94
|
+
# Attention implementation to use, if relevant.
|
95
|
+
@attn_implementation_internal = kwargs.delete(:attn_implementation)
|
96
|
+
|
97
|
+
# Drop the transformers version info
|
98
|
+
@transformers_version = kwargs.delete(:transformers_version)
|
99
|
+
|
100
|
+
# Deal with gradient checkpointing
|
101
|
+
# if kwargs[:gradient_checkpointing] == false
|
102
|
+
# warn(
|
103
|
+
# "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
|
104
|
+
# "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
|
105
|
+
# "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
|
106
|
+
# )
|
107
|
+
# end
|
108
|
+
|
109
|
+
kwargs.each do |k, v|
|
110
|
+
instance_variable_set("@#{k}", v)
|
111
|
+
end
|
98
112
|
end
|
99
113
|
|
100
114
|
def name_or_path
|
@@ -182,6 +196,20 @@ module Transformers
|
|
182
196
|
JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
|
183
197
|
end
|
184
198
|
|
199
|
+
def getattr(key, default)
|
200
|
+
if respond_to?(key)
|
201
|
+
public_send(key)
|
202
|
+
elsif instance_variable_defined?("@#{key}")
|
203
|
+
instance_variable_get("@#{key}")
|
204
|
+
else
|
205
|
+
default
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
def hasattr(key)
|
210
|
+
respond_to?(key) || instance_variable_defined?("@#{key}")
|
211
|
+
end
|
212
|
+
|
185
213
|
class << self
|
186
214
|
def from_pretrained(
|
187
215
|
pretrained_model_name_or_path,
|
@@ -207,7 +207,7 @@ module Transformers
|
|
207
207
|
|
208
208
|
def init_weights
|
209
209
|
# Prune heads if needed
|
210
|
-
if @config.pruned_heads
|
210
|
+
if @config.pruned_heads.any?
|
211
211
|
prune_heads(@config.pruned_heads)
|
212
212
|
end
|
213
213
|
|
@@ -803,11 +803,18 @@ module Transformers
|
|
803
803
|
raise Todo
|
804
804
|
end
|
805
805
|
|
806
|
+
model_class_name = model.class.name.split("::").last
|
807
|
+
|
806
808
|
if error_msgs.length > 0
|
807
|
-
|
809
|
+
error_msg = error_msgs.join("\n\t")
|
810
|
+
if error_msg.include?("size mismatch")
|
811
|
+
error_msg += (
|
812
|
+
"\n\tYou may consider adding `ignore_mismatched_sizes: true` in the model `from_pretrained` method."
|
813
|
+
)
|
814
|
+
end
|
815
|
+
raise RuntimeError, "Error(s) in loading state_dict for #{model_class_name}:\n\t#{error_msg}"
|
808
816
|
end
|
809
817
|
|
810
|
-
model_class_name = model.class.name.split("::").last
|
811
818
|
if unexpected_keys.length > 0
|
812
819
|
archs = model.config.architectures.nil? ? [] : model.config.architectures
|
813
820
|
warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
|
@@ -116,7 +116,7 @@ module Transformers
|
|
116
116
|
def _load_attr_from_module(model_type, attr)
|
117
117
|
module_name = model_type_to_module_name(model_type)
|
118
118
|
if !@modules.include?(module_name)
|
119
|
-
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
119
|
+
@modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
|
120
120
|
end
|
121
121
|
getattribute_from_module(@modules[module_name], attr)
|
122
122
|
end
|
@@ -15,8 +15,11 @@
|
|
15
15
|
module Transformers
|
16
16
|
CONFIG_MAPPING_NAMES = {
|
17
17
|
"bert" => "BertConfig",
|
18
|
+
"deberta-v2" => "DebertaV2Config",
|
18
19
|
"distilbert" => "DistilBertConfig",
|
19
|
-
"
|
20
|
+
"mpnet" => "MPNetConfig",
|
21
|
+
"vit" => "ViTConfig",
|
22
|
+
"xlm-roberta" => "XLMRobertaConfig"
|
20
23
|
}
|
21
24
|
|
22
25
|
class LazyConfigMapping
|
@@ -30,7 +33,7 @@ module Transformers
|
|
30
33
|
value = @mapping.fetch(key)
|
31
34
|
module_name = model_type_to_module_name(key)
|
32
35
|
if !@modules.include?(module_name)
|
33
|
-
@modules[module_name] = Transformers.const_get(module_name.capitalize)
|
36
|
+
@modules[module_name] = Transformers.const_get(module_name.split("-").map(&:capitalize).join)
|
34
37
|
end
|
35
38
|
@modules[module_name].const_get(value)
|
36
39
|
end
|
@@ -15,16 +15,22 @@
|
|
15
15
|
module Transformers
|
16
16
|
MODEL_MAPPING_NAMES = {
|
17
17
|
"bert" => "BertModel",
|
18
|
+
"deberta-v2" => "DebertaV2Model",
|
18
19
|
"distilbert" => "DistilBertModel",
|
19
|
-
"
|
20
|
+
"mpnet" => "MPNetModel",
|
21
|
+
"vit" => "ViTModel",
|
22
|
+
"xlm-roberta" => "XLMRobertaModel"
|
20
23
|
}
|
21
24
|
|
22
25
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
23
|
-
"bert" => "BertForMaskedLM"
|
26
|
+
"bert" => "BertForMaskedLM",
|
27
|
+
"mpnet" => "MPNetForMaskedLM"
|
24
28
|
}
|
25
29
|
|
26
30
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
27
|
-
"
|
31
|
+
"deberta-v2" => "DebertaV2ForSequenceClassification",
|
32
|
+
"distilbert" => "DistilBertForSequenceClassification",
|
33
|
+
"xlm-roberta" => "XLMRobertaForSequenceClassification"
|
28
34
|
}
|
29
35
|
|
30
36
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
@@ -15,7 +15,10 @@
|
|
15
15
|
module Transformers
|
16
16
|
TOKENIZER_MAPPING_NAMES = {
|
17
17
|
"bert" => ["BertTokenizer", "BertTokenizerFast"],
|
18
|
-
"
|
18
|
+
"deberta-v2" => ["DebertaV2TokenizerFast"],
|
19
|
+
"distilbert" => ["DistilBertTokenizer", "DistilBertTokenizerFast"],
|
20
|
+
"mpnet" => ["MPNetTokenizerFast"],
|
21
|
+
"xlm-roberta" => ["XLMRobertaTokenizerFast"]
|
19
22
|
}
|
20
23
|
|
21
24
|
TOKENIZER_MAPPING = LazyAutoMapping.new(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAMES)
|
@@ -98,7 +101,7 @@ module Transformers
|
|
98
101
|
|
99
102
|
TOKENIZER_MAPPING_NAMES.each do |module_name, tokenizers|
|
100
103
|
if tokenizers.include?(class_name)
|
101
|
-
cls = Transformers.const_get(module_name.capitalize).const_get(class_name)
|
104
|
+
cls = Transformers.const_get(module_name.split("-").map(&:capitalize).join).const_get(class_name)
|
102
105
|
raise Error, "Invalid tokenizer class: #{class_name}" unless cls < PreTrainedTokenizer || cls < PreTrainedTokenizerFast
|
103
106
|
return cls
|
104
107
|
end
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2020, Microsoft and the HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module DebertaV2
|
17
|
+
class DebertaV2Config < PretrainedConfig
|
18
|
+
self.model_type = "deberta-v2"
|
19
|
+
|
20
|
+
attr_reader :vocab_size, :hidden_size, :num_hidden_layers, :num_attention_heads,
|
21
|
+
:intermediate_size, :hidden_act, :hidden_dropout_prob, :attention_probs_dropout_prob,
|
22
|
+
:max_position_embeddings, :type_vocab_size, :initializer_range, :layer_norm_eps,
|
23
|
+
:relative_attention, :max_relative_positions, :pad_token_id, :position_biased_input,
|
24
|
+
:pos_att_type, :pooler_dropout, :pooler_hidden_act, :pooler_hidden_size
|
25
|
+
|
26
|
+
def initialize(
|
27
|
+
vocab_size: 128100,
|
28
|
+
hidden_size: 1536,
|
29
|
+
num_hidden_layers: 24,
|
30
|
+
num_attention_heads: 24,
|
31
|
+
intermediate_size: 6144,
|
32
|
+
hidden_act: "gelu",
|
33
|
+
hidden_dropout_prob: 0.1,
|
34
|
+
attention_probs_dropout_prob: 0.1,
|
35
|
+
max_position_embeddings: 512,
|
36
|
+
type_vocab_size: 0,
|
37
|
+
initializer_range: 0.02,
|
38
|
+
layer_norm_eps: 1e-07,
|
39
|
+
relative_attention: false,
|
40
|
+
max_relative_positions: -1,
|
41
|
+
pad_token_id: 0,
|
42
|
+
position_biased_input: true,
|
43
|
+
pos_att_type: nil,
|
44
|
+
pooler_dropout: 0,
|
45
|
+
pooler_hidden_act: "gelu",
|
46
|
+
**kwargs
|
47
|
+
)
|
48
|
+
super(**kwargs)
|
49
|
+
|
50
|
+
@hidden_size = hidden_size
|
51
|
+
@num_hidden_layers = num_hidden_layers
|
52
|
+
@num_attention_heads = num_attention_heads
|
53
|
+
@intermediate_size = intermediate_size
|
54
|
+
@hidden_act = hidden_act
|
55
|
+
@hidden_dropout_prob = hidden_dropout_prob
|
56
|
+
@attention_probs_dropout_prob = attention_probs_dropout_prob
|
57
|
+
@max_position_embeddings = max_position_embeddings
|
58
|
+
@type_vocab_size = type_vocab_size
|
59
|
+
@initializer_range = initializer_range
|
60
|
+
@relative_attention = relative_attention
|
61
|
+
@max_relative_positions = max_relative_positions
|
62
|
+
@pad_token_id = pad_token_id
|
63
|
+
@position_biased_input = position_biased_input
|
64
|
+
|
65
|
+
# Backwards compatibility
|
66
|
+
if pos_att_type.is_a?(String)
|
67
|
+
pos_att_type = pos_att_type.downcase.split("|").map { |x| x.strip }
|
68
|
+
end
|
69
|
+
|
70
|
+
@pos_att_type = pos_att_type
|
71
|
+
@vocab_size = vocab_size
|
72
|
+
@layer_norm_eps = layer_norm_eps
|
73
|
+
|
74
|
+
@pooler_hidden_size = kwargs[:pooler_hidden_size] || hidden_size
|
75
|
+
@pooler_dropout = pooler_dropout
|
76
|
+
@pooler_hidden_act = pooler_hidden_act
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|