transformers-rb 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 3f070b9828c5c5ad71c75f46ca9daf1387a5ec3848cb406aac9e5f1bbc1d4531
4
- data.tar.gz: 31b28a5a87c58db6fc3146e390e8a4a7bf1ffc34ede6d3cd6fcd7f3aa3df2d28
3
+ metadata.gz: 16b409e954f6bcc45fd4f3f2db94dc92c87d47c4c06936162978cc7d7e54fc09
4
+ data.tar.gz: c40ea58b7531e89a041ce2782dea89536f85227903e2dc9a60113afe041bb9f7
5
5
  SHA512:
6
- metadata.gz: aa2055e44b9071a425ebfb59d6b2edbedce1f3cf97e0baa55d1280451c1c1db097a52b0b9615a188b1d96f0854e557fb4cb769b05cb3af4db229cd3fcdf8fb95
7
- data.tar.gz: 1af002f238e9189a2e2a6b5f1aafc9201cfd5bc5f8afe4a80b81757b5d9f5d4fa52bc61a57b4fdd6920bd3692f704398aa58c7cc4fd797bd881ab9887c9c77f9
6
+ metadata.gz: 0576500ca9fe9379aae4c2cc050aa34c90eea7b2d5251b6139c48d88b5107086e6197aad5210bc74d7969a5bfa0458d31ce7faf561df69b2ae9b2a6400280ce0
7
+ data.tar.gz: b5b0a865acdd37bcde11571e365a1d39b473f3dcbbe737a082f37002f886fe38f0f57c77f44236f859939d1d3b1df9659bb524a483cd0b9048dbc4cd472a3355
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.1.4 (2024-10-22)
2
+
3
+ - Added `BertForSequenceClassification`
4
+
1
5
  ## 0.1.3 (2024-09-17)
2
6
 
3
7
  - Added `reranking` pipeline
data/README.md CHANGED
@@ -27,12 +27,13 @@ Embedding
27
27
 
