transformers-rb 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c8f34c5454e2a1ac18bbb9a4b290a43e994cd3984fa2b4125ff4af969b9d17ed
4
- data.tar.gz: 57c876fd1a4e62089fdc7bcbfcb9c155050166458a679991036894a6721ac168
3
+ metadata.gz: 2b7441c0cc400d85d3d0775e6c53a631c83887e481588148a7cb3242e50c0f08
4
+ data.tar.gz: df639eb5c3231d344ec8516ec861cb40239ac5ff0310faa5b1b102a79af97972
5
5
  SHA512:
6
- metadata.gz: 7458b1ba0303e0741abf16a63efc350b6cad5e5dff48c46dcba6f62858f562bbf83478eb918995c45f5882159cf5c22d696cc4b0360f813312c2263da7c28205
7
- data.tar.gz: a4d98b210a22d23bc55f452a93dd0e9df20c81cbc0d537d29eaa1069b157eec1df6cdc76fc7f398947e03ae841166d609bff0d250506eb0dd0745ef0fdf86efd
6
+ metadata.gz: 566a03e8a65255d3eb7e87bb3e68c0ac4d228ba4bc55e62eaad98f6677b8a5f29ee9201efaefc3edb309d279c1a986a664dfffbf3cc292114394be1c10fa949a
7
+ data.tar.gz: fdae84f4fa8aac90f1a627d738fb701ef91d384414173f4042bac56bfe5c78c6ceb98aaea8ef0f2158d324ec60d18cbd91f84bd30f42b49d63193392e0105e96
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ ## 0.1.1 (2024-08-29)
2
+
3
+ - Added `embedding` pipeline
4
+ - Added experimental `fast_init` option
5
+ - Improved performance of loading models
6
+ - Fixed error with `aggregation_strategy` option
7
+
1
8
  ## 0.1.0 (2024-08-19)
2
9
 
