transformers-rb 0.1.0 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c8f34c5454e2a1ac18bbb9a4b290a43e994cd3984fa2b4125ff4af969b9d17ed
4
- data.tar.gz: 57c876fd1a4e62089fdc7bcbfcb9c155050166458a679991036894a6721ac168
3
+ metadata.gz: 3f29055705824ba101cba238960d4f10825c75bc7867b9eb0b611cda6a547612
4
+ data.tar.gz: d0967f7742f7b2d6194376eb040a3be81e77a9ded94302aeb934de678959e434
5
5
  SHA512:
6
- metadata.gz: 7458b1ba0303e0741abf16a63efc350b6cad5e5dff48c46dcba6f62858f562bbf83478eb918995c45f5882159cf5c22d696cc4b0360f813312c2263da7c28205
7
- data.tar.gz: a4d98b210a22d23bc55f452a93dd0e9df20c81cbc0d537d29eaa1069b157eec1df6cdc76fc7f398947e03ae841166d609bff0d250506eb0dd0745ef0fdf86efd
6
+ metadata.gz: 38b9ed4fd654ca593e3d6e7c7f20eb3c6b68ecfa5f86099fbc8d160f9093617cc79a571c331a0ec0c70a6770c8d9460194ba75e61d534f9e15931f22e5ae60c3
7
+ data.tar.gz: 00ce437ce8fe419fafddd59b7f9f61050d2ddf5817816b53beb1c43badbced9fe77ea28b26f2d607e87a9be855a19c4b254bc50df2881b4126f2def5c6875c3d
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## 0.1.2 (2024-09-10)
2
+
3
+ - Fixed default revision for pipelines
4
+
5
+ ## 0.1.1 (2024-08-29)
6
+
7
+ - Added `embedding` pipeline
8
+ - Added experimental `fast_init` option
9
+ - Improved performance of loading models
10
+ - Fixed error with `aggregation_strategy` option
11
+
1
12
  ## 0.1.0 (2024-08-19)
2
13
 
3
14
  - First release
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :slightly_smiling_face: State-of-the-art [transformers](https://github.com/huggingface/transformers) for Ruby
4
4
 
5
+ For fast inference, check out [Informers](https://github.com/ankane/informers) :fire:
6
+
5
7
  [![Build Status](https://github.com/ankane/transformers-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/transformers-ruby/actions)
6
8
 
7
9
  ## Installation
@@ -21,6 +23,20 @@ gem "transformers-rb"
21
23
 
22
24
  ## Models
23
25
 
26
+ Embedding
27
+
28
+ - [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
29
+ - [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
30
+ - [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
31
+ - [thenlper/gte-small](#thenlpergte-small)
32
+ - [intfloat/e5-base-v2](#intfloate5-base-v2)
33
+ - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
34
+ - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+
36
+ Sparse embedding
37
+
38
+ - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
39
+
24
40
  ### sentence-transformers/all-MiniLM-L6-v2
25
41
 
26
42
  [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
@@ -28,8 +44,8 @@ gem "transformers-rb"
28
44
  ```ruby
29
45
  sentences = ["This is an example sentence", "Each sentence is converted"]
30
46
 
31
- model = Transformers::SentenceTransformer.new("sentence-transformers/all-MiniLM-L6-v2")
32
- embeddings = model.encode(sentences)
47
+ model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2")
48
+ embeddings = model.(sentences)
33
49
  ```
34
50
 
35
51
  ### sentence-transformers/multi-qa-MiniLM-L6-cos-v1
@@ -40,10 +56,10 @@ embeddings = model.encode(sentences)
40
56
  query = "How many people live in London?"
41
57
  docs = ["Around 9 Million people live in London", "London is known for its financial district"]
42
58
 
43
- model = Transformers::SentenceTransformer.new("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
44
- query_emb = model.encode(query)
45
- doc_emb = model.encode(docs)
46
- scores = Torch.mm(Torch.tensor([query_emb]), Torch.tensor(doc_emb).transpose(0, 1))[0].cpu.to_a
59
+ model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
60
+ query_embedding = model.(query)
61
+ doc_embeddings = model.(docs)
62
+ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
47
63
  doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
48
64
  ```
49
65
 
@@ -52,18 +68,78 @@ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
52
68
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
53
69
 
54
70
  ```ruby
55
- def transform_query(query)
56
- "Represent this sentence for searching relevant passages: #{query}"
57
- end
71
+ query_prefix = "Represent this sentence for searching relevant passages: "
58
72
 
59
- docs = [
60
- transform_query("puppy"),
73
+ input = [
61
74
  "The dog is barking",
62
- "The cat is purring"
75
+ "The cat is purring",
76
+ query_prefix + "puppy"
63
77
  ]
64
78
 
65
- model = Transformers::SentenceTransformer.new("mixedbread-ai/mxbai-embed-large-v1")
66
- embeddings = model.encode(docs)
79
+ model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
80
+ embeddings = model.(input)
81
+ ```
82
+
83
+ ### thenlper/gte-small
84
+
85
+ [Docs](https://huggingface.co/thenlper/gte-small)
86
+
87
+ ```ruby
88
+ sentences = ["That is a happy person", "That is a very happy person"]
89
+
90
+ model = Transformers.pipeline("embedding", "thenlper/gte-small")
91
+ embeddings = model.(sentences)
92
+ ```
93
+
94
+ ### intfloat/e5-base-v2
95
+
96
+ [Docs](https://huggingface.co/intfloat/e5-base-v2)
97
+
98
+ ```ruby
99
+ doc_prefix = "passage: "
100
+ query_prefix = "query: "
101
+
102
+ input = [
103
+ doc_prefix + "Ruby is a programming language created by Matz",
104
+ query_prefix + "Ruby creator"
105
+ ]
106
+
107
+ model = Transformers.pipeline("embedding", "intfloat/e5-base-v2")
108
+ embeddings = model.(input)
109
+ ```
110
+
111
+ ### BAAI/bge-base-en-v1.5
112
+
113
+ [Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)
114
+
115
+ ```ruby
116
+ query_prefix = "Represent this sentence for searching relevant passages: "
117
+
118
+ input = [
119
+ "The dog is barking",
120
+ "The cat is purring",
121
+ query_prefix + "puppy"
122
+ ]
123
+
124
+ model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
125
+ embeddings = model.(input)
126
+ ```
127
+
128
+ ### Snowflake/snowflake-arctic-embed-m-v1.5
129
+
130
+ [Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)
131
+
132
+ ```ruby
133
+ query_prefix = "Represent this sentence for searching relevant passages: "
134
+
135
+ input = [
136
+ "The dog is barking",
137
+ "The cat is purring",
138
+ query_prefix + "puppy"
139
+ ]
140
+
141
+ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
142
+ embeddings = model.(input, pooling: "cls")
67
143
  ```
68
144
 
69
145
  ### opensearch-project/opensearch-neural-sparse-encoding-v1
@@ -89,6 +165,13 @@ embeddings = values.to_a
89
165
 
90
166
  ## Pipelines
91
167
 
168
+ Embedding
169
+
170
+ ```ruby
171
+ embed = Transformers.pipeline("embedding")
172
+ embed.("We are very happy to show you the 🤗 Transformers library.")
173
+ ```
174
+
92
175
  Named-entity recognition
93
176
 
94
177
  ```ruby
@@ -207,10 +207,17 @@ module Transformers
207
207
 
208
208
  config = new(**config_dict)
209
209
 
210
+ to_remove = []
210
211
  kwargs.each do |key, value|
211
212
  if config.respond_to?("#{key}=")
212
213
  config.public_send("#{key}=", value)
213
214
  end
215
+ if key != :torch_dtype
216
+ to_remove << key
217
+ end
218
+ end
219
+ to_remove.each do |key|
220
+ kwargs.delete(key)
214
221
  end
215
222
 
216
223
  Transformers.logger.info("Model config #{config}")
@@ -22,6 +22,7 @@ module Transformers
22
22
  def to_h
23
23
  @data
24
24
  end
25
+ alias_method :to_hash, :to_h
25
26
 
26
27
  def [](item)
27
28
  @data[item]
@@ -14,6 +14,47 @@
14
14
  # limitations under the License.
15
15
 
16
16
  module Transformers
17
+ module ModelingUtils
18
+ TORCH_INIT_FUNCTIONS = {
19
+ "uniform!" => Torch::NN::Init.method(:uniform!),
20
+ "normal!" => Torch::NN::Init.method(:normal!),
21
+ # "trunc_normal!" => Torch::NN::Init.method(:trunc_normal!),
22
+ "constant!" => Torch::NN::Init.method(:constant!),
23
+ "xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
24
+ "xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
25
+ "kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
26
+ "kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!),
27
+ # "uniform" => Torch::NN::Init.method(:uniform),
28
+ # "normal" => Torch::NN::Init.method(:normal),
29
+ # "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
30
+ # "xavier_normal" => Torch::NN::Init.method(:xavier_normal),
31
+ # "kaiming_uniform" => Torch::NN::Init.method(:kaiming_uniform),
32
+ # "kaiming_normal" => Torch::NN::Init.method(:kaiming_normal)
33
+ }
34
+
35
+ # private
36
+ # note: this improves loading time significantly, but is not thread-safe!
37
+ def self.no_init_weights
38
+ return yield unless Transformers.fast_init
39
+
40
+ _skip_init = lambda do |*args, **kwargs|
41
+ # pass
42
+ end
43
+ # Save the original initialization functions
44
+ TORCH_INIT_FUNCTIONS.each do |name, init_func|
45
+ Torch::NN::Init.singleton_class.undef_method(name)
46
+ Torch::NN::Init.define_singleton_method(name, &_skip_init)
47
+ end
48
+ yield
49
+ ensure
50
+ # Restore the original initialization functions
51
+ TORCH_INIT_FUNCTIONS.each do |name, init_func|
52
+ Torch::NN::Init.singleton_class.undef_method(name)
53
+ Torch::NN::Init.define_singleton_method(name, init_func)
54
+ end
55
+ end
56
+ end
57
+
17
58
  module ModuleUtilsMixin
18
59
  def get_extended_attention_mask(
19
60
  attention_mask,
@@ -138,7 +179,11 @@ module Transformers
138
179
  end
139
180
 
140
181
  def _initialize_weights(mod)
182
+ if mod.instance_variable_defined?(:@is_hf_initialized)
183
+ return
184
+ end
141
185
  _init_weights(mod)
186
+ mod.instance_variable_set(:@is_hf_initialized, true)
142
187
  end
143
188
 
144
189
  def tie_weights
@@ -166,7 +211,9 @@ module Transformers
166
211
  prune_heads(@config.pruned_heads)
167
212
  end
168
213
 
169
- if true
214
+ # TODO implement no_init_weights context manager
215
+ _init_weights = false
216
+ if _init_weights
170
217
  # Initialize weights
171
218
  apply(method(:_initialize_weights))
172
219
 
@@ -512,8 +559,8 @@ module Transformers
512
559
 
513
560
  config.name_or_path = pretrained_model_name_or_path
514
561
 
515
- model_kwargs = {}
516
- model = new(config, *model_args, **model_kwargs)
562
+ # Instantiate model.
563
+ model = ModelingUtils.no_init_weights { new(config, *model_args, **model_kwargs) }
517
564
 
518
565
  # make sure we use the model's config since the __init__ call might have copied it
519
566
  config = model.config
@@ -683,6 +730,10 @@ module Transformers
683
730
  end
684
731
  end
685
732
 
733
+ if _fast_init
734
+ # TODO
735
+ end
736
+
686
737
  # Make sure we are able to load base models as well as derived models (with heads)
687
738
  start_prefix = ""
688
739
  model_to_load = model
@@ -756,28 +807,29 @@ module Transformers
756
807
  raise Todo
757
808
  end
758
809
 
810
+ model_class_name = model.class.name.split("::").last
759
811
  if unexpected_keys.length > 0
760
812
  archs = model.config.architectures.nil? ? [] : model.config.architectures
761
- warner = archs.include?(model.class.name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
813
+ warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
762
814
  warner.(
763
815
  "Some weights of the model checkpoint at #{pretrained_model_name_or_path} were not used when" +
764
- " initializing #{model.class.name}: #{unexpected_keys}\n- This IS expected if you are" +
765
- " initializing #{model.class.name} from the checkpoint of a model trained on another task or" +
816
+ " initializing #{model_class_name}: #{unexpected_keys}\n- This IS expected if you are" +
817
+ " initializing #{model_class_name} from the checkpoint of a model trained on another task or" +
766
818
  " with another architecture (e.g. initializing a BertForSequenceClassification model from a" +
767
819
  " BertForPreTraining model).\n- This IS NOT expected if you are initializing" +
768
- " #{model.class.name} from the checkpoint of a model that you expect to be exactly identical" +
820
+ " #{model_class_name} from the checkpoint of a model that you expect to be exactly identical" +
769
821
  " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
770
822
  )
771
823
  else
772
- Transformers.logger.info("All model checkpoint weights were used when initializing #{model.class.name}.\n")
824
+ Transformers.logger.info("All model checkpoint weights were used when initializing #{model_class_name}.\n")
773
825
  end
774
826
  if missing_keys.length > 0
775
- Transformers.logger.info("Some weights of #{model.class.name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
827
+ Transformers.logger.info("Some weights of #{model_class_name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
776
828
  elsif mismatched_keys.length == 0
777
829
  Transformers.logger.info(
778
- "All the weights of #{model.class.name} were initialized from the model checkpoint at" +
830
+ "All the weights of #{model_class_name} were initialized from the model checkpoint at" +
779
831
  " #{pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" +
780
- " was trained on, you can already use #{model.class.name} for predictions without further" +
832
+ " was trained on, you can already use #{model_class_name} for predictions without further" +
781
833
  " training."
782
834
  )
783
835
  end
@@ -97,7 +97,7 @@ module Transformers
97
97
  @position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
98
98
  if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
99
99
  @max_position_embeddings = config.max_position_embeddings
100
- @distance_embedding = Torch:NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
100
+ @distance_embedding = Torch::NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
101
101
  end
102
102
 
103
103
  @is_decoder = config.is_decoder
@@ -639,8 +639,8 @@ module Transformers
639
639
  extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
640
640
  end
641
641
 
642
- # # If a 2D or 3D attention mask is provided for the cross-attention
643
- # # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
642
+ # If a 2D or 3D attention mask is provided for the cross-attention
643
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
644
644
  if @config.is_decoder && !encoder_hidden_states.nil?
645
645
  encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
646
646
  encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
@@ -24,7 +24,7 @@ module Transformers
24
24
  "pt" => [AutoModel],
25
25
  "default" => {
26
26
  "model" => {
27
- "pt" => ["distilbert/distilbert-base-cased", "935ac13"]
27
+ "pt" => ["distilbert/distilbert-base-cased", "6ea8117"]
28
28
  }
29
29
  },
30
30
  "type" => "multimodal"
@@ -34,7 +34,7 @@ module Transformers
34
34
  "pt" => [AutoModelForSequenceClassification],
35
35
  "default" => {
36
36
  "model" => {
37
- "pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"]
37
+ "pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", "714eb0f"]
38
38
  }
39
39
  },
40
40
  "type" => "text"
@@ -44,7 +44,7 @@ module Transformers
44
44
  "pt" => [AutoModelForTokenClassification],
45
45
  "default" => {
46
46
  "model" => {
47
- "pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"]
47
+ "pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", "4c53496"]
48
48
  }
49
49
  },
50
50
  "type" => "text"
@@ -54,7 +54,7 @@ module Transformers
54
54
  "pt" => [AutoModelForQuestionAnswering],
55
55
  "default" => {
56
56
  "model" => {
57
- "pt" => ["distilbert/distilbert-base-cased-distilled-squad", "626af31"]
57
+ "pt" => ["distilbert/distilbert-base-cased-distilled-squad", "564e9b5"]
58
58
  }
59
59
  },
60
60
  "type" => "text"
@@ -64,7 +64,7 @@ module Transformers
64
64
  "pt" => [AutoModelForImageClassification],
65
65
  "default" => {
66
66
  "model" => {
67
- "pt" => ["google/vit-base-patch16-224", "5dca96d"]
67
+ "pt" => ["google/vit-base-patch16-224", "3f49326"]
68
68
  }
69
69
  },
70
70
  "type" => "image"
@@ -78,7 +78,17 @@ module Transformers
78
78
  }
79
79
  },
80
80
  "type" => "image"
81
- }
81
+ },
82
+ "embedding" => {
83
+ "impl" => EmbeddingPipeline,
84
+ "pt" => [AutoModel],
85
+ "default" => {
86
+ "model" => {
87
+ "pt" => ["sentence-transformers/all-MiniLM-L6-v2", "8b3219a"]
88
+ }
89
+ },
90
+ "type" => "text"
91
+ },
82
92
  }
83
93
 
84
94
  PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
@@ -86,6 +96,7 @@ module Transformers
86
96
  class << self
87
97
  def pipeline(
88
98
  task,
99
+ model_arg = nil,
89
100
  model: nil,
90
101
  config: nil,
91
102
  tokenizer: nil,
@@ -103,6 +114,13 @@ module Transformers
103
114
  pipeline_class: nil,
104
115
  **kwargs
105
116
  )
117
+ if !model_arg.nil?
118
+ if !model.nil?
119
+ raise ArgumentError, "Cannot pass multiple models"
120
+ end
121
+ model = model_arg
122
+ end
123
+
106
124
  model_kwargs ||= {}
107
125
  # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
108
126
  # this is to keep BC).
@@ -209,6 +227,7 @@ module Transformers
209
227
  " #{revision} (#{Utils::Hub::HUGGINGFACE_CO_RESOLVE_ENDPOINT}/#{model}).\n" +
210
228
  "Using a pipeline without specifying a model name and revision in production is not recommended."
211
229
  )
230
+ hub_kwargs[:revision] = revision
212
231
  if config.nil? && model.is_a?(String)
213
232
  config = AutoConfig.from_pretrained(model, _from_pipeline: task, **hub_kwargs, **model_kwargs)
214
233
  hub_kwargs[:_commit_hash] = config._commit_hash
@@ -0,0 +1,46 @@
1
+ module Transformers
2
+ class EmbeddingPipeline < Pipeline
3
+ def _sanitize_parameters(**kwargs)
4
+ [{}, {}, kwargs]
5
+ end
6
+
7
+ def preprocess(inputs)
8
+ @tokenizer.(inputs, return_tensors: @framework)
9
+ end
10
+
11
+ def _forward(model_inputs)
12
+ {
13
+ last_hidden_state: @model.(**model_inputs)[0],
14
+ attention_mask: model_inputs[:attention_mask]
15
+ }
16
+ end
17
+
18
+ def postprocess(model_outputs, pooling: "mean", normalize: true)
19
+ output = model_outputs[:last_hidden_state]
20
+
21
+ case pooling
22
+ when "none"
23
+ # do nothing
24
+ when "mean"
25
+ output = mean_pooling(output, model_outputs[:attention_mask])
26
+ when "cls"
27
+ output = output[0.., 0]
28
+ else
29
+ raise Error, "Pooling method '#{pooling}' not supported."
30
+ end
31
+
32
+ if normalize
33
+ output = Torch::NN::Functional.normalize(output, p: 2, dim: 1)
34
+ end
35
+
36
+ output[0].to_a
37
+ end
38
+
39
+ private
40
+
41
+ def mean_pooling(output, attention_mask)
42
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
43
+ Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
44
+ end
45
+ end
46
+ end
@@ -45,7 +45,7 @@ module Transformers
45
45
  end
46
46
 
47
47
  def _forward(model_inputs)
48
- model_outputs = @model.(**model_inputs.to_h)
48
+ model_outputs = @model.(**model_inputs)
49
49
  model_outputs
50
50
  end
51
51
 
@@ -29,7 +29,7 @@ module Transformers
29
29
  end
30
30
 
31
31
  def _forward(model_inputs)
32
- model_outputs = @model.(**model_inputs.to_h)
32
+ model_outputs = @model.(**model_inputs)
33
33
  model_outputs
34
34
  end
35
35
 
@@ -34,7 +34,7 @@ module Transformers
34
34
  end
35
35
 
36
36
  if function_to_apply.is_a?(String)
37
- function_to_apply = ClassificationFunction.new(function_to_apply.upcase).to_s
37
+ function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s
38
38
  end
39
39
 
40
40
  if !function_to_apply.nil?
@@ -62,7 +62,7 @@ module Transformers
62
62
  end
63
63
 
64
64
  def _forward(model_inputs)
65
- @model.(**model_inputs.to_h)
65
+ @model.(**model_inputs)
66
66
  end
67
67
 
68
68
  def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
@@ -62,7 +62,7 @@ module Transformers
62
62
 
63
63
  if !aggregation_strategy.nil?
64
64
  if aggregation_strategy.is_a?(String)
65
- aggregation_strategy = AggregationStrategy.new(aggregation_strategy.upcase).to_s
65
+ aggregation_strategy = AggregationStrategy.new(aggregation_strategy.downcase).to_s
66
66
  end
67
67
  if (
68
68
  [AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
@@ -278,5 +278,80 @@ module Transformers
278
278
  end
279
279
  group_entities(entities)
280
280
  end
281
+
282
+ def aggregate_word(entities, aggregation_strategy)
283
+ raise Todo
284
+ end
285
+
286
+ def aggregate_words(entities, aggregation_strategy)
287
+ raise Todo
288
+ end
289
+
290
+ def group_sub_entities(entities)
291
+ # Get the first entity in the entity group
292
+ entity = entities[0][:entity].split("-", 2)[-1]
293
+ scores = entities.map { |entity| entity[:score] }
294
+ tokens = entities.map { |entity| entity[:word] }
295
+
296
+ entity_group = {
297
+ entity_group: entity,
298
+ score: scores.sum / scores.count.to_f,
299
+ word: @tokenizer.convert_tokens_to_string(tokens),
300
+ start: entities[0][:start],
301
+ end: entities[-1][:end]
302
+ }
303
+ entity_group
304
+ end
305
+
306
+ def get_tag(entity_name)
307
+ if entity_name.start_with?("B-")
308
+ bi = "B"
309
+ tag = entity_name[2..]
310
+ elsif entity_name.start_with?("I-")
311
+ bi = "I"
312
+ tag = entity_name[2..]
313
+ else
314
+ # It's not in B-, I- format
315
+ # Default to I- for continuation.
316
+ bi = "I"
317
+ tag = entity_name
318
+ end
319
+ [bi, tag]
320
+ end
321
+
322
+ def group_entities(entities)
323
+ entity_groups = []
324
+ entity_group_disagg = []
325
+
326
+ entities.each do |entity|
327
+ if entity_group_disagg.empty?
328
+ entity_group_disagg << entity
329
+ next
330
+ end
331
+
332
+ # If the current entity is similar and adjacent to the previous entity,
333
+ # append it to the disaggregated entity group
334
+ # The split is meant to account for the "B" and "I" prefixes
335
+ # Shouldn't merge if both entities are B-type
336
+ bi, tag = get_tag(entity[:entity])
337
+ _last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])
338
+
339
+ if tag == last_tag && bi != "B"
340
+ # Modify subword type to be previous_type
341
+ entity_group_disagg << entity
342
+ else
343
+ # If the current entity is different from the previous entity
344
+ # aggregate the disaggregated entity group
345
+ entity_groups << group_sub_entities(entity_group_disagg)
346
+ entity_group_disagg = [entity]
347
+ end
348
+ end
349
+ if entity_group_disagg.any?
350
+ # it's the last entity, add it to the entity groups
351
+ entity_groups << group_sub_entities(entity_group_disagg)
352
+ end
353
+
354
+ entity_groups
355
+ end
281
356
  end
282
357
  end
@@ -2,36 +2,19 @@ module Transformers
2
2
  class SentenceTransformer
3
3
  def initialize(model_id)
4
4
  @model_id = model_id
5
- @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
6
- @model = Transformers::AutoModel.from_pretrained(model_id)
5
+ @model = Transformers.pipeline("embedding", model_id)
7
6
  end
8
7
 
9
8
  def encode(sentences)
10
- singular = sentences.is_a?(String)
11
- sentences = [sentences] if singular
12
-
13
- input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
14
- output = Torch.no_grad { @model.(**input) }[0]
15
-
16
9
  # TODO check modules.json
17
10
  if [
18
11
  "sentence-transformers/all-MiniLM-L6-v2",
19
12
  "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
20
13
  ].include?(@model_id)
21
- output = mean_pooling(output, input[:attention_mask])
22
- output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
14
+ @model.(sentences)
23
15
  else
24
- output = output[0.., 0].to_a
16
+ @model.(sentences, pooling: "cls", normalize: false)
25
17
  end
26
-
27
- singular ? output[0] : output
28
- end
29
-
30
- private
31
-
32
- def mean_pooling(output, attention_mask)
33
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
34
- Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
35
18
  end
36
19
  end
37
20
  end
@@ -91,6 +91,10 @@ module Transformers
91
91
  get_vocab
92
92
  end
93
93
 
94
+ def backend_tokenizer
95
+ @tokenizer
96
+ end
97
+
94
98
  def convert_tokens_to_ids(tokens)
95
99
  if tokens.nil?
96
100
  return nil
@@ -130,6 +134,10 @@ module Transformers
130
134
  tokens
131
135
  end
132
136
 
137
+ def convert_tokens_to_string(tokens)
138
+ backend_tokenizer.decoder.decode(tokens)
139
+ end
140
+
133
141
  private
134
142
 
135
143
  def set_truncation_and_padding(
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.2"
3
3
  end
data/lib/transformers.rb CHANGED
@@ -75,6 +75,7 @@ require_relative "transformers/models/vit/modeling_vit"
75
75
  # pipelines
76
76
  require_relative "transformers/pipelines/base"
77
77
  require_relative "transformers/pipelines/feature_extraction"
78
+ require_relative "transformers/pipelines/embedding"
78
79
  require_relative "transformers/pipelines/image_classification"
79
80
  require_relative "transformers/pipelines/image_feature_extraction"
80
81
  require_relative "transformers/pipelines/pt_utils"
@@ -97,4 +98,10 @@ module Transformers
97
98
  "not implemented yet"
98
99
  end
99
100
  end
101
+
102
+ class << self
103
+ # experimental
104
+ attr_accessor :fast_init
105
+ end
106
+ self.fast_init = false
100
107
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-08-19 00:00:00.000000000 Z
11
+ date: 2024-09-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -44,14 +44,14 @@ dependencies:
44
44
  requirements:
45
45
  - - ">="
46
46
  - !ruby/object:Gem::Version
47
- version: '0.5'
47
+ version: 0.5.2
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - ">="
53
53
  - !ruby/object:Gem::Version
54
- version: '0.5'
54
+ version: 0.5.2
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: torch-rb
57
57
  requirement: !ruby/object:Gem::Requirement
@@ -113,6 +113,7 @@ files:
113
113
  - lib/transformers/models/vit/modeling_vit.rb
114
114
  - lib/transformers/pipelines/_init.rb
115
115
  - lib/transformers/pipelines/base.rb
116
+ - lib/transformers/pipelines/embedding.rb
116
117
  - lib/transformers/pipelines/feature_extraction.rb
117
118
  - lib/transformers/pipelines/image_classification.rb
118
119
  - lib/transformers/pipelines/image_feature_extraction.rb
@@ -154,7 +155,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
154
155
  - !ruby/object:Gem::Version
155
156
  version: '0'
156
157
  requirements: []
157
- rubygems_version: 3.5.11
158
+ rubygems_version: 3.5.16
158
159
  signing_key:
159
160
  specification_version: 4
160
161
  summary: State-of-the-art transformers for Ruby