28
28
  - [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
29
29
  - [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
30
+ - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
31
+ - [sentence-transformers/paraphrase-MiniLM-L6-v2](#sentence-transformersparaphrase-minilm-l6-v2)
30
32
  - [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
31
33
  - [thenlper/gte-small](#thenlpergte-small)
32
34
  - [intfloat/e5-base-v2](#intfloate5-base-v2)
33
35
  - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
34
36
  - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
- - [sentence-transformers/all-mpnet-base-v2](#sentence-transformersall-mpnet-base-v2)
36
37
 
37
38
  Sparse embedding
38
39
 
@@ -69,6 +70,28 @@ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
69
70
  doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
70
71
  ```
71
72
 
73
+ ### sentence-transformers/all-mpnet-base-v2
74
+
75
+ [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
76
+
77
+ ```ruby
78
+ sentences = ["This is an example sentence", "Each sentence is converted"]
79
+
80
+ model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
81
+ embeddings = model.(sentences)
82
+ ```
83
+
84
+ ### sentence-transformers/paraphrase-MiniLM-L6-v2
85
+
86
+ [Docs](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2)
87
+
88
+ ```ruby
89
+ sentences = ["This is an example sentence", "Each sentence is converted"]
90
+
91
+ model = Transformers.pipeline("embedding", "sentence-transformers/paraphrase-MiniLM-L6-v2")
92
+ embeddings = model.(sentences)
93
+ ```
94
+
72
95
  ### mixedbread-ai/mxbai-embed-large-v1
73
96
 
74
97
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
@@ -148,17 +171,6 @@ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v
148
171
  embeddings = model.(input, pooling: "cls")
149
172
  ```
150
173
 
151
- ### sentence-transformers/all-mpnet-base-v2
152
-
153
- [Docs](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
154
-
155
- ```ruby
156
- sentences = ["This is an example sentence", "Each sentence is converted"]
157
-
158
- model = Transformers.pipeline("embedding", "sentence-transformers/all-mpnet-base-v2")
159
- embeddings = model.(sentences)
160
- ```
161
-
162
174
  ### opensearch-project/opensearch-neural-sparse-encoding-v1
163
175
 
164
176
  [Docs](https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1)
@@ -36,7 +36,9 @@ module Transformers
36
36
 
37
37
  attr_reader :output_hidden_states, :output_attentions, :pruned_heads, :tie_word_embeddings, :tokenizer_class,
38
38
  :chunk_size_feed_forward, :pad_token_id, :is_decoder, :add_cross_attention,
39
- :problem_type, :id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
39
+ :id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
40
+
41
+ attr_accessor :problem_type
40
42
 
41
43
  def initialize(**kwargs)
42
44
  @return_dict = kwargs.delete(:return_dict) { true }
@@ -45,7 +45,7 @@ module Transformers
45
45
 
46
46
  REPO_TYPES_URL_PREFIXES = {
47
47
  REPO_TYPE_DATASET => "datasets/",
48
- REPO_TYPE_SPACE => "spaces/",
48
+ REPO_TYPE_SPACE => "spaces/"
49
49
  }
50
50
 
51
51
  # default cache
@@ -23,7 +23,7 @@ module Transformers
23
23
  "xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
24
24
  "xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
25
25
  "kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
26
- "kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!),
26
+ "kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!)
27
27
  # "uniform" => Torch::NN::Init.method(:uniform),
28
28
  # "normal" => Torch::NN::Init.method(:normal),
29
29
  # "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
@@ -55,7 +55,7 @@ module Transformers
55
55
  config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
56
56
  if config_dict[:model_type]
57
57
  config_class = CONFIG_MAPPING[config_dict[:model_type]]
58
- return config_class.from_dict(config_dict, **unused_kwargs)
58
+ config_class.from_dict(config_dict, **unused_kwargs)
59
59
  else
60
60
  raise Todo
61
61
  end
@@ -28,6 +28,7 @@ module Transformers
28
28
  }
29
29
 
30
30
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
31
+ "bert" => "BertForSequenceClassification",
31
32
  "deberta-v2" => "DebertaV2ForSequenceClassification",
32
33
  "distilbert" => "DistilBertForSequenceClassification",
33
34
  "xlm-roberta" => "XLMRobertaForSequenceClassification"
@@ -353,7 +353,7 @@ module Transformers
353
353
  def feed_forward_chunk(attention_output)
354
354
  intermediate_output = @intermediate.(attention_output)
355
355
  layer_output = @output.(intermediate_output, attention_output)
356
- return layer_output
356
+ layer_output
357
357
  end
358
358
  end
359
359
 
@@ -370,7 +370,7 @@ module Transformers
370
370
  attention_mask: nil,
371
371
  head_mask: nil,
372
372
  encoder_hidden_states: nil,
373
- encoder_attention_mask:nil,
373
+ encoder_attention_mask: nil,
374
374
  past_key_values: nil,
375
375
  use_cache: nil,
376
376
  output_attentions: false,
@@ -814,7 +814,7 @@ module Transformers
814
814
  loss = nil
815
815
  if !labels.nil?
816
816
  loss_fct = CrossEntropyLoss.new
817
- loss = loss_fct.(logits.view(-1,@num_labels), labels.view(-1))
817
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
818
818
  end
819
819
 
820
820
  if !return_dict
@@ -829,8 +829,98 @@ module Transformers
829
829
  )
830
830
  end
831
831
  end
832
+
833
+ class BertForSequenceClassification < BertPreTrainedModel
834
+ def initialize(config)
835
+ super
836
+ @num_labels = config.num_labels
837
+ @config = config
838
+
839
+ @bert = BertModel.new(config, add_pooling_layer: true)
840
+ classifier_dropout = (
841
+ config.classifier_dropout.nil? ? config.hidden_dropout_prob : config.classifier_dropout
842
+ )
843
+ @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
844
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
845
+
846
+ # Initialize weights and apply final processing
847
+ post_init
848
+ end
849
+
850
+ def forward(
851
+ input_ids: nil,
852
+ attention_mask: nil,
853
+ token_type_ids: nil,
854
+ position_ids: nil,
855
+ head_mask: nil,
856
+ inputs_embeds: nil,
857
+ labels: nil,
858
+ output_attentions: nil,
859
+ output_hidden_states: nil,
860
+ return_dict: nil
861
+ )
862
+ return_dict = @config.use_return_dict if return_dict.nil?
863
+
864
+ outputs = @bert.(
865
+ input_ids: input_ids,
866
+ attention_mask: attention_mask,
867
+ token_type_ids: token_type_ids,
868
+ position_ids: position_ids,
869
+ head_mask: head_mask,
870
+ inputs_embeds: inputs_embeds,
871
+ output_attentions: output_attentions,
872
+ output_hidden_states: output_hidden_states,
873
+ return_dict: return_dict
874
+ )
875
+
876
+ pooled_output = outputs[1]
877
+
878
+ pooled_output = @dropout.(pooled_output)
879
+ logits = @classifier.(pooled_output)
880
+
881
+ loss = nil
882
+ if !labels.nil?
883
+ if @config.problem_type.nil?
884
+ if @num_labels == 1
885
+ @config.problem_type = "regression"
886
+ elsif @num_labels > 1 && (labels.dtype == Torch.long || labels.dtype == Torch.int)
887
+ @config.problem_type = "single_label_classification"
888
+ else
889
+ @config.problem_type = "multi_label_classification"
890
+ end
891
+ end
892
+
893
+ if @config.problem_type == "regression"
894
+ loss_fct = Torch::NN::MSELoss.new
895
+ if @num_labels == 1
896
+ loss = loss_fct.(logits.squeeze, labels.squeeze)
897
+ else
898
+ loss = loss_fct.(logits, labels)
899
+ end
900
+ elsif @config.problem_type == "single_label_classification"
901
+ loss_fct = Torch::NN::CrossEntropyLoss.new
902
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
903
+ elsif @config.problem_type == "multi_label_classification"
904
+ loss_fct = Torch::NN::BCEWithLogitsLoss.new
905
+ loss = loss_fct.(logits, labels)
906
+ end
907
+ end
908
+
909
+ if !return_dict
910
+ raise Todo
911
+ end
912
+
913
+ SequenceClassifierOutput.new(
914
+ loss: loss,
915
+ logits: logits,
916
+ hidden_states: outputs.hidden_states,
917
+ attentions: outputs.attentions
918
+ )
919
+ end
920
+ end
832
921
  end
833
922
 
834
923
  BertModel = Bert::BertModel
835
924
  BertForTokenClassification = Bert::BertForTokenClassification
925
+ BertForSequenceClassification = Bert::BertForSequenceClassification
836
926
  end
@@ -19,7 +19,7 @@ module Transformers
19
19
  self.attribute_map = {
20
20
  hidden_size: "dim",
21
21
  num_attention_heads: "n_heads",
22
- num_hidden_layers: "n_layers",
22
+ num_hidden_layers: "n_layers"
23
23
  }
24
24
 
25
25
  attr_reader :vocab_size, :max_position_embeddings, :sinusoidal_pos_embds, :n_layers, :n_heads,
@@ -15,7 +15,6 @@
15
15
  module Transformers
16
16
  class PreTrainedTokenizer < PreTrainedTokenizerBase
17
17
  def initialize(**kwargs)
18
-
19
18
  # 2. init `_added_tokens_decoder` if child class did not
20
19
  if !instance_variable_defined?(:@added_tokens_decoder)
21
20
  @added_tokens_decoder = {}
@@ -181,9 +181,9 @@ module Transformers
181
181
  proxies: proxies,
182
182
  timeout: 10
183
183
  )
184
- return true
184
+ true
185
185
  rescue HfHub::OfflineModeIsEnabled
186
- return has_file_in_cache
186
+ has_file_in_cache
187
187
  rescue HfHub::GatedRepoError => e
188
188
  Transformers.logger.error(e)
189
189
  raise EnvironmentError,
@@ -200,7 +200,7 @@ module Transformers
200
200
  "#{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " +
201
201
  "model name. Check the model page at 'https://huggingface.co/#{path_or_repo}' for available revisions."
202
202
  rescue HfHub::EntryNotFoundError
203
- return false # File does not exist
203
+ false # File does not exist
204
204
  end
205
205
  end
206
206
  end
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.3"
2
+ VERSION = "0.1.4"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.3
4
+ version: 0.1.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-09-17 00:00:00.000000000 Z
11
+ date: 2024-10-22 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray