transformers-rb 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c8f34c5454e2a1ac18bbb9a4b290a43e994cd3984fa2b4125ff4af969b9d17ed
4
- data.tar.gz: 57c876fd1a4e62089fdc7bcbfcb9c155050166458a679991036894a6721ac168
3
+ metadata.gz: 2b7441c0cc400d85d3d0775e6c53a631c83887e481588148a7cb3242e50c0f08
4
+ data.tar.gz: df639eb5c3231d344ec8516ec861cb40239ac5ff0310faa5b1b102a79af97972
5
5
  SHA512:
6
- metadata.gz: 7458b1ba0303e0741abf16a63efc350b6cad5e5dff48c46dcba6f62858f562bbf83478eb918995c45f5882159cf5c22d696cc4b0360f813312c2263da7c28205
7
- data.tar.gz: a4d98b210a22d23bc55f452a93dd0e9df20c81cbc0d537d29eaa1069b157eec1df6cdc76fc7f398947e03ae841166d609bff0d250506eb0dd0745ef0fdf86efd
6
+ metadata.gz: 566a03e8a65255d3eb7e87bb3e68c0ac4d228ba4bc55e62eaad98f6677b8a5f29ee9201efaefc3edb309d279c1a986a664dfffbf3cc292114394be1c10fa949a
7
+ data.tar.gz: fdae84f4fa8aac90f1a627d738fb701ef91d384414173f4042bac56bfe5c78c6ceb98aaea8ef0f2158d324ec60d18cbd91f84bd30f42b49d63193392e0105e96
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ ## 0.1.1 (2024-08-29)
2
+
3
+ - Added `embedding` pipeline
4
+ - Added experimental `fast_init` option
5
+ - Improved performance of loading models
6
+ - Fixed error with `aggregation_strategy` option
7
+
1
8
  ## 0.1.0 (2024-08-19)
2
9
 