3
10
  - First release
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :slightly_smiling_face: State-of-the-art [transformers](https://github.com/huggingface/transformers) for Ruby
4
4
 
5
+ For fast inference, check out [Informers](https://github.com/ankane/informers) :fire:
6
+
5
7
  [![Build Status](https://github.com/ankane/transformers-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/transformers-ruby/actions)
6
8
 
7
9
  ## Installation
@@ -21,6 +23,17 @@ gem "transformers-rb"
21
23
 
22
24
  ## Models
23
25
 
26
+ Embedding
27
+
28
+ - [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
29
+ - [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
30
+ - [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
31
+ - [thenlper/gte-small](#thenlpergte-small)
32
+ - [intfloat/e5-base-v2](#intfloate5-base-v2)
33
+ - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
34
+ - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+ - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
36
+
24
37
  ### sentence-transformers/all-MiniLM-L6-v2
25
38
 
26
39
  [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
@@ -28,8 +41,8 @@ gem "transformers-rb"
28
41
  ```ruby
29
42
  sentences = ["This is an example sentence", "Each sentence is converted"]
30
43
 
31
- model = Transformers::SentenceTransformer.new("sentence-transformers/all-MiniLM-L6-v2")
32
- embeddings = model.encode(sentences)
44
+ model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2")
45
+ embeddings = model.(sentences)
33
46
  ```
34
47
 
35
48
  ### sentence-transformers/multi-qa-MiniLM-L6-cos-v1
@@ -40,10 +53,10 @@ embeddings = model.encode(sentences)
40
53
  query = "How many people live in London?"
41
54
  docs = ["Around 9 Million people live in London", "London is known for its financial district"]
42
55
 
43
- model = Transformers::SentenceTransformer.new("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
44
- query_emb = model.encode(query)
45
- doc_emb = model.encode(docs)
46
- scores = Torch.mm(Torch.tensor([query_emb]), Torch.tensor(doc_emb).transpose(0, 1))[0].cpu.to_a
56
+ model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
57
+ query_embedding = model.(query)
58
+ doc_embeddings = model.(docs)
59
+ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
47
60
  doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
48
61
  ```
49
62
 
@@ -52,18 +65,78 @@ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
52
65
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
53
66
 
54
67
  ```ruby
55
- def transform_query(query)
56
- "Represent this sentence for searching relevant passages: #{query}"
57
- end
68
+ query_prefix = "Represent this sentence for searching relevant passages: "
58
69
 
59
- docs = [
60
- transform_query("puppy"),
70
+ input = [
61
71
  "The dog is barking",
62
- "The cat is purring"
72
+ "The cat is purring",
73
+ query_prefix + "puppy"
74
+ ]
75
+
76
+ model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
77
+ embeddings = model.(input)
78
+ ```
79
+
80
+ ### thenlper/gte-small
81
+
82
+ [Docs](https://huggingface.co/thenlper/gte-small)
83
+
84
+ ```ruby
85
+ sentences = ["That is a happy person", "That is a very happy person"]
86
+
87
+ model = Transformers.pipeline("embedding", "thenlper/gte-small")
88
+ embeddings = model.(sentences)
89
+ ```
90
+
91
+ ### intfloat/e5-base-v2
92
+
93
+ [Docs](https://huggingface.co/intfloat/e5-base-v2)
94
+
95
+ ```ruby
96
+ doc_prefix = "passage: "
97
+ query_prefix = "query: "
98
+
99
+ input = [
100
+ doc_prefix + "Ruby is a programming language created by Matz",
101
+ query_prefix + "Ruby creator"
63
102
  ]
64
103
 
65
- model = Transformers::SentenceTransformer.new("mixedbread-ai/mxbai-embed-large-v1")
66
- embeddings = model.encode(docs)
104
+ model = Transformers.pipeline("embedding", "intfloat/e5-base-v2")
105
+ embeddings = model.(input)
106
+ ```
107
+
108
+ ### BAAI/bge-base-en-v1.5
109
+
110
+ [Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)
111
+
112
+ ```ruby
113
+ query_prefix = "Represent this sentence for searching relevant passages: "
114
+
115
+ input = [
116
+ "The dog is barking",
117
+ "The cat is purring",
118
+ query_prefix + "puppy"
119
+ ]
120
+
121
+ model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
122
+ embeddings = model.(input)
123
+ ```
124
+
125
+ ### Snowflake/snowflake-arctic-embed-m-v1.5
126
+
127
+ [Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)
128
+
129
+ ```ruby
130
+ query_prefix = "Represent this sentence for searching relevant passages: "
131
+
132
+ input = [
133
+ "The dog is barking",
134
+ "The cat is purring",
135
+ query_prefix + "puppy"
136
+ ]
137
+
138
+ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
139
+ embeddings = model.(input, pooling: "cls")
67
140
  ```
68
141
 
69
142
  ### opensearch-project/opensearch-neural-sparse-encoding-v1
@@ -89,6 +162,13 @@ embeddings = values.to_a
89
162
 
90
163
  ## Pipelines
91
164
 
165
+ Embedding
166
+
167
+ ```ruby
168
+ embed = Transformers.pipeline("embedding")
169
+ embed.("We are very happy to show you the 🤗 Transformers library.")
170
+ ```
171
+
92
172
  Named-entity recognition
93
173
 
94
174
  ```ruby
@@ -207,10 +207,17 @@ module Transformers
207
207
 
208
208
  config = new(**config_dict)
209
209
 
210
+ to_remove = []
210
211
  kwargs.each do |key, value|
211
212
  if config.respond_to?("#{key}=")
212
213
  config.public_send("#{key}=", value)
213
214
  end
215
+ if key != :torch_dtype
216
+ to_remove << key
217
+ end
218
+ end
219
+ to_remove.each do |key|
220
+ kwargs.delete(key)
214
221
  end
215
222
 
216
223
  Transformers.logger.info("Model config #{config}")
@@ -22,6 +22,7 @@ module Transformers
22
22
  def to_h
23
23
  @data
24
24
  end
25
+ alias_method :to_hash, :to_h
25
26
 
26
27
  def [](item)
27
28
  @data[item]
@@ -14,6 +14,47 @@
14
14
  # limitations under the License.
15
15
 
16
16
  module Transformers
17
+ module ModelingUtils
18
+ TORCH_INIT_FUNCTIONS = {
19
+ "uniform!" => Torch::NN::Init.method(:uniform!),
20
+ "normal!" => Torch::NN::Init.method(:normal!),
21
+ # "trunc_normal!" => Torch::NN::Init.method(:trunc_normal!),
22
+ "constant!" => Torch::NN::Init.method(:constant!),
23
+ "xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
24
+ "xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
25
+ "kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
26
+ "kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!),
27
+ # "uniform" => Torch::NN::Init.method(:uniform),
28
+ # "normal" => Torch::NN::Init.method(:normal),
29
+ # "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
30
+ # "xavier_normal" => Torch::NN::Init.method(:xavier_normal),
31
+ # "kaiming_uniform" => Torch::NN::Init.method(:kaiming_uniform),
32
+ # "kaiming_normal" => Torch::NN::Init.method(:kaiming_normal)
33
+ }
34
+
35
+ # private
36
+ # note: this improves loading time significantly, but is not thread-safe!
37
+ def self.no_init_weights
38
+ return yield unless Transformers.fast_init
39
+
40
+ _skip_init = lambda do |*args, **kwargs|
41
+ # pass
42
+ end
43
+ # Save the original initialization functions
44
+ TORCH_INIT_FUNCTIONS.each do |name, init_func|
45
+ Torch::NN::Init.singleton_class.undef_method(name)
46
+ Torch::NN::Init.define_singleton_method(name, &_skip_init)
47
+ end
48
+ yield
49
+ ensure
50
+ # Restore the original initialization functions
51
+ TORCH_INIT_FUNCTIONS.each do |name, init_func|
52
+ Torch::NN::Init.singleton_class.undef_method(name)
53
+ Torch::NN::Init.define_singleton_method(name, init_func)
54
+ end
55
+ end
56
+ end
57
+
17
58
  module ModuleUtilsMixin
18
59
  def get_extended_attention_mask(
19
60
  attention_mask,
@@ -138,7 +179,11 @@ module Transformers
138
179
  end
139
180
 
140
181
  def _initialize_weights(mod)
182
+ if mod.instance_variable_defined?(:@is_hf_initialized)
183
+ return
184
+ end
141
185
  _init_weights(mod)
186
+ mod.instance_variable_set(:@is_hf_initialized, true)
142
187
  end
143
188
 
144
189
  def tie_weights
@@ -166,7 +211,9 @@ module Transformers
166
211
  prune_heads(@config.pruned_heads)
167
212
  end
168
213
 
169
- if true
214
+ # TODO implement no_init_weights context manager
215
+ _init_weights = false
216
+ if _init_weights
170
217
  # Initialize weights
171
218
  apply(method(:_initialize_weights))
172
219
 
@@ -512,8 +559,8 @@ module Transformers
512
559
 
513
560
  config.name_or_path = pretrained_model_name_or_path
514
561
 
515
- model_kwargs = {}
516
- model = new(config, *model_args, **model_kwargs)
562
+ # Instantiate model.
563
+ model = ModelingUtils.no_init_weights { new(config, *model_args, **model_kwargs) }
517
564
 
518
565
  # make sure we use the model's config since the __init__ call might have copied it
519
566
  config = model.config
@@ -683,6 +730,10 @@ module Transformers
683
730
  end
684
731
  end
685
732
 
733
+ if _fast_init
734
+ # TODO
735
+ end
736
+
686
737
  # Make sure we are able to load base models as well as derived models (with heads)
687
738
  start_prefix = ""
688
739
  model_to_load = model
@@ -756,28 +807,29 @@ module Transformers
756
807
  raise Todo
757
808
  end
758
809
 
810
+ model_class_name = model.class.name.split("::").last
759
811
  if unexpected_keys.length > 0
760
812
  archs = model.config.architectures.nil? ? [] : model.config.architectures
761
- warner = archs.include?(model.class.name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
813
+ warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
762
814
  warner.(
763
815
  "Some weights of the model checkpoint at #{pretrained_model_name_or_path} were not used when" +
764
- " initializing #{model.class.name}: #{unexpected_keys}\n- This IS expected if you are" +
765
- " initializing #{model.class.name} from the checkpoint of a model trained on another task or" +
816
+ " initializing #{model_class_name}: #{unexpected_keys}\n- This IS expected if you are" +
817
+ " initializing #{model_class_name} from the checkpoint of a model trained on another task or" +
766
818
  " with another architecture (e.g. initializing a BertForSequenceClassification model from a" +
767
819
  " BertForPreTraining model).\n- This IS NOT expected if you are initializing" +
768
- " #{model.class.name} from the checkpoint of a model that you expect to be exactly identical" +
820
+ " #{model_class_name} from the checkpoint of a model that you expect to be exactly identical" +
769
821
  " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
770
822
  )
771
823
  else
772
- Transformers.logger.info("All model checkpoint weights were used when initializing #{model.class.name}.\n")
824
+ Transformers.logger.info("All model checkpoint weights were used when initializing #{model_class_name}.\n")
773
825
  end
774
826
  if missing_keys.length > 0
775
- Transformers.logger.info("Some weights of #{model.class.name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
827
+ Transformers.logger.info("Some weights of #{model_class_name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
776
828
  elsif mismatched_keys.length == 0
777
829
  Transformers.logger.info(
778
- "All the weights of #{model.class.name} were initialized from the model checkpoint at" +
830
+ "All the weights of #{model_class_name} were initialized from the model checkpoint at" +
779
831
  " #{pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" +
780
- " was trained on, you can already use #{model.class.name} for predictions without further" +
832
+ " was trained on, you can already use #{model_class_name} for predictions without further" +
781
833
  " training."
782
834
  )
783
835
  end
@@ -97,7 +97,7 @@ module Transformers
97
97
  @position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
98
98
  if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
99
99
  @max_position_embeddings = config.max_position_embeddings
100
- @distance_embedding = Torch:NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
100
+ @distance_embedding = Torch::NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
101
101
  end
102
102
 
103
103
  @is_decoder = config.is_decoder
@@ -639,8 +639,8 @@ module Transformers
639
639
  extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
640
640
  end
641
641
 
642
- # # If a 2D or 3D attention mask is provided for the cross-attention
643
- # # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
642
+ # If a 2D or 3D attention mask is provided for the cross-attention
643
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
644
644
  if @config.is_decoder && !encoder_hidden_states.nil?
645
645
  encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
646
646
  encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
@@ -78,7 +78,17 @@ module Transformers
78
78
  }
79
79
  },
80
80
  "type" => "image"
81
- }
81
+ },
82
+ "embedding" => {
83
+ "impl" => EmbeddingPipeline,
84
+ "pt" => [AutoModel],
85
+ "default" => {
86
+ "model" => {
87
+ "pt" => ["sentence-transformers/all-MiniLM-L6-v2", "8b3219a"]
88
+ }
89
+ },
90
+ "type" => "text"
91
+ },
82
92
  }
83
93
 
84
94
  PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
@@ -86,6 +96,7 @@ module Transformers
86
96
  class << self
87
97
  def pipeline(
88
98
  task,
99
+ model_arg = nil,
89
100
  model: nil,
90
101
  config: nil,
91
102
  tokenizer: nil,
@@ -103,6 +114,13 @@ module Transformers
103
114
  pipeline_class: nil,
104
115
  **kwargs
105
116
  )
117
+ if !model_arg.nil?
118
+ if !model.nil?
119
+ raise ArgumentError, "Cannot pass multiple models"
120
+ end
121
+ model = model_arg
122
+ end
123
+
106
124
  model_kwargs ||= {}
107
125
  # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
108
126
  # this is to keep BC).
@@ -0,0 +1,46 @@
1
+ module Transformers
2
+ class EmbeddingPipeline < Pipeline
3
+ def _sanitize_parameters(**kwargs)
4
+ [{}, {}, kwargs]
5
+ end
6
+
7
+ def preprocess(inputs)
8
+ @tokenizer.(inputs, return_tensors: @framework)
9
+ end
10
+
11
+ def _forward(model_inputs)
12
+ {
13
+ last_hidden_state: @model.(**model_inputs)[0],
14
+ attention_mask: model_inputs[:attention_mask]
15
+ }
16
+ end
17
+
18
+ def postprocess(model_outputs, pooling: "mean", normalize: true)
19
+ output = model_outputs[:last_hidden_state]
20
+
21
+ case pooling
22
+ when "none"
23
+ # do nothing
24
+ when "mean"
25
+ output = mean_pooling(output, model_outputs[:attention_mask])
26
+ when "cls"
27
+ output = output[0.., 0]
28
+ else
29
+ raise Error, "Pooling method '#{pooling}' not supported."
30
+ end
31
+
32
+ if normalize
33
+ output = Torch::NN::Functional.normalize(output, p: 2, dim: 1)
34
+ end
35
+
36
+ output[0].to_a
37
+ end
38
+
39
+ private
40
+
41
+ def mean_pooling(output, attention_mask)
42
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
43
+ Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
44
+ end
45
+ end
46
+ end
@@ -45,7 +45,7 @@ module Transformers
45
45
  end
46
46
 
47
47
  def _forward(model_inputs)
48
- model_outputs = @model.(**model_inputs.to_h)
48
+ model_outputs = @model.(**model_inputs)
49
49
  model_outputs
50
50
  end
51
51
 
@@ -29,7 +29,7 @@ module Transformers
29
29
  end
30
30
 
31
31
  def _forward(model_inputs)
32
- model_outputs = @model.(**model_inputs.to_h)
32
+ model_outputs = @model.(**model_inputs)
33
33
  model_outputs
34
34
  end
35
35
 
@@ -34,7 +34,7 @@ module Transformers
34
34
  end
35
35
 
36
36
  if function_to_apply.is_a?(String)
37
- function_to_apply = ClassificationFunction.new(function_to_apply.upcase).to_s
37
+ function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s
38
38
  end
39
39
 
40
40
  if !function_to_apply.nil?
@@ -62,7 +62,7 @@ module Transformers
62
62
  end
63
63
 
64
64
  def _forward(model_inputs)
65
- @model.(**model_inputs.to_h)
65
+ @model.(**model_inputs)
66
66
  end
67
67
 
68
68
  def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
@@ -62,7 +62,7 @@ module Transformers
62
62
 
63
63
  if !aggregation_strategy.nil?
64
64
  if aggregation_strategy.is_a?(String)
65
- aggregation_strategy = AggregationStrategy.new(aggregation_strategy.upcase).to_s
65
+ aggregation_strategy = AggregationStrategy.new(aggregation_strategy.downcase).to_s
66
66
  end
67
67
  if (
68
68
  [AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
@@ -278,5 +278,80 @@ module Transformers
278
278
  end
279
279
  group_entities(entities)
280
280
  end
281
+
282
+ def aggregate_word(entities, aggregation_strategy)
283
+ raise Todo
284
+ end
285
+
286
+ def aggregate_words(entities, aggregation_strategy)
287
+ raise Todo
288
+ end
289
+
290
+ def group_sub_entities(entities)
291
+ # Get the first entity in the entity group
292
+ entity = entities[0][:entity].split("-", 2)[-1]
293
+ scores = entities.map { |entity| entity[:score] }
294
+ tokens = entities.map { |entity| entity[:word] }
295
+
296
+ entity_group = {
297
+ entity_group: entity,
298
+ score: scores.sum / scores.count.to_f,
299
+ word: @tokenizer.convert_tokens_to_string(tokens),
300
+ start: entities[0][:start],
301
+ end: entities[-1][:end]
302
+ }
303
+ entity_group
304
+ end
305
+
306
+ def get_tag(entity_name)
307
+ if entity_name.start_with?("B-")
308
+ bi = "B"
309
+ tag = entity_name[2..]
310
+ elsif entity_name.start_with?("I-")
311
+ bi = "I"
312
+ tag = entity_name[2..]
313
+ else
314
+ # It's not in B-, I- format
315
+ # Default to I- for continuation.
316
+ bi = "I"
317
+ tag = entity_name
318
+ end
319
+ [bi, tag]
320
+ end
321
+
322
+ def group_entities(entities)
323
+ entity_groups = []
324
+ entity_group_disagg = []
325
+
326
+ entities.each do |entity|
327
+ if entity_group_disagg.empty?
328
+ entity_group_disagg << entity
329
+ next
330
+ end
331
+
332
+ # If the current entity is similar and adjacent to the previous entity,
333
+ # append it to the disaggregated entity group
334
+ # The split is meant to account for the "B" and "I" prefixes
335
+ # Shouldn't merge if both entities are B-type
336
+ bi, tag = get_tag(entity[:entity])
337
+ _last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])
338
+
339
+ if tag == last_tag && bi != "B"
340
+ # Modify subword type to be previous_type
341
+ entity_group_disagg << entity
342
+ else
343
+ # If the current entity is different from the previous entity
344
+ # aggregate the disaggregated entity group
345
+ entity_groups << group_sub_entities(entity_group_disagg)
346
+ entity_group_disagg = [entity]
347
+ end
348
+ end
349
+ if entity_group_disagg.any?
350
+ # it's the last entity, add it to the entity groups
351
+ entity_groups << group_sub_entities(entity_group_disagg)
352
+ end
353
+
354
+ entity_groups
355
+ end
281
356
  end
282
357
  end
@@ -2,36 +2,19 @@ module Transformers
2
2
  class SentenceTransformer
3
3
  def initialize(model_id)
4
4
  @model_id = model_id
5
- @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
6
- @model = Transformers::AutoModel.from_pretrained(model_id)
5
+ @model = Transformers.pipeline("embedding", model_id)
7
6
  end
8
7
 
9
8
  def encode(sentences)
10
- singular = sentences.is_a?(String)
11
- sentences = [sentences] if singular
12
-
13
- input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
14
- output = Torch.no_grad { @model.(**input) }[0]
15
-
16
9
  # TODO check modules.json
17
10
  if [
18
11
  "sentence-transformers/all-MiniLM-L6-v2",
19
12
  "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
20
13
  ].include?(@model_id)
21
- output = mean_pooling(output, input[:attention_mask])
22
- output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
14
+ @model.(sentences)
23
15
  else
24
- output = output[0.., 0].to_a
16
+ @model.(sentences, pooling: "cls", normalize: false)
25
17
  end
26
-
27
- singular ? output[0] : output
28
- end
29
-
30
- private
31
-
32
- def mean_pooling(output, attention_mask)
33
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
34
- Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
35
18
  end
36
19
  end
37
20
  end
@@ -91,6 +91,10 @@ module Transformers
91
91
  get_vocab
92
92
  end
93
93
 
94
+ def backend_tokenizer
95
+ @tokenizer
96
+ end
97
+
94
98
  def convert_tokens_to_ids(tokens)
95
99
  if tokens.nil?
96
100
  return nil
@@ -130,6 +134,10 @@ module Transformers
130
134
  tokens
131
135
  end
132
136
 
137
+ def convert_tokens_to_string(tokens)
138
+ backend_tokenizer.decoder.decode(tokens)
139
+ end
140
+
133
141
  private
134
142
 
135
143
  def set_truncation_and_padding(
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
data/lib/transformers.rb CHANGED
@@ -75,6 +75,7 @@ require_relative "transformers/models/vit/modeling_vit"
75
75
  # pipelines
76
76
  require_relative "transformers/pipelines/base"
77
77
  require_relative "transformers/pipelines/feature_extraction"
78
+ require_relative "transformers/pipelines/embedding"
78
79
  require_relative "transformers/pipelines/image_classification"
79
80
  require_relative "transformers/pipelines/image_feature_extraction"
80
81
  require_relative "transformers/pipelines/pt_utils"
@@ -97,4 +98,10 @@ module Transformers
97
98
  "not implemented yet"
98
99
  end
99
100
  end
101
+
102
+ class << self
103
+ # experimental
104
+ attr_accessor :fast_init
105
+ end
106
+ self.fast_init = false
100
107
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-08-19 00:00:00.000000000 Z
11
+ date: 2024-08-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -44,14 +44,14 @@ dependencies:
44
44
  requirements:
45
45
  - - ">="
46
46
  - !ruby/object:Gem::Version
47
- version: '0.5'
47
+ version: 0.5.2
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - ">="
53
53
  - !ruby/object:Gem::Version
54
- version: '0.5'
54
+ version: 0.5.2
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: torch-rb
57
57
  requirement: !ruby/object:Gem::Requirement
@@ -113,6 +113,7 @@ files:
113
113
  - lib/transformers/models/vit/modeling_vit.rb
114
114
  - lib/transformers/pipelines/_init.rb
115
115
  - lib/transformers/pipelines/base.rb
116
+ - lib/transformers/pipelines/embedding.rb
116
117
  - lib/transformers/pipelines/feature_extraction.rb
117
118
  - lib/transformers/pipelines/image_classification.rb
118
119
  - lib/transformers/pipelines/image_feature_extraction.rb