transformers-rb 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +24 -12
- data/lib/transformers/configuration_utils.rb +3 -1
- data/lib/transformers/hf_hub/constants.rb +1 -1
- data/lib/transformers/modeling_utils.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +1 -1
- data/lib/transformers/models/auto/modeling_auto.rb +1 -0
- data/lib/transformers/models/bert/modeling_bert.rb +93 -3
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +1 -1
- data/lib/transformers/tokenization_utils.rb +0 -1
- data/lib/transformers/utils/hub.rb +3 -3
- data/lib/transformers/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 16b409e954f6bcc45fd4f3f2db94dc92c87d47c4c06936162978cc7d7e54fc09
|
4
|
+
data.tar.gz: c40ea58b7531e89a041ce2782dea89536f85227903e2dc9a60113afe041bb9f7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 0576500ca9fe9379aae4c2cc050aa34c90eea7b2d5251b6139c48d88b5107086e6197aad5210bc74d7969a5bfa0458d31ce7faf561df69b2ae9b2a6400280ce0
|
7
|
+
data.tar.gz: b5b0a865acdd37bcde11571e365a1d39b473f3dcbbe737a082f37002f886fe38f0f57c77f44236f859939d1d3b1df9659bb524a483cd0b9048dbc4cd472a3355
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -27,12 +27,13 @@ Embedding
|
|
27
27
|
|
28
28
|
- [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
|
29
29
|
- [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
|
30
|
+
- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
|
31
|
+
- [sentence-transformers/paraphrase-MiniLM-L6-v2](#sentence-transformersparaphrase-minilm-l6-v2)
|
30
32
|
- [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
|
31
33
|
- [thenlper/gte-small](#thenlpergte-small)
|
32
34
|
- [intfloat/e5-base-v2](#intfloate5-base-v2)
|
33
35
|
- [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
|
34
36
|
- [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
|
35
|
-
- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
|
36
37
|
|
37
38
|
Sparse embedding
|
38
39
|
|
@@ -69,6 +70,28 @@ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
|
|
69
70
|
doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
|
70
71
|
```
|
71
72
|
|
73
|
+
### sentence-transformers/all-mpnet-base-v2
|
74
|
+
|
75
|
+
[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
|
76
|
+
|
77
|
+
```ruby
|
78
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
79
|
+
|
80
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
|
81
|
+
embeddings = model.(sentences)
|
82
|
+
```
|
83
|
+
|
84
|
+
### sentence-transformers/paraphrase-MiniLM-L6-v2
|
85
|
+
|
86
|
+
[Docs](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)
|
87
|
+
|
88
|
+
```ruby
|
89
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
90
|
+
|
91
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/paraphrase-MiniLM-L6-v2")
|
92
|
+
embeddings = model.(sentences)
|
93
|
+
```
|
94
|
+
|
72
95
|
### mixedbread-ai/mxbai-embed-large-v1
|
73
96
|
|
74
97
|
[Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
|
@@ -148,17 +171,6 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
|
|
148
171
|
embeddings = model.(input, pooling: "cls")
|
149
172
|
```
|
150
173
|
|
151
|
-
### sentence-transformers/all-mpnet-base-v2
|
152
|
-
|
153
|
-
[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
|
154
|
-
|
155
|
-
```ruby
|
156
|
-
sentences = ["This is an example sentence", "Each sentence is converted"]
|
157
|
-
|
158
|
-
model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
|
159
|
-
embeddings = model.(sentences)
|
160
|
-
```
|
161
|
-
|
162
174
|
### opensearch-project/opensearch-neural-sparse-encoding-v1
|
163
175
|
|
164
176
|
[Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
|
@@ -36,7 +36,9 @@ module Transformers
|
|
36
36
|
|
37
37
|
attr_reader :output_hidden_states, :output_attentions, :pruned_heads, :tie_word_embeddings, :tokenizer_class,
|
38
38
|
:chunk_size_feed_forward, :pad_token_id, :is_decoder, :add_cross_attention,
|
39
|
-
:
|
39
|
+
:id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
|
40
|
+
|
41
|
+
attr_accessor :problem_type
|
40
42
|
|
41
43
|
def initialize(**kwargs)
|
42
44
|
@return_dict = kwargs.delete(:return_dict) { true }
|
@@ -23,7 +23,7 @@ module Transformers
|
|
23
23
|
"xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
|
24
24
|
"xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
|
25
25
|
"kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
|
26
|
-
"kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!)
|
26
|
+
"kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!)
|
27
27
|
# "uniform" => Torch::NN::Init.method(:uniform),
|
28
28
|
# "normal" => Torch::NN::Init.method(:normal),
|
29
29
|
# "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
|
@@ -55,7 +55,7 @@ module Transformers
|
|
55
55
|
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
56
56
|
if config_dict[:model_type]
|
57
57
|
config_class = CONFIG_MAPPING[config_dict[:model_type]]
|
58
|
-
|
58
|
+
config_class.from_dict(config_dict, **unused_kwargs)
|
59
59
|
else
|
60
60
|
raise Todo
|
61
61
|
end
|
@@ -28,6 +28,7 @@ module Transformers
|
|
28
28
|
}
|
29
29
|
|
30
30
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
31
|
+
"bert" => "BertForSequenceClassification",
|
31
32
|
"deberta-v2" => "DebertaV2ForSequenceClassification",
|
32
33
|
"distilbert" => "DistilBertForSequenceClassification",
|
33
34
|
"xlm-roberta" => "XLMRobertaForSequenceClassification"
|
@@ -353,7 +353,7 @@ module Transformers
|
|
353
353
|
def feed_forward_chunk(attention_output)
|
354
354
|
intermediate_output = @intermediate.(attention_output)
|
355
355
|
layer_output = @output.(intermediate_output, attention_output)
|
356
|
-
|
356
|
+
layer_output
|
357
357
|
end
|
358
358
|
end
|
359
359
|
|
@@ -370,7 +370,7 @@ module Transformers
|
|
370
370
|
attention_mask: nil,
|
371
371
|
head_mask: nil,
|
372
372
|
encoder_hidden_states: nil,
|
373
|
-
encoder_attention_mask:nil,
|
373
|
+
encoder_attention_mask: nil,
|
374
374
|
past_key_values: nil,
|
375
375
|
use_cache: nil,
|
376
376
|
output_attentions: false,
|
@@ -814,7 +814,7 @@ module Transformers
|
|
814
814
|
loss = nil
|
815
815
|
if !labels.nil?
|
816
816
|
loss_fct = CrossEntropyLoss.new
|
817
|
-
loss = loss_fct.(logits.view(-1
|
817
|
+
loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
|
818
818
|
end
|
819
819
|
|
820
820
|
if !return_dict
|
@@ -829,8 +829,98 @@ module Transformers
|
|
829
829
|
)
|
830
830
|
end
|
831
831
|
end
|
832
|
+
|
833
|
+
class BertForSequenceClassification < BertPreTrainedModel
|
834
|
+
def initialize(config)
|
835
|
+
super
|
836
|
+
@num_labels = config.num_labels
|
837
|
+
@config = config
|
838
|
+
|
839
|
+
@bert = BertModel.new(config, add_pooling_layer: true)
|
840
|
+
classifier_dropout = (
|
841
|
+
config.classifier_dropout.nil? ? config.hidden_dropout_prob : config.classifier_dropout
|
842
|
+
)
|
843
|
+
@dropout = Torch::NN::Dropout.new(p: classifier_dropout)
|
844
|
+
@classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
|
845
|
+
|
846
|
+
# Initialize weights and apply final processing
|
847
|
+
post_init
|
848
|
+
end
|
849
|
+
|
850
|
+
def forward(
|
851
|
+
input_ids: nil,
|
852
|
+
attention_mask: nil,
|
853
|
+
token_type_ids: nil,
|
854
|
+
position_ids: nil,
|
855
|
+
head_mask: nil,
|
856
|
+
inputs_embeds: nil,
|
857
|
+
labels: nil,
|
858
|
+
output_attentions: nil,
|
859
|
+
output_hidden_states: nil,
|
860
|
+
return_dict: nil
|
861
|
+
)
|
862
|
+
return_dict = @config.use_return_dict if return_dict.nil?
|
863
|
+
|
864
|
+
outputs = @bert.(
|
865
|
+
input_ids: input_ids,
|
866
|
+
attention_mask: attention_mask,
|
867
|
+
token_type_ids: token_type_ids,
|
868
|
+
position_ids: position_ids,
|
869
|
+
head_mask: head_mask,
|
870
|
+
inputs_embeds: inputs_embeds,
|
871
|
+
output_attentions: output_attentions,
|
872
|
+
output_hidden_states: output_hidden_states,
|
873
|
+
return_dict: return_dict
|
874
|
+
)
|
875
|
+
|
876
|
+
pooled_output = outputs[1]
|
877
|
+
|
878
|
+
pooled_output = @dropout.(pooled_output)
|
879
|
+
logits = @classifier.(pooled_output)
|
880
|
+
|
881
|
+
loss = nil
|
882
|
+
if !labels.nil?
|
883
|
+
if @config.problem_type.nil?
|
884
|
+
if @num_labels == 1
|
885
|
+
@config.problem_type = "regression"
|
886
|
+
elsif @num_labels > 1 && (labels.dtype == Torch.long || labels.dtype == Torch.int)
|
887
|
+
@config.problem_type = "single_label_classification"
|
888
|
+
else
|
889
|
+
@config.problem_type = "multi_label_classification"
|
890
|
+
end
|
891
|
+
end
|
892
|
+
|
893
|
+
if @config.problem_type == "regression"
|
894
|
+
loss_fct = Torch::NN::MSELoss.new
|
895
|
+
if @num_labels == 1
|
896
|
+
loss = loss_fct.(logits.squeeze, labels.squeeze)
|
897
|
+
else
|
898
|
+
loss = loss_fct.(logits, labels)
|
899
|
+
end
|
900
|
+
elsif @config.problem_type == "single_label_classification"
|
901
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new
|
902
|
+
loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
|
903
|
+
elsif @config.problem_type == "multi_label_classification"
|
904
|
+
loss_fct = Torch::NN::BCEWithLogitsLoss.new
|
905
|
+
loss = loss_fct.(logits, labels)
|
906
|
+
end
|
907
|
+
end
|
908
|
+
|
909
|
+
if !return_dict
|
910
|
+
raise Todo
|
911
|
+
end
|
912
|
+
|
913
|
+
SequenceClassifierOutput.new(
|
914
|
+
loss: loss,
|
915
|
+
logits: logits,
|
916
|
+
hidden_states: outputs.hidden_states,
|
917
|
+
attentions: outputs.attentions
|
918
|
+
)
|
919
|
+
end
|
920
|
+
end
|
832
921
|
end
|
833
922
|
|
834
923
|
BertModel = Bert::BertModel
|
835
924
|
BertForTokenClassification = Bert::BertForTokenClassification
|
925
|
+
BertForSequenceClassification = Bert::BertForSequenceClassification
|
836
926
|
end
|
@@ -19,7 +19,7 @@ module Transformers
|
|
19
19
|
self.attribute_map = {
|
20
20
|
hidden_size: "dim",
|
21
21
|
num_attention_heads: "n_heads",
|
22
|
-
num_hidden_layers: "n_layers"
|
22
|
+
num_hidden_layers: "n_layers"
|
23
23
|
}
|
24
24
|
|
25
25
|
attr_reader :vocab_size, :max_position_embeddings, :sinusoidal_pos_embds, :n_layers, :n_heads,
|
@@ -181,9 +181,9 @@ module Transformers
|
|
181
181
|
proxies: proxies,
|
182
182
|
timeout: 10
|
183
183
|
)
|
184
|
-
|
184
|
+
true
|
185
185
|
rescue HfHub::OfflineModeIsEnabled
|
186
|
-
|
186
|
+
has_file_in_cache
|
187
187
|
rescue HfHub::GatedRepoError => e
|
188
188
|
Transformers.logger.error(e)
|
189
189
|
raise EnvironmentError,
|
@@ -200,7 +200,7 @@ module Transformers
|
|
200
200
|
"#{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " +
|
201
201
|
"model name. Check the model page at 'https://huggingface.co/#{path_or_repo}' for available revisions."
|
202
202
|
rescue HfHub::EntryNotFoundError
|
203
|
-
|
203
|
+
false # File does not exist
|
204
204
|
end
|
205
205
|
end
|
206
206
|
end
|
data/lib/transformers/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: transformers-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-10-22 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|