transformers-rb 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +24 -12
- data/lib/transformers/configuration_utils.rb +3 -1
- data/lib/transformers/hf_hub/constants.rb +1 -1
- data/lib/transformers/modeling_utils.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +1 -1
- data/lib/transformers/models/auto/modeling_auto.rb +1 -0
- data/lib/transformers/models/bert/modeling_bert.rb +93 -3
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +1 -1
- data/lib/transformers/tokenization_utils.rb +0 -1
- data/lib/transformers/utils/hub.rb +3 -3
- data/lib/transformers/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 16b409e954f6bcc45fd4f3f2db94dc92c87d47c4c06936162978cc7d7e54fc09
|
4
|
+
data.tar.gz: c40ea58b7531e89a041ce2782dea89536f85227903e2dc9a60113afe041bb9f7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 0576500ca9fe9379aae4c2cc050aa34c90eea7b2d5251b6139c48d88b5107086e6197aad5210bc74d7969a5bfa0458d31ce7faf561df69b2ae9b2a6400280ce0
|
7
|
+
data.tar.gz: b5b0a865acdd37bcde11571e365a1d39b473f3dcbbe737a082f37002f886fe38f0f57c77f44236f859939d1d3b1df9659bb524a483cd0b9048dbc4cd472a3355
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -27,12 +27,13 @@ Embedding
|
|
27
27
|
|
28
28
|
- [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
|
29
29
|
- [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
|
30
|
+
- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
|
31
|
+
- [sentence-transformers/paraphrase-MiniLM-L6-v2](#sentence-transformersparaphrase-minilm-l6-v2)
|
30
32
|
- [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
|
31
33
|
- [thenlper/gte-small](#thenlpergte-small)
|
32
34
|
- [intfloat/e5-base-v2](#intfloate5-base-v2)
|
33
35
|
- [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
|
34
36
|
- [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
|
35
|
-
- [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
|
36
37
|
|
37
38
|
Sparse embedding
|
38
39
|
|
@@ -69,6 +70,28 @@ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
|
|
69
70
|
doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
|
70
71
|
```
|
71
72
|
|
73
|
+
### sentence-transformers/all-mpnet-base-v2
|
74
|
+
|
75
|
+
[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
|
76
|
+
|
77
|
+
```ruby
|
78
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
79
|
+
|
80
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
|
81
|
+
embeddings = model.(sentences)
|
82
|
+
```
|
83
|
+
|
84
|
+
### sentence-transformers/paraphrase-MiniLM-L6-v2
|
85
|
+
|
86
|
+
[Docs](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)
|
87
|
+
|
88
|
+
```ruby
|
89
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
90
|
+
|
91
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/paraphrase-MiniLM-L6-v2")
|
92
|
+
embeddings = model.(sentences)
|
93
|
+
```
|
94
|
+
|
72
95
|
### mixedbread-ai/mxbai-embed-large-v1
|
73
96
|
|
74
97
|
[Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
|
@@ -148,17 +171,6 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
|
|
148
171
|
embeddings = model.(input, pooling: "cls")
|
149
172
|
```
|
150
173
|
|
151
|
-
### sentence-transformers/all-mpnet-base-v2
|
152
|
-
|
153
|
-
[Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
|
154
|
-
|
155
|
-
```ruby
|
156
|
-
sentences = ["This is an example sentence", "Each sentence is converted"]
|
157
|
-
|
158
|
-
model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
|
159
|
-
embeddings = model.(sentences)
|
160
|
-
```
|
161
|
-
|
162
174
|
### opensearch-project/opensearch-neural-sparse-encoding-v1
|
163
175
|
|
164
176
|
[Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
|
@@ -36,7 +36,9 @@ module Transformers
|
|
36
36
|
|
37
37
|
attr_reader :output_hidden_states, :output_attentions, :pruned_heads, :tie_word_embeddings, :tokenizer_class,
|
38
38
|
:chunk_size_feed_forward, :pad_token_id, :is_decoder, :add_cross_attention,
|
39
|
-
:
|
39
|
+
:id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
|
40
|
+
|
41
|
+
attr_accessor :problem_type
|
40
42
|
|
41
43
|
def initialize(**kwargs)
|
42
44
|
@return_dict = kwargs.delete(:return_dict) { true }
|
@@ -23,7 +23,7 @@ module Transformers
|
|
23
23
|
"xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
|
24
24
|
"xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
|
25
25
|
"kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
|
26
|
-
"kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!)
|
26
|
+
"kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!)
|
27
27
|
# "uniform" => Torch::NN::Init.method(:uniform),
|
28
28
|
# "normal" => Torch::NN::Init.method(:normal),
|
29
29
|
# "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
|
@@ -55,7 +55,7 @@ module Transformers
|
|
55
55
|
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
56
56
|
if config_dict[:model_type]
|
57
57
|
config_class = CONFIG_MAPPING[config_dict[:model_type]]
|
58
|
-
|
58
|
+
config_class.from_dict(config_dict, **unused_kwargs)
|
59
59
|
else
|
60
60
|
raise Todo
|
61
61
|
end
|
@@ -28,6 +28,7 @@ module Transformers
|
|
28
28
|
}
|
29
29
|
|
30
30
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
31
|
+
"bert" => "BertForSequenceClassification",
|
31
32
|
"deberta-v2" => "DebertaV2ForSequenceClassification",
|
32
33
|
"distilbert" => "DistilBertForSequenceClassification",
|
33
34
|
"xlm-roberta" => "XLMRobertaForSequenceClassification"
|
@@ -353,7 +353,7 @@ module Transformers
|
|
353
353
|
def feed_forward_chunk(attention_output)
|
354
354
|
intermediate_output = @intermediate.(attention_output)
|
355
355
|
layer_output = @output.(intermediate_output, attention_output)
|
356
|
-
|
356
|
+
layer_output
|
357
357
|
end
|
358
358
|
end
|
359
359
|
|
@@ -370,7 +370,7 @@ module Transformers
|
|
370
370
|
attention_mask: nil,
|
371
371
|
head_mask: nil,
|
372
372
|
encoder_hidden_states: nil,
|
373
|
-
encoder_attention_mask:nil,
|
373
|
+
encoder_attention_mask: nil,
|
374
374
|
past_key_values: nil,
|
375
375
|
use_cache: nil,
|
376
376
|
output_attentions: false,
|
@@ -814,7 +814,7 @@ module Transformers
|
|
814
814
|
loss = nil
|
815
815
|
if !labels.nil?
|
816
816
|
loss_fct = CrossEntropyLoss.new
|
817
|
-
loss = loss_fct.(logits.view(-1
|
817
|
+
loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
|
818
818
|
end
|
819
819
|
|
820
820
|
if !return_dict
|
@@ -829,8 +829,98 @@ module Transformers
|
|
829
829
|
)
|
830
830
|
end
|
831
831
|
end
|
832
|
+
|
833
|
+
class BertForSequenceClassification < BertPreTrainedModel
|
834
|
+
def initialize(config)
|
835
|
+
super
|
836
|
+
@num_labels = config.num_labels
|
837
|
+
@config = config
|
838
|
+
|
839
|
+
@bert = BertModel.new(config, add_pooling_layer: true)
|
840
|
+
classifier_dropout = (
|
841
|
+
config.classifier_dropout.nil? ? config.hidden_dropout_prob : config.classifier_dropout
|
842
|
+
)
|
843
|
+
@dropout = Torch::NN::Dropout.new(p: classifier_dropout)
|
844
|
+
@classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
|
845
|
+
|
846
|
+
# Initialize weights and apply final processing
|
847
|
+
post_init
|
848
|
+
end
|
849
|
+
|
850
|
+
def forward(
|
851
|
+
input_ids: nil,
|
852
|
+
attention_mask: nil,
|
853
|
+
token_type_ids: nil,
|
854
|
+
position_ids: nil,
|
855
|
+
head_mask: nil,
|
856
|
+
inputs_embeds: nil,
|
857
|
+
labels: nil,
|
858
|
+
output_attentions: nil,
|
859
|
+
output_hidden_states: nil,
|
860
|
+
return_dict: nil
|
861
|
+
)
|
862
|
+
return_dict = @config.use_return_dict if return_dict.nil?
|
863
|
+
|
864
|
+
outputs = @bert.(
|
865
|
+
input_ids: input_ids,
|
866
|
+
attention_mask: attention_mask,
|
867
|
+
token_type_ids: token_type_ids,
|
868
|
+
position_ids: position_ids,
|
869
|
+
head_mask: head_mask,
|
870
|
+
inputs_embeds: inputs_embeds,
|
871
|
+
output_attentions: output_attentions,
|
872
|
+
output_hidden_states: output_hidden_states,
|
873
|
+
return_dict: return_dict
|
874
|
+
)
|
875
|
+
|
876
|
+
pooled_output = outputs[1]
|
877
|
+
|
878
|
+
pooled_output = @dropout.(pooled_output)
|
879
|
+
logits = @classifier.(pooled_output)
|
880
|
+
|
881
|
+
loss = nil
|
882
|
+
if !labels.nil?
|
883
|
+
if @config.problem_type.nil?
|
884
|
+
if @num_labels == 1
|
885
|
+
@config.problem_type = "regression"
|
886
|
+
elsif @num_labels > 1 && (labels.dtype == Torch.long || labels.dtype == Torch.int)
|
887
|
+
@config.problem_type = "single_label_classification"
|
888
|
+
else
|
889
|
+
@config.problem_type = "multi_label_classification"
|
890
|
+
end
|
891
|
+
end
|
892
|
+
|
893
|
+
if @config.problem_type == "regression"
|
894
|
+
loss_fct = Torch::NN::MSELoss.new
|
895
|
+
if @num_labels == 1
|
896
|
+
loss = loss_fct.(logits.squeeze, labels.squeeze)
|
897
|
+
else
|
898
|
+
loss = loss_fct.(logits, labels)
|
899
|
+
end
|
900
|
+
elsif @config.problem_type == "single_label_classification"
|
901
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new
|
902
|
+
loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
|
903
|
+
elsif @config.problem_type == "multi_label_classification"
|
904
|
+
loss_fct = Torch::NN::BCEWithLogitsLoss.new
|
905
|
+
loss = loss_fct.(logits, labels)
|
906
|
+
end
|
907
|
+
end
|
908
|
+
|
909
|
+
if !return_dict
|
910
|
+
raise Todo
|
911
|
+
end
|
912
|
+
|
913
|
+
SequenceClassifierOutput.new(
|
914
|
+
loss: loss,
|
915
|
+
logits: logits,
|
916
|
+
hidden_states: outputs.hidden_states,
|
917
|
+
attentions: outputs.attentions
|
918
|
+
)
|
919
|
+
end
|
920
|
+
end
|
832
921
|
end
|
833
922
|
|
834
923
|
BertModel = Bert::BertModel
|
835
924
|
BertForTokenClassification = Bert::BertForTokenClassification
|
925
|
+
BertForSequenceClassification = Bert::BertForSequenceClassification
|
836
926
|
end
|
@@ -19,7 +19,7 @@ module Transformers
|
|
19
19
|
self.attribute_map = {
|
20
20
|
hidden_size: "dim",
|
21
21
|
num_attention_heads: "n_heads",
|
22
|
-
num_hidden_layers: "n_layers"
|
22
|
+
num_hidden_layers: "n_layers"
|
23
23
|
}
|
24
24
|
|
25
25
|
attr_reader :vocab_size, :max_position_embeddings, :sinusoidal_pos_embds, :n_layers, :n_heads,
|
@@ -181,9 +181,9 @@ module Transformers
|
|
181
181
|
proxies: proxies,
|
182
182
|
timeout: 10
|
183
183
|
)
|
184
|
-
|
184
|
+
true
|
185
185
|
rescue HfHub::OfflineModeIsEnabled
|
186
|
-
|
186
|
+
has_file_in_cache
|
187
187
|
rescue HfHub::GatedRepoError => e
|
188
188
|
Transformers.logger.error(e)
|
189
189
|
raise EnvironmentError,
|
@@ -200,7 +200,7 @@ module Transformers
|
|
200
200
|
"#{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " +
|
201
201
|
"model name. Check the model page at 'https://huggingface.co/#{path_or_repo}' for available revisions."
|
202
202
|
rescue HfHub::EntryNotFoundError
|
203
|
-
|
203
|
+
false # File does not exist
|
204
204
|
end
|
205
205
|
end
|
206
206
|
end
|
data/lib/transformers/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: transformers-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-10-22 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|