3
10
  - First release
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :slightly_smiling_face: State-of-the-art [transformers](https://github.com/huggingface/transformers) for Ruby
4
4
 
5
+ For fast inference, check out [Informers](https://github.com/ankane/informers) :fire:
6
+
5
7
  [![Build Status](https://github.com/ankane/transformers-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/transformers-ruby/actions)
6
8
 
7
9
  ## Installation
@@ -21,6 +23,17 @@ gem "transformers-rb"
21
23
 
22
24
  ## Models
23
25
 
26
+ Embedding
27
+
28
+ - [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
29
+ - [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
30
+ - [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
31
+ - [thenlper/gte-small](#thenlpergte-small)
32
+ - [intfloat/e5-base-v2](#intfloate5-base-v2)
33
+ - [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
34
+ - [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
35
+ - [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
36
+
24
37
  ### sentence-transformers/all-MiniLM-L6-v2
25
38
 
26
39
  [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
@@ -28,8 +41,8 @@ gem "transformers-rb"
28
41
  ```ruby
29
42
  sentences = ["This is an example sentence", "Each sentence is converted"]
30
43
 
31
- model = Transformers::SentenceTransformer.new("sentence-transformers/all-MiniLM-L6-v2")
32
- embeddings = model.encode(sentences)
44
+ model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2")
45
+ embeddings = model.(sentences)
33
46
  ```
34
47
 
35
48
  ### sentence-transformers/multi-qa-MiniLM-L6-cos-v1
@@ -40,10 +53,10 @@ embeddings = model.encode(sentences)
40
53
  query = "How many people live in London?"
41
54
  docs = ["Around 9 Million people live in London", "London is known for its financial district"]
42
55
 
43
- model = Transformers::SentenceTransformer.new("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
44
- query_emb = model.encode(query)
45
- doc_emb = model.encode(docs)
46
- scores = Torch.mm(Torch.tensor([query_emb]), Torch.tensor(doc_emb).transpose(0, 1))[0].cpu.to_a
56
+ model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
57
+ query_embedding = model.(query)
58
+ doc_embeddings = model.(docs)
59
+ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
47
60
  doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
48
61
  ```
49
62
 
@@ -52,18 +65,78 @@ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
52
65
  [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
53
66
 
54
67
  ```ruby
55
- def transform_query(query)
56
- "Represent this sentence for searching relevant passages: #{query}"
57
- end
68
+ query_prefix = "Represent this sentence for searching relevant passages: "
58
69
 
59
- docs = [
60
- transform_query("puppy"),
70
+ input = [
61
71
  "The dog is barking",
62
- "The cat is purring"
72
+ "The cat is purring",
73
+ query_prefix + "puppy"
74
+ ]
75
+
76
+ model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
77
+ embeddings = model.(input)
78
+ ```
79
+
80
+ ### thenlper/gte-small
81
+
82
+ [Docs](https://huggingface.co/thenlper/gte-small)
83
+
84
+ ```ruby
85
+ sentences = ["That is a happy person", "That is a very happy person"]
86
+
87
+ model = Transformers.pipeline("embedding", "thenlper/gte-small")
88
+ embeddings = model.(sentences)
89
+ ```
90
+
91
+ ### intfloat/e5-base-v2
92
+
93
+ [Docs](https://huggingface.co/intfloat/e5-base-v2)
94
+
95
+ ```ruby
96
+ doc_prefix = "passage: "
97
+ query_prefix = "query: "
98
+
99
+ input = [
100
+ doc_prefix + "Ruby is a programming language created by Matz",
101
+ query_prefix + "Ruby creator"
63
102
  ]
64
103
 
65
- model = Transformers::SentenceTransformer.new("mixedbread-ai/mxbai-embed-large-v1")
66
- embeddings = model.encode(docs)
104
+ model = Transformers.pipeline("embedding", "intfloat/e5-base-v2")
105
+ embeddings = model.(input)
106
+ ```
107
+
108
+ ### BAAI/bge-base-en-v1.5
109
+
110
+ [Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)
111
+
112
+ ```ruby
113
+ query_prefix = "Represent this sentence for searching relevant passages: "
114
+
115
+ input = [
116
+ "The dog is barking",
117
+ "The cat is purring",
118
+ query_prefix + "puppy"
119
+ ]
120
+
121
+ model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
122
+ embeddings = model.(input)
123
+ ```
124
+
125
+ ### Snowflake/snowflake-arctic-embed-m-v1.5
126
+
127
+ [Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)
128
+
129
+ ```ruby
130
+ query_prefix = "Represent this sentence for searching relevant passages: "
131
+
132
+ input = [
133
+ "The dog is barking",
134
+ "The cat is purring",
135
+ query_prefix + "puppy"
136
+ ]
137
+
138
+ model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
139
+ embeddings = model.(input, pooling: "cls")
67
140
  ```
68
141
 
69
142
  ### opensearch-project/opensearch-neural-sparse-encoding-v1
@@ -89,6 +162,13 @@ embeddings = values.to_a
89
162
 
90
163
  ## Pipelines
91
164
 
165
+ Embedding
166
+
167
+ ```ruby
168
+ embed = Transformers.pipeline("embedding")
169
+ embed.("We are very happy to show you the 🤗 Transformers library.")
170
+ ```
171
+
92
172
  Named-entity recognition
93
173
 
94
174
  ```ruby
@@ -207,10 +207,17 @@ module Transformers
207
207
 
208
208
  config = new(**config_dict)
209
209
 
210
+ to_remove = []
210
211
  kwargs.each do |key, value|
211
212
  if config.respond_to?("#{key}=")
212
213
  config.public_send("#{key}=", value)
213
214
  end
215
+ if key != :torch_dtype
216
+ to_remove << key
217
+ end
218
+ end
219
+ to_remove.each do |key|
220
+ kwargs.delete(key)
214
221
  end
215
222
 
216
223
  Transformers.logger.info("Model config #{config}")
@@ -22,6 +22,7 @@ module Transformers
22
22
  def to_h
23
23
  @data
24
24
  end
25
+ alias_method :to_hash, :to_h
25
26
 
26
27
  def [](item)
27
28
  @data[item]
@@ -14,6 +14,47 @@
14
14
  # limitations under the License.
15
15
 
16
16
  module Transformers
17
+ module ModelingUtils
18
+ TORCH_INIT_FUNCTIONS = {
19
+ "uniform!" => Torch::NN::Init.method(:uniform!),
20
+ "normal!" => Torch::NN::Init.method(:normal!),
21
+ # "trunc_normal!" => Torch::NN::Init.method(:trunc_normal!),
22
+ "constant!" => Torch::NN::Init.method(:constant!),
23
+ "xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
24
+ "xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
25
+ "kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
26
+ "kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!),
27
+ # "uniform" => Torch::NN::Init.method(:uniform),
28
+ # "normal" => Torch::NN::Init.method(:normal),
29
+ # "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
30
+ # "xavier_normal" => Torch::NN::Init.method(:xavier_normal),
31
+ # "kaiming_uniform" => Torch::NN::Init.method(:kaiming_uniform),
32
+ # "kaiming_normal" => Torch::NN::Init.method(:kaiming_normal)
33
+ }
34
+
35
+ # private
36
+ # note: this improves loading time significantly, but is not thread-safe!
37
+ def self.no_init_weights
38
+ return yield unless Transformers.fast_init
39
+
40
+ _skip_init = lambda do |*args, **kwargs|
41
+ # pass
42
+ end
43
+ # Save the original initialization functions
44
+ TORCH_INIT_FUNCTIONS.each do |name, init_func|
45
+ Torch::NN::Init.singleton_class.undef_method(name)
46
+ Torch::NN::Init.define_singleton_method(name, &_skip_init)
47
+ end
48
+ yield
49
+ ensure
50
+ # Restore the original initialization functions
51
+ TORCH_INIT_FUNCTIONS.each do |name, init_func|
52
+ Torch::NN::Init.singleton_class.undef_method(name)
53
+ Torch::NN::Init.define_singleton_method(name, init_func)
54
+ end
55
+ end
56
+ end
57
+
17
58
  module ModuleUtilsMixin
18
59
  def get_extended_attention_mask(
19
60
  attention_mask,
@@ -138,7 +179,11 @@ module Transformers
138
179
  end
139
180
 
140
181
  def _initialize_weights(mod)
182
+ if mod.instance_variable_defined?(:@is_hf_initialized)
183
+ return
184
+ end
141
185
  _init_weights(mod)
186
+ mod.instance_variable_set(:@is_hf_initialized, true)
142
187
  end
143
188
 
144
189
  def tie_weights
@@ -166,7 +211,9 @@ module Transformers
166
211
  prune_heads(@config.pruned_heads)
167
212
  end
168
213
 
169
- if true
214
+ # TODO implement no_init_weights context manager
215
+ _init_weights = false
216
+ if _init_weights
170
217
  # Initialize weights
171
218
  apply(method(:_initialize_weights))
172
219
 
@@ -512,8 +559,8 @@ module Transformers
512
559
 
513
560
  config.name_or_path = pretrained_model_name_or_path
514
561
 
515
- model_kwargs = {}
516
- model = new(config, *model_args, **model_kwargs)
562
+ # Instantiate model.
563
+ model = ModelingUtils.no_init_weights { new(config, *model_args, **model_kwargs) }
517
564
 
518
565
  # make sure we use the model's config since the __init__ call might have copied it
519
566
  config = model.config
@@ -683,6 +730,10 @@ module Transformers
683
730
  end
684
731
  end
685
732
 
733
+ if _fast_init
734
+ # TODO
735
+ end
736
+
686
737
  # Make sure we are able to load base models as well as derived models (with heads)
687
738
  start_prefix = ""
688
739
  model_to_load = model
@@ -756,28 +807,29 @@ module Transformers
756
807
  raise Todo
757
808
  end
758
809
 
810
+ model_class_name = model.class.name.split("::").last
759
811
  if unexpected_keys.length > 0
760
812
  archs = model.config.architectures.nil? ? [] : model.config.architectures
761
- warner = archs.include?(model.class.name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
813
+ warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
762
814
  warner.(
763
815
  "Some weights of the model checkpoint at #{pretrained_model_name_or_path} were not used when" +
764
- " initializing #{model.class.name}: #{unexpected_keys}\n- This IS expected if you are" +
765
- " initializing #{model.class.name} from the checkpoint of a model trained on another task or" +
816
+ " initializing #{model_class_name}: #{unexpected_keys}\n- This IS expected if you are" +
817
+ " initializing #{model_class_name} from the checkpoint of a model trained on another task or" +
766
818
  " with another architecture (e.g. initializing a BertForSequenceClassification model from a" +
767
819
  " BertForPreTraining model).\n- This IS NOT expected if you are initializing" +
768
- " #{model.class.name} from the checkpoint of a model that you expect to be exactly identical" +
820
+ " #{model_class_name} from the checkpoint of a model that you expect to be exactly identical" +
769
821
  " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
770
822
  )
771
823
  else
772
- Transformers.logger.info("All model checkpoint weights were used when initializing #{model.class.name}.\n")
824
+ Transformers.logger.info("All model checkpoint weights were used when initializing #{model_class_name}.\n")
773
825
  end
774
826
  if missing_keys.length > 0
775
- Transformers.logger.info("Some weights of #{model.class.name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
827
+ Transformers.logger.info("Some weights of #{model_class_name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
776
828
  elsif mismatched_keys.length == 0
777
829
  Transformers.logger.info(
778
- "All the weights of #{model.class.name} were initialized from the model checkpoint at" +
830
+ "All the weights of #{model_class_name} were initialized from the model checkpoint at" +
779
831
  " #{pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" +
780
- " was trained on, you can already use #{model.class.name} for predictions without further" +
832
+ " was trained on, you can already use #{model_class_name} for predictions without further" +
781
833
  " training."
782
834
  )
783
835
  end
@@ -97,7 +97,7 @@ module Transformers
97
97
  @position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
98
98
  if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
99
99
  @max_position_embeddings = config.max_position_embeddings
100
- @distance_embedding = Torch:NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
100
+ @distance_embedding = Torch::NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
101
101
  end
102
102
 
103
103
  @is_decoder = config.is_decoder
@@ -639,8 +639,8 @@ module Transformers
639
639
  extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
640
640
  end
641
641
 
642
- # # If a 2D or 3D attention mask is provided for the cross-attention
643
- # # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
642
+ # If a 2D or 3D attention mask is provided for the cross-attention
643
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
644
644
  if @config.is_decoder && !encoder_hidden_states.nil?
645
645
  encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
646
646
  encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
@@ -78,7 +78,17 @@ module Transformers
78
78
  }
79
79
  },
80
80
  "type" => "image"
81
- }
81
+ },
82
+ "embedding" => {
83
+ "impl" => EmbeddingPipeline,
84
+ "pt" => [AutoModel],
85
+ "default" => {
86
+ "model" => {
87
+ "pt" => ["sentence-transformers/all-MiniLM-L6-v2", "8b3219a"]
88
+ }
89
+ },
90
+ "type" => "text"
91
+ },
82
92
  }
83
93
 
84
94
  PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
@@ -86,6 +96,7 @@ module Transformers
86
96
  class << self
87
97
  def pipeline(
88
98
  task,
99
+ model_arg = nil,
89
100
  model: nil,
90
101
  config: nil,
91
102
  tokenizer: nil,
@@ -103,6 +114,13 @@ module Transformers
103
114
  pipeline_class: nil,
104
115
  **kwargs
105
116
  )
117
+ if !model_arg.nil?
118
+ if !model.nil?
119
+ raise ArgumentError, "Cannot pass multiple models"
120
+ end
121
+ model = model_arg
122
+ end
123
+
106
124
  model_kwargs ||= {}
107
125
  # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
108
126
  # this is to keep BC).
@@ -0,0 +1,46 @@
1
+ module Transformers
2
+ class EmbeddingPipeline < Pipeline
3
+ def _sanitize_parameters(**kwargs)
4
+ [{}, {}, kwargs]
5
+ end
6
+
7
+ def preprocess(inputs)
8
+ @tokenizer.(inputs, return_tensors: @framework)
9
+ end
10
+
11
+ def _forward(model_inputs)
12
+ {
13
+ last_hidden_state: @model.(**model_inputs)[0],
14
+ attention_mask: model_inputs[:attention_mask]
15
+ }
16
+ end
17
+
18
+ def postprocess(model_outputs, pooling: "mean", normalize: true)
19
+ output = model_outputs[:last_hidden_state]
20
+
21
+ case pooling
22
+ when "none"
23
+ # do nothing
24
+ when "mean"
25
+ output = mean_pooling(output, model_outputs[:attention_mask])
26
+ when "cls"
27
+ output = output[0.., 0]
28
+ else
29
+ raise Error, "Pooling method '#{pooling}' not supported."
30
+ end
31
+
32
+ if normalize
33
+ output = Torch::NN::Functional.normalize(output, p: 2, dim: 1)
34
+ end
35
+
36
+ output[0].to_a
37
+ end
38
+
39
+ private
40
+
41
+ def mean_pooling(output, attention_mask)
42
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
43
+ Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
44
+ end
45
+ end
46
+ end
@@ -45,7 +45,7 @@ module Transformers
45
45
  end
46
46
 
47
47
  def _forward(model_inputs)
48
- model_outputs = @model.(**model_inputs.to_h)
48
+ model_outputs = @model.(**model_inputs)
49
49
  model_outputs
50
50
  end
51
51
 
@@ -29,7 +29,7 @@ module Transformers
29
29
  end
30
30
 
31
31
  def _forward(model_inputs)
32
- model_outputs = @model.(**model_inputs.to_h)
32
+ model_outputs = @model.(**model_inputs)
33
33
  model_outputs
34
34
  end
35
35
 
@@ -34,7 +34,7 @@ module Transformers
34
34
  end
35
35
 
36
36
  if function_to_apply.is_a?(String)
37
- function_to_apply = ClassificationFunction.new(function_to_apply.upcase).to_s
37
+ function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s
38
38
  end
39
39
 
40
40
  if !function_to_apply.nil?
@@ -62,7 +62,7 @@ module Transformers
62
62
  end
63
63
 
64
64
  def _forward(model_inputs)
65
- @model.(**model_inputs.to_h)
65
+ @model.(**model_inputs)
66
66
  end
67
67
 
68
68
  def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
@@ -62,7 +62,7 @@ module Transformers
62
62
 
63
63
  if !aggregation_strategy.nil?
64
64
  if aggregation_strategy.is_a?(String)
65
- aggregation_strategy = AggregationStrategy.new(aggregation_strategy.upcase).to_s
65
+ aggregation_strategy = AggregationStrategy.new(aggregation_strategy.downcase).to_s
66
66
  end
67
67
  if (
68
68
  [AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
@@ -278,5 +278,80 @@ module Transformers
278
278
  end
279
279
  group_entities(entities)
280
280
  end
281
+
282
+ def aggregate_word(entities, aggregation_strategy)
283
+ raise Todo
284
+ end
285
+
286
+ def aggregate_words(entities, aggregation_strategy)
287
+ raise Todo
288
+ end
289
+
290
+ def group_sub_entities(entities)
291
+ # Get the first entity in the entity group
292
+ entity = entities[0][:entity].split("-", 2)[-1]
293
+ scores = entities.map { |entity| entity[:score] }
294
+ tokens = entities.map { |entity| entity[:word] }
295
+
296
+ entity_group = {
297
+ entity_group: entity,
298
+ score: scores.sum / scores.count.to_f,
299
+ word: @tokenizer.convert_tokens_to_string(tokens),
300
+ start: entities[0][:start],
301
+ end: entities[-1][:end]
302
+ }
303
+ entity_group
304
+ end
305
+
306
+ def get_tag(entity_name)
307
+ if entity_name.start_with?("B-")
308
+ bi = "B"
309
+ tag = entity_name[2..]
310
+ elsif entity_name.start_with?("I-")
311
+ bi = "I"
312
+ tag = entity_name[2..]
313
+ else
314
+ # It's not in B-, I- format
315
+ # Default to I- for continuation.
316
+ bi = "I"
317
+ tag = entity_name
318
+ end
319
+ [bi, tag]
320
+ end
321
+
322
+ def group_entities(entities)
323
+ entity_groups = []
324
+ entity_group_disagg = []
325
+
326
+ entities.each do |entity|
327
+ if entity_group_disagg.empty?
328
+ entity_group_disagg << entity
329
+ next
330
+ end
331
+
332
+ # If the current entity is similar and adjacent to the previous entity,
333
+ # append it to the disaggregated entity group
334
+ # The split is meant to account for the "B" and "I" prefixes
335
+ # Shouldn't merge if both entities are B-type
336
+ bi, tag = get_tag(entity[:entity])
337
+ _last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])
338
+
339
+ if tag == last_tag && bi != "B"
340
+ # Modify subword type to be previous_type
341
+ entity_group_disagg << entity
342
+ else
343
+ # If the current entity is different from the previous entity
344
+ # aggregate the disaggregated entity group
345
+ entity_groups << group_sub_entities(entity_group_disagg)
346
+ entity_group_disagg = [entity]
347
+ end
348
+ end
349
+ if entity_group_disagg.any?
350
+ # it's the last entity, add it to the entity groups
351
+ entity_groups << group_sub_entities(entity_group_disagg)
352
+ end
353
+
354
+ entity_groups
355
+ end
281
356
  end
282
357
  end
@@ -2,36 +2,19 @@ module Transformers
2
2
  class SentenceTransformer
3
3
  def initialize(model_id)
4
4
  @model_id = model_id
5
- @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
6
- @model = Transformers::AutoModel.from_pretrained(model_id)
5
+ @model = Transformers.pipeline("embedding", model_id)
7
6
  end
8
7
 
9
8
  def encode(sentences)
10
- singular = sentences.is_a?(String)
11
- sentences = [sentences] if singular
12
-
13
- input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
14
- output = Torch.no_grad { @model.(**input) }[0]
15
-
16
9
  # TODO check modules.json
17
10
  if [
18
11
  "sentence-transformers/all-MiniLM-L6-v2",
19
12
  "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
20
13
  ].include?(@model_id)
21
- output = mean_pooling(output, input[:attention_mask])
22
- output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
14
+ @model.(sentences)
23
15
  else
24
- output = output[0.., 0].to_a
16
+ @model.(sentences, pooling: "cls", normalize: false)
25
17
  end
26
-
27
- singular ? output[0] : output
28
- end
29
-
30
- private
31
-
32
- def mean_pooling(output, attention_mask)
33
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
34
- Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
35
18
  end
36
19
  end
37
20
  end
@@ -91,6 +91,10 @@ module Transformers
91
91
  get_vocab
92
92
  end
93
93
 
94
+ def backend_tokenizer
95
+ @tokenizer
96
+ end
97
+
94
98
  def convert_tokens_to_ids(tokens)
95
99
  if tokens.nil?
96
100
  return nil
@@ -130,6 +134,10 @@ module Transformers
130
134
  tokens
131
135
  end
132
136
 
137
+ def convert_tokens_to_string(tokens)
138
+ backend_tokenizer.decoder.decode(tokens)
139
+ end
140
+
133
141
  private
134
142
 
135
143
  def set_truncation_and_padding(
@@ -1,3 +1,3 @@
1
1
  module Transformers
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
data/lib/transformers.rb CHANGED
@@ -75,6 +75,7 @@ require_relative "transformers/models/vit/modeling_vit"
75
75
  # pipelines
76
76
  require_relative "transformers/pipelines/base"
77
77
  require_relative "transformers/pipelines/feature_extraction"
78
+ require_relative "transformers/pipelines/embedding"
78
79
  require_relative "transformers/pipelines/image_classification"
79
80
  require_relative "transformers/pipelines/image_feature_extraction"
80
81
  require_relative "transformers/pipelines/pt_utils"
@@ -97,4 +98,10 @@ module Transformers
97
98
  "not implemented yet"
98
99
  end
99
100
  end
101
+
102
+ class << self
103
+ # experimental
104
+ attr_accessor :fast_init
105
+ end
106
+ self.fast_init = false
100
107
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: transformers-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-08-19 00:00:00.000000000 Z
11
+ date: 2024-08-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -44,14 +44,14 @@ dependencies:
44
44
  requirements:
45
45
  - - ">="
46
46
  - !ruby/object:Gem::Version
47
- version: '0.5'
47
+ version: 0.5.2
48
48
  type: :runtime
49
49
  prerelease: false
50
50
  version_requirements: !ruby/object:Gem::Requirement
51
51
  requirements:
52
52
  - - ">="
53
53
  - !ruby/object:Gem::Version
54
- version: '0.5'
54
+ version: 0.5.2
55
55
  - !ruby/object:Gem::Dependency
56
56
  name: torch-rb
57
57
  requirement: !ruby/object:Gem::Requirement
@@ -113,6 +113,7 @@ files:
113
113
  - lib/transformers/models/vit/modeling_vit.rb
114
114
  - lib/transformers/pipelines/_init.rb
115
115
  - lib/transformers/pipelines/base.rb
116
+ - lib/transformers/pipelines/embedding.rb
116
117
  - lib/transformers/pipelines/feature_extraction.rb
117
118
  - lib/transformers/pipelines/image_classification.rb
118
119
  - lib/transformers/pipelines/image_feature_extraction.